diff --git a/relay/channel/aws/adaptor.go b/relay/channel/aws/adaptor.go index 1526a7f75..9d5e5891e 100644 --- a/relay/channel/aws/adaptor.go +++ b/relay/channel/aws/adaptor.go @@ -60,7 +60,16 @@ func (a *Adaptor) ConvertOpenAIRequest(c *gin.Context, info *relaycommon.RelayIn if request == nil { return nil, errors.New("request is nil") } + // 检查是否为Nova模型 + if isNovaModel(request.Model) { + novaReq := convertToNovaRequest(request) + c.Set("request_model", request.Model) + c.Set("converted_request", novaReq) + c.Set("is_nova_model", true) + return novaReq, nil + } + // 原有的Claude模型处理逻辑 var claudeReq *dto.ClaudeRequest var err error claudeReq, err = claude.RequestOpenAI2ClaudeMessage(c, *request) @@ -69,6 +78,7 @@ func (a *Adaptor) ConvertOpenAIRequest(c *gin.Context, info *relaycommon.RelayIn } c.Set("request_model", claudeReq.Model) c.Set("converted_request", claudeReq) + c.Set("is_nova_model", false) return claudeReq, err } diff --git a/relay/channel/aws/constants.go b/relay/channel/aws/constants.go index 3f8800b1e..8ed8f0318 100644 --- a/relay/channel/aws/constants.go +++ b/relay/channel/aws/constants.go @@ -1,5 +1,7 @@ package aws +import "strings" + var awsModelIDMap = map[string]string{ "claude-instant-1.2": "anthropic.claude-instant-v1", "claude-2.0": "anthropic.claude-v2", @@ -14,6 +16,10 @@ var awsModelIDMap = map[string]string{ "claude-sonnet-4-20250514": "anthropic.claude-sonnet-4-20250514-v1:0", "claude-opus-4-20250514": "anthropic.claude-opus-4-20250514-v1:0", "claude-opus-4-1-20250805": "anthropic.claude-opus-4-1-20250805-v1:0", + // Nova models + "amazon.nova-micro-v1:0": "us.amazon.nova-micro-v1:0", + "amazon.nova-lite-v1:0": "us.amazon.nova-lite-v1:0", + "amazon.nova-pro-v1:0": "us.amazon.nova-pro-v1:0", } var awsModelCanCrossRegionMap = map[string]map[string]bool{ @@ -67,3 +73,8 @@ var awsRegionCrossModelPrefixMap = map[string]string{ } var ChannelName = "aws" + +// 判断是否为Nova模型 +func isNovaModel(modelId string) bool { + return strings.HasPrefix(modelId, "amazon.nova-") +} diff --git a/relay/channel/aws/dto.go b/relay/channel/aws/dto.go index 0188c30a9..25851ff6f 100644 --- a/relay/channel/aws/dto.go +++ b/relay/channel/aws/dto.go @@ -34,3 +34,56 @@ func copyRequest(req *dto.ClaudeRequest) *AwsClaudeRequest { Thinking: req.Thinking, } } + +// Nova模型使用messages-v1格式 +type NovaMessage struct { + Role string `json:"role"` + Content []NovaContent `json:"content"` +} + +type NovaContent struct { + Text string `json:"text"` +} + +type NovaRequest struct { + SchemaVersion string `json:"schemaVersion"` + Messages []NovaMessage `json:"messages"` + InferenceConfig NovaInferenceConfig `json:"inferenceConfig,omitempty"` +} + +type NovaInferenceConfig struct { + MaxTokens int `json:"maxTokens,omitempty"` + Temperature float64 `json:"temperature,omitempty"` + TopP float64 `json:"topP,omitempty"` +} + +// 转换OpenAI请求为Nova格式 +func convertToNovaRequest(req *dto.GeneralOpenAIRequest) *NovaRequest { + novaMessages := make([]NovaMessage, len(req.Messages)) + for i, msg := range req.Messages { + novaMessages[i] = NovaMessage{ + Role: msg.Role, + Content: []NovaContent{{Text: msg.StringContent()}}, + } + } + + novaReq := &NovaRequest{ + SchemaVersion: "messages-v1", + Messages: novaMessages, + } + + // 设置推理配置 + if req.MaxTokens != 0 || (req.Temperature != nil && *req.Temperature != 0) || req.TopP != 0 { + if req.MaxTokens != 0 { + novaReq.InferenceConfig.MaxTokens = int(req.MaxTokens) + } + if req.Temperature != nil && *req.Temperature != 0 { + novaReq.InferenceConfig.Temperature = *req.Temperature + } + if req.TopP != 0 { + novaReq.InferenceConfig.TopP = req.TopP + } + } + + return novaReq +} diff --git a/relay/channel/aws/relay-aws.go b/relay/channel/aws/relay-aws.go index 26e234fa3..3df6b33dd 100644 --- a/relay/channel/aws/relay-aws.go +++ b/relay/channel/aws/relay-aws.go @@ -1,6 +1,7 @@ package aws import ( + "encoding/json" "fmt" "net/http" "one-api/common" @@ -93,7 +94,13 @@ func awsHandler(c *gin.Context, info *relaycommon.RelayInfo, requestMode int) (* } awsModelId := awsModelID(c.GetString("request_model")) + // 检查是否为Nova模型 + isNova, _ := c.Get("is_nova_model") + if isNova == true { + return handleNovaRequest(c, awsCli, info, awsModelId) + } + // 原有的Claude处理逻辑 awsRegionPrefix := awsRegionPrefix(awsCli.Options().Region) canCrossRegion := awsModelCanCrossRegion(awsModelId, awsRegionPrefix) if canCrossRegion { @@ -209,3 +216,74 @@ func awsStreamHandler(c *gin.Context, resp *http.Response, info *relaycommon.Rel claude.HandleStreamFinalResponse(c, info, claudeInfo, RequestModeMessage) return nil, claudeInfo.Usage } + +// Nova模型处理函数 +func handleNovaRequest(c *gin.Context, awsCli *bedrockruntime.Client, info *relaycommon.RelayInfo, awsModelId string) (*types.NewAPIError, *dto.Usage) { + novaReq_, ok := c.Get("converted_request") + if !ok { + return types.NewError(errors.New("nova request not found"), types.ErrorCodeInvalidRequest), nil + } + novaReq := novaReq_.(*NovaRequest) + + // 使用InvokeModel API,但使用Nova格式的请求体 + awsReq := &bedrockruntime.InvokeModelInput{ + ModelId: aws.String(awsModelId), + Accept: aws.String("application/json"), + ContentType: aws.String("application/json"), + } + + reqBody, err := json.Marshal(novaReq) + if err != nil { + return types.NewError(errors.Wrap(err, "marshal nova request"), types.ErrorCodeBadResponseBody), nil + } + awsReq.Body = reqBody + + awsResp, err := awsCli.InvokeModel(c.Request.Context(), awsReq) + if err != nil { + return types.NewError(errors.Wrap(err, "InvokeModel"), types.ErrorCodeChannelAwsClientError), nil + } + + // 解析Nova响应 + var novaResp struct { + Output struct { + Message struct { + Content []struct { + Text string `json:"text"` + } `json:"content"` + } `json:"message"` + } `json:"output"` + Usage struct { + InputTokens int `json:"inputTokens"` + OutputTokens int `json:"outputTokens"` + TotalTokens int `json:"totalTokens"` + } `json:"usage"` + } + + if err := json.Unmarshal(awsResp.Body, &novaResp); err != nil { + return types.NewError(errors.Wrap(err, "unmarshal nova response"), types.ErrorCodeBadResponseBody), nil + } + + // 构造OpenAI格式响应 + response := dto.OpenAITextResponse{ + Id: helper.GetResponseID(c), + Object: "chat.completion", + Created: common.GetTimestamp(), + Model: info.UpstreamModelName, + Choices: []dto.OpenAITextResponseChoice{{ + Index: 0, + Message: dto.Message{ + Role: "assistant", + Content: novaResp.Output.Message.Content[0].Text, + }, + FinishReason: "stop", + }}, + Usage: dto.Usage{ + PromptTokens: novaResp.Usage.InputTokens, + CompletionTokens: novaResp.Usage.OutputTokens, + TotalTokens: novaResp.Usage.TotalTokens, + }, + } + + c.JSON(http.StatusOK, response) + return nil, &response.Usage +}