From 4e8615f276114b41e638294f11e3a1a00189aa4a Mon Sep 17 00:00:00 2001 From: Wang Lvyuan <74089601+LvyuanW@users.noreply.github.com> Date: Sat, 14 Mar 2026 10:47:31 +0800 Subject: [PATCH 1/2] fix: honor account model mapping before group fallback --- backend/internal/service/account.go | 22 ++++-- .../internal/service/account_wildcard_test.go | 63 +++++++++++++++++ .../openai_gateway_chat_completions.go | 5 +- .../service/openai_gateway_messages.go | 6 +- .../internal/service/openai_model_mapping.go | 19 +++++ .../service/openai_model_mapping_test.go | 70 +++++++++++++++++++ 6 files changed, 171 insertions(+), 14 deletions(-) create mode 100644 backend/internal/service/openai_model_mapping.go create mode 100644 backend/internal/service/openai_model_mapping_test.go diff --git a/backend/internal/service/account.go b/backend/internal/service/account.go index 9d4f73d4..91c85196 100644 --- a/backend/internal/service/account.go +++ b/backend/internal/service/account.go @@ -521,16 +521,23 @@ func (a *Account) IsModelSupported(requestedModel string) bool { // GetMappedModel 获取映射后的模型名(支持通配符,最长优先匹配) // 如果未配置 mapping,返回原始模型名 func (a *Account) GetMappedModel(requestedModel string) string { + mappedModel, _ := a.ResolveMappedModel(requestedModel) + return mappedModel +} + +// ResolveMappedModel 获取映射后的模型名,并返回是否命中了账号级映射。 +// matched=true 表示命中了精确映射或通配符映射,即使映射结果与原模型名相同。 +func (a *Account) ResolveMappedModel(requestedModel string) (mappedModel string, matched bool) { mapping := a.GetModelMapping() if len(mapping) == 0 { - return requestedModel + return requestedModel, false } // 精确匹配优先 if mappedModel, exists := mapping[requestedModel]; exists { - return mappedModel + return mappedModel, true } // 通配符匹配(最长优先) - return matchWildcardMapping(mapping, requestedModel) + return matchWildcardMappingResult(mapping, requestedModel) } func (a *Account) GetBaseURL() string { @@ -607,6 +614,11 @@ func matchWildcard(pattern, str string) bool { // matchWildcardMapping 通配符映射匹配(最长优先) // 如果没有匹配,返回原始字符串 func matchWildcardMapping(mapping map[string]string, requestedModel string) string { + mappedModel, _ := matchWildcardMappingResult(mapping, requestedModel) + return mappedModel +} + +func matchWildcardMappingResult(mapping map[string]string, requestedModel string) (string, bool) { // 收集所有匹配的 pattern,按长度降序排序(最长优先) type patternMatch struct { pattern string @@ -621,7 +633,7 @@ func matchWildcardMapping(mapping map[string]string, requestedModel string) stri } if len(matches) == 0 { - return requestedModel // 无匹配,返回原始模型名 + return requestedModel, false // 无匹配,返回原始模型名 } // 按 pattern 长度降序排序 @@ -632,7 +644,7 @@ func matchWildcardMapping(mapping map[string]string, requestedModel string) stri return matches[i].pattern < matches[j].pattern }) - return matches[0].target + return matches[0].target, true } func (a *Account) IsCustomErrorCodesEnabled() bool { diff --git a/backend/internal/service/account_wildcard_test.go b/backend/internal/service/account_wildcard_test.go index 7782f948..652735d3 100644 --- a/backend/internal/service/account_wildcard_test.go +++ b/backend/internal/service/account_wildcard_test.go @@ -268,6 +268,69 @@ func TestAccountGetMappedModel(t *testing.T) { } } +func TestAccountResolveMappedModel(t *testing.T) { + tests := []struct { + name string + credentials map[string]any + requestedModel string + expectedModel string + expectedMatch bool + }{ + { + name: "no mapping reports unmatched", + credentials: nil, + requestedModel: "gpt-5.4", + expectedModel: "gpt-5.4", + expectedMatch: false, + }, + { + name: "exact passthrough mapping still counts as matched", + credentials: map[string]any{ + "model_mapping": map[string]any{ + "gpt-5.4": "gpt-5.4", + }, + }, + requestedModel: "gpt-5.4", + expectedModel: "gpt-5.4", + expectedMatch: true, + }, + { + name: "wildcard passthrough mapping still counts as matched", + credentials: map[string]any{ + "model_mapping": map[string]any{ + "gpt-*": "gpt-5.4", + }, + }, + requestedModel: "gpt-5.4", + expectedModel: "gpt-5.4", + expectedMatch: true, + }, + { + name: "missing mapping reports unmatched", + credentials: map[string]any{ + "model_mapping": map[string]any{ + "gpt-5.2": "gpt-5.2", + }, + }, + requestedModel: "gpt-5.4", + expectedModel: "gpt-5.4", + expectedMatch: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + account := &Account{ + Credentials: tt.credentials, + } + mappedModel, matched := account.ResolveMappedModel(tt.requestedModel) + if mappedModel != tt.expectedModel || matched != tt.expectedMatch { + t.Fatalf("ResolveMappedModel(%q) = (%q, %v), want (%q, %v)", tt.requestedModel, mappedModel, matched, tt.expectedModel, tt.expectedMatch) + } + }) + } +} + func TestAccountGetModelMapping_AntigravityEnsuresGeminiDefaultPassthroughs(t *testing.T) { account := &Account{ Platform: PlatformAntigravity, diff --git a/backend/internal/service/openai_gateway_chat_completions.go b/backend/internal/service/openai_gateway_chat_completions.go index f893eeb9..9529f6be 100644 --- a/backend/internal/service/openai_gateway_chat_completions.go +++ b/backend/internal/service/openai_gateway_chat_completions.go @@ -51,10 +51,7 @@ func (s *OpenAIGatewayService) ForwardAsChatCompletions( } // 3. Model mapping - mappedModel := account.GetMappedModel(originalModel) - if mappedModel == originalModel && defaultMappedModel != "" { - mappedModel = defaultMappedModel - } + mappedModel := resolveOpenAIForwardModel(account, originalModel, defaultMappedModel) responsesReq.Model = mappedModel logger.L().Debug("openai chat_completions: model mapping applied", diff --git a/backend/internal/service/openai_gateway_messages.go b/backend/internal/service/openai_gateway_messages.go index e4a3d9c0..1e40ec6f 100644 --- a/backend/internal/service/openai_gateway_messages.go +++ b/backend/internal/service/openai_gateway_messages.go @@ -59,11 +59,7 @@ func (s *OpenAIGatewayService) ForwardAsAnthropic( } // 3. Model mapping - mappedModel := account.GetMappedModel(originalModel) - // 分组级降级:账号未映射时使用分组默认映射模型 - if mappedModel == originalModel && defaultMappedModel != "" { - mappedModel = defaultMappedModel - } + mappedModel := resolveOpenAIForwardModel(account, originalModel, defaultMappedModel) responsesReq.Model = mappedModel logger.L().Debug("openai messages: model mapping applied", diff --git a/backend/internal/service/openai_model_mapping.go b/backend/internal/service/openai_model_mapping.go new file mode 100644 index 00000000..9bf3fba3 --- /dev/null +++ b/backend/internal/service/openai_model_mapping.go @@ -0,0 +1,19 @@ +package service + +// resolveOpenAIForwardModel determines the upstream model for OpenAI-compatible +// forwarding. Group-level default mapping only applies when the account itself +// did not match any explicit model_mapping rule. +func resolveOpenAIForwardModel(account *Account, requestedModel, defaultMappedModel string) string { + if account == nil { + if defaultMappedModel != "" { + return defaultMappedModel + } + return requestedModel + } + + mappedModel, matched := account.ResolveMappedModel(requestedModel) + if !matched && defaultMappedModel != "" { + return defaultMappedModel + } + return mappedModel +} diff --git a/backend/internal/service/openai_model_mapping_test.go b/backend/internal/service/openai_model_mapping_test.go new file mode 100644 index 00000000..7af3ecae --- /dev/null +++ b/backend/internal/service/openai_model_mapping_test.go @@ -0,0 +1,70 @@ +package service + +import "testing" + +func TestResolveOpenAIForwardModel(t *testing.T) { + tests := []struct { + name string + account *Account + requestedModel string + defaultMappedModel string + expectedModel string + }{ + { + name: "falls back to group default when account has no mapping", + account: &Account{ + Credentials: map[string]any{}, + }, + requestedModel: "gpt-5.4", + defaultMappedModel: "gpt-4o-mini", + expectedModel: "gpt-4o-mini", + }, + { + name: "preserves exact passthrough mapping instead of group default", + account: &Account{ + Credentials: map[string]any{ + "model_mapping": map[string]any{ + "gpt-5.4": "gpt-5.4", + }, + }, + }, + requestedModel: "gpt-5.4", + defaultMappedModel: "gpt-4o-mini", + expectedModel: "gpt-5.4", + }, + { + name: "preserves wildcard passthrough mapping instead of group default", + account: &Account{ + Credentials: map[string]any{ + "model_mapping": map[string]any{ + "gpt-*": "gpt-5.4", + }, + }, + }, + requestedModel: "gpt-5.4", + defaultMappedModel: "gpt-4o-mini", + expectedModel: "gpt-5.4", + }, + { + name: "uses account remap when explicit target differs", + account: &Account{ + Credentials: map[string]any{ + "model_mapping": map[string]any{ + "gpt-5": "gpt-5.4", + }, + }, + }, + requestedModel: "gpt-5", + defaultMappedModel: "gpt-4o-mini", + expectedModel: "gpt-5.4", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + if got := resolveOpenAIForwardModel(tt.account, tt.requestedModel, tt.defaultMappedModel); got != tt.expectedModel { + t.Fatalf("resolveOpenAIForwardModel(...) = %q, want %q", got, tt.expectedModel) + } + }) + } +} From a377e990885f5bcecb98c4797bd988e810f71842 Mon Sep 17 00:00:00 2001 From: Wang Lvyuan <74089601+LvyuanW@users.noreply.github.com> Date: Sat, 14 Mar 2026 12:56:34 +0800 Subject: [PATCH 2/2] fix: remove unused wildcard mapping helper --- backend/internal/service/account.go | 7 ------- backend/internal/service/account_wildcard_test.go | 15 +++++++++++---- 2 files changed, 11 insertions(+), 11 deletions(-) diff --git a/backend/internal/service/account.go b/backend/internal/service/account.go index 7c858fd5..6c88ed68 100644 --- a/backend/internal/service/account.go +++ b/backend/internal/service/account.go @@ -612,13 +612,6 @@ func matchWildcard(pattern, str string) bool { return matchAntigravityWildcard(pattern, str) } -// matchWildcardMapping 通配符映射匹配(最长优先) -// 如果没有匹配,返回原始字符串 -func matchWildcardMapping(mapping map[string]string, requestedModel string) string { - mappedModel, _ := matchWildcardMappingResult(mapping, requestedModel) - return mappedModel -} - func matchWildcardMappingResult(mapping map[string]string, requestedModel string) (string, bool) { // 收集所有匹配的 pattern,按长度降序排序(最长优先) type patternMatch struct { diff --git a/backend/internal/service/account_wildcard_test.go b/backend/internal/service/account_wildcard_test.go index 652735d3..0d7ffffa 100644 --- a/backend/internal/service/account_wildcard_test.go +++ b/backend/internal/service/account_wildcard_test.go @@ -43,12 +43,13 @@ func TestMatchWildcard(t *testing.T) { } } -func TestMatchWildcardMapping(t *testing.T) { +func TestMatchWildcardMappingResult(t *testing.T) { tests := []struct { name string mapping map[string]string requestedModel string expected string + matched bool }{ // 精确匹配优先于通配符 { @@ -59,6 +60,7 @@ func TestMatchWildcardMapping(t *testing.T) { }, requestedModel: "claude-sonnet-4-5", expected: "claude-sonnet-4-5-exact", + matched: true, }, // 最长通配符优先 @@ -71,6 +73,7 @@ func TestMatchWildcardMapping(t *testing.T) { }, requestedModel: "claude-sonnet-4-5", expected: "claude-sonnet-4-series", + matched: true, }, // 单个通配符 @@ -81,6 +84,7 @@ func TestMatchWildcardMapping(t *testing.T) { }, requestedModel: "claude-opus-4-5", expected: "claude-mapped", + matched: true, }, // 无匹配返回原始模型 @@ -91,6 +95,7 @@ func TestMatchWildcardMapping(t *testing.T) { }, requestedModel: "gemini-3-flash", expected: "gemini-3-flash", + matched: false, }, // 空映射返回原始模型 @@ -99,6 +104,7 @@ func TestMatchWildcardMapping(t *testing.T) { mapping: map[string]string{}, requestedModel: "claude-sonnet-4-5", expected: "claude-sonnet-4-5", + matched: false, }, // Gemini 模型映射 @@ -110,14 +116,15 @@ func TestMatchWildcardMapping(t *testing.T) { }, requestedModel: "gemini-3-flash-preview", expected: "gemini-3-pro-high", + matched: true, }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - result := matchWildcardMapping(tt.mapping, tt.requestedModel) - if result != tt.expected { - t.Errorf("matchWildcardMapping(%v, %q) = %q, want %q", tt.mapping, tt.requestedModel, result, tt.expected) + result, matched := matchWildcardMappingResult(tt.mapping, tt.requestedModel) + if result != tt.expected || matched != tt.matched { + t.Errorf("matchWildcardMappingResult(%v, %q) = (%q, %v), want (%q, %v)", tt.mapping, tt.requestedModel, result, matched, tt.expected, tt.matched) } }) }