feat(aws): Add support for anthropic-beta header in AwsClaudeRequest

This commit is contained in:
CaIon
2025-11-14 12:01:00 +08:00
parent 974df5e7b9
commit e1a52f1d5a
2 changed files with 19 additions and 19 deletions

View File

@@ -1,7 +1,9 @@
package aws
import (
"encoding/json"
"io"
"net/http"
"github.com/QuantumNous/new-api/common"
"github.com/QuantumNous/new-api/dto"
@@ -10,6 +12,7 @@ import (
type AwsClaudeRequest struct {
// AnthropicVersion should be "bedrock-2023-05-31"
AnthropicVersion string `json:"anthropic_version"`
AnthropicBeta json.RawMessage `json:"anthropic_beta"`
System any `json:"system,omitempty"`
Messages []dto.ClaudeMessage `json:"messages"`
MaxTokens uint `json:"max_tokens,omitempty"`
@@ -22,29 +25,23 @@ type AwsClaudeRequest struct {
Thinking *dto.Thinking `json:"thinking,omitempty"`
}
func copyRequest(req *dto.ClaudeRequest) *AwsClaudeRequest {
return &AwsClaudeRequest{
AnthropicVersion: "bedrock-2023-05-31",
System: req.System,
Messages: req.Messages,
MaxTokens: req.MaxTokens,
Temperature: req.Temperature,
TopP: req.TopP,
TopK: req.TopK,
StopSequences: req.StopSequences,
Tools: req.Tools,
ToolChoice: req.ToolChoice,
Thinking: req.Thinking,
}
}
func formatRequest(requestBody io.Reader) (*AwsClaudeRequest, error) {
func formatRequest(requestBody io.Reader, requestHeader http.Header) (*AwsClaudeRequest, error) {
var awsClaudeRequest AwsClaudeRequest
err := common.DecodeJson(requestBody, &awsClaudeRequest)
if err != nil {
return nil, err
}
awsClaudeRequest.AnthropicVersion = "bedrock-2023-05-31"
// check header anthropic-beta
anthropicBetaValues := requestHeader.Values("anthropic-beta")
if len(anthropicBetaValues) > 0 {
betaJson, err := json.Marshal(anthropicBetaValues)
if err != nil {
return nil, err
}
awsClaudeRequest.AnthropicBeta = json.RawMessage(betaJson)
}
return &awsClaudeRequest, nil
}

View File

@@ -73,7 +73,6 @@ func doAwsClientRequest(c *gin.Context, info *relaycommon.RelayInfo, a *Adaptor,
}
a.AwsClient = awsCli
println(info.UpstreamModelName)
// 获取对应的AWS模型ID
awsModelId := getAwsModelID(info.UpstreamModelName)
@@ -83,6 +82,10 @@ func doAwsClientRequest(c *gin.Context, info *relaycommon.RelayInfo, a *Adaptor,
awsModelId = awsModelCrossRegion(awsModelId, awsRegionPrefix)
}
// init empty request.header
requestHeader := http.Header{}
a.SetupRequestHeader(c, &requestHeader, info)
if isNovaModel(awsModelId) {
var novaReq *NovaRequest
err = common.DecodeJson(requestBody, &novaReq)
@@ -104,7 +107,7 @@ func doAwsClientRequest(c *gin.Context, info *relaycommon.RelayInfo, a *Adaptor,
awsReq.Body = reqBody
return nil, nil
} else {
awsClaudeReq, err := formatRequest(requestBody)
awsClaudeReq, err := formatRequest(requestBody, requestHeader)
if err != nil {
return nil, types.NewError(errors.Wrap(err, "format aws request fail"), types.ErrorCodeBadRequestBody)
}