Files
new-api/relay/channel/cloudflare/relay_cloudflare.go
CaIon f5b409d74f feat: refactor token estimation logic
- Introduced new OpenAI text models in `common/model.go`.
- Added `IsOpenAITextModel` function to check for OpenAI text models.
- Refactored token estimation methods across various channels to use estimated prompt tokens instead of direct prompt token counts.
- Updated related functions and structures to accommodate the new token estimation approach, enhancing overall token management.
2025-12-02 21:34:39 +08:00

148 lines
4.3 KiB
Go

package cloudflare
import (
"bufio"
"encoding/json"
"io"
"net/http"
"strings"
"time"
"github.com/QuantumNous/new-api/dto"
"github.com/QuantumNous/new-api/logger"
relaycommon "github.com/QuantumNous/new-api/relay/common"
"github.com/QuantumNous/new-api/relay/helper"
"github.com/QuantumNous/new-api/service"
"github.com/QuantumNous/new-api/types"
"github.com/gin-gonic/gin"
)
func convertCf2CompletionsRequest(textRequest dto.GeneralOpenAIRequest) *CfRequest {
p, _ := textRequest.Prompt.(string)
return &CfRequest{
Prompt: p,
MaxTokens: textRequest.GetMaxTokens(),
Stream: textRequest.Stream,
Temperature: textRequest.Temperature,
}
}
func cfStreamHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http.Response) (*types.NewAPIError, *dto.Usage) {
scanner := bufio.NewScanner(resp.Body)
scanner.Split(bufio.ScanLines)
helper.SetEventStreamHeaders(c)
id := helper.GetResponseID(c)
var responseText string
isFirst := true
for scanner.Scan() {
data := scanner.Text()
if len(data) < len("data: ") {
continue
}
data = strings.TrimPrefix(data, "data: ")
data = strings.TrimSuffix(data, "\r")
if data == "[DONE]" {
break
}
var response dto.ChatCompletionsStreamResponse
err := json.Unmarshal([]byte(data), &response)
if err != nil {
logger.LogError(c, "error_unmarshalling_stream_response: "+err.Error())
continue
}
for _, choice := range response.Choices {
choice.Delta.Role = "assistant"
responseText += choice.Delta.GetContentString()
}
response.Id = id
response.Model = info.UpstreamModelName
err = helper.ObjectData(c, response)
if isFirst {
isFirst = false
info.FirstResponseTime = time.Now()
}
if err != nil {
logger.LogError(c, "error_rendering_stream_response: "+err.Error())
}
}
if err := scanner.Err(); err != nil {
logger.LogError(c, "error_scanning_stream_response: "+err.Error())
}
usage := service.ResponseText2Usage(c, responseText, info.UpstreamModelName, info.GetEstimatePromptTokens())
if info.ShouldIncludeUsage {
response := helper.GenerateFinalUsageResponse(id, info.StartTime.Unix(), info.UpstreamModelName, *usage)
err := helper.ObjectData(c, response)
if err != nil {
logger.LogError(c, "error_rendering_final_usage_response: "+err.Error())
}
}
helper.Done(c)
service.CloseResponseBodyGracefully(resp)
return nil, usage
}
func cfHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http.Response) (*types.NewAPIError, *dto.Usage) {
responseBody, err := io.ReadAll(resp.Body)
if err != nil {
return types.NewError(err, types.ErrorCodeBadResponseBody), nil
}
service.CloseResponseBodyGracefully(resp)
var response dto.TextResponse
err = json.Unmarshal(responseBody, &response)
if err != nil {
return types.NewError(err, types.ErrorCodeBadResponseBody), nil
}
response.Model = info.UpstreamModelName
var responseText string
for _, choice := range response.Choices {
responseText += choice.Message.StringContent()
}
usage := service.ResponseText2Usage(c, responseText, info.UpstreamModelName, info.GetEstimatePromptTokens())
response.Usage = *usage
response.Id = helper.GetResponseID(c)
jsonResponse, err := json.Marshal(response)
if err != nil {
return types.NewError(err, types.ErrorCodeBadResponseBody), nil
}
c.Writer.Header().Set("Content-Type", "application/json")
c.Writer.WriteHeader(resp.StatusCode)
_, _ = c.Writer.Write(jsonResponse)
return nil, usage
}
func cfSTTHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http.Response) (*types.NewAPIError, *dto.Usage) {
var cfResp CfAudioResponse
responseBody, err := io.ReadAll(resp.Body)
if err != nil {
return types.NewError(err, types.ErrorCodeBadResponseBody), nil
}
service.CloseResponseBodyGracefully(resp)
err = json.Unmarshal(responseBody, &cfResp)
if err != nil {
return types.NewError(err, types.ErrorCodeBadResponseBody), nil
}
audioResp := &dto.AudioResponse{
Text: cfResp.Result.Text,
}
jsonResponse, err := json.Marshal(audioResp)
if err != nil {
return types.NewError(err, types.ErrorCodeBadResponseBody), nil
}
c.Writer.Header().Set("Content-Type", "application/json")
c.Writer.WriteHeader(resp.StatusCode)
_, _ = c.Writer.Write(jsonResponse)
usage := service.ResponseText2Usage(c, cfResp.Result.Text, info.UpstreamModelName, info.GetEstimatePromptTokens())
return nil, usage
}