mirror of
https://github.com/Wei-Shaw/sub2api.git
synced 2026-03-30 03:26:10 +00:00
feat: add OpenAI chat completions compatibility
This commit is contained in:
530
backend/internal/handler/openai_chat_completions.go
Normal file
530
backend/internal/handler/openai_chat_completions.go
Normal file
@@ -0,0 +1,530 @@
|
||||
package handler
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"crypto/rand"
|
||||
"encoding/hex"
|
||||
"encoding/json"
|
||||
"io"
|
||||
"net/http"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/service"
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
// ChatCompletions handles OpenAI Chat Completions API compatibility.
|
||||
// POST /v1/chat/completions
|
||||
func (h *OpenAIGatewayHandler) ChatCompletions(c *gin.Context) {
|
||||
body, err := io.ReadAll(c.Request.Body)
|
||||
if err != nil {
|
||||
if maxErr, ok := extractMaxBytesError(err); ok {
|
||||
h.errorResponse(c, http.StatusRequestEntityTooLarge, "invalid_request_error", buildBodyTooLargeMessage(maxErr.Limit))
|
||||
return
|
||||
}
|
||||
h.errorResponse(c, http.StatusBadRequest, "invalid_request_error", "Failed to read request body")
|
||||
return
|
||||
}
|
||||
if len(body) == 0 {
|
||||
h.errorResponse(c, http.StatusBadRequest, "invalid_request_error", "Request body is empty")
|
||||
return
|
||||
}
|
||||
|
||||
// Preserve original chat-completions request for upstream passthrough when needed.
|
||||
c.Set(service.OpenAIChatCompletionsBodyKey, body)
|
||||
|
||||
var chatReq map[string]any
|
||||
if err := json.Unmarshal(body, &chatReq); err != nil {
|
||||
h.errorResponse(c, http.StatusBadRequest, "invalid_request_error", "Failed to parse request body")
|
||||
return
|
||||
}
|
||||
|
||||
includeUsage := false
|
||||
if streamOptions, ok := chatReq["stream_options"].(map[string]any); ok {
|
||||
if v, ok := streamOptions["include_usage"].(bool); ok {
|
||||
includeUsage = v
|
||||
}
|
||||
}
|
||||
c.Set(service.OpenAIChatCompletionsIncludeUsageKey, includeUsage)
|
||||
|
||||
converted, err := service.ConvertChatCompletionsToResponses(chatReq)
|
||||
if err != nil {
|
||||
h.errorResponse(c, http.StatusBadRequest, "invalid_request_error", err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
convertedBody, err := json.Marshal(converted)
|
||||
if err != nil {
|
||||
h.errorResponse(c, http.StatusInternalServerError, "api_error", "Failed to process request")
|
||||
return
|
||||
}
|
||||
|
||||
stream, _ := converted["stream"].(bool)
|
||||
model, _ := converted["model"].(string)
|
||||
writer := newChatCompletionsResponseWriter(c.Writer, stream, includeUsage, model)
|
||||
c.Writer = writer
|
||||
c.Request.Body = io.NopCloser(bytes.NewReader(convertedBody))
|
||||
c.Request.ContentLength = int64(len(convertedBody))
|
||||
|
||||
h.Responses(c)
|
||||
writer.Finalize()
|
||||
}
|
||||
|
||||
type chatCompletionsResponseWriter struct {
|
||||
gin.ResponseWriter
|
||||
stream bool
|
||||
includeUsage bool
|
||||
buffer bytes.Buffer
|
||||
streamBuf bytes.Buffer
|
||||
state *chatCompletionStreamState
|
||||
corrector *service.CodexToolCorrector
|
||||
finalized bool
|
||||
passthrough bool
|
||||
}
|
||||
|
||||
type chatCompletionStreamState struct {
|
||||
id string
|
||||
model string
|
||||
created int64
|
||||
sentRole bool
|
||||
sawToolCall bool
|
||||
sawText bool
|
||||
toolCallIndex map[string]int
|
||||
usage map[string]any
|
||||
}
|
||||
|
||||
func newChatCompletionsResponseWriter(w gin.ResponseWriter, stream bool, includeUsage bool, model string) *chatCompletionsResponseWriter {
|
||||
return &chatCompletionsResponseWriter{
|
||||
ResponseWriter: w,
|
||||
stream: stream,
|
||||
includeUsage: includeUsage,
|
||||
state: &chatCompletionStreamState{
|
||||
model: strings.TrimSpace(model),
|
||||
toolCallIndex: make(map[string]int),
|
||||
},
|
||||
corrector: service.NewCodexToolCorrector(),
|
||||
}
|
||||
}
|
||||
|
||||
func (w *chatCompletionsResponseWriter) Write(data []byte) (int, error) {
|
||||
if w.passthrough {
|
||||
return w.ResponseWriter.Write(data)
|
||||
}
|
||||
if w.stream {
|
||||
n, err := w.streamBuf.Write(data)
|
||||
if err != nil {
|
||||
return n, err
|
||||
}
|
||||
w.flushStreamBuffer()
|
||||
return n, nil
|
||||
}
|
||||
|
||||
if w.finalized {
|
||||
return len(data), nil
|
||||
}
|
||||
return w.buffer.Write(data)
|
||||
}
|
||||
|
||||
func (w *chatCompletionsResponseWriter) WriteString(s string) (int, error) {
|
||||
return w.Write([]byte(s))
|
||||
}
|
||||
|
||||
func (w *chatCompletionsResponseWriter) Finalize() {
|
||||
if w.finalized {
|
||||
return
|
||||
}
|
||||
w.finalized = true
|
||||
if w.passthrough {
|
||||
return
|
||||
}
|
||||
if w.stream {
|
||||
return
|
||||
}
|
||||
|
||||
body := w.buffer.Bytes()
|
||||
if len(body) == 0 {
|
||||
return
|
||||
}
|
||||
|
||||
w.ResponseWriter.Header().Del("Content-Length")
|
||||
|
||||
converted, err := service.ConvertResponsesToChatCompletion(body)
|
||||
if err != nil {
|
||||
_, _ = w.ResponseWriter.Write(body)
|
||||
return
|
||||
}
|
||||
|
||||
corrected := converted
|
||||
if correctedStr, ok := w.corrector.CorrectToolCallsInSSEData(string(converted)); ok {
|
||||
corrected = []byte(correctedStr)
|
||||
}
|
||||
|
||||
_, _ = w.ResponseWriter.Write(corrected)
|
||||
}
|
||||
|
||||
func (w *chatCompletionsResponseWriter) SetPassthrough() {
|
||||
w.passthrough = true
|
||||
}
|
||||
|
||||
func (w *chatCompletionsResponseWriter) flushStreamBuffer() {
|
||||
for {
|
||||
buf := w.streamBuf.Bytes()
|
||||
idx := bytes.IndexByte(buf, '\n')
|
||||
if idx == -1 {
|
||||
return
|
||||
}
|
||||
lineBytes := w.streamBuf.Next(idx + 1)
|
||||
line := strings.TrimRight(string(lineBytes), "\r\n")
|
||||
w.handleStreamLine(line)
|
||||
}
|
||||
}
|
||||
|
||||
func (w *chatCompletionsResponseWriter) handleStreamLine(line string) {
|
||||
if line == "" {
|
||||
return
|
||||
}
|
||||
if strings.HasPrefix(line, ":") {
|
||||
_, _ = w.ResponseWriter.Write([]byte(line + "\n\n"))
|
||||
return
|
||||
}
|
||||
if !strings.HasPrefix(line, "data:") {
|
||||
return
|
||||
}
|
||||
|
||||
data := strings.TrimSpace(strings.TrimPrefix(line, "data:"))
|
||||
for _, chunk := range w.convertResponseDataToChatChunks(data) {
|
||||
if chunk == "" {
|
||||
continue
|
||||
}
|
||||
if chunk == "[DONE]" {
|
||||
_, _ = w.ResponseWriter.Write([]byte("data: [DONE]\n\n"))
|
||||
continue
|
||||
}
|
||||
_, _ = w.ResponseWriter.Write([]byte("data: " + chunk + "\n\n"))
|
||||
}
|
||||
}
|
||||
|
||||
func (w *chatCompletionsResponseWriter) convertResponseDataToChatChunks(data string) []string {
|
||||
if data == "" {
|
||||
return nil
|
||||
}
|
||||
if data == "[DONE]" {
|
||||
return []string{"[DONE]"}
|
||||
}
|
||||
|
||||
var payload map[string]any
|
||||
if err := json.Unmarshal([]byte(data), &payload); err != nil {
|
||||
return []string{data}
|
||||
}
|
||||
|
||||
if _, ok := payload["error"]; ok {
|
||||
return []string{data}
|
||||
}
|
||||
|
||||
eventType := strings.TrimSpace(getString(payload["type"]))
|
||||
if eventType == "" {
|
||||
return []string{data}
|
||||
}
|
||||
|
||||
w.state.applyMetadata(payload)
|
||||
|
||||
switch eventType {
|
||||
case "response.created":
|
||||
return nil
|
||||
case "response.output_text.delta":
|
||||
delta := getString(payload["delta"])
|
||||
if delta == "" {
|
||||
return nil
|
||||
}
|
||||
w.state.sawText = true
|
||||
return []string{w.buildTextDeltaChunk(delta)}
|
||||
case "response.output_text.done":
|
||||
if w.state.sawText {
|
||||
return nil
|
||||
}
|
||||
text := getString(payload["text"])
|
||||
if text == "" {
|
||||
return nil
|
||||
}
|
||||
w.state.sawText = true
|
||||
return []string{w.buildTextDeltaChunk(text)}
|
||||
case "response.output_item.added", "response.output_item.delta":
|
||||
if item, ok := payload["item"].(map[string]any); ok {
|
||||
if callID, name, args, ok := extractToolCallFromItem(item); ok {
|
||||
w.state.sawToolCall = true
|
||||
return []string{w.buildToolCallChunk(callID, name, args)}
|
||||
}
|
||||
}
|
||||
case "response.completed", "response.done":
|
||||
if responseObj, ok := payload["response"].(map[string]any); ok {
|
||||
w.state.applyResponseUsage(responseObj)
|
||||
}
|
||||
return []string{w.buildFinalChunk()}
|
||||
}
|
||||
|
||||
if strings.Contains(eventType, "tool_call") || strings.Contains(eventType, "function_call") {
|
||||
callID := strings.TrimSpace(getString(payload["call_id"]))
|
||||
if callID == "" {
|
||||
callID = strings.TrimSpace(getString(payload["tool_call_id"]))
|
||||
}
|
||||
if callID == "" {
|
||||
callID = strings.TrimSpace(getString(payload["id"]))
|
||||
}
|
||||
args := getString(payload["delta"])
|
||||
name := strings.TrimSpace(getString(payload["name"]))
|
||||
if callID != "" && (args != "" || name != "") {
|
||||
w.state.sawToolCall = true
|
||||
return []string{w.buildToolCallChunk(callID, name, args)}
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (w *chatCompletionsResponseWriter) buildTextDeltaChunk(delta string) string {
|
||||
w.state.ensureDefaults()
|
||||
payload := map[string]any{
|
||||
"content": delta,
|
||||
}
|
||||
if !w.state.sentRole {
|
||||
payload["role"] = "assistant"
|
||||
w.state.sentRole = true
|
||||
}
|
||||
return w.buildChunk(payload, nil, nil)
|
||||
}
|
||||
|
||||
func (w *chatCompletionsResponseWriter) buildToolCallChunk(callID, name, args string) string {
|
||||
w.state.ensureDefaults()
|
||||
index := w.state.toolCallIndexFor(callID)
|
||||
function := map[string]any{}
|
||||
if name != "" {
|
||||
function["name"] = name
|
||||
}
|
||||
if args != "" {
|
||||
function["arguments"] = args
|
||||
}
|
||||
toolCall := map[string]any{
|
||||
"index": index,
|
||||
"id": callID,
|
||||
"type": "function",
|
||||
"function": function,
|
||||
}
|
||||
|
||||
delta := map[string]any{
|
||||
"tool_calls": []any{toolCall},
|
||||
}
|
||||
if !w.state.sentRole {
|
||||
delta["role"] = "assistant"
|
||||
w.state.sentRole = true
|
||||
}
|
||||
|
||||
return w.buildChunk(delta, nil, nil)
|
||||
}
|
||||
|
||||
func (w *chatCompletionsResponseWriter) buildFinalChunk() string {
|
||||
w.state.ensureDefaults()
|
||||
finishReason := "stop"
|
||||
if w.state.sawToolCall {
|
||||
finishReason = "tool_calls"
|
||||
}
|
||||
usage := map[string]any(nil)
|
||||
if w.includeUsage && w.state.usage != nil {
|
||||
usage = w.state.usage
|
||||
}
|
||||
return w.buildChunk(map[string]any{}, finishReason, usage)
|
||||
}
|
||||
|
||||
func (w *chatCompletionsResponseWriter) buildChunk(delta map[string]any, finishReason any, usage map[string]any) string {
|
||||
w.state.ensureDefaults()
|
||||
chunk := map[string]any{
|
||||
"id": w.state.id,
|
||||
"object": "chat.completion.chunk",
|
||||
"created": w.state.created,
|
||||
"model": w.state.model,
|
||||
"choices": []any{
|
||||
map[string]any{
|
||||
"index": 0,
|
||||
"delta": delta,
|
||||
"finish_reason": finishReason,
|
||||
},
|
||||
},
|
||||
}
|
||||
if usage != nil {
|
||||
chunk["usage"] = usage
|
||||
}
|
||||
|
||||
data, _ := json.Marshal(chunk)
|
||||
if corrected, ok := w.corrector.CorrectToolCallsInSSEData(string(data)); ok {
|
||||
return corrected
|
||||
}
|
||||
return string(data)
|
||||
}
|
||||
|
||||
func (s *chatCompletionStreamState) ensureDefaults() {
|
||||
if s.id == "" {
|
||||
s.id = "chatcmpl-" + randomHexUnsafe(12)
|
||||
}
|
||||
if s.model == "" {
|
||||
s.model = "unknown"
|
||||
}
|
||||
if s.created == 0 {
|
||||
s.created = time.Now().Unix()
|
||||
}
|
||||
}
|
||||
|
||||
func (s *chatCompletionStreamState) toolCallIndexFor(callID string) int {
|
||||
if idx, ok := s.toolCallIndex[callID]; ok {
|
||||
return idx
|
||||
}
|
||||
idx := len(s.toolCallIndex)
|
||||
s.toolCallIndex[callID] = idx
|
||||
return idx
|
||||
}
|
||||
|
||||
func (s *chatCompletionStreamState) applyMetadata(payload map[string]any) {
|
||||
if responseObj, ok := payload["response"].(map[string]any); ok {
|
||||
s.applyResponseMetadata(responseObj)
|
||||
}
|
||||
|
||||
if s.id == "" {
|
||||
if id := strings.TrimSpace(getString(payload["response_id"])); id != "" {
|
||||
s.id = id
|
||||
} else if id := strings.TrimSpace(getString(payload["id"])); id != "" {
|
||||
s.id = id
|
||||
}
|
||||
}
|
||||
if s.model == "" {
|
||||
if model := strings.TrimSpace(getString(payload["model"])); model != "" {
|
||||
s.model = model
|
||||
}
|
||||
}
|
||||
if s.created == 0 {
|
||||
if created := getInt64(payload["created_at"]); created != 0 {
|
||||
s.created = created
|
||||
} else if created := getInt64(payload["created"]); created != 0 {
|
||||
s.created = created
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (s *chatCompletionStreamState) applyResponseMetadata(responseObj map[string]any) {
|
||||
if s.id == "" {
|
||||
if id := strings.TrimSpace(getString(responseObj["id"])); id != "" {
|
||||
s.id = id
|
||||
}
|
||||
}
|
||||
if s.model == "" {
|
||||
if model := strings.TrimSpace(getString(responseObj["model"])); model != "" {
|
||||
s.model = model
|
||||
}
|
||||
}
|
||||
if s.created == 0 {
|
||||
if created := getInt64(responseObj["created_at"]); created != 0 {
|
||||
s.created = created
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (s *chatCompletionStreamState) applyResponseUsage(responseObj map[string]any) {
|
||||
usage, ok := responseObj["usage"].(map[string]any)
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
promptTokens := int(getNumber(usage["input_tokens"]))
|
||||
completionTokens := int(getNumber(usage["output_tokens"]))
|
||||
if promptTokens == 0 && completionTokens == 0 {
|
||||
return
|
||||
}
|
||||
s.usage = map[string]any{
|
||||
"prompt_tokens": promptTokens,
|
||||
"completion_tokens": completionTokens,
|
||||
"total_tokens": promptTokens + completionTokens,
|
||||
}
|
||||
}
|
||||
|
||||
func extractToolCallFromItem(item map[string]any) (string, string, string, bool) {
|
||||
itemType := strings.TrimSpace(getString(item["type"]))
|
||||
if itemType != "tool_call" && itemType != "function_call" {
|
||||
return "", "", "", false
|
||||
}
|
||||
callID := strings.TrimSpace(getString(item["call_id"]))
|
||||
if callID == "" {
|
||||
callID = strings.TrimSpace(getString(item["id"]))
|
||||
}
|
||||
name := strings.TrimSpace(getString(item["name"]))
|
||||
args := getString(item["arguments"])
|
||||
if fn, ok := item["function"].(map[string]any); ok {
|
||||
if name == "" {
|
||||
name = strings.TrimSpace(getString(fn["name"]))
|
||||
}
|
||||
if args == "" {
|
||||
args = getString(fn["arguments"])
|
||||
}
|
||||
}
|
||||
if callID == "" && name == "" && args == "" {
|
||||
return "", "", "", false
|
||||
}
|
||||
if callID == "" {
|
||||
callID = "call_" + randomHexUnsafe(6)
|
||||
}
|
||||
return callID, name, args, true
|
||||
}
|
||||
|
||||
func getString(value any) string {
|
||||
switch v := value.(type) {
|
||||
case string:
|
||||
return v
|
||||
case []byte:
|
||||
return string(v)
|
||||
case json.Number:
|
||||
return v.String()
|
||||
default:
|
||||
return ""
|
||||
}
|
||||
}
|
||||
|
||||
func getNumber(value any) float64 {
|
||||
switch v := value.(type) {
|
||||
case float64:
|
||||
return v
|
||||
case float32:
|
||||
return float64(v)
|
||||
case int:
|
||||
return float64(v)
|
||||
case int64:
|
||||
return float64(v)
|
||||
case json.Number:
|
||||
f, _ := v.Float64()
|
||||
return f
|
||||
default:
|
||||
return 0
|
||||
}
|
||||
}
|
||||
|
||||
func getInt64(value any) int64 {
|
||||
switch v := value.(type) {
|
||||
case int64:
|
||||
return v
|
||||
case int:
|
||||
return int64(v)
|
||||
case float64:
|
||||
return int64(v)
|
||||
case json.Number:
|
||||
i, _ := v.Int64()
|
||||
return i
|
||||
default:
|
||||
return 0
|
||||
}
|
||||
}
|
||||
|
||||
func randomHexUnsafe(byteLength int) string {
|
||||
if byteLength <= 0 {
|
||||
byteLength = 8
|
||||
}
|
||||
buf := make([]byte, byteLength)
|
||||
if _, err := rand.Read(buf); err != nil {
|
||||
return "000000"
|
||||
}
|
||||
return hex.EncodeToString(buf)
|
||||
}
|
||||
@@ -1,8 +1,6 @@
|
||||
package routes
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/config"
|
||||
"github.com/Wei-Shaw/sub2api/internal/handler"
|
||||
"github.com/Wei-Shaw/sub2api/internal/server/middleware"
|
||||
@@ -43,15 +41,8 @@ func RegisterGatewayRoutes(
|
||||
gateway.GET("/usage", h.Gateway.Usage)
|
||||
// OpenAI Responses API
|
||||
gateway.POST("/responses", h.OpenAIGateway.Responses)
|
||||
// 明确阻止旧协议入口:OpenAI 仅支持 Responses API,避免客户端误解为会自动路由到其它平台。
|
||||
gateway.POST("/chat/completions", func(c *gin.Context) {
|
||||
c.JSON(http.StatusBadRequest, gin.H{
|
||||
"error": gin.H{
|
||||
"type": "invalid_request_error",
|
||||
"message": "Unsupported legacy protocol: /v1/chat/completions is not supported. Please use /v1/responses.",
|
||||
},
|
||||
})
|
||||
})
|
||||
// OpenAI Chat Completions API
|
||||
gateway.POST("/chat/completions", h.OpenAIGateway.ChatCompletions)
|
||||
}
|
||||
|
||||
// Gemini 原生 API 兼容层(Gemini SDK/CLI 直连)
|
||||
@@ -69,6 +60,8 @@ func RegisterGatewayRoutes(
|
||||
|
||||
// OpenAI Responses API(不带v1前缀的别名)
|
||||
r.POST("/responses", bodyLimit, clientRequestID, opsErrorLogger, gin.HandlerFunc(apiKeyAuth), h.OpenAIGateway.Responses)
|
||||
// OpenAI Chat Completions API(不带v1前缀的别名)
|
||||
r.POST("/chat/completions", bodyLimit, clientRequestID, opsErrorLogger, gin.HandlerFunc(apiKeyAuth), h.OpenAIGateway.ChatCompletions)
|
||||
|
||||
// Antigravity 模型列表
|
||||
r.GET("/antigravity/models", gin.HandlerFunc(apiKeyAuth), h.Gateway.AntigravityModels)
|
||||
|
||||
513
backend/internal/service/openai_chat_completions.go
Normal file
513
backend/internal/service/openai_chat_completions.go
Normal file
@@ -0,0 +1,513 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"strings"
|
||||
"time"
|
||||
)
|
||||
|
||||
// ConvertChatCompletionsToResponses converts an OpenAI Chat Completions request to a Responses request.
|
||||
func ConvertChatCompletionsToResponses(req map[string]any) (map[string]any, error) {
|
||||
if req == nil {
|
||||
return nil, errors.New("request is nil")
|
||||
}
|
||||
|
||||
model := strings.TrimSpace(getString(req["model"]))
|
||||
if model == "" {
|
||||
return nil, errors.New("model is required")
|
||||
}
|
||||
|
||||
messagesRaw, ok := req["messages"]
|
||||
if !ok {
|
||||
return nil, errors.New("messages is required")
|
||||
}
|
||||
messages, ok := messagesRaw.([]any)
|
||||
if !ok {
|
||||
return nil, errors.New("messages must be an array")
|
||||
}
|
||||
|
||||
input, err := convertChatMessagesToResponsesInput(messages)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
out := make(map[string]any, len(req)+1)
|
||||
for key, value := range req {
|
||||
switch key {
|
||||
case "messages", "max_tokens", "max_completion_tokens", "stream_options", "functions", "function_call":
|
||||
continue
|
||||
default:
|
||||
out[key] = value
|
||||
}
|
||||
}
|
||||
|
||||
out["model"] = model
|
||||
out["input"] = input
|
||||
|
||||
if _, ok := out["max_output_tokens"]; !ok {
|
||||
if v, ok := req["max_tokens"]; ok {
|
||||
out["max_output_tokens"] = v
|
||||
} else if v, ok := req["max_completion_tokens"]; ok {
|
||||
out["max_output_tokens"] = v
|
||||
}
|
||||
}
|
||||
|
||||
if _, ok := out["tools"]; !ok {
|
||||
if functions, ok := req["functions"].([]any); ok && len(functions) > 0 {
|
||||
tools := make([]any, 0, len(functions))
|
||||
for _, fn := range functions {
|
||||
if fnMap, ok := fn.(map[string]any); ok {
|
||||
tools = append(tools, map[string]any{
|
||||
"type": "function",
|
||||
"function": fnMap,
|
||||
})
|
||||
}
|
||||
}
|
||||
if len(tools) > 0 {
|
||||
out["tools"] = tools
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if _, ok := out["tool_choice"]; !ok {
|
||||
if functionCall, ok := req["function_call"]; ok {
|
||||
out["tool_choice"] = functionCall
|
||||
}
|
||||
}
|
||||
|
||||
return out, nil
|
||||
}
|
||||
|
||||
// ConvertResponsesToChatCompletion converts an OpenAI Responses response body to Chat Completions format.
|
||||
func ConvertResponsesToChatCompletion(body []byte) ([]byte, error) {
|
||||
var resp map[string]any
|
||||
if err := json.Unmarshal(body, &resp); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
id := strings.TrimSpace(getString(resp["id"]))
|
||||
if id == "" {
|
||||
id = "chatcmpl-" + safeRandomHex(12)
|
||||
}
|
||||
model := strings.TrimSpace(getString(resp["model"]))
|
||||
|
||||
created := getInt64(resp["created_at"])
|
||||
if created == 0 {
|
||||
created = getInt64(resp["created"])
|
||||
}
|
||||
if created == 0 {
|
||||
created = time.Now().Unix()
|
||||
}
|
||||
|
||||
text, toolCalls := extractResponseTextAndToolCalls(resp)
|
||||
finishReason := "stop"
|
||||
if len(toolCalls) > 0 {
|
||||
finishReason = "tool_calls"
|
||||
}
|
||||
|
||||
message := map[string]any{
|
||||
"role": "assistant",
|
||||
"content": text,
|
||||
}
|
||||
if len(toolCalls) > 0 {
|
||||
message["tool_calls"] = toolCalls
|
||||
}
|
||||
|
||||
chatResp := map[string]any{
|
||||
"id": id,
|
||||
"object": "chat.completion",
|
||||
"created": created,
|
||||
"model": model,
|
||||
"choices": []any{
|
||||
map[string]any{
|
||||
"index": 0,
|
||||
"message": message,
|
||||
"finish_reason": finishReason,
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
if usage := extractResponseUsage(resp); usage != nil {
|
||||
chatResp["usage"] = usage
|
||||
}
|
||||
if fingerprint := strings.TrimSpace(getString(resp["system_fingerprint"])); fingerprint != "" {
|
||||
chatResp["system_fingerprint"] = fingerprint
|
||||
}
|
||||
|
||||
return json.Marshal(chatResp)
|
||||
}
|
||||
|
||||
func convertChatMessagesToResponsesInput(messages []any) ([]any, error) {
|
||||
input := make([]any, 0, len(messages))
|
||||
for _, msg := range messages {
|
||||
msgMap, ok := msg.(map[string]any)
|
||||
if !ok {
|
||||
return nil, errors.New("message must be an object")
|
||||
}
|
||||
role := strings.TrimSpace(getString(msgMap["role"]))
|
||||
if role == "" {
|
||||
return nil, errors.New("message role is required")
|
||||
}
|
||||
|
||||
switch role {
|
||||
case "tool":
|
||||
callID := strings.TrimSpace(getString(msgMap["tool_call_id"]))
|
||||
if callID == "" {
|
||||
callID = strings.TrimSpace(getString(msgMap["id"]))
|
||||
}
|
||||
output := extractMessageContentText(msgMap["content"])
|
||||
input = append(input, map[string]any{
|
||||
"type": "function_call_output",
|
||||
"call_id": callID,
|
||||
"output": output,
|
||||
})
|
||||
case "function":
|
||||
callID := strings.TrimSpace(getString(msgMap["name"]))
|
||||
output := extractMessageContentText(msgMap["content"])
|
||||
input = append(input, map[string]any{
|
||||
"type": "function_call_output",
|
||||
"call_id": callID,
|
||||
"output": output,
|
||||
})
|
||||
default:
|
||||
convertedContent := convertChatContent(msgMap["content"])
|
||||
toolCalls := []any(nil)
|
||||
if role == "assistant" {
|
||||
toolCalls = extractToolCallsFromMessage(msgMap)
|
||||
}
|
||||
skipAssistantMessage := role == "assistant" && len(toolCalls) > 0 && isEmptyContent(convertedContent)
|
||||
if !skipAssistantMessage {
|
||||
msgItem := map[string]any{
|
||||
"role": role,
|
||||
"content": convertedContent,
|
||||
}
|
||||
if name := strings.TrimSpace(getString(msgMap["name"])); name != "" {
|
||||
msgItem["name"] = name
|
||||
}
|
||||
input = append(input, msgItem)
|
||||
}
|
||||
if role == "assistant" && len(toolCalls) > 0 {
|
||||
input = append(input, toolCalls...)
|
||||
}
|
||||
}
|
||||
}
|
||||
return input, nil
|
||||
}
|
||||
|
||||
func convertChatContent(content any) any {
|
||||
switch v := content.(type) {
|
||||
case nil:
|
||||
return ""
|
||||
case string:
|
||||
return v
|
||||
case []any:
|
||||
converted := make([]any, 0, len(v))
|
||||
for _, part := range v {
|
||||
partMap, ok := part.(map[string]any)
|
||||
if !ok {
|
||||
converted = append(converted, part)
|
||||
continue
|
||||
}
|
||||
partType := strings.TrimSpace(getString(partMap["type"]))
|
||||
switch partType {
|
||||
case "text":
|
||||
text := getString(partMap["text"])
|
||||
if text != "" {
|
||||
converted = append(converted, map[string]any{
|
||||
"type": "input_text",
|
||||
"text": text,
|
||||
})
|
||||
continue
|
||||
}
|
||||
case "image_url":
|
||||
imageURL := ""
|
||||
if imageObj, ok := partMap["image_url"].(map[string]any); ok {
|
||||
imageURL = getString(imageObj["url"])
|
||||
} else {
|
||||
imageURL = getString(partMap["image_url"])
|
||||
}
|
||||
if imageURL != "" {
|
||||
converted = append(converted, map[string]any{
|
||||
"type": "input_image",
|
||||
"image_url": imageURL,
|
||||
})
|
||||
continue
|
||||
}
|
||||
case "input_text", "input_image":
|
||||
converted = append(converted, partMap)
|
||||
continue
|
||||
}
|
||||
converted = append(converted, partMap)
|
||||
}
|
||||
return converted
|
||||
default:
|
||||
return v
|
||||
}
|
||||
}
|
||||
|
||||
func extractToolCallsFromMessage(msg map[string]any) []any {
|
||||
var out []any
|
||||
if toolCalls, ok := msg["tool_calls"].([]any); ok {
|
||||
for _, call := range toolCalls {
|
||||
callMap, ok := call.(map[string]any)
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
callID := strings.TrimSpace(getString(callMap["id"]))
|
||||
if callID == "" {
|
||||
callID = strings.TrimSpace(getString(callMap["call_id"]))
|
||||
}
|
||||
name := ""
|
||||
args := ""
|
||||
if fn, ok := callMap["function"].(map[string]any); ok {
|
||||
name = strings.TrimSpace(getString(fn["name"]))
|
||||
args = getString(fn["arguments"])
|
||||
}
|
||||
if name == "" && args == "" {
|
||||
continue
|
||||
}
|
||||
item := map[string]any{
|
||||
"type": "tool_call",
|
||||
}
|
||||
if callID != "" {
|
||||
item["call_id"] = callID
|
||||
}
|
||||
if name != "" {
|
||||
item["name"] = name
|
||||
}
|
||||
if args != "" {
|
||||
item["arguments"] = args
|
||||
}
|
||||
out = append(out, item)
|
||||
}
|
||||
}
|
||||
|
||||
if fnCall, ok := msg["function_call"].(map[string]any); ok {
|
||||
name := strings.TrimSpace(getString(fnCall["name"]))
|
||||
args := getString(fnCall["arguments"])
|
||||
if name != "" || args != "" {
|
||||
callID := strings.TrimSpace(getString(msg["tool_call_id"]))
|
||||
if callID == "" {
|
||||
callID = name
|
||||
}
|
||||
item := map[string]any{
|
||||
"type": "function_call",
|
||||
}
|
||||
if callID != "" {
|
||||
item["call_id"] = callID
|
||||
}
|
||||
if name != "" {
|
||||
item["name"] = name
|
||||
}
|
||||
if args != "" {
|
||||
item["arguments"] = args
|
||||
}
|
||||
out = append(out, item)
|
||||
}
|
||||
}
|
||||
|
||||
return out
|
||||
}
|
||||
|
||||
func extractMessageContentText(content any) string {
|
||||
switch v := content.(type) {
|
||||
case string:
|
||||
return v
|
||||
case []any:
|
||||
parts := make([]string, 0, len(v))
|
||||
for _, part := range v {
|
||||
partMap, ok := part.(map[string]any)
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
partType := strings.TrimSpace(getString(partMap["type"]))
|
||||
if partType == "" || partType == "text" || partType == "output_text" || partType == "input_text" {
|
||||
text := getString(partMap["text"])
|
||||
if text != "" {
|
||||
parts = append(parts, text)
|
||||
}
|
||||
}
|
||||
}
|
||||
return strings.Join(parts, "")
|
||||
default:
|
||||
return ""
|
||||
}
|
||||
}
|
||||
|
||||
func isEmptyContent(content any) bool {
|
||||
switch v := content.(type) {
|
||||
case nil:
|
||||
return true
|
||||
case string:
|
||||
return strings.TrimSpace(v) == ""
|
||||
case []any:
|
||||
return len(v) == 0
|
||||
default:
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
func extractResponseTextAndToolCalls(resp map[string]any) (string, []any) {
|
||||
output, ok := resp["output"].([]any)
|
||||
if !ok {
|
||||
if text, ok := resp["output_text"].(string); ok {
|
||||
return text, nil
|
||||
}
|
||||
return "", nil
|
||||
}
|
||||
|
||||
textParts := make([]string, 0)
|
||||
toolCalls := make([]any, 0)
|
||||
|
||||
for _, item := range output {
|
||||
itemMap, ok := item.(map[string]any)
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
itemType := strings.TrimSpace(getString(itemMap["type"]))
|
||||
|
||||
if itemType == "tool_call" || itemType == "function_call" {
|
||||
if tc := responseItemToChatToolCall(itemMap); tc != nil {
|
||||
toolCalls = append(toolCalls, tc)
|
||||
}
|
||||
continue
|
||||
}
|
||||
|
||||
content := itemMap["content"]
|
||||
switch v := content.(type) {
|
||||
case string:
|
||||
if v != "" {
|
||||
textParts = append(textParts, v)
|
||||
}
|
||||
case []any:
|
||||
for _, part := range v {
|
||||
partMap, ok := part.(map[string]any)
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
partType := strings.TrimSpace(getString(partMap["type"]))
|
||||
switch partType {
|
||||
case "output_text", "text", "input_text":
|
||||
text := getString(partMap["text"])
|
||||
if text != "" {
|
||||
textParts = append(textParts, text)
|
||||
}
|
||||
case "tool_call", "function_call":
|
||||
if tc := responseItemToChatToolCall(partMap); tc != nil {
|
||||
toolCalls = append(toolCalls, tc)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return strings.Join(textParts, ""), toolCalls
|
||||
}
|
||||
|
||||
func responseItemToChatToolCall(item map[string]any) map[string]any {
|
||||
callID := strings.TrimSpace(getString(item["call_id"]))
|
||||
if callID == "" {
|
||||
callID = strings.TrimSpace(getString(item["id"]))
|
||||
}
|
||||
name := strings.TrimSpace(getString(item["name"]))
|
||||
arguments := getString(item["arguments"])
|
||||
if fn, ok := item["function"].(map[string]any); ok {
|
||||
if name == "" {
|
||||
name = strings.TrimSpace(getString(fn["name"]))
|
||||
}
|
||||
if arguments == "" {
|
||||
arguments = getString(fn["arguments"])
|
||||
}
|
||||
}
|
||||
|
||||
if name == "" && arguments == "" && callID == "" {
|
||||
return nil
|
||||
}
|
||||
|
||||
if callID == "" {
|
||||
callID = "call_" + safeRandomHex(6)
|
||||
}
|
||||
|
||||
return map[string]any{
|
||||
"id": callID,
|
||||
"type": "function",
|
||||
"function": map[string]any{
|
||||
"name": name,
|
||||
"arguments": arguments,
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
func extractResponseUsage(resp map[string]any) map[string]any {
|
||||
usage, ok := resp["usage"].(map[string]any)
|
||||
if !ok {
|
||||
return nil
|
||||
}
|
||||
promptTokens := int(getNumber(usage["input_tokens"]))
|
||||
completionTokens := int(getNumber(usage["output_tokens"]))
|
||||
if promptTokens == 0 && completionTokens == 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
return map[string]any{
|
||||
"prompt_tokens": promptTokens,
|
||||
"completion_tokens": completionTokens,
|
||||
"total_tokens": promptTokens + completionTokens,
|
||||
}
|
||||
}
|
||||
|
||||
func getString(value any) string {
|
||||
switch v := value.(type) {
|
||||
case string:
|
||||
return v
|
||||
case []byte:
|
||||
return string(v)
|
||||
case json.Number:
|
||||
return v.String()
|
||||
default:
|
||||
return ""
|
||||
}
|
||||
}
|
||||
|
||||
func getNumber(value any) float64 {
|
||||
switch v := value.(type) {
|
||||
case float64:
|
||||
return v
|
||||
case float32:
|
||||
return float64(v)
|
||||
case int:
|
||||
return float64(v)
|
||||
case int64:
|
||||
return float64(v)
|
||||
case json.Number:
|
||||
f, _ := v.Float64()
|
||||
return f
|
||||
default:
|
||||
return 0
|
||||
}
|
||||
}
|
||||
|
||||
func getInt64(value any) int64 {
|
||||
switch v := value.(type) {
|
||||
case int64:
|
||||
return v
|
||||
case int:
|
||||
return int64(v)
|
||||
case float64:
|
||||
return int64(v)
|
||||
case json.Number:
|
||||
i, _ := v.Int64()
|
||||
return i
|
||||
default:
|
||||
return 0
|
||||
}
|
||||
}
|
||||
|
||||
func safeRandomHex(byteLength int) string {
|
||||
value, err := randomHexString(byteLength)
|
||||
if err != nil || value == "" {
|
||||
return "000000"
|
||||
}
|
||||
return value
|
||||
}
|
||||
486
backend/internal/service/openai_chat_completions_forward.go
Normal file
486
backend/internal/service/openai_chat_completions_forward.go
Normal file
@@ -0,0 +1,486 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"bytes"
|
||||
"context"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"log"
|
||||
"net/http"
|
||||
"strings"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/util/responseheaders"
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
type chatStreamingResult struct {
|
||||
usage *OpenAIUsage
|
||||
firstTokenMs *int
|
||||
}
|
||||
|
||||
func (s *OpenAIGatewayService) forwardChatCompletions(ctx context.Context, c *gin.Context, account *Account, body []byte, includeUsage bool, startTime time.Time) (*OpenAIForwardResult, error) {
|
||||
// Parse request body once (avoid multiple parse/serialize cycles)
|
||||
var reqBody map[string]any
|
||||
if err := json.Unmarshal(body, &reqBody); err != nil {
|
||||
return nil, fmt.Errorf("parse request: %w", err)
|
||||
}
|
||||
|
||||
reqModel, _ := reqBody["model"].(string)
|
||||
reqStream, _ := reqBody["stream"].(bool)
|
||||
originalModel := reqModel
|
||||
|
||||
bodyModified := false
|
||||
mappedModel := account.GetMappedModel(reqModel)
|
||||
if mappedModel != reqModel {
|
||||
log.Printf("[OpenAI Chat] Model mapping applied: %s -> %s (account: %s)", reqModel, mappedModel, account.Name)
|
||||
reqBody["model"] = mappedModel
|
||||
bodyModified = true
|
||||
}
|
||||
|
||||
if reqStream && includeUsage {
|
||||
streamOptions, _ := reqBody["stream_options"].(map[string]any)
|
||||
if streamOptions == nil {
|
||||
streamOptions = map[string]any{}
|
||||
}
|
||||
if _, ok := streamOptions["include_usage"]; !ok {
|
||||
streamOptions["include_usage"] = true
|
||||
reqBody["stream_options"] = streamOptions
|
||||
bodyModified = true
|
||||
}
|
||||
}
|
||||
|
||||
if bodyModified {
|
||||
var err error
|
||||
body, err = json.Marshal(reqBody)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("serialize request body: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
// Get access token
|
||||
token, _, err := s.GetAccessToken(ctx, account)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
upstreamReq, err := s.buildChatCompletionsRequest(ctx, c, account, body, token)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
proxyURL := ""
|
||||
if account.ProxyID != nil && account.Proxy != nil {
|
||||
proxyURL = account.Proxy.URL()
|
||||
}
|
||||
|
||||
if c != nil {
|
||||
c.Set(OpsUpstreamRequestBodyKey, string(body))
|
||||
}
|
||||
|
||||
resp, err := s.httpUpstream.Do(upstreamReq, proxyURL, account.ID, account.Concurrency)
|
||||
if err != nil {
|
||||
safeErr := sanitizeUpstreamErrorMessage(err.Error())
|
||||
setOpsUpstreamError(c, 0, safeErr, "")
|
||||
appendOpsUpstreamError(c, OpsUpstreamErrorEvent{
|
||||
Platform: account.Platform,
|
||||
AccountID: account.ID,
|
||||
AccountName: account.Name,
|
||||
UpstreamStatusCode: 0,
|
||||
Kind: "request_error",
|
||||
Message: safeErr,
|
||||
})
|
||||
c.JSON(http.StatusBadGateway, gin.H{
|
||||
"error": gin.H{
|
||||
"type": "upstream_error",
|
||||
"message": "Upstream request failed",
|
||||
},
|
||||
})
|
||||
return nil, fmt.Errorf("upstream request failed: %s", safeErr)
|
||||
}
|
||||
defer func() { _ = resp.Body.Close() }()
|
||||
|
||||
if resp.StatusCode >= 400 {
|
||||
if s.shouldFailoverUpstreamError(resp.StatusCode) {
|
||||
respBody, _ := io.ReadAll(io.LimitReader(resp.Body, 2<<20))
|
||||
_ = resp.Body.Close()
|
||||
resp.Body = io.NopCloser(bytes.NewReader(respBody))
|
||||
|
||||
upstreamMsg := strings.TrimSpace(extractUpstreamErrorMessage(respBody))
|
||||
upstreamMsg = sanitizeUpstreamErrorMessage(upstreamMsg)
|
||||
upstreamDetail := ""
|
||||
if s.cfg != nil && s.cfg.Gateway.LogUpstreamErrorBody {
|
||||
maxBytes := s.cfg.Gateway.LogUpstreamErrorBodyMaxBytes
|
||||
if maxBytes <= 0 {
|
||||
maxBytes = 2048
|
||||
}
|
||||
upstreamDetail = truncateString(string(respBody), maxBytes)
|
||||
}
|
||||
appendOpsUpstreamError(c, OpsUpstreamErrorEvent{
|
||||
Platform: account.Platform,
|
||||
AccountID: account.ID,
|
||||
AccountName: account.Name,
|
||||
UpstreamStatusCode: resp.StatusCode,
|
||||
UpstreamRequestID: resp.Header.Get("x-request-id"),
|
||||
Kind: "failover",
|
||||
Message: upstreamMsg,
|
||||
Detail: upstreamDetail,
|
||||
})
|
||||
|
||||
s.handleFailoverSideEffects(ctx, resp, account)
|
||||
return nil, &UpstreamFailoverError{StatusCode: resp.StatusCode}
|
||||
}
|
||||
return s.handleErrorResponse(ctx, resp, c, account, body)
|
||||
}
|
||||
|
||||
var usage *OpenAIUsage
|
||||
var firstTokenMs *int
|
||||
if reqStream {
|
||||
streamResult, err := s.handleChatCompletionsStreamingResponse(ctx, resp, c, account, startTime, originalModel, mappedModel)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
usage = streamResult.usage
|
||||
firstTokenMs = streamResult.firstTokenMs
|
||||
} else {
|
||||
usage, err = s.handleChatCompletionsNonStreamingResponse(resp, c, originalModel, mappedModel)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
|
||||
if usage == nil {
|
||||
usage = &OpenAIUsage{}
|
||||
}
|
||||
|
||||
return &OpenAIForwardResult{
|
||||
RequestID: resp.Header.Get("x-request-id"),
|
||||
Usage: *usage,
|
||||
Model: originalModel,
|
||||
Stream: reqStream,
|
||||
Duration: time.Since(startTime),
|
||||
FirstTokenMs: firstTokenMs,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (s *OpenAIGatewayService) buildChatCompletionsRequest(ctx context.Context, c *gin.Context, account *Account, body []byte, token string) (*http.Request, error) {
|
||||
var targetURL string
|
||||
baseURL := account.GetOpenAIBaseURL()
|
||||
if baseURL == "" {
|
||||
targetURL = openaiChatAPIURL
|
||||
} else {
|
||||
validatedURL, err := s.validateUpstreamBaseURL(baseURL)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
targetURL = validatedURL + "/chat/completions"
|
||||
}
|
||||
|
||||
req, err := http.NewRequestWithContext(ctx, "POST", targetURL, bytes.NewReader(body))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
req.Header.Set("authorization", "Bearer "+token)
|
||||
|
||||
for key, values := range c.Request.Header {
|
||||
lowerKey := strings.ToLower(key)
|
||||
if openaiChatAllowedHeaders[lowerKey] {
|
||||
for _, v := range values {
|
||||
req.Header.Add(key, v)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
customUA := account.GetOpenAIUserAgent()
|
||||
if customUA != "" {
|
||||
req.Header.Set("user-agent", customUA)
|
||||
}
|
||||
|
||||
if req.Header.Get("content-type") == "" {
|
||||
req.Header.Set("content-type", "application/json")
|
||||
}
|
||||
|
||||
return req, nil
|
||||
}
|
||||
|
||||
func (s *OpenAIGatewayService) handleChatCompletionsStreamingResponse(ctx context.Context, resp *http.Response, c *gin.Context, account *Account, startTime time.Time, originalModel, mappedModel string) (*chatStreamingResult, error) {
|
||||
if s.cfg != nil {
|
||||
responseheaders.WriteFilteredHeaders(c.Writer.Header(), resp.Header, s.cfg.Security.ResponseHeaders)
|
||||
}
|
||||
|
||||
c.Header("Content-Type", "text/event-stream")
|
||||
c.Header("Cache-Control", "no-cache")
|
||||
c.Header("Connection", "keep-alive")
|
||||
c.Header("X-Accel-Buffering", "no")
|
||||
|
||||
if v := resp.Header.Get("x-request-id"); v != "" {
|
||||
c.Header("x-request-id", v)
|
||||
}
|
||||
|
||||
w := c.Writer
|
||||
flusher, ok := w.(http.Flusher)
|
||||
if !ok {
|
||||
return nil, errors.New("streaming not supported")
|
||||
}
|
||||
|
||||
usage := &OpenAIUsage{}
|
||||
var firstTokenMs *int
|
||||
|
||||
scanner := bufio.NewScanner(resp.Body)
|
||||
maxLineSize := defaultMaxLineSize
|
||||
if s.cfg != nil && s.cfg.Gateway.MaxLineSize > 0 {
|
||||
maxLineSize = s.cfg.Gateway.MaxLineSize
|
||||
}
|
||||
scanner.Buffer(make([]byte, 64*1024), maxLineSize)
|
||||
|
||||
type scanEvent struct {
|
||||
line string
|
||||
err error
|
||||
}
|
||||
events := make(chan scanEvent, 16)
|
||||
done := make(chan struct{})
|
||||
sendEvent := func(ev scanEvent) bool {
|
||||
select {
|
||||
case events <- ev:
|
||||
return true
|
||||
case <-done:
|
||||
return false
|
||||
}
|
||||
}
|
||||
var lastReadAt int64
|
||||
atomic.StoreInt64(&lastReadAt, time.Now().UnixNano())
|
||||
go func() {
|
||||
defer close(events)
|
||||
for scanner.Scan() {
|
||||
atomic.StoreInt64(&lastReadAt, time.Now().UnixNano())
|
||||
if !sendEvent(scanEvent{line: scanner.Text()}) {
|
||||
return
|
||||
}
|
||||
}
|
||||
if err := scanner.Err(); err != nil {
|
||||
_ = sendEvent(scanEvent{err: err})
|
||||
}
|
||||
}()
|
||||
defer close(done)
|
||||
|
||||
streamInterval := time.Duration(0)
|
||||
if s.cfg != nil && s.cfg.Gateway.StreamDataIntervalTimeout > 0 {
|
||||
streamInterval = time.Duration(s.cfg.Gateway.StreamDataIntervalTimeout) * time.Second
|
||||
}
|
||||
var intervalTicker *time.Ticker
|
||||
if streamInterval > 0 {
|
||||
intervalTicker = time.NewTicker(streamInterval)
|
||||
defer intervalTicker.Stop()
|
||||
}
|
||||
var intervalCh <-chan time.Time
|
||||
if intervalTicker != nil {
|
||||
intervalCh = intervalTicker.C
|
||||
}
|
||||
|
||||
keepaliveInterval := time.Duration(0)
|
||||
if s.cfg != nil && s.cfg.Gateway.StreamKeepaliveInterval > 0 {
|
||||
keepaliveInterval = time.Duration(s.cfg.Gateway.StreamKeepaliveInterval) * time.Second
|
||||
}
|
||||
var keepaliveTicker *time.Ticker
|
||||
if keepaliveInterval > 0 {
|
||||
keepaliveTicker = time.NewTicker(keepaliveInterval)
|
||||
defer keepaliveTicker.Stop()
|
||||
}
|
||||
var keepaliveCh <-chan time.Time
|
||||
if keepaliveTicker != nil {
|
||||
keepaliveCh = keepaliveTicker.C
|
||||
}
|
||||
lastDataAt := time.Now()
|
||||
|
||||
errorEventSent := false
|
||||
sendErrorEvent := func(reason string) {
|
||||
if errorEventSent {
|
||||
return
|
||||
}
|
||||
errorEventSent = true
|
||||
_, _ = fmt.Fprintf(w, "event: error\ndata: {\"error\":\"%s\"}\n\n", reason)
|
||||
flusher.Flush()
|
||||
}
|
||||
|
||||
needModelReplace := originalModel != mappedModel
|
||||
|
||||
for {
|
||||
select {
|
||||
case ev, ok := <-events:
|
||||
if !ok {
|
||||
return &chatStreamingResult{usage: usage, firstTokenMs: firstTokenMs}, nil
|
||||
}
|
||||
if ev.err != nil {
|
||||
if errors.Is(ev.err, bufio.ErrTooLong) {
|
||||
log.Printf("SSE line too long: account=%d max_size=%d error=%v", account.ID, maxLineSize, ev.err)
|
||||
sendErrorEvent("response_too_large")
|
||||
return &chatStreamingResult{usage: usage, firstTokenMs: firstTokenMs}, ev.err
|
||||
}
|
||||
sendErrorEvent("stream_read_error")
|
||||
return &chatStreamingResult{usage: usage, firstTokenMs: firstTokenMs}, fmt.Errorf("stream read error: %w", ev.err)
|
||||
}
|
||||
|
||||
line := ev.line
|
||||
lastDataAt = time.Now()
|
||||
|
||||
if openaiSSEDataRe.MatchString(line) {
|
||||
data := openaiSSEDataRe.ReplaceAllString(line, "")
|
||||
|
||||
if needModelReplace {
|
||||
line = s.replaceModelInSSELine(line, mappedModel, originalModel)
|
||||
}
|
||||
|
||||
if correctedData, corrected := s.toolCorrector.CorrectToolCallsInSSEData(data); corrected {
|
||||
line = "data: " + correctedData
|
||||
}
|
||||
|
||||
if _, err := fmt.Fprintf(w, "%s\n", line); err != nil {
|
||||
sendErrorEvent("write_failed")
|
||||
return &chatStreamingResult{usage: usage, firstTokenMs: firstTokenMs}, err
|
||||
}
|
||||
flusher.Flush()
|
||||
|
||||
if firstTokenMs == nil {
|
||||
if event := parseChatStreamEvent(data); event != nil {
|
||||
if chatChunkHasDelta(event) {
|
||||
ms := int(time.Since(startTime).Milliseconds())
|
||||
firstTokenMs = &ms
|
||||
}
|
||||
applyChatUsageFromEvent(event, usage)
|
||||
}
|
||||
} else {
|
||||
if event := parseChatStreamEvent(data); event != nil {
|
||||
applyChatUsageFromEvent(event, usage)
|
||||
}
|
||||
}
|
||||
} else {
|
||||
if _, err := fmt.Fprintf(w, "%s\n", line); err != nil {
|
||||
sendErrorEvent("write_failed")
|
||||
return &chatStreamingResult{usage: usage, firstTokenMs: firstTokenMs}, err
|
||||
}
|
||||
flusher.Flush()
|
||||
}
|
||||
|
||||
case <-intervalCh:
|
||||
lastRead := time.Unix(0, atomic.LoadInt64(&lastReadAt))
|
||||
if time.Since(lastRead) < streamInterval {
|
||||
continue
|
||||
}
|
||||
log.Printf("Stream data interval timeout: account=%d model=%s interval=%s", account.ID, originalModel, streamInterval)
|
||||
if s.rateLimitService != nil {
|
||||
s.rateLimitService.HandleStreamTimeout(ctx, account, originalModel)
|
||||
}
|
||||
sendErrorEvent("stream_timeout")
|
||||
return &chatStreamingResult{usage: usage, firstTokenMs: firstTokenMs}, fmt.Errorf("stream data interval timeout")
|
||||
|
||||
case <-keepaliveCh:
|
||||
if time.Since(lastDataAt) < keepaliveInterval {
|
||||
continue
|
||||
}
|
||||
if _, err := fmt.Fprint(w, ":\n\n"); err != nil {
|
||||
return &chatStreamingResult{usage: usage, firstTokenMs: firstTokenMs}, err
|
||||
}
|
||||
flusher.Flush()
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (s *OpenAIGatewayService) handleChatCompletionsNonStreamingResponse(resp *http.Response, c *gin.Context, originalModel, mappedModel string) (*OpenAIUsage, error) {
|
||||
body, err := io.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
usage := &OpenAIUsage{}
|
||||
var parsed map[string]any
|
||||
if json.Unmarshal(body, &parsed) == nil {
|
||||
if usageMap, ok := parsed["usage"].(map[string]any); ok {
|
||||
applyChatUsageFromMap(usageMap, usage)
|
||||
}
|
||||
}
|
||||
|
||||
if originalModel != mappedModel {
|
||||
body = s.replaceModelInResponseBody(body, mappedModel, originalModel)
|
||||
}
|
||||
body = s.correctToolCallsInResponseBody(body)
|
||||
|
||||
responseheaders.WriteFilteredHeaders(c.Writer.Header(), resp.Header, s.cfg.Security.ResponseHeaders)
|
||||
|
||||
contentType := "application/json"
|
||||
if s.cfg != nil && !s.cfg.Security.ResponseHeaders.Enabled {
|
||||
if upstreamType := resp.Header.Get("Content-Type"); upstreamType != "" {
|
||||
contentType = upstreamType
|
||||
}
|
||||
}
|
||||
|
||||
c.Data(resp.StatusCode, contentType, body)
|
||||
return usage, nil
|
||||
}
|
||||
|
||||
func parseChatStreamEvent(data string) map[string]any {
|
||||
if data == "" || data == "[DONE]" {
|
||||
return nil
|
||||
}
|
||||
var event map[string]any
|
||||
if json.Unmarshal([]byte(data), &event) != nil {
|
||||
return nil
|
||||
}
|
||||
return event
|
||||
}
|
||||
|
||||
func chatChunkHasDelta(event map[string]any) bool {
|
||||
choices, ok := event["choices"].([]any)
|
||||
if !ok {
|
||||
return false
|
||||
}
|
||||
for _, choice := range choices {
|
||||
choiceMap, ok := choice.(map[string]any)
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
delta, ok := choiceMap["delta"].(map[string]any)
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
if content, ok := delta["content"].(string); ok && strings.TrimSpace(content) != "" {
|
||||
return true
|
||||
}
|
||||
if toolCalls, ok := delta["tool_calls"].([]any); ok && len(toolCalls) > 0 {
|
||||
return true
|
||||
}
|
||||
if functionCall, ok := delta["function_call"].(map[string]any); ok && len(functionCall) > 0 {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func applyChatUsageFromEvent(event map[string]any, usage *OpenAIUsage) {
|
||||
if event == nil || usage == nil {
|
||||
return
|
||||
}
|
||||
usageMap, ok := event["usage"].(map[string]any)
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
applyChatUsageFromMap(usageMap, usage)
|
||||
}
|
||||
|
||||
func applyChatUsageFromMap(usageMap map[string]any, usage *OpenAIUsage) {
|
||||
if usageMap == nil || usage == nil {
|
||||
return
|
||||
}
|
||||
promptTokens := int(getNumber(usageMap["prompt_tokens"]))
|
||||
completionTokens := int(getNumber(usageMap["completion_tokens"]))
|
||||
if promptTokens > 0 {
|
||||
usage.InputTokens = promptTokens
|
||||
}
|
||||
if completionTokens > 0 {
|
||||
usage.OutputTokens = completionTokens
|
||||
}
|
||||
}
|
||||
132
backend/internal/service/openai_chat_completions_test.go
Normal file
132
backend/internal/service/openai_chat_completions_test.go
Normal file
@@ -0,0 +1,132 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestConvertChatCompletionsToResponses(t *testing.T) {
|
||||
req := map[string]any{
|
||||
"model": "gpt-4o",
|
||||
"messages": []any{
|
||||
map[string]any{
|
||||
"role": "user",
|
||||
"content": "hello",
|
||||
},
|
||||
map[string]any{
|
||||
"role": "assistant",
|
||||
"tool_calls": []any{
|
||||
map[string]any{
|
||||
"id": "call_1",
|
||||
"type": "function",
|
||||
"function": map[string]any{
|
||||
"name": "ping",
|
||||
"arguments": "{}",
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
map[string]any{
|
||||
"role": "tool",
|
||||
"tool_call_id": "call_1",
|
||||
"content": "ok",
|
||||
"response": "ignored",
|
||||
"response_time": 1,
|
||||
},
|
||||
},
|
||||
"functions": []any{
|
||||
map[string]any{
|
||||
"name": "ping",
|
||||
"description": "ping tool",
|
||||
"parameters": map[string]any{"type": "object"},
|
||||
},
|
||||
},
|
||||
"function_call": map[string]any{"name": "ping"},
|
||||
}
|
||||
|
||||
converted, err := ConvertChatCompletionsToResponses(req)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, "gpt-4o", converted["model"])
|
||||
|
||||
input, ok := converted["input"].([]any)
|
||||
require.True(t, ok)
|
||||
require.Len(t, input, 3)
|
||||
|
||||
toolCall := findInputItemByType(input, "tool_call")
|
||||
require.NotNil(t, toolCall)
|
||||
require.Equal(t, "call_1", toolCall["call_id"])
|
||||
|
||||
toolOutput := findInputItemByType(input, "function_call_output")
|
||||
require.NotNil(t, toolOutput)
|
||||
require.Equal(t, "call_1", toolOutput["call_id"])
|
||||
|
||||
tools, ok := converted["tools"].([]any)
|
||||
require.True(t, ok)
|
||||
require.Len(t, tools, 1)
|
||||
|
||||
require.Equal(t, map[string]any{"name": "ping"}, converted["tool_choice"])
|
||||
}
|
||||
|
||||
func TestConvertResponsesToChatCompletion(t *testing.T) {
|
||||
resp := map[string]any{
|
||||
"id": "resp_123",
|
||||
"model": "gpt-4o",
|
||||
"created_at": 1700000000,
|
||||
"output": []any{
|
||||
map[string]any{
|
||||
"type": "message",
|
||||
"role": "assistant",
|
||||
"content": []any{
|
||||
map[string]any{
|
||||
"type": "output_text",
|
||||
"text": "hi",
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
"usage": map[string]any{
|
||||
"input_tokens": 2,
|
||||
"output_tokens": 3,
|
||||
},
|
||||
}
|
||||
body, err := json.Marshal(resp)
|
||||
require.NoError(t, err)
|
||||
|
||||
converted, err := ConvertResponsesToChatCompletion(body)
|
||||
require.NoError(t, err)
|
||||
|
||||
var chat map[string]any
|
||||
require.NoError(t, json.Unmarshal(converted, &chat))
|
||||
require.Equal(t, "chat.completion", chat["object"])
|
||||
|
||||
choices, ok := chat["choices"].([]any)
|
||||
require.True(t, ok)
|
||||
require.Len(t, choices, 1)
|
||||
|
||||
choice, ok := choices[0].(map[string]any)
|
||||
require.True(t, ok)
|
||||
message, ok := choice["message"].(map[string]any)
|
||||
require.True(t, ok)
|
||||
require.Equal(t, "hi", message["content"])
|
||||
|
||||
usage, ok := chat["usage"].(map[string]any)
|
||||
require.True(t, ok)
|
||||
require.Equal(t, float64(2), usage["prompt_tokens"])
|
||||
require.Equal(t, float64(3), usage["completion_tokens"])
|
||||
require.Equal(t, float64(5), usage["total_tokens"])
|
||||
}
|
||||
|
||||
func findInputItemByType(items []any, itemType string) map[string]any {
|
||||
for _, item := range items {
|
||||
itemMap, ok := item.(map[string]any)
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
if itemMap["type"] == itemType {
|
||||
return itemMap
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
@@ -11,6 +11,7 @@ import (
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"regexp"
|
||||
"sort"
|
||||
"strconv"
|
||||
"strings"
|
||||
@@ -33,6 +34,7 @@ const (
|
||||
chatgptCodexURL = "https://chatgpt.com/backend-api/codex/responses"
|
||||
// OpenAI Platform API for API Key accounts (fallback)
|
||||
openaiPlatformAPIURL = "https://api.openai.com/v1/responses"
|
||||
openaiChatAPIURL = "https://api.openai.com/v1/chat/completions"
|
||||
openaiStickySessionTTL = time.Hour // 粘性会话TTL
|
||||
codexCLIUserAgent = "codex_cli_rs/0.98.0"
|
||||
// codex_cli_only 拒绝时单个请求头日志长度上限(字符)
|
||||
@@ -42,6 +44,16 @@ const (
|
||||
OpenAIParsedRequestBodyKey = "openai_parsed_request_body"
|
||||
)
|
||||
|
||||
// OpenAIChatCompletionsBodyKey stores the original chat-completions payload in gin.Context.
|
||||
const OpenAIChatCompletionsBodyKey = "openai_chat_completions_body"
|
||||
|
||||
// OpenAIChatCompletionsIncludeUsageKey stores stream_options.include_usage in gin.Context.
|
||||
const OpenAIChatCompletionsIncludeUsageKey = "openai_chat_completions_include_usage"
|
||||
|
||||
// openaiSSEDataRe matches SSE data lines with optional whitespace after colon.
|
||||
// Some upstream APIs return non-standard "data:" without space (should be "data: ").
|
||||
var openaiSSEDataRe = regexp.MustCompile(`^data:\s*`)
|
||||
|
||||
// OpenAI allowed headers whitelist (for non-passthrough).
|
||||
var openaiAllowedHeaders = map[string]bool{
|
||||
"accept-language": true,
|
||||
@@ -81,6 +93,19 @@ var codexCLIOnlyDebugHeaderWhitelist = []string{
|
||||
"X-Real-IP",
|
||||
}
|
||||
|
||||
// OpenAI chat-completions allowed headers (extend responses whitelist).
|
||||
var openaiChatAllowedHeaders = map[string]bool{
|
||||
"accept-language": true,
|
||||
"content-type": true,
|
||||
"conversation_id": true,
|
||||
"user-agent": true,
|
||||
"originator": true,
|
||||
"session_id": true,
|
||||
"openai-organization": true,
|
||||
"openai-project": true,
|
||||
"openai-beta": true,
|
||||
}
|
||||
|
||||
// OpenAICodexUsageSnapshot represents Codex API usage limits from response headers
|
||||
type OpenAICodexUsageSnapshot struct {
|
||||
PrimaryUsedPercent *float64 `json:"primary_used_percent,omitempty"`
|
||||
@@ -1005,6 +1030,23 @@ func (s *OpenAIGatewayService) Forward(ctx context.Context, c *gin.Context, acco
|
||||
return nil, errors.New("codex_cli_only restriction: only codex official clients are allowed")
|
||||
}
|
||||
|
||||
if c != nil && account != nil && account.Type == AccountTypeAPIKey {
|
||||
if raw, ok := c.Get(OpenAIChatCompletionsBodyKey); ok {
|
||||
if rawBody, ok := raw.([]byte); ok && len(rawBody) > 0 {
|
||||
includeUsage := false
|
||||
if v, ok := c.Get(OpenAIChatCompletionsIncludeUsageKey); ok {
|
||||
if flag, ok := v.(bool); ok {
|
||||
includeUsage = flag
|
||||
}
|
||||
}
|
||||
if passthroughWriter, ok := c.Writer.(interface{ SetPassthrough() }); ok {
|
||||
passthroughWriter.SetPassthrough()
|
||||
}
|
||||
return s.forwardChatCompletions(ctx, c, account, rawBody, includeUsage, startTime)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
originalBody := body
|
||||
reqModel, reqStream, promptCacheKey := extractOpenAIRequestMetaFromBody(body)
|
||||
originalModel := reqModel
|
||||
|
||||
Reference in New Issue
Block a user