diff --git a/backend/internal/repository/usage_log_repo.go b/backend/internal/repository/usage_log_repo.go index 5a022665..aab66081 100644 --- a/backend/internal/repository/usage_log_repo.go +++ b/backend/internal/repository/usage_log_repo.go @@ -246,16 +246,16 @@ func (r *usageLogRepository) CreateBestEffort(ctx context.Context, log *service. select { case r.bestEffortBatchCh <- req: case <-ctx.Done(): - return ctx.Err() + return service.MarkUsageLogCreateDropped(ctx.Err()) default: - return errors.New("usage log best-effort queue full") + return service.MarkUsageLogCreateDropped(errors.New("usage log best-effort queue full")) } select { case err := <-req.resultCh: return err case <-ctx.Done(): - return ctx.Err() + return service.MarkUsageLogCreateDropped(ctx.Err()) } } @@ -355,7 +355,7 @@ func (r *usageLogRepository) createBatched(ctx context.Context, log *service.Usa case <-ctx.Done(): return false, service.MarkUsageLogCreateNotPersisted(ctx.Err()) default: - return r.createSingle(ctx, r.sql, log) + return false, service.MarkUsageLogCreateNotPersisted(errors.New("usage log create batch queue full")) } select { @@ -840,27 +840,39 @@ func buildUsageLogBatchInsertQuery(keys []string, preparedByKey map[string]usage cache_ttl_overridden, created_at FROM input - ON CONFLICT (request_id, api_key_id) DO UPDATE - SET request_id = usage_logs.request_id - RETURNING request_id, api_key_id, id, created_at, (xmax = 0) AS inserted + ON CONFLICT (request_id, api_key_id) DO NOTHING + RETURNING request_id, api_key_id, id, created_at + ), + resolved AS ( + SELECT + input.input_idx, + input.request_id, + input.api_key_id, + COALESCE(inserted.id, existing.id) AS id, + COALESCE(inserted.created_at, existing.created_at) AS created_at, + (inserted.id IS NOT NULL) AS inserted + FROM input + LEFT JOIN inserted + ON inserted.request_id = input.request_id + AND inserted.api_key_id = input.api_key_id + LEFT JOIN usage_logs existing + ON existing.request_id = input.request_id + AND existing.api_key_id = input.api_key_id ) SELECT COALESCE( json_agg( json_build_object( - 'request_id', inserted.request_id, - 'api_key_id', inserted.api_key_id, - 'id', inserted.id, - 'created_at', inserted.created_at, - 'inserted', inserted.inserted + 'request_id', resolved.request_id, + 'api_key_id', resolved.api_key_id, + 'id', resolved.id, + 'created_at', resolved.created_at, + 'inserted', resolved.inserted ) - ORDER BY input.input_idx + ORDER BY resolved.input_idx ), '[]'::json ) - FROM input - JOIN inserted - ON inserted.request_id = input.request_id - AND inserted.api_key_id = input.api_key_id + FROM resolved `) return query.String(), args } diff --git a/backend/internal/repository/usage_log_repo_integration_test.go b/backend/internal/repository/usage_log_repo_integration_test.go index 00740878..0383f3bc 100644 --- a/backend/internal/repository/usage_log_repo_integration_test.go +++ b/backend/internal/repository/usage_log_repo_integration_test.go @@ -288,6 +288,34 @@ func TestUsageLogRepositoryCreateBestEffort_BatchPathDuplicateRequestID(t *testi }, 3*time.Second, 20*time.Millisecond) } +func TestUsageLogRepositoryCreateBestEffort_QueueFullReturnsDropped(t *testing.T) { + ctx := context.Background() + client := testEntClient(t) + repo := newUsageLogRepositoryWithSQL(client, integrationDB) + repo.bestEffortBatchCh = make(chan usageLogBestEffortRequest, 1) + repo.bestEffortBatchCh <- usageLogBestEffortRequest{} + + user := mustCreateUser(t, client, &service.User{Email: fmt.Sprintf("usage-best-effort-full-%d@example.com", time.Now().UnixNano())}) + apiKey := mustCreateApiKey(t, client, &service.APIKey{UserID: user.ID, Key: "sk-usage-best-effort-full-" + uuid.NewString(), Name: "k"}) + account := mustCreateAccount(t, client, &service.Account{Name: "acc-usage-best-effort-full-" + uuid.NewString()}) + + err := repo.CreateBestEffort(ctx, &service.UsageLog{ + UserID: user.ID, + APIKeyID: apiKey.ID, + AccountID: account.ID, + RequestID: uuid.NewString(), + Model: "claude-3", + InputTokens: 10, + OutputTokens: 20, + TotalCost: 0.5, + ActualCost: 0.5, + CreatedAt: time.Now().UTC(), + }) + + require.Error(t, err) + require.True(t, service.IsUsageLogCreateDropped(err)) +} + func TestUsageLogRepositoryCreate_BatchPathCanceledContextMarksNotPersisted(t *testing.T) { client := testEntClient(t) repo := newUsageLogRepositoryWithSQL(client, integrationDB) @@ -317,6 +345,35 @@ func TestUsageLogRepositoryCreate_BatchPathCanceledContextMarksNotPersisted(t *t require.True(t, service.IsUsageLogCreateNotPersisted(err)) } +func TestUsageLogRepositoryCreate_BatchPathQueueFullMarksNotPersisted(t *testing.T) { + ctx := context.Background() + client := testEntClient(t) + repo := newUsageLogRepositoryWithSQL(client, integrationDB) + repo.createBatchCh = make(chan usageLogCreateRequest, 1) + repo.createBatchCh <- usageLogCreateRequest{} + + user := mustCreateUser(t, client, &service.User{Email: fmt.Sprintf("usage-create-full-%d@example.com", time.Now().UnixNano())}) + apiKey := mustCreateApiKey(t, client, &service.APIKey{UserID: user.ID, Key: "sk-usage-create-full-" + uuid.NewString(), Name: "k"}) + account := mustCreateAccount(t, client, &service.Account{Name: "acc-usage-create-full-" + uuid.NewString()}) + + inserted, err := repo.Create(ctx, &service.UsageLog{ + UserID: user.ID, + APIKeyID: apiKey.ID, + AccountID: account.ID, + RequestID: uuid.NewString(), + Model: "claude-3", + InputTokens: 10, + OutputTokens: 20, + TotalCost: 0.5, + ActualCost: 0.5, + CreatedAt: time.Now().UTC(), + }) + + require.False(t, inserted) + require.Error(t, err) + require.True(t, service.IsUsageLogCreateNotPersisted(err)) +} + func TestUsageLogRepositoryCreate_BatchPathCanceledAfterQueueMarksNotPersisted(t *testing.T) { client := testEntClient(t) repo := newUsageLogRepositoryWithSQL(client, integrationDB) diff --git a/backend/internal/repository/usage_log_repo_unit_test.go b/backend/internal/repository/usage_log_repo_unit_test.go index d0e14ffd..0458902d 100644 --- a/backend/internal/repository/usage_log_repo_unit_test.go +++ b/backend/internal/repository/usage_log_repo_unit_test.go @@ -3,8 +3,11 @@ package repository import ( + "strings" "testing" + "time" + "github.com/Wei-Shaw/sub2api/internal/service" "github.com/stretchr/testify/require" ) @@ -39,3 +42,26 @@ func TestSafeDateFormat(t *testing.T) { }) } } + +func TestBuildUsageLogBatchInsertQuery_UsesConflictDoNothing(t *testing.T) { + log := &service.UsageLog{ + UserID: 1, + APIKeyID: 2, + AccountID: 3, + RequestID: "req-batch-no-update", + Model: "gpt-5", + InputTokens: 10, + OutputTokens: 5, + TotalCost: 1.2, + ActualCost: 1.2, + CreatedAt: time.Now().UTC(), + } + prepared := prepareUsageLogInsert(log) + + query, _ := buildUsageLogBatchInsertQuery([]string{usageLogBatchKey(log.RequestID, log.APIKeyID)}, map[string]usageLogInsertPrepared{ + usageLogBatchKey(log.RequestID, log.APIKeyID): prepared, + }) + + require.Contains(t, query, "ON CONFLICT (request_id, api_key_id) DO NOTHING") + require.NotContains(t, strings.ToUpper(query), "DO UPDATE") +} diff --git a/backend/internal/service/gateway_record_usage_test.go b/backend/internal/service/gateway_record_usage_test.go index 92e59ac8..475dea6f 100644 --- a/backend/internal/service/gateway_record_usage_test.go +++ b/backend/internal/service/gateway_record_usage_test.go @@ -4,6 +4,8 @@ package service import ( "context" + "errors" + "strings" "testing" "time" @@ -233,6 +235,89 @@ func TestGatewayServiceRecordUsage_UsesFallbackRequestIDForUsageLog(t *testing.T require.Equal(t, "local:gateway-local-fallback", usageRepo.lastLog.RequestID) } +func TestGatewayServiceRecordUsage_PrefersClientRequestIDOverUpstreamRequestID(t *testing.T) { + usageRepo := &openAIRecordUsageLogRepoStub{} + billingRepo := &openAIRecordUsageBillingRepoStub{result: &UsageBillingApplyResult{Applied: true}} + svc := newGatewayRecordUsageServiceWithBillingRepoForTest(usageRepo, billingRepo, &openAIRecordUsageUserRepoStub{}, &openAIRecordUsageSubRepoStub{}) + + ctx := context.WithValue(context.Background(), ctxkey.ClientRequestID, "client-stable-123") + ctx = context.WithValue(ctx, ctxkey.RequestID, "req-local-ignored") + err := svc.RecordUsage(ctx, &RecordUsageInput{ + Result: &ForwardResult{ + RequestID: "upstream-volatile-456", + Usage: ClaudeUsage{ + InputTokens: 10, + OutputTokens: 6, + }, + Model: "claude-sonnet-4", + Duration: time.Second, + }, + APIKey: &APIKey{ID: 506}, + User: &User{ID: 606}, + Account: &Account{ID: 706}, + }) + + require.NoError(t, err) + require.NotNil(t, billingRepo.lastCmd) + require.Equal(t, "client:client-stable-123", billingRepo.lastCmd.RequestID) + require.NotNil(t, usageRepo.lastLog) + require.Equal(t, "client:client-stable-123", usageRepo.lastLog.RequestID) +} + +func TestGatewayServiceRecordUsage_GeneratesRequestIDWhenAllSourcesMissing(t *testing.T) { + usageRepo := &openAIRecordUsageLogRepoStub{} + billingRepo := &openAIRecordUsageBillingRepoStub{result: &UsageBillingApplyResult{Applied: true}} + svc := newGatewayRecordUsageServiceWithBillingRepoForTest(usageRepo, billingRepo, &openAIRecordUsageUserRepoStub{}, &openAIRecordUsageSubRepoStub{}) + + err := svc.RecordUsage(context.Background(), &RecordUsageInput{ + Result: &ForwardResult{ + RequestID: "", + Usage: ClaudeUsage{ + InputTokens: 10, + OutputTokens: 6, + }, + Model: "claude-sonnet-4", + Duration: time.Second, + }, + APIKey: &APIKey{ID: 507}, + User: &User{ID: 607}, + Account: &Account{ID: 707}, + }) + + require.NoError(t, err) + require.NotNil(t, billingRepo.lastCmd) + require.True(t, strings.HasPrefix(billingRepo.lastCmd.RequestID, "generated:")) + require.NotNil(t, usageRepo.lastLog) + require.Equal(t, billingRepo.lastCmd.RequestID, usageRepo.lastLog.RequestID) +} + +func TestGatewayServiceRecordUsage_DroppedUsageLogDoesNotSyncFallback(t *testing.T) { + usageRepo := &openAIRecordUsageBestEffortLogRepoStub{ + bestEffortErr: MarkUsageLogCreateDropped(errors.New("usage log best-effort queue full")), + } + billingRepo := &openAIRecordUsageBillingRepoStub{result: &UsageBillingApplyResult{Applied: true}} + svc := newGatewayRecordUsageServiceWithBillingRepoForTest(usageRepo, billingRepo, &openAIRecordUsageUserRepoStub{}, &openAIRecordUsageSubRepoStub{}) + + err := svc.RecordUsage(context.Background(), &RecordUsageInput{ + Result: &ForwardResult{ + RequestID: "gateway_drop_usage_log", + Usage: ClaudeUsage{ + InputTokens: 10, + OutputTokens: 6, + }, + Model: "claude-sonnet-4", + Duration: time.Second, + }, + APIKey: &APIKey{ID: 508}, + User: &User{ID: 608}, + Account: &Account{ID: 708}, + }) + + require.NoError(t, err) + require.Equal(t, 1, usageRepo.bestEffortCalls) + require.Equal(t, 0, usageRepo.createCalls) +} + func TestGatewayServiceRecordUsage_BillingErrorSkipsUsageLogWrite(t *testing.T) { usageRepo := &openAIRecordUsageLogRepoStub{} billingRepo := &openAIRecordUsageBillingRepoStub{err: context.DeadlineExceeded} diff --git a/backend/internal/service/gateway_service.go b/backend/internal/service/gateway_service.go index f40119f7..a87255b0 100644 --- a/backend/internal/service/gateway_service.go +++ b/backend/internal/service/gateway_service.go @@ -6745,9 +6745,6 @@ func postUsageBilling(ctx context.Context, p *postUsageBillingParams, deps *bill } func resolveUsageBillingRequestID(ctx context.Context, upstreamRequestID string) string { - if requestID := strings.TrimSpace(upstreamRequestID); requestID != "" { - return requestID - } if ctx != nil { if clientRequestID, _ := ctx.Value(ctxkey.ClientRequestID).(string); strings.TrimSpace(clientRequestID) != "" { return "client:" + strings.TrimSpace(clientRequestID) @@ -6756,7 +6753,10 @@ func resolveUsageBillingRequestID(ctx context.Context, upstreamRequestID string) return "local:" + strings.TrimSpace(requestID) } } - return "" + if requestID := strings.TrimSpace(upstreamRequestID); requestID != "" { + return requestID + } + return "generated:" + generateRequestID() } func resolveUsageBillingPayloadFingerprint(ctx context.Context, requestPayloadHash string) string { @@ -6931,6 +6931,9 @@ func writeUsageLogBestEffort(ctx context.Context, repo UsageLogRepository, usage if writer, ok := repo.(usageLogBestEffortWriter); ok { if err := writer.CreateBestEffort(usageCtx, usageLog); err != nil { logger.LegacyPrintf(logKey, "Create usage log failed: %v", err) + if IsUsageLogCreateDropped(err) { + return + } if _, syncErr := repo.Create(usageCtx, usageLog); syncErr != nil { logger.LegacyPrintf(logKey, "Create usage log sync fallback failed: %v", syncErr) } diff --git a/backend/internal/service/openai_gateway_record_usage_test.go b/backend/internal/service/openai_gateway_record_usage_test.go index f05fa5f5..438e9aeb 100644 --- a/backend/internal/service/openai_gateway_record_usage_test.go +++ b/backend/internal/service/openai_gateway_record_usage_test.go @@ -3,6 +3,7 @@ package service import ( "context" "errors" + "strings" "testing" "time" @@ -28,6 +29,31 @@ func (s *openAIRecordUsageLogRepoStub) Create(ctx context.Context, log *UsageLog return s.inserted, s.err } +type openAIRecordUsageBestEffortLogRepoStub struct { + UsageLogRepository + + bestEffortErr error + createErr error + bestEffortCalls int + createCalls int + lastLog *UsageLog + lastCtxErr error +} + +func (s *openAIRecordUsageBestEffortLogRepoStub) CreateBestEffort(ctx context.Context, log *UsageLog) error { + s.bestEffortCalls++ + s.lastLog = log + s.lastCtxErr = ctx.Err() + return s.bestEffortErr +} + +func (s *openAIRecordUsageBestEffortLogRepoStub) Create(ctx context.Context, log *UsageLog) (bool, error) { + s.createCalls++ + s.lastLog = log + s.lastCtxErr = ctx.Err() + return false, s.createErr +} + type openAIRecordUsageBillingRepoStub struct { UsageBillingRepository @@ -543,6 +569,65 @@ func TestOpenAIGatewayServiceRecordUsage_UsesFallbackRequestIDForBillingAndUsage require.Equal(t, "local:req-local-fallback", usageRepo.lastLog.RequestID) } +func TestOpenAIGatewayServiceRecordUsage_PrefersClientRequestIDOverUpstreamRequestID(t *testing.T) { + usageRepo := &openAIRecordUsageLogRepoStub{} + billingRepo := &openAIRecordUsageBillingRepoStub{result: &UsageBillingApplyResult{Applied: true}} + userRepo := &openAIRecordUsageUserRepoStub{} + subRepo := &openAIRecordUsageSubRepoStub{} + svc := newOpenAIRecordUsageServiceWithBillingRepoForTest(usageRepo, billingRepo, userRepo, subRepo, nil) + + ctx := context.WithValue(context.Background(), ctxkey.ClientRequestID, "openai-client-stable-123") + err := svc.RecordUsage(ctx, &OpenAIRecordUsageInput{ + Result: &OpenAIForwardResult{ + RequestID: "upstream-openai-volatile-456", + Usage: OpenAIUsage{ + InputTokens: 8, + OutputTokens: 4, + }, + Model: "gpt-5.1", + Duration: time.Second, + }, + APIKey: &APIKey{ID: 10049}, + User: &User{ID: 20049}, + Account: &Account{ID: 30049}, + }) + + require.NoError(t, err) + require.NotNil(t, billingRepo.lastCmd) + require.Equal(t, "client:openai-client-stable-123", billingRepo.lastCmd.RequestID) + require.NotNil(t, usageRepo.lastLog) + require.Equal(t, "client:openai-client-stable-123", usageRepo.lastLog.RequestID) +} + +func TestOpenAIGatewayServiceRecordUsage_GeneratesRequestIDWhenAllSourcesMissing(t *testing.T) { + usageRepo := &openAIRecordUsageLogRepoStub{} + billingRepo := &openAIRecordUsageBillingRepoStub{result: &UsageBillingApplyResult{Applied: true}} + userRepo := &openAIRecordUsageUserRepoStub{} + subRepo := &openAIRecordUsageSubRepoStub{} + svc := newOpenAIRecordUsageServiceWithBillingRepoForTest(usageRepo, billingRepo, userRepo, subRepo, nil) + + err := svc.RecordUsage(context.Background(), &OpenAIRecordUsageInput{ + Result: &OpenAIForwardResult{ + RequestID: "", + Usage: OpenAIUsage{ + InputTokens: 8, + OutputTokens: 4, + }, + Model: "gpt-5.1", + Duration: time.Second, + }, + APIKey: &APIKey{ID: 10050}, + User: &User{ID: 20050}, + Account: &Account{ID: 30050}, + }) + + require.NoError(t, err) + require.NotNil(t, billingRepo.lastCmd) + require.True(t, strings.HasPrefix(billingRepo.lastCmd.RequestID, "generated:")) + require.NotNil(t, usageRepo.lastLog) + require.Equal(t, billingRepo.lastCmd.RequestID, usageRepo.lastLog.RequestID) +} + func TestOpenAIGatewayServiceRecordUsage_BillingErrorSkipsUsageLogWrite(t *testing.T) { usageRepo := &openAIRecordUsageLogRepoStub{} billingRepo := &openAIRecordUsageBillingRepoStub{err: errors.New("billing tx failed")} diff --git a/backend/internal/service/usage_log_create_result.go b/backend/internal/service/usage_log_create_result.go index 5e18b44c..1cd84f44 100644 --- a/backend/internal/service/usage_log_create_result.go +++ b/backend/internal/service/usage_log_create_result.go @@ -7,6 +7,7 @@ type usageLogCreateDisposition int const ( usageLogCreateDispositionUnknown usageLogCreateDisposition = iota usageLogCreateDispositionNotPersisted + usageLogCreateDispositionDropped ) type UsageLogCreateError struct { @@ -38,6 +39,16 @@ func MarkUsageLogCreateNotPersisted(err error) error { } } +func MarkUsageLogCreateDropped(err error) error { + if err == nil { + return nil + } + return &UsageLogCreateError{ + err: err, + disposition: usageLogCreateDispositionDropped, + } +} + func IsUsageLogCreateNotPersisted(err error) bool { if err == nil { return false @@ -49,6 +60,17 @@ func IsUsageLogCreateNotPersisted(err error) bool { return target.disposition == usageLogCreateDispositionNotPersisted } +func IsUsageLogCreateDropped(err error) bool { + if err == nil { + return false + } + var target *UsageLogCreateError + if !errors.As(err, &target) { + return false + } + return target.disposition == usageLogCreateDispositionDropped +} + func ShouldBillAfterUsageLogCreate(inserted bool, err error) bool { if inserted { return true