mirror of
https://github.com/QuantumNous/new-api.git
synced 2026-04-19 12:08:37 +00:00
Merge branch 'upstream-main' into fix/pr-2540
# Conflicts: # relay/channel/gemini/relay-gemini.go
This commit is contained in:
@@ -13,12 +13,37 @@ import (
|
||||
"github.com/QuantumNous/new-api/relay/channel/openai"
|
||||
relaycommon "github.com/QuantumNous/new-api/relay/common"
|
||||
"github.com/QuantumNous/new-api/relay/constant"
|
||||
"github.com/QuantumNous/new-api/setting/model_setting"
|
||||
"github.com/QuantumNous/new-api/service"
|
||||
"github.com/QuantumNous/new-api/types"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
type Adaptor struct {
|
||||
IsSyncImageModel bool
|
||||
}
|
||||
|
||||
/*
|
||||
var syncModels = []string{
|
||||
"z-image",
|
||||
"qwen-image",
|
||||
"wan2.6",
|
||||
}
|
||||
*/
|
||||
func supportsAliAnthropicMessages(modelName string) bool {
|
||||
// Only models with the "qwen" designation can use the Claude-compatible interface; others require conversion.
|
||||
return strings.Contains(strings.ToLower(modelName), "qwen")
|
||||
}
|
||||
|
||||
var syncModels = []string{
|
||||
"z-image",
|
||||
"qwen-image",
|
||||
"wan2.6",
|
||||
}
|
||||
|
||||
func isSyncImageModel(modelName string) bool {
|
||||
return model_setting.IsSyncImageModel(modelName)
|
||||
}
|
||||
|
||||
func (a *Adaptor) ConvertGeminiRequest(*gin.Context, *relaycommon.RelayInfo, *dto.GeminiChatRequest) (any, error) {
|
||||
@@ -27,7 +52,18 @@ func (a *Adaptor) ConvertGeminiRequest(*gin.Context, *relaycommon.RelayInfo, *dt
|
||||
}
|
||||
|
||||
func (a *Adaptor) ConvertClaudeRequest(c *gin.Context, info *relaycommon.RelayInfo, req *dto.ClaudeRequest) (any, error) {
|
||||
return req, nil
|
||||
if supportsAliAnthropicMessages(info.UpstreamModelName) {
|
||||
return req, nil
|
||||
}
|
||||
|
||||
oaiReq, err := service.ClaudeToOpenAIRequest(*req, info)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if info.SupportStreamOptions && info.IsStream {
|
||||
oaiReq.StreamOptions = &dto.StreamOptions{IncludeUsage: true}
|
||||
}
|
||||
return a.ConvertOpenAIRequest(c, info, oaiReq)
|
||||
}
|
||||
|
||||
func (a *Adaptor) Init(info *relaycommon.RelayInfo) {
|
||||
@@ -37,7 +73,11 @@ func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) {
|
||||
var fullRequestURL string
|
||||
switch info.RelayFormat {
|
||||
case types.RelayFormatClaude:
|
||||
fullRequestURL = fmt.Sprintf("%s/api/v2/apps/claude-code-proxy/v1/messages", info.ChannelBaseUrl)
|
||||
if supportsAliAnthropicMessages(info.UpstreamModelName) {
|
||||
fullRequestURL = fmt.Sprintf("%s/apps/anthropic/v1/messages", info.ChannelBaseUrl)
|
||||
} else {
|
||||
fullRequestURL = fmt.Sprintf("%s/compatible-mode/v1/chat/completions", info.ChannelBaseUrl)
|
||||
}
|
||||
default:
|
||||
switch info.RelayMode {
|
||||
case constant.RelayModeEmbeddings:
|
||||
@@ -45,10 +85,16 @@ func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) {
|
||||
case constant.RelayModeRerank:
|
||||
fullRequestURL = fmt.Sprintf("%s/api/v1/services/rerank/text-rerank/text-rerank", info.ChannelBaseUrl)
|
||||
case constant.RelayModeImagesGenerations:
|
||||
fullRequestURL = fmt.Sprintf("%s/api/v1/services/aigc/text2image/image-synthesis", info.ChannelBaseUrl)
|
||||
if isSyncImageModel(info.OriginModelName) {
|
||||
fullRequestURL = fmt.Sprintf("%s/api/v1/services/aigc/multimodal-generation/generation", info.ChannelBaseUrl)
|
||||
} else {
|
||||
fullRequestURL = fmt.Sprintf("%s/api/v1/services/aigc/text2image/image-synthesis", info.ChannelBaseUrl)
|
||||
}
|
||||
case constant.RelayModeImagesEdits:
|
||||
if isWanModel(info.OriginModelName) {
|
||||
if isOldWanModel(info.OriginModelName) {
|
||||
fullRequestURL = fmt.Sprintf("%s/api/v1/services/aigc/image2image/image-synthesis", info.ChannelBaseUrl)
|
||||
} else if isWanModel(info.OriginModelName) {
|
||||
fullRequestURL = fmt.Sprintf("%s/api/v1/services/aigc/image-generation/generation", info.ChannelBaseUrl)
|
||||
} else {
|
||||
fullRequestURL = fmt.Sprintf("%s/api/v1/services/aigc/multimodal-generation/generation", info.ChannelBaseUrl)
|
||||
}
|
||||
@@ -72,7 +118,11 @@ func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Header, info *rel
|
||||
req.Set("X-DashScope-Plugin", c.GetString("plugin"))
|
||||
}
|
||||
if info.RelayMode == constant.RelayModeImagesGenerations {
|
||||
req.Set("X-DashScope-Async", "enable")
|
||||
if isSyncImageModel(info.OriginModelName) {
|
||||
|
||||
} else {
|
||||
req.Set("X-DashScope-Async", "enable")
|
||||
}
|
||||
}
|
||||
if info.RelayMode == constant.RelayModeImagesEdits {
|
||||
if isWanModel(info.OriginModelName) {
|
||||
@@ -108,15 +158,25 @@ func (a *Adaptor) ConvertOpenAIRequest(c *gin.Context, info *relaycommon.RelayIn
|
||||
|
||||
func (a *Adaptor) ConvertImageRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.ImageRequest) (any, error) {
|
||||
if info.RelayMode == constant.RelayModeImagesGenerations {
|
||||
aliRequest, err := oaiImage2Ali(request)
|
||||
if isSyncImageModel(info.OriginModelName) {
|
||||
a.IsSyncImageModel = true
|
||||
}
|
||||
aliRequest, err := oaiImage2AliImageRequest(info, request, a.IsSyncImageModel)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("convert image request failed: %w", err)
|
||||
return nil, fmt.Errorf("convert image request to async ali image request failed: %w", err)
|
||||
}
|
||||
return aliRequest, nil
|
||||
} else if info.RelayMode == constant.RelayModeImagesEdits {
|
||||
if isWanModel(info.OriginModelName) {
|
||||
if isOldWanModel(info.OriginModelName) {
|
||||
return oaiFormEdit2WanxImageEdit(c, info, request)
|
||||
}
|
||||
if isSyncImageModel(info.OriginModelName) {
|
||||
if isWanModel(info.OriginModelName) {
|
||||
a.IsSyncImageModel = false
|
||||
} else {
|
||||
a.IsSyncImageModel = true
|
||||
}
|
||||
}
|
||||
// ali image edit https://bailian.console.aliyun.com/?tab=api#/api/?type=model&url=2976416
|
||||
// 如果用户使用表单,则需要解析表单数据
|
||||
if strings.Contains(c.Request.Header.Get("Content-Type"), "multipart/form-data") {
|
||||
@@ -126,9 +186,9 @@ func (a *Adaptor) ConvertImageRequest(c *gin.Context, info *relaycommon.RelayInf
|
||||
}
|
||||
return aliRequest, nil
|
||||
} else {
|
||||
aliRequest, err := oaiImage2Ali(request)
|
||||
aliRequest, err := oaiImage2AliImageRequest(info, request, a.IsSyncImageModel)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("convert image request failed: %w", err)
|
||||
return nil, fmt.Errorf("convert image request to async ali image request failed: %w", err)
|
||||
}
|
||||
return aliRequest, nil
|
||||
}
|
||||
@@ -150,7 +210,7 @@ func (a *Adaptor) ConvertAudioRequest(c *gin.Context, info *relaycommon.RelayInf
|
||||
}
|
||||
|
||||
func (a *Adaptor) ConvertOpenAIResponsesRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.OpenAIResponsesRequest) (any, error) {
|
||||
// TODO implement me
|
||||
//TODO implement me
|
||||
return nil, errors.New("not implemented")
|
||||
}
|
||||
|
||||
@@ -161,21 +221,22 @@ func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, request
|
||||
func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage any, err *types.NewAPIError) {
|
||||
switch info.RelayFormat {
|
||||
case types.RelayFormatClaude:
|
||||
if info.IsStream {
|
||||
return claude.ClaudeStreamHandler(c, resp, info, claude.RequestModeMessage)
|
||||
} else {
|
||||
if supportsAliAnthropicMessages(info.UpstreamModelName) {
|
||||
if info.IsStream {
|
||||
return claude.ClaudeStreamHandler(c, resp, info, claude.RequestModeMessage)
|
||||
}
|
||||
|
||||
return claude.ClaudeHandler(c, resp, info, claude.RequestModeMessage)
|
||||
}
|
||||
|
||||
adaptor := openai.Adaptor{}
|
||||
return adaptor.DoResponse(c, resp, info)
|
||||
default:
|
||||
switch info.RelayMode {
|
||||
case constant.RelayModeImagesGenerations:
|
||||
err, usage = aliImageHandler(c, resp, info)
|
||||
err, usage = aliImageHandler(a, c, resp, info)
|
||||
case constant.RelayModeImagesEdits:
|
||||
if isWanModel(info.OriginModelName) {
|
||||
err, usage = aliImageHandler(c, resp, info)
|
||||
} else {
|
||||
err, usage = aliImageEditHandler(c, resp, info)
|
||||
}
|
||||
err, usage = aliImageHandler(a, c, resp, info)
|
||||
case constant.RelayModeRerank:
|
||||
err, usage = RerankHandler(c, resp, info)
|
||||
default:
|
||||
|
||||
@@ -1,6 +1,13 @@
|
||||
package ali
|
||||
|
||||
import "github.com/QuantumNous/new-api/dto"
|
||||
import (
|
||||
"strings"
|
||||
|
||||
"github.com/QuantumNous/new-api/dto"
|
||||
"github.com/QuantumNous/new-api/logger"
|
||||
"github.com/QuantumNous/new-api/service"
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
type AliMessage struct {
|
||||
Content any `json:"content"`
|
||||
@@ -65,6 +72,7 @@ type AliUsage struct {
|
||||
InputTokens int `json:"input_tokens"`
|
||||
OutputTokens int `json:"output_tokens"`
|
||||
TotalTokens int `json:"total_tokens"`
|
||||
ImageCount int `json:"image_count,omitempty"`
|
||||
}
|
||||
|
||||
type TaskResult struct {
|
||||
@@ -75,14 +83,78 @@ type TaskResult struct {
|
||||
}
|
||||
|
||||
type AliOutput struct {
|
||||
TaskId string `json:"task_id,omitempty"`
|
||||
TaskStatus string `json:"task_status,omitempty"`
|
||||
Text string `json:"text"`
|
||||
FinishReason string `json:"finish_reason"`
|
||||
Message string `json:"message,omitempty"`
|
||||
Code string `json:"code,omitempty"`
|
||||
Results []TaskResult `json:"results,omitempty"`
|
||||
Choices []map[string]any `json:"choices,omitempty"`
|
||||
TaskId string `json:"task_id,omitempty"`
|
||||
TaskStatus string `json:"task_status,omitempty"`
|
||||
Text string `json:"text"`
|
||||
FinishReason string `json:"finish_reason"`
|
||||
Message string `json:"message,omitempty"`
|
||||
Code string `json:"code,omitempty"`
|
||||
Results []TaskResult `json:"results,omitempty"`
|
||||
Choices []struct {
|
||||
FinishReason string `json:"finish_reason,omitempty"`
|
||||
Message struct {
|
||||
Role string `json:"role,omitempty"`
|
||||
Content []AliMediaContent `json:"content,omitempty"`
|
||||
ReasoningContent string `json:"reasoning_content,omitempty"`
|
||||
} `json:"message,omitempty"`
|
||||
} `json:"choices,omitempty"`
|
||||
}
|
||||
|
||||
func (o *AliOutput) ChoicesToOpenAIImageDate(c *gin.Context, responseFormat string) []dto.ImageData {
|
||||
var imageData []dto.ImageData
|
||||
if len(o.Choices) > 0 {
|
||||
for _, choice := range o.Choices {
|
||||
var data dto.ImageData
|
||||
for _, content := range choice.Message.Content {
|
||||
if content.Image != "" {
|
||||
if strings.HasPrefix(content.Image, "http") {
|
||||
var b64Json string
|
||||
if responseFormat == "b64_json" {
|
||||
_, b64, err := service.GetImageFromUrl(content.Image)
|
||||
if err != nil {
|
||||
logger.LogError(c, "get_image_data_failed: "+err.Error())
|
||||
continue
|
||||
}
|
||||
b64Json = b64
|
||||
}
|
||||
data.Url = content.Image
|
||||
data.B64Json = b64Json
|
||||
} else {
|
||||
data.B64Json = content.Image
|
||||
}
|
||||
} else if content.Text != "" {
|
||||
data.RevisedPrompt = content.Text
|
||||
}
|
||||
}
|
||||
imageData = append(imageData, data)
|
||||
}
|
||||
}
|
||||
|
||||
return imageData
|
||||
}
|
||||
|
||||
func (o *AliOutput) ResultToOpenAIImageDate(c *gin.Context, responseFormat string) []dto.ImageData {
|
||||
var imageData []dto.ImageData
|
||||
for _, data := range o.Results {
|
||||
var b64Json string
|
||||
if responseFormat == "b64_json" {
|
||||
_, b64, err := service.GetImageFromUrl(data.Url)
|
||||
if err != nil {
|
||||
logger.LogError(c, "get_image_data_failed: "+err.Error())
|
||||
continue
|
||||
}
|
||||
b64Json = b64
|
||||
} else {
|
||||
b64Json = data.B64Image
|
||||
}
|
||||
|
||||
imageData = append(imageData, dto.ImageData{
|
||||
Url: data.Url,
|
||||
B64Json: b64Json,
|
||||
RevisedPrompt: "",
|
||||
})
|
||||
}
|
||||
return imageData
|
||||
}
|
||||
|
||||
type AliResponse struct {
|
||||
@@ -92,18 +164,26 @@ type AliResponse struct {
|
||||
}
|
||||
|
||||
type AliImageRequest struct {
|
||||
Model string `json:"model"`
|
||||
Input any `json:"input"`
|
||||
Parameters any `json:"parameters,omitempty"`
|
||||
ResponseFormat string `json:"response_format,omitempty"`
|
||||
Model string `json:"model"`
|
||||
Input any `json:"input"`
|
||||
Parameters AliImageParameters `json:"parameters,omitempty"`
|
||||
ResponseFormat string `json:"response_format,omitempty"`
|
||||
}
|
||||
|
||||
type AliImageParameters struct {
|
||||
Size string `json:"size,omitempty"`
|
||||
N int `json:"n,omitempty"`
|
||||
Steps string `json:"steps,omitempty"`
|
||||
Scale string `json:"scale,omitempty"`
|
||||
Watermark *bool `json:"watermark,omitempty"`
|
||||
Size string `json:"size,omitempty"`
|
||||
N int `json:"n,omitempty"`
|
||||
Steps string `json:"steps,omitempty"`
|
||||
Scale string `json:"scale,omitempty"`
|
||||
Watermark *bool `json:"watermark,omitempty"`
|
||||
PromptExtend *bool `json:"prompt_extend,omitempty"`
|
||||
}
|
||||
|
||||
func (p *AliImageParameters) PromptExtendValue() bool {
|
||||
if p != nil && p.PromptExtend != nil {
|
||||
return *p.PromptExtend
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
type AliImageInput struct {
|
||||
|
||||
@@ -1,7 +1,6 @@
|
||||
package ali
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/base64"
|
||||
"errors"
|
||||
"fmt"
|
||||
@@ -21,17 +20,23 @@ import (
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
func oaiImage2Ali(request dto.ImageRequest) (*AliImageRequest, error) {
|
||||
func oaiImage2AliImageRequest(info *relaycommon.RelayInfo, request dto.ImageRequest, isSync bool) (*AliImageRequest, error) {
|
||||
var imageRequest AliImageRequest
|
||||
imageRequest.Model = request.Model
|
||||
imageRequest.ResponseFormat = request.ResponseFormat
|
||||
logger.LogJson(context.Background(), "oaiImage2Ali request extra", request.Extra)
|
||||
if request.Extra != nil {
|
||||
if val, ok := request.Extra["parameters"]; ok {
|
||||
err := common.Unmarshal(val, &imageRequest.Parameters)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("invalid parameters field: %w", err)
|
||||
}
|
||||
} else {
|
||||
// 兼容没有parameters字段的情况,从openai标准字段中提取参数
|
||||
imageRequest.Parameters = AliImageParameters{
|
||||
Size: strings.Replace(request.Size, "x", "*", -1),
|
||||
N: int(request.N),
|
||||
Watermark: request.Watermark,
|
||||
}
|
||||
}
|
||||
if val, ok := request.Extra["input"]; ok {
|
||||
err := common.Unmarshal(val, &imageRequest.Input)
|
||||
@@ -41,23 +46,44 @@ func oaiImage2Ali(request dto.ImageRequest) (*AliImageRequest, error) {
|
||||
}
|
||||
}
|
||||
|
||||
if imageRequest.Parameters == nil {
|
||||
imageRequest.Parameters = AliImageParameters{
|
||||
Size: strings.Replace(request.Size, "x", "*", -1),
|
||||
N: int(request.N),
|
||||
Watermark: request.Watermark,
|
||||
if strings.Contains(request.Model, "z-image") {
|
||||
// z-image 开启prompt_extend后,按2倍计费
|
||||
if imageRequest.Parameters.PromptExtendValue() {
|
||||
info.PriceData.AddOtherRatio("prompt_extend", 2)
|
||||
}
|
||||
}
|
||||
|
||||
if imageRequest.Input == nil {
|
||||
imageRequest.Input = AliImageInput{
|
||||
Prompt: request.Prompt,
|
||||
// 检查n参数
|
||||
if imageRequest.Parameters.N != 0 {
|
||||
info.PriceData.AddOtherRatio("n", float64(imageRequest.Parameters.N))
|
||||
}
|
||||
|
||||
// 同步图片模型和异步图片模型请求格式不一样
|
||||
if isSync {
|
||||
if imageRequest.Input == nil {
|
||||
imageRequest.Input = AliImageInput{
|
||||
Messages: []AliMessage{
|
||||
{
|
||||
Role: "user",
|
||||
Content: []AliMediaContent{
|
||||
{
|
||||
Text: request.Prompt,
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
}
|
||||
} else {
|
||||
if imageRequest.Input == nil {
|
||||
imageRequest.Input = AliImageInput{
|
||||
Prompt: request.Prompt,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return &imageRequest, nil
|
||||
}
|
||||
|
||||
func getImageBase64sFromForm(c *gin.Context, fieldName string) ([]string, error) {
|
||||
mf := c.Request.MultipartForm
|
||||
if mf == nil {
|
||||
@@ -199,6 +225,8 @@ func asyncTaskWait(c *gin.Context, info *relaycommon.RelayInfo, taskID string) (
|
||||
var taskResponse AliResponse
|
||||
var responseBody []byte
|
||||
|
||||
time.Sleep(time.Duration(5) * time.Second)
|
||||
|
||||
for {
|
||||
logger.LogDebug(c, fmt.Sprintf("asyncTaskWait step %d/%d, wait %d seconds", step, maxStep, waitSeconds))
|
||||
step++
|
||||
@@ -238,32 +266,17 @@ func responseAli2OpenAIImage(c *gin.Context, response *AliResponse, originBody [
|
||||
Created: info.StartTime.Unix(),
|
||||
}
|
||||
|
||||
for _, data := range response.Output.Results {
|
||||
var b64Json string
|
||||
if responseFormat == "b64_json" {
|
||||
_, b64, err := service.GetImageFromUrl(data.Url)
|
||||
if err != nil {
|
||||
logger.LogError(c, "get_image_data_failed: "+err.Error())
|
||||
continue
|
||||
}
|
||||
b64Json = b64
|
||||
} else {
|
||||
b64Json = data.B64Image
|
||||
}
|
||||
|
||||
imageResponse.Data = append(imageResponse.Data, dto.ImageData{
|
||||
Url: data.Url,
|
||||
B64Json: b64Json,
|
||||
RevisedPrompt: "",
|
||||
})
|
||||
if len(response.Output.Results) > 0 {
|
||||
imageResponse.Data = response.Output.ResultToOpenAIImageDate(c, responseFormat)
|
||||
} else if len(response.Output.Choices) > 0 {
|
||||
imageResponse.Data = response.Output.ChoicesToOpenAIImageDate(c, responseFormat)
|
||||
}
|
||||
var mapResponse map[string]any
|
||||
_ = common.Unmarshal(originBody, &mapResponse)
|
||||
imageResponse.Extra = mapResponse
|
||||
|
||||
imageResponse.Metadata = originBody
|
||||
return &imageResponse
|
||||
}
|
||||
|
||||
func aliImageHandler(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (*types.NewAPIError, *dto.Usage) {
|
||||
func aliImageHandler(a *Adaptor, c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (*types.NewAPIError, *dto.Usage) {
|
||||
responseFormat := c.GetString("response_format")
|
||||
|
||||
var aliTaskResponse AliResponse
|
||||
@@ -282,66 +295,49 @@ func aliImageHandler(c *gin.Context, resp *http.Response, info *relaycommon.Rela
|
||||
return types.NewError(errors.New(aliTaskResponse.Message), types.ErrorCodeBadResponse), nil
|
||||
}
|
||||
|
||||
aliResponse, originRespBody, err := asyncTaskWait(c, info, aliTaskResponse.Output.TaskId)
|
||||
if err != nil {
|
||||
return types.NewError(err, types.ErrorCodeBadResponse), nil
|
||||
}
|
||||
var (
|
||||
aliResponse *AliResponse
|
||||
originRespBody []byte
|
||||
)
|
||||
|
||||
if aliResponse.Output.TaskStatus != "SUCCEEDED" {
|
||||
return types.WithOpenAIError(types.OpenAIError{
|
||||
Message: aliResponse.Output.Message,
|
||||
Type: "ali_error",
|
||||
Param: "",
|
||||
Code: aliResponse.Output.Code,
|
||||
}, resp.StatusCode), nil
|
||||
}
|
||||
|
||||
fullTextResponse := responseAli2OpenAIImage(c, aliResponse, originRespBody, info, responseFormat)
|
||||
jsonResponse, err := common.Marshal(fullTextResponse)
|
||||
if err != nil {
|
||||
return types.NewError(err, types.ErrorCodeBadResponseBody), nil
|
||||
}
|
||||
service.IOCopyBytesGracefully(c, resp, jsonResponse)
|
||||
return nil, &dto.Usage{}
|
||||
}
|
||||
|
||||
func aliImageEditHandler(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (*types.NewAPIError, *dto.Usage) {
|
||||
var aliResponse AliResponse
|
||||
responseBody, err := io.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
return types.NewOpenAIError(err, types.ErrorCodeReadResponseBodyFailed, http.StatusInternalServerError), nil
|
||||
}
|
||||
|
||||
service.CloseResponseBodyGracefully(resp)
|
||||
err = common.Unmarshal(responseBody, &aliResponse)
|
||||
if err != nil {
|
||||
return types.NewOpenAIError(err, types.ErrorCodeBadResponseBody, http.StatusInternalServerError), nil
|
||||
}
|
||||
|
||||
if aliResponse.Message != "" {
|
||||
logger.LogError(c, "ali_task_failed: "+aliResponse.Message)
|
||||
return types.NewError(errors.New(aliResponse.Message), types.ErrorCodeBadResponse), nil
|
||||
}
|
||||
var fullTextResponse dto.ImageResponse
|
||||
if len(aliResponse.Output.Choices) > 0 {
|
||||
fullTextResponse = dto.ImageResponse{
|
||||
Created: info.StartTime.Unix(),
|
||||
Data: []dto.ImageData{
|
||||
{
|
||||
Url: aliResponse.Output.Choices[0]["message"].(map[string]any)["content"].([]any)[0].(map[string]any)["image"].(string),
|
||||
B64Json: "",
|
||||
},
|
||||
},
|
||||
if a.IsSyncImageModel {
|
||||
aliResponse = &aliTaskResponse
|
||||
originRespBody = responseBody
|
||||
} else {
|
||||
// 异步图片模型需要轮询任务结果
|
||||
aliResponse, originRespBody, err = asyncTaskWait(c, info, aliTaskResponse.Output.TaskId)
|
||||
if err != nil {
|
||||
return types.NewError(err, types.ErrorCodeBadResponse), nil
|
||||
}
|
||||
if aliResponse.Output.TaskStatus != "SUCCEEDED" {
|
||||
return types.WithOpenAIError(types.OpenAIError{
|
||||
Message: aliResponse.Output.Message,
|
||||
Type: "ali_error",
|
||||
Param: "",
|
||||
Code: aliResponse.Output.Code,
|
||||
}, resp.StatusCode), nil
|
||||
}
|
||||
}
|
||||
|
||||
var mapResponse map[string]any
|
||||
_ = common.Unmarshal(responseBody, &mapResponse)
|
||||
fullTextResponse.Extra = mapResponse
|
||||
jsonResponse, err := common.Marshal(fullTextResponse)
|
||||
//logger.LogDebug(c, "ali_async_task_result: "+string(originRespBody))
|
||||
if a.IsSyncImageModel {
|
||||
logger.LogDebug(c, "ali_sync_image_result: "+string(originRespBody))
|
||||
} else {
|
||||
logger.LogDebug(c, "ali_async_image_result: "+string(originRespBody))
|
||||
}
|
||||
|
||||
imageResponses := responseAli2OpenAIImage(c, aliResponse, originRespBody, info, responseFormat)
|
||||
// 可能生成多张图片,修正计费数量n
|
||||
if aliResponse.Usage.ImageCount != 0 {
|
||||
info.PriceData.AddOtherRatio("n", float64(aliResponse.Usage.ImageCount))
|
||||
} else if len(imageResponses.Data) != 0 {
|
||||
info.PriceData.AddOtherRatio("n", float64(len(imageResponses.Data)))
|
||||
}
|
||||
jsonResponse, err := common.Marshal(imageResponses)
|
||||
if err != nil {
|
||||
return types.NewError(err, types.ErrorCodeBadResponseBody), nil
|
||||
}
|
||||
service.IOCopyBytesGracefully(c, resp, jsonResponse)
|
||||
|
||||
return nil, &dto.Usage{}
|
||||
}
|
||||
|
||||
@@ -26,14 +26,22 @@ func oaiFormEdit2WanxImageEdit(c *gin.Context, info *relaycommon.RelayInfo, requ
|
||||
if wanInput.Images, err = getImageBase64sFromForm(c, "image"); err != nil {
|
||||
return nil, fmt.Errorf("get image base64s from form failed: %w", err)
|
||||
}
|
||||
wanParams := WanImageParameters{
|
||||
//wanParams := WanImageParameters{
|
||||
// N: int(request.N),
|
||||
//}
|
||||
imageRequest.Input = wanInput
|
||||
imageRequest.Parameters = AliImageParameters{
|
||||
N: int(request.N),
|
||||
}
|
||||
imageRequest.Input = wanInput
|
||||
imageRequest.Parameters = wanParams
|
||||
info.PriceData.AddOtherRatio("n", float64(imageRequest.Parameters.N))
|
||||
|
||||
return &imageRequest, nil
|
||||
}
|
||||
|
||||
func isOldWanModel(modelName string) bool {
|
||||
return strings.Contains(modelName, "wan") && !strings.Contains(modelName, "wan2.6")
|
||||
}
|
||||
|
||||
func isWanModel(modelName string) bool {
|
||||
return strings.Contains(modelName, "wan")
|
||||
}
|
||||
|
||||
@@ -1,11 +1,13 @@
|
||||
package aws
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/QuantumNous/new-api/common"
|
||||
"github.com/QuantumNous/new-api/dto"
|
||||
@@ -37,6 +39,13 @@ func getAwsErrorStatusCode(err error) int {
|
||||
return http.StatusInternalServerError
|
||||
}
|
||||
|
||||
func newAwsInvokeContext() (context.Context, context.CancelFunc) {
|
||||
if common.RelayTimeout <= 0 {
|
||||
return context.Background(), func() {}
|
||||
}
|
||||
return context.WithTimeout(context.Background(), time.Duration(common.RelayTimeout)*time.Second)
|
||||
}
|
||||
|
||||
func newAwsClient(c *gin.Context, info *relaycommon.RelayInfo) (*bedrockruntime.Client, error) {
|
||||
var (
|
||||
httpClient *http.Client
|
||||
@@ -117,6 +126,7 @@ func doAwsClientRequest(c *gin.Context, info *relaycommon.RelayInfo, a *Adaptor,
|
||||
return nil, types.NewError(errors.Wrap(err, "marshal nova request"), types.ErrorCodeBadResponseBody)
|
||||
}
|
||||
awsReq.Body = reqBody
|
||||
a.AwsReq = awsReq
|
||||
return nil, nil
|
||||
} else {
|
||||
awsClaudeReq, err := formatRequest(requestBody, requestHeader)
|
||||
@@ -201,7 +211,10 @@ func getAwsModelID(requestModel string) string {
|
||||
|
||||
func awsHandler(c *gin.Context, info *relaycommon.RelayInfo, a *Adaptor) (*types.NewAPIError, *dto.Usage) {
|
||||
|
||||
awsResp, err := a.AwsClient.InvokeModel(c.Request.Context(), a.AwsReq.(*bedrockruntime.InvokeModelInput))
|
||||
ctx, cancel := newAwsInvokeContext()
|
||||
defer cancel()
|
||||
|
||||
awsResp, err := a.AwsClient.InvokeModel(ctx, a.AwsReq.(*bedrockruntime.InvokeModelInput))
|
||||
if err != nil {
|
||||
statusCode := getAwsErrorStatusCode(err)
|
||||
return types.NewOpenAIError(errors.Wrap(err, "InvokeModel"), types.ErrorCodeAwsInvokeError, statusCode), nil
|
||||
@@ -228,7 +241,10 @@ func awsHandler(c *gin.Context, info *relaycommon.RelayInfo, a *Adaptor) (*types
|
||||
}
|
||||
|
||||
func awsStreamHandler(c *gin.Context, info *relaycommon.RelayInfo, a *Adaptor) (*types.NewAPIError, *dto.Usage) {
|
||||
awsResp, err := a.AwsClient.InvokeModelWithResponseStream(c.Request.Context(), a.AwsReq.(*bedrockruntime.InvokeModelWithResponseStreamInput))
|
||||
ctx, cancel := newAwsInvokeContext()
|
||||
defer cancel()
|
||||
|
||||
awsResp, err := a.AwsClient.InvokeModelWithResponseStream(ctx, a.AwsReq.(*bedrockruntime.InvokeModelWithResponseStreamInput))
|
||||
if err != nil {
|
||||
statusCode := getAwsErrorStatusCode(err)
|
||||
return types.NewOpenAIError(errors.Wrap(err, "InvokeModelWithResponseStream"), types.ErrorCodeAwsInvokeError, statusCode), nil
|
||||
@@ -268,7 +284,10 @@ func awsStreamHandler(c *gin.Context, info *relaycommon.RelayInfo, a *Adaptor) (
|
||||
// Nova模型处理函数
|
||||
func handleNovaRequest(c *gin.Context, info *relaycommon.RelayInfo, a *Adaptor) (*types.NewAPIError, *dto.Usage) {
|
||||
|
||||
awsResp, err := a.AwsClient.InvokeModel(c.Request.Context(), a.AwsReq.(*bedrockruntime.InvokeModelInput))
|
||||
ctx, cancel := newAwsInvokeContext()
|
||||
defer cancel()
|
||||
|
||||
awsResp, err := a.AwsClient.InvokeModel(ctx, a.AwsReq.(*bedrockruntime.InvokeModelInput))
|
||||
if err != nil {
|
||||
statusCode := getAwsErrorStatusCode(err)
|
||||
return types.NewOpenAIError(errors.Wrap(err, "InvokeModel"), types.ErrorCodeAwsInvokeError, statusCode), nil
|
||||
|
||||
164
relay/channel/codex/adaptor.go
Normal file
164
relay/channel/codex/adaptor.go
Normal file
@@ -0,0 +1,164 @@
|
||||
package codex
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"io"
|
||||
"net/http"
|
||||
"strings"
|
||||
|
||||
"github.com/QuantumNous/new-api/common"
|
||||
"github.com/QuantumNous/new-api/dto"
|
||||
"github.com/QuantumNous/new-api/relay/channel"
|
||||
"github.com/QuantumNous/new-api/relay/channel/openai"
|
||||
relaycommon "github.com/QuantumNous/new-api/relay/common"
|
||||
relayconstant "github.com/QuantumNous/new-api/relay/constant"
|
||||
"github.com/QuantumNous/new-api/types"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
type Adaptor struct {
|
||||
}
|
||||
|
||||
func (a *Adaptor) ConvertGeminiRequest(c *gin.Context, info *relaycommon.RelayInfo, request *dto.GeminiChatRequest) (any, error) {
|
||||
return nil, errors.New("codex channel: endpoint not supported")
|
||||
}
|
||||
|
||||
func (a *Adaptor) ConvertClaudeRequest(*gin.Context, *relaycommon.RelayInfo, *dto.ClaudeRequest) (any, error) {
|
||||
return nil, errors.New("codex channel: endpoint not supported")
|
||||
}
|
||||
|
||||
func (a *Adaptor) ConvertAudioRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.AudioRequest) (io.Reader, error) {
|
||||
return nil, errors.New("codex channel: endpoint not supported")
|
||||
}
|
||||
|
||||
func (a *Adaptor) ConvertImageRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.ImageRequest) (any, error) {
|
||||
return nil, errors.New("codex channel: endpoint not supported")
|
||||
}
|
||||
|
||||
func (a *Adaptor) Init(info *relaycommon.RelayInfo) {
|
||||
}
|
||||
|
||||
func (a *Adaptor) ConvertOpenAIRequest(c *gin.Context, info *relaycommon.RelayInfo, request *dto.GeneralOpenAIRequest) (any, error) {
|
||||
return nil, errors.New("codex channel: endpoint not supported")
|
||||
}
|
||||
|
||||
func (a *Adaptor) ConvertRerankRequest(c *gin.Context, relayMode int, request dto.RerankRequest) (any, error) {
|
||||
return nil, errors.New("codex channel: endpoint not supported")
|
||||
}
|
||||
|
||||
func (a *Adaptor) ConvertEmbeddingRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.EmbeddingRequest) (any, error) {
|
||||
return nil, errors.New("codex channel: endpoint not supported")
|
||||
}
|
||||
|
||||
func (a *Adaptor) ConvertOpenAIResponsesRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.OpenAIResponsesRequest) (any, error) {
|
||||
if info != nil && info.ChannelSetting.SystemPrompt != "" {
|
||||
systemPrompt := info.ChannelSetting.SystemPrompt
|
||||
|
||||
if len(request.Instructions) == 0 {
|
||||
if b, err := common.Marshal(systemPrompt); err == nil {
|
||||
request.Instructions = b
|
||||
} else {
|
||||
return nil, err
|
||||
}
|
||||
} else if info.ChannelSetting.SystemPromptOverride {
|
||||
var existing string
|
||||
if err := common.Unmarshal(request.Instructions, &existing); err == nil {
|
||||
existing = strings.TrimSpace(existing)
|
||||
if existing == "" {
|
||||
if b, err := common.Marshal(systemPrompt); err == nil {
|
||||
request.Instructions = b
|
||||
} else {
|
||||
return nil, err
|
||||
}
|
||||
} else {
|
||||
if b, err := common.Marshal(systemPrompt + "\n" + existing); err == nil {
|
||||
request.Instructions = b
|
||||
} else {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
} else {
|
||||
if b, err := common.Marshal(systemPrompt); err == nil {
|
||||
request.Instructions = b
|
||||
} else {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// codex: store must be false
|
||||
request.Store = json.RawMessage("false")
|
||||
// rm max_output_tokens
|
||||
request.MaxOutputTokens = 0
|
||||
request.Temperature = nil
|
||||
return request, nil
|
||||
}
|
||||
|
||||
func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (any, error) {
|
||||
return channel.DoApiRequest(a, c, info, requestBody)
|
||||
}
|
||||
|
||||
func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage any, err *types.NewAPIError) {
|
||||
if info.RelayMode != relayconstant.RelayModeResponses {
|
||||
return nil, types.NewError(errors.New("codex channel: endpoint not supported"), types.ErrorCodeInvalidRequest)
|
||||
}
|
||||
|
||||
if info.IsStream {
|
||||
return openai.OaiResponsesStreamHandler(c, info, resp)
|
||||
}
|
||||
return openai.OaiResponsesHandler(c, info, resp)
|
||||
}
|
||||
|
||||
func (a *Adaptor) GetModelList() []string {
|
||||
return ModelList
|
||||
}
|
||||
|
||||
func (a *Adaptor) GetChannelName() string {
|
||||
return ChannelName
|
||||
}
|
||||
|
||||
func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) {
|
||||
if info.RelayMode != relayconstant.RelayModeResponses {
|
||||
return "", errors.New("codex channel: only /v1/responses is supported")
|
||||
}
|
||||
return relaycommon.GetFullRequestURL(info.ChannelBaseUrl, "/backend-api/codex/responses", info.ChannelType), nil
|
||||
}
|
||||
|
||||
func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Header, info *relaycommon.RelayInfo) error {
|
||||
channel.SetupApiRequestHeader(info, c, req)
|
||||
|
||||
key := strings.TrimSpace(info.ApiKey)
|
||||
if !strings.HasPrefix(key, "{") {
|
||||
return errors.New("codex channel: key must be a JSON object")
|
||||
}
|
||||
|
||||
oauthKey, err := ParseOAuthKey(key)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
accessToken := strings.TrimSpace(oauthKey.AccessToken)
|
||||
accountID := strings.TrimSpace(oauthKey.AccountID)
|
||||
|
||||
if accessToken == "" {
|
||||
return errors.New("codex channel: access_token is required")
|
||||
}
|
||||
if accountID == "" {
|
||||
return errors.New("codex channel: account_id is required")
|
||||
}
|
||||
|
||||
req.Set("Authorization", "Bearer "+accessToken)
|
||||
req.Set("chatgpt-account-id", accountID)
|
||||
|
||||
if req.Get("OpenAI-Beta") == "" {
|
||||
req.Set("OpenAI-Beta", "responses=experimental")
|
||||
}
|
||||
if req.Get("originator") == "" {
|
||||
req.Set("originator", "codex_cli_rs")
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
9
relay/channel/codex/constants.go
Normal file
9
relay/channel/codex/constants.go
Normal file
@@ -0,0 +1,9 @@
|
||||
package codex
|
||||
|
||||
var ModelList = []string{
|
||||
"gpt-5", "gpt-5-codex", "gpt-5-codex-mini",
|
||||
"gpt-5.1", "gpt-5.1-codex", "gpt-5.1-codex-max", "gpt-5.1-codex-mini",
|
||||
"gpt-5.2", "gpt-5.2-codex",
|
||||
}
|
||||
|
||||
const ChannelName = "codex"
|
||||
30
relay/channel/codex/oauth_key.go
Normal file
30
relay/channel/codex/oauth_key.go
Normal file
@@ -0,0 +1,30 @@
|
||||
package codex
|
||||
|
||||
import (
|
||||
"errors"
|
||||
|
||||
"github.com/QuantumNous/new-api/common"
|
||||
)
|
||||
|
||||
type OAuthKey struct {
|
||||
IDToken string `json:"id_token,omitempty"`
|
||||
AccessToken string `json:"access_token,omitempty"`
|
||||
RefreshToken string `json:"refresh_token,omitempty"`
|
||||
|
||||
AccountID string `json:"account_id,omitempty"`
|
||||
LastRefresh string `json:"last_refresh,omitempty"`
|
||||
Email string `json:"email,omitempty"`
|
||||
Type string `json:"type,omitempty"`
|
||||
Expired string `json:"expired,omitempty"`
|
||||
}
|
||||
|
||||
func ParseOAuthKey(raw string) (*OAuthKey, error) {
|
||||
if raw == "" {
|
||||
return nil, errors.New("codex channel: empty oauth key")
|
||||
}
|
||||
var key OAuthKey
|
||||
if err := common.Unmarshal([]byte(raw), &key); err != nil {
|
||||
return nil, errors.New("codex channel: invalid oauth key json")
|
||||
}
|
||||
return &key, nil
|
||||
}
|
||||
@@ -1,6 +1,7 @@
|
||||
package gemini
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
@@ -8,6 +9,7 @@ import (
|
||||
"net/http"
|
||||
"strconv"
|
||||
"strings"
|
||||
"time"
|
||||
"unicode/utf8"
|
||||
|
||||
"github.com/QuantumNous/new-api/common"
|
||||
@@ -32,6 +34,7 @@ var geminiSupportedMimeTypes = map[string]bool{
|
||||
"audio/wav": true,
|
||||
"image/png": true,
|
||||
"image/jpeg": true,
|
||||
"image/jpg": true, // support old image/jpeg
|
||||
"image/webp": true,
|
||||
"text/plain": true,
|
||||
"video/mov": true,
|
||||
@@ -381,7 +384,7 @@ func CovertOpenAI2Gemini(c *gin.Context, textRequest dto.GeneralOpenAIRequest, i
|
||||
var system_content []string
|
||||
//shouldAddDummyModelMessage := false
|
||||
for _, message := range textRequest.Messages {
|
||||
if message.Role == "system" {
|
||||
if message.Role == "system" || message.Role == "developer" {
|
||||
system_content = append(system_content, message.StringContent())
|
||||
continue
|
||||
} else if message.Role == "tool" || message.Role == "function" {
|
||||
@@ -659,101 +662,84 @@ func getSupportedMimeTypesList() []string {
|
||||
return keys
|
||||
}
|
||||
|
||||
var geminiOpenAPISchemaAllowedFields = map[string]struct{}{
|
||||
"anyOf": {},
|
||||
"default": {},
|
||||
"description": {},
|
||||
"enum": {},
|
||||
"example": {},
|
||||
"format": {},
|
||||
"items": {},
|
||||
"maxItems": {},
|
||||
"maxLength": {},
|
||||
"maxProperties": {},
|
||||
"maximum": {},
|
||||
"minItems": {},
|
||||
"minLength": {},
|
||||
"minProperties": {},
|
||||
"minimum": {},
|
||||
"nullable": {},
|
||||
"pattern": {},
|
||||
"properties": {},
|
||||
"propertyOrdering": {},
|
||||
"required": {},
|
||||
"title": {},
|
||||
"type": {},
|
||||
}
|
||||
|
||||
const geminiFunctionSchemaMaxDepth = 64
|
||||
|
||||
// cleanFunctionParameters recursively removes unsupported fields from Gemini function parameters.
|
||||
func cleanFunctionParameters(params interface{}) interface{} {
|
||||
return cleanFunctionParametersWithDepth(params, 0)
|
||||
}
|
||||
|
||||
func cleanFunctionParametersWithDepth(params interface{}, depth int) interface{} {
|
||||
if params == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
if depth >= geminiFunctionSchemaMaxDepth {
|
||||
return cleanFunctionParametersShallow(params)
|
||||
}
|
||||
|
||||
switch v := params.(type) {
|
||||
case map[string]interface{}:
|
||||
// Create a copy to avoid modifying the original
|
||||
cleanedMap := make(map[string]interface{})
|
||||
// Keep only Gemini-supported OpenAPI schema subset fields (per official SDK Schema).
|
||||
cleanedMap := make(map[string]interface{}, len(v))
|
||||
for k, val := range v {
|
||||
cleanedMap[k] = val
|
||||
}
|
||||
|
||||
// Remove unsupported root-level fields
|
||||
delete(cleanedMap, "default")
|
||||
delete(cleanedMap, "exclusiveMaximum")
|
||||
delete(cleanedMap, "exclusiveMinimum")
|
||||
delete(cleanedMap, "$schema")
|
||||
delete(cleanedMap, "additionalProperties")
|
||||
|
||||
// Check and clean 'format' for string types
|
||||
if propType, typeExists := cleanedMap["type"].(string); typeExists && propType == "string" {
|
||||
if formatValue, formatExists := cleanedMap["format"].(string); formatExists {
|
||||
if formatValue != "enum" && formatValue != "date-time" {
|
||||
delete(cleanedMap, "format")
|
||||
}
|
||||
if _, ok := geminiOpenAPISchemaAllowedFields[k]; ok {
|
||||
cleanedMap[k] = val
|
||||
}
|
||||
}
|
||||
|
||||
normalizeGeminiSchemaTypeAndNullable(cleanedMap)
|
||||
|
||||
// Clean properties
|
||||
if props, ok := cleanedMap["properties"].(map[string]interface{}); ok && props != nil {
|
||||
cleanedProps := make(map[string]interface{})
|
||||
for propName, propValue := range props {
|
||||
cleanedProps[propName] = cleanFunctionParameters(propValue)
|
||||
cleanedProps[propName] = cleanFunctionParametersWithDepth(propValue, depth+1)
|
||||
}
|
||||
cleanedMap["properties"] = cleanedProps
|
||||
}
|
||||
|
||||
// Recursively clean items in arrays
|
||||
if items, ok := cleanedMap["items"].(map[string]interface{}); ok && items != nil {
|
||||
cleanedMap["items"] = cleanFunctionParameters(items)
|
||||
cleanedMap["items"] = cleanFunctionParametersWithDepth(items, depth+1)
|
||||
}
|
||||
// Also handle items if it's an array of schemas
|
||||
if itemsArray, ok := cleanedMap["items"].([]interface{}); ok {
|
||||
cleanedItemsArray := make([]interface{}, len(itemsArray))
|
||||
for i, item := range itemsArray {
|
||||
cleanedItemsArray[i] = cleanFunctionParameters(item)
|
||||
}
|
||||
cleanedMap["items"] = cleanedItemsArray
|
||||
// OpenAPI tuple-style items is not supported by Gemini SDK Schema; keep first to avoid API rejection.
|
||||
if itemsArray, ok := cleanedMap["items"].([]interface{}); ok && len(itemsArray) > 0 {
|
||||
cleanedMap["items"] = cleanFunctionParametersWithDepth(itemsArray[0], depth+1)
|
||||
}
|
||||
|
||||
// Recursively clean other schema composition keywords
|
||||
for _, field := range []string{"allOf", "anyOf", "oneOf"} {
|
||||
if nested, ok := cleanedMap[field].([]interface{}); ok {
|
||||
cleanedNested := make([]interface{}, len(nested))
|
||||
for i, item := range nested {
|
||||
cleanedNested[i] = cleanFunctionParameters(item)
|
||||
}
|
||||
cleanedMap[field] = cleanedNested
|
||||
}
|
||||
}
|
||||
|
||||
// Recursively clean patternProperties
|
||||
if patternProps, ok := cleanedMap["patternProperties"].(map[string]interface{}); ok {
|
||||
cleanedPatternProps := make(map[string]interface{})
|
||||
for pattern, schema := range patternProps {
|
||||
cleanedPatternProps[pattern] = cleanFunctionParameters(schema)
|
||||
}
|
||||
cleanedMap["patternProperties"] = cleanedPatternProps
|
||||
}
|
||||
|
||||
// Recursively clean definitions
|
||||
if definitions, ok := cleanedMap["definitions"].(map[string]interface{}); ok {
|
||||
cleanedDefinitions := make(map[string]interface{})
|
||||
for defName, defSchema := range definitions {
|
||||
cleanedDefinitions[defName] = cleanFunctionParameters(defSchema)
|
||||
}
|
||||
cleanedMap["definitions"] = cleanedDefinitions
|
||||
}
|
||||
|
||||
// Recursively clean $defs (newer JSON Schema draft)
|
||||
if defs, ok := cleanedMap["$defs"].(map[string]interface{}); ok {
|
||||
cleanedDefs := make(map[string]interface{})
|
||||
for defName, defSchema := range defs {
|
||||
cleanedDefs[defName] = cleanFunctionParameters(defSchema)
|
||||
}
|
||||
cleanedMap["$defs"] = cleanedDefs
|
||||
}
|
||||
|
||||
// Clean conditional keywords
|
||||
for _, field := range []string{"if", "then", "else", "not"} {
|
||||
if nested, ok := cleanedMap[field]; ok {
|
||||
cleanedMap[field] = cleanFunctionParameters(nested)
|
||||
// Recursively clean anyOf
|
||||
if nested, ok := cleanedMap["anyOf"].([]interface{}); ok && nested != nil {
|
||||
cleanedNested := make([]interface{}, len(nested))
|
||||
for i, item := range nested {
|
||||
cleanedNested[i] = cleanFunctionParametersWithDepth(item, depth+1)
|
||||
}
|
||||
cleanedMap["anyOf"] = cleanedNested
|
||||
}
|
||||
|
||||
return cleanedMap
|
||||
@@ -762,7 +748,7 @@ func cleanFunctionParameters(params interface{}) interface{} {
|
||||
// Handle arrays of schemas
|
||||
cleanedArray := make([]interface{}, len(v))
|
||||
for i, item := range v {
|
||||
cleanedArray[i] = cleanFunctionParameters(item)
|
||||
cleanedArray[i] = cleanFunctionParametersWithDepth(item, depth+1)
|
||||
}
|
||||
return cleanedArray
|
||||
|
||||
@@ -772,6 +758,91 @@ func cleanFunctionParameters(params interface{}) interface{} {
|
||||
}
|
||||
}
|
||||
|
||||
func cleanFunctionParametersShallow(params interface{}) interface{} {
|
||||
switch v := params.(type) {
|
||||
case map[string]interface{}:
|
||||
cleanedMap := make(map[string]interface{}, len(v))
|
||||
for k, val := range v {
|
||||
if _, ok := geminiOpenAPISchemaAllowedFields[k]; ok {
|
||||
cleanedMap[k] = val
|
||||
}
|
||||
}
|
||||
normalizeGeminiSchemaTypeAndNullable(cleanedMap)
|
||||
// Stop recursion and avoid retaining huge nested structures.
|
||||
delete(cleanedMap, "properties")
|
||||
delete(cleanedMap, "items")
|
||||
delete(cleanedMap, "anyOf")
|
||||
return cleanedMap
|
||||
case []interface{}:
|
||||
// Prefer an empty list over deep recursion on attacker-controlled inputs.
|
||||
return []interface{}{}
|
||||
default:
|
||||
return params
|
||||
}
|
||||
}
|
||||
|
||||
func normalizeGeminiSchemaTypeAndNullable(schema map[string]interface{}) {
|
||||
rawType, ok := schema["type"]
|
||||
if !ok || rawType == nil {
|
||||
return
|
||||
}
|
||||
|
||||
normalize := func(t string) (string, bool) {
|
||||
switch strings.ToLower(strings.TrimSpace(t)) {
|
||||
case "object":
|
||||
return "OBJECT", false
|
||||
case "array":
|
||||
return "ARRAY", false
|
||||
case "string":
|
||||
return "STRING", false
|
||||
case "integer":
|
||||
return "INTEGER", false
|
||||
case "number":
|
||||
return "NUMBER", false
|
||||
case "boolean":
|
||||
return "BOOLEAN", false
|
||||
case "null":
|
||||
return "", true
|
||||
default:
|
||||
return t, false
|
||||
}
|
||||
}
|
||||
|
||||
switch t := rawType.(type) {
|
||||
case string:
|
||||
normalized, isNull := normalize(t)
|
||||
if isNull {
|
||||
schema["nullable"] = true
|
||||
delete(schema, "type")
|
||||
return
|
||||
}
|
||||
schema["type"] = normalized
|
||||
case []interface{}:
|
||||
nullable := false
|
||||
var chosen string
|
||||
for _, item := range t {
|
||||
if s, ok := item.(string); ok {
|
||||
normalized, isNull := normalize(s)
|
||||
if isNull {
|
||||
nullable = true
|
||||
continue
|
||||
}
|
||||
if chosen == "" {
|
||||
chosen = normalized
|
||||
}
|
||||
}
|
||||
}
|
||||
if nullable {
|
||||
schema["nullable"] = true
|
||||
}
|
||||
if chosen != "" {
|
||||
schema["type"] = chosen
|
||||
} else {
|
||||
delete(schema, "type")
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func removeAdditionalPropertiesWithDepth(schema interface{}, depth int) interface{} {
|
||||
if depth >= 5 {
|
||||
return schema
|
||||
@@ -1183,6 +1254,8 @@ func GeminiChatStreamHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *
|
||||
id := helper.GetResponseID(c)
|
||||
createAt := common.GetTimestamp()
|
||||
finishReason := constant.FinishReasonStop
|
||||
toolCallIndexByChoice := make(map[int]map[string]int)
|
||||
nextToolCallIndexByChoice := make(map[int]int)
|
||||
|
||||
usage, err := geminiStreamHandler(c, info, resp, func(data string, geminiResponse *dto.GeminiChatResponse) bool {
|
||||
response, isStop := streamResponseGeminiChat2OpenAI(geminiResponse)
|
||||
@@ -1190,6 +1263,28 @@ func GeminiChatStreamHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *
|
||||
response.Id = id
|
||||
response.Created = createAt
|
||||
response.Model = info.UpstreamModelName
|
||||
for choiceIdx := range response.Choices {
|
||||
choiceKey := response.Choices[choiceIdx].Index
|
||||
for toolIdx := range response.Choices[choiceIdx].Delta.ToolCalls {
|
||||
tool := &response.Choices[choiceIdx].Delta.ToolCalls[toolIdx]
|
||||
if tool.ID == "" {
|
||||
continue
|
||||
}
|
||||
m := toolCallIndexByChoice[choiceKey]
|
||||
if m == nil {
|
||||
m = make(map[string]int)
|
||||
toolCallIndexByChoice[choiceKey] = m
|
||||
}
|
||||
if idx, ok := m[tool.ID]; ok {
|
||||
tool.SetIndex(idx)
|
||||
continue
|
||||
}
|
||||
idx := nextToolCallIndexByChoice[choiceKey]
|
||||
nextToolCallIndexByChoice[choiceKey] = idx + 1
|
||||
m[tool.ID] = idx
|
||||
tool.SetIndex(idx)
|
||||
}
|
||||
}
|
||||
|
||||
logger.LogDebug(c, fmt.Sprintf("info.SendResponseCount = %d", info.SendResponseCount))
|
||||
if info.SendResponseCount == 0 {
|
||||
@@ -1417,6 +1512,79 @@ func GeminiImageHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http.
|
||||
return usage, nil
|
||||
}
|
||||
|
||||
type GeminiModelsResponse struct {
|
||||
Models []dto.GeminiModel `json:"models"`
|
||||
NextPageToken string `json:"nextPageToken"`
|
||||
}
|
||||
|
||||
func FetchGeminiModels(baseURL, apiKey, proxyURL string) ([]string, error) {
|
||||
client, err := service.GetHttpClientWithProxy(proxyURL)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("创建HTTP客户端失败: %v", err)
|
||||
}
|
||||
|
||||
allModels := make([]string, 0)
|
||||
nextPageToken := ""
|
||||
maxPages := 100 // Safety limit to prevent infinite loops
|
||||
|
||||
for page := 0; page < maxPages; page++ {
|
||||
url := fmt.Sprintf("%s/v1beta/models", baseURL)
|
||||
if nextPageToken != "" {
|
||||
url = fmt.Sprintf("%s?pageToken=%s", url, nextPageToken)
|
||||
}
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second)
|
||||
request, err := http.NewRequestWithContext(ctx, "GET", url, nil)
|
||||
if err != nil {
|
||||
cancel()
|
||||
return nil, fmt.Errorf("创建请求失败: %v", err)
|
||||
}
|
||||
|
||||
request.Header.Set("x-goog-api-key", apiKey)
|
||||
|
||||
response, err := client.Do(request)
|
||||
if err != nil {
|
||||
cancel()
|
||||
return nil, fmt.Errorf("请求失败: %v", err)
|
||||
}
|
||||
|
||||
if response.StatusCode != http.StatusOK {
|
||||
body, _ := io.ReadAll(response.Body)
|
||||
response.Body.Close()
|
||||
cancel()
|
||||
return nil, fmt.Errorf("服务器返回错误 %d: %s", response.StatusCode, string(body))
|
||||
}
|
||||
|
||||
body, err := io.ReadAll(response.Body)
|
||||
response.Body.Close()
|
||||
cancel()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("读取响应失败: %v", err)
|
||||
}
|
||||
|
||||
var modelsResponse GeminiModelsResponse
|
||||
if err = common.Unmarshal(body, &modelsResponse); err != nil {
|
||||
return nil, fmt.Errorf("解析响应失败: %v", err)
|
||||
}
|
||||
|
||||
for _, model := range modelsResponse.Models {
|
||||
modelNameValue, ok := model.Name.(string)
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
modelName := strings.TrimPrefix(modelNameValue, "models/")
|
||||
allModels = append(allModels, modelName)
|
||||
}
|
||||
|
||||
nextPageToken = modelsResponse.NextPageToken
|
||||
if nextPageToken == "" {
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
return allModels, nil
|
||||
}
|
||||
|
||||
// convertToolChoiceToGeminiConfig converts OpenAI tool_choice to Gemini toolConfig
|
||||
// OpenAI tool_choice values:
|
||||
// - "auto": Let the model decide (default)
|
||||
|
||||
@@ -14,6 +14,9 @@ var ModelList = []string{
|
||||
"speech-02-turbo",
|
||||
"speech-01-hd",
|
||||
"speech-01-turbo",
|
||||
"MiniMax-M2.1",
|
||||
"MiniMax-M2.1-lightning",
|
||||
"MiniMax-M2",
|
||||
}
|
||||
|
||||
var ChannelName = "minimax"
|
||||
|
||||
@@ -67,3 +67,40 @@ type OllamaEmbeddingResponse struct {
|
||||
Embeddings [][]float64 `json:"embeddings"`
|
||||
PromptEvalCount int `json:"prompt_eval_count,omitempty"`
|
||||
}
|
||||
|
||||
type OllamaTagsResponse struct {
|
||||
Models []OllamaModel `json:"models"`
|
||||
}
|
||||
|
||||
type OllamaModel struct {
|
||||
Name string `json:"name"`
|
||||
Size int64 `json:"size"`
|
||||
Digest string `json:"digest,omitempty"`
|
||||
ModifiedAt string `json:"modified_at"`
|
||||
Details OllamaModelDetail `json:"details,omitempty"`
|
||||
}
|
||||
|
||||
type OllamaModelDetail struct {
|
||||
ParentModel string `json:"parent_model,omitempty"`
|
||||
Format string `json:"format,omitempty"`
|
||||
Family string `json:"family,omitempty"`
|
||||
Families []string `json:"families,omitempty"`
|
||||
ParameterSize string `json:"parameter_size,omitempty"`
|
||||
QuantizationLevel string `json:"quantization_level,omitempty"`
|
||||
}
|
||||
|
||||
type OllamaPullRequest struct {
|
||||
Name string `json:"name"`
|
||||
Stream bool `json:"stream,omitempty"`
|
||||
}
|
||||
|
||||
type OllamaPullResponse struct {
|
||||
Status string `json:"status"`
|
||||
Digest string `json:"digest,omitempty"`
|
||||
Total int64 `json:"total,omitempty"`
|
||||
Completed int64 `json:"completed,omitempty"`
|
||||
}
|
||||
|
||||
type OllamaDeleteRequest struct {
|
||||
Name string `json:"name"`
|
||||
}
|
||||
|
||||
@@ -1,11 +1,13 @@
|
||||
package ollama
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/QuantumNous/new-api/common"
|
||||
"github.com/QuantumNous/new-api/dto"
|
||||
@@ -283,3 +285,246 @@ func ollamaEmbeddingHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *h
|
||||
service.IOCopyBytesGracefully(c, resp, out)
|
||||
return usage, nil
|
||||
}
|
||||
|
||||
func FetchOllamaModels(baseURL, apiKey string) ([]OllamaModel, error) {
|
||||
url := fmt.Sprintf("%s/api/tags", baseURL)
|
||||
|
||||
client := &http.Client{}
|
||||
request, err := http.NewRequest("GET", url, nil)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("创建请求失败: %v", err)
|
||||
}
|
||||
|
||||
// Ollama 通常不需要 Bearer token,但为了兼容性保留
|
||||
if apiKey != "" {
|
||||
request.Header.Set("Authorization", "Bearer "+apiKey)
|
||||
}
|
||||
|
||||
response, err := client.Do(request)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("请求失败: %v", err)
|
||||
}
|
||||
defer response.Body.Close()
|
||||
|
||||
if response.StatusCode != http.StatusOK {
|
||||
body, _ := io.ReadAll(response.Body)
|
||||
return nil, fmt.Errorf("服务器返回错误 %d: %s", response.StatusCode, string(body))
|
||||
}
|
||||
|
||||
var tagsResponse OllamaTagsResponse
|
||||
body, err := io.ReadAll(response.Body)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("读取响应失败: %v", err)
|
||||
}
|
||||
|
||||
err = common.Unmarshal(body, &tagsResponse)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("解析响应失败: %v", err)
|
||||
}
|
||||
|
||||
return tagsResponse.Models, nil
|
||||
}
|
||||
|
||||
// 拉取 Ollama 模型 (非流式)
|
||||
func PullOllamaModel(baseURL, apiKey, modelName string) error {
|
||||
url := fmt.Sprintf("%s/api/pull", baseURL)
|
||||
|
||||
pullRequest := OllamaPullRequest{
|
||||
Name: modelName,
|
||||
Stream: false, // 非流式,简化处理
|
||||
}
|
||||
|
||||
requestBody, err := common.Marshal(pullRequest)
|
||||
if err != nil {
|
||||
return fmt.Errorf("序列化请求失败: %v", err)
|
||||
}
|
||||
|
||||
client := &http.Client{
|
||||
Timeout: 30 * 60 * 1000 * time.Millisecond, // 30分钟超时,支持大模型
|
||||
}
|
||||
request, err := http.NewRequest("POST", url, strings.NewReader(string(requestBody)))
|
||||
if err != nil {
|
||||
return fmt.Errorf("创建请求失败: %v", err)
|
||||
}
|
||||
|
||||
request.Header.Set("Content-Type", "application/json")
|
||||
if apiKey != "" {
|
||||
request.Header.Set("Authorization", "Bearer "+apiKey)
|
||||
}
|
||||
|
||||
response, err := client.Do(request)
|
||||
if err != nil {
|
||||
return fmt.Errorf("请求失败: %v", err)
|
||||
}
|
||||
defer response.Body.Close()
|
||||
|
||||
if response.StatusCode != http.StatusOK {
|
||||
body, _ := io.ReadAll(response.Body)
|
||||
return fmt.Errorf("拉取模型失败 %d: %s", response.StatusCode, string(body))
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// 流式拉取 Ollama 模型 (支持进度回调)
|
||||
func PullOllamaModelStream(baseURL, apiKey, modelName string, progressCallback func(OllamaPullResponse)) error {
|
||||
url := fmt.Sprintf("%s/api/pull", baseURL)
|
||||
|
||||
pullRequest := OllamaPullRequest{
|
||||
Name: modelName,
|
||||
Stream: true, // 启用流式
|
||||
}
|
||||
|
||||
requestBody, err := common.Marshal(pullRequest)
|
||||
if err != nil {
|
||||
return fmt.Errorf("序列化请求失败: %v", err)
|
||||
}
|
||||
|
||||
client := &http.Client{
|
||||
Timeout: 60 * 60 * 1000 * time.Millisecond, // 1小时超时,支持超大模型
|
||||
}
|
||||
request, err := http.NewRequest("POST", url, strings.NewReader(string(requestBody)))
|
||||
if err != nil {
|
||||
return fmt.Errorf("创建请求失败: %v", err)
|
||||
}
|
||||
|
||||
request.Header.Set("Content-Type", "application/json")
|
||||
if apiKey != "" {
|
||||
request.Header.Set("Authorization", "Bearer "+apiKey)
|
||||
}
|
||||
|
||||
response, err := client.Do(request)
|
||||
if err != nil {
|
||||
return fmt.Errorf("请求失败: %v", err)
|
||||
}
|
||||
defer response.Body.Close()
|
||||
|
||||
if response.StatusCode != http.StatusOK {
|
||||
body, _ := io.ReadAll(response.Body)
|
||||
return fmt.Errorf("拉取模型失败 %d: %s", response.StatusCode, string(body))
|
||||
}
|
||||
|
||||
// 读取流式响应
|
||||
scanner := bufio.NewScanner(response.Body)
|
||||
successful := false
|
||||
for scanner.Scan() {
|
||||
line := scanner.Text()
|
||||
if strings.TrimSpace(line) == "" {
|
||||
continue
|
||||
}
|
||||
|
||||
var pullResponse OllamaPullResponse
|
||||
if err := common.Unmarshal([]byte(line), &pullResponse); err != nil {
|
||||
continue // 忽略解析失败的行
|
||||
}
|
||||
|
||||
if progressCallback != nil {
|
||||
progressCallback(pullResponse)
|
||||
}
|
||||
|
||||
// 检查是否出现错误或完成
|
||||
if strings.EqualFold(pullResponse.Status, "error") {
|
||||
return fmt.Errorf("拉取模型失败: %s", strings.TrimSpace(line))
|
||||
}
|
||||
if strings.EqualFold(pullResponse.Status, "success") {
|
||||
successful = true
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
if err := scanner.Err(); err != nil {
|
||||
return fmt.Errorf("读取流式响应失败: %v", err)
|
||||
}
|
||||
|
||||
if !successful {
|
||||
return fmt.Errorf("拉取模型未完成: 未收到成功状态")
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// 删除 Ollama 模型
|
||||
func DeleteOllamaModel(baseURL, apiKey, modelName string) error {
|
||||
url := fmt.Sprintf("%s/api/delete", baseURL)
|
||||
|
||||
deleteRequest := OllamaDeleteRequest{
|
||||
Name: modelName,
|
||||
}
|
||||
|
||||
requestBody, err := common.Marshal(deleteRequest)
|
||||
if err != nil {
|
||||
return fmt.Errorf("序列化请求失败: %v", err)
|
||||
}
|
||||
|
||||
client := &http.Client{}
|
||||
request, err := http.NewRequest("DELETE", url, strings.NewReader(string(requestBody)))
|
||||
if err != nil {
|
||||
return fmt.Errorf("创建请求失败: %v", err)
|
||||
}
|
||||
|
||||
request.Header.Set("Content-Type", "application/json")
|
||||
if apiKey != "" {
|
||||
request.Header.Set("Authorization", "Bearer "+apiKey)
|
||||
}
|
||||
|
||||
response, err := client.Do(request)
|
||||
if err != nil {
|
||||
return fmt.Errorf("请求失败: %v", err)
|
||||
}
|
||||
defer response.Body.Close()
|
||||
|
||||
if response.StatusCode != http.StatusOK {
|
||||
body, _ := io.ReadAll(response.Body)
|
||||
return fmt.Errorf("删除模型失败 %d: %s", response.StatusCode, string(body))
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func FetchOllamaVersion(baseURL, apiKey string) (string, error) {
|
||||
trimmedBase := strings.TrimRight(baseURL, "/")
|
||||
if trimmedBase == "" {
|
||||
return "", fmt.Errorf("baseURL 为空")
|
||||
}
|
||||
|
||||
url := fmt.Sprintf("%s/api/version", trimmedBase)
|
||||
|
||||
client := &http.Client{Timeout: 10 * time.Second}
|
||||
request, err := http.NewRequest("GET", url, nil)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("创建请求失败: %v", err)
|
||||
}
|
||||
|
||||
if apiKey != "" {
|
||||
request.Header.Set("Authorization", "Bearer "+apiKey)
|
||||
}
|
||||
|
||||
response, err := client.Do(request)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("请求失败: %v", err)
|
||||
}
|
||||
defer response.Body.Close()
|
||||
|
||||
body, err := io.ReadAll(response.Body)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("读取响应失败: %v", err)
|
||||
}
|
||||
|
||||
if response.StatusCode != http.StatusOK {
|
||||
return "", fmt.Errorf("查询版本失败 %d: %s", response.StatusCode, string(body))
|
||||
}
|
||||
|
||||
var versionResp struct {
|
||||
Version string `json:"version"`
|
||||
}
|
||||
|
||||
if err := json.Unmarshal(body, &versionResp); err != nil {
|
||||
return "", fmt.Errorf("解析响应失败: %v", err)
|
||||
}
|
||||
|
||||
if versionResp.Version == "" {
|
||||
return "", fmt.Errorf("未返回版本信息")
|
||||
}
|
||||
|
||||
return versionResp.Version, nil
|
||||
}
|
||||
|
||||
369
relay/channel/openai/chat_via_responses.go
Normal file
369
relay/channel/openai/chat_via_responses.go
Normal file
@@ -0,0 +1,369 @@
|
||||
package openai
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/QuantumNous/new-api/common"
|
||||
"github.com/QuantumNous/new-api/dto"
|
||||
"github.com/QuantumNous/new-api/logger"
|
||||
relaycommon "github.com/QuantumNous/new-api/relay/common"
|
||||
"github.com/QuantumNous/new-api/relay/helper"
|
||||
"github.com/QuantumNous/new-api/service"
|
||||
"github.com/QuantumNous/new-api/types"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
func OaiResponsesToChatHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http.Response) (*dto.Usage, *types.NewAPIError) {
|
||||
if resp == nil || resp.Body == nil {
|
||||
return nil, types.NewOpenAIError(fmt.Errorf("invalid response"), types.ErrorCodeBadResponse, http.StatusInternalServerError)
|
||||
}
|
||||
|
||||
defer service.CloseResponseBodyGracefully(resp)
|
||||
|
||||
var responsesResp dto.OpenAIResponsesResponse
|
||||
body, err := io.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
return nil, types.NewOpenAIError(err, types.ErrorCodeReadResponseBodyFailed, http.StatusInternalServerError)
|
||||
}
|
||||
|
||||
if err := common.Unmarshal(body, &responsesResp); err != nil {
|
||||
return nil, types.NewOpenAIError(err, types.ErrorCodeBadResponseBody, http.StatusInternalServerError)
|
||||
}
|
||||
|
||||
if oaiError := responsesResp.GetOpenAIError(); oaiError != nil && oaiError.Type != "" {
|
||||
return nil, types.WithOpenAIError(*oaiError, resp.StatusCode)
|
||||
}
|
||||
|
||||
chatId := helper.GetResponseID(c)
|
||||
chatResp, usage, err := service.ResponsesResponseToChatCompletionsResponse(&responsesResp, chatId)
|
||||
if err != nil {
|
||||
return nil, types.NewOpenAIError(err, types.ErrorCodeBadResponseBody, http.StatusInternalServerError)
|
||||
}
|
||||
|
||||
if usage == nil || usage.TotalTokens == 0 {
|
||||
text := service.ExtractOutputTextFromResponses(&responsesResp)
|
||||
usage = service.ResponseText2Usage(c, text, info.UpstreamModelName, info.GetEstimatePromptTokens())
|
||||
chatResp.Usage = *usage
|
||||
}
|
||||
|
||||
chatBody, err := common.Marshal(chatResp)
|
||||
if err != nil {
|
||||
return nil, types.NewOpenAIError(err, types.ErrorCodeJsonMarshalFailed, http.StatusInternalServerError)
|
||||
}
|
||||
|
||||
service.IOCopyBytesGracefully(c, resp, chatBody)
|
||||
return usage, nil
|
||||
}
|
||||
|
||||
func OaiResponsesToChatStreamHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http.Response) (*dto.Usage, *types.NewAPIError) {
|
||||
if resp == nil || resp.Body == nil {
|
||||
return nil, types.NewOpenAIError(fmt.Errorf("invalid response"), types.ErrorCodeBadResponse, http.StatusInternalServerError)
|
||||
}
|
||||
|
||||
defer service.CloseResponseBodyGracefully(resp)
|
||||
|
||||
responseId := helper.GetResponseID(c)
|
||||
createAt := time.Now().Unix()
|
||||
model := info.UpstreamModelName
|
||||
|
||||
var (
|
||||
usage = &dto.Usage{}
|
||||
outputText strings.Builder
|
||||
usageText strings.Builder
|
||||
sentStart bool
|
||||
sentStop bool
|
||||
sawToolCall bool
|
||||
streamErr *types.NewAPIError
|
||||
)
|
||||
|
||||
toolCallIndexByID := make(map[string]int)
|
||||
toolCallNameByID := make(map[string]string)
|
||||
toolCallArgsByID := make(map[string]string)
|
||||
toolCallNameSent := make(map[string]bool)
|
||||
toolCallCanonicalIDByItemID := make(map[string]string)
|
||||
|
||||
sendStartIfNeeded := func() bool {
|
||||
if sentStart {
|
||||
return true
|
||||
}
|
||||
if err := helper.ObjectData(c, helper.GenerateStartEmptyResponse(responseId, createAt, model, nil)); err != nil {
|
||||
streamErr = types.NewOpenAIError(err, types.ErrorCodeBadResponse, http.StatusInternalServerError)
|
||||
return false
|
||||
}
|
||||
sentStart = true
|
||||
return true
|
||||
}
|
||||
|
||||
sendToolCallDelta := func(callID string, name string, argsDelta string) bool {
|
||||
if callID == "" {
|
||||
return true
|
||||
}
|
||||
if outputText.Len() > 0 {
|
||||
// Prefer streaming assistant text over tool calls to match non-stream behavior.
|
||||
return true
|
||||
}
|
||||
if !sendStartIfNeeded() {
|
||||
return false
|
||||
}
|
||||
|
||||
idx, ok := toolCallIndexByID[callID]
|
||||
if !ok {
|
||||
idx = len(toolCallIndexByID)
|
||||
toolCallIndexByID[callID] = idx
|
||||
}
|
||||
if name != "" {
|
||||
toolCallNameByID[callID] = name
|
||||
}
|
||||
if toolCallNameByID[callID] != "" {
|
||||
name = toolCallNameByID[callID]
|
||||
}
|
||||
|
||||
tool := dto.ToolCallResponse{
|
||||
ID: callID,
|
||||
Type: "function",
|
||||
Function: dto.FunctionResponse{
|
||||
Arguments: argsDelta,
|
||||
},
|
||||
}
|
||||
tool.SetIndex(idx)
|
||||
if name != "" && !toolCallNameSent[callID] {
|
||||
tool.Function.Name = name
|
||||
toolCallNameSent[callID] = true
|
||||
}
|
||||
|
||||
chunk := &dto.ChatCompletionsStreamResponse{
|
||||
Id: responseId,
|
||||
Object: "chat.completion.chunk",
|
||||
Created: createAt,
|
||||
Model: model,
|
||||
Choices: []dto.ChatCompletionsStreamResponseChoice{
|
||||
{
|
||||
Index: 0,
|
||||
Delta: dto.ChatCompletionsStreamResponseChoiceDelta{
|
||||
ToolCalls: []dto.ToolCallResponse{tool},
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
if err := helper.ObjectData(c, chunk); err != nil {
|
||||
streamErr = types.NewOpenAIError(err, types.ErrorCodeBadResponse, http.StatusInternalServerError)
|
||||
return false
|
||||
}
|
||||
sawToolCall = true
|
||||
|
||||
// Include tool call data in the local builder for fallback token estimation.
|
||||
if tool.Function.Name != "" {
|
||||
usageText.WriteString(tool.Function.Name)
|
||||
}
|
||||
if argsDelta != "" {
|
||||
usageText.WriteString(argsDelta)
|
||||
}
|
||||
return true
|
||||
}
|
||||
|
||||
helper.StreamScannerHandler(c, resp, info, func(data string) bool {
|
||||
if streamErr != nil {
|
||||
return false
|
||||
}
|
||||
|
||||
var streamResp dto.ResponsesStreamResponse
|
||||
if err := common.UnmarshalJsonStr(data, &streamResp); err != nil {
|
||||
logger.LogError(c, "failed to unmarshal responses stream event: "+err.Error())
|
||||
return true
|
||||
}
|
||||
|
||||
switch streamResp.Type {
|
||||
case "response.created":
|
||||
if streamResp.Response != nil {
|
||||
if streamResp.Response.Model != "" {
|
||||
model = streamResp.Response.Model
|
||||
}
|
||||
if streamResp.Response.CreatedAt != 0 {
|
||||
createAt = int64(streamResp.Response.CreatedAt)
|
||||
}
|
||||
}
|
||||
|
||||
case "response.output_text.delta":
|
||||
if !sendStartIfNeeded() {
|
||||
return false
|
||||
}
|
||||
|
||||
if streamResp.Delta != "" {
|
||||
outputText.WriteString(streamResp.Delta)
|
||||
usageText.WriteString(streamResp.Delta)
|
||||
delta := streamResp.Delta
|
||||
chunk := &dto.ChatCompletionsStreamResponse{
|
||||
Id: responseId,
|
||||
Object: "chat.completion.chunk",
|
||||
Created: createAt,
|
||||
Model: model,
|
||||
Choices: []dto.ChatCompletionsStreamResponseChoice{
|
||||
{
|
||||
Index: 0,
|
||||
Delta: dto.ChatCompletionsStreamResponseChoiceDelta{
|
||||
Content: &delta,
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
if err := helper.ObjectData(c, chunk); err != nil {
|
||||
streamErr = types.NewOpenAIError(err, types.ErrorCodeBadResponse, http.StatusInternalServerError)
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
case "response.output_item.added", "response.output_item.done":
|
||||
if streamResp.Item == nil {
|
||||
break
|
||||
}
|
||||
if streamResp.Item.Type != "function_call" {
|
||||
break
|
||||
}
|
||||
|
||||
itemID := strings.TrimSpace(streamResp.Item.ID)
|
||||
callID := strings.TrimSpace(streamResp.Item.CallId)
|
||||
if callID == "" {
|
||||
callID = itemID
|
||||
}
|
||||
if itemID != "" && callID != "" {
|
||||
toolCallCanonicalIDByItemID[itemID] = callID
|
||||
}
|
||||
name := strings.TrimSpace(streamResp.Item.Name)
|
||||
if name != "" {
|
||||
toolCallNameByID[callID] = name
|
||||
}
|
||||
|
||||
newArgs := streamResp.Item.Arguments
|
||||
prevArgs := toolCallArgsByID[callID]
|
||||
argsDelta := ""
|
||||
if newArgs != "" {
|
||||
if strings.HasPrefix(newArgs, prevArgs) {
|
||||
argsDelta = newArgs[len(prevArgs):]
|
||||
} else {
|
||||
argsDelta = newArgs
|
||||
}
|
||||
toolCallArgsByID[callID] = newArgs
|
||||
}
|
||||
|
||||
if !sendToolCallDelta(callID, name, argsDelta) {
|
||||
return false
|
||||
}
|
||||
|
||||
case "response.function_call_arguments.delta":
|
||||
itemID := strings.TrimSpace(streamResp.ItemID)
|
||||
callID := toolCallCanonicalIDByItemID[itemID]
|
||||
if callID == "" {
|
||||
callID = itemID
|
||||
}
|
||||
if callID == "" {
|
||||
break
|
||||
}
|
||||
toolCallArgsByID[callID] += streamResp.Delta
|
||||
if !sendToolCallDelta(callID, "", streamResp.Delta) {
|
||||
return false
|
||||
}
|
||||
|
||||
case "response.function_call_arguments.done":
|
||||
|
||||
case "response.completed":
|
||||
if streamResp.Response != nil {
|
||||
if streamResp.Response.Model != "" {
|
||||
model = streamResp.Response.Model
|
||||
}
|
||||
if streamResp.Response.CreatedAt != 0 {
|
||||
createAt = int64(streamResp.Response.CreatedAt)
|
||||
}
|
||||
if streamResp.Response.Usage != nil {
|
||||
if streamResp.Response.Usage.InputTokens != 0 {
|
||||
usage.PromptTokens = streamResp.Response.Usage.InputTokens
|
||||
usage.InputTokens = streamResp.Response.Usage.InputTokens
|
||||
}
|
||||
if streamResp.Response.Usage.OutputTokens != 0 {
|
||||
usage.CompletionTokens = streamResp.Response.Usage.OutputTokens
|
||||
usage.OutputTokens = streamResp.Response.Usage.OutputTokens
|
||||
}
|
||||
if streamResp.Response.Usage.TotalTokens != 0 {
|
||||
usage.TotalTokens = streamResp.Response.Usage.TotalTokens
|
||||
} else {
|
||||
usage.TotalTokens = usage.PromptTokens + usage.CompletionTokens
|
||||
}
|
||||
if streamResp.Response.Usage.InputTokensDetails != nil {
|
||||
usage.PromptTokensDetails.CachedTokens = streamResp.Response.Usage.InputTokensDetails.CachedTokens
|
||||
usage.PromptTokensDetails.ImageTokens = streamResp.Response.Usage.InputTokensDetails.ImageTokens
|
||||
usage.PromptTokensDetails.AudioTokens = streamResp.Response.Usage.InputTokensDetails.AudioTokens
|
||||
}
|
||||
if streamResp.Response.Usage.CompletionTokenDetails.ReasoningTokens != 0 {
|
||||
usage.CompletionTokenDetails.ReasoningTokens = streamResp.Response.Usage.CompletionTokenDetails.ReasoningTokens
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if !sendStartIfNeeded() {
|
||||
return false
|
||||
}
|
||||
if !sentStop {
|
||||
finishReason := "stop"
|
||||
if sawToolCall && outputText.Len() == 0 {
|
||||
finishReason = "tool_calls"
|
||||
}
|
||||
stop := helper.GenerateStopResponse(responseId, createAt, model, finishReason)
|
||||
if err := helper.ObjectData(c, stop); err != nil {
|
||||
streamErr = types.NewOpenAIError(err, types.ErrorCodeBadResponse, http.StatusInternalServerError)
|
||||
return false
|
||||
}
|
||||
sentStop = true
|
||||
}
|
||||
|
||||
case "response.error", "response.failed":
|
||||
if streamResp.Response != nil {
|
||||
if oaiErr := streamResp.Response.GetOpenAIError(); oaiErr != nil && oaiErr.Type != "" {
|
||||
streamErr = types.WithOpenAIError(*oaiErr, http.StatusInternalServerError)
|
||||
return false
|
||||
}
|
||||
}
|
||||
streamErr = types.NewOpenAIError(fmt.Errorf("responses stream error: %s", streamResp.Type), types.ErrorCodeBadResponse, http.StatusInternalServerError)
|
||||
return false
|
||||
|
||||
default:
|
||||
}
|
||||
|
||||
return true
|
||||
})
|
||||
|
||||
if streamErr != nil {
|
||||
return nil, streamErr
|
||||
}
|
||||
|
||||
if usage.TotalTokens == 0 {
|
||||
usage = service.ResponseText2Usage(c, usageText.String(), info.UpstreamModelName, info.GetEstimatePromptTokens())
|
||||
}
|
||||
|
||||
if !sentStart {
|
||||
if err := helper.ObjectData(c, helper.GenerateStartEmptyResponse(responseId, createAt, model, nil)); err != nil {
|
||||
return nil, types.NewOpenAIError(err, types.ErrorCodeBadResponse, http.StatusInternalServerError)
|
||||
}
|
||||
}
|
||||
if !sentStop {
|
||||
finishReason := "stop"
|
||||
if sawToolCall && outputText.Len() == 0 {
|
||||
finishReason = "tool_calls"
|
||||
}
|
||||
stop := helper.GenerateStopResponse(responseId, createAt, model, finishReason)
|
||||
if err := helper.ObjectData(c, stop); err != nil {
|
||||
return nil, types.NewOpenAIError(err, types.ErrorCodeBadResponse, http.StatusInternalServerError)
|
||||
}
|
||||
}
|
||||
if info.ShouldIncludeUsage && usage != nil {
|
||||
if err := helper.ObjectData(c, helper.GenerateFinalUsageResponse(responseId, createAt, model, *usage)); err != nil {
|
||||
return nil, types.NewOpenAIError(err, types.ErrorCodeBadResponse, http.StatusInternalServerError)
|
||||
}
|
||||
}
|
||||
|
||||
helper.Done(c)
|
||||
return usage, nil
|
||||
}
|
||||
@@ -208,7 +208,6 @@ func HandleFinalResponse(c *gin.Context, info *relaycommon.RelayInfo, lastStream
|
||||
helper.Done(c)
|
||||
|
||||
case types.RelayFormatClaude:
|
||||
info.ClaudeConvertInfo.Done = true
|
||||
var streamResponse dto.ChatCompletionsStreamResponse
|
||||
if err := common.Unmarshal(common.StringToByteSlice(lastStreamData), &streamResponse); err != nil {
|
||||
common.SysLog("error unmarshalling stream response: " + err.Error())
|
||||
@@ -221,6 +220,7 @@ func HandleFinalResponse(c *gin.Context, info *relaycommon.RelayInfo, lastStream
|
||||
for _, resp := range claudeResponses {
|
||||
_ = helper.ClaudeData(c, *resp)
|
||||
}
|
||||
info.ClaudeConvertInfo.Done = true
|
||||
|
||||
case types.RelayFormatGemini:
|
||||
var streamResponse dto.ChatCompletionsStreamResponse
|
||||
|
||||
@@ -186,7 +186,7 @@ func OaiStreamHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http.Re
|
||||
usage.CompletionTokens += toolCount * 7
|
||||
}
|
||||
|
||||
applyUsagePostProcessing(info, usage, nil)
|
||||
applyUsagePostProcessing(info, usage, common.StringToByteSlice(lastStreamData))
|
||||
|
||||
HandleFinalResponse(c, info, lastStreamData, responseId, createAt, model, systemFingerprint, usage, containStreamUsage)
|
||||
|
||||
@@ -596,7 +596,8 @@ func applyUsagePostProcessing(info *relaycommon.RelayInfo, usage *dto.Usage, res
|
||||
if usage.PromptTokensDetails.CachedTokens == 0 && usage.PromptCacheHitTokens != 0 {
|
||||
usage.PromptTokensDetails.CachedTokens = usage.PromptCacheHitTokens
|
||||
}
|
||||
case constant.ChannelTypeZhipu_v4, constant.ChannelTypeMoonshot:
|
||||
case constant.ChannelTypeZhipu_v4:
|
||||
// 智普的cached_tokens在标准位置: usage.prompt_tokens_details.cached_tokens
|
||||
if usage.PromptTokensDetails.CachedTokens == 0 {
|
||||
if usage.InputTokensDetails != nil && usage.InputTokensDetails.CachedTokens > 0 {
|
||||
usage.PromptTokensDetails.CachedTokens = usage.InputTokensDetails.CachedTokens
|
||||
@@ -606,6 +607,19 @@ func applyUsagePostProcessing(info *relaycommon.RelayInfo, usage *dto.Usage, res
|
||||
usage.PromptTokensDetails.CachedTokens = usage.PromptCacheHitTokens
|
||||
}
|
||||
}
|
||||
case constant.ChannelTypeMoonshot:
|
||||
// Moonshot的cached_tokens在非标准位置: choices[].usage.cached_tokens
|
||||
if usage.PromptTokensDetails.CachedTokens == 0 {
|
||||
if usage.InputTokensDetails != nil && usage.InputTokensDetails.CachedTokens > 0 {
|
||||
usage.PromptTokensDetails.CachedTokens = usage.InputTokensDetails.CachedTokens
|
||||
} else if cachedTokens, ok := extractMoonshotCachedTokensFromBody(responseBody); ok {
|
||||
usage.PromptTokensDetails.CachedTokens = cachedTokens
|
||||
} else if cachedTokens, ok := extractCachedTokensFromBody(responseBody); ok {
|
||||
usage.PromptTokensDetails.CachedTokens = cachedTokens
|
||||
} else if usage.PromptCacheHitTokens > 0 {
|
||||
usage.PromptTokensDetails.CachedTokens = usage.PromptCacheHitTokens
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -639,3 +653,32 @@ func extractCachedTokensFromBody(body []byte) (int, bool) {
|
||||
}
|
||||
return 0, false
|
||||
}
|
||||
|
||||
// extractMoonshotCachedTokensFromBody 从Moonshot的非标准位置提取cached_tokens
|
||||
// Moonshot的流式响应格式: {"choices":[{"usage":{"cached_tokens":111}}]}
|
||||
func extractMoonshotCachedTokensFromBody(body []byte) (int, bool) {
|
||||
if len(body) == 0 {
|
||||
return 0, false
|
||||
}
|
||||
|
||||
var payload struct {
|
||||
Choices []struct {
|
||||
Usage struct {
|
||||
CachedTokens *int `json:"cached_tokens"`
|
||||
} `json:"usage"`
|
||||
} `json:"choices"`
|
||||
}
|
||||
|
||||
if err := common.Unmarshal(body, &payload); err != nil {
|
||||
return 0, false
|
||||
}
|
||||
|
||||
// 遍历choices查找cached_tokens
|
||||
for _, choice := range payload.Choices {
|
||||
if choice.Usage.CachedTokens != nil && *choice.Usage.CachedTokens > 0 {
|
||||
return *choice.Usage.CachedTokens, true
|
||||
}
|
||||
}
|
||||
|
||||
return 0, false
|
||||
}
|
||||
|
||||
@@ -192,6 +192,10 @@ func sizeToResolution(size string) (string, error) {
|
||||
func ProcessAliOtherRatios(aliReq *AliVideoRequest) (map[string]float64, error) {
|
||||
otherRatios := make(map[string]float64)
|
||||
aliRatios := map[string]map[string]float64{
|
||||
"wan2.6-i2v": {
|
||||
"720P": 1,
|
||||
"1080P": 1 / 0.6,
|
||||
},
|
||||
"wan2.5-t2v-preview": {
|
||||
"480P": 1,
|
||||
"720P": 2,
|
||||
@@ -287,7 +291,9 @@ func (a *TaskAdaptor) convertToAliRequest(info *relaycommon.RelayInfo, req relay
|
||||
aliReq.Parameters.Size = "1280*720"
|
||||
}
|
||||
} else {
|
||||
if strings.HasPrefix(req.Model, "wan2.5") {
|
||||
if strings.HasPrefix(req.Model, "wan2.6") {
|
||||
aliReq.Parameters.Resolution = "1080P"
|
||||
} else if strings.HasPrefix(req.Model, "wan2.5") {
|
||||
aliReq.Parameters.Resolution = "1080P"
|
||||
} else if strings.HasPrefix(req.Model, "wan2.2-i2v-flash") {
|
||||
aliReq.Parameters.Resolution = "720P"
|
||||
|
||||
@@ -6,6 +6,9 @@ import (
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"time"
|
||||
|
||||
"github.com/QuantumNous/new-api/common"
|
||||
|
||||
"github.com/QuantumNous/new-api/constant"
|
||||
"github.com/QuantumNous/new-api/dto"
|
||||
@@ -23,18 +26,36 @@ import (
|
||||
// ============================
|
||||
|
||||
type ContentItem struct {
|
||||
Type string `json:"type"` // "text" or "image_url"
|
||||
Text string `json:"text,omitempty"` // for text type
|
||||
ImageURL *ImageURL `json:"image_url,omitempty"` // for image_url type
|
||||
Type string `json:"type"` // "text", "image_url" or "video"
|
||||
Text string `json:"text,omitempty"` // for text type
|
||||
ImageURL *ImageURL `json:"image_url,omitempty"` // for image_url type
|
||||
Video *VideoReference `json:"video,omitempty"` // for video (sample) type
|
||||
}
|
||||
|
||||
type ImageURL struct {
|
||||
URL string `json:"url"`
|
||||
}
|
||||
|
||||
type VideoReference struct {
|
||||
URL string `json:"url"` // Draft video URL
|
||||
}
|
||||
|
||||
type requestPayload struct {
|
||||
Model string `json:"model"`
|
||||
Content []ContentItem `json:"content"`
|
||||
Model string `json:"model"`
|
||||
Content []ContentItem `json:"content"`
|
||||
CallbackURL string `json:"callback_url,omitempty"`
|
||||
ReturnLastFrame *dto.BoolValue `json:"return_last_frame,omitempty"`
|
||||
ServiceTier string `json:"service_tier,omitempty"`
|
||||
ExecutionExpiresAfter dto.IntValue `json:"execution_expires_after,omitempty"`
|
||||
GenerateAudio *dto.BoolValue `json:"generate_audio,omitempty"`
|
||||
Draft *dto.BoolValue `json:"draft,omitempty"`
|
||||
Resolution string `json:"resolution,omitempty"`
|
||||
Ratio string `json:"ratio,omitempty"`
|
||||
Duration dto.IntValue `json:"duration,omitempty"`
|
||||
Frames dto.IntValue `json:"frames,omitempty"`
|
||||
Seed dto.IntValue `json:"seed,omitempty"`
|
||||
CameraFixed *dto.BoolValue `json:"camera_fixed,omitempty"`
|
||||
Watermark *dto.BoolValue `json:"watermark,omitempty"`
|
||||
}
|
||||
|
||||
type responsePayload struct {
|
||||
@@ -53,6 +74,7 @@ type responseTask struct {
|
||||
Duration int `json:"duration"`
|
||||
Ratio string `json:"ratio"`
|
||||
FramesPerSecond int `json:"framespersecond"`
|
||||
ServiceTier string `json:"service_tier"`
|
||||
Usage struct {
|
||||
CompletionTokens int `json:"completion_tokens"`
|
||||
TotalTokens int `json:"total_tokens"`
|
||||
@@ -98,16 +120,16 @@ func (a *TaskAdaptor) BuildRequestHeader(c *gin.Context, req *http.Request, info
|
||||
|
||||
// BuildRequestBody converts request into Doubao specific format.
|
||||
func (a *TaskAdaptor) BuildRequestBody(c *gin.Context, info *relaycommon.RelayInfo) (io.Reader, error) {
|
||||
v, exists := c.Get("task_request")
|
||||
if !exists {
|
||||
return nil, fmt.Errorf("request not found in context")
|
||||
req, err := relaycommon.GetTaskRequest(c)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
req := v.(relaycommon.TaskSubmitReq)
|
||||
|
||||
body, err := a.convertToRequestPayload(&req)
|
||||
if err != nil {
|
||||
return nil, errors.Wrap(err, "convert request payload failed")
|
||||
}
|
||||
info.UpstreamModelName = body.Model
|
||||
data, err := json.Marshal(body)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
@@ -141,7 +163,13 @@ func (a *TaskAdaptor) DoResponse(c *gin.Context, resp *http.Response, info *rela
|
||||
return
|
||||
}
|
||||
|
||||
c.JSON(http.StatusOK, gin.H{"task_id": dResp.ID})
|
||||
ov := dto.NewOpenAIVideo()
|
||||
ov.ID = dResp.ID
|
||||
ov.TaskID = dResp.ID
|
||||
ov.CreatedAt = time.Now().Unix()
|
||||
ov.Model = info.OriginModelName
|
||||
|
||||
c.JSON(http.StatusOK, ov)
|
||||
return dResp.ID, responseBody, nil
|
||||
}
|
||||
|
||||
@@ -204,12 +232,15 @@ func (a *TaskAdaptor) convertToRequestPayload(req *relaycommon.TaskSubmitReq) (*
|
||||
}
|
||||
}
|
||||
|
||||
// TODO: Add support for additional parameters from metadata
|
||||
// such as ratio, duration, seed, etc.
|
||||
// metadata := req.Metadata
|
||||
// if metadata != nil {
|
||||
// // Parse and apply metadata parameters
|
||||
// }
|
||||
metadata := req.Metadata
|
||||
medaBytes, err := json.Marshal(metadata)
|
||||
if err != nil {
|
||||
return nil, errors.Wrap(err, "metadata marshal metadata failed")
|
||||
}
|
||||
err = json.Unmarshal(medaBytes, &r)
|
||||
if err != nil {
|
||||
return nil, errors.Wrap(err, "unmarshal metadata failed")
|
||||
}
|
||||
|
||||
return &r, nil
|
||||
}
|
||||
@@ -229,7 +260,7 @@ func (a *TaskAdaptor) ParseTaskResult(respBody []byte) (*relaycommon.TaskInfo, e
|
||||
case "pending", "queued":
|
||||
taskResult.Status = model.TaskStatusQueued
|
||||
taskResult.Progress = "10%"
|
||||
case "processing":
|
||||
case "processing", "running":
|
||||
taskResult.Status = model.TaskStatusInProgress
|
||||
taskResult.Progress = "50%"
|
||||
case "succeeded":
|
||||
@@ -251,3 +282,30 @@ func (a *TaskAdaptor) ParseTaskResult(respBody []byte) (*relaycommon.TaskInfo, e
|
||||
|
||||
return &taskResult, nil
|
||||
}
|
||||
|
||||
func (a *TaskAdaptor) ConvertToOpenAIVideo(originTask *model.Task) ([]byte, error) {
|
||||
var dResp responseTask
|
||||
if err := json.Unmarshal(originTask.Data, &dResp); err != nil {
|
||||
return nil, errors.Wrap(err, "unmarshal doubao task data failed")
|
||||
}
|
||||
|
||||
openAIVideo := dto.NewOpenAIVideo()
|
||||
openAIVideo.ID = originTask.TaskID
|
||||
openAIVideo.TaskID = originTask.TaskID
|
||||
openAIVideo.Status = originTask.Status.ToVideoStatus()
|
||||
openAIVideo.SetProgressStr(originTask.Progress)
|
||||
openAIVideo.SetMetadata("url", dResp.Content.VideoURL)
|
||||
openAIVideo.CreatedAt = originTask.CreatedAt
|
||||
openAIVideo.CompletedAt = originTask.UpdatedAt
|
||||
openAIVideo.Model = originTask.Properties.OriginModelName
|
||||
|
||||
if dResp.Status == "failed" {
|
||||
openAIVideo.Error = &dto.OpenAIVideoError{
|
||||
Message: "task failed",
|
||||
Code: "failed",
|
||||
}
|
||||
}
|
||||
|
||||
jsonData, _ := common.Marshal(openAIVideo)
|
||||
return jsonData, nil
|
||||
}
|
||||
|
||||
@@ -4,6 +4,7 @@ var ModelList = []string{
|
||||
"doubao-seedance-1-0-pro-250528",
|
||||
"doubao-seedance-1-0-lite-t2v",
|
||||
"doubao-seedance-1-0-lite-i2v",
|
||||
"doubao-seedance-1-5-pro-251215",
|
||||
}
|
||||
|
||||
var ChannelName = "doubao-video"
|
||||
|
||||
@@ -17,6 +17,7 @@ import (
|
||||
|
||||
"github.com/QuantumNous/new-api/common"
|
||||
"github.com/QuantumNous/new-api/model"
|
||||
"github.com/samber/lo"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/pkg/errors"
|
||||
@@ -409,14 +410,15 @@ func (a *TaskAdaptor) convertToRequestPayload(req *relaycommon.TaskSubmitReq) (*
|
||||
|
||||
// 即梦视频3.0 ReqKey转换
|
||||
// https://www.volcengine.com/docs/85621/1792707
|
||||
imageLen := lo.Max([]int{len(req.Images), len(r.BinaryDataBase64), len(r.ImageUrls)})
|
||||
if strings.Contains(r.ReqKey, "jimeng_v30") {
|
||||
if r.ReqKey == "jimeng_v30_pro" {
|
||||
// 3.0 pro只有固定的jimeng_ti2v_v30_pro
|
||||
r.ReqKey = "jimeng_ti2v_v30_pro"
|
||||
} else if len(req.Images) > 1 {
|
||||
} else if imageLen > 1 {
|
||||
// 多张图片:首尾帧生成
|
||||
r.ReqKey = strings.TrimSuffix(strings.Replace(r.ReqKey, "jimeng_v30", "jimeng_i2v_first_tail_v30", 1), "p")
|
||||
} else if len(req.Images) == 1 {
|
||||
} else if imageLen == 1 {
|
||||
// 单张图片:图生视频
|
||||
r.ReqKey = strings.TrimSuffix(strings.Replace(r.ReqKey, "jimeng_v30", "jimeng_i2v_first_v30", 1), "p")
|
||||
} else {
|
||||
|
||||
@@ -346,7 +346,7 @@ func (a *TaskAdaptor) ParseTaskResult(respBody []byte) (*relaycommon.TaskInfo, e
|
||||
}
|
||||
taskInfo.Code = resPayload.Code
|
||||
taskInfo.TaskID = resPayload.Data.TaskId
|
||||
taskInfo.Reason = resPayload.Message
|
||||
taskInfo.Reason = resPayload.Data.TaskStatusMsg
|
||||
//任务状态,枚举值:submitted(已提交)、processing(处理中)、succeed(成功)、failed(失败)
|
||||
status := resPayload.Data.TaskStatus
|
||||
switch status {
|
||||
|
||||
@@ -40,6 +40,7 @@ var claudeModelMap = map[string]string{
|
||||
"claude-opus-4-20250514": "claude-opus-4@20250514",
|
||||
"claude-opus-4-1-20250805": "claude-opus-4-1@20250805",
|
||||
"claude-sonnet-4-5-20250929": "claude-sonnet-4-5@20250929",
|
||||
"claude-haiku-4-5-20251001": "claude-haiku-4-5@20251001",
|
||||
"claude-opus-4-5-20251101": "claude-opus-4-5@20251101",
|
||||
}
|
||||
|
||||
|
||||
@@ -270,6 +270,8 @@ func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) {
|
||||
// return fmt.Sprintf("%s/api/v3/images/edits", baseUrl), nil
|
||||
case constant.RelayModeRerank:
|
||||
return fmt.Sprintf("%s/api/v3/rerank", baseUrl), nil
|
||||
case constant.RelayModeResponses:
|
||||
return fmt.Sprintf("%s/api/v3/responses", baseUrl), nil
|
||||
case constant.RelayModeAudioSpeech:
|
||||
if baseUrl == channelconstant.ChannelBaseURLs[channelconstant.ChannelTypeVolcEngine] {
|
||||
return "wss://openspeech.bytedance.com/api/v1/tts/ws_binary", nil
|
||||
@@ -323,7 +325,7 @@ func (a *Adaptor) ConvertEmbeddingRequest(c *gin.Context, info *relaycommon.Rela
|
||||
}
|
||||
|
||||
func (a *Adaptor) ConvertOpenAIResponsesRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.OpenAIResponsesRequest) (any, error) {
|
||||
return nil, errors.New("not implemented")
|
||||
return request, nil
|
||||
}
|
||||
|
||||
func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (any, error) {
|
||||
|
||||
Reference in New Issue
Block a user