diff --git a/common/endpoint_defaults.go b/common/endpoint_defaults.go index ffc263507..25f9c68eb 100644 --- a/common/endpoint_defaults.go +++ b/common/endpoint_defaults.go @@ -23,6 +23,7 @@ var defaultEndpointInfoMap = map[constant.EndpointType]EndpointInfo{ constant.EndpointTypeGemini: {Path: "/v1beta/models/{model}:generateContent", Method: "POST"}, constant.EndpointTypeJinaRerank: {Path: "/rerank", Method: "POST"}, constant.EndpointTypeImageGeneration: {Path: "/v1/images/generations", Method: "POST"}, + constant.EndpointTypeEmbeddings: {Path: "/v1/embeddings", Method: "POST"}, } // GetDefaultEndpointInfo 返回指定端点类型的默认信息以及是否存在 diff --git a/constant/endpoint_type.go b/constant/endpoint_type.go index ef096b759..f799e5ba8 100644 --- a/constant/endpoint_type.go +++ b/constant/endpoint_type.go @@ -9,6 +9,7 @@ const ( EndpointTypeGemini EndpointType = "gemini" EndpointTypeJinaRerank EndpointType = "jina-rerank" EndpointTypeImageGeneration EndpointType = "image-generation" + EndpointTypeEmbeddings EndpointType = "embeddings" //EndpointTypeMidjourney EndpointType = "midjourney-proxy" //EndpointTypeSuno EndpointType = "suno-proxy" //EndpointTypeKling EndpointType = "kling" diff --git a/controller/channel-test.go b/controller/channel-test.go index 9ea6eed75..b3a3be4eb 100644 --- a/controller/channel-test.go +++ b/controller/channel-test.go @@ -38,7 +38,7 @@ type testResult struct { newAPIError *types.NewAPIError } -func testChannel(channel *model.Channel, testModel string) testResult { +func testChannel(channel *model.Channel, testModel string, endpointType string) testResult { tik := time.Now() if channel.Type == constant.ChannelTypeMidjourney { return testResult{ @@ -81,18 +81,26 @@ func testChannel(channel *model.Channel, testModel string) testResult { requestPath := "/v1/chat/completions" - // 先判断是否为 Embedding 模型 - if strings.Contains(strings.ToLower(testModel), "embedding") || - strings.HasPrefix(testModel, "m3e") || // m3e 系列模型 - strings.Contains(testModel, "bge-") || // bge 系列模型 - strings.Contains(testModel, "embed") || - channel.Type == constant.ChannelTypeMokaAI { // 其他 embedding 模型 - requestPath = "/v1/embeddings" // 修改请求路径 - } + // 如果指定了端点类型,使用指定的端点类型 + if endpointType != "" { + if endpointInfo, ok := common.GetDefaultEndpointInfo(constant.EndpointType(endpointType)); ok { + requestPath = endpointInfo.Path + } + } else { + // 如果没有指定端点类型,使用原有的自动检测逻辑 + // 先判断是否为 Embedding 模型 + if strings.Contains(strings.ToLower(testModel), "embedding") || + strings.HasPrefix(testModel, "m3e") || // m3e 系列模型 + strings.Contains(testModel, "bge-") || // bge 系列模型 + strings.Contains(testModel, "embed") || + channel.Type == constant.ChannelTypeMokaAI { // 其他 embedding 模型 + requestPath = "/v1/embeddings" // 修改请求路径 + } - // VolcEngine 图像生成模型 - if channel.Type == constant.ChannelTypeVolcEngine && strings.Contains(testModel, "seedream") { - requestPath = "/v1/images/generations" + // VolcEngine 图像生成模型 + if channel.Type == constant.ChannelTypeVolcEngine && strings.Contains(testModel, "seedream") { + requestPath = "/v1/images/generations" + } } c.Request = &http.Request{ @@ -114,21 +122,6 @@ func testChannel(channel *model.Channel, testModel string) testResult { } } - // 重新检查模型类型并更新请求路径 - if strings.Contains(strings.ToLower(testModel), "embedding") || - strings.HasPrefix(testModel, "m3e") || - strings.Contains(testModel, "bge-") || - strings.Contains(testModel, "embed") || - channel.Type == constant.ChannelTypeMokaAI { - requestPath = "/v1/embeddings" - c.Request.URL.Path = requestPath - } - - if channel.Type == constant.ChannelTypeVolcEngine && strings.Contains(testModel, "seedream") { - requestPath = "/v1/images/generations" - c.Request.URL.Path = requestPath - } - cache, err := model.GetUserCache(1) if err != nil { return testResult{ @@ -153,17 +146,54 @@ func testChannel(channel *model.Channel, testModel string) testResult { newAPIError: newAPIError, } } - request := buildTestRequest(testModel) - // Determine relay format based on request path - relayFormat := types.RelayFormatOpenAI - if c.Request.URL.Path == "/v1/embeddings" { - relayFormat = types.RelayFormatEmbedding - } - if c.Request.URL.Path == "/v1/images/generations" { - relayFormat = types.RelayFormatOpenAIImage + // Determine relay format based on endpoint type or request path + var relayFormat types.RelayFormat + if endpointType != "" { + // 根据指定的端点类型设置 relayFormat + switch constant.EndpointType(endpointType) { + case constant.EndpointTypeOpenAI: + relayFormat = types.RelayFormatOpenAI + case constant.EndpointTypeOpenAIResponse: + relayFormat = types.RelayFormatOpenAIResponses + case constant.EndpointTypeAnthropic: + relayFormat = types.RelayFormatClaude + case constant.EndpointTypeGemini: + relayFormat = types.RelayFormatGemini + case constant.EndpointTypeJinaRerank: + relayFormat = types.RelayFormatRerank + case constant.EndpointTypeImageGeneration: + relayFormat = types.RelayFormatOpenAIImage + case constant.EndpointTypeEmbeddings: + relayFormat = types.RelayFormatEmbedding + default: + relayFormat = types.RelayFormatOpenAI + } + } else { + // 根据请求路径自动检测 + relayFormat = types.RelayFormatOpenAI + if c.Request.URL.Path == "/v1/embeddings" { + relayFormat = types.RelayFormatEmbedding + } + if c.Request.URL.Path == "/v1/images/generations" { + relayFormat = types.RelayFormatOpenAIImage + } + if c.Request.URL.Path == "/v1/messages" { + relayFormat = types.RelayFormatClaude + } + if strings.Contains(c.Request.URL.Path, "/v1beta/models") { + relayFormat = types.RelayFormatGemini + } + if c.Request.URL.Path == "/v1/rerank" || c.Request.URL.Path == "/rerank" { + relayFormat = types.RelayFormatRerank + } + if c.Request.URL.Path == "/v1/responses" { + relayFormat = types.RelayFormatOpenAIResponses + } } + request := buildTestRequest(testModel, endpointType) + info, err := relaycommon.GenRelayInfo(c, relayFormat, request, nil) if err != nil { @@ -186,7 +216,8 @@ func testChannel(channel *model.Channel, testModel string) testResult { } testModel = info.UpstreamModelName - request.Model = testModel + // 更新请求中的模型名称 + request.SetModelName(testModel) apiType, _ := common.ChannelType2APIType(channel.Type) adaptor := relay.GetAdaptor(apiType) @@ -216,33 +247,62 @@ func testChannel(channel *model.Channel, testModel string) testResult { var convertedRequest any // 根据 RelayMode 选择正确的转换函数 - if info.RelayMode == relayconstant.RelayModeEmbeddings { - // 创建一个 EmbeddingRequest - embeddingRequest := dto.EmbeddingRequest{ - Input: request.Input, - Model: request.Model, - } - // 调用专门用于 Embedding 的转换函数 - convertedRequest, err = adaptor.ConvertEmbeddingRequest(c, info, embeddingRequest) - } else if info.RelayMode == relayconstant.RelayModeImagesGenerations { - // 创建一个 ImageRequest - prompt := "cat" - if request.Prompt != nil { - if promptStr, ok := request.Prompt.(string); ok && promptStr != "" { - prompt = promptStr + switch info.RelayMode { + case relayconstant.RelayModeEmbeddings: + // Embedding 请求 - request 已经是正确的类型 + if embeddingReq, ok := request.(*dto.EmbeddingRequest); ok { + convertedRequest, err = adaptor.ConvertEmbeddingRequest(c, info, *embeddingReq) + } else { + return testResult{ + context: c, + localErr: errors.New("invalid embedding request type"), + newAPIError: types.NewError(errors.New("invalid embedding request type"), types.ErrorCodeConvertRequestFailed), } } - imageRequest := dto.ImageRequest{ - Prompt: prompt, - Model: request.Model, - N: uint(request.N), - Size: request.Size, + case relayconstant.RelayModeImagesGenerations: + // 图像生成请求 - request 已经是正确的类型 + if imageReq, ok := request.(*dto.ImageRequest); ok { + convertedRequest, err = adaptor.ConvertImageRequest(c, info, *imageReq) + } else { + return testResult{ + context: c, + localErr: errors.New("invalid image request type"), + newAPIError: types.NewError(errors.New("invalid image request type"), types.ErrorCodeConvertRequestFailed), + } + } + case relayconstant.RelayModeRerank: + // Rerank 请求 - request 已经是正确的类型 + if rerankReq, ok := request.(*dto.RerankRequest); ok { + convertedRequest, err = adaptor.ConvertRerankRequest(c, info.RelayMode, *rerankReq) + } else { + return testResult{ + context: c, + localErr: errors.New("invalid rerank request type"), + newAPIError: types.NewError(errors.New("invalid rerank request type"), types.ErrorCodeConvertRequestFailed), + } + } + case relayconstant.RelayModeResponses: + // Response 请求 - request 已经是正确的类型 + if responseReq, ok := request.(*dto.OpenAIResponsesRequest); ok { + convertedRequest, err = adaptor.ConvertOpenAIResponsesRequest(c, info, *responseReq) + } else { + return testResult{ + context: c, + localErr: errors.New("invalid response request type"), + newAPIError: types.NewError(errors.New("invalid response request type"), types.ErrorCodeConvertRequestFailed), + } + } + default: + // Chat/Completion 等其他请求类型 + if generalReq, ok := request.(*dto.GeneralOpenAIRequest); ok { + convertedRequest, err = adaptor.ConvertOpenAIRequest(c, info, generalReq) + } else { + return testResult{ + context: c, + localErr: errors.New("invalid general request type"), + newAPIError: types.NewError(errors.New("invalid general request type"), types.ErrorCodeConvertRequestFailed), + } } - // 调用专门用于图像生成的转换函数 - convertedRequest, err = adaptor.ConvertImageRequest(c, info, imageRequest) - } else { - // 对其他所有请求类型(如 Chat),保持原有逻辑 - convertedRequest, err = adaptor.ConvertOpenAIRequest(c, info, request) } if err != nil { @@ -345,22 +405,82 @@ func testChannel(channel *model.Channel, testModel string) testResult { } } -func buildTestRequest(model string) *dto.GeneralOpenAIRequest { - testRequest := &dto.GeneralOpenAIRequest{ - Model: "", // this will be set later - Stream: false, +func buildTestRequest(model string, endpointType string) dto.Request { + // 根据端点类型构建不同的测试请求 + if endpointType != "" { + switch constant.EndpointType(endpointType) { + case constant.EndpointTypeEmbeddings: + // 返回 EmbeddingRequest + return &dto.EmbeddingRequest{ + Model: model, + Input: []any{"hello world"}, + } + case constant.EndpointTypeImageGeneration: + // 返回 ImageRequest + return &dto.ImageRequest{ + Model: model, + Prompt: "a cute cat", + N: 1, + Size: "1024x1024", + } + case constant.EndpointTypeJinaRerank: + // 返回 RerankRequest + return &dto.RerankRequest{ + Model: model, + Query: "What is Deep Learning?", + Documents: []any{"Deep Learning is a subset of machine learning.", "Machine learning is a field of artificial intelligence."}, + TopN: 2, + } + case constant.EndpointTypeOpenAIResponse: + // 返回 OpenAIResponsesRequest + return &dto.OpenAIResponsesRequest{ + Model: model, + Input: json.RawMessage("\"hi\""), + } + case constant.EndpointTypeAnthropic, constant.EndpointTypeGemini, constant.EndpointTypeOpenAI: + // 返回 GeneralOpenAIRequest + maxTokens := uint(10) + if constant.EndpointType(endpointType) == constant.EndpointTypeGemini { + maxTokens = 3000 + } + return &dto.GeneralOpenAIRequest{ + Model: model, + Stream: false, + Messages: []dto.Message{ + { + Role: "user", + Content: "hi", + }, + }, + MaxTokens: maxTokens, + } + } } + // 自动检测逻辑(保持原有行为) // 先判断是否为 Embedding 模型 - if strings.Contains(strings.ToLower(model), "embedding") || // 其他 embedding 模型 - strings.HasPrefix(model, "m3e") || // m3e 系列模型 + if strings.Contains(strings.ToLower(model), "embedding") || + strings.HasPrefix(model, "m3e") || strings.Contains(model, "bge-") { - testRequest.Model = model - // Embedding 请求 - testRequest.Input = []any{"hello world"} // 修改为any,因为dto/openai_request.go 的ParseInput方法无法处理[]string类型 - return testRequest + // 返回 EmbeddingRequest + return &dto.EmbeddingRequest{ + Model: model, + Input: []any{"hello world"}, + } } - // 并非Embedding 模型 + + // Chat/Completion 请求 - 返回 GeneralOpenAIRequest + testRequest := &dto.GeneralOpenAIRequest{ + Model: model, + Stream: false, + Messages: []dto.Message{ + { + Role: "user", + Content: "hi", + }, + }, + } + if strings.HasPrefix(model, "o") { testRequest.MaxCompletionTokens = 10 } else if strings.Contains(model, "thinking") { @@ -373,12 +493,6 @@ func buildTestRequest(model string) *dto.GeneralOpenAIRequest { testRequest.MaxTokens = 10 } - testMessage := dto.Message{ - Role: "user", - Content: "hi", - } - testRequest.Model = model - testRequest.Messages = append(testRequest.Messages, testMessage) return testRequest } @@ -402,8 +516,9 @@ func TestChannel(c *gin.Context) { // } //}() testModel := c.Query("model") + endpointType := c.Query("endpoint_type") tik := time.Now() - result := testChannel(channel, testModel) + result := testChannel(channel, testModel, endpointType) if result.localErr != nil { c.JSON(http.StatusOK, gin.H{ "success": false, @@ -429,7 +544,6 @@ func TestChannel(c *gin.Context) { "message": "", "time": consumedTime, }) - return } var testAllChannelsLock sync.Mutex @@ -463,7 +577,7 @@ func testAllChannels(notify bool) error { for _, channel := range channels { isChannelEnabled := channel.Status == common.ChannelStatusEnabled tik := time.Now() - result := testChannel(channel, "") + result := testChannel(channel, "", "") tok := time.Now() milliseconds := tok.Sub(tik).Milliseconds() @@ -477,7 +591,7 @@ func testAllChannels(notify bool) error { // 当错误检查通过,才检查响应时间 if common.AutomaticDisableChannelEnabled && !shouldBanChannel { if milliseconds > disableThreshold { - err := errors.New(fmt.Sprintf("响应时间 %.2fs 超过阈值 %.2fs", float64(milliseconds)/1000.0, float64(disableThreshold)/1000.0)) + err := fmt.Errorf("响应时间 %.2fs 超过阈值 %.2fs", float64(milliseconds)/1000.0, float64(disableThreshold)/1000.0) newAPIError = types.NewOpenAIError(err, types.ErrorCodeChannelResponseTimeExceeded, http.StatusRequestTimeout) shouldBanChannel = true } @@ -514,7 +628,6 @@ func TestAllChannels(c *gin.Context) { "success": true, "message": "", }) - return } var autoTestChannelsOnce sync.Once diff --git a/relay/channel/volcengine/adaptor.go b/relay/channel/volcengine/adaptor.go index 21d6e1705..234ab4c99 100644 --- a/relay/channel/volcengine/adaptor.go +++ b/relay/channel/volcengine/adaptor.go @@ -195,21 +195,29 @@ func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) { baseUrl = channelconstant.ChannelBaseURLs[channelconstant.ChannelTypeVolcEngine] } - switch info.RelayMode { - case constant.RelayModeChatCompletions: + switch info.RelayFormat { + case types.RelayFormatClaude: if strings.HasPrefix(info.UpstreamModelName, "bot") { return fmt.Sprintf("%s/api/v3/bots/chat/completions", baseUrl), nil } return fmt.Sprintf("%s/api/v3/chat/completions", baseUrl), nil - case constant.RelayModeEmbeddings: - return fmt.Sprintf("%s/api/v3/embeddings", baseUrl), nil - case constant.RelayModeImagesGenerations: - return fmt.Sprintf("%s/api/v3/images/generations", baseUrl), nil - case constant.RelayModeImagesEdits: - return fmt.Sprintf("%s/api/v3/images/edits", baseUrl), nil - case constant.RelayModeRerank: - return fmt.Sprintf("%s/api/v3/rerank", baseUrl), nil default: + switch info.RelayMode { + case constant.RelayModeChatCompletions: + if strings.HasPrefix(info.UpstreamModelName, "bot") { + return fmt.Sprintf("%s/api/v3/bots/chat/completions", baseUrl), nil + } + return fmt.Sprintf("%s/api/v3/chat/completions", baseUrl), nil + case constant.RelayModeEmbeddings: + return fmt.Sprintf("%s/api/v3/embeddings", baseUrl), nil + case constant.RelayModeImagesGenerations: + return fmt.Sprintf("%s/api/v3/images/generations", baseUrl), nil + case constant.RelayModeImagesEdits: + return fmt.Sprintf("%s/api/v3/images/edits", baseUrl), nil + case constant.RelayModeRerank: + return fmt.Sprintf("%s/api/v3/rerank", baseUrl), nil + default: + } } return "", fmt.Errorf("unsupported relay mode: %d", info.RelayMode) } diff --git a/service/passkey/service.go b/service/passkey/service.go index 62befb9d3..dc8da0ccc 100644 --- a/service/passkey/service.go +++ b/service/passkey/service.go @@ -80,9 +80,11 @@ func BuildWebAuthn(r *http.Request) (*webauthn.WebAuthn, error) { } func resolveOrigins(r *http.Request, settings *system_setting.PasskeySettings) ([]string, error) { - if len(settings.Origins) > 0 { - origins := make([]string, 0, len(settings.Origins)) - for _, origin := range settings.Origins { + originsStr := strings.TrimSpace(settings.Origins) + if originsStr != "" { + originList := strings.Split(originsStr, ",") + origins := make([]string, 0, len(originList)) + for _, origin := range originList { trimmed := strings.TrimSpace(origin) if trimmed == "" { continue diff --git a/setting/system_setting/passkey.go b/setting/system_setting/passkey.go index 54746e808..a0766a67b 100644 --- a/setting/system_setting/passkey.go +++ b/setting/system_setting/passkey.go @@ -1,25 +1,27 @@ package system_setting import ( + "net/url" "one-api/common" "one-api/setting/config" + "strings" ) type PasskeySettings struct { - Enabled bool `json:"enabled"` - RPDisplayName string `json:"rp_display_name"` - RPID string `json:"rp_id"` - Origins []string `json:"origins"` - AllowInsecureOrigin bool `json:"allow_insecure_origin"` - UserVerification string `json:"user_verification"` - AttachmentPreference string `json:"attachment_preference"` + Enabled bool `json:"enabled"` + RPDisplayName string `json:"rp_display_name"` + RPID string `json:"rp_id"` + Origins string `json:"origins"` + AllowInsecureOrigin bool `json:"allow_insecure_origin"` + UserVerification string `json:"user_verification"` + AttachmentPreference string `json:"attachment_preference"` } var defaultPasskeySettings = PasskeySettings{ Enabled: false, RPDisplayName: common.SystemName, RPID: "", - Origins: []string{}, + Origins: "", AllowInsecureOrigin: false, UserVerification: "preferred", AttachmentPreference: "", @@ -30,5 +32,18 @@ func init() { } func GetPasskeySettings() *PasskeySettings { + if defaultPasskeySettings.RPID == "" && ServerAddress != "" { + // 从ServerAddress提取域名作为RPID + // ServerAddress可能是 "https://newapi.pro" 这种格式 + serverAddr := strings.TrimSpace(ServerAddress) + if parsed, err := url.Parse(serverAddr); err == nil && parsed.Host != "" { + defaultPasskeySettings.RPID = parsed.Host + } else { + defaultPasskeySettings.RPID = serverAddr + } + } + if defaultPasskeySettings.Origins == "" || defaultPasskeySettings.Origins == "[]" { + defaultPasskeySettings.Origins = ServerAddress + } return &defaultPasskeySettings } diff --git a/web/src/components/settings/SystemSetting.jsx b/web/src/components/settings/SystemSetting.jsx index f0c2dbc3a..112d104a6 100644 --- a/web/src/components/settings/SystemSetting.jsx +++ b/web/src/components/settings/SystemSetting.jsx @@ -122,7 +122,6 @@ const SystemSetting = () => { const [domainList, setDomainList] = useState([]); const [ipList, setIpList] = useState([]); const [allowedPorts, setAllowedPorts] = useState([]); - const [passkeyOrigins, setPasskeyOrigins] = useState([]); const getOptions = async () => { setLoading(true); @@ -188,22 +187,19 @@ const SystemSetting = () => { item.value = toBoolean(item.value); break; case 'passkey.origins': - try { - const origins = item.value ? JSON.parse(item.value) : []; - setPasskeyOrigins(Array.isArray(origins) ? origins : []); - item.value = Array.isArray(origins) ? origins : []; - } catch (e) { - setPasskeyOrigins([]); - item.value = []; - } + // origins是逗号分隔的字符串,直接使用 + item.value = item.value || ''; break; case 'passkey.rp_display_name': case 'passkey.rp_id': - case 'passkey.user_verification': case 'passkey.attachment_preference': // 确保字符串字段不为null/undefined item.value = item.value || ''; break; + case 'passkey.user_verification': + // 确保有默认值 + item.value = item.value || 'preferred'; + break; case 'Price': case 'MinTopUp': item.value = parseFloat(item.value); @@ -611,42 +607,33 @@ const SystemSetting = () => { }; const submitPasskeySettings = async () => { + // 使用formApi直接获取当前表单值 + const formValues = formApiRef.current?.getValues() || {}; + const options = []; - // 只在值有变化时才提交,并确保空值转换为空字符串 - if (originInputs['passkey.rp_display_name'] !== inputs['passkey.rp_display_name']) { - options.push({ - key: 'passkey.rp_display_name', - value: inputs['passkey.rp_display_name'] || '', - }); - } - if (originInputs['passkey.rp_id'] !== inputs['passkey.rp_id']) { - options.push({ - key: 'passkey.rp_id', - value: inputs['passkey.rp_id'] || '', - }); - } - if (originInputs['passkey.user_verification'] !== inputs['passkey.user_verification']) { - options.push({ - key: 'passkey.user_verification', - value: inputs['passkey.user_verification'] || 'preferred', - }); - } - if (originInputs['passkey.attachment_preference'] !== inputs['passkey.attachment_preference']) { - options.push({ - key: 'passkey.attachment_preference', - value: inputs['passkey.attachment_preference'] || '', - }); - } - // Origins总是提交,因为它们可能会被用户清空 + options.push({ + key: 'passkey.rp_display_name', + value: formValues['passkey.rp_display_name'] || inputs['passkey.rp_display_name'] || '', + }); + options.push({ + key: 'passkey.rp_id', + value: formValues['passkey.rp_id'] || inputs['passkey.rp_id'] || '', + }); + options.push({ + key: 'passkey.user_verification', + value: formValues['passkey.user_verification'] || inputs['passkey.user_verification'] || 'preferred', + }); + options.push({ + key: 'passkey.attachment_preference', + value: formValues['passkey.attachment_preference'] || inputs['passkey.attachment_preference'] || '', + }); options.push({ key: 'passkey.origins', - value: JSON.stringify(Array.isArray(passkeyOrigins) ? passkeyOrigins : []), + value: formValues['passkey.origins'] || inputs['passkey.origins'] || '', }); - if (options.length > 0) { - await updateOptions(options); - } + await updateOptions(options); }; const handleCheckboxChange = async (optionKey, event) => { @@ -1037,7 +1024,7 @@ const SystemSetting = () => { >