diff --git a/relay/channel/ollama/adaptor.go b/relay/channel/ollama/adaptor.go index d6b5b697e..3732be91b 100644 --- a/relay/channel/ollama/adaptor.go +++ b/relay/channel/ollama/adaptor.go @@ -10,6 +10,7 @@ import ( relaycommon "one-api/relay/common" relayconstant "one-api/relay/constant" "one-api/types" + "strings" "github.com/gin-gonic/gin" ) @@ -48,15 +49,15 @@ func (a *Adaptor) Init(info *relaycommon.RelayInfo) { } func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) { - if info.RelayFormat == types.RelayFormatClaude { - return info.ChannelBaseUrl + "/v1/chat/completions", nil - } - switch info.RelayMode { - case relayconstant.RelayModeEmbeddings: + // embeddings fixed endpoint + if info.RelayMode == relayconstant.RelayModeEmbeddings { return info.ChannelBaseUrl + "/api/embed", nil - default: - return relaycommon.GetFullRequestURL(info.ChannelBaseUrl, info.RequestURLPath, info.ChannelType), nil } + // For chat vs generate: if original path contains "/v1/completions" map to generate; otherwise chat + if strings.Contains(info.RequestURLPath, "/v1/completions") || info.RelayMode == relayconstant.RelayModeCompletions { + return info.ChannelBaseUrl + "/api/generate", nil + } + return info.ChannelBaseUrl + "/api/chat", nil } func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Header, info *relaycommon.RelayInfo) error { @@ -66,10 +67,12 @@ func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Header, info *rel } func (a *Adaptor) ConvertOpenAIRequest(c *gin.Context, info *relaycommon.RelayInfo, request *dto.GeneralOpenAIRequest) (any, error) { - if request == nil { - return nil, errors.New("request is nil") + if request == nil { return nil, errors.New("request is nil") } + // decide generate or chat + if strings.Contains(info.RequestURLPath, "/v1/completions") || info.RelayMode == relayconstant.RelayModeCompletions { + return openAIToGenerate(c, request) } - return requestOpenAI2Ollama(c, request) + return openAIChatToOllamaChat(c, request) } func (a *Adaptor) ConvertRerankRequest(c *gin.Context, relayMode int, request dto.RerankRequest) (any, error) { @@ -92,15 +95,13 @@ func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, request func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage any, err *types.NewAPIError) { switch info.RelayMode { case relayconstant.RelayModeEmbeddings: - usage, err = ollamaEmbeddingHandler(c, info, resp) + return ollamaEmbeddingHandler(c, info, resp) default: if info.IsStream { - usage, err = openai.OaiStreamHandler(c, info, resp) - } else { - usage, err = openai.OpenaiHandler(c, info, resp) + return ollamaStreamHandler(c, info, resp) } + return ollamaChatHandler(c, info, resp) } - return } func (a *Adaptor) GetModelList() []string { diff --git a/relay/channel/ollama/dto.go b/relay/channel/ollama/dto.go index 317c2a4a1..b3d083dce 100644 --- a/relay/channel/ollama/dto.go +++ b/relay/channel/ollama/dto.go @@ -5,45 +5,70 @@ import ( "one-api/dto" ) -type OllamaRequest struct { - Model string `json:"model,omitempty"` - Messages []dto.Message `json:"messages,omitempty"` - Stream bool `json:"stream,omitempty"` - Temperature *float64 `json:"temperature,omitempty"` - Seed float64 `json:"seed,omitempty"` - Topp float64 `json:"top_p,omitempty"` - TopK int `json:"top_k,omitempty"` - Stop any `json:"stop,omitempty"` - MaxTokens uint `json:"max_tokens,omitempty"` - Tools []dto.ToolCallRequest `json:"tools,omitempty"` - ResponseFormat any `json:"response_format,omitempty"` - FrequencyPenalty float64 `json:"frequency_penalty,omitempty"` - PresencePenalty float64 `json:"presence_penalty,omitempty"` - Suffix any `json:"suffix,omitempty"` - StreamOptions *dto.StreamOptions `json:"stream_options,omitempty"` - Prompt any `json:"prompt,omitempty"` - Think json.RawMessage `json:"think,omitempty"` +// OllamaChatMessage represents a single chat message +type OllamaChatMessage struct { + Role string `json:"role"` + Content string `json:"content,omitempty"` + Images []string `json:"images,omitempty"` + ToolCalls []OllamaToolCall `json:"tool_calls,omitempty"` + ToolName string `json:"tool_name,omitempty"` + Thinking json.RawMessage `json:"thinking,omitempty"` } -type Options struct { - Seed int `json:"seed,omitempty"` - Temperature *float64 `json:"temperature,omitempty"` - TopK int `json:"top_k,omitempty"` - TopP float64 `json:"top_p,omitempty"` - FrequencyPenalty float64 `json:"frequency_penalty,omitempty"` - PresencePenalty float64 `json:"presence_penalty,omitempty"` - NumPredict int `json:"num_predict,omitempty"` - NumCtx int `json:"num_ctx,omitempty"` +type OllamaToolFunction struct { + Name string `json:"name"` + Description string `json:"description,omitempty"` + Parameters interface{} `json:"parameters,omitempty"` +} + +type OllamaTool struct { + Type string `json:"type"` + Function OllamaToolFunction `json:"function"` +} + +type OllamaToolCall struct { + Function struct { + Name string `json:"name"` + Arguments interface{} `json:"arguments"` + } `json:"function"` +} + +// OllamaChatRequest -> /api/chat +type OllamaChatRequest struct { + Model string `json:"model"` + Messages []OllamaChatMessage `json:"messages"` + Tools interface{} `json:"tools,omitempty"` + Format interface{} `json:"format,omitempty"` + Stream bool `json:"stream,omitempty"` + Options map[string]any `json:"options,omitempty"` + KeepAlive interface{} `json:"keep_alive,omitempty"` + Think json.RawMessage `json:"think,omitempty"` +} + +// OllamaGenerateRequest -> /api/generate +type OllamaGenerateRequest struct { + Model string `json:"model"` + Prompt string `json:"prompt,omitempty"` + Suffix string `json:"suffix,omitempty"` + Images []string `json:"images,omitempty"` + Format interface{} `json:"format,omitempty"` + Stream bool `json:"stream,omitempty"` + Options map[string]any `json:"options,omitempty"` + KeepAlive interface{} `json:"keep_alive,omitempty"` + Think json.RawMessage `json:"think,omitempty"` } type OllamaEmbeddingRequest struct { - Model string `json:"model,omitempty"` - Input []string `json:"input"` - Options *Options `json:"options,omitempty"` + Model string `json:"model"` + Input interface{} `json:"input"` + Options map[string]any `json:"options,omitempty"` + Dimensions int `json:"dimensions,omitempty"` } type OllamaEmbeddingResponse struct { - Error string `json:"error,omitempty"` - Model string `json:"model"` - Embedding [][]float64 `json:"embeddings,omitempty"` + Error string `json:"error,omitempty"` + Model string `json:"model"` + Embeddings [][]float64 `json:"embeddings"` + PromptEvalCount int `json:"prompt_eval_count,omitempty"` } + diff --git a/relay/channel/ollama/relay-ollama.go b/relay/channel/ollama/relay-ollama.go index 27c67b4ec..897e22cbd 100644 --- a/relay/channel/ollama/relay-ollama.go +++ b/relay/channel/ollama/relay-ollama.go @@ -1,6 +1,7 @@ package ollama import ( + "encoding/json" "fmt" "io" "net/http" @@ -14,121 +15,179 @@ import ( "github.com/gin-gonic/gin" ) -func requestOpenAI2Ollama(c *gin.Context, request *dto.GeneralOpenAIRequest) (*OllamaRequest, error) { - messages := make([]dto.Message, 0, len(request.Messages)) - for _, message := range request.Messages { - if !message.IsStringContent() { - mediaMessages := message.ParseContent() - for j, mediaMessage := range mediaMessages { - if mediaMessage.Type == dto.ContentTypeImageURL { - imageUrl := mediaMessage.GetImageMedia() - // check if not base64 - if strings.HasPrefix(imageUrl.Url, "http") { - fileData, err := service.GetFileBase64FromUrl(c, imageUrl.Url, "formatting image for Ollama") - if err != nil { - return nil, err +// openAIChatToOllamaChat converts OpenAI-style chat request to Ollama chat +func openAIChatToOllamaChat(c *gin.Context, r *dto.GeneralOpenAIRequest) (*OllamaChatRequest, error) { + chatReq := &OllamaChatRequest{ + Model: r.Model, + Stream: r.Stream, + Options: map[string]any{}, + Think: r.Think, + } + // format mapping + if r.ResponseFormat != nil { + if r.ResponseFormat.Type == "json" { + chatReq.Format = "json" + } else if r.ResponseFormat.Type == "json_schema" { + // supply schema object directly + if len(r.ResponseFormat.JsonSchema) > 0 { + var schema any + _ = json.Unmarshal(r.ResponseFormat.JsonSchema, &schema) + chatReq.Format = schema + } + } + } + + // options mapping + if r.Temperature != nil { chatReq.Options["temperature"] = r.Temperature } + if r.TopP != 0 { chatReq.Options["top_p"] = r.TopP } + if r.TopK != 0 { chatReq.Options["top_k"] = r.TopK } + if r.FrequencyPenalty != 0 { chatReq.Options["frequency_penalty"] = r.FrequencyPenalty } + if r.PresencePenalty != 0 { chatReq.Options["presence_penalty"] = r.PresencePenalty } + if r.Seed != 0 { chatReq.Options["seed"] = int(r.Seed) } + if mt := r.GetMaxTokens(); mt != 0 { chatReq.Options["num_predict"] = int(mt) } + + // Stop -> options.stop (array) + if r.Stop != nil { + switch v := r.Stop.(type) { + case string: + chatReq.Options["stop"] = []string{v} + case []string: + chatReq.Options["stop"] = v + case []any: + arr := make([]string,0,len(v)) + for _, i := range v { if s,ok:=i.(string); ok { arr = append(arr,s) } } + if len(arr)>0 { chatReq.Options["stop"] = arr } + } + } + + // tools + if len(r.Tools) > 0 { + tools := make([]OllamaTool,0,len(r.Tools)) + for _, t := range r.Tools { + tools = append(tools, OllamaTool{Type: "function", Function: OllamaToolFunction{Name: t.Function.Name, Description: t.Function.Description, Parameters: t.Function.Parameters}}) + } + chatReq.Tools = tools + } + + // messages + chatReq.Messages = make([]OllamaChatMessage,0,len(r.Messages)) + for _, m := range r.Messages { + // gather text parts & images + var textBuilder strings.Builder + var images []string + if m.IsStringContent() { + textBuilder.WriteString(m.StringContent()) + } else { + parts := m.ParseContent() + for _, part := range parts { + if part.Type == dto.ContentTypeImageURL { + img := part.GetImageMedia() + if img != nil && img.Url != "" { + // ensure base64 dataURL + if strings.HasPrefix(img.Url, "http") { + fileData, err := service.GetFileBase64FromUrl(c, img.Url, "fetch image for ollama chat") + if err != nil { return nil, err } + img.Url = fmt.Sprintf("data:%s;base64,%s", fileData.MimeType, fileData.Base64Data) } - imageUrl.Url = fmt.Sprintf("data:%s;base64,%s", fileData.MimeType, fileData.Base64Data) + images = append(images, img.Url) } - mediaMessage.ImageUrl = imageUrl - mediaMessages[j] = mediaMessage + } else if part.Type == dto.ContentTypeText { + textBuilder.WriteString(part.Text) } } - message.SetMediaContent(mediaMessages) } - messages = append(messages, dto.Message{ - Role: message.Role, - Content: message.Content, - ToolCalls: message.ToolCalls, - ToolCallId: message.ToolCallId, - }) + cm := OllamaChatMessage{Role: m.Role, Content: textBuilder.String()} + if len(images)>0 { cm.Images = images } + // history tool call result message + if m.Role == "tool" && m.Name != nil { cm.ToolName = *m.Name } + // tool calls from assistant previous message + if len(m.ToolCalls)>0 { + calls := make([]OllamaToolCall,0,len(m.ToolCalls)) + for _, tc := range m.ToolCalls { + var args interface{} + if tc.Function.Arguments != "" { _ = json.Unmarshal([]byte(tc.Function.Arguments), &args) } + oc := OllamaToolCall{} + oc.Function.Name = tc.Function.Name + if args==nil { args = map[string]any{} } + oc.Function.Arguments = args + calls = append(calls, oc) + } + cm.ToolCalls = calls + } + chatReq.Messages = append(chatReq.Messages, cm) } - str, ok := request.Stop.(string) - var Stop []string - if ok { - Stop = []string{str} - } else { - Stop, _ = request.Stop.([]string) - } - ollamaRequest := &OllamaRequest{ - Model: request.Model, - Messages: messages, - Stream: request.Stream, - Temperature: request.Temperature, - Seed: request.Seed, - Topp: request.TopP, - TopK: request.TopK, - Stop: Stop, - Tools: request.Tools, - MaxTokens: request.GetMaxTokens(), - ResponseFormat: request.ResponseFormat, - FrequencyPenalty: request.FrequencyPenalty, - PresencePenalty: request.PresencePenalty, - Prompt: request.Prompt, - StreamOptions: request.StreamOptions, - Suffix: request.Suffix, - } - ollamaRequest.Think = request.Think - return ollamaRequest, nil + return chatReq, nil } -func requestOpenAI2Embeddings(request dto.EmbeddingRequest) *OllamaEmbeddingRequest { - return &OllamaEmbeddingRequest{ - Model: request.Model, - Input: request.ParseInput(), - Options: &Options{ - Seed: int(request.Seed), - Temperature: request.Temperature, - TopP: request.TopP, - FrequencyPenalty: request.FrequencyPenalty, - PresencePenalty: request.PresencePenalty, - }, +// openAIToGenerate converts OpenAI completions request to Ollama generate +func openAIToGenerate(c *gin.Context, r *dto.GeneralOpenAIRequest) (*OllamaGenerateRequest, error) { + gen := &OllamaGenerateRequest{ + Model: r.Model, + Stream: r.Stream, + Options: map[string]any{}, + Think: r.Think, } + // Prompt may be in r.Prompt (string or []any) + if r.Prompt != nil { + switch v := r.Prompt.(type) { + case string: + gen.Prompt = v + case []any: + var sb strings.Builder + for _, it := range v { if s,ok:=it.(string); ok { sb.WriteString(s) } } + gen.Prompt = sb.String() + default: + gen.Prompt = fmt.Sprintf("%v", r.Prompt) + } + } + if r.Suffix != nil { if s,ok:=r.Suffix.(string); ok { gen.Suffix = s } } + if r.ResponseFormat != nil { + if r.ResponseFormat.Type == "json" { gen.Format = "json" } else if r.ResponseFormat.Type == "json_schema" { var schema any; _ = json.Unmarshal(r.ResponseFormat.JsonSchema,&schema); gen.Format=schema } + } + if r.Temperature != nil { gen.Options["temperature"] = r.Temperature } + if r.TopP != 0 { gen.Options["top_p"] = r.TopP } + if r.TopK != 0 { gen.Options["top_k"] = r.TopK } + if r.FrequencyPenalty != 0 { gen.Options["frequency_penalty"] = r.FrequencyPenalty } + if r.PresencePenalty != 0 { gen.Options["presence_penalty"] = r.PresencePenalty } + if r.Seed != 0 { gen.Options["seed"] = int(r.Seed) } + if mt := r.GetMaxTokens(); mt != 0 { gen.Options["num_predict"] = int(mt) } + if r.Stop != nil { + switch v := r.Stop.(type) { + case string: gen.Options["stop"] = []string{v} + case []string: gen.Options["stop"] = v + case []any: arr:=make([]string,0,len(v)); for _,i:= range v { if s,ok:=i.(string); ok { arr=append(arr,s) } }; if len(arr)>0 { gen.Options["stop"]=arr } + } + } + return gen, nil +} + +func requestOpenAI2Embeddings(r dto.EmbeddingRequest) *OllamaEmbeddingRequest { + opts := map[string]any{} + if r.Temperature != nil { opts["temperature"] = r.Temperature } + if r.TopP != 0 { opts["top_p"] = r.TopP } + if r.TopK != 0 { opts["top_k"] = r.TopK } + if r.FrequencyPenalty != 0 { opts["frequency_penalty"] = r.FrequencyPenalty } + if r.PresencePenalty != 0 { opts["presence_penalty"] = r.PresencePenalty } + if r.Seed != 0 { opts["seed"] = int(r.Seed) } + if r.Dimensions != 0 { opts["dimensions"] = r.Dimensions } + input := r.ParseInput() + if len(input)==1 { return &OllamaEmbeddingRequest{Model:r.Model, Input: input[0], Options: opts, Dimensions:r.Dimensions} } + return &OllamaEmbeddingRequest{Model:r.Model, Input: input, Options: opts, Dimensions:r.Dimensions} } func ollamaEmbeddingHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http.Response) (*dto.Usage, *types.NewAPIError) { - var ollamaEmbeddingResponse OllamaEmbeddingResponse - responseBody, err := io.ReadAll(resp.Body) - if err != nil { - return nil, types.NewOpenAIError(err, types.ErrorCodeBadResponseBody, http.StatusInternalServerError) - } + var oResp OllamaEmbeddingResponse + body, err := io.ReadAll(resp.Body) + if err != nil { return nil, types.NewOpenAIError(err, types.ErrorCodeBadResponseBody, http.StatusInternalServerError) } service.CloseResponseBodyGracefully(resp) - err = common.Unmarshal(responseBody, &ollamaEmbeddingResponse) - if err != nil { - return nil, types.NewOpenAIError(err, types.ErrorCodeBadResponseBody, http.StatusInternalServerError) - } - if ollamaEmbeddingResponse.Error != "" { - return nil, types.NewOpenAIError(fmt.Errorf("ollama error: %s", ollamaEmbeddingResponse.Error), types.ErrorCodeBadResponseBody, http.StatusInternalServerError) - } - flattenedEmbeddings := flattenEmbeddings(ollamaEmbeddingResponse.Embedding) - data := make([]dto.OpenAIEmbeddingResponseItem, 0, 1) - data = append(data, dto.OpenAIEmbeddingResponseItem{ - Embedding: flattenedEmbeddings, - Object: "embedding", - }) - usage := &dto.Usage{ - TotalTokens: info.PromptTokens, - CompletionTokens: 0, - PromptTokens: info.PromptTokens, - } - embeddingResponse := &dto.OpenAIEmbeddingResponse{ - Object: "list", - Data: data, - Model: info.UpstreamModelName, - Usage: *usage, - } - doResponseBody, err := common.Marshal(embeddingResponse) - if err != nil { - return nil, types.NewOpenAIError(err, types.ErrorCodeBadResponseBody, http.StatusInternalServerError) - } - service.IOCopyBytesGracefully(c, resp, doResponseBody) + if err = common.Unmarshal(body, &oResp); err != nil { return nil, types.NewOpenAIError(err, types.ErrorCodeBadResponseBody, http.StatusInternalServerError) } + if oResp.Error != "" { return nil, types.NewOpenAIError(fmt.Errorf("ollama error: %s", oResp.Error), types.ErrorCodeBadResponseBody, http.StatusInternalServerError) } + data := make([]dto.OpenAIEmbeddingResponseItem,0,len(oResp.Embeddings)) + for i, emb := range oResp.Embeddings { data = append(data, dto.OpenAIEmbeddingResponseItem{Index:i,Object:"embedding",Embedding:emb}) } + usage := &dto.Usage{PromptTokens: oResp.PromptEvalCount, CompletionTokens:0, TotalTokens: oResp.PromptEvalCount} + embResp := &dto.OpenAIEmbeddingResponse{Object:"list", Data:data, Model: info.UpstreamModelName, Usage:*usage} + out, _ := common.Marshal(embResp) + service.IOCopyBytesGracefully(c, resp, out) return usage, nil } -func flattenEmbeddings(embeddings [][]float64) []float64 { - flattened := []float64{} - for _, row := range embeddings { - flattened = append(flattened, row...) - } - return flattened -} diff --git a/relay/channel/ollama/stream.go b/relay/channel/ollama/stream.go new file mode 100644 index 000000000..3ae9c6d04 --- /dev/null +++ b/relay/channel/ollama/stream.go @@ -0,0 +1,165 @@ +package ollama + +import ( + "bufio" + "encoding/json" + "fmt" + "io" + "net/http" + "one-api/common" + "one-api/dto" + "one-api/logger" + relaycommon "one-api/relay/common" + "one-api/relay/helper" + "one-api/service" + "one-api/types" + "strings" + "time" + + "github.com/gin-gonic/gin" +) + +// Ollama streaming chunk (chat or generate) +type ollamaChatStreamChunk struct { + Model string `json:"model"` + CreatedAt string `json:"created_at"` + // chat + Message *struct { + Role string `json:"role"` + Content string `json:"content"` + ToolCalls []struct { `json:"tool_calls"` + Function struct { + Name string `json:"name"` + Arguments interface{} `json:"arguments"` + } `json:"function"` + } `json:"tool_calls"` + } `json:"message"` + // generate + Response string `json:"response"` + Done bool `json:"done"` + DoneReason string `json:"done_reason"` + TotalDuration int64 `json:"total_duration"` + LoadDuration int64 `json:"load_duration"` + PromptEvalCount int `json:"prompt_eval_count"` + EvalCount int `json:"eval_count"` + // generate mode may use these + PromptEvalDuration int64 `json:"prompt_eval_duration"` + EvalDuration int64 `json:"eval_duration"` +} + +func toUnix(ts string) int64 { // parse RFC3339 / variant; fallback time.Now + if ts == "" { return time.Now().Unix() } + // try time.RFC3339 or with nanoseconds + t, err := time.Parse(time.RFC3339Nano, ts) + if err != nil { t2, err2 := time.Parse(time.RFC3339, ts); if err2==nil { return t2.Unix() }; return time.Now().Unix() } + return t.Unix() +} + +// streaming handler: convert Ollama stream -> OpenAI SSE +func ollamaStreamHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http.Response) (*dto.Usage, *types.NewAPIError) { + if resp == nil || resp.Body == nil { return nil, types.NewOpenAIError(fmt.Errorf("empty response"), types.ErrorCodeBadResponse, http.StatusBadRequest) } + defer service.CloseResponseBodyGracefully(resp) + + helper.SetEventStreamHeaders(c) + scanner := bufio.NewScanner(resp.Body) + usage := &dto.Usage{} + var model = info.UpstreamModelName + var responseId = common.GetUUID() + var created = time.Now().Unix() + var aggregatedText strings.Builder + var toolCallIndex int + // send start event + start := helper.GenerateStartEmptyResponse(responseId, created, model, nil) + if data, err := common.Marshal(start); err == nil { _ = helper.StringData(c, string(data)) } + + for scanner.Scan() { + line := scanner.Text() + line = strings.TrimSpace(line) + if line == "" { continue } + var chunk ollamaChatStreamChunk + if err := json.Unmarshal([]byte(line), &chunk); err != nil { + logger.LogError(c, "ollama stream json decode error: "+err.Error()+" line="+line) + return usage, types.NewOpenAIError(err, types.ErrorCodeBadResponseBody, http.StatusInternalServerError) + } + if chunk.Model != "" { model = chunk.Model } + created = toUnix(chunk.CreatedAt) + + if !chunk.Done { + // delta content + var content string + if chunk.Message != nil { content = chunk.Message.Content } else { content = chunk.Response } + if content != "" { aggregatedText.WriteString(content) } + delta := dto.ChatCompletionsStreamResponse{ + Id: responseId, + Object: "chat.completion.chunk", + Created: created, + Model: model, + Choices: []dto.ChatCompletionsStreamResponseChoice{ { + Index: 0, + Delta: dto.ChatCompletionsStreamResponseChoiceDelta{ Role: "assistant" }, + } }, + } + if content != "" { delta.Choices[0].Delta.SetContentString(content) } + // tool calls + if chunk.Message != nil && len(chunk.Message.ToolCalls) > 0 { + delta.Choices[0].Delta.ToolCalls = make([]dto.ToolCallResponse,0,len(chunk.Message.ToolCalls)) + for _, tc := range chunk.Message.ToolCalls { + // arguments -> string + argBytes, _ := json.Marshal(tc.Function.Arguments) + tr := dto.ToolCallResponse{ID:"", Type:nil, Function: dto.FunctionResponse{Name: tc.Function.Name, Arguments: string(argBytes)}} + tr.SetIndex(toolCallIndex) + toolCallIndex++ + delta.Choices[0].Delta.ToolCalls = append(delta.Choices[0].Delta.ToolCalls, tr) + } + } + if data, err := common.Marshal(delta); err == nil { _ = helper.StringData(c, string(data)) } + continue + } + // done frame + usage.PromptTokens = chunk.PromptEvalCount + usage.CompletionTokens = chunk.EvalCount + usage.TotalTokens = usage.PromptTokens + usage.CompletionTokens + finishReason := chunk.DoneReason + if finishReason == "" { finishReason = "stop" } + stop := helper.GenerateStopResponse(responseId, created, model, finishReason) + if data, err := common.Marshal(stop); err == nil { _ = helper.StringData(c, string(data)) } + final := helper.GenerateFinalUsageResponse(responseId, created, model, *usage) + if data, err := common.Marshal(final); err == nil { _ = helper.StringData(c, string(data)) } + } + if err := scanner.Err(); err != nil && err != io.EOF { logger.LogError(c, "ollama stream scan error: "+err.Error()) } + return usage, nil +} + +// non-stream handler for chat/generate +func ollamaChatHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http.Response) (*dto.Usage, *types.NewAPIError) { + body, err := io.ReadAll(resp.Body) + if err != nil { return nil, types.NewOpenAIError(err, types.ErrorCodeReadResponseBodyFailed, http.StatusInternalServerError) } + service.CloseResponseBodyGracefully(resp) + if common.DebugEnabled { println("ollama non-stream resp:", string(body)) } + var chunk ollamaChatStreamChunk + if err = json.Unmarshal(body, &chunk); err != nil { return nil, types.NewOpenAIError(err, types.ErrorCodeBadResponseBody, http.StatusInternalServerError) } + model := chunk.Model + if model == "" { model = info.UpstreamModelName } + created := toUnix(chunk.CreatedAt) + content := "" + if chunk.Message != nil { content = chunk.Message.Content } else { content = chunk.Response } + usage := &dto.Usage{PromptTokens: chunk.PromptEvalCount, CompletionTokens: chunk.EvalCount, TotalTokens: chunk.PromptEvalCount + chunk.EvalCount} + // Build OpenAI style response + full := dto.OpenAITextResponse{ + Id: common.GetUUID(), + Model: model, + Object: "chat.completion", + Created: created, + Choices: []dto.OpenAITextResponseChoice{ { + Index: 0, + Message: dto.Message{Role: "assistant", Content: contentPtr(content)}, + FinishReason: func() string { if chunk.DoneReason == "" { return "stop" } ; return chunk.DoneReason }(), + } }, + Usage: *usage, + } + out, _ := common.Marshal(full) + service.IOCopyBytesGracefully(c, resp, out) + return usage, nil +} + +func contentPtr(s string) *string { if s=="" { return nil }; return &s }