fix(openai): account cached tokens for

zhipu_v4 usage
This commit is contained in:
RedwindA
2025-10-08 16:52:49 +08:00
parent a610ef48e4
commit f930cdbb51
2 changed files with 61 additions and 6 deletions

View File

@@ -163,13 +163,10 @@ func OaiStreamHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http.Re
if !containStreamUsage {
usage = service.ResponseText2Usage(responseTextBuilder.String(), info.UpstreamModelName, info.PromptTokens)
usage.CompletionTokens += toolCount * 7
} else {
if info.ChannelType == constant.ChannelTypeDeepSeek {
if usage.PromptCacheHitTokens != 0 {
usage.PromptTokensDetails.CachedTokens = usage.PromptCacheHitTokens
}
}
}
applyUsagePostProcessing(info, usage, nil)
HandleFinalResponse(c, info, lastStreamData, responseId, createAt, model, systemFingerprint, usage, containStreamUsage)
return usage, nil
@@ -233,6 +230,8 @@ func OpenaiHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http.Respo
usageModified = true
}
applyUsagePostProcessing(info, &simpleResponse.Usage, responseBody)
switch info.RelayFormat {
case types.RelayFormatOpenAI:
if usageModified {
@@ -631,5 +630,60 @@ func OpenaiHandlerWithUsage(c *gin.Context, info *relaycommon.RelayInfo, resp *h
usageResp.PromptTokensDetails.ImageTokens += usageResp.InputTokensDetails.ImageTokens
usageResp.PromptTokensDetails.TextTokens += usageResp.InputTokensDetails.TextTokens
}
applyUsagePostProcessing(info, &usageResp.Usage, responseBody)
return &usageResp.Usage, nil
}
func applyUsagePostProcessing(info *relaycommon.RelayInfo, usage *dto.Usage, responseBody []byte) {
if info == nil || usage == nil {
return
}
switch info.ChannelType {
case constant.ChannelTypeDeepSeek:
if usage.PromptTokensDetails.CachedTokens == 0 && usage.PromptCacheHitTokens != 0 {
usage.PromptTokensDetails.CachedTokens = usage.PromptCacheHitTokens
}
case constant.ChannelTypeZhipu_v4:
if usage.PromptTokensDetails.CachedTokens == 0 {
if usage.InputTokensDetails != nil && usage.InputTokensDetails.CachedTokens > 0 {
usage.PromptTokensDetails.CachedTokens = usage.InputTokensDetails.CachedTokens
} else if cachedTokens, ok := extractCachedTokensFromBody(responseBody); ok {
usage.PromptTokensDetails.CachedTokens = cachedTokens
} else if usage.PromptCacheHitTokens > 0 {
usage.PromptTokensDetails.CachedTokens = usage.PromptCacheHitTokens
}
}
}
}
func extractCachedTokensFromBody(body []byte) (int, bool) {
if len(body) == 0 {
return 0, false
}
var payload struct {
Usage struct {
PromptTokensDetails struct {
CachedTokens *int `json:"cached_tokens"`
} `json:"prompt_tokens_details"`
CachedTokens *int `json:"cached_tokens"`
PromptCacheHitTokens *int `json:"prompt_cache_hit_tokens"`
} `json:"usage"`
}
if err := json.Unmarshal(body, &payload); err != nil {
return 0, false
}
if payload.Usage.PromptTokensDetails.CachedTokens != nil {
return *payload.Usage.PromptTokensDetails.CachedTokens, true
}
if payload.Usage.CachedTokens != nil {
return *payload.Usage.CachedTokens, true
}
if payload.Usage.PromptCacheHitTokens != nil {
return *payload.Usage.PromptCacheHitTokens, true
}
return 0, false
}

View File

@@ -261,6 +261,7 @@ var streamSupportedChannels = map[int]bool{
constant.ChannelTypeXai: true,
constant.ChannelTypeDeepSeek: true,
constant.ChannelTypeBaiduV2: true,
constant.ChannelTypeZhipu_v4: true,
}
func GenRelayInfoWs(c *gin.Context, ws *websocket.Conn) *RelayInfo {