mirror of
https://github.com/QuantumNous/new-api.git
synced 2026-04-01 03:01:21 +00:00
Compare commits
4 Commits
fix-random
...
v0.9.13
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
c4ca9d7c3b | ||
|
|
303feafc3c | ||
|
|
b2de5e229c | ||
|
|
8297723d91 |
@@ -30,10 +30,11 @@ func GetAudioDuration(ctx context.Context, f io.ReadSeeker, ext string) (duratio
|
||||
duration, err = getFLACDuration(f)
|
||||
case ".m4a", ".mp4":
|
||||
duration, err = getM4ADuration(f)
|
||||
case ".ogg", ".oga":
|
||||
case ".ogg", ".oga", ".opus":
|
||||
duration, err = getOGGDuration(f)
|
||||
case ".opus":
|
||||
duration, err = getOpusDuration(f)
|
||||
if err != nil {
|
||||
duration, err = getOpusDuration(f)
|
||||
}
|
||||
case ".aiff", ".aif", ".aifc":
|
||||
duration, err = getAIFFDuration(f)
|
||||
case ".webm":
|
||||
|
||||
@@ -230,10 +230,6 @@ func GetUUID() string {
|
||||
|
||||
const keyChars = "0123456789abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ"
|
||||
|
||||
func init() {
|
||||
rand.New(rand.NewSource(time.Now().UnixNano()))
|
||||
}
|
||||
|
||||
func GenerateRandomCharsKey(length int) (string, error) {
|
||||
b := make([]byte, length)
|
||||
maxI := big.NewInt(int64(len(keyChars)))
|
||||
|
||||
@@ -84,6 +84,7 @@ func Relay(c *gin.Context, relayFormat types.RelayFormat) {
|
||||
|
||||
defer func() {
|
||||
if newAPIError != nil {
|
||||
logger.LogError(c, fmt.Sprintf("relay error: %s", newAPIError.Error()))
|
||||
newAPIError.SetMessage(common.MessageWithRequestId(newAPIError.Error(), requestId))
|
||||
switch relayFormat {
|
||||
case types.RelayFormatOpenAIRealtime:
|
||||
@@ -281,7 +282,7 @@ func shouldRetry(c *gin.Context, openaiErr *types.NewAPIError, retryTimes int) b
|
||||
}
|
||||
|
||||
func processChannelError(c *gin.Context, channelError types.ChannelError, err *types.NewAPIError) {
|
||||
logger.LogError(c, fmt.Sprintf("relay error (channel #%d, status code: %d): %s", channelError.ChannelId, err.StatusCode, err.Error()))
|
||||
logger.LogError(c, fmt.Sprintf("channel error (channel #%d, status code: %d): %s", channelError.ChannelId, err.StatusCode, err.Error()))
|
||||
// 不要使用context获取渠道信息,异步处理时可能会出现渠道信息不一致的情况
|
||||
// do not use context to get channel info, there may be inconsistent channel info when processing asynchronously
|
||||
if service.ShouldDisableChannel(channelError.ChannelId, err) && channelError.AutoBan {
|
||||
|
||||
@@ -142,7 +142,6 @@ func GetRandomSatisfiedChannel(group string, model string, retry int) (*Channel,
|
||||
targetPriority := int64(sortedUniquePriorities[retry])
|
||||
|
||||
// get the priority for the given retry number
|
||||
var shouldSmooth = false
|
||||
var sumWeight = 0
|
||||
var targetChannels []*Channel
|
||||
for _, channelId := range channels {
|
||||
@@ -155,38 +154,34 @@ func GetRandomSatisfiedChannel(group string, model string, retry int) (*Channel,
|
||||
return nil, fmt.Errorf("数据库一致性错误,渠道# %d 不存在,请联系管理员修复", channelId)
|
||||
}
|
||||
}
|
||||
if sumWeight/len(targetChannels) < 10 {
|
||||
shouldSmooth = true
|
||||
|
||||
if len(targetChannels) == 0 {
|
||||
return nil, errors.New(fmt.Sprintf("no channel found, group: %s, model: %s, priority: %d", group, model, targetPriority))
|
||||
}
|
||||
|
||||
// 平滑系数
|
||||
// smoothing factor and adjustment
|
||||
smoothingFactor := 1
|
||||
if shouldSmooth {
|
||||
smoothingAdjustment := 0
|
||||
|
||||
if sumWeight == 0 {
|
||||
// when all channels have weight 0, set sumWeight to the number of channels and set smoothing adjustment to 100
|
||||
// each channel's effective weight = 100
|
||||
sumWeight = len(targetChannels) * 100
|
||||
smoothingAdjustment = 100
|
||||
} else if sumWeight/len(targetChannels) < 10 {
|
||||
// when the average weight is less than 10, set smoothing factor to 100
|
||||
smoothingFactor = 100
|
||||
}
|
||||
|
||||
// Calculate the total weight of all channels up to endIdx
|
||||
totalWeight := sumWeight * smoothingFactor
|
||||
|
||||
// totalWeight 小于等于0时,给每个渠道加100的权重,然后进行随机选择
|
||||
if totalWeight <= 0 {
|
||||
if len(targetChannels) > 0 {
|
||||
totalWeight = len(targetChannels) * 100
|
||||
randomWeight := rand.Intn(totalWeight)
|
||||
for _, channel := range targetChannels {
|
||||
randomWeight -= 100
|
||||
if randomWeight <= 0 {
|
||||
return channel, nil
|
||||
}
|
||||
}
|
||||
}
|
||||
return nil, errors.New("no available channels")
|
||||
}
|
||||
// Generate a random value in the range [0, totalWeight)
|
||||
randomWeight := rand.Intn(totalWeight)
|
||||
|
||||
// Find a channel based on its weight
|
||||
for _, channel := range targetChannels {
|
||||
randomWeight -= channel.GetWeight() * smoothingFactor
|
||||
randomWeight -= channel.GetWeight()*smoothingFactor + smoothingAdjustment
|
||||
if randomWeight < 0 {
|
||||
return channel, nil
|
||||
}
|
||||
|
||||
@@ -15,9 +15,11 @@ import (
|
||||
"github.com/QuantumNous/new-api/common"
|
||||
"github.com/QuantumNous/new-api/constant"
|
||||
"github.com/QuantumNous/new-api/dto"
|
||||
"github.com/QuantumNous/new-api/logger"
|
||||
"github.com/QuantumNous/new-api/relay/channel"
|
||||
"github.com/QuantumNous/new-api/relay/channel/ai360"
|
||||
"github.com/QuantumNous/new-api/relay/channel/lingyiwanwu"
|
||||
|
||||
//"github.com/QuantumNous/new-api/relay/channel/minimax"
|
||||
"github.com/QuantumNous/new-api/relay/channel/openrouter"
|
||||
"github.com/QuantumNous/new-api/relay/channel/xinference"
|
||||
@@ -352,27 +354,43 @@ func (a *Adaptor) ConvertAudioRequest(c *gin.Context, info *relaycommon.RelayInf
|
||||
|
||||
writer.WriteField("model", request.Model)
|
||||
|
||||
// 获取所有表单字段
|
||||
formData := c.Request.PostForm
|
||||
formData, err2 := common.ParseMultipartFormReusable(c)
|
||||
if err2 != nil {
|
||||
return nil, fmt.Errorf("error parsing multipart form: %w", err2)
|
||||
}
|
||||
|
||||
// 打印类似 curl 命令格式的信息
|
||||
logger.LogDebug(c.Request.Context(), fmt.Sprintf("--form 'model=\"%s\"'", request.Model))
|
||||
|
||||
// 遍历表单字段并打印输出
|
||||
for key, values := range formData {
|
||||
for key, values := range formData.Value {
|
||||
if key == "model" {
|
||||
continue
|
||||
}
|
||||
for _, value := range values {
|
||||
writer.WriteField(key, value)
|
||||
logger.LogDebug(c.Request.Context(), fmt.Sprintf("--form '%s=\"%s\"'", key, value))
|
||||
}
|
||||
}
|
||||
|
||||
// 添加文件字段
|
||||
file, header, err := c.Request.FormFile("file")
|
||||
if err != nil {
|
||||
// 从 formData 中获取文件
|
||||
fileHeaders := formData.File["file"]
|
||||
if len(fileHeaders) == 0 {
|
||||
return nil, errors.New("file is required")
|
||||
}
|
||||
|
||||
// 使用 formData 中的第一个文件
|
||||
fileHeader := fileHeaders[0]
|
||||
logger.LogDebug(c.Request.Context(), fmt.Sprintf("--form 'file=@\"%s\"' (size: %d bytes, content-type: %s)",
|
||||
fileHeader.Filename, fileHeader.Size, fileHeader.Header.Get("Content-Type")))
|
||||
|
||||
file, err := fileHeader.Open()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("error opening audio file: %v", err)
|
||||
}
|
||||
defer file.Close()
|
||||
|
||||
part, err := writer.CreateFormFile("file", header.Filename)
|
||||
part, err := writer.CreateFormFile("file", fileHeader.Filename)
|
||||
if err != nil {
|
||||
return nil, errors.New("create form file failed")
|
||||
}
|
||||
@@ -383,6 +401,7 @@ func (a *Adaptor) ConvertAudioRequest(c *gin.Context, info *relaycommon.RelayInf
|
||||
// 关闭 multipart 编写器以设置分界线
|
||||
writer.Close()
|
||||
c.Request.Header.Set("Content-Type", writer.FormDataContentType())
|
||||
logger.LogDebug(c.Request.Context(), fmt.Sprintf("--header 'Content-Type: %s'", writer.FormDataContentType()))
|
||||
return &requestBody, nil
|
||||
}
|
||||
}
|
||||
|
||||
@@ -31,8 +31,8 @@ func (a *Adaptor) ConvertClaudeRequest(c *gin.Context, info *relaycommon.RelayIn
|
||||
}
|
||||
|
||||
func (a *Adaptor) ConvertAudioRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.AudioRequest) (io.Reader, error) {
|
||||
//TODO implement me
|
||||
return nil, errors.New("not supported")
|
||||
adaptor := openai.Adaptor{}
|
||||
return adaptor.ConvertAudioRequest(c, info, request)
|
||||
}
|
||||
|
||||
func (a *Adaptor) ConvertImageRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.ImageRequest) (any, error) {
|
||||
@@ -65,16 +65,8 @@ func (a *Adaptor) Init(info *relaycommon.RelayInfo) {
|
||||
func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) {
|
||||
if info.RelayMode == constant.RelayModeRerank {
|
||||
return fmt.Sprintf("%s/v1/rerank", info.ChannelBaseUrl), nil
|
||||
} else if info.RelayMode == constant.RelayModeEmbeddings {
|
||||
return fmt.Sprintf("%s/v1/embeddings", info.ChannelBaseUrl), nil
|
||||
} else if info.RelayMode == constant.RelayModeChatCompletions {
|
||||
return fmt.Sprintf("%s/v1/chat/completions", info.ChannelBaseUrl), nil
|
||||
} else if info.RelayMode == constant.RelayModeCompletions {
|
||||
return fmt.Sprintf("%s/v1/completions", info.ChannelBaseUrl), nil
|
||||
} else if info.RelayMode == constant.RelayModeImagesGenerations {
|
||||
return fmt.Sprintf("%s/v1/images/generations", info.ChannelBaseUrl), nil
|
||||
}
|
||||
return fmt.Sprintf("%s/v1/chat/completions", info.ChannelBaseUrl), nil
|
||||
return relaycommon.GetFullRequestURL(info.ChannelBaseUrl, info.RequestURLPath, info.ChannelType), nil
|
||||
}
|
||||
|
||||
func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Header, info *relaycommon.RelayInfo) error {
|
||||
@@ -103,7 +95,8 @@ func (a *Adaptor) ConvertOpenAIResponsesRequest(c *gin.Context, info *relaycommo
|
||||
}
|
||||
|
||||
func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (any, error) {
|
||||
return channel.DoApiRequest(a, c, info, requestBody)
|
||||
adaptor := openai.Adaptor{}
|
||||
return adaptor.DoRequest(c, info, requestBody)
|
||||
}
|
||||
|
||||
func (a *Adaptor) ConvertRerankRequest(c *gin.Context, relayMode int, request dto.RerankRequest) (any, error) {
|
||||
@@ -118,21 +111,9 @@ func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycom
|
||||
switch info.RelayMode {
|
||||
case constant.RelayModeRerank:
|
||||
usage, err = siliconflowRerankHandler(c, info, resp)
|
||||
case constant.RelayModeEmbeddings:
|
||||
usage, err = openai.OpenaiHandler(c, info, resp)
|
||||
case constant.RelayModeCompletions:
|
||||
fallthrough
|
||||
case constant.RelayModeChatCompletions:
|
||||
fallthrough
|
||||
case constant.RelayModeImagesGenerations:
|
||||
fallthrough
|
||||
default:
|
||||
if info.IsStream {
|
||||
usage, err = openai.OaiStreamHandler(c, info, resp)
|
||||
} else {
|
||||
usage, err = openai.OpenaiHandler(c, info, resp)
|
||||
}
|
||||
|
||||
adaptor := openai.Adaptor{}
|
||||
usage, err = adaptor.DoResponse(c, resp, info)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user