mirror of
https://github.com/QuantumNous/new-api.git
synced 2026-04-02 19:43:40 +00:00
Compare commits
115 Commits
v0.9.0
...
refactor/s
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
2e994abdd9 | ||
|
|
26a18346b2 | ||
|
|
99fcc354e3 | ||
|
|
456987a3d4 | ||
|
|
347c31f93c | ||
|
|
836ae7affe | ||
|
|
dd46322421 | ||
|
|
71f5dc987a | ||
|
|
6992fd2b66 | ||
|
|
92895ebe5a | ||
|
|
c0fb3bf95f | ||
|
|
abe31f216f | ||
|
|
44bc65691e | ||
|
|
7c27558de9 | ||
|
|
51ef19a3fb | ||
|
|
8e7301b79a | ||
|
|
ec98a21933 | ||
|
|
1dd59f5d08 | ||
|
|
ea084e775e | ||
|
|
41be436c04 | ||
|
|
b73b16e102 | ||
|
|
8f9960bcc7 | ||
|
|
3c70617060 | ||
|
|
3a98ae3f70 | ||
|
|
1894ddc786 | ||
|
|
f23be16e98 | ||
|
|
b882dfa8f6 | ||
|
|
d491cbd3d2 | ||
|
|
334ba555fc | ||
|
|
ba632d0b4d | ||
|
|
b5d3e87ea2 | ||
|
|
f22ea6e0a8 | ||
|
|
9f1ab16aa5 | ||
|
|
0dd475d2ff | ||
|
|
dd374cdd9b | ||
|
|
daf3ef9848 | ||
|
|
23ee0fc3b4 | ||
|
|
08638b18ce | ||
|
|
d331f0fb2a | ||
|
|
4b98fceb6e | ||
|
|
ef63416098 | ||
|
|
50a432180d | ||
|
|
2ea7634549 | ||
|
|
10da082412 | ||
|
|
31c8ead1d4 | ||
|
|
00f4594062 | ||
|
|
467e584359 | ||
|
|
f635fc3ae6 | ||
|
|
168ebb1cd4 | ||
|
|
b7bc609a7a | ||
|
|
046c8b27b6 | ||
|
|
4be61d00e4 | ||
|
|
4ac7d94026 | ||
|
|
9af71caf73 | ||
|
|
91e57a4c69 | ||
|
|
45a6a779e5 | ||
|
|
49c7a0dee5 | ||
|
|
956244c742 | ||
|
|
752dc11dd4 | ||
|
|
17be7c3b45 | ||
|
|
11cf70e60d | ||
|
|
dfa27f3412 | ||
|
|
e34b5def60 | ||
|
|
63f94e7669 | ||
|
|
18a385f817 | ||
|
|
8e95d338b5 | ||
|
|
f236785ed5 | ||
|
|
f3e220b196 | ||
|
|
33bf267ce8 | ||
|
|
05c2dde38f | ||
|
|
0ee5670be6 | ||
|
|
9790e2c4f6 | ||
|
|
4f760a8d40 | ||
|
|
8563eafc57 | ||
|
|
72d5b35d3f | ||
|
|
7d71f467d9 | ||
|
|
aea732ab92 | ||
|
|
da6f24a3d4 | ||
|
|
28ed42130c | ||
|
|
96215c9fd5 | ||
|
|
6628fd9181 | ||
|
|
a3b8a1998a | ||
|
|
6a34d365ec | ||
|
|
406a3e4dca | ||
|
|
c1d7ecdeec | ||
|
|
6451158680 | ||
|
|
0bd4b34046 | ||
|
|
f14b06ec3a | ||
|
|
6ed775be8f | ||
|
|
b712279b2a | ||
|
|
1bffe3081d | ||
|
|
cfebe80822 | ||
|
|
17e697af8f | ||
|
|
01b35bb667 | ||
|
|
d8410d2f11 | ||
|
|
e68eed3d40 | ||
|
|
04cc668430 | ||
|
|
5d76e16324 | ||
|
|
b6c547ae98 | ||
|
|
93adcd57d7 | ||
|
|
e813da59cc | ||
|
|
b25ac0bfb6 | ||
|
|
70c27bc662 | ||
|
|
db6a788e0d | ||
|
|
e3bc40f11b | ||
|
|
3e9be07db4 | ||
|
|
684caa3673 | ||
|
|
47aaa695b2 | ||
|
|
cda73a2ec5 | ||
|
|
27a0a447d0 | ||
|
|
fcdfd027cd | ||
|
|
3f9698bb47 | ||
|
|
d15718a87e | ||
|
|
da5aace109 | ||
|
|
81e29aaa3d |
4
.gitignore
vendored
4
.gitignore
vendored
@@ -11,4 +11,6 @@ web/dist
|
||||
one-api
|
||||
.DS_Store
|
||||
tiktoken_cache
|
||||
.eslintcache
|
||||
.eslintcache
|
||||
.cursor
|
||||
*.mdc
|
||||
@@ -12,4 +12,4 @@ var LogSqlType = DatabaseTypeSQLite // Default to SQLite for logging SQL queries
|
||||
var UsingMySQL = false
|
||||
var UsingClickHouse = false
|
||||
|
||||
var SQLitePath = "one-api.db?_busy_timeout=30000"
|
||||
var SQLitePath = "one-api.db?_busy_timeout=30000"
|
||||
22
common/ip.go
Normal file
22
common/ip.go
Normal file
@@ -0,0 +1,22 @@
|
||||
package common
|
||||
|
||||
import "net"
|
||||
|
||||
func IsPrivateIP(ip net.IP) bool {
|
||||
if ip.IsLoopback() || ip.IsLinkLocalUnicast() || ip.IsLinkLocalMulticast() {
|
||||
return true
|
||||
}
|
||||
|
||||
private := []net.IPNet{
|
||||
{IP: net.IPv4(10, 0, 0, 0), Mask: net.CIDRMask(8, 32)},
|
||||
{IP: net.IPv4(172, 16, 0, 0), Mask: net.CIDRMask(12, 32)},
|
||||
{IP: net.IPv4(192, 168, 0, 0), Mask: net.CIDRMask(16, 32)},
|
||||
}
|
||||
|
||||
for _, privateNet := range private {
|
||||
if privateNet.Contains(ip) {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
327
common/ssrf_protection.go
Normal file
327
common/ssrf_protection.go
Normal file
@@ -0,0 +1,327 @@
|
||||
package common
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"net"
|
||||
"net/url"
|
||||
"strconv"
|
||||
"strings"
|
||||
)
|
||||
|
||||
// SSRFProtection SSRF防护配置
|
||||
type SSRFProtection struct {
|
||||
AllowPrivateIp bool
|
||||
DomainFilterMode bool // true: 白名单, false: 黑名单
|
||||
DomainList []string // domain format, e.g. example.com, *.example.com
|
||||
IpFilterMode bool // true: 白名单, false: 黑名单
|
||||
IpList []string // CIDR or single IP
|
||||
AllowedPorts []int // 允许的端口范围
|
||||
ApplyIPFilterForDomain bool // 对域名启用IP过滤
|
||||
}
|
||||
|
||||
// DefaultSSRFProtection 默认SSRF防护配置
|
||||
var DefaultSSRFProtection = &SSRFProtection{
|
||||
AllowPrivateIp: false,
|
||||
DomainFilterMode: true,
|
||||
DomainList: []string{},
|
||||
IpFilterMode: true,
|
||||
IpList: []string{},
|
||||
AllowedPorts: []int{},
|
||||
}
|
||||
|
||||
// isPrivateIP 检查IP是否为私有地址
|
||||
func isPrivateIP(ip net.IP) bool {
|
||||
if ip.IsLoopback() || ip.IsLinkLocalUnicast() || ip.IsLinkLocalMulticast() {
|
||||
return true
|
||||
}
|
||||
|
||||
// 检查私有网段
|
||||
private := []net.IPNet{
|
||||
{IP: net.IPv4(10, 0, 0, 0), Mask: net.CIDRMask(8, 32)}, // 10.0.0.0/8
|
||||
{IP: net.IPv4(172, 16, 0, 0), Mask: net.CIDRMask(12, 32)}, // 172.16.0.0/12
|
||||
{IP: net.IPv4(192, 168, 0, 0), Mask: net.CIDRMask(16, 32)}, // 192.168.0.0/16
|
||||
{IP: net.IPv4(127, 0, 0, 0), Mask: net.CIDRMask(8, 32)}, // 127.0.0.0/8
|
||||
{IP: net.IPv4(169, 254, 0, 0), Mask: net.CIDRMask(16, 32)}, // 169.254.0.0/16 (链路本地)
|
||||
{IP: net.IPv4(224, 0, 0, 0), Mask: net.CIDRMask(4, 32)}, // 224.0.0.0/4 (组播)
|
||||
{IP: net.IPv4(240, 0, 0, 0), Mask: net.CIDRMask(4, 32)}, // 240.0.0.0/4 (保留)
|
||||
}
|
||||
|
||||
for _, privateNet := range private {
|
||||
if privateNet.Contains(ip) {
|
||||
return true
|
||||
}
|
||||
}
|
||||
|
||||
// 检查IPv6私有地址
|
||||
if ip.To4() == nil {
|
||||
// IPv6 loopback
|
||||
if ip.Equal(net.IPv6loopback) {
|
||||
return true
|
||||
}
|
||||
// IPv6 link-local
|
||||
if strings.HasPrefix(ip.String(), "fe80:") {
|
||||
return true
|
||||
}
|
||||
// IPv6 unique local
|
||||
if strings.HasPrefix(ip.String(), "fc") || strings.HasPrefix(ip.String(), "fd") {
|
||||
return true
|
||||
}
|
||||
}
|
||||
|
||||
return false
|
||||
}
|
||||
|
||||
// parsePortRanges 解析端口范围配置
|
||||
// 支持格式: "80", "443", "8000-9000"
|
||||
func parsePortRanges(portConfigs []string) ([]int, error) {
|
||||
var ports []int
|
||||
|
||||
for _, config := range portConfigs {
|
||||
config = strings.TrimSpace(config)
|
||||
if config == "" {
|
||||
continue
|
||||
}
|
||||
|
||||
if strings.Contains(config, "-") {
|
||||
// 处理端口范围 "8000-9000"
|
||||
parts := strings.Split(config, "-")
|
||||
if len(parts) != 2 {
|
||||
return nil, fmt.Errorf("invalid port range format: %s", config)
|
||||
}
|
||||
|
||||
startPort, err := strconv.Atoi(strings.TrimSpace(parts[0]))
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("invalid start port in range %s: %v", config, err)
|
||||
}
|
||||
|
||||
endPort, err := strconv.Atoi(strings.TrimSpace(parts[1]))
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("invalid end port in range %s: %v", config, err)
|
||||
}
|
||||
|
||||
if startPort > endPort {
|
||||
return nil, fmt.Errorf("invalid port range %s: start port cannot be greater than end port", config)
|
||||
}
|
||||
|
||||
if startPort < 1 || startPort > 65535 || endPort < 1 || endPort > 65535 {
|
||||
return nil, fmt.Errorf("port range %s contains invalid port numbers (must be 1-65535)", config)
|
||||
}
|
||||
|
||||
// 添加范围内的所有端口
|
||||
for port := startPort; port <= endPort; port++ {
|
||||
ports = append(ports, port)
|
||||
}
|
||||
} else {
|
||||
// 处理单个端口 "80"
|
||||
port, err := strconv.Atoi(config)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("invalid port number: %s", config)
|
||||
}
|
||||
|
||||
if port < 1 || port > 65535 {
|
||||
return nil, fmt.Errorf("invalid port number %d (must be 1-65535)", port)
|
||||
}
|
||||
|
||||
ports = append(ports, port)
|
||||
}
|
||||
}
|
||||
|
||||
return ports, nil
|
||||
}
|
||||
|
||||
// isAllowedPort 检查端口是否被允许
|
||||
func (p *SSRFProtection) isAllowedPort(port int) bool {
|
||||
if len(p.AllowedPorts) == 0 {
|
||||
return true // 如果没有配置端口限制,则允许所有端口
|
||||
}
|
||||
|
||||
for _, allowedPort := range p.AllowedPorts {
|
||||
if port == allowedPort {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// isDomainWhitelisted 检查域名是否在白名单中
|
||||
func isDomainListed(domain string, list []string) bool {
|
||||
if len(list) == 0 {
|
||||
return false
|
||||
}
|
||||
|
||||
domain = strings.ToLower(domain)
|
||||
for _, item := range list {
|
||||
item = strings.ToLower(strings.TrimSpace(item))
|
||||
if item == "" {
|
||||
continue
|
||||
}
|
||||
// 精确匹配
|
||||
if domain == item {
|
||||
return true
|
||||
}
|
||||
// 通配符匹配 (*.example.com)
|
||||
if strings.HasPrefix(item, "*.") {
|
||||
suffix := strings.TrimPrefix(item, "*.")
|
||||
if strings.HasSuffix(domain, "."+suffix) || domain == suffix {
|
||||
return true
|
||||
}
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func (p *SSRFProtection) isDomainAllowed(domain string) bool {
|
||||
listed := isDomainListed(domain, p.DomainList)
|
||||
if p.DomainFilterMode { // 白名单
|
||||
return listed
|
||||
}
|
||||
// 黑名单
|
||||
return !listed
|
||||
}
|
||||
|
||||
// isIPWhitelisted 检查IP是否在白名单中
|
||||
|
||||
func isIPListed(ip net.IP, list []string) bool {
|
||||
if len(list) == 0 {
|
||||
return false
|
||||
}
|
||||
|
||||
for _, whitelistCIDR := range list {
|
||||
_, network, err := net.ParseCIDR(whitelistCIDR)
|
||||
if err != nil {
|
||||
// 尝试作为单个IP处理
|
||||
if whitelistIP := net.ParseIP(whitelistCIDR); whitelistIP != nil {
|
||||
if ip.Equal(whitelistIP) {
|
||||
return true
|
||||
}
|
||||
}
|
||||
continue
|
||||
}
|
||||
|
||||
if network.Contains(ip) {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// IsIPAccessAllowed 检查IP是否允许访问
|
||||
func (p *SSRFProtection) IsIPAccessAllowed(ip net.IP) bool {
|
||||
// 私有IP限制
|
||||
if isPrivateIP(ip) && !p.AllowPrivateIp {
|
||||
return false
|
||||
}
|
||||
|
||||
listed := isIPListed(ip, p.IpList)
|
||||
if p.IpFilterMode { // 白名单
|
||||
return listed
|
||||
}
|
||||
// 黑名单
|
||||
return !listed
|
||||
}
|
||||
|
||||
// ValidateURL 验证URL是否安全
|
||||
func (p *SSRFProtection) ValidateURL(urlStr string) error {
|
||||
// 解析URL
|
||||
u, err := url.Parse(urlStr)
|
||||
if err != nil {
|
||||
return fmt.Errorf("invalid URL format: %v", err)
|
||||
}
|
||||
|
||||
// 只允许HTTP/HTTPS协议
|
||||
if u.Scheme != "http" && u.Scheme != "https" {
|
||||
return fmt.Errorf("unsupported protocol: %s (only http/https allowed)", u.Scheme)
|
||||
}
|
||||
|
||||
// 解析主机和端口
|
||||
host, portStr, err := net.SplitHostPort(u.Host)
|
||||
if err != nil {
|
||||
// 没有端口,使用默认端口
|
||||
host = u.Hostname()
|
||||
if u.Scheme == "https" {
|
||||
portStr = "443"
|
||||
} else {
|
||||
portStr = "80"
|
||||
}
|
||||
}
|
||||
|
||||
// 验证端口
|
||||
port, err := strconv.Atoi(portStr)
|
||||
if err != nil {
|
||||
return fmt.Errorf("invalid port: %s", portStr)
|
||||
}
|
||||
|
||||
if !p.isAllowedPort(port) {
|
||||
return fmt.Errorf("port %d is not allowed", port)
|
||||
}
|
||||
|
||||
// 如果 host 是 IP,则跳过域名检查
|
||||
if ip := net.ParseIP(host); ip != nil {
|
||||
if !p.IsIPAccessAllowed(ip) {
|
||||
if isPrivateIP(ip) {
|
||||
return fmt.Errorf("private IP address not allowed: %s", ip.String())
|
||||
}
|
||||
if p.IpFilterMode {
|
||||
return fmt.Errorf("ip not in whitelist: %s", ip.String())
|
||||
}
|
||||
return fmt.Errorf("ip in blacklist: %s", ip.String())
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// 先进行域名过滤
|
||||
if !p.isDomainAllowed(host) {
|
||||
if p.DomainFilterMode {
|
||||
return fmt.Errorf("domain not in whitelist: %s", host)
|
||||
}
|
||||
return fmt.Errorf("domain in blacklist: %s", host)
|
||||
}
|
||||
|
||||
// 若未启用对域名应用IP过滤,则到此通过
|
||||
if !p.ApplyIPFilterForDomain {
|
||||
return nil
|
||||
}
|
||||
|
||||
// 解析域名对应IP并检查
|
||||
ips, err := net.LookupIP(host)
|
||||
if err != nil {
|
||||
return fmt.Errorf("DNS resolution failed for %s: %v", host, err)
|
||||
}
|
||||
for _, ip := range ips {
|
||||
if !p.IsIPAccessAllowed(ip) {
|
||||
if isPrivateIP(ip) && !p.AllowPrivateIp {
|
||||
return fmt.Errorf("private IP address not allowed: %s resolves to %s", host, ip.String())
|
||||
}
|
||||
if p.IpFilterMode {
|
||||
return fmt.Errorf("ip not in whitelist: %s resolves to %s", host, ip.String())
|
||||
}
|
||||
return fmt.Errorf("ip in blacklist: %s resolves to %s", host, ip.String())
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// ValidateURLWithFetchSetting 使用FetchSetting配置验证URL
|
||||
func ValidateURLWithFetchSetting(urlStr string, enableSSRFProtection, allowPrivateIp bool, domainFilterMode bool, ipFilterMode bool, domainList, ipList, allowedPorts []string, applyIPFilterForDomain bool) error {
|
||||
// 如果SSRF防护被禁用,直接返回成功
|
||||
if !enableSSRFProtection {
|
||||
return nil
|
||||
}
|
||||
|
||||
// 解析端口范围配置
|
||||
allowedPortInts, err := parsePortRanges(allowedPorts)
|
||||
if err != nil {
|
||||
return fmt.Errorf("request reject - invalid port configuration: %v", err)
|
||||
}
|
||||
|
||||
protection := &SSRFProtection{
|
||||
AllowPrivateIp: allowPrivateIp,
|
||||
DomainFilterMode: domainFilterMode,
|
||||
DomainList: domainList,
|
||||
IpFilterMode: ipFilterMode,
|
||||
IpList: ipList,
|
||||
AllowedPorts: allowedPortInts,
|
||||
ApplyIPFilterForDomain: applyIPFilterForDomain,
|
||||
}
|
||||
return protection.ValidateURL(urlStr)
|
||||
}
|
||||
@@ -11,8 +11,10 @@ const (
|
||||
SunoActionMusic = "MUSIC"
|
||||
SunoActionLyrics = "LYRICS"
|
||||
|
||||
TaskActionGenerate = "generate"
|
||||
TaskActionTextGenerate = "textGenerate"
|
||||
TaskActionGenerate = "generate"
|
||||
TaskActionTextGenerate = "textGenerate"
|
||||
TaskActionFirstTailGenerate = "firstTailGenerate"
|
||||
TaskActionReferenceGenerate = "referenceGenerate"
|
||||
)
|
||||
|
||||
var SunoModel2Action = map[string]string{
|
||||
|
||||
@@ -10,7 +10,7 @@ import (
|
||||
"one-api/constant"
|
||||
"one-api/model"
|
||||
"one-api/service"
|
||||
"one-api/setting"
|
||||
"one-api/setting/operation_setting"
|
||||
"one-api/types"
|
||||
"strconv"
|
||||
"time"
|
||||
@@ -342,7 +342,7 @@ func updateChannelMoonshotBalance(channel *model.Channel) (float64, error) {
|
||||
return 0, fmt.Errorf("failed to update moonshot balance, status: %v, code: %d, scode: %s", response.Status, response.Code, response.Scode)
|
||||
}
|
||||
availableBalanceCny := response.Data.AvailableBalance
|
||||
availableBalanceUsd := decimal.NewFromFloat(availableBalanceCny).Div(decimal.NewFromFloat(setting.Price)).InexactFloat64()
|
||||
availableBalanceUsd := decimal.NewFromFloat(availableBalanceCny).Div(decimal.NewFromFloat(operation_setting.Price)).InexactFloat64()
|
||||
channel.UpdateBalance(availableBalanceUsd)
|
||||
return availableBalanceUsd, nil
|
||||
}
|
||||
|
||||
@@ -90,6 +90,11 @@ func testChannel(channel *model.Channel, testModel string) testResult {
|
||||
requestPath = "/v1/embeddings" // 修改请求路径
|
||||
}
|
||||
|
||||
// VolcEngine 图像生成模型
|
||||
if channel.Type == constant.ChannelTypeVolcEngine && strings.Contains(testModel, "seedream") {
|
||||
requestPath = "/v1/images/generations"
|
||||
}
|
||||
|
||||
c.Request = &http.Request{
|
||||
Method: "POST",
|
||||
URL: &url.URL{Path: requestPath}, // 使用动态路径
|
||||
@@ -109,6 +114,21 @@ func testChannel(channel *model.Channel, testModel string) testResult {
|
||||
}
|
||||
}
|
||||
|
||||
// 重新检查模型类型并更新请求路径
|
||||
if strings.Contains(strings.ToLower(testModel), "embedding") ||
|
||||
strings.HasPrefix(testModel, "m3e") ||
|
||||
strings.Contains(testModel, "bge-") ||
|
||||
strings.Contains(testModel, "embed") ||
|
||||
channel.Type == constant.ChannelTypeMokaAI {
|
||||
requestPath = "/v1/embeddings"
|
||||
c.Request.URL.Path = requestPath
|
||||
}
|
||||
|
||||
if channel.Type == constant.ChannelTypeVolcEngine && strings.Contains(testModel, "seedream") {
|
||||
requestPath = "/v1/images/generations"
|
||||
c.Request.URL.Path = requestPath
|
||||
}
|
||||
|
||||
cache, err := model.GetUserCache(1)
|
||||
if err != nil {
|
||||
return testResult{
|
||||
@@ -140,6 +160,9 @@ func testChannel(channel *model.Channel, testModel string) testResult {
|
||||
if c.Request.URL.Path == "/v1/embeddings" {
|
||||
relayFormat = types.RelayFormatEmbedding
|
||||
}
|
||||
if c.Request.URL.Path == "/v1/images/generations" {
|
||||
relayFormat = types.RelayFormatOpenAIImage
|
||||
}
|
||||
|
||||
info, err := relaycommon.GenRelayInfo(c, relayFormat, request, nil)
|
||||
|
||||
@@ -201,6 +224,22 @@ func testChannel(channel *model.Channel, testModel string) testResult {
|
||||
}
|
||||
// 调用专门用于 Embedding 的转换函数
|
||||
convertedRequest, err = adaptor.ConvertEmbeddingRequest(c, info, embeddingRequest)
|
||||
} else if info.RelayMode == relayconstant.RelayModeImagesGenerations {
|
||||
// 创建一个 ImageRequest
|
||||
prompt := "cat"
|
||||
if request.Prompt != nil {
|
||||
if promptStr, ok := request.Prompt.(string); ok && promptStr != "" {
|
||||
prompt = promptStr
|
||||
}
|
||||
}
|
||||
imageRequest := dto.ImageRequest{
|
||||
Prompt: prompt,
|
||||
Model: request.Model,
|
||||
N: uint(request.N),
|
||||
Size: request.Size,
|
||||
}
|
||||
// 调用专门用于图像生成的转换函数
|
||||
convertedRequest, err = adaptor.ConvertImageRequest(c, info, imageRequest)
|
||||
} else {
|
||||
// 对其他所有请求类型(如 Chat),保持原有逻辑
|
||||
convertedRequest, err = adaptor.ConvertOpenAIRequest(c, info, request)
|
||||
@@ -235,7 +274,7 @@ func testChannel(channel *model.Channel, testModel string) testResult {
|
||||
if resp != nil {
|
||||
httpResp = resp.(*http.Response)
|
||||
if httpResp.StatusCode != http.StatusOK {
|
||||
err := service.RelayErrorHandler(httpResp, true)
|
||||
err := service.RelayErrorHandler(c.Request.Context(), httpResp, true)
|
||||
return testResult{
|
||||
context: c,
|
||||
localErr: err,
|
||||
|
||||
@@ -6,6 +6,7 @@ import (
|
||||
"net/http"
|
||||
"one-api/common"
|
||||
"one-api/constant"
|
||||
"one-api/dto"
|
||||
"one-api/model"
|
||||
"strconv"
|
||||
"strings"
|
||||
@@ -500,9 +501,10 @@ func validateChannel(channel *model.Channel, isAdd bool) error {
|
||||
}
|
||||
|
||||
type AddChannelRequest struct {
|
||||
Mode string `json:"mode"`
|
||||
MultiKeyMode constant.MultiKeyMode `json:"multi_key_mode"`
|
||||
Channel *model.Channel `json:"channel"`
|
||||
Mode string `json:"mode"`
|
||||
MultiKeyMode constant.MultiKeyMode `json:"multi_key_mode"`
|
||||
BatchAddSetKeyPrefix2Name bool `json:"batch_add_set_key_prefix_2_name"`
|
||||
Channel *model.Channel `json:"channel"`
|
||||
}
|
||||
|
||||
func getVertexArrayKeys(keys string) ([]string, error) {
|
||||
@@ -560,7 +562,7 @@ func AddChannel(c *gin.Context) {
|
||||
case "multi_to_single":
|
||||
addChannelRequest.Channel.ChannelInfo.IsMultiKey = true
|
||||
addChannelRequest.Channel.ChannelInfo.MultiKeyMode = addChannelRequest.MultiKeyMode
|
||||
if addChannelRequest.Channel.Type == constant.ChannelTypeVertexAi {
|
||||
if addChannelRequest.Channel.Type == constant.ChannelTypeVertexAi && addChannelRequest.Channel.GetOtherSettings().VertexKeyType != dto.VertexKeyTypeAPIKey {
|
||||
array, err := getVertexArrayKeys(addChannelRequest.Channel.Key)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
@@ -585,7 +587,7 @@ func AddChannel(c *gin.Context) {
|
||||
}
|
||||
keys = []string{addChannelRequest.Channel.Key}
|
||||
case "batch":
|
||||
if addChannelRequest.Channel.Type == constant.ChannelTypeVertexAi {
|
||||
if addChannelRequest.Channel.Type == constant.ChannelTypeVertexAi && addChannelRequest.Channel.GetOtherSettings().VertexKeyType != dto.VertexKeyTypeAPIKey {
|
||||
// multi json
|
||||
keys, err = getVertexArrayKeys(addChannelRequest.Channel.Key)
|
||||
if err != nil {
|
||||
@@ -615,6 +617,13 @@ func AddChannel(c *gin.Context) {
|
||||
}
|
||||
localChannel := addChannelRequest.Channel
|
||||
localChannel.Key = key
|
||||
if addChannelRequest.BatchAddSetKeyPrefix2Name && len(keys) > 1 {
|
||||
keyPrefix := localChannel.Key
|
||||
if len(localChannel.Key) > 8 {
|
||||
keyPrefix = localChannel.Key[:8]
|
||||
}
|
||||
localChannel.Name = fmt.Sprintf("%s %s", localChannel.Name, keyPrefix)
|
||||
}
|
||||
channels = append(channels, *localChannel)
|
||||
}
|
||||
err = model.BatchInsertChannels(channels)
|
||||
@@ -840,7 +849,7 @@ func UpdateChannel(c *gin.Context) {
|
||||
}
|
||||
|
||||
// 处理 Vertex AI 的特殊情况
|
||||
if channel.Type == constant.ChannelTypeVertexAi {
|
||||
if channel.Type == constant.ChannelTypeVertexAi && channel.GetOtherSettings().VertexKeyType != dto.VertexKeyTypeAPIKey {
|
||||
// 尝试解析新密钥为JSON数组
|
||||
if strings.HasPrefix(strings.TrimSpace(channel.Key), "[") {
|
||||
array, err := getVertexArrayKeys(channel.Key)
|
||||
|
||||
@@ -13,6 +13,7 @@ import (
|
||||
"one-api/model"
|
||||
"one-api/service"
|
||||
"one-api/setting"
|
||||
"one-api/setting/system_setting"
|
||||
"time"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
@@ -259,7 +260,7 @@ func GetAllMidjourney(c *gin.Context) {
|
||||
|
||||
if setting.MjForwardUrlEnabled {
|
||||
for i, midjourney := range items {
|
||||
midjourney.ImageUrl = setting.ServerAddress + "/mj/image/" + midjourney.MjId
|
||||
midjourney.ImageUrl = system_setting.ServerAddress + "/mj/image/" + midjourney.MjId
|
||||
items[i] = midjourney
|
||||
}
|
||||
}
|
||||
@@ -284,7 +285,7 @@ func GetUserMidjourney(c *gin.Context) {
|
||||
|
||||
if setting.MjForwardUrlEnabled {
|
||||
for i, midjourney := range items {
|
||||
midjourney.ImageUrl = setting.ServerAddress + "/mj/image/" + midjourney.MjId
|
||||
midjourney.ImageUrl = system_setting.ServerAddress + "/mj/image/" + midjourney.MjId
|
||||
items[i] = midjourney
|
||||
}
|
||||
}
|
||||
|
||||
@@ -58,11 +58,7 @@ func GetStatus(c *gin.Context) {
|
||||
"footer_html": common.Footer,
|
||||
"wechat_qrcode": common.WeChatAccountQRCodeImageURL,
|
||||
"wechat_login": common.WeChatAuthEnabled,
|
||||
"server_address": setting.ServerAddress,
|
||||
"price": setting.Price,
|
||||
"stripe_unit_price": setting.StripeUnitPrice,
|
||||
"min_topup": setting.MinTopUp,
|
||||
"stripe_min_topup": setting.StripeMinTopUp,
|
||||
"server_address": system_setting.ServerAddress,
|
||||
"turnstile_check": common.TurnstileCheckEnabled,
|
||||
"turnstile_site_key": common.TurnstileSiteKey,
|
||||
"top_up_link": common.TopUpLink,
|
||||
@@ -75,15 +71,15 @@ func GetStatus(c *gin.Context) {
|
||||
"enable_data_export": common.DataExportEnabled,
|
||||
"data_export_default_time": common.DataExportDefaultTime,
|
||||
"default_collapse_sidebar": common.DefaultCollapseSidebar,
|
||||
"enable_online_topup": setting.PayAddress != "" && setting.EpayId != "" && setting.EpayKey != "",
|
||||
"enable_stripe_topup": setting.StripeApiSecret != "" && setting.StripeWebhookSecret != "" && setting.StripePriceId != "",
|
||||
"mj_notify_enabled": setting.MjNotifyEnabled,
|
||||
"chats": setting.Chats,
|
||||
"demo_site_enabled": operation_setting.DemoSiteEnabled,
|
||||
"self_use_mode_enabled": operation_setting.SelfUseModeEnabled,
|
||||
"default_use_auto_group": setting.DefaultUseAutoGroup,
|
||||
"pay_methods": setting.PayMethods,
|
||||
"usd_exchange_rate": setting.USDExchangeRate,
|
||||
|
||||
"usd_exchange_rate": operation_setting.USDExchangeRate,
|
||||
"price": operation_setting.Price,
|
||||
"stripe_unit_price": setting.StripeUnitPrice,
|
||||
|
||||
// 面板启用开关
|
||||
"api_info_enabled": cs.ApiInfoEnabled,
|
||||
@@ -253,7 +249,7 @@ func SendPasswordResetEmail(c *gin.Context) {
|
||||
}
|
||||
code := common.GenerateVerificationCode(0)
|
||||
common.RegisterVerificationCodeWithKey(email, code, common.PasswordResetPurpose)
|
||||
link := fmt.Sprintf("%s/user/reset?email=%s&token=%s", setting.ServerAddress, email, code)
|
||||
link := fmt.Sprintf("%s/user/reset?email=%s&token=%s", system_setting.ServerAddress, email, code)
|
||||
subject := fmt.Sprintf("%s密码重置", common.SystemName)
|
||||
content := fmt.Sprintf("<p>您好,你正在进行%s密码重置。</p>"+
|
||||
"<p>点击 <a href='%s'>此处</a> 进行密码重置。</p>"+
|
||||
|
||||
@@ -8,7 +8,6 @@ import (
|
||||
"net/url"
|
||||
"one-api/common"
|
||||
"one-api/model"
|
||||
"one-api/setting"
|
||||
"one-api/setting/system_setting"
|
||||
"strconv"
|
||||
"strings"
|
||||
@@ -45,7 +44,7 @@ func getOidcUserInfoByCode(code string) (*OidcUser, error) {
|
||||
values.Set("client_secret", system_setting.GetOIDCSettings().ClientSecret)
|
||||
values.Set("code", code)
|
||||
values.Set("grant_type", "authorization_code")
|
||||
values.Set("redirect_uri", fmt.Sprintf("%s/oauth/oidc", setting.ServerAddress))
|
||||
values.Set("redirect_uri", fmt.Sprintf("%s/oauth/oidc", system_setting.ServerAddress))
|
||||
formData := values.Encode()
|
||||
req, err := http.NewRequest("POST", system_setting.GetOIDCSettings().TokenEndpoint, strings.NewReader(formData))
|
||||
if err != nil {
|
||||
|
||||
@@ -128,6 +128,33 @@ func UpdateOption(c *gin.Context) {
|
||||
})
|
||||
return
|
||||
}
|
||||
case "ImageRatio":
|
||||
err = ratio_setting.UpdateImageRatioByJSONString(option.Value.(string))
|
||||
if err != nil {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": false,
|
||||
"message": "图片倍率设置失败: " + err.Error(),
|
||||
})
|
||||
return
|
||||
}
|
||||
case "AudioRatio":
|
||||
err = ratio_setting.UpdateAudioRatioByJSONString(option.Value.(string))
|
||||
if err != nil {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": false,
|
||||
"message": "音频倍率设置失败: " + err.Error(),
|
||||
})
|
||||
return
|
||||
}
|
||||
case "AudioCompletionRatio":
|
||||
err = ratio_setting.UpdateAudioCompletionRatioByJSONString(option.Value.(string))
|
||||
if err != nil {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": false,
|
||||
"message": "音频补全倍率设置失败: " + err.Error(),
|
||||
})
|
||||
return
|
||||
}
|
||||
case "ModelRequestRateLimitGroup":
|
||||
err = setting.CheckModelRequestRateLimitGroup(option.Value.(string))
|
||||
if err != nil {
|
||||
|
||||
@@ -139,15 +139,15 @@ func Relay(c *gin.Context, relayFormat types.RelayFormat) {
|
||||
|
||||
// common.SetContextKey(c, constant.ContextKeyTokenCountMeta, meta)
|
||||
|
||||
preConsumedQuota, newAPIError := service.PreConsumeQuota(c, priceData.ShouldPreConsumedQuota, relayInfo)
|
||||
newAPIError = service.PreConsumeQuota(c, priceData.ShouldPreConsumedQuota, relayInfo)
|
||||
if newAPIError != nil {
|
||||
return
|
||||
}
|
||||
|
||||
defer func() {
|
||||
// Only return quota if downstream failed and quota was actually pre-consumed
|
||||
if newAPIError != nil && preConsumedQuota != 0 {
|
||||
service.ReturnPreConsumedQuota(c, relayInfo, preConsumedQuota)
|
||||
if newAPIError != nil && relayInfo.FinalPreConsumedQuota != 0 {
|
||||
service.ReturnPreConsumedQuota(c, relayInfo)
|
||||
}
|
||||
}()
|
||||
|
||||
@@ -277,14 +277,13 @@ func shouldRetry(c *gin.Context, openaiErr *types.NewAPIError, retryTimes int) b
|
||||
|
||||
func processChannelError(c *gin.Context, channelError types.ChannelError, err *types.NewAPIError) {
|
||||
logger.LogError(c, fmt.Sprintf("relay error (channel #%d, status code: %d): %s", channelError.ChannelId, err.StatusCode, err.Error()))
|
||||
|
||||
gopool.Go(func() {
|
||||
// 不要使用context获取渠道信息,异步处理时可能会出现渠道信息不一致的情况
|
||||
// do not use context to get channel info, there may be inconsistent channel info when processing asynchronously
|
||||
if service.ShouldDisableChannel(channelError.ChannelId, err) && channelError.AutoBan {
|
||||
// 不要使用context获取渠道信息,异步处理时可能会出现渠道信息不一致的情况
|
||||
// do not use context to get channel info, there may be inconsistent channel info when processing asynchronously
|
||||
if service.ShouldDisableChannel(channelError.ChannelId, err) && channelError.AutoBan {
|
||||
gopool.Go(func() {
|
||||
service.DisableChannel(channelError, err.Error())
|
||||
}
|
||||
})
|
||||
})
|
||||
}
|
||||
|
||||
if constant.ErrorLogEnabled && types.IsRecordErrorLog(err) {
|
||||
// 保存错误日志到mysql中
|
||||
|
||||
@@ -53,7 +53,7 @@ func GetSetup(c *gin.Context) {
|
||||
func PostSetup(c *gin.Context) {
|
||||
// Check if setup is already completed
|
||||
if constant.Setup {
|
||||
c.JSON(400, gin.H{
|
||||
c.JSON(200, gin.H{
|
||||
"success": false,
|
||||
"message": "系统已经初始化完成",
|
||||
})
|
||||
@@ -66,7 +66,7 @@ func PostSetup(c *gin.Context) {
|
||||
var req SetupRequest
|
||||
err := c.ShouldBindJSON(&req)
|
||||
if err != nil {
|
||||
c.JSON(400, gin.H{
|
||||
c.JSON(200, gin.H{
|
||||
"success": false,
|
||||
"message": "请求参数有误",
|
||||
})
|
||||
@@ -77,7 +77,7 @@ func PostSetup(c *gin.Context) {
|
||||
if !rootExists {
|
||||
// Validate username length: max 12 characters to align with model.User validation
|
||||
if len(req.Username) > 12 {
|
||||
c.JSON(400, gin.H{
|
||||
c.JSON(200, gin.H{
|
||||
"success": false,
|
||||
"message": "用户名长度不能超过12个字符",
|
||||
})
|
||||
@@ -85,7 +85,7 @@ func PostSetup(c *gin.Context) {
|
||||
}
|
||||
// Validate password
|
||||
if req.Password != req.ConfirmPassword {
|
||||
c.JSON(400, gin.H{
|
||||
c.JSON(200, gin.H{
|
||||
"success": false,
|
||||
"message": "两次输入的密码不一致",
|
||||
})
|
||||
@@ -93,7 +93,7 @@ func PostSetup(c *gin.Context) {
|
||||
}
|
||||
|
||||
if len(req.Password) < 8 {
|
||||
c.JSON(400, gin.H{
|
||||
c.JSON(200, gin.H{
|
||||
"success": false,
|
||||
"message": "密码长度至少为8个字符",
|
||||
})
|
||||
@@ -103,7 +103,7 @@ func PostSetup(c *gin.Context) {
|
||||
// Create root user
|
||||
hashedPassword, err := common.Password2Hash(req.Password)
|
||||
if err != nil {
|
||||
c.JSON(500, gin.H{
|
||||
c.JSON(200, gin.H{
|
||||
"success": false,
|
||||
"message": "系统错误: " + err.Error(),
|
||||
})
|
||||
@@ -120,7 +120,7 @@ func PostSetup(c *gin.Context) {
|
||||
}
|
||||
err = model.DB.Create(&rootUser).Error
|
||||
if err != nil {
|
||||
c.JSON(500, gin.H{
|
||||
c.JSON(200, gin.H{
|
||||
"success": false,
|
||||
"message": "创建管理员账号失败: " + err.Error(),
|
||||
})
|
||||
@@ -135,7 +135,7 @@ func PostSetup(c *gin.Context) {
|
||||
// Save operation modes to database for persistence
|
||||
err = model.UpdateOption("SelfUseModeEnabled", boolToString(req.SelfUseModeEnabled))
|
||||
if err != nil {
|
||||
c.JSON(500, gin.H{
|
||||
c.JSON(200, gin.H{
|
||||
"success": false,
|
||||
"message": "保存自用模式设置失败: " + err.Error(),
|
||||
})
|
||||
@@ -144,7 +144,7 @@ func PostSetup(c *gin.Context) {
|
||||
|
||||
err = model.UpdateOption("DemoSiteEnabled", boolToString(req.DemoSiteEnabled))
|
||||
if err != nil {
|
||||
c.JSON(500, gin.H{
|
||||
c.JSON(200, gin.H{
|
||||
"success": false,
|
||||
"message": "保存演示站点模式设置失败: " + err.Error(),
|
||||
})
|
||||
@@ -160,7 +160,7 @@ func PostSetup(c *gin.Context) {
|
||||
}
|
||||
err = model.DB.Create(&setup).Error
|
||||
if err != nil {
|
||||
c.JSON(500, gin.H{
|
||||
c.JSON(200, gin.H{
|
||||
"success": false,
|
||||
"message": "系统初始化失败: " + err.Error(),
|
||||
})
|
||||
@@ -178,4 +178,4 @@ func boolToString(b bool) string {
|
||||
return "true"
|
||||
}
|
||||
return "false"
|
||||
}
|
||||
}
|
||||
@@ -94,7 +94,7 @@ func updateVideoSingleTask(ctx context.Context, adaptor channel.TaskAdaptor, cha
|
||||
} else if taskResult, err = adaptor.ParseTaskResult(responseBody); err != nil {
|
||||
return fmt.Errorf("parseTaskResult failed for task %s: %w", taskId, err)
|
||||
} else {
|
||||
task.Data = responseBody
|
||||
task.Data = redactVideoResponseBody(responseBody)
|
||||
}
|
||||
|
||||
now := time.Now().Unix()
|
||||
@@ -117,7 +117,9 @@ func updateVideoSingleTask(ctx context.Context, adaptor channel.TaskAdaptor, cha
|
||||
if task.FinishTime == 0 {
|
||||
task.FinishTime = now
|
||||
}
|
||||
task.FailReason = taskResult.Url
|
||||
if !(len(taskResult.Url) > 5 && taskResult.Url[:5] == "data:") {
|
||||
task.FailReason = taskResult.Url
|
||||
}
|
||||
case model.TaskStatusFailure:
|
||||
task.Status = model.TaskStatusFailure
|
||||
task.Progress = "100%"
|
||||
@@ -146,3 +148,37 @@ func updateVideoSingleTask(ctx context.Context, adaptor channel.TaskAdaptor, cha
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func redactVideoResponseBody(body []byte) []byte {
|
||||
var m map[string]any
|
||||
if err := json.Unmarshal(body, &m); err != nil {
|
||||
return body
|
||||
}
|
||||
resp, _ := m["response"].(map[string]any)
|
||||
if resp != nil {
|
||||
delete(resp, "bytesBase64Encoded")
|
||||
if v, ok := resp["video"].(string); ok {
|
||||
resp["video"] = truncateBase64(v)
|
||||
}
|
||||
if vs, ok := resp["videos"].([]any); ok {
|
||||
for i := range vs {
|
||||
if vm, ok := vs[i].(map[string]any); ok {
|
||||
delete(vm, "bytesBase64Encoded")
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
b, err := json.Marshal(m)
|
||||
if err != nil {
|
||||
return body
|
||||
}
|
||||
return b
|
||||
}
|
||||
|
||||
func truncateBase64(s string) string {
|
||||
const maxKeep = 256
|
||||
if len(s) <= maxKeep {
|
||||
return s
|
||||
}
|
||||
return s[:maxKeep] + "..."
|
||||
}
|
||||
|
||||
@@ -9,6 +9,8 @@ import (
|
||||
"one-api/model"
|
||||
"one-api/service"
|
||||
"one-api/setting"
|
||||
"one-api/setting/operation_setting"
|
||||
"one-api/setting/system_setting"
|
||||
"strconv"
|
||||
"sync"
|
||||
"time"
|
||||
@@ -19,6 +21,44 @@ import (
|
||||
"github.com/shopspring/decimal"
|
||||
)
|
||||
|
||||
func GetTopUpInfo(c *gin.Context) {
|
||||
// 获取支付方式
|
||||
payMethods := operation_setting.PayMethods
|
||||
|
||||
// 如果启用了 Stripe 支付,添加到支付方法列表
|
||||
if setting.StripeApiSecret != "" && setting.StripeWebhookSecret != "" && setting.StripePriceId != "" {
|
||||
// 检查是否已经包含 Stripe
|
||||
hasStripe := false
|
||||
for _, method := range payMethods {
|
||||
if method["type"] == "stripe" {
|
||||
hasStripe = true
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
if !hasStripe {
|
||||
stripeMethod := map[string]string{
|
||||
"name": "Stripe",
|
||||
"type": "stripe",
|
||||
"color": "rgba(var(--semi-purple-5), 1)",
|
||||
"min_topup": strconv.Itoa(setting.StripeMinTopUp),
|
||||
}
|
||||
payMethods = append(payMethods, stripeMethod)
|
||||
}
|
||||
}
|
||||
|
||||
data := gin.H{
|
||||
"enable_online_topup": operation_setting.PayAddress != "" && operation_setting.EpayId != "" && operation_setting.EpayKey != "",
|
||||
"enable_stripe_topup": setting.StripeApiSecret != "" && setting.StripeWebhookSecret != "" && setting.StripePriceId != "",
|
||||
"pay_methods": payMethods,
|
||||
"min_topup": operation_setting.MinTopUp,
|
||||
"stripe_min_topup": setting.StripeMinTopUp,
|
||||
"amount_options": operation_setting.GetPaymentSetting().AmountOptions,
|
||||
"discount": operation_setting.GetPaymentSetting().AmountDiscount,
|
||||
}
|
||||
common.ApiSuccess(c, data)
|
||||
}
|
||||
|
||||
type EpayRequest struct {
|
||||
Amount int64 `json:"amount"`
|
||||
PaymentMethod string `json:"payment_method"`
|
||||
@@ -31,13 +71,13 @@ type AmountRequest struct {
|
||||
}
|
||||
|
||||
func GetEpayClient() *epay.Client {
|
||||
if setting.PayAddress == "" || setting.EpayId == "" || setting.EpayKey == "" {
|
||||
if operation_setting.PayAddress == "" || operation_setting.EpayId == "" || operation_setting.EpayKey == "" {
|
||||
return nil
|
||||
}
|
||||
withUrl, err := epay.NewClient(&epay.Config{
|
||||
PartnerID: setting.EpayId,
|
||||
Key: setting.EpayKey,
|
||||
}, setting.PayAddress)
|
||||
PartnerID: operation_setting.EpayId,
|
||||
Key: operation_setting.EpayKey,
|
||||
}, operation_setting.PayAddress)
|
||||
if err != nil {
|
||||
return nil
|
||||
}
|
||||
@@ -58,15 +98,23 @@ func getPayMoney(amount int64, group string) float64 {
|
||||
}
|
||||
|
||||
dTopupGroupRatio := decimal.NewFromFloat(topupGroupRatio)
|
||||
dPrice := decimal.NewFromFloat(setting.Price)
|
||||
dPrice := decimal.NewFromFloat(operation_setting.Price)
|
||||
// apply optional preset discount by the original request amount (if configured), default 1.0
|
||||
discount := 1.0
|
||||
if ds, ok := operation_setting.GetPaymentSetting().AmountDiscount[int(amount)]; ok {
|
||||
if ds > 0 {
|
||||
discount = ds
|
||||
}
|
||||
}
|
||||
dDiscount := decimal.NewFromFloat(discount)
|
||||
|
||||
payMoney := dAmount.Mul(dPrice).Mul(dTopupGroupRatio)
|
||||
payMoney := dAmount.Mul(dPrice).Mul(dTopupGroupRatio).Mul(dDiscount)
|
||||
|
||||
return payMoney.InexactFloat64()
|
||||
}
|
||||
|
||||
func getMinTopup() int64 {
|
||||
minTopup := setting.MinTopUp
|
||||
minTopup := operation_setting.MinTopUp
|
||||
if !common.DisplayInCurrencyEnabled {
|
||||
dMinTopup := decimal.NewFromInt(int64(minTopup))
|
||||
dQuotaPerUnit := decimal.NewFromFloat(common.QuotaPerUnit)
|
||||
@@ -99,13 +147,13 @@ func RequestEpay(c *gin.Context) {
|
||||
return
|
||||
}
|
||||
|
||||
if !setting.ContainsPayMethod(req.PaymentMethod) {
|
||||
if !operation_setting.ContainsPayMethod(req.PaymentMethod) {
|
||||
c.JSON(200, gin.H{"message": "error", "data": "支付方式不存在"})
|
||||
return
|
||||
}
|
||||
|
||||
callBackAddress := service.GetCallbackAddress()
|
||||
returnUrl, _ := url.Parse(setting.ServerAddress + "/console/log")
|
||||
returnUrl, _ := url.Parse(system_setting.ServerAddress + "/console/log")
|
||||
notifyUrl, _ := url.Parse(callBackAddress + "/api/user/epay/notify")
|
||||
tradeNo := fmt.Sprintf("%s%d", common.GetRandomString(6), time.Now().Unix())
|
||||
tradeNo = fmt.Sprintf("USR%dNO%s", id, tradeNo)
|
||||
|
||||
@@ -8,6 +8,8 @@ import (
|
||||
"one-api/common"
|
||||
"one-api/model"
|
||||
"one-api/setting"
|
||||
"one-api/setting/operation_setting"
|
||||
"one-api/setting/system_setting"
|
||||
"strconv"
|
||||
"strings"
|
||||
"time"
|
||||
@@ -215,8 +217,8 @@ func genStripeLink(referenceId string, customerId string, email string, amount i
|
||||
|
||||
params := &stripe.CheckoutSessionParams{
|
||||
ClientReferenceID: stripe.String(referenceId),
|
||||
SuccessURL: stripe.String(setting.ServerAddress + "/log"),
|
||||
CancelURL: stripe.String(setting.ServerAddress + "/topup"),
|
||||
SuccessURL: stripe.String(system_setting.ServerAddress + "/console/log"),
|
||||
CancelURL: stripe.String(system_setting.ServerAddress + "/topup"),
|
||||
LineItems: []*stripe.CheckoutSessionLineItemParams{
|
||||
{
|
||||
Price: stripe.String(setting.StripePriceId),
|
||||
@@ -254,6 +256,7 @@ func GetChargedAmount(count float64, user model.User) float64 {
|
||||
}
|
||||
|
||||
func getStripePayMoney(amount float64, group string) float64 {
|
||||
originalAmount := amount
|
||||
if !common.DisplayInCurrencyEnabled {
|
||||
amount = amount / common.QuotaPerUnit
|
||||
}
|
||||
@@ -262,7 +265,14 @@ func getStripePayMoney(amount float64, group string) float64 {
|
||||
if topupGroupRatio == 0 {
|
||||
topupGroupRatio = 1
|
||||
}
|
||||
payMoney := amount * setting.StripeUnitPrice * topupGroupRatio
|
||||
// apply optional preset discount by the original request amount (if configured), default 1.0
|
||||
discount := 1.0
|
||||
if ds, ok := operation_setting.GetPaymentSetting().AmountDiscount[int(originalAmount)]; ok {
|
||||
if ds > 0 {
|
||||
discount = ds
|
||||
}
|
||||
}
|
||||
payMoney := amount * setting.StripeUnitPrice * topupGroupRatio * discount
|
||||
return payMoney
|
||||
}
|
||||
|
||||
|
||||
@@ -9,6 +9,14 @@ type ChannelSettings struct {
|
||||
SystemPromptOverride bool `json:"system_prompt_override,omitempty"`
|
||||
}
|
||||
|
||||
type VertexKeyType string
|
||||
|
||||
const (
|
||||
VertexKeyTypeJSON VertexKeyType = "json"
|
||||
VertexKeyTypeAPIKey VertexKeyType = "api_key"
|
||||
)
|
||||
|
||||
type ChannelOtherSettings struct {
|
||||
AzureResponsesVersion string `json:"azure_responses_version,omitempty"`
|
||||
AzureResponsesVersion string `json:"azure_responses_version,omitempty"`
|
||||
VertexKeyType VertexKeyType `json:"vertex_key_type,omitempty"` // "json" or "api_key"
|
||||
}
|
||||
|
||||
@@ -2,12 +2,11 @@ package dto
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"github.com/gin-gonic/gin"
|
||||
"one-api/common"
|
||||
"one-api/logger"
|
||||
"one-api/types"
|
||||
"strings"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
type GeminiChatRequest struct {
|
||||
@@ -269,15 +268,14 @@ type GeminiChatResponse struct {
|
||||
}
|
||||
|
||||
type GeminiUsageMetadata struct {
|
||||
PromptTokenCount int `json:"promptTokenCount"`
|
||||
CandidatesTokenCount int `json:"candidatesTokenCount"`
|
||||
TotalTokenCount int `json:"totalTokenCount"`
|
||||
ThoughtsTokenCount int `json:"thoughtsTokenCount"`
|
||||
PromptTokensDetails []GeminiModalityTokenCount `json:"promptTokensDetails"`
|
||||
CandidatesTokensDetails []GeminiModalityTokenCount `json:"candidatesTokensDetails"`
|
||||
PromptTokenCount int `json:"promptTokenCount"`
|
||||
CandidatesTokenCount int `json:"candidatesTokenCount"`
|
||||
TotalTokenCount int `json:"totalTokenCount"`
|
||||
ThoughtsTokenCount int `json:"thoughtsTokenCount"`
|
||||
PromptTokensDetails []GeminiPromptTokensDetails `json:"promptTokensDetails"`
|
||||
}
|
||||
|
||||
type GeminiModalityTokenCount struct {
|
||||
type GeminiPromptTokensDetails struct {
|
||||
Modality string `json:"modality"`
|
||||
TokenCount int `json:"tokenCount"`
|
||||
}
|
||||
|
||||
@@ -59,6 +59,31 @@ func (i *ImageRequest) UnmarshalJSON(data []byte) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
// 序列化时需要重新把字段平铺
|
||||
func (r ImageRequest) MarshalJSON() ([]byte, error) {
|
||||
// 将已定义字段转为 map
|
||||
type Alias ImageRequest
|
||||
alias := Alias(r)
|
||||
base, err := common.Marshal(alias)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
var baseMap map[string]json.RawMessage
|
||||
if err := common.Unmarshal(base, &baseMap); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// 合并 ExtraFields
|
||||
for k, v := range r.Extra {
|
||||
if _, exists := baseMap[k]; !exists {
|
||||
baseMap[k] = v
|
||||
}
|
||||
}
|
||||
|
||||
return json.Marshal(baseMap)
|
||||
}
|
||||
|
||||
func GetJSONFieldNames(t reflect.Type) map[string]struct{} {
|
||||
fields := make(map[string]struct{})
|
||||
for i := 0; i < t.NumField(); i++ {
|
||||
|
||||
@@ -772,11 +772,12 @@ type OpenAIResponsesRequest struct {
|
||||
Instructions json.RawMessage `json:"instructions,omitempty"`
|
||||
MaxOutputTokens uint `json:"max_output_tokens,omitempty"`
|
||||
Metadata json.RawMessage `json:"metadata,omitempty"`
|
||||
ParallelToolCalls bool `json:"parallel_tool_calls,omitempty"`
|
||||
ParallelToolCalls json.RawMessage `json:"parallel_tool_calls,omitempty"`
|
||||
PreviousResponseID string `json:"previous_response_id,omitempty"`
|
||||
Reasoning *Reasoning `json:"reasoning,omitempty"`
|
||||
ServiceTier string `json:"service_tier,omitempty"`
|
||||
Store bool `json:"store,omitempty"`
|
||||
Store json.RawMessage `json:"store,omitempty"`
|
||||
PromptCacheKey json.RawMessage `json:"prompt_cache_key,omitempty"`
|
||||
Stream bool `json:"stream,omitempty"`
|
||||
Temperature float64 `json:"temperature,omitempty"`
|
||||
Text json.RawMessage `json:"text,omitempty"`
|
||||
|
||||
@@ -6,6 +6,10 @@ import (
|
||||
"one-api/types"
|
||||
)
|
||||
|
||||
const (
|
||||
ResponsesOutputTypeImageGenerationCall = "image_generation_call"
|
||||
)
|
||||
|
||||
type SimpleResponse struct {
|
||||
Usage `json:"usage"`
|
||||
Error any `json:"error"`
|
||||
@@ -273,6 +277,42 @@ func (o *OpenAIResponsesResponse) GetOpenAIError() *types.OpenAIError {
|
||||
return GetOpenAIError(o.Error)
|
||||
}
|
||||
|
||||
func (o *OpenAIResponsesResponse) HasImageGenerationCall() bool {
|
||||
if len(o.Output) == 0 {
|
||||
return false
|
||||
}
|
||||
for _, output := range o.Output {
|
||||
if output.Type == ResponsesOutputTypeImageGenerationCall {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func (o *OpenAIResponsesResponse) GetQuality() string {
|
||||
if len(o.Output) == 0 {
|
||||
return ""
|
||||
}
|
||||
for _, output := range o.Output {
|
||||
if output.Type == ResponsesOutputTypeImageGenerationCall {
|
||||
return output.Quality
|
||||
}
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
func (o *OpenAIResponsesResponse) GetSize() string {
|
||||
if len(o.Output) == 0 {
|
||||
return ""
|
||||
}
|
||||
for _, output := range o.Output {
|
||||
if output.Type == ResponsesOutputTypeImageGenerationCall {
|
||||
return output.Size
|
||||
}
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
type IncompleteDetails struct {
|
||||
Reasoning string `json:"reasoning"`
|
||||
}
|
||||
@@ -283,6 +323,8 @@ type ResponsesOutput struct {
|
||||
Status string `json:"status"`
|
||||
Role string `json:"role"`
|
||||
Content []ResponsesOutputContent `json:"content"`
|
||||
Quality string `json:"quality"`
|
||||
Size string `json:"size"`
|
||||
}
|
||||
|
||||
type ResponsesOutputContent struct {
|
||||
|
||||
@@ -166,9 +166,9 @@ func getModelRequest(c *gin.Context) (*ModelRequest, bool, error) {
|
||||
c.Set("platform", string(constant.TaskPlatformSuno))
|
||||
c.Set("relay_mode", relayMode)
|
||||
} else if strings.Contains(c.Request.URL.Path, "/v1/video/generations") {
|
||||
err = common.UnmarshalBodyReusable(c, &modelRequest)
|
||||
relayMode := relayconstant.RelayModeUnknown
|
||||
if c.Request.Method == http.MethodPost {
|
||||
err = common.UnmarshalBodyReusable(c, &modelRequest)
|
||||
relayMode = relayconstant.RelayModeVideoSubmit
|
||||
} else if c.Request.Method == http.MethodGet {
|
||||
relayMode = relayconstant.RelayModeVideoFetchByID
|
||||
|
||||
@@ -42,7 +42,6 @@ type Channel struct {
|
||||
Priority *int64 `json:"priority" gorm:"bigint;default:0"`
|
||||
AutoBan *int `json:"auto_ban" gorm:"default:1"`
|
||||
OtherInfo string `json:"other_info"`
|
||||
OtherSettings string `json:"settings" gorm:"column:settings"` // 其他设置
|
||||
Tag *string `json:"tag" gorm:"index"`
|
||||
Setting *string `json:"setting" gorm:"type:text"` // 渠道额外设置
|
||||
ParamOverride *string `json:"param_override" gorm:"type:text"`
|
||||
@@ -51,6 +50,8 @@ type Channel struct {
|
||||
// add after v0.8.5
|
||||
ChannelInfo ChannelInfo `json:"channel_info" gorm:"type:json"`
|
||||
|
||||
OtherSettings string `json:"settings" gorm:"column:settings"` // 其他设置,存储azure版本等不需要检索的信息,详见dto.ChannelOtherSettings
|
||||
|
||||
// cache info
|
||||
Keys []string `json:"-" gorm:"-"`
|
||||
}
|
||||
|
||||
@@ -6,6 +6,7 @@ import (
|
||||
"one-api/setting/config"
|
||||
"one-api/setting/operation_setting"
|
||||
"one-api/setting/ratio_setting"
|
||||
"one-api/setting/system_setting"
|
||||
"strconv"
|
||||
"strings"
|
||||
"time"
|
||||
@@ -66,16 +67,16 @@ func InitOptionMap() {
|
||||
common.OptionMap["SystemName"] = common.SystemName
|
||||
common.OptionMap["Logo"] = common.Logo
|
||||
common.OptionMap["ServerAddress"] = ""
|
||||
common.OptionMap["WorkerUrl"] = setting.WorkerUrl
|
||||
common.OptionMap["WorkerValidKey"] = setting.WorkerValidKey
|
||||
common.OptionMap["WorkerAllowHttpImageRequestEnabled"] = strconv.FormatBool(setting.WorkerAllowHttpImageRequestEnabled)
|
||||
common.OptionMap["WorkerUrl"] = system_setting.WorkerUrl
|
||||
common.OptionMap["WorkerValidKey"] = system_setting.WorkerValidKey
|
||||
common.OptionMap["WorkerAllowHttpImageRequestEnabled"] = strconv.FormatBool(system_setting.WorkerAllowHttpImageRequestEnabled)
|
||||
common.OptionMap["PayAddress"] = ""
|
||||
common.OptionMap["CustomCallbackAddress"] = ""
|
||||
common.OptionMap["EpayId"] = ""
|
||||
common.OptionMap["EpayKey"] = ""
|
||||
common.OptionMap["Price"] = strconv.FormatFloat(setting.Price, 'f', -1, 64)
|
||||
common.OptionMap["USDExchangeRate"] = strconv.FormatFloat(setting.USDExchangeRate, 'f', -1, 64)
|
||||
common.OptionMap["MinTopUp"] = strconv.Itoa(setting.MinTopUp)
|
||||
common.OptionMap["Price"] = strconv.FormatFloat(operation_setting.Price, 'f', -1, 64)
|
||||
common.OptionMap["USDExchangeRate"] = strconv.FormatFloat(operation_setting.USDExchangeRate, 'f', -1, 64)
|
||||
common.OptionMap["MinTopUp"] = strconv.Itoa(operation_setting.MinTopUp)
|
||||
common.OptionMap["StripeMinTopUp"] = strconv.Itoa(setting.StripeMinTopUp)
|
||||
common.OptionMap["StripeApiSecret"] = setting.StripeApiSecret
|
||||
common.OptionMap["StripeWebhookSecret"] = setting.StripeWebhookSecret
|
||||
@@ -85,7 +86,7 @@ func InitOptionMap() {
|
||||
common.OptionMap["Chats"] = setting.Chats2JsonString()
|
||||
common.OptionMap["AutoGroups"] = setting.AutoGroups2JsonString()
|
||||
common.OptionMap["DefaultUseAutoGroup"] = strconv.FormatBool(setting.DefaultUseAutoGroup)
|
||||
common.OptionMap["PayMethods"] = setting.PayMethods2JsonString()
|
||||
common.OptionMap["PayMethods"] = operation_setting.PayMethods2JsonString()
|
||||
common.OptionMap["GitHubClientId"] = ""
|
||||
common.OptionMap["GitHubClientSecret"] = ""
|
||||
common.OptionMap["TelegramBotToken"] = ""
|
||||
@@ -111,6 +112,9 @@ func InitOptionMap() {
|
||||
common.OptionMap["GroupGroupRatio"] = ratio_setting.GroupGroupRatio2JSONString()
|
||||
common.OptionMap["UserUsableGroups"] = setting.UserUsableGroups2JSONString()
|
||||
common.OptionMap["CompletionRatio"] = ratio_setting.CompletionRatio2JSONString()
|
||||
common.OptionMap["ImageRatio"] = ratio_setting.ImageRatio2JSONString()
|
||||
common.OptionMap["AudioRatio"] = ratio_setting.AudioRatio2JSONString()
|
||||
common.OptionMap["AudioCompletionRatio"] = ratio_setting.AudioCompletionRatio2JSONString()
|
||||
common.OptionMap["TopUpLink"] = common.TopUpLink
|
||||
//common.OptionMap["ChatLink"] = common.ChatLink
|
||||
//common.OptionMap["ChatLink2"] = common.ChatLink2
|
||||
@@ -271,7 +275,7 @@ func updateOptionMap(key string, value string) (err error) {
|
||||
case "SMTPSSLEnabled":
|
||||
common.SMTPSSLEnabled = boolValue
|
||||
case "WorkerAllowHttpImageRequestEnabled":
|
||||
setting.WorkerAllowHttpImageRequestEnabled = boolValue
|
||||
system_setting.WorkerAllowHttpImageRequestEnabled = boolValue
|
||||
case "DefaultUseAutoGroup":
|
||||
setting.DefaultUseAutoGroup = boolValue
|
||||
case "ExposeRatioEnabled":
|
||||
@@ -293,29 +297,29 @@ func updateOptionMap(key string, value string) (err error) {
|
||||
case "SMTPToken":
|
||||
common.SMTPToken = value
|
||||
case "ServerAddress":
|
||||
setting.ServerAddress = value
|
||||
system_setting.ServerAddress = value
|
||||
case "WorkerUrl":
|
||||
setting.WorkerUrl = value
|
||||
system_setting.WorkerUrl = value
|
||||
case "WorkerValidKey":
|
||||
setting.WorkerValidKey = value
|
||||
system_setting.WorkerValidKey = value
|
||||
case "PayAddress":
|
||||
setting.PayAddress = value
|
||||
operation_setting.PayAddress = value
|
||||
case "Chats":
|
||||
err = setting.UpdateChatsByJsonString(value)
|
||||
case "AutoGroups":
|
||||
err = setting.UpdateAutoGroupsByJsonString(value)
|
||||
case "CustomCallbackAddress":
|
||||
setting.CustomCallbackAddress = value
|
||||
operation_setting.CustomCallbackAddress = value
|
||||
case "EpayId":
|
||||
setting.EpayId = value
|
||||
operation_setting.EpayId = value
|
||||
case "EpayKey":
|
||||
setting.EpayKey = value
|
||||
operation_setting.EpayKey = value
|
||||
case "Price":
|
||||
setting.Price, _ = strconv.ParseFloat(value, 64)
|
||||
operation_setting.Price, _ = strconv.ParseFloat(value, 64)
|
||||
case "USDExchangeRate":
|
||||
setting.USDExchangeRate, _ = strconv.ParseFloat(value, 64)
|
||||
operation_setting.USDExchangeRate, _ = strconv.ParseFloat(value, 64)
|
||||
case "MinTopUp":
|
||||
setting.MinTopUp, _ = strconv.Atoi(value)
|
||||
operation_setting.MinTopUp, _ = strconv.Atoi(value)
|
||||
case "StripeApiSecret":
|
||||
setting.StripeApiSecret = value
|
||||
case "StripeWebhookSecret":
|
||||
@@ -396,6 +400,12 @@ func updateOptionMap(key string, value string) (err error) {
|
||||
err = ratio_setting.UpdateModelPriceByJSONString(value)
|
||||
case "CacheRatio":
|
||||
err = ratio_setting.UpdateCacheRatioByJSONString(value)
|
||||
case "ImageRatio":
|
||||
err = ratio_setting.UpdateImageRatioByJSONString(value)
|
||||
case "AudioRatio":
|
||||
err = ratio_setting.UpdateAudioRatioByJSONString(value)
|
||||
case "AudioCompletionRatio":
|
||||
err = ratio_setting.UpdateAudioCompletionRatioByJSONString(value)
|
||||
case "TopUpLink":
|
||||
common.TopUpLink = value
|
||||
//case "ChatLink":
|
||||
@@ -413,7 +423,7 @@ func updateOptionMap(key string, value string) (err error) {
|
||||
case "StreamCacheQueueLength":
|
||||
setting.StreamCacheQueueLength, _ = strconv.Atoi(value)
|
||||
case "PayMethods":
|
||||
err = setting.UpdatePayMethodsByJsonString(value)
|
||||
err = operation_setting.UpdatePayMethodsByJsonString(value)
|
||||
}
|
||||
return err
|
||||
}
|
||||
|
||||
@@ -53,7 +53,7 @@ func AudioHelper(c *gin.Context, info *relaycommon.RelayInfo) (newAPIError *type
|
||||
if resp != nil {
|
||||
httpResp = resp.(*http.Response)
|
||||
if httpResp.StatusCode != http.StatusOK {
|
||||
newAPIError = service.RelayErrorHandler(httpResp, false)
|
||||
newAPIError = service.RelayErrorHandler(c.Request.Context(), httpResp, false)
|
||||
// reset status code 重置状态码
|
||||
service.ResetStatusCode(newAPIError, statusCodeMappingStr)
|
||||
return newAPIError
|
||||
|
||||
@@ -264,9 +264,8 @@ func doRequest(c *gin.Context, req *http.Request, info *common.RelayInfo) (*http
|
||||
}
|
||||
|
||||
resp, err := client.Do(req)
|
||||
|
||||
if err != nil {
|
||||
return nil, err
|
||||
return nil, types.NewError(err, types.ErrorCodeDoRequestFailed, types.ErrOptionWithHideErrMsg("upstream error: do request failed"))
|
||||
}
|
||||
if resp == nil {
|
||||
return nil, errors.New("resp is nil")
|
||||
|
||||
@@ -60,7 +60,16 @@ func (a *Adaptor) ConvertOpenAIRequest(c *gin.Context, info *relaycommon.RelayIn
|
||||
if request == nil {
|
||||
return nil, errors.New("request is nil")
|
||||
}
|
||||
// 检查是否为Nova模型
|
||||
if isNovaModel(request.Model) {
|
||||
novaReq := convertToNovaRequest(request)
|
||||
c.Set("request_model", request.Model)
|
||||
c.Set("converted_request", novaReq)
|
||||
c.Set("is_nova_model", true)
|
||||
return novaReq, nil
|
||||
}
|
||||
|
||||
// 原有的Claude模型处理逻辑
|
||||
var claudeReq *dto.ClaudeRequest
|
||||
var err error
|
||||
claudeReq, err = claude.RequestOpenAI2ClaudeMessage(c, *request)
|
||||
@@ -69,6 +78,7 @@ func (a *Adaptor) ConvertOpenAIRequest(c *gin.Context, info *relaycommon.RelayIn
|
||||
}
|
||||
c.Set("request_model", claudeReq.Model)
|
||||
c.Set("converted_request", claudeReq)
|
||||
c.Set("is_nova_model", false)
|
||||
return claudeReq, err
|
||||
}
|
||||
|
||||
|
||||
@@ -1,5 +1,7 @@
|
||||
package aws
|
||||
|
||||
import "strings"
|
||||
|
||||
var awsModelIDMap = map[string]string{
|
||||
"claude-instant-1.2": "anthropic.claude-instant-v1",
|
||||
"claude-2.0": "anthropic.claude-v2",
|
||||
@@ -14,6 +16,11 @@ var awsModelIDMap = map[string]string{
|
||||
"claude-sonnet-4-20250514": "anthropic.claude-sonnet-4-20250514-v1:0",
|
||||
"claude-opus-4-20250514": "anthropic.claude-opus-4-20250514-v1:0",
|
||||
"claude-opus-4-1-20250805": "anthropic.claude-opus-4-1-20250805-v1:0",
|
||||
// Nova models
|
||||
"nova-micro-v1:0": "amazon.nova-micro-v1:0",
|
||||
"nova-lite-v1:0": "amazon.nova-lite-v1:0",
|
||||
"nova-pro-v1:0": "amazon.nova-pro-v1:0",
|
||||
"nova-premier-v1:0": "amazon.nova-premier-v1:0",
|
||||
}
|
||||
|
||||
var awsModelCanCrossRegionMap = map[string]map[string]bool{
|
||||
@@ -58,7 +65,27 @@ var awsModelCanCrossRegionMap = map[string]map[string]bool{
|
||||
"anthropic.claude-opus-4-1-20250805-v1:0": {
|
||||
"us": true,
|
||||
},
|
||||
}
|
||||
// Nova models - all support three major regions
|
||||
"amazon.nova-micro-v1:0": {
|
||||
"us": true,
|
||||
"eu": true,
|
||||
"apac": true,
|
||||
},
|
||||
"amazon.nova-lite-v1:0": {
|
||||
"us": true,
|
||||
"eu": true,
|
||||
"apac": true,
|
||||
},
|
||||
"amazon.nova-pro-v1:0": {
|
||||
"us": true,
|
||||
"eu": true,
|
||||
"apac": true,
|
||||
},
|
||||
"amazon.nova-premier-v1:0": {
|
||||
"us": true,
|
||||
"eu": true,
|
||||
"apac": true,
|
||||
}}
|
||||
|
||||
var awsRegionCrossModelPrefixMap = map[string]string{
|
||||
"us": "us",
|
||||
@@ -67,3 +94,8 @@ var awsRegionCrossModelPrefixMap = map[string]string{
|
||||
}
|
||||
|
||||
var ChannelName = "aws"
|
||||
|
||||
// 判断是否为Nova模型
|
||||
func isNovaModel(modelId string) bool {
|
||||
return strings.HasPrefix(modelId, "nova-")
|
||||
}
|
||||
|
||||
@@ -34,3 +34,92 @@ func copyRequest(req *dto.ClaudeRequest) *AwsClaudeRequest {
|
||||
Thinking: req.Thinking,
|
||||
}
|
||||
}
|
||||
|
||||
// NovaMessage Nova模型使用messages-v1格式
|
||||
type NovaMessage struct {
|
||||
Role string `json:"role"`
|
||||
Content []NovaContent `json:"content"`
|
||||
}
|
||||
|
||||
type NovaContent struct {
|
||||
Text string `json:"text"`
|
||||
}
|
||||
|
||||
type NovaRequest struct {
|
||||
SchemaVersion string `json:"schemaVersion"` // 请求版本,例如 "1.0"
|
||||
Messages []NovaMessage `json:"messages"` // 对话消息列表
|
||||
InferenceConfig *NovaInferenceConfig `json:"inferenceConfig,omitempty"` // 推理配置,可选
|
||||
}
|
||||
|
||||
type NovaInferenceConfig struct {
|
||||
MaxTokens int `json:"maxTokens,omitempty"` // 最大生成的 token 数
|
||||
Temperature float64 `json:"temperature,omitempty"` // 随机性 (默认 0.7, 范围 0-1)
|
||||
TopP float64 `json:"topP,omitempty"` // nucleus sampling (默认 0.9, 范围 0-1)
|
||||
TopK int `json:"topK,omitempty"` // 限制候选 token 数 (默认 50, 范围 0-128)
|
||||
StopSequences []string `json:"stopSequences,omitempty"` // 停止生成的序列
|
||||
}
|
||||
|
||||
// 转换OpenAI请求为Nova格式
|
||||
func convertToNovaRequest(req *dto.GeneralOpenAIRequest) *NovaRequest {
|
||||
novaMessages := make([]NovaMessage, len(req.Messages))
|
||||
for i, msg := range req.Messages {
|
||||
novaMessages[i] = NovaMessage{
|
||||
Role: msg.Role,
|
||||
Content: []NovaContent{{Text: msg.StringContent()}},
|
||||
}
|
||||
}
|
||||
|
||||
novaReq := &NovaRequest{
|
||||
SchemaVersion: "messages-v1",
|
||||
Messages: novaMessages,
|
||||
}
|
||||
|
||||
// 设置推理配置
|
||||
if req.MaxTokens != 0 || (req.Temperature != nil && *req.Temperature != 0) || req.TopP != 0 || req.TopK != 0 || req.Stop != nil {
|
||||
novaReq.InferenceConfig = &NovaInferenceConfig{}
|
||||
if req.MaxTokens != 0 {
|
||||
novaReq.InferenceConfig.MaxTokens = int(req.MaxTokens)
|
||||
}
|
||||
if req.Temperature != nil && *req.Temperature != 0 {
|
||||
novaReq.InferenceConfig.Temperature = *req.Temperature
|
||||
}
|
||||
if req.TopP != 0 {
|
||||
novaReq.InferenceConfig.TopP = req.TopP
|
||||
}
|
||||
if req.TopK != 0 {
|
||||
novaReq.InferenceConfig.TopK = req.TopK
|
||||
}
|
||||
if req.Stop != nil {
|
||||
if stopSequences := parseStopSequences(req.Stop); len(stopSequences) > 0 {
|
||||
novaReq.InferenceConfig.StopSequences = stopSequences
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return novaReq
|
||||
}
|
||||
|
||||
// parseStopSequences 解析停止序列,支持字符串或字符串数组
|
||||
func parseStopSequences(stop any) []string {
|
||||
if stop == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
switch v := stop.(type) {
|
||||
case string:
|
||||
if v != "" {
|
||||
return []string{v}
|
||||
}
|
||||
case []string:
|
||||
return v
|
||||
case []interface{}:
|
||||
var sequences []string
|
||||
for _, item := range v {
|
||||
if str, ok := item.(string); ok && str != "" {
|
||||
sequences = append(sequences, str)
|
||||
}
|
||||
}
|
||||
return sequences
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
package aws
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"one-api/common"
|
||||
@@ -93,7 +94,19 @@ func awsHandler(c *gin.Context, info *relaycommon.RelayInfo, requestMode int) (*
|
||||
}
|
||||
|
||||
awsModelId := awsModelID(c.GetString("request_model"))
|
||||
// 检查是否为Nova模型
|
||||
isNova, _ := c.Get("is_nova_model")
|
||||
if isNova == true {
|
||||
// Nova模型也支持跨区域
|
||||
awsRegionPrefix := awsRegionPrefix(awsCli.Options().Region)
|
||||
canCrossRegion := awsModelCanCrossRegion(awsModelId, awsRegionPrefix)
|
||||
if canCrossRegion {
|
||||
awsModelId = awsModelCrossRegion(awsModelId, awsRegionPrefix)
|
||||
}
|
||||
return handleNovaRequest(c, awsCli, info, awsModelId)
|
||||
}
|
||||
|
||||
// 原有的Claude处理逻辑
|
||||
awsRegionPrefix := awsRegionPrefix(awsCli.Options().Region)
|
||||
canCrossRegion := awsModelCanCrossRegion(awsModelId, awsRegionPrefix)
|
||||
if canCrossRegion {
|
||||
@@ -209,3 +222,74 @@ func awsStreamHandler(c *gin.Context, resp *http.Response, info *relaycommon.Rel
|
||||
claude.HandleStreamFinalResponse(c, info, claudeInfo, RequestModeMessage)
|
||||
return nil, claudeInfo.Usage
|
||||
}
|
||||
|
||||
// Nova模型处理函数
|
||||
func handleNovaRequest(c *gin.Context, awsCli *bedrockruntime.Client, info *relaycommon.RelayInfo, awsModelId string) (*types.NewAPIError, *dto.Usage) {
|
||||
novaReq_, ok := c.Get("converted_request")
|
||||
if !ok {
|
||||
return types.NewError(errors.New("nova request not found"), types.ErrorCodeInvalidRequest), nil
|
||||
}
|
||||
novaReq := novaReq_.(*NovaRequest)
|
||||
|
||||
// 使用InvokeModel API,但使用Nova格式的请求体
|
||||
awsReq := &bedrockruntime.InvokeModelInput{
|
||||
ModelId: aws.String(awsModelId),
|
||||
Accept: aws.String("application/json"),
|
||||
ContentType: aws.String("application/json"),
|
||||
}
|
||||
|
||||
reqBody, err := json.Marshal(novaReq)
|
||||
if err != nil {
|
||||
return types.NewError(errors.Wrap(err, "marshal nova request"), types.ErrorCodeBadResponseBody), nil
|
||||
}
|
||||
awsReq.Body = reqBody
|
||||
|
||||
awsResp, err := awsCli.InvokeModel(c.Request.Context(), awsReq)
|
||||
if err != nil {
|
||||
return types.NewError(errors.Wrap(err, "InvokeModel"), types.ErrorCodeChannelAwsClientError), nil
|
||||
}
|
||||
|
||||
// 解析Nova响应
|
||||
var novaResp struct {
|
||||
Output struct {
|
||||
Message struct {
|
||||
Content []struct {
|
||||
Text string `json:"text"`
|
||||
} `json:"content"`
|
||||
} `json:"message"`
|
||||
} `json:"output"`
|
||||
Usage struct {
|
||||
InputTokens int `json:"inputTokens"`
|
||||
OutputTokens int `json:"outputTokens"`
|
||||
TotalTokens int `json:"totalTokens"`
|
||||
} `json:"usage"`
|
||||
}
|
||||
|
||||
if err := json.Unmarshal(awsResp.Body, &novaResp); err != nil {
|
||||
return types.NewError(errors.Wrap(err, "unmarshal nova response"), types.ErrorCodeBadResponseBody), nil
|
||||
}
|
||||
|
||||
// 构造OpenAI格式响应
|
||||
response := dto.OpenAITextResponse{
|
||||
Id: helper.GetResponseID(c),
|
||||
Object: "chat.completion",
|
||||
Created: common.GetTimestamp(),
|
||||
Model: info.UpstreamModelName,
|
||||
Choices: []dto.OpenAITextResponseChoice{{
|
||||
Index: 0,
|
||||
Message: dto.Message{
|
||||
Role: "assistant",
|
||||
Content: novaResp.Output.Message.Content[0].Text,
|
||||
},
|
||||
FinishReason: "stop",
|
||||
}},
|
||||
Usage: dto.Usage{
|
||||
PromptTokens: novaResp.Usage.InputTokens,
|
||||
CompletionTokens: novaResp.Usage.OutputTokens,
|
||||
TotalTokens: novaResp.Usage.TotalTokens,
|
||||
},
|
||||
}
|
||||
|
||||
c.JSON(http.StatusOK, response)
|
||||
return nil, &response.Usage
|
||||
}
|
||||
|
||||
@@ -3,17 +3,17 @@ package deepseek
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"github.com/gin-gonic/gin"
|
||||
"io"
|
||||
"net/http"
|
||||
"one-api/dto"
|
||||
"one-api/relay/channel"
|
||||
"one-api/relay/channel/claude"
|
||||
"one-api/relay/channel/openai"
|
||||
relaycommon "one-api/relay/common"
|
||||
"one-api/relay/constant"
|
||||
"one-api/types"
|
||||
"strings"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
type Adaptor struct {
|
||||
@@ -25,7 +25,7 @@ func (a *Adaptor) ConvertGeminiRequest(*gin.Context, *relaycommon.RelayInfo, *dt
|
||||
}
|
||||
|
||||
func (a *Adaptor) ConvertClaudeRequest(c *gin.Context, info *relaycommon.RelayInfo, req *dto.ClaudeRequest) (any, error) {
|
||||
adaptor := openai.Adaptor{}
|
||||
adaptor := claude.Adaptor{}
|
||||
return adaptor.ConvertClaudeRequest(c, info, req)
|
||||
}
|
||||
|
||||
@@ -44,14 +44,19 @@ func (a *Adaptor) Init(info *relaycommon.RelayInfo) {
|
||||
|
||||
func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) {
|
||||
fimBaseUrl := info.ChannelBaseUrl
|
||||
if !strings.HasSuffix(info.ChannelBaseUrl, "/beta") {
|
||||
fimBaseUrl += "/beta"
|
||||
}
|
||||
switch info.RelayMode {
|
||||
case constant.RelayModeCompletions:
|
||||
return fmt.Sprintf("%s/completions", fimBaseUrl), nil
|
||||
switch info.RelayFormat {
|
||||
case types.RelayFormatClaude:
|
||||
return fmt.Sprintf("%s/anthropic/v1/messages", info.ChannelBaseUrl), nil
|
||||
default:
|
||||
return fmt.Sprintf("%s/v1/chat/completions", info.ChannelBaseUrl), nil
|
||||
if !strings.HasSuffix(info.ChannelBaseUrl, "/beta") {
|
||||
fimBaseUrl += "/beta"
|
||||
}
|
||||
switch info.RelayMode {
|
||||
case constant.RelayModeCompletions:
|
||||
return fmt.Sprintf("%s/completions", fimBaseUrl), nil
|
||||
default:
|
||||
return fmt.Sprintf("%s/v1/chat/completions", info.ChannelBaseUrl), nil
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -87,12 +92,17 @@ func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, request
|
||||
}
|
||||
|
||||
func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage any, err *types.NewAPIError) {
|
||||
if info.IsStream {
|
||||
usage, err = openai.OaiStreamHandler(c, info, resp)
|
||||
} else {
|
||||
usage, err = openai.OpenaiHandler(c, info, resp)
|
||||
switch info.RelayFormat {
|
||||
case types.RelayFormatClaude:
|
||||
if info.IsStream {
|
||||
return claude.ClaudeStreamHandler(c, resp, info, claude.RequestModeMessage)
|
||||
} else {
|
||||
return claude.ClaudeHandler(c, resp, info, claude.RequestModeMessage)
|
||||
}
|
||||
default:
|
||||
adaptor := openai.Adaptor{}
|
||||
return adaptor.DoResponse(c, resp, info)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
func (a *Adaptor) GetModelList() []string {
|
||||
|
||||
@@ -215,8 +215,8 @@ func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, request
|
||||
|
||||
func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage any, err *types.NewAPIError) {
|
||||
if info.RelayMode == constant.RelayModeGemini {
|
||||
if strings.HasSuffix(info.RequestURLPath, ":embedContent") ||
|
||||
strings.HasSuffix(info.RequestURLPath, ":batchEmbedContents") {
|
||||
if strings.Contains(info.RequestURLPath, ":embedContent") ||
|
||||
strings.Contains(info.RequestURLPath, ":batchEmbedContents") {
|
||||
return NativeGeminiEmbeddingHandler(c, resp, info)
|
||||
}
|
||||
if info.IsStream {
|
||||
|
||||
@@ -46,32 +46,6 @@ func GeminiTextGenerationHandler(c *gin.Context, info *relaycommon.RelayInfo, re
|
||||
|
||||
usage.CompletionTokenDetails.ReasoningTokens = geminiResponse.UsageMetadata.ThoughtsTokenCount
|
||||
|
||||
if strings.HasPrefix(info.UpstreamModelName, "gemini-2.5-flash-image-preview") {
|
||||
imageOutputCounts := 0
|
||||
for _, candidate := range geminiResponse.Candidates {
|
||||
for _, part := range candidate.Content.Parts {
|
||||
if part.InlineData != nil && strings.HasPrefix(part.InlineData.MimeType, "image/") {
|
||||
imageOutputCounts++
|
||||
}
|
||||
}
|
||||
}
|
||||
if imageOutputCounts != 0 {
|
||||
usage.CompletionTokens = usage.CompletionTokens - imageOutputCounts*1290
|
||||
usage.TotalTokens = usage.TotalTokens - imageOutputCounts*1290
|
||||
c.Set("gemini_image_tokens", imageOutputCounts*1290)
|
||||
}
|
||||
}
|
||||
|
||||
// if strings.HasPrefix(info.UpstreamModelName, "gemini-2.5-flash-image-preview") {
|
||||
// for _, detail := range geminiResponse.UsageMetadata.CandidatesTokensDetails {
|
||||
// if detail.Modality == "IMAGE" {
|
||||
// usage.CompletionTokens = usage.CompletionTokens - detail.TokenCount
|
||||
// usage.TotalTokens = usage.TotalTokens - detail.TokenCount
|
||||
// c.Set("gemini_image_tokens", detail.TokenCount)
|
||||
// }
|
||||
// }
|
||||
// }
|
||||
|
||||
for _, detail := range geminiResponse.UsageMetadata.PromptTokensDetails {
|
||||
if detail.Modality == "AUDIO" {
|
||||
usage.PromptTokensDetails.AudioTokens = detail.TokenCount
|
||||
@@ -162,16 +136,6 @@ func GeminiTextGenerationStreamHandler(c *gin.Context, info *relaycommon.RelayIn
|
||||
usage.PromptTokensDetails.TextTokens = detail.TokenCount
|
||||
}
|
||||
}
|
||||
|
||||
if strings.HasPrefix(info.UpstreamModelName, "gemini-2.5-flash-image-preview") {
|
||||
for _, detail := range geminiResponse.UsageMetadata.CandidatesTokensDetails {
|
||||
if detail.Modality == "IMAGE" {
|
||||
usage.CompletionTokens = usage.CompletionTokens - detail.TokenCount
|
||||
usage.TotalTokens = usage.TotalTokens - detail.TokenCount
|
||||
c.Set("gemini_image_tokens", detail.TokenCount)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// 直接发送 GeminiChatResponse 响应
|
||||
|
||||
@@ -23,6 +23,7 @@ import (
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
// https://cloud.google.com/vertex-ai/generative-ai/docs/model-reference/inference?hl=zh-cn#blob
|
||||
var geminiSupportedMimeTypes = map[string]bool{
|
||||
"application/pdf": true,
|
||||
"audio/mpeg": true,
|
||||
@@ -30,6 +31,7 @@ var geminiSupportedMimeTypes = map[string]bool{
|
||||
"audio/wav": true,
|
||||
"image/png": true,
|
||||
"image/jpeg": true,
|
||||
"image/webp": true,
|
||||
"text/plain": true,
|
||||
"video/mov": true,
|
||||
"video/mpeg": true,
|
||||
|
||||
@@ -25,7 +25,7 @@ func (a *Adaptor) ConvertGeminiRequest(*gin.Context, *relaycommon.RelayInfo, *dt
|
||||
}
|
||||
|
||||
func (a *Adaptor) ConvertClaudeRequest(c *gin.Context, info *relaycommon.RelayInfo, req *dto.ClaudeRequest) (any, error) {
|
||||
adaptor := openai.Adaptor{}
|
||||
adaptor := claude.Adaptor{}
|
||||
return adaptor.ConvertClaudeRequest(c, info, req)
|
||||
}
|
||||
|
||||
|
||||
@@ -33,6 +33,12 @@ func OaiResponsesHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http
|
||||
return nil, types.WithOpenAIError(*oaiError, resp.StatusCode)
|
||||
}
|
||||
|
||||
if responsesResponse.HasImageGenerationCall() {
|
||||
c.Set("image_generation_call", true)
|
||||
c.Set("image_generation_call_quality", responsesResponse.GetQuality())
|
||||
c.Set("image_generation_call_size", responsesResponse.GetSize())
|
||||
}
|
||||
|
||||
// 写入新的 response body
|
||||
service.IOCopyBytesGracefully(c, resp, responseBody)
|
||||
|
||||
@@ -80,18 +86,25 @@ func OaiResponsesStreamHandler(c *gin.Context, info *relaycommon.RelayInfo, resp
|
||||
sendResponsesStreamData(c, streamResponse, data)
|
||||
switch streamResponse.Type {
|
||||
case "response.completed":
|
||||
if streamResponse.Response != nil && streamResponse.Response.Usage != nil {
|
||||
if streamResponse.Response.Usage.InputTokens != 0 {
|
||||
usage.PromptTokens = streamResponse.Response.Usage.InputTokens
|
||||
if streamResponse.Response != nil {
|
||||
if streamResponse.Response.Usage != nil {
|
||||
if streamResponse.Response.Usage.InputTokens != 0 {
|
||||
usage.PromptTokens = streamResponse.Response.Usage.InputTokens
|
||||
}
|
||||
if streamResponse.Response.Usage.OutputTokens != 0 {
|
||||
usage.CompletionTokens = streamResponse.Response.Usage.OutputTokens
|
||||
}
|
||||
if streamResponse.Response.Usage.TotalTokens != 0 {
|
||||
usage.TotalTokens = streamResponse.Response.Usage.TotalTokens
|
||||
}
|
||||
if streamResponse.Response.Usage.InputTokensDetails != nil {
|
||||
usage.PromptTokensDetails.CachedTokens = streamResponse.Response.Usage.InputTokensDetails.CachedTokens
|
||||
}
|
||||
}
|
||||
if streamResponse.Response.Usage.OutputTokens != 0 {
|
||||
usage.CompletionTokens = streamResponse.Response.Usage.OutputTokens
|
||||
}
|
||||
if streamResponse.Response.Usage.TotalTokens != 0 {
|
||||
usage.TotalTokens = streamResponse.Response.Usage.TotalTokens
|
||||
}
|
||||
if streamResponse.Response.Usage.InputTokensDetails != nil {
|
||||
usage.PromptTokensDetails.CachedTokens = streamResponse.Response.Usage.InputTokensDetails.CachedTokens
|
||||
if streamResponse.Response.HasImageGenerationCall() {
|
||||
c.Set("image_generation_call", true)
|
||||
c.Set("image_generation_call_quality", streamResponse.Response.GetQuality())
|
||||
c.Set("image_generation_call_size", streamResponse.Response.GetSize())
|
||||
}
|
||||
}
|
||||
case "response.output_text.delta":
|
||||
|
||||
@@ -18,7 +18,6 @@ import (
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/pkg/errors"
|
||||
|
||||
"one-api/common"
|
||||
"one-api/constant"
|
||||
"one-api/dto"
|
||||
"one-api/relay/channel"
|
||||
@@ -37,6 +36,7 @@ type requestPayload struct {
|
||||
Prompt string `json:"prompt,omitempty"`
|
||||
Seed int64 `json:"seed"`
|
||||
AspectRatio string `json:"aspect_ratio"`
|
||||
Frames int `json:"frames,omitempty"`
|
||||
}
|
||||
|
||||
type responsePayload struct {
|
||||
@@ -89,26 +89,14 @@ func (a *TaskAdaptor) Init(info *relaycommon.RelayInfo) {
|
||||
// ValidateRequestAndSetAction parses body, validates fields and sets default action.
|
||||
func (a *TaskAdaptor) ValidateRequestAndSetAction(c *gin.Context, info *relaycommon.RelayInfo) (taskErr *dto.TaskError) {
|
||||
// Accept only POST /v1/video/generations as "generate" action.
|
||||
action := constant.TaskActionGenerate
|
||||
info.Action = action
|
||||
|
||||
req := relaycommon.TaskSubmitReq{}
|
||||
if err := common.UnmarshalBodyReusable(c, &req); err != nil {
|
||||
taskErr = service.TaskErrorWrapperLocal(err, "invalid_request", http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
if strings.TrimSpace(req.Prompt) == "" {
|
||||
taskErr = service.TaskErrorWrapperLocal(fmt.Errorf("prompt is required"), "invalid_request", http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
|
||||
// Store into context for later usage
|
||||
c.Set("task_request", req)
|
||||
return nil
|
||||
return relaycommon.ValidateBasicTaskRequest(c, info, constant.TaskActionGenerate)
|
||||
}
|
||||
|
||||
// BuildRequestURL constructs the upstream URL.
|
||||
func (a *TaskAdaptor) BuildRequestURL(info *relaycommon.RelayInfo) (string, error) {
|
||||
if isNewAPIRelay(info.ApiKey) {
|
||||
return fmt.Sprintf("%s/jimeng/?Action=CVSync2AsyncSubmitTask&Version=2022-08-31", a.baseURL), nil
|
||||
}
|
||||
return fmt.Sprintf("%s/?Action=CVSync2AsyncSubmitTask&Version=2022-08-31", a.baseURL), nil
|
||||
}
|
||||
|
||||
@@ -116,7 +104,12 @@ func (a *TaskAdaptor) BuildRequestURL(info *relaycommon.RelayInfo) (string, erro
|
||||
func (a *TaskAdaptor) BuildRequestHeader(c *gin.Context, req *http.Request, info *relaycommon.RelayInfo) error {
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
req.Header.Set("Accept", "application/json")
|
||||
return a.signRequest(req, a.accessKey, a.secretKey)
|
||||
if isNewAPIRelay(info.ApiKey) {
|
||||
req.Header.Set("Authorization", "Bearer "+info.ApiKey)
|
||||
} else {
|
||||
return a.signRequest(req, a.accessKey, a.secretKey)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// BuildRequestBody converts request into Jimeng specific format.
|
||||
@@ -176,6 +169,9 @@ func (a *TaskAdaptor) FetchTask(baseUrl, key string, body map[string]any) (*http
|
||||
}
|
||||
|
||||
uri := fmt.Sprintf("%s/?Action=CVSync2AsyncGetResult&Version=2022-08-31", baseUrl)
|
||||
if isNewAPIRelay(key) {
|
||||
uri = fmt.Sprintf("%s/jimeng/?Action=CVSync2AsyncGetResult&Version=2022-08-31", a.baseURL)
|
||||
}
|
||||
payload := map[string]string{
|
||||
"req_key": "jimeng_vgfm_t2v_l20", // This is fixed value from doc: https://www.volcengine.com/docs/85621/1544774
|
||||
"task_id": taskID,
|
||||
@@ -193,17 +189,20 @@ func (a *TaskAdaptor) FetchTask(baseUrl, key string, body map[string]any) (*http
|
||||
req.Header.Set("Accept", "application/json")
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
|
||||
keyParts := strings.Split(key, "|")
|
||||
if len(keyParts) != 2 {
|
||||
return nil, fmt.Errorf("invalid api key format for jimeng: expected 'ak|sk'")
|
||||
}
|
||||
accessKey := strings.TrimSpace(keyParts[0])
|
||||
secretKey := strings.TrimSpace(keyParts[1])
|
||||
if isNewAPIRelay(key) {
|
||||
req.Header.Set("Authorization", "Bearer "+key)
|
||||
} else {
|
||||
keyParts := strings.Split(key, "|")
|
||||
if len(keyParts) != 2 {
|
||||
return nil, fmt.Errorf("invalid api key format for jimeng: expected 'ak|sk'")
|
||||
}
|
||||
accessKey := strings.TrimSpace(keyParts[0])
|
||||
secretKey := strings.TrimSpace(keyParts[1])
|
||||
|
||||
if err := a.signRequest(req, accessKey, secretKey); err != nil {
|
||||
return nil, errors.Wrap(err, "sign request failed")
|
||||
if err := a.signRequest(req, accessKey, secretKey); err != nil {
|
||||
return nil, errors.Wrap(err, "sign request failed")
|
||||
}
|
||||
}
|
||||
|
||||
return service.GetHttpClient().Do(req)
|
||||
}
|
||||
|
||||
@@ -327,18 +326,23 @@ func hmacSHA256(key []byte, data []byte) []byte {
|
||||
|
||||
func (a *TaskAdaptor) convertToRequestPayload(req *relaycommon.TaskSubmitReq) (*requestPayload, error) {
|
||||
r := requestPayload{
|
||||
ReqKey: "jimeng_vgfm_i2v_l20",
|
||||
Prompt: req.Prompt,
|
||||
AspectRatio: "16:9", // Default aspect ratio
|
||||
Seed: -1, // Default to random
|
||||
ReqKey: req.Model,
|
||||
Prompt: req.Prompt,
|
||||
}
|
||||
|
||||
switch req.Duration {
|
||||
case 10:
|
||||
r.Frames = 241 // 24*10+1 = 241
|
||||
default:
|
||||
r.Frames = 121 // 24*5+1 = 121
|
||||
}
|
||||
|
||||
// Handle one-of image_urls or binary_data_base64
|
||||
if req.Image != "" {
|
||||
if strings.HasPrefix(req.Image, "http") {
|
||||
r.ImageUrls = []string{req.Image}
|
||||
if req.HasImage() {
|
||||
if strings.HasPrefix(req.Images[0], "http") {
|
||||
r.ImageUrls = req.Images
|
||||
} else {
|
||||
r.BinaryDataBase64 = []string{req.Image}
|
||||
r.BinaryDataBase64 = req.Images
|
||||
}
|
||||
}
|
||||
metadata := req.Metadata
|
||||
@@ -350,6 +354,22 @@ func (a *TaskAdaptor) convertToRequestPayload(req *relaycommon.TaskSubmitReq) (*
|
||||
if err != nil {
|
||||
return nil, errors.Wrap(err, "unmarshal metadata failed")
|
||||
}
|
||||
|
||||
// 即梦视频3.0 ReqKey转换
|
||||
// https://www.volcengine.com/docs/85621/1792707
|
||||
if strings.Contains(r.ReqKey, "jimeng_v30") {
|
||||
if len(r.ImageUrls) > 1 {
|
||||
// 多张图片:首尾帧生成
|
||||
r.ReqKey = strings.Replace(r.ReqKey, "jimeng_v30", "jimeng_i2v_first_tail_v30", 1)
|
||||
} else if len(r.ImageUrls) == 1 {
|
||||
// 单张图片:图生视频
|
||||
r.ReqKey = strings.Replace(r.ReqKey, "jimeng_v30", "jimeng_i2v_first_v30", 1)
|
||||
} else {
|
||||
// 无图片:文生视频
|
||||
r.ReqKey = strings.Replace(r.ReqKey, "jimeng_v30", "jimeng_t2v_v30", 1)
|
||||
}
|
||||
}
|
||||
|
||||
return &r, nil
|
||||
}
|
||||
|
||||
@@ -378,3 +398,7 @@ func (a *TaskAdaptor) ParseTaskResult(respBody []byte) (*relaycommon.TaskInfo, e
|
||||
taskResult.Url = resTask.Data.VideoUrl
|
||||
return &taskResult, nil
|
||||
}
|
||||
|
||||
func isNewAPIRelay(apiKey string) bool {
|
||||
return strings.HasPrefix(apiKey, "sk-")
|
||||
}
|
||||
|
||||
@@ -16,7 +16,6 @@ import (
|
||||
"github.com/golang-jwt/jwt"
|
||||
"github.com/pkg/errors"
|
||||
|
||||
"one-api/common"
|
||||
"one-api/constant"
|
||||
"one-api/dto"
|
||||
"one-api/relay/channel"
|
||||
@@ -28,16 +27,6 @@ import (
|
||||
// Request / Response structures
|
||||
// ============================
|
||||
|
||||
type SubmitReq struct {
|
||||
Prompt string `json:"prompt"`
|
||||
Model string `json:"model,omitempty"`
|
||||
Mode string `json:"mode,omitempty"`
|
||||
Image string `json:"image,omitempty"`
|
||||
Size string `json:"size,omitempty"`
|
||||
Duration int `json:"duration,omitempty"`
|
||||
Metadata map[string]interface{} `json:"metadata,omitempty"`
|
||||
}
|
||||
|
||||
type TrajectoryPoint struct {
|
||||
X int `json:"x"`
|
||||
Y int `json:"y"`
|
||||
@@ -121,28 +110,18 @@ func (a *TaskAdaptor) Init(info *relaycommon.RelayInfo) {
|
||||
|
||||
// ValidateRequestAndSetAction parses body, validates fields and sets default action.
|
||||
func (a *TaskAdaptor) ValidateRequestAndSetAction(c *gin.Context, info *relaycommon.RelayInfo) (taskErr *dto.TaskError) {
|
||||
// Accept only POST /v1/video/generations as "generate" action.
|
||||
action := constant.TaskActionGenerate
|
||||
info.Action = action
|
||||
|
||||
var req SubmitReq
|
||||
if err := common.UnmarshalBodyReusable(c, &req); err != nil {
|
||||
taskErr = service.TaskErrorWrapperLocal(err, "invalid_request", http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
if strings.TrimSpace(req.Prompt) == "" {
|
||||
taskErr = service.TaskErrorWrapperLocal(fmt.Errorf("prompt is required"), "invalid_request", http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
|
||||
// Store into context for later usage
|
||||
c.Set("task_request", req)
|
||||
return nil
|
||||
// Use the standard validation method for TaskSubmitReq
|
||||
return relaycommon.ValidateBasicTaskRequest(c, info, constant.TaskActionGenerate)
|
||||
}
|
||||
|
||||
// BuildRequestURL constructs the upstream URL.
|
||||
func (a *TaskAdaptor) BuildRequestURL(info *relaycommon.RelayInfo) (string, error) {
|
||||
path := lo.Ternary(info.Action == constant.TaskActionGenerate, "/v1/videos/image2video", "/v1/videos/text2video")
|
||||
|
||||
if isNewAPIRelay(info.ApiKey) {
|
||||
return fmt.Sprintf("%s/kling%s", a.baseURL, path), nil
|
||||
}
|
||||
|
||||
return fmt.Sprintf("%s%s", a.baseURL, path), nil
|
||||
}
|
||||
|
||||
@@ -166,7 +145,7 @@ func (a *TaskAdaptor) BuildRequestBody(c *gin.Context, info *relaycommon.RelayIn
|
||||
if !exists {
|
||||
return nil, fmt.Errorf("request not found in context")
|
||||
}
|
||||
req := v.(SubmitReq)
|
||||
req := v.(relaycommon.TaskSubmitReq)
|
||||
|
||||
body, err := a.convertToRequestPayload(&req)
|
||||
if err != nil {
|
||||
@@ -225,6 +204,9 @@ func (a *TaskAdaptor) FetchTask(baseUrl, key string, body map[string]any) (*http
|
||||
}
|
||||
path := lo.Ternary(action == constant.TaskActionGenerate, "/v1/videos/image2video", "/v1/videos/text2video")
|
||||
url := fmt.Sprintf("%s%s/%s", baseUrl, path, taskID)
|
||||
if isNewAPIRelay(key) {
|
||||
url = fmt.Sprintf("%s/kling%s/%s", baseUrl, path, taskID)
|
||||
}
|
||||
|
||||
req, err := http.NewRequest(http.MethodGet, url, nil)
|
||||
if err != nil {
|
||||
@@ -255,7 +237,7 @@ func (a *TaskAdaptor) GetChannelName() string {
|
||||
// helpers
|
||||
// ============================
|
||||
|
||||
func (a *TaskAdaptor) convertToRequestPayload(req *SubmitReq) (*requestPayload, error) {
|
||||
func (a *TaskAdaptor) convertToRequestPayload(req *relaycommon.TaskSubmitReq) (*requestPayload, error) {
|
||||
r := requestPayload{
|
||||
Prompt: req.Prompt,
|
||||
Image: req.Image,
|
||||
@@ -330,8 +312,13 @@ func (a *TaskAdaptor) createJWTToken() (string, error) {
|
||||
//}
|
||||
|
||||
func (a *TaskAdaptor) createJWTTokenWithKey(apiKey string) (string, error) {
|
||||
|
||||
if isNewAPIRelay(apiKey) {
|
||||
return apiKey, nil // new api relay
|
||||
}
|
||||
keyParts := strings.Split(apiKey, "|")
|
||||
if len(keyParts) != 2 {
|
||||
return "", errors.New("invalid api_key, required format is accessKey|secretKey")
|
||||
}
|
||||
accessKey := strings.TrimSpace(keyParts[0])
|
||||
if len(keyParts) == 1 {
|
||||
return accessKey, nil
|
||||
@@ -378,3 +365,7 @@ func (a *TaskAdaptor) ParseTaskResult(respBody []byte) (*relaycommon.TaskInfo, e
|
||||
}
|
||||
return taskInfo, nil
|
||||
}
|
||||
|
||||
func isNewAPIRelay(apiKey string) bool {
|
||||
return strings.HasPrefix(apiKey, "sk-")
|
||||
}
|
||||
|
||||
355
relay/channel/task/vertex/adaptor.go
Normal file
355
relay/channel/task/vertex/adaptor.go
Normal file
@@ -0,0 +1,355 @@
|
||||
package vertex
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/base64"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"one-api/model"
|
||||
"regexp"
|
||||
"strings"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
|
||||
"one-api/constant"
|
||||
"one-api/dto"
|
||||
"one-api/relay/channel"
|
||||
vertexcore "one-api/relay/channel/vertex"
|
||||
relaycommon "one-api/relay/common"
|
||||
"one-api/service"
|
||||
)
|
||||
|
||||
// ============================
|
||||
// Request / Response structures
|
||||
// ============================
|
||||
|
||||
type requestPayload struct {
|
||||
Instances []map[string]any `json:"instances"`
|
||||
Parameters map[string]any `json:"parameters,omitempty"`
|
||||
}
|
||||
|
||||
type submitResponse struct {
|
||||
Name string `json:"name"`
|
||||
}
|
||||
|
||||
type operationVideo struct {
|
||||
MimeType string `json:"mimeType"`
|
||||
BytesBase64Encoded string `json:"bytesBase64Encoded"`
|
||||
Encoding string `json:"encoding"`
|
||||
}
|
||||
|
||||
type operationResponse struct {
|
||||
Name string `json:"name"`
|
||||
Done bool `json:"done"`
|
||||
Response struct {
|
||||
Type string `json:"@type"`
|
||||
RaiMediaFilteredCount int `json:"raiMediaFilteredCount"`
|
||||
Videos []operationVideo `json:"videos"`
|
||||
BytesBase64Encoded string `json:"bytesBase64Encoded"`
|
||||
Encoding string `json:"encoding"`
|
||||
Video string `json:"video"`
|
||||
} `json:"response"`
|
||||
Error struct {
|
||||
Message string `json:"message"`
|
||||
} `json:"error"`
|
||||
}
|
||||
|
||||
// ============================
|
||||
// Adaptor implementation
|
||||
// ============================
|
||||
|
||||
type TaskAdaptor struct {
|
||||
ChannelType int
|
||||
apiKey string
|
||||
baseURL string
|
||||
}
|
||||
|
||||
func (a *TaskAdaptor) Init(info *relaycommon.RelayInfo) {
|
||||
a.ChannelType = info.ChannelType
|
||||
a.baseURL = info.ChannelBaseUrl
|
||||
a.apiKey = info.ApiKey
|
||||
}
|
||||
|
||||
// ValidateRequestAndSetAction parses body, validates fields and sets default action.
|
||||
func (a *TaskAdaptor) ValidateRequestAndSetAction(c *gin.Context, info *relaycommon.RelayInfo) (taskErr *dto.TaskError) {
|
||||
// Use the standard validation method for TaskSubmitReq
|
||||
return relaycommon.ValidateBasicTaskRequest(c, info, constant.TaskActionTextGenerate)
|
||||
}
|
||||
|
||||
// BuildRequestURL constructs the upstream URL.
|
||||
func (a *TaskAdaptor) BuildRequestURL(info *relaycommon.RelayInfo) (string, error) {
|
||||
adc := &vertexcore.Credentials{}
|
||||
if err := json.Unmarshal([]byte(a.apiKey), adc); err != nil {
|
||||
return "", fmt.Errorf("failed to decode credentials: %w", err)
|
||||
}
|
||||
modelName := info.OriginModelName
|
||||
if modelName == "" {
|
||||
modelName = "veo-3.0-generate-001"
|
||||
}
|
||||
|
||||
region := vertexcore.GetModelRegion(info.ApiVersion, modelName)
|
||||
if strings.TrimSpace(region) == "" {
|
||||
region = "global"
|
||||
}
|
||||
if region == "global" {
|
||||
return fmt.Sprintf(
|
||||
"https://aiplatform.googleapis.com/v1/projects/%s/locations/global/publishers/google/models/%s:predictLongRunning",
|
||||
adc.ProjectID,
|
||||
modelName,
|
||||
), nil
|
||||
}
|
||||
return fmt.Sprintf(
|
||||
"https://%s-aiplatform.googleapis.com/v1/projects/%s/locations/%s/publishers/google/models/%s:predictLongRunning",
|
||||
region,
|
||||
adc.ProjectID,
|
||||
region,
|
||||
modelName,
|
||||
), nil
|
||||
}
|
||||
|
||||
// BuildRequestHeader sets required headers.
|
||||
func (a *TaskAdaptor) BuildRequestHeader(c *gin.Context, req *http.Request, info *relaycommon.RelayInfo) error {
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
req.Header.Set("Accept", "application/json")
|
||||
|
||||
adc := &vertexcore.Credentials{}
|
||||
if err := json.Unmarshal([]byte(a.apiKey), adc); err != nil {
|
||||
return fmt.Errorf("failed to decode credentials: %w", err)
|
||||
}
|
||||
|
||||
token, err := vertexcore.AcquireAccessToken(*adc, "")
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to acquire access token: %w", err)
|
||||
}
|
||||
req.Header.Set("Authorization", "Bearer "+token)
|
||||
req.Header.Set("x-goog-user-project", adc.ProjectID)
|
||||
return nil
|
||||
}
|
||||
|
||||
// BuildRequestBody converts request into Vertex specific format.
|
||||
func (a *TaskAdaptor) BuildRequestBody(c *gin.Context, info *relaycommon.RelayInfo) (io.Reader, error) {
|
||||
v, ok := c.Get("task_request")
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("request not found in context")
|
||||
}
|
||||
req := v.(relaycommon.TaskSubmitReq)
|
||||
|
||||
body := requestPayload{
|
||||
Instances: []map[string]any{{"prompt": req.Prompt}},
|
||||
Parameters: map[string]any{},
|
||||
}
|
||||
if req.Metadata != nil {
|
||||
if v, ok := req.Metadata["storageUri"]; ok {
|
||||
body.Parameters["storageUri"] = v
|
||||
}
|
||||
if v, ok := req.Metadata["sampleCount"]; ok {
|
||||
body.Parameters["sampleCount"] = v
|
||||
}
|
||||
}
|
||||
if _, ok := body.Parameters["sampleCount"]; !ok {
|
||||
body.Parameters["sampleCount"] = 1
|
||||
}
|
||||
|
||||
data, err := json.Marshal(body)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return bytes.NewReader(data), nil
|
||||
}
|
||||
|
||||
// DoRequest delegates to common helper.
|
||||
func (a *TaskAdaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (*http.Response, error) {
|
||||
return channel.DoTaskApiRequest(a, c, info, requestBody)
|
||||
}
|
||||
|
||||
// DoResponse handles upstream response, returns taskID etc.
|
||||
func (a *TaskAdaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (taskID string, taskData []byte, taskErr *dto.TaskError) {
|
||||
responseBody, err := io.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
return "", nil, service.TaskErrorWrapper(err, "read_response_body_failed", http.StatusInternalServerError)
|
||||
}
|
||||
_ = resp.Body.Close()
|
||||
|
||||
var s submitResponse
|
||||
if err := json.Unmarshal(responseBody, &s); err != nil {
|
||||
return "", nil, service.TaskErrorWrapper(err, "unmarshal_response_failed", http.StatusInternalServerError)
|
||||
}
|
||||
if strings.TrimSpace(s.Name) == "" {
|
||||
return "", nil, service.TaskErrorWrapper(fmt.Errorf("missing operation name"), "invalid_response", http.StatusInternalServerError)
|
||||
}
|
||||
localID := encodeLocalTaskID(s.Name)
|
||||
c.JSON(http.StatusOK, gin.H{"task_id": localID})
|
||||
return localID, responseBody, nil
|
||||
}
|
||||
|
||||
func (a *TaskAdaptor) GetModelList() []string { return []string{"veo-3.0-generate-001"} }
|
||||
func (a *TaskAdaptor) GetChannelName() string { return "vertex" }
|
||||
|
||||
// FetchTask fetch task status
|
||||
func (a *TaskAdaptor) FetchTask(baseUrl, key string, body map[string]any) (*http.Response, error) {
|
||||
taskID, ok := body["task_id"].(string)
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("invalid task_id")
|
||||
}
|
||||
upstreamName, err := decodeLocalTaskID(taskID)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("decode task_id failed: %w", err)
|
||||
}
|
||||
region := extractRegionFromOperationName(upstreamName)
|
||||
if region == "" {
|
||||
region = "us-central1"
|
||||
}
|
||||
project := extractProjectFromOperationName(upstreamName)
|
||||
modelName := extractModelFromOperationName(upstreamName)
|
||||
if project == "" || modelName == "" {
|
||||
return nil, fmt.Errorf("cannot extract project/model from operation name")
|
||||
}
|
||||
var url string
|
||||
if region == "global" {
|
||||
url = fmt.Sprintf("https://aiplatform.googleapis.com/v1/projects/%s/locations/global/publishers/google/models/%s:fetchPredictOperation", project, modelName)
|
||||
} else {
|
||||
url = fmt.Sprintf("https://%s-aiplatform.googleapis.com/v1/projects/%s/locations/%s/publishers/google/models/%s:fetchPredictOperation", region, project, region, modelName)
|
||||
}
|
||||
payload := map[string]string{"operationName": upstreamName}
|
||||
data, err := json.Marshal(payload)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
adc := &vertexcore.Credentials{}
|
||||
if err := json.Unmarshal([]byte(key), adc); err != nil {
|
||||
return nil, fmt.Errorf("failed to decode credentials: %w", err)
|
||||
}
|
||||
token, err := vertexcore.AcquireAccessToken(*adc, "")
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to acquire access token: %w", err)
|
||||
}
|
||||
req, err := http.NewRequest(http.MethodPost, url, bytes.NewReader(data))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
req.Header.Set("Accept", "application/json")
|
||||
req.Header.Set("Authorization", "Bearer "+token)
|
||||
req.Header.Set("x-goog-user-project", adc.ProjectID)
|
||||
return service.GetHttpClient().Do(req)
|
||||
}
|
||||
|
||||
func (a *TaskAdaptor) ParseTaskResult(respBody []byte) (*relaycommon.TaskInfo, error) {
|
||||
var op operationResponse
|
||||
if err := json.Unmarshal(respBody, &op); err != nil {
|
||||
return nil, fmt.Errorf("unmarshal operation response failed: %w", err)
|
||||
}
|
||||
ti := &relaycommon.TaskInfo{}
|
||||
if op.Error.Message != "" {
|
||||
ti.Status = model.TaskStatusFailure
|
||||
ti.Reason = op.Error.Message
|
||||
ti.Progress = "100%"
|
||||
return ti, nil
|
||||
}
|
||||
if !op.Done {
|
||||
ti.Status = model.TaskStatusInProgress
|
||||
ti.Progress = "50%"
|
||||
return ti, nil
|
||||
}
|
||||
ti.Status = model.TaskStatusSuccess
|
||||
ti.Progress = "100%"
|
||||
if len(op.Response.Videos) > 0 {
|
||||
v0 := op.Response.Videos[0]
|
||||
if v0.BytesBase64Encoded != "" {
|
||||
mime := strings.TrimSpace(v0.MimeType)
|
||||
if mime == "" {
|
||||
enc := strings.TrimSpace(v0.Encoding)
|
||||
if enc == "" {
|
||||
enc = "mp4"
|
||||
}
|
||||
if strings.Contains(enc, "/") {
|
||||
mime = enc
|
||||
} else {
|
||||
mime = "video/" + enc
|
||||
}
|
||||
}
|
||||
ti.Url = "data:" + mime + ";base64," + v0.BytesBase64Encoded
|
||||
return ti, nil
|
||||
}
|
||||
}
|
||||
if op.Response.BytesBase64Encoded != "" {
|
||||
enc := strings.TrimSpace(op.Response.Encoding)
|
||||
if enc == "" {
|
||||
enc = "mp4"
|
||||
}
|
||||
mime := enc
|
||||
if !strings.Contains(enc, "/") {
|
||||
mime = "video/" + enc
|
||||
}
|
||||
ti.Url = "data:" + mime + ";base64," + op.Response.BytesBase64Encoded
|
||||
return ti, nil
|
||||
}
|
||||
if op.Response.Video != "" { // some variants use `video` as base64
|
||||
enc := strings.TrimSpace(op.Response.Encoding)
|
||||
if enc == "" {
|
||||
enc = "mp4"
|
||||
}
|
||||
mime := enc
|
||||
if !strings.Contains(enc, "/") {
|
||||
mime = "video/" + enc
|
||||
}
|
||||
ti.Url = "data:" + mime + ";base64," + op.Response.Video
|
||||
return ti, nil
|
||||
}
|
||||
return ti, nil
|
||||
}
|
||||
|
||||
// ============================
|
||||
// helpers
|
||||
// ============================
|
||||
|
||||
func encodeLocalTaskID(name string) string {
|
||||
return base64.RawURLEncoding.EncodeToString([]byte(name))
|
||||
}
|
||||
|
||||
func decodeLocalTaskID(local string) (string, error) {
|
||||
b, err := base64.RawURLEncoding.DecodeString(local)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
return string(b), nil
|
||||
}
|
||||
|
||||
var regionRe = regexp.MustCompile(`locations/([a-z0-9-]+)/`)
|
||||
|
||||
func extractRegionFromOperationName(name string) string {
|
||||
m := regionRe.FindStringSubmatch(name)
|
||||
if len(m) == 2 {
|
||||
return m[1]
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
var modelRe = regexp.MustCompile(`models/([^/]+)/operations/`)
|
||||
|
||||
func extractModelFromOperationName(name string) string {
|
||||
m := modelRe.FindStringSubmatch(name)
|
||||
if len(m) == 2 {
|
||||
return m[1]
|
||||
}
|
||||
idx := strings.Index(name, "models/")
|
||||
if idx >= 0 {
|
||||
s := name[idx+len("models/"):]
|
||||
if p := strings.Index(s, "/operations/"); p > 0 {
|
||||
return s[:p]
|
||||
}
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
var projectRe = regexp.MustCompile(`projects/([^/]+)/locations/`)
|
||||
|
||||
func extractProjectFromOperationName(name string) string {
|
||||
m := projectRe.FindStringSubmatch(name)
|
||||
if len(m) == 2 {
|
||||
return m[1]
|
||||
}
|
||||
return ""
|
||||
}
|
||||
@@ -23,16 +23,6 @@ import (
|
||||
// Request / Response structures
|
||||
// ============================
|
||||
|
||||
type SubmitReq struct {
|
||||
Prompt string `json:"prompt"`
|
||||
Model string `json:"model,omitempty"`
|
||||
Mode string `json:"mode,omitempty"`
|
||||
Image string `json:"image,omitempty"`
|
||||
Size string `json:"size,omitempty"`
|
||||
Duration int `json:"duration,omitempty"`
|
||||
Metadata map[string]interface{} `json:"metadata,omitempty"`
|
||||
}
|
||||
|
||||
type requestPayload struct {
|
||||
Model string `json:"model"`
|
||||
Images []string `json:"images"`
|
||||
@@ -90,23 +80,7 @@ func (a *TaskAdaptor) Init(info *relaycommon.RelayInfo) {
|
||||
}
|
||||
|
||||
func (a *TaskAdaptor) ValidateRequestAndSetAction(c *gin.Context, info *relaycommon.RelayInfo) *dto.TaskError {
|
||||
var req SubmitReq
|
||||
if err := c.ShouldBindJSON(&req); err != nil {
|
||||
return service.TaskErrorWrapper(err, "invalid_request_body", http.StatusBadRequest)
|
||||
}
|
||||
|
||||
if req.Prompt == "" {
|
||||
return service.TaskErrorWrapperLocal(fmt.Errorf("prompt is required"), "missing_prompt", http.StatusBadRequest)
|
||||
}
|
||||
|
||||
if req.Image != "" {
|
||||
info.Action = constant.TaskActionGenerate
|
||||
} else {
|
||||
info.Action = constant.TaskActionTextGenerate
|
||||
}
|
||||
|
||||
c.Set("task_request", req)
|
||||
return nil
|
||||
return relaycommon.ValidateBasicTaskRequest(c, info, constant.TaskActionGenerate)
|
||||
}
|
||||
|
||||
func (a *TaskAdaptor) BuildRequestBody(c *gin.Context, _ *relaycommon.RelayInfo) (io.Reader, error) {
|
||||
@@ -114,7 +88,7 @@ func (a *TaskAdaptor) BuildRequestBody(c *gin.Context, _ *relaycommon.RelayInfo)
|
||||
if !exists {
|
||||
return nil, fmt.Errorf("request not found in context")
|
||||
}
|
||||
req := v.(SubmitReq)
|
||||
req := v.(relaycommon.TaskSubmitReq)
|
||||
|
||||
body, err := a.convertToRequestPayload(&req)
|
||||
if err != nil {
|
||||
@@ -137,6 +111,10 @@ func (a *TaskAdaptor) BuildRequestURL(info *relaycommon.RelayInfo) (string, erro
|
||||
switch info.Action {
|
||||
case constant.TaskActionGenerate:
|
||||
path = "/img2video"
|
||||
case constant.TaskActionFirstTailGenerate:
|
||||
path = "/start-end2video"
|
||||
case constant.TaskActionReferenceGenerate:
|
||||
path = "/reference2video"
|
||||
default:
|
||||
path = "/text2video"
|
||||
}
|
||||
@@ -211,15 +189,10 @@ func (a *TaskAdaptor) GetChannelName() string {
|
||||
// helpers
|
||||
// ============================
|
||||
|
||||
func (a *TaskAdaptor) convertToRequestPayload(req *SubmitReq) (*requestPayload, error) {
|
||||
var images []string
|
||||
if req.Image != "" {
|
||||
images = []string{req.Image}
|
||||
}
|
||||
|
||||
func (a *TaskAdaptor) convertToRequestPayload(req *relaycommon.TaskSubmitReq) (*requestPayload, error) {
|
||||
r := requestPayload{
|
||||
Model: defaultString(req.Model, "viduq1"),
|
||||
Images: images,
|
||||
Images: req.Images,
|
||||
Prompt: req.Prompt,
|
||||
Duration: defaultInt(req.Duration, 5),
|
||||
Resolution: defaultString(req.Size, "1080p"),
|
||||
|
||||
@@ -6,6 +6,7 @@ import (
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"one-api/common"
|
||||
"one-api/dto"
|
||||
"one-api/relay/channel"
|
||||
"one-api/relay/channel/claude"
|
||||
@@ -80,16 +81,64 @@ func (a *Adaptor) Init(info *relaycommon.RelayInfo) {
|
||||
}
|
||||
}
|
||||
|
||||
func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) {
|
||||
adc := &Credentials{}
|
||||
if err := json.Unmarshal([]byte(info.ApiKey), adc); err != nil {
|
||||
return "", fmt.Errorf("failed to decode credentials file: %w", err)
|
||||
}
|
||||
func (a *Adaptor) getRequestUrl(info *relaycommon.RelayInfo, modelName, suffix string) (string, error) {
|
||||
region := GetModelRegion(info.ApiVersion, info.OriginModelName)
|
||||
a.AccountCredentials = *adc
|
||||
if info.ChannelOtherSettings.VertexKeyType != dto.VertexKeyTypeAPIKey {
|
||||
adc := &Credentials{}
|
||||
if err := common.Unmarshal([]byte(info.ApiKey), adc); err != nil {
|
||||
return "", fmt.Errorf("failed to decode credentials file: %w", err)
|
||||
}
|
||||
a.AccountCredentials = *adc
|
||||
|
||||
if a.RequestMode == RequestModeLlama {
|
||||
return fmt.Sprintf(
|
||||
"https://%s-aiplatform.googleapis.com/v1beta1/projects/%s/locations/%s/endpoints/openapi/chat/completions",
|
||||
region,
|
||||
adc.ProjectID,
|
||||
region,
|
||||
), nil
|
||||
}
|
||||
|
||||
if region == "global" {
|
||||
return fmt.Sprintf(
|
||||
"https://aiplatform.googleapis.com/v1/projects/%s/locations/global/publishers/google/models/%s:%s",
|
||||
adc.ProjectID,
|
||||
modelName,
|
||||
suffix,
|
||||
), nil
|
||||
} else {
|
||||
return fmt.Sprintf(
|
||||
"https://%s-aiplatform.googleapis.com/v1/projects/%s/locations/%s/publishers/google/models/%s:%s",
|
||||
region,
|
||||
adc.ProjectID,
|
||||
region,
|
||||
modelName,
|
||||
suffix,
|
||||
), nil
|
||||
}
|
||||
} else {
|
||||
if region == "global" {
|
||||
return fmt.Sprintf(
|
||||
"https://aiplatform.googleapis.com/v1/publishers/google/models/%s:%s?key=%s",
|
||||
modelName,
|
||||
suffix,
|
||||
info.ApiKey,
|
||||
), nil
|
||||
} else {
|
||||
return fmt.Sprintf(
|
||||
"https://%s-aiplatform.googleapis.com/v1/publishers/google/models/%s:%s?key=%s",
|
||||
region,
|
||||
modelName,
|
||||
suffix,
|
||||
info.ApiKey,
|
||||
), nil
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) {
|
||||
suffix := ""
|
||||
if a.RequestMode == RequestModeGemini {
|
||||
|
||||
if model_setting.GetGeminiSettings().ThinkingAdapterEnabled {
|
||||
// 新增逻辑:处理 -thinking-<budget> 格式
|
||||
if strings.Contains(info.UpstreamModelName, "-thinking-") {
|
||||
@@ -111,24 +160,7 @@ func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) {
|
||||
if strings.HasPrefix(info.UpstreamModelName, "imagen") {
|
||||
suffix = "predict"
|
||||
}
|
||||
|
||||
if region == "global" {
|
||||
return fmt.Sprintf(
|
||||
"https://aiplatform.googleapis.com/v1/projects/%s/locations/global/publishers/google/models/%s:%s",
|
||||
adc.ProjectID,
|
||||
info.UpstreamModelName,
|
||||
suffix,
|
||||
), nil
|
||||
} else {
|
||||
return fmt.Sprintf(
|
||||
"https://%s-aiplatform.googleapis.com/v1/projects/%s/locations/%s/publishers/google/models/%s:%s",
|
||||
region,
|
||||
adc.ProjectID,
|
||||
region,
|
||||
info.UpstreamModelName,
|
||||
suffix,
|
||||
), nil
|
||||
}
|
||||
return a.getRequestUrl(info, info.UpstreamModelName, suffix)
|
||||
} else if a.RequestMode == RequestModeClaude {
|
||||
if info.IsStream {
|
||||
suffix = "streamRawPredict?alt=sse"
|
||||
@@ -139,41 +171,25 @@ func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) {
|
||||
if v, ok := claudeModelMap[info.UpstreamModelName]; ok {
|
||||
model = v
|
||||
}
|
||||
if region == "global" {
|
||||
return fmt.Sprintf(
|
||||
"https://aiplatform.googleapis.com/v1/projects/%s/locations/global/publishers/anthropic/models/%s:%s",
|
||||
adc.ProjectID,
|
||||
model,
|
||||
suffix,
|
||||
), nil
|
||||
} else {
|
||||
return fmt.Sprintf(
|
||||
"https://%s-aiplatform.googleapis.com/v1/projects/%s/locations/%s/publishers/anthropic/models/%s:%s",
|
||||
region,
|
||||
adc.ProjectID,
|
||||
region,
|
||||
model,
|
||||
suffix,
|
||||
), nil
|
||||
}
|
||||
return a.getRequestUrl(info, model, suffix)
|
||||
} else if a.RequestMode == RequestModeLlama {
|
||||
return fmt.Sprintf(
|
||||
"https://%s-aiplatform.googleapis.com/v1beta1/projects/%s/locations/%s/endpoints/openapi/chat/completions",
|
||||
region,
|
||||
adc.ProjectID,
|
||||
region,
|
||||
), nil
|
||||
return a.getRequestUrl(info, "", "")
|
||||
}
|
||||
return "", errors.New("unsupported request mode")
|
||||
}
|
||||
|
||||
func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Header, info *relaycommon.RelayInfo) error {
|
||||
channel.SetupApiRequestHeader(info, c, req)
|
||||
accessToken, err := getAccessToken(a, info)
|
||||
if err != nil {
|
||||
return err
|
||||
if info.ChannelOtherSettings.VertexKeyType != dto.VertexKeyTypeAPIKey {
|
||||
accessToken, err := getAccessToken(a, info)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
req.Set("Authorization", "Bearer "+accessToken)
|
||||
}
|
||||
if a.AccountCredentials.ProjectID != "" {
|
||||
req.Set("x-goog-user-project", a.AccountCredentials.ProjectID)
|
||||
}
|
||||
req.Set("Authorization", "Bearer "+accessToken)
|
||||
return nil
|
||||
}
|
||||
|
||||
|
||||
@@ -12,7 +12,10 @@ func GetModelRegion(other string, localModelName string) string {
|
||||
if m[localModelName] != nil {
|
||||
return m[localModelName].(string)
|
||||
} else {
|
||||
return m["default"].(string)
|
||||
if v, ok := m["default"]; ok {
|
||||
return v.(string)
|
||||
}
|
||||
return "global"
|
||||
}
|
||||
}
|
||||
return other
|
||||
|
||||
@@ -6,14 +6,15 @@ import (
|
||||
"encoding/json"
|
||||
"encoding/pem"
|
||||
"errors"
|
||||
"github.com/bytedance/gopkg/cache/asynccache"
|
||||
"github.com/golang-jwt/jwt"
|
||||
"net/http"
|
||||
"net/url"
|
||||
relaycommon "one-api/relay/common"
|
||||
"one-api/service"
|
||||
"strings"
|
||||
|
||||
"github.com/bytedance/gopkg/cache/asynccache"
|
||||
"github.com/golang-jwt/jwt"
|
||||
|
||||
"fmt"
|
||||
"time"
|
||||
)
|
||||
@@ -137,3 +138,45 @@ func exchangeJwtForAccessToken(signedJWT string, info *relaycommon.RelayInfo) (s
|
||||
|
||||
return "", fmt.Errorf("failed to get access token: %v", result)
|
||||
}
|
||||
|
||||
func AcquireAccessToken(creds Credentials, proxy string) (string, error) {
|
||||
signedJWT, err := createSignedJWT(creds.ClientEmail, creds.PrivateKey)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("failed to create signed JWT: %w", err)
|
||||
}
|
||||
return exchangeJwtForAccessTokenWithProxy(signedJWT, proxy)
|
||||
}
|
||||
|
||||
func exchangeJwtForAccessTokenWithProxy(signedJWT string, proxy string) (string, error) {
|
||||
authURL := "https://www.googleapis.com/oauth2/v4/token"
|
||||
data := url.Values{}
|
||||
data.Set("grant_type", "urn:ietf:params:oauth:grant-type:jwt-bearer")
|
||||
data.Set("assertion", signedJWT)
|
||||
|
||||
var client *http.Client
|
||||
var err error
|
||||
if proxy != "" {
|
||||
client, err = service.NewProxyHttpClient(proxy)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("new proxy http client failed: %w", err)
|
||||
}
|
||||
} else {
|
||||
client = service.GetHttpClient()
|
||||
}
|
||||
|
||||
resp, err := client.PostForm(authURL, data)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
var result map[string]interface{}
|
||||
if err := json.NewDecoder(resp.Body).Decode(&result); err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
if accessToken, ok := result["access_token"].(string); ok {
|
||||
return accessToken, nil
|
||||
}
|
||||
return "", fmt.Errorf("failed to get access token: %v", result)
|
||||
}
|
||||
|
||||
@@ -41,6 +41,8 @@ func (a *Adaptor) ConvertAudioRequest(c *gin.Context, info *relaycommon.RelayInf
|
||||
|
||||
func (a *Adaptor) ConvertImageRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.ImageRequest) (any, error) {
|
||||
switch info.RelayMode {
|
||||
case constant.RelayModeImagesGenerations:
|
||||
return request, nil
|
||||
case constant.RelayModeImagesEdits:
|
||||
|
||||
var requestBody bytes.Buffer
|
||||
|
||||
@@ -8,6 +8,7 @@ var ModelList = []string{
|
||||
"Doubao-lite-32k",
|
||||
"Doubao-lite-4k",
|
||||
"Doubao-embedding",
|
||||
"doubao-seedream-4-0-250828",
|
||||
}
|
||||
|
||||
var ChannelName = "volcengine"
|
||||
|
||||
@@ -6,6 +6,7 @@ import (
|
||||
"io"
|
||||
"net/http"
|
||||
"one-api/common"
|
||||
"one-api/constant"
|
||||
"one-api/dto"
|
||||
relaycommon "one-api/relay/common"
|
||||
"one-api/relay/helper"
|
||||
@@ -69,6 +70,31 @@ func ClaudeHelper(c *gin.Context, info *relaycommon.RelayInfo) (newAPIError *typ
|
||||
info.UpstreamModelName = request.Model
|
||||
}
|
||||
|
||||
if info.ChannelSetting.SystemPrompt != "" {
|
||||
if request.System == nil {
|
||||
request.SetStringSystem(info.ChannelSetting.SystemPrompt)
|
||||
} else if info.ChannelSetting.SystemPromptOverride {
|
||||
common.SetContextKey(c, constant.ContextKeySystemPromptOverride, true)
|
||||
if request.IsStringSystem() {
|
||||
existing := strings.TrimSpace(request.GetStringSystem())
|
||||
if existing == "" {
|
||||
request.SetStringSystem(info.ChannelSetting.SystemPrompt)
|
||||
} else {
|
||||
request.SetStringSystem(info.ChannelSetting.SystemPrompt + "\n" + existing)
|
||||
}
|
||||
} else {
|
||||
systemContents := request.ParseSystem()
|
||||
newSystem := dto.ClaudeMediaMessage{Type: dto.ContentTypeText}
|
||||
newSystem.SetText(info.ChannelSetting.SystemPrompt)
|
||||
if len(systemContents) == 0 {
|
||||
request.System = []dto.ClaudeMediaMessage{newSystem}
|
||||
} else {
|
||||
request.System = append([]dto.ClaudeMediaMessage{newSystem}, systemContents...)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
var requestBody io.Reader
|
||||
if model_setting.GetGlobalSettings().PassThroughRequestEnabled || info.ChannelSetting.PassThroughBodyEnabled {
|
||||
body, err := common.GetRequestBody(c)
|
||||
@@ -111,7 +137,7 @@ func ClaudeHelper(c *gin.Context, info *relaycommon.RelayInfo) (newAPIError *typ
|
||||
httpResp = resp.(*http.Response)
|
||||
info.IsStream = info.IsStream || strings.HasPrefix(httpResp.Header.Get("Content-Type"), "text/event-stream")
|
||||
if httpResp.StatusCode != http.StatusOK {
|
||||
newAPIError = service.RelayErrorHandler(httpResp, false)
|
||||
newAPIError = service.RelayErrorHandler(c.Request.Context(), httpResp, false)
|
||||
// reset status code 重置状态码
|
||||
service.ResetStatusCode(newAPIError, statusCodeMappingStr)
|
||||
return newAPIError
|
||||
|
||||
@@ -481,11 +481,20 @@ type TaskSubmitReq struct {
|
||||
Model string `json:"model,omitempty"`
|
||||
Mode string `json:"mode,omitempty"`
|
||||
Image string `json:"image,omitempty"`
|
||||
Images []string `json:"images,omitempty"`
|
||||
Size string `json:"size,omitempty"`
|
||||
Duration int `json:"duration,omitempty"`
|
||||
Metadata map[string]interface{} `json:"metadata,omitempty"`
|
||||
}
|
||||
|
||||
func (t TaskSubmitReq) GetPrompt() string {
|
||||
return t.Prompt
|
||||
}
|
||||
|
||||
func (t TaskSubmitReq) HasImage() bool {
|
||||
return len(t.Images) > 0
|
||||
}
|
||||
|
||||
type TaskInfo struct {
|
||||
Code int `json:"code"`
|
||||
TaskID string `json:"task_id"`
|
||||
|
||||
@@ -2,12 +2,23 @@ package common
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"net/http"
|
||||
"one-api/common"
|
||||
"one-api/constant"
|
||||
"one-api/dto"
|
||||
"strings"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
type HasPrompt interface {
|
||||
GetPrompt() string
|
||||
}
|
||||
|
||||
type HasImage interface {
|
||||
HasImage() bool
|
||||
}
|
||||
|
||||
func GetFullRequestURL(baseURL string, requestURL string, channelType int) string {
|
||||
fullRequestURL := fmt.Sprintf("%s%s", baseURL, requestURL)
|
||||
|
||||
@@ -30,3 +41,56 @@ func GetAPIVersion(c *gin.Context) string {
|
||||
}
|
||||
return apiVersion
|
||||
}
|
||||
|
||||
func createTaskError(err error, code string, statusCode int, localError bool) *dto.TaskError {
|
||||
return &dto.TaskError{
|
||||
Code: code,
|
||||
Message: err.Error(),
|
||||
StatusCode: statusCode,
|
||||
LocalError: localError,
|
||||
Error: err,
|
||||
}
|
||||
}
|
||||
|
||||
func storeTaskRequest(c *gin.Context, info *RelayInfo, action string, requestObj interface{}) {
|
||||
info.Action = action
|
||||
c.Set("task_request", requestObj)
|
||||
}
|
||||
|
||||
func validatePrompt(prompt string) *dto.TaskError {
|
||||
if strings.TrimSpace(prompt) == "" {
|
||||
return createTaskError(fmt.Errorf("prompt is required"), "invalid_request", http.StatusBadRequest, true)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func ValidateBasicTaskRequest(c *gin.Context, info *RelayInfo, action string) *dto.TaskError {
|
||||
var req TaskSubmitReq
|
||||
if err := common.UnmarshalBodyReusable(c, &req); err != nil {
|
||||
return createTaskError(err, "invalid_request", http.StatusBadRequest, true)
|
||||
}
|
||||
|
||||
if taskErr := validatePrompt(req.Prompt); taskErr != nil {
|
||||
return taskErr
|
||||
}
|
||||
|
||||
if len(req.Images) == 0 && strings.TrimSpace(req.Image) != "" {
|
||||
// 兼容单图上传
|
||||
req.Images = []string{req.Image}
|
||||
}
|
||||
|
||||
if req.HasImage() {
|
||||
action = constant.TaskActionGenerate
|
||||
if info.ChannelType == constant.ChannelTypeVidu {
|
||||
// vidu 增加 首尾帧生视频和参考图生视频
|
||||
if len(req.Images) == 2 {
|
||||
action = constant.TaskActionFirstTailGenerate
|
||||
} else if len(req.Images) > 2 {
|
||||
action = constant.TaskActionReferenceGenerate
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
storeTaskRequest(c, info, action, req)
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -90,41 +90,43 @@ func TextHelper(c *gin.Context, info *relaycommon.RelayInfo) (newAPIError *types
|
||||
|
||||
if info.ChannelSetting.SystemPrompt != "" {
|
||||
// 如果有系统提示,则将其添加到请求中
|
||||
request := convertedRequest.(*dto.GeneralOpenAIRequest)
|
||||
containSystemPrompt := false
|
||||
for _, message := range request.Messages {
|
||||
if message.Role == request.GetSystemRoleName() {
|
||||
containSystemPrompt = true
|
||||
break
|
||||
}
|
||||
}
|
||||
if !containSystemPrompt {
|
||||
// 如果没有系统提示,则添加系统提示
|
||||
systemMessage := dto.Message{
|
||||
Role: request.GetSystemRoleName(),
|
||||
Content: info.ChannelSetting.SystemPrompt,
|
||||
}
|
||||
request.Messages = append([]dto.Message{systemMessage}, request.Messages...)
|
||||
} else if info.ChannelSetting.SystemPromptOverride {
|
||||
common.SetContextKey(c, constant.ContextKeySystemPromptOverride, true)
|
||||
// 如果有系统提示,且允许覆盖,则拼接到前面
|
||||
for i, message := range request.Messages {
|
||||
request, ok := convertedRequest.(*dto.GeneralOpenAIRequest)
|
||||
if ok {
|
||||
containSystemPrompt := false
|
||||
for _, message := range request.Messages {
|
||||
if message.Role == request.GetSystemRoleName() {
|
||||
if message.IsStringContent() {
|
||||
request.Messages[i].SetStringContent(info.ChannelSetting.SystemPrompt + "\n" + message.StringContent())
|
||||
} else {
|
||||
contents := message.ParseContent()
|
||||
contents = append([]dto.MediaContent{
|
||||
{
|
||||
Type: dto.ContentTypeText,
|
||||
Text: info.ChannelSetting.SystemPrompt,
|
||||
},
|
||||
}, contents...)
|
||||
request.Messages[i].Content = contents
|
||||
}
|
||||
containSystemPrompt = true
|
||||
break
|
||||
}
|
||||
}
|
||||
if !containSystemPrompt {
|
||||
// 如果没有系统提示,则添加系统提示
|
||||
systemMessage := dto.Message{
|
||||
Role: request.GetSystemRoleName(),
|
||||
Content: info.ChannelSetting.SystemPrompt,
|
||||
}
|
||||
request.Messages = append([]dto.Message{systemMessage}, request.Messages...)
|
||||
} else if info.ChannelSetting.SystemPromptOverride {
|
||||
common.SetContextKey(c, constant.ContextKeySystemPromptOverride, true)
|
||||
// 如果有系统提示,且允许覆盖,则拼接到前面
|
||||
for i, message := range request.Messages {
|
||||
if message.Role == request.GetSystemRoleName() {
|
||||
if message.IsStringContent() {
|
||||
request.Messages[i].SetStringContent(info.ChannelSetting.SystemPrompt + "\n" + message.StringContent())
|
||||
} else {
|
||||
contents := message.ParseContent()
|
||||
contents = append([]dto.MediaContent{
|
||||
{
|
||||
Type: dto.ContentTypeText,
|
||||
Text: info.ChannelSetting.SystemPrompt,
|
||||
},
|
||||
}, contents...)
|
||||
request.Messages[i].Content = contents
|
||||
}
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -158,7 +160,7 @@ func TextHelper(c *gin.Context, info *relaycommon.RelayInfo) (newAPIError *types
|
||||
httpResp = resp.(*http.Response)
|
||||
info.IsStream = info.IsStream || strings.HasPrefix(httpResp.Header.Get("Content-Type"), "text/event-stream")
|
||||
if httpResp.StatusCode != http.StatusOK {
|
||||
newApiErr := service.RelayErrorHandler(httpResp, false)
|
||||
newApiErr := service.RelayErrorHandler(c.Request.Context(), httpResp, false)
|
||||
// reset status code 重置状态码
|
||||
service.ResetStatusCode(newApiErr, statusCodeMappingStr)
|
||||
return newApiErr
|
||||
@@ -195,6 +197,8 @@ func postConsumeQuota(ctx *gin.Context, relayInfo *relaycommon.RelayInfo, usage
|
||||
imageTokens := usage.PromptTokensDetails.ImageTokens
|
||||
audioTokens := usage.PromptTokensDetails.AudioTokens
|
||||
completionTokens := usage.CompletionTokens
|
||||
cachedCreationTokens := usage.PromptTokensDetails.CachedCreationTokens
|
||||
|
||||
modelName := relayInfo.OriginModelName
|
||||
|
||||
tokenName := ctx.GetString("token_name")
|
||||
@@ -204,6 +208,7 @@ func postConsumeQuota(ctx *gin.Context, relayInfo *relaycommon.RelayInfo, usage
|
||||
modelRatio := relayInfo.PriceData.ModelRatio
|
||||
groupRatio := relayInfo.PriceData.GroupRatioInfo.GroupRatio
|
||||
modelPrice := relayInfo.PriceData.ModelPrice
|
||||
cachedCreationRatio := relayInfo.PriceData.CacheCreationRatio
|
||||
|
||||
// Convert values to decimal for precise calculation
|
||||
dPromptTokens := decimal.NewFromInt(int64(promptTokens))
|
||||
@@ -211,12 +216,14 @@ func postConsumeQuota(ctx *gin.Context, relayInfo *relaycommon.RelayInfo, usage
|
||||
dImageTokens := decimal.NewFromInt(int64(imageTokens))
|
||||
dAudioTokens := decimal.NewFromInt(int64(audioTokens))
|
||||
dCompletionTokens := decimal.NewFromInt(int64(completionTokens))
|
||||
dCachedCreationTokens := decimal.NewFromInt(int64(cachedCreationTokens))
|
||||
dCompletionRatio := decimal.NewFromFloat(completionRatio)
|
||||
dCacheRatio := decimal.NewFromFloat(cacheRatio)
|
||||
dImageRatio := decimal.NewFromFloat(imageRatio)
|
||||
dModelRatio := decimal.NewFromFloat(modelRatio)
|
||||
dGroupRatio := decimal.NewFromFloat(groupRatio)
|
||||
dModelPrice := decimal.NewFromFloat(modelPrice)
|
||||
dCachedCreationRatio := decimal.NewFromFloat(cachedCreationRatio)
|
||||
dQuotaPerUnit := decimal.NewFromFloat(common.QuotaPerUnit)
|
||||
|
||||
ratio := dModelRatio.Mul(dGroupRatio)
|
||||
@@ -271,6 +278,13 @@ func postConsumeQuota(ctx *gin.Context, relayInfo *relaycommon.RelayInfo, usage
|
||||
fileSearchTool.CallCount, dFileSearchQuota.String())
|
||||
}
|
||||
}
|
||||
var dImageGenerationCallQuota decimal.Decimal
|
||||
var imageGenerationCallPrice float64
|
||||
if ctx.GetBool("image_generation_call") {
|
||||
imageGenerationCallPrice = operation_setting.GetGPTImage1PriceOnceCall(ctx.GetString("image_generation_call_quality"), ctx.GetString("image_generation_call_size"))
|
||||
dImageGenerationCallQuota = decimal.NewFromFloat(imageGenerationCallPrice).Mul(dGroupRatio).Mul(dQuotaPerUnit)
|
||||
extraContent += fmt.Sprintf("Image Generation Call 花费 %s", dImageGenerationCallQuota.String())
|
||||
}
|
||||
|
||||
var quotaCalculateDecimal decimal.Decimal
|
||||
|
||||
@@ -284,6 +298,11 @@ func postConsumeQuota(ctx *gin.Context, relayInfo *relaycommon.RelayInfo, usage
|
||||
baseTokens = baseTokens.Sub(dCacheTokens)
|
||||
cachedTokensWithRatio = dCacheTokens.Mul(dCacheRatio)
|
||||
}
|
||||
var dCachedCreationTokensWithRatio decimal.Decimal
|
||||
if !dCachedCreationTokens.IsZero() {
|
||||
baseTokens = baseTokens.Sub(dCachedCreationTokens)
|
||||
dCachedCreationTokensWithRatio = dCachedCreationTokens.Mul(dCachedCreationRatio)
|
||||
}
|
||||
|
||||
// 减去 image tokens
|
||||
var imageTokensWithRatio decimal.Decimal
|
||||
@@ -302,7 +321,9 @@ func postConsumeQuota(ctx *gin.Context, relayInfo *relaycommon.RelayInfo, usage
|
||||
extraContent += fmt.Sprintf("Audio Input 花费 %s", audioInputQuota.String())
|
||||
}
|
||||
}
|
||||
promptQuota := baseTokens.Add(cachedTokensWithRatio).Add(imageTokensWithRatio)
|
||||
promptQuota := baseTokens.Add(cachedTokensWithRatio).
|
||||
Add(imageTokensWithRatio).
|
||||
Add(dCachedCreationTokensWithRatio)
|
||||
|
||||
completionQuota := dCompletionTokens.Mul(dCompletionRatio)
|
||||
|
||||
@@ -314,22 +335,13 @@ func postConsumeQuota(ctx *gin.Context, relayInfo *relaycommon.RelayInfo, usage
|
||||
} else {
|
||||
quotaCalculateDecimal = dModelPrice.Mul(dQuotaPerUnit).Mul(dGroupRatio)
|
||||
}
|
||||
var dGeminiImageOutputQuota decimal.Decimal
|
||||
var imageOutputPrice float64
|
||||
if strings.HasPrefix(modelName, "gemini-2.5-flash-image-preview") {
|
||||
imageOutputPrice = operation_setting.GetGeminiImageOutputPricePerMillionTokens(modelName)
|
||||
if imageOutputPrice > 0 {
|
||||
dImageOutputTokens := decimal.NewFromInt(int64(ctx.GetInt("gemini_image_tokens")))
|
||||
dGeminiImageOutputQuota = decimal.NewFromFloat(imageOutputPrice).Div(decimal.NewFromInt(1000000)).Mul(dImageOutputTokens).Mul(dGroupRatio).Mul(dQuotaPerUnit)
|
||||
}
|
||||
}
|
||||
// 添加 responses tools call 调用的配额
|
||||
quotaCalculateDecimal = quotaCalculateDecimal.Add(dWebSearchQuota)
|
||||
quotaCalculateDecimal = quotaCalculateDecimal.Add(dFileSearchQuota)
|
||||
// 添加 audio input 独立计费
|
||||
quotaCalculateDecimal = quotaCalculateDecimal.Add(audioInputQuota)
|
||||
// 添加 Gemini image output 计费
|
||||
quotaCalculateDecimal = quotaCalculateDecimal.Add(dGeminiImageOutputQuota)
|
||||
// 添加 image generation call 计费
|
||||
quotaCalculateDecimal = quotaCalculateDecimal.Add(dImageGenerationCallQuota)
|
||||
|
||||
quota := int(quotaCalculateDecimal.Round(0).IntPart())
|
||||
totalTokens := promptTokens + completionTokens
|
||||
@@ -395,6 +407,10 @@ func postConsumeQuota(ctx *gin.Context, relayInfo *relaycommon.RelayInfo, usage
|
||||
other["image_ratio"] = imageRatio
|
||||
other["image_output"] = imageTokens
|
||||
}
|
||||
if cachedCreationTokens != 0 {
|
||||
other["cache_creation_tokens"] = cachedCreationTokens
|
||||
other["cache_creation_ratio"] = cachedCreationRatio
|
||||
}
|
||||
if !dWebSearchQuota.IsZero() {
|
||||
if relayInfo.ResponsesUsageInfo != nil {
|
||||
if webSearchTool, exists := relayInfo.ResponsesUsageInfo.BuiltInTools[dto.BuildInToolWebSearchPreview]; exists {
|
||||
@@ -424,9 +440,9 @@ func postConsumeQuota(ctx *gin.Context, relayInfo *relaycommon.RelayInfo, usage
|
||||
other["audio_input_token_count"] = audioTokens
|
||||
other["audio_input_price"] = audioInputPrice
|
||||
}
|
||||
if !dGeminiImageOutputQuota.IsZero() {
|
||||
other["image_output_token_count"] = ctx.GetInt("gemini_image_tokens")
|
||||
other["image_output_price"] = imageOutputPrice
|
||||
if !dImageGenerationCallQuota.IsZero() {
|
||||
other["image_generation_call"] = true
|
||||
other["image_generation_call_price"] = imageGenerationCallPrice
|
||||
}
|
||||
model.RecordConsumeLog(ctx, relayInfo.UserId, model.RecordConsumeLogParams{
|
||||
ChannelId: relayInfo.ChannelId,
|
||||
|
||||
@@ -58,7 +58,7 @@ func EmbeddingHelper(c *gin.Context, info *relaycommon.RelayInfo) (newAPIError *
|
||||
if resp != nil {
|
||||
httpResp = resp.(*http.Response)
|
||||
if httpResp.StatusCode != http.StatusOK {
|
||||
newAPIError = service.RelayErrorHandler(httpResp, false)
|
||||
newAPIError = service.RelayErrorHandler(c.Request.Context(), httpResp, false)
|
||||
// reset status code 重置状态码
|
||||
service.ResetStatusCode(newAPIError, statusCodeMappingStr)
|
||||
return newAPIError
|
||||
|
||||
@@ -6,6 +6,7 @@ import (
|
||||
"io"
|
||||
"net/http"
|
||||
"one-api/common"
|
||||
"one-api/constant"
|
||||
"one-api/dto"
|
||||
"one-api/logger"
|
||||
"one-api/relay/channel/gemini"
|
||||
@@ -94,6 +95,32 @@ func GeminiHelper(c *gin.Context, info *relaycommon.RelayInfo) (newAPIError *typ
|
||||
|
||||
adaptor.Init(info)
|
||||
|
||||
if info.ChannelSetting.SystemPrompt != "" {
|
||||
if request.SystemInstructions == nil {
|
||||
request.SystemInstructions = &dto.GeminiChatContent{
|
||||
Parts: []dto.GeminiPart{
|
||||
{Text: info.ChannelSetting.SystemPrompt},
|
||||
},
|
||||
}
|
||||
} else if len(request.SystemInstructions.Parts) == 0 {
|
||||
request.SystemInstructions.Parts = []dto.GeminiPart{{Text: info.ChannelSetting.SystemPrompt}}
|
||||
} else if info.ChannelSetting.SystemPromptOverride {
|
||||
common.SetContextKey(c, constant.ContextKeySystemPromptOverride, true)
|
||||
merged := false
|
||||
for i := range request.SystemInstructions.Parts {
|
||||
if request.SystemInstructions.Parts[i].Text == "" {
|
||||
continue
|
||||
}
|
||||
request.SystemInstructions.Parts[i].Text = info.ChannelSetting.SystemPrompt + "\n" + request.SystemInstructions.Parts[i].Text
|
||||
merged = true
|
||||
break
|
||||
}
|
||||
if !merged {
|
||||
request.SystemInstructions.Parts = append([]dto.GeminiPart{{Text: info.ChannelSetting.SystemPrompt}}, request.SystemInstructions.Parts...)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Clean up empty system instruction
|
||||
if request.SystemInstructions != nil {
|
||||
hasContent := false
|
||||
@@ -152,7 +179,7 @@ func GeminiHelper(c *gin.Context, info *relaycommon.RelayInfo) (newAPIError *typ
|
||||
httpResp = resp.(*http.Response)
|
||||
info.IsStream = info.IsStream || strings.HasPrefix(httpResp.Header.Get("Content-Type"), "text/event-stream")
|
||||
if httpResp.StatusCode != http.StatusOK {
|
||||
newAPIError = service.RelayErrorHandler(httpResp, false)
|
||||
newAPIError = service.RelayErrorHandler(c.Request.Context(), httpResp, false)
|
||||
// reset status code 重置状态码
|
||||
service.ResetStatusCode(newAPIError, statusCodeMappingStr)
|
||||
return newAPIError
|
||||
@@ -249,7 +276,7 @@ func GeminiEmbeddingHandler(c *gin.Context, info *relaycommon.RelayInfo) (newAPI
|
||||
if resp != nil {
|
||||
httpResp = resp.(*http.Response)
|
||||
if httpResp.StatusCode != http.StatusOK {
|
||||
newAPIError = service.RelayErrorHandler(httpResp, false)
|
||||
newAPIError = service.RelayErrorHandler(c.Request.Context(), httpResp, false)
|
||||
service.ResetStatusCode(newAPIError, statusCodeMappingStr)
|
||||
return newAPIError
|
||||
}
|
||||
|
||||
@@ -52,6 +52,8 @@ func ModelPriceHelper(c *gin.Context, info *relaycommon.RelayInfo, promptTokens
|
||||
var cacheRatio float64
|
||||
var imageRatio float64
|
||||
var cacheCreationRatio float64
|
||||
var audioRatio float64
|
||||
var audioCompletionRatio float64
|
||||
if !usePrice {
|
||||
preConsumedTokens := common.Max(promptTokens, common.PreConsumedQuota)
|
||||
if meta.MaxTokens != 0 {
|
||||
@@ -73,6 +75,8 @@ func ModelPriceHelper(c *gin.Context, info *relaycommon.RelayInfo, promptTokens
|
||||
cacheRatio, _ = ratio_setting.GetCacheRatio(info.OriginModelName)
|
||||
cacheCreationRatio, _ = ratio_setting.GetCreateCacheRatio(info.OriginModelName)
|
||||
imageRatio, _ = ratio_setting.GetImageRatio(info.OriginModelName)
|
||||
audioRatio = ratio_setting.GetAudioRatio(info.OriginModelName)
|
||||
audioCompletionRatio = ratio_setting.GetAudioCompletionRatio(info.OriginModelName)
|
||||
ratio := modelRatio * groupRatioInfo.GroupRatio
|
||||
preConsumedQuota = int(float64(preConsumedTokens) * ratio)
|
||||
} else {
|
||||
@@ -90,6 +94,8 @@ func ModelPriceHelper(c *gin.Context, info *relaycommon.RelayInfo, promptTokens
|
||||
UsePrice: usePrice,
|
||||
CacheRatio: cacheRatio,
|
||||
ImageRatio: imageRatio,
|
||||
AudioRatio: audioRatio,
|
||||
AudioCompletionRatio: audioCompletionRatio,
|
||||
CacheCreationRatio: cacheCreationRatio,
|
||||
ShouldPreConsumedQuota: preConsumedQuota,
|
||||
}
|
||||
|
||||
@@ -21,7 +21,11 @@ func GetAndValidateRequest(c *gin.Context, format types.RelayFormat) (request dt
|
||||
case types.RelayFormatOpenAI:
|
||||
request, err = GetAndValidateTextRequest(c, relayMode)
|
||||
case types.RelayFormatGemini:
|
||||
request, err = GetAndValidateGeminiRequest(c)
|
||||
if strings.Contains(c.Request.URL.Path, ":embedContent") || strings.Contains(c.Request.URL.Path, ":batchEmbedContents") {
|
||||
request, err = GetAndValidateGeminiEmbeddingRequest(c)
|
||||
} else {
|
||||
request, err = GetAndValidateGeminiRequest(c)
|
||||
}
|
||||
case types.RelayFormatClaude:
|
||||
request, err = GetAndValidateClaudeRequest(c)
|
||||
case types.RelayFormatOpenAIResponses:
|
||||
@@ -288,7 +292,6 @@ func GetAndValidateTextRequest(c *gin.Context, relayMode int) (*dto.GeneralOpenA
|
||||
}
|
||||
|
||||
func GetAndValidateGeminiRequest(c *gin.Context) (*dto.GeminiChatRequest, error) {
|
||||
|
||||
request := &dto.GeminiChatRequest{}
|
||||
err := common.UnmarshalBodyReusable(c, request)
|
||||
if err != nil {
|
||||
@@ -304,3 +307,12 @@ func GetAndValidateGeminiRequest(c *gin.Context) (*dto.GeminiChatRequest, error)
|
||||
|
||||
return request, nil
|
||||
}
|
||||
|
||||
func GetAndValidateGeminiEmbeddingRequest(c *gin.Context) (*dto.GeminiEmbeddingRequest, error) {
|
||||
request := &dto.GeminiEmbeddingRequest{}
|
||||
err := common.UnmarshalBodyReusable(c, request)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return request, nil
|
||||
}
|
||||
|
||||
@@ -91,7 +91,7 @@ func ImageHelper(c *gin.Context, info *relaycommon.RelayInfo) (newAPIError *type
|
||||
httpResp = resp.(*http.Response)
|
||||
info.IsStream = info.IsStream || strings.HasPrefix(httpResp.Header.Get("Content-Type"), "text/event-stream")
|
||||
if httpResp.StatusCode != http.StatusOK {
|
||||
newAPIError = service.RelayErrorHandler(httpResp, false)
|
||||
newAPIError = service.RelayErrorHandler(c.Request.Context(), httpResp, false)
|
||||
// reset status code 重置状态码
|
||||
service.ResetStatusCode(newAPIError, statusCodeMappingStr)
|
||||
return newAPIError
|
||||
@@ -120,7 +120,7 @@ func ImageHelper(c *gin.Context, info *relaycommon.RelayInfo) (newAPIError *type
|
||||
var logContent string
|
||||
|
||||
if len(request.Size) > 0 {
|
||||
logContent = fmt.Sprintf("大小 %s, 品质 %s", request.Size, quality)
|
||||
logContent = fmt.Sprintf("大小 %s, 品质 %s, 张数 %d", request.Size, quality, request.N)
|
||||
}
|
||||
|
||||
postConsumeQuota(c, info, usage.(*dto.Usage), logContent)
|
||||
|
||||
@@ -16,6 +16,7 @@ import (
|
||||
"one-api/relay/helper"
|
||||
"one-api/service"
|
||||
"one-api/setting"
|
||||
"one-api/setting/system_setting"
|
||||
"strconv"
|
||||
"strings"
|
||||
"time"
|
||||
@@ -131,7 +132,7 @@ func coverMidjourneyTaskDto(c *gin.Context, originTask *model.Midjourney) (midjo
|
||||
midjourneyTask.FinishTime = originTask.FinishTime
|
||||
midjourneyTask.ImageUrl = ""
|
||||
if originTask.ImageUrl != "" && setting.MjForwardUrlEnabled {
|
||||
midjourneyTask.ImageUrl = setting.ServerAddress + "/mj/image/" + originTask.MjId
|
||||
midjourneyTask.ImageUrl = system_setting.ServerAddress + "/mj/image/" + originTask.MjId
|
||||
if originTask.Status != "SUCCESS" {
|
||||
midjourneyTask.ImageUrl += "?rand=" + strconv.FormatInt(time.Now().UnixNano(), 10)
|
||||
}
|
||||
|
||||
@@ -1,7 +1,6 @@
|
||||
package relay
|
||||
|
||||
import (
|
||||
"github.com/gin-gonic/gin"
|
||||
"one-api/constant"
|
||||
"one-api/relay/channel"
|
||||
"one-api/relay/channel/ali"
|
||||
@@ -28,6 +27,7 @@ import (
|
||||
taskjimeng "one-api/relay/channel/task/jimeng"
|
||||
"one-api/relay/channel/task/kling"
|
||||
"one-api/relay/channel/task/suno"
|
||||
taskvertex "one-api/relay/channel/task/vertex"
|
||||
taskVidu "one-api/relay/channel/task/vidu"
|
||||
"one-api/relay/channel/tencent"
|
||||
"one-api/relay/channel/vertex"
|
||||
@@ -37,6 +37,8 @@ import (
|
||||
"one-api/relay/channel/zhipu"
|
||||
"one-api/relay/channel/zhipu_4v"
|
||||
"strconv"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
func GetAdaptor(apiType int) channel.Adaptor {
|
||||
@@ -126,6 +128,8 @@ func GetTaskAdaptor(platform constant.TaskPlatform) channel.TaskAdaptor {
|
||||
return &kling.TaskAdaptor{}
|
||||
case constant.ChannelTypeJimeng:
|
||||
return &taskjimeng.TaskAdaptor{}
|
||||
case constant.ChannelTypeVertexAi:
|
||||
return &taskvertex.TaskAdaptor{}
|
||||
case constant.ChannelTypeVidu:
|
||||
return &taskVidu.TaskAdaptor{}
|
||||
}
|
||||
|
||||
@@ -15,6 +15,8 @@ import (
|
||||
relayconstant "one-api/relay/constant"
|
||||
"one-api/service"
|
||||
"one-api/setting/ratio_setting"
|
||||
"strconv"
|
||||
"strings"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
@@ -33,6 +35,7 @@ func RelayTaskSubmit(c *gin.Context, info *relaycommon.RelayInfo) (taskErr *dto.
|
||||
platform = GetTaskPlatform(c)
|
||||
}
|
||||
|
||||
info.InitChannelMeta(c)
|
||||
adaptor := GetTaskAdaptor(platform)
|
||||
if adaptor == nil {
|
||||
return service.TaskErrorWrapperLocal(fmt.Errorf("invalid api platform: %s", platform), "invalid_api_platform", http.StatusBadRequest)
|
||||
@@ -197,6 +200,9 @@ func RelayTaskFetch(c *gin.Context, relayMode int) (taskResp *dto.TaskError) {
|
||||
if taskErr != nil {
|
||||
return taskErr
|
||||
}
|
||||
if len(respBody) == 0 {
|
||||
respBody = []byte("{\"code\":\"success\",\"data\":null}")
|
||||
}
|
||||
|
||||
c.Writer.Header().Set("Content-Type", "application/json")
|
||||
_, err := io.Copy(c.Writer, bytes.NewBuffer(respBody))
|
||||
@@ -276,10 +282,92 @@ func videoFetchByIDRespBodyBuilder(c *gin.Context) (respBody []byte, taskResp *d
|
||||
return
|
||||
}
|
||||
|
||||
respBody, err = json.Marshal(dto.TaskResponse[any]{
|
||||
Code: "success",
|
||||
Data: TaskModel2Dto(originTask),
|
||||
})
|
||||
func() {
|
||||
channelModel, err2 := model.GetChannelById(originTask.ChannelId, true)
|
||||
if err2 != nil {
|
||||
return
|
||||
}
|
||||
if channelModel.Type != constant.ChannelTypeVertexAi {
|
||||
return
|
||||
}
|
||||
baseURL := constant.ChannelBaseURLs[channelModel.Type]
|
||||
if channelModel.GetBaseURL() != "" {
|
||||
baseURL = channelModel.GetBaseURL()
|
||||
}
|
||||
adaptor := GetTaskAdaptor(constant.TaskPlatform(strconv.Itoa(channelModel.Type)))
|
||||
if adaptor == nil {
|
||||
return
|
||||
}
|
||||
resp, err2 := adaptor.FetchTask(baseURL, channelModel.Key, map[string]any{
|
||||
"task_id": originTask.TaskID,
|
||||
"action": originTask.Action,
|
||||
})
|
||||
if err2 != nil || resp == nil {
|
||||
return
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
body, err2 := io.ReadAll(resp.Body)
|
||||
if err2 != nil {
|
||||
return
|
||||
}
|
||||
ti, err2 := adaptor.ParseTaskResult(body)
|
||||
if err2 == nil && ti != nil {
|
||||
if ti.Status != "" {
|
||||
originTask.Status = model.TaskStatus(ti.Status)
|
||||
}
|
||||
if ti.Progress != "" {
|
||||
originTask.Progress = ti.Progress
|
||||
}
|
||||
if ti.Url != "" {
|
||||
originTask.FailReason = ti.Url
|
||||
}
|
||||
_ = originTask.Update()
|
||||
var raw map[string]any
|
||||
_ = json.Unmarshal(body, &raw)
|
||||
format := "mp4"
|
||||
if respObj, ok := raw["response"].(map[string]any); ok {
|
||||
if vids, ok := respObj["videos"].([]any); ok && len(vids) > 0 {
|
||||
if v0, ok := vids[0].(map[string]any); ok {
|
||||
if mt, ok := v0["mimeType"].(string); ok && mt != "" {
|
||||
if strings.Contains(mt, "mp4") {
|
||||
format = "mp4"
|
||||
} else {
|
||||
format = mt
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
status := "processing"
|
||||
switch originTask.Status {
|
||||
case model.TaskStatusSuccess:
|
||||
status = "succeeded"
|
||||
case model.TaskStatusFailure:
|
||||
status = "failed"
|
||||
case model.TaskStatusQueued, model.TaskStatusSubmitted:
|
||||
status = "queued"
|
||||
}
|
||||
out := map[string]any{
|
||||
"error": nil,
|
||||
"format": format,
|
||||
"metadata": nil,
|
||||
"status": status,
|
||||
"task_id": originTask.TaskID,
|
||||
"url": originTask.FailReason,
|
||||
}
|
||||
respBody, _ = json.Marshal(dto.TaskResponse[any]{
|
||||
Code: "success",
|
||||
Data: out,
|
||||
})
|
||||
}
|
||||
}()
|
||||
|
||||
if len(respBody) == 0 {
|
||||
respBody, err = json.Marshal(dto.TaskResponse[any]{
|
||||
Code: "success",
|
||||
Data: TaskModel2Dto(originTask),
|
||||
})
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
|
||||
@@ -81,7 +81,7 @@ func RerankHelper(c *gin.Context, info *relaycommon.RelayInfo) (newAPIError *typ
|
||||
if resp != nil {
|
||||
httpResp = resp.(*http.Response)
|
||||
if httpResp.StatusCode != http.StatusOK {
|
||||
newAPIError = service.RelayErrorHandler(httpResp, false)
|
||||
newAPIError = service.RelayErrorHandler(c.Request.Context(), httpResp, false)
|
||||
// reset status code 重置状态码
|
||||
service.ResetStatusCode(newAPIError, statusCodeMappingStr)
|
||||
return newAPIError
|
||||
|
||||
@@ -41,7 +41,7 @@ func ResponsesHelper(c *gin.Context, info *relaycommon.RelayInfo) (newAPIError *
|
||||
}
|
||||
adaptor.Init(info)
|
||||
var requestBody io.Reader
|
||||
if model_setting.GetGlobalSettings().PassThroughRequestEnabled {
|
||||
if model_setting.GetGlobalSettings().PassThroughRequestEnabled || info.ChannelSetting.PassThroughBodyEnabled {
|
||||
body, err := common.GetRequestBody(c)
|
||||
if err != nil {
|
||||
return types.NewError(err, types.ErrorCodeReadRequestBodyFailed, types.ErrOptionWithSkipRetry())
|
||||
@@ -82,7 +82,7 @@ func ResponsesHelper(c *gin.Context, info *relaycommon.RelayInfo) (newAPIError *
|
||||
httpResp = resp.(*http.Response)
|
||||
|
||||
if httpResp.StatusCode != http.StatusOK {
|
||||
newAPIError = service.RelayErrorHandler(httpResp, false)
|
||||
newAPIError = service.RelayErrorHandler(c.Request.Context(), httpResp, false)
|
||||
// reset status code 重置状态码
|
||||
service.ResetStatusCode(newAPIError, statusCodeMappingStr)
|
||||
return newAPIError
|
||||
|
||||
@@ -60,6 +60,7 @@ func SetApiRouter(router *gin.Engine) {
|
||||
selfRoute.DELETE("/self", controller.DeleteSelf)
|
||||
selfRoute.GET("/token", controller.GenerateAccessToken)
|
||||
selfRoute.GET("/aff", controller.GetAffCode)
|
||||
selfRoute.GET("/topup/info", controller.GetTopUpInfo)
|
||||
selfRoute.POST("/topup", middleware.CriticalRateLimit(), controller.TopUp)
|
||||
selfRoute.POST("/pay", middleware.CriticalRateLimit(), controller.RequestEpay)
|
||||
selfRoute.POST("/amount", controller.RequestAmount)
|
||||
|
||||
@@ -1,57 +0,0 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"one-api/common"
|
||||
"one-api/setting"
|
||||
"strings"
|
||||
)
|
||||
|
||||
// WorkerRequest Worker请求的数据结构
|
||||
type WorkerRequest struct {
|
||||
URL string `json:"url"`
|
||||
Key string `json:"key"`
|
||||
Method string `json:"method,omitempty"`
|
||||
Headers map[string]string `json:"headers,omitempty"`
|
||||
Body json.RawMessage `json:"body,omitempty"`
|
||||
}
|
||||
|
||||
// DoWorkerRequest 通过Worker发送请求
|
||||
func DoWorkerRequest(req *WorkerRequest) (*http.Response, error) {
|
||||
if !setting.EnableWorker() {
|
||||
return nil, fmt.Errorf("worker not enabled")
|
||||
}
|
||||
if !setting.WorkerAllowHttpImageRequestEnabled && !strings.HasPrefix(req.URL, "https") {
|
||||
return nil, fmt.Errorf("only support https url")
|
||||
}
|
||||
|
||||
workerUrl := setting.WorkerUrl
|
||||
if !strings.HasSuffix(workerUrl, "/") {
|
||||
workerUrl += "/"
|
||||
}
|
||||
|
||||
// 序列化worker请求数据
|
||||
workerPayload, err := json.Marshal(req)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to marshal worker payload: %v", err)
|
||||
}
|
||||
|
||||
return http.Post(workerUrl, "application/json", bytes.NewBuffer(workerPayload))
|
||||
}
|
||||
|
||||
func DoDownloadRequest(originUrl string, reason ...string) (resp *http.Response, err error) {
|
||||
if setting.EnableWorker() {
|
||||
common.SysLog(fmt.Sprintf("downloading file from worker: %s, reason: %s", originUrl, strings.Join(reason, ", ")))
|
||||
req := &WorkerRequest{
|
||||
URL: originUrl,
|
||||
Key: setting.WorkerValidKey,
|
||||
}
|
||||
return DoWorkerRequest(req)
|
||||
} else {
|
||||
common.SysLog(fmt.Sprintf("downloading from origin with worker: %s, reason: %s", originUrl, strings.Join(reason, ", ")))
|
||||
return http.Get(originUrl)
|
||||
}
|
||||
}
|
||||
69
service/download.go
Normal file
69
service/download.go
Normal file
@@ -0,0 +1,69 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"one-api/common"
|
||||
"one-api/setting/system_setting"
|
||||
"strings"
|
||||
)
|
||||
|
||||
// WorkerRequest Worker请求的数据结构
|
||||
type WorkerRequest struct {
|
||||
URL string `json:"url"`
|
||||
Key string `json:"key"`
|
||||
Method string `json:"method,omitempty"`
|
||||
Headers map[string]string `json:"headers,omitempty"`
|
||||
Body json.RawMessage `json:"body,omitempty"`
|
||||
}
|
||||
|
||||
// DoWorkerRequest 通过Worker发送请求
|
||||
func DoWorkerRequest(req *WorkerRequest) (*http.Response, error) {
|
||||
if !system_setting.EnableWorker() {
|
||||
return nil, fmt.Errorf("worker not enabled")
|
||||
}
|
||||
if !system_setting.WorkerAllowHttpImageRequestEnabled && !strings.HasPrefix(req.URL, "https") {
|
||||
return nil, fmt.Errorf("only support https url")
|
||||
}
|
||||
|
||||
// SSRF防护:验证请求URL
|
||||
fetchSetting := system_setting.GetFetchSetting()
|
||||
if err := common.ValidateURLWithFetchSetting(req.URL, fetchSetting.EnableSSRFProtection, fetchSetting.AllowPrivateIp, fetchSetting.DomainFilterMode, fetchSetting.IpFilterMode, fetchSetting.DomainList, fetchSetting.IpList, fetchSetting.AllowedPorts, fetchSetting.ApplyIPFilterForDomain); err != nil {
|
||||
return nil, fmt.Errorf("request reject: %v", err)
|
||||
}
|
||||
|
||||
workerUrl := system_setting.WorkerUrl
|
||||
if !strings.HasSuffix(workerUrl, "/") {
|
||||
workerUrl += "/"
|
||||
}
|
||||
|
||||
// 序列化worker请求数据
|
||||
workerPayload, err := json.Marshal(req)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to marshal worker payload: %v", err)
|
||||
}
|
||||
|
||||
return http.Post(workerUrl, "application/json", bytes.NewBuffer(workerPayload))
|
||||
}
|
||||
|
||||
func DoDownloadRequest(originUrl string, reason ...string) (resp *http.Response, err error) {
|
||||
if system_setting.EnableWorker() {
|
||||
common.SysLog(fmt.Sprintf("downloading file from worker: %s, reason: %s", originUrl, strings.Join(reason, ", ")))
|
||||
req := &WorkerRequest{
|
||||
URL: originUrl,
|
||||
Key: system_setting.WorkerValidKey,
|
||||
}
|
||||
return DoWorkerRequest(req)
|
||||
} else {
|
||||
// SSRF防护:验证请求URL(非Worker模式)
|
||||
fetchSetting := system_setting.GetFetchSetting()
|
||||
if err := common.ValidateURLWithFetchSetting(originUrl, fetchSetting.EnableSSRFProtection, fetchSetting.AllowPrivateIp, fetchSetting.DomainFilterMode, fetchSetting.IpFilterMode, fetchSetting.DomainList, fetchSetting.IpList, fetchSetting.AllowedPorts, fetchSetting.ApplyIPFilterForDomain); err != nil {
|
||||
return nil, fmt.Errorf("request reject: %v", err)
|
||||
}
|
||||
|
||||
common.SysLog(fmt.Sprintf("downloading from origin: %s, reason: %s", common.MaskSensitiveInfo(originUrl), strings.Join(reason, ", ")))
|
||||
return http.Get(originUrl)
|
||||
}
|
||||
}
|
||||
@@ -1,12 +1,13 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"one-api/setting"
|
||||
"one-api/setting/operation_setting"
|
||||
"one-api/setting/system_setting"
|
||||
)
|
||||
|
||||
func GetCallbackAddress() string {
|
||||
if setting.CustomCallbackAddress == "" {
|
||||
return setting.ServerAddress
|
||||
if operation_setting.CustomCallbackAddress == "" {
|
||||
return system_setting.ServerAddress
|
||||
}
|
||||
return setting.CustomCallbackAddress
|
||||
return operation_setting.CustomCallbackAddress
|
||||
}
|
||||
|
||||
@@ -1,12 +1,14 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"one-api/common"
|
||||
"one-api/dto"
|
||||
"one-api/logger"
|
||||
"one-api/types"
|
||||
"strconv"
|
||||
"strings"
|
||||
@@ -78,7 +80,7 @@ func ClaudeErrorWrapperLocal(err error, code string, statusCode int) *dto.Claude
|
||||
return claudeErr
|
||||
}
|
||||
|
||||
func RelayErrorHandler(resp *http.Response, showBodyWhenFail bool) (newApiErr *types.NewAPIError) {
|
||||
func RelayErrorHandler(ctx context.Context, resp *http.Response, showBodyWhenFail bool) (newApiErr *types.NewAPIError) {
|
||||
newApiErr = types.InitOpenAIError(types.ErrorCodeBadResponseStatusCode, resp.StatusCode)
|
||||
|
||||
responseBody, err := io.ReadAll(resp.Body)
|
||||
@@ -94,7 +96,7 @@ func RelayErrorHandler(resp *http.Response, showBodyWhenFail bool) (newApiErr *t
|
||||
newApiErr.Err = fmt.Errorf("bad response status code %d, body: %s", resp.StatusCode, string(responseBody))
|
||||
} else {
|
||||
if common.DebugEnabled {
|
||||
println(fmt.Sprintf("bad response status code %d, body: %s", resp.StatusCode, string(responseBody)))
|
||||
logger.LogInfo(ctx, fmt.Sprintf("bad response status code %d, body: %s", resp.StatusCode, string(responseBody)))
|
||||
}
|
||||
newApiErr.Err = fmt.Errorf("bad response status code %d", resp.StatusCode)
|
||||
}
|
||||
|
||||
@@ -13,13 +13,13 @@ import (
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
func ReturnPreConsumedQuota(c *gin.Context, relayInfo *relaycommon.RelayInfo, preConsumedQuota int) {
|
||||
if preConsumedQuota != 0 {
|
||||
logger.LogInfo(c, fmt.Sprintf("用户 %d 请求失败, 返还预扣费额度 %s", relayInfo.UserId, logger.FormatQuota(preConsumedQuota)))
|
||||
func ReturnPreConsumedQuota(c *gin.Context, relayInfo *relaycommon.RelayInfo) {
|
||||
if relayInfo.FinalPreConsumedQuota != 0 {
|
||||
logger.LogInfo(c, fmt.Sprintf("用户 %d 请求失败, 返还预扣费额度 %s", relayInfo.UserId, logger.FormatQuota(relayInfo.FinalPreConsumedQuota)))
|
||||
gopool.Go(func() {
|
||||
relayInfoCopy := *relayInfo
|
||||
|
||||
err := PostConsumeQuota(&relayInfoCopy, -preConsumedQuota, 0, false)
|
||||
err := PostConsumeQuota(&relayInfoCopy, -relayInfoCopy.FinalPreConsumedQuota, 0, false)
|
||||
if err != nil {
|
||||
common.SysLog("error return pre-consumed quota: " + err.Error())
|
||||
}
|
||||
@@ -29,16 +29,16 @@ func ReturnPreConsumedQuota(c *gin.Context, relayInfo *relaycommon.RelayInfo, pr
|
||||
|
||||
// PreConsumeQuota checks if the user has enough quota to pre-consume.
|
||||
// It returns the pre-consumed quota if successful, or an error if not.
|
||||
func PreConsumeQuota(c *gin.Context, preConsumedQuota int, relayInfo *relaycommon.RelayInfo) (int, *types.NewAPIError) {
|
||||
func PreConsumeQuota(c *gin.Context, preConsumedQuota int, relayInfo *relaycommon.RelayInfo) *types.NewAPIError {
|
||||
userQuota, err := model.GetUserQuota(relayInfo.UserId, false)
|
||||
if err != nil {
|
||||
return 0, types.NewError(err, types.ErrorCodeQueryDataError, types.ErrOptionWithSkipRetry())
|
||||
return types.NewError(err, types.ErrorCodeQueryDataError, types.ErrOptionWithSkipRetry())
|
||||
}
|
||||
if userQuota <= 0 {
|
||||
return 0, types.NewErrorWithStatusCode(fmt.Errorf("用户额度不足, 剩余额度: %s", logger.FormatQuota(userQuota)), types.ErrorCodeInsufficientUserQuota, http.StatusForbidden, types.ErrOptionWithSkipRetry(), types.ErrOptionWithNoRecordErrorLog())
|
||||
return types.NewErrorWithStatusCode(fmt.Errorf("用户额度不足, 剩余额度: %s", logger.FormatQuota(userQuota)), types.ErrorCodeInsufficientUserQuota, http.StatusForbidden, types.ErrOptionWithSkipRetry(), types.ErrOptionWithNoRecordErrorLog())
|
||||
}
|
||||
if userQuota-preConsumedQuota < 0 {
|
||||
return 0, types.NewErrorWithStatusCode(fmt.Errorf("预扣费额度失败, 用户剩余额度: %s, 需要预扣费额度: %s", logger.FormatQuota(userQuota), logger.FormatQuota(preConsumedQuota)), types.ErrorCodeInsufficientUserQuota, http.StatusForbidden, types.ErrOptionWithSkipRetry(), types.ErrOptionWithNoRecordErrorLog())
|
||||
return types.NewErrorWithStatusCode(fmt.Errorf("预扣费额度失败, 用户剩余额度: %s, 需要预扣费额度: %s", logger.FormatQuota(userQuota), logger.FormatQuota(preConsumedQuota)), types.ErrorCodeInsufficientUserQuota, http.StatusForbidden, types.ErrOptionWithSkipRetry(), types.ErrOptionWithNoRecordErrorLog())
|
||||
}
|
||||
|
||||
trustQuota := common.GetTrustQuota()
|
||||
@@ -65,14 +65,14 @@ func PreConsumeQuota(c *gin.Context, preConsumedQuota int, relayInfo *relaycommo
|
||||
if preConsumedQuota > 0 {
|
||||
err := PreConsumeTokenQuota(relayInfo, preConsumedQuota)
|
||||
if err != nil {
|
||||
return 0, types.NewErrorWithStatusCode(err, types.ErrorCodePreConsumeTokenQuotaFailed, http.StatusForbidden, types.ErrOptionWithSkipRetry(), types.ErrOptionWithNoRecordErrorLog())
|
||||
return types.NewErrorWithStatusCode(err, types.ErrorCodePreConsumeTokenQuotaFailed, http.StatusForbidden, types.ErrOptionWithSkipRetry(), types.ErrOptionWithNoRecordErrorLog())
|
||||
}
|
||||
err = model.DecreaseUserQuota(relayInfo.UserId, preConsumedQuota)
|
||||
if err != nil {
|
||||
return 0, types.NewError(err, types.ErrorCodeUpdateDataError, types.ErrOptionWithSkipRetry())
|
||||
return types.NewError(err, types.ErrorCodeUpdateDataError, types.ErrOptionWithSkipRetry())
|
||||
}
|
||||
logger.LogInfo(c, fmt.Sprintf("用户 %d 预扣费 %s, 预扣费后剩余额度: %s", relayInfo.UserId, logger.FormatQuota(preConsumedQuota), logger.FormatQuota(userQuota-preConsumedQuota)))
|
||||
}
|
||||
relayInfo.FinalPreConsumedQuota = preConsumedQuota
|
||||
return preConsumedQuota, nil
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -11,8 +11,8 @@ import (
|
||||
"one-api/logger"
|
||||
"one-api/model"
|
||||
relaycommon "one-api/relay/common"
|
||||
"one-api/setting"
|
||||
"one-api/setting/ratio_setting"
|
||||
"one-api/setting/system_setting"
|
||||
"one-api/types"
|
||||
"strings"
|
||||
"time"
|
||||
@@ -534,7 +534,7 @@ func checkAndSendQuotaNotify(relayInfo *relaycommon.RelayInfo, quota int, preCon
|
||||
}
|
||||
if quotaTooLow {
|
||||
prompt := "您的额度即将用尽"
|
||||
topUpLink := fmt.Sprintf("%s/topup", setting.ServerAddress)
|
||||
topUpLink := fmt.Sprintf("%s/topup", system_setting.ServerAddress)
|
||||
|
||||
// 根据通知方式生成不同的内容格式
|
||||
var content string
|
||||
|
||||
@@ -336,7 +336,7 @@ func CountRequestToken(c *gin.Context, meta *types.TokenCountMeta, info *relayco
|
||||
for i, file := range meta.Files {
|
||||
switch file.FileType {
|
||||
case types.FileTypeImage:
|
||||
if info.RelayFormat == types.RelayFormatGemini && !strings.HasPrefix(model, "gemini-2.5-flash-image-preview") {
|
||||
if info.RelayFormat == types.RelayFormatGemini {
|
||||
tkm += 256
|
||||
} else {
|
||||
token, err := getImageToken(file, model, info.IsStream)
|
||||
|
||||
@@ -7,7 +7,7 @@ import (
|
||||
"one-api/common"
|
||||
"one-api/dto"
|
||||
"one-api/model"
|
||||
"one-api/setting"
|
||||
"one-api/setting/system_setting"
|
||||
"strings"
|
||||
)
|
||||
|
||||
@@ -91,11 +91,11 @@ func sendBarkNotify(barkURL string, data dto.Notify) error {
|
||||
var resp *http.Response
|
||||
var err error
|
||||
|
||||
if setting.EnableWorker() {
|
||||
if system_setting.EnableWorker() {
|
||||
// 使用worker发送请求
|
||||
workerReq := &WorkerRequest{
|
||||
URL: finalURL,
|
||||
Key: setting.WorkerValidKey,
|
||||
Key: system_setting.WorkerValidKey,
|
||||
Method: http.MethodGet,
|
||||
Headers: map[string]string{
|
||||
"User-Agent": "OneAPI-Bark-Notify/1.0",
|
||||
@@ -113,6 +113,12 @@ func sendBarkNotify(barkURL string, data dto.Notify) error {
|
||||
return fmt.Errorf("bark request failed with status code: %d", resp.StatusCode)
|
||||
}
|
||||
} else {
|
||||
// SSRF防护:验证Bark URL(非Worker模式)
|
||||
fetchSetting := system_setting.GetFetchSetting()
|
||||
if err := common.ValidateURLWithFetchSetting(finalURL, fetchSetting.EnableSSRFProtection, fetchSetting.AllowPrivateIp, fetchSetting.DomainFilterMode, fetchSetting.IpFilterMode, fetchSetting.DomainList, fetchSetting.IpList, fetchSetting.AllowedPorts, fetchSetting.ApplyIPFilterForDomain); err != nil {
|
||||
return fmt.Errorf("request reject: %v", err)
|
||||
}
|
||||
|
||||
// 直接发送请求
|
||||
req, err = http.NewRequest(http.MethodGet, finalURL, nil)
|
||||
if err != nil {
|
||||
|
||||
@@ -8,8 +8,9 @@ import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"one-api/common"
|
||||
"one-api/dto"
|
||||
"one-api/setting"
|
||||
"one-api/setting/system_setting"
|
||||
"time"
|
||||
)
|
||||
|
||||
@@ -56,11 +57,11 @@ func SendWebhookNotify(webhookURL string, secret string, data dto.Notify) error
|
||||
var req *http.Request
|
||||
var resp *http.Response
|
||||
|
||||
if setting.EnableWorker() {
|
||||
if system_setting.EnableWorker() {
|
||||
// 构建worker请求数据
|
||||
workerReq := &WorkerRequest{
|
||||
URL: webhookURL,
|
||||
Key: setting.WorkerValidKey,
|
||||
Key: system_setting.WorkerValidKey,
|
||||
Method: http.MethodPost,
|
||||
Headers: map[string]string{
|
||||
"Content-Type": "application/json",
|
||||
@@ -86,6 +87,12 @@ func SendWebhookNotify(webhookURL string, secret string, data dto.Notify) error
|
||||
return fmt.Errorf("webhook request failed with status code: %d", resp.StatusCode)
|
||||
}
|
||||
} else {
|
||||
// SSRF防护:验证Webhook URL(非Worker模式)
|
||||
fetchSetting := system_setting.GetFetchSetting()
|
||||
if err := common.ValidateURLWithFetchSetting(webhookURL, fetchSetting.EnableSSRFProtection, fetchSetting.AllowPrivateIp, fetchSetting.DomainFilterMode, fetchSetting.IpFilterMode, fetchSetting.DomainList, fetchSetting.IpList, fetchSetting.AllowedPorts, fetchSetting.ApplyIPFilterForDomain); err != nil {
|
||||
return fmt.Errorf("request reject: %v", err)
|
||||
}
|
||||
|
||||
req, err = http.NewRequest(http.MethodPost, webhookURL, bytes.NewBuffer(payloadBytes))
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to create webhook request: %v", err)
|
||||
|
||||
@@ -26,7 +26,6 @@ var defaultGeminiSettings = GeminiSettings{
|
||||
SupportedImagineModels: []string{
|
||||
"gemini-2.0-flash-exp-image-generation",
|
||||
"gemini-2.0-flash-exp",
|
||||
"gemini-2.5-flash-image-preview",
|
||||
},
|
||||
ThinkingAdapterEnabled: false,
|
||||
ThinkingAdapterBudgetTokensPercentage: 0.6,
|
||||
|
||||
23
setting/operation_setting/payment_setting.go
Normal file
23
setting/operation_setting/payment_setting.go
Normal file
@@ -0,0 +1,23 @@
|
||||
package operation_setting
|
||||
|
||||
import "one-api/setting/config"
|
||||
|
||||
type PaymentSetting struct {
|
||||
AmountOptions []int `json:"amount_options"`
|
||||
AmountDiscount map[int]float64 `json:"amount_discount"` // 充值金额对应的折扣,例如 100 元 0.9 表示 100 元充值享受 9 折优惠
|
||||
}
|
||||
|
||||
// 默认配置
|
||||
var paymentSetting = PaymentSetting{
|
||||
AmountOptions: []int{10, 20, 50, 100, 200, 500},
|
||||
AmountDiscount: map[int]float64{},
|
||||
}
|
||||
|
||||
func init() {
|
||||
// 注册到全局配置管理器
|
||||
config.GlobalConfig.Register("payment_setting", &paymentSetting)
|
||||
}
|
||||
|
||||
func GetPaymentSetting() *PaymentSetting {
|
||||
return &paymentSetting
|
||||
}
|
||||
@@ -1,6 +1,13 @@
|
||||
package setting
|
||||
/**
|
||||
此文件为旧版支付设置文件,如需增加新的参数、变量等,请在 payment_setting.go 中添加
|
||||
This file is the old version of the payment settings file. If you need to add new parameters, variables, etc., please add them in payment_setting.go
|
||||
*/
|
||||
|
||||
import "encoding/json"
|
||||
package operation_setting
|
||||
|
||||
import (
|
||||
"one-api/common"
|
||||
)
|
||||
|
||||
var PayAddress = ""
|
||||
var CustomCallbackAddress = ""
|
||||
@@ -21,15 +28,21 @@ var PayMethods = []map[string]string{
|
||||
"color": "rgba(var(--semi-green-5), 1)",
|
||||
"type": "wxpay",
|
||||
},
|
||||
{
|
||||
"name": "自定义1",
|
||||
"color": "black",
|
||||
"type": "custom1",
|
||||
"min_topup": "50",
|
||||
},
|
||||
}
|
||||
|
||||
func UpdatePayMethodsByJsonString(jsonString string) error {
|
||||
PayMethods = make([]map[string]string, 0)
|
||||
return json.Unmarshal([]byte(jsonString), &PayMethods)
|
||||
return common.Unmarshal([]byte(jsonString), &PayMethods)
|
||||
}
|
||||
|
||||
func PayMethods2JsonString() string {
|
||||
jsonBytes, err := json.Marshal(PayMethods)
|
||||
jsonBytes, err := common.Marshal(PayMethods)
|
||||
if err != nil {
|
||||
return "[]"
|
||||
}
|
||||
@@ -10,6 +10,18 @@ const (
|
||||
FileSearchPrice = 2.5
|
||||
)
|
||||
|
||||
const (
|
||||
GPTImage1Low1024x1024 = 0.011
|
||||
GPTImage1Low1024x1536 = 0.016
|
||||
GPTImage1Low1536x1024 = 0.016
|
||||
GPTImage1Medium1024x1024 = 0.042
|
||||
GPTImage1Medium1024x1536 = 0.063
|
||||
GPTImage1Medium1536x1024 = 0.063
|
||||
GPTImage1High1024x1024 = 0.167
|
||||
GPTImage1High1024x1536 = 0.25
|
||||
GPTImage1High1536x1024 = 0.25
|
||||
)
|
||||
|
||||
const (
|
||||
// Gemini Audio Input Price
|
||||
Gemini25FlashPreviewInputAudioPrice = 1.00
|
||||
@@ -24,10 +36,6 @@ const (
|
||||
ClaudeWebSearchPrice = 10.00
|
||||
)
|
||||
|
||||
const (
|
||||
Gemini25FlashImagePreviewImageOutputPrice = 30.00
|
||||
)
|
||||
|
||||
func GetClaudeWebSearchPricePerThousand() float64 {
|
||||
return ClaudeWebSearchPrice
|
||||
}
|
||||
@@ -70,9 +78,30 @@ func GetGeminiInputAudioPricePerMillionTokens(modelName string) float64 {
|
||||
return 0
|
||||
}
|
||||
|
||||
func GetGeminiImageOutputPricePerMillionTokens(modelName string) float64 {
|
||||
if strings.HasPrefix(modelName, "gemini-2.5-flash-image-preview") {
|
||||
return Gemini25FlashImagePreviewImageOutputPrice
|
||||
func GetGPTImage1PriceOnceCall(quality string, size string) float64 {
|
||||
prices := map[string]map[string]float64{
|
||||
"low": {
|
||||
"1024x1024": GPTImage1Low1024x1024,
|
||||
"1024x1536": GPTImage1Low1024x1536,
|
||||
"1536x1024": GPTImage1Low1536x1024,
|
||||
},
|
||||
"medium": {
|
||||
"1024x1024": GPTImage1Medium1024x1024,
|
||||
"1024x1536": GPTImage1Medium1024x1536,
|
||||
"1536x1024": GPTImage1Medium1536x1024,
|
||||
},
|
||||
"high": {
|
||||
"1024x1024": GPTImage1High1024x1024,
|
||||
"1024x1536": GPTImage1High1024x1536,
|
||||
"1536x1024": GPTImage1High1536x1024,
|
||||
},
|
||||
}
|
||||
return 0
|
||||
|
||||
if qualityMap, exists := prices[quality]; exists {
|
||||
if price, exists := qualityMap[size]; exists {
|
||||
return price
|
||||
}
|
||||
}
|
||||
|
||||
return GPTImage1High1024x1024
|
||||
}
|
||||
|
||||
@@ -178,7 +178,7 @@ var defaultModelRatio = map[string]float64{
|
||||
"gemini-2.5-flash-lite-preview-thinking-*": 0.05,
|
||||
"gemini-2.5-flash-lite-preview-06-17": 0.05,
|
||||
"gemini-2.5-flash": 0.15,
|
||||
"gemini-2.5-flash-image-preview": 0.15, // $0.30(text/image) / 1M tokens
|
||||
"gemini-embedding-001": 0.075,
|
||||
"text-embedding-004": 0.001,
|
||||
"chatglm_turbo": 0.3572, // ¥0.005 / 1k tokens
|
||||
"chatglm_pro": 0.7143, // ¥0.01 / 1k tokens
|
||||
@@ -279,6 +279,18 @@ var defaultModelPrice = map[string]float64{
|
||||
"mj_upload": 0.05,
|
||||
}
|
||||
|
||||
var defaultAudioRatio = map[string]float64{
|
||||
"gpt-4o-audio-preview": 16,
|
||||
"gpt-4o-mini-audio-preview": 66.67,
|
||||
"gpt-4o-realtime-preview": 8,
|
||||
"gpt-4o-mini-realtime-preview": 16.67,
|
||||
}
|
||||
|
||||
var defaultAudioCompletionRatio = map[string]float64{
|
||||
"gpt-4o-realtime": 2,
|
||||
"gpt-4o-mini-realtime": 2,
|
||||
}
|
||||
|
||||
var (
|
||||
modelPriceMap map[string]float64 = nil
|
||||
modelPriceMapMutex = sync.RWMutex{}
|
||||
@@ -294,11 +306,10 @@ var (
|
||||
)
|
||||
|
||||
var defaultCompletionRatio = map[string]float64{
|
||||
"gpt-4-gizmo-*": 2,
|
||||
"gpt-4o-gizmo-*": 3,
|
||||
"gpt-4-all": 2,
|
||||
"gpt-image-1": 8,
|
||||
"gemini-2.5-flash-image-preview": 8.3333333333,
|
||||
"gpt-4-gizmo-*": 2,
|
||||
"gpt-4o-gizmo-*": 3,
|
||||
"gpt-4-all": 2,
|
||||
"gpt-image-1": 8,
|
||||
}
|
||||
|
||||
// InitRatioSettings initializes all model related settings maps
|
||||
@@ -328,6 +339,15 @@ func InitRatioSettings() {
|
||||
imageRatioMap = defaultImageRatio
|
||||
imageRatioMapMutex.Unlock()
|
||||
|
||||
// initialize audioRatioMap
|
||||
audioRatioMapMutex.Lock()
|
||||
audioRatioMap = defaultAudioRatio
|
||||
audioRatioMapMutex.Unlock()
|
||||
|
||||
// initialize audioCompletionRatioMap
|
||||
audioCompletionRatioMapMutex.Lock()
|
||||
audioCompletionRatioMap = defaultAudioCompletionRatio
|
||||
audioCompletionRatioMapMutex.Unlock()
|
||||
}
|
||||
|
||||
func GetModelPriceMap() map[string]float64 {
|
||||
@@ -419,6 +439,18 @@ func GetDefaultModelRatioMap() map[string]float64 {
|
||||
return defaultModelRatio
|
||||
}
|
||||
|
||||
func GetDefaultImageRatioMap() map[string]float64 {
|
||||
return defaultImageRatio
|
||||
}
|
||||
|
||||
func GetDefaultAudioRatioMap() map[string]float64 {
|
||||
return defaultAudioRatio
|
||||
}
|
||||
|
||||
func GetDefaultAudioCompletionRatioMap() map[string]float64 {
|
||||
return defaultAudioCompletionRatio
|
||||
}
|
||||
|
||||
func GetCompletionRatioMap() map[string]float64 {
|
||||
CompletionRatioMutex.RLock()
|
||||
defer CompletionRatioMutex.RUnlock()
|
||||
@@ -586,32 +618,22 @@ func getHardcodedCompletionModelRatio(name string) (float64, bool) {
|
||||
}
|
||||
|
||||
func GetAudioRatio(name string) float64 {
|
||||
if strings.Contains(name, "-realtime") {
|
||||
if strings.HasSuffix(name, "gpt-4o-realtime-preview") {
|
||||
return 8
|
||||
} else if strings.Contains(name, "gpt-4o-mini-realtime-preview") {
|
||||
return 10 / 0.6
|
||||
} else {
|
||||
return 20
|
||||
}
|
||||
}
|
||||
if strings.Contains(name, "-audio") {
|
||||
if strings.HasPrefix(name, "gpt-4o-audio-preview") {
|
||||
return 40 / 2.5
|
||||
} else if strings.HasPrefix(name, "gpt-4o-mini-audio-preview") {
|
||||
return 10 / 0.15
|
||||
} else {
|
||||
return 40
|
||||
}
|
||||
audioRatioMapMutex.RLock()
|
||||
defer audioRatioMapMutex.RUnlock()
|
||||
name = FormatMatchingModelName(name)
|
||||
if ratio, ok := audioRatioMap[name]; ok {
|
||||
return ratio
|
||||
}
|
||||
return 20
|
||||
}
|
||||
|
||||
func GetAudioCompletionRatio(name string) float64 {
|
||||
if strings.HasPrefix(name, "gpt-4o-realtime") {
|
||||
return 2
|
||||
} else if strings.HasPrefix(name, "gpt-4o-mini-realtime") {
|
||||
return 2
|
||||
audioCompletionRatioMapMutex.RLock()
|
||||
defer audioCompletionRatioMapMutex.RUnlock()
|
||||
name = FormatMatchingModelName(name)
|
||||
if ratio, ok := audioCompletionRatioMap[name]; ok {
|
||||
|
||||
return ratio
|
||||
}
|
||||
return 2
|
||||
}
|
||||
@@ -632,6 +654,14 @@ var defaultImageRatio = map[string]float64{
|
||||
}
|
||||
var imageRatioMap map[string]float64
|
||||
var imageRatioMapMutex sync.RWMutex
|
||||
var (
|
||||
audioRatioMap map[string]float64 = nil
|
||||
audioRatioMapMutex = sync.RWMutex{}
|
||||
)
|
||||
var (
|
||||
audioCompletionRatioMap map[string]float64 = nil
|
||||
audioCompletionRatioMapMutex = sync.RWMutex{}
|
||||
)
|
||||
|
||||
func ImageRatio2JSONString() string {
|
||||
imageRatioMapMutex.RLock()
|
||||
@@ -660,6 +690,71 @@ func GetImageRatio(name string) (float64, bool) {
|
||||
return ratio, true
|
||||
}
|
||||
|
||||
func AudioRatio2JSONString() string {
|
||||
audioRatioMapMutex.RLock()
|
||||
defer audioRatioMapMutex.RUnlock()
|
||||
jsonBytes, err := common.Marshal(audioRatioMap)
|
||||
if err != nil {
|
||||
common.SysError("error marshalling audio ratio: " + err.Error())
|
||||
}
|
||||
return string(jsonBytes)
|
||||
}
|
||||
|
||||
func UpdateAudioRatioByJSONString(jsonStr string) error {
|
||||
|
||||
tmp := make(map[string]float64)
|
||||
if err := common.Unmarshal([]byte(jsonStr), &tmp); err != nil {
|
||||
return err
|
||||
}
|
||||
audioRatioMapMutex.Lock()
|
||||
audioRatioMap = tmp
|
||||
audioRatioMapMutex.Unlock()
|
||||
InvalidateExposedDataCache()
|
||||
return nil
|
||||
}
|
||||
|
||||
func GetAudioRatioCopy() map[string]float64 {
|
||||
audioRatioMapMutex.RLock()
|
||||
defer audioRatioMapMutex.RUnlock()
|
||||
copyMap := make(map[string]float64, len(audioRatioMap))
|
||||
for k, v := range audioRatioMap {
|
||||
copyMap[k] = v
|
||||
}
|
||||
return copyMap
|
||||
}
|
||||
|
||||
func AudioCompletionRatio2JSONString() string {
|
||||
audioCompletionRatioMapMutex.RLock()
|
||||
defer audioCompletionRatioMapMutex.RUnlock()
|
||||
jsonBytes, err := common.Marshal(audioCompletionRatioMap)
|
||||
if err != nil {
|
||||
common.SysError("error marshalling audio completion ratio: " + err.Error())
|
||||
}
|
||||
return string(jsonBytes)
|
||||
}
|
||||
|
||||
func UpdateAudioCompletionRatioByJSONString(jsonStr string) error {
|
||||
tmp := make(map[string]float64)
|
||||
if err := common.Unmarshal([]byte(jsonStr), &tmp); err != nil {
|
||||
return err
|
||||
}
|
||||
audioCompletionRatioMapMutex.Lock()
|
||||
audioCompletionRatioMap = tmp
|
||||
audioCompletionRatioMapMutex.Unlock()
|
||||
InvalidateExposedDataCache()
|
||||
return nil
|
||||
}
|
||||
|
||||
func GetAudioCompletionRatioCopy() map[string]float64 {
|
||||
audioCompletionRatioMapMutex.RLock()
|
||||
defer audioCompletionRatioMapMutex.RUnlock()
|
||||
copyMap := make(map[string]float64, len(audioCompletionRatioMap))
|
||||
for k, v := range audioCompletionRatioMap {
|
||||
copyMap[k] = v
|
||||
}
|
||||
return copyMap
|
||||
}
|
||||
|
||||
func GetModelRatioCopy() map[string]float64 {
|
||||
modelRatioMapMutex.RLock()
|
||||
defer modelRatioMapMutex.RUnlock()
|
||||
|
||||
34
setting/system_setting/fetch_setting.go
Normal file
34
setting/system_setting/fetch_setting.go
Normal file
@@ -0,0 +1,34 @@
|
||||
package system_setting
|
||||
|
||||
import "one-api/setting/config"
|
||||
|
||||
type FetchSetting struct {
|
||||
EnableSSRFProtection bool `json:"enable_ssrf_protection"` // 是否启用SSRF防护
|
||||
AllowPrivateIp bool `json:"allow_private_ip"`
|
||||
DomainFilterMode bool `json:"domain_filter_mode"` // 域名过滤模式,true: 白名单模式,false: 黑名单模式
|
||||
IpFilterMode bool `json:"ip_filter_mode"` // IP过滤模式,true: 白名单模式,false: 黑名单模式
|
||||
DomainList []string `json:"domain_list"` // domain format, e.g. example.com, *.example.com
|
||||
IpList []string `json:"ip_list"` // CIDR format
|
||||
AllowedPorts []string `json:"allowed_ports"` // port range format, e.g. 80, 443, 8000-9000
|
||||
ApplyIPFilterForDomain bool `json:"apply_ip_filter_for_domain"` // 对域名启用IP过滤(实验性)
|
||||
}
|
||||
|
||||
var defaultFetchSetting = FetchSetting{
|
||||
EnableSSRFProtection: true, // 默认开启SSRF防护
|
||||
AllowPrivateIp: false,
|
||||
DomainFilterMode: false,
|
||||
IpFilterMode: false,
|
||||
DomainList: []string{},
|
||||
IpList: []string{},
|
||||
AllowedPorts: []string{"80", "443", "8080", "8443"},
|
||||
ApplyIPFilterForDomain: false,
|
||||
}
|
||||
|
||||
func init() {
|
||||
// 注册到全局配置管理器
|
||||
config.GlobalConfig.Register("fetch_setting", &defaultFetchSetting)
|
||||
}
|
||||
|
||||
func GetFetchSetting() *FetchSetting {
|
||||
return &defaultFetchSetting
|
||||
}
|
||||
@@ -1,4 +1,4 @@
|
||||
package setting
|
||||
package system_setting
|
||||
|
||||
var ServerAddress = "http://localhost:3000"
|
||||
var WorkerUrl = ""
|
||||
@@ -122,6 +122,9 @@ func (e *NewAPIError) MaskSensitiveError() string {
|
||||
return string(e.errorCode)
|
||||
}
|
||||
errStr := e.Err.Error()
|
||||
if e.errorCode == ErrorCodeCountTokenFailed {
|
||||
return errStr
|
||||
}
|
||||
return common.MaskSensitiveInfo(errStr)
|
||||
}
|
||||
|
||||
@@ -153,8 +156,9 @@ func (e *NewAPIError) ToOpenAIError() OpenAIError {
|
||||
Code: e.errorCode,
|
||||
}
|
||||
}
|
||||
|
||||
result.Message = common.MaskSensitiveInfo(result.Message)
|
||||
if e.errorCode != ErrorCodeCountTokenFailed {
|
||||
result.Message = common.MaskSensitiveInfo(result.Message)
|
||||
}
|
||||
return result
|
||||
}
|
||||
|
||||
@@ -178,13 +182,23 @@ func (e *NewAPIError) ToClaudeError() ClaudeError {
|
||||
Type: string(e.errorType),
|
||||
}
|
||||
}
|
||||
result.Message = common.MaskSensitiveInfo(result.Message)
|
||||
if e.errorCode != ErrorCodeCountTokenFailed {
|
||||
result.Message = common.MaskSensitiveInfo(result.Message)
|
||||
}
|
||||
return result
|
||||
}
|
||||
|
||||
type NewAPIErrorOptions func(*NewAPIError)
|
||||
|
||||
func NewError(err error, errorCode ErrorCode, ops ...NewAPIErrorOptions) *NewAPIError {
|
||||
var newErr *NewAPIError
|
||||
// 保留深层传递的 new err
|
||||
if errors.As(err, &newErr) {
|
||||
for _, op := range ops {
|
||||
op(newErr)
|
||||
}
|
||||
return newErr
|
||||
}
|
||||
e := &NewAPIError{
|
||||
Err: err,
|
||||
RelayError: nil,
|
||||
@@ -199,8 +213,21 @@ func NewError(err error, errorCode ErrorCode, ops ...NewAPIErrorOptions) *NewAPI
|
||||
}
|
||||
|
||||
func NewOpenAIError(err error, errorCode ErrorCode, statusCode int, ops ...NewAPIErrorOptions) *NewAPIError {
|
||||
if errorCode == ErrorCodeDoRequestFailed {
|
||||
err = errors.New("upstream error: do request failed")
|
||||
var newErr *NewAPIError
|
||||
// 保留深层传递的 new err
|
||||
if errors.As(err, &newErr) {
|
||||
if newErr.RelayError == nil {
|
||||
openaiError := OpenAIError{
|
||||
Message: newErr.Error(),
|
||||
Type: string(errorCode),
|
||||
Code: errorCode,
|
||||
}
|
||||
newErr.RelayError = openaiError
|
||||
}
|
||||
for _, op := range ops {
|
||||
op(newErr)
|
||||
}
|
||||
return newErr
|
||||
}
|
||||
openaiError := OpenAIError{
|
||||
Message: err.Error(),
|
||||
@@ -305,6 +332,15 @@ func ErrOptionWithNoRecordErrorLog() NewAPIErrorOptions {
|
||||
}
|
||||
}
|
||||
|
||||
func ErrOptionWithHideErrMsg(replaceStr string) NewAPIErrorOptions {
|
||||
return func(e *NewAPIError) {
|
||||
if common.DebugEnabled {
|
||||
fmt.Printf("ErrOptionWithHideErrMsg: %s, origin error: %s", replaceStr, e.Err)
|
||||
}
|
||||
e.Err = errors.New(replaceStr)
|
||||
}
|
||||
}
|
||||
|
||||
func IsRecordErrorLog(e *NewAPIError) bool {
|
||||
if e == nil {
|
||||
return false
|
||||
|
||||
@@ -15,6 +15,8 @@ type PriceData struct {
|
||||
CacheRatio float64
|
||||
CacheCreationRatio float64
|
||||
ImageRatio float64
|
||||
AudioRatio float64
|
||||
AudioCompletionRatio float64
|
||||
UsePrice bool
|
||||
ShouldPreConsumedQuota int
|
||||
GroupRatioInfo GroupRatioInfo
|
||||
@@ -27,5 +29,5 @@ type PerCallPriceData struct {
|
||||
}
|
||||
|
||||
func (p PriceData) ToSetting() string {
|
||||
return fmt.Sprintf("ModelPrice: %f, ModelRatio: %f, CompletionRatio: %f, CacheRatio: %f, GroupRatio: %f, UsePrice: %t, CacheCreationRatio: %f, ShouldPreConsumedQuota: %d, ImageRatio: %f", p.ModelPrice, p.ModelRatio, p.CompletionRatio, p.CacheRatio, p.GroupRatioInfo.GroupRatio, p.UsePrice, p.CacheCreationRatio, p.ShouldPreConsumedQuota, p.ImageRatio)
|
||||
return fmt.Sprintf("ModelPrice: %f, ModelRatio: %f, CompletionRatio: %f, CacheRatio: %f, GroupRatio: %f, UsePrice: %t, CacheCreationRatio: %f, ShouldPreConsumedQuota: %d, ImageRatio: %f, AudioRatio: %f, AudioCompletionRatio: %f", p.ModelPrice, p.ModelRatio, p.CompletionRatio, p.CacheRatio, p.GroupRatioInfo.GroupRatio, p.UsePrice, p.CacheCreationRatio, p.ShouldPreConsumedQuota, p.ImageRatio, p.AudioRatio, p.AudioCompletionRatio)
|
||||
}
|
||||
|
||||
1
web/.env.example
Normal file
1
web/.env.example
Normal file
@@ -0,0 +1 @@
|
||||
VITE_CLERK_PUBLISHABLE_KEY=
|
||||
128
web/.github/CODE_OF_CONDUCT.md
vendored
Normal file
128
web/.github/CODE_OF_CONDUCT.md
vendored
Normal file
@@ -0,0 +1,128 @@
|
||||
# Contributor Covenant Code of Conduct
|
||||
|
||||
## Our Pledge
|
||||
|
||||
We as members, contributors, and leaders pledge to make participation in our
|
||||
community a harassment-free experience for everyone, regardless of age, body
|
||||
size, visible or invisible disability, ethnicity, sex characteristics, gender
|
||||
identity and expression, level of experience, education, socio-economic status,
|
||||
nationality, personal appearance, race, religion, or sexual identity
|
||||
and orientation.
|
||||
|
||||
We pledge to act and interact in ways that contribute to an open, welcoming,
|
||||
diverse, inclusive, and healthy community.
|
||||
|
||||
## Our Standards
|
||||
|
||||
Examples of behavior that contributes to a positive environment for our
|
||||
community include:
|
||||
|
||||
* Demonstrating empathy and kindness toward other people
|
||||
* Being respectful of differing opinions, viewpoints, and experiences
|
||||
* Giving and gracefully accepting constructive feedback
|
||||
* Accepting responsibility and apologizing to those affected by our mistakes,
|
||||
and learning from the experience
|
||||
* Focusing on what is best not just for us as individuals, but for the
|
||||
overall community
|
||||
|
||||
Examples of unacceptable behavior include:
|
||||
|
||||
* The use of sexualized language or imagery, and sexual attention or
|
||||
advances of any kind
|
||||
* Trolling, insulting or derogatory comments, and personal or political attacks
|
||||
* Public or private harassment
|
||||
* Publishing others' private information, such as a physical or email
|
||||
address, without their explicit permission
|
||||
* Other conduct which could reasonably be considered inappropriate in a
|
||||
professional setting
|
||||
|
||||
## Enforcement Responsibilities
|
||||
|
||||
Community leaders are responsible for clarifying and enforcing our standards of
|
||||
acceptable behavior and will take appropriate and fair corrective action in
|
||||
response to any behavior that they deem inappropriate, threatening, offensive,
|
||||
or harmful.
|
||||
|
||||
Community leaders have the right and responsibility to remove, edit, or reject
|
||||
comments, commits, code, wiki edits, issues, and other contributions that are
|
||||
not aligned to this Code of Conduct, and will communicate reasons for moderation
|
||||
decisions when appropriate.
|
||||
|
||||
## Scope
|
||||
|
||||
This Code of Conduct applies within all community spaces, and also applies when
|
||||
an individual is officially representing the community in public spaces.
|
||||
Examples of representing our community include using an official e-mail address,
|
||||
posting via an official social media account, or acting as an appointed
|
||||
representative at an online or offline event.
|
||||
|
||||
## Enforcement
|
||||
|
||||
Instances of abusive, harassing, or otherwise unacceptable behavior may be
|
||||
reported to the community leaders responsible for enforcement at
|
||||
.
|
||||
All complaints will be reviewed and investigated promptly and fairly.
|
||||
|
||||
All community leaders are obligated to respect the privacy and security of the
|
||||
reporter of any incident.
|
||||
|
||||
## Enforcement Guidelines
|
||||
|
||||
Community leaders will follow these Community Impact Guidelines in determining
|
||||
the consequences for any action they deem in violation of this Code of Conduct:
|
||||
|
||||
### 1. Correction
|
||||
|
||||
**Community Impact**: Use of inappropriate language or other behavior deemed
|
||||
unprofessional or unwelcome in the community.
|
||||
|
||||
**Consequence**: A private, written warning from community leaders, providing
|
||||
clarity around the nature of the violation and an explanation of why the
|
||||
behavior was inappropriate. A public apology may be requested.
|
||||
|
||||
### 2. Warning
|
||||
|
||||
**Community Impact**: A violation through a single incident or series
|
||||
of actions.
|
||||
|
||||
**Consequence**: A warning with consequences for continued behavior. No
|
||||
interaction with the people involved, including unsolicited interaction with
|
||||
those enforcing the Code of Conduct, for a specified period of time. This
|
||||
includes avoiding interactions in community spaces as well as external channels
|
||||
like social media. Violating these terms may lead to a temporary or
|
||||
permanent ban.
|
||||
|
||||
### 3. Temporary Ban
|
||||
|
||||
**Community Impact**: A serious violation of community standards, including
|
||||
sustained inappropriate behavior.
|
||||
|
||||
**Consequence**: A temporary ban from any sort of interaction or public
|
||||
communication with the community for a specified period of time. No public or
|
||||
private interaction with the people involved, including unsolicited interaction
|
||||
with those enforcing the Code of Conduct, is allowed during this period.
|
||||
Violating these terms may lead to a permanent ban.
|
||||
|
||||
### 4. Permanent Ban
|
||||
|
||||
**Community Impact**: Demonstrating a pattern of violation of community
|
||||
standards, including sustained inappropriate behavior, harassment of an
|
||||
individual, or aggression toward or disparagement of classes of individuals.
|
||||
|
||||
**Consequence**: A permanent ban from any sort of public interaction within
|
||||
the community.
|
||||
|
||||
## Attribution
|
||||
|
||||
This Code of Conduct is adapted from the [Contributor Covenant][homepage],
|
||||
version 2.0, available at
|
||||
https://www.contributor-covenant.org/version/2/0/code_of_conduct.html.
|
||||
|
||||
Community Impact Guidelines were inspired by [Mozilla's code of conduct
|
||||
enforcement ladder](https://github.com/mozilla/diversity).
|
||||
|
||||
[homepage]: https://www.contributor-covenant.org
|
||||
|
||||
For answers to common questions about this code of conduct, see the FAQ at
|
||||
https://www.contributor-covenant.org/faq. Translations are available at
|
||||
https://www.contributor-covenant.org/translations.
|
||||
101
web/.github/CONTRIBUTING.md
vendored
Normal file
101
web/.github/CONTRIBUTING.md
vendored
Normal file
@@ -0,0 +1,101 @@
|
||||
# Contributing to Shadcn-Admin
|
||||
|
||||
Thank you for considering contributing to **shadcn-admin**! Every contribution is valuable, whether it's reporting bugs, suggesting improvements, adding features, or refining README.
|
||||
|
||||
## Table of Contents
|
||||
|
||||
1. [Getting Started](#getting-started)
|
||||
2. [How to Contribute](#how-to-contribute)
|
||||
3. [Code Standards](#code-standards)
|
||||
4. [Pull Request Guidelines](#pull-request-guidelines)
|
||||
5. [Reporting Issues](#reporting-issues)
|
||||
6. [Community Guidelines](#community-guidelines)
|
||||
|
||||
---
|
||||
|
||||
## Getting Started
|
||||
|
||||
1. **Fork** the repository.
|
||||
2. **Clone** your fork:
|
||||
|
||||
```bash
|
||||
git clone https://github.com/your-username/shadcn-admin.git
|
||||
```
|
||||
|
||||
3. **Install dependencies:**
|
||||
|
||||
```bash
|
||||
pnpm install
|
||||
```
|
||||
|
||||
4. **Run the project locally:**
|
||||
|
||||
```bash
|
||||
pnpm dev
|
||||
```
|
||||
|
||||
5. Create a new branch for your contribution:
|
||||
|
||||
```bash
|
||||
git checkout -b feature/your-feature
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## How to Contribute
|
||||
|
||||
- **Feature Requests:** Open an issue or start a discussion to discuss the feature before implementation.
|
||||
- **Bug Fixes:** Provide clear reproduction steps in your issue.
|
||||
- **Documentation:** Improvements to the documentation (README) are always appreciated.
|
||||
|
||||
> **Note:** Pull Requests adding new features without a prior issue or discussion will **not be accepted**.
|
||||
|
||||
---
|
||||
|
||||
## Code Standards
|
||||
|
||||
- Follow the existing **ESLint** and **Prettier** configurations.
|
||||
- Ensure your code is **type-safe** with **TypeScript**.
|
||||
- Maintain consistency with the existing code structure.
|
||||
|
||||
> **Tips!** Before submitting your changes, run the following commands:
|
||||
|
||||
```bash
|
||||
pnpm lint && pnpm format && pnpm knip && pnpm build
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## Pull Request Guidelines
|
||||
|
||||
- **Follow the [PR Template](./PULL_REQUEST_TEMPLATE.md):**
|
||||
- Description
|
||||
- Types of changes
|
||||
- Checklist
|
||||
- Further comments
|
||||
- Related Issue
|
||||
- Ensure your changes pass **CI checks**.
|
||||
- Keep PRs **focused** and **concise**.
|
||||
- Reference related issues in your PR description.
|
||||
|
||||
---
|
||||
|
||||
## Reporting Issues
|
||||
|
||||
- Clearly describe the issue.
|
||||
- Provide reproduction steps if applicable.
|
||||
- Include screenshots or code examples if relevant.
|
||||
|
||||
---
|
||||
|
||||
## Community Guidelines
|
||||
|
||||
- Be respectful and constructive.
|
||||
- Follow the [Code of Conduct](./CODE_OF_CONDUCT.md).
|
||||
- Stay on topic in discussions.
|
||||
|
||||
---
|
||||
|
||||
Thank you for helping make **shadcn-admin** better! 🚀
|
||||
|
||||
If you have any questions, feel free to reach out via [Discussions](https://github.com/satnaing/shadcn-admin/discussions).
|
||||
14
web/.github/FUNDING.yml
vendored
Normal file
14
web/.github/FUNDING.yml
vendored
Normal file
@@ -0,0 +1,14 @@
|
||||
github: [satnaing]
|
||||
buy_me_a_coffee: satnaing
|
||||
|
||||
# patreon: # Replace with a single Patreon username
|
||||
# open_collective: # Replace with a single Open Collective username
|
||||
# ko_fi: # Replace with a single Ko-fi username
|
||||
# tidelift: # Replace with a single Tidelift platform-name/package-name e.g., npm/babel
|
||||
# community_bridge: # Replace with a single Community Bridge project-name e.g., cloud-foundry
|
||||
# liberapay: # Replace with a single Liberapay username
|
||||
# issuehunt: # Replace with a single IssueHunt username
|
||||
# lfx_crowdfunding: # Replace with a single LFX Crowdfunding project-name e.g., cloud-foundry
|
||||
# polar: # Replace with a single Polar username
|
||||
# thanks_dev: # Replace with a single thanks.dev username
|
||||
# custom: # Replace with up to 4 custom sponsorship URLs e.g., ['link1', 'link2']
|
||||
5
web/.github/ISSUE_TEMPLATE/config.yml
vendored
Normal file
5
web/.github/ISSUE_TEMPLATE/config.yml
vendored
Normal file
@@ -0,0 +1,5 @@
|
||||
blank_issues_enabled: false
|
||||
contact_links:
|
||||
- name: Shadcn-Admin Discussions
|
||||
url: https://github.com/satnaing/shadcn-admin/discussions
|
||||
about: Please ask and answer questions here.
|
||||
19
web/.github/ISSUE_TEMPLATE/✨-feature-request.md
vendored
Normal file
19
web/.github/ISSUE_TEMPLATE/✨-feature-request.md
vendored
Normal file
@@ -0,0 +1,19 @@
|
||||
---
|
||||
name: "✨ Feature Request"
|
||||
about: Suggest an idea for improving Shadcn-Admin
|
||||
title: "[Feature Request]: "
|
||||
labels: enhancement
|
||||
assignees: ""
|
||||
---
|
||||
|
||||
**Is your feature request related to a problem? Please describe.**
|
||||
A clear and concise description of what the problem is. Ex. I'm always frustrated when [...]
|
||||
|
||||
**Describe the solution you'd like**
|
||||
A clear and concise description of what you want to happen.
|
||||
|
||||
**Describe alternatives you've considered**
|
||||
A clear and concise description of any alternative solutions or features you've considered.
|
||||
|
||||
**Additional context**
|
||||
Add any other context or screenshots about the feature request here.
|
||||
27
web/.github/ISSUE_TEMPLATE/🐞-bug-report.md
vendored
Normal file
27
web/.github/ISSUE_TEMPLATE/🐞-bug-report.md
vendored
Normal file
@@ -0,0 +1,27 @@
|
||||
---
|
||||
name: "\U0001F41E Bug report"
|
||||
about: Report a bug or unexpected behavior in Shadcn-Admin
|
||||
title: "[BUG]: "
|
||||
labels: bug
|
||||
assignees: ""
|
||||
---
|
||||
|
||||
**Describe the bug**
|
||||
A clear and concise description of what the bug is.
|
||||
|
||||
**To Reproduce**
|
||||
Steps to reproduce the behavior:
|
||||
|
||||
1. Go to '...'
|
||||
2. Click on '....'
|
||||
3. Scroll down to '....'
|
||||
4. See error
|
||||
|
||||
**Expected behavior**
|
||||
A clear and concise description of what you expected to happen.
|
||||
|
||||
**Screenshots**
|
||||
If applicable, add screenshots to help explain your problem.
|
||||
|
||||
**Additional context**
|
||||
Add any other context about the problem here.
|
||||
27
web/.github/PULL_REQUEST_TEMPLATE.md
vendored
Normal file
27
web/.github/PULL_REQUEST_TEMPLATE.md
vendored
Normal file
@@ -0,0 +1,27 @@
|
||||
## Description
|
||||
|
||||
<!-- A clear and concise description of what the pull request does. Include any relevant motivation and background. -->
|
||||
|
||||
## Types of changes
|
||||
|
||||
<!-- What types of changes does your code introduce to AstroPaper? Put an `x` in the boxes that apply -->
|
||||
|
||||
- [ ] Bug Fix (non-breaking change which fixes an issue)
|
||||
- [ ] New Feature (non-breaking change which adds functionality)
|
||||
- [ ] Others (any other types not listed above)
|
||||
|
||||
## Checklist
|
||||
|
||||
<!-- Please follow this checklist and put an x in each of the boxes, like this: [x]. You can also fill these out after creating the PR. This is simply a reminder of what we are going to look for before merging your code. -->
|
||||
|
||||
- [ ] I have read the [Contributing Guide](https://github.com/satnaing/shadcn-admin/blob/main/.github/CONTRIBUTING.md)
|
||||
|
||||
## Further comments
|
||||
|
||||
<!-- If this is a relatively large or complex change, kick off the discussion by explaining why you chose the solution you did and what alternatives you considered, etc... -->
|
||||
|
||||
## Related Issue
|
||||
|
||||
<!-- If this PR is related to an existing issue, link to it here. -->
|
||||
|
||||
Closes: #<!-- Issue number, if applicable -->
|
||||
41
web/.github/workflows/ci.yml
vendored
Normal file
41
web/.github/workflows/ci.yml
vendored
Normal file
@@ -0,0 +1,41 @@
|
||||
name: Continuous Integration
|
||||
|
||||
on:
|
||||
push:
|
||||
branches:
|
||||
- main
|
||||
|
||||
pull_request:
|
||||
branches:
|
||||
- main
|
||||
|
||||
jobs:
|
||||
install-lint-build:
|
||||
runs-on: ubuntu-latest
|
||||
|
||||
steps:
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@v4
|
||||
|
||||
- name: Setup Node.js
|
||||
uses: actions/setup-node@v4
|
||||
with:
|
||||
node-version: 20
|
||||
|
||||
- name: Install pnpm
|
||||
run: npm install -g pnpm
|
||||
|
||||
- name: Install dependencies
|
||||
run: pnpm install --frozen-lockfile
|
||||
|
||||
- name: Lint the code
|
||||
run: pnpm lint
|
||||
|
||||
# - name: Analyze unused files and dependencies
|
||||
# run: pnpm knip
|
||||
|
||||
- name: Run Prettier check
|
||||
run: pnpm format:check
|
||||
|
||||
- name: Build the project
|
||||
run: pnpm build
|
||||
29
web/.github/workflows/stale.yml
vendored
Normal file
29
web/.github/workflows/stale.yml
vendored
Normal file
@@ -0,0 +1,29 @@
|
||||
name: Close inactive issues/PR
|
||||
|
||||
on:
|
||||
schedule:
|
||||
- cron: '38 18 * * *'
|
||||
|
||||
jobs:
|
||||
stale:
|
||||
|
||||
runs-on: ubuntu-latest
|
||||
permissions:
|
||||
issues: write
|
||||
pull-requests: write
|
||||
|
||||
steps:
|
||||
- uses: actions/stale@v5
|
||||
with:
|
||||
repo-token: ${{ secrets.GITHUB_TOKEN }}
|
||||
days-before-issue-stale: 120
|
||||
days-before-issue-close: 120
|
||||
stale-issue-label: "stale"
|
||||
stale-issue-message: "This issue is stale because it has been open for 120 days with no activity."
|
||||
close-issue-message: "This issue was closed because it has been inactive for 120 days since being marked as stale."
|
||||
days-before-pr-stale: 120
|
||||
days-before-pr-close: 120
|
||||
stale-pr-label: "stale"
|
||||
stale-pr-message: "This PR is stale because it has been open for 120 days with no activity."
|
||||
close-pr-message: "This PR was closed because it has been inactive for 120 days since being marked as stale."
|
||||
operations-per-run: 0
|
||||
44
web/.gitignore
vendored
44
web/.gitignore
vendored
@@ -1,26 +1,26 @@
|
||||
# See https://help.github.com/articles/ignoring-files/ for more about ignoring files.
|
||||
|
||||
# dependencies
|
||||
/node_modules
|
||||
/.pnp
|
||||
.pnp.js
|
||||
|
||||
# testing
|
||||
/coverage
|
||||
|
||||
# production
|
||||
/build
|
||||
|
||||
# misc
|
||||
.DS_Store
|
||||
.env.local
|
||||
.env.development.local
|
||||
.env.test.local
|
||||
.env.production.local
|
||||
|
||||
# Logs
|
||||
logs
|
||||
*.log
|
||||
npm-debug.log*
|
||||
yarn-debug.log*
|
||||
yarn-error.log*
|
||||
pnpm-debug.log*
|
||||
lerna-debug.log*
|
||||
|
||||
node_modules
|
||||
dist
|
||||
dist-ssr
|
||||
*.local
|
||||
|
||||
.env
|
||||
|
||||
# Editor directories and files
|
||||
.vscode/*
|
||||
!.vscode/extensions.json
|
||||
.idea
|
||||
package-lock.json
|
||||
yarn.lock
|
||||
.DS_Store
|
||||
*.suo
|
||||
*.ntvs*
|
||||
*.njsproj
|
||||
*.sln
|
||||
*.sw?
|
||||
|
||||
18
web/.prettierignore
Normal file
18
web/.prettierignore
Normal file
@@ -0,0 +1,18 @@
|
||||
# Ignore everything
|
||||
/*
|
||||
|
||||
# Except these files & folders
|
||||
!/src
|
||||
!index.html
|
||||
!package.json
|
||||
!tailwind.config.js
|
||||
!tsconfig.json
|
||||
!tsconfig.node.json
|
||||
!vite.config.ts
|
||||
!.prettierrc
|
||||
!README.md
|
||||
!eslint.config.js
|
||||
!postcss.config.js
|
||||
|
||||
# Ignore auto generated routeTree.gen.ts
|
||||
/src/routeTree.gen.ts
|
||||
49
web/.prettierrc
Normal file
49
web/.prettierrc
Normal file
@@ -0,0 +1,49 @@
|
||||
{
|
||||
"arrowParens": "always",
|
||||
"semi": false,
|
||||
"tabWidth": 2,
|
||||
"printWidth": 80,
|
||||
"singleQuote": true,
|
||||
"jsxSingleQuote": true,
|
||||
"trailingComma": "es5",
|
||||
"bracketSpacing": true,
|
||||
"endOfLine": "lf",
|
||||
"plugins": [
|
||||
"@trivago/prettier-plugin-sort-imports",
|
||||
"prettier-plugin-tailwindcss"
|
||||
],
|
||||
"importOrder": [
|
||||
"^path$",
|
||||
"^vite$",
|
||||
"^@vitejs/(.*)$",
|
||||
"^react$",
|
||||
"^react-dom/client$",
|
||||
"^react/(.*)$",
|
||||
"^globals$",
|
||||
"^zod$",
|
||||
"^axios$",
|
||||
"^date-fns$",
|
||||
"^react-hook-form$",
|
||||
"^use-intl$",
|
||||
"^@radix-ui/(.*)$",
|
||||
"^@hookform/resolvers/zod$",
|
||||
"^@tanstack/react-query$",
|
||||
"^@tanstack/react-router$",
|
||||
"^@tanstack/react-table$",
|
||||
"<THIRD_PARTY_MODULES>",
|
||||
"^@/assets/(.*)",
|
||||
"^@/api/(.*)$",
|
||||
"^@/stores/(.*)$",
|
||||
"^@/lib/(.*)$",
|
||||
"^@/utils/(.*)$",
|
||||
"^@/constants/(.*)$",
|
||||
"^@/context/(.*)$",
|
||||
"^@/hooks/(.*)$",
|
||||
"^@/components/layouts/(.*)$",
|
||||
"^@/components/ui/(.*)$",
|
||||
"^@/components/errors/(.*)$",
|
||||
"^@/components/(.*)$",
|
||||
"^@/features/(.*)$",
|
||||
"^[./]"
|
||||
]
|
||||
}
|
||||
297
web/CHANGELOG.md
Normal file
297
web/CHANGELOG.md
Normal file
@@ -0,0 +1,297 @@
|
||||
## v2.1.0 (2025-08-23)
|
||||
|
||||
### Feat
|
||||
|
||||
- enhance data table pagination with page numbers (#207)
|
||||
- enhance auth flow with sign-out dialogs and redirect functionality (#206)
|
||||
|
||||
### Refactor
|
||||
|
||||
- reorganize utility files into `lib/` folder (#209)
|
||||
- extract data-table components and reorganize structure (#208)
|
||||
|
||||
## v2.0.0 (2025-08-16)
|
||||
|
||||
### BREAKING CHANGE
|
||||
|
||||
- CSS file structure has been reorganized
|
||||
|
||||
### Feat
|
||||
|
||||
- add search param sync in apps route (#200)
|
||||
- improve tables and sync table states with search param (#199)
|
||||
- add data table bulk action toolbar (#196)
|
||||
- add config drawer and update overall layout (#186)
|
||||
- RTL support (#179)
|
||||
|
||||
### Fix
|
||||
|
||||
- adjust layout styles in search and top nav in dashboard page
|
||||
- update spacing and layout styles
|
||||
- update faceted icon color
|
||||
- improve user table hover & selected styles (#195)
|
||||
- add max-width for large screens to improve responsiveness (#194)
|
||||
- adjust chat border radius for better responsiveness (#193)
|
||||
- update hard-coded or inconsistent colors (#191)
|
||||
- use variable for inset layout height calculation
|
||||
- faded-bottom overflow issue in inset layout
|
||||
- hide unnecessary configs on mobile (#189)
|
||||
- adjust file input text vertical alignment (#188)
|
||||
|
||||
### Refactor
|
||||
|
||||
- enforce consistency and code quality (#198)
|
||||
- improve code quality and consistency (#197)
|
||||
- update error routes (#192)
|
||||
- remove DirSwitch component and its usage in Tasks (#190)
|
||||
- standardize using cookie as persist state (#187)
|
||||
- separate CSS into modular theme and base styles (#185)
|
||||
- replace tabler icons with lucide icons (#183)
|
||||
|
||||
## v1.4.2 (2025-07-23)
|
||||
|
||||
### Fix
|
||||
|
||||
- remove unnecessary transitions in table (#176)
|
||||
- overflow background in tables (#175)
|
||||
|
||||
## v1.4.1 (2025-06-25)
|
||||
|
||||
### Fix
|
||||
|
||||
- user list overflow in chat (#160)
|
||||
- prevent showing collapsed menu on mobile (#155)
|
||||
- white background select dropdown in dark mode (#149)
|
||||
|
||||
### Refactor
|
||||
|
||||
- update font config guide in fonts.ts (#164)
|
||||
|
||||
## v1.4.0 (2025-05-25)
|
||||
|
||||
### Feat
|
||||
|
||||
- **clerk**: add Clerk for auth and protected route (#146)
|
||||
|
||||
### Fix
|
||||
|
||||
- add an indicator for nested pages in search (#147)
|
||||
- update faded-bottom color with css variable (#139)
|
||||
|
||||
## v1.3.0 (2025-04-16)
|
||||
|
||||
### Fix
|
||||
|
||||
- replace custom otp with input-otp component (#131)
|
||||
- disable layout animation on mobile (#130)
|
||||
- upgrade react-day-picker and update calendar component (#129)
|
||||
|
||||
### Others
|
||||
|
||||
- upgrade Tailwind CSS to v4 (#125)
|
||||
- upgrade dependencies (#128)
|
||||
- configure automatic code-splitting (#127)
|
||||
|
||||
## v1.2.0 (2025-04-12)
|
||||
|
||||
### Feat
|
||||
|
||||
- add loading indicator during page transitions (#119)
|
||||
- add light favicons and theme-based switching (#112)
|
||||
- add new chat dialog in chats page (#90)
|
||||
|
||||
### Fix
|
||||
|
||||
- add fallback font for fontFamily (#110)
|
||||
- broken focus behavior in add user dialog (#113)
|
||||
|
||||
## v1.1.0 (2025-01-30)
|
||||
|
||||
### Feat
|
||||
|
||||
- allow changing font family in setting
|
||||
|
||||
### Fix
|
||||
|
||||
- update sidebar color in dark mode for consistent look (#87)
|
||||
- use overflow-clip in table paginations (#86)
|
||||
- **style**: update global scrollbar style (#82)
|
||||
- toolbar filter placeholder typo in user table (#76)
|
||||
|
||||
## v1.0.3 (2024-12-28)
|
||||
|
||||
### Fix
|
||||
|
||||
- add gap between buttons in import task dialog (#70)
|
||||
- hide button sort if column cannot be hidden & update filterFn (#69)
|
||||
- nav links added in profile dropdown (#68)
|
||||
|
||||
### Refactor
|
||||
|
||||
- optimize states in users/tasks context (#71)
|
||||
|
||||
## v1.0.2 (2024-12-25)
|
||||
|
||||
### Fix
|
||||
|
||||
- update overall layout due to scroll-lock bug (#66)
|
||||
|
||||
### Refactor
|
||||
|
||||
- analyze and remove unused files/exports with knip (#67)
|
||||
|
||||
## v1.0.1 (2024-12-14)
|
||||
|
||||
### Fix
|
||||
|
||||
- merge two button components into one (#60)
|
||||
- loading all tabler-icon chunks in dev mode (#59)
|
||||
- display menu dropdown when sidebar collapsed (#58)
|
||||
- update spacing & alignment in dialogs/drawers
|
||||
- update border & transition of sticky columns in user table
|
||||
- update heading alignment to left in user dialogs
|
||||
- add height and scroll area in user mutation dialogs
|
||||
- update `/dashboard` route to just `/`
|
||||
- **build**: replace require with import in tailwind.config.js
|
||||
|
||||
### Refactor
|
||||
|
||||
- remove unnecessary layout-backup file
|
||||
|
||||
## v1.0.0 (2024-12-09)
|
||||
|
||||
### BREAKING CHANGE
|
||||
|
||||
- Restructured the entire folder
|
||||
hierarchy to adopt a feature-based structure. This
|
||||
change improves code modularity and maintainability
|
||||
but introduces breaking changes.
|
||||
|
||||
### Feat
|
||||
|
||||
- implement task dialogs
|
||||
- implement user invite dialog
|
||||
- implement users CRUD
|
||||
- implement global command/search
|
||||
- implement custom sidebar trigger
|
||||
- implement coming-soon page
|
||||
|
||||
### Fix
|
||||
|
||||
- uncontrolled issue in account setting
|
||||
- card layout issue in app integrations page
|
||||
- remove form reset logic from useEffect in task import
|
||||
- update JSX types due to react 19
|
||||
- prevent card stretch in filtered app layout
|
||||
- layout wrap issue in tasks page on mobile
|
||||
- update user column hover and selected colors
|
||||
- add setTimeout in user dialog closing
|
||||
- layout shift issue in dropdown modal
|
||||
- z-axis overflow issue in header
|
||||
- stretch search bar only in mobile
|
||||
- language dropdown issue in account setting
|
||||
- update overflow contents with scroll area
|
||||
|
||||
### Refactor
|
||||
|
||||
- update layouts and extract common layout
|
||||
- reorganize project to feature-based structure
|
||||
|
||||
## v1.0.0-beta.5 (2024-11-11)
|
||||
|
||||
### Feat
|
||||
|
||||
- add multiple language support (#37)
|
||||
|
||||
### Fix
|
||||
|
||||
- ensure site syncs with system theme changes (#49)
|
||||
- recent sales responsive on ipad view (#40)
|
||||
|
||||
## v1.0.0-beta.4 (2024-09-22)
|
||||
|
||||
### Feat
|
||||
|
||||
- upgrade theme button to theme dropdown (#33)
|
||||
- **a11y**: add "Skip to Main" button to improve keyboard navigation (#27)
|
||||
|
||||
### Fix
|
||||
|
||||
- optimize onComplete/onIncomplete invocation (#32)
|
||||
- solve asChild attribute issue in custom button (#31)
|
||||
- improve custom Button component (#28)
|
||||
|
||||
## v1.0.0-beta.3 (2024-08-25)
|
||||
|
||||
### Feat
|
||||
|
||||
- implement chat page (#21)
|
||||
- add 401 error page (#12)
|
||||
- implement apps page
|
||||
- add otp page
|
||||
|
||||
### Fix
|
||||
|
||||
- prevent focus zoom on mobile devices (#20)
|
||||
- resolve eslint script issue (#18)
|
||||
- **a11y**: update default aria-label of each pin-input
|
||||
- resolve OTP paste issue in multi-digit pin-input
|
||||
- update layouts and solve overflow issues (#11)
|
||||
- sync pin inputs programmatically
|
||||
|
||||
## v1.0.0-beta.2 (2024-03-18)
|
||||
|
||||
### Feat
|
||||
|
||||
- implement custom pin-input component (#2)
|
||||
|
||||
## v1.0.0-beta.1 (2024-02-08)
|
||||
|
||||
### Feat
|
||||
|
||||
- update theme-color meta tag when theme is updated
|
||||
- add coming soon page in broken pages
|
||||
- implement tasks table and page
|
||||
- add remaining settings pages
|
||||
- add example error page for settings
|
||||
- update general error page to be more flexible
|
||||
- implement settings layout and settings profile page
|
||||
- add error pages
|
||||
- add password-input custom component
|
||||
- add sign-up page
|
||||
- add forgot-password page
|
||||
- add box sign in page
|
||||
- add email + password sign in page
|
||||
- make sidebar responsive and accessible
|
||||
- add tailwind prettier plugin
|
||||
- make sidebar collapsed state in local storage
|
||||
- add check current active nav hook
|
||||
- add loader component ui
|
||||
- update dropdown nav by default if child is active
|
||||
- add main-panel in dashboard
|
||||
- **ui**: add dark mode
|
||||
- **ui**: implement side nav ui
|
||||
|
||||
### Fix
|
||||
|
||||
- update incorrect overflow side nav height
|
||||
- exclude shadcn components from linting and remove unused props
|
||||
- solve text overflow issue when nav text is long
|
||||
- replace nav with dropdown in mobile topnav
|
||||
- make sidebar scrollable when overflow
|
||||
- update nav link keys
|
||||
- **ui**: update label style
|
||||
|
||||
### Refactor
|
||||
|
||||
- move password-input component into custom component dir
|
||||
- add custom button component
|
||||
- extract redundant codes into layout component
|
||||
- update react-router to use new api for routing
|
||||
- update main panel layout
|
||||
- update major layouts and styling
|
||||
- update main panel to be responsive
|
||||
- update sidebar collapsed state to false in mobile
|
||||
- update sidebar logo and title
|
||||
- **ui**: remove unnecessary spacing
|
||||
- remove unused files
|
||||
21
web/LICENSE
Normal file
21
web/LICENSE
Normal file
@@ -0,0 +1,21 @@
|
||||
MIT License
|
||||
|
||||
Copyright (c) 2024 Sat Naing
|
||||
|
||||
Permission is hereby granted, free of charge, to any person obtaining a copy
|
||||
of this software and associated documentation files (the "Software"), to deal
|
||||
in the Software without restriction, including without limitation the rights
|
||||
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
||||
copies of the Software, and to permit persons to whom the Software is
|
||||
furnished to do so, subject to the following conditions:
|
||||
|
||||
The above copyright notice and this permission notice shall be included in all
|
||||
copies or substantial portions of the Software.
|
||||
|
||||
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
||||
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
||||
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
||||
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
||||
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
||||
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
||||
SOFTWARE.
|
||||
117
web/README.md
Normal file
117
web/README.md
Normal file
@@ -0,0 +1,117 @@
|
||||
# Shadcn Admin Dashboard
|
||||
|
||||
Admin Dashboard UI crafted with Shadcn and Vite. Built with responsiveness and accessibility in mind.
|
||||
|
||||

|
||||
|
||||
I've been creating dashboard UIs at work and for my personal projects. I always wanted to make a reusable collection of dashboard UI for future projects; and here it is now. While I've created a few custom components, some of the code is directly adapted from ShadcnUI examples.
|
||||
|
||||
> This is not a starter project (template) though. I'll probably make one in the future.
|
||||
|
||||
## Features
|
||||
|
||||
- Light/dark mode
|
||||
- Responsive
|
||||
- Accessible
|
||||
- With built-in Sidebar component
|
||||
- Global search command
|
||||
- 10+ pages
|
||||
- Extra custom components
|
||||
- RTL support
|
||||
|
||||
<details>
|
||||
<summary>Customized Components (click to expand)</summary>
|
||||
|
||||
This project uses Shadcn UI components, but some have been slightly modified for better RTL (Right-to-Left) support and other improvements. These customized components differ from the original Shadcn UI versions.
|
||||
|
||||
If you want to update components using the Shadcn CLI (e.g., `npx shadcn@latest add <component>`), it's generally safe for non-customized components. For the listed customized ones, you may need to manually merge changes to preserve the project's modifications and avoid overwriting RTL support or other updates.
|
||||
|
||||
> If you don't require RTL support, you can safely update the 'RTL Updated Components' via the Shadcn CLI, as these changes are primarily for RTL compatibility. The 'Modified Components' may have other customizations to consider.
|
||||
|
||||
### Modified Components
|
||||
|
||||
- scroll-area
|
||||
- sonner
|
||||
- separator
|
||||
|
||||
### RTL Updated Components
|
||||
|
||||
- alert-dialog
|
||||
- calendar
|
||||
- command
|
||||
- dialog
|
||||
- dropdown-menu
|
||||
- select
|
||||
- table
|
||||
- sheet
|
||||
- sidebar
|
||||
- switch
|
||||
|
||||
**Notes:**
|
||||
|
||||
- **Modified Components**: These have general updates, potentially including RTL adjustments.
|
||||
- **RTL Updated Components**: These have specific changes for RTL language support (e.g., layout, positioning).
|
||||
- For implementation details, check the source files in `src/components/ui/`.
|
||||
- All other Shadcn UI components in the project are standard and can be safely updated via the CLI.
|
||||
|
||||
</details>
|
||||
|
||||
## Tech Stack
|
||||
|
||||
**UI:** [ShadcnUI](https://ui.shadcn.com) (TailwindCSS + RadixUI)
|
||||
|
||||
**Build Tool:** [Vite](https://vitejs.dev/)
|
||||
|
||||
**Routing:** [TanStack Router](https://tanstack.com/router/latest)
|
||||
|
||||
**Type Checking:** [TypeScript](https://www.typescriptlang.org/)
|
||||
|
||||
**Linting/Formatting:** [Eslint](https://eslint.org/) & [Prettier](https://prettier.io/)
|
||||
|
||||
**Icons:** [Lucide Icons](https://lucide.dev/icons/), [Tabler Icons](https://tabler.io/icons) (Brand icons only)
|
||||
|
||||
**Auth (partial):** [Clerk](https://go.clerk.com/GttUAaK)
|
||||
|
||||
## Run Locally
|
||||
|
||||
Clone the project
|
||||
|
||||
```bash
|
||||
git clone https://github.com/satnaing/shadcn-admin.git
|
||||
```
|
||||
|
||||
Go to the project directory
|
||||
|
||||
```bash
|
||||
cd shadcn-admin
|
||||
```
|
||||
|
||||
Install dependencies
|
||||
|
||||
```bash
|
||||
pnpm install
|
||||
```
|
||||
|
||||
Start the server
|
||||
|
||||
```bash
|
||||
pnpm run dev
|
||||
```
|
||||
|
||||
## Sponsoring this project ❤️
|
||||
|
||||
If you find this project helpful or use this in your own work, consider [sponsoring me](https://github.com/sponsors/satnaing) to support development and maintenance. You can [buy me a coffee](https://buymeacoffee.com/satnaing) as well. Don’t worry, every penny helps. Thank you! 🙏
|
||||
|
||||
For questions or sponsorship inquiries, feel free to reach out at [contact@satnaing.dev](mailto:contact@satnaing.dev).
|
||||
|
||||
### Current Sponsor
|
||||
|
||||
- [Clerk](https://go.clerk.com/GttUAaK) - for backing the implementation of Clerk in this project
|
||||
|
||||
## Author
|
||||
|
||||
Crafted with 🤍 by [@satnaing](https://github.com/satnaing)
|
||||
|
||||
## License
|
||||
|
||||
Licensed under the [MIT License](https://choosealicense.com/licenses/mit/)
|
||||
21
web/components.json
Normal file
21
web/components.json
Normal file
@@ -0,0 +1,21 @@
|
||||
{
|
||||
"$schema": "https://ui.shadcn.com/schema.json",
|
||||
"style": "new-york",
|
||||
"rsc": false,
|
||||
"tsx": true,
|
||||
"tailwind": {
|
||||
"config": "",
|
||||
"css": "src/styles/index.css",
|
||||
"baseColor": "slate",
|
||||
"cssVariables": true,
|
||||
"prefix": ""
|
||||
},
|
||||
"aliases": {
|
||||
"components": "@/components",
|
||||
"utils": "@/lib/utils",
|
||||
"ui": "@/components/ui",
|
||||
"lib": "@/lib",
|
||||
"hooks": "@/hooks"
|
||||
},
|
||||
"iconLibrary": "lucide"
|
||||
}
|
||||
7
web/cz.yaml
Normal file
7
web/cz.yaml
Normal file
@@ -0,0 +1,7 @@
|
||||
---
|
||||
commitizen:
|
||||
name: cz_conventional_commits
|
||||
tag_format: v$version
|
||||
update_changelog_on_bump: true
|
||||
version_provider: npm
|
||||
version_scheme: semver
|
||||
58
web/eslint.config.js
Normal file
58
web/eslint.config.js
Normal file
@@ -0,0 +1,58 @@
|
||||
import globals from 'globals'
|
||||
import js from '@eslint/js'
|
||||
import pluginQuery from '@tanstack/eslint-plugin-query'
|
||||
import reactHooks from 'eslint-plugin-react-hooks'
|
||||
import reactRefresh from 'eslint-plugin-react-refresh'
|
||||
import tseslint from 'typescript-eslint'
|
||||
|
||||
export default tseslint.config(
|
||||
{ ignores: ['dist', 'src/components/ui'] },
|
||||
{
|
||||
extends: [
|
||||
js.configs.recommended,
|
||||
...tseslint.configs.recommended,
|
||||
...pluginQuery.configs['flat/recommended'],
|
||||
],
|
||||
files: ['**/*.{ts,tsx}'],
|
||||
languageOptions: {
|
||||
ecmaVersion: 2020,
|
||||
globals: globals.browser,
|
||||
},
|
||||
plugins: {
|
||||
'react-hooks': reactHooks,
|
||||
'react-refresh': reactRefresh,
|
||||
},
|
||||
rules: {
|
||||
...reactHooks.configs.recommended.rules,
|
||||
'react-refresh/only-export-components': [
|
||||
'warn',
|
||||
{ allowConstantExport: true },
|
||||
],
|
||||
'no-console': 'error',
|
||||
'no-unused-vars': 'off',
|
||||
'@typescript-eslint/no-unused-vars': [
|
||||
'error',
|
||||
{
|
||||
args: 'all',
|
||||
argsIgnorePattern: '^_',
|
||||
caughtErrors: 'all',
|
||||
caughtErrorsIgnorePattern: '^_',
|
||||
destructuredArrayIgnorePattern: '^_',
|
||||
varsIgnorePattern: '^_',
|
||||
ignoreRestSiblings: true,
|
||||
},
|
||||
],
|
||||
// Enforce type-only imports for TypeScript types
|
||||
'@typescript-eslint/consistent-type-imports': [
|
||||
'error',
|
||||
{
|
||||
prefer: 'type-imports',
|
||||
fixStyle: 'inline-type-imports',
|
||||
disallowTypeAnnotations: false,
|
||||
},
|
||||
],
|
||||
// Prevent duplicate imports from the same module
|
||||
'no-duplicate-imports': 'error',
|
||||
},
|
||||
}
|
||||
)
|
||||
@@ -1,20 +1,80 @@
|
||||
<!doctype html>
|
||||
<html lang="zh">
|
||||
<html lang="en">
|
||||
<head>
|
||||
<meta charset="utf-8" />
|
||||
<link rel="icon" href="/logo.png" />
|
||||
<meta name="viewport" content="width=device-width, initial-scale=1" />
|
||||
<meta name="theme-color" content="#ffffff" />
|
||||
<meta charset="UTF-8" />
|
||||
<link
|
||||
rel="icon"
|
||||
type="image/svg+xml"
|
||||
href="/images/favicon.svg"
|
||||
media="(prefers-color-scheme: light)"
|
||||
/>
|
||||
<link
|
||||
rel="icon"
|
||||
type="image/svg+xml"
|
||||
href="/images/favicon_light.svg"
|
||||
media="(prefers-color-scheme: dark)"
|
||||
/>
|
||||
<link
|
||||
rel="icon"
|
||||
type="image/png"
|
||||
href="/images/favicon.png"
|
||||
media="(prefers-color-scheme: light)"
|
||||
/>
|
||||
<link
|
||||
rel="icon"
|
||||
type="image/png"
|
||||
href="/images/favicon_light.png"
|
||||
media="(prefers-color-scheme: dark)"
|
||||
/>
|
||||
<meta name="viewport" content="width=device-width, initial-scale=1.0" />
|
||||
|
||||
<!-- Primary Meta Tags -->
|
||||
<title>Shadcn Admin</title>
|
||||
<meta name="title" content="Shadcn Admin" />
|
||||
<meta
|
||||
name="description"
|
||||
content="OpenAI 接口聚合管理,支持多种渠道包括 Azure,可用于二次分发管理 key,仅单可执行文件,已打包好 Docker 镜像,一键部署,开箱即用"
|
||||
content="Admin Dashboard UI built with Shadcn and Vite."
|
||||
/>
|
||||
<title>New API</title>
|
||||
|
||||
<!-- Open Graph / Facebook -->
|
||||
<meta property="og:type" content="website" />
|
||||
<meta property="og:url" content="https://shadcn-admin.netlify.app" />
|
||||
<meta property="og:title" content="Shadcn Admin" />
|
||||
<meta
|
||||
property="og:description"
|
||||
content="Admin Dashboard UI built with Shadcn and Vite."
|
||||
/>
|
||||
<meta
|
||||
property="og:image"
|
||||
content="https://shadcn-admin.netlify.app/images/shadcn-admin.png"
|
||||
/>
|
||||
|
||||
<!-- Twitter -->
|
||||
<meta property="twitter:card" content="summary_large_image" />
|
||||
<meta property="twitter:url" content="https://shadcn-admin.netlify.app" />
|
||||
<meta property="twitter:title" content="Shadcn Admin" />
|
||||
<meta
|
||||
property="twitter:description"
|
||||
content="Admin Dashboard UI built with Shadcn and Vite."
|
||||
/>
|
||||
<meta
|
||||
property="twitter:image"
|
||||
content="https://shadcn-admin.netlify.app/images/shadcn-admin.png"
|
||||
/>
|
||||
|
||||
<!-- font family -->
|
||||
<link rel="preconnect" href="https://fonts.googleapis.com" />
|
||||
<link rel="preconnect" href="https://fonts.gstatic.com" crossorigin />
|
||||
<link
|
||||
href="https://fonts.googleapis.com/css2?family=Inter:ital,opsz,wght@0,14..32,100..900;1,14..32,100..900&family=Manrope:wght@200..800&display=swap"
|
||||
rel="stylesheet"
|
||||
/>
|
||||
|
||||
<meta name="theme-color" content="#fff" />
|
||||
</head>
|
||||
|
||||
<body>
|
||||
<noscript>You need to enable JavaScript to run this app.</noscript>
|
||||
<div id="root"></div>
|
||||
<script type="module" src="/src/index.jsx"></script>
|
||||
<script type="module" src="/src/main.tsx"></script>
|
||||
</body>
|
||||
</html>
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user