Compare commits

...

30 Commits

Author SHA1 Message Date
Seefs
2c0db08f32 Merge pull request #2815 from wans10/main
fix: 修复模型管理"参与官方同步"与"状态"开关无法保存的问题
2026-02-08 19:56:16 +08:00
Calcium-Ion
11de49f9b9 Merge pull request #2895 from seefs001/fix/model-manager
fix: 如果模型管理有自定义配置则不合并默认配置
2026-02-08 19:53:49 +08:00
Seefs
4950db666f fix: 如果模型管理有自定义配置则不合并默认配置 2026-02-08 19:42:52 +08:00
CaIon
44c5fac5ea refactor(ratio): replace maps with RWMap for improved concurrency handling 2026-02-08 00:48:21 +08:00
Calcium-Ion
7a146a11f5 Merge pull request #2870 from seefs001/feature/cache-creation-configurable
feat: make 5m cache-creation ratio configurable
2026-02-08 00:28:42 +08:00
Calcium-Ion
897955256e Merge pull request #2889 from seefs001/feature/messages2responses
feat: /v1/messages -> /v1/responses
2026-02-08 00:25:27 +08:00
Calcium-Ion
bc6810ca5a Merge pull request #2887 from seefs001/fix/claude
fix: 补全 streaming message_delta 事件缺失的 input_tokens 和 cache 相关字段
2026-02-08 00:18:04 +08:00
Calcium-Ion
742f4ad1e4 Merge pull request #2883 from seefs001/fix/claude-relay-info-input-token
fix: 使用openai兼容接口调用部分渠道在最终端点为claude原生端点下还是走了openai扣减input_token的逻辑
2026-02-08 00:17:50 +08:00
Calcium-Ion
83a5245bb1 Merge pull request #2875 from seefs001/feature/channel-test-stream
feat: channel test with stream=true
2026-02-08 00:17:07 +08:00
Seefs
2faa873caf Merge branch 'feature/messages2responses' into upstream-main
# Conflicts:
#	service/openaicompat/chat_to_responses.go
2026-02-08 00:16:35 +08:00
Calcium-Ion
ce0113a6b5 Merge pull request #2864 from seefs001/fix/thining-summary
fix: add paragraph breaks between reasoning summary chunks
2026-02-08 00:15:32 +08:00
Calcium-Ion
dd5610d39e Merge pull request #2854 from seefs001/fix/claude-tool-index
fix: Claude stream block index/type transitions
2026-02-08 00:15:20 +08:00
Calcium-Ion
8e1a990b45 Merge pull request #2857 from QuantumNous/feat/custom-oauth
feat(oauth): implement custom OAuth provider
2026-02-08 00:13:20 +08:00
Seefs
5f6f95c7c1 Merge pull request #2874 from MUTED64/main
feat: Force beta=true parameter for Anthropic channel
2026-02-08 00:09:28 +08:00
Seefs
0b3a0b38d6 fix: patch message_delta usage via gjson/sjson and skip on passthrough 2026-02-07 19:13:58 +08:00
Thomas
bbad917101 fix: 补全 streaming message_delta 事件缺失的 input_tokens 和 cache 相关字段 (#2881)
当上游为 AWS Bedrock 时,message_delta 的 usage 可能缺少 input_tokens、
cache_creation_input_tokens、cache_read_input_tokens 等字段,导致与原生
Anthropic 格式不一致。从 message_start 积累的 claudeInfo 中补全这些字段后
重新序列化,确保客户端收到一致的 usage 格式。
2026-02-07 18:17:22 +08:00
Seefs
a0bb78edd0 fix: 使用openai兼容接口调用部分渠道在最终端点为claude原生端点下还是走了openai扣减input_token的逻辑 2026-02-07 14:21:19 +08:00
Seefs
fac9c367b1 fix: auto default codex to /v1/responses without overriding user-selected endpoint 2026-02-06 22:08:55 +08:00
Seefs
23227e18f9 feat: channel test stream 2026-02-06 21:57:38 +08:00
MUTED64
4332837f05 feat: Force beta=true parameter for Anthropic channel 2026-02-06 21:22:39 +08:00
Seefs
50ee4361d0 feat: make 5m cache-creation ratio configurable 2026-02-06 19:46:59 +08:00
Seefs
3af53bdd41 fix max_output_token 2026-02-06 16:04:49 +08:00
Seefs
aa8240e482 feat: /v1/messages -> /v1/responses 2026-02-06 15:22:32 +08:00
Seefs
b580b8bd1d fix: add paragraph breaks between reasoning summary chunks in chat2responses stream 2026-02-06 14:46:29 +08:00
CaIon
e8d26e52d8 refactor(oauth): update UpdateCustomOAuthProviderRequest to use pointers for optional fields
- Change fields in UpdateCustomOAuthProviderRequest struct to use pointers for optional values, allowing for better handling of nil cases.
- Update UpdateCustomOAuthProvider function to check for nil before assigning optional fields, ensuring existing values are preserved when not provided.
2026-02-05 22:03:30 +08:00
CaIon
2567cff6c8 fix(oauth): enhance error handling and transaction management for OAuth user creation and binding
- Improve error handling in DeleteCustomOAuthProvider to log and return errors when fetching binding counts.
- Refactor user creation and OAuth binding logic to use transactions for atomic operations, ensuring data integrity.
- Add unique constraints to UserOAuthBinding model to prevent duplicate bindings.
- Enhance GitHub OAuth provider error logging for non-200 responses.
- Update AccountManagement component to provide clearer error messages on API failures.
2026-02-05 21:48:05 +08:00
Seefs
7314c974f3 fix: Claude stream block index/type transitions 2026-02-05 19:32:26 +08:00
Seefs
fca80a57ad fix: Claude stream block index/type transitions 2026-02-05 19:11:58 +08:00
wans10
3229b81149 fix(model): 解决模型创建和更新时零值字段被默认值覆盖的问题
- 在创建记录前保存原始状态和同步官方字段值
- 使用独立的更新操作确保零值能够正确保存到数据库
- 修改更新方法使用 Select 强制更新所有字段包括零值
- 避免 GORM 默认行为对零值字段应用默认值导致数据丢失
2026-02-03 13:32:14 +08:00
wans10
5efb402532 refactor(model): 优化模型更新逻辑
- 将全局更新改为字段映射更新
- 移除不必要的会话配置选项
- 使用显式字段映射替代 Omit 和 Select 操作
- 提升代码可读性和维护性
- 保持数据一致性的同时提高性能
2026-02-03 09:48:53 +08:00
47 changed files with 1232 additions and 505 deletions

View File

@@ -2,29 +2,37 @@ package common
import (
"encoding/json"
"sync"
)
var TopupGroupRatio = map[string]float64{
var topupGroupRatio = map[string]float64{
"default": 1,
"vip": 1,
"svip": 1,
}
var topupGroupRatioMutex sync.RWMutex
func TopupGroupRatio2JSONString() string {
jsonBytes, err := json.Marshal(TopupGroupRatio)
topupGroupRatioMutex.RLock()
defer topupGroupRatioMutex.RUnlock()
jsonBytes, err := json.Marshal(topupGroupRatio)
if err != nil {
SysError("error marshalling model ratio: " + err.Error())
SysError("error marshalling topup group ratio: " + err.Error())
}
return string(jsonBytes)
}
func UpdateTopupGroupRatioByJSONString(jsonStr string) error {
TopupGroupRatio = make(map[string]float64)
return json.Unmarshal([]byte(jsonStr), &TopupGroupRatio)
topupGroupRatioMutex.Lock()
defer topupGroupRatioMutex.Unlock()
topupGroupRatio = make(map[string]float64)
return json.Unmarshal([]byte(jsonStr), &topupGroupRatio)
}
func GetTopupGroupRatio(name string) float64 {
ratio, ok := TopupGroupRatio[name]
topupGroupRatioMutex.RLock()
defer topupGroupRatioMutex.RUnlock()
ratio, ok := topupGroupRatio[name]
if !ok {
SysError("topup group ratio not found: " + name)
return 1

View File

@@ -31,6 +31,7 @@ import (
"github.com/bytedance/gopkg/util/gopool"
"github.com/samber/lo"
"github.com/tidwall/gjson"
"github.com/gin-gonic/gin"
)
@@ -41,7 +42,21 @@ type testResult struct {
newAPIError *types.NewAPIError
}
func testChannel(channel *model.Channel, testModel string, endpointType string) testResult {
func normalizeChannelTestEndpoint(channel *model.Channel, modelName, endpointType string) string {
normalized := strings.TrimSpace(endpointType)
if normalized != "" {
return normalized
}
if strings.HasSuffix(modelName, ratio_setting.CompactModelSuffix) {
return string(constant.EndpointTypeOpenAIResponseCompact)
}
if channel != nil && channel.Type == constant.ChannelTypeCodex {
return string(constant.EndpointTypeOpenAIResponse)
}
return normalized
}
func testChannel(channel *model.Channel, testModel string, endpointType string, isStream bool) testResult {
tik := time.Now()
var unsupportedTestChannelTypes = []int{
constant.ChannelTypeMidjourney,
@@ -76,6 +91,8 @@ func testChannel(channel *model.Channel, testModel string, endpointType string)
}
}
endpointType = normalizeChannelTestEndpoint(channel, testModel, endpointType)
requestPath := "/v1/chat/completions"
// 如果指定了端点类型,使用指定的端点类型
@@ -200,7 +217,7 @@ func testChannel(channel *model.Channel, testModel string, endpointType string)
}
}
request := buildTestRequest(testModel, endpointType, channel)
request := buildTestRequest(testModel, endpointType, channel, isStream)
info, err := relaycommon.GenRelayInfo(c, relayFormat, request, nil)
@@ -418,16 +435,16 @@ func testChannel(channel *model.Channel, testModel string, endpointType string)
newAPIError: respErr,
}
}
if usageA == nil {
usage, usageErr := coerceTestUsage(usageA, isStream, info.GetEstimatePromptTokens())
if usageErr != nil {
return testResult{
context: c,
localErr: errors.New("usage is nil"),
newAPIError: types.NewOpenAIError(errors.New("usage is nil"), types.ErrorCodeBadResponseBody, http.StatusInternalServerError),
localErr: usageErr,
newAPIError: types.NewOpenAIError(usageErr, types.ErrorCodeBadResponseBody, http.StatusInternalServerError),
}
}
usage := usageA.(*dto.Usage)
result := w.Result()
respBody, err := io.ReadAll(result.Body)
respBody, err := readTestResponseBody(result.Body, isStream)
if err != nil {
return testResult{
context: c,
@@ -435,6 +452,13 @@ func testChannel(channel *model.Channel, testModel string, endpointType string)
newAPIError: types.NewOpenAIError(err, types.ErrorCodeReadResponseBodyFailed, http.StatusInternalServerError),
}
}
if bodyErr := detectErrorFromTestResponseBody(respBody); bodyErr != nil {
return testResult{
context: c,
localErr: bodyErr,
newAPIError: types.NewOpenAIError(bodyErr, types.ErrorCodeBadResponseBody, http.StatusInternalServerError),
}
}
info.SetEstimatePromptTokens(usage.PromptTokens)
quota := 0
@@ -473,7 +497,101 @@ func testChannel(channel *model.Channel, testModel string, endpointType string)
}
}
func buildTestRequest(model string, endpointType string, channel *model.Channel) dto.Request {
func coerceTestUsage(usageAny any, isStream bool, estimatePromptTokens int) (*dto.Usage, error) {
switch u := usageAny.(type) {
case *dto.Usage:
return u, nil
case dto.Usage:
return &u, nil
case nil:
if !isStream {
return nil, errors.New("usage is nil")
}
usage := &dto.Usage{
PromptTokens: estimatePromptTokens,
}
usage.TotalTokens = usage.PromptTokens
return usage, nil
default:
if !isStream {
return nil, fmt.Errorf("invalid usage type: %T", usageAny)
}
usage := &dto.Usage{
PromptTokens: estimatePromptTokens,
}
usage.TotalTokens = usage.PromptTokens
return usage, nil
}
}
func readTestResponseBody(body io.ReadCloser, isStream bool) ([]byte, error) {
defer func() { _ = body.Close() }()
const maxStreamLogBytes = 8 << 10
if isStream {
return io.ReadAll(io.LimitReader(body, maxStreamLogBytes))
}
return io.ReadAll(body)
}
func detectErrorFromTestResponseBody(respBody []byte) error {
b := bytes.TrimSpace(respBody)
if len(b) == 0 {
return nil
}
if message := detectErrorMessageFromJSONBytes(b); message != "" {
return fmt.Errorf("upstream error: %s", message)
}
for _, line := range bytes.Split(b, []byte{'\n'}) {
line = bytes.TrimSpace(line)
if len(line) == 0 {
continue
}
if !bytes.HasPrefix(line, []byte("data:")) {
continue
}
payload := bytes.TrimSpace(bytes.TrimPrefix(line, []byte("data:")))
if len(payload) == 0 || bytes.Equal(payload, []byte("[DONE]")) {
continue
}
if message := detectErrorMessageFromJSONBytes(payload); message != "" {
return fmt.Errorf("upstream error: %s", message)
}
}
return nil
}
func detectErrorMessageFromJSONBytes(jsonBytes []byte) string {
if len(jsonBytes) == 0 {
return ""
}
if jsonBytes[0] != '{' && jsonBytes[0] != '[' {
return ""
}
errVal := gjson.GetBytes(jsonBytes, "error")
if !errVal.Exists() || errVal.Type == gjson.Null {
return ""
}
message := gjson.GetBytes(jsonBytes, "error.message").String()
if message == "" {
message = gjson.GetBytes(jsonBytes, "error.error.message").String()
}
if message == "" && errVal.Type == gjson.String {
message = errVal.String()
}
if message == "" {
message = errVal.Raw
}
message = strings.TrimSpace(message)
if message == "" {
return "upstream returned error payload"
}
return message
}
func buildTestRequest(model string, endpointType string, channel *model.Channel, isStream bool) dto.Request {
testResponsesInput := json.RawMessage(`[{"role":"user","content":"hi"}]`)
// 根据端点类型构建不同的测试请求
@@ -504,8 +622,9 @@ func buildTestRequest(model string, endpointType string, channel *model.Channel)
case constant.EndpointTypeOpenAIResponse:
// 返回 OpenAIResponsesRequest
return &dto.OpenAIResponsesRequest{
Model: model,
Input: json.RawMessage(`[{"role":"user","content":"hi"}]`),
Model: model,
Input: json.RawMessage(`[{"role":"user","content":"hi"}]`),
Stream: isStream,
}
case constant.EndpointTypeOpenAIResponseCompact:
// 返回 OpenAIResponsesCompactionRequest
@@ -519,9 +638,9 @@ func buildTestRequest(model string, endpointType string, channel *model.Channel)
if constant.EndpointType(endpointType) == constant.EndpointTypeGemini {
maxTokens = 3000
}
return &dto.GeneralOpenAIRequest{
req := &dto.GeneralOpenAIRequest{
Model: model,
Stream: false,
Stream: isStream,
Messages: []dto.Message{
{
Role: "user",
@@ -530,6 +649,10 @@ func buildTestRequest(model string, endpointType string, channel *model.Channel)
},
MaxTokens: maxTokens,
}
if isStream {
req.StreamOptions = &dto.StreamOptions{IncludeUsage: true}
}
return req
}
}
@@ -565,15 +688,16 @@ func buildTestRequest(model string, endpointType string, channel *model.Channel)
// Responses-only models (e.g. codex series)
if strings.Contains(strings.ToLower(model), "codex") {
return &dto.OpenAIResponsesRequest{
Model: model,
Input: json.RawMessage(`[{"role":"user","content":"hi"}]`),
Model: model,
Input: json.RawMessage(`[{"role":"user","content":"hi"}]`),
Stream: isStream,
}
}
// Chat/Completion 请求 - 返回 GeneralOpenAIRequest
testRequest := &dto.GeneralOpenAIRequest{
Model: model,
Stream: false,
Stream: isStream,
Messages: []dto.Message{
{
Role: "user",
@@ -581,6 +705,9 @@ func buildTestRequest(model string, endpointType string, channel *model.Channel)
},
},
}
if isStream {
testRequest.StreamOptions = &dto.StreamOptions{IncludeUsage: true}
}
if strings.HasPrefix(model, "o") {
testRequest.MaxCompletionTokens = 16
@@ -618,8 +745,9 @@ func TestChannel(c *gin.Context) {
//}()
testModel := c.Query("model")
endpointType := c.Query("endpoint_type")
isStream, _ := strconv.ParseBool(c.Query("stream"))
tik := time.Now()
result := testChannel(channel, testModel, endpointType)
result := testChannel(channel, testModel, endpointType, isStream)
if result.localErr != nil {
c.JSON(http.StatusOK, gin.H{
"success": false,
@@ -678,7 +806,7 @@ func testAllChannels(notify bool) error {
for _, channel := range channels {
isChannelEnabled := channel.Status == common.ChannelStatusEnabled
tik := time.Now()
result := testChannel(channel, "", "")
result := testChannel(channel, "", "", false)
tok := time.Now()
milliseconds := tok.Sub(tik).Milliseconds()

View File

@@ -166,21 +166,21 @@ func CreateCustomOAuthProvider(c *gin.Context) {
// UpdateCustomOAuthProviderRequest is the request structure for updating a custom OAuth provider
type UpdateCustomOAuthProviderRequest struct {
Name string `json:"name"`
Slug string `json:"slug"`
Enabled bool `json:"enabled"`
ClientId string `json:"client_id"`
ClientSecret string `json:"client_secret"` // Optional: if empty, keep existing
AuthorizationEndpoint string `json:"authorization_endpoint"`
TokenEndpoint string `json:"token_endpoint"`
UserInfoEndpoint string `json:"user_info_endpoint"`
Scopes string `json:"scopes"`
UserIdField string `json:"user_id_field"`
UsernameField string `json:"username_field"`
DisplayNameField string `json:"display_name_field"`
EmailField string `json:"email_field"`
WellKnown string `json:"well_known"`
AuthStyle int `json:"auth_style"`
Name string `json:"name"`
Slug string `json:"slug"`
Enabled *bool `json:"enabled"` // Optional: if nil, keep existing
ClientId string `json:"client_id"`
ClientSecret string `json:"client_secret"` // Optional: if empty, keep existing
AuthorizationEndpoint string `json:"authorization_endpoint"`
TokenEndpoint string `json:"token_endpoint"`
UserInfoEndpoint string `json:"user_info_endpoint"`
Scopes string `json:"scopes"`
UserIdField string `json:"user_id_field"`
UsernameField string `json:"username_field"`
DisplayNameField string `json:"display_name_field"`
EmailField string `json:"email_field"`
WellKnown *string `json:"well_known"` // Optional: if nil, keep existing
AuthStyle *int `json:"auth_style"` // Optional: if nil, keep existing
}
// UpdateCustomOAuthProvider updates an existing custom OAuth provider
@@ -227,7 +227,9 @@ func UpdateCustomOAuthProvider(c *gin.Context) {
if req.Slug != "" {
provider.Slug = req.Slug
}
provider.Enabled = req.Enabled
if req.Enabled != nil {
provider.Enabled = *req.Enabled
}
if req.ClientId != "" {
provider.ClientId = req.ClientId
}
@@ -258,8 +260,12 @@ func UpdateCustomOAuthProvider(c *gin.Context) {
if req.EmailField != "" {
provider.EmailField = req.EmailField
}
provider.WellKnown = req.WellKnown
provider.AuthStyle = req.AuthStyle
if req.WellKnown != nil {
provider.WellKnown = *req.WellKnown
}
if req.AuthStyle != nil {
provider.AuthStyle = *req.AuthStyle
}
if err := model.UpdateCustomOAuthProvider(provider); err != nil {
common.ApiError(c, err)
@@ -296,7 +302,12 @@ func DeleteCustomOAuthProvider(c *gin.Context) {
}
// Check if there are any user bindings
count, _ := model.GetBindingCountByProviderId(id)
count, err := model.GetBindingCountByProviderId(id)
if err != nil {
common.SysError("Failed to get binding count for provider " + strconv.Itoa(id) + ": " + err.Error())
common.ApiErrorMsg(c, "检查用户绑定时发生错误,请稍后重试")
return
}
if count > 0 {
common.ApiErrorMsg(c, "该 OAuth 提供商还有用户绑定,无法删除。请先解除所有用户绑定。")
return

View File

@@ -11,6 +11,7 @@ import (
"github.com/QuantumNous/new-api/oauth"
"github.com/gin-contrib/sessions"
"github.com/gin-gonic/gin"
"gorm.io/gorm"
)
// providerParams returns map with Provider key for i18n templates
@@ -256,27 +257,62 @@ func findOrCreateOAuthUser(c *gin.Context, provider oauth.Provider, oauthUser *o
inviterId, _ = model.GetUserIdByAffCode(affCode.(string))
}
if err := user.Insert(inviterId); err != nil {
return nil, err
}
// For custom providers, create the binding after user is created
// Use transaction to ensure user creation and OAuth binding are atomic
if genericProvider, ok := provider.(*oauth.GenericOAuthProvider); ok {
binding := &model.UserOAuthBinding{
UserId: user.Id,
ProviderId: genericProvider.GetProviderId(),
ProviderUserId: oauthUser.ProviderUserID,
}
if err := model.CreateUserOAuthBinding(binding); err != nil {
common.SysError(fmt.Sprintf("[OAuth] Failed to create binding for user %d: %s", user.Id, err.Error()))
// Don't fail the registration, just log the error
// Custom provider: create user and binding in a transaction
err := model.DB.Transaction(func(tx *gorm.DB) error {
// Create user
if err := user.InsertWithTx(tx, inviterId); err != nil {
return err
}
// Create OAuth binding
binding := &model.UserOAuthBinding{
UserId: user.Id,
ProviderId: genericProvider.GetProviderId(),
ProviderUserId: oauthUser.ProviderUserID,
}
if err := model.CreateUserOAuthBindingWithTx(tx, binding); err != nil {
return err
}
return nil
})
if err != nil {
return nil, err
}
// Perform post-transaction tasks (logs, sidebar config, inviter rewards)
user.FinalizeOAuthUserCreation(inviterId)
} else {
// Built-in provider: set the provider user ID on the user model
provider.SetProviderUserID(user, oauthUser.ProviderUserID)
if err := user.Update(false); err != nil {
common.SysError(fmt.Sprintf("[OAuth] Failed to update provider ID for user %d: %s", user.Id, err.Error()))
// Built-in provider: create user and update provider ID in a transaction
err := model.DB.Transaction(func(tx *gorm.DB) error {
// Create user
if err := user.InsertWithTx(tx, inviterId); err != nil {
return err
}
// Set the provider user ID on the user model and update
provider.SetProviderUserID(user, oauthUser.ProviderUserID)
if err := tx.Model(user).Updates(map[string]interface{}{
"github_id": user.GitHubId,
"discord_id": user.DiscordId,
"oidc_id": user.OidcId,
"linux_do_id": user.LinuxDOId,
"wechat_id": user.WeChatId,
"telegram_id": user.TelegramId,
}).Error; err != nil {
return err
}
return nil
})
if err != nil {
return nil, err
}
// Perform post-transaction tasks
user.FinalizeOAuthUserCreation(inviterId)
}
return user, nil

View File

@@ -169,6 +169,15 @@ func UpdateOption(c *gin.Context) {
})
return
}
case "CreateCacheRatio":
err = ratio_setting.UpdateCreateCacheRatioByJSONString(option.Value.(string))
if err != nil {
c.JSON(http.StatusOK, gin.H{
"success": false,
"message": "缓存创建倍率设置失败: " + err.Error(),
})
return
}
case "ModelRequestRateLimitGroup":
err = setting.CheckModelRequestRateLimitGroup(option.Value.(string))
if err != nil {

View File

@@ -27,6 +27,7 @@ type ChannelOtherSettings struct {
AzureResponsesVersion string `json:"azure_responses_version,omitempty"`
VertexKeyType VertexKeyType `json:"vertex_key_type,omitempty"` // "json" or "api_key"
OpenRouterEnterprise *bool `json:"openrouter_enterprise,omitempty"`
ClaudeBetaQuery bool `json:"claude_beta_query,omitempty"` // Claude 渠道是否强制追加 ?beta=true
AllowServiceTier bool `json:"allow_service_tier,omitempty"` // 是否允许 service_tier 透传(默认过滤以避免额外计费)
DisableStore bool `json:"disable_store,omitempty"` // 是否禁用 store 透传(默认允许透传,禁用后可能导致 Codex 无法使用)
AllowSafetyIdentifier bool `json:"allow_safety_identifier,omitempty"` // 是否允许 safety_identifier 透传(默认过滤以保护用户隐私)

View File

@@ -97,13 +97,18 @@ func DeleteCustomOAuthProvider(id int) error {
}
// IsSlugTaken checks if a slug is already taken by another provider
// Returns true on DB errors (fail-closed) to prevent slug conflicts
func IsSlugTaken(slug string, excludeId int) bool {
var count int64
query := DB.Model(&CustomOAuthProvider{}).Where("slug = ?", slug)
if excludeId > 0 {
query = query.Where("id != ?", excludeId)
}
query.Count(&count)
res := query.Count(&count)
if res.Error != nil {
// Fail-closed: treat DB errors as slug being taken to prevent conflicts
return true
}
return count > 0
}

View File

@@ -47,7 +47,21 @@ func (mi *Model) Insert() error {
now := common.GetTimestamp()
mi.CreatedTime = now
mi.UpdatedTime = now
return DB.Create(mi).Error
// 保存原始值(因为 Create 后可能被 GORM 的 default 标签覆盖为 1
originalStatus := mi.Status
originalSyncOfficial := mi.SyncOfficial
// 先创建记录GORM 会对零值字段应用默认值)
if err := DB.Create(mi).Error; err != nil {
return err
}
// 使用保存的原始值进行更新,确保零值能正确保存
return DB.Model(&Model{}).Where("id = ?", mi.Id).Updates(map[string]interface{}{
"status": originalStatus,
"sync_official": originalSyncOfficial,
}).Error
}
func IsModelNameDuplicated(id int, name string) (bool, error) {
@@ -61,11 +75,9 @@ func IsModelNameDuplicated(id int, name string) (bool, error) {
func (mi *Model) Update() error {
mi.UpdatedTime = common.GetTimestamp()
return DB.Session(&gorm.Session{AllowGlobalUpdate: false, FullSaveAssociations: false}).
Model(&Model{}).
Where("id = ?", mi.Id).
Omit("created_time").
Select("*").
// 使用 Select 强制更新所有字段,包括零值
return DB.Model(&Model{}).Where("id = ?", mi.Id).
Select("model_name", "description", "icon", "tags", "vendor_id", "endpoints", "status", "sync_official", "name_rule", "updated_time").
Updates(mi).Error
}

View File

@@ -115,6 +115,7 @@ func InitOptionMap() {
common.OptionMap["ModelRatio"] = ratio_setting.ModelRatio2JSONString()
common.OptionMap["ModelPrice"] = ratio_setting.ModelPrice2JSONString()
common.OptionMap["CacheRatio"] = ratio_setting.CacheRatio2JSONString()
common.OptionMap["CreateCacheRatio"] = ratio_setting.CreateCacheRatio2JSONString()
common.OptionMap["GroupRatio"] = ratio_setting.GroupRatio2JSONString()
common.OptionMap["GroupGroupRatio"] = ratio_setting.GroupGroupRatio2JSONString()
common.OptionMap["UserUsableGroups"] = setting.UserUsableGroups2JSONString()
@@ -427,6 +428,8 @@ func updateOptionMap(key string, value string) (err error) {
err = ratio_setting.UpdateModelPriceByJSONString(value)
case "CacheRatio":
err = ratio_setting.UpdateCacheRatioByJSONString(value)
case "CreateCacheRatio":
err = ratio_setting.UpdateCreateCacheRatioByJSONString(value)
case "ImageRatio":
err = ratio_setting.UpdateImageRatioByJSONString(value)
case "AudioRatio":

View File

@@ -196,20 +196,25 @@ func updatePricing() {
modelSupportEndpointsStr[ability.Model] = endpoints
}
// 再补充模型自定义端点
// 再补充模型自定义端点:若配置有效则替换默认端点,不做合并
for modelName, meta := range metaMap {
if strings.TrimSpace(meta.Endpoints) == "" {
continue
}
var raw map[string]interface{}
if err := json.Unmarshal([]byte(meta.Endpoints), &raw); err == nil {
endpoints := modelSupportEndpointsStr[modelName]
for k := range raw {
if !common.StringsContains(endpoints, k) {
endpoints = append(endpoints, k)
endpoints := make([]string, 0, len(raw))
for k, v := range raw {
switch v.(type) {
case string, map[string]interface{}:
if !common.StringsContains(endpoints, k) {
endpoints = append(endpoints, k)
}
}
}
modelSupportEndpointsStr[modelName] = endpoints
if len(endpoints) > 0 {
modelSupportEndpointsStr[modelName] = endpoints
}
}
}

View File

@@ -429,6 +429,65 @@ func (user *User) Insert(inviterId int) error {
return nil
}
// InsertWithTx inserts a new user within an existing transaction.
// This is used for OAuth registration where user creation and binding need to be atomic.
// Post-creation tasks (sidebar config, logs, inviter rewards) are handled after the transaction commits.
func (user *User) InsertWithTx(tx *gorm.DB, inviterId int) error {
var err error
if user.Password != "" {
user.Password, err = common.Password2Hash(user.Password)
if err != nil {
return err
}
}
user.Quota = common.QuotaForNewUser
user.AffCode = common.GetRandomString(4)
// 初始化用户设置
if user.Setting == "" {
defaultSetting := dto.UserSetting{}
user.SetSetting(defaultSetting)
}
result := tx.Create(user)
if result.Error != nil {
return result.Error
}
return nil
}
// FinalizeOAuthUserCreation performs post-transaction tasks for OAuth user creation.
// This should be called after the transaction commits successfully.
func (user *User) FinalizeOAuthUserCreation(inviterId int) {
// 用户创建成功后,根据角色初始化边栏配置
var createdUser User
if err := DB.Where("id = ?", user.Id).First(&createdUser).Error; err == nil {
defaultSidebarConfig := generateDefaultSidebarConfigForRole(createdUser.Role)
if defaultSidebarConfig != "" {
currentSetting := createdUser.GetSetting()
currentSetting.SidebarModules = defaultSidebarConfig
createdUser.SetSetting(currentSetting)
createdUser.Update(false)
common.SysLog(fmt.Sprintf("为新用户 %s (角色: %d) 初始化边栏配置", createdUser.Username, createdUser.Role))
}
}
if common.QuotaForNewUser > 0 {
RecordLog(user.Id, LogTypeSystem, fmt.Sprintf("新用户注册赠送 %s", logger.LogQuota(common.QuotaForNewUser)))
}
if inviterId != 0 {
if common.QuotaForInvitee > 0 {
_ = IncreaseUserQuota(user.Id, common.QuotaForInvitee, true)
RecordLog(user.Id, LogTypeSystem, fmt.Sprintf("使用邀请码赠送 %s", logger.LogQuota(common.QuotaForInvitee)))
}
if common.QuotaForInviter > 0 {
RecordLog(inviterId, LogTypeSystem, fmt.Sprintf("邀请用户赠送 %s", logger.LogQuota(common.QuotaForInviter)))
_ = inviteUser(inviterId)
}
}
}
func (user *User) Update(updatePassword bool) error {
var err error
if updatePassword {

View File

@@ -3,18 +3,17 @@ package model
import (
"errors"
"time"
"gorm.io/gorm"
)
// UserOAuthBinding stores the binding relationship between users and custom OAuth providers
type UserOAuthBinding struct {
Id int `json:"id" gorm:"primaryKey"`
UserId int `json:"user_id" gorm:"index;not null"` // User ID
ProviderId int `json:"provider_id" gorm:"index;not null"` // Custom OAuth provider ID
ProviderUserId string `json:"provider_user_id" gorm:"type:varchar(256);not null"` // User ID from OAuth provider
UserId int `json:"user_id" gorm:"not null;uniqueIndex:ux_user_provider"` // User ID - one binding per user per provider
ProviderId int `json:"provider_id" gorm:"not null;uniqueIndex:ux_user_provider;uniqueIndex:ux_provider_userid"` // Custom OAuth provider ID
ProviderUserId string `json:"provider_user_id" gorm:"type:varchar(256);not null;uniqueIndex:ux_provider_userid"` // User ID from OAuth provider - one OAuth account per provider
CreatedAt time.Time `json:"created_at"`
// Composite unique index to prevent duplicate bindings
// One OAuth account can only be bound to one user
}
func (UserOAuthBinding) TableName() string {
@@ -82,6 +81,29 @@ func CreateUserOAuthBinding(binding *UserOAuthBinding) error {
return DB.Create(binding).Error
}
// CreateUserOAuthBindingWithTx creates a new OAuth binding within a transaction
func CreateUserOAuthBindingWithTx(tx *gorm.DB, binding *UserOAuthBinding) error {
if binding.UserId == 0 {
return errors.New("user ID is required")
}
if binding.ProviderId == 0 {
return errors.New("provider ID is required")
}
if binding.ProviderUserId == "" {
return errors.New("provider user ID is required")
}
// Check if this provider user ID is already taken (use tx to check within the same transaction)
var count int64
tx.Model(&UserOAuthBinding{}).Where("provider_id = ? AND provider_user_id = ?", binding.ProviderId, binding.ProviderUserId).Count(&count)
if count > 0 {
return errors.New("this OAuth account is already bound to another user")
}
binding.CreatedAt = time.Now()
return tx.Create(binding).Error
}
// UpdateUserOAuthBinding updates an existing OAuth binding (e.g., rebind to different OAuth account)
func UpdateUserOAuthBinding(userId, providerId int, newProviderUserId string) error {
// Check if the new provider user ID is already taken by another user

View File

@@ -5,6 +5,7 @@ import (
"context"
"encoding/json"
"fmt"
"io"
"net/http"
"strconv"
"time"
@@ -122,6 +123,17 @@ func (p *GitHubProvider) GetUserInfo(ctx context.Context, token *OAuthToken) (*O
logger.LogDebug(ctx, "[OAuth-GitHub] GetUserInfo response status: %d", res.StatusCode)
// Check for non-200 status codes before attempting to decode
if res.StatusCode != http.StatusOK {
body, _ := io.ReadAll(res.Body)
bodyStr := string(body)
if len(bodyStr) > 500 {
bodyStr = bodyStr[:500] + "..."
}
logger.LogError(ctx, fmt.Sprintf("[OAuth-GitHub] GetUserInfo failed: status=%d, body=%s", res.StatusCode, bodyStr))
return nil, NewOAuthErrorWithRaw(i18n.MsgOAuthGetUserErr, map[string]any{"Provider": "GitHub"}, fmt.Sprintf("status %d", res.StatusCode))
}
var githubUser gitHubUser
err = json.NewDecoder(res.Body).Decode(&githubUser)
if err != nil {

View File

@@ -223,11 +223,8 @@ func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycom
switch info.RelayFormat {
case types.RelayFormatClaude:
if supportsAliAnthropicMessages(info.UpstreamModelName) {
if info.IsStream {
return claude.ClaudeStreamHandler(c, resp, info)
}
return claude.ClaudeHandler(c, resp, info)
adaptor := claude.Adaptor{}
return adaptor.DoResponse(c, resp, info)
}
adaptor := openai.Adaptor{}

View File

@@ -95,6 +95,7 @@ func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, request
}
func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage any, err *types.NewAPIError) {
info.FinalRequestRelayFormat = types.RelayFormatClaude
if info.IsStream {
return ClaudeStreamHandler(c, resp, info)
} else {

View File

@@ -0,0 +1,111 @@
package claude
import (
"testing"
"github.com/QuantumNous/new-api/dto"
relaycommon "github.com/QuantumNous/new-api/relay/common"
"github.com/QuantumNous/new-api/setting/model_setting"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"github.com/tidwall/gjson"
)
func TestPatchClaudeMessageDeltaUsageDataPreserveUnknownFields(t *testing.T) {
originalData := `{"type":"message_delta","delta":{"stop_reason":"end_turn","stop_sequence":null},"usage":{"output_tokens":53},"vendor_meta":{"trace_id":"trace_001"}}`
usage := &dto.ClaudeUsage{
InputTokens: 100,
CacheReadInputTokens: 30,
CacheCreationInputTokens: 50,
}
patchedData := patchClaudeMessageDeltaUsageData(originalData, usage)
require.Equal(t, "message_delta", gjson.Get(patchedData, "type").String())
require.Equal(t, "end_turn", gjson.Get(patchedData, "delta.stop_reason").String())
require.Equal(t, "trace_001", gjson.Get(patchedData, "vendor_meta.trace_id").String())
require.EqualValues(t, 53, gjson.Get(patchedData, "usage.output_tokens").Int())
require.EqualValues(t, 100, gjson.Get(patchedData, "usage.input_tokens").Int())
require.EqualValues(t, 30, gjson.Get(patchedData, "usage.cache_read_input_tokens").Int())
require.EqualValues(t, 50, gjson.Get(patchedData, "usage.cache_creation_input_tokens").Int())
}
func TestPatchClaudeMessageDeltaUsageDataZeroValueChecks(t *testing.T) {
originalData := `{"type":"message_delta","usage":{"output_tokens":53,"input_tokens":9,"cache_read_input_tokens":0}}`
usage := &dto.ClaudeUsage{
InputTokens: 100,
CacheReadInputTokens: 30,
CacheCreationInputTokens: 0,
}
patchedData := patchClaudeMessageDeltaUsageData(originalData, usage)
require.EqualValues(t, 9, gjson.Get(patchedData, "usage.input_tokens").Int())
require.EqualValues(t, 30, gjson.Get(patchedData, "usage.cache_read_input_tokens").Int())
assert.False(t, gjson.Get(patchedData, "usage.cache_creation_input_tokens").Exists())
}
func TestShouldSkipClaudeMessageDeltaUsagePatch(t *testing.T) {
originGlobalPassThrough := model_setting.GetGlobalSettings().PassThroughRequestEnabled
t.Cleanup(func() {
model_setting.GetGlobalSettings().PassThroughRequestEnabled = originGlobalPassThrough
})
model_setting.GetGlobalSettings().PassThroughRequestEnabled = true
assert.True(t, shouldSkipClaudeMessageDeltaUsagePatch(&relaycommon.RelayInfo{}))
model_setting.GetGlobalSettings().PassThroughRequestEnabled = false
assert.True(t, shouldSkipClaudeMessageDeltaUsagePatch(&relaycommon.RelayInfo{
ChannelMeta: &relaycommon.ChannelMeta{ChannelSetting: dto.ChannelSettings{PassThroughBodyEnabled: true}},
}))
assert.False(t, shouldSkipClaudeMessageDeltaUsagePatch(&relaycommon.RelayInfo{
ChannelMeta: &relaycommon.ChannelMeta{ChannelSetting: dto.ChannelSettings{PassThroughBodyEnabled: false}},
}))
}
func TestBuildMessageDeltaPatchUsage(t *testing.T) {
t.Run("merge missing fields from claudeInfo", func(t *testing.T) {
claudeResponse := &dto.ClaudeResponse{Usage: &dto.ClaudeUsage{OutputTokens: 53}}
claudeInfo := &ClaudeResponseInfo{
Usage: &dto.Usage{
PromptTokens: 100,
PromptTokensDetails: dto.InputTokenDetails{
CachedTokens: 30,
CachedCreationTokens: 50,
},
ClaudeCacheCreation5mTokens: 10,
ClaudeCacheCreation1hTokens: 20,
},
}
usage := buildMessageDeltaPatchUsage(claudeResponse, claudeInfo)
require.NotNil(t, usage)
require.EqualValues(t, 100, usage.InputTokens)
require.EqualValues(t, 30, usage.CacheReadInputTokens)
require.EqualValues(t, 50, usage.CacheCreationInputTokens)
require.EqualValues(t, 53, usage.OutputTokens)
require.NotNil(t, usage.CacheCreation)
require.EqualValues(t, 10, usage.CacheCreation.Ephemeral5mInputTokens)
require.EqualValues(t, 20, usage.CacheCreation.Ephemeral1hInputTokens)
})
t.Run("keep upstream non-zero values", func(t *testing.T) {
claudeResponse := &dto.ClaudeResponse{Usage: &dto.ClaudeUsage{
InputTokens: 9,
CacheReadInputTokens: 7,
CacheCreationInputTokens: 6,
}}
claudeInfo := &ClaudeResponseInfo{Usage: &dto.Usage{
PromptTokens: 100,
PromptTokensDetails: dto.InputTokenDetails{
CachedTokens: 30,
CachedCreationTokens: 50,
},
}}
usage := buildMessageDeltaPatchUsage(claudeResponse, claudeInfo)
require.EqualValues(t, 9, usage.InputTokens)
require.EqualValues(t, 7, usage.CacheReadInputTokens)
require.EqualValues(t, 6, usage.CacheCreationInputTokens)
})
}

View File

@@ -21,6 +21,8 @@ import (
"github.com/QuantumNous/new-api/types"
"github.com/gin-gonic/gin"
"github.com/tidwall/gjson"
"github.com/tidwall/sjson"
)
const (
@@ -544,6 +546,78 @@ type ClaudeResponseInfo struct {
Done bool
}
func buildMessageDeltaPatchUsage(claudeResponse *dto.ClaudeResponse, claudeInfo *ClaudeResponseInfo) *dto.ClaudeUsage {
usage := &dto.ClaudeUsage{}
if claudeResponse != nil && claudeResponse.Usage != nil {
*usage = *claudeResponse.Usage
}
if claudeInfo == nil || claudeInfo.Usage == nil {
return usage
}
if usage.InputTokens == 0 && claudeInfo.Usage.PromptTokens > 0 {
usage.InputTokens = claudeInfo.Usage.PromptTokens
}
if usage.CacheReadInputTokens == 0 && claudeInfo.Usage.PromptTokensDetails.CachedTokens > 0 {
usage.CacheReadInputTokens = claudeInfo.Usage.PromptTokensDetails.CachedTokens
}
if usage.CacheCreationInputTokens == 0 && claudeInfo.Usage.PromptTokensDetails.CachedCreationTokens > 0 {
usage.CacheCreationInputTokens = claudeInfo.Usage.PromptTokensDetails.CachedCreationTokens
}
if usage.CacheCreation == nil && (claudeInfo.Usage.ClaudeCacheCreation5mTokens > 0 || claudeInfo.Usage.ClaudeCacheCreation1hTokens > 0) {
usage.CacheCreation = &dto.ClaudeCacheCreationUsage{
Ephemeral5mInputTokens: claudeInfo.Usage.ClaudeCacheCreation5mTokens,
Ephemeral1hInputTokens: claudeInfo.Usage.ClaudeCacheCreation1hTokens,
}
}
return usage
}
func shouldSkipClaudeMessageDeltaUsagePatch(info *relaycommon.RelayInfo) bool {
if model_setting.GetGlobalSettings().PassThroughRequestEnabled {
return true
}
if info == nil {
return false
}
return info.ChannelSetting.PassThroughBodyEnabled
}
func patchClaudeMessageDeltaUsageData(data string, usage *dto.ClaudeUsage) string {
if data == "" || usage == nil {
return data
}
data = setMessageDeltaUsageInt(data, "usage.input_tokens", usage.InputTokens)
data = setMessageDeltaUsageInt(data, "usage.cache_read_input_tokens", usage.CacheReadInputTokens)
data = setMessageDeltaUsageInt(data, "usage.cache_creation_input_tokens", usage.CacheCreationInputTokens)
if usage.CacheCreation != nil {
data = setMessageDeltaUsageInt(data, "usage.cache_creation.ephemeral_5m_input_tokens", usage.CacheCreation.Ephemeral5mInputTokens)
data = setMessageDeltaUsageInt(data, "usage.cache_creation.ephemeral_1h_input_tokens", usage.CacheCreation.Ephemeral1hInputTokens)
}
return data
}
func setMessageDeltaUsageInt(data string, path string, localValue int) string {
if localValue <= 0 {
return data
}
upstreamValue := gjson.Get(data, path)
if upstreamValue.Exists() && upstreamValue.Int() > 0 {
return data
}
patchedData, err := sjson.Set(data, path, localValue)
if err != nil {
return data
}
return patchedData
}
func FormatClaudeResponseInfo(claudeResponse *dto.ClaudeResponse, oaiResponse *dto.ChatCompletionsStreamResponse, claudeInfo *ClaudeResponseInfo) bool {
if claudeInfo == nil {
return false
@@ -638,6 +712,12 @@ func HandleStreamResponseData(c *gin.Context, info *relaycommon.RelayInfo, claud
if claudeResponse.Message != nil {
info.UpstreamModelName = claudeResponse.Message.Model
}
} else if claudeResponse.Type == "message_delta" {
// 确保 message_delta 的 usage 包含完整的 input_tokens 和 cache 相关字段
// 解决 AWS Bedrock 等上游返回的 message_delta 缺少这些字段的问题
if !shouldSkipClaudeMessageDeltaUsagePatch(info) {
data = patchClaudeMessageDeltaUsageData(data, buildMessageDeltaPatchUsage(&claudeResponse, claudeInfo))
}
}
helper.ClaudeChunkData(c, claudeResponse, data)
} else if info.RelayFormat == types.RelayFormatOpenAI {

View File

@@ -0,0 +1,175 @@
package claude
import (
"strings"
"testing"
"github.com/QuantumNous/new-api/dto"
)
func TestFormatClaudeResponseInfo_MessageStart(t *testing.T) {
claudeInfo := &ClaudeResponseInfo{
Usage: &dto.Usage{},
}
claudeResponse := &dto.ClaudeResponse{
Type: "message_start",
Message: &dto.ClaudeMediaMessage{
Id: "msg_123",
Model: "claude-3-5-sonnet",
Usage: &dto.ClaudeUsage{
InputTokens: 100,
OutputTokens: 1,
CacheCreationInputTokens: 50,
CacheReadInputTokens: 30,
},
},
}
ok := FormatClaudeResponseInfo(claudeResponse, nil, claudeInfo)
if !ok {
t.Fatal("expected true")
}
if claudeInfo.Usage.PromptTokens != 100 {
t.Errorf("PromptTokens = %d, want 100", claudeInfo.Usage.PromptTokens)
}
if claudeInfo.Usage.PromptTokensDetails.CachedTokens != 30 {
t.Errorf("CachedTokens = %d, want 30", claudeInfo.Usage.PromptTokensDetails.CachedTokens)
}
if claudeInfo.Usage.PromptTokensDetails.CachedCreationTokens != 50 {
t.Errorf("CachedCreationTokens = %d, want 50", claudeInfo.Usage.PromptTokensDetails.CachedCreationTokens)
}
if claudeInfo.ResponseId != "msg_123" {
t.Errorf("ResponseId = %s, want msg_123", claudeInfo.ResponseId)
}
if claudeInfo.Model != "claude-3-5-sonnet" {
t.Errorf("Model = %s, want claude-3-5-sonnet", claudeInfo.Model)
}
}
func TestFormatClaudeResponseInfo_MessageDelta_FullUsage(t *testing.T) {
// message_start 先积累 usage
claudeInfo := &ClaudeResponseInfo{
Usage: &dto.Usage{
PromptTokens: 100,
PromptTokensDetails: dto.InputTokenDetails{
CachedTokens: 30,
CachedCreationTokens: 50,
},
CompletionTokens: 1,
},
}
// message_delta 带完整 usage原生 Anthropic 场景)
claudeResponse := &dto.ClaudeResponse{
Type: "message_delta",
Usage: &dto.ClaudeUsage{
InputTokens: 100,
OutputTokens: 200,
CacheCreationInputTokens: 50,
CacheReadInputTokens: 30,
},
}
ok := FormatClaudeResponseInfo(claudeResponse, nil, claudeInfo)
if !ok {
t.Fatal("expected true")
}
if claudeInfo.Usage.PromptTokens != 100 {
t.Errorf("PromptTokens = %d, want 100", claudeInfo.Usage.PromptTokens)
}
if claudeInfo.Usage.CompletionTokens != 200 {
t.Errorf("CompletionTokens = %d, want 200", claudeInfo.Usage.CompletionTokens)
}
if claudeInfo.Usage.TotalTokens != 300 {
t.Errorf("TotalTokens = %d, want 300", claudeInfo.Usage.TotalTokens)
}
if !claudeInfo.Done {
t.Error("expected Done = true")
}
}
func TestFormatClaudeResponseInfo_MessageDelta_OnlyOutputTokens(t *testing.T) {
// 模拟 Bedrock: message_start 已积累 usage
claudeInfo := &ClaudeResponseInfo{
Usage: &dto.Usage{
PromptTokens: 100,
PromptTokensDetails: dto.InputTokenDetails{
CachedTokens: 30,
CachedCreationTokens: 50,
},
CompletionTokens: 1,
ClaudeCacheCreation5mTokens: 10,
ClaudeCacheCreation1hTokens: 20,
},
}
// Bedrock 的 message_delta 只有 output_tokens缺少 input_tokens 和 cache 字段
claudeResponse := &dto.ClaudeResponse{
Type: "message_delta",
Usage: &dto.ClaudeUsage{
OutputTokens: 200,
// InputTokens, CacheCreationInputTokens, CacheReadInputTokens 都是 0
},
}
ok := FormatClaudeResponseInfo(claudeResponse, nil, claudeInfo)
if !ok {
t.Fatal("expected true")
}
// PromptTokens 应保持 message_start 的值(因为 message_delta 的 InputTokens=0不更新
if claudeInfo.Usage.PromptTokens != 100 {
t.Errorf("PromptTokens = %d, want 100", claudeInfo.Usage.PromptTokens)
}
if claudeInfo.Usage.CompletionTokens != 200 {
t.Errorf("CompletionTokens = %d, want 200", claudeInfo.Usage.CompletionTokens)
}
if claudeInfo.Usage.TotalTokens != 300 {
t.Errorf("TotalTokens = %d, want 300", claudeInfo.Usage.TotalTokens)
}
// cache 字段应保持 message_start 的值
if claudeInfo.Usage.PromptTokensDetails.CachedTokens != 30 {
t.Errorf("CachedTokens = %d, want 30", claudeInfo.Usage.PromptTokensDetails.CachedTokens)
}
if claudeInfo.Usage.PromptTokensDetails.CachedCreationTokens != 50 {
t.Errorf("CachedCreationTokens = %d, want 50", claudeInfo.Usage.PromptTokensDetails.CachedCreationTokens)
}
if claudeInfo.Usage.ClaudeCacheCreation5mTokens != 10 {
t.Errorf("ClaudeCacheCreation5mTokens = %d, want 10", claudeInfo.Usage.ClaudeCacheCreation5mTokens)
}
if claudeInfo.Usage.ClaudeCacheCreation1hTokens != 20 {
t.Errorf("ClaudeCacheCreation1hTokens = %d, want 20", claudeInfo.Usage.ClaudeCacheCreation1hTokens)
}
if !claudeInfo.Done {
t.Error("expected Done = true")
}
}
func TestFormatClaudeResponseInfo_NilClaudeInfo(t *testing.T) {
claudeResponse := &dto.ClaudeResponse{Type: "message_start"}
ok := FormatClaudeResponseInfo(claudeResponse, nil, nil)
if ok {
t.Error("expected false for nil claudeInfo")
}
}
func TestFormatClaudeResponseInfo_ContentBlockDelta(t *testing.T) {
text := "hello"
claudeInfo := &ClaudeResponseInfo{
Usage: &dto.Usage{},
ResponseText: strings.Builder{},
}
claudeResponse := &dto.ClaudeResponse{
Type: "content_block_delta",
Delta: &dto.ClaudeMediaMessage{
Text: &text,
},
}
ok := FormatClaudeResponseInfo(claudeResponse, nil, claudeInfo)
if !ok {
t.Fatal("expected true")
}
if claudeInfo.ResponseText.String() != "hello" {
t.Errorf("ResponseText = %q, want %q", claudeInfo.ResponseText.String(), "hello")
}
}

View File

@@ -26,7 +26,7 @@ func (a *Adaptor) ConvertGeminiRequest(c *gin.Context, info *relaycommon.RelayIn
}
func (a *Adaptor) ConvertClaudeRequest(*gin.Context, *relaycommon.RelayInfo, *dto.ClaudeRequest) (any, error) {
return nil, errors.New("codex channel: endpoint not supported")
return nil, errors.New("codex channel: /v1/messages endpoint not supported")
}
func (a *Adaptor) ConvertAudioRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.AudioRequest) (io.Reader, error) {
@@ -41,15 +41,15 @@ func (a *Adaptor) Init(info *relaycommon.RelayInfo) {
}
func (a *Adaptor) ConvertOpenAIRequest(c *gin.Context, info *relaycommon.RelayInfo, request *dto.GeneralOpenAIRequest) (any, error) {
return nil, errors.New("codex channel: endpoint not supported")
return nil, errors.New("codex channel: /v1/chat/completions endpoint not supported")
}
func (a *Adaptor) ConvertRerankRequest(c *gin.Context, relayMode int, request dto.RerankRequest) (any, error) {
return nil, errors.New("codex channel: endpoint not supported")
return nil, errors.New("codex channel: /v1/rerank endpoint not supported")
}
func (a *Adaptor) ConvertEmbeddingRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.EmbeddingRequest) (any, error) {
return nil, errors.New("codex channel: endpoint not supported")
return nil, errors.New("codex channel: /v1/embeddings endpoint not supported")
}
func (a *Adaptor) ConvertOpenAIResponsesRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.OpenAIResponsesRequest) (any, error) {

View File

@@ -95,11 +95,8 @@ func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, request
func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage any, err *types.NewAPIError) {
switch info.RelayFormat {
case types.RelayFormatClaude:
if info.IsStream {
return claude.ClaudeStreamHandler(c, resp, info)
} else {
return claude.ClaudeHandler(c, resp, info)
}
adaptor := claude.Adaptor{}
return adaptor.DoResponse(c, resp, info)
default:
adaptor := openai.Adaptor{}
return adaptor.DoResponse(c, resp, info)

View File

@@ -102,11 +102,8 @@ func (a *Adaptor) ConvertEmbeddingRequest(c *gin.Context, info *relaycommon.Rela
func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage any, err *types.NewAPIError) {
switch info.RelayFormat {
case types.RelayFormatClaude:
if info.IsStream {
return claude.ClaudeStreamHandler(c, resp, info)
} else {
return claude.ClaudeHandler(c, resp, info)
}
adaptor := claude.Adaptor{}
return adaptor.DoResponse(c, resp, info)
default:
adaptor := openai.Adaptor{}
return adaptor.DoResponse(c, resp, info)

View File

@@ -171,7 +171,9 @@ func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) {
url = strings.Replace(url, "{model}", info.UpstreamModelName, -1)
return url, nil
default:
if info.RelayFormat == types.RelayFormatClaude || info.RelayFormat == types.RelayFormatGemini {
if (info.RelayFormat == types.RelayFormatClaude || info.RelayFormat == types.RelayFormatGemini) &&
info.RelayMode != relayconstant.RelayModeResponses &&
info.RelayMode != relayconstant.RelayModeResponsesCompact {
return fmt.Sprintf("%s/v1/chat/completions", info.ChannelBaseUrl), nil
}
return relaycommon.GetFullRequestURL(info.ChannelBaseUrl, info.RequestURLPath, info.ChannelType), nil

View File

@@ -71,12 +71,22 @@ func OaiResponsesToChatHandler(c *gin.Context, info *relaycommon.RelayInfo, resp
chatResp.Usage = *usage
}
chatBody, err := common.Marshal(chatResp)
var responseBody []byte
switch info.RelayFormat {
case types.RelayFormatClaude:
claudeResp := service.ResponseOpenAI2Claude(chatResp, info)
responseBody, err = common.Marshal(claudeResp)
case types.RelayFormatGemini:
geminiResp := service.ResponseOpenAI2Gemini(chatResp, info)
responseBody, err = common.Marshal(geminiResp)
default:
responseBody, err = common.Marshal(chatResp)
}
if err != nil {
return nil, types.NewOpenAIError(err, types.ErrorCodeJsonMarshalFailed, http.StatusInternalServerError)
}
service.IOCopyBytesGracefully(c, resp, chatBody)
service.IOCopyBytesGracefully(c, resp, responseBody)
return usage, nil
}
@@ -106,14 +116,43 @@ func OaiResponsesToChatStreamHandler(c *gin.Context, info *relaycommon.RelayInfo
toolCallArgsByID := make(map[string]string)
toolCallNameSent := make(map[string]bool)
toolCallCanonicalIDByItemID := make(map[string]string)
hasSentReasoningSummary := false
needsReasoningSummarySeparator := false
//reasoningSummaryTextByKey := make(map[string]string)
if info.RelayFormat == types.RelayFormatClaude && info.ClaudeConvertInfo == nil {
info.ClaudeConvertInfo = &relaycommon.ClaudeConvertInfo{LastMessagesType: relaycommon.LastMessageTypeNone}
}
sendChatChunk := func(chunk *dto.ChatCompletionsStreamResponse) bool {
if chunk == nil {
return true
}
if info.RelayFormat == types.RelayFormatOpenAI {
if err := helper.ObjectData(c, chunk); err != nil {
streamErr = types.NewOpenAIError(err, types.ErrorCodeBadResponse, http.StatusInternalServerError)
return false
}
return true
}
chunkData, err := common.Marshal(chunk)
if err != nil {
streamErr = types.NewOpenAIError(err, types.ErrorCodeJsonMarshalFailed, http.StatusInternalServerError)
return false
}
if err := HandleStreamFormat(c, info, string(chunkData), false, false); err != nil {
streamErr = types.NewOpenAIError(err, types.ErrorCodeBadResponse, http.StatusInternalServerError)
return false
}
return true
}
sendStartIfNeeded := func() bool {
if sentStart {
return true
}
if err := helper.ObjectData(c, helper.GenerateStartEmptyResponse(responseId, createAt, model, nil)); err != nil {
streamErr = types.NewOpenAIError(err, types.ErrorCodeBadResponse, http.StatusInternalServerError)
if !sendChatChunk(helper.GenerateStartEmptyResponse(responseId, createAt, model, nil)) {
return false
}
sentStart = true
@@ -154,6 +193,17 @@ func OaiResponsesToChatStreamHandler(c *gin.Context, info *relaycommon.RelayInfo
if delta == "" {
return true
}
if needsReasoningSummarySeparator {
if strings.HasPrefix(delta, "\n\n") {
needsReasoningSummarySeparator = false
} else if strings.HasPrefix(delta, "\n") {
delta = "\n" + delta
needsReasoningSummarySeparator = false
} else {
delta = "\n\n" + delta
needsReasoningSummarySeparator = false
}
}
if !sendStartIfNeeded() {
return false
}
@@ -173,10 +223,10 @@ func OaiResponsesToChatStreamHandler(c *gin.Context, info *relaycommon.RelayInfo
},
},
}
if err := helper.ObjectData(c, chunk); err != nil {
streamErr = types.NewOpenAIError(err, types.ErrorCodeBadResponse, http.StatusInternalServerError)
if !sendChatChunk(chunk) {
return false
}
hasSentReasoningSummary = true
return true
}
@@ -231,8 +281,7 @@ func OaiResponsesToChatStreamHandler(c *gin.Context, info *relaycommon.RelayInfo
},
},
}
if err := helper.ObjectData(c, chunk); err != nil {
streamErr = types.NewOpenAIError(err, types.ErrorCodeBadResponse, http.StatusInternalServerError)
if !sendChatChunk(chunk) {
return false
}
sawToolCall = true
@@ -282,6 +331,9 @@ func OaiResponsesToChatStreamHandler(c *gin.Context, info *relaycommon.RelayInfo
}
case "response.reasoning_summary_text.done":
if hasSentReasoningSummary {
needsReasoningSummarySeparator = true
}
//case "response.reasoning_summary_part.added", "response.reasoning_summary_part.done":
// key := responsesStreamIndexKey(strings.TrimSpace(streamResp.ItemID), streamResp.SummaryIndex)
@@ -323,8 +375,7 @@ func OaiResponsesToChatStreamHandler(c *gin.Context, info *relaycommon.RelayInfo
},
},
}
if err := helper.ObjectData(c, chunk); err != nil {
streamErr = types.NewOpenAIError(err, types.ErrorCodeBadResponse, http.StatusInternalServerError)
if !sendChatChunk(chunk) {
return false
}
}
@@ -419,13 +470,15 @@ func OaiResponsesToChatStreamHandler(c *gin.Context, info *relaycommon.RelayInfo
return false
}
if !sentStop {
if info.RelayFormat == types.RelayFormatClaude && info.ClaudeConvertInfo != nil {
info.ClaudeConvertInfo.Usage = usage
}
finishReason := "stop"
if sawToolCall && outputText.Len() == 0 {
finishReason = "tool_calls"
}
stop := helper.GenerateStopResponse(responseId, createAt, model, finishReason)
if err := helper.ObjectData(c, stop); err != nil {
streamErr = types.NewOpenAIError(err, types.ErrorCodeBadResponse, http.StatusInternalServerError)
if !sendChatChunk(stop) {
return false
}
sentStop = true
@@ -456,26 +509,31 @@ func OaiResponsesToChatStreamHandler(c *gin.Context, info *relaycommon.RelayInfo
}
if !sentStart {
if err := helper.ObjectData(c, helper.GenerateStartEmptyResponse(responseId, createAt, model, nil)); err != nil {
return nil, types.NewOpenAIError(err, types.ErrorCodeBadResponse, http.StatusInternalServerError)
if !sendChatChunk(helper.GenerateStartEmptyResponse(responseId, createAt, model, nil)) {
return nil, streamErr
}
}
if !sentStop {
if info.RelayFormat == types.RelayFormatClaude && info.ClaudeConvertInfo != nil {
info.ClaudeConvertInfo.Usage = usage
}
finishReason := "stop"
if sawToolCall && outputText.Len() == 0 {
finishReason = "tool_calls"
}
stop := helper.GenerateStopResponse(responseId, createAt, model, finishReason)
if err := helper.ObjectData(c, stop); err != nil {
return nil, types.NewOpenAIError(err, types.ErrorCodeBadResponse, http.StatusInternalServerError)
if !sendChatChunk(stop) {
return nil, streamErr
}
}
if info.ShouldIncludeUsage && usage != nil {
if info.RelayFormat == types.RelayFormatOpenAI && info.ShouldIncludeUsage && usage != nil {
if err := helper.ObjectData(c, helper.GenerateFinalUsageResponse(responseId, createAt, model, *usage)); err != nil {
return nil, types.NewOpenAIError(err, types.ErrorCodeBadResponse, http.StatusInternalServerError)
}
}
helper.Done(c)
if info.RelayFormat == types.RelayFormatOpenAI {
helper.Done(c)
}
return usage, nil
}

View File

@@ -365,10 +365,11 @@ func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, request
}
func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage any, err *types.NewAPIError) {
claudeAdaptor := claude.Adaptor{}
if info.IsStream {
switch a.RequestMode {
case RequestModeClaude:
return claude.ClaudeStreamHandler(c, resp, info)
return claudeAdaptor.DoResponse(c, resp, info)
case RequestModeGemini:
if info.RelayMode == constant.RelayModeGemini {
return gemini.GeminiTextGenerationStreamHandler(c, info, resp)
@@ -381,7 +382,7 @@ func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycom
} else {
switch a.RequestMode {
case RequestModeClaude:
return claude.ClaudeHandler(c, resp, info)
return claudeAdaptor.DoResponse(c, resp, info)
case RequestModeGemini:
if info.RelayMode == constant.RelayModeGemini {
return gemini.GeminiTextGenerationHandler(c, info, resp)

View File

@@ -347,10 +347,8 @@ func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, request
func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage any, err *types.NewAPIError) {
if info.RelayFormat == types.RelayFormatClaude {
if _, ok := channelconstant.ChannelSpecialBases[info.ChannelBaseUrl]; ok {
if info.IsStream {
return claude.ClaudeStreamHandler(c, resp, info)
}
return claude.ClaudeHandler(c, resp, info)
adaptor := claude.Adaptor{}
return adaptor.DoResponse(c, resp, info)
}
}

View File

@@ -109,11 +109,8 @@ func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, request
func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage any, err *types.NewAPIError) {
switch info.RelayFormat {
case types.RelayFormatClaude:
if info.IsStream {
return claude.ClaudeStreamHandler(c, resp, info)
} else {
return claude.ClaudeHandler(c, resp, info)
}
adaptor := claude.Adaptor{}
return adaptor.DoResponse(c, resp, info)
default:
if info.RelayMode == relayconstant.RelayModeImagesGenerations {
return zhipu4vImageHandler(c, resp, info)

View File

@@ -110,6 +110,23 @@ func ClaudeHelper(c *gin.Context, info *relaycommon.RelayInfo) (newAPIError *typ
}
}
if !model_setting.GetGlobalSettings().PassThroughRequestEnabled &&
!info.ChannelSetting.PassThroughBodyEnabled &&
service.ShouldChatCompletionsUseResponsesGlobal(info.ChannelId, info.ChannelType, info.OriginModelName) {
openAIRequest, convErr := service.ClaudeToOpenAIRequest(*request, info)
if convErr != nil {
return types.NewError(convErr, types.ErrorCodeConvertRequestFailed, types.ErrOptionWithSkipRetry())
}
usage, newApiErr := chatCompletionsViaResponses(c, info, adaptor, openAIRequest)
if newApiErr != nil {
return newApiErr
}
service.PostClaudeConsumeQuota(c, info, usage)
return nil
}
var requestBody io.Reader
if model_setting.GetGlobalSettings().PassThroughRequestEnabled || info.ChannelSetting.PassThroughBodyEnabled {
body, err := common.GetRequestBody(c)

View File

@@ -37,6 +37,9 @@ type ClaudeConvertInfo struct {
Usage *dto.Usage
FinishReason string
Done bool
ToolCallBaseIndex int
ToolCallMaxIndexOffset int
}
type RerankerInfo struct {
@@ -145,6 +148,8 @@ type RelayInfo struct {
// RequestConversionChain records request format conversions in order, e.g.
// ["openai", "openai_responses"] or ["openai", "claude"].
RequestConversionChain []types.RelayFormat
// 最终请求到上游的格式 TODO: 当前仅设置了Claude
FinalRequestRelayFormat types.RelayFormat
ThinkingContentInfo
TokenCountMeta
@@ -319,12 +324,15 @@ func GenRelayInfoClaude(c *gin.Context, request dto.Request) *RelayInfo {
info.ClaudeConvertInfo = &ClaudeConvertInfo{
LastMessagesType: LastMessageTypeNone,
}
if c.Query("beta") == "true" {
info.IsClaudeBetaQuery = true
}
info.IsClaudeBetaQuery = c.Query("beta") == "true" || isClaudeBetaForced(c)
return info
}
func isClaudeBetaForced(c *gin.Context) bool {
channelOtherSettings, ok := common.GetContextKeyType[dto.ChannelOtherSettings](c, constant.ContextKeyChannelOtherSetting)
return ok && channelOtherSettings.ClaudeBetaQuery
}
func GenRelayInfoRerank(c *gin.Context, request *dto.RerankRequest) *RelayInfo {
info := genBaseRelayInfo(c, request)
info.RelayMode = relayconstant.RelayModeRerank

View File

@@ -334,7 +334,7 @@ func postConsumeQuota(ctx *gin.Context, relayInfo *relaycommon.RelayInfo, usage
var audioInputQuota decimal.Decimal
var audioInputPrice float64
isClaudeUsageSemantic := relayInfo.ChannelType == constant.ChannelTypeAnthropic
isClaudeUsageSemantic := relayInfo.FinalRequestRelayFormat == types.RelayFormatClaude
if !relayInfo.PriceData.UsePrice {
baseTokens := dPromptTokens
// 减去 cached tokens

View File

@@ -207,6 +207,44 @@ func StreamResponseOpenAI2Claude(openAIResponse *dto.ChatCompletionsStreamRespon
}
var claudeResponses []*dto.ClaudeResponse
// stopOpenBlocks emits the required content_block_stop event(s) for the currently open block(s)
// according to Anthropic's SSE streaming state machine:
// content_block_start -> content_block_delta* -> content_block_stop (per index).
//
// For text/thinking, there is at most one open block at info.ClaudeConvertInfo.Index.
// For tools, OpenAI tool_calls can stream multiple parallel tool_use blocks (indexed from 0),
// so we may have multiple open blocks and must stop each one explicitly.
stopOpenBlocks := func() {
switch info.ClaudeConvertInfo.LastMessagesType {
case relaycommon.LastMessageTypeText, relaycommon.LastMessageTypeThinking:
claudeResponses = append(claudeResponses, generateStopBlock(info.ClaudeConvertInfo.Index))
case relaycommon.LastMessageTypeTools:
base := info.ClaudeConvertInfo.ToolCallBaseIndex
for offset := 0; offset <= info.ClaudeConvertInfo.ToolCallMaxIndexOffset; offset++ {
claudeResponses = append(claudeResponses, generateStopBlock(base+offset))
}
}
}
// stopOpenBlocksAndAdvance closes the currently open block(s) and advances the content block index
// to the next available slot for subsequent content_block_start events.
//
// This prevents invalid streams where a content_block_delta (e.g. thinking_delta) is emitted for an
// index whose active content_block type is different (the typical cause of "Mismatched content block type").
stopOpenBlocksAndAdvance := func() {
if info.ClaudeConvertInfo.LastMessagesType == relaycommon.LastMessageTypeNone {
return
}
stopOpenBlocks()
switch info.ClaudeConvertInfo.LastMessagesType {
case relaycommon.LastMessageTypeTools:
info.ClaudeConvertInfo.Index = info.ClaudeConvertInfo.ToolCallBaseIndex + info.ClaudeConvertInfo.ToolCallMaxIndexOffset + 1
info.ClaudeConvertInfo.ToolCallBaseIndex = 0
info.ClaudeConvertInfo.ToolCallMaxIndexOffset = 0
default:
info.ClaudeConvertInfo.Index++
}
info.ClaudeConvertInfo.LastMessagesType = relaycommon.LastMessageTypeNone
}
if info.SendResponseCount == 1 {
msg := &dto.ClaudeMediaMessage{
Id: openAIResponse.Id,
@@ -228,6 +266,8 @@ func StreamResponseOpenAI2Claude(openAIResponse *dto.ChatCompletionsStreamRespon
//})
if openAIResponse.IsToolCall() {
info.ClaudeConvertInfo.LastMessagesType = relaycommon.LastMessageTypeTools
info.ClaudeConvertInfo.ToolCallBaseIndex = 0
info.ClaudeConvertInfo.ToolCallMaxIndexOffset = 0
var toolCall dto.ToolCallResponse
if len(openAIResponse.Choices) > 0 && len(openAIResponse.Choices[0].Delta.ToolCalls) > 0 {
toolCall = openAIResponse.Choices[0].Delta.ToolCalls[0]
@@ -252,8 +292,9 @@ func StreamResponseOpenAI2Claude(openAIResponse *dto.ChatCompletionsStreamRespon
claudeResponses = append(claudeResponses, resp)
// 首块包含工具 delta则追加 input_json_delta
if toolCall.Function.Arguments != "" {
idx := 0
claudeResponses = append(claudeResponses, &dto.ClaudeResponse{
Index: &info.ClaudeConvertInfo.Index,
Index: &idx,
Type: "content_block_delta",
Delta: &dto.ClaudeMediaMessage{
Type: "input_json_delta",
@@ -270,16 +311,21 @@ func StreamResponseOpenAI2Claude(openAIResponse *dto.ChatCompletionsStreamRespon
content := openAIResponse.Choices[0].Delta.GetContentString()
if reasoning != "" {
if info.ClaudeConvertInfo.LastMessagesType != relaycommon.LastMessageTypeThinking {
stopOpenBlocksAndAdvance()
}
idx := info.ClaudeConvertInfo.Index
claudeResponses = append(claudeResponses, &dto.ClaudeResponse{
Index: &info.ClaudeConvertInfo.Index,
Index: &idx,
Type: "content_block_start",
ContentBlock: &dto.ClaudeMediaMessage{
Type: "thinking",
Thinking: common.GetPointer[string](""),
},
})
idx2 := idx
claudeResponses = append(claudeResponses, &dto.ClaudeResponse{
Index: &info.ClaudeConvertInfo.Index,
Index: &idx2,
Type: "content_block_delta",
Delta: &dto.ClaudeMediaMessage{
Type: "thinking_delta",
@@ -288,16 +334,21 @@ func StreamResponseOpenAI2Claude(openAIResponse *dto.ChatCompletionsStreamRespon
})
info.ClaudeConvertInfo.LastMessagesType = relaycommon.LastMessageTypeThinking
} else if content != "" {
if info.ClaudeConvertInfo.LastMessagesType != relaycommon.LastMessageTypeText {
stopOpenBlocksAndAdvance()
}
idx := info.ClaudeConvertInfo.Index
claudeResponses = append(claudeResponses, &dto.ClaudeResponse{
Index: &info.ClaudeConvertInfo.Index,
Index: &idx,
Type: "content_block_start",
ContentBlock: &dto.ClaudeMediaMessage{
Type: "text",
Text: common.GetPointer[string](""),
},
})
idx2 := idx
claudeResponses = append(claudeResponses, &dto.ClaudeResponse{
Index: &info.ClaudeConvertInfo.Index,
Index: &idx2,
Type: "content_block_delta",
Delta: &dto.ClaudeMediaMessage{
Type: "text_delta",
@@ -311,7 +362,7 @@ func StreamResponseOpenAI2Claude(openAIResponse *dto.ChatCompletionsStreamRespon
// 如果首块就带 finish_reason需要立即发送停止块
if len(openAIResponse.Choices) > 0 && openAIResponse.Choices[0].FinishReason != nil && *openAIResponse.Choices[0].FinishReason != "" {
info.FinishReason = *openAIResponse.Choices[0].FinishReason
claudeResponses = append(claudeResponses, generateStopBlock(info.ClaudeConvertInfo.Index))
stopOpenBlocks()
oaiUsage := openAIResponse.Usage
if oaiUsage == nil {
oaiUsage = info.ClaudeConvertInfo.Usage
@@ -342,7 +393,7 @@ func StreamResponseOpenAI2Claude(openAIResponse *dto.ChatCompletionsStreamRespon
// no choices
// 可能为非标准的 OpenAI 响应,判断是否已经完成
if info.ClaudeConvertInfo.Done {
claudeResponses = append(claudeResponses, generateStopBlock(info.ClaudeConvertInfo.Index))
stopOpenBlocks()
oaiUsage := info.ClaudeConvertInfo.Usage
if oaiUsage != nil {
claudeResponses = append(claudeResponses, &dto.ClaudeResponse{
@@ -376,18 +427,25 @@ func StreamResponseOpenAI2Claude(openAIResponse *dto.ChatCompletionsStreamRespon
if len(chosenChoice.Delta.ToolCalls) > 0 {
toolCalls := chosenChoice.Delta.ToolCalls
if info.ClaudeConvertInfo.LastMessagesType != relaycommon.LastMessageTypeTools {
claudeResponses = append(claudeResponses, generateStopBlock(info.ClaudeConvertInfo.Index))
info.ClaudeConvertInfo.Index++
stopOpenBlocksAndAdvance()
info.ClaudeConvertInfo.ToolCallBaseIndex = info.ClaudeConvertInfo.Index
info.ClaudeConvertInfo.ToolCallMaxIndexOffset = 0
}
info.ClaudeConvertInfo.LastMessagesType = relaycommon.LastMessageTypeTools
base := info.ClaudeConvertInfo.ToolCallBaseIndex
maxOffset := info.ClaudeConvertInfo.ToolCallMaxIndexOffset
for i, toolCall := range toolCalls {
blockIndex := info.ClaudeConvertInfo.Index
offset := 0
if toolCall.Index != nil {
blockIndex = *toolCall.Index
} else if len(toolCalls) > 1 {
blockIndex = info.ClaudeConvertInfo.Index + i
offset = *toolCall.Index
} else {
offset = i
}
if offset > maxOffset {
maxOffset = offset
}
blockIndex := base + offset
idx := blockIndex
if toolCall.Function.Name != "" {
@@ -413,17 +471,19 @@ func StreamResponseOpenAI2Claude(openAIResponse *dto.ChatCompletionsStreamRespon
},
})
}
info.ClaudeConvertInfo.Index = blockIndex
}
info.ClaudeConvertInfo.ToolCallMaxIndexOffset = maxOffset
info.ClaudeConvertInfo.Index = base + maxOffset
} else {
reasoning := chosenChoice.Delta.GetReasoningContent()
textContent := chosenChoice.Delta.GetContentString()
if reasoning != "" || textContent != "" {
if reasoning != "" {
if info.ClaudeConvertInfo.LastMessagesType != relaycommon.LastMessageTypeThinking {
stopOpenBlocksAndAdvance()
idx := info.ClaudeConvertInfo.Index
claudeResponses = append(claudeResponses, &dto.ClaudeResponse{
Index: &info.ClaudeConvertInfo.Index,
Index: &idx,
Type: "content_block_start",
ContentBlock: &dto.ClaudeMediaMessage{
Type: "thinking",
@@ -438,12 +498,10 @@ func StreamResponseOpenAI2Claude(openAIResponse *dto.ChatCompletionsStreamRespon
}
} else {
if info.ClaudeConvertInfo.LastMessagesType != relaycommon.LastMessageTypeText {
if info.ClaudeConvertInfo.LastMessagesType == relaycommon.LastMessageTypeThinking || info.ClaudeConvertInfo.LastMessagesType == relaycommon.LastMessageTypeTools {
claudeResponses = append(claudeResponses, generateStopBlock(info.ClaudeConvertInfo.Index))
info.ClaudeConvertInfo.Index++
}
stopOpenBlocksAndAdvance()
idx := info.ClaudeConvertInfo.Index
claudeResponses = append(claudeResponses, &dto.ClaudeResponse{
Index: &info.ClaudeConvertInfo.Index,
Index: &idx,
Type: "content_block_start",
ContentBlock: &dto.ClaudeMediaMessage{
Type: "text",
@@ -462,13 +520,13 @@ func StreamResponseOpenAI2Claude(openAIResponse *dto.ChatCompletionsStreamRespon
}
}
claudeResponse.Index = &info.ClaudeConvertInfo.Index
claudeResponse.Index = common.GetPointer[int](info.ClaudeConvertInfo.Index)
if !isEmpty && claudeResponse.Delta != nil {
claudeResponses = append(claudeResponses, &claudeResponse)
}
if doneChunk || info.ClaudeConvertInfo.Done {
claudeResponses = append(claudeResponses, generateStopBlock(info.ClaudeConvertInfo.Index))
stopOpenBlocks()
oaiUsage := openAIResponse.Usage
if oaiUsage == nil {
oaiUsage = info.ClaudeConvertInfo.Usage

View File

@@ -1,10 +1,7 @@
package ratio_setting
import (
"encoding/json"
"sync"
"github.com/QuantumNous/new-api/common"
"github.com/QuantumNous/new-api/types"
)
var defaultCacheRatio = map[string]float64{
@@ -98,44 +95,37 @@ var defaultCreateCacheRatio = map[string]float64{
//var defaultCreateCacheRatio = map[string]float64{}
var cacheRatioMap map[string]float64
var cacheRatioMapMutex sync.RWMutex
var cacheRatioMap = types.NewRWMap[string, float64]()
var createCacheRatioMap = types.NewRWMap[string, float64]()
// GetCacheRatioMap returns the cache ratio map
// GetCacheRatioMap returns a copy of the cache ratio map
func GetCacheRatioMap() map[string]float64 {
cacheRatioMapMutex.RLock()
defer cacheRatioMapMutex.RUnlock()
return cacheRatioMap
return cacheRatioMap.ReadAll()
}
// CacheRatio2JSONString converts the cache ratio map to a JSON string
func CacheRatio2JSONString() string {
cacheRatioMapMutex.RLock()
defer cacheRatioMapMutex.RUnlock()
jsonBytes, err := json.Marshal(cacheRatioMap)
if err != nil {
common.SysLog("error marshalling cache ratio: " + err.Error())
}
return string(jsonBytes)
return cacheRatioMap.MarshalJSONString()
}
// CreateCacheRatio2JSONString converts the create cache ratio map to a JSON string
func CreateCacheRatio2JSONString() string {
return createCacheRatioMap.MarshalJSONString()
}
// UpdateCacheRatioByJSONString updates the cache ratio map from a JSON string
func UpdateCacheRatioByJSONString(jsonStr string) error {
cacheRatioMapMutex.Lock()
defer cacheRatioMapMutex.Unlock()
cacheRatioMap = make(map[string]float64)
err := json.Unmarshal([]byte(jsonStr), &cacheRatioMap)
if err == nil {
InvalidateExposedDataCache()
}
return err
return types.LoadFromJsonStringWithCallback(cacheRatioMap, jsonStr, InvalidateExposedDataCache)
}
// UpdateCreateCacheRatioByJSONString updates the create cache ratio map from a JSON string
func UpdateCreateCacheRatioByJSONString(jsonStr string) error {
return types.LoadFromJsonStringWithCallback(createCacheRatioMap, jsonStr, InvalidateExposedDataCache)
}
// GetCacheRatio returns the cache ratio for a model
func GetCacheRatio(name string) (float64, bool) {
cacheRatioMapMutex.RLock()
defer cacheRatioMapMutex.RUnlock()
ratio, ok := cacheRatioMap[name]
ratio, ok := cacheRatioMap.Get(name)
if !ok {
return 1, false // Default to 1 if not found
}
@@ -143,7 +133,7 @@ func GetCacheRatio(name string) (float64, bool) {
}
func GetCreateCacheRatio(name string) (float64, bool) {
ratio, ok := defaultCreateCacheRatio[name]
ratio, ok := createCacheRatioMap.Get(name)
if !ok {
return 1.25, false // Default to 1.25 if not found
}
@@ -151,11 +141,9 @@ func GetCreateCacheRatio(name string) (float64, bool) {
}
func GetCacheRatioCopy() map[string]float64 {
cacheRatioMapMutex.RLock()
defer cacheRatioMapMutex.RUnlock()
copyMap := make(map[string]float64, len(cacheRatioMap))
for k, v := range cacheRatioMap {
copyMap[k] = v
}
return copyMap
return cacheRatioMap.ReadAll()
}
func GetCreateCacheRatioCopy() map[string]float64 {
return createCacheRatioMap.ReadAll()
}

View File

@@ -42,10 +42,11 @@ func GetExposedData() gin.H {
return cloneGinH(c.data)
}
newData := gin.H{
"model_ratio": GetModelRatioCopy(),
"completion_ratio": GetCompletionRatioCopy(),
"cache_ratio": GetCacheRatioCopy(),
"model_price": GetModelPriceCopy(),
"model_ratio": GetModelRatioCopy(),
"completion_ratio": GetCompletionRatioCopy(),
"cache_ratio": GetCacheRatioCopy(),
"create_cache_ratio": GetCreateCacheRatioCopy(),
"model_price": GetModelPriceCopy(),
}
exposedData.Store(&exposedCache{
data: newData,

View File

@@ -3,29 +3,27 @@ package ratio_setting
import (
"encoding/json"
"errors"
"sync"
"github.com/QuantumNous/new-api/common"
"github.com/QuantumNous/new-api/setting/config"
"github.com/QuantumNous/new-api/types"
)
var groupRatio = map[string]float64{
var defaultGroupRatio = map[string]float64{
"default": 1,
"vip": 1,
"svip": 1,
}
var groupRatioMutex sync.RWMutex
var groupRatioMap = types.NewRWMap[string, float64]()
var (
GroupGroupRatio = map[string]map[string]float64{
"vip": {
"edit_this": 0.9,
},
}
groupGroupRatioMutex sync.RWMutex
)
var defaultGroupGroupRatio = map[string]map[string]float64{
"vip": {
"edit_this": 0.9,
},
}
var groupGroupRatioMap = types.NewRWMap[string, map[string]float64]()
var defaultGroupSpecialUsableGroup = map[string]map[string]string{
"vip": {
@@ -35,9 +33,9 @@ var defaultGroupSpecialUsableGroup = map[string]map[string]string{
}
type GroupRatioSetting struct {
GroupRatio map[string]float64 `json:"group_ratio"`
GroupGroupRatio map[string]map[string]float64 `json:"group_group_ratio"`
GroupSpecialUsableGroup *types.RWMap[string, map[string]string] `json:"group_special_usable_group"`
GroupRatio *types.RWMap[string, float64] `json:"group_ratio"`
GroupGroupRatio *types.RWMap[string, map[string]float64] `json:"group_group_ratio"`
GroupSpecialUsableGroup *types.RWMap[string, map[string]string] `json:"group_special_usable_group"`
}
var groupRatioSetting GroupRatioSetting
@@ -46,10 +44,13 @@ func init() {
groupSpecialUsableGroup := types.NewRWMap[string, map[string]string]()
groupSpecialUsableGroup.AddAll(defaultGroupSpecialUsableGroup)
groupRatioMap.AddAll(defaultGroupRatio)
groupGroupRatioMap.AddAll(defaultGroupGroupRatio)
groupRatioSetting = GroupRatioSetting{
GroupSpecialUsableGroup: groupSpecialUsableGroup,
GroupRatio: groupRatio,
GroupGroupRatio: GroupGroupRatio,
GroupRatio: groupRatioMap,
GroupGroupRatio: groupGroupRatioMap,
}
config.GlobalConfig.Register("group_ratio_setting", &groupRatioSetting)
@@ -64,48 +65,24 @@ func GetGroupRatioSetting() *GroupRatioSetting {
}
func GetGroupRatioCopy() map[string]float64 {
groupRatioMutex.RLock()
defer groupRatioMutex.RUnlock()
groupRatioCopy := make(map[string]float64)
for k, v := range groupRatio {
groupRatioCopy[k] = v
}
return groupRatioCopy
return groupRatioMap.ReadAll()
}
func ContainsGroupRatio(name string) bool {
groupRatioMutex.RLock()
defer groupRatioMutex.RUnlock()
_, ok := groupRatio[name]
_, ok := groupRatioMap.Get(name)
return ok
}
func GroupRatio2JSONString() string {
groupRatioMutex.RLock()
defer groupRatioMutex.RUnlock()
jsonBytes, err := json.Marshal(groupRatio)
if err != nil {
common.SysLog("error marshalling model ratio: " + err.Error())
}
return string(jsonBytes)
return groupRatioMap.MarshalJSONString()
}
func UpdateGroupRatioByJSONString(jsonStr string) error {
groupRatioMutex.Lock()
defer groupRatioMutex.Unlock()
groupRatio = make(map[string]float64)
return json.Unmarshal([]byte(jsonStr), &groupRatio)
return types.LoadFromJsonString(groupRatioMap, jsonStr)
}
func GetGroupRatio(name string) float64 {
groupRatioMutex.RLock()
defer groupRatioMutex.RUnlock()
ratio, ok := groupRatio[name]
ratio, ok := groupRatioMap.Get(name)
if !ok {
common.SysLog("group ratio not found: " + name)
return 1
@@ -114,10 +91,7 @@ func GetGroupRatio(name string) float64 {
}
func GetGroupGroupRatio(userGroup, usingGroup string) (float64, bool) {
groupGroupRatioMutex.RLock()
defer groupGroupRatioMutex.RUnlock()
gp, ok := GroupGroupRatio[userGroup]
gp, ok := groupGroupRatioMap.Get(userGroup)
if !ok {
return -1, false
}
@@ -129,22 +103,11 @@ func GetGroupGroupRatio(userGroup, usingGroup string) (float64, bool) {
}
func GroupGroupRatio2JSONString() string {
groupGroupRatioMutex.RLock()
defer groupGroupRatioMutex.RUnlock()
jsonBytes, err := json.Marshal(GroupGroupRatio)
if err != nil {
common.SysLog("error marshalling group-group ratio: " + err.Error())
}
return string(jsonBytes)
return groupGroupRatioMap.MarshalJSONString()
}
func UpdateGroupGroupRatioByJSONString(jsonStr string) error {
groupGroupRatioMutex.Lock()
defer groupGroupRatioMutex.Unlock()
GroupGroupRatio = make(map[string]map[string]float64)
return json.Unmarshal([]byte(jsonStr), &GroupGroupRatio)
return types.LoadFromJsonString(groupGroupRatioMap, jsonStr)
}
func CheckGroupRatio(jsonStr string) error {

View File

@@ -1,12 +1,11 @@
package ratio_setting
import (
"encoding/json"
"strings"
"sync"
"github.com/QuantumNous/new-api/common"
"github.com/QuantumNous/new-api/setting/operation_setting"
"github.com/QuantumNous/new-api/types"
)
// from songquanpeng/one-api
@@ -319,19 +318,9 @@ var defaultAudioCompletionRatio = map[string]float64{
"tts-1-hd-1106": 0,
}
var (
modelPriceMap map[string]float64 = nil
modelPriceMapMutex = sync.RWMutex{}
)
var (
modelRatioMap map[string]float64 = nil
modelRatioMapMutex = sync.RWMutex{}
)
var (
CompletionRatio map[string]float64 = nil
CompletionRatioMutex = sync.RWMutex{}
)
var modelPriceMap = types.NewRWMap[string, float64]()
var modelRatioMap = types.NewRWMap[string, float64]()
var completionRatioMap = types.NewRWMap[string, float64]()
var defaultCompletionRatio = map[string]float64{
"gpt-4-gizmo-*": 2,
@@ -342,79 +331,34 @@ var defaultCompletionRatio = map[string]float64{
// InitRatioSettings initializes all model related settings maps
func InitRatioSettings() {
// Initialize modelPriceMap
modelPriceMapMutex.Lock()
modelPriceMap = defaultModelPrice
modelPriceMapMutex.Unlock()
// Initialize modelRatioMap
modelRatioMapMutex.Lock()
modelRatioMap = defaultModelRatio
modelRatioMapMutex.Unlock()
// Initialize CompletionRatio
CompletionRatioMutex.Lock()
CompletionRatio = defaultCompletionRatio
CompletionRatioMutex.Unlock()
// Initialize cacheRatioMap
cacheRatioMapMutex.Lock()
cacheRatioMap = defaultCacheRatio
cacheRatioMapMutex.Unlock()
// initialize imageRatioMap
imageRatioMapMutex.Lock()
imageRatioMap = defaultImageRatio
imageRatioMapMutex.Unlock()
// initialize audioRatioMap
audioRatioMapMutex.Lock()
audioRatioMap = defaultAudioRatio
audioRatioMapMutex.Unlock()
// initialize audioCompletionRatioMap
audioCompletionRatioMapMutex.Lock()
audioCompletionRatioMap = defaultAudioCompletionRatio
audioCompletionRatioMapMutex.Unlock()
modelPriceMap.AddAll(defaultModelPrice)
modelRatioMap.AddAll(defaultModelRatio)
completionRatioMap.AddAll(defaultCompletionRatio)
cacheRatioMap.AddAll(defaultCacheRatio)
createCacheRatioMap.AddAll(defaultCreateCacheRatio)
imageRatioMap.AddAll(defaultImageRatio)
audioRatioMap.AddAll(defaultAudioRatio)
audioCompletionRatioMap.AddAll(defaultAudioCompletionRatio)
}
func GetModelPriceMap() map[string]float64 {
modelPriceMapMutex.RLock()
defer modelPriceMapMutex.RUnlock()
return modelPriceMap
return modelPriceMap.ReadAll()
}
func ModelPrice2JSONString() string {
modelPriceMapMutex.RLock()
defer modelPriceMapMutex.RUnlock()
jsonBytes, err := common.Marshal(modelPriceMap)
if err != nil {
common.SysError("error marshalling model price: " + err.Error())
}
return string(jsonBytes)
return modelPriceMap.MarshalJSONString()
}
func UpdateModelPriceByJSONString(jsonStr string) error {
modelPriceMapMutex.Lock()
defer modelPriceMapMutex.Unlock()
modelPriceMap = make(map[string]float64)
err := json.Unmarshal([]byte(jsonStr), &modelPriceMap)
if err == nil {
InvalidateExposedDataCache()
}
return err
return types.LoadFromJsonStringWithCallback(modelPriceMap, jsonStr, InvalidateExposedDataCache)
}
// GetModelPrice 返回模型的价格,如果模型不存在则返回-1false
func GetModelPrice(name string, printErr bool) (float64, bool) {
modelPriceMapMutex.RLock()
defer modelPriceMapMutex.RUnlock()
name = FormatMatchingModelName(name)
if strings.HasSuffix(name, CompactModelSuffix) {
price, ok := modelPriceMap[CompactWildcardModelKey]
price, ok := modelPriceMap.Get(CompactWildcardModelKey)
if !ok {
if printErr {
common.SysError("model price not found: " + name)
@@ -424,7 +368,7 @@ func GetModelPrice(name string, printErr bool) (float64, bool) {
return price, true
}
price, ok := modelPriceMap[name]
price, ok := modelPriceMap.Get(name)
if !ok {
if printErr {
common.SysError("model price not found: " + name)
@@ -435,14 +379,7 @@ func GetModelPrice(name string, printErr bool) (float64, bool) {
}
func UpdateModelRatioByJSONString(jsonStr string) error {
modelRatioMapMutex.Lock()
defer modelRatioMapMutex.Unlock()
modelRatioMap = make(map[string]float64)
err := common.Unmarshal([]byte(jsonStr), &modelRatioMap)
if err == nil {
InvalidateExposedDataCache()
}
return err
return types.LoadFromJsonStringWithCallback(modelRatioMap, jsonStr, InvalidateExposedDataCache)
}
// 处理带有思考预算的模型名称,方便统一定价
@@ -454,15 +391,12 @@ func handleThinkingBudgetModel(name, prefix, wildcard string) string {
}
func GetModelRatio(name string) (float64, bool, string) {
modelRatioMapMutex.RLock()
defer modelRatioMapMutex.RUnlock()
name = FormatMatchingModelName(name)
ratio, ok := modelRatioMap[name]
ratio, ok := modelRatioMap.Get(name)
if !ok {
if strings.HasSuffix(name, CompactModelSuffix) {
if wildcardRatio, ok := modelRatioMap[CompactWildcardModelKey]; ok {
if wildcardRatio, ok := modelRatioMap.Get(CompactWildcardModelKey); ok {
return wildcardRatio, true, name
}
//return 0, true, name
@@ -488,54 +422,19 @@ func GetDefaultModelPriceMap() map[string]float64 {
return defaultModelPrice
}
func GetDefaultImageRatioMap() map[string]float64 {
return defaultImageRatio
}
func GetDefaultAudioRatioMap() map[string]float64 {
return defaultAudioRatio
}
func GetDefaultAudioCompletionRatioMap() map[string]float64 {
return defaultAudioCompletionRatio
}
func GetCompletionRatioMap() map[string]float64 {
CompletionRatioMutex.RLock()
defer CompletionRatioMutex.RUnlock()
return CompletionRatio
}
func CompletionRatio2JSONString() string {
CompletionRatioMutex.RLock()
defer CompletionRatioMutex.RUnlock()
jsonBytes, err := json.Marshal(CompletionRatio)
if err != nil {
common.SysError("error marshalling completion ratio: " + err.Error())
}
return string(jsonBytes)
return completionRatioMap.MarshalJSONString()
}
func UpdateCompletionRatioByJSONString(jsonStr string) error {
CompletionRatioMutex.Lock()
defer CompletionRatioMutex.Unlock()
CompletionRatio = make(map[string]float64)
err := common.Unmarshal([]byte(jsonStr), &CompletionRatio)
if err == nil {
InvalidateExposedDataCache()
}
return err
return types.LoadFromJsonStringWithCallback(completionRatioMap, jsonStr, InvalidateExposedDataCache)
}
func GetCompletionRatio(name string) float64 {
CompletionRatioMutex.RLock()
defer CompletionRatioMutex.RUnlock()
name = FormatMatchingModelName(name)
if strings.Contains(name, "/") {
if ratio, ok := CompletionRatio[name]; ok {
if ratio, ok := completionRatioMap.Get(name); ok {
return ratio
}
}
@@ -543,7 +442,7 @@ func GetCompletionRatio(name string) float64 {
if contain {
return hardCodedRatio
}
if ratio, ok := CompletionRatio[name]; ok {
if ratio, ok := completionRatioMap.Get(name); ok {
return ratio
}
return hardCodedRatio
@@ -671,88 +570,54 @@ func getHardcodedCompletionModelRatio(name string) (float64, bool) {
}
func GetAudioRatio(name string) float64 {
audioRatioMapMutex.RLock()
defer audioRatioMapMutex.RUnlock()
name = FormatMatchingModelName(name)
if ratio, ok := audioRatioMap[name]; ok {
if ratio, ok := audioRatioMap.Get(name); ok {
return ratio
}
return 1
}
func GetAudioCompletionRatio(name string) float64 {
audioCompletionRatioMapMutex.RLock()
defer audioCompletionRatioMapMutex.RUnlock()
name = FormatMatchingModelName(name)
if ratio, ok := audioCompletionRatioMap[name]; ok {
if ratio, ok := audioCompletionRatioMap.Get(name); ok {
return ratio
}
return 1
}
func ContainsAudioRatio(name string) bool {
audioRatioMapMutex.RLock()
defer audioRatioMapMutex.RUnlock()
name = FormatMatchingModelName(name)
_, ok := audioRatioMap[name]
_, ok := audioRatioMap.Get(name)
return ok
}
func ContainsAudioCompletionRatio(name string) bool {
audioCompletionRatioMapMutex.RLock()
defer audioCompletionRatioMapMutex.RUnlock()
name = FormatMatchingModelName(name)
_, ok := audioCompletionRatioMap[name]
_, ok := audioCompletionRatioMap.Get(name)
return ok
}
func ModelRatio2JSONString() string {
modelRatioMapMutex.RLock()
defer modelRatioMapMutex.RUnlock()
jsonBytes, err := common.Marshal(modelRatioMap)
if err != nil {
common.SysError("error marshalling model ratio: " + err.Error())
}
return string(jsonBytes)
return modelRatioMap.MarshalJSONString()
}
var defaultImageRatio = map[string]float64{
"gpt-image-1": 2,
}
var imageRatioMap map[string]float64
var imageRatioMapMutex sync.RWMutex
var (
audioRatioMap map[string]float64 = nil
audioRatioMapMutex = sync.RWMutex{}
)
var (
audioCompletionRatioMap map[string]float64 = nil
audioCompletionRatioMapMutex = sync.RWMutex{}
)
var imageRatioMap = types.NewRWMap[string, float64]()
var audioRatioMap = types.NewRWMap[string, float64]()
var audioCompletionRatioMap = types.NewRWMap[string, float64]()
func ImageRatio2JSONString() string {
imageRatioMapMutex.RLock()
defer imageRatioMapMutex.RUnlock()
jsonBytes, err := common.Marshal(imageRatioMap)
if err != nil {
common.SysError("error marshalling cache ratio: " + err.Error())
}
return string(jsonBytes)
return imageRatioMap.MarshalJSONString()
}
func UpdateImageRatioByJSONString(jsonStr string) error {
imageRatioMapMutex.Lock()
defer imageRatioMapMutex.Unlock()
imageRatioMap = make(map[string]float64)
return common.Unmarshal([]byte(jsonStr), &imageRatioMap)
return types.LoadFromJsonString(imageRatioMap, jsonStr)
}
func GetImageRatio(name string) (float64, bool) {
imageRatioMapMutex.RLock()
defer imageRatioMapMutex.RUnlock()
ratio, ok := imageRatioMap[name]
ratio, ok := imageRatioMap.Get(name)
if !ok {
return 1, false // Default to 1 if not found
}
@@ -760,78 +625,31 @@ func GetImageRatio(name string) (float64, bool) {
}
func AudioRatio2JSONString() string {
audioRatioMapMutex.RLock()
defer audioRatioMapMutex.RUnlock()
jsonBytes, err := common.Marshal(audioRatioMap)
if err != nil {
common.SysError("error marshalling audio ratio: " + err.Error())
}
return string(jsonBytes)
return audioRatioMap.MarshalJSONString()
}
func UpdateAudioRatioByJSONString(jsonStr string) error {
tmp := make(map[string]float64)
if err := common.Unmarshal([]byte(jsonStr), &tmp); err != nil {
return err
}
audioRatioMapMutex.Lock()
audioRatioMap = tmp
audioRatioMapMutex.Unlock()
InvalidateExposedDataCache()
return nil
return types.LoadFromJsonStringWithCallback(audioRatioMap, jsonStr, InvalidateExposedDataCache)
}
func AudioCompletionRatio2JSONString() string {
audioCompletionRatioMapMutex.RLock()
defer audioCompletionRatioMapMutex.RUnlock()
jsonBytes, err := common.Marshal(audioCompletionRatioMap)
if err != nil {
common.SysError("error marshalling audio completion ratio: " + err.Error())
}
return string(jsonBytes)
return audioCompletionRatioMap.MarshalJSONString()
}
func UpdateAudioCompletionRatioByJSONString(jsonStr string) error {
tmp := make(map[string]float64)
if err := common.Unmarshal([]byte(jsonStr), &tmp); err != nil {
return err
}
audioCompletionRatioMapMutex.Lock()
audioCompletionRatioMap = tmp
audioCompletionRatioMapMutex.Unlock()
InvalidateExposedDataCache()
return nil
return types.LoadFromJsonStringWithCallback(audioCompletionRatioMap, jsonStr, InvalidateExposedDataCache)
}
func GetModelRatioCopy() map[string]float64 {
modelRatioMapMutex.RLock()
defer modelRatioMapMutex.RUnlock()
copyMap := make(map[string]float64, len(modelRatioMap))
for k, v := range modelRatioMap {
copyMap[k] = v
}
return copyMap
return modelRatioMap.ReadAll()
}
func GetModelPriceCopy() map[string]float64 {
modelPriceMapMutex.RLock()
defer modelPriceMapMutex.RUnlock()
copyMap := make(map[string]float64, len(modelPriceMap))
for k, v := range modelPriceMap {
copyMap[k] = v
}
return copyMap
return modelPriceMap.ReadAll()
}
func GetCompletionRatioCopy() map[string]float64 {
CompletionRatioMutex.RLock()
defer CompletionRatioMutex.RUnlock()
copyMap := make(map[string]float64, len(CompletionRatio))
for k, v := range CompletionRatio {
copyMap[k] = v
}
return copyMap
return completionRatioMap.ReadAll()
}
// 转换模型名,减少渠道必须配置各种带参数模型

View File

@@ -80,3 +80,24 @@ func LoadFromJsonString[K comparable, V any](m *RWMap[K, V], jsonStr string) err
m.data = make(map[K]V)
return common.Unmarshal([]byte(jsonStr), &m.data)
}
// LoadFromJsonStringWithCallback loads a JSON string into the RWMap and calls the callback on success.
func LoadFromJsonStringWithCallback[K comparable, V any](m *RWMap[K, V], jsonStr string, onSuccess func()) error {
m.mutex.Lock()
defer m.mutex.Unlock()
m.data = make(map[K]V)
err := common.Unmarshal([]byte(jsonStr), &m.data)
if err == nil && onSuccess != nil {
onSuccess()
}
return err
}
// MarshalJSONString returns the JSON string representation of the RWMap.
func (m *RWMap[K, V]) MarshalJSONString() string {
bytes, err := m.MarshalJSON()
if err != nil {
return "{}"
}
return string(bytes)
}

View File

@@ -36,6 +36,7 @@ const RatioSetting = () => {
ModelPrice: '',
ModelRatio: '',
CacheRatio: '',
CreateCacheRatio: '',
CompletionRatio: '',
GroupRatio: '',
GroupGroupRatio: '',

View File

@@ -107,9 +107,11 @@ const AccountManagement = ({
const res = await API.get('/api/user/oauth/bindings');
if (res.data.success) {
setCustomOAuthBindings(res.data.data || []);
} else {
showError(res.data.message || t('获取绑定信息失败'));
}
} catch (error) {
// ignore
showError(error.response?.data?.message || error.message || t('获取绑定信息失败'));
}
};
@@ -131,7 +133,7 @@ const AccountManagement = ({
showError(res.data.message);
}
} catch (error) {
showError(t('操作失败'));
showError(error.response?.data?.message || error.message || t('操作失败'));
} finally {
setCustomOAuthLoading((prev) => ({ ...prev, [providerId]: false }));
}

View File

@@ -170,6 +170,7 @@ const EditChannelModal = (props) => {
allow_service_tier: false,
disable_store: false, // false = 允许透传(默认开启)
allow_safety_identifier: false,
claude_beta_query: false,
};
const [batch, setBatch] = useState(false);
const [multiToSingle, setMultiToSingle] = useState(false);
@@ -633,6 +634,7 @@ const EditChannelModal = (props) => {
data.disable_store = parsedSettings.disable_store || false;
data.allow_safety_identifier =
parsedSettings.allow_safety_identifier || false;
data.claude_beta_query = parsedSettings.claude_beta_query || false;
} catch (error) {
console.error('解析其他设置失败:', error);
data.azure_responses_version = '';
@@ -643,6 +645,7 @@ const EditChannelModal = (props) => {
data.allow_service_tier = false;
data.disable_store = false;
data.allow_safety_identifier = false;
data.claude_beta_query = false;
}
} else {
// 兼容历史数据:老渠道没有 settings 时,默认按 json 展示
@@ -652,6 +655,7 @@ const EditChannelModal = (props) => {
data.allow_service_tier = false;
data.disable_store = false;
data.allow_safety_identifier = false;
data.claude_beta_query = false;
}
if (
@@ -1394,6 +1398,9 @@ const EditChannelModal = (props) => {
settings.allow_safety_identifier =
localInputs.allow_safety_identifier === true;
}
if (localInputs.type === 14) {
settings.claude_beta_query = localInputs.claude_beta_query === true;
}
}
localInputs.settings = JSON.stringify(settings);
@@ -1414,6 +1421,7 @@ const EditChannelModal = (props) => {
delete localInputs.allow_service_tier;
delete localInputs.disable_store;
delete localInputs.allow_safety_identifier;
delete localInputs.claude_beta_query;
let res;
localInputs.auto_ban = localInputs.auto_ban ? 1 : 0;
@@ -3316,6 +3324,24 @@ const EditChannelModal = (props) => {
</div>
</div>
{inputs.type === 14 && (
<Form.Switch
field='claude_beta_query'
label={t('Claude 强制 beta=true')}
checkedText={t('开')}
uncheckedText={t('关')}
onChange={(value) =>
handleChannelOtherSettingsChange(
'claude_beta_query',
value,
)
}
extraText={t(
'开启后,该渠道请求 Claude 时将强制追加 ?beta=true无需客户端手动传参',
)}
/>
)}
{inputs.type === 1 && (
<Form.Switch
field='force_format'

View File

@@ -26,8 +26,10 @@ import {
Tag,
Typography,
Select,
Switch,
Banner,
} from '@douyinfe/semi-ui';
import { IconSearch } from '@douyinfe/semi-icons';
import { IconSearch, IconInfoCircle } from '@douyinfe/semi-icons';
import { copy, showError, showInfo, showSuccess } from '../../../../helpers';
import { MODEL_TABLE_PAGE_SIZE } from '../../../../constants';
@@ -48,11 +50,25 @@ const ModelTestModal = ({
setModelTablePage,
selectedEndpointType,
setSelectedEndpointType,
isStreamTest,
setIsStreamTest,
allSelectingRef,
isMobile,
t,
}) => {
const hasChannel = Boolean(currentTestChannel);
const streamToggleDisabled = [
'embeddings',
'image-generation',
'jina-rerank',
'openai-response-compact',
].includes(selectedEndpointType);
React.useEffect(() => {
if (streamToggleDisabled && isStreamTest) {
setIsStreamTest(false);
}
}, [streamToggleDisabled, isStreamTest, setIsStreamTest]);
const filteredModels = hasChannel
? currentTestChannel.models
@@ -181,6 +197,7 @@ const ModelTestModal = ({
currentTestChannel,
record.model,
selectedEndpointType,
isStreamTest,
)
}
loading={isTesting}
@@ -258,25 +275,46 @@ const ModelTestModal = ({
>
{hasChannel && (
<div className='model-test-scroll'>
{/* 端点类型选择器 */}
<div className='flex items-center gap-2 w-full mb-2'>
<Typography.Text strong>{t('端点类型')}:</Typography.Text>
<Select
value={selectedEndpointType}
onChange={setSelectedEndpointType}
optionList={endpointTypeOptions}
className='!w-full'
placeholder={t('选择端点类型')}
/>
{/* Endpoint toolbar */}
<div className='flex flex-col sm:flex-row sm:items-center gap-2 w-full mb-2'>
<div className='flex items-center gap-2 flex-1 min-w-0'>
<Typography.Text strong className='shrink-0'>
{t('端点类型')}:
</Typography.Text>
<Select
value={selectedEndpointType}
onChange={setSelectedEndpointType}
optionList={endpointTypeOptions}
className='!w-full min-w-0'
placeholder={t('选择端点类型')}
/>
</div>
<div className='flex items-center justify-between sm:justify-end gap-2 shrink-0'>
<Typography.Text strong className='shrink-0'>
{t('流式')}:
</Typography.Text>
<Switch
checked={isStreamTest}
onChange={setIsStreamTest}
size='small'
disabled={streamToggleDisabled}
aria-label={t('流式')}
/>
</div>
</div>
<Typography.Text type='tertiary' size='small' className='block mb-2'>
{t(
<Banner
type='info'
closeIcon={null}
icon={<IconInfoCircle />}
className='!rounded-lg mb-2'
description={t(
'说明:本页测试为非流式请求;若渠道仅支持流式返回,可能出现测试失败,请以实际使用为准。',
)}
</Typography.Text>
/>
{/* 搜索与操作按钮 */}
<div className='flex items-center justify-end gap-2 w-full mb-2'>
<div className='flex flex-col sm:flex-row sm:items-center gap-2 w-full mb-2'>
<Input
placeholder={t('搜索模型...')}
value={modelSearchKeyword}
@@ -284,16 +322,17 @@ const ModelTestModal = ({
setModelSearchKeyword(v);
setModelTablePage(1);
}}
className='!w-full'
className='!w-full sm:!flex-1'
prefix={<IconSearch />}
showClear
/>
<Button onClick={handleCopySelected}>{t('复制已选')}</Button>
<Button type='tertiary' onClick={handleSelectSuccess}>
{t('选择成功')}
</Button>
<div className='flex items-center justify-end gap-2'>
<Button onClick={handleCopySelected}>{t('复制已选')}</Button>
<Button type='tertiary' onClick={handleSelectSuccess}>
{t('选择成功')}
</Button>
</div>
</div>
<Table

View File

@@ -87,6 +87,7 @@ export const useChannelsData = () => {
const [isBatchTesting, setIsBatchTesting] = useState(false);
const [modelTablePage, setModelTablePage] = useState(1);
const [selectedEndpointType, setSelectedEndpointType] = useState('');
const [isStreamTest, setIsStreamTest] = useState(false);
const [globalPassThroughEnabled, setGlobalPassThroughEnabled] =
useState(false);
@@ -851,7 +852,12 @@ export const useChannelsData = () => {
};
// Test channel - 单个模型测试,参考旧版实现
const testChannel = async (record, model, endpointType = '') => {
const testChannel = async (
record,
model,
endpointType = '',
stream = false,
) => {
const testKey = `${record.id}-${model}`;
// 检查是否应该停止批量测试
@@ -867,6 +873,9 @@ export const useChannelsData = () => {
if (endpointType) {
url += `&endpoint_type=${endpointType}`;
}
if (stream) {
url += `&stream=true`;
}
const res = await API.get(url);
// 检查是否在请求期间被停止
@@ -995,7 +1004,12 @@ export const useChannelsData = () => {
);
const batchPromises = batch.map((model) =>
testChannel(currentTestChannel, model, selectedEndpointType),
testChannel(
currentTestChannel,
model,
selectedEndpointType,
isStreamTest,
),
);
const batchResults = await Promise.allSettled(batchPromises);
results.push(...batchResults);
@@ -1080,6 +1094,7 @@ export const useChannelsData = () => {
setSelectedModelKeys([]);
setModelTablePage(1);
setSelectedEndpointType('');
setIsStreamTest(false);
// 可选择性保留测试结果,这里不清空以便用户查看
};
@@ -1170,6 +1185,8 @@ export const useChannelsData = () => {
setModelTablePage,
selectedEndpointType,
setSelectedEndpointType,
isStreamTest,
setIsStreamTest,
allSelectingRef,
// Multi-key management states

View File

@@ -1146,6 +1146,8 @@
"提示:链接中的{key}将被替换为API密钥{address}将被替换为服务器地址": "Tip: {key} in the link will be replaced with the API key, {address} will be replaced with the server address",
"提示价格:{{symbol}}{{price}} / 1M tokens": "Prompt price: {{symbol}}{{price}} / 1M tokens",
"提示缓存倍率": "Prompt cache ratio",
"缓存创建倍率": "Cache creation ratio",
"默认为 5m 缓存创建倍率1h 缓存创建倍率按固定乘法自动计算(当前为 1.6x": "Defaults to the 5m cache creation ratio; the 1h cache creation ratio is computed by fixed multiplication (currently 1.6x)",
"搜索供应商": "Search vendor",
"搜索关键字": "Search keywords",
"搜索失败": "Search failed",
@@ -1548,6 +1550,7 @@
"流": "stream",
"流式响应完成": "Streaming response completed",
"流式输出": "Streaming Output",
"流式": "Streaming",
"流量端口": "Traffic Port",
"浅色": "Light",
"浅色模式": "Light Mode",

View File

@@ -1156,6 +1156,8 @@
"提示:链接中的{key}将被替换为API密钥{address}将被替换为服务器地址": "Astuce : {key} dans le lien sera remplacé par la clé API, {address} sera remplacé par l'adresse du serveur",
"提示价格:{{symbol}}{{price}} / 1M tokens": "Prix d'invite : {{symbol}}{{price}} / 1M tokens",
"提示缓存倍率": "Ratio de cache d'invite",
"缓存创建倍率": "Ratio de création du cache",
"默认为 5m 缓存创建倍率1h 缓存创建倍率按固定乘法自动计算(当前为 1.6x": "Par défaut, le ratio de création de cache 5m est utilisé ; le ratio de création de cache 1h est calculé via une multiplication fixe (actuellement 1.6x)",
"搜索供应商": "Rechercher un fournisseur",
"搜索关键字": "Rechercher des mots-clés",
"搜索失败": "Search failed",
@@ -1558,6 +1560,7 @@
"流": "Flux",
"流式响应完成": "Flux terminé",
"流式输出": "Sortie en flux",
"流式": "Streaming",
"流量端口": "Traffic Port",
"浅色": "Clair",
"浅色模式": "Mode clair",

View File

@@ -1141,6 +1141,8 @@
"提示:链接中的{key}将被替换为API密钥{address}将被替换为服务器地址": "ヒント:リンク内の{key}はAPIキーに、{address}はサーバーURLに置換されます",
"提示价格:{{symbol}}{{price}} / 1M tokens": "プロンプト料金:{{symbol}}{{price}} / 1M tokens",
"提示缓存倍率": "プロンプトキャッシュ倍率",
"缓存创建倍率": "キャッシュ作成倍率",
"默认为 5m 缓存创建倍率1h 缓存创建倍率按固定乘法自动计算(当前为 1.6x": "デフォルトは5mのキャッシュ作成倍率です。1hのキャッシュ作成倍率は固定乗数で自動計算されます現在は1.6倍)",
"搜索供应商": "プロバイダーで検索",
"搜索关键字": "検索キーワード",
"搜索失败": "Search failed",
@@ -1543,6 +1545,7 @@
"流": "ストリーム",
"流式响应完成": "ストリーム完了",
"流式输出": "ストリーム出力",
"流式": "ストリーミング",
"流量端口": "Traffic Port",
"浅色": "ライト",
"浅色模式": "ライトモード",

View File

@@ -1167,6 +1167,8 @@
"提示:链接中的{key}将被替换为API密钥{address}将被替换为服务器地址": "Промпт: {key} в ссылке будет заменен на API-ключ, {address} будет заменен на адрес сервера",
"提示价格:{{symbol}}{{price}} / 1M tokens": "Цена промпта: {{symbol}}{{price}} / 1M токенов",
"提示缓存倍率": "Коэффициент кэша промптов",
"缓存创建倍率": "Коэффициент создания кэша",
"默认为 5m 缓存创建倍率1h 缓存创建倍率按固定乘法自动计算(当前为 1.6x": "По умолчанию используется коэффициент создания кэша 5m; коэффициент создания кэша 1h автоматически вычисляется фиксированным умножением (сейчас 1.6x)",
"搜索供应商": "Поиск поставщиков",
"搜索关键字": "Поиск по ключевым словам",
"搜索失败": "Search failed",
@@ -1569,6 +1571,7 @@
"流": "Поток",
"流式响应完成": "Поток завершён",
"流式输出": "Потоковый вывод",
"流式": "Стриминг",
"流量端口": "Traffic Port",
"浅色": "Светлая",
"浅色模式": "Светлый режим",

View File

@@ -1142,6 +1142,8 @@
"提示:链接中的{key}将被替换为API密钥{address}将被替换为服务器地址": "Mẹo: {key} trong liên kết sẽ được thay thế bằng khóa API, {address} sẽ được thay thế bằng địa chỉ máy chủ",
"提示价格:{{symbol}}{{price}} / 1M tokens": "Giá gợi ý: {{symbol}}{{price}} / 1M tokens",
"提示缓存倍率": "Tỷ lệ bộ nhớ đệm gợi ý",
"缓存创建倍率": "Tỷ lệ tạo bộ nhớ đệm",
"默认为 5m 缓存创建倍率1h 缓存创建倍率按固定乘法自动计算(当前为 1.6x": "Mặc định dùng tỷ lệ tạo bộ nhớ đệm 5m; tỷ lệ tạo bộ nhớ đệm 1h được tự động tính bằng phép nhân cố định (hiện là 1.6x)",
"搜索供应商": "Tìm kiếm nhà cung cấp",
"搜索关键字": "Từ khóa tìm kiếm",
"搜索失败": "Search failed",
@@ -1597,6 +1599,7 @@
"流": "luồng",
"流式响应完成": "Luồng hoàn tất",
"流式输出": "Đầu ra luồng",
"流式": "Streaming",
"流量端口": "Traffic Port",
"浅色": "Sáng",
"浅色模式": "Chế độ sáng",

View File

@@ -1136,6 +1136,8 @@
"提示:链接中的{key}将被替换为API密钥{address}将被替换为服务器地址": "提示:链接中的{key}将被替换为API密钥{address}将被替换为服务器地址",
"提示价格:{{symbol}}{{price}} / 1M tokens": "提示价格:{{symbol}}{{price}} / 1M tokens",
"提示缓存倍率": "提示缓存倍率",
"缓存创建倍率": "缓存创建倍率",
"默认为 5m 缓存创建倍率1h 缓存创建倍率按固定乘法自动计算(当前为 1.6x": "默认为 5m 缓存创建倍率1h 缓存创建倍率按固定乘法自动计算(当前为 1.6x",
"搜索供应商": "搜索供应商",
"搜索关键字": "搜索关键字",
"搜索失败": "搜索失败",
@@ -1538,6 +1540,7 @@
"流": "流",
"流式响应完成": "流式响应完成",
"流式输出": "流式输出",
"流式": "流式",
"流量端口": "流量端口",
"浅色": "浅色",
"浅色模式": "浅色模式",

View File

@@ -43,6 +43,7 @@ export default function ModelRatioSettings(props) {
ModelPrice: '',
ModelRatio: '',
CacheRatio: '',
CreateCacheRatio: '',
CompletionRatio: '',
ImageRatio: '',
AudioRatio: '',
@@ -200,6 +201,30 @@ export default function ModelRatioSettings(props) {
/>
</Col>
</Row>
<Row gutter={16}>
<Col xs={24} sm={16}>
<Form.TextArea
label={t('缓存创建倍率')}
extraText={t(
'默认为 5m 缓存创建倍率1h 缓存创建倍率按固定乘法自动计算(当前为 1.6x',
)}
placeholder={t('为一个 JSON 文本,键为模型名称,值为倍率')}
field={'CreateCacheRatio'}
autosize={{ minRows: 6, maxRows: 12 }}
trigger='blur'
stopValidateWithError
rules={[
{
validator: (rule, value) => verifyJSON(value),
message: '不是合法的 JSON 字符串',
},
]}
onChange={(value) =>
setInputs({ ...inputs, CreateCacheRatio: value })
}
/>
</Col>
</Row>
<Row gutter={16}>
<Col xs={24} sm={16}>
<Form.TextArea