mirror of
https://github.com/Wei-Shaw/sub2api.git
synced 2026-03-30 02:09:43 +00:00
feat: decouple billing correctness from usage log batching
This commit is contained in:
@@ -934,9 +934,10 @@ type DashboardAggregationConfig struct {
|
||||
|
||||
// DashboardAggregationRetentionConfig 预聚合保留窗口
|
||||
type DashboardAggregationRetentionConfig struct {
|
||||
UsageLogsDays int `mapstructure:"usage_logs_days"`
|
||||
HourlyDays int `mapstructure:"hourly_days"`
|
||||
DailyDays int `mapstructure:"daily_days"`
|
||||
UsageLogsDays int `mapstructure:"usage_logs_days"`
|
||||
UsageBillingDedupDays int `mapstructure:"usage_billing_dedup_days"`
|
||||
HourlyDays int `mapstructure:"hourly_days"`
|
||||
DailyDays int `mapstructure:"daily_days"`
|
||||
}
|
||||
|
||||
// UsageCleanupConfig 使用记录清理任务配置
|
||||
@@ -1301,6 +1302,7 @@ func setDefaults() {
|
||||
viper.SetDefault("dashboard_aggregation.backfill_enabled", false)
|
||||
viper.SetDefault("dashboard_aggregation.backfill_max_days", 31)
|
||||
viper.SetDefault("dashboard_aggregation.retention.usage_logs_days", 90)
|
||||
viper.SetDefault("dashboard_aggregation.retention.usage_billing_dedup_days", 365)
|
||||
viper.SetDefault("dashboard_aggregation.retention.hourly_days", 180)
|
||||
viper.SetDefault("dashboard_aggregation.retention.daily_days", 730)
|
||||
viper.SetDefault("dashboard_aggregation.recompute_days", 2)
|
||||
@@ -1758,6 +1760,12 @@ func (c *Config) Validate() error {
|
||||
if c.DashboardAgg.Retention.UsageLogsDays <= 0 {
|
||||
return fmt.Errorf("dashboard_aggregation.retention.usage_logs_days must be positive")
|
||||
}
|
||||
if c.DashboardAgg.Retention.UsageBillingDedupDays <= 0 {
|
||||
return fmt.Errorf("dashboard_aggregation.retention.usage_billing_dedup_days must be positive")
|
||||
}
|
||||
if c.DashboardAgg.Retention.UsageBillingDedupDays < c.DashboardAgg.Retention.UsageLogsDays {
|
||||
return fmt.Errorf("dashboard_aggregation.retention.usage_billing_dedup_days must be greater than or equal to usage_logs_days")
|
||||
}
|
||||
if c.DashboardAgg.Retention.HourlyDays <= 0 {
|
||||
return fmt.Errorf("dashboard_aggregation.retention.hourly_days must be positive")
|
||||
}
|
||||
@@ -1780,6 +1788,14 @@ func (c *Config) Validate() error {
|
||||
if c.DashboardAgg.Retention.UsageLogsDays < 0 {
|
||||
return fmt.Errorf("dashboard_aggregation.retention.usage_logs_days must be non-negative")
|
||||
}
|
||||
if c.DashboardAgg.Retention.UsageBillingDedupDays < 0 {
|
||||
return fmt.Errorf("dashboard_aggregation.retention.usage_billing_dedup_days must be non-negative")
|
||||
}
|
||||
if c.DashboardAgg.Retention.UsageBillingDedupDays > 0 &&
|
||||
c.DashboardAgg.Retention.UsageLogsDays > 0 &&
|
||||
c.DashboardAgg.Retention.UsageBillingDedupDays < c.DashboardAgg.Retention.UsageLogsDays {
|
||||
return fmt.Errorf("dashboard_aggregation.retention.usage_billing_dedup_days must be greater than or equal to usage_logs_days")
|
||||
}
|
||||
if c.DashboardAgg.Retention.HourlyDays < 0 {
|
||||
return fmt.Errorf("dashboard_aggregation.retention.hourly_days must be non-negative")
|
||||
}
|
||||
|
||||
@@ -441,6 +441,9 @@ func TestLoadDefaultDashboardAggregationConfig(t *testing.T) {
|
||||
if cfg.DashboardAgg.Retention.UsageLogsDays != 90 {
|
||||
t.Fatalf("DashboardAgg.Retention.UsageLogsDays = %d, want 90", cfg.DashboardAgg.Retention.UsageLogsDays)
|
||||
}
|
||||
if cfg.DashboardAgg.Retention.UsageBillingDedupDays != 365 {
|
||||
t.Fatalf("DashboardAgg.Retention.UsageBillingDedupDays = %d, want 365", cfg.DashboardAgg.Retention.UsageBillingDedupDays)
|
||||
}
|
||||
if cfg.DashboardAgg.Retention.HourlyDays != 180 {
|
||||
t.Fatalf("DashboardAgg.Retention.HourlyDays = %d, want 180", cfg.DashboardAgg.Retention.HourlyDays)
|
||||
}
|
||||
@@ -1016,6 +1019,23 @@ func TestValidateConfigErrors(t *testing.T) {
|
||||
mutate: func(c *Config) { c.DashboardAgg.Enabled = true; c.DashboardAgg.Retention.UsageLogsDays = 0 },
|
||||
wantErr: "dashboard_aggregation.retention.usage_logs_days",
|
||||
},
|
||||
{
|
||||
name: "dashboard aggregation dedup retention",
|
||||
mutate: func(c *Config) {
|
||||
c.DashboardAgg.Enabled = true
|
||||
c.DashboardAgg.Retention.UsageBillingDedupDays = 0
|
||||
},
|
||||
wantErr: "dashboard_aggregation.retention.usage_billing_dedup_days",
|
||||
},
|
||||
{
|
||||
name: "dashboard aggregation dedup retention smaller than usage logs",
|
||||
mutate: func(c *Config) {
|
||||
c.DashboardAgg.Enabled = true
|
||||
c.DashboardAgg.Retention.UsageLogsDays = 30
|
||||
c.DashboardAgg.Retention.UsageBillingDedupDays = 29
|
||||
},
|
||||
wantErr: "dashboard_aggregation.retention.usage_billing_dedup_days",
|
||||
},
|
||||
{
|
||||
name: "dashboard aggregation disabled interval",
|
||||
mutate: func(c *Config) { c.DashboardAgg.Enabled = false; c.DashboardAgg.IntervalSeconds = -1 },
|
||||
|
||||
@@ -434,19 +434,21 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
|
||||
// 捕获请求信息(用于异步记录,避免在 goroutine 中访问 gin.Context)
|
||||
userAgent := c.GetHeader("User-Agent")
|
||||
clientIP := ip.GetClientIP(c)
|
||||
requestPayloadHash := service.HashUsageRequestPayload(body)
|
||||
|
||||
// 使用量记录通过有界 worker 池提交,避免请求热路径创建无界 goroutine。
|
||||
h.submitUsageRecordTask(func(ctx context.Context) {
|
||||
if err := h.gatewayService.RecordUsage(ctx, &service.RecordUsageInput{
|
||||
Result: result,
|
||||
APIKey: apiKey,
|
||||
User: apiKey.User,
|
||||
Account: account,
|
||||
Subscription: subscription,
|
||||
UserAgent: userAgent,
|
||||
IPAddress: clientIP,
|
||||
ForceCacheBilling: fs.ForceCacheBilling,
|
||||
APIKeyService: h.apiKeyService,
|
||||
Result: result,
|
||||
APIKey: apiKey,
|
||||
User: apiKey.User,
|
||||
Account: account,
|
||||
Subscription: subscription,
|
||||
UserAgent: userAgent,
|
||||
IPAddress: clientIP,
|
||||
RequestPayloadHash: requestPayloadHash,
|
||||
ForceCacheBilling: fs.ForceCacheBilling,
|
||||
APIKeyService: h.apiKeyService,
|
||||
}); err != nil {
|
||||
logger.L().With(
|
||||
zap.String("component", "handler.gateway.messages"),
|
||||
@@ -736,19 +738,21 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
|
||||
// 捕获请求信息(用于异步记录,避免在 goroutine 中访问 gin.Context)
|
||||
userAgent := c.GetHeader("User-Agent")
|
||||
clientIP := ip.GetClientIP(c)
|
||||
requestPayloadHash := service.HashUsageRequestPayload(body)
|
||||
|
||||
// 使用量记录通过有界 worker 池提交,避免请求热路径创建无界 goroutine。
|
||||
h.submitUsageRecordTask(func(ctx context.Context) {
|
||||
if err := h.gatewayService.RecordUsage(ctx, &service.RecordUsageInput{
|
||||
Result: result,
|
||||
APIKey: currentAPIKey,
|
||||
User: currentAPIKey.User,
|
||||
Account: account,
|
||||
Subscription: currentSubscription,
|
||||
UserAgent: userAgent,
|
||||
IPAddress: clientIP,
|
||||
ForceCacheBilling: fs.ForceCacheBilling,
|
||||
APIKeyService: h.apiKeyService,
|
||||
Result: result,
|
||||
APIKey: currentAPIKey,
|
||||
User: currentAPIKey.User,
|
||||
Account: account,
|
||||
Subscription: currentSubscription,
|
||||
UserAgent: userAgent,
|
||||
IPAddress: clientIP,
|
||||
RequestPayloadHash: requestPayloadHash,
|
||||
ForceCacheBilling: fs.ForceCacheBilling,
|
||||
APIKeyService: h.apiKeyService,
|
||||
}); err != nil {
|
||||
logger.L().With(
|
||||
zap.String("component", "handler.gateway.messages"),
|
||||
|
||||
@@ -139,6 +139,7 @@ func newTestGatewayHandler(t *testing.T, group *service.Group, accounts []*servi
|
||||
nil, // accountRepo (not used: scheduler snapshot hit)
|
||||
&fakeGroupRepo{group: group},
|
||||
nil, // usageLogRepo
|
||||
nil, // usageBillingRepo
|
||||
nil, // userRepo
|
||||
nil, // userSubRepo
|
||||
nil, // userGroupRateRepo
|
||||
|
||||
@@ -503,6 +503,7 @@ func (h *GatewayHandler) GeminiV1BetaModels(c *gin.Context) {
|
||||
}
|
||||
|
||||
// 使用量记录通过有界 worker 池提交,避免请求热路径创建无界 goroutine。
|
||||
requestPayloadHash := service.HashUsageRequestPayload(body)
|
||||
h.submitUsageRecordTask(func(ctx context.Context) {
|
||||
if err := h.gatewayService.RecordUsageWithLongContext(ctx, &service.RecordUsageLongContextInput{
|
||||
Result: result,
|
||||
@@ -512,6 +513,7 @@ func (h *GatewayHandler) GeminiV1BetaModels(c *gin.Context) {
|
||||
Subscription: subscription,
|
||||
UserAgent: userAgent,
|
||||
IPAddress: clientIP,
|
||||
RequestPayloadHash: requestPayloadHash,
|
||||
LongContextThreshold: 200000, // Gemini 200K 阈值
|
||||
LongContextMultiplier: 2.0, // 超出部分双倍计费
|
||||
ForceCacheBilling: fs.ForceCacheBilling,
|
||||
|
||||
@@ -352,18 +352,20 @@ func (h *OpenAIGatewayHandler) Responses(c *gin.Context) {
|
||||
// 捕获请求信息(用于异步记录,避免在 goroutine 中访问 gin.Context)
|
||||
userAgent := c.GetHeader("User-Agent")
|
||||
clientIP := ip.GetClientIP(c)
|
||||
requestPayloadHash := service.HashUsageRequestPayload(body)
|
||||
|
||||
// 使用量记录通过有界 worker 池提交,避免请求热路径创建无界 goroutine。
|
||||
h.submitUsageRecordTask(func(ctx context.Context) {
|
||||
if err := h.gatewayService.RecordUsage(ctx, &service.OpenAIRecordUsageInput{
|
||||
Result: result,
|
||||
APIKey: apiKey,
|
||||
User: apiKey.User,
|
||||
Account: account,
|
||||
Subscription: subscription,
|
||||
UserAgent: userAgent,
|
||||
IPAddress: clientIP,
|
||||
APIKeyService: h.apiKeyService,
|
||||
Result: result,
|
||||
APIKey: apiKey,
|
||||
User: apiKey.User,
|
||||
Account: account,
|
||||
Subscription: subscription,
|
||||
UserAgent: userAgent,
|
||||
IPAddress: clientIP,
|
||||
RequestPayloadHash: requestPayloadHash,
|
||||
APIKeyService: h.apiKeyService,
|
||||
}); err != nil {
|
||||
logger.L().With(
|
||||
zap.String("component", "handler.openai_gateway.responses"),
|
||||
@@ -732,17 +734,19 @@ func (h *OpenAIGatewayHandler) Messages(c *gin.Context) {
|
||||
|
||||
userAgent := c.GetHeader("User-Agent")
|
||||
clientIP := ip.GetClientIP(c)
|
||||
requestPayloadHash := service.HashUsageRequestPayload(body)
|
||||
|
||||
h.submitUsageRecordTask(func(ctx context.Context) {
|
||||
if err := h.gatewayService.RecordUsage(ctx, &service.OpenAIRecordUsageInput{
|
||||
Result: result,
|
||||
APIKey: apiKey,
|
||||
User: apiKey.User,
|
||||
Account: account,
|
||||
Subscription: subscription,
|
||||
UserAgent: userAgent,
|
||||
IPAddress: clientIP,
|
||||
APIKeyService: h.apiKeyService,
|
||||
Result: result,
|
||||
APIKey: apiKey,
|
||||
User: apiKey.User,
|
||||
Account: account,
|
||||
Subscription: subscription,
|
||||
UserAgent: userAgent,
|
||||
IPAddress: clientIP,
|
||||
RequestPayloadHash: requestPayloadHash,
|
||||
APIKeyService: h.apiKeyService,
|
||||
}); err != nil {
|
||||
logger.L().With(
|
||||
zap.String("component", "handler.openai_gateway.messages"),
|
||||
@@ -1231,14 +1235,15 @@ func (h *OpenAIGatewayHandler) ResponsesWebSocket(c *gin.Context) {
|
||||
h.gatewayService.ReportOpenAIAccountScheduleResult(account.ID, true, result.FirstTokenMs)
|
||||
h.submitUsageRecordTask(func(taskCtx context.Context) {
|
||||
if err := h.gatewayService.RecordUsage(taskCtx, &service.OpenAIRecordUsageInput{
|
||||
Result: result,
|
||||
APIKey: apiKey,
|
||||
User: apiKey.User,
|
||||
Account: account,
|
||||
Subscription: subscription,
|
||||
UserAgent: userAgent,
|
||||
IPAddress: clientIP,
|
||||
APIKeyService: h.apiKeyService,
|
||||
Result: result,
|
||||
APIKey: apiKey,
|
||||
User: apiKey.User,
|
||||
Account: account,
|
||||
Subscription: subscription,
|
||||
UserAgent: userAgent,
|
||||
IPAddress: clientIP,
|
||||
RequestPayloadHash: service.HashUsageRequestPayload(firstMessage),
|
||||
APIKeyService: h.apiKeyService,
|
||||
}); err != nil {
|
||||
reqLog.Error("openai.websocket_record_usage_failed",
|
||||
zap.Int64("account_id", account.ID),
|
||||
|
||||
@@ -2206,7 +2206,7 @@ func (s *stubSoraClientForHandler) GetVideoTask(_ context.Context, _ *service.Ac
|
||||
// newMinimalGatewayService 创建仅包含 accountRepo 的最小 GatewayService(用于测试 SelectAccountForModel)。
|
||||
func newMinimalGatewayService(accountRepo service.AccountRepository) *service.GatewayService {
|
||||
return service.NewGatewayService(
|
||||
accountRepo, nil, nil, nil, nil, nil, nil, nil,
|
||||
accountRepo, nil, nil, nil, nil, nil, nil, nil, nil,
|
||||
nil, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil,
|
||||
)
|
||||
}
|
||||
|
||||
@@ -399,17 +399,19 @@ func (h *SoraGatewayHandler) ChatCompletions(c *gin.Context) {
|
||||
|
||||
userAgent := c.GetHeader("User-Agent")
|
||||
clientIP := ip.GetClientIP(c)
|
||||
requestPayloadHash := service.HashUsageRequestPayload(body)
|
||||
|
||||
// 使用量记录通过有界 worker 池提交,避免请求热路径创建无界 goroutine。
|
||||
h.submitUsageRecordTask(func(ctx context.Context) {
|
||||
if err := h.gatewayService.RecordUsage(ctx, &service.RecordUsageInput{
|
||||
Result: result,
|
||||
APIKey: apiKey,
|
||||
User: apiKey.User,
|
||||
Account: account,
|
||||
Subscription: subscription,
|
||||
UserAgent: userAgent,
|
||||
IPAddress: clientIP,
|
||||
Result: result,
|
||||
APIKey: apiKey,
|
||||
User: apiKey.User,
|
||||
Account: account,
|
||||
Subscription: subscription,
|
||||
UserAgent: userAgent,
|
||||
IPAddress: clientIP,
|
||||
RequestPayloadHash: requestPayloadHash,
|
||||
}); err != nil {
|
||||
logger.L().With(
|
||||
zap.String("component", "handler.sora_gateway.chat_completions"),
|
||||
|
||||
@@ -431,6 +431,7 @@ func TestSoraGatewayHandler_ChatCompletions(t *testing.T) {
|
||||
nil,
|
||||
nil,
|
||||
nil,
|
||||
nil,
|
||||
testutil.StubGatewayCache{},
|
||||
cfg,
|
||||
nil,
|
||||
|
||||
@@ -17,6 +17,9 @@ type dashboardAggregationRepository struct {
|
||||
sql sqlExecutor
|
||||
}
|
||||
|
||||
const usageLogsCleanupBatchSize = 10000
|
||||
const usageBillingDedupCleanupBatchSize = 10000
|
||||
|
||||
// NewDashboardAggregationRepository 创建仪表盘预聚合仓储。
|
||||
func NewDashboardAggregationRepository(sqlDB *sql.DB) service.DashboardAggregationRepository {
|
||||
if sqlDB == nil {
|
||||
@@ -42,6 +45,9 @@ func isPostgresDriver(db *sql.DB) bool {
|
||||
}
|
||||
|
||||
func (r *dashboardAggregationRepository) AggregateRange(ctx context.Context, start, end time.Time) error {
|
||||
if r == nil || r.sql == nil {
|
||||
return nil
|
||||
}
|
||||
loc := timezone.Location()
|
||||
startLocal := start.In(loc)
|
||||
endLocal := end.In(loc)
|
||||
@@ -61,6 +67,22 @@ func (r *dashboardAggregationRepository) AggregateRange(ctx context.Context, sta
|
||||
dayEnd = dayEnd.Add(24 * time.Hour)
|
||||
}
|
||||
|
||||
if db, ok := r.sql.(*sql.DB); ok {
|
||||
tx, err := db.BeginTx(ctx, nil)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
txRepo := newDashboardAggregationRepositoryWithSQL(tx)
|
||||
if err := txRepo.aggregateRangeInTx(ctx, hourStart, hourEnd, dayStart, dayEnd); err != nil {
|
||||
_ = tx.Rollback()
|
||||
return err
|
||||
}
|
||||
return tx.Commit()
|
||||
}
|
||||
return r.aggregateRangeInTx(ctx, hourStart, hourEnd, dayStart, dayEnd)
|
||||
}
|
||||
|
||||
func (r *dashboardAggregationRepository) aggregateRangeInTx(ctx context.Context, hourStart, hourEnd, dayStart, dayEnd time.Time) error {
|
||||
// 以桶边界聚合,允许覆盖 end 所在桶的剩余区间。
|
||||
if err := r.insertHourlyActiveUsers(ctx, hourStart, hourEnd); err != nil {
|
||||
return err
|
||||
@@ -195,8 +217,58 @@ func (r *dashboardAggregationRepository) CleanupUsageLogs(ctx context.Context, c
|
||||
if isPartitioned {
|
||||
return r.dropUsageLogsPartitions(ctx, cutoff)
|
||||
}
|
||||
_, err = r.sql.ExecContext(ctx, "DELETE FROM usage_logs WHERE created_at < $1", cutoff.UTC())
|
||||
return err
|
||||
for {
|
||||
res, err := r.sql.ExecContext(ctx, `
|
||||
WITH victims AS (
|
||||
SELECT ctid
|
||||
FROM usage_logs
|
||||
WHERE created_at < $1
|
||||
LIMIT $2
|
||||
)
|
||||
DELETE FROM usage_logs
|
||||
WHERE ctid IN (SELECT ctid FROM victims)
|
||||
`, cutoff.UTC(), usageLogsCleanupBatchSize)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
affected, err := res.RowsAffected()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if affected < usageLogsCleanupBatchSize {
|
||||
return nil
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (r *dashboardAggregationRepository) CleanupUsageBillingDedup(ctx context.Context, cutoff time.Time) error {
|
||||
for {
|
||||
res, err := r.sql.ExecContext(ctx, `
|
||||
WITH victims AS (
|
||||
SELECT ctid, request_id, api_key_id, request_fingerprint, created_at
|
||||
FROM usage_billing_dedup
|
||||
WHERE created_at < $1
|
||||
LIMIT $2
|
||||
), archived AS (
|
||||
INSERT INTO usage_billing_dedup_archive (request_id, api_key_id, request_fingerprint, created_at)
|
||||
SELECT request_id, api_key_id, request_fingerprint, created_at
|
||||
FROM victims
|
||||
ON CONFLICT (request_id, api_key_id) DO NOTHING
|
||||
)
|
||||
DELETE FROM usage_billing_dedup
|
||||
WHERE ctid IN (SELECT ctid FROM victims)
|
||||
`, cutoff.UTC(), usageBillingDedupCleanupBatchSize)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
affected, err := res.RowsAffected()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if affected < usageBillingDedupCleanupBatchSize {
|
||||
return nil
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (r *dashboardAggregationRepository) EnsureUsageLogsPartitions(ctx context.Context, now time.Time) error {
|
||||
|
||||
@@ -45,6 +45,20 @@ func TestMigrationsRunner_IsIdempotent_AndSchemaIsUpToDate(t *testing.T) {
|
||||
requireColumn(t, tx, "usage_logs", "request_type", "smallint", 0, false)
|
||||
requireColumn(t, tx, "usage_logs", "openai_ws_mode", "boolean", 0, false)
|
||||
|
||||
// usage_billing_dedup: billing idempotency narrow table
|
||||
var usageBillingDedupRegclass sql.NullString
|
||||
require.NoError(t, tx.QueryRowContext(context.Background(), "SELECT to_regclass('public.usage_billing_dedup')").Scan(&usageBillingDedupRegclass))
|
||||
require.True(t, usageBillingDedupRegclass.Valid, "expected usage_billing_dedup table to exist")
|
||||
requireColumn(t, tx, "usage_billing_dedup", "request_fingerprint", "character varying", 64, false)
|
||||
requireIndex(t, tx, "usage_billing_dedup", "idx_usage_billing_dedup_request_api_key")
|
||||
requireIndex(t, tx, "usage_billing_dedup", "idx_usage_billing_dedup_created_at_brin")
|
||||
|
||||
var usageBillingDedupArchiveRegclass sql.NullString
|
||||
require.NoError(t, tx.QueryRowContext(context.Background(), "SELECT to_regclass('public.usage_billing_dedup_archive')").Scan(&usageBillingDedupArchiveRegclass))
|
||||
require.True(t, usageBillingDedupArchiveRegclass.Valid, "expected usage_billing_dedup_archive table to exist")
|
||||
requireColumn(t, tx, "usage_billing_dedup_archive", "request_fingerprint", "character varying", 64, false)
|
||||
requireIndex(t, tx, "usage_billing_dedup_archive", "usage_billing_dedup_archive_pkey")
|
||||
|
||||
// settings table should exist
|
||||
var settingsRegclass sql.NullString
|
||||
require.NoError(t, tx.QueryRowContext(context.Background(), "SELECT to_regclass('public.settings')").Scan(&settingsRegclass))
|
||||
@@ -75,6 +89,23 @@ func TestMigrationsRunner_IsIdempotent_AndSchemaIsUpToDate(t *testing.T) {
|
||||
requireColumn(t, tx, "user_allowed_groups", "created_at", "timestamp with time zone", 0, false)
|
||||
}
|
||||
|
||||
func requireIndex(t *testing.T, tx *sql.Tx, table, index string) {
|
||||
t.Helper()
|
||||
|
||||
var exists bool
|
||||
err := tx.QueryRowContext(context.Background(), `
|
||||
SELECT EXISTS (
|
||||
SELECT 1
|
||||
FROM pg_indexes
|
||||
WHERE schemaname = 'public'
|
||||
AND tablename = $1
|
||||
AND indexname = $2
|
||||
)
|
||||
`, table, index).Scan(&exists)
|
||||
require.NoError(t, err, "query pg_indexes for %s.%s", table, index)
|
||||
require.True(t, exists, "expected index %s on %s", index, table)
|
||||
}
|
||||
|
||||
func requireColumn(t *testing.T, tx *sql.Tx, table, column, dataType string, maxLen int, nullable bool) {
|
||||
t.Helper()
|
||||
|
||||
|
||||
308
backend/internal/repository/usage_billing_repo.go
Normal file
308
backend/internal/repository/usage_billing_repo.go
Normal file
@@ -0,0 +1,308 @@
|
||||
package repository
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"errors"
|
||||
"strings"
|
||||
|
||||
dbent "github.com/Wei-Shaw/sub2api/ent"
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/logger"
|
||||
"github.com/Wei-Shaw/sub2api/internal/service"
|
||||
)
|
||||
|
||||
type usageBillingRepository struct {
|
||||
db *sql.DB
|
||||
}
|
||||
|
||||
func NewUsageBillingRepository(_ *dbent.Client, sqlDB *sql.DB) service.UsageBillingRepository {
|
||||
return &usageBillingRepository{db: sqlDB}
|
||||
}
|
||||
|
||||
func (r *usageBillingRepository) Apply(ctx context.Context, cmd *service.UsageBillingCommand) (_ *service.UsageBillingApplyResult, err error) {
|
||||
if cmd == nil {
|
||||
return &service.UsageBillingApplyResult{}, nil
|
||||
}
|
||||
if r == nil || r.db == nil {
|
||||
return nil, errors.New("usage billing repository db is nil")
|
||||
}
|
||||
|
||||
cmd.Normalize()
|
||||
if cmd.RequestID == "" {
|
||||
return nil, service.ErrUsageBillingRequestIDRequired
|
||||
}
|
||||
|
||||
tx, err := r.db.BeginTx(ctx, nil)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer func() {
|
||||
if tx != nil {
|
||||
_ = tx.Rollback()
|
||||
}
|
||||
}()
|
||||
|
||||
applied, err := r.claimUsageBillingKey(ctx, tx, cmd)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if !applied {
|
||||
return &service.UsageBillingApplyResult{Applied: false}, nil
|
||||
}
|
||||
|
||||
result := &service.UsageBillingApplyResult{Applied: true}
|
||||
if err := r.applyUsageBillingEffects(ctx, tx, cmd, result); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if err := tx.Commit(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
tx = nil
|
||||
return result, nil
|
||||
}
|
||||
|
||||
func (r *usageBillingRepository) claimUsageBillingKey(ctx context.Context, tx *sql.Tx, cmd *service.UsageBillingCommand) (bool, error) {
|
||||
var id int64
|
||||
err := tx.QueryRowContext(ctx, `
|
||||
INSERT INTO usage_billing_dedup (request_id, api_key_id, request_fingerprint)
|
||||
VALUES ($1, $2, $3)
|
||||
ON CONFLICT (request_id, api_key_id) DO NOTHING
|
||||
RETURNING id
|
||||
`, cmd.RequestID, cmd.APIKeyID, cmd.RequestFingerprint).Scan(&id)
|
||||
if errors.Is(err, sql.ErrNoRows) {
|
||||
var existingFingerprint string
|
||||
if err := tx.QueryRowContext(ctx, `
|
||||
SELECT request_fingerprint
|
||||
FROM usage_billing_dedup
|
||||
WHERE request_id = $1 AND api_key_id = $2
|
||||
`, cmd.RequestID, cmd.APIKeyID).Scan(&existingFingerprint); err != nil {
|
||||
return false, err
|
||||
}
|
||||
if strings.TrimSpace(existingFingerprint) != strings.TrimSpace(cmd.RequestFingerprint) {
|
||||
return false, service.ErrUsageBillingRequestConflict
|
||||
}
|
||||
return false, nil
|
||||
}
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
var archivedFingerprint string
|
||||
err = tx.QueryRowContext(ctx, `
|
||||
SELECT request_fingerprint
|
||||
FROM usage_billing_dedup_archive
|
||||
WHERE request_id = $1 AND api_key_id = $2
|
||||
`, cmd.RequestID, cmd.APIKeyID).Scan(&archivedFingerprint)
|
||||
if err == nil {
|
||||
if strings.TrimSpace(archivedFingerprint) != strings.TrimSpace(cmd.RequestFingerprint) {
|
||||
return false, service.ErrUsageBillingRequestConflict
|
||||
}
|
||||
return false, nil
|
||||
}
|
||||
if !errors.Is(err, sql.ErrNoRows) {
|
||||
return false, err
|
||||
}
|
||||
return true, nil
|
||||
}
|
||||
|
||||
func (r *usageBillingRepository) applyUsageBillingEffects(ctx context.Context, tx *sql.Tx, cmd *service.UsageBillingCommand, result *service.UsageBillingApplyResult) error {
|
||||
if cmd.SubscriptionCost > 0 && cmd.SubscriptionID != nil {
|
||||
if err := incrementUsageBillingSubscription(ctx, tx, *cmd.SubscriptionID, cmd.SubscriptionCost); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
if cmd.BalanceCost > 0 {
|
||||
if err := deductUsageBillingBalance(ctx, tx, cmd.UserID, cmd.BalanceCost); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
if cmd.APIKeyQuotaCost > 0 {
|
||||
exhausted, err := incrementUsageBillingAPIKeyQuota(ctx, tx, cmd.APIKeyID, cmd.APIKeyQuotaCost)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
result.APIKeyQuotaExhausted = exhausted
|
||||
}
|
||||
|
||||
if cmd.APIKeyRateLimitCost > 0 {
|
||||
if err := incrementUsageBillingAPIKeyRateLimit(ctx, tx, cmd.APIKeyID, cmd.APIKeyRateLimitCost); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
if cmd.AccountQuotaCost > 0 && strings.EqualFold(cmd.AccountType, service.AccountTypeAPIKey) {
|
||||
if err := incrementUsageBillingAccountQuota(ctx, tx, cmd.AccountID, cmd.AccountQuotaCost); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func incrementUsageBillingSubscription(ctx context.Context, tx *sql.Tx, subscriptionID int64, costUSD float64) error {
|
||||
const updateSQL = `
|
||||
UPDATE user_subscriptions us
|
||||
SET
|
||||
daily_usage_usd = us.daily_usage_usd + $1,
|
||||
weekly_usage_usd = us.weekly_usage_usd + $1,
|
||||
monthly_usage_usd = us.monthly_usage_usd + $1,
|
||||
updated_at = NOW()
|
||||
FROM groups g
|
||||
WHERE us.id = $2
|
||||
AND us.deleted_at IS NULL
|
||||
AND us.group_id = g.id
|
||||
AND g.deleted_at IS NULL
|
||||
`
|
||||
res, err := tx.ExecContext(ctx, updateSQL, costUSD, subscriptionID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
affected, err := res.RowsAffected()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if affected > 0 {
|
||||
return nil
|
||||
}
|
||||
return service.ErrSubscriptionNotFound
|
||||
}
|
||||
|
||||
func deductUsageBillingBalance(ctx context.Context, tx *sql.Tx, userID int64, amount float64) error {
|
||||
res, err := tx.ExecContext(ctx, `
|
||||
UPDATE users
|
||||
SET balance = balance - $1,
|
||||
updated_at = NOW()
|
||||
WHERE id = $2 AND deleted_at IS NULL
|
||||
`, amount, userID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
affected, err := res.RowsAffected()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if affected > 0 {
|
||||
return nil
|
||||
}
|
||||
return service.ErrUserNotFound
|
||||
}
|
||||
|
||||
func incrementUsageBillingAPIKeyQuota(ctx context.Context, tx *sql.Tx, apiKeyID int64, amount float64) (bool, error) {
|
||||
var exhausted bool
|
||||
err := tx.QueryRowContext(ctx, `
|
||||
UPDATE api_keys
|
||||
SET quota_used = quota_used + $1,
|
||||
status = CASE
|
||||
WHEN quota > 0
|
||||
AND status = $3
|
||||
AND quota_used < quota
|
||||
AND quota_used + $1 >= quota
|
||||
THEN $4
|
||||
ELSE status
|
||||
END,
|
||||
updated_at = NOW()
|
||||
WHERE id = $2 AND deleted_at IS NULL
|
||||
RETURNING quota > 0 AND quota_used >= quota AND quota_used - $1 < quota
|
||||
`, amount, apiKeyID, service.StatusAPIKeyActive, service.StatusAPIKeyQuotaExhausted).Scan(&exhausted)
|
||||
if errors.Is(err, sql.ErrNoRows) {
|
||||
return false, service.ErrAPIKeyNotFound
|
||||
}
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
return exhausted, nil
|
||||
}
|
||||
|
||||
func incrementUsageBillingAPIKeyRateLimit(ctx context.Context, tx *sql.Tx, apiKeyID int64, cost float64) error {
|
||||
res, err := tx.ExecContext(ctx, `
|
||||
UPDATE api_keys SET
|
||||
usage_5h = CASE WHEN window_5h_start IS NOT NULL AND window_5h_start + INTERVAL '5 hours' <= NOW() THEN $1 ELSE usage_5h + $1 END,
|
||||
usage_1d = CASE WHEN window_1d_start IS NOT NULL AND window_1d_start + INTERVAL '24 hours' <= NOW() THEN $1 ELSE usage_1d + $1 END,
|
||||
usage_7d = CASE WHEN window_7d_start IS NOT NULL AND window_7d_start + INTERVAL '7 days' <= NOW() THEN $1 ELSE usage_7d + $1 END,
|
||||
window_5h_start = CASE WHEN window_5h_start IS NULL OR window_5h_start + INTERVAL '5 hours' <= NOW() THEN NOW() ELSE window_5h_start END,
|
||||
window_1d_start = CASE WHEN window_1d_start IS NULL OR window_1d_start + INTERVAL '24 hours' <= NOW() THEN date_trunc('day', NOW()) ELSE window_1d_start END,
|
||||
window_7d_start = CASE WHEN window_7d_start IS NULL OR window_7d_start + INTERVAL '7 days' <= NOW() THEN date_trunc('day', NOW()) ELSE window_7d_start END,
|
||||
updated_at = NOW()
|
||||
WHERE id = $2 AND deleted_at IS NULL
|
||||
`, cost, apiKeyID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
affected, err := res.RowsAffected()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if affected == 0 {
|
||||
return service.ErrAPIKeyNotFound
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func incrementUsageBillingAccountQuota(ctx context.Context, tx *sql.Tx, accountID int64, amount float64) error {
|
||||
rows, err := tx.QueryContext(ctx,
|
||||
`UPDATE accounts SET extra = (
|
||||
COALESCE(extra, '{}'::jsonb)
|
||||
|| jsonb_build_object('quota_used', COALESCE((extra->>'quota_used')::numeric, 0) + $1)
|
||||
|| CASE WHEN COALESCE((extra->>'quota_daily_limit')::numeric, 0) > 0 THEN
|
||||
jsonb_build_object(
|
||||
'quota_daily_used',
|
||||
CASE WHEN COALESCE((extra->>'quota_daily_start')::timestamptz, '1970-01-01'::timestamptz)
|
||||
+ '24 hours'::interval <= NOW()
|
||||
THEN $1
|
||||
ELSE COALESCE((extra->>'quota_daily_used')::numeric, 0) + $1 END,
|
||||
'quota_daily_start',
|
||||
CASE WHEN COALESCE((extra->>'quota_daily_start')::timestamptz, '1970-01-01'::timestamptz)
|
||||
+ '24 hours'::interval <= NOW()
|
||||
THEN `+nowUTC+`
|
||||
ELSE COALESCE(extra->>'quota_daily_start', `+nowUTC+`) END
|
||||
)
|
||||
ELSE '{}'::jsonb END
|
||||
|| CASE WHEN COALESCE((extra->>'quota_weekly_limit')::numeric, 0) > 0 THEN
|
||||
jsonb_build_object(
|
||||
'quota_weekly_used',
|
||||
CASE WHEN COALESCE((extra->>'quota_weekly_start')::timestamptz, '1970-01-01'::timestamptz)
|
||||
+ '168 hours'::interval <= NOW()
|
||||
THEN $1
|
||||
ELSE COALESCE((extra->>'quota_weekly_used')::numeric, 0) + $1 END,
|
||||
'quota_weekly_start',
|
||||
CASE WHEN COALESCE((extra->>'quota_weekly_start')::timestamptz, '1970-01-01'::timestamptz)
|
||||
+ '168 hours'::interval <= NOW()
|
||||
THEN `+nowUTC+`
|
||||
ELSE COALESCE(extra->>'quota_weekly_start', `+nowUTC+`) END
|
||||
)
|
||||
ELSE '{}'::jsonb END
|
||||
), updated_at = NOW()
|
||||
WHERE id = $2 AND deleted_at IS NULL
|
||||
RETURNING
|
||||
COALESCE((extra->>'quota_used')::numeric, 0),
|
||||
COALESCE((extra->>'quota_limit')::numeric, 0)`,
|
||||
amount, accountID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer func() { _ = rows.Close() }()
|
||||
|
||||
var newUsed, limit float64
|
||||
if rows.Next() {
|
||||
if err := rows.Scan(&newUsed, &limit); err != nil {
|
||||
return err
|
||||
}
|
||||
} else {
|
||||
if err := rows.Err(); err != nil {
|
||||
return err
|
||||
}
|
||||
return service.ErrAccountNotFound
|
||||
}
|
||||
if err := rows.Err(); err != nil {
|
||||
return err
|
||||
}
|
||||
if limit > 0 && newUsed >= limit && (newUsed-amount) < limit {
|
||||
if err := enqueueSchedulerOutbox(ctx, tx, service.SchedulerOutboxEventAccountChanged, &accountID, nil, nil); err != nil {
|
||||
logger.LegacyPrintf("repository.usage_billing", "[SchedulerOutbox] enqueue quota exceeded failed: account=%d err=%v", accountID, err)
|
||||
return err
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
@@ -0,0 +1,279 @@
|
||||
//go:build integration
|
||||
|
||||
package repository
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/google/uuid"
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/service"
|
||||
)
|
||||
|
||||
func TestUsageBillingRepositoryApply_DeduplicatesBalanceBilling(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
client := testEntClient(t)
|
||||
repo := NewUsageBillingRepository(client, integrationDB)
|
||||
|
||||
user := mustCreateUser(t, client, &service.User{
|
||||
Email: fmt.Sprintf("usage-billing-user-%d@example.com", time.Now().UnixNano()),
|
||||
PasswordHash: "hash",
|
||||
Balance: 100,
|
||||
})
|
||||
apiKey := mustCreateApiKey(t, client, &service.APIKey{
|
||||
UserID: user.ID,
|
||||
Key: "sk-usage-billing-" + uuid.NewString(),
|
||||
Name: "billing",
|
||||
Quota: 1,
|
||||
})
|
||||
account := mustCreateAccount(t, client, &service.Account{
|
||||
Name: "usage-billing-account-" + uuid.NewString(),
|
||||
Type: service.AccountTypeAPIKey,
|
||||
})
|
||||
|
||||
requestID := uuid.NewString()
|
||||
cmd := &service.UsageBillingCommand{
|
||||
RequestID: requestID,
|
||||
APIKeyID: apiKey.ID,
|
||||
UserID: user.ID,
|
||||
AccountID: account.ID,
|
||||
AccountType: service.AccountTypeAPIKey,
|
||||
BalanceCost: 1.25,
|
||||
APIKeyQuotaCost: 1.25,
|
||||
APIKeyRateLimitCost: 1.25,
|
||||
}
|
||||
|
||||
result1, err := repo.Apply(ctx, cmd)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, result1)
|
||||
require.True(t, result1.Applied)
|
||||
require.True(t, result1.APIKeyQuotaExhausted)
|
||||
|
||||
result2, err := repo.Apply(ctx, cmd)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, result2)
|
||||
require.False(t, result2.Applied)
|
||||
|
||||
var balance float64
|
||||
require.NoError(t, integrationDB.QueryRowContext(ctx, "SELECT balance FROM users WHERE id = $1", user.ID).Scan(&balance))
|
||||
require.InDelta(t, 98.75, balance, 0.000001)
|
||||
|
||||
var quotaUsed float64
|
||||
require.NoError(t, integrationDB.QueryRowContext(ctx, "SELECT quota_used FROM api_keys WHERE id = $1", apiKey.ID).Scan("aUsed))
|
||||
require.InDelta(t, 1.25, quotaUsed, 0.000001)
|
||||
|
||||
var usage5h float64
|
||||
require.NoError(t, integrationDB.QueryRowContext(ctx, "SELECT usage_5h FROM api_keys WHERE id = $1", apiKey.ID).Scan(&usage5h))
|
||||
require.InDelta(t, 1.25, usage5h, 0.000001)
|
||||
|
||||
var status string
|
||||
require.NoError(t, integrationDB.QueryRowContext(ctx, "SELECT status FROM api_keys WHERE id = $1", apiKey.ID).Scan(&status))
|
||||
require.Equal(t, service.StatusAPIKeyQuotaExhausted, status)
|
||||
|
||||
var dedupCount int
|
||||
require.NoError(t, integrationDB.QueryRowContext(ctx, "SELECT COUNT(*) FROM usage_billing_dedup WHERE request_id = $1 AND api_key_id = $2", requestID, apiKey.ID).Scan(&dedupCount))
|
||||
require.Equal(t, 1, dedupCount)
|
||||
}
|
||||
|
||||
func TestUsageBillingRepositoryApply_DeduplicatesSubscriptionBilling(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
client := testEntClient(t)
|
||||
repo := NewUsageBillingRepository(client, integrationDB)
|
||||
|
||||
user := mustCreateUser(t, client, &service.User{
|
||||
Email: fmt.Sprintf("usage-billing-sub-user-%d@example.com", time.Now().UnixNano()),
|
||||
PasswordHash: "hash",
|
||||
})
|
||||
group := mustCreateGroup(t, client, &service.Group{
|
||||
Name: "usage-billing-group-" + uuid.NewString(),
|
||||
Platform: service.PlatformAnthropic,
|
||||
SubscriptionType: service.SubscriptionTypeSubscription,
|
||||
})
|
||||
apiKey := mustCreateApiKey(t, client, &service.APIKey{
|
||||
UserID: user.ID,
|
||||
GroupID: &group.ID,
|
||||
Key: "sk-usage-billing-sub-" + uuid.NewString(),
|
||||
Name: "billing-sub",
|
||||
})
|
||||
subscription := mustCreateSubscription(t, client, &service.UserSubscription{
|
||||
UserID: user.ID,
|
||||
GroupID: group.ID,
|
||||
})
|
||||
|
||||
requestID := uuid.NewString()
|
||||
cmd := &service.UsageBillingCommand{
|
||||
RequestID: requestID,
|
||||
APIKeyID: apiKey.ID,
|
||||
UserID: user.ID,
|
||||
AccountID: 0,
|
||||
SubscriptionID: &subscription.ID,
|
||||
SubscriptionCost: 2.5,
|
||||
}
|
||||
|
||||
result1, err := repo.Apply(ctx, cmd)
|
||||
require.NoError(t, err)
|
||||
require.True(t, result1.Applied)
|
||||
|
||||
result2, err := repo.Apply(ctx, cmd)
|
||||
require.NoError(t, err)
|
||||
require.False(t, result2.Applied)
|
||||
|
||||
var dailyUsage float64
|
||||
require.NoError(t, integrationDB.QueryRowContext(ctx, "SELECT daily_usage_usd FROM user_subscriptions WHERE id = $1", subscription.ID).Scan(&dailyUsage))
|
||||
require.InDelta(t, 2.5, dailyUsage, 0.000001)
|
||||
}
|
||||
|
||||
func TestUsageBillingRepositoryApply_RequestFingerprintConflict(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
client := testEntClient(t)
|
||||
repo := NewUsageBillingRepository(client, integrationDB)
|
||||
|
||||
user := mustCreateUser(t, client, &service.User{
|
||||
Email: fmt.Sprintf("usage-billing-conflict-user-%d@example.com", time.Now().UnixNano()),
|
||||
PasswordHash: "hash",
|
||||
Balance: 100,
|
||||
})
|
||||
apiKey := mustCreateApiKey(t, client, &service.APIKey{
|
||||
UserID: user.ID,
|
||||
Key: "sk-usage-billing-conflict-" + uuid.NewString(),
|
||||
Name: "billing-conflict",
|
||||
})
|
||||
|
||||
requestID := uuid.NewString()
|
||||
_, err := repo.Apply(ctx, &service.UsageBillingCommand{
|
||||
RequestID: requestID,
|
||||
APIKeyID: apiKey.ID,
|
||||
UserID: user.ID,
|
||||
BalanceCost: 1.25,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
_, err = repo.Apply(ctx, &service.UsageBillingCommand{
|
||||
RequestID: requestID,
|
||||
APIKeyID: apiKey.ID,
|
||||
UserID: user.ID,
|
||||
BalanceCost: 2.50,
|
||||
})
|
||||
require.ErrorIs(t, err, service.ErrUsageBillingRequestConflict)
|
||||
}
|
||||
|
||||
func TestUsageBillingRepositoryApply_UpdatesAccountQuota(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
client := testEntClient(t)
|
||||
repo := NewUsageBillingRepository(client, integrationDB)
|
||||
|
||||
user := mustCreateUser(t, client, &service.User{
|
||||
Email: fmt.Sprintf("usage-billing-account-user-%d@example.com", time.Now().UnixNano()),
|
||||
PasswordHash: "hash",
|
||||
})
|
||||
apiKey := mustCreateApiKey(t, client, &service.APIKey{
|
||||
UserID: user.ID,
|
||||
Key: "sk-usage-billing-account-" + uuid.NewString(),
|
||||
Name: "billing-account",
|
||||
})
|
||||
account := mustCreateAccount(t, client, &service.Account{
|
||||
Name: "usage-billing-account-quota-" + uuid.NewString(),
|
||||
Type: service.AccountTypeAPIKey,
|
||||
Extra: map[string]any{
|
||||
"quota_limit": 100.0,
|
||||
},
|
||||
})
|
||||
|
||||
_, err := repo.Apply(ctx, &service.UsageBillingCommand{
|
||||
RequestID: uuid.NewString(),
|
||||
APIKeyID: apiKey.ID,
|
||||
UserID: user.ID,
|
||||
AccountID: account.ID,
|
||||
AccountType: service.AccountTypeAPIKey,
|
||||
AccountQuotaCost: 3.5,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
var quotaUsed float64
|
||||
require.NoError(t, integrationDB.QueryRowContext(ctx, "SELECT COALESCE((extra->>'quota_used')::numeric, 0) FROM accounts WHERE id = $1", account.ID).Scan("aUsed))
|
||||
require.InDelta(t, 3.5, quotaUsed, 0.000001)
|
||||
}
|
||||
|
||||
func TestDashboardAggregationRepositoryCleanupUsageBillingDedup_BatchDeletesOldRows(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
repo := newDashboardAggregationRepositoryWithSQL(integrationDB)
|
||||
|
||||
oldRequestID := "dedup-old-" + uuid.NewString()
|
||||
newRequestID := "dedup-new-" + uuid.NewString()
|
||||
oldCreatedAt := time.Now().UTC().AddDate(0, 0, -400)
|
||||
newCreatedAt := time.Now().UTC().Add(-time.Hour)
|
||||
|
||||
_, err := integrationDB.ExecContext(ctx, `
|
||||
INSERT INTO usage_billing_dedup (request_id, api_key_id, request_fingerprint, created_at)
|
||||
VALUES ($1, 1, $2, $3), ($4, 1, $5, $6)
|
||||
`,
|
||||
oldRequestID, strings.Repeat("a", 64), oldCreatedAt,
|
||||
newRequestID, strings.Repeat("b", 64), newCreatedAt,
|
||||
)
|
||||
require.NoError(t, err)
|
||||
|
||||
require.NoError(t, repo.CleanupUsageBillingDedup(ctx, time.Now().UTC().AddDate(0, 0, -365)))
|
||||
|
||||
var oldCount int
|
||||
require.NoError(t, integrationDB.QueryRowContext(ctx, "SELECT COUNT(*) FROM usage_billing_dedup WHERE request_id = $1", oldRequestID).Scan(&oldCount))
|
||||
require.Equal(t, 0, oldCount)
|
||||
|
||||
var newCount int
|
||||
require.NoError(t, integrationDB.QueryRowContext(ctx, "SELECT COUNT(*) FROM usage_billing_dedup WHERE request_id = $1", newRequestID).Scan(&newCount))
|
||||
require.Equal(t, 1, newCount)
|
||||
|
||||
var archivedCount int
|
||||
require.NoError(t, integrationDB.QueryRowContext(ctx, "SELECT COUNT(*) FROM usage_billing_dedup_archive WHERE request_id = $1", oldRequestID).Scan(&archivedCount))
|
||||
require.Equal(t, 1, archivedCount)
|
||||
}
|
||||
|
||||
func TestUsageBillingRepositoryApply_DeduplicatesAgainstArchivedKey(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
client := testEntClient(t)
|
||||
repo := NewUsageBillingRepository(client, integrationDB)
|
||||
aggRepo := newDashboardAggregationRepositoryWithSQL(integrationDB)
|
||||
|
||||
user := mustCreateUser(t, client, &service.User{
|
||||
Email: fmt.Sprintf("usage-billing-archive-user-%d@example.com", time.Now().UnixNano()),
|
||||
PasswordHash: "hash",
|
||||
Balance: 100,
|
||||
})
|
||||
apiKey := mustCreateApiKey(t, client, &service.APIKey{
|
||||
UserID: user.ID,
|
||||
Key: "sk-usage-billing-archive-" + uuid.NewString(),
|
||||
Name: "billing-archive",
|
||||
})
|
||||
|
||||
requestID := uuid.NewString()
|
||||
cmd := &service.UsageBillingCommand{
|
||||
RequestID: requestID,
|
||||
APIKeyID: apiKey.ID,
|
||||
UserID: user.ID,
|
||||
BalanceCost: 1.25,
|
||||
}
|
||||
|
||||
result1, err := repo.Apply(ctx, cmd)
|
||||
require.NoError(t, err)
|
||||
require.True(t, result1.Applied)
|
||||
|
||||
_, err = integrationDB.ExecContext(ctx, `
|
||||
UPDATE usage_billing_dedup
|
||||
SET created_at = $1
|
||||
WHERE request_id = $2 AND api_key_id = $3
|
||||
`, time.Now().UTC().AddDate(0, 0, -400), requestID, apiKey.ID)
|
||||
require.NoError(t, err)
|
||||
require.NoError(t, aggRepo.CleanupUsageBillingDedup(ctx, time.Now().UTC().AddDate(0, 0, -365)))
|
||||
|
||||
result2, err := repo.Apply(ctx, cmd)
|
||||
require.NoError(t, err)
|
||||
require.False(t, result2.Applied)
|
||||
|
||||
var balance float64
|
||||
require.NoError(t, integrationDB.QueryRowContext(ctx, "SELECT balance FROM users WHERE id = $1", user.ID).Scan(&balance))
|
||||
require.InDelta(t, 98.75, balance, 0.000001)
|
||||
}
|
||||
@@ -3,12 +3,14 @@ package repository
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"os"
|
||||
"strconv"
|
||||
"strings"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
|
||||
dbent "github.com/Wei-Shaw/sub2api/ent"
|
||||
@@ -17,11 +19,13 @@ import (
|
||||
dbgroup "github.com/Wei-Shaw/sub2api/ent/group"
|
||||
dbuser "github.com/Wei-Shaw/sub2api/ent/user"
|
||||
dbusersub "github.com/Wei-Shaw/sub2api/ent/usersubscription"
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/logger"
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/timezone"
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/usagestats"
|
||||
"github.com/Wei-Shaw/sub2api/internal/service"
|
||||
"github.com/lib/pq"
|
||||
gocache "github.com/patrickmn/go-cache"
|
||||
)
|
||||
|
||||
const usageLogSelectColumns = "id, user_id, api_key_id, account_id, request_id, model, group_id, subscription_id, input_tokens, output_tokens, cache_creation_tokens, cache_read_tokens, cache_creation_5m_tokens, cache_creation_1h_tokens, input_cost, output_cost, cache_creation_cost, cache_read_cost, total_cost, actual_cost, rate_multiplier, account_rate_multiplier, billing_type, request_type, stream, openai_ws_mode, duration_ms, first_token_ms, user_agent, ip_address, image_count, image_size, media_type, service_tier, reasoning_effort, cache_ttl_overridden, created_at"
|
||||
@@ -47,18 +51,29 @@ type usageLogRepository struct {
|
||||
sql sqlExecutor
|
||||
db *sql.DB
|
||||
|
||||
createBatchOnce sync.Once
|
||||
createBatchCh chan usageLogCreateRequest
|
||||
createBatchOnce sync.Once
|
||||
createBatchCh chan usageLogCreateRequest
|
||||
bestEffortBatchOnce sync.Once
|
||||
bestEffortBatchCh chan usageLogBestEffortRequest
|
||||
bestEffortRecent *gocache.Cache
|
||||
}
|
||||
|
||||
const (
|
||||
usageLogCreateBatchMaxSize = 64
|
||||
usageLogCreateBatchWindow = 3 * time.Millisecond
|
||||
usageLogCreateBatchQueueCap = 4096
|
||||
usageLogCreateCancelWait = 2 * time.Second
|
||||
|
||||
usageLogBestEffortBatchMaxSize = 256
|
||||
usageLogBestEffortBatchWindow = 20 * time.Millisecond
|
||||
usageLogBestEffortBatchQueueCap = 32768
|
||||
usageLogBestEffortRecentTTL = 30 * time.Second
|
||||
)
|
||||
|
||||
type usageLogCreateRequest struct {
|
||||
log *service.UsageLog
|
||||
prepared usageLogInsertPrepared
|
||||
shared *usageLogCreateShared
|
||||
resultCh chan usageLogCreateResult
|
||||
}
|
||||
|
||||
@@ -67,6 +82,12 @@ type usageLogCreateResult struct {
|
||||
err error
|
||||
}
|
||||
|
||||
type usageLogBestEffortRequest struct {
|
||||
prepared usageLogInsertPrepared
|
||||
apiKeyID int64
|
||||
resultCh chan error
|
||||
}
|
||||
|
||||
type usageLogInsertPrepared struct {
|
||||
createdAt time.Time
|
||||
requestID string
|
||||
@@ -80,6 +101,25 @@ type usageLogBatchState struct {
|
||||
CreatedAt time.Time
|
||||
}
|
||||
|
||||
type usageLogBatchRow struct {
|
||||
RequestID string `json:"request_id"`
|
||||
APIKeyID int64 `json:"api_key_id"`
|
||||
ID int64 `json:"id"`
|
||||
CreatedAt time.Time `json:"created_at"`
|
||||
Inserted bool `json:"inserted"`
|
||||
}
|
||||
|
||||
type usageLogCreateShared struct {
|
||||
state atomic.Int32
|
||||
}
|
||||
|
||||
const (
|
||||
usageLogCreateStateQueued int32 = iota
|
||||
usageLogCreateStateProcessing
|
||||
usageLogCreateStateCompleted
|
||||
usageLogCreateStateCanceled
|
||||
)
|
||||
|
||||
func NewUsageLogRepository(client *dbent.Client, sqlDB *sql.DB) service.UsageLogRepository {
|
||||
return newUsageLogRepositoryWithSQL(client, sqlDB)
|
||||
}
|
||||
@@ -90,6 +130,7 @@ func newUsageLogRepositoryWithSQL(client *dbent.Client, sqlq sqlExecutor) *usage
|
||||
if db, ok := sqlq.(*sql.DB); ok {
|
||||
repo.db = db
|
||||
}
|
||||
repo.bestEffortRecent = gocache.New(usageLogBestEffortRecentTTL, time.Minute)
|
||||
return repo
|
||||
}
|
||||
|
||||
@@ -124,9 +165,6 @@ func (r *usageLogRepository) Create(ctx context.Context, log *service.UsageLog)
|
||||
if tx := dbent.TxFromContext(ctx); tx != nil {
|
||||
return r.createSingle(ctx, tx.Client(), log)
|
||||
}
|
||||
if r.db == nil {
|
||||
return r.createSingle(ctx, r.sql, log)
|
||||
}
|
||||
requestID := strings.TrimSpace(log.RequestID)
|
||||
if requestID == "" {
|
||||
return r.createSingle(ctx, r.sql, log)
|
||||
@@ -135,11 +173,61 @@ func (r *usageLogRepository) Create(ctx context.Context, log *service.UsageLog)
|
||||
return r.createBatched(ctx, log)
|
||||
}
|
||||
|
||||
func (r *usageLogRepository) CreateBestEffort(ctx context.Context, log *service.UsageLog) error {
|
||||
if log == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
if tx := dbent.TxFromContext(ctx); tx != nil {
|
||||
_, err := r.createSingle(ctx, tx.Client(), log)
|
||||
return err
|
||||
}
|
||||
if r.db == nil {
|
||||
_, err := r.createSingle(ctx, r.sql, log)
|
||||
return err
|
||||
}
|
||||
|
||||
r.ensureBestEffortBatcher()
|
||||
if r.bestEffortBatchCh == nil {
|
||||
_, err := r.createSingle(ctx, r.sql, log)
|
||||
return err
|
||||
}
|
||||
|
||||
req := usageLogBestEffortRequest{
|
||||
prepared: prepareUsageLogInsert(log),
|
||||
apiKeyID: log.APIKeyID,
|
||||
resultCh: make(chan error, 1),
|
||||
}
|
||||
if key, ok := r.bestEffortRecentKey(req.prepared.requestID, req.apiKeyID); ok {
|
||||
if _, exists := r.bestEffortRecent.Get(key); exists {
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
select {
|
||||
case r.bestEffortBatchCh <- req:
|
||||
case <-ctx.Done():
|
||||
return ctx.Err()
|
||||
default:
|
||||
return errors.New("usage log best-effort queue full")
|
||||
}
|
||||
|
||||
select {
|
||||
case err := <-req.resultCh:
|
||||
return err
|
||||
case <-ctx.Done():
|
||||
return ctx.Err()
|
||||
}
|
||||
}
|
||||
|
||||
func (r *usageLogRepository) createSingle(ctx context.Context, sqlq sqlExecutor, log *service.UsageLog) (bool, error) {
|
||||
prepared := prepareUsageLogInsert(log)
|
||||
if sqlq == nil {
|
||||
sqlq = r.sql
|
||||
}
|
||||
if ctx != nil && ctx.Err() != nil {
|
||||
return false, service.MarkUsageLogCreateNotPersisted(ctx.Err())
|
||||
}
|
||||
|
||||
query := `
|
||||
INSERT INTO usage_logs (
|
||||
@@ -218,13 +306,15 @@ func (r *usageLogRepository) createBatched(ctx context.Context, log *service.Usa
|
||||
|
||||
req := usageLogCreateRequest{
|
||||
log: log,
|
||||
prepared: prepareUsageLogInsert(log),
|
||||
shared: &usageLogCreateShared{},
|
||||
resultCh: make(chan usageLogCreateResult, 1),
|
||||
}
|
||||
|
||||
select {
|
||||
case r.createBatchCh <- req:
|
||||
case <-ctx.Done():
|
||||
return false, ctx.Err()
|
||||
return false, service.MarkUsageLogCreateNotPersisted(ctx.Err())
|
||||
default:
|
||||
return r.createSingle(ctx, r.sql, log)
|
||||
}
|
||||
@@ -233,7 +323,17 @@ func (r *usageLogRepository) createBatched(ctx context.Context, log *service.Usa
|
||||
case res := <-req.resultCh:
|
||||
return res.inserted, res.err
|
||||
case <-ctx.Done():
|
||||
return false, ctx.Err()
|
||||
if req.shared != nil && req.shared.state.CompareAndSwap(usageLogCreateStateQueued, usageLogCreateStateCanceled) {
|
||||
return false, service.MarkUsageLogCreateNotPersisted(ctx.Err())
|
||||
}
|
||||
timer := time.NewTimer(usageLogCreateCancelWait)
|
||||
defer timer.Stop()
|
||||
select {
|
||||
case res := <-req.resultCh:
|
||||
return res.inserted, res.err
|
||||
case <-timer.C:
|
||||
return false, ctx.Err()
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -247,6 +347,16 @@ func (r *usageLogRepository) ensureCreateBatcher() {
|
||||
})
|
||||
}
|
||||
|
||||
func (r *usageLogRepository) ensureBestEffortBatcher() {
|
||||
if r == nil || r.db == nil {
|
||||
return
|
||||
}
|
||||
r.bestEffortBatchOnce.Do(func() {
|
||||
r.bestEffortBatchCh = make(chan usageLogBestEffortRequest, usageLogBestEffortBatchQueueCap)
|
||||
go r.runBestEffortBatcher(r.db)
|
||||
})
|
||||
}
|
||||
|
||||
func (r *usageLogRepository) runCreateBatcher(db *sql.DB) {
|
||||
for {
|
||||
first, ok := <-r.createBatchCh
|
||||
@@ -281,6 +391,40 @@ func (r *usageLogRepository) runCreateBatcher(db *sql.DB) {
|
||||
}
|
||||
}
|
||||
|
||||
func (r *usageLogRepository) runBestEffortBatcher(db *sql.DB) {
|
||||
for {
|
||||
first, ok := <-r.bestEffortBatchCh
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
|
||||
batch := make([]usageLogBestEffortRequest, 0, usageLogBestEffortBatchMaxSize)
|
||||
batch = append(batch, first)
|
||||
|
||||
timer := time.NewTimer(usageLogBestEffortBatchWindow)
|
||||
bestEffortLoop:
|
||||
for len(batch) < usageLogBestEffortBatchMaxSize {
|
||||
select {
|
||||
case req, ok := <-r.bestEffortBatchCh:
|
||||
if !ok {
|
||||
break bestEffortLoop
|
||||
}
|
||||
batch = append(batch, req)
|
||||
case <-timer.C:
|
||||
break bestEffortLoop
|
||||
}
|
||||
}
|
||||
if !timer.Stop() {
|
||||
select {
|
||||
case <-timer.C:
|
||||
default:
|
||||
}
|
||||
}
|
||||
|
||||
r.flushBestEffortBatch(db, batch)
|
||||
}
|
||||
}
|
||||
|
||||
func (r *usageLogRepository) flushCreateBatch(db *sql.DB, batch []usageLogCreateRequest) {
|
||||
if len(batch) == 0 {
|
||||
return
|
||||
@@ -293,10 +437,19 @@ func (r *usageLogRepository) flushCreateBatch(db *sql.DB, batch []usageLogCreate
|
||||
|
||||
for _, req := range batch {
|
||||
if req.log == nil {
|
||||
sendUsageLogCreateResult(req.resultCh, usageLogCreateResult{inserted: false, err: nil})
|
||||
completeUsageLogCreateRequest(req, usageLogCreateResult{inserted: false, err: nil})
|
||||
continue
|
||||
}
|
||||
prepared := prepareUsageLogInsert(req.log)
|
||||
if req.shared != nil && !req.shared.state.CompareAndSwap(usageLogCreateStateQueued, usageLogCreateStateProcessing) {
|
||||
if req.shared.state.Load() == usageLogCreateStateCanceled {
|
||||
completeUsageLogCreateRequest(req, usageLogCreateResult{
|
||||
inserted: false,
|
||||
err: service.MarkUsageLogCreateNotPersisted(context.Canceled),
|
||||
})
|
||||
continue
|
||||
}
|
||||
}
|
||||
prepared := req.prepared
|
||||
if prepared.requestID == "" {
|
||||
fallback = append(fallback, req)
|
||||
continue
|
||||
@@ -310,10 +463,37 @@ func (r *usageLogRepository) flushCreateBatch(db *sql.DB, batch []usageLogCreate
|
||||
}
|
||||
|
||||
if len(uniqueOrder) > 0 {
|
||||
insertedMap, stateMap, err := r.batchInsertUsageLogs(db, uniqueOrder, preparedByKey)
|
||||
insertedMap, stateMap, safeFallback, err := r.batchInsertUsageLogs(db, uniqueOrder, preparedByKey)
|
||||
if err != nil {
|
||||
for _, key := range uniqueOrder {
|
||||
fallback = append(fallback, requestsByKey[key]...)
|
||||
if safeFallback {
|
||||
for _, key := range uniqueOrder {
|
||||
fallback = append(fallback, requestsByKey[key]...)
|
||||
}
|
||||
} else {
|
||||
for _, key := range uniqueOrder {
|
||||
reqs := requestsByKey[key]
|
||||
state, hasState := stateMap[key]
|
||||
inserted := insertedMap[key]
|
||||
for idx, req := range reqs {
|
||||
req.log.RateMultiplier = preparedByKey[key].rateMultiplier
|
||||
if hasState {
|
||||
req.log.ID = state.ID
|
||||
req.log.CreatedAt = state.CreatedAt
|
||||
}
|
||||
switch {
|
||||
case inserted && idx == 0:
|
||||
completeUsageLogCreateRequest(req, usageLogCreateResult{inserted: true, err: nil})
|
||||
case inserted:
|
||||
completeUsageLogCreateRequest(req, usageLogCreateResult{inserted: false, err: nil})
|
||||
case hasState:
|
||||
completeUsageLogCreateRequest(req, usageLogCreateResult{inserted: false, err: nil})
|
||||
case idx == 0:
|
||||
completeUsageLogCreateRequest(req, usageLogCreateResult{inserted: false, err: err})
|
||||
default:
|
||||
completeUsageLogCreateRequest(req, usageLogCreateResult{inserted: false, err: nil})
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
} else {
|
||||
for _, key := range uniqueOrder {
|
||||
@@ -321,7 +501,7 @@ func (r *usageLogRepository) flushCreateBatch(db *sql.DB, batch []usageLogCreate
|
||||
state, ok := stateMap[key]
|
||||
if !ok {
|
||||
for _, req := range reqs {
|
||||
sendUsageLogCreateResult(req.resultCh, usageLogCreateResult{
|
||||
completeUsageLogCreateRequest(req, usageLogCreateResult{
|
||||
inserted: false,
|
||||
err: fmt.Errorf("usage log batch state missing for key=%s", key),
|
||||
})
|
||||
@@ -332,7 +512,7 @@ func (r *usageLogRepository) flushCreateBatch(db *sql.DB, batch []usageLogCreate
|
||||
req.log.ID = state.ID
|
||||
req.log.CreatedAt = state.CreatedAt
|
||||
req.log.RateMultiplier = preparedByKey[key].rateMultiplier
|
||||
sendUsageLogCreateResult(req.resultCh, usageLogCreateResult{
|
||||
completeUsageLogCreateRequest(req, usageLogCreateResult{
|
||||
inserted: idx == 0 && insertedMap[key],
|
||||
err: nil,
|
||||
})
|
||||
@@ -345,56 +525,366 @@ func (r *usageLogRepository) flushCreateBatch(db *sql.DB, batch []usageLogCreate
|
||||
return
|
||||
}
|
||||
|
||||
fallbackCtx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
|
||||
defer cancel()
|
||||
for _, req := range fallback {
|
||||
fallbackCtx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
|
||||
inserted, err := r.createSingle(fallbackCtx, db, req.log)
|
||||
sendUsageLogCreateResult(req.resultCh, usageLogCreateResult{inserted: inserted, err: err})
|
||||
cancel()
|
||||
completeUsageLogCreateRequest(req, usageLogCreateResult{inserted: inserted, err: err})
|
||||
}
|
||||
}
|
||||
|
||||
func (r *usageLogRepository) batchInsertUsageLogs(db *sql.DB, keys []string, preparedByKey map[string]usageLogInsertPrepared) (map[string]bool, map[string]usageLogBatchState, error) {
|
||||
func (r *usageLogRepository) flushBestEffortBatch(db *sql.DB, batch []usageLogBestEffortRequest) {
|
||||
if len(batch) == 0 {
|
||||
return
|
||||
}
|
||||
|
||||
type bestEffortGroup struct {
|
||||
prepared usageLogInsertPrepared
|
||||
apiKeyID int64
|
||||
key string
|
||||
reqs []usageLogBestEffortRequest
|
||||
}
|
||||
|
||||
groupsByKey := make(map[string]*bestEffortGroup, len(batch))
|
||||
groupOrder := make([]*bestEffortGroup, 0, len(batch))
|
||||
preparedList := make([]usageLogInsertPrepared, 0, len(batch))
|
||||
|
||||
for idx, req := range batch {
|
||||
prepared := req.prepared
|
||||
key := fmt.Sprintf("__best_effort_%d", idx)
|
||||
if prepared.requestID != "" {
|
||||
key = usageLogBatchKey(prepared.requestID, req.apiKeyID)
|
||||
}
|
||||
group, exists := groupsByKey[key]
|
||||
if !exists {
|
||||
group = &bestEffortGroup{
|
||||
prepared: prepared,
|
||||
apiKeyID: req.apiKeyID,
|
||||
key: key,
|
||||
}
|
||||
groupsByKey[key] = group
|
||||
groupOrder = append(groupOrder, group)
|
||||
preparedList = append(preparedList, prepared)
|
||||
}
|
||||
group.reqs = append(group.reqs, req)
|
||||
}
|
||||
|
||||
if len(preparedList) == 0 {
|
||||
for _, req := range batch {
|
||||
sendUsageLogBestEffortResult(req.resultCh, nil)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
|
||||
defer cancel()
|
||||
|
||||
query, args := buildUsageLogBestEffortInsertQuery(preparedList)
|
||||
if _, err := db.ExecContext(ctx, query, args...); err != nil {
|
||||
logger.LegacyPrintf("repository.usage_log", "best-effort batch insert failed: %v", err)
|
||||
for _, group := range groupOrder {
|
||||
singleErr := execUsageLogInsertNoResult(ctx, db, group.prepared)
|
||||
if singleErr != nil {
|
||||
logger.LegacyPrintf("repository.usage_log", "best-effort single fallback insert failed: %v", singleErr)
|
||||
} else if group.prepared.requestID != "" && r != nil && r.bestEffortRecent != nil {
|
||||
r.bestEffortRecent.SetDefault(group.key, struct{}{})
|
||||
}
|
||||
for _, req := range group.reqs {
|
||||
sendUsageLogBestEffortResult(req.resultCh, singleErr)
|
||||
}
|
||||
}
|
||||
return
|
||||
}
|
||||
for _, group := range groupOrder {
|
||||
if group.prepared.requestID != "" && r != nil && r.bestEffortRecent != nil {
|
||||
r.bestEffortRecent.SetDefault(group.key, struct{}{})
|
||||
}
|
||||
for _, req := range group.reqs {
|
||||
sendUsageLogBestEffortResult(req.resultCh, nil)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func sendUsageLogBestEffortResult(ch chan error, err error) {
|
||||
if ch == nil {
|
||||
return
|
||||
}
|
||||
select {
|
||||
case ch <- err:
|
||||
default:
|
||||
}
|
||||
}
|
||||
|
||||
func completeUsageLogCreateRequest(req usageLogCreateRequest, res usageLogCreateResult) {
|
||||
if req.shared != nil {
|
||||
req.shared.state.Store(usageLogCreateStateCompleted)
|
||||
}
|
||||
sendUsageLogCreateResult(req.resultCh, res)
|
||||
}
|
||||
|
||||
func (r *usageLogRepository) batchInsertUsageLogs(db *sql.DB, keys []string, preparedByKey map[string]usageLogInsertPrepared) (map[string]bool, map[string]usageLogBatchState, bool, error) {
|
||||
if len(keys) == 0 {
|
||||
return map[string]bool{}, map[string]usageLogBatchState{}, nil
|
||||
return map[string]bool{}, map[string]usageLogBatchState{}, false, nil
|
||||
}
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
|
||||
defer cancel()
|
||||
|
||||
query, args := buildUsageLogBatchInsertQuery(keys, preparedByKey)
|
||||
rows, err := db.QueryContext(ctx, query, args...)
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
var payload []byte
|
||||
if err := db.QueryRowContext(ctx, query, args...).Scan(&payload); err != nil {
|
||||
return nil, nil, true, err
|
||||
}
|
||||
var rows []usageLogBatchRow
|
||||
if err := json.Unmarshal(payload, &rows); err != nil {
|
||||
return nil, nil, false, err
|
||||
}
|
||||
insertedMap := make(map[string]bool, len(keys))
|
||||
for rows.Next() {
|
||||
var (
|
||||
requestID string
|
||||
apiKeyID int64
|
||||
id int64
|
||||
createdAt time.Time
|
||||
)
|
||||
if err := rows.Scan(&requestID, &apiKeyID, &id, &createdAt); err != nil {
|
||||
_ = rows.Close()
|
||||
return nil, nil, err
|
||||
stateMap := make(map[string]usageLogBatchState, len(keys))
|
||||
for _, row := range rows {
|
||||
key := usageLogBatchKey(row.RequestID, row.APIKeyID)
|
||||
insertedMap[key] = row.Inserted
|
||||
stateMap[key] = usageLogBatchState{
|
||||
ID: row.ID,
|
||||
CreatedAt: row.CreatedAt,
|
||||
}
|
||||
insertedMap[usageLogBatchKey(requestID, apiKeyID)] = true
|
||||
}
|
||||
if err := rows.Err(); err != nil {
|
||||
_ = rows.Close()
|
||||
return nil, nil, err
|
||||
if len(stateMap) != len(keys) {
|
||||
return insertedMap, stateMap, false, fmt.Errorf("usage log batch state count mismatch: got=%d want=%d", len(stateMap), len(keys))
|
||||
}
|
||||
_ = rows.Close()
|
||||
|
||||
stateMap, err := loadUsageLogBatchStates(ctx, db, keys, preparedByKey)
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
return insertedMap, stateMap, nil
|
||||
return insertedMap, stateMap, false, nil
|
||||
}
|
||||
|
||||
func buildUsageLogBatchInsertQuery(keys []string, preparedByKey map[string]usageLogInsertPrepared) (string, []any) {
|
||||
var query strings.Builder
|
||||
_, _ = query.WriteString(`
|
||||
WITH input (
|
||||
input_idx,
|
||||
user_id,
|
||||
api_key_id,
|
||||
account_id,
|
||||
request_id,
|
||||
model,
|
||||
group_id,
|
||||
subscription_id,
|
||||
input_tokens,
|
||||
output_tokens,
|
||||
cache_creation_tokens,
|
||||
cache_read_tokens,
|
||||
cache_creation_5m_tokens,
|
||||
cache_creation_1h_tokens,
|
||||
input_cost,
|
||||
output_cost,
|
||||
cache_creation_cost,
|
||||
cache_read_cost,
|
||||
total_cost,
|
||||
actual_cost,
|
||||
rate_multiplier,
|
||||
account_rate_multiplier,
|
||||
billing_type,
|
||||
request_type,
|
||||
stream,
|
||||
openai_ws_mode,
|
||||
duration_ms,
|
||||
first_token_ms,
|
||||
user_agent,
|
||||
ip_address,
|
||||
image_count,
|
||||
image_size,
|
||||
media_type,
|
||||
service_tier,
|
||||
reasoning_effort,
|
||||
cache_ttl_overridden,
|
||||
created_at
|
||||
) AS (VALUES `)
|
||||
|
||||
args := make([]any, 0, len(keys)*37)
|
||||
argPos := 1
|
||||
for idx, key := range keys {
|
||||
if idx > 0 {
|
||||
_, _ = query.WriteString(",")
|
||||
}
|
||||
_, _ = query.WriteString("(")
|
||||
_, _ = query.WriteString("$")
|
||||
_, _ = query.WriteString(strconv.Itoa(argPos))
|
||||
args = append(args, idx)
|
||||
argPos++
|
||||
prepared := preparedByKey[key]
|
||||
for i := 0; i < len(prepared.args); i++ {
|
||||
_, _ = query.WriteString(",")
|
||||
_, _ = query.WriteString("$")
|
||||
_, _ = query.WriteString(strconv.Itoa(argPos))
|
||||
argPos++
|
||||
}
|
||||
_, _ = query.WriteString(")")
|
||||
args = append(args, prepared.args...)
|
||||
}
|
||||
_, _ = query.WriteString(`
|
||||
),
|
||||
inserted AS (
|
||||
INSERT INTO usage_logs (
|
||||
user_id,
|
||||
api_key_id,
|
||||
account_id,
|
||||
request_id,
|
||||
model,
|
||||
group_id,
|
||||
subscription_id,
|
||||
input_tokens,
|
||||
output_tokens,
|
||||
cache_creation_tokens,
|
||||
cache_read_tokens,
|
||||
cache_creation_5m_tokens,
|
||||
cache_creation_1h_tokens,
|
||||
input_cost,
|
||||
output_cost,
|
||||
cache_creation_cost,
|
||||
cache_read_cost,
|
||||
total_cost,
|
||||
actual_cost,
|
||||
rate_multiplier,
|
||||
account_rate_multiplier,
|
||||
billing_type,
|
||||
request_type,
|
||||
stream,
|
||||
openai_ws_mode,
|
||||
duration_ms,
|
||||
first_token_ms,
|
||||
user_agent,
|
||||
ip_address,
|
||||
image_count,
|
||||
image_size,
|
||||
media_type,
|
||||
service_tier,
|
||||
reasoning_effort,
|
||||
cache_ttl_overridden,
|
||||
created_at
|
||||
)
|
||||
SELECT
|
||||
user_id,
|
||||
api_key_id,
|
||||
account_id,
|
||||
request_id,
|
||||
model,
|
||||
group_id,
|
||||
subscription_id,
|
||||
input_tokens,
|
||||
output_tokens,
|
||||
cache_creation_tokens,
|
||||
cache_read_tokens,
|
||||
cache_creation_5m_tokens,
|
||||
cache_creation_1h_tokens,
|
||||
input_cost,
|
||||
output_cost,
|
||||
cache_creation_cost,
|
||||
cache_read_cost,
|
||||
total_cost,
|
||||
actual_cost,
|
||||
rate_multiplier,
|
||||
account_rate_multiplier,
|
||||
billing_type,
|
||||
request_type,
|
||||
stream,
|
||||
openai_ws_mode,
|
||||
duration_ms,
|
||||
first_token_ms,
|
||||
user_agent,
|
||||
ip_address,
|
||||
image_count,
|
||||
image_size,
|
||||
media_type,
|
||||
service_tier,
|
||||
reasoning_effort,
|
||||
cache_ttl_overridden,
|
||||
created_at
|
||||
FROM input
|
||||
ON CONFLICT (request_id, api_key_id) DO UPDATE
|
||||
SET request_id = usage_logs.request_id
|
||||
RETURNING request_id, api_key_id, id, created_at, (xmax = 0) AS inserted
|
||||
)
|
||||
SELECT COALESCE(
|
||||
json_agg(
|
||||
json_build_object(
|
||||
'request_id', inserted.request_id,
|
||||
'api_key_id', inserted.api_key_id,
|
||||
'id', inserted.id,
|
||||
'created_at', inserted.created_at,
|
||||
'inserted', inserted.inserted
|
||||
)
|
||||
ORDER BY input.input_idx
|
||||
),
|
||||
'[]'::json
|
||||
)
|
||||
FROM input
|
||||
JOIN inserted
|
||||
ON inserted.request_id = input.request_id
|
||||
AND inserted.api_key_id = input.api_key_id
|
||||
`)
|
||||
return query.String(), args
|
||||
}
|
||||
|
||||
func buildUsageLogBestEffortInsertQuery(preparedList []usageLogInsertPrepared) (string, []any) {
|
||||
var query strings.Builder
|
||||
_, _ = query.WriteString(`
|
||||
WITH input (
|
||||
user_id,
|
||||
api_key_id,
|
||||
account_id,
|
||||
request_id,
|
||||
model,
|
||||
group_id,
|
||||
subscription_id,
|
||||
input_tokens,
|
||||
output_tokens,
|
||||
cache_creation_tokens,
|
||||
cache_read_tokens,
|
||||
cache_creation_5m_tokens,
|
||||
cache_creation_1h_tokens,
|
||||
input_cost,
|
||||
output_cost,
|
||||
cache_creation_cost,
|
||||
cache_read_cost,
|
||||
total_cost,
|
||||
actual_cost,
|
||||
rate_multiplier,
|
||||
account_rate_multiplier,
|
||||
billing_type,
|
||||
request_type,
|
||||
stream,
|
||||
openai_ws_mode,
|
||||
duration_ms,
|
||||
first_token_ms,
|
||||
user_agent,
|
||||
ip_address,
|
||||
image_count,
|
||||
image_size,
|
||||
media_type,
|
||||
service_tier,
|
||||
reasoning_effort,
|
||||
cache_ttl_overridden,
|
||||
created_at
|
||||
) AS (VALUES `)
|
||||
|
||||
args := make([]any, 0, len(preparedList)*36)
|
||||
argPos := 1
|
||||
for idx, prepared := range preparedList {
|
||||
if idx > 0 {
|
||||
_, _ = query.WriteString(",")
|
||||
}
|
||||
_, _ = query.WriteString("(")
|
||||
for i := 0; i < len(prepared.args); i++ {
|
||||
if i > 0 {
|
||||
_, _ = query.WriteString(",")
|
||||
}
|
||||
_, _ = query.WriteString("$")
|
||||
_, _ = query.WriteString(strconv.Itoa(argPos))
|
||||
argPos++
|
||||
}
|
||||
_, _ = query.WriteString(")")
|
||||
args = append(args, prepared.args...)
|
||||
}
|
||||
|
||||
_, _ = query.WriteString(`
|
||||
)
|
||||
INSERT INTO usage_logs (
|
||||
user_id,
|
||||
api_key_id,
|
||||
@@ -432,80 +922,101 @@ func buildUsageLogBatchInsertQuery(keys []string, preparedByKey map[string]usage
|
||||
reasoning_effort,
|
||||
cache_ttl_overridden,
|
||||
created_at
|
||||
) VALUES `)
|
||||
|
||||
args := make([]any, 0, len(keys)*36)
|
||||
argPos := 1
|
||||
for idx, key := range keys {
|
||||
if idx > 0 {
|
||||
_, _ = query.WriteString(",")
|
||||
}
|
||||
_, _ = query.WriteString("(")
|
||||
prepared := preparedByKey[key]
|
||||
for i := 0; i < len(prepared.args); i++ {
|
||||
if i > 0 {
|
||||
_, _ = query.WriteString(",")
|
||||
}
|
||||
_, _ = query.WriteString("$")
|
||||
_, _ = query.WriteString(strconv.Itoa(argPos))
|
||||
argPos++
|
||||
}
|
||||
_, _ = query.WriteString(")")
|
||||
args = append(args, prepared.args...)
|
||||
}
|
||||
_, _ = query.WriteString(`
|
||||
)
|
||||
SELECT
|
||||
user_id,
|
||||
api_key_id,
|
||||
account_id,
|
||||
request_id,
|
||||
model,
|
||||
group_id,
|
||||
subscription_id,
|
||||
input_tokens,
|
||||
output_tokens,
|
||||
cache_creation_tokens,
|
||||
cache_read_tokens,
|
||||
cache_creation_5m_tokens,
|
||||
cache_creation_1h_tokens,
|
||||
input_cost,
|
||||
output_cost,
|
||||
cache_creation_cost,
|
||||
cache_read_cost,
|
||||
total_cost,
|
||||
actual_cost,
|
||||
rate_multiplier,
|
||||
account_rate_multiplier,
|
||||
billing_type,
|
||||
request_type,
|
||||
stream,
|
||||
openai_ws_mode,
|
||||
duration_ms,
|
||||
first_token_ms,
|
||||
user_agent,
|
||||
ip_address,
|
||||
image_count,
|
||||
image_size,
|
||||
media_type,
|
||||
service_tier,
|
||||
reasoning_effort,
|
||||
cache_ttl_overridden,
|
||||
created_at
|
||||
FROM input
|
||||
ON CONFLICT (request_id, api_key_id) DO NOTHING
|
||||
RETURNING request_id, api_key_id, id, created_at
|
||||
`)
|
||||
|
||||
return query.String(), args
|
||||
}
|
||||
|
||||
func loadUsageLogBatchStates(ctx context.Context, db *sql.DB, keys []string, preparedByKey map[string]usageLogInsertPrepared) (map[string]usageLogBatchState, error) {
|
||||
var query strings.Builder
|
||||
_, _ = query.WriteString(`SELECT request_id, api_key_id, id, created_at FROM usage_logs WHERE `)
|
||||
args := make([]any, 0, len(keys)*2)
|
||||
argPos := 1
|
||||
for idx, key := range keys {
|
||||
if idx > 0 {
|
||||
_, _ = query.WriteString(" OR ")
|
||||
}
|
||||
prepared := preparedByKey[key]
|
||||
apiKeyID := prepared.args[1]
|
||||
_, _ = query.WriteString("(request_id = $")
|
||||
_, _ = query.WriteString(strconv.Itoa(argPos))
|
||||
_, _ = query.WriteString(" AND api_key_id = $")
|
||||
_, _ = query.WriteString(strconv.Itoa(argPos + 1))
|
||||
_, _ = query.WriteString(")")
|
||||
args = append(args, prepared.requestID, apiKeyID)
|
||||
argPos += 2
|
||||
}
|
||||
|
||||
rows, err := db.QueryContext(ctx, query.String(), args...)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer func() { _ = rows.Close() }()
|
||||
|
||||
stateMap := make(map[string]usageLogBatchState, len(keys))
|
||||
for rows.Next() {
|
||||
var (
|
||||
requestID string
|
||||
apiKeyID int64
|
||||
id int64
|
||||
createdAt time.Time
|
||||
func execUsageLogInsertNoResult(ctx context.Context, sqlq sqlExecutor, prepared usageLogInsertPrepared) error {
|
||||
_, err := sqlq.ExecContext(ctx, `
|
||||
INSERT INTO usage_logs (
|
||||
user_id,
|
||||
api_key_id,
|
||||
account_id,
|
||||
request_id,
|
||||
model,
|
||||
group_id,
|
||||
subscription_id,
|
||||
input_tokens,
|
||||
output_tokens,
|
||||
cache_creation_tokens,
|
||||
cache_read_tokens,
|
||||
cache_creation_5m_tokens,
|
||||
cache_creation_1h_tokens,
|
||||
input_cost,
|
||||
output_cost,
|
||||
cache_creation_cost,
|
||||
cache_read_cost,
|
||||
total_cost,
|
||||
actual_cost,
|
||||
rate_multiplier,
|
||||
account_rate_multiplier,
|
||||
billing_type,
|
||||
request_type,
|
||||
stream,
|
||||
openai_ws_mode,
|
||||
duration_ms,
|
||||
first_token_ms,
|
||||
user_agent,
|
||||
ip_address,
|
||||
image_count,
|
||||
image_size,
|
||||
media_type,
|
||||
service_tier,
|
||||
reasoning_effort,
|
||||
cache_ttl_overridden,
|
||||
created_at
|
||||
) VALUES (
|
||||
$1, $2, $3, $4, $5,
|
||||
$6, $7,
|
||||
$8, $9, $10, $11,
|
||||
$12, $13,
|
||||
$14, $15, $16, $17, $18, $19,
|
||||
$20, $21, $22, $23, $24, $25, $26, $27, $28, $29, $30, $31, $32, $33, $34, $35, $36
|
||||
)
|
||||
if err := rows.Scan(&requestID, &apiKeyID, &id, &createdAt); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
stateMap[usageLogBatchKey(requestID, apiKeyID)] = usageLogBatchState{
|
||||
ID: id,
|
||||
CreatedAt: createdAt,
|
||||
}
|
||||
}
|
||||
if err := rows.Err(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return stateMap, nil
|
||||
ON CONFLICT (request_id, api_key_id) DO NOTHING
|
||||
`, prepared.args...)
|
||||
return err
|
||||
}
|
||||
|
||||
func prepareUsageLogInsert(log *service.UsageLog) usageLogInsertPrepared {
|
||||
@@ -597,6 +1108,14 @@ func sendUsageLogCreateResult(ch chan usageLogCreateResult, res usageLogCreateRe
|
||||
}
|
||||
}
|
||||
|
||||
func (r *usageLogRepository) bestEffortRecentKey(requestID string, apiKeyID int64) (string, bool) {
|
||||
requestID = strings.TrimSpace(requestID)
|
||||
if requestID == "" || r == nil || r.bestEffortRecent == nil {
|
||||
return "", false
|
||||
}
|
||||
return usageLogBatchKey(requestID, apiKeyID), true
|
||||
}
|
||||
|
||||
func (r *usageLogRepository) GetByID(ctx context.Context, id int64) (log *service.UsageLog, err error) {
|
||||
query := "SELECT " + usageLogSelectColumns + " FROM usage_logs WHERE id = $1"
|
||||
rows, err := r.sql.QueryContext(ctx, query, id)
|
||||
|
||||
@@ -183,6 +183,214 @@ func TestUsageLogRepositoryCreate_BatchPathDuplicateRequestID(t *testing.T) {
|
||||
require.Equal(t, 1, count)
|
||||
}
|
||||
|
||||
func TestUsageLogRepositoryFlushCreateBatch_DeduplicatesSameKeyInMemory(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
client := testEntClient(t)
|
||||
repo := newUsageLogRepositoryWithSQL(client, integrationDB)
|
||||
|
||||
user := mustCreateUser(t, client, &service.User{Email: fmt.Sprintf("usage-batch-memdup-%d@example.com", time.Now().UnixNano())})
|
||||
apiKey := mustCreateApiKey(t, client, &service.APIKey{UserID: user.ID, Key: "sk-usage-batch-memdup-" + uuid.NewString(), Name: "k"})
|
||||
account := mustCreateAccount(t, client, &service.Account{Name: "acc-usage-batch-memdup-" + uuid.NewString()})
|
||||
requestID := uuid.NewString()
|
||||
|
||||
const total = 8
|
||||
batch := make([]usageLogCreateRequest, 0, total)
|
||||
logs := make([]*service.UsageLog, 0, total)
|
||||
|
||||
for i := 0; i < total; i++ {
|
||||
log := &service.UsageLog{
|
||||
UserID: user.ID,
|
||||
APIKeyID: apiKey.ID,
|
||||
AccountID: account.ID,
|
||||
RequestID: requestID,
|
||||
Model: "claude-3",
|
||||
InputTokens: 10 + i,
|
||||
OutputTokens: 20 + i,
|
||||
TotalCost: 0.5,
|
||||
ActualCost: 0.5,
|
||||
CreatedAt: time.Now().UTC(),
|
||||
}
|
||||
logs = append(logs, log)
|
||||
batch = append(batch, usageLogCreateRequest{
|
||||
log: log,
|
||||
prepared: prepareUsageLogInsert(log),
|
||||
resultCh: make(chan usageLogCreateResult, 1),
|
||||
})
|
||||
}
|
||||
|
||||
repo.flushCreateBatch(integrationDB, batch)
|
||||
|
||||
insertedCount := 0
|
||||
var firstID int64
|
||||
for idx, req := range batch {
|
||||
res := <-req.resultCh
|
||||
require.NoError(t, res.err)
|
||||
if res.inserted {
|
||||
insertedCount++
|
||||
}
|
||||
require.NotZero(t, logs[idx].ID)
|
||||
if idx == 0 {
|
||||
firstID = logs[idx].ID
|
||||
} else {
|
||||
require.Equal(t, firstID, logs[idx].ID)
|
||||
}
|
||||
}
|
||||
|
||||
require.Equal(t, 1, insertedCount)
|
||||
|
||||
var count int
|
||||
require.NoError(t, integrationDB.QueryRowContext(ctx, "SELECT COUNT(*) FROM usage_logs WHERE request_id = $1 AND api_key_id = $2", requestID, apiKey.ID).Scan(&count))
|
||||
require.Equal(t, 1, count)
|
||||
}
|
||||
|
||||
func TestUsageLogRepositoryCreateBestEffort_BatchPathDuplicateRequestID(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
client := testEntClient(t)
|
||||
repo := newUsageLogRepositoryWithSQL(client, integrationDB)
|
||||
|
||||
user := mustCreateUser(t, client, &service.User{Email: fmt.Sprintf("usage-best-effort-dup-%d@example.com", time.Now().UnixNano())})
|
||||
apiKey := mustCreateApiKey(t, client, &service.APIKey{UserID: user.ID, Key: "sk-usage-best-effort-dup-" + uuid.NewString(), Name: "k"})
|
||||
account := mustCreateAccount(t, client, &service.Account{Name: "acc-usage-best-effort-dup-" + uuid.NewString()})
|
||||
requestID := uuid.NewString()
|
||||
|
||||
log1 := &service.UsageLog{
|
||||
UserID: user.ID,
|
||||
APIKeyID: apiKey.ID,
|
||||
AccountID: account.ID,
|
||||
RequestID: requestID,
|
||||
Model: "claude-3",
|
||||
InputTokens: 10,
|
||||
OutputTokens: 20,
|
||||
TotalCost: 0.5,
|
||||
ActualCost: 0.5,
|
||||
CreatedAt: time.Now().UTC(),
|
||||
}
|
||||
log2 := &service.UsageLog{
|
||||
UserID: user.ID,
|
||||
APIKeyID: apiKey.ID,
|
||||
AccountID: account.ID,
|
||||
RequestID: requestID,
|
||||
Model: "claude-3",
|
||||
InputTokens: 10,
|
||||
OutputTokens: 20,
|
||||
TotalCost: 0.5,
|
||||
ActualCost: 0.5,
|
||||
CreatedAt: time.Now().UTC(),
|
||||
}
|
||||
|
||||
require.NoError(t, repo.CreateBestEffort(ctx, log1))
|
||||
require.NoError(t, repo.CreateBestEffort(ctx, log2))
|
||||
|
||||
require.Eventually(t, func() bool {
|
||||
var count int
|
||||
err := integrationDB.QueryRowContext(ctx, "SELECT COUNT(*) FROM usage_logs WHERE request_id = $1 AND api_key_id = $2", requestID, apiKey.ID).Scan(&count)
|
||||
return err == nil && count == 1
|
||||
}, 3*time.Second, 20*time.Millisecond)
|
||||
}
|
||||
|
||||
func TestUsageLogRepositoryCreate_BatchPathCanceledContextMarksNotPersisted(t *testing.T) {
|
||||
client := testEntClient(t)
|
||||
repo := newUsageLogRepositoryWithSQL(client, integrationDB)
|
||||
|
||||
user := mustCreateUser(t, client, &service.User{Email: fmt.Sprintf("usage-cancel-%d@example.com", time.Now().UnixNano())})
|
||||
apiKey := mustCreateApiKey(t, client, &service.APIKey{UserID: user.ID, Key: "sk-usage-cancel-" + uuid.NewString(), Name: "k"})
|
||||
account := mustCreateAccount(t, client, &service.Account{Name: "acc-usage-cancel-" + uuid.NewString()})
|
||||
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
cancel()
|
||||
|
||||
inserted, err := repo.Create(ctx, &service.UsageLog{
|
||||
UserID: user.ID,
|
||||
APIKeyID: apiKey.ID,
|
||||
AccountID: account.ID,
|
||||
RequestID: uuid.NewString(),
|
||||
Model: "claude-3",
|
||||
InputTokens: 10,
|
||||
OutputTokens: 20,
|
||||
TotalCost: 0.5,
|
||||
ActualCost: 0.5,
|
||||
CreatedAt: time.Now().UTC(),
|
||||
})
|
||||
|
||||
require.False(t, inserted)
|
||||
require.Error(t, err)
|
||||
require.True(t, service.IsUsageLogCreateNotPersisted(err))
|
||||
}
|
||||
|
||||
func TestUsageLogRepositoryCreate_BatchPathCanceledAfterQueueMarksNotPersisted(t *testing.T) {
|
||||
client := testEntClient(t)
|
||||
repo := newUsageLogRepositoryWithSQL(client, integrationDB)
|
||||
repo.createBatchCh = make(chan usageLogCreateRequest, 1)
|
||||
|
||||
user := mustCreateUser(t, client, &service.User{Email: fmt.Sprintf("usage-cancel-queued-%d@example.com", time.Now().UnixNano())})
|
||||
apiKey := mustCreateApiKey(t, client, &service.APIKey{UserID: user.ID, Key: "sk-usage-cancel-queued-" + uuid.NewString(), Name: "k"})
|
||||
account := mustCreateAccount(t, client, &service.Account{Name: "acc-usage-cancel-queued-" + uuid.NewString()})
|
||||
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
errCh := make(chan error, 1)
|
||||
|
||||
go func() {
|
||||
_, err := repo.createBatched(ctx, &service.UsageLog{
|
||||
UserID: user.ID,
|
||||
APIKeyID: apiKey.ID,
|
||||
AccountID: account.ID,
|
||||
RequestID: uuid.NewString(),
|
||||
Model: "claude-3",
|
||||
InputTokens: 10,
|
||||
OutputTokens: 20,
|
||||
TotalCost: 0.5,
|
||||
ActualCost: 0.5,
|
||||
CreatedAt: time.Now().UTC(),
|
||||
})
|
||||
errCh <- err
|
||||
}()
|
||||
|
||||
req := <-repo.createBatchCh
|
||||
require.NotNil(t, req.shared)
|
||||
cancel()
|
||||
|
||||
err := <-errCh
|
||||
require.Error(t, err)
|
||||
require.True(t, service.IsUsageLogCreateNotPersisted(err))
|
||||
completeUsageLogCreateRequest(req, usageLogCreateResult{inserted: false, err: service.MarkUsageLogCreateNotPersisted(context.Canceled)})
|
||||
}
|
||||
|
||||
func TestUsageLogRepositoryFlushCreateBatch_CanceledRequestReturnsNotPersisted(t *testing.T) {
|
||||
client := testEntClient(t)
|
||||
repo := newUsageLogRepositoryWithSQL(client, integrationDB)
|
||||
|
||||
user := mustCreateUser(t, client, &service.User{Email: fmt.Sprintf("usage-flush-cancel-%d@example.com", time.Now().UnixNano())})
|
||||
apiKey := mustCreateApiKey(t, client, &service.APIKey{UserID: user.ID, Key: "sk-usage-flush-cancel-" + uuid.NewString(), Name: "k"})
|
||||
account := mustCreateAccount(t, client, &service.Account{Name: "acc-usage-flush-cancel-" + uuid.NewString()})
|
||||
|
||||
log := &service.UsageLog{
|
||||
UserID: user.ID,
|
||||
APIKeyID: apiKey.ID,
|
||||
AccountID: account.ID,
|
||||
RequestID: uuid.NewString(),
|
||||
Model: "claude-3",
|
||||
InputTokens: 10,
|
||||
OutputTokens: 20,
|
||||
TotalCost: 0.5,
|
||||
ActualCost: 0.5,
|
||||
CreatedAt: time.Now().UTC(),
|
||||
}
|
||||
req := usageLogCreateRequest{
|
||||
log: log,
|
||||
prepared: prepareUsageLogInsert(log),
|
||||
shared: &usageLogCreateShared{},
|
||||
resultCh: make(chan usageLogCreateResult, 1),
|
||||
}
|
||||
req.shared.state.Store(usageLogCreateStateCanceled)
|
||||
|
||||
repo.flushCreateBatch(integrationDB, []usageLogCreateRequest{req})
|
||||
|
||||
res := <-req.resultCh
|
||||
require.False(t, res.inserted)
|
||||
require.Error(t, res.err)
|
||||
require.True(t, service.IsUsageLogCreateNotPersisted(res.err))
|
||||
}
|
||||
|
||||
func (s *UsageLogRepoSuite) TestGetByID() {
|
||||
user := mustCreateUser(s.T(), s.client, &service.User{Email: "getbyid@test.com"})
|
||||
apiKey := mustCreateApiKey(s.T(), s.client, &service.APIKey{UserID: user.ID, Key: "sk-getbyid", Name: "k"})
|
||||
|
||||
@@ -62,6 +62,7 @@ var ProviderSet = wire.NewSet(
|
||||
NewAnnouncementRepository,
|
||||
NewAnnouncementReadRepository,
|
||||
NewUsageLogRepository,
|
||||
NewUsageBillingRepository,
|
||||
NewIdempotencyRepository,
|
||||
NewUsageCleanupRepository,
|
||||
NewDashboardAggregationRepository,
|
||||
|
||||
@@ -35,6 +35,7 @@ type DashboardAggregationRepository interface {
|
||||
UpdateAggregationWatermark(ctx context.Context, aggregatedAt time.Time) error
|
||||
CleanupAggregates(ctx context.Context, hourlyCutoff, dailyCutoff time.Time) error
|
||||
CleanupUsageLogs(ctx context.Context, cutoff time.Time) error
|
||||
CleanupUsageBillingDedup(ctx context.Context, cutoff time.Time) error
|
||||
EnsureUsageLogsPartitions(ctx context.Context, now time.Time) error
|
||||
}
|
||||
|
||||
@@ -296,6 +297,7 @@ func (s *DashboardAggregationService) maybeCleanupRetention(ctx context.Context,
|
||||
hourlyCutoff := now.AddDate(0, 0, -s.cfg.Retention.HourlyDays)
|
||||
dailyCutoff := now.AddDate(0, 0, -s.cfg.Retention.DailyDays)
|
||||
usageCutoff := now.AddDate(0, 0, -s.cfg.Retention.UsageLogsDays)
|
||||
dedupCutoff := now.AddDate(0, 0, -s.cfg.Retention.UsageBillingDedupDays)
|
||||
|
||||
aggErr := s.repo.CleanupAggregates(ctx, hourlyCutoff, dailyCutoff)
|
||||
if aggErr != nil {
|
||||
@@ -305,7 +307,11 @@ func (s *DashboardAggregationService) maybeCleanupRetention(ctx context.Context,
|
||||
if usageErr != nil {
|
||||
logger.LegacyPrintf("service.dashboard_aggregation", "[DashboardAggregation] usage_logs 保留清理失败: %v", usageErr)
|
||||
}
|
||||
if aggErr == nil && usageErr == nil {
|
||||
dedupErr := s.repo.CleanupUsageBillingDedup(ctx, dedupCutoff)
|
||||
if dedupErr != nil {
|
||||
logger.LegacyPrintf("service.dashboard_aggregation", "[DashboardAggregation] usage_billing_dedup 保留清理失败: %v", dedupErr)
|
||||
}
|
||||
if aggErr == nil && usageErr == nil && dedupErr == nil {
|
||||
s.lastRetentionCleanup.Store(now)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -12,12 +12,18 @@ import (
|
||||
|
||||
type dashboardAggregationRepoTestStub struct {
|
||||
aggregateCalls int
|
||||
recomputeCalls int
|
||||
cleanupUsageCalls int
|
||||
cleanupDedupCalls int
|
||||
ensurePartitionCalls int
|
||||
lastStart time.Time
|
||||
lastEnd time.Time
|
||||
watermark time.Time
|
||||
aggregateErr error
|
||||
cleanupAggregatesErr error
|
||||
cleanupUsageErr error
|
||||
cleanupDedupErr error
|
||||
ensurePartitionErr error
|
||||
}
|
||||
|
||||
func (s *dashboardAggregationRepoTestStub) AggregateRange(ctx context.Context, start, end time.Time) error {
|
||||
@@ -28,6 +34,7 @@ func (s *dashboardAggregationRepoTestStub) AggregateRange(ctx context.Context, s
|
||||
}
|
||||
|
||||
func (s *dashboardAggregationRepoTestStub) RecomputeRange(ctx context.Context, start, end time.Time) error {
|
||||
s.recomputeCalls++
|
||||
return s.AggregateRange(ctx, start, end)
|
||||
}
|
||||
|
||||
@@ -44,11 +51,18 @@ func (s *dashboardAggregationRepoTestStub) CleanupAggregates(ctx context.Context
|
||||
}
|
||||
|
||||
func (s *dashboardAggregationRepoTestStub) CleanupUsageLogs(ctx context.Context, cutoff time.Time) error {
|
||||
s.cleanupUsageCalls++
|
||||
return s.cleanupUsageErr
|
||||
}
|
||||
|
||||
func (s *dashboardAggregationRepoTestStub) CleanupUsageBillingDedup(ctx context.Context, cutoff time.Time) error {
|
||||
s.cleanupDedupCalls++
|
||||
return s.cleanupDedupErr
|
||||
}
|
||||
|
||||
func (s *dashboardAggregationRepoTestStub) EnsureUsageLogsPartitions(ctx context.Context, now time.Time) error {
|
||||
return nil
|
||||
s.ensurePartitionCalls++
|
||||
return s.ensurePartitionErr
|
||||
}
|
||||
|
||||
func TestDashboardAggregationService_RunScheduledAggregation_EpochUsesRetentionStart(t *testing.T) {
|
||||
@@ -90,6 +104,50 @@ func TestDashboardAggregationService_CleanupRetentionFailure_DoesNotRecord(t *te
|
||||
svc.maybeCleanupRetention(context.Background(), time.Now().UTC())
|
||||
|
||||
require.Nil(t, svc.lastRetentionCleanup.Load())
|
||||
require.Equal(t, 1, repo.cleanupUsageCalls)
|
||||
require.Equal(t, 1, repo.cleanupDedupCalls)
|
||||
}
|
||||
|
||||
func TestDashboardAggregationService_CleanupDedupFailure_DoesNotRecord(t *testing.T) {
|
||||
repo := &dashboardAggregationRepoTestStub{cleanupDedupErr: errors.New("dedup cleanup failed")}
|
||||
svc := &DashboardAggregationService{
|
||||
repo: repo,
|
||||
cfg: config.DashboardAggregationConfig{
|
||||
Retention: config.DashboardAggregationRetentionConfig{
|
||||
UsageLogsDays: 1,
|
||||
HourlyDays: 1,
|
||||
DailyDays: 1,
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
svc.maybeCleanupRetention(context.Background(), time.Now().UTC())
|
||||
|
||||
require.Nil(t, svc.lastRetentionCleanup.Load())
|
||||
require.Equal(t, 1, repo.cleanupDedupCalls)
|
||||
}
|
||||
|
||||
func TestDashboardAggregationService_PartitionFailure_DoesNotAggregate(t *testing.T) {
|
||||
repo := &dashboardAggregationRepoTestStub{ensurePartitionErr: errors.New("partition failed")}
|
||||
svc := &DashboardAggregationService{
|
||||
repo: repo,
|
||||
cfg: config.DashboardAggregationConfig{
|
||||
Enabled: true,
|
||||
IntervalSeconds: 60,
|
||||
LookbackSeconds: 120,
|
||||
Retention: config.DashboardAggregationRetentionConfig{
|
||||
UsageLogsDays: 1,
|
||||
UsageBillingDedupDays: 2,
|
||||
HourlyDays: 1,
|
||||
DailyDays: 1,
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
svc.runScheduledAggregation()
|
||||
|
||||
require.Equal(t, 1, repo.ensurePartitionCalls)
|
||||
require.Equal(t, 1, repo.aggregateCalls)
|
||||
}
|
||||
|
||||
func TestDashboardAggregationService_TriggerBackfill_TooLarge(t *testing.T) {
|
||||
|
||||
@@ -124,6 +124,10 @@ func (s *dashboardAggregationRepoStub) CleanupUsageLogs(ctx context.Context, cut
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *dashboardAggregationRepoStub) CleanupUsageBillingDedup(ctx context.Context, cutoff time.Time) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *dashboardAggregationRepoStub) EnsureUsageLogsPartitions(ctx context.Context, now time.Time) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -136,16 +136,18 @@ func TestGatewayService_AnthropicAPIKeyPassthrough_ForwardStreamPreservesBodyAnd
|
||||
},
|
||||
}
|
||||
|
||||
svc := &GatewayService{
|
||||
cfg: &config.Config{
|
||||
Gateway: config.GatewayConfig{
|
||||
MaxLineSize: defaultMaxLineSize,
|
||||
},
|
||||
cfg := &config.Config{
|
||||
Gateway: config.GatewayConfig{
|
||||
MaxLineSize: defaultMaxLineSize,
|
||||
},
|
||||
httpUpstream: upstream,
|
||||
rateLimitService: &RateLimitService{},
|
||||
deferredService: &DeferredService{},
|
||||
billingCacheService: nil,
|
||||
}
|
||||
svc := &GatewayService{
|
||||
cfg: cfg,
|
||||
responseHeaderFilter: compileResponseHeaderFilter(cfg),
|
||||
httpUpstream: upstream,
|
||||
rateLimitService: &RateLimitService{},
|
||||
deferredService: &DeferredService{},
|
||||
billingCacheService: nil,
|
||||
}
|
||||
|
||||
account := &Account{
|
||||
@@ -221,14 +223,16 @@ func TestGatewayService_AnthropicAPIKeyPassthrough_ForwardCountTokensPreservesBo
|
||||
},
|
||||
}
|
||||
|
||||
svc := &GatewayService{
|
||||
cfg: &config.Config{
|
||||
Gateway: config.GatewayConfig{
|
||||
MaxLineSize: defaultMaxLineSize,
|
||||
},
|
||||
cfg := &config.Config{
|
||||
Gateway: config.GatewayConfig{
|
||||
MaxLineSize: defaultMaxLineSize,
|
||||
},
|
||||
httpUpstream: upstream,
|
||||
rateLimitService: &RateLimitService{},
|
||||
}
|
||||
svc := &GatewayService{
|
||||
cfg: cfg,
|
||||
responseHeaderFilter: compileResponseHeaderFilter(cfg),
|
||||
httpUpstream: upstream,
|
||||
rateLimitService: &RateLimitService{},
|
||||
}
|
||||
|
||||
account := &Account{
|
||||
@@ -727,6 +731,39 @@ func TestGatewayService_AnthropicAPIKeyPassthrough_StreamingStillCollectsUsageAf
|
||||
require.Equal(t, 5, result.usage.OutputTokens)
|
||||
}
|
||||
|
||||
func TestGatewayService_AnthropicAPIKeyPassthrough_MissingTerminalEventReturnsError(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
|
||||
rec := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(rec)
|
||||
c.Request = httptest.NewRequest(http.MethodPost, "/v1/messages", nil)
|
||||
|
||||
svc := &GatewayService{
|
||||
cfg: &config.Config{
|
||||
Gateway: config.GatewayConfig{
|
||||
MaxLineSize: defaultMaxLineSize,
|
||||
},
|
||||
},
|
||||
rateLimitService: &RateLimitService{},
|
||||
}
|
||||
|
||||
resp := &http.Response{
|
||||
StatusCode: http.StatusOK,
|
||||
Header: http.Header{"Content-Type": []string{"text/event-stream"}},
|
||||
Body: io.NopCloser(strings.NewReader(strings.Join([]string{
|
||||
`data: {"type":"message_start","message":{"usage":{"input_tokens":11}}}`,
|
||||
"",
|
||||
`data: {"type":"message_delta","usage":{"output_tokens":5}}`,
|
||||
"",
|
||||
}, "\n"))),
|
||||
}
|
||||
|
||||
result, err := svc.handleStreamingResponseAnthropicAPIKeyPassthrough(context.Background(), resp, c, &Account{ID: 1}, time.Now(), "claude-3-7-sonnet-20250219")
|
||||
require.Error(t, err)
|
||||
require.Contains(t, err.Error(), "missing terminal event")
|
||||
require.NotNil(t, result)
|
||||
}
|
||||
|
||||
func TestGatewayService_AnthropicAPIKeyPassthrough_ForwardDirect_NonStreamingSuccess(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
rec := httptest.NewRecorder()
|
||||
@@ -1074,7 +1111,8 @@ func TestGatewayService_AnthropicAPIKeyPassthrough_StreamingTimeoutAfterClientDi
|
||||
_ = pr.Close()
|
||||
<-done
|
||||
|
||||
require.NoError(t, err)
|
||||
require.Error(t, err)
|
||||
require.Contains(t, err.Error(), "stream usage incomplete after timeout")
|
||||
require.NotNil(t, result)
|
||||
require.True(t, result.clientDisconnect)
|
||||
require.Equal(t, 9, result.usage.InputTokens)
|
||||
@@ -1103,7 +1141,8 @@ func TestGatewayService_AnthropicAPIKeyPassthrough_StreamingContextCanceled(t *t
|
||||
}
|
||||
|
||||
result, err := svc.handleStreamingResponseAnthropicAPIKeyPassthrough(context.Background(), resp, c, &Account{ID: 3}, time.Now(), "claude-3-7-sonnet-20250219")
|
||||
require.NoError(t, err)
|
||||
require.Error(t, err)
|
||||
require.Contains(t, err.Error(), "stream usage incomplete")
|
||||
require.NotNil(t, result)
|
||||
require.True(t, result.clientDisconnect)
|
||||
}
|
||||
@@ -1133,7 +1172,8 @@ func TestGatewayService_AnthropicAPIKeyPassthrough_StreamingUpstreamReadErrorAft
|
||||
}
|
||||
|
||||
result, err := svc.handleStreamingResponseAnthropicAPIKeyPassthrough(context.Background(), resp, c, &Account{ID: 4}, time.Now(), "claude-3-7-sonnet-20250219")
|
||||
require.NoError(t, err)
|
||||
require.Error(t, err)
|
||||
require.Contains(t, err.Error(), "stream usage incomplete after disconnect")
|
||||
require.NotNil(t, result)
|
||||
require.True(t, result.clientDisconnect)
|
||||
require.Equal(t, 8, result.usage.InputTokens)
|
||||
|
||||
261
backend/internal/service/gateway_record_usage_test.go
Normal file
261
backend/internal/service/gateway_record_usage_test.go
Normal file
@@ -0,0 +1,261 @@
|
||||
//go:build unit
|
||||
|
||||
package service
|
||||
|
||||
import (
|
||||
"context"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/config"
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/ctxkey"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func newGatewayRecordUsageServiceForTest(usageRepo UsageLogRepository, userRepo UserRepository, subRepo UserSubscriptionRepository) *GatewayService {
|
||||
cfg := &config.Config{}
|
||||
cfg.Default.RateMultiplier = 1.1
|
||||
return NewGatewayService(
|
||||
nil,
|
||||
nil,
|
||||
usageRepo,
|
||||
nil,
|
||||
userRepo,
|
||||
subRepo,
|
||||
nil,
|
||||
nil,
|
||||
cfg,
|
||||
nil,
|
||||
nil,
|
||||
NewBillingService(cfg, nil),
|
||||
nil,
|
||||
&BillingCacheService{},
|
||||
nil,
|
||||
nil,
|
||||
&DeferredService{},
|
||||
nil,
|
||||
nil,
|
||||
nil,
|
||||
nil,
|
||||
nil,
|
||||
)
|
||||
}
|
||||
|
||||
func newGatewayRecordUsageServiceWithBillingRepoForTest(usageRepo UsageLogRepository, billingRepo UsageBillingRepository, userRepo UserRepository, subRepo UserSubscriptionRepository) *GatewayService {
|
||||
svc := newGatewayRecordUsageServiceForTest(usageRepo, userRepo, subRepo)
|
||||
svc.usageBillingRepo = billingRepo
|
||||
return svc
|
||||
}
|
||||
|
||||
func TestGatewayServiceRecordUsage_BillingUsesDetachedContext(t *testing.T) {
|
||||
usageRepo := &openAIRecordUsageLogRepoStub{inserted: false, err: context.DeadlineExceeded}
|
||||
userRepo := &openAIRecordUsageUserRepoStub{}
|
||||
subRepo := &openAIRecordUsageSubRepoStub{}
|
||||
quotaSvc := &openAIRecordUsageAPIKeyQuotaStub{}
|
||||
svc := newGatewayRecordUsageServiceForTest(usageRepo, userRepo, subRepo)
|
||||
|
||||
reqCtx, cancel := context.WithCancel(context.Background())
|
||||
cancel()
|
||||
|
||||
err := svc.RecordUsage(reqCtx, &RecordUsageInput{
|
||||
Result: &ForwardResult{
|
||||
RequestID: "gateway_detached_ctx",
|
||||
Usage: ClaudeUsage{
|
||||
InputTokens: 10,
|
||||
OutputTokens: 6,
|
||||
},
|
||||
Model: "claude-sonnet-4",
|
||||
Duration: time.Second,
|
||||
},
|
||||
APIKey: &APIKey{
|
||||
ID: 501,
|
||||
Quota: 100,
|
||||
},
|
||||
User: &User{ID: 601},
|
||||
Account: &Account{ID: 701},
|
||||
APIKeyService: quotaSvc,
|
||||
})
|
||||
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, 1, usageRepo.calls)
|
||||
require.Equal(t, 1, userRepo.deductCalls)
|
||||
require.NoError(t, userRepo.lastCtxErr)
|
||||
require.Equal(t, 1, quotaSvc.quotaCalls)
|
||||
require.NoError(t, quotaSvc.lastQuotaCtxErr)
|
||||
}
|
||||
|
||||
func TestGatewayServiceRecordUsage_BillingFingerprintIncludesRequestPayloadHash(t *testing.T) {
|
||||
usageRepo := &openAIRecordUsageLogRepoStub{}
|
||||
billingRepo := &openAIRecordUsageBillingRepoStub{result: &UsageBillingApplyResult{Applied: true}}
|
||||
svc := newGatewayRecordUsageServiceWithBillingRepoForTest(usageRepo, billingRepo, &openAIRecordUsageUserRepoStub{}, &openAIRecordUsageSubRepoStub{})
|
||||
|
||||
payloadHash := HashUsageRequestPayload([]byte(`{"messages":[{"role":"user","content":"hello"}]}`))
|
||||
err := svc.RecordUsage(context.Background(), &RecordUsageInput{
|
||||
Result: &ForwardResult{
|
||||
RequestID: "gateway_payload_hash",
|
||||
Usage: ClaudeUsage{
|
||||
InputTokens: 10,
|
||||
OutputTokens: 6,
|
||||
},
|
||||
Model: "claude-sonnet-4",
|
||||
Duration: time.Second,
|
||||
},
|
||||
APIKey: &APIKey{ID: 501, Quota: 100},
|
||||
User: &User{ID: 601},
|
||||
Account: &Account{ID: 701},
|
||||
RequestPayloadHash: payloadHash,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, billingRepo.lastCmd)
|
||||
require.Equal(t, payloadHash, billingRepo.lastCmd.RequestPayloadHash)
|
||||
}
|
||||
|
||||
func TestGatewayServiceRecordUsage_BillingFingerprintFallsBackToContextRequestID(t *testing.T) {
|
||||
usageRepo := &openAIRecordUsageLogRepoStub{}
|
||||
billingRepo := &openAIRecordUsageBillingRepoStub{result: &UsageBillingApplyResult{Applied: true}}
|
||||
svc := newGatewayRecordUsageServiceWithBillingRepoForTest(usageRepo, billingRepo, &openAIRecordUsageUserRepoStub{}, &openAIRecordUsageSubRepoStub{})
|
||||
|
||||
ctx := context.WithValue(context.Background(), ctxkey.RequestID, "req-local-123")
|
||||
err := svc.RecordUsage(ctx, &RecordUsageInput{
|
||||
Result: &ForwardResult{
|
||||
RequestID: "gateway_payload_fallback",
|
||||
Usage: ClaudeUsage{
|
||||
InputTokens: 10,
|
||||
OutputTokens: 6,
|
||||
},
|
||||
Model: "claude-sonnet-4",
|
||||
Duration: time.Second,
|
||||
},
|
||||
APIKey: &APIKey{ID: 501, Quota: 100},
|
||||
User: &User{ID: 601},
|
||||
Account: &Account{ID: 701},
|
||||
})
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, billingRepo.lastCmd)
|
||||
require.Equal(t, "local:req-local-123", billingRepo.lastCmd.RequestPayloadHash)
|
||||
}
|
||||
|
||||
func TestGatewayServiceRecordUsage_UsageLogWriteErrorDoesNotSkipBilling(t *testing.T) {
|
||||
usageRepo := &openAIRecordUsageLogRepoStub{inserted: false, err: MarkUsageLogCreateNotPersisted(context.Canceled)}
|
||||
userRepo := &openAIRecordUsageUserRepoStub{}
|
||||
subRepo := &openAIRecordUsageSubRepoStub{}
|
||||
quotaSvc := &openAIRecordUsageAPIKeyQuotaStub{}
|
||||
svc := newGatewayRecordUsageServiceForTest(usageRepo, userRepo, subRepo)
|
||||
|
||||
err := svc.RecordUsage(context.Background(), &RecordUsageInput{
|
||||
Result: &ForwardResult{
|
||||
RequestID: "gateway_not_persisted",
|
||||
Usage: ClaudeUsage{
|
||||
InputTokens: 10,
|
||||
OutputTokens: 6,
|
||||
},
|
||||
Model: "claude-sonnet-4",
|
||||
Duration: time.Second,
|
||||
},
|
||||
APIKey: &APIKey{
|
||||
ID: 503,
|
||||
Quota: 100,
|
||||
},
|
||||
User: &User{ID: 603},
|
||||
Account: &Account{ID: 703},
|
||||
APIKeyService: quotaSvc,
|
||||
})
|
||||
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, 1, usageRepo.calls)
|
||||
require.Equal(t, 1, userRepo.deductCalls)
|
||||
require.Equal(t, 1, quotaSvc.quotaCalls)
|
||||
}
|
||||
|
||||
func TestGatewayServiceRecordUsageWithLongContext_BillingUsesDetachedContext(t *testing.T) {
|
||||
usageRepo := &openAIRecordUsageLogRepoStub{inserted: false, err: context.DeadlineExceeded}
|
||||
userRepo := &openAIRecordUsageUserRepoStub{}
|
||||
subRepo := &openAIRecordUsageSubRepoStub{}
|
||||
quotaSvc := &openAIRecordUsageAPIKeyQuotaStub{}
|
||||
svc := newGatewayRecordUsageServiceForTest(usageRepo, userRepo, subRepo)
|
||||
|
||||
reqCtx, cancel := context.WithCancel(context.Background())
|
||||
cancel()
|
||||
|
||||
err := svc.RecordUsageWithLongContext(reqCtx, &RecordUsageLongContextInput{
|
||||
Result: &ForwardResult{
|
||||
RequestID: "gateway_long_context_detached_ctx",
|
||||
Usage: ClaudeUsage{
|
||||
InputTokens: 12,
|
||||
OutputTokens: 8,
|
||||
},
|
||||
Model: "claude-sonnet-4",
|
||||
Duration: time.Second,
|
||||
},
|
||||
APIKey: &APIKey{
|
||||
ID: 502,
|
||||
Quota: 100,
|
||||
},
|
||||
User: &User{ID: 602},
|
||||
Account: &Account{ID: 702},
|
||||
LongContextThreshold: 200000,
|
||||
LongContextMultiplier: 2,
|
||||
APIKeyService: quotaSvc,
|
||||
})
|
||||
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, 1, usageRepo.calls)
|
||||
require.Equal(t, 1, userRepo.deductCalls)
|
||||
require.NoError(t, userRepo.lastCtxErr)
|
||||
require.Equal(t, 1, quotaSvc.quotaCalls)
|
||||
require.NoError(t, quotaSvc.lastQuotaCtxErr)
|
||||
}
|
||||
|
||||
func TestGatewayServiceRecordUsage_UsesFallbackRequestIDForUsageLog(t *testing.T) {
|
||||
usageRepo := &openAIRecordUsageLogRepoStub{}
|
||||
userRepo := &openAIRecordUsageUserRepoStub{}
|
||||
subRepo := &openAIRecordUsageSubRepoStub{}
|
||||
svc := newGatewayRecordUsageServiceForTest(usageRepo, userRepo, subRepo)
|
||||
|
||||
ctx := context.WithValue(context.Background(), ctxkey.RequestID, "gateway-local-fallback")
|
||||
err := svc.RecordUsage(ctx, &RecordUsageInput{
|
||||
Result: &ForwardResult{
|
||||
RequestID: "",
|
||||
Usage: ClaudeUsage{
|
||||
InputTokens: 10,
|
||||
OutputTokens: 6,
|
||||
},
|
||||
Model: "claude-sonnet-4",
|
||||
Duration: time.Second,
|
||||
},
|
||||
APIKey: &APIKey{ID: 504},
|
||||
User: &User{ID: 604},
|
||||
Account: &Account{ID: 704},
|
||||
})
|
||||
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, usageRepo.lastLog)
|
||||
require.Equal(t, "local:gateway-local-fallback", usageRepo.lastLog.RequestID)
|
||||
}
|
||||
|
||||
func TestGatewayServiceRecordUsage_BillingErrorSkipsUsageLogWrite(t *testing.T) {
|
||||
usageRepo := &openAIRecordUsageLogRepoStub{}
|
||||
billingRepo := &openAIRecordUsageBillingRepoStub{err: context.DeadlineExceeded}
|
||||
userRepo := &openAIRecordUsageUserRepoStub{}
|
||||
subRepo := &openAIRecordUsageSubRepoStub{}
|
||||
svc := newGatewayRecordUsageServiceWithBillingRepoForTest(usageRepo, billingRepo, userRepo, subRepo)
|
||||
|
||||
err := svc.RecordUsage(context.Background(), &RecordUsageInput{
|
||||
Result: &ForwardResult{
|
||||
RequestID: "gateway_billing_fail",
|
||||
Usage: ClaudeUsage{
|
||||
InputTokens: 10,
|
||||
OutputTokens: 6,
|
||||
},
|
||||
Model: "claude-sonnet-4",
|
||||
Duration: time.Second,
|
||||
},
|
||||
APIKey: &APIKey{ID: 505},
|
||||
User: &User{ID: 605},
|
||||
Account: &Account{ID: 705},
|
||||
})
|
||||
|
||||
require.Error(t, err)
|
||||
require.Equal(t, 1, billingRepo.calls)
|
||||
require.Equal(t, 0, usageRepo.calls)
|
||||
}
|
||||
@@ -50,6 +50,7 @@ const (
|
||||
|
||||
defaultUserGroupRateCacheTTL = 30 * time.Second
|
||||
defaultModelsListCacheTTL = 15 * time.Second
|
||||
postUsageBillingTimeout = 15 * time.Second
|
||||
)
|
||||
|
||||
const (
|
||||
@@ -106,6 +107,52 @@ func GatewayModelsListCacheStats() (cacheHit, cacheMiss, store int64) {
|
||||
return modelsListCacheHitTotal.Load(), modelsListCacheMissTotal.Load(), modelsListCacheStoreTotal.Load()
|
||||
}
|
||||
|
||||
func claudeUsageHasAnyTokens(usage *ClaudeUsage) bool {
|
||||
return usage != nil && (usage.InputTokens > 0 ||
|
||||
usage.OutputTokens > 0 ||
|
||||
usage.CacheCreationInputTokens > 0 ||
|
||||
usage.CacheReadInputTokens > 0 ||
|
||||
usage.CacheCreation5mTokens > 0 ||
|
||||
usage.CacheCreation1hTokens > 0)
|
||||
}
|
||||
|
||||
func openAIUsageHasAnyTokens(usage *OpenAIUsage) bool {
|
||||
return usage != nil && (usage.InputTokens > 0 ||
|
||||
usage.OutputTokens > 0 ||
|
||||
usage.CacheCreationInputTokens > 0 ||
|
||||
usage.CacheReadInputTokens > 0)
|
||||
}
|
||||
|
||||
func openAIStreamEventIsTerminal(data string) bool {
|
||||
trimmed := strings.TrimSpace(data)
|
||||
if trimmed == "" {
|
||||
return false
|
||||
}
|
||||
if trimmed == "[DONE]" {
|
||||
return true
|
||||
}
|
||||
switch gjson.Get(trimmed, "type").String() {
|
||||
case "response.completed", "response.done", "response.failed":
|
||||
return true
|
||||
default:
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
func anthropicStreamEventIsTerminal(eventName, data string) bool {
|
||||
if strings.EqualFold(strings.TrimSpace(eventName), "message_stop") {
|
||||
return true
|
||||
}
|
||||
trimmed := strings.TrimSpace(data)
|
||||
if trimmed == "" {
|
||||
return false
|
||||
}
|
||||
if trimmed == "[DONE]" {
|
||||
return true
|
||||
}
|
||||
return gjson.Get(trimmed, "type").String() == "message_stop"
|
||||
}
|
||||
|
||||
func cloneStringSlice(src []string) []string {
|
||||
if len(src) == 0 {
|
||||
return nil
|
||||
@@ -504,6 +551,7 @@ type GatewayService struct {
|
||||
accountRepo AccountRepository
|
||||
groupRepo GroupRepository
|
||||
usageLogRepo UsageLogRepository
|
||||
usageBillingRepo UsageBillingRepository
|
||||
userRepo UserRepository
|
||||
userSubRepo UserSubscriptionRepository
|
||||
userGroupRateRepo UserGroupRateRepository
|
||||
@@ -537,6 +585,7 @@ func NewGatewayService(
|
||||
accountRepo AccountRepository,
|
||||
groupRepo GroupRepository,
|
||||
usageLogRepo UsageLogRepository,
|
||||
usageBillingRepo UsageBillingRepository,
|
||||
userRepo UserRepository,
|
||||
userSubRepo UserSubscriptionRepository,
|
||||
userGroupRateRepo UserGroupRateRepository,
|
||||
@@ -563,6 +612,7 @@ func NewGatewayService(
|
||||
accountRepo: accountRepo,
|
||||
groupRepo: groupRepo,
|
||||
usageLogRepo: usageLogRepo,
|
||||
usageBillingRepo: usageBillingRepo,
|
||||
userRepo: userRepo,
|
||||
userSubRepo: userSubRepo,
|
||||
userGroupRateRepo: userGroupRateRepo,
|
||||
@@ -4049,7 +4099,9 @@ func (s *GatewayService) Forward(ctx context.Context, c *gin.Context, account *A
|
||||
retryStart := time.Now()
|
||||
for attempt := 1; attempt <= maxRetryAttempts; attempt++ {
|
||||
// 构建上游请求(每次重试需要重新构建,因为请求体需要重新读取)
|
||||
upstreamReq, err := s.buildUpstreamRequest(ctx, c, account, body, token, tokenType, reqModel, reqStream, shouldMimicClaudeCode)
|
||||
upstreamCtx, releaseUpstreamCtx := detachStreamUpstreamContext(ctx, reqStream)
|
||||
upstreamReq, err := s.buildUpstreamRequest(upstreamCtx, c, account, body, token, tokenType, reqModel, reqStream, shouldMimicClaudeCode)
|
||||
releaseUpstreamCtx()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -4127,7 +4179,9 @@ func (s *GatewayService) Forward(ctx context.Context, c *gin.Context, account *A
|
||||
// also downgrade tool_use/tool_result blocks to text.
|
||||
|
||||
filteredBody := FilterThinkingBlocksForRetry(body)
|
||||
retryReq, buildErr := s.buildUpstreamRequest(ctx, c, account, filteredBody, token, tokenType, reqModel, reqStream, shouldMimicClaudeCode)
|
||||
retryCtx, releaseRetryCtx := detachStreamUpstreamContext(ctx, reqStream)
|
||||
retryReq, buildErr := s.buildUpstreamRequest(retryCtx, c, account, filteredBody, token, tokenType, reqModel, reqStream, shouldMimicClaudeCode)
|
||||
releaseRetryCtx()
|
||||
if buildErr == nil {
|
||||
retryResp, retryErr := s.httpUpstream.DoWithTLS(retryReq, proxyURL, account.ID, account.Concurrency, account.IsTLSFingerprintEnabled())
|
||||
if retryErr == nil {
|
||||
@@ -4159,7 +4213,9 @@ func (s *GatewayService) Forward(ctx context.Context, c *gin.Context, account *A
|
||||
if looksLikeToolSignatureError(msg2) && time.Since(retryStart) < maxRetryElapsed {
|
||||
logger.LegacyPrintf("service.gateway", "Account %d: signature retry still failing and looks tool-related, retrying with tool blocks downgraded", account.ID)
|
||||
filteredBody2 := FilterSignatureSensitiveBlocksForRetry(body)
|
||||
retryReq2, buildErr2 := s.buildUpstreamRequest(ctx, c, account, filteredBody2, token, tokenType, reqModel, reqStream, shouldMimicClaudeCode)
|
||||
retryCtx2, releaseRetryCtx2 := detachStreamUpstreamContext(ctx, reqStream)
|
||||
retryReq2, buildErr2 := s.buildUpstreamRequest(retryCtx2, c, account, filteredBody2, token, tokenType, reqModel, reqStream, shouldMimicClaudeCode)
|
||||
releaseRetryCtx2()
|
||||
if buildErr2 == nil {
|
||||
retryResp2, retryErr2 := s.httpUpstream.DoWithTLS(retryReq2, proxyURL, account.ID, account.Concurrency, account.IsTLSFingerprintEnabled())
|
||||
if retryErr2 == nil {
|
||||
@@ -4226,7 +4282,9 @@ func (s *GatewayService) Forward(ctx context.Context, c *gin.Context, account *A
|
||||
rectifiedBody, applied := RectifyThinkingBudget(body)
|
||||
if applied && time.Since(retryStart) < maxRetryElapsed {
|
||||
logger.LegacyPrintf("service.gateway", "Account %d: detected budget_tokens constraint error, retrying with rectified budget (budget_tokens=%d, max_tokens=%d)", account.ID, BudgetRectifyBudgetTokens, BudgetRectifyMaxTokens)
|
||||
budgetRetryReq, buildErr := s.buildUpstreamRequest(ctx, c, account, rectifiedBody, token, tokenType, reqModel, reqStream, shouldMimicClaudeCode)
|
||||
budgetRetryCtx, releaseBudgetRetryCtx := detachStreamUpstreamContext(ctx, reqStream)
|
||||
budgetRetryReq, buildErr := s.buildUpstreamRequest(budgetRetryCtx, c, account, rectifiedBody, token, tokenType, reqModel, reqStream, shouldMimicClaudeCode)
|
||||
releaseBudgetRetryCtx()
|
||||
if buildErr == nil {
|
||||
budgetRetryResp, retryErr := s.httpUpstream.DoWithTLS(budgetRetryReq, proxyURL, account.ID, account.Concurrency, account.IsTLSFingerprintEnabled())
|
||||
if retryErr == nil {
|
||||
@@ -4498,7 +4556,9 @@ func (s *GatewayService) forwardAnthropicAPIKeyPassthrough(
|
||||
var resp *http.Response
|
||||
retryStart := time.Now()
|
||||
for attempt := 1; attempt <= maxRetryAttempts; attempt++ {
|
||||
upstreamReq, err := s.buildUpstreamRequestAnthropicAPIKeyPassthrough(ctx, c, account, body, token)
|
||||
upstreamCtx, releaseUpstreamCtx := detachStreamUpstreamContext(ctx, reqStream)
|
||||
upstreamReq, err := s.buildUpstreamRequestAnthropicAPIKeyPassthrough(upstreamCtx, c, account, body, token)
|
||||
releaseUpstreamCtx()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -4774,6 +4834,7 @@ func (s *GatewayService) handleStreamingResponseAnthropicAPIKeyPassthrough(
|
||||
usage := &ClaudeUsage{}
|
||||
var firstTokenMs *int
|
||||
clientDisconnected := false
|
||||
sawTerminalEvent := false
|
||||
|
||||
scanner := bufio.NewScanner(resp.Body)
|
||||
maxLineSize := defaultMaxLineSize
|
||||
@@ -4836,17 +4897,20 @@ func (s *GatewayService) handleStreamingResponseAnthropicAPIKeyPassthrough(
|
||||
// 兜底补刷,确保最后一个未以空行结尾的事件也能及时送达客户端。
|
||||
flusher.Flush()
|
||||
}
|
||||
if !sawTerminalEvent {
|
||||
return &streamingResult{usage: usage, firstTokenMs: firstTokenMs, clientDisconnect: clientDisconnected}, fmt.Errorf("stream usage incomplete: missing terminal event")
|
||||
}
|
||||
return &streamingResult{usage: usage, firstTokenMs: firstTokenMs, clientDisconnect: clientDisconnected}, nil
|
||||
}
|
||||
if ev.err != nil {
|
||||
if sawTerminalEvent {
|
||||
return &streamingResult{usage: usage, firstTokenMs: firstTokenMs, clientDisconnect: clientDisconnected}, nil
|
||||
}
|
||||
if clientDisconnected {
|
||||
logger.LegacyPrintf("service.gateway", "[Anthropic passthrough] Upstream read error after client disconnect: account=%d err=%v", account.ID, ev.err)
|
||||
return &streamingResult{usage: usage, firstTokenMs: firstTokenMs, clientDisconnect: true}, nil
|
||||
return &streamingResult{usage: usage, firstTokenMs: firstTokenMs, clientDisconnect: true}, fmt.Errorf("stream usage incomplete after disconnect: %w", ev.err)
|
||||
}
|
||||
if errors.Is(ev.err, context.Canceled) || errors.Is(ev.err, context.DeadlineExceeded) {
|
||||
logger.LegacyPrintf("service.gateway", "[Anthropic passthrough] 流读取被取消: account=%d request_id=%s err=%v ctx_err=%v",
|
||||
account.ID, resp.Header.Get("x-request-id"), ev.err, ctx.Err())
|
||||
return &streamingResult{usage: usage, firstTokenMs: firstTokenMs, clientDisconnect: true}, nil
|
||||
return &streamingResult{usage: usage, firstTokenMs: firstTokenMs, clientDisconnect: true}, fmt.Errorf("stream usage incomplete: %w", ev.err)
|
||||
}
|
||||
if errors.Is(ev.err, bufio.ErrTooLong) {
|
||||
logger.LegacyPrintf("service.gateway", "[Anthropic passthrough] SSE line too long: account=%d max_size=%d error=%v", account.ID, maxLineSize, ev.err)
|
||||
@@ -4858,11 +4922,19 @@ func (s *GatewayService) handleStreamingResponseAnthropicAPIKeyPassthrough(
|
||||
line := ev.line
|
||||
if data, ok := extractAnthropicSSEDataLine(line); ok {
|
||||
trimmed := strings.TrimSpace(data)
|
||||
if anthropicStreamEventIsTerminal("", trimmed) {
|
||||
sawTerminalEvent = true
|
||||
}
|
||||
if firstTokenMs == nil && trimmed != "" && trimmed != "[DONE]" {
|
||||
ms := int(time.Since(startTime).Milliseconds())
|
||||
firstTokenMs = &ms
|
||||
}
|
||||
s.parseSSEUsagePassthrough(data, usage)
|
||||
} else {
|
||||
trimmed := strings.TrimSpace(line)
|
||||
if strings.HasPrefix(trimmed, "event:") && anthropicStreamEventIsTerminal(strings.TrimSpace(strings.TrimPrefix(trimmed, "event:")), "") {
|
||||
sawTerminalEvent = true
|
||||
}
|
||||
}
|
||||
|
||||
if !clientDisconnected {
|
||||
@@ -4884,8 +4956,7 @@ func (s *GatewayService) handleStreamingResponseAnthropicAPIKeyPassthrough(
|
||||
continue
|
||||
}
|
||||
if clientDisconnected {
|
||||
logger.LegacyPrintf("service.gateway", "[Anthropic passthrough] Upstream timeout after client disconnect: account=%d model=%s", account.ID, model)
|
||||
return &streamingResult{usage: usage, firstTokenMs: firstTokenMs, clientDisconnect: true}, nil
|
||||
return &streamingResult{usage: usage, firstTokenMs: firstTokenMs, clientDisconnect: true}, fmt.Errorf("stream usage incomplete after timeout")
|
||||
}
|
||||
logger.LegacyPrintf("service.gateway", "[Anthropic passthrough] Stream data interval timeout: account=%d model=%s interval=%s", account.ID, model, streamInterval)
|
||||
if s.rateLimitService != nil {
|
||||
@@ -6011,6 +6082,7 @@ func (s *GatewayService) handleStreamingResponse(ctx context.Context, resp *http
|
||||
|
||||
needModelReplace := originalModel != mappedModel
|
||||
clientDisconnected := false // 客户端断开标志,断开后继续读取上游以获取完整usage
|
||||
sawTerminalEvent := false
|
||||
|
||||
pendingEventLines := make([]string, 0, 4)
|
||||
|
||||
@@ -6041,6 +6113,7 @@ func (s *GatewayService) handleStreamingResponse(ctx context.Context, resp *http
|
||||
}
|
||||
|
||||
if dataLine == "[DONE]" {
|
||||
sawTerminalEvent = true
|
||||
block := ""
|
||||
if eventName != "" {
|
||||
block = "event: " + eventName + "\n"
|
||||
@@ -6107,6 +6180,9 @@ func (s *GatewayService) handleStreamingResponse(ctx context.Context, resp *http
|
||||
}
|
||||
|
||||
usagePatch := s.extractSSEUsagePatch(event)
|
||||
if anthropicStreamEventIsTerminal(eventName, dataLine) {
|
||||
sawTerminalEvent = true
|
||||
}
|
||||
if !eventChanged {
|
||||
block := ""
|
||||
if eventName != "" {
|
||||
@@ -6140,18 +6216,22 @@ func (s *GatewayService) handleStreamingResponse(ctx context.Context, resp *http
|
||||
case ev, ok := <-events:
|
||||
if !ok {
|
||||
// 上游完成,返回结果
|
||||
if !sawTerminalEvent {
|
||||
return &streamingResult{usage: usage, firstTokenMs: firstTokenMs, clientDisconnect: clientDisconnected}, fmt.Errorf("stream usage incomplete: missing terminal event")
|
||||
}
|
||||
return &streamingResult{usage: usage, firstTokenMs: firstTokenMs, clientDisconnect: clientDisconnected}, nil
|
||||
}
|
||||
if ev.err != nil {
|
||||
if sawTerminalEvent {
|
||||
return &streamingResult{usage: usage, firstTokenMs: firstTokenMs, clientDisconnect: clientDisconnected}, nil
|
||||
}
|
||||
// 检测 context 取消(客户端断开会导致 context 取消,进而影响上游读取)
|
||||
if errors.Is(ev.err, context.Canceled) || errors.Is(ev.err, context.DeadlineExceeded) {
|
||||
logger.LegacyPrintf("service.gateway", "Context canceled during streaming, returning collected usage")
|
||||
return &streamingResult{usage: usage, firstTokenMs: firstTokenMs, clientDisconnect: true}, nil
|
||||
return &streamingResult{usage: usage, firstTokenMs: firstTokenMs, clientDisconnect: true}, fmt.Errorf("stream usage incomplete: %w", ev.err)
|
||||
}
|
||||
// 客户端已通过写入失败检测到断开,上游也出错了,返回已收集的 usage
|
||||
if clientDisconnected {
|
||||
logger.LegacyPrintf("service.gateway", "Upstream read error after client disconnect: %v, returning collected usage", ev.err)
|
||||
return &streamingResult{usage: usage, firstTokenMs: firstTokenMs, clientDisconnect: true}, nil
|
||||
return &streamingResult{usage: usage, firstTokenMs: firstTokenMs, clientDisconnect: true}, fmt.Errorf("stream usage incomplete after disconnect: %w", ev.err)
|
||||
}
|
||||
// 客户端未断开,正常的错误处理
|
||||
if errors.Is(ev.err, bufio.ErrTooLong) {
|
||||
@@ -6209,9 +6289,7 @@ func (s *GatewayService) handleStreamingResponse(ctx context.Context, resp *http
|
||||
continue
|
||||
}
|
||||
if clientDisconnected {
|
||||
// 客户端已断开,上游也超时了,返回已收集的 usage
|
||||
logger.LegacyPrintf("service.gateway", "Upstream timeout after client disconnect, returning collected usage")
|
||||
return &streamingResult{usage: usage, firstTokenMs: firstTokenMs, clientDisconnect: true}, nil
|
||||
return &streamingResult{usage: usage, firstTokenMs: firstTokenMs, clientDisconnect: true}, fmt.Errorf("stream usage incomplete after timeout")
|
||||
}
|
||||
logger.LegacyPrintf("service.gateway", "Stream data interval timeout: account=%d model=%s interval=%s", account.ID, originalModel, streamInterval)
|
||||
// 处理流超时,可能标记账户为临时不可调度或错误状态
|
||||
@@ -6557,15 +6635,16 @@ func (s *GatewayService) getUserGroupRateMultiplier(ctx context.Context, userID,
|
||||
|
||||
// RecordUsageInput 记录使用量的输入参数
|
||||
type RecordUsageInput struct {
|
||||
Result *ForwardResult
|
||||
APIKey *APIKey
|
||||
User *User
|
||||
Account *Account
|
||||
Subscription *UserSubscription // 可选:订阅信息
|
||||
UserAgent string // 请求的 User-Agent
|
||||
IPAddress string // 请求的客户端 IP 地址
|
||||
ForceCacheBilling bool // 强制缓存计费:将 input_tokens 转为 cache_read 计费(用于粘性会话切换)
|
||||
APIKeyService APIKeyQuotaUpdater // 可选:用于更新API Key配额
|
||||
Result *ForwardResult
|
||||
APIKey *APIKey
|
||||
User *User
|
||||
Account *Account
|
||||
Subscription *UserSubscription // 可选:订阅信息
|
||||
UserAgent string // 请求的 User-Agent
|
||||
IPAddress string // 请求的客户端 IP 地址
|
||||
RequestPayloadHash string // 请求体语义哈希,用于降低 request_id 误复用时的静默误去重风险
|
||||
ForceCacheBilling bool // 强制缓存计费:将 input_tokens 转为 cache_read 计费(用于粘性会话切换)
|
||||
APIKeyService APIKeyQuotaUpdater // 可选:用于更新API Key配额
|
||||
}
|
||||
|
||||
// APIKeyQuotaUpdater defines the interface for updating API Key quota and rate limit usage
|
||||
@@ -6574,6 +6653,14 @@ type APIKeyQuotaUpdater interface {
|
||||
UpdateRateLimitUsage(ctx context.Context, apiKeyID int64, cost float64) error
|
||||
}
|
||||
|
||||
type apiKeyAuthCacheInvalidator interface {
|
||||
InvalidateAuthCacheByKey(ctx context.Context, key string)
|
||||
}
|
||||
|
||||
type usageLogBestEffortWriter interface {
|
||||
CreateBestEffort(ctx context.Context, log *UsageLog) error
|
||||
}
|
||||
|
||||
// postUsageBillingParams 统一扣费所需的参数
|
||||
type postUsageBillingParams struct {
|
||||
Cost *CostBreakdown
|
||||
@@ -6581,6 +6668,7 @@ type postUsageBillingParams struct {
|
||||
APIKey *APIKey
|
||||
Account *Account
|
||||
Subscription *UserSubscription
|
||||
RequestPayloadHash string
|
||||
IsSubscriptionBill bool
|
||||
AccountRateMultiplier float64
|
||||
APIKeyService APIKeyQuotaUpdater
|
||||
@@ -6592,19 +6680,22 @@ type postUsageBillingParams struct {
|
||||
// - API Key 限速用量更新
|
||||
// - 账号配额用量更新(账号口径:TotalCost × 账号计费倍率)
|
||||
func postUsageBilling(ctx context.Context, p *postUsageBillingParams, deps *billingDeps) {
|
||||
billingCtx, cancel := detachedBillingContext(ctx)
|
||||
defer cancel()
|
||||
|
||||
cost := p.Cost
|
||||
|
||||
// 1. 订阅 / 余额扣费
|
||||
if p.IsSubscriptionBill {
|
||||
if cost.TotalCost > 0 {
|
||||
if err := deps.userSubRepo.IncrementUsage(ctx, p.Subscription.ID, cost.TotalCost); err != nil {
|
||||
if err := deps.userSubRepo.IncrementUsage(billingCtx, p.Subscription.ID, cost.TotalCost); err != nil {
|
||||
slog.Error("increment subscription usage failed", "subscription_id", p.Subscription.ID, "error", err)
|
||||
}
|
||||
deps.billingCacheService.QueueUpdateSubscriptionUsage(p.User.ID, *p.APIKey.GroupID, cost.TotalCost)
|
||||
}
|
||||
} else {
|
||||
if cost.ActualCost > 0 {
|
||||
if err := deps.userRepo.DeductBalance(ctx, p.User.ID, cost.ActualCost); err != nil {
|
||||
if err := deps.userRepo.DeductBalance(billingCtx, p.User.ID, cost.ActualCost); err != nil {
|
||||
slog.Error("deduct balance failed", "user_id", p.User.ID, "error", err)
|
||||
}
|
||||
deps.billingCacheService.QueueDeductBalance(p.User.ID, cost.ActualCost)
|
||||
@@ -6613,31 +6704,187 @@ func postUsageBilling(ctx context.Context, p *postUsageBillingParams, deps *bill
|
||||
|
||||
// 2. API Key 配额
|
||||
if cost.ActualCost > 0 && p.APIKey.Quota > 0 && p.APIKeyService != nil {
|
||||
if err := p.APIKeyService.UpdateQuotaUsed(ctx, p.APIKey.ID, cost.ActualCost); err != nil {
|
||||
if err := p.APIKeyService.UpdateQuotaUsed(billingCtx, p.APIKey.ID, cost.ActualCost); err != nil {
|
||||
slog.Error("update api key quota failed", "api_key_id", p.APIKey.ID, "error", err)
|
||||
}
|
||||
}
|
||||
|
||||
// 3. API Key 限速用量
|
||||
if cost.ActualCost > 0 && p.APIKey.HasRateLimits() && p.APIKeyService != nil {
|
||||
if err := p.APIKeyService.UpdateRateLimitUsage(ctx, p.APIKey.ID, cost.ActualCost); err != nil {
|
||||
if err := p.APIKeyService.UpdateRateLimitUsage(billingCtx, p.APIKey.ID, cost.ActualCost); err != nil {
|
||||
slog.Error("update api key rate limit usage failed", "api_key_id", p.APIKey.ID, "error", err)
|
||||
}
|
||||
deps.billingCacheService.QueueUpdateAPIKeyRateLimitUsage(p.APIKey.ID, cost.ActualCost)
|
||||
}
|
||||
|
||||
// 4. 账号配额用量(账号口径:TotalCost × 账号计费倍率)
|
||||
if cost.TotalCost > 0 && p.Account.Type == AccountTypeAPIKey && p.Account.HasAnyQuotaLimit() {
|
||||
accountCost := cost.TotalCost * p.AccountRateMultiplier
|
||||
if err := deps.accountRepo.IncrementQuotaUsed(ctx, p.Account.ID, accountCost); err != nil {
|
||||
if err := deps.accountRepo.IncrementQuotaUsed(billingCtx, p.Account.ID, accountCost); err != nil {
|
||||
slog.Error("increment account quota used failed", "account_id", p.Account.ID, "cost", accountCost, "error", err)
|
||||
}
|
||||
}
|
||||
|
||||
// 5. 更新账号最近使用时间
|
||||
finalizePostUsageBilling(p, deps)
|
||||
}
|
||||
|
||||
func resolveUsageBillingRequestID(ctx context.Context, upstreamRequestID string) string {
|
||||
if requestID := strings.TrimSpace(upstreamRequestID); requestID != "" {
|
||||
return requestID
|
||||
}
|
||||
if ctx != nil {
|
||||
if clientRequestID, _ := ctx.Value(ctxkey.ClientRequestID).(string); strings.TrimSpace(clientRequestID) != "" {
|
||||
return "client:" + strings.TrimSpace(clientRequestID)
|
||||
}
|
||||
if requestID, _ := ctx.Value(ctxkey.RequestID).(string); strings.TrimSpace(requestID) != "" {
|
||||
return "local:" + strings.TrimSpace(requestID)
|
||||
}
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
func resolveUsageBillingPayloadFingerprint(ctx context.Context, requestPayloadHash string) string {
|
||||
if payloadHash := strings.TrimSpace(requestPayloadHash); payloadHash != "" {
|
||||
return payloadHash
|
||||
}
|
||||
if ctx != nil {
|
||||
if clientRequestID, _ := ctx.Value(ctxkey.ClientRequestID).(string); strings.TrimSpace(clientRequestID) != "" {
|
||||
return "client:" + strings.TrimSpace(clientRequestID)
|
||||
}
|
||||
if requestID, _ := ctx.Value(ctxkey.RequestID).(string); strings.TrimSpace(requestID) != "" {
|
||||
return "local:" + strings.TrimSpace(requestID)
|
||||
}
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
func buildUsageBillingCommand(requestID string, usageLog *UsageLog, p *postUsageBillingParams) *UsageBillingCommand {
|
||||
if p == nil || p.Cost == nil || p.APIKey == nil || p.User == nil || p.Account == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
cmd := &UsageBillingCommand{
|
||||
RequestID: requestID,
|
||||
APIKeyID: p.APIKey.ID,
|
||||
UserID: p.User.ID,
|
||||
AccountID: p.Account.ID,
|
||||
AccountType: p.Account.Type,
|
||||
RequestPayloadHash: strings.TrimSpace(p.RequestPayloadHash),
|
||||
}
|
||||
if usageLog != nil {
|
||||
cmd.Model = usageLog.Model
|
||||
cmd.BillingType = usageLog.BillingType
|
||||
cmd.InputTokens = usageLog.InputTokens
|
||||
cmd.OutputTokens = usageLog.OutputTokens
|
||||
cmd.CacheCreationTokens = usageLog.CacheCreationTokens
|
||||
cmd.CacheReadTokens = usageLog.CacheReadTokens
|
||||
cmd.ImageCount = usageLog.ImageCount
|
||||
if usageLog.MediaType != nil {
|
||||
cmd.MediaType = *usageLog.MediaType
|
||||
}
|
||||
if usageLog.ServiceTier != nil {
|
||||
cmd.ServiceTier = *usageLog.ServiceTier
|
||||
}
|
||||
if usageLog.ReasoningEffort != nil {
|
||||
cmd.ReasoningEffort = *usageLog.ReasoningEffort
|
||||
}
|
||||
if usageLog.SubscriptionID != nil {
|
||||
cmd.SubscriptionID = usageLog.SubscriptionID
|
||||
}
|
||||
}
|
||||
|
||||
if p.IsSubscriptionBill && p.Subscription != nil && p.Cost.TotalCost > 0 {
|
||||
cmd.SubscriptionID = &p.Subscription.ID
|
||||
cmd.SubscriptionCost = p.Cost.TotalCost
|
||||
} else if p.Cost.ActualCost > 0 {
|
||||
cmd.BalanceCost = p.Cost.ActualCost
|
||||
}
|
||||
|
||||
if p.Cost.ActualCost > 0 && p.APIKey.Quota > 0 && p.APIKeyService != nil {
|
||||
cmd.APIKeyQuotaCost = p.Cost.ActualCost
|
||||
}
|
||||
if p.Cost.ActualCost > 0 && p.APIKey.HasRateLimits() && p.APIKeyService != nil {
|
||||
cmd.APIKeyRateLimitCost = p.Cost.ActualCost
|
||||
}
|
||||
if p.Cost.TotalCost > 0 && p.Account.Type == AccountTypeAPIKey && p.Account.HasAnyQuotaLimit() {
|
||||
cmd.AccountQuotaCost = p.Cost.TotalCost * p.AccountRateMultiplier
|
||||
}
|
||||
|
||||
cmd.Normalize()
|
||||
return cmd
|
||||
}
|
||||
|
||||
func applyUsageBilling(ctx context.Context, requestID string, usageLog *UsageLog, p *postUsageBillingParams, deps *billingDeps, repo UsageBillingRepository) (bool, error) {
|
||||
if p == nil || deps == nil {
|
||||
return false, nil
|
||||
}
|
||||
|
||||
cmd := buildUsageBillingCommand(requestID, usageLog, p)
|
||||
if cmd == nil || cmd.RequestID == "" || repo == nil {
|
||||
postUsageBilling(ctx, p, deps)
|
||||
return true, nil
|
||||
}
|
||||
|
||||
billingCtx, cancel := detachedBillingContext(ctx)
|
||||
defer cancel()
|
||||
|
||||
result, err := repo.Apply(billingCtx, cmd)
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
|
||||
if result == nil || !result.Applied {
|
||||
deps.deferredService.ScheduleLastUsedUpdate(p.Account.ID)
|
||||
return false, nil
|
||||
}
|
||||
|
||||
if result.APIKeyQuotaExhausted {
|
||||
if invalidator, ok := p.APIKeyService.(apiKeyAuthCacheInvalidator); ok && p.APIKey != nil && p.APIKey.Key != "" {
|
||||
invalidator.InvalidateAuthCacheByKey(billingCtx, p.APIKey.Key)
|
||||
}
|
||||
}
|
||||
|
||||
finalizePostUsageBilling(p, deps)
|
||||
return true, nil
|
||||
}
|
||||
|
||||
func finalizePostUsageBilling(p *postUsageBillingParams, deps *billingDeps) {
|
||||
if p == nil || p.Cost == nil || deps == nil {
|
||||
return
|
||||
}
|
||||
|
||||
if p.IsSubscriptionBill {
|
||||
if p.Cost.TotalCost > 0 && p.User != nil && p.APIKey != nil && p.APIKey.GroupID != nil {
|
||||
deps.billingCacheService.QueueUpdateSubscriptionUsage(p.User.ID, *p.APIKey.GroupID, p.Cost.TotalCost)
|
||||
}
|
||||
} else if p.Cost.ActualCost > 0 && p.User != nil {
|
||||
deps.billingCacheService.QueueDeductBalance(p.User.ID, p.Cost.ActualCost)
|
||||
}
|
||||
|
||||
if p.Cost.ActualCost > 0 && p.APIKey != nil && p.APIKey.HasRateLimits() {
|
||||
deps.billingCacheService.QueueUpdateAPIKeyRateLimitUsage(p.APIKey.ID, p.Cost.ActualCost)
|
||||
}
|
||||
|
||||
deps.deferredService.ScheduleLastUsedUpdate(p.Account.ID)
|
||||
}
|
||||
|
||||
func detachedBillingContext(ctx context.Context) (context.Context, context.CancelFunc) {
|
||||
base := context.Background()
|
||||
if ctx != nil {
|
||||
base = context.WithoutCancel(ctx)
|
||||
}
|
||||
return context.WithTimeout(base, postUsageBillingTimeout)
|
||||
}
|
||||
|
||||
func detachStreamUpstreamContext(ctx context.Context, stream bool) (context.Context, context.CancelFunc) {
|
||||
if !stream {
|
||||
return ctx, func() {}
|
||||
}
|
||||
if ctx == nil {
|
||||
return context.Background(), func() {}
|
||||
}
|
||||
return context.WithoutCancel(ctx), func() {}
|
||||
}
|
||||
|
||||
// billingDeps 扣费逻辑依赖的服务(由各 gateway service 提供)
|
||||
type billingDeps struct {
|
||||
accountRepo AccountRepository
|
||||
@@ -6657,6 +6904,28 @@ func (s *GatewayService) billingDeps() *billingDeps {
|
||||
}
|
||||
}
|
||||
|
||||
func writeUsageLogBestEffort(ctx context.Context, repo UsageLogRepository, usageLog *UsageLog, logKey string) {
|
||||
if repo == nil || usageLog == nil {
|
||||
return
|
||||
}
|
||||
usageCtx, cancel := detachedBillingContext(ctx)
|
||||
defer cancel()
|
||||
|
||||
if writer, ok := repo.(usageLogBestEffortWriter); ok {
|
||||
if err := writer.CreateBestEffort(usageCtx, usageLog); err != nil {
|
||||
logger.LegacyPrintf(logKey, "Create usage log failed: %v", err)
|
||||
if _, syncErr := repo.Create(usageCtx, usageLog); syncErr != nil {
|
||||
logger.LegacyPrintf(logKey, "Create usage log sync fallback failed: %v", syncErr)
|
||||
}
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
if _, err := repo.Create(usageCtx, usageLog); err != nil {
|
||||
logger.LegacyPrintf(logKey, "Create usage log failed: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
// RecordUsage 记录使用量并扣费(或更新订阅用量)
|
||||
func (s *GatewayService) RecordUsage(ctx context.Context, input *RecordUsageInput) error {
|
||||
result := input.Result
|
||||
@@ -6758,11 +7027,12 @@ func (s *GatewayService) RecordUsage(ctx context.Context, input *RecordUsageInpu
|
||||
mediaType = &result.MediaType
|
||||
}
|
||||
accountRateMultiplier := account.BillingRateMultiplier()
|
||||
requestID := resolveUsageBillingRequestID(ctx, result.RequestID)
|
||||
usageLog := &UsageLog{
|
||||
UserID: user.ID,
|
||||
APIKeyID: apiKey.ID,
|
||||
AccountID: account.ID,
|
||||
RequestID: result.RequestID,
|
||||
RequestID: requestID,
|
||||
Model: result.Model,
|
||||
InputTokens: result.Usage.InputTokens,
|
||||
OutputTokens: result.Usage.OutputTokens,
|
||||
@@ -6807,33 +7077,32 @@ func (s *GatewayService) RecordUsage(ctx context.Context, input *RecordUsageInpu
|
||||
usageLog.SubscriptionID = &subscription.ID
|
||||
}
|
||||
|
||||
inserted, err := s.usageLogRepo.Create(ctx, usageLog)
|
||||
if err != nil {
|
||||
logger.LegacyPrintf("service.gateway", "Create usage log failed: %v", err)
|
||||
}
|
||||
|
||||
if s.cfg != nil && s.cfg.RunMode == config.RunModeSimple {
|
||||
writeUsageLogBestEffort(ctx, s.usageLogRepo, usageLog, "service.gateway")
|
||||
logger.LegacyPrintf("service.gateway", "[SIMPLE MODE] Usage recorded (not billed): user=%d, tokens=%d", usageLog.UserID, usageLog.TotalTokens())
|
||||
s.deferredService.ScheduleLastUsedUpdate(account.ID)
|
||||
return nil
|
||||
}
|
||||
|
||||
shouldBill := inserted || err != nil
|
||||
|
||||
if shouldBill {
|
||||
postUsageBilling(ctx, &postUsageBillingParams{
|
||||
billingErr := func() error {
|
||||
_, err := applyUsageBilling(ctx, requestID, usageLog, &postUsageBillingParams{
|
||||
Cost: cost,
|
||||
User: user,
|
||||
APIKey: apiKey,
|
||||
Account: account,
|
||||
Subscription: subscription,
|
||||
RequestPayloadHash: resolveUsageBillingPayloadFingerprint(ctx, input.RequestPayloadHash),
|
||||
IsSubscriptionBill: isSubscriptionBilling,
|
||||
AccountRateMultiplier: accountRateMultiplier,
|
||||
APIKeyService: input.APIKeyService,
|
||||
}, s.billingDeps())
|
||||
} else {
|
||||
s.deferredService.ScheduleLastUsedUpdate(account.ID)
|
||||
}, s.billingDeps(), s.usageBillingRepo)
|
||||
return err
|
||||
}()
|
||||
|
||||
if billingErr != nil {
|
||||
return billingErr
|
||||
}
|
||||
writeUsageLogBestEffort(ctx, s.usageLogRepo, usageLog, "service.gateway")
|
||||
|
||||
return nil
|
||||
}
|
||||
@@ -6844,13 +7113,14 @@ type RecordUsageLongContextInput struct {
|
||||
APIKey *APIKey
|
||||
User *User
|
||||
Account *Account
|
||||
Subscription *UserSubscription // 可选:订阅信息
|
||||
UserAgent string // 请求的 User-Agent
|
||||
IPAddress string // 请求的客户端 IP 地址
|
||||
LongContextThreshold int // 长上下文阈值(如 200000)
|
||||
LongContextMultiplier float64 // 超出阈值部分的倍率(如 2.0)
|
||||
ForceCacheBilling bool // 强制缓存计费:将 input_tokens 转为 cache_read 计费(用于粘性会话切换)
|
||||
APIKeyService *APIKeyService // API Key 配额服务(可选)
|
||||
Subscription *UserSubscription // 可选:订阅信息
|
||||
UserAgent string // 请求的 User-Agent
|
||||
IPAddress string // 请求的客户端 IP 地址
|
||||
RequestPayloadHash string // 请求体语义哈希,用于降低 request_id 误复用时的静默误去重风险
|
||||
LongContextThreshold int // 长上下文阈值(如 200000)
|
||||
LongContextMultiplier float64 // 超出阈值部分的倍率(如 2.0)
|
||||
ForceCacheBilling bool // 强制缓存计费:将 input_tokens 转为 cache_read 计费(用于粘性会话切换)
|
||||
APIKeyService APIKeyQuotaUpdater // API Key 配额服务(可选)
|
||||
}
|
||||
|
||||
// RecordUsageWithLongContext 记录使用量并扣费,支持长上下文双倍计费(用于 Gemini)
|
||||
@@ -6933,11 +7203,12 @@ func (s *GatewayService) RecordUsageWithLongContext(ctx context.Context, input *
|
||||
imageSize = &result.ImageSize
|
||||
}
|
||||
accountRateMultiplier := account.BillingRateMultiplier()
|
||||
requestID := resolveUsageBillingRequestID(ctx, result.RequestID)
|
||||
usageLog := &UsageLog{
|
||||
UserID: user.ID,
|
||||
APIKeyID: apiKey.ID,
|
||||
AccountID: account.ID,
|
||||
RequestID: result.RequestID,
|
||||
RequestID: requestID,
|
||||
Model: result.Model,
|
||||
InputTokens: result.Usage.InputTokens,
|
||||
OutputTokens: result.Usage.OutputTokens,
|
||||
@@ -6981,33 +7252,32 @@ func (s *GatewayService) RecordUsageWithLongContext(ctx context.Context, input *
|
||||
usageLog.SubscriptionID = &subscription.ID
|
||||
}
|
||||
|
||||
inserted, err := s.usageLogRepo.Create(ctx, usageLog)
|
||||
if err != nil {
|
||||
logger.LegacyPrintf("service.gateway", "Create usage log failed: %v", err)
|
||||
}
|
||||
|
||||
if s.cfg != nil && s.cfg.RunMode == config.RunModeSimple {
|
||||
writeUsageLogBestEffort(ctx, s.usageLogRepo, usageLog, "service.gateway")
|
||||
logger.LegacyPrintf("service.gateway", "[SIMPLE MODE] Usage recorded (not billed): user=%d, tokens=%d", usageLog.UserID, usageLog.TotalTokens())
|
||||
s.deferredService.ScheduleLastUsedUpdate(account.ID)
|
||||
return nil
|
||||
}
|
||||
|
||||
shouldBill := inserted || err != nil
|
||||
|
||||
if shouldBill {
|
||||
postUsageBilling(ctx, &postUsageBillingParams{
|
||||
billingErr := func() error {
|
||||
_, err := applyUsageBilling(ctx, requestID, usageLog, &postUsageBillingParams{
|
||||
Cost: cost,
|
||||
User: user,
|
||||
APIKey: apiKey,
|
||||
Account: account,
|
||||
Subscription: subscription,
|
||||
RequestPayloadHash: resolveUsageBillingPayloadFingerprint(ctx, input.RequestPayloadHash),
|
||||
IsSubscriptionBill: isSubscriptionBilling,
|
||||
AccountRateMultiplier: accountRateMultiplier,
|
||||
APIKeyService: input.APIKeyService,
|
||||
}, s.billingDeps())
|
||||
} else {
|
||||
s.deferredService.ScheduleLastUsedUpdate(account.ID)
|
||||
}, s.billingDeps(), s.usageBillingRepo)
|
||||
return err
|
||||
}()
|
||||
|
||||
if billingErr != nil {
|
||||
return billingErr
|
||||
}
|
||||
writeUsageLogBestEffort(ctx, s.usageLogRepo, usageLog, "service.gateway")
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -181,7 +181,8 @@ func TestHandleStreamingResponse_EmptyStream(t *testing.T) {
|
||||
|
||||
result, err := svc.handleStreamingResponse(context.Background(), resp, c, &Account{ID: 1}, time.Now(), "model", "model", false)
|
||||
_ = pr.Close()
|
||||
require.NoError(t, err)
|
||||
require.Error(t, err)
|
||||
require.Contains(t, err.Error(), "missing terminal event")
|
||||
require.NotNil(t, result)
|
||||
}
|
||||
|
||||
|
||||
@@ -7,35 +7,63 @@ import (
|
||||
"time"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/config"
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/ctxkey"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
type openAIRecordUsageLogRepoStub struct {
|
||||
UsageLogRepository
|
||||
|
||||
inserted bool
|
||||
err error
|
||||
calls int
|
||||
lastLog *UsageLog
|
||||
inserted bool
|
||||
err error
|
||||
calls int
|
||||
lastLog *UsageLog
|
||||
lastCtxErr error
|
||||
}
|
||||
|
||||
func (s *openAIRecordUsageLogRepoStub) Create(ctx context.Context, log *UsageLog) (bool, error) {
|
||||
s.calls++
|
||||
s.lastLog = log
|
||||
s.lastCtxErr = ctx.Err()
|
||||
return s.inserted, s.err
|
||||
}
|
||||
|
||||
type openAIRecordUsageBillingRepoStub struct {
|
||||
UsageBillingRepository
|
||||
|
||||
result *UsageBillingApplyResult
|
||||
err error
|
||||
calls int
|
||||
lastCmd *UsageBillingCommand
|
||||
lastCtxErr error
|
||||
}
|
||||
|
||||
func (s *openAIRecordUsageBillingRepoStub) Apply(ctx context.Context, cmd *UsageBillingCommand) (*UsageBillingApplyResult, error) {
|
||||
s.calls++
|
||||
s.lastCmd = cmd
|
||||
s.lastCtxErr = ctx.Err()
|
||||
if s.err != nil {
|
||||
return nil, s.err
|
||||
}
|
||||
if s.result != nil {
|
||||
return s.result, nil
|
||||
}
|
||||
return &UsageBillingApplyResult{Applied: true}, nil
|
||||
}
|
||||
|
||||
type openAIRecordUsageUserRepoStub struct {
|
||||
UserRepository
|
||||
|
||||
deductCalls int
|
||||
deductErr error
|
||||
lastAmount float64
|
||||
lastCtxErr error
|
||||
}
|
||||
|
||||
func (s *openAIRecordUsageUserRepoStub) DeductBalance(ctx context.Context, id int64, amount float64) error {
|
||||
s.deductCalls++
|
||||
s.lastAmount = amount
|
||||
s.lastCtxErr = ctx.Err()
|
||||
return s.deductErr
|
||||
}
|
||||
|
||||
@@ -44,29 +72,35 @@ type openAIRecordUsageSubRepoStub struct {
|
||||
|
||||
incrementCalls int
|
||||
incrementErr error
|
||||
lastCtxErr error
|
||||
}
|
||||
|
||||
func (s *openAIRecordUsageSubRepoStub) IncrementUsage(ctx context.Context, id int64, costUSD float64) error {
|
||||
s.incrementCalls++
|
||||
s.lastCtxErr = ctx.Err()
|
||||
return s.incrementErr
|
||||
}
|
||||
|
||||
type openAIRecordUsageAPIKeyQuotaStub struct {
|
||||
quotaCalls int
|
||||
rateLimitCalls int
|
||||
err error
|
||||
lastAmount float64
|
||||
quotaCalls int
|
||||
rateLimitCalls int
|
||||
err error
|
||||
lastAmount float64
|
||||
lastQuotaCtxErr error
|
||||
lastRateLimitCtxErr error
|
||||
}
|
||||
|
||||
func (s *openAIRecordUsageAPIKeyQuotaStub) UpdateQuotaUsed(ctx context.Context, apiKeyID int64, cost float64) error {
|
||||
s.quotaCalls++
|
||||
s.lastAmount = cost
|
||||
s.lastQuotaCtxErr = ctx.Err()
|
||||
return s.err
|
||||
}
|
||||
|
||||
func (s *openAIRecordUsageAPIKeyQuotaStub) UpdateRateLimitUsage(ctx context.Context, apiKeyID int64, cost float64) error {
|
||||
s.rateLimitCalls++
|
||||
s.lastAmount = cost
|
||||
s.lastRateLimitCtxErr = ctx.Err()
|
||||
return s.err
|
||||
}
|
||||
|
||||
@@ -93,23 +127,38 @@ func i64p(v int64) *int64 {
|
||||
func newOpenAIRecordUsageServiceForTest(usageRepo UsageLogRepository, userRepo UserRepository, subRepo UserSubscriptionRepository, rateRepo UserGroupRateRepository) *OpenAIGatewayService {
|
||||
cfg := &config.Config{}
|
||||
cfg.Default.RateMultiplier = 1.1
|
||||
svc := NewOpenAIGatewayService(
|
||||
nil,
|
||||
usageRepo,
|
||||
nil,
|
||||
userRepo,
|
||||
subRepo,
|
||||
rateRepo,
|
||||
nil,
|
||||
cfg,
|
||||
nil,
|
||||
nil,
|
||||
NewBillingService(cfg, nil),
|
||||
nil,
|
||||
&BillingCacheService{},
|
||||
nil,
|
||||
&DeferredService{},
|
||||
nil,
|
||||
)
|
||||
svc.userGroupRateResolver = newUserGroupRateResolver(
|
||||
rateRepo,
|
||||
nil,
|
||||
resolveUserGroupRateCacheTTL(cfg),
|
||||
nil,
|
||||
"service.openai_gateway.test",
|
||||
)
|
||||
return svc
|
||||
}
|
||||
|
||||
return &OpenAIGatewayService{
|
||||
usageLogRepo: usageRepo,
|
||||
userRepo: userRepo,
|
||||
userSubRepo: subRepo,
|
||||
cfg: cfg,
|
||||
billingService: NewBillingService(cfg, nil),
|
||||
billingCacheService: &BillingCacheService{},
|
||||
deferredService: &DeferredService{},
|
||||
userGroupRateResolver: newUserGroupRateResolver(
|
||||
rateRepo,
|
||||
nil,
|
||||
resolveUserGroupRateCacheTTL(cfg),
|
||||
nil,
|
||||
"service.openai_gateway.test",
|
||||
),
|
||||
}
|
||||
func newOpenAIRecordUsageServiceWithBillingRepoForTest(usageRepo UsageLogRepository, billingRepo UsageBillingRepository, userRepo UserRepository, subRepo UserSubscriptionRepository, rateRepo UserGroupRateRepository) *OpenAIGatewayService {
|
||||
svc := newOpenAIRecordUsageServiceForTest(usageRepo, userRepo, subRepo, rateRepo)
|
||||
svc.usageBillingRepo = billingRepo
|
||||
return svc
|
||||
}
|
||||
|
||||
func expectedOpenAICost(t *testing.T, svc *OpenAIGatewayService, model string, usage OpenAIUsage, multiplier float64) *CostBreakdown {
|
||||
@@ -252,9 +301,10 @@ func TestOpenAIGatewayServiceRecordUsage_FallsBackToGroupDefaultRateWhenResolver
|
||||
|
||||
func TestOpenAIGatewayServiceRecordUsage_DuplicateUsageLogSkipsBilling(t *testing.T) {
|
||||
usageRepo := &openAIRecordUsageLogRepoStub{inserted: false}
|
||||
billingRepo := &openAIRecordUsageBillingRepoStub{result: &UsageBillingApplyResult{Applied: false}}
|
||||
userRepo := &openAIRecordUsageUserRepoStub{}
|
||||
subRepo := &openAIRecordUsageSubRepoStub{}
|
||||
svc := newOpenAIRecordUsageServiceForTest(usageRepo, userRepo, subRepo, nil)
|
||||
svc := newOpenAIRecordUsageServiceWithBillingRepoForTest(usageRepo, billingRepo, userRepo, subRepo, nil)
|
||||
|
||||
err := svc.RecordUsage(context.Background(), &OpenAIRecordUsageInput{
|
||||
Result: &OpenAIForwardResult{
|
||||
@@ -272,11 +322,254 @@ func TestOpenAIGatewayServiceRecordUsage_DuplicateUsageLogSkipsBilling(t *testin
|
||||
})
|
||||
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, 1, billingRepo.calls)
|
||||
require.Equal(t, 1, usageRepo.calls)
|
||||
require.Equal(t, 0, userRepo.deductCalls)
|
||||
require.Equal(t, 0, subRepo.incrementCalls)
|
||||
}
|
||||
|
||||
func TestOpenAIGatewayServiceRecordUsage_DuplicateBillingKeySkipsBillingWithRepo(t *testing.T) {
|
||||
usageRepo := &openAIRecordUsageLogRepoStub{inserted: false}
|
||||
billingRepo := &openAIRecordUsageBillingRepoStub{result: &UsageBillingApplyResult{Applied: false}}
|
||||
userRepo := &openAIRecordUsageUserRepoStub{}
|
||||
subRepo := &openAIRecordUsageSubRepoStub{}
|
||||
quotaSvc := &openAIRecordUsageAPIKeyQuotaStub{}
|
||||
svc := newOpenAIRecordUsageServiceWithBillingRepoForTest(usageRepo, billingRepo, userRepo, subRepo, nil)
|
||||
|
||||
err := svc.RecordUsage(context.Background(), &OpenAIRecordUsageInput{
|
||||
Result: &OpenAIForwardResult{
|
||||
RequestID: "resp_duplicate_billing_key",
|
||||
Usage: OpenAIUsage{
|
||||
InputTokens: 8,
|
||||
OutputTokens: 4,
|
||||
},
|
||||
Model: "gpt-5.1",
|
||||
Duration: time.Second,
|
||||
},
|
||||
APIKey: &APIKey{
|
||||
ID: 10045,
|
||||
Quota: 100,
|
||||
},
|
||||
User: &User{ID: 20045},
|
||||
Account: &Account{ID: 30045},
|
||||
APIKeyService: quotaSvc,
|
||||
})
|
||||
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, 1, billingRepo.calls)
|
||||
require.Equal(t, 1, usageRepo.calls)
|
||||
require.Equal(t, 0, userRepo.deductCalls)
|
||||
require.Equal(t, 0, subRepo.incrementCalls)
|
||||
require.Equal(t, 0, quotaSvc.quotaCalls)
|
||||
}
|
||||
|
||||
func TestOpenAIGatewayServiceRecordUsage_BillsWhenUsageLogCreateReturnsError(t *testing.T) {
|
||||
usage := OpenAIUsage{InputTokens: 8, OutputTokens: 4}
|
||||
usageRepo := &openAIRecordUsageLogRepoStub{inserted: false, err: errors.New("usage log batch state uncertain")}
|
||||
userRepo := &openAIRecordUsageUserRepoStub{}
|
||||
subRepo := &openAIRecordUsageSubRepoStub{}
|
||||
svc := newOpenAIRecordUsageServiceForTest(usageRepo, userRepo, subRepo, nil)
|
||||
|
||||
err := svc.RecordUsage(context.Background(), &OpenAIRecordUsageInput{
|
||||
Result: &OpenAIForwardResult{
|
||||
RequestID: "resp_usage_log_error",
|
||||
Usage: usage,
|
||||
Model: "gpt-5.1",
|
||||
Duration: time.Second,
|
||||
},
|
||||
APIKey: &APIKey{ID: 10041},
|
||||
User: &User{ID: 20041},
|
||||
Account: &Account{ID: 30041},
|
||||
})
|
||||
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, 1, usageRepo.calls)
|
||||
require.Equal(t, 1, userRepo.deductCalls)
|
||||
require.Equal(t, 0, subRepo.incrementCalls)
|
||||
}
|
||||
|
||||
func TestOpenAIGatewayServiceRecordUsage_UsageLogWriteErrorDoesNotSkipBilling(t *testing.T) {
|
||||
usageRepo := &openAIRecordUsageLogRepoStub{inserted: false, err: MarkUsageLogCreateNotPersisted(context.Canceled)}
|
||||
userRepo := &openAIRecordUsageUserRepoStub{}
|
||||
subRepo := &openAIRecordUsageSubRepoStub{}
|
||||
quotaSvc := &openAIRecordUsageAPIKeyQuotaStub{}
|
||||
svc := newOpenAIRecordUsageServiceForTest(usageRepo, userRepo, subRepo, nil)
|
||||
|
||||
err := svc.RecordUsage(context.Background(), &OpenAIRecordUsageInput{
|
||||
Result: &OpenAIForwardResult{
|
||||
RequestID: "resp_not_persisted",
|
||||
Usage: OpenAIUsage{
|
||||
InputTokens: 8,
|
||||
OutputTokens: 4,
|
||||
},
|
||||
Model: "gpt-5.1",
|
||||
Duration: time.Second,
|
||||
},
|
||||
APIKey: &APIKey{
|
||||
ID: 10043,
|
||||
Quota: 100,
|
||||
},
|
||||
User: &User{ID: 20043},
|
||||
Account: &Account{ID: 30043},
|
||||
APIKeyService: quotaSvc,
|
||||
})
|
||||
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, 1, usageRepo.calls)
|
||||
require.Equal(t, 1, userRepo.deductCalls)
|
||||
require.Equal(t, 0, subRepo.incrementCalls)
|
||||
require.Equal(t, 1, quotaSvc.quotaCalls)
|
||||
}
|
||||
|
||||
func TestOpenAIGatewayServiceRecordUsage_BillingUsesDetachedContext(t *testing.T) {
|
||||
usage := OpenAIUsage{InputTokens: 10, OutputTokens: 6, CacheReadInputTokens: 2}
|
||||
usageRepo := &openAIRecordUsageLogRepoStub{inserted: false, err: context.DeadlineExceeded}
|
||||
userRepo := &openAIRecordUsageUserRepoStub{}
|
||||
subRepo := &openAIRecordUsageSubRepoStub{}
|
||||
quotaSvc := &openAIRecordUsageAPIKeyQuotaStub{}
|
||||
svc := newOpenAIRecordUsageServiceForTest(usageRepo, userRepo, subRepo, nil)
|
||||
|
||||
reqCtx, cancel := context.WithCancel(context.Background())
|
||||
cancel()
|
||||
|
||||
err := svc.RecordUsage(reqCtx, &OpenAIRecordUsageInput{
|
||||
Result: &OpenAIForwardResult{
|
||||
RequestID: "resp_detached_billing_ctx",
|
||||
Usage: usage,
|
||||
Model: "gpt-5.1",
|
||||
Duration: time.Second,
|
||||
},
|
||||
APIKey: &APIKey{
|
||||
ID: 10042,
|
||||
Quota: 100,
|
||||
},
|
||||
User: &User{ID: 20042},
|
||||
Account: &Account{ID: 30042},
|
||||
APIKeyService: quotaSvc,
|
||||
})
|
||||
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, 1, userRepo.deductCalls)
|
||||
require.NoError(t, userRepo.lastCtxErr)
|
||||
require.Equal(t, 1, quotaSvc.quotaCalls)
|
||||
require.NoError(t, quotaSvc.lastQuotaCtxErr)
|
||||
}
|
||||
|
||||
func TestOpenAIGatewayServiceRecordUsage_BillingRepoUsesDetachedContext(t *testing.T) {
|
||||
usageRepo := &openAIRecordUsageLogRepoStub{}
|
||||
billingRepo := &openAIRecordUsageBillingRepoStub{result: &UsageBillingApplyResult{Applied: true}}
|
||||
userRepo := &openAIRecordUsageUserRepoStub{}
|
||||
subRepo := &openAIRecordUsageSubRepoStub{}
|
||||
svc := newOpenAIRecordUsageServiceWithBillingRepoForTest(usageRepo, billingRepo, userRepo, subRepo, nil)
|
||||
|
||||
reqCtx, cancel := context.WithCancel(context.Background())
|
||||
cancel()
|
||||
|
||||
err := svc.RecordUsage(reqCtx, &OpenAIRecordUsageInput{
|
||||
Result: &OpenAIForwardResult{
|
||||
RequestID: "resp_detached_billing_repo_ctx",
|
||||
Usage: OpenAIUsage{
|
||||
InputTokens: 8,
|
||||
OutputTokens: 4,
|
||||
},
|
||||
Model: "gpt-5.1",
|
||||
Duration: time.Second,
|
||||
},
|
||||
APIKey: &APIKey{ID: 10046},
|
||||
User: &User{ID: 20046},
|
||||
Account: &Account{ID: 30046},
|
||||
})
|
||||
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, 1, billingRepo.calls)
|
||||
require.NoError(t, billingRepo.lastCtxErr)
|
||||
require.Equal(t, 1, usageRepo.calls)
|
||||
require.NoError(t, usageRepo.lastCtxErr)
|
||||
}
|
||||
|
||||
func TestOpenAIGatewayServiceRecordUsage_BillingFingerprintIncludesRequestPayloadHash(t *testing.T) {
|
||||
usageRepo := &openAIRecordUsageLogRepoStub{}
|
||||
billingRepo := &openAIRecordUsageBillingRepoStub{result: &UsageBillingApplyResult{Applied: true}}
|
||||
svc := newOpenAIRecordUsageServiceWithBillingRepoForTest(usageRepo, billingRepo, &openAIRecordUsageUserRepoStub{}, &openAIRecordUsageSubRepoStub{}, nil)
|
||||
|
||||
payloadHash := HashUsageRequestPayload([]byte(`{"model":"gpt-5","input":"hello"}`))
|
||||
err := svc.RecordUsage(context.Background(), &OpenAIRecordUsageInput{
|
||||
Result: &OpenAIForwardResult{
|
||||
RequestID: "openai_payload_hash",
|
||||
Usage: OpenAIUsage{
|
||||
InputTokens: 10,
|
||||
OutputTokens: 6,
|
||||
},
|
||||
Model: "gpt-5",
|
||||
Duration: time.Second,
|
||||
},
|
||||
APIKey: &APIKey{ID: 501, Quota: 100},
|
||||
User: &User{ID: 601},
|
||||
Account: &Account{ID: 701},
|
||||
RequestPayloadHash: payloadHash,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, billingRepo.lastCmd)
|
||||
require.Equal(t, payloadHash, billingRepo.lastCmd.RequestPayloadHash)
|
||||
}
|
||||
|
||||
func TestOpenAIGatewayServiceRecordUsage_UsesFallbackRequestIDForBillingAndUsageLog(t *testing.T) {
|
||||
usageRepo := &openAIRecordUsageLogRepoStub{}
|
||||
billingRepo := &openAIRecordUsageBillingRepoStub{result: &UsageBillingApplyResult{Applied: true}}
|
||||
userRepo := &openAIRecordUsageUserRepoStub{}
|
||||
subRepo := &openAIRecordUsageSubRepoStub{}
|
||||
svc := newOpenAIRecordUsageServiceWithBillingRepoForTest(usageRepo, billingRepo, userRepo, subRepo, nil)
|
||||
|
||||
ctx := context.WithValue(context.Background(), ctxkey.RequestID, "req-local-fallback")
|
||||
err := svc.RecordUsage(ctx, &OpenAIRecordUsageInput{
|
||||
Result: &OpenAIForwardResult{
|
||||
RequestID: "",
|
||||
Usage: OpenAIUsage{
|
||||
InputTokens: 8,
|
||||
OutputTokens: 4,
|
||||
},
|
||||
Model: "gpt-5.1",
|
||||
Duration: time.Second,
|
||||
},
|
||||
APIKey: &APIKey{ID: 10047},
|
||||
User: &User{ID: 20047},
|
||||
Account: &Account{ID: 30047},
|
||||
})
|
||||
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, billingRepo.lastCmd)
|
||||
require.Equal(t, "local:req-local-fallback", billingRepo.lastCmd.RequestID)
|
||||
require.NotNil(t, usageRepo.lastLog)
|
||||
require.Equal(t, "local:req-local-fallback", usageRepo.lastLog.RequestID)
|
||||
}
|
||||
|
||||
func TestOpenAIGatewayServiceRecordUsage_BillingErrorSkipsUsageLogWrite(t *testing.T) {
|
||||
usageRepo := &openAIRecordUsageLogRepoStub{}
|
||||
billingRepo := &openAIRecordUsageBillingRepoStub{err: errors.New("billing tx failed")}
|
||||
userRepo := &openAIRecordUsageUserRepoStub{}
|
||||
subRepo := &openAIRecordUsageSubRepoStub{}
|
||||
svc := newOpenAIRecordUsageServiceWithBillingRepoForTest(usageRepo, billingRepo, userRepo, subRepo, nil)
|
||||
|
||||
err := svc.RecordUsage(context.Background(), &OpenAIRecordUsageInput{
|
||||
Result: &OpenAIForwardResult{
|
||||
RequestID: "resp_billing_fail",
|
||||
Usage: OpenAIUsage{
|
||||
InputTokens: 8,
|
||||
OutputTokens: 4,
|
||||
},
|
||||
Model: "gpt-5.1",
|
||||
Duration: time.Second,
|
||||
},
|
||||
APIKey: &APIKey{ID: 10048},
|
||||
User: &User{ID: 20048},
|
||||
Account: &Account{ID: 30048},
|
||||
})
|
||||
|
||||
require.Error(t, err)
|
||||
require.Equal(t, 1, billingRepo.calls)
|
||||
require.Equal(t, 0, usageRepo.calls)
|
||||
}
|
||||
|
||||
func TestOpenAIGatewayServiceRecordUsage_UpdatesAPIKeyQuotaWhenConfigured(t *testing.T) {
|
||||
usage := OpenAIUsage{InputTokens: 10, OutputTokens: 6, CacheReadInputTokens: 2}
|
||||
usageRepo := &openAIRecordUsageLogRepoStub{inserted: true}
|
||||
|
||||
@@ -259,6 +259,7 @@ type openAIWSRetryMetrics struct {
|
||||
type OpenAIGatewayService struct {
|
||||
accountRepo AccountRepository
|
||||
usageLogRepo UsageLogRepository
|
||||
usageBillingRepo UsageBillingRepository
|
||||
userRepo UserRepository
|
||||
userSubRepo UserSubscriptionRepository
|
||||
cache GatewayCache
|
||||
@@ -295,6 +296,7 @@ type OpenAIGatewayService struct {
|
||||
func NewOpenAIGatewayService(
|
||||
accountRepo AccountRepository,
|
||||
usageLogRepo UsageLogRepository,
|
||||
usageBillingRepo UsageBillingRepository,
|
||||
userRepo UserRepository,
|
||||
userSubRepo UserSubscriptionRepository,
|
||||
userGroupRateRepo UserGroupRateRepository,
|
||||
@@ -312,6 +314,7 @@ func NewOpenAIGatewayService(
|
||||
svc := &OpenAIGatewayService{
|
||||
accountRepo: accountRepo,
|
||||
usageLogRepo: usageLogRepo,
|
||||
usageBillingRepo: usageBillingRepo,
|
||||
userRepo: userRepo,
|
||||
userSubRepo: userSubRepo,
|
||||
cache: cache,
|
||||
@@ -2014,7 +2017,9 @@ func (s *OpenAIGatewayService) Forward(ctx context.Context, c *gin.Context, acco
|
||||
}
|
||||
|
||||
// Build upstream request
|
||||
upstreamReq, err := s.buildUpstreamRequest(ctx, c, account, body, token, reqStream, promptCacheKey, isCodexCLI)
|
||||
upstreamCtx, releaseUpstreamCtx := detachStreamUpstreamContext(ctx, reqStream)
|
||||
upstreamReq, err := s.buildUpstreamRequest(upstreamCtx, c, account, body, token, reqStream, promptCacheKey, isCodexCLI)
|
||||
releaseUpstreamCtx()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -2206,7 +2211,9 @@ func (s *OpenAIGatewayService) forwardOpenAIPassthrough(
|
||||
return nil, err
|
||||
}
|
||||
|
||||
upstreamReq, err := s.buildUpstreamRequestOpenAIPassthrough(ctx, c, account, body, token)
|
||||
upstreamCtx, releaseUpstreamCtx := detachStreamUpstreamContext(ctx, reqStream)
|
||||
upstreamReq, err := s.buildUpstreamRequestOpenAIPassthrough(upstreamCtx, c, account, body, token)
|
||||
releaseUpstreamCtx()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -2543,6 +2550,7 @@ func (s *OpenAIGatewayService) handleStreamingResponsePassthrough(
|
||||
var firstTokenMs *int
|
||||
clientDisconnected := false
|
||||
sawDone := false
|
||||
sawTerminalEvent := false
|
||||
upstreamRequestID := strings.TrimSpace(resp.Header.Get("x-request-id"))
|
||||
|
||||
scanner := bufio.NewScanner(resp.Body)
|
||||
@@ -2562,6 +2570,9 @@ func (s *OpenAIGatewayService) handleStreamingResponsePassthrough(
|
||||
if trimmedData == "[DONE]" {
|
||||
sawDone = true
|
||||
}
|
||||
if openAIStreamEventIsTerminal(trimmedData) {
|
||||
sawTerminalEvent = true
|
||||
}
|
||||
if firstTokenMs == nil && trimmedData != "" && trimmedData != "[DONE]" {
|
||||
ms := int(time.Since(startTime).Milliseconds())
|
||||
firstTokenMs = &ms
|
||||
@@ -2579,19 +2590,14 @@ func (s *OpenAIGatewayService) handleStreamingResponsePassthrough(
|
||||
}
|
||||
}
|
||||
if err := scanner.Err(); err != nil {
|
||||
if clientDisconnected {
|
||||
logger.LegacyPrintf("service.openai_gateway", "[OpenAI passthrough] Upstream read error after client disconnect: account=%d err=%v", account.ID, err)
|
||||
if sawTerminalEvent {
|
||||
return &openaiStreamingResultPassthrough{usage: usage, firstTokenMs: firstTokenMs}, nil
|
||||
}
|
||||
if clientDisconnected {
|
||||
return &openaiStreamingResultPassthrough{usage: usage, firstTokenMs: firstTokenMs}, fmt.Errorf("stream usage incomplete after disconnect: %w", err)
|
||||
}
|
||||
if errors.Is(err, context.Canceled) || errors.Is(err, context.DeadlineExceeded) {
|
||||
logger.LegacyPrintf("service.openai_gateway",
|
||||
"[OpenAI passthrough] 流读取被取消,可能发生断流: account=%d request_id=%s err=%v ctx_err=%v",
|
||||
account.ID,
|
||||
upstreamRequestID,
|
||||
err,
|
||||
ctx.Err(),
|
||||
)
|
||||
return &openaiStreamingResultPassthrough{usage: usage, firstTokenMs: firstTokenMs}, nil
|
||||
return &openaiStreamingResultPassthrough{usage: usage, firstTokenMs: firstTokenMs}, fmt.Errorf("stream usage incomplete: %w", err)
|
||||
}
|
||||
if errors.Is(err, bufio.ErrTooLong) {
|
||||
logger.LegacyPrintf("service.openai_gateway", "[OpenAI passthrough] SSE line too long: account=%d max_size=%d error=%v", account.ID, maxLineSize, err)
|
||||
@@ -2605,12 +2611,13 @@ func (s *OpenAIGatewayService) handleStreamingResponsePassthrough(
|
||||
)
|
||||
return &openaiStreamingResultPassthrough{usage: usage, firstTokenMs: firstTokenMs}, fmt.Errorf("stream read error: %w", err)
|
||||
}
|
||||
if !clientDisconnected && !sawDone && ctx.Err() == nil {
|
||||
if !clientDisconnected && !sawDone && !sawTerminalEvent && ctx.Err() == nil {
|
||||
logger.FromContext(ctx).With(
|
||||
zap.String("component", "service.openai_gateway"),
|
||||
zap.Int64("account_id", account.ID),
|
||||
zap.String("upstream_request_id", upstreamRequestID),
|
||||
).Info("OpenAI passthrough 上游流在未收到 [DONE] 时结束,疑似断流")
|
||||
return &openaiStreamingResultPassthrough{usage: usage, firstTokenMs: firstTokenMs}, errors.New("stream usage incomplete: missing terminal event")
|
||||
}
|
||||
|
||||
return &openaiStreamingResultPassthrough{usage: usage, firstTokenMs: firstTokenMs}, nil
|
||||
@@ -3030,6 +3037,7 @@ func (s *OpenAIGatewayService) handleStreamingResponse(ctx context.Context, resp
|
||||
// 否则下游 SDK(例如 OpenCode)会因为类型校验失败而报错。
|
||||
errorEventSent := false
|
||||
clientDisconnected := false // 客户端断开后继续 drain 上游以收集 usage
|
||||
sawTerminalEvent := false
|
||||
sendErrorEvent := func(reason string) {
|
||||
if errorEventSent || clientDisconnected {
|
||||
return
|
||||
@@ -3060,22 +3068,27 @@ func (s *OpenAIGatewayService) handleStreamingResponse(ctx context.Context, resp
|
||||
logger.LegacyPrintf("service.openai_gateway", "Client disconnected during final flush, returning collected usage")
|
||||
}
|
||||
}
|
||||
if !sawTerminalEvent {
|
||||
return resultWithUsage(), fmt.Errorf("stream usage incomplete: missing terminal event")
|
||||
}
|
||||
return resultWithUsage(), nil
|
||||
}
|
||||
handleScanErr := func(scanErr error) (*openaiStreamingResult, error, bool) {
|
||||
if scanErr == nil {
|
||||
return nil, nil, false
|
||||
}
|
||||
if sawTerminalEvent {
|
||||
logger.LegacyPrintf("service.openai_gateway", "Upstream scan ended after terminal event: %v", scanErr)
|
||||
return resultWithUsage(), nil, true
|
||||
}
|
||||
// 客户端断开/取消请求时,上游读取往往会返回 context canceled。
|
||||
// /v1/responses 的 SSE 事件必须符合 OpenAI 协议;这里不注入自定义 error event,避免下游 SDK 解析失败。
|
||||
if errors.Is(scanErr, context.Canceled) || errors.Is(scanErr, context.DeadlineExceeded) {
|
||||
logger.LegacyPrintf("service.openai_gateway", "Context canceled during streaming, returning collected usage")
|
||||
return resultWithUsage(), nil, true
|
||||
return resultWithUsage(), fmt.Errorf("stream usage incomplete: %w", scanErr), true
|
||||
}
|
||||
// 客户端已断开时,上游出错仅影响体验,不影响计费;返回已收集 usage
|
||||
if clientDisconnected {
|
||||
logger.LegacyPrintf("service.openai_gateway", "Upstream read error after client disconnect: %v, returning collected usage", scanErr)
|
||||
return resultWithUsage(), nil, true
|
||||
return resultWithUsage(), fmt.Errorf("stream usage incomplete after disconnect: %w", scanErr), true
|
||||
}
|
||||
if errors.Is(scanErr, bufio.ErrTooLong) {
|
||||
logger.LegacyPrintf("service.openai_gateway", "SSE line too long: account=%d max_size=%d error=%v", account.ID, maxLineSize, scanErr)
|
||||
@@ -3098,6 +3111,9 @@ func (s *OpenAIGatewayService) handleStreamingResponse(ctx context.Context, resp
|
||||
}
|
||||
|
||||
dataBytes := []byte(data)
|
||||
if openAIStreamEventIsTerminal(data) {
|
||||
sawTerminalEvent = true
|
||||
}
|
||||
|
||||
// Correct Codex tool calls if needed (apply_patch -> edit, etc.)
|
||||
if correctedData, corrected := s.toolCorrector.CorrectToolCallsInSSEBytes(dataBytes); corrected {
|
||||
@@ -3214,8 +3230,7 @@ func (s *OpenAIGatewayService) handleStreamingResponse(ctx context.Context, resp
|
||||
continue
|
||||
}
|
||||
if clientDisconnected {
|
||||
logger.LegacyPrintf("service.openai_gateway", "Upstream timeout after client disconnect, returning collected usage")
|
||||
return resultWithUsage(), nil
|
||||
return resultWithUsage(), fmt.Errorf("stream usage incomplete after timeout")
|
||||
}
|
||||
logger.LegacyPrintf("service.openai_gateway", "Stream data interval timeout: account=%d model=%s interval=%s", account.ID, originalModel, streamInterval)
|
||||
// 处理流超时,可能标记账户为临时不可调度或错误状态
|
||||
@@ -3313,11 +3328,12 @@ func (s *OpenAIGatewayService) parseSSEUsageBytes(data []byte, usage *OpenAIUsag
|
||||
if usage == nil || len(data) == 0 || bytes.Equal(data, []byte("[DONE]")) {
|
||||
return
|
||||
}
|
||||
// 选择性解析:仅在数据中包含 completed 事件标识时才进入字段提取。
|
||||
if len(data) < 80 || !bytes.Contains(data, []byte(`"response.completed"`)) {
|
||||
// 选择性解析:仅在数据中包含终止事件标识时才进入字段提取。
|
||||
if len(data) < 72 {
|
||||
return
|
||||
}
|
||||
if gjson.GetBytes(data, "type").String() != "response.completed" {
|
||||
eventType := gjson.GetBytes(data, "type").String()
|
||||
if eventType != "response.completed" && eventType != "response.done" {
|
||||
return
|
||||
}
|
||||
|
||||
@@ -3670,14 +3686,15 @@ func (s *OpenAIGatewayService) replaceModelInResponseBody(body []byte, fromModel
|
||||
|
||||
// OpenAIRecordUsageInput input for recording usage
|
||||
type OpenAIRecordUsageInput struct {
|
||||
Result *OpenAIForwardResult
|
||||
APIKey *APIKey
|
||||
User *User
|
||||
Account *Account
|
||||
Subscription *UserSubscription
|
||||
UserAgent string // 请求的 User-Agent
|
||||
IPAddress string // 请求的客户端 IP 地址
|
||||
APIKeyService APIKeyQuotaUpdater
|
||||
Result *OpenAIForwardResult
|
||||
APIKey *APIKey
|
||||
User *User
|
||||
Account *Account
|
||||
Subscription *UserSubscription
|
||||
UserAgent string // 请求的 User-Agent
|
||||
IPAddress string // 请求的客户端 IP 地址
|
||||
RequestPayloadHash string
|
||||
APIKeyService APIKeyQuotaUpdater
|
||||
}
|
||||
|
||||
// RecordUsage records usage and deducts balance
|
||||
@@ -3743,11 +3760,12 @@ func (s *OpenAIGatewayService) RecordUsage(ctx context.Context, input *OpenAIRec
|
||||
// Create usage log
|
||||
durationMs := int(result.Duration.Milliseconds())
|
||||
accountRateMultiplier := account.BillingRateMultiplier()
|
||||
requestID := resolveUsageBillingRequestID(ctx, result.RequestID)
|
||||
usageLog := &UsageLog{
|
||||
UserID: user.ID,
|
||||
APIKeyID: apiKey.ID,
|
||||
AccountID: account.ID,
|
||||
RequestID: result.RequestID,
|
||||
RequestID: requestID,
|
||||
Model: billingModel,
|
||||
ServiceTier: result.ServiceTier,
|
||||
ReasoningEffort: result.ReasoningEffort,
|
||||
@@ -3788,29 +3806,32 @@ func (s *OpenAIGatewayService) RecordUsage(ctx context.Context, input *OpenAIRec
|
||||
usageLog.SubscriptionID = &subscription.ID
|
||||
}
|
||||
|
||||
inserted, err := s.usageLogRepo.Create(ctx, usageLog)
|
||||
if s.cfg != nil && s.cfg.RunMode == config.RunModeSimple {
|
||||
writeUsageLogBestEffort(ctx, s.usageLogRepo, usageLog, "service.openai_gateway")
|
||||
logger.LegacyPrintf("service.openai_gateway", "[SIMPLE MODE] Usage recorded (not billed): user=%d, tokens=%d", usageLog.UserID, usageLog.TotalTokens())
|
||||
s.deferredService.ScheduleLastUsedUpdate(account.ID)
|
||||
return nil
|
||||
}
|
||||
|
||||
shouldBill := inserted || err != nil
|
||||
|
||||
if shouldBill {
|
||||
postUsageBilling(ctx, &postUsageBillingParams{
|
||||
billingErr := func() error {
|
||||
_, err := applyUsageBilling(ctx, requestID, usageLog, &postUsageBillingParams{
|
||||
Cost: cost,
|
||||
User: user,
|
||||
APIKey: apiKey,
|
||||
Account: account,
|
||||
Subscription: subscription,
|
||||
RequestPayloadHash: resolveUsageBillingPayloadFingerprint(ctx, input.RequestPayloadHash),
|
||||
IsSubscriptionBill: isSubscriptionBilling,
|
||||
AccountRateMultiplier: accountRateMultiplier,
|
||||
APIKeyService: input.APIKeyService,
|
||||
}, s.billingDeps())
|
||||
} else {
|
||||
s.deferredService.ScheduleLastUsedUpdate(account.ID)
|
||||
}, s.billingDeps(), s.usageBillingRepo)
|
||||
return err
|
||||
}()
|
||||
|
||||
if billingErr != nil {
|
||||
return billingErr
|
||||
}
|
||||
writeUsageLogBestEffort(ctx, s.usageLogRepo, usageLog, "service.openai_gateway")
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -916,7 +916,7 @@ func TestOpenAIStreamingTimeout(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestOpenAIStreamingContextCanceledDoesNotInjectErrorEvent(t *testing.T) {
|
||||
func TestOpenAIStreamingContextCanceledReturnsIncompleteErrorWithoutInjectingErrorEvent(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
cfg := &config.Config{
|
||||
Gateway: config.GatewayConfig{
|
||||
@@ -940,8 +940,8 @@ func TestOpenAIStreamingContextCanceledDoesNotInjectErrorEvent(t *testing.T) {
|
||||
}
|
||||
|
||||
_, err := svc.handleStreamingResponse(c.Request.Context(), resp, c, &Account{ID: 1}, time.Now(), "model", "model")
|
||||
if err != nil {
|
||||
t.Fatalf("expected nil error, got %v", err)
|
||||
if err == nil || !strings.Contains(err.Error(), "stream usage incomplete") {
|
||||
t.Fatalf("expected incomplete stream error, got %v", err)
|
||||
}
|
||||
if strings.Contains(rec.Body.String(), "event: error") || strings.Contains(rec.Body.String(), "stream_read_error") {
|
||||
t.Fatalf("expected no injected SSE error event, got %q", rec.Body.String())
|
||||
@@ -993,6 +993,107 @@ func TestOpenAIStreamingClientDisconnectDrainsUpstreamUsage(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestOpenAIStreamingMissingTerminalEventReturnsIncompleteError(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
cfg := &config.Config{
|
||||
Gateway: config.GatewayConfig{
|
||||
StreamDataIntervalTimeout: 0,
|
||||
StreamKeepaliveInterval: 0,
|
||||
MaxLineSize: defaultMaxLineSize,
|
||||
},
|
||||
}
|
||||
svc := &OpenAIGatewayService{cfg: cfg}
|
||||
|
||||
rec := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(rec)
|
||||
c.Request = httptest.NewRequest(http.MethodPost, "/", nil)
|
||||
|
||||
pr, pw := io.Pipe()
|
||||
resp := &http.Response{
|
||||
StatusCode: http.StatusOK,
|
||||
Body: pr,
|
||||
Header: http.Header{},
|
||||
}
|
||||
|
||||
go func() {
|
||||
defer func() { _ = pw.Close() }()
|
||||
_, _ = pw.Write([]byte("data: {\"type\":\"response.in_progress\",\"response\":{}}\n\n"))
|
||||
}()
|
||||
|
||||
_, err := svc.handleStreamingResponse(c.Request.Context(), resp, c, &Account{ID: 1}, time.Now(), "model", "model")
|
||||
_ = pr.Close()
|
||||
if err == nil || !strings.Contains(err.Error(), "missing terminal event") {
|
||||
t.Fatalf("expected missing terminal event error, got %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestOpenAIStreamingPassthroughMissingTerminalEventReturnsIncompleteError(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
cfg := &config.Config{
|
||||
Gateway: config.GatewayConfig{
|
||||
MaxLineSize: defaultMaxLineSize,
|
||||
},
|
||||
}
|
||||
svc := &OpenAIGatewayService{cfg: cfg}
|
||||
|
||||
rec := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(rec)
|
||||
c.Request = httptest.NewRequest(http.MethodPost, "/", nil)
|
||||
|
||||
pr, pw := io.Pipe()
|
||||
resp := &http.Response{
|
||||
StatusCode: http.StatusOK,
|
||||
Body: pr,
|
||||
Header: http.Header{},
|
||||
}
|
||||
|
||||
go func() {
|
||||
defer func() { _ = pw.Close() }()
|
||||
_, _ = pw.Write([]byte("data: {\"type\":\"response.in_progress\",\"response\":{}}\n\n"))
|
||||
}()
|
||||
|
||||
_, err := svc.handleStreamingResponsePassthrough(c.Request.Context(), resp, c, &Account{ID: 1}, time.Now())
|
||||
_ = pr.Close()
|
||||
if err == nil || !strings.Contains(err.Error(), "missing terminal event") {
|
||||
t.Fatalf("expected missing terminal event error, got %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestOpenAIStreamingPassthroughResponseDoneWithoutDoneMarkerStillSucceeds(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
cfg := &config.Config{
|
||||
Gateway: config.GatewayConfig{
|
||||
MaxLineSize: defaultMaxLineSize,
|
||||
},
|
||||
}
|
||||
svc := &OpenAIGatewayService{cfg: cfg}
|
||||
|
||||
rec := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(rec)
|
||||
c.Request = httptest.NewRequest(http.MethodPost, "/", nil)
|
||||
|
||||
pr, pw := io.Pipe()
|
||||
resp := &http.Response{
|
||||
StatusCode: http.StatusOK,
|
||||
Body: pr,
|
||||
Header: http.Header{},
|
||||
}
|
||||
|
||||
go func() {
|
||||
defer func() { _ = pw.Close() }()
|
||||
_, _ = pw.Write([]byte("data: {\"type\":\"response.done\",\"response\":{\"usage\":{\"input_tokens\":2,\"output_tokens\":3,\"input_tokens_details\":{\"cached_tokens\":1}}}}\n\n"))
|
||||
}()
|
||||
|
||||
result, err := svc.handleStreamingResponsePassthrough(c.Request.Context(), resp, c, &Account{ID: 1}, time.Now())
|
||||
_ = pr.Close()
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, result)
|
||||
require.NotNil(t, result.usage)
|
||||
require.Equal(t, 2, result.usage.InputTokens)
|
||||
require.Equal(t, 3, result.usage.OutputTokens)
|
||||
require.Equal(t, 1, result.usage.CacheReadInputTokens)
|
||||
}
|
||||
|
||||
func TestOpenAIStreamingTooLong(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
cfg := &config.Config{
|
||||
@@ -1124,7 +1225,7 @@ func TestOpenAIStreamingHeadersOverride(t *testing.T) {
|
||||
|
||||
go func() {
|
||||
defer func() { _ = pw.Close() }()
|
||||
_, _ = pw.Write([]byte("data: {}\n\n"))
|
||||
_, _ = pw.Write([]byte("data: {\"type\":\"response.completed\",\"response\":{}}\n\n"))
|
||||
}()
|
||||
|
||||
_, err := svc.handleStreamingResponse(c.Request.Context(), resp, c, &Account{ID: 1}, time.Now(), "model", "model")
|
||||
@@ -1674,6 +1775,12 @@ func TestParseSSEUsage_SelectiveParsing(t *testing.T) {
|
||||
require.Equal(t, 3, usage.InputTokens)
|
||||
require.Equal(t, 5, usage.OutputTokens)
|
||||
require.Equal(t, 2, usage.CacheReadInputTokens)
|
||||
|
||||
// done 事件同样可能携带最终 usage
|
||||
svc.parseSSEUsage(`{"type":"response.done","response":{"usage":{"input_tokens":13,"output_tokens":15,"input_tokens_details":{"cached_tokens":4}}}}`, usage)
|
||||
require.Equal(t, 13, usage.InputTokens)
|
||||
require.Equal(t, 15, usage.OutputTokens)
|
||||
require.Equal(t, 4, usage.CacheReadInputTokens)
|
||||
}
|
||||
|
||||
func TestExtractCodexFinalResponse_SampleReplay(t *testing.T) {
|
||||
|
||||
@@ -392,6 +392,7 @@ func TestNewOpenAIGatewayService_InitializesOpenAIWSResolver(t *testing.T) {
|
||||
nil,
|
||||
nil,
|
||||
nil,
|
||||
nil,
|
||||
cfg,
|
||||
nil,
|
||||
nil,
|
||||
|
||||
110
backend/internal/service/usage_billing.go
Normal file
110
backend/internal/service/usage_billing.go
Normal file
@@ -0,0 +1,110 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/sha256"
|
||||
"encoding/hex"
|
||||
"errors"
|
||||
"fmt"
|
||||
"strings"
|
||||
)
|
||||
|
||||
var ErrUsageBillingRequestIDRequired = errors.New("usage billing request_id is required")
|
||||
var ErrUsageBillingRequestConflict = errors.New("usage billing request fingerprint conflict")
|
||||
|
||||
// UsageBillingCommand describes one billable request that must be applied at most once.
|
||||
type UsageBillingCommand struct {
|
||||
RequestID string
|
||||
APIKeyID int64
|
||||
RequestFingerprint string
|
||||
RequestPayloadHash string
|
||||
|
||||
UserID int64
|
||||
AccountID int64
|
||||
SubscriptionID *int64
|
||||
AccountType string
|
||||
Model string
|
||||
ServiceTier string
|
||||
ReasoningEffort string
|
||||
BillingType int8
|
||||
InputTokens int
|
||||
OutputTokens int
|
||||
CacheCreationTokens int
|
||||
CacheReadTokens int
|
||||
ImageCount int
|
||||
MediaType string
|
||||
|
||||
BalanceCost float64
|
||||
SubscriptionCost float64
|
||||
APIKeyQuotaCost float64
|
||||
APIKeyRateLimitCost float64
|
||||
AccountQuotaCost float64
|
||||
}
|
||||
|
||||
func (c *UsageBillingCommand) Normalize() {
|
||||
if c == nil {
|
||||
return
|
||||
}
|
||||
c.RequestID = strings.TrimSpace(c.RequestID)
|
||||
if strings.TrimSpace(c.RequestFingerprint) == "" {
|
||||
c.RequestFingerprint = buildUsageBillingFingerprint(c)
|
||||
}
|
||||
}
|
||||
|
||||
func buildUsageBillingFingerprint(c *UsageBillingCommand) string {
|
||||
if c == nil {
|
||||
return ""
|
||||
}
|
||||
raw := fmt.Sprintf(
|
||||
"%d|%d|%d|%s|%s|%s|%s|%d|%d|%d|%d|%d|%d|%s|%d|%0.10f|%0.10f|%0.10f|%0.10f|%0.10f",
|
||||
c.UserID,
|
||||
c.AccountID,
|
||||
c.APIKeyID,
|
||||
strings.TrimSpace(c.AccountType),
|
||||
strings.TrimSpace(c.Model),
|
||||
strings.TrimSpace(c.ServiceTier),
|
||||
strings.TrimSpace(c.ReasoningEffort),
|
||||
c.BillingType,
|
||||
c.InputTokens,
|
||||
c.OutputTokens,
|
||||
c.CacheCreationTokens,
|
||||
c.CacheReadTokens,
|
||||
c.ImageCount,
|
||||
strings.TrimSpace(c.MediaType),
|
||||
valueOrZero(c.SubscriptionID),
|
||||
c.BalanceCost,
|
||||
c.SubscriptionCost,
|
||||
c.APIKeyQuotaCost,
|
||||
c.APIKeyRateLimitCost,
|
||||
c.AccountQuotaCost,
|
||||
)
|
||||
if payloadHash := strings.TrimSpace(c.RequestPayloadHash); payloadHash != "" {
|
||||
raw += "|" + payloadHash
|
||||
}
|
||||
sum := sha256.Sum256([]byte(raw))
|
||||
return hex.EncodeToString(sum[:])
|
||||
}
|
||||
|
||||
func HashUsageRequestPayload(payload []byte) string {
|
||||
if len(payload) == 0 {
|
||||
return ""
|
||||
}
|
||||
sum := sha256.Sum256(payload)
|
||||
return hex.EncodeToString(sum[:])
|
||||
}
|
||||
|
||||
func valueOrZero(v *int64) int64 {
|
||||
if v == nil {
|
||||
return 0
|
||||
}
|
||||
return *v
|
||||
}
|
||||
|
||||
type UsageBillingApplyResult struct {
|
||||
Applied bool
|
||||
APIKeyQuotaExhausted bool
|
||||
}
|
||||
|
||||
type UsageBillingRepository interface {
|
||||
Apply(ctx context.Context, cmd *UsageBillingCommand) (*UsageBillingApplyResult, error)
|
||||
}
|
||||
@@ -56,7 +56,8 @@ type cleanupRepoStub struct {
|
||||
}
|
||||
|
||||
type dashboardRepoStub struct {
|
||||
recomputeErr error
|
||||
recomputeErr error
|
||||
recomputeCalls int
|
||||
}
|
||||
|
||||
func (s *dashboardRepoStub) AggregateRange(ctx context.Context, start, end time.Time) error {
|
||||
@@ -64,6 +65,7 @@ func (s *dashboardRepoStub) AggregateRange(ctx context.Context, start, end time.
|
||||
}
|
||||
|
||||
func (s *dashboardRepoStub) RecomputeRange(ctx context.Context, start, end time.Time) error {
|
||||
s.recomputeCalls++
|
||||
return s.recomputeErr
|
||||
}
|
||||
|
||||
@@ -83,6 +85,10 @@ func (s *dashboardRepoStub) CleanupUsageLogs(ctx context.Context, cutoff time.Ti
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *dashboardRepoStub) CleanupUsageBillingDedup(ctx context.Context, cutoff time.Time) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *dashboardRepoStub) EnsureUsageLogsPartitions(ctx context.Context, now time.Time) error {
|
||||
return nil
|
||||
}
|
||||
@@ -550,13 +556,14 @@ func TestUsageCleanupServiceExecuteTaskMarkFailedUpdateError(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestUsageCleanupServiceExecuteTaskDashboardRecomputeError(t *testing.T) {
|
||||
dashboardRepo := &dashboardRepoStub{recomputeErr: errors.New("recompute failed")}
|
||||
repo := &cleanupRepoStub{
|
||||
deleteQueue: []cleanupDeleteResponse{
|
||||
{deleted: 0},
|
||||
},
|
||||
}
|
||||
dashboard := NewDashboardAggregationService(&dashboardRepoStub{}, nil, &config.Config{
|
||||
DashboardAgg: config.DashboardAggregationConfig{Enabled: false},
|
||||
dashboard := NewDashboardAggregationService(dashboardRepo, nil, &config.Config{
|
||||
DashboardAgg: config.DashboardAggregationConfig{Enabled: true},
|
||||
})
|
||||
cfg := &config.Config{UsageCleanup: config.UsageCleanupConfig{Enabled: true, BatchSize: 2}}
|
||||
svc := NewUsageCleanupService(repo, nil, dashboard, cfg)
|
||||
@@ -573,15 +580,17 @@ func TestUsageCleanupServiceExecuteTaskDashboardRecomputeError(t *testing.T) {
|
||||
repo.mu.Lock()
|
||||
defer repo.mu.Unlock()
|
||||
require.Len(t, repo.markSucceeded, 1)
|
||||
require.Eventually(t, func() bool { return dashboardRepo.recomputeCalls == 1 }, time.Second, 10*time.Millisecond)
|
||||
}
|
||||
|
||||
func TestUsageCleanupServiceExecuteTaskDashboardRecomputeSuccess(t *testing.T) {
|
||||
dashboardRepo := &dashboardRepoStub{}
|
||||
repo := &cleanupRepoStub{
|
||||
deleteQueue: []cleanupDeleteResponse{
|
||||
{deleted: 0},
|
||||
},
|
||||
}
|
||||
dashboard := NewDashboardAggregationService(&dashboardRepoStub{}, nil, &config.Config{
|
||||
dashboard := NewDashboardAggregationService(dashboardRepo, nil, &config.Config{
|
||||
DashboardAgg: config.DashboardAggregationConfig{Enabled: true},
|
||||
})
|
||||
cfg := &config.Config{UsageCleanup: config.UsageCleanupConfig{Enabled: true, BatchSize: 2}}
|
||||
@@ -599,6 +608,7 @@ func TestUsageCleanupServiceExecuteTaskDashboardRecomputeSuccess(t *testing.T) {
|
||||
repo.mu.Lock()
|
||||
defer repo.mu.Unlock()
|
||||
require.Len(t, repo.markSucceeded, 1)
|
||||
require.Eventually(t, func() bool { return dashboardRepo.recomputeCalls == 1 }, time.Second, 10*time.Millisecond)
|
||||
}
|
||||
|
||||
func TestUsageCleanupServiceExecuteTaskCanceled(t *testing.T) {
|
||||
|
||||
60
backend/internal/service/usage_log_create_result.go
Normal file
60
backend/internal/service/usage_log_create_result.go
Normal file
@@ -0,0 +1,60 @@
|
||||
package service
|
||||
|
||||
import "errors"
|
||||
|
||||
type usageLogCreateDisposition int
|
||||
|
||||
const (
|
||||
usageLogCreateDispositionUnknown usageLogCreateDisposition = iota
|
||||
usageLogCreateDispositionNotPersisted
|
||||
)
|
||||
|
||||
type UsageLogCreateError struct {
|
||||
err error
|
||||
disposition usageLogCreateDisposition
|
||||
}
|
||||
|
||||
func (e *UsageLogCreateError) Error() string {
|
||||
if e == nil || e.err == nil {
|
||||
return "usage log create error"
|
||||
}
|
||||
return e.err.Error()
|
||||
}
|
||||
|
||||
func (e *UsageLogCreateError) Unwrap() error {
|
||||
if e == nil {
|
||||
return nil
|
||||
}
|
||||
return e.err
|
||||
}
|
||||
|
||||
func MarkUsageLogCreateNotPersisted(err error) error {
|
||||
if err == nil {
|
||||
return nil
|
||||
}
|
||||
return &UsageLogCreateError{
|
||||
err: err,
|
||||
disposition: usageLogCreateDispositionNotPersisted,
|
||||
}
|
||||
}
|
||||
|
||||
func IsUsageLogCreateNotPersisted(err error) bool {
|
||||
if err == nil {
|
||||
return false
|
||||
}
|
||||
var target *UsageLogCreateError
|
||||
if !errors.As(err, &target) {
|
||||
return false
|
||||
}
|
||||
return target.disposition == usageLogCreateDispositionNotPersisted
|
||||
}
|
||||
|
||||
func ShouldBillAfterUsageLogCreate(inserted bool, err error) bool {
|
||||
if inserted {
|
||||
return true
|
||||
}
|
||||
if err == nil {
|
||||
return false
|
||||
}
|
||||
return !IsUsageLogCreateNotPersisted(err)
|
||||
}
|
||||
Reference in New Issue
Block a user