mirror of
https://github.com/QuantumNous/new-api.git
synced 2026-04-12 06:47:27 +00:00
Compare commits
10 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
6e7587ab46 | ||
|
|
cc5066c510 | ||
|
|
b9b69b01e5 | ||
|
|
1f4f9123aa | ||
|
|
9cc6385b0c | ||
|
|
2d42145b66 | ||
|
|
94736407a0 | ||
|
|
de859c3cc9 | ||
|
|
8dd4ce986c | ||
|
|
06da65a9d0 |
@@ -24,7 +24,7 @@ func relayHandler(c *gin.Context, relayMode int) *dto.OpenAIErrorWithStatusCode
|
||||
var err *dto.OpenAIErrorWithStatusCode
|
||||
switch relayMode {
|
||||
case relayconstant.RelayModeImagesGenerations:
|
||||
err = relay.ImageHelper(c, relayMode)
|
||||
err = relay.ImageHelper(c)
|
||||
case relayconstant.RelayModeAudioSpeech:
|
||||
fallthrough
|
||||
case relayconstant.RelayModeAudioTranslation:
|
||||
|
||||
@@ -88,15 +88,15 @@ func (r GeneralOpenAIRequest) ParseInput() []string {
|
||||
}
|
||||
|
||||
type Message struct {
|
||||
Role string `json:"role"`
|
||||
Content json.RawMessage `json:"content"`
|
||||
// parsedContent not json field
|
||||
parsedContent []MediaContent
|
||||
Name *string `json:"name,omitempty"`
|
||||
Prefix *bool `json:"prefix,omitempty"`
|
||||
ReasoningContent string `json:"reasoning_content,omitempty"`
|
||||
ToolCalls json.RawMessage `json:"tool_calls,omitempty"`
|
||||
ToolCallId string `json:"tool_call_id,omitempty"`
|
||||
Role string `json:"role"`
|
||||
Content json.RawMessage `json:"content"`
|
||||
Name *string `json:"name,omitempty"`
|
||||
Prefix *bool `json:"prefix,omitempty"`
|
||||
ReasoningContent string `json:"reasoning_content,omitempty"`
|
||||
ToolCalls json.RawMessage `json:"tool_calls,omitempty"`
|
||||
ToolCallId string `json:"tool_call_id,omitempty"`
|
||||
parsedContent []MediaContent
|
||||
parsedStringContent *string
|
||||
}
|
||||
|
||||
type MediaContent struct {
|
||||
@@ -150,6 +150,9 @@ func (m *Message) SetToolCalls(toolCalls any) {
|
||||
}
|
||||
|
||||
func (m *Message) StringContent() string {
|
||||
if m.parsedStringContent != nil {
|
||||
return *m.parsedStringContent
|
||||
}
|
||||
var stringContent string
|
||||
if err := json.Unmarshal(m.Content, &stringContent); err == nil {
|
||||
return stringContent
|
||||
@@ -160,16 +163,24 @@ func (m *Message) StringContent() string {
|
||||
func (m *Message) SetStringContent(content string) {
|
||||
jsonContent, _ := json.Marshal(content)
|
||||
m.Content = jsonContent
|
||||
m.parsedStringContent = &content
|
||||
m.parsedContent = nil
|
||||
}
|
||||
|
||||
func (m *Message) SetMediaContent(content []MediaContent) {
|
||||
jsonContent, _ := json.Marshal(content)
|
||||
m.Content = jsonContent
|
||||
m.parsedContent = nil
|
||||
m.parsedStringContent = nil
|
||||
}
|
||||
|
||||
func (m *Message) IsStringContent() bool {
|
||||
if m.parsedStringContent != nil {
|
||||
return true
|
||||
}
|
||||
var stringContent string
|
||||
if err := json.Unmarshal(m.Content, &stringContent); err == nil {
|
||||
m.parsedStringContent = &stringContent
|
||||
return true
|
||||
}
|
||||
return false
|
||||
@@ -179,72 +190,86 @@ func (m *Message) ParseContent() []MediaContent {
|
||||
if m.parsedContent != nil {
|
||||
return m.parsedContent
|
||||
}
|
||||
|
||||
var contentList []MediaContent
|
||||
defer func() {
|
||||
if len(contentList) > 0 {
|
||||
m.parsedContent = contentList
|
||||
}
|
||||
}()
|
||||
|
||||
// 先尝试解析为字符串
|
||||
var stringContent string
|
||||
if err := json.Unmarshal(m.Content, &stringContent); err == nil {
|
||||
contentList = append(contentList, MediaContent{
|
||||
contentList = []MediaContent{{
|
||||
Type: ContentTypeText,
|
||||
Text: stringContent,
|
||||
})
|
||||
}}
|
||||
m.parsedContent = contentList
|
||||
return contentList
|
||||
}
|
||||
var arrayContent []json.RawMessage
|
||||
|
||||
// 尝试解析为数组
|
||||
var arrayContent []map[string]interface{}
|
||||
if err := json.Unmarshal(m.Content, &arrayContent); err == nil {
|
||||
for _, contentItem := range arrayContent {
|
||||
var contentMap map[string]any
|
||||
if err := json.Unmarshal(contentItem, &contentMap); err != nil {
|
||||
contentType, ok := contentItem["type"].(string)
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
switch contentMap["type"] {
|
||||
|
||||
switch contentType {
|
||||
case ContentTypeText:
|
||||
if subStr, ok := contentMap["text"].(string); ok {
|
||||
if text, ok := contentItem["text"].(string); ok {
|
||||
contentList = append(contentList, MediaContent{
|
||||
Type: ContentTypeText,
|
||||
Text: subStr,
|
||||
Text: text,
|
||||
})
|
||||
}
|
||||
|
||||
case ContentTypeImageURL:
|
||||
if subObj, ok := contentMap["image_url"].(map[string]any); ok {
|
||||
detail, ok := subObj["detail"]
|
||||
if ok {
|
||||
subObj["detail"] = detail.(string)
|
||||
} else {
|
||||
subObj["detail"] = "high"
|
||||
}
|
||||
imageUrl := contentItem["image_url"]
|
||||
switch v := imageUrl.(type) {
|
||||
case string:
|
||||
contentList = append(contentList, MediaContent{
|
||||
Type: ContentTypeImageURL,
|
||||
ImageUrl: MessageImageUrl{
|
||||
Url: subObj["url"].(string),
|
||||
Detail: subObj["detail"].(string),
|
||||
},
|
||||
})
|
||||
} else if url, ok := contentMap["image_url"].(string); ok {
|
||||
contentList = append(contentList, MediaContent{
|
||||
Type: ContentTypeImageURL,
|
||||
ImageUrl: MessageImageUrl{
|
||||
Url: url,
|
||||
Url: v,
|
||||
Detail: "high",
|
||||
},
|
||||
})
|
||||
case map[string]interface{}:
|
||||
url, ok1 := v["url"].(string)
|
||||
detail, ok2 := v["detail"].(string)
|
||||
if !ok2 {
|
||||
detail = "high"
|
||||
}
|
||||
if ok1 {
|
||||
contentList = append(contentList, MediaContent{
|
||||
Type: ContentTypeImageURL,
|
||||
ImageUrl: MessageImageUrl{
|
||||
Url: url,
|
||||
Detail: detail,
|
||||
},
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
case ContentTypeInputAudio:
|
||||
if subObj, ok := contentMap["input_audio"].(map[string]any); ok {
|
||||
contentList = append(contentList, MediaContent{
|
||||
Type: ContentTypeInputAudio,
|
||||
InputAudio: MessageInputAudio{
|
||||
Data: subObj["data"].(string),
|
||||
Format: subObj["format"].(string),
|
||||
},
|
||||
})
|
||||
if audioData, ok := contentItem["input_audio"].(map[string]interface{}); ok {
|
||||
data, ok1 := audioData["data"].(string)
|
||||
format, ok2 := audioData["format"].(string)
|
||||
if ok1 && ok2 {
|
||||
contentList = append(contentList, MediaContent{
|
||||
Type: ContentTypeInputAudio,
|
||||
InputAudio: MessageInputAudio{
|
||||
Data: data,
|
||||
Format: format,
|
||||
},
|
||||
})
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
return contentList
|
||||
}
|
||||
return nil
|
||||
|
||||
if len(contentList) > 0 {
|
||||
m.parsedContent = contentList
|
||||
}
|
||||
return contentList
|
||||
}
|
||||
|
||||
@@ -62,9 +62,10 @@ type ChatCompletionsStreamResponseChoice struct {
|
||||
}
|
||||
|
||||
type ChatCompletionsStreamResponseChoiceDelta struct {
|
||||
Content *string `json:"content,omitempty"`
|
||||
Role string `json:"role,omitempty"`
|
||||
ToolCalls []ToolCall `json:"tool_calls,omitempty"`
|
||||
Content *string `json:"content,omitempty"`
|
||||
ReasoningContent *string `json:"reasoning_content,omitempty"`
|
||||
Role string `json:"role,omitempty"`
|
||||
ToolCalls []ToolCall `json:"tool_calls,omitempty"`
|
||||
}
|
||||
|
||||
func (c *ChatCompletionsStreamResponseChoiceDelta) SetContentString(s string) {
|
||||
@@ -78,6 +79,13 @@ func (c *ChatCompletionsStreamResponseChoiceDelta) GetContentString() string {
|
||||
return *c.Content
|
||||
}
|
||||
|
||||
func (c *ChatCompletionsStreamResponseChoiceDelta) GetReasoningContent() string {
|
||||
if c.ReasoningContent == nil {
|
||||
return ""
|
||||
}
|
||||
return *c.ReasoningContent
|
||||
}
|
||||
|
||||
type ToolCall struct {
|
||||
// Index is not nil only in chat completion chunk object
|
||||
Index *int `json:"index,omitempty"`
|
||||
|
||||
@@ -84,7 +84,7 @@ func InitOptionMap() {
|
||||
common.OptionMap["QuotaForInviter"] = strconv.Itoa(common.QuotaForInviter)
|
||||
common.OptionMap["QuotaForInvitee"] = strconv.Itoa(common.QuotaForInvitee)
|
||||
common.OptionMap["QuotaRemindThreshold"] = strconv.Itoa(common.QuotaRemindThreshold)
|
||||
common.OptionMap["PreConsumedQuota"] = strconv.Itoa(common.PreConsumedQuota)
|
||||
common.OptionMap["ShouldPreConsumedQuota"] = strconv.Itoa(common.PreConsumedQuota)
|
||||
common.OptionMap["ModelRatio"] = common.ModelRatio2JSONString()
|
||||
common.OptionMap["ModelPrice"] = common.ModelPrice2JSONString()
|
||||
common.OptionMap["GroupRatio"] = setting.GroupRatio2JSONString()
|
||||
@@ -306,7 +306,7 @@ func updateOptionMap(key string, value string) (err error) {
|
||||
common.QuotaForInvitee, _ = strconv.Atoi(value)
|
||||
case "QuotaRemindThreshold":
|
||||
common.QuotaRemindThreshold, _ = strconv.Atoi(value)
|
||||
case "PreConsumedQuota":
|
||||
case "ShouldPreConsumedQuota":
|
||||
common.PreConsumedQuota, _ = strconv.Atoi(value)
|
||||
case "RetryTimes":
|
||||
common.RetryTimes, _ = strconv.Atoi(value)
|
||||
|
||||
@@ -162,6 +162,7 @@ func OaiStreamHandler(c *gin.Context, resp *http.Response, info *relaycommon.Rel
|
||||
//}
|
||||
for _, choice := range streamResponse.Choices {
|
||||
responseTextBuilder.WriteString(choice.Delta.GetContentString())
|
||||
responseTextBuilder.WriteString(choice.Delta.GetReasoningContent())
|
||||
if choice.Delta.ToolCalls != nil {
|
||||
if len(choice.Delta.ToolCalls) > toolCount {
|
||||
toolCount = len(choice.Delta.ToolCalls)
|
||||
@@ -182,6 +183,7 @@ func OaiStreamHandler(c *gin.Context, resp *http.Response, info *relaycommon.Rel
|
||||
//}
|
||||
for _, choice := range streamResponse.Choices {
|
||||
responseTextBuilder.WriteString(choice.Delta.GetContentString())
|
||||
responseTextBuilder.WriteString(choice.Delta.GetReasoningContent())
|
||||
if choice.Delta.ToolCalls != nil {
|
||||
if len(choice.Delta.ToolCalls) > toolCount {
|
||||
toolCount = len(choice.Delta.ToolCalls)
|
||||
@@ -273,7 +275,7 @@ func OpenaiHandler(c *gin.Context, resp *http.Response, promptTokens int, model
|
||||
if simpleResponse.Usage.TotalTokens == 0 || (simpleResponse.Usage.PromptTokens == 0 && simpleResponse.Usage.CompletionTokens == 0) {
|
||||
completionTokens := 0
|
||||
for _, choice := range simpleResponse.Choices {
|
||||
ctkm, _ := service.CountTextToken(string(choice.Message.Content), model)
|
||||
ctkm, _ := service.CountTextToken(choice.Message.StringContent()+choice.Message.ReasoningContent, model)
|
||||
completionTokens += ctkm
|
||||
}
|
||||
simpleResponse.Usage = dto.Usage{
|
||||
|
||||
@@ -13,24 +13,24 @@ import (
|
||||
)
|
||||
|
||||
type RelayInfo struct {
|
||||
ChannelType int
|
||||
ChannelId int
|
||||
TokenId int
|
||||
TokenKey string
|
||||
UserId int
|
||||
Group string
|
||||
TokenUnlimited bool
|
||||
StartTime time.Time
|
||||
FirstResponseTime time.Time
|
||||
setFirstResponse bool
|
||||
ApiType int
|
||||
IsStream bool
|
||||
IsPlayground bool
|
||||
UsePrice bool
|
||||
RelayMode int
|
||||
UpstreamModelName string
|
||||
OriginModelName string
|
||||
RecodeModelName string
|
||||
ChannelType int
|
||||
ChannelId int
|
||||
TokenId int
|
||||
TokenKey string
|
||||
UserId int
|
||||
Group string
|
||||
TokenUnlimited bool
|
||||
StartTime time.Time
|
||||
FirstResponseTime time.Time
|
||||
setFirstResponse bool
|
||||
ApiType int
|
||||
IsStream bool
|
||||
IsPlayground bool
|
||||
UsePrice bool
|
||||
RelayMode int
|
||||
UpstreamModelName string
|
||||
OriginModelName string
|
||||
//RecodeModelName string
|
||||
RequestURLPath string
|
||||
ApiVersion string
|
||||
PromptTokens int
|
||||
@@ -39,6 +39,7 @@ type RelayInfo struct {
|
||||
BaseUrl string
|
||||
SupportStreamOptions bool
|
||||
ShouldIncludeUsage bool
|
||||
IsModelMapped bool
|
||||
ClientWs *websocket.Conn
|
||||
TargetWs *websocket.Conn
|
||||
InputAudioFormat string
|
||||
@@ -50,6 +51,18 @@ type RelayInfo struct {
|
||||
ChannelSetting map[string]interface{}
|
||||
}
|
||||
|
||||
// 定义支持流式选项的通道类型
|
||||
var streamSupportedChannels = map[int]bool{
|
||||
common.ChannelTypeOpenAI: true,
|
||||
common.ChannelTypeAnthropic: true,
|
||||
common.ChannelTypeAws: true,
|
||||
common.ChannelTypeGemini: true,
|
||||
common.ChannelCloudflare: true,
|
||||
common.ChannelTypeAzure: true,
|
||||
common.ChannelTypeVolcEngine: true,
|
||||
common.ChannelTypeOllama: true,
|
||||
}
|
||||
|
||||
func GenRelayInfoWs(c *gin.Context, ws *websocket.Conn) *RelayInfo {
|
||||
info := GenRelayInfo(c)
|
||||
info.ClientWs = ws
|
||||
@@ -89,12 +102,13 @@ func GenRelayInfo(c *gin.Context) *RelayInfo {
|
||||
FirstResponseTime: startTime.Add(-time.Second),
|
||||
OriginModelName: c.GetString("original_model"),
|
||||
UpstreamModelName: c.GetString("original_model"),
|
||||
RecodeModelName: c.GetString("recode_model"),
|
||||
ApiType: apiType,
|
||||
ApiVersion: c.GetString("api_version"),
|
||||
ApiKey: strings.TrimPrefix(c.Request.Header.Get("Authorization"), "Bearer "),
|
||||
Organization: c.GetString("channel_organization"),
|
||||
ChannelSetting: channelSetting,
|
||||
//RecodeModelName: c.GetString("original_model"),
|
||||
IsModelMapped: false,
|
||||
ApiType: apiType,
|
||||
ApiVersion: c.GetString("api_version"),
|
||||
ApiKey: strings.TrimPrefix(c.Request.Header.Get("Authorization"), "Bearer "),
|
||||
Organization: c.GetString("channel_organization"),
|
||||
ChannelSetting: channelSetting,
|
||||
}
|
||||
if strings.HasPrefix(c.Request.URL.Path, "/pg") {
|
||||
info.IsPlayground = true
|
||||
@@ -110,10 +124,7 @@ func GenRelayInfo(c *gin.Context) *RelayInfo {
|
||||
if info.ChannelType == common.ChannelTypeVertexAi {
|
||||
info.ApiVersion = c.GetString("region")
|
||||
}
|
||||
if info.ChannelType == common.ChannelTypeOpenAI || info.ChannelType == common.ChannelTypeAnthropic ||
|
||||
info.ChannelType == common.ChannelTypeAws || info.ChannelType == common.ChannelTypeGemini ||
|
||||
info.ChannelType == common.ChannelCloudflare || info.ChannelType == common.ChannelTypeAzure ||
|
||||
info.ChannelType == common.ChannelTypeVolcEngine || info.ChannelType == common.ChannelTypeOllama {
|
||||
if streamSupportedChannels[info.ChannelType] {
|
||||
info.SupportStreamOptions = true
|
||||
}
|
||||
return info
|
||||
|
||||
25
relay/helper/model_mapped.go
Normal file
25
relay/helper/model_mapped.go
Normal file
@@ -0,0 +1,25 @@
|
||||
package helper
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"github.com/gin-gonic/gin"
|
||||
"one-api/relay/common"
|
||||
)
|
||||
|
||||
func ModelMappedHelper(c *gin.Context, info *common.RelayInfo) error {
|
||||
// map model name
|
||||
modelMapping := c.GetString("model_mapping")
|
||||
if modelMapping != "" && modelMapping != "{}" {
|
||||
modelMap := make(map[string]string)
|
||||
err := json.Unmarshal([]byte(modelMapping), &modelMap)
|
||||
if err != nil {
|
||||
return fmt.Errorf("unmarshal_model_mapping_failed")
|
||||
}
|
||||
if modelMap[info.OriginModelName] != "" {
|
||||
info.UpstreamModelName = modelMap[info.OriginModelName]
|
||||
info.IsModelMapped = true
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
41
relay/helper/price.go
Normal file
41
relay/helper/price.go
Normal file
@@ -0,0 +1,41 @@
|
||||
package helper
|
||||
|
||||
import (
|
||||
"github.com/gin-gonic/gin"
|
||||
"one-api/common"
|
||||
relaycommon "one-api/relay/common"
|
||||
"one-api/setting"
|
||||
)
|
||||
|
||||
type PriceData struct {
|
||||
ModelPrice float64
|
||||
ModelRatio float64
|
||||
GroupRatio float64
|
||||
UsePrice bool
|
||||
ShouldPreConsumedQuota int
|
||||
}
|
||||
|
||||
func ModelPriceHelper(c *gin.Context, info *relaycommon.RelayInfo, promptTokens int, maxTokens int) PriceData {
|
||||
modelPrice, usePrice := common.GetModelPrice(info.OriginModelName, false)
|
||||
groupRatio := setting.GetGroupRatio(info.Group)
|
||||
var preConsumedQuota int
|
||||
var modelRatio float64
|
||||
if !usePrice {
|
||||
preConsumedTokens := common.PreConsumedQuota
|
||||
if maxTokens != 0 {
|
||||
preConsumedTokens = promptTokens + maxTokens
|
||||
}
|
||||
modelRatio = common.GetModelRatio(info.OriginModelName)
|
||||
ratio := modelRatio * groupRatio
|
||||
preConsumedQuota = int(float64(preConsumedTokens) * ratio)
|
||||
} else {
|
||||
preConsumedQuota = int(modelPrice * common.QuotaPerUnit * groupRatio)
|
||||
}
|
||||
return PriceData{
|
||||
ModelPrice: modelPrice,
|
||||
ModelRatio: modelRatio,
|
||||
GroupRatio: groupRatio,
|
||||
UsePrice: usePrice,
|
||||
ShouldPreConsumedQuota: preConsumedQuota,
|
||||
}
|
||||
}
|
||||
@@ -1,7 +1,6 @@
|
||||
package relay
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"github.com/gin-gonic/gin"
|
||||
@@ -11,8 +10,10 @@ import (
|
||||
"one-api/model"
|
||||
relaycommon "one-api/relay/common"
|
||||
relayconstant "one-api/relay/constant"
|
||||
"one-api/relay/helper"
|
||||
"one-api/service"
|
||||
"one-api/setting"
|
||||
"strings"
|
||||
)
|
||||
|
||||
func getAndValidAudioRequest(c *gin.Context, info *relaycommon.RelayInfo) (*dto.AudioRequest, error) {
|
||||
@@ -27,8 +28,9 @@ func getAndValidAudioRequest(c *gin.Context, info *relaycommon.RelayInfo) (*dto.
|
||||
return nil, errors.New("model is required")
|
||||
}
|
||||
if setting.ShouldCheckPromptSensitive() {
|
||||
err := service.CheckSensitiveInput(audioRequest.Input)
|
||||
words, err := service.CheckSensitiveInput(audioRequest.Input)
|
||||
if err != nil {
|
||||
common.LogWarn(c, fmt.Sprintf("user sensitive words detected: %s", strings.Join(words, ",")))
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
@@ -73,15 +75,13 @@ func AudioHelper(c *gin.Context) (openaiErr *dto.OpenAIErrorWithStatusCode) {
|
||||
relayInfo.PromptTokens = promptTokens
|
||||
}
|
||||
|
||||
modelRatio := common.GetModelRatio(audioRequest.Model)
|
||||
groupRatio := setting.GetGroupRatio(relayInfo.Group)
|
||||
ratio := modelRatio * groupRatio
|
||||
preConsumedQuota := int(float64(preConsumedTokens) * ratio)
|
||||
priceData := helper.ModelPriceHelper(c, relayInfo, preConsumedTokens, 0)
|
||||
|
||||
userQuota, err := model.GetUserQuota(relayInfo.UserId, false)
|
||||
if err != nil {
|
||||
return service.OpenAIErrorWrapperLocal(err, "get_user_quota_failed", http.StatusInternalServerError)
|
||||
}
|
||||
preConsumedQuota, userQuota, openaiErr = preConsumeQuota(c, preConsumedQuota, relayInfo)
|
||||
preConsumedQuota, userQuota, openaiErr := preConsumeQuota(c, priceData.ShouldPreConsumedQuota, relayInfo)
|
||||
if openaiErr != nil {
|
||||
return openaiErr
|
||||
}
|
||||
@@ -91,19 +91,12 @@ func AudioHelper(c *gin.Context) (openaiErr *dto.OpenAIErrorWithStatusCode) {
|
||||
}
|
||||
}()
|
||||
|
||||
// map model name
|
||||
modelMapping := c.GetString("model_mapping")
|
||||
if modelMapping != "" {
|
||||
modelMap := make(map[string]string)
|
||||
err := json.Unmarshal([]byte(modelMapping), &modelMap)
|
||||
if err != nil {
|
||||
return service.OpenAIErrorWrapper(err, "unmarshal_model_mapping_failed", http.StatusInternalServerError)
|
||||
}
|
||||
if modelMap[audioRequest.Model] != "" {
|
||||
audioRequest.Model = modelMap[audioRequest.Model]
|
||||
}
|
||||
err = helper.ModelMappedHelper(c, relayInfo)
|
||||
if err != nil {
|
||||
return service.OpenAIErrorWrapperLocal(err, "model_mapped_error", http.StatusInternalServerError)
|
||||
}
|
||||
relayInfo.UpstreamModelName = audioRequest.Model
|
||||
|
||||
audioRequest.Model = relayInfo.UpstreamModelName
|
||||
|
||||
adaptor := GetAdaptor(relayInfo.ApiType)
|
||||
if adaptor == nil {
|
||||
@@ -140,7 +133,7 @@ func AudioHelper(c *gin.Context) (openaiErr *dto.OpenAIErrorWithStatusCode) {
|
||||
return openaiErr
|
||||
}
|
||||
|
||||
postConsumeQuota(c, relayInfo, audioRequest.Model, usage.(*dto.Usage), ratio, preConsumedQuota, userQuota, modelRatio, groupRatio, 0, false, "")
|
||||
postConsumeQuota(c, relayInfo, usage.(*dto.Usage), preConsumedQuota, userQuota, priceData, "")
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -12,6 +12,7 @@ import (
|
||||
"one-api/dto"
|
||||
"one-api/model"
|
||||
relaycommon "one-api/relay/common"
|
||||
"one-api/relay/helper"
|
||||
"one-api/service"
|
||||
"one-api/setting"
|
||||
"strings"
|
||||
@@ -60,15 +61,16 @@ func getAndValidImageRequest(c *gin.Context, info *relaycommon.RelayInfo) (*dto.
|
||||
// return service.OpenAIErrorWrapper(errors.New("n must be between 1 and 10"), "invalid_field_value", http.StatusBadRequest)
|
||||
//}
|
||||
if setting.ShouldCheckPromptSensitive() {
|
||||
err := service.CheckSensitiveInput(imageRequest.Prompt)
|
||||
words, err := service.CheckSensitiveInput(imageRequest.Prompt)
|
||||
if err != nil {
|
||||
common.LogWarn(c, fmt.Sprintf("user sensitive words detected: %s", strings.Join(words, ",")))
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
return imageRequest, nil
|
||||
}
|
||||
|
||||
func ImageHelper(c *gin.Context, relayMode int) *dto.OpenAIErrorWithStatusCode {
|
||||
func ImageHelper(c *gin.Context) *dto.OpenAIErrorWithStatusCode {
|
||||
relayInfo := relaycommon.GenRelayInfo(c)
|
||||
|
||||
imageRequest, err := getAndValidImageRequest(c, relayInfo)
|
||||
@@ -77,29 +79,20 @@ func ImageHelper(c *gin.Context, relayMode int) *dto.OpenAIErrorWithStatusCode {
|
||||
return service.OpenAIErrorWrapper(err, "invalid_image_request", http.StatusBadRequest)
|
||||
}
|
||||
|
||||
// map model name
|
||||
modelMapping := c.GetString("model_mapping")
|
||||
if modelMapping != "" {
|
||||
modelMap := make(map[string]string)
|
||||
err := json.Unmarshal([]byte(modelMapping), &modelMap)
|
||||
if err != nil {
|
||||
return service.OpenAIErrorWrapper(err, "unmarshal_model_mapping_failed", http.StatusInternalServerError)
|
||||
}
|
||||
if modelMap[imageRequest.Model] != "" {
|
||||
imageRequest.Model = modelMap[imageRequest.Model]
|
||||
}
|
||||
err = helper.ModelMappedHelper(c, relayInfo)
|
||||
if err != nil {
|
||||
return service.OpenAIErrorWrapperLocal(err, "model_mapped_error", http.StatusInternalServerError)
|
||||
}
|
||||
relayInfo.UpstreamModelName = imageRequest.Model
|
||||
|
||||
modelPrice, success := common.GetModelPrice(imageRequest.Model, true)
|
||||
if !success {
|
||||
modelRatio := common.GetModelRatio(imageRequest.Model)
|
||||
imageRequest.Model = relayInfo.UpstreamModelName
|
||||
|
||||
priceData := helper.ModelPriceHelper(c, relayInfo, 0, 0)
|
||||
if !priceData.UsePrice {
|
||||
// modelRatio 16 = modelPrice $0.04
|
||||
// per 1 modelRatio = $0.04 / 16
|
||||
modelPrice = 0.0025 * modelRatio
|
||||
priceData.ModelPrice = 0.0025 * priceData.ModelRatio
|
||||
}
|
||||
|
||||
groupRatio := setting.GetGroupRatio(relayInfo.Group)
|
||||
userQuota, err := model.GetUserQuota(relayInfo.UserId, false)
|
||||
|
||||
sizeRatio := 1.0
|
||||
@@ -122,11 +115,11 @@ func ImageHelper(c *gin.Context, relayMode int) *dto.OpenAIErrorWithStatusCode {
|
||||
}
|
||||
}
|
||||
|
||||
imageRatio := modelPrice * sizeRatio * qualityRatio * float64(imageRequest.N)
|
||||
quota := int(imageRatio * groupRatio * common.QuotaPerUnit)
|
||||
imageRatio := priceData.ModelPrice * sizeRatio * qualityRatio * float64(imageRequest.N)
|
||||
quota := int(imageRatio * priceData.GroupRatio * common.QuotaPerUnit)
|
||||
|
||||
if userQuota-quota < 0 {
|
||||
return service.OpenAIErrorWrapperLocal(errors.New(fmt.Sprintf("image pre-consumed quota failed, user quota: %d, need quota: %d", userQuota, quota)), "insufficient_user_quota", http.StatusBadRequest)
|
||||
return service.OpenAIErrorWrapperLocal(fmt.Errorf("image pre-consumed quota failed, user quota: %s, need quota: %s", common.FormatQuota(userQuota), common.FormatQuota(quota)), "insufficient_user_quota", http.StatusForbidden)
|
||||
}
|
||||
|
||||
adaptor := GetAdaptor(relayInfo.ApiType)
|
||||
@@ -184,7 +177,6 @@ func ImageHelper(c *gin.Context, relayMode int) *dto.OpenAIErrorWithStatusCode {
|
||||
}
|
||||
|
||||
logContent := fmt.Sprintf("大小 %s, 品质 %s", imageRequest.Size, quality)
|
||||
postConsumeQuota(c, relayInfo, imageRequest.Model, usage, 0, 0, userQuota, 0, groupRatio, imageRatio, true, logContent)
|
||||
|
||||
postConsumeQuota(c, relayInfo, usage, 0, userQuota, priceData, logContent)
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -15,6 +15,7 @@ import (
|
||||
"one-api/model"
|
||||
relaycommon "one-api/relay/common"
|
||||
relayconstant "one-api/relay/constant"
|
||||
"one-api/relay/helper"
|
||||
"one-api/service"
|
||||
"one-api/setting"
|
||||
"strings"
|
||||
@@ -76,40 +77,21 @@ func TextHelper(c *gin.Context) (openaiErr *dto.OpenAIErrorWithStatusCode) {
|
||||
return service.OpenAIErrorWrapperLocal(err, "invalid_text_request", http.StatusBadRequest)
|
||||
}
|
||||
|
||||
// map model name
|
||||
//isModelMapped := false
|
||||
modelMapping := c.GetString("model_mapping")
|
||||
//isModelMapped := false
|
||||
if modelMapping != "" && modelMapping != "{}" {
|
||||
modelMap := make(map[string]string)
|
||||
err := json.Unmarshal([]byte(modelMapping), &modelMap)
|
||||
if err != nil {
|
||||
return service.OpenAIErrorWrapperLocal(err, "unmarshal_model_mapping_failed", http.StatusInternalServerError)
|
||||
}
|
||||
if modelMap[textRequest.Model] != "" {
|
||||
//isModelMapped = true
|
||||
textRequest.Model = modelMap[textRequest.Model]
|
||||
// set upstream model name
|
||||
//isModelMapped = true
|
||||
}
|
||||
}
|
||||
relayInfo.UpstreamModelName = textRequest.Model
|
||||
relayInfo.RecodeModelName = textRequest.Model
|
||||
modelPrice, getModelPriceSuccess := common.GetModelPrice(textRequest.Model, false)
|
||||
groupRatio := setting.GetGroupRatio(relayInfo.Group)
|
||||
|
||||
var preConsumedQuota int
|
||||
var ratio float64
|
||||
var modelRatio float64
|
||||
//err := service.SensitiveWordsCheck(textRequest)
|
||||
|
||||
if setting.ShouldCheckPromptSensitive() {
|
||||
err = checkRequestSensitive(textRequest, relayInfo)
|
||||
words, err := checkRequestSensitive(textRequest, relayInfo)
|
||||
if err != nil {
|
||||
common.LogWarn(c, fmt.Sprintf("user sensitive words detected: %s", strings.Join(words, ", ")))
|
||||
return service.OpenAIErrorWrapperLocal(err, "sensitive_words_detected", http.StatusBadRequest)
|
||||
}
|
||||
}
|
||||
|
||||
err = helper.ModelMappedHelper(c, relayInfo)
|
||||
if err != nil {
|
||||
return service.OpenAIErrorWrapperLocal(err, "model_mapped_error", http.StatusInternalServerError)
|
||||
}
|
||||
|
||||
textRequest.Model = relayInfo.UpstreamModelName
|
||||
|
||||
// 获取 promptTokens,如果上下文中已经存在,则直接使用
|
||||
var promptTokens int
|
||||
if value, exists := c.Get("prompt_tokens"); exists {
|
||||
@@ -124,20 +106,10 @@ func TextHelper(c *gin.Context) (openaiErr *dto.OpenAIErrorWithStatusCode) {
|
||||
c.Set("prompt_tokens", promptTokens)
|
||||
}
|
||||
|
||||
if !getModelPriceSuccess {
|
||||
preConsumedTokens := common.PreConsumedQuota
|
||||
if textRequest.MaxTokens != 0 {
|
||||
preConsumedTokens = promptTokens + int(textRequest.MaxTokens)
|
||||
}
|
||||
modelRatio = common.GetModelRatio(textRequest.Model)
|
||||
ratio = modelRatio * groupRatio
|
||||
preConsumedQuota = int(float64(preConsumedTokens) * ratio)
|
||||
} else {
|
||||
preConsumedQuota = int(modelPrice * common.QuotaPerUnit * groupRatio)
|
||||
}
|
||||
priceData := helper.ModelPriceHelper(c, relayInfo, promptTokens, int(textRequest.MaxTokens))
|
||||
|
||||
// pre-consume quota 预消耗配额
|
||||
preConsumedQuota, userQuota, openaiErr := preConsumeQuota(c, preConsumedQuota, relayInfo)
|
||||
preConsumedQuota, userQuota, openaiErr := preConsumeQuota(c, priceData.ShouldPreConsumedQuota, relayInfo)
|
||||
if openaiErr != nil {
|
||||
return openaiErr
|
||||
}
|
||||
@@ -220,10 +192,10 @@ func TextHelper(c *gin.Context) (openaiErr *dto.OpenAIErrorWithStatusCode) {
|
||||
return openaiErr
|
||||
}
|
||||
|
||||
if strings.HasPrefix(relayInfo.RecodeModelName, "gpt-4o-audio") {
|
||||
service.PostAudioConsumeQuota(c, relayInfo, usage.(*dto.Usage), preConsumedQuota, userQuota, modelRatio, groupRatio, modelPrice, getModelPriceSuccess, "")
|
||||
if strings.HasPrefix(relayInfo.OriginModelName, "gpt-4o-audio") {
|
||||
service.PostAudioConsumeQuota(c, relayInfo, usage.(*dto.Usage), preConsumedQuota, userQuota, priceData, "")
|
||||
} else {
|
||||
postConsumeQuota(c, relayInfo, relayInfo.RecodeModelName, usage.(*dto.Usage), ratio, preConsumedQuota, userQuota, modelRatio, groupRatio, modelPrice, getModelPriceSuccess, "")
|
||||
postConsumeQuota(c, relayInfo, usage.(*dto.Usage), preConsumedQuota, userQuota, priceData, "")
|
||||
}
|
||||
return nil
|
||||
}
|
||||
@@ -248,19 +220,20 @@ func getPromptTokens(textRequest *dto.GeneralOpenAIRequest, info *relaycommon.Re
|
||||
return promptTokens, err
|
||||
}
|
||||
|
||||
func checkRequestSensitive(textRequest *dto.GeneralOpenAIRequest, info *relaycommon.RelayInfo) error {
|
||||
func checkRequestSensitive(textRequest *dto.GeneralOpenAIRequest, info *relaycommon.RelayInfo) ([]string, error) {
|
||||
var err error
|
||||
var words []string
|
||||
switch info.RelayMode {
|
||||
case relayconstant.RelayModeChatCompletions:
|
||||
err = service.CheckSensitiveMessages(textRequest.Messages)
|
||||
words, err = service.CheckSensitiveMessages(textRequest.Messages)
|
||||
case relayconstant.RelayModeCompletions:
|
||||
err = service.CheckSensitiveInput(textRequest.Prompt)
|
||||
words, err = service.CheckSensitiveInput(textRequest.Prompt)
|
||||
case relayconstant.RelayModeModerations:
|
||||
err = service.CheckSensitiveInput(textRequest.Input)
|
||||
words, err = service.CheckSensitiveInput(textRequest.Input)
|
||||
case relayconstant.RelayModeEmbeddings:
|
||||
err = service.CheckSensitiveInput(textRequest.Input)
|
||||
words, err = service.CheckSensitiveInput(textRequest.Input)
|
||||
}
|
||||
return err
|
||||
return words, err
|
||||
}
|
||||
|
||||
// 预扣费并返回用户剩余配额
|
||||
@@ -273,7 +246,7 @@ func preConsumeQuota(c *gin.Context, preConsumedQuota int, relayInfo *relaycommo
|
||||
return 0, 0, service.OpenAIErrorWrapperLocal(errors.New("user quota is not enough"), "insufficient_user_quota", http.StatusForbidden)
|
||||
}
|
||||
if userQuota-preConsumedQuota < 0 {
|
||||
return 0, 0, service.OpenAIErrorWrapperLocal(fmt.Errorf("chat pre-consumed quota failed, user quota: %s, need quota: %s", common.FormatQuota(userQuota), common.FormatQuota(preConsumedQuota)), "insufficient_user_quota", http.StatusBadRequest)
|
||||
return 0, 0, service.OpenAIErrorWrapperLocal(fmt.Errorf("chat pre-consumed quota failed, user quota: %s, need quota: %s", common.FormatQuota(userQuota), common.FormatQuota(preConsumedQuota)), "insufficient_user_quota", http.StatusForbidden)
|
||||
}
|
||||
if userQuota > 100*preConsumedQuota {
|
||||
// 用户额度充足,判断令牌额度是否充足
|
||||
@@ -319,9 +292,8 @@ func returnPreConsumedQuota(c *gin.Context, relayInfo *relaycommon.RelayInfo, us
|
||||
}
|
||||
}
|
||||
|
||||
func postConsumeQuota(ctx *gin.Context, relayInfo *relaycommon.RelayInfo, modelName string,
|
||||
usage *dto.Usage, ratio float64, preConsumedQuota int, userQuota int, modelRatio float64, groupRatio float64,
|
||||
modelPrice float64, usePrice bool, extraContent string) {
|
||||
func postConsumeQuota(ctx *gin.Context, relayInfo *relaycommon.RelayInfo,
|
||||
usage *dto.Usage, preConsumedQuota int, userQuota int, priceData helper.PriceData, extraContent string) {
|
||||
if usage == nil {
|
||||
usage = &dto.Usage{
|
||||
PromptTokens: relayInfo.PromptTokens,
|
||||
@@ -333,12 +305,18 @@ func postConsumeQuota(ctx *gin.Context, relayInfo *relaycommon.RelayInfo, modelN
|
||||
useTimeSeconds := time.Now().Unix() - relayInfo.StartTime.Unix()
|
||||
promptTokens := usage.PromptTokens
|
||||
completionTokens := usage.CompletionTokens
|
||||
modelName := relayInfo.OriginModelName
|
||||
|
||||
tokenName := ctx.GetString("token_name")
|
||||
completionRatio := common.GetCompletionRatio(modelName)
|
||||
ratio := priceData.ModelRatio * priceData.GroupRatio
|
||||
modelRatio := priceData.ModelRatio
|
||||
groupRatio := priceData.GroupRatio
|
||||
modelPrice := priceData.ModelPrice
|
||||
usePrice := priceData.UsePrice
|
||||
|
||||
quota := 0
|
||||
if !usePrice {
|
||||
if !priceData.UsePrice {
|
||||
quota = promptTokens + int(math.Round(float64(completionTokens)*completionRatio))
|
||||
quota = int(math.Round(float64(quota) * ratio))
|
||||
if ratio != 0 && quota <= 0 {
|
||||
|
||||
@@ -10,8 +10,8 @@ import (
|
||||
"one-api/dto"
|
||||
relaycommon "one-api/relay/common"
|
||||
relayconstant "one-api/relay/constant"
|
||||
"one-api/relay/helper"
|
||||
"one-api/service"
|
||||
"one-api/setting"
|
||||
)
|
||||
|
||||
func getEmbeddingPromptToken(embeddingRequest dto.EmbeddingRequest) int {
|
||||
@@ -47,43 +47,20 @@ func EmbeddingHelper(c *gin.Context) (openaiErr *dto.OpenAIErrorWithStatusCode)
|
||||
return service.OpenAIErrorWrapperLocal(err, "invalid_embedding_request", http.StatusBadRequest)
|
||||
}
|
||||
|
||||
// map model name
|
||||
modelMapping := c.GetString("model_mapping")
|
||||
//isModelMapped := false
|
||||
if modelMapping != "" && modelMapping != "{}" {
|
||||
modelMap := make(map[string]string)
|
||||
err := json.Unmarshal([]byte(modelMapping), &modelMap)
|
||||
if err != nil {
|
||||
return service.OpenAIErrorWrapperLocal(err, "unmarshal_model_mapping_failed", http.StatusInternalServerError)
|
||||
}
|
||||
if modelMap[embeddingRequest.Model] != "" {
|
||||
embeddingRequest.Model = modelMap[embeddingRequest.Model]
|
||||
// set upstream model name
|
||||
//isModelMapped = true
|
||||
}
|
||||
err = helper.ModelMappedHelper(c, relayInfo)
|
||||
if err != nil {
|
||||
return service.OpenAIErrorWrapperLocal(err, "model_mapped_error", http.StatusInternalServerError)
|
||||
}
|
||||
|
||||
relayInfo.UpstreamModelName = embeddingRequest.Model
|
||||
modelPrice, success := common.GetModelPrice(embeddingRequest.Model, false)
|
||||
groupRatio := setting.GetGroupRatio(relayInfo.Group)
|
||||
|
||||
var preConsumedQuota int
|
||||
var ratio float64
|
||||
var modelRatio float64
|
||||
embeddingRequest.Model = relayInfo.UpstreamModelName
|
||||
|
||||
promptToken := getEmbeddingPromptToken(*embeddingRequest)
|
||||
if !success {
|
||||
preConsumedTokens := promptToken
|
||||
modelRatio = common.GetModelRatio(embeddingRequest.Model)
|
||||
ratio = modelRatio * groupRatio
|
||||
preConsumedQuota = int(float64(preConsumedTokens) * ratio)
|
||||
} else {
|
||||
preConsumedQuota = int(modelPrice * common.QuotaPerUnit * groupRatio)
|
||||
}
|
||||
relayInfo.PromptTokens = promptToken
|
||||
|
||||
priceData := helper.ModelPriceHelper(c, relayInfo, promptToken, 0)
|
||||
|
||||
// pre-consume quota 预消耗配额
|
||||
preConsumedQuota, userQuota, openaiErr := preConsumeQuota(c, preConsumedQuota, relayInfo)
|
||||
preConsumedQuota, userQuota, openaiErr := preConsumeQuota(c, priceData.ShouldPreConsumedQuota, relayInfo)
|
||||
if openaiErr != nil {
|
||||
return openaiErr
|
||||
}
|
||||
@@ -132,6 +109,6 @@ func EmbeddingHelper(c *gin.Context) (openaiErr *dto.OpenAIErrorWithStatusCode)
|
||||
service.ResetStatusCode(openaiErr, statusCodeMappingStr)
|
||||
return openaiErr
|
||||
}
|
||||
postConsumeQuota(c, relayInfo, embeddingRequest.Model, usage.(*dto.Usage), ratio, preConsumedQuota, userQuota, modelRatio, groupRatio, modelPrice, success, "")
|
||||
postConsumeQuota(c, relayInfo, usage.(*dto.Usage), preConsumedQuota, userQuota, priceData, "")
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -9,8 +9,8 @@ import (
|
||||
"one-api/common"
|
||||
"one-api/dto"
|
||||
relaycommon "one-api/relay/common"
|
||||
"one-api/relay/helper"
|
||||
"one-api/service"
|
||||
"one-api/setting"
|
||||
)
|
||||
|
||||
func getRerankPromptToken(rerankRequest dto.RerankRequest) int {
|
||||
@@ -40,43 +40,20 @@ func RerankHelper(c *gin.Context, relayMode int) (openaiErr *dto.OpenAIErrorWith
|
||||
return service.OpenAIErrorWrapperLocal(fmt.Errorf("documents is empty"), "invalid_documents", http.StatusBadRequest)
|
||||
}
|
||||
|
||||
// map model name
|
||||
modelMapping := c.GetString("model_mapping")
|
||||
//isModelMapped := false
|
||||
if modelMapping != "" && modelMapping != "{}" {
|
||||
modelMap := make(map[string]string)
|
||||
err := json.Unmarshal([]byte(modelMapping), &modelMap)
|
||||
if err != nil {
|
||||
return service.OpenAIErrorWrapperLocal(err, "unmarshal_model_mapping_failed", http.StatusInternalServerError)
|
||||
}
|
||||
if modelMap[rerankRequest.Model] != "" {
|
||||
rerankRequest.Model = modelMap[rerankRequest.Model]
|
||||
// set upstream model name
|
||||
//isModelMapped = true
|
||||
}
|
||||
err = helper.ModelMappedHelper(c, relayInfo)
|
||||
if err != nil {
|
||||
return service.OpenAIErrorWrapperLocal(err, "model_mapped_error", http.StatusInternalServerError)
|
||||
}
|
||||
|
||||
relayInfo.UpstreamModelName = rerankRequest.Model
|
||||
modelPrice, success := common.GetModelPrice(rerankRequest.Model, false)
|
||||
groupRatio := setting.GetGroupRatio(relayInfo.Group)
|
||||
|
||||
var preConsumedQuota int
|
||||
var ratio float64
|
||||
var modelRatio float64
|
||||
rerankRequest.Model = relayInfo.UpstreamModelName
|
||||
|
||||
promptToken := getRerankPromptToken(*rerankRequest)
|
||||
if !success {
|
||||
preConsumedTokens := promptToken
|
||||
modelRatio = common.GetModelRatio(rerankRequest.Model)
|
||||
ratio = modelRatio * groupRatio
|
||||
preConsumedQuota = int(float64(preConsumedTokens) * ratio)
|
||||
} else {
|
||||
preConsumedQuota = int(modelPrice * common.QuotaPerUnit * groupRatio)
|
||||
}
|
||||
relayInfo.PromptTokens = promptToken
|
||||
|
||||
priceData := helper.ModelPriceHelper(c, relayInfo, promptToken, 0)
|
||||
|
||||
// pre-consume quota 预消耗配额
|
||||
preConsumedQuota, userQuota, openaiErr := preConsumeQuota(c, preConsumedQuota, relayInfo)
|
||||
preConsumedQuota, userQuota, openaiErr := preConsumeQuota(c, priceData.ShouldPreConsumedQuota, relayInfo)
|
||||
if openaiErr != nil {
|
||||
return openaiErr
|
||||
}
|
||||
@@ -124,6 +101,6 @@ func RerankHelper(c *gin.Context, relayMode int) (openaiErr *dto.OpenAIErrorWith
|
||||
service.ResetStatusCode(openaiErr, statusCodeMappingStr)
|
||||
return openaiErr
|
||||
}
|
||||
postConsumeQuota(c, relayInfo, rerankRequest.Model, usage.(*dto.Usage), ratio, preConsumedQuota, userQuota, modelRatio, groupRatio, modelPrice, success, "")
|
||||
postConsumeQuota(c, relayInfo, usage.(*dto.Usage), preConsumedQuota, userQuota, priceData, "")
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -16,6 +16,10 @@ func GenerateTextOtherInfo(ctx *gin.Context, relayInfo *relaycommon.RelayInfo, m
|
||||
if relayInfo.ReasoningEffort != "" {
|
||||
other["reasoning_effort"] = relayInfo.ReasoningEffort
|
||||
}
|
||||
if relayInfo.IsModelMapped {
|
||||
other["is_model_mapped"] = true
|
||||
other["upstream_model_name"] = relayInfo.UpstreamModelName
|
||||
}
|
||||
adminInfo := make(map[string]interface{})
|
||||
adminInfo["use_channel"] = ctx.GetStringSlice("use_channel")
|
||||
other["admin_info"] = adminInfo
|
||||
|
||||
@@ -10,6 +10,7 @@ import (
|
||||
"one-api/dto"
|
||||
"one-api/model"
|
||||
relaycommon "one-api/relay/common"
|
||||
"one-api/relay/helper"
|
||||
"one-api/setting"
|
||||
"strings"
|
||||
"time"
|
||||
@@ -68,7 +69,7 @@ func PreWssConsumeQuota(ctx *gin.Context, relayInfo *relaycommon.RelayInfo, usag
|
||||
return err
|
||||
}
|
||||
|
||||
modelName := relayInfo.UpstreamModelName
|
||||
modelName := relayInfo.OriginModelName
|
||||
textInputTokens := usage.InputTokenDetails.TextTokens
|
||||
textOutTokens := usage.OutputTokenDetails.TextTokens
|
||||
audioInputTokens := usage.InputTokenDetails.AudioTokens
|
||||
@@ -94,11 +95,11 @@ func PreWssConsumeQuota(ctx *gin.Context, relayInfo *relaycommon.RelayInfo, usag
|
||||
quota := calculateAudioQuota(quotaInfo)
|
||||
|
||||
if userQuota < quota {
|
||||
return errors.New(fmt.Sprintf("用户额度不足,剩余额度为 %d", userQuota))
|
||||
return fmt.Errorf("user quota is not enough, user quota: %s, need quota: %s", common.FormatQuota(userQuota), common.FormatQuota(quota))
|
||||
}
|
||||
|
||||
if !token.UnlimitedQuota && token.RemainQuota < quota {
|
||||
return errors.New(fmt.Sprintf("令牌额度不足,剩余额度为 %d", token.RemainQuota))
|
||||
return fmt.Errorf("token quota is not enough, token remain quota: %s, need quota: %s", common.FormatQuota(token.RemainQuota), common.FormatQuota(quota))
|
||||
}
|
||||
|
||||
err = PostConsumeQuota(relayInfo, quota, 0, false)
|
||||
@@ -122,7 +123,7 @@ func PostWssConsumeQuota(ctx *gin.Context, relayInfo *relaycommon.RelayInfo, mod
|
||||
|
||||
tokenName := ctx.GetString("token_name")
|
||||
completionRatio := common.GetCompletionRatio(modelName)
|
||||
audioRatio := common.GetAudioRatio(relayInfo.UpstreamModelName)
|
||||
audioRatio := common.GetAudioRatio(relayInfo.OriginModelName)
|
||||
audioCompletionRatio := common.GetAudioCompletionRatio(modelName)
|
||||
|
||||
quotaInfo := QuotaInfo{
|
||||
@@ -173,8 +174,7 @@ func PostWssConsumeQuota(ctx *gin.Context, relayInfo *relaycommon.RelayInfo, mod
|
||||
}
|
||||
|
||||
func PostAudioConsumeQuota(ctx *gin.Context, relayInfo *relaycommon.RelayInfo,
|
||||
usage *dto.Usage, preConsumedQuota int, userQuota int, modelRatio float64, groupRatio float64,
|
||||
modelPrice float64, usePrice bool, extraContent string) {
|
||||
usage *dto.Usage, preConsumedQuota int, userQuota int, priceData helper.PriceData, extraContent string) {
|
||||
|
||||
useTimeSeconds := time.Now().Unix() - relayInfo.StartTime.Unix()
|
||||
textInputTokens := usage.PromptTokensDetails.TextTokens
|
||||
@@ -184,9 +184,14 @@ func PostAudioConsumeQuota(ctx *gin.Context, relayInfo *relaycommon.RelayInfo,
|
||||
audioOutTokens := usage.CompletionTokenDetails.AudioTokens
|
||||
|
||||
tokenName := ctx.GetString("token_name")
|
||||
completionRatio := common.GetCompletionRatio(relayInfo.RecodeModelName)
|
||||
audioRatio := common.GetAudioRatio(relayInfo.RecodeModelName)
|
||||
audioCompletionRatio := common.GetAudioCompletionRatio(relayInfo.RecodeModelName)
|
||||
completionRatio := common.GetCompletionRatio(relayInfo.OriginModelName)
|
||||
audioRatio := common.GetAudioRatio(relayInfo.OriginModelName)
|
||||
audioCompletionRatio := common.GetAudioCompletionRatio(relayInfo.OriginModelName)
|
||||
|
||||
modelRatio := priceData.ModelRatio
|
||||
groupRatio := priceData.GroupRatio
|
||||
modelPrice := priceData.ModelPrice
|
||||
usePrice := priceData.UsePrice
|
||||
|
||||
quotaInfo := QuotaInfo{
|
||||
InputDetails: TokenDetails{
|
||||
@@ -197,7 +202,7 @@ func PostAudioConsumeQuota(ctx *gin.Context, relayInfo *relaycommon.RelayInfo,
|
||||
TextTokens: textOutTokens,
|
||||
AudioTokens: audioOutTokens,
|
||||
},
|
||||
ModelName: relayInfo.RecodeModelName,
|
||||
ModelName: relayInfo.OriginModelName,
|
||||
UsePrice: usePrice,
|
||||
ModelRatio: modelRatio,
|
||||
GroupRatio: groupRatio,
|
||||
@@ -220,7 +225,7 @@ func PostAudioConsumeQuota(ctx *gin.Context, relayInfo *relaycommon.RelayInfo,
|
||||
quota = 0
|
||||
logContent += fmt.Sprintf("(可能是上游超时)")
|
||||
common.LogError(ctx, fmt.Sprintf("total tokens is 0, cannot consume quota, userId %d, channelId %d, "+
|
||||
"tokenId %d, model %s, pre-consumed quota %d", relayInfo.UserId, relayInfo.ChannelId, relayInfo.TokenId, relayInfo.RecodeModelName, preConsumedQuota))
|
||||
"tokenId %d, model %s, pre-consumed quota %d", relayInfo.UserId, relayInfo.ChannelId, relayInfo.TokenId, relayInfo.OriginModelName, preConsumedQuota))
|
||||
} else {
|
||||
quotaDelta := quota - preConsumedQuota
|
||||
if quotaDelta != 0 {
|
||||
@@ -233,7 +238,7 @@ func PostAudioConsumeQuota(ctx *gin.Context, relayInfo *relaycommon.RelayInfo,
|
||||
model.UpdateChannelUsedQuota(relayInfo.ChannelId, quota)
|
||||
}
|
||||
|
||||
logModel := relayInfo.RecodeModelName
|
||||
logModel := relayInfo.OriginModelName
|
||||
if extraContent != "" {
|
||||
logContent += ", " + extraContent
|
||||
}
|
||||
@@ -257,7 +262,7 @@ func PreConsumeTokenQuota(relayInfo *relaycommon.RelayInfo, quota int) error {
|
||||
return err
|
||||
}
|
||||
if !relayInfo.TokenUnlimited && token.RemainQuota < quota {
|
||||
return errors.New("令牌额度不足")
|
||||
return fmt.Errorf("token quota is not enough, token remain quota: %s, need quota: %s", common.FormatQuota(token.RemainQuota), common.FormatQuota(quota))
|
||||
}
|
||||
err = model.DecreaseTokenQuota(relayInfo.TokenId, relayInfo.TokenKey, quota)
|
||||
if err != nil {
|
||||
|
||||
@@ -8,48 +8,47 @@ import (
|
||||
"strings"
|
||||
)
|
||||
|
||||
func CheckSensitiveMessages(messages []dto.Message) error {
|
||||
func CheckSensitiveMessages(messages []dto.Message) ([]string, error) {
|
||||
if len(messages) == 0 {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
for _, message := range messages {
|
||||
if len(message.Content) > 0 {
|
||||
if message.IsStringContent() {
|
||||
stringContent := message.StringContent()
|
||||
if ok, words := SensitiveWordContains(stringContent); ok {
|
||||
return errors.New("sensitive words: " + strings.Join(words, ","))
|
||||
}
|
||||
arrayContent := message.ParseContent()
|
||||
for _, m := range arrayContent {
|
||||
if m.Type == "image_url" {
|
||||
// TODO: check image url
|
||||
continue
|
||||
}
|
||||
} else {
|
||||
arrayContent := message.ParseContent()
|
||||
for _, m := range arrayContent {
|
||||
if m.Type == "image_url" {
|
||||
// TODO: check image url
|
||||
} else {
|
||||
if ok, words := SensitiveWordContains(m.Text); ok {
|
||||
return errors.New("sensitive words: " + strings.Join(words, ","))
|
||||
}
|
||||
}
|
||||
// 检查 text 是否为空
|
||||
if m.Text == "" {
|
||||
continue
|
||||
}
|
||||
if ok, words := SensitiveWordContains(m.Text); ok {
|
||||
return words, errors.New("sensitive words detected")
|
||||
}
|
||||
}
|
||||
}
|
||||
return nil
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
func CheckSensitiveText(text string) error {
|
||||
func CheckSensitiveText(text string) ([]string, error) {
|
||||
if ok, words := SensitiveWordContains(text); ok {
|
||||
return errors.New("sensitive words: " + strings.Join(words, ","))
|
||||
return words, errors.New("sensitive words detected")
|
||||
}
|
||||
return nil
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
func CheckSensitiveInput(input any) error {
|
||||
func CheckSensitiveInput(input any) ([]string, error) {
|
||||
switch v := input.(type) {
|
||||
case string:
|
||||
return CheckSensitiveText(v)
|
||||
case []string:
|
||||
text := ""
|
||||
var builder strings.Builder
|
||||
for _, s := range v {
|
||||
text += s
|
||||
builder.WriteString(s)
|
||||
}
|
||||
return CheckSensitiveText(text)
|
||||
return CheckSensitiveText(builder.String())
|
||||
}
|
||||
return CheckSensitiveText(fmt.Sprintf("%v", input))
|
||||
}
|
||||
@@ -59,8 +58,11 @@ func SensitiveWordContains(text string) (bool, []string) {
|
||||
if len(setting.SensitiveWords) == 0 {
|
||||
return false, nil
|
||||
}
|
||||
if len(text) == 0 {
|
||||
return false, nil
|
||||
}
|
||||
checkText := strings.ToLower(text)
|
||||
return AcSearch(checkText, setting.SensitiveWords, false)
|
||||
return AcSearch(checkText, setting.SensitiveWords, true)
|
||||
}
|
||||
|
||||
// SensitiveWordReplace 敏感词替换,返回是否包含敏感词和替换后的文本
|
||||
@@ -72,14 +74,21 @@ func SensitiveWordReplace(text string, returnImmediately bool) (bool, []string,
|
||||
m := InitAc(setting.SensitiveWords)
|
||||
hits := m.MultiPatternSearch([]rune(checkText), returnImmediately)
|
||||
if len(hits) > 0 {
|
||||
words := make([]string, 0)
|
||||
words := make([]string, 0, len(hits))
|
||||
var builder strings.Builder
|
||||
builder.Grow(len(text))
|
||||
lastPos := 0
|
||||
|
||||
for _, hit := range hits {
|
||||
pos := hit.Pos
|
||||
word := string(hit.Word)
|
||||
text = text[:pos] + "**###**" + text[pos+len(word):]
|
||||
builder.WriteString(text[lastPos:pos])
|
||||
builder.WriteString("**###**")
|
||||
lastPos = pos + len(word)
|
||||
words = append(words, word)
|
||||
}
|
||||
return true, words, text
|
||||
builder.WriteString(text[lastPos:])
|
||||
return true, words, builder.String()
|
||||
}
|
||||
return false, nil, text
|
||||
}
|
||||
|
||||
@@ -15,7 +15,7 @@ import {
|
||||
Button, Descriptions,
|
||||
Form,
|
||||
Layout,
|
||||
Modal,
|
||||
Modal, Popover,
|
||||
Select,
|
||||
Space,
|
||||
Spin,
|
||||
@@ -34,6 +34,7 @@ import {
|
||||
import Paragraph from '@douyinfe/semi-ui/lib/es/typography/paragraph';
|
||||
import { getLogOther } from '../helpers/other.js';
|
||||
import { StyleContext } from '../context/Style/index.js';
|
||||
import { IconInherit, IconRefresh } from '@douyinfe/semi-icons';
|
||||
|
||||
const { Header } = Layout;
|
||||
|
||||
@@ -141,7 +142,78 @@ const LogsTable = () => {
|
||||
</Tag>
|
||||
);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
function renderModelName(record) {
|
||||
|
||||
let other = getLogOther(record.other);
|
||||
let modelMapped = other?.is_model_mapped && other?.upstream_model_name && other?.upstream_model_name !== '';
|
||||
if (!modelMapped) {
|
||||
return <Tag
|
||||
color={stringToColor(record.model_name)}
|
||||
size='large'
|
||||
onClick={(event) => {
|
||||
copyText(event, record.model_name).then(r => {});
|
||||
}}
|
||||
>
|
||||
{' '}{record.model_name}{' '}
|
||||
</Tag>;
|
||||
} else {
|
||||
return (
|
||||
<>
|
||||
<Space vertical align={'start'}>
|
||||
<Popover content={
|
||||
<div style={{padding: 10}}>
|
||||
<Space vertical align={'start'}>
|
||||
<Tag
|
||||
color={stringToColor(record.model_name)}
|
||||
size='large'
|
||||
onClick={(event) => {
|
||||
copyText(event, record.model_name).then(r => {});
|
||||
}}
|
||||
>
|
||||
{t('请求并计费模型')}{' '}{record.model_name}{' '}
|
||||
</Tag>
|
||||
<Tag
|
||||
color={stringToColor(other.upstream_model_name)}
|
||||
size='large'
|
||||
onClick={(event) => {
|
||||
copyText(event, other.upstream_model_name).then(r => {});
|
||||
}}
|
||||
>
|
||||
{t('实际模型')}{' '}{other.upstream_model_name}{' '}
|
||||
</Tag>
|
||||
</Space>
|
||||
</div>
|
||||
}>
|
||||
<Tag
|
||||
color={stringToColor(record.model_name)}
|
||||
size='large'
|
||||
onClick={(event) => {
|
||||
copyText(event, record.model_name).then(r => {});
|
||||
}}
|
||||
suffixIcon={<IconRefresh />}
|
||||
>
|
||||
{' '}{record.model_name}{' '}
|
||||
</Tag>
|
||||
</Popover>
|
||||
{/*<Tooltip content={t('实际模型')}>*/}
|
||||
{/* <Tag*/}
|
||||
{/* color={stringToColor(other.upstream_model_name)}*/}
|
||||
{/* size='large'*/}
|
||||
{/* onClick={(event) => {*/}
|
||||
{/* copyText(event, other.upstream_model_name).then(r => {});*/}
|
||||
{/* }}*/}
|
||||
{/* >*/}
|
||||
{/* {' '}{other.upstream_model_name}{' '}*/}
|
||||
{/* </Tag>*/}
|
||||
{/*</Tooltip>*/}
|
||||
</Space>
|
||||
</>
|
||||
);
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
const columns = [
|
||||
{
|
||||
@@ -272,18 +344,7 @@ const LogsTable = () => {
|
||||
dataIndex: 'model_name',
|
||||
render: (text, record, index) => {
|
||||
return record.type === 0 || record.type === 2 ? (
|
||||
<>
|
||||
<Tag
|
||||
color={stringToColor(text)}
|
||||
size='large'
|
||||
onClick={(event) => {
|
||||
copyText(event, text);
|
||||
}}
|
||||
>
|
||||
{' '}
|
||||
{text}{' '}
|
||||
</Tag>
|
||||
</>
|
||||
<>{renderModelName(record)}</>
|
||||
) : (
|
||||
<></>
|
||||
);
|
||||
@@ -580,6 +641,17 @@ const LogsTable = () => {
|
||||
value: logs[i].content,
|
||||
});
|
||||
if (logs[i].type === 2) {
|
||||
let modelMapped = other?.is_model_mapped && other?.upstream_model_name && other?.upstream_model_name !== '';
|
||||
if (modelMapped) {
|
||||
expandDataLocal.push({
|
||||
key: t('请求并计费模型'),
|
||||
value: logs[i].model_name,
|
||||
});
|
||||
expandDataLocal.push({
|
||||
key: t('实际模型'),
|
||||
value: other.upstream_model_name,
|
||||
});
|
||||
}
|
||||
let content = '';
|
||||
if (other?.ws || other?.audio) {
|
||||
content = renderAudioModelPrice(
|
||||
|
||||
@@ -1249,5 +1249,26 @@
|
||||
"已注销": "Logged out",
|
||||
"自动禁用关键词": "Automatic disable keywords",
|
||||
"一行一个,不区分大小写": "One line per keyword, not case-sensitive",
|
||||
"当上游通道返回错误中包含这些关键词时(不区分大小写),自动禁用通道": "When the upstream channel returns an error containing these keywords (not case-sensitive), automatically disable the channel"
|
||||
"当上游通道返回错误中包含这些关键词时(不区分大小写),自动禁用通道": "When the upstream channel returns an error containing these keywords (not case-sensitive), automatically disable the channel",
|
||||
"请求并计费模型": "Request and charge model",
|
||||
"实际模型": "Actual model",
|
||||
"渠道信息": "Channel information",
|
||||
"通知设置": "Notification settings",
|
||||
"Webhook地址": "Webhook URL",
|
||||
"请输入Webhook地址,例如: https://example.com/webhook": "Please enter the Webhook URL, e.g.: https://example.com/webhook",
|
||||
"邮件通知": "Email notification",
|
||||
"Webhook通知": "Webhook notification",
|
||||
"接口凭证(可选)": "Interface credentials (optional)",
|
||||
"密钥将以 Bearer 方式添加到请求头中,用于验证webhook请求的合法性": "The secret will be added to the request header as a Bearer token to verify the legitimacy of the webhook request",
|
||||
"Authorization: Bearer your-secret-key": "Authorization: Bearer your-secret-key",
|
||||
"额度预警阈值": "Quota warning threshold",
|
||||
"当剩余额度低于此数值时,系统将通过选择的方式发送通知": "When the remaining quota is lower than this value, the system will send a notification through the selected method",
|
||||
"Webhook请求结构": "Webhook request structure",
|
||||
"只支持https,系统将以 POST 方式发送通知,请确保地址可以接收 POST 请求": "Only https is supported, the system will send a notification through POST, please ensure the address can receive POST requests",
|
||||
"保存设置": "Save settings",
|
||||
"通知邮箱": "Notification email",
|
||||
"设置用于接收额度预警的邮箱地址,不填则使用账号绑定的邮箱": "Set the email address for receiving quota warning notifications, if not set, the email address bound to the account will be used",
|
||||
"留空则使用账号绑定的邮箱": "If left blank, the email address bound to the account will be used",
|
||||
"代理站地址": "Base URL",
|
||||
"对于官方渠道,new-api已经内置地址,除非是第三方代理站点或者Azure的特殊接入地址,否则不需要填写": "For official channels, the new-api has a built-in address. Unless it is a third-party proxy site or a special Azure access address, there is no need to fill it in"
|
||||
}
|
||||
@@ -540,21 +540,23 @@ const EditChannel = (props) => {
|
||||
value={inputs.name}
|
||||
autoComplete="new-password"
|
||||
/>
|
||||
{inputs.type !== 3 && inputs.type !== 8 && inputs.type !== 22 && inputs.type !== 36 && (
|
||||
{inputs.type !== 3 && inputs.type !== 8 && inputs.type !== 22 && inputs.type !== 36 && inputs.type !== 45 && (
|
||||
<>
|
||||
<div style={{ marginTop: 10 }}>
|
||||
<Typography.Text strong>{t('BaseURL')}:</Typography.Text>
|
||||
<Typography.Text strong>{t('代理站地址')}:</Typography.Text>
|
||||
</div>
|
||||
<Input
|
||||
label={t('BaseURL')}
|
||||
name="base_url"
|
||||
placeholder={t('此项可选,用于通过代理站来进行 API 调用,末尾不要带/v1和/')}
|
||||
onChange={(value) => {
|
||||
handleInputChange('base_url', value);
|
||||
}}
|
||||
value={inputs.base_url}
|
||||
autoComplete="new-password"
|
||||
/>
|
||||
<Tooltip content={t('对于官方渠道,new-api已经内置地址,除非是第三方代理站点或者Azure的特殊接入地址,否则不需要填写')}>
|
||||
<Input
|
||||
label={t('代理站地址')}
|
||||
name="base_url"
|
||||
placeholder={t('此项可选,用于通过代理站来进行 API 调用,末尾不要带/v1和/')}
|
||||
onChange={(value) => {
|
||||
handleInputChange('base_url', value);
|
||||
}}
|
||||
value={inputs.base_url}
|
||||
autoComplete="new-password"
|
||||
/>
|
||||
</Tooltip>
|
||||
</>
|
||||
)}
|
||||
<div style={{ marginTop: 10 }}>
|
||||
|
||||
Reference in New Issue
Block a user