mirror of
https://github.com/QuantumNous/new-api.git
synced 2026-04-18 00:57:26 +00:00
Merge branch 'upstream-main' into feature/improve-param-override
# Conflicts: # relay/channel/api_request_test.go # relay/common/override_test.go # web/src/components/table/channels/modals/EditChannelModal.jsx
This commit is contained in:
@@ -30,8 +30,8 @@
|
||||
</p>
|
||||
|
||||
<p align="center">
|
||||
<a href="https://trendshift.io/repositories/8227" target="_blank">
|
||||
<img src="https://trendshift.io/api/badge/repositories/8227" alt="Calcium-Ion%2Fnew-api | Trendshift" style="width: 250px; height: 55px;" width="250" height="55"/>
|
||||
<a href="https://trendshift.io/repositories/20180" target="_blank">
|
||||
<img src="https://trendshift.io/api/badge/repositories/20180" alt="QuantumNous%2Fnew-api | Trendshift" style="width: 250px; height: 55px;" width="250" height="55"/>
|
||||
</a>
|
||||
<br>
|
||||
<a href="https://hellogithub.com/repository/QuantumNous/new-api" target="_blank">
|
||||
|
||||
@@ -30,8 +30,8 @@
|
||||
</p>
|
||||
|
||||
<p align="center">
|
||||
<a href="https://trendshift.io/repositories/8227" target="_blank">
|
||||
<img src="https://trendshift.io/api/badge/repositories/8227" alt="Calcium-Ion%2Fnew-api | Trendshift" style="width: 250px; height: 55px;" width="250" height="55"/>
|
||||
<a href="https://trendshift.io/repositories/20180" target="_blank">
|
||||
<img src="https://trendshift.io/api/badge/repositories/20180" alt="QuantumNous%2Fnew-api | Trendshift" style="width: 250px; height: 55px;" width="250" height="55"/>
|
||||
</a>
|
||||
<br>
|
||||
<a href="https://hellogithub.com/repository/QuantumNous/new-api" target="_blank">
|
||||
|
||||
@@ -30,8 +30,8 @@
|
||||
</p>
|
||||
|
||||
<p align="center">
|
||||
<a href="https://trendshift.io/repositories/8227" target="_blank">
|
||||
<img src="https://trendshift.io/api/badge/repositories/8227" alt="Calcium-Ion%2Fnew-api | Trendshift" style="width: 250px; height: 55px;" width="250" height="55"/>
|
||||
<a href="https://trendshift.io/repositories/20180" target="_blank">
|
||||
<img src="https://trendshift.io/api/badge/repositories/20180" alt="QuantumNous%2Fnew-api | Trendshift" style="width: 250px; height: 55px;" width="250" height="55"/>
|
||||
</a>
|
||||
<br>
|
||||
<a href="https://hellogithub.com/repository/QuantumNous/new-api" target="_blank">
|
||||
|
||||
@@ -30,8 +30,8 @@
|
||||
</p>
|
||||
|
||||
<p align="center">
|
||||
<a href="https://trendshift.io/repositories/8227" target="_blank">
|
||||
<img src="https://trendshift.io/api/badge/repositories/8227" alt="Calcium-Ion%2Fnew-api | Trendshift" style="width: 250px; height: 55px;" width="250" height="55"/>
|
||||
<a href="https://trendshift.io/repositories/20180" target="_blank">
|
||||
<img src="https://trendshift.io/api/badge/repositories/20180" alt="QuantumNous%2Fnew-api | Trendshift" style="width: 250px; height: 55px;" width="250" height="55"/>
|
||||
</a>
|
||||
<br>
|
||||
<a href="https://hellogithub.com/repository/QuantumNous/new-api" target="_blank">
|
||||
|
||||
@@ -30,8 +30,8 @@
|
||||
</p>
|
||||
|
||||
<p align="center">
|
||||
<a href="https://trendshift.io/repositories/8227" target="_blank">
|
||||
<img src="https://trendshift.io/api/badge/repositories/8227" alt="Calcium-Ion%2Fnew-api | Trendshift" style="width: 250px; height: 55px;" width="250" height="55"/>
|
||||
<a href="https://trendshift.io/repositories/20180" target="_blank">
|
||||
<img src="https://trendshift.io/api/badge/repositories/20180" alt="QuantumNous%2Fnew-api | Trendshift" style="width: 250px; height: 55px;" width="250" height="55"/>
|
||||
</a>
|
||||
<br>
|
||||
<a href="https://hellogithub.com/repository/QuantumNous/new-api" target="_blank">
|
||||
|
||||
@@ -243,7 +243,15 @@ func ParseMultipartFormReusable(c *gin.Context) (*multipart.Form, error) {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
contentType := c.Request.Header.Get("Content-Type")
|
||||
// Use the original Content-Type saved on first call to avoid boundary
|
||||
// mismatch when callers overwrite c.Request.Header after multipart rebuild.
|
||||
var contentType string
|
||||
if saved, ok := c.Get("_original_multipart_ct"); ok {
|
||||
contentType = saved.(string)
|
||||
} else {
|
||||
contentType = c.Request.Header.Get("Content-Type")
|
||||
c.Set("_original_multipart_ct", contentType)
|
||||
}
|
||||
boundary, err := parseBoundary(contentType)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
@@ -295,7 +303,13 @@ func parseFormData(data []byte, v any) error {
|
||||
}
|
||||
|
||||
func parseMultipartFormData(c *gin.Context, data []byte, v any) error {
|
||||
contentType := c.Request.Header.Get("Content-Type")
|
||||
var contentType string
|
||||
if saved, ok := c.Get("_original_multipart_ct"); ok {
|
||||
contentType = saved.(string)
|
||||
} else {
|
||||
contentType = c.Request.Header.Get("Content-Type")
|
||||
c.Set("_original_multipart_ct", contentType)
|
||||
}
|
||||
boundary, err := parseBoundary(contentType)
|
||||
if err != nil {
|
||||
if errors.Is(err, errBoundaryNotFound) {
|
||||
|
||||
@@ -145,6 +145,8 @@ func initConstantEnv() {
|
||||
constant.ErrorLogEnabled = GetEnvOrDefaultBool("ERROR_LOG_ENABLED", false)
|
||||
// 任务轮询时查询的最大数量
|
||||
constant.TaskQueryLimit = GetEnvOrDefault("TASK_QUERY_LIMIT", 1000)
|
||||
// 异步任务超时时间(分钟),超过此时间未完成的任务将被标记为失败并退款。0 表示禁用。
|
||||
constant.TaskTimeoutMinutes = GetEnvOrDefault("TASK_TIMEOUT_MINUTES", 1440)
|
||||
|
||||
soraPatchStr := GetEnvOrDefaultString("TASK_PRICE_PATCH", "")
|
||||
if soraPatchStr != "" {
|
||||
|
||||
@@ -16,6 +16,7 @@ var NotificationLimitDurationMinute int
|
||||
var GenerateDefaultToken bool
|
||||
var ErrorLogEnabled bool
|
||||
var TaskQueryLimit int
|
||||
var TaskTimeoutMinutes int
|
||||
|
||||
// temporary variable for sora patch, will be removed in future
|
||||
var TaskPricePatches []string
|
||||
|
||||
@@ -145,6 +145,7 @@ func completeCodexOAuthWithChannelID(c *gin.Context, channelID int) {
|
||||
return
|
||||
}
|
||||
|
||||
channelProxy := ""
|
||||
if channelID > 0 {
|
||||
ch, err := model.GetChannelById(channelID, false)
|
||||
if err != nil {
|
||||
@@ -159,6 +160,7 @@ func completeCodexOAuthWithChannelID(c *gin.Context, channelID int) {
|
||||
c.JSON(http.StatusOK, gin.H{"success": false, "message": "channel type is not Codex"})
|
||||
return
|
||||
}
|
||||
channelProxy = ch.GetSetting().Proxy
|
||||
}
|
||||
|
||||
session := sessions.Default(c)
|
||||
@@ -176,7 +178,7 @@ func completeCodexOAuthWithChannelID(c *gin.Context, channelID int) {
|
||||
ctx, cancel := context.WithTimeout(c.Request.Context(), 15*time.Second)
|
||||
defer cancel()
|
||||
|
||||
tokenRes, err := service.ExchangeCodexAuthorizationCode(ctx, code, verifier)
|
||||
tokenRes, err := service.ExchangeCodexAuthorizationCodeWithProxy(ctx, code, verifier, channelProxy)
|
||||
if err != nil {
|
||||
common.SysError("failed to exchange codex authorization code: " + err.Error())
|
||||
c.JSON(http.StatusOK, gin.H{"success": false, "message": "授权码交换失败,请重试"})
|
||||
|
||||
@@ -2,7 +2,6 @@ package controller
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"strconv"
|
||||
@@ -80,7 +79,7 @@ func GetCodexChannelUsage(c *gin.Context) {
|
||||
refreshCtx, refreshCancel := context.WithTimeout(c.Request.Context(), 10*time.Second)
|
||||
defer refreshCancel()
|
||||
|
||||
res, refreshErr := service.RefreshCodexOAuthToken(refreshCtx, oauthKey.RefreshToken)
|
||||
res, refreshErr := service.RefreshCodexOAuthTokenWithProxy(refreshCtx, oauthKey.RefreshToken, ch.GetSetting().Proxy)
|
||||
if refreshErr == nil {
|
||||
oauthKey.AccessToken = res.AccessToken
|
||||
oauthKey.RefreshToken = res.RefreshToken
|
||||
@@ -109,7 +108,7 @@ func GetCodexChannelUsage(c *gin.Context) {
|
||||
}
|
||||
|
||||
var payload any
|
||||
if json.Unmarshal(body, &payload) != nil {
|
||||
if common.Unmarshal(body, &payload) != nil {
|
||||
payload = string(body)
|
||||
}
|
||||
|
||||
|
||||
@@ -1,8 +1,13 @@
|
||||
package controller
|
||||
|
||||
import (
|
||||
"context"
|
||||
"io"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"strconv"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/QuantumNous/new-api/common"
|
||||
"github.com/QuantumNous/new-api/model"
|
||||
@@ -16,6 +21,7 @@ type CustomOAuthProviderResponse struct {
|
||||
Id int `json:"id"`
|
||||
Name string `json:"name"`
|
||||
Slug string `json:"slug"`
|
||||
Icon string `json:"icon"`
|
||||
Enabled bool `json:"enabled"`
|
||||
ClientId string `json:"client_id"`
|
||||
AuthorizationEndpoint string `json:"authorization_endpoint"`
|
||||
@@ -28,6 +34,16 @@ type CustomOAuthProviderResponse struct {
|
||||
EmailField string `json:"email_field"`
|
||||
WellKnown string `json:"well_known"`
|
||||
AuthStyle int `json:"auth_style"`
|
||||
AccessPolicy string `json:"access_policy"`
|
||||
AccessDeniedMessage string `json:"access_denied_message"`
|
||||
}
|
||||
|
||||
type UserOAuthBindingResponse struct {
|
||||
ProviderId int `json:"provider_id"`
|
||||
ProviderName string `json:"provider_name"`
|
||||
ProviderSlug string `json:"provider_slug"`
|
||||
ProviderIcon string `json:"provider_icon"`
|
||||
ProviderUserId string `json:"provider_user_id"`
|
||||
}
|
||||
|
||||
func toCustomOAuthProviderResponse(p *model.CustomOAuthProvider) *CustomOAuthProviderResponse {
|
||||
@@ -35,6 +51,7 @@ func toCustomOAuthProviderResponse(p *model.CustomOAuthProvider) *CustomOAuthPro
|
||||
Id: p.Id,
|
||||
Name: p.Name,
|
||||
Slug: p.Slug,
|
||||
Icon: p.Icon,
|
||||
Enabled: p.Enabled,
|
||||
ClientId: p.ClientId,
|
||||
AuthorizationEndpoint: p.AuthorizationEndpoint,
|
||||
@@ -47,6 +64,8 @@ func toCustomOAuthProviderResponse(p *model.CustomOAuthProvider) *CustomOAuthPro
|
||||
EmailField: p.EmailField,
|
||||
WellKnown: p.WellKnown,
|
||||
AuthStyle: p.AuthStyle,
|
||||
AccessPolicy: p.AccessPolicy,
|
||||
AccessDeniedMessage: p.AccessDeniedMessage,
|
||||
}
|
||||
}
|
||||
|
||||
@@ -96,6 +115,7 @@ func GetCustomOAuthProvider(c *gin.Context) {
|
||||
type CreateCustomOAuthProviderRequest struct {
|
||||
Name string `json:"name" binding:"required"`
|
||||
Slug string `json:"slug" binding:"required"`
|
||||
Icon string `json:"icon"`
|
||||
Enabled bool `json:"enabled"`
|
||||
ClientId string `json:"client_id" binding:"required"`
|
||||
ClientSecret string `json:"client_secret" binding:"required"`
|
||||
@@ -109,6 +129,85 @@ type CreateCustomOAuthProviderRequest struct {
|
||||
EmailField string `json:"email_field"`
|
||||
WellKnown string `json:"well_known"`
|
||||
AuthStyle int `json:"auth_style"`
|
||||
AccessPolicy string `json:"access_policy"`
|
||||
AccessDeniedMessage string `json:"access_denied_message"`
|
||||
}
|
||||
|
||||
type FetchCustomOAuthDiscoveryRequest struct {
|
||||
WellKnownURL string `json:"well_known_url"`
|
||||
IssuerURL string `json:"issuer_url"`
|
||||
}
|
||||
|
||||
// FetchCustomOAuthDiscovery fetches OIDC discovery document via backend (root-only route)
|
||||
func FetchCustomOAuthDiscovery(c *gin.Context) {
|
||||
var req FetchCustomOAuthDiscoveryRequest
|
||||
if err := c.ShouldBindJSON(&req); err != nil {
|
||||
common.ApiErrorMsg(c, "无效的请求参数: "+err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
wellKnownURL := strings.TrimSpace(req.WellKnownURL)
|
||||
issuerURL := strings.TrimSpace(req.IssuerURL)
|
||||
|
||||
if wellKnownURL == "" && issuerURL == "" {
|
||||
common.ApiErrorMsg(c, "请先填写 Discovery URL 或 Issuer URL")
|
||||
return
|
||||
}
|
||||
|
||||
targetURL := wellKnownURL
|
||||
if targetURL == "" {
|
||||
targetURL = strings.TrimRight(issuerURL, "/") + "/.well-known/openid-configuration"
|
||||
}
|
||||
targetURL = strings.TrimSpace(targetURL)
|
||||
|
||||
parsedURL, err := url.Parse(targetURL)
|
||||
if err != nil || parsedURL.Host == "" || (parsedURL.Scheme != "http" && parsedURL.Scheme != "https") {
|
||||
common.ApiErrorMsg(c, "Discovery URL 无效,仅支持 http/https")
|
||||
return
|
||||
}
|
||||
|
||||
ctx, cancel := context.WithTimeout(c.Request.Context(), 20*time.Second)
|
||||
defer cancel()
|
||||
|
||||
httpReq, err := http.NewRequestWithContext(ctx, http.MethodGet, targetURL, nil)
|
||||
if err != nil {
|
||||
common.ApiErrorMsg(c, "创建 Discovery 请求失败: "+err.Error())
|
||||
return
|
||||
}
|
||||
httpReq.Header.Set("Accept", "application/json")
|
||||
|
||||
client := &http.Client{Timeout: 20 * time.Second}
|
||||
resp, err := client.Do(httpReq)
|
||||
if err != nil {
|
||||
common.ApiErrorMsg(c, "获取 Discovery 配置失败: "+err.Error())
|
||||
return
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
body, _ := io.ReadAll(io.LimitReader(resp.Body, 512))
|
||||
message := strings.TrimSpace(string(body))
|
||||
if message == "" {
|
||||
message = resp.Status
|
||||
}
|
||||
common.ApiErrorMsg(c, "获取 Discovery 配置失败: "+message)
|
||||
return
|
||||
}
|
||||
|
||||
var discovery map[string]any
|
||||
if err = common.DecodeJson(resp.Body, &discovery); err != nil {
|
||||
common.ApiErrorMsg(c, "解析 Discovery 配置失败: "+err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": true,
|
||||
"message": "",
|
||||
"data": gin.H{
|
||||
"well_known_url": targetURL,
|
||||
"discovery": discovery,
|
||||
},
|
||||
})
|
||||
}
|
||||
|
||||
// CreateCustomOAuthProvider creates a new custom OAuth provider
|
||||
@@ -134,6 +233,7 @@ func CreateCustomOAuthProvider(c *gin.Context) {
|
||||
provider := &model.CustomOAuthProvider{
|
||||
Name: req.Name,
|
||||
Slug: req.Slug,
|
||||
Icon: req.Icon,
|
||||
Enabled: req.Enabled,
|
||||
ClientId: req.ClientId,
|
||||
ClientSecret: req.ClientSecret,
|
||||
@@ -147,6 +247,8 @@ func CreateCustomOAuthProvider(c *gin.Context) {
|
||||
EmailField: req.EmailField,
|
||||
WellKnown: req.WellKnown,
|
||||
AuthStyle: req.AuthStyle,
|
||||
AccessPolicy: req.AccessPolicy,
|
||||
AccessDeniedMessage: req.AccessDeniedMessage,
|
||||
}
|
||||
|
||||
if err := model.CreateCustomOAuthProvider(provider); err != nil {
|
||||
@@ -168,9 +270,10 @@ func CreateCustomOAuthProvider(c *gin.Context) {
|
||||
type UpdateCustomOAuthProviderRequest struct {
|
||||
Name string `json:"name"`
|
||||
Slug string `json:"slug"`
|
||||
Enabled *bool `json:"enabled"` // Optional: if nil, keep existing
|
||||
Icon *string `json:"icon"` // Optional: if nil, keep existing
|
||||
Enabled *bool `json:"enabled"` // Optional: if nil, keep existing
|
||||
ClientId string `json:"client_id"`
|
||||
ClientSecret string `json:"client_secret"` // Optional: if empty, keep existing
|
||||
ClientSecret string `json:"client_secret"` // Optional: if empty, keep existing
|
||||
AuthorizationEndpoint string `json:"authorization_endpoint"`
|
||||
TokenEndpoint string `json:"token_endpoint"`
|
||||
UserInfoEndpoint string `json:"user_info_endpoint"`
|
||||
@@ -181,6 +284,8 @@ type UpdateCustomOAuthProviderRequest struct {
|
||||
EmailField string `json:"email_field"`
|
||||
WellKnown *string `json:"well_known"` // Optional: if nil, keep existing
|
||||
AuthStyle *int `json:"auth_style"` // Optional: if nil, keep existing
|
||||
AccessPolicy *string `json:"access_policy"` // Optional: if nil, keep existing
|
||||
AccessDeniedMessage *string `json:"access_denied_message"` // Optional: if nil, keep existing
|
||||
}
|
||||
|
||||
// UpdateCustomOAuthProvider updates an existing custom OAuth provider
|
||||
@@ -227,6 +332,9 @@ func UpdateCustomOAuthProvider(c *gin.Context) {
|
||||
if req.Slug != "" {
|
||||
provider.Slug = req.Slug
|
||||
}
|
||||
if req.Icon != nil {
|
||||
provider.Icon = *req.Icon
|
||||
}
|
||||
if req.Enabled != nil {
|
||||
provider.Enabled = *req.Enabled
|
||||
}
|
||||
@@ -266,6 +374,12 @@ func UpdateCustomOAuthProvider(c *gin.Context) {
|
||||
if req.AuthStyle != nil {
|
||||
provider.AuthStyle = *req.AuthStyle
|
||||
}
|
||||
if req.AccessPolicy != nil {
|
||||
provider.AccessPolicy = *req.AccessPolicy
|
||||
}
|
||||
if req.AccessDeniedMessage != nil {
|
||||
provider.AccessDeniedMessage = *req.AccessDeniedMessage
|
||||
}
|
||||
|
||||
if err := model.UpdateCustomOAuthProvider(provider); err != nil {
|
||||
common.ApiError(c, err)
|
||||
@@ -327,6 +441,30 @@ func DeleteCustomOAuthProvider(c *gin.Context) {
|
||||
})
|
||||
}
|
||||
|
||||
func buildUserOAuthBindingsResponse(userId int) ([]UserOAuthBindingResponse, error) {
|
||||
bindings, err := model.GetUserOAuthBindingsByUserId(userId)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
response := make([]UserOAuthBindingResponse, 0, len(bindings))
|
||||
for _, binding := range bindings {
|
||||
provider, err := model.GetCustomOAuthProviderById(binding.ProviderId)
|
||||
if err != nil {
|
||||
continue
|
||||
}
|
||||
response = append(response, UserOAuthBindingResponse{
|
||||
ProviderId: binding.ProviderId,
|
||||
ProviderName: provider.Name,
|
||||
ProviderSlug: provider.Slug,
|
||||
ProviderIcon: provider.Icon,
|
||||
ProviderUserId: binding.ProviderUserId,
|
||||
})
|
||||
}
|
||||
|
||||
return response, nil
|
||||
}
|
||||
|
||||
// GetUserOAuthBindings returns all OAuth bindings for the current user
|
||||
func GetUserOAuthBindings(c *gin.Context) {
|
||||
userId := c.GetInt("id")
|
||||
@@ -335,32 +473,43 @@ func GetUserOAuthBindings(c *gin.Context) {
|
||||
return
|
||||
}
|
||||
|
||||
bindings, err := model.GetUserOAuthBindingsByUserId(userId)
|
||||
response, err := buildUserOAuthBindingsResponse(userId)
|
||||
if err != nil {
|
||||
common.ApiError(c, err)
|
||||
return
|
||||
}
|
||||
|
||||
// Build response with provider info
|
||||
type BindingResponse struct {
|
||||
ProviderId int `json:"provider_id"`
|
||||
ProviderName string `json:"provider_name"`
|
||||
ProviderSlug string `json:"provider_slug"`
|
||||
ProviderUserId string `json:"provider_user_id"`
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": true,
|
||||
"message": "",
|
||||
"data": response,
|
||||
})
|
||||
}
|
||||
|
||||
func GetUserOAuthBindingsByAdmin(c *gin.Context) {
|
||||
userIdStr := c.Param("id")
|
||||
userId, err := strconv.Atoi(userIdStr)
|
||||
if err != nil {
|
||||
common.ApiErrorMsg(c, "invalid user id")
|
||||
return
|
||||
}
|
||||
|
||||
response := make([]BindingResponse, 0)
|
||||
for _, binding := range bindings {
|
||||
provider, err := model.GetCustomOAuthProviderById(binding.ProviderId)
|
||||
if err != nil {
|
||||
continue // Skip if provider not found
|
||||
}
|
||||
response = append(response, BindingResponse{
|
||||
ProviderId: binding.ProviderId,
|
||||
ProviderName: provider.Name,
|
||||
ProviderSlug: provider.Slug,
|
||||
ProviderUserId: binding.ProviderUserId,
|
||||
})
|
||||
targetUser, err := model.GetUserById(userId, false)
|
||||
if err != nil {
|
||||
common.ApiError(c, err)
|
||||
return
|
||||
}
|
||||
|
||||
myRole := c.GetInt("role")
|
||||
if myRole <= targetUser.Role && myRole != common.RoleRootUser {
|
||||
common.ApiErrorMsg(c, "no permission")
|
||||
return
|
||||
}
|
||||
|
||||
response, err := buildUserOAuthBindingsResponse(userId)
|
||||
if err != nil {
|
||||
common.ApiError(c, err)
|
||||
return
|
||||
}
|
||||
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
@@ -395,3 +544,41 @@ func UnbindCustomOAuth(c *gin.Context) {
|
||||
"message": "解绑成功",
|
||||
})
|
||||
}
|
||||
|
||||
func UnbindCustomOAuthByAdmin(c *gin.Context) {
|
||||
userIdStr := c.Param("id")
|
||||
userId, err := strconv.Atoi(userIdStr)
|
||||
if err != nil {
|
||||
common.ApiErrorMsg(c, "invalid user id")
|
||||
return
|
||||
}
|
||||
|
||||
targetUser, err := model.GetUserById(userId, false)
|
||||
if err != nil {
|
||||
common.ApiError(c, err)
|
||||
return
|
||||
}
|
||||
|
||||
myRole := c.GetInt("role")
|
||||
if myRole <= targetUser.Role && myRole != common.RoleRootUser {
|
||||
common.ApiErrorMsg(c, "no permission")
|
||||
return
|
||||
}
|
||||
|
||||
providerIdStr := c.Param("provider_id")
|
||||
providerId, err := strconv.Atoi(providerIdStr)
|
||||
if err != nil {
|
||||
common.ApiErrorMsg(c, "invalid provider id")
|
||||
return
|
||||
}
|
||||
|
||||
if err := model.DeleteUserOAuthBinding(userId, providerId); err != nil {
|
||||
common.ApiError(c, err)
|
||||
return
|
||||
}
|
||||
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": true,
|
||||
"message": "success",
|
||||
})
|
||||
}
|
||||
|
||||
@@ -105,13 +105,13 @@ func UpdateMidjourneyTaskBulk() {
|
||||
}
|
||||
responseBody, err := io.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
logger.LogError(ctx, fmt.Sprintf("Get Task parse body error: %v", err))
|
||||
logger.LogError(ctx, fmt.Sprintf("Get Mjp Task parse body error: %v", err))
|
||||
continue
|
||||
}
|
||||
var responseItems []dto.MidjourneyDto
|
||||
err = json.Unmarshal(responseBody, &responseItems)
|
||||
if err != nil {
|
||||
logger.LogError(ctx, fmt.Sprintf("Get Task parse body error2: %v, body: %s", err, string(responseBody)))
|
||||
logger.LogError(ctx, fmt.Sprintf("Get Mjp Task parse body error2: %v, body: %s", err, string(responseBody)))
|
||||
continue
|
||||
}
|
||||
resp.Body.Close()
|
||||
@@ -130,6 +130,7 @@ func UpdateMidjourneyTaskBulk() {
|
||||
if !checkMjTaskNeedUpdate(task, responseItem) {
|
||||
continue
|
||||
}
|
||||
preStatus := task.Status
|
||||
task.Code = 1
|
||||
task.Progress = responseItem.Progress
|
||||
task.PromptEn = responseItem.PromptEn
|
||||
@@ -172,18 +173,26 @@ func UpdateMidjourneyTaskBulk() {
|
||||
shouldReturnQuota = true
|
||||
}
|
||||
}
|
||||
err = task.Update()
|
||||
won, err := task.UpdateWithStatus(preStatus)
|
||||
if err != nil {
|
||||
logger.LogError(ctx, "UpdateMidjourneyTask task error: "+err.Error())
|
||||
} else {
|
||||
if shouldReturnQuota {
|
||||
err = model.IncreaseUserQuota(task.UserId, task.Quota, false)
|
||||
if err != nil {
|
||||
logger.LogError(ctx, "fail to increase user quota: "+err.Error())
|
||||
}
|
||||
logContent := fmt.Sprintf("构图失败 %s,补偿 %s", task.MjId, logger.LogQuota(task.Quota))
|
||||
model.RecordLog(task.UserId, model.LogTypeSystem, logContent)
|
||||
} else if won && shouldReturnQuota {
|
||||
err = model.IncreaseUserQuota(task.UserId, task.Quota, false)
|
||||
if err != nil {
|
||||
logger.LogError(ctx, "fail to increase user quota: "+err.Error())
|
||||
}
|
||||
model.RecordTaskBillingLog(model.RecordTaskBillingLogParams{
|
||||
UserId: task.UserId,
|
||||
LogType: model.LogTypeRefund,
|
||||
Content: "",
|
||||
ChannelId: task.ChannelId,
|
||||
ModelName: service.CovertMjpActionToModelName(task.Action),
|
||||
Quota: task.Quota,
|
||||
Other: map[string]interface{}{
|
||||
"task_id": task.MjId,
|
||||
"reason": "构图失败",
|
||||
},
|
||||
})
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -134,8 +134,10 @@ func GetStatus(c *gin.Context) {
|
||||
customProviders := oauth.GetEnabledCustomProviders()
|
||||
if len(customProviders) > 0 {
|
||||
type CustomOAuthInfo struct {
|
||||
Id int `json:"id"`
|
||||
Name string `json:"name"`
|
||||
Slug string `json:"slug"`
|
||||
Icon string `json:"icon"`
|
||||
ClientId string `json:"client_id"`
|
||||
AuthorizationEndpoint string `json:"authorization_endpoint"`
|
||||
Scopes string `json:"scopes"`
|
||||
@@ -144,8 +146,10 @@ func GetStatus(c *gin.Context) {
|
||||
for _, p := range customProviders {
|
||||
config := p.GetConfig()
|
||||
providersInfo = append(providersInfo, CustomOAuthInfo{
|
||||
Id: config.Id,
|
||||
Name: config.Name,
|
||||
Slug: config.Slug,
|
||||
Icon: config.Icon,
|
||||
ClientId: config.ClientId,
|
||||
AuthorizationEndpoint: config.AuthorizationEndpoint,
|
||||
Scopes: config.Scopes,
|
||||
|
||||
@@ -237,6 +237,16 @@ func findOrCreateOAuthUser(c *gin.Context, provider oauth.Provider, oauthUser *o
|
||||
|
||||
// Set up new user
|
||||
user.Username = provider.GetProviderPrefix() + strconv.Itoa(model.GetMaxUserId()+1)
|
||||
|
||||
if oauthUser.Username != "" {
|
||||
if exists, err := model.CheckUserExistOrDeleted(oauthUser.Username, ""); err == nil && !exists {
|
||||
// 防止索引退化
|
||||
if len(oauthUser.Username) <= model.UserNameMaxLength {
|
||||
user.Username = oauthUser.Username
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if oauthUser.DisplayName != "" {
|
||||
user.DisplayName = oauthUser.DisplayName
|
||||
} else if oauthUser.Username != "" {
|
||||
@@ -295,12 +305,12 @@ func findOrCreateOAuthUser(c *gin.Context, provider oauth.Provider, oauthUser *o
|
||||
// Set the provider user ID on the user model and update
|
||||
provider.SetProviderUserID(user, oauthUser.ProviderUserID)
|
||||
if err := tx.Model(user).Updates(map[string]interface{}{
|
||||
"github_id": user.GitHubId,
|
||||
"discord_id": user.DiscordId,
|
||||
"oidc_id": user.OidcId,
|
||||
"linux_do_id": user.LinuxDOId,
|
||||
"wechat_id": user.WeChatId,
|
||||
"telegram_id": user.TelegramId,
|
||||
"github_id": user.GitHubId,
|
||||
"discord_id": user.DiscordId,
|
||||
"oidc_id": user.OidcId,
|
||||
"linux_do_id": user.LinuxDOId,
|
||||
"wechat_id": user.WeChatId,
|
||||
"telegram_id": user.TelegramId,
|
||||
}).Error; err != nil {
|
||||
return err
|
||||
}
|
||||
@@ -340,6 +350,8 @@ func handleOAuthError(c *gin.Context, err error) {
|
||||
} else {
|
||||
common.ApiErrorI18n(c, e.MsgKey)
|
||||
}
|
||||
case *oauth.AccessDeniedError:
|
||||
common.ApiErrorMsg(c, e.Message)
|
||||
case *oauth.TrustLevelError:
|
||||
common.ApiErrorI18n(c, i18n.MsgOAuthTrustLevelLow)
|
||||
default:
|
||||
|
||||
@@ -455,72 +455,147 @@ func RelayNotFound(c *gin.Context) {
|
||||
})
|
||||
}
|
||||
|
||||
func RelayTask(c *gin.Context) {
|
||||
retryTimes := common.RetryTimes
|
||||
channelId := c.GetInt("channel_id")
|
||||
c.Set("use_channel", []string{fmt.Sprintf("%d", channelId)})
|
||||
func RelayTaskFetch(c *gin.Context) {
|
||||
relayInfo, err := relaycommon.GenRelayInfo(c, types.RelayFormatTask, nil, nil)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusInternalServerError, &dto.TaskError{
|
||||
Code: "gen_relay_info_failed",
|
||||
Message: err.Error(),
|
||||
StatusCode: http.StatusInternalServerError,
|
||||
})
|
||||
return
|
||||
}
|
||||
taskErr := taskRelayHandler(c, relayInfo)
|
||||
if taskErr == nil {
|
||||
retryTimes = 0
|
||||
if taskErr := relay.RelayTaskFetch(c, relayInfo.RelayMode); taskErr != nil {
|
||||
respondTaskError(c, taskErr)
|
||||
}
|
||||
}
|
||||
|
||||
func RelayTask(c *gin.Context) {
|
||||
relayInfo, err := relaycommon.GenRelayInfo(c, types.RelayFormatTask, nil, nil)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusInternalServerError, &dto.TaskError{
|
||||
Code: "gen_relay_info_failed",
|
||||
Message: err.Error(),
|
||||
StatusCode: http.StatusInternalServerError,
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
if taskErr := relay.ResolveOriginTask(c, relayInfo); taskErr != nil {
|
||||
respondTaskError(c, taskErr)
|
||||
return
|
||||
}
|
||||
|
||||
var result *relay.TaskSubmitResult
|
||||
var taskErr *dto.TaskError
|
||||
defer func() {
|
||||
if taskErr != nil && relayInfo.Billing != nil {
|
||||
relayInfo.Billing.Refund(c)
|
||||
}
|
||||
}()
|
||||
|
||||
retryParam := &service.RetryParam{
|
||||
Ctx: c,
|
||||
TokenGroup: relayInfo.TokenGroup,
|
||||
ModelName: relayInfo.OriginModelName,
|
||||
Retry: common.GetPointer(0),
|
||||
}
|
||||
for ; shouldRetryTaskRelay(c, channelId, taskErr, retryTimes) && retryParam.GetRetry() < retryTimes; retryParam.IncreaseRetry() {
|
||||
channel, newAPIError := getChannel(c, relayInfo, retryParam)
|
||||
if newAPIError != nil {
|
||||
logger.LogError(c, fmt.Sprintf("CacheGetRandomSatisfiedChannel failed: %s", newAPIError.Error()))
|
||||
taskErr = service.TaskErrorWrapperLocal(newAPIError.Err, "get_channel_failed", http.StatusInternalServerError)
|
||||
break
|
||||
}
|
||||
channelId = channel.Id
|
||||
useChannel := c.GetStringSlice("use_channel")
|
||||
useChannel = append(useChannel, fmt.Sprintf("%d", channelId))
|
||||
c.Set("use_channel", useChannel)
|
||||
logger.LogInfo(c, fmt.Sprintf("using channel #%d to retry (remain times %d)", channel.Id, retryParam.GetRetry()))
|
||||
//middleware.SetupContextForSelectedChannel(c, channel, originalModel)
|
||||
|
||||
bodyStorage, err := common.GetBodyStorage(c)
|
||||
if err != nil {
|
||||
if common.IsRequestBodyTooLargeError(err) || errors.Is(err, common.ErrRequestBodyTooLarge) {
|
||||
taskErr = service.TaskErrorWrapperLocal(err, "read_request_body_failed", http.StatusRequestEntityTooLarge)
|
||||
for ; retryParam.GetRetry() <= common.RetryTimes; retryParam.IncreaseRetry() {
|
||||
var channel *model.Channel
|
||||
|
||||
if lockedCh, ok := relayInfo.LockedChannel.(*model.Channel); ok && lockedCh != nil {
|
||||
channel = lockedCh
|
||||
if retryParam.GetRetry() > 0 {
|
||||
if setupErr := middleware.SetupContextForSelectedChannel(c, channel, relayInfo.OriginModelName); setupErr != nil {
|
||||
taskErr = service.TaskErrorWrapperLocal(setupErr.Err, "setup_locked_channel_failed", http.StatusInternalServerError)
|
||||
break
|
||||
}
|
||||
}
|
||||
} else {
|
||||
var channelErr *types.NewAPIError
|
||||
channel, channelErr = getChannel(c, relayInfo, retryParam)
|
||||
if channelErr != nil {
|
||||
logger.LogError(c, channelErr.Error())
|
||||
taskErr = service.TaskErrorWrapperLocal(channelErr.Err, "get_channel_failed", http.StatusInternalServerError)
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
addUsedChannel(c, channel.Id)
|
||||
bodyStorage, bodyErr := common.GetBodyStorage(c)
|
||||
if bodyErr != nil {
|
||||
if common.IsRequestBodyTooLargeError(bodyErr) || errors.Is(bodyErr, common.ErrRequestBodyTooLarge) {
|
||||
taskErr = service.TaskErrorWrapperLocal(bodyErr, "read_request_body_failed", http.StatusRequestEntityTooLarge)
|
||||
} else {
|
||||
taskErr = service.TaskErrorWrapperLocal(err, "read_request_body_failed", http.StatusBadRequest)
|
||||
taskErr = service.TaskErrorWrapperLocal(bodyErr, "read_request_body_failed", http.StatusBadRequest)
|
||||
}
|
||||
break
|
||||
}
|
||||
c.Request.Body = io.NopCloser(bodyStorage)
|
||||
taskErr = taskRelayHandler(c, relayInfo)
|
||||
|
||||
result, taskErr = relay.RelayTaskSubmit(c, relayInfo)
|
||||
if taskErr == nil {
|
||||
break
|
||||
}
|
||||
|
||||
if !taskErr.LocalError {
|
||||
processChannelError(c,
|
||||
*types.NewChannelError(channel.Id, channel.Type, channel.Name, channel.ChannelInfo.IsMultiKey,
|
||||
common.GetContextKeyString(c, constant.ContextKeyChannelKey), channel.GetAutoBan()),
|
||||
types.NewOpenAIError(taskErr.Error, types.ErrorCodeBadResponseStatusCode, taskErr.StatusCode))
|
||||
}
|
||||
|
||||
if !shouldRetryTaskRelay(c, channel.Id, taskErr, common.RetryTimes-retryParam.GetRetry()) {
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
useChannel := c.GetStringSlice("use_channel")
|
||||
if len(useChannel) > 1 {
|
||||
retryLogStr := fmt.Sprintf("重试:%s", strings.Trim(strings.Join(strings.Fields(fmt.Sprint(useChannel)), "->"), "[]"))
|
||||
logger.LogInfo(c, retryLogStr)
|
||||
}
|
||||
if taskErr != nil {
|
||||
if taskErr.StatusCode == http.StatusTooManyRequests {
|
||||
taskErr.Message = "当前分组上游负载已饱和,请稍后再试"
|
||||
|
||||
// ── 成功:结算 + 日志 + 插入任务 ──
|
||||
if taskErr == nil {
|
||||
if settleErr := service.SettleBilling(c, relayInfo, result.Quota); settleErr != nil {
|
||||
common.SysError("settle task billing error: " + settleErr.Error())
|
||||
}
|
||||
c.JSON(taskErr.StatusCode, taskErr)
|
||||
service.LogTaskConsumption(c, relayInfo)
|
||||
|
||||
task := model.InitTask(result.Platform, relayInfo)
|
||||
task.PrivateData.UpstreamTaskID = result.UpstreamTaskID
|
||||
task.PrivateData.BillingSource = relayInfo.BillingSource
|
||||
task.PrivateData.SubscriptionId = relayInfo.SubscriptionId
|
||||
task.PrivateData.TokenId = relayInfo.TokenId
|
||||
task.PrivateData.BillingContext = &model.TaskBillingContext{
|
||||
ModelPrice: relayInfo.PriceData.ModelPrice,
|
||||
GroupRatio: relayInfo.PriceData.GroupRatioInfo.GroupRatio,
|
||||
ModelRatio: relayInfo.PriceData.ModelRatio,
|
||||
OtherRatios: relayInfo.PriceData.OtherRatios,
|
||||
OriginModelName: relayInfo.OriginModelName,
|
||||
PerCallBilling: common.StringsContains(constant.TaskPricePatches, relayInfo.OriginModelName),
|
||||
}
|
||||
task.Quota = result.Quota
|
||||
task.Data = result.TaskData
|
||||
task.Action = relayInfo.Action
|
||||
if insertErr := task.Insert(); insertErr != nil {
|
||||
common.SysError("insert task error: " + insertErr.Error())
|
||||
}
|
||||
}
|
||||
|
||||
if taskErr != nil {
|
||||
respondTaskError(c, taskErr)
|
||||
}
|
||||
}
|
||||
|
||||
func taskRelayHandler(c *gin.Context, relayInfo *relaycommon.RelayInfo) *dto.TaskError {
|
||||
var err *dto.TaskError
|
||||
switch relayInfo.RelayMode {
|
||||
case relayconstant.RelayModeSunoFetch, relayconstant.RelayModeSunoFetchByID, relayconstant.RelayModeVideoFetchByID:
|
||||
err = relay.RelayTaskFetch(c, relayInfo.RelayMode)
|
||||
default:
|
||||
err = relay.RelayTaskSubmit(c, relayInfo)
|
||||
// respondTaskError 统一输出 Task 错误响应(含 429 限流提示改写)
|
||||
func respondTaskError(c *gin.Context, taskErr *dto.TaskError) {
|
||||
if taskErr.StatusCode == http.StatusTooManyRequests {
|
||||
taskErr.Message = "当前分组上游负载已饱和,请稍后再试"
|
||||
}
|
||||
return err
|
||||
c.JSON(taskErr.StatusCode, taskErr)
|
||||
}
|
||||
|
||||
func shouldRetryTaskRelay(c *gin.Context, channelId int, taskErr *dto.TaskError, retryTimes int) bool {
|
||||
@@ -544,7 +619,7 @@ func shouldRetryTaskRelay(c *gin.Context, channelId int, taskErr *dto.TaskError,
|
||||
}
|
||||
if taskErr.StatusCode/100 == 5 {
|
||||
// 超时不重试
|
||||
if taskErr.StatusCode == 504 || taskErr.StatusCode == 524 {
|
||||
if operation_setting.IsAlwaysSkipRetryStatusCode(taskErr.StatusCode) {
|
||||
return false
|
||||
}
|
||||
return true
|
||||
|
||||
@@ -1,231 +1,22 @@
|
||||
package controller
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"sort"
|
||||
"strconv"
|
||||
"time"
|
||||
|
||||
"github.com/QuantumNous/new-api/common"
|
||||
"github.com/QuantumNous/new-api/constant"
|
||||
"github.com/QuantumNous/new-api/dto"
|
||||
"github.com/QuantumNous/new-api/logger"
|
||||
"github.com/QuantumNous/new-api/model"
|
||||
"github.com/QuantumNous/new-api/relay"
|
||||
"github.com/QuantumNous/new-api/service"
|
||||
"github.com/QuantumNous/new-api/types"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/samber/lo"
|
||||
)
|
||||
|
||||
// UpdateTaskBulk 薄入口,实际轮询逻辑在 service 层
|
||||
func UpdateTaskBulk() {
|
||||
//revocer
|
||||
//imageModel := "midjourney"
|
||||
for {
|
||||
time.Sleep(time.Duration(15) * time.Second)
|
||||
common.SysLog("任务进度轮询开始")
|
||||
ctx := context.TODO()
|
||||
allTasks := model.GetAllUnFinishSyncTasks(constant.TaskQueryLimit)
|
||||
platformTask := make(map[constant.TaskPlatform][]*model.Task)
|
||||
for _, t := range allTasks {
|
||||
platformTask[t.Platform] = append(platformTask[t.Platform], t)
|
||||
}
|
||||
for platform, tasks := range platformTask {
|
||||
if len(tasks) == 0 {
|
||||
continue
|
||||
}
|
||||
taskChannelM := make(map[int][]string)
|
||||
taskM := make(map[string]*model.Task)
|
||||
nullTaskIds := make([]int64, 0)
|
||||
for _, task := range tasks {
|
||||
if task.TaskID == "" {
|
||||
// 统计失败的未完成任务
|
||||
nullTaskIds = append(nullTaskIds, task.ID)
|
||||
continue
|
||||
}
|
||||
taskM[task.TaskID] = task
|
||||
taskChannelM[task.ChannelId] = append(taskChannelM[task.ChannelId], task.TaskID)
|
||||
}
|
||||
if len(nullTaskIds) > 0 {
|
||||
err := model.TaskBulkUpdateByID(nullTaskIds, map[string]any{
|
||||
"status": "FAILURE",
|
||||
"progress": "100%",
|
||||
})
|
||||
if err != nil {
|
||||
logger.LogError(ctx, fmt.Sprintf("Fix null task_id task error: %v", err))
|
||||
} else {
|
||||
logger.LogInfo(ctx, fmt.Sprintf("Fix null task_id task success: %v", nullTaskIds))
|
||||
}
|
||||
}
|
||||
if len(taskChannelM) == 0 {
|
||||
continue
|
||||
}
|
||||
|
||||
UpdateTaskByPlatform(platform, taskChannelM, taskM)
|
||||
}
|
||||
common.SysLog("任务进度轮询完成")
|
||||
}
|
||||
}
|
||||
|
||||
func UpdateTaskByPlatform(platform constant.TaskPlatform, taskChannelM map[int][]string, taskM map[string]*model.Task) {
|
||||
switch platform {
|
||||
case constant.TaskPlatformMidjourney:
|
||||
//_ = UpdateMidjourneyTaskAll(context.Background(), tasks)
|
||||
case constant.TaskPlatformSuno:
|
||||
_ = UpdateSunoTaskAll(context.Background(), taskChannelM, taskM)
|
||||
default:
|
||||
if err := UpdateVideoTaskAll(context.Background(), platform, taskChannelM, taskM); err != nil {
|
||||
common.SysLog(fmt.Sprintf("UpdateVideoTaskAll fail: %s", err))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func UpdateSunoTaskAll(ctx context.Context, taskChannelM map[int][]string, taskM map[string]*model.Task) error {
|
||||
for channelId, taskIds := range taskChannelM {
|
||||
err := updateSunoTaskAll(ctx, channelId, taskIds, taskM)
|
||||
if err != nil {
|
||||
logger.LogError(ctx, fmt.Sprintf("渠道 #%d 更新异步任务失败: %s", channelId, err.Error()))
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func updateSunoTaskAll(ctx context.Context, channelId int, taskIds []string, taskM map[string]*model.Task) error {
|
||||
logger.LogInfo(ctx, fmt.Sprintf("渠道 #%d 未完成的任务有: %d", channelId, len(taskIds)))
|
||||
if len(taskIds) == 0 {
|
||||
return nil
|
||||
}
|
||||
channel, err := model.CacheGetChannel(channelId)
|
||||
if err != nil {
|
||||
common.SysLog(fmt.Sprintf("CacheGetChannel: %v", err))
|
||||
err = model.TaskBulkUpdate(taskIds, map[string]any{
|
||||
"fail_reason": fmt.Sprintf("获取渠道信息失败,请联系管理员,渠道ID:%d", channelId),
|
||||
"status": "FAILURE",
|
||||
"progress": "100%",
|
||||
})
|
||||
if err != nil {
|
||||
common.SysLog(fmt.Sprintf("UpdateMidjourneyTask error2: %v", err))
|
||||
}
|
||||
return err
|
||||
}
|
||||
adaptor := relay.GetTaskAdaptor(constant.TaskPlatformSuno)
|
||||
if adaptor == nil {
|
||||
return errors.New("adaptor not found")
|
||||
}
|
||||
proxy := channel.GetSetting().Proxy
|
||||
resp, err := adaptor.FetchTask(*channel.BaseURL, channel.Key, map[string]any{
|
||||
"ids": taskIds,
|
||||
}, proxy)
|
||||
if err != nil {
|
||||
common.SysLog(fmt.Sprintf("Get Task Do req error: %v", err))
|
||||
return err
|
||||
}
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
logger.LogError(ctx, fmt.Sprintf("Get Task status code: %d", resp.StatusCode))
|
||||
return errors.New(fmt.Sprintf("Get Task status code: %d", resp.StatusCode))
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
responseBody, err := io.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
common.SysLog(fmt.Sprintf("Get Task parse body error: %v", err))
|
||||
return err
|
||||
}
|
||||
var responseItems dto.TaskResponse[[]dto.SunoDataResponse]
|
||||
err = json.Unmarshal(responseBody, &responseItems)
|
||||
if err != nil {
|
||||
logger.LogError(ctx, fmt.Sprintf("Get Task parse body error2: %v, body: %s", err, string(responseBody)))
|
||||
return err
|
||||
}
|
||||
if !responseItems.IsSuccess() {
|
||||
common.SysLog(fmt.Sprintf("渠道 #%d 未完成的任务有: %d, 成功获取到任务数: %s", channelId, len(taskIds), string(responseBody)))
|
||||
return err
|
||||
}
|
||||
|
||||
for _, responseItem := range responseItems.Data {
|
||||
task := taskM[responseItem.TaskID]
|
||||
if !checkTaskNeedUpdate(task, responseItem) {
|
||||
continue
|
||||
}
|
||||
|
||||
task.Status = lo.If(model.TaskStatus(responseItem.Status) != "", model.TaskStatus(responseItem.Status)).Else(task.Status)
|
||||
task.FailReason = lo.If(responseItem.FailReason != "", responseItem.FailReason).Else(task.FailReason)
|
||||
task.SubmitTime = lo.If(responseItem.SubmitTime != 0, responseItem.SubmitTime).Else(task.SubmitTime)
|
||||
task.StartTime = lo.If(responseItem.StartTime != 0, responseItem.StartTime).Else(task.StartTime)
|
||||
task.FinishTime = lo.If(responseItem.FinishTime != 0, responseItem.FinishTime).Else(task.FinishTime)
|
||||
if responseItem.FailReason != "" || task.Status == model.TaskStatusFailure {
|
||||
logger.LogInfo(ctx, task.TaskID+" 构建失败,"+task.FailReason)
|
||||
task.Progress = "100%"
|
||||
//err = model.CacheUpdateUserQuota(task.UserId) ?
|
||||
if err != nil {
|
||||
logger.LogError(ctx, "error update user quota cache: "+err.Error())
|
||||
} else {
|
||||
quota := task.Quota
|
||||
if quota != 0 {
|
||||
err = model.IncreaseUserQuota(task.UserId, quota, false)
|
||||
if err != nil {
|
||||
logger.LogError(ctx, "fail to increase user quota: "+err.Error())
|
||||
}
|
||||
logContent := fmt.Sprintf("异步任务执行失败 %s,补偿 %s", task.TaskID, logger.LogQuota(quota))
|
||||
model.RecordLog(task.UserId, model.LogTypeSystem, logContent)
|
||||
}
|
||||
}
|
||||
}
|
||||
if responseItem.Status == model.TaskStatusSuccess {
|
||||
task.Progress = "100%"
|
||||
}
|
||||
task.Data = responseItem.Data
|
||||
|
||||
err = task.Update()
|
||||
if err != nil {
|
||||
common.SysLog("UpdateMidjourneyTask task error: " + err.Error())
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func checkTaskNeedUpdate(oldTask *model.Task, newTask dto.SunoDataResponse) bool {
|
||||
|
||||
if oldTask.SubmitTime != newTask.SubmitTime {
|
||||
return true
|
||||
}
|
||||
if oldTask.StartTime != newTask.StartTime {
|
||||
return true
|
||||
}
|
||||
if oldTask.FinishTime != newTask.FinishTime {
|
||||
return true
|
||||
}
|
||||
if string(oldTask.Status) != newTask.Status {
|
||||
return true
|
||||
}
|
||||
if oldTask.FailReason != newTask.FailReason {
|
||||
return true
|
||||
}
|
||||
if oldTask.FinishTime != newTask.FinishTime {
|
||||
return true
|
||||
}
|
||||
|
||||
if (oldTask.Status == model.TaskStatusFailure || oldTask.Status == model.TaskStatusSuccess) && oldTask.Progress != "100%" {
|
||||
return true
|
||||
}
|
||||
|
||||
oldData, _ := json.Marshal(oldTask.Data)
|
||||
newData, _ := json.Marshal(newTask.Data)
|
||||
|
||||
sort.Slice(oldData, func(i, j int) bool {
|
||||
return oldData[i] < oldData[j]
|
||||
})
|
||||
sort.Slice(newData, func(i, j int) bool {
|
||||
return newData[i] < newData[j]
|
||||
})
|
||||
|
||||
if string(oldData) != string(newData) {
|
||||
return true
|
||||
}
|
||||
return false
|
||||
service.TaskPollingLoop()
|
||||
}
|
||||
|
||||
func GetAllTask(c *gin.Context) {
|
||||
@@ -247,7 +38,7 @@ func GetAllTask(c *gin.Context) {
|
||||
items := model.TaskGetAllTasks(pageInfo.GetStartIdx(), pageInfo.GetPageSize(), queryParams)
|
||||
total := model.TaskCountAllTasks(queryParams)
|
||||
pageInfo.SetTotal(int(total))
|
||||
pageInfo.SetItems(items)
|
||||
pageInfo.SetItems(tasksToDto(items, true))
|
||||
common.ApiSuccess(c, pageInfo)
|
||||
}
|
||||
|
||||
@@ -271,6 +62,33 @@ func GetUserTask(c *gin.Context) {
|
||||
items := model.TaskGetAllUserTask(userId, pageInfo.GetStartIdx(), pageInfo.GetPageSize(), queryParams)
|
||||
total := model.TaskCountAllUserTask(userId, queryParams)
|
||||
pageInfo.SetTotal(int(total))
|
||||
pageInfo.SetItems(items)
|
||||
pageInfo.SetItems(tasksToDto(items, false))
|
||||
common.ApiSuccess(c, pageInfo)
|
||||
}
|
||||
|
||||
func tasksToDto(tasks []*model.Task, fillUser bool) []*dto.TaskDto {
|
||||
var userIdMap map[int]*model.UserBase
|
||||
if fillUser {
|
||||
userIdMap = make(map[int]*model.UserBase)
|
||||
userIds := types.NewSet[int]()
|
||||
for _, task := range tasks {
|
||||
userIds.Add(task.UserId)
|
||||
}
|
||||
for _, userId := range userIds.Items() {
|
||||
cacheUser, err := model.GetUserCache(userId)
|
||||
if err == nil {
|
||||
userIdMap[userId] = cacheUser
|
||||
}
|
||||
}
|
||||
}
|
||||
result := make([]*dto.TaskDto, len(tasks))
|
||||
for i, task := range tasks {
|
||||
if fillUser {
|
||||
if user, ok := userIdMap[task.UserId]; ok {
|
||||
task.Username = user.Username
|
||||
}
|
||||
}
|
||||
result[i] = relay.TaskModel2Dto(task)
|
||||
}
|
||||
return result
|
||||
}
|
||||
|
||||
@@ -1,313 +0,0 @@
|
||||
package controller
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"time"
|
||||
|
||||
"github.com/QuantumNous/new-api/common"
|
||||
"github.com/QuantumNous/new-api/constant"
|
||||
"github.com/QuantumNous/new-api/dto"
|
||||
"github.com/QuantumNous/new-api/logger"
|
||||
"github.com/QuantumNous/new-api/model"
|
||||
"github.com/QuantumNous/new-api/relay"
|
||||
"github.com/QuantumNous/new-api/relay/channel"
|
||||
relaycommon "github.com/QuantumNous/new-api/relay/common"
|
||||
"github.com/QuantumNous/new-api/setting/ratio_setting"
|
||||
)
|
||||
|
||||
func UpdateVideoTaskAll(ctx context.Context, platform constant.TaskPlatform, taskChannelM map[int][]string, taskM map[string]*model.Task) error {
|
||||
for channelId, taskIds := range taskChannelM {
|
||||
if err := updateVideoTaskAll(ctx, platform, channelId, taskIds, taskM); err != nil {
|
||||
logger.LogError(ctx, fmt.Sprintf("Channel #%d failed to update video async tasks: %s", channelId, err.Error()))
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func updateVideoTaskAll(ctx context.Context, platform constant.TaskPlatform, channelId int, taskIds []string, taskM map[string]*model.Task) error {
|
||||
logger.LogInfo(ctx, fmt.Sprintf("Channel #%d pending video tasks: %d", channelId, len(taskIds)))
|
||||
if len(taskIds) == 0 {
|
||||
return nil
|
||||
}
|
||||
cacheGetChannel, err := model.CacheGetChannel(channelId)
|
||||
if err != nil {
|
||||
errUpdate := model.TaskBulkUpdate(taskIds, map[string]any{
|
||||
"fail_reason": fmt.Sprintf("Failed to get channel info, channel ID: %d", channelId),
|
||||
"status": "FAILURE",
|
||||
"progress": "100%",
|
||||
})
|
||||
if errUpdate != nil {
|
||||
common.SysLog(fmt.Sprintf("UpdateVideoTask error: %v", errUpdate))
|
||||
}
|
||||
return fmt.Errorf("CacheGetChannel failed: %w", err)
|
||||
}
|
||||
adaptor := relay.GetTaskAdaptor(platform)
|
||||
if adaptor == nil {
|
||||
return fmt.Errorf("video adaptor not found")
|
||||
}
|
||||
info := &relaycommon.RelayInfo{}
|
||||
info.ChannelMeta = &relaycommon.ChannelMeta{
|
||||
ChannelBaseUrl: cacheGetChannel.GetBaseURL(),
|
||||
}
|
||||
info.ApiKey = cacheGetChannel.Key
|
||||
adaptor.Init(info)
|
||||
for _, taskId := range taskIds {
|
||||
if err := updateVideoSingleTask(ctx, adaptor, cacheGetChannel, taskId, taskM); err != nil {
|
||||
logger.LogError(ctx, fmt.Sprintf("Failed to update video task %s: %s", taskId, err.Error()))
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func updateVideoSingleTask(ctx context.Context, adaptor channel.TaskAdaptor, channel *model.Channel, taskId string, taskM map[string]*model.Task) error {
|
||||
baseURL := constant.ChannelBaseURLs[channel.Type]
|
||||
if channel.GetBaseURL() != "" {
|
||||
baseURL = channel.GetBaseURL()
|
||||
}
|
||||
proxy := channel.GetSetting().Proxy
|
||||
|
||||
task := taskM[taskId]
|
||||
if task == nil {
|
||||
logger.LogError(ctx, fmt.Sprintf("Task %s not found in taskM", taskId))
|
||||
return fmt.Errorf("task %s not found", taskId)
|
||||
}
|
||||
key := channel.Key
|
||||
|
||||
privateData := task.PrivateData
|
||||
if privateData.Key != "" {
|
||||
key = privateData.Key
|
||||
}
|
||||
resp, err := adaptor.FetchTask(baseURL, key, map[string]any{
|
||||
"task_id": taskId,
|
||||
"action": task.Action,
|
||||
}, proxy)
|
||||
if err != nil {
|
||||
return fmt.Errorf("fetchTask failed for task %s: %w", taskId, err)
|
||||
}
|
||||
//if resp.StatusCode != http.StatusOK {
|
||||
//return fmt.Errorf("get Video Task status code: %d", resp.StatusCode)
|
||||
//}
|
||||
defer resp.Body.Close()
|
||||
responseBody, err := io.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
return fmt.Errorf("readAll failed for task %s: %w", taskId, err)
|
||||
}
|
||||
|
||||
logger.LogDebug(ctx, fmt.Sprintf("UpdateVideoSingleTask response: %s", string(responseBody)))
|
||||
|
||||
taskResult := &relaycommon.TaskInfo{}
|
||||
// try parse as New API response format
|
||||
var responseItems dto.TaskResponse[model.Task]
|
||||
if err = common.Unmarshal(responseBody, &responseItems); err == nil && responseItems.IsSuccess() {
|
||||
logger.LogDebug(ctx, fmt.Sprintf("UpdateVideoSingleTask parsed as new api response format: %+v", responseItems))
|
||||
t := responseItems.Data
|
||||
taskResult.TaskID = t.TaskID
|
||||
taskResult.Status = string(t.Status)
|
||||
taskResult.Url = t.FailReason
|
||||
taskResult.Progress = t.Progress
|
||||
taskResult.Reason = t.FailReason
|
||||
task.Data = t.Data
|
||||
} else if taskResult, err = adaptor.ParseTaskResult(responseBody); err != nil {
|
||||
return fmt.Errorf("parseTaskResult failed for task %s: %w", taskId, err)
|
||||
} else {
|
||||
task.Data = redactVideoResponseBody(responseBody)
|
||||
}
|
||||
|
||||
logger.LogDebug(ctx, fmt.Sprintf("UpdateVideoSingleTask taskResult: %+v", taskResult))
|
||||
|
||||
now := time.Now().Unix()
|
||||
if taskResult.Status == "" {
|
||||
//return fmt.Errorf("task %s status is empty", taskId)
|
||||
taskResult = relaycommon.FailTaskInfo("upstream returned empty status")
|
||||
}
|
||||
|
||||
// 记录原本的状态,防止重复退款
|
||||
shouldRefund := false
|
||||
quota := task.Quota
|
||||
preStatus := task.Status
|
||||
|
||||
task.Status = model.TaskStatus(taskResult.Status)
|
||||
switch taskResult.Status {
|
||||
case model.TaskStatusSubmitted:
|
||||
task.Progress = "10%"
|
||||
case model.TaskStatusQueued:
|
||||
task.Progress = "20%"
|
||||
case model.TaskStatusInProgress:
|
||||
task.Progress = "30%"
|
||||
if task.StartTime == 0 {
|
||||
task.StartTime = now
|
||||
}
|
||||
case model.TaskStatusSuccess:
|
||||
task.Progress = "100%"
|
||||
if task.FinishTime == 0 {
|
||||
task.FinishTime = now
|
||||
}
|
||||
if !(len(taskResult.Url) > 5 && taskResult.Url[:5] == "data:") {
|
||||
task.FailReason = taskResult.Url
|
||||
}
|
||||
|
||||
// 如果返回了 total_tokens 并且配置了模型倍率(非固定价格),则重新计费
|
||||
if taskResult.TotalTokens > 0 {
|
||||
// 获取模型名称
|
||||
var taskData map[string]interface{}
|
||||
if err := json.Unmarshal(task.Data, &taskData); err == nil {
|
||||
if modelName, ok := taskData["model"].(string); ok && modelName != "" {
|
||||
// 获取模型价格和倍率
|
||||
modelRatio, hasRatioSetting, _ := ratio_setting.GetModelRatio(modelName)
|
||||
// 只有配置了倍率(非固定价格)时才按 token 重新计费
|
||||
if hasRatioSetting && modelRatio > 0 {
|
||||
// 获取用户和组的倍率信息
|
||||
group := task.Group
|
||||
if group == "" {
|
||||
user, err := model.GetUserById(task.UserId, false)
|
||||
if err == nil {
|
||||
group = user.Group
|
||||
}
|
||||
}
|
||||
if group != "" {
|
||||
groupRatio := ratio_setting.GetGroupRatio(group)
|
||||
userGroupRatio, hasUserGroupRatio := ratio_setting.GetGroupGroupRatio(group, group)
|
||||
|
||||
var finalGroupRatio float64
|
||||
if hasUserGroupRatio {
|
||||
finalGroupRatio = userGroupRatio
|
||||
} else {
|
||||
finalGroupRatio = groupRatio
|
||||
}
|
||||
|
||||
// 计算实际应扣费额度: totalTokens * modelRatio * groupRatio
|
||||
actualQuota := int(float64(taskResult.TotalTokens) * modelRatio * finalGroupRatio)
|
||||
|
||||
// 计算差额
|
||||
preConsumedQuota := task.Quota
|
||||
quotaDelta := actualQuota - preConsumedQuota
|
||||
|
||||
if quotaDelta > 0 {
|
||||
// 需要补扣费
|
||||
logger.LogInfo(ctx, fmt.Sprintf("视频任务 %s 预扣费后补扣费:%s(实际消耗:%s,预扣费:%s,tokens:%d)",
|
||||
task.TaskID,
|
||||
logger.LogQuota(quotaDelta),
|
||||
logger.LogQuota(actualQuota),
|
||||
logger.LogQuota(preConsumedQuota),
|
||||
taskResult.TotalTokens,
|
||||
))
|
||||
if err := model.DecreaseUserQuota(task.UserId, quotaDelta); err != nil {
|
||||
logger.LogError(ctx, fmt.Sprintf("补扣费失败: %s", err.Error()))
|
||||
} else {
|
||||
model.UpdateUserUsedQuotaAndRequestCount(task.UserId, quotaDelta)
|
||||
model.UpdateChannelUsedQuota(task.ChannelId, quotaDelta)
|
||||
task.Quota = actualQuota // 更新任务记录的实际扣费额度
|
||||
|
||||
// 记录消费日志
|
||||
logContent := fmt.Sprintf("视频任务成功补扣费,模型倍率 %.2f,分组倍率 %.2f,tokens %d,预扣费 %s,实际扣费 %s,补扣费 %s",
|
||||
modelRatio, finalGroupRatio, taskResult.TotalTokens,
|
||||
logger.LogQuota(preConsumedQuota), logger.LogQuota(actualQuota), logger.LogQuota(quotaDelta))
|
||||
model.RecordLog(task.UserId, model.LogTypeSystem, logContent)
|
||||
}
|
||||
} else if quotaDelta < 0 {
|
||||
// 需要退还多扣的费用
|
||||
refundQuota := -quotaDelta
|
||||
logger.LogInfo(ctx, fmt.Sprintf("视频任务 %s 预扣费后返还:%s(实际消耗:%s,预扣费:%s,tokens:%d)",
|
||||
task.TaskID,
|
||||
logger.LogQuota(refundQuota),
|
||||
logger.LogQuota(actualQuota),
|
||||
logger.LogQuota(preConsumedQuota),
|
||||
taskResult.TotalTokens,
|
||||
))
|
||||
if err := model.IncreaseUserQuota(task.UserId, refundQuota, false); err != nil {
|
||||
logger.LogError(ctx, fmt.Sprintf("退还预扣费失败: %s", err.Error()))
|
||||
} else {
|
||||
task.Quota = actualQuota // 更新任务记录的实际扣费额度
|
||||
|
||||
// 记录退款日志
|
||||
logContent := fmt.Sprintf("视频任务成功退还多扣费用,模型倍率 %.2f,分组倍率 %.2f,tokens %d,预扣费 %s,实际扣费 %s,退还 %s",
|
||||
modelRatio, finalGroupRatio, taskResult.TotalTokens,
|
||||
logger.LogQuota(preConsumedQuota), logger.LogQuota(actualQuota), logger.LogQuota(refundQuota))
|
||||
model.RecordLog(task.UserId, model.LogTypeSystem, logContent)
|
||||
}
|
||||
} else {
|
||||
// quotaDelta == 0, 预扣费刚好准确
|
||||
logger.LogInfo(ctx, fmt.Sprintf("视频任务 %s 预扣费准确(%s,tokens:%d)",
|
||||
task.TaskID, logger.LogQuota(actualQuota), taskResult.TotalTokens))
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
case model.TaskStatusFailure:
|
||||
logger.LogJson(ctx, fmt.Sprintf("Task %s failed", taskId), task)
|
||||
task.Status = model.TaskStatusFailure
|
||||
task.Progress = "100%"
|
||||
if task.FinishTime == 0 {
|
||||
task.FinishTime = now
|
||||
}
|
||||
task.FailReason = taskResult.Reason
|
||||
logger.LogInfo(ctx, fmt.Sprintf("Task %s failed: %s", task.TaskID, task.FailReason))
|
||||
taskResult.Progress = "100%"
|
||||
if quota != 0 {
|
||||
if preStatus != model.TaskStatusFailure {
|
||||
shouldRefund = true
|
||||
} else {
|
||||
logger.LogWarn(ctx, fmt.Sprintf("Task %s already in failure status, skip refund", task.TaskID))
|
||||
}
|
||||
}
|
||||
default:
|
||||
return fmt.Errorf("unknown task status %s for task %s", taskResult.Status, taskId)
|
||||
}
|
||||
if taskResult.Progress != "" {
|
||||
task.Progress = taskResult.Progress
|
||||
}
|
||||
if err := task.Update(); err != nil {
|
||||
common.SysLog("UpdateVideoTask task error: " + err.Error())
|
||||
shouldRefund = false
|
||||
}
|
||||
|
||||
if shouldRefund {
|
||||
// 任务失败且之前状态不是失败才退还额度,防止重复退还
|
||||
if err := model.IncreaseUserQuota(task.UserId, quota, false); err != nil {
|
||||
logger.LogWarn(ctx, "Failed to increase user quota: "+err.Error())
|
||||
}
|
||||
logContent := fmt.Sprintf("Video async task failed %s, refund %s", task.TaskID, logger.LogQuota(quota))
|
||||
model.RecordLog(task.UserId, model.LogTypeSystem, logContent)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func redactVideoResponseBody(body []byte) []byte {
|
||||
var m map[string]any
|
||||
if err := json.Unmarshal(body, &m); err != nil {
|
||||
return body
|
||||
}
|
||||
resp, _ := m["response"].(map[string]any)
|
||||
if resp != nil {
|
||||
delete(resp, "bytesBase64Encoded")
|
||||
if v, ok := resp["video"].(string); ok {
|
||||
resp["video"] = truncateBase64(v)
|
||||
}
|
||||
if vs, ok := resp["videos"].([]any); ok {
|
||||
for i := range vs {
|
||||
if vm, ok := vs[i].(map[string]any); ok {
|
||||
delete(vm, "bytesBase64Encoded")
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
b, err := json.Marshal(m)
|
||||
if err != nil {
|
||||
return body
|
||||
}
|
||||
return b
|
||||
}
|
||||
|
||||
func truncateBase64(s string) string {
|
||||
const maxKeep = 256
|
||||
if len(s) <= maxKeep {
|
||||
return s
|
||||
}
|
||||
return s[:maxKeep] + "..."
|
||||
}
|
||||
@@ -582,6 +582,44 @@ func UpdateUser(c *gin.Context) {
|
||||
return
|
||||
}
|
||||
|
||||
func AdminClearUserBinding(c *gin.Context) {
|
||||
id, err := strconv.Atoi(c.Param("id"))
|
||||
if err != nil {
|
||||
common.ApiErrorI18n(c, i18n.MsgInvalidParams)
|
||||
return
|
||||
}
|
||||
|
||||
bindingType := strings.ToLower(strings.TrimSpace(c.Param("binding_type")))
|
||||
if bindingType == "" {
|
||||
common.ApiErrorI18n(c, i18n.MsgInvalidParams)
|
||||
return
|
||||
}
|
||||
|
||||
user, err := model.GetUserById(id, false)
|
||||
if err != nil {
|
||||
common.ApiError(c, err)
|
||||
return
|
||||
}
|
||||
|
||||
myRole := c.GetInt("role")
|
||||
if myRole <= user.Role && myRole != common.RoleRootUser {
|
||||
common.ApiErrorI18n(c, i18n.MsgUserNoPermissionSameLevel)
|
||||
return
|
||||
}
|
||||
|
||||
if err := user.ClearBinding(bindingType); err != nil {
|
||||
common.ApiError(c, err)
|
||||
return
|
||||
}
|
||||
|
||||
model.RecordLog(user.Id, model.LogTypeManage, fmt.Sprintf("admin cleared %s binding for user %s", bindingType, user.Username))
|
||||
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": true,
|
||||
"message": "success",
|
||||
})
|
||||
}
|
||||
|
||||
func UpdateSelf(c *gin.Context) {
|
||||
var requestData map[string]interface{}
|
||||
err := json.NewDecoder(c.Request.Body).Decode(&requestData)
|
||||
|
||||
@@ -16,59 +16,44 @@ import (
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
// videoProxyError returns a standardized OpenAI-style error response.
|
||||
func videoProxyError(c *gin.Context, status int, errType, message string) {
|
||||
c.JSON(status, gin.H{
|
||||
"error": gin.H{
|
||||
"message": message,
|
||||
"type": errType,
|
||||
},
|
||||
})
|
||||
}
|
||||
|
||||
func VideoProxy(c *gin.Context) {
|
||||
taskID := c.Param("task_id")
|
||||
if taskID == "" {
|
||||
c.JSON(http.StatusBadRequest, gin.H{
|
||||
"error": gin.H{
|
||||
"message": "task_id is required",
|
||||
"type": "invalid_request_error",
|
||||
},
|
||||
})
|
||||
videoProxyError(c, http.StatusBadRequest, "invalid_request_error", "task_id is required")
|
||||
return
|
||||
}
|
||||
|
||||
task, exists, err := model.GetByOnlyTaskId(taskID)
|
||||
if err != nil {
|
||||
logger.LogError(c.Request.Context(), fmt.Sprintf("Failed to query task %s: %s", taskID, err.Error()))
|
||||
c.JSON(http.StatusInternalServerError, gin.H{
|
||||
"error": gin.H{
|
||||
"message": "Failed to query task",
|
||||
"type": "server_error",
|
||||
},
|
||||
})
|
||||
videoProxyError(c, http.StatusInternalServerError, "server_error", "Failed to query task")
|
||||
return
|
||||
}
|
||||
if !exists || task == nil {
|
||||
logger.LogError(c.Request.Context(), fmt.Sprintf("Failed to get task %s: %v", taskID, err))
|
||||
c.JSON(http.StatusNotFound, gin.H{
|
||||
"error": gin.H{
|
||||
"message": "Task not found",
|
||||
"type": "invalid_request_error",
|
||||
},
|
||||
})
|
||||
videoProxyError(c, http.StatusNotFound, "invalid_request_error", "Task not found")
|
||||
return
|
||||
}
|
||||
|
||||
if task.Status != model.TaskStatusSuccess {
|
||||
c.JSON(http.StatusBadRequest, gin.H{
|
||||
"error": gin.H{
|
||||
"message": fmt.Sprintf("Task is not completed yet, current status: %s", task.Status),
|
||||
"type": "invalid_request_error",
|
||||
},
|
||||
})
|
||||
videoProxyError(c, http.StatusBadRequest, "invalid_request_error",
|
||||
fmt.Sprintf("Task is not completed yet, current status: %s", task.Status))
|
||||
return
|
||||
}
|
||||
|
||||
channel, err := model.CacheGetChannel(task.ChannelId)
|
||||
if err != nil {
|
||||
logger.LogError(c.Request.Context(), fmt.Sprintf("Failed to get task %s: not found", taskID))
|
||||
c.JSON(http.StatusInternalServerError, gin.H{
|
||||
"error": gin.H{
|
||||
"message": "Failed to retrieve channel information",
|
||||
"type": "server_error",
|
||||
},
|
||||
})
|
||||
logger.LogError(c.Request.Context(), fmt.Sprintf("Failed to get channel for task %s: %s", taskID, err.Error()))
|
||||
videoProxyError(c, http.StatusInternalServerError, "server_error", "Failed to retrieve channel information")
|
||||
return
|
||||
}
|
||||
baseURL := channel.GetBaseURL()
|
||||
@@ -81,12 +66,7 @@ func VideoProxy(c *gin.Context) {
|
||||
client, err := service.GetHttpClientWithProxy(proxy)
|
||||
if err != nil {
|
||||
logger.LogError(c.Request.Context(), fmt.Sprintf("Failed to create proxy client for task %s: %s", taskID, err.Error()))
|
||||
c.JSON(http.StatusInternalServerError, gin.H{
|
||||
"error": gin.H{
|
||||
"message": "Failed to create proxy client",
|
||||
"type": "server_error",
|
||||
},
|
||||
})
|
||||
videoProxyError(c, http.StatusInternalServerError, "server_error", "Failed to create proxy client")
|
||||
return
|
||||
}
|
||||
|
||||
@@ -95,12 +75,7 @@ func VideoProxy(c *gin.Context) {
|
||||
req, err := http.NewRequestWithContext(ctx, http.MethodGet, "", nil)
|
||||
if err != nil {
|
||||
logger.LogError(c.Request.Context(), fmt.Sprintf("Failed to create request: %s", err.Error()))
|
||||
c.JSON(http.StatusInternalServerError, gin.H{
|
||||
"error": gin.H{
|
||||
"message": "Failed to create proxy request",
|
||||
"type": "server_error",
|
||||
},
|
||||
})
|
||||
videoProxyError(c, http.StatusInternalServerError, "server_error", "Failed to create proxy request")
|
||||
return
|
||||
}
|
||||
|
||||
@@ -109,68 +84,43 @@ func VideoProxy(c *gin.Context) {
|
||||
apiKey := task.PrivateData.Key
|
||||
if apiKey == "" {
|
||||
logger.LogError(c.Request.Context(), fmt.Sprintf("Missing stored API key for Gemini task %s", taskID))
|
||||
c.JSON(http.StatusInternalServerError, gin.H{
|
||||
"error": gin.H{
|
||||
"message": "API key not stored for task",
|
||||
"type": "server_error",
|
||||
},
|
||||
})
|
||||
videoProxyError(c, http.StatusInternalServerError, "server_error", "API key not stored for task")
|
||||
return
|
||||
}
|
||||
|
||||
videoURL, err = getGeminiVideoURL(channel, task, apiKey)
|
||||
if err != nil {
|
||||
logger.LogError(c.Request.Context(), fmt.Sprintf("Failed to resolve Gemini video URL for task %s: %s", taskID, err.Error()))
|
||||
c.JSON(http.StatusBadGateway, gin.H{
|
||||
"error": gin.H{
|
||||
"message": "Failed to resolve Gemini video URL",
|
||||
"type": "server_error",
|
||||
},
|
||||
})
|
||||
videoProxyError(c, http.StatusBadGateway, "server_error", "Failed to resolve Gemini video URL")
|
||||
return
|
||||
}
|
||||
req.Header.Set("x-goog-api-key", apiKey)
|
||||
case constant.ChannelTypeOpenAI, constant.ChannelTypeSora:
|
||||
videoURL = fmt.Sprintf("%s/v1/videos/%s/content", baseURL, task.TaskID)
|
||||
videoURL = fmt.Sprintf("%s/v1/videos/%s/content", baseURL, task.GetUpstreamTaskID())
|
||||
req.Header.Set("Authorization", "Bearer "+channel.Key)
|
||||
default:
|
||||
// Video URL is directly in task.FailReason
|
||||
videoURL = task.FailReason
|
||||
// Video URL is stored in PrivateData.ResultURL (fallback to FailReason for old data)
|
||||
videoURL = task.GetResultURL()
|
||||
}
|
||||
|
||||
req.URL, err = url.Parse(videoURL)
|
||||
if err != nil {
|
||||
logger.LogError(c.Request.Context(), fmt.Sprintf("Failed to parse URL %s: %s", videoURL, err.Error()))
|
||||
c.JSON(http.StatusInternalServerError, gin.H{
|
||||
"error": gin.H{
|
||||
"message": "Failed to create proxy request",
|
||||
"type": "server_error",
|
||||
},
|
||||
})
|
||||
videoProxyError(c, http.StatusInternalServerError, "server_error", "Failed to create proxy request")
|
||||
return
|
||||
}
|
||||
|
||||
resp, err := client.Do(req)
|
||||
if err != nil {
|
||||
logger.LogError(c.Request.Context(), fmt.Sprintf("Failed to fetch video from %s: %s", videoURL, err.Error()))
|
||||
c.JSON(http.StatusBadGateway, gin.H{
|
||||
"error": gin.H{
|
||||
"message": "Failed to fetch video content",
|
||||
"type": "server_error",
|
||||
},
|
||||
})
|
||||
videoProxyError(c, http.StatusBadGateway, "server_error", "Failed to fetch video content")
|
||||
return
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
logger.LogError(c.Request.Context(), fmt.Sprintf("Upstream returned status %d for %s", resp.StatusCode, videoURL))
|
||||
c.JSON(http.StatusBadGateway, gin.H{
|
||||
"error": gin.H{
|
||||
"message": fmt.Sprintf("Upstream service returned status %d", resp.StatusCode),
|
||||
"type": "server_error",
|
||||
},
|
||||
})
|
||||
videoProxyError(c, http.StatusBadGateway, "server_error",
|
||||
fmt.Sprintf("Upstream service returned status %d", resp.StatusCode))
|
||||
return
|
||||
}
|
||||
|
||||
@@ -180,10 +130,9 @@ func VideoProxy(c *gin.Context) {
|
||||
}
|
||||
}
|
||||
|
||||
c.Writer.Header().Set("Cache-Control", "public, max-age=86400") // Cache for 24 hours
|
||||
c.Writer.Header().Set("Cache-Control", "public, max-age=86400")
|
||||
c.Writer.WriteHeader(resp.StatusCode)
|
||||
_, err = io.Copy(c.Writer, resp.Body)
|
||||
if err != nil {
|
||||
if _, err = io.Copy(c.Writer, resp.Body); err != nil {
|
||||
logger.LogError(c.Request.Context(), fmt.Sprintf("Failed to stream video content: %s", err.Error()))
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,12 +1,12 @@
|
||||
package controller
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"strconv"
|
||||
"strings"
|
||||
|
||||
"github.com/QuantumNous/new-api/common"
|
||||
"github.com/QuantumNous/new-api/constant"
|
||||
"github.com/QuantumNous/new-api/model"
|
||||
"github.com/QuantumNous/new-api/relay"
|
||||
@@ -37,7 +37,7 @@ func getGeminiVideoURL(channel *model.Channel, task *model.Task, apiKey string)
|
||||
|
||||
proxy := channel.GetSetting().Proxy
|
||||
resp, err := adaptor.FetchTask(baseURL, apiKey, map[string]any{
|
||||
"task_id": task.TaskID,
|
||||
"task_id": task.GetUpstreamTaskID(),
|
||||
"action": task.Action,
|
||||
}, proxy)
|
||||
if err != nil {
|
||||
@@ -71,7 +71,7 @@ func extractGeminiVideoURLFromTaskData(task *model.Task) string {
|
||||
return ""
|
||||
}
|
||||
var payload map[string]any
|
||||
if err := json.Unmarshal(task.Data, &payload); err != nil {
|
||||
if err := common.Unmarshal(task.Data, &payload); err != nil {
|
||||
return ""
|
||||
}
|
||||
return extractGeminiVideoURLFromMap(payload)
|
||||
@@ -79,7 +79,7 @@ func extractGeminiVideoURLFromTaskData(task *model.Task) string {
|
||||
|
||||
func extractGeminiVideoURLFromPayload(body []byte) string {
|
||||
var payload map[string]any
|
||||
if err := json.Unmarshal(body, &payload); err != nil {
|
||||
if err := common.Unmarshal(body, &payload); err != nil {
|
||||
return ""
|
||||
}
|
||||
return extractGeminiVideoURLFromMap(payload)
|
||||
|
||||
@@ -24,14 +24,16 @@ const (
|
||||
)
|
||||
|
||||
type ChannelOtherSettings struct {
|
||||
AzureResponsesVersion string `json:"azure_responses_version,omitempty"`
|
||||
VertexKeyType VertexKeyType `json:"vertex_key_type,omitempty"` // "json" or "api_key"
|
||||
OpenRouterEnterprise *bool `json:"openrouter_enterprise,omitempty"`
|
||||
ClaudeBetaQuery bool `json:"claude_beta_query,omitempty"` // Claude 渠道是否强制追加 ?beta=true
|
||||
AllowServiceTier bool `json:"allow_service_tier,omitempty"` // 是否允许 service_tier 透传(默认过滤以避免额外计费)
|
||||
DisableStore bool `json:"disable_store,omitempty"` // 是否禁用 store 透传(默认允许透传,禁用后可能导致 Codex 无法使用)
|
||||
AllowSafetyIdentifier bool `json:"allow_safety_identifier,omitempty"` // 是否允许 safety_identifier 透传(默认过滤以保护用户隐私)
|
||||
AwsKeyType AwsKeyType `json:"aws_key_type,omitempty"`
|
||||
AzureResponsesVersion string `json:"azure_responses_version,omitempty"`
|
||||
VertexKeyType VertexKeyType `json:"vertex_key_type,omitempty"` // "json" or "api_key"
|
||||
OpenRouterEnterprise *bool `json:"openrouter_enterprise,omitempty"`
|
||||
ClaudeBetaQuery bool `json:"claude_beta_query,omitempty"` // Claude 渠道是否强制追加 ?beta=true
|
||||
AllowServiceTier bool `json:"allow_service_tier,omitempty"` // 是否允许 service_tier 透传(默认过滤以避免额外计费)
|
||||
AllowInferenceGeo bool `json:"allow_inference_geo,omitempty"` // 是否允许 inference_geo 透传(仅 Claude,默认过滤以满足数据驻留合规)
|
||||
DisableStore bool `json:"disable_store,omitempty"` // 是否禁用 store 透传(默认允许透传,禁用后可能导致 Codex 无法使用)
|
||||
AllowSafetyIdentifier bool `json:"allow_safety_identifier,omitempty"` // 是否允许 safety_identifier 透传(默认过滤以保护用户隐私)
|
||||
AllowIncludeObfuscation bool `json:"allow_include_obfuscation,omitempty"` // 是否允许 stream_options.include_obfuscation 透传(默认过滤以避免关闭流混淆保护)
|
||||
AwsKeyType AwsKeyType `json:"aws_key_type,omitempty"`
|
||||
}
|
||||
|
||||
func (s *ChannelOtherSettings) IsOpenRouterEnterprise() bool {
|
||||
|
||||
@@ -190,10 +190,13 @@ type ClaudeToolChoice struct {
|
||||
}
|
||||
|
||||
type ClaudeRequest struct {
|
||||
Model string `json:"model"`
|
||||
Prompt string `json:"prompt,omitempty"`
|
||||
System any `json:"system,omitempty"`
|
||||
Messages []ClaudeMessage `json:"messages,omitempty"`
|
||||
Model string `json:"model"`
|
||||
Prompt string `json:"prompt,omitempty"`
|
||||
System any `json:"system,omitempty"`
|
||||
Messages []ClaudeMessage `json:"messages,omitempty"`
|
||||
// InferenceGeo controls Claude data residency region.
|
||||
// This field is filtered by default and can be enabled via channel setting allow_inference_geo.
|
||||
InferenceGeo string `json:"inference_geo,omitempty"`
|
||||
MaxTokens uint `json:"max_tokens,omitempty"`
|
||||
MaxTokensToSample uint `json:"max_tokens_to_sample,omitempty"`
|
||||
StopSequences []string `json:"stop_sequences,omitempty"`
|
||||
@@ -210,7 +213,8 @@ type ClaudeRequest struct {
|
||||
Thinking *Thinking `json:"thinking,omitempty"`
|
||||
McpServers json.RawMessage `json:"mcp_servers,omitempty"`
|
||||
Metadata json.RawMessage `json:"metadata,omitempty"`
|
||||
// 服务层级字段,用于指定 API 服务等级。允许透传可能导致实际计费高于预期,默认应过滤
|
||||
// ServiceTier specifies upstream service level and may affect billing.
|
||||
// This field is filtered by default and can be enabled via channel setting allow_service_tier.
|
||||
ServiceTier string `json:"service_tier,omitempty"`
|
||||
}
|
||||
|
||||
|
||||
@@ -324,25 +324,26 @@ type GeminiChatTool struct {
|
||||
}
|
||||
|
||||
type GeminiChatGenerationConfig struct {
|
||||
Temperature *float64 `json:"temperature,omitempty"`
|
||||
TopP float64 `json:"topP,omitempty"`
|
||||
TopK float64 `json:"topK,omitempty"`
|
||||
MaxOutputTokens uint `json:"maxOutputTokens,omitempty"`
|
||||
CandidateCount int `json:"candidateCount,omitempty"`
|
||||
StopSequences []string `json:"stopSequences,omitempty"`
|
||||
ResponseMimeType string `json:"responseMimeType,omitempty"`
|
||||
ResponseSchema any `json:"responseSchema,omitempty"`
|
||||
ResponseJsonSchema json.RawMessage `json:"responseJsonSchema,omitempty"`
|
||||
PresencePenalty *float32 `json:"presencePenalty,omitempty"`
|
||||
FrequencyPenalty *float32 `json:"frequencyPenalty,omitempty"`
|
||||
ResponseLogprobs bool `json:"responseLogprobs,omitempty"`
|
||||
Logprobs *int32 `json:"logprobs,omitempty"`
|
||||
MediaResolution MediaResolution `json:"mediaResolution,omitempty"`
|
||||
Seed int64 `json:"seed,omitempty"`
|
||||
ResponseModalities []string `json:"responseModalities,omitempty"`
|
||||
ThinkingConfig *GeminiThinkingConfig `json:"thinkingConfig,omitempty"`
|
||||
SpeechConfig json.RawMessage `json:"speechConfig,omitempty"` // RawMessage to allow flexible speech config
|
||||
ImageConfig json.RawMessage `json:"imageConfig,omitempty"` // RawMessage to allow flexible image config
|
||||
Temperature *float64 `json:"temperature,omitempty"`
|
||||
TopP float64 `json:"topP,omitempty"`
|
||||
TopK float64 `json:"topK,omitempty"`
|
||||
MaxOutputTokens uint `json:"maxOutputTokens,omitempty"`
|
||||
CandidateCount int `json:"candidateCount,omitempty"`
|
||||
StopSequences []string `json:"stopSequences,omitempty"`
|
||||
ResponseMimeType string `json:"responseMimeType,omitempty"`
|
||||
ResponseSchema any `json:"responseSchema,omitempty"`
|
||||
ResponseJsonSchema json.RawMessage `json:"responseJsonSchema,omitempty"`
|
||||
PresencePenalty *float32 `json:"presencePenalty,omitempty"`
|
||||
FrequencyPenalty *float32 `json:"frequencyPenalty,omitempty"`
|
||||
ResponseLogprobs bool `json:"responseLogprobs,omitempty"`
|
||||
Logprobs *int32 `json:"logprobs,omitempty"`
|
||||
EnableEnhancedCivicAnswers *bool `json:"enableEnhancedCivicAnswers,omitempty"`
|
||||
MediaResolution MediaResolution `json:"mediaResolution,omitempty"`
|
||||
Seed int64 `json:"seed,omitempty"`
|
||||
ResponseModalities []string `json:"responseModalities,omitempty"`
|
||||
ThinkingConfig *GeminiThinkingConfig `json:"thinkingConfig,omitempty"`
|
||||
SpeechConfig json.RawMessage `json:"speechConfig,omitempty"` // RawMessage to allow flexible speech config
|
||||
ImageConfig json.RawMessage `json:"imageConfig,omitempty"` // RawMessage to allow flexible image config
|
||||
}
|
||||
|
||||
// UnmarshalJSON allows GeminiChatGenerationConfig to accept both snake_case and camelCase fields.
|
||||
@@ -350,22 +351,23 @@ func (c *GeminiChatGenerationConfig) UnmarshalJSON(data []byte) error {
|
||||
type Alias GeminiChatGenerationConfig
|
||||
var aux struct {
|
||||
Alias
|
||||
TopPSnake float64 `json:"top_p,omitempty"`
|
||||
TopKSnake float64 `json:"top_k,omitempty"`
|
||||
MaxOutputTokensSnake uint `json:"max_output_tokens,omitempty"`
|
||||
CandidateCountSnake int `json:"candidate_count,omitempty"`
|
||||
StopSequencesSnake []string `json:"stop_sequences,omitempty"`
|
||||
ResponseMimeTypeSnake string `json:"response_mime_type,omitempty"`
|
||||
ResponseSchemaSnake any `json:"response_schema,omitempty"`
|
||||
ResponseJsonSchemaSnake json.RawMessage `json:"response_json_schema,omitempty"`
|
||||
PresencePenaltySnake *float32 `json:"presence_penalty,omitempty"`
|
||||
FrequencyPenaltySnake *float32 `json:"frequency_penalty,omitempty"`
|
||||
ResponseLogprobsSnake bool `json:"response_logprobs,omitempty"`
|
||||
MediaResolutionSnake MediaResolution `json:"media_resolution,omitempty"`
|
||||
ResponseModalitiesSnake []string `json:"response_modalities,omitempty"`
|
||||
ThinkingConfigSnake *GeminiThinkingConfig `json:"thinking_config,omitempty"`
|
||||
SpeechConfigSnake json.RawMessage `json:"speech_config,omitempty"`
|
||||
ImageConfigSnake json.RawMessage `json:"image_config,omitempty"`
|
||||
TopPSnake float64 `json:"top_p,omitempty"`
|
||||
TopKSnake float64 `json:"top_k,omitempty"`
|
||||
MaxOutputTokensSnake uint `json:"max_output_tokens,omitempty"`
|
||||
CandidateCountSnake int `json:"candidate_count,omitempty"`
|
||||
StopSequencesSnake []string `json:"stop_sequences,omitempty"`
|
||||
ResponseMimeTypeSnake string `json:"response_mime_type,omitempty"`
|
||||
ResponseSchemaSnake any `json:"response_schema,omitempty"`
|
||||
ResponseJsonSchemaSnake json.RawMessage `json:"response_json_schema,omitempty"`
|
||||
PresencePenaltySnake *float32 `json:"presence_penalty,omitempty"`
|
||||
FrequencyPenaltySnake *float32 `json:"frequency_penalty,omitempty"`
|
||||
ResponseLogprobsSnake bool `json:"response_logprobs,omitempty"`
|
||||
EnableEnhancedCivicAnswersSnake *bool `json:"enable_enhanced_civic_answers,omitempty"`
|
||||
MediaResolutionSnake MediaResolution `json:"media_resolution,omitempty"`
|
||||
ResponseModalitiesSnake []string `json:"response_modalities,omitempty"`
|
||||
ThinkingConfigSnake *GeminiThinkingConfig `json:"thinking_config,omitempty"`
|
||||
SpeechConfigSnake json.RawMessage `json:"speech_config,omitempty"`
|
||||
ImageConfigSnake json.RawMessage `json:"image_config,omitempty"`
|
||||
}
|
||||
|
||||
if err := common.Unmarshal(data, &aux); err != nil {
|
||||
@@ -408,6 +410,9 @@ func (c *GeminiChatGenerationConfig) UnmarshalJSON(data []byte) error {
|
||||
if aux.ResponseLogprobsSnake {
|
||||
c.ResponseLogprobs = aux.ResponseLogprobsSnake
|
||||
}
|
||||
if aux.EnableEnhancedCivicAnswersSnake != nil {
|
||||
c.EnableEnhancedCivicAnswers = aux.EnableEnhancedCivicAnswersSnake
|
||||
}
|
||||
if aux.MediaResolutionSnake != "" {
|
||||
c.MediaResolution = aux.MediaResolutionSnake
|
||||
}
|
||||
@@ -453,12 +458,14 @@ type GeminiChatResponse struct {
|
||||
}
|
||||
|
||||
type GeminiUsageMetadata struct {
|
||||
PromptTokenCount int `json:"promptTokenCount"`
|
||||
CandidatesTokenCount int `json:"candidatesTokenCount"`
|
||||
TotalTokenCount int `json:"totalTokenCount"`
|
||||
ThoughtsTokenCount int `json:"thoughtsTokenCount"`
|
||||
CachedContentTokenCount int `json:"cachedContentTokenCount"`
|
||||
PromptTokensDetails []GeminiPromptTokensDetails `json:"promptTokensDetails"`
|
||||
PromptTokenCount int `json:"promptTokenCount"`
|
||||
ToolUsePromptTokenCount int `json:"toolUsePromptTokenCount"`
|
||||
CandidatesTokenCount int `json:"candidatesTokenCount"`
|
||||
TotalTokenCount int `json:"totalTokenCount"`
|
||||
ThoughtsTokenCount int `json:"thoughtsTokenCount"`
|
||||
CachedContentTokenCount int `json:"cachedContentTokenCount"`
|
||||
PromptTokensDetails []GeminiPromptTokensDetails `json:"promptTokensDetails"`
|
||||
ToolUsePromptTokensDetails []GeminiPromptTokensDetails `json:"toolUsePromptTokensDetails"`
|
||||
}
|
||||
|
||||
type GeminiPromptTokensDetails struct {
|
||||
|
||||
@@ -54,18 +54,22 @@ type GeneralOpenAIRequest struct {
|
||||
ParallelTooCalls *bool `json:"parallel_tool_calls,omitempty"`
|
||||
Tools []ToolCallRequest `json:"tools,omitempty"`
|
||||
ToolChoice any `json:"tool_choice,omitempty"`
|
||||
FunctionCall json.RawMessage `json:"function_call,omitempty"`
|
||||
User string `json:"user,omitempty"`
|
||||
LogProbs bool `json:"logprobs,omitempty"`
|
||||
TopLogProbs int `json:"top_logprobs,omitempty"`
|
||||
Dimensions int `json:"dimensions,omitempty"`
|
||||
Modalities json.RawMessage `json:"modalities,omitempty"`
|
||||
Audio json.RawMessage `json:"audio,omitempty"`
|
||||
// ServiceTier specifies upstream service level and may affect billing.
|
||||
// This field is filtered by default and can be enabled via channel setting allow_service_tier.
|
||||
ServiceTier string `json:"service_tier,omitempty"`
|
||||
LogProbs bool `json:"logprobs,omitempty"`
|
||||
TopLogProbs int `json:"top_logprobs,omitempty"`
|
||||
Dimensions int `json:"dimensions,omitempty"`
|
||||
Modalities json.RawMessage `json:"modalities,omitempty"`
|
||||
Audio json.RawMessage `json:"audio,omitempty"`
|
||||
// 安全标识符,用于帮助 OpenAI 检测可能违反使用政策的应用程序用户
|
||||
// 注意:此字段会向 OpenAI 发送用户标识信息,默认过滤以保护用户隐私
|
||||
// 注意:此字段会向 OpenAI 发送用户标识信息,默认过滤,可通过 allow_safety_identifier 开启
|
||||
SafetyIdentifier string `json:"safety_identifier,omitempty"`
|
||||
// Whether or not to store the output of this chat completion request for use in our model distillation or evals products.
|
||||
// 是否存储此次请求数据供 OpenAI 用于评估和优化产品
|
||||
// 注意:默认过滤此字段以保护用户隐私,但过滤后可能导致 Codex 无法正常使用
|
||||
// 注意:默认允许透传,可通过 disable_store 禁用;禁用后可能导致 Codex 无法正常使用
|
||||
Store json.RawMessage `json:"store,omitempty"`
|
||||
// Used by OpenAI to cache responses for similar requests to optimize your cache hit rates. Replaces the user field
|
||||
PromptCacheKey string `json:"prompt_cache_key,omitempty"`
|
||||
@@ -261,6 +265,9 @@ type FunctionRequest struct {
|
||||
|
||||
type StreamOptions struct {
|
||||
IncludeUsage bool `json:"include_usage,omitempty"`
|
||||
// IncludeObfuscation is only for /v1/responses stream payload.
|
||||
// This field is filtered by default and can be enabled via channel setting allow_include_obfuscation.
|
||||
IncludeObfuscation bool `json:"include_obfuscation,omitempty"`
|
||||
}
|
||||
|
||||
func (r *GeneralOpenAIRequest) GetMaxTokens() uint {
|
||||
@@ -799,30 +806,42 @@ type WebSearchOptions struct {
|
||||
|
||||
// https://platform.openai.com/docs/api-reference/responses/create
|
||||
type OpenAIResponsesRequest struct {
|
||||
Model string `json:"model"`
|
||||
Input json.RawMessage `json:"input,omitempty"`
|
||||
Include json.RawMessage `json:"include,omitempty"`
|
||||
Model string `json:"model"`
|
||||
Input json.RawMessage `json:"input,omitempty"`
|
||||
Include json.RawMessage `json:"include,omitempty"`
|
||||
// 在后台运行推理,暂时还不支持依赖的接口
|
||||
// Background json.RawMessage `json:"background,omitempty"`
|
||||
Conversation json.RawMessage `json:"conversation,omitempty"`
|
||||
ContextManagement json.RawMessage `json:"context_management,omitempty"`
|
||||
Instructions json.RawMessage `json:"instructions,omitempty"`
|
||||
MaxOutputTokens uint `json:"max_output_tokens,omitempty"`
|
||||
TopLogProbs *int `json:"top_logprobs,omitempty"`
|
||||
Metadata json.RawMessage `json:"metadata,omitempty"`
|
||||
ParallelToolCalls json.RawMessage `json:"parallel_tool_calls,omitempty"`
|
||||
PreviousResponseID string `json:"previous_response_id,omitempty"`
|
||||
Reasoning *Reasoning `json:"reasoning,omitempty"`
|
||||
// 服务层级字段,用于指定 API 服务等级。允许透传可能导致实际计费高于预期,默认应过滤
|
||||
ServiceTier string `json:"service_tier,omitempty"`
|
||||
// ServiceTier specifies upstream service level and may affect billing.
|
||||
// This field is filtered by default and can be enabled via channel setting allow_service_tier.
|
||||
ServiceTier string `json:"service_tier,omitempty"`
|
||||
// Store controls whether upstream may store request/response data.
|
||||
// This field is allowed by default and can be disabled via channel setting disable_store.
|
||||
Store json.RawMessage `json:"store,omitempty"`
|
||||
PromptCacheKey json.RawMessage `json:"prompt_cache_key,omitempty"`
|
||||
PromptCacheRetention json.RawMessage `json:"prompt_cache_retention,omitempty"`
|
||||
Stream bool `json:"stream,omitempty"`
|
||||
Temperature *float64 `json:"temperature,omitempty"`
|
||||
Text json.RawMessage `json:"text,omitempty"`
|
||||
ToolChoice json.RawMessage `json:"tool_choice,omitempty"`
|
||||
Tools json.RawMessage `json:"tools,omitempty"` // 需要处理的参数很少,MCP 参数太多不确定,所以用 map
|
||||
TopP *float64 `json:"top_p,omitempty"`
|
||||
Truncation string `json:"truncation,omitempty"`
|
||||
User string `json:"user,omitempty"`
|
||||
MaxToolCalls uint `json:"max_tool_calls,omitempty"`
|
||||
Prompt json.RawMessage `json:"prompt,omitempty"`
|
||||
// SafetyIdentifier carries client identity for policy abuse detection.
|
||||
// This field is filtered by default and can be enabled via channel setting allow_safety_identifier.
|
||||
SafetyIdentifier string `json:"safety_identifier,omitempty"`
|
||||
Stream bool `json:"stream,omitempty"`
|
||||
StreamOptions *StreamOptions `json:"stream_options,omitempty"`
|
||||
Temperature *float64 `json:"temperature,omitempty"`
|
||||
Text json.RawMessage `json:"text,omitempty"`
|
||||
ToolChoice json.RawMessage `json:"tool_choice,omitempty"`
|
||||
Tools json.RawMessage `json:"tools,omitempty"` // 需要处理的参数很少,MCP 参数太多不确定,所以用 map
|
||||
TopP *float64 `json:"top_p,omitempty"`
|
||||
Truncation string `json:"truncation,omitempty"`
|
||||
User string `json:"user,omitempty"`
|
||||
MaxToolCalls uint `json:"max_tool_calls,omitempty"`
|
||||
Prompt json.RawMessage `json:"prompt,omitempty"`
|
||||
// qwen
|
||||
EnableThinking json.RawMessage `json:"enable_thinking,omitempty"`
|
||||
// perplexity
|
||||
|
||||
32
dto/suno.go
32
dto/suno.go
@@ -4,10 +4,6 @@ import (
|
||||
"encoding/json"
|
||||
)
|
||||
|
||||
type TaskData interface {
|
||||
SunoDataResponse | []SunoDataResponse | string | any
|
||||
}
|
||||
|
||||
type SunoSubmitReq struct {
|
||||
GptDescriptionPrompt string `json:"gpt_description_prompt,omitempty"`
|
||||
Prompt string `json:"prompt,omitempty"`
|
||||
@@ -20,10 +16,6 @@ type SunoSubmitReq struct {
|
||||
MakeInstrumental bool `json:"make_instrumental"`
|
||||
}
|
||||
|
||||
type FetchReq struct {
|
||||
IDs []string `json:"ids"`
|
||||
}
|
||||
|
||||
type SunoDataResponse struct {
|
||||
TaskID string `json:"task_id" gorm:"type:varchar(50);index"`
|
||||
Action string `json:"action" gorm:"type:varchar(40);index"` // 任务类型, song, lyrics, description-mode
|
||||
@@ -66,30 +58,6 @@ type SunoLyrics struct {
|
||||
Text string `json:"text"`
|
||||
}
|
||||
|
||||
const TaskSuccessCode = "success"
|
||||
|
||||
type TaskResponse[T TaskData] struct {
|
||||
Code string `json:"code"`
|
||||
Message string `json:"message"`
|
||||
Data T `json:"data"`
|
||||
}
|
||||
|
||||
func (t *TaskResponse[T]) IsSuccess() bool {
|
||||
return t.Code == TaskSuccessCode
|
||||
}
|
||||
|
||||
type TaskDto struct {
|
||||
TaskID string `json:"task_id"` // 第三方id,不一定有/ song id\ Task id
|
||||
Action string `json:"action"` // 任务类型, song, lyrics, description-mode
|
||||
Status string `json:"status"` // 任务状态, submitted, queueing, processing, success, failed
|
||||
FailReason string `json:"fail_reason"`
|
||||
SubmitTime int64 `json:"submit_time"`
|
||||
StartTime int64 `json:"start_time"`
|
||||
FinishTime int64 `json:"finish_time"`
|
||||
Progress string `json:"progress"`
|
||||
Data json.RawMessage `json:"data"`
|
||||
}
|
||||
|
||||
type SunoGoAPISubmitReq struct {
|
||||
CustomMode bool `json:"custom_mode"`
|
||||
|
||||
|
||||
47
dto/task.go
47
dto/task.go
@@ -1,5 +1,9 @@
|
||||
package dto
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
)
|
||||
|
||||
type TaskError struct {
|
||||
Code string `json:"code"`
|
||||
Message string `json:"message"`
|
||||
@@ -8,3 +12,46 @@ type TaskError struct {
|
||||
LocalError bool `json:"-"`
|
||||
Error error `json:"-"`
|
||||
}
|
||||
|
||||
type TaskData interface {
|
||||
SunoDataResponse | []SunoDataResponse | string | any
|
||||
}
|
||||
|
||||
const TaskSuccessCode = "success"
|
||||
|
||||
type TaskResponse[T TaskData] struct {
|
||||
Code string `json:"code"`
|
||||
Message string `json:"message"`
|
||||
Data T `json:"data"`
|
||||
}
|
||||
|
||||
func (t *TaskResponse[T]) IsSuccess() bool {
|
||||
return t.Code == TaskSuccessCode
|
||||
}
|
||||
|
||||
type TaskDto struct {
|
||||
ID int64 `json:"id"`
|
||||
CreatedAt int64 `json:"created_at"`
|
||||
UpdatedAt int64 `json:"updated_at"`
|
||||
TaskID string `json:"task_id"`
|
||||
Platform string `json:"platform"`
|
||||
UserId int `json:"user_id"`
|
||||
Group string `json:"group"`
|
||||
ChannelId int `json:"channel_id"`
|
||||
Quota int `json:"quota"`
|
||||
Action string `json:"action"`
|
||||
Status string `json:"status"`
|
||||
FailReason string `json:"fail_reason"`
|
||||
ResultURL string `json:"result_url,omitempty"` // 任务结果 URL(视频地址等)
|
||||
SubmitTime int64 `json:"submit_time"`
|
||||
StartTime int64 `json:"start_time"`
|
||||
FinishTime int64 `json:"finish_time"`
|
||||
Progress string `json:"progress"`
|
||||
Properties any `json:"properties"`
|
||||
Username string `json:"username,omitempty"`
|
||||
Data json.RawMessage `json:"data"`
|
||||
}
|
||||
|
||||
type FetchReq struct {
|
||||
IDs []string `json:"ids"`
|
||||
}
|
||||
|
||||
@@ -2,7 +2,6 @@ package logger
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"log"
|
||||
@@ -151,7 +150,7 @@ func FormatQuota(quota int) string {
|
||||
|
||||
// LogJson 仅供测试使用 only for test
|
||||
func LogJson(ctx context.Context, msg string, obj any) {
|
||||
jsonStr, err := json.Marshal(obj)
|
||||
jsonStr, err := common.Marshal(obj)
|
||||
if err != nil {
|
||||
LogError(ctx, fmt.Sprintf("json marshal failed: %s", err.Error()))
|
||||
return
|
||||
|
||||
10
main.go
10
main.go
@@ -19,6 +19,7 @@ import (
|
||||
"github.com/QuantumNous/new-api/middleware"
|
||||
"github.com/QuantumNous/new-api/model"
|
||||
"github.com/QuantumNous/new-api/oauth"
|
||||
"github.com/QuantumNous/new-api/relay"
|
||||
"github.com/QuantumNous/new-api/router"
|
||||
"github.com/QuantumNous/new-api/service"
|
||||
_ "github.com/QuantumNous/new-api/setting/performance_setting"
|
||||
@@ -111,6 +112,15 @@ func main() {
|
||||
// Subscription quota reset task (daily/weekly/monthly/custom)
|
||||
service.StartSubscriptionQuotaResetTask()
|
||||
|
||||
// Wire task polling adaptor factory (breaks service -> relay import cycle)
|
||||
service.GetTaskAdaptorFunc = func(platform constant.TaskPlatform) service.TaskPollingAdaptor {
|
||||
a := relay.GetTaskAdaptor(platform)
|
||||
if a == nil {
|
||||
return nil
|
||||
}
|
||||
return a
|
||||
}
|
||||
|
||||
if common.IsMasterNode && constant.UpdateTask {
|
||||
gopool.Go(func() {
|
||||
controller.UpdateMidjourneyTaskBulk()
|
||||
|
||||
@@ -170,6 +170,24 @@ func WssAuth(c *gin.Context) {
|
||||
|
||||
}
|
||||
|
||||
// TokenOrUserAuth allows either session-based user auth or API token auth.
|
||||
// Used for endpoints that need to be accessible from both the dashboard and API clients.
|
||||
func TokenOrUserAuth() func(c *gin.Context) {
|
||||
return func(c *gin.Context) {
|
||||
// Try session auth first (dashboard users)
|
||||
session := sessions.Default(c)
|
||||
if id := session.Get("id"); id != nil {
|
||||
if status, ok := session.Get("status").(int); ok && status == common.UserStatusEnabled {
|
||||
c.Set("id", id)
|
||||
c.Next()
|
||||
return
|
||||
}
|
||||
}
|
||||
// Fall back to token auth (API clients)
|
||||
TokenAuth()(c)
|
||||
}
|
||||
}
|
||||
|
||||
// TokenAuthReadOnly 宽松版本的令牌认证中间件,用于只读查询接口。
|
||||
// 只验证令牌 key 是否存在,不检查令牌状态、过期时间和额度。
|
||||
// 即使令牌已过期、已耗尽或已禁用,也允许访问。
|
||||
|
||||
@@ -7,14 +7,28 @@ import (
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
const RouteTagKey = "route_tag"
|
||||
|
||||
func RouteTag(tag string) gin.HandlerFunc {
|
||||
return func(c *gin.Context) {
|
||||
c.Set(RouteTagKey, tag)
|
||||
c.Next()
|
||||
}
|
||||
}
|
||||
|
||||
func SetUpLogger(server *gin.Engine) {
|
||||
server.Use(gin.LoggerWithFormatter(func(param gin.LogFormatterParams) string {
|
||||
var requestID string
|
||||
if param.Keys != nil {
|
||||
requestID = param.Keys[common.RequestIdKey].(string)
|
||||
requestID, _ = param.Keys[common.RequestIdKey].(string)
|
||||
}
|
||||
return fmt.Sprintf("[GIN] %s | %s | %3d | %13v | %15s | %7s %s\n",
|
||||
tag, _ := param.Keys[RouteTagKey].(string)
|
||||
if tag == "" {
|
||||
tag = "web"
|
||||
}
|
||||
return fmt.Sprintf("[GIN] %s | %s | %s | %3d | %13v | %15s | %7s %s\n",
|
||||
param.TimeStamp.Format("2006/01/02 - 15:04:05"),
|
||||
tag,
|
||||
requestID,
|
||||
param.StatusCode,
|
||||
param.Latency,
|
||||
|
||||
@@ -2,32 +2,65 @@ package model
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/QuantumNous/new-api/common"
|
||||
)
|
||||
|
||||
type accessPolicyPayload struct {
|
||||
Logic string `json:"logic"`
|
||||
Conditions []accessConditionItem `json:"conditions"`
|
||||
Groups []accessPolicyPayload `json:"groups"`
|
||||
}
|
||||
|
||||
type accessConditionItem struct {
|
||||
Field string `json:"field"`
|
||||
Op string `json:"op"`
|
||||
Value any `json:"value"`
|
||||
}
|
||||
|
||||
var supportedAccessPolicyOps = map[string]struct{}{
|
||||
"eq": {},
|
||||
"ne": {},
|
||||
"gt": {},
|
||||
"gte": {},
|
||||
"lt": {},
|
||||
"lte": {},
|
||||
"in": {},
|
||||
"not_in": {},
|
||||
"contains": {},
|
||||
"not_contains": {},
|
||||
"exists": {},
|
||||
"not_exists": {},
|
||||
}
|
||||
|
||||
// CustomOAuthProvider stores configuration for custom OAuth providers
|
||||
type CustomOAuthProvider struct {
|
||||
Id int `json:"id" gorm:"primaryKey"`
|
||||
Name string `json:"name" gorm:"type:varchar(64);not null"` // Display name, e.g., "GitHub Enterprise"
|
||||
Slug string `json:"slug" gorm:"type:varchar(64);uniqueIndex;not null"` // URL identifier, e.g., "github-enterprise"
|
||||
Enabled bool `json:"enabled" gorm:"default:false"` // Whether this provider is enabled
|
||||
ClientId string `json:"client_id" gorm:"type:varchar(256)"` // OAuth client ID
|
||||
ClientSecret string `json:"-" gorm:"type:varchar(512)"` // OAuth client secret (not returned to frontend)
|
||||
AuthorizationEndpoint string `json:"authorization_endpoint" gorm:"type:varchar(512)"` // Authorization URL
|
||||
TokenEndpoint string `json:"token_endpoint" gorm:"type:varchar(512)"` // Token exchange URL
|
||||
UserInfoEndpoint string `json:"user_info_endpoint" gorm:"type:varchar(512)"` // User info URL
|
||||
Scopes string `json:"scopes" gorm:"type:varchar(256);default:'openid profile email'"` // OAuth scopes
|
||||
Id int `json:"id" gorm:"primaryKey"`
|
||||
Name string `json:"name" gorm:"type:varchar(64);not null"` // Display name, e.g., "GitHub Enterprise"
|
||||
Slug string `json:"slug" gorm:"type:varchar(64);uniqueIndex;not null"` // URL identifier, e.g., "github-enterprise"
|
||||
Icon string `json:"icon" gorm:"type:varchar(128);default:''"` // Icon name from @lobehub/icons
|
||||
Enabled bool `json:"enabled" gorm:"default:false"` // Whether this provider is enabled
|
||||
ClientId string `json:"client_id" gorm:"type:varchar(256)"` // OAuth client ID
|
||||
ClientSecret string `json:"-" gorm:"type:varchar(512)"` // OAuth client secret (not returned to frontend)
|
||||
AuthorizationEndpoint string `json:"authorization_endpoint" gorm:"type:varchar(512)"` // Authorization URL
|
||||
TokenEndpoint string `json:"token_endpoint" gorm:"type:varchar(512)"` // Token exchange URL
|
||||
UserInfoEndpoint string `json:"user_info_endpoint" gorm:"type:varchar(512)"` // User info URL
|
||||
Scopes string `json:"scopes" gorm:"type:varchar(256);default:'openid profile email'"` // OAuth scopes
|
||||
|
||||
// Field mapping configuration (supports JSONPath via gjson)
|
||||
UserIdField string `json:"user_id_field" gorm:"type:varchar(128);default:'sub'"` // User ID field path, e.g., "sub", "id", "data.user.id"
|
||||
UsernameField string `json:"username_field" gorm:"type:varchar(128);default:'preferred_username'"` // Username field path
|
||||
DisplayNameField string `json:"display_name_field" gorm:"type:varchar(128);default:'name'"` // Display name field path
|
||||
EmailField string `json:"email_field" gorm:"type:varchar(128);default:'email'"` // Email field path
|
||||
UserIdField string `json:"user_id_field" gorm:"type:varchar(128);default:'sub'"` // User ID field path, e.g., "sub", "id", "data.user.id"
|
||||
UsernameField string `json:"username_field" gorm:"type:varchar(128);default:'preferred_username'"` // Username field path
|
||||
DisplayNameField string `json:"display_name_field" gorm:"type:varchar(128);default:'name'"` // Display name field path
|
||||
EmailField string `json:"email_field" gorm:"type:varchar(128);default:'email'"` // Email field path
|
||||
|
||||
// Advanced options
|
||||
WellKnown string `json:"well_known" gorm:"type:varchar(512)"` // OIDC discovery endpoint (optional)
|
||||
AuthStyle int `json:"auth_style" gorm:"default:0"` // 0=auto, 1=params, 2=header (Basic Auth)
|
||||
WellKnown string `json:"well_known" gorm:"type:varchar(512)"` // OIDC discovery endpoint (optional)
|
||||
AuthStyle int `json:"auth_style" gorm:"default:0"` // 0=auto, 1=params, 2=header (Basic Auth)
|
||||
AccessPolicy string `json:"access_policy" gorm:"type:text"` // JSON policy for access control based on user info
|
||||
AccessDeniedMessage string `json:"access_denied_message" gorm:"type:varchar(512)"` // Custom error message template when access is denied
|
||||
|
||||
CreatedAt time.Time `json:"created_at"`
|
||||
UpdatedAt time.Time `json:"updated_at"`
|
||||
@@ -158,6 +191,57 @@ func validateCustomOAuthProvider(provider *CustomOAuthProvider) error {
|
||||
if provider.Scopes == "" {
|
||||
provider.Scopes = "openid profile email"
|
||||
}
|
||||
if strings.TrimSpace(provider.AccessPolicy) != "" {
|
||||
var policy accessPolicyPayload
|
||||
if err := common.UnmarshalJsonStr(provider.AccessPolicy, &policy); err != nil {
|
||||
return errors.New("access_policy must be valid JSON")
|
||||
}
|
||||
if err := validateAccessPolicyPayload(&policy); err != nil {
|
||||
return fmt.Errorf("access_policy is invalid: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func validateAccessPolicyPayload(policy *accessPolicyPayload) error {
|
||||
if policy == nil {
|
||||
return errors.New("policy is nil")
|
||||
}
|
||||
|
||||
logic := strings.ToLower(strings.TrimSpace(policy.Logic))
|
||||
if logic == "" {
|
||||
logic = "and"
|
||||
}
|
||||
if logic != "and" && logic != "or" {
|
||||
return fmt.Errorf("unsupported logic: %s", logic)
|
||||
}
|
||||
|
||||
if len(policy.Conditions) == 0 && len(policy.Groups) == 0 {
|
||||
return errors.New("policy requires at least one condition or group")
|
||||
}
|
||||
|
||||
for index, condition := range policy.Conditions {
|
||||
field := strings.TrimSpace(condition.Field)
|
||||
if field == "" {
|
||||
return fmt.Errorf("condition[%d].field is required", index)
|
||||
}
|
||||
op := strings.ToLower(strings.TrimSpace(condition.Op))
|
||||
if _, ok := supportedAccessPolicyOps[op]; !ok {
|
||||
return fmt.Errorf("condition[%d].op is unsupported: %s", index, op)
|
||||
}
|
||||
if op == "in" || op == "not_in" {
|
||||
if _, ok := condition.Value.([]any); !ok {
|
||||
return fmt.Errorf("condition[%d].value must be an array for op %s", index, op)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
for index := range policy.Groups {
|
||||
if err := validateAccessPolicyPayload(&policy.Groups[index]); err != nil {
|
||||
return fmt.Errorf("group[%d]: %w", index, err)
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
63
model/log.go
63
model/log.go
@@ -199,6 +199,49 @@ func RecordConsumeLog(c *gin.Context, userId int, params RecordConsumeLogParams)
|
||||
}
|
||||
}
|
||||
|
||||
type RecordTaskBillingLogParams struct {
|
||||
UserId int
|
||||
LogType int
|
||||
Content string
|
||||
ChannelId int
|
||||
ModelName string
|
||||
Quota int
|
||||
TokenId int
|
||||
Group string
|
||||
Other map[string]interface{}
|
||||
}
|
||||
|
||||
func RecordTaskBillingLog(params RecordTaskBillingLogParams) {
|
||||
if params.LogType == LogTypeConsume && !common.LogConsumeEnabled {
|
||||
return
|
||||
}
|
||||
username, _ := GetUsernameById(params.UserId, false)
|
||||
tokenName := ""
|
||||
if params.TokenId > 0 {
|
||||
if token, err := GetTokenById(params.TokenId); err == nil {
|
||||
tokenName = token.Name
|
||||
}
|
||||
}
|
||||
log := &Log{
|
||||
UserId: params.UserId,
|
||||
Username: username,
|
||||
CreatedAt: common.GetTimestamp(),
|
||||
Type: params.LogType,
|
||||
Content: params.Content,
|
||||
TokenName: tokenName,
|
||||
ModelName: params.ModelName,
|
||||
Quota: params.Quota,
|
||||
ChannelId: params.ChannelId,
|
||||
TokenId: params.TokenId,
|
||||
Group: params.Group,
|
||||
Other: common.MapToJsonStr(params.Other),
|
||||
}
|
||||
err := LOG_DB.Create(log).Error
|
||||
if err != nil {
|
||||
common.SysLog("failed to record task billing log: " + err.Error())
|
||||
}
|
||||
}
|
||||
|
||||
func GetAllLogs(logType int, startTimestamp int64, endTimestamp int64, modelName string, username string, tokenName string, startIdx int, num int, channel int, group string, requestId string) (logs []*Log, total int64, err error) {
|
||||
var tx *gorm.DB
|
||||
if logType == LogTypeUnknown {
|
||||
@@ -252,8 +295,24 @@ func GetAllLogs(logType int, startTimestamp int64, endTimestamp int64, modelName
|
||||
Id int `gorm:"column:id"`
|
||||
Name string `gorm:"column:name"`
|
||||
}
|
||||
if err = DB.Table("channels").Select("id, name").Where("id IN ?", channelIds.Items()).Find(&channels).Error; err != nil {
|
||||
return logs, total, err
|
||||
if common.MemoryCacheEnabled {
|
||||
// Cache get channel
|
||||
for _, channelId := range channelIds.Items() {
|
||||
if cacheChannel, err := CacheGetChannel(channelId); err == nil {
|
||||
channels = append(channels, struct {
|
||||
Id int `gorm:"column:id"`
|
||||
Name string `gorm:"column:name"`
|
||||
}{
|
||||
Id: channelId,
|
||||
Name: cacheChannel.Name,
|
||||
})
|
||||
}
|
||||
}
|
||||
} else {
|
||||
// Bulk query channels from DB
|
||||
if err = DB.Table("channels").Select("id, name").Where("id IN ?", channelIds.Items()).Find(&channels).Error; err != nil {
|
||||
return logs, total, err
|
||||
}
|
||||
}
|
||||
channelMap := make(map[int]string, len(channels))
|
||||
for _, channel := range channels {
|
||||
|
||||
@@ -157,6 +157,19 @@ func (midjourney *Midjourney) Update() error {
|
||||
return err
|
||||
}
|
||||
|
||||
// UpdateWithStatus performs a conditional UPDATE guarded by fromStatus (CAS).
|
||||
// Returns (true, nil) if this caller won the update, (false, nil) if
|
||||
// another process already moved the task out of fromStatus.
|
||||
// UpdateWithStatus performs a conditional UPDATE guarded by fromStatus (CAS).
|
||||
// Uses Model().Select("*").Updates() to avoid GORM Save()'s INSERT fallback.
|
||||
func (midjourney *Midjourney) UpdateWithStatus(fromStatus string) (bool, error) {
|
||||
result := DB.Model(midjourney).Where("status = ?", fromStatus).Select("*").Updates(midjourney)
|
||||
if result.Error != nil {
|
||||
return false, result.Error
|
||||
}
|
||||
return result.RowsAffected > 0, nil
|
||||
}
|
||||
|
||||
func MjBulkUpdate(mjIds []string, params map[string]any) error {
|
||||
return DB.Model(&Midjourney{}).
|
||||
Where("mj_id in (?)", mjIds).
|
||||
|
||||
182
model/task.go
182
model/task.go
@@ -1,10 +1,12 @@
|
||||
package model
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"database/sql/driver"
|
||||
"encoding/json"
|
||||
"time"
|
||||
|
||||
"github.com/QuantumNous/new-api/common"
|
||||
"github.com/QuantumNous/new-api/constant"
|
||||
"github.com/QuantumNous/new-api/dto"
|
||||
commonRelay "github.com/QuantumNous/new-api/relay/common"
|
||||
@@ -64,13 +66,12 @@ type Task struct {
|
||||
}
|
||||
|
||||
func (t *Task) SetData(data any) {
|
||||
b, _ := json.Marshal(data)
|
||||
b, _ := common.Marshal(data)
|
||||
t.Data = json.RawMessage(b)
|
||||
}
|
||||
|
||||
func (t *Task) GetData(v any) error {
|
||||
err := json.Unmarshal(t.Data, &v)
|
||||
return err
|
||||
return common.Unmarshal(t.Data, &v)
|
||||
}
|
||||
|
||||
type Properties struct {
|
||||
@@ -85,18 +86,59 @@ func (m *Properties) Scan(val interface{}) error {
|
||||
*m = Properties{}
|
||||
return nil
|
||||
}
|
||||
return json.Unmarshal(bytesValue, m)
|
||||
return common.Unmarshal(bytesValue, m)
|
||||
}
|
||||
|
||||
func (m Properties) Value() (driver.Value, error) {
|
||||
if m == (Properties{}) {
|
||||
return nil, nil
|
||||
}
|
||||
return json.Marshal(m)
|
||||
return common.Marshal(m)
|
||||
}
|
||||
|
||||
type TaskPrivateData struct {
|
||||
Key string `json:"key,omitempty"`
|
||||
Key string `json:"key,omitempty"`
|
||||
UpstreamTaskID string `json:"upstream_task_id,omitempty"` // 上游真实 task ID
|
||||
ResultURL string `json:"result_url,omitempty"` // 任务成功后的结果 URL(视频地址等)
|
||||
// 计费上下文:用于异步退款/差额结算(轮询阶段读取)
|
||||
BillingSource string `json:"billing_source,omitempty"` // "wallet" 或 "subscription"
|
||||
SubscriptionId int `json:"subscription_id,omitempty"` // 订阅 ID,用于订阅退款
|
||||
TokenId int `json:"token_id,omitempty"` // 令牌 ID,用于令牌额度退款
|
||||
BillingContext *TaskBillingContext `json:"billing_context,omitempty"` // 计费参数快照(用于轮询阶段重新计算)
|
||||
}
|
||||
|
||||
// TaskBillingContext 记录任务提交时的计费参数,以便轮询阶段可以重新计算额度。
|
||||
type TaskBillingContext struct {
|
||||
ModelPrice float64 `json:"model_price,omitempty"` // 模型单价
|
||||
GroupRatio float64 `json:"group_ratio,omitempty"` // 分组倍率
|
||||
ModelRatio float64 `json:"model_ratio,omitempty"` // 模型倍率
|
||||
OtherRatios map[string]float64 `json:"other_ratios,omitempty"` // 附加倍率(时长、分辨率等)
|
||||
OriginModelName string `json:"origin_model_name,omitempty"` // 模型名称,必须为OriginModelName
|
||||
PerCallBilling bool `json:"per_call_billing,omitempty"` // 按次计费:跳过轮询阶段的差额结算
|
||||
}
|
||||
|
||||
// GetUpstreamTaskID 获取上游真实 task ID(用于与 provider 通信)
|
||||
// 旧数据没有 UpstreamTaskID 时,TaskID 本身就是上游 ID
|
||||
func (t *Task) GetUpstreamTaskID() string {
|
||||
if t.PrivateData.UpstreamTaskID != "" {
|
||||
return t.PrivateData.UpstreamTaskID
|
||||
}
|
||||
return t.TaskID
|
||||
}
|
||||
|
||||
// GetResultURL 获取任务结果 URL(视频地址等)
|
||||
// 新数据存在 PrivateData.ResultURL 中;旧数据回退到 FailReason(历史兼容)
|
||||
func (t *Task) GetResultURL() string {
|
||||
if t.PrivateData.ResultURL != "" {
|
||||
return t.PrivateData.ResultURL
|
||||
}
|
||||
return t.FailReason
|
||||
}
|
||||
|
||||
// GenerateTaskID 生成对外暴露的 task_xxxx 格式 ID
|
||||
func GenerateTaskID() string {
|
||||
key, _ := common.GenerateRandomCharsKey(32)
|
||||
return "task_" + key
|
||||
}
|
||||
|
||||
func (p *TaskPrivateData) Scan(val interface{}) error {
|
||||
@@ -104,14 +146,14 @@ func (p *TaskPrivateData) Scan(val interface{}) error {
|
||||
if len(bytesValue) == 0 {
|
||||
return nil
|
||||
}
|
||||
return json.Unmarshal(bytesValue, p)
|
||||
return common.Unmarshal(bytesValue, p)
|
||||
}
|
||||
|
||||
func (p TaskPrivateData) Value() (driver.Value, error) {
|
||||
if (p == TaskPrivateData{}) {
|
||||
return nil, nil
|
||||
}
|
||||
return json.Marshal(p)
|
||||
return common.Marshal(p)
|
||||
}
|
||||
|
||||
// SyncTaskQueryParams 用于包含所有搜索条件的结构体,可以根据需求添加更多字段
|
||||
@@ -142,7 +184,16 @@ func InitTask(platform constant.TaskPlatform, relayInfo *commonRelay.RelayInfo)
|
||||
}
|
||||
}
|
||||
|
||||
// 使用预生成的公开 ID(如果有),否则新生成
|
||||
taskID := ""
|
||||
if relayInfo.TaskRelayInfo != nil && relayInfo.TaskRelayInfo.PublicTaskID != "" {
|
||||
taskID = relayInfo.TaskRelayInfo.PublicTaskID
|
||||
} else {
|
||||
taskID = GenerateTaskID()
|
||||
}
|
||||
|
||||
t := &Task{
|
||||
TaskID: taskID,
|
||||
UserId: relayInfo.UserId,
|
||||
Group: relayInfo.UsingGroup,
|
||||
SubmitTime: time.Now().Unix(),
|
||||
@@ -237,6 +288,20 @@ func TaskGetAllTasks(startIdx int, num int, queryParams SyncTaskQueryParams) []*
|
||||
return tasks
|
||||
}
|
||||
|
||||
func GetTimedOutUnfinishedTasks(cutoffUnix int64, limit int) []*Task {
|
||||
var tasks []*Task
|
||||
err := DB.Where("progress != ?", "100%").
|
||||
Where("status NOT IN ?", []string{TaskStatusFailure, TaskStatusSuccess}).
|
||||
Where("submit_time < ?", cutoffUnix).
|
||||
Order("submit_time").
|
||||
Limit(limit).
|
||||
Find(&tasks).Error
|
||||
if err != nil {
|
||||
return nil
|
||||
}
|
||||
return tasks
|
||||
}
|
||||
|
||||
func GetAllUnFinishSyncTasks(limit int) []*Task {
|
||||
var tasks []*Task
|
||||
var err error
|
||||
@@ -291,40 +356,70 @@ func GetByTaskIds(userId int, taskIds []any) ([]*Task, error) {
|
||||
return task, nil
|
||||
}
|
||||
|
||||
func TaskUpdateProgress(id int64, progress string) error {
|
||||
return DB.Model(&Task{}).Where("id = ?", id).Update("progress", progress).Error
|
||||
}
|
||||
|
||||
func (Task *Task) Insert() error {
|
||||
var err error
|
||||
err = DB.Create(Task).Error
|
||||
return err
|
||||
}
|
||||
|
||||
type taskSnapshot struct {
|
||||
Status TaskStatus
|
||||
Progress string
|
||||
StartTime int64
|
||||
FinishTime int64
|
||||
FailReason string
|
||||
ResultURL string
|
||||
Data json.RawMessage
|
||||
}
|
||||
|
||||
func (s taskSnapshot) Equal(other taskSnapshot) bool {
|
||||
return s.Status == other.Status &&
|
||||
s.Progress == other.Progress &&
|
||||
s.StartTime == other.StartTime &&
|
||||
s.FinishTime == other.FinishTime &&
|
||||
s.FailReason == other.FailReason &&
|
||||
s.ResultURL == other.ResultURL &&
|
||||
bytes.Equal(s.Data, other.Data)
|
||||
}
|
||||
|
||||
func (t *Task) Snapshot() taskSnapshot {
|
||||
return taskSnapshot{
|
||||
Status: t.Status,
|
||||
Progress: t.Progress,
|
||||
StartTime: t.StartTime,
|
||||
FinishTime: t.FinishTime,
|
||||
FailReason: t.FailReason,
|
||||
ResultURL: t.PrivateData.ResultURL,
|
||||
Data: t.Data,
|
||||
}
|
||||
}
|
||||
|
||||
func (Task *Task) Update() error {
|
||||
var err error
|
||||
err = DB.Save(Task).Error
|
||||
return err
|
||||
}
|
||||
|
||||
func TaskBulkUpdate(TaskIds []string, params map[string]any) error {
|
||||
if len(TaskIds) == 0 {
|
||||
return nil
|
||||
// UpdateWithStatus performs a conditional UPDATE guarded by fromStatus (CAS).
|
||||
// Returns (true, nil) if this caller won the update, (false, nil) if
|
||||
// another process already moved the task out of fromStatus.
|
||||
//
|
||||
// Uses Model().Select("*").Updates() instead of Save() because GORM's Save
|
||||
// falls back to INSERT ON CONFLICT when the WHERE-guarded UPDATE matches
|
||||
// zero rows, which silently bypasses the CAS guard.
|
||||
func (t *Task) UpdateWithStatus(fromStatus TaskStatus) (bool, error) {
|
||||
result := DB.Model(t).Where("status = ?", fromStatus).Select("*").Updates(t)
|
||||
if result.Error != nil {
|
||||
return false, result.Error
|
||||
}
|
||||
return DB.Model(&Task{}).
|
||||
Where("task_id in (?)", TaskIds).
|
||||
Updates(params).Error
|
||||
}
|
||||
|
||||
func TaskBulkUpdateByTaskIds(taskIDs []int64, params map[string]any) error {
|
||||
if len(taskIDs) == 0 {
|
||||
return nil
|
||||
}
|
||||
return DB.Model(&Task{}).
|
||||
Where("id in (?)", taskIDs).
|
||||
Updates(params).Error
|
||||
return result.RowsAffected > 0, nil
|
||||
}
|
||||
|
||||
// TaskBulkUpdateByID performs an unconditional bulk UPDATE by primary key IDs.
|
||||
// WARNING: This function has NO CAS (Compare-And-Swap) guard — it will overwrite
|
||||
// any concurrent status changes. DO NOT use in billing/quota lifecycle flows
|
||||
// (e.g., timeout, success, failure transitions that trigger refunds or settlements).
|
||||
// For status transitions that involve billing, use Task.UpdateWithStatus() instead.
|
||||
func TaskBulkUpdateByID(ids []int64, params map[string]any) error {
|
||||
if len(ids) == 0 {
|
||||
return nil
|
||||
@@ -339,37 +434,6 @@ type TaskQuotaUsage struct {
|
||||
Count float64 `json:"count"`
|
||||
}
|
||||
|
||||
func SumUsedTaskQuota(queryParams SyncTaskQueryParams) (stat []TaskQuotaUsage, err error) {
|
||||
query := DB.Model(Task{})
|
||||
// 添加过滤条件
|
||||
if queryParams.ChannelID != "" {
|
||||
query = query.Where("channel_id = ?", queryParams.ChannelID)
|
||||
}
|
||||
if queryParams.UserID != "" {
|
||||
query = query.Where("user_id = ?", queryParams.UserID)
|
||||
}
|
||||
if len(queryParams.UserIDs) != 0 {
|
||||
query = query.Where("user_id in (?)", queryParams.UserIDs)
|
||||
}
|
||||
if queryParams.TaskID != "" {
|
||||
query = query.Where("task_id = ?", queryParams.TaskID)
|
||||
}
|
||||
if queryParams.Action != "" {
|
||||
query = query.Where("action = ?", queryParams.Action)
|
||||
}
|
||||
if queryParams.Status != "" {
|
||||
query = query.Where("status = ?", queryParams.Status)
|
||||
}
|
||||
if queryParams.StartTimestamp != 0 {
|
||||
query = query.Where("submit_time >= ?", queryParams.StartTimestamp)
|
||||
}
|
||||
if queryParams.EndTimestamp != 0 {
|
||||
query = query.Where("submit_time <= ?", queryParams.EndTimestamp)
|
||||
}
|
||||
err = query.Select("mode, sum(quota) as count").Group("mode").Find(&stat).Error
|
||||
return stat, err
|
||||
}
|
||||
|
||||
// TaskCountAllTasks returns total tasks that match the given query params (admin usage)
|
||||
func TaskCountAllTasks(queryParams SyncTaskQueryParams) int64 {
|
||||
var total int64
|
||||
@@ -438,6 +502,6 @@ func (t *Task) ToOpenAIVideo() *dto.OpenAIVideo {
|
||||
openAIVideo.SetProgressStr(t.Progress)
|
||||
openAIVideo.CreatedAt = t.CreatedAt
|
||||
openAIVideo.CompletedAt = t.UpdatedAt
|
||||
openAIVideo.SetMetadata("url", t.FailReason)
|
||||
openAIVideo.SetMetadata("url", t.GetResultURL())
|
||||
return openAIVideo
|
||||
}
|
||||
|
||||
217
model/task_cas_test.go
Normal file
217
model/task_cas_test.go
Normal file
@@ -0,0 +1,217 @@
|
||||
package model
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"os"
|
||||
"sync"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/QuantumNous/new-api/common"
|
||||
"github.com/glebarez/sqlite"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
"gorm.io/gorm"
|
||||
)
|
||||
|
||||
func TestMain(m *testing.M) {
|
||||
db, err := gorm.Open(sqlite.Open(":memory:"), &gorm.Config{})
|
||||
if err != nil {
|
||||
panic("failed to open test db: " + err.Error())
|
||||
}
|
||||
DB = db
|
||||
LOG_DB = db
|
||||
|
||||
common.UsingSQLite = true
|
||||
common.RedisEnabled = false
|
||||
common.BatchUpdateEnabled = false
|
||||
common.LogConsumeEnabled = true
|
||||
|
||||
sqlDB, err := db.DB()
|
||||
if err != nil {
|
||||
panic("failed to get sql.DB: " + err.Error())
|
||||
}
|
||||
sqlDB.SetMaxOpenConns(1)
|
||||
|
||||
if err := db.AutoMigrate(&Task{}, &User{}, &Token{}, &Log{}, &Channel{}); err != nil {
|
||||
panic("failed to migrate: " + err.Error())
|
||||
}
|
||||
|
||||
os.Exit(m.Run())
|
||||
}
|
||||
|
||||
func truncateTables(t *testing.T) {
|
||||
t.Helper()
|
||||
t.Cleanup(func() {
|
||||
DB.Exec("DELETE FROM tasks")
|
||||
DB.Exec("DELETE FROM users")
|
||||
DB.Exec("DELETE FROM tokens")
|
||||
DB.Exec("DELETE FROM logs")
|
||||
DB.Exec("DELETE FROM channels")
|
||||
})
|
||||
}
|
||||
|
||||
func insertTask(t *testing.T, task *Task) {
|
||||
t.Helper()
|
||||
task.CreatedAt = time.Now().Unix()
|
||||
task.UpdatedAt = time.Now().Unix()
|
||||
require.NoError(t, DB.Create(task).Error)
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// Snapshot / Equal — pure logic tests (no DB)
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
func TestSnapshotEqual_Same(t *testing.T) {
|
||||
s := taskSnapshot{
|
||||
Status: TaskStatusInProgress,
|
||||
Progress: "50%",
|
||||
StartTime: 1000,
|
||||
FinishTime: 0,
|
||||
FailReason: "",
|
||||
ResultURL: "",
|
||||
Data: json.RawMessage(`{"key":"value"}`),
|
||||
}
|
||||
assert.True(t, s.Equal(s))
|
||||
}
|
||||
|
||||
func TestSnapshotEqual_DifferentStatus(t *testing.T) {
|
||||
a := taskSnapshot{Status: TaskStatusInProgress, Data: json.RawMessage(`{}`)}
|
||||
b := taskSnapshot{Status: TaskStatusSuccess, Data: json.RawMessage(`{}`)}
|
||||
assert.False(t, a.Equal(b))
|
||||
}
|
||||
|
||||
func TestSnapshotEqual_DifferentProgress(t *testing.T) {
|
||||
a := taskSnapshot{Status: TaskStatusInProgress, Progress: "30%", Data: json.RawMessage(`{}`)}
|
||||
b := taskSnapshot{Status: TaskStatusInProgress, Progress: "60%", Data: json.RawMessage(`{}`)}
|
||||
assert.False(t, a.Equal(b))
|
||||
}
|
||||
|
||||
func TestSnapshotEqual_DifferentData(t *testing.T) {
|
||||
a := taskSnapshot{Status: TaskStatusInProgress, Data: json.RawMessage(`{"a":1}`)}
|
||||
b := taskSnapshot{Status: TaskStatusInProgress, Data: json.RawMessage(`{"a":2}`)}
|
||||
assert.False(t, a.Equal(b))
|
||||
}
|
||||
|
||||
func TestSnapshotEqual_NilVsEmpty(t *testing.T) {
|
||||
a := taskSnapshot{Status: TaskStatusInProgress, Data: nil}
|
||||
b := taskSnapshot{Status: TaskStatusInProgress, Data: json.RawMessage{}}
|
||||
// bytes.Equal(nil, []byte{}) == true
|
||||
assert.True(t, a.Equal(b))
|
||||
}
|
||||
|
||||
func TestSnapshot_Roundtrip(t *testing.T) {
|
||||
task := &Task{
|
||||
Status: TaskStatusInProgress,
|
||||
Progress: "42%",
|
||||
StartTime: 1234,
|
||||
FinishTime: 5678,
|
||||
FailReason: "timeout",
|
||||
PrivateData: TaskPrivateData{
|
||||
ResultURL: "https://example.com/result.mp4",
|
||||
},
|
||||
Data: json.RawMessage(`{"model":"test-model"}`),
|
||||
}
|
||||
snap := task.Snapshot()
|
||||
assert.Equal(t, task.Status, snap.Status)
|
||||
assert.Equal(t, task.Progress, snap.Progress)
|
||||
assert.Equal(t, task.StartTime, snap.StartTime)
|
||||
assert.Equal(t, task.FinishTime, snap.FinishTime)
|
||||
assert.Equal(t, task.FailReason, snap.FailReason)
|
||||
assert.Equal(t, task.PrivateData.ResultURL, snap.ResultURL)
|
||||
assert.JSONEq(t, string(task.Data), string(snap.Data))
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// UpdateWithStatus CAS — DB integration tests
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
func TestUpdateWithStatus_Win(t *testing.T) {
|
||||
truncateTables(t)
|
||||
|
||||
task := &Task{
|
||||
TaskID: "task_cas_win",
|
||||
Status: TaskStatusInProgress,
|
||||
Progress: "50%",
|
||||
Data: json.RawMessage(`{}`),
|
||||
}
|
||||
insertTask(t, task)
|
||||
|
||||
task.Status = TaskStatusSuccess
|
||||
task.Progress = "100%"
|
||||
won, err := task.UpdateWithStatus(TaskStatusInProgress)
|
||||
require.NoError(t, err)
|
||||
assert.True(t, won)
|
||||
|
||||
var reloaded Task
|
||||
require.NoError(t, DB.First(&reloaded, task.ID).Error)
|
||||
assert.EqualValues(t, TaskStatusSuccess, reloaded.Status)
|
||||
assert.Equal(t, "100%", reloaded.Progress)
|
||||
}
|
||||
|
||||
func TestUpdateWithStatus_Lose(t *testing.T) {
|
||||
truncateTables(t)
|
||||
|
||||
task := &Task{
|
||||
TaskID: "task_cas_lose",
|
||||
Status: TaskStatusFailure,
|
||||
Data: json.RawMessage(`{}`),
|
||||
}
|
||||
insertTask(t, task)
|
||||
|
||||
task.Status = TaskStatusSuccess
|
||||
won, err := task.UpdateWithStatus(TaskStatusInProgress) // wrong fromStatus
|
||||
require.NoError(t, err)
|
||||
assert.False(t, won)
|
||||
|
||||
var reloaded Task
|
||||
require.NoError(t, DB.First(&reloaded, task.ID).Error)
|
||||
assert.EqualValues(t, TaskStatusFailure, reloaded.Status) // unchanged
|
||||
}
|
||||
|
||||
func TestUpdateWithStatus_ConcurrentWinner(t *testing.T) {
|
||||
truncateTables(t)
|
||||
|
||||
task := &Task{
|
||||
TaskID: "task_cas_race",
|
||||
Status: TaskStatusInProgress,
|
||||
Quota: 1000,
|
||||
Data: json.RawMessage(`{}`),
|
||||
}
|
||||
insertTask(t, task)
|
||||
|
||||
const goroutines = 5
|
||||
wins := make([]bool, goroutines)
|
||||
var wg sync.WaitGroup
|
||||
wg.Add(goroutines)
|
||||
|
||||
for i := 0; i < goroutines; i++ {
|
||||
go func(idx int) {
|
||||
defer wg.Done()
|
||||
t := &Task{}
|
||||
*t = Task{
|
||||
ID: task.ID,
|
||||
TaskID: task.TaskID,
|
||||
Status: TaskStatusSuccess,
|
||||
Progress: "100%",
|
||||
Quota: task.Quota,
|
||||
Data: json.RawMessage(`{}`),
|
||||
}
|
||||
t.CreatedAt = task.CreatedAt
|
||||
t.UpdatedAt = time.Now().Unix()
|
||||
won, err := t.UpdateWithStatus(TaskStatusInProgress)
|
||||
if err == nil {
|
||||
wins[idx] = won
|
||||
}
|
||||
}(i)
|
||||
}
|
||||
wg.Wait()
|
||||
|
||||
winCount := 0
|
||||
for _, w := range wins {
|
||||
if w {
|
||||
winCount++
|
||||
}
|
||||
}
|
||||
assert.Equal(t, 1, winCount, "exactly one goroutine should win the CAS")
|
||||
}
|
||||
@@ -360,7 +360,7 @@ func DeleteTokenById(id int, userId int) (err error) {
|
||||
return token.Delete()
|
||||
}
|
||||
|
||||
func IncreaseTokenQuota(id int, key string, quota int) (err error) {
|
||||
func IncreaseTokenQuota(tokenId int, key string, quota int) (err error) {
|
||||
if quota < 0 {
|
||||
return errors.New("quota 不能为负数!")
|
||||
}
|
||||
@@ -373,10 +373,10 @@ func IncreaseTokenQuota(id int, key string, quota int) (err error) {
|
||||
})
|
||||
}
|
||||
if common.BatchUpdateEnabled {
|
||||
addNewRecord(BatchUpdateTypeTokenQuota, id, quota)
|
||||
addNewRecord(BatchUpdateTypeTokenQuota, tokenId, quota)
|
||||
return nil
|
||||
}
|
||||
return increaseTokenQuota(id, quota)
|
||||
return increaseTokenQuota(tokenId, quota)
|
||||
}
|
||||
|
||||
func increaseTokenQuota(id int, quota int) (err error) {
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
package model
|
||||
|
||||
import (
|
||||
"database/sql"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
@@ -15,6 +16,8 @@ import (
|
||||
"gorm.io/gorm"
|
||||
)
|
||||
|
||||
const UserNameMaxLength = 20
|
||||
|
||||
// User if you add sensitive fields, don't forget to clean them in setupLogin function.
|
||||
// Otherwise, the sensitive information will be saved on local storage in plain text!
|
||||
type User struct {
|
||||
@@ -536,6 +539,37 @@ func (user *User) Edit(updatePassword bool) error {
|
||||
return updateUserCache(*user)
|
||||
}
|
||||
|
||||
func (user *User) ClearBinding(bindingType string) error {
|
||||
if user.Id == 0 {
|
||||
return errors.New("user id is empty")
|
||||
}
|
||||
|
||||
bindingColumnMap := map[string]string{
|
||||
"email": "email",
|
||||
"github": "github_id",
|
||||
"discord": "discord_id",
|
||||
"oidc": "oidc_id",
|
||||
"wechat": "wechat_id",
|
||||
"telegram": "telegram_id",
|
||||
"linuxdo": "linux_do_id",
|
||||
}
|
||||
|
||||
column, ok := bindingColumnMap[bindingType]
|
||||
if !ok {
|
||||
return errors.New("invalid binding type")
|
||||
}
|
||||
|
||||
if err := DB.Model(&User{}).Where("id = ?", user.Id).Update(column, "").Error; err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if err := DB.Where("id = ?", user.Id).First(user).Error; err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return updateUserCache(*user)
|
||||
}
|
||||
|
||||
func (user *User) Delete() error {
|
||||
if user.Id == 0 {
|
||||
return errors.New("id 为空!")
|
||||
@@ -820,10 +854,17 @@ func GetUserSetting(id int, fromDB bool) (settingMap dto.UserSetting, err error)
|
||||
// Don't return error - fall through to DB
|
||||
}
|
||||
fromDB = true
|
||||
err = DB.Model(&User{}).Where("id = ?", id).Select("setting").Find(&setting).Error
|
||||
// can be nil setting
|
||||
var safeSetting sql.NullString
|
||||
err = DB.Model(&User{}).Where("id = ?", id).Select("setting").Find(&safeSetting).Error
|
||||
if err != nil {
|
||||
return settingMap, err
|
||||
}
|
||||
if safeSetting.Valid {
|
||||
setting = safeSetting.String
|
||||
} else {
|
||||
setting = ""
|
||||
}
|
||||
userBase := &UserBase{
|
||||
Setting: setting,
|
||||
}
|
||||
|
||||
404
oauth/generic.go
404
oauth/generic.go
@@ -3,19 +3,24 @@ package oauth
|
||||
import (
|
||||
"context"
|
||||
"encoding/base64"
|
||||
"encoding/json"
|
||||
stdjson "encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"regexp"
|
||||
"strconv"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/QuantumNous/new-api/common"
|
||||
"github.com/QuantumNous/new-api/i18n"
|
||||
"github.com/QuantumNous/new-api/logger"
|
||||
"github.com/QuantumNous/new-api/model"
|
||||
"github.com/QuantumNous/new-api/setting/system_setting"
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/samber/lo"
|
||||
"github.com/tidwall/gjson"
|
||||
)
|
||||
|
||||
@@ -31,6 +36,40 @@ type GenericOAuthProvider struct {
|
||||
config *model.CustomOAuthProvider
|
||||
}
|
||||
|
||||
type accessPolicy struct {
|
||||
Logic string `json:"logic"`
|
||||
Conditions []accessCondition `json:"conditions"`
|
||||
Groups []accessPolicy `json:"groups"`
|
||||
}
|
||||
|
||||
type accessCondition struct {
|
||||
Field string `json:"field"`
|
||||
Op string `json:"op"`
|
||||
Value any `json:"value"`
|
||||
}
|
||||
|
||||
type accessPolicyFailure struct {
|
||||
Field string
|
||||
Op string
|
||||
Expected any
|
||||
Current any
|
||||
}
|
||||
|
||||
var supportedAccessPolicyOps = []string{
|
||||
"eq",
|
||||
"ne",
|
||||
"gt",
|
||||
"gte",
|
||||
"lt",
|
||||
"lte",
|
||||
"in",
|
||||
"not_in",
|
||||
"contains",
|
||||
"not_contains",
|
||||
"exists",
|
||||
"not_exists",
|
||||
}
|
||||
|
||||
// NewGenericOAuthProvider creates a new generic OAuth provider from config
|
||||
func NewGenericOAuthProvider(config *model.CustomOAuthProvider) *GenericOAuthProvider {
|
||||
return &GenericOAuthProvider{config: config}
|
||||
@@ -125,7 +164,7 @@ func (p *GenericOAuthProvider) ExchangeToken(ctx context.Context, code string, c
|
||||
ErrorDesc string `json:"error_description"`
|
||||
}
|
||||
|
||||
if err := json.Unmarshal(body, &tokenResponse); err != nil {
|
||||
if err := common.Unmarshal(body, &tokenResponse); err != nil {
|
||||
// Try to parse as URL-encoded (some OAuth servers like GitHub return this format)
|
||||
parsedValues, parseErr := url.ParseQuery(bodyStr)
|
||||
if parseErr != nil {
|
||||
@@ -227,11 +266,30 @@ func (p *GenericOAuthProvider) GetUserInfo(ctx context.Context, token *OAuthToke
|
||||
logger.LogDebug(ctx, "[OAuth-Generic-%s] GetUserInfo success: id=%s, username=%s, name=%s, email=%s",
|
||||
p.config.Slug, userId, username, displayName, email)
|
||||
|
||||
policyRaw := strings.TrimSpace(p.config.AccessPolicy)
|
||||
if policyRaw != "" {
|
||||
policy, err := parseAccessPolicy(policyRaw)
|
||||
if err != nil {
|
||||
logger.LogError(ctx, fmt.Sprintf("[OAuth-Generic-%s] invalid access policy: %s", p.config.Slug, err.Error()))
|
||||
return nil, NewOAuthErrorWithRaw(i18n.MsgOAuthGetUserErr, nil, "invalid access policy configuration")
|
||||
}
|
||||
allowed, failure := evaluateAccessPolicy(bodyStr, policy)
|
||||
if !allowed {
|
||||
message := renderAccessDeniedMessage(p.config.AccessDeniedMessage, p.config.Name, bodyStr, failure)
|
||||
logger.LogWarn(ctx, fmt.Sprintf("[OAuth-Generic-%s] access denied by policy: field=%s op=%s expected=%v current=%v",
|
||||
p.config.Slug, failure.Field, failure.Op, failure.Expected, failure.Current))
|
||||
return nil, &AccessDeniedError{Message: message}
|
||||
}
|
||||
}
|
||||
|
||||
return &OAuthUser{
|
||||
ProviderUserID: userId,
|
||||
Username: username,
|
||||
DisplayName: displayName,
|
||||
Email: email,
|
||||
Extra: map[string]any{
|
||||
"provider": p.config.Slug,
|
||||
},
|
||||
}, nil
|
||||
}
|
||||
|
||||
@@ -266,3 +324,345 @@ func (p *GenericOAuthProvider) GetProviderId() int {
|
||||
func (p *GenericOAuthProvider) IsGenericProvider() bool {
|
||||
return true
|
||||
}
|
||||
|
||||
func parseAccessPolicy(raw string) (*accessPolicy, error) {
|
||||
var policy accessPolicy
|
||||
if err := common.UnmarshalJsonStr(raw, &policy); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if err := validateAccessPolicy(&policy); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &policy, nil
|
||||
}
|
||||
|
||||
func validateAccessPolicy(policy *accessPolicy) error {
|
||||
if policy == nil {
|
||||
return errors.New("policy is nil")
|
||||
}
|
||||
|
||||
logic := strings.ToLower(strings.TrimSpace(policy.Logic))
|
||||
if logic == "" {
|
||||
logic = "and"
|
||||
}
|
||||
if !lo.Contains([]string{"and", "or"}, logic) {
|
||||
return fmt.Errorf("unsupported policy logic: %s", logic)
|
||||
}
|
||||
policy.Logic = logic
|
||||
|
||||
if len(policy.Conditions) == 0 && len(policy.Groups) == 0 {
|
||||
return errors.New("policy requires at least one condition or group")
|
||||
}
|
||||
|
||||
for index := range policy.Conditions {
|
||||
if err := validateAccessCondition(&policy.Conditions[index], index); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
for index := range policy.Groups {
|
||||
if err := validateAccessPolicy(&policy.Groups[index]); err != nil {
|
||||
return fmt.Errorf("invalid policy group[%d]: %w", index, err)
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func validateAccessCondition(condition *accessCondition, index int) error {
|
||||
if condition == nil {
|
||||
return fmt.Errorf("condition[%d] is nil", index)
|
||||
}
|
||||
|
||||
condition.Field = strings.TrimSpace(condition.Field)
|
||||
if condition.Field == "" {
|
||||
return fmt.Errorf("condition[%d].field is required", index)
|
||||
}
|
||||
|
||||
condition.Op = normalizePolicyOp(condition.Op)
|
||||
if !lo.Contains(supportedAccessPolicyOps, condition.Op) {
|
||||
return fmt.Errorf("condition[%d].op is unsupported: %s", index, condition.Op)
|
||||
}
|
||||
|
||||
if lo.Contains([]string{"in", "not_in"}, condition.Op) {
|
||||
if _, ok := condition.Value.([]any); !ok {
|
||||
return fmt.Errorf("condition[%d].value must be an array for op %s", index, condition.Op)
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func evaluateAccessPolicy(body string, policy *accessPolicy) (bool, *accessPolicyFailure) {
|
||||
if policy == nil {
|
||||
return true, nil
|
||||
}
|
||||
|
||||
logic := strings.ToLower(strings.TrimSpace(policy.Logic))
|
||||
if logic == "" {
|
||||
logic = "and"
|
||||
}
|
||||
|
||||
hasAny := len(policy.Conditions) > 0 || len(policy.Groups) > 0
|
||||
if !hasAny {
|
||||
return true, nil
|
||||
}
|
||||
|
||||
if logic == "or" {
|
||||
var firstFailure *accessPolicyFailure
|
||||
for _, cond := range policy.Conditions {
|
||||
ok, failure := evaluateAccessCondition(body, cond)
|
||||
if ok {
|
||||
return true, nil
|
||||
}
|
||||
if firstFailure == nil {
|
||||
firstFailure = failure
|
||||
}
|
||||
}
|
||||
for _, group := range policy.Groups {
|
||||
ok, failure := evaluateAccessPolicy(body, &group)
|
||||
if ok {
|
||||
return true, nil
|
||||
}
|
||||
if firstFailure == nil {
|
||||
firstFailure = failure
|
||||
}
|
||||
}
|
||||
return false, firstFailure
|
||||
}
|
||||
|
||||
for _, cond := range policy.Conditions {
|
||||
ok, failure := evaluateAccessCondition(body, cond)
|
||||
if !ok {
|
||||
return false, failure
|
||||
}
|
||||
}
|
||||
for _, group := range policy.Groups {
|
||||
ok, failure := evaluateAccessPolicy(body, &group)
|
||||
if !ok {
|
||||
return false, failure
|
||||
}
|
||||
}
|
||||
return true, nil
|
||||
}
|
||||
|
||||
func evaluateAccessCondition(body string, cond accessCondition) (bool, *accessPolicyFailure) {
|
||||
path := cond.Field
|
||||
op := cond.Op
|
||||
result := gjson.Get(body, path)
|
||||
current := gjsonResultToValue(result)
|
||||
failure := &accessPolicyFailure{
|
||||
Field: path,
|
||||
Op: op,
|
||||
Expected: cond.Value,
|
||||
Current: current,
|
||||
}
|
||||
|
||||
switch op {
|
||||
case "exists":
|
||||
return result.Exists(), failure
|
||||
case "not_exists":
|
||||
return !result.Exists(), failure
|
||||
case "eq":
|
||||
return compareAny(current, cond.Value) == 0, failure
|
||||
case "ne":
|
||||
return compareAny(current, cond.Value) != 0, failure
|
||||
case "gt":
|
||||
return compareAny(current, cond.Value) > 0, failure
|
||||
case "gte":
|
||||
return compareAny(current, cond.Value) >= 0, failure
|
||||
case "lt":
|
||||
return compareAny(current, cond.Value) < 0, failure
|
||||
case "lte":
|
||||
return compareAny(current, cond.Value) <= 0, failure
|
||||
case "in":
|
||||
return valueInSlice(current, cond.Value), failure
|
||||
case "not_in":
|
||||
return !valueInSlice(current, cond.Value), failure
|
||||
case "contains":
|
||||
return containsValue(current, cond.Value), failure
|
||||
case "not_contains":
|
||||
return !containsValue(current, cond.Value), failure
|
||||
default:
|
||||
return false, failure
|
||||
}
|
||||
}
|
||||
|
||||
func normalizePolicyOp(op string) string {
|
||||
return strings.ToLower(strings.TrimSpace(op))
|
||||
}
|
||||
|
||||
func gjsonResultToValue(result gjson.Result) any {
|
||||
if !result.Exists() {
|
||||
return nil
|
||||
}
|
||||
if result.IsArray() {
|
||||
arr := result.Array()
|
||||
values := make([]any, 0, len(arr))
|
||||
for _, item := range arr {
|
||||
values = append(values, gjsonResultToValue(item))
|
||||
}
|
||||
return values
|
||||
}
|
||||
switch result.Type {
|
||||
case gjson.Null:
|
||||
return nil
|
||||
case gjson.True:
|
||||
return true
|
||||
case gjson.False:
|
||||
return false
|
||||
case gjson.Number:
|
||||
return result.Num
|
||||
case gjson.String:
|
||||
return result.String()
|
||||
case gjson.JSON:
|
||||
var data any
|
||||
if err := common.UnmarshalJsonStr(result.Raw, &data); err == nil {
|
||||
return data
|
||||
}
|
||||
return result.Raw
|
||||
default:
|
||||
return result.Value()
|
||||
}
|
||||
}
|
||||
|
||||
func compareAny(left any, right any) int {
|
||||
if lf, ok := toFloat(left); ok {
|
||||
if rf, ok2 := toFloat(right); ok2 {
|
||||
switch {
|
||||
case lf < rf:
|
||||
return -1
|
||||
case lf > rf:
|
||||
return 1
|
||||
default:
|
||||
return 0
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
ls := strings.TrimSpace(fmt.Sprint(left))
|
||||
rs := strings.TrimSpace(fmt.Sprint(right))
|
||||
switch {
|
||||
case ls < rs:
|
||||
return -1
|
||||
case ls > rs:
|
||||
return 1
|
||||
default:
|
||||
return 0
|
||||
}
|
||||
}
|
||||
|
||||
func toFloat(v any) (float64, bool) {
|
||||
switch value := v.(type) {
|
||||
case float64:
|
||||
return value, true
|
||||
case float32:
|
||||
return float64(value), true
|
||||
case int:
|
||||
return float64(value), true
|
||||
case int8:
|
||||
return float64(value), true
|
||||
case int16:
|
||||
return float64(value), true
|
||||
case int32:
|
||||
return float64(value), true
|
||||
case int64:
|
||||
return float64(value), true
|
||||
case uint:
|
||||
return float64(value), true
|
||||
case uint8:
|
||||
return float64(value), true
|
||||
case uint16:
|
||||
return float64(value), true
|
||||
case uint32:
|
||||
return float64(value), true
|
||||
case uint64:
|
||||
return float64(value), true
|
||||
case stdjson.Number:
|
||||
n, err := value.Float64()
|
||||
if err == nil {
|
||||
return n, true
|
||||
}
|
||||
case string:
|
||||
n, err := strconv.ParseFloat(strings.TrimSpace(value), 64)
|
||||
if err == nil {
|
||||
return n, true
|
||||
}
|
||||
}
|
||||
return 0, false
|
||||
}
|
||||
|
||||
func valueInSlice(current any, expected any) bool {
|
||||
list, ok := expected.([]any)
|
||||
if !ok {
|
||||
return false
|
||||
}
|
||||
return lo.ContainsBy(list, func(item any) bool {
|
||||
return compareAny(current, item) == 0
|
||||
})
|
||||
}
|
||||
|
||||
func containsValue(current any, expected any) bool {
|
||||
switch value := current.(type) {
|
||||
case string:
|
||||
target := strings.TrimSpace(fmt.Sprint(expected))
|
||||
return strings.Contains(value, target)
|
||||
case []any:
|
||||
return lo.ContainsBy(value, func(item any) bool {
|
||||
return compareAny(item, expected) == 0
|
||||
})
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func renderAccessDeniedMessage(template string, providerName string, body string, failure *accessPolicyFailure) string {
|
||||
defaultMessage := "Access denied: your account does not meet this provider's access requirements."
|
||||
message := strings.TrimSpace(template)
|
||||
if message == "" {
|
||||
return defaultMessage
|
||||
}
|
||||
|
||||
if failure == nil {
|
||||
failure = &accessPolicyFailure{}
|
||||
}
|
||||
|
||||
replacements := map[string]string{
|
||||
"{{provider}}": providerName,
|
||||
"{{field}}": failure.Field,
|
||||
"{{op}}": failure.Op,
|
||||
"{{required}}": fmt.Sprint(failure.Expected),
|
||||
"{{current}}": fmt.Sprint(failure.Current),
|
||||
}
|
||||
|
||||
for key, value := range replacements {
|
||||
message = strings.ReplaceAll(message, key, value)
|
||||
}
|
||||
|
||||
currentPattern := regexp.MustCompile(`\{\{current\.([^}]+)\}\}`)
|
||||
message = currentPattern.ReplaceAllStringFunc(message, func(token string) string {
|
||||
match := currentPattern.FindStringSubmatch(token)
|
||||
if len(match) != 2 {
|
||||
return ""
|
||||
}
|
||||
path := strings.TrimSpace(match[1])
|
||||
if path == "" {
|
||||
return ""
|
||||
}
|
||||
return strings.TrimSpace(gjson.Get(body, path).String())
|
||||
})
|
||||
|
||||
requiredPattern := regexp.MustCompile(`\{\{required\.([^}]+)\}\}`)
|
||||
message = requiredPattern.ReplaceAllStringFunc(message, func(token string) string {
|
||||
match := requiredPattern.FindStringSubmatch(token)
|
||||
if len(match) != 2 {
|
||||
return ""
|
||||
}
|
||||
path := strings.TrimSpace(match[1])
|
||||
if failure.Field == path {
|
||||
return fmt.Sprint(failure.Expected)
|
||||
}
|
||||
return ""
|
||||
})
|
||||
|
||||
return strings.TrimSpace(message)
|
||||
}
|
||||
|
||||
@@ -57,3 +57,12 @@ func NewOAuthErrorWithRaw(msgKey string, params map[string]any, rawError string)
|
||||
RawError: rawError,
|
||||
}
|
||||
}
|
||||
|
||||
// AccessDeniedError is a direct user-facing access denial message.
|
||||
type AccessDeniedError struct {
|
||||
Message string
|
||||
}
|
||||
|
||||
func (e *AccessDeniedError) Error() string {
|
||||
return e.Message
|
||||
}
|
||||
|
||||
@@ -36,6 +36,32 @@ type TaskAdaptor interface {
|
||||
|
||||
ValidateRequestAndSetAction(c *gin.Context, info *relaycommon.RelayInfo) *dto.TaskError
|
||||
|
||||
// ── Billing ──────────────────────────────────────────────────────
|
||||
|
||||
// EstimateBilling returns OtherRatios for pre-charge based on user request.
|
||||
// Called after ValidateRequestAndSetAction, before price calculation.
|
||||
// Adaptors should extract duration, resolution, etc. from the parsed request
|
||||
// and return them as ratio multipliers (e.g. {"seconds": 5, "size": 1.666}).
|
||||
// Return nil to use the base model price without extra ratios.
|
||||
EstimateBilling(c *gin.Context, info *relaycommon.RelayInfo) map[string]float64
|
||||
|
||||
// AdjustBillingOnSubmit returns adjusted OtherRatios from the upstream
|
||||
// submit response. Called after a successful DoResponse.
|
||||
// If the upstream returned actual parameters that differ from the estimate
|
||||
// (e.g. actual seconds), return updated ratios so the caller can recalculate
|
||||
// the quota and settle the delta with the pre-charge.
|
||||
// Return nil if no adjustment is needed.
|
||||
AdjustBillingOnSubmit(info *relaycommon.RelayInfo, taskData []byte) map[string]float64
|
||||
|
||||
// AdjustBillingOnComplete returns the actual quota when a task reaches a
|
||||
// terminal state (success/failure) during polling.
|
||||
// Called by the polling loop after ParseTaskResult.
|
||||
// Return a positive value to trigger delta settlement (supplement / refund).
|
||||
// Return 0 to keep the pre-charged amount unchanged.
|
||||
AdjustBillingOnComplete(task *model.Task, taskResult *relaycommon.TaskInfo) int
|
||||
|
||||
// ── Request / Response ───────────────────────────────────────────
|
||||
|
||||
BuildRequestURL(info *relaycommon.RelayInfo) (string, error)
|
||||
BuildRequestHeader(c *gin.Context, req *http.Request, info *relaycommon.RelayInfo) error
|
||||
BuildRequestBody(c *gin.Context, info *relaycommon.RelayInfo) (io.Reader, error)
|
||||
@@ -46,9 +72,9 @@ type TaskAdaptor interface {
|
||||
GetModelList() []string
|
||||
GetChannelName() string
|
||||
|
||||
// FetchTask
|
||||
FetchTask(baseUrl, key string, body map[string]any, proxy string) (*http.Response, error)
|
||||
// ── Polling ──────────────────────────────────────────────────────
|
||||
|
||||
FetchTask(baseUrl, key string, body map[string]any, proxy string) (*http.Response, error)
|
||||
ParseTaskResult(respBody []byte) (*relaycommon.TaskInfo, error)
|
||||
}
|
||||
|
||||
|
||||
@@ -61,8 +61,9 @@ var passthroughSkipHeaderNamesLower = map[string]struct{}{
|
||||
"cookie": {},
|
||||
|
||||
// Additional headers that should not be forwarded by name-matching passthrough rules.
|
||||
"host": {},
|
||||
"content-length": {},
|
||||
"host": {},
|
||||
"content-length": {},
|
||||
"accept-encoding": {},
|
||||
|
||||
// Do not passthrough credentials by wildcard/regex.
|
||||
"authorization": {},
|
||||
|
||||
@@ -110,3 +110,30 @@ func TestProcessHeaderOverride_RuntimeOverrideHasPriority(t *testing.T) {
|
||||
_, ok := headers["X-Legacy"]
|
||||
require.False(t, ok)
|
||||
}
|
||||
|
||||
func TestProcessHeaderOverride_PassthroughSkipsAcceptEncoding(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
gin.SetMode(gin.TestMode)
|
||||
recorder := httptest.NewRecorder()
|
||||
ctx, _ := gin.CreateTestContext(recorder)
|
||||
ctx.Request = httptest.NewRequest(http.MethodPost, "/v1/chat/completions", nil)
|
||||
ctx.Request.Header.Set("X-Trace-Id", "trace-123")
|
||||
ctx.Request.Header.Set("Accept-Encoding", "gzip")
|
||||
|
||||
info := &relaycommon.RelayInfo{
|
||||
IsChannelTest: false,
|
||||
ChannelMeta: &relaycommon.ChannelMeta{
|
||||
HeadersOverride: map[string]any{
|
||||
"*": "",
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
headers, err := processHeaderOverride(info, ctx)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, "trace-123", headers["X-Trace-Id"])
|
||||
|
||||
_, hasAcceptEncoding := headers["Accept-Encoding"]
|
||||
require.False(t, hasAcceptEncoding)
|
||||
}
|
||||
|
||||
@@ -42,22 +42,7 @@ func GeminiTextGenerationHandler(c *gin.Context, info *relaycommon.RelayInfo, re
|
||||
}
|
||||
|
||||
// 计算使用量(基于 UsageMetadata)
|
||||
usage := dto.Usage{
|
||||
PromptTokens: geminiResponse.UsageMetadata.PromptTokenCount,
|
||||
CompletionTokens: geminiResponse.UsageMetadata.CandidatesTokenCount + geminiResponse.UsageMetadata.ThoughtsTokenCount,
|
||||
TotalTokens: geminiResponse.UsageMetadata.TotalTokenCount,
|
||||
}
|
||||
|
||||
usage.CompletionTokenDetails.ReasoningTokens = geminiResponse.UsageMetadata.ThoughtsTokenCount
|
||||
usage.PromptTokensDetails.CachedTokens = geminiResponse.UsageMetadata.CachedContentTokenCount
|
||||
|
||||
for _, detail := range geminiResponse.UsageMetadata.PromptTokensDetails {
|
||||
if detail.Modality == "AUDIO" {
|
||||
usage.PromptTokensDetails.AudioTokens = detail.TokenCount
|
||||
} else if detail.Modality == "TEXT" {
|
||||
usage.PromptTokensDetails.TextTokens = detail.TokenCount
|
||||
}
|
||||
}
|
||||
usage := buildUsageFromGeminiMetadata(geminiResponse.UsageMetadata, info.GetEstimatePromptTokens())
|
||||
|
||||
service.IOCopyBytesGracefully(c, resp, responseBody)
|
||||
|
||||
|
||||
@@ -1032,6 +1032,46 @@ func getResponseToolCall(item *dto.GeminiPart) *dto.ToolCallResponse {
|
||||
}
|
||||
}
|
||||
|
||||
func buildUsageFromGeminiMetadata(metadata dto.GeminiUsageMetadata, fallbackPromptTokens int) dto.Usage {
|
||||
promptTokens := metadata.PromptTokenCount + metadata.ToolUsePromptTokenCount
|
||||
if promptTokens <= 0 && fallbackPromptTokens > 0 {
|
||||
promptTokens = fallbackPromptTokens
|
||||
}
|
||||
|
||||
usage := dto.Usage{
|
||||
PromptTokens: promptTokens,
|
||||
CompletionTokens: metadata.CandidatesTokenCount + metadata.ThoughtsTokenCount,
|
||||
TotalTokens: metadata.TotalTokenCount,
|
||||
}
|
||||
usage.CompletionTokenDetails.ReasoningTokens = metadata.ThoughtsTokenCount
|
||||
usage.PromptTokensDetails.CachedTokens = metadata.CachedContentTokenCount
|
||||
|
||||
for _, detail := range metadata.PromptTokensDetails {
|
||||
if detail.Modality == "AUDIO" {
|
||||
usage.PromptTokensDetails.AudioTokens += detail.TokenCount
|
||||
} else if detail.Modality == "TEXT" {
|
||||
usage.PromptTokensDetails.TextTokens += detail.TokenCount
|
||||
}
|
||||
}
|
||||
for _, detail := range metadata.ToolUsePromptTokensDetails {
|
||||
if detail.Modality == "AUDIO" {
|
||||
usage.PromptTokensDetails.AudioTokens += detail.TokenCount
|
||||
} else if detail.Modality == "TEXT" {
|
||||
usage.PromptTokensDetails.TextTokens += detail.TokenCount
|
||||
}
|
||||
}
|
||||
|
||||
if usage.TotalTokens > 0 && usage.CompletionTokens <= 0 {
|
||||
usage.CompletionTokens = usage.TotalTokens - usage.PromptTokens
|
||||
}
|
||||
|
||||
if usage.PromptTokens > 0 && usage.PromptTokensDetails.TextTokens == 0 && usage.PromptTokensDetails.AudioTokens == 0 {
|
||||
usage.PromptTokensDetails.TextTokens = usage.PromptTokens
|
||||
}
|
||||
|
||||
return usage
|
||||
}
|
||||
|
||||
func responseGeminiChat2OpenAI(c *gin.Context, response *dto.GeminiChatResponse) *dto.OpenAITextResponse {
|
||||
fullTextResponse := dto.OpenAITextResponse{
|
||||
Id: helper.GetResponseID(c),
|
||||
@@ -1272,18 +1312,8 @@ func geminiStreamHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http
|
||||
|
||||
// 更新使用量统计
|
||||
if geminiResponse.UsageMetadata.TotalTokenCount != 0 {
|
||||
usage.PromptTokens = geminiResponse.UsageMetadata.PromptTokenCount
|
||||
usage.CompletionTokens = geminiResponse.UsageMetadata.CandidatesTokenCount + geminiResponse.UsageMetadata.ThoughtsTokenCount
|
||||
usage.CompletionTokenDetails.ReasoningTokens = geminiResponse.UsageMetadata.ThoughtsTokenCount
|
||||
usage.TotalTokens = geminiResponse.UsageMetadata.TotalTokenCount
|
||||
usage.PromptTokensDetails.CachedTokens = geminiResponse.UsageMetadata.CachedContentTokenCount
|
||||
for _, detail := range geminiResponse.UsageMetadata.PromptTokensDetails {
|
||||
if detail.Modality == "AUDIO" {
|
||||
usage.PromptTokensDetails.AudioTokens = detail.TokenCount
|
||||
} else if detail.Modality == "TEXT" {
|
||||
usage.PromptTokensDetails.TextTokens = detail.TokenCount
|
||||
}
|
||||
}
|
||||
mappedUsage := buildUsageFromGeminiMetadata(geminiResponse.UsageMetadata, info.GetEstimatePromptTokens())
|
||||
*usage = mappedUsage
|
||||
}
|
||||
|
||||
return callback(data, &geminiResponse)
|
||||
@@ -1295,11 +1325,6 @@ func geminiStreamHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http
|
||||
}
|
||||
}
|
||||
|
||||
usage.PromptTokensDetails.TextTokens = usage.PromptTokens
|
||||
if usage.TotalTokens > 0 {
|
||||
usage.CompletionTokens = usage.TotalTokens - usage.PromptTokens
|
||||
}
|
||||
|
||||
if usage.CompletionTokens <= 0 {
|
||||
if info.ReceivedResponseCount > 0 {
|
||||
usage = service.ResponseText2Usage(c, responseText.String(), info.UpstreamModelName, info.GetEstimatePromptTokens())
|
||||
@@ -1416,21 +1441,7 @@ func GeminiChatHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http.R
|
||||
return nil, types.NewOpenAIError(err, types.ErrorCodeBadResponseBody, http.StatusInternalServerError)
|
||||
}
|
||||
if len(geminiResponse.Candidates) == 0 {
|
||||
usage := dto.Usage{
|
||||
PromptTokens: geminiResponse.UsageMetadata.PromptTokenCount,
|
||||
}
|
||||
usage.CompletionTokenDetails.ReasoningTokens = geminiResponse.UsageMetadata.ThoughtsTokenCount
|
||||
usage.PromptTokensDetails.CachedTokens = geminiResponse.UsageMetadata.CachedContentTokenCount
|
||||
for _, detail := range geminiResponse.UsageMetadata.PromptTokensDetails {
|
||||
if detail.Modality == "AUDIO" {
|
||||
usage.PromptTokensDetails.AudioTokens = detail.TokenCount
|
||||
} else if detail.Modality == "TEXT" {
|
||||
usage.PromptTokensDetails.TextTokens = detail.TokenCount
|
||||
}
|
||||
}
|
||||
if usage.PromptTokens <= 0 {
|
||||
usage.PromptTokens = info.GetEstimatePromptTokens()
|
||||
}
|
||||
usage := buildUsageFromGeminiMetadata(geminiResponse.UsageMetadata, info.GetEstimatePromptTokens())
|
||||
|
||||
var newAPIError *types.NewAPIError
|
||||
if geminiResponse.PromptFeedback != nil && geminiResponse.PromptFeedback.BlockReason != nil {
|
||||
@@ -1466,23 +1477,7 @@ func GeminiChatHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http.R
|
||||
}
|
||||
fullTextResponse := responseGeminiChat2OpenAI(c, &geminiResponse)
|
||||
fullTextResponse.Model = info.UpstreamModelName
|
||||
usage := dto.Usage{
|
||||
PromptTokens: geminiResponse.UsageMetadata.PromptTokenCount,
|
||||
CompletionTokens: geminiResponse.UsageMetadata.CandidatesTokenCount,
|
||||
TotalTokens: geminiResponse.UsageMetadata.TotalTokenCount,
|
||||
}
|
||||
|
||||
usage.CompletionTokenDetails.ReasoningTokens = geminiResponse.UsageMetadata.ThoughtsTokenCount
|
||||
usage.PromptTokensDetails.CachedTokens = geminiResponse.UsageMetadata.CachedContentTokenCount
|
||||
usage.CompletionTokens = usage.TotalTokens - usage.PromptTokens
|
||||
|
||||
for _, detail := range geminiResponse.UsageMetadata.PromptTokensDetails {
|
||||
if detail.Modality == "AUDIO" {
|
||||
usage.PromptTokensDetails.AudioTokens = detail.TokenCount
|
||||
} else if detail.Modality == "TEXT" {
|
||||
usage.PromptTokensDetails.TextTokens = detail.TokenCount
|
||||
}
|
||||
}
|
||||
usage := buildUsageFromGeminiMetadata(geminiResponse.UsageMetadata, info.GetEstimatePromptTokens())
|
||||
|
||||
fullTextResponse.Usage = usage
|
||||
|
||||
|
||||
333
relay/channel/gemini/relay_gemini_usage_test.go
Normal file
333
relay/channel/gemini/relay_gemini_usage_test.go
Normal file
@@ -0,0 +1,333 @@
|
||||
package gemini
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"io"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"testing"
|
||||
|
||||
"github.com/QuantumNous/new-api/common"
|
||||
"github.com/QuantumNous/new-api/constant"
|
||||
"github.com/QuantumNous/new-api/dto"
|
||||
relaycommon "github.com/QuantumNous/new-api/relay/common"
|
||||
"github.com/QuantumNous/new-api/types"
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestGeminiChatHandlerCompletionTokensExcludeToolUsePromptTokens(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
gin.SetMode(gin.TestMode)
|
||||
c, _ := gin.CreateTestContext(httptest.NewRecorder())
|
||||
c.Request = httptest.NewRequest(http.MethodPost, "/v1/chat/completions", nil)
|
||||
|
||||
info := &relaycommon.RelayInfo{
|
||||
RelayFormat: types.RelayFormatGemini,
|
||||
OriginModelName: "gemini-3-flash-preview",
|
||||
ChannelMeta: &relaycommon.ChannelMeta{
|
||||
UpstreamModelName: "gemini-3-flash-preview",
|
||||
},
|
||||
}
|
||||
|
||||
payload := dto.GeminiChatResponse{
|
||||
Candidates: []dto.GeminiChatCandidate{
|
||||
{
|
||||
Content: dto.GeminiChatContent{
|
||||
Role: "model",
|
||||
Parts: []dto.GeminiPart{
|
||||
{Text: "ok"},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
UsageMetadata: dto.GeminiUsageMetadata{
|
||||
PromptTokenCount: 151,
|
||||
ToolUsePromptTokenCount: 18329,
|
||||
CandidatesTokenCount: 1089,
|
||||
ThoughtsTokenCount: 1120,
|
||||
TotalTokenCount: 20689,
|
||||
},
|
||||
}
|
||||
|
||||
body, err := common.Marshal(payload)
|
||||
require.NoError(t, err)
|
||||
|
||||
resp := &http.Response{
|
||||
Body: io.NopCloser(bytes.NewReader(body)),
|
||||
}
|
||||
|
||||
usage, newAPIError := GeminiChatHandler(c, info, resp)
|
||||
require.Nil(t, newAPIError)
|
||||
require.NotNil(t, usage)
|
||||
require.Equal(t, 18480, usage.PromptTokens)
|
||||
require.Equal(t, 2209, usage.CompletionTokens)
|
||||
require.Equal(t, 20689, usage.TotalTokens)
|
||||
require.Equal(t, 1120, usage.CompletionTokenDetails.ReasoningTokens)
|
||||
}
|
||||
|
||||
func TestGeminiStreamHandlerCompletionTokensExcludeToolUsePromptTokens(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
c, _ := gin.CreateTestContext(httptest.NewRecorder())
|
||||
c.Request = httptest.NewRequest(http.MethodPost, "/v1/chat/completions", nil)
|
||||
|
||||
oldStreamingTimeout := constant.StreamingTimeout
|
||||
constant.StreamingTimeout = 300
|
||||
t.Cleanup(func() {
|
||||
constant.StreamingTimeout = oldStreamingTimeout
|
||||
})
|
||||
|
||||
info := &relaycommon.RelayInfo{
|
||||
OriginModelName: "gemini-3-flash-preview",
|
||||
ChannelMeta: &relaycommon.ChannelMeta{
|
||||
UpstreamModelName: "gemini-3-flash-preview",
|
||||
},
|
||||
}
|
||||
|
||||
chunk := dto.GeminiChatResponse{
|
||||
Candidates: []dto.GeminiChatCandidate{
|
||||
{
|
||||
Content: dto.GeminiChatContent{
|
||||
Role: "model",
|
||||
Parts: []dto.GeminiPart{
|
||||
{Text: "partial"},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
UsageMetadata: dto.GeminiUsageMetadata{
|
||||
PromptTokenCount: 151,
|
||||
ToolUsePromptTokenCount: 18329,
|
||||
CandidatesTokenCount: 1089,
|
||||
ThoughtsTokenCount: 1120,
|
||||
TotalTokenCount: 20689,
|
||||
},
|
||||
}
|
||||
|
||||
chunkData, err := common.Marshal(chunk)
|
||||
require.NoError(t, err)
|
||||
|
||||
streamBody := []byte("data: " + string(chunkData) + "\n" + "data: [DONE]\n")
|
||||
resp := &http.Response{
|
||||
Body: io.NopCloser(bytes.NewReader(streamBody)),
|
||||
}
|
||||
|
||||
usage, newAPIError := geminiStreamHandler(c, info, resp, func(_ string, _ *dto.GeminiChatResponse) bool {
|
||||
return true
|
||||
})
|
||||
require.Nil(t, newAPIError)
|
||||
require.NotNil(t, usage)
|
||||
require.Equal(t, 18480, usage.PromptTokens)
|
||||
require.Equal(t, 2209, usage.CompletionTokens)
|
||||
require.Equal(t, 20689, usage.TotalTokens)
|
||||
require.Equal(t, 1120, usage.CompletionTokenDetails.ReasoningTokens)
|
||||
}
|
||||
|
||||
func TestGeminiTextGenerationHandlerPromptTokensIncludeToolUsePromptTokens(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
gin.SetMode(gin.TestMode)
|
||||
c, _ := gin.CreateTestContext(httptest.NewRecorder())
|
||||
c.Request = httptest.NewRequest(http.MethodPost, "/v1beta/models/gemini-3-flash-preview:generateContent", nil)
|
||||
|
||||
info := &relaycommon.RelayInfo{
|
||||
OriginModelName: "gemini-3-flash-preview",
|
||||
ChannelMeta: &relaycommon.ChannelMeta{
|
||||
UpstreamModelName: "gemini-3-flash-preview",
|
||||
},
|
||||
}
|
||||
|
||||
payload := dto.GeminiChatResponse{
|
||||
Candidates: []dto.GeminiChatCandidate{
|
||||
{
|
||||
Content: dto.GeminiChatContent{
|
||||
Role: "model",
|
||||
Parts: []dto.GeminiPart{
|
||||
{Text: "ok"},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
UsageMetadata: dto.GeminiUsageMetadata{
|
||||
PromptTokenCount: 151,
|
||||
ToolUsePromptTokenCount: 18329,
|
||||
CandidatesTokenCount: 1089,
|
||||
ThoughtsTokenCount: 1120,
|
||||
TotalTokenCount: 20689,
|
||||
},
|
||||
}
|
||||
|
||||
body, err := common.Marshal(payload)
|
||||
require.NoError(t, err)
|
||||
|
||||
resp := &http.Response{
|
||||
Body: io.NopCloser(bytes.NewReader(body)),
|
||||
}
|
||||
|
||||
usage, newAPIError := GeminiTextGenerationHandler(c, info, resp)
|
||||
require.Nil(t, newAPIError)
|
||||
require.NotNil(t, usage)
|
||||
require.Equal(t, 18480, usage.PromptTokens)
|
||||
require.Equal(t, 2209, usage.CompletionTokens)
|
||||
require.Equal(t, 20689, usage.TotalTokens)
|
||||
require.Equal(t, 1120, usage.CompletionTokenDetails.ReasoningTokens)
|
||||
}
|
||||
|
||||
func TestGeminiChatHandlerUsesEstimatedPromptTokensWhenUsagePromptMissing(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
gin.SetMode(gin.TestMode)
|
||||
c, _ := gin.CreateTestContext(httptest.NewRecorder())
|
||||
c.Request = httptest.NewRequest(http.MethodPost, "/v1/chat/completions", nil)
|
||||
|
||||
info := &relaycommon.RelayInfo{
|
||||
RelayFormat: types.RelayFormatGemini,
|
||||
OriginModelName: "gemini-3-flash-preview",
|
||||
ChannelMeta: &relaycommon.ChannelMeta{
|
||||
UpstreamModelName: "gemini-3-flash-preview",
|
||||
},
|
||||
}
|
||||
info.SetEstimatePromptTokens(20)
|
||||
|
||||
payload := dto.GeminiChatResponse{
|
||||
Candidates: []dto.GeminiChatCandidate{
|
||||
{
|
||||
Content: dto.GeminiChatContent{
|
||||
Role: "model",
|
||||
Parts: []dto.GeminiPart{
|
||||
{Text: "ok"},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
UsageMetadata: dto.GeminiUsageMetadata{
|
||||
PromptTokenCount: 0,
|
||||
ToolUsePromptTokenCount: 0,
|
||||
CandidatesTokenCount: 90,
|
||||
ThoughtsTokenCount: 10,
|
||||
TotalTokenCount: 110,
|
||||
},
|
||||
}
|
||||
|
||||
body, err := common.Marshal(payload)
|
||||
require.NoError(t, err)
|
||||
|
||||
resp := &http.Response{
|
||||
Body: io.NopCloser(bytes.NewReader(body)),
|
||||
}
|
||||
|
||||
usage, newAPIError := GeminiChatHandler(c, info, resp)
|
||||
require.Nil(t, newAPIError)
|
||||
require.NotNil(t, usage)
|
||||
require.Equal(t, 20, usage.PromptTokens)
|
||||
require.Equal(t, 100, usage.CompletionTokens)
|
||||
require.Equal(t, 110, usage.TotalTokens)
|
||||
}
|
||||
|
||||
func TestGeminiStreamHandlerUsesEstimatedPromptTokensWhenUsagePromptMissing(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
c, _ := gin.CreateTestContext(httptest.NewRecorder())
|
||||
c.Request = httptest.NewRequest(http.MethodPost, "/v1/chat/completions", nil)
|
||||
|
||||
oldStreamingTimeout := constant.StreamingTimeout
|
||||
constant.StreamingTimeout = 300
|
||||
t.Cleanup(func() {
|
||||
constant.StreamingTimeout = oldStreamingTimeout
|
||||
})
|
||||
|
||||
info := &relaycommon.RelayInfo{
|
||||
OriginModelName: "gemini-3-flash-preview",
|
||||
ChannelMeta: &relaycommon.ChannelMeta{
|
||||
UpstreamModelName: "gemini-3-flash-preview",
|
||||
},
|
||||
}
|
||||
info.SetEstimatePromptTokens(20)
|
||||
|
||||
chunk := dto.GeminiChatResponse{
|
||||
Candidates: []dto.GeminiChatCandidate{
|
||||
{
|
||||
Content: dto.GeminiChatContent{
|
||||
Role: "model",
|
||||
Parts: []dto.GeminiPart{
|
||||
{Text: "partial"},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
UsageMetadata: dto.GeminiUsageMetadata{
|
||||
PromptTokenCount: 0,
|
||||
ToolUsePromptTokenCount: 0,
|
||||
CandidatesTokenCount: 90,
|
||||
ThoughtsTokenCount: 10,
|
||||
TotalTokenCount: 110,
|
||||
},
|
||||
}
|
||||
|
||||
chunkData, err := common.Marshal(chunk)
|
||||
require.NoError(t, err)
|
||||
|
||||
streamBody := []byte("data: " + string(chunkData) + "\n" + "data: [DONE]\n")
|
||||
resp := &http.Response{
|
||||
Body: io.NopCloser(bytes.NewReader(streamBody)),
|
||||
}
|
||||
|
||||
usage, newAPIError := geminiStreamHandler(c, info, resp, func(_ string, _ *dto.GeminiChatResponse) bool {
|
||||
return true
|
||||
})
|
||||
require.Nil(t, newAPIError)
|
||||
require.NotNil(t, usage)
|
||||
require.Equal(t, 20, usage.PromptTokens)
|
||||
require.Equal(t, 100, usage.CompletionTokens)
|
||||
require.Equal(t, 110, usage.TotalTokens)
|
||||
}
|
||||
|
||||
func TestGeminiTextGenerationHandlerUsesEstimatedPromptTokensWhenUsagePromptMissing(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
gin.SetMode(gin.TestMode)
|
||||
c, _ := gin.CreateTestContext(httptest.NewRecorder())
|
||||
c.Request = httptest.NewRequest(http.MethodPost, "/v1beta/models/gemini-3-flash-preview:generateContent", nil)
|
||||
|
||||
info := &relaycommon.RelayInfo{
|
||||
OriginModelName: "gemini-3-flash-preview",
|
||||
ChannelMeta: &relaycommon.ChannelMeta{
|
||||
UpstreamModelName: "gemini-3-flash-preview",
|
||||
},
|
||||
}
|
||||
info.SetEstimatePromptTokens(20)
|
||||
|
||||
payload := dto.GeminiChatResponse{
|
||||
Candidates: []dto.GeminiChatCandidate{
|
||||
{
|
||||
Content: dto.GeminiChatContent{
|
||||
Role: "model",
|
||||
Parts: []dto.GeminiPart{
|
||||
{Text: "ok"},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
UsageMetadata: dto.GeminiUsageMetadata{
|
||||
PromptTokenCount: 0,
|
||||
ToolUsePromptTokenCount: 0,
|
||||
CandidatesTokenCount: 90,
|
||||
ThoughtsTokenCount: 10,
|
||||
TotalTokenCount: 110,
|
||||
},
|
||||
}
|
||||
|
||||
body, err := common.Marshal(payload)
|
||||
require.NoError(t, err)
|
||||
|
||||
resp := &http.Response{
|
||||
Body: io.NopCloser(bytes.NewReader(body)),
|
||||
}
|
||||
|
||||
usage, newAPIError := GeminiTextGenerationHandler(c, info, resp)
|
||||
require.Nil(t, newAPIError)
|
||||
require.NotNil(t, usage)
|
||||
require.Equal(t, 20, usage.PromptTokens)
|
||||
require.Equal(t, 100, usage.CompletionTokens)
|
||||
require.Equal(t, 110, usage.TotalTokens)
|
||||
}
|
||||
@@ -10,6 +10,7 @@ import (
|
||||
|
||||
"github.com/QuantumNous/new-api/dto"
|
||||
"github.com/QuantumNous/new-api/relay/channel"
|
||||
"github.com/QuantumNous/new-api/relay/channel/claude"
|
||||
"github.com/QuantumNous/new-api/relay/channel/openai"
|
||||
relaycommon "github.com/QuantumNous/new-api/relay/common"
|
||||
"github.com/QuantumNous/new-api/relay/constant"
|
||||
@@ -26,7 +27,8 @@ func (a *Adaptor) ConvertGeminiRequest(*gin.Context, *relaycommon.RelayInfo, *dt
|
||||
}
|
||||
|
||||
func (a *Adaptor) ConvertClaudeRequest(c *gin.Context, info *relaycommon.RelayInfo, req *dto.ClaudeRequest) (any, error) {
|
||||
return nil, errors.New("not implemented")
|
||||
adaptor := claude.Adaptor{}
|
||||
return adaptor.ConvertClaudeRequest(c, info, req)
|
||||
}
|
||||
|
||||
func (a *Adaptor) ConvertAudioRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.AudioRequest) (io.Reader, error) {
|
||||
@@ -119,8 +121,14 @@ func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycom
|
||||
return handleTTSResponse(c, resp, info)
|
||||
}
|
||||
|
||||
adaptor := openai.Adaptor{}
|
||||
return adaptor.DoResponse(c, resp, info)
|
||||
switch info.RelayFormat {
|
||||
case types.RelayFormatClaude:
|
||||
adaptor := claude.Adaptor{}
|
||||
return adaptor.DoResponse(c, resp, info)
|
||||
default:
|
||||
adaptor := openai.Adaptor{}
|
||||
return adaptor.DoResponse(c, resp, info)
|
||||
}
|
||||
}
|
||||
|
||||
func (a *Adaptor) GetModelList() []string {
|
||||
|
||||
@@ -6,6 +6,7 @@ import (
|
||||
channelconstant "github.com/QuantumNous/new-api/constant"
|
||||
relaycommon "github.com/QuantumNous/new-api/relay/common"
|
||||
"github.com/QuantumNous/new-api/relay/constant"
|
||||
"github.com/QuantumNous/new-api/types"
|
||||
)
|
||||
|
||||
func GetRequestURL(info *relaycommon.RelayInfo) (string, error) {
|
||||
@@ -13,13 +14,17 @@ func GetRequestURL(info *relaycommon.RelayInfo) (string, error) {
|
||||
if baseUrl == "" {
|
||||
baseUrl = channelconstant.ChannelBaseURLs[channelconstant.ChannelTypeMiniMax]
|
||||
}
|
||||
|
||||
switch info.RelayMode {
|
||||
case constant.RelayModeChatCompletions:
|
||||
return fmt.Sprintf("%s/v1/text/chatcompletion_v2", baseUrl), nil
|
||||
case constant.RelayModeAudioSpeech:
|
||||
return fmt.Sprintf("%s/v1/t2a_v2", baseUrl), nil
|
||||
switch info.RelayFormat {
|
||||
case types.RelayFormatClaude:
|
||||
return fmt.Sprintf("%s/anthropic/v1/messages", info.ChannelBaseUrl), nil
|
||||
default:
|
||||
return "", fmt.Errorf("unsupported relay mode: %d", info.RelayMode)
|
||||
switch info.RelayMode {
|
||||
case constant.RelayModeChatCompletions:
|
||||
return fmt.Sprintf("%s/v1/text/chatcompletion_v2", baseUrl), nil
|
||||
case constant.RelayModeAudioSpeech:
|
||||
return fmt.Sprintf("%s/v1/t2a_v2", baseUrl), nil
|
||||
default:
|
||||
return "", fmt.Errorf("unsupported relay mode: %d", info.RelayMode)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -13,6 +13,7 @@ import (
|
||||
"github.com/QuantumNous/new-api/logger"
|
||||
"github.com/QuantumNous/new-api/model"
|
||||
"github.com/QuantumNous/new-api/relay/channel"
|
||||
"github.com/QuantumNous/new-api/relay/channel/task/taskcommon"
|
||||
relaycommon "github.com/QuantumNous/new-api/relay/common"
|
||||
"github.com/QuantumNous/new-api/service"
|
||||
"github.com/samber/lo"
|
||||
@@ -108,10 +109,10 @@ type AliMetadata struct {
|
||||
// ============================
|
||||
|
||||
type TaskAdaptor struct {
|
||||
taskcommon.BaseBilling
|
||||
ChannelType int
|
||||
apiKey string
|
||||
baseURL string
|
||||
aliReq *AliVideoRequest
|
||||
}
|
||||
|
||||
func (a *TaskAdaptor) Init(info *relaycommon.RelayInfo) {
|
||||
@@ -121,17 +122,7 @@ func (a *TaskAdaptor) Init(info *relaycommon.RelayInfo) {
|
||||
}
|
||||
|
||||
func (a *TaskAdaptor) ValidateRequestAndSetAction(c *gin.Context, info *relaycommon.RelayInfo) (taskErr *dto.TaskError) {
|
||||
// 阿里通义万相支持 JSON 格式,不使用 multipart
|
||||
var taskReq relaycommon.TaskSubmitReq
|
||||
if err := common.UnmarshalBodyReusable(c, &taskReq); err != nil {
|
||||
return service.TaskErrorWrapper(err, "unmarshal_task_request_failed", http.StatusBadRequest)
|
||||
}
|
||||
aliReq, err := a.convertToAliRequest(info, taskReq)
|
||||
if err != nil {
|
||||
return service.TaskErrorWrapper(err, "convert_to_ali_request_failed", http.StatusInternalServerError)
|
||||
}
|
||||
a.aliReq = aliReq
|
||||
logger.LogJson(c, "ali video request body", aliReq)
|
||||
// ValidateMultipartDirect 负责解析并将原始 TaskSubmitReq 存入 context
|
||||
return relaycommon.ValidateMultipartDirect(c, info)
|
||||
}
|
||||
|
||||
@@ -148,11 +139,21 @@ func (a *TaskAdaptor) BuildRequestHeader(c *gin.Context, req *http.Request, info
|
||||
}
|
||||
|
||||
func (a *TaskAdaptor) BuildRequestBody(c *gin.Context, info *relaycommon.RelayInfo) (io.Reader, error) {
|
||||
bodyBytes, err := common.Marshal(a.aliReq)
|
||||
taskReq, err := relaycommon.GetTaskRequest(c)
|
||||
if err != nil {
|
||||
return nil, errors.Wrap(err, "get_task_request_failed")
|
||||
}
|
||||
|
||||
aliReq, err := a.convertToAliRequest(info, taskReq)
|
||||
if err != nil {
|
||||
return nil, errors.Wrap(err, "convert_to_ali_request_failed")
|
||||
}
|
||||
logger.LogJson(c, "ali video request body", aliReq)
|
||||
|
||||
bodyBytes, err := common.Marshal(aliReq)
|
||||
if err != nil {
|
||||
return nil, errors.Wrap(err, "marshal_ali_request_failed")
|
||||
}
|
||||
|
||||
return bytes.NewReader(bodyBytes), nil
|
||||
}
|
||||
|
||||
@@ -252,8 +253,12 @@ func ProcessAliOtherRatios(aliReq *AliVideoRequest) (map[string]float64, error)
|
||||
}
|
||||
|
||||
func (a *TaskAdaptor) convertToAliRequest(info *relaycommon.RelayInfo, req relaycommon.TaskSubmitReq) (*AliVideoRequest, error) {
|
||||
upstreamModel := req.Model
|
||||
if info.IsModelMapped {
|
||||
upstreamModel = info.UpstreamModelName
|
||||
}
|
||||
aliReq := &AliVideoRequest{
|
||||
Model: req.Model,
|
||||
Model: upstreamModel,
|
||||
Input: AliVideoInput{
|
||||
Prompt: req.Prompt,
|
||||
ImgURL: req.InputReference,
|
||||
@@ -331,23 +336,37 @@ func (a *TaskAdaptor) convertToAliRequest(info *relaycommon.RelayInfo, req relay
|
||||
}
|
||||
}
|
||||
|
||||
if aliReq.Model != req.Model {
|
||||
if aliReq.Model != upstreamModel {
|
||||
return nil, errors.New("can't change model with metadata")
|
||||
}
|
||||
|
||||
info.PriceData.OtherRatios = map[string]float64{
|
||||
return aliReq, nil
|
||||
}
|
||||
|
||||
// EstimateBilling 根据用户请求参数计算 OtherRatios(时长、分辨率等)。
|
||||
// 在 ValidateRequestAndSetAction 之后、价格计算之前调用。
|
||||
func (a *TaskAdaptor) EstimateBilling(c *gin.Context, info *relaycommon.RelayInfo) map[string]float64 {
|
||||
taskReq, err := relaycommon.GetTaskRequest(c)
|
||||
if err != nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
aliReq, err := a.convertToAliRequest(info, taskReq)
|
||||
if err != nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
otherRatios := map[string]float64{
|
||||
"seconds": float64(aliReq.Parameters.Duration),
|
||||
}
|
||||
|
||||
ratios, err := ProcessAliOtherRatios(aliReq)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
return otherRatios
|
||||
}
|
||||
for s, f := range ratios {
|
||||
info.PriceData.OtherRatios[s] = f
|
||||
for k, v := range ratios {
|
||||
otherRatios[k] = v
|
||||
}
|
||||
|
||||
return aliReq, nil
|
||||
return otherRatios
|
||||
}
|
||||
|
||||
// DoRequest delegates to common helper
|
||||
@@ -384,7 +403,8 @@ func (a *TaskAdaptor) DoResponse(c *gin.Context, resp *http.Response, info *rela
|
||||
|
||||
// 转换为 OpenAI 格式响应
|
||||
openAIResp := dto.NewOpenAIVideo()
|
||||
openAIResp.ID = aliResp.Output.TaskID
|
||||
openAIResp.ID = info.PublicTaskID
|
||||
openAIResp.TaskID = info.PublicTaskID
|
||||
openAIResp.Model = c.GetString("model")
|
||||
if openAIResp.Model == "" && info != nil {
|
||||
openAIResp.Model = info.OriginModelName
|
||||
|
||||
@@ -2,7 +2,6 @@ package doubao
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
@@ -14,6 +13,7 @@ import (
|
||||
"github.com/QuantumNous/new-api/dto"
|
||||
"github.com/QuantumNous/new-api/model"
|
||||
"github.com/QuantumNous/new-api/relay/channel"
|
||||
taskcommon "github.com/QuantumNous/new-api/relay/channel/task/taskcommon"
|
||||
relaycommon "github.com/QuantumNous/new-api/relay/common"
|
||||
"github.com/QuantumNous/new-api/service"
|
||||
|
||||
@@ -89,6 +89,7 @@ type responseTask struct {
|
||||
// ============================
|
||||
|
||||
type TaskAdaptor struct {
|
||||
taskcommon.BaseBilling
|
||||
ChannelType int
|
||||
apiKey string
|
||||
baseURL string
|
||||
@@ -130,8 +131,12 @@ func (a *TaskAdaptor) BuildRequestBody(c *gin.Context, info *relaycommon.RelayIn
|
||||
if err != nil {
|
||||
return nil, errors.Wrap(err, "convert request payload failed")
|
||||
}
|
||||
info.UpstreamModelName = body.Model
|
||||
data, err := json.Marshal(body)
|
||||
if info.IsModelMapped {
|
||||
body.Model = info.UpstreamModelName
|
||||
} else {
|
||||
info.UpstreamModelName = body.Model
|
||||
}
|
||||
data, err := common.Marshal(body)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -154,7 +159,7 @@ func (a *TaskAdaptor) DoResponse(c *gin.Context, resp *http.Response, info *rela
|
||||
|
||||
// Parse Doubao response
|
||||
var dResp responsePayload
|
||||
if err := json.Unmarshal(responseBody, &dResp); err != nil {
|
||||
if err := common.Unmarshal(responseBody, &dResp); err != nil {
|
||||
taskErr = service.TaskErrorWrapper(errors.Wrapf(err, "body: %s", responseBody), "unmarshal_response_body_failed", http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
@@ -165,8 +170,8 @@ func (a *TaskAdaptor) DoResponse(c *gin.Context, resp *http.Response, info *rela
|
||||
}
|
||||
|
||||
ov := dto.NewOpenAIVideo()
|
||||
ov.ID = dResp.ID
|
||||
ov.TaskID = dResp.ID
|
||||
ov.ID = info.PublicTaskID
|
||||
ov.TaskID = info.PublicTaskID
|
||||
ov.CreatedAt = time.Now().Unix()
|
||||
ov.Model = info.OriginModelName
|
||||
|
||||
@@ -234,12 +239,7 @@ func (a *TaskAdaptor) convertToRequestPayload(req *relaycommon.TaskSubmitReq) (*
|
||||
}
|
||||
|
||||
metadata := req.Metadata
|
||||
medaBytes, err := json.Marshal(metadata)
|
||||
if err != nil {
|
||||
return nil, errors.Wrap(err, "metadata marshal metadata failed")
|
||||
}
|
||||
err = json.Unmarshal(medaBytes, &r)
|
||||
if err != nil {
|
||||
if err := taskcommon.UnmarshalMetadata(metadata, &r); err != nil {
|
||||
return nil, errors.Wrap(err, "unmarshal metadata failed")
|
||||
}
|
||||
|
||||
@@ -248,7 +248,7 @@ func (a *TaskAdaptor) convertToRequestPayload(req *relaycommon.TaskSubmitReq) (*
|
||||
|
||||
func (a *TaskAdaptor) ParseTaskResult(respBody []byte) (*relaycommon.TaskInfo, error) {
|
||||
resTask := responseTask{}
|
||||
if err := json.Unmarshal(respBody, &resTask); err != nil {
|
||||
if err := common.Unmarshal(respBody, &resTask); err != nil {
|
||||
return nil, errors.Wrap(err, "unmarshal task result failed")
|
||||
}
|
||||
|
||||
@@ -286,7 +286,7 @@ func (a *TaskAdaptor) ParseTaskResult(respBody []byte) (*relaycommon.TaskInfo, e
|
||||
|
||||
func (a *TaskAdaptor) ConvertToOpenAIVideo(originTask *model.Task) ([]byte, error) {
|
||||
var dResp responseTask
|
||||
if err := json.Unmarshal(originTask.Data, &dResp); err != nil {
|
||||
if err := common.Unmarshal(originTask.Data, &dResp); err != nil {
|
||||
return nil, errors.Wrap(err, "unmarshal doubao task data failed")
|
||||
}
|
||||
|
||||
@@ -307,6 +307,5 @@ func (a *TaskAdaptor) ConvertToOpenAIVideo(originTask *model.Task) ([]byte, erro
|
||||
}
|
||||
}
|
||||
|
||||
jsonData, _ := common.Marshal(openAIVideo)
|
||||
return jsonData, nil
|
||||
return common.Marshal(openAIVideo)
|
||||
}
|
||||
|
||||
@@ -2,8 +2,6 @@ package gemini
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/base64"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
@@ -16,10 +14,10 @@ import (
|
||||
"github.com/QuantumNous/new-api/dto"
|
||||
"github.com/QuantumNous/new-api/model"
|
||||
"github.com/QuantumNous/new-api/relay/channel"
|
||||
taskcommon "github.com/QuantumNous/new-api/relay/channel/task/taskcommon"
|
||||
relaycommon "github.com/QuantumNous/new-api/relay/common"
|
||||
"github.com/QuantumNous/new-api/service"
|
||||
"github.com/QuantumNous/new-api/setting/model_setting"
|
||||
"github.com/QuantumNous/new-api/setting/system_setting"
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/pkg/errors"
|
||||
)
|
||||
@@ -87,6 +85,7 @@ type operationResponse struct {
|
||||
// ============================
|
||||
|
||||
type TaskAdaptor struct {
|
||||
taskcommon.BaseBilling
|
||||
ChannelType int
|
||||
apiKey string
|
||||
baseURL string
|
||||
@@ -106,7 +105,7 @@ func (a *TaskAdaptor) ValidateRequestAndSetAction(c *gin.Context, info *relaycom
|
||||
|
||||
// BuildRequestURL constructs the upstream URL.
|
||||
func (a *TaskAdaptor) BuildRequestURL(info *relaycommon.RelayInfo) (string, error) {
|
||||
modelName := info.OriginModelName
|
||||
modelName := info.UpstreamModelName
|
||||
version := model_setting.GetGeminiVersionSetting(modelName)
|
||||
|
||||
return fmt.Sprintf(
|
||||
@@ -145,16 +144,11 @@ func (a *TaskAdaptor) BuildRequestBody(c *gin.Context, info *relaycommon.RelayIn
|
||||
}
|
||||
|
||||
metadata := req.Metadata
|
||||
medaBytes, err := json.Marshal(metadata)
|
||||
if err != nil {
|
||||
return nil, errors.Wrap(err, "metadata marshal metadata failed")
|
||||
}
|
||||
err = json.Unmarshal(medaBytes, &body.Parameters)
|
||||
if err != nil {
|
||||
if err := taskcommon.UnmarshalMetadata(metadata, &body.Parameters); err != nil {
|
||||
return nil, errors.Wrap(err, "unmarshal metadata failed")
|
||||
}
|
||||
|
||||
data, err := json.Marshal(body)
|
||||
data, err := common.Marshal(body)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -175,16 +169,16 @@ func (a *TaskAdaptor) DoResponse(c *gin.Context, resp *http.Response, info *rela
|
||||
_ = resp.Body.Close()
|
||||
|
||||
var s submitResponse
|
||||
if err := json.Unmarshal(responseBody, &s); err != nil {
|
||||
if err := common.Unmarshal(responseBody, &s); err != nil {
|
||||
return "", nil, service.TaskErrorWrapper(err, "unmarshal_response_failed", http.StatusInternalServerError)
|
||||
}
|
||||
if strings.TrimSpace(s.Name) == "" {
|
||||
return "", nil, service.TaskErrorWrapper(fmt.Errorf("missing operation name"), "invalid_response", http.StatusInternalServerError)
|
||||
}
|
||||
taskID = encodeLocalTaskID(s.Name)
|
||||
taskID = taskcommon.EncodeLocalTaskID(s.Name)
|
||||
ov := dto.NewOpenAIVideo()
|
||||
ov.ID = taskID
|
||||
ov.TaskID = taskID
|
||||
ov.ID = info.PublicTaskID
|
||||
ov.TaskID = info.PublicTaskID
|
||||
ov.CreatedAt = time.Now().Unix()
|
||||
ov.Model = info.OriginModelName
|
||||
c.JSON(http.StatusOK, ov)
|
||||
@@ -206,7 +200,7 @@ func (a *TaskAdaptor) FetchTask(baseUrl, key string, body map[string]any, proxy
|
||||
return nil, fmt.Errorf("invalid task_id")
|
||||
}
|
||||
|
||||
upstreamName, err := decodeLocalTaskID(taskID)
|
||||
upstreamName, err := taskcommon.DecodeLocalTaskID(taskID)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("decode task_id failed: %w", err)
|
||||
}
|
||||
@@ -232,7 +226,7 @@ func (a *TaskAdaptor) FetchTask(baseUrl, key string, body map[string]any, proxy
|
||||
|
||||
func (a *TaskAdaptor) ParseTaskResult(respBody []byte) (*relaycommon.TaskInfo, error) {
|
||||
var op operationResponse
|
||||
if err := json.Unmarshal(respBody, &op); err != nil {
|
||||
if err := common.Unmarshal(respBody, &op); err != nil {
|
||||
return nil, fmt.Errorf("unmarshal operation response failed: %w", err)
|
||||
}
|
||||
|
||||
@@ -254,9 +248,8 @@ func (a *TaskAdaptor) ParseTaskResult(respBody []byte) (*relaycommon.TaskInfo, e
|
||||
ti.Status = model.TaskStatusSuccess
|
||||
ti.Progress = "100%"
|
||||
|
||||
taskID := encodeLocalTaskID(op.Name)
|
||||
ti.TaskID = taskID
|
||||
ti.Url = fmt.Sprintf("%s/v1/videos/%s/content", system_setting.ServerAddress, taskID)
|
||||
ti.TaskID = taskcommon.EncodeLocalTaskID(op.Name)
|
||||
// Url intentionally left empty — the caller constructs the proxy URL using the public task ID
|
||||
|
||||
// Extract URL from generateVideoResponse if available
|
||||
if len(op.Response.GenerateVideoResponse.GeneratedSamples) > 0 {
|
||||
@@ -269,7 +262,10 @@ func (a *TaskAdaptor) ParseTaskResult(respBody []byte) (*relaycommon.TaskInfo, e
|
||||
}
|
||||
|
||||
func (a *TaskAdaptor) ConvertToOpenAIVideo(task *model.Task) ([]byte, error) {
|
||||
upstreamName, err := decodeLocalTaskID(task.TaskID)
|
||||
// Use GetUpstreamTaskID() to get the real upstream operation name for model extraction.
|
||||
// task.TaskID is now a public task_xxxx ID, no longer a base64-encoded upstream name.
|
||||
upstreamTaskID := task.GetUpstreamTaskID()
|
||||
upstreamName, err := taskcommon.DecodeLocalTaskID(upstreamTaskID)
|
||||
if err != nil {
|
||||
upstreamName = ""
|
||||
}
|
||||
@@ -297,18 +293,6 @@ func (a *TaskAdaptor) ConvertToOpenAIVideo(task *model.Task) ([]byte, error) {
|
||||
// helpers
|
||||
// ============================
|
||||
|
||||
func encodeLocalTaskID(name string) string {
|
||||
return base64.RawURLEncoding.EncodeToString([]byte(name))
|
||||
}
|
||||
|
||||
func decodeLocalTaskID(local string) (string, error) {
|
||||
b, err := base64.RawURLEncoding.DecodeString(local)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
return string(b), nil
|
||||
}
|
||||
|
||||
var modelRe = regexp.MustCompile(`models/([^/]+)/operations/`)
|
||||
|
||||
func extractModelFromOperationName(name string) string {
|
||||
|
||||
@@ -2,7 +2,6 @@ package hailuo
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
@@ -18,12 +17,14 @@ import (
|
||||
"github.com/QuantumNous/new-api/constant"
|
||||
"github.com/QuantumNous/new-api/dto"
|
||||
"github.com/QuantumNous/new-api/relay/channel"
|
||||
taskcommon "github.com/QuantumNous/new-api/relay/channel/task/taskcommon"
|
||||
relaycommon "github.com/QuantumNous/new-api/relay/common"
|
||||
"github.com/QuantumNous/new-api/service"
|
||||
)
|
||||
|
||||
// https://platform.minimaxi.com/docs/api-reference/video-generation-intro
|
||||
type TaskAdaptor struct {
|
||||
taskcommon.BaseBilling
|
||||
ChannelType int
|
||||
apiKey string
|
||||
baseURL string
|
||||
@@ -60,12 +61,12 @@ func (a *TaskAdaptor) BuildRequestBody(c *gin.Context, info *relaycommon.RelayIn
|
||||
return nil, fmt.Errorf("invalid request type in context")
|
||||
}
|
||||
|
||||
body, err := a.convertToRequestPayload(&req)
|
||||
body, err := a.convertToRequestPayload(&req, info)
|
||||
if err != nil {
|
||||
return nil, errors.Wrap(err, "convert request payload failed")
|
||||
}
|
||||
|
||||
data, err := json.Marshal(body)
|
||||
data, err := common.Marshal(body)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -86,7 +87,7 @@ func (a *TaskAdaptor) DoResponse(c *gin.Context, resp *http.Response, info *rela
|
||||
_ = resp.Body.Close()
|
||||
|
||||
var hResp VideoResponse
|
||||
if err := json.Unmarshal(responseBody, &hResp); err != nil {
|
||||
if err := common.Unmarshal(responseBody, &hResp); err != nil {
|
||||
taskErr = service.TaskErrorWrapper(errors.Wrapf(err, "body: %s", responseBody), "unmarshal_response_body_failed", http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
@@ -101,8 +102,8 @@ func (a *TaskAdaptor) DoResponse(c *gin.Context, resp *http.Response, info *rela
|
||||
}
|
||||
|
||||
ov := dto.NewOpenAIVideo()
|
||||
ov.ID = hResp.TaskID
|
||||
ov.TaskID = hResp.TaskID
|
||||
ov.ID = info.PublicTaskID
|
||||
ov.TaskID = info.PublicTaskID
|
||||
ov.CreatedAt = time.Now().Unix()
|
||||
ov.Model = info.OriginModelName
|
||||
|
||||
@@ -141,8 +142,8 @@ func (a *TaskAdaptor) GetChannelName() string {
|
||||
return ChannelName
|
||||
}
|
||||
|
||||
func (a *TaskAdaptor) convertToRequestPayload(req *relaycommon.TaskSubmitReq) (*VideoRequest, error) {
|
||||
modelConfig := GetModelConfig(req.Model)
|
||||
func (a *TaskAdaptor) convertToRequestPayload(req *relaycommon.TaskSubmitReq, info *relaycommon.RelayInfo) (*VideoRequest, error) {
|
||||
modelConfig := GetModelConfig(info.UpstreamModelName)
|
||||
duration := DefaultDuration
|
||||
if req.Duration > 0 {
|
||||
duration = req.Duration
|
||||
@@ -153,7 +154,7 @@ func (a *TaskAdaptor) convertToRequestPayload(req *relaycommon.TaskSubmitReq) (*
|
||||
}
|
||||
|
||||
videoRequest := &VideoRequest{
|
||||
Model: req.Model,
|
||||
Model: info.UpstreamModelName,
|
||||
Prompt: req.Prompt,
|
||||
Duration: &duration,
|
||||
Resolution: resolution,
|
||||
@@ -182,7 +183,7 @@ func (a *TaskAdaptor) parseResolutionFromSize(size string, modelConfig ModelConf
|
||||
|
||||
func (a *TaskAdaptor) ParseTaskResult(respBody []byte) (*relaycommon.TaskInfo, error) {
|
||||
resTask := QueryTaskResponse{}
|
||||
if err := json.Unmarshal(respBody, &resTask); err != nil {
|
||||
if err := common.Unmarshal(respBody, &resTask); err != nil {
|
||||
return nil, errors.Wrap(err, "unmarshal task result failed")
|
||||
}
|
||||
|
||||
@@ -224,7 +225,7 @@ func (a *TaskAdaptor) ParseTaskResult(respBody []byte) (*relaycommon.TaskInfo, e
|
||||
|
||||
func (a *TaskAdaptor) ConvertToOpenAIVideo(originTask *model.Task) ([]byte, error) {
|
||||
var hailuoResp QueryTaskResponse
|
||||
if err := json.Unmarshal(originTask.Data, &hailuoResp); err != nil {
|
||||
if err := common.Unmarshal(originTask.Data, &hailuoResp); err != nil {
|
||||
return nil, errors.Wrap(err, "unmarshal hailuo task data failed")
|
||||
}
|
||||
|
||||
@@ -271,7 +272,7 @@ func (a *TaskAdaptor) buildVideoURL(_, fileID string) string {
|
||||
}
|
||||
|
||||
var retrieveResp RetrieveFileResponse
|
||||
if err := json.Unmarshal(responseBody, &retrieveResp); err != nil {
|
||||
if err := common.Unmarshal(responseBody, &retrieveResp); err != nil {
|
||||
return ""
|
||||
}
|
||||
|
||||
|
||||
@@ -6,7 +6,6 @@ import (
|
||||
"crypto/sha256"
|
||||
"encoding/base64"
|
||||
"encoding/hex"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
@@ -25,6 +24,7 @@ import (
|
||||
"github.com/QuantumNous/new-api/constant"
|
||||
"github.com/QuantumNous/new-api/dto"
|
||||
"github.com/QuantumNous/new-api/relay/channel"
|
||||
taskcommon "github.com/QuantumNous/new-api/relay/channel/task/taskcommon"
|
||||
relaycommon "github.com/QuantumNous/new-api/relay/common"
|
||||
"github.com/QuantumNous/new-api/service"
|
||||
)
|
||||
@@ -77,6 +77,7 @@ const (
|
||||
// ============================
|
||||
|
||||
type TaskAdaptor struct {
|
||||
taskcommon.BaseBilling
|
||||
ChannelType int
|
||||
accessKey string
|
||||
secretKey string
|
||||
@@ -164,11 +165,11 @@ func (a *TaskAdaptor) BuildRequestBody(c *gin.Context, info *relaycommon.RelayIn
|
||||
}
|
||||
}
|
||||
|
||||
body, err := a.convertToRequestPayload(&req)
|
||||
body, err := a.convertToRequestPayload(&req, info)
|
||||
if err != nil {
|
||||
return nil, errors.Wrap(err, "convert request payload failed")
|
||||
}
|
||||
data, err := json.Marshal(body)
|
||||
data, err := common.Marshal(body)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -191,7 +192,7 @@ func (a *TaskAdaptor) DoResponse(c *gin.Context, resp *http.Response, info *rela
|
||||
|
||||
// Parse Jimeng response
|
||||
var jResp responsePayload
|
||||
if err := json.Unmarshal(responseBody, &jResp); err != nil {
|
||||
if err := common.Unmarshal(responseBody, &jResp); err != nil {
|
||||
taskErr = service.TaskErrorWrapper(errors.Wrapf(err, "body: %s", responseBody), "unmarshal_response_body_failed", http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
@@ -202,8 +203,8 @@ func (a *TaskAdaptor) DoResponse(c *gin.Context, resp *http.Response, info *rela
|
||||
}
|
||||
|
||||
ov := dto.NewOpenAIVideo()
|
||||
ov.ID = jResp.Data.TaskID
|
||||
ov.TaskID = jResp.Data.TaskID
|
||||
ov.ID = info.PublicTaskID
|
||||
ov.TaskID = info.PublicTaskID
|
||||
ov.CreatedAt = time.Now().Unix()
|
||||
ov.Model = info.OriginModelName
|
||||
c.JSON(http.StatusOK, ov)
|
||||
@@ -225,7 +226,7 @@ func (a *TaskAdaptor) FetchTask(baseUrl, key string, body map[string]any, proxy
|
||||
"req_key": "jimeng_vgfm_t2v_l20", // This is fixed value from doc: https://www.volcengine.com/docs/85621/1544774
|
||||
"task_id": taskID,
|
||||
}
|
||||
payloadBytes, err := json.Marshal(payload)
|
||||
payloadBytes, err := common.Marshal(payload)
|
||||
if err != nil {
|
||||
return nil, errors.Wrap(err, "marshal fetch task payload failed")
|
||||
}
|
||||
@@ -377,9 +378,9 @@ func hmacSHA256(key []byte, data []byte) []byte {
|
||||
return h.Sum(nil)
|
||||
}
|
||||
|
||||
func (a *TaskAdaptor) convertToRequestPayload(req *relaycommon.TaskSubmitReq) (*requestPayload, error) {
|
||||
func (a *TaskAdaptor) convertToRequestPayload(req *relaycommon.TaskSubmitReq, info *relaycommon.RelayInfo) (*requestPayload, error) {
|
||||
r := requestPayload{
|
||||
ReqKey: req.Model,
|
||||
ReqKey: info.UpstreamModelName,
|
||||
Prompt: req.Prompt,
|
||||
}
|
||||
|
||||
@@ -398,13 +399,7 @@ func (a *TaskAdaptor) convertToRequestPayload(req *relaycommon.TaskSubmitReq) (*
|
||||
r.BinaryDataBase64 = req.Images
|
||||
}
|
||||
}
|
||||
metadata := req.Metadata
|
||||
medaBytes, err := json.Marshal(metadata)
|
||||
if err != nil {
|
||||
return nil, errors.Wrap(err, "metadata marshal metadata failed")
|
||||
}
|
||||
err = json.Unmarshal(medaBytes, &r)
|
||||
if err != nil {
|
||||
if err := taskcommon.UnmarshalMetadata(req.Metadata, &r); err != nil {
|
||||
return nil, errors.Wrap(err, "unmarshal metadata failed")
|
||||
}
|
||||
|
||||
@@ -432,7 +427,7 @@ func (a *TaskAdaptor) convertToRequestPayload(req *relaycommon.TaskSubmitReq) (*
|
||||
|
||||
func (a *TaskAdaptor) ParseTaskResult(respBody []byte) (*relaycommon.TaskInfo, error) {
|
||||
resTask := responseTask{}
|
||||
if err := json.Unmarshal(respBody, &resTask); err != nil {
|
||||
if err := common.Unmarshal(respBody, &resTask); err != nil {
|
||||
return nil, errors.Wrap(err, "unmarshal task result failed")
|
||||
}
|
||||
taskResult := relaycommon.TaskInfo{}
|
||||
@@ -458,7 +453,7 @@ func (a *TaskAdaptor) ParseTaskResult(respBody []byte) (*relaycommon.TaskInfo, e
|
||||
|
||||
func (a *TaskAdaptor) ConvertToOpenAIVideo(originTask *model.Task) ([]byte, error) {
|
||||
var jimengResp responseTask
|
||||
if err := json.Unmarshal(originTask.Data, &jimengResp); err != nil {
|
||||
if err := common.Unmarshal(originTask.Data, &jimengResp); err != nil {
|
||||
return nil, errors.Wrap(err, "unmarshal jimeng task data failed")
|
||||
}
|
||||
|
||||
@@ -477,8 +472,7 @@ func (a *TaskAdaptor) ConvertToOpenAIVideo(originTask *model.Task) ([]byte, erro
|
||||
}
|
||||
}
|
||||
|
||||
jsonData, _ := common.Marshal(openAIVideo)
|
||||
return jsonData, nil
|
||||
return common.Marshal(openAIVideo)
|
||||
}
|
||||
|
||||
func isNewAPIRelay(apiKey string) bool {
|
||||
|
||||
@@ -2,7 +2,6 @@ package kling
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
@@ -21,6 +20,7 @@ import (
|
||||
"github.com/QuantumNous/new-api/constant"
|
||||
"github.com/QuantumNous/new-api/dto"
|
||||
"github.com/QuantumNous/new-api/relay/channel"
|
||||
taskcommon "github.com/QuantumNous/new-api/relay/channel/task/taskcommon"
|
||||
relaycommon "github.com/QuantumNous/new-api/relay/common"
|
||||
"github.com/QuantumNous/new-api/service"
|
||||
)
|
||||
@@ -97,6 +97,7 @@ type responsePayload struct {
|
||||
// ============================
|
||||
|
||||
type TaskAdaptor struct {
|
||||
taskcommon.BaseBilling
|
||||
ChannelType int
|
||||
apiKey string
|
||||
baseURL string
|
||||
@@ -149,14 +150,14 @@ func (a *TaskAdaptor) BuildRequestBody(c *gin.Context, info *relaycommon.RelayIn
|
||||
}
|
||||
req := v.(relaycommon.TaskSubmitReq)
|
||||
|
||||
body, err := a.convertToRequestPayload(&req)
|
||||
body, err := a.convertToRequestPayload(&req, info)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if body.Image == "" && body.ImageTail == "" {
|
||||
c.Set("action", constant.TaskActionTextGenerate)
|
||||
}
|
||||
data, err := json.Marshal(body)
|
||||
data, err := common.Marshal(body)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -180,7 +181,7 @@ func (a *TaskAdaptor) DoResponse(c *gin.Context, resp *http.Response, info *rela
|
||||
}
|
||||
|
||||
var kResp responsePayload
|
||||
err = json.Unmarshal(responseBody, &kResp)
|
||||
err = common.Unmarshal(responseBody, &kResp)
|
||||
if err != nil {
|
||||
taskErr = service.TaskErrorWrapper(err, "unmarshal_response_failed", http.StatusInternalServerError)
|
||||
return
|
||||
@@ -190,8 +191,8 @@ func (a *TaskAdaptor) DoResponse(c *gin.Context, resp *http.Response, info *rela
|
||||
return
|
||||
}
|
||||
ov := dto.NewOpenAIVideo()
|
||||
ov.ID = kResp.Data.TaskId
|
||||
ov.TaskID = kResp.Data.TaskId
|
||||
ov.ID = info.PublicTaskID
|
||||
ov.TaskID = info.PublicTaskID
|
||||
ov.CreatedAt = time.Now().Unix()
|
||||
ov.Model = info.OriginModelName
|
||||
c.JSON(http.StatusOK, ov)
|
||||
@@ -247,15 +248,15 @@ func (a *TaskAdaptor) GetChannelName() string {
|
||||
// helpers
|
||||
// ============================
|
||||
|
||||
func (a *TaskAdaptor) convertToRequestPayload(req *relaycommon.TaskSubmitReq) (*requestPayload, error) {
|
||||
func (a *TaskAdaptor) convertToRequestPayload(req *relaycommon.TaskSubmitReq, info *relaycommon.RelayInfo) (*requestPayload, error) {
|
||||
r := requestPayload{
|
||||
Prompt: req.Prompt,
|
||||
Image: req.Image,
|
||||
Mode: defaultString(req.Mode, "std"),
|
||||
Duration: fmt.Sprintf("%d", defaultInt(req.Duration, 5)),
|
||||
Mode: taskcommon.DefaultString(req.Mode, "std"),
|
||||
Duration: fmt.Sprintf("%d", taskcommon.DefaultInt(req.Duration, 5)),
|
||||
AspectRatio: a.getAspectRatio(req.Size),
|
||||
ModelName: req.Model,
|
||||
Model: req.Model, // Keep consistent with model_name, double writing improves compatibility
|
||||
ModelName: info.UpstreamModelName,
|
||||
Model: info.UpstreamModelName,
|
||||
CfgScale: 0.5,
|
||||
StaticMask: "",
|
||||
DynamicMasks: []DynamicMask{},
|
||||
@@ -265,14 +266,9 @@ func (a *TaskAdaptor) convertToRequestPayload(req *relaycommon.TaskSubmitReq) (*
|
||||
}
|
||||
if r.ModelName == "" {
|
||||
r.ModelName = "kling-v1"
|
||||
r.Model = "kling-v1"
|
||||
}
|
||||
metadata := req.Metadata
|
||||
medaBytes, err := json.Marshal(metadata)
|
||||
if err != nil {
|
||||
return nil, errors.Wrap(err, "metadata marshal metadata failed")
|
||||
}
|
||||
err = json.Unmarshal(medaBytes, &r)
|
||||
if err != nil {
|
||||
if err := taskcommon.UnmarshalMetadata(req.Metadata, &r); err != nil {
|
||||
return nil, errors.Wrap(err, "unmarshal metadata failed")
|
||||
}
|
||||
return &r, nil
|
||||
@@ -291,20 +287,6 @@ func (a *TaskAdaptor) getAspectRatio(size string) string {
|
||||
}
|
||||
}
|
||||
|
||||
func defaultString(s, def string) string {
|
||||
if strings.TrimSpace(s) == "" {
|
||||
return def
|
||||
}
|
||||
return s
|
||||
}
|
||||
|
||||
func defaultInt(v int, def int) int {
|
||||
if v == 0 {
|
||||
return def
|
||||
}
|
||||
return v
|
||||
}
|
||||
|
||||
// ============================
|
||||
// JWT helpers
|
||||
// ============================
|
||||
@@ -340,7 +322,7 @@ func (a *TaskAdaptor) createJWTTokenWithKey(apiKey string) (string, error) {
|
||||
func (a *TaskAdaptor) ParseTaskResult(respBody []byte) (*relaycommon.TaskInfo, error) {
|
||||
taskInfo := &relaycommon.TaskInfo{}
|
||||
resPayload := responsePayload{}
|
||||
err := json.Unmarshal(respBody, &resPayload)
|
||||
err := common.Unmarshal(respBody, &resPayload)
|
||||
if err != nil {
|
||||
return nil, errors.Wrap(err, "failed to unmarshal response body")
|
||||
}
|
||||
@@ -374,7 +356,7 @@ func isNewAPIRelay(apiKey string) bool {
|
||||
|
||||
func (a *TaskAdaptor) ConvertToOpenAIVideo(originTask *model.Task) ([]byte, error) {
|
||||
var klingResp responsePayload
|
||||
if err := json.Unmarshal(originTask.Data, &klingResp); err != nil {
|
||||
if err := common.Unmarshal(originTask.Data, &klingResp); err != nil {
|
||||
return nil, errors.Wrap(err, "unmarshal kling task data failed")
|
||||
}
|
||||
|
||||
@@ -401,6 +383,5 @@ func (a *TaskAdaptor) ConvertToOpenAIVideo(originTask *model.Task) ([]byte, erro
|
||||
Code: fmt.Sprintf("%d", klingResp.Code),
|
||||
}
|
||||
}
|
||||
jsonData, _ := common.Marshal(openAIVideo)
|
||||
return jsonData, nil
|
||||
return common.Marshal(openAIVideo)
|
||||
}
|
||||
|
||||
@@ -1,9 +1,13 @@
|
||||
package sora
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"fmt"
|
||||
"io"
|
||||
"mime/multipart"
|
||||
"net/http"
|
||||
"net/textproto"
|
||||
"strconv"
|
||||
"strings"
|
||||
|
||||
"github.com/QuantumNous/new-api/common"
|
||||
@@ -11,12 +15,13 @@ import (
|
||||
"github.com/QuantumNous/new-api/dto"
|
||||
"github.com/QuantumNous/new-api/model"
|
||||
"github.com/QuantumNous/new-api/relay/channel"
|
||||
taskcommon "github.com/QuantumNous/new-api/relay/channel/task/taskcommon"
|
||||
relaycommon "github.com/QuantumNous/new-api/relay/common"
|
||||
"github.com/QuantumNous/new-api/service"
|
||||
"github.com/QuantumNous/new-api/setting/system_setting"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/pkg/errors"
|
||||
"github.com/tidwall/sjson"
|
||||
)
|
||||
|
||||
// ============================
|
||||
@@ -57,6 +62,7 @@ type responseTask struct {
|
||||
// ============================
|
||||
|
||||
type TaskAdaptor struct {
|
||||
taskcommon.BaseBilling
|
||||
ChannelType int
|
||||
apiKey string
|
||||
baseURL string
|
||||
@@ -69,15 +75,15 @@ func (a *TaskAdaptor) Init(info *relaycommon.RelayInfo) {
|
||||
}
|
||||
|
||||
func validateRemixRequest(c *gin.Context) *dto.TaskError {
|
||||
var req struct {
|
||||
Prompt string `json:"prompt"`
|
||||
}
|
||||
var req relaycommon.TaskSubmitReq
|
||||
if err := common.UnmarshalBodyReusable(c, &req); err != nil {
|
||||
return service.TaskErrorWrapperLocal(err, "invalid_request", http.StatusBadRequest)
|
||||
}
|
||||
if strings.TrimSpace(req.Prompt) == "" {
|
||||
return service.TaskErrorWrapperLocal(fmt.Errorf("field prompt is required"), "invalid_request", http.StatusBadRequest)
|
||||
}
|
||||
// 存储原始请求到 context,与 ValidateMultipartDirect 路径保持一致
|
||||
c.Set("task_request", req)
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -88,6 +94,41 @@ func (a *TaskAdaptor) ValidateRequestAndSetAction(c *gin.Context, info *relaycom
|
||||
return relaycommon.ValidateMultipartDirect(c, info)
|
||||
}
|
||||
|
||||
// EstimateBilling 根据用户请求的 seconds 和 size 计算 OtherRatios。
|
||||
func (a *TaskAdaptor) EstimateBilling(c *gin.Context, info *relaycommon.RelayInfo) map[string]float64 {
|
||||
// remix 路径的 OtherRatios 已在 ResolveOriginTask 中设置
|
||||
if info.Action == constant.TaskActionRemix {
|
||||
return nil
|
||||
}
|
||||
|
||||
req, err := relaycommon.GetTaskRequest(c)
|
||||
if err != nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
seconds, _ := strconv.Atoi(req.Seconds)
|
||||
if seconds == 0 {
|
||||
seconds = req.Duration
|
||||
}
|
||||
if seconds <= 0 {
|
||||
seconds = 4
|
||||
}
|
||||
|
||||
size := req.Size
|
||||
if size == "" {
|
||||
size = "720x1280"
|
||||
}
|
||||
|
||||
ratios := map[string]float64{
|
||||
"seconds": float64(seconds),
|
||||
"size": 1,
|
||||
}
|
||||
if size == "1792x1024" || size == "1024x1792" {
|
||||
ratios["size"] = 1.666667
|
||||
}
|
||||
return ratios
|
||||
}
|
||||
|
||||
func (a *TaskAdaptor) BuildRequestURL(info *relaycommon.RelayInfo) (string, error) {
|
||||
if info.Action == constant.TaskActionRemix {
|
||||
return fmt.Sprintf("%s/v1/videos/%s/remix", a.baseURL, info.OriginTaskID), nil
|
||||
@@ -107,6 +148,74 @@ func (a *TaskAdaptor) BuildRequestBody(c *gin.Context, info *relaycommon.RelayIn
|
||||
if err != nil {
|
||||
return nil, errors.Wrap(err, "get_request_body_failed")
|
||||
}
|
||||
cachedBody, err := storage.Bytes()
|
||||
if err != nil {
|
||||
return nil, errors.Wrap(err, "read_body_bytes_failed")
|
||||
}
|
||||
contentType := c.GetHeader("Content-Type")
|
||||
|
||||
if strings.HasPrefix(contentType, "application/json") {
|
||||
var bodyMap map[string]interface{}
|
||||
if err := common.Unmarshal(cachedBody, &bodyMap); err == nil {
|
||||
bodyMap["model"] = info.UpstreamModelName
|
||||
if newBody, err := common.Marshal(bodyMap); err == nil {
|
||||
return bytes.NewReader(newBody), nil
|
||||
}
|
||||
}
|
||||
return bytes.NewReader(cachedBody), nil
|
||||
}
|
||||
|
||||
if strings.Contains(contentType, "multipart/form-data") {
|
||||
formData, err := common.ParseMultipartFormReusable(c)
|
||||
if err != nil {
|
||||
return bytes.NewReader(cachedBody), nil
|
||||
}
|
||||
var buf bytes.Buffer
|
||||
writer := multipart.NewWriter(&buf)
|
||||
writer.WriteField("model", info.UpstreamModelName)
|
||||
for key, values := range formData.Value {
|
||||
if key == "model" {
|
||||
continue
|
||||
}
|
||||
for _, v := range values {
|
||||
writer.WriteField(key, v)
|
||||
}
|
||||
}
|
||||
for fieldName, fileHeaders := range formData.File {
|
||||
for _, fh := range fileHeaders {
|
||||
f, err := fh.Open()
|
||||
if err != nil {
|
||||
continue
|
||||
}
|
||||
ct := fh.Header.Get("Content-Type")
|
||||
if ct == "" || ct == "application/octet-stream" {
|
||||
buf512 := make([]byte, 512)
|
||||
n, _ := io.ReadFull(f, buf512)
|
||||
ct = http.DetectContentType(buf512[:n])
|
||||
// Re-open after sniffing so the full content is copied below
|
||||
f.Close()
|
||||
f, err = fh.Open()
|
||||
if err != nil {
|
||||
continue
|
||||
}
|
||||
}
|
||||
h := make(textproto.MIMEHeader)
|
||||
h.Set("Content-Disposition", fmt.Sprintf(`form-data; name="%s"; filename="%s"`, fieldName, fh.Filename))
|
||||
h.Set("Content-Type", ct)
|
||||
part, err := writer.CreatePart(h)
|
||||
if err != nil {
|
||||
f.Close()
|
||||
continue
|
||||
}
|
||||
io.Copy(part, f)
|
||||
f.Close()
|
||||
}
|
||||
}
|
||||
writer.Close()
|
||||
c.Request.Header.Set("Content-Type", writer.FormDataContentType())
|
||||
return &buf, nil
|
||||
}
|
||||
|
||||
return common.ReaderOnly(storage), nil
|
||||
}
|
||||
|
||||
@@ -116,7 +225,7 @@ func (a *TaskAdaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, req
|
||||
}
|
||||
|
||||
// DoResponse handles upstream response, returns taskID etc.
|
||||
func (a *TaskAdaptor) DoResponse(c *gin.Context, resp *http.Response, _ *relaycommon.RelayInfo) (taskID string, taskData []byte, taskErr *dto.TaskError) {
|
||||
func (a *TaskAdaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (taskID string, taskData []byte, taskErr *dto.TaskError) {
|
||||
responseBody, err := io.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
taskErr = service.TaskErrorWrapper(err, "read_response_body_failed", http.StatusInternalServerError)
|
||||
@@ -131,17 +240,20 @@ func (a *TaskAdaptor) DoResponse(c *gin.Context, resp *http.Response, _ *relayco
|
||||
return
|
||||
}
|
||||
|
||||
if dResp.ID == "" {
|
||||
if dResp.TaskID == "" {
|
||||
taskErr = service.TaskErrorWrapper(fmt.Errorf("task_id is empty"), "invalid_response", http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
dResp.ID = dResp.TaskID
|
||||
dResp.TaskID = ""
|
||||
upstreamID := dResp.ID
|
||||
if upstreamID == "" {
|
||||
upstreamID = dResp.TaskID
|
||||
}
|
||||
if upstreamID == "" {
|
||||
taskErr = service.TaskErrorWrapper(fmt.Errorf("task_id is empty"), "invalid_response", http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
|
||||
// 使用公开 task_xxxx ID 返回给客户端
|
||||
dResp.ID = info.PublicTaskID
|
||||
dResp.TaskID = info.PublicTaskID
|
||||
c.JSON(http.StatusOK, dResp)
|
||||
return dResp.ID, responseBody, nil
|
||||
return upstreamID, responseBody, nil
|
||||
}
|
||||
|
||||
// FetchTask fetch task status
|
||||
@@ -192,7 +304,7 @@ func (a *TaskAdaptor) ParseTaskResult(respBody []byte) (*relaycommon.TaskInfo, e
|
||||
taskResult.Status = model.TaskStatusInProgress
|
||||
case "completed":
|
||||
taskResult.Status = model.TaskStatusSuccess
|
||||
taskResult.Url = fmt.Sprintf("%s/v1/videos/%s/content", system_setting.ServerAddress, resTask.ID)
|
||||
// Url intentionally left empty — the caller constructs the proxy URL using the public task ID
|
||||
case "failed", "cancelled":
|
||||
taskResult.Status = model.TaskStatusFailure
|
||||
if resTask.Error != nil {
|
||||
@@ -210,5 +322,10 @@ func (a *TaskAdaptor) ParseTaskResult(respBody []byte) (*relaycommon.TaskInfo, e
|
||||
}
|
||||
|
||||
func (a *TaskAdaptor) ConvertToOpenAIVideo(task *model.Task) ([]byte, error) {
|
||||
return task.Data, nil
|
||||
data := task.Data
|
||||
var err error
|
||||
if data, err = sjson.SetBytes(data, "id", task.TaskID); err != nil {
|
||||
return nil, errors.Wrap(err, "set id failed")
|
||||
}
|
||||
return data, nil
|
||||
}
|
||||
|
||||
@@ -2,18 +2,16 @@ package suno
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/QuantumNous/new-api/common"
|
||||
"github.com/QuantumNous/new-api/constant"
|
||||
"github.com/QuantumNous/new-api/dto"
|
||||
"github.com/QuantumNous/new-api/relay/channel"
|
||||
taskcommon "github.com/QuantumNous/new-api/relay/channel/task/taskcommon"
|
||||
relaycommon "github.com/QuantumNous/new-api/relay/common"
|
||||
"github.com/QuantumNous/new-api/service"
|
||||
|
||||
@@ -21,11 +19,16 @@ import (
|
||||
)
|
||||
|
||||
type TaskAdaptor struct {
|
||||
taskcommon.BaseBilling
|
||||
ChannelType int
|
||||
}
|
||||
|
||||
// ParseTaskResult is not used for Suno tasks.
|
||||
// Suno polling uses a dedicated batch-fetch path (service.UpdateSunoTasks) that
|
||||
// receives dto.TaskResponse[[]dto.SunoDataResponse] from the upstream /fetch API.
|
||||
// This differs from the per-task polling used by video adaptors.
|
||||
func (a *TaskAdaptor) ParseTaskResult([]byte) (*relaycommon.TaskInfo, error) {
|
||||
return nil, fmt.Errorf("not implement") // todo implement this method if needed
|
||||
return nil, fmt.Errorf("suno uses batch polling via UpdateSunoTasks, ParseTaskResult is not applicable")
|
||||
}
|
||||
|
||||
func (a *TaskAdaptor) Init(info *relaycommon.RelayInfo) {
|
||||
@@ -47,13 +50,13 @@ func (a *TaskAdaptor) ValidateRequestAndSetAction(c *gin.Context, info *relaycom
|
||||
return
|
||||
}
|
||||
|
||||
if sunoRequest.ContinueClipId != "" {
|
||||
if sunoRequest.TaskID == "" {
|
||||
taskErr = service.TaskErrorWrapperLocal(fmt.Errorf("task id is empty"), "invalid_request", http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
info.OriginTaskID = sunoRequest.TaskID
|
||||
}
|
||||
//if sunoRequest.ContinueClipId != "" {
|
||||
// if sunoRequest.TaskID == "" {
|
||||
// taskErr = service.TaskErrorWrapperLocal(fmt.Errorf("task id is empty"), "invalid_request", http.StatusBadRequest)
|
||||
// return
|
||||
// }
|
||||
// info.OriginTaskID = sunoRequest.TaskID
|
||||
//}
|
||||
|
||||
info.Action = action
|
||||
c.Set("task_request", sunoRequest)
|
||||
@@ -76,12 +79,9 @@ func (a *TaskAdaptor) BuildRequestHeader(c *gin.Context, req *http.Request, info
|
||||
func (a *TaskAdaptor) BuildRequestBody(c *gin.Context, info *relaycommon.RelayInfo) (io.Reader, error) {
|
||||
sunoRequest, ok := c.Get("task_request")
|
||||
if !ok {
|
||||
err := common.UnmarshalBodyReusable(c, &sunoRequest)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return nil, fmt.Errorf("task_request not found in context")
|
||||
}
|
||||
data, err := json.Marshal(sunoRequest)
|
||||
data, err := common.Marshal(sunoRequest)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -99,7 +99,7 @@ func (a *TaskAdaptor) DoResponse(c *gin.Context, resp *http.Response, info *rela
|
||||
return
|
||||
}
|
||||
var sunoResponse dto.TaskResponse[string]
|
||||
err = json.Unmarshal(responseBody, &sunoResponse)
|
||||
err = common.Unmarshal(responseBody, &sunoResponse)
|
||||
if err != nil {
|
||||
taskErr = service.TaskErrorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError)
|
||||
return
|
||||
@@ -109,17 +109,13 @@ func (a *TaskAdaptor) DoResponse(c *gin.Context, resp *http.Response, info *rela
|
||||
return
|
||||
}
|
||||
|
||||
for k, v := range resp.Header {
|
||||
c.Writer.Header().Set(k, v[0])
|
||||
}
|
||||
c.Writer.Header().Set("Content-Type", "application/json")
|
||||
c.Writer.WriteHeader(resp.StatusCode)
|
||||
|
||||
_, err = io.Copy(c.Writer, bytes.NewBuffer(responseBody))
|
||||
if err != nil {
|
||||
taskErr = service.TaskErrorWrapper(err, "copy_response_body_failed", http.StatusInternalServerError)
|
||||
return
|
||||
// 使用公开 task_xxxx ID 替换上游 ID 返回给客户端
|
||||
publicResponse := dto.TaskResponse[string]{
|
||||
Code: sunoResponse.Code,
|
||||
Message: sunoResponse.Message,
|
||||
Data: info.PublicTaskID,
|
||||
}
|
||||
c.JSON(http.StatusOK, publicResponse)
|
||||
|
||||
return sunoResponse.Data, nil, nil
|
||||
}
|
||||
@@ -134,7 +130,7 @@ func (a *TaskAdaptor) GetChannelName() string {
|
||||
|
||||
func (a *TaskAdaptor) FetchTask(baseUrl, key string, body map[string]any, proxy string) (*http.Response, error) {
|
||||
requestUrl := fmt.Sprintf("%s/suno/fetch", baseUrl)
|
||||
byteBody, err := json.Marshal(body)
|
||||
byteBody, err := common.Marshal(body)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -144,13 +140,6 @@ func (a *TaskAdaptor) FetchTask(baseUrl, key string, body map[string]any, proxy
|
||||
common.SysLog(fmt.Sprintf("Get Task error: %v", err))
|
||||
return nil, err
|
||||
}
|
||||
defer req.Body.Close()
|
||||
// 设置超时时间
|
||||
timeout := time.Second * 15
|
||||
ctx, cancel := context.WithTimeout(context.Background(), timeout)
|
||||
defer cancel()
|
||||
// 使用带有超时的 context 创建新的请求
|
||||
req = req.WithContext(ctx)
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
req.Header.Set("Authorization", "Bearer "+key)
|
||||
client, err := service.GetHttpClientWithProxy(proxy)
|
||||
|
||||
95
relay/channel/task/taskcommon/helpers.go
Normal file
95
relay/channel/task/taskcommon/helpers.go
Normal file
@@ -0,0 +1,95 @@
|
||||
package taskcommon
|
||||
|
||||
import (
|
||||
"encoding/base64"
|
||||
"fmt"
|
||||
|
||||
"github.com/QuantumNous/new-api/common"
|
||||
"github.com/QuantumNous/new-api/model"
|
||||
relaycommon "github.com/QuantumNous/new-api/relay/common"
|
||||
"github.com/QuantumNous/new-api/setting/system_setting"
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
// UnmarshalMetadata converts a map[string]any metadata to a typed struct via JSON round-trip.
|
||||
// This replaces the repeated pattern: json.Marshal(metadata) → json.Unmarshal(bytes, &target).
|
||||
func UnmarshalMetadata(metadata map[string]any, target any) error {
|
||||
if metadata == nil {
|
||||
return nil
|
||||
}
|
||||
metaBytes, err := common.Marshal(metadata)
|
||||
if err != nil {
|
||||
return fmt.Errorf("marshal metadata failed: %w", err)
|
||||
}
|
||||
if err := common.Unmarshal(metaBytes, target); err != nil {
|
||||
return fmt.Errorf("unmarshal metadata failed: %w", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// DefaultString returns val if non-empty, otherwise fallback.
|
||||
func DefaultString(val, fallback string) string {
|
||||
if val == "" {
|
||||
return fallback
|
||||
}
|
||||
return val
|
||||
}
|
||||
|
||||
// DefaultInt returns val if non-zero, otherwise fallback.
|
||||
func DefaultInt(val, fallback int) int {
|
||||
if val == 0 {
|
||||
return fallback
|
||||
}
|
||||
return val
|
||||
}
|
||||
|
||||
// EncodeLocalTaskID encodes an upstream operation name to a URL-safe base64 string.
|
||||
// Used by Gemini/Vertex to store upstream names as task IDs.
|
||||
func EncodeLocalTaskID(name string) string {
|
||||
return base64.RawURLEncoding.EncodeToString([]byte(name))
|
||||
}
|
||||
|
||||
// DecodeLocalTaskID decodes a base64-encoded upstream operation name.
|
||||
func DecodeLocalTaskID(id string) (string, error) {
|
||||
b, err := base64.RawURLEncoding.DecodeString(id)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
return string(b), nil
|
||||
}
|
||||
|
||||
// BuildProxyURL constructs the video proxy URL using the public task ID.
|
||||
// e.g., "https://your-server.com/v1/videos/task_xxxx/content"
|
||||
func BuildProxyURL(taskID string) string {
|
||||
return fmt.Sprintf("%s/v1/videos/%s/content", system_setting.ServerAddress, taskID)
|
||||
}
|
||||
|
||||
// Status-to-progress mapping constants for polling updates.
|
||||
const (
|
||||
ProgressSubmitted = "10%"
|
||||
ProgressQueued = "20%"
|
||||
ProgressInProgress = "30%"
|
||||
ProgressComplete = "100%"
|
||||
)
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// BaseBilling — embeddable no-op implementations for TaskAdaptor billing methods.
|
||||
// Adaptors that do not need custom billing can embed this struct directly.
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
type BaseBilling struct{}
|
||||
|
||||
// EstimateBilling returns nil (no extra ratios; use base model price).
|
||||
func (BaseBilling) EstimateBilling(_ *gin.Context, _ *relaycommon.RelayInfo) map[string]float64 {
|
||||
return nil
|
||||
}
|
||||
|
||||
// AdjustBillingOnSubmit returns nil (no submit-time adjustment).
|
||||
func (BaseBilling) AdjustBillingOnSubmit(_ *relaycommon.RelayInfo, _ []byte) map[string]float64 {
|
||||
return nil
|
||||
}
|
||||
|
||||
// AdjustBillingOnComplete returns 0 (keep pre-charged amount).
|
||||
func (BaseBilling) AdjustBillingOnComplete(_ *model.Task, _ *relaycommon.TaskInfo) int {
|
||||
return 0
|
||||
}
|
||||
@@ -2,13 +2,12 @@ package vertex
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/base64"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"regexp"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/QuantumNous/new-api/common"
|
||||
"github.com/QuantumNous/new-api/model"
|
||||
@@ -17,6 +16,7 @@ import (
|
||||
"github.com/QuantumNous/new-api/constant"
|
||||
"github.com/QuantumNous/new-api/dto"
|
||||
"github.com/QuantumNous/new-api/relay/channel"
|
||||
taskcommon "github.com/QuantumNous/new-api/relay/channel/task/taskcommon"
|
||||
vertexcore "github.com/QuantumNous/new-api/relay/channel/vertex"
|
||||
relaycommon "github.com/QuantumNous/new-api/relay/common"
|
||||
"github.com/QuantumNous/new-api/service"
|
||||
@@ -62,6 +62,7 @@ type operationResponse struct {
|
||||
// ============================
|
||||
|
||||
type TaskAdaptor struct {
|
||||
taskcommon.BaseBilling
|
||||
ChannelType int
|
||||
apiKey string
|
||||
baseURL string
|
||||
@@ -82,10 +83,10 @@ func (a *TaskAdaptor) ValidateRequestAndSetAction(c *gin.Context, info *relaycom
|
||||
// BuildRequestURL constructs the upstream URL.
|
||||
func (a *TaskAdaptor) BuildRequestURL(info *relaycommon.RelayInfo) (string, error) {
|
||||
adc := &vertexcore.Credentials{}
|
||||
if err := json.Unmarshal([]byte(a.apiKey), adc); err != nil {
|
||||
if err := common.Unmarshal([]byte(a.apiKey), adc); err != nil {
|
||||
return "", fmt.Errorf("failed to decode credentials: %w", err)
|
||||
}
|
||||
modelName := info.OriginModelName
|
||||
modelName := info.UpstreamModelName
|
||||
if modelName == "" {
|
||||
modelName = "veo-3.0-generate-001"
|
||||
}
|
||||
@@ -116,7 +117,7 @@ func (a *TaskAdaptor) BuildRequestHeader(c *gin.Context, req *http.Request, info
|
||||
req.Header.Set("Accept", "application/json")
|
||||
|
||||
adc := &vertexcore.Credentials{}
|
||||
if err := json.Unmarshal([]byte(a.apiKey), adc); err != nil {
|
||||
if err := common.Unmarshal([]byte(a.apiKey), adc); err != nil {
|
||||
return fmt.Errorf("failed to decode credentials: %w", err)
|
||||
}
|
||||
|
||||
@@ -133,6 +134,28 @@ func (a *TaskAdaptor) BuildRequestHeader(c *gin.Context, req *http.Request, info
|
||||
return nil
|
||||
}
|
||||
|
||||
// EstimateBilling 根据用户请求中的 sampleCount 计算 OtherRatios。
|
||||
func (a *TaskAdaptor) EstimateBilling(c *gin.Context, _ *relaycommon.RelayInfo) map[string]float64 {
|
||||
sampleCount := 1
|
||||
v, ok := c.Get("task_request")
|
||||
if ok {
|
||||
req := v.(relaycommon.TaskSubmitReq)
|
||||
if req.Metadata != nil {
|
||||
if sc, exists := req.Metadata["sampleCount"]; exists {
|
||||
if i, ok := sc.(int); ok && i > 0 {
|
||||
sampleCount = i
|
||||
}
|
||||
if f, ok := sc.(float64); ok && int(f) > 0 {
|
||||
sampleCount = int(f)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
return map[string]float64{
|
||||
"sampleCount": float64(sampleCount),
|
||||
}
|
||||
}
|
||||
|
||||
// BuildRequestBody converts request into Vertex specific format.
|
||||
func (a *TaskAdaptor) BuildRequestBody(c *gin.Context, info *relaycommon.RelayInfo) (io.Reader, error) {
|
||||
v, ok := c.Get("task_request")
|
||||
@@ -166,25 +189,7 @@ func (a *TaskAdaptor) BuildRequestBody(c *gin.Context, info *relaycommon.RelayIn
|
||||
return nil, fmt.Errorf("sampleCount must be greater than 0")
|
||||
}
|
||||
|
||||
// if req.Duration > 0 {
|
||||
// body.Parameters["durationSeconds"] = req.Duration
|
||||
// } else if req.Seconds != "" {
|
||||
// seconds, err := strconv.Atoi(req.Seconds)
|
||||
// if err != nil {
|
||||
// return nil, errors.Wrap(err, "convert seconds to int failed")
|
||||
// }
|
||||
// body.Parameters["durationSeconds"] = seconds
|
||||
// }
|
||||
|
||||
info.PriceData.OtherRatios = map[string]float64{
|
||||
"sampleCount": float64(body.Parameters["sampleCount"].(int)),
|
||||
}
|
||||
|
||||
// if v, ok := body.Parameters["durationSeconds"]; ok {
|
||||
// info.PriceData.OtherRatios["durationSeconds"] = float64(v.(int))
|
||||
// }
|
||||
|
||||
data, err := json.Marshal(body)
|
||||
data, err := common.Marshal(body)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -205,14 +210,19 @@ func (a *TaskAdaptor) DoResponse(c *gin.Context, resp *http.Response, info *rela
|
||||
_ = resp.Body.Close()
|
||||
|
||||
var s submitResponse
|
||||
if err := json.Unmarshal(responseBody, &s); err != nil {
|
||||
if err := common.Unmarshal(responseBody, &s); err != nil {
|
||||
return "", nil, service.TaskErrorWrapper(err, "unmarshal_response_failed", http.StatusInternalServerError)
|
||||
}
|
||||
if strings.TrimSpace(s.Name) == "" {
|
||||
return "", nil, service.TaskErrorWrapper(fmt.Errorf("missing operation name"), "invalid_response", http.StatusInternalServerError)
|
||||
}
|
||||
localID := encodeLocalTaskID(s.Name)
|
||||
c.JSON(http.StatusOK, gin.H{"task_id": localID})
|
||||
localID := taskcommon.EncodeLocalTaskID(s.Name)
|
||||
ov := dto.NewOpenAIVideo()
|
||||
ov.ID = info.PublicTaskID
|
||||
ov.TaskID = info.PublicTaskID
|
||||
ov.CreatedAt = time.Now().Unix()
|
||||
ov.Model = info.OriginModelName
|
||||
c.JSON(http.StatusOK, ov)
|
||||
return localID, responseBody, nil
|
||||
}
|
||||
|
||||
@@ -225,7 +235,7 @@ func (a *TaskAdaptor) FetchTask(baseUrl, key string, body map[string]any, proxy
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("invalid task_id")
|
||||
}
|
||||
upstreamName, err := decodeLocalTaskID(taskID)
|
||||
upstreamName, err := taskcommon.DecodeLocalTaskID(taskID)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("decode task_id failed: %w", err)
|
||||
}
|
||||
@@ -245,12 +255,12 @@ func (a *TaskAdaptor) FetchTask(baseUrl, key string, body map[string]any, proxy
|
||||
url = fmt.Sprintf("https://%s-aiplatform.googleapis.com/v1/projects/%s/locations/%s/publishers/google/models/%s:fetchPredictOperation", region, project, region, modelName)
|
||||
}
|
||||
payload := map[string]string{"operationName": upstreamName}
|
||||
data, err := json.Marshal(payload)
|
||||
data, err := common.Marshal(payload)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
adc := &vertexcore.Credentials{}
|
||||
if err := json.Unmarshal([]byte(key), adc); err != nil {
|
||||
if err := common.Unmarshal([]byte(key), adc); err != nil {
|
||||
return nil, fmt.Errorf("failed to decode credentials: %w", err)
|
||||
}
|
||||
token, err := vertexcore.AcquireAccessToken(*adc, proxy)
|
||||
@@ -274,7 +284,7 @@ func (a *TaskAdaptor) FetchTask(baseUrl, key string, body map[string]any, proxy
|
||||
|
||||
func (a *TaskAdaptor) ParseTaskResult(respBody []byte) (*relaycommon.TaskInfo, error) {
|
||||
var op operationResponse
|
||||
if err := json.Unmarshal(respBody, &op); err != nil {
|
||||
if err := common.Unmarshal(respBody, &op); err != nil {
|
||||
return nil, fmt.Errorf("unmarshal operation response failed: %w", err)
|
||||
}
|
||||
ti := &relaycommon.TaskInfo{}
|
||||
@@ -338,7 +348,10 @@ func (a *TaskAdaptor) ParseTaskResult(respBody []byte) (*relaycommon.TaskInfo, e
|
||||
}
|
||||
|
||||
func (a *TaskAdaptor) ConvertToOpenAIVideo(task *model.Task) ([]byte, error) {
|
||||
upstreamName, err := decodeLocalTaskID(task.TaskID)
|
||||
// Use GetUpstreamTaskID() to get the real upstream operation name for model extraction.
|
||||
// task.TaskID is now a public task_xxxx ID, no longer a base64-encoded upstream name.
|
||||
upstreamTaskID := task.GetUpstreamTaskID()
|
||||
upstreamName, err := taskcommon.DecodeLocalTaskID(upstreamTaskID)
|
||||
if err != nil {
|
||||
upstreamName = ""
|
||||
}
|
||||
@@ -353,8 +366,8 @@ func (a *TaskAdaptor) ConvertToOpenAIVideo(task *model.Task) ([]byte, error) {
|
||||
v.SetProgressStr(task.Progress)
|
||||
v.CreatedAt = task.CreatedAt
|
||||
v.CompletedAt = task.UpdatedAt
|
||||
if strings.HasPrefix(task.FailReason, "data:") && len(task.FailReason) > 0 {
|
||||
v.SetMetadata("url", task.FailReason)
|
||||
if resultURL := task.GetResultURL(); strings.HasPrefix(resultURL, "data:") && len(resultURL) > 0 {
|
||||
v.SetMetadata("url", resultURL)
|
||||
}
|
||||
|
||||
return common.Marshal(v)
|
||||
@@ -364,18 +377,6 @@ func (a *TaskAdaptor) ConvertToOpenAIVideo(task *model.Task) ([]byte, error) {
|
||||
// helpers
|
||||
// ============================
|
||||
|
||||
func encodeLocalTaskID(name string) string {
|
||||
return base64.RawURLEncoding.EncodeToString([]byte(name))
|
||||
}
|
||||
|
||||
func decodeLocalTaskID(local string) (string, error) {
|
||||
b, err := base64.RawURLEncoding.DecodeString(local)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
return string(b), nil
|
||||
}
|
||||
|
||||
var regionRe = regexp.MustCompile(`locations/([a-z0-9-]+)/`)
|
||||
|
||||
func extractRegionFromOperationName(name string) string {
|
||||
|
||||
@@ -2,7 +2,6 @@ package vidu
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
@@ -16,6 +15,7 @@ import (
|
||||
"github.com/QuantumNous/new-api/dto"
|
||||
"github.com/QuantumNous/new-api/model"
|
||||
"github.com/QuantumNous/new-api/relay/channel"
|
||||
taskcommon "github.com/QuantumNous/new-api/relay/channel/task/taskcommon"
|
||||
relaycommon "github.com/QuantumNous/new-api/relay/common"
|
||||
"github.com/QuantumNous/new-api/service"
|
||||
|
||||
@@ -73,6 +73,7 @@ type creation struct {
|
||||
// ============================
|
||||
|
||||
type TaskAdaptor struct {
|
||||
taskcommon.BaseBilling
|
||||
ChannelType int
|
||||
baseURL string
|
||||
}
|
||||
@@ -115,7 +116,7 @@ func (a *TaskAdaptor) BuildRequestBody(c *gin.Context, info *relaycommon.RelayIn
|
||||
}
|
||||
req := v.(relaycommon.TaskSubmitReq)
|
||||
|
||||
body, err := a.convertToRequestPayload(&req)
|
||||
body, err := a.convertToRequestPayload(&req, info)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -127,7 +128,7 @@ func (a *TaskAdaptor) BuildRequestBody(c *gin.Context, info *relaycommon.RelayIn
|
||||
}
|
||||
}
|
||||
|
||||
data, err := json.Marshal(body)
|
||||
data, err := common.Marshal(body)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -168,7 +169,7 @@ func (a *TaskAdaptor) DoResponse(c *gin.Context, resp *http.Response, info *rela
|
||||
}
|
||||
|
||||
var vResp responsePayload
|
||||
err = json.Unmarshal(responseBody, &vResp)
|
||||
err = common.Unmarshal(responseBody, &vResp)
|
||||
if err != nil {
|
||||
taskErr = service.TaskErrorWrapper(errors.Wrap(err, fmt.Sprintf("%s", responseBody)), "unmarshal_response_failed", http.StatusInternalServerError)
|
||||
return
|
||||
@@ -180,8 +181,8 @@ func (a *TaskAdaptor) DoResponse(c *gin.Context, resp *http.Response, info *rela
|
||||
}
|
||||
|
||||
ov := dto.NewOpenAIVideo()
|
||||
ov.ID = vResp.TaskId
|
||||
ov.TaskID = vResp.TaskId
|
||||
ov.ID = info.PublicTaskID
|
||||
ov.TaskID = info.PublicTaskID
|
||||
ov.CreatedAt = time.Now().Unix()
|
||||
ov.Model = info.OriginModelName
|
||||
c.JSON(http.StatusOK, ov)
|
||||
@@ -223,47 +224,27 @@ func (a *TaskAdaptor) GetChannelName() string {
|
||||
// helpers
|
||||
// ============================
|
||||
|
||||
func (a *TaskAdaptor) convertToRequestPayload(req *relaycommon.TaskSubmitReq) (*requestPayload, error) {
|
||||
func (a *TaskAdaptor) convertToRequestPayload(req *relaycommon.TaskSubmitReq, info *relaycommon.RelayInfo) (*requestPayload, error) {
|
||||
r := requestPayload{
|
||||
Model: defaultString(req.Model, "viduq1"),
|
||||
Model: taskcommon.DefaultString(info.UpstreamModelName, "viduq1"),
|
||||
Images: req.Images,
|
||||
Prompt: req.Prompt,
|
||||
Duration: defaultInt(req.Duration, 5),
|
||||
Resolution: defaultString(req.Size, "1080p"),
|
||||
Duration: taskcommon.DefaultInt(req.Duration, 5),
|
||||
Resolution: taskcommon.DefaultString(req.Size, "1080p"),
|
||||
MovementAmplitude: "auto",
|
||||
Bgm: false,
|
||||
}
|
||||
metadata := req.Metadata
|
||||
medaBytes, err := json.Marshal(metadata)
|
||||
if err != nil {
|
||||
return nil, errors.Wrap(err, "metadata marshal metadata failed")
|
||||
}
|
||||
err = json.Unmarshal(medaBytes, &r)
|
||||
if err != nil {
|
||||
if err := taskcommon.UnmarshalMetadata(req.Metadata, &r); err != nil {
|
||||
return nil, errors.Wrap(err, "unmarshal metadata failed")
|
||||
}
|
||||
return &r, nil
|
||||
}
|
||||
|
||||
func defaultString(value, defaultValue string) string {
|
||||
if value == "" {
|
||||
return defaultValue
|
||||
}
|
||||
return value
|
||||
}
|
||||
|
||||
func defaultInt(value, defaultValue int) int {
|
||||
if value == 0 {
|
||||
return defaultValue
|
||||
}
|
||||
return value
|
||||
}
|
||||
|
||||
func (a *TaskAdaptor) ParseTaskResult(respBody []byte) (*relaycommon.TaskInfo, error) {
|
||||
taskInfo := &relaycommon.TaskInfo{}
|
||||
|
||||
var taskResp taskResultResponse
|
||||
err := json.Unmarshal(respBody, &taskResp)
|
||||
err := common.Unmarshal(respBody, &taskResp)
|
||||
if err != nil {
|
||||
return nil, errors.Wrap(err, "failed to unmarshal response body")
|
||||
}
|
||||
@@ -293,7 +274,7 @@ func (a *TaskAdaptor) ParseTaskResult(respBody []byte) (*relaycommon.TaskInfo, e
|
||||
|
||||
func (a *TaskAdaptor) ConvertToOpenAIVideo(originTask *model.Task) ([]byte, error) {
|
||||
var viduResp taskResultResponse
|
||||
if err := json.Unmarshal(originTask.Data, &viduResp); err != nil {
|
||||
if err := common.Unmarshal(originTask.Data, &viduResp); err != nil {
|
||||
return nil, errors.Wrap(err, "unmarshal vidu task data failed")
|
||||
}
|
||||
|
||||
@@ -315,6 +296,5 @@ func (a *TaskAdaptor) ConvertToOpenAIVideo(originTask *model.Task) ([]byte, erro
|
||||
}
|
||||
}
|
||||
|
||||
jsonData, _ := common.Marshal(openAIVideo)
|
||||
return jsonData, nil
|
||||
return common.Marshal(openAIVideo)
|
||||
}
|
||||
|
||||
@@ -75,7 +75,7 @@ func chatCompletionsViaResponses(c *gin.Context, info *relaycommon.RelayInfo, ad
|
||||
return nil, types.NewError(err, types.ErrorCodeConvertRequestFailed, types.ErrOptionWithSkipRetry())
|
||||
}
|
||||
|
||||
chatJSON, err = relaycommon.RemoveDisabledFields(chatJSON, info.ChannelOtherSettings)
|
||||
chatJSON, err = relaycommon.RemoveDisabledFields(chatJSON, info.ChannelOtherSettings, info.ChannelSetting.PassThroughBodyEnabled)
|
||||
if err != nil {
|
||||
return nil, types.NewError(err, types.ErrorCodeConvertRequestFailed, types.ErrOptionWithSkipRetry())
|
||||
}
|
||||
@@ -119,7 +119,7 @@ func chatCompletionsViaResponses(c *gin.Context, info *relaycommon.RelayInfo, ad
|
||||
return nil, types.NewError(err, types.ErrorCodeConvertRequestFailed, types.ErrOptionWithSkipRetry())
|
||||
}
|
||||
|
||||
jsonData, err = relaycommon.RemoveDisabledFields(jsonData, info.ChannelOtherSettings)
|
||||
jsonData, err = relaycommon.RemoveDisabledFields(jsonData, info.ChannelOtherSettings, info.ChannelSetting.PassThroughBodyEnabled)
|
||||
if err != nil {
|
||||
return nil, types.NewError(err, types.ErrorCodeConvertRequestFailed, types.ErrOptionWithSkipRetry())
|
||||
}
|
||||
|
||||
@@ -146,7 +146,7 @@ func ClaudeHelper(c *gin.Context, info *relaycommon.RelayInfo) (newAPIError *typ
|
||||
}
|
||||
|
||||
// remove disabled fields for Claude API
|
||||
jsonData, err = relaycommon.RemoveDisabledFields(jsonData, info.ChannelOtherSettings)
|
||||
jsonData, err = relaycommon.RemoveDisabledFields(jsonData, info.ChannelOtherSettings, info.ChannelSetting.PassThroughBodyEnabled)
|
||||
if err != nil {
|
||||
return types.NewError(err, types.ErrorCodeConvertRequestFailed, types.ErrOptionWithSkipRetry())
|
||||
}
|
||||
|
||||
@@ -6,6 +6,9 @@ import (
|
||||
"testing"
|
||||
|
||||
"github.com/QuantumNous/new-api/types"
|
||||
|
||||
"github.com/QuantumNous/new-api/dto"
|
||||
"github.com/QuantumNous/new-api/setting/model_setting"
|
||||
)
|
||||
|
||||
func TestApplyParamOverrideTrimPrefix(t *testing.T) {
|
||||
@@ -1311,6 +1314,76 @@ func TestApplyParamOverrideWithRelayInfoMoveAndCopyHeaders(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestRemoveDisabledFieldsSkipWhenChannelPassThroughEnabled(t *testing.T) {
|
||||
input := `{
|
||||
"service_tier":"flex",
|
||||
"safety_identifier":"user-123",
|
||||
"store":true,
|
||||
"stream_options":{"include_obfuscation":false}
|
||||
}`
|
||||
settings := dto.ChannelOtherSettings{}
|
||||
|
||||
out, err := RemoveDisabledFields([]byte(input), settings, true)
|
||||
if err != nil {
|
||||
t.Fatalf("RemoveDisabledFields returned error: %v", err)
|
||||
}
|
||||
assertJSONEqual(t, input, string(out))
|
||||
}
|
||||
|
||||
func TestRemoveDisabledFieldsSkipWhenGlobalPassThroughEnabled(t *testing.T) {
|
||||
original := model_setting.GetGlobalSettings().PassThroughRequestEnabled
|
||||
model_setting.GetGlobalSettings().PassThroughRequestEnabled = true
|
||||
t.Cleanup(func() {
|
||||
model_setting.GetGlobalSettings().PassThroughRequestEnabled = original
|
||||
})
|
||||
|
||||
input := `{
|
||||
"service_tier":"flex",
|
||||
"safety_identifier":"user-123",
|
||||
"stream_options":{"include_obfuscation":false}
|
||||
}`
|
||||
settings := dto.ChannelOtherSettings{}
|
||||
|
||||
out, err := RemoveDisabledFields([]byte(input), settings, false)
|
||||
if err != nil {
|
||||
t.Fatalf("RemoveDisabledFields returned error: %v", err)
|
||||
}
|
||||
assertJSONEqual(t, input, string(out))
|
||||
}
|
||||
|
||||
func TestRemoveDisabledFieldsDefaultFiltering(t *testing.T) {
|
||||
input := `{
|
||||
"service_tier":"flex",
|
||||
"inference_geo":"eu",
|
||||
"safety_identifier":"user-123",
|
||||
"store":true,
|
||||
"stream_options":{"include_obfuscation":false}
|
||||
}`
|
||||
settings := dto.ChannelOtherSettings{}
|
||||
|
||||
out, err := RemoveDisabledFields([]byte(input), settings, false)
|
||||
if err != nil {
|
||||
t.Fatalf("RemoveDisabledFields returned error: %v", err)
|
||||
}
|
||||
assertJSONEqual(t, `{"store":true}`, string(out))
|
||||
}
|
||||
|
||||
func TestRemoveDisabledFieldsAllowInferenceGeo(t *testing.T) {
|
||||
input := `{
|
||||
"inference_geo":"eu",
|
||||
"store":true
|
||||
}`
|
||||
settings := dto.ChannelOtherSettings{
|
||||
AllowInferenceGeo: true,
|
||||
}
|
||||
|
||||
out, err := RemoveDisabledFields([]byte(input), settings, false)
|
||||
if err != nil {
|
||||
t.Fatalf("RemoveDisabledFields returned error: %v", err)
|
||||
}
|
||||
assertJSONEqual(t, `{"inference_geo":"eu","store":true}`, string(out))
|
||||
}
|
||||
|
||||
func assertJSONEqual(t *testing.T, want, got string) {
|
||||
t.Helper()
|
||||
|
||||
|
||||
@@ -119,8 +119,12 @@ type RelayInfo struct {
|
||||
SendResponseCount int
|
||||
ReceivedResponseCount int
|
||||
FinalPreConsumedQuota int // 最终预消耗的配额
|
||||
// ForcePreConsume 为 true 时禁用 BillingSession 的信任额度旁路,
|
||||
// 强制预扣全额。用于异步任务(视频/音乐生成等),因为请求返回后任务仍在运行,
|
||||
// 必须在提交前锁定全额。
|
||||
ForcePreConsume bool
|
||||
// Billing 是计费会话,封装了预扣费/结算/退款的统一生命周期。
|
||||
// 免费模型和按次计费(MJ/Task)时为 nil。
|
||||
// 免费模型时为 nil。
|
||||
Billing BillingSettler
|
||||
// BillingSource indicates whether this request is billed from wallet quota or subscription.
|
||||
// "" or "wallet" => wallet; "subscription" => subscription
|
||||
@@ -153,7 +157,8 @@ type RelayInfo struct {
|
||||
// RequestConversionChain records request format conversions in order, e.g.
|
||||
// ["openai", "openai_responses"] or ["openai", "claude"].
|
||||
RequestConversionChain []types.RelayFormat
|
||||
// 最终请求到上游的格式 TODO: 当前仅设置了Claude
|
||||
// 最终请求到上游的格式。可由 adaptor 显式设置;
|
||||
// 若为空,调用 GetFinalRequestRelayFormat 会回退到 RequestConversionChain 的最后一项或 RelayFormat。
|
||||
FinalRequestRelayFormat types.RelayFormat
|
||||
|
||||
ThinkingContentInfo
|
||||
@@ -552,8 +557,10 @@ func GenRelayInfo(c *gin.Context, relayFormat types.RelayFormat, request dto.Req
|
||||
return nil, errors.New("request is not a OpenAIResponsesCompactionRequest")
|
||||
case types.RelayFormatTask:
|
||||
info = genBaseRelayInfo(c, nil)
|
||||
info.TaskRelayInfo = &TaskRelayInfo{}
|
||||
case types.RelayFormatMjProxy:
|
||||
info = genBaseRelayInfo(c, nil)
|
||||
info.TaskRelayInfo = &TaskRelayInfo{}
|
||||
default:
|
||||
err = errors.New("invalid relay format")
|
||||
}
|
||||
@@ -600,6 +607,19 @@ func (info *RelayInfo) AppendRequestConversion(format types.RelayFormat) {
|
||||
info.RequestConversionChain = append(info.RequestConversionChain, format)
|
||||
}
|
||||
|
||||
func (info *RelayInfo) GetFinalRequestRelayFormat() types.RelayFormat {
|
||||
if info == nil {
|
||||
return ""
|
||||
}
|
||||
if info.FinalRequestRelayFormat != "" {
|
||||
return info.FinalRequestRelayFormat
|
||||
}
|
||||
if n := len(info.RequestConversionChain); n > 0 {
|
||||
return info.RequestConversionChain[n-1]
|
||||
}
|
||||
return info.RelayFormat
|
||||
}
|
||||
|
||||
func GenRelayInfoResponsesCompaction(c *gin.Context, request *dto.OpenAIResponsesCompactionRequest) *RelayInfo {
|
||||
info := genBaseRelayInfo(c, request)
|
||||
if info.RelayMode == relayconstant.RelayModeUnknown {
|
||||
@@ -635,8 +655,16 @@ func (info *RelayInfo) HasSendResponse() bool {
|
||||
type TaskRelayInfo struct {
|
||||
Action string
|
||||
OriginTaskID string
|
||||
// PublicTaskID 是提交时预生成的 task_xxxx 格式公开 ID,
|
||||
// 供 DoResponse 在返回给客户端时使用(避免暴露上游真实 ID)。
|
||||
PublicTaskID string
|
||||
|
||||
ConsumeQuota bool
|
||||
|
||||
// LockedChannel holds the full channel object when the request is bound to
|
||||
// a specific channel (e.g., remix on origin task's channel). Stored as any
|
||||
// to avoid an import cycle with model; callers type-assert to *model.Channel.
|
||||
LockedChannel any
|
||||
}
|
||||
|
||||
type TaskSubmitReq struct {
|
||||
@@ -694,11 +722,11 @@ func (t *TaskSubmitReq) UnmarshalJSON(data []byte) error {
|
||||
func (t *TaskSubmitReq) UnmarshalMetadata(v any) error {
|
||||
metadata := t.Metadata
|
||||
if metadata != nil {
|
||||
metadataBytes, err := json.Marshal(metadata)
|
||||
metadataBytes, err := common.Marshal(metadata)
|
||||
if err != nil {
|
||||
return fmt.Errorf("marshal metadata failed: %w", err)
|
||||
}
|
||||
err = json.Unmarshal(metadataBytes, v)
|
||||
err = common.Unmarshal(metadataBytes, v)
|
||||
if err != nil {
|
||||
return fmt.Errorf("unmarshal metadata to target failed: %w", err)
|
||||
}
|
||||
@@ -727,9 +755,15 @@ func FailTaskInfo(reason string) *TaskInfo {
|
||||
|
||||
// RemoveDisabledFields 从请求 JSON 数据中移除渠道设置中禁用的字段
|
||||
// service_tier: 服务层级字段,可能导致额外计费(OpenAI、Claude、Responses API 支持)
|
||||
// inference_geo: Claude 数据驻留推理区域字段(仅 Claude 支持,默认过滤)
|
||||
// store: 数据存储授权字段,涉及用户隐私(仅 OpenAI、Responses API 支持,默认允许透传,禁用后可能导致 Codex 无法使用)
|
||||
// safety_identifier: 安全标识符,用于向 OpenAI 报告违规用户(仅 OpenAI 支持,涉及用户隐私)
|
||||
func RemoveDisabledFields(jsonData []byte, channelOtherSettings dto.ChannelOtherSettings) ([]byte, error) {
|
||||
// stream_options.include_obfuscation: 响应流混淆控制字段(仅 OpenAI Responses API 支持)
|
||||
func RemoveDisabledFields(jsonData []byte, channelOtherSettings dto.ChannelOtherSettings, channelPassThroughEnabled bool) ([]byte, error) {
|
||||
if model_setting.GetGlobalSettings().PassThroughRequestEnabled || channelPassThroughEnabled {
|
||||
return jsonData, nil
|
||||
}
|
||||
|
||||
var data map[string]interface{}
|
||||
if err := common.Unmarshal(jsonData, &data); err != nil {
|
||||
common.SysError("RemoveDisabledFields Unmarshal error :" + err.Error())
|
||||
@@ -743,6 +777,13 @@ func RemoveDisabledFields(jsonData []byte, channelOtherSettings dto.ChannelOther
|
||||
}
|
||||
}
|
||||
|
||||
// 默认移除 inference_geo,除非明确允许(避免在未授权情况下透传数据驻留区域)
|
||||
if !channelOtherSettings.AllowInferenceGeo {
|
||||
if _, exists := data["inference_geo"]; exists {
|
||||
delete(data, "inference_geo")
|
||||
}
|
||||
}
|
||||
|
||||
// 默认允许 store 透传,除非明确禁用(禁用可能影响 Codex 使用)
|
||||
if channelOtherSettings.DisableStore {
|
||||
if _, exists := data["store"]; exists {
|
||||
@@ -757,6 +798,22 @@ func RemoveDisabledFields(jsonData []byte, channelOtherSettings dto.ChannelOther
|
||||
}
|
||||
}
|
||||
|
||||
// 默认移除 stream_options.include_obfuscation,除非明确允许(避免关闭响应流混淆保护)
|
||||
if !channelOtherSettings.AllowIncludeObfuscation {
|
||||
if streamOptionsAny, exists := data["stream_options"]; exists {
|
||||
if streamOptions, ok := streamOptionsAny.(map[string]interface{}); ok {
|
||||
if _, includeExists := streamOptions["include_obfuscation"]; includeExists {
|
||||
delete(streamOptions, "include_obfuscation")
|
||||
}
|
||||
if len(streamOptions) == 0 {
|
||||
delete(data, "stream_options")
|
||||
} else {
|
||||
data["stream_options"] = streamOptions
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
jsonDataAfter, err := common.Marshal(data)
|
||||
if err != nil {
|
||||
common.SysError("RemoveDisabledFields Marshal error :" + err.Error())
|
||||
|
||||
40
relay/common/relay_info_test.go
Normal file
40
relay/common/relay_info_test.go
Normal file
@@ -0,0 +1,40 @@
|
||||
package common
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/QuantumNous/new-api/types"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestRelayInfoGetFinalRequestRelayFormatPrefersExplicitFinal(t *testing.T) {
|
||||
info := &RelayInfo{
|
||||
RelayFormat: types.RelayFormatOpenAI,
|
||||
RequestConversionChain: []types.RelayFormat{types.RelayFormatOpenAI, types.RelayFormatClaude},
|
||||
FinalRequestRelayFormat: types.RelayFormatOpenAIResponses,
|
||||
}
|
||||
|
||||
require.Equal(t, types.RelayFormat(types.RelayFormatOpenAIResponses), info.GetFinalRequestRelayFormat())
|
||||
}
|
||||
|
||||
func TestRelayInfoGetFinalRequestRelayFormatFallsBackToConversionChain(t *testing.T) {
|
||||
info := &RelayInfo{
|
||||
RelayFormat: types.RelayFormatOpenAI,
|
||||
RequestConversionChain: []types.RelayFormat{types.RelayFormatOpenAI, types.RelayFormatClaude},
|
||||
}
|
||||
|
||||
require.Equal(t, types.RelayFormat(types.RelayFormatClaude), info.GetFinalRequestRelayFormat())
|
||||
}
|
||||
|
||||
func TestRelayInfoGetFinalRequestRelayFormatFallsBackToRelayFormat(t *testing.T) {
|
||||
info := &RelayInfo{
|
||||
RelayFormat: types.RelayFormatGemini,
|
||||
}
|
||||
|
||||
require.Equal(t, types.RelayFormat(types.RelayFormatGemini), info.GetFinalRequestRelayFormat())
|
||||
}
|
||||
|
||||
func TestRelayInfoGetFinalRequestRelayFormatNilReceiver(t *testing.T) {
|
||||
var info *RelayInfo
|
||||
require.Equal(t, types.RelayFormat(""), info.GetFinalRequestRelayFormat())
|
||||
}
|
||||
@@ -173,16 +173,10 @@ func ValidateMultipartDirect(c *gin.Context, info *RelayInfo) *dto.TaskError {
|
||||
if model == "sora-2-pro" && !lo.Contains([]string{"720x1280", "1280x720", "1792x1024", "1024x1792"}, size) {
|
||||
return createTaskError(fmt.Errorf("sora-2 size is invalid"), "invalid_size", http.StatusBadRequest, true)
|
||||
}
|
||||
info.PriceData.OtherRatios = map[string]float64{
|
||||
"seconds": float64(seconds),
|
||||
"size": 1,
|
||||
}
|
||||
if lo.Contains([]string{"1792x1024", "1024x1792"}, size) {
|
||||
info.PriceData.OtherRatios["size"] = 1.666667
|
||||
}
|
||||
// OtherRatios 已移到 Sora adaptor 的 EstimateBilling 中设置
|
||||
}
|
||||
|
||||
info.Action = action
|
||||
storeTaskRequest(c, info, action, req)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -165,7 +165,7 @@ func TextHelper(c *gin.Context, info *relaycommon.RelayInfo) (newAPIError *types
|
||||
}
|
||||
|
||||
// remove disabled fields for OpenAI API
|
||||
jsonData, err = relaycommon.RemoveDisabledFields(jsonData, info.ChannelOtherSettings)
|
||||
jsonData, err = relaycommon.RemoveDisabledFields(jsonData, info.ChannelOtherSettings, info.ChannelSetting.PassThroughBodyEnabled)
|
||||
if err != nil {
|
||||
return types.NewError(err, types.ErrorCodeConvertRequestFailed, types.ErrOptionWithSkipRetry())
|
||||
}
|
||||
@@ -232,7 +232,7 @@ func postConsumeQuota(ctx *gin.Context, relayInfo *relaycommon.RelayInfo, usage
|
||||
}
|
||||
|
||||
if originUsage != nil {
|
||||
service.ObserveChannelAffinityUsageCacheFromContext(ctx, usage)
|
||||
service.ObserveChannelAffinityUsageCacheByRelayFormat(ctx, usage, relayInfo.GetFinalRequestRelayFormat())
|
||||
}
|
||||
|
||||
adminRejectReason := common.GetContextKeyString(ctx, constant.ContextKeyAdminRejectReason)
|
||||
@@ -336,7 +336,7 @@ func postConsumeQuota(ctx *gin.Context, relayInfo *relaycommon.RelayInfo, usage
|
||||
|
||||
var audioInputQuota decimal.Decimal
|
||||
var audioInputPrice float64
|
||||
isClaudeUsageSemantic := relayInfo.FinalRequestRelayFormat == types.RelayFormatClaude
|
||||
isClaudeUsageSemantic := relayInfo.GetFinalRequestRelayFormat() == types.RelayFormatClaude
|
||||
if !relayInfo.PriceData.UsePrice {
|
||||
baseTokens := dPromptTokens
|
||||
// 减去 cached tokens
|
||||
|
||||
@@ -140,7 +140,7 @@ func ModelPriceHelper(c *gin.Context, info *relaycommon.RelayInfo, promptTokens
|
||||
}
|
||||
|
||||
// ModelPriceHelperPerCall 按次计费的 PriceHelper (MJ、Task)
|
||||
func ModelPriceHelperPerCall(c *gin.Context, info *relaycommon.RelayInfo) types.PerCallPriceData {
|
||||
func ModelPriceHelperPerCall(c *gin.Context, info *relaycommon.RelayInfo) types.PriceData {
|
||||
groupRatioInfo := HandleGroupRatio(c, info)
|
||||
|
||||
modelPrice, success := ratio_setting.GetModelPrice(info.OriginModelName, true)
|
||||
@@ -154,7 +154,18 @@ func ModelPriceHelperPerCall(c *gin.Context, info *relaycommon.RelayInfo) types.
|
||||
}
|
||||
}
|
||||
quota := int(modelPrice * common.QuotaPerUnit * groupRatioInfo.GroupRatio)
|
||||
priceData := types.PerCallPriceData{
|
||||
|
||||
// 免费模型检测(与 ModelPriceHelper 对齐)
|
||||
freeModel := false
|
||||
if !operation_setting.GetQuotaSetting().EnableFreeModelPreConsume {
|
||||
if groupRatioInfo.GroupRatio == 0 || modelPrice == 0 {
|
||||
quota = 0
|
||||
freeModel = true
|
||||
}
|
||||
}
|
||||
|
||||
priceData := types.PriceData{
|
||||
FreeModel: freeModel,
|
||||
ModelPrice: modelPrice,
|
||||
Quota: quota,
|
||||
GroupRatioInfo: groupRatioInfo,
|
||||
|
||||
@@ -176,10 +176,32 @@ func StreamScannerHandler(c *gin.Context, resp *http.Response, info *relaycommon
|
||||
})
|
||||
}
|
||||
|
||||
dataChan := make(chan string, 10)
|
||||
|
||||
wg.Add(1)
|
||||
gopool.Go(func() {
|
||||
defer func() {
|
||||
wg.Done()
|
||||
if r := recover(); r != nil {
|
||||
logger.LogError(c, fmt.Sprintf("data handler goroutine panic: %v", r))
|
||||
}
|
||||
common.SafeSendBool(stopChan, true)
|
||||
}()
|
||||
for data := range dataChan {
|
||||
writeMutex.Lock()
|
||||
success := dataHandler(data)
|
||||
writeMutex.Unlock()
|
||||
if !success {
|
||||
return
|
||||
}
|
||||
}
|
||||
})
|
||||
|
||||
// Scanner goroutine with improved error handling
|
||||
wg.Add(1)
|
||||
common.RelayCtxGo(ctx, func() {
|
||||
defer func() {
|
||||
close(dataChan)
|
||||
wg.Done()
|
||||
if r := recover(); r != nil {
|
||||
logger.LogError(c, fmt.Sprintf("scanner goroutine panic: %v", r))
|
||||
@@ -215,27 +237,16 @@ func StreamScannerHandler(c *gin.Context, resp *http.Response, info *relaycommon
|
||||
continue
|
||||
}
|
||||
data = data[5:]
|
||||
data = strings.TrimLeft(data, " ")
|
||||
data = strings.TrimSuffix(data, "\r")
|
||||
data = strings.TrimSpace(data)
|
||||
if data == "" {
|
||||
continue
|
||||
}
|
||||
if !strings.HasPrefix(data, "[DONE]") {
|
||||
info.SetFirstResponseTime()
|
||||
info.ReceivedResponseCount++
|
||||
// 使用超时机制防止写操作阻塞
|
||||
done := make(chan bool, 1)
|
||||
gopool.Go(func() {
|
||||
writeMutex.Lock()
|
||||
defer writeMutex.Unlock()
|
||||
done <- dataHandler(data)
|
||||
})
|
||||
|
||||
select {
|
||||
case success := <-done:
|
||||
if !success {
|
||||
return
|
||||
}
|
||||
case <-time.After(10 * time.Second):
|
||||
logger.LogError(c, "data handler timeout")
|
||||
return
|
||||
case dataChan <- data:
|
||||
case <-ctx.Done():
|
||||
return
|
||||
case <-stopChan:
|
||||
|
||||
521
relay/helper/stream_scanner_test.go
Normal file
521
relay/helper/stream_scanner_test.go
Normal file
@@ -0,0 +1,521 @@
|
||||
package helper
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"strings"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/QuantumNous/new-api/constant"
|
||||
relaycommon "github.com/QuantumNous/new-api/relay/common"
|
||||
"github.com/QuantumNous/new-api/setting/operation_setting"
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func init() {
|
||||
gin.SetMode(gin.TestMode)
|
||||
}
|
||||
|
||||
func setupStreamTest(t *testing.T, body io.Reader) (*gin.Context, *http.Response, *relaycommon.RelayInfo) {
|
||||
t.Helper()
|
||||
|
||||
oldTimeout := constant.StreamingTimeout
|
||||
constant.StreamingTimeout = 30
|
||||
t.Cleanup(func() {
|
||||
constant.StreamingTimeout = oldTimeout
|
||||
})
|
||||
|
||||
recorder := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(recorder)
|
||||
c.Request = httptest.NewRequest(http.MethodPost, "/v1/chat/completions", nil)
|
||||
|
||||
resp := &http.Response{
|
||||
Body: io.NopCloser(body),
|
||||
}
|
||||
|
||||
info := &relaycommon.RelayInfo{
|
||||
ChannelMeta: &relaycommon.ChannelMeta{},
|
||||
}
|
||||
|
||||
return c, resp, info
|
||||
}
|
||||
|
||||
func buildSSEBody(n int) string {
|
||||
var b strings.Builder
|
||||
for i := 0; i < n; i++ {
|
||||
fmt.Fprintf(&b, "data: {\"id\":%d,\"choices\":[{\"delta\":{\"content\":\"token_%d\"}}]}\n", i, i)
|
||||
}
|
||||
b.WriteString("data: [DONE]\n")
|
||||
return b.String()
|
||||
}
|
||||
|
||||
// slowReader wraps a reader and injects a delay before each Read call,
|
||||
// simulating a slow upstream that trickles data.
|
||||
type slowReader struct {
|
||||
r io.Reader
|
||||
delay time.Duration
|
||||
}
|
||||
|
||||
func (s *slowReader) Read(p []byte) (int, error) {
|
||||
time.Sleep(s.delay)
|
||||
return s.r.Read(p)
|
||||
}
|
||||
|
||||
// ---------- Basic correctness ----------
|
||||
|
||||
func TestStreamScannerHandler_NilInputs(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
recorder := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(recorder)
|
||||
c.Request = httptest.NewRequest(http.MethodPost, "/", nil)
|
||||
|
||||
info := &relaycommon.RelayInfo{ChannelMeta: &relaycommon.ChannelMeta{}}
|
||||
|
||||
StreamScannerHandler(c, nil, info, func(data string) bool { return true })
|
||||
StreamScannerHandler(c, &http.Response{Body: io.NopCloser(strings.NewReader(""))}, info, nil)
|
||||
}
|
||||
|
||||
func TestStreamScannerHandler_EmptyBody(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
c, resp, info := setupStreamTest(t, strings.NewReader(""))
|
||||
|
||||
var called atomic.Bool
|
||||
StreamScannerHandler(c, resp, info, func(data string) bool {
|
||||
called.Store(true)
|
||||
return true
|
||||
})
|
||||
|
||||
assert.False(t, called.Load(), "handler should not be called for empty body")
|
||||
}
|
||||
|
||||
func TestStreamScannerHandler_1000Chunks(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
const numChunks = 1000
|
||||
body := buildSSEBody(numChunks)
|
||||
c, resp, info := setupStreamTest(t, strings.NewReader(body))
|
||||
|
||||
var count atomic.Int64
|
||||
StreamScannerHandler(c, resp, info, func(data string) bool {
|
||||
count.Add(1)
|
||||
return true
|
||||
})
|
||||
|
||||
assert.Equal(t, int64(numChunks), count.Load())
|
||||
assert.Equal(t, numChunks, info.ReceivedResponseCount)
|
||||
}
|
||||
|
||||
func TestStreamScannerHandler_10000Chunks(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
const numChunks = 10000
|
||||
body := buildSSEBody(numChunks)
|
||||
c, resp, info := setupStreamTest(t, strings.NewReader(body))
|
||||
|
||||
var count atomic.Int64
|
||||
start := time.Now()
|
||||
|
||||
StreamScannerHandler(c, resp, info, func(data string) bool {
|
||||
count.Add(1)
|
||||
return true
|
||||
})
|
||||
|
||||
elapsed := time.Since(start)
|
||||
assert.Equal(t, int64(numChunks), count.Load())
|
||||
assert.Equal(t, numChunks, info.ReceivedResponseCount)
|
||||
t.Logf("10000 chunks processed in %v", elapsed)
|
||||
}
|
||||
|
||||
func TestStreamScannerHandler_OrderPreserved(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
const numChunks = 500
|
||||
body := buildSSEBody(numChunks)
|
||||
c, resp, info := setupStreamTest(t, strings.NewReader(body))
|
||||
|
||||
var mu sync.Mutex
|
||||
received := make([]string, 0, numChunks)
|
||||
|
||||
StreamScannerHandler(c, resp, info, func(data string) bool {
|
||||
mu.Lock()
|
||||
received = append(received, data)
|
||||
mu.Unlock()
|
||||
return true
|
||||
})
|
||||
|
||||
require.Equal(t, numChunks, len(received))
|
||||
for i := 0; i < numChunks; i++ {
|
||||
expected := fmt.Sprintf("{\"id\":%d,\"choices\":[{\"delta\":{\"content\":\"token_%d\"}}]}", i, i)
|
||||
assert.Equal(t, expected, received[i], "chunk %d out of order", i)
|
||||
}
|
||||
}
|
||||
|
||||
func TestStreamScannerHandler_DoneStopsScanner(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
body := buildSSEBody(50) + "data: should_not_appear\n"
|
||||
c, resp, info := setupStreamTest(t, strings.NewReader(body))
|
||||
|
||||
var count atomic.Int64
|
||||
StreamScannerHandler(c, resp, info, func(data string) bool {
|
||||
count.Add(1)
|
||||
return true
|
||||
})
|
||||
|
||||
assert.Equal(t, int64(50), count.Load(), "data after [DONE] must not be processed")
|
||||
}
|
||||
|
||||
func TestStreamScannerHandler_HandlerFailureStops(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
const numChunks = 200
|
||||
body := buildSSEBody(numChunks)
|
||||
c, resp, info := setupStreamTest(t, strings.NewReader(body))
|
||||
|
||||
const failAt = 50
|
||||
var count atomic.Int64
|
||||
StreamScannerHandler(c, resp, info, func(data string) bool {
|
||||
n := count.Add(1)
|
||||
return n < failAt
|
||||
})
|
||||
|
||||
// The worker stops at failAt; the scanner may have read ahead,
|
||||
// but the handler should not be called beyond failAt.
|
||||
assert.Equal(t, int64(failAt), count.Load())
|
||||
}
|
||||
|
||||
func TestStreamScannerHandler_SkipsNonDataLines(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
var b strings.Builder
|
||||
b.WriteString(": comment line\n")
|
||||
b.WriteString("event: message\n")
|
||||
b.WriteString("id: 12345\n")
|
||||
b.WriteString("retry: 5000\n")
|
||||
for i := 0; i < 100; i++ {
|
||||
fmt.Fprintf(&b, "data: payload_%d\n", i)
|
||||
b.WriteString(": interleaved comment\n")
|
||||
}
|
||||
b.WriteString("data: [DONE]\n")
|
||||
|
||||
c, resp, info := setupStreamTest(t, strings.NewReader(b.String()))
|
||||
|
||||
var count atomic.Int64
|
||||
StreamScannerHandler(c, resp, info, func(data string) bool {
|
||||
count.Add(1)
|
||||
return true
|
||||
})
|
||||
|
||||
assert.Equal(t, int64(100), count.Load())
|
||||
}
|
||||
|
||||
func TestStreamScannerHandler_DataWithExtraSpaces(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
body := "data: {\"trimmed\":true} \ndata: [DONE]\n"
|
||||
c, resp, info := setupStreamTest(t, strings.NewReader(body))
|
||||
|
||||
var got string
|
||||
StreamScannerHandler(c, resp, info, func(data string) bool {
|
||||
got = data
|
||||
return true
|
||||
})
|
||||
|
||||
assert.Equal(t, "{\"trimmed\":true}", got)
|
||||
}
|
||||
|
||||
// ---------- Decoupling: scanner not blocked by slow handler ----------
|
||||
|
||||
func TestStreamScannerHandler_ScannerDecoupledFromSlowHandler(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
// Strategy: use a slow upstream (io.Pipe, 10ms per chunk) AND a slow handler (20ms per chunk).
|
||||
// If the scanner were synchronously coupled to the handler, total time would be
|
||||
// ~numChunks * (10ms + 20ms) = 30ms * 50 = 1500ms.
|
||||
// With decoupling, total time should be closer to
|
||||
// ~numChunks * max(10ms, 20ms) = 20ms * 50 = 1000ms
|
||||
// because the scanner reads ahead into the buffer while the handler processes.
|
||||
const numChunks = 50
|
||||
const upstreamDelay = 10 * time.Millisecond
|
||||
const handlerDelay = 20 * time.Millisecond
|
||||
|
||||
pr, pw := io.Pipe()
|
||||
go func() {
|
||||
defer pw.Close()
|
||||
for i := 0; i < numChunks; i++ {
|
||||
fmt.Fprintf(pw, "data: {\"id\":%d}\n", i)
|
||||
time.Sleep(upstreamDelay)
|
||||
}
|
||||
fmt.Fprint(pw, "data: [DONE]\n")
|
||||
}()
|
||||
|
||||
recorder := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(recorder)
|
||||
c.Request = httptest.NewRequest(http.MethodPost, "/v1/chat/completions", nil)
|
||||
|
||||
oldTimeout := constant.StreamingTimeout
|
||||
constant.StreamingTimeout = 30
|
||||
t.Cleanup(func() { constant.StreamingTimeout = oldTimeout })
|
||||
|
||||
resp := &http.Response{Body: pr}
|
||||
info := &relaycommon.RelayInfo{ChannelMeta: &relaycommon.ChannelMeta{}}
|
||||
|
||||
var count atomic.Int64
|
||||
start := time.Now()
|
||||
done := make(chan struct{})
|
||||
go func() {
|
||||
StreamScannerHandler(c, resp, info, func(data string) bool {
|
||||
time.Sleep(handlerDelay)
|
||||
count.Add(1)
|
||||
return true
|
||||
})
|
||||
close(done)
|
||||
}()
|
||||
|
||||
select {
|
||||
case <-done:
|
||||
case <-time.After(15 * time.Second):
|
||||
t.Fatal("StreamScannerHandler did not complete in time")
|
||||
}
|
||||
|
||||
elapsed := time.Since(start)
|
||||
assert.Equal(t, int64(numChunks), count.Load())
|
||||
|
||||
coupledTime := time.Duration(numChunks) * (upstreamDelay + handlerDelay)
|
||||
t.Logf("elapsed=%v, coupled_estimate=%v", elapsed, coupledTime)
|
||||
|
||||
// If decoupled, elapsed should be well under the coupled estimate.
|
||||
assert.Less(t, elapsed, coupledTime*85/100,
|
||||
"decoupled elapsed time (%v) should be significantly less than coupled estimate (%v)", elapsed, coupledTime)
|
||||
}
|
||||
|
||||
func TestStreamScannerHandler_SlowUpstreamFastHandler(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
const numChunks = 50
|
||||
body := buildSSEBody(numChunks)
|
||||
reader := &slowReader{r: strings.NewReader(body), delay: 2 * time.Millisecond}
|
||||
c, resp, info := setupStreamTest(t, reader)
|
||||
|
||||
var count atomic.Int64
|
||||
start := time.Now()
|
||||
|
||||
done := make(chan struct{})
|
||||
go func() {
|
||||
StreamScannerHandler(c, resp, info, func(data string) bool {
|
||||
count.Add(1)
|
||||
return true
|
||||
})
|
||||
close(done)
|
||||
}()
|
||||
|
||||
select {
|
||||
case <-done:
|
||||
case <-time.After(15 * time.Second):
|
||||
t.Fatal("timed out with slow upstream")
|
||||
}
|
||||
|
||||
elapsed := time.Since(start)
|
||||
assert.Equal(t, int64(numChunks), count.Load())
|
||||
t.Logf("slow upstream (%d chunks, 2ms/read): %v", numChunks, elapsed)
|
||||
}
|
||||
|
||||
// ---------- Ping tests ----------
|
||||
|
||||
func TestStreamScannerHandler_PingSentDuringSlowUpstream(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
setting := operation_setting.GetGeneralSetting()
|
||||
oldEnabled := setting.PingIntervalEnabled
|
||||
oldSeconds := setting.PingIntervalSeconds
|
||||
setting.PingIntervalEnabled = true
|
||||
setting.PingIntervalSeconds = 1
|
||||
t.Cleanup(func() {
|
||||
setting.PingIntervalEnabled = oldEnabled
|
||||
setting.PingIntervalSeconds = oldSeconds
|
||||
})
|
||||
|
||||
// Create a reader that delivers data slowly: one chunk every 500ms over 3.5 seconds.
|
||||
// The ping interval is 1s, so we should see at least 2 pings.
|
||||
pr, pw := io.Pipe()
|
||||
go func() {
|
||||
defer pw.Close()
|
||||
for i := 0; i < 7; i++ {
|
||||
fmt.Fprintf(pw, "data: chunk_%d\n", i)
|
||||
time.Sleep(500 * time.Millisecond)
|
||||
}
|
||||
fmt.Fprint(pw, "data: [DONE]\n")
|
||||
}()
|
||||
|
||||
recorder := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(recorder)
|
||||
c.Request = httptest.NewRequest(http.MethodPost, "/v1/chat/completions", nil)
|
||||
|
||||
oldTimeout := constant.StreamingTimeout
|
||||
constant.StreamingTimeout = 30
|
||||
t.Cleanup(func() {
|
||||
constant.StreamingTimeout = oldTimeout
|
||||
})
|
||||
|
||||
resp := &http.Response{Body: pr}
|
||||
info := &relaycommon.RelayInfo{ChannelMeta: &relaycommon.ChannelMeta{}}
|
||||
|
||||
var count atomic.Int64
|
||||
done := make(chan struct{})
|
||||
go func() {
|
||||
StreamScannerHandler(c, resp, info, func(data string) bool {
|
||||
count.Add(1)
|
||||
return true
|
||||
})
|
||||
close(done)
|
||||
}()
|
||||
|
||||
select {
|
||||
case <-done:
|
||||
case <-time.After(15 * time.Second):
|
||||
t.Fatal("timed out waiting for stream to finish")
|
||||
}
|
||||
|
||||
assert.Equal(t, int64(7), count.Load())
|
||||
|
||||
body := recorder.Body.String()
|
||||
pingCount := strings.Count(body, ": PING")
|
||||
t.Logf("received %d pings in response body", pingCount)
|
||||
assert.GreaterOrEqual(t, pingCount, 2,
|
||||
"expected at least 2 pings during 3.5s stream with 1s interval; got %d", pingCount)
|
||||
}
|
||||
|
||||
func TestStreamScannerHandler_PingDisabledByRelayInfo(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
setting := operation_setting.GetGeneralSetting()
|
||||
oldEnabled := setting.PingIntervalEnabled
|
||||
oldSeconds := setting.PingIntervalSeconds
|
||||
setting.PingIntervalEnabled = true
|
||||
setting.PingIntervalSeconds = 1
|
||||
t.Cleanup(func() {
|
||||
setting.PingIntervalEnabled = oldEnabled
|
||||
setting.PingIntervalSeconds = oldSeconds
|
||||
})
|
||||
|
||||
pr, pw := io.Pipe()
|
||||
go func() {
|
||||
defer pw.Close()
|
||||
for i := 0; i < 5; i++ {
|
||||
fmt.Fprintf(pw, "data: chunk_%d\n", i)
|
||||
time.Sleep(500 * time.Millisecond)
|
||||
}
|
||||
fmt.Fprint(pw, "data: [DONE]\n")
|
||||
}()
|
||||
|
||||
recorder := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(recorder)
|
||||
c.Request = httptest.NewRequest(http.MethodPost, "/v1/chat/completions", nil)
|
||||
|
||||
oldTimeout := constant.StreamingTimeout
|
||||
constant.StreamingTimeout = 30
|
||||
t.Cleanup(func() {
|
||||
constant.StreamingTimeout = oldTimeout
|
||||
})
|
||||
|
||||
resp := &http.Response{Body: pr}
|
||||
info := &relaycommon.RelayInfo{
|
||||
DisablePing: true,
|
||||
ChannelMeta: &relaycommon.ChannelMeta{},
|
||||
}
|
||||
|
||||
var count atomic.Int64
|
||||
done := make(chan struct{})
|
||||
go func() {
|
||||
StreamScannerHandler(c, resp, info, func(data string) bool {
|
||||
count.Add(1)
|
||||
return true
|
||||
})
|
||||
close(done)
|
||||
}()
|
||||
|
||||
select {
|
||||
case <-done:
|
||||
case <-time.After(15 * time.Second):
|
||||
t.Fatal("timed out")
|
||||
}
|
||||
|
||||
assert.Equal(t, int64(5), count.Load())
|
||||
|
||||
body := recorder.Body.String()
|
||||
pingCount := strings.Count(body, ": PING")
|
||||
assert.Equal(t, 0, pingCount, "pings should be disabled when DisablePing=true")
|
||||
}
|
||||
|
||||
func TestStreamScannerHandler_PingInterleavesWithSlowUpstream(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
setting := operation_setting.GetGeneralSetting()
|
||||
oldEnabled := setting.PingIntervalEnabled
|
||||
oldSeconds := setting.PingIntervalSeconds
|
||||
setting.PingIntervalEnabled = true
|
||||
setting.PingIntervalSeconds = 1
|
||||
t.Cleanup(func() {
|
||||
setting.PingIntervalEnabled = oldEnabled
|
||||
setting.PingIntervalSeconds = oldSeconds
|
||||
})
|
||||
|
||||
// Slow upstream + slow handler. Total stream takes ~5 seconds.
|
||||
// The ping goroutine stays alive as long as the scanner is reading,
|
||||
// so pings should fire between data writes.
|
||||
pr, pw := io.Pipe()
|
||||
go func() {
|
||||
defer pw.Close()
|
||||
for i := 0; i < 10; i++ {
|
||||
fmt.Fprintf(pw, "data: chunk_%d\n", i)
|
||||
time.Sleep(500 * time.Millisecond)
|
||||
}
|
||||
fmt.Fprint(pw, "data: [DONE]\n")
|
||||
}()
|
||||
|
||||
recorder := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(recorder)
|
||||
c.Request = httptest.NewRequest(http.MethodPost, "/v1/chat/completions", nil)
|
||||
|
||||
oldTimeout := constant.StreamingTimeout
|
||||
constant.StreamingTimeout = 30
|
||||
t.Cleanup(func() {
|
||||
constant.StreamingTimeout = oldTimeout
|
||||
})
|
||||
|
||||
resp := &http.Response{Body: pr}
|
||||
info := &relaycommon.RelayInfo{ChannelMeta: &relaycommon.ChannelMeta{}}
|
||||
|
||||
var count atomic.Int64
|
||||
done := make(chan struct{})
|
||||
go func() {
|
||||
StreamScannerHandler(c, resp, info, func(data string) bool {
|
||||
count.Add(1)
|
||||
return true
|
||||
})
|
||||
close(done)
|
||||
}()
|
||||
|
||||
select {
|
||||
case <-done:
|
||||
case <-time.After(15 * time.Second):
|
||||
t.Fatal("timed out")
|
||||
}
|
||||
|
||||
assert.Equal(t, int64(10), count.Load())
|
||||
|
||||
body := recorder.Body.String()
|
||||
pingCount := strings.Count(body, ": PING")
|
||||
t.Logf("received %d pings interleaved with 10 chunks over 5s", pingCount)
|
||||
assert.GreaterOrEqual(t, pingCount, 3,
|
||||
"expected at least 3 pings during 5s stream with 1s ping interval; got %d", pingCount)
|
||||
}
|
||||
@@ -184,7 +184,7 @@ func RelaySwapFace(c *gin.Context, info *relaycommon.RelayInfo) *dto.MidjourneyR
|
||||
if swapFaceRequest.SourceBase64 == "" || swapFaceRequest.TargetBase64 == "" {
|
||||
return service.MidjourneyErrorWrapper(constant.MjRequestError, "sour_base64_and_target_base64_is_required")
|
||||
}
|
||||
modelName := service.CoverActionToModelName(constant.MjActionSwapFace)
|
||||
modelName := service.CovertMjpActionToModelName(constant.MjActionSwapFace)
|
||||
|
||||
priceData := helper.ModelPriceHelperPerCall(c, info)
|
||||
|
||||
@@ -485,7 +485,7 @@ func RelayMidjourneySubmit(c *gin.Context, relayInfo *relaycommon.RelayInfo) *dt
|
||||
|
||||
fullRequestURL := fmt.Sprintf("%s%s", baseURL, requestURL)
|
||||
|
||||
modelName := service.CoverActionToModelName(midjRequest.Action)
|
||||
modelName := service.CovertMjpActionToModelName(midjRequest.Action)
|
||||
|
||||
priceData := helper.ModelPriceHelperPerCall(c, relayInfo)
|
||||
|
||||
|
||||
@@ -2,7 +2,6 @@ package relay
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
@@ -15,29 +14,33 @@ import (
|
||||
"github.com/QuantumNous/new-api/dto"
|
||||
"github.com/QuantumNous/new-api/model"
|
||||
"github.com/QuantumNous/new-api/relay/channel"
|
||||
"github.com/QuantumNous/new-api/relay/channel/task/taskcommon"
|
||||
relaycommon "github.com/QuantumNous/new-api/relay/common"
|
||||
relayconstant "github.com/QuantumNous/new-api/relay/constant"
|
||||
"github.com/QuantumNous/new-api/relay/helper"
|
||||
"github.com/QuantumNous/new-api/service"
|
||||
"github.com/QuantumNous/new-api/setting/ratio_setting"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
/*
|
||||
Task 任务通过平台、Action 区分任务
|
||||
*/
|
||||
func RelayTaskSubmit(c *gin.Context, info *relaycommon.RelayInfo) (taskErr *dto.TaskError) {
|
||||
info.InitChannelMeta(c)
|
||||
// ensure TaskRelayInfo is initialized to avoid nil dereference when accessing embedded fields
|
||||
if info.TaskRelayInfo == nil {
|
||||
info.TaskRelayInfo = &relaycommon.TaskRelayInfo{}
|
||||
}
|
||||
type TaskSubmitResult struct {
|
||||
UpstreamTaskID string
|
||||
TaskData []byte
|
||||
Platform constant.TaskPlatform
|
||||
Quota int
|
||||
//PerCallPrice types.PriceData
|
||||
}
|
||||
|
||||
// ResolveOriginTask 处理基于已有任务的提交(remix / continuation):
|
||||
// 查找原始任务、从中提取模型名称、将渠道锁定到原始任务的渠道
|
||||
// (通过 info.LockedChannel,重试时复用同一渠道并轮换 key),
|
||||
// 以及提取 OtherRatios(时长、分辨率)。
|
||||
// 该函数在控制器的重试循环之前调用一次,其结果通过 info 字段和上下文持久化。
|
||||
func ResolveOriginTask(c *gin.Context, info *relaycommon.RelayInfo) *dto.TaskError {
|
||||
// 检测 remix action
|
||||
path := c.Request.URL.Path
|
||||
if strings.Contains(path, "/v1/videos/") && strings.HasSuffix(path, "/remix") {
|
||||
info.Action = constant.TaskActionRemix
|
||||
}
|
||||
|
||||
// 提取 remix 任务的 video_id
|
||||
if info.Action == constant.TaskActionRemix {
|
||||
videoID := c.Param("video_id")
|
||||
if strings.TrimSpace(videoID) == "" {
|
||||
@@ -46,64 +49,71 @@ func RelayTaskSubmit(c *gin.Context, info *relaycommon.RelayInfo) (taskErr *dto.
|
||||
info.OriginTaskID = videoID
|
||||
}
|
||||
|
||||
platform := constant.TaskPlatform(c.GetString("platform"))
|
||||
if info.OriginTaskID == "" {
|
||||
return nil
|
||||
}
|
||||
|
||||
// 获取原始任务信息
|
||||
if info.OriginTaskID != "" {
|
||||
originTask, exist, err := model.GetByTaskId(info.UserId, info.OriginTaskID)
|
||||
if err != nil {
|
||||
taskErr = service.TaskErrorWrapper(err, "get_origin_task_failed", http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
if !exist {
|
||||
taskErr = service.TaskErrorWrapperLocal(errors.New("task_origin_not_exist"), "task_not_exist", http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
if info.OriginModelName == "" {
|
||||
if originTask.Properties.OriginModelName != "" {
|
||||
info.OriginModelName = originTask.Properties.OriginModelName
|
||||
} else if originTask.Properties.UpstreamModelName != "" {
|
||||
info.OriginModelName = originTask.Properties.UpstreamModelName
|
||||
} else {
|
||||
var taskData map[string]interface{}
|
||||
_ = json.Unmarshal(originTask.Data, &taskData)
|
||||
if m, ok := taskData["model"].(string); ok && m != "" {
|
||||
info.OriginModelName = m
|
||||
platform = originTask.Platform
|
||||
}
|
||||
}
|
||||
}
|
||||
if originTask.ChannelId != info.ChannelId {
|
||||
channel, err := model.GetChannelById(originTask.ChannelId, true)
|
||||
if err != nil {
|
||||
taskErr = service.TaskErrorWrapperLocal(err, "channel_not_found", http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
if channel.Status != common.ChannelStatusEnabled {
|
||||
taskErr = service.TaskErrorWrapperLocal(errors.New("the channel of the origin task is disabled"), "task_channel_disable", http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
key, _, newAPIError := channel.GetNextEnabledKey()
|
||||
if newAPIError != nil {
|
||||
taskErr = service.TaskErrorWrapper(newAPIError, "channel_no_available_key", newAPIError.StatusCode)
|
||||
return
|
||||
}
|
||||
common.SetContextKey(c, constant.ContextKeyChannelKey, key)
|
||||
common.SetContextKey(c, constant.ContextKeyChannelType, channel.Type)
|
||||
common.SetContextKey(c, constant.ContextKeyChannelBaseUrl, channel.GetBaseURL())
|
||||
common.SetContextKey(c, constant.ContextKeyChannelId, originTask.ChannelId)
|
||||
// 查找原始任务
|
||||
originTask, exist, err := model.GetByTaskId(info.UserId, info.OriginTaskID)
|
||||
if err != nil {
|
||||
return service.TaskErrorWrapper(err, "get_origin_task_failed", http.StatusInternalServerError)
|
||||
}
|
||||
if !exist {
|
||||
return service.TaskErrorWrapperLocal(errors.New("task_origin_not_exist"), "task_not_exist", http.StatusBadRequest)
|
||||
}
|
||||
|
||||
info.ChannelBaseUrl = channel.GetBaseURL()
|
||||
info.ChannelId = originTask.ChannelId
|
||||
info.ChannelType = channel.Type
|
||||
info.ApiKey = key
|
||||
platform = originTask.Platform
|
||||
}
|
||||
|
||||
// 使用原始任务的参数
|
||||
if info.Action == constant.TaskActionRemix {
|
||||
// 从原始任务推导模型名称
|
||||
if info.OriginModelName == "" {
|
||||
if originTask.Properties.OriginModelName != "" {
|
||||
info.OriginModelName = originTask.Properties.OriginModelName
|
||||
} else if originTask.Properties.UpstreamModelName != "" {
|
||||
info.OriginModelName = originTask.Properties.UpstreamModelName
|
||||
} else {
|
||||
var taskData map[string]interface{}
|
||||
_ = json.Unmarshal(originTask.Data, &taskData)
|
||||
_ = common.Unmarshal(originTask.Data, &taskData)
|
||||
if m, ok := taskData["model"].(string); ok && m != "" {
|
||||
info.OriginModelName = m
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// 锁定到原始任务的渠道(重试时复用同一渠道,轮换 key)
|
||||
ch, err := model.GetChannelById(originTask.ChannelId, true)
|
||||
if err != nil {
|
||||
return service.TaskErrorWrapperLocal(err, "channel_not_found", http.StatusBadRequest)
|
||||
}
|
||||
if ch.Status != common.ChannelStatusEnabled {
|
||||
return service.TaskErrorWrapperLocal(errors.New("the channel of the origin task is disabled"), "task_channel_disable", http.StatusBadRequest)
|
||||
}
|
||||
info.LockedChannel = ch
|
||||
|
||||
if originTask.ChannelId != info.ChannelId {
|
||||
key, _, newAPIError := ch.GetNextEnabledKey()
|
||||
if newAPIError != nil {
|
||||
return service.TaskErrorWrapper(newAPIError, "channel_no_available_key", newAPIError.StatusCode)
|
||||
}
|
||||
common.SetContextKey(c, constant.ContextKeyChannelKey, key)
|
||||
common.SetContextKey(c, constant.ContextKeyChannelType, ch.Type)
|
||||
common.SetContextKey(c, constant.ContextKeyChannelBaseUrl, ch.GetBaseURL())
|
||||
common.SetContextKey(c, constant.ContextKeyChannelId, originTask.ChannelId)
|
||||
|
||||
info.ChannelBaseUrl = ch.GetBaseURL()
|
||||
info.ChannelId = originTask.ChannelId
|
||||
info.ChannelType = ch.Type
|
||||
info.ApiKey = key
|
||||
}
|
||||
|
||||
// 提取 remix 参数(时长、分辨率 → OtherRatios)
|
||||
if info.Action == constant.TaskActionRemix {
|
||||
if originTask.PrivateData.BillingContext != nil {
|
||||
// 新的 remix 逻辑:直接从原始任务的 BillingContext 中提取 OtherRatios(如果存在)
|
||||
for s, f := range originTask.PrivateData.BillingContext.OtherRatios {
|
||||
info.PriceData.AddOtherRatio(s, f)
|
||||
}
|
||||
} else {
|
||||
// 旧的 remix 逻辑:直接从 task data 解析 seconds 和 size(如果存在)
|
||||
var taskData map[string]interface{}
|
||||
_ = common.Unmarshal(originTask.Data, &taskData)
|
||||
secondsStr, _ := taskData["seconds"].(string)
|
||||
seconds, _ := strconv.Atoi(secondsStr)
|
||||
if seconds <= 0 {
|
||||
@@ -120,167 +130,146 @@ func RelayTaskSubmit(c *gin.Context, info *relaycommon.RelayInfo) (taskErr *dto.
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// RelayTaskSubmit 完成 task 提交的全部流程(每次尝试调用一次):
|
||||
// 刷新渠道元数据 → 确定 platform/adaptor → 验证请求 →
|
||||
// 估算计费(EstimateBilling) → 计算价格 → 预扣费(仅首次)→
|
||||
// 构建/发送/解析上游请求 → 提交后计费调整(AdjustBillingOnSubmit)。
|
||||
// 控制器负责 defer Refund 和成功后 Settle。
|
||||
func RelayTaskSubmit(c *gin.Context, info *relaycommon.RelayInfo) (*TaskSubmitResult, *dto.TaskError) {
|
||||
info.InitChannelMeta(c)
|
||||
|
||||
// 1. 确定 platform → 创建适配器 → 验证请求
|
||||
platform := constant.TaskPlatform(c.GetString("platform"))
|
||||
if platform == "" {
|
||||
platform = GetTaskPlatform(c)
|
||||
}
|
||||
|
||||
info.InitChannelMeta(c)
|
||||
adaptor := GetTaskAdaptor(platform)
|
||||
if adaptor == nil {
|
||||
return service.TaskErrorWrapperLocal(fmt.Errorf("invalid api platform: %s", platform), "invalid_api_platform", http.StatusBadRequest)
|
||||
return nil, service.TaskErrorWrapperLocal(fmt.Errorf("invalid api platform: %s", platform), "invalid_api_platform", http.StatusBadRequest)
|
||||
}
|
||||
adaptor.Init(info)
|
||||
// get & validate taskRequest 获取并验证文本请求
|
||||
taskErr = adaptor.ValidateRequestAndSetAction(c, info)
|
||||
if taskErr != nil {
|
||||
return
|
||||
if taskErr := adaptor.ValidateRequestAndSetAction(c, info); taskErr != nil {
|
||||
return nil, taskErr
|
||||
}
|
||||
|
||||
// 2. 确定模型名称
|
||||
modelName := info.OriginModelName
|
||||
if modelName == "" {
|
||||
modelName = service.CoverTaskActionToModelName(platform, info.Action)
|
||||
}
|
||||
modelPrice, success := ratio_setting.GetModelPrice(modelName, true)
|
||||
if !success {
|
||||
defaultPrice, ok := ratio_setting.GetDefaultModelPriceMap()[modelName]
|
||||
if !ok {
|
||||
modelPrice = float64(common.PreConsumedQuota) / common.QuotaPerUnit
|
||||
} else {
|
||||
modelPrice = defaultPrice
|
||||
|
||||
// 2.5 应用渠道的模型映射(与同步任务对齐)
|
||||
info.OriginModelName = modelName
|
||||
info.UpstreamModelName = modelName
|
||||
if err := helper.ModelMappedHelper(c, info, nil); err != nil {
|
||||
return nil, service.TaskErrorWrapperLocal(err, "model_mapping_failed", http.StatusBadRequest)
|
||||
}
|
||||
|
||||
// 3. 预生成公开 task ID(仅首次)
|
||||
if info.PublicTaskID == "" {
|
||||
info.PublicTaskID = model.GenerateTaskID()
|
||||
}
|
||||
|
||||
// 4. 价格计算:基础模型价格
|
||||
info.OriginModelName = modelName
|
||||
info.PriceData = helper.ModelPriceHelperPerCall(c, info)
|
||||
|
||||
// 5. 计费估算:让适配器根据用户请求提供 OtherRatios(时长、分辨率等)
|
||||
// 必须在 ModelPriceHelperPerCall 之后调用(它会重建 PriceData)。
|
||||
// ResolveOriginTask 可能已在 remix 路径中预设了 OtherRatios,此处合并。
|
||||
if estimatedRatios := adaptor.EstimateBilling(c, info); len(estimatedRatios) > 0 {
|
||||
for k, v := range estimatedRatios {
|
||||
info.PriceData.AddOtherRatio(k, v)
|
||||
}
|
||||
}
|
||||
|
||||
// 处理 auto 分组:从 context 获取实际选中的分组
|
||||
// 当使用 auto 分组时,Distribute 中间件会将实际选中的分组存储在 ContextKeyAutoGroup 中
|
||||
if autoGroup, exists := common.GetContextKey(c, constant.ContextKeyAutoGroup); exists {
|
||||
if groupStr, ok := autoGroup.(string); ok && groupStr != "" {
|
||||
info.UsingGroup = groupStr
|
||||
}
|
||||
}
|
||||
|
||||
// 预扣
|
||||
groupRatio := ratio_setting.GetGroupRatio(info.UsingGroup)
|
||||
var ratio float64
|
||||
userGroupRatio, hasUserGroupRatio := ratio_setting.GetGroupGroupRatio(info.UserGroup, info.UsingGroup)
|
||||
if hasUserGroupRatio {
|
||||
ratio = modelPrice * userGroupRatio
|
||||
} else {
|
||||
ratio = modelPrice * groupRatio
|
||||
}
|
||||
// FIXME: 临时修补,支持任务仅按次计费
|
||||
// 6. 将 OtherRatios 应用到基础额度
|
||||
if !common.StringsContains(constant.TaskPricePatches, modelName) {
|
||||
if len(info.PriceData.OtherRatios) > 0 {
|
||||
for _, ra := range info.PriceData.OtherRatios {
|
||||
if 1.0 != ra {
|
||||
ratio *= ra
|
||||
}
|
||||
for _, ra := range info.PriceData.OtherRatios {
|
||||
if ra != 1.0 {
|
||||
info.PriceData.Quota = int(float64(info.PriceData.Quota) * ra)
|
||||
}
|
||||
}
|
||||
}
|
||||
println(fmt.Sprintf("model: %s, model_price: %.4f, group: %s, group_ratio: %.4f, final_ratio: %.4f", modelName, modelPrice, info.UsingGroup, groupRatio, ratio))
|
||||
userQuota, err := model.GetUserQuota(info.UserId, false)
|
||||
if err != nil {
|
||||
taskErr = service.TaskErrorWrapper(err, "get_user_quota_failed", http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
quota := int(ratio * common.QuotaPerUnit)
|
||||
if userQuota-quota < 0 {
|
||||
taskErr = service.TaskErrorWrapperLocal(errors.New("user quota is not enough"), "quota_not_enough", http.StatusForbidden)
|
||||
return
|
||||
|
||||
// 7. 预扣费(仅首次 — 重试时 info.Billing 已存在,跳过)
|
||||
if info.Billing == nil && !info.PriceData.FreeModel {
|
||||
info.ForcePreConsume = true
|
||||
if apiErr := service.PreConsumeBilling(c, info.PriceData.Quota, info); apiErr != nil {
|
||||
return nil, service.TaskErrorFromAPIError(apiErr)
|
||||
}
|
||||
}
|
||||
|
||||
// build body
|
||||
// 8. 构建请求体
|
||||
requestBody, err := adaptor.BuildRequestBody(c, info)
|
||||
if err != nil {
|
||||
taskErr = service.TaskErrorWrapper(err, "build_request_failed", http.StatusInternalServerError)
|
||||
return
|
||||
return nil, service.TaskErrorWrapper(err, "build_request_failed", http.StatusInternalServerError)
|
||||
}
|
||||
// do request
|
||||
|
||||
// 9. 发送请求
|
||||
resp, err := adaptor.DoRequest(c, info, requestBody)
|
||||
if err != nil {
|
||||
taskErr = service.TaskErrorWrapper(err, "do_request_failed", http.StatusInternalServerError)
|
||||
return
|
||||
return nil, service.TaskErrorWrapper(err, "do_request_failed", http.StatusInternalServerError)
|
||||
}
|
||||
// handle response
|
||||
if resp != nil && resp.StatusCode != http.StatusOK {
|
||||
responseBody, _ := io.ReadAll(resp.Body)
|
||||
taskErr = service.TaskErrorWrapper(fmt.Errorf("%s", string(responseBody)), "fail_to_fetch_task", resp.StatusCode)
|
||||
return
|
||||
return nil, service.TaskErrorWrapper(fmt.Errorf("%s", string(responseBody)), "fail_to_fetch_task", resp.StatusCode)
|
||||
}
|
||||
|
||||
defer func() {
|
||||
// release quota
|
||||
if info.ConsumeQuota && taskErr == nil {
|
||||
// 10. 返回 OtherRatios 给下游(header 必须在 DoResponse 写 body 之前设置)
|
||||
otherRatios := info.PriceData.OtherRatios
|
||||
if otherRatios == nil {
|
||||
otherRatios = map[string]float64{}
|
||||
}
|
||||
ratiosJSON, _ := common.Marshal(otherRatios)
|
||||
c.Header("X-New-Api-Other-Ratios", string(ratiosJSON))
|
||||
|
||||
err := service.PostConsumeQuota(info, quota, 0, true)
|
||||
if err != nil {
|
||||
common.SysLog("error consuming token remain quota: " + err.Error())
|
||||
}
|
||||
if quota != 0 {
|
||||
tokenName := c.GetString("token_name")
|
||||
//gRatio := groupRatio
|
||||
//if hasUserGroupRatio {
|
||||
// gRatio = userGroupRatio
|
||||
//}
|
||||
logContent := fmt.Sprintf("操作 %s", info.Action)
|
||||
// FIXME: 临时修补,支持任务仅按次计费
|
||||
if common.StringsContains(constant.TaskPricePatches, modelName) {
|
||||
logContent = fmt.Sprintf("%s,按次计费", logContent)
|
||||
} else {
|
||||
if len(info.PriceData.OtherRatios) > 0 {
|
||||
var contents []string
|
||||
for key, ra := range info.PriceData.OtherRatios {
|
||||
if 1.0 != ra {
|
||||
contents = append(contents, fmt.Sprintf("%s: %.2f", key, ra))
|
||||
}
|
||||
}
|
||||
if len(contents) > 0 {
|
||||
logContent = fmt.Sprintf("%s, 计算参数:%s", logContent, strings.Join(contents, ", "))
|
||||
}
|
||||
}
|
||||
}
|
||||
other := make(map[string]interface{})
|
||||
if c != nil && c.Request != nil && c.Request.URL != nil {
|
||||
other["request_path"] = c.Request.URL.Path
|
||||
}
|
||||
other["model_price"] = modelPrice
|
||||
other["group_ratio"] = groupRatio
|
||||
if hasUserGroupRatio {
|
||||
other["user_group_ratio"] = userGroupRatio
|
||||
}
|
||||
model.RecordConsumeLog(c, info.UserId, model.RecordConsumeLogParams{
|
||||
ChannelId: info.ChannelId,
|
||||
ModelName: modelName,
|
||||
TokenName: tokenName,
|
||||
Quota: quota,
|
||||
Content: logContent,
|
||||
TokenId: info.TokenId,
|
||||
Group: info.UsingGroup,
|
||||
Other: other,
|
||||
})
|
||||
model.UpdateUserUsedQuotaAndRequestCount(info.UserId, quota)
|
||||
model.UpdateChannelUsedQuota(info.ChannelId, quota)
|
||||
}
|
||||
}
|
||||
}()
|
||||
|
||||
taskID, taskData, taskErr := adaptor.DoResponse(c, resp, info)
|
||||
// 11. 解析响应
|
||||
upstreamTaskID, taskData, taskErr := adaptor.DoResponse(c, resp, info)
|
||||
if taskErr != nil {
|
||||
return
|
||||
return nil, taskErr
|
||||
}
|
||||
info.ConsumeQuota = true
|
||||
// insert task
|
||||
task := model.InitTask(platform, info)
|
||||
task.TaskID = taskID
|
||||
task.Quota = quota
|
||||
task.Data = taskData
|
||||
task.Action = info.Action
|
||||
err = task.Insert()
|
||||
if err != nil {
|
||||
taskErr = service.TaskErrorWrapper(err, "insert_task_failed", http.StatusInternalServerError)
|
||||
return
|
||||
|
||||
// 11. 提交后计费调整:让适配器根据上游实际返回调整 OtherRatios
|
||||
finalQuota := info.PriceData.Quota
|
||||
if adjustedRatios := adaptor.AdjustBillingOnSubmit(info, taskData); len(adjustedRatios) > 0 {
|
||||
// 基于调整后的 ratios 重新计算 quota
|
||||
finalQuota = recalcQuotaFromRatios(info, adjustedRatios)
|
||||
info.PriceData.OtherRatios = adjustedRatios
|
||||
info.PriceData.Quota = finalQuota
|
||||
}
|
||||
return nil
|
||||
|
||||
return &TaskSubmitResult{
|
||||
UpstreamTaskID: upstreamTaskID,
|
||||
TaskData: taskData,
|
||||
Platform: platform,
|
||||
Quota: finalQuota,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// recalcQuotaFromRatios 根据 adjustedRatios 重新计算 quota。
|
||||
// 公式: baseQuota × ∏(ratio) — 其中 baseQuota 是不含 OtherRatios 的基础额度。
|
||||
func recalcQuotaFromRatios(info *relaycommon.RelayInfo, ratios map[string]float64) int {
|
||||
// 从 PriceData 获取不含 OtherRatios 的基础价格
|
||||
baseQuota := info.PriceData.Quota
|
||||
// 先除掉原有的 OtherRatios 恢复基础额度
|
||||
for _, ra := range info.PriceData.OtherRatios {
|
||||
if ra != 1.0 && ra > 0 {
|
||||
baseQuota = int(float64(baseQuota) / ra)
|
||||
}
|
||||
}
|
||||
// 应用新的 ratios
|
||||
result := float64(baseQuota)
|
||||
for _, ra := range ratios {
|
||||
if ra != 1.0 {
|
||||
result *= ra
|
||||
}
|
||||
}
|
||||
return int(result)
|
||||
}
|
||||
|
||||
var fetchRespBuilders = map[int]func(c *gin.Context) (respBody []byte, taskResp *dto.TaskError){
|
||||
@@ -336,7 +325,7 @@ func sunoFetchRespBodyBuilder(c *gin.Context) (respBody []byte, taskResp *dto.Ta
|
||||
} else {
|
||||
tasks = make([]any, 0)
|
||||
}
|
||||
respBody, err = json.Marshal(dto.TaskResponse[[]any]{
|
||||
respBody, err = common.Marshal(dto.TaskResponse[[]any]{
|
||||
Code: "success",
|
||||
Data: tasks,
|
||||
})
|
||||
@@ -357,7 +346,7 @@ func sunoFetchByIDRespBodyBuilder(c *gin.Context) (respBody []byte, taskResp *dt
|
||||
return
|
||||
}
|
||||
|
||||
respBody, err = json.Marshal(dto.TaskResponse[any]{
|
||||
respBody, err = common.Marshal(dto.TaskResponse[any]{
|
||||
Code: "success",
|
||||
Data: TaskModel2Dto(originTask),
|
||||
})
|
||||
@@ -381,97 +370,16 @@ func videoFetchByIDRespBodyBuilder(c *gin.Context) (respBody []byte, taskResp *d
|
||||
return
|
||||
}
|
||||
|
||||
func() {
|
||||
channelModel, err2 := model.GetChannelById(originTask.ChannelId, true)
|
||||
if err2 != nil {
|
||||
return
|
||||
}
|
||||
if channelModel.Type != constant.ChannelTypeVertexAi && channelModel.Type != constant.ChannelTypeGemini {
|
||||
return
|
||||
}
|
||||
baseURL := constant.ChannelBaseURLs[channelModel.Type]
|
||||
if channelModel.GetBaseURL() != "" {
|
||||
baseURL = channelModel.GetBaseURL()
|
||||
}
|
||||
proxy := channelModel.GetSetting().Proxy
|
||||
adaptor := GetTaskAdaptor(constant.TaskPlatform(strconv.Itoa(channelModel.Type)))
|
||||
if adaptor == nil {
|
||||
return
|
||||
}
|
||||
resp, err2 := adaptor.FetchTask(baseURL, channelModel.Key, map[string]any{
|
||||
"task_id": originTask.TaskID,
|
||||
"action": originTask.Action,
|
||||
}, proxy)
|
||||
if err2 != nil || resp == nil {
|
||||
return
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
body, err2 := io.ReadAll(resp.Body)
|
||||
if err2 != nil {
|
||||
return
|
||||
}
|
||||
ti, err2 := adaptor.ParseTaskResult(body)
|
||||
if err2 == nil && ti != nil {
|
||||
if ti.Status != "" {
|
||||
originTask.Status = model.TaskStatus(ti.Status)
|
||||
}
|
||||
if ti.Progress != "" {
|
||||
originTask.Progress = ti.Progress
|
||||
}
|
||||
if ti.Url != "" {
|
||||
if strings.HasPrefix(ti.Url, "data:") {
|
||||
} else {
|
||||
originTask.FailReason = ti.Url
|
||||
}
|
||||
}
|
||||
_ = originTask.Update()
|
||||
var raw map[string]any
|
||||
_ = json.Unmarshal(body, &raw)
|
||||
format := "mp4"
|
||||
if respObj, ok := raw["response"].(map[string]any); ok {
|
||||
if vids, ok := respObj["videos"].([]any); ok && len(vids) > 0 {
|
||||
if v0, ok := vids[0].(map[string]any); ok {
|
||||
if mt, ok := v0["mimeType"].(string); ok && mt != "" {
|
||||
if strings.Contains(mt, "mp4") {
|
||||
format = "mp4"
|
||||
} else {
|
||||
format = mt
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
status := "processing"
|
||||
switch originTask.Status {
|
||||
case model.TaskStatusSuccess:
|
||||
status = "succeeded"
|
||||
case model.TaskStatusFailure:
|
||||
status = "failed"
|
||||
case model.TaskStatusQueued, model.TaskStatusSubmitted:
|
||||
status = "queued"
|
||||
}
|
||||
if !strings.HasPrefix(c.Request.RequestURI, "/v1/videos/") {
|
||||
out := map[string]any{
|
||||
"error": nil,
|
||||
"format": format,
|
||||
"metadata": nil,
|
||||
"status": status,
|
||||
"task_id": originTask.TaskID,
|
||||
"url": originTask.FailReason,
|
||||
}
|
||||
respBody, _ = json.Marshal(dto.TaskResponse[any]{
|
||||
Code: "success",
|
||||
Data: out,
|
||||
})
|
||||
}
|
||||
}
|
||||
}()
|
||||
isOpenAIVideoAPI := strings.HasPrefix(c.Request.RequestURI, "/v1/videos/")
|
||||
|
||||
if len(respBody) != 0 {
|
||||
// Gemini/Vertex 支持实时查询:用户 fetch 时直接从上游拉取最新状态
|
||||
if realtimeResp := tryRealtimeFetch(originTask, isOpenAIVideoAPI); len(realtimeResp) > 0 {
|
||||
respBody = realtimeResp
|
||||
return
|
||||
}
|
||||
|
||||
if strings.HasPrefix(c.Request.RequestURI, "/v1/videos/") {
|
||||
// OpenAI Video API 格式: 走各 adaptor 的 ConvertToOpenAIVideo
|
||||
if isOpenAIVideoAPI {
|
||||
adaptor := GetTaskAdaptor(originTask.Platform)
|
||||
if adaptor == nil {
|
||||
taskResp = service.TaskErrorWrapperLocal(fmt.Errorf("invalid channel id: %d", originTask.ChannelId), "invalid_channel_id", http.StatusBadRequest)
|
||||
@@ -486,10 +394,12 @@ func videoFetchByIDRespBodyBuilder(c *gin.Context) (respBody []byte, taskResp *d
|
||||
respBody = openAIVideoData
|
||||
return
|
||||
}
|
||||
taskResp = service.TaskErrorWrapperLocal(errors.New(fmt.Sprintf("not_implemented:%s", originTask.Platform)), "not_implemented", http.StatusNotImplemented)
|
||||
taskResp = service.TaskErrorWrapperLocal(fmt.Errorf("not_implemented:%s", originTask.Platform), "not_implemented", http.StatusNotImplemented)
|
||||
return
|
||||
}
|
||||
respBody, err = json.Marshal(dto.TaskResponse[any]{
|
||||
|
||||
// 通用 TaskDto 格式
|
||||
respBody, err = common.Marshal(dto.TaskResponse[any]{
|
||||
Code: "success",
|
||||
Data: TaskModel2Dto(originTask),
|
||||
})
|
||||
@@ -499,16 +409,150 @@ func videoFetchByIDRespBodyBuilder(c *gin.Context) (respBody []byte, taskResp *d
|
||||
return
|
||||
}
|
||||
|
||||
// tryRealtimeFetch 尝试从上游实时拉取 Gemini/Vertex 任务状态。
|
||||
// 仅当渠道类型为 Gemini 或 Vertex 时触发;其他渠道或出错时返回 nil。
|
||||
// 当非 OpenAI Video API 时,还会构建自定义格式的响应体。
|
||||
func tryRealtimeFetch(task *model.Task, isOpenAIVideoAPI bool) []byte {
|
||||
channelModel, err := model.GetChannelById(task.ChannelId, true)
|
||||
if err != nil {
|
||||
return nil
|
||||
}
|
||||
if channelModel.Type != constant.ChannelTypeVertexAi && channelModel.Type != constant.ChannelTypeGemini {
|
||||
return nil
|
||||
}
|
||||
|
||||
baseURL := constant.ChannelBaseURLs[channelModel.Type]
|
||||
if channelModel.GetBaseURL() != "" {
|
||||
baseURL = channelModel.GetBaseURL()
|
||||
}
|
||||
proxy := channelModel.GetSetting().Proxy
|
||||
adaptor := GetTaskAdaptor(constant.TaskPlatform(strconv.Itoa(channelModel.Type)))
|
||||
if adaptor == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
resp, err := adaptor.FetchTask(baseURL, channelModel.Key, map[string]any{
|
||||
"task_id": task.GetUpstreamTaskID(),
|
||||
"action": task.Action,
|
||||
}, proxy)
|
||||
if err != nil || resp == nil {
|
||||
return nil
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
body, err := io.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
ti, err := adaptor.ParseTaskResult(body)
|
||||
if err != nil || ti == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
snap := task.Snapshot()
|
||||
|
||||
// 将上游最新状态更新到 task
|
||||
if ti.Status != "" {
|
||||
task.Status = model.TaskStatus(ti.Status)
|
||||
}
|
||||
if ti.Progress != "" {
|
||||
task.Progress = ti.Progress
|
||||
}
|
||||
if strings.HasPrefix(ti.Url, "data:") {
|
||||
// data: URI — kept in Data, not ResultURL
|
||||
} else if ti.Url != "" {
|
||||
task.PrivateData.ResultURL = ti.Url
|
||||
} else if task.Status == model.TaskStatusSuccess {
|
||||
// No URL from adaptor — construct proxy URL using public task ID
|
||||
task.PrivateData.ResultURL = taskcommon.BuildProxyURL(task.TaskID)
|
||||
}
|
||||
|
||||
if !snap.Equal(task.Snapshot()) {
|
||||
_, _ = task.UpdateWithStatus(snap.Status)
|
||||
}
|
||||
|
||||
// OpenAI Video API 由调用者的 ConvertToOpenAIVideo 分支处理
|
||||
if isOpenAIVideoAPI {
|
||||
return nil
|
||||
}
|
||||
|
||||
// 非 OpenAI Video API: 构建自定义格式响应
|
||||
format := detectVideoFormat(body)
|
||||
out := map[string]any{
|
||||
"error": nil,
|
||||
"format": format,
|
||||
"metadata": nil,
|
||||
"status": mapTaskStatusToSimple(task.Status),
|
||||
"task_id": task.TaskID,
|
||||
"url": task.GetResultURL(),
|
||||
}
|
||||
respBody, _ := common.Marshal(dto.TaskResponse[any]{
|
||||
Code: "success",
|
||||
Data: out,
|
||||
})
|
||||
return respBody
|
||||
}
|
||||
|
||||
// detectVideoFormat 从 Gemini/Vertex 原始响应中探测视频格式
|
||||
func detectVideoFormat(rawBody []byte) string {
|
||||
var raw map[string]any
|
||||
if err := common.Unmarshal(rawBody, &raw); err != nil {
|
||||
return "mp4"
|
||||
}
|
||||
respObj, ok := raw["response"].(map[string]any)
|
||||
if !ok {
|
||||
return "mp4"
|
||||
}
|
||||
vids, ok := respObj["videos"].([]any)
|
||||
if !ok || len(vids) == 0 {
|
||||
return "mp4"
|
||||
}
|
||||
v0, ok := vids[0].(map[string]any)
|
||||
if !ok {
|
||||
return "mp4"
|
||||
}
|
||||
mt, ok := v0["mimeType"].(string)
|
||||
if !ok || mt == "" || strings.Contains(mt, "mp4") {
|
||||
return "mp4"
|
||||
}
|
||||
return mt
|
||||
}
|
||||
|
||||
// mapTaskStatusToSimple 将内部 TaskStatus 映射为简化状态字符串
|
||||
func mapTaskStatusToSimple(status model.TaskStatus) string {
|
||||
switch status {
|
||||
case model.TaskStatusSuccess:
|
||||
return "succeeded"
|
||||
case model.TaskStatusFailure:
|
||||
return "failed"
|
||||
case model.TaskStatusQueued, model.TaskStatusSubmitted:
|
||||
return "queued"
|
||||
default:
|
||||
return "processing"
|
||||
}
|
||||
}
|
||||
|
||||
func TaskModel2Dto(task *model.Task) *dto.TaskDto {
|
||||
return &dto.TaskDto{
|
||||
ID: task.ID,
|
||||
CreatedAt: task.CreatedAt,
|
||||
UpdatedAt: task.UpdatedAt,
|
||||
TaskID: task.TaskID,
|
||||
Platform: string(task.Platform),
|
||||
UserId: task.UserId,
|
||||
Group: task.Group,
|
||||
ChannelId: task.ChannelId,
|
||||
Quota: task.Quota,
|
||||
Action: task.Action,
|
||||
Status: string(task.Status),
|
||||
FailReason: task.FailReason,
|
||||
ResultURL: task.GetResultURL(),
|
||||
SubmitTime: task.SubmitTime,
|
||||
StartTime: task.StartTime,
|
||||
FinishTime: task.FinishTime,
|
||||
Progress: task.Progress,
|
||||
Properties: task.Properties,
|
||||
Username: task.Username,
|
||||
Data: task.Data,
|
||||
}
|
||||
}
|
||||
|
||||
@@ -89,7 +89,7 @@ func ResponsesHelper(c *gin.Context, info *relaycommon.RelayInfo) (newAPIError *
|
||||
}
|
||||
|
||||
// remove disabled fields for OpenAI Responses API
|
||||
jsonData, err = relaycommon.RemoveDisabledFields(jsonData, info.ChannelOtherSettings)
|
||||
jsonData, err = relaycommon.RemoveDisabledFields(jsonData, info.ChannelOtherSettings, info.ChannelSetting.PassThroughBodyEnabled)
|
||||
if err != nil {
|
||||
return types.NewError(err, types.ErrorCodeConvertRequestFailed, types.ErrOptionWithSkipRetry())
|
||||
}
|
||||
|
||||
@@ -13,6 +13,7 @@ import (
|
||||
|
||||
func SetApiRouter(router *gin.Engine) {
|
||||
apiRouter := router.Group("/api")
|
||||
apiRouter.Use(middleware.RouteTag("api"))
|
||||
apiRouter.Use(gzip.Gzip(gzip.DefaultCompression))
|
||||
apiRouter.Use(middleware.BodyStorageCleanup()) // 清理请求体存储
|
||||
apiRouter.Use(middleware.GlobalAPIRateLimit())
|
||||
@@ -114,6 +115,9 @@ func SetApiRouter(router *gin.Engine) {
|
||||
adminRoute.GET("/topup", controller.GetAllTopUps)
|
||||
adminRoute.POST("/topup/complete", controller.AdminCompleteTopUp)
|
||||
adminRoute.GET("/search", controller.SearchUsers)
|
||||
adminRoute.GET("/:id/oauth/bindings", controller.GetUserOAuthBindingsByAdmin)
|
||||
adminRoute.DELETE("/:id/oauth/bindings/:provider_id", controller.UnbindCustomOAuthByAdmin)
|
||||
adminRoute.DELETE("/:id/bindings/:binding_type", controller.AdminClearUserBinding)
|
||||
adminRoute.GET("/:id", controller.GetUser)
|
||||
adminRoute.POST("/", controller.CreateUser)
|
||||
adminRoute.POST("/manage", controller.ManageUser)
|
||||
@@ -170,10 +174,11 @@ func SetApiRouter(router *gin.Engine) {
|
||||
optionRoute.POST("/migrate_console_setting", controller.MigrateConsoleSetting) // 用于迁移检测的旧键,下个版本会删除
|
||||
}
|
||||
|
||||
// Custom OAuth provider management (admin only)
|
||||
// Custom OAuth provider management (root only)
|
||||
customOAuthRoute := apiRouter.Group("/custom-oauth-provider")
|
||||
customOAuthRoute.Use(middleware.RootAuth())
|
||||
{
|
||||
customOAuthRoute.POST("/discovery", controller.FetchCustomOAuthDiscovery)
|
||||
customOAuthRoute.GET("/", controller.GetCustomOAuthProviders)
|
||||
customOAuthRoute.GET("/:id", controller.GetCustomOAuthProvider)
|
||||
customOAuthRoute.POST("/", controller.CreateCustomOAuthProvider)
|
||||
|
||||
@@ -9,6 +9,7 @@ import (
|
||||
|
||||
func SetDashboardRouter(router *gin.Engine) {
|
||||
apiRouter := router.Group("/")
|
||||
apiRouter.Use(middleware.RouteTag("old_api"))
|
||||
apiRouter.Use(gzip.Gzip(gzip.DefaultCompression))
|
||||
apiRouter.Use(middleware.GlobalAPIRateLimit())
|
||||
apiRouter.Use(middleware.CORS())
|
||||
|
||||
@@ -8,6 +8,7 @@ import (
|
||||
"strings"
|
||||
|
||||
"github.com/QuantumNous/new-api/common"
|
||||
"github.com/QuantumNous/new-api/middleware"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
@@ -27,6 +28,7 @@ func SetRouter(router *gin.Engine, buildFS embed.FS, indexPage []byte) {
|
||||
} else {
|
||||
frontendBaseUrl = strings.TrimSuffix(frontendBaseUrl, "/")
|
||||
router.NoRoute(func(c *gin.Context) {
|
||||
c.Set(middleware.RouteTagKey, "web")
|
||||
c.Redirect(http.StatusMovedPermanently, fmt.Sprintf("%s%s", frontendBaseUrl, c.Request.RequestURI))
|
||||
})
|
||||
}
|
||||
|
||||
@@ -17,6 +17,7 @@ func SetRelayRouter(router *gin.Engine) {
|
||||
router.Use(middleware.StatsMiddleware())
|
||||
// https://platform.openai.com/docs/api-reference/introduction
|
||||
modelsRouter := router.Group("/v1/models")
|
||||
modelsRouter.Use(middleware.RouteTag("relay"))
|
||||
modelsRouter.Use(middleware.TokenAuth())
|
||||
{
|
||||
modelsRouter.GET("", func(c *gin.Context) {
|
||||
@@ -41,6 +42,7 @@ func SetRelayRouter(router *gin.Engine) {
|
||||
}
|
||||
|
||||
geminiRouter := router.Group("/v1beta/models")
|
||||
geminiRouter.Use(middleware.RouteTag("relay"))
|
||||
geminiRouter.Use(middleware.TokenAuth())
|
||||
{
|
||||
geminiRouter.GET("", func(c *gin.Context) {
|
||||
@@ -49,6 +51,7 @@ func SetRelayRouter(router *gin.Engine) {
|
||||
}
|
||||
|
||||
geminiCompatibleRouter := router.Group("/v1beta/openai/models")
|
||||
geminiCompatibleRouter.Use(middleware.RouteTag("relay"))
|
||||
geminiCompatibleRouter.Use(middleware.TokenAuth())
|
||||
{
|
||||
geminiCompatibleRouter.GET("", func(c *gin.Context) {
|
||||
@@ -57,12 +60,14 @@ func SetRelayRouter(router *gin.Engine) {
|
||||
}
|
||||
|
||||
playgroundRouter := router.Group("/pg")
|
||||
playgroundRouter.Use(middleware.RouteTag("relay"))
|
||||
playgroundRouter.Use(middleware.SystemPerformanceCheck())
|
||||
playgroundRouter.Use(middleware.UserAuth(), middleware.Distribute())
|
||||
{
|
||||
playgroundRouter.POST("/chat/completions", controller.Playground)
|
||||
}
|
||||
relayV1Router := router.Group("/v1")
|
||||
relayV1Router.Use(middleware.RouteTag("relay"))
|
||||
relayV1Router.Use(middleware.SystemPerformanceCheck())
|
||||
relayV1Router.Use(middleware.TokenAuth())
|
||||
relayV1Router.Use(middleware.ModelRequestRateLimit())
|
||||
@@ -161,24 +166,28 @@ func SetRelayRouter(router *gin.Engine) {
|
||||
}
|
||||
|
||||
relayMjRouter := router.Group("/mj")
|
||||
relayMjRouter.Use(middleware.RouteTag("relay"))
|
||||
relayMjRouter.Use(middleware.SystemPerformanceCheck())
|
||||
registerMjRouterGroup(relayMjRouter)
|
||||
|
||||
relayMjModeRouter := router.Group("/:mode/mj")
|
||||
relayMjModeRouter.Use(middleware.RouteTag("relay"))
|
||||
relayMjModeRouter.Use(middleware.SystemPerformanceCheck())
|
||||
registerMjRouterGroup(relayMjModeRouter)
|
||||
//relayMjRouter.Use()
|
||||
|
||||
relaySunoRouter := router.Group("/suno")
|
||||
relaySunoRouter.Use(middleware.RouteTag("relay"))
|
||||
relaySunoRouter.Use(middleware.SystemPerformanceCheck())
|
||||
relaySunoRouter.Use(middleware.TokenAuth(), middleware.Distribute())
|
||||
{
|
||||
relaySunoRouter.POST("/submit/:action", controller.RelayTask)
|
||||
relaySunoRouter.POST("/fetch", controller.RelayTask)
|
||||
relaySunoRouter.GET("/fetch/:id", controller.RelayTask)
|
||||
relaySunoRouter.POST("/fetch", controller.RelayTaskFetch)
|
||||
relaySunoRouter.GET("/fetch/:id", controller.RelayTaskFetch)
|
||||
}
|
||||
|
||||
relayGeminiRouter := router.Group("/v1beta")
|
||||
relayGeminiRouter.Use(middleware.RouteTag("relay"))
|
||||
relayGeminiRouter.Use(middleware.SystemPerformanceCheck())
|
||||
relayGeminiRouter.Use(middleware.TokenAuth())
|
||||
relayGeminiRouter.Use(middleware.ModelRequestRateLimit())
|
||||
|
||||
@@ -8,32 +8,42 @@ import (
|
||||
)
|
||||
|
||||
func SetVideoRouter(router *gin.Engine) {
|
||||
// Video proxy: accepts either session auth (dashboard) or token auth (API clients)
|
||||
videoProxyRouter := router.Group("/v1")
|
||||
videoProxyRouter.Use(middleware.RouteTag("relay"))
|
||||
videoProxyRouter.Use(middleware.TokenOrUserAuth())
|
||||
{
|
||||
videoProxyRouter.GET("/videos/:task_id/content", controller.VideoProxy)
|
||||
}
|
||||
|
||||
videoV1Router := router.Group("/v1")
|
||||
videoV1Router.Use(middleware.RouteTag("relay"))
|
||||
videoV1Router.Use(middleware.TokenAuth(), middleware.Distribute())
|
||||
{
|
||||
videoV1Router.GET("/videos/:task_id/content", controller.VideoProxy)
|
||||
videoV1Router.POST("/video/generations", controller.RelayTask)
|
||||
videoV1Router.GET("/video/generations/:task_id", controller.RelayTask)
|
||||
videoV1Router.GET("/video/generations/:task_id", controller.RelayTaskFetch)
|
||||
videoV1Router.POST("/videos/:video_id/remix", controller.RelayTask)
|
||||
}
|
||||
// openai compatible API video routes
|
||||
// docs: https://platform.openai.com/docs/api-reference/videos/create
|
||||
{
|
||||
videoV1Router.POST("/videos", controller.RelayTask)
|
||||
videoV1Router.GET("/videos/:task_id", controller.RelayTask)
|
||||
videoV1Router.GET("/videos/:task_id", controller.RelayTaskFetch)
|
||||
}
|
||||
|
||||
klingV1Router := router.Group("/kling/v1")
|
||||
klingV1Router.Use(middleware.RouteTag("relay"))
|
||||
klingV1Router.Use(middleware.KlingRequestConvert(), middleware.TokenAuth(), middleware.Distribute())
|
||||
{
|
||||
klingV1Router.POST("/videos/text2video", controller.RelayTask)
|
||||
klingV1Router.POST("/videos/image2video", controller.RelayTask)
|
||||
klingV1Router.GET("/videos/text2video/:task_id", controller.RelayTask)
|
||||
klingV1Router.GET("/videos/image2video/:task_id", controller.RelayTask)
|
||||
klingV1Router.GET("/videos/text2video/:task_id", controller.RelayTaskFetch)
|
||||
klingV1Router.GET("/videos/image2video/:task_id", controller.RelayTaskFetch)
|
||||
}
|
||||
|
||||
// Jimeng official API routes - direct mapping to official API format
|
||||
jimengOfficialGroup := router.Group("jimeng")
|
||||
jimengOfficialGroup.Use(middleware.RouteTag("relay"))
|
||||
jimengOfficialGroup.Use(middleware.JimengRequestConvert(), middleware.TokenAuth(), middleware.Distribute())
|
||||
{
|
||||
// Maps to: /?Action=CVSync2AsyncSubmitTask&Version=2022-08-31 and /?Action=CVSync2AsyncGetResult&Version=2022-08-31
|
||||
|
||||
@@ -19,6 +19,7 @@ func SetWebRouter(router *gin.Engine, buildFS embed.FS, indexPage []byte) {
|
||||
router.Use(middleware.Cache())
|
||||
router.Use(static.Serve("/", common.EmbedFolder(buildFS, "web/dist")))
|
||||
router.NoRoute(func(c *gin.Context) {
|
||||
c.Set(middleware.RouteTagKey, "web")
|
||||
if strings.HasPrefix(c.Request.RequestURI, "/v1") || strings.HasPrefix(c.Request.RequestURI, "/api") || strings.HasPrefix(c.Request.RequestURI, "/assets") {
|
||||
controller.RelayNotFound(c)
|
||||
return
|
||||
|
||||
@@ -193,6 +193,11 @@ func (s *BillingSession) preConsume(c *gin.Context, quota int) *types.NewAPIErro
|
||||
|
||||
// shouldTrust 统一信任额度检查,适用于钱包和订阅。
|
||||
func (s *BillingSession) shouldTrust(c *gin.Context) bool {
|
||||
// 异步任务(ForcePreConsume=true)必须预扣全额,不允许信任旁路
|
||||
if s.relayInfo.ForcePreConsume {
|
||||
return false
|
||||
}
|
||||
|
||||
trustQuota := common.GetTrustQuota()
|
||||
if trustQuota <= 0 {
|
||||
return false
|
||||
|
||||
@@ -13,6 +13,7 @@ import (
|
||||
"github.com/QuantumNous/new-api/dto"
|
||||
"github.com/QuantumNous/new-api/pkg/cachex"
|
||||
"github.com/QuantumNous/new-api/setting/operation_setting"
|
||||
"github.com/QuantumNous/new-api/types"
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/samber/hot"
|
||||
"github.com/tidwall/gjson"
|
||||
@@ -61,6 +62,12 @@ type ChannelAffinityStatsContext struct {
|
||||
TTLSeconds int64
|
||||
}
|
||||
|
||||
const (
|
||||
cacheTokenRateModeCachedOverPrompt = "cached_over_prompt"
|
||||
cacheTokenRateModeCachedOverPromptPlusCached = "cached_over_prompt_plus_cached"
|
||||
cacheTokenRateModeMixed = "mixed"
|
||||
)
|
||||
|
||||
type ChannelAffinityCacheStats struct {
|
||||
Enabled bool `json:"enabled"`
|
||||
Total int `json:"total"`
|
||||
@@ -565,9 +572,10 @@ func RecordChannelAffinity(c *gin.Context, channelID int) {
|
||||
}
|
||||
|
||||
type ChannelAffinityUsageCacheStats struct {
|
||||
RuleName string `json:"rule_name"`
|
||||
UsingGroup string `json:"using_group"`
|
||||
KeyFingerprint string `json:"key_fp"`
|
||||
RuleName string `json:"rule_name"`
|
||||
UsingGroup string `json:"using_group"`
|
||||
KeyFingerprint string `json:"key_fp"`
|
||||
CachedTokenRateMode string `json:"cached_token_rate_mode"`
|
||||
|
||||
Hit int64 `json:"hit"`
|
||||
Total int64 `json:"total"`
|
||||
@@ -582,6 +590,8 @@ type ChannelAffinityUsageCacheStats struct {
|
||||
}
|
||||
|
||||
type ChannelAffinityUsageCacheCounters struct {
|
||||
CachedTokenRateMode string `json:"cached_token_rate_mode"`
|
||||
|
||||
Hit int64 `json:"hit"`
|
||||
Total int64 `json:"total"`
|
||||
WindowSeconds int64 `json:"window_seconds"`
|
||||
@@ -596,12 +606,17 @@ type ChannelAffinityUsageCacheCounters struct {
|
||||
|
||||
var channelAffinityUsageCacheStatsLocks [64]sync.Mutex
|
||||
|
||||
func ObserveChannelAffinityUsageCacheFromContext(c *gin.Context, usage *dto.Usage) {
|
||||
// ObserveChannelAffinityUsageCacheByRelayFormat records usage cache stats with a stable rate mode derived from relay format.
|
||||
func ObserveChannelAffinityUsageCacheByRelayFormat(c *gin.Context, usage *dto.Usage, relayFormat types.RelayFormat) {
|
||||
ObserveChannelAffinityUsageCacheFromContext(c, usage, cachedTokenRateModeByRelayFormat(relayFormat))
|
||||
}
|
||||
|
||||
func ObserveChannelAffinityUsageCacheFromContext(c *gin.Context, usage *dto.Usage, cachedTokenRateMode string) {
|
||||
statsCtx, ok := GetChannelAffinityStatsContext(c)
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
observeChannelAffinityUsageCache(statsCtx, usage)
|
||||
observeChannelAffinityUsageCache(statsCtx, usage, cachedTokenRateMode)
|
||||
}
|
||||
|
||||
func GetChannelAffinityUsageCacheStats(ruleName, usingGroup, keyFp string) ChannelAffinityUsageCacheStats {
|
||||
@@ -628,6 +643,7 @@ func GetChannelAffinityUsageCacheStats(ruleName, usingGroup, keyFp string) Chann
|
||||
}
|
||||
}
|
||||
return ChannelAffinityUsageCacheStats{
|
||||
CachedTokenRateMode: v.CachedTokenRateMode,
|
||||
RuleName: ruleName,
|
||||
UsingGroup: usingGroup,
|
||||
KeyFingerprint: keyFp,
|
||||
@@ -643,7 +659,7 @@ func GetChannelAffinityUsageCacheStats(ruleName, usingGroup, keyFp string) Chann
|
||||
}
|
||||
}
|
||||
|
||||
func observeChannelAffinityUsageCache(statsCtx ChannelAffinityStatsContext, usage *dto.Usage) {
|
||||
func observeChannelAffinityUsageCache(statsCtx ChannelAffinityStatsContext, usage *dto.Usage, cachedTokenRateMode string) {
|
||||
entryKey := channelAffinityUsageCacheEntryKey(statsCtx.RuleName, statsCtx.UsingGroup, statsCtx.KeyFingerprint)
|
||||
if entryKey == "" {
|
||||
return
|
||||
@@ -669,6 +685,14 @@ func observeChannelAffinityUsageCache(statsCtx ChannelAffinityStatsContext, usag
|
||||
if !found {
|
||||
next = ChannelAffinityUsageCacheCounters{}
|
||||
}
|
||||
currentMode := normalizeCachedTokenRateMode(cachedTokenRateMode)
|
||||
if currentMode != "" {
|
||||
if next.CachedTokenRateMode == "" {
|
||||
next.CachedTokenRateMode = currentMode
|
||||
} else if next.CachedTokenRateMode != currentMode && next.CachedTokenRateMode != cacheTokenRateModeMixed {
|
||||
next.CachedTokenRateMode = cacheTokenRateModeMixed
|
||||
}
|
||||
}
|
||||
next.Total++
|
||||
hit, cachedTokens, promptCacheHitTokens := usageCacheSignals(usage)
|
||||
if hit {
|
||||
@@ -684,6 +708,30 @@ func observeChannelAffinityUsageCache(statsCtx ChannelAffinityStatsContext, usag
|
||||
_ = cache.SetWithTTL(entryKey, next, ttl)
|
||||
}
|
||||
|
||||
func normalizeCachedTokenRateMode(mode string) string {
|
||||
switch mode {
|
||||
case cacheTokenRateModeCachedOverPrompt:
|
||||
return cacheTokenRateModeCachedOverPrompt
|
||||
case cacheTokenRateModeCachedOverPromptPlusCached:
|
||||
return cacheTokenRateModeCachedOverPromptPlusCached
|
||||
case cacheTokenRateModeMixed:
|
||||
return cacheTokenRateModeMixed
|
||||
default:
|
||||
return ""
|
||||
}
|
||||
}
|
||||
|
||||
func cachedTokenRateModeByRelayFormat(relayFormat types.RelayFormat) string {
|
||||
switch relayFormat {
|
||||
case types.RelayFormatOpenAI, types.RelayFormatOpenAIResponses, types.RelayFormatOpenAIResponsesCompaction:
|
||||
return cacheTokenRateModeCachedOverPrompt
|
||||
case types.RelayFormatClaude:
|
||||
return cacheTokenRateModeCachedOverPromptPlusCached
|
||||
default:
|
||||
return ""
|
||||
}
|
||||
}
|
||||
|
||||
func channelAffinityUsageCacheEntryKey(ruleName, usingGroup, keyFp string) string {
|
||||
ruleName = strings.TrimSpace(ruleName)
|
||||
usingGroup = strings.TrimSpace(usingGroup)
|
||||
|
||||
105
service/channel_affinity_usage_cache_test.go
Normal file
105
service/channel_affinity_usage_cache_test.go
Normal file
@@ -0,0 +1,105 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"net/http/httptest"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/QuantumNous/new-api/dto"
|
||||
"github.com/QuantumNous/new-api/types"
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func buildChannelAffinityStatsContextForTest(ruleName, usingGroup, keyFP string) *gin.Context {
|
||||
rec := httptest.NewRecorder()
|
||||
ctx, _ := gin.CreateTestContext(rec)
|
||||
setChannelAffinityContext(ctx, channelAffinityMeta{
|
||||
CacheKey: fmt.Sprintf("test:%s:%s:%s", ruleName, usingGroup, keyFP),
|
||||
TTLSeconds: 600,
|
||||
RuleName: ruleName,
|
||||
UsingGroup: usingGroup,
|
||||
KeyFingerprint: keyFP,
|
||||
})
|
||||
return ctx
|
||||
}
|
||||
|
||||
func TestObserveChannelAffinityUsageCacheByRelayFormat_ClaudeMode(t *testing.T) {
|
||||
ruleName := fmt.Sprintf("rule_%d", time.Now().UnixNano())
|
||||
usingGroup := "default"
|
||||
keyFP := fmt.Sprintf("fp_%d", time.Now().UnixNano())
|
||||
ctx := buildChannelAffinityStatsContextForTest(ruleName, usingGroup, keyFP)
|
||||
|
||||
usage := &dto.Usage{
|
||||
PromptTokens: 100,
|
||||
CompletionTokens: 40,
|
||||
TotalTokens: 140,
|
||||
PromptTokensDetails: dto.InputTokenDetails{
|
||||
CachedTokens: 30,
|
||||
},
|
||||
}
|
||||
|
||||
ObserveChannelAffinityUsageCacheByRelayFormat(ctx, usage, types.RelayFormatClaude)
|
||||
stats := GetChannelAffinityUsageCacheStats(ruleName, usingGroup, keyFP)
|
||||
|
||||
require.EqualValues(t, 1, stats.Total)
|
||||
require.EqualValues(t, 1, stats.Hit)
|
||||
require.EqualValues(t, 100, stats.PromptTokens)
|
||||
require.EqualValues(t, 40, stats.CompletionTokens)
|
||||
require.EqualValues(t, 140, stats.TotalTokens)
|
||||
require.EqualValues(t, 30, stats.CachedTokens)
|
||||
require.Equal(t, cacheTokenRateModeCachedOverPromptPlusCached, stats.CachedTokenRateMode)
|
||||
}
|
||||
|
||||
func TestObserveChannelAffinityUsageCacheByRelayFormat_MixedMode(t *testing.T) {
|
||||
ruleName := fmt.Sprintf("rule_%d", time.Now().UnixNano())
|
||||
usingGroup := "default"
|
||||
keyFP := fmt.Sprintf("fp_%d", time.Now().UnixNano())
|
||||
ctx := buildChannelAffinityStatsContextForTest(ruleName, usingGroup, keyFP)
|
||||
|
||||
openAIUsage := &dto.Usage{
|
||||
PromptTokens: 100,
|
||||
PromptTokensDetails: dto.InputTokenDetails{
|
||||
CachedTokens: 10,
|
||||
},
|
||||
}
|
||||
claudeUsage := &dto.Usage{
|
||||
PromptTokens: 80,
|
||||
PromptTokensDetails: dto.InputTokenDetails{
|
||||
CachedTokens: 20,
|
||||
},
|
||||
}
|
||||
|
||||
ObserveChannelAffinityUsageCacheByRelayFormat(ctx, openAIUsage, types.RelayFormatOpenAI)
|
||||
ObserveChannelAffinityUsageCacheByRelayFormat(ctx, claudeUsage, types.RelayFormatClaude)
|
||||
stats := GetChannelAffinityUsageCacheStats(ruleName, usingGroup, keyFP)
|
||||
|
||||
require.EqualValues(t, 2, stats.Total)
|
||||
require.EqualValues(t, 2, stats.Hit)
|
||||
require.EqualValues(t, 180, stats.PromptTokens)
|
||||
require.EqualValues(t, 30, stats.CachedTokens)
|
||||
require.Equal(t, cacheTokenRateModeMixed, stats.CachedTokenRateMode)
|
||||
}
|
||||
|
||||
func TestObserveChannelAffinityUsageCacheByRelayFormat_UnsupportedModeKeepsEmpty(t *testing.T) {
|
||||
ruleName := fmt.Sprintf("rule_%d", time.Now().UnixNano())
|
||||
usingGroup := "default"
|
||||
keyFP := fmt.Sprintf("fp_%d", time.Now().UnixNano())
|
||||
ctx := buildChannelAffinityStatsContextForTest(ruleName, usingGroup, keyFP)
|
||||
|
||||
usage := &dto.Usage{
|
||||
PromptTokens: 100,
|
||||
PromptTokensDetails: dto.InputTokenDetails{
|
||||
CachedTokens: 25,
|
||||
},
|
||||
}
|
||||
|
||||
ObserveChannelAffinityUsageCacheByRelayFormat(ctx, usage, types.RelayFormatGemini)
|
||||
stats := GetChannelAffinityUsageCacheStats(ruleName, usingGroup, keyFP)
|
||||
|
||||
require.EqualValues(t, 1, stats.Total)
|
||||
require.EqualValues(t, 1, stats.Hit)
|
||||
require.EqualValues(t, 25, stats.CachedTokens)
|
||||
require.Equal(t, "", stats.CachedTokenRateMode)
|
||||
}
|
||||
@@ -62,7 +62,7 @@ func RefreshCodexChannelCredential(ctx context.Context, channelID int, opts Code
|
||||
refreshCtx, cancel := context.WithTimeout(ctx, 10*time.Second)
|
||||
defer cancel()
|
||||
|
||||
res, err := RefreshCodexOAuthToken(refreshCtx, oauthKey.RefreshToken)
|
||||
res, err := RefreshCodexOAuthTokenWithProxy(refreshCtx, oauthKey.RefreshToken, ch.GetSetting().Proxy)
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
|
||||
@@ -12,6 +12,8 @@ import (
|
||||
"net/url"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/QuantumNous/new-api/common"
|
||||
)
|
||||
|
||||
const (
|
||||
@@ -38,12 +40,26 @@ type CodexOAuthAuthorizationFlow struct {
|
||||
}
|
||||
|
||||
func RefreshCodexOAuthToken(ctx context.Context, refreshToken string) (*CodexOAuthTokenResult, error) {
|
||||
client := &http.Client{Timeout: defaultHTTPTimeout}
|
||||
return RefreshCodexOAuthTokenWithProxy(ctx, refreshToken, "")
|
||||
}
|
||||
|
||||
func RefreshCodexOAuthTokenWithProxy(ctx context.Context, refreshToken string, proxyURL string) (*CodexOAuthTokenResult, error) {
|
||||
client, err := getCodexOAuthHTTPClient(proxyURL)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return refreshCodexOAuthToken(ctx, client, codexOAuthTokenURL, codexOAuthClientID, refreshToken)
|
||||
}
|
||||
|
||||
func ExchangeCodexAuthorizationCode(ctx context.Context, code string, verifier string) (*CodexOAuthTokenResult, error) {
|
||||
client := &http.Client{Timeout: defaultHTTPTimeout}
|
||||
return ExchangeCodexAuthorizationCodeWithProxy(ctx, code, verifier, "")
|
||||
}
|
||||
|
||||
func ExchangeCodexAuthorizationCodeWithProxy(ctx context.Context, code string, verifier string, proxyURL string) (*CodexOAuthTokenResult, error) {
|
||||
client, err := getCodexOAuthHTTPClient(proxyURL)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return exchangeCodexAuthorizationCode(ctx, client, codexOAuthTokenURL, codexOAuthClientID, code, verifier, codexOAuthRedirectURI)
|
||||
}
|
||||
|
||||
@@ -104,7 +120,7 @@ func refreshCodexOAuthToken(
|
||||
ExpiresIn int `json:"expires_in"`
|
||||
}
|
||||
|
||||
if err := json.NewDecoder(resp.Body).Decode(&payload); err != nil {
|
||||
if err := common.DecodeJson(resp.Body, &payload); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if resp.StatusCode < 200 || resp.StatusCode >= 300 {
|
||||
@@ -165,7 +181,7 @@ func exchangeCodexAuthorizationCode(
|
||||
RefreshToken string `json:"refresh_token"`
|
||||
ExpiresIn int `json:"expires_in"`
|
||||
}
|
||||
if err := json.NewDecoder(resp.Body).Decode(&payload); err != nil {
|
||||
if err := common.DecodeJson(resp.Body, &payload); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if resp.StatusCode < 200 || resp.StatusCode >= 300 {
|
||||
@@ -181,6 +197,19 @@ func exchangeCodexAuthorizationCode(
|
||||
}, nil
|
||||
}
|
||||
|
||||
func getCodexOAuthHTTPClient(proxyURL string) (*http.Client, error) {
|
||||
baseClient, err := GetHttpClientWithProxy(strings.TrimSpace(proxyURL))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if baseClient == nil {
|
||||
return &http.Client{Timeout: defaultHTTPTimeout}, nil
|
||||
}
|
||||
clientCopy := *baseClient
|
||||
clientCopy.Timeout = defaultHTTPTimeout
|
||||
return &clientCopy, nil
|
||||
}
|
||||
|
||||
func buildCodexAuthorizeURL(state string, challenge string) (string, error) {
|
||||
u, err := url.Parse(codexOAuthAuthorizeURL)
|
||||
if err != nil {
|
||||
|
||||
@@ -206,3 +206,16 @@ func TaskErrorWrapper(err error, code string, statusCode int) *dto.TaskError {
|
||||
|
||||
return taskError
|
||||
}
|
||||
|
||||
// TaskErrorFromAPIError 将 PreConsumeBilling 返回的 NewAPIError 转换为 TaskError。
|
||||
func TaskErrorFromAPIError(apiErr *types.NewAPIError) *dto.TaskError {
|
||||
if apiErr == nil {
|
||||
return nil
|
||||
}
|
||||
return &dto.TaskError{
|
||||
Code: string(apiErr.GetErrorCode()),
|
||||
Message: apiErr.Err.Error(),
|
||||
StatusCode: apiErr.StatusCode,
|
||||
Error: apiErr.Err,
|
||||
}
|
||||
}
|
||||
|
||||
@@ -204,7 +204,7 @@ func GenerateClaudeOtherInfo(ctx *gin.Context, relayInfo *relaycommon.RelayInfo,
|
||||
return info
|
||||
}
|
||||
|
||||
func GenerateMjOtherInfo(relayInfo *relaycommon.RelayInfo, priceData types.PerCallPriceData) map[string]interface{} {
|
||||
func GenerateMjOtherInfo(relayInfo *relaycommon.RelayInfo, priceData types.PriceData) map[string]interface{} {
|
||||
other := make(map[string]interface{})
|
||||
other["model_price"] = priceData.ModelPrice
|
||||
other["group_ratio"] = priceData.GroupRatioInfo.GroupRatio
|
||||
|
||||
@@ -19,7 +19,7 @@ import (
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
func CoverActionToModelName(mjAction string) string {
|
||||
func CovertMjpActionToModelName(mjAction string) string {
|
||||
modelName := "mj_" + strings.ToLower(mjAction)
|
||||
if mjAction == constant.MjActionSwapFace {
|
||||
modelName = "swap_face"
|
||||
@@ -70,7 +70,7 @@ func GetMjRequestModel(relayMode int, midjRequest *dto.MidjourneyRequest) (strin
|
||||
return "", MidjourneyErrorWrapper(constant.MjRequestError, "unknown_relay_action"), false
|
||||
}
|
||||
}
|
||||
modelName := CoverActionToModelName(action)
|
||||
modelName := CovertMjpActionToModelName(action)
|
||||
return modelName, nil, true
|
||||
}
|
||||
|
||||
|
||||
@@ -236,6 +236,9 @@ func PostWssConsumeQuota(ctx *gin.Context, relayInfo *relaycommon.RelayInfo, mod
|
||||
}
|
||||
|
||||
func PostClaudeConsumeQuota(ctx *gin.Context, relayInfo *relaycommon.RelayInfo, usage *dto.Usage) {
|
||||
if usage != nil {
|
||||
ObserveChannelAffinityUsageCacheByRelayFormat(ctx, usage, relayInfo.GetFinalRequestRelayFormat())
|
||||
}
|
||||
|
||||
useTimeSeconds := time.Now().Unix() - relayInfo.StartTime.Unix()
|
||||
promptTokens := usage.PromptTokens
|
||||
|
||||
285
service/task_billing.go
Normal file
285
service/task_billing.go
Normal file
@@ -0,0 +1,285 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"strings"
|
||||
|
||||
"github.com/QuantumNous/new-api/common"
|
||||
"github.com/QuantumNous/new-api/constant"
|
||||
"github.com/QuantumNous/new-api/logger"
|
||||
"github.com/QuantumNous/new-api/model"
|
||||
relaycommon "github.com/QuantumNous/new-api/relay/common"
|
||||
"github.com/QuantumNous/new-api/setting/ratio_setting"
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
// LogTaskConsumption 记录任务消费日志和统计信息(仅记录,不涉及实际扣费)。
|
||||
// 实际扣费已由 BillingSession(PreConsumeBilling + SettleBilling)完成。
|
||||
func LogTaskConsumption(c *gin.Context, info *relaycommon.RelayInfo) {
|
||||
tokenName := c.GetString("token_name")
|
||||
logContent := fmt.Sprintf("操作 %s", info.Action)
|
||||
// 支持任务仅按次计费
|
||||
if common.StringsContains(constant.TaskPricePatches, info.OriginModelName) {
|
||||
logContent = fmt.Sprintf("%s,按次计费", logContent)
|
||||
} else {
|
||||
if len(info.PriceData.OtherRatios) > 0 {
|
||||
var contents []string
|
||||
for key, ra := range info.PriceData.OtherRatios {
|
||||
if 1.0 != ra {
|
||||
contents = append(contents, fmt.Sprintf("%s: %.2f", key, ra))
|
||||
}
|
||||
}
|
||||
if len(contents) > 0 {
|
||||
logContent = fmt.Sprintf("%s, 计算参数:%s", logContent, strings.Join(contents, ", "))
|
||||
}
|
||||
}
|
||||
}
|
||||
other := make(map[string]interface{})
|
||||
other["request_path"] = c.Request.URL.Path
|
||||
other["model_price"] = info.PriceData.ModelPrice
|
||||
other["group_ratio"] = info.PriceData.GroupRatioInfo.GroupRatio
|
||||
if info.PriceData.GroupRatioInfo.HasSpecialRatio {
|
||||
other["user_group_ratio"] = info.PriceData.GroupRatioInfo.GroupSpecialRatio
|
||||
}
|
||||
if info.IsModelMapped {
|
||||
other["is_model_mapped"] = true
|
||||
other["upstream_model_name"] = info.UpstreamModelName
|
||||
}
|
||||
model.RecordConsumeLog(c, info.UserId, model.RecordConsumeLogParams{
|
||||
ChannelId: info.ChannelId,
|
||||
ModelName: info.OriginModelName,
|
||||
TokenName: tokenName,
|
||||
Quota: info.PriceData.Quota,
|
||||
Content: logContent,
|
||||
TokenId: info.TokenId,
|
||||
Group: info.UsingGroup,
|
||||
Other: other,
|
||||
})
|
||||
model.UpdateUserUsedQuotaAndRequestCount(info.UserId, info.PriceData.Quota)
|
||||
model.UpdateChannelUsedQuota(info.ChannelId, info.PriceData.Quota)
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// 异步任务计费辅助函数
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
// resolveTokenKey 通过 TokenId 运行时获取令牌 Key(用于 Redis 缓存操作)。
|
||||
// 如果令牌已被删除或查询失败,返回空字符串。
|
||||
func resolveTokenKey(ctx context.Context, tokenId int, taskID string) string {
|
||||
token, err := model.GetTokenById(tokenId)
|
||||
if err != nil {
|
||||
logger.LogWarn(ctx, fmt.Sprintf("获取令牌 key 失败 (tokenId=%d, task=%s): %s", tokenId, taskID, err.Error()))
|
||||
return ""
|
||||
}
|
||||
return token.Key
|
||||
}
|
||||
|
||||
// taskIsSubscription 判断任务是否通过订阅计费。
|
||||
func taskIsSubscription(task *model.Task) bool {
|
||||
return task.PrivateData.BillingSource == BillingSourceSubscription && task.PrivateData.SubscriptionId > 0
|
||||
}
|
||||
|
||||
// taskAdjustFunding 调整任务的资金来源(钱包或订阅),delta > 0 表示扣费,delta < 0 表示退还。
|
||||
func taskAdjustFunding(task *model.Task, delta int) error {
|
||||
if taskIsSubscription(task) {
|
||||
return model.PostConsumeUserSubscriptionDelta(task.PrivateData.SubscriptionId, int64(delta))
|
||||
}
|
||||
if delta > 0 {
|
||||
return model.DecreaseUserQuota(task.UserId, delta)
|
||||
}
|
||||
return model.IncreaseUserQuota(task.UserId, -delta, false)
|
||||
}
|
||||
|
||||
// taskAdjustTokenQuota 调整任务的令牌额度,delta > 0 表示扣费,delta < 0 表示退还。
|
||||
// 需要通过 resolveTokenKey 运行时获取 key(不从 PrivateData 中读取)。
|
||||
func taskAdjustTokenQuota(ctx context.Context, task *model.Task, delta int) {
|
||||
if task.PrivateData.TokenId <= 0 || delta == 0 {
|
||||
return
|
||||
}
|
||||
tokenKey := resolveTokenKey(ctx, task.PrivateData.TokenId, task.TaskID)
|
||||
if tokenKey == "" {
|
||||
return
|
||||
}
|
||||
var err error
|
||||
if delta > 0 {
|
||||
err = model.DecreaseTokenQuota(task.PrivateData.TokenId, tokenKey, delta)
|
||||
} else {
|
||||
err = model.IncreaseTokenQuota(task.PrivateData.TokenId, tokenKey, -delta)
|
||||
}
|
||||
if err != nil {
|
||||
logger.LogWarn(ctx, fmt.Sprintf("调整令牌额度失败 (delta=%d, task=%s): %s", delta, task.TaskID, err.Error()))
|
||||
}
|
||||
}
|
||||
|
||||
// taskBillingOther 从 task 的 BillingContext 构建日志 Other 字段。
|
||||
func taskBillingOther(task *model.Task) map[string]interface{} {
|
||||
other := make(map[string]interface{})
|
||||
if bc := task.PrivateData.BillingContext; bc != nil {
|
||||
other["model_price"] = bc.ModelPrice
|
||||
other["group_ratio"] = bc.GroupRatio
|
||||
if len(bc.OtherRatios) > 0 {
|
||||
for k, v := range bc.OtherRatios {
|
||||
other[k] = v
|
||||
}
|
||||
}
|
||||
}
|
||||
props := task.Properties
|
||||
if props.UpstreamModelName != "" && props.UpstreamModelName != props.OriginModelName {
|
||||
other["is_model_mapped"] = true
|
||||
other["upstream_model_name"] = props.UpstreamModelName
|
||||
}
|
||||
return other
|
||||
}
|
||||
|
||||
// taskModelName 从 BillingContext 或 Properties 中获取模型名称。
|
||||
func taskModelName(task *model.Task) string {
|
||||
if bc := task.PrivateData.BillingContext; bc != nil && bc.OriginModelName != "" {
|
||||
return bc.OriginModelName
|
||||
}
|
||||
return task.Properties.OriginModelName
|
||||
}
|
||||
|
||||
// RefundTaskQuota 统一的任务失败退款逻辑。
|
||||
// 当异步任务失败时,将预扣的 quota 退还给用户(支持钱包和订阅),并退还令牌额度。
|
||||
func RefundTaskQuota(ctx context.Context, task *model.Task, reason string) {
|
||||
quota := task.Quota
|
||||
if quota == 0 {
|
||||
return
|
||||
}
|
||||
|
||||
// 1. 退还资金来源(钱包或订阅)
|
||||
if err := taskAdjustFunding(task, -quota); err != nil {
|
||||
logger.LogWarn(ctx, fmt.Sprintf("退还资金来源失败 task %s: %s", task.TaskID, err.Error()))
|
||||
return
|
||||
}
|
||||
|
||||
// 2. 退还令牌额度
|
||||
taskAdjustTokenQuota(ctx, task, -quota)
|
||||
|
||||
// 3. 记录日志
|
||||
other := taskBillingOther(task)
|
||||
other["task_id"] = task.TaskID
|
||||
other["reason"] = reason
|
||||
model.RecordTaskBillingLog(model.RecordTaskBillingLogParams{
|
||||
UserId: task.UserId,
|
||||
LogType: model.LogTypeRefund,
|
||||
Content: "",
|
||||
ChannelId: task.ChannelId,
|
||||
ModelName: taskModelName(task),
|
||||
Quota: quota,
|
||||
TokenId: task.PrivateData.TokenId,
|
||||
Group: task.Group,
|
||||
Other: other,
|
||||
})
|
||||
}
|
||||
|
||||
// RecalculateTaskQuota 通用的异步差额结算。
|
||||
// actualQuota 是任务完成后的实际应扣额度,与预扣额度 (task.Quota) 做差额结算。
|
||||
// reason 用于日志记录(例如 "token重算" 或 "adaptor调整")。
|
||||
func RecalculateTaskQuota(ctx context.Context, task *model.Task, actualQuota int, reason string) {
|
||||
if actualQuota <= 0 {
|
||||
return
|
||||
}
|
||||
preConsumedQuota := task.Quota
|
||||
quotaDelta := actualQuota - preConsumedQuota
|
||||
|
||||
if quotaDelta == 0 {
|
||||
logger.LogInfo(ctx, fmt.Sprintf("任务 %s 预扣费准确(%s,%s)",
|
||||
task.TaskID, logger.LogQuota(actualQuota), reason))
|
||||
return
|
||||
}
|
||||
|
||||
logger.LogInfo(ctx, fmt.Sprintf("任务 %s 差额结算:delta=%s(实际:%s,预扣:%s,%s)",
|
||||
task.TaskID,
|
||||
logger.LogQuota(quotaDelta),
|
||||
logger.LogQuota(actualQuota),
|
||||
logger.LogQuota(preConsumedQuota),
|
||||
reason,
|
||||
))
|
||||
|
||||
// 调整资金来源
|
||||
if err := taskAdjustFunding(task, quotaDelta); err != nil {
|
||||
logger.LogError(ctx, fmt.Sprintf("差额结算资金调整失败 task %s: %s", task.TaskID, err.Error()))
|
||||
return
|
||||
}
|
||||
|
||||
// 调整令牌额度
|
||||
taskAdjustTokenQuota(ctx, task, quotaDelta)
|
||||
|
||||
task.Quota = actualQuota
|
||||
|
||||
var logType int
|
||||
var logQuota int
|
||||
if quotaDelta > 0 {
|
||||
logType = model.LogTypeConsume
|
||||
logQuota = quotaDelta
|
||||
model.UpdateUserUsedQuotaAndRequestCount(task.UserId, quotaDelta)
|
||||
model.UpdateChannelUsedQuota(task.ChannelId, quotaDelta)
|
||||
} else {
|
||||
logType = model.LogTypeRefund
|
||||
logQuota = -quotaDelta
|
||||
}
|
||||
other := taskBillingOther(task)
|
||||
other["task_id"] = task.TaskID
|
||||
other["reason"] = reason
|
||||
other["pre_consumed_quota"] = preConsumedQuota
|
||||
other["actual_quota"] = actualQuota
|
||||
model.RecordTaskBillingLog(model.RecordTaskBillingLogParams{
|
||||
UserId: task.UserId,
|
||||
LogType: logType,
|
||||
Content: "",
|
||||
ChannelId: task.ChannelId,
|
||||
ModelName: taskModelName(task),
|
||||
Quota: logQuota,
|
||||
TokenId: task.PrivateData.TokenId,
|
||||
Group: task.Group,
|
||||
Other: other,
|
||||
})
|
||||
}
|
||||
|
||||
// RecalculateTaskQuotaByTokens 根据实际 token 消耗重新计费(异步差额结算)。
|
||||
// 当任务成功且返回了 totalTokens 时,根据模型倍率和分组倍率重新计算实际扣费额度,
|
||||
// 与预扣费的差额进行补扣或退还。支持钱包和订阅计费来源。
|
||||
func RecalculateTaskQuotaByTokens(ctx context.Context, task *model.Task, totalTokens int) {
|
||||
if totalTokens <= 0 {
|
||||
return
|
||||
}
|
||||
|
||||
modelName := taskModelName(task)
|
||||
|
||||
// 获取模型价格和倍率
|
||||
modelRatio, hasRatioSetting, _ := ratio_setting.GetModelRatio(modelName)
|
||||
// 只有配置了倍率(非固定价格)时才按 token 重新计费
|
||||
if !hasRatioSetting || modelRatio <= 0 {
|
||||
return
|
||||
}
|
||||
|
||||
// 获取用户和组的倍率信息
|
||||
group := task.Group
|
||||
if group == "" {
|
||||
user, err := model.GetUserById(task.UserId, false)
|
||||
if err == nil {
|
||||
group = user.Group
|
||||
}
|
||||
}
|
||||
if group == "" {
|
||||
return
|
||||
}
|
||||
|
||||
groupRatio := ratio_setting.GetGroupRatio(group)
|
||||
userGroupRatio, hasUserGroupRatio := ratio_setting.GetGroupGroupRatio(group, group)
|
||||
|
||||
var finalGroupRatio float64
|
||||
if hasUserGroupRatio {
|
||||
finalGroupRatio = userGroupRatio
|
||||
} else {
|
||||
finalGroupRatio = groupRatio
|
||||
}
|
||||
|
||||
// 计算实际应扣费额度: totalTokens * modelRatio * groupRatio
|
||||
actualQuota := int(float64(totalTokens) * modelRatio * finalGroupRatio)
|
||||
|
||||
reason := fmt.Sprintf("token重算:tokens=%d, modelRatio=%.2f, groupRatio=%.2f", totalTokens, modelRatio, finalGroupRatio)
|
||||
RecalculateTaskQuota(ctx, task, actualQuota, reason)
|
||||
}
|
||||
712
service/task_billing_test.go
Normal file
712
service/task_billing_test.go
Normal file
@@ -0,0 +1,712 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"net/http"
|
||||
"os"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/QuantumNous/new-api/common"
|
||||
"github.com/QuantumNous/new-api/model"
|
||||
relaycommon "github.com/QuantumNous/new-api/relay/common"
|
||||
"github.com/glebarez/sqlite"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
"gorm.io/gorm"
|
||||
)
|
||||
|
||||
func TestMain(m *testing.M) {
|
||||
db, err := gorm.Open(sqlite.Open(":memory:"), &gorm.Config{})
|
||||
if err != nil {
|
||||
panic("failed to open test db: " + err.Error())
|
||||
}
|
||||
sqlDB, err := db.DB()
|
||||
if err != nil {
|
||||
panic("failed to get sql.DB: " + err.Error())
|
||||
}
|
||||
sqlDB.SetMaxOpenConns(1)
|
||||
|
||||
model.DB = db
|
||||
model.LOG_DB = db
|
||||
|
||||
common.UsingSQLite = true
|
||||
common.RedisEnabled = false
|
||||
common.BatchUpdateEnabled = false
|
||||
common.LogConsumeEnabled = true
|
||||
|
||||
if err := db.AutoMigrate(
|
||||
&model.Task{},
|
||||
&model.User{},
|
||||
&model.Token{},
|
||||
&model.Log{},
|
||||
&model.Channel{},
|
||||
&model.UserSubscription{},
|
||||
); err != nil {
|
||||
panic("failed to migrate: " + err.Error())
|
||||
}
|
||||
|
||||
os.Exit(m.Run())
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// Seed helpers
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
func truncate(t *testing.T) {
|
||||
t.Helper()
|
||||
t.Cleanup(func() {
|
||||
model.DB.Exec("DELETE FROM tasks")
|
||||
model.DB.Exec("DELETE FROM users")
|
||||
model.DB.Exec("DELETE FROM tokens")
|
||||
model.DB.Exec("DELETE FROM logs")
|
||||
model.DB.Exec("DELETE FROM channels")
|
||||
model.DB.Exec("DELETE FROM user_subscriptions")
|
||||
})
|
||||
}
|
||||
|
||||
func seedUser(t *testing.T, id int, quota int) {
|
||||
t.Helper()
|
||||
user := &model.User{Id: id, Username: "test_user", Quota: quota, Status: common.UserStatusEnabled}
|
||||
require.NoError(t, model.DB.Create(user).Error)
|
||||
}
|
||||
|
||||
func seedToken(t *testing.T, id int, userId int, key string, remainQuota int) {
|
||||
t.Helper()
|
||||
token := &model.Token{
|
||||
Id: id,
|
||||
UserId: userId,
|
||||
Key: key,
|
||||
Name: "test_token",
|
||||
Status: common.TokenStatusEnabled,
|
||||
RemainQuota: remainQuota,
|
||||
UsedQuota: 0,
|
||||
}
|
||||
require.NoError(t, model.DB.Create(token).Error)
|
||||
}
|
||||
|
||||
func seedSubscription(t *testing.T, id int, userId int, amountTotal int64, amountUsed int64) {
|
||||
t.Helper()
|
||||
sub := &model.UserSubscription{
|
||||
Id: id,
|
||||
UserId: userId,
|
||||
AmountTotal: amountTotal,
|
||||
AmountUsed: amountUsed,
|
||||
Status: "active",
|
||||
StartTime: time.Now().Unix(),
|
||||
EndTime: time.Now().Add(30 * 24 * time.Hour).Unix(),
|
||||
}
|
||||
require.NoError(t, model.DB.Create(sub).Error)
|
||||
}
|
||||
|
||||
func seedChannel(t *testing.T, id int) {
|
||||
t.Helper()
|
||||
ch := &model.Channel{Id: id, Name: "test_channel", Key: "sk-test", Status: common.ChannelStatusEnabled}
|
||||
require.NoError(t, model.DB.Create(ch).Error)
|
||||
}
|
||||
|
||||
func makeTask(userId, channelId, quota, tokenId int, billingSource string, subscriptionId int) *model.Task {
|
||||
return &model.Task{
|
||||
TaskID: "task_" + time.Now().Format("150405.000"),
|
||||
UserId: userId,
|
||||
ChannelId: channelId,
|
||||
Quota: quota,
|
||||
Status: model.TaskStatus(model.TaskStatusInProgress),
|
||||
Group: "default",
|
||||
Data: json.RawMessage(`{}`),
|
||||
CreatedAt: time.Now().Unix(),
|
||||
UpdatedAt: time.Now().Unix(),
|
||||
Properties: model.Properties{
|
||||
OriginModelName: "test-model",
|
||||
},
|
||||
PrivateData: model.TaskPrivateData{
|
||||
BillingSource: billingSource,
|
||||
SubscriptionId: subscriptionId,
|
||||
TokenId: tokenId,
|
||||
BillingContext: &model.TaskBillingContext{
|
||||
ModelPrice: 0.02,
|
||||
GroupRatio: 1.0,
|
||||
OriginModelName: "test-model",
|
||||
},
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// Read-back helpers
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
func getUserQuota(t *testing.T, id int) int {
|
||||
t.Helper()
|
||||
var user model.User
|
||||
require.NoError(t, model.DB.Select("quota").Where("id = ?", id).First(&user).Error)
|
||||
return user.Quota
|
||||
}
|
||||
|
||||
func getTokenRemainQuota(t *testing.T, id int) int {
|
||||
t.Helper()
|
||||
var token model.Token
|
||||
require.NoError(t, model.DB.Select("remain_quota").Where("id = ?", id).First(&token).Error)
|
||||
return token.RemainQuota
|
||||
}
|
||||
|
||||
func getTokenUsedQuota(t *testing.T, id int) int {
|
||||
t.Helper()
|
||||
var token model.Token
|
||||
require.NoError(t, model.DB.Select("used_quota").Where("id = ?", id).First(&token).Error)
|
||||
return token.UsedQuota
|
||||
}
|
||||
|
||||
func getSubscriptionUsed(t *testing.T, id int) int64 {
|
||||
t.Helper()
|
||||
var sub model.UserSubscription
|
||||
require.NoError(t, model.DB.Select("amount_used").Where("id = ?", id).First(&sub).Error)
|
||||
return sub.AmountUsed
|
||||
}
|
||||
|
||||
func getLastLog(t *testing.T) *model.Log {
|
||||
t.Helper()
|
||||
var log model.Log
|
||||
err := model.LOG_DB.Order("id desc").First(&log).Error
|
||||
if err != nil {
|
||||
return nil
|
||||
}
|
||||
return &log
|
||||
}
|
||||
|
||||
func countLogs(t *testing.T) int64 {
|
||||
t.Helper()
|
||||
var count int64
|
||||
model.LOG_DB.Model(&model.Log{}).Count(&count)
|
||||
return count
|
||||
}
|
||||
|
||||
// ===========================================================================
|
||||
// RefundTaskQuota tests
|
||||
// ===========================================================================
|
||||
|
||||
func TestRefundTaskQuota_Wallet(t *testing.T) {
|
||||
truncate(t)
|
||||
ctx := context.Background()
|
||||
|
||||
const userID, tokenID, channelID = 1, 1, 1
|
||||
const initQuota, preConsumed = 10000, 3000
|
||||
const tokenRemain = 5000
|
||||
|
||||
seedUser(t, userID, initQuota)
|
||||
seedToken(t, tokenID, userID, "sk-test-key", tokenRemain)
|
||||
seedChannel(t, channelID)
|
||||
|
||||
task := makeTask(userID, channelID, preConsumed, tokenID, BillingSourceWallet, 0)
|
||||
|
||||
RefundTaskQuota(ctx, task, "task failed: upstream error")
|
||||
|
||||
// User quota should increase by preConsumed
|
||||
assert.Equal(t, initQuota+preConsumed, getUserQuota(t, userID))
|
||||
|
||||
// Token remain_quota should increase, used_quota should decrease
|
||||
assert.Equal(t, tokenRemain+preConsumed, getTokenRemainQuota(t, tokenID))
|
||||
assert.Equal(t, -preConsumed, getTokenUsedQuota(t, tokenID))
|
||||
|
||||
// A refund log should be created
|
||||
log := getLastLog(t)
|
||||
require.NotNil(t, log)
|
||||
assert.Equal(t, model.LogTypeRefund, log.Type)
|
||||
assert.Equal(t, preConsumed, log.Quota)
|
||||
assert.Equal(t, "test-model", log.ModelName)
|
||||
}
|
||||
|
||||
func TestRefundTaskQuota_Subscription(t *testing.T) {
|
||||
truncate(t)
|
||||
ctx := context.Background()
|
||||
|
||||
const userID, tokenID, channelID, subID = 2, 2, 2, 1
|
||||
const preConsumed = 2000
|
||||
const subTotal, subUsed int64 = 100000, 50000
|
||||
const tokenRemain = 8000
|
||||
|
||||
seedUser(t, userID, 0)
|
||||
seedToken(t, tokenID, userID, "sk-sub-key", tokenRemain)
|
||||
seedChannel(t, channelID)
|
||||
seedSubscription(t, subID, userID, subTotal, subUsed)
|
||||
|
||||
task := makeTask(userID, channelID, preConsumed, tokenID, BillingSourceSubscription, subID)
|
||||
|
||||
RefundTaskQuota(ctx, task, "subscription task failed")
|
||||
|
||||
// Subscription used should decrease by preConsumed
|
||||
assert.Equal(t, subUsed-int64(preConsumed), getSubscriptionUsed(t, subID))
|
||||
|
||||
// Token should also be refunded
|
||||
assert.Equal(t, tokenRemain+preConsumed, getTokenRemainQuota(t, tokenID))
|
||||
|
||||
log := getLastLog(t)
|
||||
require.NotNil(t, log)
|
||||
assert.Equal(t, model.LogTypeRefund, log.Type)
|
||||
}
|
||||
|
||||
func TestRefundTaskQuota_ZeroQuota(t *testing.T) {
|
||||
truncate(t)
|
||||
ctx := context.Background()
|
||||
|
||||
const userID = 3
|
||||
seedUser(t, userID, 5000)
|
||||
|
||||
task := makeTask(userID, 0, 0, 0, BillingSourceWallet, 0)
|
||||
|
||||
RefundTaskQuota(ctx, task, "zero quota task")
|
||||
|
||||
// No change to user quota
|
||||
assert.Equal(t, 5000, getUserQuota(t, userID))
|
||||
|
||||
// No log created
|
||||
assert.Equal(t, int64(0), countLogs(t))
|
||||
}
|
||||
|
||||
func TestRefundTaskQuota_NoToken(t *testing.T) {
|
||||
truncate(t)
|
||||
ctx := context.Background()
|
||||
|
||||
const userID, channelID = 4, 4
|
||||
const initQuota, preConsumed = 10000, 1500
|
||||
|
||||
seedUser(t, userID, initQuota)
|
||||
seedChannel(t, channelID)
|
||||
|
||||
task := makeTask(userID, channelID, preConsumed, 0, BillingSourceWallet, 0) // TokenId=0
|
||||
|
||||
RefundTaskQuota(ctx, task, "no token task failed")
|
||||
|
||||
// User quota refunded
|
||||
assert.Equal(t, initQuota+preConsumed, getUserQuota(t, userID))
|
||||
|
||||
// Log created
|
||||
log := getLastLog(t)
|
||||
require.NotNil(t, log)
|
||||
assert.Equal(t, model.LogTypeRefund, log.Type)
|
||||
}
|
||||
|
||||
// ===========================================================================
|
||||
// RecalculateTaskQuota tests
|
||||
// ===========================================================================
|
||||
|
||||
func TestRecalculate_PositiveDelta(t *testing.T) {
|
||||
truncate(t)
|
||||
ctx := context.Background()
|
||||
|
||||
const userID, tokenID, channelID = 10, 10, 10
|
||||
const initQuota, preConsumed = 10000, 2000
|
||||
const actualQuota = 3000 // under-charged by 1000
|
||||
const tokenRemain = 5000
|
||||
|
||||
seedUser(t, userID, initQuota)
|
||||
seedToken(t, tokenID, userID, "sk-recalc-pos", tokenRemain)
|
||||
seedChannel(t, channelID)
|
||||
|
||||
task := makeTask(userID, channelID, preConsumed, tokenID, BillingSourceWallet, 0)
|
||||
|
||||
RecalculateTaskQuota(ctx, task, actualQuota, "adaptor adjustment")
|
||||
|
||||
// User quota should decrease by the delta (1000 additional charge)
|
||||
assert.Equal(t, initQuota-(actualQuota-preConsumed), getUserQuota(t, userID))
|
||||
|
||||
// Token should also be charged the delta
|
||||
assert.Equal(t, tokenRemain-(actualQuota-preConsumed), getTokenRemainQuota(t, tokenID))
|
||||
|
||||
// task.Quota should be updated to actualQuota
|
||||
assert.Equal(t, actualQuota, task.Quota)
|
||||
|
||||
// Log type should be Consume (additional charge)
|
||||
log := getLastLog(t)
|
||||
require.NotNil(t, log)
|
||||
assert.Equal(t, model.LogTypeConsume, log.Type)
|
||||
assert.Equal(t, actualQuota-preConsumed, log.Quota)
|
||||
}
|
||||
|
||||
func TestRecalculate_NegativeDelta(t *testing.T) {
|
||||
truncate(t)
|
||||
ctx := context.Background()
|
||||
|
||||
const userID, tokenID, channelID = 11, 11, 11
|
||||
const initQuota, preConsumed = 10000, 5000
|
||||
const actualQuota = 3000 // over-charged by 2000
|
||||
const tokenRemain = 5000
|
||||
|
||||
seedUser(t, userID, initQuota)
|
||||
seedToken(t, tokenID, userID, "sk-recalc-neg", tokenRemain)
|
||||
seedChannel(t, channelID)
|
||||
|
||||
task := makeTask(userID, channelID, preConsumed, tokenID, BillingSourceWallet, 0)
|
||||
|
||||
RecalculateTaskQuota(ctx, task, actualQuota, "adaptor adjustment")
|
||||
|
||||
// User quota should increase by abs(delta) = 2000 (refund overpayment)
|
||||
assert.Equal(t, initQuota+(preConsumed-actualQuota), getUserQuota(t, userID))
|
||||
|
||||
// Token should be refunded the difference
|
||||
assert.Equal(t, tokenRemain+(preConsumed-actualQuota), getTokenRemainQuota(t, tokenID))
|
||||
|
||||
// task.Quota updated
|
||||
assert.Equal(t, actualQuota, task.Quota)
|
||||
|
||||
// Log type should be Refund
|
||||
log := getLastLog(t)
|
||||
require.NotNil(t, log)
|
||||
assert.Equal(t, model.LogTypeRefund, log.Type)
|
||||
assert.Equal(t, preConsumed-actualQuota, log.Quota)
|
||||
}
|
||||
|
||||
func TestRecalculate_ZeroDelta(t *testing.T) {
|
||||
truncate(t)
|
||||
ctx := context.Background()
|
||||
|
||||
const userID = 12
|
||||
const initQuota, preConsumed = 10000, 3000
|
||||
|
||||
seedUser(t, userID, initQuota)
|
||||
|
||||
task := makeTask(userID, 0, preConsumed, 0, BillingSourceWallet, 0)
|
||||
|
||||
RecalculateTaskQuota(ctx, task, preConsumed, "exact match")
|
||||
|
||||
// No change to user quota
|
||||
assert.Equal(t, initQuota, getUserQuota(t, userID))
|
||||
|
||||
// No log created (delta is zero)
|
||||
assert.Equal(t, int64(0), countLogs(t))
|
||||
}
|
||||
|
||||
func TestRecalculate_ActualQuotaZero(t *testing.T) {
|
||||
truncate(t)
|
||||
ctx := context.Background()
|
||||
|
||||
const userID = 13
|
||||
const initQuota = 10000
|
||||
|
||||
seedUser(t, userID, initQuota)
|
||||
|
||||
task := makeTask(userID, 0, 5000, 0, BillingSourceWallet, 0)
|
||||
|
||||
RecalculateTaskQuota(ctx, task, 0, "zero actual")
|
||||
|
||||
// No change (early return)
|
||||
assert.Equal(t, initQuota, getUserQuota(t, userID))
|
||||
assert.Equal(t, int64(0), countLogs(t))
|
||||
}
|
||||
|
||||
func TestRecalculate_Subscription_NegativeDelta(t *testing.T) {
|
||||
truncate(t)
|
||||
ctx := context.Background()
|
||||
|
||||
const userID, tokenID, channelID, subID = 14, 14, 14, 2
|
||||
const preConsumed = 5000
|
||||
const actualQuota = 2000 // over-charged by 3000
|
||||
const subTotal, subUsed int64 = 100000, 50000
|
||||
const tokenRemain = 8000
|
||||
|
||||
seedUser(t, userID, 0)
|
||||
seedToken(t, tokenID, userID, "sk-sub-recalc", tokenRemain)
|
||||
seedChannel(t, channelID)
|
||||
seedSubscription(t, subID, userID, subTotal, subUsed)
|
||||
|
||||
task := makeTask(userID, channelID, preConsumed, tokenID, BillingSourceSubscription, subID)
|
||||
|
||||
RecalculateTaskQuota(ctx, task, actualQuota, "subscription over-charge")
|
||||
|
||||
// Subscription used should decrease by delta (refund 3000)
|
||||
assert.Equal(t, subUsed-int64(preConsumed-actualQuota), getSubscriptionUsed(t, subID))
|
||||
|
||||
// Token refunded
|
||||
assert.Equal(t, tokenRemain+(preConsumed-actualQuota), getTokenRemainQuota(t, tokenID))
|
||||
|
||||
assert.Equal(t, actualQuota, task.Quota)
|
||||
|
||||
log := getLastLog(t)
|
||||
require.NotNil(t, log)
|
||||
assert.Equal(t, model.LogTypeRefund, log.Type)
|
||||
}
|
||||
|
||||
// ===========================================================================
|
||||
// CAS + Billing integration tests
|
||||
// Simulates the flow in updateVideoSingleTask (service/task_polling.go)
|
||||
// ===========================================================================
|
||||
|
||||
// simulatePollBilling reproduces the CAS + billing logic from updateVideoSingleTask.
|
||||
// It takes a persisted task (already in DB), applies the new status, and performs
|
||||
// the conditional update + billing exactly as the polling loop does.
|
||||
func simulatePollBilling(ctx context.Context, task *model.Task, newStatus model.TaskStatus, actualQuota int) {
|
||||
snap := task.Snapshot()
|
||||
|
||||
shouldRefund := false
|
||||
shouldSettle := false
|
||||
quota := task.Quota
|
||||
|
||||
task.Status = newStatus
|
||||
switch string(newStatus) {
|
||||
case model.TaskStatusSuccess:
|
||||
task.Progress = "100%"
|
||||
task.FinishTime = 9999
|
||||
shouldSettle = true
|
||||
case model.TaskStatusFailure:
|
||||
task.Progress = "100%"
|
||||
task.FinishTime = 9999
|
||||
task.FailReason = "upstream error"
|
||||
if quota != 0 {
|
||||
shouldRefund = true
|
||||
}
|
||||
default:
|
||||
task.Progress = "50%"
|
||||
}
|
||||
|
||||
isDone := task.Status == model.TaskStatus(model.TaskStatusSuccess) || task.Status == model.TaskStatus(model.TaskStatusFailure)
|
||||
if isDone && snap.Status != task.Status {
|
||||
won, err := task.UpdateWithStatus(snap.Status)
|
||||
if err != nil {
|
||||
shouldRefund = false
|
||||
shouldSettle = false
|
||||
} else if !won {
|
||||
shouldRefund = false
|
||||
shouldSettle = false
|
||||
}
|
||||
} else if !snap.Equal(task.Snapshot()) {
|
||||
_, _ = task.UpdateWithStatus(snap.Status)
|
||||
}
|
||||
|
||||
if shouldSettle && actualQuota > 0 {
|
||||
RecalculateTaskQuota(ctx, task, actualQuota, "test settle")
|
||||
}
|
||||
if shouldRefund {
|
||||
RefundTaskQuota(ctx, task, task.FailReason)
|
||||
}
|
||||
}
|
||||
|
||||
func TestCASGuardedRefund_Win(t *testing.T) {
|
||||
truncate(t)
|
||||
ctx := context.Background()
|
||||
|
||||
const userID, tokenID, channelID = 20, 20, 20
|
||||
const initQuota, preConsumed = 10000, 4000
|
||||
const tokenRemain = 6000
|
||||
|
||||
seedUser(t, userID, initQuota)
|
||||
seedToken(t, tokenID, userID, "sk-cas-refund-win", tokenRemain)
|
||||
seedChannel(t, channelID)
|
||||
|
||||
task := makeTask(userID, channelID, preConsumed, tokenID, BillingSourceWallet, 0)
|
||||
task.Status = model.TaskStatus(model.TaskStatusInProgress)
|
||||
require.NoError(t, model.DB.Create(task).Error)
|
||||
|
||||
simulatePollBilling(ctx, task, model.TaskStatus(model.TaskStatusFailure), 0)
|
||||
|
||||
// CAS wins: task in DB should now be FAILURE
|
||||
var reloaded model.Task
|
||||
require.NoError(t, model.DB.First(&reloaded, task.ID).Error)
|
||||
assert.EqualValues(t, model.TaskStatusFailure, reloaded.Status)
|
||||
|
||||
// Refund should have happened
|
||||
assert.Equal(t, initQuota+preConsumed, getUserQuota(t, userID))
|
||||
assert.Equal(t, tokenRemain+preConsumed, getTokenRemainQuota(t, tokenID))
|
||||
|
||||
log := getLastLog(t)
|
||||
require.NotNil(t, log)
|
||||
assert.Equal(t, model.LogTypeRefund, log.Type)
|
||||
}
|
||||
|
||||
func TestCASGuardedRefund_Lose(t *testing.T) {
|
||||
truncate(t)
|
||||
ctx := context.Background()
|
||||
|
||||
const userID, tokenID, channelID = 21, 21, 21
|
||||
const initQuota, preConsumed = 10000, 4000
|
||||
const tokenRemain = 6000
|
||||
|
||||
seedUser(t, userID, initQuota)
|
||||
seedToken(t, tokenID, userID, "sk-cas-refund-lose", tokenRemain)
|
||||
seedChannel(t, channelID)
|
||||
|
||||
// Create task with IN_PROGRESS in DB
|
||||
task := makeTask(userID, channelID, preConsumed, tokenID, BillingSourceWallet, 0)
|
||||
task.Status = model.TaskStatus(model.TaskStatusInProgress)
|
||||
require.NoError(t, model.DB.Create(task).Error)
|
||||
|
||||
// Simulate another process already transitioning to FAILURE
|
||||
model.DB.Model(&model.Task{}).Where("id = ?", task.ID).Update("status", model.TaskStatusFailure)
|
||||
|
||||
// Our process still has the old in-memory state (IN_PROGRESS) and tries to transition
|
||||
// task.Status is still IN_PROGRESS in the snapshot
|
||||
simulatePollBilling(ctx, task, model.TaskStatus(model.TaskStatusFailure), 0)
|
||||
|
||||
// CAS lost: user quota should NOT change (no double refund)
|
||||
assert.Equal(t, initQuota, getUserQuota(t, userID))
|
||||
assert.Equal(t, tokenRemain, getTokenRemainQuota(t, tokenID))
|
||||
|
||||
// No billing log should be created
|
||||
assert.Equal(t, int64(0), countLogs(t))
|
||||
}
|
||||
|
||||
func TestCASGuardedSettle_Win(t *testing.T) {
|
||||
truncate(t)
|
||||
ctx := context.Background()
|
||||
|
||||
const userID, tokenID, channelID = 22, 22, 22
|
||||
const initQuota, preConsumed = 10000, 5000
|
||||
const actualQuota = 3000 // over-charged, should get partial refund
|
||||
const tokenRemain = 8000
|
||||
|
||||
seedUser(t, userID, initQuota)
|
||||
seedToken(t, tokenID, userID, "sk-cas-settle-win", tokenRemain)
|
||||
seedChannel(t, channelID)
|
||||
|
||||
task := makeTask(userID, channelID, preConsumed, tokenID, BillingSourceWallet, 0)
|
||||
task.Status = model.TaskStatus(model.TaskStatusInProgress)
|
||||
require.NoError(t, model.DB.Create(task).Error)
|
||||
|
||||
simulatePollBilling(ctx, task, model.TaskStatus(model.TaskStatusSuccess), actualQuota)
|
||||
|
||||
// CAS wins: task should be SUCCESS
|
||||
var reloaded model.Task
|
||||
require.NoError(t, model.DB.First(&reloaded, task.ID).Error)
|
||||
assert.EqualValues(t, model.TaskStatusSuccess, reloaded.Status)
|
||||
|
||||
// Settlement should refund the over-charge (5000 - 3000 = 2000 back to user)
|
||||
assert.Equal(t, initQuota+(preConsumed-actualQuota), getUserQuota(t, userID))
|
||||
assert.Equal(t, tokenRemain+(preConsumed-actualQuota), getTokenRemainQuota(t, tokenID))
|
||||
|
||||
// task.Quota should be updated to actualQuota
|
||||
assert.Equal(t, actualQuota, task.Quota)
|
||||
}
|
||||
|
||||
func TestNonTerminalUpdate_NoBilling(t *testing.T) {
|
||||
truncate(t)
|
||||
ctx := context.Background()
|
||||
|
||||
const userID, channelID = 23, 23
|
||||
const initQuota, preConsumed = 10000, 3000
|
||||
|
||||
seedUser(t, userID, initQuota)
|
||||
seedChannel(t, channelID)
|
||||
|
||||
task := makeTask(userID, channelID, preConsumed, 0, BillingSourceWallet, 0)
|
||||
task.Status = model.TaskStatus(model.TaskStatusInProgress)
|
||||
task.Progress = "20%"
|
||||
require.NoError(t, model.DB.Create(task).Error)
|
||||
|
||||
// Simulate a non-terminal poll update (still IN_PROGRESS, progress changed)
|
||||
simulatePollBilling(ctx, task, model.TaskStatus(model.TaskStatusInProgress), 0)
|
||||
|
||||
// User quota should NOT change
|
||||
assert.Equal(t, initQuota, getUserQuota(t, userID))
|
||||
|
||||
// No billing log
|
||||
assert.Equal(t, int64(0), countLogs(t))
|
||||
|
||||
// Task progress should be updated in DB
|
||||
var reloaded model.Task
|
||||
require.NoError(t, model.DB.First(&reloaded, task.ID).Error)
|
||||
assert.Equal(t, "50%", reloaded.Progress)
|
||||
}
|
||||
|
||||
// ===========================================================================
|
||||
// Mock adaptor for settleTaskBillingOnComplete tests
|
||||
// ===========================================================================
|
||||
|
||||
type mockAdaptor struct {
|
||||
adjustReturn int
|
||||
}
|
||||
|
||||
func (m *mockAdaptor) Init(_ *relaycommon.RelayInfo) {}
|
||||
func (m *mockAdaptor) FetchTask(string, string, map[string]any, string) (*http.Response, error) { return nil, nil }
|
||||
func (m *mockAdaptor) ParseTaskResult([]byte) (*relaycommon.TaskInfo, error) { return nil, nil }
|
||||
func (m *mockAdaptor) AdjustBillingOnComplete(_ *model.Task, _ *relaycommon.TaskInfo) int {
|
||||
return m.adjustReturn
|
||||
}
|
||||
|
||||
// ===========================================================================
|
||||
// PerCallBilling tests — settleTaskBillingOnComplete
|
||||
// ===========================================================================
|
||||
|
||||
func TestSettle_PerCallBilling_SkipsAdaptorAdjust(t *testing.T) {
|
||||
truncate(t)
|
||||
ctx := context.Background()
|
||||
|
||||
const userID, tokenID, channelID = 30, 30, 30
|
||||
const initQuota, preConsumed = 10000, 5000
|
||||
const tokenRemain = 8000
|
||||
|
||||
seedUser(t, userID, initQuota)
|
||||
seedToken(t, tokenID, userID, "sk-percall-adaptor", tokenRemain)
|
||||
seedChannel(t, channelID)
|
||||
|
||||
task := makeTask(userID, channelID, preConsumed, tokenID, BillingSourceWallet, 0)
|
||||
task.PrivateData.BillingContext.PerCallBilling = true
|
||||
|
||||
adaptor := &mockAdaptor{adjustReturn: 2000}
|
||||
taskResult := &relaycommon.TaskInfo{Status: model.TaskStatusSuccess}
|
||||
|
||||
settleTaskBillingOnComplete(ctx, adaptor, task, taskResult)
|
||||
|
||||
// Per-call: no adjustment despite adaptor returning 2000
|
||||
assert.Equal(t, initQuota, getUserQuota(t, userID))
|
||||
assert.Equal(t, tokenRemain, getTokenRemainQuota(t, tokenID))
|
||||
assert.Equal(t, preConsumed, task.Quota)
|
||||
assert.Equal(t, int64(0), countLogs(t))
|
||||
}
|
||||
|
||||
func TestSettle_PerCallBilling_SkipsTotalTokens(t *testing.T) {
|
||||
truncate(t)
|
||||
ctx := context.Background()
|
||||
|
||||
const userID, tokenID, channelID = 31, 31, 31
|
||||
const initQuota, preConsumed = 10000, 4000
|
||||
const tokenRemain = 7000
|
||||
|
||||
seedUser(t, userID, initQuota)
|
||||
seedToken(t, tokenID, userID, "sk-percall-tokens", tokenRemain)
|
||||
seedChannel(t, channelID)
|
||||
|
||||
task := makeTask(userID, channelID, preConsumed, tokenID, BillingSourceWallet, 0)
|
||||
task.PrivateData.BillingContext.PerCallBilling = true
|
||||
|
||||
adaptor := &mockAdaptor{adjustReturn: 0}
|
||||
taskResult := &relaycommon.TaskInfo{Status: model.TaskStatusSuccess, TotalTokens: 9999}
|
||||
|
||||
settleTaskBillingOnComplete(ctx, adaptor, task, taskResult)
|
||||
|
||||
// Per-call: no recalculation by tokens
|
||||
assert.Equal(t, initQuota, getUserQuota(t, userID))
|
||||
assert.Equal(t, tokenRemain, getTokenRemainQuota(t, tokenID))
|
||||
assert.Equal(t, preConsumed, task.Quota)
|
||||
assert.Equal(t, int64(0), countLogs(t))
|
||||
}
|
||||
|
||||
func TestSettle_NonPerCall_AdaptorAdjustWorks(t *testing.T) {
|
||||
truncate(t)
|
||||
ctx := context.Background()
|
||||
|
||||
const userID, tokenID, channelID = 32, 32, 32
|
||||
const initQuota, preConsumed = 10000, 5000
|
||||
const adaptorQuota = 3000
|
||||
const tokenRemain = 8000
|
||||
|
||||
seedUser(t, userID, initQuota)
|
||||
seedToken(t, tokenID, userID, "sk-nonpercall-adj", tokenRemain)
|
||||
seedChannel(t, channelID)
|
||||
|
||||
task := makeTask(userID, channelID, preConsumed, tokenID, BillingSourceWallet, 0)
|
||||
// PerCallBilling defaults to false
|
||||
|
||||
adaptor := &mockAdaptor{adjustReturn: adaptorQuota}
|
||||
taskResult := &relaycommon.TaskInfo{Status: model.TaskStatusSuccess}
|
||||
|
||||
settleTaskBillingOnComplete(ctx, adaptor, task, taskResult)
|
||||
|
||||
// Non-per-call: adaptor adjustment applies (refund 2000)
|
||||
assert.Equal(t, initQuota+(preConsumed-adaptorQuota), getUserQuota(t, userID))
|
||||
assert.Equal(t, tokenRemain+(preConsumed-adaptorQuota), getTokenRemainQuota(t, tokenID))
|
||||
assert.Equal(t, adaptorQuota, task.Quota)
|
||||
|
||||
log := getLastLog(t)
|
||||
require.NotNil(t, log)
|
||||
assert.Equal(t, model.LogTypeRefund, log.Type)
|
||||
}
|
||||
539
service/task_polling.go
Normal file
539
service/task_polling.go
Normal file
@@ -0,0 +1,539 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"sort"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/QuantumNous/new-api/common"
|
||||
"github.com/QuantumNous/new-api/constant"
|
||||
"github.com/QuantumNous/new-api/dto"
|
||||
"github.com/QuantumNous/new-api/logger"
|
||||
"github.com/QuantumNous/new-api/model"
|
||||
"github.com/QuantumNous/new-api/relay/channel/task/taskcommon"
|
||||
relaycommon "github.com/QuantumNous/new-api/relay/common"
|
||||
|
||||
"github.com/samber/lo"
|
||||
)
|
||||
|
||||
// TaskPollingAdaptor 定义轮询所需的最小适配器接口,避免 service -> relay 的循环依赖
|
||||
type TaskPollingAdaptor interface {
|
||||
Init(info *relaycommon.RelayInfo)
|
||||
FetchTask(baseURL string, key string, body map[string]any, proxy string) (*http.Response, error)
|
||||
ParseTaskResult(body []byte) (*relaycommon.TaskInfo, error)
|
||||
// AdjustBillingOnComplete 在任务到达终态(成功/失败)时由轮询循环调用。
|
||||
// 返回正数触发差额结算(补扣/退还),返回 0 保持预扣费金额不变。
|
||||
AdjustBillingOnComplete(task *model.Task, taskResult *relaycommon.TaskInfo) int
|
||||
}
|
||||
|
||||
// GetTaskAdaptorFunc 由 main 包注入,用于获取指定平台的任务适配器。
|
||||
// 打破 service -> relay -> relay/channel -> service 的循环依赖。
|
||||
var GetTaskAdaptorFunc func(platform constant.TaskPlatform) TaskPollingAdaptor
|
||||
|
||||
// sweepTimedOutTasks 在主轮询之前独立清理超时任务。
|
||||
// 每次最多处理 100 条,剩余的下个周期继续处理。
|
||||
// 使用 per-task CAS (UpdateWithStatus) 防止覆盖被正常轮询已推进的任务。
|
||||
func sweepTimedOutTasks(ctx context.Context) {
|
||||
if constant.TaskTimeoutMinutes <= 0 {
|
||||
return
|
||||
}
|
||||
cutoff := time.Now().Unix() - int64(constant.TaskTimeoutMinutes)*60
|
||||
tasks := model.GetTimedOutUnfinishedTasks(cutoff, 100)
|
||||
if len(tasks) == 0 {
|
||||
return
|
||||
}
|
||||
|
||||
const legacyTaskCutoff int64 = 1740182400 // 2026-02-22 00:00:00 UTC
|
||||
reason := fmt.Sprintf("任务超时(%d分钟)", constant.TaskTimeoutMinutes)
|
||||
legacyReason := "任务超时(旧系统遗留任务,不进行退款,请联系管理员)"
|
||||
now := time.Now().Unix()
|
||||
timedOutCount := 0
|
||||
|
||||
for _, task := range tasks {
|
||||
isLegacy := task.SubmitTime > 0 && task.SubmitTime < legacyTaskCutoff
|
||||
|
||||
oldStatus := task.Status
|
||||
task.Status = model.TaskStatusFailure
|
||||
task.Progress = "100%"
|
||||
task.FinishTime = now
|
||||
if isLegacy {
|
||||
task.FailReason = legacyReason
|
||||
} else {
|
||||
task.FailReason = reason
|
||||
}
|
||||
|
||||
won, err := task.UpdateWithStatus(oldStatus)
|
||||
if err != nil {
|
||||
logger.LogError(ctx, fmt.Sprintf("sweepTimedOutTasks CAS update error for task %s: %v", task.TaskID, err))
|
||||
continue
|
||||
}
|
||||
if !won {
|
||||
logger.LogInfo(ctx, fmt.Sprintf("sweepTimedOutTasks: task %s already transitioned, skip", task.TaskID))
|
||||
continue
|
||||
}
|
||||
timedOutCount++
|
||||
if !isLegacy && task.Quota != 0 {
|
||||
RefundTaskQuota(ctx, task, reason)
|
||||
}
|
||||
}
|
||||
|
||||
if timedOutCount > 0 {
|
||||
logger.LogInfo(ctx, fmt.Sprintf("sweepTimedOutTasks: timed out %d tasks", timedOutCount))
|
||||
}
|
||||
}
|
||||
|
||||
// TaskPollingLoop 主轮询循环,每 15 秒检查一次未完成的任务
|
||||
func TaskPollingLoop() {
|
||||
for {
|
||||
time.Sleep(time.Duration(15) * time.Second)
|
||||
common.SysLog("任务进度轮询开始")
|
||||
ctx := context.TODO()
|
||||
sweepTimedOutTasks(ctx)
|
||||
allTasks := model.GetAllUnFinishSyncTasks(constant.TaskQueryLimit)
|
||||
platformTask := make(map[constant.TaskPlatform][]*model.Task)
|
||||
for _, t := range allTasks {
|
||||
platformTask[t.Platform] = append(platformTask[t.Platform], t)
|
||||
}
|
||||
for platform, tasks := range platformTask {
|
||||
if len(tasks) == 0 {
|
||||
continue
|
||||
}
|
||||
taskChannelM := make(map[int][]string)
|
||||
taskM := make(map[string]*model.Task)
|
||||
nullTaskIds := make([]int64, 0)
|
||||
for _, task := range tasks {
|
||||
upstreamID := task.GetUpstreamTaskID()
|
||||
if upstreamID == "" {
|
||||
// 统计失败的未完成任务
|
||||
nullTaskIds = append(nullTaskIds, task.ID)
|
||||
continue
|
||||
}
|
||||
taskM[upstreamID] = task
|
||||
taskChannelM[task.ChannelId] = append(taskChannelM[task.ChannelId], upstreamID)
|
||||
}
|
||||
if len(nullTaskIds) > 0 {
|
||||
err := model.TaskBulkUpdateByID(nullTaskIds, map[string]any{
|
||||
"status": "FAILURE",
|
||||
"progress": "100%",
|
||||
})
|
||||
if err != nil {
|
||||
logger.LogError(ctx, fmt.Sprintf("Fix null task_id task error: %v", err))
|
||||
} else {
|
||||
logger.LogInfo(ctx, fmt.Sprintf("Fix null task_id task success: %v", nullTaskIds))
|
||||
}
|
||||
}
|
||||
if len(taskChannelM) == 0 {
|
||||
continue
|
||||
}
|
||||
|
||||
DispatchPlatformUpdate(platform, taskChannelM, taskM)
|
||||
}
|
||||
common.SysLog("任务进度轮询完成")
|
||||
}
|
||||
}
|
||||
|
||||
// DispatchPlatformUpdate 按平台分发轮询更新
|
||||
func DispatchPlatformUpdate(platform constant.TaskPlatform, taskChannelM map[int][]string, taskM map[string]*model.Task) {
|
||||
switch platform {
|
||||
case constant.TaskPlatformMidjourney:
|
||||
// MJ 轮询由其自身处理,这里预留入口
|
||||
case constant.TaskPlatformSuno:
|
||||
_ = UpdateSunoTasks(context.Background(), taskChannelM, taskM)
|
||||
default:
|
||||
if err := UpdateVideoTasks(context.Background(), platform, taskChannelM, taskM); err != nil {
|
||||
common.SysLog(fmt.Sprintf("UpdateVideoTasks fail: %s", err))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// UpdateSunoTasks 按渠道更新所有 Suno 任务
|
||||
func UpdateSunoTasks(ctx context.Context, taskChannelM map[int][]string, taskM map[string]*model.Task) error {
|
||||
for channelId, taskIds := range taskChannelM {
|
||||
err := updateSunoTasks(ctx, channelId, taskIds, taskM)
|
||||
if err != nil {
|
||||
logger.LogError(ctx, fmt.Sprintf("渠道 #%d 更新异步任务失败: %s", channelId, err.Error()))
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func updateSunoTasks(ctx context.Context, channelId int, taskIds []string, taskM map[string]*model.Task) error {
|
||||
logger.LogInfo(ctx, fmt.Sprintf("渠道 #%d 未完成的任务有: %d", channelId, len(taskIds)))
|
||||
if len(taskIds) == 0 {
|
||||
return nil
|
||||
}
|
||||
ch, err := model.CacheGetChannel(channelId)
|
||||
if err != nil {
|
||||
common.SysLog(fmt.Sprintf("CacheGetChannel: %v", err))
|
||||
// Collect DB primary key IDs for bulk update (taskIds are upstream IDs, not task_id column values)
|
||||
var failedIDs []int64
|
||||
for _, upstreamID := range taskIds {
|
||||
if t, ok := taskM[upstreamID]; ok {
|
||||
failedIDs = append(failedIDs, t.ID)
|
||||
}
|
||||
}
|
||||
err = model.TaskBulkUpdateByID(failedIDs, map[string]any{
|
||||
"fail_reason": fmt.Sprintf("获取渠道信息失败,请联系管理员,渠道ID:%d", channelId),
|
||||
"status": "FAILURE",
|
||||
"progress": "100%",
|
||||
})
|
||||
if err != nil {
|
||||
common.SysLog(fmt.Sprintf("UpdateSunoTask error: %v", err))
|
||||
}
|
||||
return err
|
||||
}
|
||||
adaptor := GetTaskAdaptorFunc(constant.TaskPlatformSuno)
|
||||
if adaptor == nil {
|
||||
return errors.New("adaptor not found")
|
||||
}
|
||||
proxy := ch.GetSetting().Proxy
|
||||
resp, err := adaptor.FetchTask(*ch.BaseURL, ch.Key, map[string]any{
|
||||
"ids": taskIds,
|
||||
}, proxy)
|
||||
if err != nil {
|
||||
common.SysLog(fmt.Sprintf("Get Task Do req error: %v", err))
|
||||
return err
|
||||
}
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
logger.LogError(ctx, fmt.Sprintf("Get Task status code: %d", resp.StatusCode))
|
||||
return fmt.Errorf("Get Task status code: %d", resp.StatusCode)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
responseBody, err := io.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
common.SysLog(fmt.Sprintf("Get Suno Task parse body error: %v", err))
|
||||
return err
|
||||
}
|
||||
var responseItems dto.TaskResponse[[]dto.SunoDataResponse]
|
||||
err = common.Unmarshal(responseBody, &responseItems)
|
||||
if err != nil {
|
||||
logger.LogError(ctx, fmt.Sprintf("Get Suno Task parse body error2: %v, body: %s", err, string(responseBody)))
|
||||
return err
|
||||
}
|
||||
if !responseItems.IsSuccess() {
|
||||
common.SysLog(fmt.Sprintf("渠道 #%d 未完成的任务有: %d, 成功获取到任务数: %s", channelId, len(taskIds), string(responseBody)))
|
||||
return err
|
||||
}
|
||||
|
||||
for _, responseItem := range responseItems.Data {
|
||||
task := taskM[responseItem.TaskID]
|
||||
if !taskNeedsUpdate(task, responseItem) {
|
||||
continue
|
||||
}
|
||||
|
||||
task.Status = lo.If(model.TaskStatus(responseItem.Status) != "", model.TaskStatus(responseItem.Status)).Else(task.Status)
|
||||
task.FailReason = lo.If(responseItem.FailReason != "", responseItem.FailReason).Else(task.FailReason)
|
||||
task.SubmitTime = lo.If(responseItem.SubmitTime != 0, responseItem.SubmitTime).Else(task.SubmitTime)
|
||||
task.StartTime = lo.If(responseItem.StartTime != 0, responseItem.StartTime).Else(task.StartTime)
|
||||
task.FinishTime = lo.If(responseItem.FinishTime != 0, responseItem.FinishTime).Else(task.FinishTime)
|
||||
if responseItem.FailReason != "" || task.Status == model.TaskStatusFailure {
|
||||
logger.LogInfo(ctx, task.TaskID+" 构建失败,"+task.FailReason)
|
||||
task.Progress = "100%"
|
||||
RefundTaskQuota(ctx, task, task.FailReason)
|
||||
}
|
||||
if responseItem.Status == model.TaskStatusSuccess {
|
||||
task.Progress = "100%"
|
||||
}
|
||||
task.Data = responseItem.Data
|
||||
|
||||
err = task.Update()
|
||||
if err != nil {
|
||||
common.SysLog("UpdateSunoTask task error: " + err.Error())
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// taskNeedsUpdate 检查 Suno 任务是否需要更新
|
||||
func taskNeedsUpdate(oldTask *model.Task, newTask dto.SunoDataResponse) bool {
|
||||
if oldTask.SubmitTime != newTask.SubmitTime {
|
||||
return true
|
||||
}
|
||||
if oldTask.StartTime != newTask.StartTime {
|
||||
return true
|
||||
}
|
||||
if oldTask.FinishTime != newTask.FinishTime {
|
||||
return true
|
||||
}
|
||||
if string(oldTask.Status) != newTask.Status {
|
||||
return true
|
||||
}
|
||||
if oldTask.FailReason != newTask.FailReason {
|
||||
return true
|
||||
}
|
||||
|
||||
if (oldTask.Status == model.TaskStatusFailure || oldTask.Status == model.TaskStatusSuccess) && oldTask.Progress != "100%" {
|
||||
return true
|
||||
}
|
||||
|
||||
oldData, _ := common.Marshal(oldTask.Data)
|
||||
newData, _ := common.Marshal(newTask.Data)
|
||||
|
||||
sort.Slice(oldData, func(i, j int) bool {
|
||||
return oldData[i] < oldData[j]
|
||||
})
|
||||
sort.Slice(newData, func(i, j int) bool {
|
||||
return newData[i] < newData[j]
|
||||
})
|
||||
|
||||
if string(oldData) != string(newData) {
|
||||
return true
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// UpdateVideoTasks 按渠道更新所有视频任务
|
||||
func UpdateVideoTasks(ctx context.Context, platform constant.TaskPlatform, taskChannelM map[int][]string, taskM map[string]*model.Task) error {
|
||||
for channelId, taskIds := range taskChannelM {
|
||||
if err := updateVideoTasks(ctx, platform, channelId, taskIds, taskM); err != nil {
|
||||
logger.LogError(ctx, fmt.Sprintf("Channel #%d failed to update video async tasks: %s", channelId, err.Error()))
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func updateVideoTasks(ctx context.Context, platform constant.TaskPlatform, channelId int, taskIds []string, taskM map[string]*model.Task) error {
|
||||
logger.LogInfo(ctx, fmt.Sprintf("Channel #%d pending video tasks: %d", channelId, len(taskIds)))
|
||||
if len(taskIds) == 0 {
|
||||
return nil
|
||||
}
|
||||
cacheGetChannel, err := model.CacheGetChannel(channelId)
|
||||
if err != nil {
|
||||
// Collect DB primary key IDs for bulk update (taskIds are upstream IDs, not task_id column values)
|
||||
var failedIDs []int64
|
||||
for _, upstreamID := range taskIds {
|
||||
if t, ok := taskM[upstreamID]; ok {
|
||||
failedIDs = append(failedIDs, t.ID)
|
||||
}
|
||||
}
|
||||
errUpdate := model.TaskBulkUpdateByID(failedIDs, map[string]any{
|
||||
"fail_reason": fmt.Sprintf("Failed to get channel info, channel ID: %d", channelId),
|
||||
"status": "FAILURE",
|
||||
"progress": "100%",
|
||||
})
|
||||
if errUpdate != nil {
|
||||
common.SysLog(fmt.Sprintf("UpdateVideoTask error: %v", errUpdate))
|
||||
}
|
||||
return fmt.Errorf("CacheGetChannel failed: %w", err)
|
||||
}
|
||||
adaptor := GetTaskAdaptorFunc(platform)
|
||||
if adaptor == nil {
|
||||
return fmt.Errorf("video adaptor not found")
|
||||
}
|
||||
info := &relaycommon.RelayInfo{}
|
||||
info.ChannelMeta = &relaycommon.ChannelMeta{
|
||||
ChannelBaseUrl: cacheGetChannel.GetBaseURL(),
|
||||
}
|
||||
info.ApiKey = cacheGetChannel.Key
|
||||
adaptor.Init(info)
|
||||
for _, taskId := range taskIds {
|
||||
if err := updateVideoSingleTask(ctx, adaptor, cacheGetChannel, taskId, taskM); err != nil {
|
||||
logger.LogError(ctx, fmt.Sprintf("Failed to update video task %s: %s", taskId, err.Error()))
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func updateVideoSingleTask(ctx context.Context, adaptor TaskPollingAdaptor, ch *model.Channel, taskId string, taskM map[string]*model.Task) error {
|
||||
baseURL := constant.ChannelBaseURLs[ch.Type]
|
||||
if ch.GetBaseURL() != "" {
|
||||
baseURL = ch.GetBaseURL()
|
||||
}
|
||||
proxy := ch.GetSetting().Proxy
|
||||
|
||||
task := taskM[taskId]
|
||||
if task == nil {
|
||||
logger.LogError(ctx, fmt.Sprintf("Task %s not found in taskM", taskId))
|
||||
return fmt.Errorf("task %s not found", taskId)
|
||||
}
|
||||
key := ch.Key
|
||||
|
||||
privateData := task.PrivateData
|
||||
if privateData.Key != "" {
|
||||
key = privateData.Key
|
||||
}
|
||||
resp, err := adaptor.FetchTask(baseURL, key, map[string]any{
|
||||
"task_id": task.GetUpstreamTaskID(),
|
||||
"action": task.Action,
|
||||
}, proxy)
|
||||
if err != nil {
|
||||
return fmt.Errorf("fetchTask failed for task %s: %w", taskId, err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
responseBody, err := io.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
return fmt.Errorf("readAll failed for task %s: %w", taskId, err)
|
||||
}
|
||||
|
||||
logger.LogDebug(ctx, fmt.Sprintf("updateVideoSingleTask response: %s", string(responseBody)))
|
||||
|
||||
snap := task.Snapshot()
|
||||
|
||||
taskResult := &relaycommon.TaskInfo{}
|
||||
// try parse as New API response format
|
||||
var responseItems dto.TaskResponse[model.Task]
|
||||
if err = common.Unmarshal(responseBody, &responseItems); err == nil && responseItems.IsSuccess() {
|
||||
logger.LogDebug(ctx, fmt.Sprintf("updateVideoSingleTask parsed as new api response format: %+v", responseItems))
|
||||
t := responseItems.Data
|
||||
taskResult.TaskID = t.TaskID
|
||||
taskResult.Status = string(t.Status)
|
||||
taskResult.Url = t.GetResultURL()
|
||||
taskResult.Progress = t.Progress
|
||||
taskResult.Reason = t.FailReason
|
||||
task.Data = t.Data
|
||||
} else if taskResult, err = adaptor.ParseTaskResult(responseBody); err != nil {
|
||||
return fmt.Errorf("parseTaskResult failed for task %s: %w", taskId, err)
|
||||
} else {
|
||||
task.Data = redactVideoResponseBody(responseBody)
|
||||
}
|
||||
|
||||
logger.LogDebug(ctx, fmt.Sprintf("updateVideoSingleTask taskResult: %+v", taskResult))
|
||||
|
||||
now := time.Now().Unix()
|
||||
if taskResult.Status == "" {
|
||||
taskResult = relaycommon.FailTaskInfo("upstream returned empty status")
|
||||
}
|
||||
|
||||
shouldRefund := false
|
||||
shouldSettle := false
|
||||
quota := task.Quota
|
||||
|
||||
task.Status = model.TaskStatus(taskResult.Status)
|
||||
switch taskResult.Status {
|
||||
case model.TaskStatusSubmitted:
|
||||
task.Progress = taskcommon.ProgressSubmitted
|
||||
case model.TaskStatusQueued:
|
||||
task.Progress = taskcommon.ProgressQueued
|
||||
case model.TaskStatusInProgress:
|
||||
task.Progress = taskcommon.ProgressInProgress
|
||||
if task.StartTime == 0 {
|
||||
task.StartTime = now
|
||||
}
|
||||
case model.TaskStatusSuccess:
|
||||
task.Progress = taskcommon.ProgressComplete
|
||||
if task.FinishTime == 0 {
|
||||
task.FinishTime = now
|
||||
}
|
||||
if strings.HasPrefix(taskResult.Url, "data:") {
|
||||
// data: URI (e.g. Vertex base64 encoded video) — keep in Data, not in ResultURL
|
||||
} else if taskResult.Url != "" {
|
||||
// Direct upstream URL (e.g. Kling, Ali, Doubao, etc.)
|
||||
task.PrivateData.ResultURL = taskResult.Url
|
||||
} else {
|
||||
// No URL from adaptor — construct proxy URL using public task ID
|
||||
task.PrivateData.ResultURL = taskcommon.BuildProxyURL(task.TaskID)
|
||||
}
|
||||
shouldSettle = true
|
||||
case model.TaskStatusFailure:
|
||||
logger.LogJson(ctx, fmt.Sprintf("Task %s failed", taskId), task)
|
||||
task.Status = model.TaskStatusFailure
|
||||
task.Progress = taskcommon.ProgressComplete
|
||||
if task.FinishTime == 0 {
|
||||
task.FinishTime = now
|
||||
}
|
||||
task.FailReason = taskResult.Reason
|
||||
logger.LogInfo(ctx, fmt.Sprintf("Task %s failed: %s", task.TaskID, task.FailReason))
|
||||
taskResult.Progress = taskcommon.ProgressComplete
|
||||
if quota != 0 {
|
||||
shouldRefund = true
|
||||
}
|
||||
default:
|
||||
return fmt.Errorf("unknown task status %s for task %s", taskResult.Status, task.TaskID)
|
||||
}
|
||||
if taskResult.Progress != "" {
|
||||
task.Progress = taskResult.Progress
|
||||
}
|
||||
|
||||
isDone := task.Status == model.TaskStatusSuccess || task.Status == model.TaskStatusFailure
|
||||
if isDone && snap.Status != task.Status {
|
||||
won, err := task.UpdateWithStatus(snap.Status)
|
||||
if err != nil {
|
||||
logger.LogError(ctx, fmt.Sprintf("UpdateWithStatus failed for task %s: %s", task.TaskID, err.Error()))
|
||||
shouldRefund = false
|
||||
shouldSettle = false
|
||||
} else if !won {
|
||||
logger.LogWarn(ctx, fmt.Sprintf("Task %s already transitioned by another process, skip billing", task.TaskID))
|
||||
shouldRefund = false
|
||||
shouldSettle = false
|
||||
}
|
||||
} else if !snap.Equal(task.Snapshot()) {
|
||||
if _, err := task.UpdateWithStatus(snap.Status); err != nil {
|
||||
logger.LogError(ctx, fmt.Sprintf("Failed to update task %s: %s", task.TaskID, err.Error()))
|
||||
}
|
||||
} else {
|
||||
// No changes, skip update
|
||||
logger.LogDebug(ctx, fmt.Sprintf("No update needed for task %s", task.TaskID))
|
||||
}
|
||||
|
||||
if shouldSettle {
|
||||
settleTaskBillingOnComplete(ctx, adaptor, task, taskResult)
|
||||
}
|
||||
if shouldRefund {
|
||||
RefundTaskQuota(ctx, task, task.FailReason)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func redactVideoResponseBody(body []byte) []byte {
|
||||
var m map[string]any
|
||||
if err := common.Unmarshal(body, &m); err != nil {
|
||||
return body
|
||||
}
|
||||
resp, _ := m["response"].(map[string]any)
|
||||
if resp != nil {
|
||||
delete(resp, "bytesBase64Encoded")
|
||||
if v, ok := resp["video"].(string); ok {
|
||||
resp["video"] = truncateBase64(v)
|
||||
}
|
||||
if vs, ok := resp["videos"].([]any); ok {
|
||||
for i := range vs {
|
||||
if vm, ok := vs[i].(map[string]any); ok {
|
||||
delete(vm, "bytesBase64Encoded")
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
b, err := common.Marshal(m)
|
||||
if err != nil {
|
||||
return body
|
||||
}
|
||||
return b
|
||||
}
|
||||
|
||||
func truncateBase64(s string) string {
|
||||
const maxKeep = 256
|
||||
if len(s) <= maxKeep {
|
||||
return s
|
||||
}
|
||||
return s[:maxKeep] + "..."
|
||||
}
|
||||
|
||||
// settleTaskBillingOnComplete 任务完成时的统一计费调整。
|
||||
// 优先级:1. adaptor.AdjustBillingOnComplete 返回正数 → 使用 adaptor 计算的额度
|
||||
//
|
||||
// 2. taskResult.TotalTokens > 0 → 按 token 重算
|
||||
// 3. 都不满足 → 保持预扣额度不变
|
||||
func settleTaskBillingOnComplete(ctx context.Context, adaptor TaskPollingAdaptor, task *model.Task, taskResult *relaycommon.TaskInfo) {
|
||||
// 0. 按次计费的任务不做差额结算
|
||||
if bc := task.PrivateData.BillingContext; bc != nil && bc.PerCallBilling {
|
||||
logger.LogInfo(ctx, fmt.Sprintf("任务 %s 按次计费,跳过差额结算", task.TaskID))
|
||||
return
|
||||
}
|
||||
// 1. 优先让 adaptor 决定最终额度
|
||||
if actualQuota := adaptor.AdjustBillingOnComplete(task, taskResult); actualQuota > 0 {
|
||||
RecalculateTaskQuota(ctx, task, actualQuota, "adaptor计费调整")
|
||||
return
|
||||
}
|
||||
// 2. 回退到 token 重算
|
||||
if taskResult.TotalTokens > 0 {
|
||||
RecalculateTaskQuotaByTokens(ctx, task, taskResult.TotalTokens)
|
||||
return
|
||||
}
|
||||
// 3. 无调整,保持预扣额度
|
||||
}
|
||||
@@ -18,8 +18,9 @@ import (
|
||||
)
|
||||
|
||||
const (
|
||||
ViolationFeeCodePrefix = "violation_fee."
|
||||
CSAMViolationMarker = "Failed check: SAFETY_CHECK_TYPE_CSAM"
|
||||
ViolationFeeCodePrefix = "violation_fee."
|
||||
CSAMViolationMarker = "Failed check: SAFETY_CHECK_TYPE"
|
||||
ContentViolatesUsageMarker = "Content violates usage guidelines"
|
||||
)
|
||||
|
||||
func IsViolationFeeCode(code types.ErrorCode) bool {
|
||||
@@ -30,11 +31,11 @@ func HasCSAMViolationMarker(err *types.NewAPIError) bool {
|
||||
if err == nil {
|
||||
return false
|
||||
}
|
||||
if strings.Contains(err.Error(), CSAMViolationMarker) {
|
||||
if strings.Contains(err.Error(), CSAMViolationMarker) || strings.Contains(err.Error(), ContentViolatesUsageMarker) {
|
||||
return true
|
||||
}
|
||||
msg := err.ToOpenAIError().Message
|
||||
return strings.Contains(msg, CSAMViolationMarker)
|
||||
return strings.Contains(msg, CSAMViolationMarker) || strings.Contains(err.Error(), ContentViolatesUsageMarker)
|
||||
}
|
||||
|
||||
func WrapAsViolationFeeGrokCSAM(err *types.NewAPIError) *types.NewAPIError {
|
||||
|
||||
@@ -26,6 +26,11 @@ var AutomaticRetryStatusCodeRanges = []StatusCodeRange{
|
||||
{Start: 525, End: 599},
|
||||
}
|
||||
|
||||
var alwaysSkipRetryStatusCodes = map[int]struct{}{
|
||||
504: {},
|
||||
524: {},
|
||||
}
|
||||
|
||||
func AutomaticDisableStatusCodesToString() string {
|
||||
return statusCodeRangesToString(AutomaticDisableStatusCodeRanges)
|
||||
}
|
||||
@@ -56,7 +61,15 @@ func AutomaticRetryStatusCodesFromString(s string) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func IsAlwaysSkipRetryStatusCode(code int) bool {
|
||||
_, exists := alwaysSkipRetryStatusCodes[code]
|
||||
return exists
|
||||
}
|
||||
|
||||
func ShouldRetryByStatusCode(code int) bool {
|
||||
if IsAlwaysSkipRetryStatusCode(code) {
|
||||
return false
|
||||
}
|
||||
return shouldMatchStatusCodeRanges(AutomaticRetryStatusCodeRanges, code)
|
||||
}
|
||||
|
||||
|
||||
@@ -62,6 +62,8 @@ func TestShouldRetryByStatusCode(t *testing.T) {
|
||||
|
||||
require.True(t, ShouldRetryByStatusCode(429))
|
||||
require.True(t, ShouldRetryByStatusCode(500))
|
||||
require.False(t, ShouldRetryByStatusCode(504))
|
||||
require.False(t, ShouldRetryByStatusCode(524))
|
||||
require.False(t, ShouldRetryByStatusCode(400))
|
||||
require.False(t, ShouldRetryByStatusCode(200))
|
||||
}
|
||||
@@ -77,3 +79,9 @@ func TestShouldRetryByStatusCode_DefaultMatchesLegacyBehavior(t *testing.T) {
|
||||
require.False(t, ShouldRetryByStatusCode(524))
|
||||
require.True(t, ShouldRetryByStatusCode(599))
|
||||
}
|
||||
|
||||
func TestIsAlwaysSkipRetryStatusCode(t *testing.T) {
|
||||
require.True(t, IsAlwaysSkipRetryStatusCode(504))
|
||||
require.True(t, IsAlwaysSkipRetryStatusCode(524))
|
||||
require.False(t, IsAlwaysSkipRetryStatusCode(500))
|
||||
}
|
||||
|
||||
@@ -22,7 +22,8 @@ type PriceData struct {
|
||||
AudioCompletionRatio float64
|
||||
OtherRatios map[string]float64
|
||||
UsePrice bool
|
||||
QuotaToPreConsume int // 预消耗额度
|
||||
Quota int // 按次计费的最终额度(MJ / Task)
|
||||
QuotaToPreConsume int // 按量计费的预消耗额度
|
||||
GroupRatioInfo GroupRatioInfo
|
||||
}
|
||||
|
||||
@@ -36,12 +37,6 @@ func (p *PriceData) AddOtherRatio(key string, ratio float64) {
|
||||
p.OtherRatios[key] = ratio
|
||||
}
|
||||
|
||||
type PerCallPriceData struct {
|
||||
ModelPrice float64
|
||||
Quota int
|
||||
GroupRatioInfo GroupRatioInfo
|
||||
}
|
||||
|
||||
func (p *PriceData) ToSetting() string {
|
||||
return fmt.Sprintf("ModelPrice: %f, ModelRatio: %f, CompletionRatio: %f, CacheRatio: %f, GroupRatio: %f, UsePrice: %t, CacheCreationRatio: %f, CacheCreation5mRatio: %f, CacheCreation1hRatio: %f, QuotaToPreConsume: %d, ImageRatio: %f, AudioRatio: %f, AudioCompletionRatio: %f", p.ModelPrice, p.ModelRatio, p.CompletionRatio, p.CacheRatio, p.GroupRatioInfo.GroupRatio, p.UsePrice, p.CacheCreationRatio, p.CacheCreation5mRatio, p.CacheCreation1hRatio, p.QuotaToPreConsume, p.ImageRatio, p.AudioRatio, p.AudioCompletionRatio)
|
||||
}
|
||||
|
||||
@@ -29,6 +29,7 @@ import {
|
||||
showSuccess,
|
||||
updateAPI,
|
||||
getSystemName,
|
||||
getOAuthProviderIcon,
|
||||
setUserData,
|
||||
onGitHubOAuthClicked,
|
||||
onDiscordOAuthClicked,
|
||||
@@ -130,6 +131,17 @@ const LoginForm = () => {
|
||||
return {};
|
||||
}
|
||||
}, [statusState?.status]);
|
||||
const hasCustomOAuthProviders =
|
||||
(status.custom_oauth_providers || []).length > 0;
|
||||
const hasOAuthLoginOptions = Boolean(
|
||||
status.github_oauth ||
|
||||
status.discord_oauth ||
|
||||
status.oidc_enabled ||
|
||||
status.wechat_login ||
|
||||
status.linuxdo_oauth ||
|
||||
status.telegram_oauth ||
|
||||
hasCustomOAuthProviders,
|
||||
);
|
||||
|
||||
useEffect(() => {
|
||||
if (status?.turnstile_check) {
|
||||
@@ -598,7 +610,7 @@ const LoginForm = () => {
|
||||
theme='outline'
|
||||
className='w-full h-12 flex items-center justify-center !rounded-full border border-gray-200 hover:bg-gray-50 transition-colors'
|
||||
type='tertiary'
|
||||
icon={<IconLock size='large' />}
|
||||
icon={getOAuthProviderIcon(provider.icon || '', 20)}
|
||||
onClick={() => handleCustomOAuthClick(provider)}
|
||||
loading={customOAuthLoading[provider.slug]}
|
||||
>
|
||||
@@ -817,12 +829,7 @@ const LoginForm = () => {
|
||||
</div>
|
||||
</Form>
|
||||
|
||||
{(status.github_oauth ||
|
||||
status.discord_oauth ||
|
||||
status.oidc_enabled ||
|
||||
status.wechat_login ||
|
||||
status.linuxdo_oauth ||
|
||||
status.telegram_oauth) && (
|
||||
{hasOAuthLoginOptions && (
|
||||
<>
|
||||
<Divider margin='12px' align='center'>
|
||||
{t('或')}
|
||||
@@ -952,14 +959,7 @@ const LoginForm = () => {
|
||||
/>
|
||||
<div className='w-full max-w-sm mt-[60px]'>
|
||||
{showEmailLogin ||
|
||||
!(
|
||||
status.github_oauth ||
|
||||
status.discord_oauth ||
|
||||
status.oidc_enabled ||
|
||||
status.wechat_login ||
|
||||
status.linuxdo_oauth ||
|
||||
status.telegram_oauth
|
||||
)
|
||||
!hasOAuthLoginOptions
|
||||
? renderEmailLoginForm()
|
||||
: renderOAuthOptions()}
|
||||
{renderWeChatLoginModal()}
|
||||
|
||||
@@ -27,8 +27,10 @@ import {
|
||||
showSuccess,
|
||||
updateAPI,
|
||||
getSystemName,
|
||||
getOAuthProviderIcon,
|
||||
setUserData,
|
||||
onDiscordOAuthClicked,
|
||||
onCustomOAuthClicked,
|
||||
} from '../../helpers';
|
||||
import Turnstile from 'react-turnstile';
|
||||
import {
|
||||
@@ -98,6 +100,7 @@ const RegisterForm = () => {
|
||||
const [otherRegisterOptionsLoading, setOtherRegisterOptionsLoading] =
|
||||
useState(false);
|
||||
const [wechatCodeSubmitLoading, setWechatCodeSubmitLoading] = useState(false);
|
||||
const [customOAuthLoading, setCustomOAuthLoading] = useState({});
|
||||
const [disableButton, setDisableButton] = useState(false);
|
||||
const [countdown, setCountdown] = useState(30);
|
||||
const [agreedToTerms, setAgreedToTerms] = useState(false);
|
||||
@@ -126,6 +129,17 @@ const RegisterForm = () => {
|
||||
return {};
|
||||
}
|
||||
}, [statusState?.status]);
|
||||
const hasCustomOAuthProviders =
|
||||
(status.custom_oauth_providers || []).length > 0;
|
||||
const hasOAuthRegisterOptions = Boolean(
|
||||
status.github_oauth ||
|
||||
status.discord_oauth ||
|
||||
status.oidc_enabled ||
|
||||
status.wechat_login ||
|
||||
status.linuxdo_oauth ||
|
||||
status.telegram_oauth ||
|
||||
hasCustomOAuthProviders,
|
||||
);
|
||||
|
||||
const [showEmailVerification, setShowEmailVerification] = useState(false);
|
||||
|
||||
@@ -319,6 +333,17 @@ const RegisterForm = () => {
|
||||
}
|
||||
};
|
||||
|
||||
const handleCustomOAuthClick = (provider) => {
|
||||
setCustomOAuthLoading((prev) => ({ ...prev, [provider.slug]: true }));
|
||||
try {
|
||||
onCustomOAuthClicked(provider, { shouldLogout: true });
|
||||
} finally {
|
||||
setTimeout(() => {
|
||||
setCustomOAuthLoading((prev) => ({ ...prev, [provider.slug]: false }));
|
||||
}, 3000);
|
||||
}
|
||||
};
|
||||
|
||||
const handleEmailRegisterClick = () => {
|
||||
setEmailRegisterLoading(true);
|
||||
setShowEmailRegister(true);
|
||||
@@ -469,6 +494,23 @@ const RegisterForm = () => {
|
||||
</Button>
|
||||
)}
|
||||
|
||||
{status.custom_oauth_providers &&
|
||||
status.custom_oauth_providers.map((provider) => (
|
||||
<Button
|
||||
key={provider.slug}
|
||||
theme='outline'
|
||||
className='w-full h-12 flex items-center justify-center !rounded-full border border-gray-200 hover:bg-gray-50 transition-colors'
|
||||
type='tertiary'
|
||||
icon={getOAuthProviderIcon(provider.icon || '', 20)}
|
||||
onClick={() => handleCustomOAuthClick(provider)}
|
||||
loading={customOAuthLoading[provider.slug]}
|
||||
>
|
||||
<span className='ml-3'>
|
||||
{t('使用 {{name}} 继续', { name: provider.name })}
|
||||
</span>
|
||||
</Button>
|
||||
))}
|
||||
|
||||
{status.telegram_oauth && (
|
||||
<div className='flex justify-center my-2'>
|
||||
<TelegramLoginButton
|
||||
@@ -650,12 +692,7 @@ const RegisterForm = () => {
|
||||
</div>
|
||||
</Form>
|
||||
|
||||
{(status.github_oauth ||
|
||||
status.discord_oauth ||
|
||||
status.oidc_enabled ||
|
||||
status.wechat_login ||
|
||||
status.linuxdo_oauth ||
|
||||
status.telegram_oauth) && (
|
||||
{hasOAuthRegisterOptions && (
|
||||
<>
|
||||
<Divider margin='12px' align='center'>
|
||||
{t('或')}
|
||||
@@ -745,14 +782,7 @@ const RegisterForm = () => {
|
||||
/>
|
||||
<div className='w-full max-w-sm mt-[60px]'>
|
||||
{showEmailRegister ||
|
||||
!(
|
||||
status.github_oauth ||
|
||||
status.discord_oauth ||
|
||||
status.oidc_enabled ||
|
||||
status.wechat_login ||
|
||||
status.linuxdo_oauth ||
|
||||
status.telegram_oauth
|
||||
)
|
||||
!hasOAuthRegisterOptions
|
||||
? renderEmailRegisterForm()
|
||||
: renderOAuthOptions()}
|
||||
{renderWeChatLoginModal()}
|
||||
|
||||
214
web/src/components/common/modals/RiskAcknowledgementModal.jsx
Normal file
214
web/src/components/common/modals/RiskAcknowledgementModal.jsx
Normal file
@@ -0,0 +1,214 @@
|
||||
/*
|
||||
Copyright (C) 2025 QuantumNous
|
||||
|
||||
This program is free software: you can redistribute it and/or modify
|
||||
it under the terms of the GNU Affero General Public License as
|
||||
published by the Free Software Foundation, either version 3 of the
|
||||
License, or (at your option) any later version.
|
||||
|
||||
This program is distributed in the hope that it will be useful,
|
||||
but WITHOUT ANY WARRANTY; without even the implied warranty of
|
||||
MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
||||
GNU Affero General Public License for more details.
|
||||
|
||||
You should have received a copy of the GNU Affero General Public License
|
||||
along with this program. If not, see <https://www.gnu.org/licenses/>.
|
||||
|
||||
For commercial licensing, please contact support@quantumnous.com
|
||||
*/
|
||||
|
||||
import React, { useCallback, useEffect, useMemo, useState } from 'react';
|
||||
import {
|
||||
Modal,
|
||||
Button,
|
||||
Typography,
|
||||
Checkbox,
|
||||
Input,
|
||||
Space,
|
||||
} from '@douyinfe/semi-ui';
|
||||
import { IconAlertTriangle } from '@douyinfe/semi-icons';
|
||||
import { useIsMobile } from '../../../hooks/common/useIsMobile';
|
||||
import MarkdownRenderer from '../markdown/MarkdownRenderer';
|
||||
|
||||
const { Text } = Typography;
|
||||
|
||||
const RiskMarkdownBlock = React.memo(function RiskMarkdownBlock({
|
||||
markdownContent,
|
||||
}) {
|
||||
if (!markdownContent) {
|
||||
return null;
|
||||
}
|
||||
|
||||
return (
|
||||
<div
|
||||
className='rounded-lg'
|
||||
style={{
|
||||
border: '1px solid var(--semi-color-warning-light-hover)',
|
||||
padding: '12px',
|
||||
contentVisibility: 'auto',
|
||||
}}
|
||||
>
|
||||
<MarkdownRenderer content={markdownContent} />
|
||||
</div>
|
||||
);
|
||||
});
|
||||
|
||||
const RiskAcknowledgementModal = React.memo(function RiskAcknowledgementModal({
|
||||
visible,
|
||||
title,
|
||||
markdownContent = '',
|
||||
detailTitle = '',
|
||||
detailItems = [],
|
||||
checklist = [],
|
||||
inputPrompt = '',
|
||||
requiredText = '',
|
||||
inputPlaceholder = '',
|
||||
mismatchText = '',
|
||||
cancelText = '',
|
||||
confirmText = '',
|
||||
onCancel,
|
||||
onConfirm,
|
||||
}) {
|
||||
const isMobile = useIsMobile();
|
||||
const [checkedItems, setCheckedItems] = useState([]);
|
||||
const [typedText, setTypedText] = useState('');
|
||||
|
||||
useEffect(() => {
|
||||
if (!visible) return;
|
||||
setCheckedItems(Array(checklist.length).fill(false));
|
||||
setTypedText('');
|
||||
}, [visible, checklist.length]);
|
||||
|
||||
const allChecked = useMemo(() => {
|
||||
if (checklist.length === 0) return true;
|
||||
return checkedItems.length === checklist.length && checkedItems.every(Boolean);
|
||||
}, [checkedItems, checklist.length]);
|
||||
|
||||
const typedMatched = useMemo(() => {
|
||||
if (!requiredText) return true;
|
||||
return typedText.trim() === requiredText.trim();
|
||||
}, [typedText, requiredText]);
|
||||
|
||||
const detailText = useMemo(() => detailItems.join(', '), [detailItems]);
|
||||
const canConfirm = allChecked && typedMatched;
|
||||
|
||||
const handleChecklistChange = useCallback((index, checked) => {
|
||||
setCheckedItems((previous) => {
|
||||
const next = [...previous];
|
||||
next[index] = checked;
|
||||
return next;
|
||||
});
|
||||
}, []);
|
||||
|
||||
return (
|
||||
<Modal
|
||||
visible={visible}
|
||||
title={
|
||||
<Space align='center'>
|
||||
<IconAlertTriangle style={{ color: 'var(--semi-color-warning)' }} />
|
||||
<span>{title}</span>
|
||||
</Space>
|
||||
}
|
||||
width={isMobile ? '100%' : 860}
|
||||
centered
|
||||
maskClosable={false}
|
||||
closeOnEsc={false}
|
||||
onCancel={onCancel}
|
||||
bodyStyle={{
|
||||
maxHeight: isMobile ? '70vh' : '72vh',
|
||||
overflowY: 'auto',
|
||||
padding: isMobile ? '12px 16px' : '18px 22px',
|
||||
}}
|
||||
footer={
|
||||
<Space>
|
||||
<Button onClick={onCancel}>{cancelText}</Button>
|
||||
<Button
|
||||
theme='solid'
|
||||
type='danger'
|
||||
disabled={!canConfirm}
|
||||
onClick={onConfirm}
|
||||
>
|
||||
{confirmText}
|
||||
</Button>
|
||||
</Space>
|
||||
}
|
||||
>
|
||||
<div className='flex flex-col gap-4'>
|
||||
|
||||
<RiskMarkdownBlock markdownContent={markdownContent} />
|
||||
|
||||
{detailItems.length > 0 ? (
|
||||
<div
|
||||
className='flex flex-col gap-2 rounded-lg'
|
||||
style={{
|
||||
border: '1px solid var(--semi-color-warning-light-hover)',
|
||||
background: 'var(--semi-color-fill-0)',
|
||||
padding: isMobile ? '10px 12px' : '12px 14px',
|
||||
}}
|
||||
>
|
||||
{detailTitle ? <Text strong>{detailTitle}</Text> : null}
|
||||
<div className='font-mono text-xs break-all bg-orange-50 border border-orange-200 rounded-md p-2'>
|
||||
{detailText}
|
||||
</div>
|
||||
</div>
|
||||
) : null}
|
||||
|
||||
{checklist.length > 0 ? (
|
||||
<div
|
||||
className='flex flex-col gap-2 rounded-lg'
|
||||
style={{
|
||||
border: '1px solid var(--semi-color-border)',
|
||||
background: 'var(--semi-color-fill-0)',
|
||||
padding: isMobile ? '10px 12px' : '12px 14px',
|
||||
}}
|
||||
>
|
||||
{checklist.map((item, index) => (
|
||||
<Checkbox
|
||||
key={`risk-check-${index}`}
|
||||
checked={!!checkedItems[index]}
|
||||
onChange={(event) => {
|
||||
handleChecklistChange(index, event.target.checked);
|
||||
}}
|
||||
>
|
||||
{item}
|
||||
</Checkbox>
|
||||
))}
|
||||
</div>
|
||||
) : null}
|
||||
|
||||
{requiredText ? (
|
||||
<div
|
||||
className='flex flex-col gap-2 rounded-lg'
|
||||
style={{
|
||||
border: '1px solid var(--semi-color-danger-light-hover)',
|
||||
background: 'var(--semi-color-danger-light-default)',
|
||||
padding: isMobile ? '10px 12px' : '12px 14px',
|
||||
}}
|
||||
>
|
||||
{inputPrompt ? <Text strong>{inputPrompt}</Text> : null}
|
||||
<div className='font-mono text-xs break-all rounded-md p-2 bg-gray-50 border border-gray-200'>
|
||||
{requiredText}
|
||||
</div>
|
||||
<Input
|
||||
value={typedText}
|
||||
onChange={setTypedText}
|
||||
placeholder={inputPlaceholder}
|
||||
autoFocus={visible}
|
||||
onCopy={(event) => event.preventDefault()}
|
||||
onCut={(event) => event.preventDefault()}
|
||||
onPaste={(event) => event.preventDefault()}
|
||||
onDrop={(event) => event.preventDefault()}
|
||||
/>
|
||||
{!typedMatched && typedText ? (
|
||||
<Text type='danger' size='small'>
|
||||
{mismatchText}
|
||||
</Text>
|
||||
) : null}
|
||||
</div>
|
||||
) : null}
|
||||
</div>
|
||||
</Modal>
|
||||
);
|
||||
});
|
||||
|
||||
export default RiskAcknowledgementModal;
|
||||
@@ -27,14 +27,20 @@ import {
|
||||
Modal,
|
||||
Banner,
|
||||
Card,
|
||||
Collapse,
|
||||
Switch,
|
||||
Table,
|
||||
Tag,
|
||||
Popconfirm,
|
||||
Space,
|
||||
Select,
|
||||
} from '@douyinfe/semi-ui';
|
||||
import { IconPlus, IconEdit, IconDelete } from '@douyinfe/semi-icons';
|
||||
import { API, showError, showSuccess } from '../../helpers';
|
||||
import {
|
||||
IconPlus,
|
||||
IconEdit,
|
||||
IconDelete,
|
||||
IconRefresh,
|
||||
} from '@douyinfe/semi-icons';
|
||||
import { API, showError, showSuccess, getOAuthProviderIcon } from '../../helpers';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
|
||||
const { Text } = Typography;
|
||||
@@ -120,6 +126,69 @@ const OAUTH_PRESETS = {
|
||||
},
|
||||
};
|
||||
|
||||
const OAUTH_PRESET_ICONS = {
|
||||
'github-enterprise': 'github',
|
||||
gitlab: 'gitlab',
|
||||
gitea: 'gitea',
|
||||
nextcloud: 'nextcloud',
|
||||
keycloak: 'keycloak',
|
||||
authentik: 'authentik',
|
||||
ory: 'openid',
|
||||
};
|
||||
|
||||
const getPresetIcon = (preset) => OAUTH_PRESET_ICONS[preset] || '';
|
||||
|
||||
const PRESET_RESET_VALUES = {
|
||||
name: '',
|
||||
slug: '',
|
||||
icon: '',
|
||||
authorization_endpoint: '',
|
||||
token_endpoint: '',
|
||||
user_info_endpoint: '',
|
||||
scopes: '',
|
||||
user_id_field: '',
|
||||
username_field: '',
|
||||
display_name_field: '',
|
||||
email_field: '',
|
||||
well_known: '',
|
||||
auth_style: 0,
|
||||
access_policy: '',
|
||||
access_denied_message: '',
|
||||
};
|
||||
|
||||
const DISCOVERY_FIELD_LABELS = {
|
||||
authorization_endpoint: 'Authorization Endpoint',
|
||||
token_endpoint: 'Token Endpoint',
|
||||
user_info_endpoint: 'User Info Endpoint',
|
||||
scopes: 'Scopes',
|
||||
user_id_field: 'User ID Field',
|
||||
username_field: 'Username Field',
|
||||
display_name_field: 'Display Name Field',
|
||||
email_field: 'Email Field',
|
||||
};
|
||||
|
||||
const ACCESS_POLICY_TEMPLATES = {
|
||||
level_active: `{
|
||||
"logic": "and",
|
||||
"conditions": [
|
||||
{"field": "trust_level", "op": "gte", "value": 2},
|
||||
{"field": "active", "op": "eq", "value": true}
|
||||
]
|
||||
}`,
|
||||
org_or_role: `{
|
||||
"logic": "or",
|
||||
"conditions": [
|
||||
{"field": "org", "op": "eq", "value": "core"},
|
||||
{"field": "roles", "op": "contains", "value": "admin"}
|
||||
]
|
||||
}`,
|
||||
};
|
||||
|
||||
const ACCESS_DENIED_TEMPLATES = {
|
||||
level_hint: '需要等级 {{required}},你当前等级 {{current}}(字段:{{field}})',
|
||||
org_hint: '仅限指定组织或角色访问。组织={{current.org}},角色={{current.roles}}',
|
||||
};
|
||||
|
||||
const CustomOAuthSetting = ({ serverAddress }) => {
|
||||
const { t } = useTranslation();
|
||||
const [providers, setProviders] = useState([]);
|
||||
@@ -129,8 +198,47 @@ const CustomOAuthSetting = ({ serverAddress }) => {
|
||||
const [formValues, setFormValues] = useState({});
|
||||
const [selectedPreset, setSelectedPreset] = useState('');
|
||||
const [baseUrl, setBaseUrl] = useState('');
|
||||
const [discoveryLoading, setDiscoveryLoading] = useState(false);
|
||||
const [discoveryInfo, setDiscoveryInfo] = useState(null);
|
||||
const [advancedActiveKeys, setAdvancedActiveKeys] = useState([]);
|
||||
const formApiRef = React.useRef(null);
|
||||
|
||||
const mergeFormValues = (newValues) => {
|
||||
setFormValues((prev) => ({ ...prev, ...newValues }));
|
||||
if (!formApiRef.current) return;
|
||||
Object.entries(newValues).forEach(([key, value]) => {
|
||||
formApiRef.current.setValue(key, value);
|
||||
});
|
||||
};
|
||||
|
||||
const getLatestFormValues = () => {
|
||||
const values = formApiRef.current?.getValues?.();
|
||||
return values && typeof values === 'object' ? values : formValues;
|
||||
};
|
||||
|
||||
const normalizeBaseUrl = (url) => (url || '').trim().replace(/\/+$/, '');
|
||||
|
||||
const inferBaseUrlFromProvider = (provider) => {
|
||||
const endpoint = provider?.authorization_endpoint || provider?.token_endpoint;
|
||||
if (!endpoint) return '';
|
||||
try {
|
||||
const url = new URL(endpoint);
|
||||
return `${url.protocol}//${url.host}`;
|
||||
} catch (error) {
|
||||
return '';
|
||||
}
|
||||
};
|
||||
|
||||
const resetDiscoveryState = () => {
|
||||
setDiscoveryInfo(null);
|
||||
};
|
||||
|
||||
const closeModal = () => {
|
||||
setModalVisible(false);
|
||||
resetDiscoveryState();
|
||||
setAdvancedActiveKeys([]);
|
||||
};
|
||||
|
||||
const fetchProviders = async () => {
|
||||
setLoading(true);
|
||||
try {
|
||||
@@ -154,23 +262,30 @@ const CustomOAuthSetting = ({ serverAddress }) => {
|
||||
setEditingProvider(null);
|
||||
setFormValues({
|
||||
enabled: false,
|
||||
icon: '',
|
||||
scopes: 'openid profile email',
|
||||
user_id_field: 'sub',
|
||||
username_field: 'preferred_username',
|
||||
display_name_field: 'name',
|
||||
email_field: 'email',
|
||||
auth_style: 0,
|
||||
access_policy: '',
|
||||
access_denied_message: '',
|
||||
});
|
||||
setSelectedPreset('');
|
||||
setBaseUrl('');
|
||||
resetDiscoveryState();
|
||||
setAdvancedActiveKeys([]);
|
||||
setModalVisible(true);
|
||||
};
|
||||
|
||||
const handleEdit = (provider) => {
|
||||
setEditingProvider(provider);
|
||||
setFormValues({ ...provider });
|
||||
setSelectedPreset('');
|
||||
setBaseUrl('');
|
||||
setSelectedPreset(OAUTH_PRESETS[provider.slug] ? provider.slug : '');
|
||||
setBaseUrl(inferBaseUrlFromProvider(provider));
|
||||
resetDiscoveryState();
|
||||
setAdvancedActiveKeys([]);
|
||||
setModalVisible(true);
|
||||
};
|
||||
|
||||
@@ -189,6 +304,8 @@ const CustomOAuthSetting = ({ serverAddress }) => {
|
||||
};
|
||||
|
||||
const handleSubmit = async () => {
|
||||
const currentValues = getLatestFormValues();
|
||||
|
||||
// Validate required fields
|
||||
const requiredFields = [
|
||||
'name',
|
||||
@@ -204,7 +321,7 @@ const CustomOAuthSetting = ({ serverAddress }) => {
|
||||
}
|
||||
|
||||
for (const field of requiredFields) {
|
||||
if (!formValues[field]) {
|
||||
if (!currentValues[field]) {
|
||||
showError(t(`请填写 ${field}`));
|
||||
return;
|
||||
}
|
||||
@@ -213,11 +330,11 @@ const CustomOAuthSetting = ({ serverAddress }) => {
|
||||
// Validate endpoint URLs must be full URLs
|
||||
const endpointFields = ['authorization_endpoint', 'token_endpoint', 'user_info_endpoint'];
|
||||
for (const field of endpointFields) {
|
||||
const value = formValues[field];
|
||||
const value = currentValues[field];
|
||||
if (value && !value.startsWith('http://') && !value.startsWith('https://')) {
|
||||
// Check if user selected a preset but forgot to fill server address
|
||||
// Check if user selected a preset but forgot to fill issuer URL
|
||||
if (selectedPreset && !baseUrl) {
|
||||
showError(t('请先填写服务器地址,以自动生成完整的端点 URL'));
|
||||
showError(t('请先填写 Issuer URL,以自动生成完整的端点 URL'));
|
||||
} else {
|
||||
showError(t('端点 URL 必须是完整地址(以 http:// 或 https:// 开头)'));
|
||||
}
|
||||
@@ -226,80 +343,199 @@ const CustomOAuthSetting = ({ serverAddress }) => {
|
||||
}
|
||||
|
||||
try {
|
||||
const payload = { ...currentValues, enabled: !!currentValues.enabled };
|
||||
delete payload.preset;
|
||||
delete payload.base_url;
|
||||
|
||||
let res;
|
||||
if (editingProvider) {
|
||||
res = await API.put(
|
||||
`/api/custom-oauth-provider/${editingProvider.id}`,
|
||||
formValues
|
||||
payload
|
||||
);
|
||||
} else {
|
||||
res = await API.post('/api/custom-oauth-provider/', formValues);
|
||||
res = await API.post('/api/custom-oauth-provider/', payload);
|
||||
}
|
||||
|
||||
if (res.data.success) {
|
||||
showSuccess(editingProvider ? t('更新成功') : t('创建成功'));
|
||||
setModalVisible(false);
|
||||
closeModal();
|
||||
fetchProviders();
|
||||
} else {
|
||||
showError(res.data.message);
|
||||
}
|
||||
} catch (error) {
|
||||
showError(editingProvider ? t('更新失败') : t('创建失败'));
|
||||
showError(
|
||||
error?.response?.data?.message ||
|
||||
(editingProvider ? t('更新失败') : t('创建失败')),
|
||||
);
|
||||
}
|
||||
};
|
||||
|
||||
const handleFetchFromDiscovery = async () => {
|
||||
const cleanBaseUrl = normalizeBaseUrl(baseUrl);
|
||||
const configuredWellKnown = (formValues.well_known || '').trim();
|
||||
const wellKnownUrl =
|
||||
configuredWellKnown ||
|
||||
(cleanBaseUrl ? `${cleanBaseUrl}/.well-known/openid-configuration` : '');
|
||||
|
||||
if (!wellKnownUrl) {
|
||||
showError(t('请先填写 Discovery URL 或 Issuer URL'));
|
||||
return;
|
||||
}
|
||||
|
||||
setDiscoveryLoading(true);
|
||||
try {
|
||||
const res = await API.post('/api/custom-oauth-provider/discovery', {
|
||||
well_known_url: configuredWellKnown || '',
|
||||
issuer_url: cleanBaseUrl || '',
|
||||
});
|
||||
if (!res.data.success) {
|
||||
throw new Error(res.data.message || t('未知错误'));
|
||||
}
|
||||
const data = res.data.data?.discovery || {};
|
||||
const resolvedWellKnown = res.data.data?.well_known_url || wellKnownUrl;
|
||||
|
||||
const discoveredValues = {
|
||||
well_known: resolvedWellKnown,
|
||||
};
|
||||
const autoFilledFields = [];
|
||||
if (data.authorization_endpoint) {
|
||||
discoveredValues.authorization_endpoint = data.authorization_endpoint;
|
||||
autoFilledFields.push('authorization_endpoint');
|
||||
}
|
||||
if (data.token_endpoint) {
|
||||
discoveredValues.token_endpoint = data.token_endpoint;
|
||||
autoFilledFields.push('token_endpoint');
|
||||
}
|
||||
if (data.userinfo_endpoint) {
|
||||
discoveredValues.user_info_endpoint = data.userinfo_endpoint;
|
||||
autoFilledFields.push('user_info_endpoint');
|
||||
}
|
||||
|
||||
const scopesSupported = Array.isArray(data.scopes_supported)
|
||||
? data.scopes_supported
|
||||
: [];
|
||||
if (scopesSupported.length > 0 && !formValues.scopes) {
|
||||
const preferredScopes = ['openid', 'profile', 'email'].filter((scope) =>
|
||||
scopesSupported.includes(scope),
|
||||
);
|
||||
discoveredValues.scopes =
|
||||
preferredScopes.length > 0
|
||||
? preferredScopes.join(' ')
|
||||
: scopesSupported.slice(0, 5).join(' ');
|
||||
autoFilledFields.push('scopes');
|
||||
}
|
||||
|
||||
const claimsSupported = Array.isArray(data.claims_supported)
|
||||
? data.claims_supported
|
||||
: [];
|
||||
const claimMap = {
|
||||
user_id_field: 'sub',
|
||||
username_field: 'preferred_username',
|
||||
display_name_field: 'name',
|
||||
email_field: 'email',
|
||||
};
|
||||
Object.entries(claimMap).forEach(([field, claim]) => {
|
||||
if (!formValues[field] && claimsSupported.includes(claim)) {
|
||||
discoveredValues[field] = claim;
|
||||
autoFilledFields.push(field);
|
||||
}
|
||||
});
|
||||
|
||||
const hasCoreEndpoint =
|
||||
discoveredValues.authorization_endpoint ||
|
||||
discoveredValues.token_endpoint ||
|
||||
discoveredValues.user_info_endpoint;
|
||||
if (!hasCoreEndpoint) {
|
||||
showError(t('未在 Discovery 响应中找到可用的 OAuth 端点'));
|
||||
return;
|
||||
}
|
||||
|
||||
mergeFormValues(discoveredValues);
|
||||
setDiscoveryInfo({
|
||||
wellKnown: wellKnownUrl,
|
||||
autoFilledFields,
|
||||
scopesSupported: scopesSupported.slice(0, 12),
|
||||
claimsSupported: claimsSupported.slice(0, 12),
|
||||
});
|
||||
showSuccess(t('已从 Discovery 自动填充配置'));
|
||||
} catch (error) {
|
||||
showError(
|
||||
t('获取 Discovery 配置失败:') + (error?.message || t('未知错误')),
|
||||
);
|
||||
} finally {
|
||||
setDiscoveryLoading(false);
|
||||
}
|
||||
};
|
||||
|
||||
const handlePresetChange = (preset) => {
|
||||
setSelectedPreset(preset);
|
||||
if (preset && OAUTH_PRESETS[preset]) {
|
||||
const presetConfig = OAUTH_PRESETS[preset];
|
||||
const cleanUrl = baseUrl ? baseUrl.replace(/\/+$/, '') : '';
|
||||
const newValues = {
|
||||
name: presetConfig.name,
|
||||
slug: preset,
|
||||
scopes: presetConfig.scopes,
|
||||
user_id_field: presetConfig.user_id_field,
|
||||
username_field: presetConfig.username_field,
|
||||
display_name_field: presetConfig.display_name_field,
|
||||
email_field: presetConfig.email_field,
|
||||
auth_style: presetConfig.auth_style ?? 0,
|
||||
};
|
||||
// Only fill endpoints if server address is provided
|
||||
if (cleanUrl) {
|
||||
newValues.authorization_endpoint = cleanUrl + presetConfig.authorization_endpoint;
|
||||
newValues.token_endpoint = cleanUrl + presetConfig.token_endpoint;
|
||||
newValues.user_info_endpoint = cleanUrl + presetConfig.user_info_endpoint;
|
||||
}
|
||||
setFormValues((prev) => ({ ...prev, ...newValues }));
|
||||
// Update form fields directly via formApi
|
||||
if (formApiRef.current) {
|
||||
Object.entries(newValues).forEach(([key, value]) => {
|
||||
formApiRef.current.setValue(key, value);
|
||||
});
|
||||
}
|
||||
resetDiscoveryState();
|
||||
const cleanUrl = normalizeBaseUrl(baseUrl);
|
||||
if (!preset || !OAUTH_PRESETS[preset]) {
|
||||
mergeFormValues(PRESET_RESET_VALUES);
|
||||
return;
|
||||
}
|
||||
|
||||
const presetConfig = OAUTH_PRESETS[preset];
|
||||
const newValues = {
|
||||
...PRESET_RESET_VALUES,
|
||||
name: presetConfig.name,
|
||||
slug: preset,
|
||||
icon: getPresetIcon(preset),
|
||||
scopes: presetConfig.scopes,
|
||||
user_id_field: presetConfig.user_id_field,
|
||||
username_field: presetConfig.username_field,
|
||||
display_name_field: presetConfig.display_name_field,
|
||||
email_field: presetConfig.email_field,
|
||||
auth_style: presetConfig.auth_style ?? 0,
|
||||
};
|
||||
if (cleanUrl) {
|
||||
newValues.authorization_endpoint =
|
||||
cleanUrl + presetConfig.authorization_endpoint;
|
||||
newValues.token_endpoint = cleanUrl + presetConfig.token_endpoint;
|
||||
newValues.user_info_endpoint = cleanUrl + presetConfig.user_info_endpoint;
|
||||
}
|
||||
mergeFormValues(newValues);
|
||||
};
|
||||
|
||||
const handleBaseUrlChange = (url) => {
|
||||
setBaseUrl(url);
|
||||
if (url && selectedPreset && OAUTH_PRESETS[selectedPreset]) {
|
||||
const presetConfig = OAUTH_PRESETS[selectedPreset];
|
||||
const cleanUrl = url.replace(/\/+$/, ''); // Remove trailing slashes
|
||||
const cleanUrl = normalizeBaseUrl(url);
|
||||
const newValues = {
|
||||
authorization_endpoint: cleanUrl + presetConfig.authorization_endpoint,
|
||||
token_endpoint: cleanUrl + presetConfig.token_endpoint,
|
||||
user_info_endpoint: cleanUrl + presetConfig.user_info_endpoint,
|
||||
};
|
||||
setFormValues((prev) => ({ ...prev, ...newValues }));
|
||||
// Update form fields directly via formApi (use merge mode to preserve other fields)
|
||||
if (formApiRef.current) {
|
||||
Object.entries(newValues).forEach(([key, value]) => {
|
||||
formApiRef.current.setValue(key, value);
|
||||
});
|
||||
}
|
||||
mergeFormValues(newValues);
|
||||
}
|
||||
};
|
||||
|
||||
const applyAccessPolicyTemplate = (templateKey) => {
|
||||
const template = ACCESS_POLICY_TEMPLATES[templateKey];
|
||||
if (!template) return;
|
||||
mergeFormValues({ access_policy: template });
|
||||
showSuccess(t('已填充策略模板'));
|
||||
};
|
||||
|
||||
const applyDeniedTemplate = (templateKey) => {
|
||||
const template = ACCESS_DENIED_TEMPLATES[templateKey];
|
||||
if (!template) return;
|
||||
mergeFormValues({ access_denied_message: template });
|
||||
showSuccess(t('已填充提示模板'));
|
||||
};
|
||||
|
||||
const columns = [
|
||||
{
|
||||
title: t('图标'),
|
||||
dataIndex: 'icon',
|
||||
key: 'icon',
|
||||
width: 80,
|
||||
render: (icon) => getOAuthProviderIcon(icon || '', 18),
|
||||
},
|
||||
{
|
||||
title: t('名称'),
|
||||
dataIndex: 'name',
|
||||
@@ -325,7 +561,10 @@ const CustomOAuthSetting = ({ serverAddress }) => {
|
||||
title: t('Client ID'),
|
||||
dataIndex: 'client_id',
|
||||
key: 'client_id',
|
||||
render: (id) => (id ? id.substring(0, 20) + '...' : '-'),
|
||||
render: (id) => {
|
||||
if (!id) return '-';
|
||||
return id.length > 20 ? `${id.substring(0, 20)}...` : id;
|
||||
},
|
||||
},
|
||||
{
|
||||
title: t('操作'),
|
||||
@@ -352,6 +591,10 @@ const CustomOAuthSetting = ({ serverAddress }) => {
|
||||
},
|
||||
];
|
||||
|
||||
const discoveryAutoFilledLabels = (discoveryInfo?.autoFilledFields || [])
|
||||
.map((field) => DISCOVERY_FIELD_LABELS[field] || field)
|
||||
.join(', ');
|
||||
|
||||
return (
|
||||
<Card>
|
||||
<Form.Section text={t('自定义 OAuth 提供商')}>
|
||||
@@ -391,56 +634,142 @@ const CustomOAuthSetting = ({ serverAddress }) => {
|
||||
<Modal
|
||||
title={editingProvider ? t('编辑 OAuth 提供商') : t('添加 OAuth 提供商')}
|
||||
visible={modalVisible}
|
||||
onOk={handleSubmit}
|
||||
onCancel={() => setModalVisible(false)}
|
||||
okText={t('保存')}
|
||||
cancelText={t('取消')}
|
||||
width={800}
|
||||
onCancel={closeModal}
|
||||
width={860}
|
||||
centered
|
||||
bodyStyle={{ maxHeight: '72vh', overflowY: 'auto', paddingRight: 6 }}
|
||||
footer={
|
||||
<div
|
||||
style={{
|
||||
display: 'flex',
|
||||
justifyContent: 'flex-end',
|
||||
alignItems: 'center',
|
||||
gap: 12,
|
||||
flexWrap: 'wrap',
|
||||
}}
|
||||
>
|
||||
<Space spacing={8} align='center'>
|
||||
<Text type='secondary'>{t('启用供应商')}</Text>
|
||||
<Switch
|
||||
checked={!!formValues.enabled}
|
||||
size='large'
|
||||
onChange={(checked) => mergeFormValues({ enabled: !!checked })}
|
||||
/>
|
||||
<Tag color={formValues.enabled ? 'green' : 'grey'}>
|
||||
{formValues.enabled ? t('已启用') : t('已禁用')}
|
||||
</Tag>
|
||||
</Space>
|
||||
<Button onClick={closeModal}>{t('取消')}</Button>
|
||||
<Button type='primary' onClick={handleSubmit}>
|
||||
{t('保存')}
|
||||
</Button>
|
||||
</div>
|
||||
}
|
||||
>
|
||||
<Form
|
||||
initValues={formValues}
|
||||
onValueChange={(values) => setFormValues(values)}
|
||||
onValueChange={() => {
|
||||
setFormValues((prev) => ({ ...prev, ...getLatestFormValues() }));
|
||||
}}
|
||||
getFormApi={(api) => (formApiRef.current = api)}
|
||||
>
|
||||
{!editingProvider && (
|
||||
<Row gutter={16} style={{ marginBottom: 16 }}>
|
||||
<Col span={12}>
|
||||
<Form.Select
|
||||
field="preset"
|
||||
label={t('预设模板')}
|
||||
placeholder={t('选择预设模板(可选)')}
|
||||
value={selectedPreset}
|
||||
onChange={handlePresetChange}
|
||||
optionList={[
|
||||
{ value: '', label: t('自定义') },
|
||||
...Object.entries(OAUTH_PRESETS).map(([key, config]) => ({
|
||||
value: key,
|
||||
label: config.name,
|
||||
})),
|
||||
]}
|
||||
/>
|
||||
</Col>
|
||||
<Col span={12}>
|
||||
<Form.Input
|
||||
field="base_url"
|
||||
label={
|
||||
selectedPreset
|
||||
? t('服务器地址') + ' *'
|
||||
: t('服务器地址')
|
||||
}
|
||||
placeholder={t('例如:https://gitea.example.com')}
|
||||
value={baseUrl}
|
||||
onChange={handleBaseUrlChange}
|
||||
extraText={
|
||||
selectedPreset
|
||||
? t('必填:请输入服务器地址以自动生成完整端点 URL')
|
||||
: t('选择预设模板后填写服务器地址可自动填充端点')
|
||||
}
|
||||
/>
|
||||
</Col>
|
||||
</Row>
|
||||
<Text strong style={{ display: 'block', marginBottom: 8 }}>
|
||||
{t('Configuration')}
|
||||
</Text>
|
||||
<Text type="secondary" style={{ display: 'block', marginBottom: 8 }}>
|
||||
{t('先填写配置,再自动填充 OAuth 端点,能显著减少手工输入')}
|
||||
</Text>
|
||||
{discoveryInfo && (
|
||||
<Banner
|
||||
type='success'
|
||||
closeIcon={null}
|
||||
style={{ marginBottom: 12 }}
|
||||
description={
|
||||
<div>
|
||||
<div>
|
||||
{t('已从 Discovery 获取配置,可继续手动修改所有字段。')}
|
||||
</div>
|
||||
{discoveryAutoFilledLabels ? (
|
||||
<div>
|
||||
{t('自动填充字段')}:
|
||||
{' '}
|
||||
{discoveryAutoFilledLabels}
|
||||
</div>
|
||||
) : null}
|
||||
{discoveryInfo.scopesSupported?.length ? (
|
||||
<div>
|
||||
{t('Discovery scopes')}:
|
||||
{' '}
|
||||
{discoveryInfo.scopesSupported.join(', ')}
|
||||
</div>
|
||||
) : null}
|
||||
{discoveryInfo.claimsSupported?.length ? (
|
||||
<div>
|
||||
{t('Discovery claims')}:
|
||||
{' '}
|
||||
{discoveryInfo.claimsSupported.join(', ')}
|
||||
</div>
|
||||
) : null}
|
||||
</div>
|
||||
}
|
||||
/>
|
||||
)}
|
||||
|
||||
<Row gutter={16}>
|
||||
<Col span={8}>
|
||||
<Form.Select
|
||||
field="preset"
|
||||
label={t('预设模板')}
|
||||
placeholder={t('选择预设模板(可选)')}
|
||||
value={selectedPreset}
|
||||
onChange={handlePresetChange}
|
||||
optionList={[
|
||||
{ value: '', label: t('自定义') },
|
||||
...Object.entries(OAUTH_PRESETS).map(([key, config]) => ({
|
||||
value: key,
|
||||
label: config.name,
|
||||
})),
|
||||
]}
|
||||
/>
|
||||
</Col>
|
||||
<Col span={10}>
|
||||
<Form.Input
|
||||
field="base_url"
|
||||
label={t('发行者 URL(Issuer URL)')}
|
||||
placeholder={t('例如:https://gitea.example.com')}
|
||||
value={baseUrl}
|
||||
onChange={handleBaseUrlChange}
|
||||
extraText={
|
||||
selectedPreset
|
||||
? t('填写后会自动拼接预设端点')
|
||||
: t('可选:用于自动生成端点或 Discovery URL')
|
||||
}
|
||||
/>
|
||||
</Col>
|
||||
<Col span={6}>
|
||||
<div style={{ display: 'flex', alignItems: 'flex-end', height: '100%' }}>
|
||||
<Button
|
||||
icon={<IconRefresh />}
|
||||
onClick={handleFetchFromDiscovery}
|
||||
loading={discoveryLoading}
|
||||
block
|
||||
>
|
||||
{t('获取 Discovery 配置')}
|
||||
</Button>
|
||||
</div>
|
||||
</Col>
|
||||
</Row>
|
||||
<Row gutter={16}>
|
||||
<Col span={24}>
|
||||
<Form.Input
|
||||
field="well_known"
|
||||
label={t('发现文档地址(Discovery URL,可选)')}
|
||||
placeholder={t('例如:https://example.com/.well-known/openid-configuration')}
|
||||
extraText={t('可留空;留空时会尝试使用 Issuer URL + /.well-known/openid-configuration')}
|
||||
/>
|
||||
</Col>
|
||||
</Row>
|
||||
|
||||
<Row gutter={16}>
|
||||
<Col span={12}>
|
||||
<Form.Input
|
||||
@@ -461,6 +790,41 @@ const CustomOAuthSetting = ({ serverAddress }) => {
|
||||
</Col>
|
||||
</Row>
|
||||
|
||||
<Row gutter={16}>
|
||||
<Col span={18}>
|
||||
<Form.Input
|
||||
field='icon'
|
||||
label={t('图标')}
|
||||
placeholder={t('例如:github / si:google / https://example.com/logo.png / 🐱')}
|
||||
extraText={
|
||||
<span>
|
||||
{t(
|
||||
'图标使用 react-icons(Simple Icons)或 URL/emoji,例如:github、gitlab、si:google',
|
||||
)}
|
||||
</span>
|
||||
}
|
||||
showClear
|
||||
/>
|
||||
</Col>
|
||||
<Col span={6} style={{ display: 'flex', alignItems: 'flex-end' }}>
|
||||
<div
|
||||
style={{
|
||||
width: '100%',
|
||||
minHeight: 74,
|
||||
border: '1px solid var(--semi-color-border)',
|
||||
borderRadius: 8,
|
||||
display: 'flex',
|
||||
alignItems: 'center',
|
||||
justifyContent: 'center',
|
||||
marginBottom: 24,
|
||||
background: 'var(--semi-color-fill-0)',
|
||||
}}
|
||||
>
|
||||
{getOAuthProviderIcon(formValues.icon || '', 24)}
|
||||
</div>
|
||||
</Col>
|
||||
</Row>
|
||||
|
||||
<Row gutter={16}>
|
||||
<Col span={12}>
|
||||
<Form.Input
|
||||
@@ -500,7 +864,7 @@ const CustomOAuthSetting = ({ serverAddress }) => {
|
||||
label={t('Authorization Endpoint')}
|
||||
placeholder={
|
||||
selectedPreset && OAUTH_PRESETS[selectedPreset]
|
||||
? t('填写服务器地址后自动生成:') +
|
||||
? t('填写 Issuer URL 后自动生成:') +
|
||||
OAUTH_PRESETS[selectedPreset].authorization_endpoint
|
||||
: 'https://example.com/oauth/authorize'
|
||||
}
|
||||
@@ -544,15 +908,14 @@ const CustomOAuthSetting = ({ serverAddress }) => {
|
||||
<Col span={12}>
|
||||
<Form.Input
|
||||
field="scopes"
|
||||
label={t('Scopes')}
|
||||
label={t('Scopes(可选)')}
|
||||
placeholder="openid profile email"
|
||||
/>
|
||||
</Col>
|
||||
<Col span={12}>
|
||||
<Form.Input
|
||||
field="well_known"
|
||||
label={t('Well-Known URL')}
|
||||
placeholder={t('OIDC Discovery 端点(可选)')}
|
||||
extraText={
|
||||
discoveryInfo?.scopesSupported?.length
|
||||
? t('Discovery 建议 scopes:') +
|
||||
discoveryInfo.scopesSupported.join(', ')
|
||||
: t('可手动填写,多个 scope 用空格分隔')
|
||||
}
|
||||
/>
|
||||
</Col>
|
||||
</Row>
|
||||
@@ -568,7 +931,7 @@ const CustomOAuthSetting = ({ serverAddress }) => {
|
||||
<Col span={12}>
|
||||
<Form.Input
|
||||
field="user_id_field"
|
||||
label={t('用户 ID 字段')}
|
||||
label={t('用户 ID 字段(可选)')}
|
||||
placeholder={t('例如:sub、id、data.user.id')}
|
||||
extraText={t('用于唯一标识用户的字段路径')}
|
||||
/>
|
||||
@@ -576,7 +939,7 @@ const CustomOAuthSetting = ({ serverAddress }) => {
|
||||
<Col span={12}>
|
||||
<Form.Input
|
||||
field="username_field"
|
||||
label={t('用户名字段')}
|
||||
label={t('用户名字段(可选)')}
|
||||
placeholder={t('例如:preferred_username、login')}
|
||||
/>
|
||||
</Col>
|
||||
@@ -586,41 +949,100 @@ const CustomOAuthSetting = ({ serverAddress }) => {
|
||||
<Col span={12}>
|
||||
<Form.Input
|
||||
field="display_name_field"
|
||||
label={t('显示名称字段')}
|
||||
label={t('显示名称字段(可选)')}
|
||||
placeholder={t('例如:name、full_name')}
|
||||
/>
|
||||
</Col>
|
||||
<Col span={12}>
|
||||
<Form.Input
|
||||
field="email_field"
|
||||
label={t('邮箱字段')}
|
||||
label={t('邮箱字段(可选)')}
|
||||
placeholder={t('例如:email')}
|
||||
/>
|
||||
</Col>
|
||||
</Row>
|
||||
|
||||
<Text strong style={{ display: 'block', margin: '16px 0 8px' }}>
|
||||
{t('高级选项')}
|
||||
</Text>
|
||||
<Collapse
|
||||
keepDOM
|
||||
activeKey={advancedActiveKeys}
|
||||
style={{ marginTop: 16 }}
|
||||
onChange={(activeKey) => {
|
||||
const keys = Array.isArray(activeKey) ? activeKey : [activeKey];
|
||||
setAdvancedActiveKeys(keys.filter(Boolean));
|
||||
}}
|
||||
>
|
||||
<Collapse.Panel header={t('高级选项')} itemKey='advanced'>
|
||||
<Row gutter={16}>
|
||||
<Col span={12}>
|
||||
<Form.Select
|
||||
field="auth_style"
|
||||
label={t('认证方式')}
|
||||
optionList={[
|
||||
{ value: 0, label: t('自动检测') },
|
||||
{ value: 1, label: t('POST 参数') },
|
||||
{ value: 2, label: t('Basic Auth 头') },
|
||||
]}
|
||||
/>
|
||||
</Col>
|
||||
</Row>
|
||||
|
||||
<Row gutter={16}>
|
||||
<Col span={12}>
|
||||
<Form.Select
|
||||
field="auth_style"
|
||||
label={t('认证方式')}
|
||||
optionList={[
|
||||
{ value: 0, label: t('自动检测') },
|
||||
{ value: 1, label: t('POST 参数') },
|
||||
{ value: 2, label: t('Basic Auth 头') },
|
||||
]}
|
||||
/>
|
||||
</Col>
|
||||
<Col span={12}>
|
||||
<Form.Checkbox field="enabled" noLabel>
|
||||
{t('启用此 OAuth 提供商')}
|
||||
</Form.Checkbox>
|
||||
</Col>
|
||||
</Row>
|
||||
<Text strong style={{ display: 'block', margin: '16px 0 8px' }}>
|
||||
{t('准入策略')}
|
||||
</Text>
|
||||
<Text type="secondary" style={{ display: 'block', marginBottom: 8 }}>
|
||||
{t('可选:基于用户信息 JSON 做组合条件准入,条件不满足时返回自定义提示')}
|
||||
</Text>
|
||||
<Row gutter={16}>
|
||||
<Col span={24}>
|
||||
<Form.TextArea
|
||||
field='access_policy'
|
||||
value={formValues.access_policy || ''}
|
||||
onChange={(value) => mergeFormValues({ access_policy: value })}
|
||||
label={t('准入策略 JSON(可选)')}
|
||||
rows={6}
|
||||
placeholder={`{
|
||||
"logic": "and",
|
||||
"conditions": [
|
||||
{"field": "trust_level", "op": "gte", "value": 2},
|
||||
{"field": "active", "op": "eq", "value": true}
|
||||
]
|
||||
}`}
|
||||
extraText={t('支持逻辑 and/or 与嵌套 groups;操作符支持 eq/ne/gt/gte/lt/lte/in/not_in/contains/exists')}
|
||||
showClear
|
||||
/>
|
||||
<Space spacing={8} style={{ marginTop: 8 }}>
|
||||
<Button size='small' theme='light' onClick={() => applyAccessPolicyTemplate('level_active')}>
|
||||
{t('填充模板:等级+激活')}
|
||||
</Button>
|
||||
<Button size='small' theme='light' onClick={() => applyAccessPolicyTemplate('org_or_role')}>
|
||||
{t('填充模板:组织或角色')}
|
||||
</Button>
|
||||
</Space>
|
||||
</Col>
|
||||
</Row>
|
||||
<Row gutter={16}>
|
||||
<Col span={24}>
|
||||
<Form.Input
|
||||
field='access_denied_message'
|
||||
value={formValues.access_denied_message || ''}
|
||||
onChange={(value) => mergeFormValues({ access_denied_message: value })}
|
||||
label={t('拒绝提示模板(可选)')}
|
||||
placeholder={t('例如:需要等级 {{required}},你当前等级 {{current}}')}
|
||||
extraText={t('可用变量:{{provider}} {{field}} {{op}} {{required}} {{current}} 以及 {{current.path}}')}
|
||||
showClear
|
||||
/>
|
||||
<Space spacing={8} style={{ marginTop: 8 }}>
|
||||
<Button size='small' theme='light' onClick={() => applyDeniedTemplate('level_hint')}>
|
||||
{t('填充模板:等级提示')}
|
||||
</Button>
|
||||
<Button size='small' theme='light' onClick={() => applyDeniedTemplate('org_hint')}>
|
||||
{t('填充模板:组织提示')}
|
||||
</Button>
|
||||
</Space>
|
||||
</Col>
|
||||
</Row>
|
||||
</Collapse.Panel>
|
||||
</Collapse>
|
||||
</Form>
|
||||
</Modal>
|
||||
</Form.Section>
|
||||
|
||||
@@ -50,6 +50,7 @@ import {
|
||||
onLinuxDOOAuthClicked,
|
||||
onDiscordOAuthClicked,
|
||||
onCustomOAuthClicked,
|
||||
getOAuthProviderIcon,
|
||||
} from '../../../../helpers';
|
||||
import TwoFASetting from '../components/TwoFASetting';
|
||||
|
||||
@@ -148,12 +149,14 @@ const AccountManagement = ({
|
||||
|
||||
// Check if custom OAuth provider is bound
|
||||
const isCustomOAuthBound = (providerId) => {
|
||||
return customOAuthBindings.some((b) => b.provider_id === providerId);
|
||||
const normalizedId = Number(providerId);
|
||||
return customOAuthBindings.some((b) => Number(b.provider_id) === normalizedId);
|
||||
};
|
||||
|
||||
// Get binding info for a provider
|
||||
const getCustomOAuthBinding = (providerId) => {
|
||||
return customOAuthBindings.find((b) => b.provider_id === providerId);
|
||||
const normalizedId = Number(providerId);
|
||||
return customOAuthBindings.find((b) => Number(b.provider_id) === normalizedId);
|
||||
};
|
||||
|
||||
React.useEffect(() => {
|
||||
@@ -524,10 +527,10 @@ const AccountManagement = ({
|
||||
<div className='flex items-center justify-between gap-3'>
|
||||
<div className='flex items-center flex-1 min-w-0'>
|
||||
<div className='w-10 h-10 rounded-full bg-slate-100 dark:bg-slate-700 flex items-center justify-center mr-3 flex-shrink-0'>
|
||||
<IconLock
|
||||
size='default'
|
||||
className='text-slate-600 dark:text-slate-300'
|
||||
/>
|
||||
{getOAuthProviderIcon(
|
||||
provider.icon || binding?.provider_icon || '',
|
||||
20,
|
||||
)}
|
||||
</div>
|
||||
<div className='flex-1 min-w-0'>
|
||||
<div className='font-medium text-gray-900'>
|
||||
|
||||
@@ -86,6 +86,7 @@ const NotificationSettings = ({
|
||||
channel: true,
|
||||
models: true,
|
||||
deployment: true,
|
||||
subscription: true,
|
||||
redemption: true,
|
||||
user: true,
|
||||
setting: true,
|
||||
@@ -169,6 +170,7 @@ const NotificationSettings = ({
|
||||
channel: true,
|
||||
models: true,
|
||||
deployment: true,
|
||||
subscription: true,
|
||||
redemption: true,
|
||||
user: true,
|
||||
setting: true,
|
||||
@@ -296,6 +298,11 @@ const NotificationSettings = ({
|
||||
title: t('模型部署'),
|
||||
description: t('模型部署管理'),
|
||||
},
|
||||
{
|
||||
key: 'subscription',
|
||||
title: t('订阅管理'),
|
||||
description: t('订阅套餐管理'),
|
||||
},
|
||||
{
|
||||
key: 'redemption',
|
||||
title: t('兑换码管理'),
|
||||
|
||||
@@ -62,9 +62,14 @@ import CodexOAuthModal from './CodexOAuthModal';
|
||||
import ParamOverrideEditorModal from './ParamOverrideEditorModal';
|
||||
import JSONEditor from '../../../common/ui/JSONEditor';
|
||||
import SecureVerificationModal from '../../../common/modals/SecureVerificationModal';
|
||||
import StatusCodeRiskGuardModal from './StatusCodeRiskGuardModal';
|
||||
import ChannelKeyDisplay from '../../../common/ui/ChannelKeyDisplay';
|
||||
import { useSecureVerification } from '../../../../hooks/common/useSecureVerification';
|
||||
import { createApiCalls } from '../../../../services/secureVerification';
|
||||
import {
|
||||
collectInvalidStatusCodeEntries,
|
||||
collectNewDisallowedStatusCodeRedirects,
|
||||
} from './statusCodeRiskGuard';
|
||||
import {
|
||||
IconSave,
|
||||
IconClose,
|
||||
@@ -195,6 +200,8 @@ const EditChannelModal = (props) => {
|
||||
allow_service_tier: false,
|
||||
disable_store: false, // false = 允许透传(默认开启)
|
||||
allow_safety_identifier: false,
|
||||
allow_include_obfuscation: false,
|
||||
allow_inference_geo: false,
|
||||
claude_beta_query: false,
|
||||
};
|
||||
const [batch, setBatch] = useState(false);
|
||||
@@ -209,6 +216,7 @@ const EditChannelModal = (props) => {
|
||||
const [fullModels, setFullModels] = useState([]);
|
||||
const [modelGroups, setModelGroups] = useState([]);
|
||||
const [customModel, setCustomModel] = useState('');
|
||||
const [modelSearchValue, setModelSearchValue] = useState('');
|
||||
const [modalImageUrl, setModalImageUrl] = useState('');
|
||||
const [isModalOpenurl, setIsModalOpenurl] = useState(false);
|
||||
const [modelModalVisible, setModelModalVisible] = useState(false);
|
||||
@@ -249,6 +257,25 @@ const EditChannelModal = (props) => {
|
||||
return [];
|
||||
}
|
||||
}, [inputs.model_mapping]);
|
||||
const modelSearchMatchedCount = useMemo(() => {
|
||||
const keyword = modelSearchValue.trim();
|
||||
if (!keyword) {
|
||||
return modelOptions.length;
|
||||
}
|
||||
return modelOptions.reduce(
|
||||
(count, option) => count + (selectFilter(keyword, option) ? 1 : 0),
|
||||
0,
|
||||
);
|
||||
}, [modelOptions, modelSearchValue]);
|
||||
const modelSearchHintText = useMemo(() => {
|
||||
const keyword = modelSearchValue.trim();
|
||||
if (!keyword || modelSearchMatchedCount !== 0) {
|
||||
return '';
|
||||
}
|
||||
return t('未匹配到模型,按回车键可将「{{name}}」作为自定义模型名添加', {
|
||||
name: keyword,
|
||||
});
|
||||
}, [modelSearchMatchedCount, modelSearchValue, t]);
|
||||
const paramOverrideMeta = useMemo(() => {
|
||||
const raw =
|
||||
typeof inputs.param_override === 'string'
|
||||
@@ -338,6 +365,12 @@ const EditChannelModal = (props) => {
|
||||
window.open(targetUrl, '_blank', 'noopener');
|
||||
};
|
||||
const [verifyLoading, setVerifyLoading] = useState(false);
|
||||
const statusCodeRiskConfirmResolverRef = useRef(null);
|
||||
const [statusCodeRiskConfirmVisible, setStatusCodeRiskConfirmVisible] =
|
||||
useState(false);
|
||||
const [statusCodeRiskDetailItems, setStatusCodeRiskDetailItems] = useState(
|
||||
[],
|
||||
);
|
||||
|
||||
// 表单块导航相关状态
|
||||
const formSectionRefs = useRef({
|
||||
@@ -359,6 +392,7 @@ const EditChannelModal = (props) => {
|
||||
const doubaoApiClickCountRef = useRef(0);
|
||||
const initialModelsRef = useRef([]);
|
||||
const initialModelMappingRef = useRef('');
|
||||
const initialStatusCodeMappingRef = useRef('');
|
||||
|
||||
// 2FA状态更新辅助函数
|
||||
const updateTwoFAState = (updates) => {
|
||||
@@ -811,6 +845,10 @@ const EditChannelModal = (props) => {
|
||||
data.disable_store = parsedSettings.disable_store || false;
|
||||
data.allow_safety_identifier =
|
||||
parsedSettings.allow_safety_identifier || false;
|
||||
data.allow_include_obfuscation =
|
||||
parsedSettings.allow_include_obfuscation || false;
|
||||
data.allow_inference_geo =
|
||||
parsedSettings.allow_inference_geo || false;
|
||||
data.claude_beta_query = parsedSettings.claude_beta_query || false;
|
||||
} catch (error) {
|
||||
console.error('解析其他设置失败:', error);
|
||||
@@ -822,6 +860,8 @@ const EditChannelModal = (props) => {
|
||||
data.allow_service_tier = false;
|
||||
data.disable_store = false;
|
||||
data.allow_safety_identifier = false;
|
||||
data.allow_include_obfuscation = false;
|
||||
data.allow_inference_geo = false;
|
||||
data.claude_beta_query = false;
|
||||
}
|
||||
} else {
|
||||
@@ -832,6 +872,8 @@ const EditChannelModal = (props) => {
|
||||
data.allow_service_tier = false;
|
||||
data.disable_store = false;
|
||||
data.allow_safety_identifier = false;
|
||||
data.allow_include_obfuscation = false;
|
||||
data.allow_inference_geo = false;
|
||||
data.claude_beta_query = false;
|
||||
}
|
||||
|
||||
@@ -868,6 +910,7 @@ const EditChannelModal = (props) => {
|
||||
.map((model) => (model || '').trim())
|
||||
.filter(Boolean);
|
||||
initialModelMappingRef.current = data.model_mapping || '';
|
||||
initialStatusCodeMappingRef.current = data.status_code_mapping || '';
|
||||
|
||||
let parsedIonet = null;
|
||||
if (data.other_info) {
|
||||
@@ -1173,6 +1216,7 @@ const EditChannelModal = (props) => {
|
||||
}, [inputs]);
|
||||
|
||||
useEffect(() => {
|
||||
setModelSearchValue('');
|
||||
if (props.visible) {
|
||||
if (isEdit) {
|
||||
loadChannel();
|
||||
@@ -1194,11 +1238,22 @@ const EditChannelModal = (props) => {
|
||||
if (!isEdit) {
|
||||
initialModelsRef.current = [];
|
||||
initialModelMappingRef.current = '';
|
||||
initialStatusCodeMappingRef.current = '';
|
||||
}
|
||||
}, [isEdit, props.visible]);
|
||||
|
||||
useEffect(() => {
|
||||
return () => {
|
||||
if (statusCodeRiskConfirmResolverRef.current) {
|
||||
statusCodeRiskConfirmResolverRef.current(false);
|
||||
statusCodeRiskConfirmResolverRef.current = null;
|
||||
}
|
||||
};
|
||||
}, []);
|
||||
|
||||
// 统一的模态框重置函数
|
||||
const resetModalState = () => {
|
||||
resolveStatusCodeRiskConfirm(false);
|
||||
formApiRef.current?.reset();
|
||||
// 重置渠道设置状态
|
||||
setChannelSettings({
|
||||
@@ -1216,6 +1271,7 @@ const EditChannelModal = (props) => {
|
||||
// 重置豆包隐藏入口状态
|
||||
setDoubaoApiEditUnlocked(false);
|
||||
doubaoApiClickCountRef.current = 0;
|
||||
setModelSearchValue('');
|
||||
// 清空表单中的key_mode字段
|
||||
if (formApiRef.current) {
|
||||
formApiRef.current.setValue('key_mode', undefined);
|
||||
@@ -1328,6 +1384,22 @@ const EditChannelModal = (props) => {
|
||||
});
|
||||
});
|
||||
|
||||
const resolveStatusCodeRiskConfirm = (confirmed) => {
|
||||
setStatusCodeRiskConfirmVisible(false);
|
||||
setStatusCodeRiskDetailItems([]);
|
||||
if (statusCodeRiskConfirmResolverRef.current) {
|
||||
statusCodeRiskConfirmResolverRef.current(confirmed);
|
||||
statusCodeRiskConfirmResolverRef.current = null;
|
||||
}
|
||||
};
|
||||
|
||||
const confirmStatusCodeRisk = (detailItems) =>
|
||||
new Promise((resolve) => {
|
||||
statusCodeRiskConfirmResolverRef.current = resolve;
|
||||
setStatusCodeRiskDetailItems(detailItems);
|
||||
setStatusCodeRiskConfirmVisible(true);
|
||||
});
|
||||
|
||||
const hasModelConfigChanged = (normalizedModels, modelMappingStr) => {
|
||||
if (!isEdit) return true;
|
||||
const initialModels = initialModelsRef.current;
|
||||
@@ -1518,6 +1590,27 @@ const EditChannelModal = (props) => {
|
||||
}
|
||||
}
|
||||
|
||||
const invalidStatusCodeEntries = collectInvalidStatusCodeEntries(
|
||||
localInputs.status_code_mapping,
|
||||
);
|
||||
if (invalidStatusCodeEntries.length > 0) {
|
||||
showError(
|
||||
`${t('状态码复写包含无效的状态码')}: ${invalidStatusCodeEntries.join(', ')}`,
|
||||
);
|
||||
return;
|
||||
}
|
||||
|
||||
const riskyStatusCodeRedirects = collectNewDisallowedStatusCodeRedirects(
|
||||
initialStatusCodeMappingRef.current,
|
||||
localInputs.status_code_mapping,
|
||||
);
|
||||
if (riskyStatusCodeRedirects.length > 0) {
|
||||
const confirmed = await confirmStatusCodeRisk(riskyStatusCodeRedirects);
|
||||
if (!confirmed) {
|
||||
return;
|
||||
}
|
||||
}
|
||||
|
||||
if (localInputs.base_url && localInputs.base_url.endsWith('/')) {
|
||||
localInputs.base_url = localInputs.base_url.slice(
|
||||
0,
|
||||
@@ -1570,13 +1663,16 @@ const EditChannelModal = (props) => {
|
||||
// type === 1 (OpenAI) 或 type === 14 (Claude): 设置字段透传控制(显式保存布尔值)
|
||||
if (localInputs.type === 1 || localInputs.type === 14) {
|
||||
settings.allow_service_tier = localInputs.allow_service_tier === true;
|
||||
// 仅 OpenAI 渠道需要 store 和 safety_identifier
|
||||
// 仅 OpenAI 渠道需要 store / safety_identifier / include_obfuscation
|
||||
if (localInputs.type === 1) {
|
||||
settings.disable_store = localInputs.disable_store === true;
|
||||
settings.allow_safety_identifier =
|
||||
localInputs.allow_safety_identifier === true;
|
||||
settings.allow_include_obfuscation =
|
||||
localInputs.allow_include_obfuscation === true;
|
||||
}
|
||||
if (localInputs.type === 14) {
|
||||
settings.allow_inference_geo = localInputs.allow_inference_geo === true;
|
||||
settings.claude_beta_query = localInputs.claude_beta_query === true;
|
||||
}
|
||||
}
|
||||
@@ -1599,6 +1695,8 @@ const EditChannelModal = (props) => {
|
||||
delete localInputs.allow_service_tier;
|
||||
delete localInputs.disable_store;
|
||||
delete localInputs.allow_safety_identifier;
|
||||
delete localInputs.allow_include_obfuscation;
|
||||
delete localInputs.allow_inference_geo;
|
||||
delete localInputs.claude_beta_query;
|
||||
|
||||
let res;
|
||||
@@ -2917,9 +3015,18 @@ const EditChannelModal = (props) => {
|
||||
rules={[{ required: true, message: t('请选择模型') }]}
|
||||
multiple
|
||||
filter={selectFilter}
|
||||
allowCreate
|
||||
autoClearSearchValue={false}
|
||||
searchPosition='dropdown'
|
||||
optionList={modelOptions}
|
||||
onSearch={(value) => setModelSearchValue(value)}
|
||||
innerBottomSlot={
|
||||
modelSearchHintText ? (
|
||||
<Text className='px-3 py-2 block text-xs !text-semi-color-text-2'>
|
||||
{modelSearchHintText}
|
||||
</Text>
|
||||
) : null
|
||||
}
|
||||
style={{ width: '100%' }}
|
||||
onChange={(value) => handleInputChange('models', value)}
|
||||
renderSelectedItem={(optionNode) => {
|
||||
@@ -3444,6 +3551,24 @@ const EditChannelModal = (props) => {
|
||||
'safety_identifier 字段用于帮助 OpenAI 识别可能违反使用政策的应用程序用户。默认关闭以保护用户隐私',
|
||||
)}
|
||||
/>
|
||||
|
||||
<Form.Switch
|
||||
field='allow_include_obfuscation'
|
||||
label={t(
|
||||
'允许 stream_options.include_obfuscation 透传',
|
||||
)}
|
||||
checkedText={t('开')}
|
||||
uncheckedText={t('关')}
|
||||
onChange={(value) =>
|
||||
handleChannelOtherSettingsChange(
|
||||
'allow_include_obfuscation',
|
||||
value,
|
||||
)
|
||||
}
|
||||
extraText={t(
|
||||
'include_obfuscation 用于控制 Responses 流混淆字段。默认关闭以避免客户端关闭该安全保护',
|
||||
)}
|
||||
/>
|
||||
</>
|
||||
)}
|
||||
|
||||
@@ -3469,6 +3594,22 @@ const EditChannelModal = (props) => {
|
||||
'service_tier 字段用于指定服务层级,允许透传可能导致实际计费高于预期。默认关闭以避免额外费用',
|
||||
)}
|
||||
/>
|
||||
|
||||
<Form.Switch
|
||||
field='allow_inference_geo'
|
||||
label={t('允许 inference_geo 透传')}
|
||||
checkedText={t('开')}
|
||||
uncheckedText={t('关')}
|
||||
onChange={(value) =>
|
||||
handleChannelOtherSettingsChange(
|
||||
'allow_inference_geo',
|
||||
value,
|
||||
)
|
||||
}
|
||||
extraText={t(
|
||||
'inference_geo 字段用于控制 Claude 数据驻留推理区域。默认关闭以避免未经授权透传地域信息',
|
||||
)}
|
||||
/>
|
||||
</>
|
||||
)}
|
||||
</Card>
|
||||
@@ -3613,6 +3754,12 @@ const EditChannelModal = (props) => {
|
||||
onVisibleChange={(visible) => setIsModalOpenurl(visible)}
|
||||
/>
|
||||
</SideSheet>
|
||||
<StatusCodeRiskGuardModal
|
||||
visible={statusCodeRiskConfirmVisible}
|
||||
detailItems={statusCodeRiskDetailItems}
|
||||
onCancel={() => resolveStatusCodeRiskConfirm(false)}
|
||||
onConfirm={() => resolveStatusCodeRiskConfirm(true)}
|
||||
/>
|
||||
{/* 使用通用安全验证模态框 */}
|
||||
<SecureVerificationModal
|
||||
visible={isModalVisible}
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user