diff --git a/backend/internal/domain/constants.go b/backend/internal/domain/constants.go index d7bb50fc..8a6621a1 100644 --- a/backend/internal/domain/constants.go +++ b/backend/internal/domain/constants.go @@ -84,10 +84,12 @@ var DefaultAntigravityModelMapping = map[string]string{ "claude-haiku-4-5": "claude-sonnet-4-5", "claude-haiku-4-5-20251001": "claude-sonnet-4-5", // Gemini 2.5 白名单 - "gemini-2.5-flash": "gemini-2.5-flash", - "gemini-2.5-flash-lite": "gemini-2.5-flash-lite", - "gemini-2.5-flash-thinking": "gemini-2.5-flash-thinking", - "gemini-2.5-pro": "gemini-2.5-pro", + "gemini-2.5-flash": "gemini-2.5-flash", + "gemini-2.5-flash-image": "gemini-2.5-flash-image", + "gemini-2.5-flash-image-preview": "gemini-2.5-flash-image", + "gemini-2.5-flash-lite": "gemini-2.5-flash-lite", + "gemini-2.5-flash-thinking": "gemini-2.5-flash-thinking", + "gemini-2.5-pro": "gemini-2.5-pro", // Gemini 3 白名单 "gemini-3-flash": "gemini-3-flash", "gemini-3-pro-high": "gemini-3-pro-high", diff --git a/backend/internal/domain/constants_test.go b/backend/internal/domain/constants_test.go index 29605ac6..de66137f 100644 --- a/backend/internal/domain/constants_test.go +++ b/backend/internal/domain/constants_test.go @@ -6,6 +6,8 @@ func TestDefaultAntigravityModelMapping_ImageCompatibilityAliases(t *testing.T) t.Parallel() cases := map[string]string{ + "gemini-2.5-flash-image": "gemini-2.5-flash-image", + "gemini-2.5-flash-image-preview": "gemini-2.5-flash-image", "gemini-3.1-flash-image": "gemini-3.1-flash-image", "gemini-3.1-flash-image-preview": "gemini-3.1-flash-image", "gemini-3-pro-image": "gemini-3.1-flash-image", diff --git a/backend/internal/handler/admin/account_handler.go b/backend/internal/handler/admin/account_handler.go index fad8a33c..7c4d4638 100644 --- a/backend/internal/handler/admin/account_handler.go +++ b/backend/internal/handler/admin/account_handler.go @@ -628,6 +628,7 @@ func (h *AccountHandler) Delete(c *gin.Context) { // TestAccountRequest represents the request body for testing an account type TestAccountRequest struct { ModelID string `json:"model_id"` + Prompt string `json:"prompt"` } type SyncFromCRSRequest struct { @@ -658,7 +659,7 @@ func (h *AccountHandler) Test(c *gin.Context) { _ = c.ShouldBindJSON(&req) // Use AccountTestService to test the account with SSE streaming - if err := h.accountTestService.TestAccountConnection(c, accountID, req.ModelID); err != nil { + if err := h.accountTestService.TestAccountConnection(c, accountID, req.ModelID, req.Prompt); err != nil { // Error already sent via SSE, just log return } diff --git a/backend/internal/handler/admin/dashboard_handler.go b/backend/internal/handler/admin/dashboard_handler.go index b0da6c5e..aa82b24f 100644 --- a/backend/internal/handler/admin/dashboard_handler.go +++ b/backend/internal/handler/admin/dashboard_handler.go @@ -249,11 +249,12 @@ func (h *DashboardHandler) GetUsageTrend(c *gin.Context) { } } - trend, err := h.dashboardService.GetUsageTrendWithFilters(c.Request.Context(), startTime, endTime, granularity, userID, apiKeyID, accountID, groupID, model, requestType, stream, billingType) + trend, hit, err := h.getUsageTrendCached(c.Request.Context(), startTime, endTime, granularity, userID, apiKeyID, accountID, groupID, model, requestType, stream, billingType) if err != nil { response.Error(c, 500, "Failed to get usage trend") return } + c.Header("X-Snapshot-Cache", cacheStatusValue(hit)) response.Success(c, gin.H{ "trend": trend, @@ -321,11 +322,12 @@ func (h *DashboardHandler) GetModelStats(c *gin.Context) { } } - stats, err := h.dashboardService.GetModelStatsWithFilters(c.Request.Context(), startTime, endTime, userID, apiKeyID, accountID, groupID, requestType, stream, billingType) + stats, hit, err := h.getModelStatsCached(c.Request.Context(), startTime, endTime, userID, apiKeyID, accountID, groupID, requestType, stream, billingType) if err != nil { response.Error(c, 500, "Failed to get model statistics") return } + c.Header("X-Snapshot-Cache", cacheStatusValue(hit)) response.Success(c, gin.H{ "models": stats, @@ -391,11 +393,12 @@ func (h *DashboardHandler) GetGroupStats(c *gin.Context) { } } - stats, err := h.dashboardService.GetGroupStatsWithFilters(c.Request.Context(), startTime, endTime, userID, apiKeyID, accountID, groupID, requestType, stream, billingType) + stats, hit, err := h.getGroupStatsCached(c.Request.Context(), startTime, endTime, userID, apiKeyID, accountID, groupID, requestType, stream, billingType) if err != nil { response.Error(c, 500, "Failed to get group statistics") return } + c.Header("X-Snapshot-Cache", cacheStatusValue(hit)) response.Success(c, gin.H{ "groups": stats, @@ -416,11 +419,12 @@ func (h *DashboardHandler) GetAPIKeyUsageTrend(c *gin.Context) { limit = 5 } - trend, err := h.dashboardService.GetAPIKeyUsageTrend(c.Request.Context(), startTime, endTime, granularity, limit) + trend, hit, err := h.getAPIKeyUsageTrendCached(c.Request.Context(), startTime, endTime, granularity, limit) if err != nil { response.Error(c, 500, "Failed to get API key usage trend") return } + c.Header("X-Snapshot-Cache", cacheStatusValue(hit)) response.Success(c, gin.H{ "trend": trend, @@ -442,11 +446,12 @@ func (h *DashboardHandler) GetUserUsageTrend(c *gin.Context) { limit = 12 } - trend, err := h.dashboardService.GetUserUsageTrend(c.Request.Context(), startTime, endTime, granularity, limit) + trend, hit, err := h.getUserUsageTrendCached(c.Request.Context(), startTime, endTime, granularity, limit) if err != nil { response.Error(c, 500, "Failed to get user usage trend") return } + c.Header("X-Snapshot-Cache", cacheStatusValue(hit)) response.Success(c, gin.H{ "trend": trend, diff --git a/backend/internal/handler/admin/dashboard_handler_cache_test.go b/backend/internal/handler/admin/dashboard_handler_cache_test.go new file mode 100644 index 00000000..ec888849 --- /dev/null +++ b/backend/internal/handler/admin/dashboard_handler_cache_test.go @@ -0,0 +1,118 @@ +package admin + +import ( + "context" + "net/http" + "net/http/httptest" + "sync/atomic" + "testing" + "time" + + "github.com/Wei-Shaw/sub2api/internal/pkg/usagestats" + "github.com/Wei-Shaw/sub2api/internal/service" + "github.com/gin-gonic/gin" + "github.com/stretchr/testify/require" +) + +type dashboardUsageRepoCacheProbe struct { + service.UsageLogRepository + trendCalls atomic.Int32 + usersTrendCalls atomic.Int32 +} + +func (r *dashboardUsageRepoCacheProbe) GetUsageTrendWithFilters( + ctx context.Context, + startTime, endTime time.Time, + granularity string, + userID, apiKeyID, accountID, groupID int64, + model string, + requestType *int16, + stream *bool, + billingType *int8, +) ([]usagestats.TrendDataPoint, error) { + r.trendCalls.Add(1) + return []usagestats.TrendDataPoint{{ + Date: "2026-03-11", + Requests: 1, + TotalTokens: 2, + Cost: 3, + ActualCost: 4, + }}, nil +} + +func (r *dashboardUsageRepoCacheProbe) GetUserUsageTrend( + ctx context.Context, + startTime, endTime time.Time, + granularity string, + limit int, +) ([]usagestats.UserUsageTrendPoint, error) { + r.usersTrendCalls.Add(1) + return []usagestats.UserUsageTrendPoint{{ + Date: "2026-03-11", + UserID: 1, + Email: "cache@test.dev", + Requests: 2, + Tokens: 20, + Cost: 2, + ActualCost: 1, + }}, nil +} + +func resetDashboardReadCachesForTest() { + dashboardTrendCache = newSnapshotCache(30 * time.Second) + dashboardUsersTrendCache = newSnapshotCache(30 * time.Second) + dashboardAPIKeysTrendCache = newSnapshotCache(30 * time.Second) + dashboardModelStatsCache = newSnapshotCache(30 * time.Second) + dashboardGroupStatsCache = newSnapshotCache(30 * time.Second) + dashboardSnapshotV2Cache = newSnapshotCache(30 * time.Second) +} + +func TestDashboardHandler_GetUsageTrend_UsesCache(t *testing.T) { + t.Cleanup(resetDashboardReadCachesForTest) + resetDashboardReadCachesForTest() + + gin.SetMode(gin.TestMode) + repo := &dashboardUsageRepoCacheProbe{} + dashboardSvc := service.NewDashboardService(repo, nil, nil, nil) + handler := NewDashboardHandler(dashboardSvc, nil) + router := gin.New() + router.GET("/admin/dashboard/trend", handler.GetUsageTrend) + + req1 := httptest.NewRequest(http.MethodGet, "/admin/dashboard/trend?start_date=2026-03-01&end_date=2026-03-07&granularity=day", nil) + rec1 := httptest.NewRecorder() + router.ServeHTTP(rec1, req1) + require.Equal(t, http.StatusOK, rec1.Code) + require.Equal(t, "miss", rec1.Header().Get("X-Snapshot-Cache")) + + req2 := httptest.NewRequest(http.MethodGet, "/admin/dashboard/trend?start_date=2026-03-01&end_date=2026-03-07&granularity=day", nil) + rec2 := httptest.NewRecorder() + router.ServeHTTP(rec2, req2) + require.Equal(t, http.StatusOK, rec2.Code) + require.Equal(t, "hit", rec2.Header().Get("X-Snapshot-Cache")) + require.Equal(t, int32(1), repo.trendCalls.Load()) +} + +func TestDashboardHandler_GetUserUsageTrend_UsesCache(t *testing.T) { + t.Cleanup(resetDashboardReadCachesForTest) + resetDashboardReadCachesForTest() + + gin.SetMode(gin.TestMode) + repo := &dashboardUsageRepoCacheProbe{} + dashboardSvc := service.NewDashboardService(repo, nil, nil, nil) + handler := NewDashboardHandler(dashboardSvc, nil) + router := gin.New() + router.GET("/admin/dashboard/users-trend", handler.GetUserUsageTrend) + + req1 := httptest.NewRequest(http.MethodGet, "/admin/dashboard/users-trend?start_date=2026-03-01&end_date=2026-03-07&granularity=day&limit=8", nil) + rec1 := httptest.NewRecorder() + router.ServeHTTP(rec1, req1) + require.Equal(t, http.StatusOK, rec1.Code) + require.Equal(t, "miss", rec1.Header().Get("X-Snapshot-Cache")) + + req2 := httptest.NewRequest(http.MethodGet, "/admin/dashboard/users-trend?start_date=2026-03-01&end_date=2026-03-07&granularity=day&limit=8", nil) + rec2 := httptest.NewRecorder() + router.ServeHTTP(rec2, req2) + require.Equal(t, http.StatusOK, rec2.Code) + require.Equal(t, "hit", rec2.Header().Get("X-Snapshot-Cache")) + require.Equal(t, int32(1), repo.usersTrendCalls.Load()) +} diff --git a/backend/internal/handler/admin/dashboard_query_cache.go b/backend/internal/handler/admin/dashboard_query_cache.go new file mode 100644 index 00000000..47af5117 --- /dev/null +++ b/backend/internal/handler/admin/dashboard_query_cache.go @@ -0,0 +1,200 @@ +package admin + +import ( + "context" + "encoding/json" + "fmt" + "time" + + "github.com/Wei-Shaw/sub2api/internal/pkg/usagestats" +) + +var ( + dashboardTrendCache = newSnapshotCache(30 * time.Second) + dashboardModelStatsCache = newSnapshotCache(30 * time.Second) + dashboardGroupStatsCache = newSnapshotCache(30 * time.Second) + dashboardUsersTrendCache = newSnapshotCache(30 * time.Second) + dashboardAPIKeysTrendCache = newSnapshotCache(30 * time.Second) +) + +type dashboardTrendCacheKey struct { + StartTime string `json:"start_time"` + EndTime string `json:"end_time"` + Granularity string `json:"granularity"` + UserID int64 `json:"user_id"` + APIKeyID int64 `json:"api_key_id"` + AccountID int64 `json:"account_id"` + GroupID int64 `json:"group_id"` + Model string `json:"model"` + RequestType *int16 `json:"request_type"` + Stream *bool `json:"stream"` + BillingType *int8 `json:"billing_type"` +} + +type dashboardModelGroupCacheKey struct { + StartTime string `json:"start_time"` + EndTime string `json:"end_time"` + UserID int64 `json:"user_id"` + APIKeyID int64 `json:"api_key_id"` + AccountID int64 `json:"account_id"` + GroupID int64 `json:"group_id"` + RequestType *int16 `json:"request_type"` + Stream *bool `json:"stream"` + BillingType *int8 `json:"billing_type"` +} + +type dashboardEntityTrendCacheKey struct { + StartTime string `json:"start_time"` + EndTime string `json:"end_time"` + Granularity string `json:"granularity"` + Limit int `json:"limit"` +} + +func cacheStatusValue(hit bool) string { + if hit { + return "hit" + } + return "miss" +} + +func mustMarshalDashboardCacheKey(value any) string { + raw, err := json.Marshal(value) + if err != nil { + return "" + } + return string(raw) +} + +func snapshotPayloadAs[T any](payload any) (T, error) { + typed, ok := payload.(T) + if !ok { + var zero T + return zero, fmt.Errorf("unexpected cache payload type %T", payload) + } + return typed, nil +} + +func (h *DashboardHandler) getUsageTrendCached( + ctx context.Context, + startTime, endTime time.Time, + granularity string, + userID, apiKeyID, accountID, groupID int64, + model string, + requestType *int16, + stream *bool, + billingType *int8, +) ([]usagestats.TrendDataPoint, bool, error) { + key := mustMarshalDashboardCacheKey(dashboardTrendCacheKey{ + StartTime: startTime.UTC().Format(time.RFC3339), + EndTime: endTime.UTC().Format(time.RFC3339), + Granularity: granularity, + UserID: userID, + APIKeyID: apiKeyID, + AccountID: accountID, + GroupID: groupID, + Model: model, + RequestType: requestType, + Stream: stream, + BillingType: billingType, + }) + entry, hit, err := dashboardTrendCache.GetOrLoad(key, func() (any, error) { + return h.dashboardService.GetUsageTrendWithFilters(ctx, startTime, endTime, granularity, userID, apiKeyID, accountID, groupID, model, requestType, stream, billingType) + }) + if err != nil { + return nil, hit, err + } + trend, err := snapshotPayloadAs[[]usagestats.TrendDataPoint](entry.Payload) + return trend, hit, err +} + +func (h *DashboardHandler) getModelStatsCached( + ctx context.Context, + startTime, endTime time.Time, + userID, apiKeyID, accountID, groupID int64, + requestType *int16, + stream *bool, + billingType *int8, +) ([]usagestats.ModelStat, bool, error) { + key := mustMarshalDashboardCacheKey(dashboardModelGroupCacheKey{ + StartTime: startTime.UTC().Format(time.RFC3339), + EndTime: endTime.UTC().Format(time.RFC3339), + UserID: userID, + APIKeyID: apiKeyID, + AccountID: accountID, + GroupID: groupID, + RequestType: requestType, + Stream: stream, + BillingType: billingType, + }) + entry, hit, err := dashboardModelStatsCache.GetOrLoad(key, func() (any, error) { + return h.dashboardService.GetModelStatsWithFilters(ctx, startTime, endTime, userID, apiKeyID, accountID, groupID, requestType, stream, billingType) + }) + if err != nil { + return nil, hit, err + } + stats, err := snapshotPayloadAs[[]usagestats.ModelStat](entry.Payload) + return stats, hit, err +} + +func (h *DashboardHandler) getGroupStatsCached( + ctx context.Context, + startTime, endTime time.Time, + userID, apiKeyID, accountID, groupID int64, + requestType *int16, + stream *bool, + billingType *int8, +) ([]usagestats.GroupStat, bool, error) { + key := mustMarshalDashboardCacheKey(dashboardModelGroupCacheKey{ + StartTime: startTime.UTC().Format(time.RFC3339), + EndTime: endTime.UTC().Format(time.RFC3339), + UserID: userID, + APIKeyID: apiKeyID, + AccountID: accountID, + GroupID: groupID, + RequestType: requestType, + Stream: stream, + BillingType: billingType, + }) + entry, hit, err := dashboardGroupStatsCache.GetOrLoad(key, func() (any, error) { + return h.dashboardService.GetGroupStatsWithFilters(ctx, startTime, endTime, userID, apiKeyID, accountID, groupID, requestType, stream, billingType) + }) + if err != nil { + return nil, hit, err + } + stats, err := snapshotPayloadAs[[]usagestats.GroupStat](entry.Payload) + return stats, hit, err +} + +func (h *DashboardHandler) getAPIKeyUsageTrendCached(ctx context.Context, startTime, endTime time.Time, granularity string, limit int) ([]usagestats.APIKeyUsageTrendPoint, bool, error) { + key := mustMarshalDashboardCacheKey(dashboardEntityTrendCacheKey{ + StartTime: startTime.UTC().Format(time.RFC3339), + EndTime: endTime.UTC().Format(time.RFC3339), + Granularity: granularity, + Limit: limit, + }) + entry, hit, err := dashboardAPIKeysTrendCache.GetOrLoad(key, func() (any, error) { + return h.dashboardService.GetAPIKeyUsageTrend(ctx, startTime, endTime, granularity, limit) + }) + if err != nil { + return nil, hit, err + } + trend, err := snapshotPayloadAs[[]usagestats.APIKeyUsageTrendPoint](entry.Payload) + return trend, hit, err +} + +func (h *DashboardHandler) getUserUsageTrendCached(ctx context.Context, startTime, endTime time.Time, granularity string, limit int) ([]usagestats.UserUsageTrendPoint, bool, error) { + key := mustMarshalDashboardCacheKey(dashboardEntityTrendCacheKey{ + StartTime: startTime.UTC().Format(time.RFC3339), + EndTime: endTime.UTC().Format(time.RFC3339), + Granularity: granularity, + Limit: limit, + }) + entry, hit, err := dashboardUsersTrendCache.GetOrLoad(key, func() (any, error) { + return h.dashboardService.GetUserUsageTrend(ctx, startTime, endTime, granularity, limit) + }) + if err != nil { + return nil, hit, err + } + trend, err := snapshotPayloadAs[[]usagestats.UserUsageTrendPoint](entry.Payload) + return trend, hit, err +} diff --git a/backend/internal/handler/admin/dashboard_snapshot_v2_handler.go b/backend/internal/handler/admin/dashboard_snapshot_v2_handler.go index f6db69f3..16e10339 100644 --- a/backend/internal/handler/admin/dashboard_snapshot_v2_handler.go +++ b/backend/internal/handler/admin/dashboard_snapshot_v2_handler.go @@ -1,7 +1,9 @@ package admin import ( + "context" "encoding/json" + "errors" "net/http" "strconv" "strings" @@ -111,20 +113,45 @@ func (h *DashboardHandler) GetSnapshotV2(c *gin.Context) { }) cacheKey := string(keyRaw) - if cached, ok := dashboardSnapshotV2Cache.Get(cacheKey); ok { - if cached.ETag != "" { - c.Header("ETag", cached.ETag) - c.Header("Vary", "If-None-Match") - if ifNoneMatchMatched(c.GetHeader("If-None-Match"), cached.ETag) { - c.Status(http.StatusNotModified) - return - } - } - c.Header("X-Snapshot-Cache", "hit") - response.Success(c, cached.Payload) + cached, hit, err := dashboardSnapshotV2Cache.GetOrLoad(cacheKey, func() (any, error) { + return h.buildSnapshotV2Response( + c.Request.Context(), + startTime, + endTime, + granularity, + filters, + includeStats, + includeTrend, + includeModels, + includeGroups, + includeUsersTrend, + usersTrendLimit, + ) + }) + if err != nil { + response.Error(c, 500, err.Error()) return } + if cached.ETag != "" { + c.Header("ETag", cached.ETag) + c.Header("Vary", "If-None-Match") + if ifNoneMatchMatched(c.GetHeader("If-None-Match"), cached.ETag) { + c.Status(http.StatusNotModified) + return + } + } + c.Header("X-Snapshot-Cache", cacheStatusValue(hit)) + response.Success(c, cached.Payload) +} +func (h *DashboardHandler) buildSnapshotV2Response( + ctx context.Context, + startTime, endTime time.Time, + granularity string, + filters *dashboardSnapshotV2Filters, + includeStats, includeTrend, includeModels, includeGroups, includeUsersTrend bool, + usersTrendLimit int, +) (*dashboardSnapshotV2Response, error) { resp := &dashboardSnapshotV2Response{ GeneratedAt: time.Now().UTC().Format(time.RFC3339), StartDate: startTime.Format("2006-01-02"), @@ -133,10 +160,9 @@ func (h *DashboardHandler) GetSnapshotV2(c *gin.Context) { } if includeStats { - stats, err := h.dashboardService.GetDashboardStats(c.Request.Context()) + stats, err := h.dashboardService.GetDashboardStats(ctx) if err != nil { - response.Error(c, 500, "Failed to get dashboard statistics") - return + return nil, errors.New("failed to get dashboard statistics") } resp.Stats = &dashboardSnapshotV2Stats{ DashboardStats: *stats, @@ -145,8 +171,8 @@ func (h *DashboardHandler) GetSnapshotV2(c *gin.Context) { } if includeTrend { - trend, err := h.dashboardService.GetUsageTrendWithFilters( - c.Request.Context(), + trend, _, err := h.getUsageTrendCached( + ctx, startTime, endTime, granularity, @@ -160,15 +186,14 @@ func (h *DashboardHandler) GetSnapshotV2(c *gin.Context) { filters.BillingType, ) if err != nil { - response.Error(c, 500, "Failed to get usage trend") - return + return nil, errors.New("failed to get usage trend") } resp.Trend = trend } if includeModels { - models, err := h.dashboardService.GetModelStatsWithFilters( - c.Request.Context(), + models, _, err := h.getModelStatsCached( + ctx, startTime, endTime, filters.UserID, @@ -180,15 +205,14 @@ func (h *DashboardHandler) GetSnapshotV2(c *gin.Context) { filters.BillingType, ) if err != nil { - response.Error(c, 500, "Failed to get model statistics") - return + return nil, errors.New("failed to get model statistics") } resp.Models = models } if includeGroups { - groups, err := h.dashboardService.GetGroupStatsWithFilters( - c.Request.Context(), + groups, _, err := h.getGroupStatsCached( + ctx, startTime, endTime, filters.UserID, @@ -200,34 +224,20 @@ func (h *DashboardHandler) GetSnapshotV2(c *gin.Context) { filters.BillingType, ) if err != nil { - response.Error(c, 500, "Failed to get group statistics") - return + return nil, errors.New("failed to get group statistics") } resp.Groups = groups } if includeUsersTrend { - usersTrend, err := h.dashboardService.GetUserUsageTrend( - c.Request.Context(), - startTime, - endTime, - granularity, - usersTrendLimit, - ) + usersTrend, _, err := h.getUserUsageTrendCached(ctx, startTime, endTime, granularity, usersTrendLimit) if err != nil { - response.Error(c, 500, "Failed to get user usage trend") - return + return nil, errors.New("failed to get user usage trend") } resp.UsersTrend = usersTrend } - cached := dashboardSnapshotV2Cache.Set(cacheKey, resp) - if cached.ETag != "" { - c.Header("ETag", cached.ETag) - c.Header("Vary", "If-None-Match") - } - c.Header("X-Snapshot-Cache", "miss") - response.Success(c, resp) + return resp, nil } func parseDashboardSnapshotV2Filters(c *gin.Context) (*dashboardSnapshotV2Filters, error) { diff --git a/backend/internal/handler/admin/ops_alerts_handler.go b/backend/internal/handler/admin/ops_alerts_handler.go index c9da19c7..edc8c7f7 100644 --- a/backend/internal/handler/admin/ops_alerts_handler.go +++ b/backend/internal/handler/admin/ops_alerts_handler.go @@ -23,6 +23,13 @@ var validOpsAlertMetricTypes = []string{ "cpu_usage_percent", "memory_usage_percent", "concurrency_queue_depth", + "group_available_accounts", + "group_available_ratio", + "group_rate_limit_ratio", + "account_rate_limited_count", + "account_error_count", + "account_error_ratio", + "overload_account_count", } var validOpsAlertMetricTypeSet = func() map[string]struct{} { @@ -82,7 +89,10 @@ func isPercentOrRateMetric(metricType string) bool { "error_rate", "upstream_error_rate", "cpu_usage_percent", - "memory_usage_percent": + "memory_usage_percent", + "group_available_ratio", + "group_rate_limit_ratio", + "account_error_ratio": return true default: return false diff --git a/backend/internal/handler/admin/snapshot_cache.go b/backend/internal/handler/admin/snapshot_cache.go index 809760a7..d6973ff9 100644 --- a/backend/internal/handler/admin/snapshot_cache.go +++ b/backend/internal/handler/admin/snapshot_cache.go @@ -7,6 +7,8 @@ import ( "strings" "sync" "time" + + "golang.org/x/sync/singleflight" ) type snapshotCacheEntry struct { @@ -19,6 +21,12 @@ type snapshotCache struct { mu sync.RWMutex ttl time.Duration items map[string]snapshotCacheEntry + sf singleflight.Group +} + +type snapshotCacheLoadResult struct { + Entry snapshotCacheEntry + Hit bool } func newSnapshotCache(ttl time.Duration) *snapshotCache { @@ -70,6 +78,41 @@ func (c *snapshotCache) Set(key string, payload any) snapshotCacheEntry { return entry } +func (c *snapshotCache) GetOrLoad(key string, load func() (any, error)) (snapshotCacheEntry, bool, error) { + if load == nil { + return snapshotCacheEntry{}, false, nil + } + if entry, ok := c.Get(key); ok { + return entry, true, nil + } + if c == nil || key == "" { + payload, err := load() + if err != nil { + return snapshotCacheEntry{}, false, err + } + return c.Set(key, payload), false, nil + } + + value, err, _ := c.sf.Do(key, func() (any, error) { + if entry, ok := c.Get(key); ok { + return snapshotCacheLoadResult{Entry: entry, Hit: true}, nil + } + payload, err := load() + if err != nil { + return nil, err + } + return snapshotCacheLoadResult{Entry: c.Set(key, payload), Hit: false}, nil + }) + if err != nil { + return snapshotCacheEntry{}, false, err + } + result, ok := value.(snapshotCacheLoadResult) + if !ok { + return snapshotCacheEntry{}, false, nil + } + return result.Entry, result.Hit, nil +} + func buildETagFromAny(payload any) string { raw, err := json.Marshal(payload) if err != nil { diff --git a/backend/internal/handler/admin/snapshot_cache_test.go b/backend/internal/handler/admin/snapshot_cache_test.go index f1c1453e..ee3f72ca 100644 --- a/backend/internal/handler/admin/snapshot_cache_test.go +++ b/backend/internal/handler/admin/snapshot_cache_test.go @@ -3,6 +3,8 @@ package admin import ( + "sync" + "sync/atomic" "testing" "time" @@ -95,6 +97,61 @@ func TestBuildETagFromAny_UnmarshalablePayload(t *testing.T) { require.Empty(t, etag) } +func TestSnapshotCache_GetOrLoad_MissThenHit(t *testing.T) { + c := newSnapshotCache(5 * time.Second) + var loads atomic.Int32 + + entry, hit, err := c.GetOrLoad("key1", func() (any, error) { + loads.Add(1) + return map[string]string{"hello": "world"}, nil + }) + require.NoError(t, err) + require.False(t, hit) + require.NotEmpty(t, entry.ETag) + require.Equal(t, int32(1), loads.Load()) + + entry2, hit, err := c.GetOrLoad("key1", func() (any, error) { + loads.Add(1) + return map[string]string{"unexpected": "value"}, nil + }) + require.NoError(t, err) + require.True(t, hit) + require.Equal(t, entry.ETag, entry2.ETag) + require.Equal(t, int32(1), loads.Load()) +} + +func TestSnapshotCache_GetOrLoad_ConcurrentSingleflight(t *testing.T) { + c := newSnapshotCache(5 * time.Second) + var loads atomic.Int32 + start := make(chan struct{}) + const callers = 8 + errCh := make(chan error, callers) + + var wg sync.WaitGroup + wg.Add(callers) + for range callers { + go func() { + defer wg.Done() + <-start + _, _, err := c.GetOrLoad("shared", func() (any, error) { + loads.Add(1) + time.Sleep(20 * time.Millisecond) + return "value", nil + }) + errCh <- err + }() + } + close(start) + wg.Wait() + close(errCh) + + for err := range errCh { + require.NoError(t, err) + } + + require.Equal(t, int32(1), loads.Load()) +} + func TestParseBoolQueryWithDefault(t *testing.T) { tests := []struct { name string diff --git a/backend/internal/handler/admin/subscription_handler.go b/backend/internal/handler/admin/subscription_handler.go index e5b6db13..d6073551 100644 --- a/backend/internal/handler/admin/subscription_handler.go +++ b/backend/internal/handler/admin/subscription_handler.go @@ -216,6 +216,37 @@ func (h *SubscriptionHandler) Extend(c *gin.Context) { }) } +// ResetSubscriptionQuotaRequest represents the reset quota request +type ResetSubscriptionQuotaRequest struct { + Daily bool `json:"daily"` + Weekly bool `json:"weekly"` +} + +// ResetQuota resets daily and/or weekly usage for a subscription. +// POST /api/v1/admin/subscriptions/:id/reset-quota +func (h *SubscriptionHandler) ResetQuota(c *gin.Context) { + subscriptionID, err := strconv.ParseInt(c.Param("id"), 10, 64) + if err != nil { + response.BadRequest(c, "Invalid subscription ID") + return + } + var req ResetSubscriptionQuotaRequest + if err := c.ShouldBindJSON(&req); err != nil { + response.BadRequest(c, "Invalid request: "+err.Error()) + return + } + if !req.Daily && !req.Weekly { + response.BadRequest(c, "At least one of 'daily' or 'weekly' must be true") + return + } + sub, err := h.subscriptionService.AdminResetQuota(c.Request.Context(), subscriptionID, req.Daily, req.Weekly) + if err != nil { + response.ErrorFrom(c, err) + return + } + response.Success(c, dto.UserSubscriptionFromServiceAdmin(sub)) +} + // Revoke handles revoking a subscription // DELETE /api/v1/admin/subscriptions/:id func (h *SubscriptionHandler) Revoke(c *gin.Context) { diff --git a/backend/internal/handler/openai_chat_completions.go b/backend/internal/handler/openai_chat_completions.go new file mode 100644 index 00000000..6900e7cd --- /dev/null +++ b/backend/internal/handler/openai_chat_completions.go @@ -0,0 +1,290 @@ +package handler + +import ( + "context" + "errors" + "net/http" + "time" + + pkghttputil "github.com/Wei-Shaw/sub2api/internal/pkg/httputil" + "github.com/Wei-Shaw/sub2api/internal/pkg/ip" + "github.com/Wei-Shaw/sub2api/internal/pkg/logger" + middleware2 "github.com/Wei-Shaw/sub2api/internal/server/middleware" + "github.com/Wei-Shaw/sub2api/internal/service" + "github.com/gin-gonic/gin" + "github.com/tidwall/gjson" + "go.uber.org/zap" +) + +// ChatCompletions handles OpenAI Chat Completions API requests. +// POST /v1/chat/completions +func (h *OpenAIGatewayHandler) ChatCompletions(c *gin.Context) { + streamStarted := false + defer h.recoverResponsesPanic(c, &streamStarted) + + requestStart := time.Now() + + apiKey, ok := middleware2.GetAPIKeyFromContext(c) + if !ok { + h.errorResponse(c, http.StatusUnauthorized, "authentication_error", "Invalid API key") + return + } + + subject, ok := middleware2.GetAuthSubjectFromContext(c) + if !ok { + h.errorResponse(c, http.StatusInternalServerError, "api_error", "User context not found") + return + } + reqLog := requestLogger( + c, + "handler.openai_gateway.chat_completions", + zap.Int64("user_id", subject.UserID), + zap.Int64("api_key_id", apiKey.ID), + zap.Any("group_id", apiKey.GroupID), + ) + + if !h.ensureResponsesDependencies(c, reqLog) { + return + } + + body, err := pkghttputil.ReadRequestBodyWithPrealloc(c.Request) + if err != nil { + if maxErr, ok := extractMaxBytesError(err); ok { + h.errorResponse(c, http.StatusRequestEntityTooLarge, "invalid_request_error", buildBodyTooLargeMessage(maxErr.Limit)) + return + } + h.errorResponse(c, http.StatusBadRequest, "invalid_request_error", "Failed to read request body") + return + } + if len(body) == 0 { + h.errorResponse(c, http.StatusBadRequest, "invalid_request_error", "Request body is empty") + return + } + + if !gjson.ValidBytes(body) { + h.errorResponse(c, http.StatusBadRequest, "invalid_request_error", "Failed to parse request body") + return + } + + modelResult := gjson.GetBytes(body, "model") + if !modelResult.Exists() || modelResult.Type != gjson.String || modelResult.String() == "" { + h.errorResponse(c, http.StatusBadRequest, "invalid_request_error", "model is required") + return + } + reqModel := modelResult.String() + reqStream := gjson.GetBytes(body, "stream").Bool() + + reqLog = reqLog.With(zap.String("model", reqModel), zap.Bool("stream", reqStream)) + + setOpsRequestContext(c, reqModel, reqStream, body) + + if h.errorPassthroughService != nil { + service.BindErrorPassthroughService(c, h.errorPassthroughService) + } + + subscription, _ := middleware2.GetSubscriptionFromContext(c) + + service.SetOpsLatencyMs(c, service.OpsAuthLatencyMsKey, time.Since(requestStart).Milliseconds()) + routingStart := time.Now() + + userReleaseFunc, acquired := h.acquireResponsesUserSlot(c, subject.UserID, subject.Concurrency, reqStream, &streamStarted, reqLog) + if !acquired { + return + } + if userReleaseFunc != nil { + defer userReleaseFunc() + } + + if err := h.billingCacheService.CheckBillingEligibility(c.Request.Context(), apiKey.User, apiKey, apiKey.Group, subscription); err != nil { + reqLog.Info("openai_chat_completions.billing_eligibility_check_failed", zap.Error(err)) + status, code, message := billingErrorDetails(err) + h.handleStreamingAwareError(c, status, code, message, streamStarted) + return + } + + sessionHash := h.gatewayService.GenerateSessionHash(c, body) + promptCacheKey := h.gatewayService.ExtractSessionID(c, body) + + maxAccountSwitches := h.maxAccountSwitches + switchCount := 0 + failedAccountIDs := make(map[int64]struct{}) + sameAccountRetryCount := make(map[int64]int) + var lastFailoverErr *service.UpstreamFailoverError + + for { + c.Set("openai_chat_completions_fallback_model", "") + reqLog.Debug("openai_chat_completions.account_selecting", zap.Int("excluded_account_count", len(failedAccountIDs))) + selection, scheduleDecision, err := h.gatewayService.SelectAccountWithScheduler( + c.Request.Context(), + apiKey.GroupID, + "", + sessionHash, + reqModel, + failedAccountIDs, + service.OpenAIUpstreamTransportAny, + ) + if err != nil { + reqLog.Warn("openai_chat_completions.account_select_failed", + zap.Error(err), + zap.Int("excluded_account_count", len(failedAccountIDs)), + ) + if len(failedAccountIDs) == 0 { + defaultModel := "" + if apiKey.Group != nil { + defaultModel = apiKey.Group.DefaultMappedModel + } + if defaultModel != "" && defaultModel != reqModel { + reqLog.Info("openai_chat_completions.fallback_to_default_model", + zap.String("default_mapped_model", defaultModel), + ) + selection, scheduleDecision, err = h.gatewayService.SelectAccountWithScheduler( + c.Request.Context(), + apiKey.GroupID, + "", + sessionHash, + defaultModel, + failedAccountIDs, + service.OpenAIUpstreamTransportAny, + ) + if err == nil && selection != nil { + c.Set("openai_chat_completions_fallback_model", defaultModel) + } + } + if err != nil { + h.handleStreamingAwareError(c, http.StatusServiceUnavailable, "api_error", "Service temporarily unavailable", streamStarted) + return + } + } else { + if lastFailoverErr != nil { + h.handleFailoverExhausted(c, lastFailoverErr, streamStarted) + } else { + h.handleStreamingAwareError(c, http.StatusBadGateway, "api_error", "Upstream request failed", streamStarted) + } + return + } + } + if selection == nil || selection.Account == nil { + h.handleStreamingAwareError(c, http.StatusServiceUnavailable, "api_error", "No available accounts", streamStarted) + return + } + account := selection.Account + sessionHash = ensureOpenAIPoolModeSessionHash(sessionHash, account) + reqLog.Debug("openai_chat_completions.account_selected", zap.Int64("account_id", account.ID), zap.String("account_name", account.Name)) + _ = scheduleDecision + setOpsSelectedAccount(c, account.ID, account.Platform) + + accountReleaseFunc, acquired := h.acquireResponsesAccountSlot(c, apiKey.GroupID, sessionHash, selection, reqStream, &streamStarted, reqLog) + if !acquired { + return + } + + service.SetOpsLatencyMs(c, service.OpsRoutingLatencyMsKey, time.Since(routingStart).Milliseconds()) + forwardStart := time.Now() + + defaultMappedModel := "" + if apiKey.Group != nil { + defaultMappedModel = apiKey.Group.DefaultMappedModel + } + if fallbackModel := c.GetString("openai_chat_completions_fallback_model"); fallbackModel != "" { + defaultMappedModel = fallbackModel + } + result, err := h.gatewayService.ForwardAsChatCompletions(c.Request.Context(), c, account, body, promptCacheKey, defaultMappedModel) + + forwardDurationMs := time.Since(forwardStart).Milliseconds() + if accountReleaseFunc != nil { + accountReleaseFunc() + } + upstreamLatencyMs, _ := getContextInt64(c, service.OpsUpstreamLatencyMsKey) + responseLatencyMs := forwardDurationMs + if upstreamLatencyMs > 0 && forwardDurationMs > upstreamLatencyMs { + responseLatencyMs = forwardDurationMs - upstreamLatencyMs + } + service.SetOpsLatencyMs(c, service.OpsResponseLatencyMsKey, responseLatencyMs) + if err == nil && result != nil && result.FirstTokenMs != nil { + service.SetOpsLatencyMs(c, service.OpsTimeToFirstTokenMsKey, int64(*result.FirstTokenMs)) + } + if err != nil { + var failoverErr *service.UpstreamFailoverError + if errors.As(err, &failoverErr) { + h.gatewayService.ReportOpenAIAccountScheduleResult(account.ID, false, nil) + // Pool mode: retry on the same account + if failoverErr.RetryableOnSameAccount { + retryLimit := account.GetPoolModeRetryCount() + if sameAccountRetryCount[account.ID] < retryLimit { + sameAccountRetryCount[account.ID]++ + reqLog.Warn("openai_chat_completions.pool_mode_same_account_retry", + zap.Int64("account_id", account.ID), + zap.Int("upstream_status", failoverErr.StatusCode), + zap.Int("retry_limit", retryLimit), + zap.Int("retry_count", sameAccountRetryCount[account.ID]), + ) + select { + case <-c.Request.Context().Done(): + return + case <-time.After(sameAccountRetryDelay): + } + continue + } + } + h.gatewayService.RecordOpenAIAccountSwitch() + failedAccountIDs[account.ID] = struct{}{} + lastFailoverErr = failoverErr + if switchCount >= maxAccountSwitches { + h.handleFailoverExhausted(c, failoverErr, streamStarted) + return + } + switchCount++ + reqLog.Warn("openai_chat_completions.upstream_failover_switching", + zap.Int64("account_id", account.ID), + zap.Int("upstream_status", failoverErr.StatusCode), + zap.Int("switch_count", switchCount), + zap.Int("max_switches", maxAccountSwitches), + ) + continue + } + h.gatewayService.ReportOpenAIAccountScheduleResult(account.ID, false, nil) + wroteFallback := h.ensureForwardErrorResponse(c, streamStarted) + reqLog.Warn("openai_chat_completions.forward_failed", + zap.Int64("account_id", account.ID), + zap.Bool("fallback_error_response_written", wroteFallback), + zap.Error(err), + ) + return + } + if result != nil { + h.gatewayService.ReportOpenAIAccountScheduleResult(account.ID, true, result.FirstTokenMs) + } else { + h.gatewayService.ReportOpenAIAccountScheduleResult(account.ID, true, nil) + } + + userAgent := c.GetHeader("User-Agent") + clientIP := ip.GetClientIP(c) + + h.submitUsageRecordTask(func(ctx context.Context) { + if err := h.gatewayService.RecordUsage(ctx, &service.OpenAIRecordUsageInput{ + Result: result, + APIKey: apiKey, + User: apiKey.User, + Account: account, + Subscription: subscription, + UserAgent: userAgent, + IPAddress: clientIP, + APIKeyService: h.apiKeyService, + }); err != nil { + logger.L().With( + zap.String("component", "handler.openai_gateway.chat_completions"), + zap.Int64("user_id", subject.UserID), + zap.Int64("api_key_id", apiKey.ID), + zap.Any("group_id", apiKey.GroupID), + zap.String("model", reqModel), + zap.Int64("account_id", account.ID), + ).Error("openai_chat_completions.record_usage_failed", zap.Error(err)) + } + }) + reqLog.Debug("openai_chat_completions.request_completed", + zap.Int64("account_id", account.ID), + zap.Int("switch_count", switchCount), + ) + return + } +} diff --git a/backend/internal/handler/ops_error_logger.go b/backend/internal/handler/ops_error_logger.go index 2f53d655..cb2fad5d 100644 --- a/backend/internal/handler/ops_error_logger.go +++ b/backend/internal/handler/ops_error_logger.go @@ -31,6 +31,7 @@ const ( const ( opsErrorLogTimeout = 5 * time.Second opsErrorLogDrainTimeout = 10 * time.Second + opsErrorLogBatchWindow = 200 * time.Millisecond opsErrorLogMinWorkerCount = 4 opsErrorLogMaxWorkerCount = 32 @@ -38,6 +39,7 @@ const ( opsErrorLogQueueSizePerWorker = 128 opsErrorLogMinQueueSize = 256 opsErrorLogMaxQueueSize = 8192 + opsErrorLogBatchSize = 32 ) type opsErrorLogJob struct { @@ -82,27 +84,82 @@ func startOpsErrorLogWorkers() { for i := 0; i < workerCount; i++ { go func() { defer opsErrorLogWorkersWg.Done() - for job := range opsErrorLogQueue { - opsErrorLogQueueLen.Add(-1) - if job.ops == nil || job.entry == nil { - continue + for { + job, ok := <-opsErrorLogQueue + if !ok { + return } - func() { - defer func() { - if r := recover(); r != nil { - log.Printf("[OpsErrorLogger] worker panic: %v\n%s", r, debug.Stack()) + opsErrorLogQueueLen.Add(-1) + batch := make([]opsErrorLogJob, 0, opsErrorLogBatchSize) + batch = append(batch, job) + + timer := time.NewTimer(opsErrorLogBatchWindow) + batchLoop: + for len(batch) < opsErrorLogBatchSize { + select { + case nextJob, ok := <-opsErrorLogQueue: + if !ok { + if !timer.Stop() { + select { + case <-timer.C: + default: + } + } + flushOpsErrorLogBatch(batch) + return } - }() - ctx, cancel := context.WithTimeout(context.Background(), opsErrorLogTimeout) - _ = job.ops.RecordError(ctx, job.entry, nil) - cancel() - opsErrorLogProcessed.Add(1) - }() + opsErrorLogQueueLen.Add(-1) + batch = append(batch, nextJob) + case <-timer.C: + break batchLoop + } + } + if !timer.Stop() { + select { + case <-timer.C: + default: + } + } + flushOpsErrorLogBatch(batch) } }() } } +func flushOpsErrorLogBatch(batch []opsErrorLogJob) { + if len(batch) == 0 { + return + } + defer func() { + if r := recover(); r != nil { + log.Printf("[OpsErrorLogger] worker panic: %v\n%s", r, debug.Stack()) + } + }() + + grouped := make(map[*service.OpsService][]*service.OpsInsertErrorLogInput, len(batch)) + var processed int64 + for _, job := range batch { + if job.ops == nil || job.entry == nil { + continue + } + grouped[job.ops] = append(grouped[job.ops], job.entry) + processed++ + } + if processed == 0 { + return + } + + for opsSvc, entries := range grouped { + if opsSvc == nil || len(entries) == 0 { + continue + } + ctx, cancel := context.WithTimeout(context.Background(), opsErrorLogTimeout) + _ = opsSvc.RecordErrorBatch(ctx, entries) + cancel() + } + opsErrorLogProcessed.Add(processed) +} + func enqueueOpsErrorLog(ops *service.OpsService, entry *service.OpsInsertErrorLogInput) { if ops == nil || entry == nil { return diff --git a/backend/internal/pkg/antigravity/claude_types.go b/backend/internal/pkg/antigravity/claude_types.go index 7cc68060..8ea87f18 100644 --- a/backend/internal/pkg/antigravity/claude_types.go +++ b/backend/internal/pkg/antigravity/claude_types.go @@ -159,6 +159,8 @@ var claudeModels = []modelDef{ // Antigravity 支持的 Gemini 模型 var geminiModels = []modelDef{ {ID: "gemini-2.5-flash", DisplayName: "Gemini 2.5 Flash", CreatedAt: "2025-01-01T00:00:00Z"}, + {ID: "gemini-2.5-flash-image", DisplayName: "Gemini 2.5 Flash Image", CreatedAt: "2025-01-01T00:00:00Z"}, + {ID: "gemini-2.5-flash-image-preview", DisplayName: "Gemini 2.5 Flash Image Preview", CreatedAt: "2025-01-01T00:00:00Z"}, {ID: "gemini-2.5-flash-lite", DisplayName: "Gemini 2.5 Flash Lite", CreatedAt: "2025-01-01T00:00:00Z"}, {ID: "gemini-2.5-flash-thinking", DisplayName: "Gemini 2.5 Flash Thinking", CreatedAt: "2025-01-01T00:00:00Z"}, {ID: "gemini-3-flash", DisplayName: "Gemini 3 Flash", CreatedAt: "2025-06-01T00:00:00Z"}, diff --git a/backend/internal/pkg/antigravity/claude_types_test.go b/backend/internal/pkg/antigravity/claude_types_test.go index f7cb0a24..9fc09b1b 100644 --- a/backend/internal/pkg/antigravity/claude_types_test.go +++ b/backend/internal/pkg/antigravity/claude_types_test.go @@ -13,6 +13,8 @@ func TestDefaultModels_ContainsNewAndLegacyImageModels(t *testing.T) { requiredIDs := []string{ "claude-opus-4-6-thinking", + "gemini-2.5-flash-image", + "gemini-2.5-flash-image-preview", "gemini-3.1-flash-image", "gemini-3.1-flash-image-preview", "gemini-3-pro-image", // legacy compatibility diff --git a/backend/internal/pkg/apicompat/chatcompletions_responses_test.go b/backend/internal/pkg/apicompat/chatcompletions_responses_test.go new file mode 100644 index 00000000..71b7a6f5 --- /dev/null +++ b/backend/internal/pkg/apicompat/chatcompletions_responses_test.go @@ -0,0 +1,733 @@ +package apicompat + +import ( + "encoding/json" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +// --------------------------------------------------------------------------- +// ChatCompletionsToResponses tests +// --------------------------------------------------------------------------- + +func TestChatCompletionsToResponses_BasicText(t *testing.T) { + req := &ChatCompletionsRequest{ + Model: "gpt-4o", + Messages: []ChatMessage{ + {Role: "user", Content: json.RawMessage(`"Hello"`)}, + }, + } + + resp, err := ChatCompletionsToResponses(req) + require.NoError(t, err) + assert.Equal(t, "gpt-4o", resp.Model) + assert.True(t, resp.Stream) // always forced true + assert.False(t, *resp.Store) + + var items []ResponsesInputItem + require.NoError(t, json.Unmarshal(resp.Input, &items)) + require.Len(t, items, 1) + assert.Equal(t, "user", items[0].Role) +} + +func TestChatCompletionsToResponses_SystemMessage(t *testing.T) { + req := &ChatCompletionsRequest{ + Model: "gpt-4o", + Messages: []ChatMessage{ + {Role: "system", Content: json.RawMessage(`"You are helpful."`)}, + {Role: "user", Content: json.RawMessage(`"Hi"`)}, + }, + } + + resp, err := ChatCompletionsToResponses(req) + require.NoError(t, err) + + var items []ResponsesInputItem + require.NoError(t, json.Unmarshal(resp.Input, &items)) + require.Len(t, items, 2) + assert.Equal(t, "system", items[0].Role) + assert.Equal(t, "user", items[1].Role) +} + +func TestChatCompletionsToResponses_ToolCalls(t *testing.T) { + req := &ChatCompletionsRequest{ + Model: "gpt-4o", + Messages: []ChatMessage{ + {Role: "user", Content: json.RawMessage(`"Call the function"`)}, + { + Role: "assistant", + ToolCalls: []ChatToolCall{ + { + ID: "call_1", + Type: "function", + Function: ChatFunctionCall{ + Name: "ping", + Arguments: `{"host":"example.com"}`, + }, + }, + }, + }, + { + Role: "tool", + ToolCallID: "call_1", + Content: json.RawMessage(`"pong"`), + }, + }, + Tools: []ChatTool{ + { + Type: "function", + Function: &ChatFunction{ + Name: "ping", + Description: "Ping a host", + Parameters: json.RawMessage(`{"type":"object"}`), + }, + }, + }, + } + + resp, err := ChatCompletionsToResponses(req) + require.NoError(t, err) + + var items []ResponsesInputItem + require.NoError(t, json.Unmarshal(resp.Input, &items)) + // user + function_call + function_call_output = 3 + // (assistant message with empty content + tool_calls → only function_call items emitted) + require.Len(t, items, 3) + + // Check function_call item + assert.Equal(t, "function_call", items[1].Type) + assert.Equal(t, "call_1", items[1].CallID) + assert.Equal(t, "ping", items[1].Name) + + // Check function_call_output item + assert.Equal(t, "function_call_output", items[2].Type) + assert.Equal(t, "call_1", items[2].CallID) + assert.Equal(t, "pong", items[2].Output) + + // Check tools + require.Len(t, resp.Tools, 1) + assert.Equal(t, "function", resp.Tools[0].Type) + assert.Equal(t, "ping", resp.Tools[0].Name) +} + +func TestChatCompletionsToResponses_MaxTokens(t *testing.T) { + t.Run("max_tokens", func(t *testing.T) { + maxTokens := 100 + req := &ChatCompletionsRequest{ + Model: "gpt-4o", + MaxTokens: &maxTokens, + Messages: []ChatMessage{{Role: "user", Content: json.RawMessage(`"Hi"`)}}, + } + resp, err := ChatCompletionsToResponses(req) + require.NoError(t, err) + require.NotNil(t, resp.MaxOutputTokens) + // Below minMaxOutputTokens (128), should be clamped + assert.Equal(t, minMaxOutputTokens, *resp.MaxOutputTokens) + }) + + t.Run("max_completion_tokens_preferred", func(t *testing.T) { + maxTokens := 100 + maxCompletion := 500 + req := &ChatCompletionsRequest{ + Model: "gpt-4o", + MaxTokens: &maxTokens, + MaxCompletionTokens: &maxCompletion, + Messages: []ChatMessage{{Role: "user", Content: json.RawMessage(`"Hi"`)}}, + } + resp, err := ChatCompletionsToResponses(req) + require.NoError(t, err) + require.NotNil(t, resp.MaxOutputTokens) + assert.Equal(t, 500, *resp.MaxOutputTokens) + }) +} + +func TestChatCompletionsToResponses_ReasoningEffort(t *testing.T) { + req := &ChatCompletionsRequest{ + Model: "gpt-4o", + ReasoningEffort: "high", + Messages: []ChatMessage{{Role: "user", Content: json.RawMessage(`"Hi"`)}}, + } + resp, err := ChatCompletionsToResponses(req) + require.NoError(t, err) + require.NotNil(t, resp.Reasoning) + assert.Equal(t, "high", resp.Reasoning.Effort) + assert.Equal(t, "auto", resp.Reasoning.Summary) +} + +func TestChatCompletionsToResponses_ImageURL(t *testing.T) { + content := `[{"type":"text","text":"Describe this"},{"type":"image_url","image_url":{"url":"data:image/png;base64,abc123"}}]` + req := &ChatCompletionsRequest{ + Model: "gpt-4o", + Messages: []ChatMessage{ + {Role: "user", Content: json.RawMessage(content)}, + }, + } + resp, err := ChatCompletionsToResponses(req) + require.NoError(t, err) + + var items []ResponsesInputItem + require.NoError(t, json.Unmarshal(resp.Input, &items)) + require.Len(t, items, 1) + + var parts []ResponsesContentPart + require.NoError(t, json.Unmarshal(items[0].Content, &parts)) + require.Len(t, parts, 2) + assert.Equal(t, "input_text", parts[0].Type) + assert.Equal(t, "Describe this", parts[0].Text) + assert.Equal(t, "input_image", parts[1].Type) + assert.Equal(t, "data:image/png;base64,abc123", parts[1].ImageURL) +} + +func TestChatCompletionsToResponses_LegacyFunctions(t *testing.T) { + req := &ChatCompletionsRequest{ + Model: "gpt-4o", + Messages: []ChatMessage{ + {Role: "user", Content: json.RawMessage(`"Hi"`)}, + }, + Functions: []ChatFunction{ + { + Name: "get_weather", + Description: "Get weather", + Parameters: json.RawMessage(`{"type":"object"}`), + }, + }, + FunctionCall: json.RawMessage(`{"name":"get_weather"}`), + } + + resp, err := ChatCompletionsToResponses(req) + require.NoError(t, err) + require.Len(t, resp.Tools, 1) + assert.Equal(t, "function", resp.Tools[0].Type) + assert.Equal(t, "get_weather", resp.Tools[0].Name) + + // tool_choice should be converted + require.NotNil(t, resp.ToolChoice) + var tc map[string]any + require.NoError(t, json.Unmarshal(resp.ToolChoice, &tc)) + assert.Equal(t, "function", tc["type"]) +} + +func TestChatCompletionsToResponses_ServiceTier(t *testing.T) { + req := &ChatCompletionsRequest{ + Model: "gpt-4o", + ServiceTier: "flex", + Messages: []ChatMessage{{Role: "user", Content: json.RawMessage(`"Hi"`)}}, + } + resp, err := ChatCompletionsToResponses(req) + require.NoError(t, err) + assert.Equal(t, "flex", resp.ServiceTier) +} + +func TestChatCompletionsToResponses_AssistantWithTextAndToolCalls(t *testing.T) { + req := &ChatCompletionsRequest{ + Model: "gpt-4o", + Messages: []ChatMessage{ + {Role: "user", Content: json.RawMessage(`"Do something"`)}, + { + Role: "assistant", + Content: json.RawMessage(`"Let me call a function."`), + ToolCalls: []ChatToolCall{ + { + ID: "call_abc", + Type: "function", + Function: ChatFunctionCall{ + Name: "do_thing", + Arguments: `{}`, + }, + }, + }, + }, + }, + } + + resp, err := ChatCompletionsToResponses(req) + require.NoError(t, err) + + var items []ResponsesInputItem + require.NoError(t, json.Unmarshal(resp.Input, &items)) + // user + assistant message (with text) + function_call + require.Len(t, items, 3) + assert.Equal(t, "user", items[0].Role) + assert.Equal(t, "assistant", items[1].Role) + assert.Equal(t, "function_call", items[2].Type) +} + +// --------------------------------------------------------------------------- +// ResponsesToChatCompletions tests +// --------------------------------------------------------------------------- + +func TestResponsesToChatCompletions_BasicText(t *testing.T) { + resp := &ResponsesResponse{ + ID: "resp_123", + Status: "completed", + Output: []ResponsesOutput{ + { + Type: "message", + Content: []ResponsesContentPart{ + {Type: "output_text", Text: "Hello, world!"}, + }, + }, + }, + Usage: &ResponsesUsage{ + InputTokens: 10, + OutputTokens: 5, + TotalTokens: 15, + }, + } + + chat := ResponsesToChatCompletions(resp, "gpt-4o") + assert.Equal(t, "chat.completion", chat.Object) + assert.Equal(t, "gpt-4o", chat.Model) + require.Len(t, chat.Choices, 1) + assert.Equal(t, "stop", chat.Choices[0].FinishReason) + + var content string + require.NoError(t, json.Unmarshal(chat.Choices[0].Message.Content, &content)) + assert.Equal(t, "Hello, world!", content) + + require.NotNil(t, chat.Usage) + assert.Equal(t, 10, chat.Usage.PromptTokens) + assert.Equal(t, 5, chat.Usage.CompletionTokens) + assert.Equal(t, 15, chat.Usage.TotalTokens) +} + +func TestResponsesToChatCompletions_ToolCalls(t *testing.T) { + resp := &ResponsesResponse{ + ID: "resp_456", + Status: "completed", + Output: []ResponsesOutput{ + { + Type: "function_call", + CallID: "call_xyz", + Name: "get_weather", + Arguments: `{"city":"NYC"}`, + }, + }, + } + + chat := ResponsesToChatCompletions(resp, "gpt-4o") + require.Len(t, chat.Choices, 1) + assert.Equal(t, "tool_calls", chat.Choices[0].FinishReason) + + msg := chat.Choices[0].Message + require.Len(t, msg.ToolCalls, 1) + assert.Equal(t, "call_xyz", msg.ToolCalls[0].ID) + assert.Equal(t, "function", msg.ToolCalls[0].Type) + assert.Equal(t, "get_weather", msg.ToolCalls[0].Function.Name) + assert.Equal(t, `{"city":"NYC"}`, msg.ToolCalls[0].Function.Arguments) +} + +func TestResponsesToChatCompletions_Reasoning(t *testing.T) { + resp := &ResponsesResponse{ + ID: "resp_789", + Status: "completed", + Output: []ResponsesOutput{ + { + Type: "reasoning", + Summary: []ResponsesSummary{ + {Type: "summary_text", Text: "I thought about it."}, + }, + }, + { + Type: "message", + Content: []ResponsesContentPart{ + {Type: "output_text", Text: "The answer is 42."}, + }, + }, + }, + } + + chat := ResponsesToChatCompletions(resp, "gpt-4o") + require.Len(t, chat.Choices, 1) + + var content string + require.NoError(t, json.Unmarshal(chat.Choices[0].Message.Content, &content)) + // Reasoning summary is prepended to text + assert.Equal(t, "I thought about it.The answer is 42.", content) +} + +func TestResponsesToChatCompletions_Incomplete(t *testing.T) { + resp := &ResponsesResponse{ + ID: "resp_inc", + Status: "incomplete", + IncompleteDetails: &ResponsesIncompleteDetails{Reason: "max_output_tokens"}, + Output: []ResponsesOutput{ + { + Type: "message", + Content: []ResponsesContentPart{ + {Type: "output_text", Text: "partial..."}, + }, + }, + }, + } + + chat := ResponsesToChatCompletions(resp, "gpt-4o") + require.Len(t, chat.Choices, 1) + assert.Equal(t, "length", chat.Choices[0].FinishReason) +} + +func TestResponsesToChatCompletions_CachedTokens(t *testing.T) { + resp := &ResponsesResponse{ + ID: "resp_cache", + Status: "completed", + Output: []ResponsesOutput{ + { + Type: "message", + Content: []ResponsesContentPart{{Type: "output_text", Text: "cached"}}, + }, + }, + Usage: &ResponsesUsage{ + InputTokens: 100, + OutputTokens: 10, + TotalTokens: 110, + InputTokensDetails: &ResponsesInputTokensDetails{ + CachedTokens: 80, + }, + }, + } + + chat := ResponsesToChatCompletions(resp, "gpt-4o") + require.NotNil(t, chat.Usage) + require.NotNil(t, chat.Usage.PromptTokensDetails) + assert.Equal(t, 80, chat.Usage.PromptTokensDetails.CachedTokens) +} + +func TestResponsesToChatCompletions_WebSearch(t *testing.T) { + resp := &ResponsesResponse{ + ID: "resp_ws", + Status: "completed", + Output: []ResponsesOutput{ + { + Type: "web_search_call", + Action: &WebSearchAction{Type: "search", Query: "test"}, + }, + { + Type: "message", + Content: []ResponsesContentPart{{Type: "output_text", Text: "search results"}}, + }, + }, + } + + chat := ResponsesToChatCompletions(resp, "gpt-4o") + require.Len(t, chat.Choices, 1) + assert.Equal(t, "stop", chat.Choices[0].FinishReason) + + var content string + require.NoError(t, json.Unmarshal(chat.Choices[0].Message.Content, &content)) + assert.Equal(t, "search results", content) +} + +// --------------------------------------------------------------------------- +// Streaming: ResponsesEventToChatChunks tests +// --------------------------------------------------------------------------- + +func TestResponsesEventToChatChunks_TextDelta(t *testing.T) { + state := NewResponsesEventToChatState() + state.Model = "gpt-4o" + + // response.created → role chunk + chunks := ResponsesEventToChatChunks(&ResponsesStreamEvent{ + Type: "response.created", + Response: &ResponsesResponse{ + ID: "resp_stream", + }, + }, state) + require.Len(t, chunks, 1) + assert.Equal(t, "assistant", chunks[0].Choices[0].Delta.Role) + assert.True(t, state.SentRole) + + // response.output_text.delta → content chunk + chunks = ResponsesEventToChatChunks(&ResponsesStreamEvent{ + Type: "response.output_text.delta", + Delta: "Hello", + }, state) + require.Len(t, chunks, 1) + require.NotNil(t, chunks[0].Choices[0].Delta.Content) + assert.Equal(t, "Hello", *chunks[0].Choices[0].Delta.Content) +} + +func TestResponsesEventToChatChunks_ToolCallDelta(t *testing.T) { + state := NewResponsesEventToChatState() + state.Model = "gpt-4o" + state.SentRole = true + + // response.output_item.added (function_call) — output_index=1 (e.g. after a message item at 0) + chunks := ResponsesEventToChatChunks(&ResponsesStreamEvent{ + Type: "response.output_item.added", + OutputIndex: 1, + Item: &ResponsesOutput{ + Type: "function_call", + CallID: "call_1", + Name: "get_weather", + }, + }, state) + require.Len(t, chunks, 1) + require.Len(t, chunks[0].Choices[0].Delta.ToolCalls, 1) + tc := chunks[0].Choices[0].Delta.ToolCalls[0] + assert.Equal(t, "call_1", tc.ID) + assert.Equal(t, "get_weather", tc.Function.Name) + require.NotNil(t, tc.Index) + assert.Equal(t, 0, *tc.Index) + + // response.function_call_arguments.delta — uses output_index (NOT call_id) to find tool + chunks = ResponsesEventToChatChunks(&ResponsesStreamEvent{ + Type: "response.function_call_arguments.delta", + OutputIndex: 1, // matches the output_index from output_item.added above + Delta: `{"city":`, + }, state) + require.Len(t, chunks, 1) + tc = chunks[0].Choices[0].Delta.ToolCalls[0] + require.NotNil(t, tc.Index) + assert.Equal(t, 0, *tc.Index, "argument delta must use same index as the tool call") + assert.Equal(t, `{"city":`, tc.Function.Arguments) + + // Add a second function call at output_index=2 + chunks = ResponsesEventToChatChunks(&ResponsesStreamEvent{ + Type: "response.output_item.added", + OutputIndex: 2, + Item: &ResponsesOutput{ + Type: "function_call", + CallID: "call_2", + Name: "get_time", + }, + }, state) + require.Len(t, chunks, 1) + tc = chunks[0].Choices[0].Delta.ToolCalls[0] + require.NotNil(t, tc.Index) + assert.Equal(t, 1, *tc.Index, "second tool call should get index 1") + + // Argument delta for second tool call + chunks = ResponsesEventToChatChunks(&ResponsesStreamEvent{ + Type: "response.function_call_arguments.delta", + OutputIndex: 2, + Delta: `{"tz":"UTC"}`, + }, state) + require.Len(t, chunks, 1) + tc = chunks[0].Choices[0].Delta.ToolCalls[0] + require.NotNil(t, tc.Index) + assert.Equal(t, 1, *tc.Index, "second tool arg delta must use index 1") + + // Argument delta for first tool call (interleaved) + chunks = ResponsesEventToChatChunks(&ResponsesStreamEvent{ + Type: "response.function_call_arguments.delta", + OutputIndex: 1, + Delta: `"Tokyo"}`, + }, state) + require.Len(t, chunks, 1) + tc = chunks[0].Choices[0].Delta.ToolCalls[0] + require.NotNil(t, tc.Index) + assert.Equal(t, 0, *tc.Index, "first tool arg delta must still use index 0") +} + +func TestResponsesEventToChatChunks_Completed(t *testing.T) { + state := NewResponsesEventToChatState() + state.Model = "gpt-4o" + state.IncludeUsage = true + + chunks := ResponsesEventToChatChunks(&ResponsesStreamEvent{ + Type: "response.completed", + Response: &ResponsesResponse{ + Status: "completed", + Usage: &ResponsesUsage{ + InputTokens: 50, + OutputTokens: 20, + TotalTokens: 70, + InputTokensDetails: &ResponsesInputTokensDetails{ + CachedTokens: 30, + }, + }, + }, + }, state) + // finish chunk + usage chunk + require.Len(t, chunks, 2) + + // First chunk: finish_reason + require.NotNil(t, chunks[0].Choices[0].FinishReason) + assert.Equal(t, "stop", *chunks[0].Choices[0].FinishReason) + + // Second chunk: usage + require.NotNil(t, chunks[1].Usage) + assert.Equal(t, 50, chunks[1].Usage.PromptTokens) + assert.Equal(t, 20, chunks[1].Usage.CompletionTokens) + assert.Equal(t, 70, chunks[1].Usage.TotalTokens) + require.NotNil(t, chunks[1].Usage.PromptTokensDetails) + assert.Equal(t, 30, chunks[1].Usage.PromptTokensDetails.CachedTokens) +} + +func TestResponsesEventToChatChunks_CompletedWithToolCalls(t *testing.T) { + state := NewResponsesEventToChatState() + state.Model = "gpt-4o" + state.SawToolCall = true + + chunks := ResponsesEventToChatChunks(&ResponsesStreamEvent{ + Type: "response.completed", + Response: &ResponsesResponse{ + Status: "completed", + }, + }, state) + require.Len(t, chunks, 1) + require.NotNil(t, chunks[0].Choices[0].FinishReason) + assert.Equal(t, "tool_calls", *chunks[0].Choices[0].FinishReason) +} + +func TestResponsesEventToChatChunks_ReasoningDelta(t *testing.T) { + state := NewResponsesEventToChatState() + state.Model = "gpt-4o" + state.SentRole = true + + chunks := ResponsesEventToChatChunks(&ResponsesStreamEvent{ + Type: "response.reasoning_summary_text.delta", + Delta: "Thinking...", + }, state) + require.Len(t, chunks, 1) + require.NotNil(t, chunks[0].Choices[0].Delta.Content) + assert.Equal(t, "Thinking...", *chunks[0].Choices[0].Delta.Content) +} + +func TestFinalizeResponsesChatStream(t *testing.T) { + state := NewResponsesEventToChatState() + state.Model = "gpt-4o" + state.IncludeUsage = true + state.Usage = &ChatUsage{ + PromptTokens: 100, + CompletionTokens: 50, + TotalTokens: 150, + } + + chunks := FinalizeResponsesChatStream(state) + require.Len(t, chunks, 2) + + // Finish chunk + require.NotNil(t, chunks[0].Choices[0].FinishReason) + assert.Equal(t, "stop", *chunks[0].Choices[0].FinishReason) + + // Usage chunk + require.NotNil(t, chunks[1].Usage) + assert.Equal(t, 100, chunks[1].Usage.PromptTokens) + + // Idempotent: second call returns nil + assert.Nil(t, FinalizeResponsesChatStream(state)) +} + +func TestFinalizeResponsesChatStream_AfterCompleted(t *testing.T) { + // If response.completed already emitted the finish chunk, FinalizeResponsesChatStream + // must be a no-op (prevents double finish_reason being sent to the client). + state := NewResponsesEventToChatState() + state.Model = "gpt-4o" + state.IncludeUsage = true + + // Simulate response.completed + chunks := ResponsesEventToChatChunks(&ResponsesStreamEvent{ + Type: "response.completed", + Response: &ResponsesResponse{ + Status: "completed", + Usage: &ResponsesUsage{ + InputTokens: 10, + OutputTokens: 5, + TotalTokens: 15, + }, + }, + }, state) + require.NotEmpty(t, chunks) // finish + usage chunks + + // Now FinalizeResponsesChatStream should return nil — already finalized. + assert.Nil(t, FinalizeResponsesChatStream(state)) +} + +func TestChatChunkToSSE(t *testing.T) { + chunk := ChatCompletionsChunk{ + ID: "chatcmpl-test", + Object: "chat.completion.chunk", + Created: 1700000000, + Model: "gpt-4o", + Choices: []ChatChunkChoice{ + { + Index: 0, + Delta: ChatDelta{Role: "assistant"}, + FinishReason: nil, + }, + }, + } + + sse, err := ChatChunkToSSE(chunk) + require.NoError(t, err) + assert.Contains(t, sse, "data: ") + assert.Contains(t, sse, "chatcmpl-test") + assert.Contains(t, sse, "assistant") + assert.True(t, len(sse) > 10) +} + +// --------------------------------------------------------------------------- +// Stream round-trip test +// --------------------------------------------------------------------------- + +func TestChatCompletionsStreamRoundTrip(t *testing.T) { + // Simulate: client sends chat completions request, upstream returns Responses SSE events. + // Verify that the streaming state machine produces correct chat completions chunks. + + state := NewResponsesEventToChatState() + state.Model = "gpt-4o" + state.IncludeUsage = true + + var allChunks []ChatCompletionsChunk + + // 1. response.created + chunks := ResponsesEventToChatChunks(&ResponsesStreamEvent{ + Type: "response.created", + Response: &ResponsesResponse{ID: "resp_rt"}, + }, state) + allChunks = append(allChunks, chunks...) + + // 2. text deltas + for _, text := range []string{"Hello", ", ", "world", "!"} { + chunks = ResponsesEventToChatChunks(&ResponsesStreamEvent{ + Type: "response.output_text.delta", + Delta: text, + }, state) + allChunks = append(allChunks, chunks...) + } + + // 3. response.completed + chunks = ResponsesEventToChatChunks(&ResponsesStreamEvent{ + Type: "response.completed", + Response: &ResponsesResponse{ + Status: "completed", + Usage: &ResponsesUsage{ + InputTokens: 10, + OutputTokens: 4, + TotalTokens: 14, + }, + }, + }, state) + allChunks = append(allChunks, chunks...) + + // Verify: role chunk + 4 text chunks + finish chunk + usage chunk = 7 + require.Len(t, allChunks, 7) + + // First chunk has role + assert.Equal(t, "assistant", allChunks[0].Choices[0].Delta.Role) + + // Text chunks + var fullText string + for i := 1; i <= 4; i++ { + require.NotNil(t, allChunks[i].Choices[0].Delta.Content) + fullText += *allChunks[i].Choices[0].Delta.Content + } + assert.Equal(t, "Hello, world!", fullText) + + // Finish chunk + require.NotNil(t, allChunks[5].Choices[0].FinishReason) + assert.Equal(t, "stop", *allChunks[5].Choices[0].FinishReason) + + // Usage chunk + require.NotNil(t, allChunks[6].Usage) + assert.Equal(t, 10, allChunks[6].Usage.PromptTokens) + assert.Equal(t, 4, allChunks[6].Usage.CompletionTokens) + + // All chunks share the same ID + for _, c := range allChunks { + assert.Equal(t, "resp_rt", c.ID) + } +} diff --git a/backend/internal/pkg/apicompat/chatcompletions_to_responses.go b/backend/internal/pkg/apicompat/chatcompletions_to_responses.go new file mode 100644 index 00000000..37285b09 --- /dev/null +++ b/backend/internal/pkg/apicompat/chatcompletions_to_responses.go @@ -0,0 +1,312 @@ +package apicompat + +import ( + "encoding/json" + "fmt" +) + +// ChatCompletionsToResponses converts a Chat Completions request into a +// Responses API request. The upstream always streams, so Stream is forced to +// true. store is always false and reasoning.encrypted_content is always +// included so that the response translator has full context. +func ChatCompletionsToResponses(req *ChatCompletionsRequest) (*ResponsesRequest, error) { + input, err := convertChatMessagesToResponsesInput(req.Messages) + if err != nil { + return nil, err + } + + inputJSON, err := json.Marshal(input) + if err != nil { + return nil, err + } + + out := &ResponsesRequest{ + Model: req.Model, + Input: inputJSON, + Temperature: req.Temperature, + TopP: req.TopP, + Stream: true, // upstream always streams + Include: []string{"reasoning.encrypted_content"}, + ServiceTier: req.ServiceTier, + } + + storeFalse := false + out.Store = &storeFalse + + // max_tokens / max_completion_tokens → max_output_tokens, prefer max_completion_tokens + maxTokens := 0 + if req.MaxTokens != nil { + maxTokens = *req.MaxTokens + } + if req.MaxCompletionTokens != nil { + maxTokens = *req.MaxCompletionTokens + } + if maxTokens > 0 { + v := maxTokens + if v < minMaxOutputTokens { + v = minMaxOutputTokens + } + out.MaxOutputTokens = &v + } + + // reasoning_effort → reasoning.effort + reasoning.summary="auto" + if req.ReasoningEffort != "" { + out.Reasoning = &ResponsesReasoning{ + Effort: req.ReasoningEffort, + Summary: "auto", + } + } + + // tools[] and legacy functions[] → ResponsesTool[] + if len(req.Tools) > 0 || len(req.Functions) > 0 { + out.Tools = convertChatToolsToResponses(req.Tools, req.Functions) + } + + // tool_choice: already compatible format — pass through directly. + // Legacy function_call needs mapping. + if len(req.ToolChoice) > 0 { + out.ToolChoice = req.ToolChoice + } else if len(req.FunctionCall) > 0 { + tc, err := convertChatFunctionCallToToolChoice(req.FunctionCall) + if err != nil { + return nil, fmt.Errorf("convert function_call: %w", err) + } + out.ToolChoice = tc + } + + return out, nil +} + +// convertChatMessagesToResponsesInput converts the Chat Completions messages +// array into a Responses API input items array. +func convertChatMessagesToResponsesInput(msgs []ChatMessage) ([]ResponsesInputItem, error) { + var out []ResponsesInputItem + for _, m := range msgs { + items, err := chatMessageToResponsesItems(m) + if err != nil { + return nil, err + } + out = append(out, items...) + } + return out, nil +} + +// chatMessageToResponsesItems converts a single ChatMessage into one or more +// ResponsesInputItem values. +func chatMessageToResponsesItems(m ChatMessage) ([]ResponsesInputItem, error) { + switch m.Role { + case "system": + return chatSystemToResponses(m) + case "user": + return chatUserToResponses(m) + case "assistant": + return chatAssistantToResponses(m) + case "tool": + return chatToolToResponses(m) + case "function": + return chatFunctionToResponses(m) + default: + return chatUserToResponses(m) + } +} + +// chatSystemToResponses converts a system message. +func chatSystemToResponses(m ChatMessage) ([]ResponsesInputItem, error) { + text, err := parseChatContent(m.Content) + if err != nil { + return nil, err + } + content, err := json.Marshal(text) + if err != nil { + return nil, err + } + return []ResponsesInputItem{{Role: "system", Content: content}}, nil +} + +// chatUserToResponses converts a user message, handling both plain strings and +// multi-modal content arrays. +func chatUserToResponses(m ChatMessage) ([]ResponsesInputItem, error) { + // Try plain string first. + var s string + if err := json.Unmarshal(m.Content, &s); err == nil { + content, _ := json.Marshal(s) + return []ResponsesInputItem{{Role: "user", Content: content}}, nil + } + + var parts []ChatContentPart + if err := json.Unmarshal(m.Content, &parts); err != nil { + return nil, fmt.Errorf("parse user content: %w", err) + } + + var responseParts []ResponsesContentPart + for _, p := range parts { + switch p.Type { + case "text": + if p.Text != "" { + responseParts = append(responseParts, ResponsesContentPart{ + Type: "input_text", + Text: p.Text, + }) + } + case "image_url": + if p.ImageURL != nil && p.ImageURL.URL != "" { + responseParts = append(responseParts, ResponsesContentPart{ + Type: "input_image", + ImageURL: p.ImageURL.URL, + }) + } + } + } + + content, err := json.Marshal(responseParts) + if err != nil { + return nil, err + } + return []ResponsesInputItem{{Role: "user", Content: content}}, nil +} + +// chatAssistantToResponses converts an assistant message. If there is both +// text content and tool_calls, the text is emitted as an assistant message +// first, then each tool_call becomes a function_call item. If the content is +// empty/nil and there are tool_calls, only function_call items are emitted. +func chatAssistantToResponses(m ChatMessage) ([]ResponsesInputItem, error) { + var items []ResponsesInputItem + + // Emit assistant message with output_text if content is non-empty. + if len(m.Content) > 0 { + var s string + if err := json.Unmarshal(m.Content, &s); err == nil && s != "" { + parts := []ResponsesContentPart{{Type: "output_text", Text: s}} + partsJSON, err := json.Marshal(parts) + if err != nil { + return nil, err + } + items = append(items, ResponsesInputItem{Role: "assistant", Content: partsJSON}) + } + } + + // Emit one function_call item per tool_call. + for _, tc := range m.ToolCalls { + args := tc.Function.Arguments + if args == "" { + args = "{}" + } + items = append(items, ResponsesInputItem{ + Type: "function_call", + CallID: tc.ID, + Name: tc.Function.Name, + Arguments: args, + ID: tc.ID, + }) + } + + return items, nil +} + +// chatToolToResponses converts a tool result message (role=tool) into a +// function_call_output item. +func chatToolToResponses(m ChatMessage) ([]ResponsesInputItem, error) { + output, err := parseChatContent(m.Content) + if err != nil { + return nil, err + } + if output == "" { + output = "(empty)" + } + return []ResponsesInputItem{{ + Type: "function_call_output", + CallID: m.ToolCallID, + Output: output, + }}, nil +} + +// chatFunctionToResponses converts a legacy function result message +// (role=function) into a function_call_output item. The Name field is used as +// call_id since legacy function calls do not carry a separate call_id. +func chatFunctionToResponses(m ChatMessage) ([]ResponsesInputItem, error) { + output, err := parseChatContent(m.Content) + if err != nil { + return nil, err + } + if output == "" { + output = "(empty)" + } + return []ResponsesInputItem{{ + Type: "function_call_output", + CallID: m.Name, + Output: output, + }}, nil +} + +// parseChatContent returns the string value of a ChatMessage Content field. +// Content must be a JSON string. Returns "" if content is null or empty. +func parseChatContent(raw json.RawMessage) (string, error) { + if len(raw) == 0 { + return "", nil + } + var s string + if err := json.Unmarshal(raw, &s); err != nil { + return "", fmt.Errorf("parse content as string: %w", err) + } + return s, nil +} + +// convertChatToolsToResponses maps Chat Completions tool definitions and legacy +// function definitions to Responses API tool definitions. +func convertChatToolsToResponses(tools []ChatTool, functions []ChatFunction) []ResponsesTool { + var out []ResponsesTool + + for _, t := range tools { + if t.Type != "function" || t.Function == nil { + continue + } + rt := ResponsesTool{ + Type: "function", + Name: t.Function.Name, + Description: t.Function.Description, + Parameters: t.Function.Parameters, + Strict: t.Function.Strict, + } + out = append(out, rt) + } + + // Legacy functions[] are treated as function-type tools. + for _, f := range functions { + rt := ResponsesTool{ + Type: "function", + Name: f.Name, + Description: f.Description, + Parameters: f.Parameters, + Strict: f.Strict, + } + out = append(out, rt) + } + + return out +} + +// convertChatFunctionCallToToolChoice maps the legacy function_call field to a +// Responses API tool_choice value. +// +// "auto" → "auto" +// "none" → "none" +// {"name":"X"} → {"type":"function","function":{"name":"X"}} +func convertChatFunctionCallToToolChoice(raw json.RawMessage) (json.RawMessage, error) { + // Try string first ("auto", "none", etc.) — pass through as-is. + var s string + if err := json.Unmarshal(raw, &s); err == nil { + return json.Marshal(s) + } + + // Object form: {"name":"X"} + var obj struct { + Name string `json:"name"` + } + if err := json.Unmarshal(raw, &obj); err != nil { + return nil, err + } + return json.Marshal(map[string]any{ + "type": "function", + "function": map[string]string{"name": obj.Name}, + }) +} diff --git a/backend/internal/pkg/apicompat/responses_to_chatcompletions.go b/backend/internal/pkg/apicompat/responses_to_chatcompletions.go new file mode 100644 index 00000000..8f83bce4 --- /dev/null +++ b/backend/internal/pkg/apicompat/responses_to_chatcompletions.go @@ -0,0 +1,368 @@ +package apicompat + +import ( + "crypto/rand" + "encoding/hex" + "encoding/json" + "fmt" + "time" +) + +// --------------------------------------------------------------------------- +// Non-streaming: ResponsesResponse → ChatCompletionsResponse +// --------------------------------------------------------------------------- + +// ResponsesToChatCompletions converts a Responses API response into a Chat +// Completions response. Text output items are concatenated into +// choices[0].message.content; function_call items become tool_calls. +func ResponsesToChatCompletions(resp *ResponsesResponse, model string) *ChatCompletionsResponse { + id := resp.ID + if id == "" { + id = generateChatCmplID() + } + + out := &ChatCompletionsResponse{ + ID: id, + Object: "chat.completion", + Created: time.Now().Unix(), + Model: model, + } + + var contentText string + var toolCalls []ChatToolCall + + for _, item := range resp.Output { + switch item.Type { + case "message": + for _, part := range item.Content { + if part.Type == "output_text" && part.Text != "" { + contentText += part.Text + } + } + case "function_call": + toolCalls = append(toolCalls, ChatToolCall{ + ID: item.CallID, + Type: "function", + Function: ChatFunctionCall{ + Name: item.Name, + Arguments: item.Arguments, + }, + }) + case "reasoning": + for _, s := range item.Summary { + if s.Type == "summary_text" && s.Text != "" { + contentText += s.Text + } + } + case "web_search_call": + // silently consumed — results already incorporated into text output + } + } + + msg := ChatMessage{Role: "assistant"} + if len(toolCalls) > 0 { + msg.ToolCalls = toolCalls + } + if contentText != "" { + raw, _ := json.Marshal(contentText) + msg.Content = raw + } + + finishReason := responsesStatusToChatFinishReason(resp.Status, resp.IncompleteDetails, toolCalls) + + out.Choices = []ChatChoice{{ + Index: 0, + Message: msg, + FinishReason: finishReason, + }} + + if resp.Usage != nil { + usage := &ChatUsage{ + PromptTokens: resp.Usage.InputTokens, + CompletionTokens: resp.Usage.OutputTokens, + TotalTokens: resp.Usage.InputTokens + resp.Usage.OutputTokens, + } + if resp.Usage.InputTokensDetails != nil && resp.Usage.InputTokensDetails.CachedTokens > 0 { + usage.PromptTokensDetails = &ChatTokenDetails{ + CachedTokens: resp.Usage.InputTokensDetails.CachedTokens, + } + } + out.Usage = usage + } + + return out +} + +func responsesStatusToChatFinishReason(status string, details *ResponsesIncompleteDetails, toolCalls []ChatToolCall) string { + switch status { + case "incomplete": + if details != nil && details.Reason == "max_output_tokens" { + return "length" + } + return "stop" + case "completed": + if len(toolCalls) > 0 { + return "tool_calls" + } + return "stop" + default: + return "stop" + } +} + +// --------------------------------------------------------------------------- +// Streaming: ResponsesStreamEvent → []ChatCompletionsChunk (stateful converter) +// --------------------------------------------------------------------------- + +// ResponsesEventToChatState tracks state for converting a sequence of Responses +// SSE events into Chat Completions SSE chunks. +type ResponsesEventToChatState struct { + ID string + Model string + Created int64 + SentRole bool + SawToolCall bool + SawText bool + Finalized bool // true after finish chunk has been emitted + NextToolCallIndex int // next sequential tool_call index to assign + OutputIndexToToolIndex map[int]int // Responses output_index → Chat tool_calls index + IncludeUsage bool + Usage *ChatUsage +} + +// NewResponsesEventToChatState returns an initialised stream state. +func NewResponsesEventToChatState() *ResponsesEventToChatState { + return &ResponsesEventToChatState{ + ID: generateChatCmplID(), + Created: time.Now().Unix(), + OutputIndexToToolIndex: make(map[int]int), + } +} + +// ResponsesEventToChatChunks converts a single Responses SSE event into zero +// or more Chat Completions chunks, updating state as it goes. +func ResponsesEventToChatChunks(evt *ResponsesStreamEvent, state *ResponsesEventToChatState) []ChatCompletionsChunk { + switch evt.Type { + case "response.created": + return resToChatHandleCreated(evt, state) + case "response.output_text.delta": + return resToChatHandleTextDelta(evt, state) + case "response.output_item.added": + return resToChatHandleOutputItemAdded(evt, state) + case "response.function_call_arguments.delta": + return resToChatHandleFuncArgsDelta(evt, state) + case "response.reasoning_summary_text.delta": + return resToChatHandleReasoningDelta(evt, state) + case "response.completed", "response.incomplete", "response.failed": + return resToChatHandleCompleted(evt, state) + default: + return nil + } +} + +// FinalizeResponsesChatStream emits a final chunk with finish_reason if the +// stream ended without a proper completion event (e.g. upstream disconnect). +// It is idempotent: if a completion event already emitted the finish chunk, +// this returns nil. +func FinalizeResponsesChatStream(state *ResponsesEventToChatState) []ChatCompletionsChunk { + if state.Finalized { + return nil + } + state.Finalized = true + + finishReason := "stop" + if state.SawToolCall { + finishReason = "tool_calls" + } + + chunks := []ChatCompletionsChunk{makeChatFinishChunk(state, finishReason)} + + if state.IncludeUsage && state.Usage != nil { + chunks = append(chunks, ChatCompletionsChunk{ + ID: state.ID, + Object: "chat.completion.chunk", + Created: state.Created, + Model: state.Model, + Choices: []ChatChunkChoice{}, + Usage: state.Usage, + }) + } + + return chunks +} + +// ChatChunkToSSE formats a ChatCompletionsChunk as an SSE data line. +func ChatChunkToSSE(chunk ChatCompletionsChunk) (string, error) { + data, err := json.Marshal(chunk) + if err != nil { + return "", err + } + return fmt.Sprintf("data: %s\n\n", data), nil +} + +// --- internal handlers --- + +func resToChatHandleCreated(evt *ResponsesStreamEvent, state *ResponsesEventToChatState) []ChatCompletionsChunk { + if evt.Response != nil { + if evt.Response.ID != "" { + state.ID = evt.Response.ID + } + if state.Model == "" && evt.Response.Model != "" { + state.Model = evt.Response.Model + } + } + // Emit the role chunk. + if state.SentRole { + return nil + } + state.SentRole = true + + role := "assistant" + return []ChatCompletionsChunk{makeChatDeltaChunk(state, ChatDelta{Role: role})} +} + +func resToChatHandleTextDelta(evt *ResponsesStreamEvent, state *ResponsesEventToChatState) []ChatCompletionsChunk { + if evt.Delta == "" { + return nil + } + state.SawText = true + content := evt.Delta + return []ChatCompletionsChunk{makeChatDeltaChunk(state, ChatDelta{Content: &content})} +} + +func resToChatHandleOutputItemAdded(evt *ResponsesStreamEvent, state *ResponsesEventToChatState) []ChatCompletionsChunk { + if evt.Item == nil || evt.Item.Type != "function_call" { + return nil + } + + state.SawToolCall = true + idx := state.NextToolCallIndex + state.OutputIndexToToolIndex[evt.OutputIndex] = idx + state.NextToolCallIndex++ + + return []ChatCompletionsChunk{makeChatDeltaChunk(state, ChatDelta{ + ToolCalls: []ChatToolCall{{ + Index: &idx, + ID: evt.Item.CallID, + Type: "function", + Function: ChatFunctionCall{ + Name: evt.Item.Name, + }, + }}, + })} +} + +func resToChatHandleFuncArgsDelta(evt *ResponsesStreamEvent, state *ResponsesEventToChatState) []ChatCompletionsChunk { + if evt.Delta == "" { + return nil + } + + idx, ok := state.OutputIndexToToolIndex[evt.OutputIndex] + if !ok { + return nil + } + + return []ChatCompletionsChunk{makeChatDeltaChunk(state, ChatDelta{ + ToolCalls: []ChatToolCall{{ + Index: &idx, + Function: ChatFunctionCall{ + Arguments: evt.Delta, + }, + }}, + })} +} + +func resToChatHandleReasoningDelta(evt *ResponsesStreamEvent, state *ResponsesEventToChatState) []ChatCompletionsChunk { + if evt.Delta == "" { + return nil + } + content := evt.Delta + return []ChatCompletionsChunk{makeChatDeltaChunk(state, ChatDelta{Content: &content})} +} + +func resToChatHandleCompleted(evt *ResponsesStreamEvent, state *ResponsesEventToChatState) []ChatCompletionsChunk { + state.Finalized = true + finishReason := "stop" + + if evt.Response != nil { + if evt.Response.Usage != nil { + u := evt.Response.Usage + usage := &ChatUsage{ + PromptTokens: u.InputTokens, + CompletionTokens: u.OutputTokens, + TotalTokens: u.InputTokens + u.OutputTokens, + } + if u.InputTokensDetails != nil && u.InputTokensDetails.CachedTokens > 0 { + usage.PromptTokensDetails = &ChatTokenDetails{ + CachedTokens: u.InputTokensDetails.CachedTokens, + } + } + state.Usage = usage + } + + switch evt.Response.Status { + case "incomplete": + if evt.Response.IncompleteDetails != nil && evt.Response.IncompleteDetails.Reason == "max_output_tokens" { + finishReason = "length" + } + case "completed": + if state.SawToolCall { + finishReason = "tool_calls" + } + } + } else if state.SawToolCall { + finishReason = "tool_calls" + } + + var chunks []ChatCompletionsChunk + chunks = append(chunks, makeChatFinishChunk(state, finishReason)) + + if state.IncludeUsage && state.Usage != nil { + chunks = append(chunks, ChatCompletionsChunk{ + ID: state.ID, + Object: "chat.completion.chunk", + Created: state.Created, + Model: state.Model, + Choices: []ChatChunkChoice{}, + Usage: state.Usage, + }) + } + + return chunks +} + +func makeChatDeltaChunk(state *ResponsesEventToChatState, delta ChatDelta) ChatCompletionsChunk { + return ChatCompletionsChunk{ + ID: state.ID, + Object: "chat.completion.chunk", + Created: state.Created, + Model: state.Model, + Choices: []ChatChunkChoice{{ + Index: 0, + Delta: delta, + FinishReason: nil, + }}, + } +} + +func makeChatFinishChunk(state *ResponsesEventToChatState, finishReason string) ChatCompletionsChunk { + empty := "" + return ChatCompletionsChunk{ + ID: state.ID, + Object: "chat.completion.chunk", + Created: state.Created, + Model: state.Model, + Choices: []ChatChunkChoice{{ + Index: 0, + Delta: ChatDelta{Content: &empty}, + FinishReason: &finishReason, + }}, + } +} + +// generateChatCmplID returns a "chatcmpl-" prefixed random hex ID. +func generateChatCmplID() string { + b := make([]byte, 12) + _, _ = rand.Read(b) + return "chatcmpl-" + hex.EncodeToString(b) +} diff --git a/backend/internal/pkg/apicompat/types.go b/backend/internal/pkg/apicompat/types.go index aa58b58f..eb77d89f 100644 --- a/backend/internal/pkg/apicompat/types.go +++ b/backend/internal/pkg/apicompat/types.go @@ -329,6 +329,148 @@ type ResponsesStreamEvent struct { SequenceNumber int `json:"sequence_number,omitempty"` } +// --------------------------------------------------------------------------- +// OpenAI Chat Completions API types +// --------------------------------------------------------------------------- + +// ChatCompletionsRequest is the request body for POST /v1/chat/completions. +type ChatCompletionsRequest struct { + Model string `json:"model"` + Messages []ChatMessage `json:"messages"` + MaxTokens *int `json:"max_tokens,omitempty"` + MaxCompletionTokens *int `json:"max_completion_tokens,omitempty"` + Temperature *float64 `json:"temperature,omitempty"` + TopP *float64 `json:"top_p,omitempty"` + Stream bool `json:"stream,omitempty"` + StreamOptions *ChatStreamOptions `json:"stream_options,omitempty"` + Tools []ChatTool `json:"tools,omitempty"` + ToolChoice json.RawMessage `json:"tool_choice,omitempty"` + ReasoningEffort string `json:"reasoning_effort,omitempty"` // "low" | "medium" | "high" + ServiceTier string `json:"service_tier,omitempty"` + Stop json.RawMessage `json:"stop,omitempty"` // string or []string + + // Legacy function calling (deprecated but still supported) + Functions []ChatFunction `json:"functions,omitempty"` + FunctionCall json.RawMessage `json:"function_call,omitempty"` +} + +// ChatStreamOptions configures streaming behavior. +type ChatStreamOptions struct { + IncludeUsage bool `json:"include_usage,omitempty"` +} + +// ChatMessage is a single message in the Chat Completions conversation. +type ChatMessage struct { + Role string `json:"role"` // "system" | "user" | "assistant" | "tool" | "function" + Content json.RawMessage `json:"content,omitempty"` + Name string `json:"name,omitempty"` + ToolCalls []ChatToolCall `json:"tool_calls,omitempty"` + ToolCallID string `json:"tool_call_id,omitempty"` + + // Legacy function calling + FunctionCall *ChatFunctionCall `json:"function_call,omitempty"` +} + +// ChatContentPart is a typed content part in a multi-modal message. +type ChatContentPart struct { + Type string `json:"type"` // "text" | "image_url" + Text string `json:"text,omitempty"` + ImageURL *ChatImageURL `json:"image_url,omitempty"` +} + +// ChatImageURL contains the URL for an image content part. +type ChatImageURL struct { + URL string `json:"url"` + Detail string `json:"detail,omitempty"` // "auto" | "low" | "high" +} + +// ChatTool describes a tool available to the model. +type ChatTool struct { + Type string `json:"type"` // "function" + Function *ChatFunction `json:"function,omitempty"` +} + +// ChatFunction describes a function tool definition. +type ChatFunction struct { + Name string `json:"name"` + Description string `json:"description,omitempty"` + Parameters json.RawMessage `json:"parameters,omitempty"` + Strict *bool `json:"strict,omitempty"` +} + +// ChatToolCall represents a tool call made by the assistant. +// Index is only populated in streaming chunks (omitted in non-streaming responses). +type ChatToolCall struct { + Index *int `json:"index,omitempty"` + ID string `json:"id,omitempty"` + Type string `json:"type,omitempty"` // "function" + Function ChatFunctionCall `json:"function"` +} + +// ChatFunctionCall contains the function name and arguments. +type ChatFunctionCall struct { + Name string `json:"name"` + Arguments string `json:"arguments"` +} + +// ChatCompletionsResponse is the non-streaming response from POST /v1/chat/completions. +type ChatCompletionsResponse struct { + ID string `json:"id"` + Object string `json:"object"` // "chat.completion" + Created int64 `json:"created"` + Model string `json:"model"` + Choices []ChatChoice `json:"choices"` + Usage *ChatUsage `json:"usage,omitempty"` + SystemFingerprint string `json:"system_fingerprint,omitempty"` + ServiceTier string `json:"service_tier,omitempty"` +} + +// ChatChoice is a single completion choice. +type ChatChoice struct { + Index int `json:"index"` + Message ChatMessage `json:"message"` + FinishReason string `json:"finish_reason"` // "stop" | "length" | "tool_calls" | "content_filter" +} + +// ChatUsage holds token counts in Chat Completions format. +type ChatUsage struct { + PromptTokens int `json:"prompt_tokens"` + CompletionTokens int `json:"completion_tokens"` + TotalTokens int `json:"total_tokens"` + PromptTokensDetails *ChatTokenDetails `json:"prompt_tokens_details,omitempty"` +} + +// ChatTokenDetails provides a breakdown of token usage. +type ChatTokenDetails struct { + CachedTokens int `json:"cached_tokens,omitempty"` +} + +// ChatCompletionsChunk is a single streaming chunk from POST /v1/chat/completions. +type ChatCompletionsChunk struct { + ID string `json:"id"` + Object string `json:"object"` // "chat.completion.chunk" + Created int64 `json:"created"` + Model string `json:"model"` + Choices []ChatChunkChoice `json:"choices"` + Usage *ChatUsage `json:"usage,omitempty"` + SystemFingerprint string `json:"system_fingerprint,omitempty"` + ServiceTier string `json:"service_tier,omitempty"` +} + +// ChatChunkChoice is a single choice in a streaming chunk. +type ChatChunkChoice struct { + Index int `json:"index"` + Delta ChatDelta `json:"delta"` + FinishReason *string `json:"finish_reason"` // pointer: null when not final +} + +// ChatDelta carries incremental content in a streaming chunk. +type ChatDelta struct { + Role string `json:"role,omitempty"` + Content *string `json:"content,omitempty"` // pointer: omit when not present, null vs "" matters + ToolCalls []ChatToolCall `json:"tool_calls,omitempty"` +} + // --------------------------------------------------------------------------- // Shared constants // --------------------------------------------------------------------------- diff --git a/backend/internal/pkg/gemini/models.go b/backend/internal/pkg/gemini/models.go index c300b17d..882d2ebd 100644 --- a/backend/internal/pkg/gemini/models.go +++ b/backend/internal/pkg/gemini/models.go @@ -18,10 +18,12 @@ func DefaultModels() []Model { return []Model{ {Name: "models/gemini-2.0-flash", SupportedGenerationMethods: methods}, {Name: "models/gemini-2.5-flash", SupportedGenerationMethods: methods}, + {Name: "models/gemini-2.5-flash-image", SupportedGenerationMethods: methods}, {Name: "models/gemini-2.5-pro", SupportedGenerationMethods: methods}, {Name: "models/gemini-3-flash-preview", SupportedGenerationMethods: methods}, {Name: "models/gemini-3-pro-preview", SupportedGenerationMethods: methods}, {Name: "models/gemini-3.1-pro-preview", SupportedGenerationMethods: methods}, + {Name: "models/gemini-3.1-flash-image", SupportedGenerationMethods: methods}, } } diff --git a/backend/internal/pkg/gemini/models_test.go b/backend/internal/pkg/gemini/models_test.go new file mode 100644 index 00000000..b80047fb --- /dev/null +++ b/backend/internal/pkg/gemini/models_test.go @@ -0,0 +1,28 @@ +package gemini + +import "testing" + +func TestDefaultModels_ContainsImageModels(t *testing.T) { + t.Parallel() + + models := DefaultModels() + byName := make(map[string]Model, len(models)) + for _, model := range models { + byName[model.Name] = model + } + + required := []string{ + "models/gemini-2.5-flash-image", + "models/gemini-3.1-flash-image", + } + + for _, name := range required { + model, ok := byName[name] + if !ok { + t.Fatalf("expected fallback model %q to exist", name) + } + if len(model.SupportedGenerationMethods) == 0 { + t.Fatalf("expected fallback model %q to advertise generation methods", name) + } + } +} diff --git a/backend/internal/pkg/geminicli/models.go b/backend/internal/pkg/geminicli/models.go index 1fc4d983..195fb06f 100644 --- a/backend/internal/pkg/geminicli/models.go +++ b/backend/internal/pkg/geminicli/models.go @@ -13,10 +13,12 @@ type Model struct { var DefaultModels = []Model{ {ID: "gemini-2.0-flash", Type: "model", DisplayName: "Gemini 2.0 Flash", CreatedAt: ""}, {ID: "gemini-2.5-flash", Type: "model", DisplayName: "Gemini 2.5 Flash", CreatedAt: ""}, + {ID: "gemini-2.5-flash-image", Type: "model", DisplayName: "Gemini 2.5 Flash Image", CreatedAt: ""}, {ID: "gemini-2.5-pro", Type: "model", DisplayName: "Gemini 2.5 Pro", CreatedAt: ""}, {ID: "gemini-3-flash-preview", Type: "model", DisplayName: "Gemini 3 Flash Preview", CreatedAt: ""}, {ID: "gemini-3-pro-preview", Type: "model", DisplayName: "Gemini 3 Pro Preview", CreatedAt: ""}, {ID: "gemini-3.1-pro-preview", Type: "model", DisplayName: "Gemini 3.1 Pro Preview", CreatedAt: ""}, + {ID: "gemini-3.1-flash-image", Type: "model", DisplayName: "Gemini 3.1 Flash Image", CreatedAt: ""}, } // DefaultTestModel is the default model to preselect in test flows. diff --git a/backend/internal/pkg/geminicli/models_test.go b/backend/internal/pkg/geminicli/models_test.go new file mode 100644 index 00000000..c1884e2e --- /dev/null +++ b/backend/internal/pkg/geminicli/models_test.go @@ -0,0 +1,23 @@ +package geminicli + +import "testing" + +func TestDefaultModels_ContainsImageModels(t *testing.T) { + t.Parallel() + + byID := make(map[string]Model, len(DefaultModels)) + for _, model := range DefaultModels { + byID[model.ID] = model + } + + required := []string{ + "gemini-2.5-flash-image", + "gemini-3.1-flash-image", + } + + for _, id := range required { + if _, ok := byID[id]; !ok { + t.Fatalf("expected curated Gemini model %q to exist", id) + } + } +} diff --git a/backend/internal/repository/account_repo.go b/backend/internal/repository/account_repo.go index daa3a878..4e7c418c 100644 --- a/backend/internal/repository/account_repo.go +++ b/backend/internal/repository/account_repo.go @@ -626,29 +626,6 @@ func (r *accountRepository) syncSchedulerAccountSnapshot(ctx context.Context, ac } } -func (r *accountRepository) patchSchedulerAccountExtra(ctx context.Context, accountID int64, updates map[string]any) { - if r == nil || r.schedulerCache == nil || accountID <= 0 || len(updates) == 0 { - return - } - account, err := r.schedulerCache.GetAccount(ctx, accountID) - if err != nil { - logger.LegacyPrintf("repository.account", "[Scheduler] patch account extra read failed: id=%d err=%v", accountID, err) - return - } - if account == nil { - return - } - if account.Extra == nil { - account.Extra = make(map[string]any, len(updates)) - } - for key, value := range updates { - account.Extra[key] = value - } - if err := r.schedulerCache.SetAccount(ctx, account); err != nil { - logger.LegacyPrintf("repository.account", "[Scheduler] patch account extra write failed: id=%d err=%v", accountID, err) - } -} - func (r *accountRepository) syncSchedulerAccountSnapshots(ctx context.Context, accountIDs []int64) { if r == nil || r.schedulerCache == nil || len(accountIDs) == 0 { return @@ -1221,15 +1198,15 @@ func (r *accountRepository) UpdateExtra(ctx context.Context, id int64, updates m if affected == 0 { return service.ErrAccountNotFound } - if shouldEnqueueSchedulerOutboxForExtraUpdates(updates) { if err := enqueueSchedulerOutbox(ctx, r.sql, service.SchedulerOutboxEventAccountChanged, &id, nil, nil); err != nil { logger.LegacyPrintf("repository.account", "[SchedulerOutbox] enqueue extra update failed: account=%d err=%v", id, err) } } else { - // 观测型 extra 字段不需要触发 bucket 重建,但尽量把单账号缓存补到最新, - // 让 sticky session / GetAccount 命中缓存时也能读到最新快照。 - r.patchSchedulerAccountExtra(ctx, id, updates) + // 观测型 extra 字段不需要触发 bucket 重建,但仍同步单账号快照, + // 让 sticky session / GetAccount 命中缓存时也能读到最新数据, + // 同时避免缓存局部 patch 覆盖掉并发写入的其它账号字段。 + r.syncSchedulerAccountSnapshot(ctx, id) } return nil } @@ -1239,9 +1216,10 @@ func shouldEnqueueSchedulerOutboxForExtraUpdates(updates map[string]any) bool { return false } for key := range updates { - if !isSchedulerNeutralExtraKey(key) { - return true + if isSchedulerNeutralExtraKey(key) { + continue } + return true } return false } @@ -1262,6 +1240,82 @@ func isSchedulerNeutralExtraKey(key string) bool { return false } +func shouldSyncSchedulerSnapshotForExtraUpdates(updates map[string]any) bool { + return codexExtraIndicatesRateLimit(updates, "7d") || codexExtraIndicatesRateLimit(updates, "5h") +} + +func codexExtraIndicatesRateLimit(updates map[string]any, window string) bool { + if len(updates) == 0 { + return false + } + usedValue, ok := updates["codex_"+window+"_used_percent"] + if !ok || !extraValueIndicatesExhausted(usedValue) { + return false + } + return extraValueHasResetMarker(updates["codex_"+window+"_reset_at"]) || + extraValueHasPositiveNumber(updates["codex_"+window+"_reset_after_seconds"]) +} + +func extraValueIndicatesExhausted(value any) bool { + number, ok := extraValueToFloat64(value) + return ok && number >= 100-1e-9 +} + +func extraValueHasPositiveNumber(value any) bool { + number, ok := extraValueToFloat64(value) + return ok && number > 0 +} + +func extraValueHasResetMarker(value any) bool { + switch v := value.(type) { + case string: + return strings.TrimSpace(v) != "" + case time.Time: + return !v.IsZero() + case *time.Time: + return v != nil && !v.IsZero() + default: + return false + } +} + +func extraValueToFloat64(value any) (float64, bool) { + switch v := value.(type) { + case float64: + return v, true + case float32: + return float64(v), true + case int: + return float64(v), true + case int8: + return float64(v), true + case int16: + return float64(v), true + case int32: + return float64(v), true + case int64: + return float64(v), true + case uint: + return float64(v), true + case uint8: + return float64(v), true + case uint16: + return float64(v), true + case uint32: + return float64(v), true + case uint64: + return float64(v), true + case json.Number: + parsed, err := v.Float64() + return parsed, err == nil + case string: + parsed, err := strconv.ParseFloat(strings.TrimSpace(v), 64) + return parsed, err == nil + default: + return 0, false + } +} + func (r *accountRepository) BulkUpdate(ctx context.Context, ids []int64, updates service.AccountBulkUpdate) (int64, error) { if len(ids) == 0 { return 0, nil diff --git a/backend/internal/repository/account_repo_integration_test.go b/backend/internal/repository/account_repo_integration_test.go index 1381fd11..29b699e6 100644 --- a/backend/internal/repository/account_repo_integration_test.go +++ b/backend/internal/repository/account_repo_integration_test.go @@ -633,7 +633,7 @@ func (s *AccountRepoSuite) TestUpdateExtra_NilExtra() { s.Require().Equal("val", got.Extra["key"]) } -func (s *AccountRepoSuite) TestUpdateExtra_SchedulerNeutralSkipsOutboxAndPatchesCache() { +func (s *AccountRepoSuite) TestUpdateExtra_SchedulerNeutralSkipsOutboxAndSyncsFreshSnapshot() { account := mustCreateAccount(s.T(), s.client, &service.Account{ Name: "acc-extra-neutral", Platform: service.PlatformOpenAI, @@ -644,6 +644,7 @@ func (s *AccountRepoSuite) TestUpdateExtra_SchedulerNeutralSkipsOutboxAndPatches account.ID: { ID: account.ID, Platform: account.Platform, + Status: service.StatusDisabled, Extra: map[string]any{ "codex_usage_updated_at": "old", }, @@ -670,25 +671,56 @@ func (s *AccountRepoSuite) TestUpdateExtra_SchedulerNeutralSkipsOutboxAndPatches s.Require().Zero(outboxCount) s.Require().Len(cacheRecorder.setAccounts, 1) s.Require().NotNil(cacheRecorder.accounts[account.ID]) + s.Require().Equal(service.StatusActive, cacheRecorder.accounts[account.ID].Status) s.Require().Equal("2026-03-11T10:00:00Z", cacheRecorder.accounts[account.ID].Extra["codex_usage_updated_at"]) } +func (s *AccountRepoSuite) TestUpdateExtra_ExhaustedCodexSnapshotSyncsSchedulerCache() { + account := mustCreateAccount(s.T(), s.client, &service.Account{ + Name: "acc-extra-codex-exhausted", + Platform: service.PlatformOpenAI, + Type: service.AccountTypeOAuth, + Extra: map[string]any{}, + }) + cacheRecorder := &schedulerCacheRecorder{} + s.repo.schedulerCache = cacheRecorder + _, err := s.repo.sql.ExecContext(s.ctx, "TRUNCATE scheduler_outbox") + s.Require().NoError(err) + + s.Require().NoError(s.repo.UpdateExtra(s.ctx, account.ID, map[string]any{ + "codex_7d_used_percent": 100.0, + "codex_7d_reset_at": "2026-03-12T13:00:00Z", + "codex_7d_reset_after_seconds": 86400, + })) + + var count int + err = scanSingleRow(s.ctx, s.repo.sql, "SELECT COUNT(*) FROM scheduler_outbox", nil, &count) + s.Require().NoError(err) + s.Require().Equal(0, count) + s.Require().Len(cacheRecorder.setAccounts, 1) + s.Require().Equal(account.ID, cacheRecorder.setAccounts[0].ID) + s.Require().Equal(service.StatusActive, cacheRecorder.setAccounts[0].Status) + s.Require().Equal(100.0, cacheRecorder.setAccounts[0].Extra["codex_7d_used_percent"]) +} + func (s *AccountRepoSuite) TestUpdateExtra_SchedulerRelevantStillEnqueuesOutbox() { account := mustCreateAccount(s.T(), s.client, &service.Account{ Name: "acc-extra-mixed", Platform: service.PlatformAntigravity, Extra: map[string]any{}, }) + _, err := s.repo.sql.ExecContext(s.ctx, "TRUNCATE scheduler_outbox") + s.Require().NoError(err) - updates := map[string]any{ + s.Require().NoError(s.repo.UpdateExtra(s.ctx, account.ID, map[string]any{ "mixed_scheduling": true, "codex_usage_updated_at": "2026-03-11T10:00:00Z", - } - s.Require().NoError(s.repo.UpdateExtra(s.ctx, account.ID, updates)) + })) - var outboxCount int - s.Require().NoError(scanSingleRow(s.ctx, s.repo.sql, "SELECT COUNT(*) FROM scheduler_outbox", nil, &outboxCount)) - s.Require().Equal(1, outboxCount) + var count int + err = scanSingleRow(s.ctx, s.repo.sql, "SELECT COUNT(*) FROM scheduler_outbox", nil, &count) + s.Require().NoError(err) + s.Require().Equal(1, count) } // --- GetByCRSAccountID --- diff --git a/backend/internal/repository/api_key_repo.go b/backend/internal/repository/api_key_repo.go index 95db1819..4c7f38a8 100644 --- a/backend/internal/repository/api_key_repo.go +++ b/backend/internal/repository/api_key_repo.go @@ -452,6 +452,32 @@ func (r *apiKeyRepository) IncrementQuotaUsed(ctx context.Context, id int64, amo return updated.QuotaUsed, nil } +// IncrementQuotaUsedAndGetState atomically increments quota_used, conditionally marks the key +// as quota_exhausted, and returns the latest quota state in one round trip. +func (r *apiKeyRepository) IncrementQuotaUsedAndGetState(ctx context.Context, id int64, amount float64) (*service.APIKeyQuotaUsageState, error) { + query := ` + UPDATE api_keys + SET + quota_used = quota_used + $1, + status = CASE + WHEN quota > 0 AND quota_used + $1 >= quota THEN $2 + ELSE status + END, + updated_at = NOW() + WHERE id = $3 AND deleted_at IS NULL + RETURNING quota_used, quota, key, status + ` + + state := &service.APIKeyQuotaUsageState{} + if err := scanSingleRow(ctx, r.sql, query, []any{amount, service.StatusAPIKeyQuotaExhausted, id}, &state.QuotaUsed, &state.Quota, &state.Key, &state.Status); err != nil { + if err == sql.ErrNoRows { + return nil, service.ErrAPIKeyNotFound + } + return nil, err + } + return state, nil +} + func (r *apiKeyRepository) UpdateLastUsed(ctx context.Context, id int64, usedAt time.Time) error { affected, err := r.client.APIKey.Update(). Where(apikey.IDEQ(id), apikey.DeletedAtIsNil()). diff --git a/backend/internal/repository/api_key_repo_integration_test.go b/backend/internal/repository/api_key_repo_integration_test.go index 80714614..a8989ff2 100644 --- a/backend/internal/repository/api_key_repo_integration_test.go +++ b/backend/internal/repository/api_key_repo_integration_test.go @@ -417,6 +417,27 @@ func (s *APIKeyRepoSuite) TestIncrementQuotaUsed_DeletedKey() { s.Require().ErrorIs(err, service.ErrAPIKeyNotFound, "已删除的 key 应返回 ErrAPIKeyNotFound") } +func (s *APIKeyRepoSuite) TestIncrementQuotaUsedAndGetState() { + user := s.mustCreateUser("quota-state@test.com") + key := s.mustCreateApiKey(user.ID, "sk-quota-state", "QuotaState", nil) + key.Quota = 3 + key.QuotaUsed = 1 + s.Require().NoError(s.repo.Update(s.ctx, key), "Update quota") + + state, err := s.repo.IncrementQuotaUsedAndGetState(s.ctx, key.ID, 2.5) + s.Require().NoError(err, "IncrementQuotaUsedAndGetState") + s.Require().NotNil(state) + s.Require().Equal(3.5, state.QuotaUsed) + s.Require().Equal(3.0, state.Quota) + s.Require().Equal(service.StatusAPIKeyQuotaExhausted, state.Status) + s.Require().Equal(key.Key, state.Key) + + got, err := s.repo.GetByID(s.ctx, key.ID) + s.Require().NoError(err, "GetByID") + s.Require().Equal(3.5, got.QuotaUsed) + s.Require().Equal(service.StatusAPIKeyQuotaExhausted, got.Status) +} + // TestIncrementQuotaUsed_Concurrent 使用真实数据库验证并发原子性。 // 注意:此测试使用 testEntClient(非事务隔离),数据会真正写入数据库。 func TestIncrementQuotaUsed_Concurrent(t *testing.T) { diff --git a/backend/internal/repository/ops_repo.go b/backend/internal/repository/ops_repo.go index 989573f2..02ca1a3b 100644 --- a/backend/internal/repository/ops_repo.go +++ b/backend/internal/repository/ops_repo.go @@ -16,19 +16,7 @@ type opsRepository struct { db *sql.DB } -func NewOpsRepository(db *sql.DB) service.OpsRepository { - return &opsRepository{db: db} -} - -func (r *opsRepository) InsertErrorLog(ctx context.Context, input *service.OpsInsertErrorLogInput) (int64, error) { - if r == nil || r.db == nil { - return 0, fmt.Errorf("nil ops repository") - } - if input == nil { - return 0, fmt.Errorf("nil input") - } - - q := ` +const insertOpsErrorLogSQL = ` INSERT INTO ops_error_logs ( request_id, client_request_id, @@ -70,12 +58,77 @@ INSERT INTO ops_error_logs ( created_at ) VALUES ( $1,$2,$3,$4,$5,$6,$7,$8,$9,$10,$11,$12,$13,$14,$15,$16,$17,$18,$19,$20,$21,$22,$23,$24,$25,$26,$27,$28,$29,$30,$31,$32,$33,$34,$35,$36,$37,$38 -) RETURNING id` +)` + +func NewOpsRepository(db *sql.DB) service.OpsRepository { + return &opsRepository{db: db} +} + +func (r *opsRepository) InsertErrorLog(ctx context.Context, input *service.OpsInsertErrorLogInput) (int64, error) { + if r == nil || r.db == nil { + return 0, fmt.Errorf("nil ops repository") + } + if input == nil { + return 0, fmt.Errorf("nil input") + } var id int64 err := r.db.QueryRowContext( ctx, - q, + insertOpsErrorLogSQL+" RETURNING id", + opsInsertErrorLogArgs(input)..., + ).Scan(&id) + if err != nil { + return 0, err + } + return id, nil +} + +func (r *opsRepository) BatchInsertErrorLogs(ctx context.Context, inputs []*service.OpsInsertErrorLogInput) (int64, error) { + if r == nil || r.db == nil { + return 0, fmt.Errorf("nil ops repository") + } + if len(inputs) == 0 { + return 0, nil + } + + tx, err := r.db.BeginTx(ctx, nil) + if err != nil { + return 0, err + } + defer func() { + if err != nil { + _ = tx.Rollback() + } + }() + + stmt, err := tx.PrepareContext(ctx, insertOpsErrorLogSQL) + if err != nil { + return 0, err + } + defer func() { + _ = stmt.Close() + }() + + var inserted int64 + for _, input := range inputs { + if input == nil { + continue + } + if _, err = stmt.ExecContext(ctx, opsInsertErrorLogArgs(input)...); err != nil { + return inserted, err + } + inserted++ + } + + if err = tx.Commit(); err != nil { + return inserted, err + } + return inserted, nil +} + +func opsInsertErrorLogArgs(input *service.OpsInsertErrorLogInput) []any { + return []any{ opsNullString(input.RequestID), opsNullString(input.ClientRequestID), opsNullInt64(input.UserID), @@ -114,11 +167,7 @@ INSERT INTO ops_error_logs ( input.IsRetryable, input.RetryCount, input.CreatedAt, - ).Scan(&id) - if err != nil { - return 0, err } - return id, nil } func (r *opsRepository) ListErrorLogs(ctx context.Context, filter *service.OpsErrorLogFilter) (*service.OpsErrorLogList, error) { diff --git a/backend/internal/repository/ops_write_pressure_integration_test.go b/backend/internal/repository/ops_write_pressure_integration_test.go new file mode 100644 index 00000000..ebb7a842 --- /dev/null +++ b/backend/internal/repository/ops_write_pressure_integration_test.go @@ -0,0 +1,79 @@ +//go:build integration + +package repository + +import ( + "context" + "testing" + "time" + + "github.com/Wei-Shaw/sub2api/internal/service" + "github.com/stretchr/testify/require" +) + +func TestOpsRepositoryBatchInsertErrorLogs(t *testing.T) { + ctx := context.Background() + _, _ = integrationDB.ExecContext(ctx, "TRUNCATE ops_error_logs RESTART IDENTITY") + + repo := NewOpsRepository(integrationDB).(*opsRepository) + now := time.Now().UTC() + inserted, err := repo.BatchInsertErrorLogs(ctx, []*service.OpsInsertErrorLogInput{ + { + RequestID: "batch-ops-1", + ErrorPhase: "upstream", + ErrorType: "upstream_error", + Severity: "error", + StatusCode: 429, + ErrorMessage: "rate limited", + CreatedAt: now, + }, + { + RequestID: "batch-ops-2", + ErrorPhase: "internal", + ErrorType: "api_error", + Severity: "error", + StatusCode: 500, + ErrorMessage: "internal error", + CreatedAt: now.Add(time.Millisecond), + }, + }) + require.NoError(t, err) + require.EqualValues(t, 2, inserted) + + var count int + require.NoError(t, integrationDB.QueryRowContext(ctx, "SELECT COUNT(*) FROM ops_error_logs WHERE request_id IN ('batch-ops-1', 'batch-ops-2')").Scan(&count)) + require.Equal(t, 2, count) +} + +func TestEnqueueSchedulerOutbox_DeduplicatesIdempotentEvents(t *testing.T) { + ctx := context.Background() + _, _ = integrationDB.ExecContext(ctx, "TRUNCATE scheduler_outbox RESTART IDENTITY") + + accountID := int64(12345) + require.NoError(t, enqueueSchedulerOutbox(ctx, integrationDB, service.SchedulerOutboxEventAccountChanged, &accountID, nil, nil)) + require.NoError(t, enqueueSchedulerOutbox(ctx, integrationDB, service.SchedulerOutboxEventAccountChanged, &accountID, nil, nil)) + + var count int + require.NoError(t, integrationDB.QueryRowContext(ctx, "SELECT COUNT(*) FROM scheduler_outbox WHERE event_type = $1", service.SchedulerOutboxEventAccountChanged).Scan(&count)) + require.Equal(t, 1, count) + + time.Sleep(schedulerOutboxDedupWindow + 150*time.Millisecond) + require.NoError(t, enqueueSchedulerOutbox(ctx, integrationDB, service.SchedulerOutboxEventAccountChanged, &accountID, nil, nil)) + require.NoError(t, integrationDB.QueryRowContext(ctx, "SELECT COUNT(*) FROM scheduler_outbox WHERE event_type = $1", service.SchedulerOutboxEventAccountChanged).Scan(&count)) + require.Equal(t, 2, count) +} + +func TestEnqueueSchedulerOutbox_DoesNotDeduplicateLastUsed(t *testing.T) { + ctx := context.Background() + _, _ = integrationDB.ExecContext(ctx, "TRUNCATE scheduler_outbox RESTART IDENTITY") + + accountID := int64(67890) + payload1 := map[string]any{"last_used": map[string]int64{"67890": 100}} + payload2 := map[string]any{"last_used": map[string]int64{"67890": 200}} + require.NoError(t, enqueueSchedulerOutbox(ctx, integrationDB, service.SchedulerOutboxEventAccountLastUsed, &accountID, nil, payload1)) + require.NoError(t, enqueueSchedulerOutbox(ctx, integrationDB, service.SchedulerOutboxEventAccountLastUsed, &accountID, nil, payload2)) + + var count int + require.NoError(t, integrationDB.QueryRowContext(ctx, "SELECT COUNT(*) FROM scheduler_outbox WHERE event_type = $1", service.SchedulerOutboxEventAccountLastUsed).Scan(&count)) + require.Equal(t, 2, count) +} diff --git a/backend/internal/repository/scheduler_outbox_repo.go b/backend/internal/repository/scheduler_outbox_repo.go index d7bc97da..4b9a9f58 100644 --- a/backend/internal/repository/scheduler_outbox_repo.go +++ b/backend/internal/repository/scheduler_outbox_repo.go @@ -4,6 +4,7 @@ import ( "context" "database/sql" "encoding/json" + "time" "github.com/Wei-Shaw/sub2api/internal/service" ) @@ -12,6 +13,8 @@ type schedulerOutboxRepository struct { db *sql.DB } +const schedulerOutboxDedupWindow = time.Second + func NewSchedulerOutboxRepository(db *sql.DB) service.SchedulerOutboxRepository { return &schedulerOutboxRepository{db: db} } @@ -88,9 +91,37 @@ func enqueueSchedulerOutbox(ctx context.Context, exec sqlExecutor, eventType str } payloadArg = encoded } - _, err := exec.ExecContext(ctx, ` + query := ` INSERT INTO scheduler_outbox (event_type, account_id, group_id, payload) VALUES ($1, $2, $3, $4) - `, eventType, accountID, groupID, payloadArg) + ` + args := []any{eventType, accountID, groupID, payloadArg} + if schedulerOutboxEventSupportsDedup(eventType) { + query = ` + INSERT INTO scheduler_outbox (event_type, account_id, group_id, payload) + SELECT $1, $2, $3, $4 + WHERE NOT EXISTS ( + SELECT 1 + FROM scheduler_outbox + WHERE event_type = $1 + AND account_id IS NOT DISTINCT FROM $2 + AND group_id IS NOT DISTINCT FROM $3 + AND created_at >= NOW() - make_interval(secs => $5) + ) + ` + args = append(args, schedulerOutboxDedupWindow.Seconds()) + } + _, err := exec.ExecContext(ctx, query, args...) return err } + +func schedulerOutboxEventSupportsDedup(eventType string) bool { + switch eventType { + case service.SchedulerOutboxEventAccountChanged, + service.SchedulerOutboxEventGroupChanged, + service.SchedulerOutboxEventFullRebuild: + return true + default: + return false + } +} diff --git a/backend/internal/server/routes/admin.go b/backend/internal/server/routes/admin.go index a69f1595..9fdb233b 100644 --- a/backend/internal/server/routes/admin.go +++ b/backend/internal/server/routes/admin.go @@ -456,6 +456,7 @@ func registerSubscriptionRoutes(admin *gin.RouterGroup, h *handler.Handlers) { subscriptions.POST("/assign", h.Admin.Subscription.Assign) subscriptions.POST("/bulk-assign", h.Admin.Subscription.BulkAssign) subscriptions.POST("/:id/extend", h.Admin.Subscription.Extend) + subscriptions.POST("/:id/reset-quota", h.Admin.Subscription.ResetQuota) subscriptions.DELETE("/:id", h.Admin.Subscription.Revoke) } diff --git a/backend/internal/server/routes/gateway.go b/backend/internal/server/routes/gateway.go index ade2ed83..ea40f2f1 100644 --- a/backend/internal/server/routes/gateway.go +++ b/backend/internal/server/routes/gateway.go @@ -71,15 +71,8 @@ func RegisterGatewayRoutes( gateway.POST("/responses", h.OpenAIGateway.Responses) gateway.POST("/responses/*subpath", h.OpenAIGateway.Responses) gateway.GET("/responses", h.OpenAIGateway.ResponsesWebSocket) - // 明确阻止旧协议入口:OpenAI 仅支持 Responses API,避免客户端误解为会自动路由到其它平台。 - gateway.POST("/chat/completions", func(c *gin.Context) { - c.JSON(http.StatusBadRequest, gin.H{ - "error": gin.H{ - "type": "invalid_request_error", - "message": "Unsupported legacy protocol: /v1/chat/completions is not supported. Please use /v1/responses.", - }, - }) - }) + // OpenAI Chat Completions API + gateway.POST("/chat/completions", h.OpenAIGateway.ChatCompletions) } // Gemini 原生 API 兼容层(Gemini SDK/CLI 直连) @@ -100,6 +93,8 @@ func RegisterGatewayRoutes( r.POST("/responses", bodyLimit, clientRequestID, opsErrorLogger, gin.HandlerFunc(apiKeyAuth), requireGroupAnthropic, h.OpenAIGateway.Responses) r.POST("/responses/*subpath", bodyLimit, clientRequestID, opsErrorLogger, gin.HandlerFunc(apiKeyAuth), requireGroupAnthropic, h.OpenAIGateway.Responses) r.GET("/responses", bodyLimit, clientRequestID, opsErrorLogger, gin.HandlerFunc(apiKeyAuth), requireGroupAnthropic, h.OpenAIGateway.ResponsesWebSocket) + // OpenAI Chat Completions API(不带v1前缀的别名) + r.POST("/chat/completions", bodyLimit, clientRequestID, opsErrorLogger, gin.HandlerFunc(apiKeyAuth), requireGroupAnthropic, h.OpenAIGateway.ChatCompletions) // Antigravity 模型列表 r.GET("/antigravity/models", gin.HandlerFunc(apiKeyAuth), requireGroupAnthropic, h.Gateway.AntigravityModels) diff --git a/backend/internal/service/account_test_service.go b/backend/internal/service/account_test_service.go index b44f29fd..472551cf 100644 --- a/backend/internal/service/account_test_service.go +++ b/backend/internal/service/account_test_service.go @@ -45,16 +45,23 @@ const ( // TestEvent represents a SSE event for account testing type TestEvent struct { - Type string `json:"type"` - Text string `json:"text,omitempty"` - Model string `json:"model,omitempty"` - Status string `json:"status,omitempty"` - Code string `json:"code,omitempty"` - Data any `json:"data,omitempty"` - Success bool `json:"success,omitempty"` - Error string `json:"error,omitempty"` + Type string `json:"type"` + Text string `json:"text,omitempty"` + Model string `json:"model,omitempty"` + Status string `json:"status,omitempty"` + Code string `json:"code,omitempty"` + ImageURL string `json:"image_url,omitempty"` + MimeType string `json:"mime_type,omitempty"` + Data any `json:"data,omitempty"` + Success bool `json:"success,omitempty"` + Error string `json:"error,omitempty"` } +const ( + defaultGeminiTextTestPrompt = "hi" + defaultGeminiImageTestPrompt = "Generate a cute orange cat astronaut sticker on a clean pastel background." +) + // AccountTestService handles account testing operations type AccountTestService struct { accountRepo AccountRepository @@ -161,7 +168,7 @@ func createTestPayload(modelID string) (map[string]any, error) { // TestAccountConnection tests an account's connection by sending a test request // All account types use full Claude Code client characteristics, only auth header differs // modelID is optional - if empty, defaults to claude.DefaultTestModel -func (s *AccountTestService) TestAccountConnection(c *gin.Context, accountID int64, modelID string) error { +func (s *AccountTestService) TestAccountConnection(c *gin.Context, accountID int64, modelID string, prompt string) error { ctx := c.Request.Context() // Get account @@ -176,11 +183,11 @@ func (s *AccountTestService) TestAccountConnection(c *gin.Context, accountID int } if account.IsGemini() { - return s.testGeminiAccountConnection(c, account, modelID) + return s.testGeminiAccountConnection(c, account, modelID, prompt) } if account.Platform == PlatformAntigravity { - return s.routeAntigravityTest(c, account, modelID) + return s.routeAntigravityTest(c, account, modelID, prompt) } if account.Platform == PlatformSora { @@ -435,7 +442,7 @@ func (s *AccountTestService) testOpenAIAccountConnection(c *gin.Context, account } // testGeminiAccountConnection tests a Gemini account's connection -func (s *AccountTestService) testGeminiAccountConnection(c *gin.Context, account *Account, modelID string) error { +func (s *AccountTestService) testGeminiAccountConnection(c *gin.Context, account *Account, modelID string, prompt string) error { ctx := c.Request.Context() // Determine the model to use @@ -462,7 +469,7 @@ func (s *AccountTestService) testGeminiAccountConnection(c *gin.Context, account c.Writer.Flush() // Create test payload (Gemini format) - payload := createGeminiTestPayload() + payload := createGeminiTestPayload(testModelID, prompt) // Build request based on account type var req *http.Request @@ -1198,10 +1205,10 @@ func truncateSoraErrorBody(body []byte, max int) string { // routeAntigravityTest 路由 Antigravity 账号的测试请求。 // APIKey 类型走原生协议(与 gateway_handler 路由一致),OAuth/Upstream 走 CRS 中转。 -func (s *AccountTestService) routeAntigravityTest(c *gin.Context, account *Account, modelID string) error { +func (s *AccountTestService) routeAntigravityTest(c *gin.Context, account *Account, modelID string, prompt string) error { if account.Type == AccountTypeAPIKey { if strings.HasPrefix(modelID, "gemini-") { - return s.testGeminiAccountConnection(c, account, modelID) + return s.testGeminiAccountConnection(c, account, modelID, prompt) } return s.testClaudeAccountConnection(c, account, modelID) } @@ -1349,14 +1356,46 @@ func (s *AccountTestService) buildCodeAssistRequest(ctx context.Context, accessT return req, nil } -// createGeminiTestPayload creates a minimal test payload for Gemini API -func createGeminiTestPayload() []byte { +// createGeminiTestPayload creates a minimal test payload for Gemini API. +// Image models use the image-generation path so the frontend can preview the returned image. +func createGeminiTestPayload(modelID string, prompt string) []byte { + if isImageGenerationModel(modelID) { + imagePrompt := strings.TrimSpace(prompt) + if imagePrompt == "" { + imagePrompt = defaultGeminiImageTestPrompt + } + + payload := map[string]any{ + "contents": []map[string]any{ + { + "role": "user", + "parts": []map[string]any{ + {"text": imagePrompt}, + }, + }, + }, + "generationConfig": map[string]any{ + "responseModalities": []string{"TEXT", "IMAGE"}, + "imageConfig": map[string]any{ + "aspectRatio": "1:1", + }, + }, + } + bytes, _ := json.Marshal(payload) + return bytes + } + + textPrompt := strings.TrimSpace(prompt) + if textPrompt == "" { + textPrompt = defaultGeminiTextTestPrompt + } + payload := map[string]any{ "contents": []map[string]any{ { "role": "user", "parts": []map[string]any{ - {"text": "hi"}, + {"text": textPrompt}, }, }, }, @@ -1416,6 +1455,17 @@ func (s *AccountTestService) processGeminiStream(c *gin.Context, body io.Reader) if text, ok := partMap["text"].(string); ok && text != "" { s.sendEvent(c, TestEvent{Type: "content", Text: text}) } + if inlineData, ok := partMap["inlineData"].(map[string]any); ok { + mimeType, _ := inlineData["mimeType"].(string) + data, _ := inlineData["data"].(string) + if strings.HasPrefix(strings.ToLower(mimeType), "image/") && data != "" { + s.sendEvent(c, TestEvent{ + Type: "image", + ImageURL: fmt.Sprintf("data:%s;base64,%s", mimeType, data), + MimeType: mimeType, + }) + } + } } } } @@ -1602,7 +1652,7 @@ func (s *AccountTestService) RunTestBackground(ctx context.Context, accountID in ginCtx, _ := gin.CreateTestContext(w) ginCtx.Request = (&http.Request{}).WithContext(ctx) - testErr := s.TestAccountConnection(ginCtx, accountID, modelID) + testErr := s.TestAccountConnection(ginCtx, accountID, modelID, "") finishedAt := time.Now() body := w.Body.String() diff --git a/backend/internal/service/account_test_service_gemini_test.go b/backend/internal/service/account_test_service_gemini_test.go new file mode 100644 index 00000000..5ba04c69 --- /dev/null +++ b/backend/internal/service/account_test_service_gemini_test.go @@ -0,0 +1,59 @@ +//go:build unit + +package service + +import ( + "encoding/json" + "strings" + "testing" + + "github.com/gin-gonic/gin" + "github.com/stretchr/testify/require" +) + +func TestCreateGeminiTestPayload_ImageModel(t *testing.T) { + t.Parallel() + + payload := createGeminiTestPayload("gemini-2.5-flash-image", "draw a tiny robot") + + var parsed struct { + Contents []struct { + Parts []struct { + Text string `json:"text"` + } `json:"parts"` + } `json:"contents"` + GenerationConfig struct { + ResponseModalities []string `json:"responseModalities"` + ImageConfig struct { + AspectRatio string `json:"aspectRatio"` + } `json:"imageConfig"` + } `json:"generationConfig"` + } + + require.NoError(t, json.Unmarshal(payload, &parsed)) + require.Len(t, parsed.Contents, 1) + require.Len(t, parsed.Contents[0].Parts, 1) + require.Equal(t, "draw a tiny robot", parsed.Contents[0].Parts[0].Text) + require.Equal(t, []string{"TEXT", "IMAGE"}, parsed.GenerationConfig.ResponseModalities) + require.Equal(t, "1:1", parsed.GenerationConfig.ImageConfig.AspectRatio) +} + +func TestProcessGeminiStream_EmitsImageEvent(t *testing.T) { + t.Parallel() + gin.SetMode(gin.TestMode) + + ctx, recorder := newSoraTestContext() + svc := &AccountTestService{} + + stream := strings.NewReader("data: {\"candidates\":[{\"content\":{\"parts\":[{\"text\":\"ok\"},{\"inlineData\":{\"mimeType\":\"image/png\",\"data\":\"QUJD\"}}]}}]}\n\ndata: [DONE]\n\n") + + err := svc.processGeminiStream(ctx, stream) + require.NoError(t, err) + + body := recorder.Body.String() + require.Contains(t, body, "\"type\":\"content\"") + require.Contains(t, body, "\"text\":\"ok\"") + require.Contains(t, body, "\"type\":\"image\"") + require.Contains(t, body, "\"image_url\":\"data:image/png;base64,QUJD\"") + require.Contains(t, body, "\"mime_type\":\"image/png\"") +} diff --git a/backend/internal/service/account_usage_service.go b/backend/internal/service/account_usage_service.go index 7c001118..e4245133 100644 --- a/backend/internal/service/account_usage_service.go +++ b/backend/internal/service/account_usage_service.go @@ -369,8 +369,11 @@ func (s *AccountUsageService) getOpenAIUsage(ctx context.Context, account *Accou } if shouldRefreshOpenAICodexSnapshot(account, usage, now) && s.shouldProbeOpenAICodexSnapshot(account.ID, now) { - if updates, err := s.probeOpenAICodexSnapshot(ctx, account); err == nil && len(updates) > 0 { + if updates, resetAt, err := s.probeOpenAICodexSnapshot(ctx, account); err == nil && (len(updates) > 0 || resetAt != nil) { mergeAccountExtra(account, updates) + if resetAt != nil { + account.RateLimitResetAt = resetAt + } if usage.UpdatedAt == nil { usage.UpdatedAt = &now } @@ -457,26 +460,26 @@ func (s *AccountUsageService) shouldProbeOpenAICodexSnapshot(accountID int64, no return true } -func (s *AccountUsageService) probeOpenAICodexSnapshot(ctx context.Context, account *Account) (map[string]any, error) { +func (s *AccountUsageService) probeOpenAICodexSnapshot(ctx context.Context, account *Account) (map[string]any, *time.Time, error) { if account == nil || !account.IsOAuth() { - return nil, nil + return nil, nil, nil } accessToken := account.GetOpenAIAccessToken() if accessToken == "" { - return nil, fmt.Errorf("no access token available") + return nil, nil, fmt.Errorf("no access token available") } modelID := openaipkg.DefaultTestModel payload := createOpenAITestPayload(modelID, true) payloadBytes, err := json.Marshal(payload) if err != nil { - return nil, fmt.Errorf("marshal openai probe payload: %w", err) + return nil, nil, fmt.Errorf("marshal openai probe payload: %w", err) } reqCtx, cancel := context.WithTimeout(ctx, 15*time.Second) defer cancel() req, err := http.NewRequestWithContext(reqCtx, http.MethodPost, chatgptCodexURL, bytes.NewReader(payloadBytes)) if err != nil { - return nil, fmt.Errorf("create openai probe request: %w", err) + return nil, nil, fmt.Errorf("create openai probe request: %w", err) } req.Host = "chatgpt.com" req.Header.Set("Content-Type", "application/json") @@ -505,43 +508,67 @@ func (s *AccountUsageService) probeOpenAICodexSnapshot(ctx context.Context, acco ResponseHeaderTimeout: 10 * time.Second, }) if err != nil { - return nil, fmt.Errorf("build openai probe client: %w", err) + return nil, nil, fmt.Errorf("build openai probe client: %w", err) } resp, err := client.Do(req) if err != nil { - return nil, fmt.Errorf("openai codex probe request failed: %w", err) + return nil, nil, fmt.Errorf("openai codex probe request failed: %w", err) } defer func() { _ = resp.Body.Close() }() - updates, err := extractOpenAICodexProbeUpdates(resp) + updates, resetAt, err := extractOpenAICodexProbeSnapshot(resp) if err != nil { - return nil, err + return nil, nil, err } - if len(updates) > 0 { - go func(accountID int64, updates map[string]any) { - updateCtx, updateCancel := context.WithTimeout(context.Background(), 5*time.Second) - defer updateCancel() + if len(updates) > 0 || resetAt != nil { + s.persistOpenAICodexProbeSnapshot(account.ID, updates, resetAt) + return updates, resetAt, nil + } + return nil, nil, nil +} + +func (s *AccountUsageService) persistOpenAICodexProbeSnapshot(accountID int64, updates map[string]any, resetAt *time.Time) { + if s == nil || s.accountRepo == nil || accountID <= 0 { + return + } + if len(updates) == 0 && resetAt == nil { + return + } + + go func() { + updateCtx, updateCancel := context.WithTimeout(context.Background(), 5*time.Second) + defer updateCancel() + if len(updates) > 0 { _ = s.accountRepo.UpdateExtra(updateCtx, accountID, updates) - }(account.ID, updates) - return updates, nil + } + if resetAt != nil { + _ = s.accountRepo.SetRateLimited(updateCtx, accountID, *resetAt) + } + }() +} + +func extractOpenAICodexProbeSnapshot(resp *http.Response) (map[string]any, *time.Time, error) { + if resp == nil { + return nil, nil, nil } - return nil, nil + if snapshot := ParseCodexRateLimitHeaders(resp.Header); snapshot != nil { + baseTime := time.Now() + updates := buildCodexUsageExtraUpdates(snapshot, baseTime) + resetAt := codexRateLimitResetAtFromSnapshot(snapshot, baseTime) + if len(updates) > 0 { + return updates, resetAt, nil + } + return nil, resetAt, nil + } + if resp.StatusCode < 200 || resp.StatusCode >= 300 { + return nil, nil, fmt.Errorf("openai codex probe returned status %d", resp.StatusCode) + } + return nil, nil, nil } func extractOpenAICodexProbeUpdates(resp *http.Response) (map[string]any, error) { - if resp == nil { - return nil, nil - } - if snapshot := ParseCodexRateLimitHeaders(resp.Header); snapshot != nil { - updates := buildCodexUsageExtraUpdates(snapshot, time.Now()) - if len(updates) > 0 { - return updates, nil - } - } - if resp.StatusCode < 200 || resp.StatusCode >= 300 { - return nil, fmt.Errorf("openai codex probe returned status %d", resp.StatusCode) - } - return nil, nil + updates, _, err := extractOpenAICodexProbeSnapshot(resp) + return updates, err } func mergeAccountExtra(account *Account, updates map[string]any) { diff --git a/backend/internal/service/account_usage_service_test.go b/backend/internal/service/account_usage_service_test.go index 974d9029..a063fe26 100644 --- a/backend/internal/service/account_usage_service_test.go +++ b/backend/internal/service/account_usage_service_test.go @@ -1,11 +1,36 @@ package service import ( + "context" "net/http" "testing" "time" ) +type accountUsageCodexProbeRepo struct { + stubOpenAIAccountRepo + updateExtraCh chan map[string]any + rateLimitCh chan time.Time +} + +func (r *accountUsageCodexProbeRepo) UpdateExtra(_ context.Context, _ int64, updates map[string]any) error { + if r.updateExtraCh != nil { + copied := make(map[string]any, len(updates)) + for k, v := range updates { + copied[k] = v + } + r.updateExtraCh <- copied + } + return nil +} + +func (r *accountUsageCodexProbeRepo) SetRateLimited(_ context.Context, _ int64, resetAt time.Time) error { + if r.rateLimitCh != nil { + r.rateLimitCh <- resetAt + } + return nil +} + func TestShouldRefreshOpenAICodexSnapshot(t *testing.T) { t.Parallel() @@ -66,3 +91,60 @@ func TestExtractOpenAICodexProbeUpdatesAccepts429WithCodexHeaders(t *testing.T) t.Fatalf("codex_7d_used_percent = %v, want 100", got) } } + +func TestExtractOpenAICodexProbeSnapshotAccepts429WithResetAt(t *testing.T) { + t.Parallel() + + headers := make(http.Header) + headers.Set("x-codex-primary-used-percent", "100") + headers.Set("x-codex-primary-reset-after-seconds", "604800") + headers.Set("x-codex-primary-window-minutes", "10080") + headers.Set("x-codex-secondary-used-percent", "100") + headers.Set("x-codex-secondary-reset-after-seconds", "18000") + headers.Set("x-codex-secondary-window-minutes", "300") + + updates, resetAt, err := extractOpenAICodexProbeSnapshot(&http.Response{StatusCode: http.StatusTooManyRequests, Header: headers}) + if err != nil { + t.Fatalf("extractOpenAICodexProbeSnapshot() error = %v", err) + } + if len(updates) == 0 { + t.Fatal("expected codex probe updates from 429 headers") + } + if resetAt == nil { + t.Fatal("expected resetAt from exhausted codex headers") + } +} + +func TestAccountUsageService_PersistOpenAICodexProbeSnapshotSetsRateLimit(t *testing.T) { + t.Parallel() + + repo := &accountUsageCodexProbeRepo{ + updateExtraCh: make(chan map[string]any, 1), + rateLimitCh: make(chan time.Time, 1), + } + svc := &AccountUsageService{accountRepo: repo} + resetAt := time.Now().Add(2 * time.Hour).UTC().Truncate(time.Second) + + svc.persistOpenAICodexProbeSnapshot(321, map[string]any{ + "codex_7d_used_percent": 100.0, + "codex_7d_reset_at": resetAt.Format(time.RFC3339), + }, &resetAt) + + select { + case updates := <-repo.updateExtraCh: + if got := updates["codex_7d_used_percent"]; got != 100.0 { + t.Fatalf("codex_7d_used_percent = %v, want 100", got) + } + case <-time.After(2 * time.Second): + t.Fatal("waiting for codex probe extra persistence timed out") + } + + select { + case got := <-repo.rateLimitCh: + if got.Before(resetAt.Add(-time.Second)) || got.After(resetAt.Add(time.Second)) { + t.Fatalf("rate limit resetAt = %v, want around %v", got, resetAt) + } + case <-time.After(2 * time.Second): + t.Fatal("waiting for codex probe rate limit persistence timed out") + } +} diff --git a/backend/internal/service/api_key_service.go b/backend/internal/service/api_key_service.go index 17c5b486..18e9ff7a 100644 --- a/backend/internal/service/api_key_service.go +++ b/backend/internal/service/api_key_service.go @@ -6,6 +6,7 @@ import ( "encoding/hex" "fmt" "strconv" + "strings" "sync" "time" @@ -110,6 +111,15 @@ func (d *APIKeyRateLimitData) EffectiveUsage7d() float64 { return d.Usage7d } +// APIKeyQuotaUsageState captures the latest quota fields after an atomic quota update. +// It is intentionally small so repositories can return it from a single SQL statement. +type APIKeyQuotaUsageState struct { + QuotaUsed float64 + Quota float64 + Key string + Status string +} + // APIKeyCache defines cache operations for API key service type APIKeyCache interface { GetCreateAttemptCount(ctx context.Context, userID int64) (int, error) @@ -817,6 +827,21 @@ func (s *APIKeyService) UpdateQuotaUsed(ctx context.Context, apiKeyID int64, cos return nil } + type quotaStateReader interface { + IncrementQuotaUsedAndGetState(ctx context.Context, id int64, amount float64) (*APIKeyQuotaUsageState, error) + } + + if repo, ok := s.apiKeyRepo.(quotaStateReader); ok { + state, err := repo.IncrementQuotaUsedAndGetState(ctx, apiKeyID, cost) + if err != nil { + return fmt.Errorf("increment quota used: %w", err) + } + if state != nil && state.Status == StatusAPIKeyQuotaExhausted && strings.TrimSpace(state.Key) != "" { + s.InvalidateAuthCacheByKey(ctx, state.Key) + } + return nil + } + // Use repository to atomically increment quota_used newQuotaUsed, err := s.apiKeyRepo.IncrementQuotaUsed(ctx, apiKeyID, cost) if err != nil { diff --git a/backend/internal/service/api_key_service_quota_test.go b/backend/internal/service/api_key_service_quota_test.go new file mode 100644 index 00000000..2e2f6f78 --- /dev/null +++ b/backend/internal/service/api_key_service_quota_test.go @@ -0,0 +1,170 @@ +//go:build unit + +package service + +import ( + "context" + "testing" + "time" + + "github.com/Wei-Shaw/sub2api/internal/pkg/pagination" + "github.com/stretchr/testify/require" +) + +type quotaStateRepoStub struct { + quotaBaseAPIKeyRepoStub + stateCalls int + state *APIKeyQuotaUsageState + stateErr error +} + +func (s *quotaStateRepoStub) IncrementQuotaUsedAndGetState(ctx context.Context, id int64, amount float64) (*APIKeyQuotaUsageState, error) { + s.stateCalls++ + if s.stateErr != nil { + return nil, s.stateErr + } + if s.state == nil { + return nil, nil + } + out := *s.state + return &out, nil +} + +type quotaStateCacheStub struct { + deleteAuthKeys []string +} + +func (s *quotaStateCacheStub) GetCreateAttemptCount(context.Context, int64) (int, error) { + return 0, nil +} + +func (s *quotaStateCacheStub) IncrementCreateAttemptCount(context.Context, int64) error { + return nil +} + +func (s *quotaStateCacheStub) DeleteCreateAttemptCount(context.Context, int64) error { + return nil +} + +func (s *quotaStateCacheStub) IncrementDailyUsage(context.Context, string) error { + return nil +} + +func (s *quotaStateCacheStub) SetDailyUsageExpiry(context.Context, string, time.Duration) error { + return nil +} + +func (s *quotaStateCacheStub) GetAuthCache(context.Context, string) (*APIKeyAuthCacheEntry, error) { + return nil, nil +} + +func (s *quotaStateCacheStub) SetAuthCache(context.Context, string, *APIKeyAuthCacheEntry, time.Duration) error { + return nil +} + +func (s *quotaStateCacheStub) DeleteAuthCache(_ context.Context, key string) error { + s.deleteAuthKeys = append(s.deleteAuthKeys, key) + return nil +} + +func (s *quotaStateCacheStub) PublishAuthCacheInvalidation(context.Context, string) error { + return nil +} + +func (s *quotaStateCacheStub) SubscribeAuthCacheInvalidation(context.Context, func(string)) error { + return nil +} + +type quotaBaseAPIKeyRepoStub struct { + getByIDCalls int +} + +func (s *quotaBaseAPIKeyRepoStub) Create(context.Context, *APIKey) error { + panic("unexpected Create call") +} +func (s *quotaBaseAPIKeyRepoStub) GetByID(context.Context, int64) (*APIKey, error) { + s.getByIDCalls++ + return nil, nil +} +func (s *quotaBaseAPIKeyRepoStub) GetKeyAndOwnerID(context.Context, int64) (string, int64, error) { + panic("unexpected GetKeyAndOwnerID call") +} +func (s *quotaBaseAPIKeyRepoStub) GetByKey(context.Context, string) (*APIKey, error) { + panic("unexpected GetByKey call") +} +func (s *quotaBaseAPIKeyRepoStub) GetByKeyForAuth(context.Context, string) (*APIKey, error) { + panic("unexpected GetByKeyForAuth call") +} +func (s *quotaBaseAPIKeyRepoStub) Update(context.Context, *APIKey) error { + panic("unexpected Update call") +} +func (s *quotaBaseAPIKeyRepoStub) Delete(context.Context, int64) error { + panic("unexpected Delete call") +} +func (s *quotaBaseAPIKeyRepoStub) ListByUserID(context.Context, int64, pagination.PaginationParams, APIKeyListFilters) ([]APIKey, *pagination.PaginationResult, error) { + panic("unexpected ListByUserID call") +} +func (s *quotaBaseAPIKeyRepoStub) VerifyOwnership(context.Context, int64, []int64) ([]int64, error) { + panic("unexpected VerifyOwnership call") +} +func (s *quotaBaseAPIKeyRepoStub) CountByUserID(context.Context, int64) (int64, error) { + panic("unexpected CountByUserID call") +} +func (s *quotaBaseAPIKeyRepoStub) ExistsByKey(context.Context, string) (bool, error) { + panic("unexpected ExistsByKey call") +} +func (s *quotaBaseAPIKeyRepoStub) ListByGroupID(context.Context, int64, pagination.PaginationParams) ([]APIKey, *pagination.PaginationResult, error) { + panic("unexpected ListByGroupID call") +} +func (s *quotaBaseAPIKeyRepoStub) SearchAPIKeys(context.Context, int64, string, int) ([]APIKey, error) { + panic("unexpected SearchAPIKeys call") +} +func (s *quotaBaseAPIKeyRepoStub) ClearGroupIDByGroupID(context.Context, int64) (int64, error) { + panic("unexpected ClearGroupIDByGroupID call") +} +func (s *quotaBaseAPIKeyRepoStub) CountByGroupID(context.Context, int64) (int64, error) { + panic("unexpected CountByGroupID call") +} +func (s *quotaBaseAPIKeyRepoStub) ListKeysByUserID(context.Context, int64) ([]string, error) { + panic("unexpected ListKeysByUserID call") +} +func (s *quotaBaseAPIKeyRepoStub) ListKeysByGroupID(context.Context, int64) ([]string, error) { + panic("unexpected ListKeysByGroupID call") +} +func (s *quotaBaseAPIKeyRepoStub) IncrementQuotaUsed(context.Context, int64, float64) (float64, error) { + panic("unexpected IncrementQuotaUsed call") +} +func (s *quotaBaseAPIKeyRepoStub) UpdateLastUsed(context.Context, int64, time.Time) error { + panic("unexpected UpdateLastUsed call") +} +func (s *quotaBaseAPIKeyRepoStub) IncrementRateLimitUsage(context.Context, int64, float64) error { + panic("unexpected IncrementRateLimitUsage call") +} +func (s *quotaBaseAPIKeyRepoStub) ResetRateLimitWindows(context.Context, int64) error { + panic("unexpected ResetRateLimitWindows call") +} +func (s *quotaBaseAPIKeyRepoStub) GetRateLimitData(context.Context, int64) (*APIKeyRateLimitData, error) { + panic("unexpected GetRateLimitData call") +} + +func TestAPIKeyService_UpdateQuotaUsed_UsesAtomicStatePath(t *testing.T) { + repo := "aStateRepoStub{ + state: &APIKeyQuotaUsageState{ + QuotaUsed: 12, + Quota: 10, + Key: "sk-test-quota", + Status: StatusAPIKeyQuotaExhausted, + }, + } + cache := "aStateCacheStub{} + svc := &APIKeyService{ + apiKeyRepo: repo, + cache: cache, + } + + err := svc.UpdateQuotaUsed(context.Background(), 101, 2) + require.NoError(t, err) + require.Equal(t, 1, repo.stateCalls) + require.Equal(t, 0, repo.getByIDCalls, "fast path should not re-read API key by id") + require.Equal(t, []string{svc.authCacheKey("sk-test-quota")}, cache.deleteAuthKeys) +} diff --git a/backend/internal/service/gateway_service.go b/backend/internal/service/gateway_service.go index 080de063..8a433a36 100644 --- a/backend/internal/service/gateway_service.go +++ b/backend/internal/service/gateway_service.go @@ -5998,6 +5998,22 @@ func (s *GatewayService) handleStreamingResponse(ctx context.Context, resp *http intervalCh = intervalTicker.C } + // 下游 keepalive:防止代理/Cloudflare Tunnel 因连接空闲而断开 + keepaliveInterval := time.Duration(0) + if s.cfg != nil && s.cfg.Gateway.StreamKeepaliveInterval > 0 { + keepaliveInterval = time.Duration(s.cfg.Gateway.StreamKeepaliveInterval) * time.Second + } + var keepaliveTicker *time.Ticker + if keepaliveInterval > 0 { + keepaliveTicker = time.NewTicker(keepaliveInterval) + defer keepaliveTicker.Stop() + } + var keepaliveCh <-chan time.Time + if keepaliveTicker != nil { + keepaliveCh = keepaliveTicker.C + } + lastDataAt := time.Now() + // 仅发送一次错误事件,避免多次写入导致协议混乱(写失败时尽力通知客户端) errorEventSent := false sendErrorEvent := func(reason string) { @@ -6187,6 +6203,7 @@ func (s *GatewayService) handleStreamingResponse(ctx context.Context, resp *http break } flusher.Flush() + lastDataAt = time.Now() } if data != "" { if firstTokenMs == nil && data != "[DONE]" { @@ -6220,6 +6237,22 @@ func (s *GatewayService) handleStreamingResponse(ctx context.Context, resp *http } sendErrorEvent("stream_timeout") return &streamingResult{usage: usage, firstTokenMs: firstTokenMs}, fmt.Errorf("stream data interval timeout") + + case <-keepaliveCh: + if clientDisconnected { + continue + } + if time.Since(lastDataAt) < keepaliveInterval { + continue + } + // SSE ping 事件:Anthropic 原生格式,客户端会正确处理, + // 同时保持连接活跃防止 Cloudflare Tunnel 等代理断开 + if _, werr := fmt.Fprint(w, "event: ping\ndata: {\"type\": \"ping\"}\n\n"); werr != nil { + clientDisconnected = true + logger.LegacyPrintf("service.gateway", "Client disconnected during keepalive ping, continuing to drain upstream for billing") + continue + } + flusher.Flush() } } diff --git a/backend/internal/service/openai_codex_transform.go b/backend/internal/service/openai_codex_transform.go index d1920140..b0e4d44f 100644 --- a/backend/internal/service/openai_codex_transform.go +++ b/backend/internal/service/openai_codex_transform.go @@ -1,6 +1,7 @@ package service import ( + "fmt" "strings" ) @@ -226,6 +227,29 @@ func normalizeCodexModel(model string) string { return "gpt-5.1" } +func SupportsVerbosity(model string) bool { + if !strings.HasPrefix(model, "gpt-") { + return true + } + + var major, minor int + n, _ := fmt.Sscanf(model, "gpt-%d.%d", &major, &minor) + + if major > 5 { + return true + } + if major < 5 { + return false + } + + // gpt-5 + if n == 1 { + return true + } + + return minor >= 3 +} + func getNormalizedCodexModel(modelID string) string { if modelID == "" { return "" diff --git a/backend/internal/service/openai_gateway_chat_completions.go b/backend/internal/service/openai_gateway_chat_completions.go new file mode 100644 index 00000000..f893eeb9 --- /dev/null +++ b/backend/internal/service/openai_gateway_chat_completions.go @@ -0,0 +1,512 @@ +package service + +import ( + "bufio" + "bytes" + "context" + "encoding/json" + "errors" + "fmt" + "io" + "net/http" + "strings" + "time" + + "github.com/Wei-Shaw/sub2api/internal/pkg/apicompat" + "github.com/Wei-Shaw/sub2api/internal/pkg/logger" + "github.com/Wei-Shaw/sub2api/internal/util/responseheaders" + "github.com/gin-gonic/gin" + "go.uber.org/zap" +) + +// ForwardAsChatCompletions accepts a Chat Completions request body, converts it +// to OpenAI Responses API format, forwards to the OpenAI upstream, and converts +// the response back to Chat Completions format. All account types (OAuth and API +// Key) go through the Responses API conversion path since the upstream only +// exposes the /v1/responses endpoint. +func (s *OpenAIGatewayService) ForwardAsChatCompletions( + ctx context.Context, + c *gin.Context, + account *Account, + body []byte, + promptCacheKey string, + defaultMappedModel string, +) (*OpenAIForwardResult, error) { + startTime := time.Now() + + // 1. Parse Chat Completions request + var chatReq apicompat.ChatCompletionsRequest + if err := json.Unmarshal(body, &chatReq); err != nil { + return nil, fmt.Errorf("parse chat completions request: %w", err) + } + originalModel := chatReq.Model + clientStream := chatReq.Stream + includeUsage := chatReq.StreamOptions != nil && chatReq.StreamOptions.IncludeUsage + + // 2. Convert to Responses and forward + // ChatCompletionsToResponses always sets Stream=true (upstream always streams). + responsesReq, err := apicompat.ChatCompletionsToResponses(&chatReq) + if err != nil { + return nil, fmt.Errorf("convert chat completions to responses: %w", err) + } + + // 3. Model mapping + mappedModel := account.GetMappedModel(originalModel) + if mappedModel == originalModel && defaultMappedModel != "" { + mappedModel = defaultMappedModel + } + responsesReq.Model = mappedModel + + logger.L().Debug("openai chat_completions: model mapping applied", + zap.Int64("account_id", account.ID), + zap.String("original_model", originalModel), + zap.String("mapped_model", mappedModel), + zap.Bool("stream", clientStream), + ) + + // 4. Marshal Responses request body, then apply OAuth codex transform + responsesBody, err := json.Marshal(responsesReq) + if err != nil { + return nil, fmt.Errorf("marshal responses request: %w", err) + } + + if account.Type == AccountTypeOAuth { + var reqBody map[string]any + if err := json.Unmarshal(responsesBody, &reqBody); err != nil { + return nil, fmt.Errorf("unmarshal for codex transform: %w", err) + } + codexResult := applyCodexOAuthTransform(reqBody, false, false) + if codexResult.PromptCacheKey != "" { + promptCacheKey = codexResult.PromptCacheKey + } else if promptCacheKey != "" { + reqBody["prompt_cache_key"] = promptCacheKey + } + responsesBody, err = json.Marshal(reqBody) + if err != nil { + return nil, fmt.Errorf("remarshal after codex transform: %w", err) + } + } + + // 5. Get access token + token, _, err := s.GetAccessToken(ctx, account) + if err != nil { + return nil, fmt.Errorf("get access token: %w", err) + } + + // 6. Build upstream request + upstreamReq, err := s.buildUpstreamRequest(ctx, c, account, responsesBody, token, true, promptCacheKey, false) + if err != nil { + return nil, fmt.Errorf("build upstream request: %w", err) + } + + if promptCacheKey != "" { + upstreamReq.Header.Set("session_id", generateSessionUUID(promptCacheKey)) + } + + // 7. Send request + proxyURL := "" + if account.Proxy != nil { + proxyURL = account.Proxy.URL() + } + resp, err := s.httpUpstream.Do(upstreamReq, proxyURL, account.ID, account.Concurrency) + if err != nil { + safeErr := sanitizeUpstreamErrorMessage(err.Error()) + setOpsUpstreamError(c, 0, safeErr, "") + appendOpsUpstreamError(c, OpsUpstreamErrorEvent{ + Platform: account.Platform, + AccountID: account.ID, + AccountName: account.Name, + UpstreamStatusCode: 0, + Kind: "request_error", + Message: safeErr, + }) + writeChatCompletionsError(c, http.StatusBadGateway, "upstream_error", "Upstream request failed") + return nil, fmt.Errorf("upstream request failed: %s", safeErr) + } + defer func() { _ = resp.Body.Close() }() + + // 8. Handle error response with failover + if resp.StatusCode >= 400 { + respBody, _ := io.ReadAll(io.LimitReader(resp.Body, 2<<20)) + _ = resp.Body.Close() + resp.Body = io.NopCloser(bytes.NewReader(respBody)) + + upstreamMsg := strings.TrimSpace(extractUpstreamErrorMessage(respBody)) + upstreamMsg = sanitizeUpstreamErrorMessage(upstreamMsg) + if s.shouldFailoverOpenAIUpstreamResponse(resp.StatusCode, upstreamMsg, respBody) { + upstreamDetail := "" + if s.cfg != nil && s.cfg.Gateway.LogUpstreamErrorBody { + maxBytes := s.cfg.Gateway.LogUpstreamErrorBodyMaxBytes + if maxBytes <= 0 { + maxBytes = 2048 + } + upstreamDetail = truncateString(string(respBody), maxBytes) + } + appendOpsUpstreamError(c, OpsUpstreamErrorEvent{ + Platform: account.Platform, + AccountID: account.ID, + AccountName: account.Name, + UpstreamStatusCode: resp.StatusCode, + UpstreamRequestID: resp.Header.Get("x-request-id"), + Kind: "failover", + Message: upstreamMsg, + Detail: upstreamDetail, + }) + if s.rateLimitService != nil { + s.rateLimitService.HandleUpstreamError(ctx, account, resp.StatusCode, resp.Header, respBody) + } + return nil, &UpstreamFailoverError{ + StatusCode: resp.StatusCode, + ResponseBody: respBody, + RetryableOnSameAccount: account.IsPoolMode() && (isPoolModeRetryableStatus(resp.StatusCode) || isOpenAITransientProcessingError(resp.StatusCode, upstreamMsg, respBody)), + } + } + return s.handleChatCompletionsErrorResponse(resp, c, account) + } + + // 9. Handle normal response + var result *OpenAIForwardResult + var handleErr error + if clientStream { + result, handleErr = s.handleChatStreamingResponse(resp, c, originalModel, mappedModel, includeUsage, startTime) + } else { + result, handleErr = s.handleChatBufferedStreamingResponse(resp, c, originalModel, mappedModel, startTime) + } + + // Propagate ServiceTier and ReasoningEffort to result for billing + if handleErr == nil && result != nil { + if responsesReq.ServiceTier != "" { + st := responsesReq.ServiceTier + result.ServiceTier = &st + } + if responsesReq.Reasoning != nil && responsesReq.Reasoning.Effort != "" { + re := responsesReq.Reasoning.Effort + result.ReasoningEffort = &re + } + } + + // Extract and save Codex usage snapshot from response headers (for OAuth accounts) + if handleErr == nil && account.Type == AccountTypeOAuth { + if snapshot := ParseCodexRateLimitHeaders(resp.Header); snapshot != nil { + s.updateCodexUsageSnapshot(ctx, account.ID, snapshot) + } + } + + return result, handleErr +} + +// handleChatCompletionsErrorResponse reads an upstream error and returns it in +// OpenAI Chat Completions error format. +func (s *OpenAIGatewayService) handleChatCompletionsErrorResponse( + resp *http.Response, + c *gin.Context, + account *Account, +) (*OpenAIForwardResult, error) { + return s.handleCompatErrorResponse(resp, c, account, writeChatCompletionsError) +} + +// handleChatBufferedStreamingResponse reads all Responses SSE events from the +// upstream, finds the terminal event, converts to a Chat Completions JSON +// response, and writes it to the client. +func (s *OpenAIGatewayService) handleChatBufferedStreamingResponse( + resp *http.Response, + c *gin.Context, + originalModel string, + mappedModel string, + startTime time.Time, +) (*OpenAIForwardResult, error) { + requestID := resp.Header.Get("x-request-id") + + scanner := bufio.NewScanner(resp.Body) + maxLineSize := defaultMaxLineSize + if s.cfg != nil && s.cfg.Gateway.MaxLineSize > 0 { + maxLineSize = s.cfg.Gateway.MaxLineSize + } + scanner.Buffer(make([]byte, 0, 64*1024), maxLineSize) + + var finalResponse *apicompat.ResponsesResponse + var usage OpenAIUsage + + for scanner.Scan() { + line := scanner.Text() + if !strings.HasPrefix(line, "data: ") || line == "data: [DONE]" { + continue + } + payload := line[6:] + + var event apicompat.ResponsesStreamEvent + if err := json.Unmarshal([]byte(payload), &event); err != nil { + logger.L().Warn("openai chat_completions buffered: failed to parse event", + zap.Error(err), + zap.String("request_id", requestID), + ) + continue + } + + if (event.Type == "response.completed" || event.Type == "response.incomplete" || event.Type == "response.failed") && + event.Response != nil { + finalResponse = event.Response + if event.Response.Usage != nil { + usage = OpenAIUsage{ + InputTokens: event.Response.Usage.InputTokens, + OutputTokens: event.Response.Usage.OutputTokens, + } + if event.Response.Usage.InputTokensDetails != nil { + usage.CacheReadInputTokens = event.Response.Usage.InputTokensDetails.CachedTokens + } + } + } + } + + if err := scanner.Err(); err != nil { + if !errors.Is(err, context.Canceled) && !errors.Is(err, context.DeadlineExceeded) { + logger.L().Warn("openai chat_completions buffered: read error", + zap.Error(err), + zap.String("request_id", requestID), + ) + } + } + + if finalResponse == nil { + writeChatCompletionsError(c, http.StatusBadGateway, "api_error", "Upstream stream ended without a terminal response event") + return nil, fmt.Errorf("upstream stream ended without terminal event") + } + + chatResp := apicompat.ResponsesToChatCompletions(finalResponse, originalModel) + + if s.responseHeaderFilter != nil { + responseheaders.WriteFilteredHeaders(c.Writer.Header(), resp.Header, s.responseHeaderFilter) + } + c.JSON(http.StatusOK, chatResp) + + return &OpenAIForwardResult{ + RequestID: requestID, + Usage: usage, + Model: originalModel, + BillingModel: mappedModel, + Stream: false, + Duration: time.Since(startTime), + }, nil +} + +// handleChatStreamingResponse reads Responses SSE events from upstream, +// converts each to Chat Completions SSE chunks, and writes them to the client. +func (s *OpenAIGatewayService) handleChatStreamingResponse( + resp *http.Response, + c *gin.Context, + originalModel string, + mappedModel string, + includeUsage bool, + startTime time.Time, +) (*OpenAIForwardResult, error) { + requestID := resp.Header.Get("x-request-id") + + if s.responseHeaderFilter != nil { + responseheaders.WriteFilteredHeaders(c.Writer.Header(), resp.Header, s.responseHeaderFilter) + } + c.Writer.Header().Set("Content-Type", "text/event-stream") + c.Writer.Header().Set("Cache-Control", "no-cache") + c.Writer.Header().Set("Connection", "keep-alive") + c.Writer.Header().Set("X-Accel-Buffering", "no") + c.Writer.WriteHeader(http.StatusOK) + + state := apicompat.NewResponsesEventToChatState() + state.Model = originalModel + state.IncludeUsage = includeUsage + + var usage OpenAIUsage + var firstTokenMs *int + firstChunk := true + + scanner := bufio.NewScanner(resp.Body) + maxLineSize := defaultMaxLineSize + if s.cfg != nil && s.cfg.Gateway.MaxLineSize > 0 { + maxLineSize = s.cfg.Gateway.MaxLineSize + } + scanner.Buffer(make([]byte, 0, 64*1024), maxLineSize) + + resultWithUsage := func() *OpenAIForwardResult { + return &OpenAIForwardResult{ + RequestID: requestID, + Usage: usage, + Model: originalModel, + BillingModel: mappedModel, + Stream: true, + Duration: time.Since(startTime), + FirstTokenMs: firstTokenMs, + } + } + + processDataLine := func(payload string) bool { + if firstChunk { + firstChunk = false + ms := int(time.Since(startTime).Milliseconds()) + firstTokenMs = &ms + } + + var event apicompat.ResponsesStreamEvent + if err := json.Unmarshal([]byte(payload), &event); err != nil { + logger.L().Warn("openai chat_completions stream: failed to parse event", + zap.Error(err), + zap.String("request_id", requestID), + ) + return false + } + + // Extract usage from completion events + if (event.Type == "response.completed" || event.Type == "response.incomplete" || event.Type == "response.failed") && + event.Response != nil && event.Response.Usage != nil { + usage = OpenAIUsage{ + InputTokens: event.Response.Usage.InputTokens, + OutputTokens: event.Response.Usage.OutputTokens, + } + if event.Response.Usage.InputTokensDetails != nil { + usage.CacheReadInputTokens = event.Response.Usage.InputTokensDetails.CachedTokens + } + } + + chunks := apicompat.ResponsesEventToChatChunks(&event, state) + for _, chunk := range chunks { + sse, err := apicompat.ChatChunkToSSE(chunk) + if err != nil { + logger.L().Warn("openai chat_completions stream: failed to marshal chunk", + zap.Error(err), + zap.String("request_id", requestID), + ) + continue + } + if _, err := fmt.Fprint(c.Writer, sse); err != nil { + logger.L().Info("openai chat_completions stream: client disconnected", + zap.String("request_id", requestID), + ) + return true + } + } + if len(chunks) > 0 { + c.Writer.Flush() + } + return false + } + + finalizeStream := func() (*OpenAIForwardResult, error) { + if finalChunks := apicompat.FinalizeResponsesChatStream(state); len(finalChunks) > 0 { + for _, chunk := range finalChunks { + sse, err := apicompat.ChatChunkToSSE(chunk) + if err != nil { + continue + } + fmt.Fprint(c.Writer, sse) //nolint:errcheck + } + } + // Send [DONE] sentinel + fmt.Fprint(c.Writer, "data: [DONE]\n\n") //nolint:errcheck + c.Writer.Flush() + return resultWithUsage(), nil + } + + handleScanErr := func(err error) { + if err != nil && !errors.Is(err, context.Canceled) && !errors.Is(err, context.DeadlineExceeded) { + logger.L().Warn("openai chat_completions stream: read error", + zap.Error(err), + zap.String("request_id", requestID), + ) + } + } + + // Determine keepalive interval + keepaliveInterval := time.Duration(0) + if s.cfg != nil && s.cfg.Gateway.StreamKeepaliveInterval > 0 { + keepaliveInterval = time.Duration(s.cfg.Gateway.StreamKeepaliveInterval) * time.Second + } + + // No keepalive: fast synchronous path + if keepaliveInterval <= 0 { + for scanner.Scan() { + line := scanner.Text() + if !strings.HasPrefix(line, "data: ") || line == "data: [DONE]" { + continue + } + if processDataLine(line[6:]) { + return resultWithUsage(), nil + } + } + handleScanErr(scanner.Err()) + return finalizeStream() + } + + // With keepalive: goroutine + channel + select + type scanEvent struct { + line string + err error + } + events := make(chan scanEvent, 16) + done := make(chan struct{}) + sendEvent := func(ev scanEvent) bool { + select { + case events <- ev: + return true + case <-done: + return false + } + } + go func() { + defer close(events) + for scanner.Scan() { + if !sendEvent(scanEvent{line: scanner.Text()}) { + return + } + } + if err := scanner.Err(); err != nil { + _ = sendEvent(scanEvent{err: err}) + } + }() + defer close(done) + + keepaliveTicker := time.NewTicker(keepaliveInterval) + defer keepaliveTicker.Stop() + lastDataAt := time.Now() + + for { + select { + case ev, ok := <-events: + if !ok { + return finalizeStream() + } + if ev.err != nil { + handleScanErr(ev.err) + return finalizeStream() + } + lastDataAt = time.Now() + line := ev.line + if !strings.HasPrefix(line, "data: ") || line == "data: [DONE]" { + continue + } + if processDataLine(line[6:]) { + return resultWithUsage(), nil + } + + case <-keepaliveTicker.C: + if time.Since(lastDataAt) < keepaliveInterval { + continue + } + // Send SSE comment as keepalive + if _, err := fmt.Fprint(c.Writer, ":\n\n"); err != nil { + logger.L().Info("openai chat_completions stream: client disconnected during keepalive", + zap.String("request_id", requestID), + ) + return resultWithUsage(), nil + } + c.Writer.Flush() + } + } +} + +// writeChatCompletionsError writes an error response in OpenAI Chat Completions format. +func writeChatCompletionsError(c *gin.Context, statusCode int, errType, message string) { + c.JSON(statusCode, gin.H{ + "error": gin.H{ + "type": errType, + "message": message, + }, + }) +} diff --git a/backend/internal/service/openai_gateway_messages.go b/backend/internal/service/openai_gateway_messages.go index 46fc68a9..e4a3d9c0 100644 --- a/backend/internal/service/openai_gateway_messages.go +++ b/backend/internal/service/openai_gateway_messages.go @@ -172,7 +172,7 @@ func (s *OpenAIGatewayService) ForwardAsAnthropic( return nil, &UpstreamFailoverError{ StatusCode: resp.StatusCode, ResponseBody: respBody, - RetryableOnSameAccount: account.IsPoolMode() && isOpenAITransientProcessingError(resp.StatusCode, upstreamMsg, respBody), + RetryableOnSameAccount: account.IsPoolMode() && (isPoolModeRetryableStatus(resp.StatusCode) || isOpenAITransientProcessingError(resp.StatusCode, upstreamMsg, respBody)), } } // Non-failover error: return Anthropic-formatted error to client @@ -219,54 +219,7 @@ func (s *OpenAIGatewayService) handleAnthropicErrorResponse( c *gin.Context, account *Account, ) (*OpenAIForwardResult, error) { - body, _ := io.ReadAll(io.LimitReader(resp.Body, 2<<20)) - - upstreamMsg := strings.TrimSpace(extractUpstreamErrorMessage(body)) - if upstreamMsg == "" { - upstreamMsg = fmt.Sprintf("Upstream error: %d", resp.StatusCode) - } - upstreamMsg = sanitizeUpstreamErrorMessage(upstreamMsg) - - // Record upstream error details for ops logging - upstreamDetail := "" - if s.cfg != nil && s.cfg.Gateway.LogUpstreamErrorBody { - maxBytes := s.cfg.Gateway.LogUpstreamErrorBodyMaxBytes - if maxBytes <= 0 { - maxBytes = 2048 - } - upstreamDetail = truncateString(string(body), maxBytes) - } - setOpsUpstreamError(c, resp.StatusCode, upstreamMsg, upstreamDetail) - - // Apply error passthrough rules (matches handleErrorResponse pattern in openai_gateway_service.go) - if status, errType, errMsg, matched := applyErrorPassthroughRule( - c, account.Platform, resp.StatusCode, body, - http.StatusBadGateway, "api_error", "Upstream request failed", - ); matched { - writeAnthropicError(c, status, errType, errMsg) - if upstreamMsg == "" { - upstreamMsg = errMsg - } - if upstreamMsg == "" { - return nil, fmt.Errorf("upstream error: %d (passthrough rule matched)", resp.StatusCode) - } - return nil, fmt.Errorf("upstream error: %d (passthrough rule matched) message=%s", resp.StatusCode, upstreamMsg) - } - - errType := "api_error" - switch { - case resp.StatusCode == 400: - errType = "invalid_request_error" - case resp.StatusCode == 404: - errType = "not_found_error" - case resp.StatusCode == 429: - errType = "rate_limit_error" - case resp.StatusCode >= 500: - errType = "api_error" - } - - writeAnthropicError(c, resp.StatusCode, errType, upstreamMsg) - return nil, fmt.Errorf("upstream error: %d %s", resp.StatusCode, upstreamMsg) + return s.handleCompatErrorResponse(resp, c, account, writeAnthropicError) } // handleAnthropicBufferedStreamingResponse reads all Responses SSE events from diff --git a/backend/internal/service/openai_gateway_service.go b/backend/internal/service/openai_gateway_service.go index 44cfc83a..023e4ed4 100644 --- a/backend/internal/service/openai_gateway_service.go +++ b/backend/internal/service/openai_gateway_service.go @@ -52,6 +52,8 @@ const ( openAIWSRetryJitterRatioDefault = 0.2 openAICompactSessionSeedKey = "openai_compact_session_seed" codexCLIVersion = "0.104.0" + // Codex 限额快照仅用于后台展示/诊断,不需要每个成功请求都立即落库。 + openAICodexSnapshotPersistMinInterval = 30 * time.Second ) // OpenAI allowed headers whitelist (for non-passthrough). @@ -255,6 +257,46 @@ type openAIWSRetryMetrics struct { nonRetryableFastFallback atomic.Int64 } +type accountWriteThrottle struct { + minInterval time.Duration + mu sync.Mutex + lastByID map[int64]time.Time +} + +func newAccountWriteThrottle(minInterval time.Duration) *accountWriteThrottle { + return &accountWriteThrottle{ + minInterval: minInterval, + lastByID: make(map[int64]time.Time), + } +} + +func (t *accountWriteThrottle) Allow(id int64, now time.Time) bool { + if t == nil || id <= 0 || t.minInterval <= 0 { + return true + } + + t.mu.Lock() + defer t.mu.Unlock() + + if last, ok := t.lastByID[id]; ok && now.Sub(last) < t.minInterval { + return false + } + t.lastByID[id] = now + + if len(t.lastByID) > 4096 { + cutoff := now.Add(-4 * t.minInterval) + for accountID, writtenAt := range t.lastByID { + if writtenAt.Before(cutoff) { + delete(t.lastByID, accountID) + } + } + } + + return true +} + +var defaultOpenAICodexSnapshotPersistThrottle = newAccountWriteThrottle(openAICodexSnapshotPersistMinInterval) + // OpenAIGatewayService handles OpenAI API gateway operations type OpenAIGatewayService struct { accountRepo AccountRepository @@ -289,6 +331,7 @@ type OpenAIGatewayService struct { openaiWSFallbackUntil sync.Map // key: int64(accountID), value: time.Time openaiWSRetryMetrics openAIWSRetryMetrics responseHeaderFilter *responseheaders.CompiledHeaderFilter + codexSnapshotThrottle *accountWriteThrottle } // NewOpenAIGatewayService creates a new OpenAIGatewayService @@ -329,17 +372,25 @@ func NewOpenAIGatewayService( nil, "service.openai_gateway", ), - httpUpstream: httpUpstream, - deferredService: deferredService, - openAITokenProvider: openAITokenProvider, - toolCorrector: NewCodexToolCorrector(), - openaiWSResolver: NewOpenAIWSProtocolResolver(cfg), - responseHeaderFilter: compileResponseHeaderFilter(cfg), + httpUpstream: httpUpstream, + deferredService: deferredService, + openAITokenProvider: openAITokenProvider, + toolCorrector: NewCodexToolCorrector(), + openaiWSResolver: NewOpenAIWSProtocolResolver(cfg), + responseHeaderFilter: compileResponseHeaderFilter(cfg), + codexSnapshotThrottle: newAccountWriteThrottle(openAICodexSnapshotPersistMinInterval), } svc.logOpenAIWSModeBootstrap() return svc } +func (s *OpenAIGatewayService) getCodexSnapshotThrottle() *accountWriteThrottle { + if s != nil && s.codexSnapshotThrottle != nil { + return s.codexSnapshotThrottle + } + return defaultOpenAICodexSnapshotPersistThrottle +} + func (s *OpenAIGatewayService) billingDeps() *billingDeps { return &billingDeps{ accountRepo: s.accountRepo, @@ -1716,6 +1767,14 @@ func (s *OpenAIGatewayService) Forward(ctx context.Context, c *gin.Context, acco bodyModified = true markPatchSet("model", normalizedModel) } + + // 移除 gpt-5.2-codex 以下的版本 verbosity 参数 + // 确保高版本模型向低版本模型映射不报错 + if !SupportsVerbosity(normalizedModel) { + if text, ok := reqBody["text"].(map[string]any); ok { + delete(text, "verbosity") + } + } } // 规范化 reasoning.effort 参数(minimal -> none),与上游允许值对齐。 @@ -2947,6 +3006,120 @@ func (s *OpenAIGatewayService) handleErrorResponse( return nil, fmt.Errorf("upstream error: %d message=%s", resp.StatusCode, upstreamMsg) } +// compatErrorWriter is the signature for format-specific error writers used by +// the compat paths (Chat Completions and Anthropic Messages). +type compatErrorWriter func(c *gin.Context, statusCode int, errType, message string) + +// handleCompatErrorResponse is the shared non-failover error handler for the +// Chat Completions and Anthropic Messages compat paths. It mirrors the logic of +// handleErrorResponse (passthrough rules, ShouldHandleErrorCode, rate-limit +// tracking, secondary failover) but delegates the final error write to the +// format-specific writer function. +func (s *OpenAIGatewayService) handleCompatErrorResponse( + resp *http.Response, + c *gin.Context, + account *Account, + writeError compatErrorWriter, +) (*OpenAIForwardResult, error) { + body, _ := io.ReadAll(io.LimitReader(resp.Body, 2<<20)) + + upstreamMsg := strings.TrimSpace(extractUpstreamErrorMessage(body)) + if upstreamMsg == "" { + upstreamMsg = fmt.Sprintf("Upstream error: %d", resp.StatusCode) + } + upstreamMsg = sanitizeUpstreamErrorMessage(upstreamMsg) + + upstreamDetail := "" + if s.cfg != nil && s.cfg.Gateway.LogUpstreamErrorBody { + maxBytes := s.cfg.Gateway.LogUpstreamErrorBodyMaxBytes + if maxBytes <= 0 { + maxBytes = 2048 + } + upstreamDetail = truncateString(string(body), maxBytes) + } + setOpsUpstreamError(c, resp.StatusCode, upstreamMsg, upstreamDetail) + + // Apply error passthrough rules + if status, errType, errMsg, matched := applyErrorPassthroughRule( + c, account.Platform, resp.StatusCode, body, + http.StatusBadGateway, "api_error", "Upstream request failed", + ); matched { + writeError(c, status, errType, errMsg) + if upstreamMsg == "" { + upstreamMsg = errMsg + } + if upstreamMsg == "" { + return nil, fmt.Errorf("upstream error: %d (passthrough rule matched)", resp.StatusCode) + } + return nil, fmt.Errorf("upstream error: %d (passthrough rule matched) message=%s", resp.StatusCode, upstreamMsg) + } + + // Check custom error codes — if the account does not handle this status, + // return a generic error without exposing upstream details. + if !account.ShouldHandleErrorCode(resp.StatusCode) { + appendOpsUpstreamError(c, OpsUpstreamErrorEvent{ + Platform: account.Platform, + AccountID: account.ID, + AccountName: account.Name, + UpstreamStatusCode: resp.StatusCode, + UpstreamRequestID: resp.Header.Get("x-request-id"), + Kind: "http_error", + Message: upstreamMsg, + Detail: upstreamDetail, + }) + writeError(c, http.StatusInternalServerError, "api_error", "Upstream gateway error") + if upstreamMsg == "" { + return nil, fmt.Errorf("upstream error: %d (not in custom error codes)", resp.StatusCode) + } + return nil, fmt.Errorf("upstream error: %d (not in custom error codes) message=%s", resp.StatusCode, upstreamMsg) + } + + // Track rate limits and decide whether to trigger secondary failover. + shouldDisable := false + if s.rateLimitService != nil { + shouldDisable = s.rateLimitService.HandleUpstreamError( + c.Request.Context(), account, resp.StatusCode, resp.Header, body, + ) + } + kind := "http_error" + if shouldDisable { + kind = "failover" + } + appendOpsUpstreamError(c, OpsUpstreamErrorEvent{ + Platform: account.Platform, + AccountID: account.ID, + AccountName: account.Name, + UpstreamStatusCode: resp.StatusCode, + UpstreamRequestID: resp.Header.Get("x-request-id"), + Kind: kind, + Message: upstreamMsg, + Detail: upstreamDetail, + }) + if shouldDisable { + return nil, &UpstreamFailoverError{ + StatusCode: resp.StatusCode, + ResponseBody: body, + RetryableOnSameAccount: account.IsPoolMode() && isPoolModeRetryableStatus(resp.StatusCode), + } + } + + // Map status code to error type and write response + errType := "api_error" + switch { + case resp.StatusCode == 400: + errType = "invalid_request_error" + case resp.StatusCode == 404: + errType = "not_found_error" + case resp.StatusCode == 429: + errType = "rate_limit_error" + case resp.StatusCode >= 500: + errType = "api_error" + } + + writeError(c, resp.StatusCode, errType, upstreamMsg) + return nil, fmt.Errorf("upstream error: %d %s", resp.StatusCode, upstreamMsg) +} + // openaiStreamingResult streaming response result type openaiStreamingResult struct { usage *OpenAIUsage @@ -4050,11 +4223,15 @@ func (s *OpenAIGatewayService) updateCodexUsageSnapshot(ctx context.Context, acc if len(updates) == 0 && resetAt == nil { return } + shouldPersistUpdates := len(updates) > 0 && s.getCodexSnapshotThrottle().Allow(accountID, now) + if !shouldPersistUpdates && resetAt == nil { + return + } go func() { updateCtx, cancel := context.WithTimeout(context.Background(), 5*time.Second) defer cancel() - if len(updates) > 0 { + if shouldPersistUpdates { _ = s.accountRepo.UpdateExtra(updateCtx, accountID, updates) } if resetAt != nil { diff --git a/backend/internal/service/openai_ws_ratelimit_signal_test.go b/backend/internal/service/openai_ws_ratelimit_signal_test.go index 28cb8e00..f5c79923 100644 --- a/backend/internal/service/openai_ws_ratelimit_signal_test.go +++ b/backend/internal/service/openai_ws_ratelimit_signal_test.go @@ -405,6 +405,40 @@ func TestOpenAIGatewayService_UpdateCodexUsageSnapshot_NonExhaustedSnapshotDoesN } } +func TestOpenAIGatewayService_UpdateCodexUsageSnapshot_ThrottlesExtraWrites(t *testing.T) { + repo := &openAICodexSnapshotAsyncRepo{ + updateExtraCh: make(chan map[string]any, 2), + rateLimitCh: make(chan time.Time, 2), + } + svc := &OpenAIGatewayService{ + accountRepo: repo, + codexSnapshotThrottle: newAccountWriteThrottle(time.Hour), + } + snapshot := &OpenAICodexUsageSnapshot{ + PrimaryUsedPercent: ptrFloat64WS(94), + PrimaryResetAfterSeconds: ptrIntWS(3600), + PrimaryWindowMinutes: ptrIntWS(10080), + SecondaryUsedPercent: ptrFloat64WS(22), + SecondaryResetAfterSeconds: ptrIntWS(1200), + SecondaryWindowMinutes: ptrIntWS(300), + } + + svc.updateCodexUsageSnapshot(context.Background(), 777, snapshot) + svc.updateCodexUsageSnapshot(context.Background(), 777, snapshot) + + select { + case <-repo.updateExtraCh: + case <-time.After(2 * time.Second): + t.Fatal("等待第一次 codex 快照落库超时") + } + + select { + case updates := <-repo.updateExtraCh: + t.Fatalf("unexpected second codex snapshot write: %v", updates) + case <-time.After(200 * time.Millisecond): + } +} + func ptrFloat64WS(v float64) *float64 { return &v } func ptrIntWS(v int) *int { return &v } diff --git a/backend/internal/service/ops_alert_evaluator_service.go b/backend/internal/service/ops_alert_evaluator_service.go index 169a5e32..88883180 100644 --- a/backend/internal/service/ops_alert_evaluator_service.go +++ b/backend/internal/service/ops_alert_evaluator_service.go @@ -506,6 +506,48 @@ func (s *OpsAlertEvaluatorService) computeRuleMetric( return float64(countAccountsByCondition(availability.Accounts, func(acc *AccountAvailability) bool { return acc.HasError && acc.TempUnschedulableUntil == nil })), true + case "group_rate_limit_ratio": + if groupID == nil || *groupID <= 0 { + return 0, false + } + if s == nil || s.opsService == nil { + return 0, false + } + availability, err := s.opsService.GetAccountAvailability(ctx, platform, groupID) + if err != nil || availability == nil { + return 0, false + } + if availability.Group == nil || availability.Group.TotalAccounts <= 0 { + return 0, true + } + return (float64(availability.Group.RateLimitCount) / float64(availability.Group.TotalAccounts)) * 100, true + case "account_error_ratio": + if s == nil || s.opsService == nil { + return 0, false + } + availability, err := s.opsService.GetAccountAvailability(ctx, platform, groupID) + if err != nil || availability == nil { + return 0, false + } + total := int64(len(availability.Accounts)) + if total <= 0 { + return 0, true + } + errorCount := countAccountsByCondition(availability.Accounts, func(acc *AccountAvailability) bool { + return acc.HasError && acc.TempUnschedulableUntil == nil + }) + return (float64(errorCount) / float64(total)) * 100, true + case "overload_account_count": + if s == nil || s.opsService == nil { + return 0, false + } + availability, err := s.opsService.GetAccountAvailability(ctx, platform, groupID) + if err != nil || availability == nil { + return 0, false + } + return float64(countAccountsByCondition(availability.Accounts, func(acc *AccountAvailability) bool { + return acc.IsOverloaded + })), true } overview, err := s.opsRepo.GetDashboardOverview(ctx, &OpsDashboardFilter{ diff --git a/backend/internal/service/ops_port.go b/backend/internal/service/ops_port.go index f3633eae..0ce9d425 100644 --- a/backend/internal/service/ops_port.go +++ b/backend/internal/service/ops_port.go @@ -7,6 +7,7 @@ import ( type OpsRepository interface { InsertErrorLog(ctx context.Context, input *OpsInsertErrorLogInput) (int64, error) + BatchInsertErrorLogs(ctx context.Context, inputs []*OpsInsertErrorLogInput) (int64, error) ListErrorLogs(ctx context.Context, filter *OpsErrorLogFilter) (*OpsErrorLogList, error) GetErrorLogByID(ctx context.Context, id int64) (*OpsErrorLogDetail, error) ListRequestDetails(ctx context.Context, filter *OpsRequestDetailFilter) ([]*OpsRequestDetail, int64, error) diff --git a/backend/internal/service/ops_repo_mock_test.go b/backend/internal/service/ops_repo_mock_test.go index e250dea3..c8c66ec6 100644 --- a/backend/internal/service/ops_repo_mock_test.go +++ b/backend/internal/service/ops_repo_mock_test.go @@ -7,6 +7,8 @@ import ( // opsRepoMock is a test-only OpsRepository implementation with optional function hooks. type opsRepoMock struct { + InsertErrorLogFn func(ctx context.Context, input *OpsInsertErrorLogInput) (int64, error) + BatchInsertErrorLogsFn func(ctx context.Context, inputs []*OpsInsertErrorLogInput) (int64, error) BatchInsertSystemLogsFn func(ctx context.Context, inputs []*OpsInsertSystemLogInput) (int64, error) ListSystemLogsFn func(ctx context.Context, filter *OpsSystemLogFilter) (*OpsSystemLogList, error) DeleteSystemLogsFn func(ctx context.Context, filter *OpsSystemLogCleanupFilter) (int64, error) @@ -14,9 +16,19 @@ type opsRepoMock struct { } func (m *opsRepoMock) InsertErrorLog(ctx context.Context, input *OpsInsertErrorLogInput) (int64, error) { + if m.InsertErrorLogFn != nil { + return m.InsertErrorLogFn(ctx, input) + } return 0, nil } +func (m *opsRepoMock) BatchInsertErrorLogs(ctx context.Context, inputs []*OpsInsertErrorLogInput) (int64, error) { + if m.BatchInsertErrorLogsFn != nil { + return m.BatchInsertErrorLogsFn(ctx, inputs) + } + return int64(len(inputs)), nil +} + func (m *opsRepoMock) ListErrorLogs(ctx context.Context, filter *OpsErrorLogFilter) (*OpsErrorLogList, error) { return &OpsErrorLogList{Errors: []*OpsErrorLog{}, Page: 1, PageSize: 20}, nil } diff --git a/backend/internal/service/ops_service.go b/backend/internal/service/ops_service.go index 767d1704..29f0aa8b 100644 --- a/backend/internal/service/ops_service.go +++ b/backend/internal/service/ops_service.go @@ -121,14 +121,74 @@ func (s *OpsService) IsMonitoringEnabled(ctx context.Context) bool { } func (s *OpsService) RecordError(ctx context.Context, entry *OpsInsertErrorLogInput, rawRequestBody []byte) error { - if entry == nil { + prepared, ok, err := s.prepareErrorLogInput(ctx, entry, rawRequestBody) + if err != nil { + log.Printf("[Ops] RecordError prepare failed: %v", err) + return err + } + if !ok { return nil } + + if _, err := s.opsRepo.InsertErrorLog(ctx, prepared); err != nil { + // Never bubble up to gateway; best-effort logging. + log.Printf("[Ops] RecordError failed: %v", err) + return err + } + return nil +} + +func (s *OpsService) RecordErrorBatch(ctx context.Context, entries []*OpsInsertErrorLogInput) error { + if len(entries) == 0 { + return nil + } + prepared := make([]*OpsInsertErrorLogInput, 0, len(entries)) + for _, entry := range entries { + item, ok, err := s.prepareErrorLogInput(ctx, entry, nil) + if err != nil { + log.Printf("[Ops] RecordErrorBatch prepare failed: %v", err) + continue + } + if ok { + prepared = append(prepared, item) + } + } + if len(prepared) == 0 { + return nil + } + if len(prepared) == 1 { + _, err := s.opsRepo.InsertErrorLog(ctx, prepared[0]) + if err != nil { + log.Printf("[Ops] RecordErrorBatch single insert failed: %v", err) + } + return err + } + + if _, err := s.opsRepo.BatchInsertErrorLogs(ctx, prepared); err != nil { + log.Printf("[Ops] RecordErrorBatch failed, fallback to single inserts: %v", err) + var firstErr error + for _, entry := range prepared { + if _, insertErr := s.opsRepo.InsertErrorLog(ctx, entry); insertErr != nil { + log.Printf("[Ops] RecordErrorBatch fallback insert failed: %v", insertErr) + if firstErr == nil { + firstErr = insertErr + } + } + } + return firstErr + } + return nil +} + +func (s *OpsService) prepareErrorLogInput(ctx context.Context, entry *OpsInsertErrorLogInput, rawRequestBody []byte) (*OpsInsertErrorLogInput, bool, error) { + if entry == nil { + return nil, false, nil + } if !s.IsMonitoringEnabled(ctx) { - return nil + return nil, false, nil } if s.opsRepo == nil { - return nil + return nil, false, nil } // Ensure timestamps are always populated. @@ -185,85 +245,88 @@ func (s *OpsService) RecordError(ctx context.Context, entry *OpsInsertErrorLogIn } } - // Sanitize + serialize upstream error events list. - if len(entry.UpstreamErrors) > 0 { - const maxEvents = 32 - events := entry.UpstreamErrors - if len(events) > maxEvents { - events = events[len(events)-maxEvents:] + if err := sanitizeOpsUpstreamErrors(entry); err != nil { + return nil, false, err + } + + return entry, true, nil +} + +func sanitizeOpsUpstreamErrors(entry *OpsInsertErrorLogInput) error { + if entry == nil || len(entry.UpstreamErrors) == 0 { + return nil + } + + const maxEvents = 32 + events := entry.UpstreamErrors + if len(events) > maxEvents { + events = events[len(events)-maxEvents:] + } + + sanitized := make([]*OpsUpstreamErrorEvent, 0, len(events)) + for _, ev := range events { + if ev == nil { + continue + } + out := *ev + + out.Platform = strings.TrimSpace(out.Platform) + out.UpstreamRequestID = truncateString(strings.TrimSpace(out.UpstreamRequestID), 128) + out.Kind = truncateString(strings.TrimSpace(out.Kind), 64) + + if out.AccountID < 0 { + out.AccountID = 0 + } + if out.UpstreamStatusCode < 0 { + out.UpstreamStatusCode = 0 + } + if out.AtUnixMs < 0 { + out.AtUnixMs = 0 } - sanitized := make([]*OpsUpstreamErrorEvent, 0, len(events)) - for _, ev := range events { - if ev == nil { - continue - } - out := *ev + msg := sanitizeUpstreamErrorMessage(strings.TrimSpace(out.Message)) + msg = truncateString(msg, 2048) + out.Message = msg - out.Platform = strings.TrimSpace(out.Platform) - out.UpstreamRequestID = truncateString(strings.TrimSpace(out.UpstreamRequestID), 128) - out.Kind = truncateString(strings.TrimSpace(out.Kind), 64) + detail := strings.TrimSpace(out.Detail) + if detail != "" { + // Keep upstream detail small; request bodies are not stored here, only upstream error payloads. + sanitizedDetail, _ := sanitizeErrorBodyForStorage(detail, opsMaxStoredErrorBodyBytes) + out.Detail = sanitizedDetail + } else { + out.Detail = "" + } - if out.AccountID < 0 { - out.AccountID = 0 - } - if out.UpstreamStatusCode < 0 { - out.UpstreamStatusCode = 0 - } - if out.AtUnixMs < 0 { - out.AtUnixMs = 0 - } - - msg := sanitizeUpstreamErrorMessage(strings.TrimSpace(out.Message)) - msg = truncateString(msg, 2048) - out.Message = msg - - detail := strings.TrimSpace(out.Detail) - if detail != "" { - // Keep upstream detail small; request bodies are not stored here, only upstream error payloads. - sanitizedDetail, _ := sanitizeErrorBodyForStorage(detail, opsMaxStoredErrorBodyBytes) - out.Detail = sanitizedDetail - } else { - out.Detail = "" - } - - out.UpstreamRequestBody = strings.TrimSpace(out.UpstreamRequestBody) - if out.UpstreamRequestBody != "" { - // Reuse the same sanitization/trimming strategy as request body storage. - // Keep it small so it is safe to persist in ops_error_logs JSON. - sanitized, truncated, _ := sanitizeAndTrimRequestBody([]byte(out.UpstreamRequestBody), 10*1024) - if sanitized != "" { - out.UpstreamRequestBody = sanitized - if truncated { - out.Kind = strings.TrimSpace(out.Kind) - if out.Kind == "" { - out.Kind = "upstream" - } - out.Kind = out.Kind + ":request_body_truncated" + out.UpstreamRequestBody = strings.TrimSpace(out.UpstreamRequestBody) + if out.UpstreamRequestBody != "" { + // Reuse the same sanitization/trimming strategy as request body storage. + // Keep it small so it is safe to persist in ops_error_logs JSON. + sanitizedBody, truncated, _ := sanitizeAndTrimRequestBody([]byte(out.UpstreamRequestBody), 10*1024) + if sanitizedBody != "" { + out.UpstreamRequestBody = sanitizedBody + if truncated { + out.Kind = strings.TrimSpace(out.Kind) + if out.Kind == "" { + out.Kind = "upstream" } - } else { - out.UpstreamRequestBody = "" + out.Kind = out.Kind + ":request_body_truncated" } + } else { + out.UpstreamRequestBody = "" } - - // Drop fully-empty events (can happen if only status code was known). - if out.UpstreamStatusCode == 0 && out.Message == "" && out.Detail == "" { - continue - } - - evCopy := out - sanitized = append(sanitized, &evCopy) } - entry.UpstreamErrorsJSON = marshalOpsUpstreamErrors(sanitized) - entry.UpstreamErrors = nil + // Drop fully-empty events (can happen if only status code was known). + if out.UpstreamStatusCode == 0 && out.Message == "" && out.Detail == "" { + continue + } + + evCopy := out + sanitized = append(sanitized, &evCopy) } - if _, err := s.opsRepo.InsertErrorLog(ctx, entry); err != nil { - // Never bubble up to gateway; best-effort logging. - log.Printf("[Ops] RecordError failed: %v", err) - return err - } + entry.UpstreamErrorsJSON = marshalOpsUpstreamErrors(sanitized) + entry.UpstreamErrors = nil return nil } diff --git a/backend/internal/service/ops_service_batch_test.go b/backend/internal/service/ops_service_batch_test.go new file mode 100644 index 00000000..f3a14d7f --- /dev/null +++ b/backend/internal/service/ops_service_batch_test.go @@ -0,0 +1,103 @@ +package service + +import ( + "context" + "errors" + "testing" + "time" + + "github.com/stretchr/testify/require" +) + +func TestOpsServiceRecordErrorBatch_SanitizesAndBatches(t *testing.T) { + t.Parallel() + + var captured []*OpsInsertErrorLogInput + repo := &opsRepoMock{ + BatchInsertErrorLogsFn: func(ctx context.Context, inputs []*OpsInsertErrorLogInput) (int64, error) { + captured = append(captured, inputs...) + return int64(len(inputs)), nil + }, + } + svc := NewOpsService(repo, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil) + + msg := " upstream failed: https://example.com?access_token=secret-value " + detail := `{"authorization":"Bearer secret-token"}` + entries := []*OpsInsertErrorLogInput{ + { + ErrorBody: `{"error":"bad","access_token":"secret"}`, + UpstreamStatusCode: intPtr(-10), + UpstreamErrorMessage: strPtr(msg), + UpstreamErrorDetail: strPtr(detail), + UpstreamErrors: []*OpsUpstreamErrorEvent{ + { + AccountID: -2, + UpstreamStatusCode: 429, + Message: " token leaked ", + Detail: `{"refresh_token":"secret"}`, + UpstreamRequestBody: `{"api_key":"secret","messages":[{"role":"user","content":"hello"}]}`, + }, + }, + }, + { + ErrorPhase: "upstream", + ErrorType: "upstream_error", + CreatedAt: time.Now().UTC(), + }, + } + + require.NoError(t, svc.RecordErrorBatch(context.Background(), entries)) + require.Len(t, captured, 2) + + first := captured[0] + require.Equal(t, "internal", first.ErrorPhase) + require.Equal(t, "api_error", first.ErrorType) + require.Nil(t, first.UpstreamStatusCode) + require.NotNil(t, first.UpstreamErrorMessage) + require.NotContains(t, *first.UpstreamErrorMessage, "secret-value") + require.Contains(t, *first.UpstreamErrorMessage, "access_token=***") + require.NotNil(t, first.UpstreamErrorDetail) + require.NotContains(t, *first.UpstreamErrorDetail, "secret-token") + require.NotContains(t, first.ErrorBody, "secret") + require.Nil(t, first.UpstreamErrors) + require.NotNil(t, first.UpstreamErrorsJSON) + require.NotContains(t, *first.UpstreamErrorsJSON, "secret") + require.Contains(t, *first.UpstreamErrorsJSON, "[REDACTED]") + + second := captured[1] + require.Equal(t, "upstream", second.ErrorPhase) + require.Equal(t, "upstream_error", second.ErrorType) + require.False(t, second.CreatedAt.IsZero()) +} + +func TestOpsServiceRecordErrorBatch_FallsBackToSingleInsert(t *testing.T) { + t.Parallel() + + var ( + batchCalls int + singleCalls int + ) + repo := &opsRepoMock{ + BatchInsertErrorLogsFn: func(ctx context.Context, inputs []*OpsInsertErrorLogInput) (int64, error) { + batchCalls++ + return 0, errors.New("batch failed") + }, + InsertErrorLogFn: func(ctx context.Context, input *OpsInsertErrorLogInput) (int64, error) { + singleCalls++ + return int64(singleCalls), nil + }, + } + svc := NewOpsService(repo, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil) + + err := svc.RecordErrorBatch(context.Background(), []*OpsInsertErrorLogInput{ + {ErrorMessage: "first"}, + {ErrorMessage: "second"}, + }) + require.NoError(t, err) + require.Equal(t, 1, batchCalls) + require.Equal(t, 2, singleCalls) +} + +func strPtr(v string) *string { + return &v +} diff --git a/backend/internal/service/subscription_reset_quota_test.go b/backend/internal/service/subscription_reset_quota_test.go new file mode 100644 index 00000000..36aa177f --- /dev/null +++ b/backend/internal/service/subscription_reset_quota_test.go @@ -0,0 +1,166 @@ +//go:build unit + +package service + +import ( + "context" + "errors" + "testing" + "time" + + "github.com/stretchr/testify/require" +) + +// resetQuotaUserSubRepoStub 支持 GetByID、ResetDailyUsage、ResetWeeklyUsage, +// 其余方法继承 userSubRepoNoop(panic)。 +type resetQuotaUserSubRepoStub struct { + userSubRepoNoop + + sub *UserSubscription + + resetDailyCalled bool + resetWeeklyCalled bool + resetDailyErr error + resetWeeklyErr error +} + +func (r *resetQuotaUserSubRepoStub) GetByID(_ context.Context, id int64) (*UserSubscription, error) { + if r.sub == nil || r.sub.ID != id { + return nil, ErrSubscriptionNotFound + } + cp := *r.sub + return &cp, nil +} + +func (r *resetQuotaUserSubRepoStub) ResetDailyUsage(_ context.Context, _ int64, windowStart time.Time) error { + r.resetDailyCalled = true + if r.resetDailyErr == nil && r.sub != nil { + r.sub.DailyUsageUSD = 0 + r.sub.DailyWindowStart = &windowStart + } + return r.resetDailyErr +} + +func (r *resetQuotaUserSubRepoStub) ResetWeeklyUsage(_ context.Context, _ int64, _ time.Time) error { + r.resetWeeklyCalled = true + return r.resetWeeklyErr +} + +func newResetQuotaSvc(stub *resetQuotaUserSubRepoStub) *SubscriptionService { + return NewSubscriptionService(groupRepoNoop{}, stub, nil, nil, nil) +} + +func TestAdminResetQuota_ResetBoth(t *testing.T) { + stub := &resetQuotaUserSubRepoStub{ + sub: &UserSubscription{ID: 1, UserID: 10, GroupID: 20}, + } + svc := newResetQuotaSvc(stub) + + result, err := svc.AdminResetQuota(context.Background(), 1, true, true) + + require.NoError(t, err) + require.NotNil(t, result) + require.True(t, stub.resetDailyCalled, "应调用 ResetDailyUsage") + require.True(t, stub.resetWeeklyCalled, "应调用 ResetWeeklyUsage") +} + +func TestAdminResetQuota_ResetDailyOnly(t *testing.T) { + stub := &resetQuotaUserSubRepoStub{ + sub: &UserSubscription{ID: 2, UserID: 10, GroupID: 20}, + } + svc := newResetQuotaSvc(stub) + + result, err := svc.AdminResetQuota(context.Background(), 2, true, false) + + require.NoError(t, err) + require.NotNil(t, result) + require.True(t, stub.resetDailyCalled, "应调用 ResetDailyUsage") + require.False(t, stub.resetWeeklyCalled, "不应调用 ResetWeeklyUsage") +} + +func TestAdminResetQuota_ResetWeeklyOnly(t *testing.T) { + stub := &resetQuotaUserSubRepoStub{ + sub: &UserSubscription{ID: 3, UserID: 10, GroupID: 20}, + } + svc := newResetQuotaSvc(stub) + + result, err := svc.AdminResetQuota(context.Background(), 3, false, true) + + require.NoError(t, err) + require.NotNil(t, result) + require.False(t, stub.resetDailyCalled, "不应调用 ResetDailyUsage") + require.True(t, stub.resetWeeklyCalled, "应调用 ResetWeeklyUsage") +} + +func TestAdminResetQuota_BothFalseReturnsError(t *testing.T) { + stub := &resetQuotaUserSubRepoStub{ + sub: &UserSubscription{ID: 7, UserID: 10, GroupID: 20}, + } + svc := newResetQuotaSvc(stub) + + _, err := svc.AdminResetQuota(context.Background(), 7, false, false) + + require.ErrorIs(t, err, ErrInvalidInput) + require.False(t, stub.resetDailyCalled) + require.False(t, stub.resetWeeklyCalled) +} + +func TestAdminResetQuota_SubscriptionNotFound(t *testing.T) { + stub := &resetQuotaUserSubRepoStub{sub: nil} + svc := newResetQuotaSvc(stub) + + _, err := svc.AdminResetQuota(context.Background(), 999, true, true) + + require.ErrorIs(t, err, ErrSubscriptionNotFound) + require.False(t, stub.resetDailyCalled) + require.False(t, stub.resetWeeklyCalled) +} + +func TestAdminResetQuota_ResetDailyUsageError(t *testing.T) { + dbErr := errors.New("db error") + stub := &resetQuotaUserSubRepoStub{ + sub: &UserSubscription{ID: 4, UserID: 10, GroupID: 20}, + resetDailyErr: dbErr, + } + svc := newResetQuotaSvc(stub) + + _, err := svc.AdminResetQuota(context.Background(), 4, true, true) + + require.ErrorIs(t, err, dbErr) + require.True(t, stub.resetDailyCalled) + require.False(t, stub.resetWeeklyCalled, "daily 失败后不应继续调用 weekly") +} + +func TestAdminResetQuota_ResetWeeklyUsageError(t *testing.T) { + dbErr := errors.New("db error") + stub := &resetQuotaUserSubRepoStub{ + sub: &UserSubscription{ID: 5, UserID: 10, GroupID: 20}, + resetWeeklyErr: dbErr, + } + svc := newResetQuotaSvc(stub) + + _, err := svc.AdminResetQuota(context.Background(), 5, false, true) + + require.ErrorIs(t, err, dbErr) + require.True(t, stub.resetWeeklyCalled) +} + +func TestAdminResetQuota_ReturnsRefreshedSub(t *testing.T) { + stub := &resetQuotaUserSubRepoStub{ + sub: &UserSubscription{ + ID: 6, + UserID: 10, + GroupID: 20, + DailyUsageUSD: 99.9, + }, + } + + svc := newResetQuotaSvc(stub) + result, err := svc.AdminResetQuota(context.Background(), 6, true, false) + + require.NoError(t, err) + // ResetDailyUsage stub 会将 sub.DailyUsageUSD 归零, + // 服务应返回第二次 GetByID 的刷新值而非初始的 99.9 + require.Equal(t, float64(0), result.DailyUsageUSD, "返回的订阅应反映已归零的用量") + require.True(t, stub.resetDailyCalled) +} diff --git a/backend/internal/service/subscription_service.go b/backend/internal/service/subscription_service.go index 57e04266..55f029fa 100644 --- a/backend/internal/service/subscription_service.go +++ b/backend/internal/service/subscription_service.go @@ -31,6 +31,7 @@ var ( ErrSubscriptionAlreadyExists = infraerrors.Conflict("SUBSCRIPTION_ALREADY_EXISTS", "subscription already exists for this user and group") ErrSubscriptionAssignConflict = infraerrors.Conflict("SUBSCRIPTION_ASSIGN_CONFLICT", "subscription exists but request conflicts with existing assignment semantics") ErrGroupNotSubscriptionType = infraerrors.BadRequest("GROUP_NOT_SUBSCRIPTION_TYPE", "group is not a subscription type") + ErrInvalidInput = infraerrors.BadRequest("INVALID_INPUT", "at least one of resetDaily or resetWeekly must be true") ErrDailyLimitExceeded = infraerrors.TooManyRequests("DAILY_LIMIT_EXCEEDED", "daily usage limit exceeded") ErrWeeklyLimitExceeded = infraerrors.TooManyRequests("WEEKLY_LIMIT_EXCEEDED", "weekly usage limit exceeded") ErrMonthlyLimitExceeded = infraerrors.TooManyRequests("MONTHLY_LIMIT_EXCEEDED", "monthly usage limit exceeded") @@ -695,6 +696,36 @@ func (s *SubscriptionService) CheckAndActivateWindow(ctx context.Context, sub *U return s.userSubRepo.ActivateWindows(ctx, sub.ID, windowStart) } +// AdminResetQuota manually resets the daily and/or weekly usage windows. +// Uses startOfDay(now) as the new window start, matching automatic resets. +func (s *SubscriptionService) AdminResetQuota(ctx context.Context, subscriptionID int64, resetDaily, resetWeekly bool) (*UserSubscription, error) { + if !resetDaily && !resetWeekly { + return nil, ErrInvalidInput + } + sub, err := s.userSubRepo.GetByID(ctx, subscriptionID) + if err != nil { + return nil, err + } + windowStart := startOfDay(time.Now()) + if resetDaily { + if err := s.userSubRepo.ResetDailyUsage(ctx, sub.ID, windowStart); err != nil { + return nil, err + } + } + if resetWeekly { + if err := s.userSubRepo.ResetWeeklyUsage(ctx, sub.ID, windowStart); err != nil { + return nil, err + } + } + // Invalidate caches, same as CheckAndResetWindows + s.InvalidateSubCache(sub.UserID, sub.GroupID) + if s.billingCacheService != nil { + _ = s.billingCacheService.InvalidateSubscription(ctx, sub.UserID, sub.GroupID) + } + // Return the refreshed subscription from DB + return s.userSubRepo.GetByID(ctx, subscriptionID) +} + // CheckAndResetWindows 检查并重置过期的窗口 func (s *SubscriptionService) CheckAndResetWindows(ctx context.Context, sub *UserSubscription) error { // 使用当天零点作为新窗口起始时间 diff --git a/backend/migrations/071_add_gemini25_flash_image_to_model_mapping.sql b/backend/migrations/071_add_gemini25_flash_image_to_model_mapping.sql new file mode 100644 index 00000000..f3cb3d37 --- /dev/null +++ b/backend/migrations/071_add_gemini25_flash_image_to_model_mapping.sql @@ -0,0 +1,51 @@ +-- Add gemini-2.5-flash-image aliases to Antigravity model_mapping +-- +-- Background: +-- Gemini native image generation now relies on gemini-2.5-flash-image, and +-- existing Antigravity accounts with persisted model_mapping need this alias in +-- order to participate in mixed scheduling from gemini groups. +-- +-- Strategy: +-- Overwrite the stored model_mapping so it matches DefaultAntigravityModelMapping +-- in constants.go, including legacy gemini-3-pro-image aliases. + +UPDATE accounts +SET credentials = jsonb_set( + credentials, + '{model_mapping}', + '{ + "claude-opus-4-6-thinking": "claude-opus-4-6-thinking", + "claude-opus-4-6": "claude-opus-4-6-thinking", + "claude-opus-4-5-thinking": "claude-opus-4-6-thinking", + "claude-opus-4-5-20251101": "claude-opus-4-6-thinking", + "claude-sonnet-4-6": "claude-sonnet-4-6", + "claude-sonnet-4-5": "claude-sonnet-4-5", + "claude-sonnet-4-5-thinking": "claude-sonnet-4-5-thinking", + "claude-sonnet-4-5-20250929": "claude-sonnet-4-5", + "claude-haiku-4-5": "claude-sonnet-4-5", + "claude-haiku-4-5-20251001": "claude-sonnet-4-5", + "gemini-2.5-flash": "gemini-2.5-flash", + "gemini-2.5-flash-image": "gemini-2.5-flash-image", + "gemini-2.5-flash-image-preview": "gemini-2.5-flash-image", + "gemini-2.5-flash-lite": "gemini-2.5-flash-lite", + "gemini-2.5-flash-thinking": "gemini-2.5-flash-thinking", + "gemini-2.5-pro": "gemini-2.5-pro", + "gemini-3-flash": "gemini-3-flash", + "gemini-3-pro-high": "gemini-3-pro-high", + "gemini-3-pro-low": "gemini-3-pro-low", + "gemini-3-flash-preview": "gemini-3-flash", + "gemini-3-pro-preview": "gemini-3-pro-high", + "gemini-3.1-pro-high": "gemini-3.1-pro-high", + "gemini-3.1-pro-low": "gemini-3.1-pro-low", + "gemini-3.1-pro-preview": "gemini-3.1-pro-high", + "gemini-3.1-flash-image": "gemini-3.1-flash-image", + "gemini-3.1-flash-image-preview": "gemini-3.1-flash-image", + "gemini-3-pro-image": "gemini-3.1-flash-image", + "gemini-3-pro-image-preview": "gemini-3.1-flash-image", + "gpt-oss-120b-medium": "gpt-oss-120b-medium", + "tab_flash_lite_preview": "tab_flash_lite_preview" + }'::jsonb +) +WHERE platform = 'antigravity' + AND deleted_at IS NULL + AND credentials->'model_mapping' IS NOT NULL; diff --git a/frontend/src/api/admin/subscriptions.ts b/frontend/src/api/admin/subscriptions.ts index 9f21056f..d06e0774 100644 --- a/frontend/src/api/admin/subscriptions.ts +++ b/frontend/src/api/admin/subscriptions.ts @@ -120,6 +120,23 @@ export async function revoke(id: number): Promise<{ message: string }> { return data } +/** + * Reset daily and/or weekly usage quota for a subscription + * @param id - Subscription ID + * @param options - Which windows to reset + * @returns Updated subscription + */ +export async function resetQuota( + id: number, + options: { daily: boolean; weekly: boolean } +): Promise { + const { data } = await apiClient.post( + `/admin/subscriptions/${id}/reset-quota`, + options + ) + return data +} + /** * List subscriptions by group * @param groupId - Group ID @@ -170,6 +187,7 @@ export const subscriptionsAPI = { bulkAssign, extend, revoke, + resetQuota, listByGroup, listByUser } diff --git a/frontend/src/components/account/AccountStatusIndicator.vue b/frontend/src/components/account/AccountStatusIndicator.vue index 1dc4f287..220b5c8b 100644 --- a/frontend/src/components/account/AccountStatusIndicator.vue +++ b/frontend/src/components/account/AccountStatusIndicator.vue @@ -176,6 +176,7 @@ const formatScopeName = (scope: string): string => { 'gemini-2.5-flash-lite': 'G25FL', 'gemini-2.5-flash-thinking': 'G25FT', 'gemini-2.5-pro': 'G25P', + 'gemini-2.5-flash-image': 'G25I', // Gemini 3 系列 'gemini-3-flash': 'G3F', 'gemini-3.1-pro-high': 'G3PH', diff --git a/frontend/src/components/account/AccountTestModal.vue b/frontend/src/components/account/AccountTestModal.vue index 792a8f45..e731a7b1 100644 --- a/frontend/src/components/account/AccountTestModal.vue +++ b/frontend/src/components/account/AccountTestModal.vue @@ -15,7 +15,7 @@
- +
{{ account.name }}
@@ -61,6 +61,17 @@ {{ t('admin.accounts.soraTestHint') }}
+
+