diff --git a/relay/channel/aws/dto.go b/relay/channel/aws/dto.go index 1f3952047..3b7b5d83e 100644 --- a/relay/channel/aws/dto.go +++ b/relay/channel/aws/dto.go @@ -1,7 +1,9 @@ package aws import ( + "encoding/json" "io" + "net/http" "github.com/QuantumNous/new-api/common" "github.com/QuantumNous/new-api/dto" @@ -10,6 +12,7 @@ import ( type AwsClaudeRequest struct { // AnthropicVersion should be "bedrock-2023-05-31" AnthropicVersion string `json:"anthropic_version"` + AnthropicBeta json.RawMessage `json:"anthropic_beta"` System any `json:"system,omitempty"` Messages []dto.ClaudeMessage `json:"messages"` MaxTokens uint `json:"max_tokens,omitempty"` @@ -22,29 +25,23 @@ type AwsClaudeRequest struct { Thinking *dto.Thinking `json:"thinking,omitempty"` } -func copyRequest(req *dto.ClaudeRequest) *AwsClaudeRequest { - return &AwsClaudeRequest{ - AnthropicVersion: "bedrock-2023-05-31", - System: req.System, - Messages: req.Messages, - MaxTokens: req.MaxTokens, - Temperature: req.Temperature, - TopP: req.TopP, - TopK: req.TopK, - StopSequences: req.StopSequences, - Tools: req.Tools, - ToolChoice: req.ToolChoice, - Thinking: req.Thinking, - } -} - -func formatRequest(requestBody io.Reader) (*AwsClaudeRequest, error) { +func formatRequest(requestBody io.Reader, requestHeader http.Header) (*AwsClaudeRequest, error) { var awsClaudeRequest AwsClaudeRequest err := common.DecodeJson(requestBody, &awsClaudeRequest) if err != nil { return nil, err } awsClaudeRequest.AnthropicVersion = "bedrock-2023-05-31" + + // check header anthropic-beta + anthropicBetaValues := requestHeader.Values("anthropic-beta") + if len(anthropicBetaValues) > 0 { + betaJson, err := json.Marshal(anthropicBetaValues) + if err != nil { + return nil, err + } + awsClaudeRequest.AnthropicBeta = json.RawMessage(betaJson) + } return &awsClaudeRequest, nil } diff --git a/relay/channel/aws/relay-aws.go b/relay/channel/aws/relay-aws.go index ec5ea8988..0795498ce 100644 --- a/relay/channel/aws/relay-aws.go +++ b/relay/channel/aws/relay-aws.go @@ -73,7 +73,6 @@ func doAwsClientRequest(c *gin.Context, info *relaycommon.RelayInfo, a *Adaptor, } a.AwsClient = awsCli - println(info.UpstreamModelName) // 获取对应的AWS模型ID awsModelId := getAwsModelID(info.UpstreamModelName) @@ -83,6 +82,10 @@ func doAwsClientRequest(c *gin.Context, info *relaycommon.RelayInfo, a *Adaptor, awsModelId = awsModelCrossRegion(awsModelId, awsRegionPrefix) } + // init empty request.header + requestHeader := http.Header{} + a.SetupRequestHeader(c, &requestHeader, info) + if isNovaModel(awsModelId) { var novaReq *NovaRequest err = common.DecodeJson(requestBody, &novaReq) @@ -104,7 +107,7 @@ func doAwsClientRequest(c *gin.Context, info *relaycommon.RelayInfo, a *Adaptor, awsReq.Body = reqBody return nil, nil } else { - awsClaudeReq, err := formatRequest(requestBody) + awsClaudeReq, err := formatRequest(requestBody, requestHeader) if err != nil { return nil, types.NewError(errors.Wrap(err, "format aws request fail"), types.ErrorCodeBadRequestBody) }