mirror of
https://github.com/QuantumNous/new-api.git
synced 2026-05-01 03:41:46 +00:00
feat: 重构ollama渠道请求
This commit is contained in:
@@ -10,6 +10,7 @@ import (
|
|||||||
relaycommon "one-api/relay/common"
|
relaycommon "one-api/relay/common"
|
||||||
relayconstant "one-api/relay/constant"
|
relayconstant "one-api/relay/constant"
|
||||||
"one-api/types"
|
"one-api/types"
|
||||||
|
"strings"
|
||||||
|
|
||||||
"github.com/gin-gonic/gin"
|
"github.com/gin-gonic/gin"
|
||||||
)
|
)
|
||||||
@@ -48,15 +49,15 @@ func (a *Adaptor) Init(info *relaycommon.RelayInfo) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) {
|
func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) {
|
||||||
if info.RelayFormat == types.RelayFormatClaude {
|
// embeddings fixed endpoint
|
||||||
return info.ChannelBaseUrl + "/v1/chat/completions", nil
|
if info.RelayMode == relayconstant.RelayModeEmbeddings {
|
||||||
}
|
|
||||||
switch info.RelayMode {
|
|
||||||
case relayconstant.RelayModeEmbeddings:
|
|
||||||
return info.ChannelBaseUrl + "/api/embed", nil
|
return info.ChannelBaseUrl + "/api/embed", nil
|
||||||
default:
|
|
||||||
return relaycommon.GetFullRequestURL(info.ChannelBaseUrl, info.RequestURLPath, info.ChannelType), nil
|
|
||||||
}
|
}
|
||||||
|
// For chat vs generate: if original path contains "/v1/completions" map to generate; otherwise chat
|
||||||
|
if strings.Contains(info.RequestURLPath, "/v1/completions") || info.RelayMode == relayconstant.RelayModeCompletions {
|
||||||
|
return info.ChannelBaseUrl + "/api/generate", nil
|
||||||
|
}
|
||||||
|
return info.ChannelBaseUrl + "/api/chat", nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Header, info *relaycommon.RelayInfo) error {
|
func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Header, info *relaycommon.RelayInfo) error {
|
||||||
@@ -66,10 +67,12 @@ func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Header, info *rel
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (a *Adaptor) ConvertOpenAIRequest(c *gin.Context, info *relaycommon.RelayInfo, request *dto.GeneralOpenAIRequest) (any, error) {
|
func (a *Adaptor) ConvertOpenAIRequest(c *gin.Context, info *relaycommon.RelayInfo, request *dto.GeneralOpenAIRequest) (any, error) {
|
||||||
if request == nil {
|
if request == nil { return nil, errors.New("request is nil") }
|
||||||
return nil, errors.New("request is nil")
|
// decide generate or chat
|
||||||
|
if strings.Contains(info.RequestURLPath, "/v1/completions") || info.RelayMode == relayconstant.RelayModeCompletions {
|
||||||
|
return openAIToGenerate(c, request)
|
||||||
}
|
}
|
||||||
return requestOpenAI2Ollama(c, request)
|
return openAIChatToOllamaChat(c, request)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (a *Adaptor) ConvertRerankRequest(c *gin.Context, relayMode int, request dto.RerankRequest) (any, error) {
|
func (a *Adaptor) ConvertRerankRequest(c *gin.Context, relayMode int, request dto.RerankRequest) (any, error) {
|
||||||
@@ -92,15 +95,13 @@ func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, request
|
|||||||
func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage any, err *types.NewAPIError) {
|
func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage any, err *types.NewAPIError) {
|
||||||
switch info.RelayMode {
|
switch info.RelayMode {
|
||||||
case relayconstant.RelayModeEmbeddings:
|
case relayconstant.RelayModeEmbeddings:
|
||||||
usage, err = ollamaEmbeddingHandler(c, info, resp)
|
return ollamaEmbeddingHandler(c, info, resp)
|
||||||
default:
|
default:
|
||||||
if info.IsStream {
|
if info.IsStream {
|
||||||
usage, err = openai.OaiStreamHandler(c, info, resp)
|
return ollamaStreamHandler(c, info, resp)
|
||||||
} else {
|
|
||||||
usage, err = openai.OpenaiHandler(c, info, resp)
|
|
||||||
}
|
}
|
||||||
|
return ollamaChatHandler(c, info, resp)
|
||||||
}
|
}
|
||||||
return
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (a *Adaptor) GetModelList() []string {
|
func (a *Adaptor) GetModelList() []string {
|
||||||
|
|||||||
@@ -5,45 +5,70 @@ import (
|
|||||||
"one-api/dto"
|
"one-api/dto"
|
||||||
)
|
)
|
||||||
|
|
||||||
type OllamaRequest struct {
|
// OllamaChatMessage represents a single chat message
|
||||||
Model string `json:"model,omitempty"`
|
type OllamaChatMessage struct {
|
||||||
Messages []dto.Message `json:"messages,omitempty"`
|
Role string `json:"role"`
|
||||||
Stream bool `json:"stream,omitempty"`
|
Content string `json:"content,omitempty"`
|
||||||
Temperature *float64 `json:"temperature,omitempty"`
|
Images []string `json:"images,omitempty"`
|
||||||
Seed float64 `json:"seed,omitempty"`
|
ToolCalls []OllamaToolCall `json:"tool_calls,omitempty"`
|
||||||
Topp float64 `json:"top_p,omitempty"`
|
ToolName string `json:"tool_name,omitempty"`
|
||||||
TopK int `json:"top_k,omitempty"`
|
Thinking json.RawMessage `json:"thinking,omitempty"`
|
||||||
Stop any `json:"stop,omitempty"`
|
|
||||||
MaxTokens uint `json:"max_tokens,omitempty"`
|
|
||||||
Tools []dto.ToolCallRequest `json:"tools,omitempty"`
|
|
||||||
ResponseFormat any `json:"response_format,omitempty"`
|
|
||||||
FrequencyPenalty float64 `json:"frequency_penalty,omitempty"`
|
|
||||||
PresencePenalty float64 `json:"presence_penalty,omitempty"`
|
|
||||||
Suffix any `json:"suffix,omitempty"`
|
|
||||||
StreamOptions *dto.StreamOptions `json:"stream_options,omitempty"`
|
|
||||||
Prompt any `json:"prompt,omitempty"`
|
|
||||||
Think json.RawMessage `json:"think,omitempty"`
|
|
||||||
}
|
}
|
||||||
|
|
||||||
type Options struct {
|
type OllamaToolFunction struct {
|
||||||
Seed int `json:"seed,omitempty"`
|
Name string `json:"name"`
|
||||||
Temperature *float64 `json:"temperature,omitempty"`
|
Description string `json:"description,omitempty"`
|
||||||
TopK int `json:"top_k,omitempty"`
|
Parameters interface{} `json:"parameters,omitempty"`
|
||||||
TopP float64 `json:"top_p,omitempty"`
|
}
|
||||||
FrequencyPenalty float64 `json:"frequency_penalty,omitempty"`
|
|
||||||
PresencePenalty float64 `json:"presence_penalty,omitempty"`
|
type OllamaTool struct {
|
||||||
NumPredict int `json:"num_predict,omitempty"`
|
Type string `json:"type"`
|
||||||
NumCtx int `json:"num_ctx,omitempty"`
|
Function OllamaToolFunction `json:"function"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type OllamaToolCall struct {
|
||||||
|
Function struct {
|
||||||
|
Name string `json:"name"`
|
||||||
|
Arguments interface{} `json:"arguments"`
|
||||||
|
} `json:"function"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// OllamaChatRequest -> /api/chat
|
||||||
|
type OllamaChatRequest struct {
|
||||||
|
Model string `json:"model"`
|
||||||
|
Messages []OllamaChatMessage `json:"messages"`
|
||||||
|
Tools interface{} `json:"tools,omitempty"`
|
||||||
|
Format interface{} `json:"format,omitempty"`
|
||||||
|
Stream bool `json:"stream,omitempty"`
|
||||||
|
Options map[string]any `json:"options,omitempty"`
|
||||||
|
KeepAlive interface{} `json:"keep_alive,omitempty"`
|
||||||
|
Think json.RawMessage `json:"think,omitempty"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// OllamaGenerateRequest -> /api/generate
|
||||||
|
type OllamaGenerateRequest struct {
|
||||||
|
Model string `json:"model"`
|
||||||
|
Prompt string `json:"prompt,omitempty"`
|
||||||
|
Suffix string `json:"suffix,omitempty"`
|
||||||
|
Images []string `json:"images,omitempty"`
|
||||||
|
Format interface{} `json:"format,omitempty"`
|
||||||
|
Stream bool `json:"stream,omitempty"`
|
||||||
|
Options map[string]any `json:"options,omitempty"`
|
||||||
|
KeepAlive interface{} `json:"keep_alive,omitempty"`
|
||||||
|
Think json.RawMessage `json:"think,omitempty"`
|
||||||
}
|
}
|
||||||
|
|
||||||
type OllamaEmbeddingRequest struct {
|
type OllamaEmbeddingRequest struct {
|
||||||
Model string `json:"model,omitempty"`
|
Model string `json:"model"`
|
||||||
Input []string `json:"input"`
|
Input interface{} `json:"input"`
|
||||||
Options *Options `json:"options,omitempty"`
|
Options map[string]any `json:"options,omitempty"`
|
||||||
|
Dimensions int `json:"dimensions,omitempty"`
|
||||||
}
|
}
|
||||||
|
|
||||||
type OllamaEmbeddingResponse struct {
|
type OllamaEmbeddingResponse struct {
|
||||||
Error string `json:"error,omitempty"`
|
Error string `json:"error,omitempty"`
|
||||||
Model string `json:"model"`
|
Model string `json:"model"`
|
||||||
Embedding [][]float64 `json:"embeddings,omitempty"`
|
Embeddings [][]float64 `json:"embeddings"`
|
||||||
|
PromptEvalCount int `json:"prompt_eval_count,omitempty"`
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -1,6 +1,7 @@
|
|||||||
package ollama
|
package ollama
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"encoding/json"
|
||||||
"fmt"
|
"fmt"
|
||||||
"io"
|
"io"
|
||||||
"net/http"
|
"net/http"
|
||||||
@@ -14,121 +15,179 @@ import (
|
|||||||
"github.com/gin-gonic/gin"
|
"github.com/gin-gonic/gin"
|
||||||
)
|
)
|
||||||
|
|
||||||
func requestOpenAI2Ollama(c *gin.Context, request *dto.GeneralOpenAIRequest) (*OllamaRequest, error) {
|
// openAIChatToOllamaChat converts OpenAI-style chat request to Ollama chat
|
||||||
messages := make([]dto.Message, 0, len(request.Messages))
|
func openAIChatToOllamaChat(c *gin.Context, r *dto.GeneralOpenAIRequest) (*OllamaChatRequest, error) {
|
||||||
for _, message := range request.Messages {
|
chatReq := &OllamaChatRequest{
|
||||||
if !message.IsStringContent() {
|
Model: r.Model,
|
||||||
mediaMessages := message.ParseContent()
|
Stream: r.Stream,
|
||||||
for j, mediaMessage := range mediaMessages {
|
Options: map[string]any{},
|
||||||
if mediaMessage.Type == dto.ContentTypeImageURL {
|
Think: r.Think,
|
||||||
imageUrl := mediaMessage.GetImageMedia()
|
}
|
||||||
// check if not base64
|
// format mapping
|
||||||
if strings.HasPrefix(imageUrl.Url, "http") {
|
if r.ResponseFormat != nil {
|
||||||
fileData, err := service.GetFileBase64FromUrl(c, imageUrl.Url, "formatting image for Ollama")
|
if r.ResponseFormat.Type == "json" {
|
||||||
if err != nil {
|
chatReq.Format = "json"
|
||||||
return nil, err
|
} else if r.ResponseFormat.Type == "json_schema" {
|
||||||
|
// supply schema object directly
|
||||||
|
if len(r.ResponseFormat.JsonSchema) > 0 {
|
||||||
|
var schema any
|
||||||
|
_ = json.Unmarshal(r.ResponseFormat.JsonSchema, &schema)
|
||||||
|
chatReq.Format = schema
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// options mapping
|
||||||
|
if r.Temperature != nil { chatReq.Options["temperature"] = r.Temperature }
|
||||||
|
if r.TopP != 0 { chatReq.Options["top_p"] = r.TopP }
|
||||||
|
if r.TopK != 0 { chatReq.Options["top_k"] = r.TopK }
|
||||||
|
if r.FrequencyPenalty != 0 { chatReq.Options["frequency_penalty"] = r.FrequencyPenalty }
|
||||||
|
if r.PresencePenalty != 0 { chatReq.Options["presence_penalty"] = r.PresencePenalty }
|
||||||
|
if r.Seed != 0 { chatReq.Options["seed"] = int(r.Seed) }
|
||||||
|
if mt := r.GetMaxTokens(); mt != 0 { chatReq.Options["num_predict"] = int(mt) }
|
||||||
|
|
||||||
|
// Stop -> options.stop (array)
|
||||||
|
if r.Stop != nil {
|
||||||
|
switch v := r.Stop.(type) {
|
||||||
|
case string:
|
||||||
|
chatReq.Options["stop"] = []string{v}
|
||||||
|
case []string:
|
||||||
|
chatReq.Options["stop"] = v
|
||||||
|
case []any:
|
||||||
|
arr := make([]string,0,len(v))
|
||||||
|
for _, i := range v { if s,ok:=i.(string); ok { arr = append(arr,s) } }
|
||||||
|
if len(arr)>0 { chatReq.Options["stop"] = arr }
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// tools
|
||||||
|
if len(r.Tools) > 0 {
|
||||||
|
tools := make([]OllamaTool,0,len(r.Tools))
|
||||||
|
for _, t := range r.Tools {
|
||||||
|
tools = append(tools, OllamaTool{Type: "function", Function: OllamaToolFunction{Name: t.Function.Name, Description: t.Function.Description, Parameters: t.Function.Parameters}})
|
||||||
|
}
|
||||||
|
chatReq.Tools = tools
|
||||||
|
}
|
||||||
|
|
||||||
|
// messages
|
||||||
|
chatReq.Messages = make([]OllamaChatMessage,0,len(r.Messages))
|
||||||
|
for _, m := range r.Messages {
|
||||||
|
// gather text parts & images
|
||||||
|
var textBuilder strings.Builder
|
||||||
|
var images []string
|
||||||
|
if m.IsStringContent() {
|
||||||
|
textBuilder.WriteString(m.StringContent())
|
||||||
|
} else {
|
||||||
|
parts := m.ParseContent()
|
||||||
|
for _, part := range parts {
|
||||||
|
if part.Type == dto.ContentTypeImageURL {
|
||||||
|
img := part.GetImageMedia()
|
||||||
|
if img != nil && img.Url != "" {
|
||||||
|
// ensure base64 dataURL
|
||||||
|
if strings.HasPrefix(img.Url, "http") {
|
||||||
|
fileData, err := service.GetFileBase64FromUrl(c, img.Url, "fetch image for ollama chat")
|
||||||
|
if err != nil { return nil, err }
|
||||||
|
img.Url = fmt.Sprintf("data:%s;base64,%s", fileData.MimeType, fileData.Base64Data)
|
||||||
}
|
}
|
||||||
imageUrl.Url = fmt.Sprintf("data:%s;base64,%s", fileData.MimeType, fileData.Base64Data)
|
images = append(images, img.Url)
|
||||||
}
|
}
|
||||||
mediaMessage.ImageUrl = imageUrl
|
} else if part.Type == dto.ContentTypeText {
|
||||||
mediaMessages[j] = mediaMessage
|
textBuilder.WriteString(part.Text)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
message.SetMediaContent(mediaMessages)
|
|
||||||
}
|
}
|
||||||
messages = append(messages, dto.Message{
|
cm := OllamaChatMessage{Role: m.Role, Content: textBuilder.String()}
|
||||||
Role: message.Role,
|
if len(images)>0 { cm.Images = images }
|
||||||
Content: message.Content,
|
// history tool call result message
|
||||||
ToolCalls: message.ToolCalls,
|
if m.Role == "tool" && m.Name != nil { cm.ToolName = *m.Name }
|
||||||
ToolCallId: message.ToolCallId,
|
// tool calls from assistant previous message
|
||||||
})
|
if len(m.ToolCalls)>0 {
|
||||||
|
calls := make([]OllamaToolCall,0,len(m.ToolCalls))
|
||||||
|
for _, tc := range m.ToolCalls {
|
||||||
|
var args interface{}
|
||||||
|
if tc.Function.Arguments != "" { _ = json.Unmarshal([]byte(tc.Function.Arguments), &args) }
|
||||||
|
oc := OllamaToolCall{}
|
||||||
|
oc.Function.Name = tc.Function.Name
|
||||||
|
if args==nil { args = map[string]any{} }
|
||||||
|
oc.Function.Arguments = args
|
||||||
|
calls = append(calls, oc)
|
||||||
|
}
|
||||||
|
cm.ToolCalls = calls
|
||||||
|
}
|
||||||
|
chatReq.Messages = append(chatReq.Messages, cm)
|
||||||
}
|
}
|
||||||
str, ok := request.Stop.(string)
|
return chatReq, nil
|
||||||
var Stop []string
|
|
||||||
if ok {
|
|
||||||
Stop = []string{str}
|
|
||||||
} else {
|
|
||||||
Stop, _ = request.Stop.([]string)
|
|
||||||
}
|
|
||||||
ollamaRequest := &OllamaRequest{
|
|
||||||
Model: request.Model,
|
|
||||||
Messages: messages,
|
|
||||||
Stream: request.Stream,
|
|
||||||
Temperature: request.Temperature,
|
|
||||||
Seed: request.Seed,
|
|
||||||
Topp: request.TopP,
|
|
||||||
TopK: request.TopK,
|
|
||||||
Stop: Stop,
|
|
||||||
Tools: request.Tools,
|
|
||||||
MaxTokens: request.GetMaxTokens(),
|
|
||||||
ResponseFormat: request.ResponseFormat,
|
|
||||||
FrequencyPenalty: request.FrequencyPenalty,
|
|
||||||
PresencePenalty: request.PresencePenalty,
|
|
||||||
Prompt: request.Prompt,
|
|
||||||
StreamOptions: request.StreamOptions,
|
|
||||||
Suffix: request.Suffix,
|
|
||||||
}
|
|
||||||
ollamaRequest.Think = request.Think
|
|
||||||
return ollamaRequest, nil
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func requestOpenAI2Embeddings(request dto.EmbeddingRequest) *OllamaEmbeddingRequest {
|
// openAIToGenerate converts OpenAI completions request to Ollama generate
|
||||||
return &OllamaEmbeddingRequest{
|
func openAIToGenerate(c *gin.Context, r *dto.GeneralOpenAIRequest) (*OllamaGenerateRequest, error) {
|
||||||
Model: request.Model,
|
gen := &OllamaGenerateRequest{
|
||||||
Input: request.ParseInput(),
|
Model: r.Model,
|
||||||
Options: &Options{
|
Stream: r.Stream,
|
||||||
Seed: int(request.Seed),
|
Options: map[string]any{},
|
||||||
Temperature: request.Temperature,
|
Think: r.Think,
|
||||||
TopP: request.TopP,
|
|
||||||
FrequencyPenalty: request.FrequencyPenalty,
|
|
||||||
PresencePenalty: request.PresencePenalty,
|
|
||||||
},
|
|
||||||
}
|
}
|
||||||
|
// Prompt may be in r.Prompt (string or []any)
|
||||||
|
if r.Prompt != nil {
|
||||||
|
switch v := r.Prompt.(type) {
|
||||||
|
case string:
|
||||||
|
gen.Prompt = v
|
||||||
|
case []any:
|
||||||
|
var sb strings.Builder
|
||||||
|
for _, it := range v { if s,ok:=it.(string); ok { sb.WriteString(s) } }
|
||||||
|
gen.Prompt = sb.String()
|
||||||
|
default:
|
||||||
|
gen.Prompt = fmt.Sprintf("%v", r.Prompt)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if r.Suffix != nil { if s,ok:=r.Suffix.(string); ok { gen.Suffix = s } }
|
||||||
|
if r.ResponseFormat != nil {
|
||||||
|
if r.ResponseFormat.Type == "json" { gen.Format = "json" } else if r.ResponseFormat.Type == "json_schema" { var schema any; _ = json.Unmarshal(r.ResponseFormat.JsonSchema,&schema); gen.Format=schema }
|
||||||
|
}
|
||||||
|
if r.Temperature != nil { gen.Options["temperature"] = r.Temperature }
|
||||||
|
if r.TopP != 0 { gen.Options["top_p"] = r.TopP }
|
||||||
|
if r.TopK != 0 { gen.Options["top_k"] = r.TopK }
|
||||||
|
if r.FrequencyPenalty != 0 { gen.Options["frequency_penalty"] = r.FrequencyPenalty }
|
||||||
|
if r.PresencePenalty != 0 { gen.Options["presence_penalty"] = r.PresencePenalty }
|
||||||
|
if r.Seed != 0 { gen.Options["seed"] = int(r.Seed) }
|
||||||
|
if mt := r.GetMaxTokens(); mt != 0 { gen.Options["num_predict"] = int(mt) }
|
||||||
|
if r.Stop != nil {
|
||||||
|
switch v := r.Stop.(type) {
|
||||||
|
case string: gen.Options["stop"] = []string{v}
|
||||||
|
case []string: gen.Options["stop"] = v
|
||||||
|
case []any: arr:=make([]string,0,len(v)); for _,i:= range v { if s,ok:=i.(string); ok { arr=append(arr,s) } }; if len(arr)>0 { gen.Options["stop"]=arr }
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return gen, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func requestOpenAI2Embeddings(r dto.EmbeddingRequest) *OllamaEmbeddingRequest {
|
||||||
|
opts := map[string]any{}
|
||||||
|
if r.Temperature != nil { opts["temperature"] = r.Temperature }
|
||||||
|
if r.TopP != 0 { opts["top_p"] = r.TopP }
|
||||||
|
if r.TopK != 0 { opts["top_k"] = r.TopK }
|
||||||
|
if r.FrequencyPenalty != 0 { opts["frequency_penalty"] = r.FrequencyPenalty }
|
||||||
|
if r.PresencePenalty != 0 { opts["presence_penalty"] = r.PresencePenalty }
|
||||||
|
if r.Seed != 0 { opts["seed"] = int(r.Seed) }
|
||||||
|
if r.Dimensions != 0 { opts["dimensions"] = r.Dimensions }
|
||||||
|
input := r.ParseInput()
|
||||||
|
if len(input)==1 { return &OllamaEmbeddingRequest{Model:r.Model, Input: input[0], Options: opts, Dimensions:r.Dimensions} }
|
||||||
|
return &OllamaEmbeddingRequest{Model:r.Model, Input: input, Options: opts, Dimensions:r.Dimensions}
|
||||||
}
|
}
|
||||||
|
|
||||||
func ollamaEmbeddingHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http.Response) (*dto.Usage, *types.NewAPIError) {
|
func ollamaEmbeddingHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http.Response) (*dto.Usage, *types.NewAPIError) {
|
||||||
var ollamaEmbeddingResponse OllamaEmbeddingResponse
|
var oResp OllamaEmbeddingResponse
|
||||||
responseBody, err := io.ReadAll(resp.Body)
|
body, err := io.ReadAll(resp.Body)
|
||||||
if err != nil {
|
if err != nil { return nil, types.NewOpenAIError(err, types.ErrorCodeBadResponseBody, http.StatusInternalServerError) }
|
||||||
return nil, types.NewOpenAIError(err, types.ErrorCodeBadResponseBody, http.StatusInternalServerError)
|
|
||||||
}
|
|
||||||
service.CloseResponseBodyGracefully(resp)
|
service.CloseResponseBodyGracefully(resp)
|
||||||
err = common.Unmarshal(responseBody, &ollamaEmbeddingResponse)
|
if err = common.Unmarshal(body, &oResp); err != nil { return nil, types.NewOpenAIError(err, types.ErrorCodeBadResponseBody, http.StatusInternalServerError) }
|
||||||
if err != nil {
|
if oResp.Error != "" { return nil, types.NewOpenAIError(fmt.Errorf("ollama error: %s", oResp.Error), types.ErrorCodeBadResponseBody, http.StatusInternalServerError) }
|
||||||
return nil, types.NewOpenAIError(err, types.ErrorCodeBadResponseBody, http.StatusInternalServerError)
|
data := make([]dto.OpenAIEmbeddingResponseItem,0,len(oResp.Embeddings))
|
||||||
}
|
for i, emb := range oResp.Embeddings { data = append(data, dto.OpenAIEmbeddingResponseItem{Index:i,Object:"embedding",Embedding:emb}) }
|
||||||
if ollamaEmbeddingResponse.Error != "" {
|
usage := &dto.Usage{PromptTokens: oResp.PromptEvalCount, CompletionTokens:0, TotalTokens: oResp.PromptEvalCount}
|
||||||
return nil, types.NewOpenAIError(fmt.Errorf("ollama error: %s", ollamaEmbeddingResponse.Error), types.ErrorCodeBadResponseBody, http.StatusInternalServerError)
|
embResp := &dto.OpenAIEmbeddingResponse{Object:"list", Data:data, Model: info.UpstreamModelName, Usage:*usage}
|
||||||
}
|
out, _ := common.Marshal(embResp)
|
||||||
flattenedEmbeddings := flattenEmbeddings(ollamaEmbeddingResponse.Embedding)
|
service.IOCopyBytesGracefully(c, resp, out)
|
||||||
data := make([]dto.OpenAIEmbeddingResponseItem, 0, 1)
|
|
||||||
data = append(data, dto.OpenAIEmbeddingResponseItem{
|
|
||||||
Embedding: flattenedEmbeddings,
|
|
||||||
Object: "embedding",
|
|
||||||
})
|
|
||||||
usage := &dto.Usage{
|
|
||||||
TotalTokens: info.PromptTokens,
|
|
||||||
CompletionTokens: 0,
|
|
||||||
PromptTokens: info.PromptTokens,
|
|
||||||
}
|
|
||||||
embeddingResponse := &dto.OpenAIEmbeddingResponse{
|
|
||||||
Object: "list",
|
|
||||||
Data: data,
|
|
||||||
Model: info.UpstreamModelName,
|
|
||||||
Usage: *usage,
|
|
||||||
}
|
|
||||||
doResponseBody, err := common.Marshal(embeddingResponse)
|
|
||||||
if err != nil {
|
|
||||||
return nil, types.NewOpenAIError(err, types.ErrorCodeBadResponseBody, http.StatusInternalServerError)
|
|
||||||
}
|
|
||||||
service.IOCopyBytesGracefully(c, resp, doResponseBody)
|
|
||||||
return usage, nil
|
return usage, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func flattenEmbeddings(embeddings [][]float64) []float64 {
|
|
||||||
flattened := []float64{}
|
|
||||||
for _, row := range embeddings {
|
|
||||||
flattened = append(flattened, row...)
|
|
||||||
}
|
|
||||||
return flattened
|
|
||||||
}
|
|
||||||
|
|||||||
165
relay/channel/ollama/stream.go
Normal file
165
relay/channel/ollama/stream.go
Normal file
@@ -0,0 +1,165 @@
|
|||||||
|
package ollama
|
||||||
|
|
||||||
|
import (
|
||||||
|
"bufio"
|
||||||
|
"encoding/json"
|
||||||
|
"fmt"
|
||||||
|
"io"
|
||||||
|
"net/http"
|
||||||
|
"one-api/common"
|
||||||
|
"one-api/dto"
|
||||||
|
"one-api/logger"
|
||||||
|
relaycommon "one-api/relay/common"
|
||||||
|
"one-api/relay/helper"
|
||||||
|
"one-api/service"
|
||||||
|
"one-api/types"
|
||||||
|
"strings"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/gin-gonic/gin"
|
||||||
|
)
|
||||||
|
|
||||||
|
// Ollama streaming chunk (chat or generate)
|
||||||
|
type ollamaChatStreamChunk struct {
|
||||||
|
Model string `json:"model"`
|
||||||
|
CreatedAt string `json:"created_at"`
|
||||||
|
// chat
|
||||||
|
Message *struct {
|
||||||
|
Role string `json:"role"`
|
||||||
|
Content string `json:"content"`
|
||||||
|
ToolCalls []struct { `json:"tool_calls"`
|
||||||
|
Function struct {
|
||||||
|
Name string `json:"name"`
|
||||||
|
Arguments interface{} `json:"arguments"`
|
||||||
|
} `json:"function"`
|
||||||
|
} `json:"tool_calls"`
|
||||||
|
} `json:"message"`
|
||||||
|
// generate
|
||||||
|
Response string `json:"response"`
|
||||||
|
Done bool `json:"done"`
|
||||||
|
DoneReason string `json:"done_reason"`
|
||||||
|
TotalDuration int64 `json:"total_duration"`
|
||||||
|
LoadDuration int64 `json:"load_duration"`
|
||||||
|
PromptEvalCount int `json:"prompt_eval_count"`
|
||||||
|
EvalCount int `json:"eval_count"`
|
||||||
|
// generate mode may use these
|
||||||
|
PromptEvalDuration int64 `json:"prompt_eval_duration"`
|
||||||
|
EvalDuration int64 `json:"eval_duration"`
|
||||||
|
}
|
||||||
|
|
||||||
|
func toUnix(ts string) int64 { // parse RFC3339 / variant; fallback time.Now
|
||||||
|
if ts == "" { return time.Now().Unix() }
|
||||||
|
// try time.RFC3339 or with nanoseconds
|
||||||
|
t, err := time.Parse(time.RFC3339Nano, ts)
|
||||||
|
if err != nil { t2, err2 := time.Parse(time.RFC3339, ts); if err2==nil { return t2.Unix() }; return time.Now().Unix() }
|
||||||
|
return t.Unix()
|
||||||
|
}
|
||||||
|
|
||||||
|
// streaming handler: convert Ollama stream -> OpenAI SSE
|
||||||
|
func ollamaStreamHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http.Response) (*dto.Usage, *types.NewAPIError) {
|
||||||
|
if resp == nil || resp.Body == nil { return nil, types.NewOpenAIError(fmt.Errorf("empty response"), types.ErrorCodeBadResponse, http.StatusBadRequest) }
|
||||||
|
defer service.CloseResponseBodyGracefully(resp)
|
||||||
|
|
||||||
|
helper.SetEventStreamHeaders(c)
|
||||||
|
scanner := bufio.NewScanner(resp.Body)
|
||||||
|
usage := &dto.Usage{}
|
||||||
|
var model = info.UpstreamModelName
|
||||||
|
var responseId = common.GetUUID()
|
||||||
|
var created = time.Now().Unix()
|
||||||
|
var aggregatedText strings.Builder
|
||||||
|
var toolCallIndex int
|
||||||
|
// send start event
|
||||||
|
start := helper.GenerateStartEmptyResponse(responseId, created, model, nil)
|
||||||
|
if data, err := common.Marshal(start); err == nil { _ = helper.StringData(c, string(data)) }
|
||||||
|
|
||||||
|
for scanner.Scan() {
|
||||||
|
line := scanner.Text()
|
||||||
|
line = strings.TrimSpace(line)
|
||||||
|
if line == "" { continue }
|
||||||
|
var chunk ollamaChatStreamChunk
|
||||||
|
if err := json.Unmarshal([]byte(line), &chunk); err != nil {
|
||||||
|
logger.LogError(c, "ollama stream json decode error: "+err.Error()+" line="+line)
|
||||||
|
return usage, types.NewOpenAIError(err, types.ErrorCodeBadResponseBody, http.StatusInternalServerError)
|
||||||
|
}
|
||||||
|
if chunk.Model != "" { model = chunk.Model }
|
||||||
|
created = toUnix(chunk.CreatedAt)
|
||||||
|
|
||||||
|
if !chunk.Done {
|
||||||
|
// delta content
|
||||||
|
var content string
|
||||||
|
if chunk.Message != nil { content = chunk.Message.Content } else { content = chunk.Response }
|
||||||
|
if content != "" { aggregatedText.WriteString(content) }
|
||||||
|
delta := dto.ChatCompletionsStreamResponse{
|
||||||
|
Id: responseId,
|
||||||
|
Object: "chat.completion.chunk",
|
||||||
|
Created: created,
|
||||||
|
Model: model,
|
||||||
|
Choices: []dto.ChatCompletionsStreamResponseChoice{ {
|
||||||
|
Index: 0,
|
||||||
|
Delta: dto.ChatCompletionsStreamResponseChoiceDelta{ Role: "assistant" },
|
||||||
|
} },
|
||||||
|
}
|
||||||
|
if content != "" { delta.Choices[0].Delta.SetContentString(content) }
|
||||||
|
// tool calls
|
||||||
|
if chunk.Message != nil && len(chunk.Message.ToolCalls) > 0 {
|
||||||
|
delta.Choices[0].Delta.ToolCalls = make([]dto.ToolCallResponse,0,len(chunk.Message.ToolCalls))
|
||||||
|
for _, tc := range chunk.Message.ToolCalls {
|
||||||
|
// arguments -> string
|
||||||
|
argBytes, _ := json.Marshal(tc.Function.Arguments)
|
||||||
|
tr := dto.ToolCallResponse{ID:"", Type:nil, Function: dto.FunctionResponse{Name: tc.Function.Name, Arguments: string(argBytes)}}
|
||||||
|
tr.SetIndex(toolCallIndex)
|
||||||
|
toolCallIndex++
|
||||||
|
delta.Choices[0].Delta.ToolCalls = append(delta.Choices[0].Delta.ToolCalls, tr)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if data, err := common.Marshal(delta); err == nil { _ = helper.StringData(c, string(data)) }
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
// done frame
|
||||||
|
usage.PromptTokens = chunk.PromptEvalCount
|
||||||
|
usage.CompletionTokens = chunk.EvalCount
|
||||||
|
usage.TotalTokens = usage.PromptTokens + usage.CompletionTokens
|
||||||
|
finishReason := chunk.DoneReason
|
||||||
|
if finishReason == "" { finishReason = "stop" }
|
||||||
|
stop := helper.GenerateStopResponse(responseId, created, model, finishReason)
|
||||||
|
if data, err := common.Marshal(stop); err == nil { _ = helper.StringData(c, string(data)) }
|
||||||
|
final := helper.GenerateFinalUsageResponse(responseId, created, model, *usage)
|
||||||
|
if data, err := common.Marshal(final); err == nil { _ = helper.StringData(c, string(data)) }
|
||||||
|
}
|
||||||
|
if err := scanner.Err(); err != nil && err != io.EOF { logger.LogError(c, "ollama stream scan error: "+err.Error()) }
|
||||||
|
return usage, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// non-stream handler for chat/generate
|
||||||
|
func ollamaChatHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http.Response) (*dto.Usage, *types.NewAPIError) {
|
||||||
|
body, err := io.ReadAll(resp.Body)
|
||||||
|
if err != nil { return nil, types.NewOpenAIError(err, types.ErrorCodeReadResponseBodyFailed, http.StatusInternalServerError) }
|
||||||
|
service.CloseResponseBodyGracefully(resp)
|
||||||
|
if common.DebugEnabled { println("ollama non-stream resp:", string(body)) }
|
||||||
|
var chunk ollamaChatStreamChunk
|
||||||
|
if err = json.Unmarshal(body, &chunk); err != nil { return nil, types.NewOpenAIError(err, types.ErrorCodeBadResponseBody, http.StatusInternalServerError) }
|
||||||
|
model := chunk.Model
|
||||||
|
if model == "" { model = info.UpstreamModelName }
|
||||||
|
created := toUnix(chunk.CreatedAt)
|
||||||
|
content := ""
|
||||||
|
if chunk.Message != nil { content = chunk.Message.Content } else { content = chunk.Response }
|
||||||
|
usage := &dto.Usage{PromptTokens: chunk.PromptEvalCount, CompletionTokens: chunk.EvalCount, TotalTokens: chunk.PromptEvalCount + chunk.EvalCount}
|
||||||
|
// Build OpenAI style response
|
||||||
|
full := dto.OpenAITextResponse{
|
||||||
|
Id: common.GetUUID(),
|
||||||
|
Model: model,
|
||||||
|
Object: "chat.completion",
|
||||||
|
Created: created,
|
||||||
|
Choices: []dto.OpenAITextResponseChoice{ {
|
||||||
|
Index: 0,
|
||||||
|
Message: dto.Message{Role: "assistant", Content: contentPtr(content)},
|
||||||
|
FinishReason: func() string { if chunk.DoneReason == "" { return "stop" } ; return chunk.DoneReason }(),
|
||||||
|
} },
|
||||||
|
Usage: *usage,
|
||||||
|
}
|
||||||
|
out, _ := common.Marshal(full)
|
||||||
|
service.IOCopyBytesGracefully(c, resp, out)
|
||||||
|
return usage, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func contentPtr(s string) *string { if s=="" { return nil }; return &s }
|
||||||
Reference in New Issue
Block a user