Files
sub2api/backend/internal/service/openai_gateway_messages.go

538 lines
17 KiB
Go

package service
import (
"bufio"
"bytes"
"context"
"encoding/json"
"errors"
"fmt"
"io"
"net/http"
"strings"
"time"
"github.com/Wei-Shaw/sub2api/internal/pkg/apicompat"
"github.com/Wei-Shaw/sub2api/internal/pkg/claude"
"github.com/Wei-Shaw/sub2api/internal/pkg/logger"
"github.com/Wei-Shaw/sub2api/internal/util/responseheaders"
"github.com/gin-gonic/gin"
"go.uber.org/zap"
)
// ForwardAsAnthropic accepts an Anthropic Messages request body, converts it
// to OpenAI Responses API format, forwards to the OpenAI upstream, and converts
// the response back to Anthropic Messages format. This enables Claude Code
// clients to access OpenAI models through the standard /v1/messages endpoint.
func (s *OpenAIGatewayService) ForwardAsAnthropic(
ctx context.Context,
c *gin.Context,
account *Account,
body []byte,
promptCacheKey string,
defaultMappedModel string,
) (*OpenAIForwardResult, error) {
startTime := time.Now()
// 1. Parse Anthropic request
var anthropicReq apicompat.AnthropicRequest
if err := json.Unmarshal(body, &anthropicReq); err != nil {
return nil, fmt.Errorf("parse anthropic request: %w", err)
}
originalModel := anthropicReq.Model
clientStream := anthropicReq.Stream // client's original stream preference
// 2. Convert Anthropic → Responses
responsesReq, err := apicompat.AnthropicToResponses(&anthropicReq)
if err != nil {
return nil, fmt.Errorf("convert anthropic to responses: %w", err)
}
// Upstream always uses streaming (upstream may not support sync mode).
// The client's original preference determines the response format.
responsesReq.Stream = true
isStream := true
// 2b. Handle BetaFastMode → service_tier: "priority"
if containsBetaToken(c.GetHeader("anthropic-beta"), claude.BetaFastMode) {
responsesReq.ServiceTier = "priority"
}
// 3. Model mapping
mappedModel := resolveOpenAIForwardModel(account, originalModel, defaultMappedModel)
responsesReq.Model = mappedModel
logger.L().Debug("openai messages: model mapping applied",
zap.Int64("account_id", account.ID),
zap.String("original_model", originalModel),
zap.String("mapped_model", mappedModel),
zap.Bool("stream", isStream),
)
// 4. Marshal Responses request body, then apply OAuth codex transform
responsesBody, err := json.Marshal(responsesReq)
if err != nil {
return nil, fmt.Errorf("marshal responses request: %w", err)
}
if account.Type == AccountTypeOAuth {
var reqBody map[string]any
if err := json.Unmarshal(responsesBody, &reqBody); err != nil {
return nil, fmt.Errorf("unmarshal for codex transform: %w", err)
}
codexResult := applyCodexOAuthTransform(reqBody, false, false)
if codexResult.PromptCacheKey != "" {
promptCacheKey = codexResult.PromptCacheKey
} else if promptCacheKey != "" {
reqBody["prompt_cache_key"] = promptCacheKey
}
// OAuth codex transform forces stream=true upstream, so always use
// the streaming response handler regardless of what the client asked.
isStream = true
responsesBody, err = json.Marshal(reqBody)
if err != nil {
return nil, fmt.Errorf("remarshal after codex transform: %w", err)
}
}
// 5. Get access token
token, _, err := s.GetAccessToken(ctx, account)
if err != nil {
return nil, fmt.Errorf("get access token: %w", err)
}
// 6. Build upstream request
upstreamReq, err := s.buildUpstreamRequest(ctx, c, account, responsesBody, token, isStream, promptCacheKey, false)
if err != nil {
return nil, fmt.Errorf("build upstream request: %w", err)
}
// Override session_id with a deterministic UUID derived from the sticky
// session key (buildUpstreamRequest may have set it to the raw value).
if promptCacheKey != "" {
upstreamReq.Header.Set("session_id", generateSessionUUID(promptCacheKey))
}
// 7. Send request
proxyURL := ""
if account.Proxy != nil {
proxyURL = account.Proxy.URL()
}
resp, err := s.httpUpstream.Do(upstreamReq, proxyURL, account.ID, account.Concurrency)
if err != nil {
safeErr := sanitizeUpstreamErrorMessage(err.Error())
setOpsUpstreamError(c, 0, safeErr, "")
appendOpsUpstreamError(c, OpsUpstreamErrorEvent{
Platform: account.Platform,
AccountID: account.ID,
AccountName: account.Name,
UpstreamStatusCode: 0,
Kind: "request_error",
Message: safeErr,
})
writeAnthropicError(c, http.StatusBadGateway, "api_error", "Upstream request failed")
return nil, fmt.Errorf("upstream request failed: %s", safeErr)
}
defer func() { _ = resp.Body.Close() }()
// 8. Handle error response with failover
if resp.StatusCode >= 400 {
respBody, _ := io.ReadAll(io.LimitReader(resp.Body, 2<<20))
_ = resp.Body.Close()
resp.Body = io.NopCloser(bytes.NewReader(respBody))
upstreamMsg := strings.TrimSpace(extractUpstreamErrorMessage(respBody))
upstreamMsg = sanitizeUpstreamErrorMessage(upstreamMsg)
if s.shouldFailoverOpenAIUpstreamResponse(resp.StatusCode, upstreamMsg, respBody) {
upstreamDetail := ""
if s.cfg != nil && s.cfg.Gateway.LogUpstreamErrorBody {
maxBytes := s.cfg.Gateway.LogUpstreamErrorBodyMaxBytes
if maxBytes <= 0 {
maxBytes = 2048
}
upstreamDetail = truncateString(string(respBody), maxBytes)
}
appendOpsUpstreamError(c, OpsUpstreamErrorEvent{
Platform: account.Platform,
AccountID: account.ID,
AccountName: account.Name,
UpstreamStatusCode: resp.StatusCode,
UpstreamRequestID: resp.Header.Get("x-request-id"),
Kind: "failover",
Message: upstreamMsg,
Detail: upstreamDetail,
})
if s.rateLimitService != nil {
s.rateLimitService.HandleUpstreamError(ctx, account, resp.StatusCode, resp.Header, respBody)
}
return nil, &UpstreamFailoverError{
StatusCode: resp.StatusCode,
ResponseBody: respBody,
RetryableOnSameAccount: account.IsPoolMode() && (isPoolModeRetryableStatus(resp.StatusCode) || isOpenAITransientProcessingError(resp.StatusCode, upstreamMsg, respBody)),
}
}
// Non-failover error: return Anthropic-formatted error to client
return s.handleAnthropicErrorResponse(resp, c, account)
}
// 9. Handle normal response
// Upstream is always streaming; choose response format based on client preference.
var result *OpenAIForwardResult
var handleErr error
if clientStream {
result, handleErr = s.handleAnthropicStreamingResponse(resp, c, originalModel, mappedModel, startTime)
} else {
// Client wants JSON: buffer the streaming response and assemble a JSON reply.
result, handleErr = s.handleAnthropicBufferedStreamingResponse(resp, c, originalModel, mappedModel, startTime)
}
// Propagate ServiceTier and ReasoningEffort to result for billing
if handleErr == nil && result != nil {
if responsesReq.ServiceTier != "" {
st := responsesReq.ServiceTier
result.ServiceTier = &st
}
if responsesReq.Reasoning != nil && responsesReq.Reasoning.Effort != "" {
re := responsesReq.Reasoning.Effort
result.ReasoningEffort = &re
}
}
// Extract and save Codex usage snapshot from response headers (for OAuth accounts)
if handleErr == nil && account.Type == AccountTypeOAuth {
if snapshot := ParseCodexRateLimitHeaders(resp.Header); snapshot != nil {
s.updateCodexUsageSnapshot(ctx, account.ID, snapshot)
}
}
return result, handleErr
}
// handleAnthropicErrorResponse reads an upstream error and returns it in
// Anthropic error format.
func (s *OpenAIGatewayService) handleAnthropicErrorResponse(
resp *http.Response,
c *gin.Context,
account *Account,
) (*OpenAIForwardResult, error) {
return s.handleCompatErrorResponse(resp, c, account, writeAnthropicError)
}
// handleAnthropicBufferedStreamingResponse reads all Responses SSE events from
// the upstream streaming response, finds the terminal event (response.completed
// / response.incomplete / response.failed), converts the complete response to
// Anthropic Messages JSON format, and writes it to the client.
// This is used when the client requested stream=false but the upstream is always
// streaming.
func (s *OpenAIGatewayService) handleAnthropicBufferedStreamingResponse(
resp *http.Response,
c *gin.Context,
originalModel string,
mappedModel string,
startTime time.Time,
) (*OpenAIForwardResult, error) {
requestID := resp.Header.Get("x-request-id")
scanner := bufio.NewScanner(resp.Body)
maxLineSize := defaultMaxLineSize
if s.cfg != nil && s.cfg.Gateway.MaxLineSize > 0 {
maxLineSize = s.cfg.Gateway.MaxLineSize
}
scanner.Buffer(make([]byte, 0, 64*1024), maxLineSize)
var finalResponse *apicompat.ResponsesResponse
var usage OpenAIUsage
for scanner.Scan() {
line := scanner.Text()
if !strings.HasPrefix(line, "data: ") || line == "data: [DONE]" {
continue
}
payload := line[6:]
var event apicompat.ResponsesStreamEvent
if err := json.Unmarshal([]byte(payload), &event); err != nil {
logger.L().Warn("openai messages buffered: failed to parse event",
zap.Error(err),
zap.String("request_id", requestID),
)
continue
}
// Terminal events carry the complete ResponsesResponse with output + usage.
if (event.Type == "response.completed" || event.Type == "response.incomplete" || event.Type == "response.failed") &&
event.Response != nil {
finalResponse = event.Response
if event.Response.Usage != nil {
usage = OpenAIUsage{
InputTokens: event.Response.Usage.InputTokens,
OutputTokens: event.Response.Usage.OutputTokens,
}
if event.Response.Usage.InputTokensDetails != nil {
usage.CacheReadInputTokens = event.Response.Usage.InputTokensDetails.CachedTokens
}
}
}
}
if err := scanner.Err(); err != nil {
if !errors.Is(err, context.Canceled) && !errors.Is(err, context.DeadlineExceeded) {
logger.L().Warn("openai messages buffered: read error",
zap.Error(err),
zap.String("request_id", requestID),
)
}
}
if finalResponse == nil {
writeAnthropicError(c, http.StatusBadGateway, "api_error", "Upstream stream ended without a terminal response event")
return nil, fmt.Errorf("upstream stream ended without terminal event")
}
anthropicResp := apicompat.ResponsesToAnthropic(finalResponse, originalModel)
if s.responseHeaderFilter != nil {
responseheaders.WriteFilteredHeaders(c.Writer.Header(), resp.Header, s.responseHeaderFilter)
}
c.JSON(http.StatusOK, anthropicResp)
return &OpenAIForwardResult{
RequestID: requestID,
Usage: usage,
Model: originalModel,
BillingModel: mappedModel,
Stream: false,
Duration: time.Since(startTime),
}, nil
}
// handleAnthropicStreamingResponse reads Responses SSE events from upstream,
// converts each to Anthropic SSE events, and writes them to the client.
// When StreamKeepaliveInterval is configured, it uses a goroutine + channel
// pattern to send Anthropic ping events during periods of upstream silence,
// preventing proxy/client timeout disconnections.
func (s *OpenAIGatewayService) handleAnthropicStreamingResponse(
resp *http.Response,
c *gin.Context,
originalModel string,
mappedModel string,
startTime time.Time,
) (*OpenAIForwardResult, error) {
requestID := resp.Header.Get("x-request-id")
if s.responseHeaderFilter != nil {
responseheaders.WriteFilteredHeaders(c.Writer.Header(), resp.Header, s.responseHeaderFilter)
}
c.Writer.Header().Set("Content-Type", "text/event-stream")
c.Writer.Header().Set("Cache-Control", "no-cache")
c.Writer.Header().Set("Connection", "keep-alive")
c.Writer.Header().Set("X-Accel-Buffering", "no")
c.Writer.WriteHeader(http.StatusOK)
state := apicompat.NewResponsesEventToAnthropicState()
state.Model = originalModel
var usage OpenAIUsage
var firstTokenMs *int
firstChunk := true
scanner := bufio.NewScanner(resp.Body)
maxLineSize := defaultMaxLineSize
if s.cfg != nil && s.cfg.Gateway.MaxLineSize > 0 {
maxLineSize = s.cfg.Gateway.MaxLineSize
}
scanner.Buffer(make([]byte, 0, 64*1024), maxLineSize)
// resultWithUsage builds the final result snapshot.
resultWithUsage := func() *OpenAIForwardResult {
return &OpenAIForwardResult{
RequestID: requestID,
Usage: usage,
Model: originalModel,
BillingModel: mappedModel,
Stream: true,
Duration: time.Since(startTime),
FirstTokenMs: firstTokenMs,
}
}
// processDataLine handles a single "data: ..." SSE line from upstream.
// Returns (clientDisconnected bool).
processDataLine := func(payload string) bool {
if firstChunk {
firstChunk = false
ms := int(time.Since(startTime).Milliseconds())
firstTokenMs = &ms
}
var event apicompat.ResponsesStreamEvent
if err := json.Unmarshal([]byte(payload), &event); err != nil {
logger.L().Warn("openai messages stream: failed to parse event",
zap.Error(err),
zap.String("request_id", requestID),
)
return false
}
// Extract usage from completion events
if (event.Type == "response.completed" || event.Type == "response.incomplete" || event.Type == "response.failed") &&
event.Response != nil && event.Response.Usage != nil {
usage = OpenAIUsage{
InputTokens: event.Response.Usage.InputTokens,
OutputTokens: event.Response.Usage.OutputTokens,
}
if event.Response.Usage.InputTokensDetails != nil {
usage.CacheReadInputTokens = event.Response.Usage.InputTokensDetails.CachedTokens
}
}
// Convert to Anthropic events
events := apicompat.ResponsesEventToAnthropicEvents(&event, state)
for _, evt := range events {
sse, err := apicompat.ResponsesAnthropicEventToSSE(evt)
if err != nil {
logger.L().Warn("openai messages stream: failed to marshal event",
zap.Error(err),
zap.String("request_id", requestID),
)
continue
}
if _, err := fmt.Fprint(c.Writer, sse); err != nil {
logger.L().Info("openai messages stream: client disconnected",
zap.String("request_id", requestID),
)
return true
}
}
if len(events) > 0 {
c.Writer.Flush()
}
return false
}
// finalizeStream sends any remaining Anthropic events and returns the result.
finalizeStream := func() (*OpenAIForwardResult, error) {
if finalEvents := apicompat.FinalizeResponsesAnthropicStream(state); len(finalEvents) > 0 {
for _, evt := range finalEvents {
sse, err := apicompat.ResponsesAnthropicEventToSSE(evt)
if err != nil {
continue
}
fmt.Fprint(c.Writer, sse) //nolint:errcheck
}
c.Writer.Flush()
}
return resultWithUsage(), nil
}
// handleScanErr logs scanner errors if meaningful.
handleScanErr := func(err error) {
if err != nil && !errors.Is(err, context.Canceled) && !errors.Is(err, context.DeadlineExceeded) {
logger.L().Warn("openai messages stream: read error",
zap.Error(err),
zap.String("request_id", requestID),
)
}
}
// ── Determine keepalive interval ──
keepaliveInterval := time.Duration(0)
if s.cfg != nil && s.cfg.Gateway.StreamKeepaliveInterval > 0 {
keepaliveInterval = time.Duration(s.cfg.Gateway.StreamKeepaliveInterval) * time.Second
}
// ── No keepalive: fast synchronous path (no goroutine overhead) ──
if keepaliveInterval <= 0 {
for scanner.Scan() {
line := scanner.Text()
if !strings.HasPrefix(line, "data: ") || line == "data: [DONE]" {
continue
}
if processDataLine(line[6:]) {
return resultWithUsage(), nil
}
}
handleScanErr(scanner.Err())
return finalizeStream()
}
// ── With keepalive: goroutine + channel + select ──
type scanEvent struct {
line string
err error
}
events := make(chan scanEvent, 16)
done := make(chan struct{})
sendEvent := func(ev scanEvent) bool {
select {
case events <- ev:
return true
case <-done:
return false
}
}
go func() {
defer close(events)
for scanner.Scan() {
if !sendEvent(scanEvent{line: scanner.Text()}) {
return
}
}
if err := scanner.Err(); err != nil {
_ = sendEvent(scanEvent{err: err})
}
}()
defer close(done)
keepaliveTicker := time.NewTicker(keepaliveInterval)
defer keepaliveTicker.Stop()
lastDataAt := time.Now()
for {
select {
case ev, ok := <-events:
if !ok {
// Upstream closed
return finalizeStream()
}
if ev.err != nil {
handleScanErr(ev.err)
return finalizeStream()
}
lastDataAt = time.Now()
line := ev.line
if !strings.HasPrefix(line, "data: ") || line == "data: [DONE]" {
continue
}
if processDataLine(line[6:]) {
return resultWithUsage(), nil
}
case <-keepaliveTicker.C:
if time.Since(lastDataAt) < keepaliveInterval {
continue
}
// Send Anthropic-format ping event
if _, err := fmt.Fprint(c.Writer, "event: ping\ndata: {\"type\":\"ping\"}\n\n"); err != nil {
// Client disconnected
logger.L().Info("openai messages stream: client disconnected during keepalive",
zap.String("request_id", requestID),
)
return resultWithUsage(), nil
}
c.Writer.Flush()
}
}
}
// writeAnthropicError writes an error response in Anthropic Messages API format.
func writeAnthropicError(c *gin.Context, statusCode int, errType, message string) {
c.JSON(statusCode, gin.H{
"type": "error",
"error": gin.H{
"type": errType,
"message": message,
},
})
}