mirror of
https://github.com/Wei-Shaw/sub2api.git
synced 2026-03-30 02:27:11 +00:00
fix: honor account model mapping before group fallback
This commit is contained in:
@@ -521,16 +521,23 @@ func (a *Account) IsModelSupported(requestedModel string) bool {
|
|||||||
// GetMappedModel 获取映射后的模型名(支持通配符,最长优先匹配)
|
// GetMappedModel 获取映射后的模型名(支持通配符,最长优先匹配)
|
||||||
// 如果未配置 mapping,返回原始模型名
|
// 如果未配置 mapping,返回原始模型名
|
||||||
func (a *Account) GetMappedModel(requestedModel string) string {
|
func (a *Account) GetMappedModel(requestedModel string) string {
|
||||||
|
mappedModel, _ := a.ResolveMappedModel(requestedModel)
|
||||||
|
return mappedModel
|
||||||
|
}
|
||||||
|
|
||||||
|
// ResolveMappedModel 获取映射后的模型名,并返回是否命中了账号级映射。
|
||||||
|
// matched=true 表示命中了精确映射或通配符映射,即使映射结果与原模型名相同。
|
||||||
|
func (a *Account) ResolveMappedModel(requestedModel string) (mappedModel string, matched bool) {
|
||||||
mapping := a.GetModelMapping()
|
mapping := a.GetModelMapping()
|
||||||
if len(mapping) == 0 {
|
if len(mapping) == 0 {
|
||||||
return requestedModel
|
return requestedModel, false
|
||||||
}
|
}
|
||||||
// 精确匹配优先
|
// 精确匹配优先
|
||||||
if mappedModel, exists := mapping[requestedModel]; exists {
|
if mappedModel, exists := mapping[requestedModel]; exists {
|
||||||
return mappedModel
|
return mappedModel, true
|
||||||
}
|
}
|
||||||
// 通配符匹配(最长优先)
|
// 通配符匹配(最长优先)
|
||||||
return matchWildcardMapping(mapping, requestedModel)
|
return matchWildcardMappingResult(mapping, requestedModel)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (a *Account) GetBaseURL() string {
|
func (a *Account) GetBaseURL() string {
|
||||||
@@ -607,6 +614,11 @@ func matchWildcard(pattern, str string) bool {
|
|||||||
// matchWildcardMapping 通配符映射匹配(最长优先)
|
// matchWildcardMapping 通配符映射匹配(最长优先)
|
||||||
// 如果没有匹配,返回原始字符串
|
// 如果没有匹配,返回原始字符串
|
||||||
func matchWildcardMapping(mapping map[string]string, requestedModel string) string {
|
func matchWildcardMapping(mapping map[string]string, requestedModel string) string {
|
||||||
|
mappedModel, _ := matchWildcardMappingResult(mapping, requestedModel)
|
||||||
|
return mappedModel
|
||||||
|
}
|
||||||
|
|
||||||
|
func matchWildcardMappingResult(mapping map[string]string, requestedModel string) (string, bool) {
|
||||||
// 收集所有匹配的 pattern,按长度降序排序(最长优先)
|
// 收集所有匹配的 pattern,按长度降序排序(最长优先)
|
||||||
type patternMatch struct {
|
type patternMatch struct {
|
||||||
pattern string
|
pattern string
|
||||||
@@ -621,7 +633,7 @@ func matchWildcardMapping(mapping map[string]string, requestedModel string) stri
|
|||||||
}
|
}
|
||||||
|
|
||||||
if len(matches) == 0 {
|
if len(matches) == 0 {
|
||||||
return requestedModel // 无匹配,返回原始模型名
|
return requestedModel, false // 无匹配,返回原始模型名
|
||||||
}
|
}
|
||||||
|
|
||||||
// 按 pattern 长度降序排序
|
// 按 pattern 长度降序排序
|
||||||
@@ -632,7 +644,7 @@ func matchWildcardMapping(mapping map[string]string, requestedModel string) stri
|
|||||||
return matches[i].pattern < matches[j].pattern
|
return matches[i].pattern < matches[j].pattern
|
||||||
})
|
})
|
||||||
|
|
||||||
return matches[0].target
|
return matches[0].target, true
|
||||||
}
|
}
|
||||||
|
|
||||||
func (a *Account) IsCustomErrorCodesEnabled() bool {
|
func (a *Account) IsCustomErrorCodesEnabled() bool {
|
||||||
|
|||||||
@@ -268,6 +268,69 @@ func TestAccountGetMappedModel(t *testing.T) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestAccountResolveMappedModel(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
credentials map[string]any
|
||||||
|
requestedModel string
|
||||||
|
expectedModel string
|
||||||
|
expectedMatch bool
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "no mapping reports unmatched",
|
||||||
|
credentials: nil,
|
||||||
|
requestedModel: "gpt-5.4",
|
||||||
|
expectedModel: "gpt-5.4",
|
||||||
|
expectedMatch: false,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "exact passthrough mapping still counts as matched",
|
||||||
|
credentials: map[string]any{
|
||||||
|
"model_mapping": map[string]any{
|
||||||
|
"gpt-5.4": "gpt-5.4",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
requestedModel: "gpt-5.4",
|
||||||
|
expectedModel: "gpt-5.4",
|
||||||
|
expectedMatch: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "wildcard passthrough mapping still counts as matched",
|
||||||
|
credentials: map[string]any{
|
||||||
|
"model_mapping": map[string]any{
|
||||||
|
"gpt-*": "gpt-5.4",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
requestedModel: "gpt-5.4",
|
||||||
|
expectedModel: "gpt-5.4",
|
||||||
|
expectedMatch: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "missing mapping reports unmatched",
|
||||||
|
credentials: map[string]any{
|
||||||
|
"model_mapping": map[string]any{
|
||||||
|
"gpt-5.2": "gpt-5.2",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
requestedModel: "gpt-5.4",
|
||||||
|
expectedModel: "gpt-5.4",
|
||||||
|
expectedMatch: false,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
account := &Account{
|
||||||
|
Credentials: tt.credentials,
|
||||||
|
}
|
||||||
|
mappedModel, matched := account.ResolveMappedModel(tt.requestedModel)
|
||||||
|
if mappedModel != tt.expectedModel || matched != tt.expectedMatch {
|
||||||
|
t.Fatalf("ResolveMappedModel(%q) = (%q, %v), want (%q, %v)", tt.requestedModel, mappedModel, matched, tt.expectedModel, tt.expectedMatch)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
func TestAccountGetModelMapping_AntigravityEnsuresGeminiDefaultPassthroughs(t *testing.T) {
|
func TestAccountGetModelMapping_AntigravityEnsuresGeminiDefaultPassthroughs(t *testing.T) {
|
||||||
account := &Account{
|
account := &Account{
|
||||||
Platform: PlatformAntigravity,
|
Platform: PlatformAntigravity,
|
||||||
|
|||||||
@@ -51,10 +51,7 @@ func (s *OpenAIGatewayService) ForwardAsChatCompletions(
|
|||||||
}
|
}
|
||||||
|
|
||||||
// 3. Model mapping
|
// 3. Model mapping
|
||||||
mappedModel := account.GetMappedModel(originalModel)
|
mappedModel := resolveOpenAIForwardModel(account, originalModel, defaultMappedModel)
|
||||||
if mappedModel == originalModel && defaultMappedModel != "" {
|
|
||||||
mappedModel = defaultMappedModel
|
|
||||||
}
|
|
||||||
responsesReq.Model = mappedModel
|
responsesReq.Model = mappedModel
|
||||||
|
|
||||||
logger.L().Debug("openai chat_completions: model mapping applied",
|
logger.L().Debug("openai chat_completions: model mapping applied",
|
||||||
|
|||||||
@@ -59,11 +59,7 @@ func (s *OpenAIGatewayService) ForwardAsAnthropic(
|
|||||||
}
|
}
|
||||||
|
|
||||||
// 3. Model mapping
|
// 3. Model mapping
|
||||||
mappedModel := account.GetMappedModel(originalModel)
|
mappedModel := resolveOpenAIForwardModel(account, originalModel, defaultMappedModel)
|
||||||
// 分组级降级:账号未映射时使用分组默认映射模型
|
|
||||||
if mappedModel == originalModel && defaultMappedModel != "" {
|
|
||||||
mappedModel = defaultMappedModel
|
|
||||||
}
|
|
||||||
responsesReq.Model = mappedModel
|
responsesReq.Model = mappedModel
|
||||||
|
|
||||||
logger.L().Debug("openai messages: model mapping applied",
|
logger.L().Debug("openai messages: model mapping applied",
|
||||||
|
|||||||
19
backend/internal/service/openai_model_mapping.go
Normal file
19
backend/internal/service/openai_model_mapping.go
Normal file
@@ -0,0 +1,19 @@
|
|||||||
|
package service
|
||||||
|
|
||||||
|
// resolveOpenAIForwardModel determines the upstream model for OpenAI-compatible
|
||||||
|
// forwarding. Group-level default mapping only applies when the account itself
|
||||||
|
// did not match any explicit model_mapping rule.
|
||||||
|
func resolveOpenAIForwardModel(account *Account, requestedModel, defaultMappedModel string) string {
|
||||||
|
if account == nil {
|
||||||
|
if defaultMappedModel != "" {
|
||||||
|
return defaultMappedModel
|
||||||
|
}
|
||||||
|
return requestedModel
|
||||||
|
}
|
||||||
|
|
||||||
|
mappedModel, matched := account.ResolveMappedModel(requestedModel)
|
||||||
|
if !matched && defaultMappedModel != "" {
|
||||||
|
return defaultMappedModel
|
||||||
|
}
|
||||||
|
return mappedModel
|
||||||
|
}
|
||||||
70
backend/internal/service/openai_model_mapping_test.go
Normal file
70
backend/internal/service/openai_model_mapping_test.go
Normal file
@@ -0,0 +1,70 @@
|
|||||||
|
package service
|
||||||
|
|
||||||
|
import "testing"
|
||||||
|
|
||||||
|
func TestResolveOpenAIForwardModel(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
account *Account
|
||||||
|
requestedModel string
|
||||||
|
defaultMappedModel string
|
||||||
|
expectedModel string
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "falls back to group default when account has no mapping",
|
||||||
|
account: &Account{
|
||||||
|
Credentials: map[string]any{},
|
||||||
|
},
|
||||||
|
requestedModel: "gpt-5.4",
|
||||||
|
defaultMappedModel: "gpt-4o-mini",
|
||||||
|
expectedModel: "gpt-4o-mini",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "preserves exact passthrough mapping instead of group default",
|
||||||
|
account: &Account{
|
||||||
|
Credentials: map[string]any{
|
||||||
|
"model_mapping": map[string]any{
|
||||||
|
"gpt-5.4": "gpt-5.4",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
requestedModel: "gpt-5.4",
|
||||||
|
defaultMappedModel: "gpt-4o-mini",
|
||||||
|
expectedModel: "gpt-5.4",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "preserves wildcard passthrough mapping instead of group default",
|
||||||
|
account: &Account{
|
||||||
|
Credentials: map[string]any{
|
||||||
|
"model_mapping": map[string]any{
|
||||||
|
"gpt-*": "gpt-5.4",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
requestedModel: "gpt-5.4",
|
||||||
|
defaultMappedModel: "gpt-4o-mini",
|
||||||
|
expectedModel: "gpt-5.4",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "uses account remap when explicit target differs",
|
||||||
|
account: &Account{
|
||||||
|
Credentials: map[string]any{
|
||||||
|
"model_mapping": map[string]any{
|
||||||
|
"gpt-5": "gpt-5.4",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
requestedModel: "gpt-5",
|
||||||
|
defaultMappedModel: "gpt-4o-mini",
|
||||||
|
expectedModel: "gpt-5.4",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
if got := resolveOpenAIForwardModel(tt.account, tt.requestedModel, tt.defaultMappedModel); got != tt.expectedModel {
|
||||||
|
t.Fatalf("resolveOpenAIForwardModel(...) = %q, want %q", got, tt.expectedModel)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
Reference in New Issue
Block a user