From f930cdbb51169246ddfdf38d95bf84cf42cc64d4 Mon Sep 17 00:00:00 2001 From: RedwindA Date: Wed, 8 Oct 2025 16:52:49 +0800 Subject: [PATCH] fix(openai): account cached tokens for zhipu_v4 usage --- relay/channel/openai/relay-openai.go | 66 +++++++++++++++++++++++++--- relay/common/relay_info.go | 1 + 2 files changed, 61 insertions(+), 6 deletions(-) diff --git a/relay/channel/openai/relay-openai.go b/relay/channel/openai/relay-openai.go index a88b68502..6619fb160 100644 --- a/relay/channel/openai/relay-openai.go +++ b/relay/channel/openai/relay-openai.go @@ -163,13 +163,10 @@ func OaiStreamHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http.Re if !containStreamUsage { usage = service.ResponseText2Usage(responseTextBuilder.String(), info.UpstreamModelName, info.PromptTokens) usage.CompletionTokens += toolCount * 7 - } else { - if info.ChannelType == constant.ChannelTypeDeepSeek { - if usage.PromptCacheHitTokens != 0 { - usage.PromptTokensDetails.CachedTokens = usage.PromptCacheHitTokens - } - } } + + applyUsagePostProcessing(info, usage, nil) + HandleFinalResponse(c, info, lastStreamData, responseId, createAt, model, systemFingerprint, usage, containStreamUsage) return usage, nil @@ -233,6 +230,8 @@ func OpenaiHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http.Respo usageModified = true } + applyUsagePostProcessing(info, &simpleResponse.Usage, responseBody) + switch info.RelayFormat { case types.RelayFormatOpenAI: if usageModified { @@ -631,5 +630,60 @@ func OpenaiHandlerWithUsage(c *gin.Context, info *relaycommon.RelayInfo, resp *h usageResp.PromptTokensDetails.ImageTokens += usageResp.InputTokensDetails.ImageTokens usageResp.PromptTokensDetails.TextTokens += usageResp.InputTokensDetails.TextTokens } + applyUsagePostProcessing(info, &usageResp.Usage, responseBody) return &usageResp.Usage, nil } + +func applyUsagePostProcessing(info *relaycommon.RelayInfo, usage *dto.Usage, responseBody []byte) { + if info == nil || usage == nil { + return + } + + switch info.ChannelType { + case constant.ChannelTypeDeepSeek: + if usage.PromptTokensDetails.CachedTokens == 0 && usage.PromptCacheHitTokens != 0 { + usage.PromptTokensDetails.CachedTokens = usage.PromptCacheHitTokens + } + case constant.ChannelTypeZhipu_v4: + if usage.PromptTokensDetails.CachedTokens == 0 { + if usage.InputTokensDetails != nil && usage.InputTokensDetails.CachedTokens > 0 { + usage.PromptTokensDetails.CachedTokens = usage.InputTokensDetails.CachedTokens + } else if cachedTokens, ok := extractCachedTokensFromBody(responseBody); ok { + usage.PromptTokensDetails.CachedTokens = cachedTokens + } else if usage.PromptCacheHitTokens > 0 { + usage.PromptTokensDetails.CachedTokens = usage.PromptCacheHitTokens + } + } + } +} + +func extractCachedTokensFromBody(body []byte) (int, bool) { + if len(body) == 0 { + return 0, false + } + + var payload struct { + Usage struct { + PromptTokensDetails struct { + CachedTokens *int `json:"cached_tokens"` + } `json:"prompt_tokens_details"` + CachedTokens *int `json:"cached_tokens"` + PromptCacheHitTokens *int `json:"prompt_cache_hit_tokens"` + } `json:"usage"` + } + + if err := json.Unmarshal(body, &payload); err != nil { + return 0, false + } + + if payload.Usage.PromptTokensDetails.CachedTokens != nil { + return *payload.Usage.PromptTokensDetails.CachedTokens, true + } + if payload.Usage.CachedTokens != nil { + return *payload.Usage.CachedTokens, true + } + if payload.Usage.PromptCacheHitTokens != nil { + return *payload.Usage.PromptCacheHitTokens, true + } + return 0, false +} diff --git a/relay/common/relay_info.go b/relay/common/relay_info.go index cc860abd9..3fc1507b2 100644 --- a/relay/common/relay_info.go +++ b/relay/common/relay_info.go @@ -261,6 +261,7 @@ var streamSupportedChannels = map[int]bool{ constant.ChannelTypeXai: true, constant.ChannelTypeDeepSeek: true, constant.ChannelTypeBaiduV2: true, + constant.ChannelTypeZhipu_v4: true, } func GenRelayInfoWs(c *gin.Context, ws *websocket.Conn) *RelayInfo {