mirror of
https://github.com/QuantumNous/new-api.git
synced 2026-06-07 22:09:57 +00:00
Merge branches 'main' and 'main' of github.com:danding5/new-api
# Conflicts: # common/api_type.go # constant/api_type.go # constant/channel.go # relay/relay_adaptor.go # web/src/constants/channel.constants.js
This commit is contained in:
@@ -7,104 +7,43 @@ import (
|
||||
"one-api/common"
|
||||
"one-api/dto"
|
||||
relaycommon "one-api/relay/common"
|
||||
relayconstant "one-api/relay/constant"
|
||||
"one-api/relay/helper"
|
||||
"one-api/service"
|
||||
"one-api/setting"
|
||||
"one-api/types"
|
||||
"strings"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
func getAndValidAudioRequest(c *gin.Context, info *relaycommon.RelayInfo) (*dto.AudioRequest, error) {
|
||||
audioRequest := &dto.AudioRequest{}
|
||||
err := common.UnmarshalBodyReusable(c, audioRequest)
|
||||
func AudioHelper(c *gin.Context, info *relaycommon.RelayInfo) (newAPIError *types.NewAPIError) {
|
||||
info.InitChannelMeta(c)
|
||||
|
||||
audioReq, ok := info.Request.(*dto.AudioRequest)
|
||||
if !ok {
|
||||
return types.NewError(errors.New("invalid request type"), types.ErrorCodeInvalidRequest, types.ErrOptionWithSkipRetry())
|
||||
}
|
||||
|
||||
request, err := common.DeepCopy(audioReq)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
return types.NewError(fmt.Errorf("failed to copy request to AudioRequest: %w", err), types.ErrorCodeInvalidRequest, types.ErrOptionWithSkipRetry())
|
||||
}
|
||||
switch info.RelayMode {
|
||||
case relayconstant.RelayModeAudioSpeech:
|
||||
if audioRequest.Model == "" {
|
||||
return nil, errors.New("model is required")
|
||||
}
|
||||
if setting.ShouldCheckPromptSensitive() {
|
||||
words, err := service.CheckSensitiveInput(audioRequest.Input)
|
||||
if err != nil {
|
||||
common.LogWarn(c, fmt.Sprintf("user sensitive words detected: %s", strings.Join(words, ",")))
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
default:
|
||||
err = c.Request.ParseForm()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
formData := c.Request.PostForm
|
||||
if audioRequest.Model == "" {
|
||||
audioRequest.Model = formData.Get("model")
|
||||
}
|
||||
|
||||
if audioRequest.Model == "" {
|
||||
return nil, errors.New("model is required")
|
||||
}
|
||||
audioRequest.ResponseFormat = formData.Get("response_format")
|
||||
if audioRequest.ResponseFormat == "" {
|
||||
audioRequest.ResponseFormat = "json"
|
||||
}
|
||||
}
|
||||
return audioRequest, nil
|
||||
}
|
||||
|
||||
func AudioHelper(c *gin.Context) (newAPIError *types.NewAPIError) {
|
||||
relayInfo := relaycommon.GenRelayInfoOpenAIAudio(c)
|
||||
audioRequest, err := getAndValidAudioRequest(c, relayInfo)
|
||||
|
||||
err = helper.ModelMappedHelper(c, info, request)
|
||||
if err != nil {
|
||||
common.LogError(c, fmt.Sprintf("getAndValidAudioRequest failed: %s", err.Error()))
|
||||
return types.NewError(err, types.ErrorCodeInvalidRequest)
|
||||
return types.NewError(err, types.ErrorCodeChannelModelMappedError, types.ErrOptionWithSkipRetry())
|
||||
}
|
||||
|
||||
promptTokens := 0
|
||||
preConsumedTokens := common.PreConsumedQuota
|
||||
if relayInfo.RelayMode == relayconstant.RelayModeAudioSpeech {
|
||||
promptTokens = service.CountTTSToken(audioRequest.Input, audioRequest.Model)
|
||||
preConsumedTokens = promptTokens
|
||||
relayInfo.PromptTokens = promptTokens
|
||||
}
|
||||
|
||||
priceData, err := helper.ModelPriceHelper(c, relayInfo, preConsumedTokens, 0)
|
||||
if err != nil {
|
||||
return types.NewError(err, types.ErrorCodeModelPriceError)
|
||||
}
|
||||
|
||||
preConsumedQuota, userQuota, openaiErr := preConsumeQuota(c, priceData.ShouldPreConsumedQuota, relayInfo)
|
||||
if openaiErr != nil {
|
||||
return openaiErr
|
||||
}
|
||||
defer func() {
|
||||
if openaiErr != nil {
|
||||
returnPreConsumedQuota(c, relayInfo, userQuota, preConsumedQuota)
|
||||
}
|
||||
}()
|
||||
|
||||
err = helper.ModelMappedHelper(c, relayInfo, audioRequest)
|
||||
if err != nil {
|
||||
return types.NewError(err, types.ErrorCodeChannelModelMappedError)
|
||||
}
|
||||
|
||||
adaptor := GetAdaptor(relayInfo.ApiType)
|
||||
adaptor := GetAdaptor(info.ApiType)
|
||||
if adaptor == nil {
|
||||
return types.NewError(fmt.Errorf("invalid api type: %d", relayInfo.ApiType), types.ErrorCodeInvalidApiType)
|
||||
return types.NewError(fmt.Errorf("invalid api type: %d", info.ApiType), types.ErrorCodeInvalidApiType, types.ErrOptionWithSkipRetry())
|
||||
}
|
||||
adaptor.Init(relayInfo)
|
||||
adaptor.Init(info)
|
||||
|
||||
ioReader, err := adaptor.ConvertAudioRequest(c, relayInfo, *audioRequest)
|
||||
ioReader, err := adaptor.ConvertAudioRequest(c, info, *request)
|
||||
if err != nil {
|
||||
return types.NewError(err, types.ErrorCodeConvertRequestFailed)
|
||||
return types.NewError(err, types.ErrorCodeConvertRequestFailed, types.ErrOptionWithSkipRetry())
|
||||
}
|
||||
|
||||
resp, err := adaptor.DoRequest(c, relayInfo, ioReader)
|
||||
resp, err := adaptor.DoRequest(c, info, ioReader)
|
||||
if err != nil {
|
||||
return types.NewError(err, types.ErrorCodeDoRequestFailed)
|
||||
}
|
||||
@@ -121,14 +60,14 @@ func AudioHelper(c *gin.Context) (newAPIError *types.NewAPIError) {
|
||||
}
|
||||
}
|
||||
|
||||
usage, newAPIError := adaptor.DoResponse(c, httpResp, relayInfo)
|
||||
usage, newAPIError := adaptor.DoResponse(c, httpResp, info)
|
||||
if newAPIError != nil {
|
||||
// reset status code 重置状态码
|
||||
service.ResetStatusCode(newAPIError, statusCodeMappingStr)
|
||||
return newAPIError
|
||||
}
|
||||
|
||||
postConsumeQuota(c, relayInfo, usage.(*dto.Usage), preConsumedQuota, userQuota, priceData, "")
|
||||
postConsumeQuota(c, info, usage.(*dto.Usage), "")
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -26,19 +26,20 @@ type Adaptor interface {
|
||||
GetModelList() []string
|
||||
GetChannelName() string
|
||||
ConvertClaudeRequest(c *gin.Context, info *relaycommon.RelayInfo, request *dto.ClaudeRequest) (any, error)
|
||||
ConvertGeminiRequest(c *gin.Context, info *relaycommon.RelayInfo, request *dto.GeminiChatRequest) (any, error)
|
||||
}
|
||||
|
||||
type TaskAdaptor interface {
|
||||
Init(info *relaycommon.TaskRelayInfo)
|
||||
Init(info *relaycommon.RelayInfo)
|
||||
|
||||
ValidateRequestAndSetAction(c *gin.Context, info *relaycommon.TaskRelayInfo) *dto.TaskError
|
||||
ValidateRequestAndSetAction(c *gin.Context, info *relaycommon.RelayInfo) *dto.TaskError
|
||||
|
||||
BuildRequestURL(info *relaycommon.TaskRelayInfo) (string, error)
|
||||
BuildRequestHeader(c *gin.Context, req *http.Request, info *relaycommon.TaskRelayInfo) error
|
||||
BuildRequestBody(c *gin.Context, info *relaycommon.TaskRelayInfo) (io.Reader, error)
|
||||
BuildRequestURL(info *relaycommon.RelayInfo) (string, error)
|
||||
BuildRequestHeader(c *gin.Context, req *http.Request, info *relaycommon.RelayInfo) error
|
||||
BuildRequestBody(c *gin.Context, info *relaycommon.RelayInfo) (io.Reader, error)
|
||||
|
||||
DoRequest(c *gin.Context, info *relaycommon.TaskRelayInfo, requestBody io.Reader) (*http.Response, error)
|
||||
DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.TaskRelayInfo) (taskID string, taskData []byte, err *dto.TaskError)
|
||||
DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (*http.Response, error)
|
||||
DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (taskID string, taskData []byte, err *dto.TaskError)
|
||||
|
||||
GetModelList() []string
|
||||
GetChannelName() string
|
||||
|
||||
@@ -7,10 +7,12 @@ import (
|
||||
"net/http"
|
||||
"one-api/dto"
|
||||
"one-api/relay/channel"
|
||||
"one-api/relay/channel/claude"
|
||||
"one-api/relay/channel/openai"
|
||||
relaycommon "one-api/relay/common"
|
||||
"one-api/relay/constant"
|
||||
"one-api/types"
|
||||
"strings"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
@@ -18,10 +20,13 @@ import (
|
||||
type Adaptor struct {
|
||||
}
|
||||
|
||||
func (a *Adaptor) ConvertClaudeRequest(*gin.Context, *relaycommon.RelayInfo, *dto.ClaudeRequest) (any, error) {
|
||||
func (a *Adaptor) ConvertGeminiRequest(*gin.Context, *relaycommon.RelayInfo, *dto.GeminiChatRequest) (any, error) {
|
||||
//TODO implement me
|
||||
panic("implement me")
|
||||
return nil, nil
|
||||
return nil, errors.New("not implemented")
|
||||
}
|
||||
|
||||
func (a *Adaptor) ConvertClaudeRequest(c *gin.Context, info *relaycommon.RelayInfo, req *dto.ClaudeRequest) (any, error) {
|
||||
return req, nil
|
||||
}
|
||||
|
||||
func (a *Adaptor) Init(info *relaycommon.RelayInfo) {
|
||||
@@ -29,18 +34,26 @@ func (a *Adaptor) Init(info *relaycommon.RelayInfo) {
|
||||
|
||||
func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) {
|
||||
var fullRequestURL string
|
||||
switch info.RelayMode {
|
||||
case constant.RelayModeEmbeddings:
|
||||
fullRequestURL = fmt.Sprintf("%s/compatible-mode/v1/embeddings", info.BaseUrl)
|
||||
case constant.RelayModeRerank:
|
||||
fullRequestURL = fmt.Sprintf("%s/api/v1/services/rerank/text-rerank/text-rerank", info.BaseUrl)
|
||||
case constant.RelayModeImagesGenerations:
|
||||
fullRequestURL = fmt.Sprintf("%s/api/v1/services/aigc/text2image/image-synthesis", info.BaseUrl)
|
||||
case constant.RelayModeCompletions:
|
||||
fullRequestURL = fmt.Sprintf("%s/compatible-mode/v1/completions", info.BaseUrl)
|
||||
switch info.RelayFormat {
|
||||
case types.RelayFormatClaude:
|
||||
fullRequestURL = fmt.Sprintf("%s/api/v2/apps/claude-code-proxy/v1/messages", info.ChannelBaseUrl)
|
||||
default:
|
||||
fullRequestURL = fmt.Sprintf("%s/compatible-mode/v1/chat/completions", info.BaseUrl)
|
||||
switch info.RelayMode {
|
||||
case constant.RelayModeEmbeddings:
|
||||
fullRequestURL = fmt.Sprintf("%s/compatible-mode/v1/embeddings", info.ChannelBaseUrl)
|
||||
case constant.RelayModeRerank:
|
||||
fullRequestURL = fmt.Sprintf("%s/api/v1/services/rerank/text-rerank/text-rerank", info.ChannelBaseUrl)
|
||||
case constant.RelayModeImagesGenerations:
|
||||
fullRequestURL = fmt.Sprintf("%s/api/v1/services/aigc/text2image/image-synthesis", info.ChannelBaseUrl)
|
||||
case constant.RelayModeImagesEdits:
|
||||
fullRequestURL = fmt.Sprintf("%s/api/v1/services/aigc/multimodal-generation/generation", info.ChannelBaseUrl)
|
||||
case constant.RelayModeCompletions:
|
||||
fullRequestURL = fmt.Sprintf("%s/compatible-mode/v1/completions", info.ChannelBaseUrl)
|
||||
default:
|
||||
fullRequestURL = fmt.Sprintf("%s/compatible-mode/v1/chat/completions", info.ChannelBaseUrl)
|
||||
}
|
||||
}
|
||||
|
||||
return fullRequestURL, nil
|
||||
}
|
||||
|
||||
@@ -53,6 +66,12 @@ func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Header, info *rel
|
||||
if c.GetString("plugin") != "" {
|
||||
req.Set("X-DashScope-Plugin", c.GetString("plugin"))
|
||||
}
|
||||
if info.RelayMode == constant.RelayModeImagesGenerations {
|
||||
req.Set("X-DashScope-Async", "enable")
|
||||
}
|
||||
if info.RelayMode == constant.RelayModeImagesEdits {
|
||||
req.Set("Content-Type", "application/json")
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -60,7 +79,13 @@ func (a *Adaptor) ConvertOpenAIRequest(c *gin.Context, info *relaycommon.RelayIn
|
||||
if request == nil {
|
||||
return nil, errors.New("request is nil")
|
||||
}
|
||||
|
||||
// docs: https://bailian.console.aliyun.com/?tab=api#/api/?type=model&url=2712216
|
||||
// fix: InternalError.Algo.InvalidParameter: The value of the enable_thinking parameter is restricted to True.
|
||||
if strings.Contains(request.Model, "thinking") {
|
||||
request.EnableThinking = true
|
||||
request.Stream = true
|
||||
info.IsStream = true
|
||||
}
|
||||
// fix: ali parameter.enable_thinking must be set to false for non-streaming calls
|
||||
if !info.IsStream {
|
||||
request.EnableThinking = false
|
||||
@@ -74,8 +99,30 @@ func (a *Adaptor) ConvertOpenAIRequest(c *gin.Context, info *relaycommon.RelayIn
|
||||
}
|
||||
|
||||
func (a *Adaptor) ConvertImageRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.ImageRequest) (any, error) {
|
||||
aliRequest := oaiImage2Ali(request)
|
||||
return aliRequest, nil
|
||||
if info.RelayMode == constant.RelayModeImagesGenerations {
|
||||
aliRequest, err := oaiImage2Ali(request)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("convert image request failed: %w", err)
|
||||
}
|
||||
return aliRequest, nil
|
||||
} else if info.RelayMode == constant.RelayModeImagesEdits {
|
||||
// ali image edit https://bailian.console.aliyun.com/?tab=api#/api/?type=model&url=2976416
|
||||
// 如果用户使用表单,则需要解析表单数据
|
||||
if strings.Contains(c.Request.Header.Get("Content-Type"), "multipart/form-data") {
|
||||
aliRequest, err := oaiFormEdit2AliImageEdit(c, info, request)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("convert image edit form request failed: %w", err)
|
||||
}
|
||||
return aliRequest, nil
|
||||
} else {
|
||||
aliRequest, err := oaiImage2Ali(request)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("convert image request failed: %w", err)
|
||||
}
|
||||
return aliRequest, nil
|
||||
}
|
||||
}
|
||||
return nil, fmt.Errorf("unsupported image relay mode: %d", info.RelayMode)
|
||||
}
|
||||
|
||||
func (a *Adaptor) ConvertRerankRequest(c *gin.Context, relayMode int, request dto.RerankRequest) (any, error) {
|
||||
@@ -101,21 +148,27 @@ func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, request
|
||||
}
|
||||
|
||||
func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage any, err *types.NewAPIError) {
|
||||
switch info.RelayMode {
|
||||
case constant.RelayModeImagesGenerations:
|
||||
err, usage = aliImageHandler(c, resp, info)
|
||||
case constant.RelayModeEmbeddings:
|
||||
err, usage = aliEmbeddingHandler(c, resp)
|
||||
case constant.RelayModeRerank:
|
||||
err, usage = RerankHandler(c, resp, info)
|
||||
default:
|
||||
switch info.RelayFormat {
|
||||
case types.RelayFormatClaude:
|
||||
if info.IsStream {
|
||||
usage, err = openai.OaiStreamHandler(c, info, resp)
|
||||
return claude.ClaudeStreamHandler(c, resp, info, claude.RequestModeMessage)
|
||||
} else {
|
||||
usage, err = openai.OpenaiHandler(c, info, resp)
|
||||
return claude.ClaudeHandler(c, resp, info, claude.RequestModeMessage)
|
||||
}
|
||||
default:
|
||||
switch info.RelayMode {
|
||||
case constant.RelayModeImagesGenerations:
|
||||
err, usage = aliImageHandler(c, resp, info)
|
||||
case constant.RelayModeImagesEdits:
|
||||
err, usage = aliImageEditHandler(c, resp, info)
|
||||
case constant.RelayModeRerank:
|
||||
err, usage = RerankHandler(c, resp, info)
|
||||
default:
|
||||
adaptor := openai.Adaptor{}
|
||||
usage, err = adaptor.DoResponse(c, resp, info)
|
||||
}
|
||||
return usage, err
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
func (a *Adaptor) GetModelList() []string {
|
||||
|
||||
@@ -3,10 +3,15 @@ package ali
|
||||
import "one-api/dto"
|
||||
|
||||
type AliMessage struct {
|
||||
Content string `json:"content"`
|
||||
Content any `json:"content"`
|
||||
Role string `json:"role"`
|
||||
}
|
||||
|
||||
type AliMediaContent struct {
|
||||
Image string `json:"image,omitempty"`
|
||||
Text string `json:"text,omitempty"`
|
||||
}
|
||||
|
||||
type AliInput struct {
|
||||
Prompt string `json:"prompt,omitempty"`
|
||||
//History []AliMessage `json:"history,omitempty"`
|
||||
@@ -70,13 +75,14 @@ type TaskResult struct {
|
||||
}
|
||||
|
||||
type AliOutput struct {
|
||||
TaskId string `json:"task_id,omitempty"`
|
||||
TaskStatus string `json:"task_status,omitempty"`
|
||||
Text string `json:"text"`
|
||||
FinishReason string `json:"finish_reason"`
|
||||
Message string `json:"message,omitempty"`
|
||||
Code string `json:"code,omitempty"`
|
||||
Results []TaskResult `json:"results,omitempty"`
|
||||
TaskId string `json:"task_id,omitempty"`
|
||||
TaskStatus string `json:"task_status,omitempty"`
|
||||
Text string `json:"text"`
|
||||
FinishReason string `json:"finish_reason"`
|
||||
Message string `json:"message,omitempty"`
|
||||
Code string `json:"code,omitempty"`
|
||||
Results []TaskResult `json:"results,omitempty"`
|
||||
Choices []map[string]any `json:"choices,omitempty"`
|
||||
}
|
||||
|
||||
type AliResponse struct {
|
||||
@@ -86,20 +92,26 @@ type AliResponse struct {
|
||||
}
|
||||
|
||||
type AliImageRequest struct {
|
||||
Model string `json:"model"`
|
||||
Input struct {
|
||||
Prompt string `json:"prompt"`
|
||||
NegativePrompt string `json:"negative_prompt,omitempty"`
|
||||
} `json:"input"`
|
||||
Parameters struct {
|
||||
Size string `json:"size,omitempty"`
|
||||
N int `json:"n,omitempty"`
|
||||
Steps string `json:"steps,omitempty"`
|
||||
Scale string `json:"scale,omitempty"`
|
||||
} `json:"parameters,omitempty"`
|
||||
Model string `json:"model"`
|
||||
Input any `json:"input"`
|
||||
Parameters any `json:"parameters,omitempty"`
|
||||
ResponseFormat string `json:"response_format,omitempty"`
|
||||
}
|
||||
|
||||
type AliImageParameters struct {
|
||||
Size string `json:"size,omitempty"`
|
||||
N int `json:"n,omitempty"`
|
||||
Steps string `json:"steps,omitempty"`
|
||||
Scale string `json:"scale,omitempty"`
|
||||
Watermark *bool `json:"watermark,omitempty"`
|
||||
}
|
||||
|
||||
type AliImageInput struct {
|
||||
Prompt string `json:"prompt,omitempty"`
|
||||
NegativePrompt string `json:"negative_prompt,omitempty"`
|
||||
Messages []AliMessage `json:"messages,omitempty"`
|
||||
}
|
||||
|
||||
type AliRerankParameters struct {
|
||||
TopN *int `json:"top_n,omitempty"`
|
||||
ReturnDocuments *bool `json:"return_documents,omitempty"`
|
||||
|
||||
@@ -1,13 +1,16 @@
|
||||
package ali
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"context"
|
||||
"encoding/base64"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"mime/multipart"
|
||||
"net/http"
|
||||
"one-api/common"
|
||||
"one-api/dto"
|
||||
"one-api/logger"
|
||||
relaycommon "one-api/relay/common"
|
||||
"one-api/service"
|
||||
"one-api/types"
|
||||
@@ -17,19 +20,139 @@ import (
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
func oaiImage2Ali(request dto.ImageRequest) *AliImageRequest {
|
||||
func oaiImage2Ali(request dto.ImageRequest) (*AliImageRequest, error) {
|
||||
var imageRequest AliImageRequest
|
||||
imageRequest.Model = request.Model
|
||||
imageRequest.ResponseFormat = request.ResponseFormat
|
||||
logger.LogJson(context.Background(), "oaiImage2Ali request extra", request.Extra)
|
||||
if request.Extra != nil {
|
||||
if val, ok := request.Extra["parameters"]; ok {
|
||||
err := common.Unmarshal(val, &imageRequest.Parameters)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("invalid parameters field: %w", err)
|
||||
}
|
||||
}
|
||||
if val, ok := request.Extra["input"]; ok {
|
||||
err := common.Unmarshal(val, &imageRequest.Input)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("invalid input field: %w", err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if imageRequest.Parameters == nil {
|
||||
imageRequest.Parameters = AliImageParameters{
|
||||
Size: strings.Replace(request.Size, "x", "*", -1),
|
||||
N: int(request.N),
|
||||
Watermark: request.Watermark,
|
||||
}
|
||||
}
|
||||
|
||||
if imageRequest.Input == nil {
|
||||
imageRequest.Input = AliImageInput{
|
||||
Prompt: request.Prompt,
|
||||
}
|
||||
}
|
||||
|
||||
return &imageRequest, nil
|
||||
}
|
||||
|
||||
func oaiFormEdit2AliImageEdit(c *gin.Context, info *relaycommon.RelayInfo, request dto.ImageRequest) (*AliImageRequest, error) {
|
||||
var imageRequest AliImageRequest
|
||||
imageRequest.Input.Prompt = request.Prompt
|
||||
imageRequest.Model = request.Model
|
||||
imageRequest.Parameters.Size = strings.Replace(request.Size, "x", "*", -1)
|
||||
imageRequest.Parameters.N = request.N
|
||||
imageRequest.ResponseFormat = request.ResponseFormat
|
||||
|
||||
return &imageRequest
|
||||
mf := c.Request.MultipartForm
|
||||
if mf == nil {
|
||||
if _, err := c.MultipartForm(); err != nil {
|
||||
return nil, fmt.Errorf("failed to parse image edit form request: %w", err)
|
||||
}
|
||||
mf = c.Request.MultipartForm
|
||||
}
|
||||
|
||||
var imageFiles []*multipart.FileHeader
|
||||
var exists bool
|
||||
|
||||
// First check for standard "image" field
|
||||
if imageFiles, exists = mf.File["image"]; !exists || len(imageFiles) == 0 {
|
||||
// If not found, check for "image[]" field
|
||||
if imageFiles, exists = mf.File["image[]"]; !exists || len(imageFiles) == 0 {
|
||||
// If still not found, iterate through all fields to find any that start with "image["
|
||||
foundArrayImages := false
|
||||
for fieldName, files := range mf.File {
|
||||
if strings.HasPrefix(fieldName, "image[") && len(files) > 0 {
|
||||
foundArrayImages = true
|
||||
imageFiles = append(imageFiles, files...)
|
||||
}
|
||||
}
|
||||
|
||||
// If no image fields found at all
|
||||
if !foundArrayImages && (len(imageFiles) == 0) {
|
||||
return nil, errors.New("image is required")
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if len(imageFiles) == 0 {
|
||||
return nil, errors.New("image is required")
|
||||
}
|
||||
|
||||
if len(imageFiles) > 1 {
|
||||
return nil, errors.New("only one image is supported for qwen edit")
|
||||
}
|
||||
|
||||
// 获取base64编码的图片
|
||||
var imageBase64s []string
|
||||
for _, file := range imageFiles {
|
||||
image, err := file.Open()
|
||||
if err != nil {
|
||||
return nil, errors.New("failed to open image file")
|
||||
}
|
||||
|
||||
// 读取文件内容
|
||||
imageData, err := io.ReadAll(image)
|
||||
if err != nil {
|
||||
return nil, errors.New("failed to read image file")
|
||||
}
|
||||
|
||||
// 获取MIME类型
|
||||
mimeType := http.DetectContentType(imageData)
|
||||
|
||||
// 编码为base64
|
||||
base64Data := base64.StdEncoding.EncodeToString(imageData)
|
||||
|
||||
// 构造data URL格式
|
||||
dataURL := fmt.Sprintf("data:%s;base64,%s", mimeType, base64Data)
|
||||
imageBase64s = append(imageBase64s, dataURL)
|
||||
image.Close()
|
||||
}
|
||||
|
||||
//dto.MediaContent{}
|
||||
mediaContents := make([]AliMediaContent, len(imageBase64s))
|
||||
for i, b64 := range imageBase64s {
|
||||
mediaContents[i] = AliMediaContent{
|
||||
Image: b64,
|
||||
}
|
||||
}
|
||||
mediaContents = append(mediaContents, AliMediaContent{
|
||||
Text: request.Prompt,
|
||||
})
|
||||
imageRequest.Input = AliImageInput{
|
||||
Messages: []AliMessage{
|
||||
{
|
||||
Role: "user",
|
||||
Content: mediaContents,
|
||||
},
|
||||
},
|
||||
}
|
||||
imageRequest.Parameters = AliImageParameters{
|
||||
Watermark: request.Watermark,
|
||||
}
|
||||
return &imageRequest, nil
|
||||
}
|
||||
|
||||
func updateTask(info *relaycommon.RelayInfo, taskID string) (*AliResponse, error, []byte) {
|
||||
url := fmt.Sprintf("%s/api/v1/tasks/%s", info.BaseUrl, taskID)
|
||||
url := fmt.Sprintf("%s/api/v1/tasks/%s", info.ChannelBaseUrl, taskID)
|
||||
|
||||
var aliResponse AliResponse
|
||||
|
||||
@@ -43,7 +166,7 @@ func updateTask(info *relaycommon.RelayInfo, taskID string) (*AliResponse, error
|
||||
client := &http.Client{}
|
||||
resp, err := client.Do(req)
|
||||
if err != nil {
|
||||
common.SysError("updateTask client.Do err: " + err.Error())
|
||||
common.SysLog("updateTask client.Do err: " + err.Error())
|
||||
return &aliResponse, err, nil
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
@@ -51,17 +174,17 @@ func updateTask(info *relaycommon.RelayInfo, taskID string) (*AliResponse, error
|
||||
responseBody, err := io.ReadAll(resp.Body)
|
||||
|
||||
var response AliResponse
|
||||
err = json.Unmarshal(responseBody, &response)
|
||||
err = common.Unmarshal(responseBody, &response)
|
||||
if err != nil {
|
||||
common.SysError("updateTask NewDecoder err: " + err.Error())
|
||||
common.SysLog("updateTask NewDecoder err: " + err.Error())
|
||||
return &aliResponse, err, nil
|
||||
}
|
||||
|
||||
return &response, nil, responseBody
|
||||
}
|
||||
|
||||
func asyncTaskWait(info *relaycommon.RelayInfo, taskID string) (*AliResponse, []byte, error) {
|
||||
waitSeconds := 3
|
||||
func asyncTaskWait(c *gin.Context, info *relaycommon.RelayInfo, taskID string) (*AliResponse, []byte, error) {
|
||||
waitSeconds := 10
|
||||
step := 0
|
||||
maxStep := 20
|
||||
|
||||
@@ -69,11 +192,14 @@ func asyncTaskWait(info *relaycommon.RelayInfo, taskID string) (*AliResponse, []
|
||||
var responseBody []byte
|
||||
|
||||
for {
|
||||
logger.LogDebug(c, fmt.Sprintf("asyncTaskWait step %d/%d, wait %d seconds", step, maxStep, waitSeconds))
|
||||
step++
|
||||
rsp, err, body := updateTask(info, taskID)
|
||||
responseBody = body
|
||||
if err != nil {
|
||||
return &taskResponse, responseBody, err
|
||||
logger.LogWarn(c, "asyncTaskWait UpdateTask err: "+err.Error())
|
||||
time.Sleep(time.Duration(waitSeconds) * time.Second)
|
||||
continue
|
||||
}
|
||||
|
||||
if rsp.Output.TaskStatus == "" {
|
||||
@@ -99,7 +225,7 @@ func asyncTaskWait(info *relaycommon.RelayInfo, taskID string) (*AliResponse, []
|
||||
return nil, nil, fmt.Errorf("aliAsyncTaskWait timeout")
|
||||
}
|
||||
|
||||
func responseAli2OpenAIImage(c *gin.Context, response *AliResponse, info *relaycommon.RelayInfo, responseFormat string) *dto.ImageResponse {
|
||||
func responseAli2OpenAIImage(c *gin.Context, response *AliResponse, originBody []byte, info *relaycommon.RelayInfo, responseFormat string) *dto.ImageResponse {
|
||||
imageResponse := dto.ImageResponse{
|
||||
Created: info.StartTime.Unix(),
|
||||
}
|
||||
@@ -109,7 +235,7 @@ func responseAli2OpenAIImage(c *gin.Context, response *AliResponse, info *relayc
|
||||
if responseFormat == "b64_json" {
|
||||
_, b64, err := service.GetImageFromUrl(data.Url)
|
||||
if err != nil {
|
||||
common.LogError(c, "get_image_data_failed: "+err.Error())
|
||||
logger.LogError(c, "get_image_data_failed: "+err.Error())
|
||||
continue
|
||||
}
|
||||
b64Json = b64
|
||||
@@ -123,6 +249,9 @@ func responseAli2OpenAIImage(c *gin.Context, response *AliResponse, info *relayc
|
||||
RevisedPrompt: "",
|
||||
})
|
||||
}
|
||||
var mapResponse map[string]any
|
||||
_ = common.Unmarshal(originBody, &mapResponse)
|
||||
imageResponse.Extra = mapResponse
|
||||
return &imageResponse
|
||||
}
|
||||
|
||||
@@ -132,20 +261,20 @@ func aliImageHandler(c *gin.Context, resp *http.Response, info *relaycommon.Rela
|
||||
var aliTaskResponse AliResponse
|
||||
responseBody, err := io.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
return types.NewError(err, types.ErrorCodeReadResponseBodyFailed), nil
|
||||
return types.NewOpenAIError(err, types.ErrorCodeReadResponseBodyFailed, http.StatusInternalServerError), nil
|
||||
}
|
||||
common.CloseResponseBodyGracefully(resp)
|
||||
err = json.Unmarshal(responseBody, &aliTaskResponse)
|
||||
service.CloseResponseBodyGracefully(resp)
|
||||
err = common.Unmarshal(responseBody, &aliTaskResponse)
|
||||
if err != nil {
|
||||
return types.NewError(err, types.ErrorCodeBadResponseBody), nil
|
||||
return types.NewOpenAIError(err, types.ErrorCodeBadResponseBody, http.StatusInternalServerError), nil
|
||||
}
|
||||
|
||||
if aliTaskResponse.Message != "" {
|
||||
common.LogError(c, "ali_async_task_failed: "+aliTaskResponse.Message)
|
||||
logger.LogError(c, "ali_async_task_failed: "+aliTaskResponse.Message)
|
||||
return types.NewError(errors.New(aliTaskResponse.Message), types.ErrorCodeBadResponse), nil
|
||||
}
|
||||
|
||||
aliResponse, _, err := asyncTaskWait(info, aliTaskResponse.Output.TaskId)
|
||||
aliResponse, originRespBody, err := asyncTaskWait(c, info, aliTaskResponse.Output.TaskId)
|
||||
if err != nil {
|
||||
return types.NewError(err, types.ErrorCodeBadResponse), nil
|
||||
}
|
||||
@@ -159,13 +288,52 @@ func aliImageHandler(c *gin.Context, resp *http.Response, info *relaycommon.Rela
|
||||
}, resp.StatusCode), nil
|
||||
}
|
||||
|
||||
fullTextResponse := responseAli2OpenAIImage(c, aliResponse, info, responseFormat)
|
||||
jsonResponse, err := json.Marshal(fullTextResponse)
|
||||
fullTextResponse := responseAli2OpenAIImage(c, aliResponse, originRespBody, info, responseFormat)
|
||||
jsonResponse, err := common.Marshal(fullTextResponse)
|
||||
if err != nil {
|
||||
return types.NewError(err, types.ErrorCodeBadResponseBody), nil
|
||||
}
|
||||
c.Writer.Header().Set("Content-Type", "application/json")
|
||||
c.Writer.WriteHeader(resp.StatusCode)
|
||||
c.Writer.Write(jsonResponse)
|
||||
service.IOCopyBytesGracefully(c, resp, jsonResponse)
|
||||
return nil, &dto.Usage{}
|
||||
}
|
||||
|
||||
func aliImageEditHandler(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (*types.NewAPIError, *dto.Usage) {
|
||||
var aliResponse AliResponse
|
||||
responseBody, err := io.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
return types.NewOpenAIError(err, types.ErrorCodeReadResponseBodyFailed, http.StatusInternalServerError), nil
|
||||
}
|
||||
|
||||
service.CloseResponseBodyGracefully(resp)
|
||||
err = common.Unmarshal(responseBody, &aliResponse)
|
||||
if err != nil {
|
||||
return types.NewOpenAIError(err, types.ErrorCodeBadResponseBody, http.StatusInternalServerError), nil
|
||||
}
|
||||
|
||||
if aliResponse.Message != "" {
|
||||
logger.LogError(c, "ali_task_failed: "+aliResponse.Message)
|
||||
return types.NewError(errors.New(aliResponse.Message), types.ErrorCodeBadResponse), nil
|
||||
}
|
||||
var fullTextResponse dto.ImageResponse
|
||||
if len(aliResponse.Output.Choices) > 0 {
|
||||
fullTextResponse = dto.ImageResponse{
|
||||
Created: info.StartTime.Unix(),
|
||||
Data: []dto.ImageData{
|
||||
{
|
||||
Url: aliResponse.Output.Choices[0]["message"].(map[string]any)["content"].([]any)[0].(map[string]any)["image"].(string),
|
||||
B64Json: "",
|
||||
},
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
var mapResponse map[string]any
|
||||
_ = common.Unmarshal(responseBody, &mapResponse)
|
||||
fullTextResponse.Extra = mapResponse
|
||||
jsonResponse, err := common.Marshal(fullTextResponse)
|
||||
if err != nil {
|
||||
return types.NewError(err, types.ErrorCodeBadResponseBody), nil
|
||||
}
|
||||
service.IOCopyBytesGracefully(c, resp, jsonResponse)
|
||||
return nil, &dto.Usage{}
|
||||
}
|
||||
|
||||
@@ -4,9 +4,9 @@ import (
|
||||
"encoding/json"
|
||||
"io"
|
||||
"net/http"
|
||||
"one-api/common"
|
||||
"one-api/dto"
|
||||
relaycommon "one-api/relay/common"
|
||||
"one-api/service"
|
||||
"one-api/types"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
@@ -34,14 +34,14 @@ func ConvertRerankRequest(request dto.RerankRequest) *AliRerankRequest {
|
||||
func RerankHandler(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (*types.NewAPIError, *dto.Usage) {
|
||||
responseBody, err := io.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
return types.NewError(err, types.ErrorCodeReadResponseBodyFailed), nil
|
||||
return types.NewOpenAIError(err, types.ErrorCodeReadResponseBodyFailed, http.StatusInternalServerError), nil
|
||||
}
|
||||
common.CloseResponseBodyGracefully(resp)
|
||||
service.CloseResponseBodyGracefully(resp)
|
||||
|
||||
var aliResponse AliRerankResponse
|
||||
err = json.Unmarshal(responseBody, &aliResponse)
|
||||
if err != nil {
|
||||
return types.NewError(err, types.ErrorCodeBadResponseBody), nil
|
||||
return types.NewOpenAIError(err, types.ErrorCodeBadResponseBody, http.StatusInternalServerError), nil
|
||||
}
|
||||
|
||||
if aliResponse.Code != "" {
|
||||
|
||||
@@ -8,6 +8,7 @@ import (
|
||||
"one-api/common"
|
||||
"one-api/dto"
|
||||
"one-api/relay/helper"
|
||||
"one-api/service"
|
||||
"strings"
|
||||
|
||||
"one-api/types"
|
||||
@@ -43,10 +44,10 @@ func aliEmbeddingHandler(c *gin.Context, resp *http.Response) (*types.NewAPIErro
|
||||
var fullTextResponse dto.FlexibleEmbeddingResponse
|
||||
err := json.NewDecoder(resp.Body).Decode(&fullTextResponse)
|
||||
if err != nil {
|
||||
return types.NewError(err, types.ErrorCodeBadResponseBody), nil
|
||||
return types.NewOpenAIError(err, types.ErrorCodeBadResponseBody, http.StatusInternalServerError), nil
|
||||
}
|
||||
|
||||
common.CloseResponseBodyGracefully(resp)
|
||||
service.CloseResponseBodyGracefully(resp)
|
||||
|
||||
model := c.GetString("model")
|
||||
if model == "" {
|
||||
@@ -148,7 +149,7 @@ func aliStreamHandler(c *gin.Context, resp *http.Response) (*types.NewAPIError,
|
||||
var aliResponse AliResponse
|
||||
err := json.Unmarshal([]byte(data), &aliResponse)
|
||||
if err != nil {
|
||||
common.SysError("error unmarshalling stream response: " + err.Error())
|
||||
common.SysLog("error unmarshalling stream response: " + err.Error())
|
||||
return true
|
||||
}
|
||||
if aliResponse.Usage.OutputTokens != 0 {
|
||||
@@ -161,7 +162,7 @@ func aliStreamHandler(c *gin.Context, resp *http.Response) (*types.NewAPIError,
|
||||
lastResponseText = aliResponse.Output.Text
|
||||
jsonResponse, err := json.Marshal(response)
|
||||
if err != nil {
|
||||
common.SysError("error marshalling stream response: " + err.Error())
|
||||
common.SysLog("error marshalling stream response: " + err.Error())
|
||||
return true
|
||||
}
|
||||
c.Render(-1, common.CustomEvent{Data: "data: " + string(jsonResponse)})
|
||||
@@ -171,7 +172,7 @@ func aliStreamHandler(c *gin.Context, resp *http.Response) (*types.NewAPIError,
|
||||
return false
|
||||
}
|
||||
})
|
||||
common.CloseResponseBodyGracefully(resp)
|
||||
service.CloseResponseBodyGracefully(resp)
|
||||
return nil, &usage
|
||||
}
|
||||
|
||||
@@ -179,12 +180,12 @@ func aliHandler(c *gin.Context, resp *http.Response) (*types.NewAPIError, *dto.U
|
||||
var aliResponse AliResponse
|
||||
responseBody, err := io.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
return types.NewError(err, types.ErrorCodeReadResponseBodyFailed), nil
|
||||
return types.NewOpenAIError(err, types.ErrorCodeReadResponseBodyFailed, http.StatusInternalServerError), nil
|
||||
}
|
||||
common.CloseResponseBodyGracefully(resp)
|
||||
service.CloseResponseBodyGracefully(resp)
|
||||
err = json.Unmarshal(responseBody, &aliResponse)
|
||||
if err != nil {
|
||||
return types.NewError(err, types.ErrorCodeBadResponseBody), nil
|
||||
return types.NewOpenAIError(err, types.ErrorCodeBadResponseBody, http.StatusInternalServerError), nil
|
||||
}
|
||||
if aliResponse.Code != "" {
|
||||
return types.WithOpenAIError(types.OpenAIError{
|
||||
|
||||
@@ -7,11 +7,13 @@ import (
|
||||
"io"
|
||||
"net/http"
|
||||
common2 "one-api/common"
|
||||
"one-api/logger"
|
||||
"one-api/relay/common"
|
||||
"one-api/relay/constant"
|
||||
"one-api/relay/helper"
|
||||
"one-api/service"
|
||||
"one-api/setting/operation_setting"
|
||||
"one-api/types"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
@@ -46,7 +48,19 @@ func DoApiRequest(a Adaptor, c *gin.Context, info *common.RelayInfo, requestBody
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("new request failed: %w", err)
|
||||
}
|
||||
err = a.SetupRequestHeader(c, &req.Header, info)
|
||||
headers := req.Header
|
||||
headerOverride := make(map[string]string)
|
||||
for k, v := range info.HeadersOverride {
|
||||
if str, ok := v.(string); ok {
|
||||
headerOverride[k] = str
|
||||
} else {
|
||||
return nil, types.NewError(err, types.ErrorCodeChannelHeaderOverrideInvalid)
|
||||
}
|
||||
}
|
||||
for key, value := range headerOverride {
|
||||
headers.Set(key, value)
|
||||
}
|
||||
err = a.SetupRequestHeader(c, &headers, info)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("setup request header failed: %w", err)
|
||||
}
|
||||
@@ -71,8 +85,19 @@ func DoFormRequest(a Adaptor, c *gin.Context, info *common.RelayInfo, requestBod
|
||||
}
|
||||
// set form data
|
||||
req.Header.Set("Content-Type", c.Request.Header.Get("Content-Type"))
|
||||
|
||||
err = a.SetupRequestHeader(c, &req.Header, info)
|
||||
headers := req.Header
|
||||
headerOverride := make(map[string]string)
|
||||
for k, v := range info.HeadersOverride {
|
||||
if str, ok := v.(string); ok {
|
||||
headerOverride[k] = str
|
||||
} else {
|
||||
return nil, types.NewError(err, types.ErrorCodeChannelHeaderOverrideInvalid)
|
||||
}
|
||||
}
|
||||
for key, value := range headerOverride {
|
||||
headers.Set(key, value)
|
||||
}
|
||||
err = a.SetupRequestHeader(c, &headers, info)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("setup request header failed: %w", err)
|
||||
}
|
||||
@@ -181,7 +206,7 @@ func sendPingData(c *gin.Context, mutex *sync.Mutex) error {
|
||||
|
||||
err := helper.PingData(c)
|
||||
if err != nil {
|
||||
common2.LogError(c, "SSE ping error: "+err.Error())
|
||||
logger.LogError(c, "SSE ping error: "+err.Error())
|
||||
done <- err
|
||||
return
|
||||
}
|
||||
@@ -223,7 +248,7 @@ func doRequest(c *gin.Context, req *http.Request, info *common.RelayInfo) (*http
|
||||
helper.SetEventStreamHeaders(c)
|
||||
// 处理流式请求的 ping 保活
|
||||
generalSettings := operation_setting.GetGeneralSetting()
|
||||
if generalSettings.PingIntervalEnabled {
|
||||
if generalSettings.PingIntervalEnabled && !info.DisablePing {
|
||||
pingInterval := time.Duration(generalSettings.PingIntervalSeconds) * time.Second
|
||||
stopPinger = startPingKeepAlive(c, pingInterval)
|
||||
// 使用defer确保在任何情况下都能停止ping goroutine
|
||||
@@ -252,7 +277,7 @@ func doRequest(c *gin.Context, req *http.Request, info *common.RelayInfo) (*http
|
||||
return resp, nil
|
||||
}
|
||||
|
||||
func DoTaskApiRequest(a TaskAdaptor, c *gin.Context, info *common.TaskRelayInfo, requestBody io.Reader) (*http.Response, error) {
|
||||
func DoTaskApiRequest(a TaskAdaptor, c *gin.Context, info *common.RelayInfo, requestBody io.Reader) (*http.Response, error) {
|
||||
fullRequestURL, err := a.BuildRequestURL(info)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
@@ -269,7 +294,7 @@ func DoTaskApiRequest(a TaskAdaptor, c *gin.Context, info *common.TaskRelayInfo,
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("setup request header failed: %w", err)
|
||||
}
|
||||
resp, err := doRequest(c, req, info.RelayInfo)
|
||||
resp, err := doRequest(c, req, info)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("do request failed: %w", err)
|
||||
}
|
||||
|
||||
@@ -22,6 +22,11 @@ type Adaptor struct {
|
||||
RequestMode int
|
||||
}
|
||||
|
||||
func (a *Adaptor) ConvertGeminiRequest(*gin.Context, *relaycommon.RelayInfo, *dto.GeminiChatRequest) (any, error) {
|
||||
//TODO implement me
|
||||
return nil, errors.New("not implemented")
|
||||
}
|
||||
|
||||
func (a *Adaptor) ConvertClaudeRequest(c *gin.Context, info *relaycommon.RelayInfo, request *dto.ClaudeRequest) (any, error) {
|
||||
c.Set("request_model", request.Model)
|
||||
c.Set("converted_request", request)
|
||||
@@ -58,7 +63,7 @@ func (a *Adaptor) ConvertOpenAIRequest(c *gin.Context, info *relaycommon.RelayIn
|
||||
|
||||
var claudeReq *dto.ClaudeRequest
|
||||
var err error
|
||||
claudeReq, err = claude.RequestOpenAI2ClaudeMessage(*request)
|
||||
claudeReq, err = claude.RequestOpenAI2ClaudeMessage(c, *request)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
@@ -13,6 +13,7 @@ var awsModelIDMap = map[string]string{
|
||||
"claude-3-7-sonnet-20250219": "anthropic.claude-3-7-sonnet-20250219-v1:0",
|
||||
"claude-sonnet-4-20250514": "anthropic.claude-sonnet-4-20250514-v1:0",
|
||||
"claude-opus-4-20250514": "anthropic.claude-opus-4-20250514-v1:0",
|
||||
"claude-opus-4-1-20250805": "anthropic.claude-opus-4-1-20250805-v1:0",
|
||||
}
|
||||
|
||||
var awsModelCanCrossRegionMap = map[string]map[string]bool{
|
||||
@@ -54,6 +55,9 @@ var awsModelCanCrossRegionMap = map[string]map[string]bool{
|
||||
"anthropic.claude-opus-4-20250514-v1:0": {
|
||||
"us": true,
|
||||
},
|
||||
"anthropic.claude-opus-4-1-20250805-v1:0": {
|
||||
"us": true,
|
||||
},
|
||||
}
|
||||
|
||||
var awsRegionCrossModelPrefixMap = map[string]string{
|
||||
|
||||
@@ -1,7 +1,6 @@
|
||||
package aws
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"one-api/common"
|
||||
@@ -19,20 +18,31 @@ import (
|
||||
"github.com/aws/aws-sdk-go-v2/credentials"
|
||||
"github.com/aws/aws-sdk-go-v2/service/bedrockruntime"
|
||||
bedrockruntimeTypes "github.com/aws/aws-sdk-go-v2/service/bedrockruntime/types"
|
||||
"github.com/aws/smithy-go/auth/bearer"
|
||||
)
|
||||
|
||||
func newAwsClient(c *gin.Context, info *relaycommon.RelayInfo) (*bedrockruntime.Client, error) {
|
||||
awsSecret := strings.Split(info.ApiKey, "|")
|
||||
if len(awsSecret) != 3 {
|
||||
var client *bedrockruntime.Client
|
||||
switch len(awsSecret) {
|
||||
case 2:
|
||||
apiKey := awsSecret[0]
|
||||
region := awsSecret[1]
|
||||
client = bedrockruntime.New(bedrockruntime.Options{
|
||||
Region: region,
|
||||
BearerAuthTokenProvider: bearer.StaticTokenProvider{Token: bearer.Token{Value: apiKey}},
|
||||
})
|
||||
case 3:
|
||||
ak := awsSecret[0]
|
||||
sk := awsSecret[1]
|
||||
region := awsSecret[2]
|
||||
client = bedrockruntime.New(bedrockruntime.Options{
|
||||
Region: region,
|
||||
Credentials: aws.NewCredentialsCache(credentials.NewStaticCredentialsProvider(ak, sk, "")),
|
||||
})
|
||||
default:
|
||||
return nil, errors.New("invalid aws secret key")
|
||||
}
|
||||
ak := awsSecret[0]
|
||||
sk := awsSecret[1]
|
||||
region := awsSecret[2]
|
||||
client := bedrockruntime.New(bedrockruntime.Options{
|
||||
Region: region,
|
||||
Credentials: aws.NewCredentialsCache(credentials.NewStaticCredentialsProvider(ak, sk, "")),
|
||||
})
|
||||
|
||||
return client, nil
|
||||
}
|
||||
@@ -102,14 +112,14 @@ func awsHandler(c *gin.Context, info *relaycommon.RelayInfo, requestMode int) (*
|
||||
}
|
||||
claudeReq := claudeReq_.(*dto.ClaudeRequest)
|
||||
awsClaudeReq := copyRequest(claudeReq)
|
||||
awsReq.Body, err = json.Marshal(awsClaudeReq)
|
||||
awsReq.Body, err = common.Marshal(awsClaudeReq)
|
||||
if err != nil {
|
||||
return types.NewError(errors.Wrap(err, "marshal request"), types.ErrorCodeBadResponseBody), nil
|
||||
}
|
||||
|
||||
awsResp, err := awsCli.InvokeModel(c.Request.Context(), awsReq)
|
||||
if err != nil {
|
||||
return types.NewError(errors.Wrap(err, "InvokeModel"), types.ErrorCodeChannelAwsClientError), nil
|
||||
return types.NewOpenAIError(errors.Wrap(err, "InvokeModel"), types.ErrorCodeAwsInvokeError, http.StatusInternalServerError), nil
|
||||
}
|
||||
|
||||
claudeInfo := &claude.ClaudeResponseInfo{
|
||||
@@ -120,7 +130,12 @@ func awsHandler(c *gin.Context, info *relaycommon.RelayInfo, requestMode int) (*
|
||||
Usage: &dto.Usage{},
|
||||
}
|
||||
|
||||
handlerErr := claude.HandleClaudeResponseData(c, info, claudeInfo, awsResp.Body, RequestModeMessage)
|
||||
// 复制上游 Content-Type 到客户端响应头
|
||||
if awsResp.ContentType != nil && *awsResp.ContentType != "" {
|
||||
c.Writer.Header().Set("Content-Type", *awsResp.ContentType)
|
||||
}
|
||||
|
||||
handlerErr := claude.HandleClaudeResponseData(c, info, claudeInfo, nil, awsResp.Body, RequestModeMessage)
|
||||
if handlerErr != nil {
|
||||
return handlerErr, nil
|
||||
}
|
||||
@@ -154,14 +169,14 @@ func awsStreamHandler(c *gin.Context, resp *http.Response, info *relaycommon.Rel
|
||||
claudeReq := claudeReq_.(*dto.ClaudeRequest)
|
||||
|
||||
awsClaudeReq := copyRequest(claudeReq)
|
||||
awsReq.Body, err = json.Marshal(awsClaudeReq)
|
||||
awsReq.Body, err = common.Marshal(awsClaudeReq)
|
||||
if err != nil {
|
||||
return types.NewError(errors.Wrap(err, "marshal request"), types.ErrorCodeBadResponseBody), nil
|
||||
}
|
||||
|
||||
awsResp, err := awsCli.InvokeModelWithResponseStream(c.Request.Context(), awsReq)
|
||||
if err != nil {
|
||||
return types.NewError(errors.Wrap(err, "InvokeModelWithResponseStream"), types.ErrorCodeChannelAwsClientError), nil
|
||||
return types.NewOpenAIError(errors.Wrap(err, "InvokeModelWithResponseStream"), types.ErrorCodeAwsInvokeError, http.StatusInternalServerError), nil
|
||||
}
|
||||
stream := awsResp.GetStream()
|
||||
defer stream.Close()
|
||||
|
||||
@@ -18,6 +18,11 @@ import (
|
||||
type Adaptor struct {
|
||||
}
|
||||
|
||||
func (a *Adaptor) ConvertGeminiRequest(*gin.Context, *relaycommon.RelayInfo, *dto.GeminiChatRequest) (any, error) {
|
||||
//TODO implement me
|
||||
return nil, errors.New("not implemented")
|
||||
}
|
||||
|
||||
func (a *Adaptor) ConvertClaudeRequest(*gin.Context, *relaycommon.RelayInfo, *dto.ClaudeRequest) (any, error) {
|
||||
//TODO implement me
|
||||
panic("implement me")
|
||||
@@ -96,7 +101,7 @@ func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) {
|
||||
default:
|
||||
suffix += strings.ToLower(info.UpstreamModelName)
|
||||
}
|
||||
fullRequestURL := fmt.Sprintf("%s/rpc/2.0/ai_custom/v1/wenxinworkshop/%s", info.BaseUrl, suffix)
|
||||
fullRequestURL := fmt.Sprintf("%s/rpc/2.0/ai_custom/v1/wenxinworkshop/%s", info.ChannelBaseUrl, suffix)
|
||||
var accessToken string
|
||||
var err error
|
||||
if accessToken, err = getBaiduAccessToken(info.ApiKey); err != nil {
|
||||
|
||||
@@ -34,9 +34,9 @@ func requestOpenAI2Baidu(request dto.GeneralOpenAIRequest) *BaiduChatRequest {
|
||||
EnableCitation: false,
|
||||
UserId: request.User,
|
||||
}
|
||||
if request.MaxTokens != 0 {
|
||||
maxTokens := int(request.MaxTokens)
|
||||
if request.MaxTokens == 1 {
|
||||
if request.GetMaxTokens() != 0 {
|
||||
maxTokens := int(request.GetMaxTokens())
|
||||
if request.GetMaxTokens() == 1 {
|
||||
maxTokens = 2
|
||||
}
|
||||
baiduRequest.MaxOutputTokens = &maxTokens
|
||||
@@ -118,7 +118,7 @@ func baiduStreamHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http.
|
||||
var baiduResponse BaiduChatStreamResponse
|
||||
err := common.Unmarshal([]byte(data), &baiduResponse)
|
||||
if err != nil {
|
||||
common.SysError("error unmarshalling stream response: " + err.Error())
|
||||
common.SysLog("error unmarshalling stream response: " + err.Error())
|
||||
return true
|
||||
}
|
||||
if baiduResponse.Usage.TotalTokens != 0 {
|
||||
@@ -129,11 +129,11 @@ func baiduStreamHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http.
|
||||
response := streamResponseBaidu2OpenAI(&baiduResponse)
|
||||
err = helper.ObjectData(c, response)
|
||||
if err != nil {
|
||||
common.SysError("error sending stream response: " + err.Error())
|
||||
common.SysLog("error sending stream response: " + err.Error())
|
||||
}
|
||||
return true
|
||||
})
|
||||
common.CloseResponseBodyGracefully(resp)
|
||||
service.CloseResponseBodyGracefully(resp)
|
||||
return nil, usage
|
||||
}
|
||||
|
||||
@@ -143,7 +143,7 @@ func baiduHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http.Respon
|
||||
if err != nil {
|
||||
return types.NewError(err, types.ErrorCodeBadResponseBody), nil
|
||||
}
|
||||
common.CloseResponseBodyGracefully(resp)
|
||||
service.CloseResponseBodyGracefully(resp)
|
||||
err = json.Unmarshal(responseBody, &baiduResponse)
|
||||
if err != nil {
|
||||
return types.NewError(err, types.ErrorCodeBadResponseBody), nil
|
||||
@@ -168,7 +168,7 @@ func baiduEmbeddingHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *ht
|
||||
if err != nil {
|
||||
return types.NewError(err, types.ErrorCodeBadResponseBody), nil
|
||||
}
|
||||
common.CloseResponseBodyGracefully(resp)
|
||||
service.CloseResponseBodyGracefully(resp)
|
||||
err = json.Unmarshal(responseBody, &baiduResponse)
|
||||
if err != nil {
|
||||
return types.NewError(err, types.ErrorCodeBadResponseBody), nil
|
||||
|
||||
@@ -9,6 +9,7 @@ import (
|
||||
"one-api/relay/channel"
|
||||
"one-api/relay/channel/openai"
|
||||
relaycommon "one-api/relay/common"
|
||||
"one-api/relay/constant"
|
||||
"one-api/types"
|
||||
"strings"
|
||||
|
||||
@@ -18,10 +19,14 @@ import (
|
||||
type Adaptor struct {
|
||||
}
|
||||
|
||||
func (a *Adaptor) ConvertClaudeRequest(*gin.Context, *relaycommon.RelayInfo, *dto.ClaudeRequest) (any, error) {
|
||||
func (a *Adaptor) ConvertGeminiRequest(*gin.Context, *relaycommon.RelayInfo, *dto.GeminiChatRequest) (any, error) {
|
||||
//TODO implement me
|
||||
panic("implement me")
|
||||
return nil, nil
|
||||
return nil, errors.New("not implemented")
|
||||
}
|
||||
|
||||
func (a *Adaptor) ConvertClaudeRequest(c *gin.Context, info *relaycommon.RelayInfo, req *dto.ClaudeRequest) (any, error) {
|
||||
adaptor := openai.Adaptor{}
|
||||
return adaptor.ConvertClaudeRequest(c, info, req)
|
||||
}
|
||||
|
||||
func (a *Adaptor) ConvertAudioRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.AudioRequest) (io.Reader, error) {
|
||||
@@ -38,20 +43,33 @@ func (a *Adaptor) Init(info *relaycommon.RelayInfo) {
|
||||
}
|
||||
|
||||
func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) {
|
||||
return fmt.Sprintf("%s/v2/chat/completions", info.BaseUrl), nil
|
||||
switch info.RelayMode {
|
||||
case constant.RelayModeChatCompletions:
|
||||
return fmt.Sprintf("%s/v2/chat/completions", info.ChannelBaseUrl), nil
|
||||
case constant.RelayModeEmbeddings:
|
||||
return fmt.Sprintf("%s/v2/embeddings", info.ChannelBaseUrl), nil
|
||||
case constant.RelayModeImagesGenerations:
|
||||
return fmt.Sprintf("%s/v2/images/generations", info.ChannelBaseUrl), nil
|
||||
case constant.RelayModeImagesEdits:
|
||||
return fmt.Sprintf("%s/v2/images/edits", info.ChannelBaseUrl), nil
|
||||
case constant.RelayModeRerank:
|
||||
return fmt.Sprintf("%s/v2/rerank", info.ChannelBaseUrl), nil
|
||||
default:
|
||||
}
|
||||
return "", fmt.Errorf("unsupported relay mode: %d", info.RelayMode)
|
||||
}
|
||||
|
||||
func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Header, info *relaycommon.RelayInfo) error {
|
||||
channel.SetupApiRequestHeader(info, c, req)
|
||||
keyParts := strings.Split(info.ApiKey, "|")
|
||||
keyParts := strings.Split(info.ApiKey, "|")
|
||||
if len(keyParts) == 0 || keyParts[0] == "" {
|
||||
return errors.New("invalid API key: authorization token is required")
|
||||
}
|
||||
if len(keyParts) > 1 {
|
||||
if keyParts[1] != "" {
|
||||
req.Set("appid", keyParts[1])
|
||||
}
|
||||
}
|
||||
return errors.New("invalid API key: authorization token is required")
|
||||
}
|
||||
if len(keyParts) > 1 {
|
||||
if keyParts[1] != "" {
|
||||
req.Set("appid", keyParts[1])
|
||||
}
|
||||
}
|
||||
req.Set("Authorization", "Bearer "+keyParts[0])
|
||||
return nil
|
||||
}
|
||||
@@ -63,20 +81,23 @@ func (a *Adaptor) ConvertOpenAIRequest(c *gin.Context, info *relaycommon.RelayIn
|
||||
if strings.HasSuffix(info.UpstreamModelName, "-search") {
|
||||
info.UpstreamModelName = strings.TrimSuffix(info.UpstreamModelName, "-search")
|
||||
request.Model = info.UpstreamModelName
|
||||
toMap := request.ToMap()
|
||||
toMap["web_search"] = map[string]any{
|
||||
"enable": true,
|
||||
"enable_citation": true,
|
||||
"enable_trace": true,
|
||||
"enable_status": false,
|
||||
if len(request.WebSearch) == 0 {
|
||||
toMap := request.ToMap()
|
||||
toMap["web_search"] = map[string]any{
|
||||
"enable": true,
|
||||
"enable_citation": true,
|
||||
"enable_trace": true,
|
||||
"enable_status": false,
|
||||
}
|
||||
return toMap, nil
|
||||
}
|
||||
return toMap, nil
|
||||
return request, nil
|
||||
}
|
||||
return request, nil
|
||||
}
|
||||
|
||||
func (a *Adaptor) ConvertRerankRequest(c *gin.Context, relayMode int, request dto.RerankRequest) (any, error) {
|
||||
return nil, nil
|
||||
return nil, errors.New("not implemented")
|
||||
}
|
||||
|
||||
func (a *Adaptor) ConvertEmbeddingRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.EmbeddingRequest) (any, error) {
|
||||
@@ -94,11 +115,8 @@ func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, request
|
||||
}
|
||||
|
||||
func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage any, err *types.NewAPIError) {
|
||||
if info.IsStream {
|
||||
usage, err = openai.OaiStreamHandler(c, info, resp)
|
||||
} else {
|
||||
usage, err = openai.OpenaiHandler(c, info, resp)
|
||||
}
|
||||
adaptor := openai.Adaptor{}
|
||||
usage, err = adaptor.DoResponse(c, resp, info)
|
||||
return
|
||||
}
|
||||
|
||||
|
||||
@@ -24,6 +24,11 @@ type Adaptor struct {
|
||||
RequestMode int
|
||||
}
|
||||
|
||||
func (a *Adaptor) ConvertGeminiRequest(*gin.Context, *relaycommon.RelayInfo, *dto.GeminiChatRequest) (any, error) {
|
||||
//TODO implement me
|
||||
return nil, errors.New("not implemented")
|
||||
}
|
||||
|
||||
func (a *Adaptor) ConvertClaudeRequest(c *gin.Context, info *relaycommon.RelayInfo, request *dto.ClaudeRequest) (any, error) {
|
||||
return request, nil
|
||||
}
|
||||
@@ -48,9 +53,9 @@ func (a *Adaptor) Init(info *relaycommon.RelayInfo) {
|
||||
|
||||
func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) {
|
||||
if a.RequestMode == RequestModeMessage {
|
||||
return fmt.Sprintf("%s/v1/messages", info.BaseUrl), nil
|
||||
return fmt.Sprintf("%s/v1/messages", info.ChannelBaseUrl), nil
|
||||
} else {
|
||||
return fmt.Sprintf("%s/v1/complete", info.BaseUrl), nil
|
||||
return fmt.Sprintf("%s/v1/complete", info.ChannelBaseUrl), nil
|
||||
}
|
||||
}
|
||||
|
||||
@@ -73,7 +78,7 @@ func (a *Adaptor) ConvertOpenAIRequest(c *gin.Context, info *relaycommon.RelayIn
|
||||
if a.RequestMode == RequestModeCompletion {
|
||||
return RequestOpenAI2ClaudeComplete(*request), nil
|
||||
} else {
|
||||
return RequestOpenAI2ClaudeMessage(*request)
|
||||
return RequestOpenAI2ClaudeMessage(c, *request)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -97,9 +102,9 @@ func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, request
|
||||
|
||||
func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage any, err *types.NewAPIError) {
|
||||
if info.IsStream {
|
||||
err, usage = ClaudeStreamHandler(c, resp, info, a.RequestMode)
|
||||
return ClaudeStreamHandler(c, resp, info, a.RequestMode)
|
||||
} else {
|
||||
err, usage = ClaudeHandler(c, resp, a.RequestMode, info)
|
||||
return ClaudeHandler(c, resp, info, a.RequestMode)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
@@ -17,6 +17,8 @@ var ModelList = []string{
|
||||
"claude-sonnet-4-20250514-thinking",
|
||||
"claude-opus-4-20250514",
|
||||
"claude-opus-4-20250514-thinking",
|
||||
"claude-opus-4-1-20250805",
|
||||
"claude-opus-4-1-20250805-thinking",
|
||||
}
|
||||
|
||||
var ChannelName = "claude"
|
||||
|
||||
@@ -7,6 +7,7 @@ import (
|
||||
"net/http"
|
||||
"one-api/common"
|
||||
"one-api/dto"
|
||||
"one-api/logger"
|
||||
"one-api/relay/channel/openrouter"
|
||||
relaycommon "one-api/relay/common"
|
||||
"one-api/relay/helper"
|
||||
@@ -31,7 +32,7 @@ func stopReasonClaude2OpenAI(reason string) string {
|
||||
case "end_turn":
|
||||
return "stop"
|
||||
case "max_tokens":
|
||||
return "max_tokens"
|
||||
return "length"
|
||||
case "tool_use":
|
||||
return "tool_calls"
|
||||
default:
|
||||
@@ -70,7 +71,7 @@ func RequestOpenAI2ClaudeComplete(textRequest dto.GeneralOpenAIRequest) *dto.Cla
|
||||
return &claudeRequest
|
||||
}
|
||||
|
||||
func RequestOpenAI2ClaudeMessage(textRequest dto.GeneralOpenAIRequest) (*dto.ClaudeRequest, error) {
|
||||
func RequestOpenAI2ClaudeMessage(c *gin.Context, textRequest dto.GeneralOpenAIRequest) (*dto.ClaudeRequest, error) {
|
||||
claudeTools := make([]any, 0, len(textRequest.Tools))
|
||||
|
||||
for _, tool := range textRequest.Tools {
|
||||
@@ -149,7 +150,7 @@ func RequestOpenAI2ClaudeMessage(textRequest dto.GeneralOpenAIRequest) (*dto.Cla
|
||||
|
||||
claudeRequest := dto.ClaudeRequest{
|
||||
Model: textRequest.Model,
|
||||
MaxTokens: textRequest.MaxTokens,
|
||||
MaxTokens: textRequest.GetMaxTokens(),
|
||||
StopSequences: nil,
|
||||
Temperature: textRequest.Temperature,
|
||||
TopP: textRequest.TopP,
|
||||
@@ -273,19 +274,28 @@ func RequestOpenAI2ClaudeMessage(textRequest dto.GeneralOpenAIRequest) (*dto.Cla
|
||||
|
||||
claudeMessages := make([]dto.ClaudeMessage, 0)
|
||||
isFirstMessage := true
|
||||
// 初始化system消息数组,用于累积多个system消息
|
||||
var systemMessages []dto.ClaudeMediaMessage
|
||||
|
||||
for _, message := range formatMessages {
|
||||
if message.Role == "system" {
|
||||
// 根据Claude API规范,system字段使用数组格式更有通用性
|
||||
if message.IsStringContent() {
|
||||
claudeRequest.System = message.StringContent()
|
||||
systemMessages = append(systemMessages, dto.ClaudeMediaMessage{
|
||||
Type: "text",
|
||||
Text: common.GetPointer[string](message.StringContent()),
|
||||
})
|
||||
} else {
|
||||
contents := message.ParseContent()
|
||||
content := ""
|
||||
for _, ctx := range contents {
|
||||
// 支持复合内容的system消息(虽然不常见,但需要考虑完整性)
|
||||
for _, ctx := range message.ParseContent() {
|
||||
if ctx.Type == "text" {
|
||||
content += ctx.Text
|
||||
systemMessages = append(systemMessages, dto.ClaudeMediaMessage{
|
||||
Type: "text",
|
||||
Text: common.GetPointer[string](ctx.Text),
|
||||
})
|
||||
}
|
||||
// 未来可以在这里扩展对图片等其他类型的支持
|
||||
}
|
||||
claudeRequest.System = content
|
||||
}
|
||||
} else {
|
||||
if isFirstMessage {
|
||||
@@ -354,7 +364,7 @@ func RequestOpenAI2ClaudeMessage(textRequest dto.GeneralOpenAIRequest) (*dto.Cla
|
||||
// 判断是否是url
|
||||
if strings.HasPrefix(imageUrl.Url, "http") {
|
||||
// 是url,获取图片的类型和base64编码的数据
|
||||
fileData, err := service.GetFileBase64FromUrl(imageUrl.Url)
|
||||
fileData, err := service.GetFileBase64FromUrl(c, imageUrl.Url, "formatting image for Claude")
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("get file base64 from url failed: %s", err.Error())
|
||||
}
|
||||
@@ -375,7 +385,7 @@ func RequestOpenAI2ClaudeMessage(textRequest dto.GeneralOpenAIRequest) (*dto.Cla
|
||||
for _, toolCall := range message.ParseToolCalls() {
|
||||
inputObj := make(map[string]any)
|
||||
if err := json.Unmarshal([]byte(toolCall.Function.Arguments), &inputObj); err != nil {
|
||||
common.SysError("tool call function arguments is not a map[string]any: " + fmt.Sprintf("%v", toolCall.Function.Arguments))
|
||||
common.SysLog("tool call function arguments is not a map[string]any: " + fmt.Sprintf("%v", toolCall.Function.Arguments))
|
||||
continue
|
||||
}
|
||||
claudeMediaMessages = append(claudeMediaMessages, dto.ClaudeMediaMessage{
|
||||
@@ -391,6 +401,12 @@ func RequestOpenAI2ClaudeMessage(textRequest dto.GeneralOpenAIRequest) (*dto.Cla
|
||||
claudeMessages = append(claudeMessages, claudeMessage)
|
||||
}
|
||||
}
|
||||
|
||||
// 设置累积的system消息
|
||||
if len(systemMessages) > 0 {
|
||||
claudeRequest.System = systemMessages
|
||||
}
|
||||
|
||||
claudeRequest.Prompt = ""
|
||||
claudeRequest.Messages = claudeMessages
|
||||
return &claudeRequest, nil
|
||||
@@ -425,7 +441,10 @@ func StreamResponseClaude2OpenAI(reqMode int, claudeResponse *dto.ClaudeResponse
|
||||
choice.Delta.Role = "assistant"
|
||||
} else if claudeResponse.Type == "content_block_start" {
|
||||
if claudeResponse.ContentBlock != nil {
|
||||
//choice.Delta.SetContentString(claudeResponse.ContentBlock.Text)
|
||||
// 如果是文本块,尽可能发送首段文本(若存在)
|
||||
if claudeResponse.ContentBlock.Type == "text" && claudeResponse.ContentBlock.Text != nil {
|
||||
choice.Delta.SetContentString(*claudeResponse.ContentBlock.Text)
|
||||
}
|
||||
if claudeResponse.ContentBlock.Type == "tool_use" {
|
||||
tools = append(tools, dto.ToolCallResponse{
|
||||
Index: common.GetPointer(fcIdx),
|
||||
@@ -609,13 +628,13 @@ func HandleStreamResponseData(c *gin.Context, info *relaycommon.RelayInfo, claud
|
||||
var claudeResponse dto.ClaudeResponse
|
||||
err := common.UnmarshalJsonStr(data, &claudeResponse)
|
||||
if err != nil {
|
||||
common.SysError("error unmarshalling stream response: " + err.Error())
|
||||
common.SysLog("error unmarshalling stream response: " + err.Error())
|
||||
return types.NewError(err, types.ErrorCodeBadResponseBody)
|
||||
}
|
||||
if claudeResponse.Error != nil && claudeResponse.Error.Type != "" {
|
||||
return types.WithClaudeError(*claudeResponse.Error, http.StatusInternalServerError)
|
||||
if claudeError := claudeResponse.GetClaudeError(); claudeError != nil && claudeError.Type != "" {
|
||||
return types.WithClaudeError(*claudeError, http.StatusInternalServerError)
|
||||
}
|
||||
if info.RelayFormat == relaycommon.RelayFormatClaude {
|
||||
if info.RelayFormat == types.RelayFormatClaude {
|
||||
FormatClaudeResponseInfo(requestMode, &claudeResponse, nil, claudeInfo)
|
||||
|
||||
if requestMode == RequestModeCompletion {
|
||||
@@ -628,7 +647,7 @@ func HandleStreamResponseData(c *gin.Context, info *relaycommon.RelayInfo, claud
|
||||
}
|
||||
}
|
||||
helper.ClaudeChunkData(c, claudeResponse, data)
|
||||
} else if info.RelayFormat == relaycommon.RelayFormatOpenAI {
|
||||
} else if info.RelayFormat == types.RelayFormatOpenAI {
|
||||
response := StreamResponseClaude2OpenAI(requestMode, &claudeResponse)
|
||||
|
||||
if !FormatClaudeResponseInfo(requestMode, &claudeResponse, response, claudeInfo) {
|
||||
@@ -637,7 +656,7 @@ func HandleStreamResponseData(c *gin.Context, info *relaycommon.RelayInfo, claud
|
||||
|
||||
err = helper.ObjectData(c, response)
|
||||
if err != nil {
|
||||
common.LogError(c, "send_stream_response_failed: "+err.Error())
|
||||
logger.LogError(c, "send_stream_response_failed: "+err.Error())
|
||||
}
|
||||
}
|
||||
return nil
|
||||
@@ -653,28 +672,27 @@ func HandleStreamFinalResponse(c *gin.Context, info *relaycommon.RelayInfo, clau
|
||||
}
|
||||
if claudeInfo.Usage.CompletionTokens == 0 || !claudeInfo.Done {
|
||||
if common.DebugEnabled {
|
||||
common.SysError("claude response usage is not complete, maybe upstream error")
|
||||
common.SysLog("claude response usage is not complete, maybe upstream error")
|
||||
}
|
||||
claudeInfo.Usage = service.ResponseText2Usage(claudeInfo.ResponseText.String(), info.UpstreamModelName, claudeInfo.Usage.PromptTokens)
|
||||
}
|
||||
}
|
||||
|
||||
if info.RelayFormat == relaycommon.RelayFormatClaude {
|
||||
if info.RelayFormat == types.RelayFormatClaude {
|
||||
//
|
||||
} else if info.RelayFormat == relaycommon.RelayFormatOpenAI {
|
||||
|
||||
} else if info.RelayFormat == types.RelayFormatOpenAI {
|
||||
if info.ShouldIncludeUsage {
|
||||
response := helper.GenerateFinalUsageResponse(claudeInfo.ResponseId, claudeInfo.Created, info.UpstreamModelName, *claudeInfo.Usage)
|
||||
err := helper.ObjectData(c, response)
|
||||
if err != nil {
|
||||
common.SysError("send final response failed: " + err.Error())
|
||||
common.SysLog("send final response failed: " + err.Error())
|
||||
}
|
||||
}
|
||||
helper.Done(c)
|
||||
}
|
||||
}
|
||||
|
||||
func ClaudeStreamHandler(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo, requestMode int) (*types.NewAPIError, *dto.Usage) {
|
||||
func ClaudeStreamHandler(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo, requestMode int) (*dto.Usage, *types.NewAPIError) {
|
||||
claudeInfo := &ClaudeResponseInfo{
|
||||
ResponseId: helper.GetResponseID(c),
|
||||
Created: common.GetTimestamp(),
|
||||
@@ -691,21 +709,21 @@ func ClaudeStreamHandler(c *gin.Context, resp *http.Response, info *relaycommon.
|
||||
return true
|
||||
})
|
||||
if err != nil {
|
||||
return err, nil
|
||||
return nil, err
|
||||
}
|
||||
|
||||
HandleStreamFinalResponse(c, info, claudeInfo, requestMode)
|
||||
return nil, claudeInfo.Usage
|
||||
return claudeInfo.Usage, nil
|
||||
}
|
||||
|
||||
func HandleClaudeResponseData(c *gin.Context, info *relaycommon.RelayInfo, claudeInfo *ClaudeResponseInfo, data []byte, requestMode int) *types.NewAPIError {
|
||||
func HandleClaudeResponseData(c *gin.Context, info *relaycommon.RelayInfo, claudeInfo *ClaudeResponseInfo, httpResp *http.Response, data []byte, requestMode int) *types.NewAPIError {
|
||||
var claudeResponse dto.ClaudeResponse
|
||||
err := common.Unmarshal(data, &claudeResponse)
|
||||
if err != nil {
|
||||
return types.NewError(err, types.ErrorCodeBadResponseBody)
|
||||
}
|
||||
if claudeResponse.Error != nil && claudeResponse.Error.Type != "" {
|
||||
return types.WithClaudeError(*claudeResponse.Error, http.StatusInternalServerError)
|
||||
if claudeError := claudeResponse.GetClaudeError(); claudeError != nil && claudeError.Type != "" {
|
||||
return types.WithClaudeError(*claudeError, http.StatusInternalServerError)
|
||||
}
|
||||
if requestMode == RequestModeCompletion {
|
||||
completionTokens := service.CountTextToken(claudeResponse.Completion, info.OriginModelName)
|
||||
@@ -721,14 +739,14 @@ func HandleClaudeResponseData(c *gin.Context, info *relaycommon.RelayInfo, claud
|
||||
}
|
||||
var responseData []byte
|
||||
switch info.RelayFormat {
|
||||
case relaycommon.RelayFormatOpenAI:
|
||||
case types.RelayFormatOpenAI:
|
||||
openaiResponse := ResponseClaude2OpenAI(requestMode, &claudeResponse)
|
||||
openaiResponse.Usage = *claudeInfo.Usage
|
||||
responseData, err = json.Marshal(openaiResponse)
|
||||
if err != nil {
|
||||
return types.NewError(err, types.ErrorCodeBadResponseBody)
|
||||
}
|
||||
case relaycommon.RelayFormatClaude:
|
||||
case types.RelayFormatClaude:
|
||||
responseData = data
|
||||
}
|
||||
|
||||
@@ -736,12 +754,12 @@ func HandleClaudeResponseData(c *gin.Context, info *relaycommon.RelayInfo, claud
|
||||
c.Set("claude_web_search_requests", claudeResponse.Usage.ServerToolUse.WebSearchRequests)
|
||||
}
|
||||
|
||||
common.IOCopyBytesGracefully(c, nil, responseData)
|
||||
service.IOCopyBytesGracefully(c, httpResp, responseData)
|
||||
return nil
|
||||
}
|
||||
|
||||
func ClaudeHandler(c *gin.Context, resp *http.Response, requestMode int, info *relaycommon.RelayInfo) (*types.NewAPIError, *dto.Usage) {
|
||||
defer common.CloseResponseBodyGracefully(resp)
|
||||
func ClaudeHandler(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo, requestMode int) (*dto.Usage, *types.NewAPIError) {
|
||||
defer service.CloseResponseBodyGracefully(resp)
|
||||
|
||||
claudeInfo := &ClaudeResponseInfo{
|
||||
ResponseId: helper.GetResponseID(c),
|
||||
@@ -752,16 +770,16 @@ func ClaudeHandler(c *gin.Context, resp *http.Response, requestMode int, info *r
|
||||
}
|
||||
responseBody, err := io.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
return types.NewError(err, types.ErrorCodeBadResponseBody), nil
|
||||
return nil, types.NewError(err, types.ErrorCodeBadResponseBody)
|
||||
}
|
||||
if common.DebugEnabled {
|
||||
println("responseBody: ", string(responseBody))
|
||||
}
|
||||
handleErr := HandleClaudeResponseData(c, info, claudeInfo, responseBody, requestMode)
|
||||
handleErr := HandleClaudeResponseData(c, info, claudeInfo, resp, responseBody, requestMode)
|
||||
if handleErr != nil {
|
||||
return handleErr, nil
|
||||
return nil, handleErr
|
||||
}
|
||||
return nil, claudeInfo.Usage
|
||||
return claudeInfo.Usage, nil
|
||||
}
|
||||
|
||||
func mapToolChoice(toolChoice any, parallelToolCalls *bool) *dto.ClaudeToolChoice {
|
||||
|
||||
@@ -8,6 +8,7 @@ import (
|
||||
"net/http"
|
||||
"one-api/dto"
|
||||
"one-api/relay/channel"
|
||||
"one-api/relay/channel/openai"
|
||||
relaycommon "one-api/relay/common"
|
||||
"one-api/relay/constant"
|
||||
"one-api/types"
|
||||
@@ -18,6 +19,11 @@ import (
|
||||
type Adaptor struct {
|
||||
}
|
||||
|
||||
func (a *Adaptor) ConvertGeminiRequest(*gin.Context, *relaycommon.RelayInfo, *dto.GeminiChatRequest) (any, error) {
|
||||
//TODO implement me
|
||||
return nil, errors.New("not implemented")
|
||||
}
|
||||
|
||||
func (a *Adaptor) ConvertClaudeRequest(*gin.Context, *relaycommon.RelayInfo, *dto.ClaudeRequest) (any, error) {
|
||||
//TODO implement me
|
||||
panic("implement me")
|
||||
@@ -30,11 +36,13 @@ func (a *Adaptor) Init(info *relaycommon.RelayInfo) {
|
||||
func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) {
|
||||
switch info.RelayMode {
|
||||
case constant.RelayModeChatCompletions:
|
||||
return fmt.Sprintf("%s/client/v4/accounts/%s/ai/v1/chat/completions", info.BaseUrl, info.ApiVersion), nil
|
||||
return fmt.Sprintf("%s/client/v4/accounts/%s/ai/v1/chat/completions", info.ChannelBaseUrl, info.ApiVersion), nil
|
||||
case constant.RelayModeEmbeddings:
|
||||
return fmt.Sprintf("%s/client/v4/accounts/%s/ai/v1/embeddings", info.BaseUrl, info.ApiVersion), nil
|
||||
return fmt.Sprintf("%s/client/v4/accounts/%s/ai/v1/embeddings", info.ChannelBaseUrl, info.ApiVersion), nil
|
||||
case constant.RelayModeResponses:
|
||||
return fmt.Sprintf("%s/client/v4/accounts/%s/ai/v1/responses", info.ChannelBaseUrl, info.ApiVersion), nil
|
||||
default:
|
||||
return fmt.Sprintf("%s/client/v4/accounts/%s/ai/run/%s", info.BaseUrl, info.ApiVersion, info.UpstreamModelName), nil
|
||||
return fmt.Sprintf("%s/client/v4/accounts/%s/ai/run/%s", info.ChannelBaseUrl, info.ApiVersion, info.UpstreamModelName), nil
|
||||
}
|
||||
}
|
||||
|
||||
@@ -57,8 +65,7 @@ func (a *Adaptor) ConvertOpenAIRequest(c *gin.Context, info *relaycommon.RelayIn
|
||||
}
|
||||
|
||||
func (a *Adaptor) ConvertOpenAIResponsesRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.OpenAIResponsesRequest) (any, error) {
|
||||
// TODO implement me
|
||||
return nil, errors.New("not implemented")
|
||||
return request, nil
|
||||
}
|
||||
|
||||
func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (any, error) {
|
||||
@@ -105,6 +112,12 @@ func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycom
|
||||
} else {
|
||||
err, usage = cfHandler(c, info, resp)
|
||||
}
|
||||
case constant.RelayModeResponses:
|
||||
if info.IsStream {
|
||||
usage, err = openai.OaiResponsesStreamHandler(c, info, resp)
|
||||
} else {
|
||||
usage, err = openai.OaiResponsesHandler(c, info, resp)
|
||||
}
|
||||
case constant.RelayModeAudioTranslation:
|
||||
fallthrough
|
||||
case constant.RelayModeAudioTranscription:
|
||||
|
||||
@@ -5,7 +5,7 @@ import "one-api/dto"
|
||||
type CfRequest struct {
|
||||
Messages []dto.Message `json:"messages,omitempty"`
|
||||
Lora string `json:"lora,omitempty"`
|
||||
MaxTokens int `json:"max_tokens,omitempty"`
|
||||
MaxTokens uint `json:"max_tokens,omitempty"`
|
||||
Prompt string `json:"prompt,omitempty"`
|
||||
Raw bool `json:"raw,omitempty"`
|
||||
Stream bool `json:"stream,omitempty"`
|
||||
|
||||
@@ -5,8 +5,8 @@ import (
|
||||
"encoding/json"
|
||||
"io"
|
||||
"net/http"
|
||||
"one-api/common"
|
||||
"one-api/dto"
|
||||
"one-api/logger"
|
||||
relaycommon "one-api/relay/common"
|
||||
"one-api/relay/helper"
|
||||
"one-api/service"
|
||||
@@ -51,7 +51,7 @@ func cfStreamHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http.Res
|
||||
var response dto.ChatCompletionsStreamResponse
|
||||
err := json.Unmarshal([]byte(data), &response)
|
||||
if err != nil {
|
||||
common.LogError(c, "error_unmarshalling_stream_response: "+err.Error())
|
||||
logger.LogError(c, "error_unmarshalling_stream_response: "+err.Error())
|
||||
continue
|
||||
}
|
||||
for _, choice := range response.Choices {
|
||||
@@ -66,24 +66,24 @@ func cfStreamHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http.Res
|
||||
info.FirstResponseTime = time.Now()
|
||||
}
|
||||
if err != nil {
|
||||
common.LogError(c, "error_rendering_stream_response: "+err.Error())
|
||||
logger.LogError(c, "error_rendering_stream_response: "+err.Error())
|
||||
}
|
||||
}
|
||||
|
||||
if err := scanner.Err(); err != nil {
|
||||
common.LogError(c, "error_scanning_stream_response: "+err.Error())
|
||||
logger.LogError(c, "error_scanning_stream_response: "+err.Error())
|
||||
}
|
||||
usage := service.ResponseText2Usage(responseText, info.UpstreamModelName, info.PromptTokens)
|
||||
if info.ShouldIncludeUsage {
|
||||
response := helper.GenerateFinalUsageResponse(id, info.StartTime.Unix(), info.UpstreamModelName, *usage)
|
||||
err := helper.ObjectData(c, response)
|
||||
if err != nil {
|
||||
common.LogError(c, "error_rendering_final_usage_response: "+err.Error())
|
||||
logger.LogError(c, "error_rendering_final_usage_response: "+err.Error())
|
||||
}
|
||||
}
|
||||
helper.Done(c)
|
||||
|
||||
common.CloseResponseBodyGracefully(resp)
|
||||
service.CloseResponseBodyGracefully(resp)
|
||||
|
||||
return nil, usage
|
||||
}
|
||||
@@ -93,7 +93,7 @@ func cfHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http.Response)
|
||||
if err != nil {
|
||||
return types.NewError(err, types.ErrorCodeBadResponseBody), nil
|
||||
}
|
||||
common.CloseResponseBodyGracefully(resp)
|
||||
service.CloseResponseBodyGracefully(resp)
|
||||
var response dto.TextResponse
|
||||
err = json.Unmarshal(responseBody, &response)
|
||||
if err != nil {
|
||||
@@ -123,7 +123,7 @@ func cfSTTHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http.Respon
|
||||
if err != nil {
|
||||
return types.NewError(err, types.ErrorCodeBadResponseBody), nil
|
||||
}
|
||||
common.CloseResponseBodyGracefully(resp)
|
||||
service.CloseResponseBodyGracefully(resp)
|
||||
err = json.Unmarshal(responseBody, &cfResp)
|
||||
if err != nil {
|
||||
return types.NewError(err, types.ErrorCodeBadResponseBody), nil
|
||||
|
||||
@@ -17,6 +17,11 @@ import (
|
||||
type Adaptor struct {
|
||||
}
|
||||
|
||||
func (a *Adaptor) ConvertGeminiRequest(*gin.Context, *relaycommon.RelayInfo, *dto.GeminiChatRequest) (any, error) {
|
||||
//TODO implement me
|
||||
return nil, errors.New("not implemented")
|
||||
}
|
||||
|
||||
func (a *Adaptor) ConvertClaudeRequest(*gin.Context, *relaycommon.RelayInfo, *dto.ClaudeRequest) (any, error) {
|
||||
//TODO implement me
|
||||
panic("implement me")
|
||||
@@ -38,9 +43,9 @@ func (a *Adaptor) Init(info *relaycommon.RelayInfo) {
|
||||
|
||||
func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) {
|
||||
if info.RelayMode == constant.RelayModeRerank {
|
||||
return fmt.Sprintf("%s/v1/rerank", info.BaseUrl), nil
|
||||
return fmt.Sprintf("%s/v1/rerank", info.ChannelBaseUrl), nil
|
||||
} else {
|
||||
return fmt.Sprintf("%s/v1/chat", info.BaseUrl), nil
|
||||
return fmt.Sprintf("%s/v1/chat", info.ChannelBaseUrl), nil
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -7,7 +7,7 @@ type CohereRequest struct {
|
||||
ChatHistory []ChatHistory `json:"chat_history"`
|
||||
Message string `json:"message"`
|
||||
Stream bool `json:"stream"`
|
||||
MaxTokens int `json:"max_tokens"`
|
||||
MaxTokens uint `json:"max_tokens"`
|
||||
SafetyMode string `json:"safety_mode,omitempty"`
|
||||
}
|
||||
|
||||
|
||||
@@ -118,7 +118,7 @@ func cohereStreamHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http
|
||||
var cohereResp CohereResponse
|
||||
err := json.Unmarshal([]byte(data), &cohereResp)
|
||||
if err != nil {
|
||||
common.SysError("error unmarshalling stream response: " + err.Error())
|
||||
common.SysLog("error unmarshalling stream response: " + err.Error())
|
||||
return true
|
||||
}
|
||||
var openaiResp dto.ChatCompletionsStreamResponse
|
||||
@@ -153,7 +153,7 @@ func cohereStreamHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http
|
||||
}
|
||||
jsonStr, err := json.Marshal(openaiResp)
|
||||
if err != nil {
|
||||
common.SysError("error marshalling stream response: " + err.Error())
|
||||
common.SysLog("error marshalling stream response: " + err.Error())
|
||||
return true
|
||||
}
|
||||
c.Render(-1, common.CustomEvent{Data: "data: " + string(jsonStr)})
|
||||
@@ -175,7 +175,7 @@ func cohereHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http.Respo
|
||||
if err != nil {
|
||||
return nil, types.NewError(err, types.ErrorCodeBadResponseBody)
|
||||
}
|
||||
common.CloseResponseBodyGracefully(resp)
|
||||
service.CloseResponseBodyGracefully(resp)
|
||||
var cohereResp CohereResponseResult
|
||||
err = json.Unmarshal(responseBody, &cohereResp)
|
||||
if err != nil {
|
||||
@@ -216,7 +216,7 @@ func cohereRerankHandler(c *gin.Context, resp *http.Response, info *relaycommon.
|
||||
if err != nil {
|
||||
return nil, types.NewError(err, types.ErrorCodeBadResponseBody)
|
||||
}
|
||||
common.CloseResponseBodyGracefully(resp)
|
||||
service.CloseResponseBodyGracefully(resp)
|
||||
var cohereResp CohereRerankResponseResult
|
||||
err = json.Unmarshal(responseBody, &cohereResp)
|
||||
if err != nil {
|
||||
|
||||
@@ -18,6 +18,11 @@ import (
|
||||
type Adaptor struct {
|
||||
}
|
||||
|
||||
func (a *Adaptor) ConvertGeminiRequest(*gin.Context, *common.RelayInfo, *dto.GeminiChatRequest) (any, error) {
|
||||
//TODO implement me
|
||||
return nil, errors.New("not implemented")
|
||||
}
|
||||
|
||||
// ConvertAudioRequest implements channel.Adaptor.
|
||||
func (a *Adaptor) ConvertAudioRequest(c *gin.Context, info *common.RelayInfo, request dto.AudioRequest) (io.Reader, error) {
|
||||
return nil, errors.New("not implemented")
|
||||
@@ -117,7 +122,7 @@ func (a *Adaptor) GetModelList() []string {
|
||||
|
||||
// GetRequestURL implements channel.Adaptor.
|
||||
func (a *Adaptor) GetRequestURL(info *common.RelayInfo) (string, error) {
|
||||
return fmt.Sprintf("%s/v3/chat", info.BaseUrl), nil
|
||||
return fmt.Sprintf("%s/v3/chat", info.ChannelBaseUrl), nil
|
||||
}
|
||||
|
||||
// Init implements channel.Adaptor.
|
||||
|
||||
@@ -49,7 +49,7 @@ func cozeChatHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http.Res
|
||||
if err != nil {
|
||||
return nil, types.NewError(err, types.ErrorCodeBadResponseBody)
|
||||
}
|
||||
common.CloseResponseBodyGracefully(resp)
|
||||
service.CloseResponseBodyGracefully(resp)
|
||||
// convert coze response to openai response
|
||||
var response dto.TextResponse
|
||||
var cozeResponse CozeChatDetailResponse
|
||||
@@ -154,7 +154,7 @@ func handleCozeEvent(c *gin.Context, event string, data string, responseText *st
|
||||
var chatData CozeChatResponseData
|
||||
err := json.Unmarshal([]byte(data), &chatData)
|
||||
if err != nil {
|
||||
common.SysError("error_unmarshalling_stream_response: " + err.Error())
|
||||
common.SysLog("error_unmarshalling_stream_response: " + err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
@@ -171,14 +171,14 @@ func handleCozeEvent(c *gin.Context, event string, data string, responseText *st
|
||||
var messageData CozeChatV3MessageDetail
|
||||
err := json.Unmarshal([]byte(data), &messageData)
|
||||
if err != nil {
|
||||
common.SysError("error_unmarshalling_stream_response: " + err.Error())
|
||||
common.SysLog("error_unmarshalling_stream_response: " + err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
var content string
|
||||
err = json.Unmarshal(messageData.Content, &content)
|
||||
if err != nil {
|
||||
common.SysError("error_unmarshalling_stream_response: " + err.Error())
|
||||
common.SysLog("error_unmarshalling_stream_response: " + err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
@@ -203,16 +203,16 @@ func handleCozeEvent(c *gin.Context, event string, data string, responseText *st
|
||||
var errorData CozeError
|
||||
err := json.Unmarshal([]byte(data), &errorData)
|
||||
if err != nil {
|
||||
common.SysError("error_unmarshalling_stream_response: " + err.Error())
|
||||
common.SysLog("error_unmarshalling_stream_response: " + err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
common.SysError(fmt.Sprintf("stream event error: ", errorData.Code, errorData.Message))
|
||||
common.SysLog(fmt.Sprintf("stream event error: ", errorData.Code, errorData.Message))
|
||||
}
|
||||
}
|
||||
|
||||
func checkIfChatComplete(a *Adaptor, c *gin.Context, info *relaycommon.RelayInfo) (error, bool) {
|
||||
requestURL := fmt.Sprintf("%s/v3/chat/retrieve", info.BaseUrl)
|
||||
requestURL := fmt.Sprintf("%s/v3/chat/retrieve", info.ChannelBaseUrl)
|
||||
|
||||
requestURL = requestURL + "?conversation_id=" + c.GetString("coze_conversation_id") + "&chat_id=" + c.GetString("coze_chat_id")
|
||||
// 将 conversationId和chatId作为参数发送get请求
|
||||
@@ -258,7 +258,7 @@ func checkIfChatComplete(a *Adaptor, c *gin.Context, info *relaycommon.RelayInfo
|
||||
}
|
||||
|
||||
func getChatDetail(a *Adaptor, c *gin.Context, info *relaycommon.RelayInfo) (*http.Response, error) {
|
||||
requestURL := fmt.Sprintf("%s/v3/chat/message/list", info.BaseUrl)
|
||||
requestURL := fmt.Sprintf("%s/v3/chat/message/list", info.ChannelBaseUrl)
|
||||
|
||||
requestURL = requestURL + "?conversation_id=" + c.GetString("coze_conversation_id") + "&chat_id=" + c.GetString("coze_chat_id")
|
||||
req, err := http.NewRequest("GET", requestURL, nil)
|
||||
|
||||
@@ -19,10 +19,14 @@ import (
|
||||
type Adaptor struct {
|
||||
}
|
||||
|
||||
func (a *Adaptor) ConvertClaudeRequest(*gin.Context, *relaycommon.RelayInfo, *dto.ClaudeRequest) (any, error) {
|
||||
func (a *Adaptor) ConvertGeminiRequest(*gin.Context, *relaycommon.RelayInfo, *dto.GeminiChatRequest) (any, error) {
|
||||
//TODO implement me
|
||||
panic("implement me")
|
||||
return nil, nil
|
||||
return nil, errors.New("not implemented")
|
||||
}
|
||||
|
||||
func (a *Adaptor) ConvertClaudeRequest(c *gin.Context, info *relaycommon.RelayInfo, req *dto.ClaudeRequest) (any, error) {
|
||||
adaptor := openai.Adaptor{}
|
||||
return adaptor.ConvertClaudeRequest(c, info, req)
|
||||
}
|
||||
|
||||
func (a *Adaptor) ConvertAudioRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.AudioRequest) (io.Reader, error) {
|
||||
@@ -39,15 +43,15 @@ func (a *Adaptor) Init(info *relaycommon.RelayInfo) {
|
||||
}
|
||||
|
||||
func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) {
|
||||
fimBaseUrl := info.BaseUrl
|
||||
if !strings.HasSuffix(info.BaseUrl, "/beta") {
|
||||
fimBaseUrl := info.ChannelBaseUrl
|
||||
if !strings.HasSuffix(info.ChannelBaseUrl, "/beta") {
|
||||
fimBaseUrl += "/beta"
|
||||
}
|
||||
switch info.RelayMode {
|
||||
case constant.RelayModeCompletions:
|
||||
return fmt.Sprintf("%s/completions", fimBaseUrl), nil
|
||||
default:
|
||||
return fmt.Sprintf("%s/v1/chat/completions", info.BaseUrl), nil
|
||||
return fmt.Sprintf("%s/v1/chat/completions", info.ChannelBaseUrl), nil
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -24,6 +24,11 @@ type Adaptor struct {
|
||||
BotType int
|
||||
}
|
||||
|
||||
func (a *Adaptor) ConvertGeminiRequest(*gin.Context, *relaycommon.RelayInfo, *dto.GeminiChatRequest) (any, error) {
|
||||
//TODO implement me
|
||||
return nil, errors.New("not implemented")
|
||||
}
|
||||
|
||||
func (a *Adaptor) ConvertClaudeRequest(*gin.Context, *relaycommon.RelayInfo, *dto.ClaudeRequest) (any, error) {
|
||||
//TODO implement me
|
||||
panic("implement me")
|
||||
@@ -56,13 +61,13 @@ func (a *Adaptor) Init(info *relaycommon.RelayInfo) {
|
||||
func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) {
|
||||
switch a.BotType {
|
||||
case BotTypeWorkFlow:
|
||||
return fmt.Sprintf("%s/v1/workflows/run", info.BaseUrl), nil
|
||||
return fmt.Sprintf("%s/v1/workflows/run", info.ChannelBaseUrl), nil
|
||||
case BotTypeCompletion:
|
||||
return fmt.Sprintf("%s/v1/completion-messages", info.BaseUrl), nil
|
||||
return fmt.Sprintf("%s/v1/completion-messages", info.ChannelBaseUrl), nil
|
||||
case BotTypeAgent:
|
||||
fallthrough
|
||||
default:
|
||||
return fmt.Sprintf("%s/v1/chat-messages", info.BaseUrl), nil
|
||||
return fmt.Sprintf("%s/v1/chat-messages", info.ChannelBaseUrl), nil
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -22,7 +22,7 @@ import (
|
||||
)
|
||||
|
||||
func uploadDifyFile(c *gin.Context, info *relaycommon.RelayInfo, user string, media dto.MediaContent) *DifyFile {
|
||||
uploadUrl := fmt.Sprintf("%s/v1/files/upload", info.BaseUrl)
|
||||
uploadUrl := fmt.Sprintf("%s/v1/files/upload", info.ChannelBaseUrl)
|
||||
switch media.Type {
|
||||
case dto.ContentTypeImageURL:
|
||||
// Decode base64 data
|
||||
@@ -36,14 +36,14 @@ func uploadDifyFile(c *gin.Context, info *relaycommon.RelayInfo, user string, me
|
||||
// Decode base64 string
|
||||
decodedData, err := base64.StdEncoding.DecodeString(base64Data)
|
||||
if err != nil {
|
||||
common.SysError("failed to decode base64: " + err.Error())
|
||||
common.SysLog("failed to decode base64: " + err.Error())
|
||||
return nil
|
||||
}
|
||||
|
||||
// Create temporary file
|
||||
tempFile, err := os.CreateTemp("", "dify-upload-*")
|
||||
if err != nil {
|
||||
common.SysError("failed to create temp file: " + err.Error())
|
||||
common.SysLog("failed to create temp file: " + err.Error())
|
||||
return nil
|
||||
}
|
||||
defer tempFile.Close()
|
||||
@@ -51,7 +51,7 @@ func uploadDifyFile(c *gin.Context, info *relaycommon.RelayInfo, user string, me
|
||||
|
||||
// Write decoded data to temp file
|
||||
if _, err := tempFile.Write(decodedData); err != nil {
|
||||
common.SysError("failed to write to temp file: " + err.Error())
|
||||
common.SysLog("failed to write to temp file: " + err.Error())
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -61,7 +61,7 @@ func uploadDifyFile(c *gin.Context, info *relaycommon.RelayInfo, user string, me
|
||||
|
||||
// Add user field
|
||||
if err := writer.WriteField("user", user); err != nil {
|
||||
common.SysError("failed to add user field: " + err.Error())
|
||||
common.SysLog("failed to add user field: " + err.Error())
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -74,13 +74,13 @@ func uploadDifyFile(c *gin.Context, info *relaycommon.RelayInfo, user string, me
|
||||
// Create form file
|
||||
part, err := writer.CreateFormFile("file", fmt.Sprintf("image.%s", strings.TrimPrefix(mimeType, "image/")))
|
||||
if err != nil {
|
||||
common.SysError("failed to create form file: " + err.Error())
|
||||
common.SysLog("failed to create form file: " + err.Error())
|
||||
return nil
|
||||
}
|
||||
|
||||
// Copy file content to form
|
||||
if _, err = io.Copy(part, bytes.NewReader(decodedData)); err != nil {
|
||||
common.SysError("failed to copy file content: " + err.Error())
|
||||
common.SysLog("failed to copy file content: " + err.Error())
|
||||
return nil
|
||||
}
|
||||
writer.Close()
|
||||
@@ -88,7 +88,7 @@ func uploadDifyFile(c *gin.Context, info *relaycommon.RelayInfo, user string, me
|
||||
// Create HTTP request
|
||||
req, err := http.NewRequest("POST", uploadUrl, body)
|
||||
if err != nil {
|
||||
common.SysError("failed to create request: " + err.Error())
|
||||
common.SysLog("failed to create request: " + err.Error())
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -99,7 +99,7 @@ func uploadDifyFile(c *gin.Context, info *relaycommon.RelayInfo, user string, me
|
||||
client := service.GetHttpClient()
|
||||
resp, err := client.Do(req)
|
||||
if err != nil {
|
||||
common.SysError("failed to send request: " + err.Error())
|
||||
common.SysLog("failed to send request: " + err.Error())
|
||||
return nil
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
@@ -109,7 +109,7 @@ func uploadDifyFile(c *gin.Context, info *relaycommon.RelayInfo, user string, me
|
||||
Id string `json:"id"`
|
||||
}
|
||||
if err := json.NewDecoder(resp.Body).Decode(&result); err != nil {
|
||||
common.SysError("failed to decode response: " + err.Error())
|
||||
common.SysLog("failed to decode response: " + err.Error())
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -219,7 +219,7 @@ func difyStreamHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http.R
|
||||
var difyResponse DifyChunkChatCompletionResponse
|
||||
err := json.Unmarshal([]byte(data), &difyResponse)
|
||||
if err != nil {
|
||||
common.SysError("error unmarshalling stream response: " + err.Error())
|
||||
common.SysLog("error unmarshalling stream response: " + err.Error())
|
||||
return true
|
||||
}
|
||||
var openaiResponse dto.ChatCompletionsStreamResponse
|
||||
@@ -239,7 +239,7 @@ func difyStreamHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http.R
|
||||
}
|
||||
err = helper.ObjectData(c, openaiResponse)
|
||||
if err != nil {
|
||||
common.SysError(err.Error())
|
||||
common.SysLog(err.Error())
|
||||
}
|
||||
return true
|
||||
})
|
||||
@@ -258,7 +258,7 @@ func difyHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http.Respons
|
||||
if err != nil {
|
||||
return nil, types.NewError(err, types.ErrorCodeBadResponseBody)
|
||||
}
|
||||
common.CloseResponseBodyGracefully(resp)
|
||||
service.CloseResponseBodyGracefully(resp)
|
||||
err = json.Unmarshal(responseBody, &difyResponse)
|
||||
if err != nil {
|
||||
return nil, types.NewError(err, types.ErrorCodeBadResponseBody)
|
||||
|
||||
@@ -1,14 +1,13 @@
|
||||
package gemini
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"one-api/common"
|
||||
"one-api/dto"
|
||||
"one-api/relay/channel"
|
||||
"one-api/relay/channel/openai"
|
||||
relaycommon "one-api/relay/common"
|
||||
"one-api/relay/constant"
|
||||
"one-api/setting/model_setting"
|
||||
@@ -21,10 +20,33 @@ import (
|
||||
type Adaptor struct {
|
||||
}
|
||||
|
||||
func (a *Adaptor) ConvertClaudeRequest(*gin.Context, *relaycommon.RelayInfo, *dto.ClaudeRequest) (any, error) {
|
||||
//TODO implement me
|
||||
panic("implement me")
|
||||
return nil, nil
|
||||
func (a *Adaptor) ConvertGeminiRequest(c *gin.Context, info *relaycommon.RelayInfo, request *dto.GeminiChatRequest) (any, error) {
|
||||
if len(request.Contents) > 0 {
|
||||
for i, content := range request.Contents {
|
||||
if i == 0 {
|
||||
if request.Contents[0].Role == "" {
|
||||
request.Contents[0].Role = "user"
|
||||
}
|
||||
}
|
||||
for _, part := range content.Parts {
|
||||
if part.FileData != nil {
|
||||
if part.FileData.MimeType == "" && strings.Contains(part.FileData.FileUri, "www.youtube.com") {
|
||||
part.FileData.MimeType = "video/webm"
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
return request, nil
|
||||
}
|
||||
|
||||
func (a *Adaptor) ConvertClaudeRequest(c *gin.Context, info *relaycommon.RelayInfo, req *dto.ClaudeRequest) (any, error) {
|
||||
adaptor := openai.Adaptor{}
|
||||
oaiReq, err := adaptor.ConvertClaudeRequest(c, info, req)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return a.ConvertOpenAIRequest(c, info, oaiReq.(*dto.GeneralOpenAIRequest))
|
||||
}
|
||||
|
||||
func (a *Adaptor) ConvertAudioRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.AudioRequest) (io.Reader, error) {
|
||||
@@ -37,26 +59,33 @@ func (a *Adaptor) ConvertImageRequest(c *gin.Context, info *relaycommon.RelayInf
|
||||
return nil, errors.New("not supported model for image generation")
|
||||
}
|
||||
|
||||
// convert size to aspect ratio
|
||||
// convert size to aspect ratio but allow user to specify aspect ratio
|
||||
aspectRatio := "1:1" // default aspect ratio
|
||||
switch request.Size {
|
||||
case "1024x1024":
|
||||
aspectRatio = "1:1"
|
||||
case "1024x1792":
|
||||
aspectRatio = "9:16"
|
||||
case "1792x1024":
|
||||
aspectRatio = "16:9"
|
||||
size := strings.TrimSpace(request.Size)
|
||||
if size != "" {
|
||||
if strings.Contains(size, ":") {
|
||||
aspectRatio = size
|
||||
} else {
|
||||
switch size {
|
||||
case "1024x1024":
|
||||
aspectRatio = "1:1"
|
||||
case "1024x1792":
|
||||
aspectRatio = "9:16"
|
||||
case "1792x1024":
|
||||
aspectRatio = "16:9"
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// build gemini imagen request
|
||||
geminiRequest := GeminiImageRequest{
|
||||
Instances: []GeminiImageInstance{
|
||||
geminiRequest := dto.GeminiImageRequest{
|
||||
Instances: []dto.GeminiImageInstance{
|
||||
{
|
||||
Prompt: request.Prompt,
|
||||
},
|
||||
},
|
||||
Parameters: GeminiImageParameters{
|
||||
SampleCount: request.N,
|
||||
Parameters: dto.GeminiImageParameters{
|
||||
SampleCount: int(request.N),
|
||||
AspectRatio: aspectRatio,
|
||||
PersonGeneration: "allow_adult", // default allow adult
|
||||
},
|
||||
@@ -86,20 +115,27 @@ func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) {
|
||||
version := model_setting.GetGeminiVersionSetting(info.UpstreamModelName)
|
||||
|
||||
if strings.HasPrefix(info.UpstreamModelName, "imagen") {
|
||||
return fmt.Sprintf("%s/%s/models/%s:predict", info.BaseUrl, version, info.UpstreamModelName), nil
|
||||
return fmt.Sprintf("%s/%s/models/%s:predict", info.ChannelBaseUrl, version, info.UpstreamModelName), nil
|
||||
}
|
||||
|
||||
if strings.HasPrefix(info.UpstreamModelName, "text-embedding") ||
|
||||
strings.HasPrefix(info.UpstreamModelName, "embedding") ||
|
||||
strings.HasPrefix(info.UpstreamModelName, "gemini-embedding") {
|
||||
return fmt.Sprintf("%s/%s/models/%s:embedContent", info.BaseUrl, version, info.UpstreamModelName), nil
|
||||
action := "embedContent"
|
||||
if info.IsGeminiBatchEmbedding {
|
||||
action = "batchEmbedContents"
|
||||
}
|
||||
return fmt.Sprintf("%s/%s/models/%s:%s", info.ChannelBaseUrl, version, info.UpstreamModelName, action), nil
|
||||
}
|
||||
|
||||
action := "generateContent"
|
||||
if info.IsStream {
|
||||
action = "streamGenerateContent?alt=sse"
|
||||
if info.RelayMode == constant.RelayModeGemini {
|
||||
info.DisablePing = true
|
||||
}
|
||||
}
|
||||
return fmt.Sprintf("%s/%s/models/%s:%s", info.BaseUrl, version, info.UpstreamModelName, action), nil
|
||||
return fmt.Sprintf("%s/%s/models/%s:%s", info.ChannelBaseUrl, version, info.UpstreamModelName, action), nil
|
||||
}
|
||||
|
||||
func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Header, info *relaycommon.RelayInfo) error {
|
||||
@@ -113,7 +149,7 @@ func (a *Adaptor) ConvertOpenAIRequest(c *gin.Context, info *relaycommon.RelayIn
|
||||
return nil, errors.New("request is nil")
|
||||
}
|
||||
|
||||
geminiRequest, err := CovertGemini2OpenAI(*request, info)
|
||||
geminiRequest, err := CovertGemini2OpenAI(c, *request, info)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -134,29 +170,38 @@ func (a *Adaptor) ConvertEmbeddingRequest(c *gin.Context, info *relaycommon.Rela
|
||||
if len(inputs) == 0 {
|
||||
return nil, errors.New("input is empty")
|
||||
}
|
||||
|
||||
// only process the first input
|
||||
geminiRequest := GeminiEmbeddingRequest{
|
||||
Content: GeminiChatContent{
|
||||
Parts: []GeminiPart{
|
||||
{
|
||||
Text: inputs[0],
|
||||
// We always build a batch-style payload with `requests`, so ensure we call the
|
||||
// batch endpoint upstream to avoid payload/endpoint mismatches.
|
||||
info.IsGeminiBatchEmbedding = true
|
||||
// process all inputs
|
||||
geminiRequests := make([]map[string]interface{}, 0, len(inputs))
|
||||
for _, input := range inputs {
|
||||
geminiRequest := map[string]interface{}{
|
||||
"model": fmt.Sprintf("models/%s", info.UpstreamModelName),
|
||||
"content": dto.GeminiChatContent{
|
||||
Parts: []dto.GeminiPart{
|
||||
{
|
||||
Text: input,
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
// set specific parameters for different models
|
||||
// https://ai.google.dev/api/embeddings?hl=zh-cn#method:-models.embedcontent
|
||||
switch info.UpstreamModelName {
|
||||
case "text-embedding-004":
|
||||
// except embedding-001 supports setting `OutputDimensionality`
|
||||
if request.Dimensions > 0 {
|
||||
geminiRequest.OutputDimensionality = request.Dimensions
|
||||
}
|
||||
|
||||
// set specific parameters for different models
|
||||
// https://ai.google.dev/api/embeddings?hl=zh-cn#method:-models.embedcontent
|
||||
switch info.UpstreamModelName {
|
||||
case "text-embedding-004", "gemini-embedding-exp-03-07", "gemini-embedding-001":
|
||||
// Only newer models introduced after 2024 support OutputDimensionality
|
||||
if request.Dimensions > 0 {
|
||||
geminiRequest["outputDimensionality"] = request.Dimensions
|
||||
}
|
||||
}
|
||||
geminiRequests = append(geminiRequests, geminiRequest)
|
||||
}
|
||||
|
||||
return geminiRequest, nil
|
||||
return map[string]interface{}{
|
||||
"requests": geminiRequests,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (a *Adaptor) ConvertOpenAIResponsesRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.OpenAIResponsesRequest) (any, error) {
|
||||
@@ -170,6 +215,10 @@ func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, request
|
||||
|
||||
func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage any, err *types.NewAPIError) {
|
||||
if info.RelayMode == constant.RelayModeGemini {
|
||||
if strings.HasSuffix(info.RequestURLPath, ":embedContent") ||
|
||||
strings.HasSuffix(info.RequestURLPath, ":batchEmbedContents") {
|
||||
return NativeGeminiEmbeddingHandler(c, resp, info)
|
||||
}
|
||||
if info.IsStream {
|
||||
return GeminiTextGenerationStreamHandler(c, info, resp)
|
||||
} else {
|
||||
@@ -194,72 +243,6 @@ func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycom
|
||||
return GeminiChatHandler(c, info, resp)
|
||||
}
|
||||
|
||||
//if usage.(*dto.Usage).CompletionTokenDetails.ReasoningTokens > 100 {
|
||||
// // 没有请求-thinking的情况下,产生思考token,则按照思考模型计费
|
||||
// if !strings.HasSuffix(info.OriginModelName, "-thinking") &&
|
||||
// !strings.HasSuffix(info.OriginModelName, "-nothinking") {
|
||||
// thinkingModelName := info.OriginModelName + "-thinking"
|
||||
// if operation_setting.SelfUseModeEnabled || helper.ContainPriceOrRatio(thinkingModelName) {
|
||||
// info.OriginModelName = thinkingModelName
|
||||
// }
|
||||
// }
|
||||
//}
|
||||
|
||||
return nil, types.NewError(errors.New("not implemented"), types.ErrorCodeBadResponseBody)
|
||||
}
|
||||
|
||||
func GeminiImageHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http.Response) (*dto.Usage, *types.NewAPIError) {
|
||||
responseBody, readErr := io.ReadAll(resp.Body)
|
||||
if readErr != nil {
|
||||
return nil, types.NewError(readErr, types.ErrorCodeBadResponseBody)
|
||||
}
|
||||
_ = resp.Body.Close()
|
||||
|
||||
var geminiResponse GeminiImageResponse
|
||||
if jsonErr := json.Unmarshal(responseBody, &geminiResponse); jsonErr != nil {
|
||||
return nil, types.NewError(jsonErr, types.ErrorCodeBadResponseBody)
|
||||
}
|
||||
|
||||
if len(geminiResponse.Predictions) == 0 {
|
||||
return nil, types.NewError(errors.New("no images generated"), types.ErrorCodeBadResponseBody)
|
||||
}
|
||||
|
||||
// convert to openai format response
|
||||
openAIResponse := dto.ImageResponse{
|
||||
Created: common.GetTimestamp(),
|
||||
Data: make([]dto.ImageData, 0, len(geminiResponse.Predictions)),
|
||||
}
|
||||
|
||||
for _, prediction := range geminiResponse.Predictions {
|
||||
if prediction.RaiFilteredReason != "" {
|
||||
continue // skip filtered image
|
||||
}
|
||||
openAIResponse.Data = append(openAIResponse.Data, dto.ImageData{
|
||||
B64Json: prediction.BytesBase64Encoded,
|
||||
})
|
||||
}
|
||||
|
||||
jsonResponse, jsonErr := json.Marshal(openAIResponse)
|
||||
if jsonErr != nil {
|
||||
return nil, types.NewError(jsonErr, types.ErrorCodeBadResponseBody)
|
||||
}
|
||||
|
||||
c.Writer.Header().Set("Content-Type", "application/json")
|
||||
c.Writer.WriteHeader(resp.StatusCode)
|
||||
_, _ = c.Writer.Write(jsonResponse)
|
||||
|
||||
// https://github.com/google-gemini/cookbook/blob/719a27d752aac33f39de18a8d3cb42a70874917e/quickstarts/Counting_Tokens.ipynb
|
||||
// each image has fixed 258 tokens
|
||||
const imageTokens = 258
|
||||
generatedImages := len(openAIResponse.Data)
|
||||
|
||||
usage := &dto.Usage{
|
||||
PromptTokens: imageTokens * generatedImages, // each generated image has fixed 258 tokens
|
||||
CompletionTokens: 0, // image generation does not calculate completion tokens
|
||||
TotalTokens: imageTokens * generatedImages,
|
||||
}
|
||||
|
||||
return usage, nil
|
||||
}
|
||||
|
||||
func (a *Adaptor) GetModelList() []string {
|
||||
|
||||
@@ -1,222 +0,0 @@
|
||||
package gemini
|
||||
|
||||
import "encoding/json"
|
||||
|
||||
type GeminiChatRequest struct {
|
||||
Contents []GeminiChatContent `json:"contents"`
|
||||
SafetySettings []GeminiChatSafetySettings `json:"safetySettings,omitempty"`
|
||||
GenerationConfig GeminiChatGenerationConfig `json:"generationConfig,omitempty"`
|
||||
Tools []GeminiChatTool `json:"tools,omitempty"`
|
||||
SystemInstructions *GeminiChatContent `json:"systemInstruction,omitempty"`
|
||||
}
|
||||
|
||||
type GeminiThinkingConfig struct {
|
||||
IncludeThoughts bool `json:"includeThoughts,omitempty"`
|
||||
ThinkingBudget *int `json:"thinkingBudget,omitempty"`
|
||||
}
|
||||
|
||||
func (c *GeminiThinkingConfig) SetThinkingBudget(budget int) {
|
||||
c.ThinkingBudget = &budget
|
||||
}
|
||||
|
||||
type GeminiInlineData struct {
|
||||
MimeType string `json:"mimeType"`
|
||||
Data string `json:"data"`
|
||||
}
|
||||
|
||||
// UnmarshalJSON custom unmarshaler for GeminiInlineData to support snake_case and camelCase for MimeType
|
||||
func (g *GeminiInlineData) UnmarshalJSON(data []byte) error {
|
||||
type Alias GeminiInlineData // Use type alias to avoid recursion
|
||||
var aux struct {
|
||||
Alias
|
||||
MimeTypeSnake string `json:"mime_type"`
|
||||
}
|
||||
|
||||
if err := json.Unmarshal(data, &aux); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
*g = GeminiInlineData(aux.Alias) // Copy other fields if any in future
|
||||
|
||||
// Prioritize snake_case if present
|
||||
if aux.MimeTypeSnake != "" {
|
||||
g.MimeType = aux.MimeTypeSnake
|
||||
} else if aux.MimeType != "" { // Fallback to camelCase from Alias
|
||||
g.MimeType = aux.MimeType
|
||||
}
|
||||
// g.Data would be populated by aux.Alias.Data
|
||||
return nil
|
||||
}
|
||||
|
||||
type FunctionCall struct {
|
||||
FunctionName string `json:"name"`
|
||||
Arguments any `json:"args"`
|
||||
}
|
||||
|
||||
type FunctionResponse struct {
|
||||
Name string `json:"name"`
|
||||
Response map[string]interface{} `json:"response"`
|
||||
}
|
||||
|
||||
type GeminiPartExecutableCode struct {
|
||||
Language string `json:"language,omitempty"`
|
||||
Code string `json:"code,omitempty"`
|
||||
}
|
||||
|
||||
type GeminiPartCodeExecutionResult struct {
|
||||
Outcome string `json:"outcome,omitempty"`
|
||||
Output string `json:"output,omitempty"`
|
||||
}
|
||||
|
||||
type GeminiFileData struct {
|
||||
MimeType string `json:"mimeType,omitempty"`
|
||||
FileUri string `json:"fileUri,omitempty"`
|
||||
}
|
||||
|
||||
type GeminiPart struct {
|
||||
Text string `json:"text,omitempty"`
|
||||
Thought bool `json:"thought,omitempty"`
|
||||
InlineData *GeminiInlineData `json:"inlineData,omitempty"`
|
||||
FunctionCall *FunctionCall `json:"functionCall,omitempty"`
|
||||
FunctionResponse *FunctionResponse `json:"functionResponse,omitempty"`
|
||||
FileData *GeminiFileData `json:"fileData,omitempty"`
|
||||
ExecutableCode *GeminiPartExecutableCode `json:"executableCode,omitempty"`
|
||||
CodeExecutionResult *GeminiPartCodeExecutionResult `json:"codeExecutionResult,omitempty"`
|
||||
}
|
||||
|
||||
// UnmarshalJSON custom unmarshaler for GeminiPart to support snake_case and camelCase for InlineData
|
||||
func (p *GeminiPart) UnmarshalJSON(data []byte) error {
|
||||
// Alias to avoid recursion during unmarshalling
|
||||
type Alias GeminiPart
|
||||
var aux struct {
|
||||
Alias
|
||||
InlineDataSnake *GeminiInlineData `json:"inline_data,omitempty"` // snake_case variant
|
||||
}
|
||||
|
||||
if err := json.Unmarshal(data, &aux); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Assign fields from alias
|
||||
*p = GeminiPart(aux.Alias)
|
||||
|
||||
// Prioritize snake_case for InlineData if present
|
||||
if aux.InlineDataSnake != nil {
|
||||
p.InlineData = aux.InlineDataSnake
|
||||
} else if aux.InlineData != nil { // Fallback to camelCase from Alias
|
||||
p.InlineData = aux.InlineData
|
||||
}
|
||||
// Other fields like Text, FunctionCall etc. are already populated via aux.Alias
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
type GeminiChatContent struct {
|
||||
Role string `json:"role,omitempty"`
|
||||
Parts []GeminiPart `json:"parts"`
|
||||
}
|
||||
|
||||
type GeminiChatSafetySettings struct {
|
||||
Category string `json:"category"`
|
||||
Threshold string `json:"threshold"`
|
||||
}
|
||||
|
||||
type GeminiChatTool struct {
|
||||
GoogleSearch any `json:"googleSearch,omitempty"`
|
||||
GoogleSearchRetrieval any `json:"googleSearchRetrieval,omitempty"`
|
||||
CodeExecution any `json:"codeExecution,omitempty"`
|
||||
FunctionDeclarations any `json:"functionDeclarations,omitempty"`
|
||||
}
|
||||
|
||||
type GeminiChatGenerationConfig struct {
|
||||
Temperature *float64 `json:"temperature,omitempty"`
|
||||
TopP float64 `json:"topP,omitempty"`
|
||||
TopK float64 `json:"topK,omitempty"`
|
||||
MaxOutputTokens uint `json:"maxOutputTokens,omitempty"`
|
||||
CandidateCount int `json:"candidateCount,omitempty"`
|
||||
StopSequences []string `json:"stopSequences,omitempty"`
|
||||
ResponseMimeType string `json:"responseMimeType,omitempty"`
|
||||
ResponseSchema any `json:"responseSchema,omitempty"`
|
||||
Seed int64 `json:"seed,omitempty"`
|
||||
ResponseModalities []string `json:"responseModalities,omitempty"`
|
||||
ThinkingConfig *GeminiThinkingConfig `json:"thinkingConfig,omitempty"`
|
||||
SpeechConfig json.RawMessage `json:"speechConfig,omitempty"` // RawMessage to allow flexible speech config
|
||||
}
|
||||
|
||||
type GeminiChatCandidate struct {
|
||||
Content GeminiChatContent `json:"content"`
|
||||
FinishReason *string `json:"finishReason"`
|
||||
Index int64 `json:"index"`
|
||||
SafetyRatings []GeminiChatSafetyRating `json:"safetyRatings"`
|
||||
}
|
||||
|
||||
type GeminiChatSafetyRating struct {
|
||||
Category string `json:"category"`
|
||||
Probability string `json:"probability"`
|
||||
}
|
||||
|
||||
type GeminiChatPromptFeedback struct {
|
||||
SafetyRatings []GeminiChatSafetyRating `json:"safetyRatings"`
|
||||
}
|
||||
|
||||
type GeminiChatResponse struct {
|
||||
Candidates []GeminiChatCandidate `json:"candidates"`
|
||||
PromptFeedback GeminiChatPromptFeedback `json:"promptFeedback"`
|
||||
UsageMetadata GeminiUsageMetadata `json:"usageMetadata"`
|
||||
}
|
||||
|
||||
type GeminiUsageMetadata struct {
|
||||
PromptTokenCount int `json:"promptTokenCount"`
|
||||
CandidatesTokenCount int `json:"candidatesTokenCount"`
|
||||
TotalTokenCount int `json:"totalTokenCount"`
|
||||
ThoughtsTokenCount int `json:"thoughtsTokenCount"`
|
||||
PromptTokensDetails []GeminiPromptTokensDetails `json:"promptTokensDetails"`
|
||||
}
|
||||
|
||||
type GeminiPromptTokensDetails struct {
|
||||
Modality string `json:"modality"`
|
||||
TokenCount int `json:"tokenCount"`
|
||||
}
|
||||
|
||||
// Imagen related structs
|
||||
type GeminiImageRequest struct {
|
||||
Instances []GeminiImageInstance `json:"instances"`
|
||||
Parameters GeminiImageParameters `json:"parameters"`
|
||||
}
|
||||
|
||||
type GeminiImageInstance struct {
|
||||
Prompt string `json:"prompt"`
|
||||
}
|
||||
|
||||
type GeminiImageParameters struct {
|
||||
SampleCount int `json:"sampleCount,omitempty"`
|
||||
AspectRatio string `json:"aspectRatio,omitempty"`
|
||||
PersonGeneration string `json:"personGeneration,omitempty"`
|
||||
}
|
||||
|
||||
type GeminiImageResponse struct {
|
||||
Predictions []GeminiImagePrediction `json:"predictions"`
|
||||
}
|
||||
|
||||
type GeminiImagePrediction struct {
|
||||
MimeType string `json:"mimeType"`
|
||||
BytesBase64Encoded string `json:"bytesBase64Encoded"`
|
||||
RaiFilteredReason string `json:"raiFilteredReason,omitempty"`
|
||||
SafetyAttributes any `json:"safetyAttributes,omitempty"`
|
||||
}
|
||||
|
||||
// Embedding related structs
|
||||
type GeminiEmbeddingRequest struct {
|
||||
Content GeminiChatContent `json:"content"`
|
||||
TaskType string `json:"taskType,omitempty"`
|
||||
Title string `json:"title,omitempty"`
|
||||
OutputDimensionality int `json:"outputDimensionality,omitempty"`
|
||||
}
|
||||
|
||||
type GeminiEmbeddingResponse struct {
|
||||
Embedding ContentEmbedding `json:"embedding"`
|
||||
}
|
||||
|
||||
type ContentEmbedding struct {
|
||||
Values []float64 `json:"values"`
|
||||
}
|
||||
@@ -5,22 +5,25 @@ import (
|
||||
"net/http"
|
||||
"one-api/common"
|
||||
"one-api/dto"
|
||||
"one-api/logger"
|
||||
relaycommon "one-api/relay/common"
|
||||
"one-api/relay/helper"
|
||||
"one-api/service"
|
||||
"one-api/types"
|
||||
"strings"
|
||||
|
||||
"github.com/pkg/errors"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
func GeminiTextGenerationHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http.Response) (*dto.Usage, *types.NewAPIError) {
|
||||
defer common.CloseResponseBodyGracefully(resp)
|
||||
defer service.CloseResponseBodyGracefully(resp)
|
||||
|
||||
// 读取响应体
|
||||
responseBody, err := io.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
return nil, types.NewError(err, types.ErrorCodeBadResponseBody)
|
||||
return nil, types.NewOpenAIError(err, types.ErrorCodeBadResponseBody, http.StatusInternalServerError)
|
||||
}
|
||||
|
||||
if common.DebugEnabled {
|
||||
@@ -28,10 +31,10 @@ func GeminiTextGenerationHandler(c *gin.Context, info *relaycommon.RelayInfo, re
|
||||
}
|
||||
|
||||
// 解析为 Gemini 原生响应格式
|
||||
var geminiResponse GeminiChatResponse
|
||||
var geminiResponse dto.GeminiChatResponse
|
||||
err = common.Unmarshal(responseBody, &geminiResponse)
|
||||
if err != nil {
|
||||
return nil, types.NewError(err, types.ErrorCodeBadResponseBody)
|
||||
return nil, types.NewOpenAIError(err, types.ErrorCodeBadResponseBody, http.StatusInternalServerError)
|
||||
}
|
||||
|
||||
// 计算使用量(基于 UsageMetadata)
|
||||
@@ -43,6 +46,32 @@ func GeminiTextGenerationHandler(c *gin.Context, info *relaycommon.RelayInfo, re
|
||||
|
||||
usage.CompletionTokenDetails.ReasoningTokens = geminiResponse.UsageMetadata.ThoughtsTokenCount
|
||||
|
||||
if strings.HasPrefix(info.UpstreamModelName, "gemini-2.5-flash-image-preview") {
|
||||
imageOutputCounts := 0
|
||||
for _, candidate := range geminiResponse.Candidates {
|
||||
for _, part := range candidate.Content.Parts {
|
||||
if part.InlineData != nil && strings.HasPrefix(part.InlineData.MimeType, "image/") {
|
||||
imageOutputCounts++
|
||||
}
|
||||
}
|
||||
}
|
||||
if imageOutputCounts != 0 {
|
||||
usage.CompletionTokens = usage.CompletionTokens - imageOutputCounts*1290
|
||||
usage.TotalTokens = usage.TotalTokens - imageOutputCounts*1290
|
||||
c.Set("gemini_image_tokens", imageOutputCounts*1290)
|
||||
}
|
||||
}
|
||||
|
||||
// if strings.HasPrefix(info.UpstreamModelName, "gemini-2.5-flash-image-preview") {
|
||||
// for _, detail := range geminiResponse.UsageMetadata.CandidatesTokensDetails {
|
||||
// if detail.Modality == "IMAGE" {
|
||||
// usage.CompletionTokens = usage.CompletionTokens - detail.TokenCount
|
||||
// usage.TotalTokens = usage.TotalTokens - detail.TokenCount
|
||||
// c.Set("gemini_image_tokens", detail.TokenCount)
|
||||
// }
|
||||
// }
|
||||
// }
|
||||
|
||||
for _, detail := range geminiResponse.UsageMetadata.PromptTokensDetails {
|
||||
if detail.Modality == "AUDIO" {
|
||||
usage.PromptTokensDetails.AudioTokens = detail.TokenCount
|
||||
@@ -51,17 +80,47 @@ func GeminiTextGenerationHandler(c *gin.Context, info *relaycommon.RelayInfo, re
|
||||
}
|
||||
}
|
||||
|
||||
// 直接返回 Gemini 原生格式的 JSON 响应
|
||||
jsonResponse, err := common.Marshal(geminiResponse)
|
||||
if err != nil {
|
||||
return nil, types.NewError(err, types.ErrorCodeBadResponseBody)
|
||||
}
|
||||
|
||||
common.IOCopyBytesGracefully(c, resp, jsonResponse)
|
||||
service.IOCopyBytesGracefully(c, resp, responseBody)
|
||||
|
||||
return &usage, nil
|
||||
}
|
||||
|
||||
func NativeGeminiEmbeddingHandler(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (*dto.Usage, *types.NewAPIError) {
|
||||
defer service.CloseResponseBodyGracefully(resp)
|
||||
|
||||
responseBody, err := io.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
return nil, types.NewOpenAIError(err, types.ErrorCodeBadResponseBody, http.StatusInternalServerError)
|
||||
}
|
||||
|
||||
if common.DebugEnabled {
|
||||
println(string(responseBody))
|
||||
}
|
||||
|
||||
usage := &dto.Usage{
|
||||
PromptTokens: info.PromptTokens,
|
||||
TotalTokens: info.PromptTokens,
|
||||
}
|
||||
|
||||
if info.IsGeminiBatchEmbedding {
|
||||
var geminiResponse dto.GeminiBatchEmbeddingResponse
|
||||
err = common.Unmarshal(responseBody, &geminiResponse)
|
||||
if err != nil {
|
||||
return nil, types.NewOpenAIError(err, types.ErrorCodeBadResponseBody, http.StatusInternalServerError)
|
||||
}
|
||||
} else {
|
||||
var geminiResponse dto.GeminiEmbeddingResponse
|
||||
err = common.Unmarshal(responseBody, &geminiResponse)
|
||||
if err != nil {
|
||||
return nil, types.NewOpenAIError(err, types.ErrorCodeBadResponseBody, http.StatusInternalServerError)
|
||||
}
|
||||
}
|
||||
|
||||
service.IOCopyBytesGracefully(c, resp, responseBody)
|
||||
|
||||
return usage, nil
|
||||
}
|
||||
|
||||
func GeminiTextGenerationStreamHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http.Response) (*dto.Usage, *types.NewAPIError) {
|
||||
var usage = &dto.Usage{}
|
||||
var imageCount int
|
||||
@@ -71,10 +130,10 @@ func GeminiTextGenerationStreamHandler(c *gin.Context, info *relaycommon.RelayIn
|
||||
responseText := strings.Builder{}
|
||||
|
||||
helper.StreamScannerHandler(c, resp, info, func(data string) bool {
|
||||
var geminiResponse GeminiChatResponse
|
||||
var geminiResponse dto.GeminiChatResponse
|
||||
err := common.UnmarshalJsonStr(data, &geminiResponse)
|
||||
if err != nil {
|
||||
common.LogError(c, "error unmarshalling stream response: "+err.Error())
|
||||
logger.LogError(c, "error unmarshalling stream response: "+err.Error())
|
||||
return false
|
||||
}
|
||||
|
||||
@@ -103,17 +162,31 @@ func GeminiTextGenerationStreamHandler(c *gin.Context, info *relaycommon.RelayIn
|
||||
usage.PromptTokensDetails.TextTokens = detail.TokenCount
|
||||
}
|
||||
}
|
||||
|
||||
if strings.HasPrefix(info.UpstreamModelName, "gemini-2.5-flash-image-preview") {
|
||||
for _, detail := range geminiResponse.UsageMetadata.CandidatesTokensDetails {
|
||||
if detail.Modality == "IMAGE" {
|
||||
usage.CompletionTokens = usage.CompletionTokens - detail.TokenCount
|
||||
usage.TotalTokens = usage.TotalTokens - detail.TokenCount
|
||||
c.Set("gemini_image_tokens", detail.TokenCount)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// 直接发送 GeminiChatResponse 响应
|
||||
err = helper.StringData(c, data)
|
||||
if err != nil {
|
||||
common.LogError(c, err.Error())
|
||||
logger.LogError(c, err.Error())
|
||||
}
|
||||
|
||||
info.SendResponseCount++
|
||||
return true
|
||||
})
|
||||
|
||||
if info.SendResponseCount == 0 {
|
||||
return nil, types.NewOpenAIError(errors.New("no response received from Gemini API"), types.ErrorCodeEmptyResponse, http.StatusInternalServerError)
|
||||
}
|
||||
|
||||
if imageCount != 0 {
|
||||
if usage.CompletionTokens == 0 {
|
||||
usage.CompletionTokens = imageCount * 258
|
||||
|
||||
@@ -9,6 +9,8 @@ import (
|
||||
"one-api/common"
|
||||
"one-api/constant"
|
||||
"one-api/dto"
|
||||
"one-api/logger"
|
||||
"one-api/relay/channel/openai"
|
||||
relaycommon "one-api/relay/common"
|
||||
"one-api/relay/helper"
|
||||
"one-api/service"
|
||||
@@ -48,12 +50,20 @@ const (
|
||||
flash25LiteMaxBudget = 24576
|
||||
)
|
||||
|
||||
// clampThinkingBudget 根据模型名称将预算限制在允许的范围内
|
||||
func clampThinkingBudget(modelName string, budget int) int {
|
||||
isNew25Pro := strings.HasPrefix(modelName, "gemini-2.5-pro") &&
|
||||
func isNew25ProModel(modelName string) bool {
|
||||
return strings.HasPrefix(modelName, "gemini-2.5-pro") &&
|
||||
!strings.HasPrefix(modelName, "gemini-2.5-pro-preview-05-06") &&
|
||||
!strings.HasPrefix(modelName, "gemini-2.5-pro-preview-03-25")
|
||||
is25FlashLite := strings.HasPrefix(modelName, "gemini-2.5-flash-lite")
|
||||
}
|
||||
|
||||
func is25FlashLiteModel(modelName string) bool {
|
||||
return strings.HasPrefix(modelName, "gemini-2.5-flash-lite")
|
||||
}
|
||||
|
||||
// clampThinkingBudget 根据模型名称将预算限制在允许的范围内
|
||||
func clampThinkingBudget(modelName string, budget int) int {
|
||||
isNew25Pro := isNew25ProModel(modelName)
|
||||
is25FlashLite := is25FlashLiteModel(modelName)
|
||||
|
||||
if is25FlashLite {
|
||||
if budget < flash25LiteMinBudget {
|
||||
@@ -80,7 +90,34 @@ func clampThinkingBudget(modelName string, budget int) int {
|
||||
return budget
|
||||
}
|
||||
|
||||
func ThinkingAdaptor(geminiRequest *GeminiChatRequest, info *relaycommon.RelayInfo) {
|
||||
// "effort": "high" - Allocates a large portion of tokens for reasoning (approximately 80% of max_tokens)
|
||||
// "effort": "medium" - Allocates a moderate portion of tokens (approximately 50% of max_tokens)
|
||||
// "effort": "low" - Allocates a smaller portion of tokens (approximately 20% of max_tokens)
|
||||
func clampThinkingBudgetByEffort(modelName string, effort string) int {
|
||||
isNew25Pro := isNew25ProModel(modelName)
|
||||
is25FlashLite := is25FlashLiteModel(modelName)
|
||||
|
||||
maxBudget := 0
|
||||
if is25FlashLite {
|
||||
maxBudget = flash25LiteMaxBudget
|
||||
}
|
||||
if isNew25Pro {
|
||||
maxBudget = pro25MaxBudget
|
||||
} else {
|
||||
maxBudget = flash25MaxBudget
|
||||
}
|
||||
switch effort {
|
||||
case "high":
|
||||
maxBudget = maxBudget * 80 / 100
|
||||
case "medium":
|
||||
maxBudget = maxBudget * 50 / 100
|
||||
case "low":
|
||||
maxBudget = maxBudget * 20 / 100
|
||||
}
|
||||
return clampThinkingBudget(modelName, maxBudget)
|
||||
}
|
||||
|
||||
func ThinkingAdaptor(geminiRequest *dto.GeminiChatRequest, info *relaycommon.RelayInfo, oaiRequest ...dto.GeneralOpenAIRequest) {
|
||||
if model_setting.GetGeminiSettings().ThinkingAdapterEnabled {
|
||||
modelName := info.UpstreamModelName
|
||||
isNew25Pro := strings.HasPrefix(modelName, "gemini-2.5-pro") &&
|
||||
@@ -92,7 +129,7 @@ func ThinkingAdaptor(geminiRequest *GeminiChatRequest, info *relaycommon.RelayIn
|
||||
if len(parts) == 2 && parts[1] != "" {
|
||||
if budgetTokens, err := strconv.Atoi(parts[1]); err == nil {
|
||||
clampedBudget := clampThinkingBudget(modelName, budgetTokens)
|
||||
geminiRequest.GenerationConfig.ThinkingConfig = &GeminiThinkingConfig{
|
||||
geminiRequest.GenerationConfig.ThinkingConfig = &dto.GeminiThinkingConfig{
|
||||
ThinkingBudget: common.GetPointer(clampedBudget),
|
||||
IncludeThoughts: true,
|
||||
}
|
||||
@@ -112,22 +149,27 @@ func ThinkingAdaptor(geminiRequest *GeminiChatRequest, info *relaycommon.RelayIn
|
||||
}
|
||||
|
||||
if isUnsupported {
|
||||
geminiRequest.GenerationConfig.ThinkingConfig = &GeminiThinkingConfig{
|
||||
geminiRequest.GenerationConfig.ThinkingConfig = &dto.GeminiThinkingConfig{
|
||||
IncludeThoughts: true,
|
||||
}
|
||||
} else {
|
||||
geminiRequest.GenerationConfig.ThinkingConfig = &GeminiThinkingConfig{
|
||||
geminiRequest.GenerationConfig.ThinkingConfig = &dto.GeminiThinkingConfig{
|
||||
IncludeThoughts: true,
|
||||
}
|
||||
if geminiRequest.GenerationConfig.MaxOutputTokens > 0 {
|
||||
budgetTokens := model_setting.GetGeminiSettings().ThinkingAdapterBudgetTokensPercentage * float64(geminiRequest.GenerationConfig.MaxOutputTokens)
|
||||
clampedBudget := clampThinkingBudget(modelName, int(budgetTokens))
|
||||
geminiRequest.GenerationConfig.ThinkingConfig.ThinkingBudget = common.GetPointer(clampedBudget)
|
||||
} else {
|
||||
if len(oaiRequest) > 0 {
|
||||
// 如果有reasoningEffort参数,则根据其值设置思考预算
|
||||
geminiRequest.GenerationConfig.ThinkingConfig.ThinkingBudget = common.GetPointer(clampThinkingBudgetByEffort(modelName, oaiRequest[0].ReasoningEffort))
|
||||
}
|
||||
}
|
||||
}
|
||||
} else if strings.HasSuffix(modelName, "-nothinking") {
|
||||
if !isNew25Pro {
|
||||
geminiRequest.GenerationConfig.ThinkingConfig = &GeminiThinkingConfig{
|
||||
geminiRequest.GenerationConfig.ThinkingConfig = &dto.GeminiThinkingConfig{
|
||||
ThinkingBudget: common.GetPointer(0),
|
||||
}
|
||||
}
|
||||
@@ -136,14 +178,14 @@ func ThinkingAdaptor(geminiRequest *GeminiChatRequest, info *relaycommon.RelayIn
|
||||
}
|
||||
|
||||
// Setting safety to the lowest possible values since Gemini is already powerless enough
|
||||
func CovertGemini2OpenAI(textRequest dto.GeneralOpenAIRequest, info *relaycommon.RelayInfo) (*GeminiChatRequest, error) {
|
||||
func CovertGemini2OpenAI(c *gin.Context, textRequest dto.GeneralOpenAIRequest, info *relaycommon.RelayInfo) (*dto.GeminiChatRequest, error) {
|
||||
|
||||
geminiRequest := GeminiChatRequest{
|
||||
Contents: make([]GeminiChatContent, 0, len(textRequest.Messages)),
|
||||
GenerationConfig: GeminiChatGenerationConfig{
|
||||
geminiRequest := dto.GeminiChatRequest{
|
||||
Contents: make([]dto.GeminiChatContent, 0, len(textRequest.Messages)),
|
||||
GenerationConfig: dto.GeminiChatGenerationConfig{
|
||||
Temperature: textRequest.Temperature,
|
||||
TopP: textRequest.TopP,
|
||||
MaxOutputTokens: textRequest.MaxTokens,
|
||||
MaxOutputTokens: textRequest.GetMaxTokens(),
|
||||
Seed: int64(textRequest.Seed),
|
||||
},
|
||||
}
|
||||
@@ -155,11 +197,41 @@ func CovertGemini2OpenAI(textRequest dto.GeneralOpenAIRequest, info *relaycommon
|
||||
}
|
||||
}
|
||||
|
||||
ThinkingAdaptor(&geminiRequest, info)
|
||||
adaptorWithExtraBody := false
|
||||
|
||||
safetySettings := make([]GeminiChatSafetySettings, 0, len(SafetySettingList))
|
||||
if len(textRequest.ExtraBody) > 0 {
|
||||
if !strings.HasSuffix(info.UpstreamModelName, "-nothinking") {
|
||||
var extraBody map[string]interface{}
|
||||
if err := common.Unmarshal(textRequest.ExtraBody, &extraBody); err != nil {
|
||||
return nil, fmt.Errorf("invalid extra body: %w", err)
|
||||
}
|
||||
// eg. {"google":{"thinking_config":{"thinking_budget":5324,"include_thoughts":true}}}
|
||||
if googleBody, ok := extraBody["google"].(map[string]interface{}); ok {
|
||||
adaptorWithExtraBody = true
|
||||
if thinkingConfig, ok := googleBody["thinking_config"].(map[string]interface{}); ok {
|
||||
if budget, ok := thinkingConfig["thinking_budget"].(float64); ok {
|
||||
budgetInt := int(budget)
|
||||
geminiRequest.GenerationConfig.ThinkingConfig = &dto.GeminiThinkingConfig{
|
||||
ThinkingBudget: common.GetPointer(budgetInt),
|
||||
IncludeThoughts: true,
|
||||
}
|
||||
} else {
|
||||
geminiRequest.GenerationConfig.ThinkingConfig = &dto.GeminiThinkingConfig{
|
||||
IncludeThoughts: true,
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if !adaptorWithExtraBody {
|
||||
ThinkingAdaptor(&geminiRequest, info, textRequest)
|
||||
}
|
||||
|
||||
safetySettings := make([]dto.GeminiChatSafetySettings, 0, len(SafetySettingList))
|
||||
for _, category := range SafetySettingList {
|
||||
safetySettings = append(safetySettings, GeminiChatSafetySettings{
|
||||
safetySettings = append(safetySettings, dto.GeminiChatSafetySettings{
|
||||
Category: category,
|
||||
Threshold: model_setting.GetGeminiSafetySetting(category),
|
||||
})
|
||||
@@ -196,32 +268,35 @@ func CovertGemini2OpenAI(textRequest dto.GeneralOpenAIRequest, info *relaycommon
|
||||
tool.Function.Parameters = cleanedParams
|
||||
functions = append(functions, tool.Function)
|
||||
}
|
||||
geminiTools := geminiRequest.GetTools()
|
||||
if codeExecution {
|
||||
geminiRequest.Tools = append(geminiRequest.Tools, GeminiChatTool{
|
||||
geminiTools = append(geminiTools, dto.GeminiChatTool{
|
||||
CodeExecution: make(map[string]string),
|
||||
})
|
||||
}
|
||||
if googleSearch {
|
||||
geminiRequest.Tools = append(geminiRequest.Tools, GeminiChatTool{
|
||||
geminiTools = append(geminiTools, dto.GeminiChatTool{
|
||||
GoogleSearch: make(map[string]string),
|
||||
})
|
||||
}
|
||||
if len(functions) > 0 {
|
||||
geminiRequest.Tools = append(geminiRequest.Tools, GeminiChatTool{
|
||||
geminiTools = append(geminiTools, dto.GeminiChatTool{
|
||||
FunctionDeclarations: functions,
|
||||
})
|
||||
}
|
||||
// common.SysLog("tools: " + fmt.Sprintf("%+v", geminiRequest.Tools))
|
||||
// json_data, _ := json.Marshal(geminiRequest.Tools)
|
||||
// common.SysLog("tools_json: " + string(json_data))
|
||||
geminiRequest.SetTools(geminiTools)
|
||||
}
|
||||
|
||||
if textRequest.ResponseFormat != nil && (textRequest.ResponseFormat.Type == "json_schema" || textRequest.ResponseFormat.Type == "json_object") {
|
||||
geminiRequest.GenerationConfig.ResponseMimeType = "application/json"
|
||||
|
||||
if textRequest.ResponseFormat.JsonSchema != nil && textRequest.ResponseFormat.JsonSchema.Schema != nil {
|
||||
cleanedSchema := removeAdditionalPropertiesWithDepth(textRequest.ResponseFormat.JsonSchema.Schema, 0)
|
||||
geminiRequest.GenerationConfig.ResponseSchema = cleanedSchema
|
||||
if len(textRequest.ResponseFormat.JsonSchema) > 0 {
|
||||
// 先将json.RawMessage解析
|
||||
var jsonSchema dto.FormatJsonSchema
|
||||
if err := common.Unmarshal(textRequest.ResponseFormat.JsonSchema, &jsonSchema); err == nil {
|
||||
cleanedSchema := removeAdditionalPropertiesWithDepth(jsonSchema.Schema, 0)
|
||||
geminiRequest.GenerationConfig.ResponseSchema = cleanedSchema
|
||||
}
|
||||
}
|
||||
}
|
||||
tool_call_ids := make(map[string]string)
|
||||
@@ -233,7 +308,7 @@ func CovertGemini2OpenAI(textRequest dto.GeneralOpenAIRequest, info *relaycommon
|
||||
continue
|
||||
} else if message.Role == "tool" || message.Role == "function" {
|
||||
if len(geminiRequest.Contents) == 0 || geminiRequest.Contents[len(geminiRequest.Contents)-1].Role == "model" {
|
||||
geminiRequest.Contents = append(geminiRequest.Contents, GeminiChatContent{
|
||||
geminiRequest.Contents = append(geminiRequest.Contents, dto.GeminiChatContent{
|
||||
Role: "user",
|
||||
})
|
||||
}
|
||||
@@ -260,18 +335,18 @@ func CovertGemini2OpenAI(textRequest dto.GeneralOpenAIRequest, info *relaycommon
|
||||
}
|
||||
}
|
||||
|
||||
functionResp := &FunctionResponse{
|
||||
functionResp := &dto.GeminiFunctionResponse{
|
||||
Name: name,
|
||||
Response: contentMap,
|
||||
}
|
||||
|
||||
*parts = append(*parts, GeminiPart{
|
||||
*parts = append(*parts, dto.GeminiPart{
|
||||
FunctionResponse: functionResp,
|
||||
})
|
||||
continue
|
||||
}
|
||||
var parts []GeminiPart
|
||||
content := GeminiChatContent{
|
||||
var parts []dto.GeminiPart
|
||||
content := dto.GeminiChatContent{
|
||||
Role: message.Role,
|
||||
}
|
||||
// isToolCall := false
|
||||
@@ -285,8 +360,8 @@ func CovertGemini2OpenAI(textRequest dto.GeneralOpenAIRequest, info *relaycommon
|
||||
return nil, fmt.Errorf("invalid arguments for function %s, args: %s", call.Function.Name, call.Function.Arguments)
|
||||
}
|
||||
}
|
||||
toolCall := GeminiPart{
|
||||
FunctionCall: &FunctionCall{
|
||||
toolCall := dto.GeminiPart{
|
||||
FunctionCall: &dto.FunctionCall{
|
||||
FunctionName: call.Function.Name,
|
||||
Arguments: args,
|
||||
},
|
||||
@@ -303,7 +378,7 @@ func CovertGemini2OpenAI(textRequest dto.GeneralOpenAIRequest, info *relaycommon
|
||||
if part.Text == "" {
|
||||
continue
|
||||
}
|
||||
parts = append(parts, GeminiPart{
|
||||
parts = append(parts, dto.GeminiPart{
|
||||
Text: part.Text,
|
||||
})
|
||||
} else if part.Type == dto.ContentTypeImageURL {
|
||||
@@ -315,7 +390,7 @@ func CovertGemini2OpenAI(textRequest dto.GeneralOpenAIRequest, info *relaycommon
|
||||
// 判断是否是url
|
||||
if strings.HasPrefix(part.GetImageMedia().Url, "http") {
|
||||
// 是url,获取文件的类型和base64编码的数据
|
||||
fileData, err := service.GetFileBase64FromUrl(part.GetImageMedia().Url)
|
||||
fileData, err := service.GetFileBase64FromUrl(c, part.GetImageMedia().Url, "formatting image for Gemini")
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("get file base64 from url '%s' failed: %w", part.GetImageMedia().Url, err)
|
||||
}
|
||||
@@ -326,8 +401,8 @@ func CovertGemini2OpenAI(textRequest dto.GeneralOpenAIRequest, info *relaycommon
|
||||
return nil, fmt.Errorf("mime type is not supported by Gemini: '%s', url: '%s', supported types are: %v", fileData.MimeType, url, getSupportedMimeTypesList())
|
||||
}
|
||||
|
||||
parts = append(parts, GeminiPart{
|
||||
InlineData: &GeminiInlineData{
|
||||
parts = append(parts, dto.GeminiPart{
|
||||
InlineData: &dto.GeminiInlineData{
|
||||
MimeType: fileData.MimeType, // 使用原始的 MimeType,因为大小写可能对API有意义
|
||||
Data: fileData.Base64Data,
|
||||
},
|
||||
@@ -337,8 +412,8 @@ func CovertGemini2OpenAI(textRequest dto.GeneralOpenAIRequest, info *relaycommon
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("decode base64 image data failed: %s", err.Error())
|
||||
}
|
||||
parts = append(parts, GeminiPart{
|
||||
InlineData: &GeminiInlineData{
|
||||
parts = append(parts, dto.GeminiPart{
|
||||
InlineData: &dto.GeminiInlineData{
|
||||
MimeType: format,
|
||||
Data: base64String,
|
||||
},
|
||||
@@ -352,8 +427,8 @@ func CovertGemini2OpenAI(textRequest dto.GeneralOpenAIRequest, info *relaycommon
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("decode base64 file data failed: %s", err.Error())
|
||||
}
|
||||
parts = append(parts, GeminiPart{
|
||||
InlineData: &GeminiInlineData{
|
||||
parts = append(parts, dto.GeminiPart{
|
||||
InlineData: &dto.GeminiInlineData{
|
||||
MimeType: format,
|
||||
Data: base64String,
|
||||
},
|
||||
@@ -366,8 +441,8 @@ func CovertGemini2OpenAI(textRequest dto.GeneralOpenAIRequest, info *relaycommon
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("decode base64 audio data failed: %s", err.Error())
|
||||
}
|
||||
parts = append(parts, GeminiPart{
|
||||
InlineData: &GeminiInlineData{
|
||||
parts = append(parts, dto.GeminiPart{
|
||||
InlineData: &dto.GeminiInlineData{
|
||||
MimeType: "audio/" + part.GetInputAudio().Format,
|
||||
Data: base64String,
|
||||
},
|
||||
@@ -387,8 +462,8 @@ func CovertGemini2OpenAI(textRequest dto.GeneralOpenAIRequest, info *relaycommon
|
||||
}
|
||||
|
||||
if len(system_content) > 0 {
|
||||
geminiRequest.SystemInstructions = &GeminiChatContent{
|
||||
Parts: []GeminiPart{
|
||||
geminiRequest.SystemInstructions = &dto.GeminiChatContent{
|
||||
Parts: []dto.GeminiPart{
|
||||
{
|
||||
Text: strings.Join(system_content, "\n"),
|
||||
},
|
||||
@@ -631,7 +706,7 @@ func unescapeMapOrSlice(data interface{}) interface{} {
|
||||
return data
|
||||
}
|
||||
|
||||
func getResponseToolCall(item *GeminiPart) *dto.ToolCallResponse {
|
||||
func getResponseToolCall(item *dto.GeminiPart) *dto.ToolCallResponse {
|
||||
var argsBytes []byte
|
||||
var err error
|
||||
if result, ok := item.FunctionCall.Arguments.(map[string]interface{}); ok {
|
||||
@@ -653,7 +728,7 @@ func getResponseToolCall(item *GeminiPart) *dto.ToolCallResponse {
|
||||
}
|
||||
}
|
||||
|
||||
func responseGeminiChat2OpenAI(c *gin.Context, response *GeminiChatResponse) *dto.OpenAITextResponse {
|
||||
func responseGeminiChat2OpenAI(c *gin.Context, response *dto.GeminiChatResponse) *dto.OpenAITextResponse {
|
||||
fullTextResponse := dto.OpenAITextResponse{
|
||||
Id: helper.GetResponseID(c),
|
||||
Object: "chat.completion",
|
||||
@@ -674,7 +749,16 @@ func responseGeminiChat2OpenAI(c *gin.Context, response *GeminiChatResponse) *dt
|
||||
var texts []string
|
||||
var toolCalls []dto.ToolCallResponse
|
||||
for _, part := range candidate.Content.Parts {
|
||||
if part.FunctionCall != nil {
|
||||
if part.InlineData != nil {
|
||||
// 媒体内容
|
||||
if strings.HasPrefix(part.InlineData.MimeType, "image") {
|
||||
imgText := ""
|
||||
texts = append(texts, imgText)
|
||||
} else {
|
||||
// 其他媒体类型,直接显示链接
|
||||
texts = append(texts, fmt.Sprintf("[media](data:%s;base64,%s)", part.InlineData.MimeType, part.InlineData.Data))
|
||||
}
|
||||
} else if part.FunctionCall != nil {
|
||||
choice.FinishReason = constant.FinishReasonToolCalls
|
||||
if call := getResponseToolCall(&part); call != nil {
|
||||
toolCalls = append(toolCalls, *call)
|
||||
@@ -720,10 +804,9 @@ func responseGeminiChat2OpenAI(c *gin.Context, response *GeminiChatResponse) *dt
|
||||
return &fullTextResponse
|
||||
}
|
||||
|
||||
func streamResponseGeminiChat2OpenAI(geminiResponse *GeminiChatResponse) (*dto.ChatCompletionsStreamResponse, bool, bool) {
|
||||
func streamResponseGeminiChat2OpenAI(geminiResponse *dto.GeminiChatResponse) (*dto.ChatCompletionsStreamResponse, bool) {
|
||||
choices := make([]dto.ChatCompletionsStreamResponseChoice, 0, len(geminiResponse.Candidates))
|
||||
isStop := false
|
||||
hasImage := false
|
||||
for _, candidate := range geminiResponse.Candidates {
|
||||
if candidate.FinishReason != nil && *candidate.FinishReason == "STOP" {
|
||||
isStop = true
|
||||
@@ -732,7 +815,7 @@ func streamResponseGeminiChat2OpenAI(geminiResponse *GeminiChatResponse) (*dto.C
|
||||
choice := dto.ChatCompletionsStreamResponseChoice{
|
||||
Index: int(candidate.Index),
|
||||
Delta: dto.ChatCompletionsStreamResponseChoiceDelta{
|
||||
Role: "assistant",
|
||||
//Role: "assistant",
|
||||
},
|
||||
}
|
||||
var texts []string
|
||||
@@ -754,7 +837,6 @@ func streamResponseGeminiChat2OpenAI(geminiResponse *GeminiChatResponse) (*dto.C
|
||||
if strings.HasPrefix(part.InlineData.MimeType, "image") {
|
||||
imgText := ""
|
||||
texts = append(texts, imgText)
|
||||
hasImage = true
|
||||
}
|
||||
} else if part.FunctionCall != nil {
|
||||
isTools = true
|
||||
@@ -762,6 +844,7 @@ func streamResponseGeminiChat2OpenAI(geminiResponse *GeminiChatResponse) (*dto.C
|
||||
call.SetIndex(len(choice.Delta.ToolCalls))
|
||||
choice.Delta.ToolCalls = append(choice.Delta.ToolCalls, *call)
|
||||
}
|
||||
|
||||
} else if part.Thought {
|
||||
isThought = true
|
||||
texts = append(texts, part.Text)
|
||||
@@ -791,28 +874,60 @@ func streamResponseGeminiChat2OpenAI(geminiResponse *GeminiChatResponse) (*dto.C
|
||||
var response dto.ChatCompletionsStreamResponse
|
||||
response.Object = "chat.completion.chunk"
|
||||
response.Choices = choices
|
||||
return &response, isStop, hasImage
|
||||
return &response, isStop
|
||||
}
|
||||
|
||||
func handleStream(c *gin.Context, info *relaycommon.RelayInfo, resp *dto.ChatCompletionsStreamResponse) error {
|
||||
streamData, err := common.Marshal(resp)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to marshal stream response: %w", err)
|
||||
}
|
||||
err = openai.HandleStreamFormat(c, info, string(streamData), info.ChannelSetting.ForceFormat, info.ChannelSetting.ThinkingToContent)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to handle stream format: %w", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func handleFinalStream(c *gin.Context, info *relaycommon.RelayInfo, resp *dto.ChatCompletionsStreamResponse) error {
|
||||
streamData, err := common.Marshal(resp)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to marshal stream response: %w", err)
|
||||
}
|
||||
openai.HandleFinalResponse(c, info, string(streamData), resp.Id, resp.Created, resp.Model, resp.GetSystemFingerprint(), resp.Usage, false)
|
||||
return nil
|
||||
}
|
||||
|
||||
func GeminiChatStreamHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http.Response) (*dto.Usage, *types.NewAPIError) {
|
||||
// responseText := ""
|
||||
id := helper.GetResponseID(c)
|
||||
createAt := common.GetTimestamp()
|
||||
responseText := strings.Builder{}
|
||||
var usage = &dto.Usage{}
|
||||
var imageCount int
|
||||
finishReason := constant.FinishReasonStop
|
||||
|
||||
helper.StreamScannerHandler(c, resp, info, func(data string) bool {
|
||||
var geminiResponse GeminiChatResponse
|
||||
var geminiResponse dto.GeminiChatResponse
|
||||
err := common.UnmarshalJsonStr(data, &geminiResponse)
|
||||
if err != nil {
|
||||
common.LogError(c, "error unmarshalling stream response: "+err.Error())
|
||||
logger.LogError(c, "error unmarshalling stream response: "+err.Error())
|
||||
return false
|
||||
}
|
||||
|
||||
response, isStop, hasImage := streamResponseGeminiChat2OpenAI(&geminiResponse)
|
||||
if hasImage {
|
||||
imageCount++
|
||||
for _, candidate := range geminiResponse.Candidates {
|
||||
for _, part := range candidate.Content.Parts {
|
||||
if part.InlineData != nil && part.InlineData.MimeType != "" {
|
||||
imageCount++
|
||||
}
|
||||
if part.Text != "" {
|
||||
responseText.WriteString(part.Text)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
response, isStop := streamResponseGeminiChat2OpenAI(&geminiResponse)
|
||||
|
||||
response.Id = id
|
||||
response.Created = createAt
|
||||
response.Model = info.UpstreamModelName
|
||||
@@ -829,18 +944,47 @@ func GeminiChatStreamHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *
|
||||
}
|
||||
}
|
||||
}
|
||||
err = helper.ObjectData(c, response)
|
||||
logger.LogDebug(c, fmt.Sprintf("info.SendResponseCount = %d", info.SendResponseCount))
|
||||
if info.SendResponseCount == 0 {
|
||||
// send first response
|
||||
emptyResponse := helper.GenerateStartEmptyResponse(id, createAt, info.UpstreamModelName, nil)
|
||||
if response.IsToolCall() {
|
||||
emptyResponse.Choices[0].Delta.ToolCalls = make([]dto.ToolCallResponse, 1)
|
||||
emptyResponse.Choices[0].Delta.ToolCalls[0] = *response.GetFirstToolCall()
|
||||
emptyResponse.Choices[0].Delta.ToolCalls[0].Function.Arguments = ""
|
||||
finishReason = constant.FinishReasonToolCalls
|
||||
err = handleStream(c, info, emptyResponse)
|
||||
if err != nil {
|
||||
logger.LogError(c, err.Error())
|
||||
}
|
||||
|
||||
response.ClearToolCalls()
|
||||
if response.IsFinished() {
|
||||
response.Choices[0].FinishReason = nil
|
||||
}
|
||||
} else {
|
||||
err = handleStream(c, info, emptyResponse)
|
||||
if err != nil {
|
||||
logger.LogError(c, err.Error())
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
err = handleStream(c, info, response)
|
||||
if err != nil {
|
||||
common.LogError(c, err.Error())
|
||||
logger.LogError(c, err.Error())
|
||||
}
|
||||
if isStop {
|
||||
response := helper.GenerateStopResponse(id, createAt, info.UpstreamModelName, constant.FinishReasonStop)
|
||||
helper.ObjectData(c, response)
|
||||
_ = handleStream(c, info, helper.GenerateStopResponse(id, createAt, info.UpstreamModelName, finishReason))
|
||||
}
|
||||
return true
|
||||
})
|
||||
|
||||
var response *dto.ChatCompletionsStreamResponse
|
||||
if info.SendResponseCount == 0 {
|
||||
// 空补全,报错不计费
|
||||
// empty response, throw an error
|
||||
return nil, types.NewOpenAIError(errors.New("no response received from Gemini API"), types.ErrorCodeEmptyResponse, http.StatusInternalServerError)
|
||||
}
|
||||
|
||||
if imageCount != 0 {
|
||||
if usage.CompletionTokens == 0 {
|
||||
@@ -851,14 +995,24 @@ func GeminiChatStreamHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *
|
||||
usage.PromptTokensDetails.TextTokens = usage.PromptTokens
|
||||
usage.CompletionTokens = usage.TotalTokens - usage.PromptTokens
|
||||
|
||||
if info.ShouldIncludeUsage {
|
||||
response = helper.GenerateFinalUsageResponse(id, createAt, info.UpstreamModelName, *usage)
|
||||
err := helper.ObjectData(c, response)
|
||||
if err != nil {
|
||||
common.SysError("send final response failed: " + err.Error())
|
||||
if usage.CompletionTokens == 0 {
|
||||
str := responseText.String()
|
||||
if len(str) > 0 {
|
||||
usage = service.ResponseText2Usage(responseText.String(), info.UpstreamModelName, info.PromptTokens)
|
||||
} else {
|
||||
// 空补全,不需要使用量
|
||||
usage = &dto.Usage{}
|
||||
}
|
||||
}
|
||||
helper.Done(c)
|
||||
|
||||
response := helper.GenerateFinalUsageResponse(id, createAt, info.UpstreamModelName, *usage)
|
||||
err := handleFinalStream(c, info, response)
|
||||
if err != nil {
|
||||
common.SysLog("send final response failed: " + err.Error())
|
||||
}
|
||||
//if info.RelayFormat == relaycommon.RelayFormatOpenAI {
|
||||
// helper.Done(c)
|
||||
//}
|
||||
//resp.Body.Close()
|
||||
return usage, nil
|
||||
}
|
||||
@@ -866,19 +1020,19 @@ func GeminiChatStreamHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *
|
||||
func GeminiChatHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http.Response) (*dto.Usage, *types.NewAPIError) {
|
||||
responseBody, err := io.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
return nil, types.NewError(err, types.ErrorCodeBadResponseBody)
|
||||
return nil, types.NewOpenAIError(err, types.ErrorCodeBadResponseBody, http.StatusInternalServerError)
|
||||
}
|
||||
common.CloseResponseBodyGracefully(resp)
|
||||
service.CloseResponseBodyGracefully(resp)
|
||||
if common.DebugEnabled {
|
||||
println(string(responseBody))
|
||||
}
|
||||
var geminiResponse GeminiChatResponse
|
||||
var geminiResponse dto.GeminiChatResponse
|
||||
err = common.Unmarshal(responseBody, &geminiResponse)
|
||||
if err != nil {
|
||||
return nil, types.NewError(err, types.ErrorCodeBadResponseBody)
|
||||
return nil, types.NewOpenAIError(err, types.ErrorCodeBadResponseBody, http.StatusInternalServerError)
|
||||
}
|
||||
if len(geminiResponse.Candidates) == 0 {
|
||||
return nil, types.NewError(errors.New("no candidates returned"), types.ErrorCodeBadResponseBody)
|
||||
return nil, types.NewOpenAIError(errors.New("no candidates returned"), types.ErrorCodeBadResponseBody, http.StatusInternalServerError)
|
||||
}
|
||||
fullTextResponse := responseGeminiChat2OpenAI(c, &geminiResponse)
|
||||
fullTextResponse.Model = info.UpstreamModelName
|
||||
@@ -900,40 +1054,55 @@ func GeminiChatHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http.R
|
||||
}
|
||||
|
||||
fullTextResponse.Usage = usage
|
||||
jsonResponse, err := json.Marshal(fullTextResponse)
|
||||
if err != nil {
|
||||
return nil, types.NewError(err, types.ErrorCodeBadResponseBody)
|
||||
|
||||
switch info.RelayFormat {
|
||||
case types.RelayFormatOpenAI:
|
||||
responseBody, err = common.Marshal(fullTextResponse)
|
||||
if err != nil {
|
||||
return nil, types.NewError(err, types.ErrorCodeBadResponseBody)
|
||||
}
|
||||
case types.RelayFormatClaude:
|
||||
claudeResp := service.ResponseOpenAI2Claude(fullTextResponse, info)
|
||||
claudeRespStr, err := common.Marshal(claudeResp)
|
||||
if err != nil {
|
||||
return nil, types.NewError(err, types.ErrorCodeBadResponseBody)
|
||||
}
|
||||
responseBody = claudeRespStr
|
||||
case types.RelayFormatGemini:
|
||||
break
|
||||
}
|
||||
c.Writer.Header().Set("Content-Type", "application/json")
|
||||
c.Writer.WriteHeader(resp.StatusCode)
|
||||
c.Writer.Write(jsonResponse)
|
||||
|
||||
service.IOCopyBytesGracefully(c, resp, responseBody)
|
||||
|
||||
return &usage, nil
|
||||
}
|
||||
|
||||
func GeminiEmbeddingHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http.Response) (*dto.Usage, *types.NewAPIError) {
|
||||
defer common.CloseResponseBodyGracefully(resp)
|
||||
defer service.CloseResponseBodyGracefully(resp)
|
||||
|
||||
responseBody, readErr := io.ReadAll(resp.Body)
|
||||
if readErr != nil {
|
||||
return nil, types.NewError(readErr, types.ErrorCodeBadResponseBody)
|
||||
return nil, types.NewOpenAIError(readErr, types.ErrorCodeBadResponseBody, http.StatusInternalServerError)
|
||||
}
|
||||
|
||||
var geminiResponse GeminiEmbeddingResponse
|
||||
var geminiResponse dto.GeminiBatchEmbeddingResponse
|
||||
if jsonErr := common.Unmarshal(responseBody, &geminiResponse); jsonErr != nil {
|
||||
return nil, types.NewError(jsonErr, types.ErrorCodeBadResponseBody)
|
||||
return nil, types.NewOpenAIError(jsonErr, types.ErrorCodeBadResponseBody, http.StatusInternalServerError)
|
||||
}
|
||||
|
||||
// convert to openai format response
|
||||
openAIResponse := dto.OpenAIEmbeddingResponse{
|
||||
Object: "list",
|
||||
Data: []dto.OpenAIEmbeddingResponseItem{
|
||||
{
|
||||
Object: "embedding",
|
||||
Embedding: geminiResponse.Embedding.Values,
|
||||
Index: 0,
|
||||
},
|
||||
},
|
||||
Model: info.UpstreamModelName,
|
||||
Data: make([]dto.OpenAIEmbeddingResponseItem, 0, len(geminiResponse.Embeddings)),
|
||||
Model: info.UpstreamModelName,
|
||||
}
|
||||
|
||||
for i, embedding := range geminiResponse.Embeddings {
|
||||
openAIResponse.Data = append(openAIResponse.Data, dto.OpenAIEmbeddingResponseItem{
|
||||
Object: "embedding",
|
||||
Embedding: embedding.Values,
|
||||
Index: i,
|
||||
})
|
||||
}
|
||||
|
||||
// calculate usage
|
||||
@@ -949,10 +1118,64 @@ func GeminiEmbeddingHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *h
|
||||
openAIResponse.Usage = *usage
|
||||
|
||||
jsonResponse, jsonErr := common.Marshal(openAIResponse)
|
||||
if jsonErr != nil {
|
||||
return nil, types.NewOpenAIError(jsonErr, types.ErrorCodeBadResponseBody, http.StatusInternalServerError)
|
||||
}
|
||||
|
||||
service.IOCopyBytesGracefully(c, resp, jsonResponse)
|
||||
return usage, nil
|
||||
}
|
||||
|
||||
func GeminiImageHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http.Response) (*dto.Usage, *types.NewAPIError) {
|
||||
responseBody, readErr := io.ReadAll(resp.Body)
|
||||
if readErr != nil {
|
||||
return nil, types.NewOpenAIError(readErr, types.ErrorCodeBadResponseBody, http.StatusInternalServerError)
|
||||
}
|
||||
_ = resp.Body.Close()
|
||||
|
||||
var geminiResponse dto.GeminiImageResponse
|
||||
if jsonErr := common.Unmarshal(responseBody, &geminiResponse); jsonErr != nil {
|
||||
return nil, types.NewOpenAIError(jsonErr, types.ErrorCodeBadResponseBody, http.StatusInternalServerError)
|
||||
}
|
||||
|
||||
if len(geminiResponse.Predictions) == 0 {
|
||||
return nil, types.NewOpenAIError(errors.New("no images generated"), types.ErrorCodeBadResponseBody, http.StatusInternalServerError)
|
||||
}
|
||||
|
||||
// convert to openai format response
|
||||
openAIResponse := dto.ImageResponse{
|
||||
Created: common.GetTimestamp(),
|
||||
Data: make([]dto.ImageData, 0, len(geminiResponse.Predictions)),
|
||||
}
|
||||
|
||||
for _, prediction := range geminiResponse.Predictions {
|
||||
if prediction.RaiFilteredReason != "" {
|
||||
continue // skip filtered image
|
||||
}
|
||||
openAIResponse.Data = append(openAIResponse.Data, dto.ImageData{
|
||||
B64Json: prediction.BytesBase64Encoded,
|
||||
})
|
||||
}
|
||||
|
||||
jsonResponse, jsonErr := json.Marshal(openAIResponse)
|
||||
if jsonErr != nil {
|
||||
return nil, types.NewError(jsonErr, types.ErrorCodeBadResponseBody)
|
||||
}
|
||||
|
||||
common.IOCopyBytesGracefully(c, resp, jsonResponse)
|
||||
c.Writer.Header().Set("Content-Type", "application/json")
|
||||
c.Writer.WriteHeader(resp.StatusCode)
|
||||
_, _ = c.Writer.Write(jsonResponse)
|
||||
|
||||
// https://github.com/google-gemini/cookbook/blob/719a27d752aac33f39de18a8d3cb42a70874917e/quickstarts/Counting_Tokens.ipynb
|
||||
// each image has fixed 258 tokens
|
||||
const imageTokens = 258
|
||||
generatedImages := len(openAIResponse.Data)
|
||||
|
||||
usage := &dto.Usage{
|
||||
PromptTokens: imageTokens * generatedImages, // each generated image has fixed 258 tokens
|
||||
CompletionTokens: 0, // image generation does not calculate completion tokens
|
||||
TotalTokens: imageTokens * generatedImages,
|
||||
}
|
||||
|
||||
return usage, nil
|
||||
}
|
||||
|
||||
@@ -4,7 +4,6 @@ import (
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"github.com/gin-gonic/gin"
|
||||
"io"
|
||||
"net/http"
|
||||
"one-api/dto"
|
||||
@@ -13,11 +12,18 @@ import (
|
||||
relaycommon "one-api/relay/common"
|
||||
relayconstant "one-api/relay/constant"
|
||||
"one-api/types"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
type Adaptor struct {
|
||||
}
|
||||
|
||||
func (a *Adaptor) ConvertGeminiRequest(*gin.Context, *relaycommon.RelayInfo, *dto.GeminiChatRequest) (any, error) {
|
||||
//TODO implement me
|
||||
return nil, errors.New("not implemented")
|
||||
}
|
||||
|
||||
func (a *Adaptor) ConvertClaudeRequest(*gin.Context, *relaycommon.RelayInfo, *dto.ClaudeRequest) (any, error) {
|
||||
return nil, errors.New("not implemented")
|
||||
}
|
||||
@@ -26,7 +32,7 @@ func (a *Adaptor) Init(info *relaycommon.RelayInfo) {
|
||||
}
|
||||
|
||||
func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) {
|
||||
return fmt.Sprintf("%s/?Action=CVProcess&Version=2022-08-31", info.BaseUrl), nil
|
||||
return fmt.Sprintf("%s/?Action=CVProcess&Version=2022-08-31", info.ChannelBaseUrl), nil
|
||||
}
|
||||
|
||||
func (a *Adaptor) SetupRequestHeader(c *gin.Context, header *http.Header, info *relaycommon.RelayInfo) error {
|
||||
|
||||
@@ -5,9 +5,9 @@ import (
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"one-api/common"
|
||||
"one-api/dto"
|
||||
relaycommon "one-api/relay/common"
|
||||
"one-api/service"
|
||||
"one-api/types"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
@@ -52,13 +52,13 @@ func jimengImageHandler(c *gin.Context, resp *http.Response, info *relaycommon.R
|
||||
var jimengResponse ImageResponse
|
||||
responseBody, err := io.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
return nil, types.NewError(err, types.ErrorCodeReadResponseBodyFailed)
|
||||
return nil, types.NewOpenAIError(err, types.ErrorCodeReadResponseBodyFailed, http.StatusInternalServerError)
|
||||
}
|
||||
common.CloseResponseBodyGracefully(resp)
|
||||
service.CloseResponseBodyGracefully(resp)
|
||||
|
||||
err = json.Unmarshal(responseBody, &jimengResponse)
|
||||
if err != nil {
|
||||
return nil, types.NewError(err, types.ErrorCodeBadResponseBody)
|
||||
return nil, types.NewOpenAIError(err, types.ErrorCodeBadResponseBody, http.StatusInternalServerError)
|
||||
}
|
||||
|
||||
// Check if the response indicates an error
|
||||
|
||||
@@ -12,7 +12,7 @@ import (
|
||||
"io"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"one-api/common"
|
||||
"one-api/logger"
|
||||
"sort"
|
||||
"strings"
|
||||
"time"
|
||||
@@ -44,7 +44,7 @@ func SetPayloadHash(c *gin.Context, req any) error {
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
common.LogInfo(c, fmt.Sprintf("SetPayloadHash body: %s", body))
|
||||
logger.LogInfo(c, fmt.Sprintf("SetPayloadHash body: %s", body))
|
||||
payloadHash := sha256.Sum256(body)
|
||||
hexPayloadHash := hex.EncodeToString(payloadHash[:])
|
||||
c.Set(HexPayloadHashKey, hexPayloadHash)
|
||||
|
||||
@@ -19,6 +19,11 @@ import (
|
||||
type Adaptor struct {
|
||||
}
|
||||
|
||||
func (a *Adaptor) ConvertGeminiRequest(*gin.Context, *relaycommon.RelayInfo, *dto.GeminiChatRequest) (any, error) {
|
||||
//TODO implement me
|
||||
return nil, errors.New("not implemented")
|
||||
}
|
||||
|
||||
func (a *Adaptor) ConvertClaudeRequest(*gin.Context, *relaycommon.RelayInfo, *dto.ClaudeRequest) (any, error) {
|
||||
//TODO implement me
|
||||
panic("implement me")
|
||||
@@ -40,9 +45,9 @@ func (a *Adaptor) Init(info *relaycommon.RelayInfo) {
|
||||
|
||||
func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) {
|
||||
if info.RelayMode == constant.RelayModeRerank {
|
||||
return fmt.Sprintf("%s/v1/rerank", info.BaseUrl), nil
|
||||
return fmt.Sprintf("%s/v1/rerank", info.ChannelBaseUrl), nil
|
||||
} else if info.RelayMode == constant.RelayModeEmbeddings {
|
||||
return fmt.Sprintf("%s/v1/embeddings", info.BaseUrl), nil
|
||||
return fmt.Sprintf("%s/v1/embeddings", info.ChannelBaseUrl), nil
|
||||
}
|
||||
return "", errors.New("invalid relay mode")
|
||||
}
|
||||
|
||||
@@ -6,5 +6,5 @@ import (
|
||||
)
|
||||
|
||||
func GetRequestURL(info *relaycommon.RelayInfo) (string, error) {
|
||||
return fmt.Sprintf("%s/v1/text/chatcompletion_v2", info.BaseUrl), nil
|
||||
return fmt.Sprintf("%s/v1/text/chatcompletion_v2", info.ChannelBaseUrl), nil
|
||||
}
|
||||
|
||||
@@ -16,6 +16,11 @@ import (
|
||||
type Adaptor struct {
|
||||
}
|
||||
|
||||
func (a *Adaptor) ConvertGeminiRequest(*gin.Context, *relaycommon.RelayInfo, *dto.GeminiChatRequest) (any, error) {
|
||||
//TODO implement me
|
||||
return nil, errors.New("not implemented")
|
||||
}
|
||||
|
||||
func (a *Adaptor) ConvertClaudeRequest(*gin.Context, *relaycommon.RelayInfo, *dto.ClaudeRequest) (any, error) {
|
||||
//TODO implement me
|
||||
panic("implement me")
|
||||
@@ -36,7 +41,7 @@ func (a *Adaptor) Init(info *relaycommon.RelayInfo) {
|
||||
}
|
||||
|
||||
func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) {
|
||||
return relaycommon.GetFullRequestURL(info.BaseUrl, info.RequestURLPath, info.ChannelType), nil
|
||||
return relaycommon.GetFullRequestURL(info.ChannelBaseUrl, info.RequestURLPath, info.ChannelType), nil
|
||||
}
|
||||
|
||||
func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Header, info *relaycommon.RelayInfo) error {
|
||||
|
||||
@@ -71,7 +71,7 @@ func requestOpenAI2Mistral(request *dto.GeneralOpenAIRequest) *dto.GeneralOpenAI
|
||||
Messages: messages,
|
||||
Temperature: request.Temperature,
|
||||
TopP: request.TopP,
|
||||
MaxTokens: request.MaxTokens,
|
||||
MaxTokens: request.GetMaxTokens(),
|
||||
Tools: request.Tools,
|
||||
ToolChoice: request.ToolChoice,
|
||||
}
|
||||
|
||||
@@ -18,6 +18,11 @@ import (
|
||||
type Adaptor struct {
|
||||
}
|
||||
|
||||
func (a *Adaptor) ConvertGeminiRequest(*gin.Context, *relaycommon.RelayInfo, *dto.GeminiChatRequest) (any, error) {
|
||||
//TODO implement me
|
||||
return nil, errors.New("not implemented")
|
||||
}
|
||||
|
||||
func (a *Adaptor) ConvertClaudeRequest(*gin.Context, *relaycommon.RelayInfo, *dto.ClaudeRequest) (any, error) {
|
||||
//TODO implement me
|
||||
panic("implement me")
|
||||
@@ -49,7 +54,7 @@ func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) {
|
||||
if strings.HasPrefix(info.UpstreamModelName, "m3e") {
|
||||
suffix = "embeddings"
|
||||
}
|
||||
fullRequestURL := fmt.Sprintf("%s/%s", info.BaseUrl, suffix)
|
||||
fullRequestURL := fmt.Sprintf("%s/%s", info.ChannelBaseUrl, suffix)
|
||||
return fullRequestURL, nil
|
||||
}
|
||||
|
||||
|
||||
@@ -6,4 +6,4 @@ var ModelList = []string{
|
||||
"m3e-small",
|
||||
}
|
||||
|
||||
var ChannelName = "mokaai"
|
||||
var ChannelName = "mokaai"
|
||||
|
||||
@@ -7,6 +7,7 @@ import (
|
||||
"one-api/common"
|
||||
"one-api/dto"
|
||||
relaycommon "one-api/relay/common"
|
||||
"one-api/service"
|
||||
"one-api/types"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
@@ -56,7 +57,7 @@ func mokaEmbeddingHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *htt
|
||||
if err != nil {
|
||||
return nil, types.NewError(err, types.ErrorCodeBadResponseBody)
|
||||
}
|
||||
common.CloseResponseBodyGracefully(resp)
|
||||
service.CloseResponseBodyGracefully(resp)
|
||||
err = json.Unmarshal(responseBody, &baiduResponse)
|
||||
if err != nil {
|
||||
return nil, types.NewError(err, types.ErrorCodeBadResponseBody)
|
||||
@@ -77,6 +78,6 @@ func mokaEmbeddingHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *htt
|
||||
}
|
||||
c.Writer.Header().Set("Content-Type", "application/json")
|
||||
c.Writer.WriteHeader(resp.StatusCode)
|
||||
common.IOCopyBytesGracefully(c, resp, jsonResponse)
|
||||
service.IOCopyBytesGracefully(c, resp, jsonResponse)
|
||||
return &fullTextResponse.Usage, nil
|
||||
}
|
||||
|
||||
110
relay/channel/moonshot/adaptor.go
Normal file
110
relay/channel/moonshot/adaptor.go
Normal file
@@ -0,0 +1,110 @@
|
||||
package moonshot
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"one-api/dto"
|
||||
"one-api/relay/channel"
|
||||
"one-api/relay/channel/claude"
|
||||
"one-api/relay/channel/openai"
|
||||
relaycommon "one-api/relay/common"
|
||||
"one-api/relay/constant"
|
||||
"one-api/types"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
type Adaptor struct {
|
||||
}
|
||||
|
||||
func (a *Adaptor) ConvertGeminiRequest(*gin.Context, *relaycommon.RelayInfo, *dto.GeminiChatRequest) (any, error) {
|
||||
//TODO implement me
|
||||
return nil, errors.New("not implemented")
|
||||
}
|
||||
|
||||
func (a *Adaptor) ConvertClaudeRequest(c *gin.Context, info *relaycommon.RelayInfo, req *dto.ClaudeRequest) (any, error) {
|
||||
adaptor := openai.Adaptor{}
|
||||
return adaptor.ConvertClaudeRequest(c, info, req)
|
||||
}
|
||||
|
||||
func (a *Adaptor) ConvertAudioRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.AudioRequest) (io.Reader, error) {
|
||||
//TODO implement me
|
||||
return nil, errors.New("not supported")
|
||||
}
|
||||
|
||||
func (a *Adaptor) ConvertImageRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.ImageRequest) (any, error) {
|
||||
adaptor := openai.Adaptor{}
|
||||
return adaptor.ConvertImageRequest(c, info, request)
|
||||
}
|
||||
|
||||
func (a *Adaptor) Init(info *relaycommon.RelayInfo) {
|
||||
}
|
||||
|
||||
func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) {
|
||||
switch info.RelayFormat {
|
||||
case types.RelayFormatClaude:
|
||||
return fmt.Sprintf("%s/anthropic/v1/messages", info.ChannelBaseUrl), nil
|
||||
default:
|
||||
if info.RelayMode == constant.RelayModeRerank {
|
||||
return fmt.Sprintf("%s/v1/rerank", info.ChannelBaseUrl), nil
|
||||
} else if info.RelayMode == constant.RelayModeEmbeddings {
|
||||
return fmt.Sprintf("%s/v1/embeddings", info.ChannelBaseUrl), nil
|
||||
} else if info.RelayMode == constant.RelayModeChatCompletions {
|
||||
return fmt.Sprintf("%s/v1/chat/completions", info.ChannelBaseUrl), nil
|
||||
} else if info.RelayMode == constant.RelayModeCompletions {
|
||||
return fmt.Sprintf("%s/v1/completions", info.ChannelBaseUrl), nil
|
||||
}
|
||||
return fmt.Sprintf("%s/v1/chat/completions", info.ChannelBaseUrl), nil
|
||||
}
|
||||
}
|
||||
|
||||
func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Header, info *relaycommon.RelayInfo) error {
|
||||
channel.SetupApiRequestHeader(info, c, req)
|
||||
req.Set("Authorization", fmt.Sprintf("Bearer %s", info.ApiKey))
|
||||
return nil
|
||||
}
|
||||
|
||||
func (a *Adaptor) ConvertOpenAIRequest(c *gin.Context, info *relaycommon.RelayInfo, request *dto.GeneralOpenAIRequest) (any, error) {
|
||||
return request, nil
|
||||
}
|
||||
|
||||
func (a *Adaptor) ConvertOpenAIResponsesRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.OpenAIResponsesRequest) (any, error) {
|
||||
// TODO implement me
|
||||
return nil, errors.New("not implemented")
|
||||
}
|
||||
|
||||
func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (any, error) {
|
||||
return channel.DoApiRequest(a, c, info, requestBody)
|
||||
}
|
||||
|
||||
func (a *Adaptor) ConvertRerankRequest(c *gin.Context, relayMode int, request dto.RerankRequest) (any, error) {
|
||||
return request, nil
|
||||
}
|
||||
|
||||
func (a *Adaptor) ConvertEmbeddingRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.EmbeddingRequest) (any, error) {
|
||||
return request, nil
|
||||
}
|
||||
|
||||
func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage any, err *types.NewAPIError) {
|
||||
switch info.RelayFormat {
|
||||
case types.RelayFormatClaude:
|
||||
if info.IsStream {
|
||||
return claude.ClaudeStreamHandler(c, resp, info, claude.RequestModeMessage)
|
||||
} else {
|
||||
return claude.ClaudeHandler(c, resp, info, claude.RequestModeMessage)
|
||||
}
|
||||
default:
|
||||
adaptor := openai.Adaptor{}
|
||||
return adaptor.DoResponse(c, resp, info)
|
||||
}
|
||||
}
|
||||
|
||||
func (a *Adaptor) GetModelList() []string {
|
||||
return ModelList
|
||||
}
|
||||
|
||||
func (a *Adaptor) GetChannelName() string {
|
||||
return ChannelName
|
||||
}
|
||||
@@ -17,10 +17,21 @@ import (
|
||||
type Adaptor struct {
|
||||
}
|
||||
|
||||
func (a *Adaptor) ConvertClaudeRequest(*gin.Context, *relaycommon.RelayInfo, *dto.ClaudeRequest) (any, error) {
|
||||
func (a *Adaptor) ConvertGeminiRequest(*gin.Context, *relaycommon.RelayInfo, *dto.GeminiChatRequest) (any, error) {
|
||||
//TODO implement me
|
||||
panic("implement me")
|
||||
return nil, nil
|
||||
return nil, errors.New("not implemented")
|
||||
}
|
||||
|
||||
func (a *Adaptor) ConvertClaudeRequest(c *gin.Context, info *relaycommon.RelayInfo, request *dto.ClaudeRequest) (any, error) {
|
||||
openaiAdaptor := openai.Adaptor{}
|
||||
openaiRequest, err := openaiAdaptor.ConvertClaudeRequest(c, info, request)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
openaiRequest.(*dto.GeneralOpenAIRequest).StreamOptions = &dto.StreamOptions{
|
||||
IncludeUsage: true,
|
||||
}
|
||||
return requestOpenAI2Ollama(c, openaiRequest.(*dto.GeneralOpenAIRequest))
|
||||
}
|
||||
|
||||
func (a *Adaptor) ConvertAudioRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.AudioRequest) (io.Reader, error) {
|
||||
@@ -37,11 +48,14 @@ func (a *Adaptor) Init(info *relaycommon.RelayInfo) {
|
||||
}
|
||||
|
||||
func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) {
|
||||
if info.RelayFormat == types.RelayFormatClaude {
|
||||
return info.ChannelBaseUrl + "/v1/chat/completions", nil
|
||||
}
|
||||
switch info.RelayMode {
|
||||
case relayconstant.RelayModeEmbeddings:
|
||||
return info.BaseUrl + "/api/embed", nil
|
||||
return info.ChannelBaseUrl + "/api/embed", nil
|
||||
default:
|
||||
return relaycommon.GetFullRequestURL(info.BaseUrl, info.RequestURLPath, info.ChannelType), nil
|
||||
return relaycommon.GetFullRequestURL(info.ChannelBaseUrl, info.RequestURLPath, info.ChannelType), nil
|
||||
}
|
||||
}
|
||||
|
||||
@@ -55,7 +69,7 @@ func (a *Adaptor) ConvertOpenAIRequest(c *gin.Context, info *relaycommon.RelayIn
|
||||
if request == nil {
|
||||
return nil, errors.New("request is nil")
|
||||
}
|
||||
return requestOpenAI2Ollama(*request)
|
||||
return requestOpenAI2Ollama(c, request)
|
||||
}
|
||||
|
||||
func (a *Adaptor) ConvertRerankRequest(c *gin.Context, relayMode int, request dto.RerankRequest) (any, error) {
|
||||
@@ -76,11 +90,12 @@ func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, request
|
||||
}
|
||||
|
||||
func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage any, err *types.NewAPIError) {
|
||||
if info.IsStream {
|
||||
usage, err = openai.OaiStreamHandler(c, info, resp)
|
||||
} else {
|
||||
if info.RelayMode == relayconstant.RelayModeEmbeddings {
|
||||
usage, err = ollamaEmbeddingHandler(c, info, resp)
|
||||
switch info.RelayMode {
|
||||
case relayconstant.RelayModeEmbeddings:
|
||||
usage, err = ollamaEmbeddingHandler(c, info, resp)
|
||||
default:
|
||||
if info.IsStream {
|
||||
usage, err = openai.OaiStreamHandler(c, info, resp)
|
||||
} else {
|
||||
usage, err = openai.OpenaiHandler(c, info, resp)
|
||||
}
|
||||
|
||||
@@ -1,6 +1,9 @@
|
||||
package ollama
|
||||
|
||||
import "one-api/dto"
|
||||
import (
|
||||
"encoding/json"
|
||||
"one-api/dto"
|
||||
)
|
||||
|
||||
type OllamaRequest struct {
|
||||
Model string `json:"model,omitempty"`
|
||||
@@ -19,6 +22,7 @@ type OllamaRequest struct {
|
||||
Suffix any `json:"suffix,omitempty"`
|
||||
StreamOptions *dto.StreamOptions `json:"stream_options,omitempty"`
|
||||
Prompt any `json:"prompt,omitempty"`
|
||||
Think json.RawMessage `json:"think,omitempty"`
|
||||
}
|
||||
|
||||
type Options struct {
|
||||
|
||||
@@ -14,7 +14,7 @@ import (
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
func requestOpenAI2Ollama(request dto.GeneralOpenAIRequest) (*OllamaRequest, error) {
|
||||
func requestOpenAI2Ollama(c *gin.Context, request *dto.GeneralOpenAIRequest) (*OllamaRequest, error) {
|
||||
messages := make([]dto.Message, 0, len(request.Messages))
|
||||
for _, message := range request.Messages {
|
||||
if !message.IsStringContent() {
|
||||
@@ -24,7 +24,7 @@ func requestOpenAI2Ollama(request dto.GeneralOpenAIRequest) (*OllamaRequest, err
|
||||
imageUrl := mediaMessage.GetImageMedia()
|
||||
// check if not base64
|
||||
if strings.HasPrefix(imageUrl.Url, "http") {
|
||||
fileData, err := service.GetFileBase64FromUrl(imageUrl.Url)
|
||||
fileData, err := service.GetFileBase64FromUrl(c, imageUrl.Url, "formatting image for Ollama")
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -50,7 +50,7 @@ func requestOpenAI2Ollama(request dto.GeneralOpenAIRequest) (*OllamaRequest, err
|
||||
} else {
|
||||
Stop, _ = request.Stop.([]string)
|
||||
}
|
||||
return &OllamaRequest{
|
||||
ollamaRequest := &OllamaRequest{
|
||||
Model: request.Model,
|
||||
Messages: messages,
|
||||
Stream: request.Stream,
|
||||
@@ -60,14 +60,16 @@ func requestOpenAI2Ollama(request dto.GeneralOpenAIRequest) (*OllamaRequest, err
|
||||
TopK: request.TopK,
|
||||
Stop: Stop,
|
||||
Tools: request.Tools,
|
||||
MaxTokens: request.MaxTokens,
|
||||
MaxTokens: request.GetMaxTokens(),
|
||||
ResponseFormat: request.ResponseFormat,
|
||||
FrequencyPenalty: request.FrequencyPenalty,
|
||||
PresencePenalty: request.PresencePenalty,
|
||||
Prompt: request.Prompt,
|
||||
StreamOptions: request.StreamOptions,
|
||||
Suffix: request.Suffix,
|
||||
}, nil
|
||||
}
|
||||
ollamaRequest.Think = request.Think
|
||||
return ollamaRequest, nil
|
||||
}
|
||||
|
||||
func requestOpenAI2Embeddings(request dto.EmbeddingRequest) *OllamaEmbeddingRequest {
|
||||
@@ -88,15 +90,15 @@ func ollamaEmbeddingHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *h
|
||||
var ollamaEmbeddingResponse OllamaEmbeddingResponse
|
||||
responseBody, err := io.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
return nil, types.NewError(err, types.ErrorCodeBadResponseBody)
|
||||
return nil, types.NewOpenAIError(err, types.ErrorCodeBadResponseBody, http.StatusInternalServerError)
|
||||
}
|
||||
common.CloseResponseBodyGracefully(resp)
|
||||
service.CloseResponseBodyGracefully(resp)
|
||||
err = common.Unmarshal(responseBody, &ollamaEmbeddingResponse)
|
||||
if err != nil {
|
||||
return nil, types.NewError(err, types.ErrorCodeBadResponseBody)
|
||||
return nil, types.NewOpenAIError(err, types.ErrorCodeBadResponseBody, http.StatusInternalServerError)
|
||||
}
|
||||
if ollamaEmbeddingResponse.Error != "" {
|
||||
return nil, types.NewError(fmt.Errorf("ollama error: %s", ollamaEmbeddingResponse.Error), types.ErrorCodeBadResponseBody)
|
||||
return nil, types.NewOpenAIError(fmt.Errorf("ollama error: %s", ollamaEmbeddingResponse.Error), types.ErrorCodeBadResponseBody, http.StatusInternalServerError)
|
||||
}
|
||||
flattenedEmbeddings := flattenEmbeddings(ollamaEmbeddingResponse.Embedding)
|
||||
data := make([]dto.OpenAIEmbeddingResponseItem, 0, 1)
|
||||
@@ -117,9 +119,9 @@ func ollamaEmbeddingHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *h
|
||||
}
|
||||
doResponseBody, err := common.Marshal(embeddingResponse)
|
||||
if err != nil {
|
||||
return nil, types.NewError(err, types.ErrorCodeBadResponseBody)
|
||||
return nil, types.NewOpenAIError(err, types.ErrorCodeBadResponseBody, http.StatusInternalServerError)
|
||||
}
|
||||
common.IOCopyBytesGracefully(c, resp, doResponseBody)
|
||||
service.IOCopyBytesGracefully(c, resp, doResponseBody)
|
||||
return usage, nil
|
||||
}
|
||||
|
||||
|
||||
@@ -9,13 +9,13 @@ import (
|
||||
"mime/multipart"
|
||||
"net/http"
|
||||
"net/textproto"
|
||||
"one-api/common"
|
||||
"one-api/constant"
|
||||
"one-api/dto"
|
||||
"one-api/relay/channel"
|
||||
"one-api/relay/channel/ai360"
|
||||
"one-api/relay/channel/lingyiwanwu"
|
||||
"one-api/relay/channel/minimax"
|
||||
"one-api/relay/channel/moonshot"
|
||||
"one-api/relay/channel/openrouter"
|
||||
"one-api/relay/channel/xinference"
|
||||
relaycommon "one-api/relay/common"
|
||||
@@ -34,15 +34,55 @@ type Adaptor struct {
|
||||
ResponseFormat string
|
||||
}
|
||||
|
||||
// parseReasoningEffortFromModelSuffix 从模型名称中解析推理级别
|
||||
// support OAI models: o1-mini/o3-mini/o4-mini/o1/o3 etc...
|
||||
// minimal effort only available in gpt-5
|
||||
func parseReasoningEffortFromModelSuffix(model string) (string, string) {
|
||||
effortSuffixes := []string{"-high", "-minimal", "-low", "-medium"}
|
||||
for _, suffix := range effortSuffixes {
|
||||
if strings.HasSuffix(model, suffix) {
|
||||
effort := strings.TrimPrefix(suffix, "-")
|
||||
originModel := strings.TrimSuffix(model, suffix)
|
||||
return effort, originModel
|
||||
}
|
||||
}
|
||||
return "", model
|
||||
}
|
||||
|
||||
func (a *Adaptor) ConvertGeminiRequest(c *gin.Context, info *relaycommon.RelayInfo, request *dto.GeminiChatRequest) (any, error) {
|
||||
// 使用 service.GeminiToOpenAIRequest 转换请求格式
|
||||
openaiRequest, err := service.GeminiToOpenAIRequest(request, info)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return a.ConvertOpenAIRequest(c, info, openaiRequest)
|
||||
}
|
||||
|
||||
func (a *Adaptor) ConvertClaudeRequest(c *gin.Context, info *relaycommon.RelayInfo, request *dto.ClaudeRequest) (any, error) {
|
||||
//if !strings.Contains(request.Model, "claude") {
|
||||
// return nil, fmt.Errorf("you are using openai channel type with path /v1/messages, only claude model supported convert, but got %s", request.Model)
|
||||
//}
|
||||
//if common.DebugEnabled {
|
||||
// bodyBytes := []byte(common.GetJsonString(request))
|
||||
// err := os.WriteFile(fmt.Sprintf("claude_request_%s.txt", c.GetString(common.RequestIdKey)), bodyBytes, 0644)
|
||||
// if err != nil {
|
||||
// println(fmt.Sprintf("failed to save request body to file: %v", err))
|
||||
// }
|
||||
//}
|
||||
aiRequest, err := service.ClaudeToOpenAIRequest(*request, info)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if info.SupportStreamOptions {
|
||||
//if common.DebugEnabled {
|
||||
// println(fmt.Sprintf("convert claude to openai request result: %s", common.GetJsonString(aiRequest)))
|
||||
// // Save request body to file for debugging
|
||||
// bodyBytes := []byte(common.GetJsonString(aiRequest))
|
||||
// err = os.WriteFile(fmt.Sprintf("claude_to_openai_request_%s.txt", c.GetString(common.RequestIdKey)), bodyBytes, 0644)
|
||||
// if err != nil {
|
||||
// println(fmt.Sprintf("failed to save request body to file: %v", err))
|
||||
// }
|
||||
//}
|
||||
if info.SupportStreamOptions && info.IsStream {
|
||||
aiRequest.StreamOptions = &dto.StreamOptions{
|
||||
IncludeUsage: true,
|
||||
}
|
||||
@@ -64,18 +104,15 @@ func (a *Adaptor) Init(info *relaycommon.RelayInfo) {
|
||||
}
|
||||
|
||||
func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) {
|
||||
if info.RelayFormat == relaycommon.RelayFormatClaude {
|
||||
return fmt.Sprintf("%s/v1/chat/completions", info.BaseUrl), nil
|
||||
}
|
||||
if info.RelayMode == relayconstant.RelayModeRealtime {
|
||||
if strings.HasPrefix(info.BaseUrl, "https://") {
|
||||
baseUrl := strings.TrimPrefix(info.BaseUrl, "https://")
|
||||
if strings.HasPrefix(info.ChannelBaseUrl, "https://") {
|
||||
baseUrl := strings.TrimPrefix(info.ChannelBaseUrl, "https://")
|
||||
baseUrl = "wss://" + baseUrl
|
||||
info.BaseUrl = baseUrl
|
||||
} else if strings.HasPrefix(info.BaseUrl, "http://") {
|
||||
baseUrl := strings.TrimPrefix(info.BaseUrl, "http://")
|
||||
info.ChannelBaseUrl = baseUrl
|
||||
} else if strings.HasPrefix(info.ChannelBaseUrl, "http://") {
|
||||
baseUrl := strings.TrimPrefix(info.ChannelBaseUrl, "http://")
|
||||
baseUrl = "ws://" + baseUrl
|
||||
info.BaseUrl = baseUrl
|
||||
info.ChannelBaseUrl = baseUrl
|
||||
}
|
||||
}
|
||||
switch info.ChannelType {
|
||||
@@ -89,10 +126,27 @@ func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) {
|
||||
requestURL = fmt.Sprintf("%s?api-version=%s", requestURL, apiVersion)
|
||||
task := strings.TrimPrefix(requestURL, "/v1/")
|
||||
|
||||
if info.RelayFormat == types.RelayFormatClaude {
|
||||
task = strings.TrimPrefix(task, "messages")
|
||||
task = "chat/completions" + task
|
||||
}
|
||||
|
||||
// 特殊处理 responses API
|
||||
if info.RelayMode == relayconstant.RelayModeResponses {
|
||||
requestURL = fmt.Sprintf("/openai/v1/responses?api-version=preview")
|
||||
return relaycommon.GetFullRequestURL(info.BaseUrl, requestURL, info.ChannelType), nil
|
||||
responsesApiVersion := "preview"
|
||||
|
||||
subUrl := "/openai/v1/responses"
|
||||
if strings.Contains(info.ChannelBaseUrl, "cognitiveservices.azure.com") {
|
||||
subUrl = "/openai/responses"
|
||||
responsesApiVersion = apiVersion
|
||||
}
|
||||
|
||||
if info.ChannelOtherSettings.AzureResponsesVersion != "" {
|
||||
responsesApiVersion = info.ChannelOtherSettings.AzureResponsesVersion
|
||||
}
|
||||
|
||||
requestURL = fmt.Sprintf("%s?api-version=%s", subUrl, responsesApiVersion)
|
||||
return relaycommon.GetFullRequestURL(info.ChannelBaseUrl, requestURL, info.ChannelType), nil
|
||||
}
|
||||
|
||||
model_ := info.UpstreamModelName
|
||||
@@ -105,15 +159,18 @@ func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) {
|
||||
if info.RelayMode == relayconstant.RelayModeRealtime {
|
||||
requestURL = fmt.Sprintf("/openai/realtime?deployment=%s&api-version=%s", model_, apiVersion)
|
||||
}
|
||||
return relaycommon.GetFullRequestURL(info.BaseUrl, requestURL, info.ChannelType), nil
|
||||
return relaycommon.GetFullRequestURL(info.ChannelBaseUrl, requestURL, info.ChannelType), nil
|
||||
case constant.ChannelTypeMiniMax:
|
||||
return minimax.GetRequestURL(info)
|
||||
case constant.ChannelTypeCustom:
|
||||
url := info.BaseUrl
|
||||
url := info.ChannelBaseUrl
|
||||
url = strings.Replace(url, "{model}", info.UpstreamModelName, -1)
|
||||
return url, nil
|
||||
default:
|
||||
return relaycommon.GetFullRequestURL(info.BaseUrl, info.RequestURLPath, info.ChannelType), nil
|
||||
if info.RelayFormat == types.RelayFormatClaude || info.RelayFormat == types.RelayFormatGemini {
|
||||
return fmt.Sprintf("%s/v1/chat/completions", info.ChannelBaseUrl), nil
|
||||
}
|
||||
return relaycommon.GetFullRequestURL(info.ChannelBaseUrl, info.RequestURLPath, info.ChannelType), nil
|
||||
}
|
||||
}
|
||||
|
||||
@@ -163,28 +220,105 @@ func (a *Adaptor) ConvertOpenAIRequest(c *gin.Context, info *relaycommon.RelayIn
|
||||
if len(request.Usage) == 0 {
|
||||
request.Usage = json.RawMessage(`{"include":true}`)
|
||||
}
|
||||
// 适配 OpenRouter 的 thinking 后缀
|
||||
if strings.HasSuffix(info.UpstreamModelName, "-thinking") {
|
||||
info.UpstreamModelName = strings.TrimSuffix(info.UpstreamModelName, "-thinking")
|
||||
request.Model = info.UpstreamModelName
|
||||
if len(request.Reasoning) == 0 {
|
||||
reasoning := map[string]any{
|
||||
"enabled": true,
|
||||
}
|
||||
if request.ReasoningEffort != "" && request.ReasoningEffort != "none" {
|
||||
reasoning["effort"] = request.ReasoningEffort
|
||||
}
|
||||
marshal, err := common.Marshal(reasoning)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("error marshalling reasoning: %w", err)
|
||||
}
|
||||
request.Reasoning = marshal
|
||||
}
|
||||
// 清空多余的ReasoningEffort
|
||||
request.ReasoningEffort = ""
|
||||
} else {
|
||||
if len(request.Reasoning) == 0 {
|
||||
// 适配 OpenAI 的 ReasoningEffort 格式
|
||||
if request.ReasoningEffort != "" {
|
||||
reasoning := map[string]any{
|
||||
"enabled": true,
|
||||
}
|
||||
if request.ReasoningEffort != "none" {
|
||||
reasoning["effort"] = request.ReasoningEffort
|
||||
marshal, err := common.Marshal(reasoning)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("error marshalling reasoning: %w", err)
|
||||
}
|
||||
request.Reasoning = marshal
|
||||
}
|
||||
}
|
||||
}
|
||||
request.ReasoningEffort = ""
|
||||
}
|
||||
|
||||
// https://docs.anthropic.com/en/api/openai-sdk#extended-thinking-support
|
||||
// 没有做排除3.5Haiku等,要出问题再加吧,最佳兼容性(不是
|
||||
if request.THINKING != nil && strings.HasPrefix(info.UpstreamModelName, "anthropic") {
|
||||
var thinking dto.Thinking // Claude标准Thinking格式
|
||||
if err := json.Unmarshal(request.THINKING, &thinking); err != nil {
|
||||
return nil, fmt.Errorf("error Unmarshal thinking: %w", err)
|
||||
}
|
||||
|
||||
// 只有当 thinking.Type 是 "enabled" 时才处理
|
||||
if thinking.Type == "enabled" {
|
||||
// 检查 BudgetTokens 是否为 nil
|
||||
if thinking.BudgetTokens == nil {
|
||||
return nil, fmt.Errorf("BudgetTokens is nil when thinking is enabled")
|
||||
}
|
||||
|
||||
reasoning := openrouter.RequestReasoning{
|
||||
MaxTokens: *thinking.BudgetTokens,
|
||||
}
|
||||
|
||||
marshal, err := common.Marshal(reasoning)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("error marshalling reasoning: %w", err)
|
||||
}
|
||||
|
||||
request.Reasoning = marshal
|
||||
}
|
||||
|
||||
// 清空 THINKING
|
||||
request.THINKING = nil
|
||||
}
|
||||
|
||||
}
|
||||
if strings.HasPrefix(request.Model, "o") {
|
||||
if strings.HasPrefix(info.UpstreamModelName, "o") || strings.HasPrefix(info.UpstreamModelName, "gpt-5") {
|
||||
if request.MaxCompletionTokens == 0 && request.MaxTokens != 0 {
|
||||
request.MaxCompletionTokens = request.MaxTokens
|
||||
request.MaxTokens = 0
|
||||
}
|
||||
request.Temperature = nil
|
||||
if strings.HasSuffix(request.Model, "-high") {
|
||||
request.ReasoningEffort = "high"
|
||||
request.Model = strings.TrimSuffix(request.Model, "-high")
|
||||
} else if strings.HasSuffix(request.Model, "-low") {
|
||||
request.ReasoningEffort = "low"
|
||||
request.Model = strings.TrimSuffix(request.Model, "-low")
|
||||
} else if strings.HasSuffix(request.Model, "-medium") {
|
||||
request.ReasoningEffort = "medium"
|
||||
request.Model = strings.TrimSuffix(request.Model, "-medium")
|
||||
|
||||
if strings.HasPrefix(info.UpstreamModelName, "o") {
|
||||
request.Temperature = nil
|
||||
}
|
||||
|
||||
if strings.HasPrefix(info.UpstreamModelName, "gpt-5") {
|
||||
if info.UpstreamModelName != "gpt-5-chat-latest" {
|
||||
request.Temperature = nil
|
||||
}
|
||||
}
|
||||
|
||||
// 转换模型推理力度后缀
|
||||
effort, originModel := parseReasoningEffortFromModelSuffix(info.UpstreamModelName)
|
||||
if effort != "" {
|
||||
request.ReasoningEffort = effort
|
||||
info.UpstreamModelName = originModel
|
||||
request.Model = originModel
|
||||
}
|
||||
|
||||
info.ReasoningEffort = request.ReasoningEffort
|
||||
info.UpstreamModelName = request.Model
|
||||
|
||||
// o系列模型developer适配(o1-mini除外)
|
||||
if !strings.HasPrefix(request.Model, "o1-mini") && !strings.HasPrefix(request.Model, "o1-preview") {
|
||||
if !strings.HasPrefix(info.UpstreamModelName, "o1-mini") && !strings.HasPrefix(info.UpstreamModelName, "o1-preview") {
|
||||
//修改第一个Message的内容,将system改为developer
|
||||
if len(request.Messages) > 0 && request.Messages[0].Role == "system" {
|
||||
request.Messages[0].Role = "developer"
|
||||
@@ -260,40 +394,42 @@ func (a *Adaptor) ConvertImageRequest(c *gin.Context, info *relaycommon.RelayInf
|
||||
writer := multipart.NewWriter(&requestBody)
|
||||
|
||||
writer.WriteField("model", request.Model)
|
||||
// 获取所有表单字段
|
||||
formData := c.Request.PostForm
|
||||
// 遍历表单字段并打印输出
|
||||
for key, values := range formData {
|
||||
if key == "model" {
|
||||
continue
|
||||
// 使用已解析的 multipart 表单,避免重复解析
|
||||
mf := c.Request.MultipartForm
|
||||
if mf == nil {
|
||||
if _, err := c.MultipartForm(); err != nil {
|
||||
return nil, errors.New("failed to parse multipart form")
|
||||
}
|
||||
for _, value := range values {
|
||||
writer.WriteField(key, value)
|
||||
mf = c.Request.MultipartForm
|
||||
}
|
||||
|
||||
// 写入所有非文件字段
|
||||
if mf != nil {
|
||||
for key, values := range mf.Value {
|
||||
if key == "model" {
|
||||
continue
|
||||
}
|
||||
for _, value := range values {
|
||||
writer.WriteField(key, value)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Parse the multipart form to handle both single image and multiple images
|
||||
if err := c.Request.ParseMultipartForm(32 << 20); err != nil { // 32MB max memory
|
||||
return nil, errors.New("failed to parse multipart form")
|
||||
}
|
||||
|
||||
if c.Request.MultipartForm != nil && c.Request.MultipartForm.File != nil {
|
||||
if mf != nil && mf.File != nil {
|
||||
// Check if "image" field exists in any form, including array notation
|
||||
var imageFiles []*multipart.FileHeader
|
||||
var exists bool
|
||||
|
||||
// First check for standard "image" field
|
||||
if imageFiles, exists = c.Request.MultipartForm.File["image"]; !exists || len(imageFiles) == 0 {
|
||||
if imageFiles, exists = mf.File["image"]; !exists || len(imageFiles) == 0 {
|
||||
// If not found, check for "image[]" field
|
||||
if imageFiles, exists = c.Request.MultipartForm.File["image[]"]; !exists || len(imageFiles) == 0 {
|
||||
if imageFiles, exists = mf.File["image[]"]; !exists || len(imageFiles) == 0 {
|
||||
// If still not found, iterate through all fields to find any that start with "image["
|
||||
foundArrayImages := false
|
||||
for fieldName, files := range c.Request.MultipartForm.File {
|
||||
for fieldName, files := range mf.File {
|
||||
if strings.HasPrefix(fieldName, "image[") && len(files) > 0 {
|
||||
foundArrayImages = true
|
||||
for _, file := range files {
|
||||
imageFiles = append(imageFiles, file)
|
||||
}
|
||||
imageFiles = append(imageFiles, files...)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -310,7 +446,6 @@ func (a *Adaptor) ConvertImageRequest(c *gin.Context, info *relaycommon.RelayInf
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to open image file %d: %w", i, err)
|
||||
}
|
||||
defer file.Close()
|
||||
|
||||
// If multiple images, use image[] as the field name
|
||||
fieldName := "image"
|
||||
@@ -334,15 +469,18 @@ func (a *Adaptor) ConvertImageRequest(c *gin.Context, info *relaycommon.RelayInf
|
||||
if _, err := io.Copy(part, file); err != nil {
|
||||
return nil, fmt.Errorf("copy file failed for image %d: %w", i, err)
|
||||
}
|
||||
|
||||
// 复制完立即关闭,避免在循环内使用 defer 占用资源
|
||||
_ = file.Close()
|
||||
}
|
||||
|
||||
// Handle mask file if present
|
||||
if maskFiles, exists := c.Request.MultipartForm.File["mask"]; exists && len(maskFiles) > 0 {
|
||||
if maskFiles, exists := mf.File["mask"]; exists && len(maskFiles) > 0 {
|
||||
maskFile, err := maskFiles[0].Open()
|
||||
if err != nil {
|
||||
return nil, errors.New("failed to open mask file")
|
||||
}
|
||||
defer maskFile.Close()
|
||||
// 复制完立即关闭,避免在循环内使用 defer 占用资源
|
||||
|
||||
// Determine MIME type for mask file
|
||||
mimeType := detectImageMimeType(maskFiles[0].Filename)
|
||||
@@ -360,6 +498,7 @@ func (a *Adaptor) ConvertImageRequest(c *gin.Context, info *relaycommon.RelayInf
|
||||
if _, err := io.Copy(maskPart, maskFile); err != nil {
|
||||
return nil, errors.New("copy mask file failed")
|
||||
}
|
||||
_ = maskFile.Close()
|
||||
}
|
||||
} else {
|
||||
return nil, errors.New("no multipart form data found")
|
||||
@@ -368,7 +507,7 @@ func (a *Adaptor) ConvertImageRequest(c *gin.Context, info *relaycommon.RelayInf
|
||||
// 关闭 multipart 编写器以设置分界线
|
||||
writer.Close()
|
||||
c.Request.Header.Set("Content-Type", writer.FormDataContentType())
|
||||
return bytes.NewReader(requestBody.Bytes()), nil
|
||||
return &requestBody, nil
|
||||
|
||||
default:
|
||||
return request, nil
|
||||
@@ -396,16 +535,17 @@ func detectImageMimeType(filename string) string {
|
||||
}
|
||||
|
||||
func (a *Adaptor) ConvertOpenAIResponsesRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.OpenAIResponsesRequest) (any, error) {
|
||||
// 模型后缀转换 reasoning effort
|
||||
if strings.HasSuffix(request.Model, "-high") {
|
||||
request.Reasoning.Effort = "high"
|
||||
request.Model = strings.TrimSuffix(request.Model, "-high")
|
||||
} else if strings.HasSuffix(request.Model, "-low") {
|
||||
request.Reasoning.Effort = "low"
|
||||
request.Model = strings.TrimSuffix(request.Model, "-low")
|
||||
} else if strings.HasSuffix(request.Model, "-medium") {
|
||||
request.Reasoning.Effort = "medium"
|
||||
request.Model = strings.TrimSuffix(request.Model, "-medium")
|
||||
// 转换模型推理力度后缀
|
||||
effort, originModel := parseReasoningEffortFromModelSuffix(request.Model)
|
||||
if effort != "" {
|
||||
if request.Reasoning == nil {
|
||||
request.Reasoning = &dto.Reasoning{
|
||||
Effort: effort,
|
||||
}
|
||||
} else {
|
||||
request.Reasoning.Effort = effort
|
||||
}
|
||||
request.Model = originModel
|
||||
}
|
||||
return request, nil
|
||||
}
|
||||
@@ -456,8 +596,6 @@ func (a *Adaptor) GetModelList() []string {
|
||||
switch a.ChannelType {
|
||||
case constant.ChannelType360:
|
||||
return ai360.ModelList
|
||||
case constant.ChannelTypeMoonshot:
|
||||
return moonshot.ModelList
|
||||
case constant.ChannelTypeLingYiWanWu:
|
||||
return lingyiwanwu.ModelList
|
||||
case constant.ChannelTypeMiniMax:
|
||||
@@ -475,8 +613,6 @@ func (a *Adaptor) GetChannelName() string {
|
||||
switch a.ChannelType {
|
||||
case constant.ChannelType360:
|
||||
return ai360.ChannelName
|
||||
case constant.ChannelTypeMoonshot:
|
||||
return moonshot.ChannelName
|
||||
case constant.ChannelTypeLingYiWanWu:
|
||||
return lingyiwanwu.ChannelName
|
||||
case constant.ChannelTypeMiniMax:
|
||||
|
||||
@@ -12,13 +12,25 @@ var ModelList = []string{
|
||||
"gpt-4o", "gpt-4o-2024-05-13", "gpt-4o-2024-08-06", "gpt-4o-2024-11-20",
|
||||
"gpt-4o-mini", "gpt-4o-mini-2024-07-18",
|
||||
"gpt-4.5-preview", "gpt-4.5-preview-2025-02-27",
|
||||
"gpt-4.1", "gpt-4.1-2025-04-14",
|
||||
"gpt-4.1-mini", "gpt-4.1-mini-2025-04-14",
|
||||
"gpt-4.1-nano", "gpt-4.1-nano-2025-04-14",
|
||||
"o1", "o1-2024-12-17",
|
||||
"o1-preview", "o1-preview-2024-09-12",
|
||||
"o1-mini", "o1-mini-2024-09-12",
|
||||
"o1-pro", "o1-pro-2025-03-19",
|
||||
"o3-mini", "o3-mini-2025-01-31",
|
||||
"o3-mini-high", "o3-mini-2025-01-31-high",
|
||||
"o3-mini-low", "o3-mini-2025-01-31-low",
|
||||
"o3-mini-medium", "o3-mini-2025-01-31-medium",
|
||||
"o1", "o1-2024-12-17",
|
||||
"o3", "o3-2025-04-16",
|
||||
"o3-pro", "o3-pro-2025-06-10",
|
||||
"o3-deep-research", "o3-deep-research-2025-06-26",
|
||||
"o4-mini", "o4-mini-2025-04-16",
|
||||
"o4-mini-deep-research", "o4-mini-deep-research-2025-06-26",
|
||||
"gpt-5", "gpt-5-2025-08-07", "gpt-5-chat-latest",
|
||||
"gpt-5-mini", "gpt-5-mini-2025-08-07",
|
||||
"gpt-5-nano", "gpt-5-nano-2025-08-07",
|
||||
"gpt-4o-audio-preview", "gpt-4o-audio-preview-2024-10-01",
|
||||
"gpt-4o-realtime-preview", "gpt-4o-realtime-preview-2024-10-01", "gpt-4o-realtime-preview-2024-12-17",
|
||||
"gpt-4o-mini-realtime-preview", "gpt-4o-mini-realtime-preview-2024-12-17",
|
||||
@@ -27,7 +39,7 @@ var ModelList = []string{
|
||||
"text-moderation-latest", "text-moderation-stable",
|
||||
"text-davinci-edit-001",
|
||||
"davinci-002", "babbage-002",
|
||||
"dall-e-3",
|
||||
"dall-e-3", "gpt-image-1",
|
||||
"whisper-1",
|
||||
"tts-1", "tts-1-1106", "tts-1-hd", "tts-1-hd-1106",
|
||||
}
|
||||
|
||||
@@ -4,30 +4,37 @@ import (
|
||||
"encoding/json"
|
||||
"one-api/common"
|
||||
"one-api/dto"
|
||||
"one-api/logger"
|
||||
relaycommon "one-api/relay/common"
|
||||
relayconstant "one-api/relay/constant"
|
||||
"one-api/relay/helper"
|
||||
"one-api/service"
|
||||
"one-api/types"
|
||||
"strings"
|
||||
|
||||
"github.com/samber/lo"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
// 辅助函数
|
||||
func handleStreamFormat(c *gin.Context, info *relaycommon.RelayInfo, data string, forceFormat bool, thinkToContent bool) error {
|
||||
func HandleStreamFormat(c *gin.Context, info *relaycommon.RelayInfo, data string, forceFormat bool, thinkToContent bool) error {
|
||||
info.SendResponseCount++
|
||||
|
||||
switch info.RelayFormat {
|
||||
case relaycommon.RelayFormatOpenAI:
|
||||
case types.RelayFormatOpenAI:
|
||||
return sendStreamData(c, info, data, forceFormat, thinkToContent)
|
||||
case relaycommon.RelayFormatClaude:
|
||||
case types.RelayFormatClaude:
|
||||
return handleClaudeFormat(c, data, info)
|
||||
case types.RelayFormatGemini:
|
||||
return handleGeminiFormat(c, data, info)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func handleClaudeFormat(c *gin.Context, data string, info *relaycommon.RelayInfo) error {
|
||||
var streamResponse dto.ChatCompletionsStreamResponse
|
||||
if err := json.Unmarshal(common.StringToByteSlice(data), &streamResponse); err != nil {
|
||||
if err := common.Unmarshal(common.StringToByteSlice(data), &streamResponse); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
@@ -41,6 +48,32 @@ func handleClaudeFormat(c *gin.Context, data string, info *relaycommon.RelayInfo
|
||||
return nil
|
||||
}
|
||||
|
||||
func handleGeminiFormat(c *gin.Context, data string, info *relaycommon.RelayInfo) error {
|
||||
var streamResponse dto.ChatCompletionsStreamResponse
|
||||
if err := common.Unmarshal(common.StringToByteSlice(data), &streamResponse); err != nil {
|
||||
logger.LogError(c, "failed to unmarshal stream response: "+err.Error())
|
||||
return err
|
||||
}
|
||||
|
||||
geminiResponse := service.StreamResponseOpenAI2Gemini(&streamResponse, info)
|
||||
|
||||
// 如果返回 nil,表示没有实际内容,跳过发送
|
||||
if geminiResponse == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
geminiResponseStr, err := common.Marshal(geminiResponse)
|
||||
if err != nil {
|
||||
logger.LogError(c, "failed to marshal gemini response: "+err.Error())
|
||||
return err
|
||||
}
|
||||
|
||||
// send gemini format response
|
||||
c.Render(-1, common.CustomEvent{Data: "data: " + string(geminiResponseStr)})
|
||||
_ = helper.FlushWriter(c)
|
||||
return nil
|
||||
}
|
||||
|
||||
func ProcessStreamResponse(streamResponse dto.ChatCompletionsStreamResponse, responseTextBuilder *strings.Builder, toolCount *int) error {
|
||||
for _, choice := range streamResponse.Choices {
|
||||
responseTextBuilder.WriteString(choice.Delta.GetContentString())
|
||||
@@ -74,14 +107,14 @@ func processChatCompletions(streamResp string, streamItems []string, responseTex
|
||||
var streamResponses []dto.ChatCompletionsStreamResponse
|
||||
if err := json.Unmarshal(common.StringToByteSlice(streamResp), &streamResponses); err != nil {
|
||||
// 一次性解析失败,逐个解析
|
||||
common.SysError("error unmarshalling stream response: " + err.Error())
|
||||
common.SysLog("error unmarshalling stream response: " + err.Error())
|
||||
for _, item := range streamItems {
|
||||
var streamResponse dto.ChatCompletionsStreamResponse
|
||||
if err := json.Unmarshal(common.StringToByteSlice(item), &streamResponse); err != nil {
|
||||
return err
|
||||
}
|
||||
if err := ProcessStreamResponse(streamResponse, responseTextBuilder, toolCount); err != nil {
|
||||
common.SysError("error processing stream response: " + err.Error())
|
||||
common.SysLog("error processing stream response: " + err.Error())
|
||||
}
|
||||
}
|
||||
return nil
|
||||
@@ -110,7 +143,7 @@ func processCompletions(streamResp string, streamItems []string, responseTextBui
|
||||
var streamResponses []dto.CompletionsStreamResponse
|
||||
if err := json.Unmarshal(common.StringToByteSlice(streamResp), &streamResponses); err != nil {
|
||||
// 一次性解析失败,逐个解析
|
||||
common.SysError("error unmarshalling stream response: " + err.Error())
|
||||
common.SysLog("error unmarshalling stream response: " + err.Error())
|
||||
for _, item := range streamItems {
|
||||
var streamResponse dto.CompletionsStreamResponse
|
||||
if err := json.Unmarshal(common.StringToByteSlice(item), &streamResponse); err != nil {
|
||||
@@ -151,19 +184,21 @@ func handleLastResponse(lastStreamData string, responseId *string, createAt *int
|
||||
*containStreamUsage = true
|
||||
*usage = lastStreamResponse.Usage
|
||||
if !info.ShouldIncludeUsage {
|
||||
*shouldSendLastResp = false
|
||||
*shouldSendLastResp = lo.SomeBy(lastStreamResponse.Choices, func(choice dto.ChatCompletionsStreamResponseChoice) bool {
|
||||
return choice.Delta.GetContentString() != "" || choice.Delta.GetReasoningContent() != ""
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func handleFinalResponse(c *gin.Context, info *relaycommon.RelayInfo, lastStreamData string,
|
||||
func HandleFinalResponse(c *gin.Context, info *relaycommon.RelayInfo, lastStreamData string,
|
||||
responseId string, createAt int64, model string, systemFingerprint string,
|
||||
usage *dto.Usage, containStreamUsage bool) {
|
||||
|
||||
switch info.RelayFormat {
|
||||
case relaycommon.RelayFormatOpenAI:
|
||||
case types.RelayFormatOpenAI:
|
||||
if info.ShouldIncludeUsage && !containStreamUsage {
|
||||
response := helper.GenerateFinalUsageResponse(responseId, createAt, model, *usage)
|
||||
response.SetSystemFingerprint(systemFingerprint)
|
||||
@@ -171,11 +206,11 @@ func handleFinalResponse(c *gin.Context, info *relaycommon.RelayInfo, lastStream
|
||||
}
|
||||
helper.Done(c)
|
||||
|
||||
case relaycommon.RelayFormatClaude:
|
||||
case types.RelayFormatClaude:
|
||||
info.ClaudeConvertInfo.Done = true
|
||||
var streamResponse dto.ChatCompletionsStreamResponse
|
||||
if err := json.Unmarshal(common.StringToByteSlice(lastStreamData), &streamResponse); err != nil {
|
||||
common.SysError("error unmarshalling stream response: " + err.Error())
|
||||
if err := common.Unmarshal(common.StringToByteSlice(lastStreamData), &streamResponse); err != nil {
|
||||
common.SysLog("error unmarshalling stream response: " + err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
@@ -183,8 +218,37 @@ func handleFinalResponse(c *gin.Context, info *relaycommon.RelayInfo, lastStream
|
||||
|
||||
claudeResponses := service.StreamResponseOpenAI2Claude(&streamResponse, info)
|
||||
for _, resp := range claudeResponses {
|
||||
helper.ClaudeData(c, *resp)
|
||||
_ = helper.ClaudeData(c, *resp)
|
||||
}
|
||||
|
||||
case types.RelayFormatGemini:
|
||||
var streamResponse dto.ChatCompletionsStreamResponse
|
||||
if err := common.Unmarshal(common.StringToByteSlice(lastStreamData), &streamResponse); err != nil {
|
||||
common.SysLog("error unmarshalling stream response: " + err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
// 这里处理的是 openai 最后一个流响应,其 delta 为空,有 finish_reason 字段
|
||||
// 因此相比较于 google 官方的流响应,由 openai 转换而来会多一个 parts 为空,finishReason 为 STOP 的响应
|
||||
// 而包含最后一段文本输出的响应(倒数第二个)的 finishReason 为 null
|
||||
// 暂不知是否有程序会不兼容。
|
||||
|
||||
geminiResponse := service.StreamResponseOpenAI2Gemini(&streamResponse, info)
|
||||
|
||||
// openai 流响应开头的空数据
|
||||
if geminiResponse == nil {
|
||||
return
|
||||
}
|
||||
|
||||
geminiResponseStr, err := common.Marshal(geminiResponse)
|
||||
if err != nil {
|
||||
common.SysLog("error marshalling gemini response: " + err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
// 发送最终的 Gemini 响应
|
||||
c.Render(-1, common.CustomEvent{Data: "data: " + string(geminiResponseStr)})
|
||||
_ = helper.FlushWriter(c)
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -2,6 +2,7 @@ package openai
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"math"
|
||||
@@ -10,6 +11,7 @@ import (
|
||||
"one-api/common"
|
||||
"one-api/constant"
|
||||
"one-api/dto"
|
||||
"one-api/logger"
|
||||
relaycommon "one-api/relay/common"
|
||||
"one-api/relay/helper"
|
||||
"one-api/service"
|
||||
@@ -108,11 +110,11 @@ func sendStreamData(c *gin.Context, info *relaycommon.RelayInfo, data string, fo
|
||||
|
||||
func OaiStreamHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http.Response) (*dto.Usage, *types.NewAPIError) {
|
||||
if resp == nil || resp.Body == nil {
|
||||
common.LogError(c, "invalid response or response body")
|
||||
return nil, types.NewError(fmt.Errorf("invalid response"), types.ErrorCodeBadResponse)
|
||||
logger.LogError(c, "invalid response or response body")
|
||||
return nil, types.NewOpenAIError(fmt.Errorf("invalid response"), types.ErrorCodeBadResponse, http.StatusInternalServerError)
|
||||
}
|
||||
|
||||
defer common.CloseResponseBodyGracefully(resp)
|
||||
defer service.CloseResponseBodyGracefully(resp)
|
||||
|
||||
model := info.UpstreamModelName
|
||||
var responseId string
|
||||
@@ -123,30 +125,19 @@ func OaiStreamHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http.Re
|
||||
var toolCount int
|
||||
var usage = &dto.Usage{}
|
||||
var streamItems []string // store stream items
|
||||
var forceFormat bool
|
||||
var thinkToContent bool
|
||||
|
||||
if info.ChannelSetting.ForceFormat {
|
||||
forceFormat = true
|
||||
}
|
||||
|
||||
if info.ChannelSetting.ThinkingToContent {
|
||||
thinkToContent = true
|
||||
}
|
||||
|
||||
var (
|
||||
lastStreamData string
|
||||
)
|
||||
var lastStreamData string
|
||||
|
||||
helper.StreamScannerHandler(c, resp, info, func(data string) bool {
|
||||
if lastStreamData != "" {
|
||||
err := handleStreamFormat(c, info, lastStreamData, forceFormat, thinkToContent)
|
||||
err := HandleStreamFormat(c, info, lastStreamData, info.ChannelSetting.ForceFormat, info.ChannelSetting.ThinkingToContent)
|
||||
if err != nil {
|
||||
common.SysError("error handling stream format: " + err.Error())
|
||||
common.SysLog("error handling stream format: " + err.Error())
|
||||
}
|
||||
}
|
||||
lastStreamData = data
|
||||
streamItems = append(streamItems, data)
|
||||
if len(data) > 0 {
|
||||
lastStreamData = data
|
||||
streamItems = append(streamItems, data)
|
||||
}
|
||||
return true
|
||||
})
|
||||
|
||||
@@ -154,16 +145,18 @@ func OaiStreamHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http.Re
|
||||
shouldSendLastResp := true
|
||||
if err := handleLastResponse(lastStreamData, &responseId, &createAt, &systemFingerprint, &model, &usage,
|
||||
&containStreamUsage, info, &shouldSendLastResp); err != nil {
|
||||
common.SysError("error handling last response: " + err.Error())
|
||||
logger.LogError(c, fmt.Sprintf("error handling last response: %s, lastStreamData: [%s]", err.Error(), lastStreamData))
|
||||
}
|
||||
|
||||
if shouldSendLastResp && info.RelayFormat == relaycommon.RelayFormatOpenAI {
|
||||
_ = sendStreamData(c, info, lastStreamData, forceFormat, thinkToContent)
|
||||
if info.RelayFormat == types.RelayFormatOpenAI {
|
||||
if shouldSendLastResp {
|
||||
_ = sendStreamData(c, info, lastStreamData, info.ChannelSetting.ForceFormat, info.ChannelSetting.ThinkingToContent)
|
||||
}
|
||||
}
|
||||
|
||||
// 处理token计算
|
||||
if err := processTokens(info.RelayMode, streamItems, &responseTextBuilder, &toolCount); err != nil {
|
||||
common.SysError("error processing tokens: " + err.Error())
|
||||
logger.LogError(c, "error processing tokens: "+err.Error())
|
||||
}
|
||||
|
||||
if !containStreamUsage {
|
||||
@@ -176,26 +169,28 @@ func OaiStreamHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http.Re
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
handleFinalResponse(c, info, lastStreamData, responseId, createAt, model, systemFingerprint, usage, containStreamUsage)
|
||||
HandleFinalResponse(c, info, lastStreamData, responseId, createAt, model, systemFingerprint, usage, containStreamUsage)
|
||||
|
||||
return usage, nil
|
||||
}
|
||||
|
||||
func OpenaiHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http.Response) (*dto.Usage, *types.NewAPIError) {
|
||||
defer common.CloseResponseBodyGracefully(resp)
|
||||
defer service.CloseResponseBodyGracefully(resp)
|
||||
|
||||
var simpleResponse dto.OpenAITextResponse
|
||||
responseBody, err := io.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
return nil, types.NewError(err, types.ErrorCodeReadResponseBodyFailed)
|
||||
return nil, types.NewOpenAIError(err, types.ErrorCodeReadResponseBodyFailed, http.StatusInternalServerError)
|
||||
}
|
||||
if common.DebugEnabled {
|
||||
println("upstream response body:", string(responseBody))
|
||||
}
|
||||
err = common.Unmarshal(responseBody, &simpleResponse)
|
||||
if err != nil {
|
||||
return nil, types.NewError(err, types.ErrorCodeBadResponseBody)
|
||||
return nil, types.NewOpenAIError(err, types.ErrorCodeBadResponseBody, http.StatusInternalServerError)
|
||||
}
|
||||
if simpleResponse.Error != nil && simpleResponse.Error.Type != "" {
|
||||
return nil, types.WithOpenAIError(*simpleResponse.Error, resp.StatusCode)
|
||||
if oaiError := simpleResponse.GetOpenAIError(); oaiError != nil && oaiError.Type != "" {
|
||||
return nil, types.WithOpenAIError(*oaiError, resp.StatusCode)
|
||||
}
|
||||
|
||||
forceFormat := false
|
||||
@@ -203,21 +198,34 @@ func OpenaiHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http.Respo
|
||||
forceFormat = true
|
||||
}
|
||||
|
||||
if simpleResponse.Usage.TotalTokens == 0 || (simpleResponse.Usage.PromptTokens == 0 && simpleResponse.Usage.CompletionTokens == 0) {
|
||||
completionTokens := 0
|
||||
for _, choice := range simpleResponse.Choices {
|
||||
ctkm := service.CountTextToken(choice.Message.StringContent()+choice.Message.ReasoningContent+choice.Message.Reasoning, info.UpstreamModelName)
|
||||
completionTokens += ctkm
|
||||
usageModified := false
|
||||
if simpleResponse.Usage.PromptTokens == 0 {
|
||||
completionTokens := simpleResponse.Usage.CompletionTokens
|
||||
if completionTokens == 0 {
|
||||
for _, choice := range simpleResponse.Choices {
|
||||
ctkm := service.CountTextToken(choice.Message.StringContent()+choice.Message.ReasoningContent+choice.Message.Reasoning, info.UpstreamModelName)
|
||||
completionTokens += ctkm
|
||||
}
|
||||
}
|
||||
simpleResponse.Usage = dto.Usage{
|
||||
PromptTokens: info.PromptTokens,
|
||||
CompletionTokens: completionTokens,
|
||||
TotalTokens: info.PromptTokens + completionTokens,
|
||||
}
|
||||
usageModified = true
|
||||
}
|
||||
|
||||
switch info.RelayFormat {
|
||||
case relaycommon.RelayFormatOpenAI:
|
||||
case types.RelayFormatOpenAI:
|
||||
if usageModified {
|
||||
var bodyMap map[string]interface{}
|
||||
err = common.Unmarshal(responseBody, &bodyMap)
|
||||
if err != nil {
|
||||
return nil, types.NewOpenAIError(err, types.ErrorCodeBadResponseBody, http.StatusInternalServerError)
|
||||
}
|
||||
bodyMap["usage"] = simpleResponse.Usage
|
||||
responseBody, _ = common.Marshal(bodyMap)
|
||||
}
|
||||
if forceFormat {
|
||||
responseBody, err = common.Marshal(simpleResponse)
|
||||
if err != nil {
|
||||
@@ -226,16 +234,23 @@ func OpenaiHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http.Respo
|
||||
} else {
|
||||
break
|
||||
}
|
||||
case relaycommon.RelayFormatClaude:
|
||||
case types.RelayFormatClaude:
|
||||
claudeResp := service.ResponseOpenAI2Claude(&simpleResponse, info)
|
||||
claudeRespStr, err := common.Marshal(claudeResp)
|
||||
if err != nil {
|
||||
return nil, types.NewError(err, types.ErrorCodeBadResponseBody)
|
||||
}
|
||||
responseBody = claudeRespStr
|
||||
case types.RelayFormatGemini:
|
||||
geminiResp := service.ResponseOpenAI2Gemini(&simpleResponse, info)
|
||||
geminiRespStr, err := common.Marshal(geminiResp)
|
||||
if err != nil {
|
||||
return nil, types.NewError(err, types.ErrorCodeBadResponseBody)
|
||||
}
|
||||
responseBody = geminiRespStr
|
||||
}
|
||||
|
||||
common.IOCopyBytesGracefully(c, resp, responseBody)
|
||||
service.IOCopyBytesGracefully(c, resp, responseBody)
|
||||
|
||||
return &simpleResponse.Usage, nil
|
||||
}
|
||||
@@ -247,7 +262,7 @@ func OpenaiTTSHandler(c *gin.Context, resp *http.Response, info *relaycommon.Rel
|
||||
// if the upstream returns a specific status code, once the upstream has already written the header,
|
||||
// the subsequent failure of the response body should be regarded as a non-recoverable error,
|
||||
// and can be terminated directly.
|
||||
defer common.CloseResponseBodyGracefully(resp)
|
||||
defer service.CloseResponseBodyGracefully(resp)
|
||||
usage := &dto.Usage{}
|
||||
usage.PromptTokens = info.PromptTokens
|
||||
usage.TotalTokens = info.PromptTokens
|
||||
@@ -258,26 +273,41 @@ func OpenaiTTSHandler(c *gin.Context, resp *http.Response, info *relaycommon.Rel
|
||||
c.Writer.WriteHeaderNow()
|
||||
_, err := io.Copy(c.Writer, resp.Body)
|
||||
if err != nil {
|
||||
common.LogError(c, err.Error())
|
||||
logger.LogError(c, err.Error())
|
||||
}
|
||||
return usage
|
||||
}
|
||||
|
||||
func OpenaiSTTHandler(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo, responseFormat string) (*types.NewAPIError, *dto.Usage) {
|
||||
defer common.CloseResponseBodyGracefully(resp)
|
||||
defer service.CloseResponseBodyGracefully(resp)
|
||||
|
||||
responseBody, err := io.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
return types.NewOpenAIError(err, types.ErrorCodeReadResponseBodyFailed, http.StatusInternalServerError), nil
|
||||
}
|
||||
// 写入新的 response body
|
||||
service.IOCopyBytesGracefully(c, resp, responseBody)
|
||||
|
||||
var responseData struct {
|
||||
Usage *dto.Usage `json:"usage"`
|
||||
}
|
||||
if err := json.Unmarshal(responseBody, &responseData); err == nil && responseData.Usage != nil {
|
||||
if responseData.Usage.TotalTokens > 0 {
|
||||
usage := responseData.Usage
|
||||
if usage.PromptTokens == 0 {
|
||||
usage.PromptTokens = usage.InputTokens
|
||||
}
|
||||
if usage.CompletionTokens == 0 {
|
||||
usage.CompletionTokens = usage.OutputTokens
|
||||
}
|
||||
return nil, usage
|
||||
}
|
||||
}
|
||||
|
||||
// count tokens by audio file duration
|
||||
audioTokens, err := countAudioTokens(c)
|
||||
if err != nil {
|
||||
return types.NewError(err, types.ErrorCodeCountTokenFailed), nil
|
||||
}
|
||||
responseBody, err := io.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
return types.NewError(err, types.ErrorCodeReadResponseBodyFailed), nil
|
||||
}
|
||||
// 写入新的 response body
|
||||
common.IOCopyBytesGracefully(c, resp, responseBody)
|
||||
|
||||
usage := &dto.Usage{}
|
||||
usage.PromptTokens = audioTokens
|
||||
usage.CompletionTokens = 0
|
||||
@@ -386,7 +416,7 @@ func OpenaiRealtimeHandler(c *gin.Context, info *relaycommon.RelayInfo) (*types.
|
||||
errChan <- fmt.Errorf("error counting text token: %v", err)
|
||||
return
|
||||
}
|
||||
common.LogInfo(c, fmt.Sprintf("type: %s, textToken: %d, audioToken: %d", realtimeEvent.Type, textToken, audioToken))
|
||||
logger.LogInfo(c, fmt.Sprintf("type: %s, textToken: %d, audioToken: %d", realtimeEvent.Type, textToken, audioToken))
|
||||
localUsage.TotalTokens += textToken + audioToken
|
||||
localUsage.InputTokens += textToken + audioToken
|
||||
localUsage.InputTokenDetails.TextTokens += textToken
|
||||
@@ -459,7 +489,7 @@ func OpenaiRealtimeHandler(c *gin.Context, info *relaycommon.RelayInfo) (*types.
|
||||
errChan <- fmt.Errorf("error counting text token: %v", err)
|
||||
return
|
||||
}
|
||||
common.LogInfo(c, fmt.Sprintf("type: %s, textToken: %d, audioToken: %d", realtimeEvent.Type, textToken, audioToken))
|
||||
logger.LogInfo(c, fmt.Sprintf("type: %s, textToken: %d, audioToken: %d", realtimeEvent.Type, textToken, audioToken))
|
||||
localUsage.TotalTokens += textToken + audioToken
|
||||
info.IsFirstRequest = false
|
||||
localUsage.InputTokens += textToken + audioToken
|
||||
@@ -474,9 +504,9 @@ func OpenaiRealtimeHandler(c *gin.Context, info *relaycommon.RelayInfo) (*types.
|
||||
localUsage = &dto.RealtimeUsage{}
|
||||
// print now usage
|
||||
}
|
||||
common.LogInfo(c, fmt.Sprintf("realtime streaming sumUsage: %v", sumUsage))
|
||||
common.LogInfo(c, fmt.Sprintf("realtime streaming localUsage: %v", localUsage))
|
||||
common.LogInfo(c, fmt.Sprintf("realtime streaming localUsage: %v", localUsage))
|
||||
logger.LogInfo(c, fmt.Sprintf("realtime streaming sumUsage: %v", sumUsage))
|
||||
logger.LogInfo(c, fmt.Sprintf("realtime streaming localUsage: %v", localUsage))
|
||||
logger.LogInfo(c, fmt.Sprintf("realtime streaming localUsage: %v", localUsage))
|
||||
|
||||
} else if realtimeEvent.Type == dto.RealtimeEventTypeSessionUpdated || realtimeEvent.Type == dto.RealtimeEventTypeSessionCreated {
|
||||
realtimeSession := realtimeEvent.Session
|
||||
@@ -491,7 +521,7 @@ func OpenaiRealtimeHandler(c *gin.Context, info *relaycommon.RelayInfo) (*types.
|
||||
errChan <- fmt.Errorf("error counting text token: %v", err)
|
||||
return
|
||||
}
|
||||
common.LogInfo(c, fmt.Sprintf("type: %s, textToken: %d, audioToken: %d", realtimeEvent.Type, textToken, audioToken))
|
||||
logger.LogInfo(c, fmt.Sprintf("type: %s, textToken: %d, audioToken: %d", realtimeEvent.Type, textToken, audioToken))
|
||||
localUsage.TotalTokens += textToken + audioToken
|
||||
localUsage.OutputTokens += textToken + audioToken
|
||||
localUsage.OutputTokenDetails.TextTokens += textToken
|
||||
@@ -517,7 +547,7 @@ func OpenaiRealtimeHandler(c *gin.Context, info *relaycommon.RelayInfo) (*types.
|
||||
case <-targetClosed:
|
||||
case err := <-errChan:
|
||||
//return service.OpenAIErrorWrapper(err, "realtime_error", http.StatusInternalServerError), nil
|
||||
common.LogError(c, "realtime error: "+err.Error())
|
||||
logger.LogError(c, "realtime error: "+err.Error())
|
||||
case <-c.Done():
|
||||
}
|
||||
|
||||
@@ -553,21 +583,21 @@ func preConsumeUsage(ctx *gin.Context, info *relaycommon.RelayInfo, usage *dto.R
|
||||
}
|
||||
|
||||
func OpenaiHandlerWithUsage(c *gin.Context, info *relaycommon.RelayInfo, resp *http.Response) (*dto.Usage, *types.NewAPIError) {
|
||||
defer common.CloseResponseBodyGracefully(resp)
|
||||
defer service.CloseResponseBodyGracefully(resp)
|
||||
|
||||
responseBody, err := io.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
return nil, types.NewError(err, types.ErrorCodeReadResponseBodyFailed)
|
||||
return nil, types.NewOpenAIError(err, types.ErrorCodeReadResponseBodyFailed, http.StatusInternalServerError)
|
||||
}
|
||||
|
||||
var usageResp dto.SimpleResponse
|
||||
err = common.Unmarshal(responseBody, &usageResp)
|
||||
if err != nil {
|
||||
return nil, types.NewError(err, types.ErrorCodeBadResponseBody)
|
||||
return nil, types.NewOpenAIError(err, types.ErrorCodeBadResponseBody, http.StatusInternalServerError)
|
||||
}
|
||||
|
||||
// 写入新的 response body
|
||||
common.IOCopyBytesGracefully(c, resp, responseBody)
|
||||
service.IOCopyBytesGracefully(c, resp, responseBody)
|
||||
|
||||
// Once we've written to the client, we should not return errors anymore
|
||||
// because the upstream has already consumed resources and returned content
|
||||
|
||||
@@ -6,6 +6,7 @@ import (
|
||||
"net/http"
|
||||
"one-api/common"
|
||||
"one-api/dto"
|
||||
"one-api/logger"
|
||||
relaycommon "one-api/relay/common"
|
||||
"one-api/relay/helper"
|
||||
"one-api/service"
|
||||
@@ -16,43 +17,58 @@ import (
|
||||
)
|
||||
|
||||
func OaiResponsesHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http.Response) (*dto.Usage, *types.NewAPIError) {
|
||||
defer common.CloseResponseBodyGracefully(resp)
|
||||
defer service.CloseResponseBodyGracefully(resp)
|
||||
|
||||
// read response body
|
||||
var responsesResponse dto.OpenAIResponsesResponse
|
||||
responseBody, err := io.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
return nil, types.NewError(err, types.ErrorCodeReadResponseBodyFailed)
|
||||
return nil, types.NewOpenAIError(err, types.ErrorCodeReadResponseBodyFailed, http.StatusInternalServerError)
|
||||
}
|
||||
err = common.Unmarshal(responseBody, &responsesResponse)
|
||||
if err != nil {
|
||||
return nil, types.NewError(err, types.ErrorCodeBadResponseBody)
|
||||
return nil, types.NewOpenAIError(err, types.ErrorCodeBadResponseBody, http.StatusInternalServerError)
|
||||
}
|
||||
if responsesResponse.Error != nil {
|
||||
return nil, types.WithOpenAIError(*responsesResponse.Error, resp.StatusCode)
|
||||
if oaiError := responsesResponse.GetOpenAIError(); oaiError != nil && oaiError.Type != "" {
|
||||
return nil, types.WithOpenAIError(*oaiError, resp.StatusCode)
|
||||
}
|
||||
|
||||
// 写入新的 response body
|
||||
common.IOCopyBytesGracefully(c, resp, responseBody)
|
||||
service.IOCopyBytesGracefully(c, resp, responseBody)
|
||||
|
||||
// compute usage
|
||||
usage := dto.Usage{}
|
||||
usage.PromptTokens = responsesResponse.Usage.InputTokens
|
||||
usage.CompletionTokens = responsesResponse.Usage.OutputTokens
|
||||
usage.TotalTokens = responsesResponse.Usage.TotalTokens
|
||||
if responsesResponse.Usage != nil {
|
||||
usage.PromptTokens = responsesResponse.Usage.InputTokens
|
||||
usage.CompletionTokens = responsesResponse.Usage.OutputTokens
|
||||
usage.TotalTokens = responsesResponse.Usage.TotalTokens
|
||||
if responsesResponse.Usage.InputTokensDetails != nil {
|
||||
usage.PromptTokensDetails.CachedTokens = responsesResponse.Usage.InputTokensDetails.CachedTokens
|
||||
}
|
||||
}
|
||||
if info == nil || info.ResponsesUsageInfo == nil || info.ResponsesUsageInfo.BuiltInTools == nil {
|
||||
return &usage, nil
|
||||
}
|
||||
// 解析 Tools 用量
|
||||
for _, tool := range responsesResponse.Tools {
|
||||
info.ResponsesUsageInfo.BuiltInTools[common.Interface2String(tool["type"])].CallCount++
|
||||
buildToolinfo, ok := info.ResponsesUsageInfo.BuiltInTools[common.Interface2String(tool["type"])]
|
||||
if !ok || buildToolinfo == nil {
|
||||
logger.LogError(c, fmt.Sprintf("BuiltInTools not found for tool type: %v", tool["type"]))
|
||||
continue
|
||||
}
|
||||
buildToolinfo.CallCount++
|
||||
}
|
||||
return &usage, nil
|
||||
}
|
||||
|
||||
func OaiResponsesStreamHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http.Response) (*dto.Usage, *types.NewAPIError) {
|
||||
if resp == nil || resp.Body == nil {
|
||||
common.LogError(c, "invalid response or response body")
|
||||
logger.LogError(c, "invalid response or response body")
|
||||
return nil, types.NewError(fmt.Errorf("invalid response"), types.ErrorCodeBadResponse)
|
||||
}
|
||||
|
||||
defer service.CloseResponseBodyGracefully(resp)
|
||||
|
||||
var usage = &dto.Usage{}
|
||||
var responseTextBuilder strings.Builder
|
||||
|
||||
@@ -64,9 +80,20 @@ func OaiResponsesStreamHandler(c *gin.Context, info *relaycommon.RelayInfo, resp
|
||||
sendResponsesStreamData(c, streamResponse, data)
|
||||
switch streamResponse.Type {
|
||||
case "response.completed":
|
||||
usage.PromptTokens = streamResponse.Response.Usage.InputTokens
|
||||
usage.CompletionTokens = streamResponse.Response.Usage.OutputTokens
|
||||
usage.TotalTokens = streamResponse.Response.Usage.TotalTokens
|
||||
if streamResponse.Response != nil && streamResponse.Response.Usage != nil {
|
||||
if streamResponse.Response.Usage.InputTokens != 0 {
|
||||
usage.PromptTokens = streamResponse.Response.Usage.InputTokens
|
||||
}
|
||||
if streamResponse.Response.Usage.OutputTokens != 0 {
|
||||
usage.CompletionTokens = streamResponse.Response.Usage.OutputTokens
|
||||
}
|
||||
if streamResponse.Response.Usage.TotalTokens != 0 {
|
||||
usage.TotalTokens = streamResponse.Response.Usage.TotalTokens
|
||||
}
|
||||
if streamResponse.Response.Usage.InputTokensDetails != nil {
|
||||
usage.PromptTokensDetails.CachedTokens = streamResponse.Response.Usage.InputTokensDetails.CachedTokens
|
||||
}
|
||||
}
|
||||
case "response.output_text.delta":
|
||||
// 处理输出文本
|
||||
responseTextBuilder.WriteString(streamResponse.Delta)
|
||||
@@ -79,6 +106,8 @@ func OaiResponsesStreamHandler(c *gin.Context, info *relaycommon.RelayInfo, resp
|
||||
}
|
||||
}
|
||||
}
|
||||
} else {
|
||||
logger.LogError(c, "failed to unmarshal stream response: "+err.Error())
|
||||
}
|
||||
return true
|
||||
})
|
||||
@@ -93,5 +122,11 @@ func OaiResponsesStreamHandler(c *gin.Context, info *relaycommon.RelayInfo, resp
|
||||
}
|
||||
}
|
||||
|
||||
if usage.PromptTokens == 0 && usage.CompletionTokens != 0 {
|
||||
usage.PromptTokens = info.PromptTokens
|
||||
}
|
||||
|
||||
usage.TotalTokens = usage.PromptTokens + usage.CompletionTokens
|
||||
|
||||
return usage, nil
|
||||
}
|
||||
|
||||
@@ -17,6 +17,11 @@ import (
|
||||
type Adaptor struct {
|
||||
}
|
||||
|
||||
func (a *Adaptor) ConvertGeminiRequest(*gin.Context, *relaycommon.RelayInfo, *dto.GeminiChatRequest) (any, error) {
|
||||
//TODO implement me
|
||||
return nil, errors.New("not implemented")
|
||||
}
|
||||
|
||||
func (a *Adaptor) ConvertClaudeRequest(*gin.Context, *relaycommon.RelayInfo, *dto.ClaudeRequest) (any, error) {
|
||||
//TODO implement me
|
||||
panic("implement me")
|
||||
@@ -37,7 +42,7 @@ func (a *Adaptor) Init(info *relaycommon.RelayInfo) {
|
||||
}
|
||||
|
||||
func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) {
|
||||
return fmt.Sprintf("%s/v1beta2/models/chat-bison-001:generateMessage", info.BaseUrl), nil
|
||||
return fmt.Sprintf("%s/v1beta2/models/chat-bison-001:generateMessage", info.ChannelBaseUrl), nil
|
||||
}
|
||||
|
||||
func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Header, info *relaycommon.RelayInfo) error {
|
||||
|
||||
@@ -18,30 +18,6 @@ import (
|
||||
// https://developers.generativeai.google/api/rest/generativelanguage/models/generateMessage#request-body
|
||||
// https://developers.generativeai.google/api/rest/generativelanguage/models/generateMessage#response-body
|
||||
|
||||
func requestOpenAI2PaLM(textRequest dto.GeneralOpenAIRequest) *PaLMChatRequest {
|
||||
palmRequest := PaLMChatRequest{
|
||||
Prompt: PaLMPrompt{
|
||||
Messages: make([]PaLMChatMessage, 0, len(textRequest.Messages)),
|
||||
},
|
||||
Temperature: textRequest.Temperature,
|
||||
CandidateCount: textRequest.N,
|
||||
TopP: textRequest.TopP,
|
||||
TopK: textRequest.MaxTokens,
|
||||
}
|
||||
for _, message := range textRequest.Messages {
|
||||
palmMessage := PaLMChatMessage{
|
||||
Content: message.StringContent(),
|
||||
}
|
||||
if message.Role == "user" {
|
||||
palmMessage.Author = "0"
|
||||
} else {
|
||||
palmMessage.Author = "1"
|
||||
}
|
||||
palmRequest.Prompt.Messages = append(palmRequest.Prompt.Messages, palmMessage)
|
||||
}
|
||||
return &palmRequest
|
||||
}
|
||||
|
||||
func responsePaLM2OpenAI(response *PaLMChatResponse) *dto.OpenAITextResponse {
|
||||
fullTextResponse := dto.OpenAITextResponse{
|
||||
Choices: make([]dto.OpenAITextResponseChoice, 0, len(response.Candidates)),
|
||||
@@ -82,15 +58,15 @@ func palmStreamHandler(c *gin.Context, resp *http.Response) (*types.NewAPIError,
|
||||
go func() {
|
||||
responseBody, err := io.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
common.SysError("error reading stream response: " + err.Error())
|
||||
common.SysLog("error reading stream response: " + err.Error())
|
||||
stopChan <- true
|
||||
return
|
||||
}
|
||||
common.CloseResponseBodyGracefully(resp)
|
||||
service.CloseResponseBodyGracefully(resp)
|
||||
var palmResponse PaLMChatResponse
|
||||
err = json.Unmarshal(responseBody, &palmResponse)
|
||||
if err != nil {
|
||||
common.SysError("error unmarshalling stream response: " + err.Error())
|
||||
common.SysLog("error unmarshalling stream response: " + err.Error())
|
||||
stopChan <- true
|
||||
return
|
||||
}
|
||||
@@ -102,7 +78,7 @@ func palmStreamHandler(c *gin.Context, resp *http.Response) (*types.NewAPIError,
|
||||
}
|
||||
jsonResponse, err := json.Marshal(fullTextResponse)
|
||||
if err != nil {
|
||||
common.SysError("error marshalling stream response: " + err.Error())
|
||||
common.SysLog("error marshalling stream response: " + err.Error())
|
||||
stopChan <- true
|
||||
return
|
||||
}
|
||||
@@ -120,20 +96,20 @@ func palmStreamHandler(c *gin.Context, resp *http.Response) (*types.NewAPIError,
|
||||
return false
|
||||
}
|
||||
})
|
||||
common.CloseResponseBodyGracefully(resp)
|
||||
service.CloseResponseBodyGracefully(resp)
|
||||
return nil, responseText
|
||||
}
|
||||
|
||||
func palmHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http.Response) (*dto.Usage, *types.NewAPIError) {
|
||||
responseBody, err := io.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
return nil, types.NewError(err, types.ErrorCodeReadResponseBodyFailed)
|
||||
return nil, types.NewOpenAIError(err, types.ErrorCodeReadResponseBodyFailed, http.StatusInternalServerError)
|
||||
}
|
||||
common.CloseResponseBodyGracefully(resp)
|
||||
service.CloseResponseBodyGracefully(resp)
|
||||
var palmResponse PaLMChatResponse
|
||||
err = json.Unmarshal(responseBody, &palmResponse)
|
||||
if err != nil {
|
||||
return nil, types.NewError(err, types.ErrorCodeBadResponseBody)
|
||||
return nil, types.NewOpenAIError(err, types.ErrorCodeBadResponseBody, http.StatusInternalServerError)
|
||||
}
|
||||
if palmResponse.Error.Code != 0 || len(palmResponse.Candidates) == 0 {
|
||||
return nil, types.WithOpenAIError(types.OpenAIError{
|
||||
@@ -157,6 +133,6 @@ func palmHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http.Respons
|
||||
}
|
||||
c.Writer.Header().Set("Content-Type", "application/json")
|
||||
c.Writer.WriteHeader(resp.StatusCode)
|
||||
common.IOCopyBytesGracefully(c, resp, jsonResponse)
|
||||
service.IOCopyBytesGracefully(c, resp, jsonResponse)
|
||||
return &usage, nil
|
||||
}
|
||||
|
||||
@@ -17,6 +17,11 @@ import (
|
||||
type Adaptor struct {
|
||||
}
|
||||
|
||||
func (a *Adaptor) ConvertGeminiRequest(*gin.Context, *relaycommon.RelayInfo, *dto.GeminiChatRequest) (any, error) {
|
||||
//TODO implement me
|
||||
return nil, errors.New("not implemented")
|
||||
}
|
||||
|
||||
func (a *Adaptor) ConvertClaudeRequest(*gin.Context, *relaycommon.RelayInfo, *dto.ClaudeRequest) (any, error) {
|
||||
//TODO implement me
|
||||
panic("implement me")
|
||||
@@ -37,7 +42,7 @@ func (a *Adaptor) Init(info *relaycommon.RelayInfo) {
|
||||
}
|
||||
|
||||
func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) {
|
||||
return fmt.Sprintf("%s/chat/completions", info.BaseUrl), nil
|
||||
return fmt.Sprintf("%s/chat/completions", info.ChannelBaseUrl), nil
|
||||
}
|
||||
|
||||
func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Header, info *relaycommon.RelayInfo) error {
|
||||
|
||||
@@ -16,6 +16,6 @@ func requestOpenAI2Perplexity(request dto.GeneralOpenAIRequest) *dto.GeneralOpen
|
||||
Messages: messages,
|
||||
Temperature: request.Temperature,
|
||||
TopP: request.TopP,
|
||||
MaxTokens: request.MaxTokens,
|
||||
MaxTokens: request.GetMaxTokens(),
|
||||
}
|
||||
}
|
||||
|
||||
@@ -18,20 +18,24 @@ import (
|
||||
type Adaptor struct {
|
||||
}
|
||||
|
||||
func (a *Adaptor) ConvertClaudeRequest(*gin.Context, *relaycommon.RelayInfo, *dto.ClaudeRequest) (any, error) {
|
||||
func (a *Adaptor) ConvertGeminiRequest(*gin.Context, *relaycommon.RelayInfo, *dto.GeminiChatRequest) (any, error) {
|
||||
//TODO implement me
|
||||
panic("implement me")
|
||||
return nil, nil
|
||||
return nil, errors.New("not implemented")
|
||||
}
|
||||
|
||||
func (a *Adaptor) ConvertClaudeRequest(c *gin.Context, info *relaycommon.RelayInfo, req *dto.ClaudeRequest) (any, error) {
|
||||
adaptor := openai.Adaptor{}
|
||||
return adaptor.ConvertClaudeRequest(c, info, req)
|
||||
}
|
||||
|
||||
func (a *Adaptor) ConvertAudioRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.AudioRequest) (io.Reader, error) {
|
||||
//TODO implement me
|
||||
return nil, errors.New("not implemented")
|
||||
return nil, errors.New("not supported")
|
||||
}
|
||||
|
||||
func (a *Adaptor) ConvertImageRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.ImageRequest) (any, error) {
|
||||
//TODO implement me
|
||||
return nil, errors.New("not implemented")
|
||||
adaptor := openai.Adaptor{}
|
||||
return adaptor.ConvertImageRequest(c, info, request)
|
||||
}
|
||||
|
||||
func (a *Adaptor) Init(info *relaycommon.RelayInfo) {
|
||||
@@ -39,15 +43,15 @@ func (a *Adaptor) Init(info *relaycommon.RelayInfo) {
|
||||
|
||||
func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) {
|
||||
if info.RelayMode == constant.RelayModeRerank {
|
||||
return fmt.Sprintf("%s/v1/rerank", info.BaseUrl), nil
|
||||
return fmt.Sprintf("%s/v1/rerank", info.ChannelBaseUrl), nil
|
||||
} else if info.RelayMode == constant.RelayModeEmbeddings {
|
||||
return fmt.Sprintf("%s/v1/embeddings", info.BaseUrl), nil
|
||||
return fmt.Sprintf("%s/v1/embeddings", info.ChannelBaseUrl), nil
|
||||
} else if info.RelayMode == constant.RelayModeChatCompletions {
|
||||
return fmt.Sprintf("%s/v1/chat/completions", info.BaseUrl), nil
|
||||
return fmt.Sprintf("%s/v1/chat/completions", info.ChannelBaseUrl), nil
|
||||
} else if info.RelayMode == constant.RelayModeCompletions {
|
||||
return fmt.Sprintf("%s/v1/completions", info.BaseUrl), nil
|
||||
return fmt.Sprintf("%s/v1/completions", info.ChannelBaseUrl), nil
|
||||
}
|
||||
return "", errors.New("invalid relay mode")
|
||||
return fmt.Sprintf("%s/v1/chat/completions", info.ChannelBaseUrl), nil
|
||||
}
|
||||
|
||||
func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Header, info *relaycommon.RelayInfo) error {
|
||||
@@ -81,16 +85,19 @@ func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycom
|
||||
switch info.RelayMode {
|
||||
case constant.RelayModeRerank:
|
||||
usage, err = siliconflowRerankHandler(c, info, resp)
|
||||
case constant.RelayModeEmbeddings:
|
||||
usage, err = openai.OpenaiHandler(c, info, resp)
|
||||
case constant.RelayModeCompletions:
|
||||
fallthrough
|
||||
case constant.RelayModeChatCompletions:
|
||||
fallthrough
|
||||
default:
|
||||
if info.IsStream {
|
||||
usage, err = openai.OaiStreamHandler(c, info, resp)
|
||||
} else {
|
||||
usage, err = openai.OpenaiHandler(c, info, resp)
|
||||
}
|
||||
case constant.RelayModeEmbeddings:
|
||||
usage, err = openai.OpenaiHandler(c, info, resp)
|
||||
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
@@ -4,9 +4,9 @@ import (
|
||||
"encoding/json"
|
||||
"io"
|
||||
"net/http"
|
||||
"one-api/common"
|
||||
"one-api/dto"
|
||||
relaycommon "one-api/relay/common"
|
||||
"one-api/service"
|
||||
"one-api/types"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
@@ -15,13 +15,13 @@ import (
|
||||
func siliconflowRerankHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http.Response) (*dto.Usage, *types.NewAPIError) {
|
||||
responseBody, err := io.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
return nil, types.NewError(err, types.ErrorCodeReadResponseBodyFailed)
|
||||
return nil, types.NewOpenAIError(err, types.ErrorCodeReadResponseBodyFailed, http.StatusInternalServerError)
|
||||
}
|
||||
common.CloseResponseBodyGracefully(resp)
|
||||
service.CloseResponseBodyGracefully(resp)
|
||||
var siliconflowResp SFRerankResponse
|
||||
err = json.Unmarshal(responseBody, &siliconflowResp)
|
||||
if err != nil {
|
||||
return nil, types.NewError(err, types.ErrorCodeBadResponseBody)
|
||||
return nil, types.NewOpenAIError(err, types.ErrorCodeBadResponseBody, http.StatusInternalServerError)
|
||||
}
|
||||
usage := &dto.Usage{
|
||||
PromptTokens: siliconflowResp.Meta.Tokens.InputTokens,
|
||||
@@ -39,6 +39,6 @@ func siliconflowRerankHandler(c *gin.Context, info *relaycommon.RelayInfo, resp
|
||||
}
|
||||
c.Writer.Header().Set("Content-Type", "application/json")
|
||||
c.Writer.WriteHeader(resp.StatusCode)
|
||||
common.IOCopyBytesGracefully(c, resp, jsonResponse)
|
||||
service.IOCopyBytesGracefully(c, resp, jsonResponse)
|
||||
return usage, nil
|
||||
}
|
||||
|
||||
@@ -74,9 +74,9 @@ type TaskAdaptor struct {
|
||||
baseURL string
|
||||
}
|
||||
|
||||
func (a *TaskAdaptor) Init(info *relaycommon.TaskRelayInfo) {
|
||||
func (a *TaskAdaptor) Init(info *relaycommon.RelayInfo) {
|
||||
a.ChannelType = info.ChannelType
|
||||
a.baseURL = info.BaseUrl
|
||||
a.baseURL = info.ChannelBaseUrl
|
||||
|
||||
// apiKey format: "access_key|secret_key"
|
||||
keyParts := strings.Split(info.ApiKey, "|")
|
||||
@@ -87,7 +87,7 @@ func (a *TaskAdaptor) Init(info *relaycommon.TaskRelayInfo) {
|
||||
}
|
||||
|
||||
// ValidateRequestAndSetAction parses body, validates fields and sets default action.
|
||||
func (a *TaskAdaptor) ValidateRequestAndSetAction(c *gin.Context, info *relaycommon.TaskRelayInfo) (taskErr *dto.TaskError) {
|
||||
func (a *TaskAdaptor) ValidateRequestAndSetAction(c *gin.Context, info *relaycommon.RelayInfo) (taskErr *dto.TaskError) {
|
||||
// Accept only POST /v1/video/generations as "generate" action.
|
||||
action := constant.TaskActionGenerate
|
||||
info.Action = action
|
||||
@@ -108,19 +108,19 @@ func (a *TaskAdaptor) ValidateRequestAndSetAction(c *gin.Context, info *relaycom
|
||||
}
|
||||
|
||||
// BuildRequestURL constructs the upstream URL.
|
||||
func (a *TaskAdaptor) BuildRequestURL(info *relaycommon.TaskRelayInfo) (string, error) {
|
||||
func (a *TaskAdaptor) BuildRequestURL(info *relaycommon.RelayInfo) (string, error) {
|
||||
return fmt.Sprintf("%s/?Action=CVSync2AsyncSubmitTask&Version=2022-08-31", a.baseURL), nil
|
||||
}
|
||||
|
||||
// BuildRequestHeader sets required headers.
|
||||
func (a *TaskAdaptor) BuildRequestHeader(c *gin.Context, req *http.Request, info *relaycommon.TaskRelayInfo) error {
|
||||
func (a *TaskAdaptor) BuildRequestHeader(c *gin.Context, req *http.Request, info *relaycommon.RelayInfo) error {
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
req.Header.Set("Accept", "application/json")
|
||||
return a.signRequest(req, a.accessKey, a.secretKey)
|
||||
}
|
||||
|
||||
// BuildRequestBody converts request into Jimeng specific format.
|
||||
func (a *TaskAdaptor) BuildRequestBody(c *gin.Context, info *relaycommon.TaskRelayInfo) (io.Reader, error) {
|
||||
func (a *TaskAdaptor) BuildRequestBody(c *gin.Context, info *relaycommon.RelayInfo) (io.Reader, error) {
|
||||
v, exists := c.Get("task_request")
|
||||
if !exists {
|
||||
return nil, fmt.Errorf("request not found in context")
|
||||
@@ -139,12 +139,12 @@ func (a *TaskAdaptor) BuildRequestBody(c *gin.Context, info *relaycommon.TaskRel
|
||||
}
|
||||
|
||||
// DoRequest delegates to common helper.
|
||||
func (a *TaskAdaptor) DoRequest(c *gin.Context, info *relaycommon.TaskRelayInfo, requestBody io.Reader) (*http.Response, error) {
|
||||
func (a *TaskAdaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (*http.Response, error) {
|
||||
return channel.DoTaskApiRequest(a, c, info, requestBody)
|
||||
}
|
||||
|
||||
// DoResponse handles upstream response, returns taskID etc.
|
||||
func (a *TaskAdaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.TaskRelayInfo) (taskID string, taskData []byte, taskErr *dto.TaskError) {
|
||||
func (a *TaskAdaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (taskID string, taskData []byte, taskErr *dto.TaskError) {
|
||||
responseBody, err := io.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
taskErr = service.TaskErrorWrapper(err, "read_response_body_failed", http.StatusInternalServerError)
|
||||
|
||||
@@ -4,13 +4,14 @@ import (
|
||||
"bytes"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"github.com/samber/lo"
|
||||
"io"
|
||||
"net/http"
|
||||
"one-api/model"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/samber/lo"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/golang-jwt/jwt"
|
||||
"github.com/pkg/errors"
|
||||
@@ -37,19 +38,52 @@ type SubmitReq struct {
|
||||
Metadata map[string]interface{} `json:"metadata,omitempty"`
|
||||
}
|
||||
|
||||
type TrajectoryPoint struct {
|
||||
X int `json:"x"`
|
||||
Y int `json:"y"`
|
||||
}
|
||||
|
||||
type DynamicMask struct {
|
||||
Mask string `json:"mask,omitempty"`
|
||||
Trajectories []TrajectoryPoint `json:"trajectories,omitempty"`
|
||||
}
|
||||
|
||||
type CameraConfig struct {
|
||||
Horizontal float64 `json:"horizontal,omitempty"`
|
||||
Vertical float64 `json:"vertical,omitempty"`
|
||||
Pan float64 `json:"pan,omitempty"`
|
||||
Tilt float64 `json:"tilt,omitempty"`
|
||||
Roll float64 `json:"roll,omitempty"`
|
||||
Zoom float64 `json:"zoom,omitempty"`
|
||||
}
|
||||
|
||||
type CameraControl struct {
|
||||
Type string `json:"type,omitempty"`
|
||||
Config *CameraConfig `json:"config,omitempty"`
|
||||
}
|
||||
|
||||
type requestPayload struct {
|
||||
Prompt string `json:"prompt,omitempty"`
|
||||
Image string `json:"image,omitempty"`
|
||||
Mode string `json:"mode,omitempty"`
|
||||
Duration string `json:"duration,omitempty"`
|
||||
AspectRatio string `json:"aspect_ratio,omitempty"`
|
||||
ModelName string `json:"model_name,omitempty"`
|
||||
CfgScale float64 `json:"cfg_scale,omitempty"`
|
||||
Prompt string `json:"prompt,omitempty"`
|
||||
Image string `json:"image,omitempty"`
|
||||
ImageTail string `json:"image_tail,omitempty"`
|
||||
NegativePrompt string `json:"negative_prompt,omitempty"`
|
||||
Mode string `json:"mode,omitempty"`
|
||||
Duration string `json:"duration,omitempty"`
|
||||
AspectRatio string `json:"aspect_ratio,omitempty"`
|
||||
ModelName string `json:"model_name,omitempty"`
|
||||
Model string `json:"model,omitempty"` // Compatible with upstreams that only recognize "model"
|
||||
CfgScale float64 `json:"cfg_scale,omitempty"`
|
||||
StaticMask string `json:"static_mask,omitempty"`
|
||||
DynamicMasks []DynamicMask `json:"dynamic_masks,omitempty"`
|
||||
CameraControl *CameraControl `json:"camera_control,omitempty"`
|
||||
CallbackUrl string `json:"callback_url,omitempty"`
|
||||
ExternalTaskId string `json:"external_task_id,omitempty"`
|
||||
}
|
||||
|
||||
type responsePayload struct {
|
||||
Code int `json:"code"`
|
||||
Message string `json:"message"`
|
||||
TaskId string `json:"task_id"`
|
||||
RequestId string `json:"request_id"`
|
||||
Data struct {
|
||||
TaskId string `json:"task_id"`
|
||||
@@ -73,25 +107,20 @@ type responsePayload struct {
|
||||
|
||||
type TaskAdaptor struct {
|
||||
ChannelType int
|
||||
accessKey string
|
||||
secretKey string
|
||||
apiKey string
|
||||
baseURL string
|
||||
}
|
||||
|
||||
func (a *TaskAdaptor) Init(info *relaycommon.TaskRelayInfo) {
|
||||
func (a *TaskAdaptor) Init(info *relaycommon.RelayInfo) {
|
||||
a.ChannelType = info.ChannelType
|
||||
a.baseURL = info.BaseUrl
|
||||
a.baseURL = info.ChannelBaseUrl
|
||||
a.apiKey = info.ApiKey
|
||||
|
||||
// apiKey format: "access_key|secret_key"
|
||||
keyParts := strings.Split(info.ApiKey, "|")
|
||||
if len(keyParts) == 2 {
|
||||
a.accessKey = strings.TrimSpace(keyParts[0])
|
||||
a.secretKey = strings.TrimSpace(keyParts[1])
|
||||
}
|
||||
}
|
||||
|
||||
// ValidateRequestAndSetAction parses body, validates fields and sets default action.
|
||||
func (a *TaskAdaptor) ValidateRequestAndSetAction(c *gin.Context, info *relaycommon.TaskRelayInfo) (taskErr *dto.TaskError) {
|
||||
func (a *TaskAdaptor) ValidateRequestAndSetAction(c *gin.Context, info *relaycommon.RelayInfo) (taskErr *dto.TaskError) {
|
||||
// Accept only POST /v1/video/generations as "generate" action.
|
||||
action := constant.TaskActionGenerate
|
||||
info.Action = action
|
||||
@@ -112,13 +141,13 @@ func (a *TaskAdaptor) ValidateRequestAndSetAction(c *gin.Context, info *relaycom
|
||||
}
|
||||
|
||||
// BuildRequestURL constructs the upstream URL.
|
||||
func (a *TaskAdaptor) BuildRequestURL(info *relaycommon.TaskRelayInfo) (string, error) {
|
||||
func (a *TaskAdaptor) BuildRequestURL(info *relaycommon.RelayInfo) (string, error) {
|
||||
path := lo.Ternary(info.Action == constant.TaskActionGenerate, "/v1/videos/image2video", "/v1/videos/text2video")
|
||||
return fmt.Sprintf("%s%s", a.baseURL, path), nil
|
||||
}
|
||||
|
||||
// BuildRequestHeader sets required headers.
|
||||
func (a *TaskAdaptor) BuildRequestHeader(c *gin.Context, req *http.Request, info *relaycommon.TaskRelayInfo) error {
|
||||
func (a *TaskAdaptor) BuildRequestHeader(c *gin.Context, req *http.Request, info *relaycommon.RelayInfo) error {
|
||||
token, err := a.createJWTToken()
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to create JWT token: %w", err)
|
||||
@@ -132,7 +161,7 @@ func (a *TaskAdaptor) BuildRequestHeader(c *gin.Context, req *http.Request, info
|
||||
}
|
||||
|
||||
// BuildRequestBody converts request into Kling specific format.
|
||||
func (a *TaskAdaptor) BuildRequestBody(c *gin.Context, info *relaycommon.TaskRelayInfo) (io.Reader, error) {
|
||||
func (a *TaskAdaptor) BuildRequestBody(c *gin.Context, info *relaycommon.RelayInfo) (io.Reader, error) {
|
||||
v, exists := c.Get("task_request")
|
||||
if !exists {
|
||||
return nil, fmt.Errorf("request not found in context")
|
||||
@@ -143,6 +172,9 @@ func (a *TaskAdaptor) BuildRequestBody(c *gin.Context, info *relaycommon.TaskRel
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if body.Image == "" && body.ImageTail == "" {
|
||||
c.Set("action", constant.TaskActionTextGenerate)
|
||||
}
|
||||
data, err := json.Marshal(body)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
@@ -151,7 +183,7 @@ func (a *TaskAdaptor) BuildRequestBody(c *gin.Context, info *relaycommon.TaskRel
|
||||
}
|
||||
|
||||
// DoRequest delegates to common helper.
|
||||
func (a *TaskAdaptor) DoRequest(c *gin.Context, info *relaycommon.TaskRelayInfo, requestBody io.Reader) (*http.Response, error) {
|
||||
func (a *TaskAdaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (*http.Response, error) {
|
||||
if action := c.GetString("action"); action != "" {
|
||||
info.Action = action
|
||||
}
|
||||
@@ -159,34 +191,26 @@ func (a *TaskAdaptor) DoRequest(c *gin.Context, info *relaycommon.TaskRelayInfo,
|
||||
}
|
||||
|
||||
// DoResponse handles upstream response, returns taskID etc.
|
||||
func (a *TaskAdaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.TaskRelayInfo) (taskID string, taskData []byte, taskErr *dto.TaskError) {
|
||||
func (a *TaskAdaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (taskID string, taskData []byte, taskErr *dto.TaskError) {
|
||||
responseBody, err := io.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
taskErr = service.TaskErrorWrapper(err, "read_response_body_failed", http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
|
||||
// Attempt Kling response parse first.
|
||||
var kResp responsePayload
|
||||
if err := json.Unmarshal(responseBody, &kResp); err == nil && kResp.Code == 0 {
|
||||
c.JSON(http.StatusOK, gin.H{"task_id": kResp.Data.TaskId})
|
||||
return kResp.Data.TaskId, responseBody, nil
|
||||
}
|
||||
|
||||
// Fallback generic task response.
|
||||
var generic dto.TaskResponse[string]
|
||||
if err := json.Unmarshal(responseBody, &generic); err != nil {
|
||||
taskErr = service.TaskErrorWrapper(errors.Wrapf(err, "body: %s", responseBody), "unmarshal_response_body_failed", http.StatusInternalServerError)
|
||||
err = json.Unmarshal(responseBody, &kResp)
|
||||
if err != nil {
|
||||
taskErr = service.TaskErrorWrapper(err, "unmarshal_response_failed", http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
|
||||
if !generic.IsSuccess() {
|
||||
taskErr = service.TaskErrorWrapper(fmt.Errorf(generic.Message), generic.Code, http.StatusInternalServerError)
|
||||
if kResp.Code != 0 {
|
||||
taskErr = service.TaskErrorWrapperLocal(fmt.Errorf(kResp.Message), "task_failed", http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
|
||||
c.JSON(http.StatusOK, gin.H{"task_id": generic.Data})
|
||||
return generic.Data, responseBody, nil
|
||||
kResp.TaskId = kResp.Data.TaskId
|
||||
c.JSON(http.StatusOK, kResp)
|
||||
return kResp.Data.TaskId, responseBody, nil
|
||||
}
|
||||
|
||||
// FetchTask fetch task status
|
||||
@@ -233,13 +257,19 @@ func (a *TaskAdaptor) GetChannelName() string {
|
||||
|
||||
func (a *TaskAdaptor) convertToRequestPayload(req *SubmitReq) (*requestPayload, error) {
|
||||
r := requestPayload{
|
||||
Prompt: req.Prompt,
|
||||
Image: req.Image,
|
||||
Mode: defaultString(req.Mode, "std"),
|
||||
Duration: fmt.Sprintf("%d", defaultInt(req.Duration, 5)),
|
||||
AspectRatio: a.getAspectRatio(req.Size),
|
||||
ModelName: req.Model,
|
||||
CfgScale: 0.5,
|
||||
Prompt: req.Prompt,
|
||||
Image: req.Image,
|
||||
Mode: defaultString(req.Mode, "std"),
|
||||
Duration: fmt.Sprintf("%d", defaultInt(req.Duration, 5)),
|
||||
AspectRatio: a.getAspectRatio(req.Size),
|
||||
ModelName: req.Model,
|
||||
Model: req.Model, // Keep consistent with model_name, double writing improves compatibility
|
||||
CfgScale: 0.5,
|
||||
StaticMask: "",
|
||||
DynamicMasks: []DynamicMask{},
|
||||
CameraControl: nil,
|
||||
CallbackUrl: "",
|
||||
ExternalTaskId: "",
|
||||
}
|
||||
if r.ModelName == "" {
|
||||
r.ModelName = "kling-v1"
|
||||
@@ -288,21 +318,25 @@ func defaultInt(v int, def int) int {
|
||||
// ============================
|
||||
|
||||
func (a *TaskAdaptor) createJWTToken() (string, error) {
|
||||
return a.createJWTTokenWithKeys(a.accessKey, a.secretKey)
|
||||
return a.createJWTTokenWithKey(a.apiKey)
|
||||
}
|
||||
|
||||
//func (a *TaskAdaptor) createJWTTokenWithKey(apiKey string) (string, error) {
|
||||
// parts := strings.Split(apiKey, "|")
|
||||
// if len(parts) != 2 {
|
||||
// return "", fmt.Errorf("invalid API key format, expected 'access_key,secret_key'")
|
||||
// }
|
||||
// return a.createJWTTokenWithKey(strings.TrimSpace(parts[0]), strings.TrimSpace(parts[1]))
|
||||
//}
|
||||
|
||||
func (a *TaskAdaptor) createJWTTokenWithKey(apiKey string) (string, error) {
|
||||
parts := strings.Split(apiKey, "|")
|
||||
if len(parts) != 2 {
|
||||
return "", fmt.Errorf("invalid API key format, expected 'access_key,secret_key'")
|
||||
}
|
||||
return a.createJWTTokenWithKeys(strings.TrimSpace(parts[0]), strings.TrimSpace(parts[1]))
|
||||
}
|
||||
|
||||
func (a *TaskAdaptor) createJWTTokenWithKeys(accessKey, secretKey string) (string, error) {
|
||||
if accessKey == "" || secretKey == "" {
|
||||
return "", fmt.Errorf("access key and secret key are required")
|
||||
keyParts := strings.Split(apiKey, "|")
|
||||
accessKey := strings.TrimSpace(keyParts[0])
|
||||
if len(keyParts) == 1 {
|
||||
return accessKey, nil
|
||||
}
|
||||
secretKey := strings.TrimSpace(keyParts[1])
|
||||
now := time.Now().Unix()
|
||||
claims := jwt.MapClaims{
|
||||
"iss": accessKey,
|
||||
@@ -315,12 +349,12 @@ func (a *TaskAdaptor) createJWTTokenWithKeys(accessKey, secretKey string) (strin
|
||||
}
|
||||
|
||||
func (a *TaskAdaptor) ParseTaskResult(respBody []byte) (*relaycommon.TaskInfo, error) {
|
||||
taskInfo := &relaycommon.TaskInfo{}
|
||||
resPayload := responsePayload{}
|
||||
err := json.Unmarshal(respBody, &resPayload)
|
||||
if err != nil {
|
||||
return nil, errors.Wrap(err, "failed to unmarshal response body")
|
||||
}
|
||||
taskInfo := &relaycommon.TaskInfo{}
|
||||
taskInfo.Code = resPayload.Code
|
||||
taskInfo.TaskID = resPayload.Data.TaskId
|
||||
taskInfo.Reason = resPayload.Message
|
||||
|
||||
@@ -5,7 +5,6 @@ import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"github.com/gin-gonic/gin"
|
||||
"io"
|
||||
"net/http"
|
||||
"one-api/common"
|
||||
@@ -16,6 +15,8 @@ import (
|
||||
"one-api/service"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
type TaskAdaptor struct {
|
||||
@@ -26,11 +27,11 @@ func (a *TaskAdaptor) ParseTaskResult([]byte) (*relaycommon.TaskInfo, error) {
|
||||
return nil, fmt.Errorf("not implement") // todo implement this method if needed
|
||||
}
|
||||
|
||||
func (a *TaskAdaptor) Init(info *relaycommon.TaskRelayInfo) {
|
||||
func (a *TaskAdaptor) Init(info *relaycommon.RelayInfo) {
|
||||
a.ChannelType = info.ChannelType
|
||||
}
|
||||
|
||||
func (a *TaskAdaptor) ValidateRequestAndSetAction(c *gin.Context, info *relaycommon.TaskRelayInfo) (taskErr *dto.TaskError) {
|
||||
func (a *TaskAdaptor) ValidateRequestAndSetAction(c *gin.Context, info *relaycommon.RelayInfo) (taskErr *dto.TaskError) {
|
||||
action := strings.ToUpper(c.Param("action"))
|
||||
|
||||
var sunoRequest *dto.SunoSubmitReq
|
||||
@@ -58,20 +59,20 @@ func (a *TaskAdaptor) ValidateRequestAndSetAction(c *gin.Context, info *relaycom
|
||||
return nil
|
||||
}
|
||||
|
||||
func (a *TaskAdaptor) BuildRequestURL(info *relaycommon.TaskRelayInfo) (string, error) {
|
||||
baseURL := info.BaseUrl
|
||||
func (a *TaskAdaptor) BuildRequestURL(info *relaycommon.RelayInfo) (string, error) {
|
||||
baseURL := info.ChannelBaseUrl
|
||||
fullRequestURL := fmt.Sprintf("%s%s", baseURL, "/suno/submit/"+info.Action)
|
||||
return fullRequestURL, nil
|
||||
}
|
||||
|
||||
func (a *TaskAdaptor) BuildRequestHeader(c *gin.Context, req *http.Request, info *relaycommon.TaskRelayInfo) error {
|
||||
func (a *TaskAdaptor) BuildRequestHeader(c *gin.Context, req *http.Request, info *relaycommon.RelayInfo) error {
|
||||
req.Header.Set("Content-Type", c.Request.Header.Get("Content-Type"))
|
||||
req.Header.Set("Accept", c.Request.Header.Get("Accept"))
|
||||
req.Header.Set("Authorization", "Bearer "+info.ApiKey)
|
||||
return nil
|
||||
}
|
||||
|
||||
func (a *TaskAdaptor) BuildRequestBody(c *gin.Context, info *relaycommon.TaskRelayInfo) (io.Reader, error) {
|
||||
func (a *TaskAdaptor) BuildRequestBody(c *gin.Context, info *relaycommon.RelayInfo) (io.Reader, error) {
|
||||
sunoRequest, ok := c.Get("task_request")
|
||||
if !ok {
|
||||
err := common.UnmarshalBodyReusable(c, &sunoRequest)
|
||||
@@ -86,11 +87,11 @@ func (a *TaskAdaptor) BuildRequestBody(c *gin.Context, info *relaycommon.TaskRel
|
||||
return bytes.NewReader(data), nil
|
||||
}
|
||||
|
||||
func (a *TaskAdaptor) DoRequest(c *gin.Context, info *relaycommon.TaskRelayInfo, requestBody io.Reader) (*http.Response, error) {
|
||||
func (a *TaskAdaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (*http.Response, error) {
|
||||
return channel.DoTaskApiRequest(a, c, info, requestBody)
|
||||
}
|
||||
|
||||
func (a *TaskAdaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.TaskRelayInfo) (taskID string, taskData []byte, taskErr *dto.TaskError) {
|
||||
func (a *TaskAdaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (taskID string, taskData []byte, taskErr *dto.TaskError) {
|
||||
responseBody, err := io.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
taskErr = service.TaskErrorWrapper(err, "read_response_body_failed", http.StatusInternalServerError)
|
||||
@@ -139,7 +140,7 @@ func (a *TaskAdaptor) FetchTask(baseUrl, key string, body map[string]any) (*http
|
||||
|
||||
req, err := http.NewRequest("POST", requestUrl, bytes.NewBuffer(byteBody))
|
||||
if err != nil {
|
||||
common.SysError(fmt.Sprintf("Get Task error: %v", err))
|
||||
common.SysLog(fmt.Sprintf("Get Task error: %v", err))
|
||||
return nil, err
|
||||
}
|
||||
defer req.Body.Close()
|
||||
|
||||
285
relay/channel/task/vidu/adaptor.go
Normal file
285
relay/channel/task/vidu/adaptor.go
Normal file
@@ -0,0 +1,285 @@
|
||||
package vidu
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
|
||||
"one-api/constant"
|
||||
"one-api/dto"
|
||||
"one-api/model"
|
||||
"one-api/relay/channel"
|
||||
relaycommon "one-api/relay/common"
|
||||
"one-api/service"
|
||||
|
||||
"github.com/pkg/errors"
|
||||
)
|
||||
|
||||
// ============================
|
||||
// Request / Response structures
|
||||
// ============================
|
||||
|
||||
type SubmitReq struct {
|
||||
Prompt string `json:"prompt"`
|
||||
Model string `json:"model,omitempty"`
|
||||
Mode string `json:"mode,omitempty"`
|
||||
Image string `json:"image,omitempty"`
|
||||
Size string `json:"size,omitempty"`
|
||||
Duration int `json:"duration,omitempty"`
|
||||
Metadata map[string]interface{} `json:"metadata,omitempty"`
|
||||
}
|
||||
|
||||
type requestPayload struct {
|
||||
Model string `json:"model"`
|
||||
Images []string `json:"images"`
|
||||
Prompt string `json:"prompt,omitempty"`
|
||||
Duration int `json:"duration,omitempty"`
|
||||
Seed int `json:"seed,omitempty"`
|
||||
Resolution string `json:"resolution,omitempty"`
|
||||
MovementAmplitude string `json:"movement_amplitude,omitempty"`
|
||||
Bgm bool `json:"bgm,omitempty"`
|
||||
Payload string `json:"payload,omitempty"`
|
||||
CallbackUrl string `json:"callback_url,omitempty"`
|
||||
}
|
||||
|
||||
type responsePayload struct {
|
||||
TaskId string `json:"task_id"`
|
||||
State string `json:"state"`
|
||||
Model string `json:"model"`
|
||||
Images []string `json:"images"`
|
||||
Prompt string `json:"prompt"`
|
||||
Duration int `json:"duration"`
|
||||
Seed int `json:"seed"`
|
||||
Resolution string `json:"resolution"`
|
||||
Bgm bool `json:"bgm"`
|
||||
MovementAmplitude string `json:"movement_amplitude"`
|
||||
Payload string `json:"payload"`
|
||||
CreatedAt string `json:"created_at"`
|
||||
}
|
||||
|
||||
type taskResultResponse struct {
|
||||
State string `json:"state"`
|
||||
ErrCode string `json:"err_code"`
|
||||
Credits int `json:"credits"`
|
||||
Payload string `json:"payload"`
|
||||
Creations []creation `json:"creations"`
|
||||
}
|
||||
|
||||
type creation struct {
|
||||
ID string `json:"id"`
|
||||
URL string `json:"url"`
|
||||
CoverURL string `json:"cover_url"`
|
||||
}
|
||||
|
||||
// ============================
|
||||
// Adaptor implementation
|
||||
// ============================
|
||||
|
||||
type TaskAdaptor struct {
|
||||
ChannelType int
|
||||
baseURL string
|
||||
}
|
||||
|
||||
func (a *TaskAdaptor) Init(info *relaycommon.RelayInfo) {
|
||||
a.ChannelType = info.ChannelType
|
||||
a.baseURL = info.ChannelBaseUrl
|
||||
}
|
||||
|
||||
func (a *TaskAdaptor) ValidateRequestAndSetAction(c *gin.Context, info *relaycommon.RelayInfo) *dto.TaskError {
|
||||
var req SubmitReq
|
||||
if err := c.ShouldBindJSON(&req); err != nil {
|
||||
return service.TaskErrorWrapper(err, "invalid_request_body", http.StatusBadRequest)
|
||||
}
|
||||
|
||||
if req.Prompt == "" {
|
||||
return service.TaskErrorWrapperLocal(fmt.Errorf("prompt is required"), "missing_prompt", http.StatusBadRequest)
|
||||
}
|
||||
|
||||
if req.Image != "" {
|
||||
info.Action = constant.TaskActionGenerate
|
||||
} else {
|
||||
info.Action = constant.TaskActionTextGenerate
|
||||
}
|
||||
|
||||
c.Set("task_request", req)
|
||||
return nil
|
||||
}
|
||||
|
||||
func (a *TaskAdaptor) BuildRequestBody(c *gin.Context, _ *relaycommon.RelayInfo) (io.Reader, error) {
|
||||
v, exists := c.Get("task_request")
|
||||
if !exists {
|
||||
return nil, fmt.Errorf("request not found in context")
|
||||
}
|
||||
req := v.(SubmitReq)
|
||||
|
||||
body, err := a.convertToRequestPayload(&req)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if len(body.Images) == 0 {
|
||||
c.Set("action", constant.TaskActionTextGenerate)
|
||||
}
|
||||
|
||||
data, err := json.Marshal(body)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return bytes.NewReader(data), nil
|
||||
}
|
||||
|
||||
func (a *TaskAdaptor) BuildRequestURL(info *relaycommon.RelayInfo) (string, error) {
|
||||
var path string
|
||||
switch info.Action {
|
||||
case constant.TaskActionGenerate:
|
||||
path = "/img2video"
|
||||
default:
|
||||
path = "/text2video"
|
||||
}
|
||||
return fmt.Sprintf("%s/ent/v2%s", a.baseURL, path), nil
|
||||
}
|
||||
|
||||
func (a *TaskAdaptor) BuildRequestHeader(c *gin.Context, req *http.Request, info *relaycommon.RelayInfo) error {
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
req.Header.Set("Accept", "application/json")
|
||||
req.Header.Set("Authorization", "Token "+info.ApiKey)
|
||||
return nil
|
||||
}
|
||||
|
||||
func (a *TaskAdaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (*http.Response, error) {
|
||||
if action := c.GetString("action"); action != "" {
|
||||
info.Action = action
|
||||
}
|
||||
return channel.DoTaskApiRequest(a, c, info, requestBody)
|
||||
}
|
||||
|
||||
func (a *TaskAdaptor) DoResponse(c *gin.Context, resp *http.Response, _ *relaycommon.RelayInfo) (taskID string, taskData []byte, taskErr *dto.TaskError) {
|
||||
responseBody, err := io.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
taskErr = service.TaskErrorWrapper(err, "read_response_body_failed", http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
|
||||
var vResp responsePayload
|
||||
err = json.Unmarshal(responseBody, &vResp)
|
||||
if err != nil {
|
||||
taskErr = service.TaskErrorWrapper(errors.Wrap(err, fmt.Sprintf("%s", responseBody)), "unmarshal_response_failed", http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
|
||||
if vResp.State == "failed" {
|
||||
taskErr = service.TaskErrorWrapperLocal(fmt.Errorf("task failed"), "task_failed", http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
|
||||
c.JSON(http.StatusOK, vResp)
|
||||
return vResp.TaskId, responseBody, nil
|
||||
}
|
||||
|
||||
func (a *TaskAdaptor) FetchTask(baseUrl, key string, body map[string]any) (*http.Response, error) {
|
||||
taskID, ok := body["task_id"].(string)
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("invalid task_id")
|
||||
}
|
||||
|
||||
url := fmt.Sprintf("%s/ent/v2/tasks/%s/creations", baseUrl, taskID)
|
||||
|
||||
req, err := http.NewRequest(http.MethodGet, url, nil)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
req.Header.Set("Accept", "application/json")
|
||||
req.Header.Set("Authorization", "Token "+key)
|
||||
|
||||
return service.GetHttpClient().Do(req)
|
||||
}
|
||||
|
||||
func (a *TaskAdaptor) GetModelList() []string {
|
||||
return []string{"viduq1", "vidu2.0", "vidu1.5"}
|
||||
}
|
||||
|
||||
func (a *TaskAdaptor) GetChannelName() string {
|
||||
return "vidu"
|
||||
}
|
||||
|
||||
// ============================
|
||||
// helpers
|
||||
// ============================
|
||||
|
||||
func (a *TaskAdaptor) convertToRequestPayload(req *SubmitReq) (*requestPayload, error) {
|
||||
var images []string
|
||||
if req.Image != "" {
|
||||
images = []string{req.Image}
|
||||
}
|
||||
|
||||
r := requestPayload{
|
||||
Model: defaultString(req.Model, "viduq1"),
|
||||
Images: images,
|
||||
Prompt: req.Prompt,
|
||||
Duration: defaultInt(req.Duration, 5),
|
||||
Resolution: defaultString(req.Size, "1080p"),
|
||||
MovementAmplitude: "auto",
|
||||
Bgm: false,
|
||||
}
|
||||
metadata := req.Metadata
|
||||
medaBytes, err := json.Marshal(metadata)
|
||||
if err != nil {
|
||||
return nil, errors.Wrap(err, "metadata marshal metadata failed")
|
||||
}
|
||||
err = json.Unmarshal(medaBytes, &r)
|
||||
if err != nil {
|
||||
return nil, errors.Wrap(err, "unmarshal metadata failed")
|
||||
}
|
||||
return &r, nil
|
||||
}
|
||||
|
||||
func defaultString(value, defaultValue string) string {
|
||||
if value == "" {
|
||||
return defaultValue
|
||||
}
|
||||
return value
|
||||
}
|
||||
|
||||
func defaultInt(value, defaultValue int) int {
|
||||
if value == 0 {
|
||||
return defaultValue
|
||||
}
|
||||
return value
|
||||
}
|
||||
|
||||
func (a *TaskAdaptor) ParseTaskResult(respBody []byte) (*relaycommon.TaskInfo, error) {
|
||||
taskInfo := &relaycommon.TaskInfo{}
|
||||
|
||||
var taskResp taskResultResponse
|
||||
err := json.Unmarshal(respBody, &taskResp)
|
||||
if err != nil {
|
||||
return nil, errors.Wrap(err, "failed to unmarshal response body")
|
||||
}
|
||||
|
||||
state := taskResp.State
|
||||
switch state {
|
||||
case "created", "queueing":
|
||||
taskInfo.Status = model.TaskStatusSubmitted
|
||||
case "processing":
|
||||
taskInfo.Status = model.TaskStatusInProgress
|
||||
case "success":
|
||||
taskInfo.Status = model.TaskStatusSuccess
|
||||
if len(taskResp.Creations) > 0 {
|
||||
taskInfo.Url = taskResp.Creations[0].URL
|
||||
}
|
||||
case "failed":
|
||||
taskInfo.Status = model.TaskStatusFailure
|
||||
if taskResp.ErrCode != "" {
|
||||
taskInfo.Reason = taskResp.ErrCode
|
||||
}
|
||||
default:
|
||||
return nil, fmt.Errorf("unknown task state: %s", state)
|
||||
}
|
||||
|
||||
return taskInfo, nil
|
||||
}
|
||||
@@ -25,6 +25,11 @@ type Adaptor struct {
|
||||
Timestamp int64
|
||||
}
|
||||
|
||||
func (a *Adaptor) ConvertGeminiRequest(*gin.Context, *relaycommon.RelayInfo, *dto.GeminiChatRequest) (any, error) {
|
||||
//TODO implement me
|
||||
return nil, errors.New("not implemented")
|
||||
}
|
||||
|
||||
func (a *Adaptor) ConvertClaudeRequest(*gin.Context, *relaycommon.RelayInfo, *dto.ClaudeRequest) (any, error) {
|
||||
//TODO implement me
|
||||
panic("implement me")
|
||||
@@ -48,7 +53,7 @@ func (a *Adaptor) Init(info *relaycommon.RelayInfo) {
|
||||
}
|
||||
|
||||
func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) {
|
||||
return fmt.Sprintf("%s/", info.BaseUrl), nil
|
||||
return fmt.Sprintf("%s/", info.ChannelBaseUrl), nil
|
||||
}
|
||||
|
||||
func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Header, info *relaycommon.RelayInfo) error {
|
||||
|
||||
@@ -106,7 +106,7 @@ func tencentStreamHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *htt
|
||||
var tencentResponse TencentChatResponse
|
||||
err := json.Unmarshal([]byte(data), &tencentResponse)
|
||||
if err != nil {
|
||||
common.SysError("error unmarshalling stream response: " + err.Error())
|
||||
common.SysLog("error unmarshalling stream response: " + err.Error())
|
||||
continue
|
||||
}
|
||||
|
||||
@@ -117,17 +117,17 @@ func tencentStreamHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *htt
|
||||
|
||||
err = helper.ObjectData(c, response)
|
||||
if err != nil {
|
||||
common.SysError(err.Error())
|
||||
common.SysLog(err.Error())
|
||||
}
|
||||
}
|
||||
|
||||
if err := scanner.Err(); err != nil {
|
||||
common.SysError("error reading stream: " + err.Error())
|
||||
common.SysLog("error reading stream: " + err.Error())
|
||||
}
|
||||
|
||||
helper.Done(c)
|
||||
|
||||
common.CloseResponseBodyGracefully(resp)
|
||||
service.CloseResponseBodyGracefully(resp)
|
||||
|
||||
return service.ResponseText2Usage(responseText, info.UpstreamModelName, info.PromptTokens), nil
|
||||
}
|
||||
@@ -136,12 +136,12 @@ func tencentHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http.Resp
|
||||
var tencentSb TencentChatResponseSB
|
||||
responseBody, err := io.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
return nil, types.NewError(err, types.ErrorCodeReadResponseBodyFailed)
|
||||
return nil, types.NewOpenAIError(err, types.ErrorCodeReadResponseBodyFailed, http.StatusInternalServerError)
|
||||
}
|
||||
common.CloseResponseBodyGracefully(resp)
|
||||
service.CloseResponseBodyGracefully(resp)
|
||||
err = json.Unmarshal(responseBody, &tencentSb)
|
||||
if err != nil {
|
||||
return nil, types.NewError(err, types.ErrorCodeBadResponseBody)
|
||||
return nil, types.NewOpenAIError(err, types.ErrorCodeBadResponseBody, http.StatusInternalServerError)
|
||||
}
|
||||
if tencentSb.Response.Error.Code != 0 {
|
||||
return nil, types.WithOpenAIError(types.OpenAIError{
|
||||
@@ -156,7 +156,7 @@ func tencentHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http.Resp
|
||||
}
|
||||
c.Writer.Header().Set("Content-Type", "application/json")
|
||||
c.Writer.WriteHeader(resp.StatusCode)
|
||||
common.IOCopyBytesGracefully(c, resp, jsonResponse)
|
||||
service.IOCopyBytesGracefully(c, resp, jsonResponse)
|
||||
return &fullTextResponse.Usage, nil
|
||||
}
|
||||
|
||||
|
||||
@@ -35,6 +35,7 @@ var claudeModelMap = map[string]string{
|
||||
"claude-3-7-sonnet-20250219": "claude-3-7-sonnet@20250219",
|
||||
"claude-sonnet-4-20250514": "claude-sonnet-4@20250514",
|
||||
"claude-opus-4-20250514": "claude-opus-4@20250514",
|
||||
"claude-opus-4-1-20250805": "claude-opus-4-1@20250805",
|
||||
}
|
||||
|
||||
const anthropicVersion = "vertex-2023-10-16"
|
||||
@@ -44,6 +45,11 @@ type Adaptor struct {
|
||||
AccountCredentials Credentials
|
||||
}
|
||||
|
||||
func (a *Adaptor) ConvertGeminiRequest(c *gin.Context, info *relaycommon.RelayInfo, request *dto.GeminiChatRequest) (any, error) {
|
||||
geminiAdaptor := gemini.Adaptor{}
|
||||
return geminiAdaptor.ConvertGeminiRequest(c, info, request)
|
||||
}
|
||||
|
||||
func (a *Adaptor) ConvertClaudeRequest(c *gin.Context, info *relaycommon.RelayInfo, request *dto.ClaudeRequest) (any, error) {
|
||||
if v, ok := claudeModelMap[info.UpstreamModelName]; ok {
|
||||
c.Set("request_model", v)
|
||||
@@ -60,17 +66,17 @@ func (a *Adaptor) ConvertAudioRequest(c *gin.Context, info *relaycommon.RelayInf
|
||||
}
|
||||
|
||||
func (a *Adaptor) ConvertImageRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.ImageRequest) (any, error) {
|
||||
//TODO implement me
|
||||
return nil, errors.New("not implemented")
|
||||
geminiAdaptor := gemini.Adaptor{}
|
||||
return geminiAdaptor.ConvertImageRequest(c, info, request)
|
||||
}
|
||||
|
||||
func (a *Adaptor) Init(info *relaycommon.RelayInfo) {
|
||||
if strings.HasPrefix(info.UpstreamModelName, "claude") {
|
||||
a.RequestMode = RequestModeClaude
|
||||
} else if strings.HasPrefix(info.UpstreamModelName, "gemini") {
|
||||
a.RequestMode = RequestModeGemini
|
||||
} else if strings.Contains(info.UpstreamModelName, "llama") {
|
||||
a.RequestMode = RequestModeLlama
|
||||
} else {
|
||||
a.RequestMode = RequestModeGemini
|
||||
}
|
||||
}
|
||||
|
||||
@@ -83,6 +89,7 @@ func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) {
|
||||
a.AccountCredentials = *adc
|
||||
suffix := ""
|
||||
if a.RequestMode == RequestModeGemini {
|
||||
|
||||
if model_setting.GetGeminiSettings().ThinkingAdapterEnabled {
|
||||
// 新增逻辑:处理 -thinking-<budget> 格式
|
||||
if strings.Contains(info.UpstreamModelName, "-thinking-") {
|
||||
@@ -100,6 +107,11 @@ func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) {
|
||||
} else {
|
||||
suffix = "generateContent"
|
||||
}
|
||||
|
||||
if strings.HasPrefix(info.UpstreamModelName, "imagen") {
|
||||
suffix = "predict"
|
||||
}
|
||||
|
||||
if region == "global" {
|
||||
return fmt.Sprintf(
|
||||
"https://aiplatform.googleapis.com/v1/projects/%s/locations/global/publishers/google/models/%s:%s",
|
||||
@@ -169,8 +181,62 @@ func (a *Adaptor) ConvertOpenAIRequest(c *gin.Context, info *relaycommon.RelayIn
|
||||
if request == nil {
|
||||
return nil, errors.New("request is nil")
|
||||
}
|
||||
if a.RequestMode == RequestModeGemini && strings.HasPrefix(info.UpstreamModelName, "imagen") {
|
||||
prompt := ""
|
||||
for _, m := range request.Messages {
|
||||
if m.Role == "user" {
|
||||
prompt = m.StringContent()
|
||||
if prompt != "" {
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
if prompt == "" {
|
||||
if p, ok := request.Prompt.(string); ok {
|
||||
prompt = p
|
||||
}
|
||||
}
|
||||
if prompt == "" {
|
||||
return nil, errors.New("prompt is required for image generation")
|
||||
}
|
||||
|
||||
imgReq := dto.ImageRequest{
|
||||
Model: request.Model,
|
||||
Prompt: prompt,
|
||||
N: 1,
|
||||
Size: "1024x1024",
|
||||
}
|
||||
if request.N > 0 {
|
||||
imgReq.N = uint(request.N)
|
||||
}
|
||||
if request.Size != "" {
|
||||
imgReq.Size = request.Size
|
||||
}
|
||||
if len(request.ExtraBody) > 0 {
|
||||
var extra map[string]any
|
||||
if err := json.Unmarshal(request.ExtraBody, &extra); err == nil {
|
||||
if n, ok := extra["n"].(float64); ok && n > 0 {
|
||||
imgReq.N = uint(n)
|
||||
}
|
||||
if size, ok := extra["size"].(string); ok {
|
||||
imgReq.Size = size
|
||||
}
|
||||
// accept aspectRatio in extra body (top-level or under parameters)
|
||||
if ar, ok := extra["aspectRatio"].(string); ok && ar != "" {
|
||||
imgReq.Size = ar
|
||||
}
|
||||
if params, ok := extra["parameters"].(map[string]any); ok {
|
||||
if ar, ok := params["aspectRatio"].(string); ok && ar != "" {
|
||||
imgReq.Size = ar
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
c.Set("request_model", request.Model)
|
||||
return a.ConvertImageRequest(c, info, imgReq)
|
||||
}
|
||||
if a.RequestMode == RequestModeClaude {
|
||||
claudeReq, err := claude.RequestOpenAI2ClaudeMessage(*request)
|
||||
claudeReq, err := claude.RequestOpenAI2ClaudeMessage(c, *request)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -179,7 +245,7 @@ func (a *Adaptor) ConvertOpenAIRequest(c *gin.Context, info *relaycommon.RelayIn
|
||||
info.UpstreamModelName = claudeReq.Model
|
||||
return vertexClaudeReq, nil
|
||||
} else if a.RequestMode == RequestModeGemini {
|
||||
geminiRequest, err := gemini.CovertGemini2OpenAI(*request, info)
|
||||
geminiRequest, err := gemini.CovertGemini2OpenAI(c, *request, info)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -213,28 +279,31 @@ func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycom
|
||||
if info.IsStream {
|
||||
switch a.RequestMode {
|
||||
case RequestModeClaude:
|
||||
err, usage = claude.ClaudeStreamHandler(c, resp, info, claude.RequestModeMessage)
|
||||
return claude.ClaudeStreamHandler(c, resp, info, claude.RequestModeMessage)
|
||||
case RequestModeGemini:
|
||||
if info.RelayMode == constant.RelayModeGemini {
|
||||
usage, err = gemini.GeminiTextGenerationStreamHandler(c, info, resp)
|
||||
return gemini.GeminiTextGenerationStreamHandler(c, info, resp)
|
||||
} else {
|
||||
usage, err = gemini.GeminiChatStreamHandler(c, info, resp)
|
||||
return gemini.GeminiChatStreamHandler(c, info, resp)
|
||||
}
|
||||
case RequestModeLlama:
|
||||
usage, err = openai.OaiStreamHandler(c, info, resp)
|
||||
return openai.OaiStreamHandler(c, info, resp)
|
||||
}
|
||||
} else {
|
||||
switch a.RequestMode {
|
||||
case RequestModeClaude:
|
||||
err, usage = claude.ClaudeHandler(c, resp, claude.RequestModeMessage, info)
|
||||
return claude.ClaudeHandler(c, resp, info, claude.RequestModeMessage)
|
||||
case RequestModeGemini:
|
||||
if info.RelayMode == constant.RelayModeGemini {
|
||||
usage, err = gemini.GeminiTextGenerationHandler(c, info, resp)
|
||||
return gemini.GeminiTextGenerationHandler(c, info, resp)
|
||||
} else {
|
||||
usage, err = gemini.GeminiChatHandler(c, info, resp)
|
||||
if strings.HasPrefix(info.UpstreamModelName, "imagen") {
|
||||
return gemini.GeminiImageHandler(c, info, resp)
|
||||
}
|
||||
return gemini.GeminiChatHandler(c, info, resp)
|
||||
}
|
||||
case RequestModeLlama:
|
||||
usage, err = openai.OpenaiHandler(c, info, resp)
|
||||
return openai.OpenaiHandler(c, info, resp)
|
||||
}
|
||||
}
|
||||
return
|
||||
|
||||
@@ -36,7 +36,12 @@ var Cache = asynccache.NewAsyncCache(asynccache.Options{
|
||||
})
|
||||
|
||||
func getAccessToken(a *Adaptor, info *relaycommon.RelayInfo) (string, error) {
|
||||
cacheKey := fmt.Sprintf("access-token-%d", info.ChannelId)
|
||||
var cacheKey string
|
||||
if info.ChannelIsMultiKey {
|
||||
cacheKey = fmt.Sprintf("access-token-%d-%d", info.ChannelId, info.ChannelMultiKeyIndex)
|
||||
} else {
|
||||
cacheKey = fmt.Sprintf("access-token-%d", info.ChannelId)
|
||||
}
|
||||
val, err := Cache.Get(cacheKey)
|
||||
if err == nil {
|
||||
return val.(string), nil
|
||||
|
||||
@@ -2,6 +2,7 @@ package volcengine
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
@@ -23,10 +24,14 @@ import (
|
||||
type Adaptor struct {
|
||||
}
|
||||
|
||||
func (a *Adaptor) ConvertClaudeRequest(*gin.Context, *relaycommon.RelayInfo, *dto.ClaudeRequest) (any, error) {
|
||||
func (a *Adaptor) ConvertGeminiRequest(*gin.Context, *relaycommon.RelayInfo, *dto.GeminiChatRequest) (any, error) {
|
||||
//TODO implement me
|
||||
panic("implement me")
|
||||
return nil, nil
|
||||
return nil, errors.New("not implemented")
|
||||
}
|
||||
|
||||
func (a *Adaptor) ConvertClaudeRequest(c *gin.Context, info *relaycommon.RelayInfo, req *dto.ClaudeRequest) (any, error) {
|
||||
adaptor := openai.Adaptor{}
|
||||
return adaptor.ConvertClaudeRequest(c, info, req)
|
||||
}
|
||||
|
||||
func (a *Adaptor) ConvertAudioRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.AudioRequest) (io.Reader, error) {
|
||||
@@ -184,13 +189,17 @@ func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) {
|
||||
switch info.RelayMode {
|
||||
case constant.RelayModeChatCompletions:
|
||||
if strings.HasPrefix(info.UpstreamModelName, "bot") {
|
||||
return fmt.Sprintf("%s/api/v3/bots/chat/completions", info.BaseUrl), nil
|
||||
return fmt.Sprintf("%s/api/v3/bots/chat/completions", info.ChannelBaseUrl), nil
|
||||
}
|
||||
return fmt.Sprintf("%s/api/v3/chat/completions", info.BaseUrl), nil
|
||||
return fmt.Sprintf("%s/api/v3/chat/completions", info.ChannelBaseUrl), nil
|
||||
case constant.RelayModeEmbeddings:
|
||||
return fmt.Sprintf("%s/api/v3/embeddings", info.BaseUrl), nil
|
||||
return fmt.Sprintf("%s/api/v3/embeddings", info.ChannelBaseUrl), nil
|
||||
case constant.RelayModeImagesGenerations:
|
||||
return fmt.Sprintf("%s/api/v3/images/generations", info.BaseUrl), nil
|
||||
return fmt.Sprintf("%s/api/v3/images/generations", info.ChannelBaseUrl), nil
|
||||
case constant.RelayModeImagesEdits:
|
||||
return fmt.Sprintf("%s/api/v3/images/edits", info.ChannelBaseUrl), nil
|
||||
case constant.RelayModeRerank:
|
||||
return fmt.Sprintf("%s/api/v3/rerank", info.ChannelBaseUrl), nil
|
||||
default:
|
||||
}
|
||||
return "", fmt.Errorf("unsupported relay mode: %d", info.RelayMode)
|
||||
@@ -206,6 +215,12 @@ func (a *Adaptor) ConvertOpenAIRequest(c *gin.Context, info *relaycommon.RelayIn
|
||||
if request == nil {
|
||||
return nil, errors.New("request is nil")
|
||||
}
|
||||
// 适配 方舟deepseek混合模型 的 thinking 后缀
|
||||
if strings.HasSuffix(info.UpstreamModelName, "-thinking") && strings.HasPrefix(info.UpstreamModelName, "deepseek") {
|
||||
info.UpstreamModelName = strings.TrimSuffix(info.UpstreamModelName, "-thinking")
|
||||
request.Model = info.UpstreamModelName
|
||||
request.THINKING = json.RawMessage(`{"type": "enabled"}`)
|
||||
}
|
||||
return request, nil
|
||||
}
|
||||
|
||||
@@ -227,18 +242,8 @@ func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, request
|
||||
}
|
||||
|
||||
func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage any, err *types.NewAPIError) {
|
||||
switch info.RelayMode {
|
||||
case constant.RelayModeChatCompletions:
|
||||
if info.IsStream {
|
||||
usage, err = openai.OaiStreamHandler(c, info, resp)
|
||||
} else {
|
||||
usage, err = openai.OpenaiHandler(c, info, resp)
|
||||
}
|
||||
case constant.RelayModeEmbeddings:
|
||||
usage, err = openai.OpenaiHandler(c, info, resp)
|
||||
case constant.RelayModeImagesGenerations, constant.RelayModeImagesEdits:
|
||||
usage, err = openai.OpenaiHandlerWithUsage(c, info, resp)
|
||||
}
|
||||
adaptor := openai.Adaptor{}
|
||||
usage, err = adaptor.DoResponse(c, resp, info)
|
||||
return
|
||||
}
|
||||
|
||||
|
||||
@@ -19,6 +19,11 @@ import (
|
||||
type Adaptor struct {
|
||||
}
|
||||
|
||||
func (a *Adaptor) ConvertGeminiRequest(*gin.Context, *relaycommon.RelayInfo, *dto.GeminiChatRequest) (any, error) {
|
||||
//TODO implement me
|
||||
return nil, errors.New("not implemented")
|
||||
}
|
||||
|
||||
func (a *Adaptor) ConvertClaudeRequest(*gin.Context, *relaycommon.RelayInfo, *dto.ClaudeRequest) (any, error) {
|
||||
//TODO implement me
|
||||
//panic("implement me")
|
||||
@@ -34,7 +39,7 @@ func (a *Adaptor) ConvertImageRequest(c *gin.Context, info *relaycommon.RelayInf
|
||||
xaiRequest := ImageRequest{
|
||||
Model: request.Model,
|
||||
Prompt: request.Prompt,
|
||||
N: request.N,
|
||||
N: int(request.N),
|
||||
ResponseFormat: request.ResponseFormat,
|
||||
}
|
||||
return xaiRequest, nil
|
||||
@@ -44,7 +49,7 @@ func (a *Adaptor) Init(info *relaycommon.RelayInfo) {
|
||||
}
|
||||
|
||||
func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) {
|
||||
return relaycommon.GetFullRequestURL(info.BaseUrl, info.RequestURLPath, info.ChannelType), nil
|
||||
return relaycommon.GetFullRequestURL(info.ChannelBaseUrl, info.RequestURLPath, info.ChannelType), nil
|
||||
}
|
||||
|
||||
func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Header, info *relaycommon.RelayInfo) error {
|
||||
|
||||
@@ -47,7 +47,7 @@ func xAIStreamHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http.Re
|
||||
var xAIResp *dto.ChatCompletionsStreamResponse
|
||||
err := json.Unmarshal([]byte(data), &xAIResp)
|
||||
if err != nil {
|
||||
common.SysError("error unmarshalling stream response: " + err.Error())
|
||||
common.SysLog("error unmarshalling stream response: " + err.Error())
|
||||
return true
|
||||
}
|
||||
|
||||
@@ -63,7 +63,7 @@ func xAIStreamHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http.Re
|
||||
_ = openai.ProcessStreamResponse(*openaiResponse, &responseTextBuilder, &toolCount)
|
||||
err = helper.ObjectData(c, openaiResponse)
|
||||
if err != nil {
|
||||
common.SysError(err.Error())
|
||||
common.SysLog(err.Error())
|
||||
}
|
||||
return true
|
||||
})
|
||||
@@ -74,12 +74,12 @@ func xAIStreamHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http.Re
|
||||
}
|
||||
|
||||
helper.Done(c)
|
||||
common.CloseResponseBodyGracefully(resp)
|
||||
service.CloseResponseBodyGracefully(resp)
|
||||
return usage, nil
|
||||
}
|
||||
|
||||
func xAIHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http.Response) (*dto.Usage, *types.NewAPIError) {
|
||||
defer common.CloseResponseBodyGracefully(resp)
|
||||
defer service.CloseResponseBodyGracefully(resp)
|
||||
|
||||
responseBody, err := io.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
@@ -101,7 +101,7 @@ func xAIHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http.Response
|
||||
return nil, types.NewError(err, types.ErrorCodeBadResponseBody)
|
||||
}
|
||||
|
||||
common.IOCopyBytesGracefully(c, resp, encodeJson)
|
||||
service.IOCopyBytesGracefully(c, resp, encodeJson)
|
||||
|
||||
return xaiResponse.Usage, nil
|
||||
}
|
||||
|
||||
@@ -17,6 +17,11 @@ type Adaptor struct {
|
||||
request *dto.GeneralOpenAIRequest
|
||||
}
|
||||
|
||||
func (a *Adaptor) ConvertGeminiRequest(*gin.Context, *relaycommon.RelayInfo, *dto.GeminiChatRequest) (any, error) {
|
||||
//TODO implement me
|
||||
return nil, errors.New("not implemented")
|
||||
}
|
||||
|
||||
func (a *Adaptor) ConvertClaudeRequest(*gin.Context, *relaycommon.RelayInfo, *dto.ClaudeRequest) (any, error) {
|
||||
//TODO implement me
|
||||
panic("implement me")
|
||||
|
||||
@@ -48,7 +48,7 @@ func requestOpenAI2Xunfei(request dto.GeneralOpenAIRequest, xunfeiAppId string,
|
||||
xunfeiRequest.Parameter.Chat.Domain = domain
|
||||
xunfeiRequest.Parameter.Chat.Temperature = request.Temperature
|
||||
xunfeiRequest.Parameter.Chat.TopK = request.N
|
||||
xunfeiRequest.Parameter.Chat.MaxTokens = request.MaxTokens
|
||||
xunfeiRequest.Parameter.Chat.MaxTokens = request.GetMaxTokens()
|
||||
xunfeiRequest.Payload.Message.Text = messages
|
||||
return &xunfeiRequest
|
||||
}
|
||||
@@ -143,7 +143,7 @@ func xunfeiStreamHandler(c *gin.Context, textRequest dto.GeneralOpenAIRequest, a
|
||||
response := streamResponseXunfei2OpenAI(&xunfeiResponse)
|
||||
jsonResponse, err := json.Marshal(response)
|
||||
if err != nil {
|
||||
common.SysError("error marshalling stream response: " + err.Error())
|
||||
common.SysLog("error marshalling stream response: " + err.Error())
|
||||
return true
|
||||
}
|
||||
c.Render(-1, common.CustomEvent{Data: "data: " + string(jsonResponse)})
|
||||
@@ -206,6 +206,11 @@ func xunfeiMakeRequest(textRequest dto.GeneralOpenAIRequest, domain, authUrl, ap
|
||||
if err != nil || resp.StatusCode != 101 {
|
||||
return nil, nil, err
|
||||
}
|
||||
|
||||
defer func() {
|
||||
conn.Close()
|
||||
}()
|
||||
|
||||
data := requestOpenAI2Xunfei(textRequest, appId, domain)
|
||||
err = conn.WriteJSON(data)
|
||||
if err != nil {
|
||||
@@ -218,20 +223,19 @@ func xunfeiMakeRequest(textRequest dto.GeneralOpenAIRequest, domain, authUrl, ap
|
||||
for {
|
||||
_, msg, err := conn.ReadMessage()
|
||||
if err != nil {
|
||||
common.SysError("error reading stream response: " + err.Error())
|
||||
common.SysLog("error reading stream response: " + err.Error())
|
||||
break
|
||||
}
|
||||
var response XunfeiChatResponse
|
||||
err = json.Unmarshal(msg, &response)
|
||||
if err != nil {
|
||||
common.SysError("error unmarshalling stream response: " + err.Error())
|
||||
common.SysLog("error unmarshalling stream response: " + err.Error())
|
||||
break
|
||||
}
|
||||
dataChan <- response
|
||||
if response.Payload.Choices.Status == 2 {
|
||||
err := conn.Close()
|
||||
if err != nil {
|
||||
common.SysError("error closing websocket connection: " + err.Error())
|
||||
common.SysLog("error closing websocket connection: " + err.Error())
|
||||
}
|
||||
break
|
||||
}
|
||||
|
||||
@@ -16,6 +16,11 @@ import (
|
||||
type Adaptor struct {
|
||||
}
|
||||
|
||||
func (a *Adaptor) ConvertGeminiRequest(*gin.Context, *relaycommon.RelayInfo, *dto.GeminiChatRequest) (any, error) {
|
||||
//TODO implement me
|
||||
return nil, errors.New("not implemented")
|
||||
}
|
||||
|
||||
func (a *Adaptor) ConvertClaudeRequest(*gin.Context, *relaycommon.RelayInfo, *dto.ClaudeRequest) (any, error) {
|
||||
//TODO implement me
|
||||
panic("implement me")
|
||||
@@ -40,7 +45,7 @@ func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) {
|
||||
if info.IsStream {
|
||||
method = "sse-invoke"
|
||||
}
|
||||
return fmt.Sprintf("%s/api/paas/v3/model-api/%s/%s", info.BaseUrl, info.UpstreamModelName, method), nil
|
||||
return fmt.Sprintf("%s/api/paas/v3/model-api/%s/%s", info.ChannelBaseUrl, info.UpstreamModelName, method), nil
|
||||
}
|
||||
|
||||
func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Header, info *relaycommon.RelayInfo) error {
|
||||
|
||||
@@ -10,6 +10,7 @@ import (
|
||||
"one-api/dto"
|
||||
relaycommon "one-api/relay/common"
|
||||
"one-api/relay/helper"
|
||||
"one-api/service"
|
||||
"one-api/types"
|
||||
"strings"
|
||||
"sync"
|
||||
@@ -38,7 +39,7 @@ func getZhipuToken(apikey string) string {
|
||||
|
||||
split := strings.Split(apikey, ".")
|
||||
if len(split) != 2 {
|
||||
common.SysError("invalid zhipu key: " + apikey)
|
||||
common.SysLog("invalid zhipu key: " + apikey)
|
||||
return ""
|
||||
}
|
||||
|
||||
@@ -186,7 +187,7 @@ func zhipuStreamHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http.
|
||||
response := streamResponseZhipu2OpenAI(data)
|
||||
jsonResponse, err := json.Marshal(response)
|
||||
if err != nil {
|
||||
common.SysError("error marshalling stream response: " + err.Error())
|
||||
common.SysLog("error marshalling stream response: " + err.Error())
|
||||
return true
|
||||
}
|
||||
c.Render(-1, common.CustomEvent{Data: "data: " + string(jsonResponse)})
|
||||
@@ -195,13 +196,13 @@ func zhipuStreamHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http.
|
||||
var zhipuResponse ZhipuStreamMetaResponse
|
||||
err := json.Unmarshal([]byte(data), &zhipuResponse)
|
||||
if err != nil {
|
||||
common.SysError("error unmarshalling stream response: " + err.Error())
|
||||
common.SysLog("error unmarshalling stream response: " + err.Error())
|
||||
return true
|
||||
}
|
||||
response, zhipuUsage := streamMetaResponseZhipu2OpenAI(&zhipuResponse)
|
||||
jsonResponse, err := json.Marshal(response)
|
||||
if err != nil {
|
||||
common.SysError("error marshalling stream response: " + err.Error())
|
||||
common.SysLog("error marshalling stream response: " + err.Error())
|
||||
return true
|
||||
}
|
||||
usage = zhipuUsage
|
||||
@@ -212,7 +213,7 @@ func zhipuStreamHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http.
|
||||
return false
|
||||
}
|
||||
})
|
||||
common.CloseResponseBodyGracefully(resp)
|
||||
service.CloseResponseBodyGracefully(resp)
|
||||
return usage, nil
|
||||
}
|
||||
|
||||
@@ -220,12 +221,12 @@ func zhipuHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http.Respon
|
||||
var zhipuResponse ZhipuResponse
|
||||
responseBody, err := io.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
return nil, types.NewError(err, types.ErrorCodeReadResponseBodyFailed)
|
||||
return nil, types.NewOpenAIError(err, types.ErrorCodeReadResponseBodyFailed, http.StatusInternalServerError)
|
||||
}
|
||||
common.CloseResponseBodyGracefully(resp)
|
||||
service.CloseResponseBodyGracefully(resp)
|
||||
err = json.Unmarshal(responseBody, &zhipuResponse)
|
||||
if err != nil {
|
||||
return nil, types.NewError(err, types.ErrorCodeBadResponseBody)
|
||||
return nil, types.NewOpenAIError(err, types.ErrorCodeBadResponseBody, http.StatusInternalServerError)
|
||||
}
|
||||
if !zhipuResponse.Success {
|
||||
return nil, types.WithOpenAIError(types.OpenAIError{
|
||||
|
||||
@@ -7,6 +7,7 @@ import (
|
||||
"net/http"
|
||||
"one-api/dto"
|
||||
"one-api/relay/channel"
|
||||
"one-api/relay/channel/claude"
|
||||
"one-api/relay/channel/openai"
|
||||
relaycommon "one-api/relay/common"
|
||||
relayconstant "one-api/relay/constant"
|
||||
@@ -18,10 +19,13 @@ import (
|
||||
type Adaptor struct {
|
||||
}
|
||||
|
||||
func (a *Adaptor) ConvertClaudeRequest(*gin.Context, *relaycommon.RelayInfo, *dto.ClaudeRequest) (any, error) {
|
||||
func (a *Adaptor) ConvertGeminiRequest(*gin.Context, *relaycommon.RelayInfo, *dto.GeminiChatRequest) (any, error) {
|
||||
//TODO implement me
|
||||
panic("implement me")
|
||||
return nil, nil
|
||||
return nil, errors.New("not implemented")
|
||||
}
|
||||
|
||||
func (a *Adaptor) ConvertClaudeRequest(c *gin.Context, info *relaycommon.RelayInfo, req *dto.ClaudeRequest) (any, error) {
|
||||
return req, nil
|
||||
}
|
||||
|
||||
func (a *Adaptor) ConvertAudioRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.AudioRequest) (io.Reader, error) {
|
||||
@@ -38,19 +42,22 @@ func (a *Adaptor) Init(info *relaycommon.RelayInfo) {
|
||||
}
|
||||
|
||||
func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) {
|
||||
baseUrl := fmt.Sprintf("%s/api/paas/v4", info.BaseUrl)
|
||||
switch info.RelayMode {
|
||||
case relayconstant.RelayModeEmbeddings:
|
||||
return fmt.Sprintf("%s/embeddings", baseUrl), nil
|
||||
switch info.RelayFormat {
|
||||
case types.RelayFormatClaude:
|
||||
return fmt.Sprintf("%s/api/anthropic/v1/messages", info.ChannelBaseUrl), nil
|
||||
default:
|
||||
return fmt.Sprintf("%s/chat/completions", baseUrl), nil
|
||||
switch info.RelayMode {
|
||||
case relayconstant.RelayModeEmbeddings:
|
||||
return fmt.Sprintf("%s/api/paas/v4/embeddings", info.ChannelBaseUrl), nil
|
||||
default:
|
||||
return fmt.Sprintf("%s/api/paas/v4/chat/completions", info.ChannelBaseUrl), nil
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Header, info *relaycommon.RelayInfo) error {
|
||||
channel.SetupApiRequestHeader(info, c, req)
|
||||
token := getZhipuToken(info.ApiKey)
|
||||
req.Set("Authorization", token)
|
||||
req.Set("Authorization", "Bearer "+info.ApiKey)
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -82,12 +89,17 @@ func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, request
|
||||
}
|
||||
|
||||
func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage any, err *types.NewAPIError) {
|
||||
if info.IsStream {
|
||||
usage, err = openai.OaiStreamHandler(c, info, resp)
|
||||
} else {
|
||||
usage, err = openai.OpenaiHandler(c, info, resp)
|
||||
switch info.RelayFormat {
|
||||
case types.RelayFormatClaude:
|
||||
if info.IsStream {
|
||||
return claude.ClaudeStreamHandler(c, resp, info, claude.RequestModeMessage)
|
||||
} else {
|
||||
return claude.ClaudeHandler(c, resp, info, claude.RequestModeMessage)
|
||||
}
|
||||
default:
|
||||
adaptor := openai.Adaptor{}
|
||||
return adaptor.DoResponse(c, resp, info)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
func (a *Adaptor) GetModelList() []string {
|
||||
|
||||
@@ -1,69 +1,10 @@
|
||||
package zhipu_4v
|
||||
|
||||
import (
|
||||
"github.com/golang-jwt/jwt"
|
||||
"one-api/common"
|
||||
"one-api/dto"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
)
|
||||
|
||||
// https://open.bigmodel.cn/doc/api#chatglm_std
|
||||
// chatglm_std, chatglm_lite
|
||||
// https://open.bigmodel.cn/api/paas/v3/model-api/chatglm_std/invoke
|
||||
// https://open.bigmodel.cn/api/paas/v3/model-api/chatglm_std/sse-invoke
|
||||
|
||||
var zhipuTokens sync.Map
|
||||
var expSeconds int64 = 24 * 3600
|
||||
|
||||
func getZhipuToken(apikey string) string {
|
||||
data, ok := zhipuTokens.Load(apikey)
|
||||
if ok {
|
||||
tokenData := data.(tokenData)
|
||||
if time.Now().Before(tokenData.ExpiryTime) {
|
||||
return tokenData.Token
|
||||
}
|
||||
}
|
||||
|
||||
split := strings.Split(apikey, ".")
|
||||
if len(split) != 2 {
|
||||
common.SysError("invalid zhipu key: " + apikey)
|
||||
return ""
|
||||
}
|
||||
|
||||
id := split[0]
|
||||
secret := split[1]
|
||||
|
||||
expMillis := time.Now().Add(time.Duration(expSeconds)*time.Second).UnixNano() / 1e6
|
||||
expiryTime := time.Now().Add(time.Duration(expSeconds) * time.Second)
|
||||
|
||||
timestamp := time.Now().UnixNano() / 1e6
|
||||
|
||||
payload := jwt.MapClaims{
|
||||
"api_key": id,
|
||||
"exp": expMillis,
|
||||
"timestamp": timestamp,
|
||||
}
|
||||
|
||||
token := jwt.NewWithClaims(jwt.SigningMethodHS256, payload)
|
||||
|
||||
token.Header["alg"] = "HS256"
|
||||
token.Header["sign_type"] = "SIGN"
|
||||
|
||||
tokenString, err := token.SignedString([]byte(secret))
|
||||
if err != nil {
|
||||
return ""
|
||||
}
|
||||
|
||||
zhipuTokens.Store(apikey, tokenData{
|
||||
Token: tokenString,
|
||||
ExpiryTime: expiryTime,
|
||||
})
|
||||
|
||||
return tokenString
|
||||
}
|
||||
|
||||
func requestOpenAI2Zhipu(request dto.GeneralOpenAIRequest) *dto.GeneralOpenAIRequest {
|
||||
messages := make([]dto.Message, 0, len(request.Messages))
|
||||
for _, message := range request.Messages {
|
||||
@@ -105,9 +46,10 @@ func requestOpenAI2Zhipu(request dto.GeneralOpenAIRequest) *dto.GeneralOpenAIReq
|
||||
Messages: messages,
|
||||
Temperature: request.Temperature,
|
||||
TopP: request.TopP,
|
||||
MaxTokens: request.MaxTokens,
|
||||
MaxTokens: request.GetMaxTokens(),
|
||||
Stop: Stop,
|
||||
Tools: request.Tools,
|
||||
ToolChoice: request.ToolChoice,
|
||||
THINKING: request.THINKING,
|
||||
}
|
||||
}
|
||||
|
||||
@@ -2,7 +2,6 @@ package relay
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
@@ -18,119 +17,99 @@ import (
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
func getAndValidateClaudeRequest(c *gin.Context) (textRequest *dto.ClaudeRequest, err error) {
|
||||
textRequest = &dto.ClaudeRequest{}
|
||||
err = c.ShouldBindJSON(textRequest)
|
||||
func ClaudeHelper(c *gin.Context, info *relaycommon.RelayInfo) (newAPIError *types.NewAPIError) {
|
||||
|
||||
info.InitChannelMeta(c)
|
||||
|
||||
claudeReq, ok := info.Request.(*dto.ClaudeRequest)
|
||||
|
||||
if !ok {
|
||||
return types.NewErrorWithStatusCode(fmt.Errorf("invalid request type, expected *dto.ClaudeRequest, got %T", info.Request), types.ErrorCodeInvalidRequest, http.StatusBadRequest, types.ErrOptionWithSkipRetry())
|
||||
}
|
||||
|
||||
request, err := common.DeepCopy(claudeReq)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
return types.NewError(fmt.Errorf("failed to copy request to ClaudeRequest: %w", err), types.ErrorCodeInvalidRequest, types.ErrOptionWithSkipRetry())
|
||||
}
|
||||
if textRequest.Messages == nil || len(textRequest.Messages) == 0 {
|
||||
return nil, errors.New("field messages is required")
|
||||
}
|
||||
if textRequest.Model == "" {
|
||||
return nil, errors.New("field model is required")
|
||||
}
|
||||
return textRequest, nil
|
||||
}
|
||||
|
||||
func ClaudeHelper(c *gin.Context) (newAPIError *types.NewAPIError) {
|
||||
|
||||
relayInfo := relaycommon.GenRelayInfoClaude(c)
|
||||
|
||||
// get & validate textRequest 获取并验证文本请求
|
||||
textRequest, err := getAndValidateClaudeRequest(c)
|
||||
err = helper.ModelMappedHelper(c, info, request)
|
||||
if err != nil {
|
||||
return types.NewError(err, types.ErrorCodeInvalidRequest)
|
||||
return types.NewError(err, types.ErrorCodeChannelModelMappedError, types.ErrOptionWithSkipRetry())
|
||||
}
|
||||
|
||||
if textRequest.Stream {
|
||||
relayInfo.IsStream = true
|
||||
}
|
||||
|
||||
err = helper.ModelMappedHelper(c, relayInfo, textRequest)
|
||||
if err != nil {
|
||||
return types.NewError(err, types.ErrorCodeChannelModelMappedError)
|
||||
}
|
||||
|
||||
promptTokens, err := getClaudePromptTokens(textRequest, relayInfo)
|
||||
// count messages token error 计算promptTokens错误
|
||||
if err != nil {
|
||||
return types.NewError(err, types.ErrorCodeCountTokenFailed)
|
||||
}
|
||||
|
||||
priceData, err := helper.ModelPriceHelper(c, relayInfo, promptTokens, int(textRequest.MaxTokens))
|
||||
if err != nil {
|
||||
return types.NewError(err, types.ErrorCodeModelPriceError)
|
||||
}
|
||||
|
||||
// pre-consume quota 预消耗配额
|
||||
preConsumedQuota, userQuota, newAPIError := preConsumeQuota(c, priceData.ShouldPreConsumedQuota, relayInfo)
|
||||
|
||||
if newAPIError != nil {
|
||||
return newAPIError
|
||||
}
|
||||
defer func() {
|
||||
if newAPIError != nil {
|
||||
returnPreConsumedQuota(c, relayInfo, userQuota, preConsumedQuota)
|
||||
}
|
||||
}()
|
||||
|
||||
adaptor := GetAdaptor(relayInfo.ApiType)
|
||||
adaptor := GetAdaptor(info.ApiType)
|
||||
if adaptor == nil {
|
||||
return types.NewError(fmt.Errorf("invalid api type: %d", relayInfo.ApiType), types.ErrorCodeInvalidApiType)
|
||||
return types.NewError(fmt.Errorf("invalid api type: %d", info.ApiType), types.ErrorCodeInvalidApiType, types.ErrOptionWithSkipRetry())
|
||||
}
|
||||
adaptor.Init(relayInfo)
|
||||
var requestBody io.Reader
|
||||
adaptor.Init(info)
|
||||
|
||||
if textRequest.MaxTokens == 0 {
|
||||
textRequest.MaxTokens = uint(model_setting.GetClaudeSettings().GetDefaultMaxTokens(textRequest.Model))
|
||||
if request.MaxTokens == 0 {
|
||||
request.MaxTokens = uint(model_setting.GetClaudeSettings().GetDefaultMaxTokens(request.Model))
|
||||
}
|
||||
|
||||
if model_setting.GetClaudeSettings().ThinkingAdapterEnabled &&
|
||||
strings.HasSuffix(textRequest.Model, "-thinking") {
|
||||
if textRequest.Thinking == nil {
|
||||
strings.HasSuffix(request.Model, "-thinking") {
|
||||
if request.Thinking == nil {
|
||||
// 因为BudgetTokens 必须大于1024
|
||||
if textRequest.MaxTokens < 1280 {
|
||||
textRequest.MaxTokens = 1280
|
||||
if request.MaxTokens < 1280 {
|
||||
request.MaxTokens = 1280
|
||||
}
|
||||
|
||||
// BudgetTokens 为 max_tokens 的 80%
|
||||
textRequest.Thinking = &dto.Thinking{
|
||||
request.Thinking = &dto.Thinking{
|
||||
Type: "enabled",
|
||||
BudgetTokens: common.GetPointer[int](int(float64(textRequest.MaxTokens) * model_setting.GetClaudeSettings().ThinkingAdapterBudgetTokensPercentage)),
|
||||
BudgetTokens: common.GetPointer[int](int(float64(request.MaxTokens) * model_setting.GetClaudeSettings().ThinkingAdapterBudgetTokensPercentage)),
|
||||
}
|
||||
// TODO: 临时处理
|
||||
// https://docs.anthropic.com/en/docs/build-with-claude/extended-thinking#important-considerations-when-using-extended-thinking
|
||||
textRequest.TopP = 0
|
||||
textRequest.Temperature = common.GetPointer[float64](1.0)
|
||||
request.TopP = 0
|
||||
request.Temperature = common.GetPointer[float64](1.0)
|
||||
}
|
||||
textRequest.Model = strings.TrimSuffix(textRequest.Model, "-thinking")
|
||||
relayInfo.UpstreamModelName = textRequest.Model
|
||||
request.Model = strings.TrimSuffix(request.Model, "-thinking")
|
||||
info.UpstreamModelName = request.Model
|
||||
}
|
||||
|
||||
convertedRequest, err := adaptor.ConvertClaudeRequest(c, relayInfo, textRequest)
|
||||
if err != nil {
|
||||
return types.NewError(err, types.ErrorCodeConvertRequestFailed)
|
||||
var requestBody io.Reader
|
||||
if model_setting.GetGlobalSettings().PassThroughRequestEnabled || info.ChannelSetting.PassThroughBodyEnabled {
|
||||
body, err := common.GetRequestBody(c)
|
||||
if err != nil {
|
||||
return types.NewErrorWithStatusCode(err, types.ErrorCodeReadRequestBodyFailed, http.StatusBadRequest, types.ErrOptionWithSkipRetry())
|
||||
}
|
||||
requestBody = bytes.NewBuffer(body)
|
||||
} else {
|
||||
convertedRequest, err := adaptor.ConvertClaudeRequest(c, info, request)
|
||||
if err != nil {
|
||||
return types.NewError(err, types.ErrorCodeConvertRequestFailed, types.ErrOptionWithSkipRetry())
|
||||
}
|
||||
jsonData, err := common.Marshal(convertedRequest)
|
||||
if err != nil {
|
||||
return types.NewError(err, types.ErrorCodeConvertRequestFailed, types.ErrOptionWithSkipRetry())
|
||||
}
|
||||
|
||||
// apply param override
|
||||
if len(info.ParamOverride) > 0 {
|
||||
jsonData, err = relaycommon.ApplyParamOverride(jsonData, info.ParamOverride)
|
||||
if err != nil {
|
||||
return types.NewError(err, types.ErrorCodeChannelParamOverrideInvalid, types.ErrOptionWithSkipRetry())
|
||||
}
|
||||
}
|
||||
|
||||
if common.DebugEnabled {
|
||||
println("requestBody: ", string(jsonData))
|
||||
}
|
||||
requestBody = bytes.NewBuffer(jsonData)
|
||||
}
|
||||
jsonData, err := common.Marshal(convertedRequest)
|
||||
if common.DebugEnabled {
|
||||
println("requestBody: ", string(jsonData))
|
||||
}
|
||||
if err != nil {
|
||||
return types.NewError(err, types.ErrorCodeConvertRequestFailed)
|
||||
}
|
||||
requestBody = bytes.NewBuffer(jsonData)
|
||||
|
||||
statusCodeMappingStr := c.GetString("status_code_mapping")
|
||||
var httpResp *http.Response
|
||||
resp, err := adaptor.DoRequest(c, relayInfo, requestBody)
|
||||
resp, err := adaptor.DoRequest(c, info, requestBody)
|
||||
if err != nil {
|
||||
return types.NewOpenAIError(err, types.ErrorCodeDoRequestFailed, http.StatusInternalServerError)
|
||||
}
|
||||
|
||||
if resp != nil {
|
||||
httpResp = resp.(*http.Response)
|
||||
relayInfo.IsStream = relayInfo.IsStream || strings.HasPrefix(httpResp.Header.Get("Content-Type"), "text/event-stream")
|
||||
info.IsStream = info.IsStream || strings.HasPrefix(httpResp.Header.Get("Content-Type"), "text/event-stream")
|
||||
if httpResp.StatusCode != http.StatusOK {
|
||||
newAPIError = service.RelayErrorHandler(httpResp, false)
|
||||
// reset status code 重置状态码
|
||||
@@ -139,24 +118,14 @@ func ClaudeHelper(c *gin.Context) (newAPIError *types.NewAPIError) {
|
||||
}
|
||||
}
|
||||
|
||||
usage, newAPIError := adaptor.DoResponse(c, httpResp, relayInfo)
|
||||
usage, newAPIError := adaptor.DoResponse(c, httpResp, info)
|
||||
//log.Printf("usage: %v", usage)
|
||||
if newAPIError != nil {
|
||||
// reset status code 重置状态码
|
||||
service.ResetStatusCode(newAPIError, statusCodeMappingStr)
|
||||
return newAPIError
|
||||
}
|
||||
service.PostClaudeConsumeQuota(c, relayInfo, usage.(*dto.Usage), preConsumedQuota, userQuota, priceData, "")
|
||||
|
||||
service.PostClaudeConsumeQuota(c, info, usage.(*dto.Usage))
|
||||
return nil
|
||||
}
|
||||
|
||||
func getClaudePromptTokens(textRequest *dto.ClaudeRequest, info *relaycommon.RelayInfo) (int, error) {
|
||||
var promptTokens int
|
||||
var err error
|
||||
switch info.RelayMode {
|
||||
default:
|
||||
promptTokens, err = service.CountTokenClaudeRequest(*textRequest, info.UpstreamModelName)
|
||||
}
|
||||
info.PromptTokens = promptTokens
|
||||
return promptTokens, err
|
||||
}
|
||||
|
||||
435
relay/common/override.go
Normal file
435
relay/common/override.go
Normal file
@@ -0,0 +1,435 @@
|
||||
package common
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"github.com/tidwall/gjson"
|
||||
"github.com/tidwall/sjson"
|
||||
"regexp"
|
||||
"strconv"
|
||||
"strings"
|
||||
)
|
||||
|
||||
type ConditionOperation struct {
|
||||
Path string `json:"path"` // JSON路径
|
||||
Mode string `json:"mode"` // full, prefix, suffix, contains, gt, gte, lt, lte
|
||||
Value interface{} `json:"value"` // 匹配的值
|
||||
Invert bool `json:"invert"` // 反选功能,true表示取反结果
|
||||
PassMissingKey bool `json:"pass_missing_key"` // 未获取到json key时的行为
|
||||
}
|
||||
|
||||
type ParamOperation struct {
|
||||
Path string `json:"path"`
|
||||
Mode string `json:"mode"` // delete, set, move, prepend, append
|
||||
Value interface{} `json:"value"`
|
||||
KeepOrigin bool `json:"keep_origin"`
|
||||
From string `json:"from,omitempty"`
|
||||
To string `json:"to,omitempty"`
|
||||
Conditions []ConditionOperation `json:"conditions,omitempty"` // 条件列表
|
||||
Logic string `json:"logic,omitempty"` // AND, OR (默认OR)
|
||||
}
|
||||
|
||||
func ApplyParamOverride(jsonData []byte, paramOverride map[string]interface{}) ([]byte, error) {
|
||||
if len(paramOverride) == 0 {
|
||||
return jsonData, nil
|
||||
}
|
||||
|
||||
// 尝试断言为操作格式
|
||||
if operations, ok := tryParseOperations(paramOverride); ok {
|
||||
// 使用新方法
|
||||
result, err := applyOperations(string(jsonData), operations)
|
||||
return []byte(result), err
|
||||
}
|
||||
|
||||
// 直接使用旧方法
|
||||
return applyOperationsLegacy(jsonData, paramOverride)
|
||||
}
|
||||
|
||||
func tryParseOperations(paramOverride map[string]interface{}) ([]ParamOperation, bool) {
|
||||
// 检查是否包含 "operations" 字段
|
||||
if opsValue, exists := paramOverride["operations"]; exists {
|
||||
if opsSlice, ok := opsValue.([]interface{}); ok {
|
||||
var operations []ParamOperation
|
||||
for _, op := range opsSlice {
|
||||
if opMap, ok := op.(map[string]interface{}); ok {
|
||||
operation := ParamOperation{}
|
||||
|
||||
// 断言必要字段
|
||||
if path, ok := opMap["path"].(string); ok {
|
||||
operation.Path = path
|
||||
}
|
||||
if mode, ok := opMap["mode"].(string); ok {
|
||||
operation.Mode = mode
|
||||
} else {
|
||||
return nil, false // mode 是必需的
|
||||
}
|
||||
|
||||
// 可选字段
|
||||
if value, exists := opMap["value"]; exists {
|
||||
operation.Value = value
|
||||
}
|
||||
if keepOrigin, ok := opMap["keep_origin"].(bool); ok {
|
||||
operation.KeepOrigin = keepOrigin
|
||||
}
|
||||
if from, ok := opMap["from"].(string); ok {
|
||||
operation.From = from
|
||||
}
|
||||
if to, ok := opMap["to"].(string); ok {
|
||||
operation.To = to
|
||||
}
|
||||
if logic, ok := opMap["logic"].(string); ok {
|
||||
operation.Logic = logic
|
||||
} else {
|
||||
operation.Logic = "OR" // 默认为OR
|
||||
}
|
||||
|
||||
// 解析条件
|
||||
if conditions, exists := opMap["conditions"]; exists {
|
||||
if condSlice, ok := conditions.([]interface{}); ok {
|
||||
for _, cond := range condSlice {
|
||||
if condMap, ok := cond.(map[string]interface{}); ok {
|
||||
condition := ConditionOperation{}
|
||||
if path, ok := condMap["path"].(string); ok {
|
||||
condition.Path = path
|
||||
}
|
||||
if mode, ok := condMap["mode"].(string); ok {
|
||||
condition.Mode = mode
|
||||
}
|
||||
if value, ok := condMap["value"]; ok {
|
||||
condition.Value = value
|
||||
}
|
||||
if invert, ok := condMap["invert"].(bool); ok {
|
||||
condition.Invert = invert
|
||||
}
|
||||
if passMissingKey, ok := condMap["pass_missing_key"].(bool); ok {
|
||||
condition.PassMissingKey = passMissingKey
|
||||
}
|
||||
operation.Conditions = append(operation.Conditions, condition)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
operations = append(operations, operation)
|
||||
} else {
|
||||
return nil, false
|
||||
}
|
||||
}
|
||||
return operations, true
|
||||
}
|
||||
}
|
||||
|
||||
return nil, false
|
||||
}
|
||||
|
||||
func checkConditions(jsonStr string, conditions []ConditionOperation, logic string) (bool, error) {
|
||||
if len(conditions) == 0 {
|
||||
return true, nil // 没有条件,直接通过
|
||||
}
|
||||
results := make([]bool, len(conditions))
|
||||
for i, condition := range conditions {
|
||||
result, err := checkSingleCondition(jsonStr, condition)
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
results[i] = result
|
||||
}
|
||||
|
||||
if strings.ToUpper(logic) == "AND" {
|
||||
for _, result := range results {
|
||||
if !result {
|
||||
return false, nil
|
||||
}
|
||||
}
|
||||
return true, nil
|
||||
} else {
|
||||
for _, result := range results {
|
||||
if result {
|
||||
return true, nil
|
||||
}
|
||||
}
|
||||
return false, nil
|
||||
}
|
||||
}
|
||||
|
||||
func checkSingleCondition(jsonStr string, condition ConditionOperation) (bool, error) {
|
||||
// 处理负数索引
|
||||
path := processNegativeIndex(jsonStr, condition.Path)
|
||||
value := gjson.Get(jsonStr, path)
|
||||
if !value.Exists() {
|
||||
if condition.PassMissingKey {
|
||||
return true, nil
|
||||
}
|
||||
return false, nil
|
||||
}
|
||||
|
||||
// 利用gjson的类型解析
|
||||
targetBytes, err := json.Marshal(condition.Value)
|
||||
if err != nil {
|
||||
return false, fmt.Errorf("failed to marshal condition value: %v", err)
|
||||
}
|
||||
targetValue := gjson.ParseBytes(targetBytes)
|
||||
|
||||
result, err := compareGjsonValues(value, targetValue, strings.ToLower(condition.Mode))
|
||||
if err != nil {
|
||||
return false, fmt.Errorf("comparison failed for path %s: %v", condition.Path, err)
|
||||
}
|
||||
|
||||
if condition.Invert {
|
||||
result = !result
|
||||
}
|
||||
return result, nil
|
||||
}
|
||||
|
||||
func processNegativeIndex(jsonStr string, path string) string {
|
||||
re := regexp.MustCompile(`\.(-\d+)`)
|
||||
matches := re.FindAllStringSubmatch(path, -1)
|
||||
|
||||
if len(matches) == 0 {
|
||||
return path
|
||||
}
|
||||
|
||||
result := path
|
||||
for _, match := range matches {
|
||||
negIndex := match[1]
|
||||
index, _ := strconv.Atoi(negIndex)
|
||||
|
||||
arrayPath := strings.Split(path, negIndex)[0]
|
||||
if strings.HasSuffix(arrayPath, ".") {
|
||||
arrayPath = arrayPath[:len(arrayPath)-1]
|
||||
}
|
||||
|
||||
array := gjson.Get(jsonStr, arrayPath)
|
||||
if array.IsArray() {
|
||||
length := len(array.Array())
|
||||
actualIndex := length + index
|
||||
if actualIndex >= 0 && actualIndex < length {
|
||||
result = strings.Replace(result, match[0], "."+strconv.Itoa(actualIndex), 1)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return result
|
||||
}
|
||||
|
||||
// compareGjsonValues 直接比较两个gjson.Result,支持所有比较模式
|
||||
func compareGjsonValues(jsonValue, targetValue gjson.Result, mode string) (bool, error) {
|
||||
switch mode {
|
||||
case "full":
|
||||
return compareEqual(jsonValue, targetValue)
|
||||
case "prefix":
|
||||
return strings.HasPrefix(jsonValue.String(), targetValue.String()), nil
|
||||
case "suffix":
|
||||
return strings.HasSuffix(jsonValue.String(), targetValue.String()), nil
|
||||
case "contains":
|
||||
return strings.Contains(jsonValue.String(), targetValue.String()), nil
|
||||
case "gt":
|
||||
return compareNumeric(jsonValue, targetValue, "gt")
|
||||
case "gte":
|
||||
return compareNumeric(jsonValue, targetValue, "gte")
|
||||
case "lt":
|
||||
return compareNumeric(jsonValue, targetValue, "lt")
|
||||
case "lte":
|
||||
return compareNumeric(jsonValue, targetValue, "lte")
|
||||
default:
|
||||
return false, fmt.Errorf("unsupported comparison mode: %s", mode)
|
||||
}
|
||||
}
|
||||
|
||||
func compareEqual(jsonValue, targetValue gjson.Result) (bool, error) {
|
||||
// 对布尔值特殊处理
|
||||
if (jsonValue.Type == gjson.True || jsonValue.Type == gjson.False) &&
|
||||
(targetValue.Type == gjson.True || targetValue.Type == gjson.False) {
|
||||
return jsonValue.Bool() == targetValue.Bool(), nil
|
||||
}
|
||||
|
||||
// 如果类型不同,报错
|
||||
if jsonValue.Type != targetValue.Type {
|
||||
return false, fmt.Errorf("compare for different types, got %v and %v", jsonValue.Type, targetValue.Type)
|
||||
}
|
||||
|
||||
switch jsonValue.Type {
|
||||
case gjson.True, gjson.False:
|
||||
return jsonValue.Bool() == targetValue.Bool(), nil
|
||||
case gjson.Number:
|
||||
return jsonValue.Num == targetValue.Num, nil
|
||||
case gjson.String:
|
||||
return jsonValue.String() == targetValue.String(), nil
|
||||
default:
|
||||
return jsonValue.String() == targetValue.String(), nil
|
||||
}
|
||||
}
|
||||
|
||||
func compareNumeric(jsonValue, targetValue gjson.Result, operator string) (bool, error) {
|
||||
// 只有数字类型才支持数值比较
|
||||
if jsonValue.Type != gjson.Number || targetValue.Type != gjson.Number {
|
||||
return false, fmt.Errorf("numeric comparison requires both values to be numbers, got %v and %v", jsonValue.Type, targetValue.Type)
|
||||
}
|
||||
|
||||
jsonNum := jsonValue.Num
|
||||
targetNum := targetValue.Num
|
||||
|
||||
switch operator {
|
||||
case "gt":
|
||||
return jsonNum > targetNum, nil
|
||||
case "gte":
|
||||
return jsonNum >= targetNum, nil
|
||||
case "lt":
|
||||
return jsonNum < targetNum, nil
|
||||
case "lte":
|
||||
return jsonNum <= targetNum, nil
|
||||
default:
|
||||
return false, fmt.Errorf("unsupported numeric operator: %s", operator)
|
||||
}
|
||||
}
|
||||
|
||||
// applyOperationsLegacy 原参数覆盖方法
|
||||
func applyOperationsLegacy(jsonData []byte, paramOverride map[string]interface{}) ([]byte, error) {
|
||||
reqMap := make(map[string]interface{})
|
||||
err := json.Unmarshal(jsonData, &reqMap)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
for key, value := range paramOverride {
|
||||
reqMap[key] = value
|
||||
}
|
||||
|
||||
return json.Marshal(reqMap)
|
||||
}
|
||||
|
||||
func applyOperations(jsonStr string, operations []ParamOperation) (string, error) {
|
||||
result := jsonStr
|
||||
for _, op := range operations {
|
||||
// 检查条件是否满足
|
||||
ok, err := checkConditions(result, op.Conditions, op.Logic)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
if !ok {
|
||||
continue // 条件不满足,跳过当前操作
|
||||
}
|
||||
// 处理路径中的负数索引
|
||||
opPath := processNegativeIndex(result, op.Path)
|
||||
opFrom := processNegativeIndex(result, op.From)
|
||||
opTo := processNegativeIndex(result, op.To)
|
||||
|
||||
switch op.Mode {
|
||||
case "delete":
|
||||
result, err = sjson.Delete(result, opPath)
|
||||
case "set":
|
||||
if op.KeepOrigin && gjson.Get(result, opPath).Exists() {
|
||||
continue
|
||||
}
|
||||
result, err = sjson.Set(result, opPath, op.Value)
|
||||
case "move":
|
||||
result, err = moveValue(result, opFrom, opTo)
|
||||
case "prepend":
|
||||
result, err = modifyValue(result, opPath, op.Value, op.KeepOrigin, true)
|
||||
case "append":
|
||||
result, err = modifyValue(result, opPath, op.Value, op.KeepOrigin, false)
|
||||
default:
|
||||
return "", fmt.Errorf("unknown operation: %s", op.Mode)
|
||||
}
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("operation %s failed: %v", op.Mode, err)
|
||||
}
|
||||
}
|
||||
return result, nil
|
||||
}
|
||||
|
||||
func moveValue(jsonStr, fromPath, toPath string) (string, error) {
|
||||
sourceValue := gjson.Get(jsonStr, fromPath)
|
||||
if !sourceValue.Exists() {
|
||||
return jsonStr, fmt.Errorf("source path does not exist: %s", fromPath)
|
||||
}
|
||||
result, err := sjson.Set(jsonStr, toPath, sourceValue.Value())
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
return sjson.Delete(result, fromPath)
|
||||
}
|
||||
|
||||
func modifyValue(jsonStr, path string, value interface{}, keepOrigin, isPrepend bool) (string, error) {
|
||||
current := gjson.Get(jsonStr, path)
|
||||
switch {
|
||||
case current.IsArray():
|
||||
return modifyArray(jsonStr, path, value, isPrepend)
|
||||
case current.Type == gjson.String:
|
||||
return modifyString(jsonStr, path, value, isPrepend)
|
||||
case current.Type == gjson.JSON:
|
||||
return mergeObjects(jsonStr, path, value, keepOrigin)
|
||||
}
|
||||
return jsonStr, fmt.Errorf("operation not supported for type: %v", current.Type)
|
||||
}
|
||||
|
||||
func modifyArray(jsonStr, path string, value interface{}, isPrepend bool) (string, error) {
|
||||
current := gjson.Get(jsonStr, path)
|
||||
var newArray []interface{}
|
||||
// 添加新值
|
||||
addValue := func() {
|
||||
if arr, ok := value.([]interface{}); ok {
|
||||
newArray = append(newArray, arr...)
|
||||
} else {
|
||||
newArray = append(newArray, value)
|
||||
}
|
||||
}
|
||||
// 添加原值
|
||||
addOriginal := func() {
|
||||
current.ForEach(func(_, val gjson.Result) bool {
|
||||
newArray = append(newArray, val.Value())
|
||||
return true
|
||||
})
|
||||
}
|
||||
if isPrepend {
|
||||
addValue()
|
||||
addOriginal()
|
||||
} else {
|
||||
addOriginal()
|
||||
addValue()
|
||||
}
|
||||
return sjson.Set(jsonStr, path, newArray)
|
||||
}
|
||||
|
||||
func modifyString(jsonStr, path string, value interface{}, isPrepend bool) (string, error) {
|
||||
current := gjson.Get(jsonStr, path)
|
||||
valueStr := fmt.Sprintf("%v", value)
|
||||
var newStr string
|
||||
if isPrepend {
|
||||
newStr = valueStr + current.String()
|
||||
} else {
|
||||
newStr = current.String() + valueStr
|
||||
}
|
||||
return sjson.Set(jsonStr, path, newStr)
|
||||
}
|
||||
|
||||
func mergeObjects(jsonStr, path string, value interface{}, keepOrigin bool) (string, error) {
|
||||
current := gjson.Get(jsonStr, path)
|
||||
var currentMap, newMap map[string]interface{}
|
||||
|
||||
// 解析当前值
|
||||
if err := json.Unmarshal([]byte(current.Raw), ¤tMap); err != nil {
|
||||
return "", err
|
||||
}
|
||||
// 解析新值
|
||||
switch v := value.(type) {
|
||||
case map[string]interface{}:
|
||||
newMap = v
|
||||
default:
|
||||
jsonBytes, _ := json.Marshal(v)
|
||||
if err := json.Unmarshal(jsonBytes, &newMap); err != nil {
|
||||
return "", err
|
||||
}
|
||||
}
|
||||
// 合并
|
||||
result := make(map[string]interface{})
|
||||
for k, v := range currentMap {
|
||||
result[k] = v
|
||||
}
|
||||
for k, v := range newMap {
|
||||
if !keepOrigin || result[k] == nil {
|
||||
result[k] = v
|
||||
}
|
||||
}
|
||||
return sjson.Set(jsonStr, path, result)
|
||||
}
|
||||
@@ -1,10 +1,13 @@
|
||||
package common
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"one-api/common"
|
||||
"one-api/constant"
|
||||
"one-api/dto"
|
||||
relayconstant "one-api/relay/constant"
|
||||
"one-api/types"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
@@ -33,17 +36,6 @@ type ClaudeConvertInfo struct {
|
||||
Done bool
|
||||
}
|
||||
|
||||
const (
|
||||
RelayFormatOpenAI = "openai"
|
||||
RelayFormatClaude = "claude"
|
||||
RelayFormatGemini = "gemini"
|
||||
RelayFormatOpenAIResponses = "openai_responses"
|
||||
RelayFormatOpenAIAudio = "openai_audio"
|
||||
RelayFormatOpenAIImage = "openai_image"
|
||||
RelayFormatRerank = "rerank"
|
||||
RelayFormatEmbedding = "embedding"
|
||||
)
|
||||
|
||||
type RerankerInfo struct {
|
||||
Documents []any
|
||||
ReturnDocuments bool
|
||||
@@ -59,9 +51,27 @@ type ResponsesUsageInfo struct {
|
||||
BuiltInTools map[string]*BuildInToolInfo
|
||||
}
|
||||
|
||||
type ChannelMeta struct {
|
||||
ChannelType int
|
||||
ChannelId int
|
||||
ChannelIsMultiKey bool
|
||||
ChannelMultiKeyIndex int
|
||||
ChannelBaseUrl string
|
||||
ApiType int
|
||||
ApiVersion string
|
||||
ApiKey string
|
||||
Organization string
|
||||
ChannelCreateTime int64
|
||||
ParamOverride map[string]interface{}
|
||||
HeadersOverride map[string]interface{}
|
||||
ChannelSetting dto.ChannelSettings
|
||||
ChannelOtherSettings dto.ChannelOtherSettings
|
||||
UpstreamModelName string
|
||||
IsModelMapped bool
|
||||
SupportStreamOptions bool // 是否支持流式选项
|
||||
}
|
||||
|
||||
type RelayInfo struct {
|
||||
ChannelType int
|
||||
ChannelId int
|
||||
TokenId int
|
||||
TokenKey string
|
||||
UserId int
|
||||
@@ -72,43 +82,169 @@ type RelayInfo struct {
|
||||
FirstResponseTime time.Time
|
||||
isFirstResponse bool
|
||||
//SendLastReasoningResponse bool
|
||||
ApiType int
|
||||
IsStream bool
|
||||
IsPlayground bool
|
||||
UsePrice bool
|
||||
RelayMode int
|
||||
UpstreamModelName string
|
||||
OriginModelName string
|
||||
//RecodeModelName string
|
||||
RequestURLPath string
|
||||
ApiVersion string
|
||||
PromptTokens int
|
||||
ApiKey string
|
||||
Organization string
|
||||
BaseUrl string
|
||||
SupportStreamOptions bool
|
||||
ShouldIncludeUsage bool
|
||||
IsModelMapped bool
|
||||
ClientWs *websocket.Conn
|
||||
TargetWs *websocket.Conn
|
||||
InputAudioFormat string
|
||||
OutputAudioFormat string
|
||||
RealtimeTools []dto.RealTimeTool
|
||||
IsFirstRequest bool
|
||||
AudioUsage bool
|
||||
ReasoningEffort string
|
||||
ChannelSetting dto.ChannelSettings
|
||||
ParamOverride map[string]interface{}
|
||||
UserSetting dto.UserSetting
|
||||
UserEmail string
|
||||
UserQuota int
|
||||
RelayFormat string
|
||||
SendResponseCount int
|
||||
ChannelCreateTime int64
|
||||
IsStream bool
|
||||
IsGeminiBatchEmbedding bool
|
||||
IsPlayground bool
|
||||
UsePrice bool
|
||||
RelayMode int
|
||||
OriginModelName string
|
||||
RequestURLPath string
|
||||
PromptTokens int
|
||||
ShouldIncludeUsage bool
|
||||
DisablePing bool // 是否禁止向下游发送自定义 Ping
|
||||
ClientWs *websocket.Conn
|
||||
TargetWs *websocket.Conn
|
||||
InputAudioFormat string
|
||||
OutputAudioFormat string
|
||||
RealtimeTools []dto.RealTimeTool
|
||||
IsFirstRequest bool
|
||||
AudioUsage bool
|
||||
ReasoningEffort string
|
||||
UserSetting dto.UserSetting
|
||||
UserEmail string
|
||||
UserQuota int
|
||||
RelayFormat types.RelayFormat
|
||||
SendResponseCount int
|
||||
FinalPreConsumedQuota int // 最终预消耗的配额
|
||||
|
||||
PriceData types.PriceData
|
||||
|
||||
Request dto.Request
|
||||
|
||||
ThinkingContentInfo
|
||||
*ClaudeConvertInfo
|
||||
*RerankerInfo
|
||||
*ResponsesUsageInfo
|
||||
*ChannelMeta
|
||||
*TaskRelayInfo
|
||||
}
|
||||
|
||||
func (info *RelayInfo) InitChannelMeta(c *gin.Context) {
|
||||
channelType := common.GetContextKeyInt(c, constant.ContextKeyChannelType)
|
||||
paramOverride := common.GetContextKeyStringMap(c, constant.ContextKeyChannelParamOverride)
|
||||
headerOverride := common.GetContextKeyStringMap(c, constant.ContextKeyChannelHeaderOverride)
|
||||
apiType, _ := common.ChannelType2APIType(channelType)
|
||||
channelMeta := &ChannelMeta{
|
||||
ChannelType: channelType,
|
||||
ChannelId: common.GetContextKeyInt(c, constant.ContextKeyChannelId),
|
||||
ChannelIsMultiKey: common.GetContextKeyBool(c, constant.ContextKeyChannelIsMultiKey),
|
||||
ChannelMultiKeyIndex: common.GetContextKeyInt(c, constant.ContextKeyChannelMultiKeyIndex),
|
||||
ChannelBaseUrl: common.GetContextKeyString(c, constant.ContextKeyChannelBaseUrl),
|
||||
ApiType: apiType,
|
||||
ApiVersion: c.GetString("api_version"),
|
||||
ApiKey: common.GetContextKeyString(c, constant.ContextKeyChannelKey),
|
||||
Organization: c.GetString("channel_organization"),
|
||||
ChannelCreateTime: c.GetInt64("channel_create_time"),
|
||||
ParamOverride: paramOverride,
|
||||
HeadersOverride: headerOverride,
|
||||
UpstreamModelName: common.GetContextKeyString(c, constant.ContextKeyOriginalModel),
|
||||
IsModelMapped: false,
|
||||
SupportStreamOptions: false,
|
||||
}
|
||||
|
||||
if channelType == constant.ChannelTypeAzure {
|
||||
channelMeta.ApiVersion = GetAPIVersion(c)
|
||||
}
|
||||
if channelType == constant.ChannelTypeVertexAi {
|
||||
channelMeta.ApiVersion = c.GetString("region")
|
||||
}
|
||||
|
||||
channelSetting, ok := common.GetContextKeyType[dto.ChannelSettings](c, constant.ContextKeyChannelSetting)
|
||||
if ok {
|
||||
channelMeta.ChannelSetting = channelSetting
|
||||
}
|
||||
|
||||
channelOtherSettings, ok := common.GetContextKeyType[dto.ChannelOtherSettings](c, constant.ContextKeyChannelOtherSetting)
|
||||
if ok {
|
||||
channelMeta.ChannelOtherSettings = channelOtherSettings
|
||||
}
|
||||
|
||||
if streamSupportedChannels[channelMeta.ChannelType] {
|
||||
channelMeta.SupportStreamOptions = true
|
||||
}
|
||||
|
||||
info.ChannelMeta = channelMeta
|
||||
|
||||
// reset some fields based on channel meta
|
||||
// 重置某些字段,例如模型名称等
|
||||
if info.Request != nil {
|
||||
info.Request.SetModelName(info.OriginModelName)
|
||||
}
|
||||
}
|
||||
|
||||
func (info *RelayInfo) ToString() string {
|
||||
if info == nil {
|
||||
return "RelayInfo<nil>"
|
||||
}
|
||||
|
||||
// Basic info
|
||||
b := &strings.Builder{}
|
||||
fmt.Fprintf(b, "RelayInfo{ ")
|
||||
fmt.Fprintf(b, "RelayFormat: %s, ", info.RelayFormat)
|
||||
fmt.Fprintf(b, "RelayMode: %d, ", info.RelayMode)
|
||||
fmt.Fprintf(b, "IsStream: %t, ", info.IsStream)
|
||||
fmt.Fprintf(b, "IsPlayground: %t, ", info.IsPlayground)
|
||||
fmt.Fprintf(b, "RequestURLPath: %q, ", info.RequestURLPath)
|
||||
fmt.Fprintf(b, "OriginModelName: %q, ", info.OriginModelName)
|
||||
fmt.Fprintf(b, "PromptTokens: %d, ", info.PromptTokens)
|
||||
fmt.Fprintf(b, "ShouldIncludeUsage: %t, ", info.ShouldIncludeUsage)
|
||||
fmt.Fprintf(b, "DisablePing: %t, ", info.DisablePing)
|
||||
fmt.Fprintf(b, "SendResponseCount: %d, ", info.SendResponseCount)
|
||||
fmt.Fprintf(b, "FinalPreConsumedQuota: %d, ", info.FinalPreConsumedQuota)
|
||||
|
||||
// User & token info (mask secrets)
|
||||
fmt.Fprintf(b, "User{ Id: %d, Email: %q, Group: %q, UsingGroup: %q, Quota: %d }, ",
|
||||
info.UserId, common.MaskEmail(info.UserEmail), info.UserGroup, info.UsingGroup, info.UserQuota)
|
||||
fmt.Fprintf(b, "Token{ Id: %d, Unlimited: %t, Key: ***masked*** }, ", info.TokenId, info.TokenUnlimited)
|
||||
|
||||
// Time info
|
||||
latencyMs := info.FirstResponseTime.Sub(info.StartTime).Milliseconds()
|
||||
fmt.Fprintf(b, "Timing{ Start: %s, FirstResponse: %s, LatencyMs: %d }, ",
|
||||
info.StartTime.Format(time.RFC3339Nano), info.FirstResponseTime.Format(time.RFC3339Nano), latencyMs)
|
||||
|
||||
// Audio / realtime
|
||||
if info.InputAudioFormat != "" || info.OutputAudioFormat != "" || len(info.RealtimeTools) > 0 || info.AudioUsage {
|
||||
fmt.Fprintf(b, "Realtime{ AudioUsage: %t, InFmt: %q, OutFmt: %q, Tools: %d }, ",
|
||||
info.AudioUsage, info.InputAudioFormat, info.OutputAudioFormat, len(info.RealtimeTools))
|
||||
}
|
||||
|
||||
// Reasoning
|
||||
if info.ReasoningEffort != "" {
|
||||
fmt.Fprintf(b, "ReasoningEffort: %q, ", info.ReasoningEffort)
|
||||
}
|
||||
|
||||
// Price data (non-sensitive)
|
||||
if info.PriceData.UsePrice {
|
||||
fmt.Fprintf(b, "PriceData{ %s }, ", info.PriceData.ToSetting())
|
||||
}
|
||||
|
||||
// Channel metadata (mask ApiKey)
|
||||
if info.ChannelMeta != nil {
|
||||
cm := info.ChannelMeta
|
||||
fmt.Fprintf(b, "ChannelMeta{ Type: %d, Id: %d, IsMultiKey: %t, MultiKeyIndex: %d, BaseURL: %q, ApiType: %d, ApiVersion: %q, Organization: %q, CreateTime: %d, UpstreamModelName: %q, IsModelMapped: %t, SupportStreamOptions: %t, ApiKey: ***masked*** }, ",
|
||||
cm.ChannelType, cm.ChannelId, cm.ChannelIsMultiKey, cm.ChannelMultiKeyIndex, cm.ChannelBaseUrl, cm.ApiType, cm.ApiVersion, cm.Organization, cm.ChannelCreateTime, cm.UpstreamModelName, cm.IsModelMapped, cm.SupportStreamOptions)
|
||||
}
|
||||
|
||||
// Responses usage info (non-sensitive)
|
||||
if info.ResponsesUsageInfo != nil && len(info.ResponsesUsageInfo.BuiltInTools) > 0 {
|
||||
fmt.Fprintf(b, "ResponsesTools{ ")
|
||||
first := true
|
||||
for name, tool := range info.ResponsesUsageInfo.BuiltInTools {
|
||||
if !first {
|
||||
fmt.Fprintf(b, ", ")
|
||||
}
|
||||
first = false
|
||||
if tool != nil {
|
||||
fmt.Fprintf(b, "%s: calls=%d", name, tool.CallCount)
|
||||
} else {
|
||||
fmt.Fprintf(b, "%s: calls=0", name)
|
||||
}
|
||||
}
|
||||
fmt.Fprintf(b, " }, ")
|
||||
}
|
||||
|
||||
fmt.Fprintf(b, "}")
|
||||
return b.String()
|
||||
}
|
||||
|
||||
// 定义支持流式选项的通道类型
|
||||
@@ -127,7 +263,8 @@ var streamSupportedChannels = map[int]bool{
|
||||
}
|
||||
|
||||
func GenRelayInfoWs(c *gin.Context, ws *websocket.Conn) *RelayInfo {
|
||||
info := GenRelayInfo(c)
|
||||
info := genBaseRelayInfo(c, nil)
|
||||
info.RelayFormat = types.RelayFormatOpenAIRealtime
|
||||
info.ClientWs = ws
|
||||
info.InputAudioFormat = "pcm16"
|
||||
info.OutputAudioFormat = "pcm16"
|
||||
@@ -135,9 +272,9 @@ func GenRelayInfoWs(c *gin.Context, ws *websocket.Conn) *RelayInfo {
|
||||
return info
|
||||
}
|
||||
|
||||
func GenRelayInfoClaude(c *gin.Context) *RelayInfo {
|
||||
info := GenRelayInfo(c)
|
||||
info.RelayFormat = RelayFormatClaude
|
||||
func GenRelayInfoClaude(c *gin.Context, request dto.Request) *RelayInfo {
|
||||
info := genBaseRelayInfo(c, request)
|
||||
info.RelayFormat = types.RelayFormatClaude
|
||||
info.ShouldIncludeUsage = false
|
||||
info.ClaudeConvertInfo = &ClaudeConvertInfo{
|
||||
LastMessagesType: LastMessageTypeNone,
|
||||
@@ -145,41 +282,39 @@ func GenRelayInfoClaude(c *gin.Context) *RelayInfo {
|
||||
return info
|
||||
}
|
||||
|
||||
func GenRelayInfoRerank(c *gin.Context, req *dto.RerankRequest) *RelayInfo {
|
||||
info := GenRelayInfo(c)
|
||||
func GenRelayInfoRerank(c *gin.Context, request *dto.RerankRequest) *RelayInfo {
|
||||
info := genBaseRelayInfo(c, request)
|
||||
info.RelayMode = relayconstant.RelayModeRerank
|
||||
info.RelayFormat = RelayFormatRerank
|
||||
info.RelayFormat = types.RelayFormatRerank
|
||||
info.RerankerInfo = &RerankerInfo{
|
||||
Documents: req.Documents,
|
||||
ReturnDocuments: req.GetReturnDocuments(),
|
||||
Documents: request.Documents,
|
||||
ReturnDocuments: request.GetReturnDocuments(),
|
||||
}
|
||||
return info
|
||||
}
|
||||
|
||||
func GenRelayInfoOpenAIAudio(c *gin.Context) *RelayInfo {
|
||||
info := GenRelayInfo(c)
|
||||
info.RelayFormat = RelayFormatOpenAIAudio
|
||||
func GenRelayInfoOpenAIAudio(c *gin.Context, request dto.Request) *RelayInfo {
|
||||
info := genBaseRelayInfo(c, request)
|
||||
info.RelayFormat = types.RelayFormatOpenAIAudio
|
||||
return info
|
||||
}
|
||||
|
||||
func GenRelayInfoEmbedding(c *gin.Context) *RelayInfo {
|
||||
info := GenRelayInfo(c)
|
||||
info.RelayFormat = RelayFormatEmbedding
|
||||
func GenRelayInfoEmbedding(c *gin.Context, request dto.Request) *RelayInfo {
|
||||
info := genBaseRelayInfo(c, request)
|
||||
info.RelayFormat = types.RelayFormatEmbedding
|
||||
return info
|
||||
}
|
||||
|
||||
func GenRelayInfoResponses(c *gin.Context, req *dto.OpenAIResponsesRequest) *RelayInfo {
|
||||
info := GenRelayInfo(c)
|
||||
func GenRelayInfoResponses(c *gin.Context, request *dto.OpenAIResponsesRequest) *RelayInfo {
|
||||
info := genBaseRelayInfo(c, request)
|
||||
info.RelayMode = relayconstant.RelayModeResponses
|
||||
info.RelayFormat = RelayFormatOpenAIResponses
|
||||
|
||||
info.SupportStreamOptions = false
|
||||
info.RelayFormat = types.RelayFormatOpenAIResponses
|
||||
|
||||
info.ResponsesUsageInfo = &ResponsesUsageInfo{
|
||||
BuiltInTools: make(map[string]*BuildInToolInfo),
|
||||
}
|
||||
if len(req.Tools) > 0 {
|
||||
for _, tool := range req.Tools {
|
||||
if len(request.Tools) > 0 {
|
||||
for _, tool := range request.GetToolsMap() {
|
||||
toolType := common.Interface2String(tool["type"])
|
||||
info.ResponsesUsageInfo.BuiltInTools[toolType] = &BuildInToolInfo{
|
||||
ToolName: toolType,
|
||||
@@ -195,93 +330,87 @@ func GenRelayInfoResponses(c *gin.Context, req *dto.OpenAIResponsesRequest) *Rel
|
||||
}
|
||||
}
|
||||
}
|
||||
info.IsStream = req.Stream
|
||||
return info
|
||||
}
|
||||
|
||||
func GenRelayInfoGemini(c *gin.Context) *RelayInfo {
|
||||
info := GenRelayInfo(c)
|
||||
info.RelayFormat = RelayFormatGemini
|
||||
func GenRelayInfoGemini(c *gin.Context, request dto.Request) *RelayInfo {
|
||||
info := genBaseRelayInfo(c, request)
|
||||
info.RelayFormat = types.RelayFormatGemini
|
||||
info.ShouldIncludeUsage = false
|
||||
|
||||
return info
|
||||
}
|
||||
|
||||
func GenRelayInfoImage(c *gin.Context) *RelayInfo {
|
||||
info := GenRelayInfo(c)
|
||||
info.RelayFormat = RelayFormatOpenAIImage
|
||||
func GenRelayInfoImage(c *gin.Context, request dto.Request) *RelayInfo {
|
||||
info := genBaseRelayInfo(c, request)
|
||||
info.RelayFormat = types.RelayFormatOpenAIImage
|
||||
return info
|
||||
}
|
||||
|
||||
func GenRelayInfo(c *gin.Context) *RelayInfo {
|
||||
channelType := common.GetContextKeyInt(c, constant.ContextKeyChannelType)
|
||||
channelId := common.GetContextKeyInt(c, constant.ContextKeyChannelId)
|
||||
paramOverride := common.GetContextKeyStringMap(c, constant.ContextKeyChannelParamOverride)
|
||||
func GenRelayInfoOpenAI(c *gin.Context, request dto.Request) *RelayInfo {
|
||||
info := genBaseRelayInfo(c, request)
|
||||
info.RelayFormat = types.RelayFormatOpenAI
|
||||
return info
|
||||
}
|
||||
|
||||
func genBaseRelayInfo(c *gin.Context, request dto.Request) *RelayInfo {
|
||||
|
||||
//channelType := common.GetContextKeyInt(c, constant.ContextKeyChannelType)
|
||||
//channelId := common.GetContextKeyInt(c, constant.ContextKeyChannelId)
|
||||
//paramOverride := common.GetContextKeyStringMap(c, constant.ContextKeyChannelParamOverride)
|
||||
|
||||
tokenId := common.GetContextKeyInt(c, constant.ContextKeyTokenId)
|
||||
tokenKey := common.GetContextKeyString(c, constant.ContextKeyTokenKey)
|
||||
userId := common.GetContextKeyInt(c, constant.ContextKeyUserId)
|
||||
tokenUnlimited := common.GetContextKeyBool(c, constant.ContextKeyTokenUnlimited)
|
||||
startTime := common.GetContextKeyTime(c, constant.ContextKeyRequestStartTime)
|
||||
if startTime.IsZero() {
|
||||
startTime = time.Now()
|
||||
}
|
||||
|
||||
isStream := false
|
||||
|
||||
if request != nil {
|
||||
isStream = request.IsStream(c)
|
||||
}
|
||||
|
||||
// firstResponseTime = time.Now() - 1 second
|
||||
|
||||
apiType, _ := common.ChannelType2APIType(channelType)
|
||||
|
||||
info := &RelayInfo{
|
||||
UserQuota: common.GetContextKeyInt(c, constant.ContextKeyUserQuota),
|
||||
UserEmail: common.GetContextKeyString(c, constant.ContextKeyUserEmail),
|
||||
isFirstResponse: true,
|
||||
RelayMode: relayconstant.Path2RelayMode(c.Request.URL.Path),
|
||||
BaseUrl: common.GetContextKeyString(c, constant.ContextKeyChannelBaseUrl),
|
||||
RequestURLPath: c.Request.URL.String(),
|
||||
ChannelType: channelType,
|
||||
ChannelId: channelId,
|
||||
TokenId: tokenId,
|
||||
TokenKey: tokenKey,
|
||||
UserId: userId,
|
||||
UsingGroup: common.GetContextKeyString(c, constant.ContextKeyUsingGroup),
|
||||
UserGroup: common.GetContextKeyString(c, constant.ContextKeyUserGroup),
|
||||
TokenUnlimited: tokenUnlimited,
|
||||
Request: request,
|
||||
|
||||
UserId: common.GetContextKeyInt(c, constant.ContextKeyUserId),
|
||||
UsingGroup: common.GetContextKeyString(c, constant.ContextKeyUsingGroup),
|
||||
UserGroup: common.GetContextKeyString(c, constant.ContextKeyUserGroup),
|
||||
UserQuota: common.GetContextKeyInt(c, constant.ContextKeyUserQuota),
|
||||
UserEmail: common.GetContextKeyString(c, constant.ContextKeyUserEmail),
|
||||
|
||||
OriginModelName: common.GetContextKeyString(c, constant.ContextKeyOriginalModel),
|
||||
PromptTokens: common.GetContextKeyInt(c, constant.ContextKeyPromptTokens),
|
||||
|
||||
TokenId: common.GetContextKeyInt(c, constant.ContextKeyTokenId),
|
||||
TokenKey: common.GetContextKeyString(c, constant.ContextKeyTokenKey),
|
||||
TokenUnlimited: common.GetContextKeyBool(c, constant.ContextKeyTokenUnlimited),
|
||||
|
||||
isFirstResponse: true,
|
||||
RelayMode: relayconstant.Path2RelayMode(c.Request.URL.Path),
|
||||
RequestURLPath: c.Request.URL.String(),
|
||||
IsStream: isStream,
|
||||
|
||||
StartTime: startTime,
|
||||
FirstResponseTime: startTime.Add(-time.Second),
|
||||
OriginModelName: common.GetContextKeyString(c, constant.ContextKeyOriginalModel),
|
||||
UpstreamModelName: common.GetContextKeyString(c, constant.ContextKeyOriginalModel),
|
||||
//RecodeModelName: c.GetString("original_model"),
|
||||
IsModelMapped: false,
|
||||
ApiType: apiType,
|
||||
ApiVersion: c.GetString("api_version"),
|
||||
ApiKey: common.GetContextKeyString(c, constant.ContextKeyChannelKey),
|
||||
Organization: c.GetString("channel_organization"),
|
||||
|
||||
ChannelCreateTime: c.GetInt64("channel_create_time"),
|
||||
ParamOverride: paramOverride,
|
||||
RelayFormat: RelayFormatOpenAI,
|
||||
ThinkingContentInfo: ThinkingContentInfo{
|
||||
IsFirstThinkingContent: true,
|
||||
SendLastThinkingContent: false,
|
||||
},
|
||||
}
|
||||
|
||||
if info.RelayMode == relayconstant.RelayModeUnknown {
|
||||
info.RelayMode = c.GetInt("relay_mode")
|
||||
}
|
||||
|
||||
if strings.HasPrefix(c.Request.URL.Path, "/pg") {
|
||||
info.IsPlayground = true
|
||||
info.RequestURLPath = strings.TrimPrefix(info.RequestURLPath, "/pg")
|
||||
info.RequestURLPath = "/v1" + info.RequestURLPath
|
||||
}
|
||||
if info.BaseUrl == "" {
|
||||
info.BaseUrl = constant.ChannelBaseURLs[channelType]
|
||||
}
|
||||
if info.ChannelType == constant.ChannelTypeAzure {
|
||||
info.ApiVersion = GetAPIVersion(c)
|
||||
}
|
||||
if info.ChannelType == constant.ChannelTypeVertexAi {
|
||||
info.ApiVersion = c.GetString("region")
|
||||
}
|
||||
if streamSupportedChannels[info.ChannelType] {
|
||||
info.SupportStreamOptions = true
|
||||
}
|
||||
|
||||
channelSetting, ok := common.GetContextKeyType[dto.ChannelSettings](c, constant.ContextKeyChannelSetting)
|
||||
if ok {
|
||||
info.ChannelSetting = channelSetting
|
||||
}
|
||||
userSetting, ok := common.GetContextKeyType[dto.UserSetting](c, constant.ContextKeyUserSetting)
|
||||
if ok {
|
||||
info.UserSetting = userSetting
|
||||
@@ -290,12 +419,43 @@ func GenRelayInfo(c *gin.Context) *RelayInfo {
|
||||
return info
|
||||
}
|
||||
|
||||
func (info *RelayInfo) SetPromptTokens(promptTokens int) {
|
||||
info.PromptTokens = promptTokens
|
||||
func GenRelayInfo(c *gin.Context, relayFormat types.RelayFormat, request dto.Request, ws *websocket.Conn) (*RelayInfo, error) {
|
||||
switch relayFormat {
|
||||
case types.RelayFormatOpenAI:
|
||||
return GenRelayInfoOpenAI(c, request), nil
|
||||
case types.RelayFormatOpenAIAudio:
|
||||
return GenRelayInfoOpenAIAudio(c, request), nil
|
||||
case types.RelayFormatOpenAIImage:
|
||||
return GenRelayInfoImage(c, request), nil
|
||||
case types.RelayFormatOpenAIRealtime:
|
||||
return GenRelayInfoWs(c, ws), nil
|
||||
case types.RelayFormatClaude:
|
||||
return GenRelayInfoClaude(c, request), nil
|
||||
case types.RelayFormatRerank:
|
||||
if request, ok := request.(*dto.RerankRequest); ok {
|
||||
return GenRelayInfoRerank(c, request), nil
|
||||
}
|
||||
return nil, errors.New("request is not a RerankRequest")
|
||||
case types.RelayFormatGemini:
|
||||
return GenRelayInfoGemini(c, request), nil
|
||||
case types.RelayFormatEmbedding:
|
||||
return GenRelayInfoEmbedding(c, request), nil
|
||||
case types.RelayFormatOpenAIResponses:
|
||||
if request, ok := request.(*dto.OpenAIResponsesRequest); ok {
|
||||
return GenRelayInfoResponses(c, request), nil
|
||||
}
|
||||
return nil, errors.New("request is not a OpenAIResponsesRequest")
|
||||
case types.RelayFormatTask:
|
||||
return genBaseRelayInfo(c, nil), nil
|
||||
case types.RelayFormatMjProxy:
|
||||
return genBaseRelayInfo(c, nil), nil
|
||||
default:
|
||||
return nil, errors.New("invalid relay format")
|
||||
}
|
||||
}
|
||||
|
||||
func (info *RelayInfo) SetIsStream(isStream bool) {
|
||||
info.IsStream = isStream
|
||||
func (info *RelayInfo) SetPromptTokens(promptTokens int) {
|
||||
info.PromptTokens = promptTokens
|
||||
}
|
||||
|
||||
func (info *RelayInfo) SetFirstResponseTime() {
|
||||
@@ -310,20 +470,12 @@ func (info *RelayInfo) HasSendResponse() bool {
|
||||
}
|
||||
|
||||
type TaskRelayInfo struct {
|
||||
*RelayInfo
|
||||
Action string
|
||||
OriginTaskID string
|
||||
|
||||
ConsumeQuota bool
|
||||
}
|
||||
|
||||
func GenTaskRelayInfo(c *gin.Context) *TaskRelayInfo {
|
||||
info := &TaskRelayInfo{
|
||||
RelayInfo: GenRelayInfo(c),
|
||||
}
|
||||
return info
|
||||
}
|
||||
|
||||
type TaskSubmitReq struct {
|
||||
Prompt string `json:"prompt"`
|
||||
Model string `json:"model,omitempty"`
|
||||
|
||||
@@ -2,12 +2,10 @@ package common
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"github.com/gin-gonic/gin"
|
||||
_ "image/gif"
|
||||
_ "image/jpeg"
|
||||
_ "image/png"
|
||||
"one-api/constant"
|
||||
"strings"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
func GetFullRequestURL(baseURL string, requestURL string, channelType int) string {
|
||||
|
||||
@@ -8,6 +8,7 @@ import (
|
||||
"one-api/dto"
|
||||
"one-api/relay/channel/xinference"
|
||||
relaycommon "one-api/relay/common"
|
||||
"one-api/service"
|
||||
"one-api/types"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
@@ -16,9 +17,9 @@ import (
|
||||
func RerankHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http.Response) (*dto.Usage, *types.NewAPIError) {
|
||||
responseBody, err := io.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
return nil, types.NewError(err, types.ErrorCodeReadResponseBodyFailed)
|
||||
return nil, types.NewOpenAIError(err, types.ErrorCodeReadResponseBodyFailed, http.StatusInternalServerError)
|
||||
}
|
||||
common.CloseResponseBodyGracefully(resp)
|
||||
service.CloseResponseBodyGracefully(resp)
|
||||
if common.DebugEnabled {
|
||||
println("reranker response body: ", string(responseBody))
|
||||
}
|
||||
@@ -27,7 +28,7 @@ func RerankHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http.Respo
|
||||
var xinRerankResponse xinference.XinRerankResponse
|
||||
err = common.Unmarshal(responseBody, &xinRerankResponse)
|
||||
if err != nil {
|
||||
return nil, types.NewError(err, types.ErrorCodeBadResponseBody)
|
||||
return nil, types.NewOpenAIError(err, types.ErrorCodeBadResponseBody, http.StatusInternalServerError)
|
||||
}
|
||||
jinaRespResults := make([]dto.RerankResponseResult, len(xinRerankResponse.Results))
|
||||
for i, result := range xinRerankResponse.Results {
|
||||
@@ -62,7 +63,7 @@ func RerankHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http.Respo
|
||||
} else {
|
||||
err = common.Unmarshal(responseBody, &jinaResp)
|
||||
if err != nil {
|
||||
return nil, types.NewError(err, types.ErrorCodeBadResponseBody)
|
||||
return nil, types.NewOpenAIError(err, types.ErrorCodeBadResponseBody, http.StatusInternalServerError)
|
||||
}
|
||||
jinaResp.Usage.PromptTokens = jinaResp.Usage.TotalTokens
|
||||
}
|
||||
|
||||
@@ -2,213 +2,152 @@ package relay
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"math"
|
||||
"net/http"
|
||||
"one-api/common"
|
||||
"one-api/constant"
|
||||
"one-api/dto"
|
||||
"one-api/logger"
|
||||
"one-api/model"
|
||||
relaycommon "one-api/relay/common"
|
||||
relayconstant "one-api/relay/constant"
|
||||
"one-api/relay/helper"
|
||||
"one-api/service"
|
||||
"one-api/setting"
|
||||
"one-api/setting/model_setting"
|
||||
"one-api/setting/operation_setting"
|
||||
"one-api/types"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/bytedance/gopkg/util/gopool"
|
||||
"github.com/shopspring/decimal"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
func getAndValidateTextRequest(c *gin.Context, relayInfo *relaycommon.RelayInfo) (*dto.GeneralOpenAIRequest, error) {
|
||||
textRequest := &dto.GeneralOpenAIRequest{}
|
||||
err := common.UnmarshalBodyReusable(c, textRequest)
|
||||
func TextHelper(c *gin.Context, info *relaycommon.RelayInfo) (newAPIError *types.NewAPIError) {
|
||||
info.InitChannelMeta(c)
|
||||
|
||||
textReq, ok := info.Request.(*dto.GeneralOpenAIRequest)
|
||||
if !ok {
|
||||
return types.NewErrorWithStatusCode(fmt.Errorf("invalid request type, expected dto.GeneralOpenAIRequest, got %T", info.Request), types.ErrorCodeInvalidRequest, http.StatusBadRequest, types.ErrOptionWithSkipRetry())
|
||||
}
|
||||
|
||||
request, err := common.DeepCopy(textReq)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if relayInfo.RelayMode == relayconstant.RelayModeModerations && textRequest.Model == "" {
|
||||
textRequest.Model = "text-moderation-latest"
|
||||
}
|
||||
if relayInfo.RelayMode == relayconstant.RelayModeEmbeddings && textRequest.Model == "" {
|
||||
textRequest.Model = c.Param("model")
|
||||
return types.NewError(fmt.Errorf("failed to copy request to GeneralOpenAIRequest: %w", err), types.ErrorCodeInvalidRequest, types.ErrOptionWithSkipRetry())
|
||||
}
|
||||
|
||||
if textRequest.MaxTokens > math.MaxInt32/2 {
|
||||
return nil, errors.New("max_tokens is invalid")
|
||||
if request.WebSearchOptions != nil {
|
||||
c.Set("chat_completion_web_search_context_size", request.WebSearchOptions.SearchContextSize)
|
||||
}
|
||||
if textRequest.Model == "" {
|
||||
return nil, errors.New("model is required")
|
||||
}
|
||||
if textRequest.WebSearchOptions != nil {
|
||||
if textRequest.WebSearchOptions.SearchContextSize != "" {
|
||||
validSizes := map[string]bool{
|
||||
"high": true,
|
||||
"medium": true,
|
||||
"low": true,
|
||||
}
|
||||
if !validSizes[textRequest.WebSearchOptions.SearchContextSize] {
|
||||
return nil, errors.New("invalid search_context_size, must be one of: high, medium, low")
|
||||
}
|
||||
} else {
|
||||
textRequest.WebSearchOptions.SearchContextSize = "medium"
|
||||
}
|
||||
}
|
||||
switch relayInfo.RelayMode {
|
||||
case relayconstant.RelayModeCompletions:
|
||||
if textRequest.Prompt == "" {
|
||||
return nil, errors.New("field prompt is required")
|
||||
}
|
||||
case relayconstant.RelayModeChatCompletions:
|
||||
if len(textRequest.Messages) == 0 {
|
||||
return nil, errors.New("field messages is required")
|
||||
}
|
||||
case relayconstant.RelayModeEmbeddings:
|
||||
case relayconstant.RelayModeModerations:
|
||||
if textRequest.Input == nil || textRequest.Input == "" {
|
||||
return nil, errors.New("field input is required")
|
||||
}
|
||||
case relayconstant.RelayModeEdits:
|
||||
if textRequest.Instruction == "" {
|
||||
return nil, errors.New("field instruction is required")
|
||||
}
|
||||
}
|
||||
relayInfo.IsStream = textRequest.Stream
|
||||
return textRequest, nil
|
||||
}
|
||||
|
||||
func TextHelper(c *gin.Context) (newAPIError *types.NewAPIError) {
|
||||
|
||||
relayInfo := relaycommon.GenRelayInfo(c)
|
||||
|
||||
// get & validate textRequest 获取并验证文本请求
|
||||
textRequest, err := getAndValidateTextRequest(c, relayInfo)
|
||||
|
||||
err = helper.ModelMappedHelper(c, info, request)
|
||||
if err != nil {
|
||||
return types.NewError(err, types.ErrorCodeInvalidRequest)
|
||||
return types.NewError(err, types.ErrorCodeChannelModelMappedError, types.ErrOptionWithSkipRetry())
|
||||
}
|
||||
|
||||
if textRequest.WebSearchOptions != nil {
|
||||
c.Set("chat_completion_web_search_context_size", textRequest.WebSearchOptions.SearchContextSize)
|
||||
}
|
||||
|
||||
if setting.ShouldCheckPromptSensitive() {
|
||||
words, err := checkRequestSensitive(textRequest, relayInfo)
|
||||
if err != nil {
|
||||
common.LogWarn(c, fmt.Sprintf("user sensitive words detected: %s", strings.Join(words, ", ")))
|
||||
return types.NewError(err, types.ErrorCodeSensitiveWordsDetected)
|
||||
}
|
||||
}
|
||||
|
||||
err = helper.ModelMappedHelper(c, relayInfo, textRequest)
|
||||
if err != nil {
|
||||
return types.NewError(err, types.ErrorCodeChannelModelMappedError)
|
||||
}
|
||||
|
||||
// 获取 promptTokens,如果上下文中已经存在,则直接使用
|
||||
var promptTokens int
|
||||
if value, exists := c.Get("prompt_tokens"); exists {
|
||||
promptTokens = value.(int)
|
||||
relayInfo.PromptTokens = promptTokens
|
||||
} else {
|
||||
promptTokens, err = getPromptTokens(textRequest, relayInfo)
|
||||
// count messages token error 计算promptTokens错误
|
||||
if err != nil {
|
||||
return types.NewError(err, types.ErrorCodeCountTokenFailed)
|
||||
}
|
||||
c.Set("prompt_tokens", promptTokens)
|
||||
}
|
||||
|
||||
priceData, err := helper.ModelPriceHelper(c, relayInfo, promptTokens, int(math.Max(float64(textRequest.MaxTokens), float64(textRequest.MaxCompletionTokens))))
|
||||
if err != nil {
|
||||
return types.NewError(err, types.ErrorCodeModelPriceError)
|
||||
}
|
||||
|
||||
// pre-consume quota 预消耗配额
|
||||
preConsumedQuota, userQuota, newApiErr := preConsumeQuota(c, priceData.ShouldPreConsumedQuota, relayInfo)
|
||||
if newApiErr != nil {
|
||||
return newApiErr
|
||||
}
|
||||
defer func() {
|
||||
if newApiErr != nil {
|
||||
returnPreConsumedQuota(c, relayInfo, userQuota, preConsumedQuota)
|
||||
}
|
||||
}()
|
||||
includeUsage := false
|
||||
includeUsage := true
|
||||
// 判断用户是否需要返回使用情况
|
||||
if textRequest.StreamOptions != nil && textRequest.StreamOptions.IncludeUsage {
|
||||
includeUsage = true
|
||||
if request.StreamOptions != nil {
|
||||
includeUsage = request.StreamOptions.IncludeUsage
|
||||
}
|
||||
|
||||
// 如果不支持StreamOptions,将StreamOptions设置为nil
|
||||
if !relayInfo.SupportStreamOptions || !textRequest.Stream {
|
||||
textRequest.StreamOptions = nil
|
||||
if !info.SupportStreamOptions || !request.Stream {
|
||||
request.StreamOptions = nil
|
||||
} else {
|
||||
// 如果支持StreamOptions,且请求中没有设置StreamOptions,根据配置文件设置StreamOptions
|
||||
if constant.ForceStreamOption {
|
||||
textRequest.StreamOptions = &dto.StreamOptions{
|
||||
request.StreamOptions = &dto.StreamOptions{
|
||||
IncludeUsage: true,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if includeUsage {
|
||||
relayInfo.ShouldIncludeUsage = true
|
||||
}
|
||||
info.ShouldIncludeUsage = includeUsage
|
||||
|
||||
adaptor := GetAdaptor(relayInfo.ApiType)
|
||||
adaptor := GetAdaptor(info.ApiType)
|
||||
if adaptor == nil {
|
||||
return types.NewError(fmt.Errorf("invalid api type: %d", relayInfo.ApiType), types.ErrorCodeInvalidApiType)
|
||||
return types.NewError(fmt.Errorf("invalid api type: %d", info.ApiType), types.ErrorCodeInvalidApiType, types.ErrOptionWithSkipRetry())
|
||||
}
|
||||
adaptor.Init(relayInfo)
|
||||
adaptor.Init(info)
|
||||
var requestBody io.Reader
|
||||
|
||||
if model_setting.GetGlobalSettings().PassThroughRequestEnabled {
|
||||
if model_setting.GetGlobalSettings().PassThroughRequestEnabled || info.ChannelSetting.PassThroughBodyEnabled {
|
||||
body, err := common.GetRequestBody(c)
|
||||
if err != nil {
|
||||
return types.NewErrorWithStatusCode(err, types.ErrorCodeReadRequestBodyFailed, http.StatusBadRequest)
|
||||
return types.NewErrorWithStatusCode(err, types.ErrorCodeReadRequestBodyFailed, http.StatusBadRequest, types.ErrOptionWithSkipRetry())
|
||||
}
|
||||
if common.DebugEnabled {
|
||||
println("requestBody: ", string(body))
|
||||
}
|
||||
requestBody = bytes.NewBuffer(body)
|
||||
} else {
|
||||
convertedRequest, err := adaptor.ConvertOpenAIRequest(c, relayInfo, textRequest)
|
||||
convertedRequest, err := adaptor.ConvertOpenAIRequest(c, info, request)
|
||||
if err != nil {
|
||||
return types.NewError(err, types.ErrorCodeConvertRequestFailed)
|
||||
return types.NewError(err, types.ErrorCodeConvertRequestFailed, types.ErrOptionWithSkipRetry())
|
||||
}
|
||||
jsonData, err := json.Marshal(convertedRequest)
|
||||
|
||||
if info.ChannelSetting.SystemPrompt != "" {
|
||||
// 如果有系统提示,则将其添加到请求中
|
||||
request := convertedRequest.(*dto.GeneralOpenAIRequest)
|
||||
containSystemPrompt := false
|
||||
for _, message := range request.Messages {
|
||||
if message.Role == request.GetSystemRoleName() {
|
||||
containSystemPrompt = true
|
||||
break
|
||||
}
|
||||
}
|
||||
if !containSystemPrompt {
|
||||
// 如果没有系统提示,则添加系统提示
|
||||
systemMessage := dto.Message{
|
||||
Role: request.GetSystemRoleName(),
|
||||
Content: info.ChannelSetting.SystemPrompt,
|
||||
}
|
||||
request.Messages = append([]dto.Message{systemMessage}, request.Messages...)
|
||||
} else if info.ChannelSetting.SystemPromptOverride {
|
||||
common.SetContextKey(c, constant.ContextKeySystemPromptOverride, true)
|
||||
// 如果有系统提示,且允许覆盖,则拼接到前面
|
||||
for i, message := range request.Messages {
|
||||
if message.Role == request.GetSystemRoleName() {
|
||||
if message.IsStringContent() {
|
||||
request.Messages[i].SetStringContent(info.ChannelSetting.SystemPrompt + "\n" + message.StringContent())
|
||||
} else {
|
||||
contents := message.ParseContent()
|
||||
contents = append([]dto.MediaContent{
|
||||
{
|
||||
Type: dto.ContentTypeText,
|
||||
Text: info.ChannelSetting.SystemPrompt,
|
||||
},
|
||||
}, contents...)
|
||||
request.Messages[i].Content = contents
|
||||
}
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
jsonData, err := common.Marshal(convertedRequest)
|
||||
if err != nil {
|
||||
return types.NewError(err, types.ErrorCodeConvertRequestFailed)
|
||||
return types.NewError(err, types.ErrorCodeJsonMarshalFailed, types.ErrOptionWithSkipRetry())
|
||||
}
|
||||
|
||||
// apply param override
|
||||
if len(relayInfo.ParamOverride) > 0 {
|
||||
reqMap := make(map[string]interface{})
|
||||
_ = common.Unmarshal(jsonData, &reqMap)
|
||||
for key, value := range relayInfo.ParamOverride {
|
||||
reqMap[key] = value
|
||||
}
|
||||
jsonData, err = common.Marshal(reqMap)
|
||||
if len(info.ParamOverride) > 0 {
|
||||
jsonData, err = relaycommon.ApplyParamOverride(jsonData, info.ParamOverride)
|
||||
if err != nil {
|
||||
return types.NewError(err, types.ErrorCodeChannelParamOverrideInvalid)
|
||||
return types.NewError(err, types.ErrorCodeChannelParamOverrideInvalid, types.ErrOptionWithSkipRetry())
|
||||
}
|
||||
}
|
||||
|
||||
if common.DebugEnabled {
|
||||
println("requestBody: ", string(jsonData))
|
||||
}
|
||||
logger.LogDebug(c, fmt.Sprintf("text request body: %s", string(jsonData)))
|
||||
|
||||
requestBody = bytes.NewBuffer(jsonData)
|
||||
}
|
||||
|
||||
var httpResp *http.Response
|
||||
resp, err := adaptor.DoRequest(c, relayInfo, requestBody)
|
||||
|
||||
resp, err := adaptor.DoRequest(c, info, requestBody)
|
||||
if err != nil {
|
||||
return types.NewOpenAIError(err, types.ErrorCodeDoRequestFailed, http.StatusInternalServerError)
|
||||
}
|
||||
@@ -217,125 +156,31 @@ func TextHelper(c *gin.Context) (newAPIError *types.NewAPIError) {
|
||||
|
||||
if resp != nil {
|
||||
httpResp = resp.(*http.Response)
|
||||
relayInfo.IsStream = relayInfo.IsStream || strings.HasPrefix(httpResp.Header.Get("Content-Type"), "text/event-stream")
|
||||
info.IsStream = info.IsStream || strings.HasPrefix(httpResp.Header.Get("Content-Type"), "text/event-stream")
|
||||
if httpResp.StatusCode != http.StatusOK {
|
||||
newApiErr = service.RelayErrorHandler(httpResp, false)
|
||||
newApiErr := service.RelayErrorHandler(httpResp, false)
|
||||
// reset status code 重置状态码
|
||||
service.ResetStatusCode(newApiErr, statusCodeMappingStr)
|
||||
return newApiErr
|
||||
}
|
||||
}
|
||||
|
||||
usage, newApiErr := adaptor.DoResponse(c, httpResp, relayInfo)
|
||||
usage, newApiErr := adaptor.DoResponse(c, httpResp, info)
|
||||
if newApiErr != nil {
|
||||
// reset status code 重置状态码
|
||||
service.ResetStatusCode(newApiErr, statusCodeMappingStr)
|
||||
return newApiErr
|
||||
}
|
||||
|
||||
if strings.HasPrefix(relayInfo.OriginModelName, "gpt-4o-audio") {
|
||||
service.PostAudioConsumeQuota(c, relayInfo, usage.(*dto.Usage), preConsumedQuota, userQuota, priceData, "")
|
||||
if strings.HasPrefix(info.OriginModelName, "gpt-4o-audio") {
|
||||
service.PostAudioConsumeQuota(c, info, usage.(*dto.Usage), "")
|
||||
} else {
|
||||
postConsumeQuota(c, relayInfo, usage.(*dto.Usage), preConsumedQuota, userQuota, priceData, "")
|
||||
postConsumeQuota(c, info, usage.(*dto.Usage), "")
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func getPromptTokens(textRequest *dto.GeneralOpenAIRequest, info *relaycommon.RelayInfo) (int, error) {
|
||||
var promptTokens int
|
||||
var err error
|
||||
switch info.RelayMode {
|
||||
case relayconstant.RelayModeChatCompletions:
|
||||
promptTokens, err = service.CountTokenChatRequest(info, *textRequest)
|
||||
case relayconstant.RelayModeCompletions:
|
||||
promptTokens = service.CountTokenInput(textRequest.Prompt, textRequest.Model)
|
||||
case relayconstant.RelayModeModerations:
|
||||
promptTokens = service.CountTokenInput(textRequest.Input, textRequest.Model)
|
||||
case relayconstant.RelayModeEmbeddings:
|
||||
promptTokens = service.CountTokenInput(textRequest.Input, textRequest.Model)
|
||||
default:
|
||||
err = errors.New("unknown relay mode")
|
||||
promptTokens = 0
|
||||
}
|
||||
info.PromptTokens = promptTokens
|
||||
return promptTokens, err
|
||||
}
|
||||
|
||||
func checkRequestSensitive(textRequest *dto.GeneralOpenAIRequest, info *relaycommon.RelayInfo) ([]string, error) {
|
||||
var err error
|
||||
var words []string
|
||||
switch info.RelayMode {
|
||||
case relayconstant.RelayModeChatCompletions:
|
||||
words, err = service.CheckSensitiveMessages(textRequest.Messages)
|
||||
case relayconstant.RelayModeCompletions:
|
||||
words, err = service.CheckSensitiveInput(textRequest.Prompt)
|
||||
case relayconstant.RelayModeModerations:
|
||||
words, err = service.CheckSensitiveInput(textRequest.Input)
|
||||
case relayconstant.RelayModeEmbeddings:
|
||||
words, err = service.CheckSensitiveInput(textRequest.Input)
|
||||
}
|
||||
return words, err
|
||||
}
|
||||
|
||||
// 预扣费并返回用户剩余配额
|
||||
func preConsumeQuota(c *gin.Context, preConsumedQuota int, relayInfo *relaycommon.RelayInfo) (int, int, *types.NewAPIError) {
|
||||
userQuota, err := model.GetUserQuota(relayInfo.UserId, false)
|
||||
if err != nil {
|
||||
return 0, 0, types.NewError(err, types.ErrorCodeQueryDataError)
|
||||
}
|
||||
if userQuota <= 0 {
|
||||
return 0, 0, types.NewErrorWithStatusCode(errors.New("user quota is not enough"), types.ErrorCodeInsufficientUserQuota, http.StatusForbidden)
|
||||
}
|
||||
if userQuota-preConsumedQuota < 0 {
|
||||
return 0, 0, types.NewErrorWithStatusCode(fmt.Errorf("pre-consume quota failed, user quota: %s, need quota: %s", common.FormatQuota(userQuota), common.FormatQuota(preConsumedQuota)), types.ErrorCodeInsufficientUserQuota, http.StatusForbidden)
|
||||
}
|
||||
relayInfo.UserQuota = userQuota
|
||||
if userQuota > 100*preConsumedQuota {
|
||||
// 用户额度充足,判断令牌额度是否充足
|
||||
if !relayInfo.TokenUnlimited {
|
||||
// 非无限令牌,判断令牌额度是否充足
|
||||
tokenQuota := c.GetInt("token_quota")
|
||||
if tokenQuota > 100*preConsumedQuota {
|
||||
// 令牌额度充足,信任令牌
|
||||
preConsumedQuota = 0
|
||||
common.LogInfo(c, fmt.Sprintf("user %d quota %s and token %d quota %d are enough, trusted and no need to pre-consume", relayInfo.UserId, common.FormatQuota(userQuota), relayInfo.TokenId, tokenQuota))
|
||||
}
|
||||
} else {
|
||||
// in this case, we do not pre-consume quota
|
||||
// because the user has enough quota
|
||||
preConsumedQuota = 0
|
||||
common.LogInfo(c, fmt.Sprintf("user %d with unlimited token has enough quota %s, trusted and no need to pre-consume", relayInfo.UserId, common.FormatQuota(userQuota)))
|
||||
}
|
||||
}
|
||||
|
||||
if preConsumedQuota > 0 {
|
||||
err := service.PreConsumeTokenQuota(relayInfo, preConsumedQuota)
|
||||
if err != nil {
|
||||
return 0, 0, types.NewErrorWithStatusCode(err, types.ErrorCodePreConsumeTokenQuotaFailed, http.StatusForbidden)
|
||||
}
|
||||
err = model.DecreaseUserQuota(relayInfo.UserId, preConsumedQuota)
|
||||
if err != nil {
|
||||
return 0, 0, types.NewError(err, types.ErrorCodeUpdateDataError)
|
||||
}
|
||||
}
|
||||
return preConsumedQuota, userQuota, nil
|
||||
}
|
||||
|
||||
func returnPreConsumedQuota(c *gin.Context, relayInfo *relaycommon.RelayInfo, userQuota int, preConsumedQuota int) {
|
||||
if preConsumedQuota != 0 {
|
||||
gopool.Go(func() {
|
||||
relayInfoCopy := *relayInfo
|
||||
|
||||
err := service.PostConsumeQuota(&relayInfoCopy, -preConsumedQuota, 0, false)
|
||||
if err != nil {
|
||||
common.SysError("error return pre-consumed quota: " + err.Error())
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func postConsumeQuota(ctx *gin.Context, relayInfo *relaycommon.RelayInfo,
|
||||
usage *dto.Usage, preConsumedQuota int, userQuota int, priceData helper.PriceData, extraContent string) {
|
||||
func postConsumeQuota(ctx *gin.Context, relayInfo *relaycommon.RelayInfo, usage *dto.Usage, extraContent string) {
|
||||
if usage == nil {
|
||||
usage = &dto.Usage{
|
||||
PromptTokens: relayInfo.PromptTokens,
|
||||
@@ -353,12 +198,12 @@ func postConsumeQuota(ctx *gin.Context, relayInfo *relaycommon.RelayInfo,
|
||||
modelName := relayInfo.OriginModelName
|
||||
|
||||
tokenName := ctx.GetString("token_name")
|
||||
completionRatio := priceData.CompletionRatio
|
||||
cacheRatio := priceData.CacheRatio
|
||||
imageRatio := priceData.ImageRatio
|
||||
modelRatio := priceData.ModelRatio
|
||||
groupRatio := priceData.GroupRatioInfo.GroupRatio
|
||||
modelPrice := priceData.ModelPrice
|
||||
completionRatio := relayInfo.PriceData.CompletionRatio
|
||||
cacheRatio := relayInfo.PriceData.CacheRatio
|
||||
imageRatio := relayInfo.PriceData.ImageRatio
|
||||
modelRatio := relayInfo.PriceData.ModelRatio
|
||||
groupRatio := relayInfo.PriceData.GroupRatioInfo.GroupRatio
|
||||
modelPrice := relayInfo.PriceData.ModelPrice
|
||||
|
||||
// Convert values to decimal for precise calculation
|
||||
dPromptTokens := decimal.NewFromInt(int64(promptTokens))
|
||||
@@ -431,7 +276,7 @@ func postConsumeQuota(ctx *gin.Context, relayInfo *relaycommon.RelayInfo,
|
||||
|
||||
var audioInputQuota decimal.Decimal
|
||||
var audioInputPrice float64
|
||||
if !priceData.UsePrice {
|
||||
if !relayInfo.PriceData.UsePrice {
|
||||
baseTokens := dPromptTokens
|
||||
// 减去 cached tokens
|
||||
var cachedTokensWithRatio decimal.Decimal
|
||||
@@ -469,21 +314,27 @@ func postConsumeQuota(ctx *gin.Context, relayInfo *relaycommon.RelayInfo,
|
||||
} else {
|
||||
quotaCalculateDecimal = dModelPrice.Mul(dQuotaPerUnit).Mul(dGroupRatio)
|
||||
}
|
||||
var dGeminiImageOutputQuota decimal.Decimal
|
||||
var imageOutputPrice float64
|
||||
if strings.HasPrefix(modelName, "gemini-2.5-flash-image-preview") {
|
||||
imageOutputPrice = operation_setting.GetGeminiImageOutputPricePerMillionTokens(modelName)
|
||||
if imageOutputPrice > 0 {
|
||||
dImageOutputTokens := decimal.NewFromInt(int64(ctx.GetInt("gemini_image_tokens")))
|
||||
dGeminiImageOutputQuota = decimal.NewFromFloat(imageOutputPrice).Div(decimal.NewFromInt(1000000)).Mul(dImageOutputTokens).Mul(dGroupRatio).Mul(dQuotaPerUnit)
|
||||
}
|
||||
}
|
||||
// 添加 responses tools call 调用的配额
|
||||
quotaCalculateDecimal = quotaCalculateDecimal.Add(dWebSearchQuota)
|
||||
quotaCalculateDecimal = quotaCalculateDecimal.Add(dFileSearchQuota)
|
||||
// 添加 audio input 独立计费
|
||||
quotaCalculateDecimal = quotaCalculateDecimal.Add(audioInputQuota)
|
||||
// 添加 Gemini image output 计费
|
||||
quotaCalculateDecimal = quotaCalculateDecimal.Add(dGeminiImageOutputQuota)
|
||||
|
||||
quota := int(quotaCalculateDecimal.Round(0).IntPart())
|
||||
totalTokens := promptTokens + completionTokens
|
||||
|
||||
var logContent string
|
||||
if !priceData.UsePrice {
|
||||
logContent = fmt.Sprintf("模型倍率 %.2f,补全倍率 %.2f,分组倍率 %.2f", modelRatio, completionRatio, groupRatio)
|
||||
} else {
|
||||
logContent = fmt.Sprintf("模型价格 %.2f,分组倍率 %.2f", modelPrice, groupRatio)
|
||||
}
|
||||
|
||||
// record all the consume log even if quota is 0
|
||||
if totalTokens == 0 {
|
||||
@@ -491,18 +342,38 @@ func postConsumeQuota(ctx *gin.Context, relayInfo *relaycommon.RelayInfo,
|
||||
// we cannot just return, because we may have to return the pre-consumed quota
|
||||
quota = 0
|
||||
logContent += fmt.Sprintf("(可能是上游超时)")
|
||||
common.LogError(ctx, fmt.Sprintf("total tokens is 0, cannot consume quota, userId %d, channelId %d, "+
|
||||
"tokenId %d, model %s, pre-consumed quota %d", relayInfo.UserId, relayInfo.ChannelId, relayInfo.TokenId, modelName, preConsumedQuota))
|
||||
logger.LogError(ctx, fmt.Sprintf("total tokens is 0, cannot consume quota, userId %d, channelId %d, "+
|
||||
"tokenId %d, model %s, pre-consumed quota %d", relayInfo.UserId, relayInfo.ChannelId, relayInfo.TokenId, modelName, relayInfo.FinalPreConsumedQuota))
|
||||
} else {
|
||||
if !ratio.IsZero() && quota == 0 {
|
||||
quota = 1
|
||||
}
|
||||
model.UpdateUserUsedQuotaAndRequestCount(relayInfo.UserId, quota)
|
||||
model.UpdateChannelUsedQuota(relayInfo.ChannelId, quota)
|
||||
}
|
||||
|
||||
quotaDelta := quota - preConsumedQuota
|
||||
quotaDelta := quota - relayInfo.FinalPreConsumedQuota
|
||||
|
||||
//logger.LogInfo(ctx, fmt.Sprintf("request quota delta: %s", logger.FormatQuota(quotaDelta)))
|
||||
|
||||
if quotaDelta > 0 {
|
||||
logger.LogInfo(ctx, fmt.Sprintf("预扣费后补扣费:%s(实际消耗:%s,预扣费:%s)",
|
||||
logger.FormatQuota(quotaDelta),
|
||||
logger.FormatQuota(quota),
|
||||
logger.FormatQuota(relayInfo.FinalPreConsumedQuota),
|
||||
))
|
||||
} else if quotaDelta < 0 {
|
||||
logger.LogInfo(ctx, fmt.Sprintf("预扣费后返还扣费:%s(实际消耗:%s,预扣费:%s)",
|
||||
logger.FormatQuota(-quotaDelta),
|
||||
logger.FormatQuota(quota),
|
||||
logger.FormatQuota(relayInfo.FinalPreConsumedQuota),
|
||||
))
|
||||
}
|
||||
|
||||
if quotaDelta != 0 {
|
||||
err := service.PostConsumeQuota(relayInfo, quotaDelta, preConsumedQuota, true)
|
||||
err := service.PostConsumeQuota(relayInfo, quotaDelta, relayInfo.FinalPreConsumedQuota, true)
|
||||
if err != nil {
|
||||
common.LogError(ctx, "error consuming token remain quota: "+err.Error())
|
||||
logger.LogError(ctx, "error consuming token remain quota: "+err.Error())
|
||||
}
|
||||
}
|
||||
|
||||
@@ -518,7 +389,7 @@ func postConsumeQuota(ctx *gin.Context, relayInfo *relaycommon.RelayInfo,
|
||||
if extraContent != "" {
|
||||
logContent += ", " + extraContent
|
||||
}
|
||||
other := service.GenerateTextOtherInfo(ctx, relayInfo, modelRatio, groupRatio, completionRatio, cacheTokens, cacheRatio, modelPrice, priceData.GroupRatioInfo.GroupSpecialRatio)
|
||||
other := service.GenerateTextOtherInfo(ctx, relayInfo, modelRatio, groupRatio, completionRatio, cacheTokens, cacheRatio, modelPrice, relayInfo.PriceData.GroupRatioInfo.GroupSpecialRatio)
|
||||
if imageTokens != 0 {
|
||||
other["image"] = true
|
||||
other["image_ratio"] = imageRatio
|
||||
@@ -553,6 +424,10 @@ func postConsumeQuota(ctx *gin.Context, relayInfo *relaycommon.RelayInfo,
|
||||
other["audio_input_token_count"] = audioTokens
|
||||
other["audio_input_price"] = audioInputPrice
|
||||
}
|
||||
if !dGeminiImageOutputQuota.IsZero() {
|
||||
other["image_output_token_count"] = ctx.GetInt("gemini_image_tokens")
|
||||
other["image_output_price"] = imageOutputPrice
|
||||
}
|
||||
model.RecordConsumeLog(ctx, relayInfo.UserId, model.RecordConsumeLogParams{
|
||||
ChannelId: relayInfo.ChannelId,
|
||||
PromptTokens: promptTokens,
|
||||
@@ -562,7 +437,6 @@ func postConsumeQuota(ctx *gin.Context, relayInfo *relaycommon.RelayInfo,
|
||||
Quota: quota,
|
||||
Content: logContent,
|
||||
TokenId: relayInfo.TokenId,
|
||||
UserQuota: userQuota,
|
||||
UseTimeSeconds: int(useTimeSeconds),
|
||||
IsStream: relayInfo.IsStream,
|
||||
Group: relayInfo.UsingGroup,
|
||||
@@ -40,11 +40,8 @@ const (
|
||||
RelayModeSunoFetchByID
|
||||
RelayModeSunoSubmit
|
||||
|
||||
RelayModeKlingFetchByID
|
||||
RelayModeKlingSubmit
|
||||
|
||||
RelayModeJimengFetchByID
|
||||
RelayModeJimengSubmit
|
||||
RelayModeVideoFetchByID
|
||||
RelayModeVideoSubmit
|
||||
|
||||
RelayModeRerank
|
||||
|
||||
@@ -87,6 +84,8 @@ func Path2RelayMode(path string) int {
|
||||
relayMode = RelayModeRealtime
|
||||
} else if strings.HasPrefix(path, "/v1beta/models") || strings.HasPrefix(path, "/v1/models") {
|
||||
relayMode = RelayModeGemini
|
||||
} else if strings.HasPrefix(path, "/mj") {
|
||||
relayMode = Path2RelayModeMidjourney(path)
|
||||
}
|
||||
return relayMode
|
||||
}
|
||||
@@ -145,23 +144,3 @@ func Path2RelaySuno(method, path string) int {
|
||||
}
|
||||
return relayMode
|
||||
}
|
||||
|
||||
func Path2RelayKling(method, path string) int {
|
||||
relayMode := RelayModeUnknown
|
||||
if method == http.MethodPost && strings.HasSuffix(path, "/video/generations") {
|
||||
relayMode = RelayModeKlingSubmit
|
||||
} else if method == http.MethodGet && strings.Contains(path, "/video/generations/") {
|
||||
relayMode = RelayModeKlingFetchByID
|
||||
}
|
||||
return relayMode
|
||||
}
|
||||
|
||||
func Path2RelayJimeng(method, path string) int {
|
||||
relayMode := RelayModeUnknown
|
||||
if method == http.MethodPost && strings.HasSuffix(path, "/video/generations") {
|
||||
relayMode = RelayModeJimengSubmit
|
||||
} else if method == http.MethodGet && strings.Contains(path, "/video/generations/") {
|
||||
relayMode = RelayModeJimengFetchByID
|
||||
}
|
||||
return relayMode
|
||||
}
|
||||
|
||||
@@ -8,7 +8,6 @@ import (
|
||||
"one-api/common"
|
||||
"one-api/dto"
|
||||
relaycommon "one-api/relay/common"
|
||||
relayconstant "one-api/relay/constant"
|
||||
"one-api/relay/helper"
|
||||
"one-api/service"
|
||||
"one-api/types"
|
||||
@@ -16,80 +15,41 @@ import (
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
func getEmbeddingPromptToken(embeddingRequest dto.EmbeddingRequest) int {
|
||||
token := service.CountTokenInput(embeddingRequest.Input, embeddingRequest.Model)
|
||||
return token
|
||||
}
|
||||
func EmbeddingHelper(c *gin.Context, info *relaycommon.RelayInfo) (newAPIError *types.NewAPIError) {
|
||||
info.InitChannelMeta(c)
|
||||
|
||||
func validateEmbeddingRequest(c *gin.Context, info *relaycommon.RelayInfo, embeddingRequest dto.EmbeddingRequest) error {
|
||||
if embeddingRequest.Input == nil {
|
||||
return fmt.Errorf("input is empty")
|
||||
embeddingReq, ok := info.Request.(*dto.EmbeddingRequest)
|
||||
if !ok {
|
||||
return types.NewErrorWithStatusCode(fmt.Errorf("invalid request type, expected *dto.EmbeddingRequest, got %T", info.Request), types.ErrorCodeInvalidRequest, http.StatusBadRequest, types.ErrOptionWithSkipRetry())
|
||||
}
|
||||
if info.RelayMode == relayconstant.RelayModeModerations && embeddingRequest.Model == "" {
|
||||
embeddingRequest.Model = "omni-moderation-latest"
|
||||
}
|
||||
if info.RelayMode == relayconstant.RelayModeEmbeddings && embeddingRequest.Model == "" {
|
||||
embeddingRequest.Model = c.Param("model")
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func EmbeddingHelper(c *gin.Context) (newAPIError *types.NewAPIError) {
|
||||
relayInfo := relaycommon.GenRelayInfoEmbedding(c)
|
||||
|
||||
var embeddingRequest *dto.EmbeddingRequest
|
||||
err := common.UnmarshalBodyReusable(c, &embeddingRequest)
|
||||
request, err := common.DeepCopy(embeddingReq)
|
||||
if err != nil {
|
||||
common.LogError(c, fmt.Sprintf("getAndValidateTextRequest failed: %s", err.Error()))
|
||||
return types.NewError(err, types.ErrorCodeInvalidRequest)
|
||||
return types.NewError(fmt.Errorf("failed to copy request to EmbeddingRequest: %w", err), types.ErrorCodeInvalidRequest, types.ErrOptionWithSkipRetry())
|
||||
}
|
||||
|
||||
err = validateEmbeddingRequest(c, relayInfo, *embeddingRequest)
|
||||
err = helper.ModelMappedHelper(c, info, request)
|
||||
if err != nil {
|
||||
return types.NewError(err, types.ErrorCodeInvalidRequest)
|
||||
return types.NewError(err, types.ErrorCodeChannelModelMappedError, types.ErrOptionWithSkipRetry())
|
||||
}
|
||||
|
||||
err = helper.ModelMappedHelper(c, relayInfo, embeddingRequest)
|
||||
if err != nil {
|
||||
return types.NewError(err, types.ErrorCodeChannelModelMappedError)
|
||||
}
|
||||
|
||||
promptToken := getEmbeddingPromptToken(*embeddingRequest)
|
||||
relayInfo.PromptTokens = promptToken
|
||||
|
||||
priceData, err := helper.ModelPriceHelper(c, relayInfo, promptToken, 0)
|
||||
if err != nil {
|
||||
return types.NewError(err, types.ErrorCodeModelPriceError)
|
||||
}
|
||||
// pre-consume quota 预消耗配额
|
||||
preConsumedQuota, userQuota, newAPIError := preConsumeQuota(c, priceData.ShouldPreConsumedQuota, relayInfo)
|
||||
if newAPIError != nil {
|
||||
return newAPIError
|
||||
}
|
||||
defer func() {
|
||||
if newAPIError != nil {
|
||||
returnPreConsumedQuota(c, relayInfo, userQuota, preConsumedQuota)
|
||||
}
|
||||
}()
|
||||
|
||||
adaptor := GetAdaptor(relayInfo.ApiType)
|
||||
adaptor := GetAdaptor(info.ApiType)
|
||||
if adaptor == nil {
|
||||
return types.NewError(fmt.Errorf("invalid api type: %d", relayInfo.ApiType), types.ErrorCodeInvalidApiType)
|
||||
return types.NewError(fmt.Errorf("invalid api type: %d", info.ApiType), types.ErrorCodeInvalidApiType, types.ErrOptionWithSkipRetry())
|
||||
}
|
||||
adaptor.Init(relayInfo)
|
||||
|
||||
convertedRequest, err := adaptor.ConvertEmbeddingRequest(c, relayInfo, *embeddingRequest)
|
||||
adaptor.Init(info)
|
||||
|
||||
convertedRequest, err := adaptor.ConvertEmbeddingRequest(c, info, *request)
|
||||
if err != nil {
|
||||
return types.NewError(err, types.ErrorCodeConvertRequestFailed)
|
||||
return types.NewError(err, types.ErrorCodeConvertRequestFailed, types.ErrOptionWithSkipRetry())
|
||||
}
|
||||
jsonData, err := json.Marshal(convertedRequest)
|
||||
if err != nil {
|
||||
return types.NewError(err, types.ErrorCodeConvertRequestFailed)
|
||||
return types.NewError(err, types.ErrorCodeConvertRequestFailed, types.ErrOptionWithSkipRetry())
|
||||
}
|
||||
requestBody := bytes.NewBuffer(jsonData)
|
||||
statusCodeMappingStr := c.GetString("status_code_mapping")
|
||||
resp, err := adaptor.DoRequest(c, relayInfo, requestBody)
|
||||
resp, err := adaptor.DoRequest(c, info, requestBody)
|
||||
if err != nil {
|
||||
return types.NewOpenAIError(err, types.ErrorCodeDoRequestFailed, http.StatusInternalServerError)
|
||||
}
|
||||
@@ -105,12 +65,12 @@ func EmbeddingHelper(c *gin.Context) (newAPIError *types.NewAPIError) {
|
||||
}
|
||||
}
|
||||
|
||||
usage, newAPIError := adaptor.DoResponse(c, httpResp, relayInfo)
|
||||
usage, newAPIError := adaptor.DoResponse(c, httpResp, info)
|
||||
if newAPIError != nil {
|
||||
// reset status code 重置状态码
|
||||
service.ResetStatusCode(newAPIError, statusCodeMappingStr)
|
||||
return newAPIError
|
||||
}
|
||||
postConsumeQuota(c, relayInfo, usage.(*dto.Usage), preConsumedQuota, userQuota, priceData, "")
|
||||
postConsumeQuota(c, info, usage.(*dto.Usage), "")
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -2,17 +2,16 @@ package relay
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"one-api/common"
|
||||
"one-api/dto"
|
||||
"one-api/logger"
|
||||
"one-api/relay/channel/gemini"
|
||||
relaycommon "one-api/relay/common"
|
||||
"one-api/relay/helper"
|
||||
"one-api/service"
|
||||
"one-api/setting"
|
||||
"one-api/setting/model_setting"
|
||||
"one-api/types"
|
||||
"strings"
|
||||
@@ -20,67 +19,13 @@ import (
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
func getAndValidateGeminiRequest(c *gin.Context) (*gemini.GeminiChatRequest, error) {
|
||||
request := &gemini.GeminiChatRequest{}
|
||||
err := common.UnmarshalBodyReusable(c, request)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if len(request.Contents) == 0 {
|
||||
return nil, errors.New("contents is required")
|
||||
}
|
||||
return request, nil
|
||||
}
|
||||
|
||||
// 流模式
|
||||
// /v1beta/models/gemini-2.0-flash:streamGenerateContent?alt=sse&key=xxx
|
||||
func checkGeminiStreamMode(c *gin.Context, relayInfo *relaycommon.RelayInfo) {
|
||||
if c.Query("alt") == "sse" {
|
||||
relayInfo.IsStream = true
|
||||
}
|
||||
|
||||
// if strings.Contains(c.Request.URL.Path, "streamGenerateContent") {
|
||||
// relayInfo.IsStream = true
|
||||
// }
|
||||
}
|
||||
|
||||
func checkGeminiInputSensitive(textRequest *gemini.GeminiChatRequest) ([]string, error) {
|
||||
var inputTexts []string
|
||||
for _, content := range textRequest.Contents {
|
||||
for _, part := range content.Parts {
|
||||
if part.Text != "" {
|
||||
inputTexts = append(inputTexts, part.Text)
|
||||
}
|
||||
}
|
||||
}
|
||||
if len(inputTexts) == 0 {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
sensitiveWords, err := service.CheckSensitiveInput(inputTexts)
|
||||
return sensitiveWords, err
|
||||
}
|
||||
|
||||
func getGeminiInputTokens(req *gemini.GeminiChatRequest, info *relaycommon.RelayInfo) int {
|
||||
// 计算输入 token 数量
|
||||
var inputTexts []string
|
||||
for _, content := range req.Contents {
|
||||
for _, part := range content.Parts {
|
||||
if part.Text != "" {
|
||||
inputTexts = append(inputTexts, part.Text)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
inputText := strings.Join(inputTexts, "\n")
|
||||
inputTokens := service.CountTokenInput(inputText, info.UpstreamModelName)
|
||||
info.PromptTokens = inputTokens
|
||||
return inputTokens
|
||||
}
|
||||
|
||||
func isNoThinkingRequest(req *gemini.GeminiChatRequest) bool {
|
||||
func isNoThinkingRequest(req *dto.GeminiChatRequest) bool {
|
||||
if req.GenerationConfig.ThinkingConfig != nil && req.GenerationConfig.ThinkingConfig.ThinkingBudget != nil {
|
||||
return *req.GenerationConfig.ThinkingConfig.ThinkingBudget <= 0
|
||||
configBudget := req.GenerationConfig.ThinkingConfig.ThinkingBudget
|
||||
if configBudget != nil && *configBudget == 0 {
|
||||
// 如果思考预算为 0,则认为是非思考请求
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
@@ -105,108 +50,99 @@ func trimModelThinking(modelName string) string {
|
||||
return modelName
|
||||
}
|
||||
|
||||
func GeminiHelper(c *gin.Context) (newAPIError *types.NewAPIError) {
|
||||
req, err := getAndValidateGeminiRequest(c)
|
||||
if err != nil {
|
||||
common.LogError(c, fmt.Sprintf("getAndValidateGeminiRequest error: %s", err.Error()))
|
||||
return types.NewError(err, types.ErrorCodeInvalidRequest)
|
||||
func GeminiHelper(c *gin.Context, info *relaycommon.RelayInfo) (newAPIError *types.NewAPIError) {
|
||||
info.InitChannelMeta(c)
|
||||
|
||||
geminiReq, ok := info.Request.(*dto.GeminiChatRequest)
|
||||
if !ok {
|
||||
return types.NewErrorWithStatusCode(fmt.Errorf("invalid request type, expected *dto.GeminiChatRequest, got %T", info.Request), types.ErrorCodeInvalidRequest, http.StatusBadRequest, types.ErrOptionWithSkipRetry())
|
||||
}
|
||||
|
||||
relayInfo := relaycommon.GenRelayInfoGemini(c)
|
||||
|
||||
// 检查 Gemini 流式模式
|
||||
checkGeminiStreamMode(c, relayInfo)
|
||||
|
||||
if setting.ShouldCheckPromptSensitive() {
|
||||
sensitiveWords, err := checkGeminiInputSensitive(req)
|
||||
if err != nil {
|
||||
common.LogWarn(c, fmt.Sprintf("user sensitive words detected: %s", strings.Join(sensitiveWords, ", ")))
|
||||
return types.NewError(err, types.ErrorCodeSensitiveWordsDetected)
|
||||
}
|
||||
request, err := common.DeepCopy(geminiReq)
|
||||
if err != nil {
|
||||
return types.NewError(fmt.Errorf("failed to copy request to GeminiChatRequest: %w", err), types.ErrorCodeInvalidRequest, types.ErrOptionWithSkipRetry())
|
||||
}
|
||||
|
||||
// model mapped 模型映射
|
||||
err = helper.ModelMappedHelper(c, relayInfo, req)
|
||||
err = helper.ModelMappedHelper(c, info, request)
|
||||
if err != nil {
|
||||
return types.NewError(err, types.ErrorCodeChannelModelMappedError)
|
||||
}
|
||||
|
||||
if value, exists := c.Get("prompt_tokens"); exists {
|
||||
promptTokens := value.(int)
|
||||
relayInfo.SetPromptTokens(promptTokens)
|
||||
} else {
|
||||
promptTokens := getGeminiInputTokens(req, relayInfo)
|
||||
c.Set("prompt_tokens", promptTokens)
|
||||
return types.NewError(err, types.ErrorCodeChannelModelMappedError, types.ErrOptionWithSkipRetry())
|
||||
}
|
||||
|
||||
if model_setting.GetGeminiSettings().ThinkingAdapterEnabled {
|
||||
if isNoThinkingRequest(req) {
|
||||
if isNoThinkingRequest(request) {
|
||||
// check is thinking
|
||||
if !strings.Contains(relayInfo.OriginModelName, "-nothinking") {
|
||||
if !strings.Contains(info.OriginModelName, "-nothinking") {
|
||||
// try to get no thinking model price
|
||||
noThinkingModelName := relayInfo.OriginModelName + "-nothinking"
|
||||
noThinkingModelName := info.OriginModelName + "-nothinking"
|
||||
containPrice := helper.ContainPriceOrRatio(noThinkingModelName)
|
||||
if containPrice {
|
||||
relayInfo.OriginModelName = noThinkingModelName
|
||||
relayInfo.UpstreamModelName = noThinkingModelName
|
||||
info.OriginModelName = noThinkingModelName
|
||||
info.UpstreamModelName = noThinkingModelName
|
||||
}
|
||||
}
|
||||
}
|
||||
if req.GenerationConfig.ThinkingConfig == nil {
|
||||
gemini.ThinkingAdaptor(req, relayInfo)
|
||||
if request.GenerationConfig.ThinkingConfig == nil {
|
||||
gemini.ThinkingAdaptor(request, info)
|
||||
}
|
||||
}
|
||||
|
||||
priceData, err := helper.ModelPriceHelper(c, relayInfo, relayInfo.PromptTokens, int(req.GenerationConfig.MaxOutputTokens))
|
||||
if err != nil {
|
||||
return types.NewError(err, types.ErrorCodeModelPriceError)
|
||||
}
|
||||
|
||||
// pre consume quota
|
||||
preConsumedQuota, userQuota, newAPIError := preConsumeQuota(c, priceData.ShouldPreConsumedQuota, relayInfo)
|
||||
if newAPIError != nil {
|
||||
return newAPIError
|
||||
}
|
||||
defer func() {
|
||||
if newAPIError != nil {
|
||||
returnPreConsumedQuota(c, relayInfo, userQuota, preConsumedQuota)
|
||||
}
|
||||
}()
|
||||
|
||||
adaptor := GetAdaptor(relayInfo.ApiType)
|
||||
adaptor := GetAdaptor(info.ApiType)
|
||||
if adaptor == nil {
|
||||
return types.NewError(fmt.Errorf("invalid api type: %d", relayInfo.ApiType), types.ErrorCodeInvalidApiType)
|
||||
return types.NewError(fmt.Errorf("invalid api type: %d", info.ApiType), types.ErrorCodeInvalidApiType, types.ErrOptionWithSkipRetry())
|
||||
}
|
||||
|
||||
adaptor.Init(relayInfo)
|
||||
adaptor.Init(info)
|
||||
|
||||
// Clean up empty system instruction
|
||||
if req.SystemInstructions != nil {
|
||||
if request.SystemInstructions != nil {
|
||||
hasContent := false
|
||||
for _, part := range req.SystemInstructions.Parts {
|
||||
for _, part := range request.SystemInstructions.Parts {
|
||||
if part.Text != "" {
|
||||
hasContent = true
|
||||
break
|
||||
}
|
||||
}
|
||||
if !hasContent {
|
||||
req.SystemInstructions = nil
|
||||
request.SystemInstructions = nil
|
||||
}
|
||||
}
|
||||
|
||||
requestBody, err := json.Marshal(req)
|
||||
if err != nil {
|
||||
return types.NewError(err, types.ErrorCodeConvertRequestFailed)
|
||||
var requestBody io.Reader
|
||||
if model_setting.GetGlobalSettings().PassThroughRequestEnabled || info.ChannelSetting.PassThroughBodyEnabled {
|
||||
body, err := common.GetRequestBody(c)
|
||||
if err != nil {
|
||||
return types.NewErrorWithStatusCode(err, types.ErrorCodeReadRequestBodyFailed, http.StatusBadRequest, types.ErrOptionWithSkipRetry())
|
||||
}
|
||||
requestBody = bytes.NewReader(body)
|
||||
} else {
|
||||
// 使用 ConvertGeminiRequest 转换请求格式
|
||||
convertedRequest, err := adaptor.ConvertGeminiRequest(c, info, request)
|
||||
if err != nil {
|
||||
return types.NewError(err, types.ErrorCodeConvertRequestFailed, types.ErrOptionWithSkipRetry())
|
||||
}
|
||||
jsonData, err := common.Marshal(convertedRequest)
|
||||
if err != nil {
|
||||
return types.NewError(err, types.ErrorCodeConvertRequestFailed, types.ErrOptionWithSkipRetry())
|
||||
}
|
||||
|
||||
// apply param override
|
||||
if len(info.ParamOverride) > 0 {
|
||||
jsonData, err = relaycommon.ApplyParamOverride(jsonData, info.ParamOverride)
|
||||
if err != nil {
|
||||
return types.NewError(err, types.ErrorCodeChannelParamOverrideInvalid, types.ErrOptionWithSkipRetry())
|
||||
}
|
||||
}
|
||||
|
||||
logger.LogDebug(c, "Gemini request body: "+string(jsonData))
|
||||
|
||||
requestBody = bytes.NewReader(jsonData)
|
||||
}
|
||||
|
||||
if common.DebugEnabled {
|
||||
println("Gemini request body: %s", string(requestBody))
|
||||
}
|
||||
|
||||
resp, err := adaptor.DoRequest(c, relayInfo, bytes.NewReader(requestBody))
|
||||
resp, err := adaptor.DoRequest(c, info, requestBody)
|
||||
if err != nil {
|
||||
common.LogError(c, "Do gemini request failed: "+err.Error())
|
||||
return types.NewError(err, types.ErrorCodeDoRequestFailed)
|
||||
logger.LogError(c, "Do gemini request failed: "+err.Error())
|
||||
return types.NewOpenAIError(err, types.ErrorCodeDoRequestFailed, http.StatusInternalServerError)
|
||||
}
|
||||
|
||||
statusCodeMappingStr := c.GetString("status_code_mapping")
|
||||
@@ -214,7 +150,7 @@ func GeminiHelper(c *gin.Context) (newAPIError *types.NewAPIError) {
|
||||
var httpResp *http.Response
|
||||
if resp != nil {
|
||||
httpResp = resp.(*http.Response)
|
||||
relayInfo.IsStream = relayInfo.IsStream || strings.HasPrefix(httpResp.Header.Get("Content-Type"), "text/event-stream")
|
||||
info.IsStream = info.IsStream || strings.HasPrefix(httpResp.Header.Get("Content-Type"), "text/event-stream")
|
||||
if httpResp.StatusCode != http.StatusOK {
|
||||
newAPIError = service.RelayErrorHandler(httpResp, false)
|
||||
// reset status code 重置状态码
|
||||
@@ -223,12 +159,108 @@ func GeminiHelper(c *gin.Context) (newAPIError *types.NewAPIError) {
|
||||
}
|
||||
}
|
||||
|
||||
usage, openaiErr := adaptor.DoResponse(c, resp.(*http.Response), relayInfo)
|
||||
usage, openaiErr := adaptor.DoResponse(c, resp.(*http.Response), info)
|
||||
if openaiErr != nil {
|
||||
service.ResetStatusCode(openaiErr, statusCodeMappingStr)
|
||||
return openaiErr
|
||||
}
|
||||
|
||||
postConsumeQuota(c, relayInfo, usage.(*dto.Usage), preConsumedQuota, userQuota, priceData, "")
|
||||
postConsumeQuota(c, info, usage.(*dto.Usage), "")
|
||||
return nil
|
||||
}
|
||||
|
||||
func GeminiEmbeddingHandler(c *gin.Context, info *relaycommon.RelayInfo) (newAPIError *types.NewAPIError) {
|
||||
info.InitChannelMeta(c)
|
||||
|
||||
isBatch := strings.HasSuffix(c.Request.URL.Path, "batchEmbedContents")
|
||||
info.IsGeminiBatchEmbedding = isBatch
|
||||
|
||||
var req dto.Request
|
||||
var err error
|
||||
var inputTexts []string
|
||||
|
||||
if isBatch {
|
||||
batchRequest := &dto.GeminiBatchEmbeddingRequest{}
|
||||
err = common.UnmarshalBodyReusable(c, batchRequest)
|
||||
if err != nil {
|
||||
return types.NewError(err, types.ErrorCodeInvalidRequest, types.ErrOptionWithSkipRetry())
|
||||
}
|
||||
req = batchRequest
|
||||
for _, r := range batchRequest.Requests {
|
||||
for _, part := range r.Content.Parts {
|
||||
if part.Text != "" {
|
||||
inputTexts = append(inputTexts, part.Text)
|
||||
}
|
||||
}
|
||||
}
|
||||
} else {
|
||||
singleRequest := &dto.GeminiEmbeddingRequest{}
|
||||
err = common.UnmarshalBodyReusable(c, singleRequest)
|
||||
if err != nil {
|
||||
return types.NewError(err, types.ErrorCodeInvalidRequest, types.ErrOptionWithSkipRetry())
|
||||
}
|
||||
req = singleRequest
|
||||
for _, part := range singleRequest.Content.Parts {
|
||||
if part.Text != "" {
|
||||
inputTexts = append(inputTexts, part.Text)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
err = helper.ModelMappedHelper(c, info, req)
|
||||
if err != nil {
|
||||
return types.NewError(err, types.ErrorCodeChannelModelMappedError, types.ErrOptionWithSkipRetry())
|
||||
}
|
||||
|
||||
adaptor := GetAdaptor(info.ApiType)
|
||||
if adaptor == nil {
|
||||
return types.NewError(fmt.Errorf("invalid api type: %d", info.ApiType), types.ErrorCodeInvalidApiType, types.ErrOptionWithSkipRetry())
|
||||
}
|
||||
adaptor.Init(info)
|
||||
|
||||
var requestBody io.Reader
|
||||
jsonData, err := common.Marshal(req)
|
||||
if err != nil {
|
||||
return types.NewError(err, types.ErrorCodeConvertRequestFailed, types.ErrOptionWithSkipRetry())
|
||||
}
|
||||
|
||||
// apply param override
|
||||
if len(info.ParamOverride) > 0 {
|
||||
reqMap := make(map[string]interface{})
|
||||
_ = common.Unmarshal(jsonData, &reqMap)
|
||||
for key, value := range info.ParamOverride {
|
||||
reqMap[key] = value
|
||||
}
|
||||
jsonData, err = common.Marshal(reqMap)
|
||||
if err != nil {
|
||||
return types.NewError(err, types.ErrorCodeChannelParamOverrideInvalid, types.ErrOptionWithSkipRetry())
|
||||
}
|
||||
}
|
||||
requestBody = bytes.NewReader(jsonData)
|
||||
|
||||
resp, err := adaptor.DoRequest(c, info, requestBody)
|
||||
if err != nil {
|
||||
logger.LogError(c, "Do gemini request failed: "+err.Error())
|
||||
return types.NewOpenAIError(err, types.ErrorCodeDoRequestFailed, http.StatusInternalServerError)
|
||||
}
|
||||
|
||||
statusCodeMappingStr := c.GetString("status_code_mapping")
|
||||
var httpResp *http.Response
|
||||
if resp != nil {
|
||||
httpResp = resp.(*http.Response)
|
||||
if httpResp.StatusCode != http.StatusOK {
|
||||
newAPIError = service.RelayErrorHandler(httpResp, false)
|
||||
service.ResetStatusCode(newAPIError, statusCodeMappingStr)
|
||||
return newAPIError
|
||||
}
|
||||
}
|
||||
|
||||
usage, openaiErr := adaptor.DoResponse(c, resp.(*http.Response), info)
|
||||
if openaiErr != nil {
|
||||
service.ResetStatusCode(openaiErr, statusCodeMappingStr)
|
||||
return openaiErr
|
||||
}
|
||||
|
||||
postConsumeQuota(c, info, usage.(*dto.Usage), "")
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -1,85 +1,80 @@
|
||||
package helper
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"one-api/common"
|
||||
"one-api/dto"
|
||||
"one-api/logger"
|
||||
"one-api/types"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/gorilla/websocket"
|
||||
)
|
||||
|
||||
func FlushWriter(c *gin.Context) error {
|
||||
if c.Writer == nil {
|
||||
return nil
|
||||
}
|
||||
if flusher, ok := c.Writer.(http.Flusher); ok {
|
||||
flusher.Flush()
|
||||
return nil
|
||||
}
|
||||
return errors.New("streaming error: flusher not found")
|
||||
}
|
||||
|
||||
func SetEventStreamHeaders(c *gin.Context) {
|
||||
// 检查是否已经设置过头部
|
||||
if _, exists := c.Get("event_stream_headers_set"); exists {
|
||||
return
|
||||
}
|
||||
|
||||
// 设置标志,表示头部已经设置过
|
||||
c.Set("event_stream_headers_set", true)
|
||||
|
||||
c.Writer.Header().Set("Content-Type", "text/event-stream")
|
||||
c.Writer.Header().Set("Cache-Control", "no-cache")
|
||||
c.Writer.Header().Set("Connection", "keep-alive")
|
||||
c.Writer.Header().Set("Transfer-Encoding", "chunked")
|
||||
c.Writer.Header().Set("X-Accel-Buffering", "no")
|
||||
|
||||
// 设置标志,表示头部已经设置过
|
||||
c.Set("event_stream_headers_set", true)
|
||||
}
|
||||
|
||||
func ClaudeData(c *gin.Context, resp dto.ClaudeResponse) error {
|
||||
jsonData, err := json.Marshal(resp)
|
||||
jsonData, err := common.Marshal(resp)
|
||||
if err != nil {
|
||||
common.SysError("error marshalling stream response: " + err.Error())
|
||||
} else {
|
||||
c.Render(-1, common.CustomEvent{Data: fmt.Sprintf("event: %s\n", resp.Type)})
|
||||
c.Render(-1, common.CustomEvent{Data: "data: " + string(jsonData)})
|
||||
}
|
||||
if flusher, ok := c.Writer.(http.Flusher); ok {
|
||||
flusher.Flush()
|
||||
} else {
|
||||
return errors.New("streaming error: flusher not found")
|
||||
}
|
||||
_ = FlushWriter(c)
|
||||
return nil
|
||||
}
|
||||
|
||||
func ClaudeChunkData(c *gin.Context, resp dto.ClaudeResponse, data string) {
|
||||
c.Render(-1, common.CustomEvent{Data: fmt.Sprintf("event: %s\n", resp.Type)})
|
||||
c.Render(-1, common.CustomEvent{Data: fmt.Sprintf("data: %s\n", data)})
|
||||
if flusher, ok := c.Writer.(http.Flusher); ok {
|
||||
flusher.Flush()
|
||||
}
|
||||
_ = FlushWriter(c)
|
||||
}
|
||||
|
||||
func ResponseChunkData(c *gin.Context, resp dto.ResponsesStreamResponse, data string) {
|
||||
c.Render(-1, common.CustomEvent{Data: fmt.Sprintf("event: %s\n", resp.Type)})
|
||||
c.Render(-1, common.CustomEvent{Data: fmt.Sprintf("data: %s", data)})
|
||||
if flusher, ok := c.Writer.(http.Flusher); ok {
|
||||
flusher.Flush()
|
||||
}
|
||||
_ = FlushWriter(c)
|
||||
}
|
||||
|
||||
func StringData(c *gin.Context, str string) error {
|
||||
//str = strings.TrimPrefix(str, "data: ")
|
||||
//str = strings.TrimSuffix(str, "\r")
|
||||
c.Render(-1, common.CustomEvent{Data: "data: " + str})
|
||||
if flusher, ok := c.Writer.(http.Flusher); ok {
|
||||
flusher.Flush()
|
||||
} else {
|
||||
return errors.New("streaming error: flusher not found")
|
||||
}
|
||||
_ = FlushWriter(c)
|
||||
return nil
|
||||
}
|
||||
|
||||
func PingData(c *gin.Context) error {
|
||||
c.Writer.Write([]byte(": PING\n\n"))
|
||||
if flusher, ok := c.Writer.(http.Flusher); ok {
|
||||
flusher.Flush()
|
||||
} else {
|
||||
return errors.New("streaming error: flusher not found")
|
||||
}
|
||||
_ = FlushWriter(c)
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -100,7 +95,7 @@ func Done(c *gin.Context) {
|
||||
|
||||
func WssString(c *gin.Context, ws *websocket.Conn, str string) error {
|
||||
if ws == nil {
|
||||
common.LogError(c, "websocket connection is nil")
|
||||
logger.LogError(c, "websocket connection is nil")
|
||||
return errors.New("websocket connection is nil")
|
||||
}
|
||||
//common.LogInfo(c, fmt.Sprintf("sending message: %s", str))
|
||||
@@ -108,12 +103,12 @@ func WssString(c *gin.Context, ws *websocket.Conn, str string) error {
|
||||
}
|
||||
|
||||
func WssObject(c *gin.Context, ws *websocket.Conn, object interface{}) error {
|
||||
jsonData, err := json.Marshal(object)
|
||||
jsonData, err := common.Marshal(object)
|
||||
if err != nil {
|
||||
return fmt.Errorf("error marshalling object: %w", err)
|
||||
}
|
||||
if ws == nil {
|
||||
common.LogError(c, "websocket connection is nil")
|
||||
logger.LogError(c, "websocket connection is nil")
|
||||
return errors.New("websocket connection is nil")
|
||||
}
|
||||
//common.LogInfo(c, fmt.Sprintf("sending message: %s", jsonData))
|
||||
@@ -121,6 +116,9 @@ func WssObject(c *gin.Context, ws *websocket.Conn, object interface{}) error {
|
||||
}
|
||||
|
||||
func WssError(c *gin.Context, ws *websocket.Conn, openaiError types.OpenAIError) {
|
||||
if ws == nil {
|
||||
return
|
||||
}
|
||||
errorObj := &dto.RealtimeEvent{
|
||||
Type: "error",
|
||||
EventId: GetLocalRealtimeID(c),
|
||||
@@ -139,6 +137,24 @@ func GetLocalRealtimeID(c *gin.Context) string {
|
||||
return fmt.Sprintf("evt_%s", logID)
|
||||
}
|
||||
|
||||
func GenerateStartEmptyResponse(id string, createAt int64, model string, systemFingerprint *string) *dto.ChatCompletionsStreamResponse {
|
||||
return &dto.ChatCompletionsStreamResponse{
|
||||
Id: id,
|
||||
Object: "chat.completion.chunk",
|
||||
Created: createAt,
|
||||
Model: model,
|
||||
SystemFingerprint: systemFingerprint,
|
||||
Choices: []dto.ChatCompletionsStreamResponseChoice{
|
||||
{
|
||||
Delta: dto.ChatCompletionsStreamResponseChoiceDelta{
|
||||
Role: "assistant",
|
||||
Content: common.GetPointer(""),
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
func GenerateStopResponse(id string, createAt int64, model string, finishReason string) *dto.ChatCompletionsStreamResponse {
|
||||
return &dto.ChatCompletionsStreamResponse{
|
||||
Id: id,
|
||||
|
||||
@@ -4,14 +4,12 @@ import (
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
common2 "one-api/common"
|
||||
"github.com/gin-gonic/gin"
|
||||
"one-api/dto"
|
||||
"one-api/relay/common"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
func ModelMappedHelper(c *gin.Context, info *common.RelayInfo, request any) error {
|
||||
func ModelMappedHelper(c *gin.Context, info *common.RelayInfo, request dto.Request) error {
|
||||
// map model name
|
||||
modelMapping := c.GetString("model_mapping")
|
||||
if modelMapping != "" && modelMapping != "{}" {
|
||||
@@ -53,40 +51,7 @@ func ModelMappedHelper(c *gin.Context, info *common.RelayInfo, request any) erro
|
||||
}
|
||||
}
|
||||
if request != nil {
|
||||
switch info.RelayFormat {
|
||||
case common.RelayFormatGemini:
|
||||
// Gemini 模型映射
|
||||
case common.RelayFormatClaude:
|
||||
if claudeRequest, ok := request.(*dto.ClaudeRequest); ok {
|
||||
claudeRequest.Model = info.UpstreamModelName
|
||||
}
|
||||
case common.RelayFormatOpenAIResponses:
|
||||
if openAIResponsesRequest, ok := request.(*dto.OpenAIResponsesRequest); ok {
|
||||
openAIResponsesRequest.Model = info.UpstreamModelName
|
||||
}
|
||||
case common.RelayFormatOpenAIAudio:
|
||||
if openAIAudioRequest, ok := request.(*dto.AudioRequest); ok {
|
||||
openAIAudioRequest.Model = info.UpstreamModelName
|
||||
}
|
||||
case common.RelayFormatOpenAIImage:
|
||||
if imageRequest, ok := request.(*dto.ImageRequest); ok {
|
||||
imageRequest.Model = info.UpstreamModelName
|
||||
}
|
||||
case common.RelayFormatRerank:
|
||||
if rerankRequest, ok := request.(*dto.RerankRequest); ok {
|
||||
rerankRequest.Model = info.UpstreamModelName
|
||||
}
|
||||
case common.RelayFormatEmbedding:
|
||||
if embeddingRequest, ok := request.(*dto.EmbeddingRequest); ok {
|
||||
embeddingRequest.Model = info.UpstreamModelName
|
||||
}
|
||||
default:
|
||||
if openAIRequest, ok := request.(*dto.GeneralOpenAIRequest); ok {
|
||||
openAIRequest.Model = info.UpstreamModelName
|
||||
} else {
|
||||
common2.LogWarn(c, fmt.Sprintf("model mapped but request type %T not supported", request))
|
||||
}
|
||||
}
|
||||
request.SetModelName(info.UpstreamModelName)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -5,35 +5,14 @@ import (
|
||||
"one-api/common"
|
||||
relaycommon "one-api/relay/common"
|
||||
"one-api/setting/ratio_setting"
|
||||
"one-api/types"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
type GroupRatioInfo struct {
|
||||
GroupRatio float64
|
||||
GroupSpecialRatio float64
|
||||
HasSpecialRatio bool
|
||||
}
|
||||
|
||||
type PriceData struct {
|
||||
ModelPrice float64
|
||||
ModelRatio float64
|
||||
CompletionRatio float64
|
||||
CacheRatio float64
|
||||
CacheCreationRatio float64
|
||||
ImageRatio float64
|
||||
UsePrice bool
|
||||
ShouldPreConsumedQuota int
|
||||
GroupRatioInfo GroupRatioInfo
|
||||
}
|
||||
|
||||
func (p PriceData) ToSetting() string {
|
||||
return fmt.Sprintf("ModelPrice: %f, ModelRatio: %f, CompletionRatio: %f, CacheRatio: %f, GroupRatio: %f, UsePrice: %t, CacheCreationRatio: %f, ShouldPreConsumedQuota: %d, ImageRatio: %f", p.ModelPrice, p.ModelRatio, p.CompletionRatio, p.CacheRatio, p.GroupRatioInfo.GroupRatio, p.UsePrice, p.CacheCreationRatio, p.ShouldPreConsumedQuota, p.ImageRatio)
|
||||
}
|
||||
|
||||
// HandleGroupRatio checks for "auto_group" in the context and updates the group ratio and relayInfo.UsingGroup if present
|
||||
func HandleGroupRatio(ctx *gin.Context, relayInfo *relaycommon.RelayInfo) GroupRatioInfo {
|
||||
groupRatioInfo := GroupRatioInfo{
|
||||
func HandleGroupRatio(ctx *gin.Context, relayInfo *relaycommon.RelayInfo) types.GroupRatioInfo {
|
||||
groupRatioInfo := types.GroupRatioInfo{
|
||||
GroupRatio: 1.0, // default ratio
|
||||
GroupSpecialRatio: -1,
|
||||
}
|
||||
@@ -62,7 +41,7 @@ func HandleGroupRatio(ctx *gin.Context, relayInfo *relaycommon.RelayInfo) GroupR
|
||||
return groupRatioInfo
|
||||
}
|
||||
|
||||
func ModelPriceHelper(c *gin.Context, info *relaycommon.RelayInfo, promptTokens int, maxTokens int) (PriceData, error) {
|
||||
func ModelPriceHelper(c *gin.Context, info *relaycommon.RelayInfo, promptTokens int, meta *types.TokenCountMeta) (types.PriceData, error) {
|
||||
modelPrice, usePrice := ratio_setting.GetModelPrice(info.OriginModelName, false)
|
||||
|
||||
groupRatioInfo := HandleGroupRatio(c, info)
|
||||
@@ -74,9 +53,9 @@ func ModelPriceHelper(c *gin.Context, info *relaycommon.RelayInfo, promptTokens
|
||||
var imageRatio float64
|
||||
var cacheCreationRatio float64
|
||||
if !usePrice {
|
||||
preConsumedTokens := common.PreConsumedQuota
|
||||
if maxTokens != 0 {
|
||||
preConsumedTokens = promptTokens + maxTokens
|
||||
preConsumedTokens := common.Max(promptTokens, common.PreConsumedQuota)
|
||||
if meta.MaxTokens != 0 {
|
||||
preConsumedTokens += meta.MaxTokens
|
||||
}
|
||||
var success bool
|
||||
var matchName string
|
||||
@@ -87,7 +66,7 @@ func ModelPriceHelper(c *gin.Context, info *relaycommon.RelayInfo, promptTokens
|
||||
acceptUnsetRatio = true
|
||||
}
|
||||
if !acceptUnsetRatio {
|
||||
return PriceData{}, fmt.Errorf("模型 %s 倍率或价格未配置,请联系管理员设置或开始自用模式;Model %s ratio or price not set, please set or start self-use mode", matchName, matchName)
|
||||
return types.PriceData{}, fmt.Errorf("模型 %s 倍率或价格未配置,请联系管理员设置或开始自用模式;Model %s ratio or price not set, please set or start self-use mode", matchName, matchName)
|
||||
}
|
||||
}
|
||||
completionRatio = ratio_setting.GetCompletionRatio(info.OriginModelName)
|
||||
@@ -97,10 +76,13 @@ func ModelPriceHelper(c *gin.Context, info *relaycommon.RelayInfo, promptTokens
|
||||
ratio := modelRatio * groupRatioInfo.GroupRatio
|
||||
preConsumedQuota = int(float64(preConsumedTokens) * ratio)
|
||||
} else {
|
||||
if meta.ImagePriceRatio != 0 {
|
||||
modelPrice = modelPrice * meta.ImagePriceRatio
|
||||
}
|
||||
preConsumedQuota = int(modelPrice * common.QuotaPerUnit * groupRatioInfo.GroupRatio)
|
||||
}
|
||||
|
||||
priceData := PriceData{
|
||||
priceData := types.PriceData{
|
||||
ModelPrice: modelPrice,
|
||||
ModelRatio: modelRatio,
|
||||
CompletionRatio: completionRatio,
|
||||
@@ -115,18 +97,12 @@ func ModelPriceHelper(c *gin.Context, info *relaycommon.RelayInfo, promptTokens
|
||||
if common.DebugEnabled {
|
||||
println(fmt.Sprintf("model_price_helper result: %s", priceData.ToSetting()))
|
||||
}
|
||||
|
||||
info.PriceData = priceData
|
||||
return priceData, nil
|
||||
}
|
||||
|
||||
type PerCallPriceData struct {
|
||||
ModelPrice float64
|
||||
Quota int
|
||||
GroupRatioInfo GroupRatioInfo
|
||||
}
|
||||
|
||||
// ModelPriceHelperPerCall 按次计费的 PriceHelper (MJ、Task)
|
||||
func ModelPriceHelperPerCall(c *gin.Context, info *relaycommon.RelayInfo) PerCallPriceData {
|
||||
func ModelPriceHelperPerCall(c *gin.Context, info *relaycommon.RelayInfo) types.PerCallPriceData {
|
||||
groupRatioInfo := HandleGroupRatio(c, info)
|
||||
|
||||
modelPrice, success := ratio_setting.GetModelPrice(info.OriginModelName, true)
|
||||
@@ -140,7 +116,7 @@ func ModelPriceHelperPerCall(c *gin.Context, info *relaycommon.RelayInfo) PerCal
|
||||
}
|
||||
}
|
||||
quota := int(modelPrice * common.QuotaPerUnit * groupRatioInfo.GroupRatio)
|
||||
priceData := PerCallPriceData{
|
||||
priceData := types.PerCallPriceData{
|
||||
ModelPrice: modelPrice,
|
||||
Quota: quota,
|
||||
GroupRatioInfo: groupRatioInfo,
|
||||
|
||||
@@ -8,6 +8,7 @@ import (
|
||||
"net/http"
|
||||
"one-api/common"
|
||||
"one-api/constant"
|
||||
"one-api/logger"
|
||||
relaycommon "one-api/relay/common"
|
||||
"one-api/setting/operation_setting"
|
||||
"strings"
|
||||
@@ -39,10 +40,6 @@ func StreamScannerHandler(c *gin.Context, resp *http.Response, info *relaycommon
|
||||
}()
|
||||
|
||||
streamingTimeout := time.Duration(constant.StreamingTimeout) * time.Second
|
||||
if strings.HasPrefix(info.UpstreamModelName, "o") {
|
||||
// twice timeout for thinking model
|
||||
streamingTimeout *= 2
|
||||
}
|
||||
|
||||
var (
|
||||
stopChan = make(chan bool, 3) // 增加缓冲区避免阻塞
|
||||
@@ -54,7 +51,7 @@ func StreamScannerHandler(c *gin.Context, resp *http.Response, info *relaycommon
|
||||
)
|
||||
|
||||
generalSettings := operation_setting.GetGeneralSetting()
|
||||
pingEnabled := generalSettings.PingIntervalEnabled
|
||||
pingEnabled := generalSettings.PingIntervalEnabled && !info.DisablePing
|
||||
pingInterval := time.Duration(generalSettings.PingIntervalSeconds) * time.Second
|
||||
if pingInterval <= 0 {
|
||||
pingInterval = DefaultPingInterval
|
||||
@@ -91,7 +88,7 @@ func StreamScannerHandler(c *gin.Context, resp *http.Response, info *relaycommon
|
||||
select {
|
||||
case <-done:
|
||||
case <-time.After(5 * time.Second):
|
||||
common.LogError(c, "timeout waiting for goroutines to exit")
|
||||
logger.LogError(c, "timeout waiting for goroutines to exit")
|
||||
}
|
||||
|
||||
close(stopChan)
|
||||
@@ -113,7 +110,7 @@ func StreamScannerHandler(c *gin.Context, resp *http.Response, info *relaycommon
|
||||
defer func() {
|
||||
wg.Done()
|
||||
if r := recover(); r != nil {
|
||||
common.LogError(c, fmt.Sprintf("ping goroutine panic: %v", r))
|
||||
logger.LogError(c, fmt.Sprintf("ping goroutine panic: %v", r))
|
||||
common.SafeSendBool(stopChan, true)
|
||||
}
|
||||
if common.DebugEnabled {
|
||||
@@ -140,14 +137,14 @@ func StreamScannerHandler(c *gin.Context, resp *http.Response, info *relaycommon
|
||||
select {
|
||||
case err := <-done:
|
||||
if err != nil {
|
||||
common.LogError(c, "ping data error: "+err.Error())
|
||||
logger.LogError(c, "ping data error: "+err.Error())
|
||||
return
|
||||
}
|
||||
if common.DebugEnabled {
|
||||
println("ping data sent")
|
||||
}
|
||||
case <-time.After(10 * time.Second):
|
||||
common.LogError(c, "ping data send timeout")
|
||||
logger.LogError(c, "ping data send timeout")
|
||||
return
|
||||
case <-ctx.Done():
|
||||
return
|
||||
@@ -162,7 +159,7 @@ func StreamScannerHandler(c *gin.Context, resp *http.Response, info *relaycommon
|
||||
// 监听客户端断开连接
|
||||
return
|
||||
case <-pingTimeout.C:
|
||||
common.LogError(c, "ping goroutine max duration reached")
|
||||
logger.LogError(c, "ping goroutine max duration reached")
|
||||
return
|
||||
}
|
||||
}
|
||||
@@ -175,7 +172,7 @@ func StreamScannerHandler(c *gin.Context, resp *http.Response, info *relaycommon
|
||||
defer func() {
|
||||
wg.Done()
|
||||
if r := recover(); r != nil {
|
||||
common.LogError(c, fmt.Sprintf("scanner goroutine panic: %v", r))
|
||||
logger.LogError(c, fmt.Sprintf("scanner goroutine panic: %v", r))
|
||||
}
|
||||
common.SafeSendBool(stopChan, true)
|
||||
if common.DebugEnabled {
|
||||
@@ -227,19 +224,25 @@ func StreamScannerHandler(c *gin.Context, resp *http.Response, info *relaycommon
|
||||
return
|
||||
}
|
||||
case <-time.After(10 * time.Second):
|
||||
common.LogError(c, "data handler timeout")
|
||||
logger.LogError(c, "data handler timeout")
|
||||
return
|
||||
case <-ctx.Done():
|
||||
return
|
||||
case <-stopChan:
|
||||
return
|
||||
}
|
||||
} else {
|
||||
// done, 处理完成标志,直接退出停止读取剩余数据防止出错
|
||||
if common.DebugEnabled {
|
||||
println("received [DONE], stopping scanner")
|
||||
}
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
if err := scanner.Err(); err != nil {
|
||||
if err != io.EOF {
|
||||
common.LogError(c, "scanner error: "+err.Error())
|
||||
logger.LogError(c, "scanner error: "+err.Error())
|
||||
}
|
||||
}
|
||||
})
|
||||
@@ -248,12 +251,12 @@ func StreamScannerHandler(c *gin.Context, resp *http.Response, info *relaycommon
|
||||
select {
|
||||
case <-ticker.C:
|
||||
// 超时处理逻辑
|
||||
common.LogError(c, "streaming timeout")
|
||||
logger.LogError(c, "streaming timeout")
|
||||
case <-stopChan:
|
||||
// 正常结束
|
||||
common.LogInfo(c, "streaming finished")
|
||||
logger.LogInfo(c, "streaming finished")
|
||||
case <-c.Request.Context().Done():
|
||||
// 客户端断开连接
|
||||
common.LogInfo(c, "client disconnected")
|
||||
logger.LogInfo(c, "client disconnected")
|
||||
}
|
||||
}
|
||||
|
||||
306
relay/helper/valid_request.go
Normal file
306
relay/helper/valid_request.go
Normal file
@@ -0,0 +1,306 @@
|
||||
package helper
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"math"
|
||||
"one-api/common"
|
||||
"one-api/dto"
|
||||
"one-api/logger"
|
||||
relayconstant "one-api/relay/constant"
|
||||
"one-api/types"
|
||||
"strings"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
func GetAndValidateRequest(c *gin.Context, format types.RelayFormat) (request dto.Request, err error) {
|
||||
relayMode := relayconstant.Path2RelayMode(c.Request.URL.Path)
|
||||
|
||||
switch format {
|
||||
case types.RelayFormatOpenAI:
|
||||
request, err = GetAndValidateTextRequest(c, relayMode)
|
||||
case types.RelayFormatGemini:
|
||||
request, err = GetAndValidateGeminiRequest(c)
|
||||
case types.RelayFormatClaude:
|
||||
request, err = GetAndValidateClaudeRequest(c)
|
||||
case types.RelayFormatOpenAIResponses:
|
||||
request, err = GetAndValidateResponsesRequest(c)
|
||||
|
||||
case types.RelayFormatOpenAIImage:
|
||||
request, err = GetAndValidOpenAIImageRequest(c, relayMode)
|
||||
case types.RelayFormatEmbedding:
|
||||
request, err = GetAndValidateEmbeddingRequest(c, relayMode)
|
||||
case types.RelayFormatRerank:
|
||||
request, err = GetAndValidateRerankRequest(c)
|
||||
case types.RelayFormatOpenAIAudio:
|
||||
request, err = GetAndValidAudioRequest(c, relayMode)
|
||||
case types.RelayFormatOpenAIRealtime:
|
||||
request = &dto.BaseRequest{}
|
||||
default:
|
||||
return nil, fmt.Errorf("unsupported relay format: %s", format)
|
||||
}
|
||||
return request, err
|
||||
}
|
||||
|
||||
func GetAndValidAudioRequest(c *gin.Context, relayMode int) (*dto.AudioRequest, error) {
|
||||
audioRequest := &dto.AudioRequest{}
|
||||
err := common.UnmarshalBodyReusable(c, audioRequest)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
switch relayMode {
|
||||
case relayconstant.RelayModeAudioSpeech:
|
||||
if audioRequest.Model == "" {
|
||||
return nil, errors.New("model is required")
|
||||
}
|
||||
default:
|
||||
err = c.Request.ParseForm()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
formData := c.Request.PostForm
|
||||
if audioRequest.Model == "" {
|
||||
audioRequest.Model = formData.Get("model")
|
||||
}
|
||||
|
||||
if audioRequest.Model == "" {
|
||||
return nil, errors.New("model is required")
|
||||
}
|
||||
audioRequest.ResponseFormat = formData.Get("response_format")
|
||||
if audioRequest.ResponseFormat == "" {
|
||||
audioRequest.ResponseFormat = "json"
|
||||
}
|
||||
}
|
||||
return audioRequest, nil
|
||||
}
|
||||
|
||||
func GetAndValidateRerankRequest(c *gin.Context) (*dto.RerankRequest, error) {
|
||||
var rerankRequest *dto.RerankRequest
|
||||
err := common.UnmarshalBodyReusable(c, &rerankRequest)
|
||||
if err != nil {
|
||||
logger.LogError(c, fmt.Sprintf("getAndValidateTextRequest failed: %s", err.Error()))
|
||||
return nil, types.NewError(err, types.ErrorCodeInvalidRequest, types.ErrOptionWithSkipRetry())
|
||||
}
|
||||
|
||||
if rerankRequest.Query == "" {
|
||||
return nil, types.NewError(fmt.Errorf("query is empty"), types.ErrorCodeInvalidRequest, types.ErrOptionWithSkipRetry())
|
||||
}
|
||||
if len(rerankRequest.Documents) == 0 {
|
||||
return nil, types.NewError(fmt.Errorf("documents is empty"), types.ErrorCodeInvalidRequest, types.ErrOptionWithSkipRetry())
|
||||
}
|
||||
return rerankRequest, nil
|
||||
}
|
||||
|
||||
func GetAndValidateEmbeddingRequest(c *gin.Context, relayMode int) (*dto.EmbeddingRequest, error) {
|
||||
var embeddingRequest *dto.EmbeddingRequest
|
||||
err := common.UnmarshalBodyReusable(c, &embeddingRequest)
|
||||
if err != nil {
|
||||
logger.LogError(c, fmt.Sprintf("getAndValidateTextRequest failed: %s", err.Error()))
|
||||
return nil, types.NewError(err, types.ErrorCodeInvalidRequest, types.ErrOptionWithSkipRetry())
|
||||
}
|
||||
|
||||
if embeddingRequest.Input == nil {
|
||||
return nil, fmt.Errorf("input is empty")
|
||||
}
|
||||
if relayMode == relayconstant.RelayModeModerations && embeddingRequest.Model == "" {
|
||||
embeddingRequest.Model = "omni-moderation-latest"
|
||||
}
|
||||
if relayMode == relayconstant.RelayModeEmbeddings && embeddingRequest.Model == "" {
|
||||
embeddingRequest.Model = c.Param("model")
|
||||
}
|
||||
return embeddingRequest, nil
|
||||
}
|
||||
|
||||
func GetAndValidateResponsesRequest(c *gin.Context) (*dto.OpenAIResponsesRequest, error) {
|
||||
request := &dto.OpenAIResponsesRequest{}
|
||||
err := common.UnmarshalBodyReusable(c, request)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if request.Model == "" {
|
||||
return nil, errors.New("model is required")
|
||||
}
|
||||
if request.Input == nil {
|
||||
return nil, errors.New("input is required")
|
||||
}
|
||||
return request, nil
|
||||
}
|
||||
|
||||
func GetAndValidOpenAIImageRequest(c *gin.Context, relayMode int) (*dto.ImageRequest, error) {
|
||||
imageRequest := &dto.ImageRequest{}
|
||||
|
||||
switch relayMode {
|
||||
case relayconstant.RelayModeImagesEdits:
|
||||
if strings.Contains(c.Request.Header.Get("Content-Type"), "multipart/form-data") {
|
||||
_, err := c.MultipartForm()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to parse image edit form request: %w", err)
|
||||
}
|
||||
formData := c.Request.PostForm
|
||||
imageRequest.Prompt = formData.Get("prompt")
|
||||
imageRequest.Model = formData.Get("model")
|
||||
imageRequest.N = uint(common.String2Int(formData.Get("n")))
|
||||
imageRequest.Quality = formData.Get("quality")
|
||||
imageRequest.Size = formData.Get("size")
|
||||
|
||||
if imageRequest.Model == "gpt-image-1" {
|
||||
if imageRequest.Quality == "" {
|
||||
imageRequest.Quality = "standard"
|
||||
}
|
||||
}
|
||||
if imageRequest.N == 0 {
|
||||
imageRequest.N = 1
|
||||
}
|
||||
|
||||
watermark := formData.Has("watermark")
|
||||
if watermark {
|
||||
imageRequest.Watermark = &watermark
|
||||
}
|
||||
break
|
||||
}
|
||||
fallthrough
|
||||
default:
|
||||
err := common.UnmarshalBodyReusable(c, imageRequest)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if imageRequest.Model == "" {
|
||||
//imageRequest.Model = "dall-e-3"
|
||||
return nil, errors.New("model is required")
|
||||
}
|
||||
|
||||
if strings.Contains(imageRequest.Size, "×") {
|
||||
return nil, errors.New("size an unexpected error occurred in the parameter, please use 'x' instead of the multiplication sign '×'")
|
||||
}
|
||||
|
||||
// Not "256x256", "512x512", or "1024x1024"
|
||||
if imageRequest.Model == "dall-e-2" || imageRequest.Model == "dall-e" {
|
||||
if imageRequest.Size != "" && imageRequest.Size != "256x256" && imageRequest.Size != "512x512" && imageRequest.Size != "1024x1024" {
|
||||
return nil, errors.New("size must be one of 256x256, 512x512, or 1024x1024 for dall-e-2 or dall-e")
|
||||
}
|
||||
if imageRequest.Size == "" {
|
||||
imageRequest.Size = "1024x1024"
|
||||
}
|
||||
} else if imageRequest.Model == "dall-e-3" {
|
||||
if imageRequest.Size != "" && imageRequest.Size != "1024x1024" && imageRequest.Size != "1024x1792" && imageRequest.Size != "1792x1024" {
|
||||
return nil, errors.New("size must be one of 1024x1024, 1024x1792 or 1792x1024 for dall-e-3")
|
||||
}
|
||||
if imageRequest.Quality == "" {
|
||||
imageRequest.Quality = "standard"
|
||||
}
|
||||
if imageRequest.Size == "" {
|
||||
imageRequest.Size = "1024x1024"
|
||||
}
|
||||
} else if imageRequest.Model == "gpt-image-1" {
|
||||
if imageRequest.Quality == "" {
|
||||
imageRequest.Quality = "auto"
|
||||
}
|
||||
}
|
||||
|
||||
//if imageRequest.Prompt == "" {
|
||||
// return nil, errors.New("prompt is required")
|
||||
//}
|
||||
|
||||
if imageRequest.N == 0 {
|
||||
imageRequest.N = 1
|
||||
}
|
||||
}
|
||||
|
||||
return imageRequest, nil
|
||||
}
|
||||
|
||||
func GetAndValidateClaudeRequest(c *gin.Context) (textRequest *dto.ClaudeRequest, err error) {
|
||||
textRequest = &dto.ClaudeRequest{}
|
||||
err = c.ShouldBindJSON(textRequest)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if textRequest.Messages == nil || len(textRequest.Messages) == 0 {
|
||||
return nil, errors.New("field messages is required")
|
||||
}
|
||||
if textRequest.Model == "" {
|
||||
return nil, errors.New("field model is required")
|
||||
}
|
||||
|
||||
//if textRequest.Stream {
|
||||
// relayInfo.IsStream = true
|
||||
//}
|
||||
|
||||
return textRequest, nil
|
||||
}
|
||||
|
||||
func GetAndValidateTextRequest(c *gin.Context, relayMode int) (*dto.GeneralOpenAIRequest, error) {
|
||||
textRequest := &dto.GeneralOpenAIRequest{}
|
||||
err := common.UnmarshalBodyReusable(c, textRequest)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if relayMode == relayconstant.RelayModeModerations && textRequest.Model == "" {
|
||||
textRequest.Model = "text-moderation-latest"
|
||||
}
|
||||
if relayMode == relayconstant.RelayModeEmbeddings && textRequest.Model == "" {
|
||||
textRequest.Model = c.Param("model")
|
||||
}
|
||||
|
||||
if textRequest.MaxTokens > math.MaxInt32/2 {
|
||||
return nil, errors.New("max_tokens is invalid")
|
||||
}
|
||||
if textRequest.Model == "" {
|
||||
return nil, errors.New("model is required")
|
||||
}
|
||||
if textRequest.WebSearchOptions != nil {
|
||||
if textRequest.WebSearchOptions.SearchContextSize != "" {
|
||||
validSizes := map[string]bool{
|
||||
"high": true,
|
||||
"medium": true,
|
||||
"low": true,
|
||||
}
|
||||
if !validSizes[textRequest.WebSearchOptions.SearchContextSize] {
|
||||
return nil, errors.New("invalid search_context_size, must be one of: high, medium, low")
|
||||
}
|
||||
} else {
|
||||
textRequest.WebSearchOptions.SearchContextSize = "medium"
|
||||
}
|
||||
}
|
||||
switch relayMode {
|
||||
case relayconstant.RelayModeCompletions:
|
||||
if textRequest.Prompt == "" {
|
||||
return nil, errors.New("field prompt is required")
|
||||
}
|
||||
case relayconstant.RelayModeChatCompletions:
|
||||
if len(textRequest.Messages) == 0 {
|
||||
return nil, errors.New("field messages is required")
|
||||
}
|
||||
case relayconstant.RelayModeEmbeddings:
|
||||
case relayconstant.RelayModeModerations:
|
||||
if textRequest.Input == nil || textRequest.Input == "" {
|
||||
return nil, errors.New("field input is required")
|
||||
}
|
||||
case relayconstant.RelayModeEdits:
|
||||
if textRequest.Instruction == "" {
|
||||
return nil, errors.New("field instruction is required")
|
||||
}
|
||||
}
|
||||
return textRequest, nil
|
||||
}
|
||||
|
||||
func GetAndValidateGeminiRequest(c *gin.Context) (*dto.GeminiChatRequest, error) {
|
||||
|
||||
request := &dto.GeminiChatRequest{}
|
||||
err := common.UnmarshalBodyReusable(c, request)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if len(request.Contents) == 0 {
|
||||
return nil, errors.New("contents is required")
|
||||
}
|
||||
|
||||
//if c.Query("alt") == "sse" {
|
||||
// relayInfo.IsStream = true
|
||||
//}
|
||||
|
||||
return request, nil
|
||||
}
|
||||
@@ -2,219 +2,94 @@ package relay
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"one-api/common"
|
||||
"one-api/constant"
|
||||
"one-api/dto"
|
||||
"one-api/model"
|
||||
"one-api/logger"
|
||||
relaycommon "one-api/relay/common"
|
||||
relayconstant "one-api/relay/constant"
|
||||
"one-api/relay/helper"
|
||||
"one-api/service"
|
||||
"one-api/setting"
|
||||
"one-api/setting/model_setting"
|
||||
"one-api/types"
|
||||
"strings"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
func getAndValidImageRequest(c *gin.Context, info *relaycommon.RelayInfo) (*dto.ImageRequest, error) {
|
||||
imageRequest := &dto.ImageRequest{}
|
||||
func ImageHelper(c *gin.Context, info *relaycommon.RelayInfo) (newAPIError *types.NewAPIError) {
|
||||
info.InitChannelMeta(c)
|
||||
|
||||
switch info.RelayMode {
|
||||
case relayconstant.RelayModeImagesEdits:
|
||||
_, err := c.MultipartForm()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
formData := c.Request.PostForm
|
||||
imageRequest.Prompt = formData.Get("prompt")
|
||||
imageRequest.Model = formData.Get("model")
|
||||
imageRequest.N = common.String2Int(formData.Get("n"))
|
||||
imageRequest.Quality = formData.Get("quality")
|
||||
imageRequest.Size = formData.Get("size")
|
||||
|
||||
if imageRequest.Model == "gpt-image-1" {
|
||||
if imageRequest.Quality == "" {
|
||||
imageRequest.Quality = "standard"
|
||||
}
|
||||
}
|
||||
if imageRequest.N == 0 {
|
||||
imageRequest.N = 1
|
||||
}
|
||||
|
||||
if info.ApiType == constant.APITypeVolcEngine {
|
||||
watermark := formData.Has("watermark")
|
||||
imageRequest.Watermark = &watermark
|
||||
}
|
||||
default:
|
||||
err := common.UnmarshalBodyReusable(c, imageRequest)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if imageRequest.Model == "" {
|
||||
imageRequest.Model = "dall-e-3"
|
||||
}
|
||||
|
||||
if strings.Contains(imageRequest.Size, "×") {
|
||||
return nil, errors.New("size an unexpected error occurred in the parameter, please use 'x' instead of the multiplication sign '×'")
|
||||
}
|
||||
|
||||
// Not "256x256", "512x512", or "1024x1024"
|
||||
if imageRequest.Model == "dall-e-2" || imageRequest.Model == "dall-e" {
|
||||
if imageRequest.Size != "" && imageRequest.Size != "256x256" && imageRequest.Size != "512x512" && imageRequest.Size != "1024x1024" {
|
||||
return nil, errors.New("size must be one of 256x256, 512x512, or 1024x1024 for dall-e-2 or dall-e")
|
||||
}
|
||||
if imageRequest.Size == "" {
|
||||
imageRequest.Size = "1024x1024"
|
||||
}
|
||||
} else if imageRequest.Model == "dall-e-3" {
|
||||
if imageRequest.Size != "" && imageRequest.Size != "1024x1024" && imageRequest.Size != "1024x1792" && imageRequest.Size != "1792x1024" {
|
||||
return nil, errors.New("size must be one of 1024x1024, 1024x1792 or 1792x1024 for dall-e-3")
|
||||
}
|
||||
if imageRequest.Quality == "" {
|
||||
imageRequest.Quality = "standard"
|
||||
}
|
||||
if imageRequest.Size == "" {
|
||||
imageRequest.Size = "1024x1024"
|
||||
}
|
||||
} else if imageRequest.Model == "gpt-image-1" {
|
||||
if imageRequest.Quality == "" {
|
||||
imageRequest.Quality = "auto"
|
||||
}
|
||||
}
|
||||
|
||||
if imageRequest.Prompt == "" {
|
||||
return nil, errors.New("prompt is required")
|
||||
}
|
||||
|
||||
if imageRequest.N == 0 {
|
||||
imageRequest.N = 1
|
||||
}
|
||||
imageReq, ok := info.Request.(*dto.ImageRequest)
|
||||
if !ok {
|
||||
return types.NewErrorWithStatusCode(fmt.Errorf("invalid request type, expected dto.ImageRequest, got %T", info.Request), types.ErrorCodeInvalidRequest, http.StatusBadRequest, types.ErrOptionWithSkipRetry())
|
||||
}
|
||||
|
||||
if setting.ShouldCheckPromptSensitive() {
|
||||
words, err := service.CheckSensitiveInput(imageRequest.Prompt)
|
||||
if err != nil {
|
||||
common.LogWarn(c, fmt.Sprintf("user sensitive words detected: %s", strings.Join(words, ",")))
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
return imageRequest, nil
|
||||
}
|
||||
|
||||
func ImageHelper(c *gin.Context) (newAPIError *types.NewAPIError) {
|
||||
relayInfo := relaycommon.GenRelayInfoImage(c)
|
||||
|
||||
imageRequest, err := getAndValidImageRequest(c, relayInfo)
|
||||
request, err := common.DeepCopy(imageReq)
|
||||
if err != nil {
|
||||
common.LogError(c, fmt.Sprintf("getAndValidImageRequest failed: %s", err.Error()))
|
||||
return types.NewError(err, types.ErrorCodeInvalidRequest)
|
||||
return types.NewError(fmt.Errorf("failed to copy request to ImageRequest: %w", err), types.ErrorCodeInvalidRequest, types.ErrOptionWithSkipRetry())
|
||||
}
|
||||
|
||||
err = helper.ModelMappedHelper(c, relayInfo, imageRequest)
|
||||
err = helper.ModelMappedHelper(c, info, request)
|
||||
if err != nil {
|
||||
return types.NewError(err, types.ErrorCodeChannelModelMappedError)
|
||||
return types.NewError(err, types.ErrorCodeChannelModelMappedError, types.ErrOptionWithSkipRetry())
|
||||
}
|
||||
|
||||
priceData, err := helper.ModelPriceHelper(c, relayInfo, len(imageRequest.Prompt), 0)
|
||||
if err != nil {
|
||||
return types.NewError(err, types.ErrorCodeModelPriceError)
|
||||
}
|
||||
var preConsumedQuota int
|
||||
var quota int
|
||||
var userQuota int
|
||||
if !priceData.UsePrice {
|
||||
// modelRatio 16 = modelPrice $0.04
|
||||
// per 1 modelRatio = $0.04 / 16
|
||||
// priceData.ModelPrice = 0.0025 * priceData.ModelRatio
|
||||
preConsumedQuota, userQuota, newAPIError = preConsumeQuota(c, priceData.ShouldPreConsumedQuota, relayInfo)
|
||||
if newAPIError != nil {
|
||||
return newAPIError
|
||||
}
|
||||
defer func() {
|
||||
if newAPIError != nil {
|
||||
returnPreConsumedQuota(c, relayInfo, userQuota, preConsumedQuota)
|
||||
}
|
||||
}()
|
||||
|
||||
} else {
|
||||
sizeRatio := 1.0
|
||||
qualityRatio := 1.0
|
||||
|
||||
if strings.HasPrefix(imageRequest.Model, "dall-e") {
|
||||
// Size
|
||||
if imageRequest.Size == "256x256" {
|
||||
sizeRatio = 0.4
|
||||
} else if imageRequest.Size == "512x512" {
|
||||
sizeRatio = 0.45
|
||||
} else if imageRequest.Size == "1024x1024" {
|
||||
sizeRatio = 1
|
||||
} else if imageRequest.Size == "1024x1792" || imageRequest.Size == "1792x1024" {
|
||||
sizeRatio = 2
|
||||
}
|
||||
|
||||
if imageRequest.Model == "dall-e-3" && imageRequest.Quality == "hd" {
|
||||
qualityRatio = 2.0
|
||||
if imageRequest.Size == "1024x1792" || imageRequest.Size == "1792x1024" {
|
||||
qualityRatio = 1.5
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// reset model price
|
||||
priceData.ModelPrice *= sizeRatio * qualityRatio * float64(imageRequest.N)
|
||||
quota = int(priceData.ModelPrice * priceData.GroupRatioInfo.GroupRatio * common.QuotaPerUnit)
|
||||
userQuota, err = model.GetUserQuota(relayInfo.UserId, false)
|
||||
if err != nil {
|
||||
return types.NewError(err, types.ErrorCodeQueryDataError)
|
||||
}
|
||||
if userQuota-quota < 0 {
|
||||
return types.NewError(fmt.Errorf("image pre-consumed quota failed, user quota: %s, need quota: %s", common.FormatQuota(userQuota), common.FormatQuota(quota)), types.ErrorCodeInsufficientUserQuota)
|
||||
}
|
||||
}
|
||||
|
||||
adaptor := GetAdaptor(relayInfo.ApiType)
|
||||
adaptor := GetAdaptor(info.ApiType)
|
||||
if adaptor == nil {
|
||||
return types.NewError(fmt.Errorf("invalid api type: %d", relayInfo.ApiType), types.ErrorCodeInvalidApiType)
|
||||
return types.NewError(fmt.Errorf("invalid api type: %d", info.ApiType), types.ErrorCodeInvalidApiType, types.ErrOptionWithSkipRetry())
|
||||
}
|
||||
adaptor.Init(relayInfo)
|
||||
adaptor.Init(info)
|
||||
|
||||
var requestBody io.Reader
|
||||
|
||||
convertedRequest, err := adaptor.ConvertImageRequest(c, relayInfo, *imageRequest)
|
||||
if err != nil {
|
||||
return types.NewError(err, types.ErrorCodeConvertRequestFailed)
|
||||
}
|
||||
if relayInfo.RelayMode == relayconstant.RelayModeImagesEdits {
|
||||
requestBody = convertedRequest.(io.Reader)
|
||||
if model_setting.GetGlobalSettings().PassThroughRequestEnabled || info.ChannelSetting.PassThroughBodyEnabled {
|
||||
body, err := common.GetRequestBody(c)
|
||||
if err != nil {
|
||||
return types.NewErrorWithStatusCode(err, types.ErrorCodeReadRequestBodyFailed, http.StatusBadRequest, types.ErrOptionWithSkipRetry())
|
||||
}
|
||||
requestBody = bytes.NewBuffer(body)
|
||||
} else {
|
||||
jsonData, err := json.Marshal(convertedRequest)
|
||||
convertedRequest, err := adaptor.ConvertImageRequest(c, info, *request)
|
||||
if err != nil {
|
||||
return types.NewError(err, types.ErrorCodeConvertRequestFailed)
|
||||
}
|
||||
requestBody = bytes.NewBuffer(jsonData)
|
||||
}
|
||||
|
||||
if common.DebugEnabled {
|
||||
println(fmt.Sprintf("image request body: %s", requestBody))
|
||||
switch convertedRequest.(type) {
|
||||
case *bytes.Buffer:
|
||||
requestBody = convertedRequest.(io.Reader)
|
||||
default:
|
||||
jsonData, err := common.Marshal(convertedRequest)
|
||||
if err != nil {
|
||||
return types.NewError(err, types.ErrorCodeConvertRequestFailed, types.ErrOptionWithSkipRetry())
|
||||
}
|
||||
|
||||
// apply param override
|
||||
if len(info.ParamOverride) > 0 {
|
||||
jsonData, err = relaycommon.ApplyParamOverride(jsonData, info.ParamOverride)
|
||||
if err != nil {
|
||||
return types.NewError(err, types.ErrorCodeChannelParamOverrideInvalid, types.ErrOptionWithSkipRetry())
|
||||
}
|
||||
}
|
||||
|
||||
if common.DebugEnabled {
|
||||
logger.LogDebug(c, fmt.Sprintf("image request body: %s", string(jsonData)))
|
||||
}
|
||||
requestBody = bytes.NewBuffer(jsonData)
|
||||
}
|
||||
}
|
||||
|
||||
statusCodeMappingStr := c.GetString("status_code_mapping")
|
||||
|
||||
resp, err := adaptor.DoRequest(c, relayInfo, requestBody)
|
||||
resp, err := adaptor.DoRequest(c, info, requestBody)
|
||||
if err != nil {
|
||||
return types.NewOpenAIError(err, types.ErrorCodeDoRequestFailed, http.StatusInternalServerError)
|
||||
}
|
||||
var httpResp *http.Response
|
||||
if resp != nil {
|
||||
httpResp = resp.(*http.Response)
|
||||
relayInfo.IsStream = relayInfo.IsStream || strings.HasPrefix(httpResp.Header.Get("Content-Type"), "text/event-stream")
|
||||
info.IsStream = info.IsStream || strings.HasPrefix(httpResp.Header.Get("Content-Type"), "text/event-stream")
|
||||
if httpResp.StatusCode != http.StatusOK {
|
||||
newAPIError = service.RelayErrorHandler(httpResp, false)
|
||||
// reset status code 重置状态码
|
||||
@@ -223,7 +98,7 @@ func ImageHelper(c *gin.Context) (newAPIError *types.NewAPIError) {
|
||||
}
|
||||
}
|
||||
|
||||
usage, newAPIError := adaptor.DoResponse(c, httpResp, relayInfo)
|
||||
usage, newAPIError := adaptor.DoResponse(c, httpResp, info)
|
||||
if newAPIError != nil {
|
||||
// reset status code 重置状态码
|
||||
service.ResetStatusCode(newAPIError, statusCodeMappingStr)
|
||||
@@ -231,17 +106,23 @@ func ImageHelper(c *gin.Context) (newAPIError *types.NewAPIError) {
|
||||
}
|
||||
|
||||
if usage.(*dto.Usage).TotalTokens == 0 {
|
||||
usage.(*dto.Usage).TotalTokens = imageRequest.N
|
||||
usage.(*dto.Usage).TotalTokens = int(request.N)
|
||||
}
|
||||
if usage.(*dto.Usage).PromptTokens == 0 {
|
||||
usage.(*dto.Usage).PromptTokens = imageRequest.N
|
||||
usage.(*dto.Usage).PromptTokens = int(request.N)
|
||||
}
|
||||
|
||||
quality := "standard"
|
||||
if imageRequest.Quality == "hd" {
|
||||
if request.Quality == "hd" {
|
||||
quality = "hd"
|
||||
}
|
||||
|
||||
logContent := fmt.Sprintf("大小 %s, 品质 %s", imageRequest.Size, quality)
|
||||
postConsumeQuota(c, relayInfo, usage.(*dto.Usage), preConsumedQuota, userQuota, priceData, logContent)
|
||||
var logContent string
|
||||
|
||||
if len(request.Size) > 0 {
|
||||
logContent = fmt.Sprintf("大小 %s, 品质 %s", request.Size, quality)
|
||||
}
|
||||
|
||||
postConsumeQuota(c, info, usage.(*dto.Usage), logContent)
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -170,26 +170,23 @@ func coverMidjourneyTaskDto(c *gin.Context, originTask *model.Midjourney) (midjo
|
||||
return
|
||||
}
|
||||
|
||||
func RelaySwapFace(c *gin.Context) *dto.MidjourneyResponse {
|
||||
startTime := time.Now().UnixNano() / int64(time.Millisecond)
|
||||
tokenId := c.GetInt("token_id")
|
||||
userId := c.GetInt("id")
|
||||
//group := c.GetString("group")
|
||||
channelId := c.GetInt("channel_id")
|
||||
relayInfo := relaycommon.GenRelayInfo(c)
|
||||
func RelaySwapFace(c *gin.Context, info *relaycommon.RelayInfo) *dto.MidjourneyResponse {
|
||||
var swapFaceRequest dto.SwapFaceRequest
|
||||
err := common.UnmarshalBodyReusable(c, &swapFaceRequest)
|
||||
if err != nil {
|
||||
return service.MidjourneyErrorWrapper(constant.MjRequestError, "bind_request_body_failed")
|
||||
}
|
||||
|
||||
info.InitChannelMeta(c)
|
||||
|
||||
if swapFaceRequest.SourceBase64 == "" || swapFaceRequest.TargetBase64 == "" {
|
||||
return service.MidjourneyErrorWrapper(constant.MjRequestError, "sour_base64_and_target_base64_is_required")
|
||||
}
|
||||
modelName := service.CoverActionToModelName(constant.MjActionSwapFace)
|
||||
|
||||
priceData := helper.ModelPriceHelperPerCall(c, relayInfo)
|
||||
priceData := helper.ModelPriceHelperPerCall(c, info)
|
||||
|
||||
userQuota, err := model.GetUserQuota(userId, false)
|
||||
userQuota, err := model.GetUserQuota(info.UserId, false)
|
||||
if err != nil {
|
||||
return &dto.MidjourneyResponse{
|
||||
Code: 4,
|
||||
@@ -212,32 +209,31 @@ func RelaySwapFace(c *gin.Context) *dto.MidjourneyResponse {
|
||||
}
|
||||
defer func() {
|
||||
if mjResp.StatusCode == 200 && mjResp.Response.Code == 1 {
|
||||
err := service.PostConsumeQuota(relayInfo, priceData.Quota, 0, true)
|
||||
err := service.PostConsumeQuota(info, priceData.Quota, 0, true)
|
||||
if err != nil {
|
||||
common.SysError("error consuming token remain quota: " + err.Error())
|
||||
common.SysLog("error consuming token remain quota: " + err.Error())
|
||||
}
|
||||
|
||||
tokenName := c.GetString("token_name")
|
||||
logContent := fmt.Sprintf("模型固定价格 %.2f,分组倍率 %.2f,操作 %s", priceData.ModelPrice, priceData.GroupRatioInfo.GroupRatio, constant.MjActionSwapFace)
|
||||
other := service.GenerateMjOtherInfo(priceData)
|
||||
model.RecordConsumeLog(c, relayInfo.UserId, model.RecordConsumeLogParams{
|
||||
ChannelId: channelId,
|
||||
model.RecordConsumeLog(c, info.UserId, model.RecordConsumeLogParams{
|
||||
ChannelId: info.ChannelId,
|
||||
ModelName: modelName,
|
||||
TokenName: tokenName,
|
||||
Quota: priceData.Quota,
|
||||
Content: logContent,
|
||||
TokenId: tokenId,
|
||||
UserQuota: userQuota,
|
||||
Group: relayInfo.UsingGroup,
|
||||
TokenId: info.TokenId,
|
||||
Group: info.UsingGroup,
|
||||
Other: other,
|
||||
})
|
||||
model.UpdateUserUsedQuotaAndRequestCount(userId, priceData.Quota)
|
||||
model.UpdateChannelUsedQuota(channelId, priceData.Quota)
|
||||
model.UpdateUserUsedQuotaAndRequestCount(info.UserId, priceData.Quota)
|
||||
model.UpdateChannelUsedQuota(info.ChannelId, priceData.Quota)
|
||||
}
|
||||
}()
|
||||
midjResponse := &mjResp.Response
|
||||
midjourneyTask := &model.Midjourney{
|
||||
UserId: userId,
|
||||
UserId: info.UserId,
|
||||
Code: midjResponse.Code,
|
||||
Action: constant.MjActionSwapFace,
|
||||
MjId: midjResponse.Result,
|
||||
@@ -245,7 +241,7 @@ func RelaySwapFace(c *gin.Context) *dto.MidjourneyResponse {
|
||||
PromptEn: "",
|
||||
Description: midjResponse.Description,
|
||||
State: "",
|
||||
SubmitTime: startTime,
|
||||
SubmitTime: info.StartTime.UnixNano() / int64(time.Millisecond),
|
||||
StartTime: time.Now().UnixNano() / int64(time.Millisecond),
|
||||
FinishTime: 0,
|
||||
ImageUrl: "",
|
||||
@@ -300,7 +296,7 @@ func RelayMidjourneyTaskImageSeed(c *gin.Context) *dto.MidjourneyResponse {
|
||||
if err != nil {
|
||||
return service.MidjourneyErrorWrapper(constant.MjRequestError, "unmarshal_response_body_failed")
|
||||
}
|
||||
common.IOCopyBytesGracefully(c, nil, respBody)
|
||||
service.IOCopyBytesGracefully(c, nil, respBody)
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -369,14 +365,7 @@ func RelayMidjourneyTask(c *gin.Context, relayMode int) *dto.MidjourneyResponse
|
||||
return nil
|
||||
}
|
||||
|
||||
func RelayMidjourneySubmit(c *gin.Context, relayMode int) *dto.MidjourneyResponse {
|
||||
|
||||
//tokenId := c.GetInt("token_id")
|
||||
//channelType := c.GetInt("channel")
|
||||
userId := c.GetInt("id")
|
||||
group := c.GetString("group")
|
||||
channelId := c.GetInt("channel_id")
|
||||
relayInfo := relaycommon.GenRelayInfo(c)
|
||||
func RelayMidjourneySubmit(c *gin.Context, relayInfo *relaycommon.RelayInfo) *dto.MidjourneyResponse {
|
||||
consumeQuota := true
|
||||
var midjRequest dto.MidjourneyRequest
|
||||
err := common.UnmarshalBodyReusable(c, &midjRequest)
|
||||
@@ -384,35 +373,37 @@ func RelayMidjourneySubmit(c *gin.Context, relayMode int) *dto.MidjourneyRespons
|
||||
return service.MidjourneyErrorWrapper(constant.MjRequestError, "bind_request_body_failed")
|
||||
}
|
||||
|
||||
if relayMode == relayconstant.RelayModeMidjourneyAction { // midjourney plus,需要从customId中获取任务信息
|
||||
relayInfo.InitChannelMeta(c)
|
||||
|
||||
if relayInfo.RelayMode == relayconstant.RelayModeMidjourneyAction { // midjourney plus,需要从customId中获取任务信息
|
||||
mjErr := service.CoverPlusActionToNormalAction(&midjRequest)
|
||||
if mjErr != nil {
|
||||
return mjErr
|
||||
}
|
||||
relayMode = relayconstant.RelayModeMidjourneyChange
|
||||
relayInfo.RelayMode = relayconstant.RelayModeMidjourneyChange
|
||||
}
|
||||
if relayMode == relayconstant.RelayModeMidjourneyVideo {
|
||||
if relayInfo.RelayMode == relayconstant.RelayModeMidjourneyVideo {
|
||||
midjRequest.Action = constant.MjActionVideo
|
||||
}
|
||||
|
||||
if relayMode == relayconstant.RelayModeMidjourneyImagine { //绘画任务,此类任务可重复
|
||||
if relayInfo.RelayMode == relayconstant.RelayModeMidjourneyImagine { //绘画任务,此类任务可重复
|
||||
if midjRequest.Prompt == "" {
|
||||
return service.MidjourneyErrorWrapper(constant.MjRequestError, "prompt_is_required")
|
||||
}
|
||||
midjRequest.Action = constant.MjActionImagine
|
||||
} else if relayMode == relayconstant.RelayModeMidjourneyDescribe { //按图生文任务,此类任务可重复
|
||||
} else if relayInfo.RelayMode == relayconstant.RelayModeMidjourneyDescribe { //按图生文任务,此类任务可重复
|
||||
midjRequest.Action = constant.MjActionDescribe
|
||||
} else if relayMode == relayconstant.RelayModeMidjourneyEdits { //编辑任务,此类任务可重复
|
||||
} else if relayInfo.RelayMode == relayconstant.RelayModeMidjourneyEdits { //编辑任务,此类任务可重复
|
||||
midjRequest.Action = constant.MjActionEdits
|
||||
} else if relayMode == relayconstant.RelayModeMidjourneyShorten { //缩短任务,此类任务可重复,plus only
|
||||
} else if relayInfo.RelayMode == relayconstant.RelayModeMidjourneyShorten { //缩短任务,此类任务可重复,plus only
|
||||
midjRequest.Action = constant.MjActionShorten
|
||||
} else if relayMode == relayconstant.RelayModeMidjourneyBlend { //绘画任务,此类任务可重复
|
||||
} else if relayInfo.RelayMode == relayconstant.RelayModeMidjourneyBlend { //绘画任务,此类任务可重复
|
||||
midjRequest.Action = constant.MjActionBlend
|
||||
} else if relayMode == relayconstant.RelayModeMidjourneyUpload { //绘画任务,此类任务可重复
|
||||
} else if relayInfo.RelayMode == relayconstant.RelayModeMidjourneyUpload { //绘画任务,此类任务可重复
|
||||
midjRequest.Action = constant.MjActionUpload
|
||||
} else if midjRequest.TaskId != "" { //放大、变换任务,此类任务,如果重复且已有结果,远端api会直接返回最终结果
|
||||
mjId := ""
|
||||
if relayMode == relayconstant.RelayModeMidjourneyChange {
|
||||
if relayInfo.RelayMode == relayconstant.RelayModeMidjourneyChange {
|
||||
if midjRequest.TaskId == "" {
|
||||
return service.MidjourneyErrorWrapper(constant.MjRequestError, "task_id_is_required")
|
||||
} else if midjRequest.Action == "" {
|
||||
@@ -422,7 +413,7 @@ func RelayMidjourneySubmit(c *gin.Context, relayMode int) *dto.MidjourneyRespons
|
||||
}
|
||||
//action = midjRequest.Action
|
||||
mjId = midjRequest.TaskId
|
||||
} else if relayMode == relayconstant.RelayModeMidjourneySimpleChange {
|
||||
} else if relayInfo.RelayMode == relayconstant.RelayModeMidjourneySimpleChange {
|
||||
if midjRequest.Content == "" {
|
||||
return service.MidjourneyErrorWrapper(constant.MjRequestError, "content_is_required")
|
||||
}
|
||||
@@ -432,13 +423,13 @@ func RelayMidjourneySubmit(c *gin.Context, relayMode int) *dto.MidjourneyRespons
|
||||
}
|
||||
mjId = params.TaskId
|
||||
midjRequest.Action = params.Action
|
||||
} else if relayMode == relayconstant.RelayModeMidjourneyModal {
|
||||
} else if relayInfo.RelayMode == relayconstant.RelayModeMidjourneyModal {
|
||||
//if midjRequest.MaskBase64 == "" {
|
||||
// return service.MidjourneyErrorWrapper(constant.MjRequestError, "mask_base64_is_required")
|
||||
//}
|
||||
mjId = midjRequest.TaskId
|
||||
midjRequest.Action = constant.MjActionModal
|
||||
} else if relayMode == relayconstant.RelayModeMidjourneyVideo {
|
||||
} else if relayInfo.RelayMode == relayconstant.RelayModeMidjourneyVideo {
|
||||
midjRequest.Action = constant.MjActionVideo
|
||||
if midjRequest.TaskId == "" {
|
||||
return service.MidjourneyErrorWrapper(constant.MjRequestError, "task_id_is_required")
|
||||
@@ -448,12 +439,12 @@ func RelayMidjourneySubmit(c *gin.Context, relayMode int) *dto.MidjourneyRespons
|
||||
mjId = midjRequest.TaskId
|
||||
}
|
||||
|
||||
originTask := model.GetByMJId(userId, mjId)
|
||||
originTask := model.GetByMJId(relayInfo.UserId, mjId)
|
||||
if originTask == nil {
|
||||
return service.MidjourneyErrorWrapper(constant.MjRequestError, "task_not_found")
|
||||
} else { //原任务的Status=SUCCESS,则可以做放大UPSCALE、变换VARIATION等动作,此时必须使用原来的请求地址才能正确处理
|
||||
if setting.MjActionCheckSuccessEnabled {
|
||||
if originTask.Status != "SUCCESS" && relayMode != relayconstant.RelayModeMidjourneyModal {
|
||||
if originTask.Status != "SUCCESS" && relayInfo.RelayMode != relayconstant.RelayModeMidjourneyModal {
|
||||
return service.MidjourneyErrorWrapper(constant.MjRequestError, "task_status_not_success")
|
||||
}
|
||||
}
|
||||
@@ -496,7 +487,7 @@ func RelayMidjourneySubmit(c *gin.Context, relayMode int) *dto.MidjourneyRespons
|
||||
|
||||
priceData := helper.ModelPriceHelperPerCall(c, relayInfo)
|
||||
|
||||
userQuota, err := model.GetUserQuota(userId, false)
|
||||
userQuota, err := model.GetUserQuota(relayInfo.UserId, false)
|
||||
if err != nil {
|
||||
return &dto.MidjourneyResponse{
|
||||
Code: 4,
|
||||
@@ -521,24 +512,23 @@ func RelayMidjourneySubmit(c *gin.Context, relayMode int) *dto.MidjourneyRespons
|
||||
if consumeQuota && midjResponseWithStatus.StatusCode == 200 {
|
||||
err := service.PostConsumeQuota(relayInfo, priceData.Quota, 0, true)
|
||||
if err != nil {
|
||||
common.SysError("error consuming token remain quota: " + err.Error())
|
||||
common.SysLog("error consuming token remain quota: " + err.Error())
|
||||
}
|
||||
tokenName := c.GetString("token_name")
|
||||
logContent := fmt.Sprintf("模型固定价格 %.2f,分组倍率 %.2f,操作 %s,ID %s", priceData.ModelPrice, priceData.GroupRatioInfo.GroupRatio, midjRequest.Action, midjResponse.Result)
|
||||
other := service.GenerateMjOtherInfo(priceData)
|
||||
model.RecordConsumeLog(c, relayInfo.UserId, model.RecordConsumeLogParams{
|
||||
ChannelId: channelId,
|
||||
ChannelId: relayInfo.ChannelId,
|
||||
ModelName: modelName,
|
||||
TokenName: tokenName,
|
||||
Quota: priceData.Quota,
|
||||
Content: logContent,
|
||||
TokenId: relayInfo.TokenId,
|
||||
UserQuota: userQuota,
|
||||
Group: group,
|
||||
Group: relayInfo.UsingGroup,
|
||||
Other: other,
|
||||
})
|
||||
model.UpdateUserUsedQuotaAndRequestCount(userId, priceData.Quota)
|
||||
model.UpdateChannelUsedQuota(channelId, priceData.Quota)
|
||||
model.UpdateUserUsedQuotaAndRequestCount(relayInfo.UserId, priceData.Quota)
|
||||
model.UpdateChannelUsedQuota(relayInfo.ChannelId, priceData.Quota)
|
||||
}
|
||||
}()
|
||||
|
||||
@@ -550,7 +540,7 @@ func RelayMidjourneySubmit(c *gin.Context, relayMode int) *dto.MidjourneyRespons
|
||||
// 24-prompt包含敏感词 {"code":24,"description":"可能包含敏感词","properties":{"promptEn":"nude body","bannedWord":"nude"}}
|
||||
// other: 提交错误,description为错误描述
|
||||
midjourneyTask := &model.Midjourney{
|
||||
UserId: userId,
|
||||
UserId: relayInfo.UserId,
|
||||
Code: midjResponse.Code,
|
||||
Action: midjRequest.Action,
|
||||
MjId: midjResponse.Result,
|
||||
@@ -572,7 +562,7 @@ func RelayMidjourneySubmit(c *gin.Context, relayMode int) *dto.MidjourneyRespons
|
||||
//无实例账号自动禁用渠道(No available account instance)
|
||||
channel, err := model.GetChannelById(midjourneyTask.ChannelId, true)
|
||||
if err != nil {
|
||||
common.SysError("get_channel_null: " + err.Error())
|
||||
common.SysLog("get_channel_null: " + err.Error())
|
||||
}
|
||||
if channel.GetAutoBan() && common.AutomaticDisableChannelEnabled {
|
||||
model.UpdateChannelStatus(midjourneyTask.ChannelId, "", 2, "No available account instance")
|
||||
@@ -1,8 +1,8 @@
|
||||
package relay
|
||||
|
||||
import (
|
||||
"github.com/gin-gonic/gin"
|
||||
"one-api/constant"
|
||||
commonconstant "one-api/constant"
|
||||
"one-api/relay/channel"
|
||||
"one-api/relay/channel/ali"
|
||||
"one-api/relay/channel/aws"
|
||||
@@ -19,6 +19,7 @@ import (
|
||||
"one-api/relay/channel/jina"
|
||||
"one-api/relay/channel/mistral"
|
||||
"one-api/relay/channel/mokaai"
|
||||
"one-api/relay/channel/moonshot"
|
||||
"one-api/relay/channel/ollama"
|
||||
"one-api/relay/channel/openai"
|
||||
"one-api/relay/channel/palm"
|
||||
@@ -27,6 +28,7 @@ import (
|
||||
taskjimeng "one-api/relay/channel/task/jimeng"
|
||||
"one-api/relay/channel/task/kling"
|
||||
"one-api/relay/channel/task/suno"
|
||||
taskVidu "one-api/relay/channel/task/vidu"
|
||||
"one-api/relay/channel/tencent"
|
||||
"one-api/relay/channel/vertex"
|
||||
"one-api/relay/channel/volcengine"
|
||||
@@ -34,7 +36,8 @@ import (
|
||||
"one-api/relay/channel/xunfei"
|
||||
"one-api/relay/channel/zhipu"
|
||||
"one-api/relay/channel/zhipu_4v"
|
||||
"one-api/relay/channel/submodel"
|
||||
"strconv"
|
||||
"one-api/relay/channel/submodel"
|
||||
)
|
||||
|
||||
func GetAdaptor(apiType int) channel.Adaptor {
|
||||
@@ -97,22 +100,38 @@ func GetAdaptor(apiType int) channel.Adaptor {
|
||||
return &coze.Adaptor{}
|
||||
case constant.APITypeJimeng:
|
||||
return &jimeng.Adaptor{}
|
||||
case constant.APITypeMoonshot:
|
||||
return &moonshot.Adaptor{} // Moonshot uses Claude API
|
||||
case constant.APITypeSubmodel:
|
||||
return &submodel.Adaptor{}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func GetTaskAdaptor(platform commonconstant.TaskPlatform) channel.TaskAdaptor {
|
||||
func GetTaskPlatform(c *gin.Context) constant.TaskPlatform {
|
||||
channelType := c.GetInt("channel_type")
|
||||
if channelType > 0 {
|
||||
return constant.TaskPlatform(strconv.Itoa(channelType))
|
||||
}
|
||||
return constant.TaskPlatform(c.GetString("platform"))
|
||||
}
|
||||
|
||||
func GetTaskAdaptor(platform constant.TaskPlatform) channel.TaskAdaptor {
|
||||
switch platform {
|
||||
//case constant.APITypeAIProxyLibrary:
|
||||
// return &aiproxy.Adaptor{}
|
||||
case commonconstant.TaskPlatformSuno:
|
||||
case constant.TaskPlatformSuno:
|
||||
return &suno.TaskAdaptor{}
|
||||
case commonconstant.TaskPlatformKling:
|
||||
return &kling.TaskAdaptor{}
|
||||
case commonconstant.TaskPlatformJimeng:
|
||||
return &taskjimeng.TaskAdaptor{}
|
||||
}
|
||||
if channelType, err := strconv.ParseInt(string(platform), 10, 64); err == nil {
|
||||
switch channelType {
|
||||
case constant.ChannelTypeKling:
|
||||
return &kling.TaskAdaptor{}
|
||||
case constant.ChannelTypeJimeng:
|
||||
return &taskjimeng.TaskAdaptor{}
|
||||
case constant.ChannelTypeVidu:
|
||||
return &taskVidu.TaskAdaptor{}
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -22,24 +22,31 @@ import (
|
||||
/*
|
||||
Task 任务通过平台、Action 区分任务
|
||||
*/
|
||||
func RelayTaskSubmit(c *gin.Context, relayMode int) (taskErr *dto.TaskError) {
|
||||
func RelayTaskSubmit(c *gin.Context, info *relaycommon.RelayInfo) (taskErr *dto.TaskError) {
|
||||
info.InitChannelMeta(c)
|
||||
// ensure TaskRelayInfo is initialized to avoid nil dereference when accessing embedded fields
|
||||
if info.TaskRelayInfo == nil {
|
||||
info.TaskRelayInfo = &relaycommon.TaskRelayInfo{}
|
||||
}
|
||||
platform := constant.TaskPlatform(c.GetString("platform"))
|
||||
relayInfo := relaycommon.GenTaskRelayInfo(c)
|
||||
if platform == "" {
|
||||
platform = GetTaskPlatform(c)
|
||||
}
|
||||
|
||||
adaptor := GetTaskAdaptor(platform)
|
||||
if adaptor == nil {
|
||||
return service.TaskErrorWrapperLocal(fmt.Errorf("invalid api platform: %s", platform), "invalid_api_platform", http.StatusBadRequest)
|
||||
}
|
||||
adaptor.Init(relayInfo)
|
||||
adaptor.Init(info)
|
||||
// get & validate taskRequest 获取并验证文本请求
|
||||
taskErr = adaptor.ValidateRequestAndSetAction(c, relayInfo)
|
||||
taskErr = adaptor.ValidateRequestAndSetAction(c, info)
|
||||
if taskErr != nil {
|
||||
return
|
||||
}
|
||||
|
||||
modelName := relayInfo.OriginModelName
|
||||
modelName := info.OriginModelName
|
||||
if modelName == "" {
|
||||
modelName = service.CoverTaskActionToModelName(platform, relayInfo.Action)
|
||||
modelName = service.CoverTaskActionToModelName(platform, info.Action)
|
||||
}
|
||||
modelPrice, success := ratio_setting.GetModelPrice(modelName, true)
|
||||
if !success {
|
||||
@@ -52,15 +59,15 @@ func RelayTaskSubmit(c *gin.Context, relayMode int) (taskErr *dto.TaskError) {
|
||||
}
|
||||
|
||||
// 预扣
|
||||
groupRatio := ratio_setting.GetGroupRatio(relayInfo.UsingGroup)
|
||||
groupRatio := ratio_setting.GetGroupRatio(info.UsingGroup)
|
||||
var ratio float64
|
||||
userGroupRatio, hasUserGroupRatio := ratio_setting.GetGroupGroupRatio(relayInfo.UserGroup, relayInfo.UsingGroup)
|
||||
userGroupRatio, hasUserGroupRatio := ratio_setting.GetGroupGroupRatio(info.UserGroup, info.UsingGroup)
|
||||
if hasUserGroupRatio {
|
||||
ratio = modelPrice * userGroupRatio
|
||||
} else {
|
||||
ratio = modelPrice * groupRatio
|
||||
}
|
||||
userQuota, err := model.GetUserQuota(relayInfo.UserId, false)
|
||||
userQuota, err := model.GetUserQuota(info.UserId, false)
|
||||
if err != nil {
|
||||
taskErr = service.TaskErrorWrapper(err, "get_user_quota_failed", http.StatusInternalServerError)
|
||||
return
|
||||
@@ -71,8 +78,8 @@ func RelayTaskSubmit(c *gin.Context, relayMode int) (taskErr *dto.TaskError) {
|
||||
return
|
||||
}
|
||||
|
||||
if relayInfo.OriginTaskID != "" {
|
||||
originTask, exist, err := model.GetByTaskId(relayInfo.UserId, relayInfo.OriginTaskID)
|
||||
if info.OriginTaskID != "" {
|
||||
originTask, exist, err := model.GetByTaskId(info.UserId, info.OriginTaskID)
|
||||
if err != nil {
|
||||
taskErr = service.TaskErrorWrapper(err, "get_origin_task_failed", http.StatusInternalServerError)
|
||||
return
|
||||
@@ -81,7 +88,7 @@ func RelayTaskSubmit(c *gin.Context, relayMode int) (taskErr *dto.TaskError) {
|
||||
taskErr = service.TaskErrorWrapperLocal(errors.New("task_origin_not_exist"), "task_not_exist", http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
if originTask.ChannelId != relayInfo.ChannelId {
|
||||
if originTask.ChannelId != info.ChannelId {
|
||||
channel, err := model.GetChannelById(originTask.ChannelId, true)
|
||||
if err != nil {
|
||||
taskErr = service.TaskErrorWrapperLocal(err, "channel_not_found", http.StatusBadRequest)
|
||||
@@ -94,19 +101,19 @@ func RelayTaskSubmit(c *gin.Context, relayMode int) (taskErr *dto.TaskError) {
|
||||
c.Set("channel_id", originTask.ChannelId)
|
||||
c.Request.Header.Set("Authorization", fmt.Sprintf("Bearer %s", channel.Key))
|
||||
|
||||
relayInfo.BaseUrl = channel.GetBaseURL()
|
||||
relayInfo.ChannelId = originTask.ChannelId
|
||||
info.ChannelBaseUrl = channel.GetBaseURL()
|
||||
info.ChannelId = originTask.ChannelId
|
||||
}
|
||||
}
|
||||
|
||||
// build body
|
||||
requestBody, err := adaptor.BuildRequestBody(c, relayInfo)
|
||||
requestBody, err := adaptor.BuildRequestBody(c, info)
|
||||
if err != nil {
|
||||
taskErr = service.TaskErrorWrapper(err, "build_request_failed", http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
// do request
|
||||
resp, err := adaptor.DoRequest(c, relayInfo, requestBody)
|
||||
resp, err := adaptor.DoRequest(c, info, requestBody)
|
||||
if err != nil {
|
||||
taskErr = service.TaskErrorWrapper(err, "do_request_failed", http.StatusInternalServerError)
|
||||
return
|
||||
@@ -120,11 +127,11 @@ func RelayTaskSubmit(c *gin.Context, relayMode int) (taskErr *dto.TaskError) {
|
||||
|
||||
defer func() {
|
||||
// release quota
|
||||
if relayInfo.ConsumeQuota && taskErr == nil {
|
||||
if info.ConsumeQuota && taskErr == nil {
|
||||
|
||||
err := service.PostConsumeQuota(relayInfo.RelayInfo, quota, 0, true)
|
||||
err := service.PostConsumeQuota(info, quota, 0, true)
|
||||
if err != nil {
|
||||
common.SysError("error consuming token remain quota: " + err.Error())
|
||||
common.SysLog("error consuming token remain quota: " + err.Error())
|
||||
}
|
||||
if quota != 0 {
|
||||
tokenName := c.GetString("token_name")
|
||||
@@ -132,41 +139,40 @@ func RelayTaskSubmit(c *gin.Context, relayMode int) (taskErr *dto.TaskError) {
|
||||
if hasUserGroupRatio {
|
||||
gRatio = userGroupRatio
|
||||
}
|
||||
logContent := fmt.Sprintf("模型固定价格 %.2f,分组倍率 %.2f,操作 %s", modelPrice, gRatio, relayInfo.Action)
|
||||
logContent := fmt.Sprintf("模型固定价格 %.2f,分组倍率 %.2f,操作 %s", modelPrice, gRatio, info.Action)
|
||||
other := make(map[string]interface{})
|
||||
other["model_price"] = modelPrice
|
||||
other["group_ratio"] = groupRatio
|
||||
if hasUserGroupRatio {
|
||||
other["user_group_ratio"] = userGroupRatio
|
||||
}
|
||||
model.RecordConsumeLog(c, relayInfo.UserId, model.RecordConsumeLogParams{
|
||||
ChannelId: relayInfo.ChannelId,
|
||||
model.RecordConsumeLog(c, info.UserId, model.RecordConsumeLogParams{
|
||||
ChannelId: info.ChannelId,
|
||||
ModelName: modelName,
|
||||
TokenName: tokenName,
|
||||
Quota: quota,
|
||||
Content: logContent,
|
||||
TokenId: relayInfo.TokenId,
|
||||
UserQuota: userQuota,
|
||||
Group: relayInfo.UsingGroup,
|
||||
TokenId: info.TokenId,
|
||||
Group: info.UsingGroup,
|
||||
Other: other,
|
||||
})
|
||||
model.UpdateUserUsedQuotaAndRequestCount(relayInfo.UserId, quota)
|
||||
model.UpdateChannelUsedQuota(relayInfo.ChannelId, quota)
|
||||
model.UpdateUserUsedQuotaAndRequestCount(info.UserId, quota)
|
||||
model.UpdateChannelUsedQuota(info.ChannelId, quota)
|
||||
}
|
||||
}
|
||||
}()
|
||||
|
||||
taskID, taskData, taskErr := adaptor.DoResponse(c, resp, relayInfo)
|
||||
taskID, taskData, taskErr := adaptor.DoResponse(c, resp, info)
|
||||
if taskErr != nil {
|
||||
return
|
||||
}
|
||||
relayInfo.ConsumeQuota = true
|
||||
info.ConsumeQuota = true
|
||||
// insert task
|
||||
task := model.InitTask(platform, relayInfo)
|
||||
task := model.InitTask(platform, info)
|
||||
task.TaskID = taskID
|
||||
task.Quota = quota
|
||||
task.Data = taskData
|
||||
task.Action = relayInfo.Action
|
||||
task.Action = info.Action
|
||||
err = task.Insert()
|
||||
if err != nil {
|
||||
taskErr = service.TaskErrorWrapper(err, "insert_task_failed", http.StatusInternalServerError)
|
||||
@@ -178,7 +184,7 @@ func RelayTaskSubmit(c *gin.Context, relayMode int) (taskErr *dto.TaskError) {
|
||||
var fetchRespBuilders = map[int]func(c *gin.Context) (respBody []byte, taskResp *dto.TaskError){
|
||||
relayconstant.RelayModeSunoFetchByID: sunoFetchByIDRespBodyBuilder,
|
||||
relayconstant.RelayModeSunoFetch: sunoFetchRespBodyBuilder,
|
||||
relayconstant.RelayModeKlingFetchByID: videoFetchByIDRespBodyBuilder,
|
||||
relayconstant.RelayModeVideoFetchByID: videoFetchByIDRespBodyBuilder,
|
||||
}
|
||||
|
||||
func RelayTaskFetch(c *gin.Context, relayMode int) (taskResp *dto.TaskError) {
|
||||
@@ -255,6 +261,9 @@ func sunoFetchByIDRespBodyBuilder(c *gin.Context) (respBody []byte, taskResp *dt
|
||||
|
||||
func videoFetchByIDRespBodyBuilder(c *gin.Context) (respBody []byte, taskResp *dto.TaskError) {
|
||||
taskId := c.Param("task_id")
|
||||
if taskId == "" {
|
||||
taskId = c.GetString("task_id")
|
||||
}
|
||||
userId := c.GetInt("id")
|
||||
|
||||
originTask, exist, err := model.GetByTaskId(userId, taskId)
|
||||
|
||||
@@ -3,86 +3,75 @@ package relay
|
||||
import (
|
||||
"bytes"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"one-api/common"
|
||||
"one-api/dto"
|
||||
relaycommon "one-api/relay/common"
|
||||
"one-api/relay/helper"
|
||||
"one-api/service"
|
||||
"one-api/setting/model_setting"
|
||||
"one-api/types"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
func getRerankPromptToken(rerankRequest dto.RerankRequest) int {
|
||||
token := service.CountTokenInput(rerankRequest.Query, rerankRequest.Model)
|
||||
for _, document := range rerankRequest.Documents {
|
||||
tkm := service.CountTokenInput(document, rerankRequest.Model)
|
||||
token += tkm
|
||||
func RerankHelper(c *gin.Context, info *relaycommon.RelayInfo) (newAPIError *types.NewAPIError) {
|
||||
info.InitChannelMeta(c)
|
||||
|
||||
rerankReq, ok := info.Request.(*dto.RerankRequest)
|
||||
if !ok {
|
||||
return types.NewErrorWithStatusCode(fmt.Errorf("invalid request type, expected dto.RerankRequest, got %T", info.Request), types.ErrorCodeInvalidRequest, http.StatusBadRequest, types.ErrOptionWithSkipRetry())
|
||||
}
|
||||
return token
|
||||
}
|
||||
|
||||
func RerankHelper(c *gin.Context, relayMode int) (newAPIError *types.NewAPIError) {
|
||||
|
||||
var rerankRequest *dto.RerankRequest
|
||||
err := common.UnmarshalBodyReusable(c, &rerankRequest)
|
||||
request, err := common.DeepCopy(rerankReq)
|
||||
if err != nil {
|
||||
common.LogError(c, fmt.Sprintf("getAndValidateTextRequest failed: %s", err.Error()))
|
||||
return types.NewError(err, types.ErrorCodeInvalidRequest)
|
||||
return types.NewError(fmt.Errorf("failed to copy request to ImageRequest: %w", err), types.ErrorCodeInvalidRequest, types.ErrOptionWithSkipRetry())
|
||||
}
|
||||
|
||||
relayInfo := relaycommon.GenRelayInfoRerank(c, rerankRequest)
|
||||
|
||||
if rerankRequest.Query == "" {
|
||||
return types.NewError(fmt.Errorf("query is empty"), types.ErrorCodeInvalidRequest)
|
||||
}
|
||||
if len(rerankRequest.Documents) == 0 {
|
||||
return types.NewError(fmt.Errorf("documents is empty"), types.ErrorCodeInvalidRequest)
|
||||
}
|
||||
|
||||
err = helper.ModelMappedHelper(c, relayInfo, rerankRequest)
|
||||
err = helper.ModelMappedHelper(c, info, request)
|
||||
if err != nil {
|
||||
return types.NewError(err, types.ErrorCodeChannelModelMappedError)
|
||||
return types.NewError(err, types.ErrorCodeChannelModelMappedError, types.ErrOptionWithSkipRetry())
|
||||
}
|
||||
|
||||
promptToken := getRerankPromptToken(*rerankRequest)
|
||||
relayInfo.PromptTokens = promptToken
|
||||
|
||||
priceData, err := helper.ModelPriceHelper(c, relayInfo, promptToken, 0)
|
||||
if err != nil {
|
||||
return types.NewError(err, types.ErrorCodeModelPriceError)
|
||||
}
|
||||
// pre-consume quota 预消耗配额
|
||||
preConsumedQuota, userQuota, newAPIError := preConsumeQuota(c, priceData.ShouldPreConsumedQuota, relayInfo)
|
||||
if newAPIError != nil {
|
||||
return newAPIError
|
||||
}
|
||||
defer func() {
|
||||
if newAPIError != nil {
|
||||
returnPreConsumedQuota(c, relayInfo, userQuota, preConsumedQuota)
|
||||
}
|
||||
}()
|
||||
|
||||
adaptor := GetAdaptor(relayInfo.ApiType)
|
||||
adaptor := GetAdaptor(info.ApiType)
|
||||
if adaptor == nil {
|
||||
return types.NewError(fmt.Errorf("invalid api type: %d", relayInfo.ApiType), types.ErrorCodeInvalidApiType)
|
||||
return types.NewError(fmt.Errorf("invalid api type: %d", info.ApiType), types.ErrorCodeInvalidApiType, types.ErrOptionWithSkipRetry())
|
||||
}
|
||||
adaptor.Init(relayInfo)
|
||||
adaptor.Init(info)
|
||||
|
||||
convertedRequest, err := adaptor.ConvertRerankRequest(c, relayInfo.RelayMode, *rerankRequest)
|
||||
if err != nil {
|
||||
return types.NewError(err, types.ErrorCodeConvertRequestFailed)
|
||||
var requestBody io.Reader
|
||||
if model_setting.GetGlobalSettings().PassThroughRequestEnabled || info.ChannelSetting.PassThroughBodyEnabled {
|
||||
body, err := common.GetRequestBody(c)
|
||||
if err != nil {
|
||||
return types.NewErrorWithStatusCode(err, types.ErrorCodeReadRequestBodyFailed, http.StatusBadRequest, types.ErrOptionWithSkipRetry())
|
||||
}
|
||||
requestBody = bytes.NewBuffer(body)
|
||||
} else {
|
||||
convertedRequest, err := adaptor.ConvertRerankRequest(c, info.RelayMode, *request)
|
||||
if err != nil {
|
||||
return types.NewError(err, types.ErrorCodeConvertRequestFailed, types.ErrOptionWithSkipRetry())
|
||||
}
|
||||
jsonData, err := common.Marshal(convertedRequest)
|
||||
if err != nil {
|
||||
return types.NewError(err, types.ErrorCodeConvertRequestFailed, types.ErrOptionWithSkipRetry())
|
||||
}
|
||||
|
||||
// apply param override
|
||||
if len(info.ParamOverride) > 0 {
|
||||
jsonData, err = relaycommon.ApplyParamOverride(jsonData, info.ParamOverride)
|
||||
if err != nil {
|
||||
return types.NewError(err, types.ErrorCodeChannelParamOverrideInvalid, types.ErrOptionWithSkipRetry())
|
||||
}
|
||||
}
|
||||
|
||||
if common.DebugEnabled {
|
||||
println(fmt.Sprintf("Rerank request body: %s", string(jsonData)))
|
||||
}
|
||||
requestBody = bytes.NewBuffer(jsonData)
|
||||
}
|
||||
jsonData, err := common.Marshal(convertedRequest)
|
||||
if err != nil {
|
||||
return types.NewError(err, types.ErrorCodeConvertRequestFailed)
|
||||
}
|
||||
requestBody := bytes.NewBuffer(jsonData)
|
||||
if common.DebugEnabled {
|
||||
println(fmt.Sprintf("Rerank request body: %s", requestBody.String()))
|
||||
}
|
||||
resp, err := adaptor.DoRequest(c, relayInfo, requestBody)
|
||||
|
||||
resp, err := adaptor.DoRequest(c, info, requestBody)
|
||||
if err != nil {
|
||||
return types.NewOpenAIError(err, types.ErrorCodeDoRequestFailed, http.StatusInternalServerError)
|
||||
}
|
||||
@@ -99,12 +88,12 @@ func RerankHelper(c *gin.Context, relayMode int) (newAPIError *types.NewAPIError
|
||||
}
|
||||
}
|
||||
|
||||
usage, newAPIError := adaptor.DoResponse(c, httpResp, relayInfo)
|
||||
usage, newAPIError := adaptor.DoResponse(c, httpResp, info)
|
||||
if newAPIError != nil {
|
||||
// reset status code 重置状态码
|
||||
service.ResetStatusCode(newAPIError, statusCodeMappingStr)
|
||||
return newAPIError
|
||||
}
|
||||
postConsumeQuota(c, relayInfo, usage.(*dto.Usage), preConsumedQuota, userQuota, priceData, "")
|
||||
postConsumeQuota(c, info, usage.(*dto.Usage), "")
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -2,8 +2,6 @@ package relay
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
@@ -12,7 +10,6 @@ import (
|
||||
relaycommon "one-api/relay/common"
|
||||
"one-api/relay/helper"
|
||||
"one-api/service"
|
||||
"one-api/setting"
|
||||
"one-api/setting/model_setting"
|
||||
"one-api/types"
|
||||
"strings"
|
||||
@@ -20,111 +17,50 @@ import (
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
func getAndValidateResponsesRequest(c *gin.Context) (*dto.OpenAIResponsesRequest, error) {
|
||||
request := &dto.OpenAIResponsesRequest{}
|
||||
err := common.UnmarshalBodyReusable(c, request)
|
||||
func ResponsesHelper(c *gin.Context, info *relaycommon.RelayInfo) (newAPIError *types.NewAPIError) {
|
||||
info.InitChannelMeta(c)
|
||||
|
||||
responsesReq, ok := info.Request.(*dto.OpenAIResponsesRequest)
|
||||
if !ok {
|
||||
return types.NewErrorWithStatusCode(fmt.Errorf("invalid request type, expected dto.OpenAIResponsesRequest, got %T", info.Request), types.ErrorCodeInvalidRequest, http.StatusBadRequest, types.ErrOptionWithSkipRetry())
|
||||
}
|
||||
|
||||
request, err := common.DeepCopy(responsesReq)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
return types.NewError(fmt.Errorf("failed to copy request to GeneralOpenAIRequest: %w", err), types.ErrorCodeInvalidRequest, types.ErrOptionWithSkipRetry())
|
||||
}
|
||||
if request.Model == "" {
|
||||
return nil, errors.New("model is required")
|
||||
}
|
||||
if len(request.Input) == 0 {
|
||||
return nil, errors.New("input is required")
|
||||
}
|
||||
return request, nil
|
||||
|
||||
}
|
||||
|
||||
func checkInputSensitive(textRequest *dto.OpenAIResponsesRequest, info *relaycommon.RelayInfo) ([]string, error) {
|
||||
sensitiveWords, err := service.CheckSensitiveInput(textRequest.Input)
|
||||
return sensitiveWords, err
|
||||
}
|
||||
|
||||
func getInputTokens(req *dto.OpenAIResponsesRequest, info *relaycommon.RelayInfo) int {
|
||||
inputTokens := service.CountTokenInput(req.Input, req.Model)
|
||||
info.PromptTokens = inputTokens
|
||||
return inputTokens
|
||||
}
|
||||
|
||||
func ResponsesHelper(c *gin.Context) (newAPIError *types.NewAPIError) {
|
||||
req, err := getAndValidateResponsesRequest(c)
|
||||
err = helper.ModelMappedHelper(c, info, request)
|
||||
if err != nil {
|
||||
common.LogError(c, fmt.Sprintf("getAndValidateResponsesRequest error: %s", err.Error()))
|
||||
return types.NewError(err, types.ErrorCodeInvalidRequest)
|
||||
return types.NewError(err, types.ErrorCodeChannelModelMappedError, types.ErrOptionWithSkipRetry())
|
||||
}
|
||||
|
||||
relayInfo := relaycommon.GenRelayInfoResponses(c, req)
|
||||
|
||||
if setting.ShouldCheckPromptSensitive() {
|
||||
sensitiveWords, err := checkInputSensitive(req, relayInfo)
|
||||
if err != nil {
|
||||
common.LogWarn(c, fmt.Sprintf("user sensitive words detected: %s", strings.Join(sensitiveWords, ", ")))
|
||||
return types.NewError(err, types.ErrorCodeSensitiveWordsDetected)
|
||||
}
|
||||
}
|
||||
|
||||
err = helper.ModelMappedHelper(c, relayInfo, req)
|
||||
if err != nil {
|
||||
return types.NewError(err, types.ErrorCodeChannelModelMappedError)
|
||||
}
|
||||
|
||||
if value, exists := c.Get("prompt_tokens"); exists {
|
||||
promptTokens := value.(int)
|
||||
relayInfo.SetPromptTokens(promptTokens)
|
||||
} else {
|
||||
promptTokens := getInputTokens(req, relayInfo)
|
||||
c.Set("prompt_tokens", promptTokens)
|
||||
}
|
||||
|
||||
priceData, err := helper.ModelPriceHelper(c, relayInfo, relayInfo.PromptTokens, int(req.MaxOutputTokens))
|
||||
if err != nil {
|
||||
return types.NewError(err, types.ErrorCodeModelPriceError)
|
||||
}
|
||||
// pre consume quota
|
||||
preConsumedQuota, userQuota, newAPIError := preConsumeQuota(c, priceData.ShouldPreConsumedQuota, relayInfo)
|
||||
if newAPIError != nil {
|
||||
return newAPIError
|
||||
}
|
||||
defer func() {
|
||||
if newAPIError != nil {
|
||||
returnPreConsumedQuota(c, relayInfo, userQuota, preConsumedQuota)
|
||||
}
|
||||
}()
|
||||
adaptor := GetAdaptor(relayInfo.ApiType)
|
||||
adaptor := GetAdaptor(info.ApiType)
|
||||
if adaptor == nil {
|
||||
return types.NewError(fmt.Errorf("invalid api type: %d", relayInfo.ApiType), types.ErrorCodeInvalidApiType)
|
||||
return types.NewError(fmt.Errorf("invalid api type: %d", info.ApiType), types.ErrorCodeInvalidApiType, types.ErrOptionWithSkipRetry())
|
||||
}
|
||||
adaptor.Init(relayInfo)
|
||||
adaptor.Init(info)
|
||||
var requestBody io.Reader
|
||||
if model_setting.GetGlobalSettings().PassThroughRequestEnabled {
|
||||
body, err := common.GetRequestBody(c)
|
||||
if err != nil {
|
||||
return types.NewError(err, types.ErrorCodeReadRequestBodyFailed)
|
||||
return types.NewError(err, types.ErrorCodeReadRequestBodyFailed, types.ErrOptionWithSkipRetry())
|
||||
}
|
||||
requestBody = bytes.NewBuffer(body)
|
||||
} else {
|
||||
convertedRequest, err := adaptor.ConvertOpenAIResponsesRequest(c, relayInfo, *req)
|
||||
convertedRequest, err := adaptor.ConvertOpenAIResponsesRequest(c, info, *request)
|
||||
if err != nil {
|
||||
return types.NewError(err, types.ErrorCodeConvertRequestFailed)
|
||||
return types.NewError(err, types.ErrorCodeConvertRequestFailed, types.ErrOptionWithSkipRetry())
|
||||
}
|
||||
jsonData, err := json.Marshal(convertedRequest)
|
||||
jsonData, err := common.Marshal(convertedRequest)
|
||||
if err != nil {
|
||||
return types.NewError(err, types.ErrorCodeConvertRequestFailed)
|
||||
return types.NewError(err, types.ErrorCodeConvertRequestFailed, types.ErrOptionWithSkipRetry())
|
||||
}
|
||||
// apply param override
|
||||
if len(relayInfo.ParamOverride) > 0 {
|
||||
reqMap := make(map[string]interface{})
|
||||
err = json.Unmarshal(jsonData, &reqMap)
|
||||
if len(info.ParamOverride) > 0 {
|
||||
jsonData, err = relaycommon.ApplyParamOverride(jsonData, info.ParamOverride)
|
||||
if err != nil {
|
||||
return types.NewError(err, types.ErrorCodeChannelParamOverrideInvalid)
|
||||
}
|
||||
for key, value := range relayInfo.ParamOverride {
|
||||
reqMap[key] = value
|
||||
}
|
||||
jsonData, err = json.Marshal(reqMap)
|
||||
if err != nil {
|
||||
return types.NewError(err, types.ErrorCodeConvertRequestFailed)
|
||||
return types.NewError(err, types.ErrorCodeChannelParamOverrideInvalid, types.ErrOptionWithSkipRetry())
|
||||
}
|
||||
}
|
||||
|
||||
@@ -135,7 +71,7 @@ func ResponsesHelper(c *gin.Context) (newAPIError *types.NewAPIError) {
|
||||
}
|
||||
|
||||
var httpResp *http.Response
|
||||
resp, err := adaptor.DoRequest(c, relayInfo, requestBody)
|
||||
resp, err := adaptor.DoRequest(c, info, requestBody)
|
||||
if err != nil {
|
||||
return types.NewOpenAIError(err, types.ErrorCodeDoRequestFailed, http.StatusInternalServerError)
|
||||
}
|
||||
@@ -153,17 +89,17 @@ func ResponsesHelper(c *gin.Context) (newAPIError *types.NewAPIError) {
|
||||
}
|
||||
}
|
||||
|
||||
usage, newAPIError := adaptor.DoResponse(c, httpResp, relayInfo)
|
||||
usage, newAPIError := adaptor.DoResponse(c, httpResp, info)
|
||||
if newAPIError != nil {
|
||||
// reset status code 重置状态码
|
||||
service.ResetStatusCode(newAPIError, statusCodeMappingStr)
|
||||
return newAPIError
|
||||
}
|
||||
|
||||
if strings.HasPrefix(relayInfo.OriginModelName, "gpt-4o-audio") {
|
||||
service.PostAudioConsumeQuota(c, relayInfo, usage.(*dto.Usage), preConsumedQuota, userQuota, priceData, "")
|
||||
if strings.HasPrefix(info.OriginModelName, "gpt-4o-audio") {
|
||||
service.PostAudioConsumeQuota(c, info, usage.(*dto.Usage), "")
|
||||
} else {
|
||||
postConsumeQuota(c, relayInfo, usage.(*dto.Usage), preConsumedQuota, userQuota, priceData, "")
|
||||
postConsumeQuota(c, info, usage.(*dto.Usage), "")
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -4,7 +4,6 @@ import (
|
||||
"fmt"
|
||||
"one-api/dto"
|
||||
relaycommon "one-api/relay/common"
|
||||
"one-api/relay/helper"
|
||||
"one-api/service"
|
||||
"one-api/types"
|
||||
|
||||
@@ -12,65 +11,35 @@ import (
|
||||
"github.com/gorilla/websocket"
|
||||
)
|
||||
|
||||
func WssHelper(c *gin.Context, ws *websocket.Conn) (newAPIError *types.NewAPIError) {
|
||||
relayInfo := relaycommon.GenRelayInfoWs(c, ws)
|
||||
func WssHelper(c *gin.Context, info *relaycommon.RelayInfo) (newAPIError *types.NewAPIError) {
|
||||
info.InitChannelMeta(c)
|
||||
|
||||
// get & validate textRequest 获取并验证文本请求
|
||||
//realtimeEvent, err := getAndValidateWssRequest(c, ws)
|
||||
//if err != nil {
|
||||
// common.LogError(c, fmt.Sprintf("getAndValidateWssRequest failed: %s", err.Error()))
|
||||
// return service.OpenAIErrorWrapperLocal(err, "invalid_text_request", http.StatusBadRequest)
|
||||
//}
|
||||
|
||||
err := helper.ModelMappedHelper(c, relayInfo, nil)
|
||||
if err != nil {
|
||||
return types.NewError(err, types.ErrorCodeChannelModelMappedError)
|
||||
}
|
||||
|
||||
priceData, err := helper.ModelPriceHelper(c, relayInfo, 0, 0)
|
||||
if err != nil {
|
||||
return types.NewError(err, types.ErrorCodeModelPriceError)
|
||||
}
|
||||
|
||||
// pre-consume quota 预消耗配额
|
||||
preConsumedQuota, userQuota, newAPIError := preConsumeQuota(c, priceData.ShouldPreConsumedQuota, relayInfo)
|
||||
if newAPIError != nil {
|
||||
return newAPIError
|
||||
}
|
||||
|
||||
defer func() {
|
||||
if newAPIError != nil {
|
||||
returnPreConsumedQuota(c, relayInfo, userQuota, preConsumedQuota)
|
||||
}
|
||||
}()
|
||||
|
||||
adaptor := GetAdaptor(relayInfo.ApiType)
|
||||
adaptor := GetAdaptor(info.ApiType)
|
||||
if adaptor == nil {
|
||||
return types.NewError(fmt.Errorf("invalid api type: %d", relayInfo.ApiType), types.ErrorCodeInvalidApiType)
|
||||
return types.NewError(fmt.Errorf("invalid api type: %d", info.ApiType), types.ErrorCodeInvalidApiType, types.ErrOptionWithSkipRetry())
|
||||
}
|
||||
adaptor.Init(relayInfo)
|
||||
adaptor.Init(info)
|
||||
//var requestBody io.Reader
|
||||
//firstWssRequest, _ := c.Get("first_wss_request")
|
||||
//requestBody = bytes.NewBuffer(firstWssRequest.([]byte))
|
||||
|
||||
statusCodeMappingStr := c.GetString("status_code_mapping")
|
||||
resp, err := adaptor.DoRequest(c, relayInfo, nil)
|
||||
resp, err := adaptor.DoRequest(c, info, nil)
|
||||
if err != nil {
|
||||
return types.NewError(err, types.ErrorCodeDoRequestFailed)
|
||||
}
|
||||
|
||||
if resp != nil {
|
||||
relayInfo.TargetWs = resp.(*websocket.Conn)
|
||||
defer relayInfo.TargetWs.Close()
|
||||
info.TargetWs = resp.(*websocket.Conn)
|
||||
defer info.TargetWs.Close()
|
||||
}
|
||||
|
||||
usage, newAPIError := adaptor.DoResponse(c, nil, relayInfo)
|
||||
usage, newAPIError := adaptor.DoResponse(c, nil, info)
|
||||
if newAPIError != nil {
|
||||
// reset status code 重置状态码
|
||||
service.ResetStatusCode(newAPIError, statusCodeMappingStr)
|
||||
return newAPIError
|
||||
}
|
||||
service.PostWssConsumeQuota(c, relayInfo, relayInfo.UpstreamModelName, usage.(*dto.RealtimeUsage), preConsumedQuota,
|
||||
userQuota, priceData, "")
|
||||
service.PostWssConsumeQuota(c, info, info.UpstreamModelName, usage.(*dto.RealtimeUsage), "")
|
||||
return nil
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user