diff --git a/backend/internal/service/openai_gateway_messages.go b/backend/internal/service/openai_gateway_messages.go index 55bc57b9..28ca5b01 100644 --- a/backend/internal/service/openai_gateway_messages.go +++ b/backend/internal/service/openai_gateway_messages.go @@ -39,7 +39,7 @@ func (s *OpenAIGatewayService) ForwardAsAnthropic( return nil, fmt.Errorf("parse anthropic request: %w", err) } originalModel := anthropicReq.Model - isStream := anthropicReq.Stream + clientStream := anthropicReq.Stream // client's original stream preference // 2. Convert Anthropic → Responses responsesReq, err := apicompat.AnthropicToResponses(&anthropicReq) @@ -47,6 +47,11 @@ func (s *OpenAIGatewayService) ForwardAsAnthropic( return nil, fmt.Errorf("convert anthropic to responses: %w", err) } + // Upstream always uses streaming (upstream may not support sync mode). + // The client's original preference determines the response format. + responsesReq.Stream = true + isStream := true + // 2b. Handle BetaFastMode → service_tier: "priority" if containsBetaToken(c.GetHeader("anthropic-beta"), claude.BetaFastMode) { responsesReq.ServiceTier = "priority" @@ -169,12 +174,14 @@ func (s *OpenAIGatewayService) ForwardAsAnthropic( } // 9. Handle normal response + // Upstream is always streaming; choose response format based on client preference. var result *OpenAIForwardResult var handleErr error - if isStream { + if clientStream { result, handleErr = s.handleAnthropicStreamingResponse(resp, c, originalModel, mappedModel, startTime) } else { - result, handleErr = s.handleAnthropicNonStreamingResponse(resp, c, originalModel, mappedModel, startTime) + // Client wants JSON: buffer the streaming response and assemble a JSON reply. + result, handleErr = s.handleAnthropicBufferedStreamingResponse(resp, c, originalModel, mappedModel, startTime) } // Propagate ServiceTier and ReasoningEffort to result for billing @@ -256,9 +263,13 @@ func (s *OpenAIGatewayService) handleAnthropicErrorResponse( return nil, fmt.Errorf("upstream error: %d %s", resp.StatusCode, upstreamMsg) } -// handleAnthropicNonStreamingResponse reads a Responses API JSON response, -// converts it to Anthropic Messages format, and writes it to the client. -func (s *OpenAIGatewayService) handleAnthropicNonStreamingResponse( +// handleAnthropicBufferedStreamingResponse reads all Responses SSE events from +// the upstream streaming response, finds the terminal event (response.completed +// / response.incomplete / response.failed), converts the complete response to +// Anthropic Messages JSON format, and writes it to the client. +// This is used when the client requested stream=false but the upstream is always +// streaming. +func (s *OpenAIGatewayService) handleAnthropicBufferedStreamingResponse( resp *http.Response, c *gin.Context, originalModel string, @@ -267,29 +278,61 @@ func (s *OpenAIGatewayService) handleAnthropicNonStreamingResponse( ) (*OpenAIForwardResult, error) { requestID := resp.Header.Get("x-request-id") - respBody, err := io.ReadAll(resp.Body) - if err != nil { - return nil, fmt.Errorf("read upstream response: %w", err) - } - - var responsesResp apicompat.ResponsesResponse - if err := json.Unmarshal(respBody, &responsesResp); err != nil { - return nil, fmt.Errorf("parse responses response: %w", err) - } - - anthropicResp := apicompat.ResponsesToAnthropic(&responsesResp, originalModel) + scanner := bufio.NewScanner(resp.Body) + scanner.Buffer(make([]byte, 0, 64*1024), 1024*1024) + var finalResponse *apicompat.ResponsesResponse var usage OpenAIUsage - if responsesResp.Usage != nil { - usage = OpenAIUsage{ - InputTokens: responsesResp.Usage.InputTokens, - OutputTokens: responsesResp.Usage.OutputTokens, + + for scanner.Scan() { + line := scanner.Text() + + if !strings.HasPrefix(line, "data: ") || line == "data: [DONE]" { + continue } - if responsesResp.Usage.InputTokensDetails != nil { - usage.CacheReadInputTokens = responsesResp.Usage.InputTokensDetails.CachedTokens + payload := line[6:] + + var event apicompat.ResponsesStreamEvent + if err := json.Unmarshal([]byte(payload), &event); err != nil { + logger.L().Warn("openai messages buffered: failed to parse event", + zap.Error(err), + zap.String("request_id", requestID), + ) + continue + } + + // Terminal events carry the complete ResponsesResponse with output + usage. + if (event.Type == "response.completed" || event.Type == "response.incomplete" || event.Type == "response.failed") && + event.Response != nil { + finalResponse = event.Response + if event.Response.Usage != nil { + usage = OpenAIUsage{ + InputTokens: event.Response.Usage.InputTokens, + OutputTokens: event.Response.Usage.OutputTokens, + } + if event.Response.Usage.InputTokensDetails != nil { + usage.CacheReadInputTokens = event.Response.Usage.InputTokensDetails.CachedTokens + } + } } } + if err := scanner.Err(); err != nil { + if !errors.Is(err, context.Canceled) && !errors.Is(err, context.DeadlineExceeded) { + logger.L().Warn("openai messages buffered: read error", + zap.Error(err), + zap.String("request_id", requestID), + ) + } + } + + if finalResponse == nil { + writeAnthropicError(c, http.StatusBadGateway, "api_error", "Upstream stream ended without a terminal response event") + return nil, fmt.Errorf("upstream stream ended without terminal event") + } + + anthropicResp := apicompat.ResponsesToAnthropic(finalResponse, originalModel) + if s.responseHeaderFilter != nil { responseheaders.WriteFilteredHeaders(c.Writer.Header(), resp.Header, s.responseHeaderFilter) } @@ -307,6 +350,9 @@ func (s *OpenAIGatewayService) handleAnthropicNonStreamingResponse( // handleAnthropicStreamingResponse reads Responses SSE events from upstream, // converts each to Anthropic SSE events, and writes them to the client. +// When StreamKeepaliveInterval is configured, it uses a goroutine + channel +// pattern to send Anthropic ping events during periods of upstream silence, +// preventing proxy/client timeout disconnections. func (s *OpenAIGatewayService) handleAnthropicStreamingResponse( resp *http.Response, c *gin.Context, @@ -322,6 +368,7 @@ func (s *OpenAIGatewayService) handleAnthropicStreamingResponse( c.Writer.Header().Set("Content-Type", "text/event-stream") c.Writer.Header().Set("Cache-Control", "no-cache") c.Writer.Header().Set("Connection", "keep-alive") + c.Writer.Header().Set("X-Accel-Buffering", "no") c.Writer.WriteHeader(http.StatusOK) state := apicompat.NewResponsesEventToAnthropicState() @@ -333,28 +380,35 @@ func (s *OpenAIGatewayService) handleAnthropicStreamingResponse( scanner := bufio.NewScanner(resp.Body) scanner.Buffer(make([]byte, 0, 64*1024), 1024*1024) - for scanner.Scan() { - line := scanner.Text() - - if !strings.HasPrefix(line, "data: ") || line == "data: [DONE]" { - continue + // resultWithUsage builds the final result snapshot. + resultWithUsage := func() *OpenAIForwardResult { + return &OpenAIForwardResult{ + RequestID: requestID, + Usage: usage, + Model: originalModel, + BillingModel: mappedModel, + Stream: true, + Duration: time.Since(startTime), + FirstTokenMs: firstTokenMs, } - payload := line[6:] + } + // processDataLine handles a single "data: ..." SSE line from upstream. + // Returns (clientDisconnected bool). + processDataLine := func(payload string) bool { if firstChunk { firstChunk = false ms := int(time.Since(startTime).Milliseconds()) firstTokenMs = &ms } - // Parse the Responses SSE event var event apicompat.ResponsesStreamEvent if err := json.Unmarshal([]byte(payload), &event); err != nil { logger.L().Warn("openai messages stream: failed to parse event", zap.Error(err), zap.String("request_id", requestID), ) - continue + return false } // Extract usage from completion events @@ -381,28 +435,36 @@ func (s *OpenAIGatewayService) handleAnthropicStreamingResponse( continue } if _, err := fmt.Fprint(c.Writer, sse); err != nil { - // Client disconnected — return collected usage logger.L().Info("openai messages stream: client disconnected", zap.String("request_id", requestID), ) - return &OpenAIForwardResult{ - RequestID: requestID, - Usage: usage, - Model: originalModel, - BillingModel: mappedModel, - Stream: true, - Duration: time.Since(startTime), - FirstTokenMs: firstTokenMs, - }, nil + return true } } if len(events) > 0 { c.Writer.Flush() } + return false } - if err := scanner.Err(); err != nil { - if !errors.Is(err, context.Canceled) && !errors.Is(err, context.DeadlineExceeded) { + // finalizeStream sends any remaining Anthropic events and returns the result. + finalizeStream := func() (*OpenAIForwardResult, error) { + if finalEvents := apicompat.FinalizeResponsesAnthropicStream(state); len(finalEvents) > 0 { + for _, evt := range finalEvents { + sse, err := apicompat.ResponsesAnthropicEventToSSE(evt) + if err != nil { + continue + } + fmt.Fprint(c.Writer, sse) //nolint:errcheck + } + c.Writer.Flush() + } + return resultWithUsage(), nil + } + + // handleScanErr logs scanner errors if meaningful. + handleScanErr := func(err error) { + if err != nil && !errors.Is(err, context.Canceled) && !errors.Is(err, context.DeadlineExceeded) { logger.L().Warn("openai messages stream: read error", zap.Error(err), zap.String("request_id", requestID), @@ -410,27 +472,94 @@ func (s *OpenAIGatewayService) handleAnthropicStreamingResponse( } } - // Ensure the Anthropic stream is properly terminated - if finalEvents := apicompat.FinalizeResponsesAnthropicStream(state); len(finalEvents) > 0 { - for _, evt := range finalEvents { - sse, err := apicompat.ResponsesAnthropicEventToSSE(evt) - if err != nil { - continue - } - fmt.Fprint(c.Writer, sse) //nolint:errcheck - } - c.Writer.Flush() + // ── Determine keepalive interval ── + keepaliveInterval := time.Duration(0) + if s.cfg != nil && s.cfg.Gateway.StreamKeepaliveInterval > 0 { + keepaliveInterval = time.Duration(s.cfg.Gateway.StreamKeepaliveInterval) * time.Second } - return &OpenAIForwardResult{ - RequestID: requestID, - Usage: usage, - Model: originalModel, - BillingModel: mappedModel, - Stream: true, - Duration: time.Since(startTime), - FirstTokenMs: firstTokenMs, - }, nil + // ── No keepalive: fast synchronous path (no goroutine overhead) ── + if keepaliveInterval <= 0 { + for scanner.Scan() { + line := scanner.Text() + if !strings.HasPrefix(line, "data: ") || line == "data: [DONE]" { + continue + } + if processDataLine(line[6:]) { + return resultWithUsage(), nil + } + } + handleScanErr(scanner.Err()) + return finalizeStream() + } + + // ── With keepalive: goroutine + channel + select ── + type scanEvent struct { + line string + err error + } + events := make(chan scanEvent, 16) + done := make(chan struct{}) + sendEvent := func(ev scanEvent) bool { + select { + case events <- ev: + return true + case <-done: + return false + } + } + go func() { + defer close(events) + for scanner.Scan() { + if !sendEvent(scanEvent{line: scanner.Text()}) { + return + } + } + if err := scanner.Err(); err != nil { + _ = sendEvent(scanEvent{err: err}) + } + }() + defer close(done) + + keepaliveTicker := time.NewTicker(keepaliveInterval) + defer keepaliveTicker.Stop() + lastDataAt := time.Now() + + for { + select { + case ev, ok := <-events: + if !ok { + // Upstream closed + return finalizeStream() + } + if ev.err != nil { + handleScanErr(ev.err) + return finalizeStream() + } + lastDataAt = time.Now() + line := ev.line + if !strings.HasPrefix(line, "data: ") || line == "data: [DONE]" { + continue + } + if processDataLine(line[6:]) { + return resultWithUsage(), nil + } + + case <-keepaliveTicker.C: + if time.Since(lastDataAt) < keepaliveInterval { + continue + } + // Send Anthropic-format ping event + if _, err := fmt.Fprint(c.Writer, "event: ping\ndata: {\"type\":\"ping\"}\n\n"); err != nil { + // Client disconnected + logger.L().Info("openai messages stream: client disconnected during keepalive", + zap.String("request_id", requestID), + ) + return resultWithUsage(), nil + } + c.Writer.Flush() + } + } } // writeAnthropicError writes an error response in Anthropic Messages API format.