mirror of
https://github.com/QuantumNous/new-api.git
synced 2026-03-30 04:40:27 +00:00
184 lines
5.2 KiB
Go
184 lines
5.2 KiB
Go
package aws
|
|
|
|
import (
|
|
"fmt"
|
|
"io"
|
|
"net/http"
|
|
"strings"
|
|
|
|
"github.com/QuantumNous/new-api/dto"
|
|
"github.com/QuantumNous/new-api/relay/channel"
|
|
"github.com/QuantumNous/new-api/relay/channel/claude"
|
|
relaycommon "github.com/QuantumNous/new-api/relay/common"
|
|
"github.com/QuantumNous/new-api/service"
|
|
"github.com/QuantumNous/new-api/types"
|
|
"github.com/aws/aws-sdk-go-v2/service/bedrockruntime"
|
|
"github.com/pkg/errors"
|
|
|
|
"github.com/gin-gonic/gin"
|
|
)
|
|
|
|
type ClientMode int
|
|
|
|
const (
|
|
ClientModeApiKey ClientMode = iota + 1
|
|
ClientModeAKSK
|
|
)
|
|
|
|
type Adaptor struct {
|
|
ClientMode ClientMode
|
|
AwsClient *bedrockruntime.Client
|
|
AwsModelId string
|
|
AwsReq any
|
|
IsNova bool
|
|
}
|
|
|
|
func (a *Adaptor) ConvertGeminiRequest(*gin.Context, *relaycommon.RelayInfo, *dto.GeminiChatRequest) (any, error) {
|
|
//TODO implement me
|
|
return nil, errors.New("not implemented")
|
|
}
|
|
|
|
func (a *Adaptor) ConvertClaudeRequest(c *gin.Context, info *relaycommon.RelayInfo, request *dto.ClaudeRequest) (any, error) {
|
|
for i, message := range request.Messages {
|
|
updated := false
|
|
if !message.IsStringContent() {
|
|
content, err := message.ParseContent()
|
|
if err != nil {
|
|
return nil, errors.Wrap(err, "failed to parse message content")
|
|
}
|
|
for i2, mediaMessage := range content {
|
|
if mediaMessage.Source != nil {
|
|
if mediaMessage.Source.Type == "url" {
|
|
fileData, err := service.GetFileBase64FromUrl(c, mediaMessage.Source.Url, "formatting image for Claude")
|
|
if err != nil {
|
|
return nil, fmt.Errorf("get file base64 from url failed: %s", err.Error())
|
|
}
|
|
mediaMessage.Source.MediaType = fileData.MimeType
|
|
mediaMessage.Source.Data = fileData.Base64Data
|
|
mediaMessage.Source.Url = ""
|
|
mediaMessage.Source.Type = "base64"
|
|
content[i2] = mediaMessage
|
|
updated = true
|
|
}
|
|
}
|
|
}
|
|
if updated {
|
|
message.SetContent(content)
|
|
}
|
|
}
|
|
if updated {
|
|
request.Messages[i] = message
|
|
}
|
|
}
|
|
return request, nil
|
|
}
|
|
|
|
func (a *Adaptor) ConvertAudioRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.AudioRequest) (io.Reader, error) {
|
|
//TODO implement me
|
|
return nil, errors.New("not implemented")
|
|
}
|
|
|
|
func (a *Adaptor) ConvertImageRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.ImageRequest) (any, error) {
|
|
//TODO implement me
|
|
return nil, errors.New("not implemented")
|
|
}
|
|
|
|
func (a *Adaptor) Init(info *relaycommon.RelayInfo) {
|
|
}
|
|
|
|
func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) {
|
|
if info.ChannelOtherSettings.AwsKeyType == dto.AwsKeyTypeApiKey {
|
|
awsModelId := awsModelID(info.UpstreamModelName)
|
|
a.ClientMode = ClientModeApiKey
|
|
awsSecret := strings.Split(info.ApiKey, "|")
|
|
if len(awsSecret) != 2 {
|
|
return "", errors.New("invalid aws api key, should be in format of <api-key>|<region>")
|
|
}
|
|
return fmt.Sprintf("https://bedrock-runtime.%s.amazonaws.com/model/%s/converse", awsModelId, awsSecret[1]), nil
|
|
} else {
|
|
a.ClientMode = ClientModeAKSK
|
|
return "", nil
|
|
}
|
|
}
|
|
|
|
func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Header, info *relaycommon.RelayInfo) error {
|
|
claude.CommonClaudeHeadersOperation(c, req, info)
|
|
if a.ClientMode == ClientModeApiKey {
|
|
req.Set("Authorization", "Bearer "+info.ApiKey)
|
|
}
|
|
return nil
|
|
}
|
|
|
|
func (a *Adaptor) ConvertOpenAIRequest(c *gin.Context, info *relaycommon.RelayInfo, request *dto.GeneralOpenAIRequest) (any, error) {
|
|
if request == nil {
|
|
return nil, errors.New("request is nil")
|
|
}
|
|
// 检查是否为Nova模型
|
|
if isNovaModel(request.Model) {
|
|
novaReq := convertToNovaRequest(request)
|
|
a.IsNova = true
|
|
return novaReq, nil
|
|
}
|
|
|
|
// 原有的Claude模型处理逻辑
|
|
var claudeReq *dto.ClaudeRequest
|
|
var err error
|
|
claudeReq, err = claude.RequestOpenAI2ClaudeMessage(c, *request)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
return claudeReq, err
|
|
}
|
|
|
|
func (a *Adaptor) ConvertRerankRequest(c *gin.Context, relayMode int, request dto.RerankRequest) (any, error) {
|
|
return nil, nil
|
|
}
|
|
|
|
func (a *Adaptor) ConvertEmbeddingRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.EmbeddingRequest) (any, error) {
|
|
//TODO implement me
|
|
return nil, errors.New("not implemented")
|
|
}
|
|
|
|
func (a *Adaptor) ConvertOpenAIResponsesRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.OpenAIResponsesRequest) (any, error) {
|
|
// TODO implement me
|
|
return nil, errors.New("not implemented")
|
|
}
|
|
|
|
func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (any, error) {
|
|
if a.ClientMode == ClientModeApiKey {
|
|
return channel.DoApiRequest(a, c, info, requestBody)
|
|
} else {
|
|
return doAwsClientRequest(c, info, a, requestBody)
|
|
}
|
|
}
|
|
|
|
func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage any, err *types.NewAPIError) {
|
|
if a.ClientMode == ClientModeApiKey {
|
|
claudeAdaptor := claude.Adaptor{}
|
|
usage, err = claudeAdaptor.DoResponse(c, resp, info)
|
|
} else {
|
|
if a.IsNova {
|
|
err, usage = handleNovaRequest(c, info, a)
|
|
} else {
|
|
if info.IsStream {
|
|
err, usage = awsStreamHandler(c, info, a)
|
|
} else {
|
|
err, usage = awsHandler(c, info, a)
|
|
}
|
|
}
|
|
}
|
|
return
|
|
}
|
|
|
|
func (a *Adaptor) GetModelList() (models []string) {
|
|
for n := range awsModelIDMap {
|
|
models = append(models, n)
|
|
}
|
|
|
|
return
|
|
}
|
|
|
|
func (a *Adaptor) GetChannelName() string {
|
|
return ChannelName
|
|
}
|