mirror of
https://github.com/QuantumNous/new-api.git
synced 2026-03-30 18:43:40 +00:00
Compare commits
23 Commits
v0.9.7-pat
...
v0.9.9
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
bd4160793e | ||
|
|
82e21972ec | ||
|
|
dce00141ce | ||
|
|
b2a057723a | ||
|
|
f023efdbfc | ||
|
|
8b65623726 | ||
|
|
aa35d8db69 | ||
|
|
64ed7dce4d | ||
|
|
67c321c4fb | ||
|
|
b3f50e9dd0 | ||
|
|
ea870a7846 | ||
|
|
fa21599fc8 | ||
|
|
e6c42bfbda | ||
|
|
7d480d5ff3 | ||
|
|
86c63ea4a7 | ||
|
|
2624c48113 | ||
|
|
384cba92cf | ||
|
|
7222265fee | ||
|
|
fdbc31eb9a | ||
|
|
3172c956f7 | ||
|
|
8b9188c584 | ||
|
|
5fc9152499 | ||
|
|
18b945b9c5 |
16
README.md
16
README.md
@@ -165,12 +165,18 @@ New API提供了丰富的功能,详细特性请参考[特性说明](https://do
|
||||
|
||||
#### 使用Docker Compose部署(推荐)
|
||||
```shell
|
||||
# 下载项目
|
||||
git clone https://github.com/Calcium-Ion/new-api.git
|
||||
# 下载项目源码
|
||||
git clone https://github.com/QuantumNous/new-api.git
|
||||
|
||||
# 进入项目目录
|
||||
cd new-api
|
||||
# 按需编辑docker-compose.yml
|
||||
# 启动
|
||||
docker-compose up -d
|
||||
|
||||
# 根据需要编辑 docker-compose.yml 文件
|
||||
# 使用nano编辑器
|
||||
nano docker-compose.yml
|
||||
# 或使用vim编辑器
|
||||
# vim docker-compose.yml
|
||||
|
||||
```
|
||||
|
||||
#### 直接使用Docker镜像
|
||||
|
||||
@@ -86,5 +86,8 @@ func SendEmail(subject string, receiver string, content string) error {
|
||||
} else {
|
||||
err = smtp.SendMail(addr, auth, SMTPFrom, to, mail)
|
||||
}
|
||||
if err != nil {
|
||||
SysError(fmt.Sprintf("failed to send email to %s: %v", receiver, err))
|
||||
}
|
||||
return err
|
||||
}
|
||||
|
||||
@@ -26,6 +26,8 @@ func GetEndpointTypesByChannelType(channelType int, modelName string) []constant
|
||||
endpointTypes = []constant.EndpointType{constant.EndpointTypeGemini, constant.EndpointTypeOpenAI}
|
||||
case constant.ChannelTypeOpenRouter: // OpenRouter 只支持 OpenAI 端点
|
||||
endpointTypes = []constant.EndpointType{constant.EndpointTypeOpenAI}
|
||||
case constant.ChannelTypeSora:
|
||||
endpointTypes = []constant.EndpointType{constant.EndpointTypeOpenAIVideo}
|
||||
default:
|
||||
if IsOpenAIResponseOnlyModel(modelName) {
|
||||
endpointTypes = []constant.EndpointType{constant.EndpointTypeOpenAIResponse}
|
||||
|
||||
@@ -3,6 +3,7 @@ package common
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/json"
|
||||
"io"
|
||||
)
|
||||
|
||||
func Unmarshal(data []byte, v any) error {
|
||||
@@ -13,7 +14,7 @@ func UnmarshalJsonStr(data string, v any) error {
|
||||
return json.Unmarshal(StringToByteSlice(data), v)
|
||||
}
|
||||
|
||||
func DecodeJson(reader *bytes.Reader, v any) error {
|
||||
func DecodeJson(reader io.Reader, v any) error {
|
||||
return json.NewDecoder(reader).Decode(v)
|
||||
}
|
||||
|
||||
|
||||
@@ -10,6 +10,7 @@ const (
|
||||
EndpointTypeJinaRerank EndpointType = "jina-rerank"
|
||||
EndpointTypeImageGeneration EndpointType = "image-generation"
|
||||
EndpointTypeEmbeddings EndpointType = "embeddings"
|
||||
EndpointTypeOpenAIVideo EndpointType = "openai-video"
|
||||
//EndpointTypeMidjourney EndpointType = "midjourney-proxy"
|
||||
//EndpointTypeSuno EndpointType = "suno-proxy"
|
||||
//EndpointTypeKling EndpointType = "kling"
|
||||
|
||||
@@ -229,7 +229,7 @@ func getChannel(c *gin.Context, group, originalModel string, retryCount int) (*m
|
||||
return nil, types.NewError(fmt.Errorf("获取分组 %s 下模型 %s 的可用渠道失败(retry): %s", selectGroup, originalModel, err.Error()), types.ErrorCodeGetChannelFailed, types.ErrOptionWithSkipRetry())
|
||||
}
|
||||
if channel == nil {
|
||||
return nil, types.NewError(fmt.Errorf("分组 %s 下模型 %s 的可用渠道不存在(数据库一致性已被破坏,retry)", selectGroup, originalModel), types.ErrorCodeGetChannelFailed, types.ErrOptionWithSkipRetry())
|
||||
return nil, types.NewError(fmt.Errorf("分组 %s 下模型 %s 的可用渠道不存在(retry)", selectGroup, originalModel), types.ErrorCodeGetChannelFailed, types.ErrOptionWithSkipRetry())
|
||||
}
|
||||
newAPIError := middleware.SetupContextForSelectedChannel(c, channel, originalModel)
|
||||
if newAPIError != nil {
|
||||
@@ -299,6 +299,9 @@ func processChannelError(c *gin.Context, channelError types.ChannelError, err *t
|
||||
userGroup := c.GetString("group")
|
||||
channelId := c.GetInt("channel_id")
|
||||
other := make(map[string]interface{})
|
||||
if c.Request != nil && c.Request.URL != nil {
|
||||
other["request_path"] = c.Request.URL.Path
|
||||
}
|
||||
other["error_type"] = err.GetErrorType()
|
||||
other["error_code"] = err.GetErrorCode()
|
||||
other["status_code"] = err.StatusCode
|
||||
|
||||
@@ -88,10 +88,13 @@ func updateVideoSingleTask(ctx context.Context, adaptor channel.TaskAdaptor, cha
|
||||
return fmt.Errorf("readAll failed for task %s: %w", taskId, err)
|
||||
}
|
||||
|
||||
logger.LogDebug(ctx, fmt.Sprintf("UpdateVideoSingleTask response: %s", string(responseBody)))
|
||||
|
||||
taskResult := &relaycommon.TaskInfo{}
|
||||
// try parse as New API response format
|
||||
var responseItems dto.TaskResponse[model.Task]
|
||||
if err = json.Unmarshal(responseBody, &responseItems); err == nil && responseItems.IsSuccess() {
|
||||
if err = common.Unmarshal(responseBody, &responseItems); err == nil && responseItems.IsSuccess() {
|
||||
logger.LogDebug(ctx, fmt.Sprintf("UpdateVideoSingleTask parsed as new api response format: %+v", responseItems))
|
||||
t := responseItems.Data
|
||||
taskResult.TaskID = t.TaskID
|
||||
taskResult.Status = string(t.Status)
|
||||
@@ -105,9 +108,12 @@ func updateVideoSingleTask(ctx context.Context, adaptor channel.TaskAdaptor, cha
|
||||
task.Data = redactVideoResponseBody(responseBody)
|
||||
}
|
||||
|
||||
logger.LogDebug(ctx, fmt.Sprintf("UpdateVideoSingleTask taskResult: %+v", taskResult))
|
||||
|
||||
now := time.Now().Unix()
|
||||
if taskResult.Status == "" {
|
||||
return fmt.Errorf("task %s status is empty", taskId)
|
||||
//return fmt.Errorf("task %s status is empty", taskId)
|
||||
taskResult = relaycommon.FailTaskInfo("upstream returned empty status")
|
||||
}
|
||||
task.Status = model.TaskStatus(taskResult.Status)
|
||||
switch taskResult.Status {
|
||||
|
||||
@@ -30,11 +30,14 @@ services:
|
||||
# - SQL_DSN=root:123456@tcp(mysql:3306)/new-api # Point to the mysql service, uncomment if using MySQL
|
||||
- REDIS_CONN_STRING=redis://redis
|
||||
- TZ=Asia/Shanghai
|
||||
- ERROR_LOG_ENABLED=true # 是否启用错误日志记录
|
||||
- BATCH_UPDATE_ENABLED=true # 是否启用批量更新 batch update enabled
|
||||
# - STREAMING_TIMEOUT=300 # 流模式无响应超时时间,单位秒,默认120秒,如果出现空补全可以尝试改为更大值 Streaming timeout in seconds, default is 120s. Increase if experiencing empty completions
|
||||
# - SESSION_SECRET=random_string # 多机部署时设置,必须修改这个随机字符串!! multi-node deployment, set this to a random string!!!!!!!
|
||||
- ERROR_LOG_ENABLED=true # 是否启用错误日志记录 (Whether to enable error log recording)
|
||||
- BATCH_UPDATE_ENABLED=true # 是否启用批量更新 (Whether to enable batch update)
|
||||
# - STREAMING_TIMEOUT=300 # 流模式无响应超时时间,单位秒,默认120秒,如果出现空补全可以尝试改为更大值 (Streaming timeout in seconds, default is 120s. Increase if experiencing empty completions)
|
||||
# - SESSION_SECRET=random_string # 多机部署时设置,必须修改这个随机字符串!! (multi-node deployment, set this to a random string!!!!!!!)
|
||||
# - SYNC_FREQUENCY=60 # Uncomment if regular database syncing is needed
|
||||
# - GOOGLE_ANALYTICS_ID=G-XXXXXXXXXX # Google Analytics 的测量 ID (Google Analytics Measurement ID)
|
||||
# - UMAMI_WEBSITE_ID=xxxxxxxx-xxxx-xxxx-xxxx-xxxxxxxxxxxx # Umami 网站 ID (Umami Website ID)
|
||||
# - UMAMI_SCRIPT_URL=https://analytics.umami.is/script.js # Umami 脚本 URL,默认为官方地址 (Umami Script URL, defaults to official URL)
|
||||
|
||||
depends_on:
|
||||
- redis
|
||||
|
||||
@@ -16,6 +16,13 @@ const (
|
||||
VertexKeyTypeAPIKey VertexKeyType = "api_key"
|
||||
)
|
||||
|
||||
type AwsKeyType string
|
||||
|
||||
const (
|
||||
AwsKeyTypeAKSK AwsKeyType = "ak_sk" // 默认
|
||||
AwsKeyTypeApiKey AwsKeyType = "api_key"
|
||||
)
|
||||
|
||||
type ChannelOtherSettings struct {
|
||||
AzureResponsesVersion string `json:"azure_responses_version,omitempty"`
|
||||
VertexKeyType VertexKeyType `json:"vertex_key_type,omitempty"` // "json" or "api_key"
|
||||
@@ -23,6 +30,7 @@ type ChannelOtherSettings struct {
|
||||
AllowServiceTier bool `json:"allow_service_tier,omitempty"` // 是否允许 service_tier 透传(默认过滤以避免额外计费)
|
||||
DisableStore bool `json:"disable_store,omitempty"` // 是否禁用 store 透传(默认允许透传,禁用后可能导致 Codex 无法使用)
|
||||
AllowSafetyIdentifier bool `json:"allow_safety_identifier,omitempty"` // 是否允许 safety_identifier 透传(默认过滤以保护用户隐私)
|
||||
AwsKeyType AwsKeyType `json:"aws_key_type,omitempty"`
|
||||
}
|
||||
|
||||
func (s *ChannelOtherSettings) IsOpenRouterEnterprise() bool {
|
||||
|
||||
@@ -24,7 +24,7 @@ type ClaudeMediaMessage struct {
|
||||
StopReason *string `json:"stop_reason,omitempty"`
|
||||
PartialJson *string `json:"partial_json,omitempty"`
|
||||
Role string `json:"role,omitempty"`
|
||||
Thinking string `json:"thinking,omitempty"`
|
||||
Thinking *string `json:"thinking,omitempty"`
|
||||
Signature string `json:"signature,omitempty"`
|
||||
Delta string `json:"delta,omitempty"`
|
||||
CacheControl json.RawMessage `json:"cache_control,omitempty"`
|
||||
@@ -148,6 +148,10 @@ func (c *ClaudeMessage) SetStringContent(content string) {
|
||||
c.Content = content
|
||||
}
|
||||
|
||||
func (c *ClaudeMessage) SetContent(content any) {
|
||||
c.Content = content
|
||||
}
|
||||
|
||||
func (c *ClaudeMessage) ParseContent() ([]ClaudeMediaMessage, error) {
|
||||
return common.Any2Type[[]ClaudeMediaMessage](c.Content)
|
||||
}
|
||||
|
||||
@@ -27,7 +27,7 @@ type OpenAIVideo struct {
|
||||
Size string `json:"size,omitempty"`
|
||||
RemixedFromVideoID string `json:"remixed_from_video_id,omitempty"`
|
||||
Error *OpenAIVideoError `json:"error,omitempty"`
|
||||
Metadata map[string]any `json:"metadata,omitempty"`
|
||||
Metadata map[string]any `json:"meta_data,omitempty"`
|
||||
}
|
||||
|
||||
func (m *OpenAIVideo) SetProgressStr(progress string) {
|
||||
|
||||
53
main.go
53
main.go
@@ -150,6 +150,26 @@ func main() {
|
||||
})
|
||||
server.Use(sessions.Sessions("session", store))
|
||||
|
||||
InjectUmamiAnalytics()
|
||||
InjectGoogleAnalytics()
|
||||
|
||||
// 设置路由
|
||||
router.SetRouter(server, buildFS, indexPage)
|
||||
var port = os.Getenv("PORT")
|
||||
if port == "" {
|
||||
port = strconv.Itoa(*common.Port)
|
||||
}
|
||||
|
||||
// Log startup success message
|
||||
common.LogStartupSuccess(startTime, port)
|
||||
|
||||
err = server.Run(":" + port)
|
||||
if err != nil {
|
||||
common.FatalLog("failed to start HTTP server: " + err.Error())
|
||||
}
|
||||
}
|
||||
|
||||
func InjectUmamiAnalytics() {
|
||||
analyticsInjectBuilder := &strings.Builder{}
|
||||
if os.Getenv("UMAMI_WEBSITE_ID") != "" {
|
||||
umamiSiteID := os.Getenv("UMAMI_WEBSITE_ID")
|
||||
@@ -164,21 +184,28 @@ func main() {
|
||||
analyticsInjectBuilder.WriteString("\"></script>")
|
||||
}
|
||||
analyticsInject := analyticsInjectBuilder.String()
|
||||
indexPage = bytes.ReplaceAll(indexPage, []byte("<analytics></analytics>\n"), []byte(analyticsInject))
|
||||
indexPage = bytes.ReplaceAll(indexPage, []byte("<!--umami-->\n"), []byte(analyticsInject))
|
||||
}
|
||||
|
||||
router.SetRouter(server, buildFS, indexPage)
|
||||
var port = os.Getenv("PORT")
|
||||
if port == "" {
|
||||
port = strconv.Itoa(*common.Port)
|
||||
}
|
||||
|
||||
// Log startup success message
|
||||
common.LogStartupSuccess(startTime, port)
|
||||
|
||||
err = server.Run(":" + port)
|
||||
if err != nil {
|
||||
common.FatalLog("failed to start HTTP server: " + err.Error())
|
||||
func InjectGoogleAnalytics() {
|
||||
analyticsInjectBuilder := &strings.Builder{}
|
||||
if os.Getenv("GOOGLE_ANALYTICS_ID") != "" {
|
||||
gaID := os.Getenv("GOOGLE_ANALYTICS_ID")
|
||||
// Google Analytics 4 (gtag.js)
|
||||
analyticsInjectBuilder.WriteString("<script async src=\"https://www.googletagmanager.com/gtag/js?id=")
|
||||
analyticsInjectBuilder.WriteString(gaID)
|
||||
analyticsInjectBuilder.WriteString("\"></script>")
|
||||
analyticsInjectBuilder.WriteString("<script>")
|
||||
analyticsInjectBuilder.WriteString("window.dataLayer = window.dataLayer || [];")
|
||||
analyticsInjectBuilder.WriteString("function gtag(){dataLayer.push(arguments);}")
|
||||
analyticsInjectBuilder.WriteString("gtag('js', new Date());")
|
||||
analyticsInjectBuilder.WriteString("gtag('config', '")
|
||||
analyticsInjectBuilder.WriteString(gaID)
|
||||
analyticsInjectBuilder.WriteString("');")
|
||||
analyticsInjectBuilder.WriteString("</script>")
|
||||
}
|
||||
analyticsInject := analyticsInjectBuilder.String()
|
||||
indexPage = bytes.ReplaceAll(indexPage, []byte("<!--Google Analytics-->\n"), []byte(analyticsInject))
|
||||
}
|
||||
|
||||
func InitResources() error {
|
||||
|
||||
@@ -102,7 +102,7 @@ func Distribute() func(c *gin.Context) {
|
||||
if userGroup == "auto" {
|
||||
showGroup = fmt.Sprintf("auto(%s)", selectGroup)
|
||||
}
|
||||
message := fmt.Sprintf("获取分组 %s 下模型 %s 的可用渠道失败(数据库一致性已被破坏,distributor): %s", showGroup, modelRequest.Model, err.Error())
|
||||
message := fmt.Sprintf("获取分组 %s 下模型 %s 的可用渠道失败(distributor): %s", showGroup, modelRequest.Model, err.Error())
|
||||
// 如果错误,但是渠道不为空,说明是数据库一致性问题
|
||||
//if channel != nil {
|
||||
// common.SysError(fmt.Sprintf("渠道不存在:%d", channel.Id))
|
||||
|
||||
@@ -46,6 +46,7 @@ const (
|
||||
LogTypeManage
|
||||
LogTypeSystem
|
||||
LogTypeError
|
||||
LogTypeRefund
|
||||
)
|
||||
|
||||
func formatUserLogs(logs []*Log) {
|
||||
|
||||
@@ -53,5 +53,5 @@ type TaskAdaptor interface {
|
||||
}
|
||||
|
||||
type OpenAIVideoConverter interface {
|
||||
ConvertToOpenAIVideo(originTask *model.Task) (*dto.OpenAIVideo, error)
|
||||
ConvertToOpenAIVideo(originTask *model.Task) ([]byte, error)
|
||||
}
|
||||
|
||||
@@ -1,25 +1,36 @@
|
||||
package aws
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"strings"
|
||||
|
||||
"github.com/QuantumNous/new-api/dto"
|
||||
"github.com/QuantumNous/new-api/relay/channel"
|
||||
"github.com/QuantumNous/new-api/relay/channel/claude"
|
||||
relaycommon "github.com/QuantumNous/new-api/relay/common"
|
||||
"github.com/QuantumNous/new-api/service"
|
||||
"github.com/QuantumNous/new-api/types"
|
||||
"github.com/aws/aws-sdk-go-v2/service/bedrockruntime"
|
||||
"github.com/pkg/errors"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
type ClientMode int
|
||||
|
||||
const (
|
||||
RequestModeCompletion = 1
|
||||
RequestModeMessage = 2
|
||||
ClientModeApiKey ClientMode = iota + 1
|
||||
ClientModeAKSK
|
||||
)
|
||||
|
||||
type Adaptor struct {
|
||||
RequestMode int
|
||||
ClientMode ClientMode
|
||||
AwsClient *bedrockruntime.Client
|
||||
AwsModelId string
|
||||
AwsReq any
|
||||
IsNova bool
|
||||
}
|
||||
|
||||
func (a *Adaptor) ConvertGeminiRequest(*gin.Context, *relaycommon.RelayInfo, *dto.GeminiChatRequest) (any, error) {
|
||||
@@ -28,8 +39,37 @@ func (a *Adaptor) ConvertGeminiRequest(*gin.Context, *relaycommon.RelayInfo, *dt
|
||||
}
|
||||
|
||||
func (a *Adaptor) ConvertClaudeRequest(c *gin.Context, info *relaycommon.RelayInfo, request *dto.ClaudeRequest) (any, error) {
|
||||
c.Set("request_model", request.Model)
|
||||
c.Set("converted_request", request)
|
||||
for i, message := range request.Messages {
|
||||
updated := false
|
||||
if !message.IsStringContent() {
|
||||
content, err := message.ParseContent()
|
||||
if err != nil {
|
||||
return nil, errors.Wrap(err, "failed to parse message content")
|
||||
}
|
||||
for i2, mediaMessage := range content {
|
||||
if mediaMessage.Source != nil {
|
||||
if mediaMessage.Source.Type == "url" {
|
||||
fileData, err := service.GetFileBase64FromUrl(c, mediaMessage.Source.Url, "formatting image for Claude")
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("get file base64 from url failed: %s", err.Error())
|
||||
}
|
||||
mediaMessage.Source.MediaType = fileData.MimeType
|
||||
mediaMessage.Source.Data = fileData.Base64Data
|
||||
mediaMessage.Source.Url = ""
|
||||
mediaMessage.Source.Type = "base64"
|
||||
content[i2] = mediaMessage
|
||||
updated = true
|
||||
}
|
||||
}
|
||||
}
|
||||
if updated {
|
||||
message.SetContent(content)
|
||||
}
|
||||
}
|
||||
if updated {
|
||||
request.Messages[i] = message
|
||||
}
|
||||
}
|
||||
return request, nil
|
||||
}
|
||||
|
||||
@@ -44,15 +84,28 @@ func (a *Adaptor) ConvertImageRequest(c *gin.Context, info *relaycommon.RelayInf
|
||||
}
|
||||
|
||||
func (a *Adaptor) Init(info *relaycommon.RelayInfo) {
|
||||
a.RequestMode = RequestModeMessage
|
||||
}
|
||||
|
||||
func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) {
|
||||
return "", nil
|
||||
if info.ChannelOtherSettings.AwsKeyType == dto.AwsKeyTypeApiKey {
|
||||
awsModelId := awsModelID(info.UpstreamModelName)
|
||||
a.ClientMode = ClientModeApiKey
|
||||
awsSecret := strings.Split(info.ApiKey, "|")
|
||||
if len(awsSecret) != 2 {
|
||||
return "", errors.New("invalid aws api key, should be in format of <api-key>|<region>")
|
||||
}
|
||||
return fmt.Sprintf("https://bedrock-runtime.%s.amazonaws.com/model/%s/converse", awsModelId, awsSecret[1]), nil
|
||||
} else {
|
||||
a.ClientMode = ClientModeAKSK
|
||||
return "", nil
|
||||
}
|
||||
}
|
||||
|
||||
func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Header, info *relaycommon.RelayInfo) error {
|
||||
claude.CommonClaudeHeadersOperation(c, req, info)
|
||||
if a.ClientMode == ClientModeApiKey {
|
||||
req.Set("Authorization", "Bearer "+info.ApiKey)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -63,22 +116,16 @@ func (a *Adaptor) ConvertOpenAIRequest(c *gin.Context, info *relaycommon.RelayIn
|
||||
// 检查是否为Nova模型
|
||||
if isNovaModel(request.Model) {
|
||||
novaReq := convertToNovaRequest(request)
|
||||
c.Set("request_model", request.Model)
|
||||
c.Set("converted_request", novaReq)
|
||||
c.Set("is_nova_model", true)
|
||||
a.IsNova = true
|
||||
return novaReq, nil
|
||||
}
|
||||
|
||||
// 原有的Claude模型处理逻辑
|
||||
var claudeReq *dto.ClaudeRequest
|
||||
var err error
|
||||
claudeReq, err = claude.RequestOpenAI2ClaudeMessage(c, *request)
|
||||
claudeReq, err := claude.RequestOpenAI2ClaudeMessage(c, *request)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
return nil, errors.Wrap(err, "failed to convert openai request to claude request")
|
||||
}
|
||||
c.Set("request_model", claudeReq.Model)
|
||||
c.Set("converted_request", claudeReq)
|
||||
c.Set("is_nova_model", false)
|
||||
info.UpstreamModelName = claudeReq.Model
|
||||
return claudeReq, err
|
||||
}
|
||||
|
||||
@@ -97,14 +144,27 @@ func (a *Adaptor) ConvertOpenAIResponsesRequest(c *gin.Context, info *relaycommo
|
||||
}
|
||||
|
||||
func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (any, error) {
|
||||
return nil, nil
|
||||
if a.ClientMode == ClientModeApiKey {
|
||||
return channel.DoApiRequest(a, c, info, requestBody)
|
||||
} else {
|
||||
return doAwsClientRequest(c, info, a, requestBody)
|
||||
}
|
||||
}
|
||||
|
||||
func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage any, err *types.NewAPIError) {
|
||||
if info.IsStream {
|
||||
err, usage = awsStreamHandler(c, resp, info, a.RequestMode)
|
||||
if a.ClientMode == ClientModeApiKey {
|
||||
claudeAdaptor := claude.Adaptor{}
|
||||
usage, err = claudeAdaptor.DoResponse(c, resp, info)
|
||||
} else {
|
||||
err, usage = awsHandler(c, info, a.RequestMode)
|
||||
if a.IsNova {
|
||||
err, usage = handleNovaRequest(c, info, a)
|
||||
} else {
|
||||
if info.IsStream {
|
||||
err, usage = awsStreamHandler(c, info, a)
|
||||
} else {
|
||||
err, usage = awsHandler(c, info, a)
|
||||
}
|
||||
}
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
@@ -124,5 +124,5 @@ var ChannelName = "aws"
|
||||
|
||||
// 判断是否为Nova模型
|
||||
func isNovaModel(modelId string) bool {
|
||||
return strings.HasPrefix(modelId, "nova-")
|
||||
return strings.Contains(modelId, "nova-")
|
||||
}
|
||||
|
||||
@@ -1,6 +1,9 @@
|
||||
package aws
|
||||
|
||||
import (
|
||||
"io"
|
||||
|
||||
"github.com/QuantumNous/new-api/common"
|
||||
"github.com/QuantumNous/new-api/dto"
|
||||
)
|
||||
|
||||
@@ -35,6 +38,16 @@ func copyRequest(req *dto.ClaudeRequest) *AwsClaudeRequest {
|
||||
}
|
||||
}
|
||||
|
||||
func formatRequest(requestBody io.Reader) (*AwsClaudeRequest, error) {
|
||||
var awsClaudeRequest AwsClaudeRequest
|
||||
err := common.DecodeJson(requestBody, &awsClaudeRequest)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
awsClaudeRequest.AnthropicVersion = "bedrock-2023-05-31"
|
||||
return &awsClaudeRequest, nil
|
||||
}
|
||||
|
||||
// NovaMessage Nova模型使用messages-v1格式
|
||||
type NovaMessage struct {
|
||||
Role string `json:"role"`
|
||||
|
||||
@@ -3,6 +3,7 @@ package aws
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"strings"
|
||||
|
||||
@@ -49,12 +50,72 @@ func newAwsClient(c *gin.Context, info *relaycommon.RelayInfo) (*bedrockruntime.
|
||||
return client, nil
|
||||
}
|
||||
|
||||
func wrapErr(err error) *dto.OpenAIErrorWithStatusCode {
|
||||
return &dto.OpenAIErrorWithStatusCode{
|
||||
StatusCode: http.StatusInternalServerError,
|
||||
Error: dto.OpenAIError{
|
||||
Message: fmt.Sprintf("%s", err.Error()),
|
||||
},
|
||||
func doAwsClientRequest(c *gin.Context, info *relaycommon.RelayInfo, a *Adaptor, requestBody io.Reader) (any, error) {
|
||||
awsCli, err := newAwsClient(c, info)
|
||||
if err != nil {
|
||||
return nil, types.NewError(err, types.ErrorCodeChannelAwsClientError)
|
||||
}
|
||||
a.AwsClient = awsCli
|
||||
|
||||
awsModelId := awsModelID(info.UpstreamModelName)
|
||||
|
||||
awsRegionPrefix := awsRegionPrefix(awsCli.Options().Region)
|
||||
canCrossRegion := awsModelCanCrossRegion(awsModelId, awsRegionPrefix)
|
||||
if canCrossRegion {
|
||||
awsModelId = awsModelCrossRegion(awsModelId, awsRegionPrefix)
|
||||
}
|
||||
|
||||
if isNovaModel(awsModelId) {
|
||||
var novaReq *NovaRequest
|
||||
err = common.DecodeJson(requestBody, &novaReq)
|
||||
if err != nil {
|
||||
return nil, types.NewError(errors.Wrap(err, "decode nova request fail"), types.ErrorCodeBadRequestBody)
|
||||
}
|
||||
|
||||
// 使用InvokeModel API,但使用Nova格式的请求体
|
||||
awsReq := &bedrockruntime.InvokeModelInput{
|
||||
ModelId: aws.String(awsModelId),
|
||||
Accept: aws.String("application/json"),
|
||||
ContentType: aws.String("application/json"),
|
||||
}
|
||||
|
||||
reqBody, err := common.Marshal(novaReq)
|
||||
if err != nil {
|
||||
return nil, types.NewError(errors.Wrap(err, "marshal nova request"), types.ErrorCodeBadResponseBody)
|
||||
}
|
||||
awsReq.Body = reqBody
|
||||
return nil, nil
|
||||
} else {
|
||||
awsClaudeReq, err := formatRequest(requestBody)
|
||||
if err != nil {
|
||||
return nil, types.NewError(errors.Wrap(err, "format aws request fail"), types.ErrorCodeBadRequestBody)
|
||||
}
|
||||
|
||||
if info.IsStream {
|
||||
awsReq := &bedrockruntime.InvokeModelWithResponseStreamInput{
|
||||
ModelId: aws.String(awsModelId),
|
||||
Accept: aws.String("application/json"),
|
||||
ContentType: aws.String("application/json"),
|
||||
}
|
||||
awsReq.Body, err = common.Marshal(awsClaudeReq)
|
||||
if err != nil {
|
||||
return nil, types.NewError(errors.Wrap(err, "marshal aws request fail"), types.ErrorCodeBadRequestBody)
|
||||
}
|
||||
a.AwsReq = awsReq
|
||||
return nil, nil
|
||||
} else {
|
||||
awsReq := &bedrockruntime.InvokeModelInput{
|
||||
ModelId: aws.String(awsModelId),
|
||||
Accept: aws.String("application/json"),
|
||||
ContentType: aws.String("application/json"),
|
||||
}
|
||||
awsReq.Body, err = common.Marshal(awsClaudeReq)
|
||||
if err != nil {
|
||||
return nil, types.NewError(errors.Wrap(err, "marshal aws request fail"), types.ErrorCodeBadRequestBody)
|
||||
}
|
||||
a.AwsReq = awsReq
|
||||
return nil, nil
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -88,50 +149,9 @@ func awsModelID(requestModel string) string {
|
||||
return requestModel
|
||||
}
|
||||
|
||||
func awsHandler(c *gin.Context, info *relaycommon.RelayInfo, requestMode int) (*types.NewAPIError, *dto.Usage) {
|
||||
awsCli, err := newAwsClient(c, info)
|
||||
if err != nil {
|
||||
return types.NewError(err, types.ErrorCodeChannelAwsClientError), nil
|
||||
}
|
||||
func awsHandler(c *gin.Context, info *relaycommon.RelayInfo, a *Adaptor) (*types.NewAPIError, *dto.Usage) {
|
||||
|
||||
awsModelId := awsModelID(c.GetString("request_model"))
|
||||
// 检查是否为Nova模型
|
||||
isNova, _ := c.Get("is_nova_model")
|
||||
if isNova == true {
|
||||
// Nova模型也支持跨区域
|
||||
awsRegionPrefix := awsRegionPrefix(awsCli.Options().Region)
|
||||
canCrossRegion := awsModelCanCrossRegion(awsModelId, awsRegionPrefix)
|
||||
if canCrossRegion {
|
||||
awsModelId = awsModelCrossRegion(awsModelId, awsRegionPrefix)
|
||||
}
|
||||
return handleNovaRequest(c, awsCli, info, awsModelId)
|
||||
}
|
||||
|
||||
// 原有的Claude处理逻辑
|
||||
awsRegionPrefix := awsRegionPrefix(awsCli.Options().Region)
|
||||
canCrossRegion := awsModelCanCrossRegion(awsModelId, awsRegionPrefix)
|
||||
if canCrossRegion {
|
||||
awsModelId = awsModelCrossRegion(awsModelId, awsRegionPrefix)
|
||||
}
|
||||
|
||||
awsReq := &bedrockruntime.InvokeModelInput{
|
||||
ModelId: aws.String(awsModelId),
|
||||
Accept: aws.String("application/json"),
|
||||
ContentType: aws.String("application/json"),
|
||||
}
|
||||
|
||||
claudeReq_, ok := c.Get("converted_request")
|
||||
if !ok {
|
||||
return types.NewError(errors.New("aws claude request not found"), types.ErrorCodeInvalidRequest), nil
|
||||
}
|
||||
claudeReq := claudeReq_.(*dto.ClaudeRequest)
|
||||
awsClaudeReq := copyRequest(claudeReq)
|
||||
awsReq.Body, err = common.Marshal(awsClaudeReq)
|
||||
if err != nil {
|
||||
return types.NewError(errors.Wrap(err, "marshal request"), types.ErrorCodeBadResponseBody), nil
|
||||
}
|
||||
|
||||
awsResp, err := awsCli.InvokeModel(c.Request.Context(), awsReq)
|
||||
awsResp, err := a.AwsClient.InvokeModel(c.Request.Context(), a.AwsReq.(*bedrockruntime.InvokeModelInput))
|
||||
if err != nil {
|
||||
return types.NewOpenAIError(errors.Wrap(err, "InvokeModel"), types.ErrorCodeAwsInvokeError, http.StatusInternalServerError), nil
|
||||
}
|
||||
@@ -149,46 +169,15 @@ func awsHandler(c *gin.Context, info *relaycommon.RelayInfo, requestMode int) (*
|
||||
c.Writer.Header().Set("Content-Type", *awsResp.ContentType)
|
||||
}
|
||||
|
||||
handlerErr := claude.HandleClaudeResponseData(c, info, claudeInfo, nil, awsResp.Body, RequestModeMessage)
|
||||
handlerErr := claude.HandleClaudeResponseData(c, info, claudeInfo, nil, awsResp.Body, claude.RequestModeMessage)
|
||||
if handlerErr != nil {
|
||||
return handlerErr, nil
|
||||
}
|
||||
return nil, claudeInfo.Usage
|
||||
}
|
||||
|
||||
func awsStreamHandler(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo, requestMode int) (*types.NewAPIError, *dto.Usage) {
|
||||
awsCli, err := newAwsClient(c, info)
|
||||
if err != nil {
|
||||
return types.NewError(err, types.ErrorCodeChannelAwsClientError), nil
|
||||
}
|
||||
|
||||
awsModelId := awsModelID(c.GetString("request_model"))
|
||||
|
||||
awsRegionPrefix := awsRegionPrefix(awsCli.Options().Region)
|
||||
canCrossRegion := awsModelCanCrossRegion(awsModelId, awsRegionPrefix)
|
||||
if canCrossRegion {
|
||||
awsModelId = awsModelCrossRegion(awsModelId, awsRegionPrefix)
|
||||
}
|
||||
|
||||
awsReq := &bedrockruntime.InvokeModelWithResponseStreamInput{
|
||||
ModelId: aws.String(awsModelId),
|
||||
Accept: aws.String("application/json"),
|
||||
ContentType: aws.String("application/json"),
|
||||
}
|
||||
|
||||
claudeReq_, ok := c.Get("converted_request")
|
||||
if !ok {
|
||||
return types.NewError(errors.New("aws claude request not found"), types.ErrorCodeInvalidRequest), nil
|
||||
}
|
||||
claudeReq := claudeReq_.(*dto.ClaudeRequest)
|
||||
|
||||
awsClaudeReq := copyRequest(claudeReq)
|
||||
awsReq.Body, err = common.Marshal(awsClaudeReq)
|
||||
if err != nil {
|
||||
return types.NewError(errors.Wrap(err, "marshal request"), types.ErrorCodeBadResponseBody), nil
|
||||
}
|
||||
|
||||
awsResp, err := awsCli.InvokeModelWithResponseStream(c.Request.Context(), awsReq)
|
||||
func awsStreamHandler(c *gin.Context, info *relaycommon.RelayInfo, a *Adaptor) (*types.NewAPIError, *dto.Usage) {
|
||||
awsResp, err := a.AwsClient.InvokeModelWithResponseStream(c.Request.Context(), a.AwsReq.(*bedrockruntime.InvokeModelWithResponseStreamInput))
|
||||
if err != nil {
|
||||
return types.NewOpenAIError(errors.Wrap(err, "InvokeModelWithResponseStream"), types.ErrorCodeAwsInvokeError, http.StatusInternalServerError), nil
|
||||
}
|
||||
@@ -207,7 +196,7 @@ func awsStreamHandler(c *gin.Context, resp *http.Response, info *relaycommon.Rel
|
||||
switch v := event.(type) {
|
||||
case *bedrockruntimeTypes.ResponseStreamMemberChunk:
|
||||
info.SetFirstResponseTime()
|
||||
respErr := claude.HandleStreamResponseData(c, info, claudeInfo, string(v.Value.Bytes), RequestModeMessage)
|
||||
respErr := claude.HandleStreamResponseData(c, info, claudeInfo, string(v.Value.Bytes), claude.RequestModeMessage)
|
||||
if respErr != nil {
|
||||
return respErr, nil
|
||||
}
|
||||
@@ -220,32 +209,14 @@ func awsStreamHandler(c *gin.Context, resp *http.Response, info *relaycommon.Rel
|
||||
}
|
||||
}
|
||||
|
||||
claude.HandleStreamFinalResponse(c, info, claudeInfo, RequestModeMessage)
|
||||
claude.HandleStreamFinalResponse(c, info, claudeInfo, claude.RequestModeMessage)
|
||||
return nil, claudeInfo.Usage
|
||||
}
|
||||
|
||||
// Nova模型处理函数
|
||||
func handleNovaRequest(c *gin.Context, awsCli *bedrockruntime.Client, info *relaycommon.RelayInfo, awsModelId string) (*types.NewAPIError, *dto.Usage) {
|
||||
novaReq_, ok := c.Get("converted_request")
|
||||
if !ok {
|
||||
return types.NewError(errors.New("nova request not found"), types.ErrorCodeInvalidRequest), nil
|
||||
}
|
||||
novaReq := novaReq_.(*NovaRequest)
|
||||
func handleNovaRequest(c *gin.Context, info *relaycommon.RelayInfo, a *Adaptor) (*types.NewAPIError, *dto.Usage) {
|
||||
|
||||
// 使用InvokeModel API,但使用Nova格式的请求体
|
||||
awsReq := &bedrockruntime.InvokeModelInput{
|
||||
ModelId: aws.String(awsModelId),
|
||||
Accept: aws.String("application/json"),
|
||||
ContentType: aws.String("application/json"),
|
||||
}
|
||||
|
||||
reqBody, err := json.Marshal(novaReq)
|
||||
if err != nil {
|
||||
return types.NewError(errors.Wrap(err, "marshal nova request"), types.ErrorCodeBadResponseBody), nil
|
||||
}
|
||||
awsReq.Body = reqBody
|
||||
|
||||
awsResp, err := awsCli.InvokeModel(c.Request.Context(), awsReq)
|
||||
awsResp, err := a.AwsClient.InvokeModel(c.Request.Context(), a.AwsReq.(*bedrockruntime.InvokeModelInput))
|
||||
if err != nil {
|
||||
return types.NewError(errors.Wrap(err, "InvokeModel"), types.ErrorCodeChannelAwsClientError), nil
|
||||
}
|
||||
|
||||
@@ -477,8 +477,7 @@ func StreamResponseClaude2OpenAI(reqMode int, claudeResponse *dto.ClaudeResponse
|
||||
signatureContent := "\n"
|
||||
choice.Delta.ReasoningContent = &signatureContent
|
||||
case "thinking_delta":
|
||||
thinkingContent := claudeResponse.Delta.Thinking
|
||||
choice.Delta.ReasoningContent = &thinkingContent
|
||||
choice.Delta.ReasoningContent = claudeResponse.Delta.Thinking
|
||||
}
|
||||
}
|
||||
} else if claudeResponse.Type == "message_delta" {
|
||||
@@ -513,7 +512,9 @@ func ResponseClaude2OpenAI(reqMode int, claudeResponse *dto.ClaudeResponse) *dto
|
||||
var responseThinking string
|
||||
if len(claudeResponse.Content) > 0 {
|
||||
responseText = claudeResponse.Content[0].GetText()
|
||||
responseThinking = claudeResponse.Content[0].Thinking
|
||||
if claudeResponse.Content[0].Thinking != nil {
|
||||
responseThinking = *claudeResponse.Content[0].Thinking
|
||||
}
|
||||
}
|
||||
tools := make([]dto.ToolCallResponse, 0)
|
||||
thinkingContent := ""
|
||||
@@ -545,7 +546,9 @@ func ResponseClaude2OpenAI(reqMode int, claudeResponse *dto.ClaudeResponse) *dto
|
||||
})
|
||||
case "thinking":
|
||||
// 加密的不管, 只输出明文的推理过程
|
||||
thinkingContent = message.Thinking
|
||||
if message.Thinking != nil {
|
||||
thinkingContent = *message.Thinking
|
||||
}
|
||||
case "text":
|
||||
responseText = message.GetText()
|
||||
}
|
||||
@@ -598,8 +601,8 @@ func FormatClaudeResponseInfo(requestMode int, claudeResponse *dto.ClaudeRespons
|
||||
if claudeResponse.Delta.Text != nil {
|
||||
claudeInfo.ResponseText.WriteString(*claudeResponse.Delta.Text)
|
||||
}
|
||||
if claudeResponse.Delta.Thinking != "" {
|
||||
claudeInfo.ResponseText.WriteString(claudeResponse.Delta.Thinking)
|
||||
if claudeResponse.Delta.Thinking != nil {
|
||||
claudeInfo.ResponseText.WriteString(*claudeResponse.Delta.Thinking)
|
||||
}
|
||||
} else if claudeResponse.Type == "message_delta" {
|
||||
// 最终的usage获取
|
||||
|
||||
@@ -1061,11 +1061,11 @@ func GeminiChatHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http.R
|
||||
}
|
||||
if len(geminiResponse.Candidates) == 0 {
|
||||
//return nil, types.NewOpenAIError(errors.New("no candidates returned"), types.ErrorCodeBadResponseBody, http.StatusInternalServerError)
|
||||
if geminiResponse.PromptFeedback != nil && geminiResponse.PromptFeedback.BlockReason != nil {
|
||||
return nil, types.NewOpenAIError(errors.New("request blocked by Gemini API: "+*geminiResponse.PromptFeedback.BlockReason), types.ErrorCodePromptBlocked, http.StatusBadRequest)
|
||||
} else {
|
||||
return nil, types.NewOpenAIError(errors.New("empty response from Gemini API"), types.ErrorCodeEmptyResponse, http.StatusInternalServerError)
|
||||
}
|
||||
//if geminiResponse.PromptFeedback != nil && geminiResponse.PromptFeedback.BlockReason != nil {
|
||||
// return nil, types.NewOpenAIError(errors.New("request blocked by Gemini API: "+*geminiResponse.PromptFeedback.BlockReason), types.ErrorCodePromptBlocked, http.StatusBadRequest)
|
||||
//} else {
|
||||
// return nil, types.NewOpenAIError(errors.New("empty response from Gemini API"), types.ErrorCodeEmptyResponse, http.StatusInternalServerError)
|
||||
//}
|
||||
}
|
||||
fullTextResponse := responseGeminiChat2OpenAI(c, &geminiResponse)
|
||||
fullTextResponse.Model = info.UpstreamModelName
|
||||
|
||||
@@ -6,6 +6,7 @@ import (
|
||||
"io"
|
||||
"net/http"
|
||||
|
||||
"github.com/QuantumNous/new-api/common"
|
||||
"github.com/QuantumNous/new-api/dto"
|
||||
"github.com/QuantumNous/new-api/relay/channel"
|
||||
"github.com/QuantumNous/new-api/relay/channel/openai"
|
||||
@@ -35,8 +36,27 @@ func (a *Adaptor) ConvertAudioRequest(c *gin.Context, info *relaycommon.RelayInf
|
||||
}
|
||||
|
||||
func (a *Adaptor) ConvertImageRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.ImageRequest) (any, error) {
|
||||
adaptor := openai.Adaptor{}
|
||||
return adaptor.ConvertImageRequest(c, info, request)
|
||||
// 解析extra到SFImageRequest里,以填入SiliconFlow特殊字段。若失败重建一个空的。
|
||||
sfRequest := &SFImageRequest{}
|
||||
extra, err := common.Marshal(request.Extra)
|
||||
if err == nil {
|
||||
err = common.Unmarshal(extra, sfRequest)
|
||||
if err != nil {
|
||||
sfRequest = &SFImageRequest{}
|
||||
}
|
||||
}
|
||||
|
||||
sfRequest.Model = request.Model
|
||||
sfRequest.Prompt = request.Prompt
|
||||
// 优先使用image_size/batch_size,否则使用OpenAI标准的size/n
|
||||
if sfRequest.ImageSize == "" {
|
||||
sfRequest.ImageSize = request.Size
|
||||
}
|
||||
if sfRequest.BatchSize == 0 {
|
||||
sfRequest.BatchSize = request.N
|
||||
}
|
||||
|
||||
return sfRequest, nil
|
||||
}
|
||||
|
||||
func (a *Adaptor) Init(info *relaycommon.RelayInfo) {
|
||||
@@ -51,6 +71,8 @@ func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) {
|
||||
return fmt.Sprintf("%s/v1/chat/completions", info.ChannelBaseUrl), nil
|
||||
} else if info.RelayMode == constant.RelayModeCompletions {
|
||||
return fmt.Sprintf("%s/v1/completions", info.ChannelBaseUrl), nil
|
||||
} else if info.RelayMode == constant.RelayModeImagesGenerations {
|
||||
return fmt.Sprintf("%s/v1/images/generations", info.ChannelBaseUrl), nil
|
||||
}
|
||||
return fmt.Sprintf("%s/v1/chat/completions", info.ChannelBaseUrl), nil
|
||||
}
|
||||
@@ -102,6 +124,8 @@ func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycom
|
||||
fallthrough
|
||||
case constant.RelayModeChatCompletions:
|
||||
fallthrough
|
||||
case constant.RelayModeImagesGenerations:
|
||||
fallthrough
|
||||
default:
|
||||
if info.IsStream {
|
||||
usage, err = openai.OaiStreamHandler(c, info, resp)
|
||||
|
||||
@@ -15,3 +15,18 @@ type SFRerankResponse struct {
|
||||
Results []dto.RerankResponseResult `json:"results"`
|
||||
Meta SFMeta `json:"meta"`
|
||||
}
|
||||
|
||||
type SFImageRequest struct {
|
||||
Model string `json:"model"`
|
||||
Prompt string `json:"prompt"`
|
||||
NegativePrompt string `json:"negative_prompt,omitempty"`
|
||||
ImageSize string `json:"image_size,omitempty"`
|
||||
BatchSize uint `json:"batch_size,omitempty"`
|
||||
Seed uint64 `json:"seed,omitempty"`
|
||||
NumInferenceSteps uint `json:"num_inference_steps,omitempty"`
|
||||
GuidanceScale float64 `json:"guidance_scale,omitempty"`
|
||||
Cfg float64 `json:"cfg,omitempty"`
|
||||
Image string `json:"image,omitempty"`
|
||||
Image2 string `json:"image2,omitempty"`
|
||||
Image3 string `json:"image3,omitempty"`
|
||||
}
|
||||
|
||||
@@ -15,6 +15,7 @@ import (
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/QuantumNous/new-api/common"
|
||||
"github.com/QuantumNous/new-api/model"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
@@ -446,7 +447,7 @@ func (a *TaskAdaptor) ParseTaskResult(respBody []byte) (*relaycommon.TaskInfo, e
|
||||
return &taskResult, nil
|
||||
}
|
||||
|
||||
func (a *TaskAdaptor) ConvertToOpenAIVideo(originTask *model.Task) (*dto.OpenAIVideo, error) {
|
||||
func (a *TaskAdaptor) ConvertToOpenAIVideo(originTask *model.Task) ([]byte, error) {
|
||||
var jimengResp responseTask
|
||||
if err := json.Unmarshal(originTask.Data, &jimengResp); err != nil {
|
||||
return nil, errors.Wrap(err, "unmarshal jimeng task data failed")
|
||||
@@ -467,7 +468,8 @@ func (a *TaskAdaptor) ConvertToOpenAIVideo(originTask *model.Task) (*dto.OpenAIV
|
||||
}
|
||||
}
|
||||
|
||||
return openAIVideo, nil
|
||||
jsonData, _ := common.Marshal(openAIVideo)
|
||||
return jsonData, nil
|
||||
}
|
||||
|
||||
func isNewAPIRelay(apiKey string) bool {
|
||||
|
||||
@@ -9,6 +9,7 @@ import (
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/QuantumNous/new-api/common"
|
||||
"github.com/QuantumNous/new-api/model"
|
||||
|
||||
"github.com/samber/lo"
|
||||
@@ -367,7 +368,7 @@ func isNewAPIRelay(apiKey string) bool {
|
||||
return strings.HasPrefix(apiKey, "sk-")
|
||||
}
|
||||
|
||||
func (a *TaskAdaptor) ConvertToOpenAIVideo(originTask *model.Task) (*dto.OpenAIVideo, error) {
|
||||
func (a *TaskAdaptor) ConvertToOpenAIVideo(originTask *model.Task) ([]byte, error) {
|
||||
var klingResp responsePayload
|
||||
if err := json.Unmarshal(originTask.Data, &klingResp); err != nil {
|
||||
return nil, errors.Wrap(err, "unmarshal kling task data failed")
|
||||
@@ -396,6 +397,6 @@ func (a *TaskAdaptor) ConvertToOpenAIVideo(originTask *model.Task) (*dto.OpenAIV
|
||||
Code: fmt.Sprintf("%d", klingResp.Code),
|
||||
}
|
||||
}
|
||||
|
||||
return openAIVideo, nil
|
||||
jsonData, _ := common.Marshal(openAIVideo)
|
||||
return jsonData, nil
|
||||
}
|
||||
|
||||
@@ -2,7 +2,6 @@ package sora
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
@@ -107,7 +106,7 @@ func (a *TaskAdaptor) DoResponse(c *gin.Context, resp *http.Response, _ *relayco
|
||||
|
||||
// Parse Sora response
|
||||
var dResp responseTask
|
||||
if err := json.Unmarshal(responseBody, &dResp); err != nil {
|
||||
if err := common.Unmarshal(responseBody, &dResp); err != nil {
|
||||
taskErr = service.TaskErrorWrapper(errors.Wrapf(err, "body: %s", responseBody), "unmarshal_response_body_failed", http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
@@ -154,7 +153,7 @@ func (a *TaskAdaptor) GetChannelName() string {
|
||||
|
||||
func (a *TaskAdaptor) ParseTaskResult(respBody []byte) (*relaycommon.TaskInfo, error) {
|
||||
resTask := responseTask{}
|
||||
if err := json.Unmarshal(respBody, &resTask); err != nil {
|
||||
if err := common.Unmarshal(respBody, &resTask); err != nil {
|
||||
return nil, errors.Wrap(err, "unmarshal task result failed")
|
||||
}
|
||||
|
||||
@@ -186,11 +185,6 @@ func (a *TaskAdaptor) ParseTaskResult(respBody []byte) (*relaycommon.TaskInfo, e
|
||||
return &taskResult, nil
|
||||
}
|
||||
|
||||
func (a *TaskAdaptor) ConvertToOpenAIVideo(task *model.Task) (*dto.OpenAIVideo, error) {
|
||||
openAIVideo := &dto.OpenAIVideo{}
|
||||
err := json.Unmarshal(task.Data, openAIVideo)
|
||||
if err != nil {
|
||||
return nil, errors.Wrap(err, "unmarshal to OpenAIVideo failed")
|
||||
}
|
||||
return openAIVideo, nil
|
||||
func (a *TaskAdaptor) ConvertToOpenAIVideo(task *model.Task) ([]byte, error) {
|
||||
return task.Data, nil
|
||||
}
|
||||
|
||||
@@ -8,6 +8,7 @@ import (
|
||||
"net/http"
|
||||
"time"
|
||||
|
||||
"github.com/QuantumNous/new-api/common"
|
||||
"github.com/gin-gonic/gin"
|
||||
|
||||
"github.com/QuantumNous/new-api/constant"
|
||||
@@ -263,7 +264,7 @@ func (a *TaskAdaptor) ParseTaskResult(respBody []byte) (*relaycommon.TaskInfo, e
|
||||
return taskInfo, nil
|
||||
}
|
||||
|
||||
func (a *TaskAdaptor) ConvertToOpenAIVideo(originTask *model.Task) (*dto.OpenAIVideo, error) {
|
||||
func (a *TaskAdaptor) ConvertToOpenAIVideo(originTask *model.Task) ([]byte, error) {
|
||||
var viduResp taskResultResponse
|
||||
if err := json.Unmarshal(originTask.Data, &viduResp); err != nil {
|
||||
return nil, errors.Wrap(err, "unmarshal vidu task data failed")
|
||||
@@ -287,5 +288,6 @@ func (a *TaskAdaptor) ConvertToOpenAIVideo(originTask *model.Task) (*dto.OpenAIV
|
||||
}
|
||||
}
|
||||
|
||||
return openAIVideo, nil
|
||||
jsonData, _ := common.Marshal(openAIVideo)
|
||||
return jsonData, nil
|
||||
}
|
||||
|
||||
@@ -512,6 +512,13 @@ type TaskInfo struct {
|
||||
TotalTokens int `json:"total_tokens,omitempty"` // 用于按倍率计费
|
||||
}
|
||||
|
||||
func FailTaskInfo(reason string) *TaskInfo {
|
||||
return &TaskInfo{
|
||||
Status: "FAILURE",
|
||||
Reason: reason,
|
||||
}
|
||||
}
|
||||
|
||||
// RemoveDisabledFields 从请求 JSON 数据中移除渠道设置中禁用的字段
|
||||
// service_tier: 服务层级字段,可能导致额外计费(OpenAI、Claude、Responses API 支持)
|
||||
// store: 数据存储授权字段,涉及用户隐私(仅 OpenAI、Responses API 支持,默认允许透传,禁用后可能导致 Codex 无法使用)
|
||||
|
||||
@@ -218,7 +218,7 @@ func RelaySwapFace(c *gin.Context, info *relaycommon.RelayInfo) *dto.MidjourneyR
|
||||
|
||||
tokenName := c.GetString("token_name")
|
||||
logContent := fmt.Sprintf("模型固定价格 %.2f,分组倍率 %.2f,操作 %s", priceData.ModelPrice, priceData.GroupRatioInfo.GroupRatio, constant.MjActionSwapFace)
|
||||
other := service.GenerateMjOtherInfo(priceData)
|
||||
other := service.GenerateMjOtherInfo(info, priceData)
|
||||
model.RecordConsumeLog(c, info.UserId, model.RecordConsumeLogParams{
|
||||
ChannelId: info.ChannelId,
|
||||
ModelName: modelName,
|
||||
@@ -518,7 +518,7 @@ func RelayMidjourneySubmit(c *gin.Context, relayInfo *relaycommon.RelayInfo) *dt
|
||||
}
|
||||
tokenName := c.GetString("token_name")
|
||||
logContent := fmt.Sprintf("模型固定价格 %.2f,分组倍率 %.2f,操作 %s,ID %s", priceData.ModelPrice, priceData.GroupRatioInfo.GroupRatio, midjRequest.Action, midjResponse.Result)
|
||||
other := service.GenerateMjOtherInfo(priceData)
|
||||
other := service.GenerateMjOtherInfo(relayInfo, priceData)
|
||||
model.RecordConsumeLog(c, relayInfo.UserId, model.RecordConsumeLogParams{
|
||||
ChannelId: relayInfo.ChannelId,
|
||||
ModelName: modelName,
|
||||
|
||||
@@ -165,6 +165,9 @@ func RelayTaskSubmit(c *gin.Context, info *relaycommon.RelayInfo) (taskErr *dto.
|
||||
}
|
||||
}
|
||||
other := make(map[string]interface{})
|
||||
if c != nil && c.Request != nil && c.Request.URL != nil {
|
||||
other["request_path"] = c.Request.URL.Path
|
||||
}
|
||||
other["model_price"] = modelPrice
|
||||
other["group_ratio"] = groupRatio
|
||||
if hasUserGroupRatio {
|
||||
@@ -394,12 +397,12 @@ func videoFetchByIDRespBodyBuilder(c *gin.Context) (respBody []byte, taskResp *d
|
||||
return
|
||||
}
|
||||
if converter, ok := adaptor.(channel.OpenAIVideoConverter); ok {
|
||||
openAIVideo, err := converter.ConvertToOpenAIVideo(originTask)
|
||||
openAIVideoData, err := converter.ConvertToOpenAIVideo(originTask)
|
||||
if err != nil {
|
||||
taskResp = service.TaskErrorWrapper(err, "convert_to_openai_video_failed", http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
respBody, _ = json.Marshal(openAIVideo)
|
||||
respBody = openAIVideoData
|
||||
return
|
||||
}
|
||||
taskResp = service.TaskErrorWrapperLocal(errors.New(fmt.Sprintf("not_implemented:%s", originTask.Platform)), "not_implemented", http.StatusNotImplemented)
|
||||
|
||||
@@ -352,7 +352,7 @@ func StreamResponseOpenAI2Claude(openAIResponse *dto.ChatCompletionsStreamRespon
|
||||
Type: "content_block_start",
|
||||
ContentBlock: &dto.ClaudeMediaMessage{
|
||||
Type: "thinking",
|
||||
Thinking: "",
|
||||
Thinking: common.GetPointer[string](""),
|
||||
},
|
||||
})
|
||||
}
|
||||
@@ -360,7 +360,7 @@ func StreamResponseOpenAI2Claude(openAIResponse *dto.ChatCompletionsStreamRespon
|
||||
// text delta
|
||||
claudeResponse.Delta = &dto.ClaudeMediaMessage{
|
||||
Type: "thinking_delta",
|
||||
Thinking: reasoning,
|
||||
Thinking: &reasoning,
|
||||
}
|
||||
} else {
|
||||
if info.ClaudeConvertInfo.LastMessagesType != relaycommon.LastMessageTypeText {
|
||||
|
||||
@@ -1,6 +1,8 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"strings"
|
||||
|
||||
"github.com/QuantumNous/new-api/common"
|
||||
"github.com/QuantumNous/new-api/constant"
|
||||
"github.com/QuantumNous/new-api/dto"
|
||||
@@ -10,6 +12,25 @@ import (
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
func appendRequestPath(ctx *gin.Context, relayInfo *relaycommon.RelayInfo, other map[string]interface{}) {
|
||||
if other == nil {
|
||||
return
|
||||
}
|
||||
if ctx != nil && ctx.Request != nil && ctx.Request.URL != nil {
|
||||
if path := ctx.Request.URL.Path; path != "" {
|
||||
other["request_path"] = path
|
||||
return
|
||||
}
|
||||
}
|
||||
if relayInfo != nil && relayInfo.RequestURLPath != "" {
|
||||
path := relayInfo.RequestURLPath
|
||||
if idx := strings.Index(path, "?"); idx != -1 {
|
||||
path = path[:idx]
|
||||
}
|
||||
other["request_path"] = path
|
||||
}
|
||||
}
|
||||
|
||||
func GenerateTextOtherInfo(ctx *gin.Context, relayInfo *relaycommon.RelayInfo, modelRatio, groupRatio, completionRatio float64,
|
||||
cacheTokens int, cacheRatio float64, modelPrice float64, userGroupRatio float64) map[string]interface{} {
|
||||
other := make(map[string]interface{})
|
||||
@@ -42,6 +63,7 @@ func GenerateTextOtherInfo(ctx *gin.Context, relayInfo *relaycommon.RelayInfo, m
|
||||
adminInfo["multi_key_index"] = common.GetContextKeyInt(ctx, constant.ContextKeyChannelMultiKeyIndex)
|
||||
}
|
||||
other["admin_info"] = adminInfo
|
||||
appendRequestPath(ctx, relayInfo, other)
|
||||
return other
|
||||
}
|
||||
|
||||
@@ -78,12 +100,13 @@ func GenerateClaudeOtherInfo(ctx *gin.Context, relayInfo *relaycommon.RelayInfo,
|
||||
return info
|
||||
}
|
||||
|
||||
func GenerateMjOtherInfo(priceData types.PerCallPriceData) map[string]interface{} {
|
||||
func GenerateMjOtherInfo(relayInfo *relaycommon.RelayInfo, priceData types.PerCallPriceData) map[string]interface{} {
|
||||
other := make(map[string]interface{})
|
||||
other["model_price"] = priceData.ModelPrice
|
||||
other["group_ratio"] = priceData.GroupRatioInfo.GroupRatio
|
||||
if priceData.GroupRatioInfo.HasSpecialRatio {
|
||||
other["user_group_ratio"] = priceData.GroupRatioInfo.GroupSpecialRatio
|
||||
}
|
||||
appendRequestPath(nil, relayInfo, other)
|
||||
return other
|
||||
}
|
||||
|
||||
@@ -62,6 +62,9 @@ const (
|
||||
ErrorCodeConvertRequestFailed ErrorCode = "convert_request_failed"
|
||||
ErrorCodeAccessDenied ErrorCode = "access_denied"
|
||||
|
||||
// request error
|
||||
ErrorCodeBadRequestBody ErrorCode = "bad_request_body"
|
||||
|
||||
// response error
|
||||
ErrorCodeReadResponseBodyFailed ErrorCode = "read_response_body_failed"
|
||||
ErrorCodeBadResponseStatusCode ErrorCode = "bad_response_status_code"
|
||||
|
||||
@@ -10,7 +10,8 @@
|
||||
content="OpenAI 接口聚合管理,支持多种渠道包括 Azure,可用于二次分发管理 key,仅单可执行文件,已打包好 Docker 镜像,一键部署,开箱即用"
|
||||
/>
|
||||
<title>New API</title>
|
||||
<analytics></analytics>
|
||||
<!--umami-->
|
||||
<!--Google Analytics-->
|
||||
</head>
|
||||
|
||||
<body>
|
||||
|
||||
@@ -110,7 +110,7 @@ function type2secretPrompt(type) {
|
||||
case 50:
|
||||
return '按照如下格式输入: AccessKey|SecretKey, 如果上游是New API,则直接输ApiKey';
|
||||
case 51:
|
||||
return '按照如下格式输入: Access Key ID|Secret Access Key';
|
||||
return '按照如下格式输入: AccessKey|SecretAccessKey';
|
||||
default:
|
||||
return '请输入渠道对应的鉴权密钥';
|
||||
}
|
||||
@@ -153,6 +153,8 @@ const EditChannelModal = (props) => {
|
||||
settings: '',
|
||||
// 仅 Vertex: 密钥格式(存入 settings.vertex_key_type)
|
||||
vertex_key_type: 'json',
|
||||
// 仅 AWS: 密钥格式和区域(存入 settings.aws_key_type 和 settings.aws_region)
|
||||
aws_key_type: 'ak_sk',
|
||||
// 企业账户设置
|
||||
is_enterprise_account: false,
|
||||
// 字段透传控制默认值
|
||||
@@ -515,6 +517,8 @@ const EditChannelModal = (props) => {
|
||||
parsedSettings.azure_responses_version || '';
|
||||
// 读取 Vertex 密钥格式
|
||||
data.vertex_key_type = parsedSettings.vertex_key_type || 'json';
|
||||
// 读取 AWS 密钥格式和区域
|
||||
data.aws_key_type = parsedSettings.aws_key_type || 'ak_sk';
|
||||
// 读取企业账户设置
|
||||
data.is_enterprise_account =
|
||||
parsedSettings.openrouter_enterprise === true;
|
||||
@@ -528,6 +532,7 @@ const EditChannelModal = (props) => {
|
||||
data.azure_responses_version = '';
|
||||
data.region = '';
|
||||
data.vertex_key_type = 'json';
|
||||
data.aws_key_type = 'ak_sk';
|
||||
data.is_enterprise_account = false;
|
||||
data.allow_service_tier = false;
|
||||
data.disable_store = false;
|
||||
@@ -536,6 +541,7 @@ const EditChannelModal = (props) => {
|
||||
} else {
|
||||
// 兼容历史数据:老渠道没有 settings 时,默认按 json 展示
|
||||
data.vertex_key_type = 'json';
|
||||
data.aws_key_type = 'ak_sk';
|
||||
data.is_enterprise_account = false;
|
||||
data.allow_service_tier = false;
|
||||
data.disable_store = false;
|
||||
@@ -997,6 +1003,11 @@ const EditChannelModal = (props) => {
|
||||
localInputs.is_enterprise_account === true;
|
||||
}
|
||||
|
||||
// type === 33 (AWS): 保存 aws_key_type 到 settings
|
||||
if (localInputs.type === 33) {
|
||||
settings.aws_key_type = localInputs.aws_key_type || 'ak_sk';
|
||||
}
|
||||
|
||||
// type === 1 (OpenAI) 或 type === 14 (Claude): 设置字段透传控制(显式保存布尔值)
|
||||
if (localInputs.type === 1 || localInputs.type === 14) {
|
||||
settings.allow_service_tier = localInputs.allow_service_tier === true;
|
||||
@@ -1020,6 +1031,8 @@ const EditChannelModal = (props) => {
|
||||
delete localInputs.is_enterprise_account;
|
||||
// 顶层的 vertex_key_type 不应发送给后端
|
||||
delete localInputs.vertex_key_type;
|
||||
// 顶层的 aws_key_type 不应发送给后端
|
||||
delete localInputs.aws_key_type;
|
||||
// 清理字段透传控制的临时字段
|
||||
delete localInputs.allow_service_tier;
|
||||
delete localInputs.disable_store;
|
||||
@@ -1468,6 +1481,31 @@ const EditChannelModal = (props) => {
|
||||
autoComplete='new-password'
|
||||
/>
|
||||
|
||||
{inputs.type === 33 && (
|
||||
<>
|
||||
<Form.Select
|
||||
field='aws_key_type'
|
||||
label={t('密钥格式')}
|
||||
placeholder={t('请选择密钥格式')}
|
||||
optionList={[
|
||||
{
|
||||
label: 'AccessKey / SecretAccessKey',
|
||||
value: 'ak_sk',
|
||||
},
|
||||
{ label: 'API Key', value: 'api_key' },
|
||||
]}
|
||||
style={{ width: '100%' }}
|
||||
value={inputs.aws_key_type || 'ak_sk'}
|
||||
onChange={(value) => {
|
||||
handleChannelOtherSettingsChange('aws_key_type', value);
|
||||
}}
|
||||
extraText={t(
|
||||
'AK/SK 模式:使用 AccessKey 和 SecretAccessKey;API Key 模式:使用 API Key',
|
||||
)}
|
||||
/>
|
||||
</>
|
||||
)}
|
||||
|
||||
{inputs.type === 41 && (
|
||||
<Form.Select
|
||||
field='vertex_key_type'
|
||||
@@ -1536,7 +1574,15 @@ const EditChannelModal = (props) => {
|
||||
<Form.TextArea
|
||||
field='key'
|
||||
label={t('密钥')}
|
||||
placeholder={t('请输入密钥,一行一个')}
|
||||
placeholder={
|
||||
inputs.type === 33
|
||||
? inputs.aws_key_type === 'api_key'
|
||||
? t('请输入 API Key,一行一个,格式:APIKey|Region')
|
||||
: t(
|
||||
'请输入密钥,一行一个,格式:AccessKey|SecretAccessKey|Region',
|
||||
)
|
||||
: t('请输入密钥,一行一个')
|
||||
}
|
||||
rules={
|
||||
isEdit
|
||||
? []
|
||||
@@ -1730,7 +1776,13 @@ const EditChannelModal = (props) => {
|
||||
? t('密钥(编辑模式下,保存的密钥不会显示)')
|
||||
: t('密钥')
|
||||
}
|
||||
placeholder={t(type2secretPrompt(inputs.type))}
|
||||
placeholder={
|
||||
inputs.type === 33
|
||||
? inputs.aws_key_type === 'api_key'
|
||||
? t('请输入 API Key,格式:APIKey|Region')
|
||||
: t('按照如下格式输入:AccessKey|SecretAccessKey|Region')
|
||||
: t(type2secretPrompt(inputs.type))
|
||||
}
|
||||
rules={
|
||||
isEdit
|
||||
? []
|
||||
|
||||
@@ -468,6 +468,12 @@ export const useLogsData = () => {
|
||||
});
|
||||
}
|
||||
}
|
||||
if (other?.request_path) {
|
||||
expandDataLocal.push({
|
||||
key: t('请求路径'),
|
||||
value: other.request_path,
|
||||
});
|
||||
}
|
||||
expandDatesLocal[logs[i].key] = expandDataLocal;
|
||||
}
|
||||
|
||||
|
||||
@@ -1675,6 +1675,7 @@
|
||||
"请求失败": "Request failed",
|
||||
"请求头覆盖": "Request header override",
|
||||
"请求并计费模型": "Request and charge model",
|
||||
"请求路径": "Request path",
|
||||
"请求时长: ${time}s": "Request time: ${time}s",
|
||||
"请求次数": "Number of Requests",
|
||||
"请求结束后多退少补": "Adjust after request completion",
|
||||
|
||||
@@ -1684,6 +1684,7 @@
|
||||
"请求失败": "Échec de la demande",
|
||||
"请求头覆盖": "Remplacement des en-têtes de demande",
|
||||
"请求并计费模型": "Modèle de demande et de facturation",
|
||||
"请求路径": "Chemin de requête",
|
||||
"请求时长: ${time}s": "Durée de la requête : ${time}s",
|
||||
"请求次数": "Nombre de demandes",
|
||||
"请求结束后多退少补": "Ajuster après la fin de la demande",
|
||||
@@ -2081,4 +2082,4 @@
|
||||
"默认测试模型": "Modèle de test par défaut",
|
||||
"默认补全倍率": "Taux de complétion par défaut"
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1693,6 +1693,7 @@
|
||||
"请求失败": "Запрос не удался",
|
||||
"请求头覆盖": "Переопределение заголовков запроса",
|
||||
"请求并计费模型": "Запрос и выставление счёта модели",
|
||||
"请求路径": "Путь запроса",
|
||||
"请求时长: ${time}s": "Время запроса: ${time}s",
|
||||
"请求次数": "Количество запросов",
|
||||
"请求结束后多退少补": "После вывода запроса возврат излишков и доплата недостатка",
|
||||
|
||||
@@ -1666,6 +1666,7 @@
|
||||
"请求失败": "请求失败",
|
||||
"请求头覆盖": "请求头覆盖",
|
||||
"请求并计费模型": "请求并计费模型",
|
||||
"请求路径": "请求路径",
|
||||
"请求时长: ${time}s": "请求时长: ${time}s",
|
||||
"请求次数": "请求次数",
|
||||
"请求结束后多退少补": "请求结束后多退少补",
|
||||
|
||||
Reference in New Issue
Block a user