mirror of
https://github.com/QuantumNous/new-api.git
synced 2026-04-29 23:08:38 +00:00
fix: unify usage mapping and include toolUsePromptTokenCount in input tokens
This commit is contained in:
@@ -453,12 +453,14 @@ type GeminiChatResponse struct {
|
|||||||
}
|
}
|
||||||
|
|
||||||
type GeminiUsageMetadata struct {
|
type GeminiUsageMetadata struct {
|
||||||
PromptTokenCount int `json:"promptTokenCount"`
|
PromptTokenCount int `json:"promptTokenCount"`
|
||||||
CandidatesTokenCount int `json:"candidatesTokenCount"`
|
ToolUsePromptTokenCount int `json:"toolUsePromptTokenCount"`
|
||||||
TotalTokenCount int `json:"totalTokenCount"`
|
CandidatesTokenCount int `json:"candidatesTokenCount"`
|
||||||
ThoughtsTokenCount int `json:"thoughtsTokenCount"`
|
TotalTokenCount int `json:"totalTokenCount"`
|
||||||
CachedContentTokenCount int `json:"cachedContentTokenCount"`
|
ThoughtsTokenCount int `json:"thoughtsTokenCount"`
|
||||||
PromptTokensDetails []GeminiPromptTokensDetails `json:"promptTokensDetails"`
|
CachedContentTokenCount int `json:"cachedContentTokenCount"`
|
||||||
|
PromptTokensDetails []GeminiPromptTokensDetails `json:"promptTokensDetails"`
|
||||||
|
ToolUsePromptTokensDetails []GeminiPromptTokensDetails `json:"toolUsePromptTokensDetails"`
|
||||||
}
|
}
|
||||||
|
|
||||||
type GeminiPromptTokensDetails struct {
|
type GeminiPromptTokensDetails struct {
|
||||||
|
|||||||
@@ -42,22 +42,7 @@ func GeminiTextGenerationHandler(c *gin.Context, info *relaycommon.RelayInfo, re
|
|||||||
}
|
}
|
||||||
|
|
||||||
// 计算使用量(基于 UsageMetadata)
|
// 计算使用量(基于 UsageMetadata)
|
||||||
usage := dto.Usage{
|
usage := buildUsageFromGeminiMetadata(geminiResponse.UsageMetadata, info.GetEstimatePromptTokens())
|
||||||
PromptTokens: geminiResponse.UsageMetadata.PromptTokenCount,
|
|
||||||
CompletionTokens: geminiResponse.UsageMetadata.CandidatesTokenCount + geminiResponse.UsageMetadata.ThoughtsTokenCount,
|
|
||||||
TotalTokens: geminiResponse.UsageMetadata.TotalTokenCount,
|
|
||||||
}
|
|
||||||
|
|
||||||
usage.CompletionTokenDetails.ReasoningTokens = geminiResponse.UsageMetadata.ThoughtsTokenCount
|
|
||||||
usage.PromptTokensDetails.CachedTokens = geminiResponse.UsageMetadata.CachedContentTokenCount
|
|
||||||
|
|
||||||
for _, detail := range geminiResponse.UsageMetadata.PromptTokensDetails {
|
|
||||||
if detail.Modality == "AUDIO" {
|
|
||||||
usage.PromptTokensDetails.AudioTokens = detail.TokenCount
|
|
||||||
} else if detail.Modality == "TEXT" {
|
|
||||||
usage.PromptTokensDetails.TextTokens = detail.TokenCount
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
service.IOCopyBytesGracefully(c, resp, responseBody)
|
service.IOCopyBytesGracefully(c, resp, responseBody)
|
||||||
|
|
||||||
|
|||||||
@@ -1032,6 +1032,46 @@ func getResponseToolCall(item *dto.GeminiPart) *dto.ToolCallResponse {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func buildUsageFromGeminiMetadata(metadata dto.GeminiUsageMetadata, fallbackPromptTokens int) dto.Usage {
|
||||||
|
promptTokens := metadata.PromptTokenCount + metadata.ToolUsePromptTokenCount
|
||||||
|
if promptTokens <= 0 && fallbackPromptTokens > 0 {
|
||||||
|
promptTokens = fallbackPromptTokens
|
||||||
|
}
|
||||||
|
|
||||||
|
usage := dto.Usage{
|
||||||
|
PromptTokens: promptTokens,
|
||||||
|
CompletionTokens: metadata.CandidatesTokenCount + metadata.ThoughtsTokenCount,
|
||||||
|
TotalTokens: metadata.TotalTokenCount,
|
||||||
|
}
|
||||||
|
usage.CompletionTokenDetails.ReasoningTokens = metadata.ThoughtsTokenCount
|
||||||
|
usage.PromptTokensDetails.CachedTokens = metadata.CachedContentTokenCount
|
||||||
|
|
||||||
|
for _, detail := range metadata.PromptTokensDetails {
|
||||||
|
if detail.Modality == "AUDIO" {
|
||||||
|
usage.PromptTokensDetails.AudioTokens += detail.TokenCount
|
||||||
|
} else if detail.Modality == "TEXT" {
|
||||||
|
usage.PromptTokensDetails.TextTokens += detail.TokenCount
|
||||||
|
}
|
||||||
|
}
|
||||||
|
for _, detail := range metadata.ToolUsePromptTokensDetails {
|
||||||
|
if detail.Modality == "AUDIO" {
|
||||||
|
usage.PromptTokensDetails.AudioTokens += detail.TokenCount
|
||||||
|
} else if detail.Modality == "TEXT" {
|
||||||
|
usage.PromptTokensDetails.TextTokens += detail.TokenCount
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if usage.TotalTokens > 0 && usage.CompletionTokens <= 0 {
|
||||||
|
usage.CompletionTokens = usage.TotalTokens - usage.PromptTokens
|
||||||
|
}
|
||||||
|
|
||||||
|
if usage.PromptTokens > 0 && usage.PromptTokensDetails.TextTokens == 0 && usage.PromptTokensDetails.AudioTokens == 0 {
|
||||||
|
usage.PromptTokensDetails.TextTokens = usage.PromptTokens
|
||||||
|
}
|
||||||
|
|
||||||
|
return usage
|
||||||
|
}
|
||||||
|
|
||||||
func responseGeminiChat2OpenAI(c *gin.Context, response *dto.GeminiChatResponse) *dto.OpenAITextResponse {
|
func responseGeminiChat2OpenAI(c *gin.Context, response *dto.GeminiChatResponse) *dto.OpenAITextResponse {
|
||||||
fullTextResponse := dto.OpenAITextResponse{
|
fullTextResponse := dto.OpenAITextResponse{
|
||||||
Id: helper.GetResponseID(c),
|
Id: helper.GetResponseID(c),
|
||||||
@@ -1272,18 +1312,8 @@ func geminiStreamHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http
|
|||||||
|
|
||||||
// 更新使用量统计
|
// 更新使用量统计
|
||||||
if geminiResponse.UsageMetadata.TotalTokenCount != 0 {
|
if geminiResponse.UsageMetadata.TotalTokenCount != 0 {
|
||||||
usage.PromptTokens = geminiResponse.UsageMetadata.PromptTokenCount
|
mappedUsage := buildUsageFromGeminiMetadata(geminiResponse.UsageMetadata, info.GetEstimatePromptTokens())
|
||||||
usage.CompletionTokens = geminiResponse.UsageMetadata.CandidatesTokenCount + geminiResponse.UsageMetadata.ThoughtsTokenCount
|
*usage = mappedUsage
|
||||||
usage.CompletionTokenDetails.ReasoningTokens = geminiResponse.UsageMetadata.ThoughtsTokenCount
|
|
||||||
usage.TotalTokens = geminiResponse.UsageMetadata.TotalTokenCount
|
|
||||||
usage.PromptTokensDetails.CachedTokens = geminiResponse.UsageMetadata.CachedContentTokenCount
|
|
||||||
for _, detail := range geminiResponse.UsageMetadata.PromptTokensDetails {
|
|
||||||
if detail.Modality == "AUDIO" {
|
|
||||||
usage.PromptTokensDetails.AudioTokens = detail.TokenCount
|
|
||||||
} else if detail.Modality == "TEXT" {
|
|
||||||
usage.PromptTokensDetails.TextTokens = detail.TokenCount
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
return callback(data, &geminiResponse)
|
return callback(data, &geminiResponse)
|
||||||
@@ -1295,11 +1325,6 @@ func geminiStreamHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
usage.PromptTokensDetails.TextTokens = usage.PromptTokens
|
|
||||||
if usage.TotalTokens > 0 {
|
|
||||||
usage.CompletionTokens = usage.TotalTokens - usage.PromptTokens
|
|
||||||
}
|
|
||||||
|
|
||||||
if usage.CompletionTokens <= 0 {
|
if usage.CompletionTokens <= 0 {
|
||||||
if info.ReceivedResponseCount > 0 {
|
if info.ReceivedResponseCount > 0 {
|
||||||
usage = service.ResponseText2Usage(c, responseText.String(), info.UpstreamModelName, info.GetEstimatePromptTokens())
|
usage = service.ResponseText2Usage(c, responseText.String(), info.UpstreamModelName, info.GetEstimatePromptTokens())
|
||||||
@@ -1416,21 +1441,7 @@ func GeminiChatHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http.R
|
|||||||
return nil, types.NewOpenAIError(err, types.ErrorCodeBadResponseBody, http.StatusInternalServerError)
|
return nil, types.NewOpenAIError(err, types.ErrorCodeBadResponseBody, http.StatusInternalServerError)
|
||||||
}
|
}
|
||||||
if len(geminiResponse.Candidates) == 0 {
|
if len(geminiResponse.Candidates) == 0 {
|
||||||
usage := dto.Usage{
|
usage := buildUsageFromGeminiMetadata(geminiResponse.UsageMetadata, info.GetEstimatePromptTokens())
|
||||||
PromptTokens: geminiResponse.UsageMetadata.PromptTokenCount,
|
|
||||||
}
|
|
||||||
usage.CompletionTokenDetails.ReasoningTokens = geminiResponse.UsageMetadata.ThoughtsTokenCount
|
|
||||||
usage.PromptTokensDetails.CachedTokens = geminiResponse.UsageMetadata.CachedContentTokenCount
|
|
||||||
for _, detail := range geminiResponse.UsageMetadata.PromptTokensDetails {
|
|
||||||
if detail.Modality == "AUDIO" {
|
|
||||||
usage.PromptTokensDetails.AudioTokens = detail.TokenCount
|
|
||||||
} else if detail.Modality == "TEXT" {
|
|
||||||
usage.PromptTokensDetails.TextTokens = detail.TokenCount
|
|
||||||
}
|
|
||||||
}
|
|
||||||
if usage.PromptTokens <= 0 {
|
|
||||||
usage.PromptTokens = info.GetEstimatePromptTokens()
|
|
||||||
}
|
|
||||||
|
|
||||||
var newAPIError *types.NewAPIError
|
var newAPIError *types.NewAPIError
|
||||||
if geminiResponse.PromptFeedback != nil && geminiResponse.PromptFeedback.BlockReason != nil {
|
if geminiResponse.PromptFeedback != nil && geminiResponse.PromptFeedback.BlockReason != nil {
|
||||||
@@ -1466,23 +1477,7 @@ func GeminiChatHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http.R
|
|||||||
}
|
}
|
||||||
fullTextResponse := responseGeminiChat2OpenAI(c, &geminiResponse)
|
fullTextResponse := responseGeminiChat2OpenAI(c, &geminiResponse)
|
||||||
fullTextResponse.Model = info.UpstreamModelName
|
fullTextResponse.Model = info.UpstreamModelName
|
||||||
usage := dto.Usage{
|
usage := buildUsageFromGeminiMetadata(geminiResponse.UsageMetadata, info.GetEstimatePromptTokens())
|
||||||
PromptTokens: geminiResponse.UsageMetadata.PromptTokenCount,
|
|
||||||
CompletionTokens: geminiResponse.UsageMetadata.CandidatesTokenCount,
|
|
||||||
TotalTokens: geminiResponse.UsageMetadata.TotalTokenCount,
|
|
||||||
}
|
|
||||||
|
|
||||||
usage.CompletionTokenDetails.ReasoningTokens = geminiResponse.UsageMetadata.ThoughtsTokenCount
|
|
||||||
usage.PromptTokensDetails.CachedTokens = geminiResponse.UsageMetadata.CachedContentTokenCount
|
|
||||||
usage.CompletionTokens = usage.TotalTokens - usage.PromptTokens
|
|
||||||
|
|
||||||
for _, detail := range geminiResponse.UsageMetadata.PromptTokensDetails {
|
|
||||||
if detail.Modality == "AUDIO" {
|
|
||||||
usage.PromptTokensDetails.AudioTokens = detail.TokenCount
|
|
||||||
} else if detail.Modality == "TEXT" {
|
|
||||||
usage.PromptTokensDetails.TextTokens = detail.TokenCount
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
fullTextResponse.Usage = usage
|
fullTextResponse.Usage = usage
|
||||||
|
|
||||||
|
|||||||
333
relay/channel/gemini/relay_gemini_usage_test.go
Normal file
333
relay/channel/gemini/relay_gemini_usage_test.go
Normal file
@@ -0,0 +1,333 @@
|
|||||||
|
package gemini
|
||||||
|
|
||||||
|
import (
|
||||||
|
"bytes"
|
||||||
|
"io"
|
||||||
|
"net/http"
|
||||||
|
"net/http/httptest"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/QuantumNous/new-api/common"
|
||||||
|
"github.com/QuantumNous/new-api/constant"
|
||||||
|
"github.com/QuantumNous/new-api/dto"
|
||||||
|
relaycommon "github.com/QuantumNous/new-api/relay/common"
|
||||||
|
"github.com/QuantumNous/new-api/types"
|
||||||
|
"github.com/gin-gonic/gin"
|
||||||
|
"github.com/stretchr/testify/require"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestGeminiChatHandlerCompletionTokensExcludeToolUsePromptTokens(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
|
||||||
|
gin.SetMode(gin.TestMode)
|
||||||
|
c, _ := gin.CreateTestContext(httptest.NewRecorder())
|
||||||
|
c.Request = httptest.NewRequest(http.MethodPost, "/v1/chat/completions", nil)
|
||||||
|
|
||||||
|
info := &relaycommon.RelayInfo{
|
||||||
|
RelayFormat: types.RelayFormatGemini,
|
||||||
|
OriginModelName: "gemini-3-flash-preview",
|
||||||
|
ChannelMeta: &relaycommon.ChannelMeta{
|
||||||
|
UpstreamModelName: "gemini-3-flash-preview",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
payload := dto.GeminiChatResponse{
|
||||||
|
Candidates: []dto.GeminiChatCandidate{
|
||||||
|
{
|
||||||
|
Content: dto.GeminiChatContent{
|
||||||
|
Role: "model",
|
||||||
|
Parts: []dto.GeminiPart{
|
||||||
|
{Text: "ok"},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
UsageMetadata: dto.GeminiUsageMetadata{
|
||||||
|
PromptTokenCount: 151,
|
||||||
|
ToolUsePromptTokenCount: 18329,
|
||||||
|
CandidatesTokenCount: 1089,
|
||||||
|
ThoughtsTokenCount: 1120,
|
||||||
|
TotalTokenCount: 20689,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
body, err := common.Marshal(payload)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
resp := &http.Response{
|
||||||
|
Body: io.NopCloser(bytes.NewReader(body)),
|
||||||
|
}
|
||||||
|
|
||||||
|
usage, newAPIError := GeminiChatHandler(c, info, resp)
|
||||||
|
require.Nil(t, newAPIError)
|
||||||
|
require.NotNil(t, usage)
|
||||||
|
require.Equal(t, 18480, usage.PromptTokens)
|
||||||
|
require.Equal(t, 2209, usage.CompletionTokens)
|
||||||
|
require.Equal(t, 20689, usage.TotalTokens)
|
||||||
|
require.Equal(t, 1120, usage.CompletionTokenDetails.ReasoningTokens)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestGeminiStreamHandlerCompletionTokensExcludeToolUsePromptTokens(t *testing.T) {
|
||||||
|
gin.SetMode(gin.TestMode)
|
||||||
|
c, _ := gin.CreateTestContext(httptest.NewRecorder())
|
||||||
|
c.Request = httptest.NewRequest(http.MethodPost, "/v1/chat/completions", nil)
|
||||||
|
|
||||||
|
oldStreamingTimeout := constant.StreamingTimeout
|
||||||
|
constant.StreamingTimeout = 300
|
||||||
|
t.Cleanup(func() {
|
||||||
|
constant.StreamingTimeout = oldStreamingTimeout
|
||||||
|
})
|
||||||
|
|
||||||
|
info := &relaycommon.RelayInfo{
|
||||||
|
OriginModelName: "gemini-3-flash-preview",
|
||||||
|
ChannelMeta: &relaycommon.ChannelMeta{
|
||||||
|
UpstreamModelName: "gemini-3-flash-preview",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
chunk := dto.GeminiChatResponse{
|
||||||
|
Candidates: []dto.GeminiChatCandidate{
|
||||||
|
{
|
||||||
|
Content: dto.GeminiChatContent{
|
||||||
|
Role: "model",
|
||||||
|
Parts: []dto.GeminiPart{
|
||||||
|
{Text: "partial"},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
UsageMetadata: dto.GeminiUsageMetadata{
|
||||||
|
PromptTokenCount: 151,
|
||||||
|
ToolUsePromptTokenCount: 18329,
|
||||||
|
CandidatesTokenCount: 1089,
|
||||||
|
ThoughtsTokenCount: 1120,
|
||||||
|
TotalTokenCount: 20689,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
chunkData, err := common.Marshal(chunk)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
streamBody := []byte("data: " + string(chunkData) + "\n" + "data: [DONE]\n")
|
||||||
|
resp := &http.Response{
|
||||||
|
Body: io.NopCloser(bytes.NewReader(streamBody)),
|
||||||
|
}
|
||||||
|
|
||||||
|
usage, newAPIError := geminiStreamHandler(c, info, resp, func(_ string, _ *dto.GeminiChatResponse) bool {
|
||||||
|
return true
|
||||||
|
})
|
||||||
|
require.Nil(t, newAPIError)
|
||||||
|
require.NotNil(t, usage)
|
||||||
|
require.Equal(t, 18480, usage.PromptTokens)
|
||||||
|
require.Equal(t, 2209, usage.CompletionTokens)
|
||||||
|
require.Equal(t, 20689, usage.TotalTokens)
|
||||||
|
require.Equal(t, 1120, usage.CompletionTokenDetails.ReasoningTokens)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestGeminiTextGenerationHandlerPromptTokensIncludeToolUsePromptTokens(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
|
||||||
|
gin.SetMode(gin.TestMode)
|
||||||
|
c, _ := gin.CreateTestContext(httptest.NewRecorder())
|
||||||
|
c.Request = httptest.NewRequest(http.MethodPost, "/v1beta/models/gemini-3-flash-preview:generateContent", nil)
|
||||||
|
|
||||||
|
info := &relaycommon.RelayInfo{
|
||||||
|
OriginModelName: "gemini-3-flash-preview",
|
||||||
|
ChannelMeta: &relaycommon.ChannelMeta{
|
||||||
|
UpstreamModelName: "gemini-3-flash-preview",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
payload := dto.GeminiChatResponse{
|
||||||
|
Candidates: []dto.GeminiChatCandidate{
|
||||||
|
{
|
||||||
|
Content: dto.GeminiChatContent{
|
||||||
|
Role: "model",
|
||||||
|
Parts: []dto.GeminiPart{
|
||||||
|
{Text: "ok"},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
UsageMetadata: dto.GeminiUsageMetadata{
|
||||||
|
PromptTokenCount: 151,
|
||||||
|
ToolUsePromptTokenCount: 18329,
|
||||||
|
CandidatesTokenCount: 1089,
|
||||||
|
ThoughtsTokenCount: 1120,
|
||||||
|
TotalTokenCount: 20689,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
body, err := common.Marshal(payload)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
resp := &http.Response{
|
||||||
|
Body: io.NopCloser(bytes.NewReader(body)),
|
||||||
|
}
|
||||||
|
|
||||||
|
usage, newAPIError := GeminiTextGenerationHandler(c, info, resp)
|
||||||
|
require.Nil(t, newAPIError)
|
||||||
|
require.NotNil(t, usage)
|
||||||
|
require.Equal(t, 18480, usage.PromptTokens)
|
||||||
|
require.Equal(t, 2209, usage.CompletionTokens)
|
||||||
|
require.Equal(t, 20689, usage.TotalTokens)
|
||||||
|
require.Equal(t, 1120, usage.CompletionTokenDetails.ReasoningTokens)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestGeminiChatHandlerUsesEstimatedPromptTokensWhenUsagePromptMissing(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
|
||||||
|
gin.SetMode(gin.TestMode)
|
||||||
|
c, _ := gin.CreateTestContext(httptest.NewRecorder())
|
||||||
|
c.Request = httptest.NewRequest(http.MethodPost, "/v1/chat/completions", nil)
|
||||||
|
|
||||||
|
info := &relaycommon.RelayInfo{
|
||||||
|
RelayFormat: types.RelayFormatGemini,
|
||||||
|
OriginModelName: "gemini-3-flash-preview",
|
||||||
|
ChannelMeta: &relaycommon.ChannelMeta{
|
||||||
|
UpstreamModelName: "gemini-3-flash-preview",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
info.SetEstimatePromptTokens(20)
|
||||||
|
|
||||||
|
payload := dto.GeminiChatResponse{
|
||||||
|
Candidates: []dto.GeminiChatCandidate{
|
||||||
|
{
|
||||||
|
Content: dto.GeminiChatContent{
|
||||||
|
Role: "model",
|
||||||
|
Parts: []dto.GeminiPart{
|
||||||
|
{Text: "ok"},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
UsageMetadata: dto.GeminiUsageMetadata{
|
||||||
|
PromptTokenCount: 0,
|
||||||
|
ToolUsePromptTokenCount: 0,
|
||||||
|
CandidatesTokenCount: 90,
|
||||||
|
ThoughtsTokenCount: 10,
|
||||||
|
TotalTokenCount: 110,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
body, err := common.Marshal(payload)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
resp := &http.Response{
|
||||||
|
Body: io.NopCloser(bytes.NewReader(body)),
|
||||||
|
}
|
||||||
|
|
||||||
|
usage, newAPIError := GeminiChatHandler(c, info, resp)
|
||||||
|
require.Nil(t, newAPIError)
|
||||||
|
require.NotNil(t, usage)
|
||||||
|
require.Equal(t, 20, usage.PromptTokens)
|
||||||
|
require.Equal(t, 100, usage.CompletionTokens)
|
||||||
|
require.Equal(t, 110, usage.TotalTokens)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestGeminiStreamHandlerUsesEstimatedPromptTokensWhenUsagePromptMissing(t *testing.T) {
|
||||||
|
gin.SetMode(gin.TestMode)
|
||||||
|
c, _ := gin.CreateTestContext(httptest.NewRecorder())
|
||||||
|
c.Request = httptest.NewRequest(http.MethodPost, "/v1/chat/completions", nil)
|
||||||
|
|
||||||
|
oldStreamingTimeout := constant.StreamingTimeout
|
||||||
|
constant.StreamingTimeout = 300
|
||||||
|
t.Cleanup(func() {
|
||||||
|
constant.StreamingTimeout = oldStreamingTimeout
|
||||||
|
})
|
||||||
|
|
||||||
|
info := &relaycommon.RelayInfo{
|
||||||
|
OriginModelName: "gemini-3-flash-preview",
|
||||||
|
ChannelMeta: &relaycommon.ChannelMeta{
|
||||||
|
UpstreamModelName: "gemini-3-flash-preview",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
info.SetEstimatePromptTokens(20)
|
||||||
|
|
||||||
|
chunk := dto.GeminiChatResponse{
|
||||||
|
Candidates: []dto.GeminiChatCandidate{
|
||||||
|
{
|
||||||
|
Content: dto.GeminiChatContent{
|
||||||
|
Role: "model",
|
||||||
|
Parts: []dto.GeminiPart{
|
||||||
|
{Text: "partial"},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
UsageMetadata: dto.GeminiUsageMetadata{
|
||||||
|
PromptTokenCount: 0,
|
||||||
|
ToolUsePromptTokenCount: 0,
|
||||||
|
CandidatesTokenCount: 90,
|
||||||
|
ThoughtsTokenCount: 10,
|
||||||
|
TotalTokenCount: 110,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
chunkData, err := common.Marshal(chunk)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
streamBody := []byte("data: " + string(chunkData) + "\n" + "data: [DONE]\n")
|
||||||
|
resp := &http.Response{
|
||||||
|
Body: io.NopCloser(bytes.NewReader(streamBody)),
|
||||||
|
}
|
||||||
|
|
||||||
|
usage, newAPIError := geminiStreamHandler(c, info, resp, func(_ string, _ *dto.GeminiChatResponse) bool {
|
||||||
|
return true
|
||||||
|
})
|
||||||
|
require.Nil(t, newAPIError)
|
||||||
|
require.NotNil(t, usage)
|
||||||
|
require.Equal(t, 20, usage.PromptTokens)
|
||||||
|
require.Equal(t, 100, usage.CompletionTokens)
|
||||||
|
require.Equal(t, 110, usage.TotalTokens)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestGeminiTextGenerationHandlerUsesEstimatedPromptTokensWhenUsagePromptMissing(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
|
||||||
|
gin.SetMode(gin.TestMode)
|
||||||
|
c, _ := gin.CreateTestContext(httptest.NewRecorder())
|
||||||
|
c.Request = httptest.NewRequest(http.MethodPost, "/v1beta/models/gemini-3-flash-preview:generateContent", nil)
|
||||||
|
|
||||||
|
info := &relaycommon.RelayInfo{
|
||||||
|
OriginModelName: "gemini-3-flash-preview",
|
||||||
|
ChannelMeta: &relaycommon.ChannelMeta{
|
||||||
|
UpstreamModelName: "gemini-3-flash-preview",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
info.SetEstimatePromptTokens(20)
|
||||||
|
|
||||||
|
payload := dto.GeminiChatResponse{
|
||||||
|
Candidates: []dto.GeminiChatCandidate{
|
||||||
|
{
|
||||||
|
Content: dto.GeminiChatContent{
|
||||||
|
Role: "model",
|
||||||
|
Parts: []dto.GeminiPart{
|
||||||
|
{Text: "ok"},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
UsageMetadata: dto.GeminiUsageMetadata{
|
||||||
|
PromptTokenCount: 0,
|
||||||
|
ToolUsePromptTokenCount: 0,
|
||||||
|
CandidatesTokenCount: 90,
|
||||||
|
ThoughtsTokenCount: 10,
|
||||||
|
TotalTokenCount: 110,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
body, err := common.Marshal(payload)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
resp := &http.Response{
|
||||||
|
Body: io.NopCloser(bytes.NewReader(body)),
|
||||||
|
}
|
||||||
|
|
||||||
|
usage, newAPIError := GeminiTextGenerationHandler(c, info, resp)
|
||||||
|
require.Nil(t, newAPIError)
|
||||||
|
require.NotNil(t, usage)
|
||||||
|
require.Equal(t, 20, usage.PromptTokens)
|
||||||
|
require.Equal(t, 100, usage.CompletionTokens)
|
||||||
|
require.Equal(t, 110, usage.TotalTokens)
|
||||||
|
}
|
||||||
Reference in New Issue
Block a user