Merge branch 'main-upstream' into fix/volcengine_default_baseurl

# Conflicts:
#	main.go
This commit is contained in:
Seefs
2025-09-29 12:08:52 +08:00
16 changed files with 880 additions and 348 deletions

View File

@@ -2,9 +2,10 @@ package common
import (
"fmt"
"github.com/gin-gonic/gin"
"os"
"time"
"github.com/gin-gonic/gin"
)
func SysLog(s string) {
@@ -22,3 +23,33 @@ func FatalLog(v ...any) {
_, _ = fmt.Fprintf(gin.DefaultErrorWriter, "[FATAL] %v | %v \n", t.Format("2006/01/02 - 15:04:05"), v)
os.Exit(1)
}
func LogStartupSuccess(startTime time.Time, port string) {
duration := time.Since(startTime)
durationMs := duration.Milliseconds()
// Get network IPs
networkIps := GetNetworkIps()
// Print blank line for spacing
fmt.Fprintf(gin.DefaultWriter, "\n")
// Print the main success message
fmt.Fprintf(gin.DefaultWriter, " \033[32m%s %s\033[0m ready in %d ms\n", SystemName, Version, durationMs)
fmt.Fprintf(gin.DefaultWriter, "\n")
// Skip fancy startup message in container environments
if !IsRunningInContainer() {
// Print local URL
fmt.Fprintf(gin.DefaultWriter, " ➜ \033[1mLocal:\033[0m http://localhost:%s/\n", port)
}
// Print network URLs
for _, ip := range networkIps {
fmt.Fprintf(gin.DefaultWriter, " ➜ \033[1mNetwork:\033[0m http://%s:%s/\n", ip, port)
}
// Print blank line for spacing
fmt.Fprintf(gin.DefaultWriter, "\n")
}

View File

@@ -68,6 +68,78 @@ func GetIp() (ip string) {
return
}
func GetNetworkIps() []string {
var networkIps []string
ips, err := net.InterfaceAddrs()
if err != nil {
log.Println(err)
return networkIps
}
for _, a := range ips {
if ipNet, ok := a.(*net.IPNet); ok && !ipNet.IP.IsLoopback() {
if ipNet.IP.To4() != nil {
ip := ipNet.IP.String()
// Include common private network ranges
if strings.HasPrefix(ip, "10.") ||
strings.HasPrefix(ip, "172.") ||
strings.HasPrefix(ip, "192.168.") {
networkIps = append(networkIps, ip)
}
}
}
}
return networkIps
}
// IsRunningInContainer detects if the application is running inside a container
func IsRunningInContainer() bool {
// Method 1: Check for .dockerenv file (Docker containers)
if _, err := os.Stat("/.dockerenv"); err == nil {
return true
}
// Method 2: Check cgroup for container indicators
if data, err := os.ReadFile("/proc/1/cgroup"); err == nil {
content := string(data)
if strings.Contains(content, "docker") ||
strings.Contains(content, "containerd") ||
strings.Contains(content, "kubepods") ||
strings.Contains(content, "/lxc/") {
return true
}
}
// Method 3: Check environment variables commonly set by container runtimes
containerEnvVars := []string{
"KUBERNETES_SERVICE_HOST",
"DOCKER_CONTAINER",
"container",
}
for _, envVar := range containerEnvVars {
if os.Getenv(envVar) != "" {
return true
}
}
// Method 4: Check if init process is not the traditional init
if data, err := os.ReadFile("/proc/1/comm"); err == nil {
comm := strings.TrimSpace(string(data))
// In containers, process 1 is often not "init" or "systemd"
if comm != "init" && comm != "systemd" {
// Additional check: if it's a common container entrypoint
if strings.Contains(comm, "docker") ||
strings.Contains(comm, "containerd") ||
strings.Contains(comm, "runc") {
return true
}
}
}
return false
}
var sizeKB = 1024
var sizeMB = sizeKB * 1024
var sizeGB = sizeMB * 1024

View File

@@ -19,4 +19,12 @@ const (
type ChannelOtherSettings struct {
AzureResponsesVersion string `json:"azure_responses_version,omitempty"`
VertexKeyType VertexKeyType `json:"vertex_key_type,omitempty"` // "json" or "api_key"
OpenRouterEnterprise *bool `json:"openrouter_enterprise,omitempty"`
}
func (s *ChannelOtherSettings) IsOpenRouterEnterprise() bool {
if s == nil || s.OpenRouterEnterprise == nil {
return false
}
return *s.OpenRouterEnterprise
}

View File

@@ -18,6 +18,7 @@ import (
"os"
"strconv"
"strings"
"time"
"github.com/bytedance/gopkg/util/gopool"
"github.com/gin-contrib/sessions"
@@ -35,6 +36,7 @@ var buildFS embed.FS
var indexPage []byte
func main() {
startTime := time.Now()
err := InitResources()
if err != nil {
@@ -168,6 +170,10 @@ func main() {
if port == "" {
port = strconv.Itoa(*common.Port)
}
// Log startup success message
common.LogStartupSuccess(startTime, port)
err = server.Run(":" + port)
if err != nil {
common.FatalLog("failed to start HTTP server: " + err.Error())
@@ -222,4 +228,4 @@ func InitResources() error {
return err
}
return nil
}
}

View File

@@ -265,6 +265,7 @@ func doRequest(c *gin.Context, req *http.Request, info *common.RelayInfo) (*http
resp, err := client.Do(req)
if err != nil {
logger.LogError(c, "do request failed: "+err.Error())
return nil, types.NewError(err, types.ErrorCodeDoRequestFailed, types.ErrOptionWithHideErrMsg("upstream error: do request failed"))
}
if resp == nil {

View File

@@ -10,6 +10,7 @@ import (
relaycommon "one-api/relay/common"
relayconstant "one-api/relay/constant"
"one-api/types"
"strings"
"github.com/gin-gonic/gin"
)
@@ -17,10 +18,7 @@ import (
type Adaptor struct {
}
func (a *Adaptor) ConvertGeminiRequest(*gin.Context, *relaycommon.RelayInfo, *dto.GeminiChatRequest) (any, error) {
//TODO implement me
return nil, errors.New("not implemented")
}
func (a *Adaptor) ConvertGeminiRequest(*gin.Context, *relaycommon.RelayInfo, *dto.GeminiChatRequest) (any, error) { return nil, errors.New("not implemented") }
func (a *Adaptor) ConvertClaudeRequest(c *gin.Context, info *relaycommon.RelayInfo, request *dto.ClaudeRequest) (any, error) {
openaiAdaptor := openai.Adaptor{}
@@ -31,32 +29,21 @@ func (a *Adaptor) ConvertClaudeRequest(c *gin.Context, info *relaycommon.RelayIn
openaiRequest.(*dto.GeneralOpenAIRequest).StreamOptions = &dto.StreamOptions{
IncludeUsage: true,
}
return requestOpenAI2Ollama(c, openaiRequest.(*dto.GeneralOpenAIRequest))
// map to ollama chat request (Claude -> OpenAI -> Ollama chat)
return openAIChatToOllamaChat(c, openaiRequest.(*dto.GeneralOpenAIRequest))
}
func (a *Adaptor) ConvertAudioRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.AudioRequest) (io.Reader, error) {
//TODO implement me
return nil, errors.New("not implemented")
}
func (a *Adaptor) ConvertAudioRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.AudioRequest) (io.Reader, error) { return nil, errors.New("not implemented") }
func (a *Adaptor) ConvertImageRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.ImageRequest) (any, error) {
//TODO implement me
return nil, errors.New("not implemented")
}
func (a *Adaptor) ConvertImageRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.ImageRequest) (any, error) { return nil, errors.New("not implemented") }
func (a *Adaptor) Init(info *relaycommon.RelayInfo) {
}
func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) {
if info.RelayFormat == types.RelayFormatClaude {
return info.ChannelBaseUrl + "/v1/chat/completions", nil
}
switch info.RelayMode {
case relayconstant.RelayModeEmbeddings:
return info.ChannelBaseUrl + "/api/embed", nil
default:
return relaycommon.GetFullRequestURL(info.ChannelBaseUrl, info.RequestURLPath, info.ChannelType), nil
}
if info.RelayMode == relayconstant.RelayModeEmbeddings { return info.ChannelBaseUrl + "/api/embed", nil }
if strings.Contains(info.RequestURLPath, "/v1/completions") || info.RelayMode == relayconstant.RelayModeCompletions { return info.ChannelBaseUrl + "/api/generate", nil }
return info.ChannelBaseUrl + "/api/chat", nil
}
func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Header, info *relaycommon.RelayInfo) error {
@@ -66,10 +53,12 @@ func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Header, info *rel
}
func (a *Adaptor) ConvertOpenAIRequest(c *gin.Context, info *relaycommon.RelayInfo, request *dto.GeneralOpenAIRequest) (any, error) {
if request == nil {
return nil, errors.New("request is nil")
if request == nil { return nil, errors.New("request is nil") }
// decide generate or chat
if strings.Contains(info.RequestURLPath, "/v1/completions") || info.RelayMode == relayconstant.RelayModeCompletions {
return openAIToGenerate(c, request)
}
return requestOpenAI2Ollama(c, request)
return openAIChatToOllamaChat(c, request)
}
func (a *Adaptor) ConvertRerankRequest(c *gin.Context, relayMode int, request dto.RerankRequest) (any, error) {
@@ -80,10 +69,7 @@ func (a *Adaptor) ConvertEmbeddingRequest(c *gin.Context, info *relaycommon.Rela
return requestOpenAI2Embeddings(request), nil
}
func (a *Adaptor) ConvertOpenAIResponsesRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.OpenAIResponsesRequest) (any, error) {
// TODO implement me
return nil, errors.New("not implemented")
}
func (a *Adaptor) ConvertOpenAIResponsesRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.OpenAIResponsesRequest) (any, error) { return nil, errors.New("not implemented") }
func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (any, error) {
return channel.DoApiRequest(a, c, info, requestBody)
@@ -92,15 +78,13 @@ func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, request
func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage any, err *types.NewAPIError) {
switch info.RelayMode {
case relayconstant.RelayModeEmbeddings:
usage, err = ollamaEmbeddingHandler(c, info, resp)
return ollamaEmbeddingHandler(c, info, resp)
default:
if info.IsStream {
usage, err = openai.OaiStreamHandler(c, info, resp)
} else {
usage, err = openai.OpenaiHandler(c, info, resp)
return ollamaStreamHandler(c, info, resp)
}
return ollamaChatHandler(c, info, resp)
}
return
}
func (a *Adaptor) GetModelList() []string {

View File

@@ -2,48 +2,69 @@ package ollama
import (
"encoding/json"
"one-api/dto"
)
type OllamaRequest struct {
Model string `json:"model,omitempty"`
Messages []dto.Message `json:"messages,omitempty"`
Stream bool `json:"stream,omitempty"`
Temperature *float64 `json:"temperature,omitempty"`
Seed float64 `json:"seed,omitempty"`
Topp float64 `json:"top_p,omitempty"`
TopK int `json:"top_k,omitempty"`
Stop any `json:"stop,omitempty"`
MaxTokens uint `json:"max_tokens,omitempty"`
Tools []dto.ToolCallRequest `json:"tools,omitempty"`
ResponseFormat any `json:"response_format,omitempty"`
FrequencyPenalty float64 `json:"frequency_penalty,omitempty"`
PresencePenalty float64 `json:"presence_penalty,omitempty"`
Suffix any `json:"suffix,omitempty"`
StreamOptions *dto.StreamOptions `json:"stream_options,omitempty"`
Prompt any `json:"prompt,omitempty"`
Think json.RawMessage `json:"think,omitempty"`
type OllamaChatMessage struct {
Role string `json:"role"`
Content string `json:"content,omitempty"`
Images []string `json:"images,omitempty"`
ToolCalls []OllamaToolCall `json:"tool_calls,omitempty"`
ToolName string `json:"tool_name,omitempty"`
Thinking json.RawMessage `json:"thinking,omitempty"`
}
type Options struct {
Seed int `json:"seed,omitempty"`
Temperature *float64 `json:"temperature,omitempty"`
TopK int `json:"top_k,omitempty"`
TopP float64 `json:"top_p,omitempty"`
FrequencyPenalty float64 `json:"frequency_penalty,omitempty"`
PresencePenalty float64 `json:"presence_penalty,omitempty"`
NumPredict int `json:"num_predict,omitempty"`
NumCtx int `json:"num_ctx,omitempty"`
type OllamaToolFunction struct {
Name string `json:"name"`
Description string `json:"description,omitempty"`
Parameters interface{} `json:"parameters,omitempty"`
}
type OllamaTool struct {
Type string `json:"type"`
Function OllamaToolFunction `json:"function"`
}
type OllamaToolCall struct {
Function struct {
Name string `json:"name"`
Arguments interface{} `json:"arguments"`
} `json:"function"`
}
type OllamaChatRequest struct {
Model string `json:"model"`
Messages []OllamaChatMessage `json:"messages"`
Tools interface{} `json:"tools,omitempty"`
Format interface{} `json:"format,omitempty"`
Stream bool `json:"stream,omitempty"`
Options map[string]any `json:"options,omitempty"`
KeepAlive interface{} `json:"keep_alive,omitempty"`
Think json.RawMessage `json:"think,omitempty"`
}
type OllamaGenerateRequest struct {
Model string `json:"model"`
Prompt string `json:"prompt,omitempty"`
Suffix string `json:"suffix,omitempty"`
Images []string `json:"images,omitempty"`
Format interface{} `json:"format,omitempty"`
Stream bool `json:"stream,omitempty"`
Options map[string]any `json:"options,omitempty"`
KeepAlive interface{} `json:"keep_alive,omitempty"`
Think json.RawMessage `json:"think,omitempty"`
}
type OllamaEmbeddingRequest struct {
Model string `json:"model,omitempty"`
Input []string `json:"input"`
Options *Options `json:"options,omitempty"`
Model string `json:"model"`
Input interface{} `json:"input"`
Options map[string]any `json:"options,omitempty"`
Dimensions int `json:"dimensions,omitempty"`
}
type OllamaEmbeddingResponse struct {
Error string `json:"error,omitempty"`
Model string `json:"model"`
Embedding [][]float64 `json:"embeddings,omitempty"`
Error string `json:"error,omitempty"`
Model string `json:"model"`
Embeddings [][]float64 `json:"embeddings"`
PromptEvalCount int `json:"prompt_eval_count,omitempty"`
}

View File

@@ -1,6 +1,7 @@
package ollama
import (
"encoding/json"
"fmt"
"io"
"net/http"
@@ -14,121 +15,176 @@ import (
"github.com/gin-gonic/gin"
)
func requestOpenAI2Ollama(c *gin.Context, request *dto.GeneralOpenAIRequest) (*OllamaRequest, error) {
messages := make([]dto.Message, 0, len(request.Messages))
for _, message := range request.Messages {
if !message.IsStringContent() {
mediaMessages := message.ParseContent()
for j, mediaMessage := range mediaMessages {
if mediaMessage.Type == dto.ContentTypeImageURL {
imageUrl := mediaMessage.GetImageMedia()
// check if not base64
if strings.HasPrefix(imageUrl.Url, "http") {
fileData, err := service.GetFileBase64FromUrl(c, imageUrl.Url, "formatting image for Ollama")
if err != nil {
return nil, err
func openAIChatToOllamaChat(c *gin.Context, r *dto.GeneralOpenAIRequest) (*OllamaChatRequest, error) {
chatReq := &OllamaChatRequest{
Model: r.Model,
Stream: r.Stream,
Options: map[string]any{},
Think: r.Think,
}
if r.ResponseFormat != nil {
if r.ResponseFormat.Type == "json" {
chatReq.Format = "json"
} else if r.ResponseFormat.Type == "json_schema" {
if len(r.ResponseFormat.JsonSchema) > 0 {
var schema any
_ = json.Unmarshal(r.ResponseFormat.JsonSchema, &schema)
chatReq.Format = schema
}
}
}
// options mapping
if r.Temperature != nil { chatReq.Options["temperature"] = r.Temperature }
if r.TopP != 0 { chatReq.Options["top_p"] = r.TopP }
if r.TopK != 0 { chatReq.Options["top_k"] = r.TopK }
if r.FrequencyPenalty != 0 { chatReq.Options["frequency_penalty"] = r.FrequencyPenalty }
if r.PresencePenalty != 0 { chatReq.Options["presence_penalty"] = r.PresencePenalty }
if r.Seed != 0 { chatReq.Options["seed"] = int(r.Seed) }
if mt := r.GetMaxTokens(); mt != 0 { chatReq.Options["num_predict"] = int(mt) }
if r.Stop != nil {
switch v := r.Stop.(type) {
case string:
chatReq.Options["stop"] = []string{v}
case []string:
chatReq.Options["stop"] = v
case []any:
arr := make([]string,0,len(v))
for _, i := range v { if s,ok:=i.(string); ok { arr = append(arr,s) } }
if len(arr)>0 { chatReq.Options["stop"] = arr }
}
}
if len(r.Tools) > 0 {
tools := make([]OllamaTool,0,len(r.Tools))
for _, t := range r.Tools {
tools = append(tools, OllamaTool{Type: "function", Function: OllamaToolFunction{Name: t.Function.Name, Description: t.Function.Description, Parameters: t.Function.Parameters}})
}
chatReq.Tools = tools
}
chatReq.Messages = make([]OllamaChatMessage,0,len(r.Messages))
for _, m := range r.Messages {
var textBuilder strings.Builder
var images []string
if m.IsStringContent() {
textBuilder.WriteString(m.StringContent())
} else {
parts := m.ParseContent()
for _, part := range parts {
if part.Type == dto.ContentTypeImageURL {
img := part.GetImageMedia()
if img != nil && img.Url != "" {
var base64Data string
if strings.HasPrefix(img.Url, "http") {
fileData, err := service.GetFileBase64FromUrl(c, img.Url, "fetch image for ollama chat")
if err != nil { return nil, err }
base64Data = fileData.Base64Data
} else if strings.HasPrefix(img.Url, "data:") {
if idx := strings.Index(img.Url, ","); idx != -1 && idx+1 < len(img.Url) { base64Data = img.Url[idx+1:] }
} else {
base64Data = img.Url
}
imageUrl.Url = fmt.Sprintf("data:%s;base64,%s", fileData.MimeType, fileData.Base64Data)
if base64Data != "" { images = append(images, base64Data) }
}
mediaMessage.ImageUrl = imageUrl
mediaMessages[j] = mediaMessage
} else if part.Type == dto.ContentTypeText {
textBuilder.WriteString(part.Text)
}
}
message.SetMediaContent(mediaMessages)
}
messages = append(messages, dto.Message{
Role: message.Role,
Content: message.Content,
ToolCalls: message.ToolCalls,
ToolCallId: message.ToolCallId,
})
cm := OllamaChatMessage{Role: m.Role, Content: textBuilder.String()}
if len(images)>0 { cm.Images = images }
if m.Role == "tool" && m.Name != nil { cm.ToolName = *m.Name }
if m.ToolCalls != nil && len(m.ToolCalls) > 0 {
parsed := m.ParseToolCalls()
if len(parsed) > 0 {
calls := make([]OllamaToolCall,0,len(parsed))
for _, tc := range parsed {
var args interface{}
if tc.Function.Arguments != "" { _ = json.Unmarshal([]byte(tc.Function.Arguments), &args) }
if args==nil { args = map[string]any{} }
oc := OllamaToolCall{}
oc.Function.Name = tc.Function.Name
oc.Function.Arguments = args
calls = append(calls, oc)
}
cm.ToolCalls = calls
}
}
chatReq.Messages = append(chatReq.Messages, cm)
}
str, ok := request.Stop.(string)
var Stop []string
if ok {
Stop = []string{str}
} else {
Stop, _ = request.Stop.([]string)
}
ollamaRequest := &OllamaRequest{
Model: request.Model,
Messages: messages,
Stream: request.Stream,
Temperature: request.Temperature,
Seed: request.Seed,
Topp: request.TopP,
TopK: request.TopK,
Stop: Stop,
Tools: request.Tools,
MaxTokens: request.GetMaxTokens(),
ResponseFormat: request.ResponseFormat,
FrequencyPenalty: request.FrequencyPenalty,
PresencePenalty: request.PresencePenalty,
Prompt: request.Prompt,
StreamOptions: request.StreamOptions,
Suffix: request.Suffix,
}
ollamaRequest.Think = request.Think
return ollamaRequest, nil
return chatReq, nil
}
func requestOpenAI2Embeddings(request dto.EmbeddingRequest) *OllamaEmbeddingRequest {
return &OllamaEmbeddingRequest{
Model: request.Model,
Input: request.ParseInput(),
Options: &Options{
Seed: int(request.Seed),
Temperature: request.Temperature,
TopP: request.TopP,
FrequencyPenalty: request.FrequencyPenalty,
PresencePenalty: request.PresencePenalty,
},
// openAIToGenerate converts OpenAI completions request to Ollama generate
func openAIToGenerate(c *gin.Context, r *dto.GeneralOpenAIRequest) (*OllamaGenerateRequest, error) {
gen := &OllamaGenerateRequest{
Model: r.Model,
Stream: r.Stream,
Options: map[string]any{},
Think: r.Think,
}
// Prompt may be in r.Prompt (string or []any)
if r.Prompt != nil {
switch v := r.Prompt.(type) {
case string:
gen.Prompt = v
case []any:
var sb strings.Builder
for _, it := range v { if s,ok:=it.(string); ok { sb.WriteString(s) } }
gen.Prompt = sb.String()
default:
gen.Prompt = fmt.Sprintf("%v", r.Prompt)
}
}
if r.Suffix != nil { if s,ok:=r.Suffix.(string); ok { gen.Suffix = s } }
if r.ResponseFormat != nil {
if r.ResponseFormat.Type == "json" { gen.Format = "json" } else if r.ResponseFormat.Type == "json_schema" { var schema any; _ = json.Unmarshal(r.ResponseFormat.JsonSchema,&schema); gen.Format=schema }
}
if r.Temperature != nil { gen.Options["temperature"] = r.Temperature }
if r.TopP != 0 { gen.Options["top_p"] = r.TopP }
if r.TopK != 0 { gen.Options["top_k"] = r.TopK }
if r.FrequencyPenalty != 0 { gen.Options["frequency_penalty"] = r.FrequencyPenalty }
if r.PresencePenalty != 0 { gen.Options["presence_penalty"] = r.PresencePenalty }
if r.Seed != 0 { gen.Options["seed"] = int(r.Seed) }
if mt := r.GetMaxTokens(); mt != 0 { gen.Options["num_predict"] = int(mt) }
if r.Stop != nil {
switch v := r.Stop.(type) {
case string: gen.Options["stop"] = []string{v}
case []string: gen.Options["stop"] = v
case []any: arr:=make([]string,0,len(v)); for _,i:= range v { if s,ok:=i.(string); ok { arr=append(arr,s) } }; if len(arr)>0 { gen.Options["stop"]=arr }
}
}
return gen, nil
}
func requestOpenAI2Embeddings(r dto.EmbeddingRequest) *OllamaEmbeddingRequest {
opts := map[string]any{}
if r.Temperature != nil { opts["temperature"] = r.Temperature }
if r.TopP != 0 { opts["top_p"] = r.TopP }
if r.FrequencyPenalty != 0 { opts["frequency_penalty"] = r.FrequencyPenalty }
if r.PresencePenalty != 0 { opts["presence_penalty"] = r.PresencePenalty }
if r.Seed != 0 { opts["seed"] = int(r.Seed) }
if r.Dimensions != 0 { opts["dimensions"] = r.Dimensions }
input := r.ParseInput()
if len(input)==1 { return &OllamaEmbeddingRequest{Model:r.Model, Input: input[0], Options: opts, Dimensions:r.Dimensions} }
return &OllamaEmbeddingRequest{Model:r.Model, Input: input, Options: opts, Dimensions:r.Dimensions}
}
func ollamaEmbeddingHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http.Response) (*dto.Usage, *types.NewAPIError) {
var ollamaEmbeddingResponse OllamaEmbeddingResponse
responseBody, err := io.ReadAll(resp.Body)
if err != nil {
return nil, types.NewOpenAIError(err, types.ErrorCodeBadResponseBody, http.StatusInternalServerError)
}
var oResp OllamaEmbeddingResponse
body, err := io.ReadAll(resp.Body)
if err != nil { return nil, types.NewOpenAIError(err, types.ErrorCodeBadResponseBody, http.StatusInternalServerError) }
service.CloseResponseBodyGracefully(resp)
err = common.Unmarshal(responseBody, &ollamaEmbeddingResponse)
if err != nil {
return nil, types.NewOpenAIError(err, types.ErrorCodeBadResponseBody, http.StatusInternalServerError)
}
if ollamaEmbeddingResponse.Error != "" {
return nil, types.NewOpenAIError(fmt.Errorf("ollama error: %s", ollamaEmbeddingResponse.Error), types.ErrorCodeBadResponseBody, http.StatusInternalServerError)
}
flattenedEmbeddings := flattenEmbeddings(ollamaEmbeddingResponse.Embedding)
data := make([]dto.OpenAIEmbeddingResponseItem, 0, 1)
data = append(data, dto.OpenAIEmbeddingResponseItem{
Embedding: flattenedEmbeddings,
Object: "embedding",
})
usage := &dto.Usage{
TotalTokens: info.PromptTokens,
CompletionTokens: 0,
PromptTokens: info.PromptTokens,
}
embeddingResponse := &dto.OpenAIEmbeddingResponse{
Object: "list",
Data: data,
Model: info.UpstreamModelName,
Usage: *usage,
}
doResponseBody, err := common.Marshal(embeddingResponse)
if err != nil {
return nil, types.NewOpenAIError(err, types.ErrorCodeBadResponseBody, http.StatusInternalServerError)
}
service.IOCopyBytesGracefully(c, resp, doResponseBody)
if err = common.Unmarshal(body, &oResp); err != nil { return nil, types.NewOpenAIError(err, types.ErrorCodeBadResponseBody, http.StatusInternalServerError) }
if oResp.Error != "" { return nil, types.NewOpenAIError(fmt.Errorf("ollama error: %s", oResp.Error), types.ErrorCodeBadResponseBody, http.StatusInternalServerError) }
data := make([]dto.OpenAIEmbeddingResponseItem,0,len(oResp.Embeddings))
for i, emb := range oResp.Embeddings { data = append(data, dto.OpenAIEmbeddingResponseItem{Index:i,Object:"embedding",Embedding:emb}) }
usage := &dto.Usage{PromptTokens: oResp.PromptEvalCount, CompletionTokens:0, TotalTokens: oResp.PromptEvalCount}
embResp := &dto.OpenAIEmbeddingResponse{Object:"list", Data:data, Model: info.UpstreamModelName, Usage:*usage}
out, _ := common.Marshal(embResp)
service.IOCopyBytesGracefully(c, resp, out)
return usage, nil
}
func flattenEmbeddings(embeddings [][]float64) []float64 {
flattened := []float64{}
for _, row := range embeddings {
flattened = append(flattened, row...)
}
return flattened
}

View File

@@ -0,0 +1,210 @@
package ollama
import (
"bufio"
"encoding/json"
"fmt"
"io"
"net/http"
"one-api/common"
"one-api/dto"
"one-api/logger"
relaycommon "one-api/relay/common"
"one-api/relay/helper"
"one-api/service"
"one-api/types"
"strings"
"time"
"github.com/gin-gonic/gin"
)
type ollamaChatStreamChunk struct {
Model string `json:"model"`
CreatedAt string `json:"created_at"`
// chat
Message *struct {
Role string `json:"role"`
Content string `json:"content"`
Thinking json.RawMessage `json:"thinking"`
ToolCalls []struct {
Function struct {
Name string `json:"name"`
Arguments interface{} `json:"arguments"`
} `json:"function"`
} `json:"tool_calls"`
} `json:"message"`
// generate
Response string `json:"response"`
Done bool `json:"done"`
DoneReason string `json:"done_reason"`
TotalDuration int64 `json:"total_duration"`
LoadDuration int64 `json:"load_duration"`
PromptEvalCount int `json:"prompt_eval_count"`
EvalCount int `json:"eval_count"`
PromptEvalDuration int64 `json:"prompt_eval_duration"`
EvalDuration int64 `json:"eval_duration"`
}
func toUnix(ts string) int64 {
if ts == "" { return time.Now().Unix() }
// try time.RFC3339 or with nanoseconds
t, err := time.Parse(time.RFC3339Nano, ts)
if err != nil { t2, err2 := time.Parse(time.RFC3339, ts); if err2==nil { return t2.Unix() }; return time.Now().Unix() }
return t.Unix()
}
func ollamaStreamHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http.Response) (*dto.Usage, *types.NewAPIError) {
if resp == nil || resp.Body == nil { return nil, types.NewOpenAIError(fmt.Errorf("empty response"), types.ErrorCodeBadResponse, http.StatusBadRequest) }
defer service.CloseResponseBodyGracefully(resp)
helper.SetEventStreamHeaders(c)
scanner := bufio.NewScanner(resp.Body)
usage := &dto.Usage{}
var model = info.UpstreamModelName
var responseId = common.GetUUID()
var created = time.Now().Unix()
var toolCallIndex int
start := helper.GenerateStartEmptyResponse(responseId, created, model, nil)
if data, err := common.Marshal(start); err == nil { _ = helper.StringData(c, string(data)) }
for scanner.Scan() {
line := scanner.Text()
line = strings.TrimSpace(line)
if line == "" { continue }
var chunk ollamaChatStreamChunk
if err := json.Unmarshal([]byte(line), &chunk); err != nil {
logger.LogError(c, "ollama stream json decode error: "+err.Error()+" line="+line)
return usage, types.NewOpenAIError(err, types.ErrorCodeBadResponseBody, http.StatusInternalServerError)
}
if chunk.Model != "" { model = chunk.Model }
created = toUnix(chunk.CreatedAt)
if !chunk.Done {
// delta content
var content string
if chunk.Message != nil { content = chunk.Message.Content } else { content = chunk.Response }
delta := dto.ChatCompletionsStreamResponse{
Id: responseId,
Object: "chat.completion.chunk",
Created: created,
Model: model,
Choices: []dto.ChatCompletionsStreamResponseChoice{ {
Index: 0,
Delta: dto.ChatCompletionsStreamResponseChoiceDelta{ Role: "assistant" },
} },
}
if content != "" { delta.Choices[0].Delta.SetContentString(content) }
if chunk.Message != nil && len(chunk.Message.Thinking) > 0 {
raw := strings.TrimSpace(string(chunk.Message.Thinking))
if raw != "" && raw != "null" { delta.Choices[0].Delta.SetReasoningContent(raw) }
}
// tool calls
if chunk.Message != nil && len(chunk.Message.ToolCalls) > 0 {
delta.Choices[0].Delta.ToolCalls = make([]dto.ToolCallResponse,0,len(chunk.Message.ToolCalls))
for _, tc := range chunk.Message.ToolCalls {
// arguments -> string
argBytes, _ := json.Marshal(tc.Function.Arguments)
toolId := fmt.Sprintf("call_%d", toolCallIndex)
tr := dto.ToolCallResponse{ID:toolId, Type:"function", Function: dto.FunctionResponse{Name: tc.Function.Name, Arguments: string(argBytes)}}
tr.SetIndex(toolCallIndex)
toolCallIndex++
delta.Choices[0].Delta.ToolCalls = append(delta.Choices[0].Delta.ToolCalls, tr)
}
}
if data, err := common.Marshal(delta); err == nil { _ = helper.StringData(c, string(data)) }
continue
}
// done frame
// finalize once and break loop
usage.PromptTokens = chunk.PromptEvalCount
usage.CompletionTokens = chunk.EvalCount
usage.TotalTokens = usage.PromptTokens + usage.CompletionTokens
finishReason := chunk.DoneReason
if finishReason == "" { finishReason = "stop" }
// emit stop delta
if stop := helper.GenerateStopResponse(responseId, created, model, finishReason); stop != nil {
if data, err := common.Marshal(stop); err == nil { _ = helper.StringData(c, string(data)) }
}
// emit usage frame
if final := helper.GenerateFinalUsageResponse(responseId, created, model, *usage); final != nil {
if data, err := common.Marshal(final); err == nil { _ = helper.StringData(c, string(data)) }
}
// send [DONE]
helper.Done(c)
break
}
if err := scanner.Err(); err != nil && err != io.EOF { logger.LogError(c, "ollama stream scan error: "+err.Error()) }
return usage, nil
}
// non-stream handler for chat/generate
func ollamaChatHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http.Response) (*dto.Usage, *types.NewAPIError) {
body, err := io.ReadAll(resp.Body)
if err != nil { return nil, types.NewOpenAIError(err, types.ErrorCodeReadResponseBodyFailed, http.StatusInternalServerError) }
service.CloseResponseBodyGracefully(resp)
raw := string(body)
if common.DebugEnabled { println("ollama non-stream raw resp:", raw) }
lines := strings.Split(raw, "\n")
var (
aggContent strings.Builder
reasoningBuilder strings.Builder
lastChunk ollamaChatStreamChunk
parsedAny bool
)
for _, ln := range lines {
ln = strings.TrimSpace(ln)
if ln == "" { continue }
var ck ollamaChatStreamChunk
if err := json.Unmarshal([]byte(ln), &ck); err != nil {
if len(lines) == 1 { return nil, types.NewOpenAIError(err, types.ErrorCodeBadResponseBody, http.StatusInternalServerError) }
continue
}
parsedAny = true
lastChunk = ck
if ck.Message != nil && len(ck.Message.Thinking) > 0 {
raw := strings.TrimSpace(string(ck.Message.Thinking))
if raw != "" && raw != "null" { reasoningBuilder.WriteString(raw) }
}
if ck.Message != nil && ck.Message.Content != "" { aggContent.WriteString(ck.Message.Content) } else if ck.Response != "" { aggContent.WriteString(ck.Response) }
}
if !parsedAny {
var single ollamaChatStreamChunk
if err := json.Unmarshal(body, &single); err != nil { return nil, types.NewOpenAIError(err, types.ErrorCodeBadResponseBody, http.StatusInternalServerError) }
lastChunk = single
if single.Message != nil {
if len(single.Message.Thinking) > 0 { raw := strings.TrimSpace(string(single.Message.Thinking)); if raw != "" && raw != "null" { reasoningBuilder.WriteString(raw) } }
aggContent.WriteString(single.Message.Content)
} else { aggContent.WriteString(single.Response) }
}
model := lastChunk.Model
if model == "" { model = info.UpstreamModelName }
created := toUnix(lastChunk.CreatedAt)
usage := &dto.Usage{PromptTokens: lastChunk.PromptEvalCount, CompletionTokens: lastChunk.EvalCount, TotalTokens: lastChunk.PromptEvalCount + lastChunk.EvalCount}
content := aggContent.String()
finishReason := lastChunk.DoneReason
if finishReason == "" { finishReason = "stop" }
msg := dto.Message{Role: "assistant", Content: contentPtr(content)}
if rc := reasoningBuilder.String(); rc != "" { msg.ReasoningContent = rc }
full := dto.OpenAITextResponse{
Id: common.GetUUID(),
Model: model,
Object: "chat.completion",
Created: created,
Choices: []dto.OpenAITextResponseChoice{ {
Index: 0,
Message: msg,
FinishReason: finishReason,
} },
Usage: *usage,
}
out, _ := common.Marshal(full)
service.IOCopyBytesGracefully(c, resp, out)
return usage, nil
}
func contentPtr(s string) *string { if s=="" { return nil }; return &s }

View File

@@ -12,6 +12,7 @@ import (
"one-api/constant"
"one-api/dto"
"one-api/logger"
"one-api/relay/channel/openrouter"
relaycommon "one-api/relay/common"
"one-api/relay/helper"
"one-api/service"
@@ -185,10 +186,27 @@ func OpenaiHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http.Respo
if common.DebugEnabled {
println("upstream response body:", string(responseBody))
}
// Unmarshal to simpleResponse
if info.ChannelType == constant.ChannelTypeOpenRouter && info.ChannelOtherSettings.IsOpenRouterEnterprise() {
// 尝试解析为 openrouter enterprise
var enterpriseResponse openrouter.OpenRouterEnterpriseResponse
err = common.Unmarshal(responseBody, &enterpriseResponse)
if err != nil {
return nil, types.NewOpenAIError(err, types.ErrorCodeBadResponseBody, http.StatusInternalServerError)
}
if enterpriseResponse.Success {
responseBody = enterpriseResponse.Data
} else {
logger.LogError(c, fmt.Sprintf("openrouter enterprise response success=false, data: %s", enterpriseResponse.Data))
return nil, types.NewOpenAIError(fmt.Errorf("openrouter response success=false"), types.ErrorCodeBadResponseBody, http.StatusInternalServerError)
}
}
err = common.Unmarshal(responseBody, &simpleResponse)
if err != nil {
return nil, types.NewOpenAIError(err, types.ErrorCodeBadResponseBody, http.StatusInternalServerError)
}
if oaiError := simpleResponse.GetOpenAIError(); oaiError != nil && oaiError.Type != "" {
return nil, types.WithOpenAIError(*oaiError, resp.StatusCode)
}

View File

@@ -1,5 +1,7 @@
package openrouter
import "encoding/json"
type RequestReasoning struct {
// One of the following (not both):
Effort string `json:"effort,omitempty"` // Can be "high", "medium", or "low" (OpenAI-style)
@@ -7,3 +9,8 @@ type RequestReasoning struct {
// Optional: Default is false. All models support this.
Exclude bool `json:"exclude,omitempty"` // Set to true to exclude reasoning tokens from response
}
type OpenRouterEnterpriseResponse struct {
Data json.RawMessage `json:"data"`
Success bool `json:"success"`
}

View File

@@ -9,6 +9,7 @@ import (
"mime/multipart"
"net/http"
"net/textproto"
channelconstant "one-api/constant"
"one-api/dto"
"one-api/relay/channel"
"one-api/relay/channel/openai"
@@ -188,20 +189,26 @@ func (a *Adaptor) Init(info *relaycommon.RelayInfo) {
}
func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) {
// 支持自定义域名,如果未设置则使用默认域名
baseUrl := info.ChannelBaseUrl
if baseUrl == "" {
baseUrl = channelconstant.ChannelBaseURLs[channelconstant.ChannelTypeVolcEngine]
}
switch info.RelayMode {
case constant.RelayModeChatCompletions:
if strings.HasPrefix(info.UpstreamModelName, "bot") {
return fmt.Sprintf("%s/api/v3/bots/chat/completions", info.ChannelBaseUrl), nil
return fmt.Sprintf("%s/api/v3/bots/chat/completions", baseUrl), nil
}
return fmt.Sprintf("%s/api/v3/chat/completions", info.ChannelBaseUrl), nil
return fmt.Sprintf("%s/api/v3/chat/completions", baseUrl), nil
case constant.RelayModeEmbeddings:
return fmt.Sprintf("%s/api/v3/embeddings", info.ChannelBaseUrl), nil
return fmt.Sprintf("%s/api/v3/embeddings", baseUrl), nil
case constant.RelayModeImagesGenerations:
return fmt.Sprintf("%s/api/v3/images/generations", info.ChannelBaseUrl), nil
return fmt.Sprintf("%s/api/v3/images/generations", baseUrl), nil
case constant.RelayModeImagesEdits:
return fmt.Sprintf("%s/api/v3/images/edits", info.ChannelBaseUrl), nil
return fmt.Sprintf("%s/api/v3/images/edits", baseUrl), nil
case constant.RelayModeRerank:
return fmt.Sprintf("%s/api/v3/rerank", info.ChannelBaseUrl), nil
return fmt.Sprintf("%s/api/v3/rerank", baseUrl), nil
default:
}
return "", fmt.Errorf("unsupported relay mode: %d", info.RelayMode)

View File

@@ -9,6 +9,11 @@ var ModelList = []string{
"Doubao-lite-4k",
"Doubao-embedding",
"doubao-seedream-4-0-250828",
"seedream-4-0-250828",
"doubao-seedance-1-0-pro-250528",
"seedance-1-0-pro-250528",
"doubao-seed-1-6-thinking-250715",
"seed-1-6-thinking-250715",
}
var ChannelName = "volcengine"

View File

@@ -207,10 +207,6 @@ func xunfeiMakeRequest(textRequest dto.GeneralOpenAIRequest, domain, authUrl, ap
return nil, nil, err
}
defer func() {
conn.Close()
}()
data := requestOpenAI2Xunfei(textRequest, appId, domain)
err = conn.WriteJSON(data)
if err != nil {
@@ -220,6 +216,9 @@ func xunfeiMakeRequest(textRequest dto.GeneralOpenAIRequest, domain, authUrl, ap
dataChan := make(chan XunfeiChatResponse)
stopChan := make(chan bool)
go func() {
defer func() {
conn.Close()
}()
for {
_, msg, err := conn.ReadMessage()
if err != nil {

View File

@@ -164,6 +164,8 @@ const EditChannelModal = (props) => {
settings: '',
// 仅 Vertex: 密钥格式(存入 settings.vertex_key_type
vertex_key_type: 'json',
// 企业账户设置
is_enterprise_account: false,
};
const [batch, setBatch] = useState(false);
const [multiToSingle, setMultiToSingle] = useState(false);
@@ -189,6 +191,7 @@ const EditChannelModal = (props) => {
const [channelSearchValue, setChannelSearchValue] = useState('');
const [useManualInput, setUseManualInput] = useState(false); // 是否使用手动输入模式
const [keyMode, setKeyMode] = useState('append'); // 密钥模式replace覆盖或 append追加
const [isEnterpriseAccount, setIsEnterpriseAccount] = useState(false); // 是否为企业账户
// 2FA验证查看密钥相关状态
const [twoFAState, setTwoFAState] = useState({
@@ -235,7 +238,7 @@ const EditChannelModal = (props) => {
pass_through_body_enabled: false,
system_prompt: '',
});
const showApiConfigCard = inputs.type !== 45; // 控制是否显示 API 配置卡片(仅当渠道类型不是 豆包 时显示)
const showApiConfigCard = true; // 控制是否显示 API 配置卡片
const getInitValues = () => ({ ...originInputs });
// 处理渠道额外设置的更新
@@ -342,6 +345,10 @@ const EditChannelModal = (props) => {
case 36:
localModels = ['suno_music', 'suno_lyrics'];
break;
case 45:
localModels = getChannelModels(value);
setInputs((prevInputs) => ({ ...prevInputs, base_url: 'https://ark.cn-beijing.volces.com' }));
break;
default:
localModels = getChannelModels(value);
break;
@@ -433,15 +440,19 @@ const EditChannelModal = (props) => {
parsedSettings.azure_responses_version || '';
// 读取 Vertex 密钥格式
data.vertex_key_type = parsedSettings.vertex_key_type || 'json';
// 读取企业账户设置
data.is_enterprise_account = parsedSettings.openrouter_enterprise === true;
} catch (error) {
console.error('解析其他设置失败:', error);
data.azure_responses_version = '';
data.region = '';
data.vertex_key_type = 'json';
data.is_enterprise_account = false;
}
} else {
// 兼容历史数据:老渠道没有 settings 时,默认按 json 展示
data.vertex_key_type = 'json';
data.is_enterprise_account = false;
}
setInputs(data);
@@ -453,6 +464,8 @@ const EditChannelModal = (props) => {
} else {
setAutoBan(true);
}
// 同步企业账户状态
setIsEnterpriseAccount(data.is_enterprise_account || false);
setBasicModels(getChannelModels(data.type));
// 同步更新channelSettings状态显示
setChannelSettings({
@@ -712,6 +725,8 @@ const EditChannelModal = (props) => {
});
// 重置密钥模式状态
setKeyMode('append');
// 重置企业账户状态
setIsEnterpriseAccount(false);
// 清空表单中的key_mode字段
if (formApiRef.current) {
formApiRef.current.setValue('key_mode', undefined);
@@ -844,6 +859,10 @@ const EditChannelModal = (props) => {
showInfo(t('请至少选择一个模型!'));
return;
}
if (localInputs.type === 45 && (!localInputs.base_url || localInputs.base_url.trim() === '')) {
showInfo(t('请输入API地址'));
return;
}
if (
localInputs.model_mapping &&
localInputs.model_mapping !== '' &&
@@ -873,6 +892,21 @@ const EditChannelModal = (props) => {
};
localInputs.setting = JSON.stringify(channelExtraSettings);
// 处理type === 20的企业账户设置
if (localInputs.type === 20) {
let settings = {};
if (localInputs.settings) {
try {
settings = JSON.parse(localInputs.settings);
} catch (error) {
console.error('解析settings失败:', error);
}
}
// 设置企业账户标识无论是true还是false都要传到后端
settings.openrouter_enterprise = localInputs.is_enterprise_account === true;
localInputs.settings = JSON.stringify(settings);
}
// 清理不需要发送到后端的字段
delete localInputs.force_format;
delete localInputs.thinking_to_content;
@@ -880,6 +914,7 @@ const EditChannelModal = (props) => {
delete localInputs.pass_through_body_enabled;
delete localInputs.system_prompt;
delete localInputs.system_prompt_override;
delete localInputs.is_enterprise_account;
// 顶层的 vertex_key_type 不应发送给后端
delete localInputs.vertex_key_type;
@@ -1264,6 +1299,21 @@ const EditChannelModal = (props) => {
onChange={(value) => handleInputChange('type', value)}
/>
{inputs.type === 20 && (
<Form.Switch
field='is_enterprise_account'
label={t('是否为企业账户')}
checkedText={t('是')}
uncheckedText={t('否')}
onChange={(value) => {
setIsEnterpriseAccount(value);
handleInputChange('is_enterprise_account', value);
}}
extraText={t('企业账户为特殊返回格式,需要特殊处理,如果非企业账户,请勿勾选')}
initValue={inputs.is_enterprise_account}
/>
)}
<Form.Input
field='name'
label={t('名称')}
@@ -1883,6 +1933,30 @@ const EditChannelModal = (props) => {
/>
</div>
)}
{inputs.type === 45 && (
<div>
<Form.Select
field='base_url'
label={t('API地址')}
placeholder={t('请选择API地址')}
onChange={(value) =>
handleInputChange('base_url', value)
}
optionList={[
{
value: 'https://ark.cn-beijing.volces.com',
label: 'https://ark.cn-beijing.volces.com'
},
{
value: 'https://ark.ap-southeast.bytepluses.com',
label: 'https://ark.ap-southeast.bytepluses.com'
}
]}
defaultValue='https://ark.cn-beijing.volces.com'
/>
</div>
)}
</Card>
)}

View File

@@ -25,13 +25,9 @@ import {
showInfo,
showSuccess,
loadChannelModels,
copy,
copy
} from '../../helpers';
import {
CHANNEL_OPTIONS,
ITEMS_PER_PAGE,
MODEL_TABLE_PAGE_SIZE,
} from '../../constants';
import { CHANNEL_OPTIONS, ITEMS_PER_PAGE, MODEL_TABLE_PAGE_SIZE } from '../../constants';
import { useIsMobile } from '../common/useIsMobile';
import { useTableCompactMode } from '../common/useTableCompactMode';
import { Modal } from '@douyinfe/semi-ui';
@@ -68,7 +64,7 @@ export const useChannelsData = () => {
// Status filter
const [statusFilter, setStatusFilter] = useState(
localStorage.getItem('channel-status-filter') || 'all',
localStorage.getItem('channel-status-filter') || 'all'
);
// Type tabs states
@@ -83,9 +79,10 @@ export const useChannelsData = () => {
const [testingModels, setTestingModels] = useState(new Set());
const [selectedModelKeys, setSelectedModelKeys] = useState([]);
const [isBatchTesting, setIsBatchTesting] = useState(false);
const [testQueue, setTestQueue] = useState([]);
const [isProcessingQueue, setIsProcessingQueue] = useState(false);
const [modelTablePage, setModelTablePage] = useState(1);
// 使用 ref 来避免闭包问题,类似旧版实现
const shouldStopBatchTestingRef = useRef(false);
// Multi-key management states
const [showMultiKeyManageModal, setShowMultiKeyManageModal] = useState(false);
@@ -119,12 +116,9 @@ export const useChannelsData = () => {
// Initialize from localStorage
useEffect(() => {
const localIdSort = localStorage.getItem('id-sort') === 'true';
const localPageSize =
parseInt(localStorage.getItem('page-size')) || ITEMS_PER_PAGE;
const localEnableTagMode =
localStorage.getItem('enable-tag-mode') === 'true';
const localEnableBatchDelete =
localStorage.getItem('enable-batch-delete') === 'true';
const localPageSize = parseInt(localStorage.getItem('page-size')) || ITEMS_PER_PAGE;
const localEnableTagMode = localStorage.getItem('enable-tag-mode') === 'true';
const localEnableBatchDelete = localStorage.getItem('enable-batch-delete') === 'true';
setIdSort(localIdSort);
setPageSize(localPageSize);
@@ -182,10 +176,7 @@ export const useChannelsData = () => {
// Save column preferences
useEffect(() => {
if (Object.keys(visibleColumns).length > 0) {
localStorage.setItem(
'channels-table-columns',
JSON.stringify(visibleColumns),
);
localStorage.setItem('channels-table-columns', JSON.stringify(visibleColumns));
}
}, [visibleColumns]);
@@ -299,21 +290,14 @@ export const useChannelsData = () => {
const { searchKeyword, searchGroup, searchModel } = getFormValues();
if (searchKeyword !== '' || searchGroup !== '' || searchModel !== '') {
setLoading(true);
await searchChannels(
enableTagMode,
typeKey,
statusF,
page,
pageSize,
idSort,
);
await searchChannels(enableTagMode, typeKey, statusF, page, pageSize, idSort);
setLoading(false);
return;
}
const reqId = ++requestCounter.current;
setLoading(true);
const typeParam = typeKey !== 'all' ? `&type=${typeKey}` : '';
const typeParam = (typeKey !== 'all') ? `&type=${typeKey}` : '';
const statusParam = statusF !== 'all' ? `&status=${statusF}` : '';
const res = await API.get(
`/api/channel/?p=${page}&page_size=${pageSize}&id_sort=${idSort}&tag_mode=${enableTagMode}${typeParam}${statusParam}`,
@@ -327,10 +311,7 @@ export const useChannelsData = () => {
if (success) {
const { items, total, type_counts } = data;
if (type_counts) {
const sumAll = Object.values(type_counts).reduce(
(acc, v) => acc + v,
0,
);
const sumAll = Object.values(type_counts).reduce((acc, v) => acc + v, 0);
setTypeCounts({ ...type_counts, all: sumAll });
}
setChannelFormat(items, enableTagMode);
@@ -354,18 +335,11 @@ export const useChannelsData = () => {
setSearching(true);
try {
if (searchKeyword === '' && searchGroup === '' && searchModel === '') {
await loadChannels(
page,
pageSz,
sortFlag,
enableTagMode,
typeKey,
statusF,
);
await loadChannels(page, pageSz, sortFlag, enableTagMode, typeKey, statusF);
return;
}
const typeParam = typeKey !== 'all' ? `&type=${typeKey}` : '';
const typeParam = (typeKey !== 'all') ? `&type=${typeKey}` : '';
const statusParam = statusF !== 'all' ? `&status=${statusF}` : '';
const res = await API.get(
`/api/channel/search?keyword=${searchKeyword}&group=${searchGroup}&model=${searchModel}&id_sort=${sortFlag}&tag_mode=${enableTagMode}&p=${page}&page_size=${pageSz}${typeParam}${statusParam}`,
@@ -373,10 +347,7 @@ export const useChannelsData = () => {
const { success, message, data } = res.data;
if (success) {
const { items = [], total = 0, type_counts = {} } = data;
const sumAll = Object.values(type_counts).reduce(
(acc, v) => acc + v,
0,
);
const sumAll = Object.values(type_counts).reduce((acc, v) => acc + v, 0);
setTypeCounts({ ...type_counts, all: sumAll });
setChannelFormat(items, enableTagMode);
setChannelCount(total);
@@ -395,14 +366,7 @@ export const useChannelsData = () => {
if (searchKeyword === '' && searchGroup === '' && searchModel === '') {
await loadChannels(page, pageSize, idSort, enableTagMode);
} else {
await searchChannels(
enableTagMode,
activeTypeKey,
statusFilter,
page,
pageSize,
idSort,
);
await searchChannels(enableTagMode, activeTypeKey, statusFilter, page, pageSize, idSort);
}
};
@@ -488,16 +452,9 @@ export const useChannelsData = () => {
const { searchKeyword, searchGroup, searchModel } = getFormValues();
setActivePage(page);
if (searchKeyword === '' && searchGroup === '' && searchModel === '') {
loadChannels(page, pageSize, idSort, enableTagMode).then(() => {});
loadChannels(page, pageSize, idSort, enableTagMode).then(() => { });
} else {
searchChannels(
enableTagMode,
activeTypeKey,
statusFilter,
page,
pageSize,
idSort,
);
searchChannels(enableTagMode, activeTypeKey, statusFilter, page, pageSize, idSort);
}
};
@@ -513,14 +470,7 @@ export const useChannelsData = () => {
showError(reason);
});
} else {
searchChannels(
enableTagMode,
activeTypeKey,
statusFilter,
1,
size,
idSort,
);
searchChannels(enableTagMode, activeTypeKey, statusFilter, 1, size, idSort);
}
};
@@ -551,10 +501,7 @@ export const useChannelsData = () => {
showError(res?.data?.message || t('渠道复制失败'));
}
} catch (error) {
showError(
t('渠道复制失败: ') +
(error?.response?.data?.message || error?.message || error),
);
showError(t('渠道复制失败: ') + (error?.response?.data?.message || error?.message || error));
}
};
@@ -593,11 +540,7 @@ export const useChannelsData = () => {
data.priority = parseInt(data.priority);
break;
case 'weight':
if (
data.weight === undefined ||
data.weight < 0 ||
data.weight === ''
) {
if (data.weight === undefined || data.weight < 0 || data.weight === '') {
showInfo('权重必须是非负整数!');
return;
}
@@ -740,136 +683,226 @@ export const useChannelsData = () => {
const res = await API.post(`/api/channel/fix`);
const { success, message, data } = res.data;
if (success) {
showSuccess(
t('已修复 ${success} 个通道,失败 ${fails} 个通道。')
.replace('${success}', data.success)
.replace('${fails}', data.fails),
);
showSuccess(t('已修复 ${success} 个通道,失败 ${fails} 个通道。').replace('${success}', data.success).replace('${fails}', data.fails));
await refresh();
} else {
showError(message);
}
};
// Test channel
// Test channel - 单个模型测试,参考旧版实现
const testChannel = async (record, model) => {
setTestQueue((prev) => [...prev, { channel: record, model }]);
if (!isProcessingQueue) {
setIsProcessingQueue(true);
const testKey = `${record.id}-${model}`;
// 检查是否应该停止批量测试
if (shouldStopBatchTestingRef.current && isBatchTesting) {
return Promise.resolve();
}
};
// Process test queue
const processTestQueue = async () => {
if (!isProcessingQueue || testQueue.length === 0) return;
const { channel, model, indexInFiltered } = testQueue[0];
if (currentTestChannel && currentTestChannel.id === channel.id) {
let pageNo;
if (indexInFiltered !== undefined) {
pageNo = Math.floor(indexInFiltered / MODEL_TABLE_PAGE_SIZE) + 1;
} else {
const filteredModelsList = currentTestChannel.models
.split(',')
.filter((m) =>
m.toLowerCase().includes(modelSearchKeyword.toLowerCase()),
);
const modelIdx = filteredModelsList.indexOf(model);
pageNo =
modelIdx !== -1
? Math.floor(modelIdx / MODEL_TABLE_PAGE_SIZE) + 1
: 1;
}
setModelTablePage(pageNo);
}
// 添加到正在测试的模型集合
setTestingModels(prev => new Set([...prev, model]));
try {
setTestingModels((prev) => new Set([...prev, model]));
const res = await API.get(
`/api/channel/test/${channel.id}?model=${model}`,
);
const res = await API.get(`/api/channel/test/${record.id}?model=${model}`);
// 检查是否在请求期间被停止
if (shouldStopBatchTestingRef.current && isBatchTesting) {
return Promise.resolve();
}
const { success, message, time } = res.data;
setModelTestResults((prev) => ({
// 更新测试结果
setModelTestResults(prev => ({
...prev,
[`${channel.id}-${model}`]: { success, time },
[testKey]: {
success,
message,
time: time || 0,
timestamp: Date.now()
}
}));
if (success) {
updateChannelProperty(channel.id, (ch) => {
ch.response_time = time * 1000;
ch.test_time = Date.now() / 1000;
// 更新渠道响应时间
updateChannelProperty(record.id, (channel) => {
channel.response_time = time * 1000;
channel.test_time = Date.now() / 1000;
});
if (!model) {
if (!model || model === '') {
showInfo(
t('通道 ${name} 测试成功,耗时 ${time.toFixed(2)} 秒。')
.replace('${name}', channel.name)
.replace('${name}', record.name)
.replace('${time.toFixed(2)}', time.toFixed(2)),
);
} else {
showInfo(
t('通道 ${name} 测试成功,模型 ${model} 耗时 ${time.toFixed(2)} 秒。')
.replace('${name}', record.name)
.replace('${model}', model)
.replace('${time.toFixed(2)}', time.toFixed(2)),
);
}
} else {
showError(message);
showError(`${t('模型')} ${model}: ${message}`);
}
} catch (error) {
showError(error.message);
// 处理网络错误
const testKey = `${record.id}-${model}`;
setModelTestResults(prev => ({
...prev,
[testKey]: {
success: false,
message: error.message || t('网络错误'),
time: 0,
timestamp: Date.now()
}
}));
showError(`${t('模型')} ${model}: ${error.message || t('测试失败')}`);
} finally {
setTestingModels((prev) => {
// 从正在测试的模型集合中移除
setTestingModels(prev => {
const newSet = new Set(prev);
newSet.delete(model);
return newSet;
});
}
setTestQueue((prev) => prev.slice(1));
};
// Monitor queue changes
useEffect(() => {
if (testQueue.length > 0 && isProcessingQueue) {
processTestQueue();
} else if (testQueue.length === 0 && isProcessingQueue) {
setIsProcessingQueue(false);
setIsBatchTesting(false);
}
}, [testQueue, isProcessingQueue]);
// Batch test models
// 批量测试单个渠道的所有模型,参考旧版实现
const batchTestModels = async () => {
if (!currentTestChannel) return;
if (!currentTestChannel || !currentTestChannel.models) {
showError(t('渠道模型信息不完整'));
return;
}
const models = currentTestChannel.models.split(',').filter(model =>
model.toLowerCase().includes(modelSearchKeyword.toLowerCase())
);
if (models.length === 0) {
showError(t('没有找到匹配的模型'));
return;
}
setIsBatchTesting(true);
setModelTablePage(1);
shouldStopBatchTestingRef.current = false; // 重置停止标志
const filteredModels = currentTestChannel.models
.split(',')
.filter((model) =>
model.toLowerCase().includes(modelSearchKeyword.toLowerCase()),
);
// 清空该渠道之前的测试结果
setModelTestResults(prev => {
const newResults = { ...prev };
models.forEach(model => {
const testKey = `${currentTestChannel.id}-${model}`;
delete newResults[testKey];
});
return newResults;
});
setTestQueue(
filteredModels.map((model, idx) => ({
channel: currentTestChannel,
model,
indexInFiltered: idx,
})),
);
setIsProcessingQueue(true);
try {
showInfo(t('开始批量测试 ${count} 个模型,已清空上次结果...').replace('${count}', models.length));
// 提高并发数量以加快测试速度,参考旧版的并发限制
const concurrencyLimit = 5;
const results = [];
for (let i = 0; i < models.length; i += concurrencyLimit) {
// 检查是否应该停止
if (shouldStopBatchTestingRef.current) {
showInfo(t('批量测试已停止'));
break;
}
const batch = models.slice(i, i + concurrencyLimit);
showInfo(t('正在测试第 ${current} - ${end} 个模型 (共 ${total} 个)')
.replace('${current}', i + 1)
.replace('${end}', Math.min(i + concurrencyLimit, models.length))
.replace('${total}', models.length)
);
const batchPromises = batch.map(model => testChannel(currentTestChannel, model));
const batchResults = await Promise.allSettled(batchPromises);
results.push(...batchResults);
// 再次检查是否应该停止
if (shouldStopBatchTestingRef.current) {
showInfo(t('批量测试已停止'));
break;
}
// 短暂延迟避免过于频繁的请求
if (i + concurrencyLimit < models.length) {
await new Promise(resolve => setTimeout(resolve, 100));
}
}
if (!shouldStopBatchTestingRef.current) {
// 等待一小段时间确保所有结果都已更新
await new Promise(resolve => setTimeout(resolve, 300));
// 使用当前状态重新计算结果统计
setModelTestResults(currentResults => {
let successCount = 0;
let failCount = 0;
models.forEach(model => {
const testKey = `${currentTestChannel.id}-${model}`;
const result = currentResults[testKey];
if (result && result.success) {
successCount++;
} else {
failCount++;
}
});
// 显示完成消息
setTimeout(() => {
showSuccess(t('批量测试完成!成功: ${success}, 失败: ${fail}, 总计: ${total}')
.replace('${success}', successCount)
.replace('${fail}', failCount)
.replace('${total}', models.length)
);
}, 100);
return currentResults; // 不修改状态,只是为了获取最新值
});
}
} catch (error) {
showError(t('批量测试过程中发生错误: ') + error.message);
} finally {
setIsBatchTesting(false);
}
};
// 停止批量测试
const stopBatchTesting = () => {
shouldStopBatchTestingRef.current = true;
setIsBatchTesting(false);
setTestingModels(new Set());
showInfo(t('已停止批量测试'));
};
// 清空测试结果
const clearTestResults = () => {
setModelTestResults({});
showInfo(t('已清空测试结果'));
};
// Handle close modal
const handleCloseModal = () => {
// 如果正在批量测试,先停止测试
if (isBatchTesting) {
setTestQueue([]);
setIsProcessingQueue(false);
setIsBatchTesting(false);
showSuccess(t('已停止测试'));
} else {
setShowModelTestModal(false);
setModelSearchKeyword('');
setSelectedModelKeys([]);
setModelTablePage(1);
shouldStopBatchTestingRef.current = true;
showInfo(t('关闭弹窗,已停止批量测试'));
}
setShowModelTestModal(false);
setModelSearchKeyword('');
setIsBatchTesting(false);
setTestingModels(new Set());
setSelectedModelKeys([]);
setModelTablePage(1);
// 可选择性保留测试结果,这里不清空以便用户查看
};
// Type counts
@@ -1012,4 +1045,4 @@ export const useChannelsData = () => {
setCompactMode,
setActivePage,
};
};
};