Merge pull request #3066 from seefs001/fix/aws-header-override

Fix/aws header override
This commit is contained in:
Seefs
2026-03-02 18:54:56 +08:00
committed by GitHub
parent f2c5acf815
commit 0689600103
14 changed files with 288 additions and 11 deletions

View File

@@ -267,6 +267,10 @@ func processHeaderOverride(info *common.RelayInfo, c *gin.Context) (map[string]s
return headerOverride, nil
}
func ResolveHeaderOverride(info *common.RelayInfo, c *gin.Context) (map[string]string, error) {
return processHeaderOverride(info, c)
}
func applyHeaderOverrideToRequest(req *http.Request, headerOverride map[string]string) {
if req == nil {
return

View File

@@ -27,6 +27,7 @@ type AwsClaudeRequest struct {
ToolChoice any `json:"tool_choice,omitempty"`
Thinking *dto.Thinking `json:"thinking,omitempty"`
OutputConfig json.RawMessage `json:"output_config,omitempty"`
//Metadata json.RawMessage `json:"metadata,omitempty"`
}
func formatRequest(requestBody io.Reader, requestHeader http.Header) (*AwsClaudeRequest, error) {

View File

@@ -11,6 +11,7 @@ import (
"github.com/QuantumNous/new-api/common"
"github.com/QuantumNous/new-api/dto"
"github.com/QuantumNous/new-api/relay/channel"
"github.com/QuantumNous/new-api/relay/channel/claude"
relaycommon "github.com/QuantumNous/new-api/relay/common"
"github.com/QuantumNous/new-api/relay/helper"
@@ -106,6 +107,13 @@ func doAwsClientRequest(c *gin.Context, info *relaycommon.RelayInfo, a *Adaptor,
// init empty request.header
requestHeader := http.Header{}
a.SetupRequestHeader(c, &requestHeader, info)
headerOverride, err := channel.ResolveHeaderOverride(info, c)
if err != nil {
return nil, err
}
for key, value := range headerOverride {
requestHeader.Set(key, value)
}
if isNovaModel(awsModelId) {
var novaReq *NovaRequest

View File

@@ -0,0 +1,55 @@
package aws
import (
"bytes"
"net/http"
"net/http/httptest"
"testing"
"github.com/QuantumNous/new-api/common"
relaycommon "github.com/QuantumNous/new-api/relay/common"
"github.com/aws/aws-sdk-go-v2/service/bedrockruntime"
"github.com/gin-gonic/gin"
"github.com/stretchr/testify/require"
)
func TestDoAwsClientRequest_AppliesRuntimeHeaderOverrideToAnthropicBeta(t *testing.T) {
t.Parallel()
gin.SetMode(gin.TestMode)
recorder := httptest.NewRecorder()
ctx, _ := gin.CreateTestContext(recorder)
ctx.Request = httptest.NewRequest(http.MethodPost, "/v1/messages", nil)
info := &relaycommon.RelayInfo{
OriginModelName: "claude-3-5-sonnet-20240620",
IsStream: false,
UseRuntimeHeadersOverride: true,
RuntimeHeadersOverride: map[string]any{
"anthropic-beta": "computer-use-2025-01-24",
},
ChannelMeta: &relaycommon.ChannelMeta{
ApiKey: "access-key|secret-key|us-east-1",
UpstreamModelName: "claude-3-5-sonnet-20240620",
},
}
requestBody := bytes.NewBufferString(`{"messages":[{"role":"user","content":"hello"}],"max_tokens":128}`)
adaptor := &Adaptor{}
_, err := doAwsClientRequest(ctx, info, adaptor, requestBody)
require.NoError(t, err)
awsReq, ok := adaptor.AwsReq.(*bedrockruntime.InvokeModelInput)
require.True(t, ok)
var payload map[string]any
require.NoError(t, common.Unmarshal(awsReq.Body, &payload))
anthropicBeta, exists := payload["anthropic_beta"]
require.True(t, exists)
values, ok := anthropicBeta.([]any)
require.True(t, ok)
require.Equal(t, []any{"computer-use-2025-01-24"}, values)
}

View File

@@ -20,6 +20,7 @@ type VertexAIClaudeRequest struct {
ToolChoice any `json:"tool_choice,omitempty"`
Thinking *dto.Thinking `json:"thinking,omitempty"`
OutputConfig json.RawMessage `json:"output_config,omitempty"`
//Metadata json.RawMessage `json:"metadata,omitempty"`
}
func copyRequest(req *dto.ClaudeRequest, version string) *VertexAIClaudeRequest {

View File

@@ -120,8 +120,18 @@ func ApplyParamOverride(jsonData []byte, paramOverride map[string]interface{}, c
// 尝试断言为操作格式
if operations, ok := tryParseOperations(paramOverride); ok {
legacyOverride := buildLegacyParamOverride(paramOverride)
workingJSON := jsonData
var err error
if len(legacyOverride) > 0 {
workingJSON, err = applyOperationsLegacy(workingJSON, legacyOverride)
if err != nil {
return nil, err
}
}
// 使用新方法
result, err := applyOperations(string(jsonData), operations, conditionContext)
result, err := applyOperations(string(workingJSON), operations, conditionContext)
return []byte(result), err
}
@@ -129,6 +139,20 @@ func ApplyParamOverride(jsonData []byte, paramOverride map[string]interface{}, c
return applyOperationsLegacy(jsonData, paramOverride)
}
func buildLegacyParamOverride(paramOverride map[string]interface{}) map[string]interface{} {
if len(paramOverride) == 0 {
return nil
}
legacy := make(map[string]interface{}, len(paramOverride))
for key, value := range paramOverride {
if strings.EqualFold(strings.TrimSpace(key), "operations") {
continue
}
legacy[key] = value
}
return legacy
}
func ApplyParamOverrideWithRelayInfo(jsonData []byte, info *RelayInfo) ([]byte, error) {
paramOverride := getParamOverrideMap(info)
if len(paramOverride) == 0 {

View File

@@ -74,6 +74,48 @@ func TestApplyParamOverrideTrimNoop(t *testing.T) {
assertJSONEqual(t, `{"model":"gpt-4","temperature":0.7}`, string(out))
}
func TestApplyParamOverrideMixedLegacyAndOperations(t *testing.T) {
input := []byte(`{"model":"openai/gpt-4","temperature":0.7}`)
override := map[string]interface{}{
"temperature": 0.2,
"top_p": 0.95,
"operations": []interface{}{
map[string]interface{}{
"path": "model",
"mode": "trim_prefix",
"value": "openai/",
},
},
}
out, err := ApplyParamOverride(input, override, nil)
if err != nil {
t.Fatalf("ApplyParamOverride returned error: %v", err)
}
assertJSONEqual(t, `{"model":"gpt-4","temperature":0.2,"top_p":0.95}`, string(out))
}
func TestApplyParamOverrideMixedLegacyAndOperationsConflictPrefersOperations(t *testing.T) {
input := []byte(`{"model":"openai/gpt-4","temperature":0.7}`)
override := map[string]interface{}{
"model": "legacy-model",
"temperature": 0.2,
"operations": []interface{}{
map[string]interface{}{
"path": "model",
"mode": "set",
"value": "op-model",
},
},
}
out, err := ApplyParamOverride(input, override, nil)
if err != nil {
t.Fatalf("ApplyParamOverride returned error: %v", err)
}
assertJSONEqual(t, `{"model":"op-model","temperature":0.2}`, string(out))
}
func TestApplyParamOverrideTrimRequiresValue(t *testing.T) {
// trim_prefix requires value example:
// {"operations":[{"path":"model","mode":"trim_prefix"}]}
@@ -1429,6 +1471,44 @@ func TestApplyParamOverrideWithRelayInfoSyncRuntimeHeaders(t *testing.T) {
}
}
func TestApplyParamOverrideWithRelayInfoMixedLegacyAndOperations(t *testing.T) {
info := &RelayInfo{
RequestHeaders: map[string]string{
"Originator": "Codex CLI",
},
ChannelMeta: &ChannelMeta{
ParamOverride: map[string]interface{}{
"temperature": 0.2,
"operations": []interface{}{
map[string]interface{}{
"mode": "pass_headers",
"value": []interface{}{"Originator"},
},
},
},
HeadersOverride: map[string]interface{}{
"X-Static": "legacy-static",
},
},
}
out, err := ApplyParamOverrideWithRelayInfo([]byte(`{"model":"gpt-5","temperature":0.7}`), info)
if err != nil {
t.Fatalf("ApplyParamOverrideWithRelayInfo returned error: %v", err)
}
assertJSONEqual(t, `{"model":"gpt-5","temperature":0.2}`, string(out))
if !info.UseRuntimeHeadersOverride {
t.Fatalf("expected runtime header override to be enabled")
}
if info.RuntimeHeadersOverride["x-static"] != "legacy-static" {
t.Fatalf("expected x-static to be preserved, got: %v", info.RuntimeHeadersOverride["x-static"])
}
if info.RuntimeHeadersOverride["originator"] != "Codex CLI" {
t.Fatalf("expected originator header to be passed, got: %v", info.RuntimeHeadersOverride["originator"])
}
}
func TestApplyParamOverrideWithRelayInfoMoveAndCopyHeaders(t *testing.T) {
info := &RelayInfo{
ChannelMeta: &ChannelMeta{

View File

@@ -229,7 +229,7 @@ func GetAndValidOpenAIImageRequest(c *gin.Context, relayMode int) (*dto.ImageReq
func GetAndValidateClaudeRequest(c *gin.Context) (textRequest *dto.ClaudeRequest, err error) {
textRequest = &dto.ClaudeRequest{}
err = c.ShouldBindJSON(textRequest)
err = common.UnmarshalBodyReusable(c, textRequest)
if err != nil {
return nil, err
}