From 8b65623726aebb8d88d59aeef01fe5bf24bd7c23 Mon Sep 17 00:00:00 2001 From: CaIon Date: Wed, 15 Oct 2025 16:44:33 +0800 Subject: [PATCH] refactor: aws --- common/json.go | 3 +- model/log.go | 1 + relay/channel/aws/adaptor.go | 97 ++++++++++++++++++++++++++----- relay/channel/aws/constants.go | 2 +- relay/channel/aws/dto.go | 13 +++++ relay/channel/aws/relay-aws.go | 102 ++------------------------------- types/error.go | 3 + 7 files changed, 108 insertions(+), 113 deletions(-) diff --git a/common/json.go b/common/json.go index 13e23a460..a65da462e 100644 --- a/common/json.go +++ b/common/json.go @@ -3,6 +3,7 @@ package common import ( "bytes" "encoding/json" + "io" ) func Unmarshal(data []byte, v any) error { @@ -13,7 +14,7 @@ func UnmarshalJsonStr(data string, v any) error { return json.Unmarshal(StringToByteSlice(data), v) } -func DecodeJson(reader *bytes.Reader, v any) error { +func DecodeJson(reader io.Reader, v any) error { return json.NewDecoder(reader).Decode(v) } diff --git a/model/log.go b/model/log.go index 3df7c5eb8..97bc47754 100644 --- a/model/log.go +++ b/model/log.go @@ -45,6 +45,7 @@ const ( LogTypeConsume LogTypeManage LogTypeSystem + LogTypeRefund LogTypeError ) diff --git a/relay/channel/aws/adaptor.go b/relay/channel/aws/adaptor.go index 437da8d0c..5d3f9ac71 100644 --- a/relay/channel/aws/adaptor.go +++ b/relay/channel/aws/adaptor.go @@ -1,14 +1,17 @@ package aws import ( - "errors" "io" "net/http" + "github.com/QuantumNous/new-api/common" "github.com/QuantumNous/new-api/dto" "github.com/QuantumNous/new-api/relay/channel/claude" relaycommon "github.com/QuantumNous/new-api/relay/common" "github.com/QuantumNous/new-api/types" + "github.com/aws/aws-sdk-go-v2/aws" + "github.com/aws/aws-sdk-go-v2/service/bedrockruntime" + "github.com/pkg/errors" "github.com/gin-gonic/gin" ) @@ -19,7 +22,10 @@ const ( ) type Adaptor struct { - RequestMode int + AwsClient *bedrockruntime.Client + AwsModelId string + AwsReq any + IsNova bool } func (a *Adaptor) ConvertGeminiRequest(*gin.Context, *relaycommon.RelayInfo, *dto.GeminiChatRequest) (any, error) { @@ -28,8 +34,6 @@ func (a *Adaptor) ConvertGeminiRequest(*gin.Context, *relaycommon.RelayInfo, *dt } func (a *Adaptor) ConvertClaudeRequest(c *gin.Context, info *relaycommon.RelayInfo, request *dto.ClaudeRequest) (any, error) { - c.Set("request_model", request.Model) - c.Set("converted_request", request) return request, nil } @@ -44,7 +48,6 @@ func (a *Adaptor) ConvertImageRequest(c *gin.Context, info *relaycommon.RelayInf } func (a *Adaptor) Init(info *relaycommon.RelayInfo) { - a.RequestMode = RequestModeMessage } func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) { @@ -63,9 +66,7 @@ func (a *Adaptor) ConvertOpenAIRequest(c *gin.Context, info *relaycommon.RelayIn // 检查是否为Nova模型 if isNovaModel(request.Model) { novaReq := convertToNovaRequest(request) - c.Set("request_model", request.Model) - c.Set("converted_request", novaReq) - c.Set("is_nova_model", true) + a.IsNova = true return novaReq, nil } @@ -76,9 +77,6 @@ func (a *Adaptor) ConvertOpenAIRequest(c *gin.Context, info *relaycommon.RelayIn if err != nil { return nil, err } - c.Set("request_model", claudeReq.Model) - c.Set("converted_request", claudeReq) - c.Set("is_nova_model", false) return claudeReq, err } @@ -97,14 +95,83 @@ func (a *Adaptor) ConvertOpenAIResponsesRequest(c *gin.Context, info *relaycommo } func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (any, error) { - return nil, nil + awsCli, err := newAwsClient(c, info) + if err != nil { + return nil, types.NewError(err, types.ErrorCodeChannelAwsClientError) + } + a.AwsClient = awsCli + + awsModelId := awsModelID(info.UpstreamModelName) + + awsRegionPrefix := awsRegionPrefix(awsCli.Options().Region) + canCrossRegion := awsModelCanCrossRegion(awsModelId, awsRegionPrefix) + if canCrossRegion { + awsModelId = awsModelCrossRegion(awsModelId, awsRegionPrefix) + } + + if isNovaModel(awsModelId) { + var novaReq *NovaRequest + err = common.DecodeJson(requestBody, &novaReq) + if err != nil { + return nil, types.NewError(errors.Wrap(err, "decode nova request fail"), types.ErrorCodeBadRequestBody) + } + + // 使用InvokeModel API,但使用Nova格式的请求体 + awsReq := &bedrockruntime.InvokeModelInput{ + ModelId: aws.String(awsModelId), + Accept: aws.String("application/json"), + ContentType: aws.String("application/json"), + } + + reqBody, err := common.Marshal(novaReq) + if err != nil { + return nil, types.NewError(errors.Wrap(err, "marshal nova request"), types.ErrorCodeBadResponseBody) + } + awsReq.Body = reqBody + return nil, nil + } else { + awsClaudeReq, err := formatRequest(requestBody) + if err != nil { + return nil, types.NewError(errors.Wrap(err, "format aws request fail"), types.ErrorCodeBadRequestBody) + } + + if info.IsStream { + awsReq := &bedrockruntime.InvokeModelWithResponseStreamInput{ + ModelId: aws.String(awsModelId), + Accept: aws.String("application/json"), + ContentType: aws.String("application/json"), + } + awsReq.Body, err = common.Marshal(awsClaudeReq) + if err != nil { + return nil, types.NewError(errors.Wrap(err, "marshal aws request fail"), types.ErrorCodeBadRequestBody) + } + a.AwsReq = awsReq + return nil, nil + } else { + awsReq := &bedrockruntime.InvokeModelInput{ + ModelId: aws.String(awsModelId), + Accept: aws.String("application/json"), + ContentType: aws.String("application/json"), + } + awsReq.Body, err = common.Marshal(awsClaudeReq) + if err != nil { + return nil, types.NewError(errors.Wrap(err, "marshal aws request fail"), types.ErrorCodeBadRequestBody) + } + a.AwsReq = awsReq + return nil, nil + } + } } func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage any, err *types.NewAPIError) { - if info.IsStream { - err, usage = awsStreamHandler(c, resp, info, a.RequestMode) + if a.IsNova { + err, usage = handleNovaRequest(c, info, a) } else { - err, usage = awsHandler(c, info, a.RequestMode) + if info.IsStream { + err, usage = awsStreamHandler(c, info, a) + } else { + err, usage = awsHandler(c, info, a) + } } return } diff --git a/relay/channel/aws/constants.go b/relay/channel/aws/constants.go index 45112d231..b2060b2ad 100644 --- a/relay/channel/aws/constants.go +++ b/relay/channel/aws/constants.go @@ -124,5 +124,5 @@ var ChannelName = "aws" // 判断是否为Nova模型 func isNovaModel(modelId string) bool { - return strings.HasPrefix(modelId, "nova-") + return strings.Contains(modelId, "nova-") } diff --git a/relay/channel/aws/dto.go b/relay/channel/aws/dto.go index 0873b671d..1f3952047 100644 --- a/relay/channel/aws/dto.go +++ b/relay/channel/aws/dto.go @@ -1,6 +1,9 @@ package aws import ( + "io" + + "github.com/QuantumNous/new-api/common" "github.com/QuantumNous/new-api/dto" ) @@ -35,6 +38,16 @@ func copyRequest(req *dto.ClaudeRequest) *AwsClaudeRequest { } } +func formatRequest(requestBody io.Reader) (*AwsClaudeRequest, error) { + var awsClaudeRequest AwsClaudeRequest + err := common.DecodeJson(requestBody, &awsClaudeRequest) + if err != nil { + return nil, err + } + awsClaudeRequest.AnthropicVersion = "bedrock-2023-05-31" + return &awsClaudeRequest, nil +} + // NovaMessage Nova模型使用messages-v1格式 type NovaMessage struct { Role string `json:"role"` diff --git a/relay/channel/aws/relay-aws.go b/relay/channel/aws/relay-aws.go index fe1b7e7f0..8adbbaaec 100644 --- a/relay/channel/aws/relay-aws.go +++ b/relay/channel/aws/relay-aws.go @@ -88,50 +88,9 @@ func awsModelID(requestModel string) string { return requestModel } -func awsHandler(c *gin.Context, info *relaycommon.RelayInfo, requestMode int) (*types.NewAPIError, *dto.Usage) { - awsCli, err := newAwsClient(c, info) - if err != nil { - return types.NewError(err, types.ErrorCodeChannelAwsClientError), nil - } +func awsHandler(c *gin.Context, info *relaycommon.RelayInfo, a *Adaptor) (*types.NewAPIError, *dto.Usage) { - awsModelId := awsModelID(c.GetString("request_model")) - // 检查是否为Nova模型 - isNova, _ := c.Get("is_nova_model") - if isNova == true { - // Nova模型也支持跨区域 - awsRegionPrefix := awsRegionPrefix(awsCli.Options().Region) - canCrossRegion := awsModelCanCrossRegion(awsModelId, awsRegionPrefix) - if canCrossRegion { - awsModelId = awsModelCrossRegion(awsModelId, awsRegionPrefix) - } - return handleNovaRequest(c, awsCli, info, awsModelId) - } - - // 原有的Claude处理逻辑 - awsRegionPrefix := awsRegionPrefix(awsCli.Options().Region) - canCrossRegion := awsModelCanCrossRegion(awsModelId, awsRegionPrefix) - if canCrossRegion { - awsModelId = awsModelCrossRegion(awsModelId, awsRegionPrefix) - } - - awsReq := &bedrockruntime.InvokeModelInput{ - ModelId: aws.String(awsModelId), - Accept: aws.String("application/json"), - ContentType: aws.String("application/json"), - } - - claudeReq_, ok := c.Get("converted_request") - if !ok { - return types.NewError(errors.New("aws claude request not found"), types.ErrorCodeInvalidRequest), nil - } - claudeReq := claudeReq_.(*dto.ClaudeRequest) - awsClaudeReq := copyRequest(claudeReq) - awsReq.Body, err = common.Marshal(awsClaudeReq) - if err != nil { - return types.NewError(errors.Wrap(err, "marshal request"), types.ErrorCodeBadResponseBody), nil - } - - awsResp, err := awsCli.InvokeModel(c.Request.Context(), awsReq) + awsResp, err := a.AwsClient.InvokeModel(c.Request.Context(), a.AwsReq.(*bedrockruntime.InvokeModelInput)) if err != nil { return types.NewOpenAIError(errors.Wrap(err, "InvokeModel"), types.ErrorCodeAwsInvokeError, http.StatusInternalServerError), nil } @@ -156,39 +115,8 @@ func awsHandler(c *gin.Context, info *relaycommon.RelayInfo, requestMode int) (* return nil, claudeInfo.Usage } -func awsStreamHandler(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo, requestMode int) (*types.NewAPIError, *dto.Usage) { - awsCli, err := newAwsClient(c, info) - if err != nil { - return types.NewError(err, types.ErrorCodeChannelAwsClientError), nil - } - - awsModelId := awsModelID(c.GetString("request_model")) - - awsRegionPrefix := awsRegionPrefix(awsCli.Options().Region) - canCrossRegion := awsModelCanCrossRegion(awsModelId, awsRegionPrefix) - if canCrossRegion { - awsModelId = awsModelCrossRegion(awsModelId, awsRegionPrefix) - } - - awsReq := &bedrockruntime.InvokeModelWithResponseStreamInput{ - ModelId: aws.String(awsModelId), - Accept: aws.String("application/json"), - ContentType: aws.String("application/json"), - } - - claudeReq_, ok := c.Get("converted_request") - if !ok { - return types.NewError(errors.New("aws claude request not found"), types.ErrorCodeInvalidRequest), nil - } - claudeReq := claudeReq_.(*dto.ClaudeRequest) - - awsClaudeReq := copyRequest(claudeReq) - awsReq.Body, err = common.Marshal(awsClaudeReq) - if err != nil { - return types.NewError(errors.Wrap(err, "marshal request"), types.ErrorCodeBadResponseBody), nil - } - - awsResp, err := awsCli.InvokeModelWithResponseStream(c.Request.Context(), awsReq) +func awsStreamHandler(c *gin.Context, info *relaycommon.RelayInfo, a *Adaptor) (*types.NewAPIError, *dto.Usage) { + awsResp, err := a.AwsClient.InvokeModelWithResponseStream(c.Request.Context(), a.AwsReq.(*bedrockruntime.InvokeModelWithResponseStreamInput)) if err != nil { return types.NewOpenAIError(errors.Wrap(err, "InvokeModelWithResponseStream"), types.ErrorCodeAwsInvokeError, http.StatusInternalServerError), nil } @@ -225,27 +153,9 @@ func awsStreamHandler(c *gin.Context, resp *http.Response, info *relaycommon.Rel } // Nova模型处理函数 -func handleNovaRequest(c *gin.Context, awsCli *bedrockruntime.Client, info *relaycommon.RelayInfo, awsModelId string) (*types.NewAPIError, *dto.Usage) { - novaReq_, ok := c.Get("converted_request") - if !ok { - return types.NewError(errors.New("nova request not found"), types.ErrorCodeInvalidRequest), nil - } - novaReq := novaReq_.(*NovaRequest) +func handleNovaRequest(c *gin.Context, info *relaycommon.RelayInfo, a *Adaptor) (*types.NewAPIError, *dto.Usage) { - // 使用InvokeModel API,但使用Nova格式的请求体 - awsReq := &bedrockruntime.InvokeModelInput{ - ModelId: aws.String(awsModelId), - Accept: aws.String("application/json"), - ContentType: aws.String("application/json"), - } - - reqBody, err := json.Marshal(novaReq) - if err != nil { - return types.NewError(errors.Wrap(err, "marshal nova request"), types.ErrorCodeBadResponseBody), nil - } - awsReq.Body = reqBody - - awsResp, err := awsCli.InvokeModel(c.Request.Context(), awsReq) + awsResp, err := a.AwsClient.InvokeModel(c.Request.Context(), a.AwsReq.(*bedrockruntime.InvokeModelInput)) if err != nil { return types.NewError(errors.Wrap(err, "InvokeModel"), types.ErrorCodeChannelAwsClientError), nil } diff --git a/types/error.go b/types/error.go index 1ca02afc6..9c12034e1 100644 --- a/types/error.go +++ b/types/error.go @@ -62,6 +62,9 @@ const ( ErrorCodeConvertRequestFailed ErrorCode = "convert_request_failed" ErrorCodeAccessDenied ErrorCode = "access_denied" + // request error + ErrorCodeBadRequestBody ErrorCode = "bad_request_body" + // response error ErrorCodeReadResponseBodyFailed ErrorCode = "read_response_body_failed" ErrorCodeBadResponseStatusCode ErrorCode = "bad_response_status_code"