mirror of
https://github.com/QuantumNous/new-api.git
synced 2026-04-19 11:08:37 +00:00
112 lines
3.7 KiB
Go
112 lines
3.7 KiB
Go
package gemini
|
||
|
||
import (
|
||
"fmt"
|
||
"io"
|
||
"net/http"
|
||
|
||
"github.com/QuantumNous/new-api/common"
|
||
"github.com/QuantumNous/new-api/constant"
|
||
"github.com/QuantumNous/new-api/dto"
|
||
"github.com/QuantumNous/new-api/logger"
|
||
relaycommon "github.com/QuantumNous/new-api/relay/common"
|
||
"github.com/QuantumNous/new-api/relay/helper"
|
||
"github.com/QuantumNous/new-api/service"
|
||
"github.com/QuantumNous/new-api/types"
|
||
|
||
"github.com/gin-gonic/gin"
|
||
)
|
||
|
||
func GeminiTextGenerationHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http.Response) (*dto.Usage, *types.NewAPIError) {
|
||
defer service.CloseResponseBodyGracefully(resp)
|
||
|
||
// 读取响应体
|
||
responseBody, err := io.ReadAll(resp.Body)
|
||
if err != nil {
|
||
return nil, types.NewOpenAIError(err, types.ErrorCodeBadResponseBody, http.StatusInternalServerError)
|
||
}
|
||
|
||
if common.DebugEnabled {
|
||
println(string(responseBody))
|
||
}
|
||
|
||
// 解析为 Gemini 原生响应格式
|
||
var geminiResponse dto.GeminiChatResponse
|
||
err = common.Unmarshal(responseBody, &geminiResponse)
|
||
if err != nil {
|
||
return nil, types.NewOpenAIError(err, types.ErrorCodeBadResponseBody, http.StatusInternalServerError)
|
||
}
|
||
|
||
if len(geminiResponse.Candidates) == 0 && geminiResponse.PromptFeedback != nil && geminiResponse.PromptFeedback.BlockReason != nil {
|
||
common.SetContextKey(c, constant.ContextKeyAdminRejectReason, fmt.Sprintf("gemini_block_reason=%s", *geminiResponse.PromptFeedback.BlockReason))
|
||
}
|
||
|
||
// 计算使用量(基于 UsageMetadata)
|
||
usage := dto.Usage{
|
||
PromptTokens: geminiResponse.UsageMetadata.PromptTokenCount,
|
||
CompletionTokens: geminiResponse.UsageMetadata.CandidatesTokenCount + geminiResponse.UsageMetadata.ThoughtsTokenCount,
|
||
TotalTokens: geminiResponse.UsageMetadata.TotalTokenCount,
|
||
}
|
||
|
||
usage.CompletionTokenDetails.ReasoningTokens = geminiResponse.UsageMetadata.ThoughtsTokenCount
|
||
|
||
for _, detail := range geminiResponse.UsageMetadata.PromptTokensDetails {
|
||
if detail.Modality == "AUDIO" {
|
||
usage.PromptTokensDetails.AudioTokens = detail.TokenCount
|
||
} else if detail.Modality == "TEXT" {
|
||
usage.PromptTokensDetails.TextTokens = detail.TokenCount
|
||
}
|
||
}
|
||
|
||
service.IOCopyBytesGracefully(c, resp, responseBody)
|
||
|
||
return &usage, nil
|
||
}
|
||
|
||
func NativeGeminiEmbeddingHandler(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (*dto.Usage, *types.NewAPIError) {
|
||
defer service.CloseResponseBodyGracefully(resp)
|
||
|
||
responseBody, err := io.ReadAll(resp.Body)
|
||
if err != nil {
|
||
return nil, types.NewOpenAIError(err, types.ErrorCodeBadResponseBody, http.StatusInternalServerError)
|
||
}
|
||
|
||
if common.DebugEnabled {
|
||
println(string(responseBody))
|
||
}
|
||
|
||
usage := service.ResponseText2Usage(c, "", info.UpstreamModelName, info.GetEstimatePromptTokens())
|
||
|
||
if info.IsGeminiBatchEmbedding {
|
||
var geminiResponse dto.GeminiBatchEmbeddingResponse
|
||
err = common.Unmarshal(responseBody, &geminiResponse)
|
||
if err != nil {
|
||
return nil, types.NewOpenAIError(err, types.ErrorCodeBadResponseBody, http.StatusInternalServerError)
|
||
}
|
||
} else {
|
||
var geminiResponse dto.GeminiEmbeddingResponse
|
||
err = common.Unmarshal(responseBody, &geminiResponse)
|
||
if err != nil {
|
||
return nil, types.NewOpenAIError(err, types.ErrorCodeBadResponseBody, http.StatusInternalServerError)
|
||
}
|
||
}
|
||
|
||
service.IOCopyBytesGracefully(c, resp, responseBody)
|
||
|
||
return usage, nil
|
||
}
|
||
|
||
func GeminiTextGenerationStreamHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http.Response) (*dto.Usage, *types.NewAPIError) {
|
||
helper.SetEventStreamHeaders(c)
|
||
|
||
return geminiStreamHandler(c, info, resp, func(data string, geminiResponse *dto.GeminiChatResponse) bool {
|
||
err := helper.StringData(c, data)
|
||
if err != nil {
|
||
logger.LogError(c, "failed to write stream data: "+err.Error())
|
||
return false
|
||
}
|
||
info.SendResponseCount++
|
||
return true
|
||
})
|
||
}
|