fix: harden usage billing idempotency and backpressure

This commit is contained in:
ius
2026-03-12 18:38:09 +08:00
parent 32d25f76fc
commit 6a685727d0
7 changed files with 311 additions and 21 deletions

View File

@@ -246,16 +246,16 @@ func (r *usageLogRepository) CreateBestEffort(ctx context.Context, log *service.
select {
case r.bestEffortBatchCh <- req:
case <-ctx.Done():
return ctx.Err()
return service.MarkUsageLogCreateDropped(ctx.Err())
default:
return errors.New("usage log best-effort queue full")
return service.MarkUsageLogCreateDropped(errors.New("usage log best-effort queue full"))
}
select {
case err := <-req.resultCh:
return err
case <-ctx.Done():
return ctx.Err()
return service.MarkUsageLogCreateDropped(ctx.Err())
}
}
@@ -355,7 +355,7 @@ func (r *usageLogRepository) createBatched(ctx context.Context, log *service.Usa
case <-ctx.Done():
return false, service.MarkUsageLogCreateNotPersisted(ctx.Err())
default:
return r.createSingle(ctx, r.sql, log)
return false, service.MarkUsageLogCreateNotPersisted(errors.New("usage log create batch queue full"))
}
select {
@@ -840,27 +840,39 @@ func buildUsageLogBatchInsertQuery(keys []string, preparedByKey map[string]usage
cache_ttl_overridden,
created_at
FROM input
ON CONFLICT (request_id, api_key_id) DO UPDATE
SET request_id = usage_logs.request_id
RETURNING request_id, api_key_id, id, created_at, (xmax = 0) AS inserted
ON CONFLICT (request_id, api_key_id) DO NOTHING
RETURNING request_id, api_key_id, id, created_at
),
resolved AS (
SELECT
input.input_idx,
input.request_id,
input.api_key_id,
COALESCE(inserted.id, existing.id) AS id,
COALESCE(inserted.created_at, existing.created_at) AS created_at,
(inserted.id IS NOT NULL) AS inserted
FROM input
LEFT JOIN inserted
ON inserted.request_id = input.request_id
AND inserted.api_key_id = input.api_key_id
LEFT JOIN usage_logs existing
ON existing.request_id = input.request_id
AND existing.api_key_id = input.api_key_id
)
SELECT COALESCE(
json_agg(
json_build_object(
'request_id', inserted.request_id,
'api_key_id', inserted.api_key_id,
'id', inserted.id,
'created_at', inserted.created_at,
'inserted', inserted.inserted
'request_id', resolved.request_id,
'api_key_id', resolved.api_key_id,
'id', resolved.id,
'created_at', resolved.created_at,
'inserted', resolved.inserted
)
ORDER BY input.input_idx
ORDER BY resolved.input_idx
),
'[]'::json
)
FROM input
JOIN inserted
ON inserted.request_id = input.request_id
AND inserted.api_key_id = input.api_key_id
FROM resolved
`)
return query.String(), args
}

View File

@@ -288,6 +288,34 @@ func TestUsageLogRepositoryCreateBestEffort_BatchPathDuplicateRequestID(t *testi
}, 3*time.Second, 20*time.Millisecond)
}
func TestUsageLogRepositoryCreateBestEffort_QueueFullReturnsDropped(t *testing.T) {
ctx := context.Background()
client := testEntClient(t)
repo := newUsageLogRepositoryWithSQL(client, integrationDB)
repo.bestEffortBatchCh = make(chan usageLogBestEffortRequest, 1)
repo.bestEffortBatchCh <- usageLogBestEffortRequest{}
user := mustCreateUser(t, client, &service.User{Email: fmt.Sprintf("usage-best-effort-full-%d@example.com", time.Now().UnixNano())})
apiKey := mustCreateApiKey(t, client, &service.APIKey{UserID: user.ID, Key: "sk-usage-best-effort-full-" + uuid.NewString(), Name: "k"})
account := mustCreateAccount(t, client, &service.Account{Name: "acc-usage-best-effort-full-" + uuid.NewString()})
err := repo.CreateBestEffort(ctx, &service.UsageLog{
UserID: user.ID,
APIKeyID: apiKey.ID,
AccountID: account.ID,
RequestID: uuid.NewString(),
Model: "claude-3",
InputTokens: 10,
OutputTokens: 20,
TotalCost: 0.5,
ActualCost: 0.5,
CreatedAt: time.Now().UTC(),
})
require.Error(t, err)
require.True(t, service.IsUsageLogCreateDropped(err))
}
func TestUsageLogRepositoryCreate_BatchPathCanceledContextMarksNotPersisted(t *testing.T) {
client := testEntClient(t)
repo := newUsageLogRepositoryWithSQL(client, integrationDB)
@@ -317,6 +345,35 @@ func TestUsageLogRepositoryCreate_BatchPathCanceledContextMarksNotPersisted(t *t
require.True(t, service.IsUsageLogCreateNotPersisted(err))
}
func TestUsageLogRepositoryCreate_BatchPathQueueFullMarksNotPersisted(t *testing.T) {
ctx := context.Background()
client := testEntClient(t)
repo := newUsageLogRepositoryWithSQL(client, integrationDB)
repo.createBatchCh = make(chan usageLogCreateRequest, 1)
repo.createBatchCh <- usageLogCreateRequest{}
user := mustCreateUser(t, client, &service.User{Email: fmt.Sprintf("usage-create-full-%d@example.com", time.Now().UnixNano())})
apiKey := mustCreateApiKey(t, client, &service.APIKey{UserID: user.ID, Key: "sk-usage-create-full-" + uuid.NewString(), Name: "k"})
account := mustCreateAccount(t, client, &service.Account{Name: "acc-usage-create-full-" + uuid.NewString()})
inserted, err := repo.Create(ctx, &service.UsageLog{
UserID: user.ID,
APIKeyID: apiKey.ID,
AccountID: account.ID,
RequestID: uuid.NewString(),
Model: "claude-3",
InputTokens: 10,
OutputTokens: 20,
TotalCost: 0.5,
ActualCost: 0.5,
CreatedAt: time.Now().UTC(),
})
require.False(t, inserted)
require.Error(t, err)
require.True(t, service.IsUsageLogCreateNotPersisted(err))
}
func TestUsageLogRepositoryCreate_BatchPathCanceledAfterQueueMarksNotPersisted(t *testing.T) {
client := testEntClient(t)
repo := newUsageLogRepositoryWithSQL(client, integrationDB)

View File

@@ -3,8 +3,11 @@
package repository
import (
"strings"
"testing"
"time"
"github.com/Wei-Shaw/sub2api/internal/service"
"github.com/stretchr/testify/require"
)
@@ -39,3 +42,26 @@ func TestSafeDateFormat(t *testing.T) {
})
}
}
func TestBuildUsageLogBatchInsertQuery_UsesConflictDoNothing(t *testing.T) {
log := &service.UsageLog{
UserID: 1,
APIKeyID: 2,
AccountID: 3,
RequestID: "req-batch-no-update",
Model: "gpt-5",
InputTokens: 10,
OutputTokens: 5,
TotalCost: 1.2,
ActualCost: 1.2,
CreatedAt: time.Now().UTC(),
}
prepared := prepareUsageLogInsert(log)
query, _ := buildUsageLogBatchInsertQuery([]string{usageLogBatchKey(log.RequestID, log.APIKeyID)}, map[string]usageLogInsertPrepared{
usageLogBatchKey(log.RequestID, log.APIKeyID): prepared,
})
require.Contains(t, query, "ON CONFLICT (request_id, api_key_id) DO NOTHING")
require.NotContains(t, strings.ToUpper(query), "DO UPDATE")
}

View File

@@ -4,6 +4,8 @@ package service
import (
"context"
"errors"
"strings"
"testing"
"time"
@@ -233,6 +235,89 @@ func TestGatewayServiceRecordUsage_UsesFallbackRequestIDForUsageLog(t *testing.T
require.Equal(t, "local:gateway-local-fallback", usageRepo.lastLog.RequestID)
}
func TestGatewayServiceRecordUsage_PrefersClientRequestIDOverUpstreamRequestID(t *testing.T) {
usageRepo := &openAIRecordUsageLogRepoStub{}
billingRepo := &openAIRecordUsageBillingRepoStub{result: &UsageBillingApplyResult{Applied: true}}
svc := newGatewayRecordUsageServiceWithBillingRepoForTest(usageRepo, billingRepo, &openAIRecordUsageUserRepoStub{}, &openAIRecordUsageSubRepoStub{})
ctx := context.WithValue(context.Background(), ctxkey.ClientRequestID, "client-stable-123")
ctx = context.WithValue(ctx, ctxkey.RequestID, "req-local-ignored")
err := svc.RecordUsage(ctx, &RecordUsageInput{
Result: &ForwardResult{
RequestID: "upstream-volatile-456",
Usage: ClaudeUsage{
InputTokens: 10,
OutputTokens: 6,
},
Model: "claude-sonnet-4",
Duration: time.Second,
},
APIKey: &APIKey{ID: 506},
User: &User{ID: 606},
Account: &Account{ID: 706},
})
require.NoError(t, err)
require.NotNil(t, billingRepo.lastCmd)
require.Equal(t, "client:client-stable-123", billingRepo.lastCmd.RequestID)
require.NotNil(t, usageRepo.lastLog)
require.Equal(t, "client:client-stable-123", usageRepo.lastLog.RequestID)
}
func TestGatewayServiceRecordUsage_GeneratesRequestIDWhenAllSourcesMissing(t *testing.T) {
usageRepo := &openAIRecordUsageLogRepoStub{}
billingRepo := &openAIRecordUsageBillingRepoStub{result: &UsageBillingApplyResult{Applied: true}}
svc := newGatewayRecordUsageServiceWithBillingRepoForTest(usageRepo, billingRepo, &openAIRecordUsageUserRepoStub{}, &openAIRecordUsageSubRepoStub{})
err := svc.RecordUsage(context.Background(), &RecordUsageInput{
Result: &ForwardResult{
RequestID: "",
Usage: ClaudeUsage{
InputTokens: 10,
OutputTokens: 6,
},
Model: "claude-sonnet-4",
Duration: time.Second,
},
APIKey: &APIKey{ID: 507},
User: &User{ID: 607},
Account: &Account{ID: 707},
})
require.NoError(t, err)
require.NotNil(t, billingRepo.lastCmd)
require.True(t, strings.HasPrefix(billingRepo.lastCmd.RequestID, "generated:"))
require.NotNil(t, usageRepo.lastLog)
require.Equal(t, billingRepo.lastCmd.RequestID, usageRepo.lastLog.RequestID)
}
func TestGatewayServiceRecordUsage_DroppedUsageLogDoesNotSyncFallback(t *testing.T) {
usageRepo := &openAIRecordUsageBestEffortLogRepoStub{
bestEffortErr: MarkUsageLogCreateDropped(errors.New("usage log best-effort queue full")),
}
billingRepo := &openAIRecordUsageBillingRepoStub{result: &UsageBillingApplyResult{Applied: true}}
svc := newGatewayRecordUsageServiceWithBillingRepoForTest(usageRepo, billingRepo, &openAIRecordUsageUserRepoStub{}, &openAIRecordUsageSubRepoStub{})
err := svc.RecordUsage(context.Background(), &RecordUsageInput{
Result: &ForwardResult{
RequestID: "gateway_drop_usage_log",
Usage: ClaudeUsage{
InputTokens: 10,
OutputTokens: 6,
},
Model: "claude-sonnet-4",
Duration: time.Second,
},
APIKey: &APIKey{ID: 508},
User: &User{ID: 608},
Account: &Account{ID: 708},
})
require.NoError(t, err)
require.Equal(t, 1, usageRepo.bestEffortCalls)
require.Equal(t, 0, usageRepo.createCalls)
}
func TestGatewayServiceRecordUsage_BillingErrorSkipsUsageLogWrite(t *testing.T) {
usageRepo := &openAIRecordUsageLogRepoStub{}
billingRepo := &openAIRecordUsageBillingRepoStub{err: context.DeadlineExceeded}

View File

@@ -6745,9 +6745,6 @@ func postUsageBilling(ctx context.Context, p *postUsageBillingParams, deps *bill
}
func resolveUsageBillingRequestID(ctx context.Context, upstreamRequestID string) string {
if requestID := strings.TrimSpace(upstreamRequestID); requestID != "" {
return requestID
}
if ctx != nil {
if clientRequestID, _ := ctx.Value(ctxkey.ClientRequestID).(string); strings.TrimSpace(clientRequestID) != "" {
return "client:" + strings.TrimSpace(clientRequestID)
@@ -6756,7 +6753,10 @@ func resolveUsageBillingRequestID(ctx context.Context, upstreamRequestID string)
return "local:" + strings.TrimSpace(requestID)
}
}
return ""
if requestID := strings.TrimSpace(upstreamRequestID); requestID != "" {
return requestID
}
return "generated:" + generateRequestID()
}
func resolveUsageBillingPayloadFingerprint(ctx context.Context, requestPayloadHash string) string {
@@ -6931,6 +6931,9 @@ func writeUsageLogBestEffort(ctx context.Context, repo UsageLogRepository, usage
if writer, ok := repo.(usageLogBestEffortWriter); ok {
if err := writer.CreateBestEffort(usageCtx, usageLog); err != nil {
logger.LegacyPrintf(logKey, "Create usage log failed: %v", err)
if IsUsageLogCreateDropped(err) {
return
}
if _, syncErr := repo.Create(usageCtx, usageLog); syncErr != nil {
logger.LegacyPrintf(logKey, "Create usage log sync fallback failed: %v", syncErr)
}

View File

@@ -3,6 +3,7 @@ package service
import (
"context"
"errors"
"strings"
"testing"
"time"
@@ -28,6 +29,31 @@ func (s *openAIRecordUsageLogRepoStub) Create(ctx context.Context, log *UsageLog
return s.inserted, s.err
}
type openAIRecordUsageBestEffortLogRepoStub struct {
UsageLogRepository
bestEffortErr error
createErr error
bestEffortCalls int
createCalls int
lastLog *UsageLog
lastCtxErr error
}
func (s *openAIRecordUsageBestEffortLogRepoStub) CreateBestEffort(ctx context.Context, log *UsageLog) error {
s.bestEffortCalls++
s.lastLog = log
s.lastCtxErr = ctx.Err()
return s.bestEffortErr
}
func (s *openAIRecordUsageBestEffortLogRepoStub) Create(ctx context.Context, log *UsageLog) (bool, error) {
s.createCalls++
s.lastLog = log
s.lastCtxErr = ctx.Err()
return false, s.createErr
}
type openAIRecordUsageBillingRepoStub struct {
UsageBillingRepository
@@ -543,6 +569,65 @@ func TestOpenAIGatewayServiceRecordUsage_UsesFallbackRequestIDForBillingAndUsage
require.Equal(t, "local:req-local-fallback", usageRepo.lastLog.RequestID)
}
func TestOpenAIGatewayServiceRecordUsage_PrefersClientRequestIDOverUpstreamRequestID(t *testing.T) {
usageRepo := &openAIRecordUsageLogRepoStub{}
billingRepo := &openAIRecordUsageBillingRepoStub{result: &UsageBillingApplyResult{Applied: true}}
userRepo := &openAIRecordUsageUserRepoStub{}
subRepo := &openAIRecordUsageSubRepoStub{}
svc := newOpenAIRecordUsageServiceWithBillingRepoForTest(usageRepo, billingRepo, userRepo, subRepo, nil)
ctx := context.WithValue(context.Background(), ctxkey.ClientRequestID, "openai-client-stable-123")
err := svc.RecordUsage(ctx, &OpenAIRecordUsageInput{
Result: &OpenAIForwardResult{
RequestID: "upstream-openai-volatile-456",
Usage: OpenAIUsage{
InputTokens: 8,
OutputTokens: 4,
},
Model: "gpt-5.1",
Duration: time.Second,
},
APIKey: &APIKey{ID: 10049},
User: &User{ID: 20049},
Account: &Account{ID: 30049},
})
require.NoError(t, err)
require.NotNil(t, billingRepo.lastCmd)
require.Equal(t, "client:openai-client-stable-123", billingRepo.lastCmd.RequestID)
require.NotNil(t, usageRepo.lastLog)
require.Equal(t, "client:openai-client-stable-123", usageRepo.lastLog.RequestID)
}
func TestOpenAIGatewayServiceRecordUsage_GeneratesRequestIDWhenAllSourcesMissing(t *testing.T) {
usageRepo := &openAIRecordUsageLogRepoStub{}
billingRepo := &openAIRecordUsageBillingRepoStub{result: &UsageBillingApplyResult{Applied: true}}
userRepo := &openAIRecordUsageUserRepoStub{}
subRepo := &openAIRecordUsageSubRepoStub{}
svc := newOpenAIRecordUsageServiceWithBillingRepoForTest(usageRepo, billingRepo, userRepo, subRepo, nil)
err := svc.RecordUsage(context.Background(), &OpenAIRecordUsageInput{
Result: &OpenAIForwardResult{
RequestID: "",
Usage: OpenAIUsage{
InputTokens: 8,
OutputTokens: 4,
},
Model: "gpt-5.1",
Duration: time.Second,
},
APIKey: &APIKey{ID: 10050},
User: &User{ID: 20050},
Account: &Account{ID: 30050},
})
require.NoError(t, err)
require.NotNil(t, billingRepo.lastCmd)
require.True(t, strings.HasPrefix(billingRepo.lastCmd.RequestID, "generated:"))
require.NotNil(t, usageRepo.lastLog)
require.Equal(t, billingRepo.lastCmd.RequestID, usageRepo.lastLog.RequestID)
}
func TestOpenAIGatewayServiceRecordUsage_BillingErrorSkipsUsageLogWrite(t *testing.T) {
usageRepo := &openAIRecordUsageLogRepoStub{}
billingRepo := &openAIRecordUsageBillingRepoStub{err: errors.New("billing tx failed")}

View File

@@ -7,6 +7,7 @@ type usageLogCreateDisposition int
const (
usageLogCreateDispositionUnknown usageLogCreateDisposition = iota
usageLogCreateDispositionNotPersisted
usageLogCreateDispositionDropped
)
type UsageLogCreateError struct {
@@ -38,6 +39,16 @@ func MarkUsageLogCreateNotPersisted(err error) error {
}
}
func MarkUsageLogCreateDropped(err error) error {
if err == nil {
return nil
}
return &UsageLogCreateError{
err: err,
disposition: usageLogCreateDispositionDropped,
}
}
func IsUsageLogCreateNotPersisted(err error) bool {
if err == nil {
return false
@@ -49,6 +60,17 @@ func IsUsageLogCreateNotPersisted(err error) bool {
return target.disposition == usageLogCreateDispositionNotPersisted
}
func IsUsageLogCreateDropped(err error) bool {
if err == nil {
return false
}
var target *UsageLogCreateError
if !errors.As(err, &target) {
return false
}
return target.disposition == usageLogCreateDispositionDropped
}
func ShouldBillAfterUsageLogCreate(inserted bool, err error) bool {
if inserted {
return true