mirror of
https://github.com/QuantumNous/new-api.git
synced 2026-04-07 21:37:27 +00:00
Compare commits
69 Commits
refactor/s
...
v0.9.1.0
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
a91f3e7556 | ||
|
|
bf9a5f5b52 | ||
|
|
7d49ce6da7 | ||
|
|
a2b5efb6bd | ||
|
|
d916456801 | ||
|
|
9a1ef8b957 | ||
|
|
e2798fa62f | ||
|
|
b08f1889e8 | ||
|
|
045ba23566 | ||
|
|
7fe969c2ce | ||
|
|
b91eb8a5ac | ||
|
|
6e6a96d19f | ||
|
|
f6be18eca4 | ||
|
|
bdefed7b0a | ||
|
|
ee7ce5a476 | ||
|
|
6659a8a569 | ||
|
|
466d19c33d | ||
|
|
486c828df0 | ||
|
|
c68fd36ee1 | ||
|
|
74122e4175 | ||
|
|
2e4405e2bd | ||
|
|
923308a899 | ||
|
|
e4efa34e6a | ||
|
|
143a2def24 | ||
|
|
ffc077490c | ||
|
|
476cf10495 | ||
|
|
b294ff5e96 | ||
|
|
096141bfef | ||
|
|
9e8b9995a6 | ||
|
|
a498da7ab2 | ||
|
|
ad72500941 | ||
|
|
79859a3fc6 | ||
|
|
5197d874d7 | ||
|
|
e9e9708d1e | ||
|
|
e0c6900195 | ||
|
|
bf99ead4a4 | ||
|
|
474db61e56 | ||
|
|
406be515db | ||
|
|
7794788b1e | ||
|
|
2f74cc077b | ||
|
|
25a8473e85 | ||
|
|
c25f487c8f | ||
|
|
4f05c8eafb | ||
|
|
f4d95bf1c4 | ||
|
|
391d4514c0 | ||
|
|
c89c8a7396 | ||
|
|
d2defa1253 | ||
|
|
127029d62d | ||
|
|
6c5181977d | ||
|
|
b69245212a | ||
|
|
2a54e989b4 | ||
|
|
2ffdf738bd | ||
|
|
b4a6721948 | ||
|
|
6c0b1681f9 | ||
|
|
4b98773e9a | ||
|
|
f19b5b8680 | ||
|
|
69a88a0563 | ||
|
|
1dd78b83b7 | ||
|
|
62549717e0 | ||
|
|
4eeca081fe | ||
|
|
9d952e0d78 | ||
|
|
f7d393fc72 | ||
|
|
176fd6eda1 | ||
|
|
7d6ba52d85 | ||
|
|
fc38c480a1 | ||
|
|
51c4cd9ab5 | ||
|
|
99a8b5eef0 | ||
|
|
ef0780c096 | ||
|
|
da98972dda |
8
.github/workflows/linux-release.yml
vendored
8
.github/workflows/linux-release.yml
vendored
@@ -38,21 +38,21 @@ jobs:
|
||||
- name: Build Backend (amd64)
|
||||
run: |
|
||||
go mod download
|
||||
go build -ldflags "-s -w -X 'one-api/common.Version=$(git describe --tags)' -extldflags '-static'" -o one-api
|
||||
go build -ldflags "-s -w -X 'one-api/common.Version=$(git describe --tags)' -extldflags '-static'" -o new-api
|
||||
|
||||
- name: Build Backend (arm64)
|
||||
run: |
|
||||
sudo apt-get update
|
||||
DEBIAN_FRONTEND=noninteractive sudo apt-get install -y gcc-aarch64-linux-gnu
|
||||
CC=aarch64-linux-gnu-gcc CGO_ENABLED=1 GOOS=linux GOARCH=arm64 go build -ldflags "-s -w -X 'one-api/common.Version=$(git describe --tags)' -extldflags '-static'" -o one-api-arm64
|
||||
CC=aarch64-linux-gnu-gcc CGO_ENABLED=1 GOOS=linux GOARCH=arm64 go build -ldflags "-s -w -X 'one-api/common.Version=$(git describe --tags)' -extldflags '-static'" -o new-api-arm64
|
||||
|
||||
- name: Release
|
||||
uses: softprops/action-gh-release@v1
|
||||
if: startsWith(github.ref, 'refs/tags/')
|
||||
with:
|
||||
files: |
|
||||
one-api
|
||||
one-api-arm64
|
||||
new-api
|
||||
new-api-arm64
|
||||
draft: true
|
||||
generate_release_notes: true
|
||||
env:
|
||||
|
||||
4
.github/workflows/macos-release.yml
vendored
4
.github/workflows/macos-release.yml
vendored
@@ -39,12 +39,12 @@ jobs:
|
||||
- name: Build Backend
|
||||
run: |
|
||||
go mod download
|
||||
go build -ldflags "-X 'one-api/common.Version=$(git describe --tags)'" -o one-api-macos
|
||||
go build -ldflags "-X 'one-api/common.Version=$(git describe --tags)'" -o new-api-macos
|
||||
- name: Release
|
||||
uses: softprops/action-gh-release@v1
|
||||
if: startsWith(github.ref, 'refs/tags/')
|
||||
with:
|
||||
files: one-api-macos
|
||||
files: new-api-macos
|
||||
draft: true
|
||||
generate_release_notes: true
|
||||
env:
|
||||
|
||||
4
.github/workflows/windows-release.yml
vendored
4
.github/workflows/windows-release.yml
vendored
@@ -41,12 +41,12 @@ jobs:
|
||||
- name: Build Backend
|
||||
run: |
|
||||
go mod download
|
||||
go build -ldflags "-s -w -X 'one-api/common.Version=$(git describe --tags)'" -o one-api.exe
|
||||
go build -ldflags "-s -w -X 'one-api/common.Version=$(git describe --tags)'" -o new-api.exe
|
||||
- name: Release
|
||||
uses: softprops/action-gh-release@v1
|
||||
if: startsWith(github.ref, 'refs/tags/')
|
||||
with:
|
||||
files: one-api.exe
|
||||
files: new-api.exe
|
||||
draft: true
|
||||
generate_release_notes: true
|
||||
env:
|
||||
|
||||
@@ -2,9 +2,10 @@ package common
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"github.com/gin-gonic/gin"
|
||||
"os"
|
||||
"time"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
func SysLog(s string) {
|
||||
@@ -22,3 +23,33 @@ func FatalLog(v ...any) {
|
||||
_, _ = fmt.Fprintf(gin.DefaultErrorWriter, "[FATAL] %v | %v \n", t.Format("2006/01/02 - 15:04:05"), v)
|
||||
os.Exit(1)
|
||||
}
|
||||
|
||||
func LogStartupSuccess(startTime time.Time, port string) {
|
||||
|
||||
duration := time.Since(startTime)
|
||||
durationMs := duration.Milliseconds()
|
||||
|
||||
// Get network IPs
|
||||
networkIps := GetNetworkIps()
|
||||
|
||||
// Print blank line for spacing
|
||||
fmt.Fprintf(gin.DefaultWriter, "\n")
|
||||
|
||||
// Print the main success message
|
||||
fmt.Fprintf(gin.DefaultWriter, " \033[32m%s %s\033[0m ready in %d ms\n", SystemName, Version, durationMs)
|
||||
fmt.Fprintf(gin.DefaultWriter, "\n")
|
||||
|
||||
// Skip fancy startup message in container environments
|
||||
if !IsRunningInContainer() {
|
||||
// Print local URL
|
||||
fmt.Fprintf(gin.DefaultWriter, " ➜ \033[1mLocal:\033[0m http://localhost:%s/\n", port)
|
||||
}
|
||||
|
||||
// Print network URLs
|
||||
for _, ip := range networkIps {
|
||||
fmt.Fprintf(gin.DefaultWriter, " ➜ \033[1mNetwork:\033[0m http://%s:%s/\n", ip, port)
|
||||
}
|
||||
|
||||
// Print blank line for spacing
|
||||
fmt.Fprintf(gin.DefaultWriter, "\n")
|
||||
}
|
||||
|
||||
@@ -68,6 +68,78 @@ func GetIp() (ip string) {
|
||||
return
|
||||
}
|
||||
|
||||
func GetNetworkIps() []string {
|
||||
var networkIps []string
|
||||
ips, err := net.InterfaceAddrs()
|
||||
if err != nil {
|
||||
log.Println(err)
|
||||
return networkIps
|
||||
}
|
||||
|
||||
for _, a := range ips {
|
||||
if ipNet, ok := a.(*net.IPNet); ok && !ipNet.IP.IsLoopback() {
|
||||
if ipNet.IP.To4() != nil {
|
||||
ip := ipNet.IP.String()
|
||||
// Include common private network ranges
|
||||
if strings.HasPrefix(ip, "10.") ||
|
||||
strings.HasPrefix(ip, "172.") ||
|
||||
strings.HasPrefix(ip, "192.168.") {
|
||||
networkIps = append(networkIps, ip)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
return networkIps
|
||||
}
|
||||
|
||||
// IsRunningInContainer detects if the application is running inside a container
|
||||
func IsRunningInContainer() bool {
|
||||
// Method 1: Check for .dockerenv file (Docker containers)
|
||||
if _, err := os.Stat("/.dockerenv"); err == nil {
|
||||
return true
|
||||
}
|
||||
|
||||
// Method 2: Check cgroup for container indicators
|
||||
if data, err := os.ReadFile("/proc/1/cgroup"); err == nil {
|
||||
content := string(data)
|
||||
if strings.Contains(content, "docker") ||
|
||||
strings.Contains(content, "containerd") ||
|
||||
strings.Contains(content, "kubepods") ||
|
||||
strings.Contains(content, "/lxc/") {
|
||||
return true
|
||||
}
|
||||
}
|
||||
|
||||
// Method 3: Check environment variables commonly set by container runtimes
|
||||
containerEnvVars := []string{
|
||||
"KUBERNETES_SERVICE_HOST",
|
||||
"DOCKER_CONTAINER",
|
||||
"container",
|
||||
}
|
||||
|
||||
for _, envVar := range containerEnvVars {
|
||||
if os.Getenv(envVar) != "" {
|
||||
return true
|
||||
}
|
||||
}
|
||||
|
||||
// Method 4: Check if init process is not the traditional init
|
||||
if data, err := os.ReadFile("/proc/1/comm"); err == nil {
|
||||
comm := strings.TrimSpace(string(data))
|
||||
// In containers, process 1 is often not "init" or "systemd"
|
||||
if comm != "init" && comm != "systemd" {
|
||||
// Additional check: if it's a common container entrypoint
|
||||
if strings.Contains(comm, "docker") ||
|
||||
strings.Contains(comm, "containerd") ||
|
||||
strings.Contains(comm, "runc") {
|
||||
return true
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return false
|
||||
}
|
||||
|
||||
var sizeKB = 1024
|
||||
var sizeMB = sizeKB * 1024
|
||||
var sizeGB = sizeMB * 1024
|
||||
|
||||
@@ -8,6 +8,7 @@ import (
|
||||
"one-api/constant"
|
||||
"one-api/dto"
|
||||
"one-api/model"
|
||||
"one-api/service"
|
||||
"strconv"
|
||||
"strings"
|
||||
|
||||
@@ -188,6 +189,8 @@ func FetchUpstreamModels(c *gin.Context) {
|
||||
url = fmt.Sprintf("%s/v1beta/openai/models", baseURL) // Remove key in url since we need to use AuthHeader
|
||||
case constant.ChannelTypeAli:
|
||||
url = fmt.Sprintf("%s/compatible-mode/v1/models", baseURL)
|
||||
case constant.ChannelTypeZhipu_v4:
|
||||
url = fmt.Sprintf("%s/api/paas/v4/models", baseURL)
|
||||
default:
|
||||
url = fmt.Sprintf("%s/v1/models", baseURL)
|
||||
}
|
||||
@@ -631,6 +634,7 @@ func AddChannel(c *gin.Context) {
|
||||
common.ApiError(c, err)
|
||||
return
|
||||
}
|
||||
service.ResetProxyClientCache()
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": true,
|
||||
"message": "",
|
||||
@@ -892,6 +896,7 @@ func UpdateChannel(c *gin.Context) {
|
||||
return
|
||||
}
|
||||
model.InitChannelCache()
|
||||
service.ResetProxyClientCache()
|
||||
channel.Key = ""
|
||||
clearChannelInfo(&channel.Channel)
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
@@ -1101,8 +1106,8 @@ func CopyChannel(c *gin.Context) {
|
||||
// MultiKeyManageRequest represents the request for multi-key management operations
|
||||
type MultiKeyManageRequest struct {
|
||||
ChannelId int `json:"channel_id"`
|
||||
Action string `json:"action"` // "disable_key", "enable_key", "delete_disabled_keys", "get_key_status"
|
||||
KeyIndex *int `json:"key_index,omitempty"` // for disable_key and enable_key actions
|
||||
Action string `json:"action"` // "disable_key", "enable_key", "delete_key", "delete_disabled_keys", "get_key_status"
|
||||
KeyIndex *int `json:"key_index,omitempty"` // for disable_key, enable_key, and delete_key actions
|
||||
Page int `json:"page,omitempty"` // for get_key_status pagination
|
||||
PageSize int `json:"page_size,omitempty"` // for get_key_status pagination
|
||||
Status *int `json:"status,omitempty"` // for get_key_status filtering: 1=enabled, 2=manual_disabled, 3=auto_disabled, nil=all
|
||||
@@ -1430,6 +1435,86 @@ func ManageMultiKeys(c *gin.Context) {
|
||||
})
|
||||
return
|
||||
|
||||
case "delete_key":
|
||||
if request.KeyIndex == nil {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": false,
|
||||
"message": "未指定要删除的密钥索引",
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
keyIndex := *request.KeyIndex
|
||||
if keyIndex < 0 || keyIndex >= channel.ChannelInfo.MultiKeySize {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": false,
|
||||
"message": "密钥索引超出范围",
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
keys := channel.GetKeys()
|
||||
var remainingKeys []string
|
||||
var newStatusList = make(map[int]int)
|
||||
var newDisabledTime = make(map[int]int64)
|
||||
var newDisabledReason = make(map[int]string)
|
||||
|
||||
newIndex := 0
|
||||
for i, key := range keys {
|
||||
// 跳过要删除的密钥
|
||||
if i == keyIndex {
|
||||
continue
|
||||
}
|
||||
|
||||
remainingKeys = append(remainingKeys, key)
|
||||
|
||||
// 保留其他密钥的状态信息,重新索引
|
||||
if channel.ChannelInfo.MultiKeyStatusList != nil {
|
||||
if status, exists := channel.ChannelInfo.MultiKeyStatusList[i]; exists && status != 1 {
|
||||
newStatusList[newIndex] = status
|
||||
}
|
||||
}
|
||||
if channel.ChannelInfo.MultiKeyDisabledTime != nil {
|
||||
if t, exists := channel.ChannelInfo.MultiKeyDisabledTime[i]; exists {
|
||||
newDisabledTime[newIndex] = t
|
||||
}
|
||||
}
|
||||
if channel.ChannelInfo.MultiKeyDisabledReason != nil {
|
||||
if r, exists := channel.ChannelInfo.MultiKeyDisabledReason[i]; exists {
|
||||
newDisabledReason[newIndex] = r
|
||||
}
|
||||
}
|
||||
newIndex++
|
||||
}
|
||||
|
||||
if len(remainingKeys) == 0 {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": false,
|
||||
"message": "不能删除最后一个密钥",
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
// Update channel with remaining keys
|
||||
channel.Key = strings.Join(remainingKeys, "\n")
|
||||
channel.ChannelInfo.MultiKeySize = len(remainingKeys)
|
||||
channel.ChannelInfo.MultiKeyStatusList = newStatusList
|
||||
channel.ChannelInfo.MultiKeyDisabledTime = newDisabledTime
|
||||
channel.ChannelInfo.MultiKeyDisabledReason = newDisabledReason
|
||||
|
||||
err = channel.Update()
|
||||
if err != nil {
|
||||
common.ApiError(c, err)
|
||||
return
|
||||
}
|
||||
|
||||
model.InitChannelCache()
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": true,
|
||||
"message": "密钥已删除",
|
||||
})
|
||||
return
|
||||
|
||||
case "delete_disabled_keys":
|
||||
keys := channel.GetKeys()
|
||||
var remainingKeys []string
|
||||
|
||||
@@ -225,7 +225,8 @@ func genStripeLink(referenceId string, customerId string, email string, amount i
|
||||
Quantity: stripe.Int64(amount),
|
||||
},
|
||||
},
|
||||
Mode: stripe.String(string(stripe.CheckoutSessionModePayment)),
|
||||
Mode: stripe.String(string(stripe.CheckoutSessionModePayment)),
|
||||
AllowPromotionCodes: stripe.Bool(setting.StripePromotionCodesEnabled),
|
||||
}
|
||||
|
||||
if "" == customerId {
|
||||
|
||||
@@ -19,4 +19,12 @@ const (
|
||||
type ChannelOtherSettings struct {
|
||||
AzureResponsesVersion string `json:"azure_responses_version,omitempty"`
|
||||
VertexKeyType VertexKeyType `json:"vertex_key_type,omitempty"` // "json" or "api_key"
|
||||
OpenRouterEnterprise *bool `json:"openrouter_enterprise,omitempty"`
|
||||
}
|
||||
|
||||
func (s *ChannelOtherSettings) IsOpenRouterEnterprise() bool {
|
||||
if s == nil || s.OpenRouterEnterprise == nil {
|
||||
return false
|
||||
}
|
||||
return *s.OpenRouterEnterprise
|
||||
}
|
||||
|
||||
@@ -14,7 +14,30 @@ type GeminiChatRequest struct {
|
||||
SafetySettings []GeminiChatSafetySettings `json:"safetySettings,omitempty"`
|
||||
GenerationConfig GeminiChatGenerationConfig `json:"generationConfig,omitempty"`
|
||||
Tools json.RawMessage `json:"tools,omitempty"`
|
||||
ToolConfig *ToolConfig `json:"toolConfig,omitempty"`
|
||||
SystemInstructions *GeminiChatContent `json:"systemInstruction,omitempty"`
|
||||
CachedContent string `json:"cachedContent,omitempty"`
|
||||
}
|
||||
|
||||
type ToolConfig struct {
|
||||
FunctionCallingConfig *FunctionCallingConfig `json:"functionCallingConfig,omitempty"`
|
||||
RetrievalConfig *RetrievalConfig `json:"retrievalConfig,omitempty"`
|
||||
}
|
||||
|
||||
type FunctionCallingConfig struct {
|
||||
Mode FunctionCallingConfigMode `json:"mode,omitempty"`
|
||||
AllowedFunctionNames []string `json:"allowedFunctionNames,omitempty"`
|
||||
}
|
||||
type FunctionCallingConfigMode string
|
||||
|
||||
type RetrievalConfig struct {
|
||||
LatLng *LatLng `json:"latLng,omitempty"`
|
||||
LanguageCode string `json:"languageCode,omitempty"`
|
||||
}
|
||||
|
||||
type LatLng struct {
|
||||
Latitude *float64 `json:"latitude,omitempty"`
|
||||
Longitude *float64 `json:"longitude,omitempty"`
|
||||
}
|
||||
|
||||
func (r *GeminiChatRequest) GetTokenCountMeta() *types.TokenCountMeta {
|
||||
@@ -228,6 +251,7 @@ type GeminiChatTool struct {
|
||||
GoogleSearchRetrieval any `json:"googleSearchRetrieval,omitempty"`
|
||||
CodeExecution any `json:"codeExecution,omitempty"`
|
||||
FunctionDeclarations any `json:"functionDeclarations,omitempty"`
|
||||
URLContext any `json:"urlContext,omitempty"`
|
||||
}
|
||||
|
||||
type GeminiChatGenerationConfig struct {
|
||||
@@ -239,12 +263,20 @@ type GeminiChatGenerationConfig struct {
|
||||
StopSequences []string `json:"stopSequences,omitempty"`
|
||||
ResponseMimeType string `json:"responseMimeType,omitempty"`
|
||||
ResponseSchema any `json:"responseSchema,omitempty"`
|
||||
ResponseJsonSchema json.RawMessage `json:"responseJsonSchema,omitempty"`
|
||||
PresencePenalty *float32 `json:"presencePenalty,omitempty"`
|
||||
FrequencyPenalty *float32 `json:"frequencyPenalty,omitempty"`
|
||||
ResponseLogprobs bool `json:"responseLogprobs,omitempty"`
|
||||
Logprobs *int32 `json:"logprobs,omitempty"`
|
||||
MediaResolution MediaResolution `json:"mediaResolution,omitempty"`
|
||||
Seed int64 `json:"seed,omitempty"`
|
||||
ResponseModalities []string `json:"responseModalities,omitempty"`
|
||||
ThinkingConfig *GeminiThinkingConfig `json:"thinkingConfig,omitempty"`
|
||||
SpeechConfig json.RawMessage `json:"speechConfig,omitempty"` // RawMessage to allow flexible speech config
|
||||
}
|
||||
|
||||
type MediaResolution string
|
||||
|
||||
type GeminiChatCandidate struct {
|
||||
Content GeminiChatContent `json:"content"`
|
||||
FinishReason *string `json:"finishReason"`
|
||||
|
||||
26
main.go
26
main.go
@@ -1,6 +1,7 @@
|
||||
package main
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"embed"
|
||||
"fmt"
|
||||
"log"
|
||||
@@ -16,6 +17,8 @@ import (
|
||||
"one-api/setting/ratio_setting"
|
||||
"os"
|
||||
"strconv"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/bytedance/gopkg/util/gopool"
|
||||
"github.com/gin-contrib/sessions"
|
||||
@@ -33,6 +36,7 @@ var buildFS embed.FS
|
||||
var indexPage []byte
|
||||
|
||||
func main() {
|
||||
startTime := time.Now()
|
||||
|
||||
err := InitResources()
|
||||
if err != nil {
|
||||
@@ -145,11 +149,31 @@ func main() {
|
||||
})
|
||||
server.Use(sessions.Sessions("session", store))
|
||||
|
||||
analyticsInjectBuilder := &strings.Builder{}
|
||||
if os.Getenv("UMAMI_WEBSITE_ID") != "" {
|
||||
umamiSiteID := os.Getenv("UMAMI_WEBSITE_ID")
|
||||
umamiScriptURL := os.Getenv("UMAMI_SCRIPT_URL")
|
||||
if umamiScriptURL == "" {
|
||||
umamiScriptURL = "https://analytics.umami.is/script.js"
|
||||
}
|
||||
analyticsInjectBuilder.WriteString("<script defer src=\"")
|
||||
analyticsInjectBuilder.WriteString(umamiScriptURL)
|
||||
analyticsInjectBuilder.WriteString("\" data-website-id=\"")
|
||||
analyticsInjectBuilder.WriteString(umamiSiteID)
|
||||
analyticsInjectBuilder.WriteString("\"></script>")
|
||||
}
|
||||
analyticsInject := analyticsInjectBuilder.String()
|
||||
indexPage = bytes.ReplaceAll(indexPage, []byte("<analytics></analytics>\n"), []byte(analyticsInject))
|
||||
|
||||
router.SetRouter(server, buildFS, indexPage)
|
||||
var port = os.Getenv("PORT")
|
||||
if port == "" {
|
||||
port = strconv.Itoa(*common.Port)
|
||||
}
|
||||
|
||||
// Log startup success message
|
||||
common.LogStartupSuccess(startTime, port)
|
||||
|
||||
err = server.Run(":" + port)
|
||||
if err != nil {
|
||||
common.FatalLog("failed to start HTTP server: " + err.Error())
|
||||
@@ -204,4 +228,4 @@ func InitResources() error {
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
@@ -82,6 +82,7 @@ func InitOptionMap() {
|
||||
common.OptionMap["StripeWebhookSecret"] = setting.StripeWebhookSecret
|
||||
common.OptionMap["StripePriceId"] = setting.StripePriceId
|
||||
common.OptionMap["StripeUnitPrice"] = strconv.FormatFloat(setting.StripeUnitPrice, 'f', -1, 64)
|
||||
common.OptionMap["StripePromotionCodesEnabled"] = strconv.FormatBool(setting.StripePromotionCodesEnabled)
|
||||
common.OptionMap["TopupGroupRatio"] = common.TopupGroupRatio2JSONString()
|
||||
common.OptionMap["Chats"] = setting.Chats2JsonString()
|
||||
common.OptionMap["AutoGroups"] = setting.AutoGroups2JsonString()
|
||||
@@ -330,6 +331,8 @@ func updateOptionMap(key string, value string) (err error) {
|
||||
setting.StripeUnitPrice, _ = strconv.ParseFloat(value, 64)
|
||||
case "StripeMinTopUp":
|
||||
setting.StripeMinTopUp, _ = strconv.Atoi(value)
|
||||
case "StripePromotionCodesEnabled":
|
||||
setting.StripePromotionCodesEnabled = value == "true"
|
||||
case "TopupGroupRatio":
|
||||
err = common.UpdateTopupGroupRatioByJSONString(value)
|
||||
case "GitHubClientId":
|
||||
|
||||
@@ -265,6 +265,7 @@ func doRequest(c *gin.Context, req *http.Request, info *common.RelayInfo) (*http
|
||||
|
||||
resp, err := client.Do(req)
|
||||
if err != nil {
|
||||
logger.LogError(c, "do request failed: "+err.Error())
|
||||
return nil, types.NewError(err, types.ErrorCodeDoRequestFailed, types.ErrOptionWithHideErrMsg("upstream error: do request failed"))
|
||||
}
|
||||
if resp == nil {
|
||||
|
||||
@@ -21,6 +21,10 @@ var awsModelIDMap = map[string]string{
|
||||
"nova-lite-v1:0": "amazon.nova-lite-v1:0",
|
||||
"nova-pro-v1:0": "amazon.nova-pro-v1:0",
|
||||
"nova-premier-v1:0": "amazon.nova-premier-v1:0",
|
||||
"nova-canvas-v1:0": "amazon.nova-canvas-v1:0",
|
||||
"nova-reel-v1:0": "amazon.nova-reel-v1:0",
|
||||
"nova-reel-v1:1": "amazon.nova-reel-v1:1",
|
||||
"nova-sonic-v1:0": "amazon.nova-sonic-v1:0",
|
||||
}
|
||||
|
||||
var awsModelCanCrossRegionMap = map[string]map[string]bool{
|
||||
@@ -82,10 +86,27 @@ var awsModelCanCrossRegionMap = map[string]map[string]bool{
|
||||
"apac": true,
|
||||
},
|
||||
"amazon.nova-premier-v1:0": {
|
||||
"us": true,
|
||||
},
|
||||
"amazon.nova-canvas-v1:0": {
|
||||
"us": true,
|
||||
"eu": true,
|
||||
"apac": true,
|
||||
}}
|
||||
},
|
||||
"amazon.nova-reel-v1:0": {
|
||||
"us": true,
|
||||
"eu": true,
|
||||
"apac": true,
|
||||
},
|
||||
"amazon.nova-reel-v1:1": {
|
||||
"us": true,
|
||||
},
|
||||
"amazon.nova-sonic-v1:0": {
|
||||
"us": true,
|
||||
"eu": true,
|
||||
"apac": true,
|
||||
},
|
||||
}
|
||||
|
||||
var awsRegionCrossModelPrefixMap = map[string]string{
|
||||
"us": "us",
|
||||
|
||||
@@ -245,6 +245,7 @@ func CovertGemini2OpenAI(c *gin.Context, textRequest dto.GeneralOpenAIRequest, i
|
||||
functions := make([]dto.FunctionRequest, 0, len(textRequest.Tools))
|
||||
googleSearch := false
|
||||
codeExecution := false
|
||||
urlContext := false
|
||||
for _, tool := range textRequest.Tools {
|
||||
if tool.Function.Name == "googleSearch" {
|
||||
googleSearch = true
|
||||
@@ -254,6 +255,10 @@ func CovertGemini2OpenAI(c *gin.Context, textRequest dto.GeneralOpenAIRequest, i
|
||||
codeExecution = true
|
||||
continue
|
||||
}
|
||||
if tool.Function.Name == "urlContext" {
|
||||
urlContext = true
|
||||
continue
|
||||
}
|
||||
if tool.Function.Parameters != nil {
|
||||
|
||||
params, ok := tool.Function.Parameters.(map[string]interface{})
|
||||
@@ -281,6 +286,11 @@ func CovertGemini2OpenAI(c *gin.Context, textRequest dto.GeneralOpenAIRequest, i
|
||||
GoogleSearch: make(map[string]string),
|
||||
})
|
||||
}
|
||||
if urlContext {
|
||||
geminiTools = append(geminiTools, dto.GeminiChatTool{
|
||||
URLContext: make(map[string]string),
|
||||
})
|
||||
}
|
||||
if len(functions) > 0 {
|
||||
geminiTools = append(geminiTools, dto.GeminiChatTool{
|
||||
FunctionDeclarations: functions,
|
||||
|
||||
@@ -10,6 +10,7 @@ import (
|
||||
relaycommon "one-api/relay/common"
|
||||
relayconstant "one-api/relay/constant"
|
||||
"one-api/types"
|
||||
"strings"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
@@ -17,10 +18,7 @@ import (
|
||||
type Adaptor struct {
|
||||
}
|
||||
|
||||
func (a *Adaptor) ConvertGeminiRequest(*gin.Context, *relaycommon.RelayInfo, *dto.GeminiChatRequest) (any, error) {
|
||||
//TODO implement me
|
||||
return nil, errors.New("not implemented")
|
||||
}
|
||||
func (a *Adaptor) ConvertGeminiRequest(*gin.Context, *relaycommon.RelayInfo, *dto.GeminiChatRequest) (any, error) { return nil, errors.New("not implemented") }
|
||||
|
||||
func (a *Adaptor) ConvertClaudeRequest(c *gin.Context, info *relaycommon.RelayInfo, request *dto.ClaudeRequest) (any, error) {
|
||||
openaiAdaptor := openai.Adaptor{}
|
||||
@@ -31,32 +29,21 @@ func (a *Adaptor) ConvertClaudeRequest(c *gin.Context, info *relaycommon.RelayIn
|
||||
openaiRequest.(*dto.GeneralOpenAIRequest).StreamOptions = &dto.StreamOptions{
|
||||
IncludeUsage: true,
|
||||
}
|
||||
return requestOpenAI2Ollama(c, openaiRequest.(*dto.GeneralOpenAIRequest))
|
||||
// map to ollama chat request (Claude -> OpenAI -> Ollama chat)
|
||||
return openAIChatToOllamaChat(c, openaiRequest.(*dto.GeneralOpenAIRequest))
|
||||
}
|
||||
|
||||
func (a *Adaptor) ConvertAudioRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.AudioRequest) (io.Reader, error) {
|
||||
//TODO implement me
|
||||
return nil, errors.New("not implemented")
|
||||
}
|
||||
func (a *Adaptor) ConvertAudioRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.AudioRequest) (io.Reader, error) { return nil, errors.New("not implemented") }
|
||||
|
||||
func (a *Adaptor) ConvertImageRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.ImageRequest) (any, error) {
|
||||
//TODO implement me
|
||||
return nil, errors.New("not implemented")
|
||||
}
|
||||
func (a *Adaptor) ConvertImageRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.ImageRequest) (any, error) { return nil, errors.New("not implemented") }
|
||||
|
||||
func (a *Adaptor) Init(info *relaycommon.RelayInfo) {
|
||||
}
|
||||
|
||||
func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) {
|
||||
if info.RelayFormat == types.RelayFormatClaude {
|
||||
return info.ChannelBaseUrl + "/v1/chat/completions", nil
|
||||
}
|
||||
switch info.RelayMode {
|
||||
case relayconstant.RelayModeEmbeddings:
|
||||
return info.ChannelBaseUrl + "/api/embed", nil
|
||||
default:
|
||||
return relaycommon.GetFullRequestURL(info.ChannelBaseUrl, info.RequestURLPath, info.ChannelType), nil
|
||||
}
|
||||
if info.RelayMode == relayconstant.RelayModeEmbeddings { return info.ChannelBaseUrl + "/api/embed", nil }
|
||||
if strings.Contains(info.RequestURLPath, "/v1/completions") || info.RelayMode == relayconstant.RelayModeCompletions { return info.ChannelBaseUrl + "/api/generate", nil }
|
||||
return info.ChannelBaseUrl + "/api/chat", nil
|
||||
}
|
||||
|
||||
func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Header, info *relaycommon.RelayInfo) error {
|
||||
@@ -66,10 +53,12 @@ func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Header, info *rel
|
||||
}
|
||||
|
||||
func (a *Adaptor) ConvertOpenAIRequest(c *gin.Context, info *relaycommon.RelayInfo, request *dto.GeneralOpenAIRequest) (any, error) {
|
||||
if request == nil {
|
||||
return nil, errors.New("request is nil")
|
||||
if request == nil { return nil, errors.New("request is nil") }
|
||||
// decide generate or chat
|
||||
if strings.Contains(info.RequestURLPath, "/v1/completions") || info.RelayMode == relayconstant.RelayModeCompletions {
|
||||
return openAIToGenerate(c, request)
|
||||
}
|
||||
return requestOpenAI2Ollama(c, request)
|
||||
return openAIChatToOllamaChat(c, request)
|
||||
}
|
||||
|
||||
func (a *Adaptor) ConvertRerankRequest(c *gin.Context, relayMode int, request dto.RerankRequest) (any, error) {
|
||||
@@ -80,10 +69,7 @@ func (a *Adaptor) ConvertEmbeddingRequest(c *gin.Context, info *relaycommon.Rela
|
||||
return requestOpenAI2Embeddings(request), nil
|
||||
}
|
||||
|
||||
func (a *Adaptor) ConvertOpenAIResponsesRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.OpenAIResponsesRequest) (any, error) {
|
||||
// TODO implement me
|
||||
return nil, errors.New("not implemented")
|
||||
}
|
||||
func (a *Adaptor) ConvertOpenAIResponsesRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.OpenAIResponsesRequest) (any, error) { return nil, errors.New("not implemented") }
|
||||
|
||||
func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (any, error) {
|
||||
return channel.DoApiRequest(a, c, info, requestBody)
|
||||
@@ -92,15 +78,13 @@ func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, request
|
||||
func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage any, err *types.NewAPIError) {
|
||||
switch info.RelayMode {
|
||||
case relayconstant.RelayModeEmbeddings:
|
||||
usage, err = ollamaEmbeddingHandler(c, info, resp)
|
||||
return ollamaEmbeddingHandler(c, info, resp)
|
||||
default:
|
||||
if info.IsStream {
|
||||
usage, err = openai.OaiStreamHandler(c, info, resp)
|
||||
} else {
|
||||
usage, err = openai.OpenaiHandler(c, info, resp)
|
||||
return ollamaStreamHandler(c, info, resp)
|
||||
}
|
||||
return ollamaChatHandler(c, info, resp)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
func (a *Adaptor) GetModelList() []string {
|
||||
|
||||
@@ -2,48 +2,69 @@ package ollama
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"one-api/dto"
|
||||
)
|
||||
|
||||
type OllamaRequest struct {
|
||||
Model string `json:"model,omitempty"`
|
||||
Messages []dto.Message `json:"messages,omitempty"`
|
||||
Stream bool `json:"stream,omitempty"`
|
||||
Temperature *float64 `json:"temperature,omitempty"`
|
||||
Seed float64 `json:"seed,omitempty"`
|
||||
Topp float64 `json:"top_p,omitempty"`
|
||||
TopK int `json:"top_k,omitempty"`
|
||||
Stop any `json:"stop,omitempty"`
|
||||
MaxTokens uint `json:"max_tokens,omitempty"`
|
||||
Tools []dto.ToolCallRequest `json:"tools,omitempty"`
|
||||
ResponseFormat any `json:"response_format,omitempty"`
|
||||
FrequencyPenalty float64 `json:"frequency_penalty,omitempty"`
|
||||
PresencePenalty float64 `json:"presence_penalty,omitempty"`
|
||||
Suffix any `json:"suffix,omitempty"`
|
||||
StreamOptions *dto.StreamOptions `json:"stream_options,omitempty"`
|
||||
Prompt any `json:"prompt,omitempty"`
|
||||
Think json.RawMessage `json:"think,omitempty"`
|
||||
type OllamaChatMessage struct {
|
||||
Role string `json:"role"`
|
||||
Content string `json:"content,omitempty"`
|
||||
Images []string `json:"images,omitempty"`
|
||||
ToolCalls []OllamaToolCall `json:"tool_calls,omitempty"`
|
||||
ToolName string `json:"tool_name,omitempty"`
|
||||
Thinking json.RawMessage `json:"thinking,omitempty"`
|
||||
}
|
||||
|
||||
type Options struct {
|
||||
Seed int `json:"seed,omitempty"`
|
||||
Temperature *float64 `json:"temperature,omitempty"`
|
||||
TopK int `json:"top_k,omitempty"`
|
||||
TopP float64 `json:"top_p,omitempty"`
|
||||
FrequencyPenalty float64 `json:"frequency_penalty,omitempty"`
|
||||
PresencePenalty float64 `json:"presence_penalty,omitempty"`
|
||||
NumPredict int `json:"num_predict,omitempty"`
|
||||
NumCtx int `json:"num_ctx,omitempty"`
|
||||
type OllamaToolFunction struct {
|
||||
Name string `json:"name"`
|
||||
Description string `json:"description,omitempty"`
|
||||
Parameters interface{} `json:"parameters,omitempty"`
|
||||
}
|
||||
|
||||
type OllamaTool struct {
|
||||
Type string `json:"type"`
|
||||
Function OllamaToolFunction `json:"function"`
|
||||
}
|
||||
|
||||
type OllamaToolCall struct {
|
||||
Function struct {
|
||||
Name string `json:"name"`
|
||||
Arguments interface{} `json:"arguments"`
|
||||
} `json:"function"`
|
||||
}
|
||||
|
||||
type OllamaChatRequest struct {
|
||||
Model string `json:"model"`
|
||||
Messages []OllamaChatMessage `json:"messages"`
|
||||
Tools interface{} `json:"tools,omitempty"`
|
||||
Format interface{} `json:"format,omitempty"`
|
||||
Stream bool `json:"stream,omitempty"`
|
||||
Options map[string]any `json:"options,omitempty"`
|
||||
KeepAlive interface{} `json:"keep_alive,omitempty"`
|
||||
Think json.RawMessage `json:"think,omitempty"`
|
||||
}
|
||||
|
||||
type OllamaGenerateRequest struct {
|
||||
Model string `json:"model"`
|
||||
Prompt string `json:"prompt,omitempty"`
|
||||
Suffix string `json:"suffix,omitempty"`
|
||||
Images []string `json:"images,omitempty"`
|
||||
Format interface{} `json:"format,omitempty"`
|
||||
Stream bool `json:"stream,omitempty"`
|
||||
Options map[string]any `json:"options,omitempty"`
|
||||
KeepAlive interface{} `json:"keep_alive,omitempty"`
|
||||
Think json.RawMessage `json:"think,omitempty"`
|
||||
}
|
||||
|
||||
type OllamaEmbeddingRequest struct {
|
||||
Model string `json:"model,omitempty"`
|
||||
Input []string `json:"input"`
|
||||
Options *Options `json:"options,omitempty"`
|
||||
Model string `json:"model"`
|
||||
Input interface{} `json:"input"`
|
||||
Options map[string]any `json:"options,omitempty"`
|
||||
Dimensions int `json:"dimensions,omitempty"`
|
||||
}
|
||||
|
||||
type OllamaEmbeddingResponse struct {
|
||||
Error string `json:"error,omitempty"`
|
||||
Model string `json:"model"`
|
||||
Embedding [][]float64 `json:"embeddings,omitempty"`
|
||||
Error string `json:"error,omitempty"`
|
||||
Model string `json:"model"`
|
||||
Embeddings [][]float64 `json:"embeddings"`
|
||||
PromptEvalCount int `json:"prompt_eval_count,omitempty"`
|
||||
}
|
||||
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
package ollama
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
@@ -14,121 +15,176 @@ import (
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
func requestOpenAI2Ollama(c *gin.Context, request *dto.GeneralOpenAIRequest) (*OllamaRequest, error) {
|
||||
messages := make([]dto.Message, 0, len(request.Messages))
|
||||
for _, message := range request.Messages {
|
||||
if !message.IsStringContent() {
|
||||
mediaMessages := message.ParseContent()
|
||||
for j, mediaMessage := range mediaMessages {
|
||||
if mediaMessage.Type == dto.ContentTypeImageURL {
|
||||
imageUrl := mediaMessage.GetImageMedia()
|
||||
// check if not base64
|
||||
if strings.HasPrefix(imageUrl.Url, "http") {
|
||||
fileData, err := service.GetFileBase64FromUrl(c, imageUrl.Url, "formatting image for Ollama")
|
||||
if err != nil {
|
||||
return nil, err
|
||||
func openAIChatToOllamaChat(c *gin.Context, r *dto.GeneralOpenAIRequest) (*OllamaChatRequest, error) {
|
||||
chatReq := &OllamaChatRequest{
|
||||
Model: r.Model,
|
||||
Stream: r.Stream,
|
||||
Options: map[string]any{},
|
||||
Think: r.Think,
|
||||
}
|
||||
if r.ResponseFormat != nil {
|
||||
if r.ResponseFormat.Type == "json" {
|
||||
chatReq.Format = "json"
|
||||
} else if r.ResponseFormat.Type == "json_schema" {
|
||||
if len(r.ResponseFormat.JsonSchema) > 0 {
|
||||
var schema any
|
||||
_ = json.Unmarshal(r.ResponseFormat.JsonSchema, &schema)
|
||||
chatReq.Format = schema
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// options mapping
|
||||
if r.Temperature != nil { chatReq.Options["temperature"] = r.Temperature }
|
||||
if r.TopP != 0 { chatReq.Options["top_p"] = r.TopP }
|
||||
if r.TopK != 0 { chatReq.Options["top_k"] = r.TopK }
|
||||
if r.FrequencyPenalty != 0 { chatReq.Options["frequency_penalty"] = r.FrequencyPenalty }
|
||||
if r.PresencePenalty != 0 { chatReq.Options["presence_penalty"] = r.PresencePenalty }
|
||||
if r.Seed != 0 { chatReq.Options["seed"] = int(r.Seed) }
|
||||
if mt := r.GetMaxTokens(); mt != 0 { chatReq.Options["num_predict"] = int(mt) }
|
||||
|
||||
if r.Stop != nil {
|
||||
switch v := r.Stop.(type) {
|
||||
case string:
|
||||
chatReq.Options["stop"] = []string{v}
|
||||
case []string:
|
||||
chatReq.Options["stop"] = v
|
||||
case []any:
|
||||
arr := make([]string,0,len(v))
|
||||
for _, i := range v { if s,ok:=i.(string); ok { arr = append(arr,s) } }
|
||||
if len(arr)>0 { chatReq.Options["stop"] = arr }
|
||||
}
|
||||
}
|
||||
|
||||
if len(r.Tools) > 0 {
|
||||
tools := make([]OllamaTool,0,len(r.Tools))
|
||||
for _, t := range r.Tools {
|
||||
tools = append(tools, OllamaTool{Type: "function", Function: OllamaToolFunction{Name: t.Function.Name, Description: t.Function.Description, Parameters: t.Function.Parameters}})
|
||||
}
|
||||
chatReq.Tools = tools
|
||||
}
|
||||
|
||||
chatReq.Messages = make([]OllamaChatMessage,0,len(r.Messages))
|
||||
for _, m := range r.Messages {
|
||||
var textBuilder strings.Builder
|
||||
var images []string
|
||||
if m.IsStringContent() {
|
||||
textBuilder.WriteString(m.StringContent())
|
||||
} else {
|
||||
parts := m.ParseContent()
|
||||
for _, part := range parts {
|
||||
if part.Type == dto.ContentTypeImageURL {
|
||||
img := part.GetImageMedia()
|
||||
if img != nil && img.Url != "" {
|
||||
var base64Data string
|
||||
if strings.HasPrefix(img.Url, "http") {
|
||||
fileData, err := service.GetFileBase64FromUrl(c, img.Url, "fetch image for ollama chat")
|
||||
if err != nil { return nil, err }
|
||||
base64Data = fileData.Base64Data
|
||||
} else if strings.HasPrefix(img.Url, "data:") {
|
||||
if idx := strings.Index(img.Url, ","); idx != -1 && idx+1 < len(img.Url) { base64Data = img.Url[idx+1:] }
|
||||
} else {
|
||||
base64Data = img.Url
|
||||
}
|
||||
imageUrl.Url = fmt.Sprintf("data:%s;base64,%s", fileData.MimeType, fileData.Base64Data)
|
||||
if base64Data != "" { images = append(images, base64Data) }
|
||||
}
|
||||
mediaMessage.ImageUrl = imageUrl
|
||||
mediaMessages[j] = mediaMessage
|
||||
} else if part.Type == dto.ContentTypeText {
|
||||
textBuilder.WriteString(part.Text)
|
||||
}
|
||||
}
|
||||
message.SetMediaContent(mediaMessages)
|
||||
}
|
||||
messages = append(messages, dto.Message{
|
||||
Role: message.Role,
|
||||
Content: message.Content,
|
||||
ToolCalls: message.ToolCalls,
|
||||
ToolCallId: message.ToolCallId,
|
||||
})
|
||||
cm := OllamaChatMessage{Role: m.Role, Content: textBuilder.String()}
|
||||
if len(images)>0 { cm.Images = images }
|
||||
if m.Role == "tool" && m.Name != nil { cm.ToolName = *m.Name }
|
||||
if m.ToolCalls != nil && len(m.ToolCalls) > 0 {
|
||||
parsed := m.ParseToolCalls()
|
||||
if len(parsed) > 0 {
|
||||
calls := make([]OllamaToolCall,0,len(parsed))
|
||||
for _, tc := range parsed {
|
||||
var args interface{}
|
||||
if tc.Function.Arguments != "" { _ = json.Unmarshal([]byte(tc.Function.Arguments), &args) }
|
||||
if args==nil { args = map[string]any{} }
|
||||
oc := OllamaToolCall{}
|
||||
oc.Function.Name = tc.Function.Name
|
||||
oc.Function.Arguments = args
|
||||
calls = append(calls, oc)
|
||||
}
|
||||
cm.ToolCalls = calls
|
||||
}
|
||||
}
|
||||
chatReq.Messages = append(chatReq.Messages, cm)
|
||||
}
|
||||
str, ok := request.Stop.(string)
|
||||
var Stop []string
|
||||
if ok {
|
||||
Stop = []string{str}
|
||||
} else {
|
||||
Stop, _ = request.Stop.([]string)
|
||||
}
|
||||
ollamaRequest := &OllamaRequest{
|
||||
Model: request.Model,
|
||||
Messages: messages,
|
||||
Stream: request.Stream,
|
||||
Temperature: request.Temperature,
|
||||
Seed: request.Seed,
|
||||
Topp: request.TopP,
|
||||
TopK: request.TopK,
|
||||
Stop: Stop,
|
||||
Tools: request.Tools,
|
||||
MaxTokens: request.GetMaxTokens(),
|
||||
ResponseFormat: request.ResponseFormat,
|
||||
FrequencyPenalty: request.FrequencyPenalty,
|
||||
PresencePenalty: request.PresencePenalty,
|
||||
Prompt: request.Prompt,
|
||||
StreamOptions: request.StreamOptions,
|
||||
Suffix: request.Suffix,
|
||||
}
|
||||
ollamaRequest.Think = request.Think
|
||||
return ollamaRequest, nil
|
||||
return chatReq, nil
|
||||
}
|
||||
|
||||
func requestOpenAI2Embeddings(request dto.EmbeddingRequest) *OllamaEmbeddingRequest {
|
||||
return &OllamaEmbeddingRequest{
|
||||
Model: request.Model,
|
||||
Input: request.ParseInput(),
|
||||
Options: &Options{
|
||||
Seed: int(request.Seed),
|
||||
Temperature: request.Temperature,
|
||||
TopP: request.TopP,
|
||||
FrequencyPenalty: request.FrequencyPenalty,
|
||||
PresencePenalty: request.PresencePenalty,
|
||||
},
|
||||
// openAIToGenerate converts OpenAI completions request to Ollama generate
|
||||
func openAIToGenerate(c *gin.Context, r *dto.GeneralOpenAIRequest) (*OllamaGenerateRequest, error) {
|
||||
gen := &OllamaGenerateRequest{
|
||||
Model: r.Model,
|
||||
Stream: r.Stream,
|
||||
Options: map[string]any{},
|
||||
Think: r.Think,
|
||||
}
|
||||
// Prompt may be in r.Prompt (string or []any)
|
||||
if r.Prompt != nil {
|
||||
switch v := r.Prompt.(type) {
|
||||
case string:
|
||||
gen.Prompt = v
|
||||
case []any:
|
||||
var sb strings.Builder
|
||||
for _, it := range v { if s,ok:=it.(string); ok { sb.WriteString(s) } }
|
||||
gen.Prompt = sb.String()
|
||||
default:
|
||||
gen.Prompt = fmt.Sprintf("%v", r.Prompt)
|
||||
}
|
||||
}
|
||||
if r.Suffix != nil { if s,ok:=r.Suffix.(string); ok { gen.Suffix = s } }
|
||||
if r.ResponseFormat != nil {
|
||||
if r.ResponseFormat.Type == "json" { gen.Format = "json" } else if r.ResponseFormat.Type == "json_schema" { var schema any; _ = json.Unmarshal(r.ResponseFormat.JsonSchema,&schema); gen.Format=schema }
|
||||
}
|
||||
if r.Temperature != nil { gen.Options["temperature"] = r.Temperature }
|
||||
if r.TopP != 0 { gen.Options["top_p"] = r.TopP }
|
||||
if r.TopK != 0 { gen.Options["top_k"] = r.TopK }
|
||||
if r.FrequencyPenalty != 0 { gen.Options["frequency_penalty"] = r.FrequencyPenalty }
|
||||
if r.PresencePenalty != 0 { gen.Options["presence_penalty"] = r.PresencePenalty }
|
||||
if r.Seed != 0 { gen.Options["seed"] = int(r.Seed) }
|
||||
if mt := r.GetMaxTokens(); mt != 0 { gen.Options["num_predict"] = int(mt) }
|
||||
if r.Stop != nil {
|
||||
switch v := r.Stop.(type) {
|
||||
case string: gen.Options["stop"] = []string{v}
|
||||
case []string: gen.Options["stop"] = v
|
||||
case []any: arr:=make([]string,0,len(v)); for _,i:= range v { if s,ok:=i.(string); ok { arr=append(arr,s) } }; if len(arr)>0 { gen.Options["stop"]=arr }
|
||||
}
|
||||
}
|
||||
return gen, nil
|
||||
}
|
||||
|
||||
func requestOpenAI2Embeddings(r dto.EmbeddingRequest) *OllamaEmbeddingRequest {
|
||||
opts := map[string]any{}
|
||||
if r.Temperature != nil { opts["temperature"] = r.Temperature }
|
||||
if r.TopP != 0 { opts["top_p"] = r.TopP }
|
||||
if r.FrequencyPenalty != 0 { opts["frequency_penalty"] = r.FrequencyPenalty }
|
||||
if r.PresencePenalty != 0 { opts["presence_penalty"] = r.PresencePenalty }
|
||||
if r.Seed != 0 { opts["seed"] = int(r.Seed) }
|
||||
if r.Dimensions != 0 { opts["dimensions"] = r.Dimensions }
|
||||
input := r.ParseInput()
|
||||
if len(input)==1 { return &OllamaEmbeddingRequest{Model:r.Model, Input: input[0], Options: opts, Dimensions:r.Dimensions} }
|
||||
return &OllamaEmbeddingRequest{Model:r.Model, Input: input, Options: opts, Dimensions:r.Dimensions}
|
||||
}
|
||||
|
||||
func ollamaEmbeddingHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http.Response) (*dto.Usage, *types.NewAPIError) {
|
||||
var ollamaEmbeddingResponse OllamaEmbeddingResponse
|
||||
responseBody, err := io.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
return nil, types.NewOpenAIError(err, types.ErrorCodeBadResponseBody, http.StatusInternalServerError)
|
||||
}
|
||||
var oResp OllamaEmbeddingResponse
|
||||
body, err := io.ReadAll(resp.Body)
|
||||
if err != nil { return nil, types.NewOpenAIError(err, types.ErrorCodeBadResponseBody, http.StatusInternalServerError) }
|
||||
service.CloseResponseBodyGracefully(resp)
|
||||
err = common.Unmarshal(responseBody, &ollamaEmbeddingResponse)
|
||||
if err != nil {
|
||||
return nil, types.NewOpenAIError(err, types.ErrorCodeBadResponseBody, http.StatusInternalServerError)
|
||||
}
|
||||
if ollamaEmbeddingResponse.Error != "" {
|
||||
return nil, types.NewOpenAIError(fmt.Errorf("ollama error: %s", ollamaEmbeddingResponse.Error), types.ErrorCodeBadResponseBody, http.StatusInternalServerError)
|
||||
}
|
||||
flattenedEmbeddings := flattenEmbeddings(ollamaEmbeddingResponse.Embedding)
|
||||
data := make([]dto.OpenAIEmbeddingResponseItem, 0, 1)
|
||||
data = append(data, dto.OpenAIEmbeddingResponseItem{
|
||||
Embedding: flattenedEmbeddings,
|
||||
Object: "embedding",
|
||||
})
|
||||
usage := &dto.Usage{
|
||||
TotalTokens: info.PromptTokens,
|
||||
CompletionTokens: 0,
|
||||
PromptTokens: info.PromptTokens,
|
||||
}
|
||||
embeddingResponse := &dto.OpenAIEmbeddingResponse{
|
||||
Object: "list",
|
||||
Data: data,
|
||||
Model: info.UpstreamModelName,
|
||||
Usage: *usage,
|
||||
}
|
||||
doResponseBody, err := common.Marshal(embeddingResponse)
|
||||
if err != nil {
|
||||
return nil, types.NewOpenAIError(err, types.ErrorCodeBadResponseBody, http.StatusInternalServerError)
|
||||
}
|
||||
service.IOCopyBytesGracefully(c, resp, doResponseBody)
|
||||
if err = common.Unmarshal(body, &oResp); err != nil { return nil, types.NewOpenAIError(err, types.ErrorCodeBadResponseBody, http.StatusInternalServerError) }
|
||||
if oResp.Error != "" { return nil, types.NewOpenAIError(fmt.Errorf("ollama error: %s", oResp.Error), types.ErrorCodeBadResponseBody, http.StatusInternalServerError) }
|
||||
data := make([]dto.OpenAIEmbeddingResponseItem,0,len(oResp.Embeddings))
|
||||
for i, emb := range oResp.Embeddings { data = append(data, dto.OpenAIEmbeddingResponseItem{Index:i,Object:"embedding",Embedding:emb}) }
|
||||
usage := &dto.Usage{PromptTokens: oResp.PromptEvalCount, CompletionTokens:0, TotalTokens: oResp.PromptEvalCount}
|
||||
embResp := &dto.OpenAIEmbeddingResponse{Object:"list", Data:data, Model: info.UpstreamModelName, Usage:*usage}
|
||||
out, _ := common.Marshal(embResp)
|
||||
service.IOCopyBytesGracefully(c, resp, out)
|
||||
return usage, nil
|
||||
}
|
||||
|
||||
func flattenEmbeddings(embeddings [][]float64) []float64 {
|
||||
flattened := []float64{}
|
||||
for _, row := range embeddings {
|
||||
flattened = append(flattened, row...)
|
||||
}
|
||||
return flattened
|
||||
}
|
||||
|
||||
210
relay/channel/ollama/stream.go
Normal file
210
relay/channel/ollama/stream.go
Normal file
@@ -0,0 +1,210 @@
|
||||
package ollama
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"one-api/common"
|
||||
"one-api/dto"
|
||||
"one-api/logger"
|
||||
relaycommon "one-api/relay/common"
|
||||
"one-api/relay/helper"
|
||||
"one-api/service"
|
||||
"one-api/types"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
type ollamaChatStreamChunk struct {
|
||||
Model string `json:"model"`
|
||||
CreatedAt string `json:"created_at"`
|
||||
// chat
|
||||
Message *struct {
|
||||
Role string `json:"role"`
|
||||
Content string `json:"content"`
|
||||
Thinking json.RawMessage `json:"thinking"`
|
||||
ToolCalls []struct {
|
||||
Function struct {
|
||||
Name string `json:"name"`
|
||||
Arguments interface{} `json:"arguments"`
|
||||
} `json:"function"`
|
||||
} `json:"tool_calls"`
|
||||
} `json:"message"`
|
||||
// generate
|
||||
Response string `json:"response"`
|
||||
Done bool `json:"done"`
|
||||
DoneReason string `json:"done_reason"`
|
||||
TotalDuration int64 `json:"total_duration"`
|
||||
LoadDuration int64 `json:"load_duration"`
|
||||
PromptEvalCount int `json:"prompt_eval_count"`
|
||||
EvalCount int `json:"eval_count"`
|
||||
PromptEvalDuration int64 `json:"prompt_eval_duration"`
|
||||
EvalDuration int64 `json:"eval_duration"`
|
||||
}
|
||||
|
||||
func toUnix(ts string) int64 {
|
||||
if ts == "" { return time.Now().Unix() }
|
||||
// try time.RFC3339 or with nanoseconds
|
||||
t, err := time.Parse(time.RFC3339Nano, ts)
|
||||
if err != nil { t2, err2 := time.Parse(time.RFC3339, ts); if err2==nil { return t2.Unix() }; return time.Now().Unix() }
|
||||
return t.Unix()
|
||||
}
|
||||
|
||||
func ollamaStreamHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http.Response) (*dto.Usage, *types.NewAPIError) {
|
||||
if resp == nil || resp.Body == nil { return nil, types.NewOpenAIError(fmt.Errorf("empty response"), types.ErrorCodeBadResponse, http.StatusBadRequest) }
|
||||
defer service.CloseResponseBodyGracefully(resp)
|
||||
|
||||
helper.SetEventStreamHeaders(c)
|
||||
scanner := bufio.NewScanner(resp.Body)
|
||||
usage := &dto.Usage{}
|
||||
var model = info.UpstreamModelName
|
||||
var responseId = common.GetUUID()
|
||||
var created = time.Now().Unix()
|
||||
var toolCallIndex int
|
||||
start := helper.GenerateStartEmptyResponse(responseId, created, model, nil)
|
||||
if data, err := common.Marshal(start); err == nil { _ = helper.StringData(c, string(data)) }
|
||||
|
||||
for scanner.Scan() {
|
||||
line := scanner.Text()
|
||||
line = strings.TrimSpace(line)
|
||||
if line == "" { continue }
|
||||
var chunk ollamaChatStreamChunk
|
||||
if err := json.Unmarshal([]byte(line), &chunk); err != nil {
|
||||
logger.LogError(c, "ollama stream json decode error: "+err.Error()+" line="+line)
|
||||
return usage, types.NewOpenAIError(err, types.ErrorCodeBadResponseBody, http.StatusInternalServerError)
|
||||
}
|
||||
if chunk.Model != "" { model = chunk.Model }
|
||||
created = toUnix(chunk.CreatedAt)
|
||||
|
||||
if !chunk.Done {
|
||||
// delta content
|
||||
var content string
|
||||
if chunk.Message != nil { content = chunk.Message.Content } else { content = chunk.Response }
|
||||
delta := dto.ChatCompletionsStreamResponse{
|
||||
Id: responseId,
|
||||
Object: "chat.completion.chunk",
|
||||
Created: created,
|
||||
Model: model,
|
||||
Choices: []dto.ChatCompletionsStreamResponseChoice{ {
|
||||
Index: 0,
|
||||
Delta: dto.ChatCompletionsStreamResponseChoiceDelta{ Role: "assistant" },
|
||||
} },
|
||||
}
|
||||
if content != "" { delta.Choices[0].Delta.SetContentString(content) }
|
||||
if chunk.Message != nil && len(chunk.Message.Thinking) > 0 {
|
||||
raw := strings.TrimSpace(string(chunk.Message.Thinking))
|
||||
if raw != "" && raw != "null" { delta.Choices[0].Delta.SetReasoningContent(raw) }
|
||||
}
|
||||
// tool calls
|
||||
if chunk.Message != nil && len(chunk.Message.ToolCalls) > 0 {
|
||||
delta.Choices[0].Delta.ToolCalls = make([]dto.ToolCallResponse,0,len(chunk.Message.ToolCalls))
|
||||
for _, tc := range chunk.Message.ToolCalls {
|
||||
// arguments -> string
|
||||
argBytes, _ := json.Marshal(tc.Function.Arguments)
|
||||
toolId := fmt.Sprintf("call_%d", toolCallIndex)
|
||||
tr := dto.ToolCallResponse{ID:toolId, Type:"function", Function: dto.FunctionResponse{Name: tc.Function.Name, Arguments: string(argBytes)}}
|
||||
tr.SetIndex(toolCallIndex)
|
||||
toolCallIndex++
|
||||
delta.Choices[0].Delta.ToolCalls = append(delta.Choices[0].Delta.ToolCalls, tr)
|
||||
}
|
||||
}
|
||||
if data, err := common.Marshal(delta); err == nil { _ = helper.StringData(c, string(data)) }
|
||||
continue
|
||||
}
|
||||
// done frame
|
||||
// finalize once and break loop
|
||||
usage.PromptTokens = chunk.PromptEvalCount
|
||||
usage.CompletionTokens = chunk.EvalCount
|
||||
usage.TotalTokens = usage.PromptTokens + usage.CompletionTokens
|
||||
finishReason := chunk.DoneReason
|
||||
if finishReason == "" { finishReason = "stop" }
|
||||
// emit stop delta
|
||||
if stop := helper.GenerateStopResponse(responseId, created, model, finishReason); stop != nil {
|
||||
if data, err := common.Marshal(stop); err == nil { _ = helper.StringData(c, string(data)) }
|
||||
}
|
||||
// emit usage frame
|
||||
if final := helper.GenerateFinalUsageResponse(responseId, created, model, *usage); final != nil {
|
||||
if data, err := common.Marshal(final); err == nil { _ = helper.StringData(c, string(data)) }
|
||||
}
|
||||
// send [DONE]
|
||||
helper.Done(c)
|
||||
break
|
||||
}
|
||||
if err := scanner.Err(); err != nil && err != io.EOF { logger.LogError(c, "ollama stream scan error: "+err.Error()) }
|
||||
return usage, nil
|
||||
}
|
||||
|
||||
// non-stream handler for chat/generate
|
||||
func ollamaChatHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http.Response) (*dto.Usage, *types.NewAPIError) {
|
||||
body, err := io.ReadAll(resp.Body)
|
||||
if err != nil { return nil, types.NewOpenAIError(err, types.ErrorCodeReadResponseBodyFailed, http.StatusInternalServerError) }
|
||||
service.CloseResponseBodyGracefully(resp)
|
||||
raw := string(body)
|
||||
if common.DebugEnabled { println("ollama non-stream raw resp:", raw) }
|
||||
|
||||
lines := strings.Split(raw, "\n")
|
||||
var (
|
||||
aggContent strings.Builder
|
||||
reasoningBuilder strings.Builder
|
||||
lastChunk ollamaChatStreamChunk
|
||||
parsedAny bool
|
||||
)
|
||||
for _, ln := range lines {
|
||||
ln = strings.TrimSpace(ln)
|
||||
if ln == "" { continue }
|
||||
var ck ollamaChatStreamChunk
|
||||
if err := json.Unmarshal([]byte(ln), &ck); err != nil {
|
||||
if len(lines) == 1 { return nil, types.NewOpenAIError(err, types.ErrorCodeBadResponseBody, http.StatusInternalServerError) }
|
||||
continue
|
||||
}
|
||||
parsedAny = true
|
||||
lastChunk = ck
|
||||
if ck.Message != nil && len(ck.Message.Thinking) > 0 {
|
||||
raw := strings.TrimSpace(string(ck.Message.Thinking))
|
||||
if raw != "" && raw != "null" { reasoningBuilder.WriteString(raw) }
|
||||
}
|
||||
if ck.Message != nil && ck.Message.Content != "" { aggContent.WriteString(ck.Message.Content) } else if ck.Response != "" { aggContent.WriteString(ck.Response) }
|
||||
}
|
||||
|
||||
if !parsedAny {
|
||||
var single ollamaChatStreamChunk
|
||||
if err := json.Unmarshal(body, &single); err != nil { return nil, types.NewOpenAIError(err, types.ErrorCodeBadResponseBody, http.StatusInternalServerError) }
|
||||
lastChunk = single
|
||||
if single.Message != nil {
|
||||
if len(single.Message.Thinking) > 0 { raw := strings.TrimSpace(string(single.Message.Thinking)); if raw != "" && raw != "null" { reasoningBuilder.WriteString(raw) } }
|
||||
aggContent.WriteString(single.Message.Content)
|
||||
} else { aggContent.WriteString(single.Response) }
|
||||
}
|
||||
|
||||
model := lastChunk.Model
|
||||
if model == "" { model = info.UpstreamModelName }
|
||||
created := toUnix(lastChunk.CreatedAt)
|
||||
usage := &dto.Usage{PromptTokens: lastChunk.PromptEvalCount, CompletionTokens: lastChunk.EvalCount, TotalTokens: lastChunk.PromptEvalCount + lastChunk.EvalCount}
|
||||
content := aggContent.String()
|
||||
finishReason := lastChunk.DoneReason
|
||||
if finishReason == "" { finishReason = "stop" }
|
||||
|
||||
msg := dto.Message{Role: "assistant", Content: contentPtr(content)}
|
||||
if rc := reasoningBuilder.String(); rc != "" { msg.ReasoningContent = rc }
|
||||
full := dto.OpenAITextResponse{
|
||||
Id: common.GetUUID(),
|
||||
Model: model,
|
||||
Object: "chat.completion",
|
||||
Created: created,
|
||||
Choices: []dto.OpenAITextResponseChoice{ {
|
||||
Index: 0,
|
||||
Message: msg,
|
||||
FinishReason: finishReason,
|
||||
} },
|
||||
Usage: *usage,
|
||||
}
|
||||
out, _ := common.Marshal(full)
|
||||
service.IOCopyBytesGracefully(c, resp, out)
|
||||
return usage, nil
|
||||
}
|
||||
|
||||
func contentPtr(s string) *string { if s=="" { return nil }; return &s }
|
||||
@@ -12,6 +12,7 @@ import (
|
||||
"one-api/constant"
|
||||
"one-api/dto"
|
||||
"one-api/logger"
|
||||
"one-api/relay/channel/openrouter"
|
||||
relaycommon "one-api/relay/common"
|
||||
"one-api/relay/helper"
|
||||
"one-api/service"
|
||||
@@ -185,10 +186,27 @@ func OpenaiHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http.Respo
|
||||
if common.DebugEnabled {
|
||||
println("upstream response body:", string(responseBody))
|
||||
}
|
||||
// Unmarshal to simpleResponse
|
||||
if info.ChannelType == constant.ChannelTypeOpenRouter && info.ChannelOtherSettings.IsOpenRouterEnterprise() {
|
||||
// 尝试解析为 openrouter enterprise
|
||||
var enterpriseResponse openrouter.OpenRouterEnterpriseResponse
|
||||
err = common.Unmarshal(responseBody, &enterpriseResponse)
|
||||
if err != nil {
|
||||
return nil, types.NewOpenAIError(err, types.ErrorCodeBadResponseBody, http.StatusInternalServerError)
|
||||
}
|
||||
if enterpriseResponse.Success {
|
||||
responseBody = enterpriseResponse.Data
|
||||
} else {
|
||||
logger.LogError(c, fmt.Sprintf("openrouter enterprise response success=false, data: %s", enterpriseResponse.Data))
|
||||
return nil, types.NewOpenAIError(fmt.Errorf("openrouter response success=false"), types.ErrorCodeBadResponseBody, http.StatusInternalServerError)
|
||||
}
|
||||
}
|
||||
|
||||
err = common.Unmarshal(responseBody, &simpleResponse)
|
||||
if err != nil {
|
||||
return nil, types.NewOpenAIError(err, types.ErrorCodeBadResponseBody, http.StatusInternalServerError)
|
||||
}
|
||||
|
||||
if oaiError := simpleResponse.GetOpenAIError(); oaiError != nil && oaiError.Type != "" {
|
||||
return nil, types.WithOpenAIError(*oaiError, resp.StatusCode)
|
||||
}
|
||||
|
||||
@@ -1,5 +1,7 @@
|
||||
package openrouter
|
||||
|
||||
import "encoding/json"
|
||||
|
||||
type RequestReasoning struct {
|
||||
// One of the following (not both):
|
||||
Effort string `json:"effort,omitempty"` // Can be "high", "medium", or "low" (OpenAI-style)
|
||||
@@ -7,3 +9,8 @@ type RequestReasoning struct {
|
||||
// Optional: Default is false. All models support this.
|
||||
Exclude bool `json:"exclude,omitempty"` // Set to true to exclude reasoning tokens from response
|
||||
}
|
||||
|
||||
type OpenRouterEnterpriseResponse struct {
|
||||
Data json.RawMessage `json:"data"`
|
||||
Success bool `json:"success"`
|
||||
}
|
||||
|
||||
@@ -9,6 +9,7 @@ import (
|
||||
"mime/multipart"
|
||||
"net/http"
|
||||
"net/textproto"
|
||||
channelconstant "one-api/constant"
|
||||
"one-api/dto"
|
||||
"one-api/relay/channel"
|
||||
"one-api/relay/channel/openai"
|
||||
@@ -188,20 +189,26 @@ func (a *Adaptor) Init(info *relaycommon.RelayInfo) {
|
||||
}
|
||||
|
||||
func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) {
|
||||
// 支持自定义域名,如果未设置则使用默认域名
|
||||
baseUrl := info.ChannelBaseUrl
|
||||
if baseUrl == "" {
|
||||
baseUrl = channelconstant.ChannelBaseURLs[channelconstant.ChannelTypeVolcEngine]
|
||||
}
|
||||
|
||||
switch info.RelayMode {
|
||||
case constant.RelayModeChatCompletions:
|
||||
if strings.HasPrefix(info.UpstreamModelName, "bot") {
|
||||
return fmt.Sprintf("%s/api/v3/bots/chat/completions", info.ChannelBaseUrl), nil
|
||||
return fmt.Sprintf("%s/api/v3/bots/chat/completions", baseUrl), nil
|
||||
}
|
||||
return fmt.Sprintf("%s/api/v3/chat/completions", info.ChannelBaseUrl), nil
|
||||
return fmt.Sprintf("%s/api/v3/chat/completions", baseUrl), nil
|
||||
case constant.RelayModeEmbeddings:
|
||||
return fmt.Sprintf("%s/api/v3/embeddings", info.ChannelBaseUrl), nil
|
||||
return fmt.Sprintf("%s/api/v3/embeddings", baseUrl), nil
|
||||
case constant.RelayModeImagesGenerations:
|
||||
return fmt.Sprintf("%s/api/v3/images/generations", info.ChannelBaseUrl), nil
|
||||
return fmt.Sprintf("%s/api/v3/images/generations", baseUrl), nil
|
||||
case constant.RelayModeImagesEdits:
|
||||
return fmt.Sprintf("%s/api/v3/images/edits", info.ChannelBaseUrl), nil
|
||||
return fmt.Sprintf("%s/api/v3/images/edits", baseUrl), nil
|
||||
case constant.RelayModeRerank:
|
||||
return fmt.Sprintf("%s/api/v3/rerank", info.ChannelBaseUrl), nil
|
||||
return fmt.Sprintf("%s/api/v3/rerank", baseUrl), nil
|
||||
default:
|
||||
}
|
||||
return "", fmt.Errorf("unsupported relay mode: %d", info.RelayMode)
|
||||
|
||||
@@ -9,6 +9,11 @@ var ModelList = []string{
|
||||
"Doubao-lite-4k",
|
||||
"Doubao-embedding",
|
||||
"doubao-seedream-4-0-250828",
|
||||
"seedream-4-0-250828",
|
||||
"doubao-seedance-1-0-pro-250528",
|
||||
"seedance-1-0-pro-250528",
|
||||
"doubao-seed-1-6-thinking-250715",
|
||||
"seed-1-6-thinking-250715",
|
||||
}
|
||||
|
||||
var ChannelName = "volcengine"
|
||||
|
||||
@@ -207,10 +207,6 @@ func xunfeiMakeRequest(textRequest dto.GeneralOpenAIRequest, domain, authUrl, ap
|
||||
return nil, nil, err
|
||||
}
|
||||
|
||||
defer func() {
|
||||
conn.Close()
|
||||
}()
|
||||
|
||||
data := requestOpenAI2Xunfei(textRequest, appId, domain)
|
||||
err = conn.WriteJSON(data)
|
||||
if err != nil {
|
||||
@@ -220,6 +216,9 @@ func xunfeiMakeRequest(textRequest dto.GeneralOpenAIRequest, domain, authUrl, ap
|
||||
dataChan := make(chan XunfeiChatResponse)
|
||||
stopChan := make(chan bool)
|
||||
go func() {
|
||||
defer func() {
|
||||
conn.Close()
|
||||
}()
|
||||
for {
|
||||
_, msg, err := conn.ReadMessage()
|
||||
if err != nil {
|
||||
|
||||
@@ -7,12 +7,17 @@ import (
|
||||
"net/http"
|
||||
"net/url"
|
||||
"one-api/common"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"golang.org/x/net/proxy"
|
||||
)
|
||||
|
||||
var httpClient *http.Client
|
||||
var (
|
||||
httpClient *http.Client
|
||||
proxyClientLock sync.Mutex
|
||||
proxyClients = make(map[string]*http.Client)
|
||||
)
|
||||
|
||||
func InitHttpClient() {
|
||||
if common.RelayTimeout == 0 {
|
||||
@@ -28,12 +33,31 @@ func GetHttpClient() *http.Client {
|
||||
return httpClient
|
||||
}
|
||||
|
||||
// ResetProxyClientCache 清空代理客户端缓存,确保下次使用时重新初始化
|
||||
func ResetProxyClientCache() {
|
||||
proxyClientLock.Lock()
|
||||
defer proxyClientLock.Unlock()
|
||||
for _, client := range proxyClients {
|
||||
if transport, ok := client.Transport.(*http.Transport); ok && transport != nil {
|
||||
transport.CloseIdleConnections()
|
||||
}
|
||||
}
|
||||
proxyClients = make(map[string]*http.Client)
|
||||
}
|
||||
|
||||
// NewProxyHttpClient 创建支持代理的 HTTP 客户端
|
||||
func NewProxyHttpClient(proxyURL string) (*http.Client, error) {
|
||||
if proxyURL == "" {
|
||||
return http.DefaultClient, nil
|
||||
}
|
||||
|
||||
proxyClientLock.Lock()
|
||||
if client, ok := proxyClients[proxyURL]; ok {
|
||||
proxyClientLock.Unlock()
|
||||
return client, nil
|
||||
}
|
||||
proxyClientLock.Unlock()
|
||||
|
||||
parsedURL, err := url.Parse(proxyURL)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
@@ -41,11 +65,16 @@ func NewProxyHttpClient(proxyURL string) (*http.Client, error) {
|
||||
|
||||
switch parsedURL.Scheme {
|
||||
case "http", "https":
|
||||
return &http.Client{
|
||||
client := &http.Client{
|
||||
Transport: &http.Transport{
|
||||
Proxy: http.ProxyURL(parsedURL),
|
||||
},
|
||||
}, nil
|
||||
}
|
||||
client.Timeout = time.Duration(common.RelayTimeout) * time.Second
|
||||
proxyClientLock.Lock()
|
||||
proxyClients[proxyURL] = client
|
||||
proxyClientLock.Unlock()
|
||||
return client, nil
|
||||
|
||||
case "socks5", "socks5h":
|
||||
// 获取认证信息
|
||||
@@ -67,13 +96,18 @@ func NewProxyHttpClient(proxyURL string) (*http.Client, error) {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return &http.Client{
|
||||
client := &http.Client{
|
||||
Transport: &http.Transport{
|
||||
DialContext: func(ctx context.Context, network, addr string) (net.Conn, error) {
|
||||
return dialer.Dial(network, addr)
|
||||
},
|
||||
},
|
||||
}, nil
|
||||
}
|
||||
client.Timeout = time.Duration(common.RelayTimeout) * time.Second
|
||||
proxyClientLock.Lock()
|
||||
proxyClients[proxyURL] = client
|
||||
proxyClientLock.Unlock()
|
||||
return client, nil
|
||||
|
||||
default:
|
||||
return nil, fmt.Errorf("unsupported proxy scheme: %s", parsedURL.Scheme)
|
||||
|
||||
@@ -5,3 +5,4 @@ var StripeWebhookSecret = ""
|
||||
var StripePriceId = ""
|
||||
var StripeUnitPrice = 8.0
|
||||
var StripeMinTopUp = 1
|
||||
var StripePromotionCodesEnabled = false
|
||||
|
||||
@@ -10,6 +10,7 @@
|
||||
content="OpenAI 接口聚合管理,支持多种渠道包括 Azure,可用于二次分发管理 key,仅单可执行文件,已打包好 Docker 镜像,一键部署,开箱即用"
|
||||
/>
|
||||
<title>New API</title>
|
||||
<analytics></analytics>
|
||||
</head>
|
||||
|
||||
<body>
|
||||
|
||||
@@ -181,8 +181,8 @@ export function PreCode(props) {
|
||||
e.preventDefault();
|
||||
e.stopPropagation();
|
||||
if (ref.current) {
|
||||
const code =
|
||||
ref.current.querySelector('code')?.innerText ?? '';
|
||||
const codeElement = ref.current.querySelector('code');
|
||||
const code = codeElement?.textContent ?? '';
|
||||
copy(code).then((success) => {
|
||||
if (success) {
|
||||
Toast.success(t('代码已复制到剪贴板'));
|
||||
|
||||
@@ -45,6 +45,7 @@ const PaymentSetting = () => {
|
||||
StripePriceId: '',
|
||||
StripeUnitPrice: 8.0,
|
||||
StripeMinTopUp: 1,
|
||||
StripePromotionCodesEnabled: false,
|
||||
});
|
||||
|
||||
let [loading, setLoading] = useState(false);
|
||||
|
||||
@@ -19,7 +19,14 @@ For commercial licensing, please contact support@quantumnous.com
|
||||
|
||||
import React, { useContext, useEffect, useState } from 'react';
|
||||
import { useNavigate } from 'react-router-dom';
|
||||
import { API, copy, showError, showInfo, showSuccess } from '../../helpers';
|
||||
import {
|
||||
API,
|
||||
copy,
|
||||
showError,
|
||||
showInfo,
|
||||
showSuccess,
|
||||
setStatusData,
|
||||
} from '../../helpers';
|
||||
import { UserContext } from '../../context/User';
|
||||
import { Modal } from '@douyinfe/semi-ui';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
@@ -71,18 +78,40 @@ const PersonalSetting = () => {
|
||||
});
|
||||
|
||||
useEffect(() => {
|
||||
let status = localStorage.getItem('status');
|
||||
if (status) {
|
||||
status = JSON.parse(status);
|
||||
setStatus(status);
|
||||
if (status.turnstile_check) {
|
||||
let saved = localStorage.getItem('status');
|
||||
if (saved) {
|
||||
const parsed = JSON.parse(saved);
|
||||
setStatus(parsed);
|
||||
if (parsed.turnstile_check) {
|
||||
setTurnstileEnabled(true);
|
||||
setTurnstileSiteKey(status.turnstile_site_key);
|
||||
setTurnstileSiteKey(parsed.turnstile_site_key);
|
||||
} else {
|
||||
setTurnstileEnabled(false);
|
||||
setTurnstileSiteKey('');
|
||||
}
|
||||
}
|
||||
getUserData().then((res) => {
|
||||
console.log(userState);
|
||||
});
|
||||
// Always refresh status from server to avoid stale flags (e.g., admin just enabled OAuth)
|
||||
(async () => {
|
||||
try {
|
||||
const res = await API.get('/api/status');
|
||||
const { success, data } = res.data;
|
||||
if (success && data) {
|
||||
setStatus(data);
|
||||
setStatusData(data);
|
||||
if (data.turnstile_check) {
|
||||
setTurnstileEnabled(true);
|
||||
setTurnstileSiteKey(data.turnstile_site_key);
|
||||
} else {
|
||||
setTurnstileEnabled(false);
|
||||
setTurnstileSiteKey('');
|
||||
}
|
||||
}
|
||||
} catch (e) {
|
||||
// ignore and keep local status
|
||||
}
|
||||
})();
|
||||
|
||||
getUserData();
|
||||
}, []);
|
||||
|
||||
useEffect(() => {
|
||||
|
||||
@@ -28,6 +28,7 @@ import {
|
||||
Tabs,
|
||||
TabPane,
|
||||
Popover,
|
||||
Modal,
|
||||
} from '@douyinfe/semi-ui';
|
||||
import {
|
||||
IconMail,
|
||||
@@ -83,6 +84,9 @@ const AccountManagement = ({
|
||||
</Popover>
|
||||
);
|
||||
};
|
||||
const isBound = (accountId) => Boolean(accountId);
|
||||
const [showTelegramBindModal, setShowTelegramBindModal] = React.useState(false);
|
||||
|
||||
return (
|
||||
<Card className='!rounded-2xl'>
|
||||
{/* 卡片头部 */}
|
||||
@@ -142,7 +146,7 @@ const AccountManagement = ({
|
||||
size='small'
|
||||
onClick={() => setShowEmailBindModal(true)}
|
||||
>
|
||||
{userState.user && userState.user.email !== ''
|
||||
{isBound(userState.user?.email)
|
||||
? t('修改绑定')
|
||||
: t('绑定')}
|
||||
</Button>
|
||||
@@ -165,9 +169,11 @@ const AccountManagement = ({
|
||||
{t('微信')}
|
||||
</div>
|
||||
<div className='text-sm text-gray-500 truncate'>
|
||||
{userState.user && userState.user.wechat_id !== ''
|
||||
? t('已绑定')
|
||||
: t('未绑定')}
|
||||
{!status.wechat_login
|
||||
? t('未启用')
|
||||
: isBound(userState.user?.wechat_id)
|
||||
? t('已绑定')
|
||||
: t('未绑定')}
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
@@ -179,7 +185,7 @@ const AccountManagement = ({
|
||||
disabled={!status.wechat_login}
|
||||
onClick={() => setShowWeChatBindModal(true)}
|
||||
>
|
||||
{userState.user && userState.user.wechat_id !== ''
|
||||
{isBound(userState.user?.wechat_id)
|
||||
? t('修改绑定')
|
||||
: status.wechat_login
|
||||
? t('绑定')
|
||||
@@ -220,8 +226,7 @@ const AccountManagement = ({
|
||||
onGitHubOAuthClicked(status.github_client_id)
|
||||
}
|
||||
disabled={
|
||||
(userState.user && userState.user.github_id !== '') ||
|
||||
!status.github_oauth
|
||||
isBound(userState.user?.github_id) || !status.github_oauth
|
||||
}
|
||||
>
|
||||
{status.github_oauth ? t('绑定') : t('未启用')}
|
||||
@@ -264,8 +269,7 @@ const AccountManagement = ({
|
||||
)
|
||||
}
|
||||
disabled={
|
||||
(userState.user && userState.user.oidc_id !== '') ||
|
||||
!status.oidc_enabled
|
||||
isBound(userState.user?.oidc_id) || !status.oidc_enabled
|
||||
}
|
||||
>
|
||||
{status.oidc_enabled ? t('绑定') : t('未启用')}
|
||||
@@ -298,26 +302,56 @@ const AccountManagement = ({
|
||||
</div>
|
||||
<div className='flex-shrink-0'>
|
||||
{status.telegram_oauth ? (
|
||||
userState.user.telegram_id !== '' ? (
|
||||
<Button disabled={true} size='small'>
|
||||
isBound(userState.user?.telegram_id) ? (
|
||||
<Button
|
||||
disabled
|
||||
size='small'
|
||||
type='primary'
|
||||
theme='outline'
|
||||
>
|
||||
{t('已绑定')}
|
||||
</Button>
|
||||
) : (
|
||||
<div className='scale-75'>
|
||||
<TelegramLoginButton
|
||||
dataAuthUrl='/api/oauth/telegram/bind'
|
||||
botName={status.telegram_bot_name}
|
||||
/>
|
||||
</div>
|
||||
<Button
|
||||
type='primary'
|
||||
theme='outline'
|
||||
size='small'
|
||||
onClick={() => setShowTelegramBindModal(true)}
|
||||
>
|
||||
{t('绑定')}
|
||||
</Button>
|
||||
)
|
||||
) : (
|
||||
<Button disabled={true} size='small'>
|
||||
<Button
|
||||
disabled
|
||||
size='small'
|
||||
type='primary'
|
||||
theme='outline'
|
||||
>
|
||||
{t('未启用')}
|
||||
</Button>
|
||||
)}
|
||||
</div>
|
||||
</div>
|
||||
</Card>
|
||||
<Modal
|
||||
title={t('绑定 Telegram')}
|
||||
visible={showTelegramBindModal}
|
||||
onCancel={() => setShowTelegramBindModal(false)}
|
||||
footer={null}
|
||||
>
|
||||
<div className='my-3 text-sm text-gray-600'>
|
||||
{t('点击下方按钮通过 Telegram 完成绑定')}
|
||||
</div>
|
||||
<div className='flex justify-center'>
|
||||
<div className='scale-90'>
|
||||
<TelegramLoginButton
|
||||
dataAuthUrl='/api/oauth/telegram/bind'
|
||||
botName={status.telegram_bot_name}
|
||||
/>
|
||||
</div>
|
||||
</div>
|
||||
</Modal>
|
||||
|
||||
{/* LinuxDO绑定 */}
|
||||
<Card className='!rounded-xl'>
|
||||
@@ -350,8 +384,7 @@ const AccountManagement = ({
|
||||
onLinuxDOOAuthClicked(status.linuxdo_client_id)
|
||||
}
|
||||
disabled={
|
||||
(userState.user && userState.user.linux_do_id !== '') ||
|
||||
!status.linuxdo_oauth
|
||||
isBound(userState.user?.linux_do_id) || !status.linuxdo_oauth
|
||||
}
|
||||
>
|
||||
{status.linuxdo_oauth ? t('绑定') : t('未启用')}
|
||||
|
||||
@@ -85,6 +85,26 @@ const REGION_EXAMPLE = {
|
||||
'claude-3-5-sonnet-20240620': 'europe-west1',
|
||||
};
|
||||
|
||||
// 支持并且已适配通过接口获取模型列表的渠道类型
|
||||
const MODEL_FETCHABLE_TYPES = new Set([
|
||||
1,
|
||||
4,
|
||||
14,
|
||||
34,
|
||||
17,
|
||||
26,
|
||||
24,
|
||||
47,
|
||||
25,
|
||||
20,
|
||||
23,
|
||||
31,
|
||||
35,
|
||||
40,
|
||||
42,
|
||||
48,
|
||||
]);
|
||||
|
||||
function type2secretPrompt(type) {
|
||||
// inputs.type === 15 ? '按照如下格式输入:APIKey|SecretKey' : (inputs.type === 18 ? '按照如下格式输入:APPID|APISecret|APIKey' : '请输入渠道对应的鉴权密钥')
|
||||
switch (type) {
|
||||
@@ -144,6 +164,8 @@ const EditChannelModal = (props) => {
|
||||
settings: '',
|
||||
// 仅 Vertex: 密钥格式(存入 settings.vertex_key_type)
|
||||
vertex_key_type: 'json',
|
||||
// 企业账户设置
|
||||
is_enterprise_account: false,
|
||||
};
|
||||
const [batch, setBatch] = useState(false);
|
||||
const [multiToSingle, setMultiToSingle] = useState(false);
|
||||
@@ -169,6 +191,7 @@ const EditChannelModal = (props) => {
|
||||
const [channelSearchValue, setChannelSearchValue] = useState('');
|
||||
const [useManualInput, setUseManualInput] = useState(false); // 是否使用手动输入模式
|
||||
const [keyMode, setKeyMode] = useState('append'); // 密钥模式:replace(覆盖)或 append(追加)
|
||||
const [isEnterpriseAccount, setIsEnterpriseAccount] = useState(false); // 是否为企业账户
|
||||
|
||||
// 2FA验证查看密钥相关状态
|
||||
const [twoFAState, setTwoFAState] = useState({
|
||||
@@ -215,7 +238,7 @@ const EditChannelModal = (props) => {
|
||||
pass_through_body_enabled: false,
|
||||
system_prompt: '',
|
||||
});
|
||||
const showApiConfigCard = inputs.type !== 45; // 控制是否显示 API 配置卡片(仅当渠道类型不是 豆包 时显示)
|
||||
const showApiConfigCard = true; // 控制是否显示 API 配置卡片
|
||||
const getInitValues = () => ({ ...originInputs });
|
||||
|
||||
// 处理渠道额外设置的更新
|
||||
@@ -322,6 +345,10 @@ const EditChannelModal = (props) => {
|
||||
case 36:
|
||||
localModels = ['suno_music', 'suno_lyrics'];
|
||||
break;
|
||||
case 45:
|
||||
localModels = getChannelModels(value);
|
||||
setInputs((prevInputs) => ({ ...prevInputs, base_url: 'https://ark.cn-beijing.volces.com' }));
|
||||
break;
|
||||
default:
|
||||
localModels = getChannelModels(value);
|
||||
break;
|
||||
@@ -413,15 +440,27 @@ const EditChannelModal = (props) => {
|
||||
parsedSettings.azure_responses_version || '';
|
||||
// 读取 Vertex 密钥格式
|
||||
data.vertex_key_type = parsedSettings.vertex_key_type || 'json';
|
||||
// 读取企业账户设置
|
||||
data.is_enterprise_account = parsedSettings.openrouter_enterprise === true;
|
||||
} catch (error) {
|
||||
console.error('解析其他设置失败:', error);
|
||||
data.azure_responses_version = '';
|
||||
data.region = '';
|
||||
data.vertex_key_type = 'json';
|
||||
data.is_enterprise_account = false;
|
||||
}
|
||||
} else {
|
||||
// 兼容历史数据:老渠道没有 settings 时,默认按 json 展示
|
||||
data.vertex_key_type = 'json';
|
||||
data.is_enterprise_account = false;
|
||||
}
|
||||
|
||||
if (
|
||||
data.type === 45 &&
|
||||
(!data.base_url ||
|
||||
(typeof data.base_url === 'string' && data.base_url.trim() === ''))
|
||||
) {
|
||||
data.base_url = 'https://ark.cn-beijing.volces.com';
|
||||
}
|
||||
|
||||
setInputs(data);
|
||||
@@ -433,6 +472,8 @@ const EditChannelModal = (props) => {
|
||||
} else {
|
||||
setAutoBan(true);
|
||||
}
|
||||
// 同步企业账户状态
|
||||
setIsEnterpriseAccount(data.is_enterprise_account || false);
|
||||
setBasicModels(getChannelModels(data.type));
|
||||
// 同步更新channelSettings状态显示
|
||||
setChannelSettings({
|
||||
@@ -692,6 +733,8 @@ const EditChannelModal = (props) => {
|
||||
});
|
||||
// 重置密钥模式状态
|
||||
setKeyMode('append');
|
||||
// 重置企业账户状态
|
||||
setIsEnterpriseAccount(false);
|
||||
// 清空表单中的key_mode字段
|
||||
if (formApiRef.current) {
|
||||
formApiRef.current.setValue('key_mode', undefined);
|
||||
@@ -802,7 +845,9 @@ const EditChannelModal = (props) => {
|
||||
delete localInputs.key;
|
||||
}
|
||||
} else {
|
||||
localInputs.key = batch ? JSON.stringify(keys) : JSON.stringify(keys[0]);
|
||||
localInputs.key = batch
|
||||
? JSON.stringify(keys)
|
||||
: JSON.stringify(keys[0]);
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -822,6 +867,10 @@ const EditChannelModal = (props) => {
|
||||
showInfo(t('请至少选择一个模型!'));
|
||||
return;
|
||||
}
|
||||
if (localInputs.type === 45 && (!localInputs.base_url || localInputs.base_url.trim() === '')) {
|
||||
showInfo(t('请输入API地址!'));
|
||||
return;
|
||||
}
|
||||
if (
|
||||
localInputs.model_mapping &&
|
||||
localInputs.model_mapping !== '' &&
|
||||
@@ -851,6 +900,21 @@ const EditChannelModal = (props) => {
|
||||
};
|
||||
localInputs.setting = JSON.stringify(channelExtraSettings);
|
||||
|
||||
// 处理type === 20的企业账户设置
|
||||
if (localInputs.type === 20) {
|
||||
let settings = {};
|
||||
if (localInputs.settings) {
|
||||
try {
|
||||
settings = JSON.parse(localInputs.settings);
|
||||
} catch (error) {
|
||||
console.error('解析settings失败:', error);
|
||||
}
|
||||
}
|
||||
// 设置企业账户标识,无论是true还是false都要传到后端
|
||||
settings.openrouter_enterprise = localInputs.is_enterprise_account === true;
|
||||
localInputs.settings = JSON.stringify(settings);
|
||||
}
|
||||
|
||||
// 清理不需要发送到后端的字段
|
||||
delete localInputs.force_format;
|
||||
delete localInputs.thinking_to_content;
|
||||
@@ -858,6 +922,7 @@ const EditChannelModal = (props) => {
|
||||
delete localInputs.pass_through_body_enabled;
|
||||
delete localInputs.system_prompt;
|
||||
delete localInputs.system_prompt_override;
|
||||
delete localInputs.is_enterprise_account;
|
||||
// 顶层的 vertex_key_type 不应发送给后端
|
||||
delete localInputs.vertex_key_type;
|
||||
|
||||
@@ -899,6 +964,56 @@ const EditChannelModal = (props) => {
|
||||
}
|
||||
};
|
||||
|
||||
// 密钥去重函数
|
||||
const deduplicateKeys = () => {
|
||||
const currentKey = formApiRef.current?.getValue('key') || inputs.key || '';
|
||||
|
||||
if (!currentKey.trim()) {
|
||||
showInfo(t('请先输入密钥'));
|
||||
return;
|
||||
}
|
||||
|
||||
// 按行分割密钥
|
||||
const keyLines = currentKey.split('\n');
|
||||
const beforeCount = keyLines.length;
|
||||
|
||||
// 使用哈希表去重,保持原有顺序
|
||||
const keySet = new Set();
|
||||
const deduplicatedKeys = [];
|
||||
|
||||
keyLines.forEach((line) => {
|
||||
const trimmedLine = line.trim();
|
||||
if (trimmedLine && !keySet.has(trimmedLine)) {
|
||||
keySet.add(trimmedLine);
|
||||
deduplicatedKeys.push(trimmedLine);
|
||||
}
|
||||
});
|
||||
|
||||
const afterCount = deduplicatedKeys.length;
|
||||
const deduplicatedKeyText = deduplicatedKeys.join('\n');
|
||||
|
||||
// 更新表单和状态
|
||||
if (formApiRef.current) {
|
||||
formApiRef.current.setValue('key', deduplicatedKeyText);
|
||||
}
|
||||
handleInputChange('key', deduplicatedKeyText);
|
||||
|
||||
// 显示去重结果
|
||||
const message = t(
|
||||
'去重完成:去重前 {{before}} 个密钥,去重后 {{after}} 个密钥',
|
||||
{
|
||||
before: beforeCount,
|
||||
after: afterCount,
|
||||
},
|
||||
);
|
||||
|
||||
if (beforeCount === afterCount) {
|
||||
showInfo(t('未发现重复密钥'));
|
||||
} else {
|
||||
showSuccess(message);
|
||||
}
|
||||
};
|
||||
|
||||
const addCustomModels = () => {
|
||||
if (customModel.trim() === '') return;
|
||||
const modelArray = customModel.split(',').map((model) => model.trim());
|
||||
@@ -994,24 +1109,41 @@ const EditChannelModal = (props) => {
|
||||
</Checkbox>
|
||||
)}
|
||||
{batch && (
|
||||
<Checkbox
|
||||
disabled={isEdit}
|
||||
checked={multiToSingle}
|
||||
onChange={() => {
|
||||
setMultiToSingle((prev) => !prev);
|
||||
setInputs((prev) => {
|
||||
const newInputs = { ...prev };
|
||||
if (!multiToSingle) {
|
||||
newInputs.multi_key_mode = multiKeyMode;
|
||||
} else {
|
||||
delete newInputs.multi_key_mode;
|
||||
}
|
||||
return newInputs;
|
||||
});
|
||||
}}
|
||||
>
|
||||
{t('密钥聚合模式')}
|
||||
</Checkbox>
|
||||
<>
|
||||
<Checkbox
|
||||
disabled={isEdit}
|
||||
checked={multiToSingle}
|
||||
onChange={() => {
|
||||
setMultiToSingle((prev) => {
|
||||
const nextValue = !prev;
|
||||
setInputs((prevInputs) => {
|
||||
const newInputs = { ...prevInputs };
|
||||
if (nextValue) {
|
||||
newInputs.multi_key_mode = multiKeyMode;
|
||||
} else {
|
||||
delete newInputs.multi_key_mode;
|
||||
}
|
||||
return newInputs;
|
||||
});
|
||||
return nextValue;
|
||||
});
|
||||
}}
|
||||
>
|
||||
{t('密钥聚合模式')}
|
||||
</Checkbox>
|
||||
|
||||
{inputs.type !== 41 && (
|
||||
<Button
|
||||
size='small'
|
||||
type='tertiary'
|
||||
theme='outline'
|
||||
onClick={deduplicateKeys}
|
||||
style={{ textDecoration: 'underline' }}
|
||||
>
|
||||
{t('密钥去重')}
|
||||
</Button>
|
||||
)}
|
||||
</>
|
||||
)}
|
||||
</Space>
|
||||
) : null;
|
||||
@@ -1175,6 +1307,21 @@ const EditChannelModal = (props) => {
|
||||
onChange={(value) => handleInputChange('type', value)}
|
||||
/>
|
||||
|
||||
{inputs.type === 20 && (
|
||||
<Form.Switch
|
||||
field='is_enterprise_account'
|
||||
label={t('是否为企业账户')}
|
||||
checkedText={t('是')}
|
||||
uncheckedText={t('否')}
|
||||
onChange={(value) => {
|
||||
setIsEnterpriseAccount(value);
|
||||
handleInputChange('is_enterprise_account', value);
|
||||
}}
|
||||
extraText={t('企业账户为特殊返回格式,需要特殊处理,如果非企业账户,请勿勾选')}
|
||||
initValue={inputs.is_enterprise_account}
|
||||
/>
|
||||
)}
|
||||
|
||||
<Form.Input
|
||||
field='name'
|
||||
label={t('名称')}
|
||||
@@ -1198,7 +1345,10 @@ const EditChannelModal = (props) => {
|
||||
value={inputs.vertex_key_type || 'json'}
|
||||
onChange={(value) => {
|
||||
// 更新设置中的 vertex_key_type
|
||||
handleChannelOtherSettingsChange('vertex_key_type', value);
|
||||
handleChannelOtherSettingsChange(
|
||||
'vertex_key_type',
|
||||
value,
|
||||
);
|
||||
// 切换为 api_key 时,关闭批量与手动/文件切换,并清理已选文件
|
||||
if (value === 'api_key') {
|
||||
setBatch(false);
|
||||
@@ -1218,7 +1368,8 @@ const EditChannelModal = (props) => {
|
||||
/>
|
||||
)}
|
||||
{batch ? (
|
||||
inputs.type === 41 && (inputs.vertex_key_type || 'json') === 'json' ? (
|
||||
inputs.type === 41 &&
|
||||
(inputs.vertex_key_type || 'json') === 'json' ? (
|
||||
<Form.Upload
|
||||
field='vertex_files'
|
||||
label={t('密钥文件 (.json)')}
|
||||
@@ -1254,7 +1405,7 @@ const EditChannelModal = (props) => {
|
||||
autoComplete='new-password'
|
||||
onChange={(value) => handleInputChange('key', value)}
|
||||
extraText={
|
||||
<div className='flex items-center gap-2'>
|
||||
<div className='flex items-center gap-2 flex-wrap'>
|
||||
{isEdit &&
|
||||
isMultiKeyChannel &&
|
||||
keyMode === 'append' && (
|
||||
@@ -1282,7 +1433,8 @@ const EditChannelModal = (props) => {
|
||||
)
|
||||
) : (
|
||||
<>
|
||||
{inputs.type === 41 && (inputs.vertex_key_type || 'json') === 'json' ? (
|
||||
{inputs.type === 41 &&
|
||||
(inputs.vertex_key_type || 'json') === 'json' ? (
|
||||
<>
|
||||
{!batch && (
|
||||
<div className='flex items-center justify-between mb-3'>
|
||||
@@ -1789,6 +1941,30 @@ const EditChannelModal = (props) => {
|
||||
/>
|
||||
</div>
|
||||
)}
|
||||
|
||||
{inputs.type === 45 && (
|
||||
<div>
|
||||
<Form.Select
|
||||
field='base_url'
|
||||
label={t('API地址')}
|
||||
placeholder={t('请选择API地址')}
|
||||
onChange={(value) =>
|
||||
handleInputChange('base_url', value)
|
||||
}
|
||||
optionList={[
|
||||
{
|
||||
value: 'https://ark.cn-beijing.volces.com',
|
||||
label: 'https://ark.cn-beijing.volces.com'
|
||||
},
|
||||
{
|
||||
value: 'https://ark.ap-southeast.bytepluses.com',
|
||||
label: 'https://ark.ap-southeast.bytepluses.com'
|
||||
}
|
||||
]}
|
||||
defaultValue='https://ark.cn-beijing.volces.com'
|
||||
/>
|
||||
</div>
|
||||
)}
|
||||
</Card>
|
||||
)}
|
||||
|
||||
@@ -1872,13 +2048,15 @@ const EditChannelModal = (props) => {
|
||||
>
|
||||
{t('填入所有模型')}
|
||||
</Button>
|
||||
<Button
|
||||
size='small'
|
||||
type='tertiary'
|
||||
onClick={() => fetchUpstreamModelList('models')}
|
||||
>
|
||||
{t('获取模型列表')}
|
||||
</Button>
|
||||
{MODEL_FETCHABLE_TYPES.has(inputs.type) && (
|
||||
<Button
|
||||
size='small'
|
||||
type='tertiary'
|
||||
onClick={() => fetchUpstreamModelList('models')}
|
||||
>
|
||||
{t('获取模型列表')}
|
||||
</Button>
|
||||
)}
|
||||
<Button
|
||||
size='small'
|
||||
type='warning'
|
||||
|
||||
@@ -247,6 +247,32 @@ const MultiKeyManageModal = ({ visible, onCancel, channel, onRefresh }) => {
|
||||
}
|
||||
};
|
||||
|
||||
// Delete a specific key
|
||||
const handleDeleteKey = async (keyIndex) => {
|
||||
const operationId = `delete_${keyIndex}`;
|
||||
setOperationLoading((prev) => ({ ...prev, [operationId]: true }));
|
||||
|
||||
try {
|
||||
const res = await API.post('/api/channel/multi_key/manage', {
|
||||
channel_id: channel.id,
|
||||
action: 'delete_key',
|
||||
key_index: keyIndex,
|
||||
});
|
||||
|
||||
if (res.data.success) {
|
||||
showSuccess(t('密钥已删除'));
|
||||
await loadKeyStatus(currentPage, pageSize); // Reload current page
|
||||
onRefresh && onRefresh(); // Refresh parent component
|
||||
} else {
|
||||
showError(res.data.message);
|
||||
}
|
||||
} catch (error) {
|
||||
showError(t('删除密钥失败'));
|
||||
} finally {
|
||||
setOperationLoading((prev) => ({ ...prev, [operationId]: false }));
|
||||
}
|
||||
};
|
||||
|
||||
// Handle page change
|
||||
const handlePageChange = (page) => {
|
||||
setCurrentPage(page);
|
||||
@@ -384,7 +410,7 @@ const MultiKeyManageModal = ({ visible, onCancel, channel, onRefresh }) => {
|
||||
title: t('操作'),
|
||||
key: 'action',
|
||||
fixed: 'right',
|
||||
width: 100,
|
||||
width: 150,
|
||||
render: (_, record) => (
|
||||
<Space>
|
||||
{record.status === 1 ? (
|
||||
@@ -406,6 +432,21 @@ const MultiKeyManageModal = ({ visible, onCancel, channel, onRefresh }) => {
|
||||
{t('启用')}
|
||||
</Button>
|
||||
)}
|
||||
<Popconfirm
|
||||
title={t('确定要删除此密钥吗?')}
|
||||
content={t('此操作不可撤销,将永久删除该密钥')}
|
||||
onConfirm={() => handleDeleteKey(record.index)}
|
||||
okType={'danger'}
|
||||
position={'topRight'}
|
||||
>
|
||||
<Button
|
||||
type='danger'
|
||||
size='small'
|
||||
loading={operationLoading[`delete_${record.index}`]}
|
||||
>
|
||||
{t('删除')}
|
||||
</Button>
|
||||
</Popconfirm>
|
||||
</Space>
|
||||
),
|
||||
},
|
||||
|
||||
@@ -17,8 +17,11 @@ along with this program. If not, see <https://www.gnu.org/licenses/>.
|
||||
For commercial licensing, please contact support@quantumnous.com
|
||||
*/
|
||||
|
||||
import React from 'react';
|
||||
import { Modal } from '@douyinfe/semi-ui';
|
||||
import React, { useState, useEffect } from 'react';
|
||||
import { Modal, Button, Typography, Spin } from '@douyinfe/semi-ui';
|
||||
import { IconExternalOpen, IconCopy } from '@douyinfe/semi-icons';
|
||||
|
||||
const { Text } = Typography;
|
||||
|
||||
const ContentModal = ({
|
||||
isModalOpen,
|
||||
@@ -26,17 +29,120 @@ const ContentModal = ({
|
||||
modalContent,
|
||||
isVideo,
|
||||
}) => {
|
||||
const [videoError, setVideoError] = useState(false);
|
||||
const [isLoading, setIsLoading] = useState(false);
|
||||
|
||||
useEffect(() => {
|
||||
if (isModalOpen && isVideo) {
|
||||
setVideoError(false);
|
||||
setIsLoading(true);
|
||||
}
|
||||
}, [isModalOpen, isVideo]);
|
||||
|
||||
const handleVideoError = () => {
|
||||
setVideoError(true);
|
||||
setIsLoading(false);
|
||||
};
|
||||
|
||||
const handleVideoLoaded = () => {
|
||||
setIsLoading(false);
|
||||
};
|
||||
|
||||
const handleCopyUrl = () => {
|
||||
navigator.clipboard.writeText(modalContent);
|
||||
};
|
||||
|
||||
const handleOpenInNewTab = () => {
|
||||
window.open(modalContent, '_blank');
|
||||
};
|
||||
|
||||
const renderVideoContent = () => {
|
||||
if (videoError) {
|
||||
return (
|
||||
<div style={{ textAlign: 'center', padding: '40px' }}>
|
||||
<Text type="tertiary" style={{ display: 'block', marginBottom: '16px' }}>
|
||||
视频无法在当前浏览器中播放,这可能是由于:
|
||||
</Text>
|
||||
<Text type="tertiary" style={{ display: 'block', marginBottom: '8px', fontSize: '12px' }}>
|
||||
• 视频服务商的跨域限制
|
||||
</Text>
|
||||
<Text type="tertiary" style={{ display: 'block', marginBottom: '8px', fontSize: '12px' }}>
|
||||
• 需要特定的请求头或认证
|
||||
</Text>
|
||||
<Text type="tertiary" style={{ display: 'block', marginBottom: '16px', fontSize: '12px' }}>
|
||||
• 防盗链保护机制
|
||||
</Text>
|
||||
|
||||
<div style={{ marginTop: '20px' }}>
|
||||
<Button
|
||||
icon={<IconExternalOpen />}
|
||||
onClick={handleOpenInNewTab}
|
||||
style={{ marginRight: '8px' }}
|
||||
>
|
||||
在新标签页中打开
|
||||
</Button>
|
||||
<Button
|
||||
icon={<IconCopy />}
|
||||
onClick={handleCopyUrl}
|
||||
>
|
||||
复制链接
|
||||
</Button>
|
||||
</div>
|
||||
|
||||
<div style={{ marginTop: '16px', padding: '8px', backgroundColor: '#f8f9fa', borderRadius: '4px' }}>
|
||||
<Text
|
||||
type="tertiary"
|
||||
style={{ fontSize: '10px', wordBreak: 'break-all' }}
|
||||
>
|
||||
{modalContent}
|
||||
</Text>
|
||||
</div>
|
||||
</div>
|
||||
);
|
||||
}
|
||||
|
||||
return (
|
||||
<div style={{ position: 'relative' }}>
|
||||
{isLoading && (
|
||||
<div style={{
|
||||
position: 'absolute',
|
||||
top: '50%',
|
||||
left: '50%',
|
||||
transform: 'translate(-50%, -50%)',
|
||||
zIndex: 10
|
||||
}}>
|
||||
<Spin size="large" />
|
||||
</div>
|
||||
)}
|
||||
<video
|
||||
src={modalContent}
|
||||
controls
|
||||
style={{ width: '100%' }}
|
||||
autoPlay
|
||||
crossOrigin="anonymous"
|
||||
onError={handleVideoError}
|
||||
onLoadedData={handleVideoLoaded}
|
||||
onLoadStart={() => setIsLoading(true)}
|
||||
/>
|
||||
</div>
|
||||
);
|
||||
};
|
||||
|
||||
return (
|
||||
<Modal
|
||||
visible={isModalOpen}
|
||||
onOk={() => setIsModalOpen(false)}
|
||||
onCancel={() => setIsModalOpen(false)}
|
||||
closable={null}
|
||||
bodyStyle={{ height: '400px', overflow: 'auto' }}
|
||||
bodyStyle={{
|
||||
height: isVideo ? '450px' : '400px',
|
||||
overflow: 'auto',
|
||||
padding: isVideo && videoError ? '0' : '24px'
|
||||
}}
|
||||
width={800}
|
||||
>
|
||||
{isVideo ? (
|
||||
<video src={modalContent} controls style={{ width: '100%' }} autoPlay />
|
||||
renderVideoContent()
|
||||
) : (
|
||||
<p style={{ whiteSpace: 'pre-line' }}>{modalContent}</p>
|
||||
)}
|
||||
|
||||
@@ -118,7 +118,6 @@ export const buildApiPayload = (
|
||||
model: inputs.model,
|
||||
group: inputs.group,
|
||||
messages: processedMessages,
|
||||
group: inputs.group,
|
||||
stream: inputs.stream,
|
||||
};
|
||||
|
||||
@@ -132,13 +131,15 @@ export const buildApiPayload = (
|
||||
seed: 'seed',
|
||||
};
|
||||
|
||||
|
||||
Object.entries(parameterMappings).forEach(([key, param]) => {
|
||||
if (
|
||||
parameterEnabled[key] &&
|
||||
inputs[param] !== undefined &&
|
||||
inputs[param] !== null
|
||||
) {
|
||||
payload[param] = inputs[param];
|
||||
const enabled = parameterEnabled[key];
|
||||
const value = inputs[param];
|
||||
const hasValue = value !== undefined && value !== null;
|
||||
|
||||
|
||||
if (enabled && hasValue) {
|
||||
payload[param] = value;
|
||||
}
|
||||
});
|
||||
|
||||
|
||||
@@ -75,13 +75,17 @@ export async function copy(text) {
|
||||
await navigator.clipboard.writeText(text);
|
||||
} catch (e) {
|
||||
try {
|
||||
// 构建input 执行 复制命令
|
||||
var _input = window.document.createElement('input');
|
||||
_input.value = text;
|
||||
window.document.body.appendChild(_input);
|
||||
_input.select();
|
||||
window.document.execCommand('Copy');
|
||||
window.document.body.removeChild(_input);
|
||||
// 构建 textarea 执行复制命令,保留多行文本格式
|
||||
const textarea = window.document.createElement('textarea');
|
||||
textarea.value = text;
|
||||
textarea.setAttribute('readonly', '');
|
||||
textarea.style.position = 'fixed';
|
||||
textarea.style.left = '-9999px';
|
||||
textarea.style.top = '-9999px';
|
||||
window.document.body.appendChild(textarea);
|
||||
textarea.select();
|
||||
window.document.execCommand('copy');
|
||||
window.document.body.removeChild(textarea);
|
||||
} catch (e) {
|
||||
okay = false;
|
||||
console.error(e);
|
||||
|
||||
@@ -25,13 +25,9 @@ import {
|
||||
showInfo,
|
||||
showSuccess,
|
||||
loadChannelModels,
|
||||
copy,
|
||||
copy
|
||||
} from '../../helpers';
|
||||
import {
|
||||
CHANNEL_OPTIONS,
|
||||
ITEMS_PER_PAGE,
|
||||
MODEL_TABLE_PAGE_SIZE,
|
||||
} from '../../constants';
|
||||
import { CHANNEL_OPTIONS, ITEMS_PER_PAGE, MODEL_TABLE_PAGE_SIZE } from '../../constants';
|
||||
import { useIsMobile } from '../common/useIsMobile';
|
||||
import { useTableCompactMode } from '../common/useTableCompactMode';
|
||||
import { Modal } from '@douyinfe/semi-ui';
|
||||
@@ -68,7 +64,7 @@ export const useChannelsData = () => {
|
||||
|
||||
// Status filter
|
||||
const [statusFilter, setStatusFilter] = useState(
|
||||
localStorage.getItem('channel-status-filter') || 'all',
|
||||
localStorage.getItem('channel-status-filter') || 'all'
|
||||
);
|
||||
|
||||
// Type tabs states
|
||||
@@ -83,9 +79,10 @@ export const useChannelsData = () => {
|
||||
const [testingModels, setTestingModels] = useState(new Set());
|
||||
const [selectedModelKeys, setSelectedModelKeys] = useState([]);
|
||||
const [isBatchTesting, setIsBatchTesting] = useState(false);
|
||||
const [testQueue, setTestQueue] = useState([]);
|
||||
const [isProcessingQueue, setIsProcessingQueue] = useState(false);
|
||||
const [modelTablePage, setModelTablePage] = useState(1);
|
||||
|
||||
// 使用 ref 来避免闭包问题,类似旧版实现
|
||||
const shouldStopBatchTestingRef = useRef(false);
|
||||
|
||||
// Multi-key management states
|
||||
const [showMultiKeyManageModal, setShowMultiKeyManageModal] = useState(false);
|
||||
@@ -119,12 +116,9 @@ export const useChannelsData = () => {
|
||||
// Initialize from localStorage
|
||||
useEffect(() => {
|
||||
const localIdSort = localStorage.getItem('id-sort') === 'true';
|
||||
const localPageSize =
|
||||
parseInt(localStorage.getItem('page-size')) || ITEMS_PER_PAGE;
|
||||
const localEnableTagMode =
|
||||
localStorage.getItem('enable-tag-mode') === 'true';
|
||||
const localEnableBatchDelete =
|
||||
localStorage.getItem('enable-batch-delete') === 'true';
|
||||
const localPageSize = parseInt(localStorage.getItem('page-size')) || ITEMS_PER_PAGE;
|
||||
const localEnableTagMode = localStorage.getItem('enable-tag-mode') === 'true';
|
||||
const localEnableBatchDelete = localStorage.getItem('enable-batch-delete') === 'true';
|
||||
|
||||
setIdSort(localIdSort);
|
||||
setPageSize(localPageSize);
|
||||
@@ -182,10 +176,7 @@ export const useChannelsData = () => {
|
||||
// Save column preferences
|
||||
useEffect(() => {
|
||||
if (Object.keys(visibleColumns).length > 0) {
|
||||
localStorage.setItem(
|
||||
'channels-table-columns',
|
||||
JSON.stringify(visibleColumns),
|
||||
);
|
||||
localStorage.setItem('channels-table-columns', JSON.stringify(visibleColumns));
|
||||
}
|
||||
}, [visibleColumns]);
|
||||
|
||||
@@ -299,21 +290,14 @@ export const useChannelsData = () => {
|
||||
const { searchKeyword, searchGroup, searchModel } = getFormValues();
|
||||
if (searchKeyword !== '' || searchGroup !== '' || searchModel !== '') {
|
||||
setLoading(true);
|
||||
await searchChannels(
|
||||
enableTagMode,
|
||||
typeKey,
|
||||
statusF,
|
||||
page,
|
||||
pageSize,
|
||||
idSort,
|
||||
);
|
||||
await searchChannels(enableTagMode, typeKey, statusF, page, pageSize, idSort);
|
||||
setLoading(false);
|
||||
return;
|
||||
}
|
||||
|
||||
const reqId = ++requestCounter.current;
|
||||
setLoading(true);
|
||||
const typeParam = typeKey !== 'all' ? `&type=${typeKey}` : '';
|
||||
const typeParam = (typeKey !== 'all') ? `&type=${typeKey}` : '';
|
||||
const statusParam = statusF !== 'all' ? `&status=${statusF}` : '';
|
||||
const res = await API.get(
|
||||
`/api/channel/?p=${page}&page_size=${pageSize}&id_sort=${idSort}&tag_mode=${enableTagMode}${typeParam}${statusParam}`,
|
||||
@@ -327,10 +311,7 @@ export const useChannelsData = () => {
|
||||
if (success) {
|
||||
const { items, total, type_counts } = data;
|
||||
if (type_counts) {
|
||||
const sumAll = Object.values(type_counts).reduce(
|
||||
(acc, v) => acc + v,
|
||||
0,
|
||||
);
|
||||
const sumAll = Object.values(type_counts).reduce((acc, v) => acc + v, 0);
|
||||
setTypeCounts({ ...type_counts, all: sumAll });
|
||||
}
|
||||
setChannelFormat(items, enableTagMode);
|
||||
@@ -354,18 +335,11 @@ export const useChannelsData = () => {
|
||||
setSearching(true);
|
||||
try {
|
||||
if (searchKeyword === '' && searchGroup === '' && searchModel === '') {
|
||||
await loadChannels(
|
||||
page,
|
||||
pageSz,
|
||||
sortFlag,
|
||||
enableTagMode,
|
||||
typeKey,
|
||||
statusF,
|
||||
);
|
||||
await loadChannels(page, pageSz, sortFlag, enableTagMode, typeKey, statusF);
|
||||
return;
|
||||
}
|
||||
|
||||
const typeParam = typeKey !== 'all' ? `&type=${typeKey}` : '';
|
||||
const typeParam = (typeKey !== 'all') ? `&type=${typeKey}` : '';
|
||||
const statusParam = statusF !== 'all' ? `&status=${statusF}` : '';
|
||||
const res = await API.get(
|
||||
`/api/channel/search?keyword=${searchKeyword}&group=${searchGroup}&model=${searchModel}&id_sort=${sortFlag}&tag_mode=${enableTagMode}&p=${page}&page_size=${pageSz}${typeParam}${statusParam}`,
|
||||
@@ -373,10 +347,7 @@ export const useChannelsData = () => {
|
||||
const { success, message, data } = res.data;
|
||||
if (success) {
|
||||
const { items = [], total = 0, type_counts = {} } = data;
|
||||
const sumAll = Object.values(type_counts).reduce(
|
||||
(acc, v) => acc + v,
|
||||
0,
|
||||
);
|
||||
const sumAll = Object.values(type_counts).reduce((acc, v) => acc + v, 0);
|
||||
setTypeCounts({ ...type_counts, all: sumAll });
|
||||
setChannelFormat(items, enableTagMode);
|
||||
setChannelCount(total);
|
||||
@@ -395,14 +366,7 @@ export const useChannelsData = () => {
|
||||
if (searchKeyword === '' && searchGroup === '' && searchModel === '') {
|
||||
await loadChannels(page, pageSize, idSort, enableTagMode);
|
||||
} else {
|
||||
await searchChannels(
|
||||
enableTagMode,
|
||||
activeTypeKey,
|
||||
statusFilter,
|
||||
page,
|
||||
pageSize,
|
||||
idSort,
|
||||
);
|
||||
await searchChannels(enableTagMode, activeTypeKey, statusFilter, page, pageSize, idSort);
|
||||
}
|
||||
};
|
||||
|
||||
@@ -488,16 +452,9 @@ export const useChannelsData = () => {
|
||||
const { searchKeyword, searchGroup, searchModel } = getFormValues();
|
||||
setActivePage(page);
|
||||
if (searchKeyword === '' && searchGroup === '' && searchModel === '') {
|
||||
loadChannels(page, pageSize, idSort, enableTagMode).then(() => {});
|
||||
loadChannels(page, pageSize, idSort, enableTagMode).then(() => { });
|
||||
} else {
|
||||
searchChannels(
|
||||
enableTagMode,
|
||||
activeTypeKey,
|
||||
statusFilter,
|
||||
page,
|
||||
pageSize,
|
||||
idSort,
|
||||
);
|
||||
searchChannels(enableTagMode, activeTypeKey, statusFilter, page, pageSize, idSort);
|
||||
}
|
||||
};
|
||||
|
||||
@@ -513,14 +470,7 @@ export const useChannelsData = () => {
|
||||
showError(reason);
|
||||
});
|
||||
} else {
|
||||
searchChannels(
|
||||
enableTagMode,
|
||||
activeTypeKey,
|
||||
statusFilter,
|
||||
1,
|
||||
size,
|
||||
idSort,
|
||||
);
|
||||
searchChannels(enableTagMode, activeTypeKey, statusFilter, 1, size, idSort);
|
||||
}
|
||||
};
|
||||
|
||||
@@ -551,10 +501,7 @@ export const useChannelsData = () => {
|
||||
showError(res?.data?.message || t('渠道复制失败'));
|
||||
}
|
||||
} catch (error) {
|
||||
showError(
|
||||
t('渠道复制失败: ') +
|
||||
(error?.response?.data?.message || error?.message || error),
|
||||
);
|
||||
showError(t('渠道复制失败: ') + (error?.response?.data?.message || error?.message || error));
|
||||
}
|
||||
};
|
||||
|
||||
@@ -593,11 +540,7 @@ export const useChannelsData = () => {
|
||||
data.priority = parseInt(data.priority);
|
||||
break;
|
||||
case 'weight':
|
||||
if (
|
||||
data.weight === undefined ||
|
||||
data.weight < 0 ||
|
||||
data.weight === ''
|
||||
) {
|
||||
if (data.weight === undefined || data.weight < 0 || data.weight === '') {
|
||||
showInfo('权重必须是非负整数!');
|
||||
return;
|
||||
}
|
||||
@@ -740,136 +683,226 @@ export const useChannelsData = () => {
|
||||
const res = await API.post(`/api/channel/fix`);
|
||||
const { success, message, data } = res.data;
|
||||
if (success) {
|
||||
showSuccess(
|
||||
t('已修复 ${success} 个通道,失败 ${fails} 个通道。')
|
||||
.replace('${success}', data.success)
|
||||
.replace('${fails}', data.fails),
|
||||
);
|
||||
showSuccess(t('已修复 ${success} 个通道,失败 ${fails} 个通道。').replace('${success}', data.success).replace('${fails}', data.fails));
|
||||
await refresh();
|
||||
} else {
|
||||
showError(message);
|
||||
}
|
||||
};
|
||||
|
||||
// Test channel
|
||||
// Test channel - 单个模型测试,参考旧版实现
|
||||
const testChannel = async (record, model) => {
|
||||
setTestQueue((prev) => [...prev, { channel: record, model }]);
|
||||
if (!isProcessingQueue) {
|
||||
setIsProcessingQueue(true);
|
||||
const testKey = `${record.id}-${model}`;
|
||||
|
||||
// 检查是否应该停止批量测试
|
||||
if (shouldStopBatchTestingRef.current && isBatchTesting) {
|
||||
return Promise.resolve();
|
||||
}
|
||||
};
|
||||
|
||||
// Process test queue
|
||||
const processTestQueue = async () => {
|
||||
if (!isProcessingQueue || testQueue.length === 0) return;
|
||||
|
||||
const { channel, model, indexInFiltered } = testQueue[0];
|
||||
|
||||
if (currentTestChannel && currentTestChannel.id === channel.id) {
|
||||
let pageNo;
|
||||
if (indexInFiltered !== undefined) {
|
||||
pageNo = Math.floor(indexInFiltered / MODEL_TABLE_PAGE_SIZE) + 1;
|
||||
} else {
|
||||
const filteredModelsList = currentTestChannel.models
|
||||
.split(',')
|
||||
.filter((m) =>
|
||||
m.toLowerCase().includes(modelSearchKeyword.toLowerCase()),
|
||||
);
|
||||
const modelIdx = filteredModelsList.indexOf(model);
|
||||
pageNo =
|
||||
modelIdx !== -1
|
||||
? Math.floor(modelIdx / MODEL_TABLE_PAGE_SIZE) + 1
|
||||
: 1;
|
||||
}
|
||||
setModelTablePage(pageNo);
|
||||
}
|
||||
// 添加到正在测试的模型集合
|
||||
setTestingModels(prev => new Set([...prev, model]));
|
||||
|
||||
try {
|
||||
setTestingModels((prev) => new Set([...prev, model]));
|
||||
const res = await API.get(
|
||||
`/api/channel/test/${channel.id}?model=${model}`,
|
||||
);
|
||||
const res = await API.get(`/api/channel/test/${record.id}?model=${model}`);
|
||||
|
||||
// 检查是否在请求期间被停止
|
||||
if (shouldStopBatchTestingRef.current && isBatchTesting) {
|
||||
return Promise.resolve();
|
||||
}
|
||||
|
||||
const { success, message, time } = res.data;
|
||||
|
||||
setModelTestResults((prev) => ({
|
||||
// 更新测试结果
|
||||
setModelTestResults(prev => ({
|
||||
...prev,
|
||||
[`${channel.id}-${model}`]: { success, time },
|
||||
[testKey]: {
|
||||
success,
|
||||
message,
|
||||
time: time || 0,
|
||||
timestamp: Date.now()
|
||||
}
|
||||
}));
|
||||
|
||||
if (success) {
|
||||
updateChannelProperty(channel.id, (ch) => {
|
||||
ch.response_time = time * 1000;
|
||||
ch.test_time = Date.now() / 1000;
|
||||
// 更新渠道响应时间
|
||||
updateChannelProperty(record.id, (channel) => {
|
||||
channel.response_time = time * 1000;
|
||||
channel.test_time = Date.now() / 1000;
|
||||
});
|
||||
if (!model) {
|
||||
|
||||
if (!model || model === '') {
|
||||
showInfo(
|
||||
t('通道 ${name} 测试成功,耗时 ${time.toFixed(2)} 秒。')
|
||||
.replace('${name}', channel.name)
|
||||
.replace('${name}', record.name)
|
||||
.replace('${time.toFixed(2)}', time.toFixed(2)),
|
||||
);
|
||||
} else {
|
||||
showInfo(
|
||||
t('通道 ${name} 测试成功,模型 ${model} 耗时 ${time.toFixed(2)} 秒。')
|
||||
.replace('${name}', record.name)
|
||||
.replace('${model}', model)
|
||||
.replace('${time.toFixed(2)}', time.toFixed(2)),
|
||||
);
|
||||
}
|
||||
} else {
|
||||
showError(message);
|
||||
showError(`${t('模型')} ${model}: ${message}`);
|
||||
}
|
||||
} catch (error) {
|
||||
showError(error.message);
|
||||
// 处理网络错误
|
||||
const testKey = `${record.id}-${model}`;
|
||||
setModelTestResults(prev => ({
|
||||
...prev,
|
||||
[testKey]: {
|
||||
success: false,
|
||||
message: error.message || t('网络错误'),
|
||||
time: 0,
|
||||
timestamp: Date.now()
|
||||
}
|
||||
}));
|
||||
showError(`${t('模型')} ${model}: ${error.message || t('测试失败')}`);
|
||||
} finally {
|
||||
setTestingModels((prev) => {
|
||||
// 从正在测试的模型集合中移除
|
||||
setTestingModels(prev => {
|
||||
const newSet = new Set(prev);
|
||||
newSet.delete(model);
|
||||
return newSet;
|
||||
});
|
||||
}
|
||||
|
||||
setTestQueue((prev) => prev.slice(1));
|
||||
};
|
||||
|
||||
// Monitor queue changes
|
||||
useEffect(() => {
|
||||
if (testQueue.length > 0 && isProcessingQueue) {
|
||||
processTestQueue();
|
||||
} else if (testQueue.length === 0 && isProcessingQueue) {
|
||||
setIsProcessingQueue(false);
|
||||
setIsBatchTesting(false);
|
||||
}
|
||||
}, [testQueue, isProcessingQueue]);
|
||||
|
||||
// Batch test models
|
||||
// 批量测试单个渠道的所有模型,参考旧版实现
|
||||
const batchTestModels = async () => {
|
||||
if (!currentTestChannel) return;
|
||||
if (!currentTestChannel || !currentTestChannel.models) {
|
||||
showError(t('渠道模型信息不完整'));
|
||||
return;
|
||||
}
|
||||
|
||||
const models = currentTestChannel.models.split(',').filter(model =>
|
||||
model.toLowerCase().includes(modelSearchKeyword.toLowerCase())
|
||||
);
|
||||
|
||||
if (models.length === 0) {
|
||||
showError(t('没有找到匹配的模型'));
|
||||
return;
|
||||
}
|
||||
|
||||
setIsBatchTesting(true);
|
||||
setModelTablePage(1);
|
||||
shouldStopBatchTestingRef.current = false; // 重置停止标志
|
||||
|
||||
const filteredModels = currentTestChannel.models
|
||||
.split(',')
|
||||
.filter((model) =>
|
||||
model.toLowerCase().includes(modelSearchKeyword.toLowerCase()),
|
||||
);
|
||||
// 清空该渠道之前的测试结果
|
||||
setModelTestResults(prev => {
|
||||
const newResults = { ...prev };
|
||||
models.forEach(model => {
|
||||
const testKey = `${currentTestChannel.id}-${model}`;
|
||||
delete newResults[testKey];
|
||||
});
|
||||
return newResults;
|
||||
});
|
||||
|
||||
setTestQueue(
|
||||
filteredModels.map((model, idx) => ({
|
||||
channel: currentTestChannel,
|
||||
model,
|
||||
indexInFiltered: idx,
|
||||
})),
|
||||
);
|
||||
setIsProcessingQueue(true);
|
||||
try {
|
||||
showInfo(t('开始批量测试 ${count} 个模型,已清空上次结果...').replace('${count}', models.length));
|
||||
|
||||
// 提高并发数量以加快测试速度,参考旧版的并发限制
|
||||
const concurrencyLimit = 5;
|
||||
const results = [];
|
||||
|
||||
for (let i = 0; i < models.length; i += concurrencyLimit) {
|
||||
// 检查是否应该停止
|
||||
if (shouldStopBatchTestingRef.current) {
|
||||
showInfo(t('批量测试已停止'));
|
||||
break;
|
||||
}
|
||||
|
||||
const batch = models.slice(i, i + concurrencyLimit);
|
||||
showInfo(t('正在测试第 ${current} - ${end} 个模型 (共 ${total} 个)')
|
||||
.replace('${current}', i + 1)
|
||||
.replace('${end}', Math.min(i + concurrencyLimit, models.length))
|
||||
.replace('${total}', models.length)
|
||||
);
|
||||
|
||||
const batchPromises = batch.map(model => testChannel(currentTestChannel, model));
|
||||
const batchResults = await Promise.allSettled(batchPromises);
|
||||
results.push(...batchResults);
|
||||
|
||||
// 再次检查是否应该停止
|
||||
if (shouldStopBatchTestingRef.current) {
|
||||
showInfo(t('批量测试已停止'));
|
||||
break;
|
||||
}
|
||||
|
||||
// 短暂延迟避免过于频繁的请求
|
||||
if (i + concurrencyLimit < models.length) {
|
||||
await new Promise(resolve => setTimeout(resolve, 100));
|
||||
}
|
||||
}
|
||||
|
||||
if (!shouldStopBatchTestingRef.current) {
|
||||
// 等待一小段时间确保所有结果都已更新
|
||||
await new Promise(resolve => setTimeout(resolve, 300));
|
||||
|
||||
// 使用当前状态重新计算结果统计
|
||||
setModelTestResults(currentResults => {
|
||||
let successCount = 0;
|
||||
let failCount = 0;
|
||||
|
||||
models.forEach(model => {
|
||||
const testKey = `${currentTestChannel.id}-${model}`;
|
||||
const result = currentResults[testKey];
|
||||
if (result && result.success) {
|
||||
successCount++;
|
||||
} else {
|
||||
failCount++;
|
||||
}
|
||||
});
|
||||
|
||||
// 显示完成消息
|
||||
setTimeout(() => {
|
||||
showSuccess(t('批量测试完成!成功: ${success}, 失败: ${fail}, 总计: ${total}')
|
||||
.replace('${success}', successCount)
|
||||
.replace('${fail}', failCount)
|
||||
.replace('${total}', models.length)
|
||||
);
|
||||
}, 100);
|
||||
|
||||
return currentResults; // 不修改状态,只是为了获取最新值
|
||||
});
|
||||
}
|
||||
} catch (error) {
|
||||
showError(t('批量测试过程中发生错误: ') + error.message);
|
||||
} finally {
|
||||
setIsBatchTesting(false);
|
||||
}
|
||||
};
|
||||
|
||||
// 停止批量测试
|
||||
const stopBatchTesting = () => {
|
||||
shouldStopBatchTestingRef.current = true;
|
||||
setIsBatchTesting(false);
|
||||
setTestingModels(new Set());
|
||||
showInfo(t('已停止批量测试'));
|
||||
};
|
||||
|
||||
// 清空测试结果
|
||||
const clearTestResults = () => {
|
||||
setModelTestResults({});
|
||||
showInfo(t('已清空测试结果'));
|
||||
};
|
||||
|
||||
// Handle close modal
|
||||
const handleCloseModal = () => {
|
||||
// 如果正在批量测试,先停止测试
|
||||
if (isBatchTesting) {
|
||||
setTestQueue([]);
|
||||
setIsProcessingQueue(false);
|
||||
setIsBatchTesting(false);
|
||||
showSuccess(t('已停止测试'));
|
||||
} else {
|
||||
setShowModelTestModal(false);
|
||||
setModelSearchKeyword('');
|
||||
setSelectedModelKeys([]);
|
||||
setModelTablePage(1);
|
||||
shouldStopBatchTestingRef.current = true;
|
||||
showInfo(t('关闭弹窗,已停止批量测试'));
|
||||
}
|
||||
|
||||
setShowModelTestModal(false);
|
||||
setModelSearchKeyword('');
|
||||
setIsBatchTesting(false);
|
||||
setTestingModels(new Set());
|
||||
setSelectedModelKeys([]);
|
||||
setModelTablePage(1);
|
||||
// 可选择性保留测试结果,这里不清空以便用户查看
|
||||
};
|
||||
|
||||
// Type counts
|
||||
@@ -1012,4 +1045,4 @@ export const useChannelsData = () => {
|
||||
setCompactMode,
|
||||
setActivePage,
|
||||
};
|
||||
};
|
||||
};
|
||||
@@ -837,6 +837,7 @@
|
||||
"确定要充值 $": "Confirm to top up $",
|
||||
"微信/支付宝 实付金额:": "WeChat/Alipay actual payment amount:",
|
||||
"Stripe 实付金额:": "Stripe actual payment amount:",
|
||||
"允许在 Stripe 支付中输入促销码": "Allow entering promotion codes during Stripe checkout",
|
||||
"支付中...": "Paying",
|
||||
"支付宝": "Alipay",
|
||||
"收益统计": "Income statistics",
|
||||
@@ -1889,6 +1890,10 @@
|
||||
"确定要删除所有已自动禁用的密钥吗?": "Are you sure you want to delete all automatically disabled keys?",
|
||||
"此操作不可撤销,将永久删除已自动禁用的密钥": "This operation cannot be undone, and all automatically disabled keys will be permanently deleted.",
|
||||
"删除自动禁用密钥": "Delete auto disabled keys",
|
||||
"确定要删除此密钥吗?": "Are you sure you want to delete this key?",
|
||||
"此操作不可撤销,将永久删除该密钥": "This operation cannot be undone, and the key will be permanently deleted.",
|
||||
"密钥已删除": "Key has been deleted",
|
||||
"删除密钥失败": "Failed to delete key",
|
||||
"图标": "Icon",
|
||||
"模型图标": "Model icon",
|
||||
"请输入图标名称": "Please enter the icon name",
|
||||
|
||||
@@ -32,5 +32,6 @@
|
||||
"端口配置详细说明": "限制外部请求只能访问指定端口。支持单个端口(80, 443)或端口范围(8000-8999)。空列表允许所有端口。默认包含常用Web端口。",
|
||||
"输入端口后回车,如:80 或 8000-8999": "输入端口后回车,如:80 或 8000-8999",
|
||||
"更新SSRF防护设置": "更新SSRF防护设置",
|
||||
"域名IP过滤详细说明": "⚠️此功能为实验性选项,域名可能解析到多个 IPv4/IPv6 地址,若开启,请确保 IP 过滤列表覆盖这些地址,否则可能导致访问失败。"
|
||||
"域名IP过滤详细说明": "⚠️此功能为实验性选项,域名可能解析到多个 IPv4/IPv6 地址,若开启,请确保 IP 过滤列表覆盖这些地址,否则可能导致访问失败。",
|
||||
"允许在 Stripe 支付中输入促销码": "允许在 Stripe 支付中输入促销码"
|
||||
}
|
||||
|
||||
@@ -45,6 +45,7 @@ export default function SettingsPaymentGateway(props) {
|
||||
StripePriceId: '',
|
||||
StripeUnitPrice: 8.0,
|
||||
StripeMinTopUp: 1,
|
||||
StripePromotionCodesEnabled: false,
|
||||
});
|
||||
const [originInputs, setOriginInputs] = useState({});
|
||||
const formApiRef = useRef(null);
|
||||
@@ -63,6 +64,10 @@ export default function SettingsPaymentGateway(props) {
|
||||
props.options.StripeMinTopUp !== undefined
|
||||
? parseFloat(props.options.StripeMinTopUp)
|
||||
: 1,
|
||||
StripePromotionCodesEnabled:
|
||||
props.options.StripePromotionCodesEnabled !== undefined
|
||||
? props.options.StripePromotionCodesEnabled
|
||||
: false,
|
||||
};
|
||||
setInputs(currentInputs);
|
||||
setOriginInputs({ ...currentInputs });
|
||||
@@ -114,6 +119,16 @@ export default function SettingsPaymentGateway(props) {
|
||||
value: inputs.StripeMinTopUp.toString(),
|
||||
});
|
||||
}
|
||||
if (
|
||||
originInputs['StripePromotionCodesEnabled'] !==
|
||||
inputs.StripePromotionCodesEnabled &&
|
||||
inputs.StripePromotionCodesEnabled !== undefined
|
||||
) {
|
||||
options.push({
|
||||
key: 'StripePromotionCodesEnabled',
|
||||
value: inputs.StripePromotionCodesEnabled ? 'true' : 'false',
|
||||
});
|
||||
}
|
||||
|
||||
// 发送请求
|
||||
const requestQueue = options.map((opt) =>
|
||||
@@ -225,6 +240,15 @@ export default function SettingsPaymentGateway(props) {
|
||||
placeholder={t('例如:2,就是最低充值2$')}
|
||||
/>
|
||||
</Col>
|
||||
<Col xs={24} sm={24} md={8} lg={8} xl={8}>
|
||||
<Form.Switch
|
||||
field='StripePromotionCodesEnabled'
|
||||
size='default'
|
||||
checkedText='|'
|
||||
uncheckedText='〇'
|
||||
label={t('允许在 Stripe 支付中输入促销码')}
|
||||
/>
|
||||
</Col>
|
||||
</Row>
|
||||
<Button onClick={submitStripeSetting}>{t('更新 Stripe 设置')}</Button>
|
||||
</Form.Section>
|
||||
|
||||
Reference in New Issue
Block a user