mirror of
https://github.com/QuantumNous/new-api.git
synced 2026-03-30 07:37:23 +00:00
Merge branch 'alpha' into base
This commit is contained in:
@@ -135,7 +135,11 @@ func GetResponseBody(method, url string, channel *model.Channel, headers http.He
|
||||
for k := range headers {
|
||||
req.Header.Add(k, headers.Get(k))
|
||||
}
|
||||
res, err := service.GetHttpClient().Do(req)
|
||||
client, err := service.NewProxyHttpClient(channel.GetSetting().Proxy)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
res, err := client.Do(req)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
@@ -69,6 +69,12 @@ func testChannel(channel *model.Channel, testModel string) testResult {
|
||||
newAPIError: nil,
|
||||
}
|
||||
}
|
||||
if channel.Type == constant.ChannelTypeVidu {
|
||||
return testResult{
|
||||
localErr: errors.New("vidu channel test is not supported"),
|
||||
newAPIError: nil,
|
||||
}
|
||||
}
|
||||
w := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(w)
|
||||
|
||||
@@ -126,10 +132,27 @@ func testChannel(channel *model.Channel, testModel string) testResult {
|
||||
newAPIError: newAPIError,
|
||||
}
|
||||
}
|
||||
request := buildTestRequest(testModel)
|
||||
|
||||
info := relaycommon.GenRelayInfo(c)
|
||||
// Determine relay format based on request path
|
||||
relayFormat := types.RelayFormatOpenAI
|
||||
if c.Request.URL.Path == "/v1/embeddings" {
|
||||
relayFormat = types.RelayFormatEmbedding
|
||||
}
|
||||
|
||||
err = helper.ModelMappedHelper(c, info, nil)
|
||||
info, err := relaycommon.GenRelayInfo(c, relayFormat, request, nil)
|
||||
|
||||
if err != nil {
|
||||
return testResult{
|
||||
context: c,
|
||||
localErr: err,
|
||||
newAPIError: types.NewError(err, types.ErrorCodeGenRelayInfoFailed),
|
||||
}
|
||||
}
|
||||
|
||||
info.InitChannelMeta(c)
|
||||
|
||||
err = helper.ModelMappedHelper(c, info, request)
|
||||
if err != nil {
|
||||
return testResult{
|
||||
context: c,
|
||||
@@ -137,7 +160,9 @@ func testChannel(channel *model.Channel, testModel string) testResult {
|
||||
newAPIError: types.NewError(err, types.ErrorCodeChannelModelMappedError),
|
||||
}
|
||||
}
|
||||
|
||||
testModel = info.UpstreamModelName
|
||||
request.Model = testModel
|
||||
|
||||
apiType, _ := common.ChannelType2APIType(channel.Type)
|
||||
adaptor := relay.GetAdaptor(apiType)
|
||||
@@ -149,13 +174,12 @@ func testChannel(channel *model.Channel, testModel string) testResult {
|
||||
}
|
||||
}
|
||||
|
||||
request := buildTestRequest(testModel)
|
||||
// 创建一个用于日志的 info 副本,移除 ApiKey
|
||||
logInfo := *info
|
||||
logInfo.ApiKey = ""
|
||||
common.SysLog(fmt.Sprintf("testing channel %d with model %s , info %+v ", channel.Id, testModel, logInfo))
|
||||
//// 创建一个用于日志的 info 副本,移除 ApiKey
|
||||
//logInfo := info
|
||||
//logInfo.ApiKey = ""
|
||||
common.SysLog(fmt.Sprintf("testing channel %d with model %s , info %+v ", channel.Id, testModel, info.ToString()))
|
||||
|
||||
priceData, err := helper.ModelPriceHelper(c, info, 0, int(request.MaxTokens))
|
||||
priceData, err := helper.ModelPriceHelper(c, info, 0, request.GetTokenCountMeta())
|
||||
if err != nil {
|
||||
return testResult{
|
||||
context: c,
|
||||
@@ -203,7 +227,7 @@ func testChannel(channel *model.Channel, testModel string) testResult {
|
||||
return testResult{
|
||||
context: c,
|
||||
localErr: err,
|
||||
newAPIError: types.NewError(err, types.ErrorCodeDoRequestFailed),
|
||||
newAPIError: types.NewOpenAIError(err, types.ErrorCodeDoRequestFailed, http.StatusInternalServerError),
|
||||
}
|
||||
}
|
||||
var httpResp *http.Response
|
||||
@@ -214,7 +238,7 @@ func testChannel(channel *model.Channel, testModel string) testResult {
|
||||
return testResult{
|
||||
context: c,
|
||||
localErr: err,
|
||||
newAPIError: types.NewError(err, types.ErrorCodeBadResponse),
|
||||
newAPIError: types.NewOpenAIError(err, types.ErrorCodeBadResponse, http.StatusInternalServerError),
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -230,7 +254,7 @@ func testChannel(channel *model.Channel, testModel string) testResult {
|
||||
return testResult{
|
||||
context: c,
|
||||
localErr: errors.New("usage is nil"),
|
||||
newAPIError: types.NewError(errors.New("usage is nil"), types.ErrorCodeBadResponseBody),
|
||||
newAPIError: types.NewOpenAIError(errors.New("usage is nil"), types.ErrorCodeBadResponseBody, http.StatusInternalServerError),
|
||||
}
|
||||
}
|
||||
usage := usageA.(*dto.Usage)
|
||||
@@ -240,7 +264,7 @@ func testChannel(channel *model.Channel, testModel string) testResult {
|
||||
return testResult{
|
||||
context: c,
|
||||
localErr: err,
|
||||
newAPIError: types.NewError(err, types.ErrorCodeReadResponseBodyFailed),
|
||||
newAPIError: types.NewOpenAIError(err, types.ErrorCodeReadResponseBodyFailed, http.StatusInternalServerError),
|
||||
}
|
||||
}
|
||||
info.PromptTokens = usage.PromptTokens
|
||||
@@ -269,7 +293,7 @@ func testChannel(channel *model.Channel, testModel string) testResult {
|
||||
Quota: quota,
|
||||
Content: "模型测试",
|
||||
UseTimeSeconds: int(consumedTime),
|
||||
IsStream: false,
|
||||
IsStream: info.IsStream,
|
||||
Group: info.UsingGroup,
|
||||
Other: other,
|
||||
})
|
||||
@@ -326,8 +350,11 @@ func TestChannel(c *gin.Context) {
|
||||
}
|
||||
channel, err := model.CacheGetChannel(channelId)
|
||||
if err != nil {
|
||||
common.ApiError(c, err)
|
||||
return
|
||||
channel, err = model.GetChannelById(channelId, true)
|
||||
if err != nil {
|
||||
common.ApiError(c, err)
|
||||
return
|
||||
}
|
||||
}
|
||||
//defer func() {
|
||||
// if channel.ChannelInfo.IsMultiKey {
|
||||
@@ -411,14 +438,14 @@ func testAllChannels(notify bool) error {
|
||||
if common.AutomaticDisableChannelEnabled && !shouldBanChannel {
|
||||
if milliseconds > disableThreshold {
|
||||
err := errors.New(fmt.Sprintf("响应时间 %.2fs 超过阈值 %.2fs", float64(milliseconds)/1000.0, float64(disableThreshold)/1000.0))
|
||||
newAPIError = types.NewError(err, types.ErrorCodeChannelResponseTimeExceeded)
|
||||
newAPIError = types.NewOpenAIError(err, types.ErrorCodeChannelResponseTimeExceeded, http.StatusRequestTimeout)
|
||||
shouldBanChannel = true
|
||||
}
|
||||
}
|
||||
|
||||
// disable channel
|
||||
if isChannelEnabled && shouldBanChannel && channel.GetAutoBan() {
|
||||
go processChannelError(result.context, *types.NewChannelError(channel.Id, channel.Type, channel.Name, channel.ChannelInfo.IsMultiKey, common.GetContextKeyString(result.context, constant.ContextKeyChannelKey), channel.GetAutoBan()), newAPIError)
|
||||
processChannelError(result.context, *types.NewChannelError(channel.Id, channel.Type, channel.Name, channel.ChannelInfo.IsMultiKey, common.GetContextKeyString(result.context, constant.ContextKeyChannelKey), channel.GetAutoBan()), newAPIError)
|
||||
}
|
||||
|
||||
// enable channel
|
||||
|
||||
@@ -52,6 +52,13 @@ func parseStatusFilter(statusParam string) int {
|
||||
}
|
||||
}
|
||||
|
||||
func clearChannelInfo(channel *model.Channel) {
|
||||
if channel.ChannelInfo.IsMultiKey {
|
||||
channel.ChannelInfo.MultiKeyDisabledReason = nil
|
||||
channel.ChannelInfo.MultiKeyDisabledTime = nil
|
||||
}
|
||||
}
|
||||
|
||||
func GetAllChannels(c *gin.Context) {
|
||||
pageInfo := common.GetPageQuery(c)
|
||||
channelData := make([]*model.Channel, 0)
|
||||
@@ -126,6 +133,10 @@ func GetAllChannels(c *gin.Context) {
|
||||
}
|
||||
}
|
||||
|
||||
for _, datum := range channelData {
|
||||
clearChannelInfo(datum)
|
||||
}
|
||||
|
||||
countQuery := model.DB.Model(&model.Channel{})
|
||||
if statusFilter == common.ChannelStatusEnabled {
|
||||
countQuery = countQuery.Where("status = ?", common.ChannelStatusEnabled)
|
||||
@@ -168,14 +179,26 @@ func FetchUpstreamModels(c *gin.Context) {
|
||||
if channel.GetBaseURL() != "" {
|
||||
baseURL = channel.GetBaseURL()
|
||||
}
|
||||
url := fmt.Sprintf("%s/v1/models", baseURL)
|
||||
|
||||
var url string
|
||||
switch channel.Type {
|
||||
case constant.ChannelTypeGemini:
|
||||
url = fmt.Sprintf("%s/v1beta/openai/models", baseURL)
|
||||
// curl https://example.com/v1beta/models?key=$GEMINI_API_KEY
|
||||
url = fmt.Sprintf("%s/v1beta/openai/models", baseURL) // Remove key in url since we need to use AuthHeader
|
||||
case constant.ChannelTypeAli:
|
||||
url = fmt.Sprintf("%s/compatible-mode/v1/models", baseURL)
|
||||
default:
|
||||
url = fmt.Sprintf("%s/v1/models", baseURL)
|
||||
}
|
||||
|
||||
// 获取响应体 - 根据渠道类型决定是否添加 AuthHeader
|
||||
var body []byte
|
||||
key := strings.Split(channel.Key, "\n")[0]
|
||||
if channel.Type == constant.ChannelTypeGemini {
|
||||
body, err = GetResponseBody("GET", url, channel, GetAuthHeader(key)) // Use AuthHeader since Gemini now forces it
|
||||
} else {
|
||||
body, err = GetResponseBody("GET", url, channel, GetAuthHeader(key))
|
||||
}
|
||||
body, err := GetResponseBody("GET", url, channel, GetAuthHeader(channel.Key))
|
||||
if err != nil {
|
||||
common.ApiError(c, err)
|
||||
return
|
||||
@@ -319,6 +342,10 @@ func SearchChannels(c *gin.Context) {
|
||||
|
||||
pagedData := channelData[startIdx:endIdx]
|
||||
|
||||
for _, datum := range pagedData {
|
||||
clearChannelInfo(datum)
|
||||
}
|
||||
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": true,
|
||||
"message": "",
|
||||
@@ -342,6 +369,9 @@ func GetChannel(c *gin.Context) {
|
||||
common.ApiError(c, err)
|
||||
return
|
||||
}
|
||||
if channel != nil {
|
||||
clearChannelInfo(channel)
|
||||
}
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": true,
|
||||
"message": "",
|
||||
@@ -350,6 +380,85 @@ func GetChannel(c *gin.Context) {
|
||||
return
|
||||
}
|
||||
|
||||
// GetChannelKey 验证2FA后获取渠道密钥
|
||||
func GetChannelKey(c *gin.Context) {
|
||||
type GetChannelKeyRequest struct {
|
||||
Code string `json:"code" binding:"required"`
|
||||
}
|
||||
|
||||
var req GetChannelKeyRequest
|
||||
if err := c.ShouldBindJSON(&req); err != nil {
|
||||
common.ApiError(c, fmt.Errorf("参数错误: %v", err))
|
||||
return
|
||||
}
|
||||
|
||||
userId := c.GetInt("id")
|
||||
channelId, err := strconv.Atoi(c.Param("id"))
|
||||
if err != nil {
|
||||
common.ApiError(c, fmt.Errorf("渠道ID格式错误: %v", err))
|
||||
return
|
||||
}
|
||||
|
||||
// 获取2FA记录并验证
|
||||
twoFA, err := model.GetTwoFAByUserId(userId)
|
||||
if err != nil {
|
||||
common.ApiError(c, fmt.Errorf("获取2FA信息失败: %v", err))
|
||||
return
|
||||
}
|
||||
|
||||
if twoFA == nil || !twoFA.IsEnabled {
|
||||
common.ApiError(c, fmt.Errorf("用户未启用2FA,无法查看密钥"))
|
||||
return
|
||||
}
|
||||
|
||||
// 统一的2FA验证逻辑
|
||||
if !validateTwoFactorAuth(twoFA, req.Code) {
|
||||
common.ApiError(c, fmt.Errorf("验证码或备用码错误,请重试"))
|
||||
return
|
||||
}
|
||||
|
||||
// 获取渠道信息(包含密钥)
|
||||
channel, err := model.GetChannelById(channelId, true)
|
||||
if err != nil {
|
||||
common.ApiError(c, fmt.Errorf("获取渠道信息失败: %v", err))
|
||||
return
|
||||
}
|
||||
|
||||
if channel == nil {
|
||||
common.ApiError(c, fmt.Errorf("渠道不存在"))
|
||||
return
|
||||
}
|
||||
|
||||
// 记录操作日志
|
||||
model.RecordLog(userId, model.LogTypeSystem, fmt.Sprintf("查看渠道密钥信息 (渠道ID: %d)", channelId))
|
||||
|
||||
// 统一的成功响应格式
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": true,
|
||||
"message": "验证成功",
|
||||
"data": map[string]interface{}{
|
||||
"key": channel.Key,
|
||||
},
|
||||
})
|
||||
}
|
||||
|
||||
// validateTwoFactorAuth 统一的2FA验证函数
|
||||
func validateTwoFactorAuth(twoFA *model.TwoFA, code string) bool {
|
||||
// 尝试验证TOTP
|
||||
if cleanCode, err := common.ValidateNumericCode(code); err == nil {
|
||||
if isValid, _ := twoFA.ValidateTOTPAndUpdateUsage(cleanCode); isValid {
|
||||
return true
|
||||
}
|
||||
}
|
||||
|
||||
// 尝试验证备用码
|
||||
if isValid, err := twoFA.ValidateBackupCodeAndUpdateUsage(code); err == nil && isValid {
|
||||
return true
|
||||
}
|
||||
|
||||
return false
|
||||
}
|
||||
|
||||
// validateChannel 通用的渠道校验函数
|
||||
func validateChannel(channel *model.Channel, isAdd bool) error {
|
||||
// 校验 channel settings
|
||||
@@ -669,6 +778,7 @@ func DeleteChannelBatch(c *gin.Context) {
|
||||
type PatchChannel struct {
|
||||
model.Channel
|
||||
MultiKeyMode *string `json:"multi_key_mode"`
|
||||
KeyMode *string `json:"key_mode"` // 多key模式下密钥覆盖或者追加
|
||||
}
|
||||
|
||||
func UpdateChannel(c *gin.Context) {
|
||||
@@ -688,7 +798,7 @@ func UpdateChannel(c *gin.Context) {
|
||||
return
|
||||
}
|
||||
// Preserve existing ChannelInfo to ensure multi-key channels keep correct state even if the client does not send ChannelInfo in the request.
|
||||
originChannel, err := model.GetChannelById(channel.Id, false)
|
||||
originChannel, err := model.GetChannelById(channel.Id, true)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": false,
|
||||
@@ -704,6 +814,69 @@ func UpdateChannel(c *gin.Context) {
|
||||
if channel.MultiKeyMode != nil && *channel.MultiKeyMode != "" {
|
||||
channel.ChannelInfo.MultiKeyMode = constant.MultiKeyMode(*channel.MultiKeyMode)
|
||||
}
|
||||
|
||||
// 处理多key模式下的密钥追加/覆盖逻辑
|
||||
if channel.KeyMode != nil && channel.ChannelInfo.IsMultiKey {
|
||||
switch *channel.KeyMode {
|
||||
case "append":
|
||||
// 追加模式:将新密钥添加到现有密钥列表
|
||||
if originChannel.Key != "" {
|
||||
var newKeys []string
|
||||
var existingKeys []string
|
||||
|
||||
// 解析现有密钥
|
||||
if strings.HasPrefix(strings.TrimSpace(originChannel.Key), "[") {
|
||||
// JSON数组格式
|
||||
var arr []json.RawMessage
|
||||
if err := json.Unmarshal([]byte(strings.TrimSpace(originChannel.Key)), &arr); err == nil {
|
||||
existingKeys = make([]string, len(arr))
|
||||
for i, v := range arr {
|
||||
existingKeys[i] = string(v)
|
||||
}
|
||||
}
|
||||
} else {
|
||||
// 换行分隔格式
|
||||
existingKeys = strings.Split(strings.Trim(originChannel.Key, "\n"), "\n")
|
||||
}
|
||||
|
||||
// 处理 Vertex AI 的特殊情况
|
||||
if channel.Type == constant.ChannelTypeVertexAi {
|
||||
// 尝试解析新密钥为JSON数组
|
||||
if strings.HasPrefix(strings.TrimSpace(channel.Key), "[") {
|
||||
array, err := getVertexArrayKeys(channel.Key)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": false,
|
||||
"message": "追加密钥解析失败: " + err.Error(),
|
||||
})
|
||||
return
|
||||
}
|
||||
newKeys = array
|
||||
} else {
|
||||
// 单个JSON密钥
|
||||
newKeys = []string{channel.Key}
|
||||
}
|
||||
// 合并密钥
|
||||
allKeys := append(existingKeys, newKeys...)
|
||||
channel.Key = strings.Join(allKeys, "\n")
|
||||
} else {
|
||||
// 普通渠道的处理
|
||||
inputKeys := strings.Split(channel.Key, "\n")
|
||||
for _, key := range inputKeys {
|
||||
key = strings.TrimSpace(key)
|
||||
if key != "" {
|
||||
newKeys = append(newKeys, key)
|
||||
}
|
||||
}
|
||||
// 合并密钥
|
||||
allKeys := append(existingKeys, newKeys...)
|
||||
channel.Key = strings.Join(allKeys, "\n")
|
||||
}
|
||||
}
|
||||
case "replace":
|
||||
// 覆盖模式:直接使用新密钥(默认行为,不需要特殊处理)
|
||||
}
|
||||
}
|
||||
err = channel.Update()
|
||||
if err != nil {
|
||||
common.ApiError(c, err)
|
||||
@@ -711,6 +884,7 @@ func UpdateChannel(c *gin.Context) {
|
||||
}
|
||||
model.InitChannelCache()
|
||||
channel.Key = ""
|
||||
clearChannelInfo(&channel.Channel)
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": true,
|
||||
"message": "",
|
||||
@@ -914,3 +1088,413 @@ func CopyChannel(c *gin.Context) {
|
||||
// success
|
||||
c.JSON(http.StatusOK, gin.H{"success": true, "message": "", "data": gin.H{"id": clone.Id}})
|
||||
}
|
||||
|
||||
// MultiKeyManageRequest represents the request for multi-key management operations
|
||||
type MultiKeyManageRequest struct {
|
||||
ChannelId int `json:"channel_id"`
|
||||
Action string `json:"action"` // "disable_key", "enable_key", "delete_disabled_keys", "get_key_status"
|
||||
KeyIndex *int `json:"key_index,omitempty"` // for disable_key and enable_key actions
|
||||
Page int `json:"page,omitempty"` // for get_key_status pagination
|
||||
PageSize int `json:"page_size,omitempty"` // for get_key_status pagination
|
||||
Status *int `json:"status,omitempty"` // for get_key_status filtering: 1=enabled, 2=manual_disabled, 3=auto_disabled, nil=all
|
||||
}
|
||||
|
||||
// MultiKeyStatusResponse represents the response for key status query
|
||||
type MultiKeyStatusResponse struct {
|
||||
Keys []KeyStatus `json:"keys"`
|
||||
Total int `json:"total"`
|
||||
Page int `json:"page"`
|
||||
PageSize int `json:"page_size"`
|
||||
TotalPages int `json:"total_pages"`
|
||||
// Statistics
|
||||
EnabledCount int `json:"enabled_count"`
|
||||
ManualDisabledCount int `json:"manual_disabled_count"`
|
||||
AutoDisabledCount int `json:"auto_disabled_count"`
|
||||
}
|
||||
|
||||
type KeyStatus struct {
|
||||
Index int `json:"index"`
|
||||
Status int `json:"status"` // 1: enabled, 2: disabled
|
||||
DisabledTime int64 `json:"disabled_time,omitempty"`
|
||||
Reason string `json:"reason,omitempty"`
|
||||
KeyPreview string `json:"key_preview"` // first 10 chars of key for identification
|
||||
}
|
||||
|
||||
// ManageMultiKeys handles multi-key management operations
|
||||
func ManageMultiKeys(c *gin.Context) {
|
||||
request := MultiKeyManageRequest{}
|
||||
err := c.ShouldBindJSON(&request)
|
||||
if err != nil {
|
||||
common.ApiError(c, err)
|
||||
return
|
||||
}
|
||||
|
||||
channel, err := model.GetChannelById(request.ChannelId, true)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": false,
|
||||
"message": "渠道不存在",
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
if !channel.ChannelInfo.IsMultiKey {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": false,
|
||||
"message": "该渠道不是多密钥模式",
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
lock := model.GetChannelPollingLock(channel.Id)
|
||||
lock.Lock()
|
||||
defer lock.Unlock()
|
||||
|
||||
switch request.Action {
|
||||
case "get_key_status":
|
||||
keys := channel.GetKeys()
|
||||
|
||||
// Default pagination parameters
|
||||
page := request.Page
|
||||
pageSize := request.PageSize
|
||||
if page <= 0 {
|
||||
page = 1
|
||||
}
|
||||
if pageSize <= 0 {
|
||||
pageSize = 50 // Default page size
|
||||
}
|
||||
|
||||
// Statistics for all keys (unchanged by filtering)
|
||||
var enabledCount, manualDisabledCount, autoDisabledCount int
|
||||
|
||||
// Build all key status data first
|
||||
var allKeyStatusList []KeyStatus
|
||||
for i, key := range keys {
|
||||
status := 1 // default enabled
|
||||
var disabledTime int64
|
||||
var reason string
|
||||
|
||||
if channel.ChannelInfo.MultiKeyStatusList != nil {
|
||||
if s, exists := channel.ChannelInfo.MultiKeyStatusList[i]; exists {
|
||||
status = s
|
||||
}
|
||||
}
|
||||
|
||||
// Count for statistics (all keys)
|
||||
switch status {
|
||||
case 1:
|
||||
enabledCount++
|
||||
case 2:
|
||||
manualDisabledCount++
|
||||
case 3:
|
||||
autoDisabledCount++
|
||||
}
|
||||
|
||||
if status != 1 {
|
||||
if channel.ChannelInfo.MultiKeyDisabledTime != nil {
|
||||
disabledTime = channel.ChannelInfo.MultiKeyDisabledTime[i]
|
||||
}
|
||||
if channel.ChannelInfo.MultiKeyDisabledReason != nil {
|
||||
reason = channel.ChannelInfo.MultiKeyDisabledReason[i]
|
||||
}
|
||||
}
|
||||
|
||||
// Create key preview (first 10 chars)
|
||||
keyPreview := key
|
||||
if len(key) > 10 {
|
||||
keyPreview = key[:10] + "..."
|
||||
}
|
||||
|
||||
allKeyStatusList = append(allKeyStatusList, KeyStatus{
|
||||
Index: i,
|
||||
Status: status,
|
||||
DisabledTime: disabledTime,
|
||||
Reason: reason,
|
||||
KeyPreview: keyPreview,
|
||||
})
|
||||
}
|
||||
|
||||
// Apply status filter if specified
|
||||
var filteredKeyStatusList []KeyStatus
|
||||
if request.Status != nil {
|
||||
for _, keyStatus := range allKeyStatusList {
|
||||
if keyStatus.Status == *request.Status {
|
||||
filteredKeyStatusList = append(filteredKeyStatusList, keyStatus)
|
||||
}
|
||||
}
|
||||
} else {
|
||||
filteredKeyStatusList = allKeyStatusList
|
||||
}
|
||||
|
||||
// Calculate pagination based on filtered results
|
||||
filteredTotal := len(filteredKeyStatusList)
|
||||
totalPages := (filteredTotal + pageSize - 1) / pageSize
|
||||
if totalPages == 0 {
|
||||
totalPages = 1
|
||||
}
|
||||
if page > totalPages {
|
||||
page = totalPages
|
||||
}
|
||||
|
||||
// Calculate range for current page
|
||||
start := (page - 1) * pageSize
|
||||
end := start + pageSize
|
||||
if end > filteredTotal {
|
||||
end = filteredTotal
|
||||
}
|
||||
|
||||
// Get the page data
|
||||
var pageKeyStatusList []KeyStatus
|
||||
if start < filteredTotal {
|
||||
pageKeyStatusList = filteredKeyStatusList[start:end]
|
||||
}
|
||||
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": true,
|
||||
"message": "",
|
||||
"data": MultiKeyStatusResponse{
|
||||
Keys: pageKeyStatusList,
|
||||
Total: filteredTotal, // Total of filtered results
|
||||
Page: page,
|
||||
PageSize: pageSize,
|
||||
TotalPages: totalPages,
|
||||
EnabledCount: enabledCount, // Overall statistics
|
||||
ManualDisabledCount: manualDisabledCount, // Overall statistics
|
||||
AutoDisabledCount: autoDisabledCount, // Overall statistics
|
||||
},
|
||||
})
|
||||
return
|
||||
|
||||
case "disable_key":
|
||||
if request.KeyIndex == nil {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": false,
|
||||
"message": "未指定要禁用的密钥索引",
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
keyIndex := *request.KeyIndex
|
||||
if keyIndex < 0 || keyIndex >= channel.ChannelInfo.MultiKeySize {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": false,
|
||||
"message": "密钥索引超出范围",
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
if channel.ChannelInfo.MultiKeyStatusList == nil {
|
||||
channel.ChannelInfo.MultiKeyStatusList = make(map[int]int)
|
||||
}
|
||||
if channel.ChannelInfo.MultiKeyDisabledTime == nil {
|
||||
channel.ChannelInfo.MultiKeyDisabledTime = make(map[int]int64)
|
||||
}
|
||||
if channel.ChannelInfo.MultiKeyDisabledReason == nil {
|
||||
channel.ChannelInfo.MultiKeyDisabledReason = make(map[int]string)
|
||||
}
|
||||
|
||||
channel.ChannelInfo.MultiKeyStatusList[keyIndex] = 2 // disabled
|
||||
|
||||
err = channel.Update()
|
||||
if err != nil {
|
||||
common.ApiError(c, err)
|
||||
return
|
||||
}
|
||||
|
||||
model.InitChannelCache()
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": true,
|
||||
"message": "密钥已禁用",
|
||||
})
|
||||
return
|
||||
|
||||
case "enable_key":
|
||||
if request.KeyIndex == nil {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": false,
|
||||
"message": "未指定要启用的密钥索引",
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
keyIndex := *request.KeyIndex
|
||||
if keyIndex < 0 || keyIndex >= channel.ChannelInfo.MultiKeySize {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": false,
|
||||
"message": "密钥索引超出范围",
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
// 从状态列表中删除该密钥的记录,使其回到默认启用状态
|
||||
if channel.ChannelInfo.MultiKeyStatusList != nil {
|
||||
delete(channel.ChannelInfo.MultiKeyStatusList, keyIndex)
|
||||
}
|
||||
if channel.ChannelInfo.MultiKeyDisabledTime != nil {
|
||||
delete(channel.ChannelInfo.MultiKeyDisabledTime, keyIndex)
|
||||
}
|
||||
if channel.ChannelInfo.MultiKeyDisabledReason != nil {
|
||||
delete(channel.ChannelInfo.MultiKeyDisabledReason, keyIndex)
|
||||
}
|
||||
|
||||
err = channel.Update()
|
||||
if err != nil {
|
||||
common.ApiError(c, err)
|
||||
return
|
||||
}
|
||||
|
||||
model.InitChannelCache()
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": true,
|
||||
"message": "密钥已启用",
|
||||
})
|
||||
return
|
||||
|
||||
case "enable_all_keys":
|
||||
// 清空所有禁用状态,使所有密钥回到默认启用状态
|
||||
var enabledCount int
|
||||
if channel.ChannelInfo.MultiKeyStatusList != nil {
|
||||
enabledCount = len(channel.ChannelInfo.MultiKeyStatusList)
|
||||
}
|
||||
|
||||
channel.ChannelInfo.MultiKeyStatusList = make(map[int]int)
|
||||
channel.ChannelInfo.MultiKeyDisabledTime = make(map[int]int64)
|
||||
channel.ChannelInfo.MultiKeyDisabledReason = make(map[int]string)
|
||||
|
||||
err = channel.Update()
|
||||
if err != nil {
|
||||
common.ApiError(c, err)
|
||||
return
|
||||
}
|
||||
|
||||
model.InitChannelCache()
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": true,
|
||||
"message": fmt.Sprintf("已启用 %d 个密钥", enabledCount),
|
||||
})
|
||||
return
|
||||
|
||||
case "disable_all_keys":
|
||||
// 禁用所有启用的密钥
|
||||
if channel.ChannelInfo.MultiKeyStatusList == nil {
|
||||
channel.ChannelInfo.MultiKeyStatusList = make(map[int]int)
|
||||
}
|
||||
if channel.ChannelInfo.MultiKeyDisabledTime == nil {
|
||||
channel.ChannelInfo.MultiKeyDisabledTime = make(map[int]int64)
|
||||
}
|
||||
if channel.ChannelInfo.MultiKeyDisabledReason == nil {
|
||||
channel.ChannelInfo.MultiKeyDisabledReason = make(map[int]string)
|
||||
}
|
||||
|
||||
var disabledCount int
|
||||
for i := 0; i < channel.ChannelInfo.MultiKeySize; i++ {
|
||||
status := 1 // default enabled
|
||||
if s, exists := channel.ChannelInfo.MultiKeyStatusList[i]; exists {
|
||||
status = s
|
||||
}
|
||||
|
||||
// 只禁用当前启用的密钥
|
||||
if status == 1 {
|
||||
channel.ChannelInfo.MultiKeyStatusList[i] = 2 // disabled
|
||||
disabledCount++
|
||||
}
|
||||
}
|
||||
|
||||
if disabledCount == 0 {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": false,
|
||||
"message": "没有可禁用的密钥",
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
err = channel.Update()
|
||||
if err != nil {
|
||||
common.ApiError(c, err)
|
||||
return
|
||||
}
|
||||
|
||||
model.InitChannelCache()
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": true,
|
||||
"message": fmt.Sprintf("已禁用 %d 个密钥", disabledCount),
|
||||
})
|
||||
return
|
||||
|
||||
case "delete_disabled_keys":
|
||||
keys := channel.GetKeys()
|
||||
var remainingKeys []string
|
||||
var deletedCount int
|
||||
var newStatusList = make(map[int]int)
|
||||
var newDisabledTime = make(map[int]int64)
|
||||
var newDisabledReason = make(map[int]string)
|
||||
|
||||
newIndex := 0
|
||||
for i, key := range keys {
|
||||
status := 1 // default enabled
|
||||
if channel.ChannelInfo.MultiKeyStatusList != nil {
|
||||
if s, exists := channel.ChannelInfo.MultiKeyStatusList[i]; exists {
|
||||
status = s
|
||||
}
|
||||
}
|
||||
|
||||
// 只删除自动禁用(status == 3)的密钥,保留启用(status == 1)和手动禁用(status == 2)的密钥
|
||||
if status == 3 {
|
||||
deletedCount++
|
||||
} else {
|
||||
remainingKeys = append(remainingKeys, key)
|
||||
// 保留非自动禁用密钥的状态信息,重新索引
|
||||
if status != 1 {
|
||||
newStatusList[newIndex] = status
|
||||
if channel.ChannelInfo.MultiKeyDisabledTime != nil {
|
||||
if t, exists := channel.ChannelInfo.MultiKeyDisabledTime[i]; exists {
|
||||
newDisabledTime[newIndex] = t
|
||||
}
|
||||
}
|
||||
if channel.ChannelInfo.MultiKeyDisabledReason != nil {
|
||||
if r, exists := channel.ChannelInfo.MultiKeyDisabledReason[i]; exists {
|
||||
newDisabledReason[newIndex] = r
|
||||
}
|
||||
}
|
||||
}
|
||||
newIndex++
|
||||
}
|
||||
}
|
||||
|
||||
if deletedCount == 0 {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": false,
|
||||
"message": "没有需要删除的自动禁用密钥",
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
// Update channel with remaining keys
|
||||
channel.Key = strings.Join(remainingKeys, "\n")
|
||||
channel.ChannelInfo.MultiKeySize = len(remainingKeys)
|
||||
channel.ChannelInfo.MultiKeyStatusList = newStatusList
|
||||
channel.ChannelInfo.MultiKeyDisabledTime = newDisabledTime
|
||||
channel.ChannelInfo.MultiKeyDisabledReason = newDisabledReason
|
||||
|
||||
err = channel.Update()
|
||||
if err != nil {
|
||||
common.ApiError(c, err)
|
||||
return
|
||||
}
|
||||
|
||||
model.InitChannelCache()
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": true,
|
||||
"message": fmt.Sprintf("已删除 %d 个自动禁用的密钥", deletedCount),
|
||||
"data": deletedCount,
|
||||
})
|
||||
return
|
||||
|
||||
default:
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": false,
|
||||
"message": "不支持的操作",
|
||||
})
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
@@ -3,101 +3,102 @@
|
||||
package controller
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"net/http"
|
||||
"one-api/common"
|
||||
"one-api/model"
|
||||
"github.com/gin-gonic/gin"
|
||||
"encoding/json"
|
||||
"net/http"
|
||||
"one-api/common"
|
||||
"one-api/model"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
// MigrateConsoleSetting 迁移旧的控制台相关配置到 console_setting.*
|
||||
func MigrateConsoleSetting(c *gin.Context) {
|
||||
// 读取全部 option
|
||||
opts, err := model.AllOption()
|
||||
if err != nil {
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"success": false, "message": err.Error()})
|
||||
return
|
||||
}
|
||||
// 建立 map
|
||||
valMap := map[string]string{}
|
||||
for _, o := range opts {
|
||||
valMap[o.Key] = o.Value
|
||||
}
|
||||
// 读取全部 option
|
||||
opts, err := model.AllOption()
|
||||
if err != nil {
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"success": false, "message": err.Error()})
|
||||
return
|
||||
}
|
||||
// 建立 map
|
||||
valMap := map[string]string{}
|
||||
for _, o := range opts {
|
||||
valMap[o.Key] = o.Value
|
||||
}
|
||||
|
||||
// 处理 APIInfo
|
||||
if v := valMap["ApiInfo"]; v != "" {
|
||||
var arr []map[string]interface{}
|
||||
if err := json.Unmarshal([]byte(v), &arr); err == nil {
|
||||
if len(arr) > 50 {
|
||||
arr = arr[:50]
|
||||
}
|
||||
bytes, _ := json.Marshal(arr)
|
||||
model.UpdateOption("console_setting.api_info", string(bytes))
|
||||
}
|
||||
model.UpdateOption("ApiInfo", "")
|
||||
}
|
||||
// Announcements 直接搬
|
||||
if v := valMap["Announcements"]; v != "" {
|
||||
model.UpdateOption("console_setting.announcements", v)
|
||||
model.UpdateOption("Announcements", "")
|
||||
}
|
||||
// FAQ 转换
|
||||
if v := valMap["FAQ"]; v != "" {
|
||||
var arr []map[string]interface{}
|
||||
if err := json.Unmarshal([]byte(v), &arr); err == nil {
|
||||
out := []map[string]interface{}{}
|
||||
for _, item := range arr {
|
||||
q, _ := item["question"].(string)
|
||||
if q == "" {
|
||||
q, _ = item["title"].(string)
|
||||
}
|
||||
a, _ := item["answer"].(string)
|
||||
if a == "" {
|
||||
a, _ = item["content"].(string)
|
||||
}
|
||||
if q != "" && a != "" {
|
||||
out = append(out, map[string]interface{}{"question": q, "answer": a})
|
||||
}
|
||||
}
|
||||
if len(out) > 50 {
|
||||
out = out[:50]
|
||||
}
|
||||
bytes, _ := json.Marshal(out)
|
||||
model.UpdateOption("console_setting.faq", string(bytes))
|
||||
}
|
||||
model.UpdateOption("FAQ", "")
|
||||
}
|
||||
// Uptime Kuma 迁移到新的 groups 结构(console_setting.uptime_kuma_groups)
|
||||
url := valMap["UptimeKumaUrl"]
|
||||
slug := valMap["UptimeKumaSlug"]
|
||||
if url != "" && slug != "" {
|
||||
// 仅当同时存在 URL 与 Slug 时才进行迁移
|
||||
groups := []map[string]interface{}{
|
||||
{
|
||||
"id": 1,
|
||||
"categoryName": "old",
|
||||
"url": url,
|
||||
"slug": slug,
|
||||
"description": "",
|
||||
},
|
||||
}
|
||||
bytes, _ := json.Marshal(groups)
|
||||
model.UpdateOption("console_setting.uptime_kuma_groups", string(bytes))
|
||||
}
|
||||
// 清空旧键内容
|
||||
if url != "" {
|
||||
model.UpdateOption("UptimeKumaUrl", "")
|
||||
}
|
||||
if slug != "" {
|
||||
model.UpdateOption("UptimeKumaSlug", "")
|
||||
}
|
||||
// 处理 APIInfo
|
||||
if v := valMap["ApiInfo"]; v != "" {
|
||||
var arr []map[string]interface{}
|
||||
if err := json.Unmarshal([]byte(v), &arr); err == nil {
|
||||
if len(arr) > 50 {
|
||||
arr = arr[:50]
|
||||
}
|
||||
bytes, _ := json.Marshal(arr)
|
||||
model.UpdateOption("console_setting.api_info", string(bytes))
|
||||
}
|
||||
model.UpdateOption("ApiInfo", "")
|
||||
}
|
||||
// Announcements 直接搬
|
||||
if v := valMap["Announcements"]; v != "" {
|
||||
model.UpdateOption("console_setting.announcements", v)
|
||||
model.UpdateOption("Announcements", "")
|
||||
}
|
||||
// FAQ 转换
|
||||
if v := valMap["FAQ"]; v != "" {
|
||||
var arr []map[string]interface{}
|
||||
if err := json.Unmarshal([]byte(v), &arr); err == nil {
|
||||
out := []map[string]interface{}{}
|
||||
for _, item := range arr {
|
||||
q, _ := item["question"].(string)
|
||||
if q == "" {
|
||||
q, _ = item["title"].(string)
|
||||
}
|
||||
a, _ := item["answer"].(string)
|
||||
if a == "" {
|
||||
a, _ = item["content"].(string)
|
||||
}
|
||||
if q != "" && a != "" {
|
||||
out = append(out, map[string]interface{}{"question": q, "answer": a})
|
||||
}
|
||||
}
|
||||
if len(out) > 50 {
|
||||
out = out[:50]
|
||||
}
|
||||
bytes, _ := json.Marshal(out)
|
||||
model.UpdateOption("console_setting.faq", string(bytes))
|
||||
}
|
||||
model.UpdateOption("FAQ", "")
|
||||
}
|
||||
// Uptime Kuma 迁移到新的 groups 结构(console_setting.uptime_kuma_groups)
|
||||
url := valMap["UptimeKumaUrl"]
|
||||
slug := valMap["UptimeKumaSlug"]
|
||||
if url != "" && slug != "" {
|
||||
// 仅当同时存在 URL 与 Slug 时才进行迁移
|
||||
groups := []map[string]interface{}{
|
||||
{
|
||||
"id": 1,
|
||||
"categoryName": "old",
|
||||
"url": url,
|
||||
"slug": slug,
|
||||
"description": "",
|
||||
},
|
||||
}
|
||||
bytes, _ := json.Marshal(groups)
|
||||
model.UpdateOption("console_setting.uptime_kuma_groups", string(bytes))
|
||||
}
|
||||
// 清空旧键内容
|
||||
if url != "" {
|
||||
model.UpdateOption("UptimeKumaUrl", "")
|
||||
}
|
||||
if slug != "" {
|
||||
model.UpdateOption("UptimeKumaSlug", "")
|
||||
}
|
||||
|
||||
// 删除旧键记录
|
||||
oldKeys := []string{"ApiInfo", "Announcements", "FAQ", "UptimeKumaUrl", "UptimeKumaSlug"}
|
||||
model.DB.Where("key IN ?", oldKeys).Delete(&model.Option{})
|
||||
// 删除旧键记录
|
||||
oldKeys := []string{"ApiInfo", "Announcements", "FAQ", "UptimeKumaUrl", "UptimeKumaSlug"}
|
||||
model.DB.Where("key IN ?", oldKeys).Delete(&model.Option{})
|
||||
|
||||
// 重新加载 OptionMap
|
||||
model.InitOptionMap()
|
||||
common.SysLog("console setting migrated")
|
||||
c.JSON(http.StatusOK, gin.H{"success": true, "message": "migrated"})
|
||||
}
|
||||
// 重新加载 OptionMap
|
||||
model.InitOptionMap()
|
||||
common.SysLog("console setting migrated")
|
||||
c.JSON(http.StatusOK, gin.H{"success": true, "message": "migrated"})
|
||||
}
|
||||
|
||||
@@ -220,21 +220,29 @@ func LinuxdoOAuth(c *gin.Context) {
|
||||
}
|
||||
} else {
|
||||
if common.RegisterEnabled {
|
||||
user.Username = "linuxdo_" + strconv.Itoa(model.GetMaxUserId()+1)
|
||||
user.DisplayName = linuxdoUser.Name
|
||||
user.Role = common.RoleCommonUser
|
||||
user.Status = common.UserStatusEnabled
|
||||
if linuxdoUser.TrustLevel >= common.LinuxDOMinimumTrustLevel {
|
||||
user.Username = "linuxdo_" + strconv.Itoa(model.GetMaxUserId()+1)
|
||||
user.DisplayName = linuxdoUser.Name
|
||||
user.Role = common.RoleCommonUser
|
||||
user.Status = common.UserStatusEnabled
|
||||
|
||||
affCode := session.Get("aff")
|
||||
inviterId := 0
|
||||
if affCode != nil {
|
||||
inviterId, _ = model.GetUserIdByAffCode(affCode.(string))
|
||||
}
|
||||
affCode := session.Get("aff")
|
||||
inviterId := 0
|
||||
if affCode != nil {
|
||||
inviterId, _ = model.GetUserIdByAffCode(affCode.(string))
|
||||
}
|
||||
|
||||
if err := user.Insert(inviterId); err != nil {
|
||||
if err := user.Insert(inviterId); err != nil {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": false,
|
||||
"message": err.Error(),
|
||||
})
|
||||
return
|
||||
}
|
||||
} else {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": false,
|
||||
"message": err.Error(),
|
||||
"message": "Linux DO 信任等级未达到管理员设置的最低信任等级",
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
@@ -9,6 +9,7 @@ import (
|
||||
"net/http"
|
||||
"one-api/common"
|
||||
"one-api/dto"
|
||||
"one-api/logger"
|
||||
"one-api/model"
|
||||
"one-api/service"
|
||||
"one-api/setting"
|
||||
@@ -28,7 +29,7 @@ func UpdateMidjourneyTaskBulk() {
|
||||
continue
|
||||
}
|
||||
|
||||
common.LogInfo(ctx, fmt.Sprintf("检测到未完成的任务数有: %v", len(tasks)))
|
||||
logger.LogInfo(ctx, fmt.Sprintf("检测到未完成的任务数有: %v", len(tasks)))
|
||||
taskChannelM := make(map[int][]string)
|
||||
taskM := make(map[string]*model.Midjourney)
|
||||
nullTaskIds := make([]int, 0)
|
||||
@@ -47,9 +48,9 @@ func UpdateMidjourneyTaskBulk() {
|
||||
"progress": "100%",
|
||||
})
|
||||
if err != nil {
|
||||
common.LogError(ctx, fmt.Sprintf("Fix null mj_id task error: %v", err))
|
||||
logger.LogError(ctx, fmt.Sprintf("Fix null mj_id task error: %v", err))
|
||||
} else {
|
||||
common.LogInfo(ctx, fmt.Sprintf("Fix null mj_id task success: %v", nullTaskIds))
|
||||
logger.LogInfo(ctx, fmt.Sprintf("Fix null mj_id task success: %v", nullTaskIds))
|
||||
}
|
||||
}
|
||||
if len(taskChannelM) == 0 {
|
||||
@@ -57,20 +58,20 @@ func UpdateMidjourneyTaskBulk() {
|
||||
}
|
||||
|
||||
for channelId, taskIds := range taskChannelM {
|
||||
common.LogInfo(ctx, fmt.Sprintf("渠道 #%d 未完成的任务有: %d", channelId, len(taskIds)))
|
||||
logger.LogInfo(ctx, fmt.Sprintf("渠道 #%d 未完成的任务有: %d", channelId, len(taskIds)))
|
||||
if len(taskIds) == 0 {
|
||||
continue
|
||||
}
|
||||
midjourneyChannel, err := model.CacheGetChannel(channelId)
|
||||
if err != nil {
|
||||
common.LogError(ctx, fmt.Sprintf("CacheGetChannel: %v", err))
|
||||
logger.LogError(ctx, fmt.Sprintf("CacheGetChannel: %v", err))
|
||||
err := model.MjBulkUpdate(taskIds, map[string]any{
|
||||
"fail_reason": fmt.Sprintf("获取渠道信息失败,请联系管理员,渠道ID:%d", channelId),
|
||||
"status": "FAILURE",
|
||||
"progress": "100%",
|
||||
})
|
||||
if err != nil {
|
||||
common.LogInfo(ctx, fmt.Sprintf("UpdateMidjourneyTask error: %v", err))
|
||||
logger.LogInfo(ctx, fmt.Sprintf("UpdateMidjourneyTask error: %v", err))
|
||||
}
|
||||
continue
|
||||
}
|
||||
@@ -81,7 +82,7 @@ func UpdateMidjourneyTaskBulk() {
|
||||
})
|
||||
req, err := http.NewRequest("POST", requestUrl, bytes.NewBuffer(body))
|
||||
if err != nil {
|
||||
common.LogError(ctx, fmt.Sprintf("Get Task error: %v", err))
|
||||
logger.LogError(ctx, fmt.Sprintf("Get Task error: %v", err))
|
||||
continue
|
||||
}
|
||||
// 设置超时时间
|
||||
@@ -93,22 +94,22 @@ func UpdateMidjourneyTaskBulk() {
|
||||
req.Header.Set("mj-api-secret", midjourneyChannel.Key)
|
||||
resp, err := service.GetHttpClient().Do(req)
|
||||
if err != nil {
|
||||
common.LogError(ctx, fmt.Sprintf("Get Task Do req error: %v", err))
|
||||
logger.LogError(ctx, fmt.Sprintf("Get Task Do req error: %v", err))
|
||||
continue
|
||||
}
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
common.LogError(ctx, fmt.Sprintf("Get Task status code: %d", resp.StatusCode))
|
||||
logger.LogError(ctx, fmt.Sprintf("Get Task status code: %d", resp.StatusCode))
|
||||
continue
|
||||
}
|
||||
responseBody, err := io.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
common.LogError(ctx, fmt.Sprintf("Get Task parse body error: %v", err))
|
||||
logger.LogError(ctx, fmt.Sprintf("Get Task parse body error: %v", err))
|
||||
continue
|
||||
}
|
||||
var responseItems []dto.MidjourneyDto
|
||||
err = json.Unmarshal(responseBody, &responseItems)
|
||||
if err != nil {
|
||||
common.LogError(ctx, fmt.Sprintf("Get Task parse body error2: %v, body: %s", err, string(responseBody)))
|
||||
logger.LogError(ctx, fmt.Sprintf("Get Task parse body error2: %v, body: %s", err, string(responseBody)))
|
||||
continue
|
||||
}
|
||||
resp.Body.Close()
|
||||
@@ -145,9 +146,25 @@ func UpdateMidjourneyTaskBulk() {
|
||||
buttonStr, _ := json.Marshal(responseItem.Buttons)
|
||||
task.Buttons = string(buttonStr)
|
||||
}
|
||||
// 映射 VideoUrl
|
||||
task.VideoUrl = responseItem.VideoUrl
|
||||
|
||||
// 映射 VideoUrls - 将数组序列化为 JSON 字符串
|
||||
if responseItem.VideoUrls != nil && len(responseItem.VideoUrls) > 0 {
|
||||
videoUrlsStr, err := json.Marshal(responseItem.VideoUrls)
|
||||
if err != nil {
|
||||
logger.LogError(ctx, fmt.Sprintf("序列化 VideoUrls 失败: %v", err))
|
||||
task.VideoUrls = "[]" // 失败时设置为空数组
|
||||
} else {
|
||||
task.VideoUrls = string(videoUrlsStr)
|
||||
}
|
||||
} else {
|
||||
task.VideoUrls = "" // 空值时清空字段
|
||||
}
|
||||
|
||||
shouldReturnQuota := false
|
||||
if (task.Progress != "100%" && responseItem.FailReason != "") || (task.Progress == "100%" && task.Status == "FAILURE") {
|
||||
common.LogInfo(ctx, task.MjId+" 构建失败,"+task.FailReason)
|
||||
logger.LogInfo(ctx, task.MjId+" 构建失败,"+task.FailReason)
|
||||
task.Progress = "100%"
|
||||
if task.Quota != 0 {
|
||||
shouldReturnQuota = true
|
||||
@@ -155,14 +172,14 @@ func UpdateMidjourneyTaskBulk() {
|
||||
}
|
||||
err = task.Update()
|
||||
if err != nil {
|
||||
common.LogError(ctx, "UpdateMidjourneyTask task error: "+err.Error())
|
||||
logger.LogError(ctx, "UpdateMidjourneyTask task error: "+err.Error())
|
||||
} else {
|
||||
if shouldReturnQuota {
|
||||
err = model.IncreaseUserQuota(task.UserId, task.Quota, false)
|
||||
if err != nil {
|
||||
common.LogError(ctx, "fail to increase user quota: "+err.Error())
|
||||
logger.LogError(ctx, "fail to increase user quota: "+err.Error())
|
||||
}
|
||||
logContent := fmt.Sprintf("构图失败 %s,补偿 %s", task.MjId, common.LogQuota(task.Quota))
|
||||
logContent := fmt.Sprintf("构图失败 %s,补偿 %s", task.MjId, logger.LogQuota(task.Quota))
|
||||
model.RecordLog(task.UserId, model.LogTypeSystem, logContent)
|
||||
}
|
||||
}
|
||||
@@ -208,6 +225,20 @@ func checkMjTaskNeedUpdate(oldTask *model.Midjourney, newTask dto.MidjourneyDto)
|
||||
if oldTask.Progress != "100%" && newTask.FailReason != "" {
|
||||
return true
|
||||
}
|
||||
// 检查 VideoUrl 是否需要更新
|
||||
if oldTask.VideoUrl != newTask.VideoUrl {
|
||||
return true
|
||||
}
|
||||
// 检查 VideoUrls 是否需要更新
|
||||
if newTask.VideoUrls != nil && len(newTask.VideoUrls) > 0 {
|
||||
newVideoUrlsStr, _ := json.Marshal(newTask.VideoUrls)
|
||||
if oldTask.VideoUrls != string(newVideoUrlsStr) {
|
||||
return true
|
||||
}
|
||||
} else if oldTask.VideoUrls != "" {
|
||||
// 如果新数据没有 VideoUrls 但旧数据有,需要更新(清空)
|
||||
return true
|
||||
}
|
||||
|
||||
return false
|
||||
}
|
||||
|
||||
@@ -39,48 +39,51 @@ func TestStatus(c *gin.Context) {
|
||||
func GetStatus(c *gin.Context) {
|
||||
|
||||
cs := console_setting.GetConsoleSetting()
|
||||
common.OptionMapRWMutex.RLock()
|
||||
defer common.OptionMapRWMutex.RUnlock()
|
||||
|
||||
data := gin.H{
|
||||
"version": common.Version,
|
||||
"start_time": common.StartTime,
|
||||
"email_verification": common.EmailVerificationEnabled,
|
||||
"github_oauth": common.GitHubOAuthEnabled,
|
||||
"github_client_id": common.GitHubClientId,
|
||||
"linuxdo_oauth": common.LinuxDOOAuthEnabled,
|
||||
"linuxdo_client_id": common.LinuxDOClientId,
|
||||
"telegram_oauth": common.TelegramOAuthEnabled,
|
||||
"telegram_bot_name": common.TelegramBotName,
|
||||
"system_name": common.SystemName,
|
||||
"logo": common.Logo,
|
||||
"footer_html": common.Footer,
|
||||
"wechat_qrcode": common.WeChatAccountQRCodeImageURL,
|
||||
"wechat_login": common.WeChatAuthEnabled,
|
||||
"server_address": setting.ServerAddress,
|
||||
"price": setting.Price,
|
||||
"stripe_unit_price": setting.StripeUnitPrice,
|
||||
"min_topup": setting.MinTopUp,
|
||||
"stripe_min_topup": setting.StripeMinTopUp,
|
||||
"turnstile_check": common.TurnstileCheckEnabled,
|
||||
"turnstile_site_key": common.TurnstileSiteKey,
|
||||
"top_up_link": common.TopUpLink,
|
||||
"docs_link": operation_setting.GetGeneralSetting().DocsLink,
|
||||
"quota_per_unit": common.QuotaPerUnit,
|
||||
"display_in_currency": common.DisplayInCurrencyEnabled,
|
||||
"enable_batch_update": common.BatchUpdateEnabled,
|
||||
"enable_drawing": common.DrawingEnabled,
|
||||
"enable_task": common.TaskEnabled,
|
||||
"enable_data_export": common.DataExportEnabled,
|
||||
"data_export_default_time": common.DataExportDefaultTime,
|
||||
"default_collapse_sidebar": common.DefaultCollapseSidebar,
|
||||
"enable_online_topup": setting.PayAddress != "" && setting.EpayId != "" && setting.EpayKey != "",
|
||||
"enable_stripe_topup": setting.StripeApiSecret != "" && setting.StripeWebhookSecret != "" && setting.StripePriceId != "",
|
||||
"mj_notify_enabled": setting.MjNotifyEnabled,
|
||||
"chats": setting.Chats,
|
||||
"demo_site_enabled": operation_setting.DemoSiteEnabled,
|
||||
"self_use_mode_enabled": operation_setting.SelfUseModeEnabled,
|
||||
"default_use_auto_group": setting.DefaultUseAutoGroup,
|
||||
"pay_methods": setting.PayMethods,
|
||||
"usd_exchange_rate": setting.USDExchangeRate,
|
||||
"version": common.Version,
|
||||
"start_time": common.StartTime,
|
||||
"email_verification": common.EmailVerificationEnabled,
|
||||
"github_oauth": common.GitHubOAuthEnabled,
|
||||
"github_client_id": common.GitHubClientId,
|
||||
"linuxdo_oauth": common.LinuxDOOAuthEnabled,
|
||||
"linuxdo_client_id": common.LinuxDOClientId,
|
||||
"linuxdo_minimum_trust_level": common.LinuxDOMinimumTrustLevel,
|
||||
"telegram_oauth": common.TelegramOAuthEnabled,
|
||||
"telegram_bot_name": common.TelegramBotName,
|
||||
"system_name": common.SystemName,
|
||||
"logo": common.Logo,
|
||||
"footer_html": common.Footer,
|
||||
"wechat_qrcode": common.WeChatAccountQRCodeImageURL,
|
||||
"wechat_login": common.WeChatAuthEnabled,
|
||||
"server_address": setting.ServerAddress,
|
||||
"price": setting.Price,
|
||||
"stripe_unit_price": setting.StripeUnitPrice,
|
||||
"min_topup": setting.MinTopUp,
|
||||
"stripe_min_topup": setting.StripeMinTopUp,
|
||||
"turnstile_check": common.TurnstileCheckEnabled,
|
||||
"turnstile_site_key": common.TurnstileSiteKey,
|
||||
"top_up_link": common.TopUpLink,
|
||||
"docs_link": operation_setting.GetGeneralSetting().DocsLink,
|
||||
"quota_per_unit": common.QuotaPerUnit,
|
||||
"display_in_currency": common.DisplayInCurrencyEnabled,
|
||||
"enable_batch_update": common.BatchUpdateEnabled,
|
||||
"enable_drawing": common.DrawingEnabled,
|
||||
"enable_task": common.TaskEnabled,
|
||||
"enable_data_export": common.DataExportEnabled,
|
||||
"data_export_default_time": common.DataExportDefaultTime,
|
||||
"default_collapse_sidebar": common.DefaultCollapseSidebar,
|
||||
"enable_online_topup": setting.PayAddress != "" && setting.EpayId != "" && setting.EpayKey != "",
|
||||
"enable_stripe_topup": setting.StripeApiSecret != "" && setting.StripeWebhookSecret != "" && setting.StripePriceId != "",
|
||||
"mj_notify_enabled": setting.MjNotifyEnabled,
|
||||
"chats": setting.Chats,
|
||||
"demo_site_enabled": operation_setting.DemoSiteEnabled,
|
||||
"self_use_mode_enabled": operation_setting.SelfUseModeEnabled,
|
||||
"default_use_auto_group": setting.DefaultUseAutoGroup,
|
||||
"pay_methods": setting.PayMethods,
|
||||
"usd_exchange_rate": setting.USDExchangeRate,
|
||||
|
||||
// 面板启用开关
|
||||
"api_info_enabled": cs.ApiInfoEnabled,
|
||||
@@ -88,6 +91,10 @@ func GetStatus(c *gin.Context) {
|
||||
"announcements_enabled": cs.AnnouncementsEnabled,
|
||||
"faq_enabled": cs.FAQEnabled,
|
||||
|
||||
// 模块管理配置
|
||||
"HeaderNavModules": common.OptionMap["HeaderNavModules"],
|
||||
"SidebarModulesAdmin": common.OptionMap["SidebarModulesAdmin"],
|
||||
|
||||
"oidc_enabled": system_setting.GetOIDCSettings().Enabled,
|
||||
"oidc_client_id": system_setting.GetOIDCSettings().ClientId,
|
||||
"oidc_authorization_endpoint": system_setting.GetOIDCSettings().AuthorizationEndpoint,
|
||||
|
||||
27
controller/missing_models.go
Normal file
27
controller/missing_models.go
Normal file
@@ -0,0 +1,27 @@
|
||||
package controller
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"one-api/model"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
// GetMissingModels returns the list of model names that are referenced by channels
|
||||
// but do not have corresponding records in the models meta table.
|
||||
// This helps administrators quickly discover models that need configuration.
|
||||
func GetMissingModels(c *gin.Context) {
|
||||
missing, err := model.GetMissingModels()
|
||||
if err != nil {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": false,
|
||||
"message": err.Error(),
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": true,
|
||||
"data": missing,
|
||||
})
|
||||
}
|
||||
@@ -16,6 +16,7 @@ import (
|
||||
"one-api/relay/channel/moonshot"
|
||||
relaycommon "one-api/relay/common"
|
||||
"one-api/setting"
|
||||
"time"
|
||||
)
|
||||
|
||||
// https://platform.openai.com/docs/api-reference/models/list
|
||||
@@ -92,7 +93,9 @@ func init() {
|
||||
if !success || apiType == constant.APITypeAIProxyLibrary {
|
||||
continue
|
||||
}
|
||||
meta := &relaycommon.RelayInfo{ChannelType: i}
|
||||
meta := &relaycommon.RelayInfo{ChannelMeta: &relaycommon.ChannelMeta{
|
||||
ChannelType: i,
|
||||
}}
|
||||
adaptor := relay.GetAdaptor(apiType)
|
||||
adaptor.Init(meta)
|
||||
channelId2Models[i] = adaptor.GetModelList()
|
||||
@@ -102,7 +105,7 @@ func init() {
|
||||
})
|
||||
}
|
||||
|
||||
func ListModels(c *gin.Context) {
|
||||
func ListModels(c *gin.Context, modelType int) {
|
||||
userOpenAiModels := make([]dto.OpenAIModels, 0)
|
||||
|
||||
modelLimitEnable := common.GetContextKeyBool(c, constant.ContextKeyTokenModelLimitEnabled)
|
||||
@@ -171,11 +174,42 @@ func ListModels(c *gin.Context) {
|
||||
}
|
||||
}
|
||||
}
|
||||
c.JSON(200, gin.H{
|
||||
"success": true,
|
||||
"data": userOpenAiModels,
|
||||
"object": "list",
|
||||
})
|
||||
switch modelType {
|
||||
case constant.ChannelTypeAnthropic:
|
||||
useranthropicModels := make([]dto.AnthropicModel, len(userOpenAiModels))
|
||||
for i, model := range userOpenAiModels {
|
||||
useranthropicModels[i] = dto.AnthropicModel{
|
||||
ID: model.Id,
|
||||
CreatedAt: time.Unix(int64(model.Created), 0).UTC().Format(time.RFC3339),
|
||||
DisplayName: model.Id,
|
||||
Type: "model",
|
||||
}
|
||||
}
|
||||
c.JSON(200, gin.H{
|
||||
"data": useranthropicModels,
|
||||
"first_id": useranthropicModels[0].ID,
|
||||
"has_more": false,
|
||||
"last_id": useranthropicModels[len(useranthropicModels)-1].ID,
|
||||
})
|
||||
case constant.ChannelTypeGemini:
|
||||
userGeminiModels := make([]dto.GeminiModel, len(userOpenAiModels))
|
||||
for i, model := range userOpenAiModels {
|
||||
userGeminiModels[i] = dto.GeminiModel{
|
||||
Name: model.Id,
|
||||
DisplayName: model.Id,
|
||||
}
|
||||
}
|
||||
c.JSON(200, gin.H{
|
||||
"models": userGeminiModels,
|
||||
"nextPageToken": nil,
|
||||
})
|
||||
default:
|
||||
c.JSON(200, gin.H{
|
||||
"success": true,
|
||||
"data": userOpenAiModels,
|
||||
"object": "list",
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func ChannelListModels(c *gin.Context) {
|
||||
@@ -199,10 +233,20 @@ func EnabledListModels(c *gin.Context) {
|
||||
})
|
||||
}
|
||||
|
||||
func RetrieveModel(c *gin.Context) {
|
||||
func RetrieveModel(c *gin.Context, modelType int) {
|
||||
modelId := c.Param("model")
|
||||
if aiModel, ok := openAIModelsMap[modelId]; ok {
|
||||
c.JSON(200, aiModel)
|
||||
switch modelType {
|
||||
case constant.ChannelTypeAnthropic:
|
||||
c.JSON(200, dto.AnthropicModel{
|
||||
ID: aiModel.Id,
|
||||
CreatedAt: time.Unix(int64(aiModel.Created), 0).UTC().Format(time.RFC3339),
|
||||
DisplayName: aiModel.Id,
|
||||
Type: "model",
|
||||
})
|
||||
default:
|
||||
c.JSON(200, aiModel)
|
||||
}
|
||||
} else {
|
||||
openAIError := dto.OpenAIError{
|
||||
Message: fmt.Sprintf("The model '%s' does not exist", modelId),
|
||||
|
||||
330
controller/model_meta.go
Normal file
330
controller/model_meta.go
Normal file
@@ -0,0 +1,330 @@
|
||||
package controller
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"sort"
|
||||
"strconv"
|
||||
"strings"
|
||||
|
||||
"one-api/common"
|
||||
"one-api/constant"
|
||||
"one-api/model"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
// GetAllModelsMeta 获取模型列表(分页)
|
||||
func GetAllModelsMeta(c *gin.Context) {
|
||||
|
||||
pageInfo := common.GetPageQuery(c)
|
||||
modelsMeta, err := model.GetAllModels(pageInfo.GetStartIdx(), pageInfo.GetPageSize())
|
||||
if err != nil {
|
||||
common.ApiError(c, err)
|
||||
return
|
||||
}
|
||||
// 批量填充附加字段,提升列表接口性能
|
||||
enrichModels(modelsMeta)
|
||||
var total int64
|
||||
model.DB.Model(&model.Model{}).Count(&total)
|
||||
|
||||
// 统计供应商计数(全部数据,不受分页影响)
|
||||
vendorCounts, _ := model.GetVendorModelCounts()
|
||||
|
||||
pageInfo.SetTotal(int(total))
|
||||
pageInfo.SetItems(modelsMeta)
|
||||
common.ApiSuccess(c, gin.H{
|
||||
"items": modelsMeta,
|
||||
"total": total,
|
||||
"page": pageInfo.GetPage(),
|
||||
"page_size": pageInfo.GetPageSize(),
|
||||
"vendor_counts": vendorCounts,
|
||||
})
|
||||
}
|
||||
|
||||
// SearchModelsMeta 搜索模型列表
|
||||
func SearchModelsMeta(c *gin.Context) {
|
||||
|
||||
keyword := c.Query("keyword")
|
||||
vendor := c.Query("vendor")
|
||||
pageInfo := common.GetPageQuery(c)
|
||||
|
||||
modelsMeta, total, err := model.SearchModels(keyword, vendor, pageInfo.GetStartIdx(), pageInfo.GetPageSize())
|
||||
if err != nil {
|
||||
common.ApiError(c, err)
|
||||
return
|
||||
}
|
||||
// 批量填充附加字段,提升列表接口性能
|
||||
enrichModels(modelsMeta)
|
||||
pageInfo.SetTotal(int(total))
|
||||
pageInfo.SetItems(modelsMeta)
|
||||
common.ApiSuccess(c, pageInfo)
|
||||
}
|
||||
|
||||
// GetModelMeta 根据 ID 获取单条模型信息
|
||||
func GetModelMeta(c *gin.Context) {
|
||||
idStr := c.Param("id")
|
||||
id, err := strconv.Atoi(idStr)
|
||||
if err != nil {
|
||||
common.ApiError(c, err)
|
||||
return
|
||||
}
|
||||
var m model.Model
|
||||
if err := model.DB.First(&m, id).Error; err != nil {
|
||||
common.ApiError(c, err)
|
||||
return
|
||||
}
|
||||
enrichModels([]*model.Model{&m})
|
||||
common.ApiSuccess(c, &m)
|
||||
}
|
||||
|
||||
// CreateModelMeta 新建模型
|
||||
func CreateModelMeta(c *gin.Context) {
|
||||
var m model.Model
|
||||
if err := c.ShouldBindJSON(&m); err != nil {
|
||||
common.ApiError(c, err)
|
||||
return
|
||||
}
|
||||
if m.ModelName == "" {
|
||||
common.ApiErrorMsg(c, "模型名称不能为空")
|
||||
return
|
||||
}
|
||||
// 名称冲突检查
|
||||
if dup, err := model.IsModelNameDuplicated(0, m.ModelName); err != nil {
|
||||
common.ApiError(c, err)
|
||||
return
|
||||
} else if dup {
|
||||
common.ApiErrorMsg(c, "模型名称已存在")
|
||||
return
|
||||
}
|
||||
|
||||
if err := m.Insert(); err != nil {
|
||||
common.ApiError(c, err)
|
||||
return
|
||||
}
|
||||
model.RefreshPricing()
|
||||
common.ApiSuccess(c, &m)
|
||||
}
|
||||
|
||||
// UpdateModelMeta 更新模型
|
||||
func UpdateModelMeta(c *gin.Context) {
|
||||
statusOnly := c.Query("status_only") == "true"
|
||||
|
||||
var m model.Model
|
||||
if err := c.ShouldBindJSON(&m); err != nil {
|
||||
common.ApiError(c, err)
|
||||
return
|
||||
}
|
||||
if m.Id == 0 {
|
||||
common.ApiErrorMsg(c, "缺少模型 ID")
|
||||
return
|
||||
}
|
||||
|
||||
if statusOnly {
|
||||
// 只更新状态,防止误清空其他字段
|
||||
if err := model.DB.Model(&model.Model{}).Where("id = ?", m.Id).Update("status", m.Status).Error; err != nil {
|
||||
common.ApiError(c, err)
|
||||
return
|
||||
}
|
||||
} else {
|
||||
// 名称冲突检查
|
||||
if dup, err := model.IsModelNameDuplicated(m.Id, m.ModelName); err != nil {
|
||||
common.ApiError(c, err)
|
||||
return
|
||||
} else if dup {
|
||||
common.ApiErrorMsg(c, "模型名称已存在")
|
||||
return
|
||||
}
|
||||
|
||||
if err := m.Update(); err != nil {
|
||||
common.ApiError(c, err)
|
||||
return
|
||||
}
|
||||
}
|
||||
model.RefreshPricing()
|
||||
common.ApiSuccess(c, &m)
|
||||
}
|
||||
|
||||
// DeleteModelMeta 删除模型
|
||||
func DeleteModelMeta(c *gin.Context) {
|
||||
idStr := c.Param("id")
|
||||
id, err := strconv.Atoi(idStr)
|
||||
if err != nil {
|
||||
common.ApiError(c, err)
|
||||
return
|
||||
}
|
||||
if err := model.DB.Delete(&model.Model{}, id).Error; err != nil {
|
||||
common.ApiError(c, err)
|
||||
return
|
||||
}
|
||||
model.RefreshPricing()
|
||||
common.ApiSuccess(c, nil)
|
||||
}
|
||||
|
||||
// enrichModels 批量填充附加信息:端点、渠道、分组、计费类型,避免 N+1 查询
|
||||
func enrichModels(models []*model.Model) {
|
||||
if len(models) == 0 {
|
||||
return
|
||||
}
|
||||
|
||||
// 1) 拆分精确与规则匹配
|
||||
exactNames := make([]string, 0)
|
||||
exactIdx := make(map[string][]int) // modelName -> indices in models
|
||||
ruleIndices := make([]int, 0)
|
||||
for i, m := range models {
|
||||
if m == nil {
|
||||
continue
|
||||
}
|
||||
if m.NameRule == model.NameRuleExact {
|
||||
exactNames = append(exactNames, m.ModelName)
|
||||
exactIdx[m.ModelName] = append(exactIdx[m.ModelName], i)
|
||||
} else {
|
||||
ruleIndices = append(ruleIndices, i)
|
||||
}
|
||||
}
|
||||
|
||||
// 2) 批量查询精确模型的绑定渠道
|
||||
channelsByModel, _ := model.GetBoundChannelsByModelsMap(exactNames)
|
||||
|
||||
// 3) 精确模型:端点从缓存、渠道批量映射、分组/计费类型从缓存
|
||||
for name, indices := range exactIdx {
|
||||
chs := channelsByModel[name]
|
||||
for _, idx := range indices {
|
||||
mm := models[idx]
|
||||
if mm.Endpoints == "" {
|
||||
eps := model.GetModelSupportEndpointTypes(mm.ModelName)
|
||||
if b, err := json.Marshal(eps); err == nil {
|
||||
mm.Endpoints = string(b)
|
||||
}
|
||||
}
|
||||
mm.BoundChannels = chs
|
||||
mm.EnableGroups = model.GetModelEnableGroups(mm.ModelName)
|
||||
mm.QuotaTypes = model.GetModelQuotaTypes(mm.ModelName)
|
||||
}
|
||||
}
|
||||
|
||||
if len(ruleIndices) == 0 {
|
||||
return
|
||||
}
|
||||
|
||||
// 4) 一次性读取定价缓存,内存匹配所有规则模型
|
||||
pricings := model.GetPricing()
|
||||
|
||||
// 为全部规则模型收集匹配名集合、端点并集、分组并集、配额集合
|
||||
matchedNamesByIdx := make(map[int][]string)
|
||||
endpointSetByIdx := make(map[int]map[constant.EndpointType]struct{})
|
||||
groupSetByIdx := make(map[int]map[string]struct{})
|
||||
quotaSetByIdx := make(map[int]map[int]struct{})
|
||||
|
||||
for _, p := range pricings {
|
||||
for _, idx := range ruleIndices {
|
||||
mm := models[idx]
|
||||
var matched bool
|
||||
switch mm.NameRule {
|
||||
case model.NameRulePrefix:
|
||||
matched = strings.HasPrefix(p.ModelName, mm.ModelName)
|
||||
case model.NameRuleSuffix:
|
||||
matched = strings.HasSuffix(p.ModelName, mm.ModelName)
|
||||
case model.NameRuleContains:
|
||||
matched = strings.Contains(p.ModelName, mm.ModelName)
|
||||
}
|
||||
if !matched {
|
||||
continue
|
||||
}
|
||||
matchedNamesByIdx[idx] = append(matchedNamesByIdx[idx], p.ModelName)
|
||||
|
||||
es := endpointSetByIdx[idx]
|
||||
if es == nil {
|
||||
es = make(map[constant.EndpointType]struct{})
|
||||
endpointSetByIdx[idx] = es
|
||||
}
|
||||
for _, et := range p.SupportedEndpointTypes {
|
||||
es[et] = struct{}{}
|
||||
}
|
||||
|
||||
gs := groupSetByIdx[idx]
|
||||
if gs == nil {
|
||||
gs = make(map[string]struct{})
|
||||
groupSetByIdx[idx] = gs
|
||||
}
|
||||
for _, g := range p.EnableGroup {
|
||||
gs[g] = struct{}{}
|
||||
}
|
||||
|
||||
qs := quotaSetByIdx[idx]
|
||||
if qs == nil {
|
||||
qs = make(map[int]struct{})
|
||||
quotaSetByIdx[idx] = qs
|
||||
}
|
||||
qs[p.QuotaType] = struct{}{}
|
||||
}
|
||||
}
|
||||
|
||||
// 5) 汇总所有匹配到的模型名称,批量查询一次渠道
|
||||
allMatchedSet := make(map[string]struct{})
|
||||
for _, names := range matchedNamesByIdx {
|
||||
for _, n := range names {
|
||||
allMatchedSet[n] = struct{}{}
|
||||
}
|
||||
}
|
||||
allMatched := make([]string, 0, len(allMatchedSet))
|
||||
for n := range allMatchedSet {
|
||||
allMatched = append(allMatched, n)
|
||||
}
|
||||
matchedChannelsByModel, _ := model.GetBoundChannelsByModelsMap(allMatched)
|
||||
|
||||
// 6) 回填每个规则模型的并集信息
|
||||
for _, idx := range ruleIndices {
|
||||
mm := models[idx]
|
||||
|
||||
// 端点并集 -> 序列化
|
||||
if es, ok := endpointSetByIdx[idx]; ok && mm.Endpoints == "" {
|
||||
eps := make([]constant.EndpointType, 0, len(es))
|
||||
for et := range es {
|
||||
eps = append(eps, et)
|
||||
}
|
||||
if b, err := json.Marshal(eps); err == nil {
|
||||
mm.Endpoints = string(b)
|
||||
}
|
||||
}
|
||||
|
||||
// 分组并集
|
||||
if gs, ok := groupSetByIdx[idx]; ok {
|
||||
groups := make([]string, 0, len(gs))
|
||||
for g := range gs {
|
||||
groups = append(groups, g)
|
||||
}
|
||||
mm.EnableGroups = groups
|
||||
}
|
||||
|
||||
// 配额类型集合(保持去重并排序)
|
||||
if qs, ok := quotaSetByIdx[idx]; ok {
|
||||
arr := make([]int, 0, len(qs))
|
||||
for k := range qs {
|
||||
arr = append(arr, k)
|
||||
}
|
||||
sort.Ints(arr)
|
||||
mm.QuotaTypes = arr
|
||||
}
|
||||
|
||||
// 渠道并集
|
||||
names := matchedNamesByIdx[idx]
|
||||
channelSet := make(map[string]model.BoundChannel)
|
||||
for _, n := range names {
|
||||
for _, ch := range matchedChannelsByModel[n] {
|
||||
key := ch.Name + "_" + strconv.Itoa(ch.Type)
|
||||
channelSet[key] = ch
|
||||
}
|
||||
}
|
||||
if len(channelSet) > 0 {
|
||||
chs := make([]model.BoundChannel, 0, len(channelSet))
|
||||
for _, ch := range channelSet {
|
||||
chs = append(chs, ch)
|
||||
}
|
||||
mm.BoundChannels = chs
|
||||
}
|
||||
|
||||
// 匹配信息
|
||||
mm.MatchedModels = names
|
||||
mm.MatchedCount = len(names)
|
||||
}
|
||||
}
|
||||
463
controller/model_sync.go
Normal file
463
controller/model_sync.go
Normal file
@@ -0,0 +1,463 @@
|
||||
package controller
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"io"
|
||||
"net"
|
||||
"net/http"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"one-api/model"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
"gorm.io/gorm"
|
||||
)
|
||||
|
||||
// 上游地址
|
||||
const (
|
||||
upstreamModelsURL = "https://basellm.github.io/llm-metadata/api/newapi/models.json"
|
||||
upstreamVendorsURL = "https://basellm.github.io/llm-metadata/api/newapi/vendors.json"
|
||||
)
|
||||
|
||||
type upstreamEnvelope[T any] struct {
|
||||
Success bool `json:"success"`
|
||||
Message string `json:"message"`
|
||||
Data []T `json:"data"`
|
||||
}
|
||||
|
||||
type upstreamModel struct {
|
||||
Description string `json:"description"`
|
||||
Endpoints json.RawMessage `json:"endpoints"`
|
||||
Icon string `json:"icon"`
|
||||
ModelName string `json:"model_name"`
|
||||
NameRule int `json:"name_rule"`
|
||||
Status int `json:"status"`
|
||||
Tags string `json:"tags"`
|
||||
VendorName string `json:"vendor_name"`
|
||||
}
|
||||
|
||||
type upstreamVendor struct {
|
||||
Description string `json:"description"`
|
||||
Icon string `json:"icon"`
|
||||
Name string `json:"name"`
|
||||
Status int `json:"status"`
|
||||
}
|
||||
|
||||
type overwriteField struct {
|
||||
ModelName string `json:"model_name"`
|
||||
Fields []string `json:"fields"`
|
||||
}
|
||||
|
||||
type syncRequest struct {
|
||||
Overwrite []overwriteField `json:"overwrite"`
|
||||
}
|
||||
|
||||
func newHTTPClient() *http.Client {
|
||||
dialer := &net.Dialer{Timeout: 10 * time.Second}
|
||||
transport := &http.Transport{
|
||||
MaxIdleConns: 100,
|
||||
IdleConnTimeout: 90 * time.Second,
|
||||
TLSHandshakeTimeout: 10 * time.Second,
|
||||
ExpectContinueTimeout: 1 * time.Second,
|
||||
ResponseHeaderTimeout: 10 * time.Second,
|
||||
}
|
||||
transport.DialContext = func(ctx context.Context, network, addr string) (net.Conn, error) {
|
||||
host, _, err := net.SplitHostPort(addr)
|
||||
if err != nil {
|
||||
host = addr
|
||||
}
|
||||
if strings.HasSuffix(host, "github.io") {
|
||||
if conn, err := dialer.DialContext(ctx, "tcp4", addr); err == nil {
|
||||
return conn, nil
|
||||
}
|
||||
return dialer.DialContext(ctx, "tcp6", addr)
|
||||
}
|
||||
return dialer.DialContext(ctx, network, addr)
|
||||
}
|
||||
return &http.Client{Transport: transport}
|
||||
}
|
||||
|
||||
var httpClient = newHTTPClient()
|
||||
|
||||
func fetchJSON[T any](ctx context.Context, url string, out *upstreamEnvelope[T]) error {
|
||||
var lastErr error
|
||||
for attempt := 0; attempt < 3; attempt++ {
|
||||
req, err := http.NewRequestWithContext(ctx, http.MethodGet, url, nil)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
resp, err := httpClient.Do(req)
|
||||
if err != nil {
|
||||
lastErr = err
|
||||
time.Sleep(time.Duration(200*(1<<attempt)) * time.Millisecond)
|
||||
continue
|
||||
}
|
||||
func() {
|
||||
defer resp.Body.Close()
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
lastErr = errors.New(resp.Status)
|
||||
return
|
||||
}
|
||||
limited := io.LimitReader(resp.Body, 10<<20)
|
||||
if err := json.NewDecoder(limited).Decode(out); err != nil {
|
||||
lastErr = err
|
||||
return
|
||||
}
|
||||
if !out.Success && len(out.Data) == 0 && out.Message == "" {
|
||||
out.Success = true
|
||||
}
|
||||
lastErr = nil
|
||||
}()
|
||||
if lastErr == nil {
|
||||
return nil
|
||||
}
|
||||
time.Sleep(time.Duration(200*(1<<attempt)) * time.Millisecond)
|
||||
}
|
||||
return lastErr
|
||||
}
|
||||
|
||||
func ensureVendorID(vendorName string, vendorByName map[string]upstreamVendor, vendorIDCache map[string]int, createdVendors *int) int {
|
||||
if vendorName == "" {
|
||||
return 0
|
||||
}
|
||||
if id, ok := vendorIDCache[vendorName]; ok {
|
||||
return id
|
||||
}
|
||||
var existing model.Vendor
|
||||
if err := model.DB.Where("name = ?", vendorName).First(&existing).Error; err == nil {
|
||||
vendorIDCache[vendorName] = existing.Id
|
||||
return existing.Id
|
||||
}
|
||||
uv := vendorByName[vendorName]
|
||||
v := &model.Vendor{
|
||||
Name: vendorName,
|
||||
Description: uv.Description,
|
||||
Icon: coalesce(uv.Icon, ""),
|
||||
Status: chooseStatus(uv.Status, 1),
|
||||
}
|
||||
if err := v.Insert(); err == nil {
|
||||
*createdVendors++
|
||||
vendorIDCache[vendorName] = v.Id
|
||||
return v.Id
|
||||
}
|
||||
vendorIDCache[vendorName] = 0
|
||||
return 0
|
||||
}
|
||||
|
||||
// SyncUpstreamModels 同步上游模型与供应商,仅对「未配置模型」生效
|
||||
func SyncUpstreamModels(c *gin.Context) {
|
||||
var req syncRequest
|
||||
// 允许空体
|
||||
_ = c.ShouldBindJSON(&req)
|
||||
// 1) 获取未配置模型列表
|
||||
missing, err := model.GetMissingModels()
|
||||
if err != nil {
|
||||
c.JSON(http.StatusOK, gin.H{"success": false, "message": err.Error()})
|
||||
return
|
||||
}
|
||||
if len(missing) == 0 {
|
||||
c.JSON(http.StatusOK, gin.H{"success": true, "data": gin.H{
|
||||
"created_models": 0,
|
||||
"created_vendors": 0,
|
||||
"skipped_models": []string{},
|
||||
}})
|
||||
return
|
||||
}
|
||||
|
||||
// 2) 拉取上游 vendors 与 models
|
||||
ctx, cancel := context.WithTimeout(c.Request.Context(), 15*time.Second)
|
||||
defer cancel()
|
||||
|
||||
var vendorsEnv upstreamEnvelope[upstreamVendor]
|
||||
_ = fetchJSON(ctx, upstreamVendorsURL, &vendorsEnv) // 若失败不拦截,后续降级
|
||||
|
||||
var modelsEnv upstreamEnvelope[upstreamModel]
|
||||
if err := fetchJSON(ctx, upstreamModelsURL, &modelsEnv); err != nil {
|
||||
c.JSON(http.StatusOK, gin.H{"success": false, "message": "获取上游模型失败: " + err.Error()})
|
||||
return
|
||||
}
|
||||
|
||||
// 建立映射
|
||||
vendorByName := make(map[string]upstreamVendor)
|
||||
for _, v := range vendorsEnv.Data {
|
||||
if v.Name != "" {
|
||||
vendorByName[v.Name] = v
|
||||
}
|
||||
}
|
||||
modelByName := make(map[string]upstreamModel)
|
||||
for _, m := range modelsEnv.Data {
|
||||
if m.ModelName != "" {
|
||||
modelByName[m.ModelName] = m
|
||||
}
|
||||
}
|
||||
|
||||
// 3) 执行同步:仅创建缺失模型;若上游缺失该模型则跳过
|
||||
createdModels := 0
|
||||
createdVendors := 0
|
||||
updatedModels := 0
|
||||
var skipped []string
|
||||
var createdList []string
|
||||
var updatedList []string
|
||||
|
||||
// 本地缓存:vendorName -> id
|
||||
vendorIDCache := make(map[string]int)
|
||||
|
||||
for _, name := range missing {
|
||||
up, ok := modelByName[name]
|
||||
if !ok {
|
||||
skipped = append(skipped, name)
|
||||
continue
|
||||
}
|
||||
|
||||
// 若本地已存在且设置为不同步,则跳过(极端情况:缺失列表与本地状态不同步时)
|
||||
var existing model.Model
|
||||
if err := model.DB.Where("model_name = ?", name).First(&existing).Error; err == nil {
|
||||
if existing.SyncOfficial == 0 {
|
||||
skipped = append(skipped, name)
|
||||
continue
|
||||
}
|
||||
}
|
||||
|
||||
// 确保 vendor 存在
|
||||
vendorID := ensureVendorID(up.VendorName, vendorByName, vendorIDCache, &createdVendors)
|
||||
|
||||
// 创建模型
|
||||
mi := &model.Model{
|
||||
ModelName: name,
|
||||
Description: up.Description,
|
||||
Icon: up.Icon,
|
||||
Tags: up.Tags,
|
||||
VendorID: vendorID,
|
||||
Status: chooseStatus(up.Status, 1),
|
||||
NameRule: up.NameRule,
|
||||
}
|
||||
if err := mi.Insert(); err == nil {
|
||||
createdModels++
|
||||
createdList = append(createdList, name)
|
||||
} else {
|
||||
skipped = append(skipped, name)
|
||||
}
|
||||
}
|
||||
|
||||
// 4) 处理可选覆盖(更新本地已有模型的差异字段)
|
||||
if len(req.Overwrite) > 0 {
|
||||
// vendorIDCache 已用于创建阶段,可复用
|
||||
for _, ow := range req.Overwrite {
|
||||
up, ok := modelByName[ow.ModelName]
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
var local model.Model
|
||||
if err := model.DB.Where("model_name = ?", ow.ModelName).First(&local).Error; err != nil {
|
||||
continue
|
||||
}
|
||||
|
||||
// 跳过被禁用官方同步的模型
|
||||
if local.SyncOfficial == 0 {
|
||||
continue
|
||||
}
|
||||
|
||||
// 映射 vendor
|
||||
newVendorID := ensureVendorID(up.VendorName, vendorByName, vendorIDCache, &createdVendors)
|
||||
|
||||
// 应用字段覆盖(事务)
|
||||
_ = model.DB.Transaction(func(tx *gorm.DB) error {
|
||||
needUpdate := false
|
||||
if containsField(ow.Fields, "description") {
|
||||
local.Description = up.Description
|
||||
needUpdate = true
|
||||
}
|
||||
if containsField(ow.Fields, "icon") {
|
||||
local.Icon = up.Icon
|
||||
needUpdate = true
|
||||
}
|
||||
if containsField(ow.Fields, "tags") {
|
||||
local.Tags = up.Tags
|
||||
needUpdate = true
|
||||
}
|
||||
if containsField(ow.Fields, "vendor") {
|
||||
local.VendorID = newVendorID
|
||||
needUpdate = true
|
||||
}
|
||||
if containsField(ow.Fields, "name_rule") {
|
||||
local.NameRule = up.NameRule
|
||||
needUpdate = true
|
||||
}
|
||||
if containsField(ow.Fields, "status") {
|
||||
local.Status = chooseStatus(up.Status, local.Status)
|
||||
needUpdate = true
|
||||
}
|
||||
if !needUpdate {
|
||||
return nil
|
||||
}
|
||||
if err := tx.Save(&local).Error; err != nil {
|
||||
return err
|
||||
}
|
||||
updatedModels++
|
||||
updatedList = append(updatedList, ow.ModelName)
|
||||
return nil
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": true,
|
||||
"data": gin.H{
|
||||
"created_models": createdModels,
|
||||
"created_vendors": createdVendors,
|
||||
"updated_models": updatedModels,
|
||||
"skipped_models": skipped,
|
||||
"created_list": createdList,
|
||||
"updated_list": updatedList,
|
||||
},
|
||||
})
|
||||
}
|
||||
|
||||
func containsField(fields []string, key string) bool {
|
||||
key = strings.ToLower(strings.TrimSpace(key))
|
||||
for _, f := range fields {
|
||||
if strings.ToLower(strings.TrimSpace(f)) == key {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func coalesce(a, b string) string {
|
||||
if strings.TrimSpace(a) != "" {
|
||||
return a
|
||||
}
|
||||
return b
|
||||
}
|
||||
|
||||
func chooseStatus(primary, fallback int) int {
|
||||
if primary == 0 && fallback != 0 {
|
||||
return fallback
|
||||
}
|
||||
if primary != 0 {
|
||||
return primary
|
||||
}
|
||||
return 1
|
||||
}
|
||||
|
||||
// SyncUpstreamPreview 预览上游与本地的差异(仅用于弹窗选择)
|
||||
func SyncUpstreamPreview(c *gin.Context) {
|
||||
// 1) 拉取上游数据
|
||||
ctx, cancel := context.WithTimeout(c.Request.Context(), 15*time.Second)
|
||||
defer cancel()
|
||||
|
||||
var vendorsEnv upstreamEnvelope[upstreamVendor]
|
||||
_ = fetchJSON(ctx, upstreamVendorsURL, &vendorsEnv)
|
||||
|
||||
var modelsEnv upstreamEnvelope[upstreamModel]
|
||||
if err := fetchJSON(ctx, upstreamModelsURL, &modelsEnv); err != nil {
|
||||
c.JSON(http.StatusOK, gin.H{"success": false, "message": "获取上游模型失败: " + err.Error()})
|
||||
return
|
||||
}
|
||||
|
||||
vendorByName := make(map[string]upstreamVendor)
|
||||
for _, v := range vendorsEnv.Data {
|
||||
if v.Name != "" {
|
||||
vendorByName[v.Name] = v
|
||||
}
|
||||
}
|
||||
modelByName := make(map[string]upstreamModel)
|
||||
upstreamNames := make([]string, 0, len(modelsEnv.Data))
|
||||
for _, m := range modelsEnv.Data {
|
||||
if m.ModelName != "" {
|
||||
modelByName[m.ModelName] = m
|
||||
upstreamNames = append(upstreamNames, m.ModelName)
|
||||
}
|
||||
}
|
||||
|
||||
// 2) 本地已有模型
|
||||
var locals []model.Model
|
||||
if len(upstreamNames) > 0 {
|
||||
_ = model.DB.Where("model_name IN ? AND sync_official <> 0", upstreamNames).Find(&locals).Error
|
||||
}
|
||||
|
||||
// 本地 vendor 名称映射
|
||||
vendorIdSet := make(map[int]struct{})
|
||||
for _, m := range locals {
|
||||
if m.VendorID != 0 {
|
||||
vendorIdSet[m.VendorID] = struct{}{}
|
||||
}
|
||||
}
|
||||
vendorIDs := make([]int, 0, len(vendorIdSet))
|
||||
for id := range vendorIdSet {
|
||||
vendorIDs = append(vendorIDs, id)
|
||||
}
|
||||
idToVendorName := make(map[int]string)
|
||||
if len(vendorIDs) > 0 {
|
||||
var dbVendors []model.Vendor
|
||||
_ = model.DB.Where("id IN ?", vendorIDs).Find(&dbVendors).Error
|
||||
for _, v := range dbVendors {
|
||||
idToVendorName[v.Id] = v.Name
|
||||
}
|
||||
}
|
||||
|
||||
// 3) 缺失且上游存在的模型
|
||||
missingList, _ := model.GetMissingModels()
|
||||
var missing []string
|
||||
for _, name := range missingList {
|
||||
if _, ok := modelByName[name]; ok {
|
||||
missing = append(missing, name)
|
||||
}
|
||||
}
|
||||
|
||||
// 4) 计算冲突字段
|
||||
type conflictField struct {
|
||||
Field string `json:"field"`
|
||||
Local interface{} `json:"local"`
|
||||
Upstream interface{} `json:"upstream"`
|
||||
}
|
||||
type conflictItem struct {
|
||||
ModelName string `json:"model_name"`
|
||||
Fields []conflictField `json:"fields"`
|
||||
}
|
||||
|
||||
var conflicts []conflictItem
|
||||
for _, local := range locals {
|
||||
up, ok := modelByName[local.ModelName]
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
fields := make([]conflictField, 0, 6)
|
||||
if strings.TrimSpace(local.Description) != strings.TrimSpace(up.Description) {
|
||||
fields = append(fields, conflictField{Field: "description", Local: local.Description, Upstream: up.Description})
|
||||
}
|
||||
if strings.TrimSpace(local.Icon) != strings.TrimSpace(up.Icon) {
|
||||
fields = append(fields, conflictField{Field: "icon", Local: local.Icon, Upstream: up.Icon})
|
||||
}
|
||||
if strings.TrimSpace(local.Tags) != strings.TrimSpace(up.Tags) {
|
||||
fields = append(fields, conflictField{Field: "tags", Local: local.Tags, Upstream: up.Tags})
|
||||
}
|
||||
// vendor 对比使用名称
|
||||
localVendor := idToVendorName[local.VendorID]
|
||||
if strings.TrimSpace(localVendor) != strings.TrimSpace(up.VendorName) {
|
||||
fields = append(fields, conflictField{Field: "vendor", Local: localVendor, Upstream: up.VendorName})
|
||||
}
|
||||
if local.NameRule != up.NameRule {
|
||||
fields = append(fields, conflictField{Field: "name_rule", Local: local.NameRule, Upstream: up.NameRule})
|
||||
}
|
||||
if local.Status != chooseStatus(up.Status, local.Status) {
|
||||
fields = append(fields, conflictField{Field: "status", Local: local.Status, Upstream: up.Status})
|
||||
}
|
||||
if len(fields) > 0 {
|
||||
conflicts = append(conflicts, conflictItem{ModelName: local.ModelName, Fields: fields})
|
||||
}
|
||||
}
|
||||
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": true,
|
||||
"data": gin.H{
|
||||
"missing": missing,
|
||||
"conflicts": conflicts,
|
||||
},
|
||||
})
|
||||
}
|
||||
|
||||
|
||||
@@ -69,7 +69,7 @@ func getOidcUserInfoByCode(code string) (*OidcUser, error) {
|
||||
}
|
||||
|
||||
if oidcResponse.AccessToken == "" {
|
||||
common.SysError("OIDC 获取 Token 失败,请检查设置!")
|
||||
common.SysLog("OIDC 获取 Token 失败,请检查设置!")
|
||||
return nil, errors.New("OIDC 获取 Token 失败,请检查设置!")
|
||||
}
|
||||
|
||||
@@ -85,7 +85,7 @@ func getOidcUserInfoByCode(code string) (*OidcUser, error) {
|
||||
}
|
||||
defer res2.Body.Close()
|
||||
if res2.StatusCode != http.StatusOK {
|
||||
common.SysError("OIDC 获取用户信息失败!请检查设置!")
|
||||
common.SysLog("OIDC 获取用户信息失败!请检查设置!")
|
||||
return nil, errors.New("OIDC 获取用户信息失败!请检查设置!")
|
||||
}
|
||||
|
||||
@@ -95,7 +95,7 @@ func getOidcUserInfoByCode(code string) (*OidcUser, error) {
|
||||
return nil, err
|
||||
}
|
||||
if oidcUser.OpenID == "" || oidcUser.Email == "" {
|
||||
common.SysError("OIDC 获取用户信息为空!请检查设置!")
|
||||
common.SysLog("OIDC 获取用户信息为空!请检查设置!")
|
||||
return nil, errors.New("OIDC 获取用户信息为空!请检查设置!")
|
||||
}
|
||||
return &oidcUser, nil
|
||||
|
||||
@@ -5,10 +5,8 @@ import (
|
||||
"fmt"
|
||||
"one-api/common"
|
||||
"one-api/constant"
|
||||
"one-api/dto"
|
||||
"one-api/middleware"
|
||||
"one-api/model"
|
||||
"one-api/setting"
|
||||
"one-api/types"
|
||||
"time"
|
||||
|
||||
@@ -28,41 +26,19 @@ func Playground(c *gin.Context) {
|
||||
|
||||
useAccessToken := c.GetBool("use_access_token")
|
||||
if useAccessToken {
|
||||
newAPIError = types.NewError(errors.New("暂不支持使用 access token"), types.ErrorCodeAccessDenied)
|
||||
newAPIError = types.NewError(errors.New("暂不支持使用 access token"), types.ErrorCodeAccessDenied, types.ErrOptionWithSkipRetry())
|
||||
return
|
||||
}
|
||||
|
||||
playgroundRequest := &dto.PlayGroundRequest{}
|
||||
err := common.UnmarshalBodyReusable(c, playgroundRequest)
|
||||
if err != nil {
|
||||
newAPIError = types.NewError(err, types.ErrorCodeInvalidRequest)
|
||||
return
|
||||
}
|
||||
|
||||
if playgroundRequest.Model == "" {
|
||||
newAPIError = types.NewError(errors.New("请选择模型"), types.ErrorCodeInvalidRequest)
|
||||
return
|
||||
}
|
||||
c.Set("original_model", playgroundRequest.Model)
|
||||
group := playgroundRequest.Group
|
||||
userGroup := c.GetString("group")
|
||||
|
||||
if group == "" {
|
||||
group = userGroup
|
||||
} else {
|
||||
if !setting.GroupInUserUsableGroups(group) && group != userGroup {
|
||||
newAPIError = types.NewError(errors.New("无权访问该分组"), types.ErrorCodeAccessDenied)
|
||||
return
|
||||
}
|
||||
c.Set("group", group)
|
||||
}
|
||||
group := c.GetString("group")
|
||||
modelName := c.GetString("original_model")
|
||||
|
||||
userId := c.GetInt("id")
|
||||
|
||||
// Write user context to ensure acceptUnsetRatio is available
|
||||
userCache, err := model.GetUserCache(userId)
|
||||
if err != nil {
|
||||
newAPIError = types.NewError(err, types.ErrorCodeQueryDataError)
|
||||
newAPIError = types.NewError(err, types.ErrorCodeQueryDataError, types.ErrOptionWithSkipRetry())
|
||||
return
|
||||
}
|
||||
userCache.WriteContext(c)
|
||||
@@ -73,12 +49,12 @@ func Playground(c *gin.Context) {
|
||||
Group: group,
|
||||
}
|
||||
_ = middleware.SetupContextForToken(c, tempToken)
|
||||
_, newAPIError = getChannel(c, group, playgroundRequest.Model, 0)
|
||||
_, newAPIError = getChannel(c, group, modelName, 0)
|
||||
if newAPIError != nil {
|
||||
return
|
||||
}
|
||||
//middleware.SetupContextForSelectedChannel(c, channel, playgroundRequest.Model)
|
||||
common.SetContextKey(c, constant.ContextKeyRequestStartTime, time.Now())
|
||||
|
||||
Relay(c)
|
||||
Relay(c, types.RelayFormatOpenAI)
|
||||
}
|
||||
|
||||
90
controller/prefill_group.go
Normal file
90
controller/prefill_group.go
Normal file
@@ -0,0 +1,90 @@
|
||||
package controller
|
||||
|
||||
import (
|
||||
"strconv"
|
||||
|
||||
"one-api/common"
|
||||
"one-api/model"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
// GetPrefillGroups 获取预填组列表,可通过 ?type=xxx 过滤
|
||||
func GetPrefillGroups(c *gin.Context) {
|
||||
groupType := c.Query("type")
|
||||
groups, err := model.GetAllPrefillGroups(groupType)
|
||||
if err != nil {
|
||||
common.ApiError(c, err)
|
||||
return
|
||||
}
|
||||
common.ApiSuccess(c, groups)
|
||||
}
|
||||
|
||||
// CreatePrefillGroup 创建新的预填组
|
||||
func CreatePrefillGroup(c *gin.Context) {
|
||||
var g model.PrefillGroup
|
||||
if err := c.ShouldBindJSON(&g); err != nil {
|
||||
common.ApiError(c, err)
|
||||
return
|
||||
}
|
||||
if g.Name == "" || g.Type == "" {
|
||||
common.ApiErrorMsg(c, "组名称和类型不能为空")
|
||||
return
|
||||
}
|
||||
// 创建前检查名称
|
||||
if dup, err := model.IsPrefillGroupNameDuplicated(0, g.Name); err != nil {
|
||||
common.ApiError(c, err)
|
||||
return
|
||||
} else if dup {
|
||||
common.ApiErrorMsg(c, "组名称已存在")
|
||||
return
|
||||
}
|
||||
|
||||
if err := g.Insert(); err != nil {
|
||||
common.ApiError(c, err)
|
||||
return
|
||||
}
|
||||
common.ApiSuccess(c, &g)
|
||||
}
|
||||
|
||||
// UpdatePrefillGroup 更新预填组
|
||||
func UpdatePrefillGroup(c *gin.Context) {
|
||||
var g model.PrefillGroup
|
||||
if err := c.ShouldBindJSON(&g); err != nil {
|
||||
common.ApiError(c, err)
|
||||
return
|
||||
}
|
||||
if g.Id == 0 {
|
||||
common.ApiErrorMsg(c, "缺少组 ID")
|
||||
return
|
||||
}
|
||||
// 名称冲突检查
|
||||
if dup, err := model.IsPrefillGroupNameDuplicated(g.Id, g.Name); err != nil {
|
||||
common.ApiError(c, err)
|
||||
return
|
||||
} else if dup {
|
||||
common.ApiErrorMsg(c, "组名称已存在")
|
||||
return
|
||||
}
|
||||
|
||||
if err := g.Update(); err != nil {
|
||||
common.ApiError(c, err)
|
||||
return
|
||||
}
|
||||
common.ApiSuccess(c, &g)
|
||||
}
|
||||
|
||||
// DeletePrefillGroup 删除预填组
|
||||
func DeletePrefillGroup(c *gin.Context) {
|
||||
idStr := c.Param("id")
|
||||
id, err := strconv.Atoi(idStr)
|
||||
if err != nil {
|
||||
common.ApiError(c, err)
|
||||
return
|
||||
}
|
||||
if err := model.DeletePrefillGroupByID(id); err != nil {
|
||||
common.ApiError(c, err)
|
||||
return
|
||||
}
|
||||
common.ApiSuccess(c, nil)
|
||||
}
|
||||
@@ -39,10 +39,13 @@ func GetPricing(c *gin.Context) {
|
||||
}
|
||||
|
||||
c.JSON(200, gin.H{
|
||||
"success": true,
|
||||
"data": pricing,
|
||||
"group_ratio": groupRatio,
|
||||
"usable_group": usableGroup,
|
||||
"success": true,
|
||||
"data": pricing,
|
||||
"vendors": model.GetVendors(),
|
||||
"group_ratio": groupRatio,
|
||||
"usable_group": usableGroup,
|
||||
"supported_endpoint": model.GetSupportedEndpointMap(),
|
||||
"auto_groups": setting.AutoGroups,
|
||||
})
|
||||
}
|
||||
|
||||
|
||||
@@ -1,24 +1,24 @@
|
||||
package controller
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"one-api/setting/ratio_setting"
|
||||
"net/http"
|
||||
"one-api/setting/ratio_setting"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
func GetRatioConfig(c *gin.Context) {
|
||||
if !ratio_setting.IsExposeRatioEnabled() {
|
||||
c.JSON(http.StatusForbidden, gin.H{
|
||||
"success": false,
|
||||
"message": "倍率配置接口未启用",
|
||||
})
|
||||
return
|
||||
}
|
||||
if !ratio_setting.IsExposeRatioEnabled() {
|
||||
c.JSON(http.StatusForbidden, gin.H{
|
||||
"success": false,
|
||||
"message": "倍率配置接口未启用",
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": true,
|
||||
"message": "",
|
||||
"data": ratio_setting.GetExposedData(),
|
||||
})
|
||||
}
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": true,
|
||||
"message": "",
|
||||
"data": ratio_setting.GetExposedData(),
|
||||
})
|
||||
}
|
||||
|
||||
@@ -1,474 +1,539 @@
|
||||
package controller
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"net"
|
||||
"net/http"
|
||||
"one-api/logger"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"one-api/common"
|
||||
"one-api/dto"
|
||||
"one-api/model"
|
||||
"one-api/setting/ratio_setting"
|
||||
"one-api/dto"
|
||||
"one-api/model"
|
||||
"one-api/setting/ratio_setting"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
const (
|
||||
defaultTimeoutSeconds = 10
|
||||
defaultEndpoint = "/api/ratio_config"
|
||||
maxConcurrentFetches = 8
|
||||
defaultTimeoutSeconds = 10
|
||||
defaultEndpoint = "/api/ratio_config"
|
||||
maxConcurrentFetches = 8
|
||||
maxRatioConfigBytes = 10 << 20 // 10MB
|
||||
floatEpsilon = 1e-9
|
||||
)
|
||||
|
||||
func nearlyEqual(a, b float64) bool {
|
||||
if a > b {
|
||||
return a-b < floatEpsilon
|
||||
}
|
||||
return b-a < floatEpsilon
|
||||
}
|
||||
|
||||
func valuesEqual(a, b interface{}) bool {
|
||||
af, aok := a.(float64)
|
||||
bf, bok := b.(float64)
|
||||
if aok && bok {
|
||||
return nearlyEqual(af, bf)
|
||||
}
|
||||
return a == b
|
||||
}
|
||||
|
||||
var ratioTypes = []string{"model_ratio", "completion_ratio", "cache_ratio", "model_price"}
|
||||
|
||||
type upstreamResult struct {
|
||||
Name string `json:"name"`
|
||||
Data map[string]any `json:"data,omitempty"`
|
||||
Err string `json:"err,omitempty"`
|
||||
Name string `json:"name"`
|
||||
Data map[string]any `json:"data,omitempty"`
|
||||
Err string `json:"err,omitempty"`
|
||||
}
|
||||
|
||||
func FetchUpstreamRatios(c *gin.Context) {
|
||||
var req dto.UpstreamRequest
|
||||
if err := c.ShouldBindJSON(&req); err != nil {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"success": false, "message": err.Error()})
|
||||
return
|
||||
}
|
||||
var req dto.UpstreamRequest
|
||||
if err := c.ShouldBindJSON(&req); err != nil {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"success": false, "message": err.Error()})
|
||||
return
|
||||
}
|
||||
|
||||
if req.Timeout <= 0 {
|
||||
req.Timeout = defaultTimeoutSeconds
|
||||
}
|
||||
if req.Timeout <= 0 {
|
||||
req.Timeout = defaultTimeoutSeconds
|
||||
}
|
||||
|
||||
var upstreams []dto.UpstreamDTO
|
||||
var upstreams []dto.UpstreamDTO
|
||||
|
||||
if len(req.Upstreams) > 0 {
|
||||
for _, u := range req.Upstreams {
|
||||
if strings.HasPrefix(u.BaseURL, "http") {
|
||||
if u.Endpoint == "" {
|
||||
u.Endpoint = defaultEndpoint
|
||||
}
|
||||
u.BaseURL = strings.TrimRight(u.BaseURL, "/")
|
||||
upstreams = append(upstreams, u)
|
||||
}
|
||||
}
|
||||
} else if len(req.ChannelIDs) > 0 {
|
||||
intIds := make([]int, 0, len(req.ChannelIDs))
|
||||
for _, id64 := range req.ChannelIDs {
|
||||
intIds = append(intIds, int(id64))
|
||||
}
|
||||
dbChannels, err := model.GetChannelsByIds(intIds)
|
||||
if err != nil {
|
||||
common.LogError(c.Request.Context(), "failed to query channels: "+err.Error())
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"success": false, "message": "查询渠道失败"})
|
||||
return
|
||||
}
|
||||
for _, ch := range dbChannels {
|
||||
if base := ch.GetBaseURL(); strings.HasPrefix(base, "http") {
|
||||
upstreams = append(upstreams, dto.UpstreamDTO{
|
||||
ID: ch.Id,
|
||||
Name: ch.Name,
|
||||
BaseURL: strings.TrimRight(base, "/"),
|
||||
Endpoint: "",
|
||||
})
|
||||
}
|
||||
}
|
||||
}
|
||||
if len(req.Upstreams) > 0 {
|
||||
for _, u := range req.Upstreams {
|
||||
if strings.HasPrefix(u.BaseURL, "http") {
|
||||
if u.Endpoint == "" {
|
||||
u.Endpoint = defaultEndpoint
|
||||
}
|
||||
u.BaseURL = strings.TrimRight(u.BaseURL, "/")
|
||||
upstreams = append(upstreams, u)
|
||||
}
|
||||
}
|
||||
} else if len(req.ChannelIDs) > 0 {
|
||||
intIds := make([]int, 0, len(req.ChannelIDs))
|
||||
for _, id64 := range req.ChannelIDs {
|
||||
intIds = append(intIds, int(id64))
|
||||
}
|
||||
dbChannels, err := model.GetChannelsByIds(intIds)
|
||||
if err != nil {
|
||||
logger.LogError(c.Request.Context(), "failed to query channels: "+err.Error())
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"success": false, "message": "查询渠道失败"})
|
||||
return
|
||||
}
|
||||
for _, ch := range dbChannels {
|
||||
if base := ch.GetBaseURL(); strings.HasPrefix(base, "http") {
|
||||
upstreams = append(upstreams, dto.UpstreamDTO{
|
||||
ID: ch.Id,
|
||||
Name: ch.Name,
|
||||
BaseURL: strings.TrimRight(base, "/"),
|
||||
Endpoint: "",
|
||||
})
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if len(upstreams) == 0 {
|
||||
c.JSON(http.StatusOK, gin.H{"success": false, "message": "无有效上游渠道"})
|
||||
return
|
||||
}
|
||||
if len(upstreams) == 0 {
|
||||
c.JSON(http.StatusOK, gin.H{"success": false, "message": "无有效上游渠道"})
|
||||
return
|
||||
}
|
||||
|
||||
var wg sync.WaitGroup
|
||||
ch := make(chan upstreamResult, len(upstreams))
|
||||
var wg sync.WaitGroup
|
||||
ch := make(chan upstreamResult, len(upstreams))
|
||||
|
||||
sem := make(chan struct{}, maxConcurrentFetches)
|
||||
sem := make(chan struct{}, maxConcurrentFetches)
|
||||
|
||||
client := &http.Client{Transport: &http.Transport{MaxIdleConns: 100, IdleConnTimeout: 90 * time.Second, TLSHandshakeTimeout: 10 * time.Second, ExpectContinueTimeout: 1 * time.Second}}
|
||||
dialer := &net.Dialer{Timeout: 10 * time.Second}
|
||||
transport := &http.Transport{MaxIdleConns: 100, IdleConnTimeout: 90 * time.Second, TLSHandshakeTimeout: 10 * time.Second, ExpectContinueTimeout: 1 * time.Second, ResponseHeaderTimeout: 10 * time.Second}
|
||||
transport.DialContext = func(ctx context.Context, network, addr string) (net.Conn, error) {
|
||||
host, _, err := net.SplitHostPort(addr)
|
||||
if err != nil {
|
||||
host = addr
|
||||
}
|
||||
// 对 github.io 优先尝试 IPv4,失败则回退 IPv6
|
||||
if strings.HasSuffix(host, "github.io") {
|
||||
if conn, err := dialer.DialContext(ctx, "tcp4", addr); err == nil {
|
||||
return conn, nil
|
||||
}
|
||||
return dialer.DialContext(ctx, "tcp6", addr)
|
||||
}
|
||||
return dialer.DialContext(ctx, network, addr)
|
||||
}
|
||||
client := &http.Client{Transport: transport}
|
||||
|
||||
for _, chn := range upstreams {
|
||||
wg.Add(1)
|
||||
go func(chItem dto.UpstreamDTO) {
|
||||
defer wg.Done()
|
||||
for _, chn := range upstreams {
|
||||
wg.Add(1)
|
||||
go func(chItem dto.UpstreamDTO) {
|
||||
defer wg.Done()
|
||||
|
||||
sem <- struct{}{}
|
||||
defer func() { <-sem }()
|
||||
sem <- struct{}{}
|
||||
defer func() { <-sem }()
|
||||
|
||||
endpoint := chItem.Endpoint
|
||||
if endpoint == "" {
|
||||
endpoint = defaultEndpoint
|
||||
} else if !strings.HasPrefix(endpoint, "/") {
|
||||
endpoint = "/" + endpoint
|
||||
}
|
||||
fullURL := chItem.BaseURL + endpoint
|
||||
endpoint := chItem.Endpoint
|
||||
var fullURL string
|
||||
if strings.HasPrefix(endpoint, "http://") || strings.HasPrefix(endpoint, "https://") {
|
||||
fullURL = endpoint
|
||||
} else {
|
||||
if endpoint == "" {
|
||||
endpoint = defaultEndpoint
|
||||
} else if !strings.HasPrefix(endpoint, "/") {
|
||||
endpoint = "/" + endpoint
|
||||
}
|
||||
fullURL = chItem.BaseURL + endpoint
|
||||
}
|
||||
|
||||
uniqueName := chItem.Name
|
||||
if chItem.ID != 0 {
|
||||
uniqueName = fmt.Sprintf("%s(%d)", chItem.Name, chItem.ID)
|
||||
}
|
||||
uniqueName := chItem.Name
|
||||
if chItem.ID != 0 {
|
||||
uniqueName = fmt.Sprintf("%s(%d)", chItem.Name, chItem.ID)
|
||||
}
|
||||
|
||||
ctx, cancel := context.WithTimeout(c.Request.Context(), time.Duration(req.Timeout)*time.Second)
|
||||
defer cancel()
|
||||
ctx, cancel := context.WithTimeout(c.Request.Context(), time.Duration(req.Timeout)*time.Second)
|
||||
defer cancel()
|
||||
|
||||
httpReq, err := http.NewRequestWithContext(ctx, http.MethodGet, fullURL, nil)
|
||||
if err != nil {
|
||||
common.LogWarn(c.Request.Context(), "build request failed: "+err.Error())
|
||||
ch <- upstreamResult{Name: uniqueName, Err: err.Error()}
|
||||
return
|
||||
}
|
||||
httpReq, err := http.NewRequestWithContext(ctx, http.MethodGet, fullURL, nil)
|
||||
if err != nil {
|
||||
logger.LogWarn(c.Request.Context(), "build request failed: "+err.Error())
|
||||
ch <- upstreamResult{Name: uniqueName, Err: err.Error()}
|
||||
return
|
||||
}
|
||||
|
||||
resp, err := client.Do(httpReq)
|
||||
if err != nil {
|
||||
common.LogWarn(c.Request.Context(), "http error on "+chItem.Name+": "+err.Error())
|
||||
ch <- upstreamResult{Name: uniqueName, Err: err.Error()}
|
||||
return
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
common.LogWarn(c.Request.Context(), "non-200 from "+chItem.Name+": "+resp.Status)
|
||||
ch <- upstreamResult{Name: uniqueName, Err: resp.Status}
|
||||
return
|
||||
}
|
||||
// 兼容两种上游接口格式:
|
||||
// type1: /api/ratio_config -> data 为 map[string]any,包含 model_ratio/completion_ratio/cache_ratio/model_price
|
||||
// type2: /api/pricing -> data 为 []Pricing 列表,需要转换为与 type1 相同的 map 格式
|
||||
var body struct {
|
||||
Success bool `json:"success"`
|
||||
Data json.RawMessage `json:"data"`
|
||||
Message string `json:"message"`
|
||||
}
|
||||
// 简单重试:最多 3 次,指数退避
|
||||
var resp *http.Response
|
||||
var lastErr error
|
||||
for attempt := 0; attempt < 3; attempt++ {
|
||||
resp, lastErr = client.Do(httpReq)
|
||||
if lastErr == nil {
|
||||
break
|
||||
}
|
||||
time.Sleep(time.Duration(200*(1<<attempt)) * time.Millisecond)
|
||||
}
|
||||
if lastErr != nil {
|
||||
logger.LogWarn(c.Request.Context(), "http error on "+chItem.Name+": "+lastErr.Error())
|
||||
ch <- upstreamResult{Name: uniqueName, Err: lastErr.Error()}
|
||||
return
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
logger.LogWarn(c.Request.Context(), "non-200 from "+chItem.Name+": "+resp.Status)
|
||||
ch <- upstreamResult{Name: uniqueName, Err: resp.Status}
|
||||
return
|
||||
}
|
||||
|
||||
if err := json.NewDecoder(resp.Body).Decode(&body); err != nil {
|
||||
common.LogWarn(c.Request.Context(), "json decode failed from "+chItem.Name+": "+err.Error())
|
||||
ch <- upstreamResult{Name: uniqueName, Err: err.Error()}
|
||||
return
|
||||
}
|
||||
// Content-Type 和响应体大小校验
|
||||
if ct := resp.Header.Get("Content-Type"); ct != "" && !strings.Contains(strings.ToLower(ct), "application/json") {
|
||||
logger.LogWarn(c.Request.Context(), "unexpected content-type from "+chItem.Name+": "+ct)
|
||||
}
|
||||
limited := io.LimitReader(resp.Body, maxRatioConfigBytes)
|
||||
// 兼容两种上游接口格式:
|
||||
// type1: /api/ratio_config -> data 为 map[string]any,包含 model_ratio/completion_ratio/cache_ratio/model_price
|
||||
// type2: /api/pricing -> data 为 []Pricing 列表,需要转换为与 type1 相同的 map 格式
|
||||
var body struct {
|
||||
Success bool `json:"success"`
|
||||
Data json.RawMessage `json:"data"`
|
||||
Message string `json:"message"`
|
||||
}
|
||||
|
||||
if !body.Success {
|
||||
ch <- upstreamResult{Name: uniqueName, Err: body.Message}
|
||||
return
|
||||
}
|
||||
if err := json.NewDecoder(limited).Decode(&body); err != nil {
|
||||
logger.LogWarn(c.Request.Context(), "json decode failed from "+chItem.Name+": "+err.Error())
|
||||
ch <- upstreamResult{Name: uniqueName, Err: err.Error()}
|
||||
return
|
||||
}
|
||||
|
||||
// 尝试按 type1 解析
|
||||
var type1Data map[string]any
|
||||
if err := json.Unmarshal(body.Data, &type1Data); err == nil {
|
||||
// 如果包含至少一个 ratioTypes 字段,则认为是 type1
|
||||
isType1 := false
|
||||
for _, rt := range ratioTypes {
|
||||
if _, ok := type1Data[rt]; ok {
|
||||
isType1 = true
|
||||
break
|
||||
}
|
||||
}
|
||||
if isType1 {
|
||||
ch <- upstreamResult{Name: uniqueName, Data: type1Data}
|
||||
return
|
||||
}
|
||||
}
|
||||
if !body.Success {
|
||||
ch <- upstreamResult{Name: uniqueName, Err: body.Message}
|
||||
return
|
||||
}
|
||||
|
||||
// 如果不是 type1,则尝试按 type2 (/api/pricing) 解析
|
||||
var pricingItems []struct {
|
||||
ModelName string `json:"model_name"`
|
||||
QuotaType int `json:"quota_type"`
|
||||
ModelRatio float64 `json:"model_ratio"`
|
||||
ModelPrice float64 `json:"model_price"`
|
||||
CompletionRatio float64 `json:"completion_ratio"`
|
||||
}
|
||||
if err := json.Unmarshal(body.Data, &pricingItems); err != nil {
|
||||
common.LogWarn(c.Request.Context(), "unrecognized data format from "+chItem.Name+": "+err.Error())
|
||||
ch <- upstreamResult{Name: uniqueName, Err: "无法解析上游返回数据"}
|
||||
return
|
||||
}
|
||||
// 若 Data 为空,将继续按 type1 尝试解析(与多数静态 ratio_config 兼容)
|
||||
|
||||
modelRatioMap := make(map[string]float64)
|
||||
completionRatioMap := make(map[string]float64)
|
||||
modelPriceMap := make(map[string]float64)
|
||||
// 尝试按 type1 解析
|
||||
var type1Data map[string]any
|
||||
if err := json.Unmarshal(body.Data, &type1Data); err == nil {
|
||||
// 如果包含至少一个 ratioTypes 字段,则认为是 type1
|
||||
isType1 := false
|
||||
for _, rt := range ratioTypes {
|
||||
if _, ok := type1Data[rt]; ok {
|
||||
isType1 = true
|
||||
break
|
||||
}
|
||||
}
|
||||
if isType1 {
|
||||
ch <- upstreamResult{Name: uniqueName, Data: type1Data}
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
for _, item := range pricingItems {
|
||||
if item.QuotaType == 1 {
|
||||
modelPriceMap[item.ModelName] = item.ModelPrice
|
||||
} else {
|
||||
modelRatioMap[item.ModelName] = item.ModelRatio
|
||||
// completionRatio 可能为 0,此时也直接赋值,保持与上游一致
|
||||
completionRatioMap[item.ModelName] = item.CompletionRatio
|
||||
}
|
||||
}
|
||||
// 如果不是 type1,则尝试按 type2 (/api/pricing) 解析
|
||||
var pricingItems []struct {
|
||||
ModelName string `json:"model_name"`
|
||||
QuotaType int `json:"quota_type"`
|
||||
ModelRatio float64 `json:"model_ratio"`
|
||||
ModelPrice float64 `json:"model_price"`
|
||||
CompletionRatio float64 `json:"completion_ratio"`
|
||||
}
|
||||
if err := json.Unmarshal(body.Data, &pricingItems); err != nil {
|
||||
logger.LogWarn(c.Request.Context(), "unrecognized data format from "+chItem.Name+": "+err.Error())
|
||||
ch <- upstreamResult{Name: uniqueName, Err: "无法解析上游返回数据"}
|
||||
return
|
||||
}
|
||||
|
||||
converted := make(map[string]any)
|
||||
modelRatioMap := make(map[string]float64)
|
||||
completionRatioMap := make(map[string]float64)
|
||||
modelPriceMap := make(map[string]float64)
|
||||
|
||||
if len(modelRatioMap) > 0 {
|
||||
ratioAny := make(map[string]any, len(modelRatioMap))
|
||||
for k, v := range modelRatioMap {
|
||||
ratioAny[k] = v
|
||||
}
|
||||
converted["model_ratio"] = ratioAny
|
||||
}
|
||||
for _, item := range pricingItems {
|
||||
if item.QuotaType == 1 {
|
||||
modelPriceMap[item.ModelName] = item.ModelPrice
|
||||
} else {
|
||||
modelRatioMap[item.ModelName] = item.ModelRatio
|
||||
// completionRatio 可能为 0,此时也直接赋值,保持与上游一致
|
||||
completionRatioMap[item.ModelName] = item.CompletionRatio
|
||||
}
|
||||
}
|
||||
|
||||
if len(completionRatioMap) > 0 {
|
||||
compAny := make(map[string]any, len(completionRatioMap))
|
||||
for k, v := range completionRatioMap {
|
||||
compAny[k] = v
|
||||
}
|
||||
converted["completion_ratio"] = compAny
|
||||
}
|
||||
converted := make(map[string]any)
|
||||
|
||||
if len(modelPriceMap) > 0 {
|
||||
priceAny := make(map[string]any, len(modelPriceMap))
|
||||
for k, v := range modelPriceMap {
|
||||
priceAny[k] = v
|
||||
}
|
||||
converted["model_price"] = priceAny
|
||||
}
|
||||
if len(modelRatioMap) > 0 {
|
||||
ratioAny := make(map[string]any, len(modelRatioMap))
|
||||
for k, v := range modelRatioMap {
|
||||
ratioAny[k] = v
|
||||
}
|
||||
converted["model_ratio"] = ratioAny
|
||||
}
|
||||
|
||||
ch <- upstreamResult{Name: uniqueName, Data: converted}
|
||||
}(chn)
|
||||
}
|
||||
if len(completionRatioMap) > 0 {
|
||||
compAny := make(map[string]any, len(completionRatioMap))
|
||||
for k, v := range completionRatioMap {
|
||||
compAny[k] = v
|
||||
}
|
||||
converted["completion_ratio"] = compAny
|
||||
}
|
||||
|
||||
wg.Wait()
|
||||
close(ch)
|
||||
if len(modelPriceMap) > 0 {
|
||||
priceAny := make(map[string]any, len(modelPriceMap))
|
||||
for k, v := range modelPriceMap {
|
||||
priceAny[k] = v
|
||||
}
|
||||
converted["model_price"] = priceAny
|
||||
}
|
||||
|
||||
localData := ratio_setting.GetExposedData()
|
||||
ch <- upstreamResult{Name: uniqueName, Data: converted}
|
||||
}(chn)
|
||||
}
|
||||
|
||||
var testResults []dto.TestResult
|
||||
var successfulChannels []struct {
|
||||
name string
|
||||
data map[string]any
|
||||
}
|
||||
wg.Wait()
|
||||
close(ch)
|
||||
|
||||
for r := range ch {
|
||||
if r.Err != "" {
|
||||
testResults = append(testResults, dto.TestResult{
|
||||
Name: r.Name,
|
||||
Status: "error",
|
||||
Error: r.Err,
|
||||
})
|
||||
} else {
|
||||
testResults = append(testResults, dto.TestResult{
|
||||
Name: r.Name,
|
||||
Status: "success",
|
||||
})
|
||||
successfulChannels = append(successfulChannels, struct {
|
||||
name string
|
||||
data map[string]any
|
||||
}{name: r.Name, data: r.Data})
|
||||
}
|
||||
}
|
||||
localData := ratio_setting.GetExposedData()
|
||||
|
||||
differences := buildDifferences(localData, successfulChannels)
|
||||
var testResults []dto.TestResult
|
||||
var successfulChannels []struct {
|
||||
name string
|
||||
data map[string]any
|
||||
}
|
||||
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": true,
|
||||
"data": gin.H{
|
||||
"differences": differences,
|
||||
"test_results": testResults,
|
||||
},
|
||||
})
|
||||
for r := range ch {
|
||||
if r.Err != "" {
|
||||
testResults = append(testResults, dto.TestResult{
|
||||
Name: r.Name,
|
||||
Status: "error",
|
||||
Error: r.Err,
|
||||
})
|
||||
} else {
|
||||
testResults = append(testResults, dto.TestResult{
|
||||
Name: r.Name,
|
||||
Status: "success",
|
||||
})
|
||||
successfulChannels = append(successfulChannels, struct {
|
||||
name string
|
||||
data map[string]any
|
||||
}{name: r.Name, data: r.Data})
|
||||
}
|
||||
}
|
||||
|
||||
differences := buildDifferences(localData, successfulChannels)
|
||||
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": true,
|
||||
"data": gin.H{
|
||||
"differences": differences,
|
||||
"test_results": testResults,
|
||||
},
|
||||
})
|
||||
}
|
||||
|
||||
func buildDifferences(localData map[string]any, successfulChannels []struct {
|
||||
name string
|
||||
data map[string]any
|
||||
name string
|
||||
data map[string]any
|
||||
}) map[string]map[string]dto.DifferenceItem {
|
||||
differences := make(map[string]map[string]dto.DifferenceItem)
|
||||
differences := make(map[string]map[string]dto.DifferenceItem)
|
||||
|
||||
allModels := make(map[string]struct{})
|
||||
|
||||
for _, ratioType := range ratioTypes {
|
||||
if localRatioAny, ok := localData[ratioType]; ok {
|
||||
if localRatio, ok := localRatioAny.(map[string]float64); ok {
|
||||
for modelName := range localRatio {
|
||||
allModels[modelName] = struct{}{}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
for _, channel := range successfulChannels {
|
||||
for _, ratioType := range ratioTypes {
|
||||
if upstreamRatio, ok := channel.data[ratioType].(map[string]any); ok {
|
||||
for modelName := range upstreamRatio {
|
||||
allModels[modelName] = struct{}{}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
allModels := make(map[string]struct{})
|
||||
|
||||
confidenceMap := make(map[string]map[string]bool)
|
||||
|
||||
// 预处理阶段:检查pricing接口的可信度
|
||||
for _, channel := range successfulChannels {
|
||||
confidenceMap[channel.name] = make(map[string]bool)
|
||||
|
||||
modelRatios, hasModelRatio := channel.data["model_ratio"].(map[string]any)
|
||||
completionRatios, hasCompletionRatio := channel.data["completion_ratio"].(map[string]any)
|
||||
|
||||
if hasModelRatio && hasCompletionRatio {
|
||||
// 遍历所有模型,检查是否满足不可信条件
|
||||
for modelName := range allModels {
|
||||
// 默认为可信
|
||||
confidenceMap[channel.name][modelName] = true
|
||||
|
||||
// 检查是否满足不可信条件:model_ratio为37.5且completion_ratio为1
|
||||
if modelRatioVal, ok := modelRatios[modelName]; ok {
|
||||
if completionRatioVal, ok := completionRatios[modelName]; ok {
|
||||
// 转换为float64进行比较
|
||||
if modelRatioFloat, ok := modelRatioVal.(float64); ok {
|
||||
if completionRatioFloat, ok := completionRatioVal.(float64); ok {
|
||||
if modelRatioFloat == 37.5 && completionRatioFloat == 1.0 {
|
||||
confidenceMap[channel.name][modelName] = false
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
} else {
|
||||
// 如果不是从pricing接口获取的数据,则全部标记为可信
|
||||
for modelName := range allModels {
|
||||
confidenceMap[channel.name][modelName] = true
|
||||
}
|
||||
}
|
||||
}
|
||||
for _, ratioType := range ratioTypes {
|
||||
if localRatioAny, ok := localData[ratioType]; ok {
|
||||
if localRatio, ok := localRatioAny.(map[string]float64); ok {
|
||||
for modelName := range localRatio {
|
||||
allModels[modelName] = struct{}{}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
for modelName := range allModels {
|
||||
for _, ratioType := range ratioTypes {
|
||||
var localValue interface{} = nil
|
||||
if localRatioAny, ok := localData[ratioType]; ok {
|
||||
if localRatio, ok := localRatioAny.(map[string]float64); ok {
|
||||
if val, exists := localRatio[modelName]; exists {
|
||||
localValue = val
|
||||
}
|
||||
}
|
||||
}
|
||||
for _, channel := range successfulChannels {
|
||||
for _, ratioType := range ratioTypes {
|
||||
if upstreamRatio, ok := channel.data[ratioType].(map[string]any); ok {
|
||||
for modelName := range upstreamRatio {
|
||||
allModels[modelName] = struct{}{}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
upstreamValues := make(map[string]interface{})
|
||||
confidenceValues := make(map[string]bool)
|
||||
hasUpstreamValue := false
|
||||
hasDifference := false
|
||||
confidenceMap := make(map[string]map[string]bool)
|
||||
|
||||
for _, channel := range successfulChannels {
|
||||
var upstreamValue interface{} = nil
|
||||
|
||||
if upstreamRatio, ok := channel.data[ratioType].(map[string]any); ok {
|
||||
if val, exists := upstreamRatio[modelName]; exists {
|
||||
upstreamValue = val
|
||||
hasUpstreamValue = true
|
||||
|
||||
if localValue != nil && localValue != val {
|
||||
hasDifference = true
|
||||
} else if localValue == val {
|
||||
upstreamValue = "same"
|
||||
}
|
||||
}
|
||||
}
|
||||
if upstreamValue == nil && localValue == nil {
|
||||
upstreamValue = "same"
|
||||
}
|
||||
|
||||
if localValue == nil && upstreamValue != nil && upstreamValue != "same" {
|
||||
hasDifference = true
|
||||
}
|
||||
|
||||
upstreamValues[channel.name] = upstreamValue
|
||||
|
||||
confidenceValues[channel.name] = confidenceMap[channel.name][modelName]
|
||||
}
|
||||
// 预处理阶段:检查pricing接口的可信度
|
||||
for _, channel := range successfulChannels {
|
||||
confidenceMap[channel.name] = make(map[string]bool)
|
||||
|
||||
shouldInclude := false
|
||||
|
||||
if localValue != nil {
|
||||
if hasDifference {
|
||||
shouldInclude = true
|
||||
}
|
||||
} else {
|
||||
if hasUpstreamValue {
|
||||
shouldInclude = true
|
||||
}
|
||||
}
|
||||
modelRatios, hasModelRatio := channel.data["model_ratio"].(map[string]any)
|
||||
completionRatios, hasCompletionRatio := channel.data["completion_ratio"].(map[string]any)
|
||||
|
||||
if shouldInclude {
|
||||
if differences[modelName] == nil {
|
||||
differences[modelName] = make(map[string]dto.DifferenceItem)
|
||||
}
|
||||
differences[modelName][ratioType] = dto.DifferenceItem{
|
||||
Current: localValue,
|
||||
Upstreams: upstreamValues,
|
||||
Confidence: confidenceValues,
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
if hasModelRatio && hasCompletionRatio {
|
||||
// 遍历所有模型,检查是否满足不可信条件
|
||||
for modelName := range allModels {
|
||||
// 默认为可信
|
||||
confidenceMap[channel.name][modelName] = true
|
||||
|
||||
channelHasDiff := make(map[string]bool)
|
||||
for _, ratioMap := range differences {
|
||||
for _, item := range ratioMap {
|
||||
for chName, val := range item.Upstreams {
|
||||
if val != nil && val != "same" {
|
||||
channelHasDiff[chName] = true
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
// 检查是否满足不可信条件:model_ratio为37.5且completion_ratio为1
|
||||
if modelRatioVal, ok := modelRatios[modelName]; ok {
|
||||
if completionRatioVal, ok := completionRatios[modelName]; ok {
|
||||
// 转换为float64进行比较
|
||||
if modelRatioFloat, ok := modelRatioVal.(float64); ok {
|
||||
if completionRatioFloat, ok := completionRatioVal.(float64); ok {
|
||||
if modelRatioFloat == 37.5 && completionRatioFloat == 1.0 {
|
||||
confidenceMap[channel.name][modelName] = false
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
} else {
|
||||
// 如果不是从pricing接口获取的数据,则全部标记为可信
|
||||
for modelName := range allModels {
|
||||
confidenceMap[channel.name][modelName] = true
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
for modelName, ratioMap := range differences {
|
||||
for ratioType, item := range ratioMap {
|
||||
for chName := range item.Upstreams {
|
||||
if !channelHasDiff[chName] {
|
||||
delete(item.Upstreams, chName)
|
||||
delete(item.Confidence, chName)
|
||||
}
|
||||
}
|
||||
for modelName := range allModels {
|
||||
for _, ratioType := range ratioTypes {
|
||||
var localValue interface{} = nil
|
||||
if localRatioAny, ok := localData[ratioType]; ok {
|
||||
if localRatio, ok := localRatioAny.(map[string]float64); ok {
|
||||
if val, exists := localRatio[modelName]; exists {
|
||||
localValue = val
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
allSame := true
|
||||
for _, v := range item.Upstreams {
|
||||
if v != "same" {
|
||||
allSame = false
|
||||
break
|
||||
}
|
||||
}
|
||||
if len(item.Upstreams) == 0 || allSame {
|
||||
delete(ratioMap, ratioType)
|
||||
} else {
|
||||
differences[modelName][ratioType] = item
|
||||
}
|
||||
}
|
||||
upstreamValues := make(map[string]interface{})
|
||||
confidenceValues := make(map[string]bool)
|
||||
hasUpstreamValue := false
|
||||
hasDifference := false
|
||||
|
||||
if len(ratioMap) == 0 {
|
||||
delete(differences, modelName)
|
||||
}
|
||||
}
|
||||
for _, channel := range successfulChannels {
|
||||
var upstreamValue interface{} = nil
|
||||
|
||||
return differences
|
||||
if upstreamRatio, ok := channel.data[ratioType].(map[string]any); ok {
|
||||
if val, exists := upstreamRatio[modelName]; exists {
|
||||
upstreamValue = val
|
||||
hasUpstreamValue = true
|
||||
|
||||
if localValue != nil && !valuesEqual(localValue, val) {
|
||||
hasDifference = true
|
||||
} else if valuesEqual(localValue, val) {
|
||||
upstreamValue = "same"
|
||||
}
|
||||
}
|
||||
}
|
||||
if upstreamValue == nil && localValue == nil {
|
||||
upstreamValue = "same"
|
||||
}
|
||||
|
||||
if localValue == nil && upstreamValue != nil && upstreamValue != "same" {
|
||||
hasDifference = true
|
||||
}
|
||||
|
||||
upstreamValues[channel.name] = upstreamValue
|
||||
|
||||
confidenceValues[channel.name] = confidenceMap[channel.name][modelName]
|
||||
}
|
||||
|
||||
shouldInclude := false
|
||||
|
||||
if localValue != nil {
|
||||
if hasDifference {
|
||||
shouldInclude = true
|
||||
}
|
||||
} else {
|
||||
if hasUpstreamValue {
|
||||
shouldInclude = true
|
||||
}
|
||||
}
|
||||
|
||||
if shouldInclude {
|
||||
if differences[modelName] == nil {
|
||||
differences[modelName] = make(map[string]dto.DifferenceItem)
|
||||
}
|
||||
differences[modelName][ratioType] = dto.DifferenceItem{
|
||||
Current: localValue,
|
||||
Upstreams: upstreamValues,
|
||||
Confidence: confidenceValues,
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
channelHasDiff := make(map[string]bool)
|
||||
for _, ratioMap := range differences {
|
||||
for _, item := range ratioMap {
|
||||
for chName, val := range item.Upstreams {
|
||||
if val != nil && val != "same" {
|
||||
channelHasDiff[chName] = true
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
for modelName, ratioMap := range differences {
|
||||
for ratioType, item := range ratioMap {
|
||||
for chName := range item.Upstreams {
|
||||
if !channelHasDiff[chName] {
|
||||
delete(item.Upstreams, chName)
|
||||
delete(item.Confidence, chName)
|
||||
}
|
||||
}
|
||||
|
||||
allSame := true
|
||||
for _, v := range item.Upstreams {
|
||||
if v != "same" {
|
||||
allSame = false
|
||||
break
|
||||
}
|
||||
}
|
||||
if len(item.Upstreams) == 0 || allSame {
|
||||
delete(ratioMap, ratioType)
|
||||
} else {
|
||||
differences[modelName][ratioType] = item
|
||||
}
|
||||
}
|
||||
|
||||
if len(ratioMap) == 0 {
|
||||
delete(differences, modelName)
|
||||
}
|
||||
}
|
||||
|
||||
return differences
|
||||
}
|
||||
|
||||
func GetSyncableChannels(c *gin.Context) {
|
||||
channels, err := model.GetAllChannels(0, 0, true, false)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": false,
|
||||
"message": err.Error(),
|
||||
})
|
||||
return
|
||||
}
|
||||
channels, err := model.GetAllChannels(0, 0, true, false)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": false,
|
||||
"message": err.Error(),
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
var syncableChannels []dto.SyncableChannel
|
||||
for _, channel := range channels {
|
||||
if channel.GetBaseURL() != "" {
|
||||
syncableChannels = append(syncableChannels, dto.SyncableChannel{
|
||||
ID: channel.Id,
|
||||
Name: channel.Name,
|
||||
BaseURL: channel.GetBaseURL(),
|
||||
Status: channel.Status,
|
||||
})
|
||||
}
|
||||
}
|
||||
var syncableChannels []dto.SyncableChannel
|
||||
for _, channel := range channels {
|
||||
if channel.GetBaseURL() != "" {
|
||||
syncableChannels = append(syncableChannels, dto.SyncableChannel{
|
||||
ID: channel.Id,
|
||||
Name: channel.Name,
|
||||
BaseURL: channel.GetBaseURL(),
|
||||
Status: channel.Status,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": true,
|
||||
"message": "",
|
||||
"data": syncableChannels,
|
||||
})
|
||||
}
|
||||
syncableChannels = append(syncableChannels, dto.SyncableChannel{
|
||||
ID: -100,
|
||||
Name: "官方倍率预设",
|
||||
BaseURL: "https://basellm.github.io",
|
||||
Status: 1,
|
||||
})
|
||||
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": true,
|
||||
"message": "",
|
||||
"data": syncableChannels,
|
||||
})
|
||||
}
|
||||
|
||||
@@ -6,6 +6,7 @@ import (
|
||||
"one-api/common"
|
||||
"one-api/model"
|
||||
"strconv"
|
||||
"unicode/utf8"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
@@ -63,7 +64,7 @@ func AddRedemption(c *gin.Context) {
|
||||
common.ApiError(c, err)
|
||||
return
|
||||
}
|
||||
if len(redemption.Name) == 0 || len(redemption.Name) > 20 {
|
||||
if utf8.RuneCountInString(redemption.Name) == 0 || utf8.RuneCountInString(redemption.Name) > 20 {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": false,
|
||||
"message": "兑换码名称长度必须在1-20之间",
|
||||
|
||||
@@ -2,115 +2,193 @@ package controller
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"log"
|
||||
"net/http"
|
||||
"one-api/common"
|
||||
"one-api/constant"
|
||||
constant2 "one-api/constant"
|
||||
"one-api/dto"
|
||||
"one-api/logger"
|
||||
"one-api/middleware"
|
||||
"one-api/model"
|
||||
"one-api/relay"
|
||||
relaycommon "one-api/relay/common"
|
||||
relayconstant "one-api/relay/constant"
|
||||
"one-api/relay/helper"
|
||||
"one-api/service"
|
||||
"one-api/setting"
|
||||
"one-api/types"
|
||||
"strings"
|
||||
|
||||
"github.com/bytedance/gopkg/util/gopool"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/gorilla/websocket"
|
||||
)
|
||||
|
||||
func relayHandler(c *gin.Context, relayMode int) *types.NewAPIError {
|
||||
func relayHandler(c *gin.Context, info *relaycommon.RelayInfo) *types.NewAPIError {
|
||||
var err *types.NewAPIError
|
||||
switch relayMode {
|
||||
switch info.RelayMode {
|
||||
case relayconstant.RelayModeImagesGenerations, relayconstant.RelayModeImagesEdits:
|
||||
err = relay.ImageHelper(c)
|
||||
err = relay.ImageHelper(c, info)
|
||||
case relayconstant.RelayModeAudioSpeech:
|
||||
fallthrough
|
||||
case relayconstant.RelayModeAudioTranslation:
|
||||
fallthrough
|
||||
case relayconstant.RelayModeAudioTranscription:
|
||||
err = relay.AudioHelper(c)
|
||||
err = relay.AudioHelper(c, info)
|
||||
case relayconstant.RelayModeRerank:
|
||||
err = relay.RerankHelper(c, relayMode)
|
||||
err = relay.RerankHelper(c, info)
|
||||
case relayconstant.RelayModeEmbeddings:
|
||||
err = relay.EmbeddingHelper(c)
|
||||
err = relay.EmbeddingHelper(c, info)
|
||||
case relayconstant.RelayModeResponses:
|
||||
err = relay.ResponsesHelper(c)
|
||||
case relayconstant.RelayModeGemini:
|
||||
err = relay.GeminiHelper(c)
|
||||
err = relay.ResponsesHelper(c, info)
|
||||
default:
|
||||
err = relay.TextHelper(c)
|
||||
err = relay.TextHelper(c, info)
|
||||
}
|
||||
|
||||
if constant2.ErrorLogEnabled && err != nil {
|
||||
// 保存错误日志到mysql中
|
||||
userId := c.GetInt("id")
|
||||
tokenName := c.GetString("token_name")
|
||||
modelName := c.GetString("original_model")
|
||||
tokenId := c.GetInt("token_id")
|
||||
userGroup := c.GetString("group")
|
||||
channelId := c.GetInt("channel_id")
|
||||
other := make(map[string]interface{})
|
||||
other["error_type"] = err.ErrorType
|
||||
other["error_code"] = err.GetErrorCode()
|
||||
other["status_code"] = err.StatusCode
|
||||
other["channel_id"] = channelId
|
||||
other["channel_name"] = c.GetString("channel_name")
|
||||
other["channel_type"] = c.GetInt("channel_type")
|
||||
|
||||
model.RecordErrorLog(c, userId, channelId, modelName, tokenName, err.Error(), tokenId, 0, false, userGroup, other)
|
||||
}
|
||||
|
||||
return err
|
||||
}
|
||||
|
||||
func Relay(c *gin.Context) {
|
||||
relayMode := relayconstant.Path2RelayMode(c.Request.URL.Path)
|
||||
func geminiRelayHandler(c *gin.Context, info *relaycommon.RelayInfo) *types.NewAPIError {
|
||||
var err *types.NewAPIError
|
||||
if strings.Contains(c.Request.URL.Path, "embed") {
|
||||
err = relay.GeminiEmbeddingHandler(c, info)
|
||||
} else {
|
||||
err = relay.GeminiHelper(c, info)
|
||||
}
|
||||
return err
|
||||
}
|
||||
|
||||
func Relay(c *gin.Context, relayFormat types.RelayFormat) {
|
||||
|
||||
requestId := c.GetString(common.RequestIdKey)
|
||||
group := c.GetString("group")
|
||||
originalModel := c.GetString("original_model")
|
||||
var newAPIError *types.NewAPIError
|
||||
group := common.GetContextKeyString(c, constant.ContextKeyUsingGroup)
|
||||
originalModel := common.GetContextKeyString(c, constant.ContextKeyOriginalModel)
|
||||
|
||||
var (
|
||||
newAPIError *types.NewAPIError
|
||||
ws *websocket.Conn
|
||||
)
|
||||
|
||||
if relayFormat == types.RelayFormatOpenAIRealtime {
|
||||
var err error
|
||||
ws, err = upgrader.Upgrade(c.Writer, c.Request, nil)
|
||||
if err != nil {
|
||||
helper.WssError(c, ws, types.NewError(err, types.ErrorCodeGetChannelFailed, types.ErrOptionWithSkipRetry()).ToOpenAIError())
|
||||
return
|
||||
}
|
||||
defer ws.Close()
|
||||
}
|
||||
|
||||
defer func() {
|
||||
if newAPIError != nil {
|
||||
newAPIError.SetMessage(common.MessageWithRequestId(newAPIError.Error(), requestId))
|
||||
switch relayFormat {
|
||||
case types.RelayFormatOpenAIRealtime:
|
||||
helper.WssError(c, ws, newAPIError.ToOpenAIError())
|
||||
case types.RelayFormatClaude:
|
||||
c.JSON(newAPIError.StatusCode, gin.H{
|
||||
"type": "error",
|
||||
"error": newAPIError.ToClaudeError(),
|
||||
})
|
||||
default:
|
||||
c.JSON(newAPIError.StatusCode, gin.H{
|
||||
"error": newAPIError.ToOpenAIError(),
|
||||
})
|
||||
}
|
||||
}
|
||||
}()
|
||||
|
||||
request, err := helper.GetAndValidateRequest(c, relayFormat)
|
||||
if err != nil {
|
||||
newAPIError = types.NewError(err, types.ErrorCodeInvalidRequest)
|
||||
return
|
||||
}
|
||||
|
||||
relayInfo, err := relaycommon.GenRelayInfo(c, relayFormat, request, ws)
|
||||
if err != nil {
|
||||
newAPIError = types.NewError(err, types.ErrorCodeGenRelayInfoFailed)
|
||||
return
|
||||
}
|
||||
|
||||
meta := request.GetTokenCountMeta()
|
||||
|
||||
if setting.ShouldCheckPromptSensitive() {
|
||||
contains, words := service.CheckSensitiveText(meta.CombineText)
|
||||
if contains {
|
||||
logger.LogWarn(c, fmt.Sprintf("user sensitive words detected: %s", strings.Join(words, ", ")))
|
||||
newAPIError = types.NewError(err, types.ErrorCodeSensitiveWordsDetected)
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
tokens, err := service.CountRequestToken(c, meta, relayInfo)
|
||||
if err != nil {
|
||||
newAPIError = types.NewError(err, types.ErrorCodeCountTokenFailed)
|
||||
return
|
||||
}
|
||||
|
||||
relayInfo.SetPromptTokens(tokens)
|
||||
|
||||
priceData, err := helper.ModelPriceHelper(c, relayInfo, tokens, meta)
|
||||
if err != nil {
|
||||
newAPIError = types.NewError(err, types.ErrorCodeModelPriceError)
|
||||
return
|
||||
}
|
||||
|
||||
// common.SetContextKey(c, constant.ContextKeyTokenCountMeta, meta)
|
||||
|
||||
preConsumedQuota, newAPIError := service.PreConsumeQuota(c, priceData.ShouldPreConsumedQuota, relayInfo)
|
||||
if newAPIError != nil {
|
||||
return
|
||||
}
|
||||
|
||||
defer func() {
|
||||
// Only return quota if downstream failed and quota was actually pre-consumed
|
||||
if newAPIError != nil && preConsumedQuota != 0 {
|
||||
service.ReturnPreConsumedQuota(c, relayInfo, preConsumedQuota)
|
||||
}
|
||||
}()
|
||||
|
||||
for i := 0; i <= common.RetryTimes; i++ {
|
||||
channel, err := getChannel(c, group, originalModel, i)
|
||||
if err != nil {
|
||||
common.LogError(c, err.Error())
|
||||
logger.LogError(c, err.Error())
|
||||
newAPIError = err
|
||||
break
|
||||
}
|
||||
|
||||
newAPIError = relayRequest(c, relayMode, channel)
|
||||
addUsedChannel(c, channel.Id)
|
||||
requestBody, _ := common.GetRequestBody(c)
|
||||
c.Request.Body = io.NopCloser(bytes.NewBuffer(requestBody))
|
||||
|
||||
if newAPIError == nil {
|
||||
return // 成功处理请求,直接返回
|
||||
switch relayFormat {
|
||||
case types.RelayFormatOpenAIRealtime:
|
||||
newAPIError = relay.WssHelper(c, relayInfo)
|
||||
case types.RelayFormatClaude:
|
||||
newAPIError = relay.ClaudeHelper(c, relayInfo)
|
||||
case types.RelayFormatGemini:
|
||||
newAPIError = geminiRelayHandler(c, relayInfo)
|
||||
default:
|
||||
newAPIError = relayHandler(c, relayInfo)
|
||||
}
|
||||
|
||||
go processChannelError(c, *types.NewChannelError(channel.Id, channel.Type, channel.Name, channel.ChannelInfo.IsMultiKey, common.GetContextKeyString(c, constant.ContextKeyChannelKey), channel.GetAutoBan()), newAPIError)
|
||||
if newAPIError == nil {
|
||||
return
|
||||
}
|
||||
|
||||
processChannelError(c, *types.NewChannelError(channel.Id, channel.Type, channel.Name, channel.ChannelInfo.IsMultiKey, common.GetContextKeyString(c, constant.ContextKeyChannelKey), channel.GetAutoBan()), newAPIError)
|
||||
|
||||
if !shouldRetry(c, newAPIError, common.RetryTimes-i) {
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
useChannel := c.GetStringSlice("use_channel")
|
||||
if len(useChannel) > 1 {
|
||||
retryLogStr := fmt.Sprintf("重试:%s", strings.Trim(strings.Join(strings.Fields(fmt.Sprint(useChannel)), "->"), "[]"))
|
||||
common.LogInfo(c, retryLogStr)
|
||||
}
|
||||
|
||||
if newAPIError != nil {
|
||||
//if newAPIError.StatusCode == http.StatusTooManyRequests {
|
||||
// common.LogError(c, fmt.Sprintf("origin 429 error: %s", newAPIError.Error()))
|
||||
// newAPIError.SetMessage("当前分组上游负载已饱和,请稍后再试")
|
||||
//}
|
||||
newAPIError.SetMessage(common.MessageWithRequestId(newAPIError.Error(), requestId))
|
||||
c.JSON(newAPIError.StatusCode, gin.H{
|
||||
"error": newAPIError.ToOpenAIError(),
|
||||
})
|
||||
logger.LogInfo(c, retryLogStr)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -121,122 +199,6 @@ var upgrader = websocket.Upgrader{
|
||||
},
|
||||
}
|
||||
|
||||
func WssRelay(c *gin.Context) {
|
||||
// 将 HTTP 连接升级为 WebSocket 连接
|
||||
|
||||
ws, err := upgrader.Upgrade(c.Writer, c.Request, nil)
|
||||
defer ws.Close()
|
||||
|
||||
if err != nil {
|
||||
helper.WssError(c, ws, types.NewError(err, types.ErrorCodeGetChannelFailed).ToOpenAIError())
|
||||
return
|
||||
}
|
||||
|
||||
relayMode := relayconstant.Path2RelayMode(c.Request.URL.Path)
|
||||
requestId := c.GetString(common.RequestIdKey)
|
||||
group := c.GetString("group")
|
||||
//wss://api.openai.com/v1/realtime?model=gpt-4o-realtime-preview-2024-10-01
|
||||
originalModel := c.GetString("original_model")
|
||||
var newAPIError *types.NewAPIError
|
||||
|
||||
for i := 0; i <= common.RetryTimes; i++ {
|
||||
channel, err := getChannel(c, group, originalModel, i)
|
||||
if err != nil {
|
||||
common.LogError(c, err.Error())
|
||||
newAPIError = err
|
||||
break
|
||||
}
|
||||
|
||||
newAPIError = wssRequest(c, ws, relayMode, channel)
|
||||
|
||||
if newAPIError == nil {
|
||||
return // 成功处理请求,直接返回
|
||||
}
|
||||
|
||||
go processChannelError(c, *types.NewChannelError(channel.Id, channel.Type, channel.Name, channel.ChannelInfo.IsMultiKey, common.GetContextKeyString(c, constant.ContextKeyChannelKey), channel.GetAutoBan()), newAPIError)
|
||||
|
||||
if !shouldRetry(c, newAPIError, common.RetryTimes-i) {
|
||||
break
|
||||
}
|
||||
}
|
||||
useChannel := c.GetStringSlice("use_channel")
|
||||
if len(useChannel) > 1 {
|
||||
retryLogStr := fmt.Sprintf("重试:%s", strings.Trim(strings.Join(strings.Fields(fmt.Sprint(useChannel)), "->"), "[]"))
|
||||
common.LogInfo(c, retryLogStr)
|
||||
}
|
||||
|
||||
if newAPIError != nil {
|
||||
//if newAPIError.StatusCode == http.StatusTooManyRequests {
|
||||
// newAPIError.SetMessage("当前分组上游负载已饱和,请稍后再试")
|
||||
//}
|
||||
newAPIError.SetMessage(common.MessageWithRequestId(newAPIError.Error(), requestId))
|
||||
helper.WssError(c, ws, newAPIError.ToOpenAIError())
|
||||
}
|
||||
}
|
||||
|
||||
func RelayClaude(c *gin.Context) {
|
||||
//relayMode := constant.Path2RelayMode(c.Request.URL.Path)
|
||||
requestId := c.GetString(common.RequestIdKey)
|
||||
group := c.GetString("group")
|
||||
originalModel := c.GetString("original_model")
|
||||
var newAPIError *types.NewAPIError
|
||||
|
||||
for i := 0; i <= common.RetryTimes; i++ {
|
||||
channel, err := getChannel(c, group, originalModel, i)
|
||||
if err != nil {
|
||||
common.LogError(c, err.Error())
|
||||
newAPIError = err
|
||||
break
|
||||
}
|
||||
|
||||
newAPIError = claudeRequest(c, channel)
|
||||
|
||||
if newAPIError == nil {
|
||||
return // 成功处理请求,直接返回
|
||||
}
|
||||
|
||||
go processChannelError(c, *types.NewChannelError(channel.Id, channel.Type, channel.Name, channel.ChannelInfo.IsMultiKey, common.GetContextKeyString(c, constant.ContextKeyChannelKey), channel.GetAutoBan()), newAPIError)
|
||||
|
||||
if !shouldRetry(c, newAPIError, common.RetryTimes-i) {
|
||||
break
|
||||
}
|
||||
}
|
||||
useChannel := c.GetStringSlice("use_channel")
|
||||
if len(useChannel) > 1 {
|
||||
retryLogStr := fmt.Sprintf("重试:%s", strings.Trim(strings.Join(strings.Fields(fmt.Sprint(useChannel)), "->"), "[]"))
|
||||
common.LogInfo(c, retryLogStr)
|
||||
}
|
||||
|
||||
if newAPIError != nil {
|
||||
newAPIError.SetMessage(common.MessageWithRequestId(newAPIError.Error(), requestId))
|
||||
c.JSON(newAPIError.StatusCode, gin.H{
|
||||
"type": "error",
|
||||
"error": newAPIError.ToClaudeError(),
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func relayRequest(c *gin.Context, relayMode int, channel *model.Channel) *types.NewAPIError {
|
||||
addUsedChannel(c, channel.Id)
|
||||
requestBody, _ := common.GetRequestBody(c)
|
||||
c.Request.Body = io.NopCloser(bytes.NewBuffer(requestBody))
|
||||
return relayHandler(c, relayMode)
|
||||
}
|
||||
|
||||
func wssRequest(c *gin.Context, ws *websocket.Conn, relayMode int, channel *model.Channel) *types.NewAPIError {
|
||||
addUsedChannel(c, channel.Id)
|
||||
requestBody, _ := common.GetRequestBody(c)
|
||||
c.Request.Body = io.NopCloser(bytes.NewBuffer(requestBody))
|
||||
return relay.WssHelper(c, ws)
|
||||
}
|
||||
|
||||
func claudeRequest(c *gin.Context, channel *model.Channel) *types.NewAPIError {
|
||||
addUsedChannel(c, channel.Id)
|
||||
requestBody, _ := common.GetRequestBody(c)
|
||||
c.Request.Body = io.NopCloser(bytes.NewBuffer(requestBody))
|
||||
return relay.ClaudeHelper(c)
|
||||
}
|
||||
|
||||
func addUsedChannel(c *gin.Context, channelId int) {
|
||||
useChannel := c.GetStringSlice("use_channel")
|
||||
useChannel = append(useChannel, fmt.Sprintf("%d", channelId))
|
||||
@@ -259,10 +221,10 @@ func getChannel(c *gin.Context, group, originalModel string, retryCount int) (*m
|
||||
}
|
||||
channel, selectGroup, err := model.CacheGetRandomSatisfiedChannel(c, group, originalModel, retryCount)
|
||||
if err != nil {
|
||||
if group == "auto" {
|
||||
return nil, types.NewError(errors.New(fmt.Sprintf("获取自动分组下模型 %s 的可用渠道失败: %s", originalModel, err.Error())), types.ErrorCodeGetChannelFailed)
|
||||
}
|
||||
return nil, types.NewError(errors.New(fmt.Sprintf("获取分组 %s 下模型 %s 的可用渠道失败: %s", selectGroup, originalModel, err.Error())), types.ErrorCodeGetChannelFailed)
|
||||
return nil, types.NewError(fmt.Errorf("获取分组 %s 下模型 %s 的可用渠道失败(retry): %s", selectGroup, originalModel, err.Error()), types.ErrorCodeGetChannelFailed, types.ErrOptionWithSkipRetry())
|
||||
}
|
||||
if channel == nil {
|
||||
return nil, types.NewError(fmt.Errorf("分组 %s 下模型 %s 的可用渠道不存在(数据库一致性已被破坏,retry)", selectGroup, originalModel), types.ErrorCodeGetChannelFailed, types.ErrOptionWithSkipRetry())
|
||||
}
|
||||
newAPIError := middleware.SetupContextForSelectedChannel(c, channel, originalModel)
|
||||
if newAPIError != nil {
|
||||
@@ -278,7 +240,7 @@ func shouldRetry(c *gin.Context, openaiErr *types.NewAPIError, retryTimes int) b
|
||||
if types.IsChannelError(openaiErr) {
|
||||
return true
|
||||
}
|
||||
if types.IsLocalError(openaiErr) {
|
||||
if types.IsSkipRetryError(openaiErr) {
|
||||
return false
|
||||
}
|
||||
if retryTimes <= 0 {
|
||||
@@ -301,10 +263,6 @@ func shouldRetry(c *gin.Context, openaiErr *types.NewAPIError, retryTimes int) b
|
||||
return true
|
||||
}
|
||||
if openaiErr.StatusCode == http.StatusBadRequest {
|
||||
channelType := c.GetInt("channel_type")
|
||||
if channelType == constant.ChannelTypeAnthropic {
|
||||
return true
|
||||
}
|
||||
return false
|
||||
}
|
||||
if openaiErr.StatusCode == 408 {
|
||||
@@ -318,44 +276,84 @@ func shouldRetry(c *gin.Context, openaiErr *types.NewAPIError, retryTimes int) b
|
||||
}
|
||||
|
||||
func processChannelError(c *gin.Context, channelError types.ChannelError, err *types.NewAPIError) {
|
||||
// 不要使用context获取渠道信息,异步处理时可能会出现渠道信息不一致的情况
|
||||
// do not use context to get channel info, there may be inconsistent channel info when processing asynchronously
|
||||
common.LogError(c, fmt.Sprintf("relay error (channel #%d, status code: %d): %s", channelError.ChannelId, err.StatusCode, err.Error()))
|
||||
if service.ShouldDisableChannel(channelError.ChannelId, err) && channelError.AutoBan {
|
||||
service.DisableChannel(channelError, err.Error())
|
||||
logger.LogError(c, fmt.Sprintf("relay error (channel #%d, status code: %d): %s", channelError.ChannelId, err.StatusCode, err.Error()))
|
||||
|
||||
gopool.Go(func() {
|
||||
// 不要使用context获取渠道信息,异步处理时可能会出现渠道信息不一致的情况
|
||||
// do not use context to get channel info, there may be inconsistent channel info when processing asynchronously
|
||||
if service.ShouldDisableChannel(channelError.ChannelId, err) && channelError.AutoBan {
|
||||
service.DisableChannel(channelError, err.Error())
|
||||
}
|
||||
})
|
||||
|
||||
if constant.ErrorLogEnabled && types.IsRecordErrorLog(err) {
|
||||
// 保存错误日志到mysql中
|
||||
userId := c.GetInt("id")
|
||||
tokenName := c.GetString("token_name")
|
||||
modelName := c.GetString("original_model")
|
||||
tokenId := c.GetInt("token_id")
|
||||
userGroup := c.GetString("group")
|
||||
channelId := c.GetInt("channel_id")
|
||||
other := make(map[string]interface{})
|
||||
other["error_type"] = err.GetErrorType()
|
||||
other["error_code"] = err.GetErrorCode()
|
||||
other["status_code"] = err.StatusCode
|
||||
other["channel_id"] = channelId
|
||||
other["channel_name"] = c.GetString("channel_name")
|
||||
other["channel_type"] = c.GetInt("channel_type")
|
||||
adminInfo := make(map[string]interface{})
|
||||
adminInfo["use_channel"] = c.GetStringSlice("use_channel")
|
||||
isMultiKey := common.GetContextKeyBool(c, constant.ContextKeyChannelIsMultiKey)
|
||||
if isMultiKey {
|
||||
adminInfo["is_multi_key"] = true
|
||||
adminInfo["multi_key_index"] = common.GetContextKeyInt(c, constant.ContextKeyChannelMultiKeyIndex)
|
||||
}
|
||||
other["admin_info"] = adminInfo
|
||||
model.RecordErrorLog(c, userId, channelId, modelName, tokenName, err.MaskSensitiveError(), tokenId, 0, false, userGroup, other)
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
func RelayMidjourney(c *gin.Context) {
|
||||
relayMode := c.GetInt("relay_mode")
|
||||
var err *dto.MidjourneyResponse
|
||||
switch relayMode {
|
||||
relayInfo, err := relaycommon.GenRelayInfo(c, types.RelayFormatMjProxy, nil, nil)
|
||||
|
||||
if err != nil {
|
||||
c.JSON(http.StatusInternalServerError, gin.H{
|
||||
"description": fmt.Sprintf("failed to generate relay info: %s", err.Error()),
|
||||
"type": "upstream_error",
|
||||
"code": 4,
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
var mjErr *dto.MidjourneyResponse
|
||||
switch relayInfo.RelayMode {
|
||||
case relayconstant.RelayModeMidjourneyNotify:
|
||||
err = relay.RelayMidjourneyNotify(c)
|
||||
mjErr = relay.RelayMidjourneyNotify(c)
|
||||
case relayconstant.RelayModeMidjourneyTaskFetch, relayconstant.RelayModeMidjourneyTaskFetchByCondition:
|
||||
err = relay.RelayMidjourneyTask(c, relayMode)
|
||||
mjErr = relay.RelayMidjourneyTask(c, relayInfo.RelayMode)
|
||||
case relayconstant.RelayModeMidjourneyTaskImageSeed:
|
||||
err = relay.RelayMidjourneyTaskImageSeed(c)
|
||||
mjErr = relay.RelayMidjourneyTaskImageSeed(c)
|
||||
case relayconstant.RelayModeSwapFace:
|
||||
err = relay.RelaySwapFace(c)
|
||||
mjErr = relay.RelaySwapFace(c, relayInfo)
|
||||
default:
|
||||
err = relay.RelayMidjourneySubmit(c, relayMode)
|
||||
mjErr = relay.RelayMidjourneySubmit(c, relayInfo)
|
||||
}
|
||||
//err = relayMidjourneySubmit(c, relayMode)
|
||||
log.Println(err)
|
||||
if err != nil {
|
||||
log.Println(mjErr)
|
||||
if mjErr != nil {
|
||||
statusCode := http.StatusBadRequest
|
||||
if err.Code == 30 {
|
||||
err.Result = "当前分组负载已饱和,请稍后再试,或升级账户以提升服务质量。"
|
||||
if mjErr.Code == 30 {
|
||||
mjErr.Result = "当前分组负载已饱和,请稍后再试,或升级账户以提升服务质量。"
|
||||
statusCode = http.StatusTooManyRequests
|
||||
}
|
||||
c.JSON(statusCode, gin.H{
|
||||
"description": fmt.Sprintf("%s %s", err.Description, err.Result),
|
||||
"description": fmt.Sprintf("%s %s", mjErr.Description, mjErr.Result),
|
||||
"type": "upstream_error",
|
||||
"code": err.Code,
|
||||
"code": mjErr.Code,
|
||||
})
|
||||
channelId := c.GetInt("channel_id")
|
||||
common.LogError(c, fmt.Sprintf("relay error (channel #%d, status code %d): %s", channelId, statusCode, fmt.Sprintf("%s %s", err.Description, err.Result)))
|
||||
logger.LogError(c, fmt.Sprintf("relay error (channel #%d, status code %d): %s", channelId, statusCode, fmt.Sprintf("%s %s", mjErr.Description, mjErr.Result)))
|
||||
}
|
||||
}
|
||||
|
||||
@@ -386,18 +384,21 @@ func RelayNotFound(c *gin.Context) {
|
||||
func RelayTask(c *gin.Context) {
|
||||
retryTimes := common.RetryTimes
|
||||
channelId := c.GetInt("channel_id")
|
||||
relayMode := c.GetInt("relay_mode")
|
||||
group := c.GetString("group")
|
||||
originalModel := c.GetString("original_model")
|
||||
c.Set("use_channel", []string{fmt.Sprintf("%d", channelId)})
|
||||
taskErr := taskRelayHandler(c, relayMode)
|
||||
relayInfo, err := relaycommon.GenRelayInfo(c, types.RelayFormatTask, nil, nil)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
taskErr := taskRelayHandler(c, relayInfo)
|
||||
if taskErr == nil {
|
||||
retryTimes = 0
|
||||
}
|
||||
for i := 0; shouldRetryTaskRelay(c, channelId, taskErr, retryTimes) && i < retryTimes; i++ {
|
||||
channel, newAPIError := getChannel(c, group, originalModel, i)
|
||||
if newAPIError != nil {
|
||||
common.LogError(c, fmt.Sprintf("CacheGetRandomSatisfiedChannel failed: %s", newAPIError.Error()))
|
||||
logger.LogError(c, fmt.Sprintf("CacheGetRandomSatisfiedChannel failed: %s", newAPIError.Error()))
|
||||
taskErr = service.TaskErrorWrapperLocal(newAPIError.Err, "get_channel_failed", http.StatusInternalServerError)
|
||||
break
|
||||
}
|
||||
@@ -405,17 +406,17 @@ func RelayTask(c *gin.Context) {
|
||||
useChannel := c.GetStringSlice("use_channel")
|
||||
useChannel = append(useChannel, fmt.Sprintf("%d", channelId))
|
||||
c.Set("use_channel", useChannel)
|
||||
common.LogInfo(c, fmt.Sprintf("using channel #%d to retry (remain times %d)", channel.Id, i))
|
||||
logger.LogInfo(c, fmt.Sprintf("using channel #%d to retry (remain times %d)", channel.Id, i))
|
||||
//middleware.SetupContextForSelectedChannel(c, channel, originalModel)
|
||||
|
||||
requestBody, _ := common.GetRequestBody(c)
|
||||
c.Request.Body = io.NopCloser(bytes.NewBuffer(requestBody))
|
||||
taskErr = taskRelayHandler(c, relayMode)
|
||||
taskErr = taskRelayHandler(c, relayInfo)
|
||||
}
|
||||
useChannel := c.GetStringSlice("use_channel")
|
||||
if len(useChannel) > 1 {
|
||||
retryLogStr := fmt.Sprintf("重试:%s", strings.Trim(strings.Join(strings.Fields(fmt.Sprint(useChannel)), "->"), "[]"))
|
||||
common.LogInfo(c, retryLogStr)
|
||||
logger.LogInfo(c, retryLogStr)
|
||||
}
|
||||
if taskErr != nil {
|
||||
if taskErr.StatusCode == http.StatusTooManyRequests {
|
||||
@@ -425,13 +426,13 @@ func RelayTask(c *gin.Context) {
|
||||
}
|
||||
}
|
||||
|
||||
func taskRelayHandler(c *gin.Context, relayMode int) *dto.TaskError {
|
||||
func taskRelayHandler(c *gin.Context, relayInfo *relaycommon.RelayInfo) *dto.TaskError {
|
||||
var err *dto.TaskError
|
||||
switch relayMode {
|
||||
case relayconstant.RelayModeSunoFetch, relayconstant.RelayModeSunoFetchByID, relayconstant.RelayModeKlingFetchByID:
|
||||
err = relay.RelayTaskFetch(c, relayMode)
|
||||
switch relayInfo.RelayMode {
|
||||
case relayconstant.RelayModeSunoFetch, relayconstant.RelayModeSunoFetchByID, relayconstant.RelayModeVideoFetchByID:
|
||||
err = relay.RelayTaskFetch(c, relayInfo.RelayMode)
|
||||
default:
|
||||
err = relay.RelayTaskSubmit(c, relayMode)
|
||||
err = relay.RelayTaskSubmit(c, relayInfo)
|
||||
}
|
||||
return err
|
||||
}
|
||||
|
||||
@@ -114,3 +114,23 @@ type KlingImage2VideoRequest struct {
|
||||
CallbackURL string `json:"callback_url,omitempty" example:"https://your.domain/callback"`
|
||||
ExternalTaskId string `json:"external_task_id,omitempty" example:"custom-task-002"`
|
||||
}
|
||||
|
||||
// KlingImage2videoTaskId godoc
|
||||
// @Summary 可灵任务查询--图生视频
|
||||
// @Description Query the status and result of a Kling video generation task by task ID
|
||||
// @Tags Origin
|
||||
// @Accept json
|
||||
// @Produce json
|
||||
// @Param task_id path string true "Task ID"
|
||||
// @Router /kling/v1/videos/image2video/{task_id} [get]
|
||||
func KlingImage2videoTaskId(c *gin.Context) {}
|
||||
|
||||
// KlingText2videoTaskId godoc
|
||||
// @Summary 可灵任务查询--文生视频
|
||||
// @Description Query the status and result of a Kling text-to-video generation task by task ID
|
||||
// @Tags Origin
|
||||
// @Accept json
|
||||
// @Produce json
|
||||
// @Param task_id path string true "Task ID"
|
||||
// @Router /kling/v1/videos/text2video/{task_id} [get]
|
||||
func KlingText2videoTaskId(c *gin.Context) {}
|
||||
|
||||
@@ -10,6 +10,7 @@ import (
|
||||
"one-api/common"
|
||||
"one-api/constant"
|
||||
"one-api/dto"
|
||||
"one-api/logger"
|
||||
"one-api/model"
|
||||
"one-api/relay"
|
||||
"sort"
|
||||
@@ -54,9 +55,9 @@ func UpdateTaskBulk() {
|
||||
"progress": "100%",
|
||||
})
|
||||
if err != nil {
|
||||
common.LogError(ctx, fmt.Sprintf("Fix null task_id task error: %v", err))
|
||||
logger.LogError(ctx, fmt.Sprintf("Fix null task_id task error: %v", err))
|
||||
} else {
|
||||
common.LogInfo(ctx, fmt.Sprintf("Fix null task_id task success: %v", nullTaskIds))
|
||||
logger.LogInfo(ctx, fmt.Sprintf("Fix null task_id task success: %v", nullTaskIds))
|
||||
}
|
||||
}
|
||||
if len(taskChannelM) == 0 {
|
||||
@@ -75,10 +76,10 @@ func UpdateTaskByPlatform(platform constant.TaskPlatform, taskChannelM map[int][
|
||||
//_ = UpdateMidjourneyTaskAll(context.Background(), tasks)
|
||||
case constant.TaskPlatformSuno:
|
||||
_ = UpdateSunoTaskAll(context.Background(), taskChannelM, taskM)
|
||||
case constant.TaskPlatformKling, constant.TaskPlatformJimeng:
|
||||
_ = UpdateVideoTaskAll(context.Background(), platform, taskChannelM, taskM)
|
||||
default:
|
||||
common.SysLog("未知平台")
|
||||
if err := UpdateVideoTaskAll(context.Background(), platform, taskChannelM, taskM); err != nil {
|
||||
common.SysLog(fmt.Sprintf("UpdateVideoTaskAll fail: %s", err))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -86,14 +87,14 @@ func UpdateSunoTaskAll(ctx context.Context, taskChannelM map[int][]string, taskM
|
||||
for channelId, taskIds := range taskChannelM {
|
||||
err := updateSunoTaskAll(ctx, channelId, taskIds, taskM)
|
||||
if err != nil {
|
||||
common.LogError(ctx, fmt.Sprintf("渠道 #%d 更新异步任务失败: %d", channelId, err.Error()))
|
||||
logger.LogError(ctx, fmt.Sprintf("渠道 #%d 更新异步任务失败: %d", channelId, err.Error()))
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func updateSunoTaskAll(ctx context.Context, channelId int, taskIds []string, taskM map[string]*model.Task) error {
|
||||
common.LogInfo(ctx, fmt.Sprintf("渠道 #%d 未完成的任务有: %d", channelId, len(taskIds)))
|
||||
logger.LogInfo(ctx, fmt.Sprintf("渠道 #%d 未完成的任务有: %d", channelId, len(taskIds)))
|
||||
if len(taskIds) == 0 {
|
||||
return nil
|
||||
}
|
||||
@@ -106,7 +107,7 @@ func updateSunoTaskAll(ctx context.Context, channelId int, taskIds []string, tas
|
||||
"progress": "100%",
|
||||
})
|
||||
if err != nil {
|
||||
common.SysError(fmt.Sprintf("UpdateMidjourneyTask error2: %v", err))
|
||||
common.SysLog(fmt.Sprintf("UpdateMidjourneyTask error2: %v", err))
|
||||
}
|
||||
return err
|
||||
}
|
||||
@@ -118,23 +119,23 @@ func updateSunoTaskAll(ctx context.Context, channelId int, taskIds []string, tas
|
||||
"ids": taskIds,
|
||||
})
|
||||
if err != nil {
|
||||
common.SysError(fmt.Sprintf("Get Task Do req error: %v", err))
|
||||
common.SysLog(fmt.Sprintf("Get Task Do req error: %v", err))
|
||||
return err
|
||||
}
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
common.LogError(ctx, fmt.Sprintf("Get Task status code: %d", resp.StatusCode))
|
||||
logger.LogError(ctx, fmt.Sprintf("Get Task status code: %d", resp.StatusCode))
|
||||
return errors.New(fmt.Sprintf("Get Task status code: %d", resp.StatusCode))
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
responseBody, err := io.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
common.SysError(fmt.Sprintf("Get Task parse body error: %v", err))
|
||||
common.SysLog(fmt.Sprintf("Get Task parse body error: %v", err))
|
||||
return err
|
||||
}
|
||||
var responseItems dto.TaskResponse[[]dto.SunoDataResponse]
|
||||
err = json.Unmarshal(responseBody, &responseItems)
|
||||
if err != nil {
|
||||
common.LogError(ctx, fmt.Sprintf("Get Task parse body error2: %v, body: %s", err, string(responseBody)))
|
||||
logger.LogError(ctx, fmt.Sprintf("Get Task parse body error2: %v, body: %s", err, string(responseBody)))
|
||||
return err
|
||||
}
|
||||
if !responseItems.IsSuccess() {
|
||||
@@ -154,19 +155,19 @@ func updateSunoTaskAll(ctx context.Context, channelId int, taskIds []string, tas
|
||||
task.StartTime = lo.If(responseItem.StartTime != 0, responseItem.StartTime).Else(task.StartTime)
|
||||
task.FinishTime = lo.If(responseItem.FinishTime != 0, responseItem.FinishTime).Else(task.FinishTime)
|
||||
if responseItem.FailReason != "" || task.Status == model.TaskStatusFailure {
|
||||
common.LogInfo(ctx, task.TaskID+" 构建失败,"+task.FailReason)
|
||||
logger.LogInfo(ctx, task.TaskID+" 构建失败,"+task.FailReason)
|
||||
task.Progress = "100%"
|
||||
//err = model.CacheUpdateUserQuota(task.UserId) ?
|
||||
if err != nil {
|
||||
common.LogError(ctx, "error update user quota cache: "+err.Error())
|
||||
logger.LogError(ctx, "error update user quota cache: "+err.Error())
|
||||
} else {
|
||||
quota := task.Quota
|
||||
if quota != 0 {
|
||||
err = model.IncreaseUserQuota(task.UserId, quota, false)
|
||||
if err != nil {
|
||||
common.LogError(ctx, "fail to increase user quota: "+err.Error())
|
||||
logger.LogError(ctx, "fail to increase user quota: "+err.Error())
|
||||
}
|
||||
logContent := fmt.Sprintf("异步任务执行失败 %s,补偿 %s", task.TaskID, common.LogQuota(quota))
|
||||
logContent := fmt.Sprintf("异步任务执行失败 %s,补偿 %s", task.TaskID, logger.LogQuota(quota))
|
||||
model.RecordLog(task.UserId, model.LogTypeSystem, logContent)
|
||||
}
|
||||
}
|
||||
@@ -178,7 +179,7 @@ func updateSunoTaskAll(ctx context.Context, channelId int, taskIds []string, tas
|
||||
|
||||
err = task.Update()
|
||||
if err != nil {
|
||||
common.SysError("UpdateMidjourneyTask task error: " + err.Error())
|
||||
common.SysLog("UpdateMidjourneyTask task error: " + err.Error())
|
||||
}
|
||||
}
|
||||
return nil
|
||||
|
||||
@@ -2,27 +2,31 @@ package controller
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"one-api/common"
|
||||
"one-api/constant"
|
||||
"one-api/dto"
|
||||
"one-api/logger"
|
||||
"one-api/model"
|
||||
"one-api/relay"
|
||||
"one-api/relay/channel"
|
||||
relaycommon "one-api/relay/common"
|
||||
"time"
|
||||
)
|
||||
|
||||
func UpdateVideoTaskAll(ctx context.Context, platform constant.TaskPlatform, taskChannelM map[int][]string, taskM map[string]*model.Task) error {
|
||||
for channelId, taskIds := range taskChannelM {
|
||||
if err := updateVideoTaskAll(ctx, platform, channelId, taskIds, taskM); err != nil {
|
||||
common.LogError(ctx, fmt.Sprintf("Channel #%d failed to update video async tasks: %s", channelId, err.Error()))
|
||||
logger.LogError(ctx, fmt.Sprintf("Channel #%d failed to update video async tasks: %s", channelId, err.Error()))
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func updateVideoTaskAll(ctx context.Context, platform constant.TaskPlatform, channelId int, taskIds []string, taskM map[string]*model.Task) error {
|
||||
common.LogInfo(ctx, fmt.Sprintf("Channel #%d pending video tasks: %d", channelId, len(taskIds)))
|
||||
logger.LogInfo(ctx, fmt.Sprintf("Channel #%d pending video tasks: %d", channelId, len(taskIds)))
|
||||
if len(taskIds) == 0 {
|
||||
return nil
|
||||
}
|
||||
@@ -34,7 +38,7 @@ func updateVideoTaskAll(ctx context.Context, platform constant.TaskPlatform, cha
|
||||
"progress": "100%",
|
||||
})
|
||||
if errUpdate != nil {
|
||||
common.SysError(fmt.Sprintf("UpdateVideoTask error: %v", errUpdate))
|
||||
common.SysLog(fmt.Sprintf("UpdateVideoTask error: %v", errUpdate))
|
||||
}
|
||||
return fmt.Errorf("CacheGetChannel failed: %w", err)
|
||||
}
|
||||
@@ -44,7 +48,7 @@ func updateVideoTaskAll(ctx context.Context, platform constant.TaskPlatform, cha
|
||||
}
|
||||
for _, taskId := range taskIds {
|
||||
if err := updateVideoSingleTask(ctx, adaptor, cacheGetChannel, taskId, taskM); err != nil {
|
||||
common.LogError(ctx, fmt.Sprintf("Failed to update video task %s: %s", taskId, err.Error()))
|
||||
logger.LogError(ctx, fmt.Sprintf("Failed to update video task %s: %s", taskId, err.Error()))
|
||||
}
|
||||
}
|
||||
return nil
|
||||
@@ -58,7 +62,7 @@ func updateVideoSingleTask(ctx context.Context, adaptor channel.TaskAdaptor, cha
|
||||
|
||||
task := taskM[taskId]
|
||||
if task == nil {
|
||||
common.LogError(ctx, fmt.Sprintf("Task %s not found in taskM", taskId))
|
||||
logger.LogError(ctx, fmt.Sprintf("Task %s not found in taskM", taskId))
|
||||
return fmt.Errorf("task %s not found", taskId)
|
||||
}
|
||||
resp, err := adaptor.FetchTask(baseURL, channel.Key, map[string]any{
|
||||
@@ -77,13 +81,21 @@ func updateVideoSingleTask(ctx context.Context, adaptor channel.TaskAdaptor, cha
|
||||
return fmt.Errorf("readAll failed for task %s: %w", taskId, err)
|
||||
}
|
||||
|
||||
taskResult, err := adaptor.ParseTaskResult(responseBody)
|
||||
if err != nil {
|
||||
taskResult := &relaycommon.TaskInfo{}
|
||||
// try parse as New API response format
|
||||
var responseItems dto.TaskResponse[model.Task]
|
||||
if err = json.Unmarshal(responseBody, &responseItems); err == nil && responseItems.IsSuccess() {
|
||||
t := responseItems.Data
|
||||
taskResult.TaskID = t.TaskID
|
||||
taskResult.Status = string(t.Status)
|
||||
taskResult.Url = t.FailReason
|
||||
taskResult.Progress = t.Progress
|
||||
taskResult.Reason = t.FailReason
|
||||
} else if taskResult, err = adaptor.ParseTaskResult(responseBody); err != nil {
|
||||
return fmt.Errorf("parseTaskResult failed for task %s: %w", taskId, err)
|
||||
} else {
|
||||
task.Data = responseBody
|
||||
}
|
||||
//if taskResult.Code != 0 {
|
||||
// return fmt.Errorf("video task fetch failed for task %s", taskId)
|
||||
//}
|
||||
|
||||
now := time.Now().Unix()
|
||||
if taskResult.Status == "" {
|
||||
@@ -113,13 +125,13 @@ func updateVideoSingleTask(ctx context.Context, adaptor channel.TaskAdaptor, cha
|
||||
task.FinishTime = now
|
||||
}
|
||||
task.FailReason = taskResult.Reason
|
||||
common.LogInfo(ctx, fmt.Sprintf("Task %s failed: %s", task.TaskID, task.FailReason))
|
||||
logger.LogInfo(ctx, fmt.Sprintf("Task %s failed: %s", task.TaskID, task.FailReason))
|
||||
quota := task.Quota
|
||||
if quota != 0 {
|
||||
if err := model.IncreaseUserQuota(task.UserId, quota, false); err != nil {
|
||||
common.LogError(ctx, "Failed to increase user quota: "+err.Error())
|
||||
logger.LogError(ctx, "Failed to increase user quota: "+err.Error())
|
||||
}
|
||||
logContent := fmt.Sprintf("Video async task failed %s, refund %s", task.TaskID, common.LogQuota(quota))
|
||||
logContent := fmt.Sprintf("Video async task failed %s, refund %s", task.TaskID, logger.LogQuota(quota))
|
||||
model.RecordLog(task.UserId, model.LogTypeSystem, logContent)
|
||||
}
|
||||
default:
|
||||
@@ -128,10 +140,8 @@ func updateVideoSingleTask(ctx context.Context, adaptor channel.TaskAdaptor, cha
|
||||
if taskResult.Progress != "" {
|
||||
task.Progress = taskResult.Progress
|
||||
}
|
||||
|
||||
task.Data = responseBody
|
||||
if err := task.Update(); err != nil {
|
||||
common.SysError("UpdateVideoTask task error: " + err.Error())
|
||||
common.SysLog("UpdateVideoTask task error: " + err.Error())
|
||||
}
|
||||
|
||||
return nil
|
||||
|
||||
@@ -5,6 +5,7 @@ import (
|
||||
"one-api/common"
|
||||
"one-api/model"
|
||||
"strconv"
|
||||
"strings"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
@@ -82,6 +83,57 @@ func GetTokenStatus(c *gin.Context) {
|
||||
})
|
||||
}
|
||||
|
||||
func GetTokenUsage(c *gin.Context) {
|
||||
authHeader := c.GetHeader("Authorization")
|
||||
if authHeader == "" {
|
||||
c.JSON(http.StatusUnauthorized, gin.H{
|
||||
"success": false,
|
||||
"message": "No Authorization header",
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
parts := strings.Split(authHeader, " ")
|
||||
if len(parts) != 2 || strings.ToLower(parts[0]) != "bearer" {
|
||||
c.JSON(http.StatusUnauthorized, gin.H{
|
||||
"success": false,
|
||||
"message": "Invalid Bearer token",
|
||||
})
|
||||
return
|
||||
}
|
||||
tokenKey := parts[1]
|
||||
|
||||
token, err := model.GetTokenByKey(strings.TrimPrefix(tokenKey, "sk-"), false)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": false,
|
||||
"message": err.Error(),
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
expiredAt := token.ExpiredTime
|
||||
if expiredAt == -1 {
|
||||
expiredAt = 0
|
||||
}
|
||||
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"code": true,
|
||||
"message": "ok",
|
||||
"data": gin.H{
|
||||
"object": "token_usage",
|
||||
"name": token.Name,
|
||||
"total_granted": token.RemainQuota + token.UsedQuota,
|
||||
"total_used": token.UsedQuota,
|
||||
"total_available": token.RemainQuota,
|
||||
"unlimited_quota": token.UnlimitedQuota,
|
||||
"model_limits": token.GetModelLimitsMap(),
|
||||
"model_limits_enabled": token.ModelLimitsEnabled,
|
||||
"expires_at": expiredAt,
|
||||
},
|
||||
})
|
||||
}
|
||||
|
||||
func AddToken(c *gin.Context) {
|
||||
token := model.Token{}
|
||||
err := c.ShouldBindJSON(&token)
|
||||
@@ -102,7 +154,7 @@ func AddToken(c *gin.Context) {
|
||||
"success": false,
|
||||
"message": "生成令牌失败",
|
||||
})
|
||||
common.SysError("failed to generate token key: " + err.Error())
|
||||
common.SysLog("failed to generate token key: " + err.Error())
|
||||
return
|
||||
}
|
||||
cleanToken := model.Token{
|
||||
|
||||
@@ -5,6 +5,7 @@ import (
|
||||
"log"
|
||||
"net/url"
|
||||
"one-api/common"
|
||||
"one-api/logger"
|
||||
"one-api/model"
|
||||
"one-api/service"
|
||||
"one-api/setting"
|
||||
@@ -231,7 +232,7 @@ func EpayNotify(c *gin.Context) {
|
||||
return
|
||||
}
|
||||
log.Printf("易支付回调更新用户成功 %v", topUp)
|
||||
model.RecordLog(topUp.UserId, model.LogTypeTopup, fmt.Sprintf("使用在线充值成功,充值金额: %v,支付金额:%f", common.LogQuota(quotaToAdd), topUp.Money))
|
||||
model.RecordLog(topUp.UserId, model.LogTypeTopup, fmt.Sprintf("使用在线充值成功,充值金额: %v,支付金额:%f", logger.LogQuota(quotaToAdd), topUp.Money))
|
||||
}
|
||||
} else {
|
||||
log.Printf("易支付异常回调: %v", verifyInfo)
|
||||
|
||||
553
controller/twofa.go
Normal file
553
controller/twofa.go
Normal file
@@ -0,0 +1,553 @@
|
||||
package controller
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"one-api/common"
|
||||
"one-api/model"
|
||||
"strconv"
|
||||
|
||||
"github.com/gin-contrib/sessions"
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
// Setup2FARequest 设置2FA请求结构
|
||||
type Setup2FARequest struct {
|
||||
Code string `json:"code" binding:"required"`
|
||||
}
|
||||
|
||||
// Verify2FARequest 验证2FA请求结构
|
||||
type Verify2FARequest struct {
|
||||
Code string `json:"code" binding:"required"`
|
||||
}
|
||||
|
||||
// Setup2FAResponse 设置2FA响应结构
|
||||
type Setup2FAResponse struct {
|
||||
Secret string `json:"secret"`
|
||||
QRCodeData string `json:"qr_code_data"`
|
||||
BackupCodes []string `json:"backup_codes"`
|
||||
}
|
||||
|
||||
// Setup2FA 初始化2FA设置
|
||||
func Setup2FA(c *gin.Context) {
|
||||
userId := c.GetInt("id")
|
||||
|
||||
// 检查用户是否已经启用2FA
|
||||
existing, err := model.GetTwoFAByUserId(userId)
|
||||
if err != nil {
|
||||
common.ApiError(c, err)
|
||||
return
|
||||
}
|
||||
if existing != nil && existing.IsEnabled {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": false,
|
||||
"message": "用户已启用2FA,请先禁用后重新设置",
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
// 如果存在已禁用的2FA记录,先删除它
|
||||
if existing != nil && !existing.IsEnabled {
|
||||
if err := existing.Delete(); err != nil {
|
||||
common.ApiError(c, err)
|
||||
return
|
||||
}
|
||||
existing = nil // 重置为nil,后续将创建新记录
|
||||
}
|
||||
|
||||
// 获取用户信息
|
||||
user, err := model.GetUserById(userId, false)
|
||||
if err != nil {
|
||||
common.ApiError(c, err)
|
||||
return
|
||||
}
|
||||
|
||||
// 生成TOTP密钥
|
||||
key, err := common.GenerateTOTPSecret(user.Username)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": false,
|
||||
"message": "生成2FA密钥失败",
|
||||
})
|
||||
common.SysLog("生成TOTP密钥失败: " + err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
// 生成备用码
|
||||
backupCodes, err := common.GenerateBackupCodes()
|
||||
if err != nil {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": false,
|
||||
"message": "生成备用码失败",
|
||||
})
|
||||
common.SysLog("生成备用码失败: " + err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
// 生成二维码数据
|
||||
qrCodeData := common.GenerateQRCodeData(key.Secret(), user.Username)
|
||||
|
||||
// 创建或更新2FA记录(暂未启用)
|
||||
twoFA := &model.TwoFA{
|
||||
UserId: userId,
|
||||
Secret: key.Secret(),
|
||||
IsEnabled: false,
|
||||
}
|
||||
|
||||
if existing != nil {
|
||||
// 更新现有记录
|
||||
twoFA.Id = existing.Id
|
||||
err = twoFA.Update()
|
||||
} else {
|
||||
// 创建新记录
|
||||
err = twoFA.Create()
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
common.ApiError(c, err)
|
||||
return
|
||||
}
|
||||
|
||||
// 创建备用码记录
|
||||
if err := model.CreateBackupCodes(userId, backupCodes); err != nil {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": false,
|
||||
"message": "保存备用码失败",
|
||||
})
|
||||
common.SysLog("保存备用码失败: " + err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
// 记录操作日志
|
||||
model.RecordLog(userId, model.LogTypeSystem, "开始设置两步验证")
|
||||
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": true,
|
||||
"message": "2FA设置初始化成功,请使用认证器扫描二维码并输入验证码完成设置",
|
||||
"data": Setup2FAResponse{
|
||||
Secret: key.Secret(),
|
||||
QRCodeData: qrCodeData,
|
||||
BackupCodes: backupCodes,
|
||||
},
|
||||
})
|
||||
}
|
||||
|
||||
// Enable2FA 启用2FA
|
||||
func Enable2FA(c *gin.Context) {
|
||||
var req Setup2FARequest
|
||||
if err := c.ShouldBindJSON(&req); err != nil {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": false,
|
||||
"message": "参数错误",
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
userId := c.GetInt("id")
|
||||
|
||||
// 获取2FA记录
|
||||
twoFA, err := model.GetTwoFAByUserId(userId)
|
||||
if err != nil {
|
||||
common.ApiError(c, err)
|
||||
return
|
||||
}
|
||||
if twoFA == nil {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": false,
|
||||
"message": "请先完成2FA初始化设置",
|
||||
})
|
||||
return
|
||||
}
|
||||
if twoFA.IsEnabled {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": false,
|
||||
"message": "2FA已经启用",
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
// 验证TOTP验证码
|
||||
cleanCode, err := common.ValidateNumericCode(req.Code)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": false,
|
||||
"message": err.Error(),
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
if !common.ValidateTOTPCode(twoFA.Secret, cleanCode) {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": false,
|
||||
"message": "验证码或备用码错误,请重试",
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
// 启用2FA
|
||||
if err := twoFA.Enable(); err != nil {
|
||||
common.ApiError(c, err)
|
||||
return
|
||||
}
|
||||
|
||||
// 记录操作日志
|
||||
model.RecordLog(userId, model.LogTypeSystem, "成功启用两步验证")
|
||||
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": true,
|
||||
"message": "两步验证启用成功",
|
||||
})
|
||||
}
|
||||
|
||||
// Disable2FA 禁用2FA
|
||||
func Disable2FA(c *gin.Context) {
|
||||
var req Verify2FARequest
|
||||
if err := c.ShouldBindJSON(&req); err != nil {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": false,
|
||||
"message": "参数错误",
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
userId := c.GetInt("id")
|
||||
|
||||
// 获取2FA记录
|
||||
twoFA, err := model.GetTwoFAByUserId(userId)
|
||||
if err != nil {
|
||||
common.ApiError(c, err)
|
||||
return
|
||||
}
|
||||
if twoFA == nil || !twoFA.IsEnabled {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": false,
|
||||
"message": "用户未启用2FA",
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
// 验证TOTP验证码或备用码
|
||||
cleanCode, err := common.ValidateNumericCode(req.Code)
|
||||
isValidTOTP := false
|
||||
isValidBackup := false
|
||||
|
||||
if err == nil {
|
||||
// 尝试验证TOTP
|
||||
isValidTOTP, _ = twoFA.ValidateTOTPAndUpdateUsage(cleanCode)
|
||||
}
|
||||
|
||||
if !isValidTOTP {
|
||||
// 尝试验证备用码
|
||||
isValidBackup, err = twoFA.ValidateBackupCodeAndUpdateUsage(req.Code)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": false,
|
||||
"message": err.Error(),
|
||||
})
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
if !isValidTOTP && !isValidBackup {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": false,
|
||||
"message": "验证码或备用码错误,请重试",
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
// 禁用2FA
|
||||
if err := model.DisableTwoFA(userId); err != nil {
|
||||
common.ApiError(c, err)
|
||||
return
|
||||
}
|
||||
|
||||
// 记录操作日志
|
||||
model.RecordLog(userId, model.LogTypeSystem, "禁用两步验证")
|
||||
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": true,
|
||||
"message": "两步验证已禁用",
|
||||
})
|
||||
}
|
||||
|
||||
// Get2FAStatus 获取用户2FA状态
|
||||
func Get2FAStatus(c *gin.Context) {
|
||||
userId := c.GetInt("id")
|
||||
|
||||
twoFA, err := model.GetTwoFAByUserId(userId)
|
||||
if err != nil {
|
||||
common.ApiError(c, err)
|
||||
return
|
||||
}
|
||||
|
||||
status := map[string]interface{}{
|
||||
"enabled": false,
|
||||
"locked": false,
|
||||
}
|
||||
|
||||
if twoFA != nil {
|
||||
status["enabled"] = twoFA.IsEnabled
|
||||
status["locked"] = twoFA.IsLocked()
|
||||
if twoFA.IsEnabled {
|
||||
// 获取剩余备用码数量
|
||||
backupCount, err := model.GetUnusedBackupCodeCount(userId)
|
||||
if err != nil {
|
||||
common.SysLog("获取备用码数量失败: " + err.Error())
|
||||
} else {
|
||||
status["backup_codes_remaining"] = backupCount
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": true,
|
||||
"message": "",
|
||||
"data": status,
|
||||
})
|
||||
}
|
||||
|
||||
// RegenerateBackupCodes 重新生成备用码
|
||||
func RegenerateBackupCodes(c *gin.Context) {
|
||||
var req Verify2FARequest
|
||||
if err := c.ShouldBindJSON(&req); err != nil {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": false,
|
||||
"message": "参数错误",
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
userId := c.GetInt("id")
|
||||
|
||||
// 获取2FA记录
|
||||
twoFA, err := model.GetTwoFAByUserId(userId)
|
||||
if err != nil {
|
||||
common.ApiError(c, err)
|
||||
return
|
||||
}
|
||||
if twoFA == nil || !twoFA.IsEnabled {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": false,
|
||||
"message": "用户未启用2FA",
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
// 验证TOTP验证码
|
||||
cleanCode, err := common.ValidateNumericCode(req.Code)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": false,
|
||||
"message": err.Error(),
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
valid, err := twoFA.ValidateTOTPAndUpdateUsage(cleanCode)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": false,
|
||||
"message": err.Error(),
|
||||
})
|
||||
return
|
||||
}
|
||||
if !valid {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": false,
|
||||
"message": "验证码或备用码错误,请重试",
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
// 生成新的备用码
|
||||
backupCodes, err := common.GenerateBackupCodes()
|
||||
if err != nil {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": false,
|
||||
"message": "生成备用码失败",
|
||||
})
|
||||
common.SysLog("生成备用码失败: " + err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
// 保存新的备用码
|
||||
if err := model.CreateBackupCodes(userId, backupCodes); err != nil {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": false,
|
||||
"message": "保存备用码失败",
|
||||
})
|
||||
common.SysLog("保存备用码失败: " + err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
// 记录操作日志
|
||||
model.RecordLog(userId, model.LogTypeSystem, "重新生成两步验证备用码")
|
||||
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": true,
|
||||
"message": "备用码重新生成成功",
|
||||
"data": map[string]interface{}{
|
||||
"backup_codes": backupCodes,
|
||||
},
|
||||
})
|
||||
}
|
||||
|
||||
// Verify2FALogin 登录时验证2FA
|
||||
func Verify2FALogin(c *gin.Context) {
|
||||
var req Verify2FARequest
|
||||
if err := c.ShouldBindJSON(&req); err != nil {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": false,
|
||||
"message": "参数错误",
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
// 从会话中获取pending用户信息
|
||||
session := sessions.Default(c)
|
||||
pendingUserId := session.Get("pending_user_id")
|
||||
if pendingUserId == nil {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": false,
|
||||
"message": "会话已过期,请重新登录",
|
||||
})
|
||||
return
|
||||
}
|
||||
userId, ok := pendingUserId.(int)
|
||||
if !ok {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": false,
|
||||
"message": "会话数据无效,请重新登录",
|
||||
})
|
||||
return
|
||||
}
|
||||
// 获取用户信息
|
||||
user, err := model.GetUserById(userId, false)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": false,
|
||||
"message": "用户不存在",
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
// 获取2FA记录
|
||||
twoFA, err := model.GetTwoFAByUserId(user.Id)
|
||||
if err != nil {
|
||||
common.ApiError(c, err)
|
||||
return
|
||||
}
|
||||
if twoFA == nil || !twoFA.IsEnabled {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": false,
|
||||
"message": "用户未启用2FA",
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
// 验证TOTP验证码或备用码
|
||||
cleanCode, err := common.ValidateNumericCode(req.Code)
|
||||
isValidTOTP := false
|
||||
isValidBackup := false
|
||||
|
||||
if err == nil {
|
||||
// 尝试验证TOTP
|
||||
isValidTOTP, _ = twoFA.ValidateTOTPAndUpdateUsage(cleanCode)
|
||||
}
|
||||
|
||||
if !isValidTOTP {
|
||||
// 尝试验证备用码
|
||||
isValidBackup, err = twoFA.ValidateBackupCodeAndUpdateUsage(req.Code)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": false,
|
||||
"message": err.Error(),
|
||||
})
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
if !isValidTOTP && !isValidBackup {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": false,
|
||||
"message": "验证码或备用码错误,请重试",
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
// 2FA验证成功,清理pending会话信息并完成登录
|
||||
session.Delete("pending_username")
|
||||
session.Delete("pending_user_id")
|
||||
session.Save()
|
||||
|
||||
setupLogin(user, c)
|
||||
}
|
||||
|
||||
// Admin2FAStats 管理员获取2FA统计信息
|
||||
func Admin2FAStats(c *gin.Context) {
|
||||
stats, err := model.GetTwoFAStats()
|
||||
if err != nil {
|
||||
common.ApiError(c, err)
|
||||
return
|
||||
}
|
||||
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": true,
|
||||
"message": "",
|
||||
"data": stats,
|
||||
})
|
||||
}
|
||||
|
||||
// AdminDisable2FA 管理员强制禁用用户2FA
|
||||
func AdminDisable2FA(c *gin.Context) {
|
||||
userIdStr := c.Param("id")
|
||||
userId, err := strconv.Atoi(userIdStr)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": false,
|
||||
"message": "用户ID格式错误",
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
// 检查目标用户权限
|
||||
targetUser, err := model.GetUserById(userId, false)
|
||||
if err != nil {
|
||||
common.ApiError(c, err)
|
||||
return
|
||||
}
|
||||
|
||||
myRole := c.GetInt("role")
|
||||
if myRole <= targetUser.Role && myRole != common.RoleRootUser {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": false,
|
||||
"message": "无权操作同级或更高级用户的2FA设置",
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
// 禁用2FA
|
||||
if err := model.DisableTwoFA(userId); err != nil {
|
||||
if errors.Is(err, model.ErrTwoFANotEnabled) {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": false,
|
||||
"message": "用户未启用2FA",
|
||||
})
|
||||
return
|
||||
}
|
||||
common.ApiError(c, err)
|
||||
return
|
||||
}
|
||||
|
||||
// 记录操作日志
|
||||
adminId := c.GetInt("id")
|
||||
model.RecordLog(userId, model.LogTypeManage,
|
||||
fmt.Sprintf("管理员(ID:%d)强制禁用了用户的两步验证", adminId))
|
||||
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": true,
|
||||
"message": "用户2FA已被强制禁用",
|
||||
})
|
||||
}
|
||||
@@ -31,7 +31,7 @@ type Monitor struct {
|
||||
|
||||
type UptimeGroupResult struct {
|
||||
CategoryName string `json:"categoryName"`
|
||||
Monitors []Monitor `json:"monitors"`
|
||||
Monitors []Monitor `json:"monitors"`
|
||||
}
|
||||
|
||||
func getAndDecode(ctx context.Context, client *http.Client, url string, dest interface{}) error {
|
||||
@@ -57,29 +57,29 @@ func fetchGroupData(ctx context.Context, client *http.Client, groupConfig map[st
|
||||
url, _ := groupConfig["url"].(string)
|
||||
slug, _ := groupConfig["slug"].(string)
|
||||
categoryName, _ := groupConfig["categoryName"].(string)
|
||||
|
||||
|
||||
result := UptimeGroupResult{
|
||||
CategoryName: categoryName,
|
||||
Monitors: []Monitor{},
|
||||
Monitors: []Monitor{},
|
||||
}
|
||||
|
||||
|
||||
if url == "" || slug == "" {
|
||||
return result
|
||||
}
|
||||
|
||||
baseURL := strings.TrimSuffix(url, "/")
|
||||
|
||||
|
||||
var statusData struct {
|
||||
PublicGroupList []struct {
|
||||
ID int `json:"id"`
|
||||
Name string `json:"name"`
|
||||
ID int `json:"id"`
|
||||
Name string `json:"name"`
|
||||
MonitorList []struct {
|
||||
ID int `json:"id"`
|
||||
Name string `json:"name"`
|
||||
} `json:"monitorList"`
|
||||
} `json:"publicGroupList"`
|
||||
}
|
||||
|
||||
|
||||
var heartbeatData struct {
|
||||
HeartbeatList map[string][]struct {
|
||||
Status int `json:"status"`
|
||||
@@ -88,11 +88,11 @@ func fetchGroupData(ctx context.Context, client *http.Client, groupConfig map[st
|
||||
}
|
||||
|
||||
g, gCtx := errgroup.WithContext(ctx)
|
||||
g.Go(func() error {
|
||||
return getAndDecode(gCtx, client, baseURL+apiStatusPath+slug, &statusData)
|
||||
g.Go(func() error {
|
||||
return getAndDecode(gCtx, client, baseURL+apiStatusPath+slug, &statusData)
|
||||
})
|
||||
g.Go(func() error {
|
||||
return getAndDecode(gCtx, client, baseURL+apiHeartbeatPath+slug, &heartbeatData)
|
||||
g.Go(func() error {
|
||||
return getAndDecode(gCtx, client, baseURL+apiHeartbeatPath+slug, &heartbeatData)
|
||||
})
|
||||
|
||||
if g.Wait() != nil {
|
||||
@@ -139,7 +139,7 @@ func GetUptimeKumaStatus(c *gin.Context) {
|
||||
|
||||
client := &http.Client{Timeout: httpTimeout}
|
||||
results := make([]UptimeGroupResult, len(groups))
|
||||
|
||||
|
||||
g, gCtx := errgroup.WithContext(ctx)
|
||||
for i, group := range groups {
|
||||
i, group := i, group
|
||||
@@ -148,7 +148,7 @@ func GetUptimeKumaStatus(c *gin.Context) {
|
||||
return nil
|
||||
})
|
||||
}
|
||||
|
||||
|
||||
g.Wait()
|
||||
c.JSON(http.StatusOK, gin.H{"success": true, "message": "", "data": results})
|
||||
}
|
||||
}
|
||||
|
||||
@@ -7,6 +7,7 @@ import (
|
||||
"net/url"
|
||||
"one-api/common"
|
||||
"one-api/dto"
|
||||
"one-api/logger"
|
||||
"one-api/model"
|
||||
"one-api/setting"
|
||||
"strconv"
|
||||
@@ -62,6 +63,32 @@ func Login(c *gin.Context) {
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
// 检查是否启用2FA
|
||||
if model.IsTwoFAEnabled(user.Id) {
|
||||
// 设置pending session,等待2FA验证
|
||||
session := sessions.Default(c)
|
||||
session.Set("pending_username", user.Username)
|
||||
session.Set("pending_user_id", user.Id)
|
||||
err := session.Save()
|
||||
if err != nil {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"message": "无法保存会话信息,请重试",
|
||||
"success": false,
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"message": "请输入两步验证码",
|
||||
"success": true,
|
||||
"data": map[string]interface{}{
|
||||
"require_2fa": true,
|
||||
},
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
setupLogin(&user, c)
|
||||
}
|
||||
|
||||
@@ -166,7 +193,7 @@ func Register(c *gin.Context) {
|
||||
"success": false,
|
||||
"message": "数据库错误,请稍后重试",
|
||||
})
|
||||
common.SysError(fmt.Sprintf("CheckUserExistOrDeleted error: %v", err))
|
||||
common.SysLog(fmt.Sprintf("CheckUserExistOrDeleted error: %v", err))
|
||||
return
|
||||
}
|
||||
if exist {
|
||||
@@ -183,6 +210,7 @@ func Register(c *gin.Context) {
|
||||
Password: user.Password,
|
||||
DisplayName: user.Username,
|
||||
InviterId: inviterId,
|
||||
Role: common.RoleCommonUser, // 明确设置角色为普通用户
|
||||
}
|
||||
if common.EmailVerificationEnabled {
|
||||
cleanUser.Email = user.Email
|
||||
@@ -209,7 +237,7 @@ func Register(c *gin.Context) {
|
||||
"success": false,
|
||||
"message": "生成默认令牌失败",
|
||||
})
|
||||
common.SysError("failed to generate token key: " + err.Error())
|
||||
common.SysLog("failed to generate token key: " + err.Error())
|
||||
return
|
||||
}
|
||||
// 生成默认令牌
|
||||
@@ -316,7 +344,7 @@ func GenerateAccessToken(c *gin.Context) {
|
||||
"success": false,
|
||||
"message": "生成失败",
|
||||
})
|
||||
common.SysError("failed to generate key: " + err.Error())
|
||||
common.SysLog("failed to generate key: " + err.Error())
|
||||
return
|
||||
}
|
||||
user.SetAccessToken(key)
|
||||
@@ -399,6 +427,7 @@ func GetAffCode(c *gin.Context) {
|
||||
|
||||
func GetSelf(c *gin.Context) {
|
||||
id := c.GetInt("id")
|
||||
userRole := c.GetInt("role")
|
||||
user, err := model.GetUserById(id, false)
|
||||
if err != nil {
|
||||
common.ApiError(c, err)
|
||||
@@ -407,14 +436,134 @@ func GetSelf(c *gin.Context) {
|
||||
// Hide admin remarks: set to empty to trigger omitempty tag, ensuring the remark field is not included in JSON returned to regular users
|
||||
user.Remark = ""
|
||||
|
||||
// 计算用户权限信息
|
||||
permissions := calculateUserPermissions(userRole)
|
||||
|
||||
// 获取用户设置并提取sidebar_modules
|
||||
userSetting := user.GetSetting()
|
||||
|
||||
// 构建响应数据,包含用户信息和权限
|
||||
responseData := map[string]interface{}{
|
||||
"id": user.Id,
|
||||
"username": user.Username,
|
||||
"display_name": user.DisplayName,
|
||||
"role": user.Role,
|
||||
"status": user.Status,
|
||||
"email": user.Email,
|
||||
"group": user.Group,
|
||||
"quota": user.Quota,
|
||||
"used_quota": user.UsedQuota,
|
||||
"request_count": user.RequestCount,
|
||||
"aff_code": user.AffCode,
|
||||
"aff_count": user.AffCount,
|
||||
"aff_quota": user.AffQuota,
|
||||
"aff_history_quota": user.AffHistoryQuota,
|
||||
"inviter_id": user.InviterId,
|
||||
"linux_do_id": user.LinuxDOId,
|
||||
"setting": user.Setting,
|
||||
"stripe_customer": user.StripeCustomer,
|
||||
"sidebar_modules": userSetting.SidebarModules, // 正确提取sidebar_modules字段
|
||||
"permissions": permissions, // 新增权限字段
|
||||
}
|
||||
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": true,
|
||||
"message": "",
|
||||
"data": user,
|
||||
"data": responseData,
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
// 计算用户权限的辅助函数
|
||||
func calculateUserPermissions(userRole int) map[string]interface{} {
|
||||
permissions := map[string]interface{}{}
|
||||
|
||||
// 根据用户角色计算权限
|
||||
if userRole == common.RoleRootUser {
|
||||
// 超级管理员不需要边栏设置功能
|
||||
permissions["sidebar_settings"] = false
|
||||
permissions["sidebar_modules"] = map[string]interface{}{}
|
||||
} else if userRole == common.RoleAdminUser {
|
||||
// 管理员可以设置边栏,但不包含系统设置功能
|
||||
permissions["sidebar_settings"] = true
|
||||
permissions["sidebar_modules"] = map[string]interface{}{
|
||||
"admin": map[string]interface{}{
|
||||
"setting": false, // 管理员不能访问系统设置
|
||||
},
|
||||
}
|
||||
} else {
|
||||
// 普通用户只能设置个人功能,不包含管理员区域
|
||||
permissions["sidebar_settings"] = true
|
||||
permissions["sidebar_modules"] = map[string]interface{}{
|
||||
"admin": false, // 普通用户不能访问管理员区域
|
||||
}
|
||||
}
|
||||
|
||||
return permissions
|
||||
}
|
||||
|
||||
// 根据用户角色生成默认的边栏配置
|
||||
func generateDefaultSidebarConfig(userRole int) string {
|
||||
defaultConfig := map[string]interface{}{}
|
||||
|
||||
// 聊天区域 - 所有用户都可以访问
|
||||
defaultConfig["chat"] = map[string]interface{}{
|
||||
"enabled": true,
|
||||
"playground": true,
|
||||
"chat": true,
|
||||
}
|
||||
|
||||
// 控制台区域 - 所有用户都可以访问
|
||||
defaultConfig["console"] = map[string]interface{}{
|
||||
"enabled": true,
|
||||
"detail": true,
|
||||
"token": true,
|
||||
"log": true,
|
||||
"midjourney": true,
|
||||
"task": true,
|
||||
}
|
||||
|
||||
// 个人中心区域 - 所有用户都可以访问
|
||||
defaultConfig["personal"] = map[string]interface{}{
|
||||
"enabled": true,
|
||||
"topup": true,
|
||||
"personal": true,
|
||||
}
|
||||
|
||||
// 管理员区域 - 根据角色决定
|
||||
if userRole == common.RoleAdminUser {
|
||||
// 管理员可以访问管理员区域,但不能访问系统设置
|
||||
defaultConfig["admin"] = map[string]interface{}{
|
||||
"enabled": true,
|
||||
"channel": true,
|
||||
"models": true,
|
||||
"redemption": true,
|
||||
"user": true,
|
||||
"setting": false, // 管理员不能访问系统设置
|
||||
}
|
||||
} else if userRole == common.RoleRootUser {
|
||||
// 超级管理员可以访问所有功能
|
||||
defaultConfig["admin"] = map[string]interface{}{
|
||||
"enabled": true,
|
||||
"channel": true,
|
||||
"models": true,
|
||||
"redemption": true,
|
||||
"user": true,
|
||||
"setting": true,
|
||||
}
|
||||
}
|
||||
// 普通用户不包含admin区域
|
||||
|
||||
// 转换为JSON字符串
|
||||
configBytes, err := json.Marshal(defaultConfig)
|
||||
if err != nil {
|
||||
common.SysLog("生成默认边栏配置失败: " + err.Error())
|
||||
return ""
|
||||
}
|
||||
|
||||
return string(configBytes)
|
||||
}
|
||||
|
||||
func GetUserModels(c *gin.Context) {
|
||||
id, err := strconv.Atoi(c.Param("id"))
|
||||
if err != nil {
|
||||
@@ -491,7 +640,7 @@ func UpdateUser(c *gin.Context) {
|
||||
return
|
||||
}
|
||||
if originUser.Quota != updatedUser.Quota {
|
||||
model.RecordLog(originUser.Id, model.LogTypeManage, fmt.Sprintf("管理员将用户额度从 %s修改为 %s", common.LogQuota(originUser.Quota), common.LogQuota(updatedUser.Quota)))
|
||||
model.RecordLog(originUser.Id, model.LogTypeManage, fmt.Sprintf("管理员将用户额度从 %s修改为 %s", logger.LogQuota(originUser.Quota), logger.LogQuota(updatedUser.Quota)))
|
||||
}
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": true,
|
||||
@@ -501,8 +650,8 @@ func UpdateUser(c *gin.Context) {
|
||||
}
|
||||
|
||||
func UpdateSelf(c *gin.Context) {
|
||||
var user model.User
|
||||
err := json.NewDecoder(c.Request.Body).Decode(&user)
|
||||
var requestData map[string]interface{}
|
||||
err := json.NewDecoder(c.Request.Body).Decode(&requestData)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": false,
|
||||
@@ -510,6 +659,60 @@ func UpdateSelf(c *gin.Context) {
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
// 检查是否是sidebar_modules更新请求
|
||||
if sidebarModules, exists := requestData["sidebar_modules"]; exists {
|
||||
userId := c.GetInt("id")
|
||||
user, err := model.GetUserById(userId, false)
|
||||
if err != nil {
|
||||
common.ApiError(c, err)
|
||||
return
|
||||
}
|
||||
|
||||
// 获取当前用户设置
|
||||
currentSetting := user.GetSetting()
|
||||
|
||||
// 更新sidebar_modules字段
|
||||
if sidebarModulesStr, ok := sidebarModules.(string); ok {
|
||||
currentSetting.SidebarModules = sidebarModulesStr
|
||||
}
|
||||
|
||||
// 保存更新后的设置
|
||||
user.SetSetting(currentSetting)
|
||||
if err := user.Update(false); err != nil {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": false,
|
||||
"message": "更新设置失败: " + err.Error(),
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": true,
|
||||
"message": "设置更新成功",
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
// 原有的用户信息更新逻辑
|
||||
var user model.User
|
||||
requestDataBytes, err := json.Marshal(requestData)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": false,
|
||||
"message": "无效的参数",
|
||||
})
|
||||
return
|
||||
}
|
||||
err = json.Unmarshal(requestDataBytes, &user)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": false,
|
||||
"message": "无效的参数",
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
if user.Password == "" {
|
||||
user.Password = "$I_LOVE_U" // make Validator happy :)
|
||||
}
|
||||
@@ -652,6 +855,7 @@ func CreateUser(c *gin.Context) {
|
||||
Username: user.Username,
|
||||
Password: user.Password,
|
||||
DisplayName: user.DisplayName,
|
||||
Role: user.Role, // 保持管理员设置的角色
|
||||
}
|
||||
if err := cleanUser.Insert(0); err != nil {
|
||||
common.ApiError(c, err)
|
||||
@@ -817,18 +1021,64 @@ type topUpRequest struct {
|
||||
Key string `json:"key"`
|
||||
}
|
||||
|
||||
var topUpLock = sync.Mutex{}
|
||||
var topUpLocks sync.Map
|
||||
var topUpCreateLock sync.Mutex
|
||||
|
||||
type topUpTryLock struct {
|
||||
ch chan struct{}
|
||||
}
|
||||
|
||||
func newTopUpTryLock() *topUpTryLock {
|
||||
return &topUpTryLock{ch: make(chan struct{}, 1)}
|
||||
}
|
||||
|
||||
func (l *topUpTryLock) TryLock() bool {
|
||||
select {
|
||||
case l.ch <- struct{}{}:
|
||||
return true
|
||||
default:
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
func (l *topUpTryLock) Unlock() {
|
||||
select {
|
||||
case <-l.ch:
|
||||
default:
|
||||
}
|
||||
}
|
||||
|
||||
func getTopUpLock(userID int) *topUpTryLock {
|
||||
if v, ok := topUpLocks.Load(userID); ok {
|
||||
return v.(*topUpTryLock)
|
||||
}
|
||||
topUpCreateLock.Lock()
|
||||
defer topUpCreateLock.Unlock()
|
||||
if v, ok := topUpLocks.Load(userID); ok {
|
||||
return v.(*topUpTryLock)
|
||||
}
|
||||
l := newTopUpTryLock()
|
||||
topUpLocks.Store(userID, l)
|
||||
return l
|
||||
}
|
||||
|
||||
func TopUp(c *gin.Context) {
|
||||
topUpLock.Lock()
|
||||
defer topUpLock.Unlock()
|
||||
id := c.GetInt("id")
|
||||
lock := getTopUpLock(id)
|
||||
if !lock.TryLock() {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": false,
|
||||
"message": "充值处理中,请稍后重试",
|
||||
})
|
||||
return
|
||||
}
|
||||
defer lock.Unlock()
|
||||
req := topUpRequest{}
|
||||
err := c.ShouldBindJSON(&req)
|
||||
if err != nil {
|
||||
common.ApiError(c, err)
|
||||
return
|
||||
}
|
||||
id := c.GetInt("id")
|
||||
quota, err := model.Redeem(req.Key, id)
|
||||
if err != nil {
|
||||
common.ApiError(c, err)
|
||||
@@ -839,7 +1089,6 @@ func TopUp(c *gin.Context) {
|
||||
"message": "",
|
||||
"data": quota,
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
type UpdateUserSettingRequest struct {
|
||||
@@ -848,6 +1097,7 @@ type UpdateUserSettingRequest struct {
|
||||
WebhookUrl string `json:"webhook_url,omitempty"`
|
||||
WebhookSecret string `json:"webhook_secret,omitempty"`
|
||||
NotificationEmail string `json:"notification_email,omitempty"`
|
||||
BarkUrl string `json:"bark_url,omitempty"`
|
||||
AcceptUnsetModelRatioModel bool `json:"accept_unset_model_ratio_model"`
|
||||
RecordIpLog bool `json:"record_ip_log"`
|
||||
}
|
||||
@@ -863,7 +1113,7 @@ func UpdateUserSetting(c *gin.Context) {
|
||||
}
|
||||
|
||||
// 验证预警类型
|
||||
if req.QuotaWarningType != dto.NotifyTypeEmail && req.QuotaWarningType != dto.NotifyTypeWebhook {
|
||||
if req.QuotaWarningType != dto.NotifyTypeEmail && req.QuotaWarningType != dto.NotifyTypeWebhook && req.QuotaWarningType != dto.NotifyTypeBark {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": false,
|
||||
"message": "无效的预警类型",
|
||||
@@ -911,6 +1161,33 @@ func UpdateUserSetting(c *gin.Context) {
|
||||
}
|
||||
}
|
||||
|
||||
// 如果是Bark类型,验证Bark URL
|
||||
if req.QuotaWarningType == dto.NotifyTypeBark {
|
||||
if req.BarkUrl == "" {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": false,
|
||||
"message": "Bark推送URL不能为空",
|
||||
})
|
||||
return
|
||||
}
|
||||
// 验证URL格式
|
||||
if _, err := url.ParseRequestURI(req.BarkUrl); err != nil {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": false,
|
||||
"message": "无效的Bark推送URL",
|
||||
})
|
||||
return
|
||||
}
|
||||
// 检查是否是HTTP或HTTPS
|
||||
if !strings.HasPrefix(req.BarkUrl, "https://") && !strings.HasPrefix(req.BarkUrl, "http://") {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": false,
|
||||
"message": "Bark推送URL必须以http://或https://开头",
|
||||
})
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
userId := c.GetInt("id")
|
||||
user, err := model.GetUserById(userId, true)
|
||||
if err != nil {
|
||||
@@ -939,6 +1216,11 @@ func UpdateUserSetting(c *gin.Context) {
|
||||
settings.NotificationEmail = req.NotificationEmail
|
||||
}
|
||||
|
||||
// 如果是Bark类型,添加Bark URL到设置中
|
||||
if req.QuotaWarningType == dto.NotifyTypeBark {
|
||||
settings.BarkUrl = req.BarkUrl
|
||||
}
|
||||
|
||||
// 更新用户设置
|
||||
user.SetSetting(settings)
|
||||
if err := user.Update(false); err != nil {
|
||||
|
||||
124
controller/vendor_meta.go
Normal file
124
controller/vendor_meta.go
Normal file
@@ -0,0 +1,124 @@
|
||||
package controller
|
||||
|
||||
import (
|
||||
"strconv"
|
||||
|
||||
"one-api/common"
|
||||
"one-api/model"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
// GetAllVendors 获取供应商列表(分页)
|
||||
func GetAllVendors(c *gin.Context) {
|
||||
pageInfo := common.GetPageQuery(c)
|
||||
vendors, err := model.GetAllVendors(pageInfo.GetStartIdx(), pageInfo.GetPageSize())
|
||||
if err != nil {
|
||||
common.ApiError(c, err)
|
||||
return
|
||||
}
|
||||
var total int64
|
||||
model.DB.Model(&model.Vendor{}).Count(&total)
|
||||
pageInfo.SetTotal(int(total))
|
||||
pageInfo.SetItems(vendors)
|
||||
common.ApiSuccess(c, pageInfo)
|
||||
}
|
||||
|
||||
// SearchVendors 搜索供应商
|
||||
func SearchVendors(c *gin.Context) {
|
||||
keyword := c.Query("keyword")
|
||||
pageInfo := common.GetPageQuery(c)
|
||||
vendors, total, err := model.SearchVendors(keyword, pageInfo.GetStartIdx(), pageInfo.GetPageSize())
|
||||
if err != nil {
|
||||
common.ApiError(c, err)
|
||||
return
|
||||
}
|
||||
pageInfo.SetTotal(int(total))
|
||||
pageInfo.SetItems(vendors)
|
||||
common.ApiSuccess(c, pageInfo)
|
||||
}
|
||||
|
||||
// GetVendorMeta 根据 ID 获取供应商
|
||||
func GetVendorMeta(c *gin.Context) {
|
||||
idStr := c.Param("id")
|
||||
id, err := strconv.Atoi(idStr)
|
||||
if err != nil {
|
||||
common.ApiError(c, err)
|
||||
return
|
||||
}
|
||||
v, err := model.GetVendorByID(id)
|
||||
if err != nil {
|
||||
common.ApiError(c, err)
|
||||
return
|
||||
}
|
||||
common.ApiSuccess(c, v)
|
||||
}
|
||||
|
||||
// CreateVendorMeta 新建供应商
|
||||
func CreateVendorMeta(c *gin.Context) {
|
||||
var v model.Vendor
|
||||
if err := c.ShouldBindJSON(&v); err != nil {
|
||||
common.ApiError(c, err)
|
||||
return
|
||||
}
|
||||
if v.Name == "" {
|
||||
common.ApiErrorMsg(c, "供应商名称不能为空")
|
||||
return
|
||||
}
|
||||
// 创建前先检查名称
|
||||
if dup, err := model.IsVendorNameDuplicated(0, v.Name); err != nil {
|
||||
common.ApiError(c, err)
|
||||
return
|
||||
} else if dup {
|
||||
common.ApiErrorMsg(c, "供应商名称已存在")
|
||||
return
|
||||
}
|
||||
|
||||
if err := v.Insert(); err != nil {
|
||||
common.ApiError(c, err)
|
||||
return
|
||||
}
|
||||
common.ApiSuccess(c, &v)
|
||||
}
|
||||
|
||||
// UpdateVendorMeta 更新供应商
|
||||
func UpdateVendorMeta(c *gin.Context) {
|
||||
var v model.Vendor
|
||||
if err := c.ShouldBindJSON(&v); err != nil {
|
||||
common.ApiError(c, err)
|
||||
return
|
||||
}
|
||||
if v.Id == 0 {
|
||||
common.ApiErrorMsg(c, "缺少供应商 ID")
|
||||
return
|
||||
}
|
||||
// 名称冲突检查
|
||||
if dup, err := model.IsVendorNameDuplicated(v.Id, v.Name); err != nil {
|
||||
common.ApiError(c, err)
|
||||
return
|
||||
} else if dup {
|
||||
common.ApiErrorMsg(c, "供应商名称已存在")
|
||||
return
|
||||
}
|
||||
|
||||
if err := v.Update(); err != nil {
|
||||
common.ApiError(c, err)
|
||||
return
|
||||
}
|
||||
common.ApiSuccess(c, &v)
|
||||
}
|
||||
|
||||
// DeleteVendorMeta 删除供应商
|
||||
func DeleteVendorMeta(c *gin.Context) {
|
||||
idStr := c.Param("id")
|
||||
id, err := strconv.Atoi(idStr)
|
||||
if err != nil {
|
||||
common.ApiError(c, err)
|
||||
return
|
||||
}
|
||||
if err := model.DB.Delete(&model.Vendor{}, id).Error; err != nil {
|
||||
common.ApiError(c, err)
|
||||
return
|
||||
}
|
||||
common.ApiSuccess(c, nil)
|
||||
}
|
||||
Reference in New Issue
Block a user