mirror of
https://github.com/Wei-Shaw/sub2api.git
synced 2026-03-30 04:05:29 +00:00
- apply default mapped model only when scheduling fallback is actually used - preserve reasoning in OpenAI-compatible output via reasoning_content and avoid invalid input function_call ids
285 lines
9.7 KiB
Go
285 lines
9.7 KiB
Go
package handler
|
|
|
|
import (
|
|
"context"
|
|
"errors"
|
|
"net/http"
|
|
"time"
|
|
|
|
pkghttputil "github.com/Wei-Shaw/sub2api/internal/pkg/httputil"
|
|
"github.com/Wei-Shaw/sub2api/internal/pkg/ip"
|
|
"github.com/Wei-Shaw/sub2api/internal/pkg/logger"
|
|
middleware2 "github.com/Wei-Shaw/sub2api/internal/server/middleware"
|
|
"github.com/Wei-Shaw/sub2api/internal/service"
|
|
"github.com/gin-gonic/gin"
|
|
"github.com/tidwall/gjson"
|
|
"go.uber.org/zap"
|
|
)
|
|
|
|
// ChatCompletions handles OpenAI Chat Completions API requests.
|
|
// POST /v1/chat/completions
|
|
func (h *OpenAIGatewayHandler) ChatCompletions(c *gin.Context) {
|
|
streamStarted := false
|
|
defer h.recoverResponsesPanic(c, &streamStarted)
|
|
|
|
requestStart := time.Now()
|
|
|
|
apiKey, ok := middleware2.GetAPIKeyFromContext(c)
|
|
if !ok {
|
|
h.errorResponse(c, http.StatusUnauthorized, "authentication_error", "Invalid API key")
|
|
return
|
|
}
|
|
|
|
subject, ok := middleware2.GetAuthSubjectFromContext(c)
|
|
if !ok {
|
|
h.errorResponse(c, http.StatusInternalServerError, "api_error", "User context not found")
|
|
return
|
|
}
|
|
reqLog := requestLogger(
|
|
c,
|
|
"handler.openai_gateway.chat_completions",
|
|
zap.Int64("user_id", subject.UserID),
|
|
zap.Int64("api_key_id", apiKey.ID),
|
|
zap.Any("group_id", apiKey.GroupID),
|
|
)
|
|
|
|
if !h.ensureResponsesDependencies(c, reqLog) {
|
|
return
|
|
}
|
|
|
|
body, err := pkghttputil.ReadRequestBodyWithPrealloc(c.Request)
|
|
if err != nil {
|
|
if maxErr, ok := extractMaxBytesError(err); ok {
|
|
h.errorResponse(c, http.StatusRequestEntityTooLarge, "invalid_request_error", buildBodyTooLargeMessage(maxErr.Limit))
|
|
return
|
|
}
|
|
h.errorResponse(c, http.StatusBadRequest, "invalid_request_error", "Failed to read request body")
|
|
return
|
|
}
|
|
if len(body) == 0 {
|
|
h.errorResponse(c, http.StatusBadRequest, "invalid_request_error", "Request body is empty")
|
|
return
|
|
}
|
|
|
|
if !gjson.ValidBytes(body) {
|
|
h.errorResponse(c, http.StatusBadRequest, "invalid_request_error", "Failed to parse request body")
|
|
return
|
|
}
|
|
|
|
modelResult := gjson.GetBytes(body, "model")
|
|
if !modelResult.Exists() || modelResult.Type != gjson.String || modelResult.String() == "" {
|
|
h.errorResponse(c, http.StatusBadRequest, "invalid_request_error", "model is required")
|
|
return
|
|
}
|
|
reqModel := modelResult.String()
|
|
reqStream := gjson.GetBytes(body, "stream").Bool()
|
|
|
|
reqLog = reqLog.With(zap.String("model", reqModel), zap.Bool("stream", reqStream))
|
|
|
|
setOpsRequestContext(c, reqModel, reqStream, body)
|
|
|
|
if h.errorPassthroughService != nil {
|
|
service.BindErrorPassthroughService(c, h.errorPassthroughService)
|
|
}
|
|
|
|
subscription, _ := middleware2.GetSubscriptionFromContext(c)
|
|
|
|
service.SetOpsLatencyMs(c, service.OpsAuthLatencyMsKey, time.Since(requestStart).Milliseconds())
|
|
routingStart := time.Now()
|
|
|
|
userReleaseFunc, acquired := h.acquireResponsesUserSlot(c, subject.UserID, subject.Concurrency, reqStream, &streamStarted, reqLog)
|
|
if !acquired {
|
|
return
|
|
}
|
|
if userReleaseFunc != nil {
|
|
defer userReleaseFunc()
|
|
}
|
|
|
|
if err := h.billingCacheService.CheckBillingEligibility(c.Request.Context(), apiKey.User, apiKey, apiKey.Group, subscription); err != nil {
|
|
reqLog.Info("openai_chat_completions.billing_eligibility_check_failed", zap.Error(err))
|
|
status, code, message := billingErrorDetails(err)
|
|
h.handleStreamingAwareError(c, status, code, message, streamStarted)
|
|
return
|
|
}
|
|
|
|
sessionHash := h.gatewayService.GenerateSessionHash(c, body)
|
|
promptCacheKey := h.gatewayService.ExtractSessionID(c, body)
|
|
|
|
maxAccountSwitches := h.maxAccountSwitches
|
|
switchCount := 0
|
|
failedAccountIDs := make(map[int64]struct{})
|
|
sameAccountRetryCount := make(map[int64]int)
|
|
var lastFailoverErr *service.UpstreamFailoverError
|
|
|
|
for {
|
|
c.Set("openai_chat_completions_fallback_model", "")
|
|
reqLog.Debug("openai_chat_completions.account_selecting", zap.Int("excluded_account_count", len(failedAccountIDs)))
|
|
selection, scheduleDecision, err := h.gatewayService.SelectAccountWithScheduler(
|
|
c.Request.Context(),
|
|
apiKey.GroupID,
|
|
"",
|
|
sessionHash,
|
|
reqModel,
|
|
failedAccountIDs,
|
|
service.OpenAIUpstreamTransportAny,
|
|
)
|
|
if err != nil {
|
|
reqLog.Warn("openai_chat_completions.account_select_failed",
|
|
zap.Error(err),
|
|
zap.Int("excluded_account_count", len(failedAccountIDs)),
|
|
)
|
|
if len(failedAccountIDs) == 0 {
|
|
defaultModel := ""
|
|
if apiKey.Group != nil {
|
|
defaultModel = apiKey.Group.DefaultMappedModel
|
|
}
|
|
if defaultModel != "" && defaultModel != reqModel {
|
|
reqLog.Info("openai_chat_completions.fallback_to_default_model",
|
|
zap.String("default_mapped_model", defaultModel),
|
|
)
|
|
selection, scheduleDecision, err = h.gatewayService.SelectAccountWithScheduler(
|
|
c.Request.Context(),
|
|
apiKey.GroupID,
|
|
"",
|
|
sessionHash,
|
|
defaultModel,
|
|
failedAccountIDs,
|
|
service.OpenAIUpstreamTransportAny,
|
|
)
|
|
if err == nil && selection != nil {
|
|
c.Set("openai_chat_completions_fallback_model", defaultModel)
|
|
}
|
|
}
|
|
if err != nil {
|
|
h.handleStreamingAwareError(c, http.StatusServiceUnavailable, "api_error", "Service temporarily unavailable", streamStarted)
|
|
return
|
|
}
|
|
} else {
|
|
if lastFailoverErr != nil {
|
|
h.handleFailoverExhausted(c, lastFailoverErr, streamStarted)
|
|
} else {
|
|
h.handleStreamingAwareError(c, http.StatusBadGateway, "api_error", "Upstream request failed", streamStarted)
|
|
}
|
|
return
|
|
}
|
|
}
|
|
if selection == nil || selection.Account == nil {
|
|
h.handleStreamingAwareError(c, http.StatusServiceUnavailable, "api_error", "No available accounts", streamStarted)
|
|
return
|
|
}
|
|
account := selection.Account
|
|
sessionHash = ensureOpenAIPoolModeSessionHash(sessionHash, account)
|
|
reqLog.Debug("openai_chat_completions.account_selected", zap.Int64("account_id", account.ID), zap.String("account_name", account.Name))
|
|
_ = scheduleDecision
|
|
setOpsSelectedAccount(c, account.ID, account.Platform)
|
|
|
|
accountReleaseFunc, acquired := h.acquireResponsesAccountSlot(c, apiKey.GroupID, sessionHash, selection, reqStream, &streamStarted, reqLog)
|
|
if !acquired {
|
|
return
|
|
}
|
|
|
|
service.SetOpsLatencyMs(c, service.OpsRoutingLatencyMsKey, time.Since(routingStart).Milliseconds())
|
|
forwardStart := time.Now()
|
|
|
|
defaultMappedModel := c.GetString("openai_chat_completions_fallback_model")
|
|
result, err := h.gatewayService.ForwardAsChatCompletions(c.Request.Context(), c, account, body, promptCacheKey, defaultMappedModel)
|
|
|
|
forwardDurationMs := time.Since(forwardStart).Milliseconds()
|
|
if accountReleaseFunc != nil {
|
|
accountReleaseFunc()
|
|
}
|
|
upstreamLatencyMs, _ := getContextInt64(c, service.OpsUpstreamLatencyMsKey)
|
|
responseLatencyMs := forwardDurationMs
|
|
if upstreamLatencyMs > 0 && forwardDurationMs > upstreamLatencyMs {
|
|
responseLatencyMs = forwardDurationMs - upstreamLatencyMs
|
|
}
|
|
service.SetOpsLatencyMs(c, service.OpsResponseLatencyMsKey, responseLatencyMs)
|
|
if err == nil && result != nil && result.FirstTokenMs != nil {
|
|
service.SetOpsLatencyMs(c, service.OpsTimeToFirstTokenMsKey, int64(*result.FirstTokenMs))
|
|
}
|
|
if err != nil {
|
|
var failoverErr *service.UpstreamFailoverError
|
|
if errors.As(err, &failoverErr) {
|
|
h.gatewayService.ReportOpenAIAccountScheduleResult(account.ID, false, nil)
|
|
// Pool mode: retry on the same account
|
|
if failoverErr.RetryableOnSameAccount {
|
|
retryLimit := account.GetPoolModeRetryCount()
|
|
if sameAccountRetryCount[account.ID] < retryLimit {
|
|
sameAccountRetryCount[account.ID]++
|
|
reqLog.Warn("openai_chat_completions.pool_mode_same_account_retry",
|
|
zap.Int64("account_id", account.ID),
|
|
zap.Int("upstream_status", failoverErr.StatusCode),
|
|
zap.Int("retry_limit", retryLimit),
|
|
zap.Int("retry_count", sameAccountRetryCount[account.ID]),
|
|
)
|
|
select {
|
|
case <-c.Request.Context().Done():
|
|
return
|
|
case <-time.After(sameAccountRetryDelay):
|
|
}
|
|
continue
|
|
}
|
|
}
|
|
h.gatewayService.RecordOpenAIAccountSwitch()
|
|
failedAccountIDs[account.ID] = struct{}{}
|
|
lastFailoverErr = failoverErr
|
|
if switchCount >= maxAccountSwitches {
|
|
h.handleFailoverExhausted(c, failoverErr, streamStarted)
|
|
return
|
|
}
|
|
switchCount++
|
|
reqLog.Warn("openai_chat_completions.upstream_failover_switching",
|
|
zap.Int64("account_id", account.ID),
|
|
zap.Int("upstream_status", failoverErr.StatusCode),
|
|
zap.Int("switch_count", switchCount),
|
|
zap.Int("max_switches", maxAccountSwitches),
|
|
)
|
|
continue
|
|
}
|
|
h.gatewayService.ReportOpenAIAccountScheduleResult(account.ID, false, nil)
|
|
wroteFallback := h.ensureForwardErrorResponse(c, streamStarted)
|
|
reqLog.Warn("openai_chat_completions.forward_failed",
|
|
zap.Int64("account_id", account.ID),
|
|
zap.Bool("fallback_error_response_written", wroteFallback),
|
|
zap.Error(err),
|
|
)
|
|
return
|
|
}
|
|
if result != nil {
|
|
h.gatewayService.ReportOpenAIAccountScheduleResult(account.ID, true, result.FirstTokenMs)
|
|
} else {
|
|
h.gatewayService.ReportOpenAIAccountScheduleResult(account.ID, true, nil)
|
|
}
|
|
|
|
userAgent := c.GetHeader("User-Agent")
|
|
clientIP := ip.GetClientIP(c)
|
|
|
|
h.submitUsageRecordTask(func(ctx context.Context) {
|
|
if err := h.gatewayService.RecordUsage(ctx, &service.OpenAIRecordUsageInput{
|
|
Result: result,
|
|
APIKey: apiKey,
|
|
User: apiKey.User,
|
|
Account: account,
|
|
Subscription: subscription,
|
|
UserAgent: userAgent,
|
|
IPAddress: clientIP,
|
|
APIKeyService: h.apiKeyService,
|
|
}); err != nil {
|
|
logger.L().With(
|
|
zap.String("component", "handler.openai_gateway.chat_completions"),
|
|
zap.Int64("user_id", subject.UserID),
|
|
zap.Int64("api_key_id", apiKey.ID),
|
|
zap.Any("group_id", apiKey.GroupID),
|
|
zap.String("model", reqModel),
|
|
zap.Int64("account_id", account.ID),
|
|
).Error("openai_chat_completions.record_usage_failed", zap.Error(err))
|
|
}
|
|
})
|
|
reqLog.Debug("openai_chat_completions.request_completed",
|
|
zap.Int64("account_id", account.ID),
|
|
zap.Int("switch_count", switchCount),
|
|
)
|
|
return
|
|
}
|
|
}
|