feat: add OpenAI chat completions compatibility

This commit is contained in:
yulate
2026-02-26 11:18:02 +08:00
parent c75c6b6858
commit 0bb6a39260
6 changed files with 1707 additions and 11 deletions

View File

@@ -0,0 +1,530 @@
package handler
import (
"bytes"
"crypto/rand"
"encoding/hex"
"encoding/json"
"io"
"net/http"
"strings"
"time"
"github.com/Wei-Shaw/sub2api/internal/service"
"github.com/gin-gonic/gin"
)
// ChatCompletions handles OpenAI Chat Completions API compatibility.
// POST /v1/chat/completions
func (h *OpenAIGatewayHandler) ChatCompletions(c *gin.Context) {
body, err := io.ReadAll(c.Request.Body)
if err != nil {
if maxErr, ok := extractMaxBytesError(err); ok {
h.errorResponse(c, http.StatusRequestEntityTooLarge, "invalid_request_error", buildBodyTooLargeMessage(maxErr.Limit))
return
}
h.errorResponse(c, http.StatusBadRequest, "invalid_request_error", "Failed to read request body")
return
}
if len(body) == 0 {
h.errorResponse(c, http.StatusBadRequest, "invalid_request_error", "Request body is empty")
return
}
// Preserve original chat-completions request for upstream passthrough when needed.
c.Set(service.OpenAIChatCompletionsBodyKey, body)
var chatReq map[string]any
if err := json.Unmarshal(body, &chatReq); err != nil {
h.errorResponse(c, http.StatusBadRequest, "invalid_request_error", "Failed to parse request body")
return
}
includeUsage := false
if streamOptions, ok := chatReq["stream_options"].(map[string]any); ok {
if v, ok := streamOptions["include_usage"].(bool); ok {
includeUsage = v
}
}
c.Set(service.OpenAIChatCompletionsIncludeUsageKey, includeUsage)
converted, err := service.ConvertChatCompletionsToResponses(chatReq)
if err != nil {
h.errorResponse(c, http.StatusBadRequest, "invalid_request_error", err.Error())
return
}
convertedBody, err := json.Marshal(converted)
if err != nil {
h.errorResponse(c, http.StatusInternalServerError, "api_error", "Failed to process request")
return
}
stream, _ := converted["stream"].(bool)
model, _ := converted["model"].(string)
writer := newChatCompletionsResponseWriter(c.Writer, stream, includeUsage, model)
c.Writer = writer
c.Request.Body = io.NopCloser(bytes.NewReader(convertedBody))
c.Request.ContentLength = int64(len(convertedBody))
h.Responses(c)
writer.Finalize()
}
type chatCompletionsResponseWriter struct {
gin.ResponseWriter
stream bool
includeUsage bool
buffer bytes.Buffer
streamBuf bytes.Buffer
state *chatCompletionStreamState
corrector *service.CodexToolCorrector
finalized bool
passthrough bool
}
type chatCompletionStreamState struct {
id string
model string
created int64
sentRole bool
sawToolCall bool
sawText bool
toolCallIndex map[string]int
usage map[string]any
}
func newChatCompletionsResponseWriter(w gin.ResponseWriter, stream bool, includeUsage bool, model string) *chatCompletionsResponseWriter {
return &chatCompletionsResponseWriter{
ResponseWriter: w,
stream: stream,
includeUsage: includeUsage,
state: &chatCompletionStreamState{
model: strings.TrimSpace(model),
toolCallIndex: make(map[string]int),
},
corrector: service.NewCodexToolCorrector(),
}
}
func (w *chatCompletionsResponseWriter) Write(data []byte) (int, error) {
if w.passthrough {
return w.ResponseWriter.Write(data)
}
if w.stream {
n, err := w.streamBuf.Write(data)
if err != nil {
return n, err
}
w.flushStreamBuffer()
return n, nil
}
if w.finalized {
return len(data), nil
}
return w.buffer.Write(data)
}
func (w *chatCompletionsResponseWriter) WriteString(s string) (int, error) {
return w.Write([]byte(s))
}
func (w *chatCompletionsResponseWriter) Finalize() {
if w.finalized {
return
}
w.finalized = true
if w.passthrough {
return
}
if w.stream {
return
}
body := w.buffer.Bytes()
if len(body) == 0 {
return
}
w.ResponseWriter.Header().Del("Content-Length")
converted, err := service.ConvertResponsesToChatCompletion(body)
if err != nil {
_, _ = w.ResponseWriter.Write(body)
return
}
corrected := converted
if correctedStr, ok := w.corrector.CorrectToolCallsInSSEData(string(converted)); ok {
corrected = []byte(correctedStr)
}
_, _ = w.ResponseWriter.Write(corrected)
}
func (w *chatCompletionsResponseWriter) SetPassthrough() {
w.passthrough = true
}
func (w *chatCompletionsResponseWriter) flushStreamBuffer() {
for {
buf := w.streamBuf.Bytes()
idx := bytes.IndexByte(buf, '\n')
if idx == -1 {
return
}
lineBytes := w.streamBuf.Next(idx + 1)
line := strings.TrimRight(string(lineBytes), "\r\n")
w.handleStreamLine(line)
}
}
func (w *chatCompletionsResponseWriter) handleStreamLine(line string) {
if line == "" {
return
}
if strings.HasPrefix(line, ":") {
_, _ = w.ResponseWriter.Write([]byte(line + "\n\n"))
return
}
if !strings.HasPrefix(line, "data:") {
return
}
data := strings.TrimSpace(strings.TrimPrefix(line, "data:"))
for _, chunk := range w.convertResponseDataToChatChunks(data) {
if chunk == "" {
continue
}
if chunk == "[DONE]" {
_, _ = w.ResponseWriter.Write([]byte("data: [DONE]\n\n"))
continue
}
_, _ = w.ResponseWriter.Write([]byte("data: " + chunk + "\n\n"))
}
}
func (w *chatCompletionsResponseWriter) convertResponseDataToChatChunks(data string) []string {
if data == "" {
return nil
}
if data == "[DONE]" {
return []string{"[DONE]"}
}
var payload map[string]any
if err := json.Unmarshal([]byte(data), &payload); err != nil {
return []string{data}
}
if _, ok := payload["error"]; ok {
return []string{data}
}
eventType := strings.TrimSpace(getString(payload["type"]))
if eventType == "" {
return []string{data}
}
w.state.applyMetadata(payload)
switch eventType {
case "response.created":
return nil
case "response.output_text.delta":
delta := getString(payload["delta"])
if delta == "" {
return nil
}
w.state.sawText = true
return []string{w.buildTextDeltaChunk(delta)}
case "response.output_text.done":
if w.state.sawText {
return nil
}
text := getString(payload["text"])
if text == "" {
return nil
}
w.state.sawText = true
return []string{w.buildTextDeltaChunk(text)}
case "response.output_item.added", "response.output_item.delta":
if item, ok := payload["item"].(map[string]any); ok {
if callID, name, args, ok := extractToolCallFromItem(item); ok {
w.state.sawToolCall = true
return []string{w.buildToolCallChunk(callID, name, args)}
}
}
case "response.completed", "response.done":
if responseObj, ok := payload["response"].(map[string]any); ok {
w.state.applyResponseUsage(responseObj)
}
return []string{w.buildFinalChunk()}
}
if strings.Contains(eventType, "tool_call") || strings.Contains(eventType, "function_call") {
callID := strings.TrimSpace(getString(payload["call_id"]))
if callID == "" {
callID = strings.TrimSpace(getString(payload["tool_call_id"]))
}
if callID == "" {
callID = strings.TrimSpace(getString(payload["id"]))
}
args := getString(payload["delta"])
name := strings.TrimSpace(getString(payload["name"]))
if callID != "" && (args != "" || name != "") {
w.state.sawToolCall = true
return []string{w.buildToolCallChunk(callID, name, args)}
}
}
return nil
}
func (w *chatCompletionsResponseWriter) buildTextDeltaChunk(delta string) string {
w.state.ensureDefaults()
payload := map[string]any{
"content": delta,
}
if !w.state.sentRole {
payload["role"] = "assistant"
w.state.sentRole = true
}
return w.buildChunk(payload, nil, nil)
}
func (w *chatCompletionsResponseWriter) buildToolCallChunk(callID, name, args string) string {
w.state.ensureDefaults()
index := w.state.toolCallIndexFor(callID)
function := map[string]any{}
if name != "" {
function["name"] = name
}
if args != "" {
function["arguments"] = args
}
toolCall := map[string]any{
"index": index,
"id": callID,
"type": "function",
"function": function,
}
delta := map[string]any{
"tool_calls": []any{toolCall},
}
if !w.state.sentRole {
delta["role"] = "assistant"
w.state.sentRole = true
}
return w.buildChunk(delta, nil, nil)
}
func (w *chatCompletionsResponseWriter) buildFinalChunk() string {
w.state.ensureDefaults()
finishReason := "stop"
if w.state.sawToolCall {
finishReason = "tool_calls"
}
usage := map[string]any(nil)
if w.includeUsage && w.state.usage != nil {
usage = w.state.usage
}
return w.buildChunk(map[string]any{}, finishReason, usage)
}
func (w *chatCompletionsResponseWriter) buildChunk(delta map[string]any, finishReason any, usage map[string]any) string {
w.state.ensureDefaults()
chunk := map[string]any{
"id": w.state.id,
"object": "chat.completion.chunk",
"created": w.state.created,
"model": w.state.model,
"choices": []any{
map[string]any{
"index": 0,
"delta": delta,
"finish_reason": finishReason,
},
},
}
if usage != nil {
chunk["usage"] = usage
}
data, _ := json.Marshal(chunk)
if corrected, ok := w.corrector.CorrectToolCallsInSSEData(string(data)); ok {
return corrected
}
return string(data)
}
func (s *chatCompletionStreamState) ensureDefaults() {
if s.id == "" {
s.id = "chatcmpl-" + randomHexUnsafe(12)
}
if s.model == "" {
s.model = "unknown"
}
if s.created == 0 {
s.created = time.Now().Unix()
}
}
func (s *chatCompletionStreamState) toolCallIndexFor(callID string) int {
if idx, ok := s.toolCallIndex[callID]; ok {
return idx
}
idx := len(s.toolCallIndex)
s.toolCallIndex[callID] = idx
return idx
}
func (s *chatCompletionStreamState) applyMetadata(payload map[string]any) {
if responseObj, ok := payload["response"].(map[string]any); ok {
s.applyResponseMetadata(responseObj)
}
if s.id == "" {
if id := strings.TrimSpace(getString(payload["response_id"])); id != "" {
s.id = id
} else if id := strings.TrimSpace(getString(payload["id"])); id != "" {
s.id = id
}
}
if s.model == "" {
if model := strings.TrimSpace(getString(payload["model"])); model != "" {
s.model = model
}
}
if s.created == 0 {
if created := getInt64(payload["created_at"]); created != 0 {
s.created = created
} else if created := getInt64(payload["created"]); created != 0 {
s.created = created
}
}
}
func (s *chatCompletionStreamState) applyResponseMetadata(responseObj map[string]any) {
if s.id == "" {
if id := strings.TrimSpace(getString(responseObj["id"])); id != "" {
s.id = id
}
}
if s.model == "" {
if model := strings.TrimSpace(getString(responseObj["model"])); model != "" {
s.model = model
}
}
if s.created == 0 {
if created := getInt64(responseObj["created_at"]); created != 0 {
s.created = created
}
}
}
func (s *chatCompletionStreamState) applyResponseUsage(responseObj map[string]any) {
usage, ok := responseObj["usage"].(map[string]any)
if !ok {
return
}
promptTokens := int(getNumber(usage["input_tokens"]))
completionTokens := int(getNumber(usage["output_tokens"]))
if promptTokens == 0 && completionTokens == 0 {
return
}
s.usage = map[string]any{
"prompt_tokens": promptTokens,
"completion_tokens": completionTokens,
"total_tokens": promptTokens + completionTokens,
}
}
func extractToolCallFromItem(item map[string]any) (string, string, string, bool) {
itemType := strings.TrimSpace(getString(item["type"]))
if itemType != "tool_call" && itemType != "function_call" {
return "", "", "", false
}
callID := strings.TrimSpace(getString(item["call_id"]))
if callID == "" {
callID = strings.TrimSpace(getString(item["id"]))
}
name := strings.TrimSpace(getString(item["name"]))
args := getString(item["arguments"])
if fn, ok := item["function"].(map[string]any); ok {
if name == "" {
name = strings.TrimSpace(getString(fn["name"]))
}
if args == "" {
args = getString(fn["arguments"])
}
}
if callID == "" && name == "" && args == "" {
return "", "", "", false
}
if callID == "" {
callID = "call_" + randomHexUnsafe(6)
}
return callID, name, args, true
}
func getString(value any) string {
switch v := value.(type) {
case string:
return v
case []byte:
return string(v)
case json.Number:
return v.String()
default:
return ""
}
}
func getNumber(value any) float64 {
switch v := value.(type) {
case float64:
return v
case float32:
return float64(v)
case int:
return float64(v)
case int64:
return float64(v)
case json.Number:
f, _ := v.Float64()
return f
default:
return 0
}
}
func getInt64(value any) int64 {
switch v := value.(type) {
case int64:
return v
case int:
return int64(v)
case float64:
return int64(v)
case json.Number:
i, _ := v.Int64()
return i
default:
return 0
}
}
func randomHexUnsafe(byteLength int) string {
if byteLength <= 0 {
byteLength = 8
}
buf := make([]byte, byteLength)
if _, err := rand.Read(buf); err != nil {
return "000000"
}
return hex.EncodeToString(buf)
}

View File

@@ -1,8 +1,6 @@
package routes
import (
"net/http"
"github.com/Wei-Shaw/sub2api/internal/config"
"github.com/Wei-Shaw/sub2api/internal/handler"
"github.com/Wei-Shaw/sub2api/internal/server/middleware"
@@ -43,15 +41,8 @@ func RegisterGatewayRoutes(
gateway.GET("/usage", h.Gateway.Usage)
// OpenAI Responses API
gateway.POST("/responses", h.OpenAIGateway.Responses)
// 明确阻止旧协议入口OpenAI 仅支持 Responses API,避免客户端误解为会自动路由到其它平台。
gateway.POST("/chat/completions", func(c *gin.Context) {
c.JSON(http.StatusBadRequest, gin.H{
"error": gin.H{
"type": "invalid_request_error",
"message": "Unsupported legacy protocol: /v1/chat/completions is not supported. Please use /v1/responses.",
},
})
})
// OpenAI Chat Completions API
gateway.POST("/chat/completions", h.OpenAIGateway.ChatCompletions)
}
// Gemini 原生 API 兼容层Gemini SDK/CLI 直连)
@@ -69,6 +60,8 @@ func RegisterGatewayRoutes(
// OpenAI Responses API不带v1前缀的别名
r.POST("/responses", bodyLimit, clientRequestID, opsErrorLogger, gin.HandlerFunc(apiKeyAuth), h.OpenAIGateway.Responses)
// OpenAI Chat Completions API不带v1前缀的别名
r.POST("/chat/completions", bodyLimit, clientRequestID, opsErrorLogger, gin.HandlerFunc(apiKeyAuth), h.OpenAIGateway.ChatCompletions)
// Antigravity 模型列表
r.GET("/antigravity/models", gin.HandlerFunc(apiKeyAuth), h.Gateway.AntigravityModels)

View File

@@ -0,0 +1,513 @@
package service
import (
"encoding/json"
"errors"
"strings"
"time"
)
// ConvertChatCompletionsToResponses converts an OpenAI Chat Completions request to a Responses request.
func ConvertChatCompletionsToResponses(req map[string]any) (map[string]any, error) {
if req == nil {
return nil, errors.New("request is nil")
}
model := strings.TrimSpace(getString(req["model"]))
if model == "" {
return nil, errors.New("model is required")
}
messagesRaw, ok := req["messages"]
if !ok {
return nil, errors.New("messages is required")
}
messages, ok := messagesRaw.([]any)
if !ok {
return nil, errors.New("messages must be an array")
}
input, err := convertChatMessagesToResponsesInput(messages)
if err != nil {
return nil, err
}
out := make(map[string]any, len(req)+1)
for key, value := range req {
switch key {
case "messages", "max_tokens", "max_completion_tokens", "stream_options", "functions", "function_call":
continue
default:
out[key] = value
}
}
out["model"] = model
out["input"] = input
if _, ok := out["max_output_tokens"]; !ok {
if v, ok := req["max_tokens"]; ok {
out["max_output_tokens"] = v
} else if v, ok := req["max_completion_tokens"]; ok {
out["max_output_tokens"] = v
}
}
if _, ok := out["tools"]; !ok {
if functions, ok := req["functions"].([]any); ok && len(functions) > 0 {
tools := make([]any, 0, len(functions))
for _, fn := range functions {
if fnMap, ok := fn.(map[string]any); ok {
tools = append(tools, map[string]any{
"type": "function",
"function": fnMap,
})
}
}
if len(tools) > 0 {
out["tools"] = tools
}
}
}
if _, ok := out["tool_choice"]; !ok {
if functionCall, ok := req["function_call"]; ok {
out["tool_choice"] = functionCall
}
}
return out, nil
}
// ConvertResponsesToChatCompletion converts an OpenAI Responses response body to Chat Completions format.
func ConvertResponsesToChatCompletion(body []byte) ([]byte, error) {
var resp map[string]any
if err := json.Unmarshal(body, &resp); err != nil {
return nil, err
}
id := strings.TrimSpace(getString(resp["id"]))
if id == "" {
id = "chatcmpl-" + safeRandomHex(12)
}
model := strings.TrimSpace(getString(resp["model"]))
created := getInt64(resp["created_at"])
if created == 0 {
created = getInt64(resp["created"])
}
if created == 0 {
created = time.Now().Unix()
}
text, toolCalls := extractResponseTextAndToolCalls(resp)
finishReason := "stop"
if len(toolCalls) > 0 {
finishReason = "tool_calls"
}
message := map[string]any{
"role": "assistant",
"content": text,
}
if len(toolCalls) > 0 {
message["tool_calls"] = toolCalls
}
chatResp := map[string]any{
"id": id,
"object": "chat.completion",
"created": created,
"model": model,
"choices": []any{
map[string]any{
"index": 0,
"message": message,
"finish_reason": finishReason,
},
},
}
if usage := extractResponseUsage(resp); usage != nil {
chatResp["usage"] = usage
}
if fingerprint := strings.TrimSpace(getString(resp["system_fingerprint"])); fingerprint != "" {
chatResp["system_fingerprint"] = fingerprint
}
return json.Marshal(chatResp)
}
func convertChatMessagesToResponsesInput(messages []any) ([]any, error) {
input := make([]any, 0, len(messages))
for _, msg := range messages {
msgMap, ok := msg.(map[string]any)
if !ok {
return nil, errors.New("message must be an object")
}
role := strings.TrimSpace(getString(msgMap["role"]))
if role == "" {
return nil, errors.New("message role is required")
}
switch role {
case "tool":
callID := strings.TrimSpace(getString(msgMap["tool_call_id"]))
if callID == "" {
callID = strings.TrimSpace(getString(msgMap["id"]))
}
output := extractMessageContentText(msgMap["content"])
input = append(input, map[string]any{
"type": "function_call_output",
"call_id": callID,
"output": output,
})
case "function":
callID := strings.TrimSpace(getString(msgMap["name"]))
output := extractMessageContentText(msgMap["content"])
input = append(input, map[string]any{
"type": "function_call_output",
"call_id": callID,
"output": output,
})
default:
convertedContent := convertChatContent(msgMap["content"])
toolCalls := []any(nil)
if role == "assistant" {
toolCalls = extractToolCallsFromMessage(msgMap)
}
skipAssistantMessage := role == "assistant" && len(toolCalls) > 0 && isEmptyContent(convertedContent)
if !skipAssistantMessage {
msgItem := map[string]any{
"role": role,
"content": convertedContent,
}
if name := strings.TrimSpace(getString(msgMap["name"])); name != "" {
msgItem["name"] = name
}
input = append(input, msgItem)
}
if role == "assistant" && len(toolCalls) > 0 {
input = append(input, toolCalls...)
}
}
}
return input, nil
}
func convertChatContent(content any) any {
switch v := content.(type) {
case nil:
return ""
case string:
return v
case []any:
converted := make([]any, 0, len(v))
for _, part := range v {
partMap, ok := part.(map[string]any)
if !ok {
converted = append(converted, part)
continue
}
partType := strings.TrimSpace(getString(partMap["type"]))
switch partType {
case "text":
text := getString(partMap["text"])
if text != "" {
converted = append(converted, map[string]any{
"type": "input_text",
"text": text,
})
continue
}
case "image_url":
imageURL := ""
if imageObj, ok := partMap["image_url"].(map[string]any); ok {
imageURL = getString(imageObj["url"])
} else {
imageURL = getString(partMap["image_url"])
}
if imageURL != "" {
converted = append(converted, map[string]any{
"type": "input_image",
"image_url": imageURL,
})
continue
}
case "input_text", "input_image":
converted = append(converted, partMap)
continue
}
converted = append(converted, partMap)
}
return converted
default:
return v
}
}
func extractToolCallsFromMessage(msg map[string]any) []any {
var out []any
if toolCalls, ok := msg["tool_calls"].([]any); ok {
for _, call := range toolCalls {
callMap, ok := call.(map[string]any)
if !ok {
continue
}
callID := strings.TrimSpace(getString(callMap["id"]))
if callID == "" {
callID = strings.TrimSpace(getString(callMap["call_id"]))
}
name := ""
args := ""
if fn, ok := callMap["function"].(map[string]any); ok {
name = strings.TrimSpace(getString(fn["name"]))
args = getString(fn["arguments"])
}
if name == "" && args == "" {
continue
}
item := map[string]any{
"type": "tool_call",
}
if callID != "" {
item["call_id"] = callID
}
if name != "" {
item["name"] = name
}
if args != "" {
item["arguments"] = args
}
out = append(out, item)
}
}
if fnCall, ok := msg["function_call"].(map[string]any); ok {
name := strings.TrimSpace(getString(fnCall["name"]))
args := getString(fnCall["arguments"])
if name != "" || args != "" {
callID := strings.TrimSpace(getString(msg["tool_call_id"]))
if callID == "" {
callID = name
}
item := map[string]any{
"type": "function_call",
}
if callID != "" {
item["call_id"] = callID
}
if name != "" {
item["name"] = name
}
if args != "" {
item["arguments"] = args
}
out = append(out, item)
}
}
return out
}
func extractMessageContentText(content any) string {
switch v := content.(type) {
case string:
return v
case []any:
parts := make([]string, 0, len(v))
for _, part := range v {
partMap, ok := part.(map[string]any)
if !ok {
continue
}
partType := strings.TrimSpace(getString(partMap["type"]))
if partType == "" || partType == "text" || partType == "output_text" || partType == "input_text" {
text := getString(partMap["text"])
if text != "" {
parts = append(parts, text)
}
}
}
return strings.Join(parts, "")
default:
return ""
}
}
func isEmptyContent(content any) bool {
switch v := content.(type) {
case nil:
return true
case string:
return strings.TrimSpace(v) == ""
case []any:
return len(v) == 0
default:
return false
}
}
func extractResponseTextAndToolCalls(resp map[string]any) (string, []any) {
output, ok := resp["output"].([]any)
if !ok {
if text, ok := resp["output_text"].(string); ok {
return text, nil
}
return "", nil
}
textParts := make([]string, 0)
toolCalls := make([]any, 0)
for _, item := range output {
itemMap, ok := item.(map[string]any)
if !ok {
continue
}
itemType := strings.TrimSpace(getString(itemMap["type"]))
if itemType == "tool_call" || itemType == "function_call" {
if tc := responseItemToChatToolCall(itemMap); tc != nil {
toolCalls = append(toolCalls, tc)
}
continue
}
content := itemMap["content"]
switch v := content.(type) {
case string:
if v != "" {
textParts = append(textParts, v)
}
case []any:
for _, part := range v {
partMap, ok := part.(map[string]any)
if !ok {
continue
}
partType := strings.TrimSpace(getString(partMap["type"]))
switch partType {
case "output_text", "text", "input_text":
text := getString(partMap["text"])
if text != "" {
textParts = append(textParts, text)
}
case "tool_call", "function_call":
if tc := responseItemToChatToolCall(partMap); tc != nil {
toolCalls = append(toolCalls, tc)
}
}
}
}
}
return strings.Join(textParts, ""), toolCalls
}
func responseItemToChatToolCall(item map[string]any) map[string]any {
callID := strings.TrimSpace(getString(item["call_id"]))
if callID == "" {
callID = strings.TrimSpace(getString(item["id"]))
}
name := strings.TrimSpace(getString(item["name"]))
arguments := getString(item["arguments"])
if fn, ok := item["function"].(map[string]any); ok {
if name == "" {
name = strings.TrimSpace(getString(fn["name"]))
}
if arguments == "" {
arguments = getString(fn["arguments"])
}
}
if name == "" && arguments == "" && callID == "" {
return nil
}
if callID == "" {
callID = "call_" + safeRandomHex(6)
}
return map[string]any{
"id": callID,
"type": "function",
"function": map[string]any{
"name": name,
"arguments": arguments,
},
}
}
func extractResponseUsage(resp map[string]any) map[string]any {
usage, ok := resp["usage"].(map[string]any)
if !ok {
return nil
}
promptTokens := int(getNumber(usage["input_tokens"]))
completionTokens := int(getNumber(usage["output_tokens"]))
if promptTokens == 0 && completionTokens == 0 {
return nil
}
return map[string]any{
"prompt_tokens": promptTokens,
"completion_tokens": completionTokens,
"total_tokens": promptTokens + completionTokens,
}
}
func getString(value any) string {
switch v := value.(type) {
case string:
return v
case []byte:
return string(v)
case json.Number:
return v.String()
default:
return ""
}
}
func getNumber(value any) float64 {
switch v := value.(type) {
case float64:
return v
case float32:
return float64(v)
case int:
return float64(v)
case int64:
return float64(v)
case json.Number:
f, _ := v.Float64()
return f
default:
return 0
}
}
func getInt64(value any) int64 {
switch v := value.(type) {
case int64:
return v
case int:
return int64(v)
case float64:
return int64(v)
case json.Number:
i, _ := v.Int64()
return i
default:
return 0
}
}
func safeRandomHex(byteLength int) string {
value, err := randomHexString(byteLength)
if err != nil || value == "" {
return "000000"
}
return value
}

View File

@@ -0,0 +1,486 @@
package service
import (
"bufio"
"bytes"
"context"
"encoding/json"
"errors"
"fmt"
"io"
"log"
"net/http"
"strings"
"sync/atomic"
"time"
"github.com/Wei-Shaw/sub2api/internal/util/responseheaders"
"github.com/gin-gonic/gin"
)
type chatStreamingResult struct {
usage *OpenAIUsage
firstTokenMs *int
}
func (s *OpenAIGatewayService) forwardChatCompletions(ctx context.Context, c *gin.Context, account *Account, body []byte, includeUsage bool, startTime time.Time) (*OpenAIForwardResult, error) {
// Parse request body once (avoid multiple parse/serialize cycles)
var reqBody map[string]any
if err := json.Unmarshal(body, &reqBody); err != nil {
return nil, fmt.Errorf("parse request: %w", err)
}
reqModel, _ := reqBody["model"].(string)
reqStream, _ := reqBody["stream"].(bool)
originalModel := reqModel
bodyModified := false
mappedModel := account.GetMappedModel(reqModel)
if mappedModel != reqModel {
log.Printf("[OpenAI Chat] Model mapping applied: %s -> %s (account: %s)", reqModel, mappedModel, account.Name)
reqBody["model"] = mappedModel
bodyModified = true
}
if reqStream && includeUsage {
streamOptions, _ := reqBody["stream_options"].(map[string]any)
if streamOptions == nil {
streamOptions = map[string]any{}
}
if _, ok := streamOptions["include_usage"]; !ok {
streamOptions["include_usage"] = true
reqBody["stream_options"] = streamOptions
bodyModified = true
}
}
if bodyModified {
var err error
body, err = json.Marshal(reqBody)
if err != nil {
return nil, fmt.Errorf("serialize request body: %w", err)
}
}
// Get access token
token, _, err := s.GetAccessToken(ctx, account)
if err != nil {
return nil, err
}
upstreamReq, err := s.buildChatCompletionsRequest(ctx, c, account, body, token)
if err != nil {
return nil, err
}
proxyURL := ""
if account.ProxyID != nil && account.Proxy != nil {
proxyURL = account.Proxy.URL()
}
if c != nil {
c.Set(OpsUpstreamRequestBodyKey, string(body))
}
resp, err := s.httpUpstream.Do(upstreamReq, proxyURL, account.ID, account.Concurrency)
if err != nil {
safeErr := sanitizeUpstreamErrorMessage(err.Error())
setOpsUpstreamError(c, 0, safeErr, "")
appendOpsUpstreamError(c, OpsUpstreamErrorEvent{
Platform: account.Platform,
AccountID: account.ID,
AccountName: account.Name,
UpstreamStatusCode: 0,
Kind: "request_error",
Message: safeErr,
})
c.JSON(http.StatusBadGateway, gin.H{
"error": gin.H{
"type": "upstream_error",
"message": "Upstream request failed",
},
})
return nil, fmt.Errorf("upstream request failed: %s", safeErr)
}
defer func() { _ = resp.Body.Close() }()
if resp.StatusCode >= 400 {
if s.shouldFailoverUpstreamError(resp.StatusCode) {
respBody, _ := io.ReadAll(io.LimitReader(resp.Body, 2<<20))
_ = resp.Body.Close()
resp.Body = io.NopCloser(bytes.NewReader(respBody))
upstreamMsg := strings.TrimSpace(extractUpstreamErrorMessage(respBody))
upstreamMsg = sanitizeUpstreamErrorMessage(upstreamMsg)
upstreamDetail := ""
if s.cfg != nil && s.cfg.Gateway.LogUpstreamErrorBody {
maxBytes := s.cfg.Gateway.LogUpstreamErrorBodyMaxBytes
if maxBytes <= 0 {
maxBytes = 2048
}
upstreamDetail = truncateString(string(respBody), maxBytes)
}
appendOpsUpstreamError(c, OpsUpstreamErrorEvent{
Platform: account.Platform,
AccountID: account.ID,
AccountName: account.Name,
UpstreamStatusCode: resp.StatusCode,
UpstreamRequestID: resp.Header.Get("x-request-id"),
Kind: "failover",
Message: upstreamMsg,
Detail: upstreamDetail,
})
s.handleFailoverSideEffects(ctx, resp, account)
return nil, &UpstreamFailoverError{StatusCode: resp.StatusCode}
}
return s.handleErrorResponse(ctx, resp, c, account, body)
}
var usage *OpenAIUsage
var firstTokenMs *int
if reqStream {
streamResult, err := s.handleChatCompletionsStreamingResponse(ctx, resp, c, account, startTime, originalModel, mappedModel)
if err != nil {
return nil, err
}
usage = streamResult.usage
firstTokenMs = streamResult.firstTokenMs
} else {
usage, err = s.handleChatCompletionsNonStreamingResponse(resp, c, originalModel, mappedModel)
if err != nil {
return nil, err
}
}
if usage == nil {
usage = &OpenAIUsage{}
}
return &OpenAIForwardResult{
RequestID: resp.Header.Get("x-request-id"),
Usage: *usage,
Model: originalModel,
Stream: reqStream,
Duration: time.Since(startTime),
FirstTokenMs: firstTokenMs,
}, nil
}
func (s *OpenAIGatewayService) buildChatCompletionsRequest(ctx context.Context, c *gin.Context, account *Account, body []byte, token string) (*http.Request, error) {
var targetURL string
baseURL := account.GetOpenAIBaseURL()
if baseURL == "" {
targetURL = openaiChatAPIURL
} else {
validatedURL, err := s.validateUpstreamBaseURL(baseURL)
if err != nil {
return nil, err
}
targetURL = validatedURL + "/chat/completions"
}
req, err := http.NewRequestWithContext(ctx, "POST", targetURL, bytes.NewReader(body))
if err != nil {
return nil, err
}
req.Header.Set("authorization", "Bearer "+token)
for key, values := range c.Request.Header {
lowerKey := strings.ToLower(key)
if openaiChatAllowedHeaders[lowerKey] {
for _, v := range values {
req.Header.Add(key, v)
}
}
}
customUA := account.GetOpenAIUserAgent()
if customUA != "" {
req.Header.Set("user-agent", customUA)
}
if req.Header.Get("content-type") == "" {
req.Header.Set("content-type", "application/json")
}
return req, nil
}
func (s *OpenAIGatewayService) handleChatCompletionsStreamingResponse(ctx context.Context, resp *http.Response, c *gin.Context, account *Account, startTime time.Time, originalModel, mappedModel string) (*chatStreamingResult, error) {
if s.cfg != nil {
responseheaders.WriteFilteredHeaders(c.Writer.Header(), resp.Header, s.cfg.Security.ResponseHeaders)
}
c.Header("Content-Type", "text/event-stream")
c.Header("Cache-Control", "no-cache")
c.Header("Connection", "keep-alive")
c.Header("X-Accel-Buffering", "no")
if v := resp.Header.Get("x-request-id"); v != "" {
c.Header("x-request-id", v)
}
w := c.Writer
flusher, ok := w.(http.Flusher)
if !ok {
return nil, errors.New("streaming not supported")
}
usage := &OpenAIUsage{}
var firstTokenMs *int
scanner := bufio.NewScanner(resp.Body)
maxLineSize := defaultMaxLineSize
if s.cfg != nil && s.cfg.Gateway.MaxLineSize > 0 {
maxLineSize = s.cfg.Gateway.MaxLineSize
}
scanner.Buffer(make([]byte, 64*1024), maxLineSize)
type scanEvent struct {
line string
err error
}
events := make(chan scanEvent, 16)
done := make(chan struct{})
sendEvent := func(ev scanEvent) bool {
select {
case events <- ev:
return true
case <-done:
return false
}
}
var lastReadAt int64
atomic.StoreInt64(&lastReadAt, time.Now().UnixNano())
go func() {
defer close(events)
for scanner.Scan() {
atomic.StoreInt64(&lastReadAt, time.Now().UnixNano())
if !sendEvent(scanEvent{line: scanner.Text()}) {
return
}
}
if err := scanner.Err(); err != nil {
_ = sendEvent(scanEvent{err: err})
}
}()
defer close(done)
streamInterval := time.Duration(0)
if s.cfg != nil && s.cfg.Gateway.StreamDataIntervalTimeout > 0 {
streamInterval = time.Duration(s.cfg.Gateway.StreamDataIntervalTimeout) * time.Second
}
var intervalTicker *time.Ticker
if streamInterval > 0 {
intervalTicker = time.NewTicker(streamInterval)
defer intervalTicker.Stop()
}
var intervalCh <-chan time.Time
if intervalTicker != nil {
intervalCh = intervalTicker.C
}
keepaliveInterval := time.Duration(0)
if s.cfg != nil && s.cfg.Gateway.StreamKeepaliveInterval > 0 {
keepaliveInterval = time.Duration(s.cfg.Gateway.StreamKeepaliveInterval) * time.Second
}
var keepaliveTicker *time.Ticker
if keepaliveInterval > 0 {
keepaliveTicker = time.NewTicker(keepaliveInterval)
defer keepaliveTicker.Stop()
}
var keepaliveCh <-chan time.Time
if keepaliveTicker != nil {
keepaliveCh = keepaliveTicker.C
}
lastDataAt := time.Now()
errorEventSent := false
sendErrorEvent := func(reason string) {
if errorEventSent {
return
}
errorEventSent = true
_, _ = fmt.Fprintf(w, "event: error\ndata: {\"error\":\"%s\"}\n\n", reason)
flusher.Flush()
}
needModelReplace := originalModel != mappedModel
for {
select {
case ev, ok := <-events:
if !ok {
return &chatStreamingResult{usage: usage, firstTokenMs: firstTokenMs}, nil
}
if ev.err != nil {
if errors.Is(ev.err, bufio.ErrTooLong) {
log.Printf("SSE line too long: account=%d max_size=%d error=%v", account.ID, maxLineSize, ev.err)
sendErrorEvent("response_too_large")
return &chatStreamingResult{usage: usage, firstTokenMs: firstTokenMs}, ev.err
}
sendErrorEvent("stream_read_error")
return &chatStreamingResult{usage: usage, firstTokenMs: firstTokenMs}, fmt.Errorf("stream read error: %w", ev.err)
}
line := ev.line
lastDataAt = time.Now()
if openaiSSEDataRe.MatchString(line) {
data := openaiSSEDataRe.ReplaceAllString(line, "")
if needModelReplace {
line = s.replaceModelInSSELine(line, mappedModel, originalModel)
}
if correctedData, corrected := s.toolCorrector.CorrectToolCallsInSSEData(data); corrected {
line = "data: " + correctedData
}
if _, err := fmt.Fprintf(w, "%s\n", line); err != nil {
sendErrorEvent("write_failed")
return &chatStreamingResult{usage: usage, firstTokenMs: firstTokenMs}, err
}
flusher.Flush()
if firstTokenMs == nil {
if event := parseChatStreamEvent(data); event != nil {
if chatChunkHasDelta(event) {
ms := int(time.Since(startTime).Milliseconds())
firstTokenMs = &ms
}
applyChatUsageFromEvent(event, usage)
}
} else {
if event := parseChatStreamEvent(data); event != nil {
applyChatUsageFromEvent(event, usage)
}
}
} else {
if _, err := fmt.Fprintf(w, "%s\n", line); err != nil {
sendErrorEvent("write_failed")
return &chatStreamingResult{usage: usage, firstTokenMs: firstTokenMs}, err
}
flusher.Flush()
}
case <-intervalCh:
lastRead := time.Unix(0, atomic.LoadInt64(&lastReadAt))
if time.Since(lastRead) < streamInterval {
continue
}
log.Printf("Stream data interval timeout: account=%d model=%s interval=%s", account.ID, originalModel, streamInterval)
if s.rateLimitService != nil {
s.rateLimitService.HandleStreamTimeout(ctx, account, originalModel)
}
sendErrorEvent("stream_timeout")
return &chatStreamingResult{usage: usage, firstTokenMs: firstTokenMs}, fmt.Errorf("stream data interval timeout")
case <-keepaliveCh:
if time.Since(lastDataAt) < keepaliveInterval {
continue
}
if _, err := fmt.Fprint(w, ":\n\n"); err != nil {
return &chatStreamingResult{usage: usage, firstTokenMs: firstTokenMs}, err
}
flusher.Flush()
}
}
}
func (s *OpenAIGatewayService) handleChatCompletionsNonStreamingResponse(resp *http.Response, c *gin.Context, originalModel, mappedModel string) (*OpenAIUsage, error) {
body, err := io.ReadAll(resp.Body)
if err != nil {
return nil, err
}
usage := &OpenAIUsage{}
var parsed map[string]any
if json.Unmarshal(body, &parsed) == nil {
if usageMap, ok := parsed["usage"].(map[string]any); ok {
applyChatUsageFromMap(usageMap, usage)
}
}
if originalModel != mappedModel {
body = s.replaceModelInResponseBody(body, mappedModel, originalModel)
}
body = s.correctToolCallsInResponseBody(body)
responseheaders.WriteFilteredHeaders(c.Writer.Header(), resp.Header, s.cfg.Security.ResponseHeaders)
contentType := "application/json"
if s.cfg != nil && !s.cfg.Security.ResponseHeaders.Enabled {
if upstreamType := resp.Header.Get("Content-Type"); upstreamType != "" {
contentType = upstreamType
}
}
c.Data(resp.StatusCode, contentType, body)
return usage, nil
}
func parseChatStreamEvent(data string) map[string]any {
if data == "" || data == "[DONE]" {
return nil
}
var event map[string]any
if json.Unmarshal([]byte(data), &event) != nil {
return nil
}
return event
}
func chatChunkHasDelta(event map[string]any) bool {
choices, ok := event["choices"].([]any)
if !ok {
return false
}
for _, choice := range choices {
choiceMap, ok := choice.(map[string]any)
if !ok {
continue
}
delta, ok := choiceMap["delta"].(map[string]any)
if !ok {
continue
}
if content, ok := delta["content"].(string); ok && strings.TrimSpace(content) != "" {
return true
}
if toolCalls, ok := delta["tool_calls"].([]any); ok && len(toolCalls) > 0 {
return true
}
if functionCall, ok := delta["function_call"].(map[string]any); ok && len(functionCall) > 0 {
return true
}
}
return false
}
func applyChatUsageFromEvent(event map[string]any, usage *OpenAIUsage) {
if event == nil || usage == nil {
return
}
usageMap, ok := event["usage"].(map[string]any)
if !ok {
return
}
applyChatUsageFromMap(usageMap, usage)
}
func applyChatUsageFromMap(usageMap map[string]any, usage *OpenAIUsage) {
if usageMap == nil || usage == nil {
return
}
promptTokens := int(getNumber(usageMap["prompt_tokens"]))
completionTokens := int(getNumber(usageMap["completion_tokens"]))
if promptTokens > 0 {
usage.InputTokens = promptTokens
}
if completionTokens > 0 {
usage.OutputTokens = completionTokens
}
}

View File

@@ -0,0 +1,132 @@
package service
import (
"encoding/json"
"testing"
"github.com/stretchr/testify/require"
)
func TestConvertChatCompletionsToResponses(t *testing.T) {
req := map[string]any{
"model": "gpt-4o",
"messages": []any{
map[string]any{
"role": "user",
"content": "hello",
},
map[string]any{
"role": "assistant",
"tool_calls": []any{
map[string]any{
"id": "call_1",
"type": "function",
"function": map[string]any{
"name": "ping",
"arguments": "{}",
},
},
},
},
map[string]any{
"role": "tool",
"tool_call_id": "call_1",
"content": "ok",
"response": "ignored",
"response_time": 1,
},
},
"functions": []any{
map[string]any{
"name": "ping",
"description": "ping tool",
"parameters": map[string]any{"type": "object"},
},
},
"function_call": map[string]any{"name": "ping"},
}
converted, err := ConvertChatCompletionsToResponses(req)
require.NoError(t, err)
require.Equal(t, "gpt-4o", converted["model"])
input, ok := converted["input"].([]any)
require.True(t, ok)
require.Len(t, input, 3)
toolCall := findInputItemByType(input, "tool_call")
require.NotNil(t, toolCall)
require.Equal(t, "call_1", toolCall["call_id"])
toolOutput := findInputItemByType(input, "function_call_output")
require.NotNil(t, toolOutput)
require.Equal(t, "call_1", toolOutput["call_id"])
tools, ok := converted["tools"].([]any)
require.True(t, ok)
require.Len(t, tools, 1)
require.Equal(t, map[string]any{"name": "ping"}, converted["tool_choice"])
}
func TestConvertResponsesToChatCompletion(t *testing.T) {
resp := map[string]any{
"id": "resp_123",
"model": "gpt-4o",
"created_at": 1700000000,
"output": []any{
map[string]any{
"type": "message",
"role": "assistant",
"content": []any{
map[string]any{
"type": "output_text",
"text": "hi",
},
},
},
},
"usage": map[string]any{
"input_tokens": 2,
"output_tokens": 3,
},
}
body, err := json.Marshal(resp)
require.NoError(t, err)
converted, err := ConvertResponsesToChatCompletion(body)
require.NoError(t, err)
var chat map[string]any
require.NoError(t, json.Unmarshal(converted, &chat))
require.Equal(t, "chat.completion", chat["object"])
choices, ok := chat["choices"].([]any)
require.True(t, ok)
require.Len(t, choices, 1)
choice, ok := choices[0].(map[string]any)
require.True(t, ok)
message, ok := choice["message"].(map[string]any)
require.True(t, ok)
require.Equal(t, "hi", message["content"])
usage, ok := chat["usage"].(map[string]any)
require.True(t, ok)
require.Equal(t, float64(2), usage["prompt_tokens"])
require.Equal(t, float64(3), usage["completion_tokens"])
require.Equal(t, float64(5), usage["total_tokens"])
}
func findInputItemByType(items []any, itemType string) map[string]any {
for _, item := range items {
itemMap, ok := item.(map[string]any)
if !ok {
continue
}
if itemMap["type"] == itemType {
return itemMap
}
}
return nil
}

View File

@@ -11,6 +11,7 @@ import (
"fmt"
"io"
"net/http"
"regexp"
"sort"
"strconv"
"strings"
@@ -33,6 +34,7 @@ const (
chatgptCodexURL = "https://chatgpt.com/backend-api/codex/responses"
// OpenAI Platform API for API Key accounts (fallback)
openaiPlatformAPIURL = "https://api.openai.com/v1/responses"
openaiChatAPIURL = "https://api.openai.com/v1/chat/completions"
openaiStickySessionTTL = time.Hour // 粘性会话TTL
codexCLIUserAgent = "codex_cli_rs/0.98.0"
// codex_cli_only 拒绝时单个请求头日志长度上限(字符)
@@ -42,6 +44,16 @@ const (
OpenAIParsedRequestBodyKey = "openai_parsed_request_body"
)
// OpenAIChatCompletionsBodyKey stores the original chat-completions payload in gin.Context.
const OpenAIChatCompletionsBodyKey = "openai_chat_completions_body"
// OpenAIChatCompletionsIncludeUsageKey stores stream_options.include_usage in gin.Context.
const OpenAIChatCompletionsIncludeUsageKey = "openai_chat_completions_include_usage"
// openaiSSEDataRe matches SSE data lines with optional whitespace after colon.
// Some upstream APIs return non-standard "data:" without space (should be "data: ").
var openaiSSEDataRe = regexp.MustCompile(`^data:\s*`)
// OpenAI allowed headers whitelist (for non-passthrough).
var openaiAllowedHeaders = map[string]bool{
"accept-language": true,
@@ -81,6 +93,19 @@ var codexCLIOnlyDebugHeaderWhitelist = []string{
"X-Real-IP",
}
// OpenAI chat-completions allowed headers (extend responses whitelist).
var openaiChatAllowedHeaders = map[string]bool{
"accept-language": true,
"content-type": true,
"conversation_id": true,
"user-agent": true,
"originator": true,
"session_id": true,
"openai-organization": true,
"openai-project": true,
"openai-beta": true,
}
// OpenAICodexUsageSnapshot represents Codex API usage limits from response headers
type OpenAICodexUsageSnapshot struct {
PrimaryUsedPercent *float64 `json:"primary_used_percent,omitempty"`
@@ -1005,6 +1030,23 @@ func (s *OpenAIGatewayService) Forward(ctx context.Context, c *gin.Context, acco
return nil, errors.New("codex_cli_only restriction: only codex official clients are allowed")
}
if c != nil && account != nil && account.Type == AccountTypeAPIKey {
if raw, ok := c.Get(OpenAIChatCompletionsBodyKey); ok {
if rawBody, ok := raw.([]byte); ok && len(rawBody) > 0 {
includeUsage := false
if v, ok := c.Get(OpenAIChatCompletionsIncludeUsageKey); ok {
if flag, ok := v.(bool); ok {
includeUsage = flag
}
}
if passthroughWriter, ok := c.Writer.(interface{ SetPassthrough() }); ok {
passthroughWriter.SetPassthrough()
}
return s.forwardChatCompletions(ctx, c, account, rawBody, includeUsage, startTime)
}
}
}
originalBody := body
reqModel, reqStream, promptCacheKey := extractOpenAIRequestMetaFromBody(body)
originalModel := reqModel