diff --git a/common/constants.go b/common/constants.go index e6d59d101..2ef2b7df2 100644 --- a/common/constants.go +++ b/common/constants.go @@ -19,6 +19,7 @@ var TopUpLink = "" // var ChatLink = "" // var ChatLink2 = "" var QuotaPerUnit = 500 * 1000.0 // $0.002 / 1K tokens +// 保留旧变量以兼容历史逻辑,实际展示由 general_setting.quota_display_type 控制 var DisplayInCurrencyEnabled = true var DisplayTokenStatEnabled = true var DrawingEnabled = true diff --git a/common/database.go b/common/database.go index 38a54d5e6..71dbd94d5 100644 --- a/common/database.go +++ b/common/database.go @@ -12,4 +12,4 @@ var LogSqlType = DatabaseTypeSQLite // Default to SQLite for logging SQL queries var UsingMySQL = false var UsingClickHouse = false -var SQLitePath = "one-api.db?_busy_timeout=30000" \ No newline at end of file +var SQLitePath = "one-api.db?_busy_timeout=30000" diff --git a/constant/api_type.go b/constant/api_type.go index 0ea5048f2..130ae9455 100644 --- a/constant/api_type.go +++ b/constant/api_type.go @@ -31,7 +31,7 @@ const ( APITypeXai APITypeCoze APITypeJimeng - APITypeMoonshot - APITypeSubmodel - APITypeDummy // this one is only for count, do not add any channel after this + APITypeMoonshot + APITypeSubmodel + APITypeDummy // this one is only for count, do not add any channel after this ) diff --git a/controller/billing.go b/controller/billing.go index 1fb83633e..db3b62b1e 100644 --- a/controller/billing.go +++ b/controller/billing.go @@ -5,6 +5,7 @@ import ( "one-api/common" "one-api/dto" "one-api/model" + "one-api/setting/operation_setting" ) func GetSubscription(c *gin.Context) { @@ -39,8 +40,18 @@ func GetSubscription(c *gin.Context) { } quota := remainQuota + usedQuota amount := float64(quota) - if common.DisplayInCurrencyEnabled { - amount /= common.QuotaPerUnit + // OpenAI 兼容接口中的 *_USD 字段含义保持“额度单位”对应值: + // 我们将其解释为以“站点展示类型”为准: + // - USD: 直接除以 QuotaPerUnit + // - CNY: 先转 USD 再乘汇率 + // - TOKENS: 直接使用 tokens 数量 + switch operation_setting.GetQuotaDisplayType() { + case operation_setting.QuotaDisplayTypeCNY: + amount = amount / common.QuotaPerUnit * operation_setting.USDExchangeRate + case operation_setting.QuotaDisplayTypeTokens: + // amount 保持 tokens 数值 + default: + amount = amount / common.QuotaPerUnit } if token != nil && token.UnlimitedQuota { amount = 100000000 @@ -80,8 +91,13 @@ func GetUsage(c *gin.Context) { return } amount := float64(quota) - if common.DisplayInCurrencyEnabled { - amount /= common.QuotaPerUnit + switch operation_setting.GetQuotaDisplayType() { + case operation_setting.QuotaDisplayTypeCNY: + amount = amount / common.QuotaPerUnit * operation_setting.USDExchangeRate + case operation_setting.QuotaDisplayTypeTokens: + // tokens 保持原值 + default: + amount = amount / common.QuotaPerUnit } usage := OpenAIUsageResponse{ Object: "list", diff --git a/controller/misc.go b/controller/misc.go index 07f7d3f05..a3e017f87 100644 --- a/controller/misc.go +++ b/controller/misc.go @@ -66,18 +66,22 @@ func GetStatus(c *gin.Context) { "top_up_link": common.TopUpLink, "docs_link": operation_setting.GetGeneralSetting().DocsLink, "quota_per_unit": common.QuotaPerUnit, - "display_in_currency": common.DisplayInCurrencyEnabled, - "enable_batch_update": common.BatchUpdateEnabled, - "enable_drawing": common.DrawingEnabled, - "enable_task": common.TaskEnabled, - "enable_data_export": common.DataExportEnabled, - "data_export_default_time": common.DataExportDefaultTime, - "default_collapse_sidebar": common.DefaultCollapseSidebar, - "mj_notify_enabled": setting.MjNotifyEnabled, - "chats": setting.Chats, - "demo_site_enabled": operation_setting.DemoSiteEnabled, - "self_use_mode_enabled": operation_setting.SelfUseModeEnabled, - "default_use_auto_group": setting.DefaultUseAutoGroup, + // 兼容旧前端:保留 display_in_currency,同时提供新的 quota_display_type + "display_in_currency": operation_setting.IsCurrencyDisplay(), + "quota_display_type": operation_setting.GetQuotaDisplayType(), + "custom_currency_symbol": operation_setting.GetGeneralSetting().CustomCurrencySymbol, + "custom_currency_exchange_rate": operation_setting.GetGeneralSetting().CustomCurrencyExchangeRate, + "enable_batch_update": common.BatchUpdateEnabled, + "enable_drawing": common.DrawingEnabled, + "enable_task": common.TaskEnabled, + "enable_data_export": common.DataExportEnabled, + "data_export_default_time": common.DataExportDefaultTime, + "default_collapse_sidebar": common.DefaultCollapseSidebar, + "mj_notify_enabled": setting.MjNotifyEnabled, + "chats": setting.Chats, + "demo_site_enabled": operation_setting.DemoSiteEnabled, + "self_use_mode_enabled": operation_setting.SelfUseModeEnabled, + "default_use_auto_group": setting.DefaultUseAutoGroup, "usd_exchange_rate": operation_setting.USDExchangeRate, "price": operation_setting.Price, diff --git a/controller/setup.go b/controller/setup.go index 3ae255e94..8b7fa3e7a 100644 --- a/controller/setup.go +++ b/controller/setup.go @@ -178,4 +178,4 @@ func boolToString(b bool) string { return "true" } return "false" -} \ No newline at end of file +} diff --git a/controller/topup.go b/controller/topup.go index 243e67940..7e2cadf1d 100644 --- a/controller/topup.go +++ b/controller/topup.go @@ -86,8 +86,9 @@ func GetEpayClient() *epay.Client { func getPayMoney(amount int64, group string) float64 { dAmount := decimal.NewFromInt(amount) - - if !common.DisplayInCurrencyEnabled { + // 充值金额以“展示类型”为准: + // - USD/CNY: 前端传 amount 为金额单位;TOKENS: 前端传 tokens,需要换成 USD 金额 + if operation_setting.GetQuotaDisplayType() == operation_setting.QuotaDisplayTypeTokens { dQuotaPerUnit := decimal.NewFromFloat(common.QuotaPerUnit) dAmount = dAmount.Div(dQuotaPerUnit) } @@ -115,7 +116,7 @@ func getPayMoney(amount int64, group string) float64 { func getMinTopup() int64 { minTopup := operation_setting.MinTopUp - if !common.DisplayInCurrencyEnabled { + if operation_setting.GetQuotaDisplayType() == operation_setting.QuotaDisplayTypeTokens { dMinTopup := decimal.NewFromInt(int64(minTopup)) dQuotaPerUnit := decimal.NewFromFloat(common.QuotaPerUnit) minTopup = int(dMinTopup.Mul(dQuotaPerUnit).IntPart()) @@ -176,7 +177,7 @@ func RequestEpay(c *gin.Context) { return } amount := req.Amount - if !common.DisplayInCurrencyEnabled { + if operation_setting.GetQuotaDisplayType() == operation_setting.QuotaDisplayTypeTokens { dAmount := decimal.NewFromInt(int64(amount)) dQuotaPerUnit := decimal.NewFromFloat(common.QuotaPerUnit) amount = dAmount.Div(dQuotaPerUnit).IntPart() diff --git a/controller/topup_stripe.go b/controller/topup_stripe.go index 9a568d857..628a3fea5 100644 --- a/controller/topup_stripe.go +++ b/controller/topup_stripe.go @@ -258,7 +258,7 @@ func GetChargedAmount(count float64, user model.User) float64 { func getStripePayMoney(amount float64, group string) float64 { originalAmount := amount - if !common.DisplayInCurrencyEnabled { + if operation_setting.GetQuotaDisplayType() == operation_setting.QuotaDisplayTypeTokens { amount = amount / common.QuotaPerUnit } // Using float64 for monetary calculations is acceptable here due to the small amounts involved @@ -279,7 +279,7 @@ func getStripePayMoney(amount float64, group string) float64 { func getStripeMinTopup() int64 { minTopup := setting.StripeMinTopUp - if !common.DisplayInCurrencyEnabled { + if operation_setting.GetQuotaDisplayType() == operation_setting.QuotaDisplayTypeTokens { minTopup = minTopup * int(common.QuotaPerUnit) } return int64(minTopup) diff --git a/logger/logger.go b/logger/logger.go index d59e51cb8..68f564370 100644 --- a/logger/logger.go +++ b/logger/logger.go @@ -7,6 +7,7 @@ import ( "io" "log" "one-api/common" + "one-api/setting/operation_setting" "os" "path/filepath" "sync" @@ -92,18 +93,55 @@ func logHelper(ctx context.Context, level string, msg string) { } func LogQuota(quota int) string { - if common.DisplayInCurrencyEnabled { - return fmt.Sprintf("$%.6f 额度", float64(quota)/common.QuotaPerUnit) - } else { + // 新逻辑:根据额度展示类型输出 + q := float64(quota) + switch operation_setting.GetQuotaDisplayType() { + case operation_setting.QuotaDisplayTypeCNY: + usd := q / common.QuotaPerUnit + cny := usd * operation_setting.USDExchangeRate + return fmt.Sprintf("¥%.6f 额度", cny) + case operation_setting.QuotaDisplayTypeCustom: + usd := q / common.QuotaPerUnit + rate := operation_setting.GetGeneralSetting().CustomCurrencyExchangeRate + symbol := operation_setting.GetGeneralSetting().CustomCurrencySymbol + if symbol == "" { + symbol = "¤" + } + if rate <= 0 { + rate = 1 + } + v := usd * rate + return fmt.Sprintf("%s%.6f 额度", symbol, v) + case operation_setting.QuotaDisplayTypeTokens: return fmt.Sprintf("%d 点额度", quota) + default: // USD + return fmt.Sprintf("$%.6f 额度", q/common.QuotaPerUnit) } } func FormatQuota(quota int) string { - if common.DisplayInCurrencyEnabled { - return fmt.Sprintf("$%.6f", float64(quota)/common.QuotaPerUnit) - } else { + q := float64(quota) + switch operation_setting.GetQuotaDisplayType() { + case operation_setting.QuotaDisplayTypeCNY: + usd := q / common.QuotaPerUnit + cny := usd * operation_setting.USDExchangeRate + return fmt.Sprintf("¥%.6f", cny) + case operation_setting.QuotaDisplayTypeCustom: + usd := q / common.QuotaPerUnit + rate := operation_setting.GetGeneralSetting().CustomCurrencyExchangeRate + symbol := operation_setting.GetGeneralSetting().CustomCurrencySymbol + if symbol == "" { + symbol = "¤" + } + if rate <= 0 { + rate = 1 + } + v := usd * rate + return fmt.Sprintf("%s%.6f", symbol, v) + case operation_setting.QuotaDisplayTypeTokens: return fmt.Sprintf("%d", quota) + default: + return fmt.Sprintf("$%.6f", q/common.QuotaPerUnit) } } diff --git a/model/option.go b/model/option.go index 9ace8fece..77525ea25 100644 --- a/model/option.go +++ b/model/option.go @@ -240,7 +240,15 @@ func updateOptionMap(key string, value string) (err error) { case "LogConsumeEnabled": common.LogConsumeEnabled = boolValue case "DisplayInCurrencyEnabled": - common.DisplayInCurrencyEnabled = boolValue + // 兼容旧字段:同步到新配置 general_setting.quota_display_type(运行时生效) + // true -> USD, false -> TOKENS + newVal := "USD" + if !boolValue { + newVal = "TOKENS" + } + if cfg := config.GlobalConfig.Get("general_setting"); cfg != nil { + _ = config.UpdateConfigFromMap(cfg, map[string]string{"quota_display_type": newVal}) + } case "DisplayTokenStatEnabled": common.DisplayTokenStatEnabled = boolValue case "DrawingEnabled": diff --git a/relay/channel/ollama/adaptor.go b/relay/channel/ollama/adaptor.go index bafe73b92..d88e04d6f 100644 --- a/relay/channel/ollama/adaptor.go +++ b/relay/channel/ollama/adaptor.go @@ -18,7 +18,9 @@ import ( type Adaptor struct { } -func (a *Adaptor) ConvertGeminiRequest(*gin.Context, *relaycommon.RelayInfo, *dto.GeminiChatRequest) (any, error) { return nil, errors.New("not implemented") } +func (a *Adaptor) ConvertGeminiRequest(*gin.Context, *relaycommon.RelayInfo, *dto.GeminiChatRequest) (any, error) { + return nil, errors.New("not implemented") +} func (a *Adaptor) ConvertClaudeRequest(c *gin.Context, info *relaycommon.RelayInfo, request *dto.ClaudeRequest) (any, error) { openaiAdaptor := openai.Adaptor{} @@ -33,17 +35,25 @@ func (a *Adaptor) ConvertClaudeRequest(c *gin.Context, info *relaycommon.RelayIn return openAIChatToOllamaChat(c, openaiRequest.(*dto.GeneralOpenAIRequest)) } -func (a *Adaptor) ConvertAudioRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.AudioRequest) (io.Reader, error) { return nil, errors.New("not implemented") } +func (a *Adaptor) ConvertAudioRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.AudioRequest) (io.Reader, error) { + return nil, errors.New("not implemented") +} -func (a *Adaptor) ConvertImageRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.ImageRequest) (any, error) { return nil, errors.New("not implemented") } +func (a *Adaptor) ConvertImageRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.ImageRequest) (any, error) { + return nil, errors.New("not implemented") +} func (a *Adaptor) Init(info *relaycommon.RelayInfo) { } func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) { - if info.RelayMode == relayconstant.RelayModeEmbeddings { return info.ChannelBaseUrl + "/api/embed", nil } - if strings.Contains(info.RequestURLPath, "/v1/completions") || info.RelayMode == relayconstant.RelayModeCompletions { return info.ChannelBaseUrl + "/api/generate", nil } - return info.ChannelBaseUrl + "/api/chat", nil + if info.RelayMode == relayconstant.RelayModeEmbeddings { + return info.ChannelBaseUrl + "/api/embed", nil + } + if strings.Contains(info.RequestURLPath, "/v1/completions") || info.RelayMode == relayconstant.RelayModeCompletions { + return info.ChannelBaseUrl + "/api/generate", nil + } + return info.ChannelBaseUrl + "/api/chat", nil } func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Header, info *relaycommon.RelayInfo) error { @@ -53,7 +63,9 @@ func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Header, info *rel } func (a *Adaptor) ConvertOpenAIRequest(c *gin.Context, info *relaycommon.RelayInfo, request *dto.GeneralOpenAIRequest) (any, error) { - if request == nil { return nil, errors.New("request is nil") } + if request == nil { + return nil, errors.New("request is nil") + } // decide generate or chat if strings.Contains(info.RequestURLPath, "/v1/completions") || info.RelayMode == relayconstant.RelayModeCompletions { return openAIToGenerate(c, request) @@ -69,7 +81,9 @@ func (a *Adaptor) ConvertEmbeddingRequest(c *gin.Context, info *relaycommon.Rela return requestOpenAI2Embeddings(request), nil } -func (a *Adaptor) ConvertOpenAIResponsesRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.OpenAIResponsesRequest) (any, error) { return nil, errors.New("not implemented") } +func (a *Adaptor) ConvertOpenAIResponsesRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.OpenAIResponsesRequest) (any, error) { + return nil, errors.New("not implemented") +} func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (any, error) { return channel.DoApiRequest(a, c, info, requestBody) diff --git a/relay/channel/ollama/dto.go b/relay/channel/ollama/dto.go index 45e49ab43..2434a4cbc 100644 --- a/relay/channel/ollama/dto.go +++ b/relay/channel/ollama/dto.go @@ -5,12 +5,12 @@ import ( ) type OllamaChatMessage struct { - Role string `json:"role"` - Content string `json:"content,omitempty"` - Images []string `json:"images,omitempty"` - ToolCalls []OllamaToolCall `json:"tool_calls,omitempty"` - ToolName string `json:"tool_name,omitempty"` - Thinking json.RawMessage `json:"thinking,omitempty"` + Role string `json:"role"` + Content string `json:"content,omitempty"` + Images []string `json:"images,omitempty"` + ToolCalls []OllamaToolCall `json:"tool_calls,omitempty"` + ToolName string `json:"tool_name,omitempty"` + Thinking json.RawMessage `json:"thinking,omitempty"` } type OllamaToolFunction struct { @@ -20,7 +20,7 @@ type OllamaToolFunction struct { } type OllamaTool struct { - Type string `json:"type"` + Type string `json:"type"` Function OllamaToolFunction `json:"function"` } @@ -43,28 +43,27 @@ type OllamaChatRequest struct { } type OllamaGenerateRequest struct { - Model string `json:"model"` - Prompt string `json:"prompt,omitempty"` - Suffix string `json:"suffix,omitempty"` - Images []string `json:"images,omitempty"` - Format interface{} `json:"format,omitempty"` - Stream bool `json:"stream,omitempty"` - Options map[string]any `json:"options,omitempty"` - KeepAlive interface{} `json:"keep_alive,omitempty"` + Model string `json:"model"` + Prompt string `json:"prompt,omitempty"` + Suffix string `json:"suffix,omitempty"` + Images []string `json:"images,omitempty"` + Format interface{} `json:"format,omitempty"` + Stream bool `json:"stream,omitempty"` + Options map[string]any `json:"options,omitempty"` + KeepAlive interface{} `json:"keep_alive,omitempty"` Think json.RawMessage `json:"think,omitempty"` } type OllamaEmbeddingRequest struct { - Model string `json:"model"` - Input interface{} `json:"input"` - Options map[string]any `json:"options,omitempty"` + Model string `json:"model"` + Input interface{} `json:"input"` + Options map[string]any `json:"options,omitempty"` Dimensions int `json:"dimensions,omitempty"` } type OllamaEmbeddingResponse struct { - Error string `json:"error,omitempty"` - Model string `json:"model"` - Embeddings [][]float64 `json:"embeddings"` - PromptEvalCount int `json:"prompt_eval_count,omitempty"` + Error string `json:"error,omitempty"` + Model string `json:"model"` + Embeddings [][]float64 `json:"embeddings"` + PromptEvalCount int `json:"prompt_eval_count,omitempty"` } - diff --git a/relay/channel/ollama/relay-ollama.go b/relay/channel/ollama/relay-ollama.go index 3b67f9525..f94a654c7 100644 --- a/relay/channel/ollama/relay-ollama.go +++ b/relay/channel/ollama/relay-ollama.go @@ -35,13 +35,27 @@ func openAIChatToOllamaChat(c *gin.Context, r *dto.GeneralOpenAIRequest) (*Ollam } // options mapping - if r.Temperature != nil { chatReq.Options["temperature"] = r.Temperature } - if r.TopP != 0 { chatReq.Options["top_p"] = r.TopP } - if r.TopK != 0 { chatReq.Options["top_k"] = r.TopK } - if r.FrequencyPenalty != 0 { chatReq.Options["frequency_penalty"] = r.FrequencyPenalty } - if r.PresencePenalty != 0 { chatReq.Options["presence_penalty"] = r.PresencePenalty } - if r.Seed != 0 { chatReq.Options["seed"] = int(r.Seed) } - if mt := r.GetMaxTokens(); mt != 0 { chatReq.Options["num_predict"] = int(mt) } + if r.Temperature != nil { + chatReq.Options["temperature"] = r.Temperature + } + if r.TopP != 0 { + chatReq.Options["top_p"] = r.TopP + } + if r.TopK != 0 { + chatReq.Options["top_k"] = r.TopK + } + if r.FrequencyPenalty != 0 { + chatReq.Options["frequency_penalty"] = r.FrequencyPenalty + } + if r.PresencePenalty != 0 { + chatReq.Options["presence_penalty"] = r.PresencePenalty + } + if r.Seed != 0 { + chatReq.Options["seed"] = int(r.Seed) + } + if mt := r.GetMaxTokens(); mt != 0 { + chatReq.Options["num_predict"] = int(mt) + } if r.Stop != nil { switch v := r.Stop.(type) { @@ -50,21 +64,27 @@ func openAIChatToOllamaChat(c *gin.Context, r *dto.GeneralOpenAIRequest) (*Ollam case []string: chatReq.Options["stop"] = v case []any: - arr := make([]string,0,len(v)) - for _, i := range v { if s,ok:=i.(string); ok { arr = append(arr,s) } } - if len(arr)>0 { chatReq.Options["stop"] = arr } + arr := make([]string, 0, len(v)) + for _, i := range v { + if s, ok := i.(string); ok { + arr = append(arr, s) + } + } + if len(arr) > 0 { + chatReq.Options["stop"] = arr + } } } if len(r.Tools) > 0 { - tools := make([]OllamaTool,0,len(r.Tools)) + tools := make([]OllamaTool, 0, len(r.Tools)) for _, t := range r.Tools { tools = append(tools, OllamaTool{Type: "function", Function: OllamaToolFunction{Name: t.Function.Name, Description: t.Function.Description, Parameters: t.Function.Parameters}}) } chatReq.Tools = tools } - chatReq.Messages = make([]OllamaChatMessage,0,len(r.Messages)) + chatReq.Messages = make([]OllamaChatMessage, 0, len(r.Messages)) for _, m := range r.Messages { var textBuilder strings.Builder var images []string @@ -79,14 +99,20 @@ func openAIChatToOllamaChat(c *gin.Context, r *dto.GeneralOpenAIRequest) (*Ollam var base64Data string if strings.HasPrefix(img.Url, "http") { fileData, err := service.GetFileBase64FromUrl(c, img.Url, "fetch image for ollama chat") - if err != nil { return nil, err } + if err != nil { + return nil, err + } base64Data = fileData.Base64Data } else if strings.HasPrefix(img.Url, "data:") { - if idx := strings.Index(img.Url, ","); idx != -1 && idx+1 < len(img.Url) { base64Data = img.Url[idx+1:] } + if idx := strings.Index(img.Url, ","); idx != -1 && idx+1 < len(img.Url) { + base64Data = img.Url[idx+1:] + } } else { base64Data = img.Url } - if base64Data != "" { images = append(images, base64Data) } + if base64Data != "" { + images = append(images, base64Data) + } } } else if part.Type == dto.ContentTypeText { textBuilder.WriteString(part.Text) @@ -94,16 +120,24 @@ func openAIChatToOllamaChat(c *gin.Context, r *dto.GeneralOpenAIRequest) (*Ollam } } cm := OllamaChatMessage{Role: m.Role, Content: textBuilder.String()} - if len(images)>0 { cm.Images = images } - if m.Role == "tool" && m.Name != nil { cm.ToolName = *m.Name } + if len(images) > 0 { + cm.Images = images + } + if m.Role == "tool" && m.Name != nil { + cm.ToolName = *m.Name + } if m.ToolCalls != nil && len(m.ToolCalls) > 0 { parsed := m.ParseToolCalls() if len(parsed) > 0 { - calls := make([]OllamaToolCall,0,len(parsed)) + calls := make([]OllamaToolCall, 0, len(parsed)) for _, tc := range parsed { var args interface{} - if tc.Function.Arguments != "" { _ = json.Unmarshal([]byte(tc.Function.Arguments), &args) } - if args==nil { args = map[string]any{} } + if tc.Function.Arguments != "" { + _ = json.Unmarshal([]byte(tc.Function.Arguments), &args) + } + if args == nil { + args = map[string]any{} + } oc := OllamaToolCall{} oc.Function.Name = tc.Function.Name oc.Function.Arguments = args @@ -132,28 +166,67 @@ func openAIToGenerate(c *gin.Context, r *dto.GeneralOpenAIRequest) (*OllamaGener gen.Prompt = v case []any: var sb strings.Builder - for _, it := range v { if s,ok:=it.(string); ok { sb.WriteString(s) } } + for _, it := range v { + if s, ok := it.(string); ok { + sb.WriteString(s) + } + } gen.Prompt = sb.String() default: gen.Prompt = fmt.Sprintf("%v", r.Prompt) } } - if r.Suffix != nil { if s,ok:=r.Suffix.(string); ok { gen.Suffix = s } } - if r.ResponseFormat != nil { - if r.ResponseFormat.Type == "json" { gen.Format = "json" } else if r.ResponseFormat.Type == "json_schema" { var schema any; _ = json.Unmarshal(r.ResponseFormat.JsonSchema,&schema); gen.Format=schema } + if r.Suffix != nil { + if s, ok := r.Suffix.(string); ok { + gen.Suffix = s + } + } + if r.ResponseFormat != nil { + if r.ResponseFormat.Type == "json" { + gen.Format = "json" + } else if r.ResponseFormat.Type == "json_schema" { + var schema any + _ = json.Unmarshal(r.ResponseFormat.JsonSchema, &schema) + gen.Format = schema + } + } + if r.Temperature != nil { + gen.Options["temperature"] = r.Temperature + } + if r.TopP != 0 { + gen.Options["top_p"] = r.TopP + } + if r.TopK != 0 { + gen.Options["top_k"] = r.TopK + } + if r.FrequencyPenalty != 0 { + gen.Options["frequency_penalty"] = r.FrequencyPenalty + } + if r.PresencePenalty != 0 { + gen.Options["presence_penalty"] = r.PresencePenalty + } + if r.Seed != 0 { + gen.Options["seed"] = int(r.Seed) + } + if mt := r.GetMaxTokens(); mt != 0 { + gen.Options["num_predict"] = int(mt) } - if r.Temperature != nil { gen.Options["temperature"] = r.Temperature } - if r.TopP != 0 { gen.Options["top_p"] = r.TopP } - if r.TopK != 0 { gen.Options["top_k"] = r.TopK } - if r.FrequencyPenalty != 0 { gen.Options["frequency_penalty"] = r.FrequencyPenalty } - if r.PresencePenalty != 0 { gen.Options["presence_penalty"] = r.PresencePenalty } - if r.Seed != 0 { gen.Options["seed"] = int(r.Seed) } - if mt := r.GetMaxTokens(); mt != 0 { gen.Options["num_predict"] = int(mt) } if r.Stop != nil { switch v := r.Stop.(type) { - case string: gen.Options["stop"] = []string{v} - case []string: gen.Options["stop"] = v - case []any: arr:=make([]string,0,len(v)); for _,i:= range v { if s,ok:=i.(string); ok { arr=append(arr,s) } }; if len(arr)>0 { gen.Options["stop"]=arr } + case string: + gen.Options["stop"] = []string{v} + case []string: + gen.Options["stop"] = v + case []any: + arr := make([]string, 0, len(v)) + for _, i := range v { + if s, ok := i.(string); ok { + arr = append(arr, s) + } + } + if len(arr) > 0 { + gen.Options["stop"] = arr + } } } return gen, nil @@ -161,30 +234,51 @@ func openAIToGenerate(c *gin.Context, r *dto.GeneralOpenAIRequest) (*OllamaGener func requestOpenAI2Embeddings(r dto.EmbeddingRequest) *OllamaEmbeddingRequest { opts := map[string]any{} - if r.Temperature != nil { opts["temperature"] = r.Temperature } - if r.TopP != 0 { opts["top_p"] = r.TopP } - if r.FrequencyPenalty != 0 { opts["frequency_penalty"] = r.FrequencyPenalty } - if r.PresencePenalty != 0 { opts["presence_penalty"] = r.PresencePenalty } - if r.Seed != 0 { opts["seed"] = int(r.Seed) } - if r.Dimensions != 0 { opts["dimensions"] = r.Dimensions } + if r.Temperature != nil { + opts["temperature"] = r.Temperature + } + if r.TopP != 0 { + opts["top_p"] = r.TopP + } + if r.FrequencyPenalty != 0 { + opts["frequency_penalty"] = r.FrequencyPenalty + } + if r.PresencePenalty != 0 { + opts["presence_penalty"] = r.PresencePenalty + } + if r.Seed != 0 { + opts["seed"] = int(r.Seed) + } + if r.Dimensions != 0 { + opts["dimensions"] = r.Dimensions + } input := r.ParseInput() - if len(input)==1 { return &OllamaEmbeddingRequest{Model:r.Model, Input: input[0], Options: opts, Dimensions:r.Dimensions} } - return &OllamaEmbeddingRequest{Model:r.Model, Input: input, Options: opts, Dimensions:r.Dimensions} + if len(input) == 1 { + return &OllamaEmbeddingRequest{Model: r.Model, Input: input[0], Options: opts, Dimensions: r.Dimensions} + } + return &OllamaEmbeddingRequest{Model: r.Model, Input: input, Options: opts, Dimensions: r.Dimensions} } func ollamaEmbeddingHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http.Response) (*dto.Usage, *types.NewAPIError) { var oResp OllamaEmbeddingResponse body, err := io.ReadAll(resp.Body) - if err != nil { return nil, types.NewOpenAIError(err, types.ErrorCodeBadResponseBody, http.StatusInternalServerError) } + if err != nil { + return nil, types.NewOpenAIError(err, types.ErrorCodeBadResponseBody, http.StatusInternalServerError) + } service.CloseResponseBodyGracefully(resp) - if err = common.Unmarshal(body, &oResp); err != nil { return nil, types.NewOpenAIError(err, types.ErrorCodeBadResponseBody, http.StatusInternalServerError) } - if oResp.Error != "" { return nil, types.NewOpenAIError(fmt.Errorf("ollama error: %s", oResp.Error), types.ErrorCodeBadResponseBody, http.StatusInternalServerError) } - data := make([]dto.OpenAIEmbeddingResponseItem,0,len(oResp.Embeddings)) - for i, emb := range oResp.Embeddings { data = append(data, dto.OpenAIEmbeddingResponseItem{Index:i,Object:"embedding",Embedding:emb}) } - usage := &dto.Usage{PromptTokens: oResp.PromptEvalCount, CompletionTokens:0, TotalTokens: oResp.PromptEvalCount} - embResp := &dto.OpenAIEmbeddingResponse{Object:"list", Data:data, Model: info.UpstreamModelName, Usage:*usage} + if err = common.Unmarshal(body, &oResp); err != nil { + return nil, types.NewOpenAIError(err, types.ErrorCodeBadResponseBody, http.StatusInternalServerError) + } + if oResp.Error != "" { + return nil, types.NewOpenAIError(fmt.Errorf("ollama error: %s", oResp.Error), types.ErrorCodeBadResponseBody, http.StatusInternalServerError) + } + data := make([]dto.OpenAIEmbeddingResponseItem, 0, len(oResp.Embeddings)) + for i, emb := range oResp.Embeddings { + data = append(data, dto.OpenAIEmbeddingResponseItem{Index: i, Object: "embedding", Embedding: emb}) + } + usage := &dto.Usage{PromptTokens: oResp.PromptEvalCount, CompletionTokens: 0, TotalTokens: oResp.PromptEvalCount} + embResp := &dto.OpenAIEmbeddingResponse{Object: "list", Data: data, Model: info.UpstreamModelName, Usage: *usage} out, _ := common.Marshal(embResp) service.IOCopyBytesGracefully(c, resp, out) return usage, nil } - diff --git a/relay/channel/ollama/stream.go b/relay/channel/ollama/stream.go index 964f11d90..9a98f7b69 100644 --- a/relay/channel/ollama/stream.go +++ b/relay/channel/ollama/stream.go @@ -1,210 +1,278 @@ package ollama import ( - "bufio" - "encoding/json" - "fmt" - "io" - "net/http" - "one-api/common" - "one-api/dto" - "one-api/logger" - relaycommon "one-api/relay/common" - "one-api/relay/helper" - "one-api/service" - "one-api/types" - "strings" - "time" + "bufio" + "encoding/json" + "fmt" + "io" + "net/http" + "one-api/common" + "one-api/dto" + "one-api/logger" + relaycommon "one-api/relay/common" + "one-api/relay/helper" + "one-api/service" + "one-api/types" + "strings" + "time" - "github.com/gin-gonic/gin" + "github.com/gin-gonic/gin" ) type ollamaChatStreamChunk struct { - Model string `json:"model"` - CreatedAt string `json:"created_at"` - // chat - Message *struct { - Role string `json:"role"` - Content string `json:"content"` - Thinking json.RawMessage `json:"thinking"` - ToolCalls []struct { - Function struct { - Name string `json:"name"` - Arguments interface{} `json:"arguments"` - } `json:"function"` - } `json:"tool_calls"` - } `json:"message"` - // generate - Response string `json:"response"` - Done bool `json:"done"` - DoneReason string `json:"done_reason"` - TotalDuration int64 `json:"total_duration"` - LoadDuration int64 `json:"load_duration"` - PromptEvalCount int `json:"prompt_eval_count"` - EvalCount int `json:"eval_count"` - PromptEvalDuration int64 `json:"prompt_eval_duration"` - EvalDuration int64 `json:"eval_duration"` + Model string `json:"model"` + CreatedAt string `json:"created_at"` + // chat + Message *struct { + Role string `json:"role"` + Content string `json:"content"` + Thinking json.RawMessage `json:"thinking"` + ToolCalls []struct { + Function struct { + Name string `json:"name"` + Arguments interface{} `json:"arguments"` + } `json:"function"` + } `json:"tool_calls"` + } `json:"message"` + // generate + Response string `json:"response"` + Done bool `json:"done"` + DoneReason string `json:"done_reason"` + TotalDuration int64 `json:"total_duration"` + LoadDuration int64 `json:"load_duration"` + PromptEvalCount int `json:"prompt_eval_count"` + EvalCount int `json:"eval_count"` + PromptEvalDuration int64 `json:"prompt_eval_duration"` + EvalDuration int64 `json:"eval_duration"` } func toUnix(ts string) int64 { - if ts == "" { return time.Now().Unix() } - // try time.RFC3339 or with nanoseconds - t, err := time.Parse(time.RFC3339Nano, ts) - if err != nil { t2, err2 := time.Parse(time.RFC3339, ts); if err2==nil { return t2.Unix() }; return time.Now().Unix() } - return t.Unix() + if ts == "" { + return time.Now().Unix() + } + // try time.RFC3339 or with nanoseconds + t, err := time.Parse(time.RFC3339Nano, ts) + if err != nil { + t2, err2 := time.Parse(time.RFC3339, ts) + if err2 == nil { + return t2.Unix() + } + return time.Now().Unix() + } + return t.Unix() } func ollamaStreamHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http.Response) (*dto.Usage, *types.NewAPIError) { - if resp == nil || resp.Body == nil { return nil, types.NewOpenAIError(fmt.Errorf("empty response"), types.ErrorCodeBadResponse, http.StatusBadRequest) } - defer service.CloseResponseBodyGracefully(resp) + if resp == nil || resp.Body == nil { + return nil, types.NewOpenAIError(fmt.Errorf("empty response"), types.ErrorCodeBadResponse, http.StatusBadRequest) + } + defer service.CloseResponseBodyGracefully(resp) - helper.SetEventStreamHeaders(c) - scanner := bufio.NewScanner(resp.Body) - usage := &dto.Usage{} - var model = info.UpstreamModelName - var responseId = common.GetUUID() - var created = time.Now().Unix() - var toolCallIndex int - start := helper.GenerateStartEmptyResponse(responseId, created, model, nil) - if data, err := common.Marshal(start); err == nil { _ = helper.StringData(c, string(data)) } + helper.SetEventStreamHeaders(c) + scanner := bufio.NewScanner(resp.Body) + usage := &dto.Usage{} + var model = info.UpstreamModelName + var responseId = common.GetUUID() + var created = time.Now().Unix() + var toolCallIndex int + start := helper.GenerateStartEmptyResponse(responseId, created, model, nil) + if data, err := common.Marshal(start); err == nil { + _ = helper.StringData(c, string(data)) + } - for scanner.Scan() { - line := scanner.Text() - line = strings.TrimSpace(line) - if line == "" { continue } - var chunk ollamaChatStreamChunk - if err := json.Unmarshal([]byte(line), &chunk); err != nil { - logger.LogError(c, "ollama stream json decode error: "+err.Error()+" line="+line) - return usage, types.NewOpenAIError(err, types.ErrorCodeBadResponseBody, http.StatusInternalServerError) - } - if chunk.Model != "" { model = chunk.Model } - created = toUnix(chunk.CreatedAt) + for scanner.Scan() { + line := scanner.Text() + line = strings.TrimSpace(line) + if line == "" { + continue + } + var chunk ollamaChatStreamChunk + if err := json.Unmarshal([]byte(line), &chunk); err != nil { + logger.LogError(c, "ollama stream json decode error: "+err.Error()+" line="+line) + return usage, types.NewOpenAIError(err, types.ErrorCodeBadResponseBody, http.StatusInternalServerError) + } + if chunk.Model != "" { + model = chunk.Model + } + created = toUnix(chunk.CreatedAt) - if !chunk.Done { - // delta content - var content string - if chunk.Message != nil { content = chunk.Message.Content } else { content = chunk.Response } - delta := dto.ChatCompletionsStreamResponse{ - Id: responseId, - Object: "chat.completion.chunk", - Created: created, - Model: model, - Choices: []dto.ChatCompletionsStreamResponseChoice{ { - Index: 0, - Delta: dto.ChatCompletionsStreamResponseChoiceDelta{ Role: "assistant" }, - } }, - } - if content != "" { delta.Choices[0].Delta.SetContentString(content) } - if chunk.Message != nil && len(chunk.Message.Thinking) > 0 { - raw := strings.TrimSpace(string(chunk.Message.Thinking)) - if raw != "" && raw != "null" { delta.Choices[0].Delta.SetReasoningContent(raw) } - } - // tool calls - if chunk.Message != nil && len(chunk.Message.ToolCalls) > 0 { - delta.Choices[0].Delta.ToolCalls = make([]dto.ToolCallResponse,0,len(chunk.Message.ToolCalls)) - for _, tc := range chunk.Message.ToolCalls { - // arguments -> string - argBytes, _ := json.Marshal(tc.Function.Arguments) - toolId := fmt.Sprintf("call_%d", toolCallIndex) - tr := dto.ToolCallResponse{ID:toolId, Type:"function", Function: dto.FunctionResponse{Name: tc.Function.Name, Arguments: string(argBytes)}} - tr.SetIndex(toolCallIndex) - toolCallIndex++ - delta.Choices[0].Delta.ToolCalls = append(delta.Choices[0].Delta.ToolCalls, tr) - } - } - if data, err := common.Marshal(delta); err == nil { _ = helper.StringData(c, string(data)) } - continue - } - // done frame - // finalize once and break loop - usage.PromptTokens = chunk.PromptEvalCount - usage.CompletionTokens = chunk.EvalCount - usage.TotalTokens = usage.PromptTokens + usage.CompletionTokens - finishReason := chunk.DoneReason - if finishReason == "" { finishReason = "stop" } - // emit stop delta - if stop := helper.GenerateStopResponse(responseId, created, model, finishReason); stop != nil { - if data, err := common.Marshal(stop); err == nil { _ = helper.StringData(c, string(data)) } - } - // emit usage frame - if final := helper.GenerateFinalUsageResponse(responseId, created, model, *usage); final != nil { - if data, err := common.Marshal(final); err == nil { _ = helper.StringData(c, string(data)) } - } - // send [DONE] - helper.Done(c) - break - } - if err := scanner.Err(); err != nil && err != io.EOF { logger.LogError(c, "ollama stream scan error: "+err.Error()) } - return usage, nil + if !chunk.Done { + // delta content + var content string + if chunk.Message != nil { + content = chunk.Message.Content + } else { + content = chunk.Response + } + delta := dto.ChatCompletionsStreamResponse{ + Id: responseId, + Object: "chat.completion.chunk", + Created: created, + Model: model, + Choices: []dto.ChatCompletionsStreamResponseChoice{{ + Index: 0, + Delta: dto.ChatCompletionsStreamResponseChoiceDelta{Role: "assistant"}, + }}, + } + if content != "" { + delta.Choices[0].Delta.SetContentString(content) + } + if chunk.Message != nil && len(chunk.Message.Thinking) > 0 { + raw := strings.TrimSpace(string(chunk.Message.Thinking)) + if raw != "" && raw != "null" { + delta.Choices[0].Delta.SetReasoningContent(raw) + } + } + // tool calls + if chunk.Message != nil && len(chunk.Message.ToolCalls) > 0 { + delta.Choices[0].Delta.ToolCalls = make([]dto.ToolCallResponse, 0, len(chunk.Message.ToolCalls)) + for _, tc := range chunk.Message.ToolCalls { + // arguments -> string + argBytes, _ := json.Marshal(tc.Function.Arguments) + toolId := fmt.Sprintf("call_%d", toolCallIndex) + tr := dto.ToolCallResponse{ID: toolId, Type: "function", Function: dto.FunctionResponse{Name: tc.Function.Name, Arguments: string(argBytes)}} + tr.SetIndex(toolCallIndex) + toolCallIndex++ + delta.Choices[0].Delta.ToolCalls = append(delta.Choices[0].Delta.ToolCalls, tr) + } + } + if data, err := common.Marshal(delta); err == nil { + _ = helper.StringData(c, string(data)) + } + continue + } + // done frame + // finalize once and break loop + usage.PromptTokens = chunk.PromptEvalCount + usage.CompletionTokens = chunk.EvalCount + usage.TotalTokens = usage.PromptTokens + usage.CompletionTokens + finishReason := chunk.DoneReason + if finishReason == "" { + finishReason = "stop" + } + // emit stop delta + if stop := helper.GenerateStopResponse(responseId, created, model, finishReason); stop != nil { + if data, err := common.Marshal(stop); err == nil { + _ = helper.StringData(c, string(data)) + } + } + // emit usage frame + if final := helper.GenerateFinalUsageResponse(responseId, created, model, *usage); final != nil { + if data, err := common.Marshal(final); err == nil { + _ = helper.StringData(c, string(data)) + } + } + // send [DONE] + helper.Done(c) + break + } + if err := scanner.Err(); err != nil && err != io.EOF { + logger.LogError(c, "ollama stream scan error: "+err.Error()) + } + return usage, nil } // non-stream handler for chat/generate func ollamaChatHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http.Response) (*dto.Usage, *types.NewAPIError) { - body, err := io.ReadAll(resp.Body) - if err != nil { return nil, types.NewOpenAIError(err, types.ErrorCodeReadResponseBodyFailed, http.StatusInternalServerError) } - service.CloseResponseBodyGracefully(resp) - raw := string(body) - if common.DebugEnabled { println("ollama non-stream raw resp:", raw) } + body, err := io.ReadAll(resp.Body) + if err != nil { + return nil, types.NewOpenAIError(err, types.ErrorCodeReadResponseBodyFailed, http.StatusInternalServerError) + } + service.CloseResponseBodyGracefully(resp) + raw := string(body) + if common.DebugEnabled { + println("ollama non-stream raw resp:", raw) + } - lines := strings.Split(raw, "\n") - var ( - aggContent strings.Builder - reasoningBuilder strings.Builder - lastChunk ollamaChatStreamChunk - parsedAny bool - ) - for _, ln := range lines { - ln = strings.TrimSpace(ln) - if ln == "" { continue } - var ck ollamaChatStreamChunk - if err := json.Unmarshal([]byte(ln), &ck); err != nil { - if len(lines) == 1 { return nil, types.NewOpenAIError(err, types.ErrorCodeBadResponseBody, http.StatusInternalServerError) } - continue - } - parsedAny = true - lastChunk = ck - if ck.Message != nil && len(ck.Message.Thinking) > 0 { - raw := strings.TrimSpace(string(ck.Message.Thinking)) - if raw != "" && raw != "null" { reasoningBuilder.WriteString(raw) } - } - if ck.Message != nil && ck.Message.Content != "" { aggContent.WriteString(ck.Message.Content) } else if ck.Response != "" { aggContent.WriteString(ck.Response) } - } + lines := strings.Split(raw, "\n") + var ( + aggContent strings.Builder + reasoningBuilder strings.Builder + lastChunk ollamaChatStreamChunk + parsedAny bool + ) + for _, ln := range lines { + ln = strings.TrimSpace(ln) + if ln == "" { + continue + } + var ck ollamaChatStreamChunk + if err := json.Unmarshal([]byte(ln), &ck); err != nil { + if len(lines) == 1 { + return nil, types.NewOpenAIError(err, types.ErrorCodeBadResponseBody, http.StatusInternalServerError) + } + continue + } + parsedAny = true + lastChunk = ck + if ck.Message != nil && len(ck.Message.Thinking) > 0 { + raw := strings.TrimSpace(string(ck.Message.Thinking)) + if raw != "" && raw != "null" { + reasoningBuilder.WriteString(raw) + } + } + if ck.Message != nil && ck.Message.Content != "" { + aggContent.WriteString(ck.Message.Content) + } else if ck.Response != "" { + aggContent.WriteString(ck.Response) + } + } - if !parsedAny { - var single ollamaChatStreamChunk - if err := json.Unmarshal(body, &single); err != nil { return nil, types.NewOpenAIError(err, types.ErrorCodeBadResponseBody, http.StatusInternalServerError) } - lastChunk = single - if single.Message != nil { - if len(single.Message.Thinking) > 0 { raw := strings.TrimSpace(string(single.Message.Thinking)); if raw != "" && raw != "null" { reasoningBuilder.WriteString(raw) } } - aggContent.WriteString(single.Message.Content) - } else { aggContent.WriteString(single.Response) } - } + if !parsedAny { + var single ollamaChatStreamChunk + if err := json.Unmarshal(body, &single); err != nil { + return nil, types.NewOpenAIError(err, types.ErrorCodeBadResponseBody, http.StatusInternalServerError) + } + lastChunk = single + if single.Message != nil { + if len(single.Message.Thinking) > 0 { + raw := strings.TrimSpace(string(single.Message.Thinking)) + if raw != "" && raw != "null" { + reasoningBuilder.WriteString(raw) + } + } + aggContent.WriteString(single.Message.Content) + } else { + aggContent.WriteString(single.Response) + } + } - model := lastChunk.Model - if model == "" { model = info.UpstreamModelName } - created := toUnix(lastChunk.CreatedAt) - usage := &dto.Usage{PromptTokens: lastChunk.PromptEvalCount, CompletionTokens: lastChunk.EvalCount, TotalTokens: lastChunk.PromptEvalCount + lastChunk.EvalCount} - content := aggContent.String() - finishReason := lastChunk.DoneReason - if finishReason == "" { finishReason = "stop" } + model := lastChunk.Model + if model == "" { + model = info.UpstreamModelName + } + created := toUnix(lastChunk.CreatedAt) + usage := &dto.Usage{PromptTokens: lastChunk.PromptEvalCount, CompletionTokens: lastChunk.EvalCount, TotalTokens: lastChunk.PromptEvalCount + lastChunk.EvalCount} + content := aggContent.String() + finishReason := lastChunk.DoneReason + if finishReason == "" { + finishReason = "stop" + } - msg := dto.Message{Role: "assistant", Content: contentPtr(content)} - if rc := reasoningBuilder.String(); rc != "" { msg.ReasoningContent = rc } - full := dto.OpenAITextResponse{ - Id: common.GetUUID(), - Model: model, - Object: "chat.completion", - Created: created, - Choices: []dto.OpenAITextResponseChoice{ { - Index: 0, - Message: msg, - FinishReason: finishReason, - } }, - Usage: *usage, - } - out, _ := common.Marshal(full) - service.IOCopyBytesGracefully(c, resp, out) - return usage, nil + msg := dto.Message{Role: "assistant", Content: contentPtr(content)} + if rc := reasoningBuilder.String(); rc != "" { + msg.ReasoningContent = rc + } + full := dto.OpenAITextResponse{ + Id: common.GetUUID(), + Model: model, + Object: "chat.completion", + Created: created, + Choices: []dto.OpenAITextResponseChoice{{ + Index: 0, + Message: msg, + FinishReason: finishReason, + }}, + Usage: *usage, + } + out, _ := common.Marshal(full) + service.IOCopyBytesGracefully(c, resp, out) + return usage, nil } -func contentPtr(s string) *string { if s=="" { return nil }; return &s } +func contentPtr(s string) *string { + if s == "" { + return nil + } + return &s +} diff --git a/relay/channel/submodel/constants.go b/relay/channel/submodel/constants.go index f5e1feb84..72d6fee31 100644 --- a/relay/channel/submodel/constants.go +++ b/relay/channel/submodel/constants.go @@ -13,4 +13,4 @@ var ModelList = []string{ "deepseek-ai/DeepSeek-V3.1", } -const ChannelName = "submodel" \ No newline at end of file +const ChannelName = "submodel" diff --git a/setting/operation_setting/general_setting.go b/setting/operation_setting/general_setting.go index ae0c436ec..c47ff64fa 100644 --- a/setting/operation_setting/general_setting.go +++ b/setting/operation_setting/general_setting.go @@ -2,17 +2,34 @@ package operation_setting import "one-api/setting/config" +// 额度展示类型 +const ( + QuotaDisplayTypeUSD = "USD" + QuotaDisplayTypeCNY = "CNY" + QuotaDisplayTypeTokens = "TOKENS" + QuotaDisplayTypeCustom = "CUSTOM" +) + type GeneralSetting struct { DocsLink string `json:"docs_link"` PingIntervalEnabled bool `json:"ping_interval_enabled"` PingIntervalSeconds int `json:"ping_interval_seconds"` + // 当前站点额度展示类型:USD / CNY / TOKENS + QuotaDisplayType string `json:"quota_display_type"` + // 自定义货币符号,用于 CUSTOM 展示类型 + CustomCurrencySymbol string `json:"custom_currency_symbol"` + // 自定义货币与美元汇率(1 USD = X Custom) + CustomCurrencyExchangeRate float64 `json:"custom_currency_exchange_rate"` } // 默认配置 var generalSetting = GeneralSetting{ - DocsLink: "https://docs.newapi.pro", - PingIntervalEnabled: false, - PingIntervalSeconds: 60, + DocsLink: "https://docs.newapi.pro", + PingIntervalEnabled: false, + PingIntervalSeconds: 60, + QuotaDisplayType: QuotaDisplayTypeUSD, + CustomCurrencySymbol: "¤", + CustomCurrencyExchangeRate: 1.0, } func init() { @@ -23,3 +40,52 @@ func init() { func GetGeneralSetting() *GeneralSetting { return &generalSetting } + +// IsCurrencyDisplay 是否以货币形式展示(美元或人民币) +func IsCurrencyDisplay() bool { + return generalSetting.QuotaDisplayType != QuotaDisplayTypeTokens +} + +// IsCNYDisplay 是否以人民币展示 +func IsCNYDisplay() bool { + return generalSetting.QuotaDisplayType == QuotaDisplayTypeCNY +} + +// GetQuotaDisplayType 返回额度展示类型 +func GetQuotaDisplayType() string { + return generalSetting.QuotaDisplayType +} + +// GetCurrencySymbol 返回当前展示类型对应符号 +func GetCurrencySymbol() string { + switch generalSetting.QuotaDisplayType { + case QuotaDisplayTypeUSD: + return "$" + case QuotaDisplayTypeCNY: + return "¥" + case QuotaDisplayTypeCustom: + if generalSetting.CustomCurrencySymbol != "" { + return generalSetting.CustomCurrencySymbol + } + return "¤" + default: + return "" + } +} + +// GetUsdToCurrencyRate 返回 1 USD = X 的 X(TOKENS 不适用) +func GetUsdToCurrencyRate(usdToCny float64) float64 { + switch generalSetting.QuotaDisplayType { + case QuotaDisplayTypeUSD: + return 1 + case QuotaDisplayTypeCNY: + return usdToCny + case QuotaDisplayTypeCustom: + if generalSetting.CustomCurrencyExchangeRate > 0 { + return generalSetting.CustomCurrencyExchangeRate + } + return 1 + default: + return 1 + } +} diff --git a/web/index.html b/web/index.html index df6b0e398..a9df87f5d 100644 --- a/web/index.html +++ b/web/index.html @@ -10,7 +10,7 @@ content="OpenAI 接口聚合管理,支持多种渠道包括 Azure,可用于二次分发管理 key,仅单可执行文件,已打包好 Docker 镜像,一键部署,开箱即用" /> New API - + diff --git a/web/src/components/settings/OperationSetting.jsx b/web/src/components/settings/OperationSetting.jsx index 0d6e44107..51cdb1de6 100644 --- a/web/src/components/settings/OperationSetting.jsx +++ b/web/src/components/settings/OperationSetting.jsx @@ -42,7 +42,7 @@ const OperationSetting = () => { QuotaPerUnit: 0, USDExchangeRate: 0, RetryTimes: 0, - DisplayInCurrencyEnabled: false, + 'general_setting.quota_display_type': 'USD', DisplayTokenStatEnabled: false, DefaultCollapseSidebar: false, DemoSiteEnabled: false, diff --git a/web/src/components/settings/personal/cards/AccountManagement.jsx b/web/src/components/settings/personal/cards/AccountManagement.jsx index 93a2daf89..ac2146c27 100644 --- a/web/src/components/settings/personal/cards/AccountManagement.jsx +++ b/web/src/components/settings/personal/cards/AccountManagement.jsx @@ -91,7 +91,8 @@ const AccountManagement = ({ ); }; const isBound = (accountId) => Boolean(accountId); - const [showTelegramBindModal, setShowTelegramBindModal] = React.useState(false); + const [showTelegramBindModal, setShowTelegramBindModal] = + React.useState(false); const passkeyEnabled = passkeyStatus?.enabled; const lastUsedLabel = passkeyStatus?.last_used_at ? new Date(passkeyStatus.last_used_at).toLocaleString() @@ -236,7 +237,8 @@ const AccountManagement = ({ onGitHubOAuthClicked(status.github_client_id) } disabled={ - isBound(userState.user?.github_id) || !status.github_oauth + isBound(userState.user?.github_id) || + !status.github_oauth } > {status.github_oauth ? t('绑定') : t('未启用')} @@ -394,7 +396,8 @@ const AccountManagement = ({ onLinuxDOOAuthClicked(status.linuxdo_client_id) } disabled={ - isBound(userState.user?.linux_do_id) || !status.linuxdo_oauth + isBound(userState.user?.linux_do_id) || + !status.linuxdo_oauth } > {status.linuxdo_oauth ? t('绑定') : t('未启用')} diff --git a/web/src/components/table/channels/modals/EditChannelModal.jsx b/web/src/components/table/channels/modals/EditChannelModal.jsx index dfbd75a43..d5d299969 100644 --- a/web/src/components/table/channels/modals/EditChannelModal.jsx +++ b/web/src/components/table/channels/modals/EditChannelModal.jsx @@ -91,8 +91,7 @@ const REGION_EXAMPLE = { // 支持并且已适配通过接口获取模型列表的渠道类型 const MODEL_FETCHABLE_TYPES = new Set([ - 1, 4, 14, 34, 17, 26, 24, 47, 25, 20, 23, 31, 35, 40, 42, 48, - 43, + 1, 4, 14, 34, 17, 26, 24, 47, 25, 20, 23, 31, 35, 40, 42, 48, 43, ]); function type2secretPrompt(type) { @@ -408,7 +407,10 @@ const EditChannelModal = (props) => { break; case 45: localModels = getChannelModels(value); - setInputs((prevInputs) => ({ ...prevInputs, base_url: 'https://ark.cn-beijing.volces.com' })); + setInputs((prevInputs) => ({ + ...prevInputs, + base_url: 'https://ark.cn-beijing.volces.com', + })); break; default: localModels = getChannelModels(value); @@ -502,7 +504,8 @@ const EditChannelModal = (props) => { // 读取 Vertex 密钥格式 data.vertex_key_type = parsedSettings.vertex_key_type || 'json'; // 读取企业账户设置 - data.is_enterprise_account = parsedSettings.openrouter_enterprise === true; + data.is_enterprise_account = + parsedSettings.openrouter_enterprise === true; // 读取字段透传控制设置 data.allow_service_tier = parsedSettings.allow_service_tier || false; data.disable_store = parsedSettings.disable_store || false; @@ -929,7 +932,10 @@ const EditChannelModal = (props) => { showInfo(t('请至少选择一个模型!')); return; } - if (localInputs.type === 45 && (!localInputs.base_url || localInputs.base_url.trim() === '')) { + if ( + localInputs.type === 45 && + (!localInputs.base_url || localInputs.base_url.trim() === '') + ) { showInfo(t('请输入API地址!')); return; } @@ -974,7 +980,8 @@ const EditChannelModal = (props) => { // type === 20: 设置企业账户标识,无论是true还是false都要传到后端 if (localInputs.type === 20) { - settings.openrouter_enterprise = localInputs.is_enterprise_account === true; + settings.openrouter_enterprise = + localInputs.is_enterprise_account === true; } // type === 1 (OpenAI) 或 type === 14 (Claude): 设置字段透传控制(显式保存布尔值) @@ -1433,7 +1440,9 @@ const EditChannelModal = (props) => { setIsEnterpriseAccount(value); handleInputChange('is_enterprise_account', value); }} - extraText={t('企业账户为特殊返回格式,需要特殊处理,如果非企业账户,请勿勾选')} + extraText={t( + '企业账户为特殊返回格式,需要特殊处理,如果非企业账户,请勿勾选', + )} initValue={inputs.is_enterprise_account} /> )} @@ -2061,27 +2070,27 @@ const EditChannelModal = (props) => { )} {inputs.type === 45 && ( -
- - handleInputChange('base_url', value) - } - optionList={[ - { - value: 'https://ark.cn-beijing.volces.com', - label: 'https://ark.cn-beijing.volces.com' - }, - { - value: 'https://ark.ap-southeast.bytepluses.com', - label: 'https://ark.ap-southeast.bytepluses.com' - } - ]} - defaultValue='https://ark.cn-beijing.volces.com' - /> -
+
+ + handleInputChange('base_url', value) + } + optionList={[ + { + value: 'https://ark.cn-beijing.volces.com', + label: 'https://ark.cn-beijing.volces.com', + }, + { + value: 'https://ark.ap-southeast.bytepluses.com', + label: 'https://ark.ap-southeast.bytepluses.com', + }, + ]} + defaultValue='https://ark.cn-beijing.volces.com' + /> +
)} diff --git a/web/src/components/table/model-pricing/filter/PricingDisplaySettings.jsx b/web/src/components/table/model-pricing/filter/PricingDisplaySettings.jsx index 4423ce39c..71dbd2000 100644 --- a/web/src/components/table/model-pricing/filter/PricingDisplaySettings.jsx +++ b/web/src/components/table/model-pricing/filter/PricingDisplaySettings.jsx @@ -56,6 +56,7 @@ const PricingDisplaySettings = ({ const currencyItems = [ { value: 'USD', label: 'USD ($)' }, { value: 'CNY', label: 'CNY (¥)' }, + { value: 'CUSTOM', label: t('自定义货币') }, ]; const handleChange = (value) => { diff --git a/web/src/components/table/model-pricing/layout/header/SearchActions.jsx b/web/src/components/table/model-pricing/layout/header/SearchActions.jsx index 646ffe023..c961b8dc1 100644 --- a/web/src/components/table/model-pricing/layout/header/SearchActions.jsx +++ b/web/src/components/table/model-pricing/layout/header/SearchActions.jsx @@ -107,6 +107,7 @@ const SearchActions = memo( optionList={[ { value: 'USD', label: 'USD' }, { value: 'CNY', label: 'CNY' }, + { value: 'CUSTOM', label: t('自定义货币') }, ]} /> )} diff --git a/web/src/components/table/task-logs/modals/ContentModal.jsx b/web/src/components/table/task-logs/modals/ContentModal.jsx index 3bfba37b1..1c7d96641 100644 --- a/web/src/components/table/task-logs/modals/ContentModal.jsx +++ b/web/src/components/table/task-logs/modals/ContentModal.jsx @@ -60,38 +60,54 @@ const ContentModal = ({ if (videoError) { return (
- + 视频无法在当前浏览器中播放,这可能是由于: - + • 视频服务商的跨域限制 - + • 需要特定的请求头或认证 - + • 防盗链保护机制 - +
- -
- -
- + {modalContent} @@ -104,22 +120,24 @@ const ContentModal = ({ return (
{isLoading && ( -
- +
+
)} -