Files
new-api/relay/channel/aws/relay-aws.go
zdwy5 e1bee48152 fix: 支持aws 通过全局参数透传或者渠道参数透传来 调用 (#2423)
* fix: 支持aws 通过全局参数透传或者渠道参数透传来 调用

* fix(aws): replace json.Unmarshal with common.Unmarshal for request body processing

---------

Co-authored-by: r0 <liangchunlei@01.ai>
Co-authored-by: CaIon <i@caion.me>
2025-12-12 17:09:27 +08:00

321 lines
10 KiB
Go
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
package aws
import (
"encoding/json"
"fmt"
"io"
"net/http"
"strings"
"github.com/QuantumNous/new-api/common"
"github.com/QuantumNous/new-api/dto"
"github.com/QuantumNous/new-api/relay/channel/claude"
relaycommon "github.com/QuantumNous/new-api/relay/common"
"github.com/QuantumNous/new-api/relay/helper"
"github.com/QuantumNous/new-api/service"
"github.com/QuantumNous/new-api/types"
"github.com/gin-gonic/gin"
"github.com/pkg/errors"
"github.com/QuantumNous/new-api/setting/model_setting"
"github.com/aws/aws-sdk-go-v2/aws"
"github.com/aws/aws-sdk-go-v2/credentials"
"github.com/aws/aws-sdk-go-v2/service/bedrockruntime"
bedrockruntimeTypes "github.com/aws/aws-sdk-go-v2/service/bedrockruntime/types"
"github.com/aws/smithy-go/auth/bearer"
)
// getAwsErrorStatusCode extracts HTTP status code from AWS SDK error
func getAwsErrorStatusCode(err error) int {
// Check for HTTP response error which contains status code
var httpErr interface{ HTTPStatusCode() int }
if errors.As(err, &httpErr) {
return httpErr.HTTPStatusCode()
}
// Default to 500 if we can't determine the status code
return http.StatusInternalServerError
}
func newAwsClient(c *gin.Context, info *relaycommon.RelayInfo) (*bedrockruntime.Client, error) {
var (
httpClient *http.Client
err error
)
if info.ChannelSetting.Proxy != "" {
httpClient, err = service.NewProxyHttpClient(info.ChannelSetting.Proxy)
if err != nil {
return nil, fmt.Errorf("new proxy http client failed: %w", err)
}
} else {
httpClient = service.GetHttpClient()
}
awsSecret := strings.Split(info.ApiKey, "|")
var client *bedrockruntime.Client
switch len(awsSecret) {
case 2:
apiKey := awsSecret[0]
region := awsSecret[1]
client = bedrockruntime.New(bedrockruntime.Options{
Region: region,
BearerAuthTokenProvider: bearer.StaticTokenProvider{Token: bearer.Token{Value: apiKey}},
HTTPClient: httpClient,
})
case 3:
ak := awsSecret[0]
sk := awsSecret[1]
region := awsSecret[2]
client = bedrockruntime.New(bedrockruntime.Options{
Region: region,
Credentials: aws.NewCredentialsCache(credentials.NewStaticCredentialsProvider(ak, sk, "")),
HTTPClient: httpClient,
})
default:
return nil, errors.New("invalid aws secret key")
}
return client, nil
}
func doAwsClientRequest(c *gin.Context, info *relaycommon.RelayInfo, a *Adaptor, requestBody io.Reader) (any, error) {
awsCli, err := newAwsClient(c, info)
if err != nil {
return nil, types.NewError(err, types.ErrorCodeChannelAwsClientError)
}
a.AwsClient = awsCli
// 获取对应的AWS模型ID
awsModelId := getAwsModelID(info.UpstreamModelName)
awsRegionPrefix := getAwsRegionPrefix(awsCli.Options().Region)
canCrossRegion := awsModelCanCrossRegion(awsModelId, awsRegionPrefix)
if canCrossRegion {
awsModelId = awsModelCrossRegion(awsModelId, awsRegionPrefix)
}
// init empty request.header
requestHeader := http.Header{}
a.SetupRequestHeader(c, &requestHeader, info)
if isNovaModel(awsModelId) {
var novaReq *NovaRequest
err = common.DecodeJson(requestBody, &novaReq)
if err != nil {
return nil, types.NewError(errors.Wrap(err, "decode nova request fail"), types.ErrorCodeBadRequestBody)
}
// 使用InvokeModel API但使用Nova格式的请求体
awsReq := &bedrockruntime.InvokeModelInput{
ModelId: aws.String(awsModelId),
Accept: aws.String("application/json"),
ContentType: aws.String("application/json"),
}
reqBody, err := common.Marshal(novaReq)
if err != nil {
return nil, types.NewError(errors.Wrap(err, "marshal nova request"), types.ErrorCodeBadResponseBody)
}
awsReq.Body = reqBody
return nil, nil
} else {
awsClaudeReq, err := formatRequest(requestBody, requestHeader)
if err != nil {
return nil, types.NewError(errors.Wrap(err, "format aws request fail"), types.ErrorCodeBadRequestBody)
}
if info.IsStream {
awsReq := &bedrockruntime.InvokeModelWithResponseStreamInput{
ModelId: aws.String(awsModelId),
Accept: aws.String("application/json"),
ContentType: aws.String("application/json"),
}
awsReq.Body, err = buildAwsRequestBody(c, info, awsClaudeReq)
if err != nil {
return nil, types.NewError(errors.Wrap(err, "marshal aws request fail"), types.ErrorCodeBadRequestBody)
}
a.AwsReq = awsReq
return nil, nil
} else {
awsReq := &bedrockruntime.InvokeModelInput{
ModelId: aws.String(awsModelId),
Accept: aws.String("application/json"),
ContentType: aws.String("application/json"),
}
awsReq.Body, err = buildAwsRequestBody(c, info, awsClaudeReq)
if err != nil {
return nil, types.NewError(errors.Wrap(err, "marshal aws request fail"), types.ErrorCodeBadRequestBody)
}
a.AwsReq = awsReq
return nil, nil
}
}
}
// buildAwsRequestBody prepares the payload for AWS requests, applying passthrough rules when enabled.
func buildAwsRequestBody(c *gin.Context, info *relaycommon.RelayInfo, awsClaudeReq any) ([]byte, error) {
if model_setting.GetGlobalSettings().PassThroughRequestEnabled || info.ChannelSetting.PassThroughBodyEnabled {
body, err := common.GetRequestBody(c)
if err != nil {
return nil, errors.Wrap(err, "get request body for pass-through fail")
}
var data map[string]interface{}
if err := common.Unmarshal(body, &data); err != nil {
return nil, errors.Wrap(err, "pass-through unmarshal request body fail")
}
delete(data, "model")
delete(data, "stream")
return common.Marshal(data)
}
return common.Marshal(awsClaudeReq)
}
func getAwsRegionPrefix(awsRegionId string) string {
parts := strings.Split(awsRegionId, "-")
regionPrefix := ""
if len(parts) > 0 {
regionPrefix = parts[0]
}
return regionPrefix
}
func awsModelCanCrossRegion(awsModelId, awsRegionPrefix string) bool {
regionSet, exists := awsModelCanCrossRegionMap[awsModelId]
return exists && regionSet[awsRegionPrefix]
}
func awsModelCrossRegion(awsModelId, awsRegionPrefix string) string {
modelPrefix, find := awsRegionCrossModelPrefixMap[awsRegionPrefix]
if !find {
return awsModelId
}
return modelPrefix + "." + awsModelId
}
func getAwsModelID(requestModel string) string {
if awsModelIDName, ok := awsModelIDMap[requestModel]; ok {
return awsModelIDName
}
return requestModel
}
func awsHandler(c *gin.Context, info *relaycommon.RelayInfo, a *Adaptor) (*types.NewAPIError, *dto.Usage) {
awsResp, err := a.AwsClient.InvokeModel(c.Request.Context(), a.AwsReq.(*bedrockruntime.InvokeModelInput))
if err != nil {
statusCode := getAwsErrorStatusCode(err)
return types.NewOpenAIError(errors.Wrap(err, "InvokeModel"), types.ErrorCodeAwsInvokeError, statusCode), nil
}
claudeInfo := &claude.ClaudeResponseInfo{
ResponseId: helper.GetResponseID(c),
Created: common.GetTimestamp(),
Model: info.UpstreamModelName,
ResponseText: strings.Builder{},
Usage: &dto.Usage{},
}
// 复制上游 Content-Type 到客户端响应头
if awsResp.ContentType != nil && *awsResp.ContentType != "" {
c.Writer.Header().Set("Content-Type", *awsResp.ContentType)
}
handlerErr := claude.HandleClaudeResponseData(c, info, claudeInfo, nil, awsResp.Body, claude.RequestModeMessage)
if handlerErr != nil {
return handlerErr, nil
}
return nil, claudeInfo.Usage
}
func awsStreamHandler(c *gin.Context, info *relaycommon.RelayInfo, a *Adaptor) (*types.NewAPIError, *dto.Usage) {
awsResp, err := a.AwsClient.InvokeModelWithResponseStream(c.Request.Context(), a.AwsReq.(*bedrockruntime.InvokeModelWithResponseStreamInput))
if err != nil {
statusCode := getAwsErrorStatusCode(err)
return types.NewOpenAIError(errors.Wrap(err, "InvokeModelWithResponseStream"), types.ErrorCodeAwsInvokeError, statusCode), nil
}
stream := awsResp.GetStream()
defer stream.Close()
claudeInfo := &claude.ClaudeResponseInfo{
ResponseId: helper.GetResponseID(c),
Created: common.GetTimestamp(),
Model: info.UpstreamModelName,
ResponseText: strings.Builder{},
Usage: &dto.Usage{},
}
for event := range stream.Events() {
switch v := event.(type) {
case *bedrockruntimeTypes.ResponseStreamMemberChunk:
info.SetFirstResponseTime()
respErr := claude.HandleStreamResponseData(c, info, claudeInfo, string(v.Value.Bytes), claude.RequestModeMessage)
if respErr != nil {
return respErr, nil
}
case *bedrockruntimeTypes.UnknownUnionMember:
fmt.Println("unknown tag:", v.Tag)
return types.NewError(errors.New("unknown response type"), types.ErrorCodeInvalidRequest), nil
default:
fmt.Println("union is nil or unknown type")
return types.NewError(errors.New("nil or unknown response type"), types.ErrorCodeInvalidRequest), nil
}
}
claude.HandleStreamFinalResponse(c, info, claudeInfo, claude.RequestModeMessage)
return nil, claudeInfo.Usage
}
// Nova模型处理函数
func handleNovaRequest(c *gin.Context, info *relaycommon.RelayInfo, a *Adaptor) (*types.NewAPIError, *dto.Usage) {
awsResp, err := a.AwsClient.InvokeModel(c.Request.Context(), a.AwsReq.(*bedrockruntime.InvokeModelInput))
if err != nil {
statusCode := getAwsErrorStatusCode(err)
return types.NewOpenAIError(errors.Wrap(err, "InvokeModel"), types.ErrorCodeAwsInvokeError, statusCode), nil
}
// 解析Nova响应
var novaResp struct {
Output struct {
Message struct {
Content []struct {
Text string `json:"text"`
} `json:"content"`
} `json:"message"`
} `json:"output"`
Usage struct {
InputTokens int `json:"inputTokens"`
OutputTokens int `json:"outputTokens"`
TotalTokens int `json:"totalTokens"`
} `json:"usage"`
}
if err := json.Unmarshal(awsResp.Body, &novaResp); err != nil {
return types.NewError(errors.Wrap(err, "unmarshal nova response"), types.ErrorCodeBadResponseBody), nil
}
// 构造OpenAI格式响应
response := dto.OpenAITextResponse{
Id: helper.GetResponseID(c),
Object: "chat.completion",
Created: common.GetTimestamp(),
Model: info.UpstreamModelName,
Choices: []dto.OpenAITextResponseChoice{{
Index: 0,
Message: dto.Message{
Role: "assistant",
Content: novaResp.Output.Message.Content[0].Text,
},
FinishReason: "stop",
}},
Usage: dto.Usage{
PromptTokens: novaResp.Usage.InputTokens,
CompletionTokens: novaResp.Usage.OutputTokens,
TotalTokens: novaResp.Usage.TotalTokens,
},
}
c.JSON(http.StatusOK, response)
return nil, &response.Usage
}