mirror of
https://github.com/QuantumNous/new-api.git
synced 2026-03-30 17:44:41 +00:00
Compare commits
30 Commits
v0.10.8
...
v0.10.9-al
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
2c0db08f32 | ||
|
|
11de49f9b9 | ||
|
|
4950db666f | ||
|
|
44c5fac5ea | ||
|
|
7a146a11f5 | ||
|
|
897955256e | ||
|
|
bc6810ca5a | ||
|
|
742f4ad1e4 | ||
|
|
83a5245bb1 | ||
|
|
2faa873caf | ||
|
|
ce0113a6b5 | ||
|
|
dd5610d39e | ||
|
|
8e1a990b45 | ||
|
|
5f6f95c7c1 | ||
|
|
0b3a0b38d6 | ||
|
|
bbad917101 | ||
|
|
a0bb78edd0 | ||
|
|
fac9c367b1 | ||
|
|
23227e18f9 | ||
|
|
4332837f05 | ||
|
|
50ee4361d0 | ||
|
|
3af53bdd41 | ||
|
|
aa8240e482 | ||
|
|
b580b8bd1d | ||
|
|
e8d26e52d8 | ||
|
|
2567cff6c8 | ||
|
|
7314c974f3 | ||
|
|
fca80a57ad | ||
|
|
3229b81149 | ||
|
|
5efb402532 |
@@ -2,29 +2,37 @@ package common
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"sync"
|
||||
)
|
||||
|
||||
var TopupGroupRatio = map[string]float64{
|
||||
var topupGroupRatio = map[string]float64{
|
||||
"default": 1,
|
||||
"vip": 1,
|
||||
"svip": 1,
|
||||
}
|
||||
var topupGroupRatioMutex sync.RWMutex
|
||||
|
||||
func TopupGroupRatio2JSONString() string {
|
||||
jsonBytes, err := json.Marshal(TopupGroupRatio)
|
||||
topupGroupRatioMutex.RLock()
|
||||
defer topupGroupRatioMutex.RUnlock()
|
||||
jsonBytes, err := json.Marshal(topupGroupRatio)
|
||||
if err != nil {
|
||||
SysError("error marshalling model ratio: " + err.Error())
|
||||
SysError("error marshalling topup group ratio: " + err.Error())
|
||||
}
|
||||
return string(jsonBytes)
|
||||
}
|
||||
|
||||
func UpdateTopupGroupRatioByJSONString(jsonStr string) error {
|
||||
TopupGroupRatio = make(map[string]float64)
|
||||
return json.Unmarshal([]byte(jsonStr), &TopupGroupRatio)
|
||||
topupGroupRatioMutex.Lock()
|
||||
defer topupGroupRatioMutex.Unlock()
|
||||
topupGroupRatio = make(map[string]float64)
|
||||
return json.Unmarshal([]byte(jsonStr), &topupGroupRatio)
|
||||
}
|
||||
|
||||
func GetTopupGroupRatio(name string) float64 {
|
||||
ratio, ok := TopupGroupRatio[name]
|
||||
topupGroupRatioMutex.RLock()
|
||||
defer topupGroupRatioMutex.RUnlock()
|
||||
ratio, ok := topupGroupRatio[name]
|
||||
if !ok {
|
||||
SysError("topup group ratio not found: " + name)
|
||||
return 1
|
||||
|
||||
@@ -31,6 +31,7 @@ import (
|
||||
|
||||
"github.com/bytedance/gopkg/util/gopool"
|
||||
"github.com/samber/lo"
|
||||
"github.com/tidwall/gjson"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
@@ -41,7 +42,21 @@ type testResult struct {
|
||||
newAPIError *types.NewAPIError
|
||||
}
|
||||
|
||||
func testChannel(channel *model.Channel, testModel string, endpointType string) testResult {
|
||||
func normalizeChannelTestEndpoint(channel *model.Channel, modelName, endpointType string) string {
|
||||
normalized := strings.TrimSpace(endpointType)
|
||||
if normalized != "" {
|
||||
return normalized
|
||||
}
|
||||
if strings.HasSuffix(modelName, ratio_setting.CompactModelSuffix) {
|
||||
return string(constant.EndpointTypeOpenAIResponseCompact)
|
||||
}
|
||||
if channel != nil && channel.Type == constant.ChannelTypeCodex {
|
||||
return string(constant.EndpointTypeOpenAIResponse)
|
||||
}
|
||||
return normalized
|
||||
}
|
||||
|
||||
func testChannel(channel *model.Channel, testModel string, endpointType string, isStream bool) testResult {
|
||||
tik := time.Now()
|
||||
var unsupportedTestChannelTypes = []int{
|
||||
constant.ChannelTypeMidjourney,
|
||||
@@ -76,6 +91,8 @@ func testChannel(channel *model.Channel, testModel string, endpointType string)
|
||||
}
|
||||
}
|
||||
|
||||
endpointType = normalizeChannelTestEndpoint(channel, testModel, endpointType)
|
||||
|
||||
requestPath := "/v1/chat/completions"
|
||||
|
||||
// 如果指定了端点类型,使用指定的端点类型
|
||||
@@ -200,7 +217,7 @@ func testChannel(channel *model.Channel, testModel string, endpointType string)
|
||||
}
|
||||
}
|
||||
|
||||
request := buildTestRequest(testModel, endpointType, channel)
|
||||
request := buildTestRequest(testModel, endpointType, channel, isStream)
|
||||
|
||||
info, err := relaycommon.GenRelayInfo(c, relayFormat, request, nil)
|
||||
|
||||
@@ -418,16 +435,16 @@ func testChannel(channel *model.Channel, testModel string, endpointType string)
|
||||
newAPIError: respErr,
|
||||
}
|
||||
}
|
||||
if usageA == nil {
|
||||
usage, usageErr := coerceTestUsage(usageA, isStream, info.GetEstimatePromptTokens())
|
||||
if usageErr != nil {
|
||||
return testResult{
|
||||
context: c,
|
||||
localErr: errors.New("usage is nil"),
|
||||
newAPIError: types.NewOpenAIError(errors.New("usage is nil"), types.ErrorCodeBadResponseBody, http.StatusInternalServerError),
|
||||
localErr: usageErr,
|
||||
newAPIError: types.NewOpenAIError(usageErr, types.ErrorCodeBadResponseBody, http.StatusInternalServerError),
|
||||
}
|
||||
}
|
||||
usage := usageA.(*dto.Usage)
|
||||
result := w.Result()
|
||||
respBody, err := io.ReadAll(result.Body)
|
||||
respBody, err := readTestResponseBody(result.Body, isStream)
|
||||
if err != nil {
|
||||
return testResult{
|
||||
context: c,
|
||||
@@ -435,6 +452,13 @@ func testChannel(channel *model.Channel, testModel string, endpointType string)
|
||||
newAPIError: types.NewOpenAIError(err, types.ErrorCodeReadResponseBodyFailed, http.StatusInternalServerError),
|
||||
}
|
||||
}
|
||||
if bodyErr := detectErrorFromTestResponseBody(respBody); bodyErr != nil {
|
||||
return testResult{
|
||||
context: c,
|
||||
localErr: bodyErr,
|
||||
newAPIError: types.NewOpenAIError(bodyErr, types.ErrorCodeBadResponseBody, http.StatusInternalServerError),
|
||||
}
|
||||
}
|
||||
info.SetEstimatePromptTokens(usage.PromptTokens)
|
||||
|
||||
quota := 0
|
||||
@@ -473,7 +497,101 @@ func testChannel(channel *model.Channel, testModel string, endpointType string)
|
||||
}
|
||||
}
|
||||
|
||||
func buildTestRequest(model string, endpointType string, channel *model.Channel) dto.Request {
|
||||
func coerceTestUsage(usageAny any, isStream bool, estimatePromptTokens int) (*dto.Usage, error) {
|
||||
switch u := usageAny.(type) {
|
||||
case *dto.Usage:
|
||||
return u, nil
|
||||
case dto.Usage:
|
||||
return &u, nil
|
||||
case nil:
|
||||
if !isStream {
|
||||
return nil, errors.New("usage is nil")
|
||||
}
|
||||
usage := &dto.Usage{
|
||||
PromptTokens: estimatePromptTokens,
|
||||
}
|
||||
usage.TotalTokens = usage.PromptTokens
|
||||
return usage, nil
|
||||
default:
|
||||
if !isStream {
|
||||
return nil, fmt.Errorf("invalid usage type: %T", usageAny)
|
||||
}
|
||||
usage := &dto.Usage{
|
||||
PromptTokens: estimatePromptTokens,
|
||||
}
|
||||
usage.TotalTokens = usage.PromptTokens
|
||||
return usage, nil
|
||||
}
|
||||
}
|
||||
|
||||
func readTestResponseBody(body io.ReadCloser, isStream bool) ([]byte, error) {
|
||||
defer func() { _ = body.Close() }()
|
||||
const maxStreamLogBytes = 8 << 10
|
||||
if isStream {
|
||||
return io.ReadAll(io.LimitReader(body, maxStreamLogBytes))
|
||||
}
|
||||
return io.ReadAll(body)
|
||||
}
|
||||
|
||||
func detectErrorFromTestResponseBody(respBody []byte) error {
|
||||
b := bytes.TrimSpace(respBody)
|
||||
if len(b) == 0 {
|
||||
return nil
|
||||
}
|
||||
if message := detectErrorMessageFromJSONBytes(b); message != "" {
|
||||
return fmt.Errorf("upstream error: %s", message)
|
||||
}
|
||||
|
||||
for _, line := range bytes.Split(b, []byte{'\n'}) {
|
||||
line = bytes.TrimSpace(line)
|
||||
if len(line) == 0 {
|
||||
continue
|
||||
}
|
||||
if !bytes.HasPrefix(line, []byte("data:")) {
|
||||
continue
|
||||
}
|
||||
payload := bytes.TrimSpace(bytes.TrimPrefix(line, []byte("data:")))
|
||||
if len(payload) == 0 || bytes.Equal(payload, []byte("[DONE]")) {
|
||||
continue
|
||||
}
|
||||
if message := detectErrorMessageFromJSONBytes(payload); message != "" {
|
||||
return fmt.Errorf("upstream error: %s", message)
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func detectErrorMessageFromJSONBytes(jsonBytes []byte) string {
|
||||
if len(jsonBytes) == 0 {
|
||||
return ""
|
||||
}
|
||||
if jsonBytes[0] != '{' && jsonBytes[0] != '[' {
|
||||
return ""
|
||||
}
|
||||
errVal := gjson.GetBytes(jsonBytes, "error")
|
||||
if !errVal.Exists() || errVal.Type == gjson.Null {
|
||||
return ""
|
||||
}
|
||||
|
||||
message := gjson.GetBytes(jsonBytes, "error.message").String()
|
||||
if message == "" {
|
||||
message = gjson.GetBytes(jsonBytes, "error.error.message").String()
|
||||
}
|
||||
if message == "" && errVal.Type == gjson.String {
|
||||
message = errVal.String()
|
||||
}
|
||||
if message == "" {
|
||||
message = errVal.Raw
|
||||
}
|
||||
message = strings.TrimSpace(message)
|
||||
if message == "" {
|
||||
return "upstream returned error payload"
|
||||
}
|
||||
return message
|
||||
}
|
||||
|
||||
func buildTestRequest(model string, endpointType string, channel *model.Channel, isStream bool) dto.Request {
|
||||
testResponsesInput := json.RawMessage(`[{"role":"user","content":"hi"}]`)
|
||||
|
||||
// 根据端点类型构建不同的测试请求
|
||||
@@ -504,8 +622,9 @@ func buildTestRequest(model string, endpointType string, channel *model.Channel)
|
||||
case constant.EndpointTypeOpenAIResponse:
|
||||
// 返回 OpenAIResponsesRequest
|
||||
return &dto.OpenAIResponsesRequest{
|
||||
Model: model,
|
||||
Input: json.RawMessage(`[{"role":"user","content":"hi"}]`),
|
||||
Model: model,
|
||||
Input: json.RawMessage(`[{"role":"user","content":"hi"}]`),
|
||||
Stream: isStream,
|
||||
}
|
||||
case constant.EndpointTypeOpenAIResponseCompact:
|
||||
// 返回 OpenAIResponsesCompactionRequest
|
||||
@@ -519,9 +638,9 @@ func buildTestRequest(model string, endpointType string, channel *model.Channel)
|
||||
if constant.EndpointType(endpointType) == constant.EndpointTypeGemini {
|
||||
maxTokens = 3000
|
||||
}
|
||||
return &dto.GeneralOpenAIRequest{
|
||||
req := &dto.GeneralOpenAIRequest{
|
||||
Model: model,
|
||||
Stream: false,
|
||||
Stream: isStream,
|
||||
Messages: []dto.Message{
|
||||
{
|
||||
Role: "user",
|
||||
@@ -530,6 +649,10 @@ func buildTestRequest(model string, endpointType string, channel *model.Channel)
|
||||
},
|
||||
MaxTokens: maxTokens,
|
||||
}
|
||||
if isStream {
|
||||
req.StreamOptions = &dto.StreamOptions{IncludeUsage: true}
|
||||
}
|
||||
return req
|
||||
}
|
||||
}
|
||||
|
||||
@@ -565,15 +688,16 @@ func buildTestRequest(model string, endpointType string, channel *model.Channel)
|
||||
// Responses-only models (e.g. codex series)
|
||||
if strings.Contains(strings.ToLower(model), "codex") {
|
||||
return &dto.OpenAIResponsesRequest{
|
||||
Model: model,
|
||||
Input: json.RawMessage(`[{"role":"user","content":"hi"}]`),
|
||||
Model: model,
|
||||
Input: json.RawMessage(`[{"role":"user","content":"hi"}]`),
|
||||
Stream: isStream,
|
||||
}
|
||||
}
|
||||
|
||||
// Chat/Completion 请求 - 返回 GeneralOpenAIRequest
|
||||
testRequest := &dto.GeneralOpenAIRequest{
|
||||
Model: model,
|
||||
Stream: false,
|
||||
Stream: isStream,
|
||||
Messages: []dto.Message{
|
||||
{
|
||||
Role: "user",
|
||||
@@ -581,6 +705,9 @@ func buildTestRequest(model string, endpointType string, channel *model.Channel)
|
||||
},
|
||||
},
|
||||
}
|
||||
if isStream {
|
||||
testRequest.StreamOptions = &dto.StreamOptions{IncludeUsage: true}
|
||||
}
|
||||
|
||||
if strings.HasPrefix(model, "o") {
|
||||
testRequest.MaxCompletionTokens = 16
|
||||
@@ -618,8 +745,9 @@ func TestChannel(c *gin.Context) {
|
||||
//}()
|
||||
testModel := c.Query("model")
|
||||
endpointType := c.Query("endpoint_type")
|
||||
isStream, _ := strconv.ParseBool(c.Query("stream"))
|
||||
tik := time.Now()
|
||||
result := testChannel(channel, testModel, endpointType)
|
||||
result := testChannel(channel, testModel, endpointType, isStream)
|
||||
if result.localErr != nil {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": false,
|
||||
@@ -678,7 +806,7 @@ func testAllChannels(notify bool) error {
|
||||
for _, channel := range channels {
|
||||
isChannelEnabled := channel.Status == common.ChannelStatusEnabled
|
||||
tik := time.Now()
|
||||
result := testChannel(channel, "", "")
|
||||
result := testChannel(channel, "", "", false)
|
||||
tok := time.Now()
|
||||
milliseconds := tok.Sub(tik).Milliseconds()
|
||||
|
||||
|
||||
@@ -166,21 +166,21 @@ func CreateCustomOAuthProvider(c *gin.Context) {
|
||||
|
||||
// UpdateCustomOAuthProviderRequest is the request structure for updating a custom OAuth provider
|
||||
type UpdateCustomOAuthProviderRequest struct {
|
||||
Name string `json:"name"`
|
||||
Slug string `json:"slug"`
|
||||
Enabled bool `json:"enabled"`
|
||||
ClientId string `json:"client_id"`
|
||||
ClientSecret string `json:"client_secret"` // Optional: if empty, keep existing
|
||||
AuthorizationEndpoint string `json:"authorization_endpoint"`
|
||||
TokenEndpoint string `json:"token_endpoint"`
|
||||
UserInfoEndpoint string `json:"user_info_endpoint"`
|
||||
Scopes string `json:"scopes"`
|
||||
UserIdField string `json:"user_id_field"`
|
||||
UsernameField string `json:"username_field"`
|
||||
DisplayNameField string `json:"display_name_field"`
|
||||
EmailField string `json:"email_field"`
|
||||
WellKnown string `json:"well_known"`
|
||||
AuthStyle int `json:"auth_style"`
|
||||
Name string `json:"name"`
|
||||
Slug string `json:"slug"`
|
||||
Enabled *bool `json:"enabled"` // Optional: if nil, keep existing
|
||||
ClientId string `json:"client_id"`
|
||||
ClientSecret string `json:"client_secret"` // Optional: if empty, keep existing
|
||||
AuthorizationEndpoint string `json:"authorization_endpoint"`
|
||||
TokenEndpoint string `json:"token_endpoint"`
|
||||
UserInfoEndpoint string `json:"user_info_endpoint"`
|
||||
Scopes string `json:"scopes"`
|
||||
UserIdField string `json:"user_id_field"`
|
||||
UsernameField string `json:"username_field"`
|
||||
DisplayNameField string `json:"display_name_field"`
|
||||
EmailField string `json:"email_field"`
|
||||
WellKnown *string `json:"well_known"` // Optional: if nil, keep existing
|
||||
AuthStyle *int `json:"auth_style"` // Optional: if nil, keep existing
|
||||
}
|
||||
|
||||
// UpdateCustomOAuthProvider updates an existing custom OAuth provider
|
||||
@@ -227,7 +227,9 @@ func UpdateCustomOAuthProvider(c *gin.Context) {
|
||||
if req.Slug != "" {
|
||||
provider.Slug = req.Slug
|
||||
}
|
||||
provider.Enabled = req.Enabled
|
||||
if req.Enabled != nil {
|
||||
provider.Enabled = *req.Enabled
|
||||
}
|
||||
if req.ClientId != "" {
|
||||
provider.ClientId = req.ClientId
|
||||
}
|
||||
@@ -258,8 +260,12 @@ func UpdateCustomOAuthProvider(c *gin.Context) {
|
||||
if req.EmailField != "" {
|
||||
provider.EmailField = req.EmailField
|
||||
}
|
||||
provider.WellKnown = req.WellKnown
|
||||
provider.AuthStyle = req.AuthStyle
|
||||
if req.WellKnown != nil {
|
||||
provider.WellKnown = *req.WellKnown
|
||||
}
|
||||
if req.AuthStyle != nil {
|
||||
provider.AuthStyle = *req.AuthStyle
|
||||
}
|
||||
|
||||
if err := model.UpdateCustomOAuthProvider(provider); err != nil {
|
||||
common.ApiError(c, err)
|
||||
@@ -296,7 +302,12 @@ func DeleteCustomOAuthProvider(c *gin.Context) {
|
||||
}
|
||||
|
||||
// Check if there are any user bindings
|
||||
count, _ := model.GetBindingCountByProviderId(id)
|
||||
count, err := model.GetBindingCountByProviderId(id)
|
||||
if err != nil {
|
||||
common.SysError("Failed to get binding count for provider " + strconv.Itoa(id) + ": " + err.Error())
|
||||
common.ApiErrorMsg(c, "检查用户绑定时发生错误,请稍后重试")
|
||||
return
|
||||
}
|
||||
if count > 0 {
|
||||
common.ApiErrorMsg(c, "该 OAuth 提供商还有用户绑定,无法删除。请先解除所有用户绑定。")
|
||||
return
|
||||
|
||||
@@ -11,6 +11,7 @@ import (
|
||||
"github.com/QuantumNous/new-api/oauth"
|
||||
"github.com/gin-contrib/sessions"
|
||||
"github.com/gin-gonic/gin"
|
||||
"gorm.io/gorm"
|
||||
)
|
||||
|
||||
// providerParams returns map with Provider key for i18n templates
|
||||
@@ -256,27 +257,62 @@ func findOrCreateOAuthUser(c *gin.Context, provider oauth.Provider, oauthUser *o
|
||||
inviterId, _ = model.GetUserIdByAffCode(affCode.(string))
|
||||
}
|
||||
|
||||
if err := user.Insert(inviterId); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// For custom providers, create the binding after user is created
|
||||
// Use transaction to ensure user creation and OAuth binding are atomic
|
||||
if genericProvider, ok := provider.(*oauth.GenericOAuthProvider); ok {
|
||||
binding := &model.UserOAuthBinding{
|
||||
UserId: user.Id,
|
||||
ProviderId: genericProvider.GetProviderId(),
|
||||
ProviderUserId: oauthUser.ProviderUserID,
|
||||
}
|
||||
if err := model.CreateUserOAuthBinding(binding); err != nil {
|
||||
common.SysError(fmt.Sprintf("[OAuth] Failed to create binding for user %d: %s", user.Id, err.Error()))
|
||||
// Don't fail the registration, just log the error
|
||||
// Custom provider: create user and binding in a transaction
|
||||
err := model.DB.Transaction(func(tx *gorm.DB) error {
|
||||
// Create user
|
||||
if err := user.InsertWithTx(tx, inviterId); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Create OAuth binding
|
||||
binding := &model.UserOAuthBinding{
|
||||
UserId: user.Id,
|
||||
ProviderId: genericProvider.GetProviderId(),
|
||||
ProviderUserId: oauthUser.ProviderUserID,
|
||||
}
|
||||
if err := model.CreateUserOAuthBindingWithTx(tx, binding); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return nil
|
||||
})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Perform post-transaction tasks (logs, sidebar config, inviter rewards)
|
||||
user.FinalizeOAuthUserCreation(inviterId)
|
||||
} else {
|
||||
// Built-in provider: set the provider user ID on the user model
|
||||
provider.SetProviderUserID(user, oauthUser.ProviderUserID)
|
||||
if err := user.Update(false); err != nil {
|
||||
common.SysError(fmt.Sprintf("[OAuth] Failed to update provider ID for user %d: %s", user.Id, err.Error()))
|
||||
// Built-in provider: create user and update provider ID in a transaction
|
||||
err := model.DB.Transaction(func(tx *gorm.DB) error {
|
||||
// Create user
|
||||
if err := user.InsertWithTx(tx, inviterId); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Set the provider user ID on the user model and update
|
||||
provider.SetProviderUserID(user, oauthUser.ProviderUserID)
|
||||
if err := tx.Model(user).Updates(map[string]interface{}{
|
||||
"github_id": user.GitHubId,
|
||||
"discord_id": user.DiscordId,
|
||||
"oidc_id": user.OidcId,
|
||||
"linux_do_id": user.LinuxDOId,
|
||||
"wechat_id": user.WeChatId,
|
||||
"telegram_id": user.TelegramId,
|
||||
}).Error; err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return nil
|
||||
})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Perform post-transaction tasks
|
||||
user.FinalizeOAuthUserCreation(inviterId)
|
||||
}
|
||||
|
||||
return user, nil
|
||||
|
||||
@@ -169,6 +169,15 @@ func UpdateOption(c *gin.Context) {
|
||||
})
|
||||
return
|
||||
}
|
||||
case "CreateCacheRatio":
|
||||
err = ratio_setting.UpdateCreateCacheRatioByJSONString(option.Value.(string))
|
||||
if err != nil {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": false,
|
||||
"message": "缓存创建倍率设置失败: " + err.Error(),
|
||||
})
|
||||
return
|
||||
}
|
||||
case "ModelRequestRateLimitGroup":
|
||||
err = setting.CheckModelRequestRateLimitGroup(option.Value.(string))
|
||||
if err != nil {
|
||||
|
||||
@@ -27,6 +27,7 @@ type ChannelOtherSettings struct {
|
||||
AzureResponsesVersion string `json:"azure_responses_version,omitempty"`
|
||||
VertexKeyType VertexKeyType `json:"vertex_key_type,omitempty"` // "json" or "api_key"
|
||||
OpenRouterEnterprise *bool `json:"openrouter_enterprise,omitempty"`
|
||||
ClaudeBetaQuery bool `json:"claude_beta_query,omitempty"` // Claude 渠道是否强制追加 ?beta=true
|
||||
AllowServiceTier bool `json:"allow_service_tier,omitempty"` // 是否允许 service_tier 透传(默认过滤以避免额外计费)
|
||||
DisableStore bool `json:"disable_store,omitempty"` // 是否禁用 store 透传(默认允许透传,禁用后可能导致 Codex 无法使用)
|
||||
AllowSafetyIdentifier bool `json:"allow_safety_identifier,omitempty"` // 是否允许 safety_identifier 透传(默认过滤以保护用户隐私)
|
||||
|
||||
@@ -97,13 +97,18 @@ func DeleteCustomOAuthProvider(id int) error {
|
||||
}
|
||||
|
||||
// IsSlugTaken checks if a slug is already taken by another provider
|
||||
// Returns true on DB errors (fail-closed) to prevent slug conflicts
|
||||
func IsSlugTaken(slug string, excludeId int) bool {
|
||||
var count int64
|
||||
query := DB.Model(&CustomOAuthProvider{}).Where("slug = ?", slug)
|
||||
if excludeId > 0 {
|
||||
query = query.Where("id != ?", excludeId)
|
||||
}
|
||||
query.Count(&count)
|
||||
res := query.Count(&count)
|
||||
if res.Error != nil {
|
||||
// Fail-closed: treat DB errors as slug being taken to prevent conflicts
|
||||
return true
|
||||
}
|
||||
return count > 0
|
||||
}
|
||||
|
||||
|
||||
@@ -47,7 +47,21 @@ func (mi *Model) Insert() error {
|
||||
now := common.GetTimestamp()
|
||||
mi.CreatedTime = now
|
||||
mi.UpdatedTime = now
|
||||
return DB.Create(mi).Error
|
||||
|
||||
// 保存原始值(因为 Create 后可能被 GORM 的 default 标签覆盖为 1)
|
||||
originalStatus := mi.Status
|
||||
originalSyncOfficial := mi.SyncOfficial
|
||||
|
||||
// 先创建记录(GORM 会对零值字段应用默认值)
|
||||
if err := DB.Create(mi).Error; err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// 使用保存的原始值进行更新,确保零值能正确保存
|
||||
return DB.Model(&Model{}).Where("id = ?", mi.Id).Updates(map[string]interface{}{
|
||||
"status": originalStatus,
|
||||
"sync_official": originalSyncOfficial,
|
||||
}).Error
|
||||
}
|
||||
|
||||
func IsModelNameDuplicated(id int, name string) (bool, error) {
|
||||
@@ -61,11 +75,9 @@ func IsModelNameDuplicated(id int, name string) (bool, error) {
|
||||
|
||||
func (mi *Model) Update() error {
|
||||
mi.UpdatedTime = common.GetTimestamp()
|
||||
return DB.Session(&gorm.Session{AllowGlobalUpdate: false, FullSaveAssociations: false}).
|
||||
Model(&Model{}).
|
||||
Where("id = ?", mi.Id).
|
||||
Omit("created_time").
|
||||
Select("*").
|
||||
// 使用 Select 强制更新所有字段,包括零值
|
||||
return DB.Model(&Model{}).Where("id = ?", mi.Id).
|
||||
Select("model_name", "description", "icon", "tags", "vendor_id", "endpoints", "status", "sync_official", "name_rule", "updated_time").
|
||||
Updates(mi).Error
|
||||
}
|
||||
|
||||
|
||||
@@ -115,6 +115,7 @@ func InitOptionMap() {
|
||||
common.OptionMap["ModelRatio"] = ratio_setting.ModelRatio2JSONString()
|
||||
common.OptionMap["ModelPrice"] = ratio_setting.ModelPrice2JSONString()
|
||||
common.OptionMap["CacheRatio"] = ratio_setting.CacheRatio2JSONString()
|
||||
common.OptionMap["CreateCacheRatio"] = ratio_setting.CreateCacheRatio2JSONString()
|
||||
common.OptionMap["GroupRatio"] = ratio_setting.GroupRatio2JSONString()
|
||||
common.OptionMap["GroupGroupRatio"] = ratio_setting.GroupGroupRatio2JSONString()
|
||||
common.OptionMap["UserUsableGroups"] = setting.UserUsableGroups2JSONString()
|
||||
@@ -427,6 +428,8 @@ func updateOptionMap(key string, value string) (err error) {
|
||||
err = ratio_setting.UpdateModelPriceByJSONString(value)
|
||||
case "CacheRatio":
|
||||
err = ratio_setting.UpdateCacheRatioByJSONString(value)
|
||||
case "CreateCacheRatio":
|
||||
err = ratio_setting.UpdateCreateCacheRatioByJSONString(value)
|
||||
case "ImageRatio":
|
||||
err = ratio_setting.UpdateImageRatioByJSONString(value)
|
||||
case "AudioRatio":
|
||||
|
||||
@@ -196,20 +196,25 @@ func updatePricing() {
|
||||
modelSupportEndpointsStr[ability.Model] = endpoints
|
||||
}
|
||||
|
||||
// 再补充模型自定义端点
|
||||
// 再补充模型自定义端点:若配置有效则替换默认端点,不做合并
|
||||
for modelName, meta := range metaMap {
|
||||
if strings.TrimSpace(meta.Endpoints) == "" {
|
||||
continue
|
||||
}
|
||||
var raw map[string]interface{}
|
||||
if err := json.Unmarshal([]byte(meta.Endpoints), &raw); err == nil {
|
||||
endpoints := modelSupportEndpointsStr[modelName]
|
||||
for k := range raw {
|
||||
if !common.StringsContains(endpoints, k) {
|
||||
endpoints = append(endpoints, k)
|
||||
endpoints := make([]string, 0, len(raw))
|
||||
for k, v := range raw {
|
||||
switch v.(type) {
|
||||
case string, map[string]interface{}:
|
||||
if !common.StringsContains(endpoints, k) {
|
||||
endpoints = append(endpoints, k)
|
||||
}
|
||||
}
|
||||
}
|
||||
modelSupportEndpointsStr[modelName] = endpoints
|
||||
if len(endpoints) > 0 {
|
||||
modelSupportEndpointsStr[modelName] = endpoints
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -429,6 +429,65 @@ func (user *User) Insert(inviterId int) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
// InsertWithTx inserts a new user within an existing transaction.
|
||||
// This is used for OAuth registration where user creation and binding need to be atomic.
|
||||
// Post-creation tasks (sidebar config, logs, inviter rewards) are handled after the transaction commits.
|
||||
func (user *User) InsertWithTx(tx *gorm.DB, inviterId int) error {
|
||||
var err error
|
||||
if user.Password != "" {
|
||||
user.Password, err = common.Password2Hash(user.Password)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
user.Quota = common.QuotaForNewUser
|
||||
user.AffCode = common.GetRandomString(4)
|
||||
|
||||
// 初始化用户设置
|
||||
if user.Setting == "" {
|
||||
defaultSetting := dto.UserSetting{}
|
||||
user.SetSetting(defaultSetting)
|
||||
}
|
||||
|
||||
result := tx.Create(user)
|
||||
if result.Error != nil {
|
||||
return result.Error
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// FinalizeOAuthUserCreation performs post-transaction tasks for OAuth user creation.
|
||||
// This should be called after the transaction commits successfully.
|
||||
func (user *User) FinalizeOAuthUserCreation(inviterId int) {
|
||||
// 用户创建成功后,根据角色初始化边栏配置
|
||||
var createdUser User
|
||||
if err := DB.Where("id = ?", user.Id).First(&createdUser).Error; err == nil {
|
||||
defaultSidebarConfig := generateDefaultSidebarConfigForRole(createdUser.Role)
|
||||
if defaultSidebarConfig != "" {
|
||||
currentSetting := createdUser.GetSetting()
|
||||
currentSetting.SidebarModules = defaultSidebarConfig
|
||||
createdUser.SetSetting(currentSetting)
|
||||
createdUser.Update(false)
|
||||
common.SysLog(fmt.Sprintf("为新用户 %s (角色: %d) 初始化边栏配置", createdUser.Username, createdUser.Role))
|
||||
}
|
||||
}
|
||||
|
||||
if common.QuotaForNewUser > 0 {
|
||||
RecordLog(user.Id, LogTypeSystem, fmt.Sprintf("新用户注册赠送 %s", logger.LogQuota(common.QuotaForNewUser)))
|
||||
}
|
||||
if inviterId != 0 {
|
||||
if common.QuotaForInvitee > 0 {
|
||||
_ = IncreaseUserQuota(user.Id, common.QuotaForInvitee, true)
|
||||
RecordLog(user.Id, LogTypeSystem, fmt.Sprintf("使用邀请码赠送 %s", logger.LogQuota(common.QuotaForInvitee)))
|
||||
}
|
||||
if common.QuotaForInviter > 0 {
|
||||
RecordLog(inviterId, LogTypeSystem, fmt.Sprintf("邀请用户赠送 %s", logger.LogQuota(common.QuotaForInviter)))
|
||||
_ = inviteUser(inviterId)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (user *User) Update(updatePassword bool) error {
|
||||
var err error
|
||||
if updatePassword {
|
||||
|
||||
@@ -3,18 +3,17 @@ package model
|
||||
import (
|
||||
"errors"
|
||||
"time"
|
||||
|
||||
"gorm.io/gorm"
|
||||
)
|
||||
|
||||
// UserOAuthBinding stores the binding relationship between users and custom OAuth providers
|
||||
type UserOAuthBinding struct {
|
||||
Id int `json:"id" gorm:"primaryKey"`
|
||||
UserId int `json:"user_id" gorm:"index;not null"` // User ID
|
||||
ProviderId int `json:"provider_id" gorm:"index;not null"` // Custom OAuth provider ID
|
||||
ProviderUserId string `json:"provider_user_id" gorm:"type:varchar(256);not null"` // User ID from OAuth provider
|
||||
UserId int `json:"user_id" gorm:"not null;uniqueIndex:ux_user_provider"` // User ID - one binding per user per provider
|
||||
ProviderId int `json:"provider_id" gorm:"not null;uniqueIndex:ux_user_provider;uniqueIndex:ux_provider_userid"` // Custom OAuth provider ID
|
||||
ProviderUserId string `json:"provider_user_id" gorm:"type:varchar(256);not null;uniqueIndex:ux_provider_userid"` // User ID from OAuth provider - one OAuth account per provider
|
||||
CreatedAt time.Time `json:"created_at"`
|
||||
|
||||
// Composite unique index to prevent duplicate bindings
|
||||
// One OAuth account can only be bound to one user
|
||||
}
|
||||
|
||||
func (UserOAuthBinding) TableName() string {
|
||||
@@ -82,6 +81,29 @@ func CreateUserOAuthBinding(binding *UserOAuthBinding) error {
|
||||
return DB.Create(binding).Error
|
||||
}
|
||||
|
||||
// CreateUserOAuthBindingWithTx creates a new OAuth binding within a transaction
|
||||
func CreateUserOAuthBindingWithTx(tx *gorm.DB, binding *UserOAuthBinding) error {
|
||||
if binding.UserId == 0 {
|
||||
return errors.New("user ID is required")
|
||||
}
|
||||
if binding.ProviderId == 0 {
|
||||
return errors.New("provider ID is required")
|
||||
}
|
||||
if binding.ProviderUserId == "" {
|
||||
return errors.New("provider user ID is required")
|
||||
}
|
||||
|
||||
// Check if this provider user ID is already taken (use tx to check within the same transaction)
|
||||
var count int64
|
||||
tx.Model(&UserOAuthBinding{}).Where("provider_id = ? AND provider_user_id = ?", binding.ProviderId, binding.ProviderUserId).Count(&count)
|
||||
if count > 0 {
|
||||
return errors.New("this OAuth account is already bound to another user")
|
||||
}
|
||||
|
||||
binding.CreatedAt = time.Now()
|
||||
return tx.Create(binding).Error
|
||||
}
|
||||
|
||||
// UpdateUserOAuthBinding updates an existing OAuth binding (e.g., rebind to different OAuth account)
|
||||
func UpdateUserOAuthBinding(userId, providerId int, newProviderUserId string) error {
|
||||
// Check if the new provider user ID is already taken by another user
|
||||
|
||||
@@ -5,6 +5,7 @@ import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"strconv"
|
||||
"time"
|
||||
@@ -122,6 +123,17 @@ func (p *GitHubProvider) GetUserInfo(ctx context.Context, token *OAuthToken) (*O
|
||||
|
||||
logger.LogDebug(ctx, "[OAuth-GitHub] GetUserInfo response status: %d", res.StatusCode)
|
||||
|
||||
// Check for non-200 status codes before attempting to decode
|
||||
if res.StatusCode != http.StatusOK {
|
||||
body, _ := io.ReadAll(res.Body)
|
||||
bodyStr := string(body)
|
||||
if len(bodyStr) > 500 {
|
||||
bodyStr = bodyStr[:500] + "..."
|
||||
}
|
||||
logger.LogError(ctx, fmt.Sprintf("[OAuth-GitHub] GetUserInfo failed: status=%d, body=%s", res.StatusCode, bodyStr))
|
||||
return nil, NewOAuthErrorWithRaw(i18n.MsgOAuthGetUserErr, map[string]any{"Provider": "GitHub"}, fmt.Sprintf("status %d", res.StatusCode))
|
||||
}
|
||||
|
||||
var githubUser gitHubUser
|
||||
err = json.NewDecoder(res.Body).Decode(&githubUser)
|
||||
if err != nil {
|
||||
|
||||
@@ -223,11 +223,8 @@ func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycom
|
||||
switch info.RelayFormat {
|
||||
case types.RelayFormatClaude:
|
||||
if supportsAliAnthropicMessages(info.UpstreamModelName) {
|
||||
if info.IsStream {
|
||||
return claude.ClaudeStreamHandler(c, resp, info)
|
||||
}
|
||||
|
||||
return claude.ClaudeHandler(c, resp, info)
|
||||
adaptor := claude.Adaptor{}
|
||||
return adaptor.DoResponse(c, resp, info)
|
||||
}
|
||||
|
||||
adaptor := openai.Adaptor{}
|
||||
|
||||
@@ -95,6 +95,7 @@ func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, request
|
||||
}
|
||||
|
||||
func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage any, err *types.NewAPIError) {
|
||||
info.FinalRequestRelayFormat = types.RelayFormatClaude
|
||||
if info.IsStream {
|
||||
return ClaudeStreamHandler(c, resp, info)
|
||||
} else {
|
||||
|
||||
111
relay/channel/claude/message_delta_usage_patch_test.go
Normal file
111
relay/channel/claude/message_delta_usage_patch_test.go
Normal file
@@ -0,0 +1,111 @@
|
||||
package claude
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/QuantumNous/new-api/dto"
|
||||
relaycommon "github.com/QuantumNous/new-api/relay/common"
|
||||
"github.com/QuantumNous/new-api/setting/model_setting"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
"github.com/tidwall/gjson"
|
||||
)
|
||||
|
||||
func TestPatchClaudeMessageDeltaUsageDataPreserveUnknownFields(t *testing.T) {
|
||||
originalData := `{"type":"message_delta","delta":{"stop_reason":"end_turn","stop_sequence":null},"usage":{"output_tokens":53},"vendor_meta":{"trace_id":"trace_001"}}`
|
||||
usage := &dto.ClaudeUsage{
|
||||
InputTokens: 100,
|
||||
CacheReadInputTokens: 30,
|
||||
CacheCreationInputTokens: 50,
|
||||
}
|
||||
|
||||
patchedData := patchClaudeMessageDeltaUsageData(originalData, usage)
|
||||
|
||||
require.Equal(t, "message_delta", gjson.Get(patchedData, "type").String())
|
||||
require.Equal(t, "end_turn", gjson.Get(patchedData, "delta.stop_reason").String())
|
||||
require.Equal(t, "trace_001", gjson.Get(patchedData, "vendor_meta.trace_id").String())
|
||||
require.EqualValues(t, 53, gjson.Get(patchedData, "usage.output_tokens").Int())
|
||||
require.EqualValues(t, 100, gjson.Get(patchedData, "usage.input_tokens").Int())
|
||||
require.EqualValues(t, 30, gjson.Get(patchedData, "usage.cache_read_input_tokens").Int())
|
||||
require.EqualValues(t, 50, gjson.Get(patchedData, "usage.cache_creation_input_tokens").Int())
|
||||
}
|
||||
|
||||
func TestPatchClaudeMessageDeltaUsageDataZeroValueChecks(t *testing.T) {
|
||||
originalData := `{"type":"message_delta","usage":{"output_tokens":53,"input_tokens":9,"cache_read_input_tokens":0}}`
|
||||
usage := &dto.ClaudeUsage{
|
||||
InputTokens: 100,
|
||||
CacheReadInputTokens: 30,
|
||||
CacheCreationInputTokens: 0,
|
||||
}
|
||||
|
||||
patchedData := patchClaudeMessageDeltaUsageData(originalData, usage)
|
||||
|
||||
require.EqualValues(t, 9, gjson.Get(patchedData, "usage.input_tokens").Int())
|
||||
require.EqualValues(t, 30, gjson.Get(patchedData, "usage.cache_read_input_tokens").Int())
|
||||
assert.False(t, gjson.Get(patchedData, "usage.cache_creation_input_tokens").Exists())
|
||||
}
|
||||
|
||||
func TestShouldSkipClaudeMessageDeltaUsagePatch(t *testing.T) {
|
||||
originGlobalPassThrough := model_setting.GetGlobalSettings().PassThroughRequestEnabled
|
||||
t.Cleanup(func() {
|
||||
model_setting.GetGlobalSettings().PassThroughRequestEnabled = originGlobalPassThrough
|
||||
})
|
||||
|
||||
model_setting.GetGlobalSettings().PassThroughRequestEnabled = true
|
||||
assert.True(t, shouldSkipClaudeMessageDeltaUsagePatch(&relaycommon.RelayInfo{}))
|
||||
|
||||
model_setting.GetGlobalSettings().PassThroughRequestEnabled = false
|
||||
assert.True(t, shouldSkipClaudeMessageDeltaUsagePatch(&relaycommon.RelayInfo{
|
||||
ChannelMeta: &relaycommon.ChannelMeta{ChannelSetting: dto.ChannelSettings{PassThroughBodyEnabled: true}},
|
||||
}))
|
||||
assert.False(t, shouldSkipClaudeMessageDeltaUsagePatch(&relaycommon.RelayInfo{
|
||||
ChannelMeta: &relaycommon.ChannelMeta{ChannelSetting: dto.ChannelSettings{PassThroughBodyEnabled: false}},
|
||||
}))
|
||||
}
|
||||
|
||||
func TestBuildMessageDeltaPatchUsage(t *testing.T) {
|
||||
t.Run("merge missing fields from claudeInfo", func(t *testing.T) {
|
||||
claudeResponse := &dto.ClaudeResponse{Usage: &dto.ClaudeUsage{OutputTokens: 53}}
|
||||
claudeInfo := &ClaudeResponseInfo{
|
||||
Usage: &dto.Usage{
|
||||
PromptTokens: 100,
|
||||
PromptTokensDetails: dto.InputTokenDetails{
|
||||
CachedTokens: 30,
|
||||
CachedCreationTokens: 50,
|
||||
},
|
||||
ClaudeCacheCreation5mTokens: 10,
|
||||
ClaudeCacheCreation1hTokens: 20,
|
||||
},
|
||||
}
|
||||
|
||||
usage := buildMessageDeltaPatchUsage(claudeResponse, claudeInfo)
|
||||
require.NotNil(t, usage)
|
||||
require.EqualValues(t, 100, usage.InputTokens)
|
||||
require.EqualValues(t, 30, usage.CacheReadInputTokens)
|
||||
require.EqualValues(t, 50, usage.CacheCreationInputTokens)
|
||||
require.EqualValues(t, 53, usage.OutputTokens)
|
||||
require.NotNil(t, usage.CacheCreation)
|
||||
require.EqualValues(t, 10, usage.CacheCreation.Ephemeral5mInputTokens)
|
||||
require.EqualValues(t, 20, usage.CacheCreation.Ephemeral1hInputTokens)
|
||||
})
|
||||
|
||||
t.Run("keep upstream non-zero values", func(t *testing.T) {
|
||||
claudeResponse := &dto.ClaudeResponse{Usage: &dto.ClaudeUsage{
|
||||
InputTokens: 9,
|
||||
CacheReadInputTokens: 7,
|
||||
CacheCreationInputTokens: 6,
|
||||
}}
|
||||
claudeInfo := &ClaudeResponseInfo{Usage: &dto.Usage{
|
||||
PromptTokens: 100,
|
||||
PromptTokensDetails: dto.InputTokenDetails{
|
||||
CachedTokens: 30,
|
||||
CachedCreationTokens: 50,
|
||||
},
|
||||
}}
|
||||
|
||||
usage := buildMessageDeltaPatchUsage(claudeResponse, claudeInfo)
|
||||
require.EqualValues(t, 9, usage.InputTokens)
|
||||
require.EqualValues(t, 7, usage.CacheReadInputTokens)
|
||||
require.EqualValues(t, 6, usage.CacheCreationInputTokens)
|
||||
})
|
||||
}
|
||||
@@ -21,6 +21,8 @@ import (
|
||||
"github.com/QuantumNous/new-api/types"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/tidwall/gjson"
|
||||
"github.com/tidwall/sjson"
|
||||
)
|
||||
|
||||
const (
|
||||
@@ -544,6 +546,78 @@ type ClaudeResponseInfo struct {
|
||||
Done bool
|
||||
}
|
||||
|
||||
func buildMessageDeltaPatchUsage(claudeResponse *dto.ClaudeResponse, claudeInfo *ClaudeResponseInfo) *dto.ClaudeUsage {
|
||||
usage := &dto.ClaudeUsage{}
|
||||
if claudeResponse != nil && claudeResponse.Usage != nil {
|
||||
*usage = *claudeResponse.Usage
|
||||
}
|
||||
|
||||
if claudeInfo == nil || claudeInfo.Usage == nil {
|
||||
return usage
|
||||
}
|
||||
|
||||
if usage.InputTokens == 0 && claudeInfo.Usage.PromptTokens > 0 {
|
||||
usage.InputTokens = claudeInfo.Usage.PromptTokens
|
||||
}
|
||||
if usage.CacheReadInputTokens == 0 && claudeInfo.Usage.PromptTokensDetails.CachedTokens > 0 {
|
||||
usage.CacheReadInputTokens = claudeInfo.Usage.PromptTokensDetails.CachedTokens
|
||||
}
|
||||
if usage.CacheCreationInputTokens == 0 && claudeInfo.Usage.PromptTokensDetails.CachedCreationTokens > 0 {
|
||||
usage.CacheCreationInputTokens = claudeInfo.Usage.PromptTokensDetails.CachedCreationTokens
|
||||
}
|
||||
if usage.CacheCreation == nil && (claudeInfo.Usage.ClaudeCacheCreation5mTokens > 0 || claudeInfo.Usage.ClaudeCacheCreation1hTokens > 0) {
|
||||
usage.CacheCreation = &dto.ClaudeCacheCreationUsage{
|
||||
Ephemeral5mInputTokens: claudeInfo.Usage.ClaudeCacheCreation5mTokens,
|
||||
Ephemeral1hInputTokens: claudeInfo.Usage.ClaudeCacheCreation1hTokens,
|
||||
}
|
||||
}
|
||||
return usage
|
||||
}
|
||||
|
||||
func shouldSkipClaudeMessageDeltaUsagePatch(info *relaycommon.RelayInfo) bool {
|
||||
if model_setting.GetGlobalSettings().PassThroughRequestEnabled {
|
||||
return true
|
||||
}
|
||||
if info == nil {
|
||||
return false
|
||||
}
|
||||
return info.ChannelSetting.PassThroughBodyEnabled
|
||||
}
|
||||
|
||||
func patchClaudeMessageDeltaUsageData(data string, usage *dto.ClaudeUsage) string {
|
||||
if data == "" || usage == nil {
|
||||
return data
|
||||
}
|
||||
|
||||
data = setMessageDeltaUsageInt(data, "usage.input_tokens", usage.InputTokens)
|
||||
data = setMessageDeltaUsageInt(data, "usage.cache_read_input_tokens", usage.CacheReadInputTokens)
|
||||
data = setMessageDeltaUsageInt(data, "usage.cache_creation_input_tokens", usage.CacheCreationInputTokens)
|
||||
|
||||
if usage.CacheCreation != nil {
|
||||
data = setMessageDeltaUsageInt(data, "usage.cache_creation.ephemeral_5m_input_tokens", usage.CacheCreation.Ephemeral5mInputTokens)
|
||||
data = setMessageDeltaUsageInt(data, "usage.cache_creation.ephemeral_1h_input_tokens", usage.CacheCreation.Ephemeral1hInputTokens)
|
||||
}
|
||||
|
||||
return data
|
||||
}
|
||||
|
||||
func setMessageDeltaUsageInt(data string, path string, localValue int) string {
|
||||
if localValue <= 0 {
|
||||
return data
|
||||
}
|
||||
|
||||
upstreamValue := gjson.Get(data, path)
|
||||
if upstreamValue.Exists() && upstreamValue.Int() > 0 {
|
||||
return data
|
||||
}
|
||||
|
||||
patchedData, err := sjson.Set(data, path, localValue)
|
||||
if err != nil {
|
||||
return data
|
||||
}
|
||||
return patchedData
|
||||
}
|
||||
|
||||
func FormatClaudeResponseInfo(claudeResponse *dto.ClaudeResponse, oaiResponse *dto.ChatCompletionsStreamResponse, claudeInfo *ClaudeResponseInfo) bool {
|
||||
if claudeInfo == nil {
|
||||
return false
|
||||
@@ -638,6 +712,12 @@ func HandleStreamResponseData(c *gin.Context, info *relaycommon.RelayInfo, claud
|
||||
if claudeResponse.Message != nil {
|
||||
info.UpstreamModelName = claudeResponse.Message.Model
|
||||
}
|
||||
} else if claudeResponse.Type == "message_delta" {
|
||||
// 确保 message_delta 的 usage 包含完整的 input_tokens 和 cache 相关字段
|
||||
// 解决 AWS Bedrock 等上游返回的 message_delta 缺少这些字段的问题
|
||||
if !shouldSkipClaudeMessageDeltaUsagePatch(info) {
|
||||
data = patchClaudeMessageDeltaUsageData(data, buildMessageDeltaPatchUsage(&claudeResponse, claudeInfo))
|
||||
}
|
||||
}
|
||||
helper.ClaudeChunkData(c, claudeResponse, data)
|
||||
} else if info.RelayFormat == types.RelayFormatOpenAI {
|
||||
|
||||
175
relay/channel/claude/relay_claude_test.go
Normal file
175
relay/channel/claude/relay_claude_test.go
Normal file
@@ -0,0 +1,175 @@
|
||||
package claude
|
||||
|
||||
import (
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"github.com/QuantumNous/new-api/dto"
|
||||
)
|
||||
|
||||
func TestFormatClaudeResponseInfo_MessageStart(t *testing.T) {
|
||||
claudeInfo := &ClaudeResponseInfo{
|
||||
Usage: &dto.Usage{},
|
||||
}
|
||||
claudeResponse := &dto.ClaudeResponse{
|
||||
Type: "message_start",
|
||||
Message: &dto.ClaudeMediaMessage{
|
||||
Id: "msg_123",
|
||||
Model: "claude-3-5-sonnet",
|
||||
Usage: &dto.ClaudeUsage{
|
||||
InputTokens: 100,
|
||||
OutputTokens: 1,
|
||||
CacheCreationInputTokens: 50,
|
||||
CacheReadInputTokens: 30,
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
ok := FormatClaudeResponseInfo(claudeResponse, nil, claudeInfo)
|
||||
if !ok {
|
||||
t.Fatal("expected true")
|
||||
}
|
||||
if claudeInfo.Usage.PromptTokens != 100 {
|
||||
t.Errorf("PromptTokens = %d, want 100", claudeInfo.Usage.PromptTokens)
|
||||
}
|
||||
if claudeInfo.Usage.PromptTokensDetails.CachedTokens != 30 {
|
||||
t.Errorf("CachedTokens = %d, want 30", claudeInfo.Usage.PromptTokensDetails.CachedTokens)
|
||||
}
|
||||
if claudeInfo.Usage.PromptTokensDetails.CachedCreationTokens != 50 {
|
||||
t.Errorf("CachedCreationTokens = %d, want 50", claudeInfo.Usage.PromptTokensDetails.CachedCreationTokens)
|
||||
}
|
||||
if claudeInfo.ResponseId != "msg_123" {
|
||||
t.Errorf("ResponseId = %s, want msg_123", claudeInfo.ResponseId)
|
||||
}
|
||||
if claudeInfo.Model != "claude-3-5-sonnet" {
|
||||
t.Errorf("Model = %s, want claude-3-5-sonnet", claudeInfo.Model)
|
||||
}
|
||||
}
|
||||
|
||||
func TestFormatClaudeResponseInfo_MessageDelta_FullUsage(t *testing.T) {
|
||||
// message_start 先积累 usage
|
||||
claudeInfo := &ClaudeResponseInfo{
|
||||
Usage: &dto.Usage{
|
||||
PromptTokens: 100,
|
||||
PromptTokensDetails: dto.InputTokenDetails{
|
||||
CachedTokens: 30,
|
||||
CachedCreationTokens: 50,
|
||||
},
|
||||
CompletionTokens: 1,
|
||||
},
|
||||
}
|
||||
|
||||
// message_delta 带完整 usage(原生 Anthropic 场景)
|
||||
claudeResponse := &dto.ClaudeResponse{
|
||||
Type: "message_delta",
|
||||
Usage: &dto.ClaudeUsage{
|
||||
InputTokens: 100,
|
||||
OutputTokens: 200,
|
||||
CacheCreationInputTokens: 50,
|
||||
CacheReadInputTokens: 30,
|
||||
},
|
||||
}
|
||||
|
||||
ok := FormatClaudeResponseInfo(claudeResponse, nil, claudeInfo)
|
||||
if !ok {
|
||||
t.Fatal("expected true")
|
||||
}
|
||||
if claudeInfo.Usage.PromptTokens != 100 {
|
||||
t.Errorf("PromptTokens = %d, want 100", claudeInfo.Usage.PromptTokens)
|
||||
}
|
||||
if claudeInfo.Usage.CompletionTokens != 200 {
|
||||
t.Errorf("CompletionTokens = %d, want 200", claudeInfo.Usage.CompletionTokens)
|
||||
}
|
||||
if claudeInfo.Usage.TotalTokens != 300 {
|
||||
t.Errorf("TotalTokens = %d, want 300", claudeInfo.Usage.TotalTokens)
|
||||
}
|
||||
if !claudeInfo.Done {
|
||||
t.Error("expected Done = true")
|
||||
}
|
||||
}
|
||||
|
||||
func TestFormatClaudeResponseInfo_MessageDelta_OnlyOutputTokens(t *testing.T) {
|
||||
// 模拟 Bedrock: message_start 已积累 usage
|
||||
claudeInfo := &ClaudeResponseInfo{
|
||||
Usage: &dto.Usage{
|
||||
PromptTokens: 100,
|
||||
PromptTokensDetails: dto.InputTokenDetails{
|
||||
CachedTokens: 30,
|
||||
CachedCreationTokens: 50,
|
||||
},
|
||||
CompletionTokens: 1,
|
||||
ClaudeCacheCreation5mTokens: 10,
|
||||
ClaudeCacheCreation1hTokens: 20,
|
||||
},
|
||||
}
|
||||
|
||||
// Bedrock 的 message_delta 只有 output_tokens,缺少 input_tokens 和 cache 字段
|
||||
claudeResponse := &dto.ClaudeResponse{
|
||||
Type: "message_delta",
|
||||
Usage: &dto.ClaudeUsage{
|
||||
OutputTokens: 200,
|
||||
// InputTokens, CacheCreationInputTokens, CacheReadInputTokens 都是 0
|
||||
},
|
||||
}
|
||||
|
||||
ok := FormatClaudeResponseInfo(claudeResponse, nil, claudeInfo)
|
||||
if !ok {
|
||||
t.Fatal("expected true")
|
||||
}
|
||||
// PromptTokens 应保持 message_start 的值(因为 message_delta 的 InputTokens=0,不更新)
|
||||
if claudeInfo.Usage.PromptTokens != 100 {
|
||||
t.Errorf("PromptTokens = %d, want 100", claudeInfo.Usage.PromptTokens)
|
||||
}
|
||||
if claudeInfo.Usage.CompletionTokens != 200 {
|
||||
t.Errorf("CompletionTokens = %d, want 200", claudeInfo.Usage.CompletionTokens)
|
||||
}
|
||||
if claudeInfo.Usage.TotalTokens != 300 {
|
||||
t.Errorf("TotalTokens = %d, want 300", claudeInfo.Usage.TotalTokens)
|
||||
}
|
||||
// cache 字段应保持 message_start 的值
|
||||
if claudeInfo.Usage.PromptTokensDetails.CachedTokens != 30 {
|
||||
t.Errorf("CachedTokens = %d, want 30", claudeInfo.Usage.PromptTokensDetails.CachedTokens)
|
||||
}
|
||||
if claudeInfo.Usage.PromptTokensDetails.CachedCreationTokens != 50 {
|
||||
t.Errorf("CachedCreationTokens = %d, want 50", claudeInfo.Usage.PromptTokensDetails.CachedCreationTokens)
|
||||
}
|
||||
if claudeInfo.Usage.ClaudeCacheCreation5mTokens != 10 {
|
||||
t.Errorf("ClaudeCacheCreation5mTokens = %d, want 10", claudeInfo.Usage.ClaudeCacheCreation5mTokens)
|
||||
}
|
||||
if claudeInfo.Usage.ClaudeCacheCreation1hTokens != 20 {
|
||||
t.Errorf("ClaudeCacheCreation1hTokens = %d, want 20", claudeInfo.Usage.ClaudeCacheCreation1hTokens)
|
||||
}
|
||||
if !claudeInfo.Done {
|
||||
t.Error("expected Done = true")
|
||||
}
|
||||
}
|
||||
|
||||
func TestFormatClaudeResponseInfo_NilClaudeInfo(t *testing.T) {
|
||||
claudeResponse := &dto.ClaudeResponse{Type: "message_start"}
|
||||
ok := FormatClaudeResponseInfo(claudeResponse, nil, nil)
|
||||
if ok {
|
||||
t.Error("expected false for nil claudeInfo")
|
||||
}
|
||||
}
|
||||
|
||||
func TestFormatClaudeResponseInfo_ContentBlockDelta(t *testing.T) {
|
||||
text := "hello"
|
||||
claudeInfo := &ClaudeResponseInfo{
|
||||
Usage: &dto.Usage{},
|
||||
ResponseText: strings.Builder{},
|
||||
}
|
||||
claudeResponse := &dto.ClaudeResponse{
|
||||
Type: "content_block_delta",
|
||||
Delta: &dto.ClaudeMediaMessage{
|
||||
Text: &text,
|
||||
},
|
||||
}
|
||||
|
||||
ok := FormatClaudeResponseInfo(claudeResponse, nil, claudeInfo)
|
||||
if !ok {
|
||||
t.Fatal("expected true")
|
||||
}
|
||||
if claudeInfo.ResponseText.String() != "hello" {
|
||||
t.Errorf("ResponseText = %q, want %q", claudeInfo.ResponseText.String(), "hello")
|
||||
}
|
||||
}
|
||||
@@ -26,7 +26,7 @@ func (a *Adaptor) ConvertGeminiRequest(c *gin.Context, info *relaycommon.RelayIn
|
||||
}
|
||||
|
||||
func (a *Adaptor) ConvertClaudeRequest(*gin.Context, *relaycommon.RelayInfo, *dto.ClaudeRequest) (any, error) {
|
||||
return nil, errors.New("codex channel: endpoint not supported")
|
||||
return nil, errors.New("codex channel: /v1/messages endpoint not supported")
|
||||
}
|
||||
|
||||
func (a *Adaptor) ConvertAudioRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.AudioRequest) (io.Reader, error) {
|
||||
@@ -41,15 +41,15 @@ func (a *Adaptor) Init(info *relaycommon.RelayInfo) {
|
||||
}
|
||||
|
||||
func (a *Adaptor) ConvertOpenAIRequest(c *gin.Context, info *relaycommon.RelayInfo, request *dto.GeneralOpenAIRequest) (any, error) {
|
||||
return nil, errors.New("codex channel: endpoint not supported")
|
||||
return nil, errors.New("codex channel: /v1/chat/completions endpoint not supported")
|
||||
}
|
||||
|
||||
func (a *Adaptor) ConvertRerankRequest(c *gin.Context, relayMode int, request dto.RerankRequest) (any, error) {
|
||||
return nil, errors.New("codex channel: endpoint not supported")
|
||||
return nil, errors.New("codex channel: /v1/rerank endpoint not supported")
|
||||
}
|
||||
|
||||
func (a *Adaptor) ConvertEmbeddingRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.EmbeddingRequest) (any, error) {
|
||||
return nil, errors.New("codex channel: endpoint not supported")
|
||||
return nil, errors.New("codex channel: /v1/embeddings endpoint not supported")
|
||||
}
|
||||
|
||||
func (a *Adaptor) ConvertOpenAIResponsesRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.OpenAIResponsesRequest) (any, error) {
|
||||
|
||||
@@ -95,11 +95,8 @@ func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, request
|
||||
func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage any, err *types.NewAPIError) {
|
||||
switch info.RelayFormat {
|
||||
case types.RelayFormatClaude:
|
||||
if info.IsStream {
|
||||
return claude.ClaudeStreamHandler(c, resp, info)
|
||||
} else {
|
||||
return claude.ClaudeHandler(c, resp, info)
|
||||
}
|
||||
adaptor := claude.Adaptor{}
|
||||
return adaptor.DoResponse(c, resp, info)
|
||||
default:
|
||||
adaptor := openai.Adaptor{}
|
||||
return adaptor.DoResponse(c, resp, info)
|
||||
|
||||
@@ -102,11 +102,8 @@ func (a *Adaptor) ConvertEmbeddingRequest(c *gin.Context, info *relaycommon.Rela
|
||||
func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage any, err *types.NewAPIError) {
|
||||
switch info.RelayFormat {
|
||||
case types.RelayFormatClaude:
|
||||
if info.IsStream {
|
||||
return claude.ClaudeStreamHandler(c, resp, info)
|
||||
} else {
|
||||
return claude.ClaudeHandler(c, resp, info)
|
||||
}
|
||||
adaptor := claude.Adaptor{}
|
||||
return adaptor.DoResponse(c, resp, info)
|
||||
default:
|
||||
adaptor := openai.Adaptor{}
|
||||
return adaptor.DoResponse(c, resp, info)
|
||||
|
||||
@@ -171,7 +171,9 @@ func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) {
|
||||
url = strings.Replace(url, "{model}", info.UpstreamModelName, -1)
|
||||
return url, nil
|
||||
default:
|
||||
if info.RelayFormat == types.RelayFormatClaude || info.RelayFormat == types.RelayFormatGemini {
|
||||
if (info.RelayFormat == types.RelayFormatClaude || info.RelayFormat == types.RelayFormatGemini) &&
|
||||
info.RelayMode != relayconstant.RelayModeResponses &&
|
||||
info.RelayMode != relayconstant.RelayModeResponsesCompact {
|
||||
return fmt.Sprintf("%s/v1/chat/completions", info.ChannelBaseUrl), nil
|
||||
}
|
||||
return relaycommon.GetFullRequestURL(info.ChannelBaseUrl, info.RequestURLPath, info.ChannelType), nil
|
||||
|
||||
@@ -71,12 +71,22 @@ func OaiResponsesToChatHandler(c *gin.Context, info *relaycommon.RelayInfo, resp
|
||||
chatResp.Usage = *usage
|
||||
}
|
||||
|
||||
chatBody, err := common.Marshal(chatResp)
|
||||
var responseBody []byte
|
||||
switch info.RelayFormat {
|
||||
case types.RelayFormatClaude:
|
||||
claudeResp := service.ResponseOpenAI2Claude(chatResp, info)
|
||||
responseBody, err = common.Marshal(claudeResp)
|
||||
case types.RelayFormatGemini:
|
||||
geminiResp := service.ResponseOpenAI2Gemini(chatResp, info)
|
||||
responseBody, err = common.Marshal(geminiResp)
|
||||
default:
|
||||
responseBody, err = common.Marshal(chatResp)
|
||||
}
|
||||
if err != nil {
|
||||
return nil, types.NewOpenAIError(err, types.ErrorCodeJsonMarshalFailed, http.StatusInternalServerError)
|
||||
}
|
||||
|
||||
service.IOCopyBytesGracefully(c, resp, chatBody)
|
||||
service.IOCopyBytesGracefully(c, resp, responseBody)
|
||||
return usage, nil
|
||||
}
|
||||
|
||||
@@ -106,14 +116,43 @@ func OaiResponsesToChatStreamHandler(c *gin.Context, info *relaycommon.RelayInfo
|
||||
toolCallArgsByID := make(map[string]string)
|
||||
toolCallNameSent := make(map[string]bool)
|
||||
toolCallCanonicalIDByItemID := make(map[string]string)
|
||||
hasSentReasoningSummary := false
|
||||
needsReasoningSummarySeparator := false
|
||||
//reasoningSummaryTextByKey := make(map[string]string)
|
||||
|
||||
if info.RelayFormat == types.RelayFormatClaude && info.ClaudeConvertInfo == nil {
|
||||
info.ClaudeConvertInfo = &relaycommon.ClaudeConvertInfo{LastMessagesType: relaycommon.LastMessageTypeNone}
|
||||
}
|
||||
|
||||
sendChatChunk := func(chunk *dto.ChatCompletionsStreamResponse) bool {
|
||||
if chunk == nil {
|
||||
return true
|
||||
}
|
||||
if info.RelayFormat == types.RelayFormatOpenAI {
|
||||
if err := helper.ObjectData(c, chunk); err != nil {
|
||||
streamErr = types.NewOpenAIError(err, types.ErrorCodeBadResponse, http.StatusInternalServerError)
|
||||
return false
|
||||
}
|
||||
return true
|
||||
}
|
||||
|
||||
chunkData, err := common.Marshal(chunk)
|
||||
if err != nil {
|
||||
streamErr = types.NewOpenAIError(err, types.ErrorCodeJsonMarshalFailed, http.StatusInternalServerError)
|
||||
return false
|
||||
}
|
||||
if err := HandleStreamFormat(c, info, string(chunkData), false, false); err != nil {
|
||||
streamErr = types.NewOpenAIError(err, types.ErrorCodeBadResponse, http.StatusInternalServerError)
|
||||
return false
|
||||
}
|
||||
return true
|
||||
}
|
||||
|
||||
sendStartIfNeeded := func() bool {
|
||||
if sentStart {
|
||||
return true
|
||||
}
|
||||
if err := helper.ObjectData(c, helper.GenerateStartEmptyResponse(responseId, createAt, model, nil)); err != nil {
|
||||
streamErr = types.NewOpenAIError(err, types.ErrorCodeBadResponse, http.StatusInternalServerError)
|
||||
if !sendChatChunk(helper.GenerateStartEmptyResponse(responseId, createAt, model, nil)) {
|
||||
return false
|
||||
}
|
||||
sentStart = true
|
||||
@@ -154,6 +193,17 @@ func OaiResponsesToChatStreamHandler(c *gin.Context, info *relaycommon.RelayInfo
|
||||
if delta == "" {
|
||||
return true
|
||||
}
|
||||
if needsReasoningSummarySeparator {
|
||||
if strings.HasPrefix(delta, "\n\n") {
|
||||
needsReasoningSummarySeparator = false
|
||||
} else if strings.HasPrefix(delta, "\n") {
|
||||
delta = "\n" + delta
|
||||
needsReasoningSummarySeparator = false
|
||||
} else {
|
||||
delta = "\n\n" + delta
|
||||
needsReasoningSummarySeparator = false
|
||||
}
|
||||
}
|
||||
if !sendStartIfNeeded() {
|
||||
return false
|
||||
}
|
||||
@@ -173,10 +223,10 @@ func OaiResponsesToChatStreamHandler(c *gin.Context, info *relaycommon.RelayInfo
|
||||
},
|
||||
},
|
||||
}
|
||||
if err := helper.ObjectData(c, chunk); err != nil {
|
||||
streamErr = types.NewOpenAIError(err, types.ErrorCodeBadResponse, http.StatusInternalServerError)
|
||||
if !sendChatChunk(chunk) {
|
||||
return false
|
||||
}
|
||||
hasSentReasoningSummary = true
|
||||
return true
|
||||
}
|
||||
|
||||
@@ -231,8 +281,7 @@ func OaiResponsesToChatStreamHandler(c *gin.Context, info *relaycommon.RelayInfo
|
||||
},
|
||||
},
|
||||
}
|
||||
if err := helper.ObjectData(c, chunk); err != nil {
|
||||
streamErr = types.NewOpenAIError(err, types.ErrorCodeBadResponse, http.StatusInternalServerError)
|
||||
if !sendChatChunk(chunk) {
|
||||
return false
|
||||
}
|
||||
sawToolCall = true
|
||||
@@ -282,6 +331,9 @@ func OaiResponsesToChatStreamHandler(c *gin.Context, info *relaycommon.RelayInfo
|
||||
}
|
||||
|
||||
case "response.reasoning_summary_text.done":
|
||||
if hasSentReasoningSummary {
|
||||
needsReasoningSummarySeparator = true
|
||||
}
|
||||
|
||||
//case "response.reasoning_summary_part.added", "response.reasoning_summary_part.done":
|
||||
// key := responsesStreamIndexKey(strings.TrimSpace(streamResp.ItemID), streamResp.SummaryIndex)
|
||||
@@ -323,8 +375,7 @@ func OaiResponsesToChatStreamHandler(c *gin.Context, info *relaycommon.RelayInfo
|
||||
},
|
||||
},
|
||||
}
|
||||
if err := helper.ObjectData(c, chunk); err != nil {
|
||||
streamErr = types.NewOpenAIError(err, types.ErrorCodeBadResponse, http.StatusInternalServerError)
|
||||
if !sendChatChunk(chunk) {
|
||||
return false
|
||||
}
|
||||
}
|
||||
@@ -419,13 +470,15 @@ func OaiResponsesToChatStreamHandler(c *gin.Context, info *relaycommon.RelayInfo
|
||||
return false
|
||||
}
|
||||
if !sentStop {
|
||||
if info.RelayFormat == types.RelayFormatClaude && info.ClaudeConvertInfo != nil {
|
||||
info.ClaudeConvertInfo.Usage = usage
|
||||
}
|
||||
finishReason := "stop"
|
||||
if sawToolCall && outputText.Len() == 0 {
|
||||
finishReason = "tool_calls"
|
||||
}
|
||||
stop := helper.GenerateStopResponse(responseId, createAt, model, finishReason)
|
||||
if err := helper.ObjectData(c, stop); err != nil {
|
||||
streamErr = types.NewOpenAIError(err, types.ErrorCodeBadResponse, http.StatusInternalServerError)
|
||||
if !sendChatChunk(stop) {
|
||||
return false
|
||||
}
|
||||
sentStop = true
|
||||
@@ -456,26 +509,31 @@ func OaiResponsesToChatStreamHandler(c *gin.Context, info *relaycommon.RelayInfo
|
||||
}
|
||||
|
||||
if !sentStart {
|
||||
if err := helper.ObjectData(c, helper.GenerateStartEmptyResponse(responseId, createAt, model, nil)); err != nil {
|
||||
return nil, types.NewOpenAIError(err, types.ErrorCodeBadResponse, http.StatusInternalServerError)
|
||||
if !sendChatChunk(helper.GenerateStartEmptyResponse(responseId, createAt, model, nil)) {
|
||||
return nil, streamErr
|
||||
}
|
||||
}
|
||||
if !sentStop {
|
||||
if info.RelayFormat == types.RelayFormatClaude && info.ClaudeConvertInfo != nil {
|
||||
info.ClaudeConvertInfo.Usage = usage
|
||||
}
|
||||
finishReason := "stop"
|
||||
if sawToolCall && outputText.Len() == 0 {
|
||||
finishReason = "tool_calls"
|
||||
}
|
||||
stop := helper.GenerateStopResponse(responseId, createAt, model, finishReason)
|
||||
if err := helper.ObjectData(c, stop); err != nil {
|
||||
return nil, types.NewOpenAIError(err, types.ErrorCodeBadResponse, http.StatusInternalServerError)
|
||||
if !sendChatChunk(stop) {
|
||||
return nil, streamErr
|
||||
}
|
||||
}
|
||||
if info.ShouldIncludeUsage && usage != nil {
|
||||
if info.RelayFormat == types.RelayFormatOpenAI && info.ShouldIncludeUsage && usage != nil {
|
||||
if err := helper.ObjectData(c, helper.GenerateFinalUsageResponse(responseId, createAt, model, *usage)); err != nil {
|
||||
return nil, types.NewOpenAIError(err, types.ErrorCodeBadResponse, http.StatusInternalServerError)
|
||||
}
|
||||
}
|
||||
|
||||
helper.Done(c)
|
||||
if info.RelayFormat == types.RelayFormatOpenAI {
|
||||
helper.Done(c)
|
||||
}
|
||||
return usage, nil
|
||||
}
|
||||
|
||||
@@ -365,10 +365,11 @@ func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, request
|
||||
}
|
||||
|
||||
func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage any, err *types.NewAPIError) {
|
||||
claudeAdaptor := claude.Adaptor{}
|
||||
if info.IsStream {
|
||||
switch a.RequestMode {
|
||||
case RequestModeClaude:
|
||||
return claude.ClaudeStreamHandler(c, resp, info)
|
||||
return claudeAdaptor.DoResponse(c, resp, info)
|
||||
case RequestModeGemini:
|
||||
if info.RelayMode == constant.RelayModeGemini {
|
||||
return gemini.GeminiTextGenerationStreamHandler(c, info, resp)
|
||||
@@ -381,7 +382,7 @@ func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycom
|
||||
} else {
|
||||
switch a.RequestMode {
|
||||
case RequestModeClaude:
|
||||
return claude.ClaudeHandler(c, resp, info)
|
||||
return claudeAdaptor.DoResponse(c, resp, info)
|
||||
case RequestModeGemini:
|
||||
if info.RelayMode == constant.RelayModeGemini {
|
||||
return gemini.GeminiTextGenerationHandler(c, info, resp)
|
||||
|
||||
@@ -347,10 +347,8 @@ func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, request
|
||||
func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage any, err *types.NewAPIError) {
|
||||
if info.RelayFormat == types.RelayFormatClaude {
|
||||
if _, ok := channelconstant.ChannelSpecialBases[info.ChannelBaseUrl]; ok {
|
||||
if info.IsStream {
|
||||
return claude.ClaudeStreamHandler(c, resp, info)
|
||||
}
|
||||
return claude.ClaudeHandler(c, resp, info)
|
||||
adaptor := claude.Adaptor{}
|
||||
return adaptor.DoResponse(c, resp, info)
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -109,11 +109,8 @@ func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, request
|
||||
func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage any, err *types.NewAPIError) {
|
||||
switch info.RelayFormat {
|
||||
case types.RelayFormatClaude:
|
||||
if info.IsStream {
|
||||
return claude.ClaudeStreamHandler(c, resp, info)
|
||||
} else {
|
||||
return claude.ClaudeHandler(c, resp, info)
|
||||
}
|
||||
adaptor := claude.Adaptor{}
|
||||
return adaptor.DoResponse(c, resp, info)
|
||||
default:
|
||||
if info.RelayMode == relayconstant.RelayModeImagesGenerations {
|
||||
return zhipu4vImageHandler(c, resp, info)
|
||||
|
||||
@@ -110,6 +110,23 @@ func ClaudeHelper(c *gin.Context, info *relaycommon.RelayInfo) (newAPIError *typ
|
||||
}
|
||||
}
|
||||
|
||||
if !model_setting.GetGlobalSettings().PassThroughRequestEnabled &&
|
||||
!info.ChannelSetting.PassThroughBodyEnabled &&
|
||||
service.ShouldChatCompletionsUseResponsesGlobal(info.ChannelId, info.ChannelType, info.OriginModelName) {
|
||||
openAIRequest, convErr := service.ClaudeToOpenAIRequest(*request, info)
|
||||
if convErr != nil {
|
||||
return types.NewError(convErr, types.ErrorCodeConvertRequestFailed, types.ErrOptionWithSkipRetry())
|
||||
}
|
||||
|
||||
usage, newApiErr := chatCompletionsViaResponses(c, info, adaptor, openAIRequest)
|
||||
if newApiErr != nil {
|
||||
return newApiErr
|
||||
}
|
||||
|
||||
service.PostClaudeConsumeQuota(c, info, usage)
|
||||
return nil
|
||||
}
|
||||
|
||||
var requestBody io.Reader
|
||||
if model_setting.GetGlobalSettings().PassThroughRequestEnabled || info.ChannelSetting.PassThroughBodyEnabled {
|
||||
body, err := common.GetRequestBody(c)
|
||||
|
||||
@@ -37,6 +37,9 @@ type ClaudeConvertInfo struct {
|
||||
Usage *dto.Usage
|
||||
FinishReason string
|
||||
Done bool
|
||||
|
||||
ToolCallBaseIndex int
|
||||
ToolCallMaxIndexOffset int
|
||||
}
|
||||
|
||||
type RerankerInfo struct {
|
||||
@@ -145,6 +148,8 @@ type RelayInfo struct {
|
||||
// RequestConversionChain records request format conversions in order, e.g.
|
||||
// ["openai", "openai_responses"] or ["openai", "claude"].
|
||||
RequestConversionChain []types.RelayFormat
|
||||
// 最终请求到上游的格式 TODO: 当前仅设置了Claude
|
||||
FinalRequestRelayFormat types.RelayFormat
|
||||
|
||||
ThinkingContentInfo
|
||||
TokenCountMeta
|
||||
@@ -319,12 +324,15 @@ func GenRelayInfoClaude(c *gin.Context, request dto.Request) *RelayInfo {
|
||||
info.ClaudeConvertInfo = &ClaudeConvertInfo{
|
||||
LastMessagesType: LastMessageTypeNone,
|
||||
}
|
||||
if c.Query("beta") == "true" {
|
||||
info.IsClaudeBetaQuery = true
|
||||
}
|
||||
info.IsClaudeBetaQuery = c.Query("beta") == "true" || isClaudeBetaForced(c)
|
||||
return info
|
||||
}
|
||||
|
||||
func isClaudeBetaForced(c *gin.Context) bool {
|
||||
channelOtherSettings, ok := common.GetContextKeyType[dto.ChannelOtherSettings](c, constant.ContextKeyChannelOtherSetting)
|
||||
return ok && channelOtherSettings.ClaudeBetaQuery
|
||||
}
|
||||
|
||||
func GenRelayInfoRerank(c *gin.Context, request *dto.RerankRequest) *RelayInfo {
|
||||
info := genBaseRelayInfo(c, request)
|
||||
info.RelayMode = relayconstant.RelayModeRerank
|
||||
|
||||
@@ -334,7 +334,7 @@ func postConsumeQuota(ctx *gin.Context, relayInfo *relaycommon.RelayInfo, usage
|
||||
|
||||
var audioInputQuota decimal.Decimal
|
||||
var audioInputPrice float64
|
||||
isClaudeUsageSemantic := relayInfo.ChannelType == constant.ChannelTypeAnthropic
|
||||
isClaudeUsageSemantic := relayInfo.FinalRequestRelayFormat == types.RelayFormatClaude
|
||||
if !relayInfo.PriceData.UsePrice {
|
||||
baseTokens := dPromptTokens
|
||||
// 减去 cached tokens
|
||||
|
||||
@@ -207,6 +207,44 @@ func StreamResponseOpenAI2Claude(openAIResponse *dto.ChatCompletionsStreamRespon
|
||||
}
|
||||
|
||||
var claudeResponses []*dto.ClaudeResponse
|
||||
// stopOpenBlocks emits the required content_block_stop event(s) for the currently open block(s)
|
||||
// according to Anthropic's SSE streaming state machine:
|
||||
// content_block_start -> content_block_delta* -> content_block_stop (per index).
|
||||
//
|
||||
// For text/thinking, there is at most one open block at info.ClaudeConvertInfo.Index.
|
||||
// For tools, OpenAI tool_calls can stream multiple parallel tool_use blocks (indexed from 0),
|
||||
// so we may have multiple open blocks and must stop each one explicitly.
|
||||
stopOpenBlocks := func() {
|
||||
switch info.ClaudeConvertInfo.LastMessagesType {
|
||||
case relaycommon.LastMessageTypeText, relaycommon.LastMessageTypeThinking:
|
||||
claudeResponses = append(claudeResponses, generateStopBlock(info.ClaudeConvertInfo.Index))
|
||||
case relaycommon.LastMessageTypeTools:
|
||||
base := info.ClaudeConvertInfo.ToolCallBaseIndex
|
||||
for offset := 0; offset <= info.ClaudeConvertInfo.ToolCallMaxIndexOffset; offset++ {
|
||||
claudeResponses = append(claudeResponses, generateStopBlock(base+offset))
|
||||
}
|
||||
}
|
||||
}
|
||||
// stopOpenBlocksAndAdvance closes the currently open block(s) and advances the content block index
|
||||
// to the next available slot for subsequent content_block_start events.
|
||||
//
|
||||
// This prevents invalid streams where a content_block_delta (e.g. thinking_delta) is emitted for an
|
||||
// index whose active content_block type is different (the typical cause of "Mismatched content block type").
|
||||
stopOpenBlocksAndAdvance := func() {
|
||||
if info.ClaudeConvertInfo.LastMessagesType == relaycommon.LastMessageTypeNone {
|
||||
return
|
||||
}
|
||||
stopOpenBlocks()
|
||||
switch info.ClaudeConvertInfo.LastMessagesType {
|
||||
case relaycommon.LastMessageTypeTools:
|
||||
info.ClaudeConvertInfo.Index = info.ClaudeConvertInfo.ToolCallBaseIndex + info.ClaudeConvertInfo.ToolCallMaxIndexOffset + 1
|
||||
info.ClaudeConvertInfo.ToolCallBaseIndex = 0
|
||||
info.ClaudeConvertInfo.ToolCallMaxIndexOffset = 0
|
||||
default:
|
||||
info.ClaudeConvertInfo.Index++
|
||||
}
|
||||
info.ClaudeConvertInfo.LastMessagesType = relaycommon.LastMessageTypeNone
|
||||
}
|
||||
if info.SendResponseCount == 1 {
|
||||
msg := &dto.ClaudeMediaMessage{
|
||||
Id: openAIResponse.Id,
|
||||
@@ -228,6 +266,8 @@ func StreamResponseOpenAI2Claude(openAIResponse *dto.ChatCompletionsStreamRespon
|
||||
//})
|
||||
if openAIResponse.IsToolCall() {
|
||||
info.ClaudeConvertInfo.LastMessagesType = relaycommon.LastMessageTypeTools
|
||||
info.ClaudeConvertInfo.ToolCallBaseIndex = 0
|
||||
info.ClaudeConvertInfo.ToolCallMaxIndexOffset = 0
|
||||
var toolCall dto.ToolCallResponse
|
||||
if len(openAIResponse.Choices) > 0 && len(openAIResponse.Choices[0].Delta.ToolCalls) > 0 {
|
||||
toolCall = openAIResponse.Choices[0].Delta.ToolCalls[0]
|
||||
@@ -252,8 +292,9 @@ func StreamResponseOpenAI2Claude(openAIResponse *dto.ChatCompletionsStreamRespon
|
||||
claudeResponses = append(claudeResponses, resp)
|
||||
// 首块包含工具 delta,则追加 input_json_delta
|
||||
if toolCall.Function.Arguments != "" {
|
||||
idx := 0
|
||||
claudeResponses = append(claudeResponses, &dto.ClaudeResponse{
|
||||
Index: &info.ClaudeConvertInfo.Index,
|
||||
Index: &idx,
|
||||
Type: "content_block_delta",
|
||||
Delta: &dto.ClaudeMediaMessage{
|
||||
Type: "input_json_delta",
|
||||
@@ -270,16 +311,21 @@ func StreamResponseOpenAI2Claude(openAIResponse *dto.ChatCompletionsStreamRespon
|
||||
content := openAIResponse.Choices[0].Delta.GetContentString()
|
||||
|
||||
if reasoning != "" {
|
||||
if info.ClaudeConvertInfo.LastMessagesType != relaycommon.LastMessageTypeThinking {
|
||||
stopOpenBlocksAndAdvance()
|
||||
}
|
||||
idx := info.ClaudeConvertInfo.Index
|
||||
claudeResponses = append(claudeResponses, &dto.ClaudeResponse{
|
||||
Index: &info.ClaudeConvertInfo.Index,
|
||||
Index: &idx,
|
||||
Type: "content_block_start",
|
||||
ContentBlock: &dto.ClaudeMediaMessage{
|
||||
Type: "thinking",
|
||||
Thinking: common.GetPointer[string](""),
|
||||
},
|
||||
})
|
||||
idx2 := idx
|
||||
claudeResponses = append(claudeResponses, &dto.ClaudeResponse{
|
||||
Index: &info.ClaudeConvertInfo.Index,
|
||||
Index: &idx2,
|
||||
Type: "content_block_delta",
|
||||
Delta: &dto.ClaudeMediaMessage{
|
||||
Type: "thinking_delta",
|
||||
@@ -288,16 +334,21 @@ func StreamResponseOpenAI2Claude(openAIResponse *dto.ChatCompletionsStreamRespon
|
||||
})
|
||||
info.ClaudeConvertInfo.LastMessagesType = relaycommon.LastMessageTypeThinking
|
||||
} else if content != "" {
|
||||
if info.ClaudeConvertInfo.LastMessagesType != relaycommon.LastMessageTypeText {
|
||||
stopOpenBlocksAndAdvance()
|
||||
}
|
||||
idx := info.ClaudeConvertInfo.Index
|
||||
claudeResponses = append(claudeResponses, &dto.ClaudeResponse{
|
||||
Index: &info.ClaudeConvertInfo.Index,
|
||||
Index: &idx,
|
||||
Type: "content_block_start",
|
||||
ContentBlock: &dto.ClaudeMediaMessage{
|
||||
Type: "text",
|
||||
Text: common.GetPointer[string](""),
|
||||
},
|
||||
})
|
||||
idx2 := idx
|
||||
claudeResponses = append(claudeResponses, &dto.ClaudeResponse{
|
||||
Index: &info.ClaudeConvertInfo.Index,
|
||||
Index: &idx2,
|
||||
Type: "content_block_delta",
|
||||
Delta: &dto.ClaudeMediaMessage{
|
||||
Type: "text_delta",
|
||||
@@ -311,7 +362,7 @@ func StreamResponseOpenAI2Claude(openAIResponse *dto.ChatCompletionsStreamRespon
|
||||
// 如果首块就带 finish_reason,需要立即发送停止块
|
||||
if len(openAIResponse.Choices) > 0 && openAIResponse.Choices[0].FinishReason != nil && *openAIResponse.Choices[0].FinishReason != "" {
|
||||
info.FinishReason = *openAIResponse.Choices[0].FinishReason
|
||||
claudeResponses = append(claudeResponses, generateStopBlock(info.ClaudeConvertInfo.Index))
|
||||
stopOpenBlocks()
|
||||
oaiUsage := openAIResponse.Usage
|
||||
if oaiUsage == nil {
|
||||
oaiUsage = info.ClaudeConvertInfo.Usage
|
||||
@@ -342,7 +393,7 @@ func StreamResponseOpenAI2Claude(openAIResponse *dto.ChatCompletionsStreamRespon
|
||||
// no choices
|
||||
// 可能为非标准的 OpenAI 响应,判断是否已经完成
|
||||
if info.ClaudeConvertInfo.Done {
|
||||
claudeResponses = append(claudeResponses, generateStopBlock(info.ClaudeConvertInfo.Index))
|
||||
stopOpenBlocks()
|
||||
oaiUsage := info.ClaudeConvertInfo.Usage
|
||||
if oaiUsage != nil {
|
||||
claudeResponses = append(claudeResponses, &dto.ClaudeResponse{
|
||||
@@ -376,18 +427,25 @@ func StreamResponseOpenAI2Claude(openAIResponse *dto.ChatCompletionsStreamRespon
|
||||
if len(chosenChoice.Delta.ToolCalls) > 0 {
|
||||
toolCalls := chosenChoice.Delta.ToolCalls
|
||||
if info.ClaudeConvertInfo.LastMessagesType != relaycommon.LastMessageTypeTools {
|
||||
claudeResponses = append(claudeResponses, generateStopBlock(info.ClaudeConvertInfo.Index))
|
||||
info.ClaudeConvertInfo.Index++
|
||||
stopOpenBlocksAndAdvance()
|
||||
info.ClaudeConvertInfo.ToolCallBaseIndex = info.ClaudeConvertInfo.Index
|
||||
info.ClaudeConvertInfo.ToolCallMaxIndexOffset = 0
|
||||
}
|
||||
info.ClaudeConvertInfo.LastMessagesType = relaycommon.LastMessageTypeTools
|
||||
base := info.ClaudeConvertInfo.ToolCallBaseIndex
|
||||
maxOffset := info.ClaudeConvertInfo.ToolCallMaxIndexOffset
|
||||
|
||||
for i, toolCall := range toolCalls {
|
||||
blockIndex := info.ClaudeConvertInfo.Index
|
||||
offset := 0
|
||||
if toolCall.Index != nil {
|
||||
blockIndex = *toolCall.Index
|
||||
} else if len(toolCalls) > 1 {
|
||||
blockIndex = info.ClaudeConvertInfo.Index + i
|
||||
offset = *toolCall.Index
|
||||
} else {
|
||||
offset = i
|
||||
}
|
||||
if offset > maxOffset {
|
||||
maxOffset = offset
|
||||
}
|
||||
blockIndex := base + offset
|
||||
|
||||
idx := blockIndex
|
||||
if toolCall.Function.Name != "" {
|
||||
@@ -413,17 +471,19 @@ func StreamResponseOpenAI2Claude(openAIResponse *dto.ChatCompletionsStreamRespon
|
||||
},
|
||||
})
|
||||
}
|
||||
|
||||
info.ClaudeConvertInfo.Index = blockIndex
|
||||
}
|
||||
info.ClaudeConvertInfo.ToolCallMaxIndexOffset = maxOffset
|
||||
info.ClaudeConvertInfo.Index = base + maxOffset
|
||||
} else {
|
||||
reasoning := chosenChoice.Delta.GetReasoningContent()
|
||||
textContent := chosenChoice.Delta.GetContentString()
|
||||
if reasoning != "" || textContent != "" {
|
||||
if reasoning != "" {
|
||||
if info.ClaudeConvertInfo.LastMessagesType != relaycommon.LastMessageTypeThinking {
|
||||
stopOpenBlocksAndAdvance()
|
||||
idx := info.ClaudeConvertInfo.Index
|
||||
claudeResponses = append(claudeResponses, &dto.ClaudeResponse{
|
||||
Index: &info.ClaudeConvertInfo.Index,
|
||||
Index: &idx,
|
||||
Type: "content_block_start",
|
||||
ContentBlock: &dto.ClaudeMediaMessage{
|
||||
Type: "thinking",
|
||||
@@ -438,12 +498,10 @@ func StreamResponseOpenAI2Claude(openAIResponse *dto.ChatCompletionsStreamRespon
|
||||
}
|
||||
} else {
|
||||
if info.ClaudeConvertInfo.LastMessagesType != relaycommon.LastMessageTypeText {
|
||||
if info.ClaudeConvertInfo.LastMessagesType == relaycommon.LastMessageTypeThinking || info.ClaudeConvertInfo.LastMessagesType == relaycommon.LastMessageTypeTools {
|
||||
claudeResponses = append(claudeResponses, generateStopBlock(info.ClaudeConvertInfo.Index))
|
||||
info.ClaudeConvertInfo.Index++
|
||||
}
|
||||
stopOpenBlocksAndAdvance()
|
||||
idx := info.ClaudeConvertInfo.Index
|
||||
claudeResponses = append(claudeResponses, &dto.ClaudeResponse{
|
||||
Index: &info.ClaudeConvertInfo.Index,
|
||||
Index: &idx,
|
||||
Type: "content_block_start",
|
||||
ContentBlock: &dto.ClaudeMediaMessage{
|
||||
Type: "text",
|
||||
@@ -462,13 +520,13 @@ func StreamResponseOpenAI2Claude(openAIResponse *dto.ChatCompletionsStreamRespon
|
||||
}
|
||||
}
|
||||
|
||||
claudeResponse.Index = &info.ClaudeConvertInfo.Index
|
||||
claudeResponse.Index = common.GetPointer[int](info.ClaudeConvertInfo.Index)
|
||||
if !isEmpty && claudeResponse.Delta != nil {
|
||||
claudeResponses = append(claudeResponses, &claudeResponse)
|
||||
}
|
||||
|
||||
if doneChunk || info.ClaudeConvertInfo.Done {
|
||||
claudeResponses = append(claudeResponses, generateStopBlock(info.ClaudeConvertInfo.Index))
|
||||
stopOpenBlocks()
|
||||
oaiUsage := openAIResponse.Usage
|
||||
if oaiUsage == nil {
|
||||
oaiUsage = info.ClaudeConvertInfo.Usage
|
||||
|
||||
@@ -1,10 +1,7 @@
|
||||
package ratio_setting
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"sync"
|
||||
|
||||
"github.com/QuantumNous/new-api/common"
|
||||
"github.com/QuantumNous/new-api/types"
|
||||
)
|
||||
|
||||
var defaultCacheRatio = map[string]float64{
|
||||
@@ -98,44 +95,37 @@ var defaultCreateCacheRatio = map[string]float64{
|
||||
|
||||
//var defaultCreateCacheRatio = map[string]float64{}
|
||||
|
||||
var cacheRatioMap map[string]float64
|
||||
var cacheRatioMapMutex sync.RWMutex
|
||||
var cacheRatioMap = types.NewRWMap[string, float64]()
|
||||
var createCacheRatioMap = types.NewRWMap[string, float64]()
|
||||
|
||||
// GetCacheRatioMap returns the cache ratio map
|
||||
// GetCacheRatioMap returns a copy of the cache ratio map
|
||||
func GetCacheRatioMap() map[string]float64 {
|
||||
cacheRatioMapMutex.RLock()
|
||||
defer cacheRatioMapMutex.RUnlock()
|
||||
return cacheRatioMap
|
||||
return cacheRatioMap.ReadAll()
|
||||
}
|
||||
|
||||
// CacheRatio2JSONString converts the cache ratio map to a JSON string
|
||||
func CacheRatio2JSONString() string {
|
||||
cacheRatioMapMutex.RLock()
|
||||
defer cacheRatioMapMutex.RUnlock()
|
||||
jsonBytes, err := json.Marshal(cacheRatioMap)
|
||||
if err != nil {
|
||||
common.SysLog("error marshalling cache ratio: " + err.Error())
|
||||
}
|
||||
return string(jsonBytes)
|
||||
return cacheRatioMap.MarshalJSONString()
|
||||
}
|
||||
|
||||
// CreateCacheRatio2JSONString converts the create cache ratio map to a JSON string
|
||||
func CreateCacheRatio2JSONString() string {
|
||||
return createCacheRatioMap.MarshalJSONString()
|
||||
}
|
||||
|
||||
// UpdateCacheRatioByJSONString updates the cache ratio map from a JSON string
|
||||
func UpdateCacheRatioByJSONString(jsonStr string) error {
|
||||
cacheRatioMapMutex.Lock()
|
||||
defer cacheRatioMapMutex.Unlock()
|
||||
cacheRatioMap = make(map[string]float64)
|
||||
err := json.Unmarshal([]byte(jsonStr), &cacheRatioMap)
|
||||
if err == nil {
|
||||
InvalidateExposedDataCache()
|
||||
}
|
||||
return err
|
||||
return types.LoadFromJsonStringWithCallback(cacheRatioMap, jsonStr, InvalidateExposedDataCache)
|
||||
}
|
||||
|
||||
// UpdateCreateCacheRatioByJSONString updates the create cache ratio map from a JSON string
|
||||
func UpdateCreateCacheRatioByJSONString(jsonStr string) error {
|
||||
return types.LoadFromJsonStringWithCallback(createCacheRatioMap, jsonStr, InvalidateExposedDataCache)
|
||||
}
|
||||
|
||||
// GetCacheRatio returns the cache ratio for a model
|
||||
func GetCacheRatio(name string) (float64, bool) {
|
||||
cacheRatioMapMutex.RLock()
|
||||
defer cacheRatioMapMutex.RUnlock()
|
||||
ratio, ok := cacheRatioMap[name]
|
||||
ratio, ok := cacheRatioMap.Get(name)
|
||||
if !ok {
|
||||
return 1, false // Default to 1 if not found
|
||||
}
|
||||
@@ -143,7 +133,7 @@ func GetCacheRatio(name string) (float64, bool) {
|
||||
}
|
||||
|
||||
func GetCreateCacheRatio(name string) (float64, bool) {
|
||||
ratio, ok := defaultCreateCacheRatio[name]
|
||||
ratio, ok := createCacheRatioMap.Get(name)
|
||||
if !ok {
|
||||
return 1.25, false // Default to 1.25 if not found
|
||||
}
|
||||
@@ -151,11 +141,9 @@ func GetCreateCacheRatio(name string) (float64, bool) {
|
||||
}
|
||||
|
||||
func GetCacheRatioCopy() map[string]float64 {
|
||||
cacheRatioMapMutex.RLock()
|
||||
defer cacheRatioMapMutex.RUnlock()
|
||||
copyMap := make(map[string]float64, len(cacheRatioMap))
|
||||
for k, v := range cacheRatioMap {
|
||||
copyMap[k] = v
|
||||
}
|
||||
return copyMap
|
||||
return cacheRatioMap.ReadAll()
|
||||
}
|
||||
|
||||
func GetCreateCacheRatioCopy() map[string]float64 {
|
||||
return createCacheRatioMap.ReadAll()
|
||||
}
|
||||
|
||||
@@ -42,10 +42,11 @@ func GetExposedData() gin.H {
|
||||
return cloneGinH(c.data)
|
||||
}
|
||||
newData := gin.H{
|
||||
"model_ratio": GetModelRatioCopy(),
|
||||
"completion_ratio": GetCompletionRatioCopy(),
|
||||
"cache_ratio": GetCacheRatioCopy(),
|
||||
"model_price": GetModelPriceCopy(),
|
||||
"model_ratio": GetModelRatioCopy(),
|
||||
"completion_ratio": GetCompletionRatioCopy(),
|
||||
"cache_ratio": GetCacheRatioCopy(),
|
||||
"create_cache_ratio": GetCreateCacheRatioCopy(),
|
||||
"model_price": GetModelPriceCopy(),
|
||||
}
|
||||
exposedData.Store(&exposedCache{
|
||||
data: newData,
|
||||
|
||||
@@ -3,29 +3,27 @@ package ratio_setting
|
||||
import (
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"sync"
|
||||
|
||||
"github.com/QuantumNous/new-api/common"
|
||||
"github.com/QuantumNous/new-api/setting/config"
|
||||
"github.com/QuantumNous/new-api/types"
|
||||
)
|
||||
|
||||
var groupRatio = map[string]float64{
|
||||
var defaultGroupRatio = map[string]float64{
|
||||
"default": 1,
|
||||
"vip": 1,
|
||||
"svip": 1,
|
||||
}
|
||||
|
||||
var groupRatioMutex sync.RWMutex
|
||||
var groupRatioMap = types.NewRWMap[string, float64]()
|
||||
|
||||
var (
|
||||
GroupGroupRatio = map[string]map[string]float64{
|
||||
"vip": {
|
||||
"edit_this": 0.9,
|
||||
},
|
||||
}
|
||||
groupGroupRatioMutex sync.RWMutex
|
||||
)
|
||||
var defaultGroupGroupRatio = map[string]map[string]float64{
|
||||
"vip": {
|
||||
"edit_this": 0.9,
|
||||
},
|
||||
}
|
||||
|
||||
var groupGroupRatioMap = types.NewRWMap[string, map[string]float64]()
|
||||
|
||||
var defaultGroupSpecialUsableGroup = map[string]map[string]string{
|
||||
"vip": {
|
||||
@@ -35,9 +33,9 @@ var defaultGroupSpecialUsableGroup = map[string]map[string]string{
|
||||
}
|
||||
|
||||
type GroupRatioSetting struct {
|
||||
GroupRatio map[string]float64 `json:"group_ratio"`
|
||||
GroupGroupRatio map[string]map[string]float64 `json:"group_group_ratio"`
|
||||
GroupSpecialUsableGroup *types.RWMap[string, map[string]string] `json:"group_special_usable_group"`
|
||||
GroupRatio *types.RWMap[string, float64] `json:"group_ratio"`
|
||||
GroupGroupRatio *types.RWMap[string, map[string]float64] `json:"group_group_ratio"`
|
||||
GroupSpecialUsableGroup *types.RWMap[string, map[string]string] `json:"group_special_usable_group"`
|
||||
}
|
||||
|
||||
var groupRatioSetting GroupRatioSetting
|
||||
@@ -46,10 +44,13 @@ func init() {
|
||||
groupSpecialUsableGroup := types.NewRWMap[string, map[string]string]()
|
||||
groupSpecialUsableGroup.AddAll(defaultGroupSpecialUsableGroup)
|
||||
|
||||
groupRatioMap.AddAll(defaultGroupRatio)
|
||||
groupGroupRatioMap.AddAll(defaultGroupGroupRatio)
|
||||
|
||||
groupRatioSetting = GroupRatioSetting{
|
||||
GroupSpecialUsableGroup: groupSpecialUsableGroup,
|
||||
GroupRatio: groupRatio,
|
||||
GroupGroupRatio: GroupGroupRatio,
|
||||
GroupRatio: groupRatioMap,
|
||||
GroupGroupRatio: groupGroupRatioMap,
|
||||
}
|
||||
|
||||
config.GlobalConfig.Register("group_ratio_setting", &groupRatioSetting)
|
||||
@@ -64,48 +65,24 @@ func GetGroupRatioSetting() *GroupRatioSetting {
|
||||
}
|
||||
|
||||
func GetGroupRatioCopy() map[string]float64 {
|
||||
groupRatioMutex.RLock()
|
||||
defer groupRatioMutex.RUnlock()
|
||||
|
||||
groupRatioCopy := make(map[string]float64)
|
||||
for k, v := range groupRatio {
|
||||
groupRatioCopy[k] = v
|
||||
}
|
||||
return groupRatioCopy
|
||||
return groupRatioMap.ReadAll()
|
||||
}
|
||||
|
||||
func ContainsGroupRatio(name string) bool {
|
||||
groupRatioMutex.RLock()
|
||||
defer groupRatioMutex.RUnlock()
|
||||
|
||||
_, ok := groupRatio[name]
|
||||
_, ok := groupRatioMap.Get(name)
|
||||
return ok
|
||||
}
|
||||
|
||||
func GroupRatio2JSONString() string {
|
||||
groupRatioMutex.RLock()
|
||||
defer groupRatioMutex.RUnlock()
|
||||
|
||||
jsonBytes, err := json.Marshal(groupRatio)
|
||||
if err != nil {
|
||||
common.SysLog("error marshalling model ratio: " + err.Error())
|
||||
}
|
||||
return string(jsonBytes)
|
||||
return groupRatioMap.MarshalJSONString()
|
||||
}
|
||||
|
||||
func UpdateGroupRatioByJSONString(jsonStr string) error {
|
||||
groupRatioMutex.Lock()
|
||||
defer groupRatioMutex.Unlock()
|
||||
|
||||
groupRatio = make(map[string]float64)
|
||||
return json.Unmarshal([]byte(jsonStr), &groupRatio)
|
||||
return types.LoadFromJsonString(groupRatioMap, jsonStr)
|
||||
}
|
||||
|
||||
func GetGroupRatio(name string) float64 {
|
||||
groupRatioMutex.RLock()
|
||||
defer groupRatioMutex.RUnlock()
|
||||
|
||||
ratio, ok := groupRatio[name]
|
||||
ratio, ok := groupRatioMap.Get(name)
|
||||
if !ok {
|
||||
common.SysLog("group ratio not found: " + name)
|
||||
return 1
|
||||
@@ -114,10 +91,7 @@ func GetGroupRatio(name string) float64 {
|
||||
}
|
||||
|
||||
func GetGroupGroupRatio(userGroup, usingGroup string) (float64, bool) {
|
||||
groupGroupRatioMutex.RLock()
|
||||
defer groupGroupRatioMutex.RUnlock()
|
||||
|
||||
gp, ok := GroupGroupRatio[userGroup]
|
||||
gp, ok := groupGroupRatioMap.Get(userGroup)
|
||||
if !ok {
|
||||
return -1, false
|
||||
}
|
||||
@@ -129,22 +103,11 @@ func GetGroupGroupRatio(userGroup, usingGroup string) (float64, bool) {
|
||||
}
|
||||
|
||||
func GroupGroupRatio2JSONString() string {
|
||||
groupGroupRatioMutex.RLock()
|
||||
defer groupGroupRatioMutex.RUnlock()
|
||||
|
||||
jsonBytes, err := json.Marshal(GroupGroupRatio)
|
||||
if err != nil {
|
||||
common.SysLog("error marshalling group-group ratio: " + err.Error())
|
||||
}
|
||||
return string(jsonBytes)
|
||||
return groupGroupRatioMap.MarshalJSONString()
|
||||
}
|
||||
|
||||
func UpdateGroupGroupRatioByJSONString(jsonStr string) error {
|
||||
groupGroupRatioMutex.Lock()
|
||||
defer groupGroupRatioMutex.Unlock()
|
||||
|
||||
GroupGroupRatio = make(map[string]map[string]float64)
|
||||
return json.Unmarshal([]byte(jsonStr), &GroupGroupRatio)
|
||||
return types.LoadFromJsonString(groupGroupRatioMap, jsonStr)
|
||||
}
|
||||
|
||||
func CheckGroupRatio(jsonStr string) error {
|
||||
|
||||
@@ -1,12 +1,11 @@
|
||||
package ratio_setting
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"strings"
|
||||
"sync"
|
||||
|
||||
"github.com/QuantumNous/new-api/common"
|
||||
"github.com/QuantumNous/new-api/setting/operation_setting"
|
||||
"github.com/QuantumNous/new-api/types"
|
||||
)
|
||||
|
||||
// from songquanpeng/one-api
|
||||
@@ -319,19 +318,9 @@ var defaultAudioCompletionRatio = map[string]float64{
|
||||
"tts-1-hd-1106": 0,
|
||||
}
|
||||
|
||||
var (
|
||||
modelPriceMap map[string]float64 = nil
|
||||
modelPriceMapMutex = sync.RWMutex{}
|
||||
)
|
||||
var (
|
||||
modelRatioMap map[string]float64 = nil
|
||||
modelRatioMapMutex = sync.RWMutex{}
|
||||
)
|
||||
|
||||
var (
|
||||
CompletionRatio map[string]float64 = nil
|
||||
CompletionRatioMutex = sync.RWMutex{}
|
||||
)
|
||||
var modelPriceMap = types.NewRWMap[string, float64]()
|
||||
var modelRatioMap = types.NewRWMap[string, float64]()
|
||||
var completionRatioMap = types.NewRWMap[string, float64]()
|
||||
|
||||
var defaultCompletionRatio = map[string]float64{
|
||||
"gpt-4-gizmo-*": 2,
|
||||
@@ -342,79 +331,34 @@ var defaultCompletionRatio = map[string]float64{
|
||||
|
||||
// InitRatioSettings initializes all model related settings maps
|
||||
func InitRatioSettings() {
|
||||
// Initialize modelPriceMap
|
||||
modelPriceMapMutex.Lock()
|
||||
modelPriceMap = defaultModelPrice
|
||||
modelPriceMapMutex.Unlock()
|
||||
|
||||
// Initialize modelRatioMap
|
||||
modelRatioMapMutex.Lock()
|
||||
modelRatioMap = defaultModelRatio
|
||||
modelRatioMapMutex.Unlock()
|
||||
|
||||
// Initialize CompletionRatio
|
||||
CompletionRatioMutex.Lock()
|
||||
CompletionRatio = defaultCompletionRatio
|
||||
CompletionRatioMutex.Unlock()
|
||||
|
||||
// Initialize cacheRatioMap
|
||||
cacheRatioMapMutex.Lock()
|
||||
cacheRatioMap = defaultCacheRatio
|
||||
cacheRatioMapMutex.Unlock()
|
||||
|
||||
// initialize imageRatioMap
|
||||
imageRatioMapMutex.Lock()
|
||||
imageRatioMap = defaultImageRatio
|
||||
imageRatioMapMutex.Unlock()
|
||||
|
||||
// initialize audioRatioMap
|
||||
audioRatioMapMutex.Lock()
|
||||
audioRatioMap = defaultAudioRatio
|
||||
audioRatioMapMutex.Unlock()
|
||||
|
||||
// initialize audioCompletionRatioMap
|
||||
audioCompletionRatioMapMutex.Lock()
|
||||
audioCompletionRatioMap = defaultAudioCompletionRatio
|
||||
audioCompletionRatioMapMutex.Unlock()
|
||||
modelPriceMap.AddAll(defaultModelPrice)
|
||||
modelRatioMap.AddAll(defaultModelRatio)
|
||||
completionRatioMap.AddAll(defaultCompletionRatio)
|
||||
cacheRatioMap.AddAll(defaultCacheRatio)
|
||||
createCacheRatioMap.AddAll(defaultCreateCacheRatio)
|
||||
imageRatioMap.AddAll(defaultImageRatio)
|
||||
audioRatioMap.AddAll(defaultAudioRatio)
|
||||
audioCompletionRatioMap.AddAll(defaultAudioCompletionRatio)
|
||||
}
|
||||
|
||||
func GetModelPriceMap() map[string]float64 {
|
||||
modelPriceMapMutex.RLock()
|
||||
defer modelPriceMapMutex.RUnlock()
|
||||
return modelPriceMap
|
||||
return modelPriceMap.ReadAll()
|
||||
}
|
||||
|
||||
func ModelPrice2JSONString() string {
|
||||
modelPriceMapMutex.RLock()
|
||||
defer modelPriceMapMutex.RUnlock()
|
||||
|
||||
jsonBytes, err := common.Marshal(modelPriceMap)
|
||||
if err != nil {
|
||||
common.SysError("error marshalling model price: " + err.Error())
|
||||
}
|
||||
return string(jsonBytes)
|
||||
return modelPriceMap.MarshalJSONString()
|
||||
}
|
||||
|
||||
func UpdateModelPriceByJSONString(jsonStr string) error {
|
||||
modelPriceMapMutex.Lock()
|
||||
defer modelPriceMapMutex.Unlock()
|
||||
modelPriceMap = make(map[string]float64)
|
||||
err := json.Unmarshal([]byte(jsonStr), &modelPriceMap)
|
||||
if err == nil {
|
||||
InvalidateExposedDataCache()
|
||||
}
|
||||
return err
|
||||
return types.LoadFromJsonStringWithCallback(modelPriceMap, jsonStr, InvalidateExposedDataCache)
|
||||
}
|
||||
|
||||
// GetModelPrice 返回模型的价格,如果模型不存在则返回-1,false
|
||||
func GetModelPrice(name string, printErr bool) (float64, bool) {
|
||||
modelPriceMapMutex.RLock()
|
||||
defer modelPriceMapMutex.RUnlock()
|
||||
|
||||
name = FormatMatchingModelName(name)
|
||||
|
||||
if strings.HasSuffix(name, CompactModelSuffix) {
|
||||
price, ok := modelPriceMap[CompactWildcardModelKey]
|
||||
price, ok := modelPriceMap.Get(CompactWildcardModelKey)
|
||||
if !ok {
|
||||
if printErr {
|
||||
common.SysError("model price not found: " + name)
|
||||
@@ -424,7 +368,7 @@ func GetModelPrice(name string, printErr bool) (float64, bool) {
|
||||
return price, true
|
||||
}
|
||||
|
||||
price, ok := modelPriceMap[name]
|
||||
price, ok := modelPriceMap.Get(name)
|
||||
if !ok {
|
||||
if printErr {
|
||||
common.SysError("model price not found: " + name)
|
||||
@@ -435,14 +379,7 @@ func GetModelPrice(name string, printErr bool) (float64, bool) {
|
||||
}
|
||||
|
||||
func UpdateModelRatioByJSONString(jsonStr string) error {
|
||||
modelRatioMapMutex.Lock()
|
||||
defer modelRatioMapMutex.Unlock()
|
||||
modelRatioMap = make(map[string]float64)
|
||||
err := common.Unmarshal([]byte(jsonStr), &modelRatioMap)
|
||||
if err == nil {
|
||||
InvalidateExposedDataCache()
|
||||
}
|
||||
return err
|
||||
return types.LoadFromJsonStringWithCallback(modelRatioMap, jsonStr, InvalidateExposedDataCache)
|
||||
}
|
||||
|
||||
// 处理带有思考预算的模型名称,方便统一定价
|
||||
@@ -454,15 +391,12 @@ func handleThinkingBudgetModel(name, prefix, wildcard string) string {
|
||||
}
|
||||
|
||||
func GetModelRatio(name string) (float64, bool, string) {
|
||||
modelRatioMapMutex.RLock()
|
||||
defer modelRatioMapMutex.RUnlock()
|
||||
|
||||
name = FormatMatchingModelName(name)
|
||||
|
||||
ratio, ok := modelRatioMap[name]
|
||||
ratio, ok := modelRatioMap.Get(name)
|
||||
if !ok {
|
||||
if strings.HasSuffix(name, CompactModelSuffix) {
|
||||
if wildcardRatio, ok := modelRatioMap[CompactWildcardModelKey]; ok {
|
||||
if wildcardRatio, ok := modelRatioMap.Get(CompactWildcardModelKey); ok {
|
||||
return wildcardRatio, true, name
|
||||
}
|
||||
//return 0, true, name
|
||||
@@ -488,54 +422,19 @@ func GetDefaultModelPriceMap() map[string]float64 {
|
||||
return defaultModelPrice
|
||||
}
|
||||
|
||||
func GetDefaultImageRatioMap() map[string]float64 {
|
||||
return defaultImageRatio
|
||||
}
|
||||
|
||||
func GetDefaultAudioRatioMap() map[string]float64 {
|
||||
return defaultAudioRatio
|
||||
}
|
||||
|
||||
func GetDefaultAudioCompletionRatioMap() map[string]float64 {
|
||||
return defaultAudioCompletionRatio
|
||||
}
|
||||
|
||||
func GetCompletionRatioMap() map[string]float64 {
|
||||
CompletionRatioMutex.RLock()
|
||||
defer CompletionRatioMutex.RUnlock()
|
||||
return CompletionRatio
|
||||
}
|
||||
|
||||
func CompletionRatio2JSONString() string {
|
||||
CompletionRatioMutex.RLock()
|
||||
defer CompletionRatioMutex.RUnlock()
|
||||
|
||||
jsonBytes, err := json.Marshal(CompletionRatio)
|
||||
if err != nil {
|
||||
common.SysError("error marshalling completion ratio: " + err.Error())
|
||||
}
|
||||
return string(jsonBytes)
|
||||
return completionRatioMap.MarshalJSONString()
|
||||
}
|
||||
|
||||
func UpdateCompletionRatioByJSONString(jsonStr string) error {
|
||||
CompletionRatioMutex.Lock()
|
||||
defer CompletionRatioMutex.Unlock()
|
||||
CompletionRatio = make(map[string]float64)
|
||||
err := common.Unmarshal([]byte(jsonStr), &CompletionRatio)
|
||||
if err == nil {
|
||||
InvalidateExposedDataCache()
|
||||
}
|
||||
return err
|
||||
return types.LoadFromJsonStringWithCallback(completionRatioMap, jsonStr, InvalidateExposedDataCache)
|
||||
}
|
||||
|
||||
func GetCompletionRatio(name string) float64 {
|
||||
CompletionRatioMutex.RLock()
|
||||
defer CompletionRatioMutex.RUnlock()
|
||||
|
||||
name = FormatMatchingModelName(name)
|
||||
|
||||
if strings.Contains(name, "/") {
|
||||
if ratio, ok := CompletionRatio[name]; ok {
|
||||
if ratio, ok := completionRatioMap.Get(name); ok {
|
||||
return ratio
|
||||
}
|
||||
}
|
||||
@@ -543,7 +442,7 @@ func GetCompletionRatio(name string) float64 {
|
||||
if contain {
|
||||
return hardCodedRatio
|
||||
}
|
||||
if ratio, ok := CompletionRatio[name]; ok {
|
||||
if ratio, ok := completionRatioMap.Get(name); ok {
|
||||
return ratio
|
||||
}
|
||||
return hardCodedRatio
|
||||
@@ -671,88 +570,54 @@ func getHardcodedCompletionModelRatio(name string) (float64, bool) {
|
||||
}
|
||||
|
||||
func GetAudioRatio(name string) float64 {
|
||||
audioRatioMapMutex.RLock()
|
||||
defer audioRatioMapMutex.RUnlock()
|
||||
name = FormatMatchingModelName(name)
|
||||
if ratio, ok := audioRatioMap[name]; ok {
|
||||
if ratio, ok := audioRatioMap.Get(name); ok {
|
||||
return ratio
|
||||
}
|
||||
return 1
|
||||
}
|
||||
|
||||
func GetAudioCompletionRatio(name string) float64 {
|
||||
audioCompletionRatioMapMutex.RLock()
|
||||
defer audioCompletionRatioMapMutex.RUnlock()
|
||||
name = FormatMatchingModelName(name)
|
||||
if ratio, ok := audioCompletionRatioMap[name]; ok {
|
||||
|
||||
if ratio, ok := audioCompletionRatioMap.Get(name); ok {
|
||||
return ratio
|
||||
}
|
||||
return 1
|
||||
}
|
||||
|
||||
func ContainsAudioRatio(name string) bool {
|
||||
audioRatioMapMutex.RLock()
|
||||
defer audioRatioMapMutex.RUnlock()
|
||||
name = FormatMatchingModelName(name)
|
||||
_, ok := audioRatioMap[name]
|
||||
_, ok := audioRatioMap.Get(name)
|
||||
return ok
|
||||
}
|
||||
|
||||
func ContainsAudioCompletionRatio(name string) bool {
|
||||
audioCompletionRatioMapMutex.RLock()
|
||||
defer audioCompletionRatioMapMutex.RUnlock()
|
||||
name = FormatMatchingModelName(name)
|
||||
_, ok := audioCompletionRatioMap[name]
|
||||
_, ok := audioCompletionRatioMap.Get(name)
|
||||
return ok
|
||||
}
|
||||
|
||||
func ModelRatio2JSONString() string {
|
||||
modelRatioMapMutex.RLock()
|
||||
defer modelRatioMapMutex.RUnlock()
|
||||
|
||||
jsonBytes, err := common.Marshal(modelRatioMap)
|
||||
if err != nil {
|
||||
common.SysError("error marshalling model ratio: " + err.Error())
|
||||
}
|
||||
return string(jsonBytes)
|
||||
return modelRatioMap.MarshalJSONString()
|
||||
}
|
||||
|
||||
var defaultImageRatio = map[string]float64{
|
||||
"gpt-image-1": 2,
|
||||
}
|
||||
var imageRatioMap map[string]float64
|
||||
var imageRatioMapMutex sync.RWMutex
|
||||
var (
|
||||
audioRatioMap map[string]float64 = nil
|
||||
audioRatioMapMutex = sync.RWMutex{}
|
||||
)
|
||||
var (
|
||||
audioCompletionRatioMap map[string]float64 = nil
|
||||
audioCompletionRatioMapMutex = sync.RWMutex{}
|
||||
)
|
||||
var imageRatioMap = types.NewRWMap[string, float64]()
|
||||
var audioRatioMap = types.NewRWMap[string, float64]()
|
||||
var audioCompletionRatioMap = types.NewRWMap[string, float64]()
|
||||
|
||||
func ImageRatio2JSONString() string {
|
||||
imageRatioMapMutex.RLock()
|
||||
defer imageRatioMapMutex.RUnlock()
|
||||
jsonBytes, err := common.Marshal(imageRatioMap)
|
||||
if err != nil {
|
||||
common.SysError("error marshalling cache ratio: " + err.Error())
|
||||
}
|
||||
return string(jsonBytes)
|
||||
return imageRatioMap.MarshalJSONString()
|
||||
}
|
||||
|
||||
func UpdateImageRatioByJSONString(jsonStr string) error {
|
||||
imageRatioMapMutex.Lock()
|
||||
defer imageRatioMapMutex.Unlock()
|
||||
imageRatioMap = make(map[string]float64)
|
||||
return common.Unmarshal([]byte(jsonStr), &imageRatioMap)
|
||||
return types.LoadFromJsonString(imageRatioMap, jsonStr)
|
||||
}
|
||||
|
||||
func GetImageRatio(name string) (float64, bool) {
|
||||
imageRatioMapMutex.RLock()
|
||||
defer imageRatioMapMutex.RUnlock()
|
||||
ratio, ok := imageRatioMap[name]
|
||||
ratio, ok := imageRatioMap.Get(name)
|
||||
if !ok {
|
||||
return 1, false // Default to 1 if not found
|
||||
}
|
||||
@@ -760,78 +625,31 @@ func GetImageRatio(name string) (float64, bool) {
|
||||
}
|
||||
|
||||
func AudioRatio2JSONString() string {
|
||||
audioRatioMapMutex.RLock()
|
||||
defer audioRatioMapMutex.RUnlock()
|
||||
jsonBytes, err := common.Marshal(audioRatioMap)
|
||||
if err != nil {
|
||||
common.SysError("error marshalling audio ratio: " + err.Error())
|
||||
}
|
||||
return string(jsonBytes)
|
||||
return audioRatioMap.MarshalJSONString()
|
||||
}
|
||||
|
||||
func UpdateAudioRatioByJSONString(jsonStr string) error {
|
||||
|
||||
tmp := make(map[string]float64)
|
||||
if err := common.Unmarshal([]byte(jsonStr), &tmp); err != nil {
|
||||
return err
|
||||
}
|
||||
audioRatioMapMutex.Lock()
|
||||
audioRatioMap = tmp
|
||||
audioRatioMapMutex.Unlock()
|
||||
InvalidateExposedDataCache()
|
||||
return nil
|
||||
return types.LoadFromJsonStringWithCallback(audioRatioMap, jsonStr, InvalidateExposedDataCache)
|
||||
}
|
||||
|
||||
func AudioCompletionRatio2JSONString() string {
|
||||
audioCompletionRatioMapMutex.RLock()
|
||||
defer audioCompletionRatioMapMutex.RUnlock()
|
||||
jsonBytes, err := common.Marshal(audioCompletionRatioMap)
|
||||
if err != nil {
|
||||
common.SysError("error marshalling audio completion ratio: " + err.Error())
|
||||
}
|
||||
return string(jsonBytes)
|
||||
return audioCompletionRatioMap.MarshalJSONString()
|
||||
}
|
||||
|
||||
func UpdateAudioCompletionRatioByJSONString(jsonStr string) error {
|
||||
tmp := make(map[string]float64)
|
||||
if err := common.Unmarshal([]byte(jsonStr), &tmp); err != nil {
|
||||
return err
|
||||
}
|
||||
audioCompletionRatioMapMutex.Lock()
|
||||
audioCompletionRatioMap = tmp
|
||||
audioCompletionRatioMapMutex.Unlock()
|
||||
InvalidateExposedDataCache()
|
||||
return nil
|
||||
return types.LoadFromJsonStringWithCallback(audioCompletionRatioMap, jsonStr, InvalidateExposedDataCache)
|
||||
}
|
||||
|
||||
func GetModelRatioCopy() map[string]float64 {
|
||||
modelRatioMapMutex.RLock()
|
||||
defer modelRatioMapMutex.RUnlock()
|
||||
copyMap := make(map[string]float64, len(modelRatioMap))
|
||||
for k, v := range modelRatioMap {
|
||||
copyMap[k] = v
|
||||
}
|
||||
return copyMap
|
||||
return modelRatioMap.ReadAll()
|
||||
}
|
||||
|
||||
func GetModelPriceCopy() map[string]float64 {
|
||||
modelPriceMapMutex.RLock()
|
||||
defer modelPriceMapMutex.RUnlock()
|
||||
copyMap := make(map[string]float64, len(modelPriceMap))
|
||||
for k, v := range modelPriceMap {
|
||||
copyMap[k] = v
|
||||
}
|
||||
return copyMap
|
||||
return modelPriceMap.ReadAll()
|
||||
}
|
||||
|
||||
func GetCompletionRatioCopy() map[string]float64 {
|
||||
CompletionRatioMutex.RLock()
|
||||
defer CompletionRatioMutex.RUnlock()
|
||||
copyMap := make(map[string]float64, len(CompletionRatio))
|
||||
for k, v := range CompletionRatio {
|
||||
copyMap[k] = v
|
||||
}
|
||||
return copyMap
|
||||
return completionRatioMap.ReadAll()
|
||||
}
|
||||
|
||||
// 转换模型名,减少渠道必须配置各种带参数模型
|
||||
|
||||
@@ -80,3 +80,24 @@ func LoadFromJsonString[K comparable, V any](m *RWMap[K, V], jsonStr string) err
|
||||
m.data = make(map[K]V)
|
||||
return common.Unmarshal([]byte(jsonStr), &m.data)
|
||||
}
|
||||
|
||||
// LoadFromJsonStringWithCallback loads a JSON string into the RWMap and calls the callback on success.
|
||||
func LoadFromJsonStringWithCallback[K comparable, V any](m *RWMap[K, V], jsonStr string, onSuccess func()) error {
|
||||
m.mutex.Lock()
|
||||
defer m.mutex.Unlock()
|
||||
m.data = make(map[K]V)
|
||||
err := common.Unmarshal([]byte(jsonStr), &m.data)
|
||||
if err == nil && onSuccess != nil {
|
||||
onSuccess()
|
||||
}
|
||||
return err
|
||||
}
|
||||
|
||||
// MarshalJSONString returns the JSON string representation of the RWMap.
|
||||
func (m *RWMap[K, V]) MarshalJSONString() string {
|
||||
bytes, err := m.MarshalJSON()
|
||||
if err != nil {
|
||||
return "{}"
|
||||
}
|
||||
return string(bytes)
|
||||
}
|
||||
|
||||
@@ -36,6 +36,7 @@ const RatioSetting = () => {
|
||||
ModelPrice: '',
|
||||
ModelRatio: '',
|
||||
CacheRatio: '',
|
||||
CreateCacheRatio: '',
|
||||
CompletionRatio: '',
|
||||
GroupRatio: '',
|
||||
GroupGroupRatio: '',
|
||||
|
||||
@@ -107,9 +107,11 @@ const AccountManagement = ({
|
||||
const res = await API.get('/api/user/oauth/bindings');
|
||||
if (res.data.success) {
|
||||
setCustomOAuthBindings(res.data.data || []);
|
||||
} else {
|
||||
showError(res.data.message || t('获取绑定信息失败'));
|
||||
}
|
||||
} catch (error) {
|
||||
// ignore
|
||||
showError(error.response?.data?.message || error.message || t('获取绑定信息失败'));
|
||||
}
|
||||
};
|
||||
|
||||
@@ -131,7 +133,7 @@ const AccountManagement = ({
|
||||
showError(res.data.message);
|
||||
}
|
||||
} catch (error) {
|
||||
showError(t('操作失败'));
|
||||
showError(error.response?.data?.message || error.message || t('操作失败'));
|
||||
} finally {
|
||||
setCustomOAuthLoading((prev) => ({ ...prev, [providerId]: false }));
|
||||
}
|
||||
|
||||
@@ -170,6 +170,7 @@ const EditChannelModal = (props) => {
|
||||
allow_service_tier: false,
|
||||
disable_store: false, // false = 允许透传(默认开启)
|
||||
allow_safety_identifier: false,
|
||||
claude_beta_query: false,
|
||||
};
|
||||
const [batch, setBatch] = useState(false);
|
||||
const [multiToSingle, setMultiToSingle] = useState(false);
|
||||
@@ -633,6 +634,7 @@ const EditChannelModal = (props) => {
|
||||
data.disable_store = parsedSettings.disable_store || false;
|
||||
data.allow_safety_identifier =
|
||||
parsedSettings.allow_safety_identifier || false;
|
||||
data.claude_beta_query = parsedSettings.claude_beta_query || false;
|
||||
} catch (error) {
|
||||
console.error('解析其他设置失败:', error);
|
||||
data.azure_responses_version = '';
|
||||
@@ -643,6 +645,7 @@ const EditChannelModal = (props) => {
|
||||
data.allow_service_tier = false;
|
||||
data.disable_store = false;
|
||||
data.allow_safety_identifier = false;
|
||||
data.claude_beta_query = false;
|
||||
}
|
||||
} else {
|
||||
// 兼容历史数据:老渠道没有 settings 时,默认按 json 展示
|
||||
@@ -652,6 +655,7 @@ const EditChannelModal = (props) => {
|
||||
data.allow_service_tier = false;
|
||||
data.disable_store = false;
|
||||
data.allow_safety_identifier = false;
|
||||
data.claude_beta_query = false;
|
||||
}
|
||||
|
||||
if (
|
||||
@@ -1394,6 +1398,9 @@ const EditChannelModal = (props) => {
|
||||
settings.allow_safety_identifier =
|
||||
localInputs.allow_safety_identifier === true;
|
||||
}
|
||||
if (localInputs.type === 14) {
|
||||
settings.claude_beta_query = localInputs.claude_beta_query === true;
|
||||
}
|
||||
}
|
||||
|
||||
localInputs.settings = JSON.stringify(settings);
|
||||
@@ -1414,6 +1421,7 @@ const EditChannelModal = (props) => {
|
||||
delete localInputs.allow_service_tier;
|
||||
delete localInputs.disable_store;
|
||||
delete localInputs.allow_safety_identifier;
|
||||
delete localInputs.claude_beta_query;
|
||||
|
||||
let res;
|
||||
localInputs.auto_ban = localInputs.auto_ban ? 1 : 0;
|
||||
@@ -3316,6 +3324,24 @@ const EditChannelModal = (props) => {
|
||||
</div>
|
||||
</div>
|
||||
|
||||
{inputs.type === 14 && (
|
||||
<Form.Switch
|
||||
field='claude_beta_query'
|
||||
label={t('Claude 强制 beta=true')}
|
||||
checkedText={t('开')}
|
||||
uncheckedText={t('关')}
|
||||
onChange={(value) =>
|
||||
handleChannelOtherSettingsChange(
|
||||
'claude_beta_query',
|
||||
value,
|
||||
)
|
||||
}
|
||||
extraText={t(
|
||||
'开启后,该渠道请求 Claude 时将强制追加 ?beta=true(无需客户端手动传参)',
|
||||
)}
|
||||
/>
|
||||
)}
|
||||
|
||||
{inputs.type === 1 && (
|
||||
<Form.Switch
|
||||
field='force_format'
|
||||
|
||||
@@ -26,8 +26,10 @@ import {
|
||||
Tag,
|
||||
Typography,
|
||||
Select,
|
||||
Switch,
|
||||
Banner,
|
||||
} from '@douyinfe/semi-ui';
|
||||
import { IconSearch } from '@douyinfe/semi-icons';
|
||||
import { IconSearch, IconInfoCircle } from '@douyinfe/semi-icons';
|
||||
import { copy, showError, showInfo, showSuccess } from '../../../../helpers';
|
||||
import { MODEL_TABLE_PAGE_SIZE } from '../../../../constants';
|
||||
|
||||
@@ -48,11 +50,25 @@ const ModelTestModal = ({
|
||||
setModelTablePage,
|
||||
selectedEndpointType,
|
||||
setSelectedEndpointType,
|
||||
isStreamTest,
|
||||
setIsStreamTest,
|
||||
allSelectingRef,
|
||||
isMobile,
|
||||
t,
|
||||
}) => {
|
||||
const hasChannel = Boolean(currentTestChannel);
|
||||
const streamToggleDisabled = [
|
||||
'embeddings',
|
||||
'image-generation',
|
||||
'jina-rerank',
|
||||
'openai-response-compact',
|
||||
].includes(selectedEndpointType);
|
||||
|
||||
React.useEffect(() => {
|
||||
if (streamToggleDisabled && isStreamTest) {
|
||||
setIsStreamTest(false);
|
||||
}
|
||||
}, [streamToggleDisabled, isStreamTest, setIsStreamTest]);
|
||||
|
||||
const filteredModels = hasChannel
|
||||
? currentTestChannel.models
|
||||
@@ -181,6 +197,7 @@ const ModelTestModal = ({
|
||||
currentTestChannel,
|
||||
record.model,
|
||||
selectedEndpointType,
|
||||
isStreamTest,
|
||||
)
|
||||
}
|
||||
loading={isTesting}
|
||||
@@ -258,25 +275,46 @@ const ModelTestModal = ({
|
||||
>
|
||||
{hasChannel && (
|
||||
<div className='model-test-scroll'>
|
||||
{/* 端点类型选择器 */}
|
||||
<div className='flex items-center gap-2 w-full mb-2'>
|
||||
<Typography.Text strong>{t('端点类型')}:</Typography.Text>
|
||||
<Select
|
||||
value={selectedEndpointType}
|
||||
onChange={setSelectedEndpointType}
|
||||
optionList={endpointTypeOptions}
|
||||
className='!w-full'
|
||||
placeholder={t('选择端点类型')}
|
||||
/>
|
||||
{/* Endpoint toolbar */}
|
||||
<div className='flex flex-col sm:flex-row sm:items-center gap-2 w-full mb-2'>
|
||||
<div className='flex items-center gap-2 flex-1 min-w-0'>
|
||||
<Typography.Text strong className='shrink-0'>
|
||||
{t('端点类型')}:
|
||||
</Typography.Text>
|
||||
<Select
|
||||
value={selectedEndpointType}
|
||||
onChange={setSelectedEndpointType}
|
||||
optionList={endpointTypeOptions}
|
||||
className='!w-full min-w-0'
|
||||
placeholder={t('选择端点类型')}
|
||||
/>
|
||||
</div>
|
||||
<div className='flex items-center justify-between sm:justify-end gap-2 shrink-0'>
|
||||
<Typography.Text strong className='shrink-0'>
|
||||
{t('流式')}:
|
||||
</Typography.Text>
|
||||
<Switch
|
||||
checked={isStreamTest}
|
||||
onChange={setIsStreamTest}
|
||||
size='small'
|
||||
disabled={streamToggleDisabled}
|
||||
aria-label={t('流式')}
|
||||
/>
|
||||
</div>
|
||||
</div>
|
||||
<Typography.Text type='tertiary' size='small' className='block mb-2'>
|
||||
{t(
|
||||
|
||||
<Banner
|
||||
type='info'
|
||||
closeIcon={null}
|
||||
icon={<IconInfoCircle />}
|
||||
className='!rounded-lg mb-2'
|
||||
description={t(
|
||||
'说明:本页测试为非流式请求;若渠道仅支持流式返回,可能出现测试失败,请以实际使用为准。',
|
||||
)}
|
||||
</Typography.Text>
|
||||
/>
|
||||
|
||||
{/* 搜索与操作按钮 */}
|
||||
<div className='flex items-center justify-end gap-2 w-full mb-2'>
|
||||
<div className='flex flex-col sm:flex-row sm:items-center gap-2 w-full mb-2'>
|
||||
<Input
|
||||
placeholder={t('搜索模型...')}
|
||||
value={modelSearchKeyword}
|
||||
@@ -284,16 +322,17 @@ const ModelTestModal = ({
|
||||
setModelSearchKeyword(v);
|
||||
setModelTablePage(1);
|
||||
}}
|
||||
className='!w-full'
|
||||
className='!w-full sm:!flex-1'
|
||||
prefix={<IconSearch />}
|
||||
showClear
|
||||
/>
|
||||
|
||||
<Button onClick={handleCopySelected}>{t('复制已选')}</Button>
|
||||
|
||||
<Button type='tertiary' onClick={handleSelectSuccess}>
|
||||
{t('选择成功')}
|
||||
</Button>
|
||||
<div className='flex items-center justify-end gap-2'>
|
||||
<Button onClick={handleCopySelected}>{t('复制已选')}</Button>
|
||||
<Button type='tertiary' onClick={handleSelectSuccess}>
|
||||
{t('选择成功')}
|
||||
</Button>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
<Table
|
||||
|
||||
@@ -87,6 +87,7 @@ export const useChannelsData = () => {
|
||||
const [isBatchTesting, setIsBatchTesting] = useState(false);
|
||||
const [modelTablePage, setModelTablePage] = useState(1);
|
||||
const [selectedEndpointType, setSelectedEndpointType] = useState('');
|
||||
const [isStreamTest, setIsStreamTest] = useState(false);
|
||||
const [globalPassThroughEnabled, setGlobalPassThroughEnabled] =
|
||||
useState(false);
|
||||
|
||||
@@ -851,7 +852,12 @@ export const useChannelsData = () => {
|
||||
};
|
||||
|
||||
// Test channel - 单个模型测试,参考旧版实现
|
||||
const testChannel = async (record, model, endpointType = '') => {
|
||||
const testChannel = async (
|
||||
record,
|
||||
model,
|
||||
endpointType = '',
|
||||
stream = false,
|
||||
) => {
|
||||
const testKey = `${record.id}-${model}`;
|
||||
|
||||
// 检查是否应该停止批量测试
|
||||
@@ -867,6 +873,9 @@ export const useChannelsData = () => {
|
||||
if (endpointType) {
|
||||
url += `&endpoint_type=${endpointType}`;
|
||||
}
|
||||
if (stream) {
|
||||
url += `&stream=true`;
|
||||
}
|
||||
const res = await API.get(url);
|
||||
|
||||
// 检查是否在请求期间被停止
|
||||
@@ -995,7 +1004,12 @@ export const useChannelsData = () => {
|
||||
);
|
||||
|
||||
const batchPromises = batch.map((model) =>
|
||||
testChannel(currentTestChannel, model, selectedEndpointType),
|
||||
testChannel(
|
||||
currentTestChannel,
|
||||
model,
|
||||
selectedEndpointType,
|
||||
isStreamTest,
|
||||
),
|
||||
);
|
||||
const batchResults = await Promise.allSettled(batchPromises);
|
||||
results.push(...batchResults);
|
||||
@@ -1080,6 +1094,7 @@ export const useChannelsData = () => {
|
||||
setSelectedModelKeys([]);
|
||||
setModelTablePage(1);
|
||||
setSelectedEndpointType('');
|
||||
setIsStreamTest(false);
|
||||
// 可选择性保留测试结果,这里不清空以便用户查看
|
||||
};
|
||||
|
||||
@@ -1170,6 +1185,8 @@ export const useChannelsData = () => {
|
||||
setModelTablePage,
|
||||
selectedEndpointType,
|
||||
setSelectedEndpointType,
|
||||
isStreamTest,
|
||||
setIsStreamTest,
|
||||
allSelectingRef,
|
||||
|
||||
// Multi-key management states
|
||||
|
||||
@@ -1146,6 +1146,8 @@
|
||||
"提示:链接中的{key}将被替换为API密钥,{address}将被替换为服务器地址": "Tip: {key} in the link will be replaced with the API key, {address} will be replaced with the server address",
|
||||
"提示价格:{{symbol}}{{price}} / 1M tokens": "Prompt price: {{symbol}}{{price}} / 1M tokens",
|
||||
"提示缓存倍率": "Prompt cache ratio",
|
||||
"缓存创建倍率": "Cache creation ratio",
|
||||
"默认为 5m 缓存创建倍率;1h 缓存创建倍率按固定乘法自动计算(当前为 1.6x)": "Defaults to the 5m cache creation ratio; the 1h cache creation ratio is computed by fixed multiplication (currently 1.6x)",
|
||||
"搜索供应商": "Search vendor",
|
||||
"搜索关键字": "Search keywords",
|
||||
"搜索失败": "Search failed",
|
||||
@@ -1548,6 +1550,7 @@
|
||||
"流": "stream",
|
||||
"流式响应完成": "Streaming response completed",
|
||||
"流式输出": "Streaming Output",
|
||||
"流式": "Streaming",
|
||||
"流量端口": "Traffic Port",
|
||||
"浅色": "Light",
|
||||
"浅色模式": "Light Mode",
|
||||
|
||||
@@ -1156,6 +1156,8 @@
|
||||
"提示:链接中的{key}将被替换为API密钥,{address}将被替换为服务器地址": "Astuce : {key} dans le lien sera remplacé par la clé API, {address} sera remplacé par l'adresse du serveur",
|
||||
"提示价格:{{symbol}}{{price}} / 1M tokens": "Prix d'invite : {{symbol}}{{price}} / 1M tokens",
|
||||
"提示缓存倍率": "Ratio de cache d'invite",
|
||||
"缓存创建倍率": "Ratio de création du cache",
|
||||
"默认为 5m 缓存创建倍率;1h 缓存创建倍率按固定乘法自动计算(当前为 1.6x)": "Par défaut, le ratio de création de cache 5m est utilisé ; le ratio de création de cache 1h est calculé via une multiplication fixe (actuellement 1.6x)",
|
||||
"搜索供应商": "Rechercher un fournisseur",
|
||||
"搜索关键字": "Rechercher des mots-clés",
|
||||
"搜索失败": "Search failed",
|
||||
@@ -1558,6 +1560,7 @@
|
||||
"流": "Flux",
|
||||
"流式响应完成": "Flux terminé",
|
||||
"流式输出": "Sortie en flux",
|
||||
"流式": "Streaming",
|
||||
"流量端口": "Traffic Port",
|
||||
"浅色": "Clair",
|
||||
"浅色模式": "Mode clair",
|
||||
|
||||
@@ -1141,6 +1141,8 @@
|
||||
"提示:链接中的{key}将被替换为API密钥,{address}将被替换为服务器地址": "ヒント:リンク内の{key}はAPIキーに、{address}はサーバーURLに置換されます",
|
||||
"提示价格:{{symbol}}{{price}} / 1M tokens": "プロンプト料金:{{symbol}}{{price}} / 1M tokens",
|
||||
"提示缓存倍率": "プロンプトキャッシュ倍率",
|
||||
"缓存创建倍率": "キャッシュ作成倍率",
|
||||
"默认为 5m 缓存创建倍率;1h 缓存创建倍率按固定乘法自动计算(当前为 1.6x)": "デフォルトは5mのキャッシュ作成倍率です。1hのキャッシュ作成倍率は固定乗数で自動計算されます(現在は1.6倍)",
|
||||
"搜索供应商": "プロバイダーで検索",
|
||||
"搜索关键字": "検索キーワード",
|
||||
"搜索失败": "Search failed",
|
||||
@@ -1543,6 +1545,7 @@
|
||||
"流": "ストリーム",
|
||||
"流式响应完成": "ストリーム完了",
|
||||
"流式输出": "ストリーム出力",
|
||||
"流式": "ストリーミング",
|
||||
"流量端口": "Traffic Port",
|
||||
"浅色": "ライト",
|
||||
"浅色模式": "ライトモード",
|
||||
|
||||
@@ -1167,6 +1167,8 @@
|
||||
"提示:链接中的{key}将被替换为API密钥,{address}将被替换为服务器地址": "Промпт: {key} в ссылке будет заменен на API-ключ, {address} будет заменен на адрес сервера",
|
||||
"提示价格:{{symbol}}{{price}} / 1M tokens": "Цена промпта: {{symbol}}{{price}} / 1M токенов",
|
||||
"提示缓存倍率": "Коэффициент кэша промптов",
|
||||
"缓存创建倍率": "Коэффициент создания кэша",
|
||||
"默认为 5m 缓存创建倍率;1h 缓存创建倍率按固定乘法自动计算(当前为 1.6x)": "По умолчанию используется коэффициент создания кэша 5m; коэффициент создания кэша 1h автоматически вычисляется фиксированным умножением (сейчас 1.6x)",
|
||||
"搜索供应商": "Поиск поставщиков",
|
||||
"搜索关键字": "Поиск по ключевым словам",
|
||||
"搜索失败": "Search failed",
|
||||
@@ -1569,6 +1571,7 @@
|
||||
"流": "Поток",
|
||||
"流式响应完成": "Поток завершён",
|
||||
"流式输出": "Потоковый вывод",
|
||||
"流式": "Стриминг",
|
||||
"流量端口": "Traffic Port",
|
||||
"浅色": "Светлая",
|
||||
"浅色模式": "Светлый режим",
|
||||
|
||||
@@ -1142,6 +1142,8 @@
|
||||
"提示:链接中的{key}将被替换为API密钥,{address}将被替换为服务器地址": "Mẹo: {key} trong liên kết sẽ được thay thế bằng khóa API, {address} sẽ được thay thế bằng địa chỉ máy chủ",
|
||||
"提示价格:{{symbol}}{{price}} / 1M tokens": "Giá gợi ý: {{symbol}}{{price}} / 1M tokens",
|
||||
"提示缓存倍率": "Tỷ lệ bộ nhớ đệm gợi ý",
|
||||
"缓存创建倍率": "Tỷ lệ tạo bộ nhớ đệm",
|
||||
"默认为 5m 缓存创建倍率;1h 缓存创建倍率按固定乘法自动计算(当前为 1.6x)": "Mặc định dùng tỷ lệ tạo bộ nhớ đệm 5m; tỷ lệ tạo bộ nhớ đệm 1h được tự động tính bằng phép nhân cố định (hiện là 1.6x)",
|
||||
"搜索供应商": "Tìm kiếm nhà cung cấp",
|
||||
"搜索关键字": "Từ khóa tìm kiếm",
|
||||
"搜索失败": "Search failed",
|
||||
@@ -1597,6 +1599,7 @@
|
||||
"流": "luồng",
|
||||
"流式响应完成": "Luồng hoàn tất",
|
||||
"流式输出": "Đầu ra luồng",
|
||||
"流式": "Streaming",
|
||||
"流量端口": "Traffic Port",
|
||||
"浅色": "Sáng",
|
||||
"浅色模式": "Chế độ sáng",
|
||||
|
||||
@@ -1136,6 +1136,8 @@
|
||||
"提示:链接中的{key}将被替换为API密钥,{address}将被替换为服务器地址": "提示:链接中的{key}将被替换为API密钥,{address}将被替换为服务器地址",
|
||||
"提示价格:{{symbol}}{{price}} / 1M tokens": "提示价格:{{symbol}}{{price}} / 1M tokens",
|
||||
"提示缓存倍率": "提示缓存倍率",
|
||||
"缓存创建倍率": "缓存创建倍率",
|
||||
"默认为 5m 缓存创建倍率;1h 缓存创建倍率按固定乘法自动计算(当前为 1.6x)": "默认为 5m 缓存创建倍率;1h 缓存创建倍率按固定乘法自动计算(当前为 1.6x)",
|
||||
"搜索供应商": "搜索供应商",
|
||||
"搜索关键字": "搜索关键字",
|
||||
"搜索失败": "搜索失败",
|
||||
@@ -1538,6 +1540,7 @@
|
||||
"流": "流",
|
||||
"流式响应完成": "流式响应完成",
|
||||
"流式输出": "流式输出",
|
||||
"流式": "流式",
|
||||
"流量端口": "流量端口",
|
||||
"浅色": "浅色",
|
||||
"浅色模式": "浅色模式",
|
||||
|
||||
@@ -43,6 +43,7 @@ export default function ModelRatioSettings(props) {
|
||||
ModelPrice: '',
|
||||
ModelRatio: '',
|
||||
CacheRatio: '',
|
||||
CreateCacheRatio: '',
|
||||
CompletionRatio: '',
|
||||
ImageRatio: '',
|
||||
AudioRatio: '',
|
||||
@@ -200,6 +201,30 @@ export default function ModelRatioSettings(props) {
|
||||
/>
|
||||
</Col>
|
||||
</Row>
|
||||
<Row gutter={16}>
|
||||
<Col xs={24} sm={16}>
|
||||
<Form.TextArea
|
||||
label={t('缓存创建倍率')}
|
||||
extraText={t(
|
||||
'默认为 5m 缓存创建倍率;1h 缓存创建倍率按固定乘法自动计算(当前为 1.6x)',
|
||||
)}
|
||||
placeholder={t('为一个 JSON 文本,键为模型名称,值为倍率')}
|
||||
field={'CreateCacheRatio'}
|
||||
autosize={{ minRows: 6, maxRows: 12 }}
|
||||
trigger='blur'
|
||||
stopValidateWithError
|
||||
rules={[
|
||||
{
|
||||
validator: (rule, value) => verifyJSON(value),
|
||||
message: '不是合法的 JSON 字符串',
|
||||
},
|
||||
]}
|
||||
onChange={(value) =>
|
||||
setInputs({ ...inputs, CreateCacheRatio: value })
|
||||
}
|
||||
/>
|
||||
</Col>
|
||||
</Row>
|
||||
<Row gutter={16}>
|
||||
<Col xs={24} sm={16}>
|
||||
<Form.TextArea
|
||||
|
||||
Reference in New Issue
Block a user