Files
new-api/relay/channel/gemini/relay-gemini-native.go
2026-01-25 14:52:18 +08:00

112 lines
3.7 KiB
Go
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
package gemini
import (
"fmt"
"io"
"net/http"
"github.com/QuantumNous/new-api/common"
"github.com/QuantumNous/new-api/constant"
"github.com/QuantumNous/new-api/dto"
"github.com/QuantumNous/new-api/logger"
relaycommon "github.com/QuantumNous/new-api/relay/common"
"github.com/QuantumNous/new-api/relay/helper"
"github.com/QuantumNous/new-api/service"
"github.com/QuantumNous/new-api/types"
"github.com/gin-gonic/gin"
)
func GeminiTextGenerationHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http.Response) (*dto.Usage, *types.NewAPIError) {
defer service.CloseResponseBodyGracefully(resp)
// 读取响应体
responseBody, err := io.ReadAll(resp.Body)
if err != nil {
return nil, types.NewOpenAIError(err, types.ErrorCodeBadResponseBody, http.StatusInternalServerError)
}
if common.DebugEnabled {
println(string(responseBody))
}
// 解析为 Gemini 原生响应格式
var geminiResponse dto.GeminiChatResponse
err = common.Unmarshal(responseBody, &geminiResponse)
if err != nil {
return nil, types.NewOpenAIError(err, types.ErrorCodeBadResponseBody, http.StatusInternalServerError)
}
if len(geminiResponse.Candidates) == 0 && geminiResponse.PromptFeedback != nil && geminiResponse.PromptFeedback.BlockReason != nil {
common.SetContextKey(c, constant.ContextKeyAdminRejectReason, fmt.Sprintf("gemini_block_reason=%s", *geminiResponse.PromptFeedback.BlockReason))
}
// 计算使用量(基于 UsageMetadata
usage := dto.Usage{
PromptTokens: geminiResponse.UsageMetadata.PromptTokenCount,
CompletionTokens: geminiResponse.UsageMetadata.CandidatesTokenCount + geminiResponse.UsageMetadata.ThoughtsTokenCount,
TotalTokens: geminiResponse.UsageMetadata.TotalTokenCount,
}
usage.CompletionTokenDetails.ReasoningTokens = geminiResponse.UsageMetadata.ThoughtsTokenCount
for _, detail := range geminiResponse.UsageMetadata.PromptTokensDetails {
if detail.Modality == "AUDIO" {
usage.PromptTokensDetails.AudioTokens = detail.TokenCount
} else if detail.Modality == "TEXT" {
usage.PromptTokensDetails.TextTokens = detail.TokenCount
}
}
service.IOCopyBytesGracefully(c, resp, responseBody)
return &usage, nil
}
func NativeGeminiEmbeddingHandler(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (*dto.Usage, *types.NewAPIError) {
defer service.CloseResponseBodyGracefully(resp)
responseBody, err := io.ReadAll(resp.Body)
if err != nil {
return nil, types.NewOpenAIError(err, types.ErrorCodeBadResponseBody, http.StatusInternalServerError)
}
if common.DebugEnabled {
println(string(responseBody))
}
usage := service.ResponseText2Usage(c, "", info.UpstreamModelName, info.GetEstimatePromptTokens())
if info.IsGeminiBatchEmbedding {
var geminiResponse dto.GeminiBatchEmbeddingResponse
err = common.Unmarshal(responseBody, &geminiResponse)
if err != nil {
return nil, types.NewOpenAIError(err, types.ErrorCodeBadResponseBody, http.StatusInternalServerError)
}
} else {
var geminiResponse dto.GeminiEmbeddingResponse
err = common.Unmarshal(responseBody, &geminiResponse)
if err != nil {
return nil, types.NewOpenAIError(err, types.ErrorCodeBadResponseBody, http.StatusInternalServerError)
}
}
service.IOCopyBytesGracefully(c, resp, responseBody)
return usage, nil
}
func GeminiTextGenerationStreamHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http.Response) (*dto.Usage, *types.NewAPIError) {
helper.SetEventStreamHeaders(c)
return geminiStreamHandler(c, info, resp, func(data string, geminiResponse *dto.GeminiChatResponse) bool {
err := helper.StringData(c, data)
if err != nil {
logger.LogError(c, "failed to write stream data: "+err.Error())
return false
}
info.SendResponseCount++
return true
})
}