diff --git a/controller/channel-test.go b/controller/channel-test.go index 5fc6d749c..5a668c488 100644 --- a/controller/channel-test.go +++ b/controller/channel-test.go @@ -235,7 +235,7 @@ func testChannel(channel *model.Channel, testModel string) testResult { if resp != nil { httpResp = resp.(*http.Response) if httpResp.StatusCode != http.StatusOK { - err := service.RelayErrorHandler(httpResp, true) + err := service.RelayErrorHandler(c.Request.Context(), httpResp, true) return testResult{ context: c, localErr: err, diff --git a/relay/audio_handler.go b/relay/audio_handler.go index 711cc7a9b..1357e3816 100644 --- a/relay/audio_handler.go +++ b/relay/audio_handler.go @@ -53,7 +53,7 @@ func AudioHelper(c *gin.Context, info *relaycommon.RelayInfo) (newAPIError *type if resp != nil { httpResp = resp.(*http.Response) if httpResp.StatusCode != http.StatusOK { - newAPIError = service.RelayErrorHandler(httpResp, false) + newAPIError = service.RelayErrorHandler(c.Request.Context(), httpResp, false) // reset status code 重置状态码 service.ResetStatusCode(newAPIError, statusCodeMappingStr) return newAPIError diff --git a/relay/claude_handler.go b/relay/claude_handler.go index 59c052f62..dbdc6ee1c 100644 --- a/relay/claude_handler.go +++ b/relay/claude_handler.go @@ -111,7 +111,7 @@ func ClaudeHelper(c *gin.Context, info *relaycommon.RelayInfo) (newAPIError *typ httpResp = resp.(*http.Response) info.IsStream = info.IsStream || strings.HasPrefix(httpResp.Header.Get("Content-Type"), "text/event-stream") if httpResp.StatusCode != http.StatusOK { - newAPIError = service.RelayErrorHandler(httpResp, false) + newAPIError = service.RelayErrorHandler(c.Request.Context(), httpResp, false) // reset status code 重置状态码 service.ResetStatusCode(newAPIError, statusCodeMappingStr) return newAPIError diff --git a/relay/compatible_handler.go b/relay/compatible_handler.go index a3c6ace6e..8f27fd60b 100644 --- a/relay/compatible_handler.go +++ b/relay/compatible_handler.go @@ -158,7 +158,7 @@ func TextHelper(c *gin.Context, info *relaycommon.RelayInfo) (newAPIError *types httpResp = resp.(*http.Response) info.IsStream = info.IsStream || strings.HasPrefix(httpResp.Header.Get("Content-Type"), "text/event-stream") if httpResp.StatusCode != http.StatusOK { - newApiErr := service.RelayErrorHandler(httpResp, false) + newApiErr := service.RelayErrorHandler(c.Request.Context(), httpResp, false) // reset status code 重置状态码 service.ResetStatusCode(newApiErr, statusCodeMappingStr) return newApiErr @@ -195,6 +195,8 @@ func postConsumeQuota(ctx *gin.Context, relayInfo *relaycommon.RelayInfo, usage imageTokens := usage.PromptTokensDetails.ImageTokens audioTokens := usage.PromptTokensDetails.AudioTokens completionTokens := usage.CompletionTokens + cachedCreationTokens := usage.PromptTokensDetails.CachedCreationTokens + modelName := relayInfo.OriginModelName tokenName := ctx.GetString("token_name") @@ -204,6 +206,7 @@ func postConsumeQuota(ctx *gin.Context, relayInfo *relaycommon.RelayInfo, usage modelRatio := relayInfo.PriceData.ModelRatio groupRatio := relayInfo.PriceData.GroupRatioInfo.GroupRatio modelPrice := relayInfo.PriceData.ModelPrice + cachedCreationRatio := relayInfo.PriceData.CacheCreationRatio // Convert values to decimal for precise calculation dPromptTokens := decimal.NewFromInt(int64(promptTokens)) @@ -211,12 +214,14 @@ func postConsumeQuota(ctx *gin.Context, relayInfo *relaycommon.RelayInfo, usage dImageTokens := decimal.NewFromInt(int64(imageTokens)) dAudioTokens := decimal.NewFromInt(int64(audioTokens)) dCompletionTokens := decimal.NewFromInt(int64(completionTokens)) + dCachedCreationTokens := decimal.NewFromInt(int64(cachedCreationTokens)) dCompletionRatio := decimal.NewFromFloat(completionRatio) dCacheRatio := decimal.NewFromFloat(cacheRatio) dImageRatio := decimal.NewFromFloat(imageRatio) dModelRatio := decimal.NewFromFloat(modelRatio) dGroupRatio := decimal.NewFromFloat(groupRatio) dModelPrice := decimal.NewFromFloat(modelPrice) + dCachedCreationRatio := decimal.NewFromFloat(cachedCreationRatio) dQuotaPerUnit := decimal.NewFromFloat(common.QuotaPerUnit) ratio := dModelRatio.Mul(dGroupRatio) @@ -284,6 +289,11 @@ func postConsumeQuota(ctx *gin.Context, relayInfo *relaycommon.RelayInfo, usage baseTokens = baseTokens.Sub(dCacheTokens) cachedTokensWithRatio = dCacheTokens.Mul(dCacheRatio) } + var dCachedCreationTokensWithRatio decimal.Decimal + if !dCachedCreationTokens.IsZero() { + baseTokens = baseTokens.Sub(dCachedCreationTokens) + dCachedCreationTokensWithRatio = dCachedCreationTokens.Mul(dCachedCreationRatio) + } // 减去 image tokens var imageTokensWithRatio decimal.Decimal @@ -302,7 +312,9 @@ func postConsumeQuota(ctx *gin.Context, relayInfo *relaycommon.RelayInfo, usage extraContent += fmt.Sprintf("Audio Input 花费 %s", audioInputQuota.String()) } } - promptQuota := baseTokens.Add(cachedTokensWithRatio).Add(imageTokensWithRatio) + promptQuota := baseTokens.Add(cachedTokensWithRatio). + Add(imageTokensWithRatio). + Add(dCachedCreationTokensWithRatio) completionQuota := dCompletionTokens.Mul(dCompletionRatio) @@ -395,6 +407,10 @@ func postConsumeQuota(ctx *gin.Context, relayInfo *relaycommon.RelayInfo, usage other["image_ratio"] = imageRatio other["image_output"] = imageTokens } + if cachedCreationTokens != 0 { + other["cache_creation_tokens"] = cachedCreationTokens + other["cache_creation_ratio"] = cachedCreationRatio + } if !dWebSearchQuota.IsZero() { if relayInfo.ResponsesUsageInfo != nil { if webSearchTool, exists := relayInfo.ResponsesUsageInfo.BuiltInTools[dto.BuildInToolWebSearchPreview]; exists { diff --git a/relay/embedding_handler.go b/relay/embedding_handler.go index 26dcf9719..3d8962bb4 100644 --- a/relay/embedding_handler.go +++ b/relay/embedding_handler.go @@ -58,7 +58,7 @@ func EmbeddingHelper(c *gin.Context, info *relaycommon.RelayInfo) (newAPIError * if resp != nil { httpResp = resp.(*http.Response) if httpResp.StatusCode != http.StatusOK { - newAPIError = service.RelayErrorHandler(httpResp, false) + newAPIError = service.RelayErrorHandler(c.Request.Context(), httpResp, false) // reset status code 重置状态码 service.ResetStatusCode(newAPIError, statusCodeMappingStr) return newAPIError diff --git a/relay/gemini_handler.go b/relay/gemini_handler.go index 460fd2f58..0252d6578 100644 --- a/relay/gemini_handler.go +++ b/relay/gemini_handler.go @@ -152,7 +152,7 @@ func GeminiHelper(c *gin.Context, info *relaycommon.RelayInfo) (newAPIError *typ httpResp = resp.(*http.Response) info.IsStream = info.IsStream || strings.HasPrefix(httpResp.Header.Get("Content-Type"), "text/event-stream") if httpResp.StatusCode != http.StatusOK { - newAPIError = service.RelayErrorHandler(httpResp, false) + newAPIError = service.RelayErrorHandler(c.Request.Context(), httpResp, false) // reset status code 重置状态码 service.ResetStatusCode(newAPIError, statusCodeMappingStr) return newAPIError @@ -249,7 +249,7 @@ func GeminiEmbeddingHandler(c *gin.Context, info *relaycommon.RelayInfo) (newAPI if resp != nil { httpResp = resp.(*http.Response) if httpResp.StatusCode != http.StatusOK { - newAPIError = service.RelayErrorHandler(httpResp, false) + newAPIError = service.RelayErrorHandler(c.Request.Context(), httpResp, false) service.ResetStatusCode(newAPIError, statusCodeMappingStr) return newAPIError } diff --git a/relay/image_handler.go b/relay/image_handler.go index 14a7103c3..e2789ae5e 100644 --- a/relay/image_handler.go +++ b/relay/image_handler.go @@ -91,7 +91,7 @@ func ImageHelper(c *gin.Context, info *relaycommon.RelayInfo) (newAPIError *type httpResp = resp.(*http.Response) info.IsStream = info.IsStream || strings.HasPrefix(httpResp.Header.Get("Content-Type"), "text/event-stream") if httpResp.StatusCode != http.StatusOK { - newAPIError = service.RelayErrorHandler(httpResp, false) + newAPIError = service.RelayErrorHandler(c.Request.Context(), httpResp, false) // reset status code 重置状态码 service.ResetStatusCode(newAPIError, statusCodeMappingStr) return newAPIError diff --git a/relay/rerank_handler.go b/relay/rerank_handler.go index fa3c7bbb4..46d2e25f6 100644 --- a/relay/rerank_handler.go +++ b/relay/rerank_handler.go @@ -81,7 +81,7 @@ func RerankHelper(c *gin.Context, info *relaycommon.RelayInfo) (newAPIError *typ if resp != nil { httpResp = resp.(*http.Response) if httpResp.StatusCode != http.StatusOK { - newAPIError = service.RelayErrorHandler(httpResp, false) + newAPIError = service.RelayErrorHandler(c.Request.Context(), httpResp, false) // reset status code 重置状态码 service.ResetStatusCode(newAPIError, statusCodeMappingStr) return newAPIError diff --git a/relay/responses_handler.go b/relay/responses_handler.go index f5f624c92..d1c5d2158 100644 --- a/relay/responses_handler.go +++ b/relay/responses_handler.go @@ -82,7 +82,7 @@ func ResponsesHelper(c *gin.Context, info *relaycommon.RelayInfo) (newAPIError * httpResp = resp.(*http.Response) if httpResp.StatusCode != http.StatusOK { - newAPIError = service.RelayErrorHandler(httpResp, false) + newAPIError = service.RelayErrorHandler(c.Request.Context(), httpResp, false) // reset status code 重置状态码 service.ResetStatusCode(newAPIError, statusCodeMappingStr) return newAPIError diff --git a/service/error.go b/service/error.go index ef5cbbde6..5c3bddd6e 100644 --- a/service/error.go +++ b/service/error.go @@ -1,12 +1,14 @@ package service import ( + "context" "errors" "fmt" "io" "net/http" "one-api/common" "one-api/dto" + "one-api/logger" "one-api/types" "strconv" "strings" @@ -78,7 +80,7 @@ func ClaudeErrorWrapperLocal(err error, code string, statusCode int) *dto.Claude return claudeErr } -func RelayErrorHandler(resp *http.Response, showBodyWhenFail bool) (newApiErr *types.NewAPIError) { +func RelayErrorHandler(ctx context.Context, resp *http.Response, showBodyWhenFail bool) (newApiErr *types.NewAPIError) { newApiErr = types.InitOpenAIError(types.ErrorCodeBadResponseStatusCode, resp.StatusCode) responseBody, err := io.ReadAll(resp.Body) @@ -94,7 +96,7 @@ func RelayErrorHandler(resp *http.Response, showBodyWhenFail bool) (newApiErr *t newApiErr.Err = fmt.Errorf("bad response status code %d, body: %s", resp.StatusCode, string(responseBody)) } else { if common.DebugEnabled { - println(fmt.Sprintf("bad response status code %d, body: %s", resp.StatusCode, string(responseBody))) + logger.LogInfo(ctx, fmt.Sprintf("bad response status code %d, body: %s", resp.StatusCode, string(responseBody))) } newApiErr.Err = fmt.Errorf("bad response status code %d", resp.StatusCode) }