diff --git a/dto/channel_settings.go b/dto/channel_settings.go index d57184b38..e88f2235e 100644 --- a/dto/channel_settings.go +++ b/dto/channel_settings.go @@ -16,6 +16,13 @@ const ( VertexKeyTypeAPIKey VertexKeyType = "api_key" ) +type AwsKeyType string + +const ( + AwsKeyTypeAKSK AwsKeyType = "ak_sk" // 默认 + AwsKeyTypeApiKey AwsKeyType = "api_key" +) + type ChannelOtherSettings struct { AzureResponsesVersion string `json:"azure_responses_version,omitempty"` VertexKeyType VertexKeyType `json:"vertex_key_type,omitempty"` // "json" or "api_key" @@ -23,6 +30,7 @@ type ChannelOtherSettings struct { AllowServiceTier bool `json:"allow_service_tier,omitempty"` // 是否允许 service_tier 透传(默认过滤以避免额外计费) DisableStore bool `json:"disable_store,omitempty"` // 是否禁用 store 透传(默认允许透传,禁用后可能导致 Codex 无法使用) AllowSafetyIdentifier bool `json:"allow_safety_identifier,omitempty"` // 是否允许 safety_identifier 透传(默认过滤以保护用户隐私) + AwsKeyType AwsKeyType `json:"aws_key_type,omitempty"` } func (s *ChannelOtherSettings) IsOpenRouterEnterprise() bool { diff --git a/relay/channel/aws/adaptor.go b/relay/channel/aws/adaptor.go index 5d3f9ac71..70736ec20 100644 --- a/relay/channel/aws/adaptor.go +++ b/relay/channel/aws/adaptor.go @@ -1,27 +1,31 @@ package aws import ( + "fmt" "io" "net/http" + "strings" - "github.com/QuantumNous/new-api/common" "github.com/QuantumNous/new-api/dto" + "github.com/QuantumNous/new-api/relay/channel" "github.com/QuantumNous/new-api/relay/channel/claude" relaycommon "github.com/QuantumNous/new-api/relay/common" "github.com/QuantumNous/new-api/types" - "github.com/aws/aws-sdk-go-v2/aws" "github.com/aws/aws-sdk-go-v2/service/bedrockruntime" "github.com/pkg/errors" "github.com/gin-gonic/gin" ) +type ClientMode int + const ( - RequestModeCompletion = 1 - RequestModeMessage = 2 + ClientModeApiKey ClientMode = iota + 1 + ClientModeAKSK ) type Adaptor struct { + ClientMode ClientMode AwsClient *bedrockruntime.Client AwsModelId string AwsReq any @@ -51,11 +55,25 @@ func (a *Adaptor) Init(info *relaycommon.RelayInfo) { } func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) { - return "", nil + if info.ChannelOtherSettings.AwsKeyType == dto.AwsKeyTypeApiKey { + awsModelId := awsModelID(info.UpstreamModelName) + a.ClientMode = ClientModeApiKey + awsSecret := strings.Split(info.ApiKey, "|") + if len(awsSecret) != 2 { + return "", errors.New("invalid aws api key, should be in format of |") + } + return fmt.Sprintf("https://bedrock-runtime.%s.amazonaws.com/model/%s/converse", awsModelId, awsSecret[1]), nil + } else { + a.ClientMode = ClientModeAKSK + return "", nil + } } func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Header, info *relaycommon.RelayInfo) error { claude.CommonClaudeHeadersOperation(c, req, info) + if a.ClientMode == ClientModeApiKey { + req.Set("Authorization", "Bearer "+info.ApiKey) + } return nil } @@ -95,82 +113,26 @@ func (a *Adaptor) ConvertOpenAIResponsesRequest(c *gin.Context, info *relaycommo } func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (any, error) { - awsCli, err := newAwsClient(c, info) - if err != nil { - return nil, types.NewError(err, types.ErrorCodeChannelAwsClientError) - } - a.AwsClient = awsCli - - awsModelId := awsModelID(info.UpstreamModelName) - - awsRegionPrefix := awsRegionPrefix(awsCli.Options().Region) - canCrossRegion := awsModelCanCrossRegion(awsModelId, awsRegionPrefix) - if canCrossRegion { - awsModelId = awsModelCrossRegion(awsModelId, awsRegionPrefix) - } - - if isNovaModel(awsModelId) { - var novaReq *NovaRequest - err = common.DecodeJson(requestBody, &novaReq) - if err != nil { - return nil, types.NewError(errors.Wrap(err, "decode nova request fail"), types.ErrorCodeBadRequestBody) - } - - // 使用InvokeModel API,但使用Nova格式的请求体 - awsReq := &bedrockruntime.InvokeModelInput{ - ModelId: aws.String(awsModelId), - Accept: aws.String("application/json"), - ContentType: aws.String("application/json"), - } - - reqBody, err := common.Marshal(novaReq) - if err != nil { - return nil, types.NewError(errors.Wrap(err, "marshal nova request"), types.ErrorCodeBadResponseBody) - } - awsReq.Body = reqBody - return nil, nil + if a.ClientMode == ClientModeApiKey { + return channel.DoApiRequest(a, c, info, requestBody) } else { - awsClaudeReq, err := formatRequest(requestBody) - if err != nil { - return nil, types.NewError(errors.Wrap(err, "format aws request fail"), types.ErrorCodeBadRequestBody) - } - - if info.IsStream { - awsReq := &bedrockruntime.InvokeModelWithResponseStreamInput{ - ModelId: aws.String(awsModelId), - Accept: aws.String("application/json"), - ContentType: aws.String("application/json"), - } - awsReq.Body, err = common.Marshal(awsClaudeReq) - if err != nil { - return nil, types.NewError(errors.Wrap(err, "marshal aws request fail"), types.ErrorCodeBadRequestBody) - } - a.AwsReq = awsReq - return nil, nil - } else { - awsReq := &bedrockruntime.InvokeModelInput{ - ModelId: aws.String(awsModelId), - Accept: aws.String("application/json"), - ContentType: aws.String("application/json"), - } - awsReq.Body, err = common.Marshal(awsClaudeReq) - if err != nil { - return nil, types.NewError(errors.Wrap(err, "marshal aws request fail"), types.ErrorCodeBadRequestBody) - } - a.AwsReq = awsReq - return nil, nil - } + return doAwsClientRequest(c, info, a, requestBody) } } func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage any, err *types.NewAPIError) { - if a.IsNova { - err, usage = handleNovaRequest(c, info, a) + if a.ClientMode == ClientModeApiKey { + claudeAdaptor := claude.Adaptor{} + usage, err = claudeAdaptor.DoResponse(c, resp, info) } else { - if info.IsStream { - err, usage = awsStreamHandler(c, info, a) + if a.IsNova { + err, usage = handleNovaRequest(c, info, a) } else { - err, usage = awsHandler(c, info, a) + if info.IsStream { + err, usage = awsStreamHandler(c, info, a) + } else { + err, usage = awsHandler(c, info, a) + } } } return diff --git a/relay/channel/aws/relay-aws.go b/relay/channel/aws/relay-aws.go index 8adbbaaec..f55690949 100644 --- a/relay/channel/aws/relay-aws.go +++ b/relay/channel/aws/relay-aws.go @@ -3,6 +3,7 @@ package aws import ( "encoding/json" "fmt" + "io" "net/http" "strings" @@ -49,12 +50,72 @@ func newAwsClient(c *gin.Context, info *relaycommon.RelayInfo) (*bedrockruntime. return client, nil } -func wrapErr(err error) *dto.OpenAIErrorWithStatusCode { - return &dto.OpenAIErrorWithStatusCode{ - StatusCode: http.StatusInternalServerError, - Error: dto.OpenAIError{ - Message: fmt.Sprintf("%s", err.Error()), - }, +func doAwsClientRequest(c *gin.Context, info *relaycommon.RelayInfo, a *Adaptor, requestBody io.Reader) (any, error) { + awsCli, err := newAwsClient(c, info) + if err != nil { + return nil, types.NewError(err, types.ErrorCodeChannelAwsClientError) + } + a.AwsClient = awsCli + + awsModelId := awsModelID(info.UpstreamModelName) + + awsRegionPrefix := awsRegionPrefix(awsCli.Options().Region) + canCrossRegion := awsModelCanCrossRegion(awsModelId, awsRegionPrefix) + if canCrossRegion { + awsModelId = awsModelCrossRegion(awsModelId, awsRegionPrefix) + } + + if isNovaModel(awsModelId) { + var novaReq *NovaRequest + err = common.DecodeJson(requestBody, &novaReq) + if err != nil { + return nil, types.NewError(errors.Wrap(err, "decode nova request fail"), types.ErrorCodeBadRequestBody) + } + + // 使用InvokeModel API,但使用Nova格式的请求体 + awsReq := &bedrockruntime.InvokeModelInput{ + ModelId: aws.String(awsModelId), + Accept: aws.String("application/json"), + ContentType: aws.String("application/json"), + } + + reqBody, err := common.Marshal(novaReq) + if err != nil { + return nil, types.NewError(errors.Wrap(err, "marshal nova request"), types.ErrorCodeBadResponseBody) + } + awsReq.Body = reqBody + return nil, nil + } else { + awsClaudeReq, err := formatRequest(requestBody) + if err != nil { + return nil, types.NewError(errors.Wrap(err, "format aws request fail"), types.ErrorCodeBadRequestBody) + } + + if info.IsStream { + awsReq := &bedrockruntime.InvokeModelWithResponseStreamInput{ + ModelId: aws.String(awsModelId), + Accept: aws.String("application/json"), + ContentType: aws.String("application/json"), + } + awsReq.Body, err = common.Marshal(awsClaudeReq) + if err != nil { + return nil, types.NewError(errors.Wrap(err, "marshal aws request fail"), types.ErrorCodeBadRequestBody) + } + a.AwsReq = awsReq + return nil, nil + } else { + awsReq := &bedrockruntime.InvokeModelInput{ + ModelId: aws.String(awsModelId), + Accept: aws.String("application/json"), + ContentType: aws.String("application/json"), + } + awsReq.Body, err = common.Marshal(awsClaudeReq) + if err != nil { + return nil, types.NewError(errors.Wrap(err, "marshal aws request fail"), types.ErrorCodeBadRequestBody) + } + a.AwsReq = awsReq + return nil, nil + } } } @@ -108,7 +169,7 @@ func awsHandler(c *gin.Context, info *relaycommon.RelayInfo, a *Adaptor) (*types c.Writer.Header().Set("Content-Type", *awsResp.ContentType) } - handlerErr := claude.HandleClaudeResponseData(c, info, claudeInfo, nil, awsResp.Body, RequestModeMessage) + handlerErr := claude.HandleClaudeResponseData(c, info, claudeInfo, nil, awsResp.Body, claude.RequestModeMessage) if handlerErr != nil { return handlerErr, nil } @@ -135,7 +196,7 @@ func awsStreamHandler(c *gin.Context, info *relaycommon.RelayInfo, a *Adaptor) ( switch v := event.(type) { case *bedrockruntimeTypes.ResponseStreamMemberChunk: info.SetFirstResponseTime() - respErr := claude.HandleStreamResponseData(c, info, claudeInfo, string(v.Value.Bytes), RequestModeMessage) + respErr := claude.HandleStreamResponseData(c, info, claudeInfo, string(v.Value.Bytes), claude.RequestModeMessage) if respErr != nil { return respErr, nil } @@ -148,7 +209,7 @@ func awsStreamHandler(c *gin.Context, info *relaycommon.RelayInfo, a *Adaptor) ( } } - claude.HandleStreamFinalResponse(c, info, claudeInfo, RequestModeMessage) + claude.HandleStreamFinalResponse(c, info, claudeInfo, claude.RequestModeMessage) return nil, claudeInfo.Usage } diff --git a/web/src/components/table/channels/modals/EditChannelModal.jsx b/web/src/components/table/channels/modals/EditChannelModal.jsx index 5fa990c0c..949bbef70 100644 --- a/web/src/components/table/channels/modals/EditChannelModal.jsx +++ b/web/src/components/table/channels/modals/EditChannelModal.jsx @@ -153,6 +153,8 @@ const EditChannelModal = (props) => { settings: '', // 仅 Vertex: 密钥格式(存入 settings.vertex_key_type) vertex_key_type: 'json', + // 仅 AWS: 密钥格式和区域(存入 settings.aws_key_type 和 settings.aws_region) + aws_key_type: 'ak_sk', // 企业账户设置 is_enterprise_account: false, // 字段透传控制默认值 @@ -515,6 +517,8 @@ const EditChannelModal = (props) => { parsedSettings.azure_responses_version || ''; // 读取 Vertex 密钥格式 data.vertex_key_type = parsedSettings.vertex_key_type || 'json'; + // 读取 AWS 密钥格式和区域 + data.aws_key_type = parsedSettings.aws_key_type || 'ak_sk'; // 读取企业账户设置 data.is_enterprise_account = parsedSettings.openrouter_enterprise === true; @@ -528,6 +532,7 @@ const EditChannelModal = (props) => { data.azure_responses_version = ''; data.region = ''; data.vertex_key_type = 'json'; + data.aws_key_type = 'ak_sk'; data.is_enterprise_account = false; data.allow_service_tier = false; data.disable_store = false; @@ -536,6 +541,7 @@ const EditChannelModal = (props) => { } else { // 兼容历史数据:老渠道没有 settings 时,默认按 json 展示 data.vertex_key_type = 'json'; + data.aws_key_type = 'ak_sk'; data.is_enterprise_account = false; data.allow_service_tier = false; data.disable_store = false; @@ -997,6 +1003,11 @@ const EditChannelModal = (props) => { localInputs.is_enterprise_account === true; } + // type === 33 (AWS): 保存 aws_key_type 到 settings + if (localInputs.type === 33) { + settings.aws_key_type = localInputs.aws_key_type || 'ak_sk'; + } + // type === 1 (OpenAI) 或 type === 14 (Claude): 设置字段透传控制(显式保存布尔值) if (localInputs.type === 1 || localInputs.type === 14) { settings.allow_service_tier = localInputs.allow_service_tier === true; @@ -1020,6 +1031,8 @@ const EditChannelModal = (props) => { delete localInputs.is_enterprise_account; // 顶层的 vertex_key_type 不应发送给后端 delete localInputs.vertex_key_type; + // 顶层的 aws_key_type 不应发送给后端 + delete localInputs.aws_key_type; // 清理字段透传控制的临时字段 delete localInputs.allow_service_tier; delete localInputs.disable_store; @@ -1468,6 +1481,31 @@ const EditChannelModal = (props) => { autoComplete='new-password' /> + {inputs.type === 33 && ( + <> + { + handleChannelOtherSettingsChange('aws_key_type', value); + }} + extraText={t( + 'AK/SK 模式:使用 Access Key ID 和 Secret Access Key;API Key 模式:使用 API Key', + )} + /> + + )} + {inputs.type === 41 && ( { { ? t('密钥(编辑模式下,保存的密钥不会显示)') : t('密钥') } - placeholder={t(type2secretPrompt(inputs.type))} + placeholder={ + inputs.type === 33 + ? inputs.aws_key_type === 'api_key' + ? t('请输入 API Key,格式:API Key|Region') + : t('按照如下格式输入:Access Key ID|Secret Access Key|Region') + : t(type2secretPrompt(inputs.type)) + } rules={ isEdit ? []