Compare commits

..

1 Commits

Author SHA1 Message Date
Xyfacai
cd7594f623 feat: dalle 格式支持自定义参数 2025-06-09 22:14:51 +08:00
273 changed files with 13545 additions and 21450 deletions

View File

@@ -7,8 +7,6 @@
# 调试相关配置
# 启用pprof
# ENABLE_PPROF=true
# 启用调试模式
# DEBUG=true
# 数据库相关配置
# 数据库连接字符串
@@ -43,14 +41,6 @@
# 更新任务启用
# UPDATE_TASK=true
# 对话超时设置
# 所有请求超时时间单位秒默认为0表示不限制
# RELAY_TIMEOUT=0
# 流模式无响应超时时间,单位秒,如果出现空补全可以尝试改为更大值
# STREAMING_TIMEOUT=120
# Gemini 识别图片 最大图片数量
# GEMINI_VISION_MAX_IMAGE_NUM=16
# 会话密钥
# SESSION_SECRET=random_string
@@ -68,19 +58,10 @@
# GET_MEDIA_TOKEN_NOT_STREAM=true
# 设置 Dify 渠道是否输出工作流和节点信息到客户端
# DIFY_DEBUG=true
# 设置流式一次回复的超时时间
# STREAMING_TIMEOUT=90
# 节点类型
# 如果是主节点则为master
# NODE_TYPE=master
# JavaScript 运行时配置
# 是否启用默认false
# JS_RUNTIME_ENABLED=true
# 最大虚拟机数量默认8
# JS_MAX_VM_COUNT=
# 运行超时时间单位默认5
# JS_SCRIPT_TIMEOUT=
# 脚本文件夹默认scripts/
# JS_SCRIPT_PATH=

View File

@@ -1,19 +0,0 @@
### PR 类型
- [ ] Bug 修复
- [ ] 新功能
- [ ] 文档更新
- [ ] 其他
### PR 是否包含破坏性更新?
- [ ]
- [ ]
### PR 描述
**请在下方详细描述您的 PR包括目的、实现细节等。**
### **重要提示**
**所有 PR 都必须提交到 `alpha` 分支。请确保您的 PR 目标分支是 `alpha`。**

View File

@@ -26,7 +26,6 @@ jobs:
- name: Build Frontend
env:
CI: ""
NODE_OPTIONS: "--max-old-space-size=4096"
run: |
cd web
bun install

View File

@@ -1,21 +0,0 @@
name: Check PR Branching Strategy
on:
pull_request:
types: [opened, synchronize, reopened, edited]
jobs:
check-branching-strategy:
runs-on: ubuntu-latest
steps:
- name: Enforce branching strategy
run: |
if [[ "${{ github.base_ref }}" == "main" ]]; then
if [[ "${{ github.head_ref }}" != "alpha" ]]; then
echo "Error: Pull requests to 'main' are only allowed from the 'alpha' branch."
exit 1
fi
elif [[ "${{ github.base_ref }}" != "alpha" ]]; then
echo "Error: Pull requests must be targeted to the 'alpha' or 'main' branch."
exit 1
fi
echo "Branching strategy check passed."

View File

@@ -24,7 +24,8 @@ RUN go build -ldflags "-s -w -X 'one-api/common.Version=$(cat VERSION)'" -o one-
FROM alpine
RUN apk upgrade --no-cache \
RUN apk update \
&& apk upgrade \
&& apk add --no-cache ca-certificates tzdata ffmpeg \
&& update-ca-certificates

View File

@@ -100,7 +100,7 @@ This version supports multiple models, please refer to [API Documentation-Relay
For detailed configuration instructions, please refer to [Installation Guide-Environment Variables Configuration](https://docs.newapi.pro/installation/environment-variables):
- `GENERATE_DEFAULT_TOKEN`: Whether to generate initial tokens for newly registered users, default is `false`
- `STREAMING_TIMEOUT`: Streaming response timeout, default is 120 seconds
- `STREAMING_TIMEOUT`: Streaming response timeout, default is 60 seconds
- `DIFY_DEBUG`: Whether to output workflow and node information for Dify channels, default is `true`
- `FORCE_STREAM_OPTION`: Whether to override client stream_options parameter, default is `true`
- `GET_MEDIA_TOKEN`: Whether to count image tokens, default is `true`

View File

@@ -100,7 +100,7 @@ New API提供了丰富的功能详细特性请参考[特性说明](https://do
详细配置说明请参考[安装指南-环境变量配置](https://docs.newapi.pro/installation/environment-variables)
- `GENERATE_DEFAULT_TOKEN`:是否为新注册用户生成初始令牌,默认为 `false`
- `STREAMING_TIMEOUT`:流式回复超时时间,默认120秒
- `STREAMING_TIMEOUT`:流式回复超时时间,默认60秒
- `DIFY_DEBUG`Dify渠道是否输出工作流和节点信息默认 `true`
- `FORCE_STREAM_OPTION`是否覆盖客户端stream_options参数默认 `true`
- `GET_MEDIA_TOKEN`是否统计图片token默认 `true`
@@ -180,6 +180,7 @@ docker run --name new-api -d --restart always -p 3000:3000 -e SQL_DSN="root:1234
其他基于New API的项目
- [new-api-horizon](https://github.com/Calcium-Ion/new-api-horizon)New API高性能优化版
- [VoAPI](https://github.com/VoAPI/VoAPI)基于New API的前端美化版本
## 帮助支持

View File

@@ -1,71 +0,0 @@
package common
import "one-api/constant"
func ChannelType2APIType(channelType int) (int, bool) {
apiType := -1
switch channelType {
case constant.ChannelTypeOpenAI:
apiType = constant.APITypeOpenAI
case constant.ChannelTypeAnthropic:
apiType = constant.APITypeAnthropic
case constant.ChannelTypeBaidu:
apiType = constant.APITypeBaidu
case constant.ChannelTypePaLM:
apiType = constant.APITypePaLM
case constant.ChannelTypeZhipu:
apiType = constant.APITypeZhipu
case constant.ChannelTypeAli:
apiType = constant.APITypeAli
case constant.ChannelTypeXunfei:
apiType = constant.APITypeXunfei
case constant.ChannelTypeAIProxyLibrary:
apiType = constant.APITypeAIProxyLibrary
case constant.ChannelTypeTencent:
apiType = constant.APITypeTencent
case constant.ChannelTypeGemini:
apiType = constant.APITypeGemini
case constant.ChannelTypeZhipu_v4:
apiType = constant.APITypeZhipuV4
case constant.ChannelTypeOllama:
apiType = constant.APITypeOllama
case constant.ChannelTypePerplexity:
apiType = constant.APITypePerplexity
case constant.ChannelTypeAws:
apiType = constant.APITypeAws
case constant.ChannelTypeCohere:
apiType = constant.APITypeCohere
case constant.ChannelTypeDify:
apiType = constant.APITypeDify
case constant.ChannelTypeJina:
apiType = constant.APITypeJina
case constant.ChannelCloudflare:
apiType = constant.APITypeCloudflare
case constant.ChannelTypeSiliconFlow:
apiType = constant.APITypeSiliconFlow
case constant.ChannelTypeVertexAi:
apiType = constant.APITypeVertexAi
case constant.ChannelTypeMistral:
apiType = constant.APITypeMistral
case constant.ChannelTypeDeepSeek:
apiType = constant.APITypeDeepSeek
case constant.ChannelTypeMokaAI:
apiType = constant.APITypeMokaAI
case constant.ChannelTypeVolcEngine:
apiType = constant.APITypeVolcEngine
case constant.ChannelTypeBaiduV2:
apiType = constant.APITypeBaiduV2
case constant.ChannelTypeOpenRouter:
apiType = constant.APITypeOpenRouter
case constant.ChannelTypeXinference:
apiType = constant.APITypeXinference
case constant.ChannelTypeXai:
apiType = constant.APITypeXai
case constant.ChannelTypeCoze:
apiType = constant.APITypeCoze
}
if apiType == -1 {
return constant.APITypeOpenAI, false
}
return apiType, true
}

View File

@@ -193,3 +193,107 @@ const (
ChannelStatusManuallyDisabled = 2 // also don't use 0
ChannelStatusAutoDisabled = 3
)
const (
ChannelTypeUnknown = 0
ChannelTypeOpenAI = 1
ChannelTypeMidjourney = 2
ChannelTypeAzure = 3
ChannelTypeOllama = 4
ChannelTypeMidjourneyPlus = 5
ChannelTypeOpenAIMax = 6
ChannelTypeOhMyGPT = 7
ChannelTypeCustom = 8
ChannelTypeAILS = 9
ChannelTypeAIProxy = 10
ChannelTypePaLM = 11
ChannelTypeAPI2GPT = 12
ChannelTypeAIGC2D = 13
ChannelTypeAnthropic = 14
ChannelTypeBaidu = 15
ChannelTypeZhipu = 16
ChannelTypeAli = 17
ChannelTypeXunfei = 18
ChannelType360 = 19
ChannelTypeOpenRouter = 20
ChannelTypeAIProxyLibrary = 21
ChannelTypeFastGPT = 22
ChannelTypeTencent = 23
ChannelTypeGemini = 24
ChannelTypeMoonshot = 25
ChannelTypeZhipu_v4 = 26
ChannelTypePerplexity = 27
ChannelTypeLingYiWanWu = 31
ChannelTypeAws = 33
ChannelTypeCohere = 34
ChannelTypeMiniMax = 35
ChannelTypeSunoAPI = 36
ChannelTypeDify = 37
ChannelTypeJina = 38
ChannelCloudflare = 39
ChannelTypeSiliconFlow = 40
ChannelTypeVertexAi = 41
ChannelTypeMistral = 42
ChannelTypeDeepSeek = 43
ChannelTypeMokaAI = 44
ChannelTypeVolcEngine = 45
ChannelTypeBaiduV2 = 46
ChannelTypeXinference = 47
ChannelTypeXai = 48
ChannelTypeCoze = 49
ChannelTypeDummy // this one is only for count, do not add any channel after this
)
var ChannelBaseURLs = []string{
"", // 0
"https://api.openai.com", // 1
"https://oa.api2d.net", // 2
"", // 3
"http://localhost:11434", // 4
"https://api.openai-sb.com", // 5
"https://api.openaimax.com", // 6
"https://api.ohmygpt.com", // 7
"", // 8
"https://api.caipacity.com", // 9
"https://api.aiproxy.io", // 10
"", // 11
"https://api.api2gpt.com", // 12
"https://api.aigc2d.com", // 13
"https://api.anthropic.com", // 14
"https://aip.baidubce.com", // 15
"https://open.bigmodel.cn", // 16
"https://dashscope.aliyuncs.com", // 17
"", // 18
"https://api.360.cn", // 19
"https://openrouter.ai/api", // 20
"https://api.aiproxy.io", // 21
"https://fastgpt.run/api/openapi", // 22
"https://hunyuan.tencentcloudapi.com", //23
"https://generativelanguage.googleapis.com", //24
"https://api.moonshot.cn", //25
"https://open.bigmodel.cn", //26
"https://api.perplexity.ai", //27
"", //28
"", //29
"", //30
"https://api.lingyiwanwu.com", //31
"", //32
"", //33
"https://api.cohere.ai", //34
"https://api.minimax.chat", //35
"", //36
"https://api.dify.ai", //37
"https://api.jina.ai", //38
"https://api.cloudflare.com", //39
"https://api.siliconflow.cn", //40
"", //41
"https://api.mistral.ai", //42
"https://api.deepseek.com", //43
"https://api.moka.ai", //44
"https://ark.cn-beijing.volces.com", //45
"https://qianfan.baidubce.com", //46
"", //47
"https://api.x.ai", //48
"https://api.coze.cn", //49
}

View File

@@ -1,14 +1,7 @@
package common
const (
DatabaseTypeMySQL = "mysql"
DatabaseTypeSQLite = "sqlite"
DatabaseTypePostgreSQL = "postgres"
)
var UsingSQLite = false
var UsingPostgreSQL = false
var LogSqlType = DatabaseTypeSQLite // Default to SQLite for logging SQL queries
var UsingMySQL = false
var UsingClickHouse = false

View File

@@ -1,41 +0,0 @@
package common
import "one-api/constant"
// GetEndpointTypesByChannelType 获取渠道最优先端点类型(所有的渠道都支持 OpenAI 端点)
func GetEndpointTypesByChannelType(channelType int, modelName string) []constant.EndpointType {
var endpointTypes []constant.EndpointType
switch channelType {
case constant.ChannelTypeJina:
endpointTypes = []constant.EndpointType{constant.EndpointTypeJinaRerank}
//case constant.ChannelTypeMidjourney, constant.ChannelTypeMidjourneyPlus:
// endpointTypes = []constant.EndpointType{constant.EndpointTypeMidjourney}
//case constant.ChannelTypeSunoAPI:
// endpointTypes = []constant.EndpointType{constant.EndpointTypeSuno}
//case constant.ChannelTypeKling:
// endpointTypes = []constant.EndpointType{constant.EndpointTypeKling}
//case constant.ChannelTypeJimeng:
// endpointTypes = []constant.EndpointType{constant.EndpointTypeJimeng}
case constant.ChannelTypeAws:
fallthrough
case constant.ChannelTypeAnthropic:
endpointTypes = []constant.EndpointType{constant.EndpointTypeAnthropic, constant.EndpointTypeOpenAI}
case constant.ChannelTypeVertexAi:
fallthrough
case constant.ChannelTypeGemini:
endpointTypes = []constant.EndpointType{constant.EndpointTypeGemini, constant.EndpointTypeOpenAI}
case constant.ChannelTypeOpenRouter: // OpenRouter 只支持 OpenAI 端点
endpointTypes = []constant.EndpointType{constant.EndpointTypeOpenAI}
default:
if IsOpenAIResponseOnlyModel(modelName) {
endpointTypes = []constant.EndpointType{constant.EndpointTypeOpenAIResponse}
} else {
endpointTypes = []constant.EndpointType{constant.EndpointTypeOpenAI}
}
}
if IsImageGenerationModel(modelName) {
// add to first
endpointTypes = append([]constant.EndpointType{constant.EndpointTypeImageGeneration}, endpointTypes...)
}
return endpointTypes
}

View File

@@ -2,11 +2,10 @@ package common
import (
"bytes"
"encoding/json"
"github.com/gin-gonic/gin"
"io"
"one-api/constant"
"strings"
"time"
)
const KeyRequestBody = "key_request_body"
@@ -32,7 +31,7 @@ func UnmarshalBodyReusable(c *gin.Context, v any) error {
}
contentType := c.Request.Header.Get("Content-Type")
if strings.HasPrefix(contentType, "application/json") {
err = UnmarshalJson(requestBody, &v)
err = json.Unmarshal(requestBody, &v)
} else {
// skip for now
// TODO: someday non json request have variant model, we will need to implementation this
@@ -44,45 +43,3 @@ func UnmarshalBodyReusable(c *gin.Context, v any) error {
c.Request.Body = io.NopCloser(bytes.NewBuffer(requestBody))
return nil
}
func SetContextKey(c *gin.Context, key constant.ContextKey, value any) {
c.Set(string(key), value)
}
func GetContextKey(c *gin.Context, key constant.ContextKey) (any, bool) {
return c.Get(string(key))
}
func GetContextKeyString(c *gin.Context, key constant.ContextKey) string {
return c.GetString(string(key))
}
func GetContextKeyInt(c *gin.Context, key constant.ContextKey) int {
return c.GetInt(string(key))
}
func GetContextKeyBool(c *gin.Context, key constant.ContextKey) bool {
return c.GetBool(string(key))
}
func GetContextKeyStringSlice(c *gin.Context, key constant.ContextKey) []string {
return c.GetStringSlice(string(key))
}
func GetContextKeyStringMap(c *gin.Context, key constant.ContextKey) map[string]any {
return c.GetStringMap(string(key))
}
func GetContextKeyTime(c *gin.Context, key constant.ContextKey) time.Time {
return c.GetTime(string(key))
}
func GetContextKeyType[T any](c *gin.Context, key constant.ContextKey) (T, bool) {
if value, ok := c.Get(string(key)); ok {
if v, ok := value.(T); ok {
return v, true
}
}
var t T
return t, false
}

View File

@@ -1,57 +0,0 @@
package common
import (
"bytes"
"fmt"
"io"
"net/http"
"github.com/gin-gonic/gin"
)
func CloseResponseBodyGracefully(httpResponse *http.Response) {
if httpResponse == nil || httpResponse.Body == nil {
return
}
err := httpResponse.Body.Close()
if err != nil {
SysError("failed to close response body: " + err.Error())
}
}
func IOCopyBytesGracefully(c *gin.Context, src *http.Response, data []byte) {
if c.Writer == nil {
return
}
body := io.NopCloser(bytes.NewBuffer(data))
// We shouldn't set the header before we parse the response body, because the parse part may fail.
// And then we will have to send an error response, but in this case, the header has already been set.
// So the httpClient will be confused by the response.
// For example, Postman will report error, and we cannot check the response at all.
if src != nil {
for k, v := range src.Header {
// avoid setting Content-Length
if k == "Content-Length" {
continue
}
c.Writer.Header().Set(k, v[0])
}
}
// set Content-Length header manually BEFORE calling WriteHeader
c.Writer.Header().Set("Content-Length", fmt.Sprintf("%d", len(data)))
// Write header with status code (this sends the headers)
if src != nil {
c.Writer.WriteHeader(src.StatusCode)
} else {
c.Writer.WriteHeader(http.StatusOK)
}
_, err := io.Copy(c.Writer, body)
if err != nil {
LogError(c, fmt.Sprintf("failed to copy response body: %s", err.Error()))
}
}

View File

@@ -4,7 +4,6 @@ import (
"flag"
"fmt"
"log"
"one-api/constant"
"os"
"path/filepath"
"strconv"
@@ -25,7 +24,7 @@ func printHelp() {
fmt.Println("Usage: one-api [--port <port>] [--log-dir <log directory>] [--version] [--help]")
}
func InitEnv() {
func LoadEnv() {
flag.Parse()
if *PrintVersion {
@@ -96,25 +95,4 @@ func InitEnv() {
GlobalWebRateLimitEnable = GetEnvOrDefaultBool("GLOBAL_WEB_RATE_LIMIT_ENABLE", true)
GlobalWebRateLimitNum = GetEnvOrDefault("GLOBAL_WEB_RATE_LIMIT", 60)
GlobalWebRateLimitDuration = int64(GetEnvOrDefault("GLOBAL_WEB_RATE_LIMIT_DURATION", 180))
initConstantEnv()
}
func initConstantEnv() {
constant.StreamingTimeout = GetEnvOrDefault("STREAMING_TIMEOUT", 120)
constant.DifyDebug = GetEnvOrDefaultBool("DIFY_DEBUG", true)
constant.MaxFileDownloadMB = GetEnvOrDefault("MAX_FILE_DOWNLOAD_MB", 20)
// ForceStreamOption 覆盖请求参数强制返回usage信息
constant.ForceStreamOption = GetEnvOrDefaultBool("FORCE_STREAM_OPTION", true)
constant.GetMediaToken = GetEnvOrDefaultBool("GET_MEDIA_TOKEN", true)
constant.GetMediaTokenNotStream = GetEnvOrDefaultBool("GET_MEDIA_TOKEN_NOT_STREAM", true)
constant.UpdateTask = GetEnvOrDefaultBool("UPDATE_TASK", true)
constant.AzureDefaultAPIVersion = GetEnvOrDefaultString("AZURE_DEFAULT_API_VERSION", "2025-04-01-preview")
constant.GeminiVisionMaxImageNum = GetEnvOrDefault("GEMINI_VISION_MAX_IMAGE_NUM", 16)
constant.NotifyLimitCount = GetEnvOrDefault("NOTIFY_LIMIT_COUNT", 2)
constant.NotificationLimitDurationMinute = GetEnvOrDefault("NOTIFICATION_LIMIT_DURATION_MINUTE", 10)
// GenerateDefaultToken 是否生成初始令牌,默认关闭。
constant.GenerateDefaultToken = GetEnvOrDefaultBool("GENERATE_DEFAULT_TOKEN", false)
// 是否启用错误日志
constant.ErrorLogEnabled = GetEnvOrDefaultBool("ERROR_LOG_ENABLED", false)
}

View File

@@ -5,16 +5,12 @@ import (
"encoding/json"
)
func UnmarshalJson(data []byte, v any) error {
return json.Unmarshal(data, v)
func DecodeJson(data []byte, v any) error {
return json.NewDecoder(bytes.NewReader(data)).Decode(v)
}
func UnmarshalJsonStr(data string, v any) error {
return json.Unmarshal(StringToByteSlice(data), v)
}
func DecodeJson(reader *bytes.Reader, v any) error {
return json.NewDecoder(reader).Decode(v)
func DecodeJsonStr(data string, v any) error {
return DecodeJson(StringToByteSlice(data), v)
}
func EncodeJson(v any) ([]byte, error) {

View File

@@ -1,42 +0,0 @@
package common
import "strings"
var (
// OpenAIResponseOnlyModels is a list of models that are only available for OpenAI responses.
OpenAIResponseOnlyModels = []string{
"o3-pro",
"o3-deep-research",
"o4-mini-deep-research",
}
ImageGenerationModels = []string{
"dall-e-3",
"dall-e-2",
"gpt-image-1",
"prefix:imagen-",
"flux-",
"flux.1-",
}
)
func IsOpenAIResponseOnlyModel(modelName string) bool {
for _, m := range OpenAIResponseOnlyModels {
if strings.Contains(modelName, m) {
return true
}
}
return false
}
func IsImageGenerationModel(modelName string) bool {
modelName = strings.ToLower(modelName)
for _, m := range ImageGenerationModels {
if strings.Contains(modelName, m) {
return true
}
if strings.HasPrefix(m, "prefix:") && strings.HasPrefix(modelName, strings.TrimPrefix(m, "prefix:")) {
return true
}
}
return false
}

View File

@@ -1,62 +0,0 @@
package common
import (
"github.com/gin-gonic/gin"
"strconv"
)
type PageInfo struct {
Page int `json:"page"` // page num 页码
PageSize int `json:"page_size"` // page size 页大小
StartTimestamp int64 `json:"start_timestamp"` // 秒级
EndTimestamp int64 `json:"end_timestamp"` // 秒级
Total int `json:"total"` // 总条数,后设置
Items any `json:"items"` // 数据,后设置
}
func (p *PageInfo) GetStartIdx() int {
return (p.Page - 1) * p.PageSize
}
func (p *PageInfo) GetEndIdx() int {
return p.Page * p.PageSize
}
func (p *PageInfo) GetPageSize() int {
return p.PageSize
}
func (p *PageInfo) GetPage() int {
return p.Page
}
func (p *PageInfo) SetTotal(total int) {
p.Total = total
}
func (p *PageInfo) SetItems(items any) {
p.Items = items
}
func GetPageQuery(c *gin.Context) (*PageInfo, error) {
pageInfo := &PageInfo{}
err := c.BindQuery(pageInfo)
if err != nil {
return nil, err
}
if pageInfo.Page < 1 {
// 兼容
page, _ := strconv.Atoi(c.Query("p"))
if page != 0 {
pageInfo.Page = page
} else {
pageInfo.Page = 1
}
}
if pageInfo.PageSize == 0 {
pageInfo.PageSize = ItemsPerPage
}
return pageInfo, nil
}

View File

@@ -16,10 +16,6 @@ import (
var RDB *redis.Client
var RedisEnabled = true
func RedisKeyCacheSeconds() int {
return SyncFrequency
}
// InitRedisClient This function is called after init()
func InitRedisClient() (err error) {
if os.Getenv("REDIS_CONN_STRING") == "" {
@@ -145,11 +141,7 @@ func RedisHSetObj(key string, obj interface{}, expiration time.Duration) error {
txn := RDB.TxPipeline()
txn.HSet(ctx, key, data)
// 只有在 expiration 大于 0 时才设置过期时间
if expiration > 0 {
txn.Expire(ctx, key, expiration)
}
txn.Expire(ctx, key, expiration)
_, err := txn.Exec(ctx)
if err != nil {

View File

@@ -1,7 +1,6 @@
package common
import (
"encoding/base64"
"encoding/json"
"math/rand"
"strconv"
@@ -69,15 +68,3 @@ func StringToByteSlice(s string) []byte {
tmp2 := [3]uintptr{tmp1[0], tmp1[1], tmp1[1]}
return *(*[]byte)(unsafe.Pointer(&tmp2))
}
func EncodeBase64(str string) string {
return base64.StdEncoding.EncodeToString([]byte(str))
}
func GetJsonString(data any) string {
if data == nil {
return ""
}
b, _ := json.Marshal(data)
return string(b)
}

View File

@@ -1,149 +0,0 @@
package common
import (
"fmt"
"reflect"
)
// StructToMap 递归地把任意结构体 v 转成 map[string]any。
// - 只处理导出字段;未导出字段会被跳过。
// - 优先使用 `json:"name"` 里逗号前的部分作为键;如果是 "-" 则忽略该字段;若无 tag则使用字段名。
// - 对指针、切片、数组、嵌套结构体、map 做深度遍历,保持原始结构。
func StructToMap(v any) (map[string]any, error) {
val := reflect.ValueOf(v)
if !val.IsValid() {
return nil, fmt.Errorf("nil value")
}
for val.Kind() == reflect.Pointer {
if val.IsNil() {
return nil, fmt.Errorf("nil pointer")
}
val = val.Elem()
}
if val.Kind() != reflect.Struct {
return nil, fmt.Errorf("expect struct, got %s", val.Kind())
}
return structValueToMap(val), nil
}
func structValueToMap(val reflect.Value) map[string]any {
out := make(map[string]any, val.NumField())
typ := val.Type()
for i := 0; i < val.NumField(); i++ {
f := typ.Field(i)
if f.PkgPath != "" { // 未导出字段
continue
}
// 解析 json tag
tag := f.Tag.Get("json")
name, opts := parseTag(tag)
if name == "-" {
continue
}
if name == "" {
name = f.Name
}
fv := val.Field(i)
out[name] = valueToAny(fv, opts.Contains("omitempty"))
}
return out
}
// valueToAny 递归处理各种值类型。
func valueToAny(v reflect.Value, omitEmpty bool) any {
if !v.IsValid() {
return nil
}
for v.Kind() == reflect.Pointer {
if v.IsNil() {
if omitEmpty {
return nil
}
// 保持与 encoding/json 行为一致nil 指针写成 null
return nil
}
v = v.Elem()
}
switch v.Kind() {
case reflect.Struct:
return structValueToMap(v)
case reflect.Slice, reflect.Array:
l := v.Len()
arr := make([]any, l)
for i := 0; i < l; i++ {
arr[i] = valueToAny(v.Index(i), false)
}
return arr
case reflect.Map:
m := make(map[string]any, v.Len())
iter := v.MapRange()
for iter.Next() {
k := iter.Key()
// 只支持 string key与 encoding/json 一致
if k.Kind() == reflect.String {
m[k.String()] = valueToAny(iter.Value(), false)
}
}
return m
default:
// 基本类型直接返回其接口值
return v.Interface()
}
}
// tagOptions 用于判断是否包含 "omitempty"
type tagOptions string
func (o tagOptions) Contains(opt string) bool {
if len(o) == 0 {
return false
}
for _, s := range splitComma(string(o)) {
if s == opt {
return true
}
}
return false
}
func parseTag(tag string) (string, tagOptions) {
if idx := indexComma(tag); idx != -1 {
return tag[:idx], tagOptions(tag[idx+1:])
}
return tag, tagOptions("")
}
// 避免 strings.Split 额外分配
func indexComma(s string) int {
for i, r := range s {
if r == ',' {
return i
}
}
return -1
}
func splitComma(s string) []string {
var parts []string
start := 0
for i, r := range s {
if r == ',' {
parts = append(parts, s[start:i])
start = i + 1
}
}
if start <= len(s) {
parts = append(parts, s[start:])
}
return parts
}

View File

@@ -13,7 +13,6 @@ import (
"math/big"
"math/rand"
"net"
"net/url"
"os"
"os/exec"
"runtime"
@@ -250,55 +249,13 @@ func SaveTmpFile(filename string, data io.Reader) (string, error) {
}
// GetAudioDuration returns the duration of an audio file in seconds.
func GetAudioDuration(ctx context.Context, filename string, ext string) (float64, error) {
func GetAudioDuration(ctx context.Context, filename string) (float64, error) {
// ffprobe -v error -show_entries format=duration -of default=noprint_wrappers=1:nokey=1 {{input}}
c := exec.CommandContext(ctx, "ffprobe", "-v", "error", "-show_entries", "format=duration", "-of", "default=noprint_wrappers=1:nokey=1", filename)
output, err := c.Output()
if err != nil {
return 0, errors.Wrap(err, "failed to get audio duration")
}
durationStr := string(bytes.TrimSpace(output))
if durationStr == "N/A" {
// Create a temporary output file name
tmpFp, err := os.CreateTemp("", "audio-*"+ext)
if err != nil {
return 0, errors.Wrap(err, "failed to create temporary file")
}
tmpName := tmpFp.Name()
// Close immediately so ffmpeg can open the file on Windows.
_ = tmpFp.Close()
defer os.Remove(tmpName)
// ffmpeg -y -i filename -vcodec copy -acodec copy <tmpName>
ffmpegCmd := exec.CommandContext(ctx, "ffmpeg", "-y", "-i", filename, "-vcodec", "copy", "-acodec", "copy", tmpName)
if err := ffmpegCmd.Run(); err != nil {
return 0, errors.Wrap(err, "failed to run ffmpeg")
}
// Recalculate the duration of the new file
c = exec.CommandContext(ctx, "ffprobe", "-v", "error", "-show_entries", "format=duration", "-of", "default=noprint_wrappers=1:nokey=1", tmpName)
output, err := c.Output()
if err != nil {
return 0, errors.Wrap(err, "failed to get audio duration after ffmpeg")
}
durationStr = string(bytes.TrimSpace(output))
}
return strconv.ParseFloat(durationStr, 64)
}
// BuildURL concatenates base and endpoint, returns the complete url string
func BuildURL(base string, endpoint string) string {
u, err := url.Parse(base)
if err != nil {
return base + endpoint
}
end := endpoint
if end == "" {
end = "/"
}
ref, err := url.Parse(end)
if err != nil {
return base + endpoint
}
return u.ResolveReference(ref).String()
return strconv.ParseFloat(string(bytes.TrimSpace(output)), 64)
}

View File

@@ -1,26 +0,0 @@
# constant 包 (`/constant`)
该目录仅用于放置全局可复用的**常量定义**,不包含任何业务逻辑或依赖关系。
## 当前文件
| 文件 | 说明 |
|----------------------|---------------------------------------------------------------------|
| `azure.go` | 定义与 Azure 相关的全局常量,如 `AzureNoRemoveDotTime`(控制删除 `.` 的截止时间)。 |
| `cache_key.go` | 缓存键格式字符串及 Token 相关字段常量,统一缓存命名规则。 |
| `channel_setting.go` | Channel 级别的设置键,如 `proxy``force_format` 等。 |
| `context_key.go` | 定义 `ContextKey` 类型以及在整个项目中使用的上下文键常量请求时间、Token/Channel/User 相关信息等)。 |
| `env.go` | 环境配置相关的全局变量,在启动阶段根据配置文件或环境变量注入。 |
| `finish_reason.go` | OpenAI/GPT 请求返回的 `finish_reason` 字符串常量集合。 |
| `midjourney.go` | Midjourney 相关错误码及动作(Action)常量与模型到动作的映射表。 |
| `setup.go` | 标识项目是否已完成初始化安装 (`Setup` 布尔值)。 |
| `task.go` | 各种任务(Task)平台、动作常量及模型与动作映射表,如 Suno、Midjourney 等。 |
| `user_setting.go` | 用户设置相关键常量以及通知类型(Email/Webhook)等。 |
## 使用约定
1. `constant` 包**只能被其他包引用**import**禁止在此包中引用项目内的其他自定义包**。如确有需要,仅允许引用 **Go 标准库**
2. 不允许在此目录内编写任何与业务流程、数据库操作、第三方服务调用等相关的逻辑代码。
3. 新增类型时,请保持命名语义清晰,并在本 README 的 **当前文件** 表格中补充说明,确保团队成员能够快速了解其用途。
> ⚠️ 违反以上约定将导致包之间产生不必要的耦合,影响代码可维护性与可测试性。请在提交代码前自行检查。

View File

@@ -1,34 +0,0 @@
package constant
const (
APITypeOpenAI = iota
APITypeAnthropic
APITypePaLM
APITypeBaidu
APITypeZhipu
APITypeAli
APITypeXunfei
APITypeAIProxyLibrary
APITypeTencent
APITypeGemini
APITypeZhipuV4
APITypeOllama
APITypePerplexity
APITypeAws
APITypeCohere
APITypeDify
APITypeJina
APITypeCloudflare
APITypeSiliconFlow
APITypeVertexAi
APITypeMistral
APITypeDeepSeek
APITypeMokaAI
APITypeVolcEngine
APITypeBaiduV2
APITypeOpenRouter
APITypeXinference
APITypeXai
APITypeCoze
APITypeDummy // this one is only for count, do not add any channel after this
)

View File

@@ -1,5 +1,14 @@
package constant
import "one-api/common"
var (
TokenCacheSeconds = common.SyncFrequency
UserId2GroupCacheSeconds = common.SyncFrequency
UserId2QuotaCacheSeconds = common.SyncFrequency
UserId2StatusCacheSeconds = common.SyncFrequency
)
// Cache keys
const (
UserGroupKeyFmt = "user_group:%d"

View File

@@ -1,109 +0,0 @@
package constant
const (
ChannelTypeUnknown = 0
ChannelTypeOpenAI = 1
ChannelTypeMidjourney = 2
ChannelTypeAzure = 3
ChannelTypeOllama = 4
ChannelTypeMidjourneyPlus = 5
ChannelTypeOpenAIMax = 6
ChannelTypeOhMyGPT = 7
ChannelTypeCustom = 8
ChannelTypeAILS = 9
ChannelTypeAIProxy = 10
ChannelTypePaLM = 11
ChannelTypeAPI2GPT = 12
ChannelTypeAIGC2D = 13
ChannelTypeAnthropic = 14
ChannelTypeBaidu = 15
ChannelTypeZhipu = 16
ChannelTypeAli = 17
ChannelTypeXunfei = 18
ChannelType360 = 19
ChannelTypeOpenRouter = 20
ChannelTypeAIProxyLibrary = 21
ChannelTypeFastGPT = 22
ChannelTypeTencent = 23
ChannelTypeGemini = 24
ChannelTypeMoonshot = 25
ChannelTypeZhipu_v4 = 26
ChannelTypePerplexity = 27
ChannelTypeLingYiWanWu = 31
ChannelTypeAws = 33
ChannelTypeCohere = 34
ChannelTypeMiniMax = 35
ChannelTypeSunoAPI = 36
ChannelTypeDify = 37
ChannelTypeJina = 38
ChannelCloudflare = 39
ChannelTypeSiliconFlow = 40
ChannelTypeVertexAi = 41
ChannelTypeMistral = 42
ChannelTypeDeepSeek = 43
ChannelTypeMokaAI = 44
ChannelTypeVolcEngine = 45
ChannelTypeBaiduV2 = 46
ChannelTypeXinference = 47
ChannelTypeXai = 48
ChannelTypeCoze = 49
ChannelTypeKling = 50
ChannelTypeJimeng = 51
ChannelTypeDummy // this one is only for count, do not add any channel after this
)
var ChannelBaseURLs = []string{
"", // 0
"https://api.openai.com", // 1
"https://oa.api2d.net", // 2
"", // 3
"http://localhost:11434", // 4
"https://api.openai-sb.com", // 5
"https://api.openaimax.com", // 6
"https://api.ohmygpt.com", // 7
"", // 8
"https://api.caipacity.com", // 9
"https://api.aiproxy.io", // 10
"", // 11
"https://api.api2gpt.com", // 12
"https://api.aigc2d.com", // 13
"https://api.anthropic.com", // 14
"https://aip.baidubce.com", // 15
"https://open.bigmodel.cn", // 16
"https://dashscope.aliyuncs.com", // 17
"", // 18
"https://api.360.cn", // 19
"https://openrouter.ai/api", // 20
"https://api.aiproxy.io", // 21
"https://fastgpt.run/api/openapi", // 22
"https://hunyuan.tencentcloudapi.com", //23
"https://generativelanguage.googleapis.com", //24
"https://api.moonshot.cn", //25
"https://open.bigmodel.cn", //26
"https://api.perplexity.ai", //27
"", //28
"", //29
"", //30
"https://api.lingyiwanwu.com", //31
"", //32
"", //33
"https://api.cohere.ai", //34
"https://api.minimax.chat", //35
"", //36
"https://api.dify.ai", //37
"https://api.jina.ai", //38
"https://api.cloudflare.com", //39
"https://api.siliconflow.cn", //40
"", //41
"https://api.mistral.ai", //42
"https://api.deepseek.com", //43
"https://api.moka.ai", //44
"https://ark.cn-beijing.volces.com", //45
"https://qianfan.baidubce.com", //46
"", //47
"https://api.x.ai", //48
"https://api.coze.cn", //49
"https://api.klingai.com", //50
"https://visual.volcengineapi.com", //51
}

View File

@@ -0,0 +1,7 @@
package constant
var (
ForceFormat = "force_format" // ForceFormat 强制格式化为OpenAI格式
ChanelSettingProxy = "proxy" // Proxy 代理
ChannelSettingThinkingToContent = "thinking_to_content" // ThinkingToContent
)

View File

@@ -1,35 +1,10 @@
package constant
type ContextKey string
const (
ContextKeyOriginalModel ContextKey = "original_model"
ContextKeyRequestStartTime ContextKey = "request_start_time"
/* token related keys */
ContextKeyTokenUnlimited ContextKey = "token_unlimited_quota"
ContextKeyTokenKey ContextKey = "token_key"
ContextKeyTokenId ContextKey = "token_id"
ContextKeyTokenGroup ContextKey = "token_group"
ContextKeyTokenAllowIps ContextKey = "allow_ips"
ContextKeyTokenSpecificChannelId ContextKey = "specific_channel_id"
ContextKeyTokenModelLimitEnabled ContextKey = "token_model_limit_enabled"
ContextKeyTokenModelLimit ContextKey = "token_model_limit"
/* channel related keys */
ContextKeyBaseUrl ContextKey = "base_url"
ContextKeyChannelType ContextKey = "channel_type"
ContextKeyChannelId ContextKey = "channel_id"
ContextKeyChannelSetting ContextKey = "channel_setting"
ContextKeyParamOverride ContextKey = "param_override"
/* user related keys */
ContextKeyUserId ContextKey = "id"
ContextKeyUserSetting ContextKey = "user_setting"
ContextKeyUserQuota ContextKey = "user_quota"
ContextKeyUserStatus ContextKey = "user_status"
ContextKeyUserEmail ContextKey = "user_email"
ContextKeyUserGroup ContextKey = "user_group"
ContextKeyUsingGroup ContextKey = "group"
ContextKeyUserName ContextKey = "username"
ContextKeyRequestStartTime = "request_start_time"
ContextKeyUserSetting = "user_setting"
ContextKeyUserQuota = "user_quota"
ContextKeyUserStatus = "user_status"
ContextKeyUserEmail = "user_email"
ContextKeyUserGroup = "user_group"
)

View File

@@ -1,16 +0,0 @@
package constant
type EndpointType string
const (
EndpointTypeOpenAI EndpointType = "openai"
EndpointTypeOpenAIResponse EndpointType = "openai-response"
EndpointTypeAnthropic EndpointType = "anthropic"
EndpointTypeGemini EndpointType = "gemini"
EndpointTypeJinaRerank EndpointType = "jina-rerank"
EndpointTypeImageGeneration EndpointType = "image-generation"
//EndpointTypeMidjourney EndpointType = "midjourney-proxy"
//EndpointTypeSuno EndpointType = "suno-proxy"
//EndpointTypeKling EndpointType = "kling"
//EndpointTypeJimeng EndpointType = "jimeng"
)

View File

@@ -1,5 +1,9 @@
package constant
import (
"one-api/common"
)
var StreamingTimeout int
var DifyDebug bool
var MaxFileDownloadMB int
@@ -13,3 +17,39 @@ var NotifyLimitCount int
var NotificationLimitDurationMinute int
var GenerateDefaultToken bool
var ErrorLogEnabled bool
//var GeminiModelMap = map[string]string{
// "gemini-1.0-pro": "v1",
//}
func InitEnv() {
StreamingTimeout = common.GetEnvOrDefault("STREAMING_TIMEOUT", 60)
DifyDebug = common.GetEnvOrDefaultBool("DIFY_DEBUG", true)
MaxFileDownloadMB = common.GetEnvOrDefault("MAX_FILE_DOWNLOAD_MB", 20)
// ForceStreamOption 覆盖请求参数强制返回usage信息
ForceStreamOption = common.GetEnvOrDefaultBool("FORCE_STREAM_OPTION", true)
GetMediaToken = common.GetEnvOrDefaultBool("GET_MEDIA_TOKEN", true)
GetMediaTokenNotStream = common.GetEnvOrDefaultBool("GET_MEDIA_TOKEN_NOT_STREAM", true)
UpdateTask = common.GetEnvOrDefaultBool("UPDATE_TASK", true)
AzureDefaultAPIVersion = common.GetEnvOrDefaultString("AZURE_DEFAULT_API_VERSION", "2025-04-01-preview")
GeminiVisionMaxImageNum = common.GetEnvOrDefault("GEMINI_VISION_MAX_IMAGE_NUM", 16)
NotifyLimitCount = common.GetEnvOrDefault("NOTIFY_LIMIT_COUNT", 2)
NotificationLimitDurationMinute = common.GetEnvOrDefault("NOTIFICATION_LIMIT_DURATION_MINUTE", 10)
// GenerateDefaultToken 是否生成初始令牌,默认关闭。
GenerateDefaultToken = common.GetEnvOrDefaultBool("GENERATE_DEFAULT_TOKEN", false)
// 是否启用错误日志
ErrorLogEnabled = common.GetEnvOrDefaultBool("ERROR_LOG_ENABLED", false)
//modelVersionMapStr := strings.TrimSpace(os.Getenv("GEMINI_MODEL_MAP"))
//if modelVersionMapStr == "" {
// return
//}
//for _, pair := range strings.Split(modelVersionMapStr, ",") {
// parts := strings.Split(pair, ":")
// if len(parts) == 2 {
// GeminiModelMap[parts[0]] = parts[1]
// } else {
// common.SysError(fmt.Sprintf("invalid model version map: %s", pair))
// }
//}
}

View File

@@ -22,8 +22,6 @@ const (
MjActionPan = "PAN"
MjActionSwapFace = "SWAP_FACE"
MjActionUpload = "UPLOAD"
MjActionVideo = "VIDEO"
MjActionEdits = "EDITS"
)
var MidjourneyModel2Action = map[string]string{
@@ -43,6 +41,4 @@ var MidjourneyModel2Action = map[string]string{
"mj_pan": MjActionPan,
"swap_face": MjActionSwapFace,
"mj_upload": MjActionUpload,
"mj_video": MjActionVideo,
"mj_edits": MjActionEdits,
}

View File

@@ -5,16 +5,11 @@ type TaskPlatform string
const (
TaskPlatformSuno TaskPlatform = "suno"
TaskPlatformMidjourney = "mj"
TaskPlatformKling TaskPlatform = "kling"
TaskPlatformJimeng TaskPlatform = "jimeng"
)
const (
SunoActionMusic = "MUSIC"
SunoActionLyrics = "LYRICS"
TaskActionGenerate = "generate"
TaskActionTextGenerate = "textGenerate"
)
var SunoModel2Action = map[string]string{

15
constant/user_setting.go Normal file
View File

@@ -0,0 +1,15 @@
package constant
var (
UserSettingNotifyType = "notify_type" // QuotaWarningType 额度预警类型
UserSettingQuotaWarningThreshold = "quota_warning_threshold" // QuotaWarningThreshold 额度预警阈值
UserSettingWebhookUrl = "webhook_url" // WebhookUrl webhook地址
UserSettingWebhookSecret = "webhook_secret" // WebhookSecret webhook密钥
UserSettingNotificationEmail = "notification_email" // NotificationEmail 通知邮箱地址
UserAcceptUnsetRatioModel = "accept_unset_model_ratio_model" // AcceptUnsetRatioModel 是否接受未设置价格的模型
)
var (
NotifyTypeEmail = "email" // Email 邮件
NotifyTypeWebhook = "webhook" // Webhook
)

View File

@@ -4,14 +4,11 @@ import (
"encoding/json"
"errors"
"fmt"
"github.com/shopspring/decimal"
"io"
"net/http"
"one-api/common"
"one-api/constant"
"one-api/model"
"one-api/service"
"one-api/setting"
"strconv"
"time"
@@ -307,70 +304,34 @@ func updateChannelOpenRouterBalance(channel *model.Channel) (float64, error) {
return balance, nil
}
func updateChannelMoonshotBalance(channel *model.Channel) (float64, error) {
url := "https://api.moonshot.cn/v1/users/me/balance"
body, err := GetResponseBody("GET", url, channel, GetAuthHeader(channel.Key))
if err != nil {
return 0, err
}
type MoonshotBalanceData struct {
AvailableBalance float64 `json:"available_balance"`
VoucherBalance float64 `json:"voucher_balance"`
CashBalance float64 `json:"cash_balance"`
}
type MoonshotBalanceResponse struct {
Code int `json:"code"`
Data MoonshotBalanceData `json:"data"`
Scode string `json:"scode"`
Status bool `json:"status"`
}
response := MoonshotBalanceResponse{}
err = json.Unmarshal(body, &response)
if err != nil {
return 0, err
}
if !response.Status || response.Code != 0 {
return 0, fmt.Errorf("failed to update moonshot balance, status: %v, code: %d, scode: %s", response.Status, response.Code, response.Scode)
}
availableBalanceCny := response.Data.AvailableBalance
availableBalanceUsd := decimal.NewFromFloat(availableBalanceCny).Div(decimal.NewFromFloat(setting.Price)).InexactFloat64()
channel.UpdateBalance(availableBalanceUsd)
return availableBalanceUsd, nil
}
func updateChannelBalance(channel *model.Channel) (float64, error) {
baseURL := constant.ChannelBaseURLs[channel.Type]
baseURL := common.ChannelBaseURLs[channel.Type]
if channel.GetBaseURL() == "" {
channel.BaseURL = &baseURL
}
switch channel.Type {
case constant.ChannelTypeOpenAI:
case common.ChannelTypeOpenAI:
if channel.GetBaseURL() != "" {
baseURL = channel.GetBaseURL()
}
case constant.ChannelTypeAzure:
case common.ChannelTypeAzure:
return 0, errors.New("尚未实现")
case constant.ChannelTypeCustom:
case common.ChannelTypeCustom:
baseURL = channel.GetBaseURL()
//case common.ChannelTypeOpenAISB:
// return updateChannelOpenAISBBalance(channel)
case constant.ChannelTypeAIProxy:
case common.ChannelTypeAIProxy:
return updateChannelAIProxyBalance(channel)
case constant.ChannelTypeAPI2GPT:
case common.ChannelTypeAPI2GPT:
return updateChannelAPI2GPTBalance(channel)
case constant.ChannelTypeAIGC2D:
case common.ChannelTypeAIGC2D:
return updateChannelAIGC2DBalance(channel)
case constant.ChannelTypeSiliconFlow:
case common.ChannelTypeSiliconFlow:
return updateChannelSiliconFlowBalance(channel)
case constant.ChannelTypeDeepSeek:
case common.ChannelTypeDeepSeek:
return updateChannelDeepSeekBalance(channel)
case constant.ChannelTypeOpenRouter:
case common.ChannelTypeOpenRouter:
return updateChannelOpenRouterBalance(channel)
case constant.ChannelTypeMoonshot:
return updateChannelMoonshotBalance(channel)
default:
return 0, errors.New("尚未实现")
}

View File

@@ -11,12 +11,12 @@ import (
"net/http/httptest"
"net/url"
"one-api/common"
"one-api/constant"
"one-api/dto"
"one-api/middleware"
"one-api/model"
"one-api/relay"
relaycommon "one-api/relay/common"
"one-api/relay/constant"
"one-api/relay/helper"
"one-api/service"
"strconv"
@@ -31,21 +31,15 @@ import (
func testChannel(channel *model.Channel, testModel string) (err error, openAIErrorWithStatusCode *dto.OpenAIErrorWithStatusCode) {
tik := time.Now()
if channel.Type == constant.ChannelTypeMidjourney {
if channel.Type == common.ChannelTypeMidjourney {
return errors.New("midjourney channel test is not supported"), nil
}
if channel.Type == constant.ChannelTypeMidjourneyPlus {
return errors.New("midjourney plus channel test is not supported"), nil
if channel.Type == common.ChannelTypeMidjourneyPlus {
return errors.New("midjourney plus channel test is not supported!!!"), nil
}
if channel.Type == constant.ChannelTypeSunoAPI {
if channel.Type == common.ChannelTypeSunoAPI {
return errors.New("suno channel test is not supported"), nil
}
if channel.Type == constant.ChannelTypeKling {
return errors.New("kling channel test is not supported"), nil
}
if channel.Type == constant.ChannelTypeJimeng {
return errors.New("jimeng channel test is not supported"), nil
}
w := httptest.NewRecorder()
c, _ := gin.CreateTestContext(w)
@@ -56,7 +50,7 @@ func testChannel(channel *model.Channel, testModel string) (err error, openAIErr
strings.HasPrefix(testModel, "m3e") || // m3e 系列模型
strings.Contains(testModel, "bge-") || // bge 系列模型
strings.Contains(testModel, "embed") ||
channel.Type == constant.ChannelTypeMokaAI { // 其他 embedding 模型
channel.Type == common.ChannelTypeMokaAI { // 其他 embedding 模型
requestPath = "/v1/embeddings" // 修改请求路径
}
@@ -96,13 +90,13 @@ func testChannel(channel *model.Channel, testModel string) (err error, openAIErr
info := relaycommon.GenRelayInfo(c)
err = helper.ModelMappedHelper(c, info, nil)
err = helper.ModelMappedHelper(c, info)
if err != nil {
return err, nil
}
testModel = info.UpstreamModelName
apiType, _ := common.ChannelType2APIType(channel.Type)
apiType, _ := constant.ChannelType2APIType(channel.Type)
adaptor := relay.GetAdaptor(apiType)
if adaptor == nil {
return fmt.Errorf("invalid api type: %d, adaptor is nil", apiType), nil
@@ -171,21 +165,10 @@ func testChannel(channel *model.Channel, testModel string) (err error, openAIErr
tok := time.Now()
milliseconds := tok.Sub(tik).Milliseconds()
consumedTime := float64(milliseconds) / 1000.0
other := service.GenerateTextOtherInfo(c, info, priceData.ModelRatio, priceData.GroupRatioInfo.GroupRatio, priceData.CompletionRatio,
usage.PromptTokensDetails.CachedTokens, priceData.CacheRatio, priceData.ModelPrice, priceData.GroupRatioInfo.GroupSpecialRatio)
model.RecordConsumeLog(c, 1, model.RecordConsumeLogParams{
ChannelId: channel.Id,
PromptTokens: usage.PromptTokens,
CompletionTokens: usage.CompletionTokens,
ModelName: info.OriginModelName,
TokenName: "模型测试",
Quota: quota,
Content: "模型测试",
UseTimeSeconds: int(consumedTime),
IsStream: false,
Group: info.UsingGroup,
Other: other,
})
other := service.GenerateTextOtherInfo(c, info, priceData.ModelRatio, priceData.GroupRatio, priceData.CompletionRatio,
usage.PromptTokensDetails.CachedTokens, priceData.CacheRatio, priceData.ModelPrice)
model.RecordConsumeLog(c, 1, channel.Id, usage.PromptTokens, usage.CompletionTokens, info.OriginModelName, "模型测试",
quota, "模型测试", 0, quota, int(consumedTime), false, info.Group, other)
common.SysLog(fmt.Sprintf("testing channel #%d, response: \n%s", channel.Id, string(respBody)))
return nil, nil
}
@@ -213,14 +196,14 @@ func buildTestRequest(model string) *dto.GeneralOpenAIRequest {
testRequest.MaxTokens = 50
}
} else if strings.Contains(model, "gemini") {
testRequest.MaxTokens = 3000
testRequest.MaxTokens = 300
} else {
testRequest.MaxTokens = 10
}
content, _ := json.Marshal("hi")
testMessage := dto.Message{
Role: "user",
Content: "hi",
Content: content,
}
testRequest.Model = model
testRequest.Messages = append(testRequest.Messages, testMessage)
@@ -288,13 +271,6 @@ func testAllChannels(notify bool) error {
disableThreshold = 10000000 // a impossible value
}
gopool.Go(func() {
// 使用 defer 确保无论如何都会重置运行状态,防止死锁
defer func() {
testAllChannelsLock.Lock()
testAllChannelsRunning = false
testAllChannelsLock.Unlock()
}()
for _, channel := range channels {
isChannelEnabled := channel.Status == common.ChannelStatusEnabled
tik := time.Now()
@@ -329,7 +305,9 @@ func testAllChannels(notify bool) error {
channel.UpdateResponseTime(milliseconds)
time.Sleep(common.RequestInterval)
}
testAllChannelsLock.Lock()
testAllChannelsRunning = false
testAllChannelsLock.Unlock()
if notify {
service.NotifyRootUser(dto.NotifyTypeChannelTest, "通道测试完成", "所有通道测试已完成")
}

View File

@@ -5,7 +5,6 @@ import (
"fmt"
"net/http"
"one-api/common"
"one-api/constant"
"one-api/model"
"strconv"
"strings"
@@ -41,124 +40,50 @@ type OpenAIModelsResponse struct {
Success bool `json:"success"`
}
func parseStatusFilter(statusParam string) int {
switch strings.ToLower(statusParam) {
case "enabled", "1":
return common.ChannelStatusEnabled
case "disabled", "0":
return 0
default:
return -1
}
}
func GetAllChannels(c *gin.Context) {
p, _ := strconv.Atoi(c.Query("p"))
pageSize, _ := strconv.Atoi(c.Query("page_size"))
if p < 1 {
p = 1
if p < 0 {
p = 0
}
if pageSize < 1 {
if pageSize < 0 {
pageSize = common.ItemsPerPage
}
channelData := make([]*model.Channel, 0)
idSort, _ := strconv.ParseBool(c.Query("id_sort"))
enableTagMode, _ := strconv.ParseBool(c.Query("tag_mode"))
statusParam := c.Query("status")
// statusFilter: -1 all, 1 enabled, 0 disabled (include auto & manual)
statusFilter := parseStatusFilter(statusParam)
// type filter
typeStr := c.Query("type")
typeFilter := -1
if typeStr != "" {
if t, err := strconv.Atoi(typeStr); err == nil {
typeFilter = t
}
}
var total int64
if enableTagMode {
tags, err := model.GetPaginatedTags((p-1)*pageSize, pageSize)
tags, err := model.GetPaginatedTags(p*pageSize, pageSize)
if err != nil {
c.JSON(http.StatusOK, gin.H{"success": false, "message": err.Error()})
c.JSON(http.StatusOK, gin.H{
"success": false,
"message": err.Error(),
})
return
}
for _, tag := range tags {
if tag == nil || *tag == "" {
continue
}
tagChannels, err := model.GetChannelsByTag(*tag, idSort)
if err != nil {
continue
}
filtered := make([]*model.Channel, 0)
for _, ch := range tagChannels {
if statusFilter == common.ChannelStatusEnabled && ch.Status != common.ChannelStatusEnabled {
continue
if tag != nil && *tag != "" {
tagChannel, err := model.GetChannelsByTag(*tag, idSort)
if err == nil {
channelData = append(channelData, tagChannel...)
}
if statusFilter == 0 && ch.Status == common.ChannelStatusEnabled {
continue
}
if typeFilter >= 0 && ch.Type != typeFilter {
continue
}
filtered = append(filtered, ch)
}
channelData = append(channelData, filtered...)
}
total, _ = model.CountAllTags()
} else {
baseQuery := model.DB.Model(&model.Channel{})
if typeFilter >= 0 {
baseQuery = baseQuery.Where("type = ?", typeFilter)
}
if statusFilter == common.ChannelStatusEnabled {
baseQuery = baseQuery.Where("status = ?", common.ChannelStatusEnabled)
} else if statusFilter == 0 {
baseQuery = baseQuery.Where("status != ?", common.ChannelStatusEnabled)
}
baseQuery.Count(&total)
order := "priority desc"
if idSort {
order = "id desc"
}
err := baseQuery.Order(order).Limit(pageSize).Offset((p - 1) * pageSize).Omit("key").Find(&channelData).Error
channels, err := model.GetAllChannels(p*pageSize, pageSize, false, idSort)
if err != nil {
c.JSON(http.StatusOK, gin.H{"success": false, "message": err.Error()})
c.JSON(http.StatusOK, gin.H{
"success": false,
"message": err.Error(),
})
return
}
channelData = channels
}
countQuery := model.DB.Model(&model.Channel{})
if statusFilter == common.ChannelStatusEnabled {
countQuery = countQuery.Where("status = ?", common.ChannelStatusEnabled)
} else if statusFilter == 0 {
countQuery = countQuery.Where("status != ?", common.ChannelStatusEnabled)
}
var results []struct {
Type int64
Count int64
}
_ = countQuery.Select("type, count(*) as count").Group("type").Find(&results).Error
typeCounts := make(map[int64]int64)
for _, r := range results {
typeCounts[r.Type] = r.Count
}
c.JSON(http.StatusOK, gin.H{
"success": true,
"message": "",
"data": gin.H{
"items": channelData,
"total": total,
"page": p,
"page_size": pageSize,
"type_counts": typeCounts,
},
"data": channelData,
})
return
}
@@ -182,15 +107,22 @@ func FetchUpstreamModels(c *gin.Context) {
return
}
baseURL := constant.ChannelBaseURLs[channel.Type]
//if channel.Type != common.ChannelTypeOpenAI {
// c.JSON(http.StatusOK, gin.H{
// "success": false,
// "message": "仅支持 OpenAI 类型渠道",
// })
// return
//}
baseURL := common.ChannelBaseURLs[channel.Type]
if channel.GetBaseURL() != "" {
baseURL = channel.GetBaseURL()
}
url := fmt.Sprintf("%s/v1/models", baseURL)
switch channel.Type {
case constant.ChannelTypeGemini:
case common.ChannelTypeGemini:
url = fmt.Sprintf("%s/v1beta/openai/models", baseURL)
case constant.ChannelTypeAli:
case common.ChannelTypeAli:
url = fmt.Sprintf("%s/compatible-mode/v1/models", baseURL)
}
body, err := GetResponseBody("GET", url, channel, GetAuthHeader(channel.Key))
@@ -214,7 +146,7 @@ func FetchUpstreamModels(c *gin.Context) {
var ids []string
for _, model := range result.Data {
id := model.ID
if channel.Type == constant.ChannelTypeGemini {
if channel.Type == common.ChannelTypeGemini {
id = strings.TrimPrefix(id, "models/")
}
ids = append(ids, id)
@@ -228,7 +160,7 @@ func FetchUpstreamModels(c *gin.Context) {
}
func FixChannelsAbilities(c *gin.Context) {
success, fails, err := model.FixAbility()
count, err := model.FixAbility()
if err != nil {
c.JSON(http.StatusOK, gin.H{
"success": false,
@@ -239,10 +171,7 @@ func FixChannelsAbilities(c *gin.Context) {
c.JSON(http.StatusOK, gin.H{
"success": true,
"message": "",
"data": gin.H{
"success": success,
"fails": fails,
},
"data": count,
})
}
@@ -250,8 +179,6 @@ func SearchChannels(c *gin.Context) {
keyword := c.Query("keyword")
group := c.Query("group")
modelKeyword := c.Query("model")
statusParam := c.Query("status")
statusFilter := parseStatusFilter(statusParam)
idSort, _ := strconv.ParseBool(c.Query("id_sort"))
enableTagMode, _ := strconv.ParseBool(c.Query("tag_mode"))
channelData := make([]*model.Channel, 0)
@@ -283,74 +210,10 @@ func SearchChannels(c *gin.Context) {
}
channelData = channels
}
if statusFilter == common.ChannelStatusEnabled || statusFilter == 0 {
filtered := make([]*model.Channel, 0, len(channelData))
for _, ch := range channelData {
if statusFilter == common.ChannelStatusEnabled && ch.Status != common.ChannelStatusEnabled {
continue
}
if statusFilter == 0 && ch.Status == common.ChannelStatusEnabled {
continue
}
filtered = append(filtered, ch)
}
channelData = filtered
}
// calculate type counts for search results
typeCounts := make(map[int64]int64)
for _, channel := range channelData {
typeCounts[int64(channel.Type)]++
}
typeParam := c.Query("type")
typeFilter := -1
if typeParam != "" {
if tp, err := strconv.Atoi(typeParam); err == nil {
typeFilter = tp
}
}
if typeFilter >= 0 {
filtered := make([]*model.Channel, 0, len(channelData))
for _, ch := range channelData {
if ch.Type == typeFilter {
filtered = append(filtered, ch)
}
}
channelData = filtered
}
page, _ := strconv.Atoi(c.DefaultQuery("p", "1"))
pageSize, _ := strconv.Atoi(c.DefaultQuery("page_size", "20"))
if page < 1 {
page = 1
}
if pageSize <= 0 {
pageSize = 20
}
total := len(channelData)
startIdx := (page - 1) * pageSize
if startIdx > total {
startIdx = total
}
endIdx := startIdx + pageSize
if endIdx > total {
endIdx = total
}
pagedData := channelData[startIdx:endIdx]
c.JSON(http.StatusOK, gin.H{
"success": true,
"message": "",
"data": gin.H{
"items": pagedData,
"total": total,
"type_counts": typeCounts,
},
"data": channelData,
})
return
}
@@ -390,17 +253,9 @@ func AddChannel(c *gin.Context) {
})
return
}
err = channel.ValidateSettings()
if err != nil {
c.JSON(http.StatusOK, gin.H{
"success": false,
"message": "channel setting 格式错误:" + err.Error(),
})
return
}
channel.CreatedTime = common.GetTimestamp()
keys := strings.Split(channel.Key, "\n")
if channel.Type == constant.ChannelTypeVertexAi {
if channel.Type == common.ChannelTypeVertexAi {
if channel.Other == "" {
c.JSON(http.StatusOK, gin.H{
"success": false,
@@ -625,15 +480,7 @@ func UpdateChannel(c *gin.Context) {
})
return
}
err = channel.ValidateSettings()
if err != nil {
c.JSON(http.StatusOK, gin.H{
"success": false,
"message": "channel setting 格式错误:" + err.Error(),
})
return
}
if channel.Type == constant.ChannelTypeVertexAi {
if channel.Type == common.ChannelTypeVertexAi {
if channel.Other == "" {
c.JSON(http.StatusOK, gin.H{
"success": false,
@@ -662,7 +509,6 @@ func UpdateChannel(c *gin.Context) {
})
return
}
channel.Key = ""
c.JSON(http.StatusOK, gin.H{
"success": true,
"message": "",
@@ -688,7 +534,7 @@ func FetchModels(c *gin.Context) {
baseURL := req.BaseURL
if baseURL == "" {
baseURL = constant.ChannelBaseURLs[req.Type]
baseURL = common.ChannelBaseURLs[req.Type]
}
client := &http.Client{}

View File

@@ -1,103 +0,0 @@
// 用于迁移检测的旧键,该文件下个版本会删除
package controller
import (
"encoding/json"
"net/http"
"one-api/common"
"one-api/model"
"github.com/gin-gonic/gin"
)
// MigrateConsoleSetting 迁移旧的控制台相关配置到 console_setting.*
func MigrateConsoleSetting(c *gin.Context) {
// 读取全部 option
opts, err := model.AllOption()
if err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"success": false, "message": err.Error()})
return
}
// 建立 map
valMap := map[string]string{}
for _, o := range opts {
valMap[o.Key] = o.Value
}
// 处理 APIInfo
if v := valMap["ApiInfo"]; v != "" {
var arr []map[string]interface{}
if err := json.Unmarshal([]byte(v), &arr); err == nil {
if len(arr) > 50 {
arr = arr[:50]
}
bytes, _ := json.Marshal(arr)
model.UpdateOption("console_setting.api_info", string(bytes))
}
model.UpdateOption("ApiInfo", "")
}
// Announcements 直接搬
if v := valMap["Announcements"]; v != "" {
model.UpdateOption("console_setting.announcements", v)
model.UpdateOption("Announcements", "")
}
// FAQ 转换
if v := valMap["FAQ"]; v != "" {
var arr []map[string]interface{}
if err := json.Unmarshal([]byte(v), &arr); err == nil {
out := []map[string]interface{}{}
for _, item := range arr {
q, _ := item["question"].(string)
if q == "" {
q, _ = item["title"].(string)
}
a, _ := item["answer"].(string)
if a == "" {
a, _ = item["content"].(string)
}
if q != "" && a != "" {
out = append(out, map[string]interface{}{"question": q, "answer": a})
}
}
if len(out) > 50 {
out = out[:50]
}
bytes, _ := json.Marshal(out)
model.UpdateOption("console_setting.faq", string(bytes))
}
model.UpdateOption("FAQ", "")
}
// Uptime Kuma 迁移到新的 groups 结构console_setting.uptime_kuma_groups
url := valMap["UptimeKumaUrl"]
slug := valMap["UptimeKumaSlug"]
if url != "" && slug != "" {
// 仅当同时存在 URL 与 Slug 时才进行迁移
groups := []map[string]interface{}{
{
"id": 1,
"categoryName": "old",
"url": url,
"slug": slug,
"description": "",
},
}
bytes, _ := json.Marshal(groups)
model.UpdateOption("console_setting.uptime_kuma_groups", string(bytes))
}
// 清空旧键内容
if url != "" {
model.UpdateOption("UptimeKumaUrl", "")
}
if slug != "" {
model.UpdateOption("UptimeKumaSlug", "")
}
// 删除旧键记录
oldKeys := []string{"ApiInfo", "Announcements", "FAQ", "UptimeKumaUrl", "UptimeKumaSlug"}
model.DB.Where("key IN ?", oldKeys).Delete(&model.Option{})
// 重新加载 OptionMap
model.InitOptionMap()
common.SysLog("console setting migrated")
c.JSON(http.StatusOK, gin.H{"success": true, "message": "migrated"})
}

View File

@@ -1,17 +1,15 @@
package controller
import (
"github.com/gin-gonic/gin"
"net/http"
"one-api/model"
"one-api/setting"
"one-api/setting/ratio_setting"
"github.com/gin-gonic/gin"
)
func GetGroups(c *gin.Context) {
groupNames := make([]string, 0)
for groupName := range ratio_setting.GetGroupRatioCopy() {
for groupName, _ := range setting.GetGroupRatioCopy() {
groupNames = append(groupNames, groupName)
}
c.JSON(http.StatusOK, gin.H{
@@ -26,7 +24,7 @@ func GetUserGroups(c *gin.Context) {
userGroup := ""
userId := c.GetInt("id")
userGroup, _ = model.GetUserGroup(userId, false)
for groupName, ratio := range ratio_setting.GetGroupRatioCopy() {
for groupName, ratio := range setting.GetGroupRatioCopy() {
// UserUsableGroups contains the groups that the user can use
userUsableGroups := setting.GetUserUsableGroups(userGroup)
if desc, ok := userUsableGroups[groupName]; ok {
@@ -36,12 +34,6 @@ func GetUserGroups(c *gin.Context) {
}
}
}
if setting.GroupInUserUsableGroups("auto") {
usableGroups["auto"] = map[string]interface{}{
"ratio": "自动",
"desc": setting.GetUsableGroupDescription("auto"),
}
}
c.JSON(http.StatusOK, gin.H{
"success": true,
"message": "",

View File

@@ -7,6 +7,7 @@ import (
"fmt"
"github.com/gin-gonic/gin"
"io"
"log"
"net/http"
"one-api/common"
"one-api/dto"
@@ -214,12 +215,8 @@ func checkMjTaskNeedUpdate(oldTask *model.Midjourney, newTask dto.MidjourneyDto)
func GetAllMidjourney(c *gin.Context) {
p, _ := strconv.Atoi(c.Query("p"))
if p < 1 {
p = 1
}
pageSize, _ := strconv.Atoi(c.Query("page_size"))
if pageSize <= 0 {
pageSize = common.ItemsPerPage
if p < 0 {
p = 0
}
// 解析其他查询参数
@@ -230,38 +227,31 @@ func GetAllMidjourney(c *gin.Context) {
EndTimestamp: c.Query("end_timestamp"),
}
items := model.GetAllTasks((p-1)*pageSize, pageSize, queryParams)
total := model.CountAllTasks(queryParams)
logs := model.GetAllTasks(p*common.ItemsPerPage, common.ItemsPerPage, queryParams)
if logs == nil {
logs = make([]*model.Midjourney, 0)
}
if setting.MjForwardUrlEnabled {
for i, midjourney := range items {
for i, midjourney := range logs {
midjourney.ImageUrl = setting.ServerAddress + "/mj/image/" + midjourney.MjId
items[i] = midjourney
logs[i] = midjourney
}
}
c.JSON(200, gin.H{
"success": true,
"message": "",
"data": gin.H{
"items": items,
"total": total,
"page": p,
"page_size": pageSize,
},
"data": logs,
})
}
func GetUserMidjourney(c *gin.Context) {
p, _ := strconv.Atoi(c.Query("p"))
if p < 1 {
p = 1
}
pageSize, _ := strconv.Atoi(c.Query("page_size"))
if pageSize <= 0 {
pageSize = common.ItemsPerPage
if p < 0 {
p = 0
}
userId := c.GetInt("id")
log.Printf("userId = %d \n", userId)
queryParams := model.TaskQueryParams{
MjID: c.Query("mj_id"),
@@ -269,23 +259,19 @@ func GetUserMidjourney(c *gin.Context) {
EndTimestamp: c.Query("end_timestamp"),
}
items := model.GetAllUserTask(userId, (p-1)*pageSize, pageSize, queryParams)
total := model.CountAllUserTask(userId, queryParams)
logs := model.GetAllUserTask(userId, p*common.ItemsPerPage, common.ItemsPerPage, queryParams)
if logs == nil {
logs = make([]*model.Midjourney, 0)
}
if setting.MjForwardUrlEnabled {
for i, midjourney := range items {
for i, midjourney := range logs {
midjourney.ImageUrl = setting.ServerAddress + "/mj/image/" + midjourney.MjId
items[i] = midjourney
logs[i] = midjourney
}
}
c.JSON(200, gin.H{
"success": true,
"message": "",
"data": gin.H{
"items": items,
"total": total,
"page": p,
"page_size": pageSize,
},
"data": logs,
})
}

View File

@@ -1,17 +1,13 @@
package controller
import (
"slices"
"encoding/json"
"fmt"
"net/http"
"one-api/common"
"one-api/constant"
"one-api/middleware"
"one-api/middleware/jsrt"
"one-api/model"
"one-api/setting"
"one-api/setting/console_setting"
"one-api/setting/operation_setting"
"one-api/setting/system_setting"
"strings"
@@ -28,85 +24,59 @@ func TestStatus(c *gin.Context) {
})
return
}
// 获取HTTP统计信息
httpStats := middleware.GetStats()
c.JSON(http.StatusOK, gin.H{
"success": true,
"message": "Server is running",
"http_stats": httpStats,
"success": true,
"message": "Server is running",
})
return
}
func GetStatus(c *gin.Context) {
cs := console_setting.GetConsoleSetting()
data := gin.H{
"version": common.Version,
"start_time": common.StartTime,
"email_verification": common.EmailVerificationEnabled,
"github_oauth": common.GitHubOAuthEnabled,
"github_client_id": common.GitHubClientId,
"linuxdo_oauth": common.LinuxDOOAuthEnabled,
"linuxdo_client_id": common.LinuxDOClientId,
"telegram_oauth": common.TelegramOAuthEnabled,
"telegram_bot_name": common.TelegramBotName,
"system_name": common.SystemName,
"logo": common.Logo,
"footer_html": common.Footer,
"wechat_qrcode": common.WeChatAccountQRCodeImageURL,
"wechat_login": common.WeChatAuthEnabled,
"server_address": setting.ServerAddress,
"price": setting.Price,
"min_topup": setting.MinTopUp,
"turnstile_check": common.TurnstileCheckEnabled,
"turnstile_site_key": common.TurnstileSiteKey,
"top_up_link": common.TopUpLink,
"docs_link": operation_setting.GetGeneralSetting().DocsLink,
"quota_per_unit": common.QuotaPerUnit,
"display_in_currency": common.DisplayInCurrencyEnabled,
"enable_batch_update": common.BatchUpdateEnabled,
"enable_drawing": common.DrawingEnabled,
"enable_task": common.TaskEnabled,
"enable_data_export": common.DataExportEnabled,
"data_export_default_time": common.DataExportDefaultTime,
"default_collapse_sidebar": common.DefaultCollapseSidebar,
"enable_online_topup": setting.PayAddress != "" && setting.EpayId != "" && setting.EpayKey != "",
"mj_notify_enabled": setting.MjNotifyEnabled,
"chats": setting.Chats,
"demo_site_enabled": operation_setting.DemoSiteEnabled,
"self_use_mode_enabled": operation_setting.SelfUseModeEnabled,
"default_use_auto_group": setting.DefaultUseAutoGroup,
"pay_methods": setting.PayMethods,
// 面板启用开关
"api_info_enabled": cs.ApiInfoEnabled,
"uptime_kuma_enabled": cs.UptimeKumaEnabled,
"announcements_enabled": cs.AnnouncementsEnabled,
"faq_enabled": cs.FAQEnabled,
"oidc_enabled": system_setting.GetOIDCSettings().Enabled,
"oidc_client_id": system_setting.GetOIDCSettings().ClientId,
"oidc_authorization_endpoint": system_setting.GetOIDCSettings().AuthorizationEndpoint,
"setup": constant.Setup,
}
// 根据启用状态注入可选内容
if cs.ApiInfoEnabled {
data["api_info"] = console_setting.GetApiInfo()
}
if cs.AnnouncementsEnabled {
data["announcements"] = console_setting.GetAnnouncements()
}
if cs.FAQEnabled {
data["faq"] = console_setting.GetFAQ()
}
c.JSON(http.StatusOK, gin.H{
"success": true,
"message": "",
"data": data,
"data": gin.H{
"version": common.Version,
"start_time": common.StartTime,
"email_verification": common.EmailVerificationEnabled,
"github_oauth": common.GitHubOAuthEnabled,
"github_client_id": common.GitHubClientId,
"linuxdo_oauth": common.LinuxDOOAuthEnabled,
"linuxdo_client_id": common.LinuxDOClientId,
"telegram_oauth": common.TelegramOAuthEnabled,
"telegram_bot_name": common.TelegramBotName,
"system_name": common.SystemName,
"logo": common.Logo,
"footer_html": common.Footer,
"wechat_qrcode": common.WeChatAccountQRCodeImageURL,
"wechat_login": common.WeChatAuthEnabled,
"server_address": setting.ServerAddress,
"price": setting.Price,
"min_topup": setting.MinTopUp,
"turnstile_check": common.TurnstileCheckEnabled,
"turnstile_site_key": common.TurnstileSiteKey,
"top_up_link": common.TopUpLink,
"docs_link": operation_setting.GetGeneralSetting().DocsLink,
"quota_per_unit": common.QuotaPerUnit,
"display_in_currency": common.DisplayInCurrencyEnabled,
"enable_batch_update": common.BatchUpdateEnabled,
"enable_drawing": common.DrawingEnabled,
"enable_task": common.TaskEnabled,
"enable_data_export": common.DataExportEnabled,
"data_export_default_time": common.DataExportDefaultTime,
"default_collapse_sidebar": common.DefaultCollapseSidebar,
"enable_online_topup": setting.PayAddress != "" && setting.EpayId != "" && setting.EpayKey != "",
"mj_notify_enabled": setting.MjNotifyEnabled,
"chats": setting.Chats,
"demo_site_enabled": operation_setting.DemoSiteEnabled,
"self_use_mode_enabled": operation_setting.SelfUseModeEnabled,
"oidc_enabled": system_setting.GetOIDCSettings().Enabled,
"oidc_client_id": system_setting.GetOIDCSettings().ClientId,
"oidc_authorization_endpoint": system_setting.GetOIDCSettings().AuthorizationEndpoint,
"setup": constant.Setup,
},
})
return
}
func GetNotice(c *gin.Context) {
@@ -117,6 +87,7 @@ func GetNotice(c *gin.Context) {
"message": "",
"data": common.OptionMap["Notice"],
})
return
}
func GetAbout(c *gin.Context) {
@@ -127,6 +98,7 @@ func GetAbout(c *gin.Context) {
"message": "",
"data": common.OptionMap["About"],
})
return
}
func GetMidjourney(c *gin.Context) {
@@ -137,6 +109,7 @@ func GetMidjourney(c *gin.Context) {
"message": "",
"data": common.OptionMap["Midjourney"],
})
return
}
func GetHomePageContent(c *gin.Context) {
@@ -147,6 +120,7 @@ func GetHomePageContent(c *gin.Context) {
"message": "",
"data": common.OptionMap["HomePageContent"],
})
return
}
func SendEmailVerification(c *gin.Context) {
@@ -169,7 +143,13 @@ func SendEmailVerification(c *gin.Context) {
localPart := parts[0]
domainPart := parts[1]
if common.EmailDomainRestrictionEnabled {
allowed := slices.Contains(common.EmailDomainWhitelist, domainPart)
allowed := false
for _, domain := range common.EmailDomainWhitelist {
if domainPart == domain {
allowed = true
break
}
}
if !allowed {
c.JSON(http.StatusOK, gin.H{
"success": false,
@@ -214,6 +194,7 @@ func SendEmailVerification(c *gin.Context) {
"success": true,
"message": "",
})
return
}
func SendPasswordResetEmail(c *gin.Context) {
@@ -252,6 +233,7 @@ func SendPasswordResetEmail(c *gin.Context) {
"success": true,
"message": "",
})
return
}
type PasswordResetRequest struct {
@@ -291,13 +273,5 @@ func ResetPassword(c *gin.Context) {
"message": "",
"data": password,
})
}
func ReloadJSScripts(c *gin.Context) {
jsrt.ReloadJSScripts()
c.JSON(http.StatusOK, gin.H{
"success": true,
"message": "JavaScript 脚本已重新加载",
})
return
}

View File

@@ -3,7 +3,6 @@ package controller
import (
"fmt"
"github.com/gin-gonic/gin"
"github.com/samber/lo"
"net/http"
"one-api/common"
"one-api/constant"
@@ -15,7 +14,7 @@ import (
"one-api/relay/channel/minimax"
"one-api/relay/channel/moonshot"
relaycommon "one-api/relay/common"
"one-api/setting"
relayconstant "one-api/relay/constant"
)
// https://platform.openai.com/docs/api-reference/models/list
@@ -24,10 +23,30 @@ var openAIModels []dto.OpenAIModels
var openAIModelsMap map[string]dto.OpenAIModels
var channelId2Models map[int][]string
func getPermission() []dto.OpenAIModelPermission {
var permission []dto.OpenAIModelPermission
permission = append(permission, dto.OpenAIModelPermission{
Id: "modelperm-LwHkVFn8AcMItP432fKKDIKJ",
Object: "model_permission",
Created: 1626777600,
AllowCreateEngine: true,
AllowSampling: true,
AllowLogprobs: true,
AllowSearchIndices: false,
AllowView: true,
AllowFineTuning: false,
Organization: "*",
Group: nil,
IsBlocking: false,
})
return permission
}
func init() {
// https://platform.openai.com/docs/models/model-endpoint-compatibility
for i := 0; i < constant.APITypeDummy; i++ {
if i == constant.APITypeAIProxyLibrary {
permission := getPermission()
for i := 0; i < relayconstant.APITypeDummy; i++ {
if i == relayconstant.APITypeAIProxyLibrary {
continue
}
adaptor := relay.GetAdaptor(i)
@@ -35,51 +54,69 @@ func init() {
modelNames := adaptor.GetModelList()
for _, modelName := range modelNames {
openAIModels = append(openAIModels, dto.OpenAIModels{
Id: modelName,
Object: "model",
Created: 1626777600,
OwnedBy: channelName,
Id: modelName,
Object: "model",
Created: 1626777600,
OwnedBy: channelName,
Permission: permission,
Root: modelName,
Parent: nil,
})
}
}
for _, modelName := range ai360.ModelList {
openAIModels = append(openAIModels, dto.OpenAIModels{
Id: modelName,
Object: "model",
Created: 1626777600,
OwnedBy: ai360.ChannelName,
Id: modelName,
Object: "model",
Created: 1626777600,
OwnedBy: ai360.ChannelName,
Permission: permission,
Root: modelName,
Parent: nil,
})
}
for _, modelName := range moonshot.ModelList {
openAIModels = append(openAIModels, dto.OpenAIModels{
Id: modelName,
Object: "model",
Created: 1626777600,
OwnedBy: moonshot.ChannelName,
Id: modelName,
Object: "model",
Created: 1626777600,
OwnedBy: moonshot.ChannelName,
Permission: permission,
Root: modelName,
Parent: nil,
})
}
for _, modelName := range lingyiwanwu.ModelList {
openAIModels = append(openAIModels, dto.OpenAIModels{
Id: modelName,
Object: "model",
Created: 1626777600,
OwnedBy: lingyiwanwu.ChannelName,
Id: modelName,
Object: "model",
Created: 1626777600,
OwnedBy: lingyiwanwu.ChannelName,
Permission: permission,
Root: modelName,
Parent: nil,
})
}
for _, modelName := range minimax.ModelList {
openAIModels = append(openAIModels, dto.OpenAIModels{
Id: modelName,
Object: "model",
Created: 1626777600,
OwnedBy: minimax.ChannelName,
Id: modelName,
Object: "model",
Created: 1626777600,
OwnedBy: minimax.ChannelName,
Permission: permission,
Root: modelName,
Parent: nil,
})
}
for modelName, _ := range constant.MidjourneyModel2Action {
openAIModels = append(openAIModels, dto.OpenAIModels{
Id: modelName,
Object: "model",
Created: 1626777600,
OwnedBy: "midjourney",
Id: modelName,
Object: "model",
Created: 1626777600,
OwnedBy: "midjourney",
Permission: permission,
Root: modelName,
Parent: nil,
})
}
openAIModelsMap = make(map[string]dto.OpenAIModels)
@@ -87,9 +124,9 @@ func init() {
openAIModelsMap[aiModel.Id] = aiModel
}
channelId2Models = make(map[int][]string)
for i := 1; i <= constant.ChannelTypeDummy; i++ {
apiType, success := common.ChannelType2APIType(i)
if !success || apiType == constant.APITypeAIProxyLibrary {
for i := 1; i <= common.ChannelTypeDummy; i++ {
apiType, success := relayconstant.ChannelType2APIType(i)
if !success || apiType == relayconstant.APITypeAIProxyLibrary {
continue
}
meta := &relaycommon.RelayInfo{ChannelType: i}
@@ -97,17 +134,15 @@ func init() {
adaptor.Init(meta)
channelId2Models[i] = adaptor.GetModelList()
}
openAIModels = lo.UniqBy(openAIModels, func(m dto.OpenAIModels) string {
return m.Id
})
}
func ListModels(c *gin.Context) {
userOpenAiModels := make([]dto.OpenAIModels, 0)
permission := getPermission()
modelLimitEnable := common.GetContextKeyBool(c, constant.ContextKeyTokenModelLimitEnabled)
modelLimitEnable := c.GetBool("token_model_limit_enabled")
if modelLimitEnable {
s, ok := common.GetContextKey(c, constant.ContextKeyTokenModelLimit)
s, ok := c.Get("token_model_limit")
var tokenModelLimit map[string]bool
if ok {
tokenModelLimit = s.(map[string]bool)
@@ -115,22 +150,23 @@ func ListModels(c *gin.Context) {
tokenModelLimit = map[string]bool{}
}
for allowModel, _ := range tokenModelLimit {
if oaiModel, ok := openAIModelsMap[allowModel]; ok {
oaiModel.SupportedEndpointTypes = model.GetModelSupportEndpointTypes(allowModel)
userOpenAiModels = append(userOpenAiModels, oaiModel)
if _, ok := openAIModelsMap[allowModel]; ok {
userOpenAiModels = append(userOpenAiModels, openAIModelsMap[allowModel])
} else {
userOpenAiModels = append(userOpenAiModels, dto.OpenAIModels{
Id: allowModel,
Object: "model",
Created: 1626777600,
OwnedBy: "custom",
SupportedEndpointTypes: model.GetModelSupportEndpointTypes(allowModel),
Id: allowModel,
Object: "model",
Created: 1626777600,
OwnedBy: "custom",
Permission: permission,
Root: allowModel,
Parent: nil,
})
}
}
} else {
userId := c.GetInt("id")
userGroup, err := model.GetUserGroup(userId, false)
userGroup, err := model.GetUserGroup(userId, true)
if err != nil {
c.JSON(http.StatusOK, gin.H{
"success": false,
@@ -139,34 +175,23 @@ func ListModels(c *gin.Context) {
return
}
group := userGroup
tokenGroup := common.GetContextKeyString(c, constant.ContextKeyTokenGroup)
tokenGroup := c.GetString("token_group")
if tokenGroup != "" {
group = tokenGroup
}
var models []string
if tokenGroup == "auto" {
for _, autoGroup := range setting.AutoGroups {
groupModels := model.GetGroupEnabledModels(autoGroup)
for _, g := range groupModels {
if !common.StringsContains(models, g) {
models = append(models, g)
}
}
}
} else {
models = model.GetGroupEnabledModels(group)
}
for _, modelName := range models {
if oaiModel, ok := openAIModelsMap[modelName]; ok {
oaiModel.SupportedEndpointTypes = model.GetModelSupportEndpointTypes(modelName)
userOpenAiModels = append(userOpenAiModels, oaiModel)
models := model.GetGroupModels(group)
for _, s := range models {
if _, ok := openAIModelsMap[s]; ok {
userOpenAiModels = append(userOpenAiModels, openAIModelsMap[s])
} else {
userOpenAiModels = append(userOpenAiModels, dto.OpenAIModels{
Id: modelName,
Object: "model",
Created: 1626777600,
OwnedBy: "custom",
SupportedEndpointTypes: model.GetModelSupportEndpointTypes(modelName),
Id: s,
Object: "model",
Created: 1626777600,
OwnedBy: "custom",
Permission: permission,
Root: s,
Parent: nil,
})
}
}

View File

@@ -6,8 +6,6 @@ import (
"one-api/common"
"one-api/model"
"one-api/setting"
"one-api/setting/console_setting"
"one-api/setting/ratio_setting"
"one-api/setting/system_setting"
"strings"
@@ -104,7 +102,7 @@ func UpdateOption(c *gin.Context) {
return
}
case "GroupRatio":
err = ratio_setting.CheckGroupRatio(option.Value)
err = setting.CheckGroupRatio(option.Value)
if err != nil {
c.JSON(http.StatusOK, gin.H{
"success": false,
@@ -121,42 +119,7 @@ func UpdateOption(c *gin.Context) {
})
return
}
case "console_setting.api_info":
err = console_setting.ValidateConsoleSettings(option.Value, "ApiInfo")
if err != nil {
c.JSON(http.StatusOK, gin.H{
"success": false,
"message": err.Error(),
})
return
}
case "console_setting.announcements":
err = console_setting.ValidateConsoleSettings(option.Value, "Announcements")
if err != nil {
c.JSON(http.StatusOK, gin.H{
"success": false,
"message": err.Error(),
})
return
}
case "console_setting.faq":
err = console_setting.ValidateConsoleSettings(option.Value, "FAQ")
if err != nil {
c.JSON(http.StatusOK, gin.H{
"success": false,
"message": err.Error(),
})
return
}
case "console_setting.uptime_kuma_groups":
err = console_setting.ValidateConsoleSettings(option.Value, "UptimeKumaGroups")
if err != nil {
c.JSON(http.StatusOK, gin.H{
"success": false,
"message": err.Error(),
})
return
}
}
err = model.UpdateOption(option.Key, option.Value)
if err != nil {

View File

@@ -3,6 +3,7 @@ package controller
import (
"errors"
"fmt"
"github.com/gin-gonic/gin"
"net/http"
"one-api/common"
"one-api/constant"
@@ -12,8 +13,6 @@ import (
"one-api/service"
"one-api/setting"
"time"
"github.com/gin-gonic/gin"
)
func Playground(c *gin.Context) {
@@ -58,22 +57,13 @@ func Playground(c *gin.Context) {
c.Set("group", group)
}
c.Set("token_name", "playground-"+group)
channel, finalGroup, err := model.CacheGetRandomSatisfiedChannel(c, group, playgroundRequest.Model, 0)
channel, err := model.CacheGetRandomSatisfiedChannel(group, playgroundRequest.Model, 0)
if err != nil {
message := fmt.Sprintf("当前分组 %s 下对于模型 %s 无可用渠道", finalGroup, playgroundRequest.Model)
message := fmt.Sprintf("当前分组 %s 下对于模型 %s 无可用渠道", group, playgroundRequest.Model)
openaiErr = service.OpenAIErrorWrapperLocal(errors.New(message), "get_playground_channel_failed", http.StatusInternalServerError)
return
}
middleware.SetupContextForSelectedChannel(c, channel, playgroundRequest.Model)
common.SetContextKey(c, constant.ContextKeyRequestStartTime, time.Now())
// Write user context to ensure acceptUnsetRatio is available
userId := c.GetInt("id")
userCache, err := model.GetUserCache(userId)
if err != nil {
openaiErr = service.OpenAIErrorWrapperLocal(err, "get_user_cache_failed", http.StatusInternalServerError)
return
}
userCache.WriteContext(c)
c.Set(constant.ContextKeyRequestStartTime, time.Now())
Relay(c)
}

View File

@@ -1,11 +1,10 @@
package controller
import (
"github.com/gin-gonic/gin"
"one-api/model"
"one-api/setting"
"one-api/setting/ratio_setting"
"github.com/gin-gonic/gin"
"one-api/setting/operation_setting"
)
func GetPricing(c *gin.Context) {
@@ -13,7 +12,7 @@ func GetPricing(c *gin.Context) {
userId, exists := c.Get("id")
usableGroup := map[string]string{}
groupRatio := map[string]float64{}
for s, f := range ratio_setting.GetGroupRatioCopy() {
for s, f := range setting.GetGroupRatioCopy() {
groupRatio[s] = f
}
var group string
@@ -21,18 +20,12 @@ func GetPricing(c *gin.Context) {
user, err := model.GetUserCache(userId.(int))
if err == nil {
group = user.Group
for g := range groupRatio {
ratio, ok := ratio_setting.GetGroupGroupRatio(group, g)
if ok {
groupRatio[g] = ratio
}
}
}
}
usableGroup = setting.GetUserUsableGroups(group)
// check groupRatio contains usableGroup
for group := range ratio_setting.GetGroupRatioCopy() {
for group := range setting.GetGroupRatioCopy() {
if _, ok := usableGroup[group]; !ok {
delete(groupRatio, group)
}
@@ -47,7 +40,7 @@ func GetPricing(c *gin.Context) {
}
func ResetModelRatio(c *gin.Context) {
defaultStr := ratio_setting.DefaultModelRatio2JSONString()
defaultStr := operation_setting.DefaultModelRatio2JSONString()
err := model.UpdateOption("ModelRatio", defaultStr)
if err != nil {
c.JSON(200, gin.H{
@@ -56,7 +49,7 @@ func ResetModelRatio(c *gin.Context) {
})
return
}
err = ratio_setting.UpdateModelRatioByJSONString(defaultStr)
err = operation_setting.UpdateModelRatioByJSONString(defaultStr)
if err != nil {
c.JSON(200, gin.H{
"success": false,

View File

@@ -1,24 +0,0 @@
package controller
import (
"net/http"
"one-api/setting/ratio_setting"
"github.com/gin-gonic/gin"
)
func GetRatioConfig(c *gin.Context) {
if !ratio_setting.IsExposeRatioEnabled() {
c.JSON(http.StatusForbidden, gin.H{
"success": false,
"message": "倍率配置接口未启用",
})
return
}
c.JSON(http.StatusOK, gin.H{
"success": true,
"message": "",
"data": ratio_setting.GetExposedData(),
})
}

View File

@@ -1,474 +0,0 @@
package controller
import (
"context"
"encoding/json"
"fmt"
"net/http"
"strings"
"sync"
"time"
"one-api/common"
"one-api/dto"
"one-api/model"
"one-api/setting/ratio_setting"
"github.com/gin-gonic/gin"
)
const (
defaultTimeoutSeconds = 10
defaultEndpoint = "/api/ratio_config"
maxConcurrentFetches = 8
)
var ratioTypes = []string{"model_ratio", "completion_ratio", "cache_ratio", "model_price"}
type upstreamResult struct {
Name string `json:"name"`
Data map[string]any `json:"data,omitempty"`
Err string `json:"err,omitempty"`
}
func FetchUpstreamRatios(c *gin.Context) {
var req dto.UpstreamRequest
if err := c.ShouldBindJSON(&req); err != nil {
c.JSON(http.StatusBadRequest, gin.H{"success": false, "message": err.Error()})
return
}
if req.Timeout <= 0 {
req.Timeout = defaultTimeoutSeconds
}
var upstreams []dto.UpstreamDTO
if len(req.Upstreams) > 0 {
for _, u := range req.Upstreams {
if strings.HasPrefix(u.BaseURL, "http") {
if u.Endpoint == "" {
u.Endpoint = defaultEndpoint
}
u.BaseURL = strings.TrimRight(u.BaseURL, "/")
upstreams = append(upstreams, u)
}
}
} else if len(req.ChannelIDs) > 0 {
intIds := make([]int, 0, len(req.ChannelIDs))
for _, id64 := range req.ChannelIDs {
intIds = append(intIds, int(id64))
}
dbChannels, err := model.GetChannelsByIds(intIds)
if err != nil {
common.LogError(c.Request.Context(), "failed to query channels: "+err.Error())
c.JSON(http.StatusInternalServerError, gin.H{"success": false, "message": "查询渠道失败"})
return
}
for _, ch := range dbChannels {
if base := ch.GetBaseURL(); strings.HasPrefix(base, "http") {
upstreams = append(upstreams, dto.UpstreamDTO{
ID: ch.Id,
Name: ch.Name,
BaseURL: strings.TrimRight(base, "/"),
Endpoint: "",
})
}
}
}
if len(upstreams) == 0 {
c.JSON(http.StatusOK, gin.H{"success": false, "message": "无有效上游渠道"})
return
}
var wg sync.WaitGroup
ch := make(chan upstreamResult, len(upstreams))
sem := make(chan struct{}, maxConcurrentFetches)
client := &http.Client{Transport: &http.Transport{MaxIdleConns: 100, IdleConnTimeout: 90 * time.Second, TLSHandshakeTimeout: 10 * time.Second, ExpectContinueTimeout: 1 * time.Second}}
for _, chn := range upstreams {
wg.Add(1)
go func(chItem dto.UpstreamDTO) {
defer wg.Done()
sem <- struct{}{}
defer func() { <-sem }()
endpoint := chItem.Endpoint
if endpoint == "" {
endpoint = defaultEndpoint
} else if !strings.HasPrefix(endpoint, "/") {
endpoint = "/" + endpoint
}
fullURL := chItem.BaseURL + endpoint
uniqueName := chItem.Name
if chItem.ID != 0 {
uniqueName = fmt.Sprintf("%s(%d)", chItem.Name, chItem.ID)
}
ctx, cancel := context.WithTimeout(c.Request.Context(), time.Duration(req.Timeout)*time.Second)
defer cancel()
httpReq, err := http.NewRequestWithContext(ctx, http.MethodGet, fullURL, nil)
if err != nil {
common.LogWarn(c.Request.Context(), "build request failed: "+err.Error())
ch <- upstreamResult{Name: uniqueName, Err: err.Error()}
return
}
resp, err := client.Do(httpReq)
if err != nil {
common.LogWarn(c.Request.Context(), "http error on "+chItem.Name+": "+err.Error())
ch <- upstreamResult{Name: uniqueName, Err: err.Error()}
return
}
defer resp.Body.Close()
if resp.StatusCode != http.StatusOK {
common.LogWarn(c.Request.Context(), "non-200 from "+chItem.Name+": "+resp.Status)
ch <- upstreamResult{Name: uniqueName, Err: resp.Status}
return
}
// 兼容两种上游接口格式:
// type1: /api/ratio_config -> data 为 map[string]any包含 model_ratio/completion_ratio/cache_ratio/model_price
// type2: /api/pricing -> data 为 []Pricing 列表,需要转换为与 type1 相同的 map 格式
var body struct {
Success bool `json:"success"`
Data json.RawMessage `json:"data"`
Message string `json:"message"`
}
if err := json.NewDecoder(resp.Body).Decode(&body); err != nil {
common.LogWarn(c.Request.Context(), "json decode failed from "+chItem.Name+": "+err.Error())
ch <- upstreamResult{Name: uniqueName, Err: err.Error()}
return
}
if !body.Success {
ch <- upstreamResult{Name: uniqueName, Err: body.Message}
return
}
// 尝试按 type1 解析
var type1Data map[string]any
if err := json.Unmarshal(body.Data, &type1Data); err == nil {
// 如果包含至少一个 ratioTypes 字段,则认为是 type1
isType1 := false
for _, rt := range ratioTypes {
if _, ok := type1Data[rt]; ok {
isType1 = true
break
}
}
if isType1 {
ch <- upstreamResult{Name: uniqueName, Data: type1Data}
return
}
}
// 如果不是 type1则尝试按 type2 (/api/pricing) 解析
var pricingItems []struct {
ModelName string `json:"model_name"`
QuotaType int `json:"quota_type"`
ModelRatio float64 `json:"model_ratio"`
ModelPrice float64 `json:"model_price"`
CompletionRatio float64 `json:"completion_ratio"`
}
if err := json.Unmarshal(body.Data, &pricingItems); err != nil {
common.LogWarn(c.Request.Context(), "unrecognized data format from "+chItem.Name+": "+err.Error())
ch <- upstreamResult{Name: uniqueName, Err: "无法解析上游返回数据"}
return
}
modelRatioMap := make(map[string]float64)
completionRatioMap := make(map[string]float64)
modelPriceMap := make(map[string]float64)
for _, item := range pricingItems {
if item.QuotaType == 1 {
modelPriceMap[item.ModelName] = item.ModelPrice
} else {
modelRatioMap[item.ModelName] = item.ModelRatio
// completionRatio 可能为 0此时也直接赋值保持与上游一致
completionRatioMap[item.ModelName] = item.CompletionRatio
}
}
converted := make(map[string]any)
if len(modelRatioMap) > 0 {
ratioAny := make(map[string]any, len(modelRatioMap))
for k, v := range modelRatioMap {
ratioAny[k] = v
}
converted["model_ratio"] = ratioAny
}
if len(completionRatioMap) > 0 {
compAny := make(map[string]any, len(completionRatioMap))
for k, v := range completionRatioMap {
compAny[k] = v
}
converted["completion_ratio"] = compAny
}
if len(modelPriceMap) > 0 {
priceAny := make(map[string]any, len(modelPriceMap))
for k, v := range modelPriceMap {
priceAny[k] = v
}
converted["model_price"] = priceAny
}
ch <- upstreamResult{Name: uniqueName, Data: converted}
}(chn)
}
wg.Wait()
close(ch)
localData := ratio_setting.GetExposedData()
var testResults []dto.TestResult
var successfulChannels []struct {
name string
data map[string]any
}
for r := range ch {
if r.Err != "" {
testResults = append(testResults, dto.TestResult{
Name: r.Name,
Status: "error",
Error: r.Err,
})
} else {
testResults = append(testResults, dto.TestResult{
Name: r.Name,
Status: "success",
})
successfulChannels = append(successfulChannels, struct {
name string
data map[string]any
}{name: r.Name, data: r.Data})
}
}
differences := buildDifferences(localData, successfulChannels)
c.JSON(http.StatusOK, gin.H{
"success": true,
"data": gin.H{
"differences": differences,
"test_results": testResults,
},
})
}
func buildDifferences(localData map[string]any, successfulChannels []struct {
name string
data map[string]any
}) map[string]map[string]dto.DifferenceItem {
differences := make(map[string]map[string]dto.DifferenceItem)
allModels := make(map[string]struct{})
for _, ratioType := range ratioTypes {
if localRatioAny, ok := localData[ratioType]; ok {
if localRatio, ok := localRatioAny.(map[string]float64); ok {
for modelName := range localRatio {
allModels[modelName] = struct{}{}
}
}
}
}
for _, channel := range successfulChannels {
for _, ratioType := range ratioTypes {
if upstreamRatio, ok := channel.data[ratioType].(map[string]any); ok {
for modelName := range upstreamRatio {
allModels[modelName] = struct{}{}
}
}
}
}
confidenceMap := make(map[string]map[string]bool)
// 预处理阶段检查pricing接口的可信度
for _, channel := range successfulChannels {
confidenceMap[channel.name] = make(map[string]bool)
modelRatios, hasModelRatio := channel.data["model_ratio"].(map[string]any)
completionRatios, hasCompletionRatio := channel.data["completion_ratio"].(map[string]any)
if hasModelRatio && hasCompletionRatio {
// 遍历所有模型,检查是否满足不可信条件
for modelName := range allModels {
// 默认为可信
confidenceMap[channel.name][modelName] = true
// 检查是否满足不可信条件model_ratio为37.5且completion_ratio为1
if modelRatioVal, ok := modelRatios[modelName]; ok {
if completionRatioVal, ok := completionRatios[modelName]; ok {
// 转换为float64进行比较
if modelRatioFloat, ok := modelRatioVal.(float64); ok {
if completionRatioFloat, ok := completionRatioVal.(float64); ok {
if modelRatioFloat == 37.5 && completionRatioFloat == 1.0 {
confidenceMap[channel.name][modelName] = false
}
}
}
}
}
}
} else {
// 如果不是从pricing接口获取的数据则全部标记为可信
for modelName := range allModels {
confidenceMap[channel.name][modelName] = true
}
}
}
for modelName := range allModels {
for _, ratioType := range ratioTypes {
var localValue interface{} = nil
if localRatioAny, ok := localData[ratioType]; ok {
if localRatio, ok := localRatioAny.(map[string]float64); ok {
if val, exists := localRatio[modelName]; exists {
localValue = val
}
}
}
upstreamValues := make(map[string]interface{})
confidenceValues := make(map[string]bool)
hasUpstreamValue := false
hasDifference := false
for _, channel := range successfulChannels {
var upstreamValue interface{} = nil
if upstreamRatio, ok := channel.data[ratioType].(map[string]any); ok {
if val, exists := upstreamRatio[modelName]; exists {
upstreamValue = val
hasUpstreamValue = true
if localValue != nil && localValue != val {
hasDifference = true
} else if localValue == val {
upstreamValue = "same"
}
}
}
if upstreamValue == nil && localValue == nil {
upstreamValue = "same"
}
if localValue == nil && upstreamValue != nil && upstreamValue != "same" {
hasDifference = true
}
upstreamValues[channel.name] = upstreamValue
confidenceValues[channel.name] = confidenceMap[channel.name][modelName]
}
shouldInclude := false
if localValue != nil {
if hasDifference {
shouldInclude = true
}
} else {
if hasUpstreamValue {
shouldInclude = true
}
}
if shouldInclude {
if differences[modelName] == nil {
differences[modelName] = make(map[string]dto.DifferenceItem)
}
differences[modelName][ratioType] = dto.DifferenceItem{
Current: localValue,
Upstreams: upstreamValues,
Confidence: confidenceValues,
}
}
}
}
channelHasDiff := make(map[string]bool)
for _, ratioMap := range differences {
for _, item := range ratioMap {
for chName, val := range item.Upstreams {
if val != nil && val != "same" {
channelHasDiff[chName] = true
}
}
}
}
for modelName, ratioMap := range differences {
for ratioType, item := range ratioMap {
for chName := range item.Upstreams {
if !channelHasDiff[chName] {
delete(item.Upstreams, chName)
delete(item.Confidence, chName)
}
}
allSame := true
for _, v := range item.Upstreams {
if v != "same" {
allSame = false
break
}
}
if len(item.Upstreams) == 0 || allSame {
delete(ratioMap, ratioType)
} else {
differences[modelName][ratioType] = item
}
}
if len(ratioMap) == 0 {
delete(differences, modelName)
}
}
return differences
}
func GetSyncableChannels(c *gin.Context) {
channels, err := model.GetAllChannels(0, 0, true, false)
if err != nil {
c.JSON(http.StatusOK, gin.H{
"success": false,
"message": err.Error(),
})
return
}
var syncableChannels []dto.SyncableChannel
for _, channel := range channels {
if channel.GetBaseURL() != "" {
syncableChannels = append(syncableChannels, dto.SyncableChannel{
ID: channel.Id,
Name: channel.Name,
BaseURL: channel.GetBaseURL(),
Status: channel.Status,
})
}
}
c.JSON(http.StatusOK, gin.H{
"success": true,
"message": "",
"data": syncableChannels,
})
}

View File

@@ -5,7 +5,6 @@ import (
"one-api/common"
"one-api/model"
"strconv"
"errors"
"github.com/gin-gonic/gin"
)
@@ -127,10 +126,6 @@ func AddRedemption(c *gin.Context) {
})
return
}
if err := validateExpiredTime(redemption.ExpiredTime); err != nil {
c.JSON(http.StatusOK, gin.H{"success": false, "message": err.Error()})
return
}
var keys []string
for i := 0; i < redemption.Count; i++ {
key := common.GetUUID()
@@ -140,7 +135,6 @@ func AddRedemption(c *gin.Context) {
Key: key,
CreatedTime: common.GetTimestamp(),
Quota: redemption.Quota,
ExpiredTime: redemption.ExpiredTime,
}
err = cleanRedemption.Insert()
if err != nil {
@@ -197,18 +191,12 @@ func UpdateRedemption(c *gin.Context) {
})
return
}
if statusOnly == "" {
if err := validateExpiredTime(redemption.ExpiredTime); err != nil {
c.JSON(http.StatusOK, gin.H{"success": false, "message": err.Error()})
return
}
if statusOnly != "" {
cleanRedemption.Status = redemption.Status
} else {
// If you add more fields, please also update redemption.Update()
cleanRedemption.Name = redemption.Name
cleanRedemption.Quota = redemption.Quota
cleanRedemption.ExpiredTime = redemption.ExpiredTime
}
if statusOnly != "" {
cleanRedemption.Status = redemption.Status
}
err = cleanRedemption.Update()
if err != nil {
@@ -225,27 +213,3 @@ func UpdateRedemption(c *gin.Context) {
})
return
}
func DeleteInvalidRedemption(c *gin.Context) {
rows, err := model.DeleteInvalidRedemptions()
if err != nil {
c.JSON(http.StatusOK, gin.H{
"success": false,
"message": err.Error(),
})
return
}
c.JSON(http.StatusOK, gin.H{
"success": true,
"message": "",
"data": rows,
})
return
}
func validateExpiredTime(expired int64) error {
if expired != 0 && expired < common.GetTimestamp() {
return errors.New("过期时间不能早于当前时间")
}
return nil
}

View File

@@ -8,12 +8,12 @@ import (
"log"
"net/http"
"one-api/common"
"one-api/constant"
constant2 "one-api/constant"
"one-api/dto"
"one-api/middleware"
"one-api/model"
"one-api/relay"
"one-api/relay/constant"
relayconstant "one-api/relay/constant"
"one-api/relay/helper"
"one-api/service"
@@ -69,7 +69,7 @@ func relayHandler(c *gin.Context, relayMode int) *dto.OpenAIErrorWithStatusCode
}
func Relay(c *gin.Context) {
relayMode := relayconstant.Path2RelayMode(c.Request.URL.Path)
relayMode := constant.Path2RelayMode(c.Request.URL.Path)
requestId := c.GetString(common.RequestIdKey)
group := c.GetString("group")
originalModel := c.GetString("original_model")
@@ -132,7 +132,7 @@ func WssRelay(c *gin.Context) {
return
}
relayMode := relayconstant.Path2RelayMode(c.Request.URL.Path)
relayMode := constant.Path2RelayMode(c.Request.URL.Path)
requestId := c.GetString(common.RequestIdKey)
group := c.GetString("group")
//wss://api.openai.com/v1/realtime?model=gpt-4o-realtime-preview-2024-10-01
@@ -259,7 +259,7 @@ func getChannel(c *gin.Context, group, originalModel string, retryCount int) (*m
AutoBan: &autoBanInt,
}, nil
}
channel, _, err := model.CacheGetRandomSatisfiedChannel(c, group, originalModel, retryCount)
channel, err := model.CacheGetRandomSatisfiedChannel(group, originalModel, retryCount)
if err != nil {
return nil, errors.New(fmt.Sprintf("获取重试渠道失败: %s", err.Error()))
}
@@ -295,7 +295,7 @@ func shouldRetry(c *gin.Context, openaiErr *dto.OpenAIErrorWithStatusCode, retry
}
if openaiErr.StatusCode == http.StatusBadRequest {
channelType := c.GetInt("channel_type")
if channelType == constant.ChannelTypeAnthropic {
if channelType == common.ChannelTypeAnthropic {
return true
}
return false
@@ -388,7 +388,7 @@ func RelayTask(c *gin.Context) {
retryTimes = 0
}
for i := 0; shouldRetryTaskRelay(c, channelId, taskErr, retryTimes) && i < retryTimes; i++ {
channel, _, err := model.CacheGetRandomSatisfiedChannel(c, group, originalModel, i)
channel, err := model.CacheGetRandomSatisfiedChannel(group, originalModel, i)
if err != nil {
common.LogError(c, fmt.Sprintf("CacheGetRandomSatisfiedChannel failed: %s", err.Error()))
break
@@ -420,7 +420,7 @@ func RelayTask(c *gin.Context) {
func taskRelayHandler(c *gin.Context, relayMode int) *dto.TaskError {
var err *dto.TaskError
switch relayMode {
case relayconstant.RelayModeSunoFetch, relayconstant.RelayModeSunoFetchByID, relayconstant.RelayModeKlingFetchByID:
case relayconstant.RelayModeSunoFetch, relayconstant.RelayModeSunoFetchByID:
err = relay.RelayTaskFetch(c, relayMode)
default:
err = relay.RelayTaskSubmit(c, relayMode)

View File

@@ -75,14 +75,6 @@ func PostSetup(c *gin.Context) {
// If root doesn't exist, validate and create admin account
if !rootExists {
// Validate username length: max 12 characters to align with model.User validation
if len(req.Username) > 12 {
c.JSON(400, gin.H{
"success": false,
"message": "用户名长度不能超过12个字符",
})
return
}
// Validate password
if req.Password != req.ConfirmPassword {
c.JSON(400, gin.H{

View File

@@ -74,8 +74,6 @@ func UpdateTaskByPlatform(platform constant.TaskPlatform, taskChannelM map[int][
//_ = UpdateMidjourneyTaskAll(context.Background(), tasks)
case constant.TaskPlatformSuno:
_ = UpdateSunoTaskAll(context.Background(), taskChannelM, taskM)
case constant.TaskPlatformKling, constant.TaskPlatformJimeng:
_ = UpdateVideoTaskAll(context.Background(), platform, taskChannelM, taskM)
default:
common.SysLog("未知平台")
}
@@ -122,7 +120,7 @@ func updateSunoTaskAll(ctx context.Context, channelId int, taskIds []string, tas
}
if resp.StatusCode != http.StatusOK {
common.LogError(ctx, fmt.Sprintf("Get Task status code: %d", resp.StatusCode))
return fmt.Errorf("Get Task status code: %d", resp.StatusCode)
return errors.New(fmt.Sprintf("Get Task status code: %d", resp.StatusCode))
}
defer resp.Body.Close()
responseBody, err := io.ReadAll(resp.Body)
@@ -226,14 +224,9 @@ func checkTaskNeedUpdate(oldTask *model.Task, newTask dto.SunoDataResponse) bool
func GetAllTask(c *gin.Context) {
p, _ := strconv.Atoi(c.Query("p"))
if p < 1 {
p = 1
if p < 0 {
p = 0
}
pageSize, _ := strconv.Atoi(c.Query("page_size"))
if pageSize <= 0 {
pageSize = common.ItemsPerPage
}
startTimestamp, _ := strconv.ParseInt(c.Query("start_timestamp"), 10, 64)
endTimestamp, _ := strconv.ParseInt(c.Query("end_timestamp"), 10, 64)
// 解析其他查询参数
@@ -244,32 +237,24 @@ func GetAllTask(c *gin.Context) {
Action: c.Query("action"),
StartTimestamp: startTimestamp,
EndTimestamp: endTimestamp,
ChannelID: c.Query("channel_id"),
}
items := model.TaskGetAllTasks((p-1)*pageSize, pageSize, queryParams)
total := model.TaskCountAllTasks(queryParams)
logs := model.TaskGetAllTasks(p*common.ItemsPerPage, common.ItemsPerPage, queryParams)
if logs == nil {
logs = make([]*model.Task, 0)
}
c.JSON(200, gin.H{
"success": true,
"message": "",
"data": gin.H{
"items": items,
"total": total,
"page": p,
"page_size": pageSize,
},
"data": logs,
})
}
func GetUserTask(c *gin.Context) {
p, _ := strconv.Atoi(c.Query("p"))
if p < 1 {
p = 1
}
pageSize, _ := strconv.Atoi(c.Query("page_size"))
if pageSize <= 0 {
pageSize = common.ItemsPerPage
if p < 0 {
p = 0
}
userId := c.GetInt("id")
@@ -286,17 +271,14 @@ func GetUserTask(c *gin.Context) {
EndTimestamp: endTimestamp,
}
items := model.TaskGetAllUserTask(userId, (p-1)*pageSize, pageSize, queryParams)
total := model.TaskCountAllUserTask(userId, queryParams)
logs := model.TaskGetAllUserTask(userId, p*common.ItemsPerPage, common.ItemsPerPage, queryParams)
if logs == nil {
logs = make([]*model.Task, 0)
}
c.JSON(200, gin.H{
"success": true,
"message": "",
"data": gin.H{
"items": items,
"total": total,
"page": p,
"page_size": pageSize,
},
"data": logs,
})
}

View File

@@ -1,138 +0,0 @@
package controller
import (
"context"
"fmt"
"io"
"one-api/common"
"one-api/constant"
"one-api/model"
"one-api/relay"
"one-api/relay/channel"
"time"
)
func UpdateVideoTaskAll(ctx context.Context, platform constant.TaskPlatform, taskChannelM map[int][]string, taskM map[string]*model.Task) error {
for channelId, taskIds := range taskChannelM {
if err := updateVideoTaskAll(ctx, platform, channelId, taskIds, taskM); err != nil {
common.LogError(ctx, fmt.Sprintf("Channel #%d failed to update video async tasks: %s", channelId, err.Error()))
}
}
return nil
}
func updateVideoTaskAll(ctx context.Context, platform constant.TaskPlatform, channelId int, taskIds []string, taskM map[string]*model.Task) error {
common.LogInfo(ctx, fmt.Sprintf("Channel #%d pending video tasks: %d", channelId, len(taskIds)))
if len(taskIds) == 0 {
return nil
}
cacheGetChannel, err := model.CacheGetChannel(channelId)
if err != nil {
errUpdate := model.TaskBulkUpdate(taskIds, map[string]any{
"fail_reason": fmt.Sprintf("Failed to get channel info, channel ID: %d", channelId),
"status": "FAILURE",
"progress": "100%",
})
if errUpdate != nil {
common.SysError(fmt.Sprintf("UpdateVideoTask error: %v", errUpdate))
}
return fmt.Errorf("CacheGetChannel failed: %w", err)
}
adaptor := relay.GetTaskAdaptor(platform)
if adaptor == nil {
return fmt.Errorf("video adaptor not found")
}
for _, taskId := range taskIds {
if err := updateVideoSingleTask(ctx, adaptor, cacheGetChannel, taskId, taskM); err != nil {
common.LogError(ctx, fmt.Sprintf("Failed to update video task %s: %s", taskId, err.Error()))
}
}
return nil
}
func updateVideoSingleTask(ctx context.Context, adaptor channel.TaskAdaptor, channel *model.Channel, taskId string, taskM map[string]*model.Task) error {
baseURL := constant.ChannelBaseURLs[channel.Type]
if channel.GetBaseURL() != "" {
baseURL = channel.GetBaseURL()
}
task := taskM[taskId]
if task == nil {
common.LogError(ctx, fmt.Sprintf("Task %s not found in taskM", taskId))
return fmt.Errorf("task %s not found", taskId)
}
resp, err := adaptor.FetchTask(baseURL, channel.Key, map[string]any{
"task_id": taskId,
"action": task.Action,
})
if err != nil {
return fmt.Errorf("fetchTask failed for task %s: %w", taskId, err)
}
//if resp.StatusCode != http.StatusOK {
//return fmt.Errorf("get Video Task status code: %d", resp.StatusCode)
//}
defer resp.Body.Close()
responseBody, err := io.ReadAll(resp.Body)
if err != nil {
return fmt.Errorf("readAll failed for task %s: %w", taskId, err)
}
taskResult, err := adaptor.ParseTaskResult(responseBody)
if err != nil {
return fmt.Errorf("parseTaskResult failed for task %s: %w", taskId, err)
}
//if taskResult.Code != 0 {
// return fmt.Errorf("video task fetch failed for task %s", taskId)
//}
now := time.Now().Unix()
if taskResult.Status == "" {
return fmt.Errorf("task %s status is empty", taskId)
}
task.Status = model.TaskStatus(taskResult.Status)
switch taskResult.Status {
case model.TaskStatusSubmitted:
task.Progress = "10%"
case model.TaskStatusQueued:
task.Progress = "20%"
case model.TaskStatusInProgress:
task.Progress = "30%"
if task.StartTime == 0 {
task.StartTime = now
}
case model.TaskStatusSuccess:
task.Progress = "100%"
if task.FinishTime == 0 {
task.FinishTime = now
}
task.FailReason = taskResult.Url
case model.TaskStatusFailure:
task.Status = model.TaskStatusFailure
task.Progress = "100%"
if task.FinishTime == 0 {
task.FinishTime = now
}
task.FailReason = taskResult.Reason
common.LogInfo(ctx, fmt.Sprintf("Task %s failed: %s", task.TaskID, task.FailReason))
quota := task.Quota
if quota != 0 {
if err := model.IncreaseUserQuota(task.UserId, quota, false); err != nil {
common.LogError(ctx, "Failed to increase user quota: "+err.Error())
}
logContent := fmt.Sprintf("Video async task failed %s, refund %s", task.TaskID, common.LogQuota(quota))
model.RecordLog(task.UserId, model.LogTypeSystem, logContent)
}
default:
return fmt.Errorf("unknown task status %s for task %s", taskResult.Status, taskId)
}
if taskResult.Progress != "" {
task.Progress = taskResult.Progress
}
task.Data = responseBody
if err := task.Update(); err != nil {
common.SysError("UpdateVideoTask task error: " + err.Error())
}
return nil
}

View File

@@ -12,15 +12,15 @@ func GetAllTokens(c *gin.Context) {
userId := c.GetInt("id")
p, _ := strconv.Atoi(c.Query("p"))
size, _ := strconv.Atoi(c.Query("size"))
if p < 1 {
p = 1
if p < 0 {
p = 0
}
if size <= 0 {
size = common.ItemsPerPage
} else if size > 100 {
size = 100
}
tokens, err := model.GetAllUserTokens(userId, (p-1)*size, size)
tokens, err := model.GetAllUserTokens(userId, p*size, size)
if err != nil {
c.JSON(http.StatusOK, gin.H{
"success": false,
@@ -28,18 +28,10 @@ func GetAllTokens(c *gin.Context) {
})
return
}
// Get total count for pagination
total, _ := model.CountUserTokens(userId)
c.JSON(http.StatusOK, gin.H{
"success": true,
"message": "",
"data": gin.H{
"items": tokens,
"total": total,
"page": p,
"page_size": size,
},
"data": tokens,
})
return
}
@@ -258,32 +250,3 @@ func UpdateToken(c *gin.Context) {
})
return
}
type TokenBatch struct {
Ids []int `json:"ids"`
}
func DeleteTokenBatch(c *gin.Context) {
tokenBatch := TokenBatch{}
if err := c.ShouldBindJSON(&tokenBatch); err != nil || len(tokenBatch.Ids) == 0 {
c.JSON(http.StatusOK, gin.H{
"success": false,
"message": "参数错误",
})
return
}
userId := c.GetInt("id")
count, err := model.BatchDeleteTokens(tokenBatch.Ids, userId)
if err != nil {
c.JSON(http.StatusOK, gin.H{
"success": false,
"message": err.Error(),
})
return
}
c.JSON(http.StatusOK, gin.H{
"success": true,
"message": "",
"data": count,
})
}

View File

@@ -97,14 +97,16 @@ func RequestEpay(c *gin.Context) {
c.JSON(200, gin.H{"message": "error", "data": "充值金额过低"})
return
}
if !setting.ContainsPayMethod(req.PaymentMethod) {
c.JSON(200, gin.H{"message": "error", "data": "支付方式不存在"})
return
payType := "wxpay"
if req.PaymentMethod == "zfb" {
payType = "alipay"
}
if req.PaymentMethod == "wx" {
req.PaymentMethod = "wxpay"
payType = "wxpay"
}
callBackAddress := service.GetCallbackAddress()
returnUrl, _ := url.Parse(setting.ServerAddress + "/console/log")
returnUrl, _ := url.Parse(setting.ServerAddress + "/log")
notifyUrl, _ := url.Parse(callBackAddress + "/api/user/epay/notify")
tradeNo := fmt.Sprintf("%s%d", common.GetRandomString(6), time.Now().Unix())
tradeNo = fmt.Sprintf("USR%dNO%s", id, tradeNo)
@@ -114,7 +116,7 @@ func RequestEpay(c *gin.Context) {
return
}
uri, params, err := client.Purchase(&epay.PurchaseArgs{
Type: req.PaymentMethod,
Type: payType,
ServiceTradeNo: tradeNo,
Name: fmt.Sprintf("TUC%d", req.Amount),
Money: strconv.FormatFloat(payMoney, 'f', 2, 64),

View File

@@ -1,154 +0,0 @@
package controller
import (
"context"
"encoding/json"
"errors"
"net/http"
"one-api/setting/console_setting"
"strconv"
"strings"
"time"
"github.com/gin-gonic/gin"
"golang.org/x/sync/errgroup"
)
const (
requestTimeout = 30 * time.Second
httpTimeout = 10 * time.Second
uptimeKeySuffix = "_24"
apiStatusPath = "/api/status-page/"
apiHeartbeatPath = "/api/status-page/heartbeat/"
)
type Monitor struct {
Name string `json:"name"`
Uptime float64 `json:"uptime"`
Status int `json:"status"`
Group string `json:"group,omitempty"`
}
type UptimeGroupResult struct {
CategoryName string `json:"categoryName"`
Monitors []Monitor `json:"monitors"`
}
func getAndDecode(ctx context.Context, client *http.Client, url string, dest interface{}) error {
req, err := http.NewRequestWithContext(ctx, http.MethodGet, url, nil)
if err != nil {
return err
}
resp, err := client.Do(req)
if err != nil {
return err
}
defer resp.Body.Close()
if resp.StatusCode != http.StatusOK {
return errors.New("non-200 status")
}
return json.NewDecoder(resp.Body).Decode(dest)
}
func fetchGroupData(ctx context.Context, client *http.Client, groupConfig map[string]interface{}) UptimeGroupResult {
url, _ := groupConfig["url"].(string)
slug, _ := groupConfig["slug"].(string)
categoryName, _ := groupConfig["categoryName"].(string)
result := UptimeGroupResult{
CategoryName: categoryName,
Monitors: []Monitor{},
}
if url == "" || slug == "" {
return result
}
baseURL := strings.TrimSuffix(url, "/")
var statusData struct {
PublicGroupList []struct {
ID int `json:"id"`
Name string `json:"name"`
MonitorList []struct {
ID int `json:"id"`
Name string `json:"name"`
} `json:"monitorList"`
} `json:"publicGroupList"`
}
var heartbeatData struct {
HeartbeatList map[string][]struct {
Status int `json:"status"`
} `json:"heartbeatList"`
UptimeList map[string]float64 `json:"uptimeList"`
}
g, gCtx := errgroup.WithContext(ctx)
g.Go(func() error {
return getAndDecode(gCtx, client, baseURL+apiStatusPath+slug, &statusData)
})
g.Go(func() error {
return getAndDecode(gCtx, client, baseURL+apiHeartbeatPath+slug, &heartbeatData)
})
if g.Wait() != nil {
return result
}
for _, pg := range statusData.PublicGroupList {
if len(pg.MonitorList) == 0 {
continue
}
for _, m := range pg.MonitorList {
monitor := Monitor{
Name: m.Name,
Group: pg.Name,
}
monitorID := strconv.Itoa(m.ID)
if uptime, exists := heartbeatData.UptimeList[monitorID+uptimeKeySuffix]; exists {
monitor.Uptime = uptime
}
if heartbeats, exists := heartbeatData.HeartbeatList[monitorID]; exists && len(heartbeats) > 0 {
monitor.Status = heartbeats[0].Status
}
result.Monitors = append(result.Monitors, monitor)
}
}
return result
}
func GetUptimeKumaStatus(c *gin.Context) {
groups := console_setting.GetUptimeKumaGroups()
if len(groups) == 0 {
c.JSON(http.StatusOK, gin.H{"success": true, "message": "", "data": []UptimeGroupResult{}})
return
}
ctx, cancel := context.WithTimeout(c.Request.Context(), requestTimeout)
defer cancel()
client := &http.Client{Timeout: httpTimeout}
results := make([]UptimeGroupResult, len(groups))
g, gCtx := errgroup.WithContext(ctx)
for i, group := range groups {
i, group := i, group
g.Go(func() error {
results[i] = fetchGroupData(gCtx, client, group)
return nil
})
}
g.Wait()
c.JSON(http.StatusOK, gin.H{"success": true, "message": "", "data": results})
}

View File

@@ -6,7 +6,6 @@ import (
"net/http"
"net/url"
"one-api/common"
"one-api/dto"
"one-api/model"
"one-api/setting"
"strconv"
@@ -227,9 +226,6 @@ func Register(c *gin.Context) {
UnlimitedQuota: true,
ModelLimitsEnabled: false,
}
if setting.DefaultUseAutoGroup {
token.Group = "auto"
}
if err := token.Insert(); err != nil {
c.JSON(http.StatusOK, gin.H{
"success": false,
@@ -247,15 +243,15 @@ func Register(c *gin.Context) {
}
func GetAllUsers(c *gin.Context) {
pageInfo, err := common.GetPageQuery(c)
if err != nil {
c.JSON(http.StatusOK, gin.H{
"success": false,
"message": "parse page query failed",
})
return
p, _ := strconv.Atoi(c.Query("p"))
pageSize, _ := strconv.Atoi(c.Query("page_size"))
if p < 1 {
p = 1
}
users, total, err := model.GetAllUsers(pageInfo)
if pageSize < 0 {
pageSize = common.ItemsPerPage
}
users, total, err := model.GetAllUsers((p-1)*pageSize, pageSize)
if err != nil {
c.JSON(http.StatusOK, gin.H{
"success": false,
@@ -263,13 +259,15 @@ func GetAllUsers(c *gin.Context) {
})
return
}
pageInfo.SetTotal(int(total))
pageInfo.SetItems(users)
c.JSON(http.StatusOK, gin.H{
"success": true,
"message": "",
"data": pageInfo,
"data": gin.H{
"items": users,
"total": total,
"page": p,
"page_size": pageSize,
},
})
return
}
@@ -461,9 +459,6 @@ func GetSelf(c *gin.Context) {
})
return
}
// Hide admin remarks: set to empty to trigger omitempty tag, ensuring the remark field is not included in JSON returned to regular users
user.Remark = ""
c.JSON(http.StatusOK, gin.H{
"success": true,
"message": "",
@@ -488,7 +483,7 @@ func GetUserModels(c *gin.Context) {
groups := setting.GetUserUsableGroups(user.Group)
var models []string
for group := range groups {
for _, g := range model.GetGroupEnabledModels(group) {
for _, g := range model.GetGroupModels(group) {
if !common.StringsContains(models, g) {
models = append(models, g)
}
@@ -948,7 +943,6 @@ type UpdateUserSettingRequest struct {
WebhookSecret string `json:"webhook_secret,omitempty"`
NotificationEmail string `json:"notification_email,omitempty"`
AcceptUnsetModelRatioModel bool `json:"accept_unset_model_ratio_model"`
RecordIpLog bool `json:"record_ip_log"`
}
func UpdateUserSetting(c *gin.Context) {
@@ -962,7 +956,7 @@ func UpdateUserSetting(c *gin.Context) {
}
// 验证预警类型
if req.QuotaWarningType != dto.NotifyTypeEmail && req.QuotaWarningType != dto.NotifyTypeWebhook {
if req.QuotaWarningType != constant.NotifyTypeEmail && req.QuotaWarningType != constant.NotifyTypeWebhook {
c.JSON(http.StatusOK, gin.H{
"success": false,
"message": "无效的预警类型",
@@ -980,7 +974,7 @@ func UpdateUserSetting(c *gin.Context) {
}
// 如果是webhook类型,验证webhook地址
if req.QuotaWarningType == dto.NotifyTypeWebhook {
if req.QuotaWarningType == constant.NotifyTypeWebhook {
if req.WebhookUrl == "" {
c.JSON(http.StatusOK, gin.H{
"success": false,
@@ -999,7 +993,7 @@ func UpdateUserSetting(c *gin.Context) {
}
// 如果是邮件类型,验证邮箱地址
if req.QuotaWarningType == dto.NotifyTypeEmail && req.NotificationEmail != "" {
if req.QuotaWarningType == constant.NotifyTypeEmail && req.NotificationEmail != "" {
// 验证邮箱格式
if !strings.Contains(req.NotificationEmail, "@") {
c.JSON(http.StatusOK, gin.H{
@@ -1021,24 +1015,23 @@ func UpdateUserSetting(c *gin.Context) {
}
// 构建设置
settings := dto.UserSetting{
NotifyType: req.QuotaWarningType,
QuotaWarningThreshold: req.QuotaWarningThreshold,
AcceptUnsetRatioModel: req.AcceptUnsetModelRatioModel,
RecordIpLog: req.RecordIpLog,
settings := map[string]interface{}{
constant.UserSettingNotifyType: req.QuotaWarningType,
constant.UserSettingQuotaWarningThreshold: req.QuotaWarningThreshold,
"accept_unset_model_ratio_model": req.AcceptUnsetModelRatioModel,
}
// 如果是webhook类型,添加webhook相关设置
if req.QuotaWarningType == dto.NotifyTypeWebhook {
settings.WebhookUrl = req.WebhookUrl
if req.QuotaWarningType == constant.NotifyTypeWebhook {
settings[constant.UserSettingWebhookUrl] = req.WebhookUrl
if req.WebhookSecret != "" {
settings.WebhookSecret = req.WebhookSecret
settings[constant.UserSettingWebhookSecret] = req.WebhookSecret
}
}
// 如果提供了通知邮箱,添加到设置中
if req.QuotaWarningType == dto.NotifyTypeEmail && req.NotificationEmail != "" {
settings.NotificationEmail = req.NotificationEmail
if req.QuotaWarningType == constant.NotifyTypeEmail && req.NotificationEmail != "" {
settings[constant.UserSettingNotificationEmail] = req.NotificationEmail
}
// 更新用户设置

View File

@@ -11,17 +11,17 @@ services:
volumes:
- ./data:/data
- ./logs:/app/logs
- ${JS_SCRIPT_DIR:-./scripts}:/app/scripts
environment:
- SQL_DSN=root:123456@tcp(mysql:3306)/new-api # Point to the mysql service
- REDIS_CONN_STRING=redis://redis
- TZ=Asia/Shanghai
- ERROR_LOG_ENABLED=true # 是否启用错误日志记录
# - STREAMING_TIMEOUT=120 # 流模式无响应超时时间单位秒默认120秒如果出现空补全可以尝试改为更大值
# - TIKTOKEN_CACHE_DIR=./tiktoken_cache # 如果需要使用tiktoken_cache请取消注释
# - SESSION_SECRET=random_string # 多机部署时设置,必须修改这个随机字符串!!!!!!!
# - NODE_TYPE=slave # Uncomment for slave node in multi-node deployment
# - SYNC_FREQUENCY=60 # Uncomment if regular database syncing is needed
# - FRONTEND_BASE_URL=https://openai.justsong.cn # Uncomment for multi-node deployment with front-end URL
depends_on:
- redis
- mysql

View File

@@ -1,7 +0,0 @@
package dto
type ChannelSettings struct {
ForceFormat bool `json:"force_format,omitempty"`
ThinkingToContent bool `json:"thinking_to_content,omitempty"`
Proxy string `json:"proxy"`
}

View File

@@ -1,9 +1,6 @@
package dto
import (
"encoding/json"
"one-api/common"
)
import "encoding/json"
type ClaudeMetadata struct {
UserId string `json:"user_id"`
@@ -23,11 +20,11 @@ type ClaudeMediaMessage struct {
Delta string `json:"delta,omitempty"`
CacheControl json.RawMessage `json:"cache_control,omitempty"`
// tool_calls
Id string `json:"id,omitempty"`
Name string `json:"name,omitempty"`
Input any `json:"input,omitempty"`
Content any `json:"content,omitempty"`
ToolUseId string `json:"tool_use_id,omitempty"`
Id string `json:"id,omitempty"`
Name string `json:"name,omitempty"`
Input any `json:"input,omitempty"`
Content json.RawMessage `json:"content,omitempty"`
ToolUseId string `json:"tool_use_id,omitempty"`
}
func (c *ClaudeMediaMessage) SetText(s string) {
@@ -42,39 +39,15 @@ func (c *ClaudeMediaMessage) GetText() string {
}
func (c *ClaudeMediaMessage) IsStringContent() bool {
if c.Content == nil {
return false
}
_, ok := c.Content.(string)
if ok {
return true
}
return false
var content string
return json.Unmarshal(c.Content, &content) == nil
}
func (c *ClaudeMediaMessage) GetStringContent() string {
if c.Content == nil {
return ""
var content string
if err := json.Unmarshal(c.Content, &content); err == nil {
return content
}
switch c.Content.(type) {
case string:
return c.Content.(string)
case []any:
var contentStr string
for _, contentItem := range c.Content.([]any) {
contentMap, ok := contentItem.(map[string]any)
if !ok {
continue
}
if contentMap["type"] == ContentTypeText {
if subStr, ok := contentMap["text"].(string); ok {
contentStr += subStr
}
}
}
return contentStr
}
return ""
}
@@ -84,12 +57,16 @@ func (c *ClaudeMediaMessage) GetJsonRowString() string {
}
func (c *ClaudeMediaMessage) SetContent(content any) {
c.Content = content
jsonContent, _ := json.Marshal(content)
c.Content = jsonContent
}
func (c *ClaudeMediaMessage) ParseMediaContent() []ClaudeMediaMessage {
mediaContent, _ := common.Any2Type[[]ClaudeMediaMessage](c.Content)
return mediaContent
var mediaContent []ClaudeMediaMessage
if err := json.Unmarshal(c.Content, &mediaContent); err == nil {
return mediaContent
}
return make([]ClaudeMediaMessage, 0)
}
type ClaudeMessageSource struct {
@@ -105,36 +82,14 @@ type ClaudeMessage struct {
}
func (c *ClaudeMessage) IsStringContent() bool {
if c.Content == nil {
return false
}
_, ok := c.Content.(string)
return ok
}
func (c *ClaudeMessage) GetStringContent() string {
if c.Content == nil {
return ""
}
switch c.Content.(type) {
case string:
if c.IsStringContent() {
return c.Content.(string)
case []any:
var contentStr string
for _, contentItem := range c.Content.([]any) {
contentMap, ok := contentItem.(map[string]any)
if !ok {
continue
}
if contentMap["type"] == ContentTypeText {
if subStr, ok := contentMap["text"].(string); ok {
contentStr += subStr
}
}
}
return contentStr
}
return ""
}
@@ -143,7 +98,15 @@ func (c *ClaudeMessage) SetStringContent(content string) {
}
func (c *ClaudeMessage) ParseContent() ([]ClaudeMediaMessage, error) {
return common.Any2Type[[]ClaudeMediaMessage](c.Content)
// map content to []ClaudeMediaMessage
// parse to json
jsonContent, _ := json.Marshal(c.Content)
var contentList []ClaudeMediaMessage
err := json.Unmarshal(jsonContent, &contentList)
if err != nil {
return make([]ClaudeMediaMessage, 0), err
}
return contentList, nil
}
type Tool struct {
@@ -178,14 +141,7 @@ type ClaudeRequest struct {
type Thinking struct {
Type string `json:"type"`
BudgetTokens *int `json:"budget_tokens,omitempty"`
}
func (c *Thinking) GetBudgetTokens() int {
if c.BudgetTokens == nil {
return 0
}
return *c.BudgetTokens
BudgetTokens int `json:"budget_tokens"`
}
func (c *ClaudeRequest) IsStringSystem() bool {
@@ -205,8 +161,14 @@ func (c *ClaudeRequest) SetStringSystem(system string) {
}
func (c *ClaudeRequest) ParseSystem() []ClaudeMediaMessage {
mediaContent, _ := common.Any2Type[[]ClaudeMediaMessage](c.System)
return mediaContent
// map content to []ClaudeMediaMessage
// parse to json
jsonContent, _ := json.Marshal(c.System)
var contentList []ClaudeMediaMessage
if err := json.Unmarshal(jsonContent, &contentList); err == nil {
return contentList
}
return make([]ClaudeMediaMessage, 0)
}
type ClaudeError struct {

View File

@@ -1,6 +1,9 @@
package dto
import "encoding/json"
import (
"encoding/json"
"reflect"
)
type ImageRequest struct {
Model string `json:"model"`
@@ -15,7 +18,58 @@ type ImageRequest struct {
Background string `json:"background,omitempty"`
Moderation string `json:"moderation,omitempty"`
OutputFormat string `json:"output_format,omitempty"`
Watermark *bool `json:"watermark,omitempty"`
// 用匿名字段接住额外的字段
Extra map[string]json.RawMessage `json:"-"`
}
func (r *ImageRequest) UnmarshalJSON(data []byte) error {
// 先解析成 map[string]interface{}
var rawMap map[string]json.RawMessage
if err := json.Unmarshal(data, &rawMap); err != nil {
return err
}
// 用 struct tag 获取所有已定义字段名
knownFields := GetJSONFieldNames(reflect.TypeOf(*r))
// 再正常解析已定义字段
type Alias ImageRequest
var known Alias
if err := json.Unmarshal(data, &known); err != nil {
return err
}
*r = ImageRequest(known)
// 提取多余字段
r.Extra = make(map[string]json.RawMessage)
for k, v := range rawMap {
if _, ok := knownFields[k]; !ok {
r.Extra[k] = v
}
}
return nil
}
func (r ImageRequest) MarshalJSON() ([]byte, error) {
// 将已定义字段转为 map
type Alias ImageRequest
alias := Alias(r)
base, err := json.Marshal(alias)
if err != nil {
return nil, err
}
var baseMap map[string]json.RawMessage
if err := json.Unmarshal(base, &baseMap); err != nil {
return nil, err
}
// 合并 ExtraFields
for k, v := range r.Extra {
baseMap[k] = v
}
return json.Marshal(baseMap)
}
type ImageResponse struct {
@@ -27,3 +81,37 @@ type ImageData struct {
B64Json string `json:"b64_json"`
RevisedPrompt string `json:"revised_prompt"`
}
func GetJSONFieldNames(t reflect.Type) map[string]struct{} {
fields := make(map[string]struct{})
for i := 0; i < t.NumField(); i++ {
field := t.Field(i)
// 跳过匿名字段(例如 ExtraFields
if field.Anonymous {
continue
}
tag := field.Tag.Get("json")
if tag == "-" || tag == "" {
continue
}
// 取逗号前字段名(排除 omitempty 等)
name := tag
if commaIdx := indexComma(tag); commaIdx != -1 {
name = tag[:commaIdx]
}
fields[name] = struct{}{}
}
return fields
}
func indexComma(s string) int {
for i := 0; i < len(s); i++ {
if s[i] == ',' {
return i
}
}
return -1
}

View File

@@ -57,8 +57,6 @@ type MidjourneyDto struct {
StartTime int64 `json:"startTime"`
FinishTime int64 `json:"finishTime"`
ImageUrl string `json:"imageUrl"`
VideoUrl string `json:"videoUrl"`
VideoUrls []ImgUrls `json:"videoUrls"`
Status string `json:"status"`
Progress string `json:"progress"`
FailReason string `json:"failReason"`
@@ -67,10 +65,6 @@ type MidjourneyDto struct {
Properties *Properties `json:"properties"`
}
type ImgUrls struct {
Url string `json:"url"`
}
type MidjourneyStatus struct {
Status int `json:"status"`
}

View File

@@ -19,54 +19,50 @@ type FormatJsonSchema struct {
}
type GeneralOpenAIRequest struct {
Model string `json:"model,omitempty"`
Messages []Message `json:"messages,omitempty"`
Prompt any `json:"prompt,omitempty"`
Prefix any `json:"prefix,omitempty"`
Suffix any `json:"suffix,omitempty"`
Stream bool `json:"stream,omitempty"`
StreamOptions *StreamOptions `json:"stream_options,omitempty"`
MaxTokens uint `json:"max_tokens,omitempty"`
MaxCompletionTokens uint `json:"max_completion_tokens,omitempty"`
ReasoningEffort string `json:"reasoning_effort,omitempty"`
Temperature *float64 `json:"temperature,omitempty"`
TopP float64 `json:"top_p,omitempty"`
TopK int `json:"top_k,omitempty"`
Stop any `json:"stop,omitempty"`
N int `json:"n,omitempty"`
Input any `json:"input,omitempty"`
Instruction string `json:"instruction,omitempty"`
Size string `json:"size,omitempty"`
Functions json.RawMessage `json:"functions,omitempty"`
FrequencyPenalty float64 `json:"frequency_penalty,omitempty"`
PresencePenalty float64 `json:"presence_penalty,omitempty"`
ResponseFormat *ResponseFormat `json:"response_format,omitempty"`
EncodingFormat json.RawMessage `json:"encoding_format,omitempty"`
Seed float64 `json:"seed,omitempty"`
ParallelTooCalls *bool `json:"parallel_tool_calls,omitempty"`
Tools []ToolCallRequest `json:"tools,omitempty"`
ToolChoice any `json:"tool_choice,omitempty"`
User string `json:"user,omitempty"`
LogProbs bool `json:"logprobs,omitempty"`
TopLogProbs int `json:"top_logprobs,omitempty"`
Dimensions int `json:"dimensions,omitempty"`
Modalities json.RawMessage `json:"modalities,omitempty"`
Audio json.RawMessage `json:"audio,omitempty"`
EnableThinking any `json:"enable_thinking,omitempty"` // ali
THINKING json.RawMessage `json:"thinking,omitempty"` // doubao
ExtraBody json.RawMessage `json:"extra_body,omitempty"`
WebSearchOptions *WebSearchOptions `json:"web_search_options,omitempty"`
// OpenRouter Params
Usage json.RawMessage `json:"usage,omitempty"`
Model string `json:"model,omitempty"`
Messages []Message `json:"messages,omitempty"`
Prompt any `json:"prompt,omitempty"`
Prefix any `json:"prefix,omitempty"`
Suffix any `json:"suffix,omitempty"`
Stream bool `json:"stream,omitempty"`
StreamOptions *StreamOptions `json:"stream_options,omitempty"`
MaxTokens uint `json:"max_tokens,omitempty"`
MaxCompletionTokens uint `json:"max_completion_tokens,omitempty"`
ReasoningEffort string `json:"reasoning_effort,omitempty"`
Temperature *float64 `json:"temperature,omitempty"`
TopP float64 `json:"top_p,omitempty"`
TopK int `json:"top_k,omitempty"`
Stop any `json:"stop,omitempty"`
N int `json:"n,omitempty"`
Input any `json:"input,omitempty"`
Instruction string `json:"instruction,omitempty"`
Size string `json:"size,omitempty"`
Functions any `json:"functions,omitempty"`
FrequencyPenalty float64 `json:"frequency_penalty,omitempty"`
PresencePenalty float64 `json:"presence_penalty,omitempty"`
ResponseFormat *ResponseFormat `json:"response_format,omitempty"`
EncodingFormat any `json:"encoding_format,omitempty"`
Seed float64 `json:"seed,omitempty"`
ParallelTooCalls *bool `json:"parallel_tool_calls,omitempty"`
Tools []ToolCallRequest `json:"tools,omitempty"`
ToolChoice any `json:"tool_choice,omitempty"`
User string `json:"user,omitempty"`
LogProbs bool `json:"logprobs,omitempty"`
TopLogProbs int `json:"top_logprobs,omitempty"`
Dimensions int `json:"dimensions,omitempty"`
Modalities any `json:"modalities,omitempty"`
Audio any `json:"audio,omitempty"`
EnableThinking any `json:"enable_thinking,omitempty"` // ali
ExtraBody any `json:"extra_body,omitempty"`
WebSearchOptions *WebSearchOptions `json:"web_search_options,omitempty"`
// OpenRouter Params
Reasoning json.RawMessage `json:"reasoning,omitempty"`
// Ali Qwen Params
VlHighResolutionImages json.RawMessage `json:"vl_high_resolution_images,omitempty"`
}
func (r *GeneralOpenAIRequest) ToMap() map[string]any {
result := make(map[string]any)
data, _ := common.EncodeJson(r)
_ = common.UnmarshalJson(data, &result)
_ = common.DecodeJson(data, &result)
return result
}
@@ -111,16 +107,16 @@ func (r *GeneralOpenAIRequest) ParseInput() []string {
}
type Message struct {
Role string `json:"role"`
Content any `json:"content"`
Name *string `json:"name,omitempty"`
Prefix *bool `json:"prefix,omitempty"`
ReasoningContent string `json:"reasoning_content,omitempty"`
Reasoning string `json:"reasoning,omitempty"`
ToolCalls json.RawMessage `json:"tool_calls,omitempty"`
ToolCallId string `json:"tool_call_id,omitempty"`
parsedContent []MediaContent
//parsedStringContent *string
Role string `json:"role"`
Content json.RawMessage `json:"content"`
Name *string `json:"name,omitempty"`
Prefix *bool `json:"prefix,omitempty"`
ReasoningContent string `json:"reasoning_content,omitempty"`
Reasoning string `json:"reasoning,omitempty"`
ToolCalls json.RawMessage `json:"tool_calls,omitempty"`
ToolCallId string `json:"tool_call_id,omitempty"`
parsedContent []MediaContent
parsedStringContent *string
}
type MediaContent struct {
@@ -136,50 +132,21 @@ type MediaContent struct {
func (m *MediaContent) GetImageMedia() *MessageImageUrl {
if m.ImageUrl != nil {
if _, ok := m.ImageUrl.(*MessageImageUrl); ok {
return m.ImageUrl.(*MessageImageUrl)
}
if itemMap, ok := m.ImageUrl.(map[string]any); ok {
out := &MessageImageUrl{
Url: common.Interface2String(itemMap["url"]),
Detail: common.Interface2String(itemMap["detail"]),
MimeType: common.Interface2String(itemMap["mime_type"]),
}
return out
}
return m.ImageUrl.(*MessageImageUrl)
}
return nil
}
func (m *MediaContent) GetInputAudio() *MessageInputAudio {
if m.InputAudio != nil {
if _, ok := m.InputAudio.(*MessageInputAudio); ok {
return m.InputAudio.(*MessageInputAudio)
}
if itemMap, ok := m.InputAudio.(map[string]any); ok {
out := &MessageInputAudio{
Data: common.Interface2String(itemMap["data"]),
Format: common.Interface2String(itemMap["format"]),
}
return out
}
return m.InputAudio.(*MessageInputAudio)
}
return nil
}
func (m *MediaContent) GetFile() *MessageFile {
if m.File != nil {
if _, ok := m.File.(*MessageFile); ok {
return m.File.(*MessageFile)
}
if itemMap, ok := m.File.(map[string]any); ok {
out := &MessageFile{
FileName: common.Interface2String(itemMap["file_name"]),
FileData: common.Interface2String(itemMap["file_data"]),
FileId: common.Interface2String(itemMap["file_id"]),
}
return out
}
return m.File.(*MessageFile)
}
return nil
}
@@ -245,186 +212,6 @@ func (m *Message) SetToolCalls(toolCalls any) {
}
func (m *Message) StringContent() string {
switch m.Content.(type) {
case string:
return m.Content.(string)
case []any:
var contentStr string
for _, contentItem := range m.Content.([]any) {
contentMap, ok := contentItem.(map[string]any)
if !ok {
continue
}
if contentMap["type"] == ContentTypeText {
if subStr, ok := contentMap["text"].(string); ok {
contentStr += subStr
}
}
}
return contentStr
}
return ""
}
func (m *Message) SetNullContent() {
m.Content = nil
m.parsedContent = nil
}
func (m *Message) SetStringContent(content string) {
m.Content = content
m.parsedContent = nil
}
func (m *Message) SetMediaContent(content []MediaContent) {
m.Content = content
m.parsedContent = content
}
func (m *Message) IsStringContent() bool {
_, ok := m.Content.(string)
if ok {
return true
}
return false
}
func (m *Message) ParseContent() []MediaContent {
if m.Content == nil {
return nil
}
if len(m.parsedContent) > 0 {
return m.parsedContent
}
var contentList []MediaContent
// 先尝试解析为字符串
content, ok := m.Content.(string)
if ok {
contentList = []MediaContent{{
Type: ContentTypeText,
Text: content,
}}
m.parsedContent = contentList
return contentList
}
// 尝试解析为数组
//var arrayContent []map[string]interface{}
arrayContent, ok := m.Content.([]any)
if !ok {
return contentList
}
for _, contentItemAny := range arrayContent {
mediaItem, ok := contentItemAny.(MediaContent)
if ok {
contentList = append(contentList, mediaItem)
continue
}
contentItem, ok := contentItemAny.(map[string]any)
if !ok {
continue
}
contentType, ok := contentItem["type"].(string)
if !ok {
continue
}
switch contentType {
case ContentTypeText:
if text, ok := contentItem["text"].(string); ok {
contentList = append(contentList, MediaContent{
Type: ContentTypeText,
Text: text,
})
}
case ContentTypeImageURL:
imageUrl := contentItem["image_url"]
temp := &MessageImageUrl{
Detail: "high",
}
switch v := imageUrl.(type) {
case string:
temp.Url = v
case map[string]interface{}:
url, ok1 := v["url"].(string)
detail, ok2 := v["detail"].(string)
if ok2 {
temp.Detail = detail
}
if ok1 {
temp.Url = url
}
}
contentList = append(contentList, MediaContent{
Type: ContentTypeImageURL,
ImageUrl: temp,
})
case ContentTypeInputAudio:
if audioData, ok := contentItem["input_audio"].(map[string]interface{}); ok {
data, ok1 := audioData["data"].(string)
format, ok2 := audioData["format"].(string)
if ok1 && ok2 {
temp := &MessageInputAudio{
Data: data,
Format: format,
}
contentList = append(contentList, MediaContent{
Type: ContentTypeInputAudio,
InputAudio: temp,
})
}
}
case ContentTypeFile:
if fileData, ok := contentItem["file"].(map[string]interface{}); ok {
fileId, ok3 := fileData["file_id"].(string)
if ok3 {
contentList = append(contentList, MediaContent{
Type: ContentTypeFile,
File: &MessageFile{
FileId: fileId,
},
})
} else {
fileName, ok1 := fileData["filename"].(string)
fileDataStr, ok2 := fileData["file_data"].(string)
if ok1 && ok2 {
contentList = append(contentList, MediaContent{
Type: ContentTypeFile,
File: &MessageFile{
FileName: fileName,
FileData: fileDataStr,
},
})
}
}
}
case ContentTypeVideoUrl:
if videoUrl, ok := contentItem["video_url"].(string); ok {
contentList = append(contentList, MediaContent{
Type: ContentTypeVideoUrl,
VideoUrl: &MessageVideoUrl{
Url: videoUrl,
},
})
}
}
}
if len(contentList) > 0 {
m.parsedContent = contentList
}
return contentList
}
// old code
/*func (m *Message) StringContent() string {
if m.parsedStringContent != nil {
return *m.parsedStringContent
}
@@ -595,7 +382,7 @@ func (m *Message) ParseContent() []MediaContent {
m.parsedContent = contentList
}
return contentList
}*/
}
type WebSearchOptions struct {
SearchContextSize string `json:"search_context_size,omitempty"`
@@ -646,6 +433,4 @@ type ResponsesToolsCall struct {
Name string `json:"name,omitempty"`
Description string `json:"description,omitempty"`
Parameters json.RawMessage `json:"parameters,omitempty"`
Function json.RawMessage `json:"function,omitempty"`
Container json.RawMessage `json:"container,omitempty"`
}

View File

@@ -26,7 +26,7 @@ type OpenAITextResponse struct {
Id string `json:"id"`
Model string `json:"model"`
Object string `json:"object"`
Created any `json:"created"`
Created int64 `json:"created"`
Choices []OpenAITextResponseChoice `json:"choices"`
Error *OpenAIError `json:"error,omitempty"`
Usage `json:"usage"`
@@ -178,8 +178,6 @@ type Usage struct {
InputTokens int `json:"input_tokens"`
OutputTokens int `json:"output_tokens"`
InputTokensDetails *InputTokenDetails `json:"input_tokens_details"`
// OpenRouter Params
Cost float64 `json:"cost,omitempty"`
}
type InputTokenDetails struct {

View File

@@ -1,11 +1,26 @@
package dto
import "one-api/constant"
type OpenAIModelPermission struct {
Id string `json:"id"`
Object string `json:"object"`
Created int `json:"created"`
AllowCreateEngine bool `json:"allow_create_engine"`
AllowSampling bool `json:"allow_sampling"`
AllowLogprobs bool `json:"allow_logprobs"`
AllowSearchIndices bool `json:"allow_search_indices"`
AllowView bool `json:"allow_view"`
AllowFineTuning bool `json:"allow_fine_tuning"`
Organization string `json:"organization"`
Group *string `json:"group"`
IsBlocking bool `json:"is_blocking"`
}
type OpenAIModels struct {
Id string `json:"id"`
Object string `json:"object"`
Created int `json:"created"`
OwnedBy string `json:"owned_by"`
SupportedEndpointTypes []constant.EndpointType `json:"supported_endpoint_types"`
Id string `json:"id"`
Object string `json:"object"`
Created int `json:"created"`
OwnedBy string `json:"owned_by"`
Permission []OpenAIModelPermission `json:"permission"`
Root string `json:"root"`
Parent *string `json:"parent"`
}

View File

@@ -1,38 +0,0 @@
package dto
type UpstreamDTO struct {
ID int `json:"id,omitempty"`
Name string `json:"name" binding:"required"`
BaseURL string `json:"base_url" binding:"required"`
Endpoint string `json:"endpoint"`
}
type UpstreamRequest struct {
ChannelIDs []int64 `json:"channel_ids"`
Upstreams []UpstreamDTO `json:"upstreams"`
Timeout int `json:"timeout"`
}
// TestResult 上游测试连通性结果
type TestResult struct {
Name string `json:"name"`
Status string `json:"status"`
Error string `json:"error,omitempty"`
}
// DifferenceItem 差异项
// Current 为本地值,可能为 nil
// Upstreams 为各渠道的上游值,具体数值 / "same" / nil
type DifferenceItem struct {
Current interface{} `json:"current"`
Upstreams map[string]interface{} `json:"upstreams"`
Confidence map[string]bool `json:"confidence"`
}
type SyncableChannel struct {
ID int `json:"id"`
Name string `json:"name"`
BaseURL string `json:"base_url"`
Status int `json:"status"`
}

View File

@@ -4,7 +4,7 @@ type RerankRequest struct {
Documents []any `json:"documents"`
Query string `json:"query"`
Model string `json:"model"`
TopN int `json:"top_n,omitempty"`
TopN int `json:"top_n"`
ReturnDocuments *bool `json:"return_documents,omitempty"`
MaxChunkPerDoc int `json:"max_chunk_per_doc,omitempty"`
OverLapTokens int `json:"overlap_tokens,omitempty"`

View File

@@ -1,16 +0,0 @@
package dto
type UserSetting struct {
NotifyType string `json:"notify_type,omitempty"` // QuotaWarningType 额度预警类型
QuotaWarningThreshold float64 `json:"quota_warning_threshold,omitempty"` // QuotaWarningThreshold 额度预警阈值
WebhookUrl string `json:"webhook_url,omitempty"` // WebhookUrl webhook地址
WebhookSecret string `json:"webhook_secret,omitempty"` // WebhookSecret webhook密钥
NotificationEmail string `json:"notification_email,omitempty"` // NotificationEmail 通知邮箱地址
AcceptUnsetRatioModel bool `json:"accept_unset_model_ratio_model,omitempty"` // AcceptUnsetRatioModel 是否接受未设置价格的模型
RecordIpLog bool `json:"record_ip_log,omitempty"` // 是否记录请求和错误日志IP
}
var (
NotifyTypeEmail = "email" // Email 邮件
NotifyTypeWebhook = "webhook" // Webhook
)

View File

@@ -1,47 +0,0 @@
package dto
type VideoRequest struct {
Model string `json:"model,omitempty" example:"kling-v1"` // Model/style ID
Prompt string `json:"prompt,omitempty" example:"宇航员站起身走了"` // Text prompt
Image string `json:"image,omitempty" example:"https://h2.inkwai.com/bs2/upload-ylab-stunt/se/ai_portal_queue_mmu_image_upscale_aiweb/3214b798-e1b4-4b00-b7af-72b5b0417420_raw_image_0.jpg"` // Image input (URL/Base64)
Duration float64 `json:"duration" example:"5.0"` // Video duration (seconds)
Width int `json:"width" example:"512"` // Video width
Height int `json:"height" example:"512"` // Video height
Fps int `json:"fps,omitempty" example:"30"` // Video frame rate
Seed int `json:"seed,omitempty" example:"20231234"` // Random seed
N int `json:"n,omitempty" example:"1"` // Number of videos to generate
ResponseFormat string `json:"response_format,omitempty" example:"url"` // Response format
User string `json:"user,omitempty" example:"user-1234"` // User identifier
Metadata map[string]any `json:"metadata,omitempty"` // Vendor-specific/custom params (e.g. negative_prompt, style, quality_level, etc.)
}
// VideoResponse 视频生成提交任务后的响应
type VideoResponse struct {
TaskId string `json:"task_id"`
Status string `json:"status"`
}
// VideoTaskResponse 查询视频生成任务状态的响应
type VideoTaskResponse struct {
TaskId string `json:"task_id" example:"abcd1234efgh"` // 任务ID
Status string `json:"status" example:"succeeded"` // 任务状态
Url string `json:"url,omitempty"` // 视频资源URL成功时
Format string `json:"format,omitempty" example:"mp4"` // 视频格式
Metadata *VideoTaskMetadata `json:"metadata,omitempty"` // 结果元数据
Error *VideoTaskError `json:"error,omitempty"` // 错误信息(失败时)
}
// VideoTaskMetadata 视频任务元数据
type VideoTaskMetadata struct {
Duration float64 `json:"duration" example:"5.0"` // 实际生成的视频时长
Fps int `json:"fps" example:"30"` // 实际帧率
Width int `json:"width" example:"512"` // 实际宽度
Height int `json:"height" example:"512"` // 实际高度
Seed int `json:"seed" example:"20231234"` // 使用的随机种子
}
// VideoTaskError 视频任务错误信息
type VideoTaskError struct {
Code int `json:"code"`
Message string `json:"message"`
}

11
go.mod
View File

@@ -11,7 +11,7 @@ require (
github.com/aws/aws-sdk-go-v2/credentials v1.17.11
github.com/aws/aws-sdk-go-v2/service/bedrockruntime v1.7.4
github.com/bytedance/gopkg v0.0.0-20220118071334-3db87571198b
github.com/dop251/goja v0.0.0-20250630131328-58d95d85e994
github.com/bytedance/sonic v1.11.6
github.com/gin-contrib/cors v1.7.2
github.com/gin-contrib/gzip v0.0.6
github.com/gin-contrib/sessions v0.0.5
@@ -25,14 +25,13 @@ require (
github.com/gorilla/websocket v1.5.0
github.com/joho/godotenv v1.5.1
github.com/pkg/errors v0.9.1
github.com/pkoukk/tiktoken-go v0.1.7
github.com/samber/lo v1.39.0
github.com/shirou/gopsutil v3.21.11+incompatible
github.com/shopspring/decimal v1.4.0
github.com/tiktoken-go/tokenizer v0.6.2
golang.org/x/crypto v0.35.0
golang.org/x/image v0.23.0
golang.org/x/net v0.35.0
golang.org/x/sync v0.11.0
gorm.io/driver/mysql v1.4.3
gorm.io/driver/postgres v1.5.2
gorm.io/gorm v1.25.2
@@ -44,13 +43,12 @@ require (
github.com/aws/aws-sdk-go-v2/internal/configsources v1.3.5 // indirect
github.com/aws/aws-sdk-go-v2/internal/endpoints/v2 v2.6.5 // indirect
github.com/aws/smithy-go v1.20.2 // indirect
github.com/bytedance/sonic v1.11.6 // indirect
github.com/bytedance/sonic/loader v0.1.1 // indirect
github.com/cespare/xxhash/v2 v2.3.0 // indirect
github.com/cloudwego/base64x v0.1.4 // indirect
github.com/cloudwego/iasm v0.2.0 // indirect
github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f // indirect
github.com/dlclark/regexp2 v1.11.5 // indirect
github.com/dlclark/regexp2 v1.11.0 // indirect
github.com/dustin/go-humanize v1.0.1 // indirect
github.com/gabriel-vasile/mimetype v1.4.3 // indirect
github.com/gin-contrib/sse v0.1.0 // indirect
@@ -58,11 +56,9 @@ require (
github.com/go-ole/go-ole v1.2.6 // indirect
github.com/go-playground/locales v0.14.1 // indirect
github.com/go-playground/universal-translator v0.18.1 // indirect
github.com/go-sourcemap/sourcemap v2.1.3+incompatible // indirect
github.com/go-sql-driver/mysql v1.7.0 // indirect
github.com/goccy/go-json v0.10.2 // indirect
github.com/google/go-cmp v0.6.0 // indirect
github.com/google/pprof v0.0.0-20230207041349-798e818bf904 // indirect
github.com/gorilla/context v1.1.1 // indirect
github.com/gorilla/securecookie v1.1.1 // indirect
github.com/gorilla/sessions v1.2.1 // indirect
@@ -88,6 +84,7 @@ require (
github.com/yusufpapurcu/wmi v1.2.3 // indirect
golang.org/x/arch v0.12.0 // indirect
golang.org/x/exp v0.0.0-20240404231335-c0f41cb1a7a0 // indirect
golang.org/x/sync v0.11.0 // indirect
golang.org/x/sys v0.30.0 // indirect
golang.org/x/text v0.22.0 // indirect
google.golang.org/protobuf v1.34.2 // indirect

18
go.sum
View File

@@ -1,7 +1,5 @@
github.com/Calcium-Ion/go-epay v0.0.4 h1:C96M7WfRLadcIVscWzwLiYs8etI1wrDmtFMuK2zP22A=
github.com/Calcium-Ion/go-epay v0.0.4/go.mod h1:cxo/ZOg8ClvE3VAnCmEzbuyAZINSq7kFEN9oHj5WQ2U=
github.com/Masterminds/semver/v3 v3.2.1 h1:RN9w6+7QoMeJVGyfmbcgs28Br8cvmnucEXnY0rYXWg0=
github.com/Masterminds/semver/v3 v3.2.1/go.mod h1:qvl/7zhW3nngYb5+80sSMF+FG2BjYrf8m9wsX0PNOMQ=
github.com/andybalholm/brotli v1.1.1 h1:PR2pgnyFznKEugtsUo0xLdDop5SKXd5Qf5ysW+7XdTA=
github.com/andybalholm/brotli v1.1.1/go.mod h1:05ib4cKhjx3OQYUY22hTVd34Bc8upXjOLL2rKwwZBoA=
github.com/anknown/ahocorasick v0.0.0-20190904063843-d75dbd5169c0 h1:onfun1RA+KcxaMk1lfrRnwCd1UUuOjJM/lri5eM1qMs=
@@ -40,10 +38,8 @@ github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c
github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f h1:lO4WD4F/rVNCu3HqELle0jiPLLBs70cWOduZpkS1E78=
github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f/go.mod h1:cuUVRXasLTGF7a8hSLbxyZXjz+1KgoB3wDUb6vlszIc=
github.com/dlclark/regexp2 v1.11.5 h1:Q/sSnsKerHeCkc/jSTNq1oCm7KiVgUMZRDUoRu0JQZQ=
github.com/dlclark/regexp2 v1.11.5/go.mod h1:DHkYz0B9wPfa6wondMfaivmHpzrQ3v9q8cnmRbL6yW8=
github.com/dop251/goja v0.0.0-20250630131328-58d95d85e994 h1:aQYWswi+hRL2zJqGacdCZx32XjKYV8ApXFGntw79XAM=
github.com/dop251/goja v0.0.0-20250630131328-58d95d85e994/go.mod h1:MxLav0peU43GgvwVgNbLAj1s/bSGboKkhuULvq/7hx4=
github.com/dlclark/regexp2 v1.11.0 h1:G/nrcoOa7ZXlpoa/91N3X7mM3r8eIlMBBJZvsz/mxKI=
github.com/dlclark/regexp2 v1.11.0/go.mod h1:DHkYz0B9wPfa6wondMfaivmHpzrQ3v9q8cnmRbL6yW8=
github.com/dustin/go-humanize v1.0.1 h1:GzkhY7T5VNhEkwH0PVJgjz+fX1rhBrR7pRT3mDkpeCY=
github.com/dustin/go-humanize v1.0.1/go.mod h1:Mu1zIs6XwVuF/gI1OepvI0qD18qycQx+mFykh5fBlto=
github.com/fsnotify/fsnotify v1.4.9 h1:hsms1Qyu0jgnwNXIxa+/V/PDsU6CfLf6CNO8H7IWoS4=
@@ -87,8 +83,6 @@ github.com/go-playground/validator/v10 v10.20.0 h1:K9ISHbSaI0lyB2eWMPJo+kOS/FBEx
github.com/go-playground/validator/v10 v10.20.0/go.mod h1:dbuPbCMFw/DrkbEynArYaCwl3amGuJotoKCe95atGMM=
github.com/go-redis/redis/v8 v8.11.5 h1:AcZZR7igkdvfVmQTPnu9WE37LRrO/YrBH5zWyjDC0oI=
github.com/go-redis/redis/v8 v8.11.5/go.mod h1:gREzHqY1hg6oD9ngVRbLStwAWKhA0FEgq8Jd4h5lpwo=
github.com/go-sourcemap/sourcemap v2.1.3+incompatible h1:W1iEw64niKVGogNgBN3ePyLFfuisuzeidWPMPWmECqU=
github.com/go-sourcemap/sourcemap v2.1.3+incompatible/go.mod h1:F8jJfvm2KbVjc5NqelyYJmf/v5J0dwNLS2mL4sNA1Jg=
github.com/go-sql-driver/mysql v1.6.0/go.mod h1:DCzpHaOWr8IXmIStZouvnhqoel9Qv2LBy8hT2VhHyBg=
github.com/go-sql-driver/mysql v1.7.0 h1:ueSltNNllEqE3qcWBTD0iQd3IpL/6U+mJxLkazJ7YPc=
github.com/go-sql-driver/mysql v1.7.0/go.mod h1:OXbVy3sEdcQ2Doequ6Z5BW6fXNQTmx+9S1MCJN5yJMI=
@@ -103,8 +97,8 @@ github.com/google/go-cmp v0.5.5/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/
github.com/google/go-cmp v0.6.0 h1:ofyhxvXcZhMsU5ulbFiLKl/XBFqE1GSq7atu8tAmTRI=
github.com/google/go-cmp v0.6.0/go.mod h1:17dUlkBOakJ0+DkrSSNjCkIjxS6bF9zb3elmeNGIjoY=
github.com/google/gofuzz v1.0.0/go.mod h1:dBl0BpW6vV/+mYPU4Po3pmUjxk6FQPldtuIdl/M65Eg=
github.com/google/pprof v0.0.0-20230207041349-798e818bf904 h1:4/hN5RUoecvl+RmJRE2YxKWtnnQls6rQjjW5oV7qg2U=
github.com/google/pprof v0.0.0-20230207041349-798e818bf904/go.mod h1:uglQLonpP8qtYCYyzA+8c/9qtqgA3qsXGYqCPKARAFg=
github.com/google/pprof v0.0.0-20221118152302-e6195bd50e26 h1:Xim43kblpZXfIBQsbuBVKCudVG457BR2GZFIz3uw3hQ=
github.com/google/pprof v0.0.0-20221118152302-e6195bd50e26/go.mod h1:dDKJzRmX4S37WGHujM7tX//fmj1uioxKzKxz3lo4HJo=
github.com/google/uuid v1.6.0 h1:NIvaJDMOsjHA8n1jAhLSgzrAzy1Hgr+hNrb57e+94F0=
github.com/google/uuid v1.6.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo=
github.com/gorilla/context v1.1.1 h1:AWwleXJkX/nhcU9bZSnZoi3h/qGYqQAGhq6zZe/aQW8=
@@ -173,6 +167,8 @@ github.com/pelletier/go-toml/v2 v2.2.1/go.mod h1:1t835xjRzz80PqgE6HHgN2JOsmgYu/h
github.com/pkg/diff v0.0.0-20210226163009-20ebb0f2a09e/go.mod h1:pJLUxLENpZxwdsKMEsNbx1VGcRFpLqf3715MtcvvzbA=
github.com/pkg/errors v0.9.1 h1:FEBLx1zS214owpjy7qsBeixbURkuhQAwrK5UwLGTwt4=
github.com/pkg/errors v0.9.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0=
github.com/pkoukk/tiktoken-go v0.1.7 h1:qOBHXX4PHtvIvmOtyg1EeKlwFRiMKAcoMp4Q+bLQDmw=
github.com/pkoukk/tiktoken-go v0.1.7/go.mod h1:9NiV+i9mJKGj1rYOT+njbv+ZwA/zJxYdewGl6qVatpg=
github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM=
github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
github.com/remyoudompheng/bigfft v0.0.0-20200410134404-eec4a21b6bb0/go.mod h1:qqbHyh8v60DhA7CoWK5oRCqLrMHRGoxYCSS9EjAz6Eo=
@@ -201,8 +197,6 @@ github.com/stretchr/testify v1.8.1/go.mod h1:w2LPCIKwWwSfY2zedu0+kehJoqGctiVI29o
github.com/stretchr/testify v1.8.4/go.mod h1:sz/lmYIOXD/1dqDmKjjqLyZ2RngseejIcXlSw2iwfAo=
github.com/stretchr/testify v1.9.0 h1:HtqpIVDClZ4nwg75+f6Lvsy/wHu+3BoSGCbBAcpTsTg=
github.com/stretchr/testify v1.9.0/go.mod h1:r2ic/lqez/lEtzL7wO/rwa5dbSLXVDPFyf8C91i36aY=
github.com/tiktoken-go/tokenizer v0.6.2 h1:t0GN2DvcUZSFWT/62YOgoqb10y7gSXBGs0A+4VCQK+g=
github.com/tiktoken-go/tokenizer v0.6.2/go.mod h1:6UCYI/DtOallbmL7sSy30p6YQv60qNyU/4aVigPOx6w=
github.com/tklauser/go-sysconf v0.3.12 h1:0QaGUFOdQaIVdPgfITYzaTegZvdCjmYO52cSFAEVmqU=
github.com/tklauser/go-sysconf v0.3.12/go.mod h1:Ho14jnntGE1fpdOqQEEaiKRpvIavV0hSfmBq8nJbHYI=
github.com/tklauser/numcpus v0.6.1 h1:ng9scYS7az0Bk4OZLvrNXNSAO2Pxr1XXRAPyjhIx+Fk=

File diff suppressed because it is too large Load Diff

96
main.go
View File

@@ -12,7 +12,7 @@ import (
"one-api/model"
"one-api/router"
"one-api/service"
"one-api/setting/ratio_setting"
"one-api/setting/operation_setting"
"os"
"strconv"
@@ -32,13 +32,14 @@ var buildFS embed.FS
var indexPage []byte
func main() {
err := InitResources()
err := godotenv.Load(".env")
if err != nil {
common.FatalLog("failed to initialize resources: " + err.Error())
return
common.SysLog("Support for .env file is disabled: " + err.Error())
}
common.LoadEnv()
common.SetupLogger()
common.SysLog("New API " + common.Version + " started")
if os.Getenv("GIN_MODE") != "debug" {
gin.SetMode(gin.ReleaseMode)
@@ -46,7 +47,19 @@ func main() {
if common.DebugEnabled {
common.SysLog("running in debug mode")
}
// Initialize SQL Database
err = model.InitDB()
if err != nil {
common.FatalLog("failed to initialize database: " + err.Error())
}
model.CheckSetup()
// Initialize SQL Database
err = model.InitLogDB()
if err != nil {
common.FatalLog("failed to initialize database: " + err.Error())
}
defer func() {
err := model.CloseDB()
if err != nil {
@@ -54,6 +67,21 @@ func main() {
}
}()
// Initialize Redis
err = common.InitRedisClient()
if err != nil {
common.FatalLog("failed to initialize Redis: " + err.Error())
}
// Initialize model settings
operation_setting.InitRatioSettings()
// Initialize constants
constant.InitEnv()
// Initialize options
model.InitOptionMap()
service.InitTokenEncoders()
if common.RedisEnabled {
// for compatibility with old versions
common.MemoryCacheEnabled = true
@@ -68,21 +96,19 @@ func main() {
if r := recover(); r != nil {
common.SysError(fmt.Sprintf("InitChannelCache panic: %v, retrying once", r))
// Retry once
_, _, fixErr := model.FixAbility()
_, fixErr := model.FixAbility()
if fixErr != nil {
common.FatalLog(fmt.Sprintf("InitChannelCache failed: %s", fixErr.Error()))
common.SysError(fmt.Sprintf("InitChannelCache failed: %s", fixErr.Error()))
}
}
}()
model.InitChannelCache()
}()
go model.SyncOptions(common.SyncFrequency)
go model.SyncChannelCache(common.SyncFrequency)
}
// 热更新配置
go model.SyncOptions(common.SyncFrequency)
// 数据看板
go model.UpdateQuotaData()
@@ -158,53 +184,3 @@ func main() {
common.FatalLog("failed to start HTTP server: " + err.Error())
}
}
func InitResources() error {
// Initialize resources here if needed
// This is a placeholder function for future resource initialization
err := godotenv.Load(".env")
if err != nil {
common.SysLog("未找到 .env 文件,使用默认环境变量,如果需要,请创建 .env 文件并设置相关变量")
common.SysLog("No .env file found, using default environment variables. If needed, please create a .env file and set the relevant variables.")
}
// 加载环境变量
common.InitEnv()
common.SetupLogger()
// Initialize model settings
ratio_setting.InitRatioSettings()
service.InitHttpClient()
service.InitTokenEncoders()
// Initialize SQL Database
err = model.InitDB()
if err != nil {
common.FatalLog("failed to initialize database: " + err.Error())
return err
}
model.CheckSetup()
// Initialize options, should after model.InitDB()
model.InitOptionMap()
// 初始化模型
model.GetPricing()
// Initialize SQL Database
err = model.InitLogDB()
if err != nil {
return err
}
// Initialize Redis
err = common.InitRedisClient()
if err != nil {
return err
}
return nil
}

View File

@@ -184,7 +184,7 @@ func TokenAuth() func(c *gin.Context) {
}
}
// gemini api 从query中获取key
if strings.HasPrefix(c.Request.URL.Path, "/v1beta/models/") || strings.HasPrefix(c.Request.URL.Path, "/v1/models/") {
if strings.HasPrefix(c.Request.URL.Path, "/v1beta/models/") {
skKey := c.Query("key")
if skKey != "" {
c.Request.Header.Set("Authorization", "Bearer "+skKey)

View File

@@ -11,7 +11,6 @@ import (
relayconstant "one-api/relay/constant"
"one-api/service"
"one-api/setting"
"one-api/setting/ratio_setting"
"strconv"
"strings"
"time"
@@ -25,7 +24,7 @@ type ModelRequest struct {
func Distribute() func(c *gin.Context) {
return func(c *gin.Context) {
allowIpsMap := common.GetContextKeyStringMap(c, constant.ContextKeyTokenAllowIps)
allowIpsMap := c.GetStringMap("allow_ips")
if len(allowIpsMap) != 0 {
clientIp := c.ClientIP()
if _, ok := allowIpsMap[clientIp]; !ok {
@@ -34,14 +33,14 @@ func Distribute() func(c *gin.Context) {
}
}
var channel *model.Channel
channelId, ok := common.GetContextKey(c, constant.ContextKeyTokenSpecificChannelId)
channelId, ok := c.Get("specific_channel_id")
modelRequest, shouldSelectChannel, err := getModelRequest(c)
if err != nil {
abortWithOpenAiMessage(c, http.StatusBadRequest, "Invalid request, "+err.Error())
return
}
userGroup := common.GetContextKeyString(c, constant.ContextKeyUserGroup)
tokenGroup := common.GetContextKeyString(c, constant.ContextKeyTokenGroup)
userGroup := c.GetString(constant.ContextKeyUserGroup)
tokenGroup := c.GetString("token_group")
if tokenGroup != "" {
// check common.UserUsableGroups[userGroup]
if _, ok := setting.GetUserUsableGroups(userGroup)[tokenGroup]; !ok {
@@ -49,15 +48,13 @@ func Distribute() func(c *gin.Context) {
return
}
// check group in common.GroupRatio
if !ratio_setting.ContainsGroupRatio(tokenGroup) {
if tokenGroup != "auto" {
abortWithOpenAiMessage(c, http.StatusForbidden, fmt.Sprintf("分组 %s 已被弃用", tokenGroup))
return
}
if !setting.ContainsGroupRatio(tokenGroup) {
abortWithOpenAiMessage(c, http.StatusForbidden, fmt.Sprintf("分组 %s 已被弃用", tokenGroup))
return
}
userGroup = tokenGroup
}
common.SetContextKey(c, constant.ContextKeyUsingGroup, userGroup)
c.Set("group", userGroup)
if ok {
id, err := strconv.Atoi(channelId.(string))
if err != nil {
@@ -76,9 +73,9 @@ func Distribute() func(c *gin.Context) {
} else {
// Select a channel for the user
// check token model mapping
modelLimitEnable := common.GetContextKeyBool(c, constant.ContextKeyTokenModelLimitEnabled)
modelLimitEnable := c.GetBool("token_model_limit_enabled")
if modelLimitEnable {
s, ok := common.GetContextKey(c, constant.ContextKeyTokenModelLimit)
s, ok := c.Get("token_model_limit")
var tokenModelLimit map[string]bool
if ok {
tokenModelLimit = s.(map[string]bool)
@@ -98,14 +95,9 @@ func Distribute() func(c *gin.Context) {
}
if shouldSelectChannel {
var selectGroup string
channel, selectGroup, err = model.CacheGetRandomSatisfiedChannel(c, userGroup, modelRequest.Model, 0)
channel, err = model.CacheGetRandomSatisfiedChannel(userGroup, modelRequest.Model, 0)
if err != nil {
showGroup := userGroup
if userGroup == "auto" {
showGroup = fmt.Sprintf("auto(%s)", selectGroup)
}
message := fmt.Sprintf("当前分组 %s 下对于模型 %s 无可用渠道", showGroup, modelRequest.Model)
message := fmt.Sprintf("当前分组 %s 下对于模型 %s 无可用渠道", userGroup, modelRequest.Model)
// 如果错误,但是渠道不为空,说明是数据库一致性问题
if channel != nil {
common.SysError(fmt.Sprintf("渠道不存在:%d", channel.Id))
@@ -121,7 +113,7 @@ func Distribute() func(c *gin.Context) {
}
}
}
common.SetContextKey(c, constant.ContextKeyRequestStartTime, time.Now())
c.Set(constant.ContextKeyRequestStartTime, time.Now())
SetupContextForSelectedChannel(c, channel, modelRequest.Model)
c.Next()
}
@@ -170,26 +162,7 @@ func getModelRequest(c *gin.Context) (*ModelRequest, bool, error) {
}
c.Set("platform", string(constant.TaskPlatformSuno))
c.Set("relay_mode", relayMode)
} else if strings.Contains(c.Request.URL.Path, "/v1/video/generations") {
err = common.UnmarshalBodyReusable(c, &modelRequest)
var platform string
var relayMode int
if strings.HasPrefix(modelRequest.Model, "jimeng") {
platform = string(constant.TaskPlatformJimeng)
relayMode = relayconstant.Path2RelayJimeng(c.Request.Method, c.Request.URL.Path)
if relayMode == relayconstant.RelayModeJimengFetchByID {
shouldSelectChannel = false
}
} else {
platform = string(constant.TaskPlatformKling)
relayMode = relayconstant.Path2RelayKling(c.Request.Method, c.Request.URL.Path)
if relayMode == relayconstant.RelayModeKlingFetchByID {
shouldSelectChannel = false
}
}
c.Set("platform", platform)
c.Set("relay_mode", relayMode)
} else if strings.HasPrefix(c.Request.URL.Path, "/v1beta/models/") || strings.HasPrefix(c.Request.URL.Path, "/v1/models/") {
} else if strings.HasPrefix(c.Request.URL.Path, "/v1beta/models/") {
// Gemini API 路径处理: /v1beta/models/gemini-2.0-flash:generateContent
relayMode := relayconstant.RelayModeGemini
modelName := extractModelNameFromGeminiPath(c.Request.URL.Path)
@@ -247,9 +220,9 @@ func SetupContextForSelectedChannel(c *gin.Context, channel *model.Channel, mode
}
c.Set("channel_id", channel.Id)
c.Set("channel_name", channel.Name)
common.SetContextKey(c, constant.ContextKeyChannelType, channel.Type)
c.Set("channel_type", channel.Type)
c.Set("channel_create_time", channel.CreatedTime)
common.SetContextKey(c, constant.ContextKeyChannelSetting, channel.GetSetting())
c.Set("channel_setting", channel.GetSetting())
c.Set("param_override", channel.GetParamOverride())
if nil != channel.OpenAIOrganization && "" != *channel.OpenAIOrganization {
c.Set("channel_organization", *channel.OpenAIOrganization)
@@ -258,24 +231,24 @@ func SetupContextForSelectedChannel(c *gin.Context, channel *model.Channel, mode
c.Set("model_mapping", channel.GetModelMapping())
c.Set("status_code_mapping", channel.GetStatusCodeMapping())
c.Request.Header.Set("Authorization", fmt.Sprintf("Bearer %s", channel.Key))
common.SetContextKey(c, constant.ContextKeyBaseUrl, channel.GetBaseURL())
c.Set("base_url", channel.GetBaseURL())
// TODO: api_version统一
switch channel.Type {
case constant.ChannelTypeAzure:
case common.ChannelTypeAzure:
c.Set("api_version", channel.Other)
case constant.ChannelTypeVertexAi:
case common.ChannelTypeVertexAi:
c.Set("region", channel.Other)
case constant.ChannelTypeXunfei:
case common.ChannelTypeXunfei:
c.Set("api_version", channel.Other)
case constant.ChannelTypeGemini:
case common.ChannelTypeGemini:
c.Set("api_version", channel.Other)
case constant.ChannelTypeAli:
case common.ChannelTypeAli:
c.Set("plugin", channel.Other)
case constant.ChannelCloudflare:
case common.ChannelCloudflare:
c.Set("api_version", channel.Other)
case constant.ChannelTypeMokaAI:
case common.ChannelTypeMokaAI:
c.Set("api_version", channel.Other)
case constant.ChannelTypeCoze:
case common.ChannelTypeCoze:
c.Set("bot_id", channel.Other)
}
}

View File

@@ -1,62 +0,0 @@
package jsrt
import (
"os"
"strconv"
"time"
)
// Runtime 配置
type JSRuntimeConfig struct {
Enabled bool `json:"enabled"`
MaxVMCount int `json:"max_vm_count"`
ScriptTimeout time.Duration `json:"script_timeout"`
ScriptDir string `json:"script_dir"`
FetchTimeout time.Duration `json:"fetch_timeout"`
}
var (
jsConfig = JSRuntimeConfig{}
)
const (
defaultScriptDir = "scripts/"
defaultScriptTimeout = 5 * time.Second
defaultFetchTimeout = 10 * time.Second
defaultMaxVMCount = 8
)
func loadCfg() {
if enabled := os.Getenv("JS_RUNTIME_ENABLED"); enabled != "" {
jsConfig.Enabled = enabled == "true"
}
if maxCount := os.Getenv("JS_MAX_VM_COUNT"); maxCount != "" {
if count, err := strconv.Atoi(maxCount); err == nil && count > 0 {
jsConfig.MaxVMCount = count
}
} else {
jsConfig.MaxVMCount = defaultMaxVMCount
}
if timeout := os.Getenv("JS_SCRIPT_TIMEOUT"); timeout != "" {
if t, err := time.ParseDuration(timeout + "s"); err == nil && t > 0 {
jsConfig.ScriptTimeout = t
}
} else {
jsConfig.ScriptTimeout = defaultScriptTimeout
}
if fetchTimeout := os.Getenv("JS_FETCH_TIMEOUT"); fetchTimeout != "" {
if t, err := time.ParseDuration(fetchTimeout + "s"); err == nil && t > 0 {
jsConfig.FetchTimeout = t
}
} else {
jsConfig.FetchTimeout = defaultFetchTimeout
}
jsConfig.ScriptDir = os.Getenv("JS_SCRIPT_DIR")
if jsConfig.ScriptDir == "" {
jsConfig.ScriptDir = defaultScriptDir
}
}

View File

@@ -1,69 +0,0 @@
package jsrt
import (
"one-api/common"
"gorm.io/gorm"
)
func dbQuery(db *gorm.DB, sql string, args ...any) []map[string]any {
if db == nil {
common.SysError("JS DB is nil")
return nil
}
rows, err := db.Raw(sql, args...).Rows()
if err != nil {
common.SysError("JS DB Query Error: " + err.Error())
return nil
}
defer rows.Close()
columns, err := rows.Columns()
if err != nil {
common.SysError("JS DB Columns Error: " + err.Error())
return nil
}
results := make([]map[string]any, 0, 100)
for rows.Next() {
values := make([]any, len(columns))
valuePtrs := make([]any, len(columns))
for i := range values {
valuePtrs[i] = &values[i]
}
if err := rows.Scan(valuePtrs...); err != nil {
common.SysError("JS DB Scan Error: " + err.Error())
continue
}
row := make(map[string]any, len(columns))
for i, col := range columns {
val := values[i]
if b, ok := val.([]byte); ok {
row[col] = string(b)
} else {
row[col] = val
}
}
results = append(results, row)
}
return results
}
func dbExec(db *gorm.DB, sql string, args ...any) map[string]any {
if db == nil {
return map[string]any{
"rowsAffected": int64(0),
"error": "database is nil",
}
}
result := db.Exec(sql, args...)
return map[string]any{
"rowsAffected": result.RowsAffected,
"error": result.Error,
}
}

View File

@@ -1,137 +0,0 @@
package jsrt
import (
"bytes"
"context"
"encoding/json"
"fmt"
"io"
"net/http"
"strings"
"time"
)
type JSFetchRequest struct {
Method string `json:"method"`
URL string `json:"url"`
Headers map[string]string `json:"headers"`
Body any `json:"body"`
Timeout int `json:"timeout"`
}
type JSFetchResponse struct {
Status int `json:"status"`
Headers map[string]string `json:"headers"`
Body string `json:"body"`
Error string `json:"error,omitempty"`
}
func (p *JSRuntimePool) fetch(url string, options ...any) *JSFetchResponse {
req := &JSFetchRequest{
Method: "GET",
URL: url,
Headers: make(map[string]string),
Timeout: int(jsConfig.FetchTimeout.Seconds()),
}
// 解析选项
if len(options) > 0 && options[0] != nil {
if optMap, ok := options[0].(map[string]any); ok {
if method, exists := optMap["method"]; exists {
if methodStr, ok := method.(string); ok {
req.Method = strings.ToUpper(methodStr)
}
}
if headers, exists := optMap["headers"]; exists {
if headersMap, ok := headers.(map[string]any); ok {
for k, v := range headersMap {
if vStr, ok := v.(string); ok {
req.Headers[k] = vStr
}
}
}
}
if body, exists := optMap["body"]; exists {
req.Body = body
}
if timeout, exists := optMap["timeout"]; exists {
if timeoutNum, ok := timeout.(float64); ok {
req.Timeout = int(timeoutNum)
}
}
}
}
// 创建HTTP请求
var bodyReader io.Reader
switch body := req.Body.(type) {
case string:
bodyReader = strings.NewReader(body)
case []byte:
bodyReader = bytes.NewReader(body)
case nil:
bodyReader = nil
default:
bodyBytes, err := json.Marshal(body)
if err != nil {
return &JSFetchResponse{
Error: fmt.Sprintf("Failed to marshal body: %v", err),
}
}
bodyReader = bytes.NewReader(bodyBytes)
}
httpReq, err := http.NewRequest(req.Method, req.URL, bodyReader)
if err != nil {
return &JSFetchResponse{
Error: err.Error(),
}
}
// 设置请求头
for k, v := range req.Headers {
httpReq.Header.Set(k, v)
}
// 设置默认User-Agent
if httpReq.Header.Get("User-Agent") == "" {
httpReq.Header.Set("User-Agent", "JS-Runtime-Fetch/1.0")
}
// 创建带超时的上下文
ctx, cancel := context.WithTimeout(context.Background(), time.Duration(req.Timeout)*time.Second)
defer cancel()
httpReq = httpReq.WithContext(ctx)
// 执行请求
resp, err := p.httpClient.Do(httpReq)
if err != nil {
return &JSFetchResponse{}
}
defer resp.Body.Close()
// 读取响应体
bodyBytes, err := io.ReadAll(resp.Body)
if err != nil {
return &JSFetchResponse{
Status: resp.StatusCode,
}
}
// 构建响应头
headers := make(map[string]string)
for k, v := range resp.Header {
if len(v) > 0 {
headers[k] = v[0]
}
}
return &JSFetchResponse{
Status: resp.StatusCode,
Headers: headers,
Body: string(bodyBytes),
}
}

View File

@@ -1,570 +0,0 @@
package jsrt
import (
"bytes"
"crypto/tls"
"encoding/json"
"fmt"
"io"
"net/http"
"one-api/common"
"one-api/model"
"os"
"path/filepath"
"strings"
"sync"
"time"
"github.com/dop251/goja"
"github.com/gin-gonic/gin"
)
// 池化
type JSRuntimePool struct {
pool chan *goja.Runtime
maxSize int
createFunc func() *goja.Runtime
scripts string
mu sync.RWMutex
httpClient *http.Client
}
var (
jsRuntimePool *JSRuntimePool
jsPoolOnce sync.Once
)
func NewJSRuntimePool(maxSize int) *JSRuntimePool {
// 创建HTTP客户端
httpClient := &http.Client{
Timeout: jsConfig.FetchTimeout,
Transport: &http.Transport{
TLSClientConfig: &tls.Config{
InsecureSkipVerify: false,
},
MaxIdleConns: 100,
MaxIdleConnsPerHost: 10,
IdleConnTimeout: 90 * time.Second,
},
}
pool := &JSRuntimePool{
pool: make(chan *goja.Runtime, maxSize),
maxSize: maxSize,
scripts: "",
httpClient: httpClient,
}
pool.createFunc = func() *goja.Runtime {
vm := goja.New()
pool.setupGlobals(vm)
pool.loadScripts(vm)
return vm
}
// 预创建VM
preCreate := min(maxSize/2, 4)
for range preCreate {
select {
case pool.pool <- pool.createFunc():
default:
}
}
return pool
}
func (p *JSRuntimePool) Get() *goja.Runtime {
select {
case vm := <-p.pool:
return vm
default:
return p.createFunc()
}
}
func (p *JSRuntimePool) Put(vm *goja.Runtime) {
if vm == nil {
return
}
select {
case p.pool <- vm:
default:
// 池满丢弃VM让GC回收
}
}
func (p *JSRuntimePool) setupGlobals(vm *goja.Runtime) {
// console
console := vm.NewObject()
console.Set("log", func(args ...any) {
var strs []string
for _, arg := range args {
strs = append(strs, fmt.Sprintf("%v", arg))
}
common.SysLog("JS: " + strings.Join(strs, " "))
})
console.Set("error", func(args ...any) {
var strs []string
for _, arg := range args {
strs = append(strs, fmt.Sprintf("%v", arg))
}
common.SysError("JS: " + strings.Join(strs, " "))
})
console.Set("warn", func(args ...any) {
var strs []string
for _, arg := range args {
strs = append(strs, fmt.Sprintf("%v", arg))
}
common.SysError("JS WARN: " + strings.Join(strs, " "))
})
vm.Set("console", console)
// JSON
jsonObj := vm.NewObject()
jsonObj.Set("parse", func(str string) any {
var result any
err := json.Unmarshal([]byte(str), &result)
if err != nil {
panic(vm.ToValue(err.Error()))
}
return result
})
jsonObj.Set("stringify", func(obj any) string {
data, err := json.Marshal(obj)
if err != nil {
panic(vm.ToValue(err.Error()))
}
return string(data)
})
vm.Set("JSON", jsonObj)
// fetch 实现
vm.Set("fetch", func(url string, options ...any) *JSFetchResponse {
return p.fetch(url, options...)
})
// 数据库
setDB(vm, model.DB, "db")
setDB(vm, model.LOG_DB, "logdb")
// 定时器
vm.Set("setTimeout", func(fn func(), delay int) {
go func() {
time.Sleep(time.Duration(delay) * time.Millisecond)
fn()
}()
})
}
func (p *JSRuntimePool) loadScripts(vm *goja.Runtime) {
p.mu.RLock()
defer p.mu.RUnlock()
// 如果已经缓存了合并的脚本,直接使用
if p.scripts != "" {
if _, err := vm.RunString(p.scripts); err != nil {
common.SysError("Failed to load cached scripts: " + err.Error())
}
return
}
// 首次加载时,读取 scripts/ 文件夹中的所有脚本
p.mu.RUnlock()
p.mu.Lock()
defer func() {
p.mu.Unlock()
p.mu.RLock()
}()
if p.scripts != "" {
if _, err := vm.RunString(p.scripts); err != nil {
common.SysError("Failed to load cached scripts: " + err.Error())
}
return
}
// 读取所有脚本文件
var combinedScript strings.Builder
scriptDir := jsConfig.ScriptDir
// 检查目录是否存在
if _, err := os.Stat(scriptDir); os.IsNotExist(err) {
common.SysLog("Scripts directory does not exist: " + scriptDir)
return
}
// 读取目录中的所有 .js 文件
files, err := filepath.Glob(filepath.Join(scriptDir, "*.js"))
if err != nil {
common.SysError("Failed to read scripts directory: " + err.Error())
return
}
if len(files) == 0 {
common.SysLog("No JavaScript files found in: " + scriptDir)
return
}
// 按文件名排序以确保加载顺序一致
for _, file := range files {
content, err := os.ReadFile(file)
if err != nil {
common.SysError("Failed to read script file " + file + ": " + err.Error())
continue
}
// 添加文件注释和内容
combinedScript.WriteString("// File: " + filepath.Base(file) + "\n")
combinedScript.WriteString(string(content))
combinedScript.WriteString("\n\n")
common.SysLog("Loaded script: " + filepath.Base(file))
}
// 缓存合并后的脚本
p.scripts = combinedScript.String()
// 执行脚本
if p.scripts != "" {
if _, err := vm.RunString(p.scripts); err != nil {
common.SysError("Failed to load combined scripts: " + err.Error())
} else {
common.SysLog("Successfully loaded and combined all JavaScript files from: " + scriptDir)
}
}
}
func (p *JSRuntimePool) ReloadScripts() {
p.mu.Lock()
defer p.mu.Unlock()
// 清空缓存的脚本
p.scripts = ""
// 清空VM池强制重新创建
for {
select {
case <-p.pool:
default:
goto done
}
}
done:
common.SysLog("JavaScript scripts reloaded")
}
func initJSRuntimePool() *JSRuntimePool {
jsPoolOnce.Do(func() {
jsRuntimePool = NewJSRuntimePool(jsConfig.MaxVMCount)
common.SysLog("JavaScript runtime pool initialized successfully")
})
return jsRuntimePool
}
func validateGinContext(c *gin.Context) error {
if c == nil {
return fmt.Errorf("gin context is nil")
}
if c.Request == nil {
return fmt.Errorf("gin context request is nil")
}
return nil
}
func (p *JSRuntimePool) executeWithTimeout(_ *goja.Runtime, fn func() (goja.Value, error)) (goja.Value, error) {
type result struct {
value goja.Value
err error
}
resultChan := make(chan result, 1)
go func() {
defer func() {
if r := recover(); r != nil {
resultChan <- result{err: fmt.Errorf("JS panic: %v", r)}
}
}()
value, err := fn()
resultChan <- result{value: value, err: err}
}()
select {
case res := <-resultChan:
return res.value, res.err
case <-time.After(jsConfig.ScriptTimeout):
return nil, fmt.Errorf("script execution timeout after %v", jsConfig.ScriptTimeout)
}
}
func (p *JSRuntimePool) PreProcessRequest(c *gin.Context) error {
if err := validateGinContext(c); err != nil {
common.SysError("JS PreProcess Validation Error: " + err.Error())
return err
}
vm := p.Get()
defer p.Put(vm)
preProcessFunc := vm.Get("preProcessRequest")
if preProcessFunc == nil || goja.IsUndefined(preProcessFunc) {
return nil
}
jsReq, err := common.StructToMap(createJSReq(c))
if err != nil {
return fmt.Errorf("failed to create JS context: %v", err)
}
result, err := p.executeWithTimeout(vm, func() (goja.Value, error) {
fn, ok := goja.AssertFunction(preProcessFunc)
if !ok {
return nil, fmt.Errorf("preProcessRequest is not a function")
}
return fn(goja.Undefined(), vm.ToValue(jsReq))
})
if err != nil {
common.SysError("JS PreProcess Error: " + err.Error())
return err
}
// 处理返回结果
if result != nil && !goja.IsUndefined(result) {
resultObj := result.Export()
if resultMap, ok := resultObj.(map[string]any); ok {
// 是否修改请求
if newBody, exists := resultMap["body"]; exists {
switch v := newBody.(type) {
case string:
c.Request.Body = io.NopCloser(strings.NewReader(v))
c.Request.ContentLength = int64(len(v))
case []byte:
c.Request.Body = io.NopCloser(bytes.NewBuffer(v))
c.Request.ContentLength = int64(len(v))
case map[string]any:
bodyBytes, err := json.Marshal(v)
if err == nil {
c.Request.Body = io.NopCloser(bytes.NewBuffer(bodyBytes))
c.Request.ContentLength = int64(len(bodyBytes))
c.Request.Header.Set("Content-Type", "application/json")
} else {
common.SysError("JS PreProcess JSON Marshal Error: " + err.Error())
}
default:
common.SysError("JS PreProcess Unsupported Body Type: " + fmt.Sprintf("%T", newBody))
}
}
// 是否修改 headers
if newHeaders, exists := resultMap["headers"]; exists {
if headersMap, ok := newHeaders.(map[string]any); ok {
for key, value := range headersMap {
if valueStr, ok := value.(string); ok {
c.Request.Header.Set(key, valueStr)
}
}
}
}
// 是否阻止请求
if block, exists := resultMap["block"]; exists {
if blockBool, ok := block.(bool); ok && blockBool {
status := http.StatusForbidden
if statusCode, exists := resultMap["statusCode"]; exists {
if statusInt, ok := statusCode.(float64); ok {
status = int(statusInt)
}
}
message := "Request blocked by pre-process script"
if msg, exists := resultMap["message"]; exists {
if msgStr, ok := msg.(string); ok {
message = msgStr
}
}
c.JSON(status, gin.H{"error": message})
c.Abort()
return fmt.Errorf("request blocked")
}
}
}
}
return nil
}
func (p *JSRuntimePool) PostProcessResponse(c *gin.Context, statusCode int, body []byte) (int, []byte, error) {
if err := validateGinContext(c); err != nil {
common.SysError("JS PostProcess Validation Error: " + err.Error())
return statusCode, body, err
}
vm := p.Get()
defer p.Put(vm)
postProcessFunc := vm.Get("postProcessResponse")
if postProcessFunc == nil || goja.IsUndefined(postProcessFunc) {
return statusCode, body, nil
}
jsReq, err := common.StructToMap(createJSReq(c))
if err != nil {
return statusCode, body, fmt.Errorf("failed to create JS context: %v", err)
}
jsResp := &JSResponse{
StatusCode: statusCode,
Headers: make(map[string]string),
Body: string(body),
}
// 获取响应头
if c.Writer != nil {
for key, values := range c.Writer.Header() {
if len(values) > 0 {
jsResp.Headers[key] = values[0]
}
}
}
jsResponse, err := common.StructToMap(jsResp)
if err != nil {
return statusCode, body, fmt.Errorf("failed to create JS response context: %v", err)
}
result, err := p.executeWithTimeout(vm, func() (goja.Value, error) {
fn, ok := goja.AssertFunction(postProcessFunc)
if !ok {
return nil, fmt.Errorf("postProcessResponse is not a function")
}
return fn(goja.Undefined(), vm.ToValue(jsReq), vm.ToValue(jsResponse))
})
if err != nil {
common.SysError("JS PostProcess Error: " + err.Error())
return statusCode, body, err
}
// 处理返回
if result != nil && !goja.IsUndefined(result) {
resultObj := result.Export()
if resultMap, ok := resultObj.(map[string]any); ok {
if newStatusCode, exists := resultMap["statusCode"]; exists {
if statusInt, ok := newStatusCode.(float64); ok {
statusCode = int(statusInt)
}
}
if newBody, exists := resultMap["body"]; exists {
if bodyStr, ok := newBody.(string); ok {
body = []byte(bodyStr)
}
}
if newHeaders, exists := resultMap["headers"]; exists {
if headersMap, ok := newHeaders.(map[string]any); ok {
for key, value := range headersMap {
if valueStr, ok := value.(string); ok {
c.Header(key, valueStr)
}
}
}
}
}
}
return statusCode, body, nil
}
func (p *JSRuntimePool) hasPostProcessFunction() bool {
vm := p.Get()
defer p.Put(vm)
postProcessFunc := vm.Get("postProcessResponse")
return postProcessFunc != nil && !goja.IsUndefined(postProcessFunc)
}
func JSRuntimeMiddleware() *gin.HandlerFunc {
loadCfg()
if !jsConfig.Enabled {
common.SysLog("JavaScript Runtime is disabled")
return nil
}
pool := initJSRuntimePool()
var fn gin.HandlerFunc
fn = func(c *gin.Context) {
start := time.Now()
// 预处理
if err := pool.PreProcessRequest(c); err != nil {
common.SysError("JS Runtime PreProcess Error: " + err.Error())
return
}
duration := time.Since(start)
if duration > time.Millisecond*100 {
common.SysLog(fmt.Sprintf("JS Runtime PreProcess took %v", duration))
}
// 后处理
if pool.hasPostProcessFunction() {
writer := newResponseWriter(c.Writer)
c.Writer = writer
c.Next()
// 后处理响应
if writer.body.Len() > 0 {
start := time.Now()
statusCode, body, err := pool.PostProcessResponse(c, writer.statusCode, writer.body.Bytes())
if err == nil {
c.Writer = writer.ResponseWriter
for k, v := range writer.headerMap {
for _, value := range v {
c.Writer.Header().Add(k, value)
}
}
c.Status(statusCode)
if len(body) >= 0 {
c.Writer.Header().Set("Content-Length", fmt.Sprintf("%d", len(body)))
c.Writer.Write(body)
} else {
c.Writer.Header().Del("Content-Length")
c.Writer.Write(body)
}
} else {
// 出错时回复原响应
c.Writer = writer.ResponseWriter
c.Status(writer.statusCode)
common.SysError(fmt.Sprintf("JS Runtime PostProcess Error: %v", err))
}
duration := time.Since(start)
if duration > time.Millisecond*100 {
common.SysLog(fmt.Sprintf("JS Runtime PostProcess took %v", duration))
}
} else {
// 没有响应体时恢复原始writer
c.Writer = writer.ResponseWriter
}
} else {
c.Next()
}
}
return &fn
}
func ReloadJSScripts() {
if jsRuntimePool != nil {
jsRuntimePool.ReloadScripts()
common.SysLog("JavaScript scripts reloaded")
}
}

View File

@@ -1,139 +0,0 @@
package jsrt
import (
"bytes"
"io"
"maps"
"net/http"
"sync"
"github.com/gin-gonic/gin"
)
// 请求
type JSReq struct {
Method string `json:"method"`
URL string `json:"url"`
Headers map[string]string `json:"headers"`
Body any `json:"body"`
UserAgent string `json:"userAgent"`
RemoteIP string `json:"remoteIP"`
Extra map[string]any `json:"extra"`
}
type JSResponse struct {
StatusCode int `json:"statusCode"`
Headers map[string]string `json:"headers"`
Body string `json:"body"`
}
type responseWriter struct {
gin.ResponseWriter
body *bytes.Buffer
statusCode int
headerMap http.Header
written bool
mu sync.RWMutex
}
func createJSReq(c *gin.Context) *JSReq {
var bodyBytes []byte
if c.Request != nil && c.Request.Body != nil {
bodyBytes, _ = io.ReadAll(c.Request.Body)
c.Request.Body = io.NopCloser(bytes.NewBuffer(bodyBytes))
}
// headers map
headers := make(map[string]string)
if c.Request != nil && c.Request.Header != nil {
for key, values := range c.Request.Header {
if len(values) > 0 {
headers[key] = values[0]
}
}
}
method := ""
url := ""
userAgent := ""
remoteIP := ""
contentType := ""
if c.Request != nil {
method = c.Request.Method
if c.Request.URL != nil {
url = c.Request.URL.String()
}
userAgent = c.Request.UserAgent()
contentType = c.ContentType()
}
if c != nil {
remoteIP = c.ClientIP()
}
parsedBody := parseBodyByType(bodyBytes, contentType)
return &JSReq{
Method: method,
URL: url,
Headers: headers,
Body: parsedBody,
UserAgent: userAgent,
RemoteIP: remoteIP,
Extra: make(map[string]any),
}
}
func newResponseWriter(w gin.ResponseWriter) *responseWriter {
return &responseWriter{
ResponseWriter: w,
body: &bytes.Buffer{},
statusCode: 200,
headerMap: make(http.Header),
written: false,
}
}
func (w *responseWriter) Write(data []byte) (int, error) {
w.mu.Lock()
defer w.mu.Unlock()
if !w.written {
w.WriteHeader(200)
}
return w.body.Write(data)
}
func (w *responseWriter) WriteString(s string) (int, error) {
w.mu.Lock()
defer w.mu.Unlock()
if !w.written {
w.WriteHeader(200)
}
return w.body.WriteString(s)
}
func (w *responseWriter) WriteHeader(statusCode int) {
w.mu.Lock()
defer w.mu.Unlock()
if w.written {
return
}
w.statusCode = statusCode
w.written = true
maps.Copy(w.headerMap, w.ResponseWriter.Header())
}
func (w *responseWriter) Header() http.Header {
w.mu.RLock()
defer w.mu.RUnlock()
if w.headerMap == nil {
w.headerMap = make(http.Header)
}
return w.headerMap
}

View File

@@ -1,86 +0,0 @@
package jsrt
import (
"encoding/json"
"net/url"
"one-api/common"
"strings"
"github.com/dop251/goja"
"gorm.io/gorm"
)
func setDB(vm *goja.Runtime, db *gorm.DB, name string) {
if db == nil {
common.SysError("JS DB is nil")
return
}
obj := vm.NewObject()
obj.Set("query", func(sql string, params ...any) []map[string]any {
return dbQuery(db, sql, params...)
})
obj.Set("exec", func(sql string, params ...any) map[string]any {
return dbExec(db, sql, params...)
})
if err := vm.Set(name, obj); err != nil {
common.SysError("Failed to set JS DB: " + err.Error())
return
}
}
func parseBodyByType(bodyBytes []byte, contentType string) any {
if len(bodyBytes) == 0 {
return ""
}
bodyStr := string(bodyBytes)
contentLower := strings.ToLower(contentType)
switch {
case strings.Contains(contentLower, "application/json"):
var jsonObj any
if err := json.Unmarshal(bodyBytes, &jsonObj); err == nil {
return jsonObj
}
return bodyStr
case strings.Contains(contentLower, "application/x-www-form-urlencoded"):
if values, err := url.ParseQuery(bodyStr); err == nil {
result := make(map[string]string, len(values))
for k, v := range values {
if len(v) > 0 {
result[k] = v[0]
}
}
return result
}
return bodyStr
case strings.Contains(contentLower, "multipart/form-data"):
return bodyBytes
case strings.Contains(contentLower, "text/"):
return bodyStr
default:
// 尝试JSON解析
var jsonObj any
if json.Unmarshal(bodyBytes, &jsonObj) == nil {
return jsonObj
}
// 尝试form解析
if values, err := url.ParseQuery(bodyStr); err == nil && len(values) > 0 {
result := make(map[string]string, len(values))
for k, v := range values {
if len(v) > 0 {
result[k] = v[0]
}
}
return result
}
return bodyStr
}
}

View File

@@ -1,47 +0,0 @@
package middleware
import (
"bytes"
"encoding/json"
"io"
"one-api/common"
"one-api/constant"
"github.com/gin-gonic/gin"
)
func KlingRequestConvert() func(c *gin.Context) {
return func(c *gin.Context) {
var originalReq map[string]interface{}
if err := common.UnmarshalBodyReusable(c, &originalReq); err != nil {
c.Next()
return
}
model, _ := originalReq["model"].(string)
prompt, _ := originalReq["prompt"].(string)
unifiedReq := map[string]interface{}{
"model": model,
"prompt": prompt,
"metadata": originalReq,
}
jsonData, err := json.Marshal(unifiedReq)
if err != nil {
c.Next()
return
}
// Rewrite request body and path
c.Request.Body = io.NopCloser(bytes.NewBuffer(jsonData))
c.Request.URL.Path = "/v1/video/generations"
if image := originalReq["image"]; image == "" {
c.Set("action", constant.TaskActionTextGenerate)
}
// We have to reset the request body for the next handlers
c.Set(common.KeyRequestBody, jsonData)
c.Next()
}
}

View File

@@ -177,9 +177,9 @@ func ModelRequestRateLimit() func(c *gin.Context) {
successMaxCount := setting.ModelRequestRateLimitSuccessCount
// 获取分组
group := common.GetContextKeyString(c, constant.ContextKeyTokenGroup)
group := c.GetString("token_group")
if group == "" {
group = common.GetContextKeyString(c, constant.ContextKeyUserGroup)
group = c.GetString(constant.ContextKeyUserGroup)
}
//获取分组的限流配置

View File

@@ -1,41 +0,0 @@
package middleware
import (
"sync/atomic"
"github.com/gin-gonic/gin"
)
// HTTPStats 存储HTTP统计信息
type HTTPStats struct {
activeConnections int64
}
var globalStats = &HTTPStats{}
// StatsMiddleware 统计中间件
func StatsMiddleware() gin.HandlerFunc {
return func(c *gin.Context) {
// 增加活跃连接数
atomic.AddInt64(&globalStats.activeConnections, 1)
// 确保在请求结束时减少连接数
defer func() {
atomic.AddInt64(&globalStats.activeConnections, -1)
}()
c.Next()
}
}
// StatsInfo 统计信息结构
type StatsInfo struct {
ActiveConnections int64 `json:"active_connections"`
}
// GetStats 获取统计信息
func GetStats() StatsInfo {
return StatsInfo{
ActiveConnections: atomic.LoadInt64(&globalStats.activeConnections),
}
}

View File

@@ -5,11 +5,9 @@ import (
"fmt"
"one-api/common"
"strings"
"sync"
"github.com/samber/lo"
"gorm.io/gorm"
"gorm.io/gorm/clause"
)
type Ability struct {
@@ -22,25 +20,10 @@ type Ability struct {
Tag *string `json:"tag" gorm:"index"`
}
type AbilityWithChannel struct {
Ability
ChannelType int `json:"channel_type"`
}
func GetAllEnableAbilityWithChannels() ([]AbilityWithChannel, error) {
var abilities []AbilityWithChannel
err := DB.Table("abilities").
Select("abilities.*, channels.type as channel_type").
Joins("left join channels on abilities.channel_id = channels.id").
Where("abilities.enabled = ?", true).
Scan(&abilities).Error
return abilities, err
}
func GetGroupEnabledModels(group string) []string {
func GetGroupModels(group string) []string {
var models []string
// Find distinct models
DB.Table("abilities").Where(commonGroupCol+" = ? and enabled = ?", group, true).Distinct("model").Pluck("model", &models)
DB.Table("abilities").Where(groupCol+" = ? and enabled = ?", group, true).Distinct("model").Pluck("model", &models)
return models
}
@@ -58,12 +41,16 @@ func GetAllEnableAbilities() []Ability {
}
func getPriority(group string, model string, retry int) (int, error) {
trueVal := "1"
if common.UsingPostgreSQL {
trueVal = "true"
}
var priorities []int
err := DB.Model(&Ability{}).
Select("DISTINCT(priority)").
Where(commonGroupCol+" = ? and model = ? and enabled = ?", group, model, true).
Order("priority DESC"). // 按优先级降序排序
Where(groupCol+" = ? and model = ? and enabled = "+trueVal, group, model).
Order("priority DESC"). // 按优先级降序排序
Pluck("priority", &priorities).Error // Pluck用于将查询的结果直接扫描到一个切片中
if err != nil {
@@ -88,14 +75,18 @@ func getPriority(group string, model string, retry int) (int, error) {
}
func getChannelQuery(group string, model string, retry int) *gorm.DB {
maxPrioritySubQuery := DB.Model(&Ability{}).Select("MAX(priority)").Where(commonGroupCol+" = ? and model = ? and enabled = ?", group, model, true)
channelQuery := DB.Where(commonGroupCol+" = ? and model = ? and enabled = ? and priority = (?)", group, model, true, maxPrioritySubQuery)
trueVal := "1"
if common.UsingPostgreSQL {
trueVal = "true"
}
maxPrioritySubQuery := DB.Model(&Ability{}).Select("MAX(priority)").Where(groupCol+" = ? and model = ? and enabled = "+trueVal, group, model)
channelQuery := DB.Where(groupCol+" = ? and model = ? and enabled = "+trueVal+" and priority = (?)", group, model, maxPrioritySubQuery)
if retry != 0 {
priority, err := getPriority(group, model, retry)
if err != nil {
common.SysError(fmt.Sprintf("Get priority failed: %s", err.Error()))
} else {
channelQuery = DB.Where(commonGroupCol+" = ? and model = ? and enabled = ? and priority = ?", group, model, true, priority)
channelQuery = DB.Where(groupCol+" = ? and model = ? and enabled = "+trueVal+" and priority = ?", group, model, priority)
}
}
@@ -142,15 +133,9 @@ func GetRandomSatisfiedChannel(group string, model string, retry int) (*Channel,
func (channel *Channel) AddAbilities() error {
models_ := strings.Split(channel.Models, ",")
groups_ := strings.Split(channel.Group, ",")
abilitySet := make(map[string]struct{})
abilities := make([]Ability, 0, len(models_))
for _, model := range models_ {
for _, group := range groups_ {
key := group + "|" + model
if _, exists := abilitySet[key]; exists {
continue
}
abilitySet[key] = struct{}{}
ability := Ability{
Group: group,
Model: model,
@@ -167,7 +152,7 @@ func (channel *Channel) AddAbilities() error {
return nil
}
for _, chunk := range lo.Chunk(abilities, 50) {
err := DB.Clauses(clause.OnConflict{DoNothing: true}).Create(&chunk).Error
err := DB.Create(&chunk).Error
if err != nil {
return err
}
@@ -209,15 +194,9 @@ func (channel *Channel) UpdateAbilities(tx *gorm.DB) error {
// Then add new abilities
models_ := strings.Split(channel.Models, ",")
groups_ := strings.Split(channel.Group, ",")
abilitySet := make(map[string]struct{})
abilities := make([]Ability, 0, len(models_))
for _, model := range models_ {
for _, group := range groups_ {
key := group + "|" + model
if _, exists := abilitySet[key]; exists {
continue
}
abilitySet[key] = struct{}{}
ability := Ability{
Group: group,
Model: model,
@@ -233,7 +212,7 @@ func (channel *Channel) UpdateAbilities(tx *gorm.DB) error {
if len(abilities) > 0 {
for _, chunk := range lo.Chunk(abilities, 50) {
err = tx.Clauses(clause.OnConflict{DoNothing: true}).Create(&chunk).Error
err = tx.Create(&chunk).Error
if err != nil {
if isNewTx {
tx.Rollback()
@@ -273,45 +252,74 @@ func UpdateAbilityByTag(tag string, newTag *string, priority *int64, weight *uin
return DB.Model(&Ability{}).Where("tag = ?", tag).Updates(ability).Error
}
var fixLock = sync.Mutex{}
func FixAbility() (int, int, error) {
lock := fixLock.TryLock()
if !lock {
return 0, 0, errors.New("已经有一个修复任务在运行中,请稍后再试")
}
defer fixLock.Unlock()
var channels []*Channel
// Find all channels
err := DB.Model(&Channel{}).Find(&channels).Error
func FixAbility() (int, error) {
var channelIds []int
count := 0
// Find all channel ids from channel table
err := DB.Model(&Channel{}).Pluck("id", &channelIds).Error
if err != nil {
return 0, 0, err
common.SysError(fmt.Sprintf("Get channel ids from channel table failed: %s", err.Error()))
return 0, err
}
if len(channels) == 0 {
return 0, 0, nil
}
successCount := 0
failCount := 0
for _, chunk := range lo.Chunk(channels, 50) {
ids := lo.Map(chunk, func(c *Channel, _ int) int { return c.Id })
// Delete all abilities of this channel
err = DB.Where("channel_id IN ?", ids).Delete(&Ability{}).Error
if err != nil {
common.SysError(fmt.Sprintf("Delete abilities failed: %s", err.Error()))
failCount += len(chunk)
continue
}
// Then add new abilities
for _, channel := range chunk {
err = channel.AddAbilities()
// Delete abilities of channels that are not in channel table - in batches to avoid too many placeholders
if len(channelIds) > 0 {
// Process deletion in chunks to avoid "too many placeholders" error
for _, chunk := range lo.Chunk(channelIds, 100) {
err = DB.Where("channel_id NOT IN (?)", chunk).Delete(&Ability{}).Error
if err != nil {
common.SysError(fmt.Sprintf("Add abilities for channel %d failed: %s", channel.Id, err.Error()))
failCount++
} else {
successCount++
common.SysError(fmt.Sprintf("Delete abilities of channels (batch) that are not in channel table failed: %s", err.Error()))
return 0, err
}
}
} else {
// If no channels exist, delete all abilities
err = DB.Delete(&Ability{}).Error
if err != nil {
common.SysError(fmt.Sprintf("Delete all abilities failed: %s", err.Error()))
return 0, err
}
common.SysLog("Delete all abilities successfully")
return 0, nil
}
common.SysLog(fmt.Sprintf("Delete abilities of channels that are not in channel table successfully, ids: %v", channelIds))
count += len(channelIds)
// Use channelIds to find channel not in abilities table
var abilityChannelIds []int
err = DB.Table("abilities").Distinct("channel_id").Pluck("channel_id", &abilityChannelIds).Error
if err != nil {
common.SysError(fmt.Sprintf("Get channel ids from abilities table failed: %s", err.Error()))
return count, err
}
var channels []Channel
if len(abilityChannelIds) == 0 {
err = DB.Find(&channels).Error
} else {
// Process query in chunks to avoid "too many placeholders" error
err = nil
for _, chunk := range lo.Chunk(abilityChannelIds, 100) {
var channelsChunk []Channel
err = DB.Where("id NOT IN (?)", chunk).Find(&channelsChunk).Error
if err != nil {
common.SysError(fmt.Sprintf("Find channels not in abilities table failed: %s", err.Error()))
return count, err
}
channels = append(channels, channelsChunk...)
}
}
for _, channel := range channels {
err := channel.UpdateAbilities(nil)
if err != nil {
common.SysError(fmt.Sprintf("Update abilities of channel %d failed: %s", channel.Id, err.Error()))
} else {
common.SysLog(fmt.Sprintf("Update abilities of channel %d successfully", channel.Id))
count++
}
}
InitChannelCache()
return successCount, failCount, nil
return count, nil
}

View File

@@ -5,13 +5,10 @@ import (
"fmt"
"math/rand"
"one-api/common"
"one-api/setting"
"sort"
"strings"
"sync"
"time"
"github.com/gin-gonic/gin"
)
var group2model2channels map[string]map[string][]*Channel
@@ -78,43 +75,7 @@ func SyncChannelCache(frequency int) {
}
}
func CacheGetRandomSatisfiedChannel(c *gin.Context, group string, model string, retry int) (*Channel, string, error) {
var channel *Channel
var err error
selectGroup := group
if group == "auto" {
if len(setting.AutoGroups) == 0 {
return nil, selectGroup, errors.New("auto groups is not enabled")
}
for _, autoGroup := range setting.AutoGroups {
if common.DebugEnabled {
println("autoGroup:", autoGroup)
}
channel, _ = getRandomSatisfiedChannel(autoGroup, model, retry)
if channel == nil {
continue
} else {
c.Set("auto_group", autoGroup)
selectGroup = autoGroup
if common.DebugEnabled {
println("selectGroup:", selectGroup)
}
break
}
}
} else {
channel, err = getRandomSatisfiedChannel(group, model, retry)
if err != nil {
return nil, group, err
}
}
if channel == nil {
return nil, group, errors.New("channel not found")
}
return channel, selectGroup, nil
}
func getRandomSatisfiedChannel(group string, model string, retry int) (*Channel, error) {
func CacheGetRandomSatisfiedChannel(group string, model string, retry int) (*Channel, error) {
if strings.HasPrefix(model, "gpt-4-gizmo") {
model = "gpt-4-gizmo-*"
}

View File

@@ -3,7 +3,6 @@ package model
import (
"encoding/json"
"one-api/common"
"one-api/dto"
"strings"
"sync"
@@ -146,7 +145,7 @@ func SearchChannels(keyword string, group string, model string, idSort bool) ([]
}
// 构造基础查询
baseQuery := DB.Model(&Channel{}).Omit("key")
baseQuery := DB.Model(&Channel{}).Omit(keyCol)
// 构造WHERE子句
var whereClause string
@@ -154,15 +153,15 @@ func SearchChannels(keyword string, group string, model string, idSort bool) ([]
if group != "" && group != "null" {
var groupCondition string
if common.UsingMySQL {
groupCondition = `CONCAT(',', ` + commonGroupCol + `, ',') LIKE ?`
groupCondition = `CONCAT(',', ` + groupCol + `, ',') LIKE ?`
} else {
// sqlite, PostgreSQL
groupCondition = `(',' || ` + commonGroupCol + ` || ',') LIKE ?`
groupCondition = `(',' || ` + groupCol + ` || ',') LIKE ?`
}
whereClause = "(id = ? OR name LIKE ? OR " + commonKeyCol + " = ? OR " + baseURLCol + " LIKE ?) AND " + modelsCol + ` LIKE ? AND ` + groupCondition
whereClause = "(id = ? OR name LIKE ? OR " + keyCol + " = ? OR " + baseURLCol + " LIKE ?) AND " + modelsCol + ` LIKE ? AND ` + groupCondition
args = append(args, common.String2Int(keyword), "%"+keyword+"%", keyword, "%"+keyword+"%", "%"+model+"%", "%,"+group+",%")
} else {
whereClause = "(id = ? OR name LIKE ? OR " + commonKeyCol + " = ? OR " + baseURLCol + " LIKE ?) AND " + modelsCol + " LIKE ?"
whereClause = "(id = ? OR name LIKE ? OR " + keyCol + " = ? OR " + baseURLCol + " LIKE ?) AND " + modelsCol + " LIKE ?"
args = append(args, common.String2Int(keyword), "%"+keyword+"%", keyword, "%"+keyword+"%", "%"+model+"%")
}
@@ -479,7 +478,7 @@ func SearchTags(keyword string, group string, model string, idSort bool) ([]*str
}
// 构造基础查询
baseQuery := DB.Model(&Channel{}).Omit("key")
baseQuery := DB.Model(&Channel{}).Omit(keyCol)
// 构造WHERE子句
var whereClause string
@@ -487,15 +486,15 @@ func SearchTags(keyword string, group string, model string, idSort bool) ([]*str
if group != "" && group != "null" {
var groupCondition string
if common.UsingMySQL {
groupCondition = `CONCAT(',', ` + commonGroupCol + `, ',') LIKE ?`
groupCondition = `CONCAT(',', ` + groupCol + `, ',') LIKE ?`
} else {
// sqlite, PostgreSQL
groupCondition = `(',' || ` + commonGroupCol + ` || ',') LIKE ?`
groupCondition = `(',' || ` + groupCol + ` || ',') LIKE ?`
}
whereClause = "(id = ? OR name LIKE ? OR " + commonKeyCol + " = ? OR " + baseURLCol + " LIKE ?) AND " + modelsCol + ` LIKE ? AND ` + groupCondition
whereClause = "(id = ? OR name LIKE ? OR " + keyCol + " = ? OR " + baseURLCol + " LIKE ?) AND " + modelsCol + ` LIKE ? AND ` + groupCondition
args = append(args, common.String2Int(keyword), "%"+keyword+"%", keyword, "%"+keyword+"%", "%"+model+"%", "%,"+group+",%")
} else {
whereClause = "(id = ? OR name LIKE ? OR " + commonKeyCol + " = ? OR " + baseURLCol + " LIKE ?) AND " + modelsCol + " LIKE ?"
whereClause = "(id = ? OR name LIKE ? OR " + keyCol + " = ? OR " + baseURLCol + " LIKE ?) AND " + modelsCol + " LIKE ?"
args = append(args, common.String2Int(keyword), "%"+keyword+"%", keyword, "%"+keyword+"%", "%"+model+"%")
}
@@ -515,19 +514,8 @@ func SearchTags(keyword string, group string, model string, idSort bool) ([]*str
return tags, nil
}
func (channel *Channel) ValidateSettings() error {
channelParams := &dto.ChannelSettings{}
if channel.Setting != nil && *channel.Setting != "" {
err := json.Unmarshal([]byte(*channel.Setting), channelParams)
if err != nil {
return err
}
}
return nil
}
func (channel *Channel) GetSetting() dto.ChannelSettings {
setting := dto.ChannelSettings{}
func (channel *Channel) GetSetting() map[string]interface{} {
setting := make(map[string]interface{})
if channel.Setting != nil && *channel.Setting != "" {
err := json.Unmarshal([]byte(*channel.Setting), &setting)
if err != nil {
@@ -537,7 +525,7 @@ func (channel *Channel) GetSetting() dto.ChannelSettings {
return setting
}
func (channel *Channel) SetSetting(setting dto.ChannelSettings) {
func (channel *Channel) SetSetting(setting map[string]interface{}) {
settingBytes, err := json.Marshal(setting)
if err != nil {
common.SysError("failed to marshal setting: " + err.Error())
@@ -595,53 +583,3 @@ func BatchSetChannelTag(ids []int, tag *string) error {
// 提交事务
return tx.Commit().Error
}
// CountAllChannels returns total channels in DB
func CountAllChannels() (int64, error) {
var total int64
err := DB.Model(&Channel{}).Count(&total).Error
return total, err
}
// CountAllTags returns number of non-empty distinct tags
func CountAllTags() (int64, error) {
var total int64
err := DB.Model(&Channel{}).Where("tag is not null AND tag != ''").Distinct("tag").Count(&total).Error
return total, err
}
// Get channels of specified type with pagination
func GetChannelsByType(startIdx int, num int, idSort bool, channelType int) ([]*Channel, error) {
var channels []*Channel
order := "priority desc"
if idSort {
order = "id desc"
}
err := DB.Where("type = ?", channelType).Order(order).Limit(num).Offset(startIdx).Omit("key").Find(&channels).Error
return channels, err
}
// Count channels of specific type
func CountChannelsByType(channelType int) (int64, error) {
var count int64
err := DB.Model(&Channel{}).Where("type = ?", channelType).Count(&count).Error
return count, err
}
// Return map[type]count for all channels
func CountChannelsGroupByType() (map[int64]int64, error) {
type result struct {
Type int64 `gorm:"column:type"`
Count int64 `gorm:"column:count"`
}
var results []result
err := DB.Model(&Channel{}).Select("type, count(*) as count").Group("type").Find(&results).Error
if err != nil {
return nil, err
}
counts := make(map[int64]int64)
for _, r := range results {
counts[r.Type] = r.Count
}
return counts, nil
}

View File

@@ -32,7 +32,6 @@ type Log struct {
ChannelName string `json:"channel_name" gorm:"->"`
TokenId int `json:"token_id" gorm:"default:0;index"`
Group string `json:"group" gorm:"index"`
Ip string `json:"ip" gorm:"index;default:''"`
Other string `json:"other"`
}
@@ -62,7 +61,7 @@ func formatUserLogs(logs []*Log) {
func GetLogByKey(key string) (logs []*Log, err error) {
if os.Getenv("LOG_SQL_DSN") != "" {
var tk Token
if err = DB.Model(&Token{}).Where(logKeyCol+"=?", strings.TrimPrefix(key, "sk-")).First(&tk).Error; err != nil {
if err = DB.Model(&Token{}).Where(keyCol+"=?", strings.TrimPrefix(key, "sk-")).First(&tk).Error; err != nil {
return nil, err
}
err = LOG_DB.Model(&Log{}).Where("token_id=?", tk.Id).Find(&logs).Error
@@ -96,13 +95,6 @@ func RecordErrorLog(c *gin.Context, userId int, channelId int, modelName string,
common.LogInfo(c, fmt.Sprintf("record error log: userId=%d, channelId=%d, modelName=%s, tokenName=%s, content=%s", userId, channelId, modelName, tokenName, content))
username := c.GetString("username")
otherStr := common.MapToJsonStr(other)
// 判断是否需要记录 IP
needRecordIp := false
if settingMap, err := GetUserSetting(userId, false); err == nil {
if settingMap.RecordIpLog {
needRecordIp = true
}
}
log := &Log{
UserId: userId,
Username: username,
@@ -119,13 +111,7 @@ func RecordErrorLog(c *gin.Context, userId int, channelId int, modelName string,
UseTime: useTimeSeconds,
IsStream: isStream,
Group: group,
Ip: func() string {
if needRecordIp {
return c.ClientIP()
}
return ""
}(),
Other: otherStr,
Other: otherStr,
}
err := LOG_DB.Create(log).Error
if err != nil {
@@ -133,59 +119,32 @@ func RecordErrorLog(c *gin.Context, userId int, channelId int, modelName string,
}
}
type RecordConsumeLogParams struct {
ChannelId int `json:"channel_id"`
PromptTokens int `json:"prompt_tokens"`
CompletionTokens int `json:"completion_tokens"`
ModelName string `json:"model_name"`
TokenName string `json:"token_name"`
Quota int `json:"quota"`
Content string `json:"content"`
TokenId int `json:"token_id"`
UserQuota int `json:"user_quota"`
UseTimeSeconds int `json:"use_time_seconds"`
IsStream bool `json:"is_stream"`
Group string `json:"group"`
Other map[string]interface{} `json:"other"`
}
func RecordConsumeLog(c *gin.Context, userId int, params RecordConsumeLogParams) {
common.LogInfo(c, fmt.Sprintf("record consume log: userId=%d, params=%s", userId, common.GetJsonString(params)))
func RecordConsumeLog(c *gin.Context, userId int, channelId int, promptTokens int, completionTokens int,
modelName string, tokenName string, quota int, content string, tokenId int, userQuota int, useTimeSeconds int,
isStream bool, group string, other map[string]interface{}) {
common.LogInfo(c, fmt.Sprintf("record consume log: userId=%d, 用户调用前余额=%d, channelId=%d, promptTokens=%d, completionTokens=%d, modelName=%s, tokenName=%s, quota=%d, content=%s", userId, userQuota, channelId, promptTokens, completionTokens, modelName, tokenName, quota, content))
if !common.LogConsumeEnabled {
return
}
username := c.GetString("username")
otherStr := common.MapToJsonStr(params.Other)
// 判断是否需要记录 IP
needRecordIp := false
if settingMap, err := GetUserSetting(userId, false); err == nil {
if settingMap.RecordIpLog {
needRecordIp = true
}
}
otherStr := common.MapToJsonStr(other)
log := &Log{
UserId: userId,
Username: username,
CreatedAt: common.GetTimestamp(),
Type: LogTypeConsume,
Content: params.Content,
PromptTokens: params.PromptTokens,
CompletionTokens: params.CompletionTokens,
TokenName: params.TokenName,
ModelName: params.ModelName,
Quota: params.Quota,
ChannelId: params.ChannelId,
TokenId: params.TokenId,
UseTime: params.UseTimeSeconds,
IsStream: params.IsStream,
Group: params.Group,
Ip: func() string {
if needRecordIp {
return c.ClientIP()
}
return ""
}(),
Other: otherStr,
Content: content,
PromptTokens: promptTokens,
CompletionTokens: completionTokens,
TokenName: tokenName,
ModelName: modelName,
Quota: quota,
ChannelId: channelId,
TokenId: tokenId,
UseTime: useTimeSeconds,
IsStream: isStream,
Group: group,
Other: otherStr,
}
err := LOG_DB.Create(log).Error
if err != nil {
@@ -193,7 +152,7 @@ func RecordConsumeLog(c *gin.Context, userId int, params RecordConsumeLogParams)
}
if common.DataExportEnabled {
gopool.Go(func() {
LogQuotaData(userId, username, params.ModelName, params.Quota, common.GetTimestamp(), params.PromptTokens+params.CompletionTokens)
LogQuotaData(userId, username, modelName, quota, common.GetTimestamp(), promptTokens+completionTokens)
})
}
}
@@ -225,7 +184,7 @@ func GetAllLogs(logType int, startTimestamp int64, endTimestamp int64, modelName
tx = tx.Where("logs.channel_id = ?", channel)
}
if group != "" {
tx = tx.Where("logs."+logGroupCol+" = ?", group)
tx = tx.Where("logs."+groupCol+" = ?", group)
}
err = tx.Model(&Log{}).Count(&total).Error
if err != nil {
@@ -236,18 +195,13 @@ func GetAllLogs(logType int, startTimestamp int64, endTimestamp int64, modelName
return nil, 0, err
}
channelIdsMap := make(map[int]struct{})
channelIds := make([]int, 0)
channelMap := make(map[int]string)
for _, log := range logs {
if log.ChannelId != 0 {
channelIdsMap[log.ChannelId] = struct{}{}
channelIds = append(channelIds, log.ChannelId)
}
}
channelIds := make([]int, 0, len(channelIdsMap))
for channelId := range channelIdsMap {
channelIds = append(channelIds, channelId)
}
if len(channelIds) > 0 {
var channels []struct {
Id int `gorm:"column:id"`
@@ -288,7 +242,7 @@ func GetUserLogs(userId int, logType int, startTimestamp int64, endTimestamp int
tx = tx.Where("logs.created_at <= ?", endTimestamp)
}
if group != "" {
tx = tx.Where("logs."+logGroupCol+" = ?", group)
tx = tx.Where("logs."+groupCol+" = ?", group)
}
err = tx.Model(&Log{}).Count(&total).Error
if err != nil {
@@ -349,8 +303,8 @@ func SumUsedQuota(logType int, startTimestamp int64, endTimestamp int64, modelNa
rpmTpmQuery = rpmTpmQuery.Where("channel_id = ?", channel)
}
if group != "" {
tx = tx.Where(logGroupCol+" = ?", group)
rpmTpmQuery = rpmTpmQuery.Where(logGroupCol+" = ?", group)
tx = tx.Where(groupCol+" = ?", group)
rpmTpmQuery = rpmTpmQuery.Where(groupCol+" = ?", group)
}
tx = tx.Where("type = ?", LogTypeConsume)

View File

@@ -1,7 +1,6 @@
package model
import (
"fmt"
"log"
"one-api/common"
"one-api/constant"
@@ -16,48 +15,18 @@ import (
"gorm.io/gorm"
)
var commonGroupCol string
var commonKeyCol string
var commonTrueVal string
var commonFalseVal string
var logKeyCol string
var logGroupCol string
var groupCol string
var keyCol string
func initCol() {
// init common column names
if common.UsingPostgreSQL {
commonGroupCol = `"group"`
commonKeyCol = `"key"`
commonTrueVal = "true"
commonFalseVal = "false"
groupCol = `"group"`
keyCol = `"key"`
} else {
commonGroupCol = "`group`"
commonKeyCol = "`key`"
commonTrueVal = "1"
commonFalseVal = "0"
groupCol = "`group`"
keyCol = "`key`"
}
if os.Getenv("LOG_SQL_DSN") != "" {
switch common.LogSqlType {
case common.DatabaseTypePostgreSQL:
logGroupCol = `"group"`
logKeyCol = `"key"`
default:
logGroupCol = commonGroupCol
logKeyCol = commonKeyCol
}
} else {
// LOG_SQL_DSN 为空时,日志数据库与主数据库相同
if common.UsingPostgreSQL {
logGroupCol = `"group"`
logKeyCol = `"key"`
} else {
logGroupCol = commonGroupCol
logKeyCol = commonKeyCol
}
}
// log sql type and database type
common.SysLog("Using Log SQL Type: " + common.LogSqlType)
}
var DB *gorm.DB
@@ -114,7 +83,7 @@ func CheckSetup() {
}
}
func chooseDB(envName string, isLog bool) (*gorm.DB, error) {
func chooseDB(envName string) (*gorm.DB, error) {
defer func() {
initCol()
}()
@@ -123,11 +92,7 @@ func chooseDB(envName string, isLog bool) (*gorm.DB, error) {
if strings.HasPrefix(dsn, "postgres://") || strings.HasPrefix(dsn, "postgresql://") {
// Use PostgreSQL
common.SysLog("using PostgreSQL as database")
if !isLog {
common.UsingPostgreSQL = true
} else {
common.LogSqlType = common.DatabaseTypePostgreSQL
}
common.UsingPostgreSQL = true
return gorm.Open(postgres.New(postgres.Config{
DSN: dsn,
PreferSimpleProtocol: true, // disables implicit prepared statement usage
@@ -137,11 +102,7 @@ func chooseDB(envName string, isLog bool) (*gorm.DB, error) {
}
if strings.HasPrefix(dsn, "local") {
common.SysLog("SQL_DSN not set, using SQLite as database")
if !isLog {
common.UsingSQLite = true
} else {
common.LogSqlType = common.DatabaseTypeSQLite
}
common.UsingSQLite = true
return gorm.Open(sqlite.Open(common.SQLitePath), &gorm.Config{
PrepareStmt: true, // precompile SQL
})
@@ -156,11 +117,7 @@ func chooseDB(envName string, isLog bool) (*gorm.DB, error) {
dsn += "?parseTime=true"
}
}
if !isLog {
common.UsingMySQL = true
} else {
common.LogSqlType = common.DatabaseTypeMySQL
}
common.UsingMySQL = true
return gorm.Open(mysql.Open(dsn), &gorm.Config{
PrepareStmt: true, // precompile SQL
})
@@ -174,7 +131,7 @@ func chooseDB(envName string, isLog bool) (*gorm.DB, error) {
}
func InitDB() (err error) {
db, err := chooseDB("SQL_DSN", false)
db, err := chooseDB("SQL_DSN")
if err == nil {
if common.DebugEnabled {
db = db.Debug()
@@ -192,7 +149,7 @@ func InitDB() (err error) {
return nil
}
if common.UsingMySQL {
//_, _ = sqlDB.Exec("ALTER TABLE channels MODIFY model_mapping TEXT;") // TODO: delete this line when most users have upgraded
_, _ = sqlDB.Exec("ALTER TABLE channels MODIFY model_mapping TEXT;") // TODO: delete this line when most users have upgraded
}
common.SysLog("database migration started")
err = migrateDB()
@@ -208,7 +165,7 @@ func InitLogDB() (err error) {
LOG_DB = DB
return
}
db, err := chooseDB("LOG_SQL_DSN", true)
db, err := chooseDB("LOG_SQL_DSN")
if err == nil {
if common.DebugEnabled {
db = db.Debug()
@@ -241,73 +198,54 @@ func InitLogDB() (err error) {
}
func migrateDB() error {
if !common.UsingPostgreSQL {
return migrateDBFast()
}
err := DB.AutoMigrate(
&Channel{},
&Token{},
&User{},
&Option{},
&Redemption{},
&Ability{},
&Log{},
&Midjourney{},
&TopUp{},
&QuotaData{},
&Task{},
&Setup{},
)
err := DB.AutoMigrate(&Channel{})
if err != nil {
return err
}
return nil
}
func migrateDBFast() error {
var wg sync.WaitGroup
errChan := make(chan error, 12) // Buffer size matches number of migrations
migrations := []struct {
model interface{}
name string
}{
{&Channel{}, "Channel"},
{&Token{}, "Token"},
{&User{}, "User"},
{&Option{}, "Option"},
{&Redemption{}, "Redemption"},
{&Ability{}, "Ability"},
{&Log{}, "Log"},
{&Midjourney{}, "Midjourney"},
{&TopUp{}, "TopUp"},
{&QuotaData{}, "QuotaData"},
{&Task{}, "Task"},
{&Setup{}, "Setup"},
err = DB.AutoMigrate(&Token{})
if err != nil {
return err
}
for _, m := range migrations {
wg.Add(1)
go func(model interface{}, name string) {
defer wg.Done()
if err := DB.AutoMigrate(model); err != nil {
errChan <- fmt.Errorf("failed to migrate %s: %v", name, err)
}
}(m.model, m.name)
err = DB.AutoMigrate(&User{})
if err != nil {
return err
}
// Wait for all migrations to complete
wg.Wait()
close(errChan)
// Check for any errors
for err := range errChan {
if err != nil {
return err
}
err = DB.AutoMigrate(&Option{})
if err != nil {
return err
}
err = DB.AutoMigrate(&Redemption{})
if err != nil {
return err
}
err = DB.AutoMigrate(&Ability{})
if err != nil {
return err
}
err = DB.AutoMigrate(&Log{})
if err != nil {
return err
}
err = DB.AutoMigrate(&Midjourney{})
if err != nil {
return err
}
err = DB.AutoMigrate(&TopUp{})
if err != nil {
return err
}
err = DB.AutoMigrate(&QuotaData{})
if err != nil {
return err
}
err = DB.AutoMigrate(&Task{})
if err != nil {
return err
}
err = DB.AutoMigrate(&Setup{})
common.SysLog("database migrated")
return nil
//err = createRootAccountIfNeed()
return err
}
func migrateLOGDB() error {

View File

@@ -14,8 +14,6 @@ type Midjourney struct {
StartTime int64 `json:"start_time" gorm:"index"`
FinishTime int64 `json:"finish_time" gorm:"index"`
ImageUrl string `json:"image_url"`
VideoUrl string `json:"video_url"`
VideoUrls string `json:"video_urls"`
Status string `json:"status" gorm:"type:varchar(20);index"`
Progress string `json:"progress" gorm:"type:varchar(30);index"`
FailReason string `json:"fail_reason"`
@@ -168,40 +166,3 @@ func MjBulkUpdateByTaskIds(taskIDs []int, params map[string]any) error {
Where("id in (?)", taskIDs).
Updates(params).Error
}
// CountAllTasks returns total midjourney tasks for admin query
func CountAllTasks(queryParams TaskQueryParams) int64 {
var total int64
query := DB.Model(&Midjourney{})
if queryParams.ChannelID != "" {
query = query.Where("channel_id = ?", queryParams.ChannelID)
}
if queryParams.MjID != "" {
query = query.Where("mj_id = ?", queryParams.MjID)
}
if queryParams.StartTimestamp != "" {
query = query.Where("submit_time >= ?", queryParams.StartTimestamp)
}
if queryParams.EndTimestamp != "" {
query = query.Where("submit_time <= ?", queryParams.EndTimestamp)
}
_ = query.Count(&total).Error
return total
}
// CountAllUserTask returns total midjourney tasks for user
func CountAllUserTask(userId int, queryParams TaskQueryParams) int64 {
var total int64
query := DB.Model(&Midjourney{}).Where("user_id = ?", userId)
if queryParams.MjID != "" {
query = query.Where("mj_id = ?", queryParams.MjID)
}
if queryParams.StartTimestamp != "" {
query = query.Where("submit_time >= ?", queryParams.StartTimestamp)
}
if queryParams.EndTimestamp != "" {
query = query.Where("submit_time <= ?", queryParams.EndTimestamp)
}
_ = query.Count(&total).Error
return total
}

View File

@@ -5,7 +5,6 @@ import (
"one-api/setting"
"one-api/setting/config"
"one-api/setting/operation_setting"
"one-api/setting/ratio_setting"
"strconv"
"strings"
"time"
@@ -77,9 +76,6 @@ func InitOptionMap() {
common.OptionMap["MinTopUp"] = strconv.Itoa(setting.MinTopUp)
common.OptionMap["TopupGroupRatio"] = common.TopupGroupRatio2JSONString()
common.OptionMap["Chats"] = setting.Chats2JsonString()
common.OptionMap["AutoGroups"] = setting.AutoGroups2JsonString()
common.OptionMap["DefaultUseAutoGroup"] = strconv.FormatBool(setting.DefaultUseAutoGroup)
common.OptionMap["PayMethods"] = setting.PayMethods2JsonString()
common.OptionMap["GitHubClientId"] = ""
common.OptionMap["GitHubClientSecret"] = ""
common.OptionMap["TelegramBotToken"] = ""
@@ -98,13 +94,12 @@ func InitOptionMap() {
common.OptionMap["ModelRequestRateLimitDurationMinutes"] = strconv.Itoa(setting.ModelRequestRateLimitDurationMinutes)
common.OptionMap["ModelRequestRateLimitSuccessCount"] = strconv.Itoa(setting.ModelRequestRateLimitSuccessCount)
common.OptionMap["ModelRequestRateLimitGroup"] = setting.ModelRequestRateLimitGroup2JSONString()
common.OptionMap["ModelRatio"] = ratio_setting.ModelRatio2JSONString()
common.OptionMap["ModelPrice"] = ratio_setting.ModelPrice2JSONString()
common.OptionMap["CacheRatio"] = ratio_setting.CacheRatio2JSONString()
common.OptionMap["GroupRatio"] = ratio_setting.GroupRatio2JSONString()
common.OptionMap["GroupGroupRatio"] = ratio_setting.GroupGroupRatio2JSONString()
common.OptionMap["ModelRatio"] = operation_setting.ModelRatio2JSONString()
common.OptionMap["ModelPrice"] = operation_setting.ModelPrice2JSONString()
common.OptionMap["CacheRatio"] = operation_setting.CacheRatio2JSONString()
common.OptionMap["GroupRatio"] = setting.GroupRatio2JSONString()
common.OptionMap["UserUsableGroups"] = setting.UserUsableGroups2JSONString()
common.OptionMap["CompletionRatio"] = ratio_setting.CompletionRatio2JSONString()
common.OptionMap["CompletionRatio"] = operation_setting.CompletionRatio2JSONString()
common.OptionMap["TopUpLink"] = common.TopUpLink
//common.OptionMap["ChatLink"] = common.ChatLink
//common.OptionMap["ChatLink2"] = common.ChatLink2
@@ -127,7 +122,6 @@ func InitOptionMap() {
common.OptionMap["SensitiveWords"] = setting.SensitiveWordsToString()
common.OptionMap["StreamCacheQueueLength"] = strconv.Itoa(setting.StreamCacheQueueLength)
common.OptionMap["AutomaticDisableKeywords"] = operation_setting.AutomaticDisableKeywordsToString()
common.OptionMap["ExposeRatioEnabled"] = strconv.FormatBool(ratio_setting.IsExposeRatioEnabled())
// 自动添加所有注册的模型配置
modelConfigs := config.GlobalConfig.ExportAllConfigs()
@@ -197,7 +191,7 @@ func updateOptionMap(key string, value string) (err error) {
common.ImageDownloadPermission = intValue
}
}
if strings.HasSuffix(key, "Enabled") || key == "DefaultCollapseSidebar" || key == "DefaultUseAutoGroup" {
if strings.HasSuffix(key, "Enabled") || key == "DefaultCollapseSidebar" {
boolValue := value == "true"
switch key {
case "PasswordRegisterEnabled":
@@ -266,10 +260,6 @@ func updateOptionMap(key string, value string) (err error) {
common.SMTPSSLEnabled = boolValue
case "WorkerAllowHttpImageRequestEnabled":
setting.WorkerAllowHttpImageRequestEnabled = boolValue
case "DefaultUseAutoGroup":
setting.DefaultUseAutoGroup = boolValue
case "ExposeRatioEnabled":
ratio_setting.SetExposeRatioEnabled(boolValue)
}
}
switch key {
@@ -296,8 +286,6 @@ func updateOptionMap(key string, value string) (err error) {
setting.PayAddress = value
case "Chats":
err = setting.UpdateChatsByJsonString(value)
case "AutoGroups":
err = setting.UpdateAutoGroupsByJsonString(value)
case "CustomCallbackAddress":
setting.CustomCallbackAddress = value
case "EpayId":
@@ -363,19 +351,17 @@ func updateOptionMap(key string, value string) (err error) {
case "DataExportDefaultTime":
common.DataExportDefaultTime = value
case "ModelRatio":
err = ratio_setting.UpdateModelRatioByJSONString(value)
err = operation_setting.UpdateModelRatioByJSONString(value)
case "GroupRatio":
err = ratio_setting.UpdateGroupRatioByJSONString(value)
case "GroupGroupRatio":
err = ratio_setting.UpdateGroupGroupRatioByJSONString(value)
err = setting.UpdateGroupRatioByJSONString(value)
case "UserUsableGroups":
err = setting.UpdateUserUsableGroupsByJSONString(value)
case "CompletionRatio":
err = ratio_setting.UpdateCompletionRatioByJSONString(value)
err = operation_setting.UpdateCompletionRatioByJSONString(value)
case "ModelPrice":
err = ratio_setting.UpdateModelPriceByJSONString(value)
err = operation_setting.UpdateModelPriceByJSONString(value)
case "CacheRatio":
err = ratio_setting.UpdateCacheRatioByJSONString(value)
err = operation_setting.UpdateCacheRatioByJSONString(value)
case "TopUpLink":
common.TopUpLink = value
//case "ChatLink":
@@ -392,8 +378,6 @@ func updateOptionMap(key string, value string) (err error) {
operation_setting.AutomaticDisableKeywordsFromString(value)
case "StreamCacheQueueLength":
setting.StreamCacheQueueLength, _ = strconv.Atoi(value)
case "PayMethods":
err = setting.UpdatePayMethodsByJsonString(value)
}
return err
}

View File

@@ -1,24 +1,20 @@
package model
import (
"fmt"
"one-api/common"
"one-api/constant"
"one-api/setting/ratio_setting"
"one-api/types"
"one-api/setting/operation_setting"
"sync"
"time"
)
type Pricing struct {
ModelName string `json:"model_name"`
QuotaType int `json:"quota_type"`
ModelRatio float64 `json:"model_ratio"`
ModelPrice float64 `json:"model_price"`
OwnerBy string `json:"owner_by"`
CompletionRatio float64 `json:"completion_ratio"`
EnableGroup []string `json:"enable_groups"`
SupportedEndpointTypes []constant.EndpointType `json:"supported_endpoint_types"`
ModelName string `json:"model_name"`
QuotaType int `json:"quota_type"`
ModelRatio float64 `json:"model_ratio"`
ModelPrice float64 `json:"model_price"`
OwnerBy string `json:"owner_by"`
CompletionRatio float64 `json:"completion_ratio"`
EnableGroup []string `json:"enable_groups,omitempty"`
}
var (
@@ -27,98 +23,56 @@ var (
updatePricingLock sync.Mutex
)
var (
modelSupportEndpointTypes = make(map[string][]constant.EndpointType)
modelSupportEndpointsLock = sync.RWMutex{}
)
func GetPricing() []Pricing {
if time.Since(lastGetPricingTime) > time.Minute*1 || len(pricingMap) == 0 {
updatePricingLock.Lock()
defer updatePricingLock.Unlock()
// Double check after acquiring the lock
if time.Since(lastGetPricingTime) > time.Minute*1 || len(pricingMap) == 0 {
modelSupportEndpointsLock.Lock()
defer modelSupportEndpointsLock.Unlock()
updatePricing()
}
}
return pricingMap
}
updatePricingLock.Lock()
defer updatePricingLock.Unlock()
func GetModelSupportEndpointTypes(model string) []constant.EndpointType {
if model == "" {
return make([]constant.EndpointType, 0)
if time.Since(lastGetPricingTime) > time.Minute*1 || len(pricingMap) == 0 {
updatePricing()
}
modelSupportEndpointsLock.RLock()
defer modelSupportEndpointsLock.RUnlock()
if endpoints, ok := modelSupportEndpointTypes[model]; ok {
return endpoints
}
return make([]constant.EndpointType, 0)
//if group != "" {
// userPricingMap := make([]Pricing, 0)
// models := GetGroupModels(group)
// for _, pricing := range pricingMap {
// if !common.StringsContains(models, pricing.ModelName) {
// pricing.Available = false
// }
// userPricingMap = append(userPricingMap, pricing)
// }
// return userPricingMap
//}
return pricingMap
}
func updatePricing() {
//modelRatios := common.GetModelRatios()
enableAbilities, err := GetAllEnableAbilityWithChannels()
if err != nil {
common.SysError(fmt.Sprintf("GetAllEnableAbilityWithChannels error: %v", err))
return
}
modelGroupsMap := make(map[string]*types.Set[string])
enableAbilities := GetAllEnableAbilities()
modelGroupsMap := make(map[string][]string)
for _, ability := range enableAbilities {
groups, ok := modelGroupsMap[ability.Model]
if !ok {
groups = types.NewSet[string]()
modelGroupsMap[ability.Model] = groups
groups := modelGroupsMap[ability.Model]
if groups == nil {
groups = make([]string, 0)
}
groups.Add(ability.Group)
}
//这里使用切片而不是Set因为一个模型可能支持多个端点类型并且第一个端点是优先使用端点
modelSupportEndpointsStr := make(map[string][]string)
for _, ability := range enableAbilities {
endpoints, ok := modelSupportEndpointsStr[ability.Model]
if !ok {
endpoints = make([]string, 0)
modelSupportEndpointsStr[ability.Model] = endpoints
if !common.StringsContains(groups, ability.Group) {
groups = append(groups, ability.Group)
}
channelTypes := common.GetEndpointTypesByChannelType(ability.ChannelType, ability.Model)
for _, channelType := range channelTypes {
if !common.StringsContains(endpoints, string(channelType)) {
endpoints = append(endpoints, string(channelType))
}
}
modelSupportEndpointsStr[ability.Model] = endpoints
}
modelSupportEndpointTypes = make(map[string][]constant.EndpointType)
for model, endpoints := range modelSupportEndpointsStr {
supportedEndpoints := make([]constant.EndpointType, 0)
for _, endpointStr := range endpoints {
endpointType := constant.EndpointType(endpointStr)
supportedEndpoints = append(supportedEndpoints, endpointType)
}
modelSupportEndpointTypes[model] = supportedEndpoints
modelGroupsMap[ability.Model] = groups
}
pricingMap = make([]Pricing, 0)
for model, groups := range modelGroupsMap {
pricing := Pricing{
ModelName: model,
EnableGroup: groups.Items(),
SupportedEndpointTypes: modelSupportEndpointTypes[model],
ModelName: model,
EnableGroup: groups,
}
modelPrice, findPrice := ratio_setting.GetModelPrice(model, false)
modelPrice, findPrice := operation_setting.GetModelPrice(model, false)
if findPrice {
pricing.ModelPrice = modelPrice
pricing.QuotaType = 1
} else {
modelRatio, _ := ratio_setting.GetModelRatio(model)
modelRatio, _ := operation_setting.GetModelRatio(model)
pricing.ModelRatio = modelRatio
pricing.CompletionRatio = ratio_setting.GetCompletionRatio(model)
pricing.CompletionRatio = operation_setting.GetCompletionRatio(model)
pricing.QuotaType = 0
}
pricingMap = append(pricingMap, pricing)

View File

@@ -21,7 +21,6 @@ type Redemption struct {
Count int `json:"count" gorm:"-:all"` // only for api request
UsedUserId int `json:"used_user_id"`
DeletedAt gorm.DeletedAt `gorm:"index"`
ExpiredTime int64 `json:"expired_time" gorm:"bigint"` // 过期时间0 表示不过期
}
func GetAllRedemptions(startIdx int, num int) (redemptions []*Redemption, total int64, err error) {
@@ -132,9 +131,6 @@ func Redeem(key string, userId int) (quota int, err error) {
if redemption.Status != common.RedemptionCodeStatusEnabled {
return errors.New("该兑换码已被使用")
}
if redemption.ExpiredTime != 0 && redemption.ExpiredTime < common.GetTimestamp() {
return errors.New("该兑换码已过期")
}
err = tx.Model(&User{}).Where("id = ?", userId).Update("quota", gorm.Expr("quota + ?", redemption.Quota)).Error
if err != nil {
return err
@@ -166,7 +162,7 @@ func (redemption *Redemption) SelectUpdate() error {
// Update Make sure your token's fields is completed, because this will update non-zero values
func (redemption *Redemption) Update() error {
var err error
err = DB.Model(redemption).Select("name", "status", "quota", "redeemed_time", "expired_time").Updates(redemption).Error
err = DB.Model(redemption).Select("name", "status", "quota", "redeemed_time").Updates(redemption).Error
return err
}
@@ -187,9 +183,3 @@ func DeleteRedemptionById(id int) (err error) {
}
return redemption.Delete()
}
func DeleteInvalidRedemptions() (int64, error) {
now := common.GetTimestamp()
result := DB.Where("status IN ? OR (status = ? AND expired_time != 0 AND expired_time < ?)", []int{common.RedemptionCodeStatusUsed, common.RedemptionCodeStatusDisabled}, common.RedemptionCodeStatusEnabled, now).Delete(&Redemption{})
return result.RowsAffected, result.Error
}

View File

@@ -302,64 +302,3 @@ func SumUsedTaskQuota(queryParams SyncTaskQueryParams) (stat []TaskQuotaUsage, e
err = query.Select("mode, sum(quota) as count").Group("mode").Find(&stat).Error
return stat, err
}
// TaskCountAllTasks returns total tasks that match the given query params (admin usage)
func TaskCountAllTasks(queryParams SyncTaskQueryParams) int64 {
var total int64
query := DB.Model(&Task{})
if queryParams.ChannelID != "" {
query = query.Where("channel_id = ?", queryParams.ChannelID)
}
if queryParams.Platform != "" {
query = query.Where("platform = ?", queryParams.Platform)
}
if queryParams.UserID != "" {
query = query.Where("user_id = ?", queryParams.UserID)
}
if len(queryParams.UserIDs) != 0 {
query = query.Where("user_id in (?)", queryParams.UserIDs)
}
if queryParams.TaskID != "" {
query = query.Where("task_id = ?", queryParams.TaskID)
}
if queryParams.Action != "" {
query = query.Where("action = ?", queryParams.Action)
}
if queryParams.Status != "" {
query = query.Where("status = ?", queryParams.Status)
}
if queryParams.StartTimestamp != 0 {
query = query.Where("submit_time >= ?", queryParams.StartTimestamp)
}
if queryParams.EndTimestamp != 0 {
query = query.Where("submit_time <= ?", queryParams.EndTimestamp)
}
_ = query.Count(&total).Error
return total
}
// TaskCountAllUserTask returns total tasks for given user
func TaskCountAllUserTask(userId int, queryParams SyncTaskQueryParams) int64 {
var total int64
query := DB.Model(&Task{}).Where("user_id = ?", userId)
if queryParams.TaskID != "" {
query = query.Where("task_id = ?", queryParams.TaskID)
}
if queryParams.Action != "" {
query = query.Where("action = ?", queryParams.Action)
}
if queryParams.Status != "" {
query = query.Where("status = ?", queryParams.Status)
}
if queryParams.Platform != "" {
query = query.Where("platform = ?", queryParams.Platform)
}
if queryParams.StartTimestamp != 0 {
query = query.Where("submit_time >= ?", queryParams.StartTimestamp)
}
if queryParams.EndTimestamp != 0 {
query = query.Where("submit_time <= ?", queryParams.EndTimestamp)
}
_ = query.Count(&total).Error
return total
}

View File

@@ -66,7 +66,7 @@ func SearchUserTokens(userId int, keyword string, token string) (tokens []*Token
if token != "" {
token = strings.Trim(token, "sk-")
}
err = DB.Where("user_id = ?", userId).Where("name LIKE ?", "%"+keyword+"%").Where(commonKeyCol+" LIKE ?", "%"+token+"%").Find(&tokens).Error
err = DB.Where("user_id = ?", userId).Where("name LIKE ?", "%"+keyword+"%").Where(keyCol+" LIKE ?", "%"+token+"%").Find(&tokens).Error
return tokens, err
}
@@ -161,7 +161,7 @@ func GetTokenByKey(key string, fromDB bool) (token *Token, err error) {
// Don't return error - fall through to DB
}
fromDB = true
err = DB.Where(commonKeyCol+" = ?", key).First(&token).Error
err = DB.Where(keyCol+" = ?", key).First(&token).Error
return token, err
}
@@ -320,44 +320,3 @@ func decreaseTokenQuota(id int, quota int) (err error) {
).Error
return err
}
// CountUserTokens returns total number of tokens for the given user, used for pagination
func CountUserTokens(userId int) (int64, error) {
var total int64
err := DB.Model(&Token{}).Where("user_id = ?", userId).Count(&total).Error
return total, err
}
// BatchDeleteTokens 删除指定用户的一组令牌,返回成功删除数量
func BatchDeleteTokens(ids []int, userId int) (int, error) {
if len(ids) == 0 {
return 0, errors.New("ids 不能为空!")
}
tx := DB.Begin()
var tokens []Token
if err := tx.Where("user_id = ? AND id IN (?)", userId, ids).Find(&tokens).Error; err != nil {
tx.Rollback()
return 0, err
}
if err := tx.Where("user_id = ? AND id IN (?)", userId, ids).Delete(&Token{}).Error; err != nil {
tx.Rollback()
return 0, err
}
if err := tx.Commit().Error; err != nil {
return 0, err
}
if common.RedisEnabled {
gopool.Go(func() {
for _, t := range tokens {
_ = cacheDeleteToken(t.Key)
}
})
}
return len(tokens), nil
}

View File

@@ -10,7 +10,7 @@ import (
func cacheSetToken(token Token) error {
key := common.GenerateHMAC(token.Key)
token.Clean()
err := common.RedisHSetObj(fmt.Sprintf("token:%s", key), &token, time.Duration(common.RedisKeyCacheSeconds())*time.Second)
err := common.RedisHSetObj(fmt.Sprintf("token:%s", key), &token, time.Duration(constant.TokenCacheSeconds)*time.Second)
if err != nil {
return err
}

View File

@@ -5,7 +5,6 @@ import (
"errors"
"fmt"
"one-api/common"
"one-api/dto"
"strconv"
"strings"
@@ -42,7 +41,6 @@ type User struct {
DeletedAt gorm.DeletedAt `gorm:"index"`
LinuxDOId string `json:"linux_do_id" gorm:"column:linux_do_id;index"`
Setting string `json:"setting" gorm:"type:text;column:setting"`
Remark string `json:"remark,omitempty" gorm:"type:varchar(255)" validate:"max=255"`
}
func (user *User) ToBaseUser() *UserBase {
@@ -69,18 +67,14 @@ func (user *User) SetAccessToken(token string) {
user.AccessToken = &token
}
func (user *User) GetSetting() dto.UserSetting {
setting := dto.UserSetting{}
if user.Setting != "" {
err := json.Unmarshal([]byte(user.Setting), &setting)
if err != nil {
common.SysError("failed to unmarshal setting: " + err.Error())
}
func (user *User) GetSetting() map[string]interface{} {
if user.Setting == "" {
return nil
}
return setting
return common.StrToMap(user.Setting)
}
func (user *User) SetSetting(setting dto.UserSetting) {
func (user *User) SetSetting(setting map[string]interface{}) {
settingBytes, err := json.Marshal(setting)
if err != nil {
common.SysError("failed to marshal setting: " + err.Error())
@@ -119,7 +113,7 @@ func GetMaxUserId() int {
return user.Id
}
func GetAllUsers(pageInfo *common.PageInfo) (users []*User, total int64, err error) {
func GetAllUsers(startIdx int, num int) (users []*User, total int64, err error) {
// Start transaction
tx := DB.Begin()
if tx.Error != nil {
@@ -139,7 +133,7 @@ func GetAllUsers(pageInfo *common.PageInfo) (users []*User, total int64, err err
}
// Get paginated users within same transaction
err = tx.Unscoped().Order("id desc").Limit(pageInfo.GetPageSize()).Offset(pageInfo.GetStartIdx()).Omit("password").Find(&users).Error
err = tx.Unscoped().Order("id desc").Limit(num).Offset(startIdx).Omit("password").Find(&users).Error
if err != nil {
tx.Rollback()
return nil, 0, err
@@ -181,7 +175,7 @@ func SearchUsers(keyword string, group string, startIdx int, num int) ([]*User,
// 如果是数字同时搜索ID和其他字段
likeCondition = "id = ? OR " + likeCondition
if group != "" {
query = query.Where("("+likeCondition+") AND "+commonGroupCol+" = ?",
query = query.Where("("+likeCondition+") AND "+groupCol+" = ?",
keywordInt, "%"+keyword+"%", "%"+keyword+"%", "%"+keyword+"%", group)
} else {
query = query.Where(likeCondition,
@@ -190,7 +184,7 @@ func SearchUsers(keyword string, group string, startIdx int, num int) ([]*User,
} else {
// 非数字关键字,只搜索字符串字段
if group != "" {
query = query.Where("("+likeCondition+") AND "+commonGroupCol+" = ?",
query = query.Where("("+likeCondition+") AND "+groupCol+" = ?",
"%"+keyword+"%", "%"+keyword+"%", "%"+keyword+"%", group)
} else {
query = query.Where(likeCondition,
@@ -372,7 +366,6 @@ func (user *User) Edit(updatePassword bool) error {
"display_name": newUser.DisplayName,
"group": newUser.Group,
"quota": newUser.Quota,
"remark": newUser.Remark,
}
if updatePassword {
updates["password"] = newUser.Password
@@ -622,7 +615,7 @@ func GetUserGroup(id int, fromDB bool) (group string, err error) {
// Don't return error - fall through to DB
}
fromDB = true
err = DB.Model(&User{}).Where("id = ?", id).Select(commonGroupCol).Find(&group).Error
err = DB.Model(&User{}).Where("id = ?", id).Select(groupCol).Find(&group).Error
if err != nil {
return "", err
}
@@ -631,7 +624,7 @@ func GetUserGroup(id int, fromDB bool) (group string, err error) {
}
// GetUserSetting gets setting from Redis first, falls back to DB if needed
func GetUserSetting(id int, fromDB bool) (settingMap dto.UserSetting, err error) {
func GetUserSetting(id int, fromDB bool) (settingMap map[string]interface{}, err error) {
var setting string
defer func() {
// Update Redis cache asynchronously on successful DB read
@@ -653,12 +646,10 @@ func GetUserSetting(id int, fromDB bool) (settingMap dto.UserSetting, err error)
fromDB = true
err = DB.Model(&User{}).Where("id = ?", id).Select("setting").Find(&setting).Error
if err != nil {
return settingMap, err
return map[string]interface{}{}, err
}
userBase := &UserBase{
Setting: setting,
}
return userBase.GetSetting(), nil
return common.StrToMap(setting), nil
}
func IncreaseUserQuota(id int, quota int, db bool) (err error) {

View File

@@ -5,7 +5,6 @@ import (
"fmt"
"one-api/common"
"one-api/constant"
"one-api/dto"
"time"
"github.com/gin-gonic/gin"
@@ -25,23 +24,28 @@ type UserBase struct {
}
func (user *UserBase) WriteContext(c *gin.Context) {
common.SetContextKey(c, constant.ContextKeyUserGroup, user.Group)
common.SetContextKey(c, constant.ContextKeyUserQuota, user.Quota)
common.SetContextKey(c, constant.ContextKeyUserStatus, user.Status)
common.SetContextKey(c, constant.ContextKeyUserEmail, user.Email)
common.SetContextKey(c, constant.ContextKeyUserName, user.Username)
common.SetContextKey(c, constant.ContextKeyUserSetting, user.GetSetting())
c.Set(constant.ContextKeyUserGroup, user.Group)
c.Set(constant.ContextKeyUserQuota, user.Quota)
c.Set(constant.ContextKeyUserStatus, user.Status)
c.Set(constant.ContextKeyUserEmail, user.Email)
c.Set("username", user.Username)
c.Set(constant.ContextKeyUserSetting, user.GetSetting())
}
func (user *UserBase) GetSetting() dto.UserSetting {
setting := dto.UserSetting{}
if user.Setting != "" {
err := json.Unmarshal([]byte(user.Setting), &setting)
if err != nil {
common.SysError("failed to unmarshal setting: " + err.Error())
}
func (user *UserBase) GetSetting() map[string]interface{} {
if user.Setting == "" {
return nil
}
return setting
return common.StrToMap(user.Setting)
}
func (user *UserBase) SetSetting(setting map[string]interface{}) {
settingBytes, err := json.Marshal(setting)
if err != nil {
common.SysError("failed to marshal setting: " + err.Error())
return
}
user.Setting = string(settingBytes)
}
// getUserCacheKey returns the key for user cache
@@ -66,7 +70,7 @@ func updateUserCache(user User) error {
return common.RedisHSetObj(
getUserCacheKey(user.Id),
user.ToBaseUser(),
time.Duration(common.RedisKeyCacheSeconds())*time.Second,
time.Duration(constant.UserId2QuotaCacheSeconds)*time.Second,
)
}
@@ -170,10 +174,11 @@ func getUserNameCache(userId int) (string, error) {
return cache.Username, nil
}
func getUserSettingCache(userId int) (dto.UserSetting, error) {
func getUserSettingCache(userId int) (map[string]interface{}, error) {
setting := make(map[string]interface{})
cache, err := GetUserCache(userId)
if err != nil {
return dto.UserSetting{}, err
return setting, err
}
return cache.GetSetting(), nil
}

View File

@@ -2,12 +2,11 @@ package model
import (
"errors"
"github.com/bytedance/gopkg/util/gopool"
"gorm.io/gorm"
"one-api/common"
"sync"
"time"
"github.com/bytedance/gopkg/util/gopool"
"gorm.io/gorm"
)
const (
@@ -49,22 +48,6 @@ func addNewRecord(type_ int, id int, value int) {
}
func batchUpdate() {
// check if there's any data to update
hasData := false
for i := 0; i < BatchUpdateTypeCount; i++ {
batchUpdateLocks[i].Lock()
if len(batchUpdateStores[i]) > 0 {
hasData = true
batchUpdateLocks[i].Unlock()
break
}
batchUpdateLocks[i].Unlock()
}
if !hasData {
return
}
common.SysLog("batch update started")
for i := 0; i < BatchUpdateTypeCount; i++ {
batchUpdateLocks[i].Lock()

View File

@@ -44,6 +44,4 @@ type TaskAdaptor interface {
// FetchTask
FetchTask(baseUrl, key string, body map[string]any) (*http.Response, error)
ParseTaskResult(respBody []byte) (*relaycommon.TaskInfo, error)
}

View File

@@ -30,7 +30,7 @@ func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) {
var fullRequestURL string
switch info.RelayMode {
case constant.RelayModeEmbeddings:
fullRequestURL = fmt.Sprintf("%s/compatible-mode/v1/embeddings", info.BaseUrl)
fullRequestURL = fmt.Sprintf("%s/api/v1/services/embeddings/text-embedding/text-embedding", info.BaseUrl)
case constant.RelayModeRerank:
fullRequestURL = fmt.Sprintf("%s/api/v1/services/rerank/text-rerank/text-rerank", info.BaseUrl)
case constant.RelayModeImagesGenerations:
@@ -82,7 +82,7 @@ func (a *Adaptor) ConvertRerankRequest(c *gin.Context, relayMode int, request dt
}
func (a *Adaptor) ConvertEmbeddingRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.EmbeddingRequest) (any, error) {
return request, nil
return embeddingRequestOpenAI2Ali(request), nil
}
func (a *Adaptor) ConvertAudioRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.AudioRequest) (io.Reader, error) {

View File

@@ -132,7 +132,10 @@ func aliImageHandler(c *gin.Context, resp *http.Response, info *relaycommon.Rela
if err != nil {
return service.OpenAIErrorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil
}
common.CloseResponseBodyGracefully(resp)
err = resp.Body.Close()
if err != nil {
return service.OpenAIErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil
}
err = json.Unmarshal(responseBody, &aliTaskResponse)
if err != nil {
return service.OpenAIErrorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil

View File

@@ -4,7 +4,6 @@ import (
"encoding/json"
"io"
"net/http"
"one-api/common"
"one-api/dto"
relaycommon "one-api/relay/common"
"one-api/service"
@@ -36,7 +35,10 @@ func RerankHandler(c *gin.Context, resp *http.Response, info *relaycommon.RelayI
if err != nil {
return service.OpenAIErrorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil
}
common.CloseResponseBodyGracefully(resp)
err = resp.Body.Close()
if err != nil {
return service.OpenAIErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil
}
var aliResponse AliRerankResponse
err = json.Unmarshal(responseBody, &aliResponse)

Some files were not shown because too many files have changed in this diff Show More