mirror of
https://github.com/QuantumNous/new-api.git
synced 2026-04-25 05:38:38 +00:00
fix(openai): account cached tokens for
zhipu_v4 usage
This commit is contained in:
@@ -163,13 +163,10 @@ func OaiStreamHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http.Re
|
|||||||
if !containStreamUsage {
|
if !containStreamUsage {
|
||||||
usage = service.ResponseText2Usage(responseTextBuilder.String(), info.UpstreamModelName, info.PromptTokens)
|
usage = service.ResponseText2Usage(responseTextBuilder.String(), info.UpstreamModelName, info.PromptTokens)
|
||||||
usage.CompletionTokens += toolCount * 7
|
usage.CompletionTokens += toolCount * 7
|
||||||
} else {
|
|
||||||
if info.ChannelType == constant.ChannelTypeDeepSeek {
|
|
||||||
if usage.PromptCacheHitTokens != 0 {
|
|
||||||
usage.PromptTokensDetails.CachedTokens = usage.PromptCacheHitTokens
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
applyUsagePostProcessing(info, usage, nil)
|
||||||
|
|
||||||
HandleFinalResponse(c, info, lastStreamData, responseId, createAt, model, systemFingerprint, usage, containStreamUsage)
|
HandleFinalResponse(c, info, lastStreamData, responseId, createAt, model, systemFingerprint, usage, containStreamUsage)
|
||||||
|
|
||||||
return usage, nil
|
return usage, nil
|
||||||
@@ -233,6 +230,8 @@ func OpenaiHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http.Respo
|
|||||||
usageModified = true
|
usageModified = true
|
||||||
}
|
}
|
||||||
|
|
||||||
|
applyUsagePostProcessing(info, &simpleResponse.Usage, responseBody)
|
||||||
|
|
||||||
switch info.RelayFormat {
|
switch info.RelayFormat {
|
||||||
case types.RelayFormatOpenAI:
|
case types.RelayFormatOpenAI:
|
||||||
if usageModified {
|
if usageModified {
|
||||||
@@ -631,5 +630,60 @@ func OpenaiHandlerWithUsage(c *gin.Context, info *relaycommon.RelayInfo, resp *h
|
|||||||
usageResp.PromptTokensDetails.ImageTokens += usageResp.InputTokensDetails.ImageTokens
|
usageResp.PromptTokensDetails.ImageTokens += usageResp.InputTokensDetails.ImageTokens
|
||||||
usageResp.PromptTokensDetails.TextTokens += usageResp.InputTokensDetails.TextTokens
|
usageResp.PromptTokensDetails.TextTokens += usageResp.InputTokensDetails.TextTokens
|
||||||
}
|
}
|
||||||
|
applyUsagePostProcessing(info, &usageResp.Usage, responseBody)
|
||||||
return &usageResp.Usage, nil
|
return &usageResp.Usage, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func applyUsagePostProcessing(info *relaycommon.RelayInfo, usage *dto.Usage, responseBody []byte) {
|
||||||
|
if info == nil || usage == nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
switch info.ChannelType {
|
||||||
|
case constant.ChannelTypeDeepSeek:
|
||||||
|
if usage.PromptTokensDetails.CachedTokens == 0 && usage.PromptCacheHitTokens != 0 {
|
||||||
|
usage.PromptTokensDetails.CachedTokens = usage.PromptCacheHitTokens
|
||||||
|
}
|
||||||
|
case constant.ChannelTypeZhipu_v4:
|
||||||
|
if usage.PromptTokensDetails.CachedTokens == 0 {
|
||||||
|
if usage.InputTokensDetails != nil && usage.InputTokensDetails.CachedTokens > 0 {
|
||||||
|
usage.PromptTokensDetails.CachedTokens = usage.InputTokensDetails.CachedTokens
|
||||||
|
} else if cachedTokens, ok := extractCachedTokensFromBody(responseBody); ok {
|
||||||
|
usage.PromptTokensDetails.CachedTokens = cachedTokens
|
||||||
|
} else if usage.PromptCacheHitTokens > 0 {
|
||||||
|
usage.PromptTokensDetails.CachedTokens = usage.PromptCacheHitTokens
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func extractCachedTokensFromBody(body []byte) (int, bool) {
|
||||||
|
if len(body) == 0 {
|
||||||
|
return 0, false
|
||||||
|
}
|
||||||
|
|
||||||
|
var payload struct {
|
||||||
|
Usage struct {
|
||||||
|
PromptTokensDetails struct {
|
||||||
|
CachedTokens *int `json:"cached_tokens"`
|
||||||
|
} `json:"prompt_tokens_details"`
|
||||||
|
CachedTokens *int `json:"cached_tokens"`
|
||||||
|
PromptCacheHitTokens *int `json:"prompt_cache_hit_tokens"`
|
||||||
|
} `json:"usage"`
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := json.Unmarshal(body, &payload); err != nil {
|
||||||
|
return 0, false
|
||||||
|
}
|
||||||
|
|
||||||
|
if payload.Usage.PromptTokensDetails.CachedTokens != nil {
|
||||||
|
return *payload.Usage.PromptTokensDetails.CachedTokens, true
|
||||||
|
}
|
||||||
|
if payload.Usage.CachedTokens != nil {
|
||||||
|
return *payload.Usage.CachedTokens, true
|
||||||
|
}
|
||||||
|
if payload.Usage.PromptCacheHitTokens != nil {
|
||||||
|
return *payload.Usage.PromptCacheHitTokens, true
|
||||||
|
}
|
||||||
|
return 0, false
|
||||||
|
}
|
||||||
|
|||||||
@@ -261,6 +261,7 @@ var streamSupportedChannels = map[int]bool{
|
|||||||
constant.ChannelTypeXai: true,
|
constant.ChannelTypeXai: true,
|
||||||
constant.ChannelTypeDeepSeek: true,
|
constant.ChannelTypeDeepSeek: true,
|
||||||
constant.ChannelTypeBaiduV2: true,
|
constant.ChannelTypeBaiduV2: true,
|
||||||
|
constant.ChannelTypeZhipu_v4: true,
|
||||||
}
|
}
|
||||||
|
|
||||||
func GenRelayInfoWs(c *gin.Context, ws *websocket.Conn) *RelayInfo {
|
func GenRelayInfoWs(c *gin.Context, ws *websocket.Conn) *RelayInfo {
|
||||||
|
|||||||
Reference in New Issue
Block a user