Files
new-api/relay/channel/ollama/stream.go

301 lines
9.1 KiB
Go

package ollama
import (
"bufio"
"encoding/json"
"fmt"
"io"
"net/http"
"strings"
"time"
"github.com/QuantumNous/new-api/common"
"github.com/QuantumNous/new-api/dto"
"github.com/QuantumNous/new-api/logger"
relaycommon "github.com/QuantumNous/new-api/relay/common"
"github.com/QuantumNous/new-api/relay/helper"
"github.com/QuantumNous/new-api/service"
"github.com/QuantumNous/new-api/types"
"github.com/gin-gonic/gin"
)
type ollamaChatStreamChunk struct {
Model string `json:"model"`
CreatedAt string `json:"created_at"`
// chat
Message *struct {
Role string `json:"role"`
Content string `json:"content"`
Thinking json.RawMessage `json:"thinking"`
ToolCalls []struct {
Function struct {
Name string `json:"name"`
Arguments interface{} `json:"arguments"`
} `json:"function"`
} `json:"tool_calls"`
} `json:"message"`
// generate
Response string `json:"response"`
Done bool `json:"done"`
DoneReason string `json:"done_reason"`
TotalDuration int64 `json:"total_duration"`
LoadDuration int64 `json:"load_duration"`
PromptEvalCount int `json:"prompt_eval_count"`
EvalCount int `json:"eval_count"`
PromptEvalDuration int64 `json:"prompt_eval_duration"`
EvalDuration int64 `json:"eval_duration"`
}
func toUnix(ts string) int64 {
if ts == "" {
return time.Now().Unix()
}
// try time.RFC3339 or with nanoseconds
t, err := time.Parse(time.RFC3339Nano, ts)
if err != nil {
t2, err2 := time.Parse(time.RFC3339, ts)
if err2 == nil {
return t2.Unix()
}
return time.Now().Unix()
}
return t.Unix()
}
func ollamaStreamHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http.Response) (*dto.Usage, *types.NewAPIError) {
if resp == nil || resp.Body == nil {
return nil, types.NewOpenAIError(fmt.Errorf("empty response"), types.ErrorCodeBadResponse, http.StatusBadRequest)
}
defer service.CloseResponseBodyGracefully(resp)
helper.SetEventStreamHeaders(c)
scanner := bufio.NewScanner(resp.Body)
usage := &dto.Usage{}
var model = info.UpstreamModelName
var responseId = common.GetUUID()
var created = time.Now().Unix()
var toolCallIndex int
start := helper.GenerateStartEmptyResponse(responseId, created, model, nil)
if data, err := common.Marshal(start); err == nil {
_ = helper.StringData(c, string(data))
}
for scanner.Scan() {
line := scanner.Text()
line = strings.TrimSpace(line)
if line == "" {
continue
}
var chunk ollamaChatStreamChunk
if err := json.Unmarshal([]byte(line), &chunk); err != nil {
logger.LogError(c, "ollama stream json decode error: "+err.Error()+" line="+line)
return usage, types.NewOpenAIError(err, types.ErrorCodeBadResponseBody, http.StatusInternalServerError)
}
if chunk.Model != "" {
model = chunk.Model
}
created = toUnix(chunk.CreatedAt)
if !chunk.Done {
// delta content
var content string
if chunk.Message != nil {
content = chunk.Message.Content
} else {
content = chunk.Response
}
delta := dto.ChatCompletionsStreamResponse{
Id: responseId,
Object: "chat.completion.chunk",
Created: created,
Model: model,
Choices: []dto.ChatCompletionsStreamResponseChoice{{
Index: 0,
Delta: dto.ChatCompletionsStreamResponseChoiceDelta{Role: "assistant"},
}},
}
if content != "" {
delta.Choices[0].Delta.SetContentString(content)
}
if chunk.Message != nil && len(chunk.Message.Thinking) > 0 {
raw := strings.TrimSpace(string(chunk.Message.Thinking))
if raw != "" && raw != "null" {
// Unmarshal the JSON string to get the actual content without quotes
var thinkingContent string
if err := json.Unmarshal(chunk.Message.Thinking, &thinkingContent); err == nil {
delta.Choices[0].Delta.SetReasoningContent(thinkingContent)
} else {
// Fallback to raw string if it's not a JSON string
delta.Choices[0].Delta.SetReasoningContent(raw)
}
}
}
// tool calls
if chunk.Message != nil && len(chunk.Message.ToolCalls) > 0 {
delta.Choices[0].Delta.ToolCalls = make([]dto.ToolCallResponse, 0, len(chunk.Message.ToolCalls))
for _, tc := range chunk.Message.ToolCalls {
// arguments -> string
argBytes, _ := json.Marshal(tc.Function.Arguments)
toolId := fmt.Sprintf("call_%d", toolCallIndex)
tr := dto.ToolCallResponse{ID: toolId, Type: "function", Function: dto.FunctionResponse{Name: tc.Function.Name, Arguments: string(argBytes)}}
tr.SetIndex(toolCallIndex)
toolCallIndex++
delta.Choices[0].Delta.ToolCalls = append(delta.Choices[0].Delta.ToolCalls, tr)
}
}
if data, err := common.Marshal(delta); err == nil {
_ = helper.StringData(c, string(data))
}
continue
}
// done frame
// finalize once and break loop
usage.PromptTokens = chunk.PromptEvalCount
usage.CompletionTokens = chunk.EvalCount
usage.TotalTokens = usage.PromptTokens + usage.CompletionTokens
finishReason := chunk.DoneReason
if finishReason == "" {
finishReason = "stop"
}
// emit stop delta
if stop := helper.GenerateStopResponse(responseId, created, model, finishReason); stop != nil {
if data, err := common.Marshal(stop); err == nil {
_ = helper.StringData(c, string(data))
}
}
// emit usage frame
if final := helper.GenerateFinalUsageResponse(responseId, created, model, *usage); final != nil {
if data, err := common.Marshal(final); err == nil {
_ = helper.StringData(c, string(data))
}
}
// send [DONE]
helper.Done(c)
break
}
if err := scanner.Err(); err != nil && err != io.EOF {
logger.LogError(c, "ollama stream scan error: "+err.Error())
}
return usage, nil
}
// non-stream handler for chat/generate
func ollamaChatHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http.Response) (*dto.Usage, *types.NewAPIError) {
body, err := io.ReadAll(resp.Body)
if err != nil {
return nil, types.NewOpenAIError(err, types.ErrorCodeReadResponseBodyFailed, http.StatusInternalServerError)
}
service.CloseResponseBodyGracefully(resp)
raw := string(body)
if common.DebugEnabled {
println("ollama non-stream raw resp:", raw)
}
lines := strings.Split(raw, "\n")
var (
aggContent strings.Builder
reasoningBuilder strings.Builder
lastChunk ollamaChatStreamChunk
parsedAny bool
)
for _, ln := range lines {
ln = strings.TrimSpace(ln)
if ln == "" {
continue
}
var ck ollamaChatStreamChunk
if err := json.Unmarshal([]byte(ln), &ck); err != nil {
if len(lines) == 1 {
return nil, types.NewOpenAIError(err, types.ErrorCodeBadResponseBody, http.StatusInternalServerError)
}
continue
}
parsedAny = true
lastChunk = ck
if ck.Message != nil && len(ck.Message.Thinking) > 0 {
raw := strings.TrimSpace(string(ck.Message.Thinking))
if raw != "" && raw != "null" {
// Unmarshal the JSON string to get the actual content without quotes
var thinkingContent string
if err := json.Unmarshal(ck.Message.Thinking, &thinkingContent); err == nil {
reasoningBuilder.WriteString(thinkingContent)
} else {
// Fallback to raw string if it's not a JSON string
reasoningBuilder.WriteString(raw)
}
}
}
if ck.Message != nil && ck.Message.Content != "" {
aggContent.WriteString(ck.Message.Content)
} else if ck.Response != "" {
aggContent.WriteString(ck.Response)
}
}
if !parsedAny {
var single ollamaChatStreamChunk
if err := json.Unmarshal(body, &single); err != nil {
return nil, types.NewOpenAIError(err, types.ErrorCodeBadResponseBody, http.StatusInternalServerError)
}
lastChunk = single
if single.Message != nil {
if len(single.Message.Thinking) > 0 {
raw := strings.TrimSpace(string(single.Message.Thinking))
if raw != "" && raw != "null" {
// Unmarshal the JSON string to get the actual content without quotes
var thinkingContent string
if err := json.Unmarshal(single.Message.Thinking, &thinkingContent); err == nil {
reasoningBuilder.WriteString(thinkingContent)
} else {
// Fallback to raw string if it's not a JSON string
reasoningBuilder.WriteString(raw)
}
}
}
aggContent.WriteString(single.Message.Content)
} else {
aggContent.WriteString(single.Response)
}
}
model := lastChunk.Model
if model == "" {
model = info.UpstreamModelName
}
created := toUnix(lastChunk.CreatedAt)
usage := &dto.Usage{PromptTokens: lastChunk.PromptEvalCount, CompletionTokens: lastChunk.EvalCount, TotalTokens: lastChunk.PromptEvalCount + lastChunk.EvalCount}
content := aggContent.String()
finishReason := lastChunk.DoneReason
if finishReason == "" {
finishReason = "stop"
}
msg := dto.Message{Role: "assistant", Content: contentPtr(content)}
if rc := reasoningBuilder.String(); rc != "" {
msg.ReasoningContent = rc
}
full := dto.OpenAITextResponse{
Id: common.GetUUID(),
Model: model,
Object: "chat.completion",
Created: created,
Choices: []dto.OpenAITextResponseChoice{{
Index: 0,
Message: msg,
FinishReason: finishReason,
}},
Usage: *usage,
}
out, _ := common.Marshal(full)
service.IOCopyBytesGracefully(c, resp, out)
return usage, nil
}
func contentPtr(s string) *string {
if s == "" {
return nil
}
return &s
}