Merge branch 'QuantumNous:main' into fix-claude-haiku

This commit is contained in:
papersnake
2025-12-26 16:23:34 +08:00
committed by GitHub
122 changed files with 2366 additions and 806 deletions

View File

@@ -67,8 +67,11 @@ func AudioHelper(c *gin.Context, info *relaycommon.RelayInfo) (newAPIError *type
service.ResetStatusCode(newAPIError, statusCodeMappingStr)
return newAPIError
}
postConsumeQuota(c, info, usage.(*dto.Usage), "")
if usage.(*dto.Usage).CompletionTokenDetails.AudioTokens > 0 || usage.(*dto.Usage).PromptTokensDetails.AudioTokens > 0 {
service.PostAudioConsumeQuota(c, info, usage.(*dto.Usage), "")
} else {
postConsumeQuota(c, info, usage.(*dto.Usage), "")
}
return nil
}

View File

@@ -47,7 +47,7 @@ type TaskAdaptor interface {
GetChannelName() string
// FetchTask
FetchTask(baseUrl, key string, body map[string]any) (*http.Response, error)
FetchTask(baseUrl, key string, body map[string]any, proxy string) (*http.Response, error)
ParseTaskResult(respBody []byte) (*relaycommon.TaskInfo, error)
}

View File

@@ -18,7 +18,7 @@ var awsModelIDMap = map[string]string{
"claude-opus-4-1-20250805": "anthropic.claude-opus-4-1-20250805-v1:0",
"claude-sonnet-4-5-20250929": "anthropic.claude-sonnet-4-5-20250929-v1:0",
"claude-haiku-4-5-20251001": "anthropic.claude-haiku-4-5-20251001-v1:0",
"claude-opus-4-5-20251101": "anthropic.claude-opus-4-5-20251101-v1:0",
"claude-opus-4-5-20251101": "anthropic.claude-opus-4-5-20251101-v1:0",
// Nova models
"nova-micro-v1:0": "amazon.nova-micro-v1:0",
"nova-lite-v1:0": "amazon.nova-lite-v1:0",

View File

@@ -18,6 +18,7 @@ import (
"github.com/gin-gonic/gin"
"github.com/pkg/errors"
"github.com/QuantumNous/new-api/setting/model_setting"
"github.com/aws/aws-sdk-go-v2/aws"
"github.com/aws/aws-sdk-go-v2/credentials"
"github.com/aws/aws-sdk-go-v2/service/bedrockruntime"
@@ -129,7 +130,7 @@ func doAwsClientRequest(c *gin.Context, info *relaycommon.RelayInfo, a *Adaptor,
Accept: aws.String("application/json"),
ContentType: aws.String("application/json"),
}
awsReq.Body, err = common.Marshal(awsClaudeReq)
awsReq.Body, err = buildAwsRequestBody(c, info, awsClaudeReq)
if err != nil {
return nil, types.NewError(errors.Wrap(err, "marshal aws request fail"), types.ErrorCodeBadRequestBody)
}
@@ -141,7 +142,7 @@ func doAwsClientRequest(c *gin.Context, info *relaycommon.RelayInfo, a *Adaptor,
Accept: aws.String("application/json"),
ContentType: aws.String("application/json"),
}
awsReq.Body, err = common.Marshal(awsClaudeReq)
awsReq.Body, err = buildAwsRequestBody(c, info, awsClaudeReq)
if err != nil {
return nil, types.NewError(errors.Wrap(err, "marshal aws request fail"), types.ErrorCodeBadRequestBody)
}
@@ -151,6 +152,24 @@ func doAwsClientRequest(c *gin.Context, info *relaycommon.RelayInfo, a *Adaptor,
}
}
// buildAwsRequestBody prepares the payload for AWS requests, applying passthrough rules when enabled.
func buildAwsRequestBody(c *gin.Context, info *relaycommon.RelayInfo, awsClaudeReq any) ([]byte, error) {
if model_setting.GetGlobalSettings().PassThroughRequestEnabled || info.ChannelSetting.PassThroughBodyEnabled {
body, err := common.GetRequestBody(c)
if err != nil {
return nil, errors.Wrap(err, "get request body for pass-through fail")
}
var data map[string]interface{}
if err := common.Unmarshal(body, &data); err != nil {
return nil, errors.Wrap(err, "pass-through unmarshal request body fail")
}
delete(data, "model")
delete(data, "stream")
return common.Marshal(data)
}
return common.Marshal(awsClaudeReq)
}
func getAwsRegionPrefix(awsRegionId string) string {
parts := strings.Split(awsRegionId, "-")
regionPrefix := ""

View File

@@ -150,7 +150,7 @@ func baiduHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http.Respon
return types.NewError(err, types.ErrorCodeBadResponseBody), nil
}
if baiduResponse.ErrorMsg != "" {
return types.NewError(fmt.Errorf(baiduResponse.ErrorMsg), types.ErrorCodeBadResponseBody), nil
return types.NewError(fmt.Errorf("%s", baiduResponse.ErrorMsg), types.ErrorCodeBadResponseBody), nil
}
fullTextResponse := responseBaidu2OpenAI(&baiduResponse)
jsonResponse, err := json.Marshal(fullTextResponse)
@@ -175,7 +175,7 @@ func baiduEmbeddingHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *ht
return types.NewError(err, types.ErrorCodeBadResponseBody), nil
}
if baiduResponse.ErrorMsg != "" {
return types.NewError(fmt.Errorf(baiduResponse.ErrorMsg), types.ErrorCodeBadResponseBody), nil
return types.NewError(fmt.Errorf("%s", baiduResponse.ErrorMsg), types.ErrorCodeBadResponseBody), nil
}
fullTextResponse := embeddingResponseBaidu2OpenAI(&baiduResponse)
jsonResponse, err := json.Marshal(fullTextResponse)

View File

@@ -9,6 +9,7 @@ var ModelList = []string{
"claude-3-opus-20240229",
"claude-3-haiku-20240307",
"claude-3-5-haiku-20241022",
"claude-haiku-4-5-20251001",
"claude-3-5-sonnet-20240620",
"claude-3-5-sonnet-20241022",
"claude-3-7-sonnet-20250219",

View File

@@ -483,9 +483,11 @@ func StreamResponseClaude2OpenAI(reqMode int, claudeResponse *dto.ClaudeResponse
}
}
} else if claudeResponse.Type == "message_delta" {
finishReason := stopReasonClaude2OpenAI(*claudeResponse.Delta.StopReason)
if finishReason != "null" {
choice.FinishReason = &finishReason
if claudeResponse.Delta != nil && claudeResponse.Delta.StopReason != nil {
finishReason := stopReasonClaude2OpenAI(*claudeResponse.Delta.StopReason)
if finishReason != "null" {
choice.FinishReason = &finishReason
}
}
//claudeUsage = &claudeResponse.Usage
} else if claudeResponse.Type == "message_stop" {

View File

@@ -208,7 +208,7 @@ func handleCozeEvent(c *gin.Context, event string, data string, responseText *st
return
}
common.SysLog(fmt.Sprintf("stream event error: ", errorData.Code, errorData.Message))
common.SysLog(fmt.Sprintf("stream event error: %v %v", errorData.Code, errorData.Message))
}
}

View File

@@ -13,6 +13,7 @@ import (
relaycommon "github.com/QuantumNous/new-api/relay/common"
"github.com/QuantumNous/new-api/relay/constant"
"github.com/QuantumNous/new-api/setting/model_setting"
"github.com/QuantumNous/new-api/setting/reasoning"
"github.com/QuantumNous/new-api/types"
"github.com/gin-gonic/gin"
@@ -137,7 +138,7 @@ func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) {
info.UpstreamModelName = strings.TrimSuffix(info.UpstreamModelName, "-thinking")
} else if strings.HasSuffix(info.UpstreamModelName, "-nothinking") {
info.UpstreamModelName = strings.TrimSuffix(info.UpstreamModelName, "-nothinking")
} else if baseModel, level := parseThinkingLevelSuffix(info.UpstreamModelName); level != "" {
} else if baseModel, level, ok := reasoning.TrimEffortSuffix(info.UpstreamModelName); ok && level != "" {
info.UpstreamModelName = baseModel
}
}

View File

@@ -94,10 +94,10 @@ func GeminiTextGenerationStreamHandler(c *gin.Context, info *relaycommon.RelayIn
helper.SetEventStreamHeaders(c)
return geminiStreamHandler(c, info, resp, func(data string, geminiResponse *dto.GeminiChatResponse) bool {
// 直接发送 GeminiChatResponse 响应
err := helper.StringData(c, data)
if err != nil {
logger.LogError(c, err.Error())
logger.LogError(c, "failed to write stream data: "+err.Error())
return false
}
info.SendResponseCount++
return true

View File

@@ -98,6 +98,7 @@ func clampThinkingBudget(modelName string, budget int) int {
// "effort": "high" - Allocates a large portion of tokens for reasoning (approximately 80% of max_tokens)
// "effort": "medium" - Allocates a moderate portion of tokens (approximately 50% of max_tokens)
// "effort": "low" - Allocates a smaller portion of tokens (approximately 20% of max_tokens)
// "effort": "minimal" - Allocates a minimal portion of tokens (approximately 5% of max_tokens)
func clampThinkingBudgetByEffort(modelName string, effort string) int {
isNew25Pro := isNew25ProModel(modelName)
is25FlashLite := is25FlashLiteModel(modelName)
@@ -118,18 +119,12 @@ func clampThinkingBudgetByEffort(modelName string, effort string) int {
maxBudget = maxBudget * 50 / 100
case "low":
maxBudget = maxBudget * 20 / 100
case "minimal":
maxBudget = maxBudget * 5 / 100
}
return clampThinkingBudget(modelName, maxBudget)
}
func parseThinkingLevelSuffix(modelName string) (string, string) {
base, level, ok := reasoning.TrimEffortSuffix(modelName)
if !ok {
return modelName, ""
}
return base, level
}
func ThinkingAdaptor(geminiRequest *dto.GeminiChatRequest, info *relaycommon.RelayInfo, oaiRequest ...dto.GeneralOpenAIRequest) {
if model_setting.GetGeminiSettings().ThinkingAdapterEnabled {
modelName := info.UpstreamModelName
@@ -186,7 +181,7 @@ func ThinkingAdaptor(geminiRequest *dto.GeminiChatRequest, info *relaycommon.Rel
ThinkingBudget: common.GetPointer(0),
}
}
} else if _, level := parseThinkingLevelSuffix(modelName); level != "" {
} else if _, level, ok := reasoning.TrimEffortSuffix(info.UpstreamModelName); ok && level != "" {
geminiRequest.GenerationConfig.ThinkingConfig = &dto.GeminiThinkingConfig{
IncludeThoughts: true,
ThinkingLevel: level,

View File

@@ -42,7 +42,7 @@ type Adaptor struct {
// support OAI models: o1-mini/o3-mini/o4-mini/o1/o3 etc...
// minimal effort only available in gpt-5
func parseReasoningEffortFromModelSuffix(model string) (string, string) {
effortSuffixes := []string{"-high", "-minimal", "-low", "-medium", "-none"}
effortSuffixes := []string{"-high", "-minimal", "-low", "-medium", "-none", "-xhigh"}
for _, suffix := range effortSuffixes {
if strings.HasSuffix(model, suffix) {
effort := strings.TrimPrefix(suffix, "-")
@@ -306,10 +306,11 @@ func (a *Adaptor) ConvertOpenAIRequest(c *gin.Context, info *relaycommon.RelayIn
request.Temperature = nil
}
// gpt-5系列模型适配 归零不再支持的参数
if strings.HasPrefix(info.UpstreamModelName, "gpt-5") {
if info.UpstreamModelName != "gpt-5-chat-latest" {
request.Temperature = nil
}
request.Temperature = nil
request.TopP = 0 // oai 的 top_p 默认值是 1.0,但是为了 omitempty 属性直接不传,这里显式设置为 0
request.LogProbs = false
}
// 转换模型推理力度后缀

View File

@@ -0,0 +1,145 @@
package openai
import (
"bytes"
"fmt"
"io"
"math"
"net/http"
"github.com/QuantumNous/new-api/common"
"github.com/QuantumNous/new-api/constant"
"github.com/QuantumNous/new-api/dto"
"github.com/QuantumNous/new-api/logger"
relaycommon "github.com/QuantumNous/new-api/relay/common"
"github.com/QuantumNous/new-api/relay/helper"
"github.com/QuantumNous/new-api/service"
"github.com/QuantumNous/new-api/types"
"github.com/gin-gonic/gin"
)
func OpenaiTTSHandler(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) *dto.Usage {
// the status code has been judged before, if there is a body reading failure,
// it should be regarded as a non-recoverable error, so it should not return err for external retry.
// Analogous to nginx's load balancing, it will only retry if it can't be requested or
// if the upstream returns a specific status code, once the upstream has already written the header,
// the subsequent failure of the response body should be regarded as a non-recoverable error,
// and can be terminated directly.
defer service.CloseResponseBodyGracefully(resp)
usage := &dto.Usage{}
usage.PromptTokens = info.GetEstimatePromptTokens()
usage.TotalTokens = info.GetEstimatePromptTokens()
for k, v := range resp.Header {
c.Writer.Header().Set(k, v[0])
}
c.Writer.WriteHeader(resp.StatusCode)
if info.IsStream {
helper.StreamScannerHandler(c, resp, info, func(data string) bool {
if service.SundaySearch(data, "usage") {
var simpleResponse dto.SimpleResponse
err := common.Unmarshal([]byte(data), &simpleResponse)
if err != nil {
logger.LogError(c, err.Error())
}
if simpleResponse.Usage.TotalTokens != 0 {
usage.PromptTokens = simpleResponse.Usage.InputTokens
usage.CompletionTokens = simpleResponse.OutputTokens
usage.TotalTokens = simpleResponse.TotalTokens
}
}
_ = helper.StringData(c, data)
return true
})
} else {
common.SetContextKey(c, constant.ContextKeyLocalCountTokens, true)
// 读取响应体到缓冲区
bodyBytes, err := io.ReadAll(resp.Body)
if err != nil {
logger.LogError(c, fmt.Sprintf("failed to read TTS response body: %v", err))
c.Writer.WriteHeaderNow()
return usage
}
// 写入响应到客户端
c.Writer.WriteHeaderNow()
_, err = c.Writer.Write(bodyBytes)
if err != nil {
logger.LogError(c, fmt.Sprintf("failed to write TTS response: %v", err))
}
// 计算音频时长并更新 usage
audioFormat := "mp3" // 默认格式
if audioReq, ok := info.Request.(*dto.AudioRequest); ok && audioReq.ResponseFormat != "" {
audioFormat = audioReq.ResponseFormat
}
var duration float64
var durationErr error
if audioFormat == "pcm" {
// PCM 格式没有文件头,根据 OpenAI TTS 的 PCM 参数计算时长
// 采样率: 24000 Hz, 位深度: 16-bit (2 bytes), 声道数: 1
const sampleRate = 24000
const bytesPerSample = 2
const channels = 1
duration = float64(len(bodyBytes)) / float64(sampleRate*bytesPerSample*channels)
} else {
ext := "." + audioFormat
reader := bytes.NewReader(bodyBytes)
duration, durationErr = common.GetAudioDuration(c.Request.Context(), reader, ext)
}
usage.PromptTokensDetails.TextTokens = usage.PromptTokens
if durationErr != nil {
logger.LogWarn(c, fmt.Sprintf("failed to get audio duration: %v", durationErr))
// 如果无法获取时长,则设置保底的 CompletionTokens根据body大小计算
sizeInKB := float64(len(bodyBytes)) / 1000.0
estimatedTokens := int(math.Ceil(sizeInKB)) // 粗略估算每KB约等于1 token
usage.CompletionTokens = estimatedTokens
usage.CompletionTokenDetails.AudioTokens = estimatedTokens
} else if duration > 0 {
// 计算 token: ceil(duration) / 60.0 * 1000即每分钟 1000 tokens
completionTokens := int(math.Round(math.Ceil(duration) / 60.0 * 1000))
usage.CompletionTokens = completionTokens
usage.CompletionTokenDetails.AudioTokens = completionTokens
}
usage.TotalTokens = usage.PromptTokens + usage.CompletionTokens
}
return usage
}
func OpenaiSTTHandler(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo, responseFormat string) (*types.NewAPIError, *dto.Usage) {
defer service.CloseResponseBodyGracefully(resp)
responseBody, err := io.ReadAll(resp.Body)
if err != nil {
return types.NewOpenAIError(err, types.ErrorCodeReadResponseBodyFailed, http.StatusInternalServerError), nil
}
// 写入新的 response body
service.IOCopyBytesGracefully(c, resp, responseBody)
var responseData struct {
Usage *dto.Usage `json:"usage"`
}
if err := common.Unmarshal(responseBody, &responseData); err == nil && responseData.Usage != nil {
if responseData.Usage.TotalTokens > 0 {
usage := responseData.Usage
if usage.PromptTokens == 0 {
usage.PromptTokens = usage.InputTokens
}
if usage.CompletionTokens == 0 {
usage.CompletionTokens = usage.OutputTokens
}
return nil, usage
}
}
usage := &dto.Usage{}
usage.PromptTokens = info.GetEstimatePromptTokens()
usage.CompletionTokens = 0
usage.TotalTokens = usage.PromptTokens + usage.CompletionTokens
return nil, usage
}

View File

@@ -172,7 +172,7 @@ func handleLastResponse(lastStreamData string, responseId *string, createAt *int
shouldSendLastResp *bool) error {
var lastStreamResponse dto.ChatCompletionsStreamResponse
if err := json.Unmarshal(common.StringToByteSlice(lastStreamData), &lastStreamResponse); err != nil {
if err := common.Unmarshal(common.StringToByteSlice(lastStreamData), &lastStreamResponse); err != nil {
return err
}

View File

@@ -1,7 +1,6 @@
package openai
import (
"encoding/json"
"fmt"
"io"
"net/http"
@@ -151,7 +150,7 @@ func OaiStreamHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http.Re
var streamResp struct {
Usage *dto.Usage `json:"usage"`
}
err := json.Unmarshal([]byte(secondLastStreamData), &streamResp)
err := common.Unmarshal([]byte(secondLastStreamData), &streamResp)
if err == nil && streamResp.Usage != nil && service.ValidUsage(streamResp.Usage) {
usage = streamResp.Usage
containStreamUsage = true
@@ -327,68 +326,6 @@ func streamTTSResponse(c *gin.Context, resp *http.Response) {
}
}
func OpenaiTTSHandler(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) *dto.Usage {
// the status code has been judged before, if there is a body reading failure,
// it should be regarded as a non-recoverable error, so it should not return err for external retry.
// Analogous to nginx's load balancing, it will only retry if it can't be requested or
// if the upstream returns a specific status code, once the upstream has already written the header,
// the subsequent failure of the response body should be regarded as a non-recoverable error,
// and can be terminated directly.
defer service.CloseResponseBodyGracefully(resp)
usage := &dto.Usage{}
usage.PromptTokens = info.GetEstimatePromptTokens()
usage.TotalTokens = info.GetEstimatePromptTokens()
for k, v := range resp.Header {
c.Writer.Header().Set(k, v[0])
}
c.Writer.WriteHeader(resp.StatusCode)
isStreaming := resp.ContentLength == -1 || resp.Header.Get("Content-Length") == ""
if isStreaming {
streamTTSResponse(c, resp)
} else {
c.Writer.WriteHeaderNow()
_, err := io.Copy(c.Writer, resp.Body)
if err != nil {
logger.LogError(c, err.Error())
}
}
return usage
}
func OpenaiSTTHandler(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo, responseFormat string) (*types.NewAPIError, *dto.Usage) {
defer service.CloseResponseBodyGracefully(resp)
responseBody, err := io.ReadAll(resp.Body)
if err != nil {
return types.NewOpenAIError(err, types.ErrorCodeReadResponseBodyFailed, http.StatusInternalServerError), nil
}
// 写入新的 response body
service.IOCopyBytesGracefully(c, resp, responseBody)
var responseData struct {
Usage *dto.Usage `json:"usage"`
}
if err := json.Unmarshal(responseBody, &responseData); err == nil && responseData.Usage != nil {
if responseData.Usage.TotalTokens > 0 {
usage := responseData.Usage
if usage.PromptTokens == 0 {
usage.PromptTokens = usage.InputTokens
}
if usage.CompletionTokens == 0 {
usage.CompletionTokens = usage.OutputTokens
}
return nil, usage
}
}
usage := &dto.Usage{}
usage.PromptTokens = info.GetEstimatePromptTokens()
usage.CompletionTokens = 0
usage.TotalTokens = usage.PromptTokens + usage.CompletionTokens
return nil, usage
}
func OpenaiRealtimeHandler(c *gin.Context, info *relaycommon.RelayInfo) (*types.NewAPIError, *dto.RealtimeUsage) {
if info == nil || info.ClientWs == nil || info.TargetWs == nil {
return types.NewError(fmt.Errorf("invalid websocket connection"), types.ErrorCodeBadResponse), nil
@@ -659,7 +596,7 @@ func applyUsagePostProcessing(info *relaycommon.RelayInfo, usage *dto.Usage, res
if usage.PromptTokensDetails.CachedTokens == 0 && usage.PromptCacheHitTokens != 0 {
usage.PromptTokensDetails.CachedTokens = usage.PromptCacheHitTokens
}
case constant.ChannelTypeZhipu_v4:
case constant.ChannelTypeZhipu_v4, constant.ChannelTypeMoonshot:
if usage.PromptTokensDetails.CachedTokens == 0 {
if usage.InputTokensDetails != nil && usage.InputTokensDetails.CachedTokens > 0 {
usage.PromptTokensDetails.CachedTokens = usage.InputTokensDetails.CachedTokens
@@ -687,7 +624,7 @@ func extractCachedTokensFromBody(body []byte) (int, bool) {
} `json:"usage"`
}
if err := json.Unmarshal(body, &payload); err != nil {
if err := common.Unmarshal(body, &payload); err != nil {
return 0, false
}

View File

@@ -393,7 +393,7 @@ func (a *TaskAdaptor) DoResponse(c *gin.Context, resp *http.Response, info *rela
}
// FetchTask 查询任务状态
func (a *TaskAdaptor) FetchTask(baseUrl, key string, body map[string]any) (*http.Response, error) {
func (a *TaskAdaptor) FetchTask(baseUrl, key string, body map[string]any, proxy string) (*http.Response, error) {
taskID, ok := body["task_id"].(string)
if !ok {
return nil, fmt.Errorf("invalid task_id")
@@ -408,7 +408,11 @@ func (a *TaskAdaptor) FetchTask(baseUrl, key string, body map[string]any) (*http
req.Header.Set("Authorization", "Bearer "+key)
return service.GetHttpClient().Do(req)
client, err := service.GetHttpClientWithProxy(proxy)
if err != nil {
return nil, fmt.Errorf("new proxy http client failed: %w", err)
}
return client.Do(req)
}
func (a *TaskAdaptor) GetModelList() []string {

View File

@@ -146,7 +146,7 @@ func (a *TaskAdaptor) DoResponse(c *gin.Context, resp *http.Response, info *rela
}
// FetchTask fetch task status
func (a *TaskAdaptor) FetchTask(baseUrl, key string, body map[string]any) (*http.Response, error) {
func (a *TaskAdaptor) FetchTask(baseUrl, key string, body map[string]any, proxy string) (*http.Response, error) {
taskID, ok := body["task_id"].(string)
if !ok {
return nil, fmt.Errorf("invalid task_id")
@@ -163,7 +163,11 @@ func (a *TaskAdaptor) FetchTask(baseUrl, key string, body map[string]any) (*http
req.Header.Set("Content-Type", "application/json")
req.Header.Set("Authorization", "Bearer "+key)
return service.GetHttpClient().Do(req)
client, err := service.GetHttpClientWithProxy(proxy)
if err != nil {
return nil, fmt.Errorf("new proxy http client failed: %w", err)
}
return client.Do(req)
}
func (a *TaskAdaptor) GetModelList() []string {

View File

@@ -200,7 +200,7 @@ func (a *TaskAdaptor) GetChannelName() string {
}
// FetchTask fetch task status
func (a *TaskAdaptor) FetchTask(baseUrl, key string, body map[string]any) (*http.Response, error) {
func (a *TaskAdaptor) FetchTask(baseUrl, key string, body map[string]any, proxy string) (*http.Response, error) {
taskID, ok := body["task_id"].(string)
if !ok {
return nil, fmt.Errorf("invalid task_id")
@@ -223,7 +223,11 @@ func (a *TaskAdaptor) FetchTask(baseUrl, key string, body map[string]any) (*http
req.Header.Set("Accept", "application/json")
req.Header.Set("x-goog-api-key", key)
return service.GetHttpClient().Do(req)
client, err := service.GetHttpClientWithProxy(proxy)
if err != nil {
return nil, fmt.Errorf("new proxy http client failed: %w", err)
}
return client.Do(req)
}
func (a *TaskAdaptor) ParseTaskResult(respBody []byte) (*relaycommon.TaskInfo, error) {

View File

@@ -110,7 +110,7 @@ func (a *TaskAdaptor) DoResponse(c *gin.Context, resp *http.Response, info *rela
return hResp.TaskID, responseBody, nil
}
func (a *TaskAdaptor) FetchTask(baseUrl, key string, body map[string]any) (*http.Response, error) {
func (a *TaskAdaptor) FetchTask(baseUrl, key string, body map[string]any, proxy string) (*http.Response, error) {
taskID, ok := body["task_id"].(string)
if !ok {
return nil, fmt.Errorf("invalid task_id")
@@ -126,7 +126,11 @@ func (a *TaskAdaptor) FetchTask(baseUrl, key string, body map[string]any) (*http
req.Header.Set("Accept", "application/json")
req.Header.Set("Authorization", "Bearer "+key)
return service.GetHttpClient().Do(req)
client, err := service.GetHttpClientWithProxy(proxy)
if err != nil {
return nil, fmt.Errorf("new proxy http client failed: %w", err)
}
return client.Do(req)
}
func (a *TaskAdaptor) GetModelList() []string {

View File

@@ -196,7 +196,7 @@ func (a *TaskAdaptor) DoResponse(c *gin.Context, resp *http.Response, info *rela
}
if jResp.Code != 10000 {
taskErr = service.TaskErrorWrapper(fmt.Errorf(jResp.Message), fmt.Sprintf("%d", jResp.Code), http.StatusInternalServerError)
taskErr = service.TaskErrorWrapper(fmt.Errorf("%s", jResp.Message), fmt.Sprintf("%d", jResp.Code), http.StatusInternalServerError)
return
}
@@ -210,7 +210,7 @@ func (a *TaskAdaptor) DoResponse(c *gin.Context, resp *http.Response, info *rela
}
// FetchTask fetch task status
func (a *TaskAdaptor) FetchTask(baseUrl, key string, body map[string]any) (*http.Response, error) {
func (a *TaskAdaptor) FetchTask(baseUrl, key string, body map[string]any, proxy string) (*http.Response, error) {
taskID, ok := body["task_id"].(string)
if !ok {
return nil, fmt.Errorf("invalid task_id")
@@ -251,7 +251,11 @@ func (a *TaskAdaptor) FetchTask(baseUrl, key string, body map[string]any) (*http
return nil, errors.Wrap(err, "sign request failed")
}
}
return service.GetHttpClient().Do(req)
client, err := service.GetHttpClientWithProxy(proxy)
if err != nil {
return nil, fmt.Errorf("new proxy http client failed: %w", err)
}
return client.Do(req)
}
func (a *TaskAdaptor) GetModelList() []string {

View File

@@ -186,7 +186,7 @@ func (a *TaskAdaptor) DoResponse(c *gin.Context, resp *http.Response, info *rela
return
}
if kResp.Code != 0 {
taskErr = service.TaskErrorWrapperLocal(fmt.Errorf(kResp.Message), "task_failed", http.StatusBadRequest)
taskErr = service.TaskErrorWrapperLocal(fmt.Errorf("%s", kResp.Message), "task_failed", http.StatusBadRequest)
return
}
ov := dto.NewOpenAIVideo()
@@ -199,7 +199,7 @@ func (a *TaskAdaptor) DoResponse(c *gin.Context, resp *http.Response, info *rela
}
// FetchTask fetch task status
func (a *TaskAdaptor) FetchTask(baseUrl, key string, body map[string]any) (*http.Response, error) {
func (a *TaskAdaptor) FetchTask(baseUrl, key string, body map[string]any, proxy string) (*http.Response, error) {
taskID, ok := body["task_id"].(string)
if !ok {
return nil, fmt.Errorf("invalid task_id")
@@ -228,7 +228,11 @@ func (a *TaskAdaptor) FetchTask(baseUrl, key string, body map[string]any) (*http
req.Header.Set("Authorization", "Bearer "+token)
req.Header.Set("User-Agent", "kling-sdk/1.0")
return service.GetHttpClient().Do(req)
client, err := service.GetHttpClientWithProxy(proxy)
if err != nil {
return nil, fmt.Errorf("new proxy http client failed: %w", err)
}
return client.Do(req)
}
func (a *TaskAdaptor) GetModelList() []string {

View File

@@ -5,8 +5,10 @@ import (
"fmt"
"io"
"net/http"
"strings"
"github.com/QuantumNous/new-api/common"
"github.com/QuantumNous/new-api/constant"
"github.com/QuantumNous/new-api/dto"
"github.com/QuantumNous/new-api/model"
"github.com/QuantumNous/new-api/relay/channel"
@@ -67,11 +69,30 @@ func (a *TaskAdaptor) Init(info *relaycommon.RelayInfo) {
a.apiKey = info.ApiKey
}
func validateRemixRequest(c *gin.Context) *dto.TaskError {
var req struct {
Prompt string `json:"prompt"`
}
if err := common.UnmarshalBodyReusable(c, &req); err != nil {
return service.TaskErrorWrapperLocal(err, "invalid_request", http.StatusBadRequest)
}
if strings.TrimSpace(req.Prompt) == "" {
return service.TaskErrorWrapperLocal(fmt.Errorf("field prompt is required"), "invalid_request", http.StatusBadRequest)
}
return nil
}
func (a *TaskAdaptor) ValidateRequestAndSetAction(c *gin.Context, info *relaycommon.RelayInfo) (taskErr *dto.TaskError) {
if info.Action == constant.TaskActionRemix {
return validateRemixRequest(c)
}
return relaycommon.ValidateMultipartDirect(c, info)
}
func (a *TaskAdaptor) BuildRequestURL(info *relaycommon.RelayInfo) (string, error) {
if info.Action == constant.TaskActionRemix {
return fmt.Sprintf("%s/v1/videos/%s/remix", a.baseURL, info.OriginTaskID), nil
}
return fmt.Sprintf("%s/v1/videos", a.baseURL), nil
}
@@ -125,7 +146,7 @@ func (a *TaskAdaptor) DoResponse(c *gin.Context, resp *http.Response, _ *relayco
}
// FetchTask fetch task status
func (a *TaskAdaptor) FetchTask(baseUrl, key string, body map[string]any) (*http.Response, error) {
func (a *TaskAdaptor) FetchTask(baseUrl, key string, body map[string]any, proxy string) (*http.Response, error) {
taskID, ok := body["task_id"].(string)
if !ok {
return nil, fmt.Errorf("invalid task_id")
@@ -140,7 +161,11 @@ func (a *TaskAdaptor) FetchTask(baseUrl, key string, body map[string]any) (*http
req.Header.Set("Authorization", "Bearer "+key)
return service.GetHttpClient().Do(req)
client, err := service.GetHttpClientWithProxy(proxy)
if err != nil {
return nil, fmt.Errorf("new proxy http client failed: %w", err)
}
return client.Do(req)
}
func (a *TaskAdaptor) GetModelList() []string {

View File

@@ -105,7 +105,7 @@ func (a *TaskAdaptor) DoResponse(c *gin.Context, resp *http.Response, info *rela
return
}
if !sunoResponse.IsSuccess() {
taskErr = service.TaskErrorWrapper(fmt.Errorf(sunoResponse.Message), sunoResponse.Code, http.StatusInternalServerError)
taskErr = service.TaskErrorWrapper(fmt.Errorf("%s", sunoResponse.Message), sunoResponse.Code, http.StatusInternalServerError)
return
}
@@ -132,7 +132,7 @@ func (a *TaskAdaptor) GetChannelName() string {
return ChannelName
}
func (a *TaskAdaptor) FetchTask(baseUrl, key string, body map[string]any) (*http.Response, error) {
func (a *TaskAdaptor) FetchTask(baseUrl, key string, body map[string]any, proxy string) (*http.Response, error) {
requestUrl := fmt.Sprintf("%s/suno/fetch", baseUrl)
byteBody, err := json.Marshal(body)
if err != nil {
@@ -153,11 +153,11 @@ func (a *TaskAdaptor) FetchTask(baseUrl, key string, body map[string]any) (*http
req = req.WithContext(ctx)
req.Header.Set("Content-Type", "application/json")
req.Header.Set("Authorization", "Bearer "+key)
resp, err := service.GetHttpClient().Do(req)
client, err := service.GetHttpClientWithProxy(proxy)
if err != nil {
return nil, err
return nil, fmt.Errorf("new proxy http client failed: %w", err)
}
return resp, nil
return client.Do(req)
}
func actionValidate(c *gin.Context, sunoRequest *dto.SunoSubmitReq, action string) (err error) {

View File

@@ -120,7 +120,11 @@ func (a *TaskAdaptor) BuildRequestHeader(c *gin.Context, req *http.Request, info
return fmt.Errorf("failed to decode credentials: %w", err)
}
token, err := vertexcore.AcquireAccessToken(*adc, "")
proxy := ""
if info != nil {
proxy = info.ChannelSetting.Proxy
}
token, err := vertexcore.AcquireAccessToken(*adc, proxy)
if err != nil {
return fmt.Errorf("failed to acquire access token: %w", err)
}
@@ -216,7 +220,7 @@ func (a *TaskAdaptor) GetModelList() []string { return []string{"veo-3.0-generat
func (a *TaskAdaptor) GetChannelName() string { return "vertex" }
// FetchTask fetch task status
func (a *TaskAdaptor) FetchTask(baseUrl, key string, body map[string]any) (*http.Response, error) {
func (a *TaskAdaptor) FetchTask(baseUrl, key string, body map[string]any, proxy string) (*http.Response, error) {
taskID, ok := body["task_id"].(string)
if !ok {
return nil, fmt.Errorf("invalid task_id")
@@ -249,7 +253,7 @@ func (a *TaskAdaptor) FetchTask(baseUrl, key string, body map[string]any) (*http
if err := json.Unmarshal([]byte(key), adc); err != nil {
return nil, fmt.Errorf("failed to decode credentials: %w", err)
}
token, err := vertexcore.AcquireAccessToken(*adc, "")
token, err := vertexcore.AcquireAccessToken(*adc, proxy)
if err != nil {
return nil, fmt.Errorf("failed to acquire access token: %w", err)
}
@@ -261,7 +265,11 @@ func (a *TaskAdaptor) FetchTask(baseUrl, key string, body map[string]any) (*http
req.Header.Set("Accept", "application/json")
req.Header.Set("Authorization", "Bearer "+token)
req.Header.Set("x-goog-user-project", adc.ProjectID)
return service.GetHttpClient().Do(req)
client, err := service.GetHttpClientWithProxy(proxy)
if err != nil {
return nil, fmt.Errorf("new proxy http client failed: %w", err)
}
return client.Do(req)
}
func (a *TaskAdaptor) ParseTaskResult(respBody []byte) (*relaycommon.TaskInfo, error) {

View File

@@ -188,7 +188,7 @@ func (a *TaskAdaptor) DoResponse(c *gin.Context, resp *http.Response, info *rela
return vResp.TaskId, responseBody, nil
}
func (a *TaskAdaptor) FetchTask(baseUrl, key string, body map[string]any) (*http.Response, error) {
func (a *TaskAdaptor) FetchTask(baseUrl, key string, body map[string]any, proxy string) (*http.Response, error) {
taskID, ok := body["task_id"].(string)
if !ok {
return nil, fmt.Errorf("invalid task_id")
@@ -204,7 +204,11 @@ func (a *TaskAdaptor) FetchTask(baseUrl, key string, body map[string]any) (*http
req.Header.Set("Accept", "application/json")
req.Header.Set("Authorization", "Token "+key)
return service.GetHttpClient().Do(req)
client, err := service.GetHttpClientWithProxy(proxy)
if err != nil {
return nil, fmt.Errorf("new proxy http client failed: %w", err)
}
return client.Do(req)
}
func (a *TaskAdaptor) GetModelList() []string {

View File

@@ -17,6 +17,7 @@ import (
relaycommon "github.com/QuantumNous/new-api/relay/common"
"github.com/QuantumNous/new-api/relay/constant"
"github.com/QuantumNous/new-api/setting/model_setting"
"github.com/QuantumNous/new-api/setting/reasoning"
"github.com/QuantumNous/new-api/types"
"github.com/gin-gonic/gin"
@@ -51,10 +52,43 @@ type Adaptor struct {
}
func (a *Adaptor) ConvertGeminiRequest(c *gin.Context, info *relaycommon.RelayInfo, request *dto.GeminiChatRequest) (any, error) {
// Vertex AI does not support functionResponse.id; keep it stripped here for consistency.
if model_setting.GetGeminiSettings().RemoveFunctionResponseIdEnabled {
removeFunctionResponseID(request)
}
geminiAdaptor := gemini.Adaptor{}
return geminiAdaptor.ConvertGeminiRequest(c, info, request)
}
func removeFunctionResponseID(request *dto.GeminiChatRequest) {
if request == nil {
return
}
if len(request.Contents) > 0 {
for i := range request.Contents {
if len(request.Contents[i].Parts) == 0 {
continue
}
for j := range request.Contents[i].Parts {
part := &request.Contents[i].Parts[j]
if part.FunctionResponse == nil {
continue
}
if len(part.FunctionResponse.ID) > 0 {
part.FunctionResponse.ID = nil
}
}
}
}
if len(request.Requests) > 0 {
for i := range request.Requests {
removeFunctionResponseID(&request.Requests[i])
}
}
}
func (a *Adaptor) ConvertClaudeRequest(c *gin.Context, info *relaycommon.RelayInfo, request *dto.ClaudeRequest) (any, error) {
if v, ok := claudeModelMap[info.UpstreamModelName]; ok {
c.Set("request_model", v)
@@ -182,6 +216,8 @@ func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) {
info.UpstreamModelName = strings.TrimSuffix(info.UpstreamModelName, "-thinking")
} else if strings.HasSuffix(info.UpstreamModelName, "-nothinking") {
info.UpstreamModelName = strings.TrimSuffix(info.UpstreamModelName, "-nothinking")
} else if baseModel, level, ok := reasoning.TrimEffortSuffix(info.UpstreamModelName); ok && level != "" {
info.UpstreamModelName = baseModel
}
}

View File

@@ -36,8 +36,7 @@ func (a *Adaptor) ConvertAudioRequest(c *gin.Context, info *relaycommon.RelayInf
}
func (a *Adaptor) ConvertImageRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.ImageRequest) (any, error) {
//TODO implement me
return nil, errors.New("not implemented")
return request, nil
}
func (a *Adaptor) Init(info *relaycommon.RelayInfo) {
@@ -63,6 +62,8 @@ func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) {
return fmt.Sprintf("%s/embeddings", specialPlan.OpenAIBaseURL), nil
}
return fmt.Sprintf("%s/api/paas/v4/embeddings", baseURL), nil
case relayconstant.RelayModeImagesGenerations:
return fmt.Sprintf("%s/api/paas/v4/images/generations", baseURL), nil
default:
if hasSpecialPlan && specialPlan.OpenAIBaseURL != "" {
return fmt.Sprintf("%s/chat/completions", specialPlan.OpenAIBaseURL), nil
@@ -114,6 +115,9 @@ func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycom
return claude.ClaudeHandler(c, resp, info, claude.RequestModeMessage)
}
default:
if info.RelayMode == relayconstant.RelayModeImagesGenerations {
return zhipu4vImageHandler(c, resp, info)
}
adaptor := openai.Adaptor{}
return adaptor.DoResponse(c, resp, info)
}

View File

@@ -4,6 +4,7 @@ import (
"time"
"github.com/QuantumNous/new-api/dto"
"github.com/QuantumNous/new-api/types"
)
// type ZhipuMessage struct {
@@ -37,7 +38,7 @@ type ZhipuV4Response struct {
Model string `json:"model"`
TextResponseChoices []dto.OpenAITextResponseChoice `json:"choices"`
Usage dto.Usage `json:"usage"`
Error dto.OpenAIError `json:"error"`
Error types.OpenAIError `json:"error"`
}
//

View File

@@ -0,0 +1,127 @@
package zhipu_4v
import (
"io"
"net/http"
"github.com/QuantumNous/new-api/common"
"github.com/QuantumNous/new-api/dto"
"github.com/QuantumNous/new-api/logger"
relaycommon "github.com/QuantumNous/new-api/relay/common"
"github.com/QuantumNous/new-api/service"
"github.com/QuantumNous/new-api/types"
"github.com/gin-gonic/gin"
)
type zhipuImageRequest struct {
Model string `json:"model"`
Prompt string `json:"prompt"`
Quality string `json:"quality,omitempty"`
Size string `json:"size,omitempty"`
WatermarkEnabled *bool `json:"watermark_enabled,omitempty"`
UserID string `json:"user_id,omitempty"`
}
type zhipuImageResponse struct {
Created *int64 `json:"created,omitempty"`
Data []zhipuImageData `json:"data,omitempty"`
ContentFilter any `json:"content_filter,omitempty"`
Usage *dto.Usage `json:"usage,omitempty"`
Error *zhipuImageError `json:"error,omitempty"`
RequestID string `json:"request_id,omitempty"`
ExtendParam map[string]string `json:"extendParam,omitempty"`
}
type zhipuImageError struct {
Code string `json:"code"`
Message string `json:"message"`
}
type zhipuImageData struct {
Url string `json:"url,omitempty"`
ImageUrl string `json:"image_url,omitempty"`
B64Json string `json:"b64_json,omitempty"`
B64Image string `json:"b64_image,omitempty"`
}
type openAIImagePayload struct {
Created int64 `json:"created"`
Data []openAIImageData `json:"data"`
}
type openAIImageData struct {
B64Json string `json:"b64_json"`
}
func zhipu4vImageHandler(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (*dto.Usage, *types.NewAPIError) {
responseBody, err := io.ReadAll(resp.Body)
if err != nil {
return nil, types.NewOpenAIError(err, types.ErrorCodeReadResponseBodyFailed, http.StatusInternalServerError)
}
service.CloseResponseBodyGracefully(resp)
var zhipuResp zhipuImageResponse
if err := common.Unmarshal(responseBody, &zhipuResp); err != nil {
return nil, types.NewOpenAIError(err, types.ErrorCodeBadResponseBody, http.StatusInternalServerError)
}
if zhipuResp.Error != nil && zhipuResp.Error.Message != "" {
return nil, types.WithOpenAIError(types.OpenAIError{
Message: zhipuResp.Error.Message,
Type: "zhipu_image_error",
Code: zhipuResp.Error.Code,
}, resp.StatusCode)
}
payload := openAIImagePayload{}
if zhipuResp.Created != nil && *zhipuResp.Created != 0 {
payload.Created = *zhipuResp.Created
} else {
payload.Created = info.StartTime.Unix()
}
for _, data := range zhipuResp.Data {
url := data.Url
if url == "" {
url = data.ImageUrl
}
if url == "" {
logger.LogWarn(c, "zhipu_image_missing_url")
continue
}
var b64 string
switch {
case data.B64Json != "":
b64 = data.B64Json
case data.B64Image != "":
b64 = data.B64Image
default:
_, downloaded, err := service.GetImageFromUrl(url)
if err != nil {
logger.LogError(c, "zhipu_image_get_b64_failed: "+err.Error())
continue
}
b64 = downloaded
}
if b64 == "" {
logger.LogWarn(c, "zhipu_image_empty_b64")
continue
}
imageData := openAIImageData{
B64Json: b64,
}
payload.Data = append(payload.Data, imageData)
}
jsonResp, err := common.Marshal(payload)
if err != nil {
return nil, types.NewError(err, types.ErrorCodeBadResponseBody)
}
service.IOCopyBytesGracefully(c, resp, jsonResp)
return &dto.Usage{}, nil
}

View File

@@ -11,6 +11,8 @@ import (
"github.com/tidwall/sjson"
)
var negativeIndexRegexp = regexp.MustCompile(`\.(-\d+)`)
type ConditionOperation struct {
Path string `json:"path"` // JSON路径
Mode string `json:"mode"` // full, prefix, suffix, contains, gt, gte, lt, lte
@@ -186,8 +188,7 @@ func checkSingleCondition(jsonStr, contextJSON string, condition ConditionOperat
}
func processNegativeIndex(jsonStr string, path string) string {
re := regexp.MustCompile(`\.(-\d+)`)
matches := re.FindAllStringSubmatch(path, -1)
matches := negativeIndexRegexp.FindAllStringSubmatch(path, -1)
if len(matches) == 0 {
return path

View File

@@ -11,6 +11,7 @@ import (
"github.com/QuantumNous/new-api/constant"
"github.com/QuantumNous/new-api/dto"
relayconstant "github.com/QuantumNous/new-api/relay/constant"
"github.com/QuantumNous/new-api/setting/model_setting"
"github.com/QuantumNous/new-api/types"
"github.com/gin-gonic/gin"
@@ -81,8 +82,9 @@ type TokenCountMeta struct {
type RelayInfo struct {
TokenId int
TokenKey string
TokenGroup string
UserId int
UsingGroup string // 使用的分组
UsingGroup string // 使用的分组当auto跨分组重试时会变动
UserGroup string // 用户所在分组
TokenUnlimited bool
StartTime time.Time
@@ -373,6 +375,12 @@ func genBaseRelayInfo(c *gin.Context, request dto.Request) *RelayInfo {
//channelId := common.GetContextKeyInt(c, constant.ContextKeyChannelId)
//paramOverride := common.GetContextKeyStringMap(c, constant.ContextKeyChannelParamOverride)
tokenGroup := common.GetContextKeyString(c, constant.ContextKeyTokenGroup)
// 当令牌分组为空时,表示使用用户分组
if tokenGroup == "" {
tokenGroup = common.GetContextKeyString(c, constant.ContextKeyUserGroup)
}
startTime := common.GetContextKeyTime(c, constant.ContextKeyRequestStartTime)
if startTime.IsZero() {
startTime = time.Now()
@@ -400,6 +408,7 @@ func genBaseRelayInfo(c *gin.Context, request dto.Request) *RelayInfo {
TokenId: common.GetContextKeyInt(c, constant.ContextKeyTokenId),
TokenKey: common.GetContextKeyString(c, constant.ContextKeyTokenKey),
TokenUnlimited: common.GetContextKeyBool(c, constant.ContextKeyTokenUnlimited),
TokenGroup: tokenGroup,
isFirstResponse: true,
RelayMode: relayconstant.Path2RelayMode(c.Request.URL.Path),
@@ -626,3 +635,47 @@ func RemoveDisabledFields(jsonData []byte, channelOtherSettings dto.ChannelOther
}
return jsonDataAfter, nil
}
// RemoveGeminiDisabledFields removes disabled fields from Gemini request JSON data
// Currently supports removing functionResponse.id field which Vertex AI does not support
func RemoveGeminiDisabledFields(jsonData []byte) ([]byte, error) {
if !model_setting.GetGeminiSettings().RemoveFunctionResponseIdEnabled {
return jsonData, nil
}
var data map[string]interface{}
if err := common.Unmarshal(jsonData, &data); err != nil {
common.SysError("RemoveGeminiDisabledFields Unmarshal error: " + err.Error())
return jsonData, nil
}
// Process contents array
// Handle both camelCase (functionResponse) and snake_case (function_response)
if contents, ok := data["contents"].([]interface{}); ok {
for _, content := range contents {
if contentMap, ok := content.(map[string]interface{}); ok {
if parts, ok := contentMap["parts"].([]interface{}); ok {
for _, part := range parts {
if partMap, ok := part.(map[string]interface{}); ok {
// Check functionResponse (camelCase)
if funcResp, ok := partMap["functionResponse"].(map[string]interface{}); ok {
delete(funcResp, "id")
}
// Check function_response (snake_case)
if funcResp, ok := partMap["function_response"].(map[string]interface{}); ok {
delete(funcResp, "id")
}
}
}
}
}
}
}
jsonDataAfter, err := common.Marshal(data)
if err != nil {
common.SysError("RemoveGeminiDisabledFields Marshal error: " + err.Error())
return jsonData, nil
}
return jsonDataAfter, nil
}

View File

@@ -181,7 +181,7 @@ func TextHelper(c *gin.Context, info *relaycommon.RelayInfo) (newAPIError *types
return newApiErr
}
if strings.HasPrefix(info.OriginModelName, "gpt-4o-audio") {
if usage.(*dto.Usage).CompletionTokenDetails.AudioTokens > 0 || usage.(*dto.Usage).PromptTokensDetails.AudioTokens > 0 {
service.PostAudioConsumeQuota(c, info, usage.(*dto.Usage), "")
} else {
postConsumeQuota(c, info, usage.(*dto.Usage), "")
@@ -300,14 +300,20 @@ func postConsumeQuota(ctx *gin.Context, relayInfo *relaycommon.RelayInfo, usage
if !relayInfo.PriceData.UsePrice {
baseTokens := dPromptTokens
// 减去 cached tokens
// Anthropic API 的 input_tokens 已经不包含缓存 tokens不需要减去
// OpenAI/OpenRouter 等 API 的 prompt_tokens 包含缓存 tokens需要减去
var cachedTokensWithRatio decimal.Decimal
if !dCacheTokens.IsZero() {
baseTokens = baseTokens.Sub(dCacheTokens)
if relayInfo.ChannelType != constant.ChannelTypeAnthropic {
baseTokens = baseTokens.Sub(dCacheTokens)
}
cachedTokensWithRatio = dCacheTokens.Mul(dCacheRatio)
}
var dCachedCreationTokensWithRatio decimal.Decimal
if !dCachedCreationTokens.IsZero() {
baseTokens = baseTokens.Sub(dCachedCreationTokens)
if relayInfo.ChannelType != constant.ChannelTypeAnthropic {
baseTokens = baseTokens.Sub(dCachedCreationTokens)
}
dCachedCreationTokensWithRatio = dCachedCreationTokens.Mul(dCachedCreationRatio)
}

View File

@@ -14,15 +14,28 @@ import (
"github.com/gorilla/websocket"
)
func FlushWriter(c *gin.Context) error {
if c.Writer == nil {
func FlushWriter(c *gin.Context) (err error) {
defer func() {
if r := recover(); r != nil {
err = fmt.Errorf("flush panic recovered: %v", r)
}
}()
if c == nil || c.Writer == nil {
return nil
}
if flusher, ok := c.Writer.(http.Flusher); ok {
flusher.Flush()
return nil
if c.Request != nil && c.Request.Context().Err() != nil {
return fmt.Errorf("request context done: %w", c.Request.Context().Err())
}
return errors.New("streaming error: flusher not found")
flusher, ok := c.Writer.(http.Flusher)
if !ok {
return errors.New("streaming error: flusher not found")
}
flusher.Flush()
return nil
}
func SetEventStreamHeaders(c *gin.Context) {
@@ -66,17 +79,31 @@ func ResponseChunkData(c *gin.Context, resp dto.ResponsesStreamResponse, data st
}
func StringData(c *gin.Context, str string) error {
//str = strings.TrimPrefix(str, "data: ")
//str = strings.TrimSuffix(str, "\r")
if c == nil || c.Writer == nil {
return errors.New("context or writer is nil")
}
if c.Request != nil && c.Request.Context().Err() != nil {
return fmt.Errorf("request context done: %w", c.Request.Context().Err())
}
c.Render(-1, common.CustomEvent{Data: "data: " + str})
_ = FlushWriter(c)
return nil
return FlushWriter(c)
}
func PingData(c *gin.Context) error {
c.Writer.Write([]byte(": PING\n\n"))
_ = FlushWriter(c)
return nil
if c == nil || c.Writer == nil {
return errors.New("context or writer is nil")
}
if c.Request != nil && c.Request.Context().Err() != nil {
return fmt.Errorf("request context done: %w", c.Request.Context().Err())
}
if _, err := c.Writer.Write([]byte(": PING\n\n")); err != nil {
return fmt.Errorf("write ping data failed: %w", err)
}
return FlushWriter(c)
}
func ObjectData(c *gin.Context, object interface{}) error {

View File

@@ -32,7 +32,94 @@ func RelayTaskSubmit(c *gin.Context, info *relaycommon.RelayInfo) (taskErr *dto.
if info.TaskRelayInfo == nil {
info.TaskRelayInfo = &relaycommon.TaskRelayInfo{}
}
path := c.Request.URL.Path
if strings.Contains(path, "/v1/videos/") && strings.HasSuffix(path, "/remix") {
info.Action = constant.TaskActionRemix
}
// 提取 remix 任务的 video_id
if info.Action == constant.TaskActionRemix {
videoID := c.Param("video_id")
if strings.TrimSpace(videoID) == "" {
return service.TaskErrorWrapperLocal(fmt.Errorf("video_id is required"), "invalid_request", http.StatusBadRequest)
}
info.OriginTaskID = videoID
}
platform := constant.TaskPlatform(c.GetString("platform"))
// 获取原始任务信息
if info.OriginTaskID != "" {
originTask, exist, err := model.GetByTaskId(info.UserId, info.OriginTaskID)
if err != nil {
taskErr = service.TaskErrorWrapper(err, "get_origin_task_failed", http.StatusInternalServerError)
return
}
if !exist {
taskErr = service.TaskErrorWrapperLocal(errors.New("task_origin_not_exist"), "task_not_exist", http.StatusBadRequest)
return
}
if info.OriginModelName == "" {
if originTask.Properties.OriginModelName != "" {
info.OriginModelName = originTask.Properties.OriginModelName
} else if originTask.Properties.UpstreamModelName != "" {
info.OriginModelName = originTask.Properties.UpstreamModelName
} else {
var taskData map[string]interface{}
_ = json.Unmarshal(originTask.Data, &taskData)
if m, ok := taskData["model"].(string); ok && m != "" {
info.OriginModelName = m
platform = originTask.Platform
}
}
}
if originTask.ChannelId != info.ChannelId {
channel, err := model.GetChannelById(originTask.ChannelId, true)
if err != nil {
taskErr = service.TaskErrorWrapperLocal(err, "channel_not_found", http.StatusBadRequest)
return
}
if channel.Status != common.ChannelStatusEnabled {
taskErr = service.TaskErrorWrapperLocal(errors.New("the channel of the origin task is disabled"), "task_channel_disable", http.StatusBadRequest)
return
}
key, _, newAPIError := channel.GetNextEnabledKey()
if newAPIError != nil {
taskErr = service.TaskErrorWrapper(newAPIError, "channel_no_available_key", newAPIError.StatusCode)
return
}
common.SetContextKey(c, constant.ContextKeyChannelKey, key)
common.SetContextKey(c, constant.ContextKeyChannelType, channel.Type)
common.SetContextKey(c, constant.ContextKeyChannelBaseUrl, channel.GetBaseURL())
common.SetContextKey(c, constant.ContextKeyChannelId, originTask.ChannelId)
info.ChannelBaseUrl = channel.GetBaseURL()
info.ChannelId = originTask.ChannelId
info.ChannelType = channel.Type
info.ApiKey = key
platform = originTask.Platform
}
// 使用原始任务的参数
if info.Action == constant.TaskActionRemix {
var taskData map[string]interface{}
_ = json.Unmarshal(originTask.Data, &taskData)
secondsStr, _ := taskData["seconds"].(string)
seconds, _ := strconv.Atoi(secondsStr)
if seconds <= 0 {
seconds = 4
}
sizeStr, _ := taskData["size"].(string)
if info.PriceData.OtherRatios == nil {
info.PriceData.OtherRatios = map[string]float64{}
}
info.PriceData.OtherRatios["seconds"] = float64(seconds)
info.PriceData.OtherRatios["size"] = 1
if sizeStr == "1792x1024" || sizeStr == "1024x1792" {
info.PriceData.OtherRatios["size"] = 1.666667
}
}
}
if platform == "" {
platform = GetTaskPlatform(c)
}
@@ -94,34 +181,6 @@ func RelayTaskSubmit(c *gin.Context, info *relaycommon.RelayInfo) (taskErr *dto.
return
}
if info.OriginTaskID != "" {
originTask, exist, err := model.GetByTaskId(info.UserId, info.OriginTaskID)
if err != nil {
taskErr = service.TaskErrorWrapper(err, "get_origin_task_failed", http.StatusInternalServerError)
return
}
if !exist {
taskErr = service.TaskErrorWrapperLocal(errors.New("task_origin_not_exist"), "task_not_exist", http.StatusBadRequest)
return
}
if originTask.ChannelId != info.ChannelId {
channel, err := model.GetChannelById(originTask.ChannelId, true)
if err != nil {
taskErr = service.TaskErrorWrapperLocal(err, "channel_not_found", http.StatusBadRequest)
return
}
if channel.Status != common.ChannelStatusEnabled {
return service.TaskErrorWrapperLocal(errors.New("该任务所属渠道已被禁用"), "task_channel_disable", http.StatusBadRequest)
}
c.Set("base_url", channel.GetBaseURL())
c.Set("channel_id", originTask.ChannelId)
c.Request.Header.Set("Authorization", fmt.Sprintf("Bearer %s", channel.Key))
info.ChannelBaseUrl = channel.GetBaseURL()
info.ChannelId = originTask.ChannelId
}
}
// build body
requestBody, err := adaptor.BuildRequestBody(c, info)
if err != nil {
@@ -137,7 +196,7 @@ func RelayTaskSubmit(c *gin.Context, info *relaycommon.RelayInfo) (taskErr *dto.
// handle response
if resp != nil && resp.StatusCode != http.StatusOK {
responseBody, _ := io.ReadAll(resp.Body)
taskErr = service.TaskErrorWrapper(fmt.Errorf(string(responseBody)), "fail_to_fetch_task", resp.StatusCode)
taskErr = service.TaskErrorWrapper(fmt.Errorf("%s", string(responseBody)), "fail_to_fetch_task", resp.StatusCode)
return
}
@@ -326,6 +385,7 @@ func videoFetchByIDRespBodyBuilder(c *gin.Context) (respBody []byte, taskResp *d
if channelModel.GetBaseURL() != "" {
baseURL = channelModel.GetBaseURL()
}
proxy := channelModel.GetSetting().Proxy
adaptor := GetTaskAdaptor(constant.TaskPlatform(strconv.Itoa(channelModel.Type)))
if adaptor == nil {
return
@@ -333,7 +393,7 @@ func videoFetchByIDRespBodyBuilder(c *gin.Context) (respBody []byte, taskResp *d
resp, err2 := adaptor.FetchTask(baseURL, channelModel.Key, map[string]any{
"task_id": originTask.TaskID,
"action": originTask.Action,
})
}, proxy)
if err2 != nil || resp == nil {
return
}