mirror of
https://github.com/QuantumNous/new-api.git
synced 2026-03-30 02:25:00 +00:00
refactor: aws
This commit is contained in:
@@ -3,6 +3,7 @@ package common
|
|||||||
import (
|
import (
|
||||||
"bytes"
|
"bytes"
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
|
"io"
|
||||||
)
|
)
|
||||||
|
|
||||||
func Unmarshal(data []byte, v any) error {
|
func Unmarshal(data []byte, v any) error {
|
||||||
@@ -13,7 +14,7 @@ func UnmarshalJsonStr(data string, v any) error {
|
|||||||
return json.Unmarshal(StringToByteSlice(data), v)
|
return json.Unmarshal(StringToByteSlice(data), v)
|
||||||
}
|
}
|
||||||
|
|
||||||
func DecodeJson(reader *bytes.Reader, v any) error {
|
func DecodeJson(reader io.Reader, v any) error {
|
||||||
return json.NewDecoder(reader).Decode(v)
|
return json.NewDecoder(reader).Decode(v)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -45,6 +45,7 @@ const (
|
|||||||
LogTypeConsume
|
LogTypeConsume
|
||||||
LogTypeManage
|
LogTypeManage
|
||||||
LogTypeSystem
|
LogTypeSystem
|
||||||
|
LogTypeRefund
|
||||||
LogTypeError
|
LogTypeError
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|||||||
@@ -1,14 +1,17 @@
|
|||||||
package aws
|
package aws
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"errors"
|
|
||||||
"io"
|
"io"
|
||||||
"net/http"
|
"net/http"
|
||||||
|
|
||||||
|
"github.com/QuantumNous/new-api/common"
|
||||||
"github.com/QuantumNous/new-api/dto"
|
"github.com/QuantumNous/new-api/dto"
|
||||||
"github.com/QuantumNous/new-api/relay/channel/claude"
|
"github.com/QuantumNous/new-api/relay/channel/claude"
|
||||||
relaycommon "github.com/QuantumNous/new-api/relay/common"
|
relaycommon "github.com/QuantumNous/new-api/relay/common"
|
||||||
"github.com/QuantumNous/new-api/types"
|
"github.com/QuantumNous/new-api/types"
|
||||||
|
"github.com/aws/aws-sdk-go-v2/aws"
|
||||||
|
"github.com/aws/aws-sdk-go-v2/service/bedrockruntime"
|
||||||
|
"github.com/pkg/errors"
|
||||||
|
|
||||||
"github.com/gin-gonic/gin"
|
"github.com/gin-gonic/gin"
|
||||||
)
|
)
|
||||||
@@ -19,7 +22,10 @@ const (
|
|||||||
)
|
)
|
||||||
|
|
||||||
type Adaptor struct {
|
type Adaptor struct {
|
||||||
RequestMode int
|
AwsClient *bedrockruntime.Client
|
||||||
|
AwsModelId string
|
||||||
|
AwsReq any
|
||||||
|
IsNova bool
|
||||||
}
|
}
|
||||||
|
|
||||||
func (a *Adaptor) ConvertGeminiRequest(*gin.Context, *relaycommon.RelayInfo, *dto.GeminiChatRequest) (any, error) {
|
func (a *Adaptor) ConvertGeminiRequest(*gin.Context, *relaycommon.RelayInfo, *dto.GeminiChatRequest) (any, error) {
|
||||||
@@ -28,8 +34,6 @@ func (a *Adaptor) ConvertGeminiRequest(*gin.Context, *relaycommon.RelayInfo, *dt
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (a *Adaptor) ConvertClaudeRequest(c *gin.Context, info *relaycommon.RelayInfo, request *dto.ClaudeRequest) (any, error) {
|
func (a *Adaptor) ConvertClaudeRequest(c *gin.Context, info *relaycommon.RelayInfo, request *dto.ClaudeRequest) (any, error) {
|
||||||
c.Set("request_model", request.Model)
|
|
||||||
c.Set("converted_request", request)
|
|
||||||
return request, nil
|
return request, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -44,7 +48,6 @@ func (a *Adaptor) ConvertImageRequest(c *gin.Context, info *relaycommon.RelayInf
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (a *Adaptor) Init(info *relaycommon.RelayInfo) {
|
func (a *Adaptor) Init(info *relaycommon.RelayInfo) {
|
||||||
a.RequestMode = RequestModeMessage
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) {
|
func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) {
|
||||||
@@ -63,9 +66,7 @@ func (a *Adaptor) ConvertOpenAIRequest(c *gin.Context, info *relaycommon.RelayIn
|
|||||||
// 检查是否为Nova模型
|
// 检查是否为Nova模型
|
||||||
if isNovaModel(request.Model) {
|
if isNovaModel(request.Model) {
|
||||||
novaReq := convertToNovaRequest(request)
|
novaReq := convertToNovaRequest(request)
|
||||||
c.Set("request_model", request.Model)
|
a.IsNova = true
|
||||||
c.Set("converted_request", novaReq)
|
|
||||||
c.Set("is_nova_model", true)
|
|
||||||
return novaReq, nil
|
return novaReq, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -76,9 +77,6 @@ func (a *Adaptor) ConvertOpenAIRequest(c *gin.Context, info *relaycommon.RelayIn
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
c.Set("request_model", claudeReq.Model)
|
|
||||||
c.Set("converted_request", claudeReq)
|
|
||||||
c.Set("is_nova_model", false)
|
|
||||||
return claudeReq, err
|
return claudeReq, err
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -97,14 +95,83 @@ func (a *Adaptor) ConvertOpenAIResponsesRequest(c *gin.Context, info *relaycommo
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (any, error) {
|
func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (any, error) {
|
||||||
return nil, nil
|
awsCli, err := newAwsClient(c, info)
|
||||||
|
if err != nil {
|
||||||
|
return nil, types.NewError(err, types.ErrorCodeChannelAwsClientError)
|
||||||
|
}
|
||||||
|
a.AwsClient = awsCli
|
||||||
|
|
||||||
|
awsModelId := awsModelID(info.UpstreamModelName)
|
||||||
|
|
||||||
|
awsRegionPrefix := awsRegionPrefix(awsCli.Options().Region)
|
||||||
|
canCrossRegion := awsModelCanCrossRegion(awsModelId, awsRegionPrefix)
|
||||||
|
if canCrossRegion {
|
||||||
|
awsModelId = awsModelCrossRegion(awsModelId, awsRegionPrefix)
|
||||||
|
}
|
||||||
|
|
||||||
|
if isNovaModel(awsModelId) {
|
||||||
|
var novaReq *NovaRequest
|
||||||
|
err = common.DecodeJson(requestBody, &novaReq)
|
||||||
|
if err != nil {
|
||||||
|
return nil, types.NewError(errors.Wrap(err, "decode nova request fail"), types.ErrorCodeBadRequestBody)
|
||||||
|
}
|
||||||
|
|
||||||
|
// 使用InvokeModel API,但使用Nova格式的请求体
|
||||||
|
awsReq := &bedrockruntime.InvokeModelInput{
|
||||||
|
ModelId: aws.String(awsModelId),
|
||||||
|
Accept: aws.String("application/json"),
|
||||||
|
ContentType: aws.String("application/json"),
|
||||||
|
}
|
||||||
|
|
||||||
|
reqBody, err := common.Marshal(novaReq)
|
||||||
|
if err != nil {
|
||||||
|
return nil, types.NewError(errors.Wrap(err, "marshal nova request"), types.ErrorCodeBadResponseBody)
|
||||||
|
}
|
||||||
|
awsReq.Body = reqBody
|
||||||
|
return nil, nil
|
||||||
|
} else {
|
||||||
|
awsClaudeReq, err := formatRequest(requestBody)
|
||||||
|
if err != nil {
|
||||||
|
return nil, types.NewError(errors.Wrap(err, "format aws request fail"), types.ErrorCodeBadRequestBody)
|
||||||
|
}
|
||||||
|
|
||||||
|
if info.IsStream {
|
||||||
|
awsReq := &bedrockruntime.InvokeModelWithResponseStreamInput{
|
||||||
|
ModelId: aws.String(awsModelId),
|
||||||
|
Accept: aws.String("application/json"),
|
||||||
|
ContentType: aws.String("application/json"),
|
||||||
|
}
|
||||||
|
awsReq.Body, err = common.Marshal(awsClaudeReq)
|
||||||
|
if err != nil {
|
||||||
|
return nil, types.NewError(errors.Wrap(err, "marshal aws request fail"), types.ErrorCodeBadRequestBody)
|
||||||
|
}
|
||||||
|
a.AwsReq = awsReq
|
||||||
|
return nil, nil
|
||||||
|
} else {
|
||||||
|
awsReq := &bedrockruntime.InvokeModelInput{
|
||||||
|
ModelId: aws.String(awsModelId),
|
||||||
|
Accept: aws.String("application/json"),
|
||||||
|
ContentType: aws.String("application/json"),
|
||||||
|
}
|
||||||
|
awsReq.Body, err = common.Marshal(awsClaudeReq)
|
||||||
|
if err != nil {
|
||||||
|
return nil, types.NewError(errors.Wrap(err, "marshal aws request fail"), types.ErrorCodeBadRequestBody)
|
||||||
|
}
|
||||||
|
a.AwsReq = awsReq
|
||||||
|
return nil, nil
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage any, err *types.NewAPIError) {
|
func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage any, err *types.NewAPIError) {
|
||||||
if info.IsStream {
|
if a.IsNova {
|
||||||
err, usage = awsStreamHandler(c, resp, info, a.RequestMode)
|
err, usage = handleNovaRequest(c, info, a)
|
||||||
} else {
|
} else {
|
||||||
err, usage = awsHandler(c, info, a.RequestMode)
|
if info.IsStream {
|
||||||
|
err, usage = awsStreamHandler(c, info, a)
|
||||||
|
} else {
|
||||||
|
err, usage = awsHandler(c, info, a)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -124,5 +124,5 @@ var ChannelName = "aws"
|
|||||||
|
|
||||||
// 判断是否为Nova模型
|
// 判断是否为Nova模型
|
||||||
func isNovaModel(modelId string) bool {
|
func isNovaModel(modelId string) bool {
|
||||||
return strings.HasPrefix(modelId, "nova-")
|
return strings.Contains(modelId, "nova-")
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -1,6 +1,9 @@
|
|||||||
package aws
|
package aws
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"io"
|
||||||
|
|
||||||
|
"github.com/QuantumNous/new-api/common"
|
||||||
"github.com/QuantumNous/new-api/dto"
|
"github.com/QuantumNous/new-api/dto"
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -35,6 +38,16 @@ func copyRequest(req *dto.ClaudeRequest) *AwsClaudeRequest {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func formatRequest(requestBody io.Reader) (*AwsClaudeRequest, error) {
|
||||||
|
var awsClaudeRequest AwsClaudeRequest
|
||||||
|
err := common.DecodeJson(requestBody, &awsClaudeRequest)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
awsClaudeRequest.AnthropicVersion = "bedrock-2023-05-31"
|
||||||
|
return &awsClaudeRequest, nil
|
||||||
|
}
|
||||||
|
|
||||||
// NovaMessage Nova模型使用messages-v1格式
|
// NovaMessage Nova模型使用messages-v1格式
|
||||||
type NovaMessage struct {
|
type NovaMessage struct {
|
||||||
Role string `json:"role"`
|
Role string `json:"role"`
|
||||||
|
|||||||
@@ -88,50 +88,9 @@ func awsModelID(requestModel string) string {
|
|||||||
return requestModel
|
return requestModel
|
||||||
}
|
}
|
||||||
|
|
||||||
func awsHandler(c *gin.Context, info *relaycommon.RelayInfo, requestMode int) (*types.NewAPIError, *dto.Usage) {
|
func awsHandler(c *gin.Context, info *relaycommon.RelayInfo, a *Adaptor) (*types.NewAPIError, *dto.Usage) {
|
||||||
awsCli, err := newAwsClient(c, info)
|
|
||||||
if err != nil {
|
|
||||||
return types.NewError(err, types.ErrorCodeChannelAwsClientError), nil
|
|
||||||
}
|
|
||||||
|
|
||||||
awsModelId := awsModelID(c.GetString("request_model"))
|
awsResp, err := a.AwsClient.InvokeModel(c.Request.Context(), a.AwsReq.(*bedrockruntime.InvokeModelInput))
|
||||||
// 检查是否为Nova模型
|
|
||||||
isNova, _ := c.Get("is_nova_model")
|
|
||||||
if isNova == true {
|
|
||||||
// Nova模型也支持跨区域
|
|
||||||
awsRegionPrefix := awsRegionPrefix(awsCli.Options().Region)
|
|
||||||
canCrossRegion := awsModelCanCrossRegion(awsModelId, awsRegionPrefix)
|
|
||||||
if canCrossRegion {
|
|
||||||
awsModelId = awsModelCrossRegion(awsModelId, awsRegionPrefix)
|
|
||||||
}
|
|
||||||
return handleNovaRequest(c, awsCli, info, awsModelId)
|
|
||||||
}
|
|
||||||
|
|
||||||
// 原有的Claude处理逻辑
|
|
||||||
awsRegionPrefix := awsRegionPrefix(awsCli.Options().Region)
|
|
||||||
canCrossRegion := awsModelCanCrossRegion(awsModelId, awsRegionPrefix)
|
|
||||||
if canCrossRegion {
|
|
||||||
awsModelId = awsModelCrossRegion(awsModelId, awsRegionPrefix)
|
|
||||||
}
|
|
||||||
|
|
||||||
awsReq := &bedrockruntime.InvokeModelInput{
|
|
||||||
ModelId: aws.String(awsModelId),
|
|
||||||
Accept: aws.String("application/json"),
|
|
||||||
ContentType: aws.String("application/json"),
|
|
||||||
}
|
|
||||||
|
|
||||||
claudeReq_, ok := c.Get("converted_request")
|
|
||||||
if !ok {
|
|
||||||
return types.NewError(errors.New("aws claude request not found"), types.ErrorCodeInvalidRequest), nil
|
|
||||||
}
|
|
||||||
claudeReq := claudeReq_.(*dto.ClaudeRequest)
|
|
||||||
awsClaudeReq := copyRequest(claudeReq)
|
|
||||||
awsReq.Body, err = common.Marshal(awsClaudeReq)
|
|
||||||
if err != nil {
|
|
||||||
return types.NewError(errors.Wrap(err, "marshal request"), types.ErrorCodeBadResponseBody), nil
|
|
||||||
}
|
|
||||||
|
|
||||||
awsResp, err := awsCli.InvokeModel(c.Request.Context(), awsReq)
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return types.NewOpenAIError(errors.Wrap(err, "InvokeModel"), types.ErrorCodeAwsInvokeError, http.StatusInternalServerError), nil
|
return types.NewOpenAIError(errors.Wrap(err, "InvokeModel"), types.ErrorCodeAwsInvokeError, http.StatusInternalServerError), nil
|
||||||
}
|
}
|
||||||
@@ -156,39 +115,8 @@ func awsHandler(c *gin.Context, info *relaycommon.RelayInfo, requestMode int) (*
|
|||||||
return nil, claudeInfo.Usage
|
return nil, claudeInfo.Usage
|
||||||
}
|
}
|
||||||
|
|
||||||
func awsStreamHandler(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo, requestMode int) (*types.NewAPIError, *dto.Usage) {
|
func awsStreamHandler(c *gin.Context, info *relaycommon.RelayInfo, a *Adaptor) (*types.NewAPIError, *dto.Usage) {
|
||||||
awsCli, err := newAwsClient(c, info)
|
awsResp, err := a.AwsClient.InvokeModelWithResponseStream(c.Request.Context(), a.AwsReq.(*bedrockruntime.InvokeModelWithResponseStreamInput))
|
||||||
if err != nil {
|
|
||||||
return types.NewError(err, types.ErrorCodeChannelAwsClientError), nil
|
|
||||||
}
|
|
||||||
|
|
||||||
awsModelId := awsModelID(c.GetString("request_model"))
|
|
||||||
|
|
||||||
awsRegionPrefix := awsRegionPrefix(awsCli.Options().Region)
|
|
||||||
canCrossRegion := awsModelCanCrossRegion(awsModelId, awsRegionPrefix)
|
|
||||||
if canCrossRegion {
|
|
||||||
awsModelId = awsModelCrossRegion(awsModelId, awsRegionPrefix)
|
|
||||||
}
|
|
||||||
|
|
||||||
awsReq := &bedrockruntime.InvokeModelWithResponseStreamInput{
|
|
||||||
ModelId: aws.String(awsModelId),
|
|
||||||
Accept: aws.String("application/json"),
|
|
||||||
ContentType: aws.String("application/json"),
|
|
||||||
}
|
|
||||||
|
|
||||||
claudeReq_, ok := c.Get("converted_request")
|
|
||||||
if !ok {
|
|
||||||
return types.NewError(errors.New("aws claude request not found"), types.ErrorCodeInvalidRequest), nil
|
|
||||||
}
|
|
||||||
claudeReq := claudeReq_.(*dto.ClaudeRequest)
|
|
||||||
|
|
||||||
awsClaudeReq := copyRequest(claudeReq)
|
|
||||||
awsReq.Body, err = common.Marshal(awsClaudeReq)
|
|
||||||
if err != nil {
|
|
||||||
return types.NewError(errors.Wrap(err, "marshal request"), types.ErrorCodeBadResponseBody), nil
|
|
||||||
}
|
|
||||||
|
|
||||||
awsResp, err := awsCli.InvokeModelWithResponseStream(c.Request.Context(), awsReq)
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return types.NewOpenAIError(errors.Wrap(err, "InvokeModelWithResponseStream"), types.ErrorCodeAwsInvokeError, http.StatusInternalServerError), nil
|
return types.NewOpenAIError(errors.Wrap(err, "InvokeModelWithResponseStream"), types.ErrorCodeAwsInvokeError, http.StatusInternalServerError), nil
|
||||||
}
|
}
|
||||||
@@ -225,27 +153,9 @@ func awsStreamHandler(c *gin.Context, resp *http.Response, info *relaycommon.Rel
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Nova模型处理函数
|
// Nova模型处理函数
|
||||||
func handleNovaRequest(c *gin.Context, awsCli *bedrockruntime.Client, info *relaycommon.RelayInfo, awsModelId string) (*types.NewAPIError, *dto.Usage) {
|
func handleNovaRequest(c *gin.Context, info *relaycommon.RelayInfo, a *Adaptor) (*types.NewAPIError, *dto.Usage) {
|
||||||
novaReq_, ok := c.Get("converted_request")
|
|
||||||
if !ok {
|
|
||||||
return types.NewError(errors.New("nova request not found"), types.ErrorCodeInvalidRequest), nil
|
|
||||||
}
|
|
||||||
novaReq := novaReq_.(*NovaRequest)
|
|
||||||
|
|
||||||
// 使用InvokeModel API,但使用Nova格式的请求体
|
awsResp, err := a.AwsClient.InvokeModel(c.Request.Context(), a.AwsReq.(*bedrockruntime.InvokeModelInput))
|
||||||
awsReq := &bedrockruntime.InvokeModelInput{
|
|
||||||
ModelId: aws.String(awsModelId),
|
|
||||||
Accept: aws.String("application/json"),
|
|
||||||
ContentType: aws.String("application/json"),
|
|
||||||
}
|
|
||||||
|
|
||||||
reqBody, err := json.Marshal(novaReq)
|
|
||||||
if err != nil {
|
|
||||||
return types.NewError(errors.Wrap(err, "marshal nova request"), types.ErrorCodeBadResponseBody), nil
|
|
||||||
}
|
|
||||||
awsReq.Body = reqBody
|
|
||||||
|
|
||||||
awsResp, err := awsCli.InvokeModel(c.Request.Context(), awsReq)
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return types.NewError(errors.Wrap(err, "InvokeModel"), types.ErrorCodeChannelAwsClientError), nil
|
return types.NewError(errors.Wrap(err, "InvokeModel"), types.ErrorCodeChannelAwsClientError), nil
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -62,6 +62,9 @@ const (
|
|||||||
ErrorCodeConvertRequestFailed ErrorCode = "convert_request_failed"
|
ErrorCodeConvertRequestFailed ErrorCode = "convert_request_failed"
|
||||||
ErrorCodeAccessDenied ErrorCode = "access_denied"
|
ErrorCodeAccessDenied ErrorCode = "access_denied"
|
||||||
|
|
||||||
|
// request error
|
||||||
|
ErrorCodeBadRequestBody ErrorCode = "bad_request_body"
|
||||||
|
|
||||||
// response error
|
// response error
|
||||||
ErrorCodeReadResponseBodyFailed ErrorCode = "read_response_body_failed"
|
ErrorCodeReadResponseBodyFailed ErrorCode = "read_response_body_failed"
|
||||||
ErrorCodeBadResponseStatusCode ErrorCode = "bad_response_status_code"
|
ErrorCodeBadResponseStatusCode ErrorCode = "bad_response_status_code"
|
||||||
|
|||||||
Reference in New Issue
Block a user