diff --git a/backend/internal/handler/dto/mappers.go b/backend/internal/handler/dto/mappers.go index 2cae9817..a12121fb 100644 --- a/backend/internal/handler/dto/mappers.go +++ b/backend/internal/handler/dto/mappers.go @@ -475,6 +475,7 @@ func usageLogFromServiceUser(l *service.UsageLog) UsageLog { AccountID: l.AccountID, RequestID: l.RequestID, Model: l.Model, + ServiceTier: l.ServiceTier, ReasoningEffort: l.ReasoningEffort, GroupID: l.GroupID, SubscriptionID: l.SubscriptionID, diff --git a/backend/internal/handler/dto/mappers_usage_test.go b/backend/internal/handler/dto/mappers_usage_test.go index d716bdc4..ea408ecb 100644 --- a/backend/internal/handler/dto/mappers_usage_test.go +++ b/backend/internal/handler/dto/mappers_usage_test.go @@ -71,3 +71,29 @@ func TestRequestTypeStringPtrNil(t *testing.T) { t.Parallel() require.Nil(t, requestTypeStringPtr(nil)) } + +func TestUsageLogFromService_IncludesServiceTierForUserAndAdmin(t *testing.T) { + t.Parallel() + + serviceTier := "priority" + log := &service.UsageLog{ + RequestID: "req_3", + Model: "gpt-5.4", + ServiceTier: &serviceTier, + AccountRateMultiplier: f64Ptr(1.5), + } + + userDTO := UsageLogFromService(log) + adminDTO := UsageLogFromServiceAdmin(log) + + require.NotNil(t, userDTO.ServiceTier) + require.Equal(t, serviceTier, *userDTO.ServiceTier) + require.NotNil(t, adminDTO.ServiceTier) + require.Equal(t, serviceTier, *adminDTO.ServiceTier) + require.NotNil(t, adminDTO.AccountRateMultiplier) + require.InDelta(t, 1.5, *adminDTO.AccountRateMultiplier, 1e-12) +} + +func f64Ptr(value float64) *float64 { + return &value +} diff --git a/backend/internal/handler/dto/types.go b/backend/internal/handler/dto/types.go index 1c68f429..7a0a0273 100644 --- a/backend/internal/handler/dto/types.go +++ b/backend/internal/handler/dto/types.go @@ -315,6 +315,8 @@ type UsageLog struct { AccountID int64 `json:"account_id"` RequestID string `json:"request_id"` Model string `json:"model"` + // ServiceTier records the OpenAI service tier used for billing, e.g. "priority" / "flex". + ServiceTier *string `json:"service_tier,omitempty"` // ReasoningEffort is the request's reasoning effort level (OpenAI Responses API). // nil means not provided / not applicable. ReasoningEffort *string `json:"reasoning_effort,omitempty"` diff --git a/backend/internal/repository/usage_log_repo.go b/backend/internal/repository/usage_log_repo.go index 7fc11b78..c91a68e5 100644 --- a/backend/internal/repository/usage_log_repo.go +++ b/backend/internal/repository/usage_log_repo.go @@ -22,7 +22,7 @@ import ( "github.com/lib/pq" ) -const usageLogSelectColumns = "id, user_id, api_key_id, account_id, request_id, model, group_id, subscription_id, input_tokens, output_tokens, cache_creation_tokens, cache_read_tokens, cache_creation_5m_tokens, cache_creation_1h_tokens, input_cost, output_cost, cache_creation_cost, cache_read_cost, total_cost, actual_cost, rate_multiplier, account_rate_multiplier, billing_type, request_type, stream, openai_ws_mode, duration_ms, first_token_ms, user_agent, ip_address, image_count, image_size, media_type, reasoning_effort, cache_ttl_overridden, created_at" +const usageLogSelectColumns = "id, user_id, api_key_id, account_id, request_id, model, group_id, subscription_id, input_tokens, output_tokens, cache_creation_tokens, cache_read_tokens, cache_creation_5m_tokens, cache_creation_1h_tokens, input_cost, output_cost, cache_creation_cost, cache_read_cost, total_cost, actual_cost, rate_multiplier, account_rate_multiplier, billing_type, request_type, stream, openai_ws_mode, duration_ms, first_token_ms, user_agent, ip_address, image_count, image_size, media_type, service_tier, reasoning_effort, cache_ttl_overridden, created_at" // dateFormatWhitelist 将 granularity 参数映射为 PostgreSQL TO_CHAR 格式字符串,防止外部输入直接拼入 SQL var dateFormatWhitelist = map[string]string{ @@ -135,6 +135,7 @@ func (r *usageLogRepository) Create(ctx context.Context, log *service.UsageLog) image_count, image_size, media_type, + service_tier, reasoning_effort, cache_ttl_overridden, created_at @@ -144,7 +145,7 @@ func (r *usageLogRepository) Create(ctx context.Context, log *service.UsageLog) $8, $9, $10, $11, $12, $13, $14, $15, $16, $17, $18, $19, - $20, $21, $22, $23, $24, $25, $26, $27, $28, $29, $30, $31, $32, $33, $34, $35 + $20, $21, $22, $23, $24, $25, $26, $27, $28, $29, $30, $31, $32, $33, $34, $35, $36 ) ON CONFLICT (request_id, api_key_id) DO NOTHING RETURNING id, created_at @@ -158,6 +159,7 @@ func (r *usageLogRepository) Create(ctx context.Context, log *service.UsageLog) ipAddress := nullString(log.IPAddress) imageSize := nullString(log.ImageSize) mediaType := nullString(log.MediaType) + serviceTier := nullString(log.ServiceTier) reasoningEffort := nullString(log.ReasoningEffort) var requestIDArg any @@ -198,6 +200,7 @@ func (r *usageLogRepository) Create(ctx context.Context, log *service.UsageLog) log.ImageCount, imageSize, mediaType, + serviceTier, reasoningEffort, log.CacheTTLOverridden, createdAt, @@ -2505,6 +2508,7 @@ func scanUsageLog(scanner interface{ Scan(...any) error }) (*service.UsageLog, e imageCount int imageSize sql.NullString mediaType sql.NullString + serviceTier sql.NullString reasoningEffort sql.NullString cacheTTLOverridden bool createdAt time.Time @@ -2544,6 +2548,7 @@ func scanUsageLog(scanner interface{ Scan(...any) error }) (*service.UsageLog, e &imageCount, &imageSize, &mediaType, + &serviceTier, &reasoningEffort, &cacheTTLOverridden, &createdAt, @@ -2614,6 +2619,9 @@ func scanUsageLog(scanner interface{ Scan(...any) error }) (*service.UsageLog, e if mediaType.Valid { log.MediaType = &mediaType.String } + if serviceTier.Valid { + log.ServiceTier = &serviceTier.String + } if reasoningEffort.Valid { log.ReasoningEffort = &reasoningEffort.String } diff --git a/backend/internal/repository/usage_log_repo_request_type_test.go b/backend/internal/repository/usage_log_repo_request_type_test.go index 53fb7227..7d82b4d0 100644 --- a/backend/internal/repository/usage_log_repo_request_type_test.go +++ b/backend/internal/repository/usage_log_repo_request_type_test.go @@ -71,6 +71,7 @@ func TestUsageLogRepositoryCreateSyncRequestTypeAndLegacyFields(t *testing.T) { log.ImageCount, sqlmock.AnyArg(), // image_size sqlmock.AnyArg(), // media_type + sqlmock.AnyArg(), // service_tier sqlmock.AnyArg(), // reasoning_effort log.CacheTTLOverridden, createdAt, @@ -81,12 +82,76 @@ func TestUsageLogRepositoryCreateSyncRequestTypeAndLegacyFields(t *testing.T) { require.NoError(t, err) require.True(t, inserted) require.Equal(t, int64(99), log.ID) + require.Nil(t, log.ServiceTier) require.Equal(t, service.RequestTypeWSV2, log.RequestType) require.True(t, log.Stream) require.True(t, log.OpenAIWSMode) require.NoError(t, mock.ExpectationsWereMet()) } +func TestUsageLogRepositoryCreate_PersistsServiceTier(t *testing.T) { + db, mock := newSQLMock(t) + repo := &usageLogRepository{sql: db} + + createdAt := time.Date(2025, 1, 2, 12, 0, 0, 0, time.UTC) + serviceTier := "priority" + log := &service.UsageLog{ + UserID: 1, + APIKeyID: 2, + AccountID: 3, + RequestID: "req-service-tier", + Model: "gpt-5.4", + ServiceTier: &serviceTier, + CreatedAt: createdAt, + } + + mock.ExpectQuery("INSERT INTO usage_logs"). + WithArgs( + log.UserID, + log.APIKeyID, + log.AccountID, + log.RequestID, + log.Model, + sqlmock.AnyArg(), + sqlmock.AnyArg(), + log.InputTokens, + log.OutputTokens, + log.CacheCreationTokens, + log.CacheReadTokens, + log.CacheCreation5mTokens, + log.CacheCreation1hTokens, + log.InputCost, + log.OutputCost, + log.CacheCreationCost, + log.CacheReadCost, + log.TotalCost, + log.ActualCost, + log.RateMultiplier, + log.AccountRateMultiplier, + log.BillingType, + int16(service.RequestTypeSync), + false, + false, + sqlmock.AnyArg(), + sqlmock.AnyArg(), + sqlmock.AnyArg(), + sqlmock.AnyArg(), + log.ImageCount, + sqlmock.AnyArg(), + sqlmock.AnyArg(), + serviceTier, + sqlmock.AnyArg(), + log.CacheTTLOverridden, + createdAt, + ). + WillReturnRows(sqlmock.NewRows([]string{"id", "created_at"}).AddRow(int64(100), createdAt)) + + inserted, err := repo.Create(context.Background(), log) + require.NoError(t, err) + require.True(t, inserted) + require.NoError(t, mock.ExpectationsWereMet()) +} + func TestUsageLogRepositoryListWithFiltersRequestTypePriority(t *testing.T) { db, mock := newSQLMock(t) repo := &usageLogRepository{sql: db} @@ -280,11 +345,14 @@ func TestScanUsageLogRequestTypeAndLegacyFallback(t *testing.T) { 0, sql.NullString{}, sql.NullString{}, + sql.NullString{Valid: true, String: "priority"}, sql.NullString{}, false, now, }}) require.NoError(t, err) + require.NotNil(t, log.ServiceTier) + require.Equal(t, "priority", *log.ServiceTier) require.Equal(t, service.RequestTypeWSV2, log.RequestType) require.True(t, log.Stream) require.True(t, log.OpenAIWSMode) @@ -316,13 +384,53 @@ func TestScanUsageLogRequestTypeAndLegacyFallback(t *testing.T) { 0, sql.NullString{}, sql.NullString{}, + sql.NullString{Valid: true, String: "flex"}, sql.NullString{}, false, now, }}) require.NoError(t, err) + require.NotNil(t, log.ServiceTier) + require.Equal(t, "flex", *log.ServiceTier) require.Equal(t, service.RequestTypeStream, log.RequestType) require.True(t, log.Stream) require.False(t, log.OpenAIWSMode) }) + + t.Run("service_tier_is_scanned", func(t *testing.T) { + now := time.Now().UTC() + log, err := scanUsageLog(usageLogScannerStub{values: []any{ + int64(3), + int64(12), + int64(22), + int64(32), + sql.NullString{Valid: true, String: "req-3"}, + "gpt-5.4", + sql.NullInt64{}, + sql.NullInt64{}, + 1, 2, 3, 4, 5, 6, + 0.1, 0.2, 0.3, 0.4, 1.0, 0.9, + 1.0, + sql.NullFloat64{}, + int16(service.BillingTypeBalance), + int16(service.RequestTypeSync), + false, + false, + sql.NullInt64{}, + sql.NullInt64{}, + sql.NullString{}, + sql.NullString{}, + 0, + sql.NullString{}, + sql.NullString{}, + sql.NullString{Valid: true, String: "priority"}, + sql.NullString{}, + false, + now, + }}) + require.NoError(t, err) + require.NotNil(t, log.ServiceTier) + require.Equal(t, "priority", *log.ServiceTier) + }) + } diff --git a/backend/internal/server/api_contract_test.go b/backend/internal/server/api_contract_test.go index aafbbe21..236bd658 100644 --- a/backend/internal/server/api_contract_test.go +++ b/backend/internal/server/api_contract_test.go @@ -210,8 +210,10 @@ func TestAPIContracts(t *testing.T) { "sora_video_price_per_request": null, "sora_video_price_per_request_hd": null, "claude_code_only": false, + "allow_messages_dispatch": false, "fallback_group_id": null, "fallback_group_id_on_invalid_request": null, + "allow_messages_dispatch": false, "created_at": "2025-01-02T03:04:05Z", "updated_at": "2025-01-02T03:04:05Z" } diff --git a/backend/internal/service/billing_service.go b/backend/internal/service/billing_service.go index d058c25a..68d7a8f9 100644 --- a/backend/internal/service/billing_service.go +++ b/backend/internal/service/billing_service.go @@ -43,16 +43,19 @@ type BillingCache interface { // ModelPricing 模型价格配置(per-token价格,与LiteLLM格式一致) type ModelPricing struct { - InputPricePerToken float64 // 每token输入价格 (USD) - OutputPricePerToken float64 // 每token输出价格 (USD) - CacheCreationPricePerToken float64 // 缓存创建每token价格 (USD) - CacheReadPricePerToken float64 // 缓存读取每token价格 (USD) - CacheCreation5mPrice float64 // 5分钟缓存创建每token价格 (USD) - CacheCreation1hPrice float64 // 1小时缓存创建每token价格 (USD) - SupportsCacheBreakdown bool // 是否支持详细的缓存分类 - LongContextInputThreshold int // 超过阈值后按整次会话提升输入价格 - LongContextInputMultiplier float64 // 长上下文整次会话输入倍率 - LongContextOutputMultiplier float64 // 长上下文整次会话输出倍率 + InputPricePerToken float64 // 每token输入价格 (USD) + InputPricePerTokenPriority float64 // priority service tier 下每token输入价格 (USD) + OutputPricePerToken float64 // 每token输出价格 (USD) + OutputPricePerTokenPriority float64 // priority service tier 下每token输出价格 (USD) + CacheCreationPricePerToken float64 // 缓存创建每token价格 (USD) + CacheReadPricePerToken float64 // 缓存读取每token价格 (USD) + CacheReadPricePerTokenPriority float64 // priority service tier 下缓存读取每token价格 (USD) + CacheCreation5mPrice float64 // 5分钟缓存创建每token价格 (USD) + CacheCreation1hPrice float64 // 1小时缓存创建每token价格 (USD) + SupportsCacheBreakdown bool // 是否支持详细的缓存分类 + LongContextInputThreshold int // 超过阈值后按整次会话提升输入价格 + LongContextInputMultiplier float64 // 长上下文整次会话输入倍率 + LongContextOutputMultiplier float64 // 长上下文整次会话输出倍率 } const ( @@ -61,6 +64,28 @@ const ( openAIGPT54LongContextOutputMultiplier = 1.5 ) +func normalizeBillingServiceTier(serviceTier string) string { + return strings.ToLower(strings.TrimSpace(serviceTier)) +} + +func usePriorityServiceTierPricing(serviceTier string, pricing *ModelPricing) bool { + if pricing == nil || normalizeBillingServiceTier(serviceTier) != "priority" { + return false + } + return pricing.InputPricePerTokenPriority > 0 || pricing.OutputPricePerTokenPriority > 0 || pricing.CacheReadPricePerTokenPriority > 0 +} + +func serviceTierCostMultiplier(serviceTier string) float64 { + switch normalizeBillingServiceTier(serviceTier) { + case "priority": + return 2.0 + case "flex": + return 0.5 + default: + return 1.0 + } +} + // UsageTokens 使用的token数量 type UsageTokens struct { InputTokens int @@ -173,30 +198,60 @@ func (s *BillingService) initFallbackPricing() { // OpenAI GPT-5.1(本地兜底,防止动态定价不可用时拒绝计费) s.fallbackPrices["gpt-5.1"] = &ModelPricing{ - InputPricePerToken: 1.25e-6, // $1.25 per MTok - OutputPricePerToken: 10e-6, // $10 per MTok - CacheCreationPricePerToken: 1.25e-6, // $1.25 per MTok - CacheReadPricePerToken: 0.125e-6, - SupportsCacheBreakdown: false, + InputPricePerToken: 1.25e-6, // $1.25 per MTok + InputPricePerTokenPriority: 2.5e-6, // $2.5 per MTok + OutputPricePerToken: 10e-6, // $10 per MTok + OutputPricePerTokenPriority: 20e-6, // $20 per MTok + CacheCreationPricePerToken: 1.25e-6, // $1.25 per MTok + CacheReadPricePerToken: 0.125e-6, + CacheReadPricePerTokenPriority: 0.25e-6, + SupportsCacheBreakdown: false, } // OpenAI GPT-5.4(业务指定价格) s.fallbackPrices["gpt-5.4"] = &ModelPricing{ - InputPricePerToken: 2.5e-6, // $2.5 per MTok - OutputPricePerToken: 15e-6, // $15 per MTok - CacheCreationPricePerToken: 2.5e-6, // $2.5 per MTok - CacheReadPricePerToken: 0.25e-6, // $0.25 per MTok - SupportsCacheBreakdown: false, - LongContextInputThreshold: openAIGPT54LongContextInputThreshold, - LongContextInputMultiplier: openAIGPT54LongContextInputMultiplier, - LongContextOutputMultiplier: openAIGPT54LongContextOutputMultiplier, + InputPricePerToken: 2.5e-6, // $2.5 per MTok + InputPricePerTokenPriority: 5e-6, // $5 per MTok + OutputPricePerToken: 15e-6, // $15 per MTok + OutputPricePerTokenPriority: 30e-6, // $30 per MTok + CacheCreationPricePerToken: 2.5e-6, // $2.5 per MTok + CacheReadPricePerToken: 0.25e-6, // $0.25 per MTok + CacheReadPricePerTokenPriority: 0.5e-6, // $0.5 per MTok + SupportsCacheBreakdown: false, + LongContextInputThreshold: openAIGPT54LongContextInputThreshold, + LongContextInputMultiplier: openAIGPT54LongContextInputMultiplier, + LongContextOutputMultiplier: openAIGPT54LongContextOutputMultiplier, + } + // OpenAI GPT-5.2(本地兜底) + s.fallbackPrices["gpt-5.2"] = &ModelPricing{ + InputPricePerToken: 1.75e-6, + InputPricePerTokenPriority: 3.5e-6, + OutputPricePerToken: 14e-6, + OutputPricePerTokenPriority: 28e-6, + CacheCreationPricePerToken: 1.75e-6, + CacheReadPricePerToken: 0.175e-6, + CacheReadPricePerTokenPriority: 0.35e-6, + SupportsCacheBreakdown: false, } // Codex 族兜底统一按 GPT-5.1 Codex 价格计费 s.fallbackPrices["gpt-5.1-codex"] = &ModelPricing{ - InputPricePerToken: 1.5e-6, // $1.5 per MTok - OutputPricePerToken: 12e-6, // $12 per MTok - CacheCreationPricePerToken: 1.5e-6, // $1.5 per MTok - CacheReadPricePerToken: 0.15e-6, - SupportsCacheBreakdown: false, + InputPricePerToken: 1.5e-6, // $1.5 per MTok + InputPricePerTokenPriority: 3e-6, // $3 per MTok + OutputPricePerToken: 12e-6, // $12 per MTok + OutputPricePerTokenPriority: 24e-6, // $24 per MTok + CacheCreationPricePerToken: 1.5e-6, // $1.5 per MTok + CacheReadPricePerToken: 0.15e-6, + CacheReadPricePerTokenPriority: 0.3e-6, + SupportsCacheBreakdown: false, + } + s.fallbackPrices["gpt-5.2-codex"] = &ModelPricing{ + InputPricePerToken: 1.75e-6, + InputPricePerTokenPriority: 3.5e-6, + OutputPricePerToken: 14e-6, + OutputPricePerTokenPriority: 28e-6, + CacheCreationPricePerToken: 1.75e-6, + CacheReadPricePerToken: 0.175e-6, + CacheReadPricePerTokenPriority: 0.35e-6, + SupportsCacheBreakdown: false, } s.fallbackPrices["gpt-5.3-codex"] = s.fallbackPrices["gpt-5.1-codex"] } @@ -241,6 +296,10 @@ func (s *BillingService) getFallbackPricing(model string) *ModelPricing { switch normalized { case "gpt-5.4": return s.fallbackPrices["gpt-5.4"] + case "gpt-5.2": + return s.fallbackPrices["gpt-5.2"] + case "gpt-5.2-codex": + return s.fallbackPrices["gpt-5.2-codex"] case "gpt-5.3-codex": return s.fallbackPrices["gpt-5.3-codex"] case "gpt-5.1-codex", "gpt-5.1-codex-max", "gpt-5.1-codex-mini", "codex-mini-latest": @@ -269,16 +328,19 @@ func (s *BillingService) GetModelPricing(model string) (*ModelPricing, error) { price1h := litellmPricing.CacheCreationInputTokenCostAbove1hr enableBreakdown := price1h > 0 && price1h > price5m return s.applyModelSpecificPricingPolicy(model, &ModelPricing{ - InputPricePerToken: litellmPricing.InputCostPerToken, - OutputPricePerToken: litellmPricing.OutputCostPerToken, - CacheCreationPricePerToken: litellmPricing.CacheCreationInputTokenCost, - CacheReadPricePerToken: litellmPricing.CacheReadInputTokenCost, - CacheCreation5mPrice: price5m, - CacheCreation1hPrice: price1h, - SupportsCacheBreakdown: enableBreakdown, - LongContextInputThreshold: litellmPricing.LongContextInputTokenThreshold, - LongContextInputMultiplier: litellmPricing.LongContextInputCostMultiplier, - LongContextOutputMultiplier: litellmPricing.LongContextOutputCostMultiplier, + InputPricePerToken: litellmPricing.InputCostPerToken, + InputPricePerTokenPriority: litellmPricing.InputCostPerTokenPriority, + OutputPricePerToken: litellmPricing.OutputCostPerToken, + OutputPricePerTokenPriority: litellmPricing.OutputCostPerTokenPriority, + CacheCreationPricePerToken: litellmPricing.CacheCreationInputTokenCost, + CacheReadPricePerToken: litellmPricing.CacheReadInputTokenCost, + CacheReadPricePerTokenPriority: litellmPricing.CacheReadInputTokenCostPriority, + CacheCreation5mPrice: price5m, + CacheCreation1hPrice: price1h, + SupportsCacheBreakdown: enableBreakdown, + LongContextInputThreshold: litellmPricing.LongContextInputTokenThreshold, + LongContextInputMultiplier: litellmPricing.LongContextInputCostMultiplier, + LongContextOutputMultiplier: litellmPricing.LongContextOutputCostMultiplier, }), nil } } @@ -295,6 +357,10 @@ func (s *BillingService) GetModelPricing(model string) (*ModelPricing, error) { // CalculateCost 计算使用费用 func (s *BillingService) CalculateCost(model string, tokens UsageTokens, rateMultiplier float64) (*CostBreakdown, error) { + return s.CalculateCostWithServiceTier(model, tokens, rateMultiplier, "") +} + +func (s *BillingService) CalculateCostWithServiceTier(model string, tokens UsageTokens, rateMultiplier float64, serviceTier string) (*CostBreakdown, error) { pricing, err := s.GetModelPricing(model) if err != nil { return nil, err @@ -303,6 +369,21 @@ func (s *BillingService) CalculateCost(model string, tokens UsageTokens, rateMul breakdown := &CostBreakdown{} inputPricePerToken := pricing.InputPricePerToken outputPricePerToken := pricing.OutputPricePerToken + cacheReadPricePerToken := pricing.CacheReadPricePerToken + tierMultiplier := 1.0 + if usePriorityServiceTierPricing(serviceTier, pricing) { + if pricing.InputPricePerTokenPriority > 0 { + inputPricePerToken = pricing.InputPricePerTokenPriority + } + if pricing.OutputPricePerTokenPriority > 0 { + outputPricePerToken = pricing.OutputPricePerTokenPriority + } + if pricing.CacheReadPricePerTokenPriority > 0 { + cacheReadPricePerToken = pricing.CacheReadPricePerTokenPriority + } + } else { + tierMultiplier = serviceTierCostMultiplier(serviceTier) + } if s.shouldApplySessionLongContextPricing(tokens, pricing) { inputPricePerToken *= pricing.LongContextInputMultiplier outputPricePerToken *= pricing.LongContextOutputMultiplier @@ -329,7 +410,14 @@ func (s *BillingService) CalculateCost(model string, tokens UsageTokens, rateMul breakdown.CacheCreationCost = float64(tokens.CacheCreationTokens) * pricing.CacheCreationPricePerToken } - breakdown.CacheReadCost = float64(tokens.CacheReadTokens) * pricing.CacheReadPricePerToken + breakdown.CacheReadCost = float64(tokens.CacheReadTokens) * cacheReadPricePerToken + + if tierMultiplier != 1.0 { + breakdown.InputCost *= tierMultiplier + breakdown.OutputCost *= tierMultiplier + breakdown.CacheCreationCost *= tierMultiplier + breakdown.CacheReadCost *= tierMultiplier + } // 计算总费用 breakdown.TotalCost = breakdown.InputCost + breakdown.OutputCost + diff --git a/backend/internal/service/billing_service_test.go b/backend/internal/service/billing_service_test.go index 0ba52e56..45bbdcee 100644 --- a/backend/internal/service/billing_service_test.go +++ b/backend/internal/service/billing_service_test.go @@ -522,3 +522,189 @@ func TestCalculateCost_LargeTokenCount(t *testing.T) { require.False(t, math.IsNaN(cost.TotalCost)) require.False(t, math.IsInf(cost.TotalCost, 0)) } + +func TestServiceTierCostMultiplier(t *testing.T) { + require.InDelta(t, 2.0, serviceTierCostMultiplier("priority"), 1e-12) + require.InDelta(t, 2.0, serviceTierCostMultiplier(" Priority "), 1e-12) + require.InDelta(t, 0.5, serviceTierCostMultiplier("flex"), 1e-12) + require.InDelta(t, 1.0, serviceTierCostMultiplier(""), 1e-12) + require.InDelta(t, 1.0, serviceTierCostMultiplier("default"), 1e-12) +} + +func TestCalculateCostWithServiceTier_OpenAIPriorityUsesPriorityPricing(t *testing.T) { + svc := newTestBillingService() + tokens := UsageTokens{InputTokens: 100, OutputTokens: 50, CacheReadTokens: 20} + + baseCost, err := svc.CalculateCost("gpt-5.1-codex", tokens, 1.0) + require.NoError(t, err) + + priorityCost, err := svc.CalculateCostWithServiceTier("gpt-5.1-codex", tokens, 1.0, "priority") + require.NoError(t, err) + + require.InDelta(t, baseCost.InputCost*2, priorityCost.InputCost, 1e-10) + require.InDelta(t, baseCost.OutputCost*2, priorityCost.OutputCost, 1e-10) + require.InDelta(t, baseCost.CacheReadCost*2, priorityCost.CacheReadCost, 1e-10) + require.InDelta(t, baseCost.TotalCost*2, priorityCost.TotalCost, 1e-10) +} + +func TestCalculateCostWithServiceTier_FlexAppliesHalfMultiplier(t *testing.T) { + svc := newTestBillingService() + tokens := UsageTokens{InputTokens: 100, OutputTokens: 50, CacheCreationTokens: 40, CacheReadTokens: 20} + + baseCost, err := svc.CalculateCost("gpt-5.4", tokens, 1.0) + require.NoError(t, err) + + flexCost, err := svc.CalculateCostWithServiceTier("gpt-5.4", tokens, 1.0, "flex") + require.NoError(t, err) + + require.InDelta(t, baseCost.InputCost*0.5, flexCost.InputCost, 1e-10) + require.InDelta(t, baseCost.OutputCost*0.5, flexCost.OutputCost, 1e-10) + require.InDelta(t, baseCost.CacheCreationCost*0.5, flexCost.CacheCreationCost, 1e-10) + require.InDelta(t, baseCost.CacheReadCost*0.5, flexCost.CacheReadCost, 1e-10) + require.InDelta(t, baseCost.TotalCost*0.5, flexCost.TotalCost, 1e-10) +} + +func TestCalculateCostWithServiceTier_PriorityFallsBackToTierMultiplierWithoutExplicitPriorityPrice(t *testing.T) { + svc := newTestBillingService() + tokens := UsageTokens{InputTokens: 120, OutputTokens: 30, CacheCreationTokens: 12, CacheReadTokens: 8} + + baseCost, err := svc.CalculateCost("claude-sonnet-4", tokens, 1.0) + require.NoError(t, err) + + priorityCost, err := svc.CalculateCostWithServiceTier("claude-sonnet-4", tokens, 1.0, "priority") + require.NoError(t, err) + + require.InDelta(t, baseCost.InputCost*2, priorityCost.InputCost, 1e-10) + require.InDelta(t, baseCost.OutputCost*2, priorityCost.OutputCost, 1e-10) + require.InDelta(t, baseCost.CacheCreationCost*2, priorityCost.CacheCreationCost, 1e-10) + require.InDelta(t, baseCost.CacheReadCost*2, priorityCost.CacheReadCost, 1e-10) + require.InDelta(t, baseCost.TotalCost*2, priorityCost.TotalCost, 1e-10) +} + +func TestBillingServiceGetModelPricing_UsesDynamicPriorityFields(t *testing.T) { + pricingSvc := &PricingService{ + pricingData: map[string]*LiteLLMModelPricing{ + "gpt-5.4": { + InputCostPerToken: 2.5e-6, + InputCostPerTokenPriority: 5e-6, + OutputCostPerToken: 15e-6, + OutputCostPerTokenPriority: 30e-6, + CacheCreationInputTokenCost: 2.5e-6, + CacheReadInputTokenCost: 0.25e-6, + CacheReadInputTokenCostPriority: 0.5e-6, + LongContextInputTokenThreshold: 272000, + LongContextInputCostMultiplier: 2.0, + LongContextOutputCostMultiplier: 1.5, + }, + }, + } + svc := NewBillingService(&config.Config{}, pricingSvc) + + pricing, err := svc.GetModelPricing("gpt-5.4") + require.NoError(t, err) + require.InDelta(t, 2.5e-6, pricing.InputPricePerToken, 1e-12) + require.InDelta(t, 5e-6, pricing.InputPricePerTokenPriority, 1e-12) + require.InDelta(t, 15e-6, pricing.OutputPricePerToken, 1e-12) + require.InDelta(t, 30e-6, pricing.OutputPricePerTokenPriority, 1e-12) + require.InDelta(t, 0.25e-6, pricing.CacheReadPricePerToken, 1e-12) + require.InDelta(t, 0.5e-6, pricing.CacheReadPricePerTokenPriority, 1e-12) + require.Equal(t, 272000, pricing.LongContextInputThreshold) + require.InDelta(t, 2.0, pricing.LongContextInputMultiplier, 1e-12) + require.InDelta(t, 1.5, pricing.LongContextOutputMultiplier, 1e-12) +} + +func TestBillingServiceGetModelPricing_OpenAIFallbackGpt52Variants(t *testing.T) { + svc := newTestBillingService() + + gpt52, err := svc.GetModelPricing("gpt-5.2") + require.NoError(t, err) + require.NotNil(t, gpt52) + require.InDelta(t, 1.75e-6, gpt52.InputPricePerToken, 1e-12) + require.InDelta(t, 3.5e-6, gpt52.InputPricePerTokenPriority, 1e-12) + + gpt52Codex, err := svc.GetModelPricing("gpt-5.2-codex") + require.NoError(t, err) + require.NotNil(t, gpt52Codex) + require.InDelta(t, 1.75e-6, gpt52Codex.InputPricePerToken, 1e-12) + require.InDelta(t, 3.5e-6, gpt52Codex.InputPricePerTokenPriority, 1e-12) + require.InDelta(t, 28e-6, gpt52Codex.OutputPricePerTokenPriority, 1e-12) +} + +func TestCalculateCostWithServiceTier_PriorityFallsBackToTierMultiplierWhenExplicitPriceMissing(t *testing.T) { + svc := NewBillingService(&config.Config{}, &PricingService{ + pricingData: map[string]*LiteLLMModelPricing{ + "custom-no-priority": { + InputCostPerToken: 1e-6, + OutputCostPerToken: 2e-6, + CacheCreationInputTokenCost: 0.5e-6, + CacheReadInputTokenCost: 0.25e-6, + }, + }, + }) + tokens := UsageTokens{InputTokens: 100, OutputTokens: 50, CacheCreationTokens: 40, CacheReadTokens: 20} + + baseCost, err := svc.CalculateCost("custom-no-priority", tokens, 1.0) + require.NoError(t, err) + + priorityCost, err := svc.CalculateCostWithServiceTier("custom-no-priority", tokens, 1.0, "priority") + require.NoError(t, err) + + require.InDelta(t, baseCost.InputCost*2, priorityCost.InputCost, 1e-10) + require.InDelta(t, baseCost.OutputCost*2, priorityCost.OutputCost, 1e-10) + require.InDelta(t, baseCost.CacheCreationCost*2, priorityCost.CacheCreationCost, 1e-10) + require.InDelta(t, baseCost.CacheReadCost*2, priorityCost.CacheReadCost, 1e-10) + require.InDelta(t, baseCost.TotalCost*2, priorityCost.TotalCost, 1e-10) +} + +func TestGetModelPricing_OpenAIGpt52FallbacksExposePriorityPrices(t *testing.T) { + svc := newTestBillingService() + + gpt52, err := svc.GetModelPricing("gpt-5.2") + require.NoError(t, err) + require.InDelta(t, 1.75e-6, gpt52.InputPricePerToken, 1e-12) + require.InDelta(t, 3.5e-6, gpt52.InputPricePerTokenPriority, 1e-12) + require.InDelta(t, 14e-6, gpt52.OutputPricePerToken, 1e-12) + require.InDelta(t, 28e-6, gpt52.OutputPricePerTokenPriority, 1e-12) + + gpt52Codex, err := svc.GetModelPricing("gpt-5.2-codex") + require.NoError(t, err) + require.InDelta(t, 1.75e-6, gpt52Codex.InputPricePerToken, 1e-12) + require.InDelta(t, 3.5e-6, gpt52Codex.InputPricePerTokenPriority, 1e-12) + require.InDelta(t, 14e-6, gpt52Codex.OutputPricePerToken, 1e-12) + require.InDelta(t, 28e-6, gpt52Codex.OutputPricePerTokenPriority, 1e-12) +} + +func TestGetModelPricing_MapsDynamicPriorityFieldsIntoBillingPricing(t *testing.T) { + svc := NewBillingService(&config.Config{}, &PricingService{ + pricingData: map[string]*LiteLLMModelPricing{ + "dynamic-tier-model": { + InputCostPerToken: 1e-6, + InputCostPerTokenPriority: 2e-6, + OutputCostPerToken: 3e-6, + OutputCostPerTokenPriority: 6e-6, + CacheCreationInputTokenCost: 4e-6, + CacheCreationInputTokenCostAbove1hr: 5e-6, + CacheReadInputTokenCost: 7e-7, + CacheReadInputTokenCostPriority: 8e-7, + LongContextInputTokenThreshold: 999, + LongContextInputCostMultiplier: 1.5, + LongContextOutputCostMultiplier: 1.25, + }, + }, + }) + + pricing, err := svc.GetModelPricing("dynamic-tier-model") + require.NoError(t, err) + require.InDelta(t, 1e-6, pricing.InputPricePerToken, 1e-12) + require.InDelta(t, 2e-6, pricing.InputPricePerTokenPriority, 1e-12) + require.InDelta(t, 3e-6, pricing.OutputPricePerToken, 1e-12) + require.InDelta(t, 6e-6, pricing.OutputPricePerTokenPriority, 1e-12) + require.InDelta(t, 4e-6, pricing.CacheCreation5mPrice, 1e-12) + require.InDelta(t, 5e-6, pricing.CacheCreation1hPrice, 1e-12) + require.True(t, pricing.SupportsCacheBreakdown) + require.InDelta(t, 7e-7, pricing.CacheReadPricePerToken, 1e-12) + require.InDelta(t, 8e-7, pricing.CacheReadPricePerTokenPriority, 1e-12) + require.Equal(t, 999, pricing.LongContextInputThreshold) + require.InDelta(t, 1.5, pricing.LongContextInputMultiplier, 1e-12) + require.InDelta(t, 1.25, pricing.LongContextOutputMultiplier, 1e-12) +} diff --git a/backend/internal/service/openai_gateway_record_usage_test.go b/backend/internal/service/openai_gateway_record_usage_test.go index 6fd94a3e..9529462e 100644 --- a/backend/internal/service/openai_gateway_record_usage_test.go +++ b/backend/internal/service/openai_gateway_record_usage_test.go @@ -334,3 +334,225 @@ func TestOpenAIGatewayServiceRecordUsage_ClampsActualInputTokensToZero(t *testin require.NotNil(t, usageRepo.lastLog) require.Equal(t, 0, usageRepo.lastLog.InputTokens) } + +func TestOpenAIGatewayServiceRecordUsage_Gpt54LongContextBillsWholeSession(t *testing.T) { + usageRepo := &openAIRecordUsageLogRepoStub{inserted: true} + userRepo := &openAIRecordUsageUserRepoStub{} + subRepo := &openAIRecordUsageSubRepoStub{} + svc := newOpenAIRecordUsageServiceForTest(usageRepo, userRepo, subRepo, nil) + + err := svc.RecordUsage(context.Background(), &OpenAIRecordUsageInput{ + Result: &OpenAIForwardResult{ + RequestID: "resp_gpt54_long_context", + Usage: OpenAIUsage{ + InputTokens: 300000, + OutputTokens: 2000, + }, + Model: "gpt-5.4-2026-03-05", + Duration: time.Second, + }, + APIKey: &APIKey{ID: 1014}, + User: &User{ID: 2014}, + Account: &Account{ID: 3014}, + }) + + require.NoError(t, err) + require.NotNil(t, usageRepo.lastLog) + + expectedInput := 300000 * 2.5e-6 * 2.0 + expectedOutput := 2000 * 15e-6 * 1.5 + require.InDelta(t, expectedInput, usageRepo.lastLog.InputCost, 1e-10) + require.InDelta(t, expectedOutput, usageRepo.lastLog.OutputCost, 1e-10) + require.InDelta(t, expectedInput+expectedOutput, usageRepo.lastLog.TotalCost, 1e-10) + require.InDelta(t, (expectedInput+expectedOutput)*1.1, usageRepo.lastLog.ActualCost, 1e-10) + require.Equal(t, 1, userRepo.deductCalls) +} + +func TestOpenAIGatewayServiceRecordUsage_ServiceTierPriorityUsesFastPricing(t *testing.T) { + usageRepo := &openAIRecordUsageLogRepoStub{inserted: true} + userRepo := &openAIRecordUsageUserRepoStub{} + subRepo := &openAIRecordUsageSubRepoStub{} + svc := newOpenAIRecordUsageServiceForTest(usageRepo, userRepo, subRepo, nil) + serviceTier := "priority" + usage := OpenAIUsage{InputTokens: 100, OutputTokens: 50} + + err := svc.RecordUsage(context.Background(), &OpenAIRecordUsageInput{ + Result: &OpenAIForwardResult{ + RequestID: "resp_service_tier_priority", + ServiceTier: &serviceTier, + Usage: usage, + Model: "gpt-5.4", + Duration: time.Second, + }, + APIKey: &APIKey{ID: 1015}, + User: &User{ID: 2015}, + Account: &Account{ID: 3015}, + }) + + require.NoError(t, err) + require.NotNil(t, usageRepo.lastLog) + require.NotNil(t, usageRepo.lastLog.ServiceTier) + require.Equal(t, serviceTier, *usageRepo.lastLog.ServiceTier) + + baseCost, calcErr := svc.billingService.CalculateCost("gpt-5.4", UsageTokens{InputTokens: 100, OutputTokens: 50}, 1.0) + require.NoError(t, calcErr) + require.InDelta(t, baseCost.TotalCost*2, usageRepo.lastLog.TotalCost, 1e-10) +} + +func TestOpenAIGatewayServiceRecordUsage_ServiceTierFlexHalvesCost(t *testing.T) { + usageRepo := &openAIRecordUsageLogRepoStub{inserted: true} + userRepo := &openAIRecordUsageUserRepoStub{} + subRepo := &openAIRecordUsageSubRepoStub{} + svc := newOpenAIRecordUsageServiceForTest(usageRepo, userRepo, subRepo, nil) + serviceTier := "flex" + usage := OpenAIUsage{InputTokens: 100, OutputTokens: 50, CacheReadInputTokens: 20} + + err := svc.RecordUsage(context.Background(), &OpenAIRecordUsageInput{ + Result: &OpenAIForwardResult{ + RequestID: "resp_service_tier_flex", + ServiceTier: &serviceTier, + Usage: usage, + Model: "gpt-5.4", + Duration: time.Second, + }, + APIKey: &APIKey{ID: 1016}, + User: &User{ID: 2016}, + Account: &Account{ID: 3016}, + }) + + require.NoError(t, err) + require.NotNil(t, usageRepo.lastLog) + + baseCost, calcErr := svc.billingService.CalculateCost("gpt-5.4", UsageTokens{InputTokens: 80, OutputTokens: 50, CacheReadTokens: 20}, 1.0) + require.NoError(t, calcErr) + require.InDelta(t, baseCost.TotalCost*0.5, usageRepo.lastLog.TotalCost, 1e-10) +} + +func TestNormalizeOpenAIServiceTier(t *testing.T) { + t.Run("fast maps to priority", func(t *testing.T) { + got := normalizeOpenAIServiceTier(" fast ") + require.NotNil(t, got) + require.Equal(t, "priority", *got) + }) + + t.Run("default ignored", func(t *testing.T) { + require.Nil(t, normalizeOpenAIServiceTier("default")) + }) + + t.Run("invalid ignored", func(t *testing.T) { + require.Nil(t, normalizeOpenAIServiceTier("turbo")) + }) +} + +func TestExtractOpenAIServiceTier(t *testing.T) { + require.Equal(t, "priority", *extractOpenAIServiceTier(map[string]any{"service_tier": "fast"})) + require.Equal(t, "flex", *extractOpenAIServiceTier(map[string]any{"service_tier": "flex"})) + require.Nil(t, extractOpenAIServiceTier(map[string]any{"service_tier": 1})) + require.Nil(t, extractOpenAIServiceTier(nil)) +} + +func TestExtractOpenAIServiceTierFromBody(t *testing.T) { + require.Equal(t, "priority", *extractOpenAIServiceTierFromBody([]byte(`{"service_tier":"fast"}`))) + require.Equal(t, "flex", *extractOpenAIServiceTierFromBody([]byte(`{"service_tier":"flex"}`))) + require.Nil(t, extractOpenAIServiceTierFromBody([]byte(`{"service_tier":"default"}`))) + require.Nil(t, extractOpenAIServiceTierFromBody(nil)) +} + +func TestOpenAIGatewayServiceRecordUsage_UsesBillingModelAndMetadataFields(t *testing.T) { + usageRepo := &openAIRecordUsageLogRepoStub{inserted: true} + userRepo := &openAIRecordUsageUserRepoStub{} + subRepo := &openAIRecordUsageSubRepoStub{} + svc := newOpenAIRecordUsageServiceForTest(usageRepo, userRepo, subRepo, nil) + serviceTier := "priority" + reasoning := "high" + + err := svc.RecordUsage(context.Background(), &OpenAIRecordUsageInput{ + Result: &OpenAIForwardResult{ + RequestID: "resp_billing_model_override", + BillingModel: "gpt-5.1-codex", + Model: "gpt-5.1", + ServiceTier: &serviceTier, + ReasoningEffort: &reasoning, + Usage: OpenAIUsage{ + InputTokens: 20, + OutputTokens: 10, + }, + Duration: 2 * time.Second, + FirstTokenMs: func() *int { v := 120; return &v }(), + }, + APIKey: &APIKey{ID: 10, GroupID: i64p(11), Group: &Group{ID: 11, RateMultiplier: 1.2}}, + User: &User{ID: 20}, + Account: &Account{ID: 30}, + UserAgent: "codex-cli/1.0", + IPAddress: "127.0.0.1", + }) + + require.NoError(t, err) + require.NotNil(t, usageRepo.lastLog) + require.Equal(t, "gpt-5.1-codex", usageRepo.lastLog.Model) + require.NotNil(t, usageRepo.lastLog.ServiceTier) + require.Equal(t, serviceTier, *usageRepo.lastLog.ServiceTier) + require.NotNil(t, usageRepo.lastLog.ReasoningEffort) + require.Equal(t, reasoning, *usageRepo.lastLog.ReasoningEffort) + require.NotNil(t, usageRepo.lastLog.UserAgent) + require.Equal(t, "codex-cli/1.0", *usageRepo.lastLog.UserAgent) + require.NotNil(t, usageRepo.lastLog.IPAddress) + require.Equal(t, "127.0.0.1", *usageRepo.lastLog.IPAddress) + require.NotNil(t, usageRepo.lastLog.GroupID) + require.Equal(t, int64(11), *usageRepo.lastLog.GroupID) + require.Equal(t, 1, userRepo.deductCalls) +} + +func TestOpenAIGatewayServiceRecordUsage_SubscriptionBillingSetsSubscriptionFields(t *testing.T) { + usageRepo := &openAIRecordUsageLogRepoStub{inserted: true} + userRepo := &openAIRecordUsageUserRepoStub{} + subRepo := &openAIRecordUsageSubRepoStub{} + svc := newOpenAIRecordUsageServiceForTest(usageRepo, userRepo, subRepo, nil) + subscription := &UserSubscription{ID: 99} + + err := svc.RecordUsage(context.Background(), &OpenAIRecordUsageInput{ + Result: &OpenAIForwardResult{ + RequestID: "resp_subscription_billing", + Usage: OpenAIUsage{InputTokens: 10, OutputTokens: 5}, + Model: "gpt-5.1", + Duration: time.Second, + }, + APIKey: &APIKey{ID: 100, GroupID: i64p(88), Group: &Group{ID: 88, SubscriptionType: SubscriptionTypeSubscription}}, + User: &User{ID: 200}, + Account: &Account{ID: 300}, + Subscription: subscription, + }) + + require.NoError(t, err) + require.NotNil(t, usageRepo.lastLog) + require.Equal(t, BillingTypeSubscription, usageRepo.lastLog.BillingType) + require.NotNil(t, usageRepo.lastLog.SubscriptionID) + require.Equal(t, subscription.ID, *usageRepo.lastLog.SubscriptionID) + require.Equal(t, 1, subRepo.incrementCalls) + require.Equal(t, 0, userRepo.deductCalls) +} + +func TestOpenAIGatewayServiceRecordUsage_SimpleModeSkipsBillingAfterPersist(t *testing.T) { + usageRepo := &openAIRecordUsageLogRepoStub{inserted: true} + userRepo := &openAIRecordUsageUserRepoStub{} + subRepo := &openAIRecordUsageSubRepoStub{} + svc := newOpenAIRecordUsageServiceForTest(usageRepo, userRepo, subRepo, nil) + svc.cfg.RunMode = config.RunModeSimple + + err := svc.RecordUsage(context.Background(), &OpenAIRecordUsageInput{ + Result: &OpenAIForwardResult{ + RequestID: "resp_simple_mode", + Usage: OpenAIUsage{InputTokens: 10, OutputTokens: 5}, + Model: "gpt-5.1", + Duration: time.Second, + }, + APIKey: &APIKey{ID: 1000}, + User: &User{ID: 2000}, + Account: &Account{ID: 3000}, + }) + + require.NoError(t, err) + require.Equal(t, 1, usageRepo.calls) + require.Equal(t, 0, userRepo.deductCalls) + require.Equal(t, 0, subRepo.incrementCalls) +} diff --git a/backend/internal/service/openai_gateway_service.go b/backend/internal/service/openai_gateway_service.go index 5c8c2710..c84a4d3e 100644 --- a/backend/internal/service/openai_gateway_service.go +++ b/backend/internal/service/openai_gateway_service.go @@ -213,6 +213,9 @@ type OpenAIForwardResult struct { // This is set by the Anthropic Messages conversion path where // the mapped upstream model differs from the client-facing model. BillingModel string + // ServiceTier records the OpenAI Responses API service tier, e.g. "priority" / "flex". + // Nil means the request did not specify a recognized tier. + ServiceTier *string // ReasoningEffort is extracted from request body (reasoning.effort) or derived from model suffix. // Stored for usage records display; nil means not provided / not applicable. ReasoningEffort *string @@ -2036,11 +2039,13 @@ func (s *OpenAIGatewayService) Forward(ctx context.Context, c *gin.Context, acco } reasoningEffort := extractOpenAIReasoningEffort(reqBody, originalModel) + serviceTier := extractOpenAIServiceTier(reqBody) return &OpenAIForwardResult{ RequestID: resp.Header.Get("x-request-id"), Usage: *usage, Model: originalModel, + ServiceTier: serviceTier, ReasoningEffort: reasoningEffort, Stream: reqStream, OpenAIWSMode: false, @@ -2195,6 +2200,7 @@ func (s *OpenAIGatewayService) forwardOpenAIPassthrough( RequestID: resp.Header.Get("x-request-id"), Usage: *usage, Model: reqModel, + ServiceTier: extractOpenAIServiceTierFromBody(body), ReasoningEffort: reasoningEffort, Stream: reqStream, OpenAIWSMode: false, @@ -3628,7 +3634,11 @@ func (s *OpenAIGatewayService) RecordUsage(ctx context.Context, input *OpenAIRec if result.BillingModel != "" { billingModel = result.BillingModel } - cost, err := s.billingService.CalculateCost(billingModel, tokens, multiplier) + serviceTier := "" + if result.ServiceTier != nil { + serviceTier = strings.TrimSpace(*result.ServiceTier) + } + cost, err := s.billingService.CalculateCostWithServiceTier(billingModel, tokens, multiplier, serviceTier) if err != nil { cost = &CostBreakdown{ActualCost: 0} } @@ -3649,6 +3659,7 @@ func (s *OpenAIGatewayService) RecordUsage(ctx context.Context, input *OpenAIRec AccountID: account.ID, RequestID: result.RequestID, Model: billingModel, + ServiceTier: result.ServiceTier, ReasoningEffort: result.ReasoningEffort, InputTokens: actualInputTokens, OutputTokens: result.Usage.OutputTokens, @@ -4047,6 +4058,40 @@ func extractOpenAIReasoningEffortFromBody(body []byte, requestedModel string) *s return &value } +func extractOpenAIServiceTier(reqBody map[string]any) *string { + if reqBody == nil { + return nil + } + raw, ok := reqBody["service_tier"].(string) + if !ok { + return nil + } + return normalizeOpenAIServiceTier(raw) +} + +func extractOpenAIServiceTierFromBody(body []byte) *string { + if len(body) == 0 { + return nil + } + return normalizeOpenAIServiceTier(gjson.GetBytes(body, "service_tier").String()) +} + +func normalizeOpenAIServiceTier(raw string) *string { + value := strings.ToLower(strings.TrimSpace(raw)) + if value == "" { + return nil + } + if value == "fast" { + value = "priority" + } + switch value { + case "priority", "flex": + return &value + default: + return nil + } +} + func getOpenAIRequestBodyMap(c *gin.Context, body []byte) (map[string]any, error) { if c != nil { if cached, ok := c.Get(OpenAIParsedRequestBodyKey); ok { diff --git a/backend/internal/service/openai_oauth_passthrough_test.go b/backend/internal/service/openai_oauth_passthrough_test.go index 2191d124..6fbd2469 100644 --- a/backend/internal/service/openai_oauth_passthrough_test.go +++ b/backend/internal/service/openai_oauth_passthrough_test.go @@ -671,7 +671,7 @@ func TestOpenAIGatewayService_OAuthPassthrough_StreamingSetsFirstTokenMs(t *test c.Request = httptest.NewRequest(http.MethodPost, "/v1/responses", bytes.NewReader(nil)) c.Request.Header.Set("User-Agent", "codex_cli_rs/0.1.0") - originalBody := []byte(`{"model":"gpt-5.2","stream":true,"input":[{"type":"text","text":"hi"}]}`) + originalBody := []byte(`{"model":"gpt-5.2","stream":true,"service_tier":"fast","input":[{"type":"text","text":"hi"}]}`) upstreamSSE := strings.Join([]string{ `data: {"type":"response.output_text.delta","delta":"h"}`, @@ -711,6 +711,8 @@ func TestOpenAIGatewayService_OAuthPassthrough_StreamingSetsFirstTokenMs(t *test require.GreaterOrEqual(t, time.Since(start), time.Duration(0)) require.NotNil(t, result.FirstTokenMs) require.GreaterOrEqual(t, *result.FirstTokenMs, 0) + require.NotNil(t, result.ServiceTier) + require.Equal(t, "priority", *result.ServiceTier) } func TestOpenAIGatewayService_OAuthPassthrough_StreamClientDisconnectStillCollectsUsage(t *testing.T) { @@ -777,7 +779,7 @@ func TestOpenAIGatewayService_APIKeyPassthrough_PreservesBodyAndUsesResponsesEnd c.Request.Header.Set("User-Agent", "curl/8.0") c.Request.Header.Set("X-Test", "keep") - originalBody := []byte(`{"model":"gpt-5.2","stream":false,"max_output_tokens":128,"input":[{"type":"text","text":"hi"}]}`) + originalBody := []byte(`{"model":"gpt-5.2","stream":false,"service_tier":"flex","max_output_tokens":128,"input":[{"type":"text","text":"hi"}]}`) resp := &http.Response{ StatusCode: http.StatusOK, Header: http.Header{"Content-Type": []string{"application/json"}, "x-request-id": []string{"rid"}}, @@ -803,8 +805,11 @@ func TestOpenAIGatewayService_APIKeyPassthrough_PreservesBodyAndUsesResponsesEnd RateMultiplier: f64p(1), } - _, err := svc.Forward(context.Background(), c, account, originalBody) + result, err := svc.Forward(context.Background(), c, account, originalBody) require.NoError(t, err) + require.NotNil(t, result) + require.NotNil(t, result.ServiceTier) + require.Equal(t, "flex", *result.ServiceTier) require.NotNil(t, upstream.lastReq) require.Equal(t, originalBody, upstream.lastBody) require.Equal(t, "https://api.openai.com/v1/responses", upstream.lastReq.URL.String()) diff --git a/backend/internal/service/openai_ws_forwarder.go b/backend/internal/service/openai_ws_forwarder.go index f9e93f85..526db215 100644 --- a/backend/internal/service/openai_ws_forwarder.go +++ b/backend/internal/service/openai_ws_forwarder.go @@ -2302,6 +2302,7 @@ func (s *OpenAIGatewayService) forwardOpenAIWSV2( RequestID: responseID, Usage: *usage, Model: originalModel, + ServiceTier: extractOpenAIServiceTier(reqBody), ReasoningEffort: extractOpenAIReasoningEffort(reqBody, originalModel), Stream: reqStream, OpenAIWSMode: true, @@ -2913,6 +2914,7 @@ func (s *OpenAIGatewayService) ProxyResponsesWebSocketFromClient( RequestID: responseID, Usage: usage, Model: originalModel, + ServiceTier: extractOpenAIServiceTierFromBody(payload), ReasoningEffort: extractOpenAIReasoningEffortFromBody(payload, originalModel), Stream: reqStream, OpenAIWSMode: true, diff --git a/backend/internal/service/openai_ws_forwarder_ingress_session_test.go b/backend/internal/service/openai_ws_forwarder_ingress_session_test.go index 59e6ecad..c527f2eb 100644 --- a/backend/internal/service/openai_ws_forwarder_ingress_session_test.go +++ b/backend/internal/service/openai_ws_forwarder_ingress_session_test.go @@ -399,7 +399,7 @@ func TestOpenAIGatewayService_ProxyResponsesWebSocketFromClient_PassthroughModeR }() writeCtx, cancelWrite := context.WithTimeout(context.Background(), 3*time.Second) - err = clientConn.Write(writeCtx, coderws.MessageText, []byte(`{"type":"response.create","model":"gpt-5.1","stream":false}`)) + err = clientConn.Write(writeCtx, coderws.MessageText, []byte(`{"type":"response.create","model":"gpt-5.1","stream":false,"service_tier":"fast"}`)) cancelWrite() require.NoError(t, err) @@ -424,6 +424,8 @@ func TestOpenAIGatewayService_ProxyResponsesWebSocketFromClient_PassthroughModeR require.True(t, result.OpenAIWSMode) require.Equal(t, 2, result.Usage.InputTokens) require.Equal(t, 3, result.Usage.OutputTokens) + require.NotNil(t, result.ServiceTier) + require.Equal(t, "priority", *result.ServiceTier) case <-time.After(2 * time.Second): t.Fatal("未收到 passthrough turn 结果回调") } @@ -2593,7 +2595,7 @@ func TestOpenAIGatewayService_ProxyResponsesWebSocketFromClient_ClientDisconnect require.NoError(t, err) writeCtx, cancelWrite := context.WithTimeout(context.Background(), 3*time.Second) - err = clientConn.Write(writeCtx, coderws.MessageText, []byte(`{"type":"response.create","model":"custom-original-model","stream":false}`)) + err = clientConn.Write(writeCtx, coderws.MessageText, []byte(`{"type":"response.create","model":"custom-original-model","stream":false,"service_tier":"flex"}`)) cancelWrite() require.NoError(t, err) // 立即关闭客户端,模拟客户端在 relay 期间断连。 @@ -2611,6 +2613,8 @@ func TestOpenAIGatewayService_ProxyResponsesWebSocketFromClient_ClientDisconnect require.Equal(t, "resp_ingress_disconnect", result.RequestID) require.Equal(t, 2, result.Usage.InputTokens) require.Equal(t, 1, result.Usage.OutputTokens) + require.NotNil(t, result.ServiceTier) + require.Equal(t, "flex", *result.ServiceTier) case <-time.After(2 * time.Second): t.Fatal("未收到断连后的 turn 结果回调") } diff --git a/backend/internal/service/openai_ws_v2_passthrough_adapter.go b/backend/internal/service/openai_ws_v2_passthrough_adapter.go index c18c921f..cda2e351 100644 --- a/backend/internal/service/openai_ws_v2_passthrough_adapter.go +++ b/backend/internal/service/openai_ws_v2_passthrough_adapter.go @@ -77,6 +77,7 @@ func (s *OpenAIGatewayService) proxyResponsesWebSocketV2Passthrough( return errors.New("token is empty") } requestModel := strings.TrimSpace(gjson.GetBytes(firstClientMessage, "model").String()) + requestServiceTier := extractOpenAIServiceTierFromBody(firstClientMessage) requestPreviousResponseID := strings.TrimSpace(gjson.GetBytes(firstClientMessage, "previous_response_id").String()) logOpenAIWSV2Passthrough( "relay_start account_id=%d model=%s previous_response_id=%s first_message_type=%s first_message_bytes=%d", @@ -178,6 +179,7 @@ func (s *OpenAIGatewayService) proxyResponsesWebSocketV2Passthrough( CacheReadInputTokens: turn.Usage.CacheReadInputTokens, }, Model: turn.RequestModel, + ServiceTier: requestServiceTier, Stream: true, OpenAIWSMode: true, ResponseHeaders: cloneHeader(handshakeHeaders), @@ -225,6 +227,7 @@ func (s *OpenAIGatewayService) proxyResponsesWebSocketV2Passthrough( CacheReadInputTokens: relayResult.Usage.CacheReadInputTokens, }, Model: relayResult.RequestModel, + ServiceTier: requestServiceTier, Stream: true, OpenAIWSMode: true, ResponseHeaders: cloneHeader(handshakeHeaders), diff --git a/backend/internal/service/pricing_service.go b/backend/internal/service/pricing_service.go index 897623d6..7ed4e7e4 100644 --- a/backend/internal/service/pricing_service.go +++ b/backend/internal/service/pricing_service.go @@ -40,13 +40,17 @@ var ( // 只保留我们需要的字段,使用指针来处理可能缺失的值 type LiteLLMModelPricing struct { InputCostPerToken float64 `json:"input_cost_per_token"` + InputCostPerTokenPriority float64 `json:"input_cost_per_token_priority"` OutputCostPerToken float64 `json:"output_cost_per_token"` + OutputCostPerTokenPriority float64 `json:"output_cost_per_token_priority"` CacheCreationInputTokenCost float64 `json:"cache_creation_input_token_cost"` CacheCreationInputTokenCostAbove1hr float64 `json:"cache_creation_input_token_cost_above_1hr"` CacheReadInputTokenCost float64 `json:"cache_read_input_token_cost"` + CacheReadInputTokenCostPriority float64 `json:"cache_read_input_token_cost_priority"` LongContextInputTokenThreshold int `json:"long_context_input_token_threshold,omitempty"` LongContextInputCostMultiplier float64 `json:"long_context_input_cost_multiplier,omitempty"` LongContextOutputCostMultiplier float64 `json:"long_context_output_cost_multiplier,omitempty"` + SupportsServiceTier bool `json:"supports_service_tier"` LiteLLMProvider string `json:"litellm_provider"` Mode string `json:"mode"` SupportsPromptCaching bool `json:"supports_prompt_caching"` @@ -62,10 +66,14 @@ type PricingRemoteClient interface { // LiteLLMRawEntry 用于解析原始JSON数据 type LiteLLMRawEntry struct { InputCostPerToken *float64 `json:"input_cost_per_token"` + InputCostPerTokenPriority *float64 `json:"input_cost_per_token_priority"` OutputCostPerToken *float64 `json:"output_cost_per_token"` + OutputCostPerTokenPriority *float64 `json:"output_cost_per_token_priority"` CacheCreationInputTokenCost *float64 `json:"cache_creation_input_token_cost"` CacheCreationInputTokenCostAbove1hr *float64 `json:"cache_creation_input_token_cost_above_1hr"` CacheReadInputTokenCost *float64 `json:"cache_read_input_token_cost"` + CacheReadInputTokenCostPriority *float64 `json:"cache_read_input_token_cost_priority"` + SupportsServiceTier bool `json:"supports_service_tier"` LiteLLMProvider string `json:"litellm_provider"` Mode string `json:"mode"` SupportsPromptCaching bool `json:"supports_prompt_caching"` @@ -324,14 +332,21 @@ func (s *PricingService) parsePricingData(body []byte) (map[string]*LiteLLMModel LiteLLMProvider: entry.LiteLLMProvider, Mode: entry.Mode, SupportsPromptCaching: entry.SupportsPromptCaching, + SupportsServiceTier: entry.SupportsServiceTier, } if entry.InputCostPerToken != nil { pricing.InputCostPerToken = *entry.InputCostPerToken } + if entry.InputCostPerTokenPriority != nil { + pricing.InputCostPerTokenPriority = *entry.InputCostPerTokenPriority + } if entry.OutputCostPerToken != nil { pricing.OutputCostPerToken = *entry.OutputCostPerToken } + if entry.OutputCostPerTokenPriority != nil { + pricing.OutputCostPerTokenPriority = *entry.OutputCostPerTokenPriority + } if entry.CacheCreationInputTokenCost != nil { pricing.CacheCreationInputTokenCost = *entry.CacheCreationInputTokenCost } @@ -341,6 +356,9 @@ func (s *PricingService) parsePricingData(body []byte) (map[string]*LiteLLMModel if entry.CacheReadInputTokenCost != nil { pricing.CacheReadInputTokenCost = *entry.CacheReadInputTokenCost } + if entry.CacheReadInputTokenCostPriority != nil { + pricing.CacheReadInputTokenCostPriority = *entry.CacheReadInputTokenCostPriority + } if entry.OutputCostPerImage != nil { pricing.OutputCostPerImage = *entry.OutputCostPerImage } diff --git a/backend/internal/service/pricing_service_test.go b/backend/internal/service/pricing_service_test.go index 6b67c55a..775024fd 100644 --- a/backend/internal/service/pricing_service_test.go +++ b/backend/internal/service/pricing_service_test.go @@ -1,11 +1,40 @@ package service import ( + "encoding/json" "testing" "github.com/stretchr/testify/require" ) +func TestParsePricingData_ParsesPriorityAndServiceTierFields(t *testing.T) { + svc := &PricingService{} + body := []byte(`{ + "gpt-5.4": { + "input_cost_per_token": 0.0000025, + "input_cost_per_token_priority": 0.000005, + "output_cost_per_token": 0.000015, + "output_cost_per_token_priority": 0.00003, + "cache_creation_input_token_cost": 0.0000025, + "cache_read_input_token_cost": 0.00000025, + "cache_read_input_token_cost_priority": 0.0000005, + "supports_service_tier": true, + "supports_prompt_caching": true, + "litellm_provider": "openai", + "mode": "chat" + } + }`) + + data, err := svc.parsePricingData(body) + require.NoError(t, err) + pricing := data["gpt-5.4"] + require.NotNil(t, pricing) + require.InDelta(t, 5e-6, pricing.InputCostPerTokenPriority, 1e-12) + require.InDelta(t, 3e-5, pricing.OutputCostPerTokenPriority, 1e-12) + require.InDelta(t, 5e-7, pricing.CacheReadInputTokenCostPriority, 1e-12) + require.True(t, pricing.SupportsServiceTier) +} + func TestGetModelPricing_Gpt53CodexSparkUsesGpt51CodexPricing(t *testing.T) { sparkPricing := &LiteLLMModelPricing{InputCostPerToken: 1} gpt53Pricing := &LiteLLMModelPricing{InputCostPerToken: 9} @@ -68,3 +97,64 @@ func TestGetModelPricing_Gpt54UsesStaticFallbackWhenRemoteMissing(t *testing.T) require.InDelta(t, 2.0, got.LongContextInputCostMultiplier, 1e-12) require.InDelta(t, 1.5, got.LongContextOutputCostMultiplier, 1e-12) } + +func TestParsePricingData_PreservesPriorityAndServiceTierFields(t *testing.T) { + raw := map[string]any{ + "gpt-5.4": map[string]any{ + "input_cost_per_token": 2.5e-6, + "input_cost_per_token_priority": 5e-6, + "output_cost_per_token": 15e-6, + "output_cost_per_token_priority": 30e-6, + "cache_read_input_token_cost": 0.25e-6, + "cache_read_input_token_cost_priority": 0.5e-6, + "supports_service_tier": true, + "supports_prompt_caching": true, + "litellm_provider": "openai", + "mode": "chat", + }, + } + body, err := json.Marshal(raw) + require.NoError(t, err) + + svc := &PricingService{} + pricingMap, err := svc.parsePricingData(body) + require.NoError(t, err) + + pricing := pricingMap["gpt-5.4"] + require.NotNil(t, pricing) + require.InDelta(t, 2.5e-6, pricing.InputCostPerToken, 1e-12) + require.InDelta(t, 5e-6, pricing.InputCostPerTokenPriority, 1e-12) + require.InDelta(t, 15e-6, pricing.OutputCostPerToken, 1e-12) + require.InDelta(t, 30e-6, pricing.OutputCostPerTokenPriority, 1e-12) + require.InDelta(t, 0.25e-6, pricing.CacheReadInputTokenCost, 1e-12) + require.InDelta(t, 0.5e-6, pricing.CacheReadInputTokenCostPriority, 1e-12) + require.True(t, pricing.SupportsServiceTier) +} + +func TestParsePricingData_PreservesServiceTierPriorityFields(t *testing.T) { + svc := &PricingService{} + pricingData, err := svc.parsePricingData([]byte(`{ + "gpt-5.4": { + "input_cost_per_token": 0.0000025, + "input_cost_per_token_priority": 0.000005, + "output_cost_per_token": 0.000015, + "output_cost_per_token_priority": 0.00003, + "cache_read_input_token_cost": 0.00000025, + "cache_read_input_token_cost_priority": 0.0000005, + "supports_service_tier": true, + "litellm_provider": "openai", + "mode": "chat" + } + }`)) + require.NoError(t, err) + + pricing := pricingData["gpt-5.4"] + require.NotNil(t, pricing) + require.InDelta(t, 0.0000025, pricing.InputCostPerToken, 1e-12) + require.InDelta(t, 0.000005, pricing.InputCostPerTokenPriority, 1e-12) + require.InDelta(t, 0.000015, pricing.OutputCostPerToken, 1e-12) + require.InDelta(t, 0.00003, pricing.OutputCostPerTokenPriority, 1e-12) + require.InDelta(t, 0.00000025, pricing.CacheReadInputTokenCost, 1e-12) + require.InDelta(t, 0.0000005, pricing.CacheReadInputTokenCostPriority, 1e-12) + require.True(t, pricing.SupportsServiceTier) +} diff --git a/backend/internal/service/usage_log.go b/backend/internal/service/usage_log.go index c1a95541..a7464956 100644 --- a/backend/internal/service/usage_log.go +++ b/backend/internal/service/usage_log.go @@ -98,6 +98,8 @@ type UsageLog struct { AccountID int64 RequestID string Model string + // ServiceTier records the OpenAI service tier used for billing, e.g. "priority" / "flex". + ServiceTier *string // ReasoningEffort is the request's reasoning effort level (OpenAI Responses API), // e.g. "low" / "medium" / "high" / "xhigh". Nil means not provided / not applicable. ReasoningEffort *string diff --git a/backend/migrations/070_add_usage_log_service_tier.sql b/backend/migrations/070_add_usage_log_service_tier.sql new file mode 100644 index 00000000..085ec0d6 --- /dev/null +++ b/backend/migrations/070_add_usage_log_service_tier.sql @@ -0,0 +1,5 @@ +ALTER TABLE usage_logs + ADD COLUMN IF NOT EXISTS service_tier VARCHAR(16); + +CREATE INDEX IF NOT EXISTS idx_usage_logs_service_tier_created_at + ON usage_logs (service_tier, created_at); diff --git a/frontend/src/components/admin/usage/UsageTable.vue b/frontend/src/components/admin/usage/UsageTable.vue index 14c434d6..2d8869e2 100644 --- a/frontend/src/components/admin/usage/UsageTable.vue +++ b/frontend/src/components/admin/usage/UsageTable.vue @@ -228,6 +228,14 @@ {{ t('admin.usage.outputCost') }} ${{ tooltipData.output_cost.toFixed(6) }} +