mirror of
https://github.com/Wei-Shaw/sub2api.git
synced 2026-03-30 02:27:11 +00:00
fix: harden usage billing idempotency and backpressure
This commit is contained in:
@@ -246,16 +246,16 @@ func (r *usageLogRepository) CreateBestEffort(ctx context.Context, log *service.
|
|||||||
select {
|
select {
|
||||||
case r.bestEffortBatchCh <- req:
|
case r.bestEffortBatchCh <- req:
|
||||||
case <-ctx.Done():
|
case <-ctx.Done():
|
||||||
return ctx.Err()
|
return service.MarkUsageLogCreateDropped(ctx.Err())
|
||||||
default:
|
default:
|
||||||
return errors.New("usage log best-effort queue full")
|
return service.MarkUsageLogCreateDropped(errors.New("usage log best-effort queue full"))
|
||||||
}
|
}
|
||||||
|
|
||||||
select {
|
select {
|
||||||
case err := <-req.resultCh:
|
case err := <-req.resultCh:
|
||||||
return err
|
return err
|
||||||
case <-ctx.Done():
|
case <-ctx.Done():
|
||||||
return ctx.Err()
|
return service.MarkUsageLogCreateDropped(ctx.Err())
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -355,7 +355,7 @@ func (r *usageLogRepository) createBatched(ctx context.Context, log *service.Usa
|
|||||||
case <-ctx.Done():
|
case <-ctx.Done():
|
||||||
return false, service.MarkUsageLogCreateNotPersisted(ctx.Err())
|
return false, service.MarkUsageLogCreateNotPersisted(ctx.Err())
|
||||||
default:
|
default:
|
||||||
return r.createSingle(ctx, r.sql, log)
|
return false, service.MarkUsageLogCreateNotPersisted(errors.New("usage log create batch queue full"))
|
||||||
}
|
}
|
||||||
|
|
||||||
select {
|
select {
|
||||||
@@ -840,27 +840,39 @@ func buildUsageLogBatchInsertQuery(keys []string, preparedByKey map[string]usage
|
|||||||
cache_ttl_overridden,
|
cache_ttl_overridden,
|
||||||
created_at
|
created_at
|
||||||
FROM input
|
FROM input
|
||||||
ON CONFLICT (request_id, api_key_id) DO UPDATE
|
ON CONFLICT (request_id, api_key_id) DO NOTHING
|
||||||
SET request_id = usage_logs.request_id
|
RETURNING request_id, api_key_id, id, created_at
|
||||||
RETURNING request_id, api_key_id, id, created_at, (xmax = 0) AS inserted
|
),
|
||||||
|
resolved AS (
|
||||||
|
SELECT
|
||||||
|
input.input_idx,
|
||||||
|
input.request_id,
|
||||||
|
input.api_key_id,
|
||||||
|
COALESCE(inserted.id, existing.id) AS id,
|
||||||
|
COALESCE(inserted.created_at, existing.created_at) AS created_at,
|
||||||
|
(inserted.id IS NOT NULL) AS inserted
|
||||||
|
FROM input
|
||||||
|
LEFT JOIN inserted
|
||||||
|
ON inserted.request_id = input.request_id
|
||||||
|
AND inserted.api_key_id = input.api_key_id
|
||||||
|
LEFT JOIN usage_logs existing
|
||||||
|
ON existing.request_id = input.request_id
|
||||||
|
AND existing.api_key_id = input.api_key_id
|
||||||
)
|
)
|
||||||
SELECT COALESCE(
|
SELECT COALESCE(
|
||||||
json_agg(
|
json_agg(
|
||||||
json_build_object(
|
json_build_object(
|
||||||
'request_id', inserted.request_id,
|
'request_id', resolved.request_id,
|
||||||
'api_key_id', inserted.api_key_id,
|
'api_key_id', resolved.api_key_id,
|
||||||
'id', inserted.id,
|
'id', resolved.id,
|
||||||
'created_at', inserted.created_at,
|
'created_at', resolved.created_at,
|
||||||
'inserted', inserted.inserted
|
'inserted', resolved.inserted
|
||||||
)
|
)
|
||||||
ORDER BY input.input_idx
|
ORDER BY resolved.input_idx
|
||||||
),
|
),
|
||||||
'[]'::json
|
'[]'::json
|
||||||
)
|
)
|
||||||
FROM input
|
FROM resolved
|
||||||
JOIN inserted
|
|
||||||
ON inserted.request_id = input.request_id
|
|
||||||
AND inserted.api_key_id = input.api_key_id
|
|
||||||
`)
|
`)
|
||||||
return query.String(), args
|
return query.String(), args
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -288,6 +288,34 @@ func TestUsageLogRepositoryCreateBestEffort_BatchPathDuplicateRequestID(t *testi
|
|||||||
}, 3*time.Second, 20*time.Millisecond)
|
}, 3*time.Second, 20*time.Millisecond)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestUsageLogRepositoryCreateBestEffort_QueueFullReturnsDropped(t *testing.T) {
|
||||||
|
ctx := context.Background()
|
||||||
|
client := testEntClient(t)
|
||||||
|
repo := newUsageLogRepositoryWithSQL(client, integrationDB)
|
||||||
|
repo.bestEffortBatchCh = make(chan usageLogBestEffortRequest, 1)
|
||||||
|
repo.bestEffortBatchCh <- usageLogBestEffortRequest{}
|
||||||
|
|
||||||
|
user := mustCreateUser(t, client, &service.User{Email: fmt.Sprintf("usage-best-effort-full-%d@example.com", time.Now().UnixNano())})
|
||||||
|
apiKey := mustCreateApiKey(t, client, &service.APIKey{UserID: user.ID, Key: "sk-usage-best-effort-full-" + uuid.NewString(), Name: "k"})
|
||||||
|
account := mustCreateAccount(t, client, &service.Account{Name: "acc-usage-best-effort-full-" + uuid.NewString()})
|
||||||
|
|
||||||
|
err := repo.CreateBestEffort(ctx, &service.UsageLog{
|
||||||
|
UserID: user.ID,
|
||||||
|
APIKeyID: apiKey.ID,
|
||||||
|
AccountID: account.ID,
|
||||||
|
RequestID: uuid.NewString(),
|
||||||
|
Model: "claude-3",
|
||||||
|
InputTokens: 10,
|
||||||
|
OutputTokens: 20,
|
||||||
|
TotalCost: 0.5,
|
||||||
|
ActualCost: 0.5,
|
||||||
|
CreatedAt: time.Now().UTC(),
|
||||||
|
})
|
||||||
|
|
||||||
|
require.Error(t, err)
|
||||||
|
require.True(t, service.IsUsageLogCreateDropped(err))
|
||||||
|
}
|
||||||
|
|
||||||
func TestUsageLogRepositoryCreate_BatchPathCanceledContextMarksNotPersisted(t *testing.T) {
|
func TestUsageLogRepositoryCreate_BatchPathCanceledContextMarksNotPersisted(t *testing.T) {
|
||||||
client := testEntClient(t)
|
client := testEntClient(t)
|
||||||
repo := newUsageLogRepositoryWithSQL(client, integrationDB)
|
repo := newUsageLogRepositoryWithSQL(client, integrationDB)
|
||||||
@@ -317,6 +345,35 @@ func TestUsageLogRepositoryCreate_BatchPathCanceledContextMarksNotPersisted(t *t
|
|||||||
require.True(t, service.IsUsageLogCreateNotPersisted(err))
|
require.True(t, service.IsUsageLogCreateNotPersisted(err))
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestUsageLogRepositoryCreate_BatchPathQueueFullMarksNotPersisted(t *testing.T) {
|
||||||
|
ctx := context.Background()
|
||||||
|
client := testEntClient(t)
|
||||||
|
repo := newUsageLogRepositoryWithSQL(client, integrationDB)
|
||||||
|
repo.createBatchCh = make(chan usageLogCreateRequest, 1)
|
||||||
|
repo.createBatchCh <- usageLogCreateRequest{}
|
||||||
|
|
||||||
|
user := mustCreateUser(t, client, &service.User{Email: fmt.Sprintf("usage-create-full-%d@example.com", time.Now().UnixNano())})
|
||||||
|
apiKey := mustCreateApiKey(t, client, &service.APIKey{UserID: user.ID, Key: "sk-usage-create-full-" + uuid.NewString(), Name: "k"})
|
||||||
|
account := mustCreateAccount(t, client, &service.Account{Name: "acc-usage-create-full-" + uuid.NewString()})
|
||||||
|
|
||||||
|
inserted, err := repo.Create(ctx, &service.UsageLog{
|
||||||
|
UserID: user.ID,
|
||||||
|
APIKeyID: apiKey.ID,
|
||||||
|
AccountID: account.ID,
|
||||||
|
RequestID: uuid.NewString(),
|
||||||
|
Model: "claude-3",
|
||||||
|
InputTokens: 10,
|
||||||
|
OutputTokens: 20,
|
||||||
|
TotalCost: 0.5,
|
||||||
|
ActualCost: 0.5,
|
||||||
|
CreatedAt: time.Now().UTC(),
|
||||||
|
})
|
||||||
|
|
||||||
|
require.False(t, inserted)
|
||||||
|
require.Error(t, err)
|
||||||
|
require.True(t, service.IsUsageLogCreateNotPersisted(err))
|
||||||
|
}
|
||||||
|
|
||||||
func TestUsageLogRepositoryCreate_BatchPathCanceledAfterQueueMarksNotPersisted(t *testing.T) {
|
func TestUsageLogRepositoryCreate_BatchPathCanceledAfterQueueMarksNotPersisted(t *testing.T) {
|
||||||
client := testEntClient(t)
|
client := testEntClient(t)
|
||||||
repo := newUsageLogRepositoryWithSQL(client, integrationDB)
|
repo := newUsageLogRepositoryWithSQL(client, integrationDB)
|
||||||
|
|||||||
@@ -3,8 +3,11 @@
|
|||||||
package repository
|
package repository
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"strings"
|
||||||
"testing"
|
"testing"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/Wei-Shaw/sub2api/internal/service"
|
||||||
"github.com/stretchr/testify/require"
|
"github.com/stretchr/testify/require"
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -39,3 +42,26 @@ func TestSafeDateFormat(t *testing.T) {
|
|||||||
})
|
})
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestBuildUsageLogBatchInsertQuery_UsesConflictDoNothing(t *testing.T) {
|
||||||
|
log := &service.UsageLog{
|
||||||
|
UserID: 1,
|
||||||
|
APIKeyID: 2,
|
||||||
|
AccountID: 3,
|
||||||
|
RequestID: "req-batch-no-update",
|
||||||
|
Model: "gpt-5",
|
||||||
|
InputTokens: 10,
|
||||||
|
OutputTokens: 5,
|
||||||
|
TotalCost: 1.2,
|
||||||
|
ActualCost: 1.2,
|
||||||
|
CreatedAt: time.Now().UTC(),
|
||||||
|
}
|
||||||
|
prepared := prepareUsageLogInsert(log)
|
||||||
|
|
||||||
|
query, _ := buildUsageLogBatchInsertQuery([]string{usageLogBatchKey(log.RequestID, log.APIKeyID)}, map[string]usageLogInsertPrepared{
|
||||||
|
usageLogBatchKey(log.RequestID, log.APIKeyID): prepared,
|
||||||
|
})
|
||||||
|
|
||||||
|
require.Contains(t, query, "ON CONFLICT (request_id, api_key_id) DO NOTHING")
|
||||||
|
require.NotContains(t, strings.ToUpper(query), "DO UPDATE")
|
||||||
|
}
|
||||||
|
|||||||
@@ -4,6 +4,8 @@ package service
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
|
"errors"
|
||||||
|
"strings"
|
||||||
"testing"
|
"testing"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
@@ -233,6 +235,89 @@ func TestGatewayServiceRecordUsage_UsesFallbackRequestIDForUsageLog(t *testing.T
|
|||||||
require.Equal(t, "local:gateway-local-fallback", usageRepo.lastLog.RequestID)
|
require.Equal(t, "local:gateway-local-fallback", usageRepo.lastLog.RequestID)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestGatewayServiceRecordUsage_PrefersClientRequestIDOverUpstreamRequestID(t *testing.T) {
|
||||||
|
usageRepo := &openAIRecordUsageLogRepoStub{}
|
||||||
|
billingRepo := &openAIRecordUsageBillingRepoStub{result: &UsageBillingApplyResult{Applied: true}}
|
||||||
|
svc := newGatewayRecordUsageServiceWithBillingRepoForTest(usageRepo, billingRepo, &openAIRecordUsageUserRepoStub{}, &openAIRecordUsageSubRepoStub{})
|
||||||
|
|
||||||
|
ctx := context.WithValue(context.Background(), ctxkey.ClientRequestID, "client-stable-123")
|
||||||
|
ctx = context.WithValue(ctx, ctxkey.RequestID, "req-local-ignored")
|
||||||
|
err := svc.RecordUsage(ctx, &RecordUsageInput{
|
||||||
|
Result: &ForwardResult{
|
||||||
|
RequestID: "upstream-volatile-456",
|
||||||
|
Usage: ClaudeUsage{
|
||||||
|
InputTokens: 10,
|
||||||
|
OutputTokens: 6,
|
||||||
|
},
|
||||||
|
Model: "claude-sonnet-4",
|
||||||
|
Duration: time.Second,
|
||||||
|
},
|
||||||
|
APIKey: &APIKey{ID: 506},
|
||||||
|
User: &User{ID: 606},
|
||||||
|
Account: &Account{ID: 706},
|
||||||
|
})
|
||||||
|
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.NotNil(t, billingRepo.lastCmd)
|
||||||
|
require.Equal(t, "client:client-stable-123", billingRepo.lastCmd.RequestID)
|
||||||
|
require.NotNil(t, usageRepo.lastLog)
|
||||||
|
require.Equal(t, "client:client-stable-123", usageRepo.lastLog.RequestID)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestGatewayServiceRecordUsage_GeneratesRequestIDWhenAllSourcesMissing(t *testing.T) {
|
||||||
|
usageRepo := &openAIRecordUsageLogRepoStub{}
|
||||||
|
billingRepo := &openAIRecordUsageBillingRepoStub{result: &UsageBillingApplyResult{Applied: true}}
|
||||||
|
svc := newGatewayRecordUsageServiceWithBillingRepoForTest(usageRepo, billingRepo, &openAIRecordUsageUserRepoStub{}, &openAIRecordUsageSubRepoStub{})
|
||||||
|
|
||||||
|
err := svc.RecordUsage(context.Background(), &RecordUsageInput{
|
||||||
|
Result: &ForwardResult{
|
||||||
|
RequestID: "",
|
||||||
|
Usage: ClaudeUsage{
|
||||||
|
InputTokens: 10,
|
||||||
|
OutputTokens: 6,
|
||||||
|
},
|
||||||
|
Model: "claude-sonnet-4",
|
||||||
|
Duration: time.Second,
|
||||||
|
},
|
||||||
|
APIKey: &APIKey{ID: 507},
|
||||||
|
User: &User{ID: 607},
|
||||||
|
Account: &Account{ID: 707},
|
||||||
|
})
|
||||||
|
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.NotNil(t, billingRepo.lastCmd)
|
||||||
|
require.True(t, strings.HasPrefix(billingRepo.lastCmd.RequestID, "generated:"))
|
||||||
|
require.NotNil(t, usageRepo.lastLog)
|
||||||
|
require.Equal(t, billingRepo.lastCmd.RequestID, usageRepo.lastLog.RequestID)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestGatewayServiceRecordUsage_DroppedUsageLogDoesNotSyncFallback(t *testing.T) {
|
||||||
|
usageRepo := &openAIRecordUsageBestEffortLogRepoStub{
|
||||||
|
bestEffortErr: MarkUsageLogCreateDropped(errors.New("usage log best-effort queue full")),
|
||||||
|
}
|
||||||
|
billingRepo := &openAIRecordUsageBillingRepoStub{result: &UsageBillingApplyResult{Applied: true}}
|
||||||
|
svc := newGatewayRecordUsageServiceWithBillingRepoForTest(usageRepo, billingRepo, &openAIRecordUsageUserRepoStub{}, &openAIRecordUsageSubRepoStub{})
|
||||||
|
|
||||||
|
err := svc.RecordUsage(context.Background(), &RecordUsageInput{
|
||||||
|
Result: &ForwardResult{
|
||||||
|
RequestID: "gateway_drop_usage_log",
|
||||||
|
Usage: ClaudeUsage{
|
||||||
|
InputTokens: 10,
|
||||||
|
OutputTokens: 6,
|
||||||
|
},
|
||||||
|
Model: "claude-sonnet-4",
|
||||||
|
Duration: time.Second,
|
||||||
|
},
|
||||||
|
APIKey: &APIKey{ID: 508},
|
||||||
|
User: &User{ID: 608},
|
||||||
|
Account: &Account{ID: 708},
|
||||||
|
})
|
||||||
|
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.Equal(t, 1, usageRepo.bestEffortCalls)
|
||||||
|
require.Equal(t, 0, usageRepo.createCalls)
|
||||||
|
}
|
||||||
|
|
||||||
func TestGatewayServiceRecordUsage_BillingErrorSkipsUsageLogWrite(t *testing.T) {
|
func TestGatewayServiceRecordUsage_BillingErrorSkipsUsageLogWrite(t *testing.T) {
|
||||||
usageRepo := &openAIRecordUsageLogRepoStub{}
|
usageRepo := &openAIRecordUsageLogRepoStub{}
|
||||||
billingRepo := &openAIRecordUsageBillingRepoStub{err: context.DeadlineExceeded}
|
billingRepo := &openAIRecordUsageBillingRepoStub{err: context.DeadlineExceeded}
|
||||||
|
|||||||
@@ -6745,9 +6745,6 @@ func postUsageBilling(ctx context.Context, p *postUsageBillingParams, deps *bill
|
|||||||
}
|
}
|
||||||
|
|
||||||
func resolveUsageBillingRequestID(ctx context.Context, upstreamRequestID string) string {
|
func resolveUsageBillingRequestID(ctx context.Context, upstreamRequestID string) string {
|
||||||
if requestID := strings.TrimSpace(upstreamRequestID); requestID != "" {
|
|
||||||
return requestID
|
|
||||||
}
|
|
||||||
if ctx != nil {
|
if ctx != nil {
|
||||||
if clientRequestID, _ := ctx.Value(ctxkey.ClientRequestID).(string); strings.TrimSpace(clientRequestID) != "" {
|
if clientRequestID, _ := ctx.Value(ctxkey.ClientRequestID).(string); strings.TrimSpace(clientRequestID) != "" {
|
||||||
return "client:" + strings.TrimSpace(clientRequestID)
|
return "client:" + strings.TrimSpace(clientRequestID)
|
||||||
@@ -6756,7 +6753,10 @@ func resolveUsageBillingRequestID(ctx context.Context, upstreamRequestID string)
|
|||||||
return "local:" + strings.TrimSpace(requestID)
|
return "local:" + strings.TrimSpace(requestID)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
return ""
|
if requestID := strings.TrimSpace(upstreamRequestID); requestID != "" {
|
||||||
|
return requestID
|
||||||
|
}
|
||||||
|
return "generated:" + generateRequestID()
|
||||||
}
|
}
|
||||||
|
|
||||||
func resolveUsageBillingPayloadFingerprint(ctx context.Context, requestPayloadHash string) string {
|
func resolveUsageBillingPayloadFingerprint(ctx context.Context, requestPayloadHash string) string {
|
||||||
@@ -6931,6 +6931,9 @@ func writeUsageLogBestEffort(ctx context.Context, repo UsageLogRepository, usage
|
|||||||
if writer, ok := repo.(usageLogBestEffortWriter); ok {
|
if writer, ok := repo.(usageLogBestEffortWriter); ok {
|
||||||
if err := writer.CreateBestEffort(usageCtx, usageLog); err != nil {
|
if err := writer.CreateBestEffort(usageCtx, usageLog); err != nil {
|
||||||
logger.LegacyPrintf(logKey, "Create usage log failed: %v", err)
|
logger.LegacyPrintf(logKey, "Create usage log failed: %v", err)
|
||||||
|
if IsUsageLogCreateDropped(err) {
|
||||||
|
return
|
||||||
|
}
|
||||||
if _, syncErr := repo.Create(usageCtx, usageLog); syncErr != nil {
|
if _, syncErr := repo.Create(usageCtx, usageLog); syncErr != nil {
|
||||||
logger.LegacyPrintf(logKey, "Create usage log sync fallback failed: %v", syncErr)
|
logger.LegacyPrintf(logKey, "Create usage log sync fallback failed: %v", syncErr)
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -3,6 +3,7 @@ package service
|
|||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
"errors"
|
"errors"
|
||||||
|
"strings"
|
||||||
"testing"
|
"testing"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
@@ -28,6 +29,31 @@ func (s *openAIRecordUsageLogRepoStub) Create(ctx context.Context, log *UsageLog
|
|||||||
return s.inserted, s.err
|
return s.inserted, s.err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
type openAIRecordUsageBestEffortLogRepoStub struct {
|
||||||
|
UsageLogRepository
|
||||||
|
|
||||||
|
bestEffortErr error
|
||||||
|
createErr error
|
||||||
|
bestEffortCalls int
|
||||||
|
createCalls int
|
||||||
|
lastLog *UsageLog
|
||||||
|
lastCtxErr error
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *openAIRecordUsageBestEffortLogRepoStub) CreateBestEffort(ctx context.Context, log *UsageLog) error {
|
||||||
|
s.bestEffortCalls++
|
||||||
|
s.lastLog = log
|
||||||
|
s.lastCtxErr = ctx.Err()
|
||||||
|
return s.bestEffortErr
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *openAIRecordUsageBestEffortLogRepoStub) Create(ctx context.Context, log *UsageLog) (bool, error) {
|
||||||
|
s.createCalls++
|
||||||
|
s.lastLog = log
|
||||||
|
s.lastCtxErr = ctx.Err()
|
||||||
|
return false, s.createErr
|
||||||
|
}
|
||||||
|
|
||||||
type openAIRecordUsageBillingRepoStub struct {
|
type openAIRecordUsageBillingRepoStub struct {
|
||||||
UsageBillingRepository
|
UsageBillingRepository
|
||||||
|
|
||||||
@@ -543,6 +569,65 @@ func TestOpenAIGatewayServiceRecordUsage_UsesFallbackRequestIDForBillingAndUsage
|
|||||||
require.Equal(t, "local:req-local-fallback", usageRepo.lastLog.RequestID)
|
require.Equal(t, "local:req-local-fallback", usageRepo.lastLog.RequestID)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestOpenAIGatewayServiceRecordUsage_PrefersClientRequestIDOverUpstreamRequestID(t *testing.T) {
|
||||||
|
usageRepo := &openAIRecordUsageLogRepoStub{}
|
||||||
|
billingRepo := &openAIRecordUsageBillingRepoStub{result: &UsageBillingApplyResult{Applied: true}}
|
||||||
|
userRepo := &openAIRecordUsageUserRepoStub{}
|
||||||
|
subRepo := &openAIRecordUsageSubRepoStub{}
|
||||||
|
svc := newOpenAIRecordUsageServiceWithBillingRepoForTest(usageRepo, billingRepo, userRepo, subRepo, nil)
|
||||||
|
|
||||||
|
ctx := context.WithValue(context.Background(), ctxkey.ClientRequestID, "openai-client-stable-123")
|
||||||
|
err := svc.RecordUsage(ctx, &OpenAIRecordUsageInput{
|
||||||
|
Result: &OpenAIForwardResult{
|
||||||
|
RequestID: "upstream-openai-volatile-456",
|
||||||
|
Usage: OpenAIUsage{
|
||||||
|
InputTokens: 8,
|
||||||
|
OutputTokens: 4,
|
||||||
|
},
|
||||||
|
Model: "gpt-5.1",
|
||||||
|
Duration: time.Second,
|
||||||
|
},
|
||||||
|
APIKey: &APIKey{ID: 10049},
|
||||||
|
User: &User{ID: 20049},
|
||||||
|
Account: &Account{ID: 30049},
|
||||||
|
})
|
||||||
|
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.NotNil(t, billingRepo.lastCmd)
|
||||||
|
require.Equal(t, "client:openai-client-stable-123", billingRepo.lastCmd.RequestID)
|
||||||
|
require.NotNil(t, usageRepo.lastLog)
|
||||||
|
require.Equal(t, "client:openai-client-stable-123", usageRepo.lastLog.RequestID)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestOpenAIGatewayServiceRecordUsage_GeneratesRequestIDWhenAllSourcesMissing(t *testing.T) {
|
||||||
|
usageRepo := &openAIRecordUsageLogRepoStub{}
|
||||||
|
billingRepo := &openAIRecordUsageBillingRepoStub{result: &UsageBillingApplyResult{Applied: true}}
|
||||||
|
userRepo := &openAIRecordUsageUserRepoStub{}
|
||||||
|
subRepo := &openAIRecordUsageSubRepoStub{}
|
||||||
|
svc := newOpenAIRecordUsageServiceWithBillingRepoForTest(usageRepo, billingRepo, userRepo, subRepo, nil)
|
||||||
|
|
||||||
|
err := svc.RecordUsage(context.Background(), &OpenAIRecordUsageInput{
|
||||||
|
Result: &OpenAIForwardResult{
|
||||||
|
RequestID: "",
|
||||||
|
Usage: OpenAIUsage{
|
||||||
|
InputTokens: 8,
|
||||||
|
OutputTokens: 4,
|
||||||
|
},
|
||||||
|
Model: "gpt-5.1",
|
||||||
|
Duration: time.Second,
|
||||||
|
},
|
||||||
|
APIKey: &APIKey{ID: 10050},
|
||||||
|
User: &User{ID: 20050},
|
||||||
|
Account: &Account{ID: 30050},
|
||||||
|
})
|
||||||
|
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.NotNil(t, billingRepo.lastCmd)
|
||||||
|
require.True(t, strings.HasPrefix(billingRepo.lastCmd.RequestID, "generated:"))
|
||||||
|
require.NotNil(t, usageRepo.lastLog)
|
||||||
|
require.Equal(t, billingRepo.lastCmd.RequestID, usageRepo.lastLog.RequestID)
|
||||||
|
}
|
||||||
|
|
||||||
func TestOpenAIGatewayServiceRecordUsage_BillingErrorSkipsUsageLogWrite(t *testing.T) {
|
func TestOpenAIGatewayServiceRecordUsage_BillingErrorSkipsUsageLogWrite(t *testing.T) {
|
||||||
usageRepo := &openAIRecordUsageLogRepoStub{}
|
usageRepo := &openAIRecordUsageLogRepoStub{}
|
||||||
billingRepo := &openAIRecordUsageBillingRepoStub{err: errors.New("billing tx failed")}
|
billingRepo := &openAIRecordUsageBillingRepoStub{err: errors.New("billing tx failed")}
|
||||||
|
|||||||
@@ -7,6 +7,7 @@ type usageLogCreateDisposition int
|
|||||||
const (
|
const (
|
||||||
usageLogCreateDispositionUnknown usageLogCreateDisposition = iota
|
usageLogCreateDispositionUnknown usageLogCreateDisposition = iota
|
||||||
usageLogCreateDispositionNotPersisted
|
usageLogCreateDispositionNotPersisted
|
||||||
|
usageLogCreateDispositionDropped
|
||||||
)
|
)
|
||||||
|
|
||||||
type UsageLogCreateError struct {
|
type UsageLogCreateError struct {
|
||||||
@@ -38,6 +39,16 @@ func MarkUsageLogCreateNotPersisted(err error) error {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func MarkUsageLogCreateDropped(err error) error {
|
||||||
|
if err == nil {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
return &UsageLogCreateError{
|
||||||
|
err: err,
|
||||||
|
disposition: usageLogCreateDispositionDropped,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
func IsUsageLogCreateNotPersisted(err error) bool {
|
func IsUsageLogCreateNotPersisted(err error) bool {
|
||||||
if err == nil {
|
if err == nil {
|
||||||
return false
|
return false
|
||||||
@@ -49,6 +60,17 @@ func IsUsageLogCreateNotPersisted(err error) bool {
|
|||||||
return target.disposition == usageLogCreateDispositionNotPersisted
|
return target.disposition == usageLogCreateDispositionNotPersisted
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func IsUsageLogCreateDropped(err error) bool {
|
||||||
|
if err == nil {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
var target *UsageLogCreateError
|
||||||
|
if !errors.As(err, &target) {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
return target.disposition == usageLogCreateDispositionDropped
|
||||||
|
}
|
||||||
|
|
||||||
func ShouldBillAfterUsageLogCreate(inserted bool, err error) bool {
|
func ShouldBillAfterUsageLogCreate(inserted bool, err error) bool {
|
||||||
if inserted {
|
if inserted {
|
||||||
return true
|
return true
|
||||||
|
|||||||
Reference in New Issue
Block a user