mirror of
https://github.com/QuantumNous/new-api.git
synced 2026-06-07 22:09:57 +00:00
Merge branch 'QuantumNous:main' into main
This commit is contained in:
@@ -265,6 +265,7 @@ func doRequest(c *gin.Context, req *http.Request, info *common.RelayInfo) (*http
|
||||
|
||||
resp, err := client.Do(req)
|
||||
if err != nil {
|
||||
logger.LogError(c, "do request failed: "+err.Error())
|
||||
return nil, types.NewError(err, types.ErrorCodeDoRequestFailed, types.ErrOptionWithHideErrMsg("upstream error: do request failed"))
|
||||
}
|
||||
if resp == nil {
|
||||
|
||||
@@ -21,6 +21,10 @@ var awsModelIDMap = map[string]string{
|
||||
"nova-lite-v1:0": "amazon.nova-lite-v1:0",
|
||||
"nova-pro-v1:0": "amazon.nova-pro-v1:0",
|
||||
"nova-premier-v1:0": "amazon.nova-premier-v1:0",
|
||||
"nova-canvas-v1:0": "amazon.nova-canvas-v1:0",
|
||||
"nova-reel-v1:0": "amazon.nova-reel-v1:0",
|
||||
"nova-reel-v1:1": "amazon.nova-reel-v1:1",
|
||||
"nova-sonic-v1:0": "amazon.nova-sonic-v1:0",
|
||||
}
|
||||
|
||||
var awsModelCanCrossRegionMap = map[string]map[string]bool{
|
||||
@@ -82,10 +86,27 @@ var awsModelCanCrossRegionMap = map[string]map[string]bool{
|
||||
"apac": true,
|
||||
},
|
||||
"amazon.nova-premier-v1:0": {
|
||||
"us": true,
|
||||
},
|
||||
"amazon.nova-canvas-v1:0": {
|
||||
"us": true,
|
||||
"eu": true,
|
||||
"apac": true,
|
||||
}}
|
||||
},
|
||||
"amazon.nova-reel-v1:0": {
|
||||
"us": true,
|
||||
"eu": true,
|
||||
"apac": true,
|
||||
},
|
||||
"amazon.nova-reel-v1:1": {
|
||||
"us": true,
|
||||
},
|
||||
"amazon.nova-sonic-v1:0": {
|
||||
"us": true,
|
||||
"eu": true,
|
||||
"apac": true,
|
||||
},
|
||||
}
|
||||
|
||||
var awsRegionCrossModelPrefixMap = map[string]string{
|
||||
"us": "us",
|
||||
|
||||
@@ -3,17 +3,17 @@ package deepseek
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"github.com/gin-gonic/gin"
|
||||
"io"
|
||||
"net/http"
|
||||
"one-api/dto"
|
||||
"one-api/relay/channel"
|
||||
"one-api/relay/channel/claude"
|
||||
"one-api/relay/channel/openai"
|
||||
relaycommon "one-api/relay/common"
|
||||
"one-api/relay/constant"
|
||||
"one-api/types"
|
||||
"strings"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
type Adaptor struct {
|
||||
@@ -25,7 +25,7 @@ func (a *Adaptor) ConvertGeminiRequest(*gin.Context, *relaycommon.RelayInfo, *dt
|
||||
}
|
||||
|
||||
func (a *Adaptor) ConvertClaudeRequest(c *gin.Context, info *relaycommon.RelayInfo, req *dto.ClaudeRequest) (any, error) {
|
||||
adaptor := openai.Adaptor{}
|
||||
adaptor := claude.Adaptor{}
|
||||
return adaptor.ConvertClaudeRequest(c, info, req)
|
||||
}
|
||||
|
||||
@@ -44,14 +44,19 @@ func (a *Adaptor) Init(info *relaycommon.RelayInfo) {
|
||||
|
||||
func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) {
|
||||
fimBaseUrl := info.ChannelBaseUrl
|
||||
if !strings.HasSuffix(info.ChannelBaseUrl, "/beta") {
|
||||
fimBaseUrl += "/beta"
|
||||
}
|
||||
switch info.RelayMode {
|
||||
case constant.RelayModeCompletions:
|
||||
return fmt.Sprintf("%s/completions", fimBaseUrl), nil
|
||||
switch info.RelayFormat {
|
||||
case types.RelayFormatClaude:
|
||||
return fmt.Sprintf("%s/anthropic/v1/messages", info.ChannelBaseUrl), nil
|
||||
default:
|
||||
return fmt.Sprintf("%s/v1/chat/completions", info.ChannelBaseUrl), nil
|
||||
if !strings.HasSuffix(info.ChannelBaseUrl, "/beta") {
|
||||
fimBaseUrl += "/beta"
|
||||
}
|
||||
switch info.RelayMode {
|
||||
case constant.RelayModeCompletions:
|
||||
return fmt.Sprintf("%s/completions", fimBaseUrl), nil
|
||||
default:
|
||||
return fmt.Sprintf("%s/v1/chat/completions", info.ChannelBaseUrl), nil
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -87,12 +92,17 @@ func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, request
|
||||
}
|
||||
|
||||
func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage any, err *types.NewAPIError) {
|
||||
if info.IsStream {
|
||||
usage, err = openai.OaiStreamHandler(c, info, resp)
|
||||
} else {
|
||||
usage, err = openai.OpenaiHandler(c, info, resp)
|
||||
switch info.RelayFormat {
|
||||
case types.RelayFormatClaude:
|
||||
if info.IsStream {
|
||||
return claude.ClaudeStreamHandler(c, resp, info, claude.RequestModeMessage)
|
||||
} else {
|
||||
return claude.ClaudeHandler(c, resp, info, claude.RequestModeMessage)
|
||||
}
|
||||
default:
|
||||
adaptor := openai.Adaptor{}
|
||||
return adaptor.DoResponse(c, resp, info)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
func (a *Adaptor) GetModelList() []string {
|
||||
|
||||
@@ -215,8 +215,8 @@ func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, request
|
||||
|
||||
func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage any, err *types.NewAPIError) {
|
||||
if info.RelayMode == constant.RelayModeGemini {
|
||||
if strings.HasSuffix(info.RequestURLPath, ":embedContent") ||
|
||||
strings.HasSuffix(info.RequestURLPath, ":batchEmbedContents") {
|
||||
if strings.Contains(info.RequestURLPath, ":embedContent") ||
|
||||
strings.Contains(info.RequestURLPath, ":batchEmbedContents") {
|
||||
return NativeGeminiEmbeddingHandler(c, resp, info)
|
||||
}
|
||||
if info.IsStream {
|
||||
|
||||
@@ -23,6 +23,7 @@ import (
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
// https://cloud.google.com/vertex-ai/generative-ai/docs/model-reference/inference?hl=zh-cn#blob
|
||||
var geminiSupportedMimeTypes = map[string]bool{
|
||||
"application/pdf": true,
|
||||
"audio/mpeg": true,
|
||||
@@ -30,6 +31,7 @@ var geminiSupportedMimeTypes = map[string]bool{
|
||||
"audio/wav": true,
|
||||
"image/png": true,
|
||||
"image/jpeg": true,
|
||||
"image/webp": true,
|
||||
"text/plain": true,
|
||||
"video/mov": true,
|
||||
"video/mpeg": true,
|
||||
@@ -243,6 +245,7 @@ func CovertGemini2OpenAI(c *gin.Context, textRequest dto.GeneralOpenAIRequest, i
|
||||
functions := make([]dto.FunctionRequest, 0, len(textRequest.Tools))
|
||||
googleSearch := false
|
||||
codeExecution := false
|
||||
urlContext := false
|
||||
for _, tool := range textRequest.Tools {
|
||||
if tool.Function.Name == "googleSearch" {
|
||||
googleSearch = true
|
||||
@@ -252,6 +255,10 @@ func CovertGemini2OpenAI(c *gin.Context, textRequest dto.GeneralOpenAIRequest, i
|
||||
codeExecution = true
|
||||
continue
|
||||
}
|
||||
if tool.Function.Name == "urlContext" {
|
||||
urlContext = true
|
||||
continue
|
||||
}
|
||||
if tool.Function.Parameters != nil {
|
||||
|
||||
params, ok := tool.Function.Parameters.(map[string]interface{})
|
||||
@@ -279,6 +286,11 @@ func CovertGemini2OpenAI(c *gin.Context, textRequest dto.GeneralOpenAIRequest, i
|
||||
GoogleSearch: make(map[string]string),
|
||||
})
|
||||
}
|
||||
if urlContext {
|
||||
geminiTools = append(geminiTools, dto.GeminiChatTool{
|
||||
URLContext: make(map[string]string),
|
||||
})
|
||||
}
|
||||
if len(functions) > 0 {
|
||||
geminiTools = append(geminiTools, dto.GeminiChatTool{
|
||||
FunctionDeclarations: functions,
|
||||
|
||||
@@ -25,7 +25,7 @@ func (a *Adaptor) ConvertGeminiRequest(*gin.Context, *relaycommon.RelayInfo, *dt
|
||||
}
|
||||
|
||||
func (a *Adaptor) ConvertClaudeRequest(c *gin.Context, info *relaycommon.RelayInfo, req *dto.ClaudeRequest) (any, error) {
|
||||
adaptor := openai.Adaptor{}
|
||||
adaptor := claude.Adaptor{}
|
||||
return adaptor.ConvertClaudeRequest(c, info, req)
|
||||
}
|
||||
|
||||
|
||||
@@ -10,6 +10,7 @@ import (
|
||||
relaycommon "one-api/relay/common"
|
||||
relayconstant "one-api/relay/constant"
|
||||
"one-api/types"
|
||||
"strings"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
@@ -17,10 +18,7 @@ import (
|
||||
type Adaptor struct {
|
||||
}
|
||||
|
||||
func (a *Adaptor) ConvertGeminiRequest(*gin.Context, *relaycommon.RelayInfo, *dto.GeminiChatRequest) (any, error) {
|
||||
//TODO implement me
|
||||
return nil, errors.New("not implemented")
|
||||
}
|
||||
func (a *Adaptor) ConvertGeminiRequest(*gin.Context, *relaycommon.RelayInfo, *dto.GeminiChatRequest) (any, error) { return nil, errors.New("not implemented") }
|
||||
|
||||
func (a *Adaptor) ConvertClaudeRequest(c *gin.Context, info *relaycommon.RelayInfo, request *dto.ClaudeRequest) (any, error) {
|
||||
openaiAdaptor := openai.Adaptor{}
|
||||
@@ -31,32 +29,21 @@ func (a *Adaptor) ConvertClaudeRequest(c *gin.Context, info *relaycommon.RelayIn
|
||||
openaiRequest.(*dto.GeneralOpenAIRequest).StreamOptions = &dto.StreamOptions{
|
||||
IncludeUsage: true,
|
||||
}
|
||||
return requestOpenAI2Ollama(c, openaiRequest.(*dto.GeneralOpenAIRequest))
|
||||
// map to ollama chat request (Claude -> OpenAI -> Ollama chat)
|
||||
return openAIChatToOllamaChat(c, openaiRequest.(*dto.GeneralOpenAIRequest))
|
||||
}
|
||||
|
||||
func (a *Adaptor) ConvertAudioRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.AudioRequest) (io.Reader, error) {
|
||||
//TODO implement me
|
||||
return nil, errors.New("not implemented")
|
||||
}
|
||||
func (a *Adaptor) ConvertAudioRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.AudioRequest) (io.Reader, error) { return nil, errors.New("not implemented") }
|
||||
|
||||
func (a *Adaptor) ConvertImageRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.ImageRequest) (any, error) {
|
||||
//TODO implement me
|
||||
return nil, errors.New("not implemented")
|
||||
}
|
||||
func (a *Adaptor) ConvertImageRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.ImageRequest) (any, error) { return nil, errors.New("not implemented") }
|
||||
|
||||
func (a *Adaptor) Init(info *relaycommon.RelayInfo) {
|
||||
}
|
||||
|
||||
func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) {
|
||||
if info.RelayFormat == types.RelayFormatClaude {
|
||||
return info.ChannelBaseUrl + "/v1/chat/completions", nil
|
||||
}
|
||||
switch info.RelayMode {
|
||||
case relayconstant.RelayModeEmbeddings:
|
||||
return info.ChannelBaseUrl + "/api/embed", nil
|
||||
default:
|
||||
return relaycommon.GetFullRequestURL(info.ChannelBaseUrl, info.RequestURLPath, info.ChannelType), nil
|
||||
}
|
||||
if info.RelayMode == relayconstant.RelayModeEmbeddings { return info.ChannelBaseUrl + "/api/embed", nil }
|
||||
if strings.Contains(info.RequestURLPath, "/v1/completions") || info.RelayMode == relayconstant.RelayModeCompletions { return info.ChannelBaseUrl + "/api/generate", nil }
|
||||
return info.ChannelBaseUrl + "/api/chat", nil
|
||||
}
|
||||
|
||||
func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Header, info *relaycommon.RelayInfo) error {
|
||||
@@ -66,10 +53,12 @@ func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Header, info *rel
|
||||
}
|
||||
|
||||
func (a *Adaptor) ConvertOpenAIRequest(c *gin.Context, info *relaycommon.RelayInfo, request *dto.GeneralOpenAIRequest) (any, error) {
|
||||
if request == nil {
|
||||
return nil, errors.New("request is nil")
|
||||
if request == nil { return nil, errors.New("request is nil") }
|
||||
// decide generate or chat
|
||||
if strings.Contains(info.RequestURLPath, "/v1/completions") || info.RelayMode == relayconstant.RelayModeCompletions {
|
||||
return openAIToGenerate(c, request)
|
||||
}
|
||||
return requestOpenAI2Ollama(c, request)
|
||||
return openAIChatToOllamaChat(c, request)
|
||||
}
|
||||
|
||||
func (a *Adaptor) ConvertRerankRequest(c *gin.Context, relayMode int, request dto.RerankRequest) (any, error) {
|
||||
@@ -80,10 +69,7 @@ func (a *Adaptor) ConvertEmbeddingRequest(c *gin.Context, info *relaycommon.Rela
|
||||
return requestOpenAI2Embeddings(request), nil
|
||||
}
|
||||
|
||||
func (a *Adaptor) ConvertOpenAIResponsesRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.OpenAIResponsesRequest) (any, error) {
|
||||
// TODO implement me
|
||||
return nil, errors.New("not implemented")
|
||||
}
|
||||
func (a *Adaptor) ConvertOpenAIResponsesRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.OpenAIResponsesRequest) (any, error) { return nil, errors.New("not implemented") }
|
||||
|
||||
func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (any, error) {
|
||||
return channel.DoApiRequest(a, c, info, requestBody)
|
||||
@@ -92,15 +78,13 @@ func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, request
|
||||
func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage any, err *types.NewAPIError) {
|
||||
switch info.RelayMode {
|
||||
case relayconstant.RelayModeEmbeddings:
|
||||
usage, err = ollamaEmbeddingHandler(c, info, resp)
|
||||
return ollamaEmbeddingHandler(c, info, resp)
|
||||
default:
|
||||
if info.IsStream {
|
||||
usage, err = openai.OaiStreamHandler(c, info, resp)
|
||||
} else {
|
||||
usage, err = openai.OpenaiHandler(c, info, resp)
|
||||
return ollamaStreamHandler(c, info, resp)
|
||||
}
|
||||
return ollamaChatHandler(c, info, resp)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
func (a *Adaptor) GetModelList() []string {
|
||||
|
||||
@@ -2,48 +2,69 @@ package ollama
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"one-api/dto"
|
||||
)
|
||||
|
||||
type OllamaRequest struct {
|
||||
Model string `json:"model,omitempty"`
|
||||
Messages []dto.Message `json:"messages,omitempty"`
|
||||
Stream bool `json:"stream,omitempty"`
|
||||
Temperature *float64 `json:"temperature,omitempty"`
|
||||
Seed float64 `json:"seed,omitempty"`
|
||||
Topp float64 `json:"top_p,omitempty"`
|
||||
TopK int `json:"top_k,omitempty"`
|
||||
Stop any `json:"stop,omitempty"`
|
||||
MaxTokens uint `json:"max_tokens,omitempty"`
|
||||
Tools []dto.ToolCallRequest `json:"tools,omitempty"`
|
||||
ResponseFormat any `json:"response_format,omitempty"`
|
||||
FrequencyPenalty float64 `json:"frequency_penalty,omitempty"`
|
||||
PresencePenalty float64 `json:"presence_penalty,omitempty"`
|
||||
Suffix any `json:"suffix,omitempty"`
|
||||
StreamOptions *dto.StreamOptions `json:"stream_options,omitempty"`
|
||||
Prompt any `json:"prompt,omitempty"`
|
||||
Think json.RawMessage `json:"think,omitempty"`
|
||||
type OllamaChatMessage struct {
|
||||
Role string `json:"role"`
|
||||
Content string `json:"content,omitempty"`
|
||||
Images []string `json:"images,omitempty"`
|
||||
ToolCalls []OllamaToolCall `json:"tool_calls,omitempty"`
|
||||
ToolName string `json:"tool_name,omitempty"`
|
||||
Thinking json.RawMessage `json:"thinking,omitempty"`
|
||||
}
|
||||
|
||||
type Options struct {
|
||||
Seed int `json:"seed,omitempty"`
|
||||
Temperature *float64 `json:"temperature,omitempty"`
|
||||
TopK int `json:"top_k,omitempty"`
|
||||
TopP float64 `json:"top_p,omitempty"`
|
||||
FrequencyPenalty float64 `json:"frequency_penalty,omitempty"`
|
||||
PresencePenalty float64 `json:"presence_penalty,omitempty"`
|
||||
NumPredict int `json:"num_predict,omitempty"`
|
||||
NumCtx int `json:"num_ctx,omitempty"`
|
||||
type OllamaToolFunction struct {
|
||||
Name string `json:"name"`
|
||||
Description string `json:"description,omitempty"`
|
||||
Parameters interface{} `json:"parameters,omitempty"`
|
||||
}
|
||||
|
||||
type OllamaTool struct {
|
||||
Type string `json:"type"`
|
||||
Function OllamaToolFunction `json:"function"`
|
||||
}
|
||||
|
||||
type OllamaToolCall struct {
|
||||
Function struct {
|
||||
Name string `json:"name"`
|
||||
Arguments interface{} `json:"arguments"`
|
||||
} `json:"function"`
|
||||
}
|
||||
|
||||
type OllamaChatRequest struct {
|
||||
Model string `json:"model"`
|
||||
Messages []OllamaChatMessage `json:"messages"`
|
||||
Tools interface{} `json:"tools,omitempty"`
|
||||
Format interface{} `json:"format,omitempty"`
|
||||
Stream bool `json:"stream,omitempty"`
|
||||
Options map[string]any `json:"options,omitempty"`
|
||||
KeepAlive interface{} `json:"keep_alive,omitempty"`
|
||||
Think json.RawMessage `json:"think,omitempty"`
|
||||
}
|
||||
|
||||
type OllamaGenerateRequest struct {
|
||||
Model string `json:"model"`
|
||||
Prompt string `json:"prompt,omitempty"`
|
||||
Suffix string `json:"suffix,omitempty"`
|
||||
Images []string `json:"images,omitempty"`
|
||||
Format interface{} `json:"format,omitempty"`
|
||||
Stream bool `json:"stream,omitempty"`
|
||||
Options map[string]any `json:"options,omitempty"`
|
||||
KeepAlive interface{} `json:"keep_alive,omitempty"`
|
||||
Think json.RawMessage `json:"think,omitempty"`
|
||||
}
|
||||
|
||||
type OllamaEmbeddingRequest struct {
|
||||
Model string `json:"model,omitempty"`
|
||||
Input []string `json:"input"`
|
||||
Options *Options `json:"options,omitempty"`
|
||||
Model string `json:"model"`
|
||||
Input interface{} `json:"input"`
|
||||
Options map[string]any `json:"options,omitempty"`
|
||||
Dimensions int `json:"dimensions,omitempty"`
|
||||
}
|
||||
|
||||
type OllamaEmbeddingResponse struct {
|
||||
Error string `json:"error,omitempty"`
|
||||
Model string `json:"model"`
|
||||
Embedding [][]float64 `json:"embeddings,omitempty"`
|
||||
Error string `json:"error,omitempty"`
|
||||
Model string `json:"model"`
|
||||
Embeddings [][]float64 `json:"embeddings"`
|
||||
PromptEvalCount int `json:"prompt_eval_count,omitempty"`
|
||||
}
|
||||
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
package ollama
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
@@ -14,121 +15,176 @@ import (
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
func requestOpenAI2Ollama(c *gin.Context, request *dto.GeneralOpenAIRequest) (*OllamaRequest, error) {
|
||||
messages := make([]dto.Message, 0, len(request.Messages))
|
||||
for _, message := range request.Messages {
|
||||
if !message.IsStringContent() {
|
||||
mediaMessages := message.ParseContent()
|
||||
for j, mediaMessage := range mediaMessages {
|
||||
if mediaMessage.Type == dto.ContentTypeImageURL {
|
||||
imageUrl := mediaMessage.GetImageMedia()
|
||||
// check if not base64
|
||||
if strings.HasPrefix(imageUrl.Url, "http") {
|
||||
fileData, err := service.GetFileBase64FromUrl(c, imageUrl.Url, "formatting image for Ollama")
|
||||
if err != nil {
|
||||
return nil, err
|
||||
func openAIChatToOllamaChat(c *gin.Context, r *dto.GeneralOpenAIRequest) (*OllamaChatRequest, error) {
|
||||
chatReq := &OllamaChatRequest{
|
||||
Model: r.Model,
|
||||
Stream: r.Stream,
|
||||
Options: map[string]any{},
|
||||
Think: r.Think,
|
||||
}
|
||||
if r.ResponseFormat != nil {
|
||||
if r.ResponseFormat.Type == "json" {
|
||||
chatReq.Format = "json"
|
||||
} else if r.ResponseFormat.Type == "json_schema" {
|
||||
if len(r.ResponseFormat.JsonSchema) > 0 {
|
||||
var schema any
|
||||
_ = json.Unmarshal(r.ResponseFormat.JsonSchema, &schema)
|
||||
chatReq.Format = schema
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// options mapping
|
||||
if r.Temperature != nil { chatReq.Options["temperature"] = r.Temperature }
|
||||
if r.TopP != 0 { chatReq.Options["top_p"] = r.TopP }
|
||||
if r.TopK != 0 { chatReq.Options["top_k"] = r.TopK }
|
||||
if r.FrequencyPenalty != 0 { chatReq.Options["frequency_penalty"] = r.FrequencyPenalty }
|
||||
if r.PresencePenalty != 0 { chatReq.Options["presence_penalty"] = r.PresencePenalty }
|
||||
if r.Seed != 0 { chatReq.Options["seed"] = int(r.Seed) }
|
||||
if mt := r.GetMaxTokens(); mt != 0 { chatReq.Options["num_predict"] = int(mt) }
|
||||
|
||||
if r.Stop != nil {
|
||||
switch v := r.Stop.(type) {
|
||||
case string:
|
||||
chatReq.Options["stop"] = []string{v}
|
||||
case []string:
|
||||
chatReq.Options["stop"] = v
|
||||
case []any:
|
||||
arr := make([]string,0,len(v))
|
||||
for _, i := range v { if s,ok:=i.(string); ok { arr = append(arr,s) } }
|
||||
if len(arr)>0 { chatReq.Options["stop"] = arr }
|
||||
}
|
||||
}
|
||||
|
||||
if len(r.Tools) > 0 {
|
||||
tools := make([]OllamaTool,0,len(r.Tools))
|
||||
for _, t := range r.Tools {
|
||||
tools = append(tools, OllamaTool{Type: "function", Function: OllamaToolFunction{Name: t.Function.Name, Description: t.Function.Description, Parameters: t.Function.Parameters}})
|
||||
}
|
||||
chatReq.Tools = tools
|
||||
}
|
||||
|
||||
chatReq.Messages = make([]OllamaChatMessage,0,len(r.Messages))
|
||||
for _, m := range r.Messages {
|
||||
var textBuilder strings.Builder
|
||||
var images []string
|
||||
if m.IsStringContent() {
|
||||
textBuilder.WriteString(m.StringContent())
|
||||
} else {
|
||||
parts := m.ParseContent()
|
||||
for _, part := range parts {
|
||||
if part.Type == dto.ContentTypeImageURL {
|
||||
img := part.GetImageMedia()
|
||||
if img != nil && img.Url != "" {
|
||||
var base64Data string
|
||||
if strings.HasPrefix(img.Url, "http") {
|
||||
fileData, err := service.GetFileBase64FromUrl(c, img.Url, "fetch image for ollama chat")
|
||||
if err != nil { return nil, err }
|
||||
base64Data = fileData.Base64Data
|
||||
} else if strings.HasPrefix(img.Url, "data:") {
|
||||
if idx := strings.Index(img.Url, ","); idx != -1 && idx+1 < len(img.Url) { base64Data = img.Url[idx+1:] }
|
||||
} else {
|
||||
base64Data = img.Url
|
||||
}
|
||||
imageUrl.Url = fmt.Sprintf("data:%s;base64,%s", fileData.MimeType, fileData.Base64Data)
|
||||
if base64Data != "" { images = append(images, base64Data) }
|
||||
}
|
||||
mediaMessage.ImageUrl = imageUrl
|
||||
mediaMessages[j] = mediaMessage
|
||||
} else if part.Type == dto.ContentTypeText {
|
||||
textBuilder.WriteString(part.Text)
|
||||
}
|
||||
}
|
||||
message.SetMediaContent(mediaMessages)
|
||||
}
|
||||
messages = append(messages, dto.Message{
|
||||
Role: message.Role,
|
||||
Content: message.Content,
|
||||
ToolCalls: message.ToolCalls,
|
||||
ToolCallId: message.ToolCallId,
|
||||
})
|
||||
cm := OllamaChatMessage{Role: m.Role, Content: textBuilder.String()}
|
||||
if len(images)>0 { cm.Images = images }
|
||||
if m.Role == "tool" && m.Name != nil { cm.ToolName = *m.Name }
|
||||
if m.ToolCalls != nil && len(m.ToolCalls) > 0 {
|
||||
parsed := m.ParseToolCalls()
|
||||
if len(parsed) > 0 {
|
||||
calls := make([]OllamaToolCall,0,len(parsed))
|
||||
for _, tc := range parsed {
|
||||
var args interface{}
|
||||
if tc.Function.Arguments != "" { _ = json.Unmarshal([]byte(tc.Function.Arguments), &args) }
|
||||
if args==nil { args = map[string]any{} }
|
||||
oc := OllamaToolCall{}
|
||||
oc.Function.Name = tc.Function.Name
|
||||
oc.Function.Arguments = args
|
||||
calls = append(calls, oc)
|
||||
}
|
||||
cm.ToolCalls = calls
|
||||
}
|
||||
}
|
||||
chatReq.Messages = append(chatReq.Messages, cm)
|
||||
}
|
||||
str, ok := request.Stop.(string)
|
||||
var Stop []string
|
||||
if ok {
|
||||
Stop = []string{str}
|
||||
} else {
|
||||
Stop, _ = request.Stop.([]string)
|
||||
}
|
||||
ollamaRequest := &OllamaRequest{
|
||||
Model: request.Model,
|
||||
Messages: messages,
|
||||
Stream: request.Stream,
|
||||
Temperature: request.Temperature,
|
||||
Seed: request.Seed,
|
||||
Topp: request.TopP,
|
||||
TopK: request.TopK,
|
||||
Stop: Stop,
|
||||
Tools: request.Tools,
|
||||
MaxTokens: request.GetMaxTokens(),
|
||||
ResponseFormat: request.ResponseFormat,
|
||||
FrequencyPenalty: request.FrequencyPenalty,
|
||||
PresencePenalty: request.PresencePenalty,
|
||||
Prompt: request.Prompt,
|
||||
StreamOptions: request.StreamOptions,
|
||||
Suffix: request.Suffix,
|
||||
}
|
||||
ollamaRequest.Think = request.Think
|
||||
return ollamaRequest, nil
|
||||
return chatReq, nil
|
||||
}
|
||||
|
||||
func requestOpenAI2Embeddings(request dto.EmbeddingRequest) *OllamaEmbeddingRequest {
|
||||
return &OllamaEmbeddingRequest{
|
||||
Model: request.Model,
|
||||
Input: request.ParseInput(),
|
||||
Options: &Options{
|
||||
Seed: int(request.Seed),
|
||||
Temperature: request.Temperature,
|
||||
TopP: request.TopP,
|
||||
FrequencyPenalty: request.FrequencyPenalty,
|
||||
PresencePenalty: request.PresencePenalty,
|
||||
},
|
||||
// openAIToGenerate converts OpenAI completions request to Ollama generate
|
||||
func openAIToGenerate(c *gin.Context, r *dto.GeneralOpenAIRequest) (*OllamaGenerateRequest, error) {
|
||||
gen := &OllamaGenerateRequest{
|
||||
Model: r.Model,
|
||||
Stream: r.Stream,
|
||||
Options: map[string]any{},
|
||||
Think: r.Think,
|
||||
}
|
||||
// Prompt may be in r.Prompt (string or []any)
|
||||
if r.Prompt != nil {
|
||||
switch v := r.Prompt.(type) {
|
||||
case string:
|
||||
gen.Prompt = v
|
||||
case []any:
|
||||
var sb strings.Builder
|
||||
for _, it := range v { if s,ok:=it.(string); ok { sb.WriteString(s) } }
|
||||
gen.Prompt = sb.String()
|
||||
default:
|
||||
gen.Prompt = fmt.Sprintf("%v", r.Prompt)
|
||||
}
|
||||
}
|
||||
if r.Suffix != nil { if s,ok:=r.Suffix.(string); ok { gen.Suffix = s } }
|
||||
if r.ResponseFormat != nil {
|
||||
if r.ResponseFormat.Type == "json" { gen.Format = "json" } else if r.ResponseFormat.Type == "json_schema" { var schema any; _ = json.Unmarshal(r.ResponseFormat.JsonSchema,&schema); gen.Format=schema }
|
||||
}
|
||||
if r.Temperature != nil { gen.Options["temperature"] = r.Temperature }
|
||||
if r.TopP != 0 { gen.Options["top_p"] = r.TopP }
|
||||
if r.TopK != 0 { gen.Options["top_k"] = r.TopK }
|
||||
if r.FrequencyPenalty != 0 { gen.Options["frequency_penalty"] = r.FrequencyPenalty }
|
||||
if r.PresencePenalty != 0 { gen.Options["presence_penalty"] = r.PresencePenalty }
|
||||
if r.Seed != 0 { gen.Options["seed"] = int(r.Seed) }
|
||||
if mt := r.GetMaxTokens(); mt != 0 { gen.Options["num_predict"] = int(mt) }
|
||||
if r.Stop != nil {
|
||||
switch v := r.Stop.(type) {
|
||||
case string: gen.Options["stop"] = []string{v}
|
||||
case []string: gen.Options["stop"] = v
|
||||
case []any: arr:=make([]string,0,len(v)); for _,i:= range v { if s,ok:=i.(string); ok { arr=append(arr,s) } }; if len(arr)>0 { gen.Options["stop"]=arr }
|
||||
}
|
||||
}
|
||||
return gen, nil
|
||||
}
|
||||
|
||||
func requestOpenAI2Embeddings(r dto.EmbeddingRequest) *OllamaEmbeddingRequest {
|
||||
opts := map[string]any{}
|
||||
if r.Temperature != nil { opts["temperature"] = r.Temperature }
|
||||
if r.TopP != 0 { opts["top_p"] = r.TopP }
|
||||
if r.FrequencyPenalty != 0 { opts["frequency_penalty"] = r.FrequencyPenalty }
|
||||
if r.PresencePenalty != 0 { opts["presence_penalty"] = r.PresencePenalty }
|
||||
if r.Seed != 0 { opts["seed"] = int(r.Seed) }
|
||||
if r.Dimensions != 0 { opts["dimensions"] = r.Dimensions }
|
||||
input := r.ParseInput()
|
||||
if len(input)==1 { return &OllamaEmbeddingRequest{Model:r.Model, Input: input[0], Options: opts, Dimensions:r.Dimensions} }
|
||||
return &OllamaEmbeddingRequest{Model:r.Model, Input: input, Options: opts, Dimensions:r.Dimensions}
|
||||
}
|
||||
|
||||
func ollamaEmbeddingHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http.Response) (*dto.Usage, *types.NewAPIError) {
|
||||
var ollamaEmbeddingResponse OllamaEmbeddingResponse
|
||||
responseBody, err := io.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
return nil, types.NewOpenAIError(err, types.ErrorCodeBadResponseBody, http.StatusInternalServerError)
|
||||
}
|
||||
var oResp OllamaEmbeddingResponse
|
||||
body, err := io.ReadAll(resp.Body)
|
||||
if err != nil { return nil, types.NewOpenAIError(err, types.ErrorCodeBadResponseBody, http.StatusInternalServerError) }
|
||||
service.CloseResponseBodyGracefully(resp)
|
||||
err = common.Unmarshal(responseBody, &ollamaEmbeddingResponse)
|
||||
if err != nil {
|
||||
return nil, types.NewOpenAIError(err, types.ErrorCodeBadResponseBody, http.StatusInternalServerError)
|
||||
}
|
||||
if ollamaEmbeddingResponse.Error != "" {
|
||||
return nil, types.NewOpenAIError(fmt.Errorf("ollama error: %s", ollamaEmbeddingResponse.Error), types.ErrorCodeBadResponseBody, http.StatusInternalServerError)
|
||||
}
|
||||
flattenedEmbeddings := flattenEmbeddings(ollamaEmbeddingResponse.Embedding)
|
||||
data := make([]dto.OpenAIEmbeddingResponseItem, 0, 1)
|
||||
data = append(data, dto.OpenAIEmbeddingResponseItem{
|
||||
Embedding: flattenedEmbeddings,
|
||||
Object: "embedding",
|
||||
})
|
||||
usage := &dto.Usage{
|
||||
TotalTokens: info.PromptTokens,
|
||||
CompletionTokens: 0,
|
||||
PromptTokens: info.PromptTokens,
|
||||
}
|
||||
embeddingResponse := &dto.OpenAIEmbeddingResponse{
|
||||
Object: "list",
|
||||
Data: data,
|
||||
Model: info.UpstreamModelName,
|
||||
Usage: *usage,
|
||||
}
|
||||
doResponseBody, err := common.Marshal(embeddingResponse)
|
||||
if err != nil {
|
||||
return nil, types.NewOpenAIError(err, types.ErrorCodeBadResponseBody, http.StatusInternalServerError)
|
||||
}
|
||||
service.IOCopyBytesGracefully(c, resp, doResponseBody)
|
||||
if err = common.Unmarshal(body, &oResp); err != nil { return nil, types.NewOpenAIError(err, types.ErrorCodeBadResponseBody, http.StatusInternalServerError) }
|
||||
if oResp.Error != "" { return nil, types.NewOpenAIError(fmt.Errorf("ollama error: %s", oResp.Error), types.ErrorCodeBadResponseBody, http.StatusInternalServerError) }
|
||||
data := make([]dto.OpenAIEmbeddingResponseItem,0,len(oResp.Embeddings))
|
||||
for i, emb := range oResp.Embeddings { data = append(data, dto.OpenAIEmbeddingResponseItem{Index:i,Object:"embedding",Embedding:emb}) }
|
||||
usage := &dto.Usage{PromptTokens: oResp.PromptEvalCount, CompletionTokens:0, TotalTokens: oResp.PromptEvalCount}
|
||||
embResp := &dto.OpenAIEmbeddingResponse{Object:"list", Data:data, Model: info.UpstreamModelName, Usage:*usage}
|
||||
out, _ := common.Marshal(embResp)
|
||||
service.IOCopyBytesGracefully(c, resp, out)
|
||||
return usage, nil
|
||||
}
|
||||
|
||||
func flattenEmbeddings(embeddings [][]float64) []float64 {
|
||||
flattened := []float64{}
|
||||
for _, row := range embeddings {
|
||||
flattened = append(flattened, row...)
|
||||
}
|
||||
return flattened
|
||||
}
|
||||
|
||||
210
relay/channel/ollama/stream.go
Normal file
210
relay/channel/ollama/stream.go
Normal file
@@ -0,0 +1,210 @@
|
||||
package ollama
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"one-api/common"
|
||||
"one-api/dto"
|
||||
"one-api/logger"
|
||||
relaycommon "one-api/relay/common"
|
||||
"one-api/relay/helper"
|
||||
"one-api/service"
|
||||
"one-api/types"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
type ollamaChatStreamChunk struct {
|
||||
Model string `json:"model"`
|
||||
CreatedAt string `json:"created_at"`
|
||||
// chat
|
||||
Message *struct {
|
||||
Role string `json:"role"`
|
||||
Content string `json:"content"`
|
||||
Thinking json.RawMessage `json:"thinking"`
|
||||
ToolCalls []struct {
|
||||
Function struct {
|
||||
Name string `json:"name"`
|
||||
Arguments interface{} `json:"arguments"`
|
||||
} `json:"function"`
|
||||
} `json:"tool_calls"`
|
||||
} `json:"message"`
|
||||
// generate
|
||||
Response string `json:"response"`
|
||||
Done bool `json:"done"`
|
||||
DoneReason string `json:"done_reason"`
|
||||
TotalDuration int64 `json:"total_duration"`
|
||||
LoadDuration int64 `json:"load_duration"`
|
||||
PromptEvalCount int `json:"prompt_eval_count"`
|
||||
EvalCount int `json:"eval_count"`
|
||||
PromptEvalDuration int64 `json:"prompt_eval_duration"`
|
||||
EvalDuration int64 `json:"eval_duration"`
|
||||
}
|
||||
|
||||
func toUnix(ts string) int64 {
|
||||
if ts == "" { return time.Now().Unix() }
|
||||
// try time.RFC3339 or with nanoseconds
|
||||
t, err := time.Parse(time.RFC3339Nano, ts)
|
||||
if err != nil { t2, err2 := time.Parse(time.RFC3339, ts); if err2==nil { return t2.Unix() }; return time.Now().Unix() }
|
||||
return t.Unix()
|
||||
}
|
||||
|
||||
func ollamaStreamHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http.Response) (*dto.Usage, *types.NewAPIError) {
|
||||
if resp == nil || resp.Body == nil { return nil, types.NewOpenAIError(fmt.Errorf("empty response"), types.ErrorCodeBadResponse, http.StatusBadRequest) }
|
||||
defer service.CloseResponseBodyGracefully(resp)
|
||||
|
||||
helper.SetEventStreamHeaders(c)
|
||||
scanner := bufio.NewScanner(resp.Body)
|
||||
usage := &dto.Usage{}
|
||||
var model = info.UpstreamModelName
|
||||
var responseId = common.GetUUID()
|
||||
var created = time.Now().Unix()
|
||||
var toolCallIndex int
|
||||
start := helper.GenerateStartEmptyResponse(responseId, created, model, nil)
|
||||
if data, err := common.Marshal(start); err == nil { _ = helper.StringData(c, string(data)) }
|
||||
|
||||
for scanner.Scan() {
|
||||
line := scanner.Text()
|
||||
line = strings.TrimSpace(line)
|
||||
if line == "" { continue }
|
||||
var chunk ollamaChatStreamChunk
|
||||
if err := json.Unmarshal([]byte(line), &chunk); err != nil {
|
||||
logger.LogError(c, "ollama stream json decode error: "+err.Error()+" line="+line)
|
||||
return usage, types.NewOpenAIError(err, types.ErrorCodeBadResponseBody, http.StatusInternalServerError)
|
||||
}
|
||||
if chunk.Model != "" { model = chunk.Model }
|
||||
created = toUnix(chunk.CreatedAt)
|
||||
|
||||
if !chunk.Done {
|
||||
// delta content
|
||||
var content string
|
||||
if chunk.Message != nil { content = chunk.Message.Content } else { content = chunk.Response }
|
||||
delta := dto.ChatCompletionsStreamResponse{
|
||||
Id: responseId,
|
||||
Object: "chat.completion.chunk",
|
||||
Created: created,
|
||||
Model: model,
|
||||
Choices: []dto.ChatCompletionsStreamResponseChoice{ {
|
||||
Index: 0,
|
||||
Delta: dto.ChatCompletionsStreamResponseChoiceDelta{ Role: "assistant" },
|
||||
} },
|
||||
}
|
||||
if content != "" { delta.Choices[0].Delta.SetContentString(content) }
|
||||
if chunk.Message != nil && len(chunk.Message.Thinking) > 0 {
|
||||
raw := strings.TrimSpace(string(chunk.Message.Thinking))
|
||||
if raw != "" && raw != "null" { delta.Choices[0].Delta.SetReasoningContent(raw) }
|
||||
}
|
||||
// tool calls
|
||||
if chunk.Message != nil && len(chunk.Message.ToolCalls) > 0 {
|
||||
delta.Choices[0].Delta.ToolCalls = make([]dto.ToolCallResponse,0,len(chunk.Message.ToolCalls))
|
||||
for _, tc := range chunk.Message.ToolCalls {
|
||||
// arguments -> string
|
||||
argBytes, _ := json.Marshal(tc.Function.Arguments)
|
||||
toolId := fmt.Sprintf("call_%d", toolCallIndex)
|
||||
tr := dto.ToolCallResponse{ID:toolId, Type:"function", Function: dto.FunctionResponse{Name: tc.Function.Name, Arguments: string(argBytes)}}
|
||||
tr.SetIndex(toolCallIndex)
|
||||
toolCallIndex++
|
||||
delta.Choices[0].Delta.ToolCalls = append(delta.Choices[0].Delta.ToolCalls, tr)
|
||||
}
|
||||
}
|
||||
if data, err := common.Marshal(delta); err == nil { _ = helper.StringData(c, string(data)) }
|
||||
continue
|
||||
}
|
||||
// done frame
|
||||
// finalize once and break loop
|
||||
usage.PromptTokens = chunk.PromptEvalCount
|
||||
usage.CompletionTokens = chunk.EvalCount
|
||||
usage.TotalTokens = usage.PromptTokens + usage.CompletionTokens
|
||||
finishReason := chunk.DoneReason
|
||||
if finishReason == "" { finishReason = "stop" }
|
||||
// emit stop delta
|
||||
if stop := helper.GenerateStopResponse(responseId, created, model, finishReason); stop != nil {
|
||||
if data, err := common.Marshal(stop); err == nil { _ = helper.StringData(c, string(data)) }
|
||||
}
|
||||
// emit usage frame
|
||||
if final := helper.GenerateFinalUsageResponse(responseId, created, model, *usage); final != nil {
|
||||
if data, err := common.Marshal(final); err == nil { _ = helper.StringData(c, string(data)) }
|
||||
}
|
||||
// send [DONE]
|
||||
helper.Done(c)
|
||||
break
|
||||
}
|
||||
if err := scanner.Err(); err != nil && err != io.EOF { logger.LogError(c, "ollama stream scan error: "+err.Error()) }
|
||||
return usage, nil
|
||||
}
|
||||
|
||||
// non-stream handler for chat/generate
|
||||
func ollamaChatHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http.Response) (*dto.Usage, *types.NewAPIError) {
|
||||
body, err := io.ReadAll(resp.Body)
|
||||
if err != nil { return nil, types.NewOpenAIError(err, types.ErrorCodeReadResponseBodyFailed, http.StatusInternalServerError) }
|
||||
service.CloseResponseBodyGracefully(resp)
|
||||
raw := string(body)
|
||||
if common.DebugEnabled { println("ollama non-stream raw resp:", raw) }
|
||||
|
||||
lines := strings.Split(raw, "\n")
|
||||
var (
|
||||
aggContent strings.Builder
|
||||
reasoningBuilder strings.Builder
|
||||
lastChunk ollamaChatStreamChunk
|
||||
parsedAny bool
|
||||
)
|
||||
for _, ln := range lines {
|
||||
ln = strings.TrimSpace(ln)
|
||||
if ln == "" { continue }
|
||||
var ck ollamaChatStreamChunk
|
||||
if err := json.Unmarshal([]byte(ln), &ck); err != nil {
|
||||
if len(lines) == 1 { return nil, types.NewOpenAIError(err, types.ErrorCodeBadResponseBody, http.StatusInternalServerError) }
|
||||
continue
|
||||
}
|
||||
parsedAny = true
|
||||
lastChunk = ck
|
||||
if ck.Message != nil && len(ck.Message.Thinking) > 0 {
|
||||
raw := strings.TrimSpace(string(ck.Message.Thinking))
|
||||
if raw != "" && raw != "null" { reasoningBuilder.WriteString(raw) }
|
||||
}
|
||||
if ck.Message != nil && ck.Message.Content != "" { aggContent.WriteString(ck.Message.Content) } else if ck.Response != "" { aggContent.WriteString(ck.Response) }
|
||||
}
|
||||
|
||||
if !parsedAny {
|
||||
var single ollamaChatStreamChunk
|
||||
if err := json.Unmarshal(body, &single); err != nil { return nil, types.NewOpenAIError(err, types.ErrorCodeBadResponseBody, http.StatusInternalServerError) }
|
||||
lastChunk = single
|
||||
if single.Message != nil {
|
||||
if len(single.Message.Thinking) > 0 { raw := strings.TrimSpace(string(single.Message.Thinking)); if raw != "" && raw != "null" { reasoningBuilder.WriteString(raw) } }
|
||||
aggContent.WriteString(single.Message.Content)
|
||||
} else { aggContent.WriteString(single.Response) }
|
||||
}
|
||||
|
||||
model := lastChunk.Model
|
||||
if model == "" { model = info.UpstreamModelName }
|
||||
created := toUnix(lastChunk.CreatedAt)
|
||||
usage := &dto.Usage{PromptTokens: lastChunk.PromptEvalCount, CompletionTokens: lastChunk.EvalCount, TotalTokens: lastChunk.PromptEvalCount + lastChunk.EvalCount}
|
||||
content := aggContent.String()
|
||||
finishReason := lastChunk.DoneReason
|
||||
if finishReason == "" { finishReason = "stop" }
|
||||
|
||||
msg := dto.Message{Role: "assistant", Content: contentPtr(content)}
|
||||
if rc := reasoningBuilder.String(); rc != "" { msg.ReasoningContent = rc }
|
||||
full := dto.OpenAITextResponse{
|
||||
Id: common.GetUUID(),
|
||||
Model: model,
|
||||
Object: "chat.completion",
|
||||
Created: created,
|
||||
Choices: []dto.OpenAITextResponseChoice{ {
|
||||
Index: 0,
|
||||
Message: msg,
|
||||
FinishReason: finishReason,
|
||||
} },
|
||||
Usage: *usage,
|
||||
}
|
||||
out, _ := common.Marshal(full)
|
||||
service.IOCopyBytesGracefully(c, resp, out)
|
||||
return usage, nil
|
||||
}
|
||||
|
||||
func contentPtr(s string) *string { if s=="" { return nil }; return &s }
|
||||
@@ -12,6 +12,7 @@ import (
|
||||
"one-api/constant"
|
||||
"one-api/dto"
|
||||
"one-api/logger"
|
||||
"one-api/relay/channel/openrouter"
|
||||
relaycommon "one-api/relay/common"
|
||||
"one-api/relay/helper"
|
||||
"one-api/service"
|
||||
@@ -185,10 +186,27 @@ func OpenaiHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http.Respo
|
||||
if common.DebugEnabled {
|
||||
println("upstream response body:", string(responseBody))
|
||||
}
|
||||
// Unmarshal to simpleResponse
|
||||
if info.ChannelType == constant.ChannelTypeOpenRouter && info.ChannelOtherSettings.IsOpenRouterEnterprise() {
|
||||
// 尝试解析为 openrouter enterprise
|
||||
var enterpriseResponse openrouter.OpenRouterEnterpriseResponse
|
||||
err = common.Unmarshal(responseBody, &enterpriseResponse)
|
||||
if err != nil {
|
||||
return nil, types.NewOpenAIError(err, types.ErrorCodeBadResponseBody, http.StatusInternalServerError)
|
||||
}
|
||||
if enterpriseResponse.Success {
|
||||
responseBody = enterpriseResponse.Data
|
||||
} else {
|
||||
logger.LogError(c, fmt.Sprintf("openrouter enterprise response success=false, data: %s", enterpriseResponse.Data))
|
||||
return nil, types.NewOpenAIError(fmt.Errorf("openrouter response success=false"), types.ErrorCodeBadResponseBody, http.StatusInternalServerError)
|
||||
}
|
||||
}
|
||||
|
||||
err = common.Unmarshal(responseBody, &simpleResponse)
|
||||
if err != nil {
|
||||
return nil, types.NewOpenAIError(err, types.ErrorCodeBadResponseBody, http.StatusInternalServerError)
|
||||
}
|
||||
|
||||
if oaiError := simpleResponse.GetOpenAIError(); oaiError != nil && oaiError.Type != "" {
|
||||
return nil, types.WithOpenAIError(*oaiError, resp.StatusCode)
|
||||
}
|
||||
|
||||
@@ -33,6 +33,12 @@ func OaiResponsesHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http
|
||||
return nil, types.WithOpenAIError(*oaiError, resp.StatusCode)
|
||||
}
|
||||
|
||||
if responsesResponse.HasImageGenerationCall() {
|
||||
c.Set("image_generation_call", true)
|
||||
c.Set("image_generation_call_quality", responsesResponse.GetQuality())
|
||||
c.Set("image_generation_call_size", responsesResponse.GetSize())
|
||||
}
|
||||
|
||||
// 写入新的 response body
|
||||
service.IOCopyBytesGracefully(c, resp, responseBody)
|
||||
|
||||
@@ -80,18 +86,25 @@ func OaiResponsesStreamHandler(c *gin.Context, info *relaycommon.RelayInfo, resp
|
||||
sendResponsesStreamData(c, streamResponse, data)
|
||||
switch streamResponse.Type {
|
||||
case "response.completed":
|
||||
if streamResponse.Response != nil && streamResponse.Response.Usage != nil {
|
||||
if streamResponse.Response.Usage.InputTokens != 0 {
|
||||
usage.PromptTokens = streamResponse.Response.Usage.InputTokens
|
||||
if streamResponse.Response != nil {
|
||||
if streamResponse.Response.Usage != nil {
|
||||
if streamResponse.Response.Usage.InputTokens != 0 {
|
||||
usage.PromptTokens = streamResponse.Response.Usage.InputTokens
|
||||
}
|
||||
if streamResponse.Response.Usage.OutputTokens != 0 {
|
||||
usage.CompletionTokens = streamResponse.Response.Usage.OutputTokens
|
||||
}
|
||||
if streamResponse.Response.Usage.TotalTokens != 0 {
|
||||
usage.TotalTokens = streamResponse.Response.Usage.TotalTokens
|
||||
}
|
||||
if streamResponse.Response.Usage.InputTokensDetails != nil {
|
||||
usage.PromptTokensDetails.CachedTokens = streamResponse.Response.Usage.InputTokensDetails.CachedTokens
|
||||
}
|
||||
}
|
||||
if streamResponse.Response.Usage.OutputTokens != 0 {
|
||||
usage.CompletionTokens = streamResponse.Response.Usage.OutputTokens
|
||||
}
|
||||
if streamResponse.Response.Usage.TotalTokens != 0 {
|
||||
usage.TotalTokens = streamResponse.Response.Usage.TotalTokens
|
||||
}
|
||||
if streamResponse.Response.Usage.InputTokensDetails != nil {
|
||||
usage.PromptTokensDetails.CachedTokens = streamResponse.Response.Usage.InputTokensDetails.CachedTokens
|
||||
if streamResponse.Response.HasImageGenerationCall() {
|
||||
c.Set("image_generation_call", true)
|
||||
c.Set("image_generation_call_quality", streamResponse.Response.GetQuality())
|
||||
c.Set("image_generation_call_size", streamResponse.Response.GetSize())
|
||||
}
|
||||
}
|
||||
case "response.output_text.delta":
|
||||
|
||||
@@ -1,5 +1,7 @@
|
||||
package openrouter
|
||||
|
||||
import "encoding/json"
|
||||
|
||||
type RequestReasoning struct {
|
||||
// One of the following (not both):
|
||||
Effort string `json:"effort,omitempty"` // Can be "high", "medium", or "low" (OpenAI-style)
|
||||
@@ -7,3 +9,8 @@ type RequestReasoning struct {
|
||||
// Optional: Default is false. All models support this.
|
||||
Exclude bool `json:"exclude,omitempty"` // Set to true to exclude reasoning tokens from response
|
||||
}
|
||||
|
||||
type OpenRouterEnterpriseResponse struct {
|
||||
Data json.RawMessage `json:"data"`
|
||||
Success bool `json:"success"`
|
||||
}
|
||||
|
||||
@@ -36,6 +36,7 @@ type requestPayload struct {
|
||||
Prompt string `json:"prompt,omitempty"`
|
||||
Seed int64 `json:"seed"`
|
||||
AspectRatio string `json:"aspect_ratio"`
|
||||
Frames int `json:"frames,omitempty"`
|
||||
}
|
||||
|
||||
type responsePayload struct {
|
||||
@@ -93,6 +94,9 @@ func (a *TaskAdaptor) ValidateRequestAndSetAction(c *gin.Context, info *relaycom
|
||||
|
||||
// BuildRequestURL constructs the upstream URL.
|
||||
func (a *TaskAdaptor) BuildRequestURL(info *relaycommon.RelayInfo) (string, error) {
|
||||
if isNewAPIRelay(info.ApiKey) {
|
||||
return fmt.Sprintf("%s/jimeng/?Action=CVSync2AsyncSubmitTask&Version=2022-08-31", a.baseURL), nil
|
||||
}
|
||||
return fmt.Sprintf("%s/?Action=CVSync2AsyncSubmitTask&Version=2022-08-31", a.baseURL), nil
|
||||
}
|
||||
|
||||
@@ -100,7 +104,12 @@ func (a *TaskAdaptor) BuildRequestURL(info *relaycommon.RelayInfo) (string, erro
|
||||
func (a *TaskAdaptor) BuildRequestHeader(c *gin.Context, req *http.Request, info *relaycommon.RelayInfo) error {
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
req.Header.Set("Accept", "application/json")
|
||||
return a.signRequest(req, a.accessKey, a.secretKey)
|
||||
if isNewAPIRelay(info.ApiKey) {
|
||||
req.Header.Set("Authorization", "Bearer "+info.ApiKey)
|
||||
} else {
|
||||
return a.signRequest(req, a.accessKey, a.secretKey)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// BuildRequestBody converts request into Jimeng specific format.
|
||||
@@ -160,6 +169,9 @@ func (a *TaskAdaptor) FetchTask(baseUrl, key string, body map[string]any) (*http
|
||||
}
|
||||
|
||||
uri := fmt.Sprintf("%s/?Action=CVSync2AsyncGetResult&Version=2022-08-31", baseUrl)
|
||||
if isNewAPIRelay(key) {
|
||||
uri = fmt.Sprintf("%s/jimeng/?Action=CVSync2AsyncGetResult&Version=2022-08-31", a.baseURL)
|
||||
}
|
||||
payload := map[string]string{
|
||||
"req_key": "jimeng_vgfm_t2v_l20", // This is fixed value from doc: https://www.volcengine.com/docs/85621/1544774
|
||||
"task_id": taskID,
|
||||
@@ -177,17 +189,20 @@ func (a *TaskAdaptor) FetchTask(baseUrl, key string, body map[string]any) (*http
|
||||
req.Header.Set("Accept", "application/json")
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
|
||||
keyParts := strings.Split(key, "|")
|
||||
if len(keyParts) != 2 {
|
||||
return nil, fmt.Errorf("invalid api key format for jimeng: expected 'ak|sk'")
|
||||
}
|
||||
accessKey := strings.TrimSpace(keyParts[0])
|
||||
secretKey := strings.TrimSpace(keyParts[1])
|
||||
if isNewAPIRelay(key) {
|
||||
req.Header.Set("Authorization", "Bearer "+key)
|
||||
} else {
|
||||
keyParts := strings.Split(key, "|")
|
||||
if len(keyParts) != 2 {
|
||||
return nil, fmt.Errorf("invalid api key format for jimeng: expected 'ak|sk'")
|
||||
}
|
||||
accessKey := strings.TrimSpace(keyParts[0])
|
||||
secretKey := strings.TrimSpace(keyParts[1])
|
||||
|
||||
if err := a.signRequest(req, accessKey, secretKey); err != nil {
|
||||
return nil, errors.Wrap(err, "sign request failed")
|
||||
if err := a.signRequest(req, accessKey, secretKey); err != nil {
|
||||
return nil, errors.Wrap(err, "sign request failed")
|
||||
}
|
||||
}
|
||||
|
||||
return service.GetHttpClient().Do(req)
|
||||
}
|
||||
|
||||
@@ -311,10 +326,15 @@ func hmacSHA256(key []byte, data []byte) []byte {
|
||||
|
||||
func (a *TaskAdaptor) convertToRequestPayload(req *relaycommon.TaskSubmitReq) (*requestPayload, error) {
|
||||
r := requestPayload{
|
||||
ReqKey: "jimeng_vgfm_i2v_l20",
|
||||
Prompt: req.Prompt,
|
||||
AspectRatio: "16:9", // Default aspect ratio
|
||||
Seed: -1, // Default to random
|
||||
ReqKey: req.Model,
|
||||
Prompt: req.Prompt,
|
||||
}
|
||||
|
||||
switch req.Duration {
|
||||
case 10:
|
||||
r.Frames = 241 // 24*10+1 = 241
|
||||
default:
|
||||
r.Frames = 121 // 24*5+1 = 121
|
||||
}
|
||||
|
||||
// Handle one-of image_urls or binary_data_base64
|
||||
@@ -334,6 +354,22 @@ func (a *TaskAdaptor) convertToRequestPayload(req *relaycommon.TaskSubmitReq) (*
|
||||
if err != nil {
|
||||
return nil, errors.Wrap(err, "unmarshal metadata failed")
|
||||
}
|
||||
|
||||
// 即梦视频3.0 ReqKey转换
|
||||
// https://www.volcengine.com/docs/85621/1792707
|
||||
if strings.Contains(r.ReqKey, "jimeng_v30") {
|
||||
if len(r.ImageUrls) > 1 {
|
||||
// 多张图片:首尾帧生成
|
||||
r.ReqKey = strings.Replace(r.ReqKey, "jimeng_v30", "jimeng_i2v_first_tail_v30", 1)
|
||||
} else if len(r.ImageUrls) == 1 {
|
||||
// 单张图片:图生视频
|
||||
r.ReqKey = strings.Replace(r.ReqKey, "jimeng_v30", "jimeng_i2v_first_v30", 1)
|
||||
} else {
|
||||
// 无图片:文生视频
|
||||
r.ReqKey = strings.Replace(r.ReqKey, "jimeng_v30", "jimeng_t2v_v30", 1)
|
||||
}
|
||||
}
|
||||
|
||||
return &r, nil
|
||||
}
|
||||
|
||||
@@ -362,3 +398,7 @@ func (a *TaskAdaptor) ParseTaskResult(respBody []byte) (*relaycommon.TaskInfo, e
|
||||
taskResult.Url = resTask.Data.VideoUrl
|
||||
return &taskResult, nil
|
||||
}
|
||||
|
||||
func isNewAPIRelay(apiKey string) bool {
|
||||
return strings.HasPrefix(apiKey, "sk-")
|
||||
}
|
||||
|
||||
@@ -117,6 +117,11 @@ func (a *TaskAdaptor) ValidateRequestAndSetAction(c *gin.Context, info *relaycom
|
||||
// BuildRequestURL constructs the upstream URL.
|
||||
func (a *TaskAdaptor) BuildRequestURL(info *relaycommon.RelayInfo) (string, error) {
|
||||
path := lo.Ternary(info.Action == constant.TaskActionGenerate, "/v1/videos/image2video", "/v1/videos/text2video")
|
||||
|
||||
if isNewAPIRelay(info.ApiKey) {
|
||||
return fmt.Sprintf("%s/kling%s", a.baseURL, path), nil
|
||||
}
|
||||
|
||||
return fmt.Sprintf("%s%s", a.baseURL, path), nil
|
||||
}
|
||||
|
||||
@@ -199,6 +204,9 @@ func (a *TaskAdaptor) FetchTask(baseUrl, key string, body map[string]any) (*http
|
||||
}
|
||||
path := lo.Ternary(action == constant.TaskActionGenerate, "/v1/videos/image2video", "/v1/videos/text2video")
|
||||
url := fmt.Sprintf("%s%s/%s", baseUrl, path, taskID)
|
||||
if isNewAPIRelay(key) {
|
||||
url = fmt.Sprintf("%s/kling%s/%s", baseUrl, path, taskID)
|
||||
}
|
||||
|
||||
req, err := http.NewRequest(http.MethodGet, url, nil)
|
||||
if err != nil {
|
||||
@@ -304,8 +312,13 @@ func (a *TaskAdaptor) createJWTToken() (string, error) {
|
||||
//}
|
||||
|
||||
func (a *TaskAdaptor) createJWTTokenWithKey(apiKey string) (string, error) {
|
||||
|
||||
if isNewAPIRelay(apiKey) {
|
||||
return apiKey, nil // new api relay
|
||||
}
|
||||
keyParts := strings.Split(apiKey, "|")
|
||||
if len(keyParts) != 2 {
|
||||
return "", errors.New("invalid api_key, required format is accessKey|secretKey")
|
||||
}
|
||||
accessKey := strings.TrimSpace(keyParts[0])
|
||||
if len(keyParts) == 1 {
|
||||
return accessKey, nil
|
||||
@@ -352,3 +365,7 @@ func (a *TaskAdaptor) ParseTaskResult(respBody []byte) (*relaycommon.TaskInfo, e
|
||||
}
|
||||
return taskInfo, nil
|
||||
}
|
||||
|
||||
func isNewAPIRelay(apiKey string) bool {
|
||||
return strings.HasPrefix(apiKey, "sk-")
|
||||
}
|
||||
|
||||
@@ -80,8 +80,7 @@ func (a *TaskAdaptor) Init(info *relaycommon.RelayInfo) {
|
||||
}
|
||||
|
||||
func (a *TaskAdaptor) ValidateRequestAndSetAction(c *gin.Context, info *relaycommon.RelayInfo) *dto.TaskError {
|
||||
// Use the unified validation method for TaskSubmitReq with image-based action determination
|
||||
return relaycommon.ValidateTaskRequestWithImageBinding(c, info)
|
||||
return relaycommon.ValidateBasicTaskRequest(c, info, constant.TaskActionGenerate)
|
||||
}
|
||||
|
||||
func (a *TaskAdaptor) BuildRequestBody(c *gin.Context, _ *relaycommon.RelayInfo) (io.Reader, error) {
|
||||
@@ -112,6 +111,10 @@ func (a *TaskAdaptor) BuildRequestURL(info *relaycommon.RelayInfo) (string, erro
|
||||
switch info.Action {
|
||||
case constant.TaskActionGenerate:
|
||||
path = "/img2video"
|
||||
case constant.TaskActionFirstTailGenerate:
|
||||
path = "/start-end2video"
|
||||
case constant.TaskActionReferenceGenerate:
|
||||
path = "/reference2video"
|
||||
default:
|
||||
path = "/text2video"
|
||||
}
|
||||
@@ -187,14 +190,9 @@ func (a *TaskAdaptor) GetChannelName() string {
|
||||
// ============================
|
||||
|
||||
func (a *TaskAdaptor) convertToRequestPayload(req *relaycommon.TaskSubmitReq) (*requestPayload, error) {
|
||||
var images []string
|
||||
if req.Image != "" {
|
||||
images = []string{req.Image}
|
||||
}
|
||||
|
||||
r := requestPayload{
|
||||
Model: defaultString(req.Model, "viduq1"),
|
||||
Images: images,
|
||||
Images: req.Images,
|
||||
Prompt: req.Prompt,
|
||||
Duration: defaultInt(req.Duration, 5),
|
||||
Resolution: defaultString(req.Size, "1080p"),
|
||||
|
||||
@@ -9,6 +9,7 @@ import (
|
||||
"mime/multipart"
|
||||
"net/http"
|
||||
"net/textproto"
|
||||
channelconstant "one-api/constant"
|
||||
"one-api/dto"
|
||||
"one-api/relay/channel"
|
||||
"one-api/relay/channel/openai"
|
||||
@@ -41,6 +42,8 @@ func (a *Adaptor) ConvertAudioRequest(c *gin.Context, info *relaycommon.RelayInf
|
||||
|
||||
func (a *Adaptor) ConvertImageRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.ImageRequest) (any, error) {
|
||||
switch info.RelayMode {
|
||||
case constant.RelayModeImagesGenerations:
|
||||
return request, nil
|
||||
case constant.RelayModeImagesEdits:
|
||||
|
||||
var requestBody bytes.Buffer
|
||||
@@ -186,20 +189,26 @@ func (a *Adaptor) Init(info *relaycommon.RelayInfo) {
|
||||
}
|
||||
|
||||
func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) {
|
||||
// 支持自定义域名,如果未设置则使用默认域名
|
||||
baseUrl := info.ChannelBaseUrl
|
||||
if baseUrl == "" {
|
||||
baseUrl = channelconstant.ChannelBaseURLs[channelconstant.ChannelTypeVolcEngine]
|
||||
}
|
||||
|
||||
switch info.RelayMode {
|
||||
case constant.RelayModeChatCompletions:
|
||||
if strings.HasPrefix(info.UpstreamModelName, "bot") {
|
||||
return fmt.Sprintf("%s/api/v3/bots/chat/completions", info.ChannelBaseUrl), nil
|
||||
return fmt.Sprintf("%s/api/v3/bots/chat/completions", baseUrl), nil
|
||||
}
|
||||
return fmt.Sprintf("%s/api/v3/chat/completions", info.ChannelBaseUrl), nil
|
||||
return fmt.Sprintf("%s/api/v3/chat/completions", baseUrl), nil
|
||||
case constant.RelayModeEmbeddings:
|
||||
return fmt.Sprintf("%s/api/v3/embeddings", info.ChannelBaseUrl), nil
|
||||
return fmt.Sprintf("%s/api/v3/embeddings", baseUrl), nil
|
||||
case constant.RelayModeImagesGenerations:
|
||||
return fmt.Sprintf("%s/api/v3/images/generations", info.ChannelBaseUrl), nil
|
||||
return fmt.Sprintf("%s/api/v3/images/generations", baseUrl), nil
|
||||
case constant.RelayModeImagesEdits:
|
||||
return fmt.Sprintf("%s/api/v3/images/edits", info.ChannelBaseUrl), nil
|
||||
return fmt.Sprintf("%s/api/v3/images/edits", baseUrl), nil
|
||||
case constant.RelayModeRerank:
|
||||
return fmt.Sprintf("%s/api/v3/rerank", info.ChannelBaseUrl), nil
|
||||
return fmt.Sprintf("%s/api/v3/rerank", baseUrl), nil
|
||||
default:
|
||||
}
|
||||
return "", fmt.Errorf("unsupported relay mode: %d", info.RelayMode)
|
||||
|
||||
@@ -8,6 +8,12 @@ var ModelList = []string{
|
||||
"Doubao-lite-32k",
|
||||
"Doubao-lite-4k",
|
||||
"Doubao-embedding",
|
||||
"doubao-seedream-4-0-250828",
|
||||
"seedream-4-0-250828",
|
||||
"doubao-seedance-1-0-pro-250528",
|
||||
"seedance-1-0-pro-250528",
|
||||
"doubao-seed-1-6-thinking-250715",
|
||||
"seed-1-6-thinking-250715",
|
||||
}
|
||||
|
||||
var ChannelName = "volcengine"
|
||||
|
||||
@@ -207,10 +207,6 @@ func xunfeiMakeRequest(textRequest dto.GeneralOpenAIRequest, domain, authUrl, ap
|
||||
return nil, nil, err
|
||||
}
|
||||
|
||||
defer func() {
|
||||
conn.Close()
|
||||
}()
|
||||
|
||||
data := requestOpenAI2Xunfei(textRequest, appId, domain)
|
||||
err = conn.WriteJSON(data)
|
||||
if err != nil {
|
||||
@@ -220,6 +216,9 @@ func xunfeiMakeRequest(textRequest dto.GeneralOpenAIRequest, domain, authUrl, ap
|
||||
dataChan := make(chan XunfeiChatResponse)
|
||||
stopChan := make(chan bool)
|
||||
go func() {
|
||||
defer func() {
|
||||
conn.Close()
|
||||
}()
|
||||
for {
|
||||
_, msg, err := conn.ReadMessage()
|
||||
if err != nil {
|
||||
|
||||
@@ -6,6 +6,7 @@ import (
|
||||
"io"
|
||||
"net/http"
|
||||
"one-api/common"
|
||||
"one-api/constant"
|
||||
"one-api/dto"
|
||||
relaycommon "one-api/relay/common"
|
||||
"one-api/relay/helper"
|
||||
@@ -69,6 +70,31 @@ func ClaudeHelper(c *gin.Context, info *relaycommon.RelayInfo) (newAPIError *typ
|
||||
info.UpstreamModelName = request.Model
|
||||
}
|
||||
|
||||
if info.ChannelSetting.SystemPrompt != "" {
|
||||
if request.System == nil {
|
||||
request.SetStringSystem(info.ChannelSetting.SystemPrompt)
|
||||
} else if info.ChannelSetting.SystemPromptOverride {
|
||||
common.SetContextKey(c, constant.ContextKeySystemPromptOverride, true)
|
||||
if request.IsStringSystem() {
|
||||
existing := strings.TrimSpace(request.GetStringSystem())
|
||||
if existing == "" {
|
||||
request.SetStringSystem(info.ChannelSetting.SystemPrompt)
|
||||
} else {
|
||||
request.SetStringSystem(info.ChannelSetting.SystemPrompt + "\n" + existing)
|
||||
}
|
||||
} else {
|
||||
systemContents := request.ParseSystem()
|
||||
newSystem := dto.ClaudeMediaMessage{Type: dto.ContentTypeText}
|
||||
newSystem.SetText(info.ChannelSetting.SystemPrompt)
|
||||
if len(systemContents) == 0 {
|
||||
request.System = []dto.ClaudeMediaMessage{newSystem}
|
||||
} else {
|
||||
request.System = append([]dto.ClaudeMediaMessage{newSystem}, systemContents...)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
var requestBody io.Reader
|
||||
if model_setting.GetGlobalSettings().PassThroughRequestEnabled || info.ChannelSetting.PassThroughBodyEnabled {
|
||||
body, err := common.GetRequestBody(c)
|
||||
|
||||
@@ -79,34 +79,18 @@ func ValidateBasicTaskRequest(c *gin.Context, info *RelayInfo, action string) *d
|
||||
req.Images = []string{req.Image}
|
||||
}
|
||||
|
||||
if req.HasImage() {
|
||||
action = constant.TaskActionGenerate
|
||||
if info.ChannelType == constant.ChannelTypeVidu {
|
||||
// vidu 增加 首尾帧生视频和参考图生视频
|
||||
if len(req.Images) == 2 {
|
||||
action = constant.TaskActionFirstTailGenerate
|
||||
} else if len(req.Images) > 2 {
|
||||
action = constant.TaskActionReferenceGenerate
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
storeTaskRequest(c, info, action, req)
|
||||
return nil
|
||||
}
|
||||
|
||||
func ValidateTaskRequestWithImage(c *gin.Context, info *RelayInfo, requestObj interface{}) *dto.TaskError {
|
||||
hasPrompt, ok := requestObj.(HasPrompt)
|
||||
if !ok {
|
||||
return createTaskError(fmt.Errorf("request must have prompt"), "invalid_request", http.StatusBadRequest, true)
|
||||
}
|
||||
|
||||
if taskErr := validatePrompt(hasPrompt.GetPrompt()); taskErr != nil {
|
||||
return taskErr
|
||||
}
|
||||
|
||||
action := constant.TaskActionTextGenerate
|
||||
if hasImage, ok := requestObj.(HasImage); ok && hasImage.HasImage() {
|
||||
action = constant.TaskActionGenerate
|
||||
}
|
||||
|
||||
storeTaskRequest(c, info, action, requestObj)
|
||||
return nil
|
||||
}
|
||||
|
||||
func ValidateTaskRequestWithImageBinding(c *gin.Context, info *RelayInfo) *dto.TaskError {
|
||||
var req TaskSubmitReq
|
||||
if err := c.ShouldBindJSON(&req); err != nil {
|
||||
return createTaskError(err, "invalid_request_body", http.StatusBadRequest, false)
|
||||
}
|
||||
|
||||
return ValidateTaskRequestWithImage(c, info, req)
|
||||
}
|
||||
|
||||
@@ -90,41 +90,43 @@ func TextHelper(c *gin.Context, info *relaycommon.RelayInfo) (newAPIError *types
|
||||
|
||||
if info.ChannelSetting.SystemPrompt != "" {
|
||||
// 如果有系统提示,则将其添加到请求中
|
||||
request := convertedRequest.(*dto.GeneralOpenAIRequest)
|
||||
containSystemPrompt := false
|
||||
for _, message := range request.Messages {
|
||||
if message.Role == request.GetSystemRoleName() {
|
||||
containSystemPrompt = true
|
||||
break
|
||||
}
|
||||
}
|
||||
if !containSystemPrompt {
|
||||
// 如果没有系统提示,则添加系统提示
|
||||
systemMessage := dto.Message{
|
||||
Role: request.GetSystemRoleName(),
|
||||
Content: info.ChannelSetting.SystemPrompt,
|
||||
}
|
||||
request.Messages = append([]dto.Message{systemMessage}, request.Messages...)
|
||||
} else if info.ChannelSetting.SystemPromptOverride {
|
||||
common.SetContextKey(c, constant.ContextKeySystemPromptOverride, true)
|
||||
// 如果有系统提示,且允许覆盖,则拼接到前面
|
||||
for i, message := range request.Messages {
|
||||
request, ok := convertedRequest.(*dto.GeneralOpenAIRequest)
|
||||
if ok {
|
||||
containSystemPrompt := false
|
||||
for _, message := range request.Messages {
|
||||
if message.Role == request.GetSystemRoleName() {
|
||||
if message.IsStringContent() {
|
||||
request.Messages[i].SetStringContent(info.ChannelSetting.SystemPrompt + "\n" + message.StringContent())
|
||||
} else {
|
||||
contents := message.ParseContent()
|
||||
contents = append([]dto.MediaContent{
|
||||
{
|
||||
Type: dto.ContentTypeText,
|
||||
Text: info.ChannelSetting.SystemPrompt,
|
||||
},
|
||||
}, contents...)
|
||||
request.Messages[i].Content = contents
|
||||
}
|
||||
containSystemPrompt = true
|
||||
break
|
||||
}
|
||||
}
|
||||
if !containSystemPrompt {
|
||||
// 如果没有系统提示,则添加系统提示
|
||||
systemMessage := dto.Message{
|
||||
Role: request.GetSystemRoleName(),
|
||||
Content: info.ChannelSetting.SystemPrompt,
|
||||
}
|
||||
request.Messages = append([]dto.Message{systemMessage}, request.Messages...)
|
||||
} else if info.ChannelSetting.SystemPromptOverride {
|
||||
common.SetContextKey(c, constant.ContextKeySystemPromptOverride, true)
|
||||
// 如果有系统提示,且允许覆盖,则拼接到前面
|
||||
for i, message := range request.Messages {
|
||||
if message.Role == request.GetSystemRoleName() {
|
||||
if message.IsStringContent() {
|
||||
request.Messages[i].SetStringContent(info.ChannelSetting.SystemPrompt + "\n" + message.StringContent())
|
||||
} else {
|
||||
contents := message.ParseContent()
|
||||
contents = append([]dto.MediaContent{
|
||||
{
|
||||
Type: dto.ContentTypeText,
|
||||
Text: info.ChannelSetting.SystemPrompt,
|
||||
},
|
||||
}, contents...)
|
||||
request.Messages[i].Content = contents
|
||||
}
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -276,6 +278,13 @@ func postConsumeQuota(ctx *gin.Context, relayInfo *relaycommon.RelayInfo, usage
|
||||
fileSearchTool.CallCount, dFileSearchQuota.String())
|
||||
}
|
||||
}
|
||||
var dImageGenerationCallQuota decimal.Decimal
|
||||
var imageGenerationCallPrice float64
|
||||
if ctx.GetBool("image_generation_call") {
|
||||
imageGenerationCallPrice = operation_setting.GetGPTImage1PriceOnceCall(ctx.GetString("image_generation_call_quality"), ctx.GetString("image_generation_call_size"))
|
||||
dImageGenerationCallQuota = decimal.NewFromFloat(imageGenerationCallPrice).Mul(dGroupRatio).Mul(dQuotaPerUnit)
|
||||
extraContent += fmt.Sprintf("Image Generation Call 花费 %s", dImageGenerationCallQuota.String())
|
||||
}
|
||||
|
||||
var quotaCalculateDecimal decimal.Decimal
|
||||
|
||||
@@ -331,6 +340,8 @@ func postConsumeQuota(ctx *gin.Context, relayInfo *relaycommon.RelayInfo, usage
|
||||
quotaCalculateDecimal = quotaCalculateDecimal.Add(dFileSearchQuota)
|
||||
// 添加 audio input 独立计费
|
||||
quotaCalculateDecimal = quotaCalculateDecimal.Add(audioInputQuota)
|
||||
// 添加 image generation call 计费
|
||||
quotaCalculateDecimal = quotaCalculateDecimal.Add(dImageGenerationCallQuota)
|
||||
|
||||
quota := int(quotaCalculateDecimal.Round(0).IntPart())
|
||||
totalTokens := promptTokens + completionTokens
|
||||
@@ -429,6 +440,10 @@ func postConsumeQuota(ctx *gin.Context, relayInfo *relaycommon.RelayInfo, usage
|
||||
other["audio_input_token_count"] = audioTokens
|
||||
other["audio_input_price"] = audioInputPrice
|
||||
}
|
||||
if !dImageGenerationCallQuota.IsZero() {
|
||||
other["image_generation_call"] = true
|
||||
other["image_generation_call_price"] = imageGenerationCallPrice
|
||||
}
|
||||
model.RecordConsumeLog(ctx, relayInfo.UserId, model.RecordConsumeLogParams{
|
||||
ChannelId: relayInfo.ChannelId,
|
||||
PromptTokens: promptTokens,
|
||||
|
||||
@@ -6,6 +6,7 @@ import (
|
||||
"io"
|
||||
"net/http"
|
||||
"one-api/common"
|
||||
"one-api/constant"
|
||||
"one-api/dto"
|
||||
"one-api/logger"
|
||||
"one-api/relay/channel/gemini"
|
||||
@@ -94,6 +95,32 @@ func GeminiHelper(c *gin.Context, info *relaycommon.RelayInfo) (newAPIError *typ
|
||||
|
||||
adaptor.Init(info)
|
||||
|
||||
if info.ChannelSetting.SystemPrompt != "" {
|
||||
if request.SystemInstructions == nil {
|
||||
request.SystemInstructions = &dto.GeminiChatContent{
|
||||
Parts: []dto.GeminiPart{
|
||||
{Text: info.ChannelSetting.SystemPrompt},
|
||||
},
|
||||
}
|
||||
} else if len(request.SystemInstructions.Parts) == 0 {
|
||||
request.SystemInstructions.Parts = []dto.GeminiPart{{Text: info.ChannelSetting.SystemPrompt}}
|
||||
} else if info.ChannelSetting.SystemPromptOverride {
|
||||
common.SetContextKey(c, constant.ContextKeySystemPromptOverride, true)
|
||||
merged := false
|
||||
for i := range request.SystemInstructions.Parts {
|
||||
if request.SystemInstructions.Parts[i].Text == "" {
|
||||
continue
|
||||
}
|
||||
request.SystemInstructions.Parts[i].Text = info.ChannelSetting.SystemPrompt + "\n" + request.SystemInstructions.Parts[i].Text
|
||||
merged = true
|
||||
break
|
||||
}
|
||||
if !merged {
|
||||
request.SystemInstructions.Parts = append([]dto.GeminiPart{{Text: info.ChannelSetting.SystemPrompt}}, request.SystemInstructions.Parts...)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Clean up empty system instruction
|
||||
if request.SystemInstructions != nil {
|
||||
hasContent := false
|
||||
|
||||
@@ -52,6 +52,8 @@ func ModelPriceHelper(c *gin.Context, info *relaycommon.RelayInfo, promptTokens
|
||||
var cacheRatio float64
|
||||
var imageRatio float64
|
||||
var cacheCreationRatio float64
|
||||
var audioRatio float64
|
||||
var audioCompletionRatio float64
|
||||
if !usePrice {
|
||||
preConsumedTokens := common.Max(promptTokens, common.PreConsumedQuota)
|
||||
if meta.MaxTokens != 0 {
|
||||
@@ -73,6 +75,8 @@ func ModelPriceHelper(c *gin.Context, info *relaycommon.RelayInfo, promptTokens
|
||||
cacheRatio, _ = ratio_setting.GetCacheRatio(info.OriginModelName)
|
||||
cacheCreationRatio, _ = ratio_setting.GetCreateCacheRatio(info.OriginModelName)
|
||||
imageRatio, _ = ratio_setting.GetImageRatio(info.OriginModelName)
|
||||
audioRatio = ratio_setting.GetAudioRatio(info.OriginModelName)
|
||||
audioCompletionRatio = ratio_setting.GetAudioCompletionRatio(info.OriginModelName)
|
||||
ratio := modelRatio * groupRatioInfo.GroupRatio
|
||||
preConsumedQuota = int(float64(preConsumedTokens) * ratio)
|
||||
} else {
|
||||
@@ -90,6 +94,8 @@ func ModelPriceHelper(c *gin.Context, info *relaycommon.RelayInfo, promptTokens
|
||||
UsePrice: usePrice,
|
||||
CacheRatio: cacheRatio,
|
||||
ImageRatio: imageRatio,
|
||||
AudioRatio: audioRatio,
|
||||
AudioCompletionRatio: audioCompletionRatio,
|
||||
CacheCreationRatio: cacheCreationRatio,
|
||||
ShouldPreConsumedQuota: preConsumedQuota,
|
||||
}
|
||||
|
||||
@@ -21,7 +21,11 @@ func GetAndValidateRequest(c *gin.Context, format types.RelayFormat) (request dt
|
||||
case types.RelayFormatOpenAI:
|
||||
request, err = GetAndValidateTextRequest(c, relayMode)
|
||||
case types.RelayFormatGemini:
|
||||
request, err = GetAndValidateGeminiRequest(c)
|
||||
if strings.Contains(c.Request.URL.Path, ":embedContent") || strings.Contains(c.Request.URL.Path, ":batchEmbedContents") {
|
||||
request, err = GetAndValidateGeminiEmbeddingRequest(c)
|
||||
} else {
|
||||
request, err = GetAndValidateGeminiRequest(c)
|
||||
}
|
||||
case types.RelayFormatClaude:
|
||||
request, err = GetAndValidateClaudeRequest(c)
|
||||
case types.RelayFormatOpenAIResponses:
|
||||
@@ -288,7 +292,6 @@ func GetAndValidateTextRequest(c *gin.Context, relayMode int) (*dto.GeneralOpenA
|
||||
}
|
||||
|
||||
func GetAndValidateGeminiRequest(c *gin.Context) (*dto.GeminiChatRequest, error) {
|
||||
|
||||
request := &dto.GeminiChatRequest{}
|
||||
err := common.UnmarshalBodyReusable(c, request)
|
||||
if err != nil {
|
||||
@@ -304,3 +307,12 @@ func GetAndValidateGeminiRequest(c *gin.Context) (*dto.GeminiChatRequest, error)
|
||||
|
||||
return request, nil
|
||||
}
|
||||
|
||||
func GetAndValidateGeminiEmbeddingRequest(c *gin.Context) (*dto.GeminiEmbeddingRequest, error) {
|
||||
request := &dto.GeminiEmbeddingRequest{}
|
||||
err := common.UnmarshalBodyReusable(c, request)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return request, nil
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user