refactor: 重构 Chat Completions 端点,采用类型安全的 Responses API 转换

将 /v1/chat/completions 端点从 ResponseWriter 劫持模式重构为独立的
类型安全转换路径,与 Anthropic Messages 端点架构对齐:

- 在 apicompat 包新增 Chat Completions 完整类型定义和双向转换器
- 新增 ForwardAsChatCompletions service 方法,走 Responses API 上游
- Handler 改为独立的账号选择/failover 循环,不再劫持 Responses handler
- 提取 handleCompatErrorResponse 为 Chat Completions 和 Messages 共用
- 删除旧的 forwardChatCompletions 直传路径及相关死代码
This commit is contained in:
shaw
2026-03-11 22:10:22 +08:00
parent 8dd38f4775
commit 9d81467937
11 changed files with 2420 additions and 1717 deletions

View File

@@ -1,23 +1,53 @@
package handler
import (
"bytes"
"crypto/rand"
"encoding/hex"
"encoding/json"
"io"
"context"
"errors"
"net/http"
"strings"
"time"
pkghttputil "github.com/Wei-Shaw/sub2api/internal/pkg/httputil"
"github.com/Wei-Shaw/sub2api/internal/pkg/ip"
"github.com/Wei-Shaw/sub2api/internal/pkg/logger"
middleware2 "github.com/Wei-Shaw/sub2api/internal/server/middleware"
"github.com/Wei-Shaw/sub2api/internal/service"
"github.com/gin-gonic/gin"
"github.com/tidwall/gjson"
"go.uber.org/zap"
)
// ChatCompletions handles OpenAI Chat Completions API compatibility.
// ChatCompletions handles OpenAI Chat Completions API requests.
// POST /v1/chat/completions
func (h *OpenAIGatewayHandler) ChatCompletions(c *gin.Context) {
body, err := io.ReadAll(c.Request.Body)
streamStarted := false
defer h.recoverResponsesPanic(c, &streamStarted)
requestStart := time.Now()
apiKey, ok := middleware2.GetAPIKeyFromContext(c)
if !ok {
h.errorResponse(c, http.StatusUnauthorized, "authentication_error", "Invalid API key")
return
}
subject, ok := middleware2.GetAuthSubjectFromContext(c)
if !ok {
h.errorResponse(c, http.StatusInternalServerError, "api_error", "User context not found")
return
}
reqLog := requestLogger(
c,
"handler.openai_gateway.chat_completions",
zap.Int64("user_id", subject.UserID),
zap.Int64("api_key_id", apiKey.ID),
zap.Any("group_id", apiKey.GroupID),
)
if !h.ensureResponsesDependencies(c, reqLog) {
return
}
body, err := pkghttputil.ReadRequestBodyWithPrealloc(c.Request)
if err != nil {
if maxErr, ok := extractMaxBytesError(err); ok {
h.errorResponse(c, http.StatusRequestEntityTooLarge, "invalid_request_error", buildBodyTooLargeMessage(maxErr.Limit))
@@ -31,516 +61,230 @@ func (h *OpenAIGatewayHandler) ChatCompletions(c *gin.Context) {
return
}
// Preserve original chat-completions request for upstream passthrough when needed.
c.Set(service.OpenAIChatCompletionsBodyKey, body)
var chatReq map[string]any
if err := json.Unmarshal(body, &chatReq); err != nil {
if !gjson.ValidBytes(body) {
h.errorResponse(c, http.StatusBadRequest, "invalid_request_error", "Failed to parse request body")
return
}
includeUsage := false
if streamOptions, ok := chatReq["stream_options"].(map[string]any); ok {
if v, ok := streamOptions["include_usage"].(bool); ok {
includeUsage = v
}
modelResult := gjson.GetBytes(body, "model")
if !modelResult.Exists() || modelResult.Type != gjson.String || modelResult.String() == "" {
h.errorResponse(c, http.StatusBadRequest, "invalid_request_error", "model is required")
return
}
c.Set(service.OpenAIChatCompletionsIncludeUsageKey, includeUsage)
reqModel := modelResult.String()
reqStream := gjson.GetBytes(body, "stream").Bool()
converted, err := service.ConvertChatCompletionsToResponses(chatReq)
if err != nil {
h.errorResponse(c, http.StatusBadRequest, "invalid_request_error", err.Error())
reqLog = reqLog.With(zap.String("model", reqModel), zap.Bool("stream", reqStream))
setOpsRequestContext(c, reqModel, reqStream, body)
if h.errorPassthroughService != nil {
service.BindErrorPassthroughService(c, h.errorPassthroughService)
}
subscription, _ := middleware2.GetSubscriptionFromContext(c)
service.SetOpsLatencyMs(c, service.OpsAuthLatencyMsKey, time.Since(requestStart).Milliseconds())
routingStart := time.Now()
userReleaseFunc, acquired := h.acquireResponsesUserSlot(c, subject.UserID, subject.Concurrency, reqStream, &streamStarted, reqLog)
if !acquired {
return
}
if userReleaseFunc != nil {
defer userReleaseFunc()
}
if err := h.billingCacheService.CheckBillingEligibility(c.Request.Context(), apiKey.User, apiKey, apiKey.Group, subscription); err != nil {
reqLog.Info("openai_chat_completions.billing_eligibility_check_failed", zap.Error(err))
status, code, message := billingErrorDetails(err)
h.handleStreamingAwareError(c, status, code, message, streamStarted)
return
}
convertedBody, err := json.Marshal(converted)
if err != nil {
h.errorResponse(c, http.StatusInternalServerError, "api_error", "Failed to process request")
return
}
sessionHash := h.gatewayService.GenerateSessionHash(c, body)
promptCacheKey := h.gatewayService.ExtractSessionID(c, body)
stream, _ := converted["stream"].(bool)
model, _ := converted["model"].(string)
originalWriter := c.Writer
writer := newChatCompletionsResponseWriter(c.Writer, stream, includeUsage, model)
c.Writer = writer
c.Request.Body = io.NopCloser(bytes.NewReader(convertedBody))
c.Request.ContentLength = int64(len(convertedBody))
maxAccountSwitches := h.maxAccountSwitches
switchCount := 0
failedAccountIDs := make(map[int64]struct{})
sameAccountRetryCount := make(map[int64]int)
var lastFailoverErr *service.UpstreamFailoverError
h.Responses(c)
writer.Finalize()
c.Writer = originalWriter
}
type chatCompletionsResponseWriter struct {
gin.ResponseWriter
stream bool
includeUsage bool
buffer bytes.Buffer
streamBuf bytes.Buffer
state *chatCompletionStreamState
corrector *service.CodexToolCorrector
finalized bool
passthrough bool
}
type chatCompletionStreamState struct {
id string
model string
created int64
sentRole bool
sawToolCall bool
sawText bool
toolCallIndex map[string]int
usage map[string]any
}
func newChatCompletionsResponseWriter(w gin.ResponseWriter, stream bool, includeUsage bool, model string) *chatCompletionsResponseWriter {
return &chatCompletionsResponseWriter{
ResponseWriter: w,
stream: stream,
includeUsage: includeUsage,
state: &chatCompletionStreamState{
model: strings.TrimSpace(model),
toolCallIndex: make(map[string]int),
},
corrector: service.NewCodexToolCorrector(),
}
}
func (w *chatCompletionsResponseWriter) Write(data []byte) (int, error) {
if w.passthrough {
return w.ResponseWriter.Write(data)
}
if w.stream {
n, err := w.streamBuf.Write(data)
if err != nil {
return n, err
}
w.flushStreamBuffer()
return n, nil
}
if w.finalized {
return len(data), nil
}
return w.buffer.Write(data)
}
func (w *chatCompletionsResponseWriter) WriteString(s string) (int, error) {
return w.Write([]byte(s))
}
func (w *chatCompletionsResponseWriter) Finalize() {
if w.finalized {
return
}
w.finalized = true
if w.passthrough {
return
}
if w.stream {
return
}
body := w.buffer.Bytes()
if len(body) == 0 {
return
}
w.ResponseWriter.Header().Del("Content-Length")
converted, err := service.ConvertResponsesToChatCompletion(body)
if err != nil {
_, _ = w.ResponseWriter.Write(body)
return
}
corrected := converted
if correctedStr, ok := w.corrector.CorrectToolCallsInSSEData(string(converted)); ok {
corrected = []byte(correctedStr)
}
_, _ = w.ResponseWriter.Write(corrected)
}
func (w *chatCompletionsResponseWriter) SetPassthrough() {
w.passthrough = true
}
func (w *chatCompletionsResponseWriter) Status() int {
if w.ResponseWriter == nil {
return 0
}
return w.ResponseWriter.Status()
}
func (w *chatCompletionsResponseWriter) Written() bool {
if w.ResponseWriter == nil {
return false
}
return w.ResponseWriter.Written()
}
func (w *chatCompletionsResponseWriter) flushStreamBuffer() {
for {
buf := w.streamBuf.Bytes()
idx := bytes.IndexByte(buf, '\n')
if idx == -1 {
return
}
lineBytes := w.streamBuf.Next(idx + 1)
line := strings.TrimRight(string(lineBytes), "\r\n")
w.handleStreamLine(line)
}
}
func (w *chatCompletionsResponseWriter) handleStreamLine(line string) {
if line == "" {
return
}
if strings.HasPrefix(line, ":") {
_, _ = w.ResponseWriter.Write([]byte(line + "\n\n"))
return
}
if !strings.HasPrefix(line, "data:") {
return
}
data := strings.TrimSpace(strings.TrimPrefix(line, "data:"))
for _, chunk := range w.convertResponseDataToChatChunks(data) {
if chunk == "" {
continue
}
if chunk == "[DONE]" {
_, _ = w.ResponseWriter.Write([]byte("data: [DONE]\n\n"))
continue
}
_, _ = w.ResponseWriter.Write([]byte("data: " + chunk + "\n\n"))
}
}
func (w *chatCompletionsResponseWriter) convertResponseDataToChatChunks(data string) []string {
if data == "" {
return nil
}
if data == "[DONE]" {
return []string{"[DONE]"}
}
var payload map[string]any
if err := json.Unmarshal([]byte(data), &payload); err != nil {
return []string{data}
}
if _, ok := payload["error"]; ok {
return []string{data}
}
eventType := strings.TrimSpace(getString(payload["type"]))
if eventType == "" {
return []string{data}
}
w.state.applyMetadata(payload)
switch eventType {
case "response.created":
return nil
case "response.output_text.delta":
delta := getString(payload["delta"])
if delta == "" {
return nil
}
w.state.sawText = true
return []string{w.buildTextDeltaChunk(delta)}
case "response.output_text.done":
if w.state.sawText {
return nil
}
text := getString(payload["text"])
if text == "" {
return nil
}
w.state.sawText = true
return []string{w.buildTextDeltaChunk(text)}
case "response.output_item.added", "response.output_item.delta":
if item, ok := payload["item"].(map[string]any); ok {
if callID, name, args, ok := extractToolCallFromItem(item); ok {
w.state.sawToolCall = true
return []string{w.buildToolCallChunk(callID, name, args)}
c.Set("openai_chat_completions_fallback_model", "")
reqLog.Debug("openai_chat_completions.account_selecting", zap.Int("excluded_account_count", len(failedAccountIDs)))
selection, scheduleDecision, err := h.gatewayService.SelectAccountWithScheduler(
c.Request.Context(),
apiKey.GroupID,
"",
sessionHash,
reqModel,
failedAccountIDs,
service.OpenAIUpstreamTransportAny,
)
if err != nil {
reqLog.Warn("openai_chat_completions.account_select_failed",
zap.Error(err),
zap.Int("excluded_account_count", len(failedAccountIDs)),
)
if len(failedAccountIDs) == 0 {
defaultModel := ""
if apiKey.Group != nil {
defaultModel = apiKey.Group.DefaultMappedModel
}
if defaultModel != "" && defaultModel != reqModel {
reqLog.Info("openai_chat_completions.fallback_to_default_model",
zap.String("default_mapped_model", defaultModel),
)
selection, scheduleDecision, err = h.gatewayService.SelectAccountWithScheduler(
c.Request.Context(),
apiKey.GroupID,
"",
sessionHash,
defaultModel,
failedAccountIDs,
service.OpenAIUpstreamTransportAny,
)
if err == nil && selection != nil {
c.Set("openai_chat_completions_fallback_model", defaultModel)
}
}
if err != nil {
h.handleStreamingAwareError(c, http.StatusServiceUnavailable, "api_error", "Service temporarily unavailable", streamStarted)
return
}
} else {
if lastFailoverErr != nil {
h.handleFailoverExhausted(c, lastFailoverErr, streamStarted)
} else {
h.handleStreamingAwareError(c, http.StatusBadGateway, "api_error", "Upstream request failed", streamStarted)
}
return
}
}
case "response.completed", "response.done":
if responseObj, ok := payload["response"].(map[string]any); ok {
w.state.applyResponseUsage(responseObj)
if selection == nil || selection.Account == nil {
h.handleStreamingAwareError(c, http.StatusServiceUnavailable, "api_error", "No available accounts", streamStarted)
return
}
return []string{w.buildFinalChunk()}
}
account := selection.Account
sessionHash = ensureOpenAIPoolModeSessionHash(sessionHash, account)
reqLog.Debug("openai_chat_completions.account_selected", zap.Int64("account_id", account.ID), zap.String("account_name", account.Name))
_ = scheduleDecision
setOpsSelectedAccount(c, account.ID, account.Platform)
if strings.Contains(eventType, "tool_call") || strings.Contains(eventType, "function_call") {
callID := strings.TrimSpace(getString(payload["call_id"]))
if callID == "" {
callID = strings.TrimSpace(getString(payload["tool_call_id"]))
accountReleaseFunc, acquired := h.acquireResponsesAccountSlot(c, apiKey.GroupID, sessionHash, selection, reqStream, &streamStarted, reqLog)
if !acquired {
return
}
if callID == "" {
callID = strings.TrimSpace(getString(payload["id"]))
service.SetOpsLatencyMs(c, service.OpsRoutingLatencyMsKey, time.Since(routingStart).Milliseconds())
forwardStart := time.Now()
defaultMappedModel := ""
if apiKey.Group != nil {
defaultMappedModel = apiKey.Group.DefaultMappedModel
}
args := getString(payload["delta"])
name := strings.TrimSpace(getString(payload["name"]))
if callID != "" && (args != "" || name != "") {
w.state.sawToolCall = true
return []string{w.buildToolCallChunk(callID, name, args)}
if fallbackModel := c.GetString("openai_chat_completions_fallback_model"); fallbackModel != "" {
defaultMappedModel = fallbackModel
}
}
result, err := h.gatewayService.ForwardAsChatCompletions(c.Request.Context(), c, account, body, promptCacheKey, defaultMappedModel)
return nil
}
func (w *chatCompletionsResponseWriter) buildTextDeltaChunk(delta string) string {
w.state.ensureDefaults()
payload := map[string]any{
"content": delta,
}
if !w.state.sentRole {
payload["role"] = "assistant"
w.state.sentRole = true
}
return w.buildChunk(payload, nil, nil)
}
func (w *chatCompletionsResponseWriter) buildToolCallChunk(callID, name, args string) string {
w.state.ensureDefaults()
index := w.state.toolCallIndexFor(callID)
function := map[string]any{}
if name != "" {
function["name"] = name
}
if args != "" {
function["arguments"] = args
}
toolCall := map[string]any{
"index": index,
"id": callID,
"type": "function",
"function": function,
}
delta := map[string]any{
"tool_calls": []any{toolCall},
}
if !w.state.sentRole {
delta["role"] = "assistant"
w.state.sentRole = true
}
return w.buildChunk(delta, nil, nil)
}
func (w *chatCompletionsResponseWriter) buildFinalChunk() string {
w.state.ensureDefaults()
finishReason := "stop"
if w.state.sawToolCall {
finishReason = "tool_calls"
}
usage := map[string]any(nil)
if w.includeUsage && w.state.usage != nil {
usage = w.state.usage
}
return w.buildChunk(map[string]any{}, finishReason, usage)
}
func (w *chatCompletionsResponseWriter) buildChunk(delta map[string]any, finishReason any, usage map[string]any) string {
w.state.ensureDefaults()
chunk := map[string]any{
"id": w.state.id,
"object": "chat.completion.chunk",
"created": w.state.created,
"model": w.state.model,
"choices": []any{
map[string]any{
"index": 0,
"delta": delta,
"finish_reason": finishReason,
},
},
}
if usage != nil {
chunk["usage"] = usage
}
data, _ := json.Marshal(chunk)
if corrected, ok := w.corrector.CorrectToolCallsInSSEData(string(data)); ok {
return corrected
}
return string(data)
}
func (s *chatCompletionStreamState) ensureDefaults() {
if s.id == "" {
s.id = "chatcmpl-" + randomHexUnsafe(12)
}
if s.model == "" {
s.model = "unknown"
}
if s.created == 0 {
s.created = time.Now().Unix()
}
}
func (s *chatCompletionStreamState) toolCallIndexFor(callID string) int {
if idx, ok := s.toolCallIndex[callID]; ok {
return idx
}
idx := len(s.toolCallIndex)
s.toolCallIndex[callID] = idx
return idx
}
func (s *chatCompletionStreamState) applyMetadata(payload map[string]any) {
if responseObj, ok := payload["response"].(map[string]any); ok {
s.applyResponseMetadata(responseObj)
}
if s.id == "" {
if id := strings.TrimSpace(getString(payload["response_id"])); id != "" {
s.id = id
} else if id := strings.TrimSpace(getString(payload["id"])); id != "" {
s.id = id
forwardDurationMs := time.Since(forwardStart).Milliseconds()
if accountReleaseFunc != nil {
accountReleaseFunc()
}
}
if s.model == "" {
if model := strings.TrimSpace(getString(payload["model"])); model != "" {
s.model = model
upstreamLatencyMs, _ := getContextInt64(c, service.OpsUpstreamLatencyMsKey)
responseLatencyMs := forwardDurationMs
if upstreamLatencyMs > 0 && forwardDurationMs > upstreamLatencyMs {
responseLatencyMs = forwardDurationMs - upstreamLatencyMs
}
}
if s.created == 0 {
if created := getInt64(payload["created_at"]); created != 0 {
s.created = created
} else if created := getInt64(payload["created"]); created != 0 {
s.created = created
service.SetOpsLatencyMs(c, service.OpsResponseLatencyMsKey, responseLatencyMs)
if err == nil && result != nil && result.FirstTokenMs != nil {
service.SetOpsLatencyMs(c, service.OpsTimeToFirstTokenMsKey, int64(*result.FirstTokenMs))
}
if err != nil {
var failoverErr *service.UpstreamFailoverError
if errors.As(err, &failoverErr) {
h.gatewayService.ReportOpenAIAccountScheduleResult(account.ID, false, nil)
// Pool mode: retry on the same account
if failoverErr.RetryableOnSameAccount {
retryLimit := account.GetPoolModeRetryCount()
if sameAccountRetryCount[account.ID] < retryLimit {
sameAccountRetryCount[account.ID]++
reqLog.Warn("openai_chat_completions.pool_mode_same_account_retry",
zap.Int64("account_id", account.ID),
zap.Int("upstream_status", failoverErr.StatusCode),
zap.Int("retry_limit", retryLimit),
zap.Int("retry_count", sameAccountRetryCount[account.ID]),
)
select {
case <-c.Request.Context().Done():
return
case <-time.After(sameAccountRetryDelay):
}
continue
}
}
h.gatewayService.RecordOpenAIAccountSwitch()
failedAccountIDs[account.ID] = struct{}{}
lastFailoverErr = failoverErr
if switchCount >= maxAccountSwitches {
h.handleFailoverExhausted(c, failoverErr, streamStarted)
return
}
switchCount++
reqLog.Warn("openai_chat_completions.upstream_failover_switching",
zap.Int64("account_id", account.ID),
zap.Int("upstream_status", failoverErr.StatusCode),
zap.Int("switch_count", switchCount),
zap.Int("max_switches", maxAccountSwitches),
)
continue
}
h.gatewayService.ReportOpenAIAccountScheduleResult(account.ID, false, nil)
wroteFallback := h.ensureForwardErrorResponse(c, streamStarted)
reqLog.Warn("openai_chat_completions.forward_failed",
zap.Int64("account_id", account.ID),
zap.Bool("fallback_error_response_written", wroteFallback),
zap.Error(err),
)
return
}
if result != nil {
h.gatewayService.ReportOpenAIAccountScheduleResult(account.ID, true, result.FirstTokenMs)
} else {
h.gatewayService.ReportOpenAIAccountScheduleResult(account.ID, true, nil)
}
}
}
func (s *chatCompletionStreamState) applyResponseMetadata(responseObj map[string]any) {
if s.id == "" {
if id := strings.TrimSpace(getString(responseObj["id"])); id != "" {
s.id = id
}
}
if s.model == "" {
if model := strings.TrimSpace(getString(responseObj["model"])); model != "" {
s.model = model
}
}
if s.created == 0 {
if created := getInt64(responseObj["created_at"]); created != 0 {
s.created = created
}
}
}
userAgent := c.GetHeader("User-Agent")
clientIP := ip.GetClientIP(c)
func (s *chatCompletionStreamState) applyResponseUsage(responseObj map[string]any) {
usage, ok := responseObj["usage"].(map[string]any)
if !ok {
h.submitUsageRecordTask(func(ctx context.Context) {
if err := h.gatewayService.RecordUsage(ctx, &service.OpenAIRecordUsageInput{
Result: result,
APIKey: apiKey,
User: apiKey.User,
Account: account,
Subscription: subscription,
UserAgent: userAgent,
IPAddress: clientIP,
APIKeyService: h.apiKeyService,
}); err != nil {
logger.L().With(
zap.String("component", "handler.openai_gateway.chat_completions"),
zap.Int64("user_id", subject.UserID),
zap.Int64("api_key_id", apiKey.ID),
zap.Any("group_id", apiKey.GroupID),
zap.String("model", reqModel),
zap.Int64("account_id", account.ID),
).Error("openai_chat_completions.record_usage_failed", zap.Error(err))
}
})
reqLog.Debug("openai_chat_completions.request_completed",
zap.Int64("account_id", account.ID),
zap.Int("switch_count", switchCount),
)
return
}
promptTokens := int(getNumber(usage["input_tokens"]))
completionTokens := int(getNumber(usage["output_tokens"]))
if promptTokens == 0 && completionTokens == 0 {
return
}
s.usage = map[string]any{
"prompt_tokens": promptTokens,
"completion_tokens": completionTokens,
"total_tokens": promptTokens + completionTokens,
}
}
func extractToolCallFromItem(item map[string]any) (string, string, string, bool) {
itemType := strings.TrimSpace(getString(item["type"]))
if itemType != "tool_call" && itemType != "function_call" {
return "", "", "", false
}
callID := strings.TrimSpace(getString(item["call_id"]))
if callID == "" {
callID = strings.TrimSpace(getString(item["id"]))
}
name := strings.TrimSpace(getString(item["name"]))
args := getString(item["arguments"])
if fn, ok := item["function"].(map[string]any); ok {
if name == "" {
name = strings.TrimSpace(getString(fn["name"]))
}
if args == "" {
args = getString(fn["arguments"])
}
}
if callID == "" && name == "" && args == "" {
return "", "", "", false
}
if callID == "" {
callID = "call_" + randomHexUnsafe(6)
}
return callID, name, args, true
}
func getString(value any) string {
switch v := value.(type) {
case string:
return v
case []byte:
return string(v)
case json.Number:
return v.String()
default:
return ""
}
}
func getNumber(value any) float64 {
switch v := value.(type) {
case float64:
return v
case float32:
return float64(v)
case int:
return float64(v)
case int64:
return float64(v)
case json.Number:
f, _ := v.Float64()
return f
default:
return 0
}
}
func getInt64(value any) int64 {
switch v := value.(type) {
case int64:
return v
case int:
return int64(v)
case float64:
return int64(v)
case json.Number:
i, _ := v.Int64()
return i
default:
return 0
}
}
func randomHexUnsafe(byteLength int) string {
if byteLength <= 0 {
byteLength = 8
}
buf := make([]byte, byteLength)
if _, err := rand.Read(buf); err != nil {
return "000000"
}
return hex.EncodeToString(buf)
}

View File

@@ -0,0 +1,733 @@
package apicompat
import (
"encoding/json"
"testing"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
// ---------------------------------------------------------------------------
// ChatCompletionsToResponses tests
// ---------------------------------------------------------------------------
func TestChatCompletionsToResponses_BasicText(t *testing.T) {
req := &ChatCompletionsRequest{
Model: "gpt-4o",
Messages: []ChatMessage{
{Role: "user", Content: json.RawMessage(`"Hello"`)},
},
}
resp, err := ChatCompletionsToResponses(req)
require.NoError(t, err)
assert.Equal(t, "gpt-4o", resp.Model)
assert.True(t, resp.Stream) // always forced true
assert.False(t, *resp.Store)
var items []ResponsesInputItem
require.NoError(t, json.Unmarshal(resp.Input, &items))
require.Len(t, items, 1)
assert.Equal(t, "user", items[0].Role)
}
func TestChatCompletionsToResponses_SystemMessage(t *testing.T) {
req := &ChatCompletionsRequest{
Model: "gpt-4o",
Messages: []ChatMessage{
{Role: "system", Content: json.RawMessage(`"You are helpful."`)},
{Role: "user", Content: json.RawMessage(`"Hi"`)},
},
}
resp, err := ChatCompletionsToResponses(req)
require.NoError(t, err)
var items []ResponsesInputItem
require.NoError(t, json.Unmarshal(resp.Input, &items))
require.Len(t, items, 2)
assert.Equal(t, "system", items[0].Role)
assert.Equal(t, "user", items[1].Role)
}
func TestChatCompletionsToResponses_ToolCalls(t *testing.T) {
req := &ChatCompletionsRequest{
Model: "gpt-4o",
Messages: []ChatMessage{
{Role: "user", Content: json.RawMessage(`"Call the function"`)},
{
Role: "assistant",
ToolCalls: []ChatToolCall{
{
ID: "call_1",
Type: "function",
Function: ChatFunctionCall{
Name: "ping",
Arguments: `{"host":"example.com"}`,
},
},
},
},
{
Role: "tool",
ToolCallID: "call_1",
Content: json.RawMessage(`"pong"`),
},
},
Tools: []ChatTool{
{
Type: "function",
Function: &ChatFunction{
Name: "ping",
Description: "Ping a host",
Parameters: json.RawMessage(`{"type":"object"}`),
},
},
},
}
resp, err := ChatCompletionsToResponses(req)
require.NoError(t, err)
var items []ResponsesInputItem
require.NoError(t, json.Unmarshal(resp.Input, &items))
// user + function_call + function_call_output = 3
// (assistant message with empty content + tool_calls → only function_call items emitted)
require.Len(t, items, 3)
// Check function_call item
assert.Equal(t, "function_call", items[1].Type)
assert.Equal(t, "call_1", items[1].CallID)
assert.Equal(t, "ping", items[1].Name)
// Check function_call_output item
assert.Equal(t, "function_call_output", items[2].Type)
assert.Equal(t, "call_1", items[2].CallID)
assert.Equal(t, "pong", items[2].Output)
// Check tools
require.Len(t, resp.Tools, 1)
assert.Equal(t, "function", resp.Tools[0].Type)
assert.Equal(t, "ping", resp.Tools[0].Name)
}
func TestChatCompletionsToResponses_MaxTokens(t *testing.T) {
t.Run("max_tokens", func(t *testing.T) {
maxTokens := 100
req := &ChatCompletionsRequest{
Model: "gpt-4o",
MaxTokens: &maxTokens,
Messages: []ChatMessage{{Role: "user", Content: json.RawMessage(`"Hi"`)}},
}
resp, err := ChatCompletionsToResponses(req)
require.NoError(t, err)
require.NotNil(t, resp.MaxOutputTokens)
// Below minMaxOutputTokens (128), should be clamped
assert.Equal(t, minMaxOutputTokens, *resp.MaxOutputTokens)
})
t.Run("max_completion_tokens_preferred", func(t *testing.T) {
maxTokens := 100
maxCompletion := 500
req := &ChatCompletionsRequest{
Model: "gpt-4o",
MaxTokens: &maxTokens,
MaxCompletionTokens: &maxCompletion,
Messages: []ChatMessage{{Role: "user", Content: json.RawMessage(`"Hi"`)}},
}
resp, err := ChatCompletionsToResponses(req)
require.NoError(t, err)
require.NotNil(t, resp.MaxOutputTokens)
assert.Equal(t, 500, *resp.MaxOutputTokens)
})
}
func TestChatCompletionsToResponses_ReasoningEffort(t *testing.T) {
req := &ChatCompletionsRequest{
Model: "gpt-4o",
ReasoningEffort: "high",
Messages: []ChatMessage{{Role: "user", Content: json.RawMessage(`"Hi"`)}},
}
resp, err := ChatCompletionsToResponses(req)
require.NoError(t, err)
require.NotNil(t, resp.Reasoning)
assert.Equal(t, "high", resp.Reasoning.Effort)
assert.Equal(t, "auto", resp.Reasoning.Summary)
}
func TestChatCompletionsToResponses_ImageURL(t *testing.T) {
content := `[{"type":"text","text":"Describe this"},{"type":"image_url","image_url":{"url":"data:image/png;base64,abc123"}}]`
req := &ChatCompletionsRequest{
Model: "gpt-4o",
Messages: []ChatMessage{
{Role: "user", Content: json.RawMessage(content)},
},
}
resp, err := ChatCompletionsToResponses(req)
require.NoError(t, err)
var items []ResponsesInputItem
require.NoError(t, json.Unmarshal(resp.Input, &items))
require.Len(t, items, 1)
var parts []ResponsesContentPart
require.NoError(t, json.Unmarshal(items[0].Content, &parts))
require.Len(t, parts, 2)
assert.Equal(t, "input_text", parts[0].Type)
assert.Equal(t, "Describe this", parts[0].Text)
assert.Equal(t, "input_image", parts[1].Type)
assert.Equal(t, "data:image/png;base64,abc123", parts[1].ImageURL)
}
func TestChatCompletionsToResponses_LegacyFunctions(t *testing.T) {
req := &ChatCompletionsRequest{
Model: "gpt-4o",
Messages: []ChatMessage{
{Role: "user", Content: json.RawMessage(`"Hi"`)},
},
Functions: []ChatFunction{
{
Name: "get_weather",
Description: "Get weather",
Parameters: json.RawMessage(`{"type":"object"}`),
},
},
FunctionCall: json.RawMessage(`{"name":"get_weather"}`),
}
resp, err := ChatCompletionsToResponses(req)
require.NoError(t, err)
require.Len(t, resp.Tools, 1)
assert.Equal(t, "function", resp.Tools[0].Type)
assert.Equal(t, "get_weather", resp.Tools[0].Name)
// tool_choice should be converted
require.NotNil(t, resp.ToolChoice)
var tc map[string]any
require.NoError(t, json.Unmarshal(resp.ToolChoice, &tc))
assert.Equal(t, "function", tc["type"])
}
func TestChatCompletionsToResponses_ServiceTier(t *testing.T) {
req := &ChatCompletionsRequest{
Model: "gpt-4o",
ServiceTier: "flex",
Messages: []ChatMessage{{Role: "user", Content: json.RawMessage(`"Hi"`)}},
}
resp, err := ChatCompletionsToResponses(req)
require.NoError(t, err)
assert.Equal(t, "flex", resp.ServiceTier)
}
func TestChatCompletionsToResponses_AssistantWithTextAndToolCalls(t *testing.T) {
req := &ChatCompletionsRequest{
Model: "gpt-4o",
Messages: []ChatMessage{
{Role: "user", Content: json.RawMessage(`"Do something"`)},
{
Role: "assistant",
Content: json.RawMessage(`"Let me call a function."`),
ToolCalls: []ChatToolCall{
{
ID: "call_abc",
Type: "function",
Function: ChatFunctionCall{
Name: "do_thing",
Arguments: `{}`,
},
},
},
},
},
}
resp, err := ChatCompletionsToResponses(req)
require.NoError(t, err)
var items []ResponsesInputItem
require.NoError(t, json.Unmarshal(resp.Input, &items))
// user + assistant message (with text) + function_call
require.Len(t, items, 3)
assert.Equal(t, "user", items[0].Role)
assert.Equal(t, "assistant", items[1].Role)
assert.Equal(t, "function_call", items[2].Type)
}
// ---------------------------------------------------------------------------
// ResponsesToChatCompletions tests
// ---------------------------------------------------------------------------
func TestResponsesToChatCompletions_BasicText(t *testing.T) {
resp := &ResponsesResponse{
ID: "resp_123",
Status: "completed",
Output: []ResponsesOutput{
{
Type: "message",
Content: []ResponsesContentPart{
{Type: "output_text", Text: "Hello, world!"},
},
},
},
Usage: &ResponsesUsage{
InputTokens: 10,
OutputTokens: 5,
TotalTokens: 15,
},
}
chat := ResponsesToChatCompletions(resp, "gpt-4o")
assert.Equal(t, "chat.completion", chat.Object)
assert.Equal(t, "gpt-4o", chat.Model)
require.Len(t, chat.Choices, 1)
assert.Equal(t, "stop", chat.Choices[0].FinishReason)
var content string
require.NoError(t, json.Unmarshal(chat.Choices[0].Message.Content, &content))
assert.Equal(t, "Hello, world!", content)
require.NotNil(t, chat.Usage)
assert.Equal(t, 10, chat.Usage.PromptTokens)
assert.Equal(t, 5, chat.Usage.CompletionTokens)
assert.Equal(t, 15, chat.Usage.TotalTokens)
}
func TestResponsesToChatCompletions_ToolCalls(t *testing.T) {
resp := &ResponsesResponse{
ID: "resp_456",
Status: "completed",
Output: []ResponsesOutput{
{
Type: "function_call",
CallID: "call_xyz",
Name: "get_weather",
Arguments: `{"city":"NYC"}`,
},
},
}
chat := ResponsesToChatCompletions(resp, "gpt-4o")
require.Len(t, chat.Choices, 1)
assert.Equal(t, "tool_calls", chat.Choices[0].FinishReason)
msg := chat.Choices[0].Message
require.Len(t, msg.ToolCalls, 1)
assert.Equal(t, "call_xyz", msg.ToolCalls[0].ID)
assert.Equal(t, "function", msg.ToolCalls[0].Type)
assert.Equal(t, "get_weather", msg.ToolCalls[0].Function.Name)
assert.Equal(t, `{"city":"NYC"}`, msg.ToolCalls[0].Function.Arguments)
}
func TestResponsesToChatCompletions_Reasoning(t *testing.T) {
resp := &ResponsesResponse{
ID: "resp_789",
Status: "completed",
Output: []ResponsesOutput{
{
Type: "reasoning",
Summary: []ResponsesSummary{
{Type: "summary_text", Text: "I thought about it."},
},
},
{
Type: "message",
Content: []ResponsesContentPart{
{Type: "output_text", Text: "The answer is 42."},
},
},
},
}
chat := ResponsesToChatCompletions(resp, "gpt-4o")
require.Len(t, chat.Choices, 1)
var content string
require.NoError(t, json.Unmarshal(chat.Choices[0].Message.Content, &content))
// Reasoning summary is prepended to text
assert.Equal(t, "I thought about it.The answer is 42.", content)
}
func TestResponsesToChatCompletions_Incomplete(t *testing.T) {
resp := &ResponsesResponse{
ID: "resp_inc",
Status: "incomplete",
IncompleteDetails: &ResponsesIncompleteDetails{Reason: "max_output_tokens"},
Output: []ResponsesOutput{
{
Type: "message",
Content: []ResponsesContentPart{
{Type: "output_text", Text: "partial..."},
},
},
},
}
chat := ResponsesToChatCompletions(resp, "gpt-4o")
require.Len(t, chat.Choices, 1)
assert.Equal(t, "length", chat.Choices[0].FinishReason)
}
func TestResponsesToChatCompletions_CachedTokens(t *testing.T) {
resp := &ResponsesResponse{
ID: "resp_cache",
Status: "completed",
Output: []ResponsesOutput{
{
Type: "message",
Content: []ResponsesContentPart{{Type: "output_text", Text: "cached"}},
},
},
Usage: &ResponsesUsage{
InputTokens: 100,
OutputTokens: 10,
TotalTokens: 110,
InputTokensDetails: &ResponsesInputTokensDetails{
CachedTokens: 80,
},
},
}
chat := ResponsesToChatCompletions(resp, "gpt-4o")
require.NotNil(t, chat.Usage)
require.NotNil(t, chat.Usage.PromptTokensDetails)
assert.Equal(t, 80, chat.Usage.PromptTokensDetails.CachedTokens)
}
func TestResponsesToChatCompletions_WebSearch(t *testing.T) {
resp := &ResponsesResponse{
ID: "resp_ws",
Status: "completed",
Output: []ResponsesOutput{
{
Type: "web_search_call",
Action: &WebSearchAction{Type: "search", Query: "test"},
},
{
Type: "message",
Content: []ResponsesContentPart{{Type: "output_text", Text: "search results"}},
},
},
}
chat := ResponsesToChatCompletions(resp, "gpt-4o")
require.Len(t, chat.Choices, 1)
assert.Equal(t, "stop", chat.Choices[0].FinishReason)
var content string
require.NoError(t, json.Unmarshal(chat.Choices[0].Message.Content, &content))
assert.Equal(t, "search results", content)
}
// ---------------------------------------------------------------------------
// Streaming: ResponsesEventToChatChunks tests
// ---------------------------------------------------------------------------
func TestResponsesEventToChatChunks_TextDelta(t *testing.T) {
state := NewResponsesEventToChatState()
state.Model = "gpt-4o"
// response.created → role chunk
chunks := ResponsesEventToChatChunks(&ResponsesStreamEvent{
Type: "response.created",
Response: &ResponsesResponse{
ID: "resp_stream",
},
}, state)
require.Len(t, chunks, 1)
assert.Equal(t, "assistant", chunks[0].Choices[0].Delta.Role)
assert.True(t, state.SentRole)
// response.output_text.delta → content chunk
chunks = ResponsesEventToChatChunks(&ResponsesStreamEvent{
Type: "response.output_text.delta",
Delta: "Hello",
}, state)
require.Len(t, chunks, 1)
require.NotNil(t, chunks[0].Choices[0].Delta.Content)
assert.Equal(t, "Hello", *chunks[0].Choices[0].Delta.Content)
}
func TestResponsesEventToChatChunks_ToolCallDelta(t *testing.T) {
state := NewResponsesEventToChatState()
state.Model = "gpt-4o"
state.SentRole = true
// response.output_item.added (function_call) — output_index=1 (e.g. after a message item at 0)
chunks := ResponsesEventToChatChunks(&ResponsesStreamEvent{
Type: "response.output_item.added",
OutputIndex: 1,
Item: &ResponsesOutput{
Type: "function_call",
CallID: "call_1",
Name: "get_weather",
},
}, state)
require.Len(t, chunks, 1)
require.Len(t, chunks[0].Choices[0].Delta.ToolCalls, 1)
tc := chunks[0].Choices[0].Delta.ToolCalls[0]
assert.Equal(t, "call_1", tc.ID)
assert.Equal(t, "get_weather", tc.Function.Name)
require.NotNil(t, tc.Index)
assert.Equal(t, 0, *tc.Index)
// response.function_call_arguments.delta — uses output_index (NOT call_id) to find tool
chunks = ResponsesEventToChatChunks(&ResponsesStreamEvent{
Type: "response.function_call_arguments.delta",
OutputIndex: 1, // matches the output_index from output_item.added above
Delta: `{"city":`,
}, state)
require.Len(t, chunks, 1)
tc = chunks[0].Choices[0].Delta.ToolCalls[0]
require.NotNil(t, tc.Index)
assert.Equal(t, 0, *tc.Index, "argument delta must use same index as the tool call")
assert.Equal(t, `{"city":`, tc.Function.Arguments)
// Add a second function call at output_index=2
chunks = ResponsesEventToChatChunks(&ResponsesStreamEvent{
Type: "response.output_item.added",
OutputIndex: 2,
Item: &ResponsesOutput{
Type: "function_call",
CallID: "call_2",
Name: "get_time",
},
}, state)
require.Len(t, chunks, 1)
tc = chunks[0].Choices[0].Delta.ToolCalls[0]
require.NotNil(t, tc.Index)
assert.Equal(t, 1, *tc.Index, "second tool call should get index 1")
// Argument delta for second tool call
chunks = ResponsesEventToChatChunks(&ResponsesStreamEvent{
Type: "response.function_call_arguments.delta",
OutputIndex: 2,
Delta: `{"tz":"UTC"}`,
}, state)
require.Len(t, chunks, 1)
tc = chunks[0].Choices[0].Delta.ToolCalls[0]
require.NotNil(t, tc.Index)
assert.Equal(t, 1, *tc.Index, "second tool arg delta must use index 1")
// Argument delta for first tool call (interleaved)
chunks = ResponsesEventToChatChunks(&ResponsesStreamEvent{
Type: "response.function_call_arguments.delta",
OutputIndex: 1,
Delta: `"Tokyo"}`,
}, state)
require.Len(t, chunks, 1)
tc = chunks[0].Choices[0].Delta.ToolCalls[0]
require.NotNil(t, tc.Index)
assert.Equal(t, 0, *tc.Index, "first tool arg delta must still use index 0")
}
func TestResponsesEventToChatChunks_Completed(t *testing.T) {
state := NewResponsesEventToChatState()
state.Model = "gpt-4o"
state.IncludeUsage = true
chunks := ResponsesEventToChatChunks(&ResponsesStreamEvent{
Type: "response.completed",
Response: &ResponsesResponse{
Status: "completed",
Usage: &ResponsesUsage{
InputTokens: 50,
OutputTokens: 20,
TotalTokens: 70,
InputTokensDetails: &ResponsesInputTokensDetails{
CachedTokens: 30,
},
},
},
}, state)
// finish chunk + usage chunk
require.Len(t, chunks, 2)
// First chunk: finish_reason
require.NotNil(t, chunks[0].Choices[0].FinishReason)
assert.Equal(t, "stop", *chunks[0].Choices[0].FinishReason)
// Second chunk: usage
require.NotNil(t, chunks[1].Usage)
assert.Equal(t, 50, chunks[1].Usage.PromptTokens)
assert.Equal(t, 20, chunks[1].Usage.CompletionTokens)
assert.Equal(t, 70, chunks[1].Usage.TotalTokens)
require.NotNil(t, chunks[1].Usage.PromptTokensDetails)
assert.Equal(t, 30, chunks[1].Usage.PromptTokensDetails.CachedTokens)
}
func TestResponsesEventToChatChunks_CompletedWithToolCalls(t *testing.T) {
state := NewResponsesEventToChatState()
state.Model = "gpt-4o"
state.SawToolCall = true
chunks := ResponsesEventToChatChunks(&ResponsesStreamEvent{
Type: "response.completed",
Response: &ResponsesResponse{
Status: "completed",
},
}, state)
require.Len(t, chunks, 1)
require.NotNil(t, chunks[0].Choices[0].FinishReason)
assert.Equal(t, "tool_calls", *chunks[0].Choices[0].FinishReason)
}
func TestResponsesEventToChatChunks_ReasoningDelta(t *testing.T) {
state := NewResponsesEventToChatState()
state.Model = "gpt-4o"
state.SentRole = true
chunks := ResponsesEventToChatChunks(&ResponsesStreamEvent{
Type: "response.reasoning_summary_text.delta",
Delta: "Thinking...",
}, state)
require.Len(t, chunks, 1)
require.NotNil(t, chunks[0].Choices[0].Delta.Content)
assert.Equal(t, "Thinking...", *chunks[0].Choices[0].Delta.Content)
}
func TestFinalizeResponsesChatStream(t *testing.T) {
state := NewResponsesEventToChatState()
state.Model = "gpt-4o"
state.IncludeUsage = true
state.Usage = &ChatUsage{
PromptTokens: 100,
CompletionTokens: 50,
TotalTokens: 150,
}
chunks := FinalizeResponsesChatStream(state)
require.Len(t, chunks, 2)
// Finish chunk
require.NotNil(t, chunks[0].Choices[0].FinishReason)
assert.Equal(t, "stop", *chunks[0].Choices[0].FinishReason)
// Usage chunk
require.NotNil(t, chunks[1].Usage)
assert.Equal(t, 100, chunks[1].Usage.PromptTokens)
// Idempotent: second call returns nil
assert.Nil(t, FinalizeResponsesChatStream(state))
}
func TestFinalizeResponsesChatStream_AfterCompleted(t *testing.T) {
// If response.completed already emitted the finish chunk, FinalizeResponsesChatStream
// must be a no-op (prevents double finish_reason being sent to the client).
state := NewResponsesEventToChatState()
state.Model = "gpt-4o"
state.IncludeUsage = true
// Simulate response.completed
chunks := ResponsesEventToChatChunks(&ResponsesStreamEvent{
Type: "response.completed",
Response: &ResponsesResponse{
Status: "completed",
Usage: &ResponsesUsage{
InputTokens: 10,
OutputTokens: 5,
TotalTokens: 15,
},
},
}, state)
require.NotEmpty(t, chunks) // finish + usage chunks
// Now FinalizeResponsesChatStream should return nil — already finalized.
assert.Nil(t, FinalizeResponsesChatStream(state))
}
func TestChatChunkToSSE(t *testing.T) {
chunk := ChatCompletionsChunk{
ID: "chatcmpl-test",
Object: "chat.completion.chunk",
Created: 1700000000,
Model: "gpt-4o",
Choices: []ChatChunkChoice{
{
Index: 0,
Delta: ChatDelta{Role: "assistant"},
FinishReason: nil,
},
},
}
sse, err := ChatChunkToSSE(chunk)
require.NoError(t, err)
assert.Contains(t, sse, "data: ")
assert.Contains(t, sse, "chatcmpl-test")
assert.Contains(t, sse, "assistant")
assert.True(t, len(sse) > 10)
}
// ---------------------------------------------------------------------------
// Stream round-trip test
// ---------------------------------------------------------------------------
func TestChatCompletionsStreamRoundTrip(t *testing.T) {
// Simulate: client sends chat completions request, upstream returns Responses SSE events.
// Verify that the streaming state machine produces correct chat completions chunks.
state := NewResponsesEventToChatState()
state.Model = "gpt-4o"
state.IncludeUsage = true
var allChunks []ChatCompletionsChunk
// 1. response.created
chunks := ResponsesEventToChatChunks(&ResponsesStreamEvent{
Type: "response.created",
Response: &ResponsesResponse{ID: "resp_rt"},
}, state)
allChunks = append(allChunks, chunks...)
// 2. text deltas
for _, text := range []string{"Hello", ", ", "world", "!"} {
chunks = ResponsesEventToChatChunks(&ResponsesStreamEvent{
Type: "response.output_text.delta",
Delta: text,
}, state)
allChunks = append(allChunks, chunks...)
}
// 3. response.completed
chunks = ResponsesEventToChatChunks(&ResponsesStreamEvent{
Type: "response.completed",
Response: &ResponsesResponse{
Status: "completed",
Usage: &ResponsesUsage{
InputTokens: 10,
OutputTokens: 4,
TotalTokens: 14,
},
},
}, state)
allChunks = append(allChunks, chunks...)
// Verify: role chunk + 4 text chunks + finish chunk + usage chunk = 7
require.Len(t, allChunks, 7)
// First chunk has role
assert.Equal(t, "assistant", allChunks[0].Choices[0].Delta.Role)
// Text chunks
var fullText string
for i := 1; i <= 4; i++ {
require.NotNil(t, allChunks[i].Choices[0].Delta.Content)
fullText += *allChunks[i].Choices[0].Delta.Content
}
assert.Equal(t, "Hello, world!", fullText)
// Finish chunk
require.NotNil(t, allChunks[5].Choices[0].FinishReason)
assert.Equal(t, "stop", *allChunks[5].Choices[0].FinishReason)
// Usage chunk
require.NotNil(t, allChunks[6].Usage)
assert.Equal(t, 10, allChunks[6].Usage.PromptTokens)
assert.Equal(t, 4, allChunks[6].Usage.CompletionTokens)
// All chunks share the same ID
for _, c := range allChunks {
assert.Equal(t, "resp_rt", c.ID)
}
}

View File

@@ -0,0 +1,312 @@
package apicompat
import (
"encoding/json"
"fmt"
)
// ChatCompletionsToResponses converts a Chat Completions request into a
// Responses API request. The upstream always streams, so Stream is forced to
// true. store is always false and reasoning.encrypted_content is always
// included so that the response translator has full context.
func ChatCompletionsToResponses(req *ChatCompletionsRequest) (*ResponsesRequest, error) {
input, err := convertChatMessagesToResponsesInput(req.Messages)
if err != nil {
return nil, err
}
inputJSON, err := json.Marshal(input)
if err != nil {
return nil, err
}
out := &ResponsesRequest{
Model: req.Model,
Input: inputJSON,
Temperature: req.Temperature,
TopP: req.TopP,
Stream: true, // upstream always streams
Include: []string{"reasoning.encrypted_content"},
ServiceTier: req.ServiceTier,
}
storeFalse := false
out.Store = &storeFalse
// max_tokens / max_completion_tokens → max_output_tokens, prefer max_completion_tokens
maxTokens := 0
if req.MaxTokens != nil {
maxTokens = *req.MaxTokens
}
if req.MaxCompletionTokens != nil {
maxTokens = *req.MaxCompletionTokens
}
if maxTokens > 0 {
v := maxTokens
if v < minMaxOutputTokens {
v = minMaxOutputTokens
}
out.MaxOutputTokens = &v
}
// reasoning_effort → reasoning.effort + reasoning.summary="auto"
if req.ReasoningEffort != "" {
out.Reasoning = &ResponsesReasoning{
Effort: req.ReasoningEffort,
Summary: "auto",
}
}
// tools[] and legacy functions[] → ResponsesTool[]
if len(req.Tools) > 0 || len(req.Functions) > 0 {
out.Tools = convertChatToolsToResponses(req.Tools, req.Functions)
}
// tool_choice: already compatible format — pass through directly.
// Legacy function_call needs mapping.
if len(req.ToolChoice) > 0 {
out.ToolChoice = req.ToolChoice
} else if len(req.FunctionCall) > 0 {
tc, err := convertChatFunctionCallToToolChoice(req.FunctionCall)
if err != nil {
return nil, fmt.Errorf("convert function_call: %w", err)
}
out.ToolChoice = tc
}
return out, nil
}
// convertChatMessagesToResponsesInput converts the Chat Completions messages
// array into a Responses API input items array.
func convertChatMessagesToResponsesInput(msgs []ChatMessage) ([]ResponsesInputItem, error) {
var out []ResponsesInputItem
for _, m := range msgs {
items, err := chatMessageToResponsesItems(m)
if err != nil {
return nil, err
}
out = append(out, items...)
}
return out, nil
}
// chatMessageToResponsesItems converts a single ChatMessage into one or more
// ResponsesInputItem values.
func chatMessageToResponsesItems(m ChatMessage) ([]ResponsesInputItem, error) {
switch m.Role {
case "system":
return chatSystemToResponses(m)
case "user":
return chatUserToResponses(m)
case "assistant":
return chatAssistantToResponses(m)
case "tool":
return chatToolToResponses(m)
case "function":
return chatFunctionToResponses(m)
default:
return chatUserToResponses(m)
}
}
// chatSystemToResponses converts a system message.
func chatSystemToResponses(m ChatMessage) ([]ResponsesInputItem, error) {
text, err := parseChatContent(m.Content)
if err != nil {
return nil, err
}
content, err := json.Marshal(text)
if err != nil {
return nil, err
}
return []ResponsesInputItem{{Role: "system", Content: content}}, nil
}
// chatUserToResponses converts a user message, handling both plain strings and
// multi-modal content arrays.
func chatUserToResponses(m ChatMessage) ([]ResponsesInputItem, error) {
// Try plain string first.
var s string
if err := json.Unmarshal(m.Content, &s); err == nil {
content, _ := json.Marshal(s)
return []ResponsesInputItem{{Role: "user", Content: content}}, nil
}
var parts []ChatContentPart
if err := json.Unmarshal(m.Content, &parts); err != nil {
return nil, fmt.Errorf("parse user content: %w", err)
}
var responseParts []ResponsesContentPart
for _, p := range parts {
switch p.Type {
case "text":
if p.Text != "" {
responseParts = append(responseParts, ResponsesContentPart{
Type: "input_text",
Text: p.Text,
})
}
case "image_url":
if p.ImageURL != nil && p.ImageURL.URL != "" {
responseParts = append(responseParts, ResponsesContentPart{
Type: "input_image",
ImageURL: p.ImageURL.URL,
})
}
}
}
content, err := json.Marshal(responseParts)
if err != nil {
return nil, err
}
return []ResponsesInputItem{{Role: "user", Content: content}}, nil
}
// chatAssistantToResponses converts an assistant message. If there is both
// text content and tool_calls, the text is emitted as an assistant message
// first, then each tool_call becomes a function_call item. If the content is
// empty/nil and there are tool_calls, only function_call items are emitted.
func chatAssistantToResponses(m ChatMessage) ([]ResponsesInputItem, error) {
var items []ResponsesInputItem
// Emit assistant message with output_text if content is non-empty.
if len(m.Content) > 0 {
var s string
if err := json.Unmarshal(m.Content, &s); err == nil && s != "" {
parts := []ResponsesContentPart{{Type: "output_text", Text: s}}
partsJSON, err := json.Marshal(parts)
if err != nil {
return nil, err
}
items = append(items, ResponsesInputItem{Role: "assistant", Content: partsJSON})
}
}
// Emit one function_call item per tool_call.
for _, tc := range m.ToolCalls {
args := tc.Function.Arguments
if args == "" {
args = "{}"
}
items = append(items, ResponsesInputItem{
Type: "function_call",
CallID: tc.ID,
Name: tc.Function.Name,
Arguments: args,
ID: tc.ID,
})
}
return items, nil
}
// chatToolToResponses converts a tool result message (role=tool) into a
// function_call_output item.
func chatToolToResponses(m ChatMessage) ([]ResponsesInputItem, error) {
output, err := parseChatContent(m.Content)
if err != nil {
return nil, err
}
if output == "" {
output = "(empty)"
}
return []ResponsesInputItem{{
Type: "function_call_output",
CallID: m.ToolCallID,
Output: output,
}}, nil
}
// chatFunctionToResponses converts a legacy function result message
// (role=function) into a function_call_output item. The Name field is used as
// call_id since legacy function calls do not carry a separate call_id.
func chatFunctionToResponses(m ChatMessage) ([]ResponsesInputItem, error) {
output, err := parseChatContent(m.Content)
if err != nil {
return nil, err
}
if output == "" {
output = "(empty)"
}
return []ResponsesInputItem{{
Type: "function_call_output",
CallID: m.Name,
Output: output,
}}, nil
}
// parseChatContent returns the string value of a ChatMessage Content field.
// Content must be a JSON string. Returns "" if content is null or empty.
func parseChatContent(raw json.RawMessage) (string, error) {
if len(raw) == 0 {
return "", nil
}
var s string
if err := json.Unmarshal(raw, &s); err != nil {
return "", fmt.Errorf("parse content as string: %w", err)
}
return s, nil
}
// convertChatToolsToResponses maps Chat Completions tool definitions and legacy
// function definitions to Responses API tool definitions.
func convertChatToolsToResponses(tools []ChatTool, functions []ChatFunction) []ResponsesTool {
var out []ResponsesTool
for _, t := range tools {
if t.Type != "function" || t.Function == nil {
continue
}
rt := ResponsesTool{
Type: "function",
Name: t.Function.Name,
Description: t.Function.Description,
Parameters: t.Function.Parameters,
Strict: t.Function.Strict,
}
out = append(out, rt)
}
// Legacy functions[] are treated as function-type tools.
for _, f := range functions {
rt := ResponsesTool{
Type: "function",
Name: f.Name,
Description: f.Description,
Parameters: f.Parameters,
Strict: f.Strict,
}
out = append(out, rt)
}
return out
}
// convertChatFunctionCallToToolChoice maps the legacy function_call field to a
// Responses API tool_choice value.
//
// "auto" → "auto"
// "none" → "none"
// {"name":"X"} → {"type":"function","function":{"name":"X"}}
func convertChatFunctionCallToToolChoice(raw json.RawMessage) (json.RawMessage, error) {
// Try string first ("auto", "none", etc.) — pass through as-is.
var s string
if err := json.Unmarshal(raw, &s); err == nil {
return json.Marshal(s)
}
// Object form: {"name":"X"}
var obj struct {
Name string `json:"name"`
}
if err := json.Unmarshal(raw, &obj); err != nil {
return nil, err
}
return json.Marshal(map[string]any{
"type": "function",
"function": map[string]string{"name": obj.Name},
})
}

View File

@@ -0,0 +1,368 @@
package apicompat
import (
"crypto/rand"
"encoding/hex"
"encoding/json"
"fmt"
"time"
)
// ---------------------------------------------------------------------------
// Non-streaming: ResponsesResponse → ChatCompletionsResponse
// ---------------------------------------------------------------------------
// ResponsesToChatCompletions converts a Responses API response into a Chat
// Completions response. Text output items are concatenated into
// choices[0].message.content; function_call items become tool_calls.
func ResponsesToChatCompletions(resp *ResponsesResponse, model string) *ChatCompletionsResponse {
id := resp.ID
if id == "" {
id = generateChatCmplID()
}
out := &ChatCompletionsResponse{
ID: id,
Object: "chat.completion",
Created: time.Now().Unix(),
Model: model,
}
var contentText string
var toolCalls []ChatToolCall
for _, item := range resp.Output {
switch item.Type {
case "message":
for _, part := range item.Content {
if part.Type == "output_text" && part.Text != "" {
contentText += part.Text
}
}
case "function_call":
toolCalls = append(toolCalls, ChatToolCall{
ID: item.CallID,
Type: "function",
Function: ChatFunctionCall{
Name: item.Name,
Arguments: item.Arguments,
},
})
case "reasoning":
for _, s := range item.Summary {
if s.Type == "summary_text" && s.Text != "" {
contentText += s.Text
}
}
case "web_search_call":
// silently consumed — results already incorporated into text output
}
}
msg := ChatMessage{Role: "assistant"}
if len(toolCalls) > 0 {
msg.ToolCalls = toolCalls
}
if contentText != "" {
raw, _ := json.Marshal(contentText)
msg.Content = raw
}
finishReason := responsesStatusToChatFinishReason(resp.Status, resp.IncompleteDetails, toolCalls)
out.Choices = []ChatChoice{{
Index: 0,
Message: msg,
FinishReason: finishReason,
}}
if resp.Usage != nil {
usage := &ChatUsage{
PromptTokens: resp.Usage.InputTokens,
CompletionTokens: resp.Usage.OutputTokens,
TotalTokens: resp.Usage.InputTokens + resp.Usage.OutputTokens,
}
if resp.Usage.InputTokensDetails != nil && resp.Usage.InputTokensDetails.CachedTokens > 0 {
usage.PromptTokensDetails = &ChatTokenDetails{
CachedTokens: resp.Usage.InputTokensDetails.CachedTokens,
}
}
out.Usage = usage
}
return out
}
func responsesStatusToChatFinishReason(status string, details *ResponsesIncompleteDetails, toolCalls []ChatToolCall) string {
switch status {
case "incomplete":
if details != nil && details.Reason == "max_output_tokens" {
return "length"
}
return "stop"
case "completed":
if len(toolCalls) > 0 {
return "tool_calls"
}
return "stop"
default:
return "stop"
}
}
// ---------------------------------------------------------------------------
// Streaming: ResponsesStreamEvent → []ChatCompletionsChunk (stateful converter)
// ---------------------------------------------------------------------------
// ResponsesEventToChatState tracks state for converting a sequence of Responses
// SSE events into Chat Completions SSE chunks.
type ResponsesEventToChatState struct {
ID string
Model string
Created int64
SentRole bool
SawToolCall bool
SawText bool
Finalized bool // true after finish chunk has been emitted
NextToolCallIndex int // next sequential tool_call index to assign
OutputIndexToToolIndex map[int]int // Responses output_index → Chat tool_calls index
IncludeUsage bool
Usage *ChatUsage
}
// NewResponsesEventToChatState returns an initialised stream state.
func NewResponsesEventToChatState() *ResponsesEventToChatState {
return &ResponsesEventToChatState{
ID: generateChatCmplID(),
Created: time.Now().Unix(),
OutputIndexToToolIndex: make(map[int]int),
}
}
// ResponsesEventToChatChunks converts a single Responses SSE event into zero
// or more Chat Completions chunks, updating state as it goes.
func ResponsesEventToChatChunks(evt *ResponsesStreamEvent, state *ResponsesEventToChatState) []ChatCompletionsChunk {
switch evt.Type {
case "response.created":
return resToChatHandleCreated(evt, state)
case "response.output_text.delta":
return resToChatHandleTextDelta(evt, state)
case "response.output_item.added":
return resToChatHandleOutputItemAdded(evt, state)
case "response.function_call_arguments.delta":
return resToChatHandleFuncArgsDelta(evt, state)
case "response.reasoning_summary_text.delta":
return resToChatHandleReasoningDelta(evt, state)
case "response.completed", "response.incomplete", "response.failed":
return resToChatHandleCompleted(evt, state)
default:
return nil
}
}
// FinalizeResponsesChatStream emits a final chunk with finish_reason if the
// stream ended without a proper completion event (e.g. upstream disconnect).
// It is idempotent: if a completion event already emitted the finish chunk,
// this returns nil.
func FinalizeResponsesChatStream(state *ResponsesEventToChatState) []ChatCompletionsChunk {
if state.Finalized {
return nil
}
state.Finalized = true
finishReason := "stop"
if state.SawToolCall {
finishReason = "tool_calls"
}
chunks := []ChatCompletionsChunk{makeChatFinishChunk(state, finishReason)}
if state.IncludeUsage && state.Usage != nil {
chunks = append(chunks, ChatCompletionsChunk{
ID: state.ID,
Object: "chat.completion.chunk",
Created: state.Created,
Model: state.Model,
Choices: []ChatChunkChoice{},
Usage: state.Usage,
})
}
return chunks
}
// ChatChunkToSSE formats a ChatCompletionsChunk as an SSE data line.
func ChatChunkToSSE(chunk ChatCompletionsChunk) (string, error) {
data, err := json.Marshal(chunk)
if err != nil {
return "", err
}
return fmt.Sprintf("data: %s\n\n", data), nil
}
// --- internal handlers ---
func resToChatHandleCreated(evt *ResponsesStreamEvent, state *ResponsesEventToChatState) []ChatCompletionsChunk {
if evt.Response != nil {
if evt.Response.ID != "" {
state.ID = evt.Response.ID
}
if state.Model == "" && evt.Response.Model != "" {
state.Model = evt.Response.Model
}
}
// Emit the role chunk.
if state.SentRole {
return nil
}
state.SentRole = true
role := "assistant"
return []ChatCompletionsChunk{makeChatDeltaChunk(state, ChatDelta{Role: role})}
}
func resToChatHandleTextDelta(evt *ResponsesStreamEvent, state *ResponsesEventToChatState) []ChatCompletionsChunk {
if evt.Delta == "" {
return nil
}
state.SawText = true
content := evt.Delta
return []ChatCompletionsChunk{makeChatDeltaChunk(state, ChatDelta{Content: &content})}
}
func resToChatHandleOutputItemAdded(evt *ResponsesStreamEvent, state *ResponsesEventToChatState) []ChatCompletionsChunk {
if evt.Item == nil || evt.Item.Type != "function_call" {
return nil
}
state.SawToolCall = true
idx := state.NextToolCallIndex
state.OutputIndexToToolIndex[evt.OutputIndex] = idx
state.NextToolCallIndex++
return []ChatCompletionsChunk{makeChatDeltaChunk(state, ChatDelta{
ToolCalls: []ChatToolCall{{
Index: &idx,
ID: evt.Item.CallID,
Type: "function",
Function: ChatFunctionCall{
Name: evt.Item.Name,
},
}},
})}
}
func resToChatHandleFuncArgsDelta(evt *ResponsesStreamEvent, state *ResponsesEventToChatState) []ChatCompletionsChunk {
if evt.Delta == "" {
return nil
}
idx, ok := state.OutputIndexToToolIndex[evt.OutputIndex]
if !ok {
return nil
}
return []ChatCompletionsChunk{makeChatDeltaChunk(state, ChatDelta{
ToolCalls: []ChatToolCall{{
Index: &idx,
Function: ChatFunctionCall{
Arguments: evt.Delta,
},
}},
})}
}
func resToChatHandleReasoningDelta(evt *ResponsesStreamEvent, state *ResponsesEventToChatState) []ChatCompletionsChunk {
if evt.Delta == "" {
return nil
}
content := evt.Delta
return []ChatCompletionsChunk{makeChatDeltaChunk(state, ChatDelta{Content: &content})}
}
func resToChatHandleCompleted(evt *ResponsesStreamEvent, state *ResponsesEventToChatState) []ChatCompletionsChunk {
state.Finalized = true
finishReason := "stop"
if evt.Response != nil {
if evt.Response.Usage != nil {
u := evt.Response.Usage
usage := &ChatUsage{
PromptTokens: u.InputTokens,
CompletionTokens: u.OutputTokens,
TotalTokens: u.InputTokens + u.OutputTokens,
}
if u.InputTokensDetails != nil && u.InputTokensDetails.CachedTokens > 0 {
usage.PromptTokensDetails = &ChatTokenDetails{
CachedTokens: u.InputTokensDetails.CachedTokens,
}
}
state.Usage = usage
}
switch evt.Response.Status {
case "incomplete":
if evt.Response.IncompleteDetails != nil && evt.Response.IncompleteDetails.Reason == "max_output_tokens" {
finishReason = "length"
}
case "completed":
if state.SawToolCall {
finishReason = "tool_calls"
}
}
} else if state.SawToolCall {
finishReason = "tool_calls"
}
var chunks []ChatCompletionsChunk
chunks = append(chunks, makeChatFinishChunk(state, finishReason))
if state.IncludeUsage && state.Usage != nil {
chunks = append(chunks, ChatCompletionsChunk{
ID: state.ID,
Object: "chat.completion.chunk",
Created: state.Created,
Model: state.Model,
Choices: []ChatChunkChoice{},
Usage: state.Usage,
})
}
return chunks
}
func makeChatDeltaChunk(state *ResponsesEventToChatState, delta ChatDelta) ChatCompletionsChunk {
return ChatCompletionsChunk{
ID: state.ID,
Object: "chat.completion.chunk",
Created: state.Created,
Model: state.Model,
Choices: []ChatChunkChoice{{
Index: 0,
Delta: delta,
FinishReason: nil,
}},
}
}
func makeChatFinishChunk(state *ResponsesEventToChatState, finishReason string) ChatCompletionsChunk {
empty := ""
return ChatCompletionsChunk{
ID: state.ID,
Object: "chat.completion.chunk",
Created: state.Created,
Model: state.Model,
Choices: []ChatChunkChoice{{
Index: 0,
Delta: ChatDelta{Content: &empty},
FinishReason: &finishReason,
}},
}
}
// generateChatCmplID returns a "chatcmpl-" prefixed random hex ID.
func generateChatCmplID() string {
b := make([]byte, 12)
_, _ = rand.Read(b)
return "chatcmpl-" + hex.EncodeToString(b)
}

View File

@@ -329,6 +329,148 @@ type ResponsesStreamEvent struct {
SequenceNumber int `json:"sequence_number,omitempty"`
}
// ---------------------------------------------------------------------------
// OpenAI Chat Completions API types
// ---------------------------------------------------------------------------
// ChatCompletionsRequest is the request body for POST /v1/chat/completions.
type ChatCompletionsRequest struct {
Model string `json:"model"`
Messages []ChatMessage `json:"messages"`
MaxTokens *int `json:"max_tokens,omitempty"`
MaxCompletionTokens *int `json:"max_completion_tokens,omitempty"`
Temperature *float64 `json:"temperature,omitempty"`
TopP *float64 `json:"top_p,omitempty"`
Stream bool `json:"stream,omitempty"`
StreamOptions *ChatStreamOptions `json:"stream_options,omitempty"`
Tools []ChatTool `json:"tools,omitempty"`
ToolChoice json.RawMessage `json:"tool_choice,omitempty"`
ReasoningEffort string `json:"reasoning_effort,omitempty"` // "low" | "medium" | "high"
ServiceTier string `json:"service_tier,omitempty"`
Stop json.RawMessage `json:"stop,omitempty"` // string or []string
// Legacy function calling (deprecated but still supported)
Functions []ChatFunction `json:"functions,omitempty"`
FunctionCall json.RawMessage `json:"function_call,omitempty"`
}
// ChatStreamOptions configures streaming behavior.
type ChatStreamOptions struct {
IncludeUsage bool `json:"include_usage,omitempty"`
}
// ChatMessage is a single message in the Chat Completions conversation.
type ChatMessage struct {
Role string `json:"role"` // "system" | "user" | "assistant" | "tool" | "function"
Content json.RawMessage `json:"content,omitempty"`
Name string `json:"name,omitempty"`
ToolCalls []ChatToolCall `json:"tool_calls,omitempty"`
ToolCallID string `json:"tool_call_id,omitempty"`
// Legacy function calling
FunctionCall *ChatFunctionCall `json:"function_call,omitempty"`
}
// ChatContentPart is a typed content part in a multi-modal message.
type ChatContentPart struct {
Type string `json:"type"` // "text" | "image_url"
Text string `json:"text,omitempty"`
ImageURL *ChatImageURL `json:"image_url,omitempty"`
}
// ChatImageURL contains the URL for an image content part.
type ChatImageURL struct {
URL string `json:"url"`
Detail string `json:"detail,omitempty"` // "auto" | "low" | "high"
}
// ChatTool describes a tool available to the model.
type ChatTool struct {
Type string `json:"type"` // "function"
Function *ChatFunction `json:"function,omitempty"`
}
// ChatFunction describes a function tool definition.
type ChatFunction struct {
Name string `json:"name"`
Description string `json:"description,omitempty"`
Parameters json.RawMessage `json:"parameters,omitempty"`
Strict *bool `json:"strict,omitempty"`
}
// ChatToolCall represents a tool call made by the assistant.
// Index is only populated in streaming chunks (omitted in non-streaming responses).
type ChatToolCall struct {
Index *int `json:"index,omitempty"`
ID string `json:"id,omitempty"`
Type string `json:"type,omitempty"` // "function"
Function ChatFunctionCall `json:"function"`
}
// ChatFunctionCall contains the function name and arguments.
type ChatFunctionCall struct {
Name string `json:"name"`
Arguments string `json:"arguments"`
}
// ChatCompletionsResponse is the non-streaming response from POST /v1/chat/completions.
type ChatCompletionsResponse struct {
ID string `json:"id"`
Object string `json:"object"` // "chat.completion"
Created int64 `json:"created"`
Model string `json:"model"`
Choices []ChatChoice `json:"choices"`
Usage *ChatUsage `json:"usage,omitempty"`
SystemFingerprint string `json:"system_fingerprint,omitempty"`
ServiceTier string `json:"service_tier,omitempty"`
}
// ChatChoice is a single completion choice.
type ChatChoice struct {
Index int `json:"index"`
Message ChatMessage `json:"message"`
FinishReason string `json:"finish_reason"` // "stop" | "length" | "tool_calls" | "content_filter"
}
// ChatUsage holds token counts in Chat Completions format.
type ChatUsage struct {
PromptTokens int `json:"prompt_tokens"`
CompletionTokens int `json:"completion_tokens"`
TotalTokens int `json:"total_tokens"`
PromptTokensDetails *ChatTokenDetails `json:"prompt_tokens_details,omitempty"`
}
// ChatTokenDetails provides a breakdown of token usage.
type ChatTokenDetails struct {
CachedTokens int `json:"cached_tokens,omitempty"`
}
// ChatCompletionsChunk is a single streaming chunk from POST /v1/chat/completions.
type ChatCompletionsChunk struct {
ID string `json:"id"`
Object string `json:"object"` // "chat.completion.chunk"
Created int64 `json:"created"`
Model string `json:"model"`
Choices []ChatChunkChoice `json:"choices"`
Usage *ChatUsage `json:"usage,omitempty"`
SystemFingerprint string `json:"system_fingerprint,omitempty"`
ServiceTier string `json:"service_tier,omitempty"`
}
// ChatChunkChoice is a single choice in a streaming chunk.
type ChatChunkChoice struct {
Index int `json:"index"`
Delta ChatDelta `json:"delta"`
FinishReason *string `json:"finish_reason"` // pointer: null when not final
}
// ChatDelta carries incremental content in a streaming chunk.
type ChatDelta struct {
Role string `json:"role,omitempty"`
Content *string `json:"content,omitempty"` // pointer: omit when not present, null vs "" matters
ToolCalls []ChatToolCall `json:"tool_calls,omitempty"`
}
// ---------------------------------------------------------------------------
// Shared constants
// ---------------------------------------------------------------------------

View File

@@ -1,513 +0,0 @@
package service
import (
"encoding/json"
"errors"
"strings"
"time"
)
// ConvertChatCompletionsToResponses converts an OpenAI Chat Completions request to a Responses request.
func ConvertChatCompletionsToResponses(req map[string]any) (map[string]any, error) {
if req == nil {
return nil, errors.New("request is nil")
}
model := strings.TrimSpace(getString(req["model"]))
if model == "" {
return nil, errors.New("model is required")
}
messagesRaw, ok := req["messages"]
if !ok {
return nil, errors.New("messages is required")
}
messages, ok := messagesRaw.([]any)
if !ok {
return nil, errors.New("messages must be an array")
}
input, err := convertChatMessagesToResponsesInput(messages)
if err != nil {
return nil, err
}
out := make(map[string]any, len(req)+1)
for key, value := range req {
switch key {
case "messages", "max_tokens", "max_completion_tokens", "stream_options", "functions", "function_call":
continue
default:
out[key] = value
}
}
out["model"] = model
out["input"] = input
if _, ok := out["max_output_tokens"]; !ok {
if v, ok := req["max_tokens"]; ok {
out["max_output_tokens"] = v
} else if v, ok := req["max_completion_tokens"]; ok {
out["max_output_tokens"] = v
}
}
if _, ok := out["tools"]; !ok {
if functions, ok := req["functions"].([]any); ok && len(functions) > 0 {
tools := make([]any, 0, len(functions))
for _, fn := range functions {
if fnMap, ok := fn.(map[string]any); ok {
tools = append(tools, map[string]any{
"type": "function",
"function": fnMap,
})
}
}
if len(tools) > 0 {
out["tools"] = tools
}
}
}
if _, ok := out["tool_choice"]; !ok {
if functionCall, ok := req["function_call"]; ok {
out["tool_choice"] = functionCall
}
}
return out, nil
}
// ConvertResponsesToChatCompletion converts an OpenAI Responses response body to Chat Completions format.
func ConvertResponsesToChatCompletion(body []byte) ([]byte, error) {
var resp map[string]any
if err := json.Unmarshal(body, &resp); err != nil {
return nil, err
}
id := strings.TrimSpace(getString(resp["id"]))
if id == "" {
id = "chatcmpl-" + safeRandomHex(12)
}
model := strings.TrimSpace(getString(resp["model"]))
created := getInt64(resp["created_at"])
if created == 0 {
created = getInt64(resp["created"])
}
if created == 0 {
created = time.Now().Unix()
}
text, toolCalls := extractResponseTextAndToolCalls(resp)
finishReason := "stop"
if len(toolCalls) > 0 {
finishReason = "tool_calls"
}
message := map[string]any{
"role": "assistant",
"content": text,
}
if len(toolCalls) > 0 {
message["tool_calls"] = toolCalls
}
chatResp := map[string]any{
"id": id,
"object": "chat.completion",
"created": created,
"model": model,
"choices": []any{
map[string]any{
"index": 0,
"message": message,
"finish_reason": finishReason,
},
},
}
if usage := extractResponseUsage(resp); usage != nil {
chatResp["usage"] = usage
}
if fingerprint := strings.TrimSpace(getString(resp["system_fingerprint"])); fingerprint != "" {
chatResp["system_fingerprint"] = fingerprint
}
return json.Marshal(chatResp)
}
func convertChatMessagesToResponsesInput(messages []any) ([]any, error) {
input := make([]any, 0, len(messages))
for _, msg := range messages {
msgMap, ok := msg.(map[string]any)
if !ok {
return nil, errors.New("message must be an object")
}
role := strings.TrimSpace(getString(msgMap["role"]))
if role == "" {
return nil, errors.New("message role is required")
}
switch role {
case "tool":
callID := strings.TrimSpace(getString(msgMap["tool_call_id"]))
if callID == "" {
callID = strings.TrimSpace(getString(msgMap["id"]))
}
output := extractMessageContentText(msgMap["content"])
input = append(input, map[string]any{
"type": "function_call_output",
"call_id": callID,
"output": output,
})
case "function":
callID := strings.TrimSpace(getString(msgMap["name"]))
output := extractMessageContentText(msgMap["content"])
input = append(input, map[string]any{
"type": "function_call_output",
"call_id": callID,
"output": output,
})
default:
convertedContent := convertChatContent(msgMap["content"])
toolCalls := []any(nil)
if role == "assistant" {
toolCalls = extractToolCallsFromMessage(msgMap)
}
skipAssistantMessage := role == "assistant" && len(toolCalls) > 0 && isEmptyContent(convertedContent)
if !skipAssistantMessage {
msgItem := map[string]any{
"role": role,
"content": convertedContent,
}
if name := strings.TrimSpace(getString(msgMap["name"])); name != "" {
msgItem["name"] = name
}
input = append(input, msgItem)
}
if role == "assistant" && len(toolCalls) > 0 {
input = append(input, toolCalls...)
}
}
}
return input, nil
}
func convertChatContent(content any) any {
switch v := content.(type) {
case nil:
return ""
case string:
return v
case []any:
converted := make([]any, 0, len(v))
for _, part := range v {
partMap, ok := part.(map[string]any)
if !ok {
converted = append(converted, part)
continue
}
partType := strings.TrimSpace(getString(partMap["type"]))
switch partType {
case "text":
text := getString(partMap["text"])
if text != "" {
converted = append(converted, map[string]any{
"type": "input_text",
"text": text,
})
continue
}
case "image_url":
imageURL := ""
if imageObj, ok := partMap["image_url"].(map[string]any); ok {
imageURL = getString(imageObj["url"])
} else {
imageURL = getString(partMap["image_url"])
}
if imageURL != "" {
converted = append(converted, map[string]any{
"type": "input_image",
"image_url": imageURL,
})
continue
}
case "input_text", "input_image":
converted = append(converted, partMap)
continue
}
converted = append(converted, partMap)
}
return converted
default:
return v
}
}
func extractToolCallsFromMessage(msg map[string]any) []any {
var out []any
if toolCalls, ok := msg["tool_calls"].([]any); ok {
for _, call := range toolCalls {
callMap, ok := call.(map[string]any)
if !ok {
continue
}
callID := strings.TrimSpace(getString(callMap["id"]))
if callID == "" {
callID = strings.TrimSpace(getString(callMap["call_id"]))
}
name := ""
args := ""
if fn, ok := callMap["function"].(map[string]any); ok {
name = strings.TrimSpace(getString(fn["name"]))
args = getString(fn["arguments"])
}
if name == "" && args == "" {
continue
}
item := map[string]any{
"type": "tool_call",
}
if callID != "" {
item["call_id"] = callID
}
if name != "" {
item["name"] = name
}
if args != "" {
item["arguments"] = args
}
out = append(out, item)
}
}
if fnCall, ok := msg["function_call"].(map[string]any); ok {
name := strings.TrimSpace(getString(fnCall["name"]))
args := getString(fnCall["arguments"])
if name != "" || args != "" {
callID := strings.TrimSpace(getString(msg["tool_call_id"]))
if callID == "" {
callID = name
}
item := map[string]any{
"type": "function_call",
}
if callID != "" {
item["call_id"] = callID
}
if name != "" {
item["name"] = name
}
if args != "" {
item["arguments"] = args
}
out = append(out, item)
}
}
return out
}
func extractMessageContentText(content any) string {
switch v := content.(type) {
case string:
return v
case []any:
parts := make([]string, 0, len(v))
for _, part := range v {
partMap, ok := part.(map[string]any)
if !ok {
continue
}
partType := strings.TrimSpace(getString(partMap["type"]))
if partType == "" || partType == "text" || partType == "output_text" || partType == "input_text" {
text := getString(partMap["text"])
if text != "" {
parts = append(parts, text)
}
}
}
return strings.Join(parts, "")
default:
return ""
}
}
func isEmptyContent(content any) bool {
switch v := content.(type) {
case nil:
return true
case string:
return strings.TrimSpace(v) == ""
case []any:
return len(v) == 0
default:
return false
}
}
func extractResponseTextAndToolCalls(resp map[string]any) (string, []any) {
output, ok := resp["output"].([]any)
if !ok {
if text, ok := resp["output_text"].(string); ok {
return text, nil
}
return "", nil
}
textParts := make([]string, 0)
toolCalls := make([]any, 0)
for _, item := range output {
itemMap, ok := item.(map[string]any)
if !ok {
continue
}
itemType := strings.TrimSpace(getString(itemMap["type"]))
if itemType == "tool_call" || itemType == "function_call" {
if tc := responseItemToChatToolCall(itemMap); tc != nil {
toolCalls = append(toolCalls, tc)
}
continue
}
content := itemMap["content"]
switch v := content.(type) {
case string:
if v != "" {
textParts = append(textParts, v)
}
case []any:
for _, part := range v {
partMap, ok := part.(map[string]any)
if !ok {
continue
}
partType := strings.TrimSpace(getString(partMap["type"]))
switch partType {
case "output_text", "text", "input_text":
text := getString(partMap["text"])
if text != "" {
textParts = append(textParts, text)
}
case "tool_call", "function_call":
if tc := responseItemToChatToolCall(partMap); tc != nil {
toolCalls = append(toolCalls, tc)
}
}
}
}
}
return strings.Join(textParts, ""), toolCalls
}
func responseItemToChatToolCall(item map[string]any) map[string]any {
callID := strings.TrimSpace(getString(item["call_id"]))
if callID == "" {
callID = strings.TrimSpace(getString(item["id"]))
}
name := strings.TrimSpace(getString(item["name"]))
arguments := getString(item["arguments"])
if fn, ok := item["function"].(map[string]any); ok {
if name == "" {
name = strings.TrimSpace(getString(fn["name"]))
}
if arguments == "" {
arguments = getString(fn["arguments"])
}
}
if name == "" && arguments == "" && callID == "" {
return nil
}
if callID == "" {
callID = "call_" + safeRandomHex(6)
}
return map[string]any{
"id": callID,
"type": "function",
"function": map[string]any{
"name": name,
"arguments": arguments,
},
}
}
func extractResponseUsage(resp map[string]any) map[string]any {
usage, ok := resp["usage"].(map[string]any)
if !ok {
return nil
}
promptTokens := int(getNumber(usage["input_tokens"]))
completionTokens := int(getNumber(usage["output_tokens"]))
if promptTokens == 0 && completionTokens == 0 {
return nil
}
return map[string]any{
"prompt_tokens": promptTokens,
"completion_tokens": completionTokens,
"total_tokens": promptTokens + completionTokens,
}
}
func getString(value any) string {
switch v := value.(type) {
case string:
return v
case []byte:
return string(v)
case json.Number:
return v.String()
default:
return ""
}
}
func getNumber(value any) float64 {
switch v := value.(type) {
case float64:
return v
case float32:
return float64(v)
case int:
return float64(v)
case int64:
return float64(v)
case json.Number:
f, _ := v.Float64()
return f
default:
return 0
}
}
func getInt64(value any) int64 {
switch v := value.(type) {
case int64:
return v
case int:
return int64(v)
case float64:
return int64(v)
case json.Number:
i, _ := v.Int64()
return i
default:
return 0
}
}
func safeRandomHex(byteLength int) string {
value, err := randomHexString(byteLength)
if err != nil || value == "" {
return "000000"
}
return value
}

View File

@@ -1,488 +0,0 @@
package service
import (
"bufio"
"bytes"
"context"
"encoding/json"
"errors"
"fmt"
"io"
"log"
"net/http"
"strings"
"sync/atomic"
"time"
"github.com/Wei-Shaw/sub2api/internal/util/responseheaders"
"github.com/gin-gonic/gin"
)
type chatStreamingResult struct {
usage *OpenAIUsage
firstTokenMs *int
}
func (s *OpenAIGatewayService) forwardChatCompletions(ctx context.Context, c *gin.Context, account *Account, body []byte, includeUsage bool, startTime time.Time) (*OpenAIForwardResult, error) {
// Parse request body once (avoid multiple parse/serialize cycles)
var reqBody map[string]any
if err := json.Unmarshal(body, &reqBody); err != nil {
return nil, fmt.Errorf("parse request: %w", err)
}
reqModel, _ := reqBody["model"].(string)
reqStream, _ := reqBody["stream"].(bool)
originalModel := reqModel
bodyModified := false
mappedModel := account.GetMappedModel(reqModel)
if mappedModel != reqModel {
log.Printf("[OpenAI Chat] Model mapping applied: %s -> %s (account: %s)", reqModel, mappedModel, account.Name)
reqBody["model"] = mappedModel
bodyModified = true
}
if reqStream && includeUsage {
streamOptions, _ := reqBody["stream_options"].(map[string]any)
if streamOptions == nil {
streamOptions = map[string]any{}
}
if _, ok := streamOptions["include_usage"]; !ok {
streamOptions["include_usage"] = true
reqBody["stream_options"] = streamOptions
bodyModified = true
}
}
if bodyModified {
var err error
body, err = json.Marshal(reqBody)
if err != nil {
return nil, fmt.Errorf("serialize request body: %w", err)
}
}
// Get access token
token, _, err := s.GetAccessToken(ctx, account)
if err != nil {
return nil, err
}
upstreamReq, err := s.buildChatCompletionsRequest(ctx, c, account, body, token)
if err != nil {
return nil, err
}
proxyURL := ""
if account.ProxyID != nil && account.Proxy != nil {
proxyURL = account.Proxy.URL()
}
if c != nil {
c.Set(OpsUpstreamRequestBodyKey, string(body))
}
resp, err := s.httpUpstream.Do(upstreamReq, proxyURL, account.ID, account.Concurrency)
if err != nil {
safeErr := sanitizeUpstreamErrorMessage(err.Error())
setOpsUpstreamError(c, 0, safeErr, "")
appendOpsUpstreamError(c, OpsUpstreamErrorEvent{
Platform: account.Platform,
AccountID: account.ID,
AccountName: account.Name,
UpstreamStatusCode: 0,
Kind: "request_error",
Message: safeErr,
})
c.JSON(http.StatusBadGateway, gin.H{
"error": gin.H{
"type": "upstream_error",
"message": "Upstream request failed",
},
})
return nil, fmt.Errorf("upstream request failed: %s", safeErr)
}
defer func() { _ = resp.Body.Close() }()
if resp.StatusCode >= 400 {
if s.shouldFailoverUpstreamError(resp.StatusCode) {
respBody, _ := io.ReadAll(io.LimitReader(resp.Body, 2<<20))
_ = resp.Body.Close()
resp.Body = io.NopCloser(bytes.NewReader(respBody))
upstreamMsg := strings.TrimSpace(extractUpstreamErrorMessage(respBody))
upstreamMsg = sanitizeUpstreamErrorMessage(upstreamMsg)
upstreamDetail := ""
if s.cfg != nil && s.cfg.Gateway.LogUpstreamErrorBody {
maxBytes := s.cfg.Gateway.LogUpstreamErrorBodyMaxBytes
if maxBytes <= 0 {
maxBytes = 2048
}
upstreamDetail = truncateString(string(respBody), maxBytes)
}
appendOpsUpstreamError(c, OpsUpstreamErrorEvent{
Platform: account.Platform,
AccountID: account.ID,
AccountName: account.Name,
UpstreamStatusCode: resp.StatusCode,
UpstreamRequestID: resp.Header.Get("x-request-id"),
Kind: "failover",
Message: upstreamMsg,
Detail: upstreamDetail,
})
s.handleFailoverSideEffects(ctx, resp, account)
return nil, &UpstreamFailoverError{StatusCode: resp.StatusCode}
}
return s.handleErrorResponse(ctx, resp, c, account, body)
}
var usage *OpenAIUsage
var firstTokenMs *int
if reqStream {
streamResult, err := s.handleChatCompletionsStreamingResponse(ctx, resp, c, account, startTime, originalModel, mappedModel)
if err != nil {
return nil, err
}
usage = streamResult.usage
firstTokenMs = streamResult.firstTokenMs
} else {
usage, err = s.handleChatCompletionsNonStreamingResponse(resp, c, originalModel, mappedModel)
if err != nil {
return nil, err
}
}
if usage == nil {
usage = &OpenAIUsage{}
}
return &OpenAIForwardResult{
RequestID: resp.Header.Get("x-request-id"),
Usage: *usage,
Model: originalModel,
Stream: reqStream,
Duration: time.Since(startTime),
FirstTokenMs: firstTokenMs,
}, nil
}
func (s *OpenAIGatewayService) buildChatCompletionsRequest(ctx context.Context, c *gin.Context, account *Account, body []byte, token string) (*http.Request, error) {
var targetURL string
baseURL := account.GetOpenAIBaseURL()
if baseURL == "" {
targetURL = openaiChatAPIURL
} else {
validatedURL, err := s.validateUpstreamBaseURL(baseURL)
if err != nil {
return nil, err
}
targetURL = validatedURL + "/chat/completions"
}
req, err := http.NewRequestWithContext(ctx, "POST", targetURL, bytes.NewReader(body))
if err != nil {
return nil, err
}
req.Header.Set("authorization", "Bearer "+token)
for key, values := range c.Request.Header {
lowerKey := strings.ToLower(key)
if openaiChatAllowedHeaders[lowerKey] {
for _, v := range values {
req.Header.Add(key, v)
}
}
}
customUA := account.GetOpenAIUserAgent()
if customUA != "" {
req.Header.Set("user-agent", customUA)
}
if req.Header.Get("content-type") == "" {
req.Header.Set("content-type", "application/json")
}
return req, nil
}
func (s *OpenAIGatewayService) handleChatCompletionsStreamingResponse(ctx context.Context, resp *http.Response, c *gin.Context, account *Account, startTime time.Time, originalModel, mappedModel string) (*chatStreamingResult, error) {
if s.responseHeaderFilter != nil {
responseheaders.WriteFilteredHeaders(c.Writer.Header(), resp.Header, s.responseHeaderFilter)
}
c.Header("Content-Type", "text/event-stream")
c.Header("Cache-Control", "no-cache")
c.Header("Connection", "keep-alive")
c.Header("X-Accel-Buffering", "no")
if v := resp.Header.Get("x-request-id"); v != "" {
c.Header("x-request-id", v)
}
w := c.Writer
flusher, ok := w.(http.Flusher)
if !ok {
return nil, errors.New("streaming not supported")
}
usage := &OpenAIUsage{}
var firstTokenMs *int
scanner := bufio.NewScanner(resp.Body)
maxLineSize := defaultMaxLineSize
if s.cfg != nil && s.cfg.Gateway.MaxLineSize > 0 {
maxLineSize = s.cfg.Gateway.MaxLineSize
}
scanner.Buffer(make([]byte, 64*1024), maxLineSize)
type scanEvent struct {
line string
err error
}
events := make(chan scanEvent, 16)
done := make(chan struct{})
sendEvent := func(ev scanEvent) bool {
select {
case events <- ev:
return true
case <-done:
return false
}
}
var lastReadAt int64
atomic.StoreInt64(&lastReadAt, time.Now().UnixNano())
go func() {
defer close(events)
for scanner.Scan() {
atomic.StoreInt64(&lastReadAt, time.Now().UnixNano())
if !sendEvent(scanEvent{line: scanner.Text()}) {
return
}
}
if err := scanner.Err(); err != nil {
_ = sendEvent(scanEvent{err: err})
}
}()
defer close(done)
streamInterval := time.Duration(0)
if s.cfg != nil && s.cfg.Gateway.StreamDataIntervalTimeout > 0 {
streamInterval = time.Duration(s.cfg.Gateway.StreamDataIntervalTimeout) * time.Second
}
var intervalTicker *time.Ticker
if streamInterval > 0 {
intervalTicker = time.NewTicker(streamInterval)
defer intervalTicker.Stop()
}
var intervalCh <-chan time.Time
if intervalTicker != nil {
intervalCh = intervalTicker.C
}
keepaliveInterval := time.Duration(0)
if s.cfg != nil && s.cfg.Gateway.StreamKeepaliveInterval > 0 {
keepaliveInterval = time.Duration(s.cfg.Gateway.StreamKeepaliveInterval) * time.Second
}
var keepaliveTicker *time.Ticker
if keepaliveInterval > 0 {
keepaliveTicker = time.NewTicker(keepaliveInterval)
defer keepaliveTicker.Stop()
}
var keepaliveCh <-chan time.Time
if keepaliveTicker != nil {
keepaliveCh = keepaliveTicker.C
}
lastDataAt := time.Now()
errorEventSent := false
sendErrorEvent := func(reason string) {
if errorEventSent {
return
}
errorEventSent = true
_, _ = fmt.Fprintf(w, "event: error\ndata: {\"error\":\"%s\"}\n\n", reason)
flusher.Flush()
}
needModelReplace := originalModel != mappedModel
for {
select {
case ev, ok := <-events:
if !ok {
return &chatStreamingResult{usage: usage, firstTokenMs: firstTokenMs}, nil
}
if ev.err != nil {
if errors.Is(ev.err, bufio.ErrTooLong) {
log.Printf("SSE line too long: account=%d max_size=%d error=%v", account.ID, maxLineSize, ev.err)
sendErrorEvent("response_too_large")
return &chatStreamingResult{usage: usage, firstTokenMs: firstTokenMs}, ev.err
}
sendErrorEvent("stream_read_error")
return &chatStreamingResult{usage: usage, firstTokenMs: firstTokenMs}, fmt.Errorf("stream read error: %w", ev.err)
}
line := ev.line
lastDataAt = time.Now()
if openaiSSEDataRe.MatchString(line) {
data := openaiSSEDataRe.ReplaceAllString(line, "")
if needModelReplace {
line = s.replaceModelInSSELine(line, mappedModel, originalModel)
}
if correctedData, corrected := s.toolCorrector.CorrectToolCallsInSSEData(data); corrected {
line = "data: " + correctedData
}
if _, err := fmt.Fprintf(w, "%s\n", line); err != nil {
sendErrorEvent("write_failed")
return &chatStreamingResult{usage: usage, firstTokenMs: firstTokenMs}, err
}
flusher.Flush()
if firstTokenMs == nil {
if event := parseChatStreamEvent(data); event != nil {
if chatChunkHasDelta(event) {
ms := int(time.Since(startTime).Milliseconds())
firstTokenMs = &ms
}
applyChatUsageFromEvent(event, usage)
}
} else {
if event := parseChatStreamEvent(data); event != nil {
applyChatUsageFromEvent(event, usage)
}
}
} else {
if _, err := fmt.Fprintf(w, "%s\n", line); err != nil {
sendErrorEvent("write_failed")
return &chatStreamingResult{usage: usage, firstTokenMs: firstTokenMs}, err
}
flusher.Flush()
}
case <-intervalCh:
lastRead := time.Unix(0, atomic.LoadInt64(&lastReadAt))
if time.Since(lastRead) < streamInterval {
continue
}
log.Printf("Stream data interval timeout: account=%d model=%s interval=%s", account.ID, originalModel, streamInterval)
if s.rateLimitService != nil {
s.rateLimitService.HandleStreamTimeout(ctx, account, originalModel)
}
sendErrorEvent("stream_timeout")
return &chatStreamingResult{usage: usage, firstTokenMs: firstTokenMs}, fmt.Errorf("stream data interval timeout")
case <-keepaliveCh:
if time.Since(lastDataAt) < keepaliveInterval {
continue
}
if _, err := fmt.Fprint(w, ":\n\n"); err != nil {
return &chatStreamingResult{usage: usage, firstTokenMs: firstTokenMs}, err
}
flusher.Flush()
}
}
}
func (s *OpenAIGatewayService) handleChatCompletionsNonStreamingResponse(resp *http.Response, c *gin.Context, originalModel, mappedModel string) (*OpenAIUsage, error) {
body, err := io.ReadAll(resp.Body)
if err != nil {
return nil, err
}
usage := &OpenAIUsage{}
var parsed map[string]any
if json.Unmarshal(body, &parsed) == nil {
if usageMap, ok := parsed["usage"].(map[string]any); ok {
applyChatUsageFromMap(usageMap, usage)
}
}
if originalModel != mappedModel {
body = s.replaceModelInResponseBody(body, mappedModel, originalModel)
}
body = s.correctToolCallsInResponseBody(body)
if s.responseHeaderFilter != nil {
responseheaders.WriteFilteredHeaders(c.Writer.Header(), resp.Header, s.responseHeaderFilter)
}
contentType := "application/json"
if s.cfg != nil && !s.cfg.Security.ResponseHeaders.Enabled {
if upstreamType := resp.Header.Get("Content-Type"); upstreamType != "" {
contentType = upstreamType
}
}
c.Data(resp.StatusCode, contentType, body)
return usage, nil
}
func parseChatStreamEvent(data string) map[string]any {
if data == "" || data == "[DONE]" {
return nil
}
var event map[string]any
if json.Unmarshal([]byte(data), &event) != nil {
return nil
}
return event
}
func chatChunkHasDelta(event map[string]any) bool {
choices, ok := event["choices"].([]any)
if !ok {
return false
}
for _, choice := range choices {
choiceMap, ok := choice.(map[string]any)
if !ok {
continue
}
delta, ok := choiceMap["delta"].(map[string]any)
if !ok {
continue
}
if content, ok := delta["content"].(string); ok && strings.TrimSpace(content) != "" {
return true
}
if toolCalls, ok := delta["tool_calls"].([]any); ok && len(toolCalls) > 0 {
return true
}
if functionCall, ok := delta["function_call"].(map[string]any); ok && len(functionCall) > 0 {
return true
}
}
return false
}
func applyChatUsageFromEvent(event map[string]any, usage *OpenAIUsage) {
if event == nil || usage == nil {
return
}
usageMap, ok := event["usage"].(map[string]any)
if !ok {
return
}
applyChatUsageFromMap(usageMap, usage)
}
func applyChatUsageFromMap(usageMap map[string]any, usage *OpenAIUsage) {
if usageMap == nil || usage == nil {
return
}
promptTokens := int(getNumber(usageMap["prompt_tokens"]))
completionTokens := int(getNumber(usageMap["completion_tokens"]))
if promptTokens > 0 {
usage.InputTokens = promptTokens
}
if completionTokens > 0 {
usage.OutputTokens = completionTokens
}
}

View File

@@ -1,132 +0,0 @@
package service
import (
"encoding/json"
"testing"
"github.com/stretchr/testify/require"
)
func TestConvertChatCompletionsToResponses(t *testing.T) {
req := map[string]any{
"model": "gpt-4o",
"messages": []any{
map[string]any{
"role": "user",
"content": "hello",
},
map[string]any{
"role": "assistant",
"tool_calls": []any{
map[string]any{
"id": "call_1",
"type": "function",
"function": map[string]any{
"name": "ping",
"arguments": "{}",
},
},
},
},
map[string]any{
"role": "tool",
"tool_call_id": "call_1",
"content": "ok",
"response": "ignored",
"response_time": 1,
},
},
"functions": []any{
map[string]any{
"name": "ping",
"description": "ping tool",
"parameters": map[string]any{"type": "object"},
},
},
"function_call": map[string]any{"name": "ping"},
}
converted, err := ConvertChatCompletionsToResponses(req)
require.NoError(t, err)
require.Equal(t, "gpt-4o", converted["model"])
input, ok := converted["input"].([]any)
require.True(t, ok)
require.Len(t, input, 3)
toolCall := findInputItemByType(input, "tool_call")
require.NotNil(t, toolCall)
require.Equal(t, "call_1", toolCall["call_id"])
toolOutput := findInputItemByType(input, "function_call_output")
require.NotNil(t, toolOutput)
require.Equal(t, "call_1", toolOutput["call_id"])
tools, ok := converted["tools"].([]any)
require.True(t, ok)
require.Len(t, tools, 1)
require.Equal(t, map[string]any{"name": "ping"}, converted["tool_choice"])
}
func TestConvertResponsesToChatCompletion(t *testing.T) {
resp := map[string]any{
"id": "resp_123",
"model": "gpt-4o",
"created_at": 1700000000,
"output": []any{
map[string]any{
"type": "message",
"role": "assistant",
"content": []any{
map[string]any{
"type": "output_text",
"text": "hi",
},
},
},
},
"usage": map[string]any{
"input_tokens": 2,
"output_tokens": 3,
},
}
body, err := json.Marshal(resp)
require.NoError(t, err)
converted, err := ConvertResponsesToChatCompletion(body)
require.NoError(t, err)
var chat map[string]any
require.NoError(t, json.Unmarshal(converted, &chat))
require.Equal(t, "chat.completion", chat["object"])
choices, ok := chat["choices"].([]any)
require.True(t, ok)
require.Len(t, choices, 1)
choice, ok := choices[0].(map[string]any)
require.True(t, ok)
message, ok := choice["message"].(map[string]any)
require.True(t, ok)
require.Equal(t, "hi", message["content"])
usage, ok := chat["usage"].(map[string]any)
require.True(t, ok)
require.Equal(t, float64(2), usage["prompt_tokens"])
require.Equal(t, float64(3), usage["completion_tokens"])
require.Equal(t, float64(5), usage["total_tokens"])
}
func findInputItemByType(items []any, itemType string) map[string]any {
for _, item := range items {
itemMap, ok := item.(map[string]any)
if !ok {
continue
}
if itemMap["type"] == itemType {
return itemMap
}
}
return nil
}

View File

@@ -0,0 +1,512 @@
package service
import (
"bufio"
"bytes"
"context"
"encoding/json"
"errors"
"fmt"
"io"
"net/http"
"strings"
"time"
"github.com/Wei-Shaw/sub2api/internal/pkg/apicompat"
"github.com/Wei-Shaw/sub2api/internal/pkg/logger"
"github.com/Wei-Shaw/sub2api/internal/util/responseheaders"
"github.com/gin-gonic/gin"
"go.uber.org/zap"
)
// ForwardAsChatCompletions accepts a Chat Completions request body, converts it
// to OpenAI Responses API format, forwards to the OpenAI upstream, and converts
// the response back to Chat Completions format. All account types (OAuth and API
// Key) go through the Responses API conversion path since the upstream only
// exposes the /v1/responses endpoint.
func (s *OpenAIGatewayService) ForwardAsChatCompletions(
ctx context.Context,
c *gin.Context,
account *Account,
body []byte,
promptCacheKey string,
defaultMappedModel string,
) (*OpenAIForwardResult, error) {
startTime := time.Now()
// 1. Parse Chat Completions request
var chatReq apicompat.ChatCompletionsRequest
if err := json.Unmarshal(body, &chatReq); err != nil {
return nil, fmt.Errorf("parse chat completions request: %w", err)
}
originalModel := chatReq.Model
clientStream := chatReq.Stream
includeUsage := chatReq.StreamOptions != nil && chatReq.StreamOptions.IncludeUsage
// 2. Convert to Responses and forward
// ChatCompletionsToResponses always sets Stream=true (upstream always streams).
responsesReq, err := apicompat.ChatCompletionsToResponses(&chatReq)
if err != nil {
return nil, fmt.Errorf("convert chat completions to responses: %w", err)
}
// 3. Model mapping
mappedModel := account.GetMappedModel(originalModel)
if mappedModel == originalModel && defaultMappedModel != "" {
mappedModel = defaultMappedModel
}
responsesReq.Model = mappedModel
logger.L().Debug("openai chat_completions: model mapping applied",
zap.Int64("account_id", account.ID),
zap.String("original_model", originalModel),
zap.String("mapped_model", mappedModel),
zap.Bool("stream", clientStream),
)
// 4. Marshal Responses request body, then apply OAuth codex transform
responsesBody, err := json.Marshal(responsesReq)
if err != nil {
return nil, fmt.Errorf("marshal responses request: %w", err)
}
if account.Type == AccountTypeOAuth {
var reqBody map[string]any
if err := json.Unmarshal(responsesBody, &reqBody); err != nil {
return nil, fmt.Errorf("unmarshal for codex transform: %w", err)
}
codexResult := applyCodexOAuthTransform(reqBody, false, false)
if codexResult.PromptCacheKey != "" {
promptCacheKey = codexResult.PromptCacheKey
} else if promptCacheKey != "" {
reqBody["prompt_cache_key"] = promptCacheKey
}
responsesBody, err = json.Marshal(reqBody)
if err != nil {
return nil, fmt.Errorf("remarshal after codex transform: %w", err)
}
}
// 5. Get access token
token, _, err := s.GetAccessToken(ctx, account)
if err != nil {
return nil, fmt.Errorf("get access token: %w", err)
}
// 6. Build upstream request
upstreamReq, err := s.buildUpstreamRequest(ctx, c, account, responsesBody, token, true, promptCacheKey, false)
if err != nil {
return nil, fmt.Errorf("build upstream request: %w", err)
}
if promptCacheKey != "" {
upstreamReq.Header.Set("session_id", generateSessionUUID(promptCacheKey))
}
// 7. Send request
proxyURL := ""
if account.Proxy != nil {
proxyURL = account.Proxy.URL()
}
resp, err := s.httpUpstream.Do(upstreamReq, proxyURL, account.ID, account.Concurrency)
if err != nil {
safeErr := sanitizeUpstreamErrorMessage(err.Error())
setOpsUpstreamError(c, 0, safeErr, "")
appendOpsUpstreamError(c, OpsUpstreamErrorEvent{
Platform: account.Platform,
AccountID: account.ID,
AccountName: account.Name,
UpstreamStatusCode: 0,
Kind: "request_error",
Message: safeErr,
})
writeChatCompletionsError(c, http.StatusBadGateway, "upstream_error", "Upstream request failed")
return nil, fmt.Errorf("upstream request failed: %s", safeErr)
}
defer func() { _ = resp.Body.Close() }()
// 8. Handle error response with failover
if resp.StatusCode >= 400 {
respBody, _ := io.ReadAll(io.LimitReader(resp.Body, 2<<20))
_ = resp.Body.Close()
resp.Body = io.NopCloser(bytes.NewReader(respBody))
upstreamMsg := strings.TrimSpace(extractUpstreamErrorMessage(respBody))
upstreamMsg = sanitizeUpstreamErrorMessage(upstreamMsg)
if s.shouldFailoverOpenAIUpstreamResponse(resp.StatusCode, upstreamMsg, respBody) {
upstreamDetail := ""
if s.cfg != nil && s.cfg.Gateway.LogUpstreamErrorBody {
maxBytes := s.cfg.Gateway.LogUpstreamErrorBodyMaxBytes
if maxBytes <= 0 {
maxBytes = 2048
}
upstreamDetail = truncateString(string(respBody), maxBytes)
}
appendOpsUpstreamError(c, OpsUpstreamErrorEvent{
Platform: account.Platform,
AccountID: account.ID,
AccountName: account.Name,
UpstreamStatusCode: resp.StatusCode,
UpstreamRequestID: resp.Header.Get("x-request-id"),
Kind: "failover",
Message: upstreamMsg,
Detail: upstreamDetail,
})
if s.rateLimitService != nil {
s.rateLimitService.HandleUpstreamError(ctx, account, resp.StatusCode, resp.Header, respBody)
}
return nil, &UpstreamFailoverError{
StatusCode: resp.StatusCode,
ResponseBody: respBody,
RetryableOnSameAccount: account.IsPoolMode() && (isPoolModeRetryableStatus(resp.StatusCode) || isOpenAITransientProcessingError(resp.StatusCode, upstreamMsg, respBody)),
}
}
return s.handleChatCompletionsErrorResponse(resp, c, account)
}
// 9. Handle normal response
var result *OpenAIForwardResult
var handleErr error
if clientStream {
result, handleErr = s.handleChatStreamingResponse(resp, c, originalModel, mappedModel, includeUsage, startTime)
} else {
result, handleErr = s.handleChatBufferedStreamingResponse(resp, c, originalModel, mappedModel, startTime)
}
// Propagate ServiceTier and ReasoningEffort to result for billing
if handleErr == nil && result != nil {
if responsesReq.ServiceTier != "" {
st := responsesReq.ServiceTier
result.ServiceTier = &st
}
if responsesReq.Reasoning != nil && responsesReq.Reasoning.Effort != "" {
re := responsesReq.Reasoning.Effort
result.ReasoningEffort = &re
}
}
// Extract and save Codex usage snapshot from response headers (for OAuth accounts)
if handleErr == nil && account.Type == AccountTypeOAuth {
if snapshot := ParseCodexRateLimitHeaders(resp.Header); snapshot != nil {
s.updateCodexUsageSnapshot(ctx, account.ID, snapshot)
}
}
return result, handleErr
}
// handleChatCompletionsErrorResponse reads an upstream error and returns it in
// OpenAI Chat Completions error format.
func (s *OpenAIGatewayService) handleChatCompletionsErrorResponse(
resp *http.Response,
c *gin.Context,
account *Account,
) (*OpenAIForwardResult, error) {
return s.handleCompatErrorResponse(resp, c, account, writeChatCompletionsError)
}
// handleChatBufferedStreamingResponse reads all Responses SSE events from the
// upstream, finds the terminal event, converts to a Chat Completions JSON
// response, and writes it to the client.
func (s *OpenAIGatewayService) handleChatBufferedStreamingResponse(
resp *http.Response,
c *gin.Context,
originalModel string,
mappedModel string,
startTime time.Time,
) (*OpenAIForwardResult, error) {
requestID := resp.Header.Get("x-request-id")
scanner := bufio.NewScanner(resp.Body)
maxLineSize := defaultMaxLineSize
if s.cfg != nil && s.cfg.Gateway.MaxLineSize > 0 {
maxLineSize = s.cfg.Gateway.MaxLineSize
}
scanner.Buffer(make([]byte, 0, 64*1024), maxLineSize)
var finalResponse *apicompat.ResponsesResponse
var usage OpenAIUsage
for scanner.Scan() {
line := scanner.Text()
if !strings.HasPrefix(line, "data: ") || line == "data: [DONE]" {
continue
}
payload := line[6:]
var event apicompat.ResponsesStreamEvent
if err := json.Unmarshal([]byte(payload), &event); err != nil {
logger.L().Warn("openai chat_completions buffered: failed to parse event",
zap.Error(err),
zap.String("request_id", requestID),
)
continue
}
if (event.Type == "response.completed" || event.Type == "response.incomplete" || event.Type == "response.failed") &&
event.Response != nil {
finalResponse = event.Response
if event.Response.Usage != nil {
usage = OpenAIUsage{
InputTokens: event.Response.Usage.InputTokens,
OutputTokens: event.Response.Usage.OutputTokens,
}
if event.Response.Usage.InputTokensDetails != nil {
usage.CacheReadInputTokens = event.Response.Usage.InputTokensDetails.CachedTokens
}
}
}
}
if err := scanner.Err(); err != nil {
if !errors.Is(err, context.Canceled) && !errors.Is(err, context.DeadlineExceeded) {
logger.L().Warn("openai chat_completions buffered: read error",
zap.Error(err),
zap.String("request_id", requestID),
)
}
}
if finalResponse == nil {
writeChatCompletionsError(c, http.StatusBadGateway, "api_error", "Upstream stream ended without a terminal response event")
return nil, fmt.Errorf("upstream stream ended without terminal event")
}
chatResp := apicompat.ResponsesToChatCompletions(finalResponse, originalModel)
if s.responseHeaderFilter != nil {
responseheaders.WriteFilteredHeaders(c.Writer.Header(), resp.Header, s.responseHeaderFilter)
}
c.JSON(http.StatusOK, chatResp)
return &OpenAIForwardResult{
RequestID: requestID,
Usage: usage,
Model: originalModel,
BillingModel: mappedModel,
Stream: false,
Duration: time.Since(startTime),
}, nil
}
// handleChatStreamingResponse reads Responses SSE events from upstream,
// converts each to Chat Completions SSE chunks, and writes them to the client.
func (s *OpenAIGatewayService) handleChatStreamingResponse(
resp *http.Response,
c *gin.Context,
originalModel string,
mappedModel string,
includeUsage bool,
startTime time.Time,
) (*OpenAIForwardResult, error) {
requestID := resp.Header.Get("x-request-id")
if s.responseHeaderFilter != nil {
responseheaders.WriteFilteredHeaders(c.Writer.Header(), resp.Header, s.responseHeaderFilter)
}
c.Writer.Header().Set("Content-Type", "text/event-stream")
c.Writer.Header().Set("Cache-Control", "no-cache")
c.Writer.Header().Set("Connection", "keep-alive")
c.Writer.Header().Set("X-Accel-Buffering", "no")
c.Writer.WriteHeader(http.StatusOK)
state := apicompat.NewResponsesEventToChatState()
state.Model = originalModel
state.IncludeUsage = includeUsage
var usage OpenAIUsage
var firstTokenMs *int
firstChunk := true
scanner := bufio.NewScanner(resp.Body)
maxLineSize := defaultMaxLineSize
if s.cfg != nil && s.cfg.Gateway.MaxLineSize > 0 {
maxLineSize = s.cfg.Gateway.MaxLineSize
}
scanner.Buffer(make([]byte, 0, 64*1024), maxLineSize)
resultWithUsage := func() *OpenAIForwardResult {
return &OpenAIForwardResult{
RequestID: requestID,
Usage: usage,
Model: originalModel,
BillingModel: mappedModel,
Stream: true,
Duration: time.Since(startTime),
FirstTokenMs: firstTokenMs,
}
}
processDataLine := func(payload string) bool {
if firstChunk {
firstChunk = false
ms := int(time.Since(startTime).Milliseconds())
firstTokenMs = &ms
}
var event apicompat.ResponsesStreamEvent
if err := json.Unmarshal([]byte(payload), &event); err != nil {
logger.L().Warn("openai chat_completions stream: failed to parse event",
zap.Error(err),
zap.String("request_id", requestID),
)
return false
}
// Extract usage from completion events
if (event.Type == "response.completed" || event.Type == "response.incomplete" || event.Type == "response.failed") &&
event.Response != nil && event.Response.Usage != nil {
usage = OpenAIUsage{
InputTokens: event.Response.Usage.InputTokens,
OutputTokens: event.Response.Usage.OutputTokens,
}
if event.Response.Usage.InputTokensDetails != nil {
usage.CacheReadInputTokens = event.Response.Usage.InputTokensDetails.CachedTokens
}
}
chunks := apicompat.ResponsesEventToChatChunks(&event, state)
for _, chunk := range chunks {
sse, err := apicompat.ChatChunkToSSE(chunk)
if err != nil {
logger.L().Warn("openai chat_completions stream: failed to marshal chunk",
zap.Error(err),
zap.String("request_id", requestID),
)
continue
}
if _, err := fmt.Fprint(c.Writer, sse); err != nil {
logger.L().Info("openai chat_completions stream: client disconnected",
zap.String("request_id", requestID),
)
return true
}
}
if len(chunks) > 0 {
c.Writer.Flush()
}
return false
}
finalizeStream := func() (*OpenAIForwardResult, error) {
if finalChunks := apicompat.FinalizeResponsesChatStream(state); len(finalChunks) > 0 {
for _, chunk := range finalChunks {
sse, err := apicompat.ChatChunkToSSE(chunk)
if err != nil {
continue
}
fmt.Fprint(c.Writer, sse) //nolint:errcheck
}
}
// Send [DONE] sentinel
fmt.Fprint(c.Writer, "data: [DONE]\n\n") //nolint:errcheck
c.Writer.Flush()
return resultWithUsage(), nil
}
handleScanErr := func(err error) {
if err != nil && !errors.Is(err, context.Canceled) && !errors.Is(err, context.DeadlineExceeded) {
logger.L().Warn("openai chat_completions stream: read error",
zap.Error(err),
zap.String("request_id", requestID),
)
}
}
// Determine keepalive interval
keepaliveInterval := time.Duration(0)
if s.cfg != nil && s.cfg.Gateway.StreamKeepaliveInterval > 0 {
keepaliveInterval = time.Duration(s.cfg.Gateway.StreamKeepaliveInterval) * time.Second
}
// No keepalive: fast synchronous path
if keepaliveInterval <= 0 {
for scanner.Scan() {
line := scanner.Text()
if !strings.HasPrefix(line, "data: ") || line == "data: [DONE]" {
continue
}
if processDataLine(line[6:]) {
return resultWithUsage(), nil
}
}
handleScanErr(scanner.Err())
return finalizeStream()
}
// With keepalive: goroutine + channel + select
type scanEvent struct {
line string
err error
}
events := make(chan scanEvent, 16)
done := make(chan struct{})
sendEvent := func(ev scanEvent) bool {
select {
case events <- ev:
return true
case <-done:
return false
}
}
go func() {
defer close(events)
for scanner.Scan() {
if !sendEvent(scanEvent{line: scanner.Text()}) {
return
}
}
if err := scanner.Err(); err != nil {
_ = sendEvent(scanEvent{err: err})
}
}()
defer close(done)
keepaliveTicker := time.NewTicker(keepaliveInterval)
defer keepaliveTicker.Stop()
lastDataAt := time.Now()
for {
select {
case ev, ok := <-events:
if !ok {
return finalizeStream()
}
if ev.err != nil {
handleScanErr(ev.err)
return finalizeStream()
}
lastDataAt = time.Now()
line := ev.line
if !strings.HasPrefix(line, "data: ") || line == "data: [DONE]" {
continue
}
if processDataLine(line[6:]) {
return resultWithUsage(), nil
}
case <-keepaliveTicker.C:
if time.Since(lastDataAt) < keepaliveInterval {
continue
}
// Send SSE comment as keepalive
if _, err := fmt.Fprint(c.Writer, ":\n\n"); err != nil {
logger.L().Info("openai chat_completions stream: client disconnected during keepalive",
zap.String("request_id", requestID),
)
return resultWithUsage(), nil
}
c.Writer.Flush()
}
}
}
// writeChatCompletionsError writes an error response in OpenAI Chat Completions format.
func writeChatCompletionsError(c *gin.Context, statusCode int, errType, message string) {
c.JSON(statusCode, gin.H{
"error": gin.H{
"type": errType,
"message": message,
},
})
}

View File

@@ -172,7 +172,7 @@ func (s *OpenAIGatewayService) ForwardAsAnthropic(
return nil, &UpstreamFailoverError{
StatusCode: resp.StatusCode,
ResponseBody: respBody,
RetryableOnSameAccount: account.IsPoolMode() && isOpenAITransientProcessingError(resp.StatusCode, upstreamMsg, respBody),
RetryableOnSameAccount: account.IsPoolMode() && (isPoolModeRetryableStatus(resp.StatusCode) || isOpenAITransientProcessingError(resp.StatusCode, upstreamMsg, respBody)),
}
}
// Non-failover error: return Anthropic-formatted error to client
@@ -219,54 +219,7 @@ func (s *OpenAIGatewayService) handleAnthropicErrorResponse(
c *gin.Context,
account *Account,
) (*OpenAIForwardResult, error) {
body, _ := io.ReadAll(io.LimitReader(resp.Body, 2<<20))
upstreamMsg := strings.TrimSpace(extractUpstreamErrorMessage(body))
if upstreamMsg == "" {
upstreamMsg = fmt.Sprintf("Upstream error: %d", resp.StatusCode)
}
upstreamMsg = sanitizeUpstreamErrorMessage(upstreamMsg)
// Record upstream error details for ops logging
upstreamDetail := ""
if s.cfg != nil && s.cfg.Gateway.LogUpstreamErrorBody {
maxBytes := s.cfg.Gateway.LogUpstreamErrorBodyMaxBytes
if maxBytes <= 0 {
maxBytes = 2048
}
upstreamDetail = truncateString(string(body), maxBytes)
}
setOpsUpstreamError(c, resp.StatusCode, upstreamMsg, upstreamDetail)
// Apply error passthrough rules (matches handleErrorResponse pattern in openai_gateway_service.go)
if status, errType, errMsg, matched := applyErrorPassthroughRule(
c, account.Platform, resp.StatusCode, body,
http.StatusBadGateway, "api_error", "Upstream request failed",
); matched {
writeAnthropicError(c, status, errType, errMsg)
if upstreamMsg == "" {
upstreamMsg = errMsg
}
if upstreamMsg == "" {
return nil, fmt.Errorf("upstream error: %d (passthrough rule matched)", resp.StatusCode)
}
return nil, fmt.Errorf("upstream error: %d (passthrough rule matched) message=%s", resp.StatusCode, upstreamMsg)
}
errType := "api_error"
switch {
case resp.StatusCode == 400:
errType = "invalid_request_error"
case resp.StatusCode == 404:
errType = "not_found_error"
case resp.StatusCode == 429:
errType = "rate_limit_error"
case resp.StatusCode >= 500:
errType = "api_error"
}
writeAnthropicError(c, resp.StatusCode, errType, upstreamMsg)
return nil, fmt.Errorf("upstream error: %d %s", resp.StatusCode, upstreamMsg)
return s.handleCompatErrorResponse(resp, c, account, writeAnthropicError)
}
// handleAnthropicBufferedStreamingResponse reads all Responses SSE events from

View File

@@ -12,7 +12,6 @@ import (
"io"
"math/rand"
"net/http"
"regexp"
"sort"
"strconv"
"strings"
@@ -37,7 +36,6 @@ const (
chatgptCodexURL = "https://chatgpt.com/backend-api/codex/responses"
// OpenAI Platform API for API Key accounts (fallback)
openaiPlatformAPIURL = "https://api.openai.com/v1/responses"
openaiChatAPIURL = "https://api.openai.com/v1/chat/completions"
openaiStickySessionTTL = time.Hour // 粘性会话TTL
codexCLIUserAgent = "codex_cli_rs/0.104.0"
// codex_cli_only 拒绝时单个请求头日志长度上限(字符)
@@ -56,16 +54,6 @@ const (
codexCLIVersion = "0.104.0"
)
// OpenAIChatCompletionsBodyKey stores the original chat-completions payload in gin.Context.
const OpenAIChatCompletionsBodyKey = "openai_chat_completions_body"
// OpenAIChatCompletionsIncludeUsageKey stores stream_options.include_usage in gin.Context.
const OpenAIChatCompletionsIncludeUsageKey = "openai_chat_completions_include_usage"
// openaiSSEDataRe matches SSE data lines with optional whitespace after colon.
// Some upstream APIs return non-standard "data:" without space (should be "data: ").
var openaiSSEDataRe = regexp.MustCompile(`^data:\s*`)
// OpenAI allowed headers whitelist (for non-passthrough).
var openaiAllowedHeaders = map[string]bool{
"accept-language": true,
@@ -109,19 +97,6 @@ var codexCLIOnlyDebugHeaderWhitelist = []string{
"X-Real-IP",
}
// OpenAI chat-completions allowed headers (extend responses whitelist).
var openaiChatAllowedHeaders = map[string]bool{
"accept-language": true,
"content-type": true,
"conversation_id": true,
"user-agent": true,
"originator": true,
"session_id": true,
"openai-organization": true,
"openai-project": true,
"openai-beta": true,
}
// OpenAICodexUsageSnapshot represents Codex API usage limits from response headers
type OpenAICodexUsageSnapshot struct {
PrimaryUsedPercent *float64 `json:"primary_used_percent,omitempty"`
@@ -1602,23 +1577,6 @@ func (s *OpenAIGatewayService) Forward(ctx context.Context, c *gin.Context, acco
return nil, errors.New("codex_cli_only restriction: only codex official clients are allowed")
}
if c != nil && account != nil && account.Type == AccountTypeAPIKey {
if raw, ok := c.Get(OpenAIChatCompletionsBodyKey); ok {
if rawBody, ok := raw.([]byte); ok && len(rawBody) > 0 {
includeUsage := false
if v, ok := c.Get(OpenAIChatCompletionsIncludeUsageKey); ok {
if flag, ok := v.(bool); ok {
includeUsage = flag
}
}
if passthroughWriter, ok := c.Writer.(interface{ SetPassthrough() }); ok {
passthroughWriter.SetPassthrough()
}
return s.forwardChatCompletions(ctx, c, account, rawBody, includeUsage, startTime)
}
}
}
originalBody := body
reqModel, reqStream, promptCacheKey := extractOpenAIRequestMetaFromBody(body)
originalModel := reqModel
@@ -2989,6 +2947,120 @@ func (s *OpenAIGatewayService) handleErrorResponse(
return nil, fmt.Errorf("upstream error: %d message=%s", resp.StatusCode, upstreamMsg)
}
// compatErrorWriter is the signature for format-specific error writers used by
// the compat paths (Chat Completions and Anthropic Messages).
type compatErrorWriter func(c *gin.Context, statusCode int, errType, message string)
// handleCompatErrorResponse is the shared non-failover error handler for the
// Chat Completions and Anthropic Messages compat paths. It mirrors the logic of
// handleErrorResponse (passthrough rules, ShouldHandleErrorCode, rate-limit
// tracking, secondary failover) but delegates the final error write to the
// format-specific writer function.
func (s *OpenAIGatewayService) handleCompatErrorResponse(
resp *http.Response,
c *gin.Context,
account *Account,
writeError compatErrorWriter,
) (*OpenAIForwardResult, error) {
body, _ := io.ReadAll(io.LimitReader(resp.Body, 2<<20))
upstreamMsg := strings.TrimSpace(extractUpstreamErrorMessage(body))
if upstreamMsg == "" {
upstreamMsg = fmt.Sprintf("Upstream error: %d", resp.StatusCode)
}
upstreamMsg = sanitizeUpstreamErrorMessage(upstreamMsg)
upstreamDetail := ""
if s.cfg != nil && s.cfg.Gateway.LogUpstreamErrorBody {
maxBytes := s.cfg.Gateway.LogUpstreamErrorBodyMaxBytes
if maxBytes <= 0 {
maxBytes = 2048
}
upstreamDetail = truncateString(string(body), maxBytes)
}
setOpsUpstreamError(c, resp.StatusCode, upstreamMsg, upstreamDetail)
// Apply error passthrough rules
if status, errType, errMsg, matched := applyErrorPassthroughRule(
c, account.Platform, resp.StatusCode, body,
http.StatusBadGateway, "api_error", "Upstream request failed",
); matched {
writeError(c, status, errType, errMsg)
if upstreamMsg == "" {
upstreamMsg = errMsg
}
if upstreamMsg == "" {
return nil, fmt.Errorf("upstream error: %d (passthrough rule matched)", resp.StatusCode)
}
return nil, fmt.Errorf("upstream error: %d (passthrough rule matched) message=%s", resp.StatusCode, upstreamMsg)
}
// Check custom error codes — if the account does not handle this status,
// return a generic error without exposing upstream details.
if !account.ShouldHandleErrorCode(resp.StatusCode) {
appendOpsUpstreamError(c, OpsUpstreamErrorEvent{
Platform: account.Platform,
AccountID: account.ID,
AccountName: account.Name,
UpstreamStatusCode: resp.StatusCode,
UpstreamRequestID: resp.Header.Get("x-request-id"),
Kind: "http_error",
Message: upstreamMsg,
Detail: upstreamDetail,
})
writeError(c, http.StatusInternalServerError, "api_error", "Upstream gateway error")
if upstreamMsg == "" {
return nil, fmt.Errorf("upstream error: %d (not in custom error codes)", resp.StatusCode)
}
return nil, fmt.Errorf("upstream error: %d (not in custom error codes) message=%s", resp.StatusCode, upstreamMsg)
}
// Track rate limits and decide whether to trigger secondary failover.
shouldDisable := false
if s.rateLimitService != nil {
shouldDisable = s.rateLimitService.HandleUpstreamError(
c.Request.Context(), account, resp.StatusCode, resp.Header, body,
)
}
kind := "http_error"
if shouldDisable {
kind = "failover"
}
appendOpsUpstreamError(c, OpsUpstreamErrorEvent{
Platform: account.Platform,
AccountID: account.ID,
AccountName: account.Name,
UpstreamStatusCode: resp.StatusCode,
UpstreamRequestID: resp.Header.Get("x-request-id"),
Kind: kind,
Message: upstreamMsg,
Detail: upstreamDetail,
})
if shouldDisable {
return nil, &UpstreamFailoverError{
StatusCode: resp.StatusCode,
ResponseBody: body,
RetryableOnSameAccount: account.IsPoolMode() && isPoolModeRetryableStatus(resp.StatusCode),
}
}
// Map status code to error type and write response
errType := "api_error"
switch {
case resp.StatusCode == 400:
errType = "invalid_request_error"
case resp.StatusCode == 404:
errType = "not_found_error"
case resp.StatusCode == 429:
errType = "rate_limit_error"
case resp.StatusCode >= 500:
errType = "api_error"
}
writeError(c, resp.StatusCode, errType, upstreamMsg)
return nil, fmt.Errorf("upstream error: %d %s", resp.StatusCode, upstreamMsg)
}
// openaiStreamingResult streaming response result
type openaiStreamingResult struct {
usage *OpenAIUsage