Compare commits

...

4 Commits

13 changed files with 187 additions and 50 deletions

3
.gitignore vendored
View File

@@ -9,4 +9,5 @@ logs
web/dist
.env
one-api
.DS_Store
.DS_Store
tiktoken_cache

View File

@@ -73,25 +73,25 @@ func LoadEnv() {
DebugEnabled = os.Getenv("DEBUG") == "true"
MemoryCacheEnabled = os.Getenv("MEMORY_CACHE_ENABLED") == "true"
IsMasterNode = os.Getenv("NODE_TYPE") != "slave"
// Parse requestInterval and set RequestInterval
requestInterval, _ = strconv.Atoi(os.Getenv("POLLING_INTERVAL"))
RequestInterval = time.Duration(requestInterval) * time.Second
// Initialize variables with GetEnvOrDefault
SyncFrequency = GetEnvOrDefault("SYNC_FREQUENCY", 60)
BatchUpdateInterval = GetEnvOrDefault("BATCH_UPDATE_INTERVAL", 5)
RelayTimeout = GetEnvOrDefault("RELAY_TIMEOUT", 0)
// Initialize string variables with GetEnvOrDefaultString
GeminiSafetySetting = GetEnvOrDefaultString("GEMINI_SAFETY_SETTING", "BLOCK_NONE")
CohereSafetySetting = GetEnvOrDefaultString("COHERE_SAFETY_SETTING", "NONE")
// Initialize rate limit variables
GlobalApiRateLimitEnable = GetEnvOrDefaultBool("GLOBAL_API_RATE_LIMIT_ENABLE", true)
GlobalApiRateLimitNum = GetEnvOrDefault("GLOBAL_API_RATE_LIMIT", 180)
GlobalApiRateLimitDuration = int64(GetEnvOrDefault("GLOBAL_API_RATE_LIMIT_DURATION", 180))
GlobalWebRateLimitEnable = GetEnvOrDefaultBool("GLOBAL_WEB_RATE_LIMIT_ENABLE", true)
GlobalWebRateLimitNum = GetEnvOrDefault("GLOBAL_WEB_RATE_LIMIT", 60)
GlobalWebRateLimitDuration = int64(GetEnvOrDefault("GLOBAL_WEB_RATE_LIMIT_DURATION", 180))

View File

@@ -15,6 +15,7 @@ services:
- SQL_DSN=root:123456@tcp(mysql:3306)/new-api # Point to the mysql service
- REDIS_CONN_STRING=redis://redis
- TZ=Asia/Shanghai
# - TIKTOKEN_CACHE_DIR=./tiktoken_cache # 如果需要使用tiktoken_cache请取消注释
# - SESSION_SECRET=random_string # 多机部署时设置,必须修改这个随机字符串!!!!!!!
# - NODE_TYPE=slave # Uncomment for slave node in multi-node deployment
# - SYNC_FREQUENCY=60 # Uncomment if regular database syncing is needed

View File

@@ -7,7 +7,7 @@ type ClaudeMetadata struct {
}
type ClaudeMediaMessage struct {
Type string `json:"type"`
Type string `json:"type,omitempty"`
Text *string `json:"text,omitempty"`
Model string `json:"model,omitempty"`
Source *ClaudeMessageSource `json:"source,omitempty"`
@@ -50,6 +50,11 @@ func (c *ClaudeMediaMessage) GetStringContent() string {
return ""
}
func (c *ClaudeMediaMessage) GetJsonRowString() string {
jsonContent, _ := json.Marshal(c)
return string(jsonContent)
}
func (c *ClaudeMediaMessage) SetContent(content any) {
jsonContent, _ := json.Marshal(content)
c.Content = jsonContent

View File

@@ -111,6 +111,7 @@ type MediaContent struct {
Text string `json:"text,omitempty"`
ImageUrl any `json:"image_url,omitempty"`
InputAudio any `json:"input_audio,omitempty"`
File any `json:"file,omitempty"`
}
func (m *MediaContent) GetImageMedia() *MessageImageUrl {
@@ -120,6 +121,20 @@ func (m *MediaContent) GetImageMedia() *MessageImageUrl {
return nil
}
func (m *MediaContent) GetInputAudio() *MessageInputAudio {
if m.InputAudio != nil {
return m.InputAudio.(*MessageInputAudio)
}
return nil
}
func (m *MediaContent) GetFile() *MessageFile {
if m.File != nil {
return m.File.(*MessageFile)
}
return nil
}
type MessageImageUrl struct {
Url string `json:"url"`
Detail string `json:"detail"`
@@ -135,10 +150,17 @@ type MessageInputAudio struct {
Format string `json:"format"`
}
type MessageFile struct {
FileName string `json:"filename,omitempty"`
FileData string `json:"file_data,omitempty"`
FileId string `json:"file_id,omitempty"`
}
const (
ContentTypeText = "text"
ContentTypeImageURL = "image_url"
ContentTypeInputAudio = "input_audio"
ContentTypeFile = "file"
)
func (m *Message) GetPrefix() bool {
@@ -192,6 +214,12 @@ func (m *Message) StringContent() string {
return stringContent
}
func (m *Message) SetNullContent() {
m.Content = nil
m.parsedStringContent = nil
m.parsedContent = nil
}
func (m *Message) SetStringContent(content string) {
jsonContent, _ := json.Marshal(content)
m.Content = jsonContent
@@ -292,6 +320,30 @@ func (m *Message) ParseContent() []MediaContent {
})
}
}
case ContentTypeFile:
if fileData, ok := contentItem["file"].(map[string]interface{}); ok {
fileId, ok3 := fileData["file_id"].(string)
if ok3 {
contentList = append(contentList, MediaContent{
Type: ContentTypeFile,
File: &MessageFile{
FileId: fileId,
},
})
} else {
fileName, ok1 := fileData["filename"].(string)
fileDataStr, ok2 := fileData["file_data"].(string)
if ok1 && ok2 {
contentList = append(contentList, MediaContent{
Type: ContentTypeFile,
File: &MessageFile{
FileName: fileName,
FileData: fileDataStr,
},
})
}
}
}
}
}
}

View File

@@ -34,7 +34,7 @@ var indexPage []byte
func main() {
err := godotenv.Load(".env")
if err != nil {
common.SysLog("Support for .env file is disabled")
common.SysLog("Support for .env file is disabled: " + err.Error())
}
common.LoadEnv()

View File

@@ -208,6 +208,34 @@ func CovertGemini2OpenAI(textRequest dto.GeneralOpenAIRequest) (*GeminiChatReque
},
})
}
} else if part.Type == dto.ContentTypeFile {
if part.GetFile().FileId != "" {
return nil, fmt.Errorf("only base64 file is supported in gemini")
}
format, base64String, err := service.DecodeBase64FileData(part.GetFile().FileData)
if err != nil {
return nil, fmt.Errorf("decode base64 file data failed: %s", err.Error())
}
parts = append(parts, GeminiPart{
InlineData: &GeminiInlineData{
MimeType: format,
Data: base64String,
},
})
} else if part.Type == dto.ContentTypeInputAudio {
if part.GetInputAudio().Data == "" {
return nil, fmt.Errorf("only base64 audio is supported in gemini")
}
format, base64String, err := service.DecodeBase64FileData(part.GetInputAudio().Data)
if err != nil {
return nil, fmt.Errorf("decode base64 audio data failed: %s", err.Error())
}
parts = append(parts, GeminiPart{
InlineData: &GeminiInlineData{
MimeType: format,
Data: base64String,
},
})
}
}
@@ -233,7 +261,6 @@ func CovertGemini2OpenAI(textRequest dto.GeneralOpenAIRequest) (*GeminiChatReque
return &geminiRequest, nil
}
// cleanFunctionParameters recursively removes unsupported fields from Gemini function parameters.
func cleanFunctionParameters(params interface{}) interface{} {
if params == nil {
@@ -307,7 +334,6 @@ func cleanFunctionParameters(params interface{}) interface{} {
cleanedMap["items"] = cleanedItemsArray
}
// Recursively clean other schema composition keywords if necessary
for _, field := range []string{"allOf", "anyOf", "oneOf"} {
if nested, ok := cleanedMap[field].([]interface{}); ok {
@@ -322,7 +348,6 @@ func cleanFunctionParameters(params interface{}) interface{} {
return cleanedMap
}
func removeAdditionalPropertiesWithDepth(schema interface{}, depth int) interface{} {
if depth >= 5 {
return schema

View File

@@ -31,6 +31,9 @@ func handleClaudeFormat(c *gin.Context, data string, info *relaycommon.RelayInfo
return err
}
if streamResponse.Usage != nil {
info.ClaudeConvertInfo.Usage = streamResponse.Usage
}
claudeResponses := service.StreamResponseOpenAI2Claude(&streamResponse, info)
for _, resp := range claudeResponses {
helper.ClaudeData(c, *resp)
@@ -170,15 +173,14 @@ func handleFinalResponse(c *gin.Context, info *relaycommon.RelayInfo, lastStream
helper.Done(c)
case relaycommon.RelayFormatClaude:
info.ClaudeConvertInfo.Done = true
var streamResponse dto.ChatCompletionsStreamResponse
if err := json.Unmarshal(common.StringToByteSlice(lastStreamData), &streamResponse); err != nil {
common.SysError("error unmarshalling stream response: " + err.Error())
return
}
if !containStreamUsage {
streamResponse.Usage = usage
}
info.ClaudeConvertInfo.Usage = usage
claudeResponses := service.StreamResponseOpenAI2Claude(&streamResponse, info)
for _, resp := range claudeResponses {

View File

@@ -170,8 +170,10 @@ func OaiStreamHandler(c *gin.Context, resp *http.Response, info *relaycommon.Rel
}
}
}
if shouldSendLastResp {
sendStreamData(c, info, lastStreamData, forceFormat, thinkToContent)
//err = handleStreamFormat(c, info, lastStreamData, forceFormat, thinkToContent)
}
// 处理token计算

View File

@@ -19,13 +19,18 @@ type ThinkingContentInfo struct {
}
const (
LastMessageTypeText = "text"
LastMessageTypeTools = "tools"
LastMessageTypeNone = "none"
LastMessageTypeText = "text"
LastMessageTypeTools = "tools"
LastMessageTypeThinking = "thinking"
)
type ClaudeConvertInfo struct {
LastMessagesType string
Index int
Usage *dto.Usage
FinishReason string
Done bool
}
const (
@@ -113,7 +118,7 @@ func GenRelayInfoClaude(c *gin.Context) *RelayInfo {
info.RelayFormat = RelayFormatClaude
info.ShouldIncludeUsage = false
info.ClaudeConvertInfo = ClaudeConvertInfo{
LastMessagesType: LastMessageTypeText,
LastMessagesType: LastMessageTypeNone,
}
return info
}

View File

@@ -45,7 +45,7 @@ func ClaudeToOpenAIRequest(claudeRequest dto.ClaudeRequest) (*dto.GeneralOpenAIR
// Add system message if present
if claudeRequest.System != nil {
if claudeRequest.IsStringSystem() {
if claudeRequest.IsStringSystem() && claudeRequest.GetStringSystem() != "" {
openAIMessage := dto.Message{
Role: "system",
}
@@ -122,23 +122,22 @@ func ClaudeToOpenAIRequest(claudeRequest dto.ClaudeRequest) (*dto.GeneralOpenAIR
oaiToolMessage.SetStringContent(mediaMsg.GetStringContent())
} else {
mediaContents := mediaMsg.ParseMediaContent()
if len(mediaContents) > 0 && mediaContents[0].Text != nil {
oaiToolMessage.SetStringContent(*mediaContents[0].Text)
}
encodeJson, _ := common.EncodeJson(mediaContents)
oaiToolMessage.SetStringContent(string(encodeJson))
}
openAIMessages = append(openAIMessages, oaiToolMessage)
}
}
if len(mediaMessages) > 0 {
openAIMessage.SetMediaContent(mediaMessages)
}
if len(toolCalls) > 0 {
openAIMessage.SetToolCalls(toolCalls)
}
if len(mediaMessages) > 0 && len(toolCalls) == 0 {
openAIMessage.SetMediaContent(mediaMessages)
}
}
if len(openAIMessage.ParseContent()) > 0 {
if len(openAIMessage.ParseContent()) > 0 || len(openAIMessage.ToolCalls) > 0 {
openAIMessages = append(openAIMessages, openAIMessage)
}
}
@@ -211,15 +210,15 @@ func StreamResponseOpenAI2Claude(openAIResponse *dto.ChatCompletionsStreamRespon
resp.SetIndex(0)
claudeResponses = append(claudeResponses, resp)
} else {
resp := &dto.ClaudeResponse{
Type: "content_block_start",
ContentBlock: &dto.ClaudeMediaMessage{
Type: "text",
Text: common.GetPointer[string](""),
},
}
resp.SetIndex(0)
claudeResponses = append(claudeResponses, resp)
//resp := &dto.ClaudeResponse{
// Type: "content_block_start",
// ContentBlock: &dto.ClaudeMediaMessage{
// Type: "text",
// Text: common.GetPointer[string](""),
// },
//}
//resp.SetIndex(0)
//claudeResponses = append(claudeResponses, resp)
}
return claudeResponses
}
@@ -232,16 +231,20 @@ func StreamResponseOpenAI2Claude(openAIResponse *dto.ChatCompletionsStreamRespon
chosenChoice := openAIResponse.Choices[0]
if chosenChoice.FinishReason != nil && *chosenChoice.FinishReason != "" {
// should be done
info.FinishReason = *chosenChoice.FinishReason
return claudeResponses
}
if info.Done {
claudeResponses = append(claudeResponses, generateStopBlock(info.ClaudeConvertInfo.Index))
if openAIResponse.Usage != nil {
if info.ClaudeConvertInfo.Usage != nil {
claudeResponses = append(claudeResponses, &dto.ClaudeResponse{
Type: "message_delta",
Usage: &dto.ClaudeUsage{
InputTokens: openAIResponse.Usage.PromptTokens,
OutputTokens: openAIResponse.Usage.CompletionTokens,
InputTokens: info.ClaudeConvertInfo.Usage.PromptTokens,
OutputTokens: info.ClaudeConvertInfo.Usage.CompletionTokens,
},
Delta: &dto.ClaudeMediaMessage{
StopReason: common.GetPointer[string](stopReasonOpenAI2Claude(*chosenChoice.FinishReason)),
StopReason: common.GetPointer[string](stopReasonOpenAI2Claude(info.FinishReason)),
},
})
}
@@ -250,10 +253,10 @@ func StreamResponseOpenAI2Claude(openAIResponse *dto.ChatCompletionsStreamRespon
})
} else {
var claudeResponse dto.ClaudeResponse
claudeResponse.SetIndex(0)
var isEmpty bool
claudeResponse.Type = "content_block_delta"
if len(chosenChoice.Delta.ToolCalls) > 0 {
if info.ClaudeConvertInfo.LastMessagesType == relaycommon.LastMessageTypeText {
if info.ClaudeConvertInfo.LastMessagesType != relaycommon.LastMessageTypeTools {
claudeResponses = append(claudeResponses, generateStopBlock(info.ClaudeConvertInfo.Index))
info.ClaudeConvertInfo.Index++
claudeResponses = append(claudeResponses, &dto.ClaudeResponse{
@@ -274,15 +277,55 @@ func StreamResponseOpenAI2Claude(openAIResponse *dto.ChatCompletionsStreamRespon
PartialJson: &chosenChoice.Delta.ToolCalls[0].Function.Arguments,
}
} else {
info.ClaudeConvertInfo.LastMessagesType = relaycommon.LastMessageTypeText
// text delta
claudeResponse.Delta = &dto.ClaudeMediaMessage{
Type: "text_delta",
Text: common.GetPointer[string](chosenChoice.Delta.GetContentString()),
reasoning := chosenChoice.Delta.GetReasoningContent()
textContent := chosenChoice.Delta.GetContentString()
if reasoning != "" || textContent != "" {
if reasoning != "" {
if info.ClaudeConvertInfo.LastMessagesType != relaycommon.LastMessageTypeThinking {
//info.ClaudeConvertInfo.Index++
claudeResponses = append(claudeResponses, &dto.ClaudeResponse{
Index: &info.ClaudeConvertInfo.Index,
Type: "content_block_start",
ContentBlock: &dto.ClaudeMediaMessage{
Type: "thinking",
Thinking: "",
},
})
}
info.ClaudeConvertInfo.LastMessagesType = relaycommon.LastMessageTypeThinking
// text delta
claudeResponse.Delta = &dto.ClaudeMediaMessage{
Type: "thinking_delta",
Thinking: reasoning,
}
} else {
if info.ClaudeConvertInfo.LastMessagesType != relaycommon.LastMessageTypeText {
claudeResponses = append(claudeResponses, generateStopBlock(info.ClaudeConvertInfo.Index))
info.ClaudeConvertInfo.Index++
claudeResponses = append(claudeResponses, &dto.ClaudeResponse{
Index: &info.ClaudeConvertInfo.Index,
Type: "content_block_start",
ContentBlock: &dto.ClaudeMediaMessage{
Type: "text",
Text: common.GetPointer[string](""),
},
})
}
info.ClaudeConvertInfo.LastMessagesType = relaycommon.LastMessageTypeText
// text delta
claudeResponse.Delta = &dto.ClaudeMediaMessage{
Type: "text_delta",
Text: common.GetPointer[string](textContent),
}
}
} else {
isEmpty = true
}
}
claudeResponse.Index = &info.ClaudeConvertInfo.Index
claudeResponses = append(claudeResponses, &claudeResponse)
if !isEmpty {
claudeResponses = append(claudeResponses, &claudeResponse)
}
}
}

View File

@@ -8,9 +8,9 @@ import (
"one-api/dto"
)
var maxFileSize = constant.MaxFileDownloadMB * 1024 * 1024
func GetFileBase64FromUrl(url string) (*dto.LocalFileData, error) {
var maxFileSize = constant.MaxFileDownloadMB * 1024 * 1024
resp, err := DoDownloadRequest(url)
if err != nil {
return nil, err
@@ -22,7 +22,6 @@ func GetFileBase64FromUrl(url string) (*dto.LocalFileData, error) {
if err != nil {
return nil, err
}
// Check actual size after reading
if len(fileBytes) > maxFileSize {
return nil, fmt.Errorf("file size exceeds maximum allowed size: %dMB", constant.MaxFileDownloadMB)

View File

@@ -398,6 +398,8 @@ func CountTokenMessages(info *relaycommon.RelayInfo, messages []dto.Message, mod
} else if m.Type == dto.ContentTypeInputAudio {
// TODO: 音频token数量计算
tokenNum += 100
} else if m.Type == dto.ContentTypeFile {
tokenNum += 5000
} else {
tokenNum += getTokenNum(tokenEncoder, m.Text)
}