Compare commits

...

9 Commits

Author SHA1 Message Date
Calcium-Ion
40e640511b Merge pull request #1139 from RedwindA/gemini-fix
feat: 增加对Gemini MimeType类型的验证
2025-06-02 22:33:01 +08:00
Calcium-Ion
5930bb88bf Merge pull request #1140 from RedwindA/gemini-tool-fix
fix: 完善Gemini渠道对tools中additionalProperties的清理
2025-06-02 22:32:43 +08:00
Calcium-Ion
8948e99eeb Merge pull request #1141 from xqx121/patch-1
Fix: The edit interface is not billed (usage-based pricing).
2025-06-02 22:32:18 +08:00
xqx121
37caafc722 Fix: The edit interface is not billed (usage-based pricing). 2025-06-02 22:11:11 +08:00
RedwindA
148c974912 feat: 增加对GeminiMIME类型的验证 2025-06-02 19:00:55 +08:00
RedwindA
f1ee9a301d refactor: enhance cleanFunctionParameters for improved handling of JSON schema, including support for $defs and conditional keywords 2025-06-01 02:08:13 +08:00
CaIon
611d77e1a9 feat: add ToMap method and enhance OpenAI request handling 2025-06-01 01:10:10 +08:00
Calcium-Ion
b05bb899f1 Merge pull request #1134 from QuantumNous/fix_ping_keepalive
fix: 流式请求ping
2025-05-31 22:16:16 +08:00
creamlike1024
c51a30b862 fix: 流式请求ping 2025-05-31 22:13:17 +08:00
5 changed files with 221 additions and 129 deletions

View File

@@ -2,6 +2,7 @@ package dto
import (
"encoding/json"
"one-api/common"
"strings"
)
@@ -57,6 +58,13 @@ type GeneralOpenAIRequest struct {
WebSearchOptions *WebSearchOptions `json:"web_search_options,omitempty"`
}
func (r *GeneralOpenAIRequest) ToMap() map[string]any {
result := make(map[string]any)
data, _ := common.EncodeJson(r)
_ = common.DecodeJson(data, &result)
return result
}
type ToolCallRequest struct {
ID string `json:"id,omitempty"`
Type string `json:"type"`
@@ -74,11 +82,11 @@ type StreamOptions struct {
IncludeUsage bool `json:"include_usage,omitempty"`
}
func (r GeneralOpenAIRequest) GetMaxTokens() int {
func (r *GeneralOpenAIRequest) GetMaxTokens() int {
return int(r.MaxTokens)
}
func (r GeneralOpenAIRequest) ParseInput() []string {
func (r *GeneralOpenAIRequest) ParseInput() []string {
if r.Input == nil {
return nil
}

View File

@@ -104,6 +104,65 @@ func DoWssRequest(a Adaptor, c *gin.Context, info *common.RelayInfo, requestBody
return targetConn, nil
}
func startPingKeepAlive(c *gin.Context, pingInterval time.Duration) context.CancelFunc {
pingerCtx, stopPinger := context.WithCancel(context.Background())
gopool.Go(func() {
defer func() {
if common2.DebugEnabled {
println("SSE ping goroutine stopped.")
}
}()
if pingInterval <= 0 {
pingInterval = helper.DefaultPingInterval
}
ticker := time.NewTicker(pingInterval)
// 退出时清理 ticker
defer ticker.Stop()
var pingMutex sync.Mutex
if common2.DebugEnabled {
println("SSE ping goroutine started")
}
for {
select {
// 发送 ping 数据
case <-ticker.C:
if err := sendPingData(c, &pingMutex); err != nil {
return
}
// 收到退出信号
case <-pingerCtx.Done():
return
// request 结束
case <-c.Request.Context().Done():
return
}
}
})
return stopPinger
}
func sendPingData(c *gin.Context, mutex *sync.Mutex) error {
mutex.Lock()
defer mutex.Unlock()
err := helper.PingData(c)
if err != nil {
common2.LogError(c, "SSE ping error: "+err.Error())
return err
}
if common2.DebugEnabled {
println("SSE ping data sent.")
}
return nil
}
func doRequest(c *gin.Context, req *http.Request, info *common.RelayInfo) (*http.Response, error) {
var client *http.Client
var err error
@@ -115,69 +174,28 @@ func doRequest(c *gin.Context, req *http.Request, info *common.RelayInfo) (*http
} else {
client = service.GetHttpClient()
}
// 流式请求 ping 保活
var stopPinger func()
generalSettings := operation_setting.GetGeneralSetting()
pingEnabled := generalSettings.PingIntervalEnabled
var pingerWg sync.WaitGroup
if info.IsStream {
helper.SetEventStreamHeaders(c)
if pingEnabled {
// 处理流式请求的 ping 保活
generalSettings := operation_setting.GetGeneralSetting()
if generalSettings.PingIntervalEnabled {
pingInterval := time.Duration(generalSettings.PingIntervalSeconds) * time.Second
var pingerCtx context.Context
pingerCtx, stopPinger = context.WithCancel(c.Request.Context())
// 退出时清理 pingerCtx 防止泄露
stopPinger := startPingKeepAlive(c, pingInterval)
defer stopPinger()
pingerWg.Add(1)
gopool.Go(func() {
defer pingerWg.Done()
if pingInterval <= 0 {
pingInterval = helper.DefaultPingInterval
}
ticker := time.NewTicker(pingInterval)
defer ticker.Stop()
var pingMutex sync.Mutex
if common2.DebugEnabled {
println("SSE ping goroutine started")
}
for {
select {
case <-ticker.C:
pingMutex.Lock()
err2 := helper.PingData(c)
pingMutex.Unlock()
if err2 != nil {
common2.LogError(c, "SSE ping error: "+err.Error())
return
}
if common2.DebugEnabled {
println("SSE ping data sent.")
}
case <-pingerCtx.Done():
if common2.DebugEnabled {
println("SSE ping goroutine stopped.")
}
return
}
}
})
}
}
resp, err := client.Do(req)
// request结束后等待 ping goroutine 完成
if info.IsStream && pingEnabled {
pingerWg.Wait()
}
if err != nil {
return nil, err
}
if resp == nil {
return nil, errors.New("resp is nil")
}
_ = req.Body.Close()
_ = c.Request.Body.Close()
return resp, nil

View File

@@ -9,6 +9,7 @@ import (
"one-api/relay/channel"
"one-api/relay/channel/openai"
relaycommon "one-api/relay/common"
"strings"
"github.com/gin-gonic/gin"
)
@@ -49,6 +50,18 @@ func (a *Adaptor) ConvertOpenAIRequest(c *gin.Context, info *relaycommon.RelayIn
if request == nil {
return nil, errors.New("request is nil")
}
if strings.HasSuffix(info.UpstreamModelName, "-search") {
info.UpstreamModelName = strings.TrimSuffix(info.UpstreamModelName, "-search")
request.Model = info.UpstreamModelName
toMap := request.ToMap()
toMap["web_search"] = map[string]any{
"enable": true,
"enable_citation": true,
"enable_trace": true,
"enable_status": false,
}
return toMap, nil
}
return request, nil
}

View File

@@ -18,6 +18,24 @@ import (
"github.com/gin-gonic/gin"
)
var geminiSupportedMimeTypes = map[string]bool{
"application/pdf": true,
"audio/mpeg": true,
"audio/mp3": true,
"audio/wav": true,
"image/png": true,
"image/jpeg": true,
"text/plain": true,
"video/mov": true,
"video/mpeg": true,
"video/mp4": true,
"video/mpg": true,
"video/avi": true,
"video/wmv": true,
"video/mpegps": true,
"video/flv": true,
}
// Setting safety to the lowest possible values since Gemini is already powerless enough
func CovertGemini2OpenAI(textRequest dto.GeneralOpenAIRequest, info *relaycommon.RelayInfo) (*GeminiChatRequest, error) {
@@ -215,14 +233,20 @@ func CovertGemini2OpenAI(textRequest dto.GeneralOpenAIRequest, info *relaycommon
}
// 判断是否是url
if strings.HasPrefix(part.GetImageMedia().Url, "http") {
// 是url获取图片的类型和base64编码的数据
// 是url获取文件的类型和base64编码的数据
fileData, err := service.GetFileBase64FromUrl(part.GetImageMedia().Url)
if err != nil {
return nil, fmt.Errorf("get file base64 from url failed: %s", err.Error())
return nil, fmt.Errorf("get file base64 from url '%s' failed: %w", part.GetImageMedia().Url, err)
}
// 校验 MimeType 是否在 Gemini 支持的白名单中
if _, ok := geminiSupportedMimeTypes[strings.ToLower(fileData.MimeType)]; !ok {
return nil, fmt.Errorf("MIME type '%s' from URL '%s' is not supported by Gemini. Supported types are: %v", fileData.MimeType, part.GetImageMedia().Url, getSupportedMimeTypesList())
}
parts = append(parts, GeminiPart{
InlineData: &GeminiInlineData{
MimeType: fileData.MimeType,
MimeType: fileData.MimeType, // 使用原始的 MimeType因为大小写可能对API有意义
Data: fileData.Base64Data,
},
})
@@ -291,100 +315,126 @@ func CovertGemini2OpenAI(textRequest dto.GeneralOpenAIRequest, info *relaycommon
return &geminiRequest, nil
}
// Helper function to get a list of supported MIME types for error messages
func getSupportedMimeTypesList() []string {
keys := make([]string, 0, len(geminiSupportedMimeTypes))
for k := range geminiSupportedMimeTypes {
keys = append(keys, k)
}
return keys
}
// cleanFunctionParameters recursively removes unsupported fields from Gemini function parameters.
func cleanFunctionParameters(params interface{}) interface{} {
if params == nil {
return nil
}
paramMap, ok := params.(map[string]interface{})
if !ok {
// Not a map, return as is (e.g., could be an array or primitive)
return params
}
switch v := params.(type) {
case map[string]interface{}:
// Create a copy to avoid modifying the original
cleanedMap := make(map[string]interface{})
for k, val := range v {
cleanedMap[k] = val
}
// Create a copy to avoid modifying the original
cleanedMap := make(map[string]interface{})
for k, v := range paramMap {
cleanedMap[k] = v
}
// Remove unsupported root-level fields
delete(cleanedMap, "default")
delete(cleanedMap, "exclusiveMaximum")
delete(cleanedMap, "exclusiveMinimum")
delete(cleanedMap, "$schema")
delete(cleanedMap, "additionalProperties")
// Remove unsupported root-level fields
delete(cleanedMap, "default")
delete(cleanedMap, "exclusiveMaximum")
delete(cleanedMap, "exclusiveMinimum")
delete(cleanedMap, "$schema")
delete(cleanedMap, "additionalProperties")
// Clean properties
if props, ok := cleanedMap["properties"].(map[string]interface{}); ok && props != nil {
cleanedProps := make(map[string]interface{})
for propName, propValue := range props {
propMap, ok := propValue.(map[string]interface{})
if !ok {
cleanedProps[propName] = propValue // Keep non-map properties
continue
}
// Create a copy of the property map
cleanedPropMap := make(map[string]interface{})
for k, v := range propMap {
cleanedPropMap[k] = v
}
// Remove unsupported fields
delete(cleanedPropMap, "default")
delete(cleanedPropMap, "exclusiveMaximum")
delete(cleanedPropMap, "exclusiveMinimum")
delete(cleanedPropMap, "$schema")
delete(cleanedPropMap, "additionalProperties")
// Check and clean 'format' for string types
if propType, typeExists := cleanedPropMap["type"].(string); typeExists && propType == "string" {
if formatValue, formatExists := cleanedPropMap["format"].(string); formatExists {
if formatValue != "enum" && formatValue != "date-time" {
delete(cleanedPropMap, "format")
}
// Check and clean 'format' for string types
if propType, typeExists := cleanedMap["type"].(string); typeExists && propType == "string" {
if formatValue, formatExists := cleanedMap["format"].(string); formatExists {
if formatValue != "enum" && formatValue != "date-time" {
delete(cleanedMap, "format")
}
}
}
// Recursively clean nested properties within this property if it's an object/array
// Check the type before recursing
if propType, typeExists := cleanedPropMap["type"].(string); typeExists && (propType == "object" || propType == "array") {
cleanedProps[propName] = cleanFunctionParameters(cleanedPropMap)
} else {
cleanedProps[propName] = cleanedPropMap // Assign the cleaned map back if not recursing
// Clean properties
if props, ok := cleanedMap["properties"].(map[string]interface{}); ok && props != nil {
cleanedProps := make(map[string]interface{})
for propName, propValue := range props {
cleanedProps[propName] = cleanFunctionParameters(propValue)
}
cleanedMap["properties"] = cleanedProps
}
cleanedMap["properties"] = cleanedProps
}
// Recursively clean items in arrays if needed (e.g., type: array, items: { ... })
if items, ok := cleanedMap["items"].(map[string]interface{}); ok && items != nil {
cleanedMap["items"] = cleanFunctionParameters(items)
}
// Also handle items if it's an array of schemas
if itemsArray, ok := cleanedMap["items"].([]interface{}); ok {
cleanedItemsArray := make([]interface{}, len(itemsArray))
for i, item := range itemsArray {
cleanedItemsArray[i] = cleanFunctionParameters(item)
// Recursively clean items in arrays
if items, ok := cleanedMap["items"].(map[string]interface{}); ok && items != nil {
cleanedMap["items"] = cleanFunctionParameters(items)
}
cleanedMap["items"] = cleanedItemsArray
}
// Recursively clean other schema composition keywords if necessary
for _, field := range []string{"allOf", "anyOf", "oneOf"} {
if nested, ok := cleanedMap[field].([]interface{}); ok {
cleanedNested := make([]interface{}, len(nested))
for i, item := range nested {
cleanedNested[i] = cleanFunctionParameters(item)
// Also handle items if it's an array of schemas
if itemsArray, ok := cleanedMap["items"].([]interface{}); ok {
cleanedItemsArray := make([]interface{}, len(itemsArray))
for i, item := range itemsArray {
cleanedItemsArray[i] = cleanFunctionParameters(item)
}
cleanedMap[field] = cleanedNested
cleanedMap["items"] = cleanedItemsArray
}
}
return cleanedMap
// Recursively clean other schema composition keywords
for _, field := range []string{"allOf", "anyOf", "oneOf"} {
if nested, ok := cleanedMap[field].([]interface{}); ok {
cleanedNested := make([]interface{}, len(nested))
for i, item := range nested {
cleanedNested[i] = cleanFunctionParameters(item)
}
cleanedMap[field] = cleanedNested
}
}
// Recursively clean patternProperties
if patternProps, ok := cleanedMap["patternProperties"].(map[string]interface{}); ok {
cleanedPatternProps := make(map[string]interface{})
for pattern, schema := range patternProps {
cleanedPatternProps[pattern] = cleanFunctionParameters(schema)
}
cleanedMap["patternProperties"] = cleanedPatternProps
}
// Recursively clean definitions
if definitions, ok := cleanedMap["definitions"].(map[string]interface{}); ok {
cleanedDefinitions := make(map[string]interface{})
for defName, defSchema := range definitions {
cleanedDefinitions[defName] = cleanFunctionParameters(defSchema)
}
cleanedMap["definitions"] = cleanedDefinitions
}
// Recursively clean $defs (newer JSON Schema draft)
if defs, ok := cleanedMap["$defs"].(map[string]interface{}); ok {
cleanedDefs := make(map[string]interface{})
for defName, defSchema := range defs {
cleanedDefs[defName] = cleanFunctionParameters(defSchema)
}
cleanedMap["$defs"] = cleanedDefs
}
// Clean conditional keywords
for _, field := range []string{"if", "then", "else", "not"} {
if nested, ok := cleanedMap[field]; ok {
cleanedMap[field] = cleanFunctionParameters(nested)
}
}
return cleanedMap
case []interface{}:
// Handle arrays of schemas
cleanedArray := make([]interface{}, len(v))
for i, item := range v {
cleanedArray[i] = cleanFunctionParameters(item)
}
return cleanedArray
default:
// Not a map or array, return as is (e.g., could be a primitive)
return params
}
}
func removeAdditionalPropertiesWithDepth(schema interface{}, depth int) interface{} {

View File

@@ -41,6 +41,9 @@ func getAndValidImageRequest(c *gin.Context, info *relaycommon.RelayInfo) (*dto.
imageRequest.Quality = "standard"
}
}
if imageRequest.N == 0 {
imageRequest.N = 1
}
default:
err := common.UnmarshalBodyReusable(c, imageRequest)
if err != nil {