mirror of
https://github.com/Wei-Shaw/sub2api.git
synced 2026-03-30 00:31:24 +00:00
fix: harden usage billing idempotency and backpressure
This commit is contained in:
@@ -246,16 +246,16 @@ func (r *usageLogRepository) CreateBestEffort(ctx context.Context, log *service.
|
||||
select {
|
||||
case r.bestEffortBatchCh <- req:
|
||||
case <-ctx.Done():
|
||||
return ctx.Err()
|
||||
return service.MarkUsageLogCreateDropped(ctx.Err())
|
||||
default:
|
||||
return errors.New("usage log best-effort queue full")
|
||||
return service.MarkUsageLogCreateDropped(errors.New("usage log best-effort queue full"))
|
||||
}
|
||||
|
||||
select {
|
||||
case err := <-req.resultCh:
|
||||
return err
|
||||
case <-ctx.Done():
|
||||
return ctx.Err()
|
||||
return service.MarkUsageLogCreateDropped(ctx.Err())
|
||||
}
|
||||
}
|
||||
|
||||
@@ -355,7 +355,7 @@ func (r *usageLogRepository) createBatched(ctx context.Context, log *service.Usa
|
||||
case <-ctx.Done():
|
||||
return false, service.MarkUsageLogCreateNotPersisted(ctx.Err())
|
||||
default:
|
||||
return r.createSingle(ctx, r.sql, log)
|
||||
return false, service.MarkUsageLogCreateNotPersisted(errors.New("usage log create batch queue full"))
|
||||
}
|
||||
|
||||
select {
|
||||
@@ -840,27 +840,39 @@ func buildUsageLogBatchInsertQuery(keys []string, preparedByKey map[string]usage
|
||||
cache_ttl_overridden,
|
||||
created_at
|
||||
FROM input
|
||||
ON CONFLICT (request_id, api_key_id) DO UPDATE
|
||||
SET request_id = usage_logs.request_id
|
||||
RETURNING request_id, api_key_id, id, created_at, (xmax = 0) AS inserted
|
||||
ON CONFLICT (request_id, api_key_id) DO NOTHING
|
||||
RETURNING request_id, api_key_id, id, created_at
|
||||
),
|
||||
resolved AS (
|
||||
SELECT
|
||||
input.input_idx,
|
||||
input.request_id,
|
||||
input.api_key_id,
|
||||
COALESCE(inserted.id, existing.id) AS id,
|
||||
COALESCE(inserted.created_at, existing.created_at) AS created_at,
|
||||
(inserted.id IS NOT NULL) AS inserted
|
||||
FROM input
|
||||
LEFT JOIN inserted
|
||||
ON inserted.request_id = input.request_id
|
||||
AND inserted.api_key_id = input.api_key_id
|
||||
LEFT JOIN usage_logs existing
|
||||
ON existing.request_id = input.request_id
|
||||
AND existing.api_key_id = input.api_key_id
|
||||
)
|
||||
SELECT COALESCE(
|
||||
json_agg(
|
||||
json_build_object(
|
||||
'request_id', inserted.request_id,
|
||||
'api_key_id', inserted.api_key_id,
|
||||
'id', inserted.id,
|
||||
'created_at', inserted.created_at,
|
||||
'inserted', inserted.inserted
|
||||
'request_id', resolved.request_id,
|
||||
'api_key_id', resolved.api_key_id,
|
||||
'id', resolved.id,
|
||||
'created_at', resolved.created_at,
|
||||
'inserted', resolved.inserted
|
||||
)
|
||||
ORDER BY input.input_idx
|
||||
ORDER BY resolved.input_idx
|
||||
),
|
||||
'[]'::json
|
||||
)
|
||||
FROM input
|
||||
JOIN inserted
|
||||
ON inserted.request_id = input.request_id
|
||||
AND inserted.api_key_id = input.api_key_id
|
||||
FROM resolved
|
||||
`)
|
||||
return query.String(), args
|
||||
}
|
||||
|
||||
@@ -288,6 +288,34 @@ func TestUsageLogRepositoryCreateBestEffort_BatchPathDuplicateRequestID(t *testi
|
||||
}, 3*time.Second, 20*time.Millisecond)
|
||||
}
|
||||
|
||||
func TestUsageLogRepositoryCreateBestEffort_QueueFullReturnsDropped(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
client := testEntClient(t)
|
||||
repo := newUsageLogRepositoryWithSQL(client, integrationDB)
|
||||
repo.bestEffortBatchCh = make(chan usageLogBestEffortRequest, 1)
|
||||
repo.bestEffortBatchCh <- usageLogBestEffortRequest{}
|
||||
|
||||
user := mustCreateUser(t, client, &service.User{Email: fmt.Sprintf("usage-best-effort-full-%d@example.com", time.Now().UnixNano())})
|
||||
apiKey := mustCreateApiKey(t, client, &service.APIKey{UserID: user.ID, Key: "sk-usage-best-effort-full-" + uuid.NewString(), Name: "k"})
|
||||
account := mustCreateAccount(t, client, &service.Account{Name: "acc-usage-best-effort-full-" + uuid.NewString()})
|
||||
|
||||
err := repo.CreateBestEffort(ctx, &service.UsageLog{
|
||||
UserID: user.ID,
|
||||
APIKeyID: apiKey.ID,
|
||||
AccountID: account.ID,
|
||||
RequestID: uuid.NewString(),
|
||||
Model: "claude-3",
|
||||
InputTokens: 10,
|
||||
OutputTokens: 20,
|
||||
TotalCost: 0.5,
|
||||
ActualCost: 0.5,
|
||||
CreatedAt: time.Now().UTC(),
|
||||
})
|
||||
|
||||
require.Error(t, err)
|
||||
require.True(t, service.IsUsageLogCreateDropped(err))
|
||||
}
|
||||
|
||||
func TestUsageLogRepositoryCreate_BatchPathCanceledContextMarksNotPersisted(t *testing.T) {
|
||||
client := testEntClient(t)
|
||||
repo := newUsageLogRepositoryWithSQL(client, integrationDB)
|
||||
@@ -317,6 +345,35 @@ func TestUsageLogRepositoryCreate_BatchPathCanceledContextMarksNotPersisted(t *t
|
||||
require.True(t, service.IsUsageLogCreateNotPersisted(err))
|
||||
}
|
||||
|
||||
func TestUsageLogRepositoryCreate_BatchPathQueueFullMarksNotPersisted(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
client := testEntClient(t)
|
||||
repo := newUsageLogRepositoryWithSQL(client, integrationDB)
|
||||
repo.createBatchCh = make(chan usageLogCreateRequest, 1)
|
||||
repo.createBatchCh <- usageLogCreateRequest{}
|
||||
|
||||
user := mustCreateUser(t, client, &service.User{Email: fmt.Sprintf("usage-create-full-%d@example.com", time.Now().UnixNano())})
|
||||
apiKey := mustCreateApiKey(t, client, &service.APIKey{UserID: user.ID, Key: "sk-usage-create-full-" + uuid.NewString(), Name: "k"})
|
||||
account := mustCreateAccount(t, client, &service.Account{Name: "acc-usage-create-full-" + uuid.NewString()})
|
||||
|
||||
inserted, err := repo.Create(ctx, &service.UsageLog{
|
||||
UserID: user.ID,
|
||||
APIKeyID: apiKey.ID,
|
||||
AccountID: account.ID,
|
||||
RequestID: uuid.NewString(),
|
||||
Model: "claude-3",
|
||||
InputTokens: 10,
|
||||
OutputTokens: 20,
|
||||
TotalCost: 0.5,
|
||||
ActualCost: 0.5,
|
||||
CreatedAt: time.Now().UTC(),
|
||||
})
|
||||
|
||||
require.False(t, inserted)
|
||||
require.Error(t, err)
|
||||
require.True(t, service.IsUsageLogCreateNotPersisted(err))
|
||||
}
|
||||
|
||||
func TestUsageLogRepositoryCreate_BatchPathCanceledAfterQueueMarksNotPersisted(t *testing.T) {
|
||||
client := testEntClient(t)
|
||||
repo := newUsageLogRepositoryWithSQL(client, integrationDB)
|
||||
|
||||
@@ -3,8 +3,11 @@
|
||||
package repository
|
||||
|
||||
import (
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/service"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
@@ -39,3 +42,26 @@ func TestSafeDateFormat(t *testing.T) {
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestBuildUsageLogBatchInsertQuery_UsesConflictDoNothing(t *testing.T) {
|
||||
log := &service.UsageLog{
|
||||
UserID: 1,
|
||||
APIKeyID: 2,
|
||||
AccountID: 3,
|
||||
RequestID: "req-batch-no-update",
|
||||
Model: "gpt-5",
|
||||
InputTokens: 10,
|
||||
OutputTokens: 5,
|
||||
TotalCost: 1.2,
|
||||
ActualCost: 1.2,
|
||||
CreatedAt: time.Now().UTC(),
|
||||
}
|
||||
prepared := prepareUsageLogInsert(log)
|
||||
|
||||
query, _ := buildUsageLogBatchInsertQuery([]string{usageLogBatchKey(log.RequestID, log.APIKeyID)}, map[string]usageLogInsertPrepared{
|
||||
usageLogBatchKey(log.RequestID, log.APIKeyID): prepared,
|
||||
})
|
||||
|
||||
require.Contains(t, query, "ON CONFLICT (request_id, api_key_id) DO NOTHING")
|
||||
require.NotContains(t, strings.ToUpper(query), "DO UPDATE")
|
||||
}
|
||||
|
||||
@@ -4,6 +4,8 @@ package service
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
@@ -233,6 +235,89 @@ func TestGatewayServiceRecordUsage_UsesFallbackRequestIDForUsageLog(t *testing.T
|
||||
require.Equal(t, "local:gateway-local-fallback", usageRepo.lastLog.RequestID)
|
||||
}
|
||||
|
||||
func TestGatewayServiceRecordUsage_PrefersClientRequestIDOverUpstreamRequestID(t *testing.T) {
|
||||
usageRepo := &openAIRecordUsageLogRepoStub{}
|
||||
billingRepo := &openAIRecordUsageBillingRepoStub{result: &UsageBillingApplyResult{Applied: true}}
|
||||
svc := newGatewayRecordUsageServiceWithBillingRepoForTest(usageRepo, billingRepo, &openAIRecordUsageUserRepoStub{}, &openAIRecordUsageSubRepoStub{})
|
||||
|
||||
ctx := context.WithValue(context.Background(), ctxkey.ClientRequestID, "client-stable-123")
|
||||
ctx = context.WithValue(ctx, ctxkey.RequestID, "req-local-ignored")
|
||||
err := svc.RecordUsage(ctx, &RecordUsageInput{
|
||||
Result: &ForwardResult{
|
||||
RequestID: "upstream-volatile-456",
|
||||
Usage: ClaudeUsage{
|
||||
InputTokens: 10,
|
||||
OutputTokens: 6,
|
||||
},
|
||||
Model: "claude-sonnet-4",
|
||||
Duration: time.Second,
|
||||
},
|
||||
APIKey: &APIKey{ID: 506},
|
||||
User: &User{ID: 606},
|
||||
Account: &Account{ID: 706},
|
||||
})
|
||||
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, billingRepo.lastCmd)
|
||||
require.Equal(t, "client:client-stable-123", billingRepo.lastCmd.RequestID)
|
||||
require.NotNil(t, usageRepo.lastLog)
|
||||
require.Equal(t, "client:client-stable-123", usageRepo.lastLog.RequestID)
|
||||
}
|
||||
|
||||
func TestGatewayServiceRecordUsage_GeneratesRequestIDWhenAllSourcesMissing(t *testing.T) {
|
||||
usageRepo := &openAIRecordUsageLogRepoStub{}
|
||||
billingRepo := &openAIRecordUsageBillingRepoStub{result: &UsageBillingApplyResult{Applied: true}}
|
||||
svc := newGatewayRecordUsageServiceWithBillingRepoForTest(usageRepo, billingRepo, &openAIRecordUsageUserRepoStub{}, &openAIRecordUsageSubRepoStub{})
|
||||
|
||||
err := svc.RecordUsage(context.Background(), &RecordUsageInput{
|
||||
Result: &ForwardResult{
|
||||
RequestID: "",
|
||||
Usage: ClaudeUsage{
|
||||
InputTokens: 10,
|
||||
OutputTokens: 6,
|
||||
},
|
||||
Model: "claude-sonnet-4",
|
||||
Duration: time.Second,
|
||||
},
|
||||
APIKey: &APIKey{ID: 507},
|
||||
User: &User{ID: 607},
|
||||
Account: &Account{ID: 707},
|
||||
})
|
||||
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, billingRepo.lastCmd)
|
||||
require.True(t, strings.HasPrefix(billingRepo.lastCmd.RequestID, "generated:"))
|
||||
require.NotNil(t, usageRepo.lastLog)
|
||||
require.Equal(t, billingRepo.lastCmd.RequestID, usageRepo.lastLog.RequestID)
|
||||
}
|
||||
|
||||
func TestGatewayServiceRecordUsage_DroppedUsageLogDoesNotSyncFallback(t *testing.T) {
|
||||
usageRepo := &openAIRecordUsageBestEffortLogRepoStub{
|
||||
bestEffortErr: MarkUsageLogCreateDropped(errors.New("usage log best-effort queue full")),
|
||||
}
|
||||
billingRepo := &openAIRecordUsageBillingRepoStub{result: &UsageBillingApplyResult{Applied: true}}
|
||||
svc := newGatewayRecordUsageServiceWithBillingRepoForTest(usageRepo, billingRepo, &openAIRecordUsageUserRepoStub{}, &openAIRecordUsageSubRepoStub{})
|
||||
|
||||
err := svc.RecordUsage(context.Background(), &RecordUsageInput{
|
||||
Result: &ForwardResult{
|
||||
RequestID: "gateway_drop_usage_log",
|
||||
Usage: ClaudeUsage{
|
||||
InputTokens: 10,
|
||||
OutputTokens: 6,
|
||||
},
|
||||
Model: "claude-sonnet-4",
|
||||
Duration: time.Second,
|
||||
},
|
||||
APIKey: &APIKey{ID: 508},
|
||||
User: &User{ID: 608},
|
||||
Account: &Account{ID: 708},
|
||||
})
|
||||
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, 1, usageRepo.bestEffortCalls)
|
||||
require.Equal(t, 0, usageRepo.createCalls)
|
||||
}
|
||||
|
||||
func TestGatewayServiceRecordUsage_BillingErrorSkipsUsageLogWrite(t *testing.T) {
|
||||
usageRepo := &openAIRecordUsageLogRepoStub{}
|
||||
billingRepo := &openAIRecordUsageBillingRepoStub{err: context.DeadlineExceeded}
|
||||
|
||||
@@ -6745,9 +6745,6 @@ func postUsageBilling(ctx context.Context, p *postUsageBillingParams, deps *bill
|
||||
}
|
||||
|
||||
func resolveUsageBillingRequestID(ctx context.Context, upstreamRequestID string) string {
|
||||
if requestID := strings.TrimSpace(upstreamRequestID); requestID != "" {
|
||||
return requestID
|
||||
}
|
||||
if ctx != nil {
|
||||
if clientRequestID, _ := ctx.Value(ctxkey.ClientRequestID).(string); strings.TrimSpace(clientRequestID) != "" {
|
||||
return "client:" + strings.TrimSpace(clientRequestID)
|
||||
@@ -6756,7 +6753,10 @@ func resolveUsageBillingRequestID(ctx context.Context, upstreamRequestID string)
|
||||
return "local:" + strings.TrimSpace(requestID)
|
||||
}
|
||||
}
|
||||
return ""
|
||||
if requestID := strings.TrimSpace(upstreamRequestID); requestID != "" {
|
||||
return requestID
|
||||
}
|
||||
return "generated:" + generateRequestID()
|
||||
}
|
||||
|
||||
func resolveUsageBillingPayloadFingerprint(ctx context.Context, requestPayloadHash string) string {
|
||||
@@ -6931,6 +6931,9 @@ func writeUsageLogBestEffort(ctx context.Context, repo UsageLogRepository, usage
|
||||
if writer, ok := repo.(usageLogBestEffortWriter); ok {
|
||||
if err := writer.CreateBestEffort(usageCtx, usageLog); err != nil {
|
||||
logger.LegacyPrintf(logKey, "Create usage log failed: %v", err)
|
||||
if IsUsageLogCreateDropped(err) {
|
||||
return
|
||||
}
|
||||
if _, syncErr := repo.Create(usageCtx, usageLog); syncErr != nil {
|
||||
logger.LegacyPrintf(logKey, "Create usage log sync fallback failed: %v", syncErr)
|
||||
}
|
||||
|
||||
@@ -3,6 +3,7 @@ package service
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
@@ -28,6 +29,31 @@ func (s *openAIRecordUsageLogRepoStub) Create(ctx context.Context, log *UsageLog
|
||||
return s.inserted, s.err
|
||||
}
|
||||
|
||||
type openAIRecordUsageBestEffortLogRepoStub struct {
|
||||
UsageLogRepository
|
||||
|
||||
bestEffortErr error
|
||||
createErr error
|
||||
bestEffortCalls int
|
||||
createCalls int
|
||||
lastLog *UsageLog
|
||||
lastCtxErr error
|
||||
}
|
||||
|
||||
func (s *openAIRecordUsageBestEffortLogRepoStub) CreateBestEffort(ctx context.Context, log *UsageLog) error {
|
||||
s.bestEffortCalls++
|
||||
s.lastLog = log
|
||||
s.lastCtxErr = ctx.Err()
|
||||
return s.bestEffortErr
|
||||
}
|
||||
|
||||
func (s *openAIRecordUsageBestEffortLogRepoStub) Create(ctx context.Context, log *UsageLog) (bool, error) {
|
||||
s.createCalls++
|
||||
s.lastLog = log
|
||||
s.lastCtxErr = ctx.Err()
|
||||
return false, s.createErr
|
||||
}
|
||||
|
||||
type openAIRecordUsageBillingRepoStub struct {
|
||||
UsageBillingRepository
|
||||
|
||||
@@ -543,6 +569,65 @@ func TestOpenAIGatewayServiceRecordUsage_UsesFallbackRequestIDForBillingAndUsage
|
||||
require.Equal(t, "local:req-local-fallback", usageRepo.lastLog.RequestID)
|
||||
}
|
||||
|
||||
func TestOpenAIGatewayServiceRecordUsage_PrefersClientRequestIDOverUpstreamRequestID(t *testing.T) {
|
||||
usageRepo := &openAIRecordUsageLogRepoStub{}
|
||||
billingRepo := &openAIRecordUsageBillingRepoStub{result: &UsageBillingApplyResult{Applied: true}}
|
||||
userRepo := &openAIRecordUsageUserRepoStub{}
|
||||
subRepo := &openAIRecordUsageSubRepoStub{}
|
||||
svc := newOpenAIRecordUsageServiceWithBillingRepoForTest(usageRepo, billingRepo, userRepo, subRepo, nil)
|
||||
|
||||
ctx := context.WithValue(context.Background(), ctxkey.ClientRequestID, "openai-client-stable-123")
|
||||
err := svc.RecordUsage(ctx, &OpenAIRecordUsageInput{
|
||||
Result: &OpenAIForwardResult{
|
||||
RequestID: "upstream-openai-volatile-456",
|
||||
Usage: OpenAIUsage{
|
||||
InputTokens: 8,
|
||||
OutputTokens: 4,
|
||||
},
|
||||
Model: "gpt-5.1",
|
||||
Duration: time.Second,
|
||||
},
|
||||
APIKey: &APIKey{ID: 10049},
|
||||
User: &User{ID: 20049},
|
||||
Account: &Account{ID: 30049},
|
||||
})
|
||||
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, billingRepo.lastCmd)
|
||||
require.Equal(t, "client:openai-client-stable-123", billingRepo.lastCmd.RequestID)
|
||||
require.NotNil(t, usageRepo.lastLog)
|
||||
require.Equal(t, "client:openai-client-stable-123", usageRepo.lastLog.RequestID)
|
||||
}
|
||||
|
||||
func TestOpenAIGatewayServiceRecordUsage_GeneratesRequestIDWhenAllSourcesMissing(t *testing.T) {
|
||||
usageRepo := &openAIRecordUsageLogRepoStub{}
|
||||
billingRepo := &openAIRecordUsageBillingRepoStub{result: &UsageBillingApplyResult{Applied: true}}
|
||||
userRepo := &openAIRecordUsageUserRepoStub{}
|
||||
subRepo := &openAIRecordUsageSubRepoStub{}
|
||||
svc := newOpenAIRecordUsageServiceWithBillingRepoForTest(usageRepo, billingRepo, userRepo, subRepo, nil)
|
||||
|
||||
err := svc.RecordUsage(context.Background(), &OpenAIRecordUsageInput{
|
||||
Result: &OpenAIForwardResult{
|
||||
RequestID: "",
|
||||
Usage: OpenAIUsage{
|
||||
InputTokens: 8,
|
||||
OutputTokens: 4,
|
||||
},
|
||||
Model: "gpt-5.1",
|
||||
Duration: time.Second,
|
||||
},
|
||||
APIKey: &APIKey{ID: 10050},
|
||||
User: &User{ID: 20050},
|
||||
Account: &Account{ID: 30050},
|
||||
})
|
||||
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, billingRepo.lastCmd)
|
||||
require.True(t, strings.HasPrefix(billingRepo.lastCmd.RequestID, "generated:"))
|
||||
require.NotNil(t, usageRepo.lastLog)
|
||||
require.Equal(t, billingRepo.lastCmd.RequestID, usageRepo.lastLog.RequestID)
|
||||
}
|
||||
|
||||
func TestOpenAIGatewayServiceRecordUsage_BillingErrorSkipsUsageLogWrite(t *testing.T) {
|
||||
usageRepo := &openAIRecordUsageLogRepoStub{}
|
||||
billingRepo := &openAIRecordUsageBillingRepoStub{err: errors.New("billing tx failed")}
|
||||
|
||||
@@ -7,6 +7,7 @@ type usageLogCreateDisposition int
|
||||
const (
|
||||
usageLogCreateDispositionUnknown usageLogCreateDisposition = iota
|
||||
usageLogCreateDispositionNotPersisted
|
||||
usageLogCreateDispositionDropped
|
||||
)
|
||||
|
||||
type UsageLogCreateError struct {
|
||||
@@ -38,6 +39,16 @@ func MarkUsageLogCreateNotPersisted(err error) error {
|
||||
}
|
||||
}
|
||||
|
||||
func MarkUsageLogCreateDropped(err error) error {
|
||||
if err == nil {
|
||||
return nil
|
||||
}
|
||||
return &UsageLogCreateError{
|
||||
err: err,
|
||||
disposition: usageLogCreateDispositionDropped,
|
||||
}
|
||||
}
|
||||
|
||||
func IsUsageLogCreateNotPersisted(err error) bool {
|
||||
if err == nil {
|
||||
return false
|
||||
@@ -49,6 +60,17 @@ func IsUsageLogCreateNotPersisted(err error) bool {
|
||||
return target.disposition == usageLogCreateDispositionNotPersisted
|
||||
}
|
||||
|
||||
func IsUsageLogCreateDropped(err error) bool {
|
||||
if err == nil {
|
||||
return false
|
||||
}
|
||||
var target *UsageLogCreateError
|
||||
if !errors.As(err, &target) {
|
||||
return false
|
||||
}
|
||||
return target.disposition == usageLogCreateDispositionDropped
|
||||
}
|
||||
|
||||
func ShouldBillAfterUsageLogCreate(inserted bool, err error) bool {
|
||||
if inserted {
|
||||
return true
|
||||
|
||||
Reference in New Issue
Block a user