Merge pull request #3066 from seefs001/fix/aws-header-override

Fix/aws header override
This commit is contained in:
Seefs
2026-03-02 18:54:56 +08:00
committed by GitHub
parent f2c5acf815
commit 0689600103
14 changed files with 288 additions and 11 deletions

View File

@@ -434,7 +434,7 @@ func ProcessTools(tools []any) ([]*Tool, []*ClaudeWebSearchTool) {
}
type Thinking struct {
Type string `json:"type"`
Type string `json:"type,omitempty"`
BudgetTokens *int `json:"budget_tokens,omitempty"`
}

14
go.mod
View File

@@ -8,10 +8,10 @@ require (
github.com/abema/go-mp4 v1.4.1
github.com/andybalholm/brotli v1.1.1
github.com/anknown/ahocorasick v0.0.0-20190904063843-d75dbd5169c0
github.com/aws/aws-sdk-go-v2 v1.37.2
github.com/aws/aws-sdk-go-v2/credentials v1.17.11
github.com/aws/aws-sdk-go-v2/service/bedrockruntime v1.33.0
github.com/aws/smithy-go v1.22.5
github.com/aws/aws-sdk-go-v2 v1.41.2
github.com/aws/aws-sdk-go-v2/credentials v1.19.10
github.com/aws/aws-sdk-go-v2/service/bedrockruntime v1.50.0
github.com/aws/smithy-go v1.24.2
github.com/bytedance/gopkg v0.1.3
github.com/gin-contrib/cors v1.7.2
github.com/gin-contrib/gzip v0.0.6
@@ -62,9 +62,9 @@ require (
require (
github.com/DmitriyVTitov/size v1.5.0 // indirect
github.com/anknown/darts v0.0.0-20151216065714-83ff685239e6 // indirect
github.com/aws/aws-sdk-go-v2/aws/protocol/eventstream v1.7.0 // indirect
github.com/aws/aws-sdk-go-v2/internal/configsources v1.4.2 // indirect
github.com/aws/aws-sdk-go-v2/internal/endpoints/v2 v2.7.2 // indirect
github.com/aws/aws-sdk-go-v2/aws/protocol/eventstream v1.7.5 // indirect
github.com/aws/aws-sdk-go-v2/internal/configsources v1.4.18 // indirect
github.com/aws/aws-sdk-go-v2/internal/endpoints/v2 v2.7.18 // indirect
github.com/beorn7/perks v1.0.1 // indirect
github.com/boombuler/barcode v1.1.0 // indirect
github.com/bytedance/sonic v1.14.1 // indirect

16
go.sum
View File

@@ -12,18 +12,34 @@ github.com/anknown/darts v0.0.0-20151216065714-83ff685239e6 h1:HblK3eJHq54yET63q
github.com/anknown/darts v0.0.0-20151216065714-83ff685239e6/go.mod h1:pbiaLIeYLUbgMY1kwEAdwO6UKD5ZNwdPGQlwokS9fe8=
github.com/aws/aws-sdk-go-v2 v1.37.2 h1:xkW1iMYawzcmYFYEV0UCMxc8gSsjCGEhBXQkdQywVbo=
github.com/aws/aws-sdk-go-v2 v1.37.2/go.mod h1:9Q0OoGQoboYIAJyslFyF1f5K1Ryddop8gqMhWx/n4Wg=
github.com/aws/aws-sdk-go-v2 v1.41.2 h1:LuT2rzqNQsauaGkPK/7813XxcZ3o3yePY0Iy891T2ls=
github.com/aws/aws-sdk-go-v2 v1.41.2/go.mod h1:IvvlAZQXvTXznUPfRVfryiG1fbzE2NGK6m9u39YQ+S4=
github.com/aws/aws-sdk-go-v2/aws/protocol/eventstream v1.7.0 h1:6GMWV6CNpA/6fbFHnoAjrv4+LGfyTqZz2LtCHnspgDg=
github.com/aws/aws-sdk-go-v2/aws/protocol/eventstream v1.7.0/go.mod h1:/mXlTIVG9jbxkqDnr5UQNQxW1HRYxeGklkM9vAFeabg=
github.com/aws/aws-sdk-go-v2/aws/protocol/eventstream v1.7.5 h1:zWFmPmgw4sveAYi1mRqG+E/g0461cJ5M4bJ8/nc6d3Q=
github.com/aws/aws-sdk-go-v2/aws/protocol/eventstream v1.7.5/go.mod h1:nVUlMLVV8ycXSb7mSkcNu9e3v/1TJq2RTlrPwhYWr5c=
github.com/aws/aws-sdk-go-v2/credentials v1.17.11 h1:YuIB1dJNf1Re822rriUOTxopaHHvIq0l/pX3fwO+Tzs=
github.com/aws/aws-sdk-go-v2/credentials v1.17.11/go.mod h1:AQtFPsDH9bI2O+71anW6EKL+NcD7LG3dpKGMV4SShgo=
github.com/aws/aws-sdk-go-v2/credentials v1.19.10 h1:EEhmEUFCE1Yhl7vDhNOI5OCL/iKMdkkYFTRpZXNw7m8=
github.com/aws/aws-sdk-go-v2/credentials v1.19.10/go.mod h1:RnnlFCAlxQCkN2Q379B67USkBMu1PipEEiibzYN5UTE=
github.com/aws/aws-sdk-go-v2/internal/configsources v1.4.2 h1:sPiRHLVUIIQcoVZTNwqQcdtjkqkPopyYmIX0M5ElRf4=
github.com/aws/aws-sdk-go-v2/internal/configsources v1.4.2/go.mod h1:ik86P3sgV+Bk7c1tBFCwI3VxMoSEwl4YkRB9xn1s340=
github.com/aws/aws-sdk-go-v2/internal/configsources v1.4.18 h1:F43zk1vemYIqPAwhjTjYIz0irU2EY7sOb/F5eJ3HuyM=
github.com/aws/aws-sdk-go-v2/internal/configsources v1.4.18/go.mod h1:w1jdlZXrGKaJcNoL+Nnrj+k5wlpGXqnNrKoP22HvAug=
github.com/aws/aws-sdk-go-v2/internal/endpoints/v2 v2.7.2 h1:ZdzDAg075H6stMZtbD2o+PyB933M/f20e9WmCBC17wA=
github.com/aws/aws-sdk-go-v2/internal/endpoints/v2 v2.7.2/go.mod h1:eE1IIzXG9sdZCB0pNNpMpsYTLl4YdOQD3njiVN1e/E4=
github.com/aws/aws-sdk-go-v2/internal/endpoints/v2 v2.7.18 h1:xCeWVjj0ki0l3nruoyP2slHsGArMxeiiaoPN5QZH6YQ=
github.com/aws/aws-sdk-go-v2/internal/endpoints/v2 v2.7.18/go.mod h1:r/eLGuGCBw6l36ZRWiw6PaZwPXb6YOj+i/7MizNl5/k=
github.com/aws/aws-sdk-go-v2/service/bedrockruntime v1.33.0 h1:JzidOz4Hcn2RbP5fvIS1iAP+DcRv5VJtgixbEYDsI5g=
github.com/aws/aws-sdk-go-v2/service/bedrockruntime v1.33.0/go.mod h1:9A4/PJYlWjvjEzzoOLGQjkLt4bYK9fRWi7uz1GSsAcA=
github.com/aws/aws-sdk-go-v2/service/bedrockruntime v1.50.0 h1:TDKR8ACRw7G+GFaQlhoy6biu+8q6ZtSddQCy9avMdMI=
github.com/aws/aws-sdk-go-v2/service/bedrockruntime v1.50.0/go.mod h1:XlhOh5Ax/lesqN4aZCUgj9vVJed5VoXYHHFYGAlJEwU=
github.com/aws/smithy-go v1.22.5 h1:P9ATCXPMb2mPjYBgueqJNCA5S9UfktsW0tTxi+a7eqw=
github.com/aws/smithy-go v1.22.5/go.mod h1:t1ufH5HMublsJYulve2RKmHDC15xu1f26kHCp/HgceI=
github.com/aws/smithy-go v1.24.1 h1:VbyeNfmYkWoxMVpGUAbQumkODcYmfMRfZ8yQiH30SK0=
github.com/aws/smithy-go v1.24.1/go.mod h1:LEj2LM3rBRQJxPZTB4KuzZkaZYnZPnvgIhb4pu07mx0=
github.com/aws/smithy-go v1.24.2 h1:FzA3bu/nt/vDvmnkg+R8Xl46gmzEDam6mZ1hzmwXFng=
github.com/aws/smithy-go v1.24.2/go.mod h1:YE2RhdIuDbA5E5bTdciG9KrW3+TiEONeUWCqxX9i1Fc=
github.com/beorn7/perks v1.0.1 h1:VlbKKnNfV8bJzeqoa4cOKqO6bYr3WgKZxO8Z16+hsOM=
github.com/beorn7/perks v1.0.1/go.mod h1:G2ZrVWU2WbWT9wwq4/hrbKbnv/1ERSJQ0ibhJ6rlkpw=
github.com/boombuler/barcode v1.0.1-0.20190219062509-6c824513bacc/go.mod h1:paBWMcWSl3LHKBqUq+rly7CNSldXjb2rDl3JlRe0mD8=

View File

@@ -267,6 +267,10 @@ func processHeaderOverride(info *common.RelayInfo, c *gin.Context) (map[string]s
return headerOverride, nil
}
func ResolveHeaderOverride(info *common.RelayInfo, c *gin.Context) (map[string]string, error) {
return processHeaderOverride(info, c)
}
func applyHeaderOverrideToRequest(req *http.Request, headerOverride map[string]string) {
if req == nil {
return

View File

@@ -27,6 +27,7 @@ type AwsClaudeRequest struct {
ToolChoice any `json:"tool_choice,omitempty"`
Thinking *dto.Thinking `json:"thinking,omitempty"`
OutputConfig json.RawMessage `json:"output_config,omitempty"`
//Metadata json.RawMessage `json:"metadata,omitempty"`
}
func formatRequest(requestBody io.Reader, requestHeader http.Header) (*AwsClaudeRequest, error) {

View File

@@ -11,6 +11,7 @@ import (
"github.com/QuantumNous/new-api/common"
"github.com/QuantumNous/new-api/dto"
"github.com/QuantumNous/new-api/relay/channel"
"github.com/QuantumNous/new-api/relay/channel/claude"
relaycommon "github.com/QuantumNous/new-api/relay/common"
"github.com/QuantumNous/new-api/relay/helper"
@@ -106,6 +107,13 @@ func doAwsClientRequest(c *gin.Context, info *relaycommon.RelayInfo, a *Adaptor,
// init empty request.header
requestHeader := http.Header{}
a.SetupRequestHeader(c, &requestHeader, info)
headerOverride, err := channel.ResolveHeaderOverride(info, c)
if err != nil {
return nil, err
}
for key, value := range headerOverride {
requestHeader.Set(key, value)
}
if isNovaModel(awsModelId) {
var novaReq *NovaRequest

View File

@@ -0,0 +1,55 @@
package aws
import (
"bytes"
"net/http"
"net/http/httptest"
"testing"
"github.com/QuantumNous/new-api/common"
relaycommon "github.com/QuantumNous/new-api/relay/common"
"github.com/aws/aws-sdk-go-v2/service/bedrockruntime"
"github.com/gin-gonic/gin"
"github.com/stretchr/testify/require"
)
func TestDoAwsClientRequest_AppliesRuntimeHeaderOverrideToAnthropicBeta(t *testing.T) {
t.Parallel()
gin.SetMode(gin.TestMode)
recorder := httptest.NewRecorder()
ctx, _ := gin.CreateTestContext(recorder)
ctx.Request = httptest.NewRequest(http.MethodPost, "/v1/messages", nil)
info := &relaycommon.RelayInfo{
OriginModelName: "claude-3-5-sonnet-20240620",
IsStream: false,
UseRuntimeHeadersOverride: true,
RuntimeHeadersOverride: map[string]any{
"anthropic-beta": "computer-use-2025-01-24",
},
ChannelMeta: &relaycommon.ChannelMeta{
ApiKey: "access-key|secret-key|us-east-1",
UpstreamModelName: "claude-3-5-sonnet-20240620",
},
}
requestBody := bytes.NewBufferString(`{"messages":[{"role":"user","content":"hello"}],"max_tokens":128}`)
adaptor := &Adaptor{}
_, err := doAwsClientRequest(ctx, info, adaptor, requestBody)
require.NoError(t, err)
awsReq, ok := adaptor.AwsReq.(*bedrockruntime.InvokeModelInput)
require.True(t, ok)
var payload map[string]any
require.NoError(t, common.Unmarshal(awsReq.Body, &payload))
anthropicBeta, exists := payload["anthropic_beta"]
require.True(t, exists)
values, ok := anthropicBeta.([]any)
require.True(t, ok)
require.Equal(t, []any{"computer-use-2025-01-24"}, values)
}

View File

@@ -20,6 +20,7 @@ type VertexAIClaudeRequest struct {
ToolChoice any `json:"tool_choice,omitempty"`
Thinking *dto.Thinking `json:"thinking,omitempty"`
OutputConfig json.RawMessage `json:"output_config,omitempty"`
//Metadata json.RawMessage `json:"metadata,omitempty"`
}
func copyRequest(req *dto.ClaudeRequest, version string) *VertexAIClaudeRequest {

View File

@@ -120,8 +120,18 @@ func ApplyParamOverride(jsonData []byte, paramOverride map[string]interface{}, c
// 尝试断言为操作格式
if operations, ok := tryParseOperations(paramOverride); ok {
legacyOverride := buildLegacyParamOverride(paramOverride)
workingJSON := jsonData
var err error
if len(legacyOverride) > 0 {
workingJSON, err = applyOperationsLegacy(workingJSON, legacyOverride)
if err != nil {
return nil, err
}
}
// 使用新方法
result, err := applyOperations(string(jsonData), operations, conditionContext)
result, err := applyOperations(string(workingJSON), operations, conditionContext)
return []byte(result), err
}
@@ -129,6 +139,20 @@ func ApplyParamOverride(jsonData []byte, paramOverride map[string]interface{}, c
return applyOperationsLegacy(jsonData, paramOverride)
}
func buildLegacyParamOverride(paramOverride map[string]interface{}) map[string]interface{} {
if len(paramOverride) == 0 {
return nil
}
legacy := make(map[string]interface{}, len(paramOverride))
for key, value := range paramOverride {
if strings.EqualFold(strings.TrimSpace(key), "operations") {
continue
}
legacy[key] = value
}
return legacy
}
func ApplyParamOverrideWithRelayInfo(jsonData []byte, info *RelayInfo) ([]byte, error) {
paramOverride := getParamOverrideMap(info)
if len(paramOverride) == 0 {

View File

@@ -74,6 +74,48 @@ func TestApplyParamOverrideTrimNoop(t *testing.T) {
assertJSONEqual(t, `{"model":"gpt-4","temperature":0.7}`, string(out))
}
func TestApplyParamOverrideMixedLegacyAndOperations(t *testing.T) {
input := []byte(`{"model":"openai/gpt-4","temperature":0.7}`)
override := map[string]interface{}{
"temperature": 0.2,
"top_p": 0.95,
"operations": []interface{}{
map[string]interface{}{
"path": "model",
"mode": "trim_prefix",
"value": "openai/",
},
},
}
out, err := ApplyParamOverride(input, override, nil)
if err != nil {
t.Fatalf("ApplyParamOverride returned error: %v", err)
}
assertJSONEqual(t, `{"model":"gpt-4","temperature":0.2,"top_p":0.95}`, string(out))
}
func TestApplyParamOverrideMixedLegacyAndOperationsConflictPrefersOperations(t *testing.T) {
input := []byte(`{"model":"openai/gpt-4","temperature":0.7}`)
override := map[string]interface{}{
"model": "legacy-model",
"temperature": 0.2,
"operations": []interface{}{
map[string]interface{}{
"path": "model",
"mode": "set",
"value": "op-model",
},
},
}
out, err := ApplyParamOverride(input, override, nil)
if err != nil {
t.Fatalf("ApplyParamOverride returned error: %v", err)
}
assertJSONEqual(t, `{"model":"op-model","temperature":0.2}`, string(out))
}
func TestApplyParamOverrideTrimRequiresValue(t *testing.T) {
// trim_prefix requires value example:
// {"operations":[{"path":"model","mode":"trim_prefix"}]}
@@ -1429,6 +1471,44 @@ func TestApplyParamOverrideWithRelayInfoSyncRuntimeHeaders(t *testing.T) {
}
}
func TestApplyParamOverrideWithRelayInfoMixedLegacyAndOperations(t *testing.T) {
info := &RelayInfo{
RequestHeaders: map[string]string{
"Originator": "Codex CLI",
},
ChannelMeta: &ChannelMeta{
ParamOverride: map[string]interface{}{
"temperature": 0.2,
"operations": []interface{}{
map[string]interface{}{
"mode": "pass_headers",
"value": []interface{}{"Originator"},
},
},
},
HeadersOverride: map[string]interface{}{
"X-Static": "legacy-static",
},
},
}
out, err := ApplyParamOverrideWithRelayInfo([]byte(`{"model":"gpt-5","temperature":0.7}`), info)
if err != nil {
t.Fatalf("ApplyParamOverrideWithRelayInfo returned error: %v", err)
}
assertJSONEqual(t, `{"model":"gpt-5","temperature":0.2}`, string(out))
if !info.UseRuntimeHeadersOverride {
t.Fatalf("expected runtime header override to be enabled")
}
if info.RuntimeHeadersOverride["x-static"] != "legacy-static" {
t.Fatalf("expected x-static to be preserved, got: %v", info.RuntimeHeadersOverride["x-static"])
}
if info.RuntimeHeadersOverride["originator"] != "Codex CLI" {
t.Fatalf("expected originator header to be passed, got: %v", info.RuntimeHeadersOverride["originator"])
}
}
func TestApplyParamOverrideWithRelayInfoMoveAndCopyHeaders(t *testing.T) {
info := &RelayInfo{
ChannelMeta: &ChannelMeta{

View File

@@ -229,7 +229,7 @@ func GetAndValidOpenAIImageRequest(c *gin.Context, relayMode int) (*dto.ImageReq
func GetAndValidateClaudeRequest(c *gin.Context) (textRequest *dto.ClaudeRequest, err error) {
textRequest = &dto.ClaudeRequest{}
err = c.ShouldBindJSON(textRequest)
err = common.UnmarshalBodyReusable(c, textRequest)
if err != nil {
return nil, err
}

View File

@@ -436,11 +436,46 @@ func mergeChannelOverride(base map[string]interface{}, tpl map[string]interface{
}
out := cloneStringAnyMap(base)
for k, v := range tpl {
if strings.EqualFold(strings.TrimSpace(k), "operations") {
baseOps, hasBaseOps := extractParamOperations(out[k])
tplOps, hasTplOps := extractParamOperations(v)
if hasTplOps {
if hasBaseOps {
out[k] = append(tplOps, baseOps...)
} else {
out[k] = tplOps
}
continue
}
}
if _, exists := out[k]; exists {
continue
}
out[k] = v
}
return out
}
func extractParamOperations(value interface{}) ([]interface{}, bool) {
switch ops := value.(type) {
case []interface{}:
if len(ops) == 0 {
return []interface{}{}, true
}
cloned := make([]interface{}, 0, len(ops))
cloned = append(cloned, ops...)
return cloned, true
case []map[string]interface{}:
cloned := make([]interface{}, 0, len(ops))
for _, op := range ops {
cloned = append(cloned, op)
}
return cloned, true
default:
return nil, false
}
}
func appendChannelAffinityTemplateAdminInfo(c *gin.Context, meta channelAffinityMeta) {
if c == nil {
return

View File

@@ -56,7 +56,7 @@ func TestApplyChannelAffinityOverrideTemplate_MergeTemplate(t *testing.T) {
merged, applied := ApplyChannelAffinityOverrideTemplate(ctx, base)
require.True(t, applied)
require.Equal(t, 0.2, merged["temperature"])
require.Equal(t, 0.7, merged["temperature"])
require.Equal(t, 0.95, merged["top_p"])
require.Equal(t, 2000, merged["max_tokens"])
require.Equal(t, 0.7, base["temperature"])
@@ -74,6 +74,48 @@ func TestApplyChannelAffinityOverrideTemplate_MergeTemplate(t *testing.T) {
require.EqualValues(t, 2, overrideInfo["param_override_keys"])
}
func TestApplyChannelAffinityOverrideTemplate_MergeOperations(t *testing.T) {
ctx := buildChannelAffinityTemplateContextForTest(channelAffinityMeta{
RuleName: "rule-with-ops-template",
ParamTemplate: map[string]interface{}{
"operations": []map[string]interface{}{
{
"mode": "pass_headers",
"value": []string{"Originator"},
},
},
},
})
base := map[string]interface{}{
"temperature": 0.7,
"operations": []map[string]interface{}{
{
"path": "model",
"mode": "trim_prefix",
"value": "openai/",
},
},
}
merged, applied := ApplyChannelAffinityOverrideTemplate(ctx, base)
require.True(t, applied)
require.Equal(t, 0.7, merged["temperature"])
opsAny, ok := merged["operations"]
require.True(t, ok)
ops, ok := opsAny.([]interface{})
require.True(t, ok)
require.Len(t, ops, 2)
firstOp, ok := ops[0].(map[string]interface{})
require.True(t, ok)
require.Equal(t, "pass_headers", firstOp["mode"])
secondOp, ok := ops[1].(map[string]interface{})
require.True(t, ok)
require.Equal(t, "trim_prefix", secondOp["mode"])
}
func TestChannelAffinityHitCodexTemplatePassHeadersEffective(t *testing.T) {
gin.SetMode(gin.TestMode)

View File

@@ -759,6 +759,10 @@ const EditChannelModal = (props) => {
}
};
const clearParamOverride = () => {
handleInputChange('param_override', '');
};
const loadChannel = async () => {
setLoading(true);
let res = await API.get(`/api/channel/${channelId}`);
@@ -3356,6 +3360,13 @@ const EditChannelModal = (props) => {
>
{t('填充旧模板')}
</Button>
<Button
size='small'
type='tertiary'
onClick={clearParamOverride}
>
{t('清空')}
</Button>
</Space>
</div>
<Text type='tertiary' size='small'>