diff --git a/backend/internal/handler/openai_chat_completions.go b/backend/internal/handler/openai_chat_completions.go index cd30a63f..6900e7cd 100644 --- a/backend/internal/handler/openai_chat_completions.go +++ b/backend/internal/handler/openai_chat_completions.go @@ -1,23 +1,53 @@ package handler import ( - "bytes" - "crypto/rand" - "encoding/hex" - "encoding/json" - "io" + "context" + "errors" "net/http" - "strings" "time" + pkghttputil "github.com/Wei-Shaw/sub2api/internal/pkg/httputil" + "github.com/Wei-Shaw/sub2api/internal/pkg/ip" + "github.com/Wei-Shaw/sub2api/internal/pkg/logger" + middleware2 "github.com/Wei-Shaw/sub2api/internal/server/middleware" "github.com/Wei-Shaw/sub2api/internal/service" "github.com/gin-gonic/gin" + "github.com/tidwall/gjson" + "go.uber.org/zap" ) -// ChatCompletions handles OpenAI Chat Completions API compatibility. +// ChatCompletions handles OpenAI Chat Completions API requests. // POST /v1/chat/completions func (h *OpenAIGatewayHandler) ChatCompletions(c *gin.Context) { - body, err := io.ReadAll(c.Request.Body) + streamStarted := false + defer h.recoverResponsesPanic(c, &streamStarted) + + requestStart := time.Now() + + apiKey, ok := middleware2.GetAPIKeyFromContext(c) + if !ok { + h.errorResponse(c, http.StatusUnauthorized, "authentication_error", "Invalid API key") + return + } + + subject, ok := middleware2.GetAuthSubjectFromContext(c) + if !ok { + h.errorResponse(c, http.StatusInternalServerError, "api_error", "User context not found") + return + } + reqLog := requestLogger( + c, + "handler.openai_gateway.chat_completions", + zap.Int64("user_id", subject.UserID), + zap.Int64("api_key_id", apiKey.ID), + zap.Any("group_id", apiKey.GroupID), + ) + + if !h.ensureResponsesDependencies(c, reqLog) { + return + } + + body, err := pkghttputil.ReadRequestBodyWithPrealloc(c.Request) if err != nil { if maxErr, ok := extractMaxBytesError(err); ok { h.errorResponse(c, http.StatusRequestEntityTooLarge, "invalid_request_error", buildBodyTooLargeMessage(maxErr.Limit)) @@ -31,516 +61,230 @@ func (h *OpenAIGatewayHandler) ChatCompletions(c *gin.Context) { return } - // Preserve original chat-completions request for upstream passthrough when needed. - c.Set(service.OpenAIChatCompletionsBodyKey, body) - - var chatReq map[string]any - if err := json.Unmarshal(body, &chatReq); err != nil { + if !gjson.ValidBytes(body) { h.errorResponse(c, http.StatusBadRequest, "invalid_request_error", "Failed to parse request body") return } - includeUsage := false - if streamOptions, ok := chatReq["stream_options"].(map[string]any); ok { - if v, ok := streamOptions["include_usage"].(bool); ok { - includeUsage = v - } + modelResult := gjson.GetBytes(body, "model") + if !modelResult.Exists() || modelResult.Type != gjson.String || modelResult.String() == "" { + h.errorResponse(c, http.StatusBadRequest, "invalid_request_error", "model is required") + return } - c.Set(service.OpenAIChatCompletionsIncludeUsageKey, includeUsage) + reqModel := modelResult.String() + reqStream := gjson.GetBytes(body, "stream").Bool() - converted, err := service.ConvertChatCompletionsToResponses(chatReq) - if err != nil { - h.errorResponse(c, http.StatusBadRequest, "invalid_request_error", err.Error()) + reqLog = reqLog.With(zap.String("model", reqModel), zap.Bool("stream", reqStream)) + + setOpsRequestContext(c, reqModel, reqStream, body) + + if h.errorPassthroughService != nil { + service.BindErrorPassthroughService(c, h.errorPassthroughService) + } + + subscription, _ := middleware2.GetSubscriptionFromContext(c) + + service.SetOpsLatencyMs(c, service.OpsAuthLatencyMsKey, time.Since(requestStart).Milliseconds()) + routingStart := time.Now() + + userReleaseFunc, acquired := h.acquireResponsesUserSlot(c, subject.UserID, subject.Concurrency, reqStream, &streamStarted, reqLog) + if !acquired { + return + } + if userReleaseFunc != nil { + defer userReleaseFunc() + } + + if err := h.billingCacheService.CheckBillingEligibility(c.Request.Context(), apiKey.User, apiKey, apiKey.Group, subscription); err != nil { + reqLog.Info("openai_chat_completions.billing_eligibility_check_failed", zap.Error(err)) + status, code, message := billingErrorDetails(err) + h.handleStreamingAwareError(c, status, code, message, streamStarted) return } - convertedBody, err := json.Marshal(converted) - if err != nil { - h.errorResponse(c, http.StatusInternalServerError, "api_error", "Failed to process request") - return - } + sessionHash := h.gatewayService.GenerateSessionHash(c, body) + promptCacheKey := h.gatewayService.ExtractSessionID(c, body) - stream, _ := converted["stream"].(bool) - model, _ := converted["model"].(string) - originalWriter := c.Writer - writer := newChatCompletionsResponseWriter(c.Writer, stream, includeUsage, model) - c.Writer = writer - c.Request.Body = io.NopCloser(bytes.NewReader(convertedBody)) - c.Request.ContentLength = int64(len(convertedBody)) + maxAccountSwitches := h.maxAccountSwitches + switchCount := 0 + failedAccountIDs := make(map[int64]struct{}) + sameAccountRetryCount := make(map[int64]int) + var lastFailoverErr *service.UpstreamFailoverError - h.Responses(c) - writer.Finalize() - c.Writer = originalWriter -} - -type chatCompletionsResponseWriter struct { - gin.ResponseWriter - stream bool - includeUsage bool - buffer bytes.Buffer - streamBuf bytes.Buffer - state *chatCompletionStreamState - corrector *service.CodexToolCorrector - finalized bool - passthrough bool -} - -type chatCompletionStreamState struct { - id string - model string - created int64 - sentRole bool - sawToolCall bool - sawText bool - toolCallIndex map[string]int - usage map[string]any -} - -func newChatCompletionsResponseWriter(w gin.ResponseWriter, stream bool, includeUsage bool, model string) *chatCompletionsResponseWriter { - return &chatCompletionsResponseWriter{ - ResponseWriter: w, - stream: stream, - includeUsage: includeUsage, - state: &chatCompletionStreamState{ - model: strings.TrimSpace(model), - toolCallIndex: make(map[string]int), - }, - corrector: service.NewCodexToolCorrector(), - } -} - -func (w *chatCompletionsResponseWriter) Write(data []byte) (int, error) { - if w.passthrough { - return w.ResponseWriter.Write(data) - } - if w.stream { - n, err := w.streamBuf.Write(data) - if err != nil { - return n, err - } - w.flushStreamBuffer() - return n, nil - } - - if w.finalized { - return len(data), nil - } - return w.buffer.Write(data) -} - -func (w *chatCompletionsResponseWriter) WriteString(s string) (int, error) { - return w.Write([]byte(s)) -} - -func (w *chatCompletionsResponseWriter) Finalize() { - if w.finalized { - return - } - w.finalized = true - if w.passthrough { - return - } - if w.stream { - return - } - - body := w.buffer.Bytes() - if len(body) == 0 { - return - } - - w.ResponseWriter.Header().Del("Content-Length") - - converted, err := service.ConvertResponsesToChatCompletion(body) - if err != nil { - _, _ = w.ResponseWriter.Write(body) - return - } - - corrected := converted - if correctedStr, ok := w.corrector.CorrectToolCallsInSSEData(string(converted)); ok { - corrected = []byte(correctedStr) - } - - _, _ = w.ResponseWriter.Write(corrected) -} - -func (w *chatCompletionsResponseWriter) SetPassthrough() { - w.passthrough = true -} - -func (w *chatCompletionsResponseWriter) Status() int { - if w.ResponseWriter == nil { - return 0 - } - return w.ResponseWriter.Status() -} - -func (w *chatCompletionsResponseWriter) Written() bool { - if w.ResponseWriter == nil { - return false - } - return w.ResponseWriter.Written() -} - -func (w *chatCompletionsResponseWriter) flushStreamBuffer() { for { - buf := w.streamBuf.Bytes() - idx := bytes.IndexByte(buf, '\n') - if idx == -1 { - return - } - lineBytes := w.streamBuf.Next(idx + 1) - line := strings.TrimRight(string(lineBytes), "\r\n") - w.handleStreamLine(line) - } -} - -func (w *chatCompletionsResponseWriter) handleStreamLine(line string) { - if line == "" { - return - } - if strings.HasPrefix(line, ":") { - _, _ = w.ResponseWriter.Write([]byte(line + "\n\n")) - return - } - if !strings.HasPrefix(line, "data:") { - return - } - - data := strings.TrimSpace(strings.TrimPrefix(line, "data:")) - for _, chunk := range w.convertResponseDataToChatChunks(data) { - if chunk == "" { - continue - } - if chunk == "[DONE]" { - _, _ = w.ResponseWriter.Write([]byte("data: [DONE]\n\n")) - continue - } - _, _ = w.ResponseWriter.Write([]byte("data: " + chunk + "\n\n")) - } -} - -func (w *chatCompletionsResponseWriter) convertResponseDataToChatChunks(data string) []string { - if data == "" { - return nil - } - if data == "[DONE]" { - return []string{"[DONE]"} - } - - var payload map[string]any - if err := json.Unmarshal([]byte(data), &payload); err != nil { - return []string{data} - } - - if _, ok := payload["error"]; ok { - return []string{data} - } - - eventType := strings.TrimSpace(getString(payload["type"])) - if eventType == "" { - return []string{data} - } - - w.state.applyMetadata(payload) - - switch eventType { - case "response.created": - return nil - case "response.output_text.delta": - delta := getString(payload["delta"]) - if delta == "" { - return nil - } - w.state.sawText = true - return []string{w.buildTextDeltaChunk(delta)} - case "response.output_text.done": - if w.state.sawText { - return nil - } - text := getString(payload["text"]) - if text == "" { - return nil - } - w.state.sawText = true - return []string{w.buildTextDeltaChunk(text)} - case "response.output_item.added", "response.output_item.delta": - if item, ok := payload["item"].(map[string]any); ok { - if callID, name, args, ok := extractToolCallFromItem(item); ok { - w.state.sawToolCall = true - return []string{w.buildToolCallChunk(callID, name, args)} + c.Set("openai_chat_completions_fallback_model", "") + reqLog.Debug("openai_chat_completions.account_selecting", zap.Int("excluded_account_count", len(failedAccountIDs))) + selection, scheduleDecision, err := h.gatewayService.SelectAccountWithScheduler( + c.Request.Context(), + apiKey.GroupID, + "", + sessionHash, + reqModel, + failedAccountIDs, + service.OpenAIUpstreamTransportAny, + ) + if err != nil { + reqLog.Warn("openai_chat_completions.account_select_failed", + zap.Error(err), + zap.Int("excluded_account_count", len(failedAccountIDs)), + ) + if len(failedAccountIDs) == 0 { + defaultModel := "" + if apiKey.Group != nil { + defaultModel = apiKey.Group.DefaultMappedModel + } + if defaultModel != "" && defaultModel != reqModel { + reqLog.Info("openai_chat_completions.fallback_to_default_model", + zap.String("default_mapped_model", defaultModel), + ) + selection, scheduleDecision, err = h.gatewayService.SelectAccountWithScheduler( + c.Request.Context(), + apiKey.GroupID, + "", + sessionHash, + defaultModel, + failedAccountIDs, + service.OpenAIUpstreamTransportAny, + ) + if err == nil && selection != nil { + c.Set("openai_chat_completions_fallback_model", defaultModel) + } + } + if err != nil { + h.handleStreamingAwareError(c, http.StatusServiceUnavailable, "api_error", "Service temporarily unavailable", streamStarted) + return + } + } else { + if lastFailoverErr != nil { + h.handleFailoverExhausted(c, lastFailoverErr, streamStarted) + } else { + h.handleStreamingAwareError(c, http.StatusBadGateway, "api_error", "Upstream request failed", streamStarted) + } + return } } - case "response.completed", "response.done": - if responseObj, ok := payload["response"].(map[string]any); ok { - w.state.applyResponseUsage(responseObj) + if selection == nil || selection.Account == nil { + h.handleStreamingAwareError(c, http.StatusServiceUnavailable, "api_error", "No available accounts", streamStarted) + return } - return []string{w.buildFinalChunk()} - } + account := selection.Account + sessionHash = ensureOpenAIPoolModeSessionHash(sessionHash, account) + reqLog.Debug("openai_chat_completions.account_selected", zap.Int64("account_id", account.ID), zap.String("account_name", account.Name)) + _ = scheduleDecision + setOpsSelectedAccount(c, account.ID, account.Platform) - if strings.Contains(eventType, "tool_call") || strings.Contains(eventType, "function_call") { - callID := strings.TrimSpace(getString(payload["call_id"])) - if callID == "" { - callID = strings.TrimSpace(getString(payload["tool_call_id"])) + accountReleaseFunc, acquired := h.acquireResponsesAccountSlot(c, apiKey.GroupID, sessionHash, selection, reqStream, &streamStarted, reqLog) + if !acquired { + return } - if callID == "" { - callID = strings.TrimSpace(getString(payload["id"])) + + service.SetOpsLatencyMs(c, service.OpsRoutingLatencyMsKey, time.Since(routingStart).Milliseconds()) + forwardStart := time.Now() + + defaultMappedModel := "" + if apiKey.Group != nil { + defaultMappedModel = apiKey.Group.DefaultMappedModel } - args := getString(payload["delta"]) - name := strings.TrimSpace(getString(payload["name"])) - if callID != "" && (args != "" || name != "") { - w.state.sawToolCall = true - return []string{w.buildToolCallChunk(callID, name, args)} + if fallbackModel := c.GetString("openai_chat_completions_fallback_model"); fallbackModel != "" { + defaultMappedModel = fallbackModel } - } + result, err := h.gatewayService.ForwardAsChatCompletions(c.Request.Context(), c, account, body, promptCacheKey, defaultMappedModel) - return nil -} - -func (w *chatCompletionsResponseWriter) buildTextDeltaChunk(delta string) string { - w.state.ensureDefaults() - payload := map[string]any{ - "content": delta, - } - if !w.state.sentRole { - payload["role"] = "assistant" - w.state.sentRole = true - } - return w.buildChunk(payload, nil, nil) -} - -func (w *chatCompletionsResponseWriter) buildToolCallChunk(callID, name, args string) string { - w.state.ensureDefaults() - index := w.state.toolCallIndexFor(callID) - function := map[string]any{} - if name != "" { - function["name"] = name - } - if args != "" { - function["arguments"] = args - } - toolCall := map[string]any{ - "index": index, - "id": callID, - "type": "function", - "function": function, - } - - delta := map[string]any{ - "tool_calls": []any{toolCall}, - } - if !w.state.sentRole { - delta["role"] = "assistant" - w.state.sentRole = true - } - - return w.buildChunk(delta, nil, nil) -} - -func (w *chatCompletionsResponseWriter) buildFinalChunk() string { - w.state.ensureDefaults() - finishReason := "stop" - if w.state.sawToolCall { - finishReason = "tool_calls" - } - usage := map[string]any(nil) - if w.includeUsage && w.state.usage != nil { - usage = w.state.usage - } - return w.buildChunk(map[string]any{}, finishReason, usage) -} - -func (w *chatCompletionsResponseWriter) buildChunk(delta map[string]any, finishReason any, usage map[string]any) string { - w.state.ensureDefaults() - chunk := map[string]any{ - "id": w.state.id, - "object": "chat.completion.chunk", - "created": w.state.created, - "model": w.state.model, - "choices": []any{ - map[string]any{ - "index": 0, - "delta": delta, - "finish_reason": finishReason, - }, - }, - } - if usage != nil { - chunk["usage"] = usage - } - - data, _ := json.Marshal(chunk) - if corrected, ok := w.corrector.CorrectToolCallsInSSEData(string(data)); ok { - return corrected - } - return string(data) -} - -func (s *chatCompletionStreamState) ensureDefaults() { - if s.id == "" { - s.id = "chatcmpl-" + randomHexUnsafe(12) - } - if s.model == "" { - s.model = "unknown" - } - if s.created == 0 { - s.created = time.Now().Unix() - } -} - -func (s *chatCompletionStreamState) toolCallIndexFor(callID string) int { - if idx, ok := s.toolCallIndex[callID]; ok { - return idx - } - idx := len(s.toolCallIndex) - s.toolCallIndex[callID] = idx - return idx -} - -func (s *chatCompletionStreamState) applyMetadata(payload map[string]any) { - if responseObj, ok := payload["response"].(map[string]any); ok { - s.applyResponseMetadata(responseObj) - } - - if s.id == "" { - if id := strings.TrimSpace(getString(payload["response_id"])); id != "" { - s.id = id - } else if id := strings.TrimSpace(getString(payload["id"])); id != "" { - s.id = id + forwardDurationMs := time.Since(forwardStart).Milliseconds() + if accountReleaseFunc != nil { + accountReleaseFunc() } - } - if s.model == "" { - if model := strings.TrimSpace(getString(payload["model"])); model != "" { - s.model = model + upstreamLatencyMs, _ := getContextInt64(c, service.OpsUpstreamLatencyMsKey) + responseLatencyMs := forwardDurationMs + if upstreamLatencyMs > 0 && forwardDurationMs > upstreamLatencyMs { + responseLatencyMs = forwardDurationMs - upstreamLatencyMs } - } - if s.created == 0 { - if created := getInt64(payload["created_at"]); created != 0 { - s.created = created - } else if created := getInt64(payload["created"]); created != 0 { - s.created = created + service.SetOpsLatencyMs(c, service.OpsResponseLatencyMsKey, responseLatencyMs) + if err == nil && result != nil && result.FirstTokenMs != nil { + service.SetOpsLatencyMs(c, service.OpsTimeToFirstTokenMsKey, int64(*result.FirstTokenMs)) + } + if err != nil { + var failoverErr *service.UpstreamFailoverError + if errors.As(err, &failoverErr) { + h.gatewayService.ReportOpenAIAccountScheduleResult(account.ID, false, nil) + // Pool mode: retry on the same account + if failoverErr.RetryableOnSameAccount { + retryLimit := account.GetPoolModeRetryCount() + if sameAccountRetryCount[account.ID] < retryLimit { + sameAccountRetryCount[account.ID]++ + reqLog.Warn("openai_chat_completions.pool_mode_same_account_retry", + zap.Int64("account_id", account.ID), + zap.Int("upstream_status", failoverErr.StatusCode), + zap.Int("retry_limit", retryLimit), + zap.Int("retry_count", sameAccountRetryCount[account.ID]), + ) + select { + case <-c.Request.Context().Done(): + return + case <-time.After(sameAccountRetryDelay): + } + continue + } + } + h.gatewayService.RecordOpenAIAccountSwitch() + failedAccountIDs[account.ID] = struct{}{} + lastFailoverErr = failoverErr + if switchCount >= maxAccountSwitches { + h.handleFailoverExhausted(c, failoverErr, streamStarted) + return + } + switchCount++ + reqLog.Warn("openai_chat_completions.upstream_failover_switching", + zap.Int64("account_id", account.ID), + zap.Int("upstream_status", failoverErr.StatusCode), + zap.Int("switch_count", switchCount), + zap.Int("max_switches", maxAccountSwitches), + ) + continue + } + h.gatewayService.ReportOpenAIAccountScheduleResult(account.ID, false, nil) + wroteFallback := h.ensureForwardErrorResponse(c, streamStarted) + reqLog.Warn("openai_chat_completions.forward_failed", + zap.Int64("account_id", account.ID), + zap.Bool("fallback_error_response_written", wroteFallback), + zap.Error(err), + ) + return + } + if result != nil { + h.gatewayService.ReportOpenAIAccountScheduleResult(account.ID, true, result.FirstTokenMs) + } else { + h.gatewayService.ReportOpenAIAccountScheduleResult(account.ID, true, nil) } - } -} -func (s *chatCompletionStreamState) applyResponseMetadata(responseObj map[string]any) { - if s.id == "" { - if id := strings.TrimSpace(getString(responseObj["id"])); id != "" { - s.id = id - } - } - if s.model == "" { - if model := strings.TrimSpace(getString(responseObj["model"])); model != "" { - s.model = model - } - } - if s.created == 0 { - if created := getInt64(responseObj["created_at"]); created != 0 { - s.created = created - } - } -} + userAgent := c.GetHeader("User-Agent") + clientIP := ip.GetClientIP(c) -func (s *chatCompletionStreamState) applyResponseUsage(responseObj map[string]any) { - usage, ok := responseObj["usage"].(map[string]any) - if !ok { + h.submitUsageRecordTask(func(ctx context.Context) { + if err := h.gatewayService.RecordUsage(ctx, &service.OpenAIRecordUsageInput{ + Result: result, + APIKey: apiKey, + User: apiKey.User, + Account: account, + Subscription: subscription, + UserAgent: userAgent, + IPAddress: clientIP, + APIKeyService: h.apiKeyService, + }); err != nil { + logger.L().With( + zap.String("component", "handler.openai_gateway.chat_completions"), + zap.Int64("user_id", subject.UserID), + zap.Int64("api_key_id", apiKey.ID), + zap.Any("group_id", apiKey.GroupID), + zap.String("model", reqModel), + zap.Int64("account_id", account.ID), + ).Error("openai_chat_completions.record_usage_failed", zap.Error(err)) + } + }) + reqLog.Debug("openai_chat_completions.request_completed", + zap.Int64("account_id", account.ID), + zap.Int("switch_count", switchCount), + ) return } - promptTokens := int(getNumber(usage["input_tokens"])) - completionTokens := int(getNumber(usage["output_tokens"])) - if promptTokens == 0 && completionTokens == 0 { - return - } - s.usage = map[string]any{ - "prompt_tokens": promptTokens, - "completion_tokens": completionTokens, - "total_tokens": promptTokens + completionTokens, - } -} - -func extractToolCallFromItem(item map[string]any) (string, string, string, bool) { - itemType := strings.TrimSpace(getString(item["type"])) - if itemType != "tool_call" && itemType != "function_call" { - return "", "", "", false - } - callID := strings.TrimSpace(getString(item["call_id"])) - if callID == "" { - callID = strings.TrimSpace(getString(item["id"])) - } - name := strings.TrimSpace(getString(item["name"])) - args := getString(item["arguments"]) - if fn, ok := item["function"].(map[string]any); ok { - if name == "" { - name = strings.TrimSpace(getString(fn["name"])) - } - if args == "" { - args = getString(fn["arguments"]) - } - } - if callID == "" && name == "" && args == "" { - return "", "", "", false - } - if callID == "" { - callID = "call_" + randomHexUnsafe(6) - } - return callID, name, args, true -} - -func getString(value any) string { - switch v := value.(type) { - case string: - return v - case []byte: - return string(v) - case json.Number: - return v.String() - default: - return "" - } -} - -func getNumber(value any) float64 { - switch v := value.(type) { - case float64: - return v - case float32: - return float64(v) - case int: - return float64(v) - case int64: - return float64(v) - case json.Number: - f, _ := v.Float64() - return f - default: - return 0 - } -} - -func getInt64(value any) int64 { - switch v := value.(type) { - case int64: - return v - case int: - return int64(v) - case float64: - return int64(v) - case json.Number: - i, _ := v.Int64() - return i - default: - return 0 - } -} - -func randomHexUnsafe(byteLength int) string { - if byteLength <= 0 { - byteLength = 8 - } - buf := make([]byte, byteLength) - if _, err := rand.Read(buf); err != nil { - return "000000" - } - return hex.EncodeToString(buf) } diff --git a/backend/internal/pkg/apicompat/chatcompletions_responses_test.go b/backend/internal/pkg/apicompat/chatcompletions_responses_test.go new file mode 100644 index 00000000..71b7a6f5 --- /dev/null +++ b/backend/internal/pkg/apicompat/chatcompletions_responses_test.go @@ -0,0 +1,733 @@ +package apicompat + +import ( + "encoding/json" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +// --------------------------------------------------------------------------- +// ChatCompletionsToResponses tests +// --------------------------------------------------------------------------- + +func TestChatCompletionsToResponses_BasicText(t *testing.T) { + req := &ChatCompletionsRequest{ + Model: "gpt-4o", + Messages: []ChatMessage{ + {Role: "user", Content: json.RawMessage(`"Hello"`)}, + }, + } + + resp, err := ChatCompletionsToResponses(req) + require.NoError(t, err) + assert.Equal(t, "gpt-4o", resp.Model) + assert.True(t, resp.Stream) // always forced true + assert.False(t, *resp.Store) + + var items []ResponsesInputItem + require.NoError(t, json.Unmarshal(resp.Input, &items)) + require.Len(t, items, 1) + assert.Equal(t, "user", items[0].Role) +} + +func TestChatCompletionsToResponses_SystemMessage(t *testing.T) { + req := &ChatCompletionsRequest{ + Model: "gpt-4o", + Messages: []ChatMessage{ + {Role: "system", Content: json.RawMessage(`"You are helpful."`)}, + {Role: "user", Content: json.RawMessage(`"Hi"`)}, + }, + } + + resp, err := ChatCompletionsToResponses(req) + require.NoError(t, err) + + var items []ResponsesInputItem + require.NoError(t, json.Unmarshal(resp.Input, &items)) + require.Len(t, items, 2) + assert.Equal(t, "system", items[0].Role) + assert.Equal(t, "user", items[1].Role) +} + +func TestChatCompletionsToResponses_ToolCalls(t *testing.T) { + req := &ChatCompletionsRequest{ + Model: "gpt-4o", + Messages: []ChatMessage{ + {Role: "user", Content: json.RawMessage(`"Call the function"`)}, + { + Role: "assistant", + ToolCalls: []ChatToolCall{ + { + ID: "call_1", + Type: "function", + Function: ChatFunctionCall{ + Name: "ping", + Arguments: `{"host":"example.com"}`, + }, + }, + }, + }, + { + Role: "tool", + ToolCallID: "call_1", + Content: json.RawMessage(`"pong"`), + }, + }, + Tools: []ChatTool{ + { + Type: "function", + Function: &ChatFunction{ + Name: "ping", + Description: "Ping a host", + Parameters: json.RawMessage(`{"type":"object"}`), + }, + }, + }, + } + + resp, err := ChatCompletionsToResponses(req) + require.NoError(t, err) + + var items []ResponsesInputItem + require.NoError(t, json.Unmarshal(resp.Input, &items)) + // user + function_call + function_call_output = 3 + // (assistant message with empty content + tool_calls → only function_call items emitted) + require.Len(t, items, 3) + + // Check function_call item + assert.Equal(t, "function_call", items[1].Type) + assert.Equal(t, "call_1", items[1].CallID) + assert.Equal(t, "ping", items[1].Name) + + // Check function_call_output item + assert.Equal(t, "function_call_output", items[2].Type) + assert.Equal(t, "call_1", items[2].CallID) + assert.Equal(t, "pong", items[2].Output) + + // Check tools + require.Len(t, resp.Tools, 1) + assert.Equal(t, "function", resp.Tools[0].Type) + assert.Equal(t, "ping", resp.Tools[0].Name) +} + +func TestChatCompletionsToResponses_MaxTokens(t *testing.T) { + t.Run("max_tokens", func(t *testing.T) { + maxTokens := 100 + req := &ChatCompletionsRequest{ + Model: "gpt-4o", + MaxTokens: &maxTokens, + Messages: []ChatMessage{{Role: "user", Content: json.RawMessage(`"Hi"`)}}, + } + resp, err := ChatCompletionsToResponses(req) + require.NoError(t, err) + require.NotNil(t, resp.MaxOutputTokens) + // Below minMaxOutputTokens (128), should be clamped + assert.Equal(t, minMaxOutputTokens, *resp.MaxOutputTokens) + }) + + t.Run("max_completion_tokens_preferred", func(t *testing.T) { + maxTokens := 100 + maxCompletion := 500 + req := &ChatCompletionsRequest{ + Model: "gpt-4o", + MaxTokens: &maxTokens, + MaxCompletionTokens: &maxCompletion, + Messages: []ChatMessage{{Role: "user", Content: json.RawMessage(`"Hi"`)}}, + } + resp, err := ChatCompletionsToResponses(req) + require.NoError(t, err) + require.NotNil(t, resp.MaxOutputTokens) + assert.Equal(t, 500, *resp.MaxOutputTokens) + }) +} + +func TestChatCompletionsToResponses_ReasoningEffort(t *testing.T) { + req := &ChatCompletionsRequest{ + Model: "gpt-4o", + ReasoningEffort: "high", + Messages: []ChatMessage{{Role: "user", Content: json.RawMessage(`"Hi"`)}}, + } + resp, err := ChatCompletionsToResponses(req) + require.NoError(t, err) + require.NotNil(t, resp.Reasoning) + assert.Equal(t, "high", resp.Reasoning.Effort) + assert.Equal(t, "auto", resp.Reasoning.Summary) +} + +func TestChatCompletionsToResponses_ImageURL(t *testing.T) { + content := `[{"type":"text","text":"Describe this"},{"type":"image_url","image_url":{"url":"data:image/png;base64,abc123"}}]` + req := &ChatCompletionsRequest{ + Model: "gpt-4o", + Messages: []ChatMessage{ + {Role: "user", Content: json.RawMessage(content)}, + }, + } + resp, err := ChatCompletionsToResponses(req) + require.NoError(t, err) + + var items []ResponsesInputItem + require.NoError(t, json.Unmarshal(resp.Input, &items)) + require.Len(t, items, 1) + + var parts []ResponsesContentPart + require.NoError(t, json.Unmarshal(items[0].Content, &parts)) + require.Len(t, parts, 2) + assert.Equal(t, "input_text", parts[0].Type) + assert.Equal(t, "Describe this", parts[0].Text) + assert.Equal(t, "input_image", parts[1].Type) + assert.Equal(t, "data:image/png;base64,abc123", parts[1].ImageURL) +} + +func TestChatCompletionsToResponses_LegacyFunctions(t *testing.T) { + req := &ChatCompletionsRequest{ + Model: "gpt-4o", + Messages: []ChatMessage{ + {Role: "user", Content: json.RawMessage(`"Hi"`)}, + }, + Functions: []ChatFunction{ + { + Name: "get_weather", + Description: "Get weather", + Parameters: json.RawMessage(`{"type":"object"}`), + }, + }, + FunctionCall: json.RawMessage(`{"name":"get_weather"}`), + } + + resp, err := ChatCompletionsToResponses(req) + require.NoError(t, err) + require.Len(t, resp.Tools, 1) + assert.Equal(t, "function", resp.Tools[0].Type) + assert.Equal(t, "get_weather", resp.Tools[0].Name) + + // tool_choice should be converted + require.NotNil(t, resp.ToolChoice) + var tc map[string]any + require.NoError(t, json.Unmarshal(resp.ToolChoice, &tc)) + assert.Equal(t, "function", tc["type"]) +} + +func TestChatCompletionsToResponses_ServiceTier(t *testing.T) { + req := &ChatCompletionsRequest{ + Model: "gpt-4o", + ServiceTier: "flex", + Messages: []ChatMessage{{Role: "user", Content: json.RawMessage(`"Hi"`)}}, + } + resp, err := ChatCompletionsToResponses(req) + require.NoError(t, err) + assert.Equal(t, "flex", resp.ServiceTier) +} + +func TestChatCompletionsToResponses_AssistantWithTextAndToolCalls(t *testing.T) { + req := &ChatCompletionsRequest{ + Model: "gpt-4o", + Messages: []ChatMessage{ + {Role: "user", Content: json.RawMessage(`"Do something"`)}, + { + Role: "assistant", + Content: json.RawMessage(`"Let me call a function."`), + ToolCalls: []ChatToolCall{ + { + ID: "call_abc", + Type: "function", + Function: ChatFunctionCall{ + Name: "do_thing", + Arguments: `{}`, + }, + }, + }, + }, + }, + } + + resp, err := ChatCompletionsToResponses(req) + require.NoError(t, err) + + var items []ResponsesInputItem + require.NoError(t, json.Unmarshal(resp.Input, &items)) + // user + assistant message (with text) + function_call + require.Len(t, items, 3) + assert.Equal(t, "user", items[0].Role) + assert.Equal(t, "assistant", items[1].Role) + assert.Equal(t, "function_call", items[2].Type) +} + +// --------------------------------------------------------------------------- +// ResponsesToChatCompletions tests +// --------------------------------------------------------------------------- + +func TestResponsesToChatCompletions_BasicText(t *testing.T) { + resp := &ResponsesResponse{ + ID: "resp_123", + Status: "completed", + Output: []ResponsesOutput{ + { + Type: "message", + Content: []ResponsesContentPart{ + {Type: "output_text", Text: "Hello, world!"}, + }, + }, + }, + Usage: &ResponsesUsage{ + InputTokens: 10, + OutputTokens: 5, + TotalTokens: 15, + }, + } + + chat := ResponsesToChatCompletions(resp, "gpt-4o") + assert.Equal(t, "chat.completion", chat.Object) + assert.Equal(t, "gpt-4o", chat.Model) + require.Len(t, chat.Choices, 1) + assert.Equal(t, "stop", chat.Choices[0].FinishReason) + + var content string + require.NoError(t, json.Unmarshal(chat.Choices[0].Message.Content, &content)) + assert.Equal(t, "Hello, world!", content) + + require.NotNil(t, chat.Usage) + assert.Equal(t, 10, chat.Usage.PromptTokens) + assert.Equal(t, 5, chat.Usage.CompletionTokens) + assert.Equal(t, 15, chat.Usage.TotalTokens) +} + +func TestResponsesToChatCompletions_ToolCalls(t *testing.T) { + resp := &ResponsesResponse{ + ID: "resp_456", + Status: "completed", + Output: []ResponsesOutput{ + { + Type: "function_call", + CallID: "call_xyz", + Name: "get_weather", + Arguments: `{"city":"NYC"}`, + }, + }, + } + + chat := ResponsesToChatCompletions(resp, "gpt-4o") + require.Len(t, chat.Choices, 1) + assert.Equal(t, "tool_calls", chat.Choices[0].FinishReason) + + msg := chat.Choices[0].Message + require.Len(t, msg.ToolCalls, 1) + assert.Equal(t, "call_xyz", msg.ToolCalls[0].ID) + assert.Equal(t, "function", msg.ToolCalls[0].Type) + assert.Equal(t, "get_weather", msg.ToolCalls[0].Function.Name) + assert.Equal(t, `{"city":"NYC"}`, msg.ToolCalls[0].Function.Arguments) +} + +func TestResponsesToChatCompletions_Reasoning(t *testing.T) { + resp := &ResponsesResponse{ + ID: "resp_789", + Status: "completed", + Output: []ResponsesOutput{ + { + Type: "reasoning", + Summary: []ResponsesSummary{ + {Type: "summary_text", Text: "I thought about it."}, + }, + }, + { + Type: "message", + Content: []ResponsesContentPart{ + {Type: "output_text", Text: "The answer is 42."}, + }, + }, + }, + } + + chat := ResponsesToChatCompletions(resp, "gpt-4o") + require.Len(t, chat.Choices, 1) + + var content string + require.NoError(t, json.Unmarshal(chat.Choices[0].Message.Content, &content)) + // Reasoning summary is prepended to text + assert.Equal(t, "I thought about it.The answer is 42.", content) +} + +func TestResponsesToChatCompletions_Incomplete(t *testing.T) { + resp := &ResponsesResponse{ + ID: "resp_inc", + Status: "incomplete", + IncompleteDetails: &ResponsesIncompleteDetails{Reason: "max_output_tokens"}, + Output: []ResponsesOutput{ + { + Type: "message", + Content: []ResponsesContentPart{ + {Type: "output_text", Text: "partial..."}, + }, + }, + }, + } + + chat := ResponsesToChatCompletions(resp, "gpt-4o") + require.Len(t, chat.Choices, 1) + assert.Equal(t, "length", chat.Choices[0].FinishReason) +} + +func TestResponsesToChatCompletions_CachedTokens(t *testing.T) { + resp := &ResponsesResponse{ + ID: "resp_cache", + Status: "completed", + Output: []ResponsesOutput{ + { + Type: "message", + Content: []ResponsesContentPart{{Type: "output_text", Text: "cached"}}, + }, + }, + Usage: &ResponsesUsage{ + InputTokens: 100, + OutputTokens: 10, + TotalTokens: 110, + InputTokensDetails: &ResponsesInputTokensDetails{ + CachedTokens: 80, + }, + }, + } + + chat := ResponsesToChatCompletions(resp, "gpt-4o") + require.NotNil(t, chat.Usage) + require.NotNil(t, chat.Usage.PromptTokensDetails) + assert.Equal(t, 80, chat.Usage.PromptTokensDetails.CachedTokens) +} + +func TestResponsesToChatCompletions_WebSearch(t *testing.T) { + resp := &ResponsesResponse{ + ID: "resp_ws", + Status: "completed", + Output: []ResponsesOutput{ + { + Type: "web_search_call", + Action: &WebSearchAction{Type: "search", Query: "test"}, + }, + { + Type: "message", + Content: []ResponsesContentPart{{Type: "output_text", Text: "search results"}}, + }, + }, + } + + chat := ResponsesToChatCompletions(resp, "gpt-4o") + require.Len(t, chat.Choices, 1) + assert.Equal(t, "stop", chat.Choices[0].FinishReason) + + var content string + require.NoError(t, json.Unmarshal(chat.Choices[0].Message.Content, &content)) + assert.Equal(t, "search results", content) +} + +// --------------------------------------------------------------------------- +// Streaming: ResponsesEventToChatChunks tests +// --------------------------------------------------------------------------- + +func TestResponsesEventToChatChunks_TextDelta(t *testing.T) { + state := NewResponsesEventToChatState() + state.Model = "gpt-4o" + + // response.created → role chunk + chunks := ResponsesEventToChatChunks(&ResponsesStreamEvent{ + Type: "response.created", + Response: &ResponsesResponse{ + ID: "resp_stream", + }, + }, state) + require.Len(t, chunks, 1) + assert.Equal(t, "assistant", chunks[0].Choices[0].Delta.Role) + assert.True(t, state.SentRole) + + // response.output_text.delta → content chunk + chunks = ResponsesEventToChatChunks(&ResponsesStreamEvent{ + Type: "response.output_text.delta", + Delta: "Hello", + }, state) + require.Len(t, chunks, 1) + require.NotNil(t, chunks[0].Choices[0].Delta.Content) + assert.Equal(t, "Hello", *chunks[0].Choices[0].Delta.Content) +} + +func TestResponsesEventToChatChunks_ToolCallDelta(t *testing.T) { + state := NewResponsesEventToChatState() + state.Model = "gpt-4o" + state.SentRole = true + + // response.output_item.added (function_call) — output_index=1 (e.g. after a message item at 0) + chunks := ResponsesEventToChatChunks(&ResponsesStreamEvent{ + Type: "response.output_item.added", + OutputIndex: 1, + Item: &ResponsesOutput{ + Type: "function_call", + CallID: "call_1", + Name: "get_weather", + }, + }, state) + require.Len(t, chunks, 1) + require.Len(t, chunks[0].Choices[0].Delta.ToolCalls, 1) + tc := chunks[0].Choices[0].Delta.ToolCalls[0] + assert.Equal(t, "call_1", tc.ID) + assert.Equal(t, "get_weather", tc.Function.Name) + require.NotNil(t, tc.Index) + assert.Equal(t, 0, *tc.Index) + + // response.function_call_arguments.delta — uses output_index (NOT call_id) to find tool + chunks = ResponsesEventToChatChunks(&ResponsesStreamEvent{ + Type: "response.function_call_arguments.delta", + OutputIndex: 1, // matches the output_index from output_item.added above + Delta: `{"city":`, + }, state) + require.Len(t, chunks, 1) + tc = chunks[0].Choices[0].Delta.ToolCalls[0] + require.NotNil(t, tc.Index) + assert.Equal(t, 0, *tc.Index, "argument delta must use same index as the tool call") + assert.Equal(t, `{"city":`, tc.Function.Arguments) + + // Add a second function call at output_index=2 + chunks = ResponsesEventToChatChunks(&ResponsesStreamEvent{ + Type: "response.output_item.added", + OutputIndex: 2, + Item: &ResponsesOutput{ + Type: "function_call", + CallID: "call_2", + Name: "get_time", + }, + }, state) + require.Len(t, chunks, 1) + tc = chunks[0].Choices[0].Delta.ToolCalls[0] + require.NotNil(t, tc.Index) + assert.Equal(t, 1, *tc.Index, "second tool call should get index 1") + + // Argument delta for second tool call + chunks = ResponsesEventToChatChunks(&ResponsesStreamEvent{ + Type: "response.function_call_arguments.delta", + OutputIndex: 2, + Delta: `{"tz":"UTC"}`, + }, state) + require.Len(t, chunks, 1) + tc = chunks[0].Choices[0].Delta.ToolCalls[0] + require.NotNil(t, tc.Index) + assert.Equal(t, 1, *tc.Index, "second tool arg delta must use index 1") + + // Argument delta for first tool call (interleaved) + chunks = ResponsesEventToChatChunks(&ResponsesStreamEvent{ + Type: "response.function_call_arguments.delta", + OutputIndex: 1, + Delta: `"Tokyo"}`, + }, state) + require.Len(t, chunks, 1) + tc = chunks[0].Choices[0].Delta.ToolCalls[0] + require.NotNil(t, tc.Index) + assert.Equal(t, 0, *tc.Index, "first tool arg delta must still use index 0") +} + +func TestResponsesEventToChatChunks_Completed(t *testing.T) { + state := NewResponsesEventToChatState() + state.Model = "gpt-4o" + state.IncludeUsage = true + + chunks := ResponsesEventToChatChunks(&ResponsesStreamEvent{ + Type: "response.completed", + Response: &ResponsesResponse{ + Status: "completed", + Usage: &ResponsesUsage{ + InputTokens: 50, + OutputTokens: 20, + TotalTokens: 70, + InputTokensDetails: &ResponsesInputTokensDetails{ + CachedTokens: 30, + }, + }, + }, + }, state) + // finish chunk + usage chunk + require.Len(t, chunks, 2) + + // First chunk: finish_reason + require.NotNil(t, chunks[0].Choices[0].FinishReason) + assert.Equal(t, "stop", *chunks[0].Choices[0].FinishReason) + + // Second chunk: usage + require.NotNil(t, chunks[1].Usage) + assert.Equal(t, 50, chunks[1].Usage.PromptTokens) + assert.Equal(t, 20, chunks[1].Usage.CompletionTokens) + assert.Equal(t, 70, chunks[1].Usage.TotalTokens) + require.NotNil(t, chunks[1].Usage.PromptTokensDetails) + assert.Equal(t, 30, chunks[1].Usage.PromptTokensDetails.CachedTokens) +} + +func TestResponsesEventToChatChunks_CompletedWithToolCalls(t *testing.T) { + state := NewResponsesEventToChatState() + state.Model = "gpt-4o" + state.SawToolCall = true + + chunks := ResponsesEventToChatChunks(&ResponsesStreamEvent{ + Type: "response.completed", + Response: &ResponsesResponse{ + Status: "completed", + }, + }, state) + require.Len(t, chunks, 1) + require.NotNil(t, chunks[0].Choices[0].FinishReason) + assert.Equal(t, "tool_calls", *chunks[0].Choices[0].FinishReason) +} + +func TestResponsesEventToChatChunks_ReasoningDelta(t *testing.T) { + state := NewResponsesEventToChatState() + state.Model = "gpt-4o" + state.SentRole = true + + chunks := ResponsesEventToChatChunks(&ResponsesStreamEvent{ + Type: "response.reasoning_summary_text.delta", + Delta: "Thinking...", + }, state) + require.Len(t, chunks, 1) + require.NotNil(t, chunks[0].Choices[0].Delta.Content) + assert.Equal(t, "Thinking...", *chunks[0].Choices[0].Delta.Content) +} + +func TestFinalizeResponsesChatStream(t *testing.T) { + state := NewResponsesEventToChatState() + state.Model = "gpt-4o" + state.IncludeUsage = true + state.Usage = &ChatUsage{ + PromptTokens: 100, + CompletionTokens: 50, + TotalTokens: 150, + } + + chunks := FinalizeResponsesChatStream(state) + require.Len(t, chunks, 2) + + // Finish chunk + require.NotNil(t, chunks[0].Choices[0].FinishReason) + assert.Equal(t, "stop", *chunks[0].Choices[0].FinishReason) + + // Usage chunk + require.NotNil(t, chunks[1].Usage) + assert.Equal(t, 100, chunks[1].Usage.PromptTokens) + + // Idempotent: second call returns nil + assert.Nil(t, FinalizeResponsesChatStream(state)) +} + +func TestFinalizeResponsesChatStream_AfterCompleted(t *testing.T) { + // If response.completed already emitted the finish chunk, FinalizeResponsesChatStream + // must be a no-op (prevents double finish_reason being sent to the client). + state := NewResponsesEventToChatState() + state.Model = "gpt-4o" + state.IncludeUsage = true + + // Simulate response.completed + chunks := ResponsesEventToChatChunks(&ResponsesStreamEvent{ + Type: "response.completed", + Response: &ResponsesResponse{ + Status: "completed", + Usage: &ResponsesUsage{ + InputTokens: 10, + OutputTokens: 5, + TotalTokens: 15, + }, + }, + }, state) + require.NotEmpty(t, chunks) // finish + usage chunks + + // Now FinalizeResponsesChatStream should return nil — already finalized. + assert.Nil(t, FinalizeResponsesChatStream(state)) +} + +func TestChatChunkToSSE(t *testing.T) { + chunk := ChatCompletionsChunk{ + ID: "chatcmpl-test", + Object: "chat.completion.chunk", + Created: 1700000000, + Model: "gpt-4o", + Choices: []ChatChunkChoice{ + { + Index: 0, + Delta: ChatDelta{Role: "assistant"}, + FinishReason: nil, + }, + }, + } + + sse, err := ChatChunkToSSE(chunk) + require.NoError(t, err) + assert.Contains(t, sse, "data: ") + assert.Contains(t, sse, "chatcmpl-test") + assert.Contains(t, sse, "assistant") + assert.True(t, len(sse) > 10) +} + +// --------------------------------------------------------------------------- +// Stream round-trip test +// --------------------------------------------------------------------------- + +func TestChatCompletionsStreamRoundTrip(t *testing.T) { + // Simulate: client sends chat completions request, upstream returns Responses SSE events. + // Verify that the streaming state machine produces correct chat completions chunks. + + state := NewResponsesEventToChatState() + state.Model = "gpt-4o" + state.IncludeUsage = true + + var allChunks []ChatCompletionsChunk + + // 1. response.created + chunks := ResponsesEventToChatChunks(&ResponsesStreamEvent{ + Type: "response.created", + Response: &ResponsesResponse{ID: "resp_rt"}, + }, state) + allChunks = append(allChunks, chunks...) + + // 2. text deltas + for _, text := range []string{"Hello", ", ", "world", "!"} { + chunks = ResponsesEventToChatChunks(&ResponsesStreamEvent{ + Type: "response.output_text.delta", + Delta: text, + }, state) + allChunks = append(allChunks, chunks...) + } + + // 3. response.completed + chunks = ResponsesEventToChatChunks(&ResponsesStreamEvent{ + Type: "response.completed", + Response: &ResponsesResponse{ + Status: "completed", + Usage: &ResponsesUsage{ + InputTokens: 10, + OutputTokens: 4, + TotalTokens: 14, + }, + }, + }, state) + allChunks = append(allChunks, chunks...) + + // Verify: role chunk + 4 text chunks + finish chunk + usage chunk = 7 + require.Len(t, allChunks, 7) + + // First chunk has role + assert.Equal(t, "assistant", allChunks[0].Choices[0].Delta.Role) + + // Text chunks + var fullText string + for i := 1; i <= 4; i++ { + require.NotNil(t, allChunks[i].Choices[0].Delta.Content) + fullText += *allChunks[i].Choices[0].Delta.Content + } + assert.Equal(t, "Hello, world!", fullText) + + // Finish chunk + require.NotNil(t, allChunks[5].Choices[0].FinishReason) + assert.Equal(t, "stop", *allChunks[5].Choices[0].FinishReason) + + // Usage chunk + require.NotNil(t, allChunks[6].Usage) + assert.Equal(t, 10, allChunks[6].Usage.PromptTokens) + assert.Equal(t, 4, allChunks[6].Usage.CompletionTokens) + + // All chunks share the same ID + for _, c := range allChunks { + assert.Equal(t, "resp_rt", c.ID) + } +} diff --git a/backend/internal/pkg/apicompat/chatcompletions_to_responses.go b/backend/internal/pkg/apicompat/chatcompletions_to_responses.go new file mode 100644 index 00000000..37285b09 --- /dev/null +++ b/backend/internal/pkg/apicompat/chatcompletions_to_responses.go @@ -0,0 +1,312 @@ +package apicompat + +import ( + "encoding/json" + "fmt" +) + +// ChatCompletionsToResponses converts a Chat Completions request into a +// Responses API request. The upstream always streams, so Stream is forced to +// true. store is always false and reasoning.encrypted_content is always +// included so that the response translator has full context. +func ChatCompletionsToResponses(req *ChatCompletionsRequest) (*ResponsesRequest, error) { + input, err := convertChatMessagesToResponsesInput(req.Messages) + if err != nil { + return nil, err + } + + inputJSON, err := json.Marshal(input) + if err != nil { + return nil, err + } + + out := &ResponsesRequest{ + Model: req.Model, + Input: inputJSON, + Temperature: req.Temperature, + TopP: req.TopP, + Stream: true, // upstream always streams + Include: []string{"reasoning.encrypted_content"}, + ServiceTier: req.ServiceTier, + } + + storeFalse := false + out.Store = &storeFalse + + // max_tokens / max_completion_tokens → max_output_tokens, prefer max_completion_tokens + maxTokens := 0 + if req.MaxTokens != nil { + maxTokens = *req.MaxTokens + } + if req.MaxCompletionTokens != nil { + maxTokens = *req.MaxCompletionTokens + } + if maxTokens > 0 { + v := maxTokens + if v < minMaxOutputTokens { + v = minMaxOutputTokens + } + out.MaxOutputTokens = &v + } + + // reasoning_effort → reasoning.effort + reasoning.summary="auto" + if req.ReasoningEffort != "" { + out.Reasoning = &ResponsesReasoning{ + Effort: req.ReasoningEffort, + Summary: "auto", + } + } + + // tools[] and legacy functions[] → ResponsesTool[] + if len(req.Tools) > 0 || len(req.Functions) > 0 { + out.Tools = convertChatToolsToResponses(req.Tools, req.Functions) + } + + // tool_choice: already compatible format — pass through directly. + // Legacy function_call needs mapping. + if len(req.ToolChoice) > 0 { + out.ToolChoice = req.ToolChoice + } else if len(req.FunctionCall) > 0 { + tc, err := convertChatFunctionCallToToolChoice(req.FunctionCall) + if err != nil { + return nil, fmt.Errorf("convert function_call: %w", err) + } + out.ToolChoice = tc + } + + return out, nil +} + +// convertChatMessagesToResponsesInput converts the Chat Completions messages +// array into a Responses API input items array. +func convertChatMessagesToResponsesInput(msgs []ChatMessage) ([]ResponsesInputItem, error) { + var out []ResponsesInputItem + for _, m := range msgs { + items, err := chatMessageToResponsesItems(m) + if err != nil { + return nil, err + } + out = append(out, items...) + } + return out, nil +} + +// chatMessageToResponsesItems converts a single ChatMessage into one or more +// ResponsesInputItem values. +func chatMessageToResponsesItems(m ChatMessage) ([]ResponsesInputItem, error) { + switch m.Role { + case "system": + return chatSystemToResponses(m) + case "user": + return chatUserToResponses(m) + case "assistant": + return chatAssistantToResponses(m) + case "tool": + return chatToolToResponses(m) + case "function": + return chatFunctionToResponses(m) + default: + return chatUserToResponses(m) + } +} + +// chatSystemToResponses converts a system message. +func chatSystemToResponses(m ChatMessage) ([]ResponsesInputItem, error) { + text, err := parseChatContent(m.Content) + if err != nil { + return nil, err + } + content, err := json.Marshal(text) + if err != nil { + return nil, err + } + return []ResponsesInputItem{{Role: "system", Content: content}}, nil +} + +// chatUserToResponses converts a user message, handling both plain strings and +// multi-modal content arrays. +func chatUserToResponses(m ChatMessage) ([]ResponsesInputItem, error) { + // Try plain string first. + var s string + if err := json.Unmarshal(m.Content, &s); err == nil { + content, _ := json.Marshal(s) + return []ResponsesInputItem{{Role: "user", Content: content}}, nil + } + + var parts []ChatContentPart + if err := json.Unmarshal(m.Content, &parts); err != nil { + return nil, fmt.Errorf("parse user content: %w", err) + } + + var responseParts []ResponsesContentPart + for _, p := range parts { + switch p.Type { + case "text": + if p.Text != "" { + responseParts = append(responseParts, ResponsesContentPart{ + Type: "input_text", + Text: p.Text, + }) + } + case "image_url": + if p.ImageURL != nil && p.ImageURL.URL != "" { + responseParts = append(responseParts, ResponsesContentPart{ + Type: "input_image", + ImageURL: p.ImageURL.URL, + }) + } + } + } + + content, err := json.Marshal(responseParts) + if err != nil { + return nil, err + } + return []ResponsesInputItem{{Role: "user", Content: content}}, nil +} + +// chatAssistantToResponses converts an assistant message. If there is both +// text content and tool_calls, the text is emitted as an assistant message +// first, then each tool_call becomes a function_call item. If the content is +// empty/nil and there are tool_calls, only function_call items are emitted. +func chatAssistantToResponses(m ChatMessage) ([]ResponsesInputItem, error) { + var items []ResponsesInputItem + + // Emit assistant message with output_text if content is non-empty. + if len(m.Content) > 0 { + var s string + if err := json.Unmarshal(m.Content, &s); err == nil && s != "" { + parts := []ResponsesContentPart{{Type: "output_text", Text: s}} + partsJSON, err := json.Marshal(parts) + if err != nil { + return nil, err + } + items = append(items, ResponsesInputItem{Role: "assistant", Content: partsJSON}) + } + } + + // Emit one function_call item per tool_call. + for _, tc := range m.ToolCalls { + args := tc.Function.Arguments + if args == "" { + args = "{}" + } + items = append(items, ResponsesInputItem{ + Type: "function_call", + CallID: tc.ID, + Name: tc.Function.Name, + Arguments: args, + ID: tc.ID, + }) + } + + return items, nil +} + +// chatToolToResponses converts a tool result message (role=tool) into a +// function_call_output item. +func chatToolToResponses(m ChatMessage) ([]ResponsesInputItem, error) { + output, err := parseChatContent(m.Content) + if err != nil { + return nil, err + } + if output == "" { + output = "(empty)" + } + return []ResponsesInputItem{{ + Type: "function_call_output", + CallID: m.ToolCallID, + Output: output, + }}, nil +} + +// chatFunctionToResponses converts a legacy function result message +// (role=function) into a function_call_output item. The Name field is used as +// call_id since legacy function calls do not carry a separate call_id. +func chatFunctionToResponses(m ChatMessage) ([]ResponsesInputItem, error) { + output, err := parseChatContent(m.Content) + if err != nil { + return nil, err + } + if output == "" { + output = "(empty)" + } + return []ResponsesInputItem{{ + Type: "function_call_output", + CallID: m.Name, + Output: output, + }}, nil +} + +// parseChatContent returns the string value of a ChatMessage Content field. +// Content must be a JSON string. Returns "" if content is null or empty. +func parseChatContent(raw json.RawMessage) (string, error) { + if len(raw) == 0 { + return "", nil + } + var s string + if err := json.Unmarshal(raw, &s); err != nil { + return "", fmt.Errorf("parse content as string: %w", err) + } + return s, nil +} + +// convertChatToolsToResponses maps Chat Completions tool definitions and legacy +// function definitions to Responses API tool definitions. +func convertChatToolsToResponses(tools []ChatTool, functions []ChatFunction) []ResponsesTool { + var out []ResponsesTool + + for _, t := range tools { + if t.Type != "function" || t.Function == nil { + continue + } + rt := ResponsesTool{ + Type: "function", + Name: t.Function.Name, + Description: t.Function.Description, + Parameters: t.Function.Parameters, + Strict: t.Function.Strict, + } + out = append(out, rt) + } + + // Legacy functions[] are treated as function-type tools. + for _, f := range functions { + rt := ResponsesTool{ + Type: "function", + Name: f.Name, + Description: f.Description, + Parameters: f.Parameters, + Strict: f.Strict, + } + out = append(out, rt) + } + + return out +} + +// convertChatFunctionCallToToolChoice maps the legacy function_call field to a +// Responses API tool_choice value. +// +// "auto" → "auto" +// "none" → "none" +// {"name":"X"} → {"type":"function","function":{"name":"X"}} +func convertChatFunctionCallToToolChoice(raw json.RawMessage) (json.RawMessage, error) { + // Try string first ("auto", "none", etc.) — pass through as-is. + var s string + if err := json.Unmarshal(raw, &s); err == nil { + return json.Marshal(s) + } + + // Object form: {"name":"X"} + var obj struct { + Name string `json:"name"` + } + if err := json.Unmarshal(raw, &obj); err != nil { + return nil, err + } + return json.Marshal(map[string]any{ + "type": "function", + "function": map[string]string{"name": obj.Name}, + }) +} diff --git a/backend/internal/pkg/apicompat/responses_to_chatcompletions.go b/backend/internal/pkg/apicompat/responses_to_chatcompletions.go new file mode 100644 index 00000000..8f83bce4 --- /dev/null +++ b/backend/internal/pkg/apicompat/responses_to_chatcompletions.go @@ -0,0 +1,368 @@ +package apicompat + +import ( + "crypto/rand" + "encoding/hex" + "encoding/json" + "fmt" + "time" +) + +// --------------------------------------------------------------------------- +// Non-streaming: ResponsesResponse → ChatCompletionsResponse +// --------------------------------------------------------------------------- + +// ResponsesToChatCompletions converts a Responses API response into a Chat +// Completions response. Text output items are concatenated into +// choices[0].message.content; function_call items become tool_calls. +func ResponsesToChatCompletions(resp *ResponsesResponse, model string) *ChatCompletionsResponse { + id := resp.ID + if id == "" { + id = generateChatCmplID() + } + + out := &ChatCompletionsResponse{ + ID: id, + Object: "chat.completion", + Created: time.Now().Unix(), + Model: model, + } + + var contentText string + var toolCalls []ChatToolCall + + for _, item := range resp.Output { + switch item.Type { + case "message": + for _, part := range item.Content { + if part.Type == "output_text" && part.Text != "" { + contentText += part.Text + } + } + case "function_call": + toolCalls = append(toolCalls, ChatToolCall{ + ID: item.CallID, + Type: "function", + Function: ChatFunctionCall{ + Name: item.Name, + Arguments: item.Arguments, + }, + }) + case "reasoning": + for _, s := range item.Summary { + if s.Type == "summary_text" && s.Text != "" { + contentText += s.Text + } + } + case "web_search_call": + // silently consumed — results already incorporated into text output + } + } + + msg := ChatMessage{Role: "assistant"} + if len(toolCalls) > 0 { + msg.ToolCalls = toolCalls + } + if contentText != "" { + raw, _ := json.Marshal(contentText) + msg.Content = raw + } + + finishReason := responsesStatusToChatFinishReason(resp.Status, resp.IncompleteDetails, toolCalls) + + out.Choices = []ChatChoice{{ + Index: 0, + Message: msg, + FinishReason: finishReason, + }} + + if resp.Usage != nil { + usage := &ChatUsage{ + PromptTokens: resp.Usage.InputTokens, + CompletionTokens: resp.Usage.OutputTokens, + TotalTokens: resp.Usage.InputTokens + resp.Usage.OutputTokens, + } + if resp.Usage.InputTokensDetails != nil && resp.Usage.InputTokensDetails.CachedTokens > 0 { + usage.PromptTokensDetails = &ChatTokenDetails{ + CachedTokens: resp.Usage.InputTokensDetails.CachedTokens, + } + } + out.Usage = usage + } + + return out +} + +func responsesStatusToChatFinishReason(status string, details *ResponsesIncompleteDetails, toolCalls []ChatToolCall) string { + switch status { + case "incomplete": + if details != nil && details.Reason == "max_output_tokens" { + return "length" + } + return "stop" + case "completed": + if len(toolCalls) > 0 { + return "tool_calls" + } + return "stop" + default: + return "stop" + } +} + +// --------------------------------------------------------------------------- +// Streaming: ResponsesStreamEvent → []ChatCompletionsChunk (stateful converter) +// --------------------------------------------------------------------------- + +// ResponsesEventToChatState tracks state for converting a sequence of Responses +// SSE events into Chat Completions SSE chunks. +type ResponsesEventToChatState struct { + ID string + Model string + Created int64 + SentRole bool + SawToolCall bool + SawText bool + Finalized bool // true after finish chunk has been emitted + NextToolCallIndex int // next sequential tool_call index to assign + OutputIndexToToolIndex map[int]int // Responses output_index → Chat tool_calls index + IncludeUsage bool + Usage *ChatUsage +} + +// NewResponsesEventToChatState returns an initialised stream state. +func NewResponsesEventToChatState() *ResponsesEventToChatState { + return &ResponsesEventToChatState{ + ID: generateChatCmplID(), + Created: time.Now().Unix(), + OutputIndexToToolIndex: make(map[int]int), + } +} + +// ResponsesEventToChatChunks converts a single Responses SSE event into zero +// or more Chat Completions chunks, updating state as it goes. +func ResponsesEventToChatChunks(evt *ResponsesStreamEvent, state *ResponsesEventToChatState) []ChatCompletionsChunk { + switch evt.Type { + case "response.created": + return resToChatHandleCreated(evt, state) + case "response.output_text.delta": + return resToChatHandleTextDelta(evt, state) + case "response.output_item.added": + return resToChatHandleOutputItemAdded(evt, state) + case "response.function_call_arguments.delta": + return resToChatHandleFuncArgsDelta(evt, state) + case "response.reasoning_summary_text.delta": + return resToChatHandleReasoningDelta(evt, state) + case "response.completed", "response.incomplete", "response.failed": + return resToChatHandleCompleted(evt, state) + default: + return nil + } +} + +// FinalizeResponsesChatStream emits a final chunk with finish_reason if the +// stream ended without a proper completion event (e.g. upstream disconnect). +// It is idempotent: if a completion event already emitted the finish chunk, +// this returns nil. +func FinalizeResponsesChatStream(state *ResponsesEventToChatState) []ChatCompletionsChunk { + if state.Finalized { + return nil + } + state.Finalized = true + + finishReason := "stop" + if state.SawToolCall { + finishReason = "tool_calls" + } + + chunks := []ChatCompletionsChunk{makeChatFinishChunk(state, finishReason)} + + if state.IncludeUsage && state.Usage != nil { + chunks = append(chunks, ChatCompletionsChunk{ + ID: state.ID, + Object: "chat.completion.chunk", + Created: state.Created, + Model: state.Model, + Choices: []ChatChunkChoice{}, + Usage: state.Usage, + }) + } + + return chunks +} + +// ChatChunkToSSE formats a ChatCompletionsChunk as an SSE data line. +func ChatChunkToSSE(chunk ChatCompletionsChunk) (string, error) { + data, err := json.Marshal(chunk) + if err != nil { + return "", err + } + return fmt.Sprintf("data: %s\n\n", data), nil +} + +// --- internal handlers --- + +func resToChatHandleCreated(evt *ResponsesStreamEvent, state *ResponsesEventToChatState) []ChatCompletionsChunk { + if evt.Response != nil { + if evt.Response.ID != "" { + state.ID = evt.Response.ID + } + if state.Model == "" && evt.Response.Model != "" { + state.Model = evt.Response.Model + } + } + // Emit the role chunk. + if state.SentRole { + return nil + } + state.SentRole = true + + role := "assistant" + return []ChatCompletionsChunk{makeChatDeltaChunk(state, ChatDelta{Role: role})} +} + +func resToChatHandleTextDelta(evt *ResponsesStreamEvent, state *ResponsesEventToChatState) []ChatCompletionsChunk { + if evt.Delta == "" { + return nil + } + state.SawText = true + content := evt.Delta + return []ChatCompletionsChunk{makeChatDeltaChunk(state, ChatDelta{Content: &content})} +} + +func resToChatHandleOutputItemAdded(evt *ResponsesStreamEvent, state *ResponsesEventToChatState) []ChatCompletionsChunk { + if evt.Item == nil || evt.Item.Type != "function_call" { + return nil + } + + state.SawToolCall = true + idx := state.NextToolCallIndex + state.OutputIndexToToolIndex[evt.OutputIndex] = idx + state.NextToolCallIndex++ + + return []ChatCompletionsChunk{makeChatDeltaChunk(state, ChatDelta{ + ToolCalls: []ChatToolCall{{ + Index: &idx, + ID: evt.Item.CallID, + Type: "function", + Function: ChatFunctionCall{ + Name: evt.Item.Name, + }, + }}, + })} +} + +func resToChatHandleFuncArgsDelta(evt *ResponsesStreamEvent, state *ResponsesEventToChatState) []ChatCompletionsChunk { + if evt.Delta == "" { + return nil + } + + idx, ok := state.OutputIndexToToolIndex[evt.OutputIndex] + if !ok { + return nil + } + + return []ChatCompletionsChunk{makeChatDeltaChunk(state, ChatDelta{ + ToolCalls: []ChatToolCall{{ + Index: &idx, + Function: ChatFunctionCall{ + Arguments: evt.Delta, + }, + }}, + })} +} + +func resToChatHandleReasoningDelta(evt *ResponsesStreamEvent, state *ResponsesEventToChatState) []ChatCompletionsChunk { + if evt.Delta == "" { + return nil + } + content := evt.Delta + return []ChatCompletionsChunk{makeChatDeltaChunk(state, ChatDelta{Content: &content})} +} + +func resToChatHandleCompleted(evt *ResponsesStreamEvent, state *ResponsesEventToChatState) []ChatCompletionsChunk { + state.Finalized = true + finishReason := "stop" + + if evt.Response != nil { + if evt.Response.Usage != nil { + u := evt.Response.Usage + usage := &ChatUsage{ + PromptTokens: u.InputTokens, + CompletionTokens: u.OutputTokens, + TotalTokens: u.InputTokens + u.OutputTokens, + } + if u.InputTokensDetails != nil && u.InputTokensDetails.CachedTokens > 0 { + usage.PromptTokensDetails = &ChatTokenDetails{ + CachedTokens: u.InputTokensDetails.CachedTokens, + } + } + state.Usage = usage + } + + switch evt.Response.Status { + case "incomplete": + if evt.Response.IncompleteDetails != nil && evt.Response.IncompleteDetails.Reason == "max_output_tokens" { + finishReason = "length" + } + case "completed": + if state.SawToolCall { + finishReason = "tool_calls" + } + } + } else if state.SawToolCall { + finishReason = "tool_calls" + } + + var chunks []ChatCompletionsChunk + chunks = append(chunks, makeChatFinishChunk(state, finishReason)) + + if state.IncludeUsage && state.Usage != nil { + chunks = append(chunks, ChatCompletionsChunk{ + ID: state.ID, + Object: "chat.completion.chunk", + Created: state.Created, + Model: state.Model, + Choices: []ChatChunkChoice{}, + Usage: state.Usage, + }) + } + + return chunks +} + +func makeChatDeltaChunk(state *ResponsesEventToChatState, delta ChatDelta) ChatCompletionsChunk { + return ChatCompletionsChunk{ + ID: state.ID, + Object: "chat.completion.chunk", + Created: state.Created, + Model: state.Model, + Choices: []ChatChunkChoice{{ + Index: 0, + Delta: delta, + FinishReason: nil, + }}, + } +} + +func makeChatFinishChunk(state *ResponsesEventToChatState, finishReason string) ChatCompletionsChunk { + empty := "" + return ChatCompletionsChunk{ + ID: state.ID, + Object: "chat.completion.chunk", + Created: state.Created, + Model: state.Model, + Choices: []ChatChunkChoice{{ + Index: 0, + Delta: ChatDelta{Content: &empty}, + FinishReason: &finishReason, + }}, + } +} + +// generateChatCmplID returns a "chatcmpl-" prefixed random hex ID. +func generateChatCmplID() string { + b := make([]byte, 12) + _, _ = rand.Read(b) + return "chatcmpl-" + hex.EncodeToString(b) +} diff --git a/backend/internal/pkg/apicompat/types.go b/backend/internal/pkg/apicompat/types.go index aa58b58f..eb77d89f 100644 --- a/backend/internal/pkg/apicompat/types.go +++ b/backend/internal/pkg/apicompat/types.go @@ -329,6 +329,148 @@ type ResponsesStreamEvent struct { SequenceNumber int `json:"sequence_number,omitempty"` } +// --------------------------------------------------------------------------- +// OpenAI Chat Completions API types +// --------------------------------------------------------------------------- + +// ChatCompletionsRequest is the request body for POST /v1/chat/completions. +type ChatCompletionsRequest struct { + Model string `json:"model"` + Messages []ChatMessage `json:"messages"` + MaxTokens *int `json:"max_tokens,omitempty"` + MaxCompletionTokens *int `json:"max_completion_tokens,omitempty"` + Temperature *float64 `json:"temperature,omitempty"` + TopP *float64 `json:"top_p,omitempty"` + Stream bool `json:"stream,omitempty"` + StreamOptions *ChatStreamOptions `json:"stream_options,omitempty"` + Tools []ChatTool `json:"tools,omitempty"` + ToolChoice json.RawMessage `json:"tool_choice,omitempty"` + ReasoningEffort string `json:"reasoning_effort,omitempty"` // "low" | "medium" | "high" + ServiceTier string `json:"service_tier,omitempty"` + Stop json.RawMessage `json:"stop,omitempty"` // string or []string + + // Legacy function calling (deprecated but still supported) + Functions []ChatFunction `json:"functions,omitempty"` + FunctionCall json.RawMessage `json:"function_call,omitempty"` +} + +// ChatStreamOptions configures streaming behavior. +type ChatStreamOptions struct { + IncludeUsage bool `json:"include_usage,omitempty"` +} + +// ChatMessage is a single message in the Chat Completions conversation. +type ChatMessage struct { + Role string `json:"role"` // "system" | "user" | "assistant" | "tool" | "function" + Content json.RawMessage `json:"content,omitempty"` + Name string `json:"name,omitempty"` + ToolCalls []ChatToolCall `json:"tool_calls,omitempty"` + ToolCallID string `json:"tool_call_id,omitempty"` + + // Legacy function calling + FunctionCall *ChatFunctionCall `json:"function_call,omitempty"` +} + +// ChatContentPart is a typed content part in a multi-modal message. +type ChatContentPart struct { + Type string `json:"type"` // "text" | "image_url" + Text string `json:"text,omitempty"` + ImageURL *ChatImageURL `json:"image_url,omitempty"` +} + +// ChatImageURL contains the URL for an image content part. +type ChatImageURL struct { + URL string `json:"url"` + Detail string `json:"detail,omitempty"` // "auto" | "low" | "high" +} + +// ChatTool describes a tool available to the model. +type ChatTool struct { + Type string `json:"type"` // "function" + Function *ChatFunction `json:"function,omitempty"` +} + +// ChatFunction describes a function tool definition. +type ChatFunction struct { + Name string `json:"name"` + Description string `json:"description,omitempty"` + Parameters json.RawMessage `json:"parameters,omitempty"` + Strict *bool `json:"strict,omitempty"` +} + +// ChatToolCall represents a tool call made by the assistant. +// Index is only populated in streaming chunks (omitted in non-streaming responses). +type ChatToolCall struct { + Index *int `json:"index,omitempty"` + ID string `json:"id,omitempty"` + Type string `json:"type,omitempty"` // "function" + Function ChatFunctionCall `json:"function"` +} + +// ChatFunctionCall contains the function name and arguments. +type ChatFunctionCall struct { + Name string `json:"name"` + Arguments string `json:"arguments"` +} + +// ChatCompletionsResponse is the non-streaming response from POST /v1/chat/completions. +type ChatCompletionsResponse struct { + ID string `json:"id"` + Object string `json:"object"` // "chat.completion" + Created int64 `json:"created"` + Model string `json:"model"` + Choices []ChatChoice `json:"choices"` + Usage *ChatUsage `json:"usage,omitempty"` + SystemFingerprint string `json:"system_fingerprint,omitempty"` + ServiceTier string `json:"service_tier,omitempty"` +} + +// ChatChoice is a single completion choice. +type ChatChoice struct { + Index int `json:"index"` + Message ChatMessage `json:"message"` + FinishReason string `json:"finish_reason"` // "stop" | "length" | "tool_calls" | "content_filter" +} + +// ChatUsage holds token counts in Chat Completions format. +type ChatUsage struct { + PromptTokens int `json:"prompt_tokens"` + CompletionTokens int `json:"completion_tokens"` + TotalTokens int `json:"total_tokens"` + PromptTokensDetails *ChatTokenDetails `json:"prompt_tokens_details,omitempty"` +} + +// ChatTokenDetails provides a breakdown of token usage. +type ChatTokenDetails struct { + CachedTokens int `json:"cached_tokens,omitempty"` +} + +// ChatCompletionsChunk is a single streaming chunk from POST /v1/chat/completions. +type ChatCompletionsChunk struct { + ID string `json:"id"` + Object string `json:"object"` // "chat.completion.chunk" + Created int64 `json:"created"` + Model string `json:"model"` + Choices []ChatChunkChoice `json:"choices"` + Usage *ChatUsage `json:"usage,omitempty"` + SystemFingerprint string `json:"system_fingerprint,omitempty"` + ServiceTier string `json:"service_tier,omitempty"` +} + +// ChatChunkChoice is a single choice in a streaming chunk. +type ChatChunkChoice struct { + Index int `json:"index"` + Delta ChatDelta `json:"delta"` + FinishReason *string `json:"finish_reason"` // pointer: null when not final +} + +// ChatDelta carries incremental content in a streaming chunk. +type ChatDelta struct { + Role string `json:"role,omitempty"` + Content *string `json:"content,omitempty"` // pointer: omit when not present, null vs "" matters + ToolCalls []ChatToolCall `json:"tool_calls,omitempty"` +} + // --------------------------------------------------------------------------- // Shared constants // --------------------------------------------------------------------------- diff --git a/backend/internal/service/openai_chat_completions.go b/backend/internal/service/openai_chat_completions.go deleted file mode 100644 index c4c95ff2..00000000 --- a/backend/internal/service/openai_chat_completions.go +++ /dev/null @@ -1,513 +0,0 @@ -package service - -import ( - "encoding/json" - "errors" - "strings" - "time" -) - -// ConvertChatCompletionsToResponses converts an OpenAI Chat Completions request to a Responses request. -func ConvertChatCompletionsToResponses(req map[string]any) (map[string]any, error) { - if req == nil { - return nil, errors.New("request is nil") - } - - model := strings.TrimSpace(getString(req["model"])) - if model == "" { - return nil, errors.New("model is required") - } - - messagesRaw, ok := req["messages"] - if !ok { - return nil, errors.New("messages is required") - } - messages, ok := messagesRaw.([]any) - if !ok { - return nil, errors.New("messages must be an array") - } - - input, err := convertChatMessagesToResponsesInput(messages) - if err != nil { - return nil, err - } - - out := make(map[string]any, len(req)+1) - for key, value := range req { - switch key { - case "messages", "max_tokens", "max_completion_tokens", "stream_options", "functions", "function_call": - continue - default: - out[key] = value - } - } - - out["model"] = model - out["input"] = input - - if _, ok := out["max_output_tokens"]; !ok { - if v, ok := req["max_tokens"]; ok { - out["max_output_tokens"] = v - } else if v, ok := req["max_completion_tokens"]; ok { - out["max_output_tokens"] = v - } - } - - if _, ok := out["tools"]; !ok { - if functions, ok := req["functions"].([]any); ok && len(functions) > 0 { - tools := make([]any, 0, len(functions)) - for _, fn := range functions { - if fnMap, ok := fn.(map[string]any); ok { - tools = append(tools, map[string]any{ - "type": "function", - "function": fnMap, - }) - } - } - if len(tools) > 0 { - out["tools"] = tools - } - } - } - - if _, ok := out["tool_choice"]; !ok { - if functionCall, ok := req["function_call"]; ok { - out["tool_choice"] = functionCall - } - } - - return out, nil -} - -// ConvertResponsesToChatCompletion converts an OpenAI Responses response body to Chat Completions format. -func ConvertResponsesToChatCompletion(body []byte) ([]byte, error) { - var resp map[string]any - if err := json.Unmarshal(body, &resp); err != nil { - return nil, err - } - - id := strings.TrimSpace(getString(resp["id"])) - if id == "" { - id = "chatcmpl-" + safeRandomHex(12) - } - model := strings.TrimSpace(getString(resp["model"])) - - created := getInt64(resp["created_at"]) - if created == 0 { - created = getInt64(resp["created"]) - } - if created == 0 { - created = time.Now().Unix() - } - - text, toolCalls := extractResponseTextAndToolCalls(resp) - finishReason := "stop" - if len(toolCalls) > 0 { - finishReason = "tool_calls" - } - - message := map[string]any{ - "role": "assistant", - "content": text, - } - if len(toolCalls) > 0 { - message["tool_calls"] = toolCalls - } - - chatResp := map[string]any{ - "id": id, - "object": "chat.completion", - "created": created, - "model": model, - "choices": []any{ - map[string]any{ - "index": 0, - "message": message, - "finish_reason": finishReason, - }, - }, - } - - if usage := extractResponseUsage(resp); usage != nil { - chatResp["usage"] = usage - } - if fingerprint := strings.TrimSpace(getString(resp["system_fingerprint"])); fingerprint != "" { - chatResp["system_fingerprint"] = fingerprint - } - - return json.Marshal(chatResp) -} - -func convertChatMessagesToResponsesInput(messages []any) ([]any, error) { - input := make([]any, 0, len(messages)) - for _, msg := range messages { - msgMap, ok := msg.(map[string]any) - if !ok { - return nil, errors.New("message must be an object") - } - role := strings.TrimSpace(getString(msgMap["role"])) - if role == "" { - return nil, errors.New("message role is required") - } - - switch role { - case "tool": - callID := strings.TrimSpace(getString(msgMap["tool_call_id"])) - if callID == "" { - callID = strings.TrimSpace(getString(msgMap["id"])) - } - output := extractMessageContentText(msgMap["content"]) - input = append(input, map[string]any{ - "type": "function_call_output", - "call_id": callID, - "output": output, - }) - case "function": - callID := strings.TrimSpace(getString(msgMap["name"])) - output := extractMessageContentText(msgMap["content"]) - input = append(input, map[string]any{ - "type": "function_call_output", - "call_id": callID, - "output": output, - }) - default: - convertedContent := convertChatContent(msgMap["content"]) - toolCalls := []any(nil) - if role == "assistant" { - toolCalls = extractToolCallsFromMessage(msgMap) - } - skipAssistantMessage := role == "assistant" && len(toolCalls) > 0 && isEmptyContent(convertedContent) - if !skipAssistantMessage { - msgItem := map[string]any{ - "role": role, - "content": convertedContent, - } - if name := strings.TrimSpace(getString(msgMap["name"])); name != "" { - msgItem["name"] = name - } - input = append(input, msgItem) - } - if role == "assistant" && len(toolCalls) > 0 { - input = append(input, toolCalls...) - } - } - } - return input, nil -} - -func convertChatContent(content any) any { - switch v := content.(type) { - case nil: - return "" - case string: - return v - case []any: - converted := make([]any, 0, len(v)) - for _, part := range v { - partMap, ok := part.(map[string]any) - if !ok { - converted = append(converted, part) - continue - } - partType := strings.TrimSpace(getString(partMap["type"])) - switch partType { - case "text": - text := getString(partMap["text"]) - if text != "" { - converted = append(converted, map[string]any{ - "type": "input_text", - "text": text, - }) - continue - } - case "image_url": - imageURL := "" - if imageObj, ok := partMap["image_url"].(map[string]any); ok { - imageURL = getString(imageObj["url"]) - } else { - imageURL = getString(partMap["image_url"]) - } - if imageURL != "" { - converted = append(converted, map[string]any{ - "type": "input_image", - "image_url": imageURL, - }) - continue - } - case "input_text", "input_image": - converted = append(converted, partMap) - continue - } - converted = append(converted, partMap) - } - return converted - default: - return v - } -} - -func extractToolCallsFromMessage(msg map[string]any) []any { - var out []any - if toolCalls, ok := msg["tool_calls"].([]any); ok { - for _, call := range toolCalls { - callMap, ok := call.(map[string]any) - if !ok { - continue - } - callID := strings.TrimSpace(getString(callMap["id"])) - if callID == "" { - callID = strings.TrimSpace(getString(callMap["call_id"])) - } - name := "" - args := "" - if fn, ok := callMap["function"].(map[string]any); ok { - name = strings.TrimSpace(getString(fn["name"])) - args = getString(fn["arguments"]) - } - if name == "" && args == "" { - continue - } - item := map[string]any{ - "type": "tool_call", - } - if callID != "" { - item["call_id"] = callID - } - if name != "" { - item["name"] = name - } - if args != "" { - item["arguments"] = args - } - out = append(out, item) - } - } - - if fnCall, ok := msg["function_call"].(map[string]any); ok { - name := strings.TrimSpace(getString(fnCall["name"])) - args := getString(fnCall["arguments"]) - if name != "" || args != "" { - callID := strings.TrimSpace(getString(msg["tool_call_id"])) - if callID == "" { - callID = name - } - item := map[string]any{ - "type": "function_call", - } - if callID != "" { - item["call_id"] = callID - } - if name != "" { - item["name"] = name - } - if args != "" { - item["arguments"] = args - } - out = append(out, item) - } - } - - return out -} - -func extractMessageContentText(content any) string { - switch v := content.(type) { - case string: - return v - case []any: - parts := make([]string, 0, len(v)) - for _, part := range v { - partMap, ok := part.(map[string]any) - if !ok { - continue - } - partType := strings.TrimSpace(getString(partMap["type"])) - if partType == "" || partType == "text" || partType == "output_text" || partType == "input_text" { - text := getString(partMap["text"]) - if text != "" { - parts = append(parts, text) - } - } - } - return strings.Join(parts, "") - default: - return "" - } -} - -func isEmptyContent(content any) bool { - switch v := content.(type) { - case nil: - return true - case string: - return strings.TrimSpace(v) == "" - case []any: - return len(v) == 0 - default: - return false - } -} - -func extractResponseTextAndToolCalls(resp map[string]any) (string, []any) { - output, ok := resp["output"].([]any) - if !ok { - if text, ok := resp["output_text"].(string); ok { - return text, nil - } - return "", nil - } - - textParts := make([]string, 0) - toolCalls := make([]any, 0) - - for _, item := range output { - itemMap, ok := item.(map[string]any) - if !ok { - continue - } - itemType := strings.TrimSpace(getString(itemMap["type"])) - - if itemType == "tool_call" || itemType == "function_call" { - if tc := responseItemToChatToolCall(itemMap); tc != nil { - toolCalls = append(toolCalls, tc) - } - continue - } - - content := itemMap["content"] - switch v := content.(type) { - case string: - if v != "" { - textParts = append(textParts, v) - } - case []any: - for _, part := range v { - partMap, ok := part.(map[string]any) - if !ok { - continue - } - partType := strings.TrimSpace(getString(partMap["type"])) - switch partType { - case "output_text", "text", "input_text": - text := getString(partMap["text"]) - if text != "" { - textParts = append(textParts, text) - } - case "tool_call", "function_call": - if tc := responseItemToChatToolCall(partMap); tc != nil { - toolCalls = append(toolCalls, tc) - } - } - } - } - } - - return strings.Join(textParts, ""), toolCalls -} - -func responseItemToChatToolCall(item map[string]any) map[string]any { - callID := strings.TrimSpace(getString(item["call_id"])) - if callID == "" { - callID = strings.TrimSpace(getString(item["id"])) - } - name := strings.TrimSpace(getString(item["name"])) - arguments := getString(item["arguments"]) - if fn, ok := item["function"].(map[string]any); ok { - if name == "" { - name = strings.TrimSpace(getString(fn["name"])) - } - if arguments == "" { - arguments = getString(fn["arguments"]) - } - } - - if name == "" && arguments == "" && callID == "" { - return nil - } - - if callID == "" { - callID = "call_" + safeRandomHex(6) - } - - return map[string]any{ - "id": callID, - "type": "function", - "function": map[string]any{ - "name": name, - "arguments": arguments, - }, - } -} - -func extractResponseUsage(resp map[string]any) map[string]any { - usage, ok := resp["usage"].(map[string]any) - if !ok { - return nil - } - promptTokens := int(getNumber(usage["input_tokens"])) - completionTokens := int(getNumber(usage["output_tokens"])) - if promptTokens == 0 && completionTokens == 0 { - return nil - } - - return map[string]any{ - "prompt_tokens": promptTokens, - "completion_tokens": completionTokens, - "total_tokens": promptTokens + completionTokens, - } -} - -func getString(value any) string { - switch v := value.(type) { - case string: - return v - case []byte: - return string(v) - case json.Number: - return v.String() - default: - return "" - } -} - -func getNumber(value any) float64 { - switch v := value.(type) { - case float64: - return v - case float32: - return float64(v) - case int: - return float64(v) - case int64: - return float64(v) - case json.Number: - f, _ := v.Float64() - return f - default: - return 0 - } -} - -func getInt64(value any) int64 { - switch v := value.(type) { - case int64: - return v - case int: - return int64(v) - case float64: - return int64(v) - case json.Number: - i, _ := v.Int64() - return i - default: - return 0 - } -} - -func safeRandomHex(byteLength int) string { - value, err := randomHexString(byteLength) - if err != nil || value == "" { - return "000000" - } - return value -} diff --git a/backend/internal/service/openai_chat_completions_forward.go b/backend/internal/service/openai_chat_completions_forward.go deleted file mode 100644 index 0eefdb35..00000000 --- a/backend/internal/service/openai_chat_completions_forward.go +++ /dev/null @@ -1,488 +0,0 @@ -package service - -import ( - "bufio" - "bytes" - "context" - "encoding/json" - "errors" - "fmt" - "io" - "log" - "net/http" - "strings" - "sync/atomic" - "time" - - "github.com/Wei-Shaw/sub2api/internal/util/responseheaders" - "github.com/gin-gonic/gin" -) - -type chatStreamingResult struct { - usage *OpenAIUsage - firstTokenMs *int -} - -func (s *OpenAIGatewayService) forwardChatCompletions(ctx context.Context, c *gin.Context, account *Account, body []byte, includeUsage bool, startTime time.Time) (*OpenAIForwardResult, error) { - // Parse request body once (avoid multiple parse/serialize cycles) - var reqBody map[string]any - if err := json.Unmarshal(body, &reqBody); err != nil { - return nil, fmt.Errorf("parse request: %w", err) - } - - reqModel, _ := reqBody["model"].(string) - reqStream, _ := reqBody["stream"].(bool) - originalModel := reqModel - - bodyModified := false - mappedModel := account.GetMappedModel(reqModel) - if mappedModel != reqModel { - log.Printf("[OpenAI Chat] Model mapping applied: %s -> %s (account: %s)", reqModel, mappedModel, account.Name) - reqBody["model"] = mappedModel - bodyModified = true - } - - if reqStream && includeUsage { - streamOptions, _ := reqBody["stream_options"].(map[string]any) - if streamOptions == nil { - streamOptions = map[string]any{} - } - if _, ok := streamOptions["include_usage"]; !ok { - streamOptions["include_usage"] = true - reqBody["stream_options"] = streamOptions - bodyModified = true - } - } - - if bodyModified { - var err error - body, err = json.Marshal(reqBody) - if err != nil { - return nil, fmt.Errorf("serialize request body: %w", err) - } - } - - // Get access token - token, _, err := s.GetAccessToken(ctx, account) - if err != nil { - return nil, err - } - - upstreamReq, err := s.buildChatCompletionsRequest(ctx, c, account, body, token) - if err != nil { - return nil, err - } - - proxyURL := "" - if account.ProxyID != nil && account.Proxy != nil { - proxyURL = account.Proxy.URL() - } - - if c != nil { - c.Set(OpsUpstreamRequestBodyKey, string(body)) - } - - resp, err := s.httpUpstream.Do(upstreamReq, proxyURL, account.ID, account.Concurrency) - if err != nil { - safeErr := sanitizeUpstreamErrorMessage(err.Error()) - setOpsUpstreamError(c, 0, safeErr, "") - appendOpsUpstreamError(c, OpsUpstreamErrorEvent{ - Platform: account.Platform, - AccountID: account.ID, - AccountName: account.Name, - UpstreamStatusCode: 0, - Kind: "request_error", - Message: safeErr, - }) - c.JSON(http.StatusBadGateway, gin.H{ - "error": gin.H{ - "type": "upstream_error", - "message": "Upstream request failed", - }, - }) - return nil, fmt.Errorf("upstream request failed: %s", safeErr) - } - defer func() { _ = resp.Body.Close() }() - - if resp.StatusCode >= 400 { - if s.shouldFailoverUpstreamError(resp.StatusCode) { - respBody, _ := io.ReadAll(io.LimitReader(resp.Body, 2<<20)) - _ = resp.Body.Close() - resp.Body = io.NopCloser(bytes.NewReader(respBody)) - - upstreamMsg := strings.TrimSpace(extractUpstreamErrorMessage(respBody)) - upstreamMsg = sanitizeUpstreamErrorMessage(upstreamMsg) - upstreamDetail := "" - if s.cfg != nil && s.cfg.Gateway.LogUpstreamErrorBody { - maxBytes := s.cfg.Gateway.LogUpstreamErrorBodyMaxBytes - if maxBytes <= 0 { - maxBytes = 2048 - } - upstreamDetail = truncateString(string(respBody), maxBytes) - } - appendOpsUpstreamError(c, OpsUpstreamErrorEvent{ - Platform: account.Platform, - AccountID: account.ID, - AccountName: account.Name, - UpstreamStatusCode: resp.StatusCode, - UpstreamRequestID: resp.Header.Get("x-request-id"), - Kind: "failover", - Message: upstreamMsg, - Detail: upstreamDetail, - }) - - s.handleFailoverSideEffects(ctx, resp, account) - return nil, &UpstreamFailoverError{StatusCode: resp.StatusCode} - } - return s.handleErrorResponse(ctx, resp, c, account, body) - } - - var usage *OpenAIUsage - var firstTokenMs *int - if reqStream { - streamResult, err := s.handleChatCompletionsStreamingResponse(ctx, resp, c, account, startTime, originalModel, mappedModel) - if err != nil { - return nil, err - } - usage = streamResult.usage - firstTokenMs = streamResult.firstTokenMs - } else { - usage, err = s.handleChatCompletionsNonStreamingResponse(resp, c, originalModel, mappedModel) - if err != nil { - return nil, err - } - } - - if usage == nil { - usage = &OpenAIUsage{} - } - - return &OpenAIForwardResult{ - RequestID: resp.Header.Get("x-request-id"), - Usage: *usage, - Model: originalModel, - Stream: reqStream, - Duration: time.Since(startTime), - FirstTokenMs: firstTokenMs, - }, nil -} - -func (s *OpenAIGatewayService) buildChatCompletionsRequest(ctx context.Context, c *gin.Context, account *Account, body []byte, token string) (*http.Request, error) { - var targetURL string - baseURL := account.GetOpenAIBaseURL() - if baseURL == "" { - targetURL = openaiChatAPIURL - } else { - validatedURL, err := s.validateUpstreamBaseURL(baseURL) - if err != nil { - return nil, err - } - targetURL = validatedURL + "/chat/completions" - } - - req, err := http.NewRequestWithContext(ctx, "POST", targetURL, bytes.NewReader(body)) - if err != nil { - return nil, err - } - - req.Header.Set("authorization", "Bearer "+token) - - for key, values := range c.Request.Header { - lowerKey := strings.ToLower(key) - if openaiChatAllowedHeaders[lowerKey] { - for _, v := range values { - req.Header.Add(key, v) - } - } - } - - customUA := account.GetOpenAIUserAgent() - if customUA != "" { - req.Header.Set("user-agent", customUA) - } - - if req.Header.Get("content-type") == "" { - req.Header.Set("content-type", "application/json") - } - - return req, nil -} - -func (s *OpenAIGatewayService) handleChatCompletionsStreamingResponse(ctx context.Context, resp *http.Response, c *gin.Context, account *Account, startTime time.Time, originalModel, mappedModel string) (*chatStreamingResult, error) { - if s.responseHeaderFilter != nil { - responseheaders.WriteFilteredHeaders(c.Writer.Header(), resp.Header, s.responseHeaderFilter) - } - - c.Header("Content-Type", "text/event-stream") - c.Header("Cache-Control", "no-cache") - c.Header("Connection", "keep-alive") - c.Header("X-Accel-Buffering", "no") - - if v := resp.Header.Get("x-request-id"); v != "" { - c.Header("x-request-id", v) - } - - w := c.Writer - flusher, ok := w.(http.Flusher) - if !ok { - return nil, errors.New("streaming not supported") - } - - usage := &OpenAIUsage{} - var firstTokenMs *int - - scanner := bufio.NewScanner(resp.Body) - maxLineSize := defaultMaxLineSize - if s.cfg != nil && s.cfg.Gateway.MaxLineSize > 0 { - maxLineSize = s.cfg.Gateway.MaxLineSize - } - scanner.Buffer(make([]byte, 64*1024), maxLineSize) - - type scanEvent struct { - line string - err error - } - events := make(chan scanEvent, 16) - done := make(chan struct{}) - sendEvent := func(ev scanEvent) bool { - select { - case events <- ev: - return true - case <-done: - return false - } - } - var lastReadAt int64 - atomic.StoreInt64(&lastReadAt, time.Now().UnixNano()) - go func() { - defer close(events) - for scanner.Scan() { - atomic.StoreInt64(&lastReadAt, time.Now().UnixNano()) - if !sendEvent(scanEvent{line: scanner.Text()}) { - return - } - } - if err := scanner.Err(); err != nil { - _ = sendEvent(scanEvent{err: err}) - } - }() - defer close(done) - - streamInterval := time.Duration(0) - if s.cfg != nil && s.cfg.Gateway.StreamDataIntervalTimeout > 0 { - streamInterval = time.Duration(s.cfg.Gateway.StreamDataIntervalTimeout) * time.Second - } - var intervalTicker *time.Ticker - if streamInterval > 0 { - intervalTicker = time.NewTicker(streamInterval) - defer intervalTicker.Stop() - } - var intervalCh <-chan time.Time - if intervalTicker != nil { - intervalCh = intervalTicker.C - } - - keepaliveInterval := time.Duration(0) - if s.cfg != nil && s.cfg.Gateway.StreamKeepaliveInterval > 0 { - keepaliveInterval = time.Duration(s.cfg.Gateway.StreamKeepaliveInterval) * time.Second - } - var keepaliveTicker *time.Ticker - if keepaliveInterval > 0 { - keepaliveTicker = time.NewTicker(keepaliveInterval) - defer keepaliveTicker.Stop() - } - var keepaliveCh <-chan time.Time - if keepaliveTicker != nil { - keepaliveCh = keepaliveTicker.C - } - lastDataAt := time.Now() - - errorEventSent := false - sendErrorEvent := func(reason string) { - if errorEventSent { - return - } - errorEventSent = true - _, _ = fmt.Fprintf(w, "event: error\ndata: {\"error\":\"%s\"}\n\n", reason) - flusher.Flush() - } - - needModelReplace := originalModel != mappedModel - - for { - select { - case ev, ok := <-events: - if !ok { - return &chatStreamingResult{usage: usage, firstTokenMs: firstTokenMs}, nil - } - if ev.err != nil { - if errors.Is(ev.err, bufio.ErrTooLong) { - log.Printf("SSE line too long: account=%d max_size=%d error=%v", account.ID, maxLineSize, ev.err) - sendErrorEvent("response_too_large") - return &chatStreamingResult{usage: usage, firstTokenMs: firstTokenMs}, ev.err - } - sendErrorEvent("stream_read_error") - return &chatStreamingResult{usage: usage, firstTokenMs: firstTokenMs}, fmt.Errorf("stream read error: %w", ev.err) - } - - line := ev.line - lastDataAt = time.Now() - - if openaiSSEDataRe.MatchString(line) { - data := openaiSSEDataRe.ReplaceAllString(line, "") - - if needModelReplace { - line = s.replaceModelInSSELine(line, mappedModel, originalModel) - } - - if correctedData, corrected := s.toolCorrector.CorrectToolCallsInSSEData(data); corrected { - line = "data: " + correctedData - } - - if _, err := fmt.Fprintf(w, "%s\n", line); err != nil { - sendErrorEvent("write_failed") - return &chatStreamingResult{usage: usage, firstTokenMs: firstTokenMs}, err - } - flusher.Flush() - - if firstTokenMs == nil { - if event := parseChatStreamEvent(data); event != nil { - if chatChunkHasDelta(event) { - ms := int(time.Since(startTime).Milliseconds()) - firstTokenMs = &ms - } - applyChatUsageFromEvent(event, usage) - } - } else { - if event := parseChatStreamEvent(data); event != nil { - applyChatUsageFromEvent(event, usage) - } - } - } else { - if _, err := fmt.Fprintf(w, "%s\n", line); err != nil { - sendErrorEvent("write_failed") - return &chatStreamingResult{usage: usage, firstTokenMs: firstTokenMs}, err - } - flusher.Flush() - } - - case <-intervalCh: - lastRead := time.Unix(0, atomic.LoadInt64(&lastReadAt)) - if time.Since(lastRead) < streamInterval { - continue - } - log.Printf("Stream data interval timeout: account=%d model=%s interval=%s", account.ID, originalModel, streamInterval) - if s.rateLimitService != nil { - s.rateLimitService.HandleStreamTimeout(ctx, account, originalModel) - } - sendErrorEvent("stream_timeout") - return &chatStreamingResult{usage: usage, firstTokenMs: firstTokenMs}, fmt.Errorf("stream data interval timeout") - - case <-keepaliveCh: - if time.Since(lastDataAt) < keepaliveInterval { - continue - } - if _, err := fmt.Fprint(w, ":\n\n"); err != nil { - return &chatStreamingResult{usage: usage, firstTokenMs: firstTokenMs}, err - } - flusher.Flush() - } - } -} - -func (s *OpenAIGatewayService) handleChatCompletionsNonStreamingResponse(resp *http.Response, c *gin.Context, originalModel, mappedModel string) (*OpenAIUsage, error) { - body, err := io.ReadAll(resp.Body) - if err != nil { - return nil, err - } - - usage := &OpenAIUsage{} - var parsed map[string]any - if json.Unmarshal(body, &parsed) == nil { - if usageMap, ok := parsed["usage"].(map[string]any); ok { - applyChatUsageFromMap(usageMap, usage) - } - } - - if originalModel != mappedModel { - body = s.replaceModelInResponseBody(body, mappedModel, originalModel) - } - body = s.correctToolCallsInResponseBody(body) - - if s.responseHeaderFilter != nil { - responseheaders.WriteFilteredHeaders(c.Writer.Header(), resp.Header, s.responseHeaderFilter) - } - - contentType := "application/json" - if s.cfg != nil && !s.cfg.Security.ResponseHeaders.Enabled { - if upstreamType := resp.Header.Get("Content-Type"); upstreamType != "" { - contentType = upstreamType - } - } - - c.Data(resp.StatusCode, contentType, body) - return usage, nil -} - -func parseChatStreamEvent(data string) map[string]any { - if data == "" || data == "[DONE]" { - return nil - } - var event map[string]any - if json.Unmarshal([]byte(data), &event) != nil { - return nil - } - return event -} - -func chatChunkHasDelta(event map[string]any) bool { - choices, ok := event["choices"].([]any) - if !ok { - return false - } - for _, choice := range choices { - choiceMap, ok := choice.(map[string]any) - if !ok { - continue - } - delta, ok := choiceMap["delta"].(map[string]any) - if !ok { - continue - } - if content, ok := delta["content"].(string); ok && strings.TrimSpace(content) != "" { - return true - } - if toolCalls, ok := delta["tool_calls"].([]any); ok && len(toolCalls) > 0 { - return true - } - if functionCall, ok := delta["function_call"].(map[string]any); ok && len(functionCall) > 0 { - return true - } - } - return false -} - -func applyChatUsageFromEvent(event map[string]any, usage *OpenAIUsage) { - if event == nil || usage == nil { - return - } - usageMap, ok := event["usage"].(map[string]any) - if !ok { - return - } - applyChatUsageFromMap(usageMap, usage) -} - -func applyChatUsageFromMap(usageMap map[string]any, usage *OpenAIUsage) { - if usageMap == nil || usage == nil { - return - } - promptTokens := int(getNumber(usageMap["prompt_tokens"])) - completionTokens := int(getNumber(usageMap["completion_tokens"])) - if promptTokens > 0 { - usage.InputTokens = promptTokens - } - if completionTokens > 0 { - usage.OutputTokens = completionTokens - } -} diff --git a/backend/internal/service/openai_chat_completions_test.go b/backend/internal/service/openai_chat_completions_test.go deleted file mode 100644 index 635bda23..00000000 --- a/backend/internal/service/openai_chat_completions_test.go +++ /dev/null @@ -1,132 +0,0 @@ -package service - -import ( - "encoding/json" - "testing" - - "github.com/stretchr/testify/require" -) - -func TestConvertChatCompletionsToResponses(t *testing.T) { - req := map[string]any{ - "model": "gpt-4o", - "messages": []any{ - map[string]any{ - "role": "user", - "content": "hello", - }, - map[string]any{ - "role": "assistant", - "tool_calls": []any{ - map[string]any{ - "id": "call_1", - "type": "function", - "function": map[string]any{ - "name": "ping", - "arguments": "{}", - }, - }, - }, - }, - map[string]any{ - "role": "tool", - "tool_call_id": "call_1", - "content": "ok", - "response": "ignored", - "response_time": 1, - }, - }, - "functions": []any{ - map[string]any{ - "name": "ping", - "description": "ping tool", - "parameters": map[string]any{"type": "object"}, - }, - }, - "function_call": map[string]any{"name": "ping"}, - } - - converted, err := ConvertChatCompletionsToResponses(req) - require.NoError(t, err) - require.Equal(t, "gpt-4o", converted["model"]) - - input, ok := converted["input"].([]any) - require.True(t, ok) - require.Len(t, input, 3) - - toolCall := findInputItemByType(input, "tool_call") - require.NotNil(t, toolCall) - require.Equal(t, "call_1", toolCall["call_id"]) - - toolOutput := findInputItemByType(input, "function_call_output") - require.NotNil(t, toolOutput) - require.Equal(t, "call_1", toolOutput["call_id"]) - - tools, ok := converted["tools"].([]any) - require.True(t, ok) - require.Len(t, tools, 1) - - require.Equal(t, map[string]any{"name": "ping"}, converted["tool_choice"]) -} - -func TestConvertResponsesToChatCompletion(t *testing.T) { - resp := map[string]any{ - "id": "resp_123", - "model": "gpt-4o", - "created_at": 1700000000, - "output": []any{ - map[string]any{ - "type": "message", - "role": "assistant", - "content": []any{ - map[string]any{ - "type": "output_text", - "text": "hi", - }, - }, - }, - }, - "usage": map[string]any{ - "input_tokens": 2, - "output_tokens": 3, - }, - } - body, err := json.Marshal(resp) - require.NoError(t, err) - - converted, err := ConvertResponsesToChatCompletion(body) - require.NoError(t, err) - - var chat map[string]any - require.NoError(t, json.Unmarshal(converted, &chat)) - require.Equal(t, "chat.completion", chat["object"]) - - choices, ok := chat["choices"].([]any) - require.True(t, ok) - require.Len(t, choices, 1) - - choice, ok := choices[0].(map[string]any) - require.True(t, ok) - message, ok := choice["message"].(map[string]any) - require.True(t, ok) - require.Equal(t, "hi", message["content"]) - - usage, ok := chat["usage"].(map[string]any) - require.True(t, ok) - require.Equal(t, float64(2), usage["prompt_tokens"]) - require.Equal(t, float64(3), usage["completion_tokens"]) - require.Equal(t, float64(5), usage["total_tokens"]) -} - -func findInputItemByType(items []any, itemType string) map[string]any { - for _, item := range items { - itemMap, ok := item.(map[string]any) - if !ok { - continue - } - if itemMap["type"] == itemType { - return itemMap - } - } - return nil -} diff --git a/backend/internal/service/openai_gateway_chat_completions.go b/backend/internal/service/openai_gateway_chat_completions.go new file mode 100644 index 00000000..f893eeb9 --- /dev/null +++ b/backend/internal/service/openai_gateway_chat_completions.go @@ -0,0 +1,512 @@ +package service + +import ( + "bufio" + "bytes" + "context" + "encoding/json" + "errors" + "fmt" + "io" + "net/http" + "strings" + "time" + + "github.com/Wei-Shaw/sub2api/internal/pkg/apicompat" + "github.com/Wei-Shaw/sub2api/internal/pkg/logger" + "github.com/Wei-Shaw/sub2api/internal/util/responseheaders" + "github.com/gin-gonic/gin" + "go.uber.org/zap" +) + +// ForwardAsChatCompletions accepts a Chat Completions request body, converts it +// to OpenAI Responses API format, forwards to the OpenAI upstream, and converts +// the response back to Chat Completions format. All account types (OAuth and API +// Key) go through the Responses API conversion path since the upstream only +// exposes the /v1/responses endpoint. +func (s *OpenAIGatewayService) ForwardAsChatCompletions( + ctx context.Context, + c *gin.Context, + account *Account, + body []byte, + promptCacheKey string, + defaultMappedModel string, +) (*OpenAIForwardResult, error) { + startTime := time.Now() + + // 1. Parse Chat Completions request + var chatReq apicompat.ChatCompletionsRequest + if err := json.Unmarshal(body, &chatReq); err != nil { + return nil, fmt.Errorf("parse chat completions request: %w", err) + } + originalModel := chatReq.Model + clientStream := chatReq.Stream + includeUsage := chatReq.StreamOptions != nil && chatReq.StreamOptions.IncludeUsage + + // 2. Convert to Responses and forward + // ChatCompletionsToResponses always sets Stream=true (upstream always streams). + responsesReq, err := apicompat.ChatCompletionsToResponses(&chatReq) + if err != nil { + return nil, fmt.Errorf("convert chat completions to responses: %w", err) + } + + // 3. Model mapping + mappedModel := account.GetMappedModel(originalModel) + if mappedModel == originalModel && defaultMappedModel != "" { + mappedModel = defaultMappedModel + } + responsesReq.Model = mappedModel + + logger.L().Debug("openai chat_completions: model mapping applied", + zap.Int64("account_id", account.ID), + zap.String("original_model", originalModel), + zap.String("mapped_model", mappedModel), + zap.Bool("stream", clientStream), + ) + + // 4. Marshal Responses request body, then apply OAuth codex transform + responsesBody, err := json.Marshal(responsesReq) + if err != nil { + return nil, fmt.Errorf("marshal responses request: %w", err) + } + + if account.Type == AccountTypeOAuth { + var reqBody map[string]any + if err := json.Unmarshal(responsesBody, &reqBody); err != nil { + return nil, fmt.Errorf("unmarshal for codex transform: %w", err) + } + codexResult := applyCodexOAuthTransform(reqBody, false, false) + if codexResult.PromptCacheKey != "" { + promptCacheKey = codexResult.PromptCacheKey + } else if promptCacheKey != "" { + reqBody["prompt_cache_key"] = promptCacheKey + } + responsesBody, err = json.Marshal(reqBody) + if err != nil { + return nil, fmt.Errorf("remarshal after codex transform: %w", err) + } + } + + // 5. Get access token + token, _, err := s.GetAccessToken(ctx, account) + if err != nil { + return nil, fmt.Errorf("get access token: %w", err) + } + + // 6. Build upstream request + upstreamReq, err := s.buildUpstreamRequest(ctx, c, account, responsesBody, token, true, promptCacheKey, false) + if err != nil { + return nil, fmt.Errorf("build upstream request: %w", err) + } + + if promptCacheKey != "" { + upstreamReq.Header.Set("session_id", generateSessionUUID(promptCacheKey)) + } + + // 7. Send request + proxyURL := "" + if account.Proxy != nil { + proxyURL = account.Proxy.URL() + } + resp, err := s.httpUpstream.Do(upstreamReq, proxyURL, account.ID, account.Concurrency) + if err != nil { + safeErr := sanitizeUpstreamErrorMessage(err.Error()) + setOpsUpstreamError(c, 0, safeErr, "") + appendOpsUpstreamError(c, OpsUpstreamErrorEvent{ + Platform: account.Platform, + AccountID: account.ID, + AccountName: account.Name, + UpstreamStatusCode: 0, + Kind: "request_error", + Message: safeErr, + }) + writeChatCompletionsError(c, http.StatusBadGateway, "upstream_error", "Upstream request failed") + return nil, fmt.Errorf("upstream request failed: %s", safeErr) + } + defer func() { _ = resp.Body.Close() }() + + // 8. Handle error response with failover + if resp.StatusCode >= 400 { + respBody, _ := io.ReadAll(io.LimitReader(resp.Body, 2<<20)) + _ = resp.Body.Close() + resp.Body = io.NopCloser(bytes.NewReader(respBody)) + + upstreamMsg := strings.TrimSpace(extractUpstreamErrorMessage(respBody)) + upstreamMsg = sanitizeUpstreamErrorMessage(upstreamMsg) + if s.shouldFailoverOpenAIUpstreamResponse(resp.StatusCode, upstreamMsg, respBody) { + upstreamDetail := "" + if s.cfg != nil && s.cfg.Gateway.LogUpstreamErrorBody { + maxBytes := s.cfg.Gateway.LogUpstreamErrorBodyMaxBytes + if maxBytes <= 0 { + maxBytes = 2048 + } + upstreamDetail = truncateString(string(respBody), maxBytes) + } + appendOpsUpstreamError(c, OpsUpstreamErrorEvent{ + Platform: account.Platform, + AccountID: account.ID, + AccountName: account.Name, + UpstreamStatusCode: resp.StatusCode, + UpstreamRequestID: resp.Header.Get("x-request-id"), + Kind: "failover", + Message: upstreamMsg, + Detail: upstreamDetail, + }) + if s.rateLimitService != nil { + s.rateLimitService.HandleUpstreamError(ctx, account, resp.StatusCode, resp.Header, respBody) + } + return nil, &UpstreamFailoverError{ + StatusCode: resp.StatusCode, + ResponseBody: respBody, + RetryableOnSameAccount: account.IsPoolMode() && (isPoolModeRetryableStatus(resp.StatusCode) || isOpenAITransientProcessingError(resp.StatusCode, upstreamMsg, respBody)), + } + } + return s.handleChatCompletionsErrorResponse(resp, c, account) + } + + // 9. Handle normal response + var result *OpenAIForwardResult + var handleErr error + if clientStream { + result, handleErr = s.handleChatStreamingResponse(resp, c, originalModel, mappedModel, includeUsage, startTime) + } else { + result, handleErr = s.handleChatBufferedStreamingResponse(resp, c, originalModel, mappedModel, startTime) + } + + // Propagate ServiceTier and ReasoningEffort to result for billing + if handleErr == nil && result != nil { + if responsesReq.ServiceTier != "" { + st := responsesReq.ServiceTier + result.ServiceTier = &st + } + if responsesReq.Reasoning != nil && responsesReq.Reasoning.Effort != "" { + re := responsesReq.Reasoning.Effort + result.ReasoningEffort = &re + } + } + + // Extract and save Codex usage snapshot from response headers (for OAuth accounts) + if handleErr == nil && account.Type == AccountTypeOAuth { + if snapshot := ParseCodexRateLimitHeaders(resp.Header); snapshot != nil { + s.updateCodexUsageSnapshot(ctx, account.ID, snapshot) + } + } + + return result, handleErr +} + +// handleChatCompletionsErrorResponse reads an upstream error and returns it in +// OpenAI Chat Completions error format. +func (s *OpenAIGatewayService) handleChatCompletionsErrorResponse( + resp *http.Response, + c *gin.Context, + account *Account, +) (*OpenAIForwardResult, error) { + return s.handleCompatErrorResponse(resp, c, account, writeChatCompletionsError) +} + +// handleChatBufferedStreamingResponse reads all Responses SSE events from the +// upstream, finds the terminal event, converts to a Chat Completions JSON +// response, and writes it to the client. +func (s *OpenAIGatewayService) handleChatBufferedStreamingResponse( + resp *http.Response, + c *gin.Context, + originalModel string, + mappedModel string, + startTime time.Time, +) (*OpenAIForwardResult, error) { + requestID := resp.Header.Get("x-request-id") + + scanner := bufio.NewScanner(resp.Body) + maxLineSize := defaultMaxLineSize + if s.cfg != nil && s.cfg.Gateway.MaxLineSize > 0 { + maxLineSize = s.cfg.Gateway.MaxLineSize + } + scanner.Buffer(make([]byte, 0, 64*1024), maxLineSize) + + var finalResponse *apicompat.ResponsesResponse + var usage OpenAIUsage + + for scanner.Scan() { + line := scanner.Text() + if !strings.HasPrefix(line, "data: ") || line == "data: [DONE]" { + continue + } + payload := line[6:] + + var event apicompat.ResponsesStreamEvent + if err := json.Unmarshal([]byte(payload), &event); err != nil { + logger.L().Warn("openai chat_completions buffered: failed to parse event", + zap.Error(err), + zap.String("request_id", requestID), + ) + continue + } + + if (event.Type == "response.completed" || event.Type == "response.incomplete" || event.Type == "response.failed") && + event.Response != nil { + finalResponse = event.Response + if event.Response.Usage != nil { + usage = OpenAIUsage{ + InputTokens: event.Response.Usage.InputTokens, + OutputTokens: event.Response.Usage.OutputTokens, + } + if event.Response.Usage.InputTokensDetails != nil { + usage.CacheReadInputTokens = event.Response.Usage.InputTokensDetails.CachedTokens + } + } + } + } + + if err := scanner.Err(); err != nil { + if !errors.Is(err, context.Canceled) && !errors.Is(err, context.DeadlineExceeded) { + logger.L().Warn("openai chat_completions buffered: read error", + zap.Error(err), + zap.String("request_id", requestID), + ) + } + } + + if finalResponse == nil { + writeChatCompletionsError(c, http.StatusBadGateway, "api_error", "Upstream stream ended without a terminal response event") + return nil, fmt.Errorf("upstream stream ended without terminal event") + } + + chatResp := apicompat.ResponsesToChatCompletions(finalResponse, originalModel) + + if s.responseHeaderFilter != nil { + responseheaders.WriteFilteredHeaders(c.Writer.Header(), resp.Header, s.responseHeaderFilter) + } + c.JSON(http.StatusOK, chatResp) + + return &OpenAIForwardResult{ + RequestID: requestID, + Usage: usage, + Model: originalModel, + BillingModel: mappedModel, + Stream: false, + Duration: time.Since(startTime), + }, nil +} + +// handleChatStreamingResponse reads Responses SSE events from upstream, +// converts each to Chat Completions SSE chunks, and writes them to the client. +func (s *OpenAIGatewayService) handleChatStreamingResponse( + resp *http.Response, + c *gin.Context, + originalModel string, + mappedModel string, + includeUsage bool, + startTime time.Time, +) (*OpenAIForwardResult, error) { + requestID := resp.Header.Get("x-request-id") + + if s.responseHeaderFilter != nil { + responseheaders.WriteFilteredHeaders(c.Writer.Header(), resp.Header, s.responseHeaderFilter) + } + c.Writer.Header().Set("Content-Type", "text/event-stream") + c.Writer.Header().Set("Cache-Control", "no-cache") + c.Writer.Header().Set("Connection", "keep-alive") + c.Writer.Header().Set("X-Accel-Buffering", "no") + c.Writer.WriteHeader(http.StatusOK) + + state := apicompat.NewResponsesEventToChatState() + state.Model = originalModel + state.IncludeUsage = includeUsage + + var usage OpenAIUsage + var firstTokenMs *int + firstChunk := true + + scanner := bufio.NewScanner(resp.Body) + maxLineSize := defaultMaxLineSize + if s.cfg != nil && s.cfg.Gateway.MaxLineSize > 0 { + maxLineSize = s.cfg.Gateway.MaxLineSize + } + scanner.Buffer(make([]byte, 0, 64*1024), maxLineSize) + + resultWithUsage := func() *OpenAIForwardResult { + return &OpenAIForwardResult{ + RequestID: requestID, + Usage: usage, + Model: originalModel, + BillingModel: mappedModel, + Stream: true, + Duration: time.Since(startTime), + FirstTokenMs: firstTokenMs, + } + } + + processDataLine := func(payload string) bool { + if firstChunk { + firstChunk = false + ms := int(time.Since(startTime).Milliseconds()) + firstTokenMs = &ms + } + + var event apicompat.ResponsesStreamEvent + if err := json.Unmarshal([]byte(payload), &event); err != nil { + logger.L().Warn("openai chat_completions stream: failed to parse event", + zap.Error(err), + zap.String("request_id", requestID), + ) + return false + } + + // Extract usage from completion events + if (event.Type == "response.completed" || event.Type == "response.incomplete" || event.Type == "response.failed") && + event.Response != nil && event.Response.Usage != nil { + usage = OpenAIUsage{ + InputTokens: event.Response.Usage.InputTokens, + OutputTokens: event.Response.Usage.OutputTokens, + } + if event.Response.Usage.InputTokensDetails != nil { + usage.CacheReadInputTokens = event.Response.Usage.InputTokensDetails.CachedTokens + } + } + + chunks := apicompat.ResponsesEventToChatChunks(&event, state) + for _, chunk := range chunks { + sse, err := apicompat.ChatChunkToSSE(chunk) + if err != nil { + logger.L().Warn("openai chat_completions stream: failed to marshal chunk", + zap.Error(err), + zap.String("request_id", requestID), + ) + continue + } + if _, err := fmt.Fprint(c.Writer, sse); err != nil { + logger.L().Info("openai chat_completions stream: client disconnected", + zap.String("request_id", requestID), + ) + return true + } + } + if len(chunks) > 0 { + c.Writer.Flush() + } + return false + } + + finalizeStream := func() (*OpenAIForwardResult, error) { + if finalChunks := apicompat.FinalizeResponsesChatStream(state); len(finalChunks) > 0 { + for _, chunk := range finalChunks { + sse, err := apicompat.ChatChunkToSSE(chunk) + if err != nil { + continue + } + fmt.Fprint(c.Writer, sse) //nolint:errcheck + } + } + // Send [DONE] sentinel + fmt.Fprint(c.Writer, "data: [DONE]\n\n") //nolint:errcheck + c.Writer.Flush() + return resultWithUsage(), nil + } + + handleScanErr := func(err error) { + if err != nil && !errors.Is(err, context.Canceled) && !errors.Is(err, context.DeadlineExceeded) { + logger.L().Warn("openai chat_completions stream: read error", + zap.Error(err), + zap.String("request_id", requestID), + ) + } + } + + // Determine keepalive interval + keepaliveInterval := time.Duration(0) + if s.cfg != nil && s.cfg.Gateway.StreamKeepaliveInterval > 0 { + keepaliveInterval = time.Duration(s.cfg.Gateway.StreamKeepaliveInterval) * time.Second + } + + // No keepalive: fast synchronous path + if keepaliveInterval <= 0 { + for scanner.Scan() { + line := scanner.Text() + if !strings.HasPrefix(line, "data: ") || line == "data: [DONE]" { + continue + } + if processDataLine(line[6:]) { + return resultWithUsage(), nil + } + } + handleScanErr(scanner.Err()) + return finalizeStream() + } + + // With keepalive: goroutine + channel + select + type scanEvent struct { + line string + err error + } + events := make(chan scanEvent, 16) + done := make(chan struct{}) + sendEvent := func(ev scanEvent) bool { + select { + case events <- ev: + return true + case <-done: + return false + } + } + go func() { + defer close(events) + for scanner.Scan() { + if !sendEvent(scanEvent{line: scanner.Text()}) { + return + } + } + if err := scanner.Err(); err != nil { + _ = sendEvent(scanEvent{err: err}) + } + }() + defer close(done) + + keepaliveTicker := time.NewTicker(keepaliveInterval) + defer keepaliveTicker.Stop() + lastDataAt := time.Now() + + for { + select { + case ev, ok := <-events: + if !ok { + return finalizeStream() + } + if ev.err != nil { + handleScanErr(ev.err) + return finalizeStream() + } + lastDataAt = time.Now() + line := ev.line + if !strings.HasPrefix(line, "data: ") || line == "data: [DONE]" { + continue + } + if processDataLine(line[6:]) { + return resultWithUsage(), nil + } + + case <-keepaliveTicker.C: + if time.Since(lastDataAt) < keepaliveInterval { + continue + } + // Send SSE comment as keepalive + if _, err := fmt.Fprint(c.Writer, ":\n\n"); err != nil { + logger.L().Info("openai chat_completions stream: client disconnected during keepalive", + zap.String("request_id", requestID), + ) + return resultWithUsage(), nil + } + c.Writer.Flush() + } + } +} + +// writeChatCompletionsError writes an error response in OpenAI Chat Completions format. +func writeChatCompletionsError(c *gin.Context, statusCode int, errType, message string) { + c.JSON(statusCode, gin.H{ + "error": gin.H{ + "type": errType, + "message": message, + }, + }) +} diff --git a/backend/internal/service/openai_gateway_messages.go b/backend/internal/service/openai_gateway_messages.go index 46fc68a9..e4a3d9c0 100644 --- a/backend/internal/service/openai_gateway_messages.go +++ b/backend/internal/service/openai_gateway_messages.go @@ -172,7 +172,7 @@ func (s *OpenAIGatewayService) ForwardAsAnthropic( return nil, &UpstreamFailoverError{ StatusCode: resp.StatusCode, ResponseBody: respBody, - RetryableOnSameAccount: account.IsPoolMode() && isOpenAITransientProcessingError(resp.StatusCode, upstreamMsg, respBody), + RetryableOnSameAccount: account.IsPoolMode() && (isPoolModeRetryableStatus(resp.StatusCode) || isOpenAITransientProcessingError(resp.StatusCode, upstreamMsg, respBody)), } } // Non-failover error: return Anthropic-formatted error to client @@ -219,54 +219,7 @@ func (s *OpenAIGatewayService) handleAnthropicErrorResponse( c *gin.Context, account *Account, ) (*OpenAIForwardResult, error) { - body, _ := io.ReadAll(io.LimitReader(resp.Body, 2<<20)) - - upstreamMsg := strings.TrimSpace(extractUpstreamErrorMessage(body)) - if upstreamMsg == "" { - upstreamMsg = fmt.Sprintf("Upstream error: %d", resp.StatusCode) - } - upstreamMsg = sanitizeUpstreamErrorMessage(upstreamMsg) - - // Record upstream error details for ops logging - upstreamDetail := "" - if s.cfg != nil && s.cfg.Gateway.LogUpstreamErrorBody { - maxBytes := s.cfg.Gateway.LogUpstreamErrorBodyMaxBytes - if maxBytes <= 0 { - maxBytes = 2048 - } - upstreamDetail = truncateString(string(body), maxBytes) - } - setOpsUpstreamError(c, resp.StatusCode, upstreamMsg, upstreamDetail) - - // Apply error passthrough rules (matches handleErrorResponse pattern in openai_gateway_service.go) - if status, errType, errMsg, matched := applyErrorPassthroughRule( - c, account.Platform, resp.StatusCode, body, - http.StatusBadGateway, "api_error", "Upstream request failed", - ); matched { - writeAnthropicError(c, status, errType, errMsg) - if upstreamMsg == "" { - upstreamMsg = errMsg - } - if upstreamMsg == "" { - return nil, fmt.Errorf("upstream error: %d (passthrough rule matched)", resp.StatusCode) - } - return nil, fmt.Errorf("upstream error: %d (passthrough rule matched) message=%s", resp.StatusCode, upstreamMsg) - } - - errType := "api_error" - switch { - case resp.StatusCode == 400: - errType = "invalid_request_error" - case resp.StatusCode == 404: - errType = "not_found_error" - case resp.StatusCode == 429: - errType = "rate_limit_error" - case resp.StatusCode >= 500: - errType = "api_error" - } - - writeAnthropicError(c, resp.StatusCode, errType, upstreamMsg) - return nil, fmt.Errorf("upstream error: %d %s", resp.StatusCode, upstreamMsg) + return s.handleCompatErrorResponse(resp, c, account, writeAnthropicError) } // handleAnthropicBufferedStreamingResponse reads all Responses SSE events from diff --git a/backend/internal/service/openai_gateway_service.go b/backend/internal/service/openai_gateway_service.go index 54068f2b..5edf890e 100644 --- a/backend/internal/service/openai_gateway_service.go +++ b/backend/internal/service/openai_gateway_service.go @@ -12,7 +12,6 @@ import ( "io" "math/rand" "net/http" - "regexp" "sort" "strconv" "strings" @@ -37,7 +36,6 @@ const ( chatgptCodexURL = "https://chatgpt.com/backend-api/codex/responses" // OpenAI Platform API for API Key accounts (fallback) openaiPlatformAPIURL = "https://api.openai.com/v1/responses" - openaiChatAPIURL = "https://api.openai.com/v1/chat/completions" openaiStickySessionTTL = time.Hour // 粘性会话TTL codexCLIUserAgent = "codex_cli_rs/0.104.0" // codex_cli_only 拒绝时单个请求头日志长度上限(字符) @@ -56,16 +54,6 @@ const ( codexCLIVersion = "0.104.0" ) -// OpenAIChatCompletionsBodyKey stores the original chat-completions payload in gin.Context. -const OpenAIChatCompletionsBodyKey = "openai_chat_completions_body" - -// OpenAIChatCompletionsIncludeUsageKey stores stream_options.include_usage in gin.Context. -const OpenAIChatCompletionsIncludeUsageKey = "openai_chat_completions_include_usage" - -// openaiSSEDataRe matches SSE data lines with optional whitespace after colon. -// Some upstream APIs return non-standard "data:" without space (should be "data: "). -var openaiSSEDataRe = regexp.MustCompile(`^data:\s*`) - // OpenAI allowed headers whitelist (for non-passthrough). var openaiAllowedHeaders = map[string]bool{ "accept-language": true, @@ -109,19 +97,6 @@ var codexCLIOnlyDebugHeaderWhitelist = []string{ "X-Real-IP", } -// OpenAI chat-completions allowed headers (extend responses whitelist). -var openaiChatAllowedHeaders = map[string]bool{ - "accept-language": true, - "content-type": true, - "conversation_id": true, - "user-agent": true, - "originator": true, - "session_id": true, - "openai-organization": true, - "openai-project": true, - "openai-beta": true, -} - // OpenAICodexUsageSnapshot represents Codex API usage limits from response headers type OpenAICodexUsageSnapshot struct { PrimaryUsedPercent *float64 `json:"primary_used_percent,omitempty"` @@ -1602,23 +1577,6 @@ func (s *OpenAIGatewayService) Forward(ctx context.Context, c *gin.Context, acco return nil, errors.New("codex_cli_only restriction: only codex official clients are allowed") } - if c != nil && account != nil && account.Type == AccountTypeAPIKey { - if raw, ok := c.Get(OpenAIChatCompletionsBodyKey); ok { - if rawBody, ok := raw.([]byte); ok && len(rawBody) > 0 { - includeUsage := false - if v, ok := c.Get(OpenAIChatCompletionsIncludeUsageKey); ok { - if flag, ok := v.(bool); ok { - includeUsage = flag - } - } - if passthroughWriter, ok := c.Writer.(interface{ SetPassthrough() }); ok { - passthroughWriter.SetPassthrough() - } - return s.forwardChatCompletions(ctx, c, account, rawBody, includeUsage, startTime) - } - } - } - originalBody := body reqModel, reqStream, promptCacheKey := extractOpenAIRequestMetaFromBody(body) originalModel := reqModel @@ -2989,6 +2947,120 @@ func (s *OpenAIGatewayService) handleErrorResponse( return nil, fmt.Errorf("upstream error: %d message=%s", resp.StatusCode, upstreamMsg) } +// compatErrorWriter is the signature for format-specific error writers used by +// the compat paths (Chat Completions and Anthropic Messages). +type compatErrorWriter func(c *gin.Context, statusCode int, errType, message string) + +// handleCompatErrorResponse is the shared non-failover error handler for the +// Chat Completions and Anthropic Messages compat paths. It mirrors the logic of +// handleErrorResponse (passthrough rules, ShouldHandleErrorCode, rate-limit +// tracking, secondary failover) but delegates the final error write to the +// format-specific writer function. +func (s *OpenAIGatewayService) handleCompatErrorResponse( + resp *http.Response, + c *gin.Context, + account *Account, + writeError compatErrorWriter, +) (*OpenAIForwardResult, error) { + body, _ := io.ReadAll(io.LimitReader(resp.Body, 2<<20)) + + upstreamMsg := strings.TrimSpace(extractUpstreamErrorMessage(body)) + if upstreamMsg == "" { + upstreamMsg = fmt.Sprintf("Upstream error: %d", resp.StatusCode) + } + upstreamMsg = sanitizeUpstreamErrorMessage(upstreamMsg) + + upstreamDetail := "" + if s.cfg != nil && s.cfg.Gateway.LogUpstreamErrorBody { + maxBytes := s.cfg.Gateway.LogUpstreamErrorBodyMaxBytes + if maxBytes <= 0 { + maxBytes = 2048 + } + upstreamDetail = truncateString(string(body), maxBytes) + } + setOpsUpstreamError(c, resp.StatusCode, upstreamMsg, upstreamDetail) + + // Apply error passthrough rules + if status, errType, errMsg, matched := applyErrorPassthroughRule( + c, account.Platform, resp.StatusCode, body, + http.StatusBadGateway, "api_error", "Upstream request failed", + ); matched { + writeError(c, status, errType, errMsg) + if upstreamMsg == "" { + upstreamMsg = errMsg + } + if upstreamMsg == "" { + return nil, fmt.Errorf("upstream error: %d (passthrough rule matched)", resp.StatusCode) + } + return nil, fmt.Errorf("upstream error: %d (passthrough rule matched) message=%s", resp.StatusCode, upstreamMsg) + } + + // Check custom error codes — if the account does not handle this status, + // return a generic error without exposing upstream details. + if !account.ShouldHandleErrorCode(resp.StatusCode) { + appendOpsUpstreamError(c, OpsUpstreamErrorEvent{ + Platform: account.Platform, + AccountID: account.ID, + AccountName: account.Name, + UpstreamStatusCode: resp.StatusCode, + UpstreamRequestID: resp.Header.Get("x-request-id"), + Kind: "http_error", + Message: upstreamMsg, + Detail: upstreamDetail, + }) + writeError(c, http.StatusInternalServerError, "api_error", "Upstream gateway error") + if upstreamMsg == "" { + return nil, fmt.Errorf("upstream error: %d (not in custom error codes)", resp.StatusCode) + } + return nil, fmt.Errorf("upstream error: %d (not in custom error codes) message=%s", resp.StatusCode, upstreamMsg) + } + + // Track rate limits and decide whether to trigger secondary failover. + shouldDisable := false + if s.rateLimitService != nil { + shouldDisable = s.rateLimitService.HandleUpstreamError( + c.Request.Context(), account, resp.StatusCode, resp.Header, body, + ) + } + kind := "http_error" + if shouldDisable { + kind = "failover" + } + appendOpsUpstreamError(c, OpsUpstreamErrorEvent{ + Platform: account.Platform, + AccountID: account.ID, + AccountName: account.Name, + UpstreamStatusCode: resp.StatusCode, + UpstreamRequestID: resp.Header.Get("x-request-id"), + Kind: kind, + Message: upstreamMsg, + Detail: upstreamDetail, + }) + if shouldDisable { + return nil, &UpstreamFailoverError{ + StatusCode: resp.StatusCode, + ResponseBody: body, + RetryableOnSameAccount: account.IsPoolMode() && isPoolModeRetryableStatus(resp.StatusCode), + } + } + + // Map status code to error type and write response + errType := "api_error" + switch { + case resp.StatusCode == 400: + errType = "invalid_request_error" + case resp.StatusCode == 404: + errType = "not_found_error" + case resp.StatusCode == 429: + errType = "rate_limit_error" + case resp.StatusCode >= 500: + errType = "api_error" + } + + writeError(c, resp.StatusCode, errType, upstreamMsg) + return nil, fmt.Errorf("upstream error: %d %s", resp.StatusCode, upstreamMsg) +} + // openaiStreamingResult streaming response result type openaiStreamingResult struct { usage *OpenAIUsage