mirror of
https://github.com/QuantumNous/new-api.git
synced 2026-03-30 05:02:17 +00:00
feat: support aws bedrock api-keys-use
This commit is contained in:
@@ -16,6 +16,13 @@ const (
|
||||
VertexKeyTypeAPIKey VertexKeyType = "api_key"
|
||||
)
|
||||
|
||||
type AwsKeyType string
|
||||
|
||||
const (
|
||||
AwsKeyTypeAKSK AwsKeyType = "ak_sk" // 默认
|
||||
AwsKeyTypeApiKey AwsKeyType = "api_key"
|
||||
)
|
||||
|
||||
type ChannelOtherSettings struct {
|
||||
AzureResponsesVersion string `json:"azure_responses_version,omitempty"`
|
||||
VertexKeyType VertexKeyType `json:"vertex_key_type,omitempty"` // "json" or "api_key"
|
||||
@@ -23,6 +30,7 @@ type ChannelOtherSettings struct {
|
||||
AllowServiceTier bool `json:"allow_service_tier,omitempty"` // 是否允许 service_tier 透传(默认过滤以避免额外计费)
|
||||
DisableStore bool `json:"disable_store,omitempty"` // 是否禁用 store 透传(默认允许透传,禁用后可能导致 Codex 无法使用)
|
||||
AllowSafetyIdentifier bool `json:"allow_safety_identifier,omitempty"` // 是否允许 safety_identifier 透传(默认过滤以保护用户隐私)
|
||||
AwsKeyType AwsKeyType `json:"aws_key_type,omitempty"`
|
||||
}
|
||||
|
||||
func (s *ChannelOtherSettings) IsOpenRouterEnterprise() bool {
|
||||
|
||||
@@ -1,27 +1,31 @@
|
||||
package aws
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"strings"
|
||||
|
||||
"github.com/QuantumNous/new-api/common"
|
||||
"github.com/QuantumNous/new-api/dto"
|
||||
"github.com/QuantumNous/new-api/relay/channel"
|
||||
"github.com/QuantumNous/new-api/relay/channel/claude"
|
||||
relaycommon "github.com/QuantumNous/new-api/relay/common"
|
||||
"github.com/QuantumNous/new-api/types"
|
||||
"github.com/aws/aws-sdk-go-v2/aws"
|
||||
"github.com/aws/aws-sdk-go-v2/service/bedrockruntime"
|
||||
"github.com/pkg/errors"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
type ClientMode int
|
||||
|
||||
const (
|
||||
RequestModeCompletion = 1
|
||||
RequestModeMessage = 2
|
||||
ClientModeApiKey ClientMode = iota + 1
|
||||
ClientModeAKSK
|
||||
)
|
||||
|
||||
type Adaptor struct {
|
||||
ClientMode ClientMode
|
||||
AwsClient *bedrockruntime.Client
|
||||
AwsModelId string
|
||||
AwsReq any
|
||||
@@ -51,11 +55,25 @@ func (a *Adaptor) Init(info *relaycommon.RelayInfo) {
|
||||
}
|
||||
|
||||
func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) {
|
||||
return "", nil
|
||||
if info.ChannelOtherSettings.AwsKeyType == dto.AwsKeyTypeApiKey {
|
||||
awsModelId := awsModelID(info.UpstreamModelName)
|
||||
a.ClientMode = ClientModeApiKey
|
||||
awsSecret := strings.Split(info.ApiKey, "|")
|
||||
if len(awsSecret) != 2 {
|
||||
return "", errors.New("invalid aws api key, should be in format of <api-key>|<region>")
|
||||
}
|
||||
return fmt.Sprintf("https://bedrock-runtime.%s.amazonaws.com/model/%s/converse", awsModelId, awsSecret[1]), nil
|
||||
} else {
|
||||
a.ClientMode = ClientModeAKSK
|
||||
return "", nil
|
||||
}
|
||||
}
|
||||
|
||||
func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Header, info *relaycommon.RelayInfo) error {
|
||||
claude.CommonClaudeHeadersOperation(c, req, info)
|
||||
if a.ClientMode == ClientModeApiKey {
|
||||
req.Set("Authorization", "Bearer "+info.ApiKey)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -95,82 +113,26 @@ func (a *Adaptor) ConvertOpenAIResponsesRequest(c *gin.Context, info *relaycommo
|
||||
}
|
||||
|
||||
func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (any, error) {
|
||||
awsCli, err := newAwsClient(c, info)
|
||||
if err != nil {
|
||||
return nil, types.NewError(err, types.ErrorCodeChannelAwsClientError)
|
||||
}
|
||||
a.AwsClient = awsCli
|
||||
|
||||
awsModelId := awsModelID(info.UpstreamModelName)
|
||||
|
||||
awsRegionPrefix := awsRegionPrefix(awsCli.Options().Region)
|
||||
canCrossRegion := awsModelCanCrossRegion(awsModelId, awsRegionPrefix)
|
||||
if canCrossRegion {
|
||||
awsModelId = awsModelCrossRegion(awsModelId, awsRegionPrefix)
|
||||
}
|
||||
|
||||
if isNovaModel(awsModelId) {
|
||||
var novaReq *NovaRequest
|
||||
err = common.DecodeJson(requestBody, &novaReq)
|
||||
if err != nil {
|
||||
return nil, types.NewError(errors.Wrap(err, "decode nova request fail"), types.ErrorCodeBadRequestBody)
|
||||
}
|
||||
|
||||
// 使用InvokeModel API,但使用Nova格式的请求体
|
||||
awsReq := &bedrockruntime.InvokeModelInput{
|
||||
ModelId: aws.String(awsModelId),
|
||||
Accept: aws.String("application/json"),
|
||||
ContentType: aws.String("application/json"),
|
||||
}
|
||||
|
||||
reqBody, err := common.Marshal(novaReq)
|
||||
if err != nil {
|
||||
return nil, types.NewError(errors.Wrap(err, "marshal nova request"), types.ErrorCodeBadResponseBody)
|
||||
}
|
||||
awsReq.Body = reqBody
|
||||
return nil, nil
|
||||
if a.ClientMode == ClientModeApiKey {
|
||||
return channel.DoApiRequest(a, c, info, requestBody)
|
||||
} else {
|
||||
awsClaudeReq, err := formatRequest(requestBody)
|
||||
if err != nil {
|
||||
return nil, types.NewError(errors.Wrap(err, "format aws request fail"), types.ErrorCodeBadRequestBody)
|
||||
}
|
||||
|
||||
if info.IsStream {
|
||||
awsReq := &bedrockruntime.InvokeModelWithResponseStreamInput{
|
||||
ModelId: aws.String(awsModelId),
|
||||
Accept: aws.String("application/json"),
|
||||
ContentType: aws.String("application/json"),
|
||||
}
|
||||
awsReq.Body, err = common.Marshal(awsClaudeReq)
|
||||
if err != nil {
|
||||
return nil, types.NewError(errors.Wrap(err, "marshal aws request fail"), types.ErrorCodeBadRequestBody)
|
||||
}
|
||||
a.AwsReq = awsReq
|
||||
return nil, nil
|
||||
} else {
|
||||
awsReq := &bedrockruntime.InvokeModelInput{
|
||||
ModelId: aws.String(awsModelId),
|
||||
Accept: aws.String("application/json"),
|
||||
ContentType: aws.String("application/json"),
|
||||
}
|
||||
awsReq.Body, err = common.Marshal(awsClaudeReq)
|
||||
if err != nil {
|
||||
return nil, types.NewError(errors.Wrap(err, "marshal aws request fail"), types.ErrorCodeBadRequestBody)
|
||||
}
|
||||
a.AwsReq = awsReq
|
||||
return nil, nil
|
||||
}
|
||||
return doAwsClientRequest(c, info, a, requestBody)
|
||||
}
|
||||
}
|
||||
|
||||
func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage any, err *types.NewAPIError) {
|
||||
if a.IsNova {
|
||||
err, usage = handleNovaRequest(c, info, a)
|
||||
if a.ClientMode == ClientModeApiKey {
|
||||
claudeAdaptor := claude.Adaptor{}
|
||||
usage, err = claudeAdaptor.DoResponse(c, resp, info)
|
||||
} else {
|
||||
if info.IsStream {
|
||||
err, usage = awsStreamHandler(c, info, a)
|
||||
if a.IsNova {
|
||||
err, usage = handleNovaRequest(c, info, a)
|
||||
} else {
|
||||
err, usage = awsHandler(c, info, a)
|
||||
if info.IsStream {
|
||||
err, usage = awsStreamHandler(c, info, a)
|
||||
} else {
|
||||
err, usage = awsHandler(c, info, a)
|
||||
}
|
||||
}
|
||||
}
|
||||
return
|
||||
|
||||
@@ -3,6 +3,7 @@ package aws
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"strings"
|
||||
|
||||
@@ -49,12 +50,72 @@ func newAwsClient(c *gin.Context, info *relaycommon.RelayInfo) (*bedrockruntime.
|
||||
return client, nil
|
||||
}
|
||||
|
||||
func wrapErr(err error) *dto.OpenAIErrorWithStatusCode {
|
||||
return &dto.OpenAIErrorWithStatusCode{
|
||||
StatusCode: http.StatusInternalServerError,
|
||||
Error: dto.OpenAIError{
|
||||
Message: fmt.Sprintf("%s", err.Error()),
|
||||
},
|
||||
func doAwsClientRequest(c *gin.Context, info *relaycommon.RelayInfo, a *Adaptor, requestBody io.Reader) (any, error) {
|
||||
awsCli, err := newAwsClient(c, info)
|
||||
if err != nil {
|
||||
return nil, types.NewError(err, types.ErrorCodeChannelAwsClientError)
|
||||
}
|
||||
a.AwsClient = awsCli
|
||||
|
||||
awsModelId := awsModelID(info.UpstreamModelName)
|
||||
|
||||
awsRegionPrefix := awsRegionPrefix(awsCli.Options().Region)
|
||||
canCrossRegion := awsModelCanCrossRegion(awsModelId, awsRegionPrefix)
|
||||
if canCrossRegion {
|
||||
awsModelId = awsModelCrossRegion(awsModelId, awsRegionPrefix)
|
||||
}
|
||||
|
||||
if isNovaModel(awsModelId) {
|
||||
var novaReq *NovaRequest
|
||||
err = common.DecodeJson(requestBody, &novaReq)
|
||||
if err != nil {
|
||||
return nil, types.NewError(errors.Wrap(err, "decode nova request fail"), types.ErrorCodeBadRequestBody)
|
||||
}
|
||||
|
||||
// 使用InvokeModel API,但使用Nova格式的请求体
|
||||
awsReq := &bedrockruntime.InvokeModelInput{
|
||||
ModelId: aws.String(awsModelId),
|
||||
Accept: aws.String("application/json"),
|
||||
ContentType: aws.String("application/json"),
|
||||
}
|
||||
|
||||
reqBody, err := common.Marshal(novaReq)
|
||||
if err != nil {
|
||||
return nil, types.NewError(errors.Wrap(err, "marshal nova request"), types.ErrorCodeBadResponseBody)
|
||||
}
|
||||
awsReq.Body = reqBody
|
||||
return nil, nil
|
||||
} else {
|
||||
awsClaudeReq, err := formatRequest(requestBody)
|
||||
if err != nil {
|
||||
return nil, types.NewError(errors.Wrap(err, "format aws request fail"), types.ErrorCodeBadRequestBody)
|
||||
}
|
||||
|
||||
if info.IsStream {
|
||||
awsReq := &bedrockruntime.InvokeModelWithResponseStreamInput{
|
||||
ModelId: aws.String(awsModelId),
|
||||
Accept: aws.String("application/json"),
|
||||
ContentType: aws.String("application/json"),
|
||||
}
|
||||
awsReq.Body, err = common.Marshal(awsClaudeReq)
|
||||
if err != nil {
|
||||
return nil, types.NewError(errors.Wrap(err, "marshal aws request fail"), types.ErrorCodeBadRequestBody)
|
||||
}
|
||||
a.AwsReq = awsReq
|
||||
return nil, nil
|
||||
} else {
|
||||
awsReq := &bedrockruntime.InvokeModelInput{
|
||||
ModelId: aws.String(awsModelId),
|
||||
Accept: aws.String("application/json"),
|
||||
ContentType: aws.String("application/json"),
|
||||
}
|
||||
awsReq.Body, err = common.Marshal(awsClaudeReq)
|
||||
if err != nil {
|
||||
return nil, types.NewError(errors.Wrap(err, "marshal aws request fail"), types.ErrorCodeBadRequestBody)
|
||||
}
|
||||
a.AwsReq = awsReq
|
||||
return nil, nil
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -108,7 +169,7 @@ func awsHandler(c *gin.Context, info *relaycommon.RelayInfo, a *Adaptor) (*types
|
||||
c.Writer.Header().Set("Content-Type", *awsResp.ContentType)
|
||||
}
|
||||
|
||||
handlerErr := claude.HandleClaudeResponseData(c, info, claudeInfo, nil, awsResp.Body, RequestModeMessage)
|
||||
handlerErr := claude.HandleClaudeResponseData(c, info, claudeInfo, nil, awsResp.Body, claude.RequestModeMessage)
|
||||
if handlerErr != nil {
|
||||
return handlerErr, nil
|
||||
}
|
||||
@@ -135,7 +196,7 @@ func awsStreamHandler(c *gin.Context, info *relaycommon.RelayInfo, a *Adaptor) (
|
||||
switch v := event.(type) {
|
||||
case *bedrockruntimeTypes.ResponseStreamMemberChunk:
|
||||
info.SetFirstResponseTime()
|
||||
respErr := claude.HandleStreamResponseData(c, info, claudeInfo, string(v.Value.Bytes), RequestModeMessage)
|
||||
respErr := claude.HandleStreamResponseData(c, info, claudeInfo, string(v.Value.Bytes), claude.RequestModeMessage)
|
||||
if respErr != nil {
|
||||
return respErr, nil
|
||||
}
|
||||
@@ -148,7 +209,7 @@ func awsStreamHandler(c *gin.Context, info *relaycommon.RelayInfo, a *Adaptor) (
|
||||
}
|
||||
}
|
||||
|
||||
claude.HandleStreamFinalResponse(c, info, claudeInfo, RequestModeMessage)
|
||||
claude.HandleStreamFinalResponse(c, info, claudeInfo, claude.RequestModeMessage)
|
||||
return nil, claudeInfo.Usage
|
||||
}
|
||||
|
||||
|
||||
@@ -153,6 +153,8 @@ const EditChannelModal = (props) => {
|
||||
settings: '',
|
||||
// 仅 Vertex: 密钥格式(存入 settings.vertex_key_type)
|
||||
vertex_key_type: 'json',
|
||||
// 仅 AWS: 密钥格式和区域(存入 settings.aws_key_type 和 settings.aws_region)
|
||||
aws_key_type: 'ak_sk',
|
||||
// 企业账户设置
|
||||
is_enterprise_account: false,
|
||||
// 字段透传控制默认值
|
||||
@@ -515,6 +517,8 @@ const EditChannelModal = (props) => {
|
||||
parsedSettings.azure_responses_version || '';
|
||||
// 读取 Vertex 密钥格式
|
||||
data.vertex_key_type = parsedSettings.vertex_key_type || 'json';
|
||||
// 读取 AWS 密钥格式和区域
|
||||
data.aws_key_type = parsedSettings.aws_key_type || 'ak_sk';
|
||||
// 读取企业账户设置
|
||||
data.is_enterprise_account =
|
||||
parsedSettings.openrouter_enterprise === true;
|
||||
@@ -528,6 +532,7 @@ const EditChannelModal = (props) => {
|
||||
data.azure_responses_version = '';
|
||||
data.region = '';
|
||||
data.vertex_key_type = 'json';
|
||||
data.aws_key_type = 'ak_sk';
|
||||
data.is_enterprise_account = false;
|
||||
data.allow_service_tier = false;
|
||||
data.disable_store = false;
|
||||
@@ -536,6 +541,7 @@ const EditChannelModal = (props) => {
|
||||
} else {
|
||||
// 兼容历史数据:老渠道没有 settings 时,默认按 json 展示
|
||||
data.vertex_key_type = 'json';
|
||||
data.aws_key_type = 'ak_sk';
|
||||
data.is_enterprise_account = false;
|
||||
data.allow_service_tier = false;
|
||||
data.disable_store = false;
|
||||
@@ -997,6 +1003,11 @@ const EditChannelModal = (props) => {
|
||||
localInputs.is_enterprise_account === true;
|
||||
}
|
||||
|
||||
// type === 33 (AWS): 保存 aws_key_type 到 settings
|
||||
if (localInputs.type === 33) {
|
||||
settings.aws_key_type = localInputs.aws_key_type || 'ak_sk';
|
||||
}
|
||||
|
||||
// type === 1 (OpenAI) 或 type === 14 (Claude): 设置字段透传控制(显式保存布尔值)
|
||||
if (localInputs.type === 1 || localInputs.type === 14) {
|
||||
settings.allow_service_tier = localInputs.allow_service_tier === true;
|
||||
@@ -1020,6 +1031,8 @@ const EditChannelModal = (props) => {
|
||||
delete localInputs.is_enterprise_account;
|
||||
// 顶层的 vertex_key_type 不应发送给后端
|
||||
delete localInputs.vertex_key_type;
|
||||
// 顶层的 aws_key_type 不应发送给后端
|
||||
delete localInputs.aws_key_type;
|
||||
// 清理字段透传控制的临时字段
|
||||
delete localInputs.allow_service_tier;
|
||||
delete localInputs.disable_store;
|
||||
@@ -1468,6 +1481,31 @@ const EditChannelModal = (props) => {
|
||||
autoComplete='new-password'
|
||||
/>
|
||||
|
||||
{inputs.type === 33 && (
|
||||
<>
|
||||
<Form.Select
|
||||
field='aws_key_type'
|
||||
label={t('密钥格式')}
|
||||
placeholder={t('请选择密钥格式')}
|
||||
optionList={[
|
||||
{
|
||||
label: 'Access Key ID / Secret Access Key',
|
||||
value: 'ak_sk',
|
||||
},
|
||||
{ label: 'API Key', value: 'api_key' },
|
||||
]}
|
||||
style={{ width: '100%' }}
|
||||
value={inputs.aws_key_type || 'ak_sk'}
|
||||
onChange={(value) => {
|
||||
handleChannelOtherSettingsChange('aws_key_type', value);
|
||||
}}
|
||||
extraText={t(
|
||||
'AK/SK 模式:使用 Access Key ID 和 Secret Access Key;API Key 模式:使用 API Key',
|
||||
)}
|
||||
/>
|
||||
</>
|
||||
)}
|
||||
|
||||
{inputs.type === 41 && (
|
||||
<Form.Select
|
||||
field='vertex_key_type'
|
||||
@@ -1536,7 +1574,15 @@ const EditChannelModal = (props) => {
|
||||
<Form.TextArea
|
||||
field='key'
|
||||
label={t('密钥')}
|
||||
placeholder={t('请输入密钥,一行一个')}
|
||||
placeholder={
|
||||
inputs.type === 33
|
||||
? inputs.aws_key_type === 'api_key'
|
||||
? t('请输入 API Key,一行一个,格式:API Key|Region')
|
||||
: t(
|
||||
'请输入密钥,一行一个,格式:Access Key ID|Secret Access Key|Region',
|
||||
)
|
||||
: t('请输入密钥,一行一个')
|
||||
}
|
||||
rules={
|
||||
isEdit
|
||||
? []
|
||||
@@ -1730,7 +1776,13 @@ const EditChannelModal = (props) => {
|
||||
? t('密钥(编辑模式下,保存的密钥不会显示)')
|
||||
: t('密钥')
|
||||
}
|
||||
placeholder={t(type2secretPrompt(inputs.type))}
|
||||
placeholder={
|
||||
inputs.type === 33
|
||||
? inputs.aws_key_type === 'api_key'
|
||||
? t('请输入 API Key,格式:API Key|Region')
|
||||
: t('按照如下格式输入:Access Key ID|Secret Access Key|Region')
|
||||
: t(type2secretPrompt(inputs.type))
|
||||
}
|
||||
rules={
|
||||
isEdit
|
||||
? []
|
||||
|
||||
Reference in New Issue
Block a user