diff --git a/dto/claude.go b/dto/claude.go index e9f42a1b3..e7d87c41f 100644 --- a/dto/claude.go +++ b/dto/claude.go @@ -434,7 +434,7 @@ func ProcessTools(tools []any) ([]*Tool, []*ClaudeWebSearchTool) { } type Thinking struct { - Type string `json:"type"` + Type string `json:"type,omitempty"` BudgetTokens *int `json:"budget_tokens,omitempty"` } diff --git a/go.mod b/go.mod index cf5fbfd34..2f28f7817 100644 --- a/go.mod +++ b/go.mod @@ -8,10 +8,10 @@ require ( github.com/abema/go-mp4 v1.4.1 github.com/andybalholm/brotli v1.1.1 github.com/anknown/ahocorasick v0.0.0-20190904063843-d75dbd5169c0 - github.com/aws/aws-sdk-go-v2 v1.37.2 - github.com/aws/aws-sdk-go-v2/credentials v1.17.11 - github.com/aws/aws-sdk-go-v2/service/bedrockruntime v1.33.0 - github.com/aws/smithy-go v1.22.5 + github.com/aws/aws-sdk-go-v2 v1.41.2 + github.com/aws/aws-sdk-go-v2/credentials v1.19.10 + github.com/aws/aws-sdk-go-v2/service/bedrockruntime v1.50.0 + github.com/aws/smithy-go v1.24.2 github.com/bytedance/gopkg v0.1.3 github.com/gin-contrib/cors v1.7.2 github.com/gin-contrib/gzip v0.0.6 @@ -62,9 +62,9 @@ require ( require ( github.com/DmitriyVTitov/size v1.5.0 // indirect github.com/anknown/darts v0.0.0-20151216065714-83ff685239e6 // indirect - github.com/aws/aws-sdk-go-v2/aws/protocol/eventstream v1.7.0 // indirect - github.com/aws/aws-sdk-go-v2/internal/configsources v1.4.2 // indirect - github.com/aws/aws-sdk-go-v2/internal/endpoints/v2 v2.7.2 // indirect + github.com/aws/aws-sdk-go-v2/aws/protocol/eventstream v1.7.5 // indirect + github.com/aws/aws-sdk-go-v2/internal/configsources v1.4.18 // indirect + github.com/aws/aws-sdk-go-v2/internal/endpoints/v2 v2.7.18 // indirect github.com/beorn7/perks v1.0.1 // indirect github.com/boombuler/barcode v1.1.0 // indirect github.com/bytedance/sonic v1.14.1 // indirect diff --git a/go.sum b/go.sum index 23fe79489..742989293 100644 --- a/go.sum +++ b/go.sum @@ -12,18 +12,34 @@ github.com/anknown/darts v0.0.0-20151216065714-83ff685239e6 h1:HblK3eJHq54yET63q github.com/anknown/darts v0.0.0-20151216065714-83ff685239e6/go.mod h1:pbiaLIeYLUbgMY1kwEAdwO6UKD5ZNwdPGQlwokS9fe8= github.com/aws/aws-sdk-go-v2 v1.37.2 h1:xkW1iMYawzcmYFYEV0UCMxc8gSsjCGEhBXQkdQywVbo= github.com/aws/aws-sdk-go-v2 v1.37.2/go.mod h1:9Q0OoGQoboYIAJyslFyF1f5K1Ryddop8gqMhWx/n4Wg= +github.com/aws/aws-sdk-go-v2 v1.41.2 h1:LuT2rzqNQsauaGkPK/7813XxcZ3o3yePY0Iy891T2ls= +github.com/aws/aws-sdk-go-v2 v1.41.2/go.mod h1:IvvlAZQXvTXznUPfRVfryiG1fbzE2NGK6m9u39YQ+S4= github.com/aws/aws-sdk-go-v2/aws/protocol/eventstream v1.7.0 h1:6GMWV6CNpA/6fbFHnoAjrv4+LGfyTqZz2LtCHnspgDg= github.com/aws/aws-sdk-go-v2/aws/protocol/eventstream v1.7.0/go.mod h1:/mXlTIVG9jbxkqDnr5UQNQxW1HRYxeGklkM9vAFeabg= +github.com/aws/aws-sdk-go-v2/aws/protocol/eventstream v1.7.5 h1:zWFmPmgw4sveAYi1mRqG+E/g0461cJ5M4bJ8/nc6d3Q= +github.com/aws/aws-sdk-go-v2/aws/protocol/eventstream v1.7.5/go.mod h1:nVUlMLVV8ycXSb7mSkcNu9e3v/1TJq2RTlrPwhYWr5c= github.com/aws/aws-sdk-go-v2/credentials v1.17.11 h1:YuIB1dJNf1Re822rriUOTxopaHHvIq0l/pX3fwO+Tzs= github.com/aws/aws-sdk-go-v2/credentials v1.17.11/go.mod h1:AQtFPsDH9bI2O+71anW6EKL+NcD7LG3dpKGMV4SShgo= +github.com/aws/aws-sdk-go-v2/credentials v1.19.10 h1:EEhmEUFCE1Yhl7vDhNOI5OCL/iKMdkkYFTRpZXNw7m8= +github.com/aws/aws-sdk-go-v2/credentials v1.19.10/go.mod h1:RnnlFCAlxQCkN2Q379B67USkBMu1PipEEiibzYN5UTE= github.com/aws/aws-sdk-go-v2/internal/configsources v1.4.2 h1:sPiRHLVUIIQcoVZTNwqQcdtjkqkPopyYmIX0M5ElRf4= github.com/aws/aws-sdk-go-v2/internal/configsources v1.4.2/go.mod h1:ik86P3sgV+Bk7c1tBFCwI3VxMoSEwl4YkRB9xn1s340= +github.com/aws/aws-sdk-go-v2/internal/configsources v1.4.18 h1:F43zk1vemYIqPAwhjTjYIz0irU2EY7sOb/F5eJ3HuyM= +github.com/aws/aws-sdk-go-v2/internal/configsources v1.4.18/go.mod h1:w1jdlZXrGKaJcNoL+Nnrj+k5wlpGXqnNrKoP22HvAug= github.com/aws/aws-sdk-go-v2/internal/endpoints/v2 v2.7.2 h1:ZdzDAg075H6stMZtbD2o+PyB933M/f20e9WmCBC17wA= github.com/aws/aws-sdk-go-v2/internal/endpoints/v2 v2.7.2/go.mod h1:eE1IIzXG9sdZCB0pNNpMpsYTLl4YdOQD3njiVN1e/E4= +github.com/aws/aws-sdk-go-v2/internal/endpoints/v2 v2.7.18 h1:xCeWVjj0ki0l3nruoyP2slHsGArMxeiiaoPN5QZH6YQ= +github.com/aws/aws-sdk-go-v2/internal/endpoints/v2 v2.7.18/go.mod h1:r/eLGuGCBw6l36ZRWiw6PaZwPXb6YOj+i/7MizNl5/k= github.com/aws/aws-sdk-go-v2/service/bedrockruntime v1.33.0 h1:JzidOz4Hcn2RbP5fvIS1iAP+DcRv5VJtgixbEYDsI5g= github.com/aws/aws-sdk-go-v2/service/bedrockruntime v1.33.0/go.mod h1:9A4/PJYlWjvjEzzoOLGQjkLt4bYK9fRWi7uz1GSsAcA= +github.com/aws/aws-sdk-go-v2/service/bedrockruntime v1.50.0 h1:TDKR8ACRw7G+GFaQlhoy6biu+8q6ZtSddQCy9avMdMI= +github.com/aws/aws-sdk-go-v2/service/bedrockruntime v1.50.0/go.mod h1:XlhOh5Ax/lesqN4aZCUgj9vVJed5VoXYHHFYGAlJEwU= github.com/aws/smithy-go v1.22.5 h1:P9ATCXPMb2mPjYBgueqJNCA5S9UfktsW0tTxi+a7eqw= github.com/aws/smithy-go v1.22.5/go.mod h1:t1ufH5HMublsJYulve2RKmHDC15xu1f26kHCp/HgceI= +github.com/aws/smithy-go v1.24.1 h1:VbyeNfmYkWoxMVpGUAbQumkODcYmfMRfZ8yQiH30SK0= +github.com/aws/smithy-go v1.24.1/go.mod h1:LEj2LM3rBRQJxPZTB4KuzZkaZYnZPnvgIhb4pu07mx0= +github.com/aws/smithy-go v1.24.2 h1:FzA3bu/nt/vDvmnkg+R8Xl46gmzEDam6mZ1hzmwXFng= +github.com/aws/smithy-go v1.24.2/go.mod h1:YE2RhdIuDbA5E5bTdciG9KrW3+TiEONeUWCqxX9i1Fc= github.com/beorn7/perks v1.0.1 h1:VlbKKnNfV8bJzeqoa4cOKqO6bYr3WgKZxO8Z16+hsOM= github.com/beorn7/perks v1.0.1/go.mod h1:G2ZrVWU2WbWT9wwq4/hrbKbnv/1ERSJQ0ibhJ6rlkpw= github.com/boombuler/barcode v1.0.1-0.20190219062509-6c824513bacc/go.mod h1:paBWMcWSl3LHKBqUq+rly7CNSldXjb2rDl3JlRe0mD8= diff --git a/relay/channel/api_request.go b/relay/channel/api_request.go index 79eac3ad6..c00dc1a69 100644 --- a/relay/channel/api_request.go +++ b/relay/channel/api_request.go @@ -267,6 +267,10 @@ func processHeaderOverride(info *common.RelayInfo, c *gin.Context) (map[string]s return headerOverride, nil } +func ResolveHeaderOverride(info *common.RelayInfo, c *gin.Context) (map[string]string, error) { + return processHeaderOverride(info, c) +} + func applyHeaderOverrideToRequest(req *http.Request, headerOverride map[string]string) { if req == nil { return diff --git a/relay/channel/aws/dto.go b/relay/channel/aws/dto.go index 042f091ef..4c5c5cbc8 100644 --- a/relay/channel/aws/dto.go +++ b/relay/channel/aws/dto.go @@ -27,6 +27,7 @@ type AwsClaudeRequest struct { ToolChoice any `json:"tool_choice,omitempty"` Thinking *dto.Thinking `json:"thinking,omitempty"` OutputConfig json.RawMessage `json:"output_config,omitempty"` + //Metadata json.RawMessage `json:"metadata,omitempty"` } func formatRequest(requestBody io.Reader, requestHeader http.Header) (*AwsClaudeRequest, error) { diff --git a/relay/channel/aws/relay-aws.go b/relay/channel/aws/relay-aws.go index c2a676738..1f6ff7e69 100644 --- a/relay/channel/aws/relay-aws.go +++ b/relay/channel/aws/relay-aws.go @@ -11,6 +11,7 @@ import ( "github.com/QuantumNous/new-api/common" "github.com/QuantumNous/new-api/dto" + "github.com/QuantumNous/new-api/relay/channel" "github.com/QuantumNous/new-api/relay/channel/claude" relaycommon "github.com/QuantumNous/new-api/relay/common" "github.com/QuantumNous/new-api/relay/helper" @@ -106,6 +107,13 @@ func doAwsClientRequest(c *gin.Context, info *relaycommon.RelayInfo, a *Adaptor, // init empty request.header requestHeader := http.Header{} a.SetupRequestHeader(c, &requestHeader, info) + headerOverride, err := channel.ResolveHeaderOverride(info, c) + if err != nil { + return nil, err + } + for key, value := range headerOverride { + requestHeader.Set(key, value) + } if isNovaModel(awsModelId) { var novaReq *NovaRequest diff --git a/relay/channel/aws/relay_aws_test.go b/relay/channel/aws/relay_aws_test.go new file mode 100644 index 000000000..92745ff40 --- /dev/null +++ b/relay/channel/aws/relay_aws_test.go @@ -0,0 +1,55 @@ +package aws + +import ( + "bytes" + "net/http" + "net/http/httptest" + "testing" + + "github.com/QuantumNous/new-api/common" + relaycommon "github.com/QuantumNous/new-api/relay/common" + "github.com/aws/aws-sdk-go-v2/service/bedrockruntime" + "github.com/gin-gonic/gin" + "github.com/stretchr/testify/require" +) + +func TestDoAwsClientRequest_AppliesRuntimeHeaderOverrideToAnthropicBeta(t *testing.T) { + t.Parallel() + + gin.SetMode(gin.TestMode) + recorder := httptest.NewRecorder() + ctx, _ := gin.CreateTestContext(recorder) + ctx.Request = httptest.NewRequest(http.MethodPost, "/v1/messages", nil) + + info := &relaycommon.RelayInfo{ + OriginModelName: "claude-3-5-sonnet-20240620", + IsStream: false, + UseRuntimeHeadersOverride: true, + RuntimeHeadersOverride: map[string]any{ + "anthropic-beta": "computer-use-2025-01-24", + }, + ChannelMeta: &relaycommon.ChannelMeta{ + ApiKey: "access-key|secret-key|us-east-1", + UpstreamModelName: "claude-3-5-sonnet-20240620", + }, + } + + requestBody := bytes.NewBufferString(`{"messages":[{"role":"user","content":"hello"}],"max_tokens":128}`) + adaptor := &Adaptor{} + + _, err := doAwsClientRequest(ctx, info, adaptor, requestBody) + require.NoError(t, err) + + awsReq, ok := adaptor.AwsReq.(*bedrockruntime.InvokeModelInput) + require.True(t, ok) + + var payload map[string]any + require.NoError(t, common.Unmarshal(awsReq.Body, &payload)) + + anthropicBeta, exists := payload["anthropic_beta"] + require.True(t, exists) + + values, ok := anthropicBeta.([]any) + require.True(t, ok) + require.Equal(t, []any{"computer-use-2025-01-24"}, values) +} diff --git a/relay/channel/vertex/dto.go b/relay/channel/vertex/dto.go index 86b628e08..c1d13a6dd 100644 --- a/relay/channel/vertex/dto.go +++ b/relay/channel/vertex/dto.go @@ -20,6 +20,7 @@ type VertexAIClaudeRequest struct { ToolChoice any `json:"tool_choice,omitempty"` Thinking *dto.Thinking `json:"thinking,omitempty"` OutputConfig json.RawMessage `json:"output_config,omitempty"` + //Metadata json.RawMessage `json:"metadata,omitempty"` } func copyRequest(req *dto.ClaudeRequest, version string) *VertexAIClaudeRequest { diff --git a/relay/common/override.go b/relay/common/override.go index 59e15176b..e0761ab63 100644 --- a/relay/common/override.go +++ b/relay/common/override.go @@ -120,8 +120,18 @@ func ApplyParamOverride(jsonData []byte, paramOverride map[string]interface{}, c // 尝试断言为操作格式 if operations, ok := tryParseOperations(paramOverride); ok { + legacyOverride := buildLegacyParamOverride(paramOverride) + workingJSON := jsonData + var err error + if len(legacyOverride) > 0 { + workingJSON, err = applyOperationsLegacy(workingJSON, legacyOverride) + if err != nil { + return nil, err + } + } + // 使用新方法 - result, err := applyOperations(string(jsonData), operations, conditionContext) + result, err := applyOperations(string(workingJSON), operations, conditionContext) return []byte(result), err } @@ -129,6 +139,20 @@ func ApplyParamOverride(jsonData []byte, paramOverride map[string]interface{}, c return applyOperationsLegacy(jsonData, paramOverride) } +func buildLegacyParamOverride(paramOverride map[string]interface{}) map[string]interface{} { + if len(paramOverride) == 0 { + return nil + } + legacy := make(map[string]interface{}, len(paramOverride)) + for key, value := range paramOverride { + if strings.EqualFold(strings.TrimSpace(key), "operations") { + continue + } + legacy[key] = value + } + return legacy +} + func ApplyParamOverrideWithRelayInfo(jsonData []byte, info *RelayInfo) ([]byte, error) { paramOverride := getParamOverrideMap(info) if len(paramOverride) == 0 { diff --git a/relay/common/override_test.go b/relay/common/override_test.go index 5f49d95ae..dbdcd4096 100644 --- a/relay/common/override_test.go +++ b/relay/common/override_test.go @@ -74,6 +74,48 @@ func TestApplyParamOverrideTrimNoop(t *testing.T) { assertJSONEqual(t, `{"model":"gpt-4","temperature":0.7}`, string(out)) } +func TestApplyParamOverrideMixedLegacyAndOperations(t *testing.T) { + input := []byte(`{"model":"openai/gpt-4","temperature":0.7}`) + override := map[string]interface{}{ + "temperature": 0.2, + "top_p": 0.95, + "operations": []interface{}{ + map[string]interface{}{ + "path": "model", + "mode": "trim_prefix", + "value": "openai/", + }, + }, + } + + out, err := ApplyParamOverride(input, override, nil) + if err != nil { + t.Fatalf("ApplyParamOverride returned error: %v", err) + } + assertJSONEqual(t, `{"model":"gpt-4","temperature":0.2,"top_p":0.95}`, string(out)) +} + +func TestApplyParamOverrideMixedLegacyAndOperationsConflictPrefersOperations(t *testing.T) { + input := []byte(`{"model":"openai/gpt-4","temperature":0.7}`) + override := map[string]interface{}{ + "model": "legacy-model", + "temperature": 0.2, + "operations": []interface{}{ + map[string]interface{}{ + "path": "model", + "mode": "set", + "value": "op-model", + }, + }, + } + + out, err := ApplyParamOverride(input, override, nil) + if err != nil { + t.Fatalf("ApplyParamOverride returned error: %v", err) + } + assertJSONEqual(t, `{"model":"op-model","temperature":0.2}`, string(out)) +} + func TestApplyParamOverrideTrimRequiresValue(t *testing.T) { // trim_prefix requires value example: // {"operations":[{"path":"model","mode":"trim_prefix"}]} @@ -1429,6 +1471,44 @@ func TestApplyParamOverrideWithRelayInfoSyncRuntimeHeaders(t *testing.T) { } } +func TestApplyParamOverrideWithRelayInfoMixedLegacyAndOperations(t *testing.T) { + info := &RelayInfo{ + RequestHeaders: map[string]string{ + "Originator": "Codex CLI", + }, + ChannelMeta: &ChannelMeta{ + ParamOverride: map[string]interface{}{ + "temperature": 0.2, + "operations": []interface{}{ + map[string]interface{}{ + "mode": "pass_headers", + "value": []interface{}{"Originator"}, + }, + }, + }, + HeadersOverride: map[string]interface{}{ + "X-Static": "legacy-static", + }, + }, + } + + out, err := ApplyParamOverrideWithRelayInfo([]byte(`{"model":"gpt-5","temperature":0.7}`), info) + if err != nil { + t.Fatalf("ApplyParamOverrideWithRelayInfo returned error: %v", err) + } + assertJSONEqual(t, `{"model":"gpt-5","temperature":0.2}`, string(out)) + + if !info.UseRuntimeHeadersOverride { + t.Fatalf("expected runtime header override to be enabled") + } + if info.RuntimeHeadersOverride["x-static"] != "legacy-static" { + t.Fatalf("expected x-static to be preserved, got: %v", info.RuntimeHeadersOverride["x-static"]) + } + if info.RuntimeHeadersOverride["originator"] != "Codex CLI" { + t.Fatalf("expected originator header to be passed, got: %v", info.RuntimeHeadersOverride["originator"]) + } +} + func TestApplyParamOverrideWithRelayInfoMoveAndCopyHeaders(t *testing.T) { info := &RelayInfo{ ChannelMeta: &ChannelMeta{ diff --git a/relay/helper/valid_request.go b/relay/helper/valid_request.go index 463837865..c5477ccea 100644 --- a/relay/helper/valid_request.go +++ b/relay/helper/valid_request.go @@ -229,7 +229,7 @@ func GetAndValidOpenAIImageRequest(c *gin.Context, relayMode int) (*dto.ImageReq func GetAndValidateClaudeRequest(c *gin.Context) (textRequest *dto.ClaudeRequest, err error) { textRequest = &dto.ClaudeRequest{} - err = c.ShouldBindJSON(textRequest) + err = common.UnmarshalBodyReusable(c, textRequest) if err != nil { return nil, err } diff --git a/service/channel_affinity.go b/service/channel_affinity.go index 3e90b9c22..c8177f9d8 100644 --- a/service/channel_affinity.go +++ b/service/channel_affinity.go @@ -436,11 +436,46 @@ func mergeChannelOverride(base map[string]interface{}, tpl map[string]interface{ } out := cloneStringAnyMap(base) for k, v := range tpl { + if strings.EqualFold(strings.TrimSpace(k), "operations") { + baseOps, hasBaseOps := extractParamOperations(out[k]) + tplOps, hasTplOps := extractParamOperations(v) + if hasTplOps { + if hasBaseOps { + out[k] = append(tplOps, baseOps...) + } else { + out[k] = tplOps + } + continue + } + } + if _, exists := out[k]; exists { + continue + } out[k] = v } return out } +func extractParamOperations(value interface{}) ([]interface{}, bool) { + switch ops := value.(type) { + case []interface{}: + if len(ops) == 0 { + return []interface{}{}, true + } + cloned := make([]interface{}, 0, len(ops)) + cloned = append(cloned, ops...) + return cloned, true + case []map[string]interface{}: + cloned := make([]interface{}, 0, len(ops)) + for _, op := range ops { + cloned = append(cloned, op) + } + return cloned, true + default: + return nil, false + } +} + func appendChannelAffinityTemplateAdminInfo(c *gin.Context, meta channelAffinityMeta) { if c == nil { return diff --git a/service/channel_affinity_template_test.go b/service/channel_affinity_template_test.go index acf301543..4a024e99b 100644 --- a/service/channel_affinity_template_test.go +++ b/service/channel_affinity_template_test.go @@ -56,7 +56,7 @@ func TestApplyChannelAffinityOverrideTemplate_MergeTemplate(t *testing.T) { merged, applied := ApplyChannelAffinityOverrideTemplate(ctx, base) require.True(t, applied) - require.Equal(t, 0.2, merged["temperature"]) + require.Equal(t, 0.7, merged["temperature"]) require.Equal(t, 0.95, merged["top_p"]) require.Equal(t, 2000, merged["max_tokens"]) require.Equal(t, 0.7, base["temperature"]) @@ -74,6 +74,48 @@ func TestApplyChannelAffinityOverrideTemplate_MergeTemplate(t *testing.T) { require.EqualValues(t, 2, overrideInfo["param_override_keys"]) } +func TestApplyChannelAffinityOverrideTemplate_MergeOperations(t *testing.T) { + ctx := buildChannelAffinityTemplateContextForTest(channelAffinityMeta{ + RuleName: "rule-with-ops-template", + ParamTemplate: map[string]interface{}{ + "operations": []map[string]interface{}{ + { + "mode": "pass_headers", + "value": []string{"Originator"}, + }, + }, + }, + }) + base := map[string]interface{}{ + "temperature": 0.7, + "operations": []map[string]interface{}{ + { + "path": "model", + "mode": "trim_prefix", + "value": "openai/", + }, + }, + } + + merged, applied := ApplyChannelAffinityOverrideTemplate(ctx, base) + require.True(t, applied) + require.Equal(t, 0.7, merged["temperature"]) + + opsAny, ok := merged["operations"] + require.True(t, ok) + ops, ok := opsAny.([]interface{}) + require.True(t, ok) + require.Len(t, ops, 2) + + firstOp, ok := ops[0].(map[string]interface{}) + require.True(t, ok) + require.Equal(t, "pass_headers", firstOp["mode"]) + + secondOp, ok := ops[1].(map[string]interface{}) + require.True(t, ok) + require.Equal(t, "trim_prefix", secondOp["mode"]) +} + func TestChannelAffinityHitCodexTemplatePassHeadersEffective(t *testing.T) { gin.SetMode(gin.TestMode) diff --git a/web/src/components/table/channels/modals/EditChannelModal.jsx b/web/src/components/table/channels/modals/EditChannelModal.jsx index 3a91207dc..ea0debd8b 100644 --- a/web/src/components/table/channels/modals/EditChannelModal.jsx +++ b/web/src/components/table/channels/modals/EditChannelModal.jsx @@ -759,6 +759,10 @@ const EditChannelModal = (props) => { } }; + const clearParamOverride = () => { + handleInputChange('param_override', ''); + }; + const loadChannel = async () => { setLoading(true); let res = await API.get(`/api/channel/${channelId}`); @@ -3356,6 +3360,13 @@ const EditChannelModal = (props) => { > {t('填充旧模板')} +