diff --git a/PR_REPORT_20260311_db_write_hotspots.md b/PR_REPORT_20260311_db_write_hotspots.md new file mode 100644 index 00000000..54db3c92 --- /dev/null +++ b/PR_REPORT_20260311_db_write_hotspots.md @@ -0,0 +1,307 @@ +# PR Report: DB 写入热点与后台查询拥塞排查 + +## 背景 + +线上在高峰期出现了几类明显症状: + +- 管理后台仪表盘接口经常超时,`/api/v1/admin/dashboard/snapshot-v2` 一度达到 50s 以上 +- 管理后台充值接口 `/api/v1/admin/users/:id/balance` 出现 15s 以上超时 +- 登录态刷新、扣费、错误记录在高峰期出现大量 `context deadline exceeded` +- PostgreSQL 曾出现连接打满,后续回退连接池后,主问题转为 WAL/刷盘拥塞 + +本报告基于 `/home/ius/sub2api` 当前源码,目标是给出一份可直接拆成 PR 的修复方案。 + +## 结论 + +这次故障的主因不是单一“慢 SQL”,而是请求成功路径上的同步写库次数过多,叠加部分后台查询仍直接扫 `usage_logs`,最终把 PostgreSQL 的 WAL 刷盘、热点行更新和 outbox 重建链路一起放大。 + +代码层面的核心问题有 6 个。 + +### 1. 成功请求路径同步写库过多 + +`backend/internal/service/gateway_service.go:6594` 的 `postUsageBilling` 在单次请求成功后,可能同步触发以下写操作: + +- `userRepo.DeductBalance` +- `APIKeyService.UpdateQuotaUsed` +- `APIKeyService.UpdateRateLimitUsage` +- `accountRepo.IncrementQuotaUsed` +- `deferredService.ScheduleLastUsedUpdate`(这一项已经做了延迟批量,是正确方向) + +也就是说,一次成功请求不是 1 次落库,而是 3 到 5 次写入。 + +这和线上看到的现象是吻合的: + +- `UPDATE accounts SET extra = ...` +- `INSERT INTO usage_logs ...` +- `INSERT INTO ops_error_logs ...` +- `scheduler_outbox` backlog + +### 2. API Key 配额更新存在额外读写放大 + +`backend/internal/service/api_key_service.go:815` 的 `UpdateQuotaUsed` 当前流程是: + +1. `IncrementQuotaUsed` +2. `GetByID` +3. 如超限再 `Update` + +对应仓储实现: + +- `backend/internal/repository/api_key_repo.go:441` 只做自增 +- 然后 service 再回表读取完整 API Key +- 之后可能再整行更新状态 + +这让“每次扣费后更新 API Key 配额”从 1 条 SQL 变成了最多 3 次数据库交互。 + +### 3. `accounts.extra` 被当成高频热写字段使用 + +两个最重的热点都落在 `accounts.extra`: + +- `backend/internal/repository/account_repo.go:1159` `UpdateExtra` +- `backend/internal/repository/account_repo.go:1683` `IncrementQuotaUsed` + +问题有两个: + +1. 两者都会重写整块 JSONB,并更新 `updated_at` +2. `UpdateExtra` 每次写完都会额外插入一条 `scheduler_outbox` + +尤其 `UpdateExtra` 现在被多处高频调用: + +- `backend/internal/service/openai_gateway_service.go:4039` 持久化 Codex rate-limit snapshot +- `backend/internal/service/ratelimit_service.go:903` 持久化 OpenAI Codex snapshot +- `backend/internal/service/ratelimit_service.go:1013` / `1025` 更新 session window utilization + +这类“监控/额度快照”并不会改变账号是否可调度,却仍然走了: + +- JSONB 更新 +- `updated_at` +- `scheduler_outbox` + +这是明显的写放大。 + +### 4. `scheduler_outbox` 设计偏向“每次状态变更都写一条”,高峰期会反压调度器 + +`backend/internal/repository/scheduler_outbox_repo.go:79` 的 `enqueueSchedulerOutbox` 非常轻,但它被大量调用。 + +例如: + +- `UpdateExtra` 每次都 enqueue `AccountChanged` +- `BatchUpdateLastUsed` 也会 enqueue 一条 `AccountLastUsed` +- 各类账号限流、过载、错误状态切换也都会 enqueue + +对应的 outbox worker 在: + +- `backend/internal/service/scheduler_snapshot_service.go:199` +- `backend/internal/service/scheduler_snapshot_service.go:219` + +它会不断拉取 outbox,再触发 `GetByID`、`rebuildBucket`、`loadAccountsFromDB`。 + +所以当高频写入导致 outbox 增长时,系统不仅多了写,还会反向带出更多读和缓存重建。 + +### 5. 仪表盘只有一部分走了预聚合,`models/groups/users-trend` 仍然直接扫 `usage_logs` + +好消息是,`dashboard stats` 本身已经接了预聚合表: + +- `backend/internal/repository/usage_log_repo.go:306` +- `backend/internal/repository/usage_log_repo.go:420` +- 预聚合表定义在 `backend/migrations/034_usage_dashboard_aggregation_tables.sql:1` + +但后台慢的不是只有 stats。 + +`snapshot-v2` 默认会同时拉: + +- stats +- trend +- model stats + +见: + +- `backend/internal/handler/admin/dashboard_snapshot_v2_handler.go:68` + +其中: + +- `GetUsageTrendWithFilters` 只有“无过滤、day/hour”时才走预聚合,见 `usage_log_repo.go:1657` +- `GetModelStatsWithFilters` 直接扫 `usage_logs`,见 `usage_log_repo.go:1805` +- `GetGroupStatsWithFilters` 直接扫 `usage_logs`,见 `usage_log_repo.go:1872` +- `GetUserUsageTrend` 直接扫 `usage_logs`,见 `usage_log_repo.go:1101` +- `GetAPIKeyUsageTrend` 直接扫 `usage_logs`,见 `usage_log_repo.go:1046` + +所以线上会出现: + +- stats 快 +- 但 `snapshot-v2` 仍然慢 +- `/admin/dashboard/users-trend` 单独也慢 + +这和你线上看到的日志完全一致。 + +### 6. 管理后台充值是“读用户 -> 整体更新用户 -> 插审计记录” + +`backend/internal/service/admin_service.go:694` 的 `UpdateUserBalance` 当前流程: + +1. `GetByID` +2. 在内存里改 balance +3. `userRepo.Update` +4. `redeemCodeRepo.Create` 记录 admin 调账历史 + +而 `userRepo.Update` 是整用户对象更新,并同步 allowed groups 事务处理: + +- `backend/internal/repository/user_repo.go:118` + +这个接口平时不一定重,但在数据库已经抖动时,会比一个原子 `UPDATE users SET balance = balance + $1` 更脆弱。 + +## 额外观察 + +### `ops_error_logs` 虽然已异步化,但单条写入仍然很重 + +错误日志中间件已经做了队列削峰: + +- `backend/internal/handler/ops_error_logger.go:69` +- `backend/internal/handler/ops_error_logger.go:106` + +这点方向是对的。 + +但落库表本身很重: + +- `backend/internal/repository/ops_repo.go:23` +- `backend/migrations/033_ops_monitoring_vnext.sql:69` +- `backend/migrations/033_ops_monitoring_vnext.sql:470` + +`ops_error_logs` 不仅列很多,还带了多组 B-Tree 和 trigram 索引。高错误率时,即使改成异步,也还是会把 WAL 和 I/O 压上去。 + +## 建议的 PR 拆分 + +建议拆成 4 个 PR,不要在一个 PR 里同时改数据库模型、后台查询和管理接口。 + +### PR 1: 收缩成功请求路径的同步写库次数 + +目标:把一次成功请求的同步写次数从 3 到 5 次,尽量压到 1 到 2 次。 + +建议改动: + +1. 把 `APIKeyService.UpdateQuotaUsed` 改为单 SQL + - 新增 repo 方法,例如 `IncrementQuotaUsedAndMaybeExhaust` + - 在 SQL 里同时完成 `quota_used += ?` 和 `status = quota_exhausted` + - 返回 `key/status/quota/quota_used` 最小字段,直接失效缓存 + - 删掉当前的 `Increment -> GetByID -> Update` + +2. 把账号 quota 计数从 `accounts.extra` 拆出去 + - 最理想:新增结构化列或独立 `account_quota_counters` 表 + - 次优:至少把 `quota_used/quota_daily_used/quota_weekly_used` 从 JSONB 中剥离 + +3. 对“纯监控型 extra 字段”禁止 enqueue outbox + - 例如 codex snapshot、session_window_utilization + - 这些字段不影响调度,不应该触发 `SchedulerOutboxEventAccountChanged` + +4. 复用现有 `DeferredService` 思路 + - `last_used` 已经是批量刷盘,见 `deferred_service.go:41` + - 可继续扩展 `deferred quota snapshot flush` + +预期收益: + +- 直接减少 WAL 写入量 +- 降低 `accounts` 热点行锁竞争 +- 降低 outbox 增长速度 + +### PR 2: 给 dashboard 补齐预聚合/缓存,避免继续扫 `usage_logs` + +目标:后台仪表盘接口不再直接扫描大表。 + +建议改动: + +1. 为 `users-trend` / `api-keys-trend` 增加小时/天级预聚合表 +2. 为 `model stats` / `group stats` 增加日级聚合表 +3. `snapshot-v2` 增加分段缓存 + - `stats` + - `trend` + - `models` + - `groups` + - `users_trend` + 避免一个 section miss 导致整份 snapshot 重新扫库 +4. 可选:把 `include_model_stats` 默认值从 `true` 改成 `false` + - 至少让默认仪表盘先恢复可用,再按需加载重模块 + +预期收益: + +- `snapshot-v2` +- `/admin/dashboard/users-trend` +- `/admin/dashboard/api-keys-trend` + +这几类接口会从“随数据量线性恶化”变成“近似固定成本”。 + +### PR 3: 简化管理后台充值链路 + +目标:管理充值/扣余额不再依赖整用户对象更新。 + +建议改动: + +1. 新增 repo 原子方法 + - `SetBalance(userID, amount)` + - `AddBalance(userID, delta)` + - `SubtractBalance(userID, delta)` + +2. `UpdateUserBalance` 改为: + - 先执行原子 SQL + - 再读一次最小必要字段返回 + - 审计记录改为异步或降级写 + +3. 审计记录建议改名或独立表 + - 现在把后台调账记录塞进 `redeem_codes`,语义上不干净 + +预期收益: + +- `/api/v1/admin/users/:id/balance` 在库抖时更稳 +- 失败面缩小,不再被 allowed groups 同步事务拖累 + +### PR 4: 为重写路径增加“丢弃策略”和“熔断指标” + +目标:高峰期先保护主链路,不让非核心写入拖死数据库。 + +建议改动: + +1. `ops_error_logs` + - 增加采样或分级开关 + - 对重复 429/5xx 做聚合计数而不是逐条插入 + - 对 request body / headers 存储加更严格开关 + +2. `scheduler_outbox` + - 增加 coalesce/merge 机制 + - 同一账号短时间内多次 `AccountChanged` 合并为一条 + +3. 指标补齐 + - outbox backlog + - ops error queue dropped + - deferred flush lag + - account extra write QPS + +## 推荐实施顺序 + +1. 先做 PR 1 + - 这是这次线上故障的主链路 +2. 再做 PR 2 + - 解决后台仪表盘慢 +3. 再做 PR 3 + - 解决后台充值接口脆弱 +4. 最后做 PR 4 + - 做长期保护 + +## 验证方案 + +每个 PR 合并前都建议做同一组验证: + +1. 压测成功请求链路,记录单请求 SQL 次数 +2. 观测 PostgreSQL: + - `pg_stat_activity` + - `pg_stat_statements` + - `WALWrite` / `WalSync` + - 每分钟 WAL 增量 +3. 观测接口: + - `/api/v1/auth/refresh` + - `/api/v1/admin/dashboard/snapshot-v2` + - `/api/v1/admin/dashboard/users-trend` + - `/api/v1/admin/users/:id/balance` +4. 观测队列: + - `ops_error_logs` queue length / dropped + - `scheduler_outbox` backlog + +## 可直接作为 PR 描述的摘要 + +This PR reduces database write amplification on the request success path and removes several hot-path writes from `accounts.extra` + `scheduler_outbox`. It also prepares dashboard endpoints to rely on pre-aggregated data instead of scanning `usage_logs` under load. The goal is to keep admin dashboard, balance update, auth refresh, and billing-related paths stable under sustained 500+ RPS traffic. diff --git a/backend/cmd/server/wire_gen.go b/backend/cmd/server/wire_gen.go index 034c70ec..76b2d0db 100644 --- a/backend/cmd/server/wire_gen.go +++ b/backend/cmd/server/wire_gen.go @@ -81,6 +81,7 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) { userHandler := handler.NewUserHandler(userService) apiKeyHandler := handler.NewAPIKeyHandler(apiKeyService) usageLogRepository := repository.NewUsageLogRepository(client, db) + usageBillingRepository := repository.NewUsageBillingRepository(client, db) usageService := service.NewUsageService(usageLogRepository, userRepository, client, apiKeyAuthCacheInvalidator) usageHandler := handler.NewUsageHandler(usageService, apiKeyService) redeemHandler := handler.NewRedeemHandler(redeemService) @@ -162,9 +163,9 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) { deferredService := service.ProvideDeferredService(accountRepository, timingWheelService) claudeTokenProvider := service.NewClaudeTokenProvider(accountRepository, geminiTokenCache, oAuthService) digestSessionStore := service.NewDigestSessionStore() - gatewayService := service.NewGatewayService(accountRepository, groupRepository, usageLogRepository, userRepository, userSubscriptionRepository, userGroupRateRepository, gatewayCache, configConfig, schedulerSnapshotService, concurrencyService, billingService, rateLimitService, billingCacheService, identityService, httpUpstream, deferredService, claudeTokenProvider, sessionLimitCache, rpmCache, digestSessionStore, settingService) + gatewayService := service.NewGatewayService(accountRepository, groupRepository, usageLogRepository, usageBillingRepository, userRepository, userSubscriptionRepository, userGroupRateRepository, gatewayCache, configConfig, schedulerSnapshotService, concurrencyService, billingService, rateLimitService, billingCacheService, identityService, httpUpstream, deferredService, claudeTokenProvider, sessionLimitCache, rpmCache, digestSessionStore, settingService) openAITokenProvider := service.NewOpenAITokenProvider(accountRepository, geminiTokenCache, openAIOAuthService) - openAIGatewayService := service.NewOpenAIGatewayService(accountRepository, usageLogRepository, userRepository, userSubscriptionRepository, userGroupRateRepository, gatewayCache, configConfig, schedulerSnapshotService, concurrencyService, billingService, rateLimitService, billingCacheService, httpUpstream, deferredService, openAITokenProvider) + openAIGatewayService := service.NewOpenAIGatewayService(accountRepository, usageLogRepository, usageBillingRepository, userRepository, userSubscriptionRepository, userGroupRateRepository, gatewayCache, configConfig, schedulerSnapshotService, concurrencyService, billingService, rateLimitService, billingCacheService, httpUpstream, deferredService, openAITokenProvider) geminiMessagesCompatService := service.NewGeminiMessagesCompatService(accountRepository, groupRepository, gatewayCache, schedulerSnapshotService, geminiTokenProvider, rateLimitService, httpUpstream, antigravityGatewayService, configConfig) opsSystemLogSink := service.ProvideOpsSystemLogSink(opsRepository) opsService := service.NewOpsService(opsRepository, settingRepository, configConfig, accountRepository, userRepository, concurrencyService, gatewayService, openAIGatewayService, geminiMessagesCompatService, antigravityGatewayService, opsSystemLogSink) diff --git a/backend/internal/config/config.go b/backend/internal/config/config.go index de876098..e90e56af 100644 --- a/backend/internal/config/config.go +++ b/backend/internal/config/config.go @@ -934,9 +934,10 @@ type DashboardAggregationConfig struct { // DashboardAggregationRetentionConfig 预聚合保留窗口 type DashboardAggregationRetentionConfig struct { - UsageLogsDays int `mapstructure:"usage_logs_days"` - HourlyDays int `mapstructure:"hourly_days"` - DailyDays int `mapstructure:"daily_days"` + UsageLogsDays int `mapstructure:"usage_logs_days"` + UsageBillingDedupDays int `mapstructure:"usage_billing_dedup_days"` + HourlyDays int `mapstructure:"hourly_days"` + DailyDays int `mapstructure:"daily_days"` } // UsageCleanupConfig 使用记录清理任务配置 @@ -1301,6 +1302,7 @@ func setDefaults() { viper.SetDefault("dashboard_aggregation.backfill_enabled", false) viper.SetDefault("dashboard_aggregation.backfill_max_days", 31) viper.SetDefault("dashboard_aggregation.retention.usage_logs_days", 90) + viper.SetDefault("dashboard_aggregation.retention.usage_billing_dedup_days", 365) viper.SetDefault("dashboard_aggregation.retention.hourly_days", 180) viper.SetDefault("dashboard_aggregation.retention.daily_days", 730) viper.SetDefault("dashboard_aggregation.recompute_days", 2) @@ -1758,6 +1760,12 @@ func (c *Config) Validate() error { if c.DashboardAgg.Retention.UsageLogsDays <= 0 { return fmt.Errorf("dashboard_aggregation.retention.usage_logs_days must be positive") } + if c.DashboardAgg.Retention.UsageBillingDedupDays <= 0 { + return fmt.Errorf("dashboard_aggregation.retention.usage_billing_dedup_days must be positive") + } + if c.DashboardAgg.Retention.UsageBillingDedupDays < c.DashboardAgg.Retention.UsageLogsDays { + return fmt.Errorf("dashboard_aggregation.retention.usage_billing_dedup_days must be greater than or equal to usage_logs_days") + } if c.DashboardAgg.Retention.HourlyDays <= 0 { return fmt.Errorf("dashboard_aggregation.retention.hourly_days must be positive") } @@ -1780,6 +1788,14 @@ func (c *Config) Validate() error { if c.DashboardAgg.Retention.UsageLogsDays < 0 { return fmt.Errorf("dashboard_aggregation.retention.usage_logs_days must be non-negative") } + if c.DashboardAgg.Retention.UsageBillingDedupDays < 0 { + return fmt.Errorf("dashboard_aggregation.retention.usage_billing_dedup_days must be non-negative") + } + if c.DashboardAgg.Retention.UsageBillingDedupDays > 0 && + c.DashboardAgg.Retention.UsageLogsDays > 0 && + c.DashboardAgg.Retention.UsageBillingDedupDays < c.DashboardAgg.Retention.UsageLogsDays { + return fmt.Errorf("dashboard_aggregation.retention.usage_billing_dedup_days must be greater than or equal to usage_logs_days") + } if c.DashboardAgg.Retention.HourlyDays < 0 { return fmt.Errorf("dashboard_aggregation.retention.hourly_days must be non-negative") } diff --git a/backend/internal/config/config_test.go b/backend/internal/config/config_test.go index 79fcc6d0..abb76549 100644 --- a/backend/internal/config/config_test.go +++ b/backend/internal/config/config_test.go @@ -441,6 +441,9 @@ func TestLoadDefaultDashboardAggregationConfig(t *testing.T) { if cfg.DashboardAgg.Retention.UsageLogsDays != 90 { t.Fatalf("DashboardAgg.Retention.UsageLogsDays = %d, want 90", cfg.DashboardAgg.Retention.UsageLogsDays) } + if cfg.DashboardAgg.Retention.UsageBillingDedupDays != 365 { + t.Fatalf("DashboardAgg.Retention.UsageBillingDedupDays = %d, want 365", cfg.DashboardAgg.Retention.UsageBillingDedupDays) + } if cfg.DashboardAgg.Retention.HourlyDays != 180 { t.Fatalf("DashboardAgg.Retention.HourlyDays = %d, want 180", cfg.DashboardAgg.Retention.HourlyDays) } @@ -1016,6 +1019,23 @@ func TestValidateConfigErrors(t *testing.T) { mutate: func(c *Config) { c.DashboardAgg.Enabled = true; c.DashboardAgg.Retention.UsageLogsDays = 0 }, wantErr: "dashboard_aggregation.retention.usage_logs_days", }, + { + name: "dashboard aggregation dedup retention", + mutate: func(c *Config) { + c.DashboardAgg.Enabled = true + c.DashboardAgg.Retention.UsageBillingDedupDays = 0 + }, + wantErr: "dashboard_aggregation.retention.usage_billing_dedup_days", + }, + { + name: "dashboard aggregation dedup retention smaller than usage logs", + mutate: func(c *Config) { + c.DashboardAgg.Enabled = true + c.DashboardAgg.Retention.UsageLogsDays = 30 + c.DashboardAgg.Retention.UsageBillingDedupDays = 29 + }, + wantErr: "dashboard_aggregation.retention.usage_billing_dedup_days", + }, { name: "dashboard aggregation disabled interval", mutate: func(c *Config) { c.DashboardAgg.Enabled = false; c.DashboardAgg.IntervalSeconds = -1 }, diff --git a/backend/internal/handler/gateway_handler.go b/backend/internal/handler/gateway_handler.go index 4441cf07..676ba0e1 100644 --- a/backend/internal/handler/gateway_handler.go +++ b/backend/internal/handler/gateway_handler.go @@ -434,19 +434,21 @@ func (h *GatewayHandler) Messages(c *gin.Context) { // 捕获请求信息(用于异步记录,避免在 goroutine 中访问 gin.Context) userAgent := c.GetHeader("User-Agent") clientIP := ip.GetClientIP(c) + requestPayloadHash := service.HashUsageRequestPayload(body) // 使用量记录通过有界 worker 池提交,避免请求热路径创建无界 goroutine。 h.submitUsageRecordTask(func(ctx context.Context) { if err := h.gatewayService.RecordUsage(ctx, &service.RecordUsageInput{ - Result: result, - APIKey: apiKey, - User: apiKey.User, - Account: account, - Subscription: subscription, - UserAgent: userAgent, - IPAddress: clientIP, - ForceCacheBilling: fs.ForceCacheBilling, - APIKeyService: h.apiKeyService, + Result: result, + APIKey: apiKey, + User: apiKey.User, + Account: account, + Subscription: subscription, + UserAgent: userAgent, + IPAddress: clientIP, + RequestPayloadHash: requestPayloadHash, + ForceCacheBilling: fs.ForceCacheBilling, + APIKeyService: h.apiKeyService, }); err != nil { logger.L().With( zap.String("component", "handler.gateway.messages"), @@ -736,19 +738,21 @@ func (h *GatewayHandler) Messages(c *gin.Context) { // 捕获请求信息(用于异步记录,避免在 goroutine 中访问 gin.Context) userAgent := c.GetHeader("User-Agent") clientIP := ip.GetClientIP(c) + requestPayloadHash := service.HashUsageRequestPayload(body) // 使用量记录通过有界 worker 池提交,避免请求热路径创建无界 goroutine。 h.submitUsageRecordTask(func(ctx context.Context) { if err := h.gatewayService.RecordUsage(ctx, &service.RecordUsageInput{ - Result: result, - APIKey: currentAPIKey, - User: currentAPIKey.User, - Account: account, - Subscription: currentSubscription, - UserAgent: userAgent, - IPAddress: clientIP, - ForceCacheBilling: fs.ForceCacheBilling, - APIKeyService: h.apiKeyService, + Result: result, + APIKey: currentAPIKey, + User: currentAPIKey.User, + Account: account, + Subscription: currentSubscription, + UserAgent: userAgent, + IPAddress: clientIP, + RequestPayloadHash: requestPayloadHash, + ForceCacheBilling: fs.ForceCacheBilling, + APIKeyService: h.apiKeyService, }); err != nil { logger.L().With( zap.String("component", "handler.gateway.messages"), diff --git a/backend/internal/handler/gateway_handler_warmup_intercept_unit_test.go b/backend/internal/handler/gateway_handler_warmup_intercept_unit_test.go index 0c94d50b..6bcc0003 100644 --- a/backend/internal/handler/gateway_handler_warmup_intercept_unit_test.go +++ b/backend/internal/handler/gateway_handler_warmup_intercept_unit_test.go @@ -139,6 +139,7 @@ func newTestGatewayHandler(t *testing.T, group *service.Group, accounts []*servi nil, // accountRepo (not used: scheduler snapshot hit) &fakeGroupRepo{group: group}, nil, // usageLogRepo + nil, // usageBillingRepo nil, // userRepo nil, // userSubRepo nil, // userGroupRateRepo diff --git a/backend/internal/handler/gemini_v1beta_handler.go b/backend/internal/handler/gemini_v1beta_handler.go index 50af9c8f..9a16ff3a 100644 --- a/backend/internal/handler/gemini_v1beta_handler.go +++ b/backend/internal/handler/gemini_v1beta_handler.go @@ -503,6 +503,7 @@ func (h *GatewayHandler) GeminiV1BetaModels(c *gin.Context) { } // 使用量记录通过有界 worker 池提交,避免请求热路径创建无界 goroutine。 + requestPayloadHash := service.HashUsageRequestPayload(body) h.submitUsageRecordTask(func(ctx context.Context) { if err := h.gatewayService.RecordUsageWithLongContext(ctx, &service.RecordUsageLongContextInput{ Result: result, @@ -512,6 +513,7 @@ func (h *GatewayHandler) GeminiV1BetaModels(c *gin.Context) { Subscription: subscription, UserAgent: userAgent, IPAddress: clientIP, + RequestPayloadHash: requestPayloadHash, LongContextThreshold: 200000, // Gemini 200K 阈值 LongContextMultiplier: 2.0, // 超出部分双倍计费 ForceCacheBilling: fs.ForceCacheBilling, diff --git a/backend/internal/handler/openai_gateway_handler.go b/backend/internal/handler/openai_gateway_handler.go index 8567b52b..d23c7efe 100644 --- a/backend/internal/handler/openai_gateway_handler.go +++ b/backend/internal/handler/openai_gateway_handler.go @@ -352,18 +352,20 @@ func (h *OpenAIGatewayHandler) Responses(c *gin.Context) { // 捕获请求信息(用于异步记录,避免在 goroutine 中访问 gin.Context) userAgent := c.GetHeader("User-Agent") clientIP := ip.GetClientIP(c) + requestPayloadHash := service.HashUsageRequestPayload(body) // 使用量记录通过有界 worker 池提交,避免请求热路径创建无界 goroutine。 h.submitUsageRecordTask(func(ctx context.Context) { if err := h.gatewayService.RecordUsage(ctx, &service.OpenAIRecordUsageInput{ - Result: result, - APIKey: apiKey, - User: apiKey.User, - Account: account, - Subscription: subscription, - UserAgent: userAgent, - IPAddress: clientIP, - APIKeyService: h.apiKeyService, + Result: result, + APIKey: apiKey, + User: apiKey.User, + Account: account, + Subscription: subscription, + UserAgent: userAgent, + IPAddress: clientIP, + RequestPayloadHash: requestPayloadHash, + APIKeyService: h.apiKeyService, }); err != nil { logger.L().With( zap.String("component", "handler.openai_gateway.responses"), @@ -732,17 +734,19 @@ func (h *OpenAIGatewayHandler) Messages(c *gin.Context) { userAgent := c.GetHeader("User-Agent") clientIP := ip.GetClientIP(c) + requestPayloadHash := service.HashUsageRequestPayload(body) h.submitUsageRecordTask(func(ctx context.Context) { if err := h.gatewayService.RecordUsage(ctx, &service.OpenAIRecordUsageInput{ - Result: result, - APIKey: apiKey, - User: apiKey.User, - Account: account, - Subscription: subscription, - UserAgent: userAgent, - IPAddress: clientIP, - APIKeyService: h.apiKeyService, + Result: result, + APIKey: apiKey, + User: apiKey.User, + Account: account, + Subscription: subscription, + UserAgent: userAgent, + IPAddress: clientIP, + RequestPayloadHash: requestPayloadHash, + APIKeyService: h.apiKeyService, }); err != nil { logger.L().With( zap.String("component", "handler.openai_gateway.messages"), @@ -1231,14 +1235,15 @@ func (h *OpenAIGatewayHandler) ResponsesWebSocket(c *gin.Context) { h.gatewayService.ReportOpenAIAccountScheduleResult(account.ID, true, result.FirstTokenMs) h.submitUsageRecordTask(func(taskCtx context.Context) { if err := h.gatewayService.RecordUsage(taskCtx, &service.OpenAIRecordUsageInput{ - Result: result, - APIKey: apiKey, - User: apiKey.User, - Account: account, - Subscription: subscription, - UserAgent: userAgent, - IPAddress: clientIP, - APIKeyService: h.apiKeyService, + Result: result, + APIKey: apiKey, + User: apiKey.User, + Account: account, + Subscription: subscription, + UserAgent: userAgent, + IPAddress: clientIP, + RequestPayloadHash: service.HashUsageRequestPayload(firstMessage), + APIKeyService: h.apiKeyService, }); err != nil { reqLog.Error("openai.websocket_record_usage_failed", zap.Int64("account_id", account.ID), diff --git a/backend/internal/handler/sora_client_handler_test.go b/backend/internal/handler/sora_client_handler_test.go index 30a761bd..dab17673 100644 --- a/backend/internal/handler/sora_client_handler_test.go +++ b/backend/internal/handler/sora_client_handler_test.go @@ -2206,7 +2206,7 @@ func (s *stubSoraClientForHandler) GetVideoTask(_ context.Context, _ *service.Ac // newMinimalGatewayService 创建仅包含 accountRepo 的最小 GatewayService(用于测试 SelectAccountForModel)。 func newMinimalGatewayService(accountRepo service.AccountRepository) *service.GatewayService { return service.NewGatewayService( - accountRepo, nil, nil, nil, nil, nil, nil, nil, + accountRepo, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil, ) } diff --git a/backend/internal/handler/sora_gateway_handler.go b/backend/internal/handler/sora_gateway_handler.go index 48c1e451..06abdf60 100644 --- a/backend/internal/handler/sora_gateway_handler.go +++ b/backend/internal/handler/sora_gateway_handler.go @@ -399,17 +399,19 @@ func (h *SoraGatewayHandler) ChatCompletions(c *gin.Context) { userAgent := c.GetHeader("User-Agent") clientIP := ip.GetClientIP(c) + requestPayloadHash := service.HashUsageRequestPayload(body) // 使用量记录通过有界 worker 池提交,避免请求热路径创建无界 goroutine。 h.submitUsageRecordTask(func(ctx context.Context) { if err := h.gatewayService.RecordUsage(ctx, &service.RecordUsageInput{ - Result: result, - APIKey: apiKey, - User: apiKey.User, - Account: account, - Subscription: subscription, - UserAgent: userAgent, - IPAddress: clientIP, + Result: result, + APIKey: apiKey, + User: apiKey.User, + Account: account, + Subscription: subscription, + UserAgent: userAgent, + IPAddress: clientIP, + RequestPayloadHash: requestPayloadHash, }); err != nil { logger.L().With( zap.String("component", "handler.sora_gateway.chat_completions"), diff --git a/backend/internal/handler/sora_gateway_handler_test.go b/backend/internal/handler/sora_gateway_handler_test.go index 688c5d12..088946e7 100644 --- a/backend/internal/handler/sora_gateway_handler_test.go +++ b/backend/internal/handler/sora_gateway_handler_test.go @@ -431,6 +431,7 @@ func TestSoraGatewayHandler_ChatCompletions(t *testing.T) { nil, nil, nil, + nil, testutil.StubGatewayCache{}, cfg, nil, diff --git a/backend/internal/repository/dashboard_aggregation_repo.go b/backend/internal/repository/dashboard_aggregation_repo.go index 59bbd6a3..e82a73a3 100644 --- a/backend/internal/repository/dashboard_aggregation_repo.go +++ b/backend/internal/repository/dashboard_aggregation_repo.go @@ -17,6 +17,9 @@ type dashboardAggregationRepository struct { sql sqlExecutor } +const usageLogsCleanupBatchSize = 10000 +const usageBillingDedupCleanupBatchSize = 10000 + // NewDashboardAggregationRepository 创建仪表盘预聚合仓储。 func NewDashboardAggregationRepository(sqlDB *sql.DB) service.DashboardAggregationRepository { if sqlDB == nil { @@ -42,6 +45,9 @@ func isPostgresDriver(db *sql.DB) bool { } func (r *dashboardAggregationRepository) AggregateRange(ctx context.Context, start, end time.Time) error { + if r == nil || r.sql == nil { + return nil + } loc := timezone.Location() startLocal := start.In(loc) endLocal := end.In(loc) @@ -61,6 +67,22 @@ func (r *dashboardAggregationRepository) AggregateRange(ctx context.Context, sta dayEnd = dayEnd.Add(24 * time.Hour) } + if db, ok := r.sql.(*sql.DB); ok { + tx, err := db.BeginTx(ctx, nil) + if err != nil { + return err + } + txRepo := newDashboardAggregationRepositoryWithSQL(tx) + if err := txRepo.aggregateRangeInTx(ctx, hourStart, hourEnd, dayStart, dayEnd); err != nil { + _ = tx.Rollback() + return err + } + return tx.Commit() + } + return r.aggregateRangeInTx(ctx, hourStart, hourEnd, dayStart, dayEnd) +} + +func (r *dashboardAggregationRepository) aggregateRangeInTx(ctx context.Context, hourStart, hourEnd, dayStart, dayEnd time.Time) error { // 以桶边界聚合,允许覆盖 end 所在桶的剩余区间。 if err := r.insertHourlyActiveUsers(ctx, hourStart, hourEnd); err != nil { return err @@ -195,8 +217,58 @@ func (r *dashboardAggregationRepository) CleanupUsageLogs(ctx context.Context, c if isPartitioned { return r.dropUsageLogsPartitions(ctx, cutoff) } - _, err = r.sql.ExecContext(ctx, "DELETE FROM usage_logs WHERE created_at < $1", cutoff.UTC()) - return err + for { + res, err := r.sql.ExecContext(ctx, ` + WITH victims AS ( + SELECT ctid + FROM usage_logs + WHERE created_at < $1 + LIMIT $2 + ) + DELETE FROM usage_logs + WHERE ctid IN (SELECT ctid FROM victims) + `, cutoff.UTC(), usageLogsCleanupBatchSize) + if err != nil { + return err + } + affected, err := res.RowsAffected() + if err != nil { + return err + } + if affected < usageLogsCleanupBatchSize { + return nil + } + } +} + +func (r *dashboardAggregationRepository) CleanupUsageBillingDedup(ctx context.Context, cutoff time.Time) error { + for { + res, err := r.sql.ExecContext(ctx, ` + WITH victims AS ( + SELECT ctid, request_id, api_key_id, request_fingerprint, created_at + FROM usage_billing_dedup + WHERE created_at < $1 + LIMIT $2 + ), archived AS ( + INSERT INTO usage_billing_dedup_archive (request_id, api_key_id, request_fingerprint, created_at) + SELECT request_id, api_key_id, request_fingerprint, created_at + FROM victims + ON CONFLICT (request_id, api_key_id) DO NOTHING + ) + DELETE FROM usage_billing_dedup + WHERE ctid IN (SELECT ctid FROM victims) + `, cutoff.UTC(), usageBillingDedupCleanupBatchSize) + if err != nil { + return err + } + affected, err := res.RowsAffected() + if err != nil { + return err + } + if affected < usageBillingDedupCleanupBatchSize { + return nil + } + } } func (r *dashboardAggregationRepository) EnsureUsageLogsPartitions(ctx context.Context, now time.Time) error { diff --git a/backend/internal/repository/migrations_schema_integration_test.go b/backend/internal/repository/migrations_schema_integration_test.go index 72422d18..dd3019bb 100644 --- a/backend/internal/repository/migrations_schema_integration_test.go +++ b/backend/internal/repository/migrations_schema_integration_test.go @@ -45,6 +45,20 @@ func TestMigrationsRunner_IsIdempotent_AndSchemaIsUpToDate(t *testing.T) { requireColumn(t, tx, "usage_logs", "request_type", "smallint", 0, false) requireColumn(t, tx, "usage_logs", "openai_ws_mode", "boolean", 0, false) + // usage_billing_dedup: billing idempotency narrow table + var usageBillingDedupRegclass sql.NullString + require.NoError(t, tx.QueryRowContext(context.Background(), "SELECT to_regclass('public.usage_billing_dedup')").Scan(&usageBillingDedupRegclass)) + require.True(t, usageBillingDedupRegclass.Valid, "expected usage_billing_dedup table to exist") + requireColumn(t, tx, "usage_billing_dedup", "request_fingerprint", "character varying", 64, false) + requireIndex(t, tx, "usage_billing_dedup", "idx_usage_billing_dedup_request_api_key") + requireIndex(t, tx, "usage_billing_dedup", "idx_usage_billing_dedup_created_at_brin") + + var usageBillingDedupArchiveRegclass sql.NullString + require.NoError(t, tx.QueryRowContext(context.Background(), "SELECT to_regclass('public.usage_billing_dedup_archive')").Scan(&usageBillingDedupArchiveRegclass)) + require.True(t, usageBillingDedupArchiveRegclass.Valid, "expected usage_billing_dedup_archive table to exist") + requireColumn(t, tx, "usage_billing_dedup_archive", "request_fingerprint", "character varying", 64, false) + requireIndex(t, tx, "usage_billing_dedup_archive", "usage_billing_dedup_archive_pkey") + // settings table should exist var settingsRegclass sql.NullString require.NoError(t, tx.QueryRowContext(context.Background(), "SELECT to_regclass('public.settings')").Scan(&settingsRegclass)) @@ -75,6 +89,23 @@ func TestMigrationsRunner_IsIdempotent_AndSchemaIsUpToDate(t *testing.T) { requireColumn(t, tx, "user_allowed_groups", "created_at", "timestamp with time zone", 0, false) } +func requireIndex(t *testing.T, tx *sql.Tx, table, index string) { + t.Helper() + + var exists bool + err := tx.QueryRowContext(context.Background(), ` +SELECT EXISTS ( + SELECT 1 + FROM pg_indexes + WHERE schemaname = 'public' + AND tablename = $1 + AND indexname = $2 +) +`, table, index).Scan(&exists) + require.NoError(t, err, "query pg_indexes for %s.%s", table, index) + require.True(t, exists, "expected index %s on %s", index, table) +} + func requireColumn(t *testing.T, tx *sql.Tx, table, column, dataType string, maxLen int, nullable bool) { t.Helper() diff --git a/backend/internal/repository/usage_billing_repo.go b/backend/internal/repository/usage_billing_repo.go new file mode 100644 index 00000000..b13cfeb8 --- /dev/null +++ b/backend/internal/repository/usage_billing_repo.go @@ -0,0 +1,308 @@ +package repository + +import ( + "context" + "database/sql" + "errors" + "strings" + + dbent "github.com/Wei-Shaw/sub2api/ent" + "github.com/Wei-Shaw/sub2api/internal/pkg/logger" + "github.com/Wei-Shaw/sub2api/internal/service" +) + +type usageBillingRepository struct { + db *sql.DB +} + +func NewUsageBillingRepository(_ *dbent.Client, sqlDB *sql.DB) service.UsageBillingRepository { + return &usageBillingRepository{db: sqlDB} +} + +func (r *usageBillingRepository) Apply(ctx context.Context, cmd *service.UsageBillingCommand) (_ *service.UsageBillingApplyResult, err error) { + if cmd == nil { + return &service.UsageBillingApplyResult{}, nil + } + if r == nil || r.db == nil { + return nil, errors.New("usage billing repository db is nil") + } + + cmd.Normalize() + if cmd.RequestID == "" { + return nil, service.ErrUsageBillingRequestIDRequired + } + + tx, err := r.db.BeginTx(ctx, nil) + if err != nil { + return nil, err + } + defer func() { + if tx != nil { + _ = tx.Rollback() + } + }() + + applied, err := r.claimUsageBillingKey(ctx, tx, cmd) + if err != nil { + return nil, err + } + if !applied { + return &service.UsageBillingApplyResult{Applied: false}, nil + } + + result := &service.UsageBillingApplyResult{Applied: true} + if err := r.applyUsageBillingEffects(ctx, tx, cmd, result); err != nil { + return nil, err + } + + if err := tx.Commit(); err != nil { + return nil, err + } + tx = nil + return result, nil +} + +func (r *usageBillingRepository) claimUsageBillingKey(ctx context.Context, tx *sql.Tx, cmd *service.UsageBillingCommand) (bool, error) { + var id int64 + err := tx.QueryRowContext(ctx, ` + INSERT INTO usage_billing_dedup (request_id, api_key_id, request_fingerprint) + VALUES ($1, $2, $3) + ON CONFLICT (request_id, api_key_id) DO NOTHING + RETURNING id + `, cmd.RequestID, cmd.APIKeyID, cmd.RequestFingerprint).Scan(&id) + if errors.Is(err, sql.ErrNoRows) { + var existingFingerprint string + if err := tx.QueryRowContext(ctx, ` + SELECT request_fingerprint + FROM usage_billing_dedup + WHERE request_id = $1 AND api_key_id = $2 + `, cmd.RequestID, cmd.APIKeyID).Scan(&existingFingerprint); err != nil { + return false, err + } + if strings.TrimSpace(existingFingerprint) != strings.TrimSpace(cmd.RequestFingerprint) { + return false, service.ErrUsageBillingRequestConflict + } + return false, nil + } + if err != nil { + return false, err + } + var archivedFingerprint string + err = tx.QueryRowContext(ctx, ` + SELECT request_fingerprint + FROM usage_billing_dedup_archive + WHERE request_id = $1 AND api_key_id = $2 + `, cmd.RequestID, cmd.APIKeyID).Scan(&archivedFingerprint) + if err == nil { + if strings.TrimSpace(archivedFingerprint) != strings.TrimSpace(cmd.RequestFingerprint) { + return false, service.ErrUsageBillingRequestConflict + } + return false, nil + } + if !errors.Is(err, sql.ErrNoRows) { + return false, err + } + return true, nil +} + +func (r *usageBillingRepository) applyUsageBillingEffects(ctx context.Context, tx *sql.Tx, cmd *service.UsageBillingCommand, result *service.UsageBillingApplyResult) error { + if cmd.SubscriptionCost > 0 && cmd.SubscriptionID != nil { + if err := incrementUsageBillingSubscription(ctx, tx, *cmd.SubscriptionID, cmd.SubscriptionCost); err != nil { + return err + } + } + + if cmd.BalanceCost > 0 { + if err := deductUsageBillingBalance(ctx, tx, cmd.UserID, cmd.BalanceCost); err != nil { + return err + } + } + + if cmd.APIKeyQuotaCost > 0 { + exhausted, err := incrementUsageBillingAPIKeyQuota(ctx, tx, cmd.APIKeyID, cmd.APIKeyQuotaCost) + if err != nil { + return err + } + result.APIKeyQuotaExhausted = exhausted + } + + if cmd.APIKeyRateLimitCost > 0 { + if err := incrementUsageBillingAPIKeyRateLimit(ctx, tx, cmd.APIKeyID, cmd.APIKeyRateLimitCost); err != nil { + return err + } + } + + if cmd.AccountQuotaCost > 0 && strings.EqualFold(cmd.AccountType, service.AccountTypeAPIKey) { + if err := incrementUsageBillingAccountQuota(ctx, tx, cmd.AccountID, cmd.AccountQuotaCost); err != nil { + return err + } + } + + return nil +} + +func incrementUsageBillingSubscription(ctx context.Context, tx *sql.Tx, subscriptionID int64, costUSD float64) error { + const updateSQL = ` + UPDATE user_subscriptions us + SET + daily_usage_usd = us.daily_usage_usd + $1, + weekly_usage_usd = us.weekly_usage_usd + $1, + monthly_usage_usd = us.monthly_usage_usd + $1, + updated_at = NOW() + FROM groups g + WHERE us.id = $2 + AND us.deleted_at IS NULL + AND us.group_id = g.id + AND g.deleted_at IS NULL + ` + res, err := tx.ExecContext(ctx, updateSQL, costUSD, subscriptionID) + if err != nil { + return err + } + affected, err := res.RowsAffected() + if err != nil { + return err + } + if affected > 0 { + return nil + } + return service.ErrSubscriptionNotFound +} + +func deductUsageBillingBalance(ctx context.Context, tx *sql.Tx, userID int64, amount float64) error { + res, err := tx.ExecContext(ctx, ` + UPDATE users + SET balance = balance - $1, + updated_at = NOW() + WHERE id = $2 AND deleted_at IS NULL + `, amount, userID) + if err != nil { + return err + } + affected, err := res.RowsAffected() + if err != nil { + return err + } + if affected > 0 { + return nil + } + return service.ErrUserNotFound +} + +func incrementUsageBillingAPIKeyQuota(ctx context.Context, tx *sql.Tx, apiKeyID int64, amount float64) (bool, error) { + var exhausted bool + err := tx.QueryRowContext(ctx, ` + UPDATE api_keys + SET quota_used = quota_used + $1, + status = CASE + WHEN quota > 0 + AND status = $3 + AND quota_used < quota + AND quota_used + $1 >= quota + THEN $4 + ELSE status + END, + updated_at = NOW() + WHERE id = $2 AND deleted_at IS NULL + RETURNING quota > 0 AND quota_used >= quota AND quota_used - $1 < quota + `, amount, apiKeyID, service.StatusAPIKeyActive, service.StatusAPIKeyQuotaExhausted).Scan(&exhausted) + if errors.Is(err, sql.ErrNoRows) { + return false, service.ErrAPIKeyNotFound + } + if err != nil { + return false, err + } + return exhausted, nil +} + +func incrementUsageBillingAPIKeyRateLimit(ctx context.Context, tx *sql.Tx, apiKeyID int64, cost float64) error { + res, err := tx.ExecContext(ctx, ` + UPDATE api_keys SET + usage_5h = CASE WHEN window_5h_start IS NOT NULL AND window_5h_start + INTERVAL '5 hours' <= NOW() THEN $1 ELSE usage_5h + $1 END, + usage_1d = CASE WHEN window_1d_start IS NOT NULL AND window_1d_start + INTERVAL '24 hours' <= NOW() THEN $1 ELSE usage_1d + $1 END, + usage_7d = CASE WHEN window_7d_start IS NOT NULL AND window_7d_start + INTERVAL '7 days' <= NOW() THEN $1 ELSE usage_7d + $1 END, + window_5h_start = CASE WHEN window_5h_start IS NULL OR window_5h_start + INTERVAL '5 hours' <= NOW() THEN NOW() ELSE window_5h_start END, + window_1d_start = CASE WHEN window_1d_start IS NULL OR window_1d_start + INTERVAL '24 hours' <= NOW() THEN date_trunc('day', NOW()) ELSE window_1d_start END, + window_7d_start = CASE WHEN window_7d_start IS NULL OR window_7d_start + INTERVAL '7 days' <= NOW() THEN date_trunc('day', NOW()) ELSE window_7d_start END, + updated_at = NOW() + WHERE id = $2 AND deleted_at IS NULL + `, cost, apiKeyID) + if err != nil { + return err + } + affected, err := res.RowsAffected() + if err != nil { + return err + } + if affected == 0 { + return service.ErrAPIKeyNotFound + } + return nil +} + +func incrementUsageBillingAccountQuota(ctx context.Context, tx *sql.Tx, accountID int64, amount float64) error { + rows, err := tx.QueryContext(ctx, + `UPDATE accounts SET extra = ( + COALESCE(extra, '{}'::jsonb) + || jsonb_build_object('quota_used', COALESCE((extra->>'quota_used')::numeric, 0) + $1) + || CASE WHEN COALESCE((extra->>'quota_daily_limit')::numeric, 0) > 0 THEN + jsonb_build_object( + 'quota_daily_used', + CASE WHEN COALESCE((extra->>'quota_daily_start')::timestamptz, '1970-01-01'::timestamptz) + + '24 hours'::interval <= NOW() + THEN $1 + ELSE COALESCE((extra->>'quota_daily_used')::numeric, 0) + $1 END, + 'quota_daily_start', + CASE WHEN COALESCE((extra->>'quota_daily_start')::timestamptz, '1970-01-01'::timestamptz) + + '24 hours'::interval <= NOW() + THEN `+nowUTC+` + ELSE COALESCE(extra->>'quota_daily_start', `+nowUTC+`) END + ) + ELSE '{}'::jsonb END + || CASE WHEN COALESCE((extra->>'quota_weekly_limit')::numeric, 0) > 0 THEN + jsonb_build_object( + 'quota_weekly_used', + CASE WHEN COALESCE((extra->>'quota_weekly_start')::timestamptz, '1970-01-01'::timestamptz) + + '168 hours'::interval <= NOW() + THEN $1 + ELSE COALESCE((extra->>'quota_weekly_used')::numeric, 0) + $1 END, + 'quota_weekly_start', + CASE WHEN COALESCE((extra->>'quota_weekly_start')::timestamptz, '1970-01-01'::timestamptz) + + '168 hours'::interval <= NOW() + THEN `+nowUTC+` + ELSE COALESCE(extra->>'quota_weekly_start', `+nowUTC+`) END + ) + ELSE '{}'::jsonb END + ), updated_at = NOW() + WHERE id = $2 AND deleted_at IS NULL + RETURNING + COALESCE((extra->>'quota_used')::numeric, 0), + COALESCE((extra->>'quota_limit')::numeric, 0)`, + amount, accountID) + if err != nil { + return err + } + defer func() { _ = rows.Close() }() + + var newUsed, limit float64 + if rows.Next() { + if err := rows.Scan(&newUsed, &limit); err != nil { + return err + } + } else { + if err := rows.Err(); err != nil { + return err + } + return service.ErrAccountNotFound + } + if err := rows.Err(); err != nil { + return err + } + if limit > 0 && newUsed >= limit && (newUsed-amount) < limit { + if err := enqueueSchedulerOutbox(ctx, tx, service.SchedulerOutboxEventAccountChanged, &accountID, nil, nil); err != nil { + logger.LegacyPrintf("repository.usage_billing", "[SchedulerOutbox] enqueue quota exceeded failed: account=%d err=%v", accountID, err) + return err + } + } + return nil +} diff --git a/backend/internal/repository/usage_billing_repo_integration_test.go b/backend/internal/repository/usage_billing_repo_integration_test.go new file mode 100644 index 00000000..eda34cc9 --- /dev/null +++ b/backend/internal/repository/usage_billing_repo_integration_test.go @@ -0,0 +1,279 @@ +//go:build integration + +package repository + +import ( + "context" + "fmt" + "strings" + "testing" + "time" + + "github.com/google/uuid" + "github.com/stretchr/testify/require" + + "github.com/Wei-Shaw/sub2api/internal/service" +) + +func TestUsageBillingRepositoryApply_DeduplicatesBalanceBilling(t *testing.T) { + ctx := context.Background() + client := testEntClient(t) + repo := NewUsageBillingRepository(client, integrationDB) + + user := mustCreateUser(t, client, &service.User{ + Email: fmt.Sprintf("usage-billing-user-%d@example.com", time.Now().UnixNano()), + PasswordHash: "hash", + Balance: 100, + }) + apiKey := mustCreateApiKey(t, client, &service.APIKey{ + UserID: user.ID, + Key: "sk-usage-billing-" + uuid.NewString(), + Name: "billing", + Quota: 1, + }) + account := mustCreateAccount(t, client, &service.Account{ + Name: "usage-billing-account-" + uuid.NewString(), + Type: service.AccountTypeAPIKey, + }) + + requestID := uuid.NewString() + cmd := &service.UsageBillingCommand{ + RequestID: requestID, + APIKeyID: apiKey.ID, + UserID: user.ID, + AccountID: account.ID, + AccountType: service.AccountTypeAPIKey, + BalanceCost: 1.25, + APIKeyQuotaCost: 1.25, + APIKeyRateLimitCost: 1.25, + } + + result1, err := repo.Apply(ctx, cmd) + require.NoError(t, err) + require.NotNil(t, result1) + require.True(t, result1.Applied) + require.True(t, result1.APIKeyQuotaExhausted) + + result2, err := repo.Apply(ctx, cmd) + require.NoError(t, err) + require.NotNil(t, result2) + require.False(t, result2.Applied) + + var balance float64 + require.NoError(t, integrationDB.QueryRowContext(ctx, "SELECT balance FROM users WHERE id = $1", user.ID).Scan(&balance)) + require.InDelta(t, 98.75, balance, 0.000001) + + var quotaUsed float64 + require.NoError(t, integrationDB.QueryRowContext(ctx, "SELECT quota_used FROM api_keys WHERE id = $1", apiKey.ID).Scan("aUsed)) + require.InDelta(t, 1.25, quotaUsed, 0.000001) + + var usage5h float64 + require.NoError(t, integrationDB.QueryRowContext(ctx, "SELECT usage_5h FROM api_keys WHERE id = $1", apiKey.ID).Scan(&usage5h)) + require.InDelta(t, 1.25, usage5h, 0.000001) + + var status string + require.NoError(t, integrationDB.QueryRowContext(ctx, "SELECT status FROM api_keys WHERE id = $1", apiKey.ID).Scan(&status)) + require.Equal(t, service.StatusAPIKeyQuotaExhausted, status) + + var dedupCount int + require.NoError(t, integrationDB.QueryRowContext(ctx, "SELECT COUNT(*) FROM usage_billing_dedup WHERE request_id = $1 AND api_key_id = $2", requestID, apiKey.ID).Scan(&dedupCount)) + require.Equal(t, 1, dedupCount) +} + +func TestUsageBillingRepositoryApply_DeduplicatesSubscriptionBilling(t *testing.T) { + ctx := context.Background() + client := testEntClient(t) + repo := NewUsageBillingRepository(client, integrationDB) + + user := mustCreateUser(t, client, &service.User{ + Email: fmt.Sprintf("usage-billing-sub-user-%d@example.com", time.Now().UnixNano()), + PasswordHash: "hash", + }) + group := mustCreateGroup(t, client, &service.Group{ + Name: "usage-billing-group-" + uuid.NewString(), + Platform: service.PlatformAnthropic, + SubscriptionType: service.SubscriptionTypeSubscription, + }) + apiKey := mustCreateApiKey(t, client, &service.APIKey{ + UserID: user.ID, + GroupID: &group.ID, + Key: "sk-usage-billing-sub-" + uuid.NewString(), + Name: "billing-sub", + }) + subscription := mustCreateSubscription(t, client, &service.UserSubscription{ + UserID: user.ID, + GroupID: group.ID, + }) + + requestID := uuid.NewString() + cmd := &service.UsageBillingCommand{ + RequestID: requestID, + APIKeyID: apiKey.ID, + UserID: user.ID, + AccountID: 0, + SubscriptionID: &subscription.ID, + SubscriptionCost: 2.5, + } + + result1, err := repo.Apply(ctx, cmd) + require.NoError(t, err) + require.True(t, result1.Applied) + + result2, err := repo.Apply(ctx, cmd) + require.NoError(t, err) + require.False(t, result2.Applied) + + var dailyUsage float64 + require.NoError(t, integrationDB.QueryRowContext(ctx, "SELECT daily_usage_usd FROM user_subscriptions WHERE id = $1", subscription.ID).Scan(&dailyUsage)) + require.InDelta(t, 2.5, dailyUsage, 0.000001) +} + +func TestUsageBillingRepositoryApply_RequestFingerprintConflict(t *testing.T) { + ctx := context.Background() + client := testEntClient(t) + repo := NewUsageBillingRepository(client, integrationDB) + + user := mustCreateUser(t, client, &service.User{ + Email: fmt.Sprintf("usage-billing-conflict-user-%d@example.com", time.Now().UnixNano()), + PasswordHash: "hash", + Balance: 100, + }) + apiKey := mustCreateApiKey(t, client, &service.APIKey{ + UserID: user.ID, + Key: "sk-usage-billing-conflict-" + uuid.NewString(), + Name: "billing-conflict", + }) + + requestID := uuid.NewString() + _, err := repo.Apply(ctx, &service.UsageBillingCommand{ + RequestID: requestID, + APIKeyID: apiKey.ID, + UserID: user.ID, + BalanceCost: 1.25, + }) + require.NoError(t, err) + + _, err = repo.Apply(ctx, &service.UsageBillingCommand{ + RequestID: requestID, + APIKeyID: apiKey.ID, + UserID: user.ID, + BalanceCost: 2.50, + }) + require.ErrorIs(t, err, service.ErrUsageBillingRequestConflict) +} + +func TestUsageBillingRepositoryApply_UpdatesAccountQuota(t *testing.T) { + ctx := context.Background() + client := testEntClient(t) + repo := NewUsageBillingRepository(client, integrationDB) + + user := mustCreateUser(t, client, &service.User{ + Email: fmt.Sprintf("usage-billing-account-user-%d@example.com", time.Now().UnixNano()), + PasswordHash: "hash", + }) + apiKey := mustCreateApiKey(t, client, &service.APIKey{ + UserID: user.ID, + Key: "sk-usage-billing-account-" + uuid.NewString(), + Name: "billing-account", + }) + account := mustCreateAccount(t, client, &service.Account{ + Name: "usage-billing-account-quota-" + uuid.NewString(), + Type: service.AccountTypeAPIKey, + Extra: map[string]any{ + "quota_limit": 100.0, + }, + }) + + _, err := repo.Apply(ctx, &service.UsageBillingCommand{ + RequestID: uuid.NewString(), + APIKeyID: apiKey.ID, + UserID: user.ID, + AccountID: account.ID, + AccountType: service.AccountTypeAPIKey, + AccountQuotaCost: 3.5, + }) + require.NoError(t, err) + + var quotaUsed float64 + require.NoError(t, integrationDB.QueryRowContext(ctx, "SELECT COALESCE((extra->>'quota_used')::numeric, 0) FROM accounts WHERE id = $1", account.ID).Scan("aUsed)) + require.InDelta(t, 3.5, quotaUsed, 0.000001) +} + +func TestDashboardAggregationRepositoryCleanupUsageBillingDedup_BatchDeletesOldRows(t *testing.T) { + ctx := context.Background() + repo := newDashboardAggregationRepositoryWithSQL(integrationDB) + + oldRequestID := "dedup-old-" + uuid.NewString() + newRequestID := "dedup-new-" + uuid.NewString() + oldCreatedAt := time.Now().UTC().AddDate(0, 0, -400) + newCreatedAt := time.Now().UTC().Add(-time.Hour) + + _, err := integrationDB.ExecContext(ctx, ` + INSERT INTO usage_billing_dedup (request_id, api_key_id, request_fingerprint, created_at) + VALUES ($1, 1, $2, $3), ($4, 1, $5, $6) + `, + oldRequestID, strings.Repeat("a", 64), oldCreatedAt, + newRequestID, strings.Repeat("b", 64), newCreatedAt, + ) + require.NoError(t, err) + + require.NoError(t, repo.CleanupUsageBillingDedup(ctx, time.Now().UTC().AddDate(0, 0, -365))) + + var oldCount int + require.NoError(t, integrationDB.QueryRowContext(ctx, "SELECT COUNT(*) FROM usage_billing_dedup WHERE request_id = $1", oldRequestID).Scan(&oldCount)) + require.Equal(t, 0, oldCount) + + var newCount int + require.NoError(t, integrationDB.QueryRowContext(ctx, "SELECT COUNT(*) FROM usage_billing_dedup WHERE request_id = $1", newRequestID).Scan(&newCount)) + require.Equal(t, 1, newCount) + + var archivedCount int + require.NoError(t, integrationDB.QueryRowContext(ctx, "SELECT COUNT(*) FROM usage_billing_dedup_archive WHERE request_id = $1", oldRequestID).Scan(&archivedCount)) + require.Equal(t, 1, archivedCount) +} + +func TestUsageBillingRepositoryApply_DeduplicatesAgainstArchivedKey(t *testing.T) { + ctx := context.Background() + client := testEntClient(t) + repo := NewUsageBillingRepository(client, integrationDB) + aggRepo := newDashboardAggregationRepositoryWithSQL(integrationDB) + + user := mustCreateUser(t, client, &service.User{ + Email: fmt.Sprintf("usage-billing-archive-user-%d@example.com", time.Now().UnixNano()), + PasswordHash: "hash", + Balance: 100, + }) + apiKey := mustCreateApiKey(t, client, &service.APIKey{ + UserID: user.ID, + Key: "sk-usage-billing-archive-" + uuid.NewString(), + Name: "billing-archive", + }) + + requestID := uuid.NewString() + cmd := &service.UsageBillingCommand{ + RequestID: requestID, + APIKeyID: apiKey.ID, + UserID: user.ID, + BalanceCost: 1.25, + } + + result1, err := repo.Apply(ctx, cmd) + require.NoError(t, err) + require.True(t, result1.Applied) + + _, err = integrationDB.ExecContext(ctx, ` + UPDATE usage_billing_dedup + SET created_at = $1 + WHERE request_id = $2 AND api_key_id = $3 + `, time.Now().UTC().AddDate(0, 0, -400), requestID, apiKey.ID) + require.NoError(t, err) + require.NoError(t, aggRepo.CleanupUsageBillingDedup(ctx, time.Now().UTC().AddDate(0, 0, -365))) + + result2, err := repo.Apply(ctx, cmd) + require.NoError(t, err) + require.False(t, result2.Applied) + + var balance float64 + require.NoError(t, integrationDB.QueryRowContext(ctx, "SELECT balance FROM users WHERE id = $1", user.ID).Scan(&balance)) + require.InDelta(t, 98.75, balance, 0.000001) +} diff --git a/backend/internal/repository/usage_log_repo.go b/backend/internal/repository/usage_log_repo.go index 8ffcb2f3..5e81818b 100644 --- a/backend/internal/repository/usage_log_repo.go +++ b/backend/internal/repository/usage_log_repo.go @@ -3,12 +3,14 @@ package repository import ( "context" "database/sql" + "encoding/json" "errors" "fmt" "os" "strconv" "strings" "sync" + "sync/atomic" "time" dbent "github.com/Wei-Shaw/sub2api/ent" @@ -17,11 +19,13 @@ import ( dbgroup "github.com/Wei-Shaw/sub2api/ent/group" dbuser "github.com/Wei-Shaw/sub2api/ent/user" dbusersub "github.com/Wei-Shaw/sub2api/ent/usersubscription" + "github.com/Wei-Shaw/sub2api/internal/pkg/logger" "github.com/Wei-Shaw/sub2api/internal/pkg/pagination" "github.com/Wei-Shaw/sub2api/internal/pkg/timezone" "github.com/Wei-Shaw/sub2api/internal/pkg/usagestats" "github.com/Wei-Shaw/sub2api/internal/service" "github.com/lib/pq" + gocache "github.com/patrickmn/go-cache" ) const usageLogSelectColumns = "id, user_id, api_key_id, account_id, request_id, model, group_id, subscription_id, input_tokens, output_tokens, cache_creation_tokens, cache_read_tokens, cache_creation_5m_tokens, cache_creation_1h_tokens, input_cost, output_cost, cache_creation_cost, cache_read_cost, total_cost, actual_cost, rate_multiplier, account_rate_multiplier, billing_type, request_type, stream, openai_ws_mode, duration_ms, first_token_ms, user_agent, ip_address, image_count, image_size, media_type, service_tier, reasoning_effort, cache_ttl_overridden, created_at" @@ -47,18 +51,29 @@ type usageLogRepository struct { sql sqlExecutor db *sql.DB - createBatchOnce sync.Once - createBatchCh chan usageLogCreateRequest + createBatchOnce sync.Once + createBatchCh chan usageLogCreateRequest + bestEffortBatchOnce sync.Once + bestEffortBatchCh chan usageLogBestEffortRequest + bestEffortRecent *gocache.Cache } const ( usageLogCreateBatchMaxSize = 64 usageLogCreateBatchWindow = 3 * time.Millisecond usageLogCreateBatchQueueCap = 4096 + usageLogCreateCancelWait = 2 * time.Second + + usageLogBestEffortBatchMaxSize = 256 + usageLogBestEffortBatchWindow = 20 * time.Millisecond + usageLogBestEffortBatchQueueCap = 32768 + usageLogBestEffortRecentTTL = 30 * time.Second ) type usageLogCreateRequest struct { log *service.UsageLog + prepared usageLogInsertPrepared + shared *usageLogCreateShared resultCh chan usageLogCreateResult } @@ -67,6 +82,12 @@ type usageLogCreateResult struct { err error } +type usageLogBestEffortRequest struct { + prepared usageLogInsertPrepared + apiKeyID int64 + resultCh chan error +} + type usageLogInsertPrepared struct { createdAt time.Time requestID string @@ -80,6 +101,25 @@ type usageLogBatchState struct { CreatedAt time.Time } +type usageLogBatchRow struct { + RequestID string `json:"request_id"` + APIKeyID int64 `json:"api_key_id"` + ID int64 `json:"id"` + CreatedAt time.Time `json:"created_at"` + Inserted bool `json:"inserted"` +} + +type usageLogCreateShared struct { + state atomic.Int32 +} + +const ( + usageLogCreateStateQueued int32 = iota + usageLogCreateStateProcessing + usageLogCreateStateCompleted + usageLogCreateStateCanceled +) + func NewUsageLogRepository(client *dbent.Client, sqlDB *sql.DB) service.UsageLogRepository { return newUsageLogRepositoryWithSQL(client, sqlDB) } @@ -90,6 +130,7 @@ func newUsageLogRepositoryWithSQL(client *dbent.Client, sqlq sqlExecutor) *usage if db, ok := sqlq.(*sql.DB); ok { repo.db = db } + repo.bestEffortRecent = gocache.New(usageLogBestEffortRecentTTL, time.Minute) return repo } @@ -124,9 +165,6 @@ func (r *usageLogRepository) Create(ctx context.Context, log *service.UsageLog) if tx := dbent.TxFromContext(ctx); tx != nil { return r.createSingle(ctx, tx.Client(), log) } - if r.db == nil { - return r.createSingle(ctx, r.sql, log) - } requestID := strings.TrimSpace(log.RequestID) if requestID == "" { return r.createSingle(ctx, r.sql, log) @@ -135,11 +173,61 @@ func (r *usageLogRepository) Create(ctx context.Context, log *service.UsageLog) return r.createBatched(ctx, log) } +func (r *usageLogRepository) CreateBestEffort(ctx context.Context, log *service.UsageLog) error { + if log == nil { + return nil + } + + if tx := dbent.TxFromContext(ctx); tx != nil { + _, err := r.createSingle(ctx, tx.Client(), log) + return err + } + if r.db == nil { + _, err := r.createSingle(ctx, r.sql, log) + return err + } + + r.ensureBestEffortBatcher() + if r.bestEffortBatchCh == nil { + _, err := r.createSingle(ctx, r.sql, log) + return err + } + + req := usageLogBestEffortRequest{ + prepared: prepareUsageLogInsert(log), + apiKeyID: log.APIKeyID, + resultCh: make(chan error, 1), + } + if key, ok := r.bestEffortRecentKey(req.prepared.requestID, req.apiKeyID); ok { + if _, exists := r.bestEffortRecent.Get(key); exists { + return nil + } + } + + select { + case r.bestEffortBatchCh <- req: + case <-ctx.Done(): + return ctx.Err() + default: + return errors.New("usage log best-effort queue full") + } + + select { + case err := <-req.resultCh: + return err + case <-ctx.Done(): + return ctx.Err() + } +} + func (r *usageLogRepository) createSingle(ctx context.Context, sqlq sqlExecutor, log *service.UsageLog) (bool, error) { prepared := prepareUsageLogInsert(log) if sqlq == nil { sqlq = r.sql } + if ctx != nil && ctx.Err() != nil { + return false, service.MarkUsageLogCreateNotPersisted(ctx.Err()) + } query := ` INSERT INTO usage_logs ( @@ -218,13 +306,15 @@ func (r *usageLogRepository) createBatched(ctx context.Context, log *service.Usa req := usageLogCreateRequest{ log: log, + prepared: prepareUsageLogInsert(log), + shared: &usageLogCreateShared{}, resultCh: make(chan usageLogCreateResult, 1), } select { case r.createBatchCh <- req: case <-ctx.Done(): - return false, ctx.Err() + return false, service.MarkUsageLogCreateNotPersisted(ctx.Err()) default: return r.createSingle(ctx, r.sql, log) } @@ -233,7 +323,17 @@ func (r *usageLogRepository) createBatched(ctx context.Context, log *service.Usa case res := <-req.resultCh: return res.inserted, res.err case <-ctx.Done(): - return false, ctx.Err() + if req.shared != nil && req.shared.state.CompareAndSwap(usageLogCreateStateQueued, usageLogCreateStateCanceled) { + return false, service.MarkUsageLogCreateNotPersisted(ctx.Err()) + } + timer := time.NewTimer(usageLogCreateCancelWait) + defer timer.Stop() + select { + case res := <-req.resultCh: + return res.inserted, res.err + case <-timer.C: + return false, ctx.Err() + } } } @@ -247,6 +347,16 @@ func (r *usageLogRepository) ensureCreateBatcher() { }) } +func (r *usageLogRepository) ensureBestEffortBatcher() { + if r == nil || r.db == nil { + return + } + r.bestEffortBatchOnce.Do(func() { + r.bestEffortBatchCh = make(chan usageLogBestEffortRequest, usageLogBestEffortBatchQueueCap) + go r.runBestEffortBatcher(r.db) + }) +} + func (r *usageLogRepository) runCreateBatcher(db *sql.DB) { for { first, ok := <-r.createBatchCh @@ -281,6 +391,40 @@ func (r *usageLogRepository) runCreateBatcher(db *sql.DB) { } } +func (r *usageLogRepository) runBestEffortBatcher(db *sql.DB) { + for { + first, ok := <-r.bestEffortBatchCh + if !ok { + return + } + + batch := make([]usageLogBestEffortRequest, 0, usageLogBestEffortBatchMaxSize) + batch = append(batch, first) + + timer := time.NewTimer(usageLogBestEffortBatchWindow) + bestEffortLoop: + for len(batch) < usageLogBestEffortBatchMaxSize { + select { + case req, ok := <-r.bestEffortBatchCh: + if !ok { + break bestEffortLoop + } + batch = append(batch, req) + case <-timer.C: + break bestEffortLoop + } + } + if !timer.Stop() { + select { + case <-timer.C: + default: + } + } + + r.flushBestEffortBatch(db, batch) + } +} + func (r *usageLogRepository) flushCreateBatch(db *sql.DB, batch []usageLogCreateRequest) { if len(batch) == 0 { return @@ -293,10 +437,19 @@ func (r *usageLogRepository) flushCreateBatch(db *sql.DB, batch []usageLogCreate for _, req := range batch { if req.log == nil { - sendUsageLogCreateResult(req.resultCh, usageLogCreateResult{inserted: false, err: nil}) + completeUsageLogCreateRequest(req, usageLogCreateResult{inserted: false, err: nil}) continue } - prepared := prepareUsageLogInsert(req.log) + if req.shared != nil && !req.shared.state.CompareAndSwap(usageLogCreateStateQueued, usageLogCreateStateProcessing) { + if req.shared.state.Load() == usageLogCreateStateCanceled { + completeUsageLogCreateRequest(req, usageLogCreateResult{ + inserted: false, + err: service.MarkUsageLogCreateNotPersisted(context.Canceled), + }) + continue + } + } + prepared := req.prepared if prepared.requestID == "" { fallback = append(fallback, req) continue @@ -310,10 +463,37 @@ func (r *usageLogRepository) flushCreateBatch(db *sql.DB, batch []usageLogCreate } if len(uniqueOrder) > 0 { - insertedMap, stateMap, err := r.batchInsertUsageLogs(db, uniqueOrder, preparedByKey) + insertedMap, stateMap, safeFallback, err := r.batchInsertUsageLogs(db, uniqueOrder, preparedByKey) if err != nil { - for _, key := range uniqueOrder { - fallback = append(fallback, requestsByKey[key]...) + if safeFallback { + for _, key := range uniqueOrder { + fallback = append(fallback, requestsByKey[key]...) + } + } else { + for _, key := range uniqueOrder { + reqs := requestsByKey[key] + state, hasState := stateMap[key] + inserted := insertedMap[key] + for idx, req := range reqs { + req.log.RateMultiplier = preparedByKey[key].rateMultiplier + if hasState { + req.log.ID = state.ID + req.log.CreatedAt = state.CreatedAt + } + switch { + case inserted && idx == 0: + completeUsageLogCreateRequest(req, usageLogCreateResult{inserted: true, err: nil}) + case inserted: + completeUsageLogCreateRequest(req, usageLogCreateResult{inserted: false, err: nil}) + case hasState: + completeUsageLogCreateRequest(req, usageLogCreateResult{inserted: false, err: nil}) + case idx == 0: + completeUsageLogCreateRequest(req, usageLogCreateResult{inserted: false, err: err}) + default: + completeUsageLogCreateRequest(req, usageLogCreateResult{inserted: false, err: nil}) + } + } + } } } else { for _, key := range uniqueOrder { @@ -321,7 +501,7 @@ func (r *usageLogRepository) flushCreateBatch(db *sql.DB, batch []usageLogCreate state, ok := stateMap[key] if !ok { for _, req := range reqs { - sendUsageLogCreateResult(req.resultCh, usageLogCreateResult{ + completeUsageLogCreateRequest(req, usageLogCreateResult{ inserted: false, err: fmt.Errorf("usage log batch state missing for key=%s", key), }) @@ -332,7 +512,7 @@ func (r *usageLogRepository) flushCreateBatch(db *sql.DB, batch []usageLogCreate req.log.ID = state.ID req.log.CreatedAt = state.CreatedAt req.log.RateMultiplier = preparedByKey[key].rateMultiplier - sendUsageLogCreateResult(req.resultCh, usageLogCreateResult{ + completeUsageLogCreateRequest(req, usageLogCreateResult{ inserted: idx == 0 && insertedMap[key], err: nil, }) @@ -345,56 +525,366 @@ func (r *usageLogRepository) flushCreateBatch(db *sql.DB, batch []usageLogCreate return } - fallbackCtx, cancel := context.WithTimeout(context.Background(), 10*time.Second) - defer cancel() for _, req := range fallback { + fallbackCtx, cancel := context.WithTimeout(context.Background(), 10*time.Second) inserted, err := r.createSingle(fallbackCtx, db, req.log) - sendUsageLogCreateResult(req.resultCh, usageLogCreateResult{inserted: inserted, err: err}) + cancel() + completeUsageLogCreateRequest(req, usageLogCreateResult{inserted: inserted, err: err}) } } -func (r *usageLogRepository) batchInsertUsageLogs(db *sql.DB, keys []string, preparedByKey map[string]usageLogInsertPrepared) (map[string]bool, map[string]usageLogBatchState, error) { +func (r *usageLogRepository) flushBestEffortBatch(db *sql.DB, batch []usageLogBestEffortRequest) { + if len(batch) == 0 { + return + } + + type bestEffortGroup struct { + prepared usageLogInsertPrepared + apiKeyID int64 + key string + reqs []usageLogBestEffortRequest + } + + groupsByKey := make(map[string]*bestEffortGroup, len(batch)) + groupOrder := make([]*bestEffortGroup, 0, len(batch)) + preparedList := make([]usageLogInsertPrepared, 0, len(batch)) + + for idx, req := range batch { + prepared := req.prepared + key := fmt.Sprintf("__best_effort_%d", idx) + if prepared.requestID != "" { + key = usageLogBatchKey(prepared.requestID, req.apiKeyID) + } + group, exists := groupsByKey[key] + if !exists { + group = &bestEffortGroup{ + prepared: prepared, + apiKeyID: req.apiKeyID, + key: key, + } + groupsByKey[key] = group + groupOrder = append(groupOrder, group) + preparedList = append(preparedList, prepared) + } + group.reqs = append(group.reqs, req) + } + + if len(preparedList) == 0 { + for _, req := range batch { + sendUsageLogBestEffortResult(req.resultCh, nil) + } + return + } + + ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) + defer cancel() + + query, args := buildUsageLogBestEffortInsertQuery(preparedList) + if _, err := db.ExecContext(ctx, query, args...); err != nil { + logger.LegacyPrintf("repository.usage_log", "best-effort batch insert failed: %v", err) + for _, group := range groupOrder { + singleErr := execUsageLogInsertNoResult(ctx, db, group.prepared) + if singleErr != nil { + logger.LegacyPrintf("repository.usage_log", "best-effort single fallback insert failed: %v", singleErr) + } else if group.prepared.requestID != "" && r != nil && r.bestEffortRecent != nil { + r.bestEffortRecent.SetDefault(group.key, struct{}{}) + } + for _, req := range group.reqs { + sendUsageLogBestEffortResult(req.resultCh, singleErr) + } + } + return + } + for _, group := range groupOrder { + if group.prepared.requestID != "" && r != nil && r.bestEffortRecent != nil { + r.bestEffortRecent.SetDefault(group.key, struct{}{}) + } + for _, req := range group.reqs { + sendUsageLogBestEffortResult(req.resultCh, nil) + } + } +} + +func sendUsageLogBestEffortResult(ch chan error, err error) { + if ch == nil { + return + } + select { + case ch <- err: + default: + } +} + +func completeUsageLogCreateRequest(req usageLogCreateRequest, res usageLogCreateResult) { + if req.shared != nil { + req.shared.state.Store(usageLogCreateStateCompleted) + } + sendUsageLogCreateResult(req.resultCh, res) +} + +func (r *usageLogRepository) batchInsertUsageLogs(db *sql.DB, keys []string, preparedByKey map[string]usageLogInsertPrepared) (map[string]bool, map[string]usageLogBatchState, bool, error) { if len(keys) == 0 { - return map[string]bool{}, map[string]usageLogBatchState{}, nil + return map[string]bool{}, map[string]usageLogBatchState{}, false, nil } ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) defer cancel() query, args := buildUsageLogBatchInsertQuery(keys, preparedByKey) - rows, err := db.QueryContext(ctx, query, args...) - if err != nil { - return nil, nil, err + var payload []byte + if err := db.QueryRowContext(ctx, query, args...).Scan(&payload); err != nil { + return nil, nil, true, err + } + var rows []usageLogBatchRow + if err := json.Unmarshal(payload, &rows); err != nil { + return nil, nil, false, err } insertedMap := make(map[string]bool, len(keys)) - for rows.Next() { - var ( - requestID string - apiKeyID int64 - id int64 - createdAt time.Time - ) - if err := rows.Scan(&requestID, &apiKeyID, &id, &createdAt); err != nil { - _ = rows.Close() - return nil, nil, err + stateMap := make(map[string]usageLogBatchState, len(keys)) + for _, row := range rows { + key := usageLogBatchKey(row.RequestID, row.APIKeyID) + insertedMap[key] = row.Inserted + stateMap[key] = usageLogBatchState{ + ID: row.ID, + CreatedAt: row.CreatedAt, } - insertedMap[usageLogBatchKey(requestID, apiKeyID)] = true } - if err := rows.Err(); err != nil { - _ = rows.Close() - return nil, nil, err + if len(stateMap) != len(keys) { + return insertedMap, stateMap, false, fmt.Errorf("usage log batch state count mismatch: got=%d want=%d", len(stateMap), len(keys)) } - _ = rows.Close() - - stateMap, err := loadUsageLogBatchStates(ctx, db, keys, preparedByKey) - if err != nil { - return nil, nil, err - } - return insertedMap, stateMap, nil + return insertedMap, stateMap, false, nil } func buildUsageLogBatchInsertQuery(keys []string, preparedByKey map[string]usageLogInsertPrepared) (string, []any) { var query strings.Builder _, _ = query.WriteString(` + WITH input ( + input_idx, + user_id, + api_key_id, + account_id, + request_id, + model, + group_id, + subscription_id, + input_tokens, + output_tokens, + cache_creation_tokens, + cache_read_tokens, + cache_creation_5m_tokens, + cache_creation_1h_tokens, + input_cost, + output_cost, + cache_creation_cost, + cache_read_cost, + total_cost, + actual_cost, + rate_multiplier, + account_rate_multiplier, + billing_type, + request_type, + stream, + openai_ws_mode, + duration_ms, + first_token_ms, + user_agent, + ip_address, + image_count, + image_size, + media_type, + service_tier, + reasoning_effort, + cache_ttl_overridden, + created_at + ) AS (VALUES `) + + args := make([]any, 0, len(keys)*37) + argPos := 1 + for idx, key := range keys { + if idx > 0 { + _, _ = query.WriteString(",") + } + _, _ = query.WriteString("(") + _, _ = query.WriteString("$") + _, _ = query.WriteString(strconv.Itoa(argPos)) + args = append(args, idx) + argPos++ + prepared := preparedByKey[key] + for i := 0; i < len(prepared.args); i++ { + _, _ = query.WriteString(",") + _, _ = query.WriteString("$") + _, _ = query.WriteString(strconv.Itoa(argPos)) + argPos++ + } + _, _ = query.WriteString(")") + args = append(args, prepared.args...) + } + _, _ = query.WriteString(` + ), + inserted AS ( + INSERT INTO usage_logs ( + user_id, + api_key_id, + account_id, + request_id, + model, + group_id, + subscription_id, + input_tokens, + output_tokens, + cache_creation_tokens, + cache_read_tokens, + cache_creation_5m_tokens, + cache_creation_1h_tokens, + input_cost, + output_cost, + cache_creation_cost, + cache_read_cost, + total_cost, + actual_cost, + rate_multiplier, + account_rate_multiplier, + billing_type, + request_type, + stream, + openai_ws_mode, + duration_ms, + first_token_ms, + user_agent, + ip_address, + image_count, + image_size, + media_type, + service_tier, + reasoning_effort, + cache_ttl_overridden, + created_at + ) + SELECT + user_id, + api_key_id, + account_id, + request_id, + model, + group_id, + subscription_id, + input_tokens, + output_tokens, + cache_creation_tokens, + cache_read_tokens, + cache_creation_5m_tokens, + cache_creation_1h_tokens, + input_cost, + output_cost, + cache_creation_cost, + cache_read_cost, + total_cost, + actual_cost, + rate_multiplier, + account_rate_multiplier, + billing_type, + request_type, + stream, + openai_ws_mode, + duration_ms, + first_token_ms, + user_agent, + ip_address, + image_count, + image_size, + media_type, + service_tier, + reasoning_effort, + cache_ttl_overridden, + created_at + FROM input + ON CONFLICT (request_id, api_key_id) DO UPDATE + SET request_id = usage_logs.request_id + RETURNING request_id, api_key_id, id, created_at, (xmax = 0) AS inserted + ) + SELECT COALESCE( + json_agg( + json_build_object( + 'request_id', inserted.request_id, + 'api_key_id', inserted.api_key_id, + 'id', inserted.id, + 'created_at', inserted.created_at, + 'inserted', inserted.inserted + ) + ORDER BY input.input_idx + ), + '[]'::json + ) + FROM input + JOIN inserted + ON inserted.request_id = input.request_id + AND inserted.api_key_id = input.api_key_id + `) + return query.String(), args +} + +func buildUsageLogBestEffortInsertQuery(preparedList []usageLogInsertPrepared) (string, []any) { + var query strings.Builder + _, _ = query.WriteString(` + WITH input ( + user_id, + api_key_id, + account_id, + request_id, + model, + group_id, + subscription_id, + input_tokens, + output_tokens, + cache_creation_tokens, + cache_read_tokens, + cache_creation_5m_tokens, + cache_creation_1h_tokens, + input_cost, + output_cost, + cache_creation_cost, + cache_read_cost, + total_cost, + actual_cost, + rate_multiplier, + account_rate_multiplier, + billing_type, + request_type, + stream, + openai_ws_mode, + duration_ms, + first_token_ms, + user_agent, + ip_address, + image_count, + image_size, + media_type, + service_tier, + reasoning_effort, + cache_ttl_overridden, + created_at + ) AS (VALUES `) + + args := make([]any, 0, len(preparedList)*36) + argPos := 1 + for idx, prepared := range preparedList { + if idx > 0 { + _, _ = query.WriteString(",") + } + _, _ = query.WriteString("(") + for i := 0; i < len(prepared.args); i++ { + if i > 0 { + _, _ = query.WriteString(",") + } + _, _ = query.WriteString("$") + _, _ = query.WriteString(strconv.Itoa(argPos)) + argPos++ + } + _, _ = query.WriteString(")") + args = append(args, prepared.args...) + } + + _, _ = query.WriteString(` + ) INSERT INTO usage_logs ( user_id, api_key_id, @@ -432,80 +922,101 @@ func buildUsageLogBatchInsertQuery(keys []string, preparedByKey map[string]usage reasoning_effort, cache_ttl_overridden, created_at - ) VALUES `) - - args := make([]any, 0, len(keys)*36) - argPos := 1 - for idx, key := range keys { - if idx > 0 { - _, _ = query.WriteString(",") - } - _, _ = query.WriteString("(") - prepared := preparedByKey[key] - for i := 0; i < len(prepared.args); i++ { - if i > 0 { - _, _ = query.WriteString(",") - } - _, _ = query.WriteString("$") - _, _ = query.WriteString(strconv.Itoa(argPos)) - argPos++ - } - _, _ = query.WriteString(")") - args = append(args, prepared.args...) - } - _, _ = query.WriteString(` + ) + SELECT + user_id, + api_key_id, + account_id, + request_id, + model, + group_id, + subscription_id, + input_tokens, + output_tokens, + cache_creation_tokens, + cache_read_tokens, + cache_creation_5m_tokens, + cache_creation_1h_tokens, + input_cost, + output_cost, + cache_creation_cost, + cache_read_cost, + total_cost, + actual_cost, + rate_multiplier, + account_rate_multiplier, + billing_type, + request_type, + stream, + openai_ws_mode, + duration_ms, + first_token_ms, + user_agent, + ip_address, + image_count, + image_size, + media_type, + service_tier, + reasoning_effort, + cache_ttl_overridden, + created_at + FROM input ON CONFLICT (request_id, api_key_id) DO NOTHING - RETURNING request_id, api_key_id, id, created_at `) + return query.String(), args } -func loadUsageLogBatchStates(ctx context.Context, db *sql.DB, keys []string, preparedByKey map[string]usageLogInsertPrepared) (map[string]usageLogBatchState, error) { - var query strings.Builder - _, _ = query.WriteString(`SELECT request_id, api_key_id, id, created_at FROM usage_logs WHERE `) - args := make([]any, 0, len(keys)*2) - argPos := 1 - for idx, key := range keys { - if idx > 0 { - _, _ = query.WriteString(" OR ") - } - prepared := preparedByKey[key] - apiKeyID := prepared.args[1] - _, _ = query.WriteString("(request_id = $") - _, _ = query.WriteString(strconv.Itoa(argPos)) - _, _ = query.WriteString(" AND api_key_id = $") - _, _ = query.WriteString(strconv.Itoa(argPos + 1)) - _, _ = query.WriteString(")") - args = append(args, prepared.requestID, apiKeyID) - argPos += 2 - } - - rows, err := db.QueryContext(ctx, query.String(), args...) - if err != nil { - return nil, err - } - defer func() { _ = rows.Close() }() - - stateMap := make(map[string]usageLogBatchState, len(keys)) - for rows.Next() { - var ( - requestID string - apiKeyID int64 - id int64 - createdAt time.Time +func execUsageLogInsertNoResult(ctx context.Context, sqlq sqlExecutor, prepared usageLogInsertPrepared) error { + _, err := sqlq.ExecContext(ctx, ` + INSERT INTO usage_logs ( + user_id, + api_key_id, + account_id, + request_id, + model, + group_id, + subscription_id, + input_tokens, + output_tokens, + cache_creation_tokens, + cache_read_tokens, + cache_creation_5m_tokens, + cache_creation_1h_tokens, + input_cost, + output_cost, + cache_creation_cost, + cache_read_cost, + total_cost, + actual_cost, + rate_multiplier, + account_rate_multiplier, + billing_type, + request_type, + stream, + openai_ws_mode, + duration_ms, + first_token_ms, + user_agent, + ip_address, + image_count, + image_size, + media_type, + service_tier, + reasoning_effort, + cache_ttl_overridden, + created_at + ) VALUES ( + $1, $2, $3, $4, $5, + $6, $7, + $8, $9, $10, $11, + $12, $13, + $14, $15, $16, $17, $18, $19, + $20, $21, $22, $23, $24, $25, $26, $27, $28, $29, $30, $31, $32, $33, $34, $35, $36 ) - if err := rows.Scan(&requestID, &apiKeyID, &id, &createdAt); err != nil { - return nil, err - } - stateMap[usageLogBatchKey(requestID, apiKeyID)] = usageLogBatchState{ - ID: id, - CreatedAt: createdAt, - } - } - if err := rows.Err(); err != nil { - return nil, err - } - return stateMap, nil + ON CONFLICT (request_id, api_key_id) DO NOTHING + `, prepared.args...) + return err } func prepareUsageLogInsert(log *service.UsageLog) usageLogInsertPrepared { @@ -597,6 +1108,14 @@ func sendUsageLogCreateResult(ch chan usageLogCreateResult, res usageLogCreateRe } } +func (r *usageLogRepository) bestEffortRecentKey(requestID string, apiKeyID int64) (string, bool) { + requestID = strings.TrimSpace(requestID) + if requestID == "" || r == nil || r.bestEffortRecent == nil { + return "", false + } + return usageLogBatchKey(requestID, apiKeyID), true +} + func (r *usageLogRepository) GetByID(ctx context.Context, id int64) (log *service.UsageLog, err error) { query := "SELECT " + usageLogSelectColumns + " FROM usage_logs WHERE id = $1" rows, err := r.sql.QueryContext(ctx, query, id) diff --git a/backend/internal/repository/usage_log_repo_integration_test.go b/backend/internal/repository/usage_log_repo_integration_test.go index d2e1e9d4..00740878 100644 --- a/backend/internal/repository/usage_log_repo_integration_test.go +++ b/backend/internal/repository/usage_log_repo_integration_test.go @@ -183,6 +183,214 @@ func TestUsageLogRepositoryCreate_BatchPathDuplicateRequestID(t *testing.T) { require.Equal(t, 1, count) } +func TestUsageLogRepositoryFlushCreateBatch_DeduplicatesSameKeyInMemory(t *testing.T) { + ctx := context.Background() + client := testEntClient(t) + repo := newUsageLogRepositoryWithSQL(client, integrationDB) + + user := mustCreateUser(t, client, &service.User{Email: fmt.Sprintf("usage-batch-memdup-%d@example.com", time.Now().UnixNano())}) + apiKey := mustCreateApiKey(t, client, &service.APIKey{UserID: user.ID, Key: "sk-usage-batch-memdup-" + uuid.NewString(), Name: "k"}) + account := mustCreateAccount(t, client, &service.Account{Name: "acc-usage-batch-memdup-" + uuid.NewString()}) + requestID := uuid.NewString() + + const total = 8 + batch := make([]usageLogCreateRequest, 0, total) + logs := make([]*service.UsageLog, 0, total) + + for i := 0; i < total; i++ { + log := &service.UsageLog{ + UserID: user.ID, + APIKeyID: apiKey.ID, + AccountID: account.ID, + RequestID: requestID, + Model: "claude-3", + InputTokens: 10 + i, + OutputTokens: 20 + i, + TotalCost: 0.5, + ActualCost: 0.5, + CreatedAt: time.Now().UTC(), + } + logs = append(logs, log) + batch = append(batch, usageLogCreateRequest{ + log: log, + prepared: prepareUsageLogInsert(log), + resultCh: make(chan usageLogCreateResult, 1), + }) + } + + repo.flushCreateBatch(integrationDB, batch) + + insertedCount := 0 + var firstID int64 + for idx, req := range batch { + res := <-req.resultCh + require.NoError(t, res.err) + if res.inserted { + insertedCount++ + } + require.NotZero(t, logs[idx].ID) + if idx == 0 { + firstID = logs[idx].ID + } else { + require.Equal(t, firstID, logs[idx].ID) + } + } + + require.Equal(t, 1, insertedCount) + + var count int + require.NoError(t, integrationDB.QueryRowContext(ctx, "SELECT COUNT(*) FROM usage_logs WHERE request_id = $1 AND api_key_id = $2", requestID, apiKey.ID).Scan(&count)) + require.Equal(t, 1, count) +} + +func TestUsageLogRepositoryCreateBestEffort_BatchPathDuplicateRequestID(t *testing.T) { + ctx := context.Background() + client := testEntClient(t) + repo := newUsageLogRepositoryWithSQL(client, integrationDB) + + user := mustCreateUser(t, client, &service.User{Email: fmt.Sprintf("usage-best-effort-dup-%d@example.com", time.Now().UnixNano())}) + apiKey := mustCreateApiKey(t, client, &service.APIKey{UserID: user.ID, Key: "sk-usage-best-effort-dup-" + uuid.NewString(), Name: "k"}) + account := mustCreateAccount(t, client, &service.Account{Name: "acc-usage-best-effort-dup-" + uuid.NewString()}) + requestID := uuid.NewString() + + log1 := &service.UsageLog{ + UserID: user.ID, + APIKeyID: apiKey.ID, + AccountID: account.ID, + RequestID: requestID, + Model: "claude-3", + InputTokens: 10, + OutputTokens: 20, + TotalCost: 0.5, + ActualCost: 0.5, + CreatedAt: time.Now().UTC(), + } + log2 := &service.UsageLog{ + UserID: user.ID, + APIKeyID: apiKey.ID, + AccountID: account.ID, + RequestID: requestID, + Model: "claude-3", + InputTokens: 10, + OutputTokens: 20, + TotalCost: 0.5, + ActualCost: 0.5, + CreatedAt: time.Now().UTC(), + } + + require.NoError(t, repo.CreateBestEffort(ctx, log1)) + require.NoError(t, repo.CreateBestEffort(ctx, log2)) + + require.Eventually(t, func() bool { + var count int + err := integrationDB.QueryRowContext(ctx, "SELECT COUNT(*) FROM usage_logs WHERE request_id = $1 AND api_key_id = $2", requestID, apiKey.ID).Scan(&count) + return err == nil && count == 1 + }, 3*time.Second, 20*time.Millisecond) +} + +func TestUsageLogRepositoryCreate_BatchPathCanceledContextMarksNotPersisted(t *testing.T) { + client := testEntClient(t) + repo := newUsageLogRepositoryWithSQL(client, integrationDB) + + user := mustCreateUser(t, client, &service.User{Email: fmt.Sprintf("usage-cancel-%d@example.com", time.Now().UnixNano())}) + apiKey := mustCreateApiKey(t, client, &service.APIKey{UserID: user.ID, Key: "sk-usage-cancel-" + uuid.NewString(), Name: "k"}) + account := mustCreateAccount(t, client, &service.Account{Name: "acc-usage-cancel-" + uuid.NewString()}) + + ctx, cancel := context.WithCancel(context.Background()) + cancel() + + inserted, err := repo.Create(ctx, &service.UsageLog{ + UserID: user.ID, + APIKeyID: apiKey.ID, + AccountID: account.ID, + RequestID: uuid.NewString(), + Model: "claude-3", + InputTokens: 10, + OutputTokens: 20, + TotalCost: 0.5, + ActualCost: 0.5, + CreatedAt: time.Now().UTC(), + }) + + require.False(t, inserted) + require.Error(t, err) + require.True(t, service.IsUsageLogCreateNotPersisted(err)) +} + +func TestUsageLogRepositoryCreate_BatchPathCanceledAfterQueueMarksNotPersisted(t *testing.T) { + client := testEntClient(t) + repo := newUsageLogRepositoryWithSQL(client, integrationDB) + repo.createBatchCh = make(chan usageLogCreateRequest, 1) + + user := mustCreateUser(t, client, &service.User{Email: fmt.Sprintf("usage-cancel-queued-%d@example.com", time.Now().UnixNano())}) + apiKey := mustCreateApiKey(t, client, &service.APIKey{UserID: user.ID, Key: "sk-usage-cancel-queued-" + uuid.NewString(), Name: "k"}) + account := mustCreateAccount(t, client, &service.Account{Name: "acc-usage-cancel-queued-" + uuid.NewString()}) + + ctx, cancel := context.WithCancel(context.Background()) + errCh := make(chan error, 1) + + go func() { + _, err := repo.createBatched(ctx, &service.UsageLog{ + UserID: user.ID, + APIKeyID: apiKey.ID, + AccountID: account.ID, + RequestID: uuid.NewString(), + Model: "claude-3", + InputTokens: 10, + OutputTokens: 20, + TotalCost: 0.5, + ActualCost: 0.5, + CreatedAt: time.Now().UTC(), + }) + errCh <- err + }() + + req := <-repo.createBatchCh + require.NotNil(t, req.shared) + cancel() + + err := <-errCh + require.Error(t, err) + require.True(t, service.IsUsageLogCreateNotPersisted(err)) + completeUsageLogCreateRequest(req, usageLogCreateResult{inserted: false, err: service.MarkUsageLogCreateNotPersisted(context.Canceled)}) +} + +func TestUsageLogRepositoryFlushCreateBatch_CanceledRequestReturnsNotPersisted(t *testing.T) { + client := testEntClient(t) + repo := newUsageLogRepositoryWithSQL(client, integrationDB) + + user := mustCreateUser(t, client, &service.User{Email: fmt.Sprintf("usage-flush-cancel-%d@example.com", time.Now().UnixNano())}) + apiKey := mustCreateApiKey(t, client, &service.APIKey{UserID: user.ID, Key: "sk-usage-flush-cancel-" + uuid.NewString(), Name: "k"}) + account := mustCreateAccount(t, client, &service.Account{Name: "acc-usage-flush-cancel-" + uuid.NewString()}) + + log := &service.UsageLog{ + UserID: user.ID, + APIKeyID: apiKey.ID, + AccountID: account.ID, + RequestID: uuid.NewString(), + Model: "claude-3", + InputTokens: 10, + OutputTokens: 20, + TotalCost: 0.5, + ActualCost: 0.5, + CreatedAt: time.Now().UTC(), + } + req := usageLogCreateRequest{ + log: log, + prepared: prepareUsageLogInsert(log), + shared: &usageLogCreateShared{}, + resultCh: make(chan usageLogCreateResult, 1), + } + req.shared.state.Store(usageLogCreateStateCanceled) + + repo.flushCreateBatch(integrationDB, []usageLogCreateRequest{req}) + + res := <-req.resultCh + require.False(t, res.inserted) + require.Error(t, res.err) + require.True(t, service.IsUsageLogCreateNotPersisted(res.err)) +} + func (s *UsageLogRepoSuite) TestGetByID() { user := mustCreateUser(s.T(), s.client, &service.User{Email: "getbyid@test.com"}) apiKey := mustCreateApiKey(s.T(), s.client, &service.APIKey{UserID: user.ID, Key: "sk-getbyid", Name: "k"}) diff --git a/backend/internal/repository/wire.go b/backend/internal/repository/wire.go index 5fe7a98e..01395bcb 100644 --- a/backend/internal/repository/wire.go +++ b/backend/internal/repository/wire.go @@ -62,6 +62,7 @@ var ProviderSet = wire.NewSet( NewAnnouncementRepository, NewAnnouncementReadRepository, NewUsageLogRepository, + NewUsageBillingRepository, NewIdempotencyRepository, NewUsageCleanupRepository, NewDashboardAggregationRepository, diff --git a/backend/internal/service/dashboard_aggregation_service.go b/backend/internal/service/dashboard_aggregation_service.go index a67f8532..b58a1ea9 100644 --- a/backend/internal/service/dashboard_aggregation_service.go +++ b/backend/internal/service/dashboard_aggregation_service.go @@ -35,6 +35,7 @@ type DashboardAggregationRepository interface { UpdateAggregationWatermark(ctx context.Context, aggregatedAt time.Time) error CleanupAggregates(ctx context.Context, hourlyCutoff, dailyCutoff time.Time) error CleanupUsageLogs(ctx context.Context, cutoff time.Time) error + CleanupUsageBillingDedup(ctx context.Context, cutoff time.Time) error EnsureUsageLogsPartitions(ctx context.Context, now time.Time) error } @@ -296,6 +297,7 @@ func (s *DashboardAggregationService) maybeCleanupRetention(ctx context.Context, hourlyCutoff := now.AddDate(0, 0, -s.cfg.Retention.HourlyDays) dailyCutoff := now.AddDate(0, 0, -s.cfg.Retention.DailyDays) usageCutoff := now.AddDate(0, 0, -s.cfg.Retention.UsageLogsDays) + dedupCutoff := now.AddDate(0, 0, -s.cfg.Retention.UsageBillingDedupDays) aggErr := s.repo.CleanupAggregates(ctx, hourlyCutoff, dailyCutoff) if aggErr != nil { @@ -305,7 +307,11 @@ func (s *DashboardAggregationService) maybeCleanupRetention(ctx context.Context, if usageErr != nil { logger.LegacyPrintf("service.dashboard_aggregation", "[DashboardAggregation] usage_logs 保留清理失败: %v", usageErr) } - if aggErr == nil && usageErr == nil { + dedupErr := s.repo.CleanupUsageBillingDedup(ctx, dedupCutoff) + if dedupErr != nil { + logger.LegacyPrintf("service.dashboard_aggregation", "[DashboardAggregation] usage_billing_dedup 保留清理失败: %v", dedupErr) + } + if aggErr == nil && usageErr == nil && dedupErr == nil { s.lastRetentionCleanup.Store(now) } } diff --git a/backend/internal/service/dashboard_aggregation_service_test.go b/backend/internal/service/dashboard_aggregation_service_test.go index a7058985..fbb671bb 100644 --- a/backend/internal/service/dashboard_aggregation_service_test.go +++ b/backend/internal/service/dashboard_aggregation_service_test.go @@ -12,12 +12,18 @@ import ( type dashboardAggregationRepoTestStub struct { aggregateCalls int + recomputeCalls int + cleanupUsageCalls int + cleanupDedupCalls int + ensurePartitionCalls int lastStart time.Time lastEnd time.Time watermark time.Time aggregateErr error cleanupAggregatesErr error cleanupUsageErr error + cleanupDedupErr error + ensurePartitionErr error } func (s *dashboardAggregationRepoTestStub) AggregateRange(ctx context.Context, start, end time.Time) error { @@ -28,6 +34,7 @@ func (s *dashboardAggregationRepoTestStub) AggregateRange(ctx context.Context, s } func (s *dashboardAggregationRepoTestStub) RecomputeRange(ctx context.Context, start, end time.Time) error { + s.recomputeCalls++ return s.AggregateRange(ctx, start, end) } @@ -44,11 +51,18 @@ func (s *dashboardAggregationRepoTestStub) CleanupAggregates(ctx context.Context } func (s *dashboardAggregationRepoTestStub) CleanupUsageLogs(ctx context.Context, cutoff time.Time) error { + s.cleanupUsageCalls++ return s.cleanupUsageErr } +func (s *dashboardAggregationRepoTestStub) CleanupUsageBillingDedup(ctx context.Context, cutoff time.Time) error { + s.cleanupDedupCalls++ + return s.cleanupDedupErr +} + func (s *dashboardAggregationRepoTestStub) EnsureUsageLogsPartitions(ctx context.Context, now time.Time) error { - return nil + s.ensurePartitionCalls++ + return s.ensurePartitionErr } func TestDashboardAggregationService_RunScheduledAggregation_EpochUsesRetentionStart(t *testing.T) { @@ -90,6 +104,50 @@ func TestDashboardAggregationService_CleanupRetentionFailure_DoesNotRecord(t *te svc.maybeCleanupRetention(context.Background(), time.Now().UTC()) require.Nil(t, svc.lastRetentionCleanup.Load()) + require.Equal(t, 1, repo.cleanupUsageCalls) + require.Equal(t, 1, repo.cleanupDedupCalls) +} + +func TestDashboardAggregationService_CleanupDedupFailure_DoesNotRecord(t *testing.T) { + repo := &dashboardAggregationRepoTestStub{cleanupDedupErr: errors.New("dedup cleanup failed")} + svc := &DashboardAggregationService{ + repo: repo, + cfg: config.DashboardAggregationConfig{ + Retention: config.DashboardAggregationRetentionConfig{ + UsageLogsDays: 1, + HourlyDays: 1, + DailyDays: 1, + }, + }, + } + + svc.maybeCleanupRetention(context.Background(), time.Now().UTC()) + + require.Nil(t, svc.lastRetentionCleanup.Load()) + require.Equal(t, 1, repo.cleanupDedupCalls) +} + +func TestDashboardAggregationService_PartitionFailure_DoesNotAggregate(t *testing.T) { + repo := &dashboardAggregationRepoTestStub{ensurePartitionErr: errors.New("partition failed")} + svc := &DashboardAggregationService{ + repo: repo, + cfg: config.DashboardAggregationConfig{ + Enabled: true, + IntervalSeconds: 60, + LookbackSeconds: 120, + Retention: config.DashboardAggregationRetentionConfig{ + UsageLogsDays: 1, + UsageBillingDedupDays: 2, + HourlyDays: 1, + DailyDays: 1, + }, + }, + } + + svc.runScheduledAggregation() + + require.Equal(t, 1, repo.ensurePartitionCalls) + require.Equal(t, 1, repo.aggregateCalls) } func TestDashboardAggregationService_TriggerBackfill_TooLarge(t *testing.T) { diff --git a/backend/internal/service/dashboard_service_test.go b/backend/internal/service/dashboard_service_test.go index 59b83e66..2a7f47b6 100644 --- a/backend/internal/service/dashboard_service_test.go +++ b/backend/internal/service/dashboard_service_test.go @@ -124,6 +124,10 @@ func (s *dashboardAggregationRepoStub) CleanupUsageLogs(ctx context.Context, cut return nil } +func (s *dashboardAggregationRepoStub) CleanupUsageBillingDedup(ctx context.Context, cutoff time.Time) error { + return nil +} + func (s *dashboardAggregationRepoStub) EnsureUsageLogsPartitions(ctx context.Context, now time.Time) error { return nil } diff --git a/backend/internal/service/gateway_anthropic_apikey_passthrough_test.go b/backend/internal/service/gateway_anthropic_apikey_passthrough_test.go index 5dcda1de..789cbab8 100644 --- a/backend/internal/service/gateway_anthropic_apikey_passthrough_test.go +++ b/backend/internal/service/gateway_anthropic_apikey_passthrough_test.go @@ -136,16 +136,18 @@ func TestGatewayService_AnthropicAPIKeyPassthrough_ForwardStreamPreservesBodyAnd }, } - svc := &GatewayService{ - cfg: &config.Config{ - Gateway: config.GatewayConfig{ - MaxLineSize: defaultMaxLineSize, - }, + cfg := &config.Config{ + Gateway: config.GatewayConfig{ + MaxLineSize: defaultMaxLineSize, }, - httpUpstream: upstream, - rateLimitService: &RateLimitService{}, - deferredService: &DeferredService{}, - billingCacheService: nil, + } + svc := &GatewayService{ + cfg: cfg, + responseHeaderFilter: compileResponseHeaderFilter(cfg), + httpUpstream: upstream, + rateLimitService: &RateLimitService{}, + deferredService: &DeferredService{}, + billingCacheService: nil, } account := &Account{ @@ -221,14 +223,16 @@ func TestGatewayService_AnthropicAPIKeyPassthrough_ForwardCountTokensPreservesBo }, } - svc := &GatewayService{ - cfg: &config.Config{ - Gateway: config.GatewayConfig{ - MaxLineSize: defaultMaxLineSize, - }, + cfg := &config.Config{ + Gateway: config.GatewayConfig{ + MaxLineSize: defaultMaxLineSize, }, - httpUpstream: upstream, - rateLimitService: &RateLimitService{}, + } + svc := &GatewayService{ + cfg: cfg, + responseHeaderFilter: compileResponseHeaderFilter(cfg), + httpUpstream: upstream, + rateLimitService: &RateLimitService{}, } account := &Account{ @@ -727,6 +731,39 @@ func TestGatewayService_AnthropicAPIKeyPassthrough_StreamingStillCollectsUsageAf require.Equal(t, 5, result.usage.OutputTokens) } +func TestGatewayService_AnthropicAPIKeyPassthrough_MissingTerminalEventReturnsError(t *testing.T) { + gin.SetMode(gin.TestMode) + + rec := httptest.NewRecorder() + c, _ := gin.CreateTestContext(rec) + c.Request = httptest.NewRequest(http.MethodPost, "/v1/messages", nil) + + svc := &GatewayService{ + cfg: &config.Config{ + Gateway: config.GatewayConfig{ + MaxLineSize: defaultMaxLineSize, + }, + }, + rateLimitService: &RateLimitService{}, + } + + resp := &http.Response{ + StatusCode: http.StatusOK, + Header: http.Header{"Content-Type": []string{"text/event-stream"}}, + Body: io.NopCloser(strings.NewReader(strings.Join([]string{ + `data: {"type":"message_start","message":{"usage":{"input_tokens":11}}}`, + "", + `data: {"type":"message_delta","usage":{"output_tokens":5}}`, + "", + }, "\n"))), + } + + result, err := svc.handleStreamingResponseAnthropicAPIKeyPassthrough(context.Background(), resp, c, &Account{ID: 1}, time.Now(), "claude-3-7-sonnet-20250219") + require.Error(t, err) + require.Contains(t, err.Error(), "missing terminal event") + require.NotNil(t, result) +} + func TestGatewayService_AnthropicAPIKeyPassthrough_ForwardDirect_NonStreamingSuccess(t *testing.T) { gin.SetMode(gin.TestMode) rec := httptest.NewRecorder() @@ -1074,7 +1111,8 @@ func TestGatewayService_AnthropicAPIKeyPassthrough_StreamingTimeoutAfterClientDi _ = pr.Close() <-done - require.NoError(t, err) + require.Error(t, err) + require.Contains(t, err.Error(), "stream usage incomplete after timeout") require.NotNil(t, result) require.True(t, result.clientDisconnect) require.Equal(t, 9, result.usage.InputTokens) @@ -1103,7 +1141,8 @@ func TestGatewayService_AnthropicAPIKeyPassthrough_StreamingContextCanceled(t *t } result, err := svc.handleStreamingResponseAnthropicAPIKeyPassthrough(context.Background(), resp, c, &Account{ID: 3}, time.Now(), "claude-3-7-sonnet-20250219") - require.NoError(t, err) + require.Error(t, err) + require.Contains(t, err.Error(), "stream usage incomplete") require.NotNil(t, result) require.True(t, result.clientDisconnect) } @@ -1133,7 +1172,8 @@ func TestGatewayService_AnthropicAPIKeyPassthrough_StreamingUpstreamReadErrorAft } result, err := svc.handleStreamingResponseAnthropicAPIKeyPassthrough(context.Background(), resp, c, &Account{ID: 4}, time.Now(), "claude-3-7-sonnet-20250219") - require.NoError(t, err) + require.Error(t, err) + require.Contains(t, err.Error(), "stream usage incomplete after disconnect") require.NotNil(t, result) require.True(t, result.clientDisconnect) require.Equal(t, 8, result.usage.InputTokens) diff --git a/backend/internal/service/gateway_record_usage_test.go b/backend/internal/service/gateway_record_usage_test.go new file mode 100644 index 00000000..92e59ac8 --- /dev/null +++ b/backend/internal/service/gateway_record_usage_test.go @@ -0,0 +1,261 @@ +//go:build unit + +package service + +import ( + "context" + "testing" + "time" + + "github.com/Wei-Shaw/sub2api/internal/config" + "github.com/Wei-Shaw/sub2api/internal/pkg/ctxkey" + "github.com/stretchr/testify/require" +) + +func newGatewayRecordUsageServiceForTest(usageRepo UsageLogRepository, userRepo UserRepository, subRepo UserSubscriptionRepository) *GatewayService { + cfg := &config.Config{} + cfg.Default.RateMultiplier = 1.1 + return NewGatewayService( + nil, + nil, + usageRepo, + nil, + userRepo, + subRepo, + nil, + nil, + cfg, + nil, + nil, + NewBillingService(cfg, nil), + nil, + &BillingCacheService{}, + nil, + nil, + &DeferredService{}, + nil, + nil, + nil, + nil, + nil, + ) +} + +func newGatewayRecordUsageServiceWithBillingRepoForTest(usageRepo UsageLogRepository, billingRepo UsageBillingRepository, userRepo UserRepository, subRepo UserSubscriptionRepository) *GatewayService { + svc := newGatewayRecordUsageServiceForTest(usageRepo, userRepo, subRepo) + svc.usageBillingRepo = billingRepo + return svc +} + +func TestGatewayServiceRecordUsage_BillingUsesDetachedContext(t *testing.T) { + usageRepo := &openAIRecordUsageLogRepoStub{inserted: false, err: context.DeadlineExceeded} + userRepo := &openAIRecordUsageUserRepoStub{} + subRepo := &openAIRecordUsageSubRepoStub{} + quotaSvc := &openAIRecordUsageAPIKeyQuotaStub{} + svc := newGatewayRecordUsageServiceForTest(usageRepo, userRepo, subRepo) + + reqCtx, cancel := context.WithCancel(context.Background()) + cancel() + + err := svc.RecordUsage(reqCtx, &RecordUsageInput{ + Result: &ForwardResult{ + RequestID: "gateway_detached_ctx", + Usage: ClaudeUsage{ + InputTokens: 10, + OutputTokens: 6, + }, + Model: "claude-sonnet-4", + Duration: time.Second, + }, + APIKey: &APIKey{ + ID: 501, + Quota: 100, + }, + User: &User{ID: 601}, + Account: &Account{ID: 701}, + APIKeyService: quotaSvc, + }) + + require.NoError(t, err) + require.Equal(t, 1, usageRepo.calls) + require.Equal(t, 1, userRepo.deductCalls) + require.NoError(t, userRepo.lastCtxErr) + require.Equal(t, 1, quotaSvc.quotaCalls) + require.NoError(t, quotaSvc.lastQuotaCtxErr) +} + +func TestGatewayServiceRecordUsage_BillingFingerprintIncludesRequestPayloadHash(t *testing.T) { + usageRepo := &openAIRecordUsageLogRepoStub{} + billingRepo := &openAIRecordUsageBillingRepoStub{result: &UsageBillingApplyResult{Applied: true}} + svc := newGatewayRecordUsageServiceWithBillingRepoForTest(usageRepo, billingRepo, &openAIRecordUsageUserRepoStub{}, &openAIRecordUsageSubRepoStub{}) + + payloadHash := HashUsageRequestPayload([]byte(`{"messages":[{"role":"user","content":"hello"}]}`)) + err := svc.RecordUsage(context.Background(), &RecordUsageInput{ + Result: &ForwardResult{ + RequestID: "gateway_payload_hash", + Usage: ClaudeUsage{ + InputTokens: 10, + OutputTokens: 6, + }, + Model: "claude-sonnet-4", + Duration: time.Second, + }, + APIKey: &APIKey{ID: 501, Quota: 100}, + User: &User{ID: 601}, + Account: &Account{ID: 701}, + RequestPayloadHash: payloadHash, + }) + require.NoError(t, err) + require.NotNil(t, billingRepo.lastCmd) + require.Equal(t, payloadHash, billingRepo.lastCmd.RequestPayloadHash) +} + +func TestGatewayServiceRecordUsage_BillingFingerprintFallsBackToContextRequestID(t *testing.T) { + usageRepo := &openAIRecordUsageLogRepoStub{} + billingRepo := &openAIRecordUsageBillingRepoStub{result: &UsageBillingApplyResult{Applied: true}} + svc := newGatewayRecordUsageServiceWithBillingRepoForTest(usageRepo, billingRepo, &openAIRecordUsageUserRepoStub{}, &openAIRecordUsageSubRepoStub{}) + + ctx := context.WithValue(context.Background(), ctxkey.RequestID, "req-local-123") + err := svc.RecordUsage(ctx, &RecordUsageInput{ + Result: &ForwardResult{ + RequestID: "gateway_payload_fallback", + Usage: ClaudeUsage{ + InputTokens: 10, + OutputTokens: 6, + }, + Model: "claude-sonnet-4", + Duration: time.Second, + }, + APIKey: &APIKey{ID: 501, Quota: 100}, + User: &User{ID: 601}, + Account: &Account{ID: 701}, + }) + require.NoError(t, err) + require.NotNil(t, billingRepo.lastCmd) + require.Equal(t, "local:req-local-123", billingRepo.lastCmd.RequestPayloadHash) +} + +func TestGatewayServiceRecordUsage_UsageLogWriteErrorDoesNotSkipBilling(t *testing.T) { + usageRepo := &openAIRecordUsageLogRepoStub{inserted: false, err: MarkUsageLogCreateNotPersisted(context.Canceled)} + userRepo := &openAIRecordUsageUserRepoStub{} + subRepo := &openAIRecordUsageSubRepoStub{} + quotaSvc := &openAIRecordUsageAPIKeyQuotaStub{} + svc := newGatewayRecordUsageServiceForTest(usageRepo, userRepo, subRepo) + + err := svc.RecordUsage(context.Background(), &RecordUsageInput{ + Result: &ForwardResult{ + RequestID: "gateway_not_persisted", + Usage: ClaudeUsage{ + InputTokens: 10, + OutputTokens: 6, + }, + Model: "claude-sonnet-4", + Duration: time.Second, + }, + APIKey: &APIKey{ + ID: 503, + Quota: 100, + }, + User: &User{ID: 603}, + Account: &Account{ID: 703}, + APIKeyService: quotaSvc, + }) + + require.NoError(t, err) + require.Equal(t, 1, usageRepo.calls) + require.Equal(t, 1, userRepo.deductCalls) + require.Equal(t, 1, quotaSvc.quotaCalls) +} + +func TestGatewayServiceRecordUsageWithLongContext_BillingUsesDetachedContext(t *testing.T) { + usageRepo := &openAIRecordUsageLogRepoStub{inserted: false, err: context.DeadlineExceeded} + userRepo := &openAIRecordUsageUserRepoStub{} + subRepo := &openAIRecordUsageSubRepoStub{} + quotaSvc := &openAIRecordUsageAPIKeyQuotaStub{} + svc := newGatewayRecordUsageServiceForTest(usageRepo, userRepo, subRepo) + + reqCtx, cancel := context.WithCancel(context.Background()) + cancel() + + err := svc.RecordUsageWithLongContext(reqCtx, &RecordUsageLongContextInput{ + Result: &ForwardResult{ + RequestID: "gateway_long_context_detached_ctx", + Usage: ClaudeUsage{ + InputTokens: 12, + OutputTokens: 8, + }, + Model: "claude-sonnet-4", + Duration: time.Second, + }, + APIKey: &APIKey{ + ID: 502, + Quota: 100, + }, + User: &User{ID: 602}, + Account: &Account{ID: 702}, + LongContextThreshold: 200000, + LongContextMultiplier: 2, + APIKeyService: quotaSvc, + }) + + require.NoError(t, err) + require.Equal(t, 1, usageRepo.calls) + require.Equal(t, 1, userRepo.deductCalls) + require.NoError(t, userRepo.lastCtxErr) + require.Equal(t, 1, quotaSvc.quotaCalls) + require.NoError(t, quotaSvc.lastQuotaCtxErr) +} + +func TestGatewayServiceRecordUsage_UsesFallbackRequestIDForUsageLog(t *testing.T) { + usageRepo := &openAIRecordUsageLogRepoStub{} + userRepo := &openAIRecordUsageUserRepoStub{} + subRepo := &openAIRecordUsageSubRepoStub{} + svc := newGatewayRecordUsageServiceForTest(usageRepo, userRepo, subRepo) + + ctx := context.WithValue(context.Background(), ctxkey.RequestID, "gateway-local-fallback") + err := svc.RecordUsage(ctx, &RecordUsageInput{ + Result: &ForwardResult{ + RequestID: "", + Usage: ClaudeUsage{ + InputTokens: 10, + OutputTokens: 6, + }, + Model: "claude-sonnet-4", + Duration: time.Second, + }, + APIKey: &APIKey{ID: 504}, + User: &User{ID: 604}, + Account: &Account{ID: 704}, + }) + + require.NoError(t, err) + require.NotNil(t, usageRepo.lastLog) + require.Equal(t, "local:gateway-local-fallback", usageRepo.lastLog.RequestID) +} + +func TestGatewayServiceRecordUsage_BillingErrorSkipsUsageLogWrite(t *testing.T) { + usageRepo := &openAIRecordUsageLogRepoStub{} + billingRepo := &openAIRecordUsageBillingRepoStub{err: context.DeadlineExceeded} + userRepo := &openAIRecordUsageUserRepoStub{} + subRepo := &openAIRecordUsageSubRepoStub{} + svc := newGatewayRecordUsageServiceWithBillingRepoForTest(usageRepo, billingRepo, userRepo, subRepo) + + err := svc.RecordUsage(context.Background(), &RecordUsageInput{ + Result: &ForwardResult{ + RequestID: "gateway_billing_fail", + Usage: ClaudeUsage{ + InputTokens: 10, + OutputTokens: 6, + }, + Model: "claude-sonnet-4", + Duration: time.Second, + }, + APIKey: &APIKey{ID: 505}, + User: &User{ID: 605}, + Account: &Account{ID: 705}, + }) + + require.Error(t, err) + require.Equal(t, 1, billingRepo.calls) + require.Equal(t, 0, usageRepo.calls) +} diff --git a/backend/internal/service/gateway_service.go b/backend/internal/service/gateway_service.go index 080de063..670ff21e 100644 --- a/backend/internal/service/gateway_service.go +++ b/backend/internal/service/gateway_service.go @@ -50,6 +50,7 @@ const ( defaultUserGroupRateCacheTTL = 30 * time.Second defaultModelsListCacheTTL = 15 * time.Second + postUsageBillingTimeout = 15 * time.Second ) const ( @@ -106,6 +107,52 @@ func GatewayModelsListCacheStats() (cacheHit, cacheMiss, store int64) { return modelsListCacheHitTotal.Load(), modelsListCacheMissTotal.Load(), modelsListCacheStoreTotal.Load() } +func claudeUsageHasAnyTokens(usage *ClaudeUsage) bool { + return usage != nil && (usage.InputTokens > 0 || + usage.OutputTokens > 0 || + usage.CacheCreationInputTokens > 0 || + usage.CacheReadInputTokens > 0 || + usage.CacheCreation5mTokens > 0 || + usage.CacheCreation1hTokens > 0) +} + +func openAIUsageHasAnyTokens(usage *OpenAIUsage) bool { + return usage != nil && (usage.InputTokens > 0 || + usage.OutputTokens > 0 || + usage.CacheCreationInputTokens > 0 || + usage.CacheReadInputTokens > 0) +} + +func openAIStreamEventIsTerminal(data string) bool { + trimmed := strings.TrimSpace(data) + if trimmed == "" { + return false + } + if trimmed == "[DONE]" { + return true + } + switch gjson.Get(trimmed, "type").String() { + case "response.completed", "response.done", "response.failed": + return true + default: + return false + } +} + +func anthropicStreamEventIsTerminal(eventName, data string) bool { + if strings.EqualFold(strings.TrimSpace(eventName), "message_stop") { + return true + } + trimmed := strings.TrimSpace(data) + if trimmed == "" { + return false + } + if trimmed == "[DONE]" { + return true + } + return gjson.Get(trimmed, "type").String() == "message_stop" +} + func cloneStringSlice(src []string) []string { if len(src) == 0 { return nil @@ -504,6 +551,7 @@ type GatewayService struct { accountRepo AccountRepository groupRepo GroupRepository usageLogRepo UsageLogRepository + usageBillingRepo UsageBillingRepository userRepo UserRepository userSubRepo UserSubscriptionRepository userGroupRateRepo UserGroupRateRepository @@ -537,6 +585,7 @@ func NewGatewayService( accountRepo AccountRepository, groupRepo GroupRepository, usageLogRepo UsageLogRepository, + usageBillingRepo UsageBillingRepository, userRepo UserRepository, userSubRepo UserSubscriptionRepository, userGroupRateRepo UserGroupRateRepository, @@ -563,6 +612,7 @@ func NewGatewayService( accountRepo: accountRepo, groupRepo: groupRepo, usageLogRepo: usageLogRepo, + usageBillingRepo: usageBillingRepo, userRepo: userRepo, userSubRepo: userSubRepo, userGroupRateRepo: userGroupRateRepo, @@ -4049,7 +4099,9 @@ func (s *GatewayService) Forward(ctx context.Context, c *gin.Context, account *A retryStart := time.Now() for attempt := 1; attempt <= maxRetryAttempts; attempt++ { // 构建上游请求(每次重试需要重新构建,因为请求体需要重新读取) - upstreamReq, err := s.buildUpstreamRequest(ctx, c, account, body, token, tokenType, reqModel, reqStream, shouldMimicClaudeCode) + upstreamCtx, releaseUpstreamCtx := detachStreamUpstreamContext(ctx, reqStream) + upstreamReq, err := s.buildUpstreamRequest(upstreamCtx, c, account, body, token, tokenType, reqModel, reqStream, shouldMimicClaudeCode) + releaseUpstreamCtx() if err != nil { return nil, err } @@ -4127,7 +4179,9 @@ func (s *GatewayService) Forward(ctx context.Context, c *gin.Context, account *A // also downgrade tool_use/tool_result blocks to text. filteredBody := FilterThinkingBlocksForRetry(body) - retryReq, buildErr := s.buildUpstreamRequest(ctx, c, account, filteredBody, token, tokenType, reqModel, reqStream, shouldMimicClaudeCode) + retryCtx, releaseRetryCtx := detachStreamUpstreamContext(ctx, reqStream) + retryReq, buildErr := s.buildUpstreamRequest(retryCtx, c, account, filteredBody, token, tokenType, reqModel, reqStream, shouldMimicClaudeCode) + releaseRetryCtx() if buildErr == nil { retryResp, retryErr := s.httpUpstream.DoWithTLS(retryReq, proxyURL, account.ID, account.Concurrency, account.IsTLSFingerprintEnabled()) if retryErr == nil { @@ -4159,7 +4213,9 @@ func (s *GatewayService) Forward(ctx context.Context, c *gin.Context, account *A if looksLikeToolSignatureError(msg2) && time.Since(retryStart) < maxRetryElapsed { logger.LegacyPrintf("service.gateway", "Account %d: signature retry still failing and looks tool-related, retrying with tool blocks downgraded", account.ID) filteredBody2 := FilterSignatureSensitiveBlocksForRetry(body) - retryReq2, buildErr2 := s.buildUpstreamRequest(ctx, c, account, filteredBody2, token, tokenType, reqModel, reqStream, shouldMimicClaudeCode) + retryCtx2, releaseRetryCtx2 := detachStreamUpstreamContext(ctx, reqStream) + retryReq2, buildErr2 := s.buildUpstreamRequest(retryCtx2, c, account, filteredBody2, token, tokenType, reqModel, reqStream, shouldMimicClaudeCode) + releaseRetryCtx2() if buildErr2 == nil { retryResp2, retryErr2 := s.httpUpstream.DoWithTLS(retryReq2, proxyURL, account.ID, account.Concurrency, account.IsTLSFingerprintEnabled()) if retryErr2 == nil { @@ -4226,7 +4282,9 @@ func (s *GatewayService) Forward(ctx context.Context, c *gin.Context, account *A rectifiedBody, applied := RectifyThinkingBudget(body) if applied && time.Since(retryStart) < maxRetryElapsed { logger.LegacyPrintf("service.gateway", "Account %d: detected budget_tokens constraint error, retrying with rectified budget (budget_tokens=%d, max_tokens=%d)", account.ID, BudgetRectifyBudgetTokens, BudgetRectifyMaxTokens) - budgetRetryReq, buildErr := s.buildUpstreamRequest(ctx, c, account, rectifiedBody, token, tokenType, reqModel, reqStream, shouldMimicClaudeCode) + budgetRetryCtx, releaseBudgetRetryCtx := detachStreamUpstreamContext(ctx, reqStream) + budgetRetryReq, buildErr := s.buildUpstreamRequest(budgetRetryCtx, c, account, rectifiedBody, token, tokenType, reqModel, reqStream, shouldMimicClaudeCode) + releaseBudgetRetryCtx() if buildErr == nil { budgetRetryResp, retryErr := s.httpUpstream.DoWithTLS(budgetRetryReq, proxyURL, account.ID, account.Concurrency, account.IsTLSFingerprintEnabled()) if retryErr == nil { @@ -4498,7 +4556,9 @@ func (s *GatewayService) forwardAnthropicAPIKeyPassthrough( var resp *http.Response retryStart := time.Now() for attempt := 1; attempt <= maxRetryAttempts; attempt++ { - upstreamReq, err := s.buildUpstreamRequestAnthropicAPIKeyPassthrough(ctx, c, account, body, token) + upstreamCtx, releaseUpstreamCtx := detachStreamUpstreamContext(ctx, reqStream) + upstreamReq, err := s.buildUpstreamRequestAnthropicAPIKeyPassthrough(upstreamCtx, c, account, body, token) + releaseUpstreamCtx() if err != nil { return nil, err } @@ -4774,6 +4834,7 @@ func (s *GatewayService) handleStreamingResponseAnthropicAPIKeyPassthrough( usage := &ClaudeUsage{} var firstTokenMs *int clientDisconnected := false + sawTerminalEvent := false scanner := bufio.NewScanner(resp.Body) maxLineSize := defaultMaxLineSize @@ -4836,17 +4897,20 @@ func (s *GatewayService) handleStreamingResponseAnthropicAPIKeyPassthrough( // 兜底补刷,确保最后一个未以空行结尾的事件也能及时送达客户端。 flusher.Flush() } + if !sawTerminalEvent { + return &streamingResult{usage: usage, firstTokenMs: firstTokenMs, clientDisconnect: clientDisconnected}, fmt.Errorf("stream usage incomplete: missing terminal event") + } return &streamingResult{usage: usage, firstTokenMs: firstTokenMs, clientDisconnect: clientDisconnected}, nil } if ev.err != nil { + if sawTerminalEvent { + return &streamingResult{usage: usage, firstTokenMs: firstTokenMs, clientDisconnect: clientDisconnected}, nil + } if clientDisconnected { - logger.LegacyPrintf("service.gateway", "[Anthropic passthrough] Upstream read error after client disconnect: account=%d err=%v", account.ID, ev.err) - return &streamingResult{usage: usage, firstTokenMs: firstTokenMs, clientDisconnect: true}, nil + return &streamingResult{usage: usage, firstTokenMs: firstTokenMs, clientDisconnect: true}, fmt.Errorf("stream usage incomplete after disconnect: %w", ev.err) } if errors.Is(ev.err, context.Canceled) || errors.Is(ev.err, context.DeadlineExceeded) { - logger.LegacyPrintf("service.gateway", "[Anthropic passthrough] 流读取被取消: account=%d request_id=%s err=%v ctx_err=%v", - account.ID, resp.Header.Get("x-request-id"), ev.err, ctx.Err()) - return &streamingResult{usage: usage, firstTokenMs: firstTokenMs, clientDisconnect: true}, nil + return &streamingResult{usage: usage, firstTokenMs: firstTokenMs, clientDisconnect: true}, fmt.Errorf("stream usage incomplete: %w", ev.err) } if errors.Is(ev.err, bufio.ErrTooLong) { logger.LegacyPrintf("service.gateway", "[Anthropic passthrough] SSE line too long: account=%d max_size=%d error=%v", account.ID, maxLineSize, ev.err) @@ -4858,11 +4922,19 @@ func (s *GatewayService) handleStreamingResponseAnthropicAPIKeyPassthrough( line := ev.line if data, ok := extractAnthropicSSEDataLine(line); ok { trimmed := strings.TrimSpace(data) + if anthropicStreamEventIsTerminal("", trimmed) { + sawTerminalEvent = true + } if firstTokenMs == nil && trimmed != "" && trimmed != "[DONE]" { ms := int(time.Since(startTime).Milliseconds()) firstTokenMs = &ms } s.parseSSEUsagePassthrough(data, usage) + } else { + trimmed := strings.TrimSpace(line) + if strings.HasPrefix(trimmed, "event:") && anthropicStreamEventIsTerminal(strings.TrimSpace(strings.TrimPrefix(trimmed, "event:")), "") { + sawTerminalEvent = true + } } if !clientDisconnected { @@ -4884,8 +4956,7 @@ func (s *GatewayService) handleStreamingResponseAnthropicAPIKeyPassthrough( continue } if clientDisconnected { - logger.LegacyPrintf("service.gateway", "[Anthropic passthrough] Upstream timeout after client disconnect: account=%d model=%s", account.ID, model) - return &streamingResult{usage: usage, firstTokenMs: firstTokenMs, clientDisconnect: true}, nil + return &streamingResult{usage: usage, firstTokenMs: firstTokenMs, clientDisconnect: true}, fmt.Errorf("stream usage incomplete after timeout") } logger.LegacyPrintf("service.gateway", "[Anthropic passthrough] Stream data interval timeout: account=%d model=%s interval=%s", account.ID, model, streamInterval) if s.rateLimitService != nil { @@ -6011,6 +6082,7 @@ func (s *GatewayService) handleStreamingResponse(ctx context.Context, resp *http needModelReplace := originalModel != mappedModel clientDisconnected := false // 客户端断开标志,断开后继续读取上游以获取完整usage + sawTerminalEvent := false pendingEventLines := make([]string, 0, 4) @@ -6041,6 +6113,7 @@ func (s *GatewayService) handleStreamingResponse(ctx context.Context, resp *http } if dataLine == "[DONE]" { + sawTerminalEvent = true block := "" if eventName != "" { block = "event: " + eventName + "\n" @@ -6107,6 +6180,9 @@ func (s *GatewayService) handleStreamingResponse(ctx context.Context, resp *http } usagePatch := s.extractSSEUsagePatch(event) + if anthropicStreamEventIsTerminal(eventName, dataLine) { + sawTerminalEvent = true + } if !eventChanged { block := "" if eventName != "" { @@ -6140,18 +6216,22 @@ func (s *GatewayService) handleStreamingResponse(ctx context.Context, resp *http case ev, ok := <-events: if !ok { // 上游完成,返回结果 + if !sawTerminalEvent { + return &streamingResult{usage: usage, firstTokenMs: firstTokenMs, clientDisconnect: clientDisconnected}, fmt.Errorf("stream usage incomplete: missing terminal event") + } return &streamingResult{usage: usage, firstTokenMs: firstTokenMs, clientDisconnect: clientDisconnected}, nil } if ev.err != nil { + if sawTerminalEvent { + return &streamingResult{usage: usage, firstTokenMs: firstTokenMs, clientDisconnect: clientDisconnected}, nil + } // 检测 context 取消(客户端断开会导致 context 取消,进而影响上游读取) if errors.Is(ev.err, context.Canceled) || errors.Is(ev.err, context.DeadlineExceeded) { - logger.LegacyPrintf("service.gateway", "Context canceled during streaming, returning collected usage") - return &streamingResult{usage: usage, firstTokenMs: firstTokenMs, clientDisconnect: true}, nil + return &streamingResult{usage: usage, firstTokenMs: firstTokenMs, clientDisconnect: true}, fmt.Errorf("stream usage incomplete: %w", ev.err) } // 客户端已通过写入失败检测到断开,上游也出错了,返回已收集的 usage if clientDisconnected { - logger.LegacyPrintf("service.gateway", "Upstream read error after client disconnect: %v, returning collected usage", ev.err) - return &streamingResult{usage: usage, firstTokenMs: firstTokenMs, clientDisconnect: true}, nil + return &streamingResult{usage: usage, firstTokenMs: firstTokenMs, clientDisconnect: true}, fmt.Errorf("stream usage incomplete after disconnect: %w", ev.err) } // 客户端未断开,正常的错误处理 if errors.Is(ev.err, bufio.ErrTooLong) { @@ -6209,9 +6289,7 @@ func (s *GatewayService) handleStreamingResponse(ctx context.Context, resp *http continue } if clientDisconnected { - // 客户端已断开,上游也超时了,返回已收集的 usage - logger.LegacyPrintf("service.gateway", "Upstream timeout after client disconnect, returning collected usage") - return &streamingResult{usage: usage, firstTokenMs: firstTokenMs, clientDisconnect: true}, nil + return &streamingResult{usage: usage, firstTokenMs: firstTokenMs, clientDisconnect: true}, fmt.Errorf("stream usage incomplete after timeout") } logger.LegacyPrintf("service.gateway", "Stream data interval timeout: account=%d model=%s interval=%s", account.ID, originalModel, streamInterval) // 处理流超时,可能标记账户为临时不可调度或错误状态 @@ -6557,15 +6635,16 @@ func (s *GatewayService) getUserGroupRateMultiplier(ctx context.Context, userID, // RecordUsageInput 记录使用量的输入参数 type RecordUsageInput struct { - Result *ForwardResult - APIKey *APIKey - User *User - Account *Account - Subscription *UserSubscription // 可选:订阅信息 - UserAgent string // 请求的 User-Agent - IPAddress string // 请求的客户端 IP 地址 - ForceCacheBilling bool // 强制缓存计费:将 input_tokens 转为 cache_read 计费(用于粘性会话切换) - APIKeyService APIKeyQuotaUpdater // 可选:用于更新API Key配额 + Result *ForwardResult + APIKey *APIKey + User *User + Account *Account + Subscription *UserSubscription // 可选:订阅信息 + UserAgent string // 请求的 User-Agent + IPAddress string // 请求的客户端 IP 地址 + RequestPayloadHash string // 请求体语义哈希,用于降低 request_id 误复用时的静默误去重风险 + ForceCacheBilling bool // 强制缓存计费:将 input_tokens 转为 cache_read 计费(用于粘性会话切换) + APIKeyService APIKeyQuotaUpdater // 可选:用于更新API Key配额 } // APIKeyQuotaUpdater defines the interface for updating API Key quota and rate limit usage @@ -6574,6 +6653,14 @@ type APIKeyQuotaUpdater interface { UpdateRateLimitUsage(ctx context.Context, apiKeyID int64, cost float64) error } +type apiKeyAuthCacheInvalidator interface { + InvalidateAuthCacheByKey(ctx context.Context, key string) +} + +type usageLogBestEffortWriter interface { + CreateBestEffort(ctx context.Context, log *UsageLog) error +} + // postUsageBillingParams 统一扣费所需的参数 type postUsageBillingParams struct { Cost *CostBreakdown @@ -6581,6 +6668,7 @@ type postUsageBillingParams struct { APIKey *APIKey Account *Account Subscription *UserSubscription + RequestPayloadHash string IsSubscriptionBill bool AccountRateMultiplier float64 APIKeyService APIKeyQuotaUpdater @@ -6592,19 +6680,22 @@ type postUsageBillingParams struct { // - API Key 限速用量更新 // - 账号配额用量更新(账号口径:TotalCost × 账号计费倍率) func postUsageBilling(ctx context.Context, p *postUsageBillingParams, deps *billingDeps) { + billingCtx, cancel := detachedBillingContext(ctx) + defer cancel() + cost := p.Cost // 1. 订阅 / 余额扣费 if p.IsSubscriptionBill { if cost.TotalCost > 0 { - if err := deps.userSubRepo.IncrementUsage(ctx, p.Subscription.ID, cost.TotalCost); err != nil { + if err := deps.userSubRepo.IncrementUsage(billingCtx, p.Subscription.ID, cost.TotalCost); err != nil { slog.Error("increment subscription usage failed", "subscription_id", p.Subscription.ID, "error", err) } deps.billingCacheService.QueueUpdateSubscriptionUsage(p.User.ID, *p.APIKey.GroupID, cost.TotalCost) } } else { if cost.ActualCost > 0 { - if err := deps.userRepo.DeductBalance(ctx, p.User.ID, cost.ActualCost); err != nil { + if err := deps.userRepo.DeductBalance(billingCtx, p.User.ID, cost.ActualCost); err != nil { slog.Error("deduct balance failed", "user_id", p.User.ID, "error", err) } deps.billingCacheService.QueueDeductBalance(p.User.ID, cost.ActualCost) @@ -6613,31 +6704,187 @@ func postUsageBilling(ctx context.Context, p *postUsageBillingParams, deps *bill // 2. API Key 配额 if cost.ActualCost > 0 && p.APIKey.Quota > 0 && p.APIKeyService != nil { - if err := p.APIKeyService.UpdateQuotaUsed(ctx, p.APIKey.ID, cost.ActualCost); err != nil { + if err := p.APIKeyService.UpdateQuotaUsed(billingCtx, p.APIKey.ID, cost.ActualCost); err != nil { slog.Error("update api key quota failed", "api_key_id", p.APIKey.ID, "error", err) } } // 3. API Key 限速用量 if cost.ActualCost > 0 && p.APIKey.HasRateLimits() && p.APIKeyService != nil { - if err := p.APIKeyService.UpdateRateLimitUsage(ctx, p.APIKey.ID, cost.ActualCost); err != nil { + if err := p.APIKeyService.UpdateRateLimitUsage(billingCtx, p.APIKey.ID, cost.ActualCost); err != nil { slog.Error("update api key rate limit usage failed", "api_key_id", p.APIKey.ID, "error", err) } - deps.billingCacheService.QueueUpdateAPIKeyRateLimitUsage(p.APIKey.ID, cost.ActualCost) } // 4. 账号配额用量(账号口径:TotalCost × 账号计费倍率) if cost.TotalCost > 0 && p.Account.Type == AccountTypeAPIKey && p.Account.HasAnyQuotaLimit() { accountCost := cost.TotalCost * p.AccountRateMultiplier - if err := deps.accountRepo.IncrementQuotaUsed(ctx, p.Account.ID, accountCost); err != nil { + if err := deps.accountRepo.IncrementQuotaUsed(billingCtx, p.Account.ID, accountCost); err != nil { slog.Error("increment account quota used failed", "account_id", p.Account.ID, "cost", accountCost, "error", err) } } - // 5. 更新账号最近使用时间 + finalizePostUsageBilling(p, deps) +} + +func resolveUsageBillingRequestID(ctx context.Context, upstreamRequestID string) string { + if requestID := strings.TrimSpace(upstreamRequestID); requestID != "" { + return requestID + } + if ctx != nil { + if clientRequestID, _ := ctx.Value(ctxkey.ClientRequestID).(string); strings.TrimSpace(clientRequestID) != "" { + return "client:" + strings.TrimSpace(clientRequestID) + } + if requestID, _ := ctx.Value(ctxkey.RequestID).(string); strings.TrimSpace(requestID) != "" { + return "local:" + strings.TrimSpace(requestID) + } + } + return "" +} + +func resolveUsageBillingPayloadFingerprint(ctx context.Context, requestPayloadHash string) string { + if payloadHash := strings.TrimSpace(requestPayloadHash); payloadHash != "" { + return payloadHash + } + if ctx != nil { + if clientRequestID, _ := ctx.Value(ctxkey.ClientRequestID).(string); strings.TrimSpace(clientRequestID) != "" { + return "client:" + strings.TrimSpace(clientRequestID) + } + if requestID, _ := ctx.Value(ctxkey.RequestID).(string); strings.TrimSpace(requestID) != "" { + return "local:" + strings.TrimSpace(requestID) + } + } + return "" +} + +func buildUsageBillingCommand(requestID string, usageLog *UsageLog, p *postUsageBillingParams) *UsageBillingCommand { + if p == nil || p.Cost == nil || p.APIKey == nil || p.User == nil || p.Account == nil { + return nil + } + + cmd := &UsageBillingCommand{ + RequestID: requestID, + APIKeyID: p.APIKey.ID, + UserID: p.User.ID, + AccountID: p.Account.ID, + AccountType: p.Account.Type, + RequestPayloadHash: strings.TrimSpace(p.RequestPayloadHash), + } + if usageLog != nil { + cmd.Model = usageLog.Model + cmd.BillingType = usageLog.BillingType + cmd.InputTokens = usageLog.InputTokens + cmd.OutputTokens = usageLog.OutputTokens + cmd.CacheCreationTokens = usageLog.CacheCreationTokens + cmd.CacheReadTokens = usageLog.CacheReadTokens + cmd.ImageCount = usageLog.ImageCount + if usageLog.MediaType != nil { + cmd.MediaType = *usageLog.MediaType + } + if usageLog.ServiceTier != nil { + cmd.ServiceTier = *usageLog.ServiceTier + } + if usageLog.ReasoningEffort != nil { + cmd.ReasoningEffort = *usageLog.ReasoningEffort + } + if usageLog.SubscriptionID != nil { + cmd.SubscriptionID = usageLog.SubscriptionID + } + } + + if p.IsSubscriptionBill && p.Subscription != nil && p.Cost.TotalCost > 0 { + cmd.SubscriptionID = &p.Subscription.ID + cmd.SubscriptionCost = p.Cost.TotalCost + } else if p.Cost.ActualCost > 0 { + cmd.BalanceCost = p.Cost.ActualCost + } + + if p.Cost.ActualCost > 0 && p.APIKey.Quota > 0 && p.APIKeyService != nil { + cmd.APIKeyQuotaCost = p.Cost.ActualCost + } + if p.Cost.ActualCost > 0 && p.APIKey.HasRateLimits() && p.APIKeyService != nil { + cmd.APIKeyRateLimitCost = p.Cost.ActualCost + } + if p.Cost.TotalCost > 0 && p.Account.Type == AccountTypeAPIKey && p.Account.HasAnyQuotaLimit() { + cmd.AccountQuotaCost = p.Cost.TotalCost * p.AccountRateMultiplier + } + + cmd.Normalize() + return cmd +} + +func applyUsageBilling(ctx context.Context, requestID string, usageLog *UsageLog, p *postUsageBillingParams, deps *billingDeps, repo UsageBillingRepository) (bool, error) { + if p == nil || deps == nil { + return false, nil + } + + cmd := buildUsageBillingCommand(requestID, usageLog, p) + if cmd == nil || cmd.RequestID == "" || repo == nil { + postUsageBilling(ctx, p, deps) + return true, nil + } + + billingCtx, cancel := detachedBillingContext(ctx) + defer cancel() + + result, err := repo.Apply(billingCtx, cmd) + if err != nil { + return false, err + } + + if result == nil || !result.Applied { + deps.deferredService.ScheduleLastUsedUpdate(p.Account.ID) + return false, nil + } + + if result.APIKeyQuotaExhausted { + if invalidator, ok := p.APIKeyService.(apiKeyAuthCacheInvalidator); ok && p.APIKey != nil && p.APIKey.Key != "" { + invalidator.InvalidateAuthCacheByKey(billingCtx, p.APIKey.Key) + } + } + + finalizePostUsageBilling(p, deps) + return true, nil +} + +func finalizePostUsageBilling(p *postUsageBillingParams, deps *billingDeps) { + if p == nil || p.Cost == nil || deps == nil { + return + } + + if p.IsSubscriptionBill { + if p.Cost.TotalCost > 0 && p.User != nil && p.APIKey != nil && p.APIKey.GroupID != nil { + deps.billingCacheService.QueueUpdateSubscriptionUsage(p.User.ID, *p.APIKey.GroupID, p.Cost.TotalCost) + } + } else if p.Cost.ActualCost > 0 && p.User != nil { + deps.billingCacheService.QueueDeductBalance(p.User.ID, p.Cost.ActualCost) + } + + if p.Cost.ActualCost > 0 && p.APIKey != nil && p.APIKey.HasRateLimits() { + deps.billingCacheService.QueueUpdateAPIKeyRateLimitUsage(p.APIKey.ID, p.Cost.ActualCost) + } + deps.deferredService.ScheduleLastUsedUpdate(p.Account.ID) } +func detachedBillingContext(ctx context.Context) (context.Context, context.CancelFunc) { + base := context.Background() + if ctx != nil { + base = context.WithoutCancel(ctx) + } + return context.WithTimeout(base, postUsageBillingTimeout) +} + +func detachStreamUpstreamContext(ctx context.Context, stream bool) (context.Context, context.CancelFunc) { + if !stream { + return ctx, func() {} + } + if ctx == nil { + return context.Background(), func() {} + } + return context.WithoutCancel(ctx), func() {} +} + // billingDeps 扣费逻辑依赖的服务(由各 gateway service 提供) type billingDeps struct { accountRepo AccountRepository @@ -6657,6 +6904,28 @@ func (s *GatewayService) billingDeps() *billingDeps { } } +func writeUsageLogBestEffort(ctx context.Context, repo UsageLogRepository, usageLog *UsageLog, logKey string) { + if repo == nil || usageLog == nil { + return + } + usageCtx, cancel := detachedBillingContext(ctx) + defer cancel() + + if writer, ok := repo.(usageLogBestEffortWriter); ok { + if err := writer.CreateBestEffort(usageCtx, usageLog); err != nil { + logger.LegacyPrintf(logKey, "Create usage log failed: %v", err) + if _, syncErr := repo.Create(usageCtx, usageLog); syncErr != nil { + logger.LegacyPrintf(logKey, "Create usage log sync fallback failed: %v", syncErr) + } + } + return + } + + if _, err := repo.Create(usageCtx, usageLog); err != nil { + logger.LegacyPrintf(logKey, "Create usage log failed: %v", err) + } +} + // RecordUsage 记录使用量并扣费(或更新订阅用量) func (s *GatewayService) RecordUsage(ctx context.Context, input *RecordUsageInput) error { result := input.Result @@ -6758,11 +7027,12 @@ func (s *GatewayService) RecordUsage(ctx context.Context, input *RecordUsageInpu mediaType = &result.MediaType } accountRateMultiplier := account.BillingRateMultiplier() + requestID := resolveUsageBillingRequestID(ctx, result.RequestID) usageLog := &UsageLog{ UserID: user.ID, APIKeyID: apiKey.ID, AccountID: account.ID, - RequestID: result.RequestID, + RequestID: requestID, Model: result.Model, InputTokens: result.Usage.InputTokens, OutputTokens: result.Usage.OutputTokens, @@ -6807,33 +7077,32 @@ func (s *GatewayService) RecordUsage(ctx context.Context, input *RecordUsageInpu usageLog.SubscriptionID = &subscription.ID } - inserted, err := s.usageLogRepo.Create(ctx, usageLog) - if err != nil { - logger.LegacyPrintf("service.gateway", "Create usage log failed: %v", err) - } - if s.cfg != nil && s.cfg.RunMode == config.RunModeSimple { + writeUsageLogBestEffort(ctx, s.usageLogRepo, usageLog, "service.gateway") logger.LegacyPrintf("service.gateway", "[SIMPLE MODE] Usage recorded (not billed): user=%d, tokens=%d", usageLog.UserID, usageLog.TotalTokens()) s.deferredService.ScheduleLastUsedUpdate(account.ID) return nil } - shouldBill := inserted || err != nil - - if shouldBill { - postUsageBilling(ctx, &postUsageBillingParams{ + billingErr := func() error { + _, err := applyUsageBilling(ctx, requestID, usageLog, &postUsageBillingParams{ Cost: cost, User: user, APIKey: apiKey, Account: account, Subscription: subscription, + RequestPayloadHash: resolveUsageBillingPayloadFingerprint(ctx, input.RequestPayloadHash), IsSubscriptionBill: isSubscriptionBilling, AccountRateMultiplier: accountRateMultiplier, APIKeyService: input.APIKeyService, - }, s.billingDeps()) - } else { - s.deferredService.ScheduleLastUsedUpdate(account.ID) + }, s.billingDeps(), s.usageBillingRepo) + return err + }() + + if billingErr != nil { + return billingErr } + writeUsageLogBestEffort(ctx, s.usageLogRepo, usageLog, "service.gateway") return nil } @@ -6844,13 +7113,14 @@ type RecordUsageLongContextInput struct { APIKey *APIKey User *User Account *Account - Subscription *UserSubscription // 可选:订阅信息 - UserAgent string // 请求的 User-Agent - IPAddress string // 请求的客户端 IP 地址 - LongContextThreshold int // 长上下文阈值(如 200000) - LongContextMultiplier float64 // 超出阈值部分的倍率(如 2.0) - ForceCacheBilling bool // 强制缓存计费:将 input_tokens 转为 cache_read 计费(用于粘性会话切换) - APIKeyService *APIKeyService // API Key 配额服务(可选) + Subscription *UserSubscription // 可选:订阅信息 + UserAgent string // 请求的 User-Agent + IPAddress string // 请求的客户端 IP 地址 + RequestPayloadHash string // 请求体语义哈希,用于降低 request_id 误复用时的静默误去重风险 + LongContextThreshold int // 长上下文阈值(如 200000) + LongContextMultiplier float64 // 超出阈值部分的倍率(如 2.0) + ForceCacheBilling bool // 强制缓存计费:将 input_tokens 转为 cache_read 计费(用于粘性会话切换) + APIKeyService APIKeyQuotaUpdater // API Key 配额服务(可选) } // RecordUsageWithLongContext 记录使用量并扣费,支持长上下文双倍计费(用于 Gemini) @@ -6933,11 +7203,12 @@ func (s *GatewayService) RecordUsageWithLongContext(ctx context.Context, input * imageSize = &result.ImageSize } accountRateMultiplier := account.BillingRateMultiplier() + requestID := resolveUsageBillingRequestID(ctx, result.RequestID) usageLog := &UsageLog{ UserID: user.ID, APIKeyID: apiKey.ID, AccountID: account.ID, - RequestID: result.RequestID, + RequestID: requestID, Model: result.Model, InputTokens: result.Usage.InputTokens, OutputTokens: result.Usage.OutputTokens, @@ -6981,33 +7252,32 @@ func (s *GatewayService) RecordUsageWithLongContext(ctx context.Context, input * usageLog.SubscriptionID = &subscription.ID } - inserted, err := s.usageLogRepo.Create(ctx, usageLog) - if err != nil { - logger.LegacyPrintf("service.gateway", "Create usage log failed: %v", err) - } - if s.cfg != nil && s.cfg.RunMode == config.RunModeSimple { + writeUsageLogBestEffort(ctx, s.usageLogRepo, usageLog, "service.gateway") logger.LegacyPrintf("service.gateway", "[SIMPLE MODE] Usage recorded (not billed): user=%d, tokens=%d", usageLog.UserID, usageLog.TotalTokens()) s.deferredService.ScheduleLastUsedUpdate(account.ID) return nil } - shouldBill := inserted || err != nil - - if shouldBill { - postUsageBilling(ctx, &postUsageBillingParams{ + billingErr := func() error { + _, err := applyUsageBilling(ctx, requestID, usageLog, &postUsageBillingParams{ Cost: cost, User: user, APIKey: apiKey, Account: account, Subscription: subscription, + RequestPayloadHash: resolveUsageBillingPayloadFingerprint(ctx, input.RequestPayloadHash), IsSubscriptionBill: isSubscriptionBilling, AccountRateMultiplier: accountRateMultiplier, APIKeyService: input.APIKeyService, - }, s.billingDeps()) - } else { - s.deferredService.ScheduleLastUsedUpdate(account.ID) + }, s.billingDeps(), s.usageBillingRepo) + return err + }() + + if billingErr != nil { + return billingErr } + writeUsageLogBestEffort(ctx, s.usageLogRepo, usageLog, "service.gateway") return nil } diff --git a/backend/internal/service/gateway_streaming_test.go b/backend/internal/service/gateway_streaming_test.go index cd690cbd..b1584827 100644 --- a/backend/internal/service/gateway_streaming_test.go +++ b/backend/internal/service/gateway_streaming_test.go @@ -181,7 +181,8 @@ func TestHandleStreamingResponse_EmptyStream(t *testing.T) { result, err := svc.handleStreamingResponse(context.Background(), resp, c, &Account{ID: 1}, time.Now(), "model", "model", false) _ = pr.Close() - require.NoError(t, err) + require.Error(t, err) + require.Contains(t, err.Error(), "missing terminal event") require.NotNil(t, result) } diff --git a/backend/internal/service/openai_gateway_record_usage_test.go b/backend/internal/service/openai_gateway_record_usage_test.go index 9529462e..f05fa5f5 100644 --- a/backend/internal/service/openai_gateway_record_usage_test.go +++ b/backend/internal/service/openai_gateway_record_usage_test.go @@ -7,35 +7,63 @@ import ( "time" "github.com/Wei-Shaw/sub2api/internal/config" + "github.com/Wei-Shaw/sub2api/internal/pkg/ctxkey" "github.com/stretchr/testify/require" ) type openAIRecordUsageLogRepoStub struct { UsageLogRepository - inserted bool - err error - calls int - lastLog *UsageLog + inserted bool + err error + calls int + lastLog *UsageLog + lastCtxErr error } func (s *openAIRecordUsageLogRepoStub) Create(ctx context.Context, log *UsageLog) (bool, error) { s.calls++ s.lastLog = log + s.lastCtxErr = ctx.Err() return s.inserted, s.err } +type openAIRecordUsageBillingRepoStub struct { + UsageBillingRepository + + result *UsageBillingApplyResult + err error + calls int + lastCmd *UsageBillingCommand + lastCtxErr error +} + +func (s *openAIRecordUsageBillingRepoStub) Apply(ctx context.Context, cmd *UsageBillingCommand) (*UsageBillingApplyResult, error) { + s.calls++ + s.lastCmd = cmd + s.lastCtxErr = ctx.Err() + if s.err != nil { + return nil, s.err + } + if s.result != nil { + return s.result, nil + } + return &UsageBillingApplyResult{Applied: true}, nil +} + type openAIRecordUsageUserRepoStub struct { UserRepository deductCalls int deductErr error lastAmount float64 + lastCtxErr error } func (s *openAIRecordUsageUserRepoStub) DeductBalance(ctx context.Context, id int64, amount float64) error { s.deductCalls++ s.lastAmount = amount + s.lastCtxErr = ctx.Err() return s.deductErr } @@ -44,29 +72,35 @@ type openAIRecordUsageSubRepoStub struct { incrementCalls int incrementErr error + lastCtxErr error } func (s *openAIRecordUsageSubRepoStub) IncrementUsage(ctx context.Context, id int64, costUSD float64) error { s.incrementCalls++ + s.lastCtxErr = ctx.Err() return s.incrementErr } type openAIRecordUsageAPIKeyQuotaStub struct { - quotaCalls int - rateLimitCalls int - err error - lastAmount float64 + quotaCalls int + rateLimitCalls int + err error + lastAmount float64 + lastQuotaCtxErr error + lastRateLimitCtxErr error } func (s *openAIRecordUsageAPIKeyQuotaStub) UpdateQuotaUsed(ctx context.Context, apiKeyID int64, cost float64) error { s.quotaCalls++ s.lastAmount = cost + s.lastQuotaCtxErr = ctx.Err() return s.err } func (s *openAIRecordUsageAPIKeyQuotaStub) UpdateRateLimitUsage(ctx context.Context, apiKeyID int64, cost float64) error { s.rateLimitCalls++ s.lastAmount = cost + s.lastRateLimitCtxErr = ctx.Err() return s.err } @@ -93,23 +127,38 @@ func i64p(v int64) *int64 { func newOpenAIRecordUsageServiceForTest(usageRepo UsageLogRepository, userRepo UserRepository, subRepo UserSubscriptionRepository, rateRepo UserGroupRateRepository) *OpenAIGatewayService { cfg := &config.Config{} cfg.Default.RateMultiplier = 1.1 + svc := NewOpenAIGatewayService( + nil, + usageRepo, + nil, + userRepo, + subRepo, + rateRepo, + nil, + cfg, + nil, + nil, + NewBillingService(cfg, nil), + nil, + &BillingCacheService{}, + nil, + &DeferredService{}, + nil, + ) + svc.userGroupRateResolver = newUserGroupRateResolver( + rateRepo, + nil, + resolveUserGroupRateCacheTTL(cfg), + nil, + "service.openai_gateway.test", + ) + return svc +} - return &OpenAIGatewayService{ - usageLogRepo: usageRepo, - userRepo: userRepo, - userSubRepo: subRepo, - cfg: cfg, - billingService: NewBillingService(cfg, nil), - billingCacheService: &BillingCacheService{}, - deferredService: &DeferredService{}, - userGroupRateResolver: newUserGroupRateResolver( - rateRepo, - nil, - resolveUserGroupRateCacheTTL(cfg), - nil, - "service.openai_gateway.test", - ), - } +func newOpenAIRecordUsageServiceWithBillingRepoForTest(usageRepo UsageLogRepository, billingRepo UsageBillingRepository, userRepo UserRepository, subRepo UserSubscriptionRepository, rateRepo UserGroupRateRepository) *OpenAIGatewayService { + svc := newOpenAIRecordUsageServiceForTest(usageRepo, userRepo, subRepo, rateRepo) + svc.usageBillingRepo = billingRepo + return svc } func expectedOpenAICost(t *testing.T, svc *OpenAIGatewayService, model string, usage OpenAIUsage, multiplier float64) *CostBreakdown { @@ -252,9 +301,10 @@ func TestOpenAIGatewayServiceRecordUsage_FallsBackToGroupDefaultRateWhenResolver func TestOpenAIGatewayServiceRecordUsage_DuplicateUsageLogSkipsBilling(t *testing.T) { usageRepo := &openAIRecordUsageLogRepoStub{inserted: false} + billingRepo := &openAIRecordUsageBillingRepoStub{result: &UsageBillingApplyResult{Applied: false}} userRepo := &openAIRecordUsageUserRepoStub{} subRepo := &openAIRecordUsageSubRepoStub{} - svc := newOpenAIRecordUsageServiceForTest(usageRepo, userRepo, subRepo, nil) + svc := newOpenAIRecordUsageServiceWithBillingRepoForTest(usageRepo, billingRepo, userRepo, subRepo, nil) err := svc.RecordUsage(context.Background(), &OpenAIRecordUsageInput{ Result: &OpenAIForwardResult{ @@ -272,11 +322,254 @@ func TestOpenAIGatewayServiceRecordUsage_DuplicateUsageLogSkipsBilling(t *testin }) require.NoError(t, err) + require.Equal(t, 1, billingRepo.calls) require.Equal(t, 1, usageRepo.calls) require.Equal(t, 0, userRepo.deductCalls) require.Equal(t, 0, subRepo.incrementCalls) } +func TestOpenAIGatewayServiceRecordUsage_DuplicateBillingKeySkipsBillingWithRepo(t *testing.T) { + usageRepo := &openAIRecordUsageLogRepoStub{inserted: false} + billingRepo := &openAIRecordUsageBillingRepoStub{result: &UsageBillingApplyResult{Applied: false}} + userRepo := &openAIRecordUsageUserRepoStub{} + subRepo := &openAIRecordUsageSubRepoStub{} + quotaSvc := &openAIRecordUsageAPIKeyQuotaStub{} + svc := newOpenAIRecordUsageServiceWithBillingRepoForTest(usageRepo, billingRepo, userRepo, subRepo, nil) + + err := svc.RecordUsage(context.Background(), &OpenAIRecordUsageInput{ + Result: &OpenAIForwardResult{ + RequestID: "resp_duplicate_billing_key", + Usage: OpenAIUsage{ + InputTokens: 8, + OutputTokens: 4, + }, + Model: "gpt-5.1", + Duration: time.Second, + }, + APIKey: &APIKey{ + ID: 10045, + Quota: 100, + }, + User: &User{ID: 20045}, + Account: &Account{ID: 30045}, + APIKeyService: quotaSvc, + }) + + require.NoError(t, err) + require.Equal(t, 1, billingRepo.calls) + require.Equal(t, 1, usageRepo.calls) + require.Equal(t, 0, userRepo.deductCalls) + require.Equal(t, 0, subRepo.incrementCalls) + require.Equal(t, 0, quotaSvc.quotaCalls) +} + +func TestOpenAIGatewayServiceRecordUsage_BillsWhenUsageLogCreateReturnsError(t *testing.T) { + usage := OpenAIUsage{InputTokens: 8, OutputTokens: 4} + usageRepo := &openAIRecordUsageLogRepoStub{inserted: false, err: errors.New("usage log batch state uncertain")} + userRepo := &openAIRecordUsageUserRepoStub{} + subRepo := &openAIRecordUsageSubRepoStub{} + svc := newOpenAIRecordUsageServiceForTest(usageRepo, userRepo, subRepo, nil) + + err := svc.RecordUsage(context.Background(), &OpenAIRecordUsageInput{ + Result: &OpenAIForwardResult{ + RequestID: "resp_usage_log_error", + Usage: usage, + Model: "gpt-5.1", + Duration: time.Second, + }, + APIKey: &APIKey{ID: 10041}, + User: &User{ID: 20041}, + Account: &Account{ID: 30041}, + }) + + require.NoError(t, err) + require.Equal(t, 1, usageRepo.calls) + require.Equal(t, 1, userRepo.deductCalls) + require.Equal(t, 0, subRepo.incrementCalls) +} + +func TestOpenAIGatewayServiceRecordUsage_UsageLogWriteErrorDoesNotSkipBilling(t *testing.T) { + usageRepo := &openAIRecordUsageLogRepoStub{inserted: false, err: MarkUsageLogCreateNotPersisted(context.Canceled)} + userRepo := &openAIRecordUsageUserRepoStub{} + subRepo := &openAIRecordUsageSubRepoStub{} + quotaSvc := &openAIRecordUsageAPIKeyQuotaStub{} + svc := newOpenAIRecordUsageServiceForTest(usageRepo, userRepo, subRepo, nil) + + err := svc.RecordUsage(context.Background(), &OpenAIRecordUsageInput{ + Result: &OpenAIForwardResult{ + RequestID: "resp_not_persisted", + Usage: OpenAIUsage{ + InputTokens: 8, + OutputTokens: 4, + }, + Model: "gpt-5.1", + Duration: time.Second, + }, + APIKey: &APIKey{ + ID: 10043, + Quota: 100, + }, + User: &User{ID: 20043}, + Account: &Account{ID: 30043}, + APIKeyService: quotaSvc, + }) + + require.NoError(t, err) + require.Equal(t, 1, usageRepo.calls) + require.Equal(t, 1, userRepo.deductCalls) + require.Equal(t, 0, subRepo.incrementCalls) + require.Equal(t, 1, quotaSvc.quotaCalls) +} + +func TestOpenAIGatewayServiceRecordUsage_BillingUsesDetachedContext(t *testing.T) { + usage := OpenAIUsage{InputTokens: 10, OutputTokens: 6, CacheReadInputTokens: 2} + usageRepo := &openAIRecordUsageLogRepoStub{inserted: false, err: context.DeadlineExceeded} + userRepo := &openAIRecordUsageUserRepoStub{} + subRepo := &openAIRecordUsageSubRepoStub{} + quotaSvc := &openAIRecordUsageAPIKeyQuotaStub{} + svc := newOpenAIRecordUsageServiceForTest(usageRepo, userRepo, subRepo, nil) + + reqCtx, cancel := context.WithCancel(context.Background()) + cancel() + + err := svc.RecordUsage(reqCtx, &OpenAIRecordUsageInput{ + Result: &OpenAIForwardResult{ + RequestID: "resp_detached_billing_ctx", + Usage: usage, + Model: "gpt-5.1", + Duration: time.Second, + }, + APIKey: &APIKey{ + ID: 10042, + Quota: 100, + }, + User: &User{ID: 20042}, + Account: &Account{ID: 30042}, + APIKeyService: quotaSvc, + }) + + require.NoError(t, err) + require.Equal(t, 1, userRepo.deductCalls) + require.NoError(t, userRepo.lastCtxErr) + require.Equal(t, 1, quotaSvc.quotaCalls) + require.NoError(t, quotaSvc.lastQuotaCtxErr) +} + +func TestOpenAIGatewayServiceRecordUsage_BillingRepoUsesDetachedContext(t *testing.T) { + usageRepo := &openAIRecordUsageLogRepoStub{} + billingRepo := &openAIRecordUsageBillingRepoStub{result: &UsageBillingApplyResult{Applied: true}} + userRepo := &openAIRecordUsageUserRepoStub{} + subRepo := &openAIRecordUsageSubRepoStub{} + svc := newOpenAIRecordUsageServiceWithBillingRepoForTest(usageRepo, billingRepo, userRepo, subRepo, nil) + + reqCtx, cancel := context.WithCancel(context.Background()) + cancel() + + err := svc.RecordUsage(reqCtx, &OpenAIRecordUsageInput{ + Result: &OpenAIForwardResult{ + RequestID: "resp_detached_billing_repo_ctx", + Usage: OpenAIUsage{ + InputTokens: 8, + OutputTokens: 4, + }, + Model: "gpt-5.1", + Duration: time.Second, + }, + APIKey: &APIKey{ID: 10046}, + User: &User{ID: 20046}, + Account: &Account{ID: 30046}, + }) + + require.NoError(t, err) + require.Equal(t, 1, billingRepo.calls) + require.NoError(t, billingRepo.lastCtxErr) + require.Equal(t, 1, usageRepo.calls) + require.NoError(t, usageRepo.lastCtxErr) +} + +func TestOpenAIGatewayServiceRecordUsage_BillingFingerprintIncludesRequestPayloadHash(t *testing.T) { + usageRepo := &openAIRecordUsageLogRepoStub{} + billingRepo := &openAIRecordUsageBillingRepoStub{result: &UsageBillingApplyResult{Applied: true}} + svc := newOpenAIRecordUsageServiceWithBillingRepoForTest(usageRepo, billingRepo, &openAIRecordUsageUserRepoStub{}, &openAIRecordUsageSubRepoStub{}, nil) + + payloadHash := HashUsageRequestPayload([]byte(`{"model":"gpt-5","input":"hello"}`)) + err := svc.RecordUsage(context.Background(), &OpenAIRecordUsageInput{ + Result: &OpenAIForwardResult{ + RequestID: "openai_payload_hash", + Usage: OpenAIUsage{ + InputTokens: 10, + OutputTokens: 6, + }, + Model: "gpt-5", + Duration: time.Second, + }, + APIKey: &APIKey{ID: 501, Quota: 100}, + User: &User{ID: 601}, + Account: &Account{ID: 701}, + RequestPayloadHash: payloadHash, + }) + require.NoError(t, err) + require.NotNil(t, billingRepo.lastCmd) + require.Equal(t, payloadHash, billingRepo.lastCmd.RequestPayloadHash) +} + +func TestOpenAIGatewayServiceRecordUsage_UsesFallbackRequestIDForBillingAndUsageLog(t *testing.T) { + usageRepo := &openAIRecordUsageLogRepoStub{} + billingRepo := &openAIRecordUsageBillingRepoStub{result: &UsageBillingApplyResult{Applied: true}} + userRepo := &openAIRecordUsageUserRepoStub{} + subRepo := &openAIRecordUsageSubRepoStub{} + svc := newOpenAIRecordUsageServiceWithBillingRepoForTest(usageRepo, billingRepo, userRepo, subRepo, nil) + + ctx := context.WithValue(context.Background(), ctxkey.RequestID, "req-local-fallback") + err := svc.RecordUsage(ctx, &OpenAIRecordUsageInput{ + Result: &OpenAIForwardResult{ + RequestID: "", + Usage: OpenAIUsage{ + InputTokens: 8, + OutputTokens: 4, + }, + Model: "gpt-5.1", + Duration: time.Second, + }, + APIKey: &APIKey{ID: 10047}, + User: &User{ID: 20047}, + Account: &Account{ID: 30047}, + }) + + require.NoError(t, err) + require.NotNil(t, billingRepo.lastCmd) + require.Equal(t, "local:req-local-fallback", billingRepo.lastCmd.RequestID) + require.NotNil(t, usageRepo.lastLog) + require.Equal(t, "local:req-local-fallback", usageRepo.lastLog.RequestID) +} + +func TestOpenAIGatewayServiceRecordUsage_BillingErrorSkipsUsageLogWrite(t *testing.T) { + usageRepo := &openAIRecordUsageLogRepoStub{} + billingRepo := &openAIRecordUsageBillingRepoStub{err: errors.New("billing tx failed")} + userRepo := &openAIRecordUsageUserRepoStub{} + subRepo := &openAIRecordUsageSubRepoStub{} + svc := newOpenAIRecordUsageServiceWithBillingRepoForTest(usageRepo, billingRepo, userRepo, subRepo, nil) + + err := svc.RecordUsage(context.Background(), &OpenAIRecordUsageInput{ + Result: &OpenAIForwardResult{ + RequestID: "resp_billing_fail", + Usage: OpenAIUsage{ + InputTokens: 8, + OutputTokens: 4, + }, + Model: "gpt-5.1", + Duration: time.Second, + }, + APIKey: &APIKey{ID: 10048}, + User: &User{ID: 20048}, + Account: &Account{ID: 30048}, + }) + + require.Error(t, err) + require.Equal(t, 1, billingRepo.calls) + require.Equal(t, 0, usageRepo.calls) +} + func TestOpenAIGatewayServiceRecordUsage_UpdatesAPIKeyQuotaWhenConfigured(t *testing.T) { usage := OpenAIUsage{InputTokens: 10, OutputTokens: 6, CacheReadInputTokens: 2} usageRepo := &openAIRecordUsageLogRepoStub{inserted: true} diff --git a/backend/internal/service/openai_gateway_service.go b/backend/internal/service/openai_gateway_service.go index 44cfc83a..241c5cd6 100644 --- a/backend/internal/service/openai_gateway_service.go +++ b/backend/internal/service/openai_gateway_service.go @@ -259,6 +259,7 @@ type openAIWSRetryMetrics struct { type OpenAIGatewayService struct { accountRepo AccountRepository usageLogRepo UsageLogRepository + usageBillingRepo UsageBillingRepository userRepo UserRepository userSubRepo UserSubscriptionRepository cache GatewayCache @@ -295,6 +296,7 @@ type OpenAIGatewayService struct { func NewOpenAIGatewayService( accountRepo AccountRepository, usageLogRepo UsageLogRepository, + usageBillingRepo UsageBillingRepository, userRepo UserRepository, userSubRepo UserSubscriptionRepository, userGroupRateRepo UserGroupRateRepository, @@ -312,6 +314,7 @@ func NewOpenAIGatewayService( svc := &OpenAIGatewayService{ accountRepo: accountRepo, usageLogRepo: usageLogRepo, + usageBillingRepo: usageBillingRepo, userRepo: userRepo, userSubRepo: userSubRepo, cache: cache, @@ -2014,7 +2017,9 @@ func (s *OpenAIGatewayService) Forward(ctx context.Context, c *gin.Context, acco } // Build upstream request - upstreamReq, err := s.buildUpstreamRequest(ctx, c, account, body, token, reqStream, promptCacheKey, isCodexCLI) + upstreamCtx, releaseUpstreamCtx := detachStreamUpstreamContext(ctx, reqStream) + upstreamReq, err := s.buildUpstreamRequest(upstreamCtx, c, account, body, token, reqStream, promptCacheKey, isCodexCLI) + releaseUpstreamCtx() if err != nil { return nil, err } @@ -2206,7 +2211,9 @@ func (s *OpenAIGatewayService) forwardOpenAIPassthrough( return nil, err } - upstreamReq, err := s.buildUpstreamRequestOpenAIPassthrough(ctx, c, account, body, token) + upstreamCtx, releaseUpstreamCtx := detachStreamUpstreamContext(ctx, reqStream) + upstreamReq, err := s.buildUpstreamRequestOpenAIPassthrough(upstreamCtx, c, account, body, token) + releaseUpstreamCtx() if err != nil { return nil, err } @@ -2543,6 +2550,7 @@ func (s *OpenAIGatewayService) handleStreamingResponsePassthrough( var firstTokenMs *int clientDisconnected := false sawDone := false + sawTerminalEvent := false upstreamRequestID := strings.TrimSpace(resp.Header.Get("x-request-id")) scanner := bufio.NewScanner(resp.Body) @@ -2562,6 +2570,9 @@ func (s *OpenAIGatewayService) handleStreamingResponsePassthrough( if trimmedData == "[DONE]" { sawDone = true } + if openAIStreamEventIsTerminal(trimmedData) { + sawTerminalEvent = true + } if firstTokenMs == nil && trimmedData != "" && trimmedData != "[DONE]" { ms := int(time.Since(startTime).Milliseconds()) firstTokenMs = &ms @@ -2579,19 +2590,14 @@ func (s *OpenAIGatewayService) handleStreamingResponsePassthrough( } } if err := scanner.Err(); err != nil { - if clientDisconnected { - logger.LegacyPrintf("service.openai_gateway", "[OpenAI passthrough] Upstream read error after client disconnect: account=%d err=%v", account.ID, err) + if sawTerminalEvent { return &openaiStreamingResultPassthrough{usage: usage, firstTokenMs: firstTokenMs}, nil } + if clientDisconnected { + return &openaiStreamingResultPassthrough{usage: usage, firstTokenMs: firstTokenMs}, fmt.Errorf("stream usage incomplete after disconnect: %w", err) + } if errors.Is(err, context.Canceled) || errors.Is(err, context.DeadlineExceeded) { - logger.LegacyPrintf("service.openai_gateway", - "[OpenAI passthrough] 流读取被取消,可能发生断流: account=%d request_id=%s err=%v ctx_err=%v", - account.ID, - upstreamRequestID, - err, - ctx.Err(), - ) - return &openaiStreamingResultPassthrough{usage: usage, firstTokenMs: firstTokenMs}, nil + return &openaiStreamingResultPassthrough{usage: usage, firstTokenMs: firstTokenMs}, fmt.Errorf("stream usage incomplete: %w", err) } if errors.Is(err, bufio.ErrTooLong) { logger.LegacyPrintf("service.openai_gateway", "[OpenAI passthrough] SSE line too long: account=%d max_size=%d error=%v", account.ID, maxLineSize, err) @@ -2605,12 +2611,13 @@ func (s *OpenAIGatewayService) handleStreamingResponsePassthrough( ) return &openaiStreamingResultPassthrough{usage: usage, firstTokenMs: firstTokenMs}, fmt.Errorf("stream read error: %w", err) } - if !clientDisconnected && !sawDone && ctx.Err() == nil { + if !clientDisconnected && !sawDone && !sawTerminalEvent && ctx.Err() == nil { logger.FromContext(ctx).With( zap.String("component", "service.openai_gateway"), zap.Int64("account_id", account.ID), zap.String("upstream_request_id", upstreamRequestID), ).Info("OpenAI passthrough 上游流在未收到 [DONE] 时结束,疑似断流") + return &openaiStreamingResultPassthrough{usage: usage, firstTokenMs: firstTokenMs}, errors.New("stream usage incomplete: missing terminal event") } return &openaiStreamingResultPassthrough{usage: usage, firstTokenMs: firstTokenMs}, nil @@ -3030,6 +3037,7 @@ func (s *OpenAIGatewayService) handleStreamingResponse(ctx context.Context, resp // 否则下游 SDK(例如 OpenCode)会因为类型校验失败而报错。 errorEventSent := false clientDisconnected := false // 客户端断开后继续 drain 上游以收集 usage + sawTerminalEvent := false sendErrorEvent := func(reason string) { if errorEventSent || clientDisconnected { return @@ -3060,22 +3068,27 @@ func (s *OpenAIGatewayService) handleStreamingResponse(ctx context.Context, resp logger.LegacyPrintf("service.openai_gateway", "Client disconnected during final flush, returning collected usage") } } + if !sawTerminalEvent { + return resultWithUsage(), fmt.Errorf("stream usage incomplete: missing terminal event") + } return resultWithUsage(), nil } handleScanErr := func(scanErr error) (*openaiStreamingResult, error, bool) { if scanErr == nil { return nil, nil, false } + if sawTerminalEvent { + logger.LegacyPrintf("service.openai_gateway", "Upstream scan ended after terminal event: %v", scanErr) + return resultWithUsage(), nil, true + } // 客户端断开/取消请求时,上游读取往往会返回 context canceled。 // /v1/responses 的 SSE 事件必须符合 OpenAI 协议;这里不注入自定义 error event,避免下游 SDK 解析失败。 if errors.Is(scanErr, context.Canceled) || errors.Is(scanErr, context.DeadlineExceeded) { - logger.LegacyPrintf("service.openai_gateway", "Context canceled during streaming, returning collected usage") - return resultWithUsage(), nil, true + return resultWithUsage(), fmt.Errorf("stream usage incomplete: %w", scanErr), true } // 客户端已断开时,上游出错仅影响体验,不影响计费;返回已收集 usage if clientDisconnected { - logger.LegacyPrintf("service.openai_gateway", "Upstream read error after client disconnect: %v, returning collected usage", scanErr) - return resultWithUsage(), nil, true + return resultWithUsage(), fmt.Errorf("stream usage incomplete after disconnect: %w", scanErr), true } if errors.Is(scanErr, bufio.ErrTooLong) { logger.LegacyPrintf("service.openai_gateway", "SSE line too long: account=%d max_size=%d error=%v", account.ID, maxLineSize, scanErr) @@ -3098,6 +3111,9 @@ func (s *OpenAIGatewayService) handleStreamingResponse(ctx context.Context, resp } dataBytes := []byte(data) + if openAIStreamEventIsTerminal(data) { + sawTerminalEvent = true + } // Correct Codex tool calls if needed (apply_patch -> edit, etc.) if correctedData, corrected := s.toolCorrector.CorrectToolCallsInSSEBytes(dataBytes); corrected { @@ -3214,8 +3230,7 @@ func (s *OpenAIGatewayService) handleStreamingResponse(ctx context.Context, resp continue } if clientDisconnected { - logger.LegacyPrintf("service.openai_gateway", "Upstream timeout after client disconnect, returning collected usage") - return resultWithUsage(), nil + return resultWithUsage(), fmt.Errorf("stream usage incomplete after timeout") } logger.LegacyPrintf("service.openai_gateway", "Stream data interval timeout: account=%d model=%s interval=%s", account.ID, originalModel, streamInterval) // 处理流超时,可能标记账户为临时不可调度或错误状态 @@ -3313,11 +3328,12 @@ func (s *OpenAIGatewayService) parseSSEUsageBytes(data []byte, usage *OpenAIUsag if usage == nil || len(data) == 0 || bytes.Equal(data, []byte("[DONE]")) { return } - // 选择性解析:仅在数据中包含 completed 事件标识时才进入字段提取。 - if len(data) < 80 || !bytes.Contains(data, []byte(`"response.completed"`)) { + // 选择性解析:仅在数据中包含终止事件标识时才进入字段提取。 + if len(data) < 72 { return } - if gjson.GetBytes(data, "type").String() != "response.completed" { + eventType := gjson.GetBytes(data, "type").String() + if eventType != "response.completed" && eventType != "response.done" { return } @@ -3670,14 +3686,15 @@ func (s *OpenAIGatewayService) replaceModelInResponseBody(body []byte, fromModel // OpenAIRecordUsageInput input for recording usage type OpenAIRecordUsageInput struct { - Result *OpenAIForwardResult - APIKey *APIKey - User *User - Account *Account - Subscription *UserSubscription - UserAgent string // 请求的 User-Agent - IPAddress string // 请求的客户端 IP 地址 - APIKeyService APIKeyQuotaUpdater + Result *OpenAIForwardResult + APIKey *APIKey + User *User + Account *Account + Subscription *UserSubscription + UserAgent string // 请求的 User-Agent + IPAddress string // 请求的客户端 IP 地址 + RequestPayloadHash string + APIKeyService APIKeyQuotaUpdater } // RecordUsage records usage and deducts balance @@ -3743,11 +3760,12 @@ func (s *OpenAIGatewayService) RecordUsage(ctx context.Context, input *OpenAIRec // Create usage log durationMs := int(result.Duration.Milliseconds()) accountRateMultiplier := account.BillingRateMultiplier() + requestID := resolveUsageBillingRequestID(ctx, result.RequestID) usageLog := &UsageLog{ UserID: user.ID, APIKeyID: apiKey.ID, AccountID: account.ID, - RequestID: result.RequestID, + RequestID: requestID, Model: billingModel, ServiceTier: result.ServiceTier, ReasoningEffort: result.ReasoningEffort, @@ -3788,29 +3806,32 @@ func (s *OpenAIGatewayService) RecordUsage(ctx context.Context, input *OpenAIRec usageLog.SubscriptionID = &subscription.ID } - inserted, err := s.usageLogRepo.Create(ctx, usageLog) if s.cfg != nil && s.cfg.RunMode == config.RunModeSimple { + writeUsageLogBestEffort(ctx, s.usageLogRepo, usageLog, "service.openai_gateway") logger.LegacyPrintf("service.openai_gateway", "[SIMPLE MODE] Usage recorded (not billed): user=%d, tokens=%d", usageLog.UserID, usageLog.TotalTokens()) s.deferredService.ScheduleLastUsedUpdate(account.ID) return nil } - shouldBill := inserted || err != nil - - if shouldBill { - postUsageBilling(ctx, &postUsageBillingParams{ + billingErr := func() error { + _, err := applyUsageBilling(ctx, requestID, usageLog, &postUsageBillingParams{ Cost: cost, User: user, APIKey: apiKey, Account: account, Subscription: subscription, + RequestPayloadHash: resolveUsageBillingPayloadFingerprint(ctx, input.RequestPayloadHash), IsSubscriptionBill: isSubscriptionBilling, AccountRateMultiplier: accountRateMultiplier, APIKeyService: input.APIKeyService, - }, s.billingDeps()) - } else { - s.deferredService.ScheduleLastUsedUpdate(account.ID) + }, s.billingDeps(), s.usageBillingRepo) + return err + }() + + if billingErr != nil { + return billingErr } + writeUsageLogBestEffort(ctx, s.usageLogRepo, usageLog, "service.openai_gateway") return nil } diff --git a/backend/internal/service/openai_gateway_service_test.go b/backend/internal/service/openai_gateway_service_test.go index 43e2f39d..9e2f33f2 100644 --- a/backend/internal/service/openai_gateway_service_test.go +++ b/backend/internal/service/openai_gateway_service_test.go @@ -916,7 +916,7 @@ func TestOpenAIStreamingTimeout(t *testing.T) { } } -func TestOpenAIStreamingContextCanceledDoesNotInjectErrorEvent(t *testing.T) { +func TestOpenAIStreamingContextCanceledReturnsIncompleteErrorWithoutInjectingErrorEvent(t *testing.T) { gin.SetMode(gin.TestMode) cfg := &config.Config{ Gateway: config.GatewayConfig{ @@ -940,8 +940,8 @@ func TestOpenAIStreamingContextCanceledDoesNotInjectErrorEvent(t *testing.T) { } _, err := svc.handleStreamingResponse(c.Request.Context(), resp, c, &Account{ID: 1}, time.Now(), "model", "model") - if err != nil { - t.Fatalf("expected nil error, got %v", err) + if err == nil || !strings.Contains(err.Error(), "stream usage incomplete") { + t.Fatalf("expected incomplete stream error, got %v", err) } if strings.Contains(rec.Body.String(), "event: error") || strings.Contains(rec.Body.String(), "stream_read_error") { t.Fatalf("expected no injected SSE error event, got %q", rec.Body.String()) @@ -993,6 +993,107 @@ func TestOpenAIStreamingClientDisconnectDrainsUpstreamUsage(t *testing.T) { } } +func TestOpenAIStreamingMissingTerminalEventReturnsIncompleteError(t *testing.T) { + gin.SetMode(gin.TestMode) + cfg := &config.Config{ + Gateway: config.GatewayConfig{ + StreamDataIntervalTimeout: 0, + StreamKeepaliveInterval: 0, + MaxLineSize: defaultMaxLineSize, + }, + } + svc := &OpenAIGatewayService{cfg: cfg} + + rec := httptest.NewRecorder() + c, _ := gin.CreateTestContext(rec) + c.Request = httptest.NewRequest(http.MethodPost, "/", nil) + + pr, pw := io.Pipe() + resp := &http.Response{ + StatusCode: http.StatusOK, + Body: pr, + Header: http.Header{}, + } + + go func() { + defer func() { _ = pw.Close() }() + _, _ = pw.Write([]byte("data: {\"type\":\"response.in_progress\",\"response\":{}}\n\n")) + }() + + _, err := svc.handleStreamingResponse(c.Request.Context(), resp, c, &Account{ID: 1}, time.Now(), "model", "model") + _ = pr.Close() + if err == nil || !strings.Contains(err.Error(), "missing terminal event") { + t.Fatalf("expected missing terminal event error, got %v", err) + } +} + +func TestOpenAIStreamingPassthroughMissingTerminalEventReturnsIncompleteError(t *testing.T) { + gin.SetMode(gin.TestMode) + cfg := &config.Config{ + Gateway: config.GatewayConfig{ + MaxLineSize: defaultMaxLineSize, + }, + } + svc := &OpenAIGatewayService{cfg: cfg} + + rec := httptest.NewRecorder() + c, _ := gin.CreateTestContext(rec) + c.Request = httptest.NewRequest(http.MethodPost, "/", nil) + + pr, pw := io.Pipe() + resp := &http.Response{ + StatusCode: http.StatusOK, + Body: pr, + Header: http.Header{}, + } + + go func() { + defer func() { _ = pw.Close() }() + _, _ = pw.Write([]byte("data: {\"type\":\"response.in_progress\",\"response\":{}}\n\n")) + }() + + _, err := svc.handleStreamingResponsePassthrough(c.Request.Context(), resp, c, &Account{ID: 1}, time.Now()) + _ = pr.Close() + if err == nil || !strings.Contains(err.Error(), "missing terminal event") { + t.Fatalf("expected missing terminal event error, got %v", err) + } +} + +func TestOpenAIStreamingPassthroughResponseDoneWithoutDoneMarkerStillSucceeds(t *testing.T) { + gin.SetMode(gin.TestMode) + cfg := &config.Config{ + Gateway: config.GatewayConfig{ + MaxLineSize: defaultMaxLineSize, + }, + } + svc := &OpenAIGatewayService{cfg: cfg} + + rec := httptest.NewRecorder() + c, _ := gin.CreateTestContext(rec) + c.Request = httptest.NewRequest(http.MethodPost, "/", nil) + + pr, pw := io.Pipe() + resp := &http.Response{ + StatusCode: http.StatusOK, + Body: pr, + Header: http.Header{}, + } + + go func() { + defer func() { _ = pw.Close() }() + _, _ = pw.Write([]byte("data: {\"type\":\"response.done\",\"response\":{\"usage\":{\"input_tokens\":2,\"output_tokens\":3,\"input_tokens_details\":{\"cached_tokens\":1}}}}\n\n")) + }() + + result, err := svc.handleStreamingResponsePassthrough(c.Request.Context(), resp, c, &Account{ID: 1}, time.Now()) + _ = pr.Close() + require.NoError(t, err) + require.NotNil(t, result) + require.NotNil(t, result.usage) + require.Equal(t, 2, result.usage.InputTokens) + require.Equal(t, 3, result.usage.OutputTokens) + require.Equal(t, 1, result.usage.CacheReadInputTokens) +} + func TestOpenAIStreamingTooLong(t *testing.T) { gin.SetMode(gin.TestMode) cfg := &config.Config{ @@ -1124,7 +1225,7 @@ func TestOpenAIStreamingHeadersOverride(t *testing.T) { go func() { defer func() { _ = pw.Close() }() - _, _ = pw.Write([]byte("data: {}\n\n")) + _, _ = pw.Write([]byte("data: {\"type\":\"response.completed\",\"response\":{}}\n\n")) }() _, err := svc.handleStreamingResponse(c.Request.Context(), resp, c, &Account{ID: 1}, time.Now(), "model", "model") @@ -1674,6 +1775,12 @@ func TestParseSSEUsage_SelectiveParsing(t *testing.T) { require.Equal(t, 3, usage.InputTokens) require.Equal(t, 5, usage.OutputTokens) require.Equal(t, 2, usage.CacheReadInputTokens) + + // done 事件同样可能携带最终 usage + svc.parseSSEUsage(`{"type":"response.done","response":{"usage":{"input_tokens":13,"output_tokens":15,"input_tokens_details":{"cached_tokens":4}}}}`, usage) + require.Equal(t, 13, usage.InputTokens) + require.Equal(t, 15, usage.OutputTokens) + require.Equal(t, 4, usage.CacheReadInputTokens) } func TestExtractCodexFinalResponse_SampleReplay(t *testing.T) { diff --git a/backend/internal/service/openai_ws_protocol_forward_test.go b/backend/internal/service/openai_ws_protocol_forward_test.go index 7295b13d..08eb397b 100644 --- a/backend/internal/service/openai_ws_protocol_forward_test.go +++ b/backend/internal/service/openai_ws_protocol_forward_test.go @@ -392,6 +392,7 @@ func TestNewOpenAIGatewayService_InitializesOpenAIWSResolver(t *testing.T) { nil, nil, nil, + nil, cfg, nil, nil, diff --git a/backend/internal/service/usage_billing.go b/backend/internal/service/usage_billing.go new file mode 100644 index 00000000..73b05743 --- /dev/null +++ b/backend/internal/service/usage_billing.go @@ -0,0 +1,110 @@ +package service + +import ( + "context" + "crypto/sha256" + "encoding/hex" + "errors" + "fmt" + "strings" +) + +var ErrUsageBillingRequestIDRequired = errors.New("usage billing request_id is required") +var ErrUsageBillingRequestConflict = errors.New("usage billing request fingerprint conflict") + +// UsageBillingCommand describes one billable request that must be applied at most once. +type UsageBillingCommand struct { + RequestID string + APIKeyID int64 + RequestFingerprint string + RequestPayloadHash string + + UserID int64 + AccountID int64 + SubscriptionID *int64 + AccountType string + Model string + ServiceTier string + ReasoningEffort string + BillingType int8 + InputTokens int + OutputTokens int + CacheCreationTokens int + CacheReadTokens int + ImageCount int + MediaType string + + BalanceCost float64 + SubscriptionCost float64 + APIKeyQuotaCost float64 + APIKeyRateLimitCost float64 + AccountQuotaCost float64 +} + +func (c *UsageBillingCommand) Normalize() { + if c == nil { + return + } + c.RequestID = strings.TrimSpace(c.RequestID) + if strings.TrimSpace(c.RequestFingerprint) == "" { + c.RequestFingerprint = buildUsageBillingFingerprint(c) + } +} + +func buildUsageBillingFingerprint(c *UsageBillingCommand) string { + if c == nil { + return "" + } + raw := fmt.Sprintf( + "%d|%d|%d|%s|%s|%s|%s|%d|%d|%d|%d|%d|%d|%s|%d|%0.10f|%0.10f|%0.10f|%0.10f|%0.10f", + c.UserID, + c.AccountID, + c.APIKeyID, + strings.TrimSpace(c.AccountType), + strings.TrimSpace(c.Model), + strings.TrimSpace(c.ServiceTier), + strings.TrimSpace(c.ReasoningEffort), + c.BillingType, + c.InputTokens, + c.OutputTokens, + c.CacheCreationTokens, + c.CacheReadTokens, + c.ImageCount, + strings.TrimSpace(c.MediaType), + valueOrZero(c.SubscriptionID), + c.BalanceCost, + c.SubscriptionCost, + c.APIKeyQuotaCost, + c.APIKeyRateLimitCost, + c.AccountQuotaCost, + ) + if payloadHash := strings.TrimSpace(c.RequestPayloadHash); payloadHash != "" { + raw += "|" + payloadHash + } + sum := sha256.Sum256([]byte(raw)) + return hex.EncodeToString(sum[:]) +} + +func HashUsageRequestPayload(payload []byte) string { + if len(payload) == 0 { + return "" + } + sum := sha256.Sum256(payload) + return hex.EncodeToString(sum[:]) +} + +func valueOrZero(v *int64) int64 { + if v == nil { + return 0 + } + return *v +} + +type UsageBillingApplyResult struct { + Applied bool + APIKeyQuotaExhausted bool +} + +type UsageBillingRepository interface { + Apply(ctx context.Context, cmd *UsageBillingCommand) (*UsageBillingApplyResult, error) +} diff --git a/backend/internal/service/usage_cleanup_service_test.go b/backend/internal/service/usage_cleanup_service_test.go index 0fdbfd47..17f21bef 100644 --- a/backend/internal/service/usage_cleanup_service_test.go +++ b/backend/internal/service/usage_cleanup_service_test.go @@ -56,7 +56,8 @@ type cleanupRepoStub struct { } type dashboardRepoStub struct { - recomputeErr error + recomputeErr error + recomputeCalls int } func (s *dashboardRepoStub) AggregateRange(ctx context.Context, start, end time.Time) error { @@ -64,6 +65,7 @@ func (s *dashboardRepoStub) AggregateRange(ctx context.Context, start, end time. } func (s *dashboardRepoStub) RecomputeRange(ctx context.Context, start, end time.Time) error { + s.recomputeCalls++ return s.recomputeErr } @@ -83,6 +85,10 @@ func (s *dashboardRepoStub) CleanupUsageLogs(ctx context.Context, cutoff time.Ti return nil } +func (s *dashboardRepoStub) CleanupUsageBillingDedup(ctx context.Context, cutoff time.Time) error { + return nil +} + func (s *dashboardRepoStub) EnsureUsageLogsPartitions(ctx context.Context, now time.Time) error { return nil } @@ -550,13 +556,14 @@ func TestUsageCleanupServiceExecuteTaskMarkFailedUpdateError(t *testing.T) { } func TestUsageCleanupServiceExecuteTaskDashboardRecomputeError(t *testing.T) { + dashboardRepo := &dashboardRepoStub{recomputeErr: errors.New("recompute failed")} repo := &cleanupRepoStub{ deleteQueue: []cleanupDeleteResponse{ {deleted: 0}, }, } - dashboard := NewDashboardAggregationService(&dashboardRepoStub{}, nil, &config.Config{ - DashboardAgg: config.DashboardAggregationConfig{Enabled: false}, + dashboard := NewDashboardAggregationService(dashboardRepo, nil, &config.Config{ + DashboardAgg: config.DashboardAggregationConfig{Enabled: true}, }) cfg := &config.Config{UsageCleanup: config.UsageCleanupConfig{Enabled: true, BatchSize: 2}} svc := NewUsageCleanupService(repo, nil, dashboard, cfg) @@ -573,15 +580,17 @@ func TestUsageCleanupServiceExecuteTaskDashboardRecomputeError(t *testing.T) { repo.mu.Lock() defer repo.mu.Unlock() require.Len(t, repo.markSucceeded, 1) + require.Eventually(t, func() bool { return dashboardRepo.recomputeCalls == 1 }, time.Second, 10*time.Millisecond) } func TestUsageCleanupServiceExecuteTaskDashboardRecomputeSuccess(t *testing.T) { + dashboardRepo := &dashboardRepoStub{} repo := &cleanupRepoStub{ deleteQueue: []cleanupDeleteResponse{ {deleted: 0}, }, } - dashboard := NewDashboardAggregationService(&dashboardRepoStub{}, nil, &config.Config{ + dashboard := NewDashboardAggregationService(dashboardRepo, nil, &config.Config{ DashboardAgg: config.DashboardAggregationConfig{Enabled: true}, }) cfg := &config.Config{UsageCleanup: config.UsageCleanupConfig{Enabled: true, BatchSize: 2}} @@ -599,6 +608,7 @@ func TestUsageCleanupServiceExecuteTaskDashboardRecomputeSuccess(t *testing.T) { repo.mu.Lock() defer repo.mu.Unlock() require.Len(t, repo.markSucceeded, 1) + require.Eventually(t, func() bool { return dashboardRepo.recomputeCalls == 1 }, time.Second, 10*time.Millisecond) } func TestUsageCleanupServiceExecuteTaskCanceled(t *testing.T) { diff --git a/backend/internal/service/usage_log_create_result.go b/backend/internal/service/usage_log_create_result.go new file mode 100644 index 00000000..5e18b44c --- /dev/null +++ b/backend/internal/service/usage_log_create_result.go @@ -0,0 +1,60 @@ +package service + +import "errors" + +type usageLogCreateDisposition int + +const ( + usageLogCreateDispositionUnknown usageLogCreateDisposition = iota + usageLogCreateDispositionNotPersisted +) + +type UsageLogCreateError struct { + err error + disposition usageLogCreateDisposition +} + +func (e *UsageLogCreateError) Error() string { + if e == nil || e.err == nil { + return "usage log create error" + } + return e.err.Error() +} + +func (e *UsageLogCreateError) Unwrap() error { + if e == nil { + return nil + } + return e.err +} + +func MarkUsageLogCreateNotPersisted(err error) error { + if err == nil { + return nil + } + return &UsageLogCreateError{ + err: err, + disposition: usageLogCreateDispositionNotPersisted, + } +} + +func IsUsageLogCreateNotPersisted(err error) bool { + if err == nil { + return false + } + var target *UsageLogCreateError + if !errors.As(err, &target) { + return false + } + return target.disposition == usageLogCreateDispositionNotPersisted +} + +func ShouldBillAfterUsageLogCreate(inserted bool, err error) bool { + if inserted { + return true + } + if err == nil { + return false + } + return !IsUsageLogCreateNotPersisted(err) +} diff --git a/backend/migrations/071_add_usage_billing_dedup.sql b/backend/migrations/071_add_usage_billing_dedup.sql new file mode 100644 index 00000000..acc28459 --- /dev/null +++ b/backend/migrations/071_add_usage_billing_dedup.sql @@ -0,0 +1,13 @@ +-- 窄表账务幂等键:将“是否已扣费”从 usage_logs 解耦出来 +-- 幂等执行:可重复运行 + +CREATE TABLE IF NOT EXISTS usage_billing_dedup ( + id BIGSERIAL PRIMARY KEY, + request_id VARCHAR(255) NOT NULL, + api_key_id BIGINT NOT NULL, + request_fingerprint VARCHAR(64) NOT NULL, + created_at TIMESTAMPTZ NOT NULL DEFAULT NOW() +); + +CREATE UNIQUE INDEX IF NOT EXISTS idx_usage_billing_dedup_request_api_key + ON usage_billing_dedup (request_id, api_key_id); diff --git a/backend/migrations/072_add_usage_billing_dedup_created_at_brin_notx.sql b/backend/migrations/072_add_usage_billing_dedup_created_at_brin_notx.sql new file mode 100644 index 00000000..965a3412 --- /dev/null +++ b/backend/migrations/072_add_usage_billing_dedup_created_at_brin_notx.sql @@ -0,0 +1,7 @@ +-- usage_billing_dedup 是按时间追加写入的幂等窄表。 +-- 使用 BRIN 支撑按 created_at 的批量保留期清理,尽量降低写放大。 +-- 使用 CONCURRENTLY 避免在热表上长时间阻塞写入。 + +CREATE INDEX CONCURRENTLY IF NOT EXISTS idx_usage_billing_dedup_created_at_brin + ON usage_billing_dedup + USING BRIN (created_at); diff --git a/backend/migrations/073_add_usage_billing_dedup_archive.sql b/backend/migrations/073_add_usage_billing_dedup_archive.sql new file mode 100644 index 00000000..d156d4eb --- /dev/null +++ b/backend/migrations/073_add_usage_billing_dedup_archive.sql @@ -0,0 +1,10 @@ +-- 冷归档旧账务幂等键,缩小热表索引与清理范围,同时不丢失长期去重能力。 + +CREATE TABLE IF NOT EXISTS usage_billing_dedup_archive ( + request_id VARCHAR(255) NOT NULL, + api_key_id BIGINT NOT NULL, + request_fingerprint VARCHAR(64) NOT NULL, + created_at TIMESTAMPTZ NOT NULL, + archived_at TIMESTAMPTZ NOT NULL DEFAULT NOW(), + PRIMARY KEY (request_id, api_key_id) +); diff --git a/deploy/build_image.sh b/deploy/build_image.sh old mode 100755 new mode 100644 diff --git a/deploy/install-datamanagementd.sh b/deploy/install-datamanagementd.sh old mode 100755 new mode 100644