diff --git a/relay/channel/openai/relay-openai.go b/relay/channel/openai/relay-openai.go index ac44312eb..a4c6ef605 100644 --- a/relay/channel/openai/relay-openai.go +++ b/relay/channel/openai/relay-openai.go @@ -186,7 +186,7 @@ func OaiStreamHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http.Re usage.CompletionTokens += toolCount * 7 } - applyUsagePostProcessing(info, usage, nil) + applyUsagePostProcessing(info, usage, common.StringToByteSlice(lastStreamData)) HandleFinalResponse(c, info, lastStreamData, responseId, createAt, model, systemFingerprint, usage, containStreamUsage) @@ -596,7 +596,8 @@ func applyUsagePostProcessing(info *relaycommon.RelayInfo, usage *dto.Usage, res if usage.PromptTokensDetails.CachedTokens == 0 && usage.PromptCacheHitTokens != 0 { usage.PromptTokensDetails.CachedTokens = usage.PromptCacheHitTokens } - case constant.ChannelTypeZhipu_v4, constant.ChannelTypeMoonshot: + case constant.ChannelTypeZhipu_v4: + // 智普的cached_tokens在标准位置: usage.prompt_tokens_details.cached_tokens if usage.PromptTokensDetails.CachedTokens == 0 { if usage.InputTokensDetails != nil && usage.InputTokensDetails.CachedTokens > 0 { usage.PromptTokensDetails.CachedTokens = usage.InputTokensDetails.CachedTokens @@ -606,6 +607,19 @@ func applyUsagePostProcessing(info *relaycommon.RelayInfo, usage *dto.Usage, res usage.PromptTokensDetails.CachedTokens = usage.PromptCacheHitTokens } } + case constant.ChannelTypeMoonshot: + // Moonshot的cached_tokens在非标准位置: choices[].usage.cached_tokens + if usage.PromptTokensDetails.CachedTokens == 0 { + if usage.InputTokensDetails != nil && usage.InputTokensDetails.CachedTokens > 0 { + usage.PromptTokensDetails.CachedTokens = usage.InputTokensDetails.CachedTokens + } else if cachedTokens, ok := extractMoonshotCachedTokensFromBody(responseBody); ok { + usage.PromptTokensDetails.CachedTokens = cachedTokens + } else if cachedTokens, ok := extractCachedTokensFromBody(responseBody); ok { + usage.PromptTokensDetails.CachedTokens = cachedTokens + } else if usage.PromptCacheHitTokens > 0 { + usage.PromptTokensDetails.CachedTokens = usage.PromptCacheHitTokens + } + } } } @@ -639,3 +653,32 @@ func extractCachedTokensFromBody(body []byte) (int, bool) { } return 0, false } + +// extractMoonshotCachedTokensFromBody 从Moonshot的非标准位置提取cached_tokens +// Moonshot的流式响应格式: {"choices":[{"usage":{"cached_tokens":111}}]} +func extractMoonshotCachedTokensFromBody(body []byte) (int, bool) { + if len(body) == 0 { + return 0, false + } + + var payload struct { + Choices []struct { + Usage struct { + CachedTokens *int `json:"cached_tokens"` + } `json:"usage"` + } `json:"choices"` + } + + if err := common.Unmarshal(body, &payload); err != nil { + return 0, false + } + + // 遍历choices查找cached_tokens + for _, choice := range payload.Choices { + if choice.Usage.CachedTokens != nil && *choice.Usage.CachedTokens > 0 { + return *choice.Usage.CachedTokens, true + } + } + + return 0, false +}