mirror of
https://github.com/Wei-Shaw/sub2api.git
synced 2026-04-19 02:07:26 +00:00
Merge branch 'main' of https://github.com/james-6-23/sub2api
This commit is contained in:
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
@@ -9,5 +9,5 @@ var ProviderSet = wire.NewSet(
|
||||
|
||||
// ProvideConfig 提供应用配置
|
||||
func ProvideConfig() (*Config, error) {
|
||||
return Load()
|
||||
return LoadForBootstrap()
|
||||
}
|
||||
|
||||
@@ -22,6 +22,7 @@ const (
|
||||
PlatformOpenAI = "openai"
|
||||
PlatformGemini = "gemini"
|
||||
PlatformAntigravity = "antigravity"
|
||||
PlatformSora = "sora"
|
||||
)
|
||||
|
||||
// Account type constants
|
||||
@@ -73,6 +74,7 @@ var DefaultAntigravityModelMapping = map[string]string{
|
||||
"claude-opus-4-6-thinking": "claude-opus-4-6-thinking", // 官方模型
|
||||
"claude-opus-4-6": "claude-opus-4-6-thinking", // 简称映射
|
||||
"claude-opus-4-5-thinking": "claude-opus-4-6-thinking", // 迁移旧模型
|
||||
"claude-sonnet-4-6": "claude-sonnet-4-6",
|
||||
"claude-sonnet-4-5": "claude-sonnet-4-5",
|
||||
"claude-sonnet-4-5-thinking": "claude-sonnet-4-5-thinking",
|
||||
// Claude 详细版本 ID 映射
|
||||
@@ -87,14 +89,24 @@ var DefaultAntigravityModelMapping = map[string]string{
|
||||
"gemini-2.5-flash-thinking": "gemini-2.5-flash-thinking",
|
||||
"gemini-2.5-pro": "gemini-2.5-pro",
|
||||
// Gemini 3 白名单
|
||||
"gemini-3-flash": "gemini-3-flash",
|
||||
"gemini-3-pro-high": "gemini-3-pro-high",
|
||||
"gemini-3-pro-low": "gemini-3-pro-low",
|
||||
"gemini-3-pro-image": "gemini-3-pro-image",
|
||||
"gemini-3-flash": "gemini-3-flash",
|
||||
"gemini-3-pro-high": "gemini-3-pro-high",
|
||||
"gemini-3-pro-low": "gemini-3-pro-low",
|
||||
// Gemini 3 preview 映射
|
||||
"gemini-3-flash-preview": "gemini-3-flash",
|
||||
"gemini-3-pro-preview": "gemini-3-pro-high",
|
||||
"gemini-3-pro-image-preview": "gemini-3-pro-image",
|
||||
"gemini-3-flash-preview": "gemini-3-flash",
|
||||
"gemini-3-pro-preview": "gemini-3-pro-high",
|
||||
// Gemini 3.1 白名单
|
||||
"gemini-3.1-pro-high": "gemini-3.1-pro-high",
|
||||
"gemini-3.1-pro-low": "gemini-3.1-pro-low",
|
||||
// Gemini 3.1 preview 映射
|
||||
"gemini-3.1-pro-preview": "gemini-3.1-pro-high",
|
||||
// Gemini 3.1 image 白名单
|
||||
"gemini-3.1-flash-image": "gemini-3.1-flash-image",
|
||||
// Gemini 3.1 image preview 映射
|
||||
"gemini-3.1-flash-image-preview": "gemini-3.1-flash-image",
|
||||
// Gemini 3 image 兼容映射(向 3.1 image 迁移)
|
||||
"gemini-3-pro-image": "gemini-3.1-flash-image",
|
||||
"gemini-3-pro-image-preview": "gemini-3.1-flash-image",
|
||||
// 其他官方模型
|
||||
"gpt-oss-120b-medium": "gpt-oss-120b-medium",
|
||||
"tab_flash_lite_preview": "tab_flash_lite_preview",
|
||||
|
||||
24
backend/internal/domain/constants_test.go
Normal file
24
backend/internal/domain/constants_test.go
Normal file
@@ -0,0 +1,24 @@
|
||||
package domain
|
||||
|
||||
import "testing"
|
||||
|
||||
func TestDefaultAntigravityModelMapping_ImageCompatibilityAliases(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
cases := map[string]string{
|
||||
"gemini-3.1-flash-image": "gemini-3.1-flash-image",
|
||||
"gemini-3.1-flash-image-preview": "gemini-3.1-flash-image",
|
||||
"gemini-3-pro-image": "gemini-3.1-flash-image",
|
||||
"gemini-3-pro-image-preview": "gemini-3.1-flash-image",
|
||||
}
|
||||
|
||||
for from, want := range cases {
|
||||
got, ok := DefaultAntigravityModelMapping[from]
|
||||
if !ok {
|
||||
t.Fatalf("expected mapping for %q to exist", from)
|
||||
}
|
||||
if got != want {
|
||||
t.Fatalf("unexpected mapping for %q: got %q want %q", from, got, want)
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -175,22 +175,28 @@ func (h *AccountHandler) ImportData(c *gin.Context) {
|
||||
return
|
||||
}
|
||||
|
||||
dataPayload := req.Data
|
||||
if err := validateDataHeader(dataPayload); err != nil {
|
||||
if err := validateDataHeader(req.Data); err != nil {
|
||||
response.BadRequest(c, err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
executeAdminIdempotentJSON(c, "admin.accounts.import_data", req, service.DefaultWriteIdempotencyTTL(), func(ctx context.Context) (any, error) {
|
||||
return h.importData(ctx, req)
|
||||
})
|
||||
}
|
||||
|
||||
func (h *AccountHandler) importData(ctx context.Context, req DataImportRequest) (DataImportResult, error) {
|
||||
skipDefaultGroupBind := true
|
||||
if req.SkipDefaultGroupBind != nil {
|
||||
skipDefaultGroupBind = *req.SkipDefaultGroupBind
|
||||
}
|
||||
|
||||
dataPayload := req.Data
|
||||
result := DataImportResult{}
|
||||
existingProxies, err := h.listAllProxies(c.Request.Context())
|
||||
|
||||
existingProxies, err := h.listAllProxies(ctx)
|
||||
if err != nil {
|
||||
response.ErrorFrom(c, err)
|
||||
return
|
||||
return result, err
|
||||
}
|
||||
|
||||
proxyKeyToID := make(map[string]int64, len(existingProxies))
|
||||
@@ -221,8 +227,8 @@ func (h *AccountHandler) ImportData(c *gin.Context) {
|
||||
proxyKeyToID[key] = existingID
|
||||
result.ProxyReused++
|
||||
if normalizedStatus != "" {
|
||||
if proxy, err := h.adminService.GetProxy(c.Request.Context(), existingID); err == nil && proxy != nil && proxy.Status != normalizedStatus {
|
||||
_, _ = h.adminService.UpdateProxy(c.Request.Context(), existingID, &service.UpdateProxyInput{
|
||||
if proxy, getErr := h.adminService.GetProxy(ctx, existingID); getErr == nil && proxy != nil && proxy.Status != normalizedStatus {
|
||||
_, _ = h.adminService.UpdateProxy(ctx, existingID, &service.UpdateProxyInput{
|
||||
Status: normalizedStatus,
|
||||
})
|
||||
}
|
||||
@@ -230,7 +236,7 @@ func (h *AccountHandler) ImportData(c *gin.Context) {
|
||||
continue
|
||||
}
|
||||
|
||||
created, err := h.adminService.CreateProxy(c.Request.Context(), &service.CreateProxyInput{
|
||||
created, createErr := h.adminService.CreateProxy(ctx, &service.CreateProxyInput{
|
||||
Name: defaultProxyName(item.Name),
|
||||
Protocol: item.Protocol,
|
||||
Host: item.Host,
|
||||
@@ -238,13 +244,13 @@ func (h *AccountHandler) ImportData(c *gin.Context) {
|
||||
Username: item.Username,
|
||||
Password: item.Password,
|
||||
})
|
||||
if err != nil {
|
||||
if createErr != nil {
|
||||
result.ProxyFailed++
|
||||
result.Errors = append(result.Errors, DataImportError{
|
||||
Kind: "proxy",
|
||||
Name: item.Name,
|
||||
ProxyKey: key,
|
||||
Message: err.Error(),
|
||||
Message: createErr.Error(),
|
||||
})
|
||||
continue
|
||||
}
|
||||
@@ -252,7 +258,7 @@ func (h *AccountHandler) ImportData(c *gin.Context) {
|
||||
result.ProxyCreated++
|
||||
|
||||
if normalizedStatus != "" && normalizedStatus != created.Status {
|
||||
_, _ = h.adminService.UpdateProxy(c.Request.Context(), created.ID, &service.UpdateProxyInput{
|
||||
_, _ = h.adminService.UpdateProxy(ctx, created.ID, &service.UpdateProxyInput{
|
||||
Status: normalizedStatus,
|
||||
})
|
||||
}
|
||||
@@ -303,7 +309,7 @@ func (h *AccountHandler) ImportData(c *gin.Context) {
|
||||
SkipDefaultGroupBind: skipDefaultGroupBind,
|
||||
}
|
||||
|
||||
if _, err := h.adminService.CreateAccount(c.Request.Context(), accountInput); err != nil {
|
||||
if _, err := h.adminService.CreateAccount(ctx, accountInput); err != nil {
|
||||
result.AccountFailed++
|
||||
result.Errors = append(result.Errors, DataImportError{
|
||||
Kind: "account",
|
||||
@@ -315,7 +321,7 @@ func (h *AccountHandler) ImportData(c *gin.Context) {
|
||||
result.AccountCreated++
|
||||
}
|
||||
|
||||
response.Success(c, result)
|
||||
return result, nil
|
||||
}
|
||||
|
||||
func (h *AccountHandler) listAllProxies(ctx context.Context) ([]service.Proxy, error) {
|
||||
|
||||
@@ -64,6 +64,7 @@ func setupAccountDataRouter() (*gin.Engine, *stubAdminService) {
|
||||
nil,
|
||||
nil,
|
||||
nil,
|
||||
nil,
|
||||
)
|
||||
|
||||
router.GET("/api/v1/admin/accounts/data", h.ExportData)
|
||||
|
||||
@@ -2,7 +2,13 @@
|
||||
package admin
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/sha256"
|
||||
"encoding/hex"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"strconv"
|
||||
"strings"
|
||||
"sync"
|
||||
@@ -10,6 +16,7 @@ import (
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/domain"
|
||||
"github.com/Wei-Shaw/sub2api/internal/handler/dto"
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/antigravity"
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/claude"
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/geminicli"
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/openai"
|
||||
@@ -46,6 +53,7 @@ type AccountHandler struct {
|
||||
concurrencyService *service.ConcurrencyService
|
||||
crsSyncService *service.CRSSyncService
|
||||
sessionLimitCache service.SessionLimitCache
|
||||
rpmCache service.RPMCache
|
||||
tokenCacheInvalidator service.TokenCacheInvalidator
|
||||
}
|
||||
|
||||
@@ -62,6 +70,7 @@ func NewAccountHandler(
|
||||
concurrencyService *service.ConcurrencyService,
|
||||
crsSyncService *service.CRSSyncService,
|
||||
sessionLimitCache service.SessionLimitCache,
|
||||
rpmCache service.RPMCache,
|
||||
tokenCacheInvalidator service.TokenCacheInvalidator,
|
||||
) *AccountHandler {
|
||||
return &AccountHandler{
|
||||
@@ -76,6 +85,7 @@ func NewAccountHandler(
|
||||
concurrencyService: concurrencyService,
|
||||
crsSyncService: crsSyncService,
|
||||
sessionLimitCache: sessionLimitCache,
|
||||
rpmCache: rpmCache,
|
||||
tokenCacheInvalidator: tokenCacheInvalidator,
|
||||
}
|
||||
}
|
||||
@@ -133,6 +143,13 @@ type BulkUpdateAccountsRequest struct {
|
||||
ConfirmMixedChannelRisk *bool `json:"confirm_mixed_channel_risk"` // 用户确认混合渠道风险
|
||||
}
|
||||
|
||||
// CheckMixedChannelRequest represents check mixed channel risk request
|
||||
type CheckMixedChannelRequest struct {
|
||||
Platform string `json:"platform" binding:"required"`
|
||||
GroupIDs []int64 `json:"group_ids"`
|
||||
AccountID *int64 `json:"account_id"`
|
||||
}
|
||||
|
||||
// AccountWithConcurrency extends Account with real-time concurrency info
|
||||
type AccountWithConcurrency struct {
|
||||
*dto.Account
|
||||
@@ -140,6 +157,51 @@ type AccountWithConcurrency struct {
|
||||
// 以下字段仅对 Anthropic OAuth/SetupToken 账号有效,且仅在启用相应功能时返回
|
||||
CurrentWindowCost *float64 `json:"current_window_cost,omitempty"` // 当前窗口费用
|
||||
ActiveSessions *int `json:"active_sessions,omitempty"` // 当前活跃会话数
|
||||
CurrentRPM *int `json:"current_rpm,omitempty"` // 当前分钟 RPM 计数
|
||||
}
|
||||
|
||||
func (h *AccountHandler) buildAccountResponseWithRuntime(ctx context.Context, account *service.Account) AccountWithConcurrency {
|
||||
item := AccountWithConcurrency{
|
||||
Account: dto.AccountFromService(account),
|
||||
CurrentConcurrency: 0,
|
||||
}
|
||||
if account == nil {
|
||||
return item
|
||||
}
|
||||
|
||||
if h.concurrencyService != nil {
|
||||
if counts, err := h.concurrencyService.GetAccountConcurrencyBatch(ctx, []int64{account.ID}); err == nil {
|
||||
item.CurrentConcurrency = counts[account.ID]
|
||||
}
|
||||
}
|
||||
|
||||
if account.IsAnthropicOAuthOrSetupToken() {
|
||||
if h.accountUsageService != nil && account.GetWindowCostLimit() > 0 {
|
||||
startTime := account.GetCurrentWindowStartTime()
|
||||
if stats, err := h.accountUsageService.GetAccountWindowStats(ctx, account.ID, startTime); err == nil && stats != nil {
|
||||
cost := stats.StandardCost
|
||||
item.CurrentWindowCost = &cost
|
||||
}
|
||||
}
|
||||
|
||||
if h.sessionLimitCache != nil && account.GetMaxSessions() > 0 {
|
||||
idleTimeout := time.Duration(account.GetSessionIdleTimeoutMinutes()) * time.Minute
|
||||
idleTimeouts := map[int64]time.Duration{account.ID: idleTimeout}
|
||||
if sessions, err := h.sessionLimitCache.GetActiveSessionCountBatch(ctx, []int64{account.ID}, idleTimeouts); err == nil {
|
||||
if count, ok := sessions[account.ID]; ok {
|
||||
item.ActiveSessions = &count
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if h.rpmCache != nil && account.GetBaseRPM() > 0 {
|
||||
if rpm, err := h.rpmCache.GetRPM(ctx, account.ID); err == nil {
|
||||
item.CurrentRPM = &rpm
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return item
|
||||
}
|
||||
|
||||
// List handles listing all accounts with pagination
|
||||
@@ -155,6 +217,7 @@ func (h *AccountHandler) List(c *gin.Context) {
|
||||
if len(search) > 100 {
|
||||
search = search[:100]
|
||||
}
|
||||
lite := parseBoolQueryWithDefault(c.Query("lite"), false)
|
||||
|
||||
var groupID int64
|
||||
if groupIDStr := c.Query("group"); groupIDStr != "" {
|
||||
@@ -173,67 +236,81 @@ func (h *AccountHandler) List(c *gin.Context) {
|
||||
accountIDs[i] = acc.ID
|
||||
}
|
||||
|
||||
concurrencyCounts, err := h.concurrencyService.GetAccountConcurrencyBatch(c.Request.Context(), accountIDs)
|
||||
if err != nil {
|
||||
// Log error but don't fail the request, just use 0 for all
|
||||
concurrencyCounts = make(map[int64]int)
|
||||
}
|
||||
|
||||
// 识别需要查询窗口费用和会话数的账号(Anthropic OAuth/SetupToken 且启用了相应功能)
|
||||
windowCostAccountIDs := make([]int64, 0)
|
||||
sessionLimitAccountIDs := make([]int64, 0)
|
||||
sessionIdleTimeouts := make(map[int64]time.Duration) // 各账号的会话空闲超时配置
|
||||
for i := range accounts {
|
||||
acc := &accounts[i]
|
||||
if acc.IsAnthropicOAuthOrSetupToken() {
|
||||
if acc.GetWindowCostLimit() > 0 {
|
||||
windowCostAccountIDs = append(windowCostAccountIDs, acc.ID)
|
||||
}
|
||||
if acc.GetMaxSessions() > 0 {
|
||||
sessionLimitAccountIDs = append(sessionLimitAccountIDs, acc.ID)
|
||||
sessionIdleTimeouts[acc.ID] = time.Duration(acc.GetSessionIdleTimeoutMinutes()) * time.Minute
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// 并行获取窗口费用和活跃会话数
|
||||
concurrencyCounts := make(map[int64]int)
|
||||
var windowCosts map[int64]float64
|
||||
var activeSessions map[int64]int
|
||||
|
||||
// 获取活跃会话数(批量查询,传入各账号的 idleTimeout 配置)
|
||||
if len(sessionLimitAccountIDs) > 0 && h.sessionLimitCache != nil {
|
||||
activeSessions, _ = h.sessionLimitCache.GetActiveSessionCountBatch(c.Request.Context(), sessionLimitAccountIDs, sessionIdleTimeouts)
|
||||
if activeSessions == nil {
|
||||
activeSessions = make(map[int64]int)
|
||||
var rpmCounts map[int64]int
|
||||
if !lite {
|
||||
// Get current concurrency counts for all accounts
|
||||
if h.concurrencyService != nil {
|
||||
if cc, ccErr := h.concurrencyService.GetAccountConcurrencyBatch(c.Request.Context(), accountIDs); ccErr == nil && cc != nil {
|
||||
concurrencyCounts = cc
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// 获取窗口费用(并行查询)
|
||||
if len(windowCostAccountIDs) > 0 {
|
||||
windowCosts = make(map[int64]float64)
|
||||
var mu sync.Mutex
|
||||
g, gctx := errgroup.WithContext(c.Request.Context())
|
||||
g.SetLimit(10) // 限制并发数
|
||||
|
||||
// 识别需要查询窗口费用、会话数和 RPM 的账号(Anthropic OAuth/SetupToken 且启用了相应功能)
|
||||
windowCostAccountIDs := make([]int64, 0)
|
||||
sessionLimitAccountIDs := make([]int64, 0)
|
||||
rpmAccountIDs := make([]int64, 0)
|
||||
sessionIdleTimeouts := make(map[int64]time.Duration) // 各账号的会话空闲超时配置
|
||||
for i := range accounts {
|
||||
acc := &accounts[i]
|
||||
if !acc.IsAnthropicOAuthOrSetupToken() || acc.GetWindowCostLimit() <= 0 {
|
||||
continue
|
||||
}
|
||||
accCopy := acc // 闭包捕获
|
||||
g.Go(func() error {
|
||||
// 使用统一的窗口开始时间计算逻辑(考虑窗口过期情况)
|
||||
startTime := accCopy.GetCurrentWindowStartTime()
|
||||
stats, err := h.accountUsageService.GetAccountWindowStats(gctx, accCopy.ID, startTime)
|
||||
if err == nil && stats != nil {
|
||||
mu.Lock()
|
||||
windowCosts[accCopy.ID] = stats.StandardCost // 使用标准费用
|
||||
mu.Unlock()
|
||||
if acc.IsAnthropicOAuthOrSetupToken() {
|
||||
if acc.GetWindowCostLimit() > 0 {
|
||||
windowCostAccountIDs = append(windowCostAccountIDs, acc.ID)
|
||||
}
|
||||
return nil // 不返回错误,允许部分失败
|
||||
})
|
||||
if acc.GetMaxSessions() > 0 {
|
||||
sessionLimitAccountIDs = append(sessionLimitAccountIDs, acc.ID)
|
||||
sessionIdleTimeouts[acc.ID] = time.Duration(acc.GetSessionIdleTimeoutMinutes()) * time.Minute
|
||||
}
|
||||
if acc.GetBaseRPM() > 0 {
|
||||
rpmAccountIDs = append(rpmAccountIDs, acc.ID)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// 获取 RPM 计数(批量查询)
|
||||
if len(rpmAccountIDs) > 0 && h.rpmCache != nil {
|
||||
rpmCounts, _ = h.rpmCache.GetRPMBatch(c.Request.Context(), rpmAccountIDs)
|
||||
if rpmCounts == nil {
|
||||
rpmCounts = make(map[int64]int)
|
||||
}
|
||||
}
|
||||
|
||||
// 获取活跃会话数(批量查询,传入各账号的 idleTimeout 配置)
|
||||
if len(sessionLimitAccountIDs) > 0 && h.sessionLimitCache != nil {
|
||||
activeSessions, _ = h.sessionLimitCache.GetActiveSessionCountBatch(c.Request.Context(), sessionLimitAccountIDs, sessionIdleTimeouts)
|
||||
if activeSessions == nil {
|
||||
activeSessions = make(map[int64]int)
|
||||
}
|
||||
}
|
||||
|
||||
// 获取窗口费用(并行查询)
|
||||
if len(windowCostAccountIDs) > 0 {
|
||||
windowCosts = make(map[int64]float64)
|
||||
var mu sync.Mutex
|
||||
g, gctx := errgroup.WithContext(c.Request.Context())
|
||||
g.SetLimit(10) // 限制并发数
|
||||
|
||||
for i := range accounts {
|
||||
acc := &accounts[i]
|
||||
if !acc.IsAnthropicOAuthOrSetupToken() || acc.GetWindowCostLimit() <= 0 {
|
||||
continue
|
||||
}
|
||||
accCopy := acc // 闭包捕获
|
||||
g.Go(func() error {
|
||||
// 使用统一的窗口开始时间计算逻辑(考虑窗口过期情况)
|
||||
startTime := accCopy.GetCurrentWindowStartTime()
|
||||
stats, err := h.accountUsageService.GetAccountWindowStats(gctx, accCopy.ID, startTime)
|
||||
if err == nil && stats != nil {
|
||||
mu.Lock()
|
||||
windowCosts[accCopy.ID] = stats.StandardCost // 使用标准费用
|
||||
mu.Unlock()
|
||||
}
|
||||
return nil // 不返回错误,允许部分失败
|
||||
})
|
||||
}
|
||||
_ = g.Wait()
|
||||
}
|
||||
_ = g.Wait()
|
||||
}
|
||||
|
||||
// Build response with concurrency info
|
||||
@@ -259,12 +336,84 @@ func (h *AccountHandler) List(c *gin.Context) {
|
||||
}
|
||||
}
|
||||
|
||||
// 添加 RPM 计数(仅当启用时)
|
||||
if rpmCounts != nil {
|
||||
if rpm, ok := rpmCounts[acc.ID]; ok {
|
||||
item.CurrentRPM = &rpm
|
||||
}
|
||||
}
|
||||
|
||||
result[i] = item
|
||||
}
|
||||
|
||||
etag := buildAccountsListETag(result, total, page, pageSize, platform, accountType, status, search, lite)
|
||||
if etag != "" {
|
||||
c.Header("ETag", etag)
|
||||
c.Header("Vary", "If-None-Match")
|
||||
if ifNoneMatchMatched(c.GetHeader("If-None-Match"), etag) {
|
||||
c.Status(http.StatusNotModified)
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
response.Paginated(c, result, total, page, pageSize)
|
||||
}
|
||||
|
||||
func buildAccountsListETag(
|
||||
items []AccountWithConcurrency,
|
||||
total int64,
|
||||
page, pageSize int,
|
||||
platform, accountType, status, search string,
|
||||
lite bool,
|
||||
) string {
|
||||
payload := struct {
|
||||
Total int64 `json:"total"`
|
||||
Page int `json:"page"`
|
||||
PageSize int `json:"page_size"`
|
||||
Platform string `json:"platform"`
|
||||
AccountType string `json:"type"`
|
||||
Status string `json:"status"`
|
||||
Search string `json:"search"`
|
||||
Lite bool `json:"lite"`
|
||||
Items []AccountWithConcurrency `json:"items"`
|
||||
}{
|
||||
Total: total,
|
||||
Page: page,
|
||||
PageSize: pageSize,
|
||||
Platform: platform,
|
||||
AccountType: accountType,
|
||||
Status: status,
|
||||
Search: search,
|
||||
Lite: lite,
|
||||
Items: items,
|
||||
}
|
||||
raw, err := json.Marshal(payload)
|
||||
if err != nil {
|
||||
return ""
|
||||
}
|
||||
sum := sha256.Sum256(raw)
|
||||
return "\"" + hex.EncodeToString(sum[:]) + "\""
|
||||
}
|
||||
|
||||
func ifNoneMatchMatched(ifNoneMatch, etag string) bool {
|
||||
if etag == "" || ifNoneMatch == "" {
|
||||
return false
|
||||
}
|
||||
for _, token := range strings.Split(ifNoneMatch, ",") {
|
||||
candidate := strings.TrimSpace(token)
|
||||
if candidate == "*" {
|
||||
return true
|
||||
}
|
||||
if candidate == etag {
|
||||
return true
|
||||
}
|
||||
if strings.HasPrefix(candidate, "W/") && strings.TrimPrefix(candidate, "W/") == etag {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// GetByID handles getting an account by ID
|
||||
// GET /api/v1/admin/accounts/:id
|
||||
func (h *AccountHandler) GetByID(c *gin.Context) {
|
||||
@@ -280,7 +429,51 @@ func (h *AccountHandler) GetByID(c *gin.Context) {
|
||||
return
|
||||
}
|
||||
|
||||
response.Success(c, dto.AccountFromService(account))
|
||||
response.Success(c, h.buildAccountResponseWithRuntime(c.Request.Context(), account))
|
||||
}
|
||||
|
||||
// CheckMixedChannel handles checking mixed channel risk for account-group binding.
|
||||
// POST /api/v1/admin/accounts/check-mixed-channel
|
||||
func (h *AccountHandler) CheckMixedChannel(c *gin.Context) {
|
||||
var req CheckMixedChannelRequest
|
||||
if err := c.ShouldBindJSON(&req); err != nil {
|
||||
response.BadRequest(c, "Invalid request: "+err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
if len(req.GroupIDs) == 0 {
|
||||
response.Success(c, gin.H{"has_risk": false})
|
||||
return
|
||||
}
|
||||
|
||||
accountID := int64(0)
|
||||
if req.AccountID != nil {
|
||||
accountID = *req.AccountID
|
||||
}
|
||||
|
||||
err := h.adminService.CheckMixedChannelRisk(c.Request.Context(), accountID, req.Platform, req.GroupIDs)
|
||||
if err != nil {
|
||||
var mixedErr *service.MixedChannelError
|
||||
if errors.As(err, &mixedErr) {
|
||||
response.Success(c, gin.H{
|
||||
"has_risk": true,
|
||||
"error": "mixed_channel_warning",
|
||||
"message": mixedErr.Error(),
|
||||
"details": gin.H{
|
||||
"group_id": mixedErr.GroupID,
|
||||
"group_name": mixedErr.GroupName,
|
||||
"current_platform": mixedErr.CurrentPlatform,
|
||||
"other_platform": mixedErr.OtherPlatform,
|
||||
},
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
response.ErrorFrom(c, err)
|
||||
return
|
||||
}
|
||||
|
||||
response.Success(c, gin.H{"has_risk": false})
|
||||
}
|
||||
|
||||
// Create handles creating a new account
|
||||
@@ -295,50 +488,57 @@ func (h *AccountHandler) Create(c *gin.Context) {
|
||||
response.BadRequest(c, "rate_multiplier must be >= 0")
|
||||
return
|
||||
}
|
||||
// base_rpm 输入校验:负值归零,超过 10000 截断
|
||||
sanitizeExtraBaseRPM(req.Extra)
|
||||
|
||||
// 确定是否跳过混合渠道检查
|
||||
skipCheck := req.ConfirmMixedChannelRisk != nil && *req.ConfirmMixedChannelRisk
|
||||
|
||||
account, err := h.adminService.CreateAccount(c.Request.Context(), &service.CreateAccountInput{
|
||||
Name: req.Name,
|
||||
Notes: req.Notes,
|
||||
Platform: req.Platform,
|
||||
Type: req.Type,
|
||||
Credentials: req.Credentials,
|
||||
Extra: req.Extra,
|
||||
ProxyID: req.ProxyID,
|
||||
Concurrency: req.Concurrency,
|
||||
Priority: req.Priority,
|
||||
RateMultiplier: req.RateMultiplier,
|
||||
GroupIDs: req.GroupIDs,
|
||||
ExpiresAt: req.ExpiresAt,
|
||||
AutoPauseOnExpired: req.AutoPauseOnExpired,
|
||||
SkipMixedChannelCheck: skipCheck,
|
||||
result, err := executeAdminIdempotent(c, "admin.accounts.create", req, service.DefaultWriteIdempotencyTTL(), func(ctx context.Context) (any, error) {
|
||||
account, execErr := h.adminService.CreateAccount(ctx, &service.CreateAccountInput{
|
||||
Name: req.Name,
|
||||
Notes: req.Notes,
|
||||
Platform: req.Platform,
|
||||
Type: req.Type,
|
||||
Credentials: req.Credentials,
|
||||
Extra: req.Extra,
|
||||
ProxyID: req.ProxyID,
|
||||
Concurrency: req.Concurrency,
|
||||
Priority: req.Priority,
|
||||
RateMultiplier: req.RateMultiplier,
|
||||
GroupIDs: req.GroupIDs,
|
||||
ExpiresAt: req.ExpiresAt,
|
||||
AutoPauseOnExpired: req.AutoPauseOnExpired,
|
||||
SkipMixedChannelCheck: skipCheck,
|
||||
})
|
||||
if execErr != nil {
|
||||
return nil, execErr
|
||||
}
|
||||
return h.buildAccountResponseWithRuntime(ctx, account), nil
|
||||
})
|
||||
if err != nil {
|
||||
// 检查是否为混合渠道错误
|
||||
var mixedErr *service.MixedChannelError
|
||||
if errors.As(err, &mixedErr) {
|
||||
// 返回特殊错误码要求确认
|
||||
// 创建接口仅返回最小必要字段,详细信息由专门检查接口提供
|
||||
c.JSON(409, gin.H{
|
||||
"error": "mixed_channel_warning",
|
||||
"message": mixedErr.Error(),
|
||||
"details": gin.H{
|
||||
"group_id": mixedErr.GroupID,
|
||||
"group_name": mixedErr.GroupName,
|
||||
"current_platform": mixedErr.CurrentPlatform,
|
||||
"other_platform": mixedErr.OtherPlatform,
|
||||
},
|
||||
"require_confirmation": true,
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
if retryAfter := service.RetryAfterSecondsFromError(err); retryAfter > 0 {
|
||||
c.Header("Retry-After", strconv.Itoa(retryAfter))
|
||||
}
|
||||
response.ErrorFrom(c, err)
|
||||
return
|
||||
}
|
||||
|
||||
response.Success(c, dto.AccountFromService(account))
|
||||
if result != nil && result.Replayed {
|
||||
c.Header("X-Idempotency-Replayed", "true")
|
||||
}
|
||||
response.Success(c, result.Data)
|
||||
}
|
||||
|
||||
// Update handles updating an account
|
||||
@@ -359,6 +559,8 @@ func (h *AccountHandler) Update(c *gin.Context) {
|
||||
response.BadRequest(c, "rate_multiplier must be >= 0")
|
||||
return
|
||||
}
|
||||
// base_rpm 输入校验:负值归零,超过 10000 截断
|
||||
sanitizeExtraBaseRPM(req.Extra)
|
||||
|
||||
// 确定是否跳过混合渠道检查
|
||||
skipCheck := req.ConfirmMixedChannelRisk != nil && *req.ConfirmMixedChannelRisk
|
||||
@@ -383,17 +585,10 @@ func (h *AccountHandler) Update(c *gin.Context) {
|
||||
// 检查是否为混合渠道错误
|
||||
var mixedErr *service.MixedChannelError
|
||||
if errors.As(err, &mixedErr) {
|
||||
// 返回特殊错误码要求确认
|
||||
// 更新接口仅返回最小必要字段,详细信息由专门检查接口提供
|
||||
c.JSON(409, gin.H{
|
||||
"error": "mixed_channel_warning",
|
||||
"message": mixedErr.Error(),
|
||||
"details": gin.H{
|
||||
"group_id": mixedErr.GroupID,
|
||||
"group_name": mixedErr.GroupName,
|
||||
"current_platform": mixedErr.CurrentPlatform,
|
||||
"other_platform": mixedErr.OtherPlatform,
|
||||
},
|
||||
"require_confirmation": true,
|
||||
})
|
||||
return
|
||||
}
|
||||
@@ -402,7 +597,7 @@ func (h *AccountHandler) Update(c *gin.Context) {
|
||||
return
|
||||
}
|
||||
|
||||
response.Success(c, dto.AccountFromService(account))
|
||||
response.Success(c, h.buildAccountResponseWithRuntime(c.Request.Context(), account))
|
||||
}
|
||||
|
||||
// Delete handles deleting an account
|
||||
@@ -660,7 +855,7 @@ func (h *AccountHandler) Refresh(c *gin.Context) {
|
||||
}
|
||||
}
|
||||
|
||||
response.Success(c, dto.AccountFromService(updatedAccount))
|
||||
response.Success(c, h.buildAccountResponseWithRuntime(c.Request.Context(), updatedAccount))
|
||||
}
|
||||
|
||||
// GetStats handles getting account statistics
|
||||
@@ -718,7 +913,7 @@ func (h *AccountHandler) ClearError(c *gin.Context) {
|
||||
}
|
||||
}
|
||||
|
||||
response.Success(c, dto.AccountFromService(account))
|
||||
response.Success(c, h.buildAccountResponseWithRuntime(c.Request.Context(), account))
|
||||
}
|
||||
|
||||
// BatchCreate handles batch creating accounts
|
||||
@@ -732,61 +927,65 @@ func (h *AccountHandler) BatchCreate(c *gin.Context) {
|
||||
return
|
||||
}
|
||||
|
||||
ctx := c.Request.Context()
|
||||
success := 0
|
||||
failed := 0
|
||||
results := make([]gin.H, 0, len(req.Accounts))
|
||||
executeAdminIdempotentJSON(c, "admin.accounts.batch_create", req, service.DefaultWriteIdempotencyTTL(), func(ctx context.Context) (any, error) {
|
||||
success := 0
|
||||
failed := 0
|
||||
results := make([]gin.H, 0, len(req.Accounts))
|
||||
|
||||
for _, item := range req.Accounts {
|
||||
if item.RateMultiplier != nil && *item.RateMultiplier < 0 {
|
||||
failed++
|
||||
for _, item := range req.Accounts {
|
||||
if item.RateMultiplier != nil && *item.RateMultiplier < 0 {
|
||||
failed++
|
||||
results = append(results, gin.H{
|
||||
"name": item.Name,
|
||||
"success": false,
|
||||
"error": "rate_multiplier must be >= 0",
|
||||
})
|
||||
continue
|
||||
}
|
||||
|
||||
// base_rpm 输入校验:负值归零,超过 10000 截断
|
||||
sanitizeExtraBaseRPM(item.Extra)
|
||||
|
||||
skipCheck := item.ConfirmMixedChannelRisk != nil && *item.ConfirmMixedChannelRisk
|
||||
|
||||
account, err := h.adminService.CreateAccount(ctx, &service.CreateAccountInput{
|
||||
Name: item.Name,
|
||||
Notes: item.Notes,
|
||||
Platform: item.Platform,
|
||||
Type: item.Type,
|
||||
Credentials: item.Credentials,
|
||||
Extra: item.Extra,
|
||||
ProxyID: item.ProxyID,
|
||||
Concurrency: item.Concurrency,
|
||||
Priority: item.Priority,
|
||||
RateMultiplier: item.RateMultiplier,
|
||||
GroupIDs: item.GroupIDs,
|
||||
ExpiresAt: item.ExpiresAt,
|
||||
AutoPauseOnExpired: item.AutoPauseOnExpired,
|
||||
SkipMixedChannelCheck: skipCheck,
|
||||
})
|
||||
if err != nil {
|
||||
failed++
|
||||
results = append(results, gin.H{
|
||||
"name": item.Name,
|
||||
"success": false,
|
||||
"error": err.Error(),
|
||||
})
|
||||
continue
|
||||
}
|
||||
success++
|
||||
results = append(results, gin.H{
|
||||
"name": item.Name,
|
||||
"success": false,
|
||||
"error": "rate_multiplier must be >= 0",
|
||||
"id": account.ID,
|
||||
"success": true,
|
||||
})
|
||||
continue
|
||||
}
|
||||
|
||||
skipCheck := item.ConfirmMixedChannelRisk != nil && *item.ConfirmMixedChannelRisk
|
||||
|
||||
account, err := h.adminService.CreateAccount(ctx, &service.CreateAccountInput{
|
||||
Name: item.Name,
|
||||
Notes: item.Notes,
|
||||
Platform: item.Platform,
|
||||
Type: item.Type,
|
||||
Credentials: item.Credentials,
|
||||
Extra: item.Extra,
|
||||
ProxyID: item.ProxyID,
|
||||
Concurrency: item.Concurrency,
|
||||
Priority: item.Priority,
|
||||
RateMultiplier: item.RateMultiplier,
|
||||
GroupIDs: item.GroupIDs,
|
||||
ExpiresAt: item.ExpiresAt,
|
||||
AutoPauseOnExpired: item.AutoPauseOnExpired,
|
||||
SkipMixedChannelCheck: skipCheck,
|
||||
})
|
||||
if err != nil {
|
||||
failed++
|
||||
results = append(results, gin.H{
|
||||
"name": item.Name,
|
||||
"success": false,
|
||||
"error": err.Error(),
|
||||
})
|
||||
continue
|
||||
}
|
||||
success++
|
||||
results = append(results, gin.H{
|
||||
"name": item.Name,
|
||||
"id": account.ID,
|
||||
"success": true,
|
||||
})
|
||||
}
|
||||
|
||||
response.Success(c, gin.H{
|
||||
"success": success,
|
||||
"failed": failed,
|
||||
"results": results,
|
||||
return gin.H{
|
||||
"success": success,
|
||||
"failed": failed,
|
||||
"results": results,
|
||||
}, nil
|
||||
})
|
||||
}
|
||||
|
||||
@@ -824,57 +1023,58 @@ func (h *AccountHandler) BatchUpdateCredentials(c *gin.Context) {
|
||||
}
|
||||
|
||||
ctx := c.Request.Context()
|
||||
success := 0
|
||||
failed := 0
|
||||
results := []gin.H{}
|
||||
|
||||
// 阶段一:预验证所有账号存在,收集 credentials
|
||||
type accountUpdate struct {
|
||||
ID int64
|
||||
Credentials map[string]any
|
||||
}
|
||||
updates := make([]accountUpdate, 0, len(req.AccountIDs))
|
||||
for _, accountID := range req.AccountIDs {
|
||||
// Get account
|
||||
account, err := h.adminService.GetAccount(ctx, accountID)
|
||||
if err != nil {
|
||||
failed++
|
||||
results = append(results, gin.H{
|
||||
"account_id": accountID,
|
||||
"success": false,
|
||||
"error": "Account not found",
|
||||
})
|
||||
continue
|
||||
response.Error(c, 404, fmt.Sprintf("Account %d not found", accountID))
|
||||
return
|
||||
}
|
||||
|
||||
// Update credentials field
|
||||
if account.Credentials == nil {
|
||||
account.Credentials = make(map[string]any)
|
||||
}
|
||||
|
||||
account.Credentials[req.Field] = req.Value
|
||||
updates = append(updates, accountUpdate{ID: accountID, Credentials: account.Credentials})
|
||||
}
|
||||
|
||||
// Update account
|
||||
updateInput := &service.UpdateAccountInput{
|
||||
Credentials: account.Credentials,
|
||||
}
|
||||
|
||||
_, err = h.adminService.UpdateAccount(ctx, accountID, updateInput)
|
||||
if err != nil {
|
||||
// 阶段二:依次更新,返回每个账号的成功/失败明细,便于调用方重试
|
||||
success := 0
|
||||
failed := 0
|
||||
successIDs := make([]int64, 0, len(updates))
|
||||
failedIDs := make([]int64, 0, len(updates))
|
||||
results := make([]gin.H, 0, len(updates))
|
||||
for _, u := range updates {
|
||||
updateInput := &service.UpdateAccountInput{Credentials: u.Credentials}
|
||||
if _, err := h.adminService.UpdateAccount(ctx, u.ID, updateInput); err != nil {
|
||||
failed++
|
||||
failedIDs = append(failedIDs, u.ID)
|
||||
results = append(results, gin.H{
|
||||
"account_id": accountID,
|
||||
"account_id": u.ID,
|
||||
"success": false,
|
||||
"error": err.Error(),
|
||||
})
|
||||
continue
|
||||
}
|
||||
|
||||
success++
|
||||
successIDs = append(successIDs, u.ID)
|
||||
results = append(results, gin.H{
|
||||
"account_id": accountID,
|
||||
"account_id": u.ID,
|
||||
"success": true,
|
||||
})
|
||||
}
|
||||
|
||||
response.Success(c, gin.H{
|
||||
"success": success,
|
||||
"failed": failed,
|
||||
"results": results,
|
||||
"success": success,
|
||||
"failed": failed,
|
||||
"success_ids": successIDs,
|
||||
"failed_ids": failedIDs,
|
||||
"results": results,
|
||||
})
|
||||
}
|
||||
|
||||
@@ -890,6 +1090,8 @@ func (h *AccountHandler) BulkUpdate(c *gin.Context) {
|
||||
response.BadRequest(c, "rate_multiplier must be >= 0")
|
||||
return
|
||||
}
|
||||
// base_rpm 输入校验:负值归零,超过 10000 截断
|
||||
sanitizeExtraBaseRPM(req.Extra)
|
||||
|
||||
// 确定是否跳过混合渠道检查
|
||||
skipCheck := req.ConfirmMixedChannelRisk != nil && *req.ConfirmMixedChannelRisk
|
||||
@@ -925,6 +1127,14 @@ func (h *AccountHandler) BulkUpdate(c *gin.Context) {
|
||||
SkipMixedChannelCheck: skipCheck,
|
||||
})
|
||||
if err != nil {
|
||||
var mixedErr *service.MixedChannelError
|
||||
if errors.As(err, &mixedErr) {
|
||||
c.JSON(409, gin.H{
|
||||
"error": "mixed_channel_warning",
|
||||
"message": mixedErr.Error(),
|
||||
})
|
||||
return
|
||||
}
|
||||
response.ErrorFrom(c, err)
|
||||
return
|
||||
}
|
||||
@@ -1109,7 +1319,13 @@ func (h *AccountHandler) ClearRateLimit(c *gin.Context) {
|
||||
return
|
||||
}
|
||||
|
||||
response.Success(c, gin.H{"message": "Rate limit cleared successfully"})
|
||||
account, err := h.adminService.GetAccount(c.Request.Context(), accountID)
|
||||
if err != nil {
|
||||
response.ErrorFrom(c, err)
|
||||
return
|
||||
}
|
||||
|
||||
response.Success(c, h.buildAccountResponseWithRuntime(c.Request.Context(), account))
|
||||
}
|
||||
|
||||
// GetTempUnschedulable handles getting temporary unschedulable status
|
||||
@@ -1173,6 +1389,57 @@ func (h *AccountHandler) GetTodayStats(c *gin.Context) {
|
||||
response.Success(c, stats)
|
||||
}
|
||||
|
||||
// BatchTodayStatsRequest 批量今日统计请求体。
|
||||
type BatchTodayStatsRequest struct {
|
||||
AccountIDs []int64 `json:"account_ids" binding:"required"`
|
||||
}
|
||||
|
||||
// GetBatchTodayStats 批量获取多个账号的今日统计。
|
||||
// POST /api/v1/admin/accounts/today-stats/batch
|
||||
func (h *AccountHandler) GetBatchTodayStats(c *gin.Context) {
|
||||
var req BatchTodayStatsRequest
|
||||
if err := c.ShouldBindJSON(&req); err != nil {
|
||||
response.BadRequest(c, "Invalid request: "+err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
accountIDs := normalizeInt64IDList(req.AccountIDs)
|
||||
if len(accountIDs) == 0 {
|
||||
response.Success(c, gin.H{"stats": map[string]any{}})
|
||||
return
|
||||
}
|
||||
|
||||
cacheKey := buildAccountTodayStatsBatchCacheKey(accountIDs)
|
||||
if cached, ok := accountTodayStatsBatchCache.Get(cacheKey); ok {
|
||||
if cached.ETag != "" {
|
||||
c.Header("ETag", cached.ETag)
|
||||
c.Header("Vary", "If-None-Match")
|
||||
if ifNoneMatchMatched(c.GetHeader("If-None-Match"), cached.ETag) {
|
||||
c.Status(http.StatusNotModified)
|
||||
return
|
||||
}
|
||||
}
|
||||
c.Header("X-Snapshot-Cache", "hit")
|
||||
response.Success(c, cached.Payload)
|
||||
return
|
||||
}
|
||||
|
||||
stats, err := h.accountUsageService.GetTodayStatsBatch(c.Request.Context(), accountIDs)
|
||||
if err != nil {
|
||||
response.ErrorFrom(c, err)
|
||||
return
|
||||
}
|
||||
|
||||
payload := gin.H{"stats": stats}
|
||||
cached := accountTodayStatsBatchCache.Set(cacheKey, payload)
|
||||
if cached.ETag != "" {
|
||||
c.Header("ETag", cached.ETag)
|
||||
c.Header("Vary", "If-None-Match")
|
||||
}
|
||||
c.Header("X-Snapshot-Cache", "miss")
|
||||
response.Success(c, payload)
|
||||
}
|
||||
|
||||
// SetSchedulableRequest represents the request body for setting schedulable status
|
||||
type SetSchedulableRequest struct {
|
||||
Schedulable bool `json:"schedulable"`
|
||||
@@ -1199,7 +1466,7 @@ func (h *AccountHandler) SetSchedulable(c *gin.Context) {
|
||||
return
|
||||
}
|
||||
|
||||
response.Success(c, dto.AccountFromService(account))
|
||||
response.Success(c, h.buildAccountResponseWithRuntime(c.Request.Context(), account))
|
||||
}
|
||||
|
||||
// GetAvailableModels handles getting available models for an account
|
||||
@@ -1296,32 +1563,14 @@ func (h *AccountHandler) GetAvailableModels(c *gin.Context) {
|
||||
|
||||
// Handle Antigravity accounts: return Claude + Gemini models
|
||||
if account.Platform == service.PlatformAntigravity {
|
||||
// Antigravity 支持 Claude 和部分 Gemini 模型
|
||||
type UnifiedModel struct {
|
||||
ID string `json:"id"`
|
||||
Type string `json:"type"`
|
||||
DisplayName string `json:"display_name"`
|
||||
}
|
||||
// 直接复用 antigravity.DefaultModels(),与 /v1/models 端点保持同步
|
||||
response.Success(c, antigravity.DefaultModels())
|
||||
return
|
||||
}
|
||||
|
||||
var models []UnifiedModel
|
||||
|
||||
// 添加 Claude 模型
|
||||
for _, m := range claude.DefaultModels {
|
||||
models = append(models, UnifiedModel{
|
||||
ID: m.ID,
|
||||
Type: m.Type,
|
||||
DisplayName: m.DisplayName,
|
||||
})
|
||||
}
|
||||
|
||||
// 添加 Gemini 3 系列模型用于测试
|
||||
geminiTestModels := []UnifiedModel{
|
||||
{ID: "gemini-3-flash", Type: "model", DisplayName: "Gemini 3 Flash"},
|
||||
{ID: "gemini-3-pro-preview", Type: "model", DisplayName: "Gemini 3 Pro Preview"},
|
||||
}
|
||||
models = append(models, geminiTestModels...)
|
||||
|
||||
response.Success(c, models)
|
||||
// Handle Sora accounts
|
||||
if account.Platform == service.PlatformSora {
|
||||
response.Success(c, service.DefaultSoraModels(nil))
|
||||
return
|
||||
}
|
||||
|
||||
@@ -1532,3 +1781,22 @@ func (h *AccountHandler) BatchRefreshTier(c *gin.Context) {
|
||||
func (h *AccountHandler) GetAntigravityDefaultModelMapping(c *gin.Context) {
|
||||
response.Success(c, domain.DefaultAntigravityModelMapping)
|
||||
}
|
||||
|
||||
// sanitizeExtraBaseRPM 对 extra map 中的 base_rpm 值进行范围校验和归一化。
|
||||
// 负值归零,超过 10000 截断为 10000。extra 为 nil 或不含 base_rpm 时无操作。
|
||||
func sanitizeExtraBaseRPM(extra map[string]any) {
|
||||
if extra == nil {
|
||||
return
|
||||
}
|
||||
raw, ok := extra["base_rpm"]
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
v := service.ParseExtraInt(raw)
|
||||
if v < 0 {
|
||||
v = 0
|
||||
} else if v > 10000 {
|
||||
v = 10000
|
||||
}
|
||||
extra["base_rpm"] = v
|
||||
}
|
||||
|
||||
@@ -0,0 +1,198 @@
|
||||
package admin
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/json"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"testing"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/service"
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func setupAccountMixedChannelRouter(adminSvc *stubAdminService) *gin.Engine {
|
||||
gin.SetMode(gin.TestMode)
|
||||
router := gin.New()
|
||||
accountHandler := NewAccountHandler(adminSvc, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil)
|
||||
router.POST("/api/v1/admin/accounts/check-mixed-channel", accountHandler.CheckMixedChannel)
|
||||
router.POST("/api/v1/admin/accounts", accountHandler.Create)
|
||||
router.PUT("/api/v1/admin/accounts/:id", accountHandler.Update)
|
||||
router.POST("/api/v1/admin/accounts/bulk-update", accountHandler.BulkUpdate)
|
||||
return router
|
||||
}
|
||||
|
||||
func TestAccountHandlerCheckMixedChannelNoRisk(t *testing.T) {
|
||||
adminSvc := newStubAdminService()
|
||||
router := setupAccountMixedChannelRouter(adminSvc)
|
||||
|
||||
body, _ := json.Marshal(map[string]any{
|
||||
"platform": "antigravity",
|
||||
"group_ids": []int64{27},
|
||||
})
|
||||
rec := httptest.NewRecorder()
|
||||
req := httptest.NewRequest(http.MethodPost, "/api/v1/admin/accounts/check-mixed-channel", bytes.NewReader(body))
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
router.ServeHTTP(rec, req)
|
||||
|
||||
require.Equal(t, http.StatusOK, rec.Code)
|
||||
var resp map[string]any
|
||||
require.NoError(t, json.Unmarshal(rec.Body.Bytes(), &resp))
|
||||
require.Equal(t, float64(0), resp["code"])
|
||||
data, ok := resp["data"].(map[string]any)
|
||||
require.True(t, ok)
|
||||
require.Equal(t, false, data["has_risk"])
|
||||
require.Equal(t, int64(0), adminSvc.lastMixedCheck.accountID)
|
||||
require.Equal(t, "antigravity", adminSvc.lastMixedCheck.platform)
|
||||
require.Equal(t, []int64{27}, adminSvc.lastMixedCheck.groupIDs)
|
||||
}
|
||||
|
||||
func TestAccountHandlerCheckMixedChannelWithRisk(t *testing.T) {
|
||||
adminSvc := newStubAdminService()
|
||||
adminSvc.checkMixedErr = &service.MixedChannelError{
|
||||
GroupID: 27,
|
||||
GroupName: "claude-max",
|
||||
CurrentPlatform: "Antigravity",
|
||||
OtherPlatform: "Anthropic",
|
||||
}
|
||||
router := setupAccountMixedChannelRouter(adminSvc)
|
||||
|
||||
body, _ := json.Marshal(map[string]any{
|
||||
"platform": "antigravity",
|
||||
"group_ids": []int64{27},
|
||||
"account_id": 99,
|
||||
})
|
||||
rec := httptest.NewRecorder()
|
||||
req := httptest.NewRequest(http.MethodPost, "/api/v1/admin/accounts/check-mixed-channel", bytes.NewReader(body))
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
router.ServeHTTP(rec, req)
|
||||
|
||||
require.Equal(t, http.StatusOK, rec.Code)
|
||||
var resp map[string]any
|
||||
require.NoError(t, json.Unmarshal(rec.Body.Bytes(), &resp))
|
||||
require.Equal(t, float64(0), resp["code"])
|
||||
data, ok := resp["data"].(map[string]any)
|
||||
require.True(t, ok)
|
||||
require.Equal(t, true, data["has_risk"])
|
||||
require.Equal(t, "mixed_channel_warning", data["error"])
|
||||
details, ok := data["details"].(map[string]any)
|
||||
require.True(t, ok)
|
||||
require.Equal(t, float64(27), details["group_id"])
|
||||
require.Equal(t, "claude-max", details["group_name"])
|
||||
require.Equal(t, "Antigravity", details["current_platform"])
|
||||
require.Equal(t, "Anthropic", details["other_platform"])
|
||||
require.Equal(t, int64(99), adminSvc.lastMixedCheck.accountID)
|
||||
}
|
||||
|
||||
func TestAccountHandlerCreateMixedChannelConflictSimplifiedResponse(t *testing.T) {
|
||||
adminSvc := newStubAdminService()
|
||||
adminSvc.createAccountErr = &service.MixedChannelError{
|
||||
GroupID: 27,
|
||||
GroupName: "claude-max",
|
||||
CurrentPlatform: "Antigravity",
|
||||
OtherPlatform: "Anthropic",
|
||||
}
|
||||
router := setupAccountMixedChannelRouter(adminSvc)
|
||||
|
||||
body, _ := json.Marshal(map[string]any{
|
||||
"name": "ag-oauth-1",
|
||||
"platform": "antigravity",
|
||||
"type": "oauth",
|
||||
"credentials": map[string]any{"refresh_token": "rt"},
|
||||
"group_ids": []int64{27},
|
||||
})
|
||||
rec := httptest.NewRecorder()
|
||||
req := httptest.NewRequest(http.MethodPost, "/api/v1/admin/accounts", bytes.NewReader(body))
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
router.ServeHTTP(rec, req)
|
||||
|
||||
require.Equal(t, http.StatusConflict, rec.Code)
|
||||
var resp map[string]any
|
||||
require.NoError(t, json.Unmarshal(rec.Body.Bytes(), &resp))
|
||||
require.Equal(t, "mixed_channel_warning", resp["error"])
|
||||
require.Contains(t, resp["message"], "mixed_channel_warning")
|
||||
_, hasDetails := resp["details"]
|
||||
_, hasRequireConfirmation := resp["require_confirmation"]
|
||||
require.False(t, hasDetails)
|
||||
require.False(t, hasRequireConfirmation)
|
||||
}
|
||||
|
||||
func TestAccountHandlerUpdateMixedChannelConflictSimplifiedResponse(t *testing.T) {
|
||||
adminSvc := newStubAdminService()
|
||||
adminSvc.updateAccountErr = &service.MixedChannelError{
|
||||
GroupID: 27,
|
||||
GroupName: "claude-max",
|
||||
CurrentPlatform: "Antigravity",
|
||||
OtherPlatform: "Anthropic",
|
||||
}
|
||||
router := setupAccountMixedChannelRouter(adminSvc)
|
||||
|
||||
body, _ := json.Marshal(map[string]any{
|
||||
"group_ids": []int64{27},
|
||||
})
|
||||
rec := httptest.NewRecorder()
|
||||
req := httptest.NewRequest(http.MethodPut, "/api/v1/admin/accounts/3", bytes.NewReader(body))
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
router.ServeHTTP(rec, req)
|
||||
|
||||
require.Equal(t, http.StatusConflict, rec.Code)
|
||||
var resp map[string]any
|
||||
require.NoError(t, json.Unmarshal(rec.Body.Bytes(), &resp))
|
||||
require.Equal(t, "mixed_channel_warning", resp["error"])
|
||||
require.Contains(t, resp["message"], "mixed_channel_warning")
|
||||
_, hasDetails := resp["details"]
|
||||
_, hasRequireConfirmation := resp["require_confirmation"]
|
||||
require.False(t, hasDetails)
|
||||
require.False(t, hasRequireConfirmation)
|
||||
}
|
||||
|
||||
func TestAccountHandlerBulkUpdateMixedChannelConflict(t *testing.T) {
|
||||
adminSvc := newStubAdminService()
|
||||
adminSvc.bulkUpdateAccountErr = &service.MixedChannelError{
|
||||
GroupID: 27,
|
||||
GroupName: "claude-max",
|
||||
CurrentPlatform: "Antigravity",
|
||||
OtherPlatform: "Anthropic",
|
||||
}
|
||||
router := setupAccountMixedChannelRouter(adminSvc)
|
||||
|
||||
body, _ := json.Marshal(map[string]any{
|
||||
"account_ids": []int64{1, 2, 3},
|
||||
"group_ids": []int64{27},
|
||||
})
|
||||
rec := httptest.NewRecorder()
|
||||
req := httptest.NewRequest(http.MethodPost, "/api/v1/admin/accounts/bulk-update", bytes.NewReader(body))
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
router.ServeHTTP(rec, req)
|
||||
|
||||
require.Equal(t, http.StatusConflict, rec.Code)
|
||||
var resp map[string]any
|
||||
require.NoError(t, json.Unmarshal(rec.Body.Bytes(), &resp))
|
||||
require.Equal(t, "mixed_channel_warning", resp["error"])
|
||||
require.Contains(t, resp["message"], "claude-max")
|
||||
}
|
||||
|
||||
func TestAccountHandlerBulkUpdateMixedChannelConfirmSkips(t *testing.T) {
|
||||
adminSvc := newStubAdminService()
|
||||
router := setupAccountMixedChannelRouter(adminSvc)
|
||||
|
||||
body, _ := json.Marshal(map[string]any{
|
||||
"account_ids": []int64{1, 2},
|
||||
"group_ids": []int64{27},
|
||||
"confirm_mixed_channel_risk": true,
|
||||
})
|
||||
rec := httptest.NewRecorder()
|
||||
req := httptest.NewRequest(http.MethodPost, "/api/v1/admin/accounts/bulk-update", bytes.NewReader(body))
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
router.ServeHTTP(rec, req)
|
||||
|
||||
require.Equal(t, http.StatusOK, rec.Code)
|
||||
var resp map[string]any
|
||||
require.NoError(t, json.Unmarshal(rec.Body.Bytes(), &resp))
|
||||
require.Equal(t, float64(0), resp["code"])
|
||||
data, ok := resp["data"].(map[string]any)
|
||||
require.True(t, ok)
|
||||
require.Equal(t, float64(2), data["success"])
|
||||
require.Equal(t, float64(0), data["failed"])
|
||||
}
|
||||
@@ -0,0 +1,67 @@
|
||||
package admin
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/json"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"testing"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestAccountHandler_Create_AnthropicAPIKeyPassthroughExtraForwarded(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
|
||||
adminSvc := newStubAdminService()
|
||||
handler := NewAccountHandler(
|
||||
adminSvc,
|
||||
nil,
|
||||
nil,
|
||||
nil,
|
||||
nil,
|
||||
nil,
|
||||
nil,
|
||||
nil,
|
||||
nil,
|
||||
nil,
|
||||
nil,
|
||||
nil,
|
||||
nil,
|
||||
)
|
||||
|
||||
router := gin.New()
|
||||
router.POST("/api/v1/admin/accounts", handler.Create)
|
||||
|
||||
body := map[string]any{
|
||||
"name": "anthropic-key-1",
|
||||
"platform": "anthropic",
|
||||
"type": "apikey",
|
||||
"credentials": map[string]any{
|
||||
"api_key": "sk-ant-xxx",
|
||||
"base_url": "https://api.anthropic.com",
|
||||
},
|
||||
"extra": map[string]any{
|
||||
"anthropic_passthrough": true,
|
||||
},
|
||||
"concurrency": 1,
|
||||
"priority": 1,
|
||||
}
|
||||
raw, err := json.Marshal(body)
|
||||
require.NoError(t, err)
|
||||
|
||||
rec := httptest.NewRecorder()
|
||||
req := httptest.NewRequest(http.MethodPost, "/api/v1/admin/accounts", bytes.NewReader(raw))
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
router.ServeHTTP(rec, req)
|
||||
|
||||
require.Equal(t, http.StatusOK, rec.Code)
|
||||
require.Len(t, adminSvc.createdAccounts, 1)
|
||||
|
||||
created := adminSvc.createdAccounts[0]
|
||||
require.Equal(t, "anthropic", created.Platform)
|
||||
require.Equal(t, "apikey", created.Type)
|
||||
require.NotNil(t, created.Extra)
|
||||
require.Equal(t, true, created.Extra["anthropic_passthrough"])
|
||||
}
|
||||
25
backend/internal/handler/admin/account_today_stats_cache.go
Normal file
25
backend/internal/handler/admin/account_today_stats_cache.go
Normal file
@@ -0,0 +1,25 @@
|
||||
package admin
|
||||
|
||||
import (
|
||||
"strconv"
|
||||
"strings"
|
||||
"time"
|
||||
)
|
||||
|
||||
var accountTodayStatsBatchCache = newSnapshotCache(30 * time.Second)
|
||||
|
||||
func buildAccountTodayStatsBatchCacheKey(accountIDs []int64) string {
|
||||
if len(accountIDs) == 0 {
|
||||
return "accounts_today_stats_empty"
|
||||
}
|
||||
var b strings.Builder
|
||||
b.Grow(len(accountIDs) * 6)
|
||||
_, _ = b.WriteString("accounts_today_stats:")
|
||||
for i, id := range accountIDs {
|
||||
if i > 0 {
|
||||
_ = b.WriteByte(',')
|
||||
}
|
||||
_, _ = b.WriteString(strconv.FormatInt(id, 10))
|
||||
}
|
||||
return b.String()
|
||||
}
|
||||
@@ -19,7 +19,7 @@ func setupAdminRouter() (*gin.Engine, *stubAdminService) {
|
||||
userHandler := NewUserHandler(adminSvc, nil)
|
||||
groupHandler := NewGroupHandler(adminSvc)
|
||||
proxyHandler := NewProxyHandler(adminSvc)
|
||||
redeemHandler := NewRedeemHandler(adminSvc)
|
||||
redeemHandler := NewRedeemHandler(adminSvc, nil)
|
||||
|
||||
router.GET("/api/v1/admin/users", userHandler.List)
|
||||
router.GET("/api/v1/admin/users/:id", userHandler.GetByID)
|
||||
@@ -47,6 +47,7 @@ func setupAdminRouter() (*gin.Engine, *stubAdminService) {
|
||||
router.DELETE("/api/v1/admin/proxies/:id", proxyHandler.Delete)
|
||||
router.POST("/api/v1/admin/proxies/batch-delete", proxyHandler.BatchDelete)
|
||||
router.POST("/api/v1/admin/proxies/:id/test", proxyHandler.Test)
|
||||
router.POST("/api/v1/admin/proxies/:id/quality-check", proxyHandler.CheckQuality)
|
||||
router.GET("/api/v1/admin/proxies/:id/stats", proxyHandler.GetStats)
|
||||
router.GET("/api/v1/admin/proxies/:id/accounts", proxyHandler.GetProxyAccounts)
|
||||
|
||||
@@ -208,6 +209,11 @@ func TestProxyHandlerEndpoints(t *testing.T) {
|
||||
router.ServeHTTP(rec, req)
|
||||
require.Equal(t, http.StatusOK, rec.Code)
|
||||
|
||||
rec = httptest.NewRecorder()
|
||||
req = httptest.NewRequest(http.MethodPost, "/api/v1/admin/proxies/4/quality-check", nil)
|
||||
router.ServeHTTP(rec, req)
|
||||
require.Equal(t, http.StatusOK, rec.Code)
|
||||
|
||||
rec = httptest.NewRecorder()
|
||||
req = httptest.NewRequest(http.MethodGet, "/api/v1/admin/proxies/4/stats", nil)
|
||||
router.ServeHTTP(rec, req)
|
||||
|
||||
@@ -58,6 +58,96 @@ func TestParseOpsDuration(t *testing.T) {
|
||||
require.False(t, ok)
|
||||
}
|
||||
|
||||
func TestParseOpsOpenAITokenStatsDuration(t *testing.T) {
|
||||
tests := []struct {
|
||||
input string
|
||||
want time.Duration
|
||||
ok bool
|
||||
}{
|
||||
{input: "30m", want: 30 * time.Minute, ok: true},
|
||||
{input: "1h", want: time.Hour, ok: true},
|
||||
{input: "1d", want: 24 * time.Hour, ok: true},
|
||||
{input: "15d", want: 15 * 24 * time.Hour, ok: true},
|
||||
{input: "30d", want: 30 * 24 * time.Hour, ok: true},
|
||||
{input: "7d", want: 0, ok: false},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
got, ok := parseOpsOpenAITokenStatsDuration(tt.input)
|
||||
require.Equal(t, tt.ok, ok, "input=%s", tt.input)
|
||||
require.Equal(t, tt.want, got, "input=%s", tt.input)
|
||||
}
|
||||
}
|
||||
|
||||
func TestParseOpsOpenAITokenStatsFilter_Defaults(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
w := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(w)
|
||||
c.Request = httptest.NewRequest(http.MethodGet, "/", nil)
|
||||
|
||||
before := time.Now().UTC()
|
||||
filter, err := parseOpsOpenAITokenStatsFilter(c)
|
||||
after := time.Now().UTC()
|
||||
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, filter)
|
||||
require.Equal(t, "30d", filter.TimeRange)
|
||||
require.Equal(t, 1, filter.Page)
|
||||
require.Equal(t, 20, filter.PageSize)
|
||||
require.Equal(t, 0, filter.TopN)
|
||||
require.Nil(t, filter.GroupID)
|
||||
require.Equal(t, "", filter.Platform)
|
||||
require.True(t, filter.StartTime.Before(filter.EndTime))
|
||||
require.WithinDuration(t, before.Add(-30*24*time.Hour), filter.StartTime, 2*time.Second)
|
||||
require.WithinDuration(t, after, filter.EndTime, 2*time.Second)
|
||||
}
|
||||
|
||||
func TestParseOpsOpenAITokenStatsFilter_WithTopN(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
w := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(w)
|
||||
c.Request = httptest.NewRequest(
|
||||
http.MethodGet,
|
||||
"/?time_range=1h&platform=openai&group_id=12&top_n=50",
|
||||
nil,
|
||||
)
|
||||
|
||||
filter, err := parseOpsOpenAITokenStatsFilter(c)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, "1h", filter.TimeRange)
|
||||
require.Equal(t, "openai", filter.Platform)
|
||||
require.NotNil(t, filter.GroupID)
|
||||
require.Equal(t, int64(12), *filter.GroupID)
|
||||
require.Equal(t, 50, filter.TopN)
|
||||
require.Equal(t, 0, filter.Page)
|
||||
require.Equal(t, 0, filter.PageSize)
|
||||
}
|
||||
|
||||
func TestParseOpsOpenAITokenStatsFilter_InvalidParams(t *testing.T) {
|
||||
tests := []string{
|
||||
"/?time_range=7d",
|
||||
"/?group_id=0",
|
||||
"/?group_id=abc",
|
||||
"/?top_n=0",
|
||||
"/?top_n=101",
|
||||
"/?top_n=10&page=1",
|
||||
"/?top_n=10&page_size=20",
|
||||
"/?page=0",
|
||||
"/?page_size=0",
|
||||
"/?page_size=101",
|
||||
}
|
||||
|
||||
gin.SetMode(gin.TestMode)
|
||||
for _, rawURL := range tests {
|
||||
w := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(w)
|
||||
c.Request = httptest.NewRequest(http.MethodGet, rawURL, nil)
|
||||
|
||||
_, err := parseOpsOpenAITokenStatsFilter(c)
|
||||
require.Error(t, err, "url=%s", rawURL)
|
||||
}
|
||||
}
|
||||
|
||||
func TestParseOpsTimeRange(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
w := httptest.NewRecorder()
|
||||
|
||||
@@ -10,19 +10,28 @@ import (
|
||||
)
|
||||
|
||||
type stubAdminService struct {
|
||||
users []service.User
|
||||
apiKeys []service.APIKey
|
||||
groups []service.Group
|
||||
accounts []service.Account
|
||||
proxies []service.Proxy
|
||||
proxyCounts []service.ProxyWithAccountCount
|
||||
redeems []service.RedeemCode
|
||||
createdAccounts []*service.CreateAccountInput
|
||||
createdProxies []*service.CreateProxyInput
|
||||
updatedProxyIDs []int64
|
||||
updatedProxies []*service.UpdateProxyInput
|
||||
testedProxyIDs []int64
|
||||
mu sync.Mutex
|
||||
users []service.User
|
||||
apiKeys []service.APIKey
|
||||
groups []service.Group
|
||||
accounts []service.Account
|
||||
proxies []service.Proxy
|
||||
proxyCounts []service.ProxyWithAccountCount
|
||||
redeems []service.RedeemCode
|
||||
createdAccounts []*service.CreateAccountInput
|
||||
createdProxies []*service.CreateProxyInput
|
||||
updatedProxyIDs []int64
|
||||
updatedProxies []*service.UpdateProxyInput
|
||||
testedProxyIDs []int64
|
||||
createAccountErr error
|
||||
updateAccountErr error
|
||||
bulkUpdateAccountErr error
|
||||
checkMixedErr error
|
||||
lastMixedCheck struct {
|
||||
accountID int64
|
||||
platform string
|
||||
groupIDs []int64
|
||||
}
|
||||
mu sync.Mutex
|
||||
}
|
||||
|
||||
func newStubAdminService() *stubAdminService {
|
||||
@@ -188,11 +197,17 @@ func (s *stubAdminService) CreateAccount(ctx context.Context, input *service.Cre
|
||||
s.mu.Lock()
|
||||
s.createdAccounts = append(s.createdAccounts, input)
|
||||
s.mu.Unlock()
|
||||
if s.createAccountErr != nil {
|
||||
return nil, s.createAccountErr
|
||||
}
|
||||
account := service.Account{ID: 300, Name: input.Name, Status: service.StatusActive}
|
||||
return &account, nil
|
||||
}
|
||||
|
||||
func (s *stubAdminService) UpdateAccount(ctx context.Context, id int64, input *service.UpdateAccountInput) (*service.Account, error) {
|
||||
if s.updateAccountErr != nil {
|
||||
return nil, s.updateAccountErr
|
||||
}
|
||||
account := service.Account{ID: id, Name: input.Name, Status: service.StatusActive}
|
||||
return &account, nil
|
||||
}
|
||||
@@ -221,7 +236,17 @@ func (s *stubAdminService) SetAccountSchedulable(ctx context.Context, id int64,
|
||||
}
|
||||
|
||||
func (s *stubAdminService) BulkUpdateAccounts(ctx context.Context, input *service.BulkUpdateAccountsInput) (*service.BulkUpdateAccountsResult, error) {
|
||||
return &service.BulkUpdateAccountsResult{Success: 1, Failed: 0, SuccessIDs: []int64{1}}, nil
|
||||
if s.bulkUpdateAccountErr != nil {
|
||||
return nil, s.bulkUpdateAccountErr
|
||||
}
|
||||
return &service.BulkUpdateAccountsResult{Success: len(input.AccountIDs), Failed: 0, SuccessIDs: input.AccountIDs}, nil
|
||||
}
|
||||
|
||||
func (s *stubAdminService) CheckMixedChannelRisk(ctx context.Context, currentAccountID int64, currentAccountPlatform string, groupIDs []int64) error {
|
||||
s.lastMixedCheck.accountID = currentAccountID
|
||||
s.lastMixedCheck.platform = currentAccountPlatform
|
||||
s.lastMixedCheck.groupIDs = append([]int64(nil), groupIDs...)
|
||||
return s.checkMixedErr
|
||||
}
|
||||
|
||||
func (s *stubAdminService) ListProxies(ctx context.Context, page, pageSize int, protocol, status, search string) ([]service.Proxy, int64, error) {
|
||||
@@ -327,6 +352,27 @@ func (s *stubAdminService) TestProxy(ctx context.Context, id int64) (*service.Pr
|
||||
return &service.ProxyTestResult{Success: true, Message: "ok"}, nil
|
||||
}
|
||||
|
||||
func (s *stubAdminService) CheckProxyQuality(ctx context.Context, id int64) (*service.ProxyQualityCheckResult, error) {
|
||||
return &service.ProxyQualityCheckResult{
|
||||
ProxyID: id,
|
||||
Score: 95,
|
||||
Grade: "A",
|
||||
Summary: "通过 5 项,告警 0 项,失败 0 项,挑战 0 项",
|
||||
PassedCount: 5,
|
||||
WarnCount: 0,
|
||||
FailedCount: 0,
|
||||
ChallengeCount: 0,
|
||||
CheckedAt: time.Now().Unix(),
|
||||
Items: []service.ProxyQualityCheckItem{
|
||||
{Target: "base_connectivity", Status: "pass", Message: "ok"},
|
||||
{Target: "openai", Status: "pass", HTTPStatus: 401},
|
||||
{Target: "anthropic", Status: "pass", HTTPStatus: 401},
|
||||
{Target: "gemini", Status: "pass", HTTPStatus: 200},
|
||||
{Target: "sora", Status: "pass", HTTPStatus: 401},
|
||||
},
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (s *stubAdminService) ListRedeemCodes(ctx context.Context, page, pageSize int, codeType, status, search string) ([]service.RedeemCode, int64, error) {
|
||||
return s.redeems, int64(len(s.redeems)), nil
|
||||
}
|
||||
@@ -361,5 +407,23 @@ func (s *stubAdminService) UpdateGroupSortOrders(ctx context.Context, updates []
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *stubAdminService) AdminUpdateAPIKeyGroupID(ctx context.Context, keyID int64, groupID *int64) (*service.AdminUpdateAPIKeyGroupIDResult, error) {
|
||||
for i := range s.apiKeys {
|
||||
if s.apiKeys[i].ID == keyID {
|
||||
k := s.apiKeys[i]
|
||||
if groupID != nil {
|
||||
if *groupID == 0 {
|
||||
k.GroupID = nil
|
||||
} else {
|
||||
gid := *groupID
|
||||
k.GroupID = &gid
|
||||
}
|
||||
}
|
||||
return &service.AdminUpdateAPIKeyGroupIDResult{APIKey: &k}, nil
|
||||
}
|
||||
}
|
||||
return nil, service.ErrAPIKeyNotFound
|
||||
}
|
||||
|
||||
// Ensure stub implements interface.
|
||||
var _ service.AdminService = (*stubAdminService)(nil)
|
||||
|
||||
63
backend/internal/handler/admin/apikey_handler.go
Normal file
63
backend/internal/handler/admin/apikey_handler.go
Normal file
@@ -0,0 +1,63 @@
|
||||
package admin
|
||||
|
||||
import (
|
||||
"strconv"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/handler/dto"
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/response"
|
||||
"github.com/Wei-Shaw/sub2api/internal/service"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
// AdminAPIKeyHandler handles admin API key management
|
||||
type AdminAPIKeyHandler struct {
|
||||
adminService service.AdminService
|
||||
}
|
||||
|
||||
// NewAdminAPIKeyHandler creates a new admin API key handler
|
||||
func NewAdminAPIKeyHandler(adminService service.AdminService) *AdminAPIKeyHandler {
|
||||
return &AdminAPIKeyHandler{
|
||||
adminService: adminService,
|
||||
}
|
||||
}
|
||||
|
||||
// AdminUpdateAPIKeyGroupRequest represents the request to update an API key's group
|
||||
type AdminUpdateAPIKeyGroupRequest struct {
|
||||
GroupID *int64 `json:"group_id"` // nil=不修改, 0=解绑, >0=绑定到目标分组
|
||||
}
|
||||
|
||||
// UpdateGroup handles updating an API key's group binding
|
||||
// PUT /api/v1/admin/api-keys/:id
|
||||
func (h *AdminAPIKeyHandler) UpdateGroup(c *gin.Context) {
|
||||
keyID, err := strconv.ParseInt(c.Param("id"), 10, 64)
|
||||
if err != nil {
|
||||
response.BadRequest(c, "Invalid API key ID")
|
||||
return
|
||||
}
|
||||
|
||||
var req AdminUpdateAPIKeyGroupRequest
|
||||
if err := c.ShouldBindJSON(&req); err != nil {
|
||||
response.BadRequest(c, "Invalid request: "+err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
result, err := h.adminService.AdminUpdateAPIKeyGroupID(c.Request.Context(), keyID, req.GroupID)
|
||||
if err != nil {
|
||||
response.ErrorFrom(c, err)
|
||||
return
|
||||
}
|
||||
|
||||
resp := struct {
|
||||
APIKey *dto.APIKey `json:"api_key"`
|
||||
AutoGrantedGroupAccess bool `json:"auto_granted_group_access"`
|
||||
GrantedGroupID *int64 `json:"granted_group_id,omitempty"`
|
||||
GrantedGroupName string `json:"granted_group_name,omitempty"`
|
||||
}{
|
||||
APIKey: dto.APIKeyFromService(result.APIKey),
|
||||
AutoGrantedGroupAccess: result.AutoGrantedGroupAccess,
|
||||
GrantedGroupID: result.GrantedGroupID,
|
||||
GrantedGroupName: result.GrantedGroupName,
|
||||
}
|
||||
response.Success(c, resp)
|
||||
}
|
||||
202
backend/internal/handler/admin/apikey_handler_test.go
Normal file
202
backend/internal/handler/admin/apikey_handler_test.go
Normal file
@@ -0,0 +1,202 @@
|
||||
package admin
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"testing"
|
||||
|
||||
infraerrors "github.com/Wei-Shaw/sub2api/internal/pkg/errors"
|
||||
"github.com/Wei-Shaw/sub2api/internal/service"
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func setupAPIKeyHandler(adminSvc service.AdminService) *gin.Engine {
|
||||
gin.SetMode(gin.TestMode)
|
||||
router := gin.New()
|
||||
h := NewAdminAPIKeyHandler(adminSvc)
|
||||
router.PUT("/api/v1/admin/api-keys/:id", h.UpdateGroup)
|
||||
return router
|
||||
}
|
||||
|
||||
func TestAdminAPIKeyHandler_UpdateGroup_InvalidID(t *testing.T) {
|
||||
router := setupAPIKeyHandler(newStubAdminService())
|
||||
body := `{"group_id": 2}`
|
||||
|
||||
rec := httptest.NewRecorder()
|
||||
req := httptest.NewRequest(http.MethodPut, "/api/v1/admin/api-keys/abc", bytes.NewBufferString(body))
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
router.ServeHTTP(rec, req)
|
||||
|
||||
require.Equal(t, http.StatusBadRequest, rec.Code)
|
||||
require.Contains(t, rec.Body.String(), "Invalid API key ID")
|
||||
}
|
||||
|
||||
func TestAdminAPIKeyHandler_UpdateGroup_InvalidJSON(t *testing.T) {
|
||||
router := setupAPIKeyHandler(newStubAdminService())
|
||||
|
||||
rec := httptest.NewRecorder()
|
||||
req := httptest.NewRequest(http.MethodPut, "/api/v1/admin/api-keys/10", bytes.NewBufferString(`{bad json`))
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
router.ServeHTTP(rec, req)
|
||||
|
||||
require.Equal(t, http.StatusBadRequest, rec.Code)
|
||||
require.Contains(t, rec.Body.String(), "Invalid request")
|
||||
}
|
||||
|
||||
func TestAdminAPIKeyHandler_UpdateGroup_KeyNotFound(t *testing.T) {
|
||||
router := setupAPIKeyHandler(newStubAdminService())
|
||||
body := `{"group_id": 2}`
|
||||
|
||||
rec := httptest.NewRecorder()
|
||||
req := httptest.NewRequest(http.MethodPut, "/api/v1/admin/api-keys/999", bytes.NewBufferString(body))
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
router.ServeHTTP(rec, req)
|
||||
|
||||
// ErrAPIKeyNotFound maps to 404
|
||||
require.Equal(t, http.StatusNotFound, rec.Code)
|
||||
}
|
||||
|
||||
func TestAdminAPIKeyHandler_UpdateGroup_BindGroup(t *testing.T) {
|
||||
router := setupAPIKeyHandler(newStubAdminService())
|
||||
body := `{"group_id": 2}`
|
||||
|
||||
rec := httptest.NewRecorder()
|
||||
req := httptest.NewRequest(http.MethodPut, "/api/v1/admin/api-keys/10", bytes.NewBufferString(body))
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
router.ServeHTTP(rec, req)
|
||||
|
||||
require.Equal(t, http.StatusOK, rec.Code)
|
||||
|
||||
var resp struct {
|
||||
Code int `json:"code"`
|
||||
Data json.RawMessage `json:"data"`
|
||||
}
|
||||
require.NoError(t, json.Unmarshal(rec.Body.Bytes(), &resp))
|
||||
require.Equal(t, 0, resp.Code)
|
||||
|
||||
var data struct {
|
||||
APIKey struct {
|
||||
ID int64 `json:"id"`
|
||||
GroupID *int64 `json:"group_id"`
|
||||
} `json:"api_key"`
|
||||
AutoGrantedGroupAccess bool `json:"auto_granted_group_access"`
|
||||
}
|
||||
require.NoError(t, json.Unmarshal(resp.Data, &data))
|
||||
require.Equal(t, int64(10), data.APIKey.ID)
|
||||
require.NotNil(t, data.APIKey.GroupID)
|
||||
require.Equal(t, int64(2), *data.APIKey.GroupID)
|
||||
}
|
||||
|
||||
func TestAdminAPIKeyHandler_UpdateGroup_Unbind(t *testing.T) {
|
||||
svc := newStubAdminService()
|
||||
gid := int64(2)
|
||||
svc.apiKeys[0].GroupID = &gid
|
||||
router := setupAPIKeyHandler(svc)
|
||||
body := `{"group_id": 0}`
|
||||
|
||||
rec := httptest.NewRecorder()
|
||||
req := httptest.NewRequest(http.MethodPut, "/api/v1/admin/api-keys/10", bytes.NewBufferString(body))
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
router.ServeHTTP(rec, req)
|
||||
|
||||
require.Equal(t, http.StatusOK, rec.Code)
|
||||
|
||||
var resp struct {
|
||||
Data struct {
|
||||
APIKey struct {
|
||||
GroupID *int64 `json:"group_id"`
|
||||
} `json:"api_key"`
|
||||
} `json:"data"`
|
||||
}
|
||||
require.NoError(t, json.Unmarshal(rec.Body.Bytes(), &resp))
|
||||
require.Nil(t, resp.Data.APIKey.GroupID)
|
||||
}
|
||||
|
||||
func TestAdminAPIKeyHandler_UpdateGroup_ServiceError(t *testing.T) {
|
||||
svc := &failingUpdateGroupService{
|
||||
stubAdminService: newStubAdminService(),
|
||||
err: errors.New("internal failure"),
|
||||
}
|
||||
router := setupAPIKeyHandler(svc)
|
||||
body := `{"group_id": 2}`
|
||||
|
||||
rec := httptest.NewRecorder()
|
||||
req := httptest.NewRequest(http.MethodPut, "/api/v1/admin/api-keys/10", bytes.NewBufferString(body))
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
router.ServeHTTP(rec, req)
|
||||
|
||||
require.Equal(t, http.StatusInternalServerError, rec.Code)
|
||||
}
|
||||
|
||||
// H2: empty body → group_id is nil → no-op, returns original key
|
||||
func TestAdminAPIKeyHandler_UpdateGroup_EmptyBody_NoChange(t *testing.T) {
|
||||
router := setupAPIKeyHandler(newStubAdminService())
|
||||
|
||||
rec := httptest.NewRecorder()
|
||||
req := httptest.NewRequest(http.MethodPut, "/api/v1/admin/api-keys/10", bytes.NewBufferString(`{}`))
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
router.ServeHTTP(rec, req)
|
||||
|
||||
require.Equal(t, http.StatusOK, rec.Code)
|
||||
|
||||
var resp struct {
|
||||
Code int `json:"code"`
|
||||
Data struct {
|
||||
APIKey struct {
|
||||
ID int64 `json:"id"`
|
||||
} `json:"api_key"`
|
||||
} `json:"data"`
|
||||
}
|
||||
require.NoError(t, json.Unmarshal(rec.Body.Bytes(), &resp))
|
||||
require.Equal(t, 0, resp.Code)
|
||||
require.Equal(t, int64(10), resp.Data.APIKey.ID)
|
||||
}
|
||||
|
||||
// M2: service returns GROUP_NOT_ACTIVE → handler maps to 400
|
||||
func TestAdminAPIKeyHandler_UpdateGroup_GroupNotActive(t *testing.T) {
|
||||
svc := &failingUpdateGroupService{
|
||||
stubAdminService: newStubAdminService(),
|
||||
err: infraerrors.BadRequest("GROUP_NOT_ACTIVE", "target group is not active"),
|
||||
}
|
||||
router := setupAPIKeyHandler(svc)
|
||||
|
||||
rec := httptest.NewRecorder()
|
||||
req := httptest.NewRequest(http.MethodPut, "/api/v1/admin/api-keys/10", bytes.NewBufferString(`{"group_id": 5}`))
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
router.ServeHTTP(rec, req)
|
||||
|
||||
require.Equal(t, http.StatusBadRequest, rec.Code)
|
||||
require.Contains(t, rec.Body.String(), "GROUP_NOT_ACTIVE")
|
||||
}
|
||||
|
||||
// M2: service returns INVALID_GROUP_ID → handler maps to 400
|
||||
func TestAdminAPIKeyHandler_UpdateGroup_NegativeGroupID(t *testing.T) {
|
||||
svc := &failingUpdateGroupService{
|
||||
stubAdminService: newStubAdminService(),
|
||||
err: infraerrors.BadRequest("INVALID_GROUP_ID", "group_id must be non-negative"),
|
||||
}
|
||||
router := setupAPIKeyHandler(svc)
|
||||
|
||||
rec := httptest.NewRecorder()
|
||||
req := httptest.NewRequest(http.MethodPut, "/api/v1/admin/api-keys/10", bytes.NewBufferString(`{"group_id": -5}`))
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
router.ServeHTTP(rec, req)
|
||||
|
||||
require.Equal(t, http.StatusBadRequest, rec.Code)
|
||||
require.Contains(t, rec.Body.String(), "INVALID_GROUP_ID")
|
||||
}
|
||||
|
||||
// failingUpdateGroupService overrides AdminUpdateAPIKeyGroupID to return an error.
|
||||
type failingUpdateGroupService struct {
|
||||
*stubAdminService
|
||||
err error
|
||||
}
|
||||
|
||||
func (f *failingUpdateGroupService) AdminUpdateAPIKeyGroupID(_ context.Context, _ int64, _ *int64) (*service.AdminUpdateAPIKeyGroupIDResult, error) {
|
||||
return nil, f.err
|
||||
}
|
||||
208
backend/internal/handler/admin/batch_update_credentials_test.go
Normal file
208
backend/internal/handler/admin/batch_update_credentials_test.go
Normal file
@@ -0,0 +1,208 @@
|
||||
//go:build unit
|
||||
|
||||
package admin
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"sync/atomic"
|
||||
"testing"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/service"
|
||||
)
|
||||
|
||||
// failingAdminService 嵌入 stubAdminService,可配置 UpdateAccount 在指定 ID 时失败。
|
||||
type failingAdminService struct {
|
||||
*stubAdminService
|
||||
failOnAccountID int64
|
||||
updateCallCount atomic.Int64
|
||||
}
|
||||
|
||||
func (f *failingAdminService) UpdateAccount(ctx context.Context, id int64, input *service.UpdateAccountInput) (*service.Account, error) {
|
||||
f.updateCallCount.Add(1)
|
||||
if id == f.failOnAccountID {
|
||||
return nil, errors.New("database error")
|
||||
}
|
||||
return f.stubAdminService.UpdateAccount(ctx, id, input)
|
||||
}
|
||||
|
||||
func setupAccountHandlerWithService(adminSvc service.AdminService) (*gin.Engine, *AccountHandler) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
router := gin.New()
|
||||
handler := NewAccountHandler(adminSvc, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil)
|
||||
router.POST("/api/v1/admin/accounts/batch-update-credentials", handler.BatchUpdateCredentials)
|
||||
return router, handler
|
||||
}
|
||||
|
||||
func TestBatchUpdateCredentials_AllSuccess(t *testing.T) {
|
||||
svc := &failingAdminService{stubAdminService: newStubAdminService()}
|
||||
router, _ := setupAccountHandlerWithService(svc)
|
||||
|
||||
body, _ := json.Marshal(BatchUpdateCredentialsRequest{
|
||||
AccountIDs: []int64{1, 2, 3},
|
||||
Field: "account_uuid",
|
||||
Value: "test-uuid",
|
||||
})
|
||||
|
||||
w := httptest.NewRecorder()
|
||||
req, _ := http.NewRequest("POST", "/api/v1/admin/accounts/batch-update-credentials", bytes.NewReader(body))
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
router.ServeHTTP(w, req)
|
||||
|
||||
require.Equal(t, http.StatusOK, w.Code, "全部成功时应返回 200")
|
||||
require.Equal(t, int64(3), svc.updateCallCount.Load(), "应调用 3 次 UpdateAccount")
|
||||
}
|
||||
|
||||
func TestBatchUpdateCredentials_PartialFailure(t *testing.T) {
|
||||
// 让第 2 个账号(ID=2)更新时失败
|
||||
svc := &failingAdminService{
|
||||
stubAdminService: newStubAdminService(),
|
||||
failOnAccountID: 2,
|
||||
}
|
||||
router, _ := setupAccountHandlerWithService(svc)
|
||||
|
||||
body, _ := json.Marshal(BatchUpdateCredentialsRequest{
|
||||
AccountIDs: []int64{1, 2, 3},
|
||||
Field: "org_uuid",
|
||||
Value: "test-org",
|
||||
})
|
||||
|
||||
w := httptest.NewRecorder()
|
||||
req, _ := http.NewRequest("POST", "/api/v1/admin/accounts/batch-update-credentials", bytes.NewReader(body))
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
router.ServeHTTP(w, req)
|
||||
|
||||
// 实现采用"部分成功"模式:总是返回 200 + 成功/失败明细
|
||||
require.Equal(t, http.StatusOK, w.Code, "批量更新返回 200 + 成功/失败明细")
|
||||
|
||||
var resp map[string]any
|
||||
require.NoError(t, json.Unmarshal(w.Body.Bytes(), &resp))
|
||||
data := resp["data"].(map[string]any)
|
||||
require.Equal(t, float64(2), data["success"], "应有 2 个成功")
|
||||
require.Equal(t, float64(1), data["failed"], "应有 1 个失败")
|
||||
|
||||
// 所有 3 个账号都会被尝试更新(非 fail-fast)
|
||||
require.Equal(t, int64(3), svc.updateCallCount.Load(),
|
||||
"应调用 3 次 UpdateAccount(逐个尝试,失败后继续)")
|
||||
}
|
||||
|
||||
func TestBatchUpdateCredentials_FirstAccountNotFound(t *testing.T) {
|
||||
// GetAccount 在 stubAdminService 中总是成功的,需要创建一个 GetAccount 会失败的 stub
|
||||
svc := &getAccountFailingService{
|
||||
stubAdminService: newStubAdminService(),
|
||||
failOnAccountID: 1,
|
||||
}
|
||||
router, _ := setupAccountHandlerWithService(svc)
|
||||
|
||||
body, _ := json.Marshal(BatchUpdateCredentialsRequest{
|
||||
AccountIDs: []int64{1, 2, 3},
|
||||
Field: "account_uuid",
|
||||
Value: "test",
|
||||
})
|
||||
|
||||
w := httptest.NewRecorder()
|
||||
req, _ := http.NewRequest("POST", "/api/v1/admin/accounts/batch-update-credentials", bytes.NewReader(body))
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
router.ServeHTTP(w, req)
|
||||
|
||||
require.Equal(t, http.StatusNotFound, w.Code, "第一阶段验证失败应返回 404")
|
||||
}
|
||||
|
||||
// getAccountFailingService 模拟 GetAccount 在特定 ID 时返回 not found。
|
||||
type getAccountFailingService struct {
|
||||
*stubAdminService
|
||||
failOnAccountID int64
|
||||
}
|
||||
|
||||
func (f *getAccountFailingService) GetAccount(ctx context.Context, id int64) (*service.Account, error) {
|
||||
if id == f.failOnAccountID {
|
||||
return nil, errors.New("not found")
|
||||
}
|
||||
return f.stubAdminService.GetAccount(ctx, id)
|
||||
}
|
||||
|
||||
func TestBatchUpdateCredentials_InterceptWarmupRequests_NonBool(t *testing.T) {
|
||||
svc := &failingAdminService{stubAdminService: newStubAdminService()}
|
||||
router, _ := setupAccountHandlerWithService(svc)
|
||||
|
||||
// intercept_warmup_requests 传入非 bool 类型(string),应返回 400
|
||||
body, _ := json.Marshal(map[string]any{
|
||||
"account_ids": []int64{1},
|
||||
"field": "intercept_warmup_requests",
|
||||
"value": "not-a-bool",
|
||||
})
|
||||
|
||||
w := httptest.NewRecorder()
|
||||
req, _ := http.NewRequest("POST", "/api/v1/admin/accounts/batch-update-credentials", bytes.NewReader(body))
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
router.ServeHTTP(w, req)
|
||||
|
||||
require.Equal(t, http.StatusBadRequest, w.Code,
|
||||
"intercept_warmup_requests 传入非 bool 值应返回 400")
|
||||
}
|
||||
|
||||
func TestBatchUpdateCredentials_InterceptWarmupRequests_ValidBool(t *testing.T) {
|
||||
svc := &failingAdminService{stubAdminService: newStubAdminService()}
|
||||
router, _ := setupAccountHandlerWithService(svc)
|
||||
|
||||
body, _ := json.Marshal(map[string]any{
|
||||
"account_ids": []int64{1},
|
||||
"field": "intercept_warmup_requests",
|
||||
"value": true,
|
||||
})
|
||||
|
||||
w := httptest.NewRecorder()
|
||||
req, _ := http.NewRequest("POST", "/api/v1/admin/accounts/batch-update-credentials", bytes.NewReader(body))
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
router.ServeHTTP(w, req)
|
||||
|
||||
require.Equal(t, http.StatusOK, w.Code,
|
||||
"intercept_warmup_requests 传入合法 bool 值应返回 200")
|
||||
}
|
||||
|
||||
func TestBatchUpdateCredentials_AccountUUID_NonString(t *testing.T) {
|
||||
svc := &failingAdminService{stubAdminService: newStubAdminService()}
|
||||
router, _ := setupAccountHandlerWithService(svc)
|
||||
|
||||
// account_uuid 传入非 string 类型(number),应返回 400
|
||||
body, _ := json.Marshal(map[string]any{
|
||||
"account_ids": []int64{1},
|
||||
"field": "account_uuid",
|
||||
"value": 12345,
|
||||
})
|
||||
|
||||
w := httptest.NewRecorder()
|
||||
req, _ := http.NewRequest("POST", "/api/v1/admin/accounts/batch-update-credentials", bytes.NewReader(body))
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
router.ServeHTTP(w, req)
|
||||
|
||||
require.Equal(t, http.StatusBadRequest, w.Code,
|
||||
"account_uuid 传入非 string 值应返回 400")
|
||||
}
|
||||
|
||||
func TestBatchUpdateCredentials_AccountUUID_NullValue(t *testing.T) {
|
||||
svc := &failingAdminService{stubAdminService: newStubAdminService()}
|
||||
router, _ := setupAccountHandlerWithService(svc)
|
||||
|
||||
// account_uuid 传入 null(设置为空),应正常通过
|
||||
body, _ := json.Marshal(map[string]any{
|
||||
"account_ids": []int64{1},
|
||||
"field": "account_uuid",
|
||||
"value": nil,
|
||||
})
|
||||
|
||||
w := httptest.NewRecorder()
|
||||
req, _ := http.NewRequest("POST", "/api/v1/admin/accounts/batch-update-credentials", bytes.NewReader(body))
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
router.ServeHTTP(w, req)
|
||||
|
||||
require.Equal(t, http.StatusOK, w.Code,
|
||||
"account_uuid 传入 null 应返回 200")
|
||||
}
|
||||
@@ -1,8 +1,10 @@
|
||||
package admin
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"strconv"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/response"
|
||||
@@ -186,7 +188,7 @@ func (h *DashboardHandler) GetRealtimeMetrics(c *gin.Context) {
|
||||
|
||||
// GetUsageTrend handles getting usage trend data
|
||||
// GET /api/v1/admin/dashboard/trend
|
||||
// Query params: start_date, end_date (YYYY-MM-DD), granularity (day/hour), user_id, api_key_id, model, account_id, group_id, stream, billing_type
|
||||
// Query params: start_date, end_date (YYYY-MM-DD), granularity (day/hour), user_id, api_key_id, model, account_id, group_id, request_type, stream, billing_type
|
||||
func (h *DashboardHandler) GetUsageTrend(c *gin.Context) {
|
||||
startTime, endTime := parseTimeRange(c)
|
||||
granularity := c.DefaultQuery("granularity", "day")
|
||||
@@ -194,6 +196,7 @@ func (h *DashboardHandler) GetUsageTrend(c *gin.Context) {
|
||||
// Parse optional filter params
|
||||
var userID, apiKeyID, accountID, groupID int64
|
||||
var model string
|
||||
var requestType *int16
|
||||
var stream *bool
|
||||
var billingType *int8
|
||||
|
||||
@@ -220,9 +223,20 @@ func (h *DashboardHandler) GetUsageTrend(c *gin.Context) {
|
||||
if modelStr := c.Query("model"); modelStr != "" {
|
||||
model = modelStr
|
||||
}
|
||||
if streamStr := c.Query("stream"); streamStr != "" {
|
||||
if requestTypeStr := strings.TrimSpace(c.Query("request_type")); requestTypeStr != "" {
|
||||
parsed, err := service.ParseUsageRequestType(requestTypeStr)
|
||||
if err != nil {
|
||||
response.BadRequest(c, err.Error())
|
||||
return
|
||||
}
|
||||
value := int16(parsed)
|
||||
requestType = &value
|
||||
} else if streamStr := c.Query("stream"); streamStr != "" {
|
||||
if streamVal, err := strconv.ParseBool(streamStr); err == nil {
|
||||
stream = &streamVal
|
||||
} else {
|
||||
response.BadRequest(c, "Invalid stream value, use true or false")
|
||||
return
|
||||
}
|
||||
}
|
||||
if billingTypeStr := c.Query("billing_type"); billingTypeStr != "" {
|
||||
@@ -235,7 +249,7 @@ func (h *DashboardHandler) GetUsageTrend(c *gin.Context) {
|
||||
}
|
||||
}
|
||||
|
||||
trend, err := h.dashboardService.GetUsageTrendWithFilters(c.Request.Context(), startTime, endTime, granularity, userID, apiKeyID, accountID, groupID, model, stream, billingType)
|
||||
trend, err := h.dashboardService.GetUsageTrendWithFilters(c.Request.Context(), startTime, endTime, granularity, userID, apiKeyID, accountID, groupID, model, requestType, stream, billingType)
|
||||
if err != nil {
|
||||
response.Error(c, 500, "Failed to get usage trend")
|
||||
return
|
||||
@@ -251,12 +265,13 @@ func (h *DashboardHandler) GetUsageTrend(c *gin.Context) {
|
||||
|
||||
// GetModelStats handles getting model usage statistics
|
||||
// GET /api/v1/admin/dashboard/models
|
||||
// Query params: start_date, end_date (YYYY-MM-DD), user_id, api_key_id, account_id, group_id, stream, billing_type
|
||||
// Query params: start_date, end_date (YYYY-MM-DD), user_id, api_key_id, account_id, group_id, request_type, stream, billing_type
|
||||
func (h *DashboardHandler) GetModelStats(c *gin.Context) {
|
||||
startTime, endTime := parseTimeRange(c)
|
||||
|
||||
// Parse optional filter params
|
||||
var userID, apiKeyID, accountID, groupID int64
|
||||
var requestType *int16
|
||||
var stream *bool
|
||||
var billingType *int8
|
||||
|
||||
@@ -280,9 +295,20 @@ func (h *DashboardHandler) GetModelStats(c *gin.Context) {
|
||||
groupID = id
|
||||
}
|
||||
}
|
||||
if streamStr := c.Query("stream"); streamStr != "" {
|
||||
if requestTypeStr := strings.TrimSpace(c.Query("request_type")); requestTypeStr != "" {
|
||||
parsed, err := service.ParseUsageRequestType(requestTypeStr)
|
||||
if err != nil {
|
||||
response.BadRequest(c, err.Error())
|
||||
return
|
||||
}
|
||||
value := int16(parsed)
|
||||
requestType = &value
|
||||
} else if streamStr := c.Query("stream"); streamStr != "" {
|
||||
if streamVal, err := strconv.ParseBool(streamStr); err == nil {
|
||||
stream = &streamVal
|
||||
} else {
|
||||
response.BadRequest(c, "Invalid stream value, use true or false")
|
||||
return
|
||||
}
|
||||
}
|
||||
if billingTypeStr := c.Query("billing_type"); billingTypeStr != "" {
|
||||
@@ -295,7 +321,7 @@ func (h *DashboardHandler) GetModelStats(c *gin.Context) {
|
||||
}
|
||||
}
|
||||
|
||||
stats, err := h.dashboardService.GetModelStatsWithFilters(c.Request.Context(), startTime, endTime, userID, apiKeyID, accountID, groupID, stream, billingType)
|
||||
stats, err := h.dashboardService.GetModelStatsWithFilters(c.Request.Context(), startTime, endTime, userID, apiKeyID, accountID, groupID, requestType, stream, billingType)
|
||||
if err != nil {
|
||||
response.Error(c, 500, "Failed to get model statistics")
|
||||
return
|
||||
@@ -308,6 +334,76 @@ func (h *DashboardHandler) GetModelStats(c *gin.Context) {
|
||||
})
|
||||
}
|
||||
|
||||
// GetGroupStats handles getting group usage statistics
|
||||
// GET /api/v1/admin/dashboard/groups
|
||||
// Query params: start_date, end_date (YYYY-MM-DD), user_id, api_key_id, account_id, group_id, request_type, stream, billing_type
|
||||
func (h *DashboardHandler) GetGroupStats(c *gin.Context) {
|
||||
startTime, endTime := parseTimeRange(c)
|
||||
|
||||
var userID, apiKeyID, accountID, groupID int64
|
||||
var requestType *int16
|
||||
var stream *bool
|
||||
var billingType *int8
|
||||
|
||||
if userIDStr := c.Query("user_id"); userIDStr != "" {
|
||||
if id, err := strconv.ParseInt(userIDStr, 10, 64); err == nil {
|
||||
userID = id
|
||||
}
|
||||
}
|
||||
if apiKeyIDStr := c.Query("api_key_id"); apiKeyIDStr != "" {
|
||||
if id, err := strconv.ParseInt(apiKeyIDStr, 10, 64); err == nil {
|
||||
apiKeyID = id
|
||||
}
|
||||
}
|
||||
if accountIDStr := c.Query("account_id"); accountIDStr != "" {
|
||||
if id, err := strconv.ParseInt(accountIDStr, 10, 64); err == nil {
|
||||
accountID = id
|
||||
}
|
||||
}
|
||||
if groupIDStr := c.Query("group_id"); groupIDStr != "" {
|
||||
if id, err := strconv.ParseInt(groupIDStr, 10, 64); err == nil {
|
||||
groupID = id
|
||||
}
|
||||
}
|
||||
if requestTypeStr := strings.TrimSpace(c.Query("request_type")); requestTypeStr != "" {
|
||||
parsed, err := service.ParseUsageRequestType(requestTypeStr)
|
||||
if err != nil {
|
||||
response.BadRequest(c, err.Error())
|
||||
return
|
||||
}
|
||||
value := int16(parsed)
|
||||
requestType = &value
|
||||
} else if streamStr := c.Query("stream"); streamStr != "" {
|
||||
if streamVal, err := strconv.ParseBool(streamStr); err == nil {
|
||||
stream = &streamVal
|
||||
} else {
|
||||
response.BadRequest(c, "Invalid stream value, use true or false")
|
||||
return
|
||||
}
|
||||
}
|
||||
if billingTypeStr := c.Query("billing_type"); billingTypeStr != "" {
|
||||
if v, err := strconv.ParseInt(billingTypeStr, 10, 8); err == nil {
|
||||
bt := int8(v)
|
||||
billingType = &bt
|
||||
} else {
|
||||
response.BadRequest(c, "Invalid billing_type")
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
stats, err := h.dashboardService.GetGroupStatsWithFilters(c.Request.Context(), startTime, endTime, userID, apiKeyID, accountID, groupID, requestType, stream, billingType)
|
||||
if err != nil {
|
||||
response.Error(c, 500, "Failed to get group statistics")
|
||||
return
|
||||
}
|
||||
|
||||
response.Success(c, gin.H{
|
||||
"groups": stats,
|
||||
"start_date": startTime.Format("2006-01-02"),
|
||||
"end_date": endTime.Add(-24 * time.Hour).Format("2006-01-02"),
|
||||
})
|
||||
}
|
||||
|
||||
// GetAPIKeyUsageTrend handles getting API key usage trend data
|
||||
// GET /api/v1/admin/dashboard/api-keys-trend
|
||||
// Query params: start_date, end_date (YYYY-MM-DD), granularity (day/hour), limit (default 5)
|
||||
@@ -365,6 +461,9 @@ type BatchUsersUsageRequest struct {
|
||||
UserIDs []int64 `json:"user_ids" binding:"required"`
|
||||
}
|
||||
|
||||
var dashboardBatchUsersUsageCache = newSnapshotCache(30 * time.Second)
|
||||
var dashboardBatchAPIKeysUsageCache = newSnapshotCache(30 * time.Second)
|
||||
|
||||
// GetBatchUsersUsage handles getting usage stats for multiple users
|
||||
// POST /api/v1/admin/dashboard/users-usage
|
||||
func (h *DashboardHandler) GetBatchUsersUsage(c *gin.Context) {
|
||||
@@ -374,18 +473,34 @@ func (h *DashboardHandler) GetBatchUsersUsage(c *gin.Context) {
|
||||
return
|
||||
}
|
||||
|
||||
if len(req.UserIDs) == 0 {
|
||||
userIDs := normalizeInt64IDList(req.UserIDs)
|
||||
if len(userIDs) == 0 {
|
||||
response.Success(c, gin.H{"stats": map[string]any{}})
|
||||
return
|
||||
}
|
||||
|
||||
stats, err := h.dashboardService.GetBatchUserUsageStats(c.Request.Context(), req.UserIDs)
|
||||
keyRaw, _ := json.Marshal(struct {
|
||||
UserIDs []int64 `json:"user_ids"`
|
||||
}{
|
||||
UserIDs: userIDs,
|
||||
})
|
||||
cacheKey := string(keyRaw)
|
||||
if cached, ok := dashboardBatchUsersUsageCache.Get(cacheKey); ok {
|
||||
c.Header("X-Snapshot-Cache", "hit")
|
||||
response.Success(c, cached.Payload)
|
||||
return
|
||||
}
|
||||
|
||||
stats, err := h.dashboardService.GetBatchUserUsageStats(c.Request.Context(), userIDs, time.Time{}, time.Time{})
|
||||
if err != nil {
|
||||
response.Error(c, 500, "Failed to get user usage stats")
|
||||
return
|
||||
}
|
||||
|
||||
response.Success(c, gin.H{"stats": stats})
|
||||
payload := gin.H{"stats": stats}
|
||||
dashboardBatchUsersUsageCache.Set(cacheKey, payload)
|
||||
c.Header("X-Snapshot-Cache", "miss")
|
||||
response.Success(c, payload)
|
||||
}
|
||||
|
||||
// BatchAPIKeysUsageRequest represents the request body for batch api key usage stats
|
||||
@@ -402,16 +517,32 @@ func (h *DashboardHandler) GetBatchAPIKeysUsage(c *gin.Context) {
|
||||
return
|
||||
}
|
||||
|
||||
if len(req.APIKeyIDs) == 0 {
|
||||
apiKeyIDs := normalizeInt64IDList(req.APIKeyIDs)
|
||||
if len(apiKeyIDs) == 0 {
|
||||
response.Success(c, gin.H{"stats": map[string]any{}})
|
||||
return
|
||||
}
|
||||
|
||||
stats, err := h.dashboardService.GetBatchAPIKeyUsageStats(c.Request.Context(), req.APIKeyIDs)
|
||||
keyRaw, _ := json.Marshal(struct {
|
||||
APIKeyIDs []int64 `json:"api_key_ids"`
|
||||
}{
|
||||
APIKeyIDs: apiKeyIDs,
|
||||
})
|
||||
cacheKey := string(keyRaw)
|
||||
if cached, ok := dashboardBatchAPIKeysUsageCache.Get(cacheKey); ok {
|
||||
c.Header("X-Snapshot-Cache", "hit")
|
||||
response.Success(c, cached.Payload)
|
||||
return
|
||||
}
|
||||
|
||||
stats, err := h.dashboardService.GetBatchAPIKeyUsageStats(c.Request.Context(), apiKeyIDs, time.Time{}, time.Time{})
|
||||
if err != nil {
|
||||
response.Error(c, 500, "Failed to get API key usage stats")
|
||||
return
|
||||
}
|
||||
|
||||
response.Success(c, gin.H{"stats": stats})
|
||||
payload := gin.H{"stats": stats}
|
||||
dashboardBatchAPIKeysUsageCache.Set(cacheKey, payload)
|
||||
c.Header("X-Snapshot-Cache", "miss")
|
||||
response.Success(c, payload)
|
||||
}
|
||||
|
||||
@@ -0,0 +1,132 @@
|
||||
package admin
|
||||
|
||||
import (
|
||||
"context"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/usagestats"
|
||||
"github.com/Wei-Shaw/sub2api/internal/service"
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
type dashboardUsageRepoCapture struct {
|
||||
service.UsageLogRepository
|
||||
trendRequestType *int16
|
||||
trendStream *bool
|
||||
modelRequestType *int16
|
||||
modelStream *bool
|
||||
}
|
||||
|
||||
func (s *dashboardUsageRepoCapture) GetUsageTrendWithFilters(
|
||||
ctx context.Context,
|
||||
startTime, endTime time.Time,
|
||||
granularity string,
|
||||
userID, apiKeyID, accountID, groupID int64,
|
||||
model string,
|
||||
requestType *int16,
|
||||
stream *bool,
|
||||
billingType *int8,
|
||||
) ([]usagestats.TrendDataPoint, error) {
|
||||
s.trendRequestType = requestType
|
||||
s.trendStream = stream
|
||||
return []usagestats.TrendDataPoint{}, nil
|
||||
}
|
||||
|
||||
func (s *dashboardUsageRepoCapture) GetModelStatsWithFilters(
|
||||
ctx context.Context,
|
||||
startTime, endTime time.Time,
|
||||
userID, apiKeyID, accountID, groupID int64,
|
||||
requestType *int16,
|
||||
stream *bool,
|
||||
billingType *int8,
|
||||
) ([]usagestats.ModelStat, error) {
|
||||
s.modelRequestType = requestType
|
||||
s.modelStream = stream
|
||||
return []usagestats.ModelStat{}, nil
|
||||
}
|
||||
|
||||
func newDashboardRequestTypeTestRouter(repo *dashboardUsageRepoCapture) *gin.Engine {
|
||||
gin.SetMode(gin.TestMode)
|
||||
dashboardSvc := service.NewDashboardService(repo, nil, nil, nil)
|
||||
handler := NewDashboardHandler(dashboardSvc, nil)
|
||||
router := gin.New()
|
||||
router.GET("/admin/dashboard/trend", handler.GetUsageTrend)
|
||||
router.GET("/admin/dashboard/models", handler.GetModelStats)
|
||||
return router
|
||||
}
|
||||
|
||||
func TestDashboardTrendRequestTypePriority(t *testing.T) {
|
||||
repo := &dashboardUsageRepoCapture{}
|
||||
router := newDashboardRequestTypeTestRouter(repo)
|
||||
|
||||
req := httptest.NewRequest(http.MethodGet, "/admin/dashboard/trend?request_type=ws_v2&stream=bad", nil)
|
||||
rec := httptest.NewRecorder()
|
||||
router.ServeHTTP(rec, req)
|
||||
|
||||
require.Equal(t, http.StatusOK, rec.Code)
|
||||
require.NotNil(t, repo.trendRequestType)
|
||||
require.Equal(t, int16(service.RequestTypeWSV2), *repo.trendRequestType)
|
||||
require.Nil(t, repo.trendStream)
|
||||
}
|
||||
|
||||
func TestDashboardTrendInvalidRequestType(t *testing.T) {
|
||||
repo := &dashboardUsageRepoCapture{}
|
||||
router := newDashboardRequestTypeTestRouter(repo)
|
||||
|
||||
req := httptest.NewRequest(http.MethodGet, "/admin/dashboard/trend?request_type=bad", nil)
|
||||
rec := httptest.NewRecorder()
|
||||
router.ServeHTTP(rec, req)
|
||||
|
||||
require.Equal(t, http.StatusBadRequest, rec.Code)
|
||||
}
|
||||
|
||||
func TestDashboardTrendInvalidStream(t *testing.T) {
|
||||
repo := &dashboardUsageRepoCapture{}
|
||||
router := newDashboardRequestTypeTestRouter(repo)
|
||||
|
||||
req := httptest.NewRequest(http.MethodGet, "/admin/dashboard/trend?stream=bad", nil)
|
||||
rec := httptest.NewRecorder()
|
||||
router.ServeHTTP(rec, req)
|
||||
|
||||
require.Equal(t, http.StatusBadRequest, rec.Code)
|
||||
}
|
||||
|
||||
func TestDashboardModelStatsRequestTypePriority(t *testing.T) {
|
||||
repo := &dashboardUsageRepoCapture{}
|
||||
router := newDashboardRequestTypeTestRouter(repo)
|
||||
|
||||
req := httptest.NewRequest(http.MethodGet, "/admin/dashboard/models?request_type=sync&stream=bad", nil)
|
||||
rec := httptest.NewRecorder()
|
||||
router.ServeHTTP(rec, req)
|
||||
|
||||
require.Equal(t, http.StatusOK, rec.Code)
|
||||
require.NotNil(t, repo.modelRequestType)
|
||||
require.Equal(t, int16(service.RequestTypeSync), *repo.modelRequestType)
|
||||
require.Nil(t, repo.modelStream)
|
||||
}
|
||||
|
||||
func TestDashboardModelStatsInvalidRequestType(t *testing.T) {
|
||||
repo := &dashboardUsageRepoCapture{}
|
||||
router := newDashboardRequestTypeTestRouter(repo)
|
||||
|
||||
req := httptest.NewRequest(http.MethodGet, "/admin/dashboard/models?request_type=bad", nil)
|
||||
rec := httptest.NewRecorder()
|
||||
router.ServeHTTP(rec, req)
|
||||
|
||||
require.Equal(t, http.StatusBadRequest, rec.Code)
|
||||
}
|
||||
|
||||
func TestDashboardModelStatsInvalidStream(t *testing.T) {
|
||||
repo := &dashboardUsageRepoCapture{}
|
||||
router := newDashboardRequestTypeTestRouter(repo)
|
||||
|
||||
req := httptest.NewRequest(http.MethodGet, "/admin/dashboard/models?stream=bad", nil)
|
||||
rec := httptest.NewRecorder()
|
||||
router.ServeHTTP(rec, req)
|
||||
|
||||
require.Equal(t, http.StatusBadRequest, rec.Code)
|
||||
}
|
||||
292
backend/internal/handler/admin/dashboard_snapshot_v2_handler.go
Normal file
292
backend/internal/handler/admin/dashboard_snapshot_v2_handler.go
Normal file
@@ -0,0 +1,292 @@
|
||||
package admin
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"net/http"
|
||||
"strconv"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/response"
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/usagestats"
|
||||
"github.com/Wei-Shaw/sub2api/internal/service"
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
var dashboardSnapshotV2Cache = newSnapshotCache(30 * time.Second)
|
||||
|
||||
type dashboardSnapshotV2Stats struct {
|
||||
usagestats.DashboardStats
|
||||
Uptime int64 `json:"uptime"`
|
||||
}
|
||||
|
||||
type dashboardSnapshotV2Response struct {
|
||||
GeneratedAt string `json:"generated_at"`
|
||||
|
||||
StartDate string `json:"start_date"`
|
||||
EndDate string `json:"end_date"`
|
||||
Granularity string `json:"granularity"`
|
||||
|
||||
Stats *dashboardSnapshotV2Stats `json:"stats,omitempty"`
|
||||
Trend []usagestats.TrendDataPoint `json:"trend,omitempty"`
|
||||
Models []usagestats.ModelStat `json:"models,omitempty"`
|
||||
Groups []usagestats.GroupStat `json:"groups,omitempty"`
|
||||
UsersTrend []usagestats.UserUsageTrendPoint `json:"users_trend,omitempty"`
|
||||
}
|
||||
|
||||
type dashboardSnapshotV2Filters struct {
|
||||
UserID int64
|
||||
APIKeyID int64
|
||||
AccountID int64
|
||||
GroupID int64
|
||||
Model string
|
||||
RequestType *int16
|
||||
Stream *bool
|
||||
BillingType *int8
|
||||
}
|
||||
|
||||
type dashboardSnapshotV2CacheKey struct {
|
||||
StartTime string `json:"start_time"`
|
||||
EndTime string `json:"end_time"`
|
||||
Granularity string `json:"granularity"`
|
||||
UserID int64 `json:"user_id"`
|
||||
APIKeyID int64 `json:"api_key_id"`
|
||||
AccountID int64 `json:"account_id"`
|
||||
GroupID int64 `json:"group_id"`
|
||||
Model string `json:"model"`
|
||||
RequestType *int16 `json:"request_type"`
|
||||
Stream *bool `json:"stream"`
|
||||
BillingType *int8 `json:"billing_type"`
|
||||
IncludeStats bool `json:"include_stats"`
|
||||
IncludeTrend bool `json:"include_trend"`
|
||||
IncludeModels bool `json:"include_models"`
|
||||
IncludeGroups bool `json:"include_groups"`
|
||||
IncludeUsersTrend bool `json:"include_users_trend"`
|
||||
UsersTrendLimit int `json:"users_trend_limit"`
|
||||
}
|
||||
|
||||
func (h *DashboardHandler) GetSnapshotV2(c *gin.Context) {
|
||||
startTime, endTime := parseTimeRange(c)
|
||||
granularity := strings.TrimSpace(c.DefaultQuery("granularity", "day"))
|
||||
if granularity != "hour" {
|
||||
granularity = "day"
|
||||
}
|
||||
|
||||
includeStats := parseBoolQueryWithDefault(c.Query("include_stats"), true)
|
||||
includeTrend := parseBoolQueryWithDefault(c.Query("include_trend"), true)
|
||||
includeModels := parseBoolQueryWithDefault(c.Query("include_model_stats"), true)
|
||||
includeGroups := parseBoolQueryWithDefault(c.Query("include_group_stats"), false)
|
||||
includeUsersTrend := parseBoolQueryWithDefault(c.Query("include_users_trend"), false)
|
||||
usersTrendLimit := 12
|
||||
if raw := strings.TrimSpace(c.Query("users_trend_limit")); raw != "" {
|
||||
if parsed, err := strconv.Atoi(raw); err == nil && parsed > 0 && parsed <= 50 {
|
||||
usersTrendLimit = parsed
|
||||
}
|
||||
}
|
||||
|
||||
filters, err := parseDashboardSnapshotV2Filters(c)
|
||||
if err != nil {
|
||||
response.BadRequest(c, err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
keyRaw, _ := json.Marshal(dashboardSnapshotV2CacheKey{
|
||||
StartTime: startTime.UTC().Format(time.RFC3339),
|
||||
EndTime: endTime.UTC().Format(time.RFC3339),
|
||||
Granularity: granularity,
|
||||
UserID: filters.UserID,
|
||||
APIKeyID: filters.APIKeyID,
|
||||
AccountID: filters.AccountID,
|
||||
GroupID: filters.GroupID,
|
||||
Model: filters.Model,
|
||||
RequestType: filters.RequestType,
|
||||
Stream: filters.Stream,
|
||||
BillingType: filters.BillingType,
|
||||
IncludeStats: includeStats,
|
||||
IncludeTrend: includeTrend,
|
||||
IncludeModels: includeModels,
|
||||
IncludeGroups: includeGroups,
|
||||
IncludeUsersTrend: includeUsersTrend,
|
||||
UsersTrendLimit: usersTrendLimit,
|
||||
})
|
||||
cacheKey := string(keyRaw)
|
||||
|
||||
if cached, ok := dashboardSnapshotV2Cache.Get(cacheKey); ok {
|
||||
if cached.ETag != "" {
|
||||
c.Header("ETag", cached.ETag)
|
||||
c.Header("Vary", "If-None-Match")
|
||||
if ifNoneMatchMatched(c.GetHeader("If-None-Match"), cached.ETag) {
|
||||
c.Status(http.StatusNotModified)
|
||||
return
|
||||
}
|
||||
}
|
||||
c.Header("X-Snapshot-Cache", "hit")
|
||||
response.Success(c, cached.Payload)
|
||||
return
|
||||
}
|
||||
|
||||
resp := &dashboardSnapshotV2Response{
|
||||
GeneratedAt: time.Now().UTC().Format(time.RFC3339),
|
||||
StartDate: startTime.Format("2006-01-02"),
|
||||
EndDate: endTime.Add(-24 * time.Hour).Format("2006-01-02"),
|
||||
Granularity: granularity,
|
||||
}
|
||||
|
||||
if includeStats {
|
||||
stats, err := h.dashboardService.GetDashboardStats(c.Request.Context())
|
||||
if err != nil {
|
||||
response.Error(c, 500, "Failed to get dashboard statistics")
|
||||
return
|
||||
}
|
||||
resp.Stats = &dashboardSnapshotV2Stats{
|
||||
DashboardStats: *stats,
|
||||
Uptime: int64(time.Since(h.startTime).Seconds()),
|
||||
}
|
||||
}
|
||||
|
||||
if includeTrend {
|
||||
trend, err := h.dashboardService.GetUsageTrendWithFilters(
|
||||
c.Request.Context(),
|
||||
startTime,
|
||||
endTime,
|
||||
granularity,
|
||||
filters.UserID,
|
||||
filters.APIKeyID,
|
||||
filters.AccountID,
|
||||
filters.GroupID,
|
||||
filters.Model,
|
||||
filters.RequestType,
|
||||
filters.Stream,
|
||||
filters.BillingType,
|
||||
)
|
||||
if err != nil {
|
||||
response.Error(c, 500, "Failed to get usage trend")
|
||||
return
|
||||
}
|
||||
resp.Trend = trend
|
||||
}
|
||||
|
||||
if includeModels {
|
||||
models, err := h.dashboardService.GetModelStatsWithFilters(
|
||||
c.Request.Context(),
|
||||
startTime,
|
||||
endTime,
|
||||
filters.UserID,
|
||||
filters.APIKeyID,
|
||||
filters.AccountID,
|
||||
filters.GroupID,
|
||||
filters.RequestType,
|
||||
filters.Stream,
|
||||
filters.BillingType,
|
||||
)
|
||||
if err != nil {
|
||||
response.Error(c, 500, "Failed to get model statistics")
|
||||
return
|
||||
}
|
||||
resp.Models = models
|
||||
}
|
||||
|
||||
if includeGroups {
|
||||
groups, err := h.dashboardService.GetGroupStatsWithFilters(
|
||||
c.Request.Context(),
|
||||
startTime,
|
||||
endTime,
|
||||
filters.UserID,
|
||||
filters.APIKeyID,
|
||||
filters.AccountID,
|
||||
filters.GroupID,
|
||||
filters.RequestType,
|
||||
filters.Stream,
|
||||
filters.BillingType,
|
||||
)
|
||||
if err != nil {
|
||||
response.Error(c, 500, "Failed to get group statistics")
|
||||
return
|
||||
}
|
||||
resp.Groups = groups
|
||||
}
|
||||
|
||||
if includeUsersTrend {
|
||||
usersTrend, err := h.dashboardService.GetUserUsageTrend(
|
||||
c.Request.Context(),
|
||||
startTime,
|
||||
endTime,
|
||||
granularity,
|
||||
usersTrendLimit,
|
||||
)
|
||||
if err != nil {
|
||||
response.Error(c, 500, "Failed to get user usage trend")
|
||||
return
|
||||
}
|
||||
resp.UsersTrend = usersTrend
|
||||
}
|
||||
|
||||
cached := dashboardSnapshotV2Cache.Set(cacheKey, resp)
|
||||
if cached.ETag != "" {
|
||||
c.Header("ETag", cached.ETag)
|
||||
c.Header("Vary", "If-None-Match")
|
||||
}
|
||||
c.Header("X-Snapshot-Cache", "miss")
|
||||
response.Success(c, resp)
|
||||
}
|
||||
|
||||
func parseDashboardSnapshotV2Filters(c *gin.Context) (*dashboardSnapshotV2Filters, error) {
|
||||
filters := &dashboardSnapshotV2Filters{
|
||||
Model: strings.TrimSpace(c.Query("model")),
|
||||
}
|
||||
|
||||
if userIDStr := strings.TrimSpace(c.Query("user_id")); userIDStr != "" {
|
||||
id, err := strconv.ParseInt(userIDStr, 10, 64)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
filters.UserID = id
|
||||
}
|
||||
if apiKeyIDStr := strings.TrimSpace(c.Query("api_key_id")); apiKeyIDStr != "" {
|
||||
id, err := strconv.ParseInt(apiKeyIDStr, 10, 64)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
filters.APIKeyID = id
|
||||
}
|
||||
if accountIDStr := strings.TrimSpace(c.Query("account_id")); accountIDStr != "" {
|
||||
id, err := strconv.ParseInt(accountIDStr, 10, 64)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
filters.AccountID = id
|
||||
}
|
||||
if groupIDStr := strings.TrimSpace(c.Query("group_id")); groupIDStr != "" {
|
||||
id, err := strconv.ParseInt(groupIDStr, 10, 64)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
filters.GroupID = id
|
||||
}
|
||||
|
||||
if requestTypeStr := strings.TrimSpace(c.Query("request_type")); requestTypeStr != "" {
|
||||
parsed, err := service.ParseUsageRequestType(requestTypeStr)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
value := int16(parsed)
|
||||
filters.RequestType = &value
|
||||
} else if streamStr := strings.TrimSpace(c.Query("stream")); streamStr != "" {
|
||||
streamVal, err := strconv.ParseBool(streamStr)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
filters.Stream = &streamVal
|
||||
}
|
||||
|
||||
if billingTypeStr := strings.TrimSpace(c.Query("billing_type")); billingTypeStr != "" {
|
||||
v, err := strconv.ParseInt(billingTypeStr, 10, 8)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
bt := int8(v)
|
||||
filters.BillingType = &bt
|
||||
}
|
||||
|
||||
return filters, nil
|
||||
}
|
||||
545
backend/internal/handler/admin/data_management_handler.go
Normal file
545
backend/internal/handler/admin/data_management_handler.go
Normal file
@@ -0,0 +1,545 @@
|
||||
package admin
|
||||
|
||||
import (
|
||||
"context"
|
||||
"strconv"
|
||||
"strings"
|
||||
|
||||
infraerrors "github.com/Wei-Shaw/sub2api/internal/pkg/errors"
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/response"
|
||||
middleware2 "github.com/Wei-Shaw/sub2api/internal/server/middleware"
|
||||
"github.com/Wei-Shaw/sub2api/internal/service"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
type DataManagementHandler struct {
|
||||
dataManagementService dataManagementService
|
||||
}
|
||||
|
||||
func NewDataManagementHandler(dataManagementService *service.DataManagementService) *DataManagementHandler {
|
||||
return &DataManagementHandler{dataManagementService: dataManagementService}
|
||||
}
|
||||
|
||||
type dataManagementService interface {
|
||||
GetConfig(ctx context.Context) (service.DataManagementConfig, error)
|
||||
UpdateConfig(ctx context.Context, cfg service.DataManagementConfig) (service.DataManagementConfig, error)
|
||||
ValidateS3(ctx context.Context, cfg service.DataManagementS3Config) (service.DataManagementTestS3Result, error)
|
||||
CreateBackupJob(ctx context.Context, input service.DataManagementCreateBackupJobInput) (service.DataManagementBackupJob, error)
|
||||
ListSourceProfiles(ctx context.Context, sourceType string) ([]service.DataManagementSourceProfile, error)
|
||||
CreateSourceProfile(ctx context.Context, input service.DataManagementCreateSourceProfileInput) (service.DataManagementSourceProfile, error)
|
||||
UpdateSourceProfile(ctx context.Context, input service.DataManagementUpdateSourceProfileInput) (service.DataManagementSourceProfile, error)
|
||||
DeleteSourceProfile(ctx context.Context, sourceType, profileID string) error
|
||||
SetActiveSourceProfile(ctx context.Context, sourceType, profileID string) (service.DataManagementSourceProfile, error)
|
||||
ListS3Profiles(ctx context.Context) ([]service.DataManagementS3Profile, error)
|
||||
CreateS3Profile(ctx context.Context, input service.DataManagementCreateS3ProfileInput) (service.DataManagementS3Profile, error)
|
||||
UpdateS3Profile(ctx context.Context, input service.DataManagementUpdateS3ProfileInput) (service.DataManagementS3Profile, error)
|
||||
DeleteS3Profile(ctx context.Context, profileID string) error
|
||||
SetActiveS3Profile(ctx context.Context, profileID string) (service.DataManagementS3Profile, error)
|
||||
ListBackupJobs(ctx context.Context, input service.DataManagementListBackupJobsInput) (service.DataManagementListBackupJobsResult, error)
|
||||
GetBackupJob(ctx context.Context, jobID string) (service.DataManagementBackupJob, error)
|
||||
EnsureAgentEnabled(ctx context.Context) error
|
||||
GetAgentHealth(ctx context.Context) service.DataManagementAgentHealth
|
||||
}
|
||||
|
||||
type TestS3ConnectionRequest struct {
|
||||
Endpoint string `json:"endpoint"`
|
||||
Region string `json:"region" binding:"required"`
|
||||
Bucket string `json:"bucket" binding:"required"`
|
||||
AccessKeyID string `json:"access_key_id"`
|
||||
SecretAccessKey string `json:"secret_access_key"`
|
||||
Prefix string `json:"prefix"`
|
||||
ForcePathStyle bool `json:"force_path_style"`
|
||||
UseSSL bool `json:"use_ssl"`
|
||||
}
|
||||
|
||||
type CreateBackupJobRequest struct {
|
||||
BackupType string `json:"backup_type" binding:"required,oneof=postgres redis full"`
|
||||
UploadToS3 bool `json:"upload_to_s3"`
|
||||
S3ProfileID string `json:"s3_profile_id"`
|
||||
PostgresID string `json:"postgres_profile_id"`
|
||||
RedisID string `json:"redis_profile_id"`
|
||||
IdempotencyKey string `json:"idempotency_key"`
|
||||
}
|
||||
|
||||
type CreateSourceProfileRequest struct {
|
||||
ProfileID string `json:"profile_id" binding:"required"`
|
||||
Name string `json:"name" binding:"required"`
|
||||
Config service.DataManagementSourceConfig `json:"config" binding:"required"`
|
||||
SetActive bool `json:"set_active"`
|
||||
}
|
||||
|
||||
type UpdateSourceProfileRequest struct {
|
||||
Name string `json:"name" binding:"required"`
|
||||
Config service.DataManagementSourceConfig `json:"config" binding:"required"`
|
||||
}
|
||||
|
||||
type CreateS3ProfileRequest struct {
|
||||
ProfileID string `json:"profile_id" binding:"required"`
|
||||
Name string `json:"name" binding:"required"`
|
||||
Enabled bool `json:"enabled"`
|
||||
Endpoint string `json:"endpoint"`
|
||||
Region string `json:"region"`
|
||||
Bucket string `json:"bucket"`
|
||||
AccessKeyID string `json:"access_key_id"`
|
||||
SecretAccessKey string `json:"secret_access_key"`
|
||||
Prefix string `json:"prefix"`
|
||||
ForcePathStyle bool `json:"force_path_style"`
|
||||
UseSSL bool `json:"use_ssl"`
|
||||
SetActive bool `json:"set_active"`
|
||||
}
|
||||
|
||||
type UpdateS3ProfileRequest struct {
|
||||
Name string `json:"name" binding:"required"`
|
||||
Enabled bool `json:"enabled"`
|
||||
Endpoint string `json:"endpoint"`
|
||||
Region string `json:"region"`
|
||||
Bucket string `json:"bucket"`
|
||||
AccessKeyID string `json:"access_key_id"`
|
||||
SecretAccessKey string `json:"secret_access_key"`
|
||||
Prefix string `json:"prefix"`
|
||||
ForcePathStyle bool `json:"force_path_style"`
|
||||
UseSSL bool `json:"use_ssl"`
|
||||
}
|
||||
|
||||
func (h *DataManagementHandler) GetAgentHealth(c *gin.Context) {
|
||||
health := h.getAgentHealth(c)
|
||||
payload := gin.H{
|
||||
"enabled": health.Enabled,
|
||||
"reason": health.Reason,
|
||||
"socket_path": health.SocketPath,
|
||||
}
|
||||
if health.Agent != nil {
|
||||
payload["agent"] = gin.H{
|
||||
"status": health.Agent.Status,
|
||||
"version": health.Agent.Version,
|
||||
"uptime_seconds": health.Agent.UptimeSeconds,
|
||||
}
|
||||
}
|
||||
response.Success(c, payload)
|
||||
}
|
||||
|
||||
func (h *DataManagementHandler) GetConfig(c *gin.Context) {
|
||||
if !h.requireAgentEnabled(c) {
|
||||
return
|
||||
}
|
||||
cfg, err := h.dataManagementService.GetConfig(c.Request.Context())
|
||||
if err != nil {
|
||||
response.ErrorFrom(c, err)
|
||||
return
|
||||
}
|
||||
response.Success(c, cfg)
|
||||
}
|
||||
|
||||
func (h *DataManagementHandler) UpdateConfig(c *gin.Context) {
|
||||
var req service.DataManagementConfig
|
||||
if err := c.ShouldBindJSON(&req); err != nil {
|
||||
response.BadRequest(c, "Invalid request: "+err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
if !h.requireAgentEnabled(c) {
|
||||
return
|
||||
}
|
||||
cfg, err := h.dataManagementService.UpdateConfig(c.Request.Context(), req)
|
||||
if err != nil {
|
||||
response.ErrorFrom(c, err)
|
||||
return
|
||||
}
|
||||
response.Success(c, cfg)
|
||||
}
|
||||
|
||||
func (h *DataManagementHandler) TestS3(c *gin.Context) {
|
||||
var req TestS3ConnectionRequest
|
||||
if err := c.ShouldBindJSON(&req); err != nil {
|
||||
response.BadRequest(c, "Invalid request: "+err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
if !h.requireAgentEnabled(c) {
|
||||
return
|
||||
}
|
||||
result, err := h.dataManagementService.ValidateS3(c.Request.Context(), service.DataManagementS3Config{
|
||||
Enabled: true,
|
||||
Endpoint: req.Endpoint,
|
||||
Region: req.Region,
|
||||
Bucket: req.Bucket,
|
||||
AccessKeyID: req.AccessKeyID,
|
||||
SecretAccessKey: req.SecretAccessKey,
|
||||
Prefix: req.Prefix,
|
||||
ForcePathStyle: req.ForcePathStyle,
|
||||
UseSSL: req.UseSSL,
|
||||
})
|
||||
if err != nil {
|
||||
response.ErrorFrom(c, err)
|
||||
return
|
||||
}
|
||||
response.Success(c, gin.H{"ok": result.OK, "message": result.Message})
|
||||
}
|
||||
|
||||
func (h *DataManagementHandler) CreateBackupJob(c *gin.Context) {
|
||||
var req CreateBackupJobRequest
|
||||
if err := c.ShouldBindJSON(&req); err != nil {
|
||||
response.BadRequest(c, "Invalid request: "+err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
req.IdempotencyKey = normalizeBackupIdempotencyKey(c.GetHeader("X-Idempotency-Key"), req.IdempotencyKey)
|
||||
if !h.requireAgentEnabled(c) {
|
||||
return
|
||||
}
|
||||
|
||||
triggeredBy := "admin:unknown"
|
||||
if subject, ok := middleware2.GetAuthSubjectFromContext(c); ok {
|
||||
triggeredBy = "admin:" + strconv.FormatInt(subject.UserID, 10)
|
||||
}
|
||||
job, err := h.dataManagementService.CreateBackupJob(c.Request.Context(), service.DataManagementCreateBackupJobInput{
|
||||
BackupType: req.BackupType,
|
||||
UploadToS3: req.UploadToS3,
|
||||
S3ProfileID: req.S3ProfileID,
|
||||
PostgresID: req.PostgresID,
|
||||
RedisID: req.RedisID,
|
||||
TriggeredBy: triggeredBy,
|
||||
IdempotencyKey: req.IdempotencyKey,
|
||||
})
|
||||
if err != nil {
|
||||
response.ErrorFrom(c, err)
|
||||
return
|
||||
}
|
||||
response.Success(c, gin.H{"job_id": job.JobID, "status": job.Status})
|
||||
}
|
||||
|
||||
func (h *DataManagementHandler) ListSourceProfiles(c *gin.Context) {
|
||||
sourceType := strings.TrimSpace(c.Param("source_type"))
|
||||
if sourceType == "" {
|
||||
response.BadRequest(c, "Invalid source_type")
|
||||
return
|
||||
}
|
||||
if sourceType != "postgres" && sourceType != "redis" {
|
||||
response.BadRequest(c, "source_type must be postgres or redis")
|
||||
return
|
||||
}
|
||||
|
||||
if !h.requireAgentEnabled(c) {
|
||||
return
|
||||
}
|
||||
items, err := h.dataManagementService.ListSourceProfiles(c.Request.Context(), sourceType)
|
||||
if err != nil {
|
||||
response.ErrorFrom(c, err)
|
||||
return
|
||||
}
|
||||
response.Success(c, gin.H{"items": items})
|
||||
}
|
||||
|
||||
func (h *DataManagementHandler) CreateSourceProfile(c *gin.Context) {
|
||||
sourceType := strings.TrimSpace(c.Param("source_type"))
|
||||
if sourceType != "postgres" && sourceType != "redis" {
|
||||
response.BadRequest(c, "source_type must be postgres or redis")
|
||||
return
|
||||
}
|
||||
|
||||
var req CreateSourceProfileRequest
|
||||
if err := c.ShouldBindJSON(&req); err != nil {
|
||||
response.BadRequest(c, "Invalid request: "+err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
if !h.requireAgentEnabled(c) {
|
||||
return
|
||||
}
|
||||
profile, err := h.dataManagementService.CreateSourceProfile(c.Request.Context(), service.DataManagementCreateSourceProfileInput{
|
||||
SourceType: sourceType,
|
||||
ProfileID: req.ProfileID,
|
||||
Name: req.Name,
|
||||
Config: req.Config,
|
||||
SetActive: req.SetActive,
|
||||
})
|
||||
if err != nil {
|
||||
response.ErrorFrom(c, err)
|
||||
return
|
||||
}
|
||||
response.Success(c, profile)
|
||||
}
|
||||
|
||||
func (h *DataManagementHandler) UpdateSourceProfile(c *gin.Context) {
|
||||
sourceType := strings.TrimSpace(c.Param("source_type"))
|
||||
if sourceType != "postgres" && sourceType != "redis" {
|
||||
response.BadRequest(c, "source_type must be postgres or redis")
|
||||
return
|
||||
}
|
||||
profileID := strings.TrimSpace(c.Param("profile_id"))
|
||||
if profileID == "" {
|
||||
response.BadRequest(c, "Invalid profile_id")
|
||||
return
|
||||
}
|
||||
|
||||
var req UpdateSourceProfileRequest
|
||||
if err := c.ShouldBindJSON(&req); err != nil {
|
||||
response.BadRequest(c, "Invalid request: "+err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
if !h.requireAgentEnabled(c) {
|
||||
return
|
||||
}
|
||||
profile, err := h.dataManagementService.UpdateSourceProfile(c.Request.Context(), service.DataManagementUpdateSourceProfileInput{
|
||||
SourceType: sourceType,
|
||||
ProfileID: profileID,
|
||||
Name: req.Name,
|
||||
Config: req.Config,
|
||||
})
|
||||
if err != nil {
|
||||
response.ErrorFrom(c, err)
|
||||
return
|
||||
}
|
||||
response.Success(c, profile)
|
||||
}
|
||||
|
||||
func (h *DataManagementHandler) DeleteSourceProfile(c *gin.Context) {
|
||||
sourceType := strings.TrimSpace(c.Param("source_type"))
|
||||
if sourceType != "postgres" && sourceType != "redis" {
|
||||
response.BadRequest(c, "source_type must be postgres or redis")
|
||||
return
|
||||
}
|
||||
profileID := strings.TrimSpace(c.Param("profile_id"))
|
||||
if profileID == "" {
|
||||
response.BadRequest(c, "Invalid profile_id")
|
||||
return
|
||||
}
|
||||
|
||||
if !h.requireAgentEnabled(c) {
|
||||
return
|
||||
}
|
||||
if err := h.dataManagementService.DeleteSourceProfile(c.Request.Context(), sourceType, profileID); err != nil {
|
||||
response.ErrorFrom(c, err)
|
||||
return
|
||||
}
|
||||
response.Success(c, gin.H{"deleted": true})
|
||||
}
|
||||
|
||||
func (h *DataManagementHandler) SetActiveSourceProfile(c *gin.Context) {
|
||||
sourceType := strings.TrimSpace(c.Param("source_type"))
|
||||
if sourceType != "postgres" && sourceType != "redis" {
|
||||
response.BadRequest(c, "source_type must be postgres or redis")
|
||||
return
|
||||
}
|
||||
profileID := strings.TrimSpace(c.Param("profile_id"))
|
||||
if profileID == "" {
|
||||
response.BadRequest(c, "Invalid profile_id")
|
||||
return
|
||||
}
|
||||
|
||||
if !h.requireAgentEnabled(c) {
|
||||
return
|
||||
}
|
||||
profile, err := h.dataManagementService.SetActiveSourceProfile(c.Request.Context(), sourceType, profileID)
|
||||
if err != nil {
|
||||
response.ErrorFrom(c, err)
|
||||
return
|
||||
}
|
||||
response.Success(c, profile)
|
||||
}
|
||||
|
||||
func (h *DataManagementHandler) ListS3Profiles(c *gin.Context) {
|
||||
if !h.requireAgentEnabled(c) {
|
||||
return
|
||||
}
|
||||
|
||||
items, err := h.dataManagementService.ListS3Profiles(c.Request.Context())
|
||||
if err != nil {
|
||||
response.ErrorFrom(c, err)
|
||||
return
|
||||
}
|
||||
response.Success(c, gin.H{"items": items})
|
||||
}
|
||||
|
||||
func (h *DataManagementHandler) CreateS3Profile(c *gin.Context) {
|
||||
var req CreateS3ProfileRequest
|
||||
if err := c.ShouldBindJSON(&req); err != nil {
|
||||
response.BadRequest(c, "Invalid request: "+err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
if !h.requireAgentEnabled(c) {
|
||||
return
|
||||
}
|
||||
|
||||
profile, err := h.dataManagementService.CreateS3Profile(c.Request.Context(), service.DataManagementCreateS3ProfileInput{
|
||||
ProfileID: req.ProfileID,
|
||||
Name: req.Name,
|
||||
SetActive: req.SetActive,
|
||||
S3: service.DataManagementS3Config{
|
||||
Enabled: req.Enabled,
|
||||
Endpoint: req.Endpoint,
|
||||
Region: req.Region,
|
||||
Bucket: req.Bucket,
|
||||
AccessKeyID: req.AccessKeyID,
|
||||
SecretAccessKey: req.SecretAccessKey,
|
||||
Prefix: req.Prefix,
|
||||
ForcePathStyle: req.ForcePathStyle,
|
||||
UseSSL: req.UseSSL,
|
||||
},
|
||||
})
|
||||
if err != nil {
|
||||
response.ErrorFrom(c, err)
|
||||
return
|
||||
}
|
||||
response.Success(c, profile)
|
||||
}
|
||||
|
||||
func (h *DataManagementHandler) UpdateS3Profile(c *gin.Context) {
|
||||
var req UpdateS3ProfileRequest
|
||||
if err := c.ShouldBindJSON(&req); err != nil {
|
||||
response.BadRequest(c, "Invalid request: "+err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
profileID := strings.TrimSpace(c.Param("profile_id"))
|
||||
if profileID == "" {
|
||||
response.BadRequest(c, "Invalid profile_id")
|
||||
return
|
||||
}
|
||||
|
||||
if !h.requireAgentEnabled(c) {
|
||||
return
|
||||
}
|
||||
|
||||
profile, err := h.dataManagementService.UpdateS3Profile(c.Request.Context(), service.DataManagementUpdateS3ProfileInput{
|
||||
ProfileID: profileID,
|
||||
Name: req.Name,
|
||||
S3: service.DataManagementS3Config{
|
||||
Enabled: req.Enabled,
|
||||
Endpoint: req.Endpoint,
|
||||
Region: req.Region,
|
||||
Bucket: req.Bucket,
|
||||
AccessKeyID: req.AccessKeyID,
|
||||
SecretAccessKey: req.SecretAccessKey,
|
||||
Prefix: req.Prefix,
|
||||
ForcePathStyle: req.ForcePathStyle,
|
||||
UseSSL: req.UseSSL,
|
||||
},
|
||||
})
|
||||
if err != nil {
|
||||
response.ErrorFrom(c, err)
|
||||
return
|
||||
}
|
||||
response.Success(c, profile)
|
||||
}
|
||||
|
||||
func (h *DataManagementHandler) DeleteS3Profile(c *gin.Context) {
|
||||
profileID := strings.TrimSpace(c.Param("profile_id"))
|
||||
if profileID == "" {
|
||||
response.BadRequest(c, "Invalid profile_id")
|
||||
return
|
||||
}
|
||||
|
||||
if !h.requireAgentEnabled(c) {
|
||||
return
|
||||
}
|
||||
if err := h.dataManagementService.DeleteS3Profile(c.Request.Context(), profileID); err != nil {
|
||||
response.ErrorFrom(c, err)
|
||||
return
|
||||
}
|
||||
response.Success(c, gin.H{"deleted": true})
|
||||
}
|
||||
|
||||
func (h *DataManagementHandler) SetActiveS3Profile(c *gin.Context) {
|
||||
profileID := strings.TrimSpace(c.Param("profile_id"))
|
||||
if profileID == "" {
|
||||
response.BadRequest(c, "Invalid profile_id")
|
||||
return
|
||||
}
|
||||
|
||||
if !h.requireAgentEnabled(c) {
|
||||
return
|
||||
}
|
||||
profile, err := h.dataManagementService.SetActiveS3Profile(c.Request.Context(), profileID)
|
||||
if err != nil {
|
||||
response.ErrorFrom(c, err)
|
||||
return
|
||||
}
|
||||
response.Success(c, profile)
|
||||
}
|
||||
|
||||
func (h *DataManagementHandler) ListBackupJobs(c *gin.Context) {
|
||||
if !h.requireAgentEnabled(c) {
|
||||
return
|
||||
}
|
||||
|
||||
pageSize := int32(20)
|
||||
if raw := strings.TrimSpace(c.Query("page_size")); raw != "" {
|
||||
v, err := strconv.Atoi(raw)
|
||||
if err != nil || v <= 0 {
|
||||
response.BadRequest(c, "Invalid page_size")
|
||||
return
|
||||
}
|
||||
pageSize = int32(v)
|
||||
}
|
||||
|
||||
result, err := h.dataManagementService.ListBackupJobs(c.Request.Context(), service.DataManagementListBackupJobsInput{
|
||||
PageSize: pageSize,
|
||||
PageToken: c.Query("page_token"),
|
||||
Status: c.Query("status"),
|
||||
BackupType: c.Query("backup_type"),
|
||||
})
|
||||
if err != nil {
|
||||
response.ErrorFrom(c, err)
|
||||
return
|
||||
}
|
||||
response.Success(c, result)
|
||||
}
|
||||
|
||||
func (h *DataManagementHandler) GetBackupJob(c *gin.Context) {
|
||||
jobID := strings.TrimSpace(c.Param("job_id"))
|
||||
if jobID == "" {
|
||||
response.BadRequest(c, "Invalid backup job ID")
|
||||
return
|
||||
}
|
||||
|
||||
if !h.requireAgentEnabled(c) {
|
||||
return
|
||||
}
|
||||
job, err := h.dataManagementService.GetBackupJob(c.Request.Context(), jobID)
|
||||
if err != nil {
|
||||
response.ErrorFrom(c, err)
|
||||
return
|
||||
}
|
||||
response.Success(c, job)
|
||||
}
|
||||
|
||||
func (h *DataManagementHandler) requireAgentEnabled(c *gin.Context) bool {
|
||||
if h.dataManagementService == nil {
|
||||
err := infraerrors.ServiceUnavailable(
|
||||
service.DataManagementAgentUnavailableReason,
|
||||
"data management agent service is not configured",
|
||||
).WithMetadata(map[string]string{"socket_path": service.DefaultDataManagementAgentSocketPath})
|
||||
response.ErrorFrom(c, err)
|
||||
return false
|
||||
}
|
||||
|
||||
if err := h.dataManagementService.EnsureAgentEnabled(c.Request.Context()); err != nil {
|
||||
response.ErrorFrom(c, err)
|
||||
return false
|
||||
}
|
||||
|
||||
return true
|
||||
}
|
||||
|
||||
func (h *DataManagementHandler) getAgentHealth(c *gin.Context) service.DataManagementAgentHealth {
|
||||
if h.dataManagementService == nil {
|
||||
return service.DataManagementAgentHealth{
|
||||
Enabled: false,
|
||||
Reason: service.DataManagementAgentUnavailableReason,
|
||||
SocketPath: service.DefaultDataManagementAgentSocketPath,
|
||||
}
|
||||
}
|
||||
return h.dataManagementService.GetAgentHealth(c.Request.Context())
|
||||
}
|
||||
|
||||
func normalizeBackupIdempotencyKey(headerValue, bodyValue string) string {
|
||||
headerKey := strings.TrimSpace(headerValue)
|
||||
if headerKey != "" {
|
||||
return headerKey
|
||||
}
|
||||
return strings.TrimSpace(bodyValue)
|
||||
}
|
||||
@@ -0,0 +1,78 @@
|
||||
package admin
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"path/filepath"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/service"
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
type apiEnvelope struct {
|
||||
Code int `json:"code"`
|
||||
Message string `json:"message"`
|
||||
Reason string `json:"reason"`
|
||||
Data json.RawMessage `json:"data"`
|
||||
}
|
||||
|
||||
func TestDataManagementHandler_AgentHealthAlways200(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
|
||||
svc := service.NewDataManagementServiceWithOptions(filepath.Join(t.TempDir(), "missing.sock"), 50*time.Millisecond)
|
||||
h := NewDataManagementHandler(svc)
|
||||
|
||||
r := gin.New()
|
||||
r.GET("/api/v1/admin/data-management/agent/health", h.GetAgentHealth)
|
||||
|
||||
rec := httptest.NewRecorder()
|
||||
req := httptest.NewRequest(http.MethodGet, "/api/v1/admin/data-management/agent/health", nil)
|
||||
r.ServeHTTP(rec, req)
|
||||
|
||||
require.Equal(t, http.StatusOK, rec.Code)
|
||||
|
||||
var envelope apiEnvelope
|
||||
require.NoError(t, json.Unmarshal(rec.Body.Bytes(), &envelope))
|
||||
require.Equal(t, 0, envelope.Code)
|
||||
|
||||
var data struct {
|
||||
Enabled bool `json:"enabled"`
|
||||
Reason string `json:"reason"`
|
||||
SocketPath string `json:"socket_path"`
|
||||
}
|
||||
require.NoError(t, json.Unmarshal(envelope.Data, &data))
|
||||
require.False(t, data.Enabled)
|
||||
require.Equal(t, service.DataManagementDeprecatedReason, data.Reason)
|
||||
require.Equal(t, svc.SocketPath(), data.SocketPath)
|
||||
}
|
||||
|
||||
func TestDataManagementHandler_NonHealthRouteReturns503WhenDisabled(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
|
||||
svc := service.NewDataManagementServiceWithOptions(filepath.Join(t.TempDir(), "missing.sock"), 50*time.Millisecond)
|
||||
h := NewDataManagementHandler(svc)
|
||||
|
||||
r := gin.New()
|
||||
r.GET("/api/v1/admin/data-management/config", h.GetConfig)
|
||||
|
||||
rec := httptest.NewRecorder()
|
||||
req := httptest.NewRequest(http.MethodGet, "/api/v1/admin/data-management/config", nil)
|
||||
r.ServeHTTP(rec, req)
|
||||
|
||||
require.Equal(t, http.StatusServiceUnavailable, rec.Code)
|
||||
|
||||
var envelope apiEnvelope
|
||||
require.NoError(t, json.Unmarshal(rec.Body.Bytes(), &envelope))
|
||||
require.Equal(t, http.StatusServiceUnavailable, envelope.Code)
|
||||
require.Equal(t, service.DataManagementDeprecatedReason, envelope.Reason)
|
||||
}
|
||||
|
||||
func TestNormalizeBackupIdempotencyKey(t *testing.T) {
|
||||
require.Equal(t, "from-header", normalizeBackupIdempotencyKey("from-header", "from-body"))
|
||||
require.Equal(t, "from-body", normalizeBackupIdempotencyKey(" ", " from-body "))
|
||||
require.Equal(t, "", normalizeBackupIdempotencyKey("", ""))
|
||||
}
|
||||
@@ -61,7 +61,11 @@ func (h *GeminiOAuthHandler) GenerateAuthURL(c *gin.Context) {
|
||||
if err != nil {
|
||||
msg := err.Error()
|
||||
// Treat missing/invalid OAuth client configuration as a user/config error.
|
||||
if strings.Contains(msg, "OAuth client not configured") || strings.Contains(msg, "requires your own OAuth Client") {
|
||||
if strings.Contains(msg, "OAuth client not configured") ||
|
||||
strings.Contains(msg, "requires your own OAuth Client") ||
|
||||
strings.Contains(msg, "requires a custom OAuth Client") ||
|
||||
strings.Contains(msg, "GEMINI_CLI_OAUTH_CLIENT_SECRET_MISSING") ||
|
||||
strings.Contains(msg, "built-in Gemini CLI OAuth client_secret is not configured") {
|
||||
response.BadRequest(c, "Failed to generate auth URL: "+msg)
|
||||
return
|
||||
}
|
||||
|
||||
@@ -27,7 +27,7 @@ func NewGroupHandler(adminService service.AdminService) *GroupHandler {
|
||||
type CreateGroupRequest struct {
|
||||
Name string `json:"name" binding:"required"`
|
||||
Description string `json:"description"`
|
||||
Platform string `json:"platform" binding:"omitempty,oneof=anthropic openai gemini antigravity"`
|
||||
Platform string `json:"platform" binding:"omitempty,oneof=anthropic openai gemini antigravity sora"`
|
||||
RateMultiplier float64 `json:"rate_multiplier"`
|
||||
IsExclusive bool `json:"is_exclusive"`
|
||||
SubscriptionType string `json:"subscription_type" binding:"omitempty,oneof=standard subscription"`
|
||||
@@ -38,6 +38,10 @@ type CreateGroupRequest struct {
|
||||
ImagePrice1K *float64 `json:"image_price_1k"`
|
||||
ImagePrice2K *float64 `json:"image_price_2k"`
|
||||
ImagePrice4K *float64 `json:"image_price_4k"`
|
||||
SoraImagePrice360 *float64 `json:"sora_image_price_360"`
|
||||
SoraImagePrice540 *float64 `json:"sora_image_price_540"`
|
||||
SoraVideoPricePerRequest *float64 `json:"sora_video_price_per_request"`
|
||||
SoraVideoPricePerRequestHD *float64 `json:"sora_video_price_per_request_hd"`
|
||||
ClaudeCodeOnly bool `json:"claude_code_only"`
|
||||
FallbackGroupID *int64 `json:"fallback_group_id"`
|
||||
FallbackGroupIDOnInvalidRequest *int64 `json:"fallback_group_id_on_invalid_request"`
|
||||
@@ -47,6 +51,8 @@ type CreateGroupRequest struct {
|
||||
MCPXMLInject *bool `json:"mcp_xml_inject"`
|
||||
// 支持的模型系列(仅 antigravity 平台使用)
|
||||
SupportedModelScopes []string `json:"supported_model_scopes"`
|
||||
// Sora 存储配额
|
||||
SoraStorageQuotaBytes int64 `json:"sora_storage_quota_bytes"`
|
||||
// 从指定分组复制账号(创建后自动绑定)
|
||||
CopyAccountsFromGroupIDs []int64 `json:"copy_accounts_from_group_ids"`
|
||||
}
|
||||
@@ -55,7 +61,7 @@ type CreateGroupRequest struct {
|
||||
type UpdateGroupRequest struct {
|
||||
Name string `json:"name"`
|
||||
Description string `json:"description"`
|
||||
Platform string `json:"platform" binding:"omitempty,oneof=anthropic openai gemini antigravity"`
|
||||
Platform string `json:"platform" binding:"omitempty,oneof=anthropic openai gemini antigravity sora"`
|
||||
RateMultiplier *float64 `json:"rate_multiplier"`
|
||||
IsExclusive *bool `json:"is_exclusive"`
|
||||
Status string `json:"status" binding:"omitempty,oneof=active inactive"`
|
||||
@@ -67,6 +73,10 @@ type UpdateGroupRequest struct {
|
||||
ImagePrice1K *float64 `json:"image_price_1k"`
|
||||
ImagePrice2K *float64 `json:"image_price_2k"`
|
||||
ImagePrice4K *float64 `json:"image_price_4k"`
|
||||
SoraImagePrice360 *float64 `json:"sora_image_price_360"`
|
||||
SoraImagePrice540 *float64 `json:"sora_image_price_540"`
|
||||
SoraVideoPricePerRequest *float64 `json:"sora_video_price_per_request"`
|
||||
SoraVideoPricePerRequestHD *float64 `json:"sora_video_price_per_request_hd"`
|
||||
ClaudeCodeOnly *bool `json:"claude_code_only"`
|
||||
FallbackGroupID *int64 `json:"fallback_group_id"`
|
||||
FallbackGroupIDOnInvalidRequest *int64 `json:"fallback_group_id_on_invalid_request"`
|
||||
@@ -76,6 +86,8 @@ type UpdateGroupRequest struct {
|
||||
MCPXMLInject *bool `json:"mcp_xml_inject"`
|
||||
// 支持的模型系列(仅 antigravity 平台使用)
|
||||
SupportedModelScopes *[]string `json:"supported_model_scopes"`
|
||||
// Sora 存储配额
|
||||
SoraStorageQuotaBytes *int64 `json:"sora_storage_quota_bytes"`
|
||||
// 从指定分组复制账号(同步操作:先清空当前分组的账号绑定,再绑定源分组的账号)
|
||||
CopyAccountsFromGroupIDs []int64 `json:"copy_accounts_from_group_ids"`
|
||||
}
|
||||
@@ -179,6 +191,10 @@ func (h *GroupHandler) Create(c *gin.Context) {
|
||||
ImagePrice1K: req.ImagePrice1K,
|
||||
ImagePrice2K: req.ImagePrice2K,
|
||||
ImagePrice4K: req.ImagePrice4K,
|
||||
SoraImagePrice360: req.SoraImagePrice360,
|
||||
SoraImagePrice540: req.SoraImagePrice540,
|
||||
SoraVideoPricePerRequest: req.SoraVideoPricePerRequest,
|
||||
SoraVideoPricePerRequestHD: req.SoraVideoPricePerRequestHD,
|
||||
ClaudeCodeOnly: req.ClaudeCodeOnly,
|
||||
FallbackGroupID: req.FallbackGroupID,
|
||||
FallbackGroupIDOnInvalidRequest: req.FallbackGroupIDOnInvalidRequest,
|
||||
@@ -186,6 +202,7 @@ func (h *GroupHandler) Create(c *gin.Context) {
|
||||
ModelRoutingEnabled: req.ModelRoutingEnabled,
|
||||
MCPXMLInject: req.MCPXMLInject,
|
||||
SupportedModelScopes: req.SupportedModelScopes,
|
||||
SoraStorageQuotaBytes: req.SoraStorageQuotaBytes,
|
||||
CopyAccountsFromGroupIDs: req.CopyAccountsFromGroupIDs,
|
||||
})
|
||||
if err != nil {
|
||||
@@ -225,6 +242,10 @@ func (h *GroupHandler) Update(c *gin.Context) {
|
||||
ImagePrice1K: req.ImagePrice1K,
|
||||
ImagePrice2K: req.ImagePrice2K,
|
||||
ImagePrice4K: req.ImagePrice4K,
|
||||
SoraImagePrice360: req.SoraImagePrice360,
|
||||
SoraImagePrice540: req.SoraImagePrice540,
|
||||
SoraVideoPricePerRequest: req.SoraVideoPricePerRequest,
|
||||
SoraVideoPricePerRequestHD: req.SoraVideoPricePerRequestHD,
|
||||
ClaudeCodeOnly: req.ClaudeCodeOnly,
|
||||
FallbackGroupID: req.FallbackGroupID,
|
||||
FallbackGroupIDOnInvalidRequest: req.FallbackGroupIDOnInvalidRequest,
|
||||
@@ -232,6 +253,7 @@ func (h *GroupHandler) Update(c *gin.Context) {
|
||||
ModelRoutingEnabled: req.ModelRoutingEnabled,
|
||||
MCPXMLInject: req.MCPXMLInject,
|
||||
SupportedModelScopes: req.SupportedModelScopes,
|
||||
SoraStorageQuotaBytes: req.SoraStorageQuotaBytes,
|
||||
CopyAccountsFromGroupIDs: req.CopyAccountsFromGroupIDs,
|
||||
})
|
||||
if err != nil {
|
||||
|
||||
25
backend/internal/handler/admin/id_list_utils.go
Normal file
25
backend/internal/handler/admin/id_list_utils.go
Normal file
@@ -0,0 +1,25 @@
|
||||
package admin
|
||||
|
||||
import "sort"
|
||||
|
||||
func normalizeInt64IDList(ids []int64) []int64 {
|
||||
if len(ids) == 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
out := make([]int64, 0, len(ids))
|
||||
seen := make(map[int64]struct{}, len(ids))
|
||||
for _, id := range ids {
|
||||
if id <= 0 {
|
||||
continue
|
||||
}
|
||||
if _, ok := seen[id]; ok {
|
||||
continue
|
||||
}
|
||||
seen[id] = struct{}{}
|
||||
out = append(out, id)
|
||||
}
|
||||
|
||||
sort.Slice(out, func(i, j int) bool { return out[i] < out[j] })
|
||||
return out
|
||||
}
|
||||
57
backend/internal/handler/admin/id_list_utils_test.go
Normal file
57
backend/internal/handler/admin/id_list_utils_test.go
Normal file
@@ -0,0 +1,57 @@
|
||||
//go:build unit
|
||||
|
||||
package admin
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestNormalizeInt64IDList(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
in []int64
|
||||
want []int64
|
||||
}{
|
||||
{"nil input", nil, nil},
|
||||
{"empty input", []int64{}, nil},
|
||||
{"single element", []int64{5}, []int64{5}},
|
||||
{"already sorted unique", []int64{1, 2, 3}, []int64{1, 2, 3}},
|
||||
{"duplicates removed", []int64{3, 1, 3, 2, 1}, []int64{1, 2, 3}},
|
||||
{"zero filtered", []int64{0, 1, 2}, []int64{1, 2}},
|
||||
{"negative filtered", []int64{-5, -1, 3}, []int64{3}},
|
||||
{"all invalid", []int64{0, -1, -2}, []int64{}},
|
||||
{"sorted output", []int64{9, 3, 7, 1}, []int64{1, 3, 7, 9}},
|
||||
}
|
||||
|
||||
for _, tc := range tests {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
got := normalizeInt64IDList(tc.in)
|
||||
if tc.want == nil {
|
||||
require.Nil(t, got)
|
||||
} else {
|
||||
require.Equal(t, tc.want, got)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestBuildAccountTodayStatsBatchCacheKey(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
ids []int64
|
||||
want string
|
||||
}{
|
||||
{"empty", nil, "accounts_today_stats_empty"},
|
||||
{"single", []int64{42}, "accounts_today_stats:42"},
|
||||
{"multiple", []int64{1, 2, 3}, "accounts_today_stats:1,2,3"},
|
||||
}
|
||||
|
||||
for _, tc := range tests {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
got := buildAccountTodayStatsBatchCacheKey(tc.ids)
|
||||
require.Equal(t, tc.want, got)
|
||||
})
|
||||
}
|
||||
}
|
||||
115
backend/internal/handler/admin/idempotency_helper.go
Normal file
115
backend/internal/handler/admin/idempotency_helper.go
Normal file
@@ -0,0 +1,115 @@
|
||||
package admin
|
||||
|
||||
import (
|
||||
"context"
|
||||
"strconv"
|
||||
"time"
|
||||
|
||||
infraerrors "github.com/Wei-Shaw/sub2api/internal/pkg/errors"
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/logger"
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/response"
|
||||
middleware2 "github.com/Wei-Shaw/sub2api/internal/server/middleware"
|
||||
"github.com/Wei-Shaw/sub2api/internal/service"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
type idempotencyStoreUnavailableMode int
|
||||
|
||||
const (
|
||||
idempotencyStoreUnavailableFailClose idempotencyStoreUnavailableMode = iota
|
||||
idempotencyStoreUnavailableFailOpen
|
||||
)
|
||||
|
||||
func executeAdminIdempotent(
|
||||
c *gin.Context,
|
||||
scope string,
|
||||
payload any,
|
||||
ttl time.Duration,
|
||||
execute func(context.Context) (any, error),
|
||||
) (*service.IdempotencyExecuteResult, error) {
|
||||
coordinator := service.DefaultIdempotencyCoordinator()
|
||||
if coordinator == nil {
|
||||
data, err := execute(c.Request.Context())
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &service.IdempotencyExecuteResult{Data: data}, nil
|
||||
}
|
||||
|
||||
actorScope := "admin:0"
|
||||
if subject, ok := middleware2.GetAuthSubjectFromContext(c); ok {
|
||||
actorScope = "admin:" + strconv.FormatInt(subject.UserID, 10)
|
||||
}
|
||||
|
||||
return coordinator.Execute(c.Request.Context(), service.IdempotencyExecuteOptions{
|
||||
Scope: scope,
|
||||
ActorScope: actorScope,
|
||||
Method: c.Request.Method,
|
||||
Route: c.FullPath(),
|
||||
IdempotencyKey: c.GetHeader("Idempotency-Key"),
|
||||
Payload: payload,
|
||||
RequireKey: true,
|
||||
TTL: ttl,
|
||||
}, execute)
|
||||
}
|
||||
|
||||
func executeAdminIdempotentJSON(
|
||||
c *gin.Context,
|
||||
scope string,
|
||||
payload any,
|
||||
ttl time.Duration,
|
||||
execute func(context.Context) (any, error),
|
||||
) {
|
||||
executeAdminIdempotentJSONWithMode(c, scope, payload, ttl, idempotencyStoreUnavailableFailClose, execute)
|
||||
}
|
||||
|
||||
func executeAdminIdempotentJSONFailOpenOnStoreUnavailable(
|
||||
c *gin.Context,
|
||||
scope string,
|
||||
payload any,
|
||||
ttl time.Duration,
|
||||
execute func(context.Context) (any, error),
|
||||
) {
|
||||
executeAdminIdempotentJSONWithMode(c, scope, payload, ttl, idempotencyStoreUnavailableFailOpen, execute)
|
||||
}
|
||||
|
||||
func executeAdminIdempotentJSONWithMode(
|
||||
c *gin.Context,
|
||||
scope string,
|
||||
payload any,
|
||||
ttl time.Duration,
|
||||
mode idempotencyStoreUnavailableMode,
|
||||
execute func(context.Context) (any, error),
|
||||
) {
|
||||
result, err := executeAdminIdempotent(c, scope, payload, ttl, execute)
|
||||
if err != nil {
|
||||
if infraerrors.Code(err) == infraerrors.Code(service.ErrIdempotencyStoreUnavail) {
|
||||
strategy := "fail_close"
|
||||
if mode == idempotencyStoreUnavailableFailOpen {
|
||||
strategy = "fail_open"
|
||||
}
|
||||
service.RecordIdempotencyStoreUnavailable(c.FullPath(), scope, "handler_"+strategy)
|
||||
logger.LegacyPrintf("handler.idempotency", "[Idempotency] store unavailable: method=%s route=%s scope=%s strategy=%s", c.Request.Method, c.FullPath(), scope, strategy)
|
||||
if mode == idempotencyStoreUnavailableFailOpen {
|
||||
data, fallbackErr := execute(c.Request.Context())
|
||||
if fallbackErr != nil {
|
||||
response.ErrorFrom(c, fallbackErr)
|
||||
return
|
||||
}
|
||||
c.Header("X-Idempotency-Degraded", "store-unavailable")
|
||||
response.Success(c, data)
|
||||
return
|
||||
}
|
||||
}
|
||||
if retryAfter := service.RetryAfterSecondsFromError(err); retryAfter > 0 {
|
||||
c.Header("Retry-After", strconv.Itoa(retryAfter))
|
||||
}
|
||||
response.ErrorFrom(c, err)
|
||||
return
|
||||
}
|
||||
if result != nil && result.Replayed {
|
||||
c.Header("X-Idempotency-Replayed", "true")
|
||||
}
|
||||
response.Success(c, result.Data)
|
||||
}
|
||||
285
backend/internal/handler/admin/idempotency_helper_test.go
Normal file
285
backend/internal/handler/admin/idempotency_helper_test.go
Normal file
@@ -0,0 +1,285 @@
|
||||
package admin
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"errors"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/service"
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
type storeUnavailableRepoStub struct{}
|
||||
|
||||
func (storeUnavailableRepoStub) CreateProcessing(context.Context, *service.IdempotencyRecord) (bool, error) {
|
||||
return false, errors.New("store unavailable")
|
||||
}
|
||||
func (storeUnavailableRepoStub) GetByScopeAndKeyHash(context.Context, string, string) (*service.IdempotencyRecord, error) {
|
||||
return nil, errors.New("store unavailable")
|
||||
}
|
||||
func (storeUnavailableRepoStub) TryReclaim(context.Context, int64, string, time.Time, time.Time, time.Time) (bool, error) {
|
||||
return false, errors.New("store unavailable")
|
||||
}
|
||||
func (storeUnavailableRepoStub) ExtendProcessingLock(context.Context, int64, string, time.Time, time.Time) (bool, error) {
|
||||
return false, errors.New("store unavailable")
|
||||
}
|
||||
func (storeUnavailableRepoStub) MarkSucceeded(context.Context, int64, int, string, time.Time) error {
|
||||
return errors.New("store unavailable")
|
||||
}
|
||||
func (storeUnavailableRepoStub) MarkFailedRetryable(context.Context, int64, string, time.Time, time.Time) error {
|
||||
return errors.New("store unavailable")
|
||||
}
|
||||
func (storeUnavailableRepoStub) DeleteExpired(context.Context, time.Time, int) (int64, error) {
|
||||
return 0, errors.New("store unavailable")
|
||||
}
|
||||
|
||||
func TestExecuteAdminIdempotentJSONFailCloseOnStoreUnavailable(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
service.SetDefaultIdempotencyCoordinator(service.NewIdempotencyCoordinator(storeUnavailableRepoStub{}, service.DefaultIdempotencyConfig()))
|
||||
t.Cleanup(func() {
|
||||
service.SetDefaultIdempotencyCoordinator(nil)
|
||||
})
|
||||
|
||||
var executed int
|
||||
router := gin.New()
|
||||
router.POST("/idempotent", func(c *gin.Context) {
|
||||
executeAdminIdempotentJSON(c, "admin.test.high", map[string]any{"a": 1}, time.Minute, func(ctx context.Context) (any, error) {
|
||||
executed++
|
||||
return gin.H{"ok": true}, nil
|
||||
})
|
||||
})
|
||||
|
||||
req := httptest.NewRequest(http.MethodPost, "/idempotent", bytes.NewBufferString(`{"a":1}`))
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
req.Header.Set("Idempotency-Key", "test-key-1")
|
||||
rec := httptest.NewRecorder()
|
||||
router.ServeHTTP(rec, req)
|
||||
|
||||
require.Equal(t, http.StatusServiceUnavailable, rec.Code)
|
||||
require.Equal(t, 0, executed, "fail-close should block business execution when idempotency store is unavailable")
|
||||
}
|
||||
|
||||
func TestExecuteAdminIdempotentJSONFailOpenOnStoreUnavailable(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
service.SetDefaultIdempotencyCoordinator(service.NewIdempotencyCoordinator(storeUnavailableRepoStub{}, service.DefaultIdempotencyConfig()))
|
||||
t.Cleanup(func() {
|
||||
service.SetDefaultIdempotencyCoordinator(nil)
|
||||
})
|
||||
|
||||
var executed int
|
||||
router := gin.New()
|
||||
router.POST("/idempotent", func(c *gin.Context) {
|
||||
executeAdminIdempotentJSONFailOpenOnStoreUnavailable(c, "admin.test.medium", map[string]any{"a": 1}, time.Minute, func(ctx context.Context) (any, error) {
|
||||
executed++
|
||||
return gin.H{"ok": true}, nil
|
||||
})
|
||||
})
|
||||
|
||||
req := httptest.NewRequest(http.MethodPost, "/idempotent", bytes.NewBufferString(`{"a":1}`))
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
req.Header.Set("Idempotency-Key", "test-key-2")
|
||||
rec := httptest.NewRecorder()
|
||||
router.ServeHTTP(rec, req)
|
||||
|
||||
require.Equal(t, http.StatusOK, rec.Code)
|
||||
require.Equal(t, "store-unavailable", rec.Header().Get("X-Idempotency-Degraded"))
|
||||
require.Equal(t, 1, executed, "fail-open strategy should allow semantic idempotent path to continue")
|
||||
}
|
||||
|
||||
type memoryIdempotencyRepoStub struct {
|
||||
mu sync.Mutex
|
||||
nextID int64
|
||||
data map[string]*service.IdempotencyRecord
|
||||
}
|
||||
|
||||
func newMemoryIdempotencyRepoStub() *memoryIdempotencyRepoStub {
|
||||
return &memoryIdempotencyRepoStub{
|
||||
nextID: 1,
|
||||
data: make(map[string]*service.IdempotencyRecord),
|
||||
}
|
||||
}
|
||||
|
||||
func (r *memoryIdempotencyRepoStub) key(scope, keyHash string) string {
|
||||
return scope + "|" + keyHash
|
||||
}
|
||||
|
||||
func (r *memoryIdempotencyRepoStub) clone(in *service.IdempotencyRecord) *service.IdempotencyRecord {
|
||||
if in == nil {
|
||||
return nil
|
||||
}
|
||||
out := *in
|
||||
if in.LockedUntil != nil {
|
||||
v := *in.LockedUntil
|
||||
out.LockedUntil = &v
|
||||
}
|
||||
if in.ResponseBody != nil {
|
||||
v := *in.ResponseBody
|
||||
out.ResponseBody = &v
|
||||
}
|
||||
if in.ResponseStatus != nil {
|
||||
v := *in.ResponseStatus
|
||||
out.ResponseStatus = &v
|
||||
}
|
||||
if in.ErrorReason != nil {
|
||||
v := *in.ErrorReason
|
||||
out.ErrorReason = &v
|
||||
}
|
||||
return &out
|
||||
}
|
||||
|
||||
func (r *memoryIdempotencyRepoStub) CreateProcessing(_ context.Context, record *service.IdempotencyRecord) (bool, error) {
|
||||
r.mu.Lock()
|
||||
defer r.mu.Unlock()
|
||||
k := r.key(record.Scope, record.IdempotencyKeyHash)
|
||||
if _, ok := r.data[k]; ok {
|
||||
return false, nil
|
||||
}
|
||||
cp := r.clone(record)
|
||||
cp.ID = r.nextID
|
||||
r.nextID++
|
||||
r.data[k] = cp
|
||||
record.ID = cp.ID
|
||||
return true, nil
|
||||
}
|
||||
|
||||
func (r *memoryIdempotencyRepoStub) GetByScopeAndKeyHash(_ context.Context, scope, keyHash string) (*service.IdempotencyRecord, error) {
|
||||
r.mu.Lock()
|
||||
defer r.mu.Unlock()
|
||||
return r.clone(r.data[r.key(scope, keyHash)]), nil
|
||||
}
|
||||
|
||||
func (r *memoryIdempotencyRepoStub) TryReclaim(_ context.Context, id int64, fromStatus string, now, newLockedUntil, newExpiresAt time.Time) (bool, error) {
|
||||
r.mu.Lock()
|
||||
defer r.mu.Unlock()
|
||||
for _, rec := range r.data {
|
||||
if rec.ID != id {
|
||||
continue
|
||||
}
|
||||
if rec.Status != fromStatus {
|
||||
return false, nil
|
||||
}
|
||||
if rec.LockedUntil != nil && rec.LockedUntil.After(now) {
|
||||
return false, nil
|
||||
}
|
||||
rec.Status = service.IdempotencyStatusProcessing
|
||||
rec.LockedUntil = &newLockedUntil
|
||||
rec.ExpiresAt = newExpiresAt
|
||||
rec.ErrorReason = nil
|
||||
return true, nil
|
||||
}
|
||||
return false, nil
|
||||
}
|
||||
|
||||
func (r *memoryIdempotencyRepoStub) ExtendProcessingLock(_ context.Context, id int64, requestFingerprint string, newLockedUntil, newExpiresAt time.Time) (bool, error) {
|
||||
r.mu.Lock()
|
||||
defer r.mu.Unlock()
|
||||
for _, rec := range r.data {
|
||||
if rec.ID != id {
|
||||
continue
|
||||
}
|
||||
if rec.Status != service.IdempotencyStatusProcessing || rec.RequestFingerprint != requestFingerprint {
|
||||
return false, nil
|
||||
}
|
||||
rec.LockedUntil = &newLockedUntil
|
||||
rec.ExpiresAt = newExpiresAt
|
||||
return true, nil
|
||||
}
|
||||
return false, nil
|
||||
}
|
||||
|
||||
func (r *memoryIdempotencyRepoStub) MarkSucceeded(_ context.Context, id int64, responseStatus int, responseBody string, expiresAt time.Time) error {
|
||||
r.mu.Lock()
|
||||
defer r.mu.Unlock()
|
||||
for _, rec := range r.data {
|
||||
if rec.ID != id {
|
||||
continue
|
||||
}
|
||||
rec.Status = service.IdempotencyStatusSucceeded
|
||||
rec.LockedUntil = nil
|
||||
rec.ExpiresAt = expiresAt
|
||||
rec.ResponseStatus = &responseStatus
|
||||
rec.ResponseBody = &responseBody
|
||||
rec.ErrorReason = nil
|
||||
return nil
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (r *memoryIdempotencyRepoStub) MarkFailedRetryable(_ context.Context, id int64, errorReason string, lockedUntil, expiresAt time.Time) error {
|
||||
r.mu.Lock()
|
||||
defer r.mu.Unlock()
|
||||
for _, rec := range r.data {
|
||||
if rec.ID != id {
|
||||
continue
|
||||
}
|
||||
rec.Status = service.IdempotencyStatusFailedRetryable
|
||||
rec.LockedUntil = &lockedUntil
|
||||
rec.ExpiresAt = expiresAt
|
||||
rec.ErrorReason = &errorReason
|
||||
return nil
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (r *memoryIdempotencyRepoStub) DeleteExpired(_ context.Context, _ time.Time, _ int) (int64, error) {
|
||||
return 0, nil
|
||||
}
|
||||
|
||||
func TestExecuteAdminIdempotentJSONConcurrentRetryOnlyOneSideEffect(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
repo := newMemoryIdempotencyRepoStub()
|
||||
cfg := service.DefaultIdempotencyConfig()
|
||||
cfg.ProcessingTimeout = 2 * time.Second
|
||||
service.SetDefaultIdempotencyCoordinator(service.NewIdempotencyCoordinator(repo, cfg))
|
||||
t.Cleanup(func() {
|
||||
service.SetDefaultIdempotencyCoordinator(nil)
|
||||
})
|
||||
|
||||
var executed atomic.Int32
|
||||
router := gin.New()
|
||||
router.POST("/idempotent", func(c *gin.Context) {
|
||||
executeAdminIdempotentJSON(c, "admin.test.concurrent", map[string]any{"a": 1}, time.Minute, func(ctx context.Context) (any, error) {
|
||||
executed.Add(1)
|
||||
time.Sleep(120 * time.Millisecond)
|
||||
return gin.H{"ok": true}, nil
|
||||
})
|
||||
})
|
||||
|
||||
call := func() (int, http.Header) {
|
||||
req := httptest.NewRequest(http.MethodPost, "/idempotent", bytes.NewBufferString(`{"a":1}`))
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
req.Header.Set("Idempotency-Key", "same-key")
|
||||
rec := httptest.NewRecorder()
|
||||
router.ServeHTTP(rec, req)
|
||||
return rec.Code, rec.Header()
|
||||
}
|
||||
|
||||
var status1, status2 int
|
||||
var wg sync.WaitGroup
|
||||
wg.Add(2)
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
status1, _ = call()
|
||||
}()
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
status2, _ = call()
|
||||
}()
|
||||
wg.Wait()
|
||||
|
||||
require.Contains(t, []int{http.StatusOK, http.StatusConflict}, status1)
|
||||
require.Contains(t, []int{http.StatusOK, http.StatusConflict}, status2)
|
||||
require.Equal(t, int32(1), executed.Load(), "same idempotency key should execute side-effect only once")
|
||||
|
||||
status3, headers3 := call()
|
||||
require.Equal(t, http.StatusOK, status3)
|
||||
require.Equal(t, "true", headers3.Get("X-Idempotency-Replayed"))
|
||||
require.Equal(t, int32(1), executed.Load())
|
||||
}
|
||||
@@ -2,8 +2,10 @@ package admin
|
||||
|
||||
import (
|
||||
"strconv"
|
||||
"strings"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/handler/dto"
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/openai"
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/response"
|
||||
"github.com/Wei-Shaw/sub2api/internal/service"
|
||||
|
||||
@@ -16,6 +18,13 @@ type OpenAIOAuthHandler struct {
|
||||
adminService service.AdminService
|
||||
}
|
||||
|
||||
func oauthPlatformFromPath(c *gin.Context) string {
|
||||
if strings.Contains(c.FullPath(), "/admin/sora/") {
|
||||
return service.PlatformSora
|
||||
}
|
||||
return service.PlatformOpenAI
|
||||
}
|
||||
|
||||
// NewOpenAIOAuthHandler creates a new OpenAI OAuth handler
|
||||
func NewOpenAIOAuthHandler(openaiOAuthService *service.OpenAIOAuthService, adminService service.AdminService) *OpenAIOAuthHandler {
|
||||
return &OpenAIOAuthHandler{
|
||||
@@ -39,7 +48,12 @@ func (h *OpenAIOAuthHandler) GenerateAuthURL(c *gin.Context) {
|
||||
req = OpenAIGenerateAuthURLRequest{}
|
||||
}
|
||||
|
||||
result, err := h.openaiOAuthService.GenerateAuthURL(c.Request.Context(), req.ProxyID, req.RedirectURI)
|
||||
result, err := h.openaiOAuthService.GenerateAuthURL(
|
||||
c.Request.Context(),
|
||||
req.ProxyID,
|
||||
req.RedirectURI,
|
||||
oauthPlatformFromPath(c),
|
||||
)
|
||||
if err != nil {
|
||||
response.ErrorFrom(c, err)
|
||||
return
|
||||
@@ -52,6 +66,7 @@ func (h *OpenAIOAuthHandler) GenerateAuthURL(c *gin.Context) {
|
||||
type OpenAIExchangeCodeRequest struct {
|
||||
SessionID string `json:"session_id" binding:"required"`
|
||||
Code string `json:"code" binding:"required"`
|
||||
State string `json:"state" binding:"required"`
|
||||
RedirectURI string `json:"redirect_uri"`
|
||||
ProxyID *int64 `json:"proxy_id"`
|
||||
}
|
||||
@@ -68,6 +83,7 @@ func (h *OpenAIOAuthHandler) ExchangeCode(c *gin.Context) {
|
||||
tokenInfo, err := h.openaiOAuthService.ExchangeCode(c.Request.Context(), &service.OpenAIExchangeCodeInput{
|
||||
SessionID: req.SessionID,
|
||||
Code: req.Code,
|
||||
State: req.State,
|
||||
RedirectURI: req.RedirectURI,
|
||||
ProxyID: req.ProxyID,
|
||||
})
|
||||
@@ -81,18 +97,29 @@ func (h *OpenAIOAuthHandler) ExchangeCode(c *gin.Context) {
|
||||
|
||||
// OpenAIRefreshTokenRequest represents the request for refreshing OpenAI token
|
||||
type OpenAIRefreshTokenRequest struct {
|
||||
RefreshToken string `json:"refresh_token" binding:"required"`
|
||||
RefreshToken string `json:"refresh_token"`
|
||||
RT string `json:"rt"`
|
||||
ClientID string `json:"client_id"`
|
||||
ProxyID *int64 `json:"proxy_id"`
|
||||
}
|
||||
|
||||
// RefreshToken refreshes an OpenAI OAuth token
|
||||
// POST /api/v1/admin/openai/refresh-token
|
||||
// POST /api/v1/admin/sora/rt2at
|
||||
func (h *OpenAIOAuthHandler) RefreshToken(c *gin.Context) {
|
||||
var req OpenAIRefreshTokenRequest
|
||||
if err := c.ShouldBindJSON(&req); err != nil {
|
||||
response.BadRequest(c, "Invalid request: "+err.Error())
|
||||
return
|
||||
}
|
||||
refreshToken := strings.TrimSpace(req.RefreshToken)
|
||||
if refreshToken == "" {
|
||||
refreshToken = strings.TrimSpace(req.RT)
|
||||
}
|
||||
if refreshToken == "" {
|
||||
response.BadRequest(c, "refresh_token is required")
|
||||
return
|
||||
}
|
||||
|
||||
var proxyURL string
|
||||
if req.ProxyID != nil {
|
||||
@@ -102,7 +129,14 @@ func (h *OpenAIOAuthHandler) RefreshToken(c *gin.Context) {
|
||||
}
|
||||
}
|
||||
|
||||
tokenInfo, err := h.openaiOAuthService.RefreshToken(c.Request.Context(), req.RefreshToken, proxyURL)
|
||||
// 未指定 client_id 时,根据请求路径平台自动设置默认值,避免 repository 层盲猜
|
||||
clientID := strings.TrimSpace(req.ClientID)
|
||||
if clientID == "" {
|
||||
platform := oauthPlatformFromPath(c)
|
||||
clientID, _ = openai.OAuthClientConfigByPlatform(platform)
|
||||
}
|
||||
|
||||
tokenInfo, err := h.openaiOAuthService.RefreshTokenWithClientID(c.Request.Context(), refreshToken, proxyURL, clientID)
|
||||
if err != nil {
|
||||
response.ErrorFrom(c, err)
|
||||
return
|
||||
@@ -111,8 +145,39 @@ func (h *OpenAIOAuthHandler) RefreshToken(c *gin.Context) {
|
||||
response.Success(c, tokenInfo)
|
||||
}
|
||||
|
||||
// RefreshAccountToken refreshes token for a specific OpenAI account
|
||||
// ExchangeSoraSessionToken exchanges Sora session token to access token
|
||||
// POST /api/v1/admin/sora/st2at
|
||||
func (h *OpenAIOAuthHandler) ExchangeSoraSessionToken(c *gin.Context) {
|
||||
var req struct {
|
||||
SessionToken string `json:"session_token"`
|
||||
ST string `json:"st"`
|
||||
ProxyID *int64 `json:"proxy_id"`
|
||||
}
|
||||
if err := c.ShouldBindJSON(&req); err != nil {
|
||||
response.BadRequest(c, "Invalid request: "+err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
sessionToken := strings.TrimSpace(req.SessionToken)
|
||||
if sessionToken == "" {
|
||||
sessionToken = strings.TrimSpace(req.ST)
|
||||
}
|
||||
if sessionToken == "" {
|
||||
response.BadRequest(c, "session_token is required")
|
||||
return
|
||||
}
|
||||
|
||||
tokenInfo, err := h.openaiOAuthService.ExchangeSoraSessionToken(c.Request.Context(), sessionToken, req.ProxyID)
|
||||
if err != nil {
|
||||
response.ErrorFrom(c, err)
|
||||
return
|
||||
}
|
||||
response.Success(c, tokenInfo)
|
||||
}
|
||||
|
||||
// RefreshAccountToken refreshes token for a specific OpenAI/Sora account
|
||||
// POST /api/v1/admin/openai/accounts/:id/refresh
|
||||
// POST /api/v1/admin/sora/accounts/:id/refresh
|
||||
func (h *OpenAIOAuthHandler) RefreshAccountToken(c *gin.Context) {
|
||||
accountID, err := strconv.ParseInt(c.Param("id"), 10, 64)
|
||||
if err != nil {
|
||||
@@ -127,9 +192,9 @@ func (h *OpenAIOAuthHandler) RefreshAccountToken(c *gin.Context) {
|
||||
return
|
||||
}
|
||||
|
||||
// Ensure account is OpenAI platform
|
||||
if !account.IsOpenAI() {
|
||||
response.BadRequest(c, "Account is not an OpenAI account")
|
||||
platform := oauthPlatformFromPath(c)
|
||||
if account.Platform != platform {
|
||||
response.BadRequest(c, "Account platform does not match OAuth endpoint")
|
||||
return
|
||||
}
|
||||
|
||||
@@ -167,12 +232,14 @@ func (h *OpenAIOAuthHandler) RefreshAccountToken(c *gin.Context) {
|
||||
response.Success(c, dto.AccountFromService(updatedAccount))
|
||||
}
|
||||
|
||||
// CreateAccountFromOAuth creates a new OpenAI OAuth account from token info
|
||||
// CreateAccountFromOAuth creates a new OpenAI/Sora OAuth account from token info
|
||||
// POST /api/v1/admin/openai/create-from-oauth
|
||||
// POST /api/v1/admin/sora/create-from-oauth
|
||||
func (h *OpenAIOAuthHandler) CreateAccountFromOAuth(c *gin.Context) {
|
||||
var req struct {
|
||||
SessionID string `json:"session_id" binding:"required"`
|
||||
Code string `json:"code" binding:"required"`
|
||||
State string `json:"state" binding:"required"`
|
||||
RedirectURI string `json:"redirect_uri"`
|
||||
ProxyID *int64 `json:"proxy_id"`
|
||||
Name string `json:"name"`
|
||||
@@ -189,6 +256,7 @@ func (h *OpenAIOAuthHandler) CreateAccountFromOAuth(c *gin.Context) {
|
||||
tokenInfo, err := h.openaiOAuthService.ExchangeCode(c.Request.Context(), &service.OpenAIExchangeCodeInput{
|
||||
SessionID: req.SessionID,
|
||||
Code: req.Code,
|
||||
State: req.State,
|
||||
RedirectURI: req.RedirectURI,
|
||||
ProxyID: req.ProxyID,
|
||||
})
|
||||
@@ -200,19 +268,25 @@ func (h *OpenAIOAuthHandler) CreateAccountFromOAuth(c *gin.Context) {
|
||||
// Build credentials from token info
|
||||
credentials := h.openaiOAuthService.BuildAccountCredentials(tokenInfo)
|
||||
|
||||
platform := oauthPlatformFromPath(c)
|
||||
|
||||
// Use email as default name if not provided
|
||||
name := req.Name
|
||||
if name == "" && tokenInfo.Email != "" {
|
||||
name = tokenInfo.Email
|
||||
}
|
||||
if name == "" {
|
||||
name = "OpenAI OAuth Account"
|
||||
if platform == service.PlatformSora {
|
||||
name = "Sora OAuth Account"
|
||||
} else {
|
||||
name = "OpenAI OAuth Account"
|
||||
}
|
||||
}
|
||||
|
||||
// Create account
|
||||
account, err := h.adminService.CreateAccount(c.Request.Context(), &service.CreateAccountInput{
|
||||
Name: name,
|
||||
Platform: "openai",
|
||||
Platform: platform,
|
||||
Type: "oauth",
|
||||
Credentials: credentials,
|
||||
ProxyID: req.ProxyID,
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
package admin
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"net/http"
|
||||
"strconv"
|
||||
"strings"
|
||||
@@ -218,6 +219,115 @@ func (h *OpsHandler) GetDashboardErrorDistribution(c *gin.Context) {
|
||||
response.Success(c, data)
|
||||
}
|
||||
|
||||
// GetDashboardOpenAITokenStats returns OpenAI token efficiency stats grouped by model.
|
||||
// GET /api/v1/admin/ops/dashboard/openai-token-stats
|
||||
func (h *OpsHandler) GetDashboardOpenAITokenStats(c *gin.Context) {
|
||||
if h.opsService == nil {
|
||||
response.Error(c, http.StatusServiceUnavailable, "Ops service not available")
|
||||
return
|
||||
}
|
||||
if err := h.opsService.RequireMonitoringEnabled(c.Request.Context()); err != nil {
|
||||
response.ErrorFrom(c, err)
|
||||
return
|
||||
}
|
||||
|
||||
filter, err := parseOpsOpenAITokenStatsFilter(c)
|
||||
if err != nil {
|
||||
response.BadRequest(c, err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
data, err := h.opsService.GetOpenAITokenStats(c.Request.Context(), filter)
|
||||
if err != nil {
|
||||
response.ErrorFrom(c, err)
|
||||
return
|
||||
}
|
||||
response.Success(c, data)
|
||||
}
|
||||
|
||||
func parseOpsOpenAITokenStatsFilter(c *gin.Context) (*service.OpsOpenAITokenStatsFilter, error) {
|
||||
if c == nil {
|
||||
return nil, fmt.Errorf("invalid request")
|
||||
}
|
||||
|
||||
timeRange := strings.TrimSpace(c.Query("time_range"))
|
||||
if timeRange == "" {
|
||||
timeRange = "30d"
|
||||
}
|
||||
dur, ok := parseOpsOpenAITokenStatsDuration(timeRange)
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("invalid time_range")
|
||||
}
|
||||
end := time.Now().UTC()
|
||||
start := end.Add(-dur)
|
||||
|
||||
filter := &service.OpsOpenAITokenStatsFilter{
|
||||
TimeRange: timeRange,
|
||||
StartTime: start,
|
||||
EndTime: end,
|
||||
Platform: strings.TrimSpace(c.Query("platform")),
|
||||
}
|
||||
|
||||
if v := strings.TrimSpace(c.Query("group_id")); v != "" {
|
||||
id, err := strconv.ParseInt(v, 10, 64)
|
||||
if err != nil || id <= 0 {
|
||||
return nil, fmt.Errorf("invalid group_id")
|
||||
}
|
||||
filter.GroupID = &id
|
||||
}
|
||||
|
||||
topNRaw := strings.TrimSpace(c.Query("top_n"))
|
||||
pageRaw := strings.TrimSpace(c.Query("page"))
|
||||
pageSizeRaw := strings.TrimSpace(c.Query("page_size"))
|
||||
if topNRaw != "" && (pageRaw != "" || pageSizeRaw != "") {
|
||||
return nil, fmt.Errorf("invalid query: top_n cannot be used with page/page_size")
|
||||
}
|
||||
|
||||
if topNRaw != "" {
|
||||
topN, err := strconv.Atoi(topNRaw)
|
||||
if err != nil || topN < 1 || topN > 100 {
|
||||
return nil, fmt.Errorf("invalid top_n")
|
||||
}
|
||||
filter.TopN = topN
|
||||
return filter, nil
|
||||
}
|
||||
|
||||
filter.Page = 1
|
||||
filter.PageSize = 20
|
||||
if pageRaw != "" {
|
||||
page, err := strconv.Atoi(pageRaw)
|
||||
if err != nil || page < 1 {
|
||||
return nil, fmt.Errorf("invalid page")
|
||||
}
|
||||
filter.Page = page
|
||||
}
|
||||
if pageSizeRaw != "" {
|
||||
pageSize, err := strconv.Atoi(pageSizeRaw)
|
||||
if err != nil || pageSize < 1 || pageSize > 100 {
|
||||
return nil, fmt.Errorf("invalid page_size")
|
||||
}
|
||||
filter.PageSize = pageSize
|
||||
}
|
||||
return filter, nil
|
||||
}
|
||||
|
||||
func parseOpsOpenAITokenStatsDuration(v string) (time.Duration, bool) {
|
||||
switch strings.TrimSpace(v) {
|
||||
case "30m":
|
||||
return 30 * time.Minute, true
|
||||
case "1h":
|
||||
return time.Hour, true
|
||||
case "1d":
|
||||
return 24 * time.Hour, true
|
||||
case "15d":
|
||||
return 15 * 24 * time.Hour, true
|
||||
case "30d":
|
||||
return 30 * 24 * time.Hour, true
|
||||
default:
|
||||
return 0, false
|
||||
}
|
||||
}
|
||||
|
||||
func pickThroughputBucketSeconds(window time.Duration) int {
|
||||
// Keep buckets predictable and avoid huge responses.
|
||||
switch {
|
||||
|
||||
@@ -0,0 +1,173 @@
|
||||
package admin
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"encoding/json"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"testing"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/config"
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/logger"
|
||||
"github.com/Wei-Shaw/sub2api/internal/server/middleware"
|
||||
"github.com/Wei-Shaw/sub2api/internal/service"
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
type testSettingRepo struct {
|
||||
values map[string]string
|
||||
}
|
||||
|
||||
func newTestSettingRepo() *testSettingRepo {
|
||||
return &testSettingRepo{values: map[string]string{}}
|
||||
}
|
||||
|
||||
func (s *testSettingRepo) Get(ctx context.Context, key string) (*service.Setting, error) {
|
||||
v, err := s.GetValue(ctx, key)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &service.Setting{Key: key, Value: v}, nil
|
||||
}
|
||||
func (s *testSettingRepo) GetValue(ctx context.Context, key string) (string, error) {
|
||||
v, ok := s.values[key]
|
||||
if !ok {
|
||||
return "", service.ErrSettingNotFound
|
||||
}
|
||||
return v, nil
|
||||
}
|
||||
func (s *testSettingRepo) Set(ctx context.Context, key, value string) error {
|
||||
s.values[key] = value
|
||||
return nil
|
||||
}
|
||||
func (s *testSettingRepo) GetMultiple(ctx context.Context, keys []string) (map[string]string, error) {
|
||||
out := make(map[string]string, len(keys))
|
||||
for _, k := range keys {
|
||||
if v, ok := s.values[k]; ok {
|
||||
out[k] = v
|
||||
}
|
||||
}
|
||||
return out, nil
|
||||
}
|
||||
func (s *testSettingRepo) SetMultiple(ctx context.Context, settings map[string]string) error {
|
||||
for k, v := range settings {
|
||||
s.values[k] = v
|
||||
}
|
||||
return nil
|
||||
}
|
||||
func (s *testSettingRepo) GetAll(ctx context.Context) (map[string]string, error) {
|
||||
out := make(map[string]string, len(s.values))
|
||||
for k, v := range s.values {
|
||||
out[k] = v
|
||||
}
|
||||
return out, nil
|
||||
}
|
||||
func (s *testSettingRepo) Delete(ctx context.Context, key string) error {
|
||||
delete(s.values, key)
|
||||
return nil
|
||||
}
|
||||
|
||||
func newOpsRuntimeRouter(handler *OpsHandler, withUser bool) *gin.Engine {
|
||||
gin.SetMode(gin.TestMode)
|
||||
r := gin.New()
|
||||
if withUser {
|
||||
r.Use(func(c *gin.Context) {
|
||||
c.Set(string(middleware.ContextKeyUser), middleware.AuthSubject{UserID: 7})
|
||||
c.Next()
|
||||
})
|
||||
}
|
||||
r.GET("/runtime/logging", handler.GetRuntimeLogConfig)
|
||||
r.PUT("/runtime/logging", handler.UpdateRuntimeLogConfig)
|
||||
r.POST("/runtime/logging/reset", handler.ResetRuntimeLogConfig)
|
||||
return r
|
||||
}
|
||||
|
||||
func newRuntimeOpsService(t *testing.T) *service.OpsService {
|
||||
t.Helper()
|
||||
if err := logger.Init(logger.InitOptions{
|
||||
Level: "info",
|
||||
Format: "json",
|
||||
ServiceName: "sub2api",
|
||||
Environment: "test",
|
||||
Output: logger.OutputOptions{
|
||||
ToStdout: false,
|
||||
ToFile: false,
|
||||
},
|
||||
}); err != nil {
|
||||
t.Fatalf("init logger: %v", err)
|
||||
}
|
||||
|
||||
settingRepo := newTestSettingRepo()
|
||||
cfg := &config.Config{
|
||||
Ops: config.OpsConfig{Enabled: true},
|
||||
Log: config.LogConfig{
|
||||
Level: "info",
|
||||
Caller: true,
|
||||
StacktraceLevel: "error",
|
||||
Sampling: config.LogSamplingConfig{
|
||||
Enabled: false,
|
||||
Initial: 100,
|
||||
Thereafter: 100,
|
||||
},
|
||||
},
|
||||
}
|
||||
return service.NewOpsService(nil, settingRepo, cfg, nil, nil, nil, nil, nil, nil, nil, nil)
|
||||
}
|
||||
|
||||
func TestOpsRuntimeLoggingHandler_GetConfig(t *testing.T) {
|
||||
h := NewOpsHandler(newRuntimeOpsService(t))
|
||||
r := newOpsRuntimeRouter(h, false)
|
||||
|
||||
w := httptest.NewRecorder()
|
||||
req := httptest.NewRequest(http.MethodGet, "/runtime/logging", nil)
|
||||
r.ServeHTTP(w, req)
|
||||
if w.Code != http.StatusOK {
|
||||
t.Fatalf("status=%d, want 200", w.Code)
|
||||
}
|
||||
}
|
||||
|
||||
func TestOpsRuntimeLoggingHandler_UpdateUnauthorized(t *testing.T) {
|
||||
h := NewOpsHandler(newRuntimeOpsService(t))
|
||||
r := newOpsRuntimeRouter(h, false)
|
||||
|
||||
body := `{"level":"debug","enable_sampling":false,"sampling_initial":100,"sampling_thereafter":100,"caller":true,"stacktrace_level":"error","retention_days":30}`
|
||||
w := httptest.NewRecorder()
|
||||
req := httptest.NewRequest(http.MethodPut, "/runtime/logging", bytes.NewBufferString(body))
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
r.ServeHTTP(w, req)
|
||||
if w.Code != http.StatusUnauthorized {
|
||||
t.Fatalf("status=%d, want 401", w.Code)
|
||||
}
|
||||
}
|
||||
|
||||
func TestOpsRuntimeLoggingHandler_UpdateAndResetSuccess(t *testing.T) {
|
||||
h := NewOpsHandler(newRuntimeOpsService(t))
|
||||
r := newOpsRuntimeRouter(h, true)
|
||||
|
||||
payload := map[string]any{
|
||||
"level": "debug",
|
||||
"enable_sampling": false,
|
||||
"sampling_initial": 100,
|
||||
"sampling_thereafter": 100,
|
||||
"caller": true,
|
||||
"stacktrace_level": "error",
|
||||
"retention_days": 30,
|
||||
}
|
||||
raw, _ := json.Marshal(payload)
|
||||
|
||||
w := httptest.NewRecorder()
|
||||
req := httptest.NewRequest(http.MethodPut, "/runtime/logging", bytes.NewReader(raw))
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
r.ServeHTTP(w, req)
|
||||
if w.Code != http.StatusOK {
|
||||
t.Fatalf("update status=%d, want 200, body=%s", w.Code, w.Body.String())
|
||||
}
|
||||
|
||||
w = httptest.NewRecorder()
|
||||
req = httptest.NewRequest(http.MethodPost, "/runtime/logging/reset", nil)
|
||||
r.ServeHTTP(w, req)
|
||||
if w.Code != http.StatusOK {
|
||||
t.Fatalf("reset status=%d, want 200, body=%s", w.Code, w.Body.String())
|
||||
}
|
||||
}
|
||||
@@ -4,6 +4,7 @@ import (
|
||||
"net/http"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/response"
|
||||
"github.com/Wei-Shaw/sub2api/internal/server/middleware"
|
||||
"github.com/Wei-Shaw/sub2api/internal/service"
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
@@ -101,6 +102,84 @@ func (h *OpsHandler) UpdateAlertRuntimeSettings(c *gin.Context) {
|
||||
response.Success(c, updated)
|
||||
}
|
||||
|
||||
// GetRuntimeLogConfig returns runtime log config (DB-backed).
|
||||
// GET /api/v1/admin/ops/runtime/logging
|
||||
func (h *OpsHandler) GetRuntimeLogConfig(c *gin.Context) {
|
||||
if h.opsService == nil {
|
||||
response.Error(c, http.StatusServiceUnavailable, "Ops service not available")
|
||||
return
|
||||
}
|
||||
if err := h.opsService.RequireMonitoringEnabled(c.Request.Context()); err != nil {
|
||||
response.ErrorFrom(c, err)
|
||||
return
|
||||
}
|
||||
|
||||
cfg, err := h.opsService.GetRuntimeLogConfig(c.Request.Context())
|
||||
if err != nil {
|
||||
response.Error(c, http.StatusInternalServerError, "Failed to get runtime log config")
|
||||
return
|
||||
}
|
||||
response.Success(c, cfg)
|
||||
}
|
||||
|
||||
// UpdateRuntimeLogConfig updates runtime log config and applies changes immediately.
|
||||
// PUT /api/v1/admin/ops/runtime/logging
|
||||
func (h *OpsHandler) UpdateRuntimeLogConfig(c *gin.Context) {
|
||||
if h.opsService == nil {
|
||||
response.Error(c, http.StatusServiceUnavailable, "Ops service not available")
|
||||
return
|
||||
}
|
||||
if err := h.opsService.RequireMonitoringEnabled(c.Request.Context()); err != nil {
|
||||
response.ErrorFrom(c, err)
|
||||
return
|
||||
}
|
||||
|
||||
var req service.OpsRuntimeLogConfig
|
||||
if err := c.ShouldBindJSON(&req); err != nil {
|
||||
response.BadRequest(c, "Invalid request body")
|
||||
return
|
||||
}
|
||||
|
||||
subject, ok := middleware.GetAuthSubjectFromContext(c)
|
||||
if !ok || subject.UserID <= 0 {
|
||||
response.Error(c, http.StatusUnauthorized, "Unauthorized")
|
||||
return
|
||||
}
|
||||
|
||||
updated, err := h.opsService.UpdateRuntimeLogConfig(c.Request.Context(), &req, subject.UserID)
|
||||
if err != nil {
|
||||
response.Error(c, http.StatusBadRequest, err.Error())
|
||||
return
|
||||
}
|
||||
response.Success(c, updated)
|
||||
}
|
||||
|
||||
// ResetRuntimeLogConfig removes runtime override and falls back to env/yaml baseline.
|
||||
// POST /api/v1/admin/ops/runtime/logging/reset
|
||||
func (h *OpsHandler) ResetRuntimeLogConfig(c *gin.Context) {
|
||||
if h.opsService == nil {
|
||||
response.Error(c, http.StatusServiceUnavailable, "Ops service not available")
|
||||
return
|
||||
}
|
||||
if err := h.opsService.RequireMonitoringEnabled(c.Request.Context()); err != nil {
|
||||
response.ErrorFrom(c, err)
|
||||
return
|
||||
}
|
||||
|
||||
subject, ok := middleware.GetAuthSubjectFromContext(c)
|
||||
if !ok || subject.UserID <= 0 {
|
||||
response.Error(c, http.StatusUnauthorized, "Unauthorized")
|
||||
return
|
||||
}
|
||||
|
||||
updated, err := h.opsService.ResetRuntimeLogConfig(c.Request.Context(), subject.UserID)
|
||||
if err != nil {
|
||||
response.Error(c, http.StatusBadRequest, err.Error())
|
||||
return
|
||||
}
|
||||
response.Success(c, updated)
|
||||
}
|
||||
|
||||
// GetAdvancedSettings returns Ops advanced settings (DB-backed).
|
||||
// GET /api/v1/admin/ops/advanced-settings
|
||||
func (h *OpsHandler) GetAdvancedSettings(c *gin.Context) {
|
||||
|
||||
145
backend/internal/handler/admin/ops_snapshot_v2_handler.go
Normal file
145
backend/internal/handler/admin/ops_snapshot_v2_handler.go
Normal file
@@ -0,0 +1,145 @@
|
||||
package admin
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"net/http"
|
||||
"strconv"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/response"
|
||||
"github.com/Wei-Shaw/sub2api/internal/service"
|
||||
"github.com/gin-gonic/gin"
|
||||
"golang.org/x/sync/errgroup"
|
||||
)
|
||||
|
||||
var opsDashboardSnapshotV2Cache = newSnapshotCache(30 * time.Second)
|
||||
|
||||
type opsDashboardSnapshotV2Response struct {
|
||||
GeneratedAt string `json:"generated_at"`
|
||||
|
||||
Overview *service.OpsDashboardOverview `json:"overview"`
|
||||
ThroughputTrend *service.OpsThroughputTrendResponse `json:"throughput_trend"`
|
||||
ErrorTrend *service.OpsErrorTrendResponse `json:"error_trend"`
|
||||
}
|
||||
|
||||
type opsDashboardSnapshotV2CacheKey struct {
|
||||
StartTime string `json:"start_time"`
|
||||
EndTime string `json:"end_time"`
|
||||
Platform string `json:"platform"`
|
||||
GroupID *int64 `json:"group_id"`
|
||||
QueryMode service.OpsQueryMode `json:"mode"`
|
||||
BucketSecond int `json:"bucket_second"`
|
||||
}
|
||||
|
||||
// GetDashboardSnapshotV2 returns ops dashboard core snapshot in one request.
|
||||
// GET /api/v1/admin/ops/dashboard/snapshot-v2
|
||||
func (h *OpsHandler) GetDashboardSnapshotV2(c *gin.Context) {
|
||||
if h.opsService == nil {
|
||||
response.Error(c, http.StatusServiceUnavailable, "Ops service not available")
|
||||
return
|
||||
}
|
||||
if err := h.opsService.RequireMonitoringEnabled(c.Request.Context()); err != nil {
|
||||
response.ErrorFrom(c, err)
|
||||
return
|
||||
}
|
||||
|
||||
startTime, endTime, err := parseOpsTimeRange(c, "1h")
|
||||
if err != nil {
|
||||
response.BadRequest(c, err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
filter := &service.OpsDashboardFilter{
|
||||
StartTime: startTime,
|
||||
EndTime: endTime,
|
||||
Platform: strings.TrimSpace(c.Query("platform")),
|
||||
QueryMode: parseOpsQueryMode(c),
|
||||
}
|
||||
if v := strings.TrimSpace(c.Query("group_id")); v != "" {
|
||||
id, err := strconv.ParseInt(v, 10, 64)
|
||||
if err != nil || id <= 0 {
|
||||
response.BadRequest(c, "Invalid group_id")
|
||||
return
|
||||
}
|
||||
filter.GroupID = &id
|
||||
}
|
||||
bucketSeconds := pickThroughputBucketSeconds(endTime.Sub(startTime))
|
||||
|
||||
keyRaw, _ := json.Marshal(opsDashboardSnapshotV2CacheKey{
|
||||
StartTime: startTime.UTC().Format(time.RFC3339),
|
||||
EndTime: endTime.UTC().Format(time.RFC3339),
|
||||
Platform: filter.Platform,
|
||||
GroupID: filter.GroupID,
|
||||
QueryMode: filter.QueryMode,
|
||||
BucketSecond: bucketSeconds,
|
||||
})
|
||||
cacheKey := string(keyRaw)
|
||||
|
||||
if cached, ok := opsDashboardSnapshotV2Cache.Get(cacheKey); ok {
|
||||
if cached.ETag != "" {
|
||||
c.Header("ETag", cached.ETag)
|
||||
c.Header("Vary", "If-None-Match")
|
||||
if ifNoneMatchMatched(c.GetHeader("If-None-Match"), cached.ETag) {
|
||||
c.Status(http.StatusNotModified)
|
||||
return
|
||||
}
|
||||
}
|
||||
c.Header("X-Snapshot-Cache", "hit")
|
||||
response.Success(c, cached.Payload)
|
||||
return
|
||||
}
|
||||
|
||||
var (
|
||||
overview *service.OpsDashboardOverview
|
||||
trend *service.OpsThroughputTrendResponse
|
||||
errTrend *service.OpsErrorTrendResponse
|
||||
)
|
||||
g, gctx := errgroup.WithContext(c.Request.Context())
|
||||
g.Go(func() error {
|
||||
f := *filter
|
||||
result, err := h.opsService.GetDashboardOverview(gctx, &f)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
overview = result
|
||||
return nil
|
||||
})
|
||||
g.Go(func() error {
|
||||
f := *filter
|
||||
result, err := h.opsService.GetThroughputTrend(gctx, &f, bucketSeconds)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
trend = result
|
||||
return nil
|
||||
})
|
||||
g.Go(func() error {
|
||||
f := *filter
|
||||
result, err := h.opsService.GetErrorTrend(gctx, &f, bucketSeconds)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
errTrend = result
|
||||
return nil
|
||||
})
|
||||
if err := g.Wait(); err != nil {
|
||||
response.ErrorFrom(c, err)
|
||||
return
|
||||
}
|
||||
|
||||
resp := &opsDashboardSnapshotV2Response{
|
||||
GeneratedAt: time.Now().UTC().Format(time.RFC3339),
|
||||
Overview: overview,
|
||||
ThroughputTrend: trend,
|
||||
ErrorTrend: errTrend,
|
||||
}
|
||||
|
||||
cached := opsDashboardSnapshotV2Cache.Set(cacheKey, resp)
|
||||
if cached.ETag != "" {
|
||||
c.Header("ETag", cached.ETag)
|
||||
c.Header("Vary", "If-None-Match")
|
||||
}
|
||||
c.Header("X-Snapshot-Cache", "miss")
|
||||
response.Success(c, resp)
|
||||
}
|
||||
174
backend/internal/handler/admin/ops_system_log_handler.go
Normal file
174
backend/internal/handler/admin/ops_system_log_handler.go
Normal file
@@ -0,0 +1,174 @@
|
||||
package admin
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"strconv"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/response"
|
||||
"github.com/Wei-Shaw/sub2api/internal/server/middleware"
|
||||
"github.com/Wei-Shaw/sub2api/internal/service"
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
type opsSystemLogCleanupRequest struct {
|
||||
StartTime string `json:"start_time"`
|
||||
EndTime string `json:"end_time"`
|
||||
|
||||
Level string `json:"level"`
|
||||
Component string `json:"component"`
|
||||
RequestID string `json:"request_id"`
|
||||
ClientRequestID string `json:"client_request_id"`
|
||||
UserID *int64 `json:"user_id"`
|
||||
AccountID *int64 `json:"account_id"`
|
||||
Platform string `json:"platform"`
|
||||
Model string `json:"model"`
|
||||
Query string `json:"q"`
|
||||
}
|
||||
|
||||
// ListSystemLogs returns indexed system logs.
|
||||
// GET /api/v1/admin/ops/system-logs
|
||||
func (h *OpsHandler) ListSystemLogs(c *gin.Context) {
|
||||
if h.opsService == nil {
|
||||
response.Error(c, http.StatusServiceUnavailable, "Ops service not available")
|
||||
return
|
||||
}
|
||||
if err := h.opsService.RequireMonitoringEnabled(c.Request.Context()); err != nil {
|
||||
response.ErrorFrom(c, err)
|
||||
return
|
||||
}
|
||||
|
||||
page, pageSize := response.ParsePagination(c)
|
||||
if pageSize > 200 {
|
||||
pageSize = 200
|
||||
}
|
||||
|
||||
start, end, err := parseOpsTimeRange(c, "1h")
|
||||
if err != nil {
|
||||
response.BadRequest(c, err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
filter := &service.OpsSystemLogFilter{
|
||||
Page: page,
|
||||
PageSize: pageSize,
|
||||
StartTime: &start,
|
||||
EndTime: &end,
|
||||
Level: strings.TrimSpace(c.Query("level")),
|
||||
Component: strings.TrimSpace(c.Query("component")),
|
||||
RequestID: strings.TrimSpace(c.Query("request_id")),
|
||||
ClientRequestID: strings.TrimSpace(c.Query("client_request_id")),
|
||||
Platform: strings.TrimSpace(c.Query("platform")),
|
||||
Model: strings.TrimSpace(c.Query("model")),
|
||||
Query: strings.TrimSpace(c.Query("q")),
|
||||
}
|
||||
if v := strings.TrimSpace(c.Query("user_id")); v != "" {
|
||||
id, parseErr := strconv.ParseInt(v, 10, 64)
|
||||
if parseErr != nil || id <= 0 {
|
||||
response.BadRequest(c, "Invalid user_id")
|
||||
return
|
||||
}
|
||||
filter.UserID = &id
|
||||
}
|
||||
if v := strings.TrimSpace(c.Query("account_id")); v != "" {
|
||||
id, parseErr := strconv.ParseInt(v, 10, 64)
|
||||
if parseErr != nil || id <= 0 {
|
||||
response.BadRequest(c, "Invalid account_id")
|
||||
return
|
||||
}
|
||||
filter.AccountID = &id
|
||||
}
|
||||
|
||||
result, err := h.opsService.ListSystemLogs(c.Request.Context(), filter)
|
||||
if err != nil {
|
||||
response.ErrorFrom(c, err)
|
||||
return
|
||||
}
|
||||
response.Paginated(c, result.Logs, int64(result.Total), result.Page, result.PageSize)
|
||||
}
|
||||
|
||||
// CleanupSystemLogs deletes indexed system logs by filter.
|
||||
// POST /api/v1/admin/ops/system-logs/cleanup
|
||||
func (h *OpsHandler) CleanupSystemLogs(c *gin.Context) {
|
||||
if h.opsService == nil {
|
||||
response.Error(c, http.StatusServiceUnavailable, "Ops service not available")
|
||||
return
|
||||
}
|
||||
if err := h.opsService.RequireMonitoringEnabled(c.Request.Context()); err != nil {
|
||||
response.ErrorFrom(c, err)
|
||||
return
|
||||
}
|
||||
|
||||
subject, ok := middleware.GetAuthSubjectFromContext(c)
|
||||
if !ok || subject.UserID <= 0 {
|
||||
response.Error(c, http.StatusUnauthorized, "Unauthorized")
|
||||
return
|
||||
}
|
||||
|
||||
var req opsSystemLogCleanupRequest
|
||||
if err := c.ShouldBindJSON(&req); err != nil {
|
||||
response.BadRequest(c, "Invalid request body")
|
||||
return
|
||||
}
|
||||
|
||||
parseTS := func(raw string) (*time.Time, error) {
|
||||
raw = strings.TrimSpace(raw)
|
||||
if raw == "" {
|
||||
return nil, nil
|
||||
}
|
||||
if t, err := time.Parse(time.RFC3339Nano, raw); err == nil {
|
||||
return &t, nil
|
||||
}
|
||||
t, err := time.Parse(time.RFC3339, raw)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &t, nil
|
||||
}
|
||||
start, err := parseTS(req.StartTime)
|
||||
if err != nil {
|
||||
response.BadRequest(c, "Invalid start_time")
|
||||
return
|
||||
}
|
||||
end, err := parseTS(req.EndTime)
|
||||
if err != nil {
|
||||
response.BadRequest(c, "Invalid end_time")
|
||||
return
|
||||
}
|
||||
|
||||
filter := &service.OpsSystemLogCleanupFilter{
|
||||
StartTime: start,
|
||||
EndTime: end,
|
||||
Level: strings.TrimSpace(req.Level),
|
||||
Component: strings.TrimSpace(req.Component),
|
||||
RequestID: strings.TrimSpace(req.RequestID),
|
||||
ClientRequestID: strings.TrimSpace(req.ClientRequestID),
|
||||
UserID: req.UserID,
|
||||
AccountID: req.AccountID,
|
||||
Platform: strings.TrimSpace(req.Platform),
|
||||
Model: strings.TrimSpace(req.Model),
|
||||
Query: strings.TrimSpace(req.Query),
|
||||
}
|
||||
|
||||
deleted, err := h.opsService.CleanupSystemLogs(c.Request.Context(), filter, subject.UserID)
|
||||
if err != nil {
|
||||
response.ErrorFrom(c, err)
|
||||
return
|
||||
}
|
||||
response.Success(c, gin.H{"deleted": deleted})
|
||||
}
|
||||
|
||||
// GetSystemLogIngestionHealth returns sink health metrics.
|
||||
// GET /api/v1/admin/ops/system-logs/health
|
||||
func (h *OpsHandler) GetSystemLogIngestionHealth(c *gin.Context) {
|
||||
if h.opsService == nil {
|
||||
response.Error(c, http.StatusServiceUnavailable, "Ops service not available")
|
||||
return
|
||||
}
|
||||
if err := h.opsService.RequireMonitoringEnabled(c.Request.Context()); err != nil {
|
||||
response.ErrorFrom(c, err)
|
||||
return
|
||||
}
|
||||
response.Success(c, h.opsService.GetSystemLogSinkHealth())
|
||||
}
|
||||
233
backend/internal/handler/admin/ops_system_log_handler_test.go
Normal file
233
backend/internal/handler/admin/ops_system_log_handler_test.go
Normal file
@@ -0,0 +1,233 @@
|
||||
package admin
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/json"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"testing"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/config"
|
||||
"github.com/Wei-Shaw/sub2api/internal/server/middleware"
|
||||
"github.com/Wei-Shaw/sub2api/internal/service"
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
type responseEnvelope struct {
|
||||
Code int `json:"code"`
|
||||
Message string `json:"message"`
|
||||
Data json.RawMessage `json:"data"`
|
||||
}
|
||||
|
||||
func newOpsSystemLogTestRouter(handler *OpsHandler, withUser bool) *gin.Engine {
|
||||
gin.SetMode(gin.TestMode)
|
||||
r := gin.New()
|
||||
if withUser {
|
||||
r.Use(func(c *gin.Context) {
|
||||
c.Set(string(middleware.ContextKeyUser), middleware.AuthSubject{UserID: 99})
|
||||
c.Next()
|
||||
})
|
||||
}
|
||||
r.GET("/logs", handler.ListSystemLogs)
|
||||
r.POST("/logs/cleanup", handler.CleanupSystemLogs)
|
||||
r.GET("/logs/health", handler.GetSystemLogIngestionHealth)
|
||||
return r
|
||||
}
|
||||
|
||||
func TestOpsSystemLogHandler_ListUnavailable(t *testing.T) {
|
||||
h := NewOpsHandler(nil)
|
||||
r := newOpsSystemLogTestRouter(h, false)
|
||||
|
||||
w := httptest.NewRecorder()
|
||||
req := httptest.NewRequest(http.MethodGet, "/logs", nil)
|
||||
r.ServeHTTP(w, req)
|
||||
if w.Code != http.StatusServiceUnavailable {
|
||||
t.Fatalf("status=%d, want 503", w.Code)
|
||||
}
|
||||
}
|
||||
|
||||
func TestOpsSystemLogHandler_ListInvalidUserID(t *testing.T) {
|
||||
svc := service.NewOpsService(nil, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil)
|
||||
h := NewOpsHandler(svc)
|
||||
r := newOpsSystemLogTestRouter(h, false)
|
||||
|
||||
w := httptest.NewRecorder()
|
||||
req := httptest.NewRequest(http.MethodGet, "/logs?user_id=abc", nil)
|
||||
r.ServeHTTP(w, req)
|
||||
if w.Code != http.StatusBadRequest {
|
||||
t.Fatalf("status=%d, want 400", w.Code)
|
||||
}
|
||||
}
|
||||
|
||||
func TestOpsSystemLogHandler_ListInvalidAccountID(t *testing.T) {
|
||||
svc := service.NewOpsService(nil, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil)
|
||||
h := NewOpsHandler(svc)
|
||||
r := newOpsSystemLogTestRouter(h, false)
|
||||
|
||||
w := httptest.NewRecorder()
|
||||
req := httptest.NewRequest(http.MethodGet, "/logs?account_id=-1", nil)
|
||||
r.ServeHTTP(w, req)
|
||||
if w.Code != http.StatusBadRequest {
|
||||
t.Fatalf("status=%d, want 400", w.Code)
|
||||
}
|
||||
}
|
||||
|
||||
func TestOpsSystemLogHandler_ListMonitoringDisabled(t *testing.T) {
|
||||
svc := service.NewOpsService(nil, nil, &config.Config{
|
||||
Ops: config.OpsConfig{Enabled: false},
|
||||
}, nil, nil, nil, nil, nil, nil, nil, nil)
|
||||
h := NewOpsHandler(svc)
|
||||
r := newOpsSystemLogTestRouter(h, false)
|
||||
|
||||
w := httptest.NewRecorder()
|
||||
req := httptest.NewRequest(http.MethodGet, "/logs", nil)
|
||||
r.ServeHTTP(w, req)
|
||||
if w.Code != http.StatusNotFound {
|
||||
t.Fatalf("status=%d, want 404", w.Code)
|
||||
}
|
||||
}
|
||||
|
||||
func TestOpsSystemLogHandler_ListSuccess(t *testing.T) {
|
||||
svc := service.NewOpsService(nil, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil)
|
||||
h := NewOpsHandler(svc)
|
||||
r := newOpsSystemLogTestRouter(h, false)
|
||||
|
||||
w := httptest.NewRecorder()
|
||||
req := httptest.NewRequest(http.MethodGet, "/logs?time_range=30m&page=1&page_size=20", nil)
|
||||
r.ServeHTTP(w, req)
|
||||
if w.Code != http.StatusOK {
|
||||
t.Fatalf("status=%d, want 200", w.Code)
|
||||
}
|
||||
|
||||
var resp responseEnvelope
|
||||
if err := json.Unmarshal(w.Body.Bytes(), &resp); err != nil {
|
||||
t.Fatalf("unmarshal response: %v", err)
|
||||
}
|
||||
if resp.Code != 0 {
|
||||
t.Fatalf("unexpected response code: %+v", resp)
|
||||
}
|
||||
}
|
||||
|
||||
func TestOpsSystemLogHandler_CleanupUnauthorized(t *testing.T) {
|
||||
svc := service.NewOpsService(nil, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil)
|
||||
h := NewOpsHandler(svc)
|
||||
r := newOpsSystemLogTestRouter(h, false)
|
||||
|
||||
w := httptest.NewRecorder()
|
||||
req := httptest.NewRequest(http.MethodPost, "/logs/cleanup", bytes.NewBufferString(`{"request_id":"r1"}`))
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
r.ServeHTTP(w, req)
|
||||
if w.Code != http.StatusUnauthorized {
|
||||
t.Fatalf("status=%d, want 401", w.Code)
|
||||
}
|
||||
}
|
||||
|
||||
func TestOpsSystemLogHandler_CleanupInvalidPayload(t *testing.T) {
|
||||
svc := service.NewOpsService(nil, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil)
|
||||
h := NewOpsHandler(svc)
|
||||
r := newOpsSystemLogTestRouter(h, true)
|
||||
|
||||
w := httptest.NewRecorder()
|
||||
req := httptest.NewRequest(http.MethodPost, "/logs/cleanup", bytes.NewBufferString(`{bad-json`))
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
r.ServeHTTP(w, req)
|
||||
if w.Code != http.StatusBadRequest {
|
||||
t.Fatalf("status=%d, want 400", w.Code)
|
||||
}
|
||||
}
|
||||
|
||||
func TestOpsSystemLogHandler_CleanupInvalidTime(t *testing.T) {
|
||||
svc := service.NewOpsService(nil, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil)
|
||||
h := NewOpsHandler(svc)
|
||||
r := newOpsSystemLogTestRouter(h, true)
|
||||
|
||||
w := httptest.NewRecorder()
|
||||
req := httptest.NewRequest(http.MethodPost, "/logs/cleanup", bytes.NewBufferString(`{"start_time":"bad","request_id":"r1"}`))
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
r.ServeHTTP(w, req)
|
||||
if w.Code != http.StatusBadRequest {
|
||||
t.Fatalf("status=%d, want 400", w.Code)
|
||||
}
|
||||
}
|
||||
|
||||
func TestOpsSystemLogHandler_CleanupInvalidEndTime(t *testing.T) {
|
||||
svc := service.NewOpsService(nil, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil)
|
||||
h := NewOpsHandler(svc)
|
||||
r := newOpsSystemLogTestRouter(h, true)
|
||||
|
||||
w := httptest.NewRecorder()
|
||||
req := httptest.NewRequest(http.MethodPost, "/logs/cleanup", bytes.NewBufferString(`{"end_time":"bad","request_id":"r1"}`))
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
r.ServeHTTP(w, req)
|
||||
if w.Code != http.StatusBadRequest {
|
||||
t.Fatalf("status=%d, want 400", w.Code)
|
||||
}
|
||||
}
|
||||
|
||||
func TestOpsSystemLogHandler_CleanupServiceUnavailable(t *testing.T) {
|
||||
svc := service.NewOpsService(nil, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil)
|
||||
h := NewOpsHandler(svc)
|
||||
r := newOpsSystemLogTestRouter(h, true)
|
||||
|
||||
w := httptest.NewRecorder()
|
||||
req := httptest.NewRequest(http.MethodPost, "/logs/cleanup", bytes.NewBufferString(`{"request_id":"r1"}`))
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
r.ServeHTTP(w, req)
|
||||
if w.Code != http.StatusServiceUnavailable {
|
||||
t.Fatalf("status=%d, want 503", w.Code)
|
||||
}
|
||||
}
|
||||
|
||||
func TestOpsSystemLogHandler_CleanupMonitoringDisabled(t *testing.T) {
|
||||
svc := service.NewOpsService(nil, nil, &config.Config{
|
||||
Ops: config.OpsConfig{Enabled: false},
|
||||
}, nil, nil, nil, nil, nil, nil, nil, nil)
|
||||
h := NewOpsHandler(svc)
|
||||
r := newOpsSystemLogTestRouter(h, true)
|
||||
|
||||
w := httptest.NewRecorder()
|
||||
req := httptest.NewRequest(http.MethodPost, "/logs/cleanup", bytes.NewBufferString(`{"request_id":"r1"}`))
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
r.ServeHTTP(w, req)
|
||||
if w.Code != http.StatusNotFound {
|
||||
t.Fatalf("status=%d, want 404", w.Code)
|
||||
}
|
||||
}
|
||||
|
||||
func TestOpsSystemLogHandler_Health(t *testing.T) {
|
||||
sink := service.NewOpsSystemLogSink(nil)
|
||||
svc := service.NewOpsService(nil, nil, nil, nil, nil, nil, nil, nil, nil, nil, sink)
|
||||
h := NewOpsHandler(svc)
|
||||
r := newOpsSystemLogTestRouter(h, false)
|
||||
|
||||
w := httptest.NewRecorder()
|
||||
req := httptest.NewRequest(http.MethodGet, "/logs/health", nil)
|
||||
r.ServeHTTP(w, req)
|
||||
if w.Code != http.StatusOK {
|
||||
t.Fatalf("status=%d, want 200", w.Code)
|
||||
}
|
||||
}
|
||||
|
||||
func TestOpsSystemLogHandler_HealthUnavailableAndMonitoringDisabled(t *testing.T) {
|
||||
h := NewOpsHandler(nil)
|
||||
r := newOpsSystemLogTestRouter(h, false)
|
||||
|
||||
w := httptest.NewRecorder()
|
||||
req := httptest.NewRequest(http.MethodGet, "/logs/health", nil)
|
||||
r.ServeHTTP(w, req)
|
||||
if w.Code != http.StatusServiceUnavailable {
|
||||
t.Fatalf("status=%d, want 503", w.Code)
|
||||
}
|
||||
|
||||
svc := service.NewOpsService(nil, nil, &config.Config{
|
||||
Ops: config.OpsConfig{Enabled: false},
|
||||
}, nil, nil, nil, nil, nil, nil, nil, nil)
|
||||
h = NewOpsHandler(svc)
|
||||
r = newOpsSystemLogTestRouter(h, false)
|
||||
w = httptest.NewRecorder()
|
||||
req = httptest.NewRequest(http.MethodGet, "/logs/health", nil)
|
||||
r.ServeHTTP(w, req)
|
||||
if w.Code != http.StatusNotFound {
|
||||
t.Fatalf("status=%d, want 404", w.Code)
|
||||
}
|
||||
}
|
||||
@@ -3,7 +3,6 @@ package admin
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"log"
|
||||
"math"
|
||||
"net"
|
||||
"net/http"
|
||||
@@ -16,6 +15,7 @@ import (
|
||||
"sync/atomic"
|
||||
"time"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/logger"
|
||||
"github.com/Wei-Shaw/sub2api/internal/service"
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/gorilla/websocket"
|
||||
@@ -62,7 +62,8 @@ const (
|
||||
)
|
||||
|
||||
var wsConnCount atomic.Int32
|
||||
var wsConnCountByIP sync.Map // map[string]*atomic.Int32
|
||||
var wsConnCountByIPMu sync.Mutex
|
||||
var wsConnCountByIP = make(map[string]int32)
|
||||
|
||||
const qpsWSIdleStopDelay = 30 * time.Second
|
||||
|
||||
@@ -252,7 +253,7 @@ func (c *opsWSQPSCache) refresh(parentCtx context.Context) {
|
||||
stats, err := opsService.GetWindowStats(ctx, now.Add(-c.requestCountWindow), now)
|
||||
if err != nil || stats == nil {
|
||||
if err != nil {
|
||||
log.Printf("[OpsWS] refresh: get window stats failed: %v", err)
|
||||
logger.LegacyPrintf("handler.admin.ops_ws", "[OpsWS] refresh: get window stats failed: %v", err)
|
||||
}
|
||||
return
|
||||
}
|
||||
@@ -278,7 +279,7 @@ func (c *opsWSQPSCache) refresh(parentCtx context.Context) {
|
||||
|
||||
msg, err := json.Marshal(payload)
|
||||
if err != nil {
|
||||
log.Printf("[OpsWS] refresh: marshal payload failed: %v", err)
|
||||
logger.LegacyPrintf("handler.admin.ops_ws", "[OpsWS] refresh: marshal payload failed: %v", err)
|
||||
return
|
||||
}
|
||||
|
||||
@@ -338,7 +339,7 @@ func (h *OpsHandler) QPSWSHandler(c *gin.Context) {
|
||||
|
||||
// Reserve a global slot before upgrading the connection to keep the limit strict.
|
||||
if !tryAcquireOpsWSTotalSlot(opsWSLimits.MaxConns) {
|
||||
log.Printf("[OpsWS] connection limit reached: %d/%d", wsConnCount.Load(), opsWSLimits.MaxConns)
|
||||
logger.LegacyPrintf("handler.admin.ops_ws", "[OpsWS] connection limit reached: %d/%d", wsConnCount.Load(), opsWSLimits.MaxConns)
|
||||
c.JSON(http.StatusServiceUnavailable, gin.H{"error": "too many connections"})
|
||||
return
|
||||
}
|
||||
@@ -350,7 +351,7 @@ func (h *OpsHandler) QPSWSHandler(c *gin.Context) {
|
||||
|
||||
if opsWSLimits.MaxConnsPerIP > 0 && clientIP != "" {
|
||||
if !tryAcquireOpsWSIPSlot(clientIP, opsWSLimits.MaxConnsPerIP) {
|
||||
log.Printf("[OpsWS] per-ip connection limit reached: ip=%s limit=%d", clientIP, opsWSLimits.MaxConnsPerIP)
|
||||
logger.LegacyPrintf("handler.admin.ops_ws", "[OpsWS] per-ip connection limit reached: ip=%s limit=%d", clientIP, opsWSLimits.MaxConnsPerIP)
|
||||
c.JSON(http.StatusServiceUnavailable, gin.H{"error": "too many connections"})
|
||||
return
|
||||
}
|
||||
@@ -359,7 +360,7 @@ func (h *OpsHandler) QPSWSHandler(c *gin.Context) {
|
||||
|
||||
conn, err := upgrader.Upgrade(c.Writer, c.Request, nil)
|
||||
if err != nil {
|
||||
log.Printf("[OpsWS] upgrade failed: %v", err)
|
||||
logger.LegacyPrintf("handler.admin.ops_ws", "[OpsWS] upgrade failed: %v", err)
|
||||
return
|
||||
}
|
||||
|
||||
@@ -389,42 +390,31 @@ func tryAcquireOpsWSIPSlot(clientIP string, limit int32) bool {
|
||||
if strings.TrimSpace(clientIP) == "" || limit <= 0 {
|
||||
return true
|
||||
}
|
||||
|
||||
v, _ := wsConnCountByIP.LoadOrStore(clientIP, &atomic.Int32{})
|
||||
counter, ok := v.(*atomic.Int32)
|
||||
if !ok {
|
||||
wsConnCountByIPMu.Lock()
|
||||
defer wsConnCountByIPMu.Unlock()
|
||||
current := wsConnCountByIP[clientIP]
|
||||
if current >= limit {
|
||||
return false
|
||||
}
|
||||
|
||||
for {
|
||||
current := counter.Load()
|
||||
if current >= limit {
|
||||
return false
|
||||
}
|
||||
if counter.CompareAndSwap(current, current+1) {
|
||||
return true
|
||||
}
|
||||
}
|
||||
wsConnCountByIP[clientIP] = current + 1
|
||||
return true
|
||||
}
|
||||
|
||||
func releaseOpsWSIPSlot(clientIP string) {
|
||||
if strings.TrimSpace(clientIP) == "" {
|
||||
return
|
||||
}
|
||||
|
||||
v, ok := wsConnCountByIP.Load(clientIP)
|
||||
wsConnCountByIPMu.Lock()
|
||||
defer wsConnCountByIPMu.Unlock()
|
||||
current, ok := wsConnCountByIP[clientIP]
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
counter, ok := v.(*atomic.Int32)
|
||||
if !ok {
|
||||
if current <= 1 {
|
||||
delete(wsConnCountByIP, clientIP)
|
||||
return
|
||||
}
|
||||
next := counter.Add(-1)
|
||||
if next <= 0 {
|
||||
// Best-effort cleanup; safe even if a new slot was acquired concurrently.
|
||||
wsConnCountByIP.Delete(clientIP)
|
||||
}
|
||||
wsConnCountByIP[clientIP] = current - 1
|
||||
}
|
||||
|
||||
func handleQPSWebSocket(parentCtx context.Context, conn *websocket.Conn) {
|
||||
@@ -452,7 +442,7 @@ func handleQPSWebSocket(parentCtx context.Context, conn *websocket.Conn) {
|
||||
|
||||
conn.SetReadLimit(qpsWSMaxReadBytes)
|
||||
if err := conn.SetReadDeadline(time.Now().Add(qpsWSPongWait)); err != nil {
|
||||
log.Printf("[OpsWS] set read deadline failed: %v", err)
|
||||
logger.LegacyPrintf("handler.admin.ops_ws", "[OpsWS] set read deadline failed: %v", err)
|
||||
return
|
||||
}
|
||||
conn.SetPongHandler(func(string) error {
|
||||
@@ -471,7 +461,7 @@ func handleQPSWebSocket(parentCtx context.Context, conn *websocket.Conn) {
|
||||
_, _, err := conn.ReadMessage()
|
||||
if err != nil {
|
||||
if websocket.IsUnexpectedCloseError(err, websocket.CloseNormalClosure, websocket.CloseGoingAway, websocket.CloseNoStatusReceived) {
|
||||
log.Printf("[OpsWS] read failed: %v", err)
|
||||
logger.LegacyPrintf("handler.admin.ops_ws", "[OpsWS] read failed: %v", err)
|
||||
}
|
||||
return
|
||||
}
|
||||
@@ -508,7 +498,7 @@ func handleQPSWebSocket(parentCtx context.Context, conn *websocket.Conn) {
|
||||
continue
|
||||
}
|
||||
if err := writeWithTimeout(websocket.TextMessage, msg); err != nil {
|
||||
log.Printf("[OpsWS] write failed: %v", err)
|
||||
logger.LegacyPrintf("handler.admin.ops_ws", "[OpsWS] write failed: %v", err)
|
||||
cancel()
|
||||
closeConn()
|
||||
wg.Wait()
|
||||
@@ -517,7 +507,7 @@ func handleQPSWebSocket(parentCtx context.Context, conn *websocket.Conn) {
|
||||
|
||||
case <-pingTicker.C:
|
||||
if err := writeWithTimeout(websocket.PingMessage, nil); err != nil {
|
||||
log.Printf("[OpsWS] ping failed: %v", err)
|
||||
logger.LegacyPrintf("handler.admin.ops_ws", "[OpsWS] ping failed: %v", err)
|
||||
cancel()
|
||||
closeConn()
|
||||
wg.Wait()
|
||||
@@ -666,14 +656,14 @@ func loadOpsWSProxyConfigFromEnv() OpsWSProxyConfig {
|
||||
if parsed, err := strconv.ParseBool(v); err == nil {
|
||||
cfg.TrustProxy = parsed
|
||||
} else {
|
||||
log.Printf("[OpsWS] invalid %s=%q (expected bool); using default=%v", envOpsWSTrustProxy, v, cfg.TrustProxy)
|
||||
logger.LegacyPrintf("handler.admin.ops_ws", "[OpsWS] invalid %s=%q (expected bool); using default=%v", envOpsWSTrustProxy, v, cfg.TrustProxy)
|
||||
}
|
||||
}
|
||||
|
||||
if raw := strings.TrimSpace(os.Getenv(envOpsWSTrustedProxies)); raw != "" {
|
||||
prefixes, invalid := parseTrustedProxyList(raw)
|
||||
if len(invalid) > 0 {
|
||||
log.Printf("[OpsWS] invalid %s entries ignored: %s", envOpsWSTrustedProxies, strings.Join(invalid, ", "))
|
||||
logger.LegacyPrintf("handler.admin.ops_ws", "[OpsWS] invalid %s entries ignored: %s", envOpsWSTrustedProxies, strings.Join(invalid, ", "))
|
||||
}
|
||||
cfg.TrustedProxies = prefixes
|
||||
}
|
||||
@@ -684,7 +674,7 @@ func loadOpsWSProxyConfigFromEnv() OpsWSProxyConfig {
|
||||
case OriginPolicyStrict, OriginPolicyPermissive:
|
||||
cfg.OriginPolicy = normalized
|
||||
default:
|
||||
log.Printf("[OpsWS] invalid %s=%q (expected %q or %q); using default=%q", envOpsWSOriginPolicy, v, OriginPolicyStrict, OriginPolicyPermissive, cfg.OriginPolicy)
|
||||
logger.LegacyPrintf("handler.admin.ops_ws", "[OpsWS] invalid %s=%q (expected %q or %q); using default=%q", envOpsWSOriginPolicy, v, OriginPolicyStrict, OriginPolicyPermissive, cfg.OriginPolicy)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -701,14 +691,14 @@ func loadOpsWSRuntimeLimitsFromEnv() opsWSRuntimeLimits {
|
||||
if parsed, err := strconv.Atoi(v); err == nil && parsed > 0 {
|
||||
cfg.MaxConns = int32(parsed)
|
||||
} else {
|
||||
log.Printf("[OpsWS] invalid %s=%q (expected int>0); using default=%d", envOpsWSMaxConns, v, cfg.MaxConns)
|
||||
logger.LegacyPrintf("handler.admin.ops_ws", "[OpsWS] invalid %s=%q (expected int>0); using default=%d", envOpsWSMaxConns, v, cfg.MaxConns)
|
||||
}
|
||||
}
|
||||
if v := strings.TrimSpace(os.Getenv(envOpsWSMaxConnsPerIP)); v != "" {
|
||||
if parsed, err := strconv.Atoi(v); err == nil && parsed >= 0 {
|
||||
cfg.MaxConnsPerIP = int32(parsed)
|
||||
} else {
|
||||
log.Printf("[OpsWS] invalid %s=%q (expected int>=0); using default=%d", envOpsWSMaxConnsPerIP, v, cfg.MaxConnsPerIP)
|
||||
logger.LegacyPrintf("handler.admin.ops_ws", "[OpsWS] invalid %s=%q (expected int>=0); using default=%d", envOpsWSMaxConnsPerIP, v, cfg.MaxConnsPerIP)
|
||||
}
|
||||
}
|
||||
return cfg
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
package admin
|
||||
|
||||
import (
|
||||
"context"
|
||||
"strconv"
|
||||
"strings"
|
||||
|
||||
@@ -63,9 +64,9 @@ func (h *ProxyHandler) List(c *gin.Context) {
|
||||
return
|
||||
}
|
||||
|
||||
out := make([]dto.ProxyWithAccountCount, 0, len(proxies))
|
||||
out := make([]dto.AdminProxyWithAccountCount, 0, len(proxies))
|
||||
for i := range proxies {
|
||||
out = append(out, *dto.ProxyWithAccountCountFromService(&proxies[i]))
|
||||
out = append(out, *dto.ProxyWithAccountCountFromServiceAdmin(&proxies[i]))
|
||||
}
|
||||
response.Paginated(c, out, total, page, pageSize)
|
||||
}
|
||||
@@ -82,9 +83,9 @@ func (h *ProxyHandler) GetAll(c *gin.Context) {
|
||||
response.ErrorFrom(c, err)
|
||||
return
|
||||
}
|
||||
out := make([]dto.ProxyWithAccountCount, 0, len(proxies))
|
||||
out := make([]dto.AdminProxyWithAccountCount, 0, len(proxies))
|
||||
for i := range proxies {
|
||||
out = append(out, *dto.ProxyWithAccountCountFromService(&proxies[i]))
|
||||
out = append(out, *dto.ProxyWithAccountCountFromServiceAdmin(&proxies[i]))
|
||||
}
|
||||
response.Success(c, out)
|
||||
return
|
||||
@@ -96,9 +97,9 @@ func (h *ProxyHandler) GetAll(c *gin.Context) {
|
||||
return
|
||||
}
|
||||
|
||||
out := make([]dto.Proxy, 0, len(proxies))
|
||||
out := make([]dto.AdminProxy, 0, len(proxies))
|
||||
for i := range proxies {
|
||||
out = append(out, *dto.ProxyFromService(&proxies[i]))
|
||||
out = append(out, *dto.ProxyFromServiceAdmin(&proxies[i]))
|
||||
}
|
||||
response.Success(c, out)
|
||||
}
|
||||
@@ -118,7 +119,7 @@ func (h *ProxyHandler) GetByID(c *gin.Context) {
|
||||
return
|
||||
}
|
||||
|
||||
response.Success(c, dto.ProxyFromService(proxy))
|
||||
response.Success(c, dto.ProxyFromServiceAdmin(proxy))
|
||||
}
|
||||
|
||||
// Create handles creating a new proxy
|
||||
@@ -130,20 +131,20 @@ func (h *ProxyHandler) Create(c *gin.Context) {
|
||||
return
|
||||
}
|
||||
|
||||
proxy, err := h.adminService.CreateProxy(c.Request.Context(), &service.CreateProxyInput{
|
||||
Name: strings.TrimSpace(req.Name),
|
||||
Protocol: strings.TrimSpace(req.Protocol),
|
||||
Host: strings.TrimSpace(req.Host),
|
||||
Port: req.Port,
|
||||
Username: strings.TrimSpace(req.Username),
|
||||
Password: strings.TrimSpace(req.Password),
|
||||
executeAdminIdempotentJSON(c, "admin.proxies.create", req, service.DefaultWriteIdempotencyTTL(), func(ctx context.Context) (any, error) {
|
||||
proxy, err := h.adminService.CreateProxy(ctx, &service.CreateProxyInput{
|
||||
Name: strings.TrimSpace(req.Name),
|
||||
Protocol: strings.TrimSpace(req.Protocol),
|
||||
Host: strings.TrimSpace(req.Host),
|
||||
Port: req.Port,
|
||||
Username: strings.TrimSpace(req.Username),
|
||||
Password: strings.TrimSpace(req.Password),
|
||||
})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return dto.ProxyFromServiceAdmin(proxy), nil
|
||||
})
|
||||
if err != nil {
|
||||
response.ErrorFrom(c, err)
|
||||
return
|
||||
}
|
||||
|
||||
response.Success(c, dto.ProxyFromService(proxy))
|
||||
}
|
||||
|
||||
// Update handles updating a proxy
|
||||
@@ -175,7 +176,7 @@ func (h *ProxyHandler) Update(c *gin.Context) {
|
||||
return
|
||||
}
|
||||
|
||||
response.Success(c, dto.ProxyFromService(proxy))
|
||||
response.Success(c, dto.ProxyFromServiceAdmin(proxy))
|
||||
}
|
||||
|
||||
// Delete handles deleting a proxy
|
||||
@@ -236,6 +237,24 @@ func (h *ProxyHandler) Test(c *gin.Context) {
|
||||
response.Success(c, result)
|
||||
}
|
||||
|
||||
// CheckQuality handles checking proxy quality across common AI targets.
|
||||
// POST /api/v1/admin/proxies/:id/quality-check
|
||||
func (h *ProxyHandler) CheckQuality(c *gin.Context) {
|
||||
proxyID, err := strconv.ParseInt(c.Param("id"), 10, 64)
|
||||
if err != nil {
|
||||
response.BadRequest(c, "Invalid proxy ID")
|
||||
return
|
||||
}
|
||||
|
||||
result, err := h.adminService.CheckProxyQuality(c.Request.Context(), proxyID)
|
||||
if err != nil {
|
||||
response.ErrorFrom(c, err)
|
||||
return
|
||||
}
|
||||
|
||||
response.Success(c, result)
|
||||
}
|
||||
|
||||
// GetStats handles getting proxy statistics
|
||||
// GET /api/v1/admin/proxies/:id/stats
|
||||
func (h *ProxyHandler) GetStats(c *gin.Context) {
|
||||
|
||||
@@ -2,12 +2,15 @@ package admin
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"encoding/csv"
|
||||
"errors"
|
||||
"fmt"
|
||||
"strconv"
|
||||
"strings"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/handler/dto"
|
||||
infraerrors "github.com/Wei-Shaw/sub2api/internal/pkg/errors"
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/response"
|
||||
"github.com/Wei-Shaw/sub2api/internal/service"
|
||||
|
||||
@@ -16,13 +19,15 @@ import (
|
||||
|
||||
// RedeemHandler handles admin redeem code management
|
||||
type RedeemHandler struct {
|
||||
adminService service.AdminService
|
||||
adminService service.AdminService
|
||||
redeemService *service.RedeemService
|
||||
}
|
||||
|
||||
// NewRedeemHandler creates a new admin redeem handler
|
||||
func NewRedeemHandler(adminService service.AdminService) *RedeemHandler {
|
||||
func NewRedeemHandler(adminService service.AdminService, redeemService *service.RedeemService) *RedeemHandler {
|
||||
return &RedeemHandler{
|
||||
adminService: adminService,
|
||||
adminService: adminService,
|
||||
redeemService: redeemService,
|
||||
}
|
||||
}
|
||||
|
||||
@@ -35,6 +40,15 @@ type GenerateRedeemCodesRequest struct {
|
||||
ValidityDays int `json:"validity_days" binding:"omitempty,max=36500"` // 订阅类型使用,默认30天,最大100年
|
||||
}
|
||||
|
||||
// CreateAndRedeemCodeRequest represents creating a fixed code and redeeming it for a target user.
|
||||
type CreateAndRedeemCodeRequest struct {
|
||||
Code string `json:"code" binding:"required,min=3,max=128"`
|
||||
Type string `json:"type" binding:"required,oneof=balance concurrency subscription invitation"`
|
||||
Value float64 `json:"value" binding:"required,gt=0"`
|
||||
UserID int64 `json:"user_id" binding:"required,gt=0"`
|
||||
Notes string `json:"notes"`
|
||||
}
|
||||
|
||||
// List handles listing all redeem codes with pagination
|
||||
// GET /api/v1/admin/redeem-codes
|
||||
func (h *RedeemHandler) List(c *gin.Context) {
|
||||
@@ -88,23 +102,99 @@ func (h *RedeemHandler) Generate(c *gin.Context) {
|
||||
return
|
||||
}
|
||||
|
||||
codes, err := h.adminService.GenerateRedeemCodes(c.Request.Context(), &service.GenerateRedeemCodesInput{
|
||||
Count: req.Count,
|
||||
Type: req.Type,
|
||||
Value: req.Value,
|
||||
GroupID: req.GroupID,
|
||||
ValidityDays: req.ValidityDays,
|
||||
executeAdminIdempotentJSON(c, "admin.redeem_codes.generate", req, service.DefaultWriteIdempotencyTTL(), func(ctx context.Context) (any, error) {
|
||||
codes, execErr := h.adminService.GenerateRedeemCodes(ctx, &service.GenerateRedeemCodesInput{
|
||||
Count: req.Count,
|
||||
Type: req.Type,
|
||||
Value: req.Value,
|
||||
GroupID: req.GroupID,
|
||||
ValidityDays: req.ValidityDays,
|
||||
})
|
||||
if execErr != nil {
|
||||
return nil, execErr
|
||||
}
|
||||
|
||||
out := make([]dto.AdminRedeemCode, 0, len(codes))
|
||||
for i := range codes {
|
||||
out = append(out, *dto.RedeemCodeFromServiceAdmin(&codes[i]))
|
||||
}
|
||||
return out, nil
|
||||
})
|
||||
if err != nil {
|
||||
response.ErrorFrom(c, err)
|
||||
}
|
||||
|
||||
// CreateAndRedeem creates a fixed redeem code and redeems it for a target user in one step.
|
||||
// POST /api/v1/admin/redeem-codes/create-and-redeem
|
||||
func (h *RedeemHandler) CreateAndRedeem(c *gin.Context) {
|
||||
if h.redeemService == nil {
|
||||
response.InternalError(c, "redeem service not configured")
|
||||
return
|
||||
}
|
||||
|
||||
out := make([]dto.AdminRedeemCode, 0, len(codes))
|
||||
for i := range codes {
|
||||
out = append(out, *dto.RedeemCodeFromServiceAdmin(&codes[i]))
|
||||
var req CreateAndRedeemCodeRequest
|
||||
if err := c.ShouldBindJSON(&req); err != nil {
|
||||
response.BadRequest(c, "Invalid request: "+err.Error())
|
||||
return
|
||||
}
|
||||
response.Success(c, out)
|
||||
req.Code = strings.TrimSpace(req.Code)
|
||||
|
||||
executeAdminIdempotentJSON(c, "admin.redeem_codes.create_and_redeem", req, service.DefaultWriteIdempotencyTTL(), func(ctx context.Context) (any, error) {
|
||||
existing, err := h.redeemService.GetByCode(ctx, req.Code)
|
||||
if err == nil {
|
||||
return h.resolveCreateAndRedeemExisting(ctx, existing, req.UserID)
|
||||
}
|
||||
if !errors.Is(err, service.ErrRedeemCodeNotFound) {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
createErr := h.redeemService.CreateCode(ctx, &service.RedeemCode{
|
||||
Code: req.Code,
|
||||
Type: req.Type,
|
||||
Value: req.Value,
|
||||
Status: service.StatusUnused,
|
||||
Notes: req.Notes,
|
||||
})
|
||||
if createErr != nil {
|
||||
// Unique code race: if code now exists, use idempotent semantics by used_by.
|
||||
existingAfterCreateErr, getErr := h.redeemService.GetByCode(ctx, req.Code)
|
||||
if getErr == nil {
|
||||
return h.resolveCreateAndRedeemExisting(ctx, existingAfterCreateErr, req.UserID)
|
||||
}
|
||||
return nil, createErr
|
||||
}
|
||||
|
||||
redeemed, redeemErr := h.redeemService.Redeem(ctx, req.UserID, req.Code)
|
||||
if redeemErr != nil {
|
||||
return nil, redeemErr
|
||||
}
|
||||
return gin.H{"redeem_code": dto.RedeemCodeFromServiceAdmin(redeemed)}, nil
|
||||
})
|
||||
}
|
||||
|
||||
func (h *RedeemHandler) resolveCreateAndRedeemExisting(ctx context.Context, existing *service.RedeemCode, userID int64) (any, error) {
|
||||
if existing == nil {
|
||||
return nil, infraerrors.Conflict("REDEEM_CODE_CONFLICT", "redeem code conflict")
|
||||
}
|
||||
|
||||
// If previous run created the code but crashed before redeem, redeem it now.
|
||||
if existing.CanUse() {
|
||||
redeemed, err := h.redeemService.Redeem(ctx, userID, existing.Code)
|
||||
if err == nil {
|
||||
return gin.H{"redeem_code": dto.RedeemCodeFromServiceAdmin(redeemed)}, nil
|
||||
}
|
||||
if !errors.Is(err, service.ErrRedeemCodeUsed) {
|
||||
return nil, err
|
||||
}
|
||||
latest, getErr := h.redeemService.GetByCode(ctx, existing.Code)
|
||||
if getErr == nil {
|
||||
existing = latest
|
||||
}
|
||||
}
|
||||
|
||||
if existing.UsedBy != nil && *existing.UsedBy == userID {
|
||||
return gin.H{"redeem_code": dto.RedeemCodeFromServiceAdmin(existing)}, nil
|
||||
}
|
||||
|
||||
return nil, infraerrors.Conflict("REDEEM_CODE_CONFLICT", "redeem code already used by another user")
|
||||
}
|
||||
|
||||
// Delete handles deleting a redeem code
|
||||
|
||||
97
backend/internal/handler/admin/search_truncate_test.go
Normal file
97
backend/internal/handler/admin/search_truncate_test.go
Normal file
@@ -0,0 +1,97 @@
|
||||
//go:build unit
|
||||
|
||||
package admin
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
// truncateSearchByRune 模拟 user_handler.go 中的 search 截断逻辑
|
||||
func truncateSearchByRune(search string, maxRunes int) string {
|
||||
if runes := []rune(search); len(runes) > maxRunes {
|
||||
return string(runes[:maxRunes])
|
||||
}
|
||||
return search
|
||||
}
|
||||
|
||||
func TestTruncateSearchByRune(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
input string
|
||||
maxRunes int
|
||||
wantLen int // 期望的 rune 长度
|
||||
}{
|
||||
{
|
||||
name: "纯中文超长",
|
||||
input: string(make([]rune, 150)),
|
||||
maxRunes: 100,
|
||||
wantLen: 100,
|
||||
},
|
||||
{
|
||||
name: "纯 ASCII 超长",
|
||||
input: string(make([]byte, 150)),
|
||||
maxRunes: 100,
|
||||
wantLen: 100,
|
||||
},
|
||||
{
|
||||
name: "空字符串",
|
||||
input: "",
|
||||
maxRunes: 100,
|
||||
wantLen: 0,
|
||||
},
|
||||
{
|
||||
name: "恰好 100 个字符",
|
||||
input: string(make([]rune, 100)),
|
||||
maxRunes: 100,
|
||||
wantLen: 100,
|
||||
},
|
||||
{
|
||||
name: "不足 100 字符不截断",
|
||||
input: "hello世界",
|
||||
maxRunes: 100,
|
||||
wantLen: 7,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range tests {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
result := truncateSearchByRune(tc.input, tc.maxRunes)
|
||||
require.Equal(t, tc.wantLen, len([]rune(result)))
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestTruncateSearchByRune_PreservesMultibyte(t *testing.T) {
|
||||
// 101 个中文字符,截断到 100 个后应该仍然是有效 UTF-8
|
||||
input := ""
|
||||
for i := 0; i < 101; i++ {
|
||||
input += "中"
|
||||
}
|
||||
result := truncateSearchByRune(input, 100)
|
||||
|
||||
require.Equal(t, 100, len([]rune(result)))
|
||||
// 验证截断结果是有效的 UTF-8(每个中文字符 3 字节)
|
||||
require.Equal(t, 300, len(result))
|
||||
}
|
||||
|
||||
func TestTruncateSearchByRune_MixedASCIIAndMultibyte(t *testing.T) {
|
||||
// 50 个 ASCII + 51 个中文 = 101 个 rune
|
||||
input := ""
|
||||
for i := 0; i < 50; i++ {
|
||||
input += "a"
|
||||
}
|
||||
for i := 0; i < 51; i++ {
|
||||
input += "中"
|
||||
}
|
||||
result := truncateSearchByRune(input, 100)
|
||||
|
||||
runes := []rune(result)
|
||||
require.Equal(t, 100, len(runes))
|
||||
// 前 50 个应该是 'a',后 50 个应该是 '中'
|
||||
require.Equal(t, 'a', runes[0])
|
||||
require.Equal(t, 'a', runes[49])
|
||||
require.Equal(t, '中', runes[50])
|
||||
require.Equal(t, '中', runes[99])
|
||||
}
|
||||
@@ -1,7 +1,13 @@
|
||||
package admin
|
||||
|
||||
import (
|
||||
"crypto/rand"
|
||||
"encoding/hex"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"log"
|
||||
"net/http"
|
||||
"regexp"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
@@ -14,21 +20,38 @@ import (
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
// semverPattern 预编译 semver 格式校验正则
|
||||
var semverPattern = regexp.MustCompile(`^\d+\.\d+\.\d+$`)
|
||||
|
||||
// menuItemIDPattern validates custom menu item IDs: alphanumeric, hyphens, underscores only.
|
||||
var menuItemIDPattern = regexp.MustCompile(`^[a-zA-Z0-9_-]+$`)
|
||||
|
||||
// generateMenuItemID generates a short random hex ID for a custom menu item.
|
||||
func generateMenuItemID() (string, error) {
|
||||
b := make([]byte, 8)
|
||||
if _, err := rand.Read(b); err != nil {
|
||||
return "", fmt.Errorf("generate menu item ID: %w", err)
|
||||
}
|
||||
return hex.EncodeToString(b), nil
|
||||
}
|
||||
|
||||
// SettingHandler 系统设置处理器
|
||||
type SettingHandler struct {
|
||||
settingService *service.SettingService
|
||||
emailService *service.EmailService
|
||||
turnstileService *service.TurnstileService
|
||||
opsService *service.OpsService
|
||||
soraS3Storage *service.SoraS3Storage
|
||||
}
|
||||
|
||||
// NewSettingHandler 创建系统设置处理器
|
||||
func NewSettingHandler(settingService *service.SettingService, emailService *service.EmailService, turnstileService *service.TurnstileService, opsService *service.OpsService) *SettingHandler {
|
||||
func NewSettingHandler(settingService *service.SettingService, emailService *service.EmailService, turnstileService *service.TurnstileService, opsService *service.OpsService, soraS3Storage *service.SoraS3Storage) *SettingHandler {
|
||||
return &SettingHandler{
|
||||
settingService: settingService,
|
||||
emailService: emailService,
|
||||
turnstileService: turnstileService,
|
||||
opsService: opsService,
|
||||
soraS3Storage: soraS3Storage,
|
||||
}
|
||||
}
|
||||
|
||||
@@ -43,10 +66,18 @@ func (h *SettingHandler) GetSettings(c *gin.Context) {
|
||||
|
||||
// Check if ops monitoring is enabled (respects config.ops.enabled)
|
||||
opsEnabled := h.opsService != nil && h.opsService.IsMonitoringEnabled(c.Request.Context())
|
||||
defaultSubscriptions := make([]dto.DefaultSubscriptionSetting, 0, len(settings.DefaultSubscriptions))
|
||||
for _, sub := range settings.DefaultSubscriptions {
|
||||
defaultSubscriptions = append(defaultSubscriptions, dto.DefaultSubscriptionSetting{
|
||||
GroupID: sub.GroupID,
|
||||
ValidityDays: sub.ValidityDays,
|
||||
})
|
||||
}
|
||||
|
||||
response.Success(c, dto.SystemSettings{
|
||||
RegistrationEnabled: settings.RegistrationEnabled,
|
||||
EmailVerifyEnabled: settings.EmailVerifyEnabled,
|
||||
RegistrationEmailSuffixWhitelist: settings.RegistrationEmailSuffixWhitelist,
|
||||
PromoCodeEnabled: settings.PromoCodeEnabled,
|
||||
PasswordResetEnabled: settings.PasswordResetEnabled,
|
||||
InvitationCodeEnabled: settings.InvitationCodeEnabled,
|
||||
@@ -76,8 +107,11 @@ func (h *SettingHandler) GetSettings(c *gin.Context) {
|
||||
HideCcsImportButton: settings.HideCcsImportButton,
|
||||
PurchaseSubscriptionEnabled: settings.PurchaseSubscriptionEnabled,
|
||||
PurchaseSubscriptionURL: settings.PurchaseSubscriptionURL,
|
||||
SoraClientEnabled: settings.SoraClientEnabled,
|
||||
CustomMenuItems: dto.ParseCustomMenuItems(settings.CustomMenuItems),
|
||||
DefaultConcurrency: settings.DefaultConcurrency,
|
||||
DefaultBalance: settings.DefaultBalance,
|
||||
DefaultSubscriptions: defaultSubscriptions,
|
||||
EnableModelFallback: settings.EnableModelFallback,
|
||||
FallbackModelAnthropic: settings.FallbackModelAnthropic,
|
||||
FallbackModelOpenAI: settings.FallbackModelOpenAI,
|
||||
@@ -89,18 +123,21 @@ func (h *SettingHandler) GetSettings(c *gin.Context) {
|
||||
OpsRealtimeMonitoringEnabled: settings.OpsRealtimeMonitoringEnabled,
|
||||
OpsQueryModeDefault: settings.OpsQueryModeDefault,
|
||||
OpsMetricsIntervalSeconds: settings.OpsMetricsIntervalSeconds,
|
||||
MinClaudeCodeVersion: settings.MinClaudeCodeVersion,
|
||||
AllowUngroupedKeyScheduling: settings.AllowUngroupedKeyScheduling,
|
||||
})
|
||||
}
|
||||
|
||||
// UpdateSettingsRequest 更新设置请求
|
||||
type UpdateSettingsRequest struct {
|
||||
// 注册设置
|
||||
RegistrationEnabled bool `json:"registration_enabled"`
|
||||
EmailVerifyEnabled bool `json:"email_verify_enabled"`
|
||||
PromoCodeEnabled bool `json:"promo_code_enabled"`
|
||||
PasswordResetEnabled bool `json:"password_reset_enabled"`
|
||||
InvitationCodeEnabled bool `json:"invitation_code_enabled"`
|
||||
TotpEnabled bool `json:"totp_enabled"` // TOTP 双因素认证
|
||||
RegistrationEnabled bool `json:"registration_enabled"`
|
||||
EmailVerifyEnabled bool `json:"email_verify_enabled"`
|
||||
RegistrationEmailSuffixWhitelist []string `json:"registration_email_suffix_whitelist"`
|
||||
PromoCodeEnabled bool `json:"promo_code_enabled"`
|
||||
PasswordResetEnabled bool `json:"password_reset_enabled"`
|
||||
InvitationCodeEnabled bool `json:"invitation_code_enabled"`
|
||||
TotpEnabled bool `json:"totp_enabled"` // TOTP 双因素认证
|
||||
|
||||
// 邮件服务设置
|
||||
SMTPHost string `json:"smtp_host"`
|
||||
@@ -123,20 +160,23 @@ type UpdateSettingsRequest struct {
|
||||
LinuxDoConnectRedirectURL string `json:"linuxdo_connect_redirect_url"`
|
||||
|
||||
// OEM设置
|
||||
SiteName string `json:"site_name"`
|
||||
SiteLogo string `json:"site_logo"`
|
||||
SiteSubtitle string `json:"site_subtitle"`
|
||||
APIBaseURL string `json:"api_base_url"`
|
||||
ContactInfo string `json:"contact_info"`
|
||||
DocURL string `json:"doc_url"`
|
||||
HomeContent string `json:"home_content"`
|
||||
HideCcsImportButton bool `json:"hide_ccs_import_button"`
|
||||
PurchaseSubscriptionEnabled *bool `json:"purchase_subscription_enabled"`
|
||||
PurchaseSubscriptionURL *string `json:"purchase_subscription_url"`
|
||||
SiteName string `json:"site_name"`
|
||||
SiteLogo string `json:"site_logo"`
|
||||
SiteSubtitle string `json:"site_subtitle"`
|
||||
APIBaseURL string `json:"api_base_url"`
|
||||
ContactInfo string `json:"contact_info"`
|
||||
DocURL string `json:"doc_url"`
|
||||
HomeContent string `json:"home_content"`
|
||||
HideCcsImportButton bool `json:"hide_ccs_import_button"`
|
||||
PurchaseSubscriptionEnabled *bool `json:"purchase_subscription_enabled"`
|
||||
PurchaseSubscriptionURL *string `json:"purchase_subscription_url"`
|
||||
SoraClientEnabled bool `json:"sora_client_enabled"`
|
||||
CustomMenuItems *[]dto.CustomMenuItem `json:"custom_menu_items"`
|
||||
|
||||
// 默认配置
|
||||
DefaultConcurrency int `json:"default_concurrency"`
|
||||
DefaultBalance float64 `json:"default_balance"`
|
||||
DefaultConcurrency int `json:"default_concurrency"`
|
||||
DefaultBalance float64 `json:"default_balance"`
|
||||
DefaultSubscriptions []dto.DefaultSubscriptionSetting `json:"default_subscriptions"`
|
||||
|
||||
// Model fallback configuration
|
||||
EnableModelFallback bool `json:"enable_model_fallback"`
|
||||
@@ -154,6 +194,11 @@ type UpdateSettingsRequest struct {
|
||||
OpsRealtimeMonitoringEnabled *bool `json:"ops_realtime_monitoring_enabled"`
|
||||
OpsQueryModeDefault *string `json:"ops_query_mode_default"`
|
||||
OpsMetricsIntervalSeconds *int `json:"ops_metrics_interval_seconds"`
|
||||
|
||||
MinClaudeCodeVersion string `json:"min_claude_code_version"`
|
||||
|
||||
// 分组隔离
|
||||
AllowUngroupedKeyScheduling bool `json:"allow_ungrouped_key_scheduling"`
|
||||
}
|
||||
|
||||
// UpdateSettings 更新系统设置
|
||||
@@ -181,6 +226,7 @@ func (h *SettingHandler) UpdateSettings(c *gin.Context) {
|
||||
if req.SMTPPort <= 0 {
|
||||
req.SMTPPort = 587
|
||||
}
|
||||
req.DefaultSubscriptions = normalizeDefaultSubscriptions(req.DefaultSubscriptions)
|
||||
|
||||
// Turnstile 参数验证
|
||||
if req.TurnstileEnabled {
|
||||
@@ -276,6 +322,84 @@ func (h *SettingHandler) UpdateSettings(c *gin.Context) {
|
||||
}
|
||||
}
|
||||
|
||||
// 自定义菜单项验证
|
||||
const (
|
||||
maxCustomMenuItems = 20
|
||||
maxMenuItemLabelLen = 50
|
||||
maxMenuItemURLLen = 2048
|
||||
maxMenuItemIconSVGLen = 10 * 1024 // 10KB
|
||||
maxMenuItemIDLen = 32
|
||||
)
|
||||
|
||||
customMenuJSON := previousSettings.CustomMenuItems
|
||||
if req.CustomMenuItems != nil {
|
||||
items := *req.CustomMenuItems
|
||||
if len(items) > maxCustomMenuItems {
|
||||
response.BadRequest(c, "Too many custom menu items (max 20)")
|
||||
return
|
||||
}
|
||||
for i, item := range items {
|
||||
if strings.TrimSpace(item.Label) == "" {
|
||||
response.BadRequest(c, "Custom menu item label is required")
|
||||
return
|
||||
}
|
||||
if len(item.Label) > maxMenuItemLabelLen {
|
||||
response.BadRequest(c, "Custom menu item label is too long (max 50 characters)")
|
||||
return
|
||||
}
|
||||
if strings.TrimSpace(item.URL) == "" {
|
||||
response.BadRequest(c, "Custom menu item URL is required")
|
||||
return
|
||||
}
|
||||
if len(item.URL) > maxMenuItemURLLen {
|
||||
response.BadRequest(c, "Custom menu item URL is too long (max 2048 characters)")
|
||||
return
|
||||
}
|
||||
if err := config.ValidateAbsoluteHTTPURL(strings.TrimSpace(item.URL)); err != nil {
|
||||
response.BadRequest(c, "Custom menu item URL must be an absolute http(s) URL")
|
||||
return
|
||||
}
|
||||
if item.Visibility != "user" && item.Visibility != "admin" {
|
||||
response.BadRequest(c, "Custom menu item visibility must be 'user' or 'admin'")
|
||||
return
|
||||
}
|
||||
if len(item.IconSVG) > maxMenuItemIconSVGLen {
|
||||
response.BadRequest(c, "Custom menu item icon SVG is too large (max 10KB)")
|
||||
return
|
||||
}
|
||||
// Auto-generate ID if missing
|
||||
if strings.TrimSpace(item.ID) == "" {
|
||||
id, err := generateMenuItemID()
|
||||
if err != nil {
|
||||
response.Error(c, http.StatusInternalServerError, "Failed to generate menu item ID")
|
||||
return
|
||||
}
|
||||
items[i].ID = id
|
||||
} else if len(item.ID) > maxMenuItemIDLen {
|
||||
response.BadRequest(c, "Custom menu item ID is too long (max 32 characters)")
|
||||
return
|
||||
} else if !menuItemIDPattern.MatchString(item.ID) {
|
||||
response.BadRequest(c, "Custom menu item ID contains invalid characters (only a-z, A-Z, 0-9, - and _ are allowed)")
|
||||
return
|
||||
}
|
||||
}
|
||||
// ID uniqueness check
|
||||
seen := make(map[string]struct{}, len(items))
|
||||
for _, item := range items {
|
||||
if _, exists := seen[item.ID]; exists {
|
||||
response.BadRequest(c, "Duplicate custom menu item ID: "+item.ID)
|
||||
return
|
||||
}
|
||||
seen[item.ID] = struct{}{}
|
||||
}
|
||||
menuBytes, err := json.Marshal(items)
|
||||
if err != nil {
|
||||
response.BadRequest(c, "Failed to serialize custom menu items")
|
||||
return
|
||||
}
|
||||
customMenuJSON = string(menuBytes)
|
||||
}
|
||||
|
||||
// Ops metrics collector interval validation (seconds).
|
||||
if req.OpsMetricsIntervalSeconds != nil {
|
||||
v := *req.OpsMetricsIntervalSeconds
|
||||
@@ -287,47 +411,68 @@ func (h *SettingHandler) UpdateSettings(c *gin.Context) {
|
||||
}
|
||||
req.OpsMetricsIntervalSeconds = &v
|
||||
}
|
||||
defaultSubscriptions := make([]service.DefaultSubscriptionSetting, 0, len(req.DefaultSubscriptions))
|
||||
for _, sub := range req.DefaultSubscriptions {
|
||||
defaultSubscriptions = append(defaultSubscriptions, service.DefaultSubscriptionSetting{
|
||||
GroupID: sub.GroupID,
|
||||
ValidityDays: sub.ValidityDays,
|
||||
})
|
||||
}
|
||||
|
||||
// 验证最低版本号格式(空字符串=禁用,或合法 semver)
|
||||
if req.MinClaudeCodeVersion != "" {
|
||||
if !semverPattern.MatchString(req.MinClaudeCodeVersion) {
|
||||
response.Error(c, http.StatusBadRequest, "min_claude_code_version must be empty or a valid semver (e.g. 2.1.63)")
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
settings := &service.SystemSettings{
|
||||
RegistrationEnabled: req.RegistrationEnabled,
|
||||
EmailVerifyEnabled: req.EmailVerifyEnabled,
|
||||
PromoCodeEnabled: req.PromoCodeEnabled,
|
||||
PasswordResetEnabled: req.PasswordResetEnabled,
|
||||
InvitationCodeEnabled: req.InvitationCodeEnabled,
|
||||
TotpEnabled: req.TotpEnabled,
|
||||
SMTPHost: req.SMTPHost,
|
||||
SMTPPort: req.SMTPPort,
|
||||
SMTPUsername: req.SMTPUsername,
|
||||
SMTPPassword: req.SMTPPassword,
|
||||
SMTPFrom: req.SMTPFrom,
|
||||
SMTPFromName: req.SMTPFromName,
|
||||
SMTPUseTLS: req.SMTPUseTLS,
|
||||
TurnstileEnabled: req.TurnstileEnabled,
|
||||
TurnstileSiteKey: req.TurnstileSiteKey,
|
||||
TurnstileSecretKey: req.TurnstileSecretKey,
|
||||
LinuxDoConnectEnabled: req.LinuxDoConnectEnabled,
|
||||
LinuxDoConnectClientID: req.LinuxDoConnectClientID,
|
||||
LinuxDoConnectClientSecret: req.LinuxDoConnectClientSecret,
|
||||
LinuxDoConnectRedirectURL: req.LinuxDoConnectRedirectURL,
|
||||
SiteName: req.SiteName,
|
||||
SiteLogo: req.SiteLogo,
|
||||
SiteSubtitle: req.SiteSubtitle,
|
||||
APIBaseURL: req.APIBaseURL,
|
||||
ContactInfo: req.ContactInfo,
|
||||
DocURL: req.DocURL,
|
||||
HomeContent: req.HomeContent,
|
||||
HideCcsImportButton: req.HideCcsImportButton,
|
||||
PurchaseSubscriptionEnabled: purchaseEnabled,
|
||||
PurchaseSubscriptionURL: purchaseURL,
|
||||
DefaultConcurrency: req.DefaultConcurrency,
|
||||
DefaultBalance: req.DefaultBalance,
|
||||
EnableModelFallback: req.EnableModelFallback,
|
||||
FallbackModelAnthropic: req.FallbackModelAnthropic,
|
||||
FallbackModelOpenAI: req.FallbackModelOpenAI,
|
||||
FallbackModelGemini: req.FallbackModelGemini,
|
||||
FallbackModelAntigravity: req.FallbackModelAntigravity,
|
||||
EnableIdentityPatch: req.EnableIdentityPatch,
|
||||
IdentityPatchPrompt: req.IdentityPatchPrompt,
|
||||
RegistrationEnabled: req.RegistrationEnabled,
|
||||
EmailVerifyEnabled: req.EmailVerifyEnabled,
|
||||
RegistrationEmailSuffixWhitelist: req.RegistrationEmailSuffixWhitelist,
|
||||
PromoCodeEnabled: req.PromoCodeEnabled,
|
||||
PasswordResetEnabled: req.PasswordResetEnabled,
|
||||
InvitationCodeEnabled: req.InvitationCodeEnabled,
|
||||
TotpEnabled: req.TotpEnabled,
|
||||
SMTPHost: req.SMTPHost,
|
||||
SMTPPort: req.SMTPPort,
|
||||
SMTPUsername: req.SMTPUsername,
|
||||
SMTPPassword: req.SMTPPassword,
|
||||
SMTPFrom: req.SMTPFrom,
|
||||
SMTPFromName: req.SMTPFromName,
|
||||
SMTPUseTLS: req.SMTPUseTLS,
|
||||
TurnstileEnabled: req.TurnstileEnabled,
|
||||
TurnstileSiteKey: req.TurnstileSiteKey,
|
||||
TurnstileSecretKey: req.TurnstileSecretKey,
|
||||
LinuxDoConnectEnabled: req.LinuxDoConnectEnabled,
|
||||
LinuxDoConnectClientID: req.LinuxDoConnectClientID,
|
||||
LinuxDoConnectClientSecret: req.LinuxDoConnectClientSecret,
|
||||
LinuxDoConnectRedirectURL: req.LinuxDoConnectRedirectURL,
|
||||
SiteName: req.SiteName,
|
||||
SiteLogo: req.SiteLogo,
|
||||
SiteSubtitle: req.SiteSubtitle,
|
||||
APIBaseURL: req.APIBaseURL,
|
||||
ContactInfo: req.ContactInfo,
|
||||
DocURL: req.DocURL,
|
||||
HomeContent: req.HomeContent,
|
||||
HideCcsImportButton: req.HideCcsImportButton,
|
||||
PurchaseSubscriptionEnabled: purchaseEnabled,
|
||||
PurchaseSubscriptionURL: purchaseURL,
|
||||
SoraClientEnabled: req.SoraClientEnabled,
|
||||
CustomMenuItems: customMenuJSON,
|
||||
DefaultConcurrency: req.DefaultConcurrency,
|
||||
DefaultBalance: req.DefaultBalance,
|
||||
DefaultSubscriptions: defaultSubscriptions,
|
||||
EnableModelFallback: req.EnableModelFallback,
|
||||
FallbackModelAnthropic: req.FallbackModelAnthropic,
|
||||
FallbackModelOpenAI: req.FallbackModelOpenAI,
|
||||
FallbackModelGemini: req.FallbackModelGemini,
|
||||
FallbackModelAntigravity: req.FallbackModelAntigravity,
|
||||
EnableIdentityPatch: req.EnableIdentityPatch,
|
||||
IdentityPatchPrompt: req.IdentityPatchPrompt,
|
||||
MinClaudeCodeVersion: req.MinClaudeCodeVersion,
|
||||
AllowUngroupedKeyScheduling: req.AllowUngroupedKeyScheduling,
|
||||
OpsMonitoringEnabled: func() bool {
|
||||
if req.OpsMonitoringEnabled != nil {
|
||||
return *req.OpsMonitoringEnabled
|
||||
@@ -367,10 +512,18 @@ func (h *SettingHandler) UpdateSettings(c *gin.Context) {
|
||||
response.ErrorFrom(c, err)
|
||||
return
|
||||
}
|
||||
updatedDefaultSubscriptions := make([]dto.DefaultSubscriptionSetting, 0, len(updatedSettings.DefaultSubscriptions))
|
||||
for _, sub := range updatedSettings.DefaultSubscriptions {
|
||||
updatedDefaultSubscriptions = append(updatedDefaultSubscriptions, dto.DefaultSubscriptionSetting{
|
||||
GroupID: sub.GroupID,
|
||||
ValidityDays: sub.ValidityDays,
|
||||
})
|
||||
}
|
||||
|
||||
response.Success(c, dto.SystemSettings{
|
||||
RegistrationEnabled: updatedSettings.RegistrationEnabled,
|
||||
EmailVerifyEnabled: updatedSettings.EmailVerifyEnabled,
|
||||
RegistrationEmailSuffixWhitelist: updatedSettings.RegistrationEmailSuffixWhitelist,
|
||||
PromoCodeEnabled: updatedSettings.PromoCodeEnabled,
|
||||
PasswordResetEnabled: updatedSettings.PasswordResetEnabled,
|
||||
InvitationCodeEnabled: updatedSettings.InvitationCodeEnabled,
|
||||
@@ -400,8 +553,11 @@ func (h *SettingHandler) UpdateSettings(c *gin.Context) {
|
||||
HideCcsImportButton: updatedSettings.HideCcsImportButton,
|
||||
PurchaseSubscriptionEnabled: updatedSettings.PurchaseSubscriptionEnabled,
|
||||
PurchaseSubscriptionURL: updatedSettings.PurchaseSubscriptionURL,
|
||||
SoraClientEnabled: updatedSettings.SoraClientEnabled,
|
||||
CustomMenuItems: dto.ParseCustomMenuItems(updatedSettings.CustomMenuItems),
|
||||
DefaultConcurrency: updatedSettings.DefaultConcurrency,
|
||||
DefaultBalance: updatedSettings.DefaultBalance,
|
||||
DefaultSubscriptions: updatedDefaultSubscriptions,
|
||||
EnableModelFallback: updatedSettings.EnableModelFallback,
|
||||
FallbackModelAnthropic: updatedSettings.FallbackModelAnthropic,
|
||||
FallbackModelOpenAI: updatedSettings.FallbackModelOpenAI,
|
||||
@@ -413,6 +569,8 @@ func (h *SettingHandler) UpdateSettings(c *gin.Context) {
|
||||
OpsRealtimeMonitoringEnabled: updatedSettings.OpsRealtimeMonitoringEnabled,
|
||||
OpsQueryModeDefault: updatedSettings.OpsQueryModeDefault,
|
||||
OpsMetricsIntervalSeconds: updatedSettings.OpsMetricsIntervalSeconds,
|
||||
MinClaudeCodeVersion: updatedSettings.MinClaudeCodeVersion,
|
||||
AllowUngroupedKeyScheduling: updatedSettings.AllowUngroupedKeyScheduling,
|
||||
})
|
||||
}
|
||||
|
||||
@@ -444,6 +602,9 @@ func diffSettings(before *service.SystemSettings, after *service.SystemSettings,
|
||||
if before.EmailVerifyEnabled != after.EmailVerifyEnabled {
|
||||
changed = append(changed, "email_verify_enabled")
|
||||
}
|
||||
if !equalStringSlice(before.RegistrationEmailSuffixWhitelist, after.RegistrationEmailSuffixWhitelist) {
|
||||
changed = append(changed, "registration_email_suffix_whitelist")
|
||||
}
|
||||
if before.PasswordResetEnabled != after.PasswordResetEnabled {
|
||||
changed = append(changed, "password_reset_enabled")
|
||||
}
|
||||
@@ -522,6 +683,9 @@ func diffSettings(before *service.SystemSettings, after *service.SystemSettings,
|
||||
if before.DefaultBalance != after.DefaultBalance {
|
||||
changed = append(changed, "default_balance")
|
||||
}
|
||||
if !equalDefaultSubscriptions(before.DefaultSubscriptions, after.DefaultSubscriptions) {
|
||||
changed = append(changed, "default_subscriptions")
|
||||
}
|
||||
if before.EnableModelFallback != after.EnableModelFallback {
|
||||
changed = append(changed, "enable_model_fallback")
|
||||
}
|
||||
@@ -555,9 +719,65 @@ func diffSettings(before *service.SystemSettings, after *service.SystemSettings,
|
||||
if before.OpsMetricsIntervalSeconds != after.OpsMetricsIntervalSeconds {
|
||||
changed = append(changed, "ops_metrics_interval_seconds")
|
||||
}
|
||||
if before.MinClaudeCodeVersion != after.MinClaudeCodeVersion {
|
||||
changed = append(changed, "min_claude_code_version")
|
||||
}
|
||||
if before.AllowUngroupedKeyScheduling != after.AllowUngroupedKeyScheduling {
|
||||
changed = append(changed, "allow_ungrouped_key_scheduling")
|
||||
}
|
||||
if before.PurchaseSubscriptionEnabled != after.PurchaseSubscriptionEnabled {
|
||||
changed = append(changed, "purchase_subscription_enabled")
|
||||
}
|
||||
if before.PurchaseSubscriptionURL != after.PurchaseSubscriptionURL {
|
||||
changed = append(changed, "purchase_subscription_url")
|
||||
}
|
||||
if before.CustomMenuItems != after.CustomMenuItems {
|
||||
changed = append(changed, "custom_menu_items")
|
||||
}
|
||||
return changed
|
||||
}
|
||||
|
||||
func normalizeDefaultSubscriptions(input []dto.DefaultSubscriptionSetting) []dto.DefaultSubscriptionSetting {
|
||||
if len(input) == 0 {
|
||||
return nil
|
||||
}
|
||||
normalized := make([]dto.DefaultSubscriptionSetting, 0, len(input))
|
||||
for _, item := range input {
|
||||
if item.GroupID <= 0 || item.ValidityDays <= 0 {
|
||||
continue
|
||||
}
|
||||
if item.ValidityDays > service.MaxValidityDays {
|
||||
item.ValidityDays = service.MaxValidityDays
|
||||
}
|
||||
normalized = append(normalized, item)
|
||||
}
|
||||
return normalized
|
||||
}
|
||||
|
||||
func equalStringSlice(a, b []string) bool {
|
||||
if len(a) != len(b) {
|
||||
return false
|
||||
}
|
||||
for i := range a {
|
||||
if a[i] != b[i] {
|
||||
return false
|
||||
}
|
||||
}
|
||||
return true
|
||||
}
|
||||
|
||||
func equalDefaultSubscriptions(a, b []service.DefaultSubscriptionSetting) bool {
|
||||
if len(a) != len(b) {
|
||||
return false
|
||||
}
|
||||
for i := range a {
|
||||
if a[i].GroupID != b[i].GroupID || a[i].ValidityDays != b[i].ValidityDays {
|
||||
return false
|
||||
}
|
||||
}
|
||||
return true
|
||||
}
|
||||
|
||||
// TestSMTPRequest 测试SMTP连接请求
|
||||
type TestSMTPRequest struct {
|
||||
SMTPHost string `json:"smtp_host" binding:"required"`
|
||||
@@ -750,6 +970,384 @@ func (h *SettingHandler) GetStreamTimeoutSettings(c *gin.Context) {
|
||||
})
|
||||
}
|
||||
|
||||
func toSoraS3SettingsDTO(settings *service.SoraS3Settings) dto.SoraS3Settings {
|
||||
if settings == nil {
|
||||
return dto.SoraS3Settings{}
|
||||
}
|
||||
return dto.SoraS3Settings{
|
||||
Enabled: settings.Enabled,
|
||||
Endpoint: settings.Endpoint,
|
||||
Region: settings.Region,
|
||||
Bucket: settings.Bucket,
|
||||
AccessKeyID: settings.AccessKeyID,
|
||||
SecretAccessKeyConfigured: settings.SecretAccessKeyConfigured,
|
||||
Prefix: settings.Prefix,
|
||||
ForcePathStyle: settings.ForcePathStyle,
|
||||
CDNURL: settings.CDNURL,
|
||||
DefaultStorageQuotaBytes: settings.DefaultStorageQuotaBytes,
|
||||
}
|
||||
}
|
||||
|
||||
func toSoraS3ProfileDTO(profile service.SoraS3Profile) dto.SoraS3Profile {
|
||||
return dto.SoraS3Profile{
|
||||
ProfileID: profile.ProfileID,
|
||||
Name: profile.Name,
|
||||
IsActive: profile.IsActive,
|
||||
Enabled: profile.Enabled,
|
||||
Endpoint: profile.Endpoint,
|
||||
Region: profile.Region,
|
||||
Bucket: profile.Bucket,
|
||||
AccessKeyID: profile.AccessKeyID,
|
||||
SecretAccessKeyConfigured: profile.SecretAccessKeyConfigured,
|
||||
Prefix: profile.Prefix,
|
||||
ForcePathStyle: profile.ForcePathStyle,
|
||||
CDNURL: profile.CDNURL,
|
||||
DefaultStorageQuotaBytes: profile.DefaultStorageQuotaBytes,
|
||||
UpdatedAt: profile.UpdatedAt,
|
||||
}
|
||||
}
|
||||
|
||||
func validateSoraS3RequiredWhenEnabled(enabled bool, endpoint, bucket, accessKeyID, secretAccessKey string, hasStoredSecret bool) error {
|
||||
if !enabled {
|
||||
return nil
|
||||
}
|
||||
if strings.TrimSpace(endpoint) == "" {
|
||||
return fmt.Errorf("S3 Endpoint is required when enabled")
|
||||
}
|
||||
if strings.TrimSpace(bucket) == "" {
|
||||
return fmt.Errorf("S3 Bucket is required when enabled")
|
||||
}
|
||||
if strings.TrimSpace(accessKeyID) == "" {
|
||||
return fmt.Errorf("S3 Access Key ID is required when enabled")
|
||||
}
|
||||
if strings.TrimSpace(secretAccessKey) != "" || hasStoredSecret {
|
||||
return nil
|
||||
}
|
||||
return fmt.Errorf("S3 Secret Access Key is required when enabled")
|
||||
}
|
||||
|
||||
func findSoraS3ProfileByID(items []service.SoraS3Profile, profileID string) *service.SoraS3Profile {
|
||||
for idx := range items {
|
||||
if items[idx].ProfileID == profileID {
|
||||
return &items[idx]
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// GetSoraS3Settings 获取 Sora S3 存储配置(兼容旧单配置接口)
|
||||
// GET /api/v1/admin/settings/sora-s3
|
||||
func (h *SettingHandler) GetSoraS3Settings(c *gin.Context) {
|
||||
settings, err := h.settingService.GetSoraS3Settings(c.Request.Context())
|
||||
if err != nil {
|
||||
response.ErrorFrom(c, err)
|
||||
return
|
||||
}
|
||||
response.Success(c, toSoraS3SettingsDTO(settings))
|
||||
}
|
||||
|
||||
// ListSoraS3Profiles 获取 Sora S3 多配置
|
||||
// GET /api/v1/admin/settings/sora-s3/profiles
|
||||
func (h *SettingHandler) ListSoraS3Profiles(c *gin.Context) {
|
||||
result, err := h.settingService.ListSoraS3Profiles(c.Request.Context())
|
||||
if err != nil {
|
||||
response.ErrorFrom(c, err)
|
||||
return
|
||||
}
|
||||
items := make([]dto.SoraS3Profile, 0, len(result.Items))
|
||||
for idx := range result.Items {
|
||||
items = append(items, toSoraS3ProfileDTO(result.Items[idx]))
|
||||
}
|
||||
response.Success(c, dto.ListSoraS3ProfilesResponse{
|
||||
ActiveProfileID: result.ActiveProfileID,
|
||||
Items: items,
|
||||
})
|
||||
}
|
||||
|
||||
// UpdateSoraS3SettingsRequest 更新/测试 Sora S3 配置请求(兼容旧接口)
|
||||
type UpdateSoraS3SettingsRequest struct {
|
||||
ProfileID string `json:"profile_id"`
|
||||
Enabled bool `json:"enabled"`
|
||||
Endpoint string `json:"endpoint"`
|
||||
Region string `json:"region"`
|
||||
Bucket string `json:"bucket"`
|
||||
AccessKeyID string `json:"access_key_id"`
|
||||
SecretAccessKey string `json:"secret_access_key"`
|
||||
Prefix string `json:"prefix"`
|
||||
ForcePathStyle bool `json:"force_path_style"`
|
||||
CDNURL string `json:"cdn_url"`
|
||||
DefaultStorageQuotaBytes int64 `json:"default_storage_quota_bytes"`
|
||||
}
|
||||
|
||||
type CreateSoraS3ProfileRequest struct {
|
||||
ProfileID string `json:"profile_id"`
|
||||
Name string `json:"name"`
|
||||
SetActive bool `json:"set_active"`
|
||||
Enabled bool `json:"enabled"`
|
||||
Endpoint string `json:"endpoint"`
|
||||
Region string `json:"region"`
|
||||
Bucket string `json:"bucket"`
|
||||
AccessKeyID string `json:"access_key_id"`
|
||||
SecretAccessKey string `json:"secret_access_key"`
|
||||
Prefix string `json:"prefix"`
|
||||
ForcePathStyle bool `json:"force_path_style"`
|
||||
CDNURL string `json:"cdn_url"`
|
||||
DefaultStorageQuotaBytes int64 `json:"default_storage_quota_bytes"`
|
||||
}
|
||||
|
||||
type UpdateSoraS3ProfileRequest struct {
|
||||
Name string `json:"name"`
|
||||
Enabled bool `json:"enabled"`
|
||||
Endpoint string `json:"endpoint"`
|
||||
Region string `json:"region"`
|
||||
Bucket string `json:"bucket"`
|
||||
AccessKeyID string `json:"access_key_id"`
|
||||
SecretAccessKey string `json:"secret_access_key"`
|
||||
Prefix string `json:"prefix"`
|
||||
ForcePathStyle bool `json:"force_path_style"`
|
||||
CDNURL string `json:"cdn_url"`
|
||||
DefaultStorageQuotaBytes int64 `json:"default_storage_quota_bytes"`
|
||||
}
|
||||
|
||||
// CreateSoraS3Profile 创建 Sora S3 配置
|
||||
// POST /api/v1/admin/settings/sora-s3/profiles
|
||||
func (h *SettingHandler) CreateSoraS3Profile(c *gin.Context) {
|
||||
var req CreateSoraS3ProfileRequest
|
||||
if err := c.ShouldBindJSON(&req); err != nil {
|
||||
response.BadRequest(c, "Invalid request: "+err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
if req.DefaultStorageQuotaBytes < 0 {
|
||||
req.DefaultStorageQuotaBytes = 0
|
||||
}
|
||||
if strings.TrimSpace(req.Name) == "" {
|
||||
response.BadRequest(c, "Name is required")
|
||||
return
|
||||
}
|
||||
if strings.TrimSpace(req.ProfileID) == "" {
|
||||
response.BadRequest(c, "Profile ID is required")
|
||||
return
|
||||
}
|
||||
if err := validateSoraS3RequiredWhenEnabled(req.Enabled, req.Endpoint, req.Bucket, req.AccessKeyID, req.SecretAccessKey, false); err != nil {
|
||||
response.BadRequest(c, err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
created, err := h.settingService.CreateSoraS3Profile(c.Request.Context(), &service.SoraS3Profile{
|
||||
ProfileID: req.ProfileID,
|
||||
Name: req.Name,
|
||||
Enabled: req.Enabled,
|
||||
Endpoint: req.Endpoint,
|
||||
Region: req.Region,
|
||||
Bucket: req.Bucket,
|
||||
AccessKeyID: req.AccessKeyID,
|
||||
SecretAccessKey: req.SecretAccessKey,
|
||||
Prefix: req.Prefix,
|
||||
ForcePathStyle: req.ForcePathStyle,
|
||||
CDNURL: req.CDNURL,
|
||||
DefaultStorageQuotaBytes: req.DefaultStorageQuotaBytes,
|
||||
}, req.SetActive)
|
||||
if err != nil {
|
||||
response.ErrorFrom(c, err)
|
||||
return
|
||||
}
|
||||
|
||||
response.Success(c, toSoraS3ProfileDTO(*created))
|
||||
}
|
||||
|
||||
// UpdateSoraS3Profile 更新 Sora S3 配置
|
||||
// PUT /api/v1/admin/settings/sora-s3/profiles/:profile_id
|
||||
func (h *SettingHandler) UpdateSoraS3Profile(c *gin.Context) {
|
||||
profileID := strings.TrimSpace(c.Param("profile_id"))
|
||||
if profileID == "" {
|
||||
response.BadRequest(c, "Profile ID is required")
|
||||
return
|
||||
}
|
||||
|
||||
var req UpdateSoraS3ProfileRequest
|
||||
if err := c.ShouldBindJSON(&req); err != nil {
|
||||
response.BadRequest(c, "Invalid request: "+err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
if req.DefaultStorageQuotaBytes < 0 {
|
||||
req.DefaultStorageQuotaBytes = 0
|
||||
}
|
||||
if strings.TrimSpace(req.Name) == "" {
|
||||
response.BadRequest(c, "Name is required")
|
||||
return
|
||||
}
|
||||
|
||||
existingList, err := h.settingService.ListSoraS3Profiles(c.Request.Context())
|
||||
if err != nil {
|
||||
response.ErrorFrom(c, err)
|
||||
return
|
||||
}
|
||||
existing := findSoraS3ProfileByID(existingList.Items, profileID)
|
||||
if existing == nil {
|
||||
response.ErrorFrom(c, service.ErrSoraS3ProfileNotFound)
|
||||
return
|
||||
}
|
||||
if err := validateSoraS3RequiredWhenEnabled(req.Enabled, req.Endpoint, req.Bucket, req.AccessKeyID, req.SecretAccessKey, existing.SecretAccessKeyConfigured); err != nil {
|
||||
response.BadRequest(c, err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
updated, updateErr := h.settingService.UpdateSoraS3Profile(c.Request.Context(), profileID, &service.SoraS3Profile{
|
||||
Name: req.Name,
|
||||
Enabled: req.Enabled,
|
||||
Endpoint: req.Endpoint,
|
||||
Region: req.Region,
|
||||
Bucket: req.Bucket,
|
||||
AccessKeyID: req.AccessKeyID,
|
||||
SecretAccessKey: req.SecretAccessKey,
|
||||
Prefix: req.Prefix,
|
||||
ForcePathStyle: req.ForcePathStyle,
|
||||
CDNURL: req.CDNURL,
|
||||
DefaultStorageQuotaBytes: req.DefaultStorageQuotaBytes,
|
||||
})
|
||||
if updateErr != nil {
|
||||
response.ErrorFrom(c, updateErr)
|
||||
return
|
||||
}
|
||||
|
||||
response.Success(c, toSoraS3ProfileDTO(*updated))
|
||||
}
|
||||
|
||||
// DeleteSoraS3Profile 删除 Sora S3 配置
|
||||
// DELETE /api/v1/admin/settings/sora-s3/profiles/:profile_id
|
||||
func (h *SettingHandler) DeleteSoraS3Profile(c *gin.Context) {
|
||||
profileID := strings.TrimSpace(c.Param("profile_id"))
|
||||
if profileID == "" {
|
||||
response.BadRequest(c, "Profile ID is required")
|
||||
return
|
||||
}
|
||||
if err := h.settingService.DeleteSoraS3Profile(c.Request.Context(), profileID); err != nil {
|
||||
response.ErrorFrom(c, err)
|
||||
return
|
||||
}
|
||||
response.Success(c, gin.H{"deleted": true})
|
||||
}
|
||||
|
||||
// SetActiveSoraS3Profile 切换激活 Sora S3 配置
|
||||
// POST /api/v1/admin/settings/sora-s3/profiles/:profile_id/activate
|
||||
func (h *SettingHandler) SetActiveSoraS3Profile(c *gin.Context) {
|
||||
profileID := strings.TrimSpace(c.Param("profile_id"))
|
||||
if profileID == "" {
|
||||
response.BadRequest(c, "Profile ID is required")
|
||||
return
|
||||
}
|
||||
active, err := h.settingService.SetActiveSoraS3Profile(c.Request.Context(), profileID)
|
||||
if err != nil {
|
||||
response.ErrorFrom(c, err)
|
||||
return
|
||||
}
|
||||
response.Success(c, toSoraS3ProfileDTO(*active))
|
||||
}
|
||||
|
||||
// UpdateSoraS3Settings 更新 Sora S3 存储配置(兼容旧单配置接口)
|
||||
// PUT /api/v1/admin/settings/sora-s3
|
||||
func (h *SettingHandler) UpdateSoraS3Settings(c *gin.Context) {
|
||||
var req UpdateSoraS3SettingsRequest
|
||||
if err := c.ShouldBindJSON(&req); err != nil {
|
||||
response.BadRequest(c, "Invalid request: "+err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
existing, err := h.settingService.GetSoraS3Settings(c.Request.Context())
|
||||
if err != nil {
|
||||
response.ErrorFrom(c, err)
|
||||
return
|
||||
}
|
||||
|
||||
if req.DefaultStorageQuotaBytes < 0 {
|
||||
req.DefaultStorageQuotaBytes = 0
|
||||
}
|
||||
if err := validateSoraS3RequiredWhenEnabled(req.Enabled, req.Endpoint, req.Bucket, req.AccessKeyID, req.SecretAccessKey, existing.SecretAccessKeyConfigured); err != nil {
|
||||
response.BadRequest(c, err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
settings := &service.SoraS3Settings{
|
||||
Enabled: req.Enabled,
|
||||
Endpoint: req.Endpoint,
|
||||
Region: req.Region,
|
||||
Bucket: req.Bucket,
|
||||
AccessKeyID: req.AccessKeyID,
|
||||
SecretAccessKey: req.SecretAccessKey,
|
||||
Prefix: req.Prefix,
|
||||
ForcePathStyle: req.ForcePathStyle,
|
||||
CDNURL: req.CDNURL,
|
||||
DefaultStorageQuotaBytes: req.DefaultStorageQuotaBytes,
|
||||
}
|
||||
if err := h.settingService.SetSoraS3Settings(c.Request.Context(), settings); err != nil {
|
||||
response.ErrorFrom(c, err)
|
||||
return
|
||||
}
|
||||
|
||||
updatedSettings, err := h.settingService.GetSoraS3Settings(c.Request.Context())
|
||||
if err != nil {
|
||||
response.ErrorFrom(c, err)
|
||||
return
|
||||
}
|
||||
response.Success(c, toSoraS3SettingsDTO(updatedSettings))
|
||||
}
|
||||
|
||||
// TestSoraS3Connection 测试 Sora S3 连接(HeadBucket)
|
||||
// POST /api/v1/admin/settings/sora-s3/test
|
||||
func (h *SettingHandler) TestSoraS3Connection(c *gin.Context) {
|
||||
if h.soraS3Storage == nil {
|
||||
response.Error(c, 500, "S3 存储服务未初始化")
|
||||
return
|
||||
}
|
||||
|
||||
var req UpdateSoraS3SettingsRequest
|
||||
if err := c.ShouldBindJSON(&req); err != nil {
|
||||
response.BadRequest(c, "Invalid request: "+err.Error())
|
||||
return
|
||||
}
|
||||
if !req.Enabled {
|
||||
response.BadRequest(c, "S3 未启用,无法测试连接")
|
||||
return
|
||||
}
|
||||
|
||||
if req.SecretAccessKey == "" {
|
||||
if req.ProfileID != "" {
|
||||
profiles, err := h.settingService.ListSoraS3Profiles(c.Request.Context())
|
||||
if err == nil {
|
||||
profile := findSoraS3ProfileByID(profiles.Items, req.ProfileID)
|
||||
if profile != nil {
|
||||
req.SecretAccessKey = profile.SecretAccessKey
|
||||
}
|
||||
}
|
||||
}
|
||||
if req.SecretAccessKey == "" {
|
||||
existing, err := h.settingService.GetSoraS3Settings(c.Request.Context())
|
||||
if err == nil {
|
||||
req.SecretAccessKey = existing.SecretAccessKey
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
testCfg := &service.SoraS3Settings{
|
||||
Enabled: true,
|
||||
Endpoint: req.Endpoint,
|
||||
Region: req.Region,
|
||||
Bucket: req.Bucket,
|
||||
AccessKeyID: req.AccessKeyID,
|
||||
SecretAccessKey: req.SecretAccessKey,
|
||||
Prefix: req.Prefix,
|
||||
ForcePathStyle: req.ForcePathStyle,
|
||||
CDNURL: req.CDNURL,
|
||||
}
|
||||
if err := h.soraS3Storage.TestConnectionWithSettings(c.Request.Context(), testCfg); err != nil {
|
||||
response.Error(c, 400, "S3 连接测试失败: "+err.Error())
|
||||
return
|
||||
}
|
||||
response.Success(c, gin.H{"message": "S3 连接成功"})
|
||||
}
|
||||
|
||||
// UpdateStreamTimeoutSettingsRequest 更新流超时配置请求
|
||||
type UpdateStreamTimeoutSettingsRequest struct {
|
||||
Enabled bool `json:"enabled"`
|
||||
|
||||
95
backend/internal/handler/admin/snapshot_cache.go
Normal file
95
backend/internal/handler/admin/snapshot_cache.go
Normal file
@@ -0,0 +1,95 @@
|
||||
package admin
|
||||
|
||||
import (
|
||||
"crypto/sha256"
|
||||
"encoding/hex"
|
||||
"encoding/json"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
)
|
||||
|
||||
type snapshotCacheEntry struct {
|
||||
ETag string
|
||||
Payload any
|
||||
ExpiresAt time.Time
|
||||
}
|
||||
|
||||
type snapshotCache struct {
|
||||
mu sync.RWMutex
|
||||
ttl time.Duration
|
||||
items map[string]snapshotCacheEntry
|
||||
}
|
||||
|
||||
func newSnapshotCache(ttl time.Duration) *snapshotCache {
|
||||
if ttl <= 0 {
|
||||
ttl = 30 * time.Second
|
||||
}
|
||||
return &snapshotCache{
|
||||
ttl: ttl,
|
||||
items: make(map[string]snapshotCacheEntry),
|
||||
}
|
||||
}
|
||||
|
||||
func (c *snapshotCache) Get(key string) (snapshotCacheEntry, bool) {
|
||||
if c == nil || key == "" {
|
||||
return snapshotCacheEntry{}, false
|
||||
}
|
||||
now := time.Now()
|
||||
|
||||
c.mu.RLock()
|
||||
entry, ok := c.items[key]
|
||||
c.mu.RUnlock()
|
||||
if !ok {
|
||||
return snapshotCacheEntry{}, false
|
||||
}
|
||||
if now.After(entry.ExpiresAt) {
|
||||
c.mu.Lock()
|
||||
delete(c.items, key)
|
||||
c.mu.Unlock()
|
||||
return snapshotCacheEntry{}, false
|
||||
}
|
||||
return entry, true
|
||||
}
|
||||
|
||||
func (c *snapshotCache) Set(key string, payload any) snapshotCacheEntry {
|
||||
if c == nil {
|
||||
return snapshotCacheEntry{}
|
||||
}
|
||||
entry := snapshotCacheEntry{
|
||||
ETag: buildETagFromAny(payload),
|
||||
Payload: payload,
|
||||
ExpiresAt: time.Now().Add(c.ttl),
|
||||
}
|
||||
if key == "" {
|
||||
return entry
|
||||
}
|
||||
c.mu.Lock()
|
||||
c.items[key] = entry
|
||||
c.mu.Unlock()
|
||||
return entry
|
||||
}
|
||||
|
||||
func buildETagFromAny(payload any) string {
|
||||
raw, err := json.Marshal(payload)
|
||||
if err != nil {
|
||||
return ""
|
||||
}
|
||||
sum := sha256.Sum256(raw)
|
||||
return "\"" + hex.EncodeToString(sum[:]) + "\""
|
||||
}
|
||||
|
||||
func parseBoolQueryWithDefault(raw string, def bool) bool {
|
||||
value := strings.TrimSpace(strings.ToLower(raw))
|
||||
if value == "" {
|
||||
return def
|
||||
}
|
||||
switch value {
|
||||
case "1", "true", "yes", "on":
|
||||
return true
|
||||
case "0", "false", "no", "off":
|
||||
return false
|
||||
default:
|
||||
return def
|
||||
}
|
||||
}
|
||||
128
backend/internal/handler/admin/snapshot_cache_test.go
Normal file
128
backend/internal/handler/admin/snapshot_cache_test.go
Normal file
@@ -0,0 +1,128 @@
|
||||
//go:build unit
|
||||
|
||||
package admin
|
||||
|
||||
import (
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestSnapshotCache_SetAndGet(t *testing.T) {
|
||||
c := newSnapshotCache(5 * time.Second)
|
||||
|
||||
entry := c.Set("key1", map[string]string{"hello": "world"})
|
||||
require.NotEmpty(t, entry.ETag)
|
||||
require.NotNil(t, entry.Payload)
|
||||
|
||||
got, ok := c.Get("key1")
|
||||
require.True(t, ok)
|
||||
require.Equal(t, entry.ETag, got.ETag)
|
||||
}
|
||||
|
||||
func TestSnapshotCache_Expiration(t *testing.T) {
|
||||
c := newSnapshotCache(1 * time.Millisecond)
|
||||
|
||||
c.Set("key1", "value")
|
||||
time.Sleep(5 * time.Millisecond)
|
||||
|
||||
_, ok := c.Get("key1")
|
||||
require.False(t, ok, "expired entry should not be returned")
|
||||
}
|
||||
|
||||
func TestSnapshotCache_GetEmptyKey(t *testing.T) {
|
||||
c := newSnapshotCache(5 * time.Second)
|
||||
_, ok := c.Get("")
|
||||
require.False(t, ok)
|
||||
}
|
||||
|
||||
func TestSnapshotCache_GetMiss(t *testing.T) {
|
||||
c := newSnapshotCache(5 * time.Second)
|
||||
_, ok := c.Get("nonexistent")
|
||||
require.False(t, ok)
|
||||
}
|
||||
|
||||
func TestSnapshotCache_NilReceiver(t *testing.T) {
|
||||
var c *snapshotCache
|
||||
_, ok := c.Get("key")
|
||||
require.False(t, ok)
|
||||
|
||||
entry := c.Set("key", "value")
|
||||
require.Empty(t, entry.ETag)
|
||||
}
|
||||
|
||||
func TestSnapshotCache_SetEmptyKey(t *testing.T) {
|
||||
c := newSnapshotCache(5 * time.Second)
|
||||
|
||||
// Set with empty key should return entry but not store it
|
||||
entry := c.Set("", "value")
|
||||
require.NotEmpty(t, entry.ETag)
|
||||
|
||||
_, ok := c.Get("")
|
||||
require.False(t, ok)
|
||||
}
|
||||
|
||||
func TestSnapshotCache_DefaultTTL(t *testing.T) {
|
||||
c := newSnapshotCache(0)
|
||||
require.Equal(t, 30*time.Second, c.ttl)
|
||||
|
||||
c2 := newSnapshotCache(-1 * time.Second)
|
||||
require.Equal(t, 30*time.Second, c2.ttl)
|
||||
}
|
||||
|
||||
func TestSnapshotCache_ETagDeterministic(t *testing.T) {
|
||||
c := newSnapshotCache(5 * time.Second)
|
||||
payload := map[string]int{"a": 1, "b": 2}
|
||||
|
||||
entry1 := c.Set("k1", payload)
|
||||
entry2 := c.Set("k2", payload)
|
||||
require.Equal(t, entry1.ETag, entry2.ETag, "same payload should produce same ETag")
|
||||
}
|
||||
|
||||
func TestSnapshotCache_ETagFormat(t *testing.T) {
|
||||
c := newSnapshotCache(5 * time.Second)
|
||||
entry := c.Set("k", "test")
|
||||
// ETag should be quoted hex string: "abcdef..."
|
||||
require.True(t, len(entry.ETag) > 2)
|
||||
require.Equal(t, byte('"'), entry.ETag[0])
|
||||
require.Equal(t, byte('"'), entry.ETag[len(entry.ETag)-1])
|
||||
}
|
||||
|
||||
func TestBuildETagFromAny_UnmarshalablePayload(t *testing.T) {
|
||||
// channels are not JSON-serializable
|
||||
etag := buildETagFromAny(make(chan int))
|
||||
require.Empty(t, etag)
|
||||
}
|
||||
|
||||
func TestParseBoolQueryWithDefault(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
raw string
|
||||
def bool
|
||||
want bool
|
||||
}{
|
||||
{"empty returns default true", "", true, true},
|
||||
{"empty returns default false", "", false, false},
|
||||
{"1", "1", false, true},
|
||||
{"true", "true", false, true},
|
||||
{"TRUE", "TRUE", false, true},
|
||||
{"yes", "yes", false, true},
|
||||
{"on", "on", false, true},
|
||||
{"0", "0", true, false},
|
||||
{"false", "false", true, false},
|
||||
{"FALSE", "FALSE", true, false},
|
||||
{"no", "no", true, false},
|
||||
{"off", "off", true, false},
|
||||
{"whitespace trimmed", " true ", false, true},
|
||||
{"unknown returns default true", "maybe", true, true},
|
||||
{"unknown returns default false", "maybe", false, false},
|
||||
}
|
||||
|
||||
for _, tc := range tests {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
got := parseBoolQueryWithDefault(tc.raw, tc.def)
|
||||
require.Equal(t, tc.want, got)
|
||||
})
|
||||
}
|
||||
}
|
||||
@@ -1,6 +1,7 @@
|
||||
package admin
|
||||
|
||||
import (
|
||||
"context"
|
||||
"strconv"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/handler/dto"
|
||||
@@ -199,13 +200,20 @@ func (h *SubscriptionHandler) Extend(c *gin.Context) {
|
||||
return
|
||||
}
|
||||
|
||||
subscription, err := h.subscriptionService.ExtendSubscription(c.Request.Context(), subscriptionID, req.Days)
|
||||
if err != nil {
|
||||
response.ErrorFrom(c, err)
|
||||
return
|
||||
idempotencyPayload := struct {
|
||||
SubscriptionID int64 `json:"subscription_id"`
|
||||
Body AdjustSubscriptionRequest `json:"body"`
|
||||
}{
|
||||
SubscriptionID: subscriptionID,
|
||||
Body: req,
|
||||
}
|
||||
|
||||
response.Success(c, dto.UserSubscriptionFromServiceAdmin(subscription))
|
||||
executeAdminIdempotentJSON(c, "admin.subscriptions.extend", idempotencyPayload, service.DefaultWriteIdempotencyTTL(), func(ctx context.Context) (any, error) {
|
||||
subscription, execErr := h.subscriptionService.ExtendSubscription(ctx, subscriptionID, req.Days)
|
||||
if execErr != nil {
|
||||
return nil, execErr
|
||||
}
|
||||
return dto.UserSubscriptionFromServiceAdmin(subscription), nil
|
||||
})
|
||||
}
|
||||
|
||||
// Revoke handles revoking a subscription
|
||||
|
||||
@@ -1,11 +1,15 @@
|
||||
package admin
|
||||
|
||||
import (
|
||||
"context"
|
||||
"net/http"
|
||||
"strconv"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/response"
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/sysutil"
|
||||
middleware2 "github.com/Wei-Shaw/sub2api/internal/server/middleware"
|
||||
"github.com/Wei-Shaw/sub2api/internal/service"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
@@ -14,12 +18,14 @@ import (
|
||||
// SystemHandler handles system-related operations
|
||||
type SystemHandler struct {
|
||||
updateSvc *service.UpdateService
|
||||
lockSvc *service.SystemOperationLockService
|
||||
}
|
||||
|
||||
// NewSystemHandler creates a new SystemHandler
|
||||
func NewSystemHandler(updateSvc *service.UpdateService) *SystemHandler {
|
||||
func NewSystemHandler(updateSvc *service.UpdateService, lockSvc *service.SystemOperationLockService) *SystemHandler {
|
||||
return &SystemHandler{
|
||||
updateSvc: updateSvc,
|
||||
lockSvc: lockSvc,
|
||||
}
|
||||
}
|
||||
|
||||
@@ -47,41 +53,125 @@ func (h *SystemHandler) CheckUpdates(c *gin.Context) {
|
||||
// PerformUpdate downloads and applies the update
|
||||
// POST /api/v1/admin/system/update
|
||||
func (h *SystemHandler) PerformUpdate(c *gin.Context) {
|
||||
if err := h.updateSvc.PerformUpdate(c.Request.Context()); err != nil {
|
||||
response.Error(c, http.StatusInternalServerError, err.Error())
|
||||
return
|
||||
}
|
||||
response.Success(c, gin.H{
|
||||
"message": "Update completed. Please restart the service.",
|
||||
"need_restart": true,
|
||||
operationID := buildSystemOperationID(c, "update")
|
||||
payload := gin.H{"operation_id": operationID}
|
||||
executeAdminIdempotentJSON(c, "admin.system.update", payload, service.DefaultSystemOperationIdempotencyTTL(), func(ctx context.Context) (any, error) {
|
||||
lock, release, err := h.acquireSystemLock(ctx, operationID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
var releaseReason string
|
||||
succeeded := false
|
||||
defer func() {
|
||||
release(releaseReason, succeeded)
|
||||
}()
|
||||
|
||||
if err := h.updateSvc.PerformUpdate(ctx); err != nil {
|
||||
releaseReason = "SYSTEM_UPDATE_FAILED"
|
||||
return nil, err
|
||||
}
|
||||
succeeded = true
|
||||
|
||||
return gin.H{
|
||||
"message": "Update completed. Please restart the service.",
|
||||
"need_restart": true,
|
||||
"operation_id": lock.OperationID(),
|
||||
}, nil
|
||||
})
|
||||
}
|
||||
|
||||
// Rollback restores the previous version
|
||||
// POST /api/v1/admin/system/rollback
|
||||
func (h *SystemHandler) Rollback(c *gin.Context) {
|
||||
if err := h.updateSvc.Rollback(); err != nil {
|
||||
response.Error(c, http.StatusInternalServerError, err.Error())
|
||||
return
|
||||
}
|
||||
response.Success(c, gin.H{
|
||||
"message": "Rollback completed. Please restart the service.",
|
||||
"need_restart": true,
|
||||
operationID := buildSystemOperationID(c, "rollback")
|
||||
payload := gin.H{"operation_id": operationID}
|
||||
executeAdminIdempotentJSON(c, "admin.system.rollback", payload, service.DefaultSystemOperationIdempotencyTTL(), func(ctx context.Context) (any, error) {
|
||||
lock, release, err := h.acquireSystemLock(ctx, operationID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
var releaseReason string
|
||||
succeeded := false
|
||||
defer func() {
|
||||
release(releaseReason, succeeded)
|
||||
}()
|
||||
|
||||
if err := h.updateSvc.Rollback(); err != nil {
|
||||
releaseReason = "SYSTEM_ROLLBACK_FAILED"
|
||||
return nil, err
|
||||
}
|
||||
succeeded = true
|
||||
|
||||
return gin.H{
|
||||
"message": "Rollback completed. Please restart the service.",
|
||||
"need_restart": true,
|
||||
"operation_id": lock.OperationID(),
|
||||
}, nil
|
||||
})
|
||||
}
|
||||
|
||||
// RestartService restarts the systemd service
|
||||
// POST /api/v1/admin/system/restart
|
||||
func (h *SystemHandler) RestartService(c *gin.Context) {
|
||||
// Schedule service restart in background after sending response
|
||||
// This ensures the client receives the success response before the service restarts
|
||||
go func() {
|
||||
// Wait a moment to ensure the response is sent
|
||||
time.Sleep(500 * time.Millisecond)
|
||||
sysutil.RestartServiceAsync()
|
||||
}()
|
||||
operationID := buildSystemOperationID(c, "restart")
|
||||
payload := gin.H{"operation_id": operationID}
|
||||
executeAdminIdempotentJSON(c, "admin.system.restart", payload, service.DefaultSystemOperationIdempotencyTTL(), func(ctx context.Context) (any, error) {
|
||||
lock, release, err := h.acquireSystemLock(ctx, operationID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
succeeded := false
|
||||
defer func() {
|
||||
release("", succeeded)
|
||||
}()
|
||||
|
||||
response.Success(c, gin.H{
|
||||
"message": "Service restart initiated",
|
||||
// Schedule service restart in background after sending response
|
||||
// This ensures the client receives the success response before the service restarts
|
||||
go func() {
|
||||
// Wait a moment to ensure the response is sent
|
||||
time.Sleep(500 * time.Millisecond)
|
||||
sysutil.RestartServiceAsync()
|
||||
}()
|
||||
succeeded = true
|
||||
return gin.H{
|
||||
"message": "Service restart initiated",
|
||||
"operation_id": lock.OperationID(),
|
||||
}, nil
|
||||
})
|
||||
}
|
||||
|
||||
func (h *SystemHandler) acquireSystemLock(
|
||||
ctx context.Context,
|
||||
operationID string,
|
||||
) (*service.SystemOperationLock, func(string, bool), error) {
|
||||
if h.lockSvc == nil {
|
||||
return nil, nil, service.ErrIdempotencyStoreUnavail
|
||||
}
|
||||
lock, err := h.lockSvc.Acquire(ctx, operationID)
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
release := func(reason string, succeeded bool) {
|
||||
releaseCtx, cancel := context.WithTimeout(context.Background(), 2*time.Second)
|
||||
defer cancel()
|
||||
_ = h.lockSvc.Release(releaseCtx, lock, succeeded, reason)
|
||||
}
|
||||
return lock, release, nil
|
||||
}
|
||||
|
||||
func buildSystemOperationID(c *gin.Context, operation string) string {
|
||||
key := strings.TrimSpace(c.GetHeader("Idempotency-Key"))
|
||||
if key == "" {
|
||||
return "sysop-" + operation + "-" + strconv.FormatInt(time.Now().UnixNano(), 36)
|
||||
}
|
||||
actorScope := "admin:0"
|
||||
if subject, ok := middleware2.GetAuthSubjectFromContext(c); ok {
|
||||
actorScope = "admin:" + strconv.FormatInt(subject.UserID, 10)
|
||||
}
|
||||
seed := operation + "|" + actorScope + "|" + c.FullPath() + "|" + key
|
||||
hash := service.HashIdempotencyKey(seed)
|
||||
if len(hash) > 24 {
|
||||
hash = hash[:24]
|
||||
}
|
||||
return "sysop-" + hash
|
||||
}
|
||||
|
||||
@@ -225,6 +225,92 @@ func TestUsageHandlerCreateCleanupTaskInvalidEndDate(t *testing.T) {
|
||||
require.Equal(t, http.StatusBadRequest, recorder.Code)
|
||||
}
|
||||
|
||||
func TestUsageHandlerCreateCleanupTaskInvalidRequestType(t *testing.T) {
|
||||
repo := &cleanupRepoStub{}
|
||||
cfg := &config.Config{UsageCleanup: config.UsageCleanupConfig{Enabled: true, MaxRangeDays: 31}}
|
||||
cleanupService := service.NewUsageCleanupService(repo, nil, nil, cfg)
|
||||
router := setupCleanupRouter(cleanupService, 88)
|
||||
|
||||
payload := map[string]any{
|
||||
"start_date": "2024-01-01",
|
||||
"end_date": "2024-01-02",
|
||||
"timezone": "UTC",
|
||||
"request_type": "invalid",
|
||||
}
|
||||
body, err := json.Marshal(payload)
|
||||
require.NoError(t, err)
|
||||
|
||||
req := httptest.NewRequest(http.MethodPost, "/api/v1/admin/usage/cleanup-tasks", bytes.NewReader(body))
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
recorder := httptest.NewRecorder()
|
||||
router.ServeHTTP(recorder, req)
|
||||
|
||||
require.Equal(t, http.StatusBadRequest, recorder.Code)
|
||||
}
|
||||
|
||||
func TestUsageHandlerCreateCleanupTaskRequestTypePriority(t *testing.T) {
|
||||
repo := &cleanupRepoStub{}
|
||||
cfg := &config.Config{UsageCleanup: config.UsageCleanupConfig{Enabled: true, MaxRangeDays: 31}}
|
||||
cleanupService := service.NewUsageCleanupService(repo, nil, nil, cfg)
|
||||
router := setupCleanupRouter(cleanupService, 99)
|
||||
|
||||
payload := map[string]any{
|
||||
"start_date": "2024-01-01",
|
||||
"end_date": "2024-01-02",
|
||||
"timezone": "UTC",
|
||||
"request_type": "ws_v2",
|
||||
"stream": false,
|
||||
}
|
||||
body, err := json.Marshal(payload)
|
||||
require.NoError(t, err)
|
||||
|
||||
req := httptest.NewRequest(http.MethodPost, "/api/v1/admin/usage/cleanup-tasks", bytes.NewReader(body))
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
recorder := httptest.NewRecorder()
|
||||
router.ServeHTTP(recorder, req)
|
||||
|
||||
require.Equal(t, http.StatusOK, recorder.Code)
|
||||
|
||||
repo.mu.Lock()
|
||||
defer repo.mu.Unlock()
|
||||
require.Len(t, repo.created, 1)
|
||||
created := repo.created[0]
|
||||
require.NotNil(t, created.Filters.RequestType)
|
||||
require.Equal(t, int16(service.RequestTypeWSV2), *created.Filters.RequestType)
|
||||
require.Nil(t, created.Filters.Stream)
|
||||
}
|
||||
|
||||
func TestUsageHandlerCreateCleanupTaskWithLegacyStream(t *testing.T) {
|
||||
repo := &cleanupRepoStub{}
|
||||
cfg := &config.Config{UsageCleanup: config.UsageCleanupConfig{Enabled: true, MaxRangeDays: 31}}
|
||||
cleanupService := service.NewUsageCleanupService(repo, nil, nil, cfg)
|
||||
router := setupCleanupRouter(cleanupService, 99)
|
||||
|
||||
payload := map[string]any{
|
||||
"start_date": "2024-01-01",
|
||||
"end_date": "2024-01-02",
|
||||
"timezone": "UTC",
|
||||
"stream": true,
|
||||
}
|
||||
body, err := json.Marshal(payload)
|
||||
require.NoError(t, err)
|
||||
|
||||
req := httptest.NewRequest(http.MethodPost, "/api/v1/admin/usage/cleanup-tasks", bytes.NewReader(body))
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
recorder := httptest.NewRecorder()
|
||||
router.ServeHTTP(recorder, req)
|
||||
|
||||
require.Equal(t, http.StatusOK, recorder.Code)
|
||||
|
||||
repo.mu.Lock()
|
||||
defer repo.mu.Unlock()
|
||||
require.Len(t, repo.created, 1)
|
||||
created := repo.created[0]
|
||||
require.Nil(t, created.Filters.RequestType)
|
||||
require.NotNil(t, created.Filters.Stream)
|
||||
require.True(t, *created.Filters.Stream)
|
||||
}
|
||||
|
||||
func TestUsageHandlerCreateCleanupTaskSuccess(t *testing.T) {
|
||||
repo := &cleanupRepoStub{}
|
||||
cfg := &config.Config{UsageCleanup: config.UsageCleanupConfig{Enabled: true, MaxRangeDays: 31}}
|
||||
|
||||
@@ -1,13 +1,14 @@
|
||||
package admin
|
||||
|
||||
import (
|
||||
"log"
|
||||
"context"
|
||||
"net/http"
|
||||
"strconv"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/handler/dto"
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/logger"
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/response"
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/timezone"
|
||||
@@ -50,6 +51,7 @@ type CreateUsageCleanupTaskRequest struct {
|
||||
AccountID *int64 `json:"account_id"`
|
||||
GroupID *int64 `json:"group_id"`
|
||||
Model *string `json:"model"`
|
||||
RequestType *string `json:"request_type"`
|
||||
Stream *bool `json:"stream"`
|
||||
BillingType *int8 `json:"billing_type"`
|
||||
Timezone string `json:"timezone"`
|
||||
@@ -59,6 +61,15 @@ type CreateUsageCleanupTaskRequest struct {
|
||||
// GET /api/v1/admin/usage
|
||||
func (h *UsageHandler) List(c *gin.Context) {
|
||||
page, pageSize := response.ParsePagination(c)
|
||||
exactTotal := false
|
||||
if exactTotalRaw := strings.TrimSpace(c.Query("exact_total")); exactTotalRaw != "" {
|
||||
parsed, err := strconv.ParseBool(exactTotalRaw)
|
||||
if err != nil {
|
||||
response.BadRequest(c, "Invalid exact_total value, use true or false")
|
||||
return
|
||||
}
|
||||
exactTotal = parsed
|
||||
}
|
||||
|
||||
// Parse filters
|
||||
var userID, apiKeyID, accountID, groupID int64
|
||||
@@ -100,8 +111,17 @@ func (h *UsageHandler) List(c *gin.Context) {
|
||||
|
||||
model := c.Query("model")
|
||||
|
||||
var requestType *int16
|
||||
var stream *bool
|
||||
if streamStr := c.Query("stream"); streamStr != "" {
|
||||
if requestTypeStr := strings.TrimSpace(c.Query("request_type")); requestTypeStr != "" {
|
||||
parsed, err := service.ParseUsageRequestType(requestTypeStr)
|
||||
if err != nil {
|
||||
response.BadRequest(c, err.Error())
|
||||
return
|
||||
}
|
||||
value := int16(parsed)
|
||||
requestType = &value
|
||||
} else if streamStr := c.Query("stream"); streamStr != "" {
|
||||
val, err := strconv.ParseBool(streamStr)
|
||||
if err != nil {
|
||||
response.BadRequest(c, "Invalid stream value, use true or false")
|
||||
@@ -151,10 +171,12 @@ func (h *UsageHandler) List(c *gin.Context) {
|
||||
AccountID: accountID,
|
||||
GroupID: groupID,
|
||||
Model: model,
|
||||
RequestType: requestType,
|
||||
Stream: stream,
|
||||
BillingType: billingType,
|
||||
StartTime: startTime,
|
||||
EndTime: endTime,
|
||||
ExactTotal: exactTotal,
|
||||
}
|
||||
|
||||
records, result, err := h.usageService.ListWithFilters(c.Request.Context(), params, filters)
|
||||
@@ -213,8 +235,17 @@ func (h *UsageHandler) Stats(c *gin.Context) {
|
||||
|
||||
model := c.Query("model")
|
||||
|
||||
var requestType *int16
|
||||
var stream *bool
|
||||
if streamStr := c.Query("stream"); streamStr != "" {
|
||||
if requestTypeStr := strings.TrimSpace(c.Query("request_type")); requestTypeStr != "" {
|
||||
parsed, err := service.ParseUsageRequestType(requestTypeStr)
|
||||
if err != nil {
|
||||
response.BadRequest(c, err.Error())
|
||||
return
|
||||
}
|
||||
value := int16(parsed)
|
||||
requestType = &value
|
||||
} else if streamStr := c.Query("stream"); streamStr != "" {
|
||||
val, err := strconv.ParseBool(streamStr)
|
||||
if err != nil {
|
||||
response.BadRequest(c, "Invalid stream value, use true or false")
|
||||
@@ -277,6 +308,7 @@ func (h *UsageHandler) Stats(c *gin.Context) {
|
||||
AccountID: accountID,
|
||||
GroupID: groupID,
|
||||
Model: model,
|
||||
RequestType: requestType,
|
||||
Stream: stream,
|
||||
BillingType: billingType,
|
||||
StartTime: &startTime,
|
||||
@@ -378,11 +410,11 @@ func (h *UsageHandler) ListCleanupTasks(c *gin.Context) {
|
||||
operator = subject.UserID
|
||||
}
|
||||
page, pageSize := response.ParsePagination(c)
|
||||
log.Printf("[UsageCleanup] 请求清理任务列表: operator=%d page=%d page_size=%d", operator, page, pageSize)
|
||||
logger.LegacyPrintf("handler.admin.usage", "[UsageCleanup] 请求清理任务列表: operator=%d page=%d page_size=%d", operator, page, pageSize)
|
||||
params := pagination.PaginationParams{Page: page, PageSize: pageSize}
|
||||
tasks, result, err := h.cleanupService.ListTasks(c.Request.Context(), params)
|
||||
if err != nil {
|
||||
log.Printf("[UsageCleanup] 查询清理任务列表失败: operator=%d page=%d page_size=%d err=%v", operator, page, pageSize, err)
|
||||
logger.LegacyPrintf("handler.admin.usage", "[UsageCleanup] 查询清理任务列表失败: operator=%d page=%d page_size=%d err=%v", operator, page, pageSize, err)
|
||||
response.ErrorFrom(c, err)
|
||||
return
|
||||
}
|
||||
@@ -390,7 +422,7 @@ func (h *UsageHandler) ListCleanupTasks(c *gin.Context) {
|
||||
for i := range tasks {
|
||||
out = append(out, *dto.UsageCleanupTaskFromService(&tasks[i]))
|
||||
}
|
||||
log.Printf("[UsageCleanup] 返回清理任务列表: operator=%d total=%d items=%d page=%d page_size=%d", operator, result.Total, len(out), page, pageSize)
|
||||
logger.LegacyPrintf("handler.admin.usage", "[UsageCleanup] 返回清理任务列表: operator=%d total=%d items=%d page=%d page_size=%d", operator, result.Total, len(out), page, pageSize)
|
||||
response.Paginated(c, out, result.Total, page, pageSize)
|
||||
}
|
||||
|
||||
@@ -431,6 +463,19 @@ func (h *UsageHandler) CreateCleanupTask(c *gin.Context) {
|
||||
}
|
||||
endTime = endTime.Add(24*time.Hour - time.Nanosecond)
|
||||
|
||||
var requestType *int16
|
||||
stream := req.Stream
|
||||
if req.RequestType != nil {
|
||||
parsed, err := service.ParseUsageRequestType(*req.RequestType)
|
||||
if err != nil {
|
||||
response.BadRequest(c, err.Error())
|
||||
return
|
||||
}
|
||||
value := int16(parsed)
|
||||
requestType = &value
|
||||
stream = nil
|
||||
}
|
||||
|
||||
filters := service.UsageCleanupFilters{
|
||||
StartTime: startTime,
|
||||
EndTime: endTime,
|
||||
@@ -439,7 +484,8 @@ func (h *UsageHandler) CreateCleanupTask(c *gin.Context) {
|
||||
AccountID: req.AccountID,
|
||||
GroupID: req.GroupID,
|
||||
Model: req.Model,
|
||||
Stream: req.Stream,
|
||||
RequestType: requestType,
|
||||
Stream: stream,
|
||||
BillingType: req.BillingType,
|
||||
}
|
||||
|
||||
@@ -463,38 +509,50 @@ func (h *UsageHandler) CreateCleanupTask(c *gin.Context) {
|
||||
if filters.Model != nil {
|
||||
model = *filters.Model
|
||||
}
|
||||
var stream any
|
||||
var streamValue any
|
||||
if filters.Stream != nil {
|
||||
stream = *filters.Stream
|
||||
streamValue = *filters.Stream
|
||||
}
|
||||
var requestTypeName any
|
||||
if filters.RequestType != nil {
|
||||
requestTypeName = service.RequestTypeFromInt16(*filters.RequestType).String()
|
||||
}
|
||||
var billingType any
|
||||
if filters.BillingType != nil {
|
||||
billingType = *filters.BillingType
|
||||
}
|
||||
|
||||
log.Printf("[UsageCleanup] 请求创建清理任务: operator=%d start=%s end=%s user_id=%v api_key_id=%v account_id=%v group_id=%v model=%v stream=%v billing_type=%v tz=%q",
|
||||
subject.UserID,
|
||||
filters.StartTime.Format(time.RFC3339),
|
||||
filters.EndTime.Format(time.RFC3339),
|
||||
userID,
|
||||
apiKeyID,
|
||||
accountID,
|
||||
groupID,
|
||||
model,
|
||||
stream,
|
||||
billingType,
|
||||
req.Timezone,
|
||||
)
|
||||
|
||||
task, err := h.cleanupService.CreateTask(c.Request.Context(), filters, subject.UserID)
|
||||
if err != nil {
|
||||
log.Printf("[UsageCleanup] 创建清理任务失败: operator=%d err=%v", subject.UserID, err)
|
||||
response.ErrorFrom(c, err)
|
||||
return
|
||||
idempotencyPayload := struct {
|
||||
OperatorID int64 `json:"operator_id"`
|
||||
Body CreateUsageCleanupTaskRequest `json:"body"`
|
||||
}{
|
||||
OperatorID: subject.UserID,
|
||||
Body: req,
|
||||
}
|
||||
executeAdminIdempotentJSON(c, "admin.usage.cleanup_tasks.create", idempotencyPayload, service.DefaultWriteIdempotencyTTL(), func(ctx context.Context) (any, error) {
|
||||
logger.LegacyPrintf("handler.admin.usage", "[UsageCleanup] 请求创建清理任务: operator=%d start=%s end=%s user_id=%v api_key_id=%v account_id=%v group_id=%v model=%v request_type=%v stream=%v billing_type=%v tz=%q",
|
||||
subject.UserID,
|
||||
filters.StartTime.Format(time.RFC3339),
|
||||
filters.EndTime.Format(time.RFC3339),
|
||||
userID,
|
||||
apiKeyID,
|
||||
accountID,
|
||||
groupID,
|
||||
model,
|
||||
requestTypeName,
|
||||
streamValue,
|
||||
billingType,
|
||||
req.Timezone,
|
||||
)
|
||||
|
||||
log.Printf("[UsageCleanup] 清理任务已创建: task=%d operator=%d status=%s", task.ID, subject.UserID, task.Status)
|
||||
response.Success(c, dto.UsageCleanupTaskFromService(task))
|
||||
task, err := h.cleanupService.CreateTask(ctx, filters, subject.UserID)
|
||||
if err != nil {
|
||||
logger.LegacyPrintf("handler.admin.usage", "[UsageCleanup] 创建清理任务失败: operator=%d err=%v", subject.UserID, err)
|
||||
return nil, err
|
||||
}
|
||||
logger.LegacyPrintf("handler.admin.usage", "[UsageCleanup] 清理任务已创建: task=%d operator=%d status=%s", task.ID, subject.UserID, task.Status)
|
||||
return dto.UsageCleanupTaskFromService(task), nil
|
||||
})
|
||||
}
|
||||
|
||||
// CancelCleanupTask handles canceling a usage cleanup task
|
||||
@@ -515,12 +573,12 @@ func (h *UsageHandler) CancelCleanupTask(c *gin.Context) {
|
||||
response.BadRequest(c, "Invalid task id")
|
||||
return
|
||||
}
|
||||
log.Printf("[UsageCleanup] 请求取消清理任务: task=%d operator=%d", taskID, subject.UserID)
|
||||
logger.LegacyPrintf("handler.admin.usage", "[UsageCleanup] 请求取消清理任务: task=%d operator=%d", taskID, subject.UserID)
|
||||
if err := h.cleanupService.CancelTask(c.Request.Context(), taskID, subject.UserID); err != nil {
|
||||
log.Printf("[UsageCleanup] 取消清理任务失败: task=%d operator=%d err=%v", taskID, subject.UserID, err)
|
||||
logger.LegacyPrintf("handler.admin.usage", "[UsageCleanup] 取消清理任务失败: task=%d operator=%d err=%v", taskID, subject.UserID, err)
|
||||
response.ErrorFrom(c, err)
|
||||
return
|
||||
}
|
||||
log.Printf("[UsageCleanup] 清理任务已取消: task=%d operator=%d", taskID, subject.UserID)
|
||||
logger.LegacyPrintf("handler.admin.usage", "[UsageCleanup] 清理任务已取消: task=%d operator=%d", taskID, subject.UserID)
|
||||
response.Success(c, gin.H{"id": taskID, "status": service.UsageCleanupStatusCanceled})
|
||||
}
|
||||
|
||||
@@ -0,0 +1,140 @@
|
||||
package admin
|
||||
|
||||
import (
|
||||
"context"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"testing"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/usagestats"
|
||||
"github.com/Wei-Shaw/sub2api/internal/service"
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
type adminUsageRepoCapture struct {
|
||||
service.UsageLogRepository
|
||||
listFilters usagestats.UsageLogFilters
|
||||
statsFilters usagestats.UsageLogFilters
|
||||
}
|
||||
|
||||
func (s *adminUsageRepoCapture) ListWithFilters(ctx context.Context, params pagination.PaginationParams, filters usagestats.UsageLogFilters) ([]service.UsageLog, *pagination.PaginationResult, error) {
|
||||
s.listFilters = filters
|
||||
return []service.UsageLog{}, &pagination.PaginationResult{
|
||||
Total: 0,
|
||||
Page: params.Page,
|
||||
PageSize: params.PageSize,
|
||||
Pages: 0,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (s *adminUsageRepoCapture) GetStatsWithFilters(ctx context.Context, filters usagestats.UsageLogFilters) (*usagestats.UsageStats, error) {
|
||||
s.statsFilters = filters
|
||||
return &usagestats.UsageStats{}, nil
|
||||
}
|
||||
|
||||
func newAdminUsageRequestTypeTestRouter(repo *adminUsageRepoCapture) *gin.Engine {
|
||||
gin.SetMode(gin.TestMode)
|
||||
usageSvc := service.NewUsageService(repo, nil, nil, nil)
|
||||
handler := NewUsageHandler(usageSvc, nil, nil, nil)
|
||||
router := gin.New()
|
||||
router.GET("/admin/usage", handler.List)
|
||||
router.GET("/admin/usage/stats", handler.Stats)
|
||||
return router
|
||||
}
|
||||
|
||||
func TestAdminUsageListRequestTypePriority(t *testing.T) {
|
||||
repo := &adminUsageRepoCapture{}
|
||||
router := newAdminUsageRequestTypeTestRouter(repo)
|
||||
|
||||
req := httptest.NewRequest(http.MethodGet, "/admin/usage?request_type=ws_v2&stream=false", nil)
|
||||
rec := httptest.NewRecorder()
|
||||
router.ServeHTTP(rec, req)
|
||||
|
||||
require.Equal(t, http.StatusOK, rec.Code)
|
||||
require.NotNil(t, repo.listFilters.RequestType)
|
||||
require.Equal(t, int16(service.RequestTypeWSV2), *repo.listFilters.RequestType)
|
||||
require.Nil(t, repo.listFilters.Stream)
|
||||
}
|
||||
|
||||
func TestAdminUsageListInvalidRequestType(t *testing.T) {
|
||||
repo := &adminUsageRepoCapture{}
|
||||
router := newAdminUsageRequestTypeTestRouter(repo)
|
||||
|
||||
req := httptest.NewRequest(http.MethodGet, "/admin/usage?request_type=bad", nil)
|
||||
rec := httptest.NewRecorder()
|
||||
router.ServeHTTP(rec, req)
|
||||
|
||||
require.Equal(t, http.StatusBadRequest, rec.Code)
|
||||
}
|
||||
|
||||
func TestAdminUsageListInvalidStream(t *testing.T) {
|
||||
repo := &adminUsageRepoCapture{}
|
||||
router := newAdminUsageRequestTypeTestRouter(repo)
|
||||
|
||||
req := httptest.NewRequest(http.MethodGet, "/admin/usage?stream=bad", nil)
|
||||
rec := httptest.NewRecorder()
|
||||
router.ServeHTTP(rec, req)
|
||||
|
||||
require.Equal(t, http.StatusBadRequest, rec.Code)
|
||||
}
|
||||
|
||||
func TestAdminUsageListExactTotalTrue(t *testing.T) {
|
||||
repo := &adminUsageRepoCapture{}
|
||||
router := newAdminUsageRequestTypeTestRouter(repo)
|
||||
|
||||
req := httptest.NewRequest(http.MethodGet, "/admin/usage?exact_total=true", nil)
|
||||
rec := httptest.NewRecorder()
|
||||
router.ServeHTTP(rec, req)
|
||||
|
||||
require.Equal(t, http.StatusOK, rec.Code)
|
||||
require.True(t, repo.listFilters.ExactTotal)
|
||||
}
|
||||
|
||||
func TestAdminUsageListInvalidExactTotal(t *testing.T) {
|
||||
repo := &adminUsageRepoCapture{}
|
||||
router := newAdminUsageRequestTypeTestRouter(repo)
|
||||
|
||||
req := httptest.NewRequest(http.MethodGet, "/admin/usage?exact_total=oops", nil)
|
||||
rec := httptest.NewRecorder()
|
||||
router.ServeHTTP(rec, req)
|
||||
|
||||
require.Equal(t, http.StatusBadRequest, rec.Code)
|
||||
}
|
||||
|
||||
func TestAdminUsageStatsRequestTypePriority(t *testing.T) {
|
||||
repo := &adminUsageRepoCapture{}
|
||||
router := newAdminUsageRequestTypeTestRouter(repo)
|
||||
|
||||
req := httptest.NewRequest(http.MethodGet, "/admin/usage/stats?request_type=stream&stream=bad", nil)
|
||||
rec := httptest.NewRecorder()
|
||||
router.ServeHTTP(rec, req)
|
||||
|
||||
require.Equal(t, http.StatusOK, rec.Code)
|
||||
require.NotNil(t, repo.statsFilters.RequestType)
|
||||
require.Equal(t, int16(service.RequestTypeStream), *repo.statsFilters.RequestType)
|
||||
require.Nil(t, repo.statsFilters.Stream)
|
||||
}
|
||||
|
||||
func TestAdminUsageStatsInvalidRequestType(t *testing.T) {
|
||||
repo := &adminUsageRepoCapture{}
|
||||
router := newAdminUsageRequestTypeTestRouter(repo)
|
||||
|
||||
req := httptest.NewRequest(http.MethodGet, "/admin/usage/stats?request_type=oops", nil)
|
||||
rec := httptest.NewRecorder()
|
||||
router.ServeHTTP(rec, req)
|
||||
|
||||
require.Equal(t, http.StatusBadRequest, rec.Code)
|
||||
}
|
||||
|
||||
func TestAdminUsageStatsInvalidStream(t *testing.T) {
|
||||
repo := &adminUsageRepoCapture{}
|
||||
router := newAdminUsageRequestTypeTestRouter(repo)
|
||||
|
||||
req := httptest.NewRequest(http.MethodGet, "/admin/usage/stats?stream=oops", nil)
|
||||
rec := httptest.NewRecorder()
|
||||
router.ServeHTTP(rec, req)
|
||||
|
||||
require.Equal(t, http.StatusBadRequest, rec.Code)
|
||||
}
|
||||
@@ -1,7 +1,9 @@
|
||||
package admin
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"strconv"
|
||||
"time"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/response"
|
||||
"github.com/Wei-Shaw/sub2api/internal/service"
|
||||
@@ -67,6 +69,8 @@ type BatchUserAttributesResponse struct {
|
||||
Attributes map[int64]map[int64]string `json:"attributes"`
|
||||
}
|
||||
|
||||
var userAttributesBatchCache = newSnapshotCache(30 * time.Second)
|
||||
|
||||
// AttributeDefinitionResponse represents attribute definition response
|
||||
type AttributeDefinitionResponse struct {
|
||||
ID int64 `json:"id"`
|
||||
@@ -327,16 +331,32 @@ func (h *UserAttributeHandler) GetBatchUserAttributes(c *gin.Context) {
|
||||
return
|
||||
}
|
||||
|
||||
if len(req.UserIDs) == 0 {
|
||||
userIDs := normalizeInt64IDList(req.UserIDs)
|
||||
if len(userIDs) == 0 {
|
||||
response.Success(c, BatchUserAttributesResponse{Attributes: map[int64]map[int64]string{}})
|
||||
return
|
||||
}
|
||||
|
||||
attrs, err := h.attrService.GetBatchUserAttributes(c.Request.Context(), req.UserIDs)
|
||||
keyRaw, _ := json.Marshal(struct {
|
||||
UserIDs []int64 `json:"user_ids"`
|
||||
}{
|
||||
UserIDs: userIDs,
|
||||
})
|
||||
cacheKey := string(keyRaw)
|
||||
if cached, ok := userAttributesBatchCache.Get(cacheKey); ok {
|
||||
c.Header("X-Snapshot-Cache", "hit")
|
||||
response.Success(c, cached.Payload)
|
||||
return
|
||||
}
|
||||
|
||||
attrs, err := h.attrService.GetBatchUserAttributes(c.Request.Context(), userIDs)
|
||||
if err != nil {
|
||||
response.ErrorFrom(c, err)
|
||||
return
|
||||
}
|
||||
|
||||
response.Success(c, BatchUserAttributesResponse{Attributes: attrs})
|
||||
payload := BatchUserAttributesResponse{Attributes: attrs}
|
||||
userAttributesBatchCache.Set(cacheKey, payload)
|
||||
c.Header("X-Snapshot-Cache", "miss")
|
||||
response.Success(c, payload)
|
||||
}
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
package admin
|
||||
|
||||
import (
|
||||
"context"
|
||||
"strconv"
|
||||
"strings"
|
||||
|
||||
@@ -33,13 +34,14 @@ func NewUserHandler(adminService service.AdminService, concurrencyService *servi
|
||||
|
||||
// CreateUserRequest represents admin create user request
|
||||
type CreateUserRequest struct {
|
||||
Email string `json:"email" binding:"required,email"`
|
||||
Password string `json:"password" binding:"required,min=6"`
|
||||
Username string `json:"username"`
|
||||
Notes string `json:"notes"`
|
||||
Balance float64 `json:"balance"`
|
||||
Concurrency int `json:"concurrency"`
|
||||
AllowedGroups []int64 `json:"allowed_groups"`
|
||||
Email string `json:"email" binding:"required,email"`
|
||||
Password string `json:"password" binding:"required,min=6"`
|
||||
Username string `json:"username"`
|
||||
Notes string `json:"notes"`
|
||||
Balance float64 `json:"balance"`
|
||||
Concurrency int `json:"concurrency"`
|
||||
AllowedGroups []int64 `json:"allowed_groups"`
|
||||
SoraStorageQuotaBytes int64 `json:"sora_storage_quota_bytes"`
|
||||
}
|
||||
|
||||
// UpdateUserRequest represents admin update user request
|
||||
@@ -55,7 +57,8 @@ type UpdateUserRequest struct {
|
||||
AllowedGroups *[]int64 `json:"allowed_groups"`
|
||||
// GroupRates 用户专属分组倍率配置
|
||||
// map[groupID]*rate,nil 表示删除该分组的专属倍率
|
||||
GroupRates map[int64]*float64 `json:"group_rates"`
|
||||
GroupRates map[int64]*float64 `json:"group_rates"`
|
||||
SoraStorageQuotaBytes *int64 `json:"sora_storage_quota_bytes"`
|
||||
}
|
||||
|
||||
// UpdateBalanceRequest represents balance update request
|
||||
@@ -78,8 +81,8 @@ func (h *UserHandler) List(c *gin.Context) {
|
||||
search := c.Query("search")
|
||||
// 标准化和验证 search 参数
|
||||
search = strings.TrimSpace(search)
|
||||
if len(search) > 100 {
|
||||
search = search[:100]
|
||||
if runes := []rune(search); len(runes) > 100 {
|
||||
search = string(runes[:100])
|
||||
}
|
||||
|
||||
filters := service.UserListFilters{
|
||||
@@ -88,6 +91,10 @@ func (h *UserHandler) List(c *gin.Context) {
|
||||
Search: search,
|
||||
Attributes: parseAttributeFilters(c),
|
||||
}
|
||||
if raw, ok := c.GetQuery("include_subscriptions"); ok {
|
||||
includeSubscriptions := parseBoolQueryWithDefault(raw, true)
|
||||
filters.IncludeSubscriptions = &includeSubscriptions
|
||||
}
|
||||
|
||||
users, total, err := h.adminService.ListUsers(c.Request.Context(), page, pageSize, filters)
|
||||
if err != nil {
|
||||
@@ -173,13 +180,14 @@ func (h *UserHandler) Create(c *gin.Context) {
|
||||
}
|
||||
|
||||
user, err := h.adminService.CreateUser(c.Request.Context(), &service.CreateUserInput{
|
||||
Email: req.Email,
|
||||
Password: req.Password,
|
||||
Username: req.Username,
|
||||
Notes: req.Notes,
|
||||
Balance: req.Balance,
|
||||
Concurrency: req.Concurrency,
|
||||
AllowedGroups: req.AllowedGroups,
|
||||
Email: req.Email,
|
||||
Password: req.Password,
|
||||
Username: req.Username,
|
||||
Notes: req.Notes,
|
||||
Balance: req.Balance,
|
||||
Concurrency: req.Concurrency,
|
||||
AllowedGroups: req.AllowedGroups,
|
||||
SoraStorageQuotaBytes: req.SoraStorageQuotaBytes,
|
||||
})
|
||||
if err != nil {
|
||||
response.ErrorFrom(c, err)
|
||||
@@ -206,15 +214,16 @@ func (h *UserHandler) Update(c *gin.Context) {
|
||||
|
||||
// 使用指针类型直接传递,nil 表示未提供该字段
|
||||
user, err := h.adminService.UpdateUser(c.Request.Context(), userID, &service.UpdateUserInput{
|
||||
Email: req.Email,
|
||||
Password: req.Password,
|
||||
Username: req.Username,
|
||||
Notes: req.Notes,
|
||||
Balance: req.Balance,
|
||||
Concurrency: req.Concurrency,
|
||||
Status: req.Status,
|
||||
AllowedGroups: req.AllowedGroups,
|
||||
GroupRates: req.GroupRates,
|
||||
Email: req.Email,
|
||||
Password: req.Password,
|
||||
Username: req.Username,
|
||||
Notes: req.Notes,
|
||||
Balance: req.Balance,
|
||||
Concurrency: req.Concurrency,
|
||||
Status: req.Status,
|
||||
AllowedGroups: req.AllowedGroups,
|
||||
GroupRates: req.GroupRates,
|
||||
SoraStorageQuotaBytes: req.SoraStorageQuotaBytes,
|
||||
})
|
||||
if err != nil {
|
||||
response.ErrorFrom(c, err)
|
||||
@@ -257,13 +266,20 @@ func (h *UserHandler) UpdateBalance(c *gin.Context) {
|
||||
return
|
||||
}
|
||||
|
||||
user, err := h.adminService.UpdateUserBalance(c.Request.Context(), userID, req.Balance, req.Operation, req.Notes)
|
||||
if err != nil {
|
||||
response.ErrorFrom(c, err)
|
||||
return
|
||||
idempotencyPayload := struct {
|
||||
UserID int64 `json:"user_id"`
|
||||
Body UpdateBalanceRequest `json:"body"`
|
||||
}{
|
||||
UserID: userID,
|
||||
Body: req,
|
||||
}
|
||||
|
||||
response.Success(c, dto.UserFromServiceAdmin(user))
|
||||
executeAdminIdempotentJSON(c, "admin.users.balance.update", idempotencyPayload, service.DefaultWriteIdempotencyTTL(), func(ctx context.Context) (any, error) {
|
||||
user, execErr := h.adminService.UpdateUserBalance(ctx, userID, req.Balance, req.Operation, req.Notes)
|
||||
if execErr != nil {
|
||||
return nil, execErr
|
||||
}
|
||||
return dto.UserFromServiceAdmin(user), nil
|
||||
})
|
||||
}
|
||||
|
||||
// GetUserAPIKeys handles getting user's API keys
|
||||
|
||||
@@ -2,7 +2,9 @@
|
||||
package handler
|
||||
|
||||
import (
|
||||
"context"
|
||||
"strconv"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/handler/dto"
|
||||
@@ -35,6 +37,11 @@ type CreateAPIKeyRequest struct {
|
||||
IPBlacklist []string `json:"ip_blacklist"` // IP 黑名单
|
||||
Quota *float64 `json:"quota"` // 配额限制 (USD)
|
||||
ExpiresInDays *int `json:"expires_in_days"` // 过期天数
|
||||
|
||||
// Rate limit fields (0 = unlimited)
|
||||
RateLimit5h *float64 `json:"rate_limit_5h"`
|
||||
RateLimit1d *float64 `json:"rate_limit_1d"`
|
||||
RateLimit7d *float64 `json:"rate_limit_7d"`
|
||||
}
|
||||
|
||||
// UpdateAPIKeyRequest represents the update API key request payload
|
||||
@@ -47,6 +54,12 @@ type UpdateAPIKeyRequest struct {
|
||||
Quota *float64 `json:"quota"` // 配额限制 (USD), 0=无限制
|
||||
ExpiresAt *string `json:"expires_at"` // 过期时间 (ISO 8601)
|
||||
ResetQuota *bool `json:"reset_quota"` // 重置已用配额
|
||||
|
||||
// Rate limit fields (nil = no change, 0 = unlimited)
|
||||
RateLimit5h *float64 `json:"rate_limit_5h"`
|
||||
RateLimit1d *float64 `json:"rate_limit_1d"`
|
||||
RateLimit7d *float64 `json:"rate_limit_7d"`
|
||||
ResetRateLimitUsage *bool `json:"reset_rate_limit_usage"` // 重置限速用量
|
||||
}
|
||||
|
||||
// List handles listing user's API keys with pagination
|
||||
@@ -61,7 +74,23 @@ func (h *APIKeyHandler) List(c *gin.Context) {
|
||||
page, pageSize := response.ParsePagination(c)
|
||||
params := pagination.PaginationParams{Page: page, PageSize: pageSize}
|
||||
|
||||
keys, result, err := h.apiKeyService.List(c.Request.Context(), subject.UserID, params)
|
||||
// Parse filter parameters
|
||||
var filters service.APIKeyListFilters
|
||||
if search := strings.TrimSpace(c.Query("search")); search != "" {
|
||||
if len(search) > 100 {
|
||||
search = search[:100]
|
||||
}
|
||||
filters.Search = search
|
||||
}
|
||||
filters.Status = c.Query("status")
|
||||
if groupIDStr := c.Query("group_id"); groupIDStr != "" {
|
||||
gid, err := strconv.ParseInt(groupIDStr, 10, 64)
|
||||
if err == nil {
|
||||
filters.GroupID = &gid
|
||||
}
|
||||
}
|
||||
|
||||
keys, result, err := h.apiKeyService.List(c.Request.Context(), subject.UserID, params, filters)
|
||||
if err != nil {
|
||||
response.ErrorFrom(c, err)
|
||||
return
|
||||
@@ -130,13 +159,23 @@ func (h *APIKeyHandler) Create(c *gin.Context) {
|
||||
if req.Quota != nil {
|
||||
svcReq.Quota = *req.Quota
|
||||
}
|
||||
key, err := h.apiKeyService.Create(c.Request.Context(), subject.UserID, svcReq)
|
||||
if err != nil {
|
||||
response.ErrorFrom(c, err)
|
||||
return
|
||||
if req.RateLimit5h != nil {
|
||||
svcReq.RateLimit5h = *req.RateLimit5h
|
||||
}
|
||||
if req.RateLimit1d != nil {
|
||||
svcReq.RateLimit1d = *req.RateLimit1d
|
||||
}
|
||||
if req.RateLimit7d != nil {
|
||||
svcReq.RateLimit7d = *req.RateLimit7d
|
||||
}
|
||||
|
||||
response.Success(c, dto.APIKeyFromService(key))
|
||||
executeUserIdempotentJSON(c, "user.api_keys.create", req, service.DefaultWriteIdempotencyTTL(), func(ctx context.Context) (any, error) {
|
||||
key, err := h.apiKeyService.Create(ctx, subject.UserID, svcReq)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return dto.APIKeyFromService(key), nil
|
||||
})
|
||||
}
|
||||
|
||||
// Update handles updating an API key
|
||||
@@ -161,10 +200,14 @@ func (h *APIKeyHandler) Update(c *gin.Context) {
|
||||
}
|
||||
|
||||
svcReq := service.UpdateAPIKeyRequest{
|
||||
IPWhitelist: req.IPWhitelist,
|
||||
IPBlacklist: req.IPBlacklist,
|
||||
Quota: req.Quota,
|
||||
ResetQuota: req.ResetQuota,
|
||||
IPWhitelist: req.IPWhitelist,
|
||||
IPBlacklist: req.IPBlacklist,
|
||||
Quota: req.Quota,
|
||||
ResetQuota: req.ResetQuota,
|
||||
RateLimit5h: req.RateLimit5h,
|
||||
RateLimit1d: req.RateLimit1d,
|
||||
RateLimit7d: req.RateLimit7d,
|
||||
ResetRateLimitUsage: req.ResetRateLimitUsage,
|
||||
}
|
||||
if req.Name != "" {
|
||||
svcReq.Name = &req.Name
|
||||
|
||||
@@ -2,6 +2,7 @@ package handler
|
||||
|
||||
import (
|
||||
"log/slog"
|
||||
"strings"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/config"
|
||||
"github.com/Wei-Shaw/sub2api/internal/handler/dto"
|
||||
@@ -112,12 +113,10 @@ func (h *AuthHandler) Register(c *gin.Context) {
|
||||
return
|
||||
}
|
||||
|
||||
// Turnstile 验证(当提供了邮箱验证码时跳过,因为发送验证码时已验证过)
|
||||
if req.VerifyCode == "" {
|
||||
if err := h.authService.VerifyTurnstile(c.Request.Context(), req.TurnstileToken, ip.GetClientIP(c)); err != nil {
|
||||
response.ErrorFrom(c, err)
|
||||
return
|
||||
}
|
||||
// Turnstile 验证(邮箱验证码注册场景避免重复校验一次性 token)
|
||||
if err := h.authService.VerifyTurnstileForRegister(c.Request.Context(), req.TurnstileToken, ip.GetClientIP(c), req.VerifyCode); err != nil {
|
||||
response.ErrorFrom(c, err)
|
||||
return
|
||||
}
|
||||
|
||||
_, user, err := h.authService.RegisterWithVerification(c.Request.Context(), req.Email, req.Password, req.VerifyCode, req.PromoCode, req.InvitationCode)
|
||||
@@ -448,17 +447,12 @@ func (h *AuthHandler) ForgotPassword(c *gin.Context) {
|
||||
return
|
||||
}
|
||||
|
||||
// Build frontend base URL from request
|
||||
scheme := "https"
|
||||
if c.Request.TLS == nil {
|
||||
// Check X-Forwarded-Proto header (common in reverse proxy setups)
|
||||
if proto := c.GetHeader("X-Forwarded-Proto"); proto != "" {
|
||||
scheme = proto
|
||||
} else {
|
||||
scheme = "http"
|
||||
}
|
||||
frontendBaseURL := strings.TrimSpace(h.cfg.Server.FrontendURL)
|
||||
if frontendBaseURL == "" {
|
||||
slog.Error("server.frontend_url not configured; cannot build password reset link")
|
||||
response.InternalError(c, "Password reset is not configured")
|
||||
return
|
||||
}
|
||||
frontendBaseURL := scheme + "://" + c.Request.Host
|
||||
|
||||
// Request password reset (async)
|
||||
// Note: This returns success even if email doesn't exist (to prevent enumeration)
|
||||
|
||||
@@ -0,0 +1,40 @@
|
||||
package dto
|
||||
|
||||
import (
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/service"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestAPIKeyFromService_MapsLastUsedAt(t *testing.T) {
|
||||
lastUsed := time.Now().UTC().Truncate(time.Second)
|
||||
src := &service.APIKey{
|
||||
ID: 1,
|
||||
UserID: 2,
|
||||
Key: "sk-map-last-used",
|
||||
Name: "Mapper",
|
||||
Status: service.StatusActive,
|
||||
LastUsedAt: &lastUsed,
|
||||
}
|
||||
|
||||
out := APIKeyFromService(src)
|
||||
require.NotNil(t, out)
|
||||
require.NotNil(t, out.LastUsedAt)
|
||||
require.WithinDuration(t, lastUsed, *out.LastUsedAt, time.Second)
|
||||
}
|
||||
|
||||
func TestAPIKeyFromService_MapsNilLastUsedAt(t *testing.T) {
|
||||
src := &service.APIKey{
|
||||
ID: 1,
|
||||
UserID: 2,
|
||||
Key: "sk-map-last-used-nil",
|
||||
Name: "MapperNil",
|
||||
Status: service.StatusActive,
|
||||
}
|
||||
|
||||
out := APIKeyFromService(src)
|
||||
require.NotNil(t, out)
|
||||
require.Nil(t, out.LastUsedAt)
|
||||
}
|
||||
@@ -2,6 +2,7 @@
|
||||
package dto
|
||||
|
||||
import (
|
||||
"strconv"
|
||||
"time"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/service"
|
||||
@@ -58,9 +59,11 @@ func UserFromServiceAdmin(u *service.User) *AdminUser {
|
||||
return nil
|
||||
}
|
||||
return &AdminUser{
|
||||
User: *base,
|
||||
Notes: u.Notes,
|
||||
GroupRates: u.GroupRates,
|
||||
User: *base,
|
||||
Notes: u.Notes,
|
||||
GroupRates: u.GroupRates,
|
||||
SoraStorageQuotaBytes: u.SoraStorageQuotaBytes,
|
||||
SoraStorageUsedBytes: u.SoraStorageUsedBytes,
|
||||
}
|
||||
}
|
||||
|
||||
@@ -69,21 +72,31 @@ func APIKeyFromService(k *service.APIKey) *APIKey {
|
||||
return nil
|
||||
}
|
||||
return &APIKey{
|
||||
ID: k.ID,
|
||||
UserID: k.UserID,
|
||||
Key: k.Key,
|
||||
Name: k.Name,
|
||||
GroupID: k.GroupID,
|
||||
Status: k.Status,
|
||||
IPWhitelist: k.IPWhitelist,
|
||||
IPBlacklist: k.IPBlacklist,
|
||||
Quota: k.Quota,
|
||||
QuotaUsed: k.QuotaUsed,
|
||||
ExpiresAt: k.ExpiresAt,
|
||||
CreatedAt: k.CreatedAt,
|
||||
UpdatedAt: k.UpdatedAt,
|
||||
User: UserFromServiceShallow(k.User),
|
||||
Group: GroupFromServiceShallow(k.Group),
|
||||
ID: k.ID,
|
||||
UserID: k.UserID,
|
||||
Key: k.Key,
|
||||
Name: k.Name,
|
||||
GroupID: k.GroupID,
|
||||
Status: k.Status,
|
||||
IPWhitelist: k.IPWhitelist,
|
||||
IPBlacklist: k.IPBlacklist,
|
||||
LastUsedAt: k.LastUsedAt,
|
||||
Quota: k.Quota,
|
||||
QuotaUsed: k.QuotaUsed,
|
||||
ExpiresAt: k.ExpiresAt,
|
||||
CreatedAt: k.CreatedAt,
|
||||
UpdatedAt: k.UpdatedAt,
|
||||
RateLimit5h: k.RateLimit5h,
|
||||
RateLimit1d: k.RateLimit1d,
|
||||
RateLimit7d: k.RateLimit7d,
|
||||
Usage5h: k.Usage5h,
|
||||
Usage1d: k.Usage1d,
|
||||
Usage7d: k.Usage7d,
|
||||
Window5hStart: k.Window5hStart,
|
||||
Window1dStart: k.Window1dStart,
|
||||
Window7dStart: k.Window7dStart,
|
||||
User: UserFromServiceShallow(k.User),
|
||||
Group: GroupFromServiceShallow(k.Group),
|
||||
}
|
||||
}
|
||||
|
||||
@@ -129,24 +142,28 @@ func GroupFromServiceAdmin(g *service.Group) *AdminGroup {
|
||||
|
||||
func groupFromServiceBase(g *service.Group) Group {
|
||||
return Group{
|
||||
ID: g.ID,
|
||||
Name: g.Name,
|
||||
Description: g.Description,
|
||||
Platform: g.Platform,
|
||||
RateMultiplier: g.RateMultiplier,
|
||||
IsExclusive: g.IsExclusive,
|
||||
Status: g.Status,
|
||||
SubscriptionType: g.SubscriptionType,
|
||||
DailyLimitUSD: g.DailyLimitUSD,
|
||||
WeeklyLimitUSD: g.WeeklyLimitUSD,
|
||||
MonthlyLimitUSD: g.MonthlyLimitUSD,
|
||||
ImagePrice1K: g.ImagePrice1K,
|
||||
ImagePrice2K: g.ImagePrice2K,
|
||||
ImagePrice4K: g.ImagePrice4K,
|
||||
ClaudeCodeOnly: g.ClaudeCodeOnly,
|
||||
FallbackGroupID: g.FallbackGroupID,
|
||||
// 无效请求兜底分组
|
||||
ID: g.ID,
|
||||
Name: g.Name,
|
||||
Description: g.Description,
|
||||
Platform: g.Platform,
|
||||
RateMultiplier: g.RateMultiplier,
|
||||
IsExclusive: g.IsExclusive,
|
||||
Status: g.Status,
|
||||
SubscriptionType: g.SubscriptionType,
|
||||
DailyLimitUSD: g.DailyLimitUSD,
|
||||
WeeklyLimitUSD: g.WeeklyLimitUSD,
|
||||
MonthlyLimitUSD: g.MonthlyLimitUSD,
|
||||
ImagePrice1K: g.ImagePrice1K,
|
||||
ImagePrice2K: g.ImagePrice2K,
|
||||
ImagePrice4K: g.ImagePrice4K,
|
||||
SoraImagePrice360: g.SoraImagePrice360,
|
||||
SoraImagePrice540: g.SoraImagePrice540,
|
||||
SoraVideoPricePerRequest: g.SoraVideoPricePerRequest,
|
||||
SoraVideoPricePerRequestHD: g.SoraVideoPricePerRequestHD,
|
||||
ClaudeCodeOnly: g.ClaudeCodeOnly,
|
||||
FallbackGroupID: g.FallbackGroupID,
|
||||
FallbackGroupIDOnInvalidRequest: g.FallbackGroupIDOnInvalidRequest,
|
||||
SoraStorageQuotaBytes: g.SoraStorageQuotaBytes,
|
||||
CreatedAt: g.CreatedAt,
|
||||
UpdatedAt: g.UpdatedAt,
|
||||
}
|
||||
@@ -201,6 +218,17 @@ func AccountFromServiceShallow(a *service.Account) *Account {
|
||||
if idleTimeout := a.GetSessionIdleTimeoutMinutes(); idleTimeout > 0 {
|
||||
out.SessionIdleTimeoutMin = &idleTimeout
|
||||
}
|
||||
if rpm := a.GetBaseRPM(); rpm > 0 {
|
||||
out.BaseRPM = &rpm
|
||||
strategy := a.GetRPMStrategy()
|
||||
out.RPMStrategy = &strategy
|
||||
buffer := a.GetRPMStickyBuffer()
|
||||
out.RPMStickyBuffer = &buffer
|
||||
}
|
||||
// 用户消息队列模式
|
||||
if mode := a.GetUserMsgQueueMode(); mode != "" {
|
||||
out.UserMsgQueueMode = &mode
|
||||
}
|
||||
// TLS指纹伪装开关
|
||||
if a.IsTLSFingerprintEnabled() {
|
||||
enabled := true
|
||||
@@ -211,6 +239,13 @@ func AccountFromServiceShallow(a *service.Account) *Account {
|
||||
enabled := true
|
||||
out.EnableSessionIDMasking = &enabled
|
||||
}
|
||||
// 缓存 TTL 强制替换
|
||||
if a.IsCacheTTLOverrideEnabled() {
|
||||
enabled := true
|
||||
out.CacheTTLOverrideEnabled = &enabled
|
||||
target := a.GetCacheTTLOverrideTarget()
|
||||
out.CacheTTLOverrideTarget = &target
|
||||
}
|
||||
}
|
||||
|
||||
return out
|
||||
@@ -271,7 +306,6 @@ func ProxyFromService(p *service.Proxy) *Proxy {
|
||||
Host: p.Host,
|
||||
Port: p.Port,
|
||||
Username: p.Username,
|
||||
Password: p.Password,
|
||||
Status: p.Status,
|
||||
CreatedAt: p.CreatedAt,
|
||||
UpdatedAt: p.UpdatedAt,
|
||||
@@ -293,6 +327,56 @@ func ProxyWithAccountCountFromService(p *service.ProxyWithAccountCount) *ProxyWi
|
||||
CountryCode: p.CountryCode,
|
||||
Region: p.Region,
|
||||
City: p.City,
|
||||
QualityStatus: p.QualityStatus,
|
||||
QualityScore: p.QualityScore,
|
||||
QualityGrade: p.QualityGrade,
|
||||
QualitySummary: p.QualitySummary,
|
||||
QualityChecked: p.QualityChecked,
|
||||
}
|
||||
}
|
||||
|
||||
// ProxyFromServiceAdmin converts a service Proxy to AdminProxy DTO for admin users.
|
||||
// It includes the password field - user-facing endpoints must not use this.
|
||||
func ProxyFromServiceAdmin(p *service.Proxy) *AdminProxy {
|
||||
if p == nil {
|
||||
return nil
|
||||
}
|
||||
base := ProxyFromService(p)
|
||||
if base == nil {
|
||||
return nil
|
||||
}
|
||||
return &AdminProxy{
|
||||
Proxy: *base,
|
||||
Password: p.Password,
|
||||
}
|
||||
}
|
||||
|
||||
// ProxyWithAccountCountFromServiceAdmin converts a service ProxyWithAccountCount to AdminProxyWithAccountCount DTO.
|
||||
// It includes the password field - user-facing endpoints must not use this.
|
||||
func ProxyWithAccountCountFromServiceAdmin(p *service.ProxyWithAccountCount) *AdminProxyWithAccountCount {
|
||||
if p == nil {
|
||||
return nil
|
||||
}
|
||||
admin := ProxyFromServiceAdmin(&p.Proxy)
|
||||
if admin == nil {
|
||||
return nil
|
||||
}
|
||||
return &AdminProxyWithAccountCount{
|
||||
AdminProxy: *admin,
|
||||
AccountCount: p.AccountCount,
|
||||
LatencyMs: p.LatencyMs,
|
||||
LatencyStatus: p.LatencyStatus,
|
||||
LatencyMessage: p.LatencyMessage,
|
||||
IPAddress: p.IPAddress,
|
||||
Country: p.Country,
|
||||
CountryCode: p.CountryCode,
|
||||
Region: p.Region,
|
||||
City: p.City,
|
||||
QualityStatus: p.QualityStatus,
|
||||
QualityScore: p.QualityScore,
|
||||
QualityGrade: p.QualityGrade,
|
||||
QualitySummary: p.QualitySummary,
|
||||
QualityChecked: p.QualityChecked,
|
||||
}
|
||||
}
|
||||
|
||||
@@ -368,6 +452,8 @@ func AccountSummaryFromService(a *service.Account) *AccountSummary {
|
||||
|
||||
func usageLogFromServiceUser(l *service.UsageLog) UsageLog {
|
||||
// 普通用户 DTO:严禁包含管理员字段(例如 account_rate_multiplier、ip_address、account)。
|
||||
requestType := l.EffectiveRequestType()
|
||||
stream, openAIWSMode := service.ApplyLegacyRequestFields(requestType, l.Stream, l.OpenAIWSMode)
|
||||
return UsageLog{
|
||||
ID: l.ID,
|
||||
UserID: l.UserID,
|
||||
@@ -392,12 +478,16 @@ func usageLogFromServiceUser(l *service.UsageLog) UsageLog {
|
||||
ActualCost: l.ActualCost,
|
||||
RateMultiplier: l.RateMultiplier,
|
||||
BillingType: l.BillingType,
|
||||
Stream: l.Stream,
|
||||
RequestType: requestType.String(),
|
||||
Stream: stream,
|
||||
OpenAIWSMode: openAIWSMode,
|
||||
DurationMs: l.DurationMs,
|
||||
FirstTokenMs: l.FirstTokenMs,
|
||||
ImageCount: l.ImageCount,
|
||||
ImageSize: l.ImageSize,
|
||||
MediaType: l.MediaType,
|
||||
UserAgent: l.UserAgent,
|
||||
CacheTTLOverridden: l.CacheTTLOverridden,
|
||||
CreatedAt: l.CreatedAt,
|
||||
User: UserFromServiceShallow(l.User),
|
||||
APIKey: APIKeyFromService(l.APIKey),
|
||||
@@ -445,6 +535,7 @@ func UsageCleanupTaskFromService(task *service.UsageCleanupTask) *UsageCleanupTa
|
||||
AccountID: task.Filters.AccountID,
|
||||
GroupID: task.Filters.GroupID,
|
||||
Model: task.Filters.Model,
|
||||
RequestType: requestTypeStringPtr(task.Filters.RequestType),
|
||||
Stream: task.Filters.Stream,
|
||||
BillingType: task.Filters.BillingType,
|
||||
},
|
||||
@@ -460,6 +551,14 @@ func UsageCleanupTaskFromService(task *service.UsageCleanupTask) *UsageCleanupTa
|
||||
}
|
||||
}
|
||||
|
||||
func requestTypeStringPtr(requestType *int16) *string {
|
||||
if requestType == nil {
|
||||
return nil
|
||||
}
|
||||
value := service.RequestTypeFromInt16(*requestType).String()
|
||||
return &value
|
||||
}
|
||||
|
||||
func SettingFromService(s *service.Setting) *Setting {
|
||||
if s == nil {
|
||||
return nil
|
||||
@@ -524,11 +623,18 @@ func BulkAssignResultFromService(r *service.BulkAssignResult) *BulkAssignResult
|
||||
for i := range r.Subscriptions {
|
||||
subs = append(subs, *UserSubscriptionFromServiceAdmin(&r.Subscriptions[i]))
|
||||
}
|
||||
statuses := make(map[string]string, len(r.Statuses))
|
||||
for userID, status := range r.Statuses {
|
||||
statuses[strconv.FormatInt(userID, 10)] = status
|
||||
}
|
||||
return &BulkAssignResult{
|
||||
SuccessCount: r.SuccessCount,
|
||||
CreatedCount: r.CreatedCount,
|
||||
ReusedCount: r.ReusedCount,
|
||||
FailedCount: r.FailedCount,
|
||||
Subscriptions: subs,
|
||||
Errors: r.Errors,
|
||||
Statuses: statuses,
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
73
backend/internal/handler/dto/mappers_usage_test.go
Normal file
73
backend/internal/handler/dto/mappers_usage_test.go
Normal file
@@ -0,0 +1,73 @@
|
||||
package dto
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/service"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestUsageLogFromService_IncludesOpenAIWSMode(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
wsLog := &service.UsageLog{
|
||||
RequestID: "req_1",
|
||||
Model: "gpt-5.3-codex",
|
||||
OpenAIWSMode: true,
|
||||
}
|
||||
httpLog := &service.UsageLog{
|
||||
RequestID: "resp_1",
|
||||
Model: "gpt-5.3-codex",
|
||||
OpenAIWSMode: false,
|
||||
}
|
||||
|
||||
require.True(t, UsageLogFromService(wsLog).OpenAIWSMode)
|
||||
require.False(t, UsageLogFromService(httpLog).OpenAIWSMode)
|
||||
require.True(t, UsageLogFromServiceAdmin(wsLog).OpenAIWSMode)
|
||||
require.False(t, UsageLogFromServiceAdmin(httpLog).OpenAIWSMode)
|
||||
}
|
||||
|
||||
func TestUsageLogFromService_PrefersRequestTypeForLegacyFields(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
log := &service.UsageLog{
|
||||
RequestID: "req_2",
|
||||
Model: "gpt-5.3-codex",
|
||||
RequestType: service.RequestTypeWSV2,
|
||||
Stream: false,
|
||||
OpenAIWSMode: false,
|
||||
}
|
||||
|
||||
userDTO := UsageLogFromService(log)
|
||||
adminDTO := UsageLogFromServiceAdmin(log)
|
||||
|
||||
require.Equal(t, "ws_v2", userDTO.RequestType)
|
||||
require.True(t, userDTO.Stream)
|
||||
require.True(t, userDTO.OpenAIWSMode)
|
||||
require.Equal(t, "ws_v2", adminDTO.RequestType)
|
||||
require.True(t, adminDTO.Stream)
|
||||
require.True(t, adminDTO.OpenAIWSMode)
|
||||
}
|
||||
|
||||
func TestUsageCleanupTaskFromService_RequestTypeMapping(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
requestType := int16(service.RequestTypeStream)
|
||||
task := &service.UsageCleanupTask{
|
||||
ID: 1,
|
||||
Status: service.UsageCleanupStatusPending,
|
||||
Filters: service.UsageCleanupFilters{
|
||||
RequestType: &requestType,
|
||||
},
|
||||
}
|
||||
|
||||
dtoTask := UsageCleanupTaskFromService(task)
|
||||
require.NotNil(t, dtoTask)
|
||||
require.NotNil(t, dtoTask.Filters.RequestType)
|
||||
require.Equal(t, "stream", *dtoTask.Filters.RequestType)
|
||||
}
|
||||
|
||||
func TestRequestTypeStringPtrNil(t *testing.T) {
|
||||
t.Parallel()
|
||||
require.Nil(t, requestTypeStringPtr(nil))
|
||||
}
|
||||
@@ -1,14 +1,30 @@
|
||||
package dto
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"strings"
|
||||
)
|
||||
|
||||
// CustomMenuItem represents a user-configured custom menu entry.
|
||||
type CustomMenuItem struct {
|
||||
ID string `json:"id"`
|
||||
Label string `json:"label"`
|
||||
IconSVG string `json:"icon_svg"`
|
||||
URL string `json:"url"`
|
||||
Visibility string `json:"visibility"` // "user" or "admin"
|
||||
SortOrder int `json:"sort_order"`
|
||||
}
|
||||
|
||||
// SystemSettings represents the admin settings API response payload.
|
||||
type SystemSettings struct {
|
||||
RegistrationEnabled bool `json:"registration_enabled"`
|
||||
EmailVerifyEnabled bool `json:"email_verify_enabled"`
|
||||
PromoCodeEnabled bool `json:"promo_code_enabled"`
|
||||
PasswordResetEnabled bool `json:"password_reset_enabled"`
|
||||
InvitationCodeEnabled bool `json:"invitation_code_enabled"`
|
||||
TotpEnabled bool `json:"totp_enabled"` // TOTP 双因素认证
|
||||
TotpEncryptionKeyConfigured bool `json:"totp_encryption_key_configured"` // TOTP 加密密钥是否已配置
|
||||
RegistrationEnabled bool `json:"registration_enabled"`
|
||||
EmailVerifyEnabled bool `json:"email_verify_enabled"`
|
||||
RegistrationEmailSuffixWhitelist []string `json:"registration_email_suffix_whitelist"`
|
||||
PromoCodeEnabled bool `json:"promo_code_enabled"`
|
||||
PasswordResetEnabled bool `json:"password_reset_enabled"`
|
||||
InvitationCodeEnabled bool `json:"invitation_code_enabled"`
|
||||
TotpEnabled bool `json:"totp_enabled"` // TOTP 双因素认证
|
||||
TotpEncryptionKeyConfigured bool `json:"totp_encryption_key_configured"` // TOTP 加密密钥是否已配置
|
||||
|
||||
SMTPHost string `json:"smtp_host"`
|
||||
SMTPPort int `json:"smtp_port"`
|
||||
@@ -27,19 +43,22 @@ type SystemSettings struct {
|
||||
LinuxDoConnectClientSecretConfigured bool `json:"linuxdo_connect_client_secret_configured"`
|
||||
LinuxDoConnectRedirectURL string `json:"linuxdo_connect_redirect_url"`
|
||||
|
||||
SiteName string `json:"site_name"`
|
||||
SiteLogo string `json:"site_logo"`
|
||||
SiteSubtitle string `json:"site_subtitle"`
|
||||
APIBaseURL string `json:"api_base_url"`
|
||||
ContactInfo string `json:"contact_info"`
|
||||
DocURL string `json:"doc_url"`
|
||||
HomeContent string `json:"home_content"`
|
||||
HideCcsImportButton bool `json:"hide_ccs_import_button"`
|
||||
PurchaseSubscriptionEnabled bool `json:"purchase_subscription_enabled"`
|
||||
PurchaseSubscriptionURL string `json:"purchase_subscription_url"`
|
||||
SiteName string `json:"site_name"`
|
||||
SiteLogo string `json:"site_logo"`
|
||||
SiteSubtitle string `json:"site_subtitle"`
|
||||
APIBaseURL string `json:"api_base_url"`
|
||||
ContactInfo string `json:"contact_info"`
|
||||
DocURL string `json:"doc_url"`
|
||||
HomeContent string `json:"home_content"`
|
||||
HideCcsImportButton bool `json:"hide_ccs_import_button"`
|
||||
PurchaseSubscriptionEnabled bool `json:"purchase_subscription_enabled"`
|
||||
PurchaseSubscriptionURL string `json:"purchase_subscription_url"`
|
||||
SoraClientEnabled bool `json:"sora_client_enabled"`
|
||||
CustomMenuItems []CustomMenuItem `json:"custom_menu_items"`
|
||||
|
||||
DefaultConcurrency int `json:"default_concurrency"`
|
||||
DefaultBalance float64 `json:"default_balance"`
|
||||
DefaultConcurrency int `json:"default_concurrency"`
|
||||
DefaultBalance float64 `json:"default_balance"`
|
||||
DefaultSubscriptions []DefaultSubscriptionSetting `json:"default_subscriptions"`
|
||||
|
||||
// Model fallback configuration
|
||||
EnableModelFallback bool `json:"enable_model_fallback"`
|
||||
@@ -57,29 +76,80 @@ type SystemSettings struct {
|
||||
OpsRealtimeMonitoringEnabled bool `json:"ops_realtime_monitoring_enabled"`
|
||||
OpsQueryModeDefault string `json:"ops_query_mode_default"`
|
||||
OpsMetricsIntervalSeconds int `json:"ops_metrics_interval_seconds"`
|
||||
|
||||
MinClaudeCodeVersion string `json:"min_claude_code_version"`
|
||||
|
||||
// 分组隔离
|
||||
AllowUngroupedKeyScheduling bool `json:"allow_ungrouped_key_scheduling"`
|
||||
}
|
||||
|
||||
type DefaultSubscriptionSetting struct {
|
||||
GroupID int64 `json:"group_id"`
|
||||
ValidityDays int `json:"validity_days"`
|
||||
}
|
||||
|
||||
type PublicSettings struct {
|
||||
RegistrationEnabled bool `json:"registration_enabled"`
|
||||
EmailVerifyEnabled bool `json:"email_verify_enabled"`
|
||||
PromoCodeEnabled bool `json:"promo_code_enabled"`
|
||||
PasswordResetEnabled bool `json:"password_reset_enabled"`
|
||||
InvitationCodeEnabled bool `json:"invitation_code_enabled"`
|
||||
TotpEnabled bool `json:"totp_enabled"` // TOTP 双因素认证
|
||||
TurnstileEnabled bool `json:"turnstile_enabled"`
|
||||
TurnstileSiteKey string `json:"turnstile_site_key"`
|
||||
SiteName string `json:"site_name"`
|
||||
SiteLogo string `json:"site_logo"`
|
||||
SiteSubtitle string `json:"site_subtitle"`
|
||||
APIBaseURL string `json:"api_base_url"`
|
||||
ContactInfo string `json:"contact_info"`
|
||||
DocURL string `json:"doc_url"`
|
||||
HomeContent string `json:"home_content"`
|
||||
HideCcsImportButton bool `json:"hide_ccs_import_button"`
|
||||
PurchaseSubscriptionEnabled bool `json:"purchase_subscription_enabled"`
|
||||
PurchaseSubscriptionURL string `json:"purchase_subscription_url"`
|
||||
LinuxDoOAuthEnabled bool `json:"linuxdo_oauth_enabled"`
|
||||
Version string `json:"version"`
|
||||
RegistrationEnabled bool `json:"registration_enabled"`
|
||||
EmailVerifyEnabled bool `json:"email_verify_enabled"`
|
||||
RegistrationEmailSuffixWhitelist []string `json:"registration_email_suffix_whitelist"`
|
||||
PromoCodeEnabled bool `json:"promo_code_enabled"`
|
||||
PasswordResetEnabled bool `json:"password_reset_enabled"`
|
||||
InvitationCodeEnabled bool `json:"invitation_code_enabled"`
|
||||
TotpEnabled bool `json:"totp_enabled"` // TOTP 双因素认证
|
||||
TurnstileEnabled bool `json:"turnstile_enabled"`
|
||||
TurnstileSiteKey string `json:"turnstile_site_key"`
|
||||
SiteName string `json:"site_name"`
|
||||
SiteLogo string `json:"site_logo"`
|
||||
SiteSubtitle string `json:"site_subtitle"`
|
||||
APIBaseURL string `json:"api_base_url"`
|
||||
ContactInfo string `json:"contact_info"`
|
||||
DocURL string `json:"doc_url"`
|
||||
HomeContent string `json:"home_content"`
|
||||
HideCcsImportButton bool `json:"hide_ccs_import_button"`
|
||||
PurchaseSubscriptionEnabled bool `json:"purchase_subscription_enabled"`
|
||||
PurchaseSubscriptionURL string `json:"purchase_subscription_url"`
|
||||
CustomMenuItems []CustomMenuItem `json:"custom_menu_items"`
|
||||
LinuxDoOAuthEnabled bool `json:"linuxdo_oauth_enabled"`
|
||||
SoraClientEnabled bool `json:"sora_client_enabled"`
|
||||
Version string `json:"version"`
|
||||
}
|
||||
|
||||
// SoraS3Settings Sora S3 存储配置 DTO(响应用,不含敏感字段)
|
||||
type SoraS3Settings struct {
|
||||
Enabled bool `json:"enabled"`
|
||||
Endpoint string `json:"endpoint"`
|
||||
Region string `json:"region"`
|
||||
Bucket string `json:"bucket"`
|
||||
AccessKeyID string `json:"access_key_id"`
|
||||
SecretAccessKeyConfigured bool `json:"secret_access_key_configured"`
|
||||
Prefix string `json:"prefix"`
|
||||
ForcePathStyle bool `json:"force_path_style"`
|
||||
CDNURL string `json:"cdn_url"`
|
||||
DefaultStorageQuotaBytes int64 `json:"default_storage_quota_bytes"`
|
||||
}
|
||||
|
||||
// SoraS3Profile Sora S3 存储配置项 DTO(响应用,不含敏感字段)
|
||||
type SoraS3Profile struct {
|
||||
ProfileID string `json:"profile_id"`
|
||||
Name string `json:"name"`
|
||||
IsActive bool `json:"is_active"`
|
||||
Enabled bool `json:"enabled"`
|
||||
Endpoint string `json:"endpoint"`
|
||||
Region string `json:"region"`
|
||||
Bucket string `json:"bucket"`
|
||||
AccessKeyID string `json:"access_key_id"`
|
||||
SecretAccessKeyConfigured bool `json:"secret_access_key_configured"`
|
||||
Prefix string `json:"prefix"`
|
||||
ForcePathStyle bool `json:"force_path_style"`
|
||||
CDNURL string `json:"cdn_url"`
|
||||
DefaultStorageQuotaBytes int64 `json:"default_storage_quota_bytes"`
|
||||
UpdatedAt string `json:"updated_at"`
|
||||
}
|
||||
|
||||
// ListSoraS3ProfilesResponse Sora S3 配置列表响应
|
||||
type ListSoraS3ProfilesResponse struct {
|
||||
ActiveProfileID string `json:"active_profile_id"`
|
||||
Items []SoraS3Profile `json:"items"`
|
||||
}
|
||||
|
||||
// StreamTimeoutSettings 流超时处理配置 DTO
|
||||
@@ -90,3 +160,29 @@ type StreamTimeoutSettings struct {
|
||||
ThresholdCount int `json:"threshold_count"`
|
||||
ThresholdWindowMinutes int `json:"threshold_window_minutes"`
|
||||
}
|
||||
|
||||
// ParseCustomMenuItems parses a JSON string into a slice of CustomMenuItem.
|
||||
// Returns empty slice on empty/invalid input.
|
||||
func ParseCustomMenuItems(raw string) []CustomMenuItem {
|
||||
raw = strings.TrimSpace(raw)
|
||||
if raw == "" || raw == "[]" {
|
||||
return []CustomMenuItem{}
|
||||
}
|
||||
var items []CustomMenuItem
|
||||
if err := json.Unmarshal([]byte(raw), &items); err != nil {
|
||||
return []CustomMenuItem{}
|
||||
}
|
||||
return items
|
||||
}
|
||||
|
||||
// ParseUserVisibleMenuItems parses custom menu items and filters out admin-only entries.
|
||||
func ParseUserVisibleMenuItems(raw string) []CustomMenuItem {
|
||||
items := ParseCustomMenuItems(raw)
|
||||
filtered := make([]CustomMenuItem, 0, len(items))
|
||||
for _, item := range items {
|
||||
if item.Visibility != "admin" {
|
||||
filtered = append(filtered, item)
|
||||
}
|
||||
}
|
||||
return filtered
|
||||
}
|
||||
|
||||
@@ -26,7 +26,9 @@ type AdminUser struct {
|
||||
Notes string `json:"notes"`
|
||||
// GroupRates 用户专属分组倍率配置
|
||||
// map[groupID]rateMultiplier
|
||||
GroupRates map[int64]float64 `json:"group_rates,omitempty"`
|
||||
GroupRates map[int64]float64 `json:"group_rates,omitempty"`
|
||||
SoraStorageQuotaBytes int64 `json:"sora_storage_quota_bytes"`
|
||||
SoraStorageUsedBytes int64 `json:"sora_storage_used_bytes"`
|
||||
}
|
||||
|
||||
type APIKey struct {
|
||||
@@ -38,12 +40,24 @@ type APIKey struct {
|
||||
Status string `json:"status"`
|
||||
IPWhitelist []string `json:"ip_whitelist"`
|
||||
IPBlacklist []string `json:"ip_blacklist"`
|
||||
LastUsedAt *time.Time `json:"last_used_at"`
|
||||
Quota float64 `json:"quota"` // Quota limit in USD (0 = unlimited)
|
||||
QuotaUsed float64 `json:"quota_used"` // Used quota amount in USD
|
||||
ExpiresAt *time.Time `json:"expires_at"` // Expiration time (nil = never expires)
|
||||
CreatedAt time.Time `json:"created_at"`
|
||||
UpdatedAt time.Time `json:"updated_at"`
|
||||
|
||||
// Rate limit fields
|
||||
RateLimit5h float64 `json:"rate_limit_5h"`
|
||||
RateLimit1d float64 `json:"rate_limit_1d"`
|
||||
RateLimit7d float64 `json:"rate_limit_7d"`
|
||||
Usage5h float64 `json:"usage_5h"`
|
||||
Usage1d float64 `json:"usage_1d"`
|
||||
Usage7d float64 `json:"usage_7d"`
|
||||
Window5hStart *time.Time `json:"window_5h_start"`
|
||||
Window1dStart *time.Time `json:"window_1d_start"`
|
||||
Window7dStart *time.Time `json:"window_7d_start"`
|
||||
|
||||
User *User `json:"user,omitempty"`
|
||||
Group *Group `json:"group,omitempty"`
|
||||
}
|
||||
@@ -67,12 +81,21 @@ type Group struct {
|
||||
ImagePrice2K *float64 `json:"image_price_2k"`
|
||||
ImagePrice4K *float64 `json:"image_price_4k"`
|
||||
|
||||
// Sora 按次计费配置
|
||||
SoraImagePrice360 *float64 `json:"sora_image_price_360"`
|
||||
SoraImagePrice540 *float64 `json:"sora_image_price_540"`
|
||||
SoraVideoPricePerRequest *float64 `json:"sora_video_price_per_request"`
|
||||
SoraVideoPricePerRequestHD *float64 `json:"sora_video_price_per_request_hd"`
|
||||
|
||||
// Claude Code 客户端限制
|
||||
ClaudeCodeOnly bool `json:"claude_code_only"`
|
||||
FallbackGroupID *int64 `json:"fallback_group_id"`
|
||||
// 无效请求兜底分组
|
||||
FallbackGroupIDOnInvalidRequest *int64 `json:"fallback_group_id_on_invalid_request"`
|
||||
|
||||
// Sora 存储配额
|
||||
SoraStorageQuotaBytes int64 `json:"sora_storage_quota_bytes"`
|
||||
|
||||
CreatedAt time.Time `json:"created_at"`
|
||||
UpdatedAt time.Time `json:"updated_at"`
|
||||
}
|
||||
@@ -141,6 +164,13 @@ type Account struct {
|
||||
MaxSessions *int `json:"max_sessions,omitempty"`
|
||||
SessionIdleTimeoutMin *int `json:"session_idle_timeout_minutes,omitempty"`
|
||||
|
||||
// RPM 限制(仅 Anthropic OAuth/SetupToken 账号有效)
|
||||
// 从 extra 字段提取,方便前端显示和编辑
|
||||
BaseRPM *int `json:"base_rpm,omitempty"`
|
||||
RPMStrategy *string `json:"rpm_strategy,omitempty"`
|
||||
RPMStickyBuffer *int `json:"rpm_sticky_buffer,omitempty"`
|
||||
UserMsgQueueMode *string `json:"user_msg_queue_mode,omitempty"`
|
||||
|
||||
// TLS指纹伪装(仅 Anthropic OAuth/SetupToken 账号有效)
|
||||
// 从 extra 字段提取,方便前端显示和编辑
|
||||
EnableTLSFingerprint *bool `json:"enable_tls_fingerprint,omitempty"`
|
||||
@@ -150,6 +180,11 @@ type Account struct {
|
||||
// 从 extra 字段提取,方便前端显示和编辑
|
||||
EnableSessionIDMasking *bool `json:"session_id_masking_enabled,omitempty"`
|
||||
|
||||
// 缓存 TTL 强制替换(仅 Anthropic OAuth/SetupToken 账号有效)
|
||||
// 启用后将所有 cache creation tokens 归入指定的 TTL 类型计费
|
||||
CacheTTLOverrideEnabled *bool `json:"cache_ttl_override_enabled,omitempty"`
|
||||
CacheTTLOverrideTarget *string `json:"cache_ttl_override_target,omitempty"`
|
||||
|
||||
Proxy *Proxy `json:"proxy,omitempty"`
|
||||
AccountGroups []AccountGroup `json:"account_groups,omitempty"`
|
||||
|
||||
@@ -191,6 +226,37 @@ type ProxyWithAccountCount struct {
|
||||
CountryCode string `json:"country_code,omitempty"`
|
||||
Region string `json:"region,omitempty"`
|
||||
City string `json:"city,omitempty"`
|
||||
QualityStatus string `json:"quality_status,omitempty"`
|
||||
QualityScore *int `json:"quality_score,omitempty"`
|
||||
QualityGrade string `json:"quality_grade,omitempty"`
|
||||
QualitySummary string `json:"quality_summary,omitempty"`
|
||||
QualityChecked *int64 `json:"quality_checked,omitempty"`
|
||||
}
|
||||
|
||||
// AdminProxy 是管理员接口使用的 proxy DTO(包含密码等敏感字段)。
|
||||
// 注意:普通接口不得使用此 DTO。
|
||||
type AdminProxy struct {
|
||||
Proxy
|
||||
Password string `json:"password,omitempty"`
|
||||
}
|
||||
|
||||
// AdminProxyWithAccountCount 是管理员接口使用的带账号统计的 proxy DTO。
|
||||
type AdminProxyWithAccountCount struct {
|
||||
AdminProxy
|
||||
AccountCount int64 `json:"account_count"`
|
||||
LatencyMs *int64 `json:"latency_ms,omitempty"`
|
||||
LatencyStatus string `json:"latency_status,omitempty"`
|
||||
LatencyMessage string `json:"latency_message,omitempty"`
|
||||
IPAddress string `json:"ip_address,omitempty"`
|
||||
Country string `json:"country,omitempty"`
|
||||
CountryCode string `json:"country_code,omitempty"`
|
||||
Region string `json:"region,omitempty"`
|
||||
City string `json:"city,omitempty"`
|
||||
QualityStatus string `json:"quality_status,omitempty"`
|
||||
QualityScore *int `json:"quality_score,omitempty"`
|
||||
QualityGrade string `json:"quality_grade,omitempty"`
|
||||
QualitySummary string `json:"quality_summary,omitempty"`
|
||||
QualityChecked *int64 `json:"quality_checked,omitempty"`
|
||||
}
|
||||
|
||||
type ProxyAccountSummary struct {
|
||||
@@ -261,18 +327,24 @@ type UsageLog struct {
|
||||
ActualCost float64 `json:"actual_cost"`
|
||||
RateMultiplier float64 `json:"rate_multiplier"`
|
||||
|
||||
BillingType int8 `json:"billing_type"`
|
||||
Stream bool `json:"stream"`
|
||||
DurationMs *int `json:"duration_ms"`
|
||||
FirstTokenMs *int `json:"first_token_ms"`
|
||||
BillingType int8 `json:"billing_type"`
|
||||
RequestType string `json:"request_type"`
|
||||
Stream bool `json:"stream"`
|
||||
OpenAIWSMode bool `json:"openai_ws_mode"`
|
||||
DurationMs *int `json:"duration_ms"`
|
||||
FirstTokenMs *int `json:"first_token_ms"`
|
||||
|
||||
// 图片生成字段
|
||||
ImageCount int `json:"image_count"`
|
||||
ImageSize *string `json:"image_size"`
|
||||
MediaType *string `json:"media_type"`
|
||||
|
||||
// User-Agent
|
||||
UserAgent *string `json:"user_agent"`
|
||||
|
||||
// Cache TTL Override 标记
|
||||
CacheTTLOverridden bool `json:"cache_ttl_overridden"`
|
||||
|
||||
CreatedAt time.Time `json:"created_at"`
|
||||
|
||||
User *User `json:"user,omitempty"`
|
||||
@@ -303,6 +375,7 @@ type UsageCleanupFilters struct {
|
||||
AccountID *int64 `json:"account_id,omitempty"`
|
||||
GroupID *int64 `json:"group_id,omitempty"`
|
||||
Model *string `json:"model,omitempty"`
|
||||
RequestType *string `json:"request_type,omitempty"`
|
||||
Stream *bool `json:"stream,omitempty"`
|
||||
BillingType *int8 `json:"billing_type,omitempty"`
|
||||
}
|
||||
@@ -374,9 +447,12 @@ type AdminUserSubscription struct {
|
||||
|
||||
type BulkAssignResult struct {
|
||||
SuccessCount int `json:"success_count"`
|
||||
CreatedCount int `json:"created_count"`
|
||||
ReusedCount int `json:"reused_count"`
|
||||
FailedCount int `json:"failed_count"`
|
||||
Subscriptions []AdminUserSubscription `json:"subscriptions"`
|
||||
Errors []string `json:"errors"`
|
||||
Statuses map[string]string `json:"statuses,omitempty"`
|
||||
}
|
||||
|
||||
// PromoCode 注册优惠码
|
||||
|
||||
174
backend/internal/handler/failover_loop.go
Normal file
174
backend/internal/handler/failover_loop.go
Normal file
@@ -0,0 +1,174 @@
|
||||
package handler
|
||||
|
||||
import (
|
||||
"context"
|
||||
"net/http"
|
||||
"time"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/logger"
|
||||
"github.com/Wei-Shaw/sub2api/internal/service"
|
||||
"go.uber.org/zap"
|
||||
)
|
||||
|
||||
// TempUnscheduler 用于 HandleFailoverError 中同账号重试耗尽后的临时封禁。
|
||||
// GatewayService 隐式实现此接口。
|
||||
type TempUnscheduler interface {
|
||||
TempUnscheduleRetryableError(ctx context.Context, accountID int64, failoverErr *service.UpstreamFailoverError)
|
||||
}
|
||||
|
||||
// FailoverAction 表示 failover 错误处理后的下一步动作
|
||||
type FailoverAction int
|
||||
|
||||
const (
|
||||
// FailoverContinue 继续循环(同账号重试或切换账号,调用方统一 continue)
|
||||
FailoverContinue FailoverAction = iota
|
||||
// FailoverExhausted 切换次数耗尽(调用方应返回错误响应)
|
||||
FailoverExhausted
|
||||
// FailoverCanceled context 已取消(调用方应直接 return)
|
||||
FailoverCanceled
|
||||
)
|
||||
|
||||
const (
|
||||
// maxSameAccountRetries 同账号重试次数上限(针对 RetryableOnSameAccount 错误)
|
||||
maxSameAccountRetries = 2
|
||||
// sameAccountRetryDelay 同账号重试间隔
|
||||
sameAccountRetryDelay = 500 * time.Millisecond
|
||||
// singleAccountBackoffDelay 单账号分组 503 退避重试固定延时。
|
||||
// Service 层在 SingleAccountRetry 模式下已做充分原地重试(最多 3 次、总等待 30s),
|
||||
// Handler 层只需短暂间隔后重新进入 Service 层即可。
|
||||
singleAccountBackoffDelay = 2 * time.Second
|
||||
)
|
||||
|
||||
// FailoverState 跨循环迭代共享的 failover 状态
|
||||
type FailoverState struct {
|
||||
SwitchCount int
|
||||
MaxSwitches int
|
||||
FailedAccountIDs map[int64]struct{}
|
||||
SameAccountRetryCount map[int64]int
|
||||
LastFailoverErr *service.UpstreamFailoverError
|
||||
ForceCacheBilling bool
|
||||
hasBoundSession bool
|
||||
}
|
||||
|
||||
// NewFailoverState 创建 failover 状态
|
||||
func NewFailoverState(maxSwitches int, hasBoundSession bool) *FailoverState {
|
||||
return &FailoverState{
|
||||
MaxSwitches: maxSwitches,
|
||||
FailedAccountIDs: make(map[int64]struct{}),
|
||||
SameAccountRetryCount: make(map[int64]int),
|
||||
hasBoundSession: hasBoundSession,
|
||||
}
|
||||
}
|
||||
|
||||
// HandleFailoverError 处理 UpstreamFailoverError,返回下一步动作。
|
||||
// 包含:缓存计费判断、同账号重试、临时封禁、切换计数、Antigravity 延时。
|
||||
func (s *FailoverState) HandleFailoverError(
|
||||
ctx context.Context,
|
||||
gatewayService TempUnscheduler,
|
||||
accountID int64,
|
||||
platform string,
|
||||
failoverErr *service.UpstreamFailoverError,
|
||||
) FailoverAction {
|
||||
s.LastFailoverErr = failoverErr
|
||||
|
||||
// 缓存计费判断
|
||||
if needForceCacheBilling(s.hasBoundSession, failoverErr) {
|
||||
s.ForceCacheBilling = true
|
||||
}
|
||||
|
||||
// 同账号重试:对 RetryableOnSameAccount 的临时性错误,先在同一账号上重试
|
||||
if failoverErr.RetryableOnSameAccount && s.SameAccountRetryCount[accountID] < maxSameAccountRetries {
|
||||
s.SameAccountRetryCount[accountID]++
|
||||
logger.FromContext(ctx).Warn("gateway.failover_same_account_retry",
|
||||
zap.Int64("account_id", accountID),
|
||||
zap.Int("upstream_status", failoverErr.StatusCode),
|
||||
zap.Int("same_account_retry_count", s.SameAccountRetryCount[accountID]),
|
||||
zap.Int("same_account_retry_max", maxSameAccountRetries),
|
||||
)
|
||||
if !sleepWithContext(ctx, sameAccountRetryDelay) {
|
||||
return FailoverCanceled
|
||||
}
|
||||
return FailoverContinue
|
||||
}
|
||||
|
||||
// 同账号重试用尽,执行临时封禁
|
||||
if failoverErr.RetryableOnSameAccount {
|
||||
gatewayService.TempUnscheduleRetryableError(ctx, accountID, failoverErr)
|
||||
}
|
||||
|
||||
// 加入失败列表
|
||||
s.FailedAccountIDs[accountID] = struct{}{}
|
||||
|
||||
// 检查是否耗尽
|
||||
if s.SwitchCount >= s.MaxSwitches {
|
||||
return FailoverExhausted
|
||||
}
|
||||
|
||||
// 递增切换计数
|
||||
s.SwitchCount++
|
||||
logger.FromContext(ctx).Warn("gateway.failover_switch_account",
|
||||
zap.Int64("account_id", accountID),
|
||||
zap.Int("upstream_status", failoverErr.StatusCode),
|
||||
zap.Int("switch_count", s.SwitchCount),
|
||||
zap.Int("max_switches", s.MaxSwitches),
|
||||
)
|
||||
|
||||
// Antigravity 平台换号线性递增延时
|
||||
if platform == service.PlatformAntigravity {
|
||||
delay := time.Duration(s.SwitchCount-1) * time.Second
|
||||
if !sleepWithContext(ctx, delay) {
|
||||
return FailoverCanceled
|
||||
}
|
||||
}
|
||||
|
||||
return FailoverContinue
|
||||
}
|
||||
|
||||
// HandleSelectionExhausted 处理选号失败(所有候选账号都在排除列表中)时的退避重试决策。
|
||||
// 针对 Antigravity 单账号分组的 503 (MODEL_CAPACITY_EXHAUSTED) 场景:
|
||||
// 清除排除列表、等待退避后重新选号。
|
||||
//
|
||||
// 返回 FailoverContinue 时,调用方应设置 SingleAccountRetry context 并 continue。
|
||||
// 返回 FailoverExhausted 时,调用方应返回错误响应。
|
||||
// 返回 FailoverCanceled 时,调用方应直接 return。
|
||||
func (s *FailoverState) HandleSelectionExhausted(ctx context.Context) FailoverAction {
|
||||
if s.LastFailoverErr != nil &&
|
||||
s.LastFailoverErr.StatusCode == http.StatusServiceUnavailable &&
|
||||
s.SwitchCount <= s.MaxSwitches {
|
||||
|
||||
logger.FromContext(ctx).Warn("gateway.failover_single_account_backoff",
|
||||
zap.Duration("backoff_delay", singleAccountBackoffDelay),
|
||||
zap.Int("switch_count", s.SwitchCount),
|
||||
zap.Int("max_switches", s.MaxSwitches),
|
||||
)
|
||||
if !sleepWithContext(ctx, singleAccountBackoffDelay) {
|
||||
return FailoverCanceled
|
||||
}
|
||||
logger.FromContext(ctx).Warn("gateway.failover_single_account_retry",
|
||||
zap.Int("switch_count", s.SwitchCount),
|
||||
zap.Int("max_switches", s.MaxSwitches),
|
||||
)
|
||||
s.FailedAccountIDs = make(map[int64]struct{})
|
||||
return FailoverContinue
|
||||
}
|
||||
return FailoverExhausted
|
||||
}
|
||||
|
||||
// needForceCacheBilling 判断 failover 时是否需要强制缓存计费。
|
||||
// 粘性会话切换账号、或上游明确标记时,将 input_tokens 转为 cache_read 计费。
|
||||
func needForceCacheBilling(hasBoundSession bool, failoverErr *service.UpstreamFailoverError) bool {
|
||||
return hasBoundSession || (failoverErr != nil && failoverErr.ForceCacheBilling)
|
||||
}
|
||||
|
||||
// sleepWithContext 等待指定时长,返回 false 表示 context 已取消。
|
||||
func sleepWithContext(ctx context.Context, d time.Duration) bool {
|
||||
if d <= 0 {
|
||||
return true
|
||||
}
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return false
|
||||
case <-time.After(d):
|
||||
return true
|
||||
}
|
||||
}
|
||||
732
backend/internal/handler/failover_loop_test.go
Normal file
732
backend/internal/handler/failover_loop_test.go
Normal file
@@ -0,0 +1,732 @@
|
||||
package handler
|
||||
|
||||
import (
|
||||
"context"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/service"
|
||||
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// Mock
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
// mockTempUnscheduler 记录 TempUnscheduleRetryableError 的调用信息。
|
||||
type mockTempUnscheduler struct {
|
||||
calls []tempUnscheduleCall
|
||||
}
|
||||
|
||||
type tempUnscheduleCall struct {
|
||||
accountID int64
|
||||
failoverErr *service.UpstreamFailoverError
|
||||
}
|
||||
|
||||
func (m *mockTempUnscheduler) TempUnscheduleRetryableError(_ context.Context, accountID int64, failoverErr *service.UpstreamFailoverError) {
|
||||
m.calls = append(m.calls, tempUnscheduleCall{accountID: accountID, failoverErr: failoverErr})
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// Helper
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
func newTestFailoverErr(statusCode int, retryable, forceBilling bool) *service.UpstreamFailoverError {
|
||||
return &service.UpstreamFailoverError{
|
||||
StatusCode: statusCode,
|
||||
RetryableOnSameAccount: retryable,
|
||||
ForceCacheBilling: forceBilling,
|
||||
}
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// NewFailoverState 测试
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
func TestNewFailoverState(t *testing.T) {
|
||||
t.Run("初始化字段正确", func(t *testing.T) {
|
||||
fs := NewFailoverState(5, true)
|
||||
require.Equal(t, 5, fs.MaxSwitches)
|
||||
require.Equal(t, 0, fs.SwitchCount)
|
||||
require.NotNil(t, fs.FailedAccountIDs)
|
||||
require.Empty(t, fs.FailedAccountIDs)
|
||||
require.NotNil(t, fs.SameAccountRetryCount)
|
||||
require.Empty(t, fs.SameAccountRetryCount)
|
||||
require.Nil(t, fs.LastFailoverErr)
|
||||
require.False(t, fs.ForceCacheBilling)
|
||||
require.True(t, fs.hasBoundSession)
|
||||
})
|
||||
|
||||
t.Run("无绑定会话", func(t *testing.T) {
|
||||
fs := NewFailoverState(3, false)
|
||||
require.Equal(t, 3, fs.MaxSwitches)
|
||||
require.False(t, fs.hasBoundSession)
|
||||
})
|
||||
|
||||
t.Run("零最大切换次数", func(t *testing.T) {
|
||||
fs := NewFailoverState(0, false)
|
||||
require.Equal(t, 0, fs.MaxSwitches)
|
||||
})
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// sleepWithContext 测试
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
func TestSleepWithContext(t *testing.T) {
|
||||
t.Run("零时长立即返回true", func(t *testing.T) {
|
||||
start := time.Now()
|
||||
ok := sleepWithContext(context.Background(), 0)
|
||||
require.True(t, ok)
|
||||
require.Less(t, time.Since(start), 50*time.Millisecond)
|
||||
})
|
||||
|
||||
t.Run("负时长立即返回true", func(t *testing.T) {
|
||||
start := time.Now()
|
||||
ok := sleepWithContext(context.Background(), -1*time.Second)
|
||||
require.True(t, ok)
|
||||
require.Less(t, time.Since(start), 50*time.Millisecond)
|
||||
})
|
||||
|
||||
t.Run("正常等待后返回true", func(t *testing.T) {
|
||||
start := time.Now()
|
||||
ok := sleepWithContext(context.Background(), 50*time.Millisecond)
|
||||
elapsed := time.Since(start)
|
||||
require.True(t, ok)
|
||||
require.GreaterOrEqual(t, elapsed, 40*time.Millisecond)
|
||||
require.Less(t, elapsed, 500*time.Millisecond)
|
||||
})
|
||||
|
||||
t.Run("已取消context立即返回false", func(t *testing.T) {
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
cancel()
|
||||
|
||||
start := time.Now()
|
||||
ok := sleepWithContext(ctx, 5*time.Second)
|
||||
require.False(t, ok)
|
||||
require.Less(t, time.Since(start), 50*time.Millisecond)
|
||||
})
|
||||
|
||||
t.Run("等待期间context取消返回false", func(t *testing.T) {
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
go func() {
|
||||
time.Sleep(30 * time.Millisecond)
|
||||
cancel()
|
||||
}()
|
||||
|
||||
start := time.Now()
|
||||
ok := sleepWithContext(ctx, 5*time.Second)
|
||||
elapsed := time.Since(start)
|
||||
require.False(t, ok)
|
||||
require.Less(t, elapsed, 500*time.Millisecond)
|
||||
})
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// HandleFailoverError — 基本切换流程
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
func TestHandleFailoverError_BasicSwitch(t *testing.T) {
|
||||
t.Run("非重试错误_非Antigravity_直接切换", func(t *testing.T) {
|
||||
mock := &mockTempUnscheduler{}
|
||||
fs := NewFailoverState(3, false)
|
||||
err := newTestFailoverErr(500, false, false)
|
||||
|
||||
action := fs.HandleFailoverError(context.Background(), mock, 100, "openai", err)
|
||||
|
||||
require.Equal(t, FailoverContinue, action)
|
||||
require.Equal(t, 1, fs.SwitchCount)
|
||||
require.Contains(t, fs.FailedAccountIDs, int64(100))
|
||||
require.Equal(t, err, fs.LastFailoverErr)
|
||||
require.False(t, fs.ForceCacheBilling)
|
||||
require.Empty(t, mock.calls, "不应调用 TempUnschedule")
|
||||
})
|
||||
|
||||
t.Run("非重试错误_Antigravity_第一次切换无延迟", func(t *testing.T) {
|
||||
// switchCount 从 0→1 时,sleepFailoverDelay(ctx, 1) 的延时 = (1-1)*1s = 0
|
||||
mock := &mockTempUnscheduler{}
|
||||
fs := NewFailoverState(3, false)
|
||||
err := newTestFailoverErr(500, false, false)
|
||||
|
||||
start := time.Now()
|
||||
action := fs.HandleFailoverError(context.Background(), mock, 100, service.PlatformAntigravity, err)
|
||||
elapsed := time.Since(start)
|
||||
|
||||
require.Equal(t, FailoverContinue, action)
|
||||
require.Equal(t, 1, fs.SwitchCount)
|
||||
require.Less(t, elapsed, 200*time.Millisecond, "第一次切换延迟应为 0")
|
||||
})
|
||||
|
||||
t.Run("非重试错误_Antigravity_第二次切换有1秒延迟", func(t *testing.T) {
|
||||
// switchCount 从 1→2 时,sleepFailoverDelay(ctx, 2) 的延时 = (2-1)*1s = 1s
|
||||
mock := &mockTempUnscheduler{}
|
||||
fs := NewFailoverState(3, false)
|
||||
fs.SwitchCount = 1 // 模拟已切换一次
|
||||
|
||||
err := newTestFailoverErr(500, false, false)
|
||||
start := time.Now()
|
||||
action := fs.HandleFailoverError(context.Background(), mock, 200, service.PlatformAntigravity, err)
|
||||
elapsed := time.Since(start)
|
||||
|
||||
require.Equal(t, FailoverContinue, action)
|
||||
require.Equal(t, 2, fs.SwitchCount)
|
||||
require.GreaterOrEqual(t, elapsed, 800*time.Millisecond, "第二次切换延迟应约 1s")
|
||||
require.Less(t, elapsed, 3*time.Second)
|
||||
})
|
||||
|
||||
t.Run("连续切换直到耗尽", func(t *testing.T) {
|
||||
mock := &mockTempUnscheduler{}
|
||||
fs := NewFailoverState(2, false)
|
||||
|
||||
// 第一次切换:0→1
|
||||
err1 := newTestFailoverErr(500, false, false)
|
||||
action := fs.HandleFailoverError(context.Background(), mock, 100, "openai", err1)
|
||||
require.Equal(t, FailoverContinue, action)
|
||||
require.Equal(t, 1, fs.SwitchCount)
|
||||
|
||||
// 第二次切换:1→2
|
||||
err2 := newTestFailoverErr(502, false, false)
|
||||
action = fs.HandleFailoverError(context.Background(), mock, 200, "openai", err2)
|
||||
require.Equal(t, FailoverContinue, action)
|
||||
require.Equal(t, 2, fs.SwitchCount)
|
||||
|
||||
// 第三次已耗尽:SwitchCount(2) >= MaxSwitches(2)
|
||||
err3 := newTestFailoverErr(503, false, false)
|
||||
action = fs.HandleFailoverError(context.Background(), mock, 300, "openai", err3)
|
||||
require.Equal(t, FailoverExhausted, action)
|
||||
require.Equal(t, 2, fs.SwitchCount, "耗尽时不应继续递增")
|
||||
|
||||
// 验证失败账号列表
|
||||
require.Len(t, fs.FailedAccountIDs, 3)
|
||||
require.Contains(t, fs.FailedAccountIDs, int64(100))
|
||||
require.Contains(t, fs.FailedAccountIDs, int64(200))
|
||||
require.Contains(t, fs.FailedAccountIDs, int64(300))
|
||||
|
||||
// LastFailoverErr 应为最后一次的错误
|
||||
require.Equal(t, err3, fs.LastFailoverErr)
|
||||
})
|
||||
|
||||
t.Run("MaxSwitches为0时首次即耗尽", func(t *testing.T) {
|
||||
mock := &mockTempUnscheduler{}
|
||||
fs := NewFailoverState(0, false)
|
||||
err := newTestFailoverErr(500, false, false)
|
||||
|
||||
action := fs.HandleFailoverError(context.Background(), mock, 100, "openai", err)
|
||||
require.Equal(t, FailoverExhausted, action)
|
||||
require.Equal(t, 0, fs.SwitchCount)
|
||||
require.Contains(t, fs.FailedAccountIDs, int64(100))
|
||||
})
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// HandleFailoverError — 缓存计费 (ForceCacheBilling)
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
func TestHandleFailoverError_CacheBilling(t *testing.T) {
|
||||
t.Run("hasBoundSession为true时设置ForceCacheBilling", func(t *testing.T) {
|
||||
mock := &mockTempUnscheduler{}
|
||||
fs := NewFailoverState(3, true) // hasBoundSession=true
|
||||
err := newTestFailoverErr(500, false, false)
|
||||
|
||||
fs.HandleFailoverError(context.Background(), mock, 100, "openai", err)
|
||||
require.True(t, fs.ForceCacheBilling)
|
||||
})
|
||||
|
||||
t.Run("failoverErr.ForceCacheBilling为true时设置", func(t *testing.T) {
|
||||
mock := &mockTempUnscheduler{}
|
||||
fs := NewFailoverState(3, false)
|
||||
err := newTestFailoverErr(500, false, true) // ForceCacheBilling=true
|
||||
|
||||
fs.HandleFailoverError(context.Background(), mock, 100, "openai", err)
|
||||
require.True(t, fs.ForceCacheBilling)
|
||||
})
|
||||
|
||||
t.Run("两者均为false时不设置", func(t *testing.T) {
|
||||
mock := &mockTempUnscheduler{}
|
||||
fs := NewFailoverState(3, false)
|
||||
err := newTestFailoverErr(500, false, false)
|
||||
|
||||
fs.HandleFailoverError(context.Background(), mock, 100, "openai", err)
|
||||
require.False(t, fs.ForceCacheBilling)
|
||||
})
|
||||
|
||||
t.Run("一旦设置不会被后续错误重置", func(t *testing.T) {
|
||||
mock := &mockTempUnscheduler{}
|
||||
fs := NewFailoverState(3, false)
|
||||
|
||||
// 第一次:ForceCacheBilling=true → 设置
|
||||
err1 := newTestFailoverErr(500, false, true)
|
||||
fs.HandleFailoverError(context.Background(), mock, 100, "openai", err1)
|
||||
require.True(t, fs.ForceCacheBilling)
|
||||
|
||||
// 第二次:ForceCacheBilling=false → 仍然保持 true
|
||||
err2 := newTestFailoverErr(502, false, false)
|
||||
fs.HandleFailoverError(context.Background(), mock, 200, "openai", err2)
|
||||
require.True(t, fs.ForceCacheBilling, "ForceCacheBilling 一旦设置不应被重置")
|
||||
})
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// HandleFailoverError — 同账号重试 (RetryableOnSameAccount)
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
func TestHandleFailoverError_SameAccountRetry(t *testing.T) {
|
||||
t.Run("第一次重试返回FailoverContinue", func(t *testing.T) {
|
||||
mock := &mockTempUnscheduler{}
|
||||
fs := NewFailoverState(3, false)
|
||||
err := newTestFailoverErr(400, true, false)
|
||||
|
||||
start := time.Now()
|
||||
action := fs.HandleFailoverError(context.Background(), mock, 100, "openai", err)
|
||||
elapsed := time.Since(start)
|
||||
|
||||
require.Equal(t, FailoverContinue, action)
|
||||
require.Equal(t, 1, fs.SameAccountRetryCount[100])
|
||||
require.Equal(t, 0, fs.SwitchCount, "同账号重试不应增加切换计数")
|
||||
require.NotContains(t, fs.FailedAccountIDs, int64(100), "同账号重试不应加入失败列表")
|
||||
require.Empty(t, mock.calls, "同账号重试期间不应调用 TempUnschedule")
|
||||
// 验证等待了 sameAccountRetryDelay (500ms)
|
||||
require.GreaterOrEqual(t, elapsed, 400*time.Millisecond)
|
||||
require.Less(t, elapsed, 2*time.Second)
|
||||
})
|
||||
|
||||
t.Run("第二次重试仍返回FailoverContinue", func(t *testing.T) {
|
||||
mock := &mockTempUnscheduler{}
|
||||
fs := NewFailoverState(3, false)
|
||||
err := newTestFailoverErr(400, true, false)
|
||||
|
||||
// 第一次
|
||||
action := fs.HandleFailoverError(context.Background(), mock, 100, "openai", err)
|
||||
require.Equal(t, FailoverContinue, action)
|
||||
require.Equal(t, 1, fs.SameAccountRetryCount[100])
|
||||
|
||||
// 第二次
|
||||
action = fs.HandleFailoverError(context.Background(), mock, 100, "openai", err)
|
||||
require.Equal(t, FailoverContinue, action)
|
||||
require.Equal(t, 2, fs.SameAccountRetryCount[100])
|
||||
|
||||
require.Empty(t, mock.calls, "两次重试期间均不应调用 TempUnschedule")
|
||||
})
|
||||
|
||||
t.Run("第三次重试耗尽_触发TempUnschedule并切换", func(t *testing.T) {
|
||||
mock := &mockTempUnscheduler{}
|
||||
fs := NewFailoverState(3, false)
|
||||
err := newTestFailoverErr(400, true, false)
|
||||
|
||||
// 第一次、第二次重试
|
||||
fs.HandleFailoverError(context.Background(), mock, 100, "openai", err)
|
||||
fs.HandleFailoverError(context.Background(), mock, 100, "openai", err)
|
||||
require.Equal(t, 2, fs.SameAccountRetryCount[100])
|
||||
|
||||
// 第三次:重试已达到 maxSameAccountRetries(2),应切换账号
|
||||
action := fs.HandleFailoverError(context.Background(), mock, 100, "openai", err)
|
||||
require.Equal(t, FailoverContinue, action)
|
||||
require.Equal(t, 1, fs.SwitchCount)
|
||||
require.Contains(t, fs.FailedAccountIDs, int64(100))
|
||||
|
||||
// 验证 TempUnschedule 被调用
|
||||
require.Len(t, mock.calls, 1)
|
||||
require.Equal(t, int64(100), mock.calls[0].accountID)
|
||||
require.Equal(t, err, mock.calls[0].failoverErr)
|
||||
})
|
||||
|
||||
t.Run("不同账号独立跟踪重试次数", func(t *testing.T) {
|
||||
mock := &mockTempUnscheduler{}
|
||||
fs := NewFailoverState(5, false)
|
||||
err := newTestFailoverErr(400, true, false)
|
||||
|
||||
// 账号 100 第一次重试
|
||||
action := fs.HandleFailoverError(context.Background(), mock, 100, "openai", err)
|
||||
require.Equal(t, FailoverContinue, action)
|
||||
require.Equal(t, 1, fs.SameAccountRetryCount[100])
|
||||
|
||||
// 账号 200 第一次重试(独立计数)
|
||||
action = fs.HandleFailoverError(context.Background(), mock, 200, "openai", err)
|
||||
require.Equal(t, FailoverContinue, action)
|
||||
require.Equal(t, 1, fs.SameAccountRetryCount[200])
|
||||
require.Equal(t, 1, fs.SameAccountRetryCount[100], "账号 100 的计数不应受影响")
|
||||
})
|
||||
|
||||
t.Run("重试耗尽后再次遇到同账号_直接切换", func(t *testing.T) {
|
||||
mock := &mockTempUnscheduler{}
|
||||
fs := NewFailoverState(5, false)
|
||||
err := newTestFailoverErr(400, true, false)
|
||||
|
||||
// 耗尽账号 100 的重试
|
||||
fs.HandleFailoverError(context.Background(), mock, 100, "openai", err)
|
||||
fs.HandleFailoverError(context.Background(), mock, 100, "openai", err)
|
||||
// 第三次: 重试耗尽 → 切换
|
||||
action := fs.HandleFailoverError(context.Background(), mock, 100, "openai", err)
|
||||
require.Equal(t, FailoverContinue, action)
|
||||
|
||||
// 再次遇到账号 100,计数仍为 2,条件不满足 → 直接切换
|
||||
action = fs.HandleFailoverError(context.Background(), mock, 100, "openai", err)
|
||||
require.Equal(t, FailoverContinue, action)
|
||||
require.Len(t, mock.calls, 2, "第二次耗尽也应调用 TempUnschedule")
|
||||
})
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// HandleFailoverError — TempUnschedule 调用验证
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
func TestHandleFailoverError_TempUnschedule(t *testing.T) {
|
||||
t.Run("非重试错误不调用TempUnschedule", func(t *testing.T) {
|
||||
mock := &mockTempUnscheduler{}
|
||||
fs := NewFailoverState(3, false)
|
||||
err := newTestFailoverErr(500, false, false) // RetryableOnSameAccount=false
|
||||
|
||||
fs.HandleFailoverError(context.Background(), mock, 100, "openai", err)
|
||||
require.Empty(t, mock.calls)
|
||||
})
|
||||
|
||||
t.Run("重试错误耗尽后调用TempUnschedule_传入正确参数", func(t *testing.T) {
|
||||
mock := &mockTempUnscheduler{}
|
||||
fs := NewFailoverState(3, false)
|
||||
err := newTestFailoverErr(502, true, false)
|
||||
|
||||
// 耗尽重试
|
||||
fs.HandleFailoverError(context.Background(), mock, 42, "openai", err)
|
||||
fs.HandleFailoverError(context.Background(), mock, 42, "openai", err)
|
||||
fs.HandleFailoverError(context.Background(), mock, 42, "openai", err)
|
||||
|
||||
require.Len(t, mock.calls, 1)
|
||||
require.Equal(t, int64(42), mock.calls[0].accountID)
|
||||
require.Equal(t, 502, mock.calls[0].failoverErr.StatusCode)
|
||||
require.True(t, mock.calls[0].failoverErr.RetryableOnSameAccount)
|
||||
})
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// HandleFailoverError — Context 取消
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
func TestHandleFailoverError_ContextCanceled(t *testing.T) {
|
||||
t.Run("同账号重试sleep期间context取消", func(t *testing.T) {
|
||||
mock := &mockTempUnscheduler{}
|
||||
fs := NewFailoverState(3, false)
|
||||
err := newTestFailoverErr(400, true, false)
|
||||
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
cancel() // 立即取消
|
||||
|
||||
start := time.Now()
|
||||
action := fs.HandleFailoverError(ctx, mock, 100, "openai", err)
|
||||
elapsed := time.Since(start)
|
||||
|
||||
require.Equal(t, FailoverCanceled, action)
|
||||
require.Less(t, elapsed, 100*time.Millisecond, "应立即返回")
|
||||
// 重试计数仍应递增
|
||||
require.Equal(t, 1, fs.SameAccountRetryCount[100])
|
||||
})
|
||||
|
||||
t.Run("Antigravity延迟期间context取消", func(t *testing.T) {
|
||||
mock := &mockTempUnscheduler{}
|
||||
fs := NewFailoverState(3, false)
|
||||
fs.SwitchCount = 1 // 下一次 switchCount=2 → delay = 1s
|
||||
err := newTestFailoverErr(500, false, false)
|
||||
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
cancel() // 立即取消
|
||||
|
||||
start := time.Now()
|
||||
action := fs.HandleFailoverError(ctx, mock, 100, service.PlatformAntigravity, err)
|
||||
elapsed := time.Since(start)
|
||||
|
||||
require.Equal(t, FailoverCanceled, action)
|
||||
require.Less(t, elapsed, 100*time.Millisecond, "应立即返回而非等待 1s")
|
||||
})
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// HandleFailoverError — FailedAccountIDs 跟踪
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
func TestHandleFailoverError_FailedAccountIDs(t *testing.T) {
|
||||
t.Run("切换时添加到失败列表", func(t *testing.T) {
|
||||
mock := &mockTempUnscheduler{}
|
||||
fs := NewFailoverState(3, false)
|
||||
|
||||
fs.HandleFailoverError(context.Background(), mock, 100, "openai", newTestFailoverErr(500, false, false))
|
||||
require.Contains(t, fs.FailedAccountIDs, int64(100))
|
||||
|
||||
fs.HandleFailoverError(context.Background(), mock, 200, "openai", newTestFailoverErr(502, false, false))
|
||||
require.Contains(t, fs.FailedAccountIDs, int64(200))
|
||||
require.Len(t, fs.FailedAccountIDs, 2)
|
||||
})
|
||||
|
||||
t.Run("耗尽时也添加到失败列表", func(t *testing.T) {
|
||||
mock := &mockTempUnscheduler{}
|
||||
fs := NewFailoverState(0, false)
|
||||
|
||||
action := fs.HandleFailoverError(context.Background(), mock, 100, "openai", newTestFailoverErr(500, false, false))
|
||||
require.Equal(t, FailoverExhausted, action)
|
||||
require.Contains(t, fs.FailedAccountIDs, int64(100))
|
||||
})
|
||||
|
||||
t.Run("同账号重试期间不添加到失败列表", func(t *testing.T) {
|
||||
mock := &mockTempUnscheduler{}
|
||||
fs := NewFailoverState(3, false)
|
||||
|
||||
action := fs.HandleFailoverError(context.Background(), mock, 100, "openai", newTestFailoverErr(400, true, false))
|
||||
require.Equal(t, FailoverContinue, action)
|
||||
require.NotContains(t, fs.FailedAccountIDs, int64(100))
|
||||
})
|
||||
|
||||
t.Run("同一账号多次切换不重复添加", func(t *testing.T) {
|
||||
mock := &mockTempUnscheduler{}
|
||||
fs := NewFailoverState(5, false)
|
||||
|
||||
fs.HandleFailoverError(context.Background(), mock, 100, "openai", newTestFailoverErr(500, false, false))
|
||||
fs.HandleFailoverError(context.Background(), mock, 100, "openai", newTestFailoverErr(500, false, false))
|
||||
require.Len(t, fs.FailedAccountIDs, 1, "map 天然去重")
|
||||
})
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// HandleFailoverError — LastFailoverErr 更新
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
func TestHandleFailoverError_LastFailoverErr(t *testing.T) {
|
||||
t.Run("每次调用都更新LastFailoverErr", func(t *testing.T) {
|
||||
mock := &mockTempUnscheduler{}
|
||||
fs := NewFailoverState(3, false)
|
||||
|
||||
err1 := newTestFailoverErr(500, false, false)
|
||||
fs.HandleFailoverError(context.Background(), mock, 100, "openai", err1)
|
||||
require.Equal(t, err1, fs.LastFailoverErr)
|
||||
|
||||
err2 := newTestFailoverErr(502, false, false)
|
||||
fs.HandleFailoverError(context.Background(), mock, 200, "openai", err2)
|
||||
require.Equal(t, err2, fs.LastFailoverErr)
|
||||
})
|
||||
|
||||
t.Run("同账号重试时也更新LastFailoverErr", func(t *testing.T) {
|
||||
mock := &mockTempUnscheduler{}
|
||||
fs := NewFailoverState(3, false)
|
||||
|
||||
err := newTestFailoverErr(400, true, false)
|
||||
fs.HandleFailoverError(context.Background(), mock, 100, "openai", err)
|
||||
require.Equal(t, err, fs.LastFailoverErr)
|
||||
})
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// HandleFailoverError — 综合集成场景
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
func TestHandleFailoverError_IntegrationScenario(t *testing.T) {
|
||||
t.Run("模拟完整failover流程_多账号混合重试与切换", func(t *testing.T) {
|
||||
mock := &mockTempUnscheduler{}
|
||||
fs := NewFailoverState(3, true) // hasBoundSession=true
|
||||
|
||||
// 1. 账号 100 遇到可重试错误,同账号重试 2 次
|
||||
retryErr := newTestFailoverErr(400, true, false)
|
||||
action := fs.HandleFailoverError(context.Background(), mock, 100, "openai", retryErr)
|
||||
require.Equal(t, FailoverContinue, action)
|
||||
require.True(t, fs.ForceCacheBilling, "hasBoundSession=true 应设置 ForceCacheBilling")
|
||||
|
||||
action = fs.HandleFailoverError(context.Background(), mock, 100, "openai", retryErr)
|
||||
require.Equal(t, FailoverContinue, action)
|
||||
|
||||
// 2. 账号 100 重试耗尽 → TempUnschedule + 切换
|
||||
action = fs.HandleFailoverError(context.Background(), mock, 100, "openai", retryErr)
|
||||
require.Equal(t, FailoverContinue, action)
|
||||
require.Equal(t, 1, fs.SwitchCount)
|
||||
require.Len(t, mock.calls, 1)
|
||||
|
||||
// 3. 账号 200 遇到不可重试错误 → 直接切换
|
||||
switchErr := newTestFailoverErr(500, false, false)
|
||||
action = fs.HandleFailoverError(context.Background(), mock, 200, "openai", switchErr)
|
||||
require.Equal(t, FailoverContinue, action)
|
||||
require.Equal(t, 2, fs.SwitchCount)
|
||||
|
||||
// 4. 账号 300 遇到不可重试错误 → 再切换
|
||||
action = fs.HandleFailoverError(context.Background(), mock, 300, "openai", switchErr)
|
||||
require.Equal(t, FailoverContinue, action)
|
||||
require.Equal(t, 3, fs.SwitchCount)
|
||||
|
||||
// 5. 账号 400 → 已耗尽 (SwitchCount=3 >= MaxSwitches=3)
|
||||
action = fs.HandleFailoverError(context.Background(), mock, 400, "openai", switchErr)
|
||||
require.Equal(t, FailoverExhausted, action)
|
||||
|
||||
// 最终状态验证
|
||||
require.Equal(t, 3, fs.SwitchCount, "耗尽时不再递增")
|
||||
require.Len(t, fs.FailedAccountIDs, 4, "4个不同账号都在失败列表中")
|
||||
require.True(t, fs.ForceCacheBilling)
|
||||
require.Len(t, mock.calls, 1, "只有账号 100 触发了 TempUnschedule")
|
||||
})
|
||||
|
||||
t.Run("模拟Antigravity平台完整流程", func(t *testing.T) {
|
||||
mock := &mockTempUnscheduler{}
|
||||
fs := NewFailoverState(2, false)
|
||||
|
||||
err := newTestFailoverErr(500, false, false)
|
||||
|
||||
// 第一次切换:delay = 0s
|
||||
start := time.Now()
|
||||
action := fs.HandleFailoverError(context.Background(), mock, 100, service.PlatformAntigravity, err)
|
||||
elapsed := time.Since(start)
|
||||
require.Equal(t, FailoverContinue, action)
|
||||
require.Less(t, elapsed, 200*time.Millisecond, "第一次切换延迟为 0")
|
||||
|
||||
// 第二次切换:delay = 1s
|
||||
start = time.Now()
|
||||
action = fs.HandleFailoverError(context.Background(), mock, 200, service.PlatformAntigravity, err)
|
||||
elapsed = time.Since(start)
|
||||
require.Equal(t, FailoverContinue, action)
|
||||
require.GreaterOrEqual(t, elapsed, 800*time.Millisecond, "第二次切换延迟约 1s")
|
||||
|
||||
// 第三次:耗尽(无延迟,因为在检查延迟之前就返回了)
|
||||
start = time.Now()
|
||||
action = fs.HandleFailoverError(context.Background(), mock, 300, service.PlatformAntigravity, err)
|
||||
elapsed = time.Since(start)
|
||||
require.Equal(t, FailoverExhausted, action)
|
||||
require.Less(t, elapsed, 200*time.Millisecond, "耗尽时不应有延迟")
|
||||
})
|
||||
|
||||
t.Run("ForceCacheBilling通过错误标志设置", func(t *testing.T) {
|
||||
mock := &mockTempUnscheduler{}
|
||||
fs := NewFailoverState(3, false) // hasBoundSession=false
|
||||
|
||||
// 第一次:ForceCacheBilling=false
|
||||
err1 := newTestFailoverErr(500, false, false)
|
||||
fs.HandleFailoverError(context.Background(), mock, 100, "openai", err1)
|
||||
require.False(t, fs.ForceCacheBilling)
|
||||
|
||||
// 第二次:ForceCacheBilling=true(Antigravity 粘性会话切换)
|
||||
err2 := newTestFailoverErr(500, false, true)
|
||||
fs.HandleFailoverError(context.Background(), mock, 200, "openai", err2)
|
||||
require.True(t, fs.ForceCacheBilling, "错误标志应触发 ForceCacheBilling")
|
||||
|
||||
// 第三次:ForceCacheBilling=false,但状态仍保持 true
|
||||
err3 := newTestFailoverErr(500, false, false)
|
||||
fs.HandleFailoverError(context.Background(), mock, 300, "openai", err3)
|
||||
require.True(t, fs.ForceCacheBilling, "不应重置")
|
||||
})
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// HandleFailoverError — 边界条件
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
func TestHandleFailoverError_EdgeCases(t *testing.T) {
|
||||
t.Run("StatusCode为0的错误也能正常处理", func(t *testing.T) {
|
||||
mock := &mockTempUnscheduler{}
|
||||
fs := NewFailoverState(3, false)
|
||||
err := newTestFailoverErr(0, false, false)
|
||||
|
||||
action := fs.HandleFailoverError(context.Background(), mock, 100, "openai", err)
|
||||
require.Equal(t, FailoverContinue, action)
|
||||
})
|
||||
|
||||
t.Run("AccountID为0也能正常跟踪", func(t *testing.T) {
|
||||
mock := &mockTempUnscheduler{}
|
||||
fs := NewFailoverState(3, false)
|
||||
err := newTestFailoverErr(500, true, false)
|
||||
|
||||
action := fs.HandleFailoverError(context.Background(), mock, 0, "openai", err)
|
||||
require.Equal(t, FailoverContinue, action)
|
||||
require.Equal(t, 1, fs.SameAccountRetryCount[0])
|
||||
})
|
||||
|
||||
t.Run("负AccountID也能正常跟踪", func(t *testing.T) {
|
||||
mock := &mockTempUnscheduler{}
|
||||
fs := NewFailoverState(3, false)
|
||||
err := newTestFailoverErr(500, true, false)
|
||||
|
||||
action := fs.HandleFailoverError(context.Background(), mock, -1, "openai", err)
|
||||
require.Equal(t, FailoverContinue, action)
|
||||
require.Equal(t, 1, fs.SameAccountRetryCount[-1])
|
||||
})
|
||||
|
||||
t.Run("空平台名称不触发Antigravity延迟", func(t *testing.T) {
|
||||
mock := &mockTempUnscheduler{}
|
||||
fs := NewFailoverState(3, false)
|
||||
fs.SwitchCount = 1
|
||||
err := newTestFailoverErr(500, false, false)
|
||||
|
||||
start := time.Now()
|
||||
action := fs.HandleFailoverError(context.Background(), mock, 100, "", err)
|
||||
elapsed := time.Since(start)
|
||||
|
||||
require.Equal(t, FailoverContinue, action)
|
||||
require.Less(t, elapsed, 200*time.Millisecond, "空平台不应触发 Antigravity 延迟")
|
||||
})
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// HandleSelectionExhausted 测试
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
func TestHandleSelectionExhausted(t *testing.T) {
|
||||
t.Run("无LastFailoverErr时返回Exhausted", func(t *testing.T) {
|
||||
fs := NewFailoverState(3, false)
|
||||
// LastFailoverErr 为 nil
|
||||
|
||||
action := fs.HandleSelectionExhausted(context.Background())
|
||||
require.Equal(t, FailoverExhausted, action)
|
||||
})
|
||||
|
||||
t.Run("非503错误返回Exhausted", func(t *testing.T) {
|
||||
fs := NewFailoverState(3, false)
|
||||
fs.LastFailoverErr = newTestFailoverErr(500, false, false)
|
||||
|
||||
action := fs.HandleSelectionExhausted(context.Background())
|
||||
require.Equal(t, FailoverExhausted, action)
|
||||
})
|
||||
|
||||
t.Run("503且未耗尽_等待后返回Continue并清除失败列表", func(t *testing.T) {
|
||||
fs := NewFailoverState(3, false)
|
||||
fs.LastFailoverErr = newTestFailoverErr(503, false, false)
|
||||
fs.FailedAccountIDs[100] = struct{}{}
|
||||
fs.SwitchCount = 1
|
||||
|
||||
start := time.Now()
|
||||
action := fs.HandleSelectionExhausted(context.Background())
|
||||
elapsed := time.Since(start)
|
||||
|
||||
require.Equal(t, FailoverContinue, action)
|
||||
require.Empty(t, fs.FailedAccountIDs, "应清除失败账号列表")
|
||||
require.GreaterOrEqual(t, elapsed, 1500*time.Millisecond, "应等待约 2s")
|
||||
require.Less(t, elapsed, 5*time.Second)
|
||||
})
|
||||
|
||||
t.Run("503但SwitchCount已超过MaxSwitches_返回Exhausted", func(t *testing.T) {
|
||||
fs := NewFailoverState(2, false)
|
||||
fs.LastFailoverErr = newTestFailoverErr(503, false, false)
|
||||
fs.SwitchCount = 3 // > MaxSwitches(2)
|
||||
|
||||
start := time.Now()
|
||||
action := fs.HandleSelectionExhausted(context.Background())
|
||||
elapsed := time.Since(start)
|
||||
|
||||
require.Equal(t, FailoverExhausted, action)
|
||||
require.Less(t, elapsed, 100*time.Millisecond, "不应等待")
|
||||
})
|
||||
|
||||
t.Run("503但context已取消_返回Canceled", func(t *testing.T) {
|
||||
fs := NewFailoverState(3, false)
|
||||
fs.LastFailoverErr = newTestFailoverErr(503, false, false)
|
||||
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
cancel()
|
||||
|
||||
start := time.Now()
|
||||
action := fs.HandleSelectionExhausted(ctx)
|
||||
elapsed := time.Since(start)
|
||||
|
||||
require.Equal(t, FailoverCanceled, action)
|
||||
require.Less(t, elapsed, 100*time.Millisecond, "应立即返回")
|
||||
})
|
||||
|
||||
t.Run("503且SwitchCount等于MaxSwitches_仍可重试", func(t *testing.T) {
|
||||
fs := NewFailoverState(2, false)
|
||||
fs.LastFailoverErr = newTestFailoverErr(503, false, false)
|
||||
fs.SwitchCount = 2 // == MaxSwitches,条件是 <=,仍可重试
|
||||
|
||||
action := fs.HandleSelectionExhausted(context.Background())
|
||||
require.Equal(t, FailoverContinue, action)
|
||||
})
|
||||
}
|
||||
File diff suppressed because it is too large
Load Diff
@@ -0,0 +1,49 @@
|
||||
package handler
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"testing"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestGatewayEnsureForwardErrorResponse_WritesFallbackWhenNotWritten(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
w := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(w)
|
||||
c.Request = httptest.NewRequest(http.MethodGet, "/", nil)
|
||||
|
||||
h := &GatewayHandler{}
|
||||
wrote := h.ensureForwardErrorResponse(c, false)
|
||||
|
||||
require.True(t, wrote)
|
||||
require.Equal(t, http.StatusBadGateway, w.Code)
|
||||
|
||||
var parsed map[string]any
|
||||
err := json.Unmarshal(w.Body.Bytes(), &parsed)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, "error", parsed["type"])
|
||||
errorObj, ok := parsed["error"].(map[string]any)
|
||||
require.True(t, ok)
|
||||
assert.Equal(t, "upstream_error", errorObj["type"])
|
||||
assert.Equal(t, "Upstream request failed", errorObj["message"])
|
||||
}
|
||||
|
||||
func TestGatewayEnsureForwardErrorResponse_DoesNotOverrideWrittenResponse(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
w := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(w)
|
||||
c.Request = httptest.NewRequest(http.MethodGet, "/", nil)
|
||||
c.String(http.StatusTeapot, "already written")
|
||||
|
||||
h := &GatewayHandler{}
|
||||
wrote := h.ensureForwardErrorResponse(c, false)
|
||||
|
||||
require.False(t, wrote)
|
||||
require.Equal(t, http.StatusTeapot, w.Code)
|
||||
assert.Equal(t, "already written", w.Body.String())
|
||||
}
|
||||
@@ -1,51 +0,0 @@
|
||||
package handler
|
||||
|
||||
import (
|
||||
"context"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// sleepAntigravitySingleAccountBackoff 测试
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
func TestSleepAntigravitySingleAccountBackoff_ReturnsTrue(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
start := time.Now()
|
||||
ok := sleepAntigravitySingleAccountBackoff(ctx, 1)
|
||||
elapsed := time.Since(start)
|
||||
|
||||
require.True(t, ok, "should return true when context is not canceled")
|
||||
// 固定延迟 2s
|
||||
require.GreaterOrEqual(t, elapsed, 1500*time.Millisecond, "should wait approximately 2s")
|
||||
require.Less(t, elapsed, 5*time.Second, "should not wait too long")
|
||||
}
|
||||
|
||||
func TestSleepAntigravitySingleAccountBackoff_ContextCanceled(t *testing.T) {
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
cancel() // 立即取消
|
||||
|
||||
start := time.Now()
|
||||
ok := sleepAntigravitySingleAccountBackoff(ctx, 1)
|
||||
elapsed := time.Since(start)
|
||||
|
||||
require.False(t, ok, "should return false when context is canceled")
|
||||
require.Less(t, elapsed, 500*time.Millisecond, "should return immediately on cancel")
|
||||
}
|
||||
|
||||
func TestSleepAntigravitySingleAccountBackoff_FixedDelay(t *testing.T) {
|
||||
// 验证不同 retryCount 都使用固定 2s 延迟
|
||||
ctx := context.Background()
|
||||
|
||||
start := time.Now()
|
||||
ok := sleepAntigravitySingleAccountBackoff(ctx, 5)
|
||||
elapsed := time.Since(start)
|
||||
|
||||
require.True(t, ok)
|
||||
// 即使 retryCount=5,延迟仍然是固定的 2s
|
||||
require.GreaterOrEqual(t, elapsed, 1500*time.Millisecond)
|
||||
require.Less(t, elapsed, 5*time.Second)
|
||||
}
|
||||
@@ -0,0 +1,348 @@
|
||||
//go:build unit
|
||||
|
||||
package handler
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"encoding/json"
|
||||
"net/http/httptest"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/config"
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/ctxkey"
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
|
||||
middleware "github.com/Wei-Shaw/sub2api/internal/server/middleware"
|
||||
"github.com/Wei-Shaw/sub2api/internal/service"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
// 目标:严格验证“antigravity 账号通过 /v1/messages 提供 Claude 服务时”,
|
||||
// 当账号 credentials.intercept_warmup_requests=true 且请求为 Warmup 时,
|
||||
// 后端会在转发上游前直接拦截并返回 mock 响应(不依赖上游)。
|
||||
|
||||
type fakeSchedulerCache struct {
|
||||
accounts []*service.Account
|
||||
}
|
||||
|
||||
func (f *fakeSchedulerCache) GetSnapshot(_ context.Context, _ service.SchedulerBucket) ([]*service.Account, bool, error) {
|
||||
return f.accounts, true, nil
|
||||
}
|
||||
func (f *fakeSchedulerCache) SetSnapshot(_ context.Context, _ service.SchedulerBucket, _ []service.Account) error {
|
||||
return nil
|
||||
}
|
||||
func (f *fakeSchedulerCache) GetAccount(_ context.Context, _ int64) (*service.Account, error) {
|
||||
return nil, nil
|
||||
}
|
||||
func (f *fakeSchedulerCache) SetAccount(_ context.Context, _ *service.Account) error { return nil }
|
||||
func (f *fakeSchedulerCache) DeleteAccount(_ context.Context, _ int64) error { return nil }
|
||||
func (f *fakeSchedulerCache) UpdateLastUsed(_ context.Context, _ map[int64]time.Time) error {
|
||||
return nil
|
||||
}
|
||||
func (f *fakeSchedulerCache) TryLockBucket(_ context.Context, _ service.SchedulerBucket, _ time.Duration) (bool, error) {
|
||||
return true, nil
|
||||
}
|
||||
func (f *fakeSchedulerCache) ListBuckets(_ context.Context) ([]service.SchedulerBucket, error) {
|
||||
return nil, nil
|
||||
}
|
||||
func (f *fakeSchedulerCache) GetOutboxWatermark(_ context.Context) (int64, error) { return 0, nil }
|
||||
func (f *fakeSchedulerCache) SetOutboxWatermark(_ context.Context, _ int64) error { return nil }
|
||||
|
||||
type fakeGroupRepo struct {
|
||||
group *service.Group
|
||||
}
|
||||
|
||||
func (f *fakeGroupRepo) Create(context.Context, *service.Group) error { return nil }
|
||||
func (f *fakeGroupRepo) GetByID(context.Context, int64) (*service.Group, error) {
|
||||
return f.group, nil
|
||||
}
|
||||
func (f *fakeGroupRepo) GetByIDLite(context.Context, int64) (*service.Group, error) {
|
||||
return f.group, nil
|
||||
}
|
||||
func (f *fakeGroupRepo) Update(context.Context, *service.Group) error { return nil }
|
||||
func (f *fakeGroupRepo) Delete(context.Context, int64) error { return nil }
|
||||
func (f *fakeGroupRepo) DeleteCascade(context.Context, int64) ([]int64, error) { return nil, nil }
|
||||
func (f *fakeGroupRepo) List(context.Context, pagination.PaginationParams) ([]service.Group, *pagination.PaginationResult, error) {
|
||||
return nil, nil, nil
|
||||
}
|
||||
func (f *fakeGroupRepo) ListWithFilters(context.Context, pagination.PaginationParams, string, string, string, *bool) ([]service.Group, *pagination.PaginationResult, error) {
|
||||
return nil, nil, nil
|
||||
}
|
||||
func (f *fakeGroupRepo) ListActive(context.Context) ([]service.Group, error) { return nil, nil }
|
||||
func (f *fakeGroupRepo) ListActiveByPlatform(context.Context, string) ([]service.Group, error) {
|
||||
return nil, nil
|
||||
}
|
||||
func (f *fakeGroupRepo) ExistsByName(context.Context, string) (bool, error) { return false, nil }
|
||||
func (f *fakeGroupRepo) GetAccountCount(context.Context, int64) (int64, error) { return 0, nil }
|
||||
func (f *fakeGroupRepo) DeleteAccountGroupsByGroupID(context.Context, int64) (int64, error) {
|
||||
return 0, nil
|
||||
}
|
||||
func (f *fakeGroupRepo) GetAccountIDsByGroupIDs(context.Context, []int64) ([]int64, error) {
|
||||
return nil, nil
|
||||
}
|
||||
func (f *fakeGroupRepo) BindAccountsToGroup(context.Context, int64, []int64) error { return nil }
|
||||
func (f *fakeGroupRepo) UpdateSortOrders(context.Context, []service.GroupSortOrderUpdate) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
type fakeConcurrencyCache struct{}
|
||||
|
||||
func (f *fakeConcurrencyCache) AcquireAccountSlot(context.Context, int64, int, string) (bool, error) {
|
||||
return true, nil
|
||||
}
|
||||
func (f *fakeConcurrencyCache) ReleaseAccountSlot(context.Context, int64, string) error { return nil }
|
||||
func (f *fakeConcurrencyCache) GetAccountConcurrency(context.Context, int64) (int, error) {
|
||||
return 0, nil
|
||||
}
|
||||
func (f *fakeConcurrencyCache) IncrementAccountWaitCount(context.Context, int64, int) (bool, error) {
|
||||
return true, nil
|
||||
}
|
||||
func (f *fakeConcurrencyCache) DecrementAccountWaitCount(context.Context, int64) error { return nil }
|
||||
func (f *fakeConcurrencyCache) GetAccountWaitingCount(context.Context, int64) (int, error) {
|
||||
return 0, nil
|
||||
}
|
||||
func (f *fakeConcurrencyCache) AcquireUserSlot(context.Context, int64, int, string) (bool, error) {
|
||||
return true, nil
|
||||
}
|
||||
func (f *fakeConcurrencyCache) ReleaseUserSlot(context.Context, int64, string) error { return nil }
|
||||
func (f *fakeConcurrencyCache) GetUserConcurrency(context.Context, int64) (int, error) { return 0, nil }
|
||||
func (f *fakeConcurrencyCache) IncrementWaitCount(context.Context, int64, int) (bool, error) {
|
||||
return true, nil
|
||||
}
|
||||
func (f *fakeConcurrencyCache) DecrementWaitCount(context.Context, int64) error { return nil }
|
||||
func (f *fakeConcurrencyCache) GetAccountsLoadBatch(context.Context, []service.AccountWithConcurrency) (map[int64]*service.AccountLoadInfo, error) {
|
||||
return map[int64]*service.AccountLoadInfo{}, nil
|
||||
}
|
||||
func (f *fakeConcurrencyCache) GetUsersLoadBatch(context.Context, []service.UserWithConcurrency) (map[int64]*service.UserLoadInfo, error) {
|
||||
return map[int64]*service.UserLoadInfo{}, nil
|
||||
}
|
||||
func (f *fakeConcurrencyCache) GetAccountConcurrencyBatch(_ context.Context, accountIDs []int64) (map[int64]int, error) {
|
||||
result := make(map[int64]int, len(accountIDs))
|
||||
for _, id := range accountIDs {
|
||||
result[id] = 0
|
||||
}
|
||||
return result, nil
|
||||
}
|
||||
func (f *fakeConcurrencyCache) CleanupExpiredAccountSlots(context.Context, int64) error { return nil }
|
||||
|
||||
func newTestGatewayHandler(t *testing.T, group *service.Group, accounts []*service.Account) (*GatewayHandler, func()) {
|
||||
t.Helper()
|
||||
|
||||
schedulerCache := &fakeSchedulerCache{accounts: accounts}
|
||||
schedulerSnapshot := service.NewSchedulerSnapshotService(schedulerCache, nil, nil, nil, nil)
|
||||
|
||||
gwSvc := service.NewGatewayService(
|
||||
nil, // accountRepo (not used: scheduler snapshot hit)
|
||||
&fakeGroupRepo{group: group},
|
||||
nil, // usageLogRepo
|
||||
nil, // userRepo
|
||||
nil, // userSubRepo
|
||||
nil, // userGroupRateRepo
|
||||
nil, // cache (disable sticky)
|
||||
nil, // cfg
|
||||
schedulerSnapshot,
|
||||
nil, // concurrencyService (disable load-aware; tryAcquire always acquired)
|
||||
nil, // billingService
|
||||
nil, // rateLimitService
|
||||
nil, // billingCacheService
|
||||
nil, // identityService
|
||||
nil, // httpUpstream
|
||||
nil, // deferredService
|
||||
nil, // claudeTokenProvider
|
||||
nil, // sessionLimitCache
|
||||
nil, // rpmCache
|
||||
nil, // digestStore
|
||||
)
|
||||
|
||||
// RunModeSimple:跳过计费检查,避免引入 repo/cache 依赖。
|
||||
cfg := &config.Config{RunMode: config.RunModeSimple}
|
||||
billingCacheSvc := service.NewBillingCacheService(nil, nil, nil, nil, cfg)
|
||||
|
||||
concurrencySvc := service.NewConcurrencyService(&fakeConcurrencyCache{})
|
||||
concurrencyHelper := NewConcurrencyHelper(concurrencySvc, SSEPingFormatClaude, 0)
|
||||
|
||||
h := &GatewayHandler{
|
||||
gatewayService: gwSvc,
|
||||
billingCacheService: billingCacheSvc,
|
||||
concurrencyHelper: concurrencyHelper,
|
||||
// 这些字段对本测试不敏感,保持较小即可
|
||||
maxAccountSwitches: 1,
|
||||
maxAccountSwitchesGemini: 1,
|
||||
}
|
||||
|
||||
cleanup := func() {
|
||||
billingCacheSvc.Stop()
|
||||
}
|
||||
return h, cleanup
|
||||
}
|
||||
|
||||
func TestGatewayHandlerMessages_InterceptWarmup_AntigravityAccount_MixedSchedulingV1(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
|
||||
groupID := int64(2001)
|
||||
accountID := int64(1001)
|
||||
|
||||
group := &service.Group{
|
||||
ID: groupID,
|
||||
Hydrated: true,
|
||||
Platform: service.PlatformAnthropic, // /v1/messages(Claude兼容)入口
|
||||
Status: service.StatusActive,
|
||||
}
|
||||
|
||||
account := &service.Account{
|
||||
ID: accountID,
|
||||
Name: "ag-1",
|
||||
Platform: service.PlatformAntigravity,
|
||||
Type: service.AccountTypeOAuth,
|
||||
Credentials: map[string]any{
|
||||
"access_token": "tok_xxx",
|
||||
"intercept_warmup_requests": true,
|
||||
},
|
||||
Extra: map[string]any{
|
||||
"mixed_scheduling": true, // 关键:允许被 anthropic 分组混合调度选中
|
||||
},
|
||||
Concurrency: 1,
|
||||
Priority: 1,
|
||||
Status: service.StatusActive,
|
||||
Schedulable: true,
|
||||
AccountGroups: []service.AccountGroup{{AccountID: accountID, GroupID: groupID}},
|
||||
}
|
||||
|
||||
h, cleanup := newTestGatewayHandler(t, group, []*service.Account{account})
|
||||
defer cleanup()
|
||||
|
||||
rec := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(rec)
|
||||
|
||||
body := []byte(`{
|
||||
"model": "claude-sonnet-4-5",
|
||||
"max_tokens": 256,
|
||||
"messages": [{"role":"user","content":[{"type":"text","text":"Warmup"}]}]
|
||||
}`)
|
||||
req := httptest.NewRequest("POST", "/v1/messages", bytes.NewReader(body))
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
req = req.WithContext(context.WithValue(req.Context(), ctxkey.Group, group))
|
||||
c.Request = req
|
||||
|
||||
apiKey := &service.APIKey{
|
||||
ID: 3001,
|
||||
UserID: 4001,
|
||||
GroupID: &groupID,
|
||||
Status: service.StatusActive,
|
||||
User: &service.User{
|
||||
ID: 4001,
|
||||
Concurrency: 10,
|
||||
Balance: 100,
|
||||
},
|
||||
Group: group,
|
||||
}
|
||||
|
||||
c.Set(string(middleware.ContextKeyAPIKey), apiKey)
|
||||
c.Set(string(middleware.ContextKeyUser), middleware.AuthSubject{UserID: apiKey.UserID, Concurrency: 10})
|
||||
|
||||
h.Messages(c)
|
||||
|
||||
require.Equal(t, 200, rec.Code)
|
||||
|
||||
// 断言:确实选中了 antigravity 账号(不是纯函数测试,而是从 Handler 里验证调度结果)
|
||||
selected, ok := c.Get(opsAccountIDKey)
|
||||
require.True(t, ok)
|
||||
require.Equal(t, accountID, selected)
|
||||
|
||||
var resp map[string]any
|
||||
require.NoError(t, json.Unmarshal(rec.Body.Bytes(), &resp))
|
||||
require.Equal(t, "msg_mock_warmup", resp["id"])
|
||||
require.Equal(t, "claude-sonnet-4-5", resp["model"])
|
||||
|
||||
content, ok := resp["content"].([]any)
|
||||
require.True(t, ok)
|
||||
require.Len(t, content, 1)
|
||||
first, ok := content[0].(map[string]any)
|
||||
require.True(t, ok)
|
||||
require.Equal(t, "New Conversation", first["text"])
|
||||
}
|
||||
|
||||
func TestGatewayHandlerMessages_InterceptWarmup_AntigravityAccount_ForcePlatform(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
|
||||
groupID := int64(2002)
|
||||
accountID := int64(1002)
|
||||
|
||||
group := &service.Group{
|
||||
ID: groupID,
|
||||
Hydrated: true,
|
||||
Platform: service.PlatformAntigravity,
|
||||
Status: service.StatusActive,
|
||||
}
|
||||
|
||||
account := &service.Account{
|
||||
ID: accountID,
|
||||
Name: "ag-2",
|
||||
Platform: service.PlatformAntigravity,
|
||||
Type: service.AccountTypeOAuth,
|
||||
Credentials: map[string]any{
|
||||
"access_token": "tok_xxx",
|
||||
"intercept_warmup_requests": true,
|
||||
},
|
||||
Concurrency: 1,
|
||||
Priority: 1,
|
||||
Status: service.StatusActive,
|
||||
Schedulable: true,
|
||||
AccountGroups: []service.AccountGroup{{AccountID: accountID, GroupID: groupID}},
|
||||
}
|
||||
|
||||
h, cleanup := newTestGatewayHandler(t, group, []*service.Account{account})
|
||||
defer cleanup()
|
||||
|
||||
rec := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(rec)
|
||||
|
||||
body := []byte(`{
|
||||
"model": "claude-sonnet-4-5",
|
||||
"max_tokens": 256,
|
||||
"messages": [{"role":"user","content":[{"type":"text","text":"Warmup"}]}]
|
||||
}`)
|
||||
req := httptest.NewRequest("POST", "/antigravity/v1/messages", bytes.NewReader(body))
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
|
||||
// 模拟 routes/gateway.go 里的 ForcePlatform 中间件效果:
|
||||
// - 写入 request.Context(Service读取)
|
||||
// - 写入 gin.Context(Handler快速读取)
|
||||
ctx := context.WithValue(req.Context(), ctxkey.Group, group)
|
||||
ctx = context.WithValue(ctx, ctxkey.ForcePlatform, service.PlatformAntigravity)
|
||||
req = req.WithContext(ctx)
|
||||
c.Request = req
|
||||
c.Set(string(middleware.ContextKeyForcePlatform), service.PlatformAntigravity)
|
||||
|
||||
apiKey := &service.APIKey{
|
||||
ID: 3002,
|
||||
UserID: 4002,
|
||||
GroupID: &groupID,
|
||||
Status: service.StatusActive,
|
||||
User: &service.User{
|
||||
ID: 4002,
|
||||
Concurrency: 10,
|
||||
Balance: 100,
|
||||
},
|
||||
Group: group,
|
||||
}
|
||||
|
||||
c.Set(string(middleware.ContextKeyAPIKey), apiKey)
|
||||
c.Set(string(middleware.ContextKeyUser), middleware.AuthSubject{UserID: apiKey.UserID, Concurrency: 10})
|
||||
|
||||
h.Messages(c)
|
||||
|
||||
require.Equal(t, 200, rec.Code)
|
||||
|
||||
selected, ok := c.Get(opsAccountIDKey)
|
||||
require.True(t, ok)
|
||||
require.Equal(t, accountID, selected)
|
||||
|
||||
var resp map[string]any
|
||||
require.NoError(t, json.Unmarshal(rec.Body.Bytes(), &resp))
|
||||
require.Equal(t, "msg_mock_warmup", resp["id"])
|
||||
require.Equal(t, "claude-sonnet-4-5", resp["model"])
|
||||
}
|
||||
@@ -4,8 +4,9 @@ import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"math/rand"
|
||||
"math/rand/v2"
|
||||
"net/http"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
@@ -17,23 +18,91 @@ import (
|
||||
// claudeCodeValidator is a singleton validator for Claude Code client detection
|
||||
var claudeCodeValidator = service.NewClaudeCodeValidator()
|
||||
|
||||
const claudeCodeParsedRequestContextKey = "claude_code_parsed_request"
|
||||
|
||||
// SetClaudeCodeClientContext 检查请求是否来自 Claude Code 客户端,并设置到 context 中
|
||||
// 返回更新后的 context
|
||||
func SetClaudeCodeClientContext(c *gin.Context, body []byte) {
|
||||
// 解析请求体为 map
|
||||
var bodyMap map[string]any
|
||||
if len(body) > 0 {
|
||||
_ = json.Unmarshal(body, &bodyMap)
|
||||
func SetClaudeCodeClientContext(c *gin.Context, body []byte, parsedReq *service.ParsedRequest) {
|
||||
if c == nil || c.Request == nil {
|
||||
return
|
||||
}
|
||||
if parsedReq != nil {
|
||||
c.Set(claudeCodeParsedRequestContextKey, parsedReq)
|
||||
}
|
||||
|
||||
// 验证是否为 Claude Code 客户端
|
||||
isClaudeCode := claudeCodeValidator.Validate(c.Request, bodyMap)
|
||||
ua := c.GetHeader("User-Agent")
|
||||
// Fast path:非 Claude CLI UA 直接判定 false,避免热路径二次 JSON 反序列化。
|
||||
if !claudeCodeValidator.ValidateUserAgent(ua) {
|
||||
ctx := service.SetClaudeCodeClient(c.Request.Context(), false)
|
||||
c.Request = c.Request.WithContext(ctx)
|
||||
return
|
||||
}
|
||||
|
||||
isClaudeCode := false
|
||||
if !strings.Contains(c.Request.URL.Path, "messages") {
|
||||
// 与 Validate 行为一致:非 messages 路径 UA 命中即可视为 Claude Code 客户端。
|
||||
isClaudeCode = true
|
||||
} else {
|
||||
// 仅在确认为 Claude CLI 且 messages 路径时再做 body 解析。
|
||||
bodyMap := claudeCodeBodyMapFromParsedRequest(parsedReq)
|
||||
if bodyMap == nil {
|
||||
bodyMap = claudeCodeBodyMapFromContextCache(c)
|
||||
}
|
||||
if bodyMap == nil && len(body) > 0 {
|
||||
_ = json.Unmarshal(body, &bodyMap)
|
||||
}
|
||||
isClaudeCode = claudeCodeValidator.Validate(c.Request, bodyMap)
|
||||
}
|
||||
|
||||
// 更新 request context
|
||||
ctx := service.SetClaudeCodeClient(c.Request.Context(), isClaudeCode)
|
||||
|
||||
// 仅在确认为 Claude Code 客户端时提取版本号写入 context
|
||||
if isClaudeCode {
|
||||
if version := claudeCodeValidator.ExtractVersion(ua); version != "" {
|
||||
ctx = service.SetClaudeCodeVersion(ctx, version)
|
||||
}
|
||||
}
|
||||
|
||||
c.Request = c.Request.WithContext(ctx)
|
||||
}
|
||||
|
||||
func claudeCodeBodyMapFromParsedRequest(parsedReq *service.ParsedRequest) map[string]any {
|
||||
if parsedReq == nil {
|
||||
return nil
|
||||
}
|
||||
bodyMap := map[string]any{
|
||||
"model": parsedReq.Model,
|
||||
}
|
||||
if parsedReq.System != nil || parsedReq.HasSystem {
|
||||
bodyMap["system"] = parsedReq.System
|
||||
}
|
||||
if parsedReq.MetadataUserID != "" {
|
||||
bodyMap["metadata"] = map[string]any{"user_id": parsedReq.MetadataUserID}
|
||||
}
|
||||
return bodyMap
|
||||
}
|
||||
|
||||
func claudeCodeBodyMapFromContextCache(c *gin.Context) map[string]any {
|
||||
if c == nil {
|
||||
return nil
|
||||
}
|
||||
if cached, ok := c.Get(service.OpenAIParsedRequestBodyKey); ok {
|
||||
if bodyMap, ok := cached.(map[string]any); ok {
|
||||
return bodyMap
|
||||
}
|
||||
}
|
||||
if cached, ok := c.Get(claudeCodeParsedRequestContextKey); ok {
|
||||
switch v := cached.(type) {
|
||||
case *service.ParsedRequest:
|
||||
return claudeCodeBodyMapFromParsedRequest(v)
|
||||
case service.ParsedRequest:
|
||||
return claudeCodeBodyMapFromParsedRequest(&v)
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// 并发槽位等待相关常量
|
||||
//
|
||||
// 性能优化说明:
|
||||
@@ -104,31 +173,24 @@ func NewConcurrencyHelper(concurrencyService *service.ConcurrencyService, pingFo
|
||||
|
||||
// wrapReleaseOnDone ensures release runs at most once and still triggers on context cancellation.
|
||||
// 用于避免客户端断开或上游超时导致的并发槽位泄漏。
|
||||
// 修复:添加 quit channel 确保 goroutine 及时退出,避免泄露
|
||||
// 优化:基于 context.AfterFunc 注册回调,避免每请求额外守护 goroutine。
|
||||
func wrapReleaseOnDone(ctx context.Context, releaseFunc func()) func() {
|
||||
if releaseFunc == nil {
|
||||
return nil
|
||||
}
|
||||
var once sync.Once
|
||||
quit := make(chan struct{})
|
||||
var stop func() bool
|
||||
|
||||
release := func() {
|
||||
once.Do(func() {
|
||||
if stop != nil {
|
||||
_ = stop()
|
||||
}
|
||||
releaseFunc()
|
||||
close(quit) // 通知监听 goroutine 退出
|
||||
})
|
||||
}
|
||||
|
||||
go func() {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
// Context 取消时释放资源
|
||||
release()
|
||||
case <-quit:
|
||||
// 正常释放已完成,goroutine 退出
|
||||
return
|
||||
}
|
||||
}()
|
||||
stop = context.AfterFunc(ctx, release)
|
||||
|
||||
return release
|
||||
}
|
||||
@@ -153,6 +215,32 @@ func (h *ConcurrencyHelper) DecrementAccountWaitCount(ctx context.Context, accou
|
||||
h.concurrencyService.DecrementAccountWaitCount(ctx, accountID)
|
||||
}
|
||||
|
||||
// TryAcquireUserSlot 尝试立即获取用户并发槽位。
|
||||
// 返回值: (releaseFunc, acquired, error)
|
||||
func (h *ConcurrencyHelper) TryAcquireUserSlot(ctx context.Context, userID int64, maxConcurrency int) (func(), bool, error) {
|
||||
result, err := h.concurrencyService.AcquireUserSlot(ctx, userID, maxConcurrency)
|
||||
if err != nil {
|
||||
return nil, false, err
|
||||
}
|
||||
if !result.Acquired {
|
||||
return nil, false, nil
|
||||
}
|
||||
return result.ReleaseFunc, true, nil
|
||||
}
|
||||
|
||||
// TryAcquireAccountSlot 尝试立即获取账号并发槽位。
|
||||
// 返回值: (releaseFunc, acquired, error)
|
||||
func (h *ConcurrencyHelper) TryAcquireAccountSlot(ctx context.Context, accountID int64, maxConcurrency int) (func(), bool, error) {
|
||||
result, err := h.concurrencyService.AcquireAccountSlot(ctx, accountID, maxConcurrency)
|
||||
if err != nil {
|
||||
return nil, false, err
|
||||
}
|
||||
if !result.Acquired {
|
||||
return nil, false, nil
|
||||
}
|
||||
return result.ReleaseFunc, true, nil
|
||||
}
|
||||
|
||||
// AcquireUserSlotWithWait acquires a user concurrency slot, waiting if necessary.
|
||||
// For streaming requests, sends ping events during the wait.
|
||||
// streamStarted is updated if streaming response has begun.
|
||||
@@ -160,13 +248,13 @@ func (h *ConcurrencyHelper) AcquireUserSlotWithWait(c *gin.Context, userID int64
|
||||
ctx := c.Request.Context()
|
||||
|
||||
// Try to acquire immediately
|
||||
result, err := h.concurrencyService.AcquireUserSlot(ctx, userID, maxConcurrency)
|
||||
releaseFunc, acquired, err := h.TryAcquireUserSlot(ctx, userID, maxConcurrency)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if result.Acquired {
|
||||
return result.ReleaseFunc, nil
|
||||
if acquired {
|
||||
return releaseFunc, nil
|
||||
}
|
||||
|
||||
// Need to wait - handle streaming ping if needed
|
||||
@@ -180,13 +268,13 @@ func (h *ConcurrencyHelper) AcquireAccountSlotWithWait(c *gin.Context, accountID
|
||||
ctx := c.Request.Context()
|
||||
|
||||
// Try to acquire immediately
|
||||
result, err := h.concurrencyService.AcquireAccountSlot(ctx, accountID, maxConcurrency)
|
||||
releaseFunc, acquired, err := h.TryAcquireAccountSlot(ctx, accountID, maxConcurrency)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if result.Acquired {
|
||||
return result.ReleaseFunc, nil
|
||||
if acquired {
|
||||
return releaseFunc, nil
|
||||
}
|
||||
|
||||
// Need to wait - handle streaming ping if needed
|
||||
@@ -196,27 +284,29 @@ func (h *ConcurrencyHelper) AcquireAccountSlotWithWait(c *gin.Context, accountID
|
||||
// waitForSlotWithPing waits for a concurrency slot, sending ping events for streaming requests.
|
||||
// streamStarted pointer is updated when streaming begins (for proper error handling by caller).
|
||||
func (h *ConcurrencyHelper) waitForSlotWithPing(c *gin.Context, slotType string, id int64, maxConcurrency int, isStream bool, streamStarted *bool) (func(), error) {
|
||||
return h.waitForSlotWithPingTimeout(c, slotType, id, maxConcurrency, maxConcurrencyWait, isStream, streamStarted)
|
||||
return h.waitForSlotWithPingTimeout(c, slotType, id, maxConcurrency, maxConcurrencyWait, isStream, streamStarted, false)
|
||||
}
|
||||
|
||||
// waitForSlotWithPingTimeout waits for a concurrency slot with a custom timeout.
|
||||
func (h *ConcurrencyHelper) waitForSlotWithPingTimeout(c *gin.Context, slotType string, id int64, maxConcurrency int, timeout time.Duration, isStream bool, streamStarted *bool) (func(), error) {
|
||||
func (h *ConcurrencyHelper) waitForSlotWithPingTimeout(c *gin.Context, slotType string, id int64, maxConcurrency int, timeout time.Duration, isStream bool, streamStarted *bool, tryImmediate bool) (func(), error) {
|
||||
ctx, cancel := context.WithTimeout(c.Request.Context(), timeout)
|
||||
defer cancel()
|
||||
|
||||
// Try immediate acquire first (avoid unnecessary wait)
|
||||
var result *service.AcquireResult
|
||||
var err error
|
||||
if slotType == "user" {
|
||||
result, err = h.concurrencyService.AcquireUserSlot(ctx, id, maxConcurrency)
|
||||
} else {
|
||||
result, err = h.concurrencyService.AcquireAccountSlot(ctx, id, maxConcurrency)
|
||||
acquireSlot := func() (*service.AcquireResult, error) {
|
||||
if slotType == "user" {
|
||||
return h.concurrencyService.AcquireUserSlot(ctx, id, maxConcurrency)
|
||||
}
|
||||
return h.concurrencyService.AcquireAccountSlot(ctx, id, maxConcurrency)
|
||||
}
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if result.Acquired {
|
||||
return result.ReleaseFunc, nil
|
||||
|
||||
if tryImmediate {
|
||||
result, err := acquireSlot()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if result.Acquired {
|
||||
return result.ReleaseFunc, nil
|
||||
}
|
||||
}
|
||||
|
||||
// Determine if ping is needed (streaming + ping format defined)
|
||||
@@ -242,7 +332,6 @@ func (h *ConcurrencyHelper) waitForSlotWithPingTimeout(c *gin.Context, slotType
|
||||
backoff := initialBackoff
|
||||
timer := time.NewTimer(backoff)
|
||||
defer timer.Stop()
|
||||
rng := rand.New(rand.NewSource(time.Now().UnixNano()))
|
||||
|
||||
for {
|
||||
select {
|
||||
@@ -268,15 +357,7 @@ func (h *ConcurrencyHelper) waitForSlotWithPingTimeout(c *gin.Context, slotType
|
||||
|
||||
case <-timer.C:
|
||||
// Try to acquire slot
|
||||
var result *service.AcquireResult
|
||||
var err error
|
||||
|
||||
if slotType == "user" {
|
||||
result, err = h.concurrencyService.AcquireUserSlot(ctx, id, maxConcurrency)
|
||||
} else {
|
||||
result, err = h.concurrencyService.AcquireAccountSlot(ctx, id, maxConcurrency)
|
||||
}
|
||||
|
||||
result, err := acquireSlot()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -284,7 +365,7 @@ func (h *ConcurrencyHelper) waitForSlotWithPingTimeout(c *gin.Context, slotType
|
||||
if result.Acquired {
|
||||
return result.ReleaseFunc, nil
|
||||
}
|
||||
backoff = nextBackoff(backoff, rng)
|
||||
backoff = nextBackoff(backoff)
|
||||
timer.Reset(backoff)
|
||||
}
|
||||
}
|
||||
@@ -292,26 +373,22 @@ func (h *ConcurrencyHelper) waitForSlotWithPingTimeout(c *gin.Context, slotType
|
||||
|
||||
// AcquireAccountSlotWithWaitTimeout acquires an account slot with a custom timeout (keeps SSE ping).
|
||||
func (h *ConcurrencyHelper) AcquireAccountSlotWithWaitTimeout(c *gin.Context, accountID int64, maxConcurrency int, timeout time.Duration, isStream bool, streamStarted *bool) (func(), error) {
|
||||
return h.waitForSlotWithPingTimeout(c, "account", accountID, maxConcurrency, timeout, isStream, streamStarted)
|
||||
return h.waitForSlotWithPingTimeout(c, "account", accountID, maxConcurrency, timeout, isStream, streamStarted, true)
|
||||
}
|
||||
|
||||
// nextBackoff 计算下一次退避时间
|
||||
// 性能优化:使用指数退避 + 随机抖动,避免惊群效应
|
||||
// current: 当前退避时间
|
||||
// rng: 随机数生成器(可为 nil,此时不添加抖动)
|
||||
// 返回值:下一次退避时间(100ms ~ 2s 之间)
|
||||
func nextBackoff(current time.Duration, rng *rand.Rand) time.Duration {
|
||||
func nextBackoff(current time.Duration) time.Duration {
|
||||
// 指数退避:当前时间 * 1.5
|
||||
next := time.Duration(float64(current) * backoffMultiplier)
|
||||
if next > maxBackoff {
|
||||
next = maxBackoff
|
||||
}
|
||||
if rng == nil {
|
||||
return next
|
||||
}
|
||||
// 添加 ±20% 的随机抖动(jitter 范围 0.8 ~ 1.2)
|
||||
// 抖动可以分散多个请求的重试时间点,避免同时冲击 Redis
|
||||
jitter := 0.8 + rng.Float64()*0.4
|
||||
jitter := 0.8 + rand.Float64()*0.4
|
||||
jittered := time.Duration(float64(next) * jitter)
|
||||
if jittered < initialBackoff {
|
||||
return initialBackoff
|
||||
|
||||
106
backend/internal/handler/gateway_helper_backoff_test.go
Normal file
106
backend/internal/handler/gateway_helper_backoff_test.go
Normal file
@@ -0,0 +1,106 @@
|
||||
package handler
|
||||
|
||||
import (
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
// --- Task 6.2 验证: math/rand/v2 迁移后 nextBackoff 行为正确 ---
|
||||
|
||||
func TestNextBackoff_ExponentialGrowth(t *testing.T) {
|
||||
// 验证退避时间指数增长(乘数 1.5)
|
||||
// 由于有随机抖动(±20%),需要验证范围
|
||||
current := initialBackoff // 100ms
|
||||
|
||||
for i := 0; i < 10; i++ {
|
||||
next := nextBackoff(current)
|
||||
|
||||
// 退避结果应在 [initialBackoff, maxBackoff] 范围内
|
||||
assert.GreaterOrEqual(t, int64(next), int64(initialBackoff),
|
||||
"第 %d 次退避不应低于初始值 %v", i, initialBackoff)
|
||||
assert.LessOrEqual(t, int64(next), int64(maxBackoff),
|
||||
"第 %d 次退避不应超过最大值 %v", i, maxBackoff)
|
||||
|
||||
// 为下一轮提供当前退避值
|
||||
current = next
|
||||
}
|
||||
}
|
||||
|
||||
func TestNextBackoff_BoundedByMaxBackoff(t *testing.T) {
|
||||
// 即使输入非常大,输出也不超过 maxBackoff
|
||||
for i := 0; i < 100; i++ {
|
||||
result := nextBackoff(10 * time.Second)
|
||||
assert.LessOrEqual(t, int64(result), int64(maxBackoff),
|
||||
"退避值不应超过 maxBackoff")
|
||||
}
|
||||
}
|
||||
|
||||
func TestNextBackoff_BoundedByInitialBackoff(t *testing.T) {
|
||||
// 即使输入非常小,输出也不低于 initialBackoff
|
||||
for i := 0; i < 100; i++ {
|
||||
result := nextBackoff(1 * time.Millisecond)
|
||||
assert.GreaterOrEqual(t, int64(result), int64(initialBackoff),
|
||||
"退避值不应低于 initialBackoff")
|
||||
}
|
||||
}
|
||||
|
||||
func TestNextBackoff_HasJitter(t *testing.T) {
|
||||
// 验证多次调用会产生不同的值(随机抖动生效)
|
||||
// 使用相同的输入调用 50 次,收集结果
|
||||
results := make(map[time.Duration]bool)
|
||||
current := 500 * time.Millisecond
|
||||
|
||||
for i := 0; i < 50; i++ {
|
||||
result := nextBackoff(current)
|
||||
results[result] = true
|
||||
}
|
||||
|
||||
// 50 次调用应该至少有 2 个不同的值(抖动存在)
|
||||
require.Greater(t, len(results), 1,
|
||||
"nextBackoff 应产生随机抖动,但所有 50 次调用结果相同")
|
||||
}
|
||||
|
||||
func TestNextBackoff_InitialValueGrows(t *testing.T) {
|
||||
// 验证从初始值开始,退避趋势是增长的
|
||||
current := initialBackoff
|
||||
var sum time.Duration
|
||||
|
||||
runs := 100
|
||||
for i := 0; i < runs; i++ {
|
||||
next := nextBackoff(current)
|
||||
sum += next
|
||||
current = next
|
||||
}
|
||||
|
||||
avg := sum / time.Duration(runs)
|
||||
// 平均退避时间应大于初始值(因为指数增长 + 上限)
|
||||
assert.Greater(t, int64(avg), int64(initialBackoff),
|
||||
"平均退避时间应大于初始退避值")
|
||||
}
|
||||
|
||||
func TestNextBackoff_ConvergesToMaxBackoff(t *testing.T) {
|
||||
// 从初始值开始,经过多次退避后应收敛到 maxBackoff 附近
|
||||
current := initialBackoff
|
||||
for i := 0; i < 20; i++ {
|
||||
current = nextBackoff(current)
|
||||
}
|
||||
|
||||
// 经过 20 次迭代后,应该已经到达 maxBackoff 区间
|
||||
// 由于抖动,允许 ±20% 的范围
|
||||
lowerBound := time.Duration(float64(maxBackoff) * 0.8)
|
||||
assert.GreaterOrEqual(t, int64(current), int64(lowerBound),
|
||||
"经过多次退避后应收敛到 maxBackoff 附近")
|
||||
}
|
||||
|
||||
func BenchmarkNextBackoff(b *testing.B) {
|
||||
current := initialBackoff
|
||||
for i := 0; i < b.N; i++ {
|
||||
current = nextBackoff(current)
|
||||
if current > maxBackoff {
|
||||
current = initialBackoff
|
||||
}
|
||||
}
|
||||
}
|
||||
122
backend/internal/handler/gateway_helper_fastpath_test.go
Normal file
122
backend/internal/handler/gateway_helper_fastpath_test.go
Normal file
@@ -0,0 +1,122 @@
|
||||
package handler
|
||||
|
||||
import (
|
||||
"context"
|
||||
"sync/atomic"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/service"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
type concurrencyCacheMock struct {
|
||||
acquireUserSlotFn func(ctx context.Context, userID int64, maxConcurrency int, requestID string) (bool, error)
|
||||
acquireAccountSlotFn func(ctx context.Context, accountID int64, maxConcurrency int, requestID string) (bool, error)
|
||||
releaseUserCalled int32
|
||||
releaseAccountCalled int32
|
||||
}
|
||||
|
||||
func (m *concurrencyCacheMock) AcquireAccountSlot(ctx context.Context, accountID int64, maxConcurrency int, requestID string) (bool, error) {
|
||||
if m.acquireAccountSlotFn != nil {
|
||||
return m.acquireAccountSlotFn(ctx, accountID, maxConcurrency, requestID)
|
||||
}
|
||||
return false, nil
|
||||
}
|
||||
|
||||
func (m *concurrencyCacheMock) ReleaseAccountSlot(ctx context.Context, accountID int64, requestID string) error {
|
||||
atomic.AddInt32(&m.releaseAccountCalled, 1)
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *concurrencyCacheMock) GetAccountConcurrency(ctx context.Context, accountID int64) (int, error) {
|
||||
return 0, nil
|
||||
}
|
||||
|
||||
func (m *concurrencyCacheMock) GetAccountConcurrencyBatch(ctx context.Context, accountIDs []int64) (map[int64]int, error) {
|
||||
result := make(map[int64]int, len(accountIDs))
|
||||
for _, accountID := range accountIDs {
|
||||
result[accountID] = 0
|
||||
}
|
||||
return result, nil
|
||||
}
|
||||
|
||||
func (m *concurrencyCacheMock) IncrementAccountWaitCount(ctx context.Context, accountID int64, maxWait int) (bool, error) {
|
||||
return true, nil
|
||||
}
|
||||
|
||||
func (m *concurrencyCacheMock) DecrementAccountWaitCount(ctx context.Context, accountID int64) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *concurrencyCacheMock) GetAccountWaitingCount(ctx context.Context, accountID int64) (int, error) {
|
||||
return 0, nil
|
||||
}
|
||||
|
||||
func (m *concurrencyCacheMock) AcquireUserSlot(ctx context.Context, userID int64, maxConcurrency int, requestID string) (bool, error) {
|
||||
if m.acquireUserSlotFn != nil {
|
||||
return m.acquireUserSlotFn(ctx, userID, maxConcurrency, requestID)
|
||||
}
|
||||
return false, nil
|
||||
}
|
||||
|
||||
func (m *concurrencyCacheMock) ReleaseUserSlot(ctx context.Context, userID int64, requestID string) error {
|
||||
atomic.AddInt32(&m.releaseUserCalled, 1)
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *concurrencyCacheMock) GetUserConcurrency(ctx context.Context, userID int64) (int, error) {
|
||||
return 0, nil
|
||||
}
|
||||
|
||||
func (m *concurrencyCacheMock) IncrementWaitCount(ctx context.Context, userID int64, maxWait int) (bool, error) {
|
||||
return true, nil
|
||||
}
|
||||
|
||||
func (m *concurrencyCacheMock) DecrementWaitCount(ctx context.Context, userID int64) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *concurrencyCacheMock) GetAccountsLoadBatch(ctx context.Context, accounts []service.AccountWithConcurrency) (map[int64]*service.AccountLoadInfo, error) {
|
||||
return map[int64]*service.AccountLoadInfo{}, nil
|
||||
}
|
||||
|
||||
func (m *concurrencyCacheMock) GetUsersLoadBatch(ctx context.Context, users []service.UserWithConcurrency) (map[int64]*service.UserLoadInfo, error) {
|
||||
return map[int64]*service.UserLoadInfo{}, nil
|
||||
}
|
||||
|
||||
func (m *concurrencyCacheMock) CleanupExpiredAccountSlots(ctx context.Context, accountID int64) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func TestConcurrencyHelper_TryAcquireUserSlot(t *testing.T) {
|
||||
cache := &concurrencyCacheMock{
|
||||
acquireUserSlotFn: func(ctx context.Context, userID int64, maxConcurrency int, requestID string) (bool, error) {
|
||||
return true, nil
|
||||
},
|
||||
}
|
||||
helper := NewConcurrencyHelper(service.NewConcurrencyService(cache), SSEPingFormatNone, time.Second)
|
||||
|
||||
release, acquired, err := helper.TryAcquireUserSlot(context.Background(), 101, 2)
|
||||
require.NoError(t, err)
|
||||
require.True(t, acquired)
|
||||
require.NotNil(t, release)
|
||||
|
||||
release()
|
||||
require.Equal(t, int32(1), atomic.LoadInt32(&cache.releaseUserCalled))
|
||||
}
|
||||
|
||||
func TestConcurrencyHelper_TryAcquireAccountSlot_NotAcquired(t *testing.T) {
|
||||
cache := &concurrencyCacheMock{
|
||||
acquireAccountSlotFn: func(ctx context.Context, accountID int64, maxConcurrency int, requestID string) (bool, error) {
|
||||
return false, nil
|
||||
},
|
||||
}
|
||||
helper := NewConcurrencyHelper(service.NewConcurrencyService(cache), SSEPingFormatNone, time.Second)
|
||||
|
||||
release, acquired, err := helper.TryAcquireAccountSlot(context.Background(), 201, 1)
|
||||
require.NoError(t, err)
|
||||
require.False(t, acquired)
|
||||
require.Nil(t, release)
|
||||
require.Equal(t, int32(0), atomic.LoadInt32(&cache.releaseAccountCalled))
|
||||
}
|
||||
317
backend/internal/handler/gateway_helper_hotpath_test.go
Normal file
317
backend/internal/handler/gateway_helper_hotpath_test.go
Normal file
@@ -0,0 +1,317 @@
|
||||
package handler
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"sync"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/service"
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
type helperConcurrencyCacheStub struct {
|
||||
mu sync.Mutex
|
||||
|
||||
accountSeq []bool
|
||||
userSeq []bool
|
||||
|
||||
accountAcquireCalls int
|
||||
userAcquireCalls int
|
||||
accountReleaseCalls int
|
||||
userReleaseCalls int
|
||||
}
|
||||
|
||||
func (s *helperConcurrencyCacheStub) AcquireAccountSlot(ctx context.Context, accountID int64, maxConcurrency int, requestID string) (bool, error) {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
s.accountAcquireCalls++
|
||||
if len(s.accountSeq) == 0 {
|
||||
return false, nil
|
||||
}
|
||||
v := s.accountSeq[0]
|
||||
s.accountSeq = s.accountSeq[1:]
|
||||
return v, nil
|
||||
}
|
||||
|
||||
func (s *helperConcurrencyCacheStub) ReleaseAccountSlot(ctx context.Context, accountID int64, requestID string) error {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
s.accountReleaseCalls++
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *helperConcurrencyCacheStub) GetAccountConcurrency(ctx context.Context, accountID int64) (int, error) {
|
||||
return 0, nil
|
||||
}
|
||||
|
||||
func (s *helperConcurrencyCacheStub) GetAccountConcurrencyBatch(ctx context.Context, accountIDs []int64) (map[int64]int, error) {
|
||||
out := make(map[int64]int, len(accountIDs))
|
||||
for _, accountID := range accountIDs {
|
||||
out[accountID] = 0
|
||||
}
|
||||
return out, nil
|
||||
}
|
||||
|
||||
func (s *helperConcurrencyCacheStub) IncrementAccountWaitCount(ctx context.Context, accountID int64, maxWait int) (bool, error) {
|
||||
return true, nil
|
||||
}
|
||||
|
||||
func (s *helperConcurrencyCacheStub) DecrementAccountWaitCount(ctx context.Context, accountID int64) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *helperConcurrencyCacheStub) GetAccountWaitingCount(ctx context.Context, accountID int64) (int, error) {
|
||||
return 0, nil
|
||||
}
|
||||
|
||||
func (s *helperConcurrencyCacheStub) AcquireUserSlot(ctx context.Context, userID int64, maxConcurrency int, requestID string) (bool, error) {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
s.userAcquireCalls++
|
||||
if len(s.userSeq) == 0 {
|
||||
return false, nil
|
||||
}
|
||||
v := s.userSeq[0]
|
||||
s.userSeq = s.userSeq[1:]
|
||||
return v, nil
|
||||
}
|
||||
|
||||
func (s *helperConcurrencyCacheStub) ReleaseUserSlot(ctx context.Context, userID int64, requestID string) error {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
s.userReleaseCalls++
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *helperConcurrencyCacheStub) GetUserConcurrency(ctx context.Context, userID int64) (int, error) {
|
||||
return 0, nil
|
||||
}
|
||||
|
||||
func (s *helperConcurrencyCacheStub) IncrementWaitCount(ctx context.Context, userID int64, maxWait int) (bool, error) {
|
||||
return true, nil
|
||||
}
|
||||
|
||||
func (s *helperConcurrencyCacheStub) DecrementWaitCount(ctx context.Context, userID int64) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *helperConcurrencyCacheStub) GetAccountsLoadBatch(ctx context.Context, accounts []service.AccountWithConcurrency) (map[int64]*service.AccountLoadInfo, error) {
|
||||
out := make(map[int64]*service.AccountLoadInfo, len(accounts))
|
||||
for _, acc := range accounts {
|
||||
out[acc.ID] = &service.AccountLoadInfo{AccountID: acc.ID}
|
||||
}
|
||||
return out, nil
|
||||
}
|
||||
|
||||
func (s *helperConcurrencyCacheStub) GetUsersLoadBatch(ctx context.Context, users []service.UserWithConcurrency) (map[int64]*service.UserLoadInfo, error) {
|
||||
out := make(map[int64]*service.UserLoadInfo, len(users))
|
||||
for _, user := range users {
|
||||
out[user.ID] = &service.UserLoadInfo{UserID: user.ID}
|
||||
}
|
||||
return out, nil
|
||||
}
|
||||
|
||||
func (s *helperConcurrencyCacheStub) CleanupExpiredAccountSlots(ctx context.Context, accountID int64) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func newHelperTestContext(method, path string) (*gin.Context, *httptest.ResponseRecorder) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
rec := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(rec)
|
||||
c.Request = httptest.NewRequest(method, path, nil)
|
||||
return c, rec
|
||||
}
|
||||
|
||||
func validClaudeCodeBodyJSON() []byte {
|
||||
return []byte(`{
|
||||
"model":"claude-3-5-sonnet-20241022",
|
||||
"system":[{"text":"You are Claude Code, Anthropic's official CLI for Claude."}],
|
||||
"metadata":{"user_id":"user_aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa_account__session_abc-123"}
|
||||
}`)
|
||||
}
|
||||
|
||||
func TestSetClaudeCodeClientContext_FastPathAndStrictPath(t *testing.T) {
|
||||
t.Run("non_cli_user_agent_sets_false", func(t *testing.T) {
|
||||
c, _ := newHelperTestContext(http.MethodPost, "/v1/messages")
|
||||
c.Request.Header.Set("User-Agent", "curl/8.6.0")
|
||||
|
||||
SetClaudeCodeClientContext(c, validClaudeCodeBodyJSON(), nil)
|
||||
require.False(t, service.IsClaudeCodeClient(c.Request.Context()))
|
||||
})
|
||||
|
||||
t.Run("cli_non_messages_path_sets_true", func(t *testing.T) {
|
||||
c, _ := newHelperTestContext(http.MethodGet, "/v1/models")
|
||||
c.Request.Header.Set("User-Agent", "claude-cli/1.0.1")
|
||||
|
||||
SetClaudeCodeClientContext(c, nil, nil)
|
||||
require.True(t, service.IsClaudeCodeClient(c.Request.Context()))
|
||||
})
|
||||
|
||||
t.Run("cli_messages_path_valid_body_sets_true", func(t *testing.T) {
|
||||
c, _ := newHelperTestContext(http.MethodPost, "/v1/messages")
|
||||
c.Request.Header.Set("User-Agent", "claude-cli/1.0.1")
|
||||
c.Request.Header.Set("X-App", "claude-code")
|
||||
c.Request.Header.Set("anthropic-beta", "message-batches-2024-09-24")
|
||||
c.Request.Header.Set("anthropic-version", "2023-06-01")
|
||||
|
||||
SetClaudeCodeClientContext(c, validClaudeCodeBodyJSON(), nil)
|
||||
require.True(t, service.IsClaudeCodeClient(c.Request.Context()))
|
||||
})
|
||||
|
||||
t.Run("cli_messages_path_invalid_body_sets_false", func(t *testing.T) {
|
||||
c, _ := newHelperTestContext(http.MethodPost, "/v1/messages")
|
||||
c.Request.Header.Set("User-Agent", "claude-cli/1.0.1")
|
||||
// 缺少严格校验所需 header + body 字段
|
||||
SetClaudeCodeClientContext(c, []byte(`{"model":"x"}`), nil)
|
||||
require.False(t, service.IsClaudeCodeClient(c.Request.Context()))
|
||||
})
|
||||
}
|
||||
|
||||
func TestSetClaudeCodeClientContext_ReuseParsedRequestAndContextCache(t *testing.T) {
|
||||
t.Run("reuse parsed request without body unmarshal", func(t *testing.T) {
|
||||
c, _ := newHelperTestContext(http.MethodPost, "/v1/messages")
|
||||
c.Request.Header.Set("User-Agent", "claude-cli/1.0.1")
|
||||
c.Request.Header.Set("X-App", "claude-code")
|
||||
c.Request.Header.Set("anthropic-beta", "message-batches-2024-09-24")
|
||||
c.Request.Header.Set("anthropic-version", "2023-06-01")
|
||||
|
||||
parsedReq := &service.ParsedRequest{
|
||||
Model: "claude-3-5-sonnet-20241022",
|
||||
System: []any{
|
||||
map[string]any{"text": "You are Claude Code, Anthropic's official CLI for Claude."},
|
||||
},
|
||||
MetadataUserID: "user_aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa_account__session_abc-123",
|
||||
}
|
||||
|
||||
// body 非法 JSON,如果函数复用 parsedReq 成功则仍应判定为 Claude Code。
|
||||
SetClaudeCodeClientContext(c, []byte(`{invalid`), parsedReq)
|
||||
require.True(t, service.IsClaudeCodeClient(c.Request.Context()))
|
||||
})
|
||||
|
||||
t.Run("reuse context cache without body unmarshal", func(t *testing.T) {
|
||||
c, _ := newHelperTestContext(http.MethodPost, "/v1/messages")
|
||||
c.Request.Header.Set("User-Agent", "claude-cli/1.0.1")
|
||||
c.Request.Header.Set("X-App", "claude-code")
|
||||
c.Request.Header.Set("anthropic-beta", "message-batches-2024-09-24")
|
||||
c.Request.Header.Set("anthropic-version", "2023-06-01")
|
||||
c.Set(service.OpenAIParsedRequestBodyKey, map[string]any{
|
||||
"model": "claude-3-5-sonnet-20241022",
|
||||
"system": []any{
|
||||
map[string]any{"text": "You are Claude Code, Anthropic's official CLI for Claude."},
|
||||
},
|
||||
"metadata": map[string]any{"user_id": "user_aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa_account__session_abc-123"},
|
||||
})
|
||||
|
||||
SetClaudeCodeClientContext(c, []byte(`{invalid`), nil)
|
||||
require.True(t, service.IsClaudeCodeClient(c.Request.Context()))
|
||||
})
|
||||
}
|
||||
|
||||
func TestWaitForSlotWithPingTimeout_AccountAndUserAcquire(t *testing.T) {
|
||||
cache := &helperConcurrencyCacheStub{
|
||||
accountSeq: []bool{false, true},
|
||||
userSeq: []bool{false, true},
|
||||
}
|
||||
concurrency := service.NewConcurrencyService(cache)
|
||||
helper := NewConcurrencyHelper(concurrency, SSEPingFormatNone, 5*time.Millisecond)
|
||||
|
||||
t.Run("account_slot_acquired_after_retry", func(t *testing.T) {
|
||||
c, _ := newHelperTestContext(http.MethodPost, "/v1/messages")
|
||||
streamStarted := false
|
||||
release, err := helper.waitForSlotWithPingTimeout(c, "account", 101, 2, time.Second, false, &streamStarted, true)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, release)
|
||||
require.False(t, streamStarted)
|
||||
release()
|
||||
require.GreaterOrEqual(t, cache.accountAcquireCalls, 2)
|
||||
require.GreaterOrEqual(t, cache.accountReleaseCalls, 1)
|
||||
})
|
||||
|
||||
t.Run("user_slot_acquired_after_retry", func(t *testing.T) {
|
||||
c, _ := newHelperTestContext(http.MethodPost, "/v1/messages")
|
||||
streamStarted := false
|
||||
release, err := helper.waitForSlotWithPingTimeout(c, "user", 202, 3, time.Second, false, &streamStarted, true)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, release)
|
||||
release()
|
||||
require.GreaterOrEqual(t, cache.userAcquireCalls, 2)
|
||||
require.GreaterOrEqual(t, cache.userReleaseCalls, 1)
|
||||
})
|
||||
}
|
||||
|
||||
func TestWaitForSlotWithPingTimeout_TimeoutAndStreamPing(t *testing.T) {
|
||||
cache := &helperConcurrencyCacheStub{
|
||||
accountSeq: []bool{false, false, false},
|
||||
}
|
||||
concurrency := service.NewConcurrencyService(cache)
|
||||
|
||||
t.Run("timeout_returns_concurrency_error", func(t *testing.T) {
|
||||
helper := NewConcurrencyHelper(concurrency, SSEPingFormatNone, 5*time.Millisecond)
|
||||
c, _ := newHelperTestContext(http.MethodPost, "/v1/messages")
|
||||
streamStarted := false
|
||||
release, err := helper.waitForSlotWithPingTimeout(c, "account", 101, 2, 130*time.Millisecond, false, &streamStarted, true)
|
||||
require.Nil(t, release)
|
||||
var cErr *ConcurrencyError
|
||||
require.ErrorAs(t, err, &cErr)
|
||||
require.True(t, cErr.IsTimeout)
|
||||
})
|
||||
|
||||
t.Run("stream_mode_sends_ping_before_timeout", func(t *testing.T) {
|
||||
helper := NewConcurrencyHelper(concurrency, SSEPingFormatComment, 10*time.Millisecond)
|
||||
c, rec := newHelperTestContext(http.MethodPost, "/v1/messages")
|
||||
streamStarted := false
|
||||
release, err := helper.waitForSlotWithPingTimeout(c, "account", 101, 2, 70*time.Millisecond, true, &streamStarted, true)
|
||||
require.Nil(t, release)
|
||||
var cErr *ConcurrencyError
|
||||
require.ErrorAs(t, err, &cErr)
|
||||
require.True(t, cErr.IsTimeout)
|
||||
require.True(t, streamStarted)
|
||||
require.Contains(t, rec.Body.String(), ":\n\n")
|
||||
})
|
||||
}
|
||||
|
||||
func TestWaitForSlotWithPingTimeout_AcquireError(t *testing.T) {
|
||||
errCache := &helperConcurrencyCacheStubWithError{
|
||||
err: errors.New("redis unavailable"),
|
||||
}
|
||||
concurrency := service.NewConcurrencyService(errCache)
|
||||
helper := NewConcurrencyHelper(concurrency, SSEPingFormatNone, 5*time.Millisecond)
|
||||
c, _ := newHelperTestContext(http.MethodPost, "/v1/messages")
|
||||
streamStarted := false
|
||||
release, err := helper.waitForSlotWithPingTimeout(c, "account", 1, 1, 200*time.Millisecond, false, &streamStarted, true)
|
||||
require.Nil(t, release)
|
||||
require.Error(t, err)
|
||||
require.Contains(t, err.Error(), "redis unavailable")
|
||||
}
|
||||
|
||||
func TestAcquireAccountSlotWithWaitTimeout_ImmediateAttemptBeforeBackoff(t *testing.T) {
|
||||
cache := &helperConcurrencyCacheStub{
|
||||
accountSeq: []bool{false},
|
||||
}
|
||||
concurrency := service.NewConcurrencyService(cache)
|
||||
helper := NewConcurrencyHelper(concurrency, SSEPingFormatNone, 5*time.Millisecond)
|
||||
c, _ := newHelperTestContext(http.MethodPost, "/v1/messages")
|
||||
streamStarted := false
|
||||
|
||||
release, err := helper.AcquireAccountSlotWithWaitTimeout(c, 301, 1, 30*time.Millisecond, false, &streamStarted)
|
||||
require.Nil(t, release)
|
||||
var cErr *ConcurrencyError
|
||||
require.ErrorAs(t, err, &cErr)
|
||||
require.True(t, cErr.IsTimeout)
|
||||
require.GreaterOrEqual(t, cache.accountAcquireCalls, 1)
|
||||
}
|
||||
|
||||
type helperConcurrencyCacheStubWithError struct {
|
||||
helperConcurrencyCacheStub
|
||||
err error
|
||||
}
|
||||
|
||||
func (s *helperConcurrencyCacheStubWithError) AcquireAccountSlot(ctx context.Context, accountID int64, maxConcurrency int, requestID string) (bool, error) {
|
||||
return false, s.err
|
||||
}
|
||||
@@ -7,24 +7,23 @@ import (
|
||||
"encoding/hex"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"io"
|
||||
"log"
|
||||
"net/http"
|
||||
"regexp"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/domain"
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/antigravity"
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/ctxkey"
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/gemini"
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/googleapi"
|
||||
pkghttputil "github.com/Wei-Shaw/sub2api/internal/pkg/httputil"
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/ip"
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/logger"
|
||||
"github.com/Wei-Shaw/sub2api/internal/server/middleware"
|
||||
"github.com/Wei-Shaw/sub2api/internal/service"
|
||||
"github.com/google/uuid"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
"go.uber.org/zap"
|
||||
)
|
||||
|
||||
// geminiCLITmpDirRegex 用于从 Gemini CLI 请求体中提取 tmp 目录的哈希值
|
||||
@@ -143,6 +142,13 @@ func (h *GatewayHandler) GeminiV1BetaModels(c *gin.Context) {
|
||||
googleError(c, http.StatusInternalServerError, "User context not found")
|
||||
return
|
||||
}
|
||||
reqLog := requestLogger(
|
||||
c,
|
||||
"handler.gemini_v1beta.models",
|
||||
zap.Int64("user_id", authSubject.UserID),
|
||||
zap.Int64("api_key_id", apiKey.ID),
|
||||
zap.Any("group_id", apiKey.GroupID),
|
||||
)
|
||||
|
||||
// 检查平台:优先使用强制平台(/antigravity 路由,中间件已设置 request.Context),否则要求 gemini 分组
|
||||
if !middleware.HasForcePlatform(c) {
|
||||
@@ -159,8 +165,9 @@ func (h *GatewayHandler) GeminiV1BetaModels(c *gin.Context) {
|
||||
}
|
||||
|
||||
stream := action == "streamGenerateContent"
|
||||
reqLog = reqLog.With(zap.String("model", modelName), zap.String("action", action), zap.Bool("stream", stream))
|
||||
|
||||
body, err := io.ReadAll(c.Request.Body)
|
||||
body, err := pkghttputil.ReadRequestBodyWithPrealloc(c.Request)
|
||||
if err != nil {
|
||||
if maxErr, ok := extractMaxBytesError(err); ok {
|
||||
googleError(c, http.StatusRequestEntityTooLarge, buildBodyTooLargeMessage(maxErr.Limit))
|
||||
@@ -187,8 +194,9 @@ func (h *GatewayHandler) GeminiV1BetaModels(c *gin.Context) {
|
||||
canWait, err := geminiConcurrency.IncrementWaitCount(c.Request.Context(), authSubject.UserID, maxWait)
|
||||
waitCounted := false
|
||||
if err != nil {
|
||||
log.Printf("Increment wait count failed: %v", err)
|
||||
reqLog.Warn("gemini.user_wait_counter_increment_failed", zap.Error(err))
|
||||
} else if !canWait {
|
||||
reqLog.Info("gemini.user_wait_queue_full", zap.Int("max_wait", maxWait))
|
||||
googleError(c, http.StatusTooManyRequests, "Too many pending requests, please retry later")
|
||||
return
|
||||
}
|
||||
@@ -208,6 +216,7 @@ func (h *GatewayHandler) GeminiV1BetaModels(c *gin.Context) {
|
||||
}
|
||||
userReleaseFunc, err := geminiConcurrency.AcquireUserSlotWithWait(c, authSubject.UserID, authSubject.Concurrency, stream, &streamStarted)
|
||||
if err != nil {
|
||||
reqLog.Warn("gemini.user_slot_acquire_failed", zap.Error(err))
|
||||
googleError(c, http.StatusTooManyRequests, err.Error())
|
||||
return
|
||||
}
|
||||
@@ -223,6 +232,7 @@ func (h *GatewayHandler) GeminiV1BetaModels(c *gin.Context) {
|
||||
|
||||
// 2) billing eligibility check (after wait)
|
||||
if err := h.billingCacheService.CheckBillingEligibility(c.Request.Context(), apiKey.User, apiKey, apiKey.Group, subscription); err != nil {
|
||||
reqLog.Info("gemini.billing_eligibility_check_failed", zap.Error(err))
|
||||
status, _, message := billingErrorDetails(err)
|
||||
googleError(c, status, message)
|
||||
return
|
||||
@@ -252,6 +262,14 @@ func (h *GatewayHandler) GeminiV1BetaModels(c *gin.Context) {
|
||||
var sessionBoundAccountID int64
|
||||
if sessionKey != "" {
|
||||
sessionBoundAccountID, _ = h.gatewayService.GetCachedSessionAccountID(c.Request.Context(), apiKey.GroupID, sessionKey)
|
||||
if sessionBoundAccountID > 0 {
|
||||
prefetchedGroupID := int64(0)
|
||||
if apiKey.GroupID != nil {
|
||||
prefetchedGroupID = *apiKey.GroupID
|
||||
}
|
||||
ctx := service.WithPrefetchedStickySession(c.Request.Context(), sessionBoundAccountID, prefetchedGroupID, h.metadataBridgeEnabled())
|
||||
c.Request = c.Request.WithContext(ctx)
|
||||
}
|
||||
}
|
||||
|
||||
// === Gemini 内容摘要会话 Fallback 逻辑 ===
|
||||
@@ -296,8 +314,11 @@ func (h *GatewayHandler) GeminiV1BetaModels(c *gin.Context) {
|
||||
matchedDigestChain = foundMatchedChain
|
||||
sessionBoundAccountID = foundAccountID
|
||||
geminiSessionUUID = foundUUID
|
||||
log.Printf("[Gemini] Digest fallback matched: uuid=%s, accountID=%d, chain=%s",
|
||||
safeShortPrefix(foundUUID, 8), foundAccountID, truncateDigestChain(geminiDigestChain))
|
||||
reqLog.Info("gemini.digest_fallback_matched",
|
||||
zap.String("session_uuid_prefix", safeShortPrefix(foundUUID, 8)),
|
||||
zap.Int64("account_id", foundAccountID),
|
||||
zap.String("digest_chain", truncateDigestChain(geminiDigestChain)),
|
||||
)
|
||||
|
||||
// 关键:如果原 sessionKey 为空,使用 prefixHash + uuid 作为 sessionKey
|
||||
// 这样 SelectAccountWithLoadAwareness 的粘性会话逻辑会优先使用匹配到的账号
|
||||
@@ -321,55 +342,54 @@ func (h *GatewayHandler) GeminiV1BetaModels(c *gin.Context) {
|
||||
hasBoundSession := sessionKey != "" && sessionBoundAccountID > 0
|
||||
cleanedForUnknownBinding := false
|
||||
|
||||
maxAccountSwitches := h.maxAccountSwitchesGemini
|
||||
switchCount := 0
|
||||
failedAccountIDs := make(map[int64]struct{})
|
||||
var lastFailoverErr *service.UpstreamFailoverError
|
||||
var forceCacheBilling bool // 粘性会话切换时的缓存计费标记
|
||||
fs := NewFailoverState(h.maxAccountSwitchesGemini, hasBoundSession)
|
||||
|
||||
// 单账号分组提前设置 SingleAccountRetry 标记,让 Service 层首次 503 就不设模型限流标记。
|
||||
// 避免单账号分组收到 503 (MODEL_CAPACITY_EXHAUSTED) 时设 29s 限流,导致后续请求连续快速失败。
|
||||
if h.gatewayService.IsSingleAntigravityAccountGroup(c.Request.Context(), apiKey.GroupID) {
|
||||
ctx := context.WithValue(c.Request.Context(), ctxkey.SingleAccountRetry, true)
|
||||
ctx := service.WithSingleAccountRetry(c.Request.Context(), true, h.metadataBridgeEnabled())
|
||||
c.Request = c.Request.WithContext(ctx)
|
||||
}
|
||||
|
||||
for {
|
||||
selection, err := h.gatewayService.SelectAccountWithLoadAwareness(c.Request.Context(), apiKey.GroupID, sessionKey, modelName, failedAccountIDs, "") // Gemini 不使用会话限制
|
||||
selection, err := h.gatewayService.SelectAccountWithLoadAwareness(c.Request.Context(), apiKey.GroupID, sessionKey, modelName, fs.FailedAccountIDs, "") // Gemini 不使用会话限制
|
||||
if err != nil {
|
||||
if len(failedAccountIDs) == 0 {
|
||||
if len(fs.FailedAccountIDs) == 0 {
|
||||
googleError(c, http.StatusServiceUnavailable, "No available Gemini accounts: "+err.Error())
|
||||
return
|
||||
}
|
||||
// Antigravity 单账号退避重试:分组内没有其他可用账号时,
|
||||
// 对 503 错误不直接返回,而是清除排除列表、等待退避后重试同一个账号。
|
||||
// 谷歌上游 503 (MODEL_CAPACITY_EXHAUSTED) 通常是暂时性的,等几秒就能恢复。
|
||||
if lastFailoverErr != nil && lastFailoverErr.StatusCode == http.StatusServiceUnavailable && switchCount <= maxAccountSwitches {
|
||||
if sleepAntigravitySingleAccountBackoff(c.Request.Context(), switchCount) {
|
||||
log.Printf("Antigravity single-account 503 retry: clearing failed accounts, retry %d/%d", switchCount, maxAccountSwitches)
|
||||
failedAccountIDs = make(map[int64]struct{})
|
||||
// 设置 context 标记,让 Service 层预检查等待限流过期而非直接切换
|
||||
ctx := context.WithValue(c.Request.Context(), ctxkey.SingleAccountRetry, true)
|
||||
c.Request = c.Request.WithContext(ctx)
|
||||
continue
|
||||
}
|
||||
action := fs.HandleSelectionExhausted(c.Request.Context())
|
||||
switch action {
|
||||
case FailoverContinue:
|
||||
ctx := service.WithSingleAccountRetry(c.Request.Context(), true, h.metadataBridgeEnabled())
|
||||
c.Request = c.Request.WithContext(ctx)
|
||||
continue
|
||||
case FailoverCanceled:
|
||||
return
|
||||
default: // FailoverExhausted
|
||||
h.handleGeminiFailoverExhausted(c, fs.LastFailoverErr)
|
||||
return
|
||||
}
|
||||
h.handleGeminiFailoverExhausted(c, lastFailoverErr)
|
||||
return
|
||||
}
|
||||
account := selection.Account
|
||||
setOpsSelectedAccount(c, account.ID)
|
||||
setOpsSelectedAccount(c, account.ID, account.Platform)
|
||||
|
||||
// 检测账号切换:如果粘性会话绑定的账号与当前选择的账号不同,清除 thoughtSignature
|
||||
// 注意:Gemini 原生 API 的 thoughtSignature 与具体上游账号强相关;跨账号透传会导致 400。
|
||||
if sessionBoundAccountID > 0 && sessionBoundAccountID != account.ID {
|
||||
log.Printf("[Gemini] Sticky session account switched: %d -> %d, cleaning thoughtSignature", sessionBoundAccountID, account.ID)
|
||||
reqLog.Info("gemini.sticky_session_account_switched",
|
||||
zap.Int64("from_account_id", sessionBoundAccountID),
|
||||
zap.Int64("to_account_id", account.ID),
|
||||
zap.Bool("clean_thought_signature", true),
|
||||
)
|
||||
body = service.CleanGeminiNativeThoughtSignatures(body)
|
||||
sessionBoundAccountID = account.ID
|
||||
} else if sessionKey != "" && sessionBoundAccountID == 0 && !cleanedForUnknownBinding && bytes.Contains(body, []byte(`"thoughtSignature"`)) {
|
||||
// 无缓存绑定但请求里已有 thoughtSignature:常见于缓存丢失/TTL 过期后,客户端继续携带旧签名。
|
||||
// 为避免第一次转发就 400,这里做一次确定性清理,让新账号重新生成签名链路。
|
||||
log.Printf("[Gemini] Sticky session binding missing, cleaning thoughtSignature proactively")
|
||||
reqLog.Info("gemini.sticky_session_binding_missing",
|
||||
zap.Bool("clean_thought_signature", true),
|
||||
)
|
||||
body = service.CleanGeminiNativeThoughtSignatures(body)
|
||||
cleanedForUnknownBinding = true
|
||||
sessionBoundAccountID = account.ID
|
||||
@@ -388,9 +408,12 @@ func (h *GatewayHandler) GeminiV1BetaModels(c *gin.Context) {
|
||||
accountWaitCounted := false
|
||||
canWait, err := geminiConcurrency.IncrementAccountWaitCount(c.Request.Context(), account.ID, selection.WaitPlan.MaxWaiting)
|
||||
if err != nil {
|
||||
log.Printf("Increment account wait count failed: %v", err)
|
||||
reqLog.Warn("gemini.account_wait_counter_increment_failed", zap.Int64("account_id", account.ID), zap.Error(err))
|
||||
} else if !canWait {
|
||||
log.Printf("Account wait queue full: account=%d", account.ID)
|
||||
reqLog.Info("gemini.account_wait_queue_full",
|
||||
zap.Int64("account_id", account.ID),
|
||||
zap.Int("max_waiting", selection.WaitPlan.MaxWaiting),
|
||||
)
|
||||
googleError(c, http.StatusTooManyRequests, "Too many pending requests, please retry later")
|
||||
return
|
||||
}
|
||||
@@ -412,6 +435,7 @@ func (h *GatewayHandler) GeminiV1BetaModels(c *gin.Context) {
|
||||
&streamStarted,
|
||||
)
|
||||
if err != nil {
|
||||
reqLog.Warn("gemini.account_slot_acquire_failed", zap.Int64("account_id", account.ID), zap.Error(err))
|
||||
googleError(c, http.StatusTooManyRequests, err.Error())
|
||||
return
|
||||
}
|
||||
@@ -420,7 +444,7 @@ func (h *GatewayHandler) GeminiV1BetaModels(c *gin.Context) {
|
||||
accountWaitCounted = false
|
||||
}
|
||||
if err := h.gatewayService.BindStickySession(c.Request.Context(), apiKey.GroupID, sessionKey, account.ID); err != nil {
|
||||
log.Printf("Bind sticky session failed: %v", err)
|
||||
reqLog.Warn("gemini.bind_sticky_session_failed", zap.Int64("account_id", account.ID), zap.Error(err))
|
||||
}
|
||||
}
|
||||
// 账号槽位/等待计数需要在超时或断开时安全回收
|
||||
@@ -429,8 +453,8 @@ func (h *GatewayHandler) GeminiV1BetaModels(c *gin.Context) {
|
||||
// 5) forward (根据平台分流)
|
||||
var result *service.ForwardResult
|
||||
requestCtx := c.Request.Context()
|
||||
if switchCount > 0 {
|
||||
requestCtx = context.WithValue(requestCtx, ctxkey.AccountSwitchCount, switchCount)
|
||||
if fs.SwitchCount > 0 {
|
||||
requestCtx = service.WithAccountSwitchCount(requestCtx, fs.SwitchCount, h.metadataBridgeEnabled())
|
||||
}
|
||||
if account.Platform == service.PlatformAntigravity && account.Type != service.AccountTypeAPIKey {
|
||||
result, err = h.antigravityGatewayService.ForwardGemini(requestCtx, c, account, modelName, action, stream, body, hasBoundSession)
|
||||
@@ -443,27 +467,19 @@ func (h *GatewayHandler) GeminiV1BetaModels(c *gin.Context) {
|
||||
if err != nil {
|
||||
var failoverErr *service.UpstreamFailoverError
|
||||
if errors.As(err, &failoverErr) {
|
||||
failedAccountIDs[account.ID] = struct{}{}
|
||||
if needForceCacheBilling(hasBoundSession, failoverErr) {
|
||||
forceCacheBilling = true
|
||||
}
|
||||
if switchCount >= maxAccountSwitches {
|
||||
lastFailoverErr = failoverErr
|
||||
h.handleGeminiFailoverExhausted(c, lastFailoverErr)
|
||||
failoverAction := fs.HandleFailoverError(c.Request.Context(), h.gatewayService, account.ID, account.Platform, failoverErr)
|
||||
switch failoverAction {
|
||||
case FailoverContinue:
|
||||
continue
|
||||
case FailoverExhausted:
|
||||
h.handleGeminiFailoverExhausted(c, fs.LastFailoverErr)
|
||||
return
|
||||
case FailoverCanceled:
|
||||
return
|
||||
}
|
||||
lastFailoverErr = failoverErr
|
||||
switchCount++
|
||||
log.Printf("Gemini account %d: upstream error %d, switching account %d/%d", account.ID, failoverErr.StatusCode, switchCount, maxAccountSwitches)
|
||||
if account.Platform == service.PlatformAntigravity {
|
||||
if !sleepFailoverDelay(c.Request.Context(), switchCount) {
|
||||
return
|
||||
}
|
||||
}
|
||||
continue
|
||||
}
|
||||
// ForwardNative already wrote the response
|
||||
log.Printf("Gemini native forward failed: %v", err)
|
||||
reqLog.Error("gemini.forward_failed", zap.Int64("account_id", account.ID), zap.Error(err))
|
||||
return
|
||||
}
|
||||
|
||||
@@ -482,31 +498,39 @@ func (h *GatewayHandler) GeminiV1BetaModels(c *gin.Context) {
|
||||
account.ID,
|
||||
matchedDigestChain,
|
||||
); err != nil {
|
||||
log.Printf("[Gemini] Failed to save digest session: %v", err)
|
||||
reqLog.Warn("gemini.digest_session_save_failed", zap.Int64("account_id", account.ID), zap.Error(err))
|
||||
}
|
||||
}
|
||||
|
||||
// 6) record usage async (Gemini 使用长上下文双倍计费)
|
||||
go func(result *service.ForwardResult, usedAccount *service.Account, ua, ip string, fcb bool) {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
|
||||
defer cancel()
|
||||
|
||||
// 使用量记录通过有界 worker 池提交,避免请求热路径创建无界 goroutine。
|
||||
h.submitUsageRecordTask(func(ctx context.Context) {
|
||||
if err := h.gatewayService.RecordUsageWithLongContext(ctx, &service.RecordUsageLongContextInput{
|
||||
Result: result,
|
||||
APIKey: apiKey,
|
||||
User: apiKey.User,
|
||||
Account: usedAccount,
|
||||
Account: account,
|
||||
Subscription: subscription,
|
||||
UserAgent: ua,
|
||||
IPAddress: ip,
|
||||
UserAgent: userAgent,
|
||||
IPAddress: clientIP,
|
||||
LongContextThreshold: 200000, // Gemini 200K 阈值
|
||||
LongContextMultiplier: 2.0, // 超出部分双倍计费
|
||||
ForceCacheBilling: fcb,
|
||||
ForceCacheBilling: fs.ForceCacheBilling,
|
||||
APIKeyService: h.apiKeyService,
|
||||
}); err != nil {
|
||||
log.Printf("Record usage failed: %v", err)
|
||||
logger.L().With(
|
||||
zap.String("component", "handler.gemini_v1beta.models"),
|
||||
zap.Int64("user_id", authSubject.UserID),
|
||||
zap.Int64("api_key_id", apiKey.ID),
|
||||
zap.Any("group_id", apiKey.GroupID),
|
||||
zap.String("model", modelName),
|
||||
zap.Int64("account_id", account.ID),
|
||||
).Error("gemini.record_usage_failed", zap.Error(err))
|
||||
}
|
||||
}(result, account, userAgent, clientIP, forceCacheBilling)
|
||||
})
|
||||
reqLog.Debug("gemini.request_completed",
|
||||
zap.Int64("account_id", account.ID),
|
||||
zap.Int("switch_count", fs.SwitchCount),
|
||||
)
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
@@ -11,6 +11,7 @@ type AdminHandlers struct {
|
||||
Group *admin.GroupHandler
|
||||
Account *admin.AccountHandler
|
||||
Announcement *admin.AnnouncementHandler
|
||||
DataManagement *admin.DataManagementHandler
|
||||
OAuth *admin.OAuthHandler
|
||||
OpenAIOAuth *admin.OpenAIOAuthHandler
|
||||
GeminiOAuth *admin.GeminiOAuthHandler
|
||||
@@ -25,6 +26,7 @@ type AdminHandlers struct {
|
||||
Usage *admin.UsageHandler
|
||||
UserAttribute *admin.UserAttributeHandler
|
||||
ErrorPassthrough *admin.ErrorPassthroughHandler
|
||||
APIKey *admin.AdminAPIKeyHandler
|
||||
}
|
||||
|
||||
// Handlers contains all HTTP handlers
|
||||
@@ -39,6 +41,8 @@ type Handlers struct {
|
||||
Admin *AdminHandlers
|
||||
Gateway *GatewayHandler
|
||||
OpenAIGateway *OpenAIGatewayHandler
|
||||
SoraGateway *SoraGatewayHandler
|
||||
SoraClient *SoraClientHandler
|
||||
Setting *SettingHandler
|
||||
Totp *TotpHandler
|
||||
}
|
||||
|
||||
65
backend/internal/handler/idempotency_helper.go
Normal file
65
backend/internal/handler/idempotency_helper.go
Normal file
@@ -0,0 +1,65 @@
|
||||
package handler
|
||||
|
||||
import (
|
||||
"context"
|
||||
"strconv"
|
||||
"time"
|
||||
|
||||
infraerrors "github.com/Wei-Shaw/sub2api/internal/pkg/errors"
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/logger"
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/response"
|
||||
middleware2 "github.com/Wei-Shaw/sub2api/internal/server/middleware"
|
||||
"github.com/Wei-Shaw/sub2api/internal/service"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
func executeUserIdempotentJSON(
|
||||
c *gin.Context,
|
||||
scope string,
|
||||
payload any,
|
||||
ttl time.Duration,
|
||||
execute func(context.Context) (any, error),
|
||||
) {
|
||||
coordinator := service.DefaultIdempotencyCoordinator()
|
||||
if coordinator == nil {
|
||||
data, err := execute(c.Request.Context())
|
||||
if err != nil {
|
||||
response.ErrorFrom(c, err)
|
||||
return
|
||||
}
|
||||
response.Success(c, data)
|
||||
return
|
||||
}
|
||||
|
||||
actorScope := "user:0"
|
||||
if subject, ok := middleware2.GetAuthSubjectFromContext(c); ok {
|
||||
actorScope = "user:" + strconv.FormatInt(subject.UserID, 10)
|
||||
}
|
||||
|
||||
result, err := coordinator.Execute(c.Request.Context(), service.IdempotencyExecuteOptions{
|
||||
Scope: scope,
|
||||
ActorScope: actorScope,
|
||||
Method: c.Request.Method,
|
||||
Route: c.FullPath(),
|
||||
IdempotencyKey: c.GetHeader("Idempotency-Key"),
|
||||
Payload: payload,
|
||||
RequireKey: true,
|
||||
TTL: ttl,
|
||||
}, execute)
|
||||
if err != nil {
|
||||
if infraerrors.Code(err) == infraerrors.Code(service.ErrIdempotencyStoreUnavail) {
|
||||
service.RecordIdempotencyStoreUnavailable(c.FullPath(), scope, "handler_fail_close")
|
||||
logger.LegacyPrintf("handler.idempotency", "[Idempotency] store unavailable: method=%s route=%s scope=%s strategy=fail_close", c.Request.Method, c.FullPath(), scope)
|
||||
}
|
||||
if retryAfter := service.RetryAfterSecondsFromError(err); retryAfter > 0 {
|
||||
c.Header("Retry-After", strconv.Itoa(retryAfter))
|
||||
}
|
||||
response.ErrorFrom(c, err)
|
||||
return
|
||||
}
|
||||
if result != nil && result.Replayed {
|
||||
c.Header("X-Idempotency-Replayed", "true")
|
||||
}
|
||||
response.Success(c, result.Data)
|
||||
}
|
||||
285
backend/internal/handler/idempotency_helper_test.go
Normal file
285
backend/internal/handler/idempotency_helper_test.go
Normal file
@@ -0,0 +1,285 @@
|
||||
package handler
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"errors"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/server/middleware"
|
||||
"github.com/Wei-Shaw/sub2api/internal/service"
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
type userStoreUnavailableRepoStub struct{}
|
||||
|
||||
func (userStoreUnavailableRepoStub) CreateProcessing(context.Context, *service.IdempotencyRecord) (bool, error) {
|
||||
return false, errors.New("store unavailable")
|
||||
}
|
||||
func (userStoreUnavailableRepoStub) GetByScopeAndKeyHash(context.Context, string, string) (*service.IdempotencyRecord, error) {
|
||||
return nil, errors.New("store unavailable")
|
||||
}
|
||||
func (userStoreUnavailableRepoStub) TryReclaim(context.Context, int64, string, time.Time, time.Time, time.Time) (bool, error) {
|
||||
return false, errors.New("store unavailable")
|
||||
}
|
||||
func (userStoreUnavailableRepoStub) ExtendProcessingLock(context.Context, int64, string, time.Time, time.Time) (bool, error) {
|
||||
return false, errors.New("store unavailable")
|
||||
}
|
||||
func (userStoreUnavailableRepoStub) MarkSucceeded(context.Context, int64, int, string, time.Time) error {
|
||||
return errors.New("store unavailable")
|
||||
}
|
||||
func (userStoreUnavailableRepoStub) MarkFailedRetryable(context.Context, int64, string, time.Time, time.Time) error {
|
||||
return errors.New("store unavailable")
|
||||
}
|
||||
func (userStoreUnavailableRepoStub) DeleteExpired(context.Context, time.Time, int) (int64, error) {
|
||||
return 0, errors.New("store unavailable")
|
||||
}
|
||||
|
||||
type userMemoryIdempotencyRepoStub struct {
|
||||
mu sync.Mutex
|
||||
nextID int64
|
||||
data map[string]*service.IdempotencyRecord
|
||||
}
|
||||
|
||||
func newUserMemoryIdempotencyRepoStub() *userMemoryIdempotencyRepoStub {
|
||||
return &userMemoryIdempotencyRepoStub{
|
||||
nextID: 1,
|
||||
data: make(map[string]*service.IdempotencyRecord),
|
||||
}
|
||||
}
|
||||
|
||||
func (r *userMemoryIdempotencyRepoStub) key(scope, keyHash string) string {
|
||||
return scope + "|" + keyHash
|
||||
}
|
||||
|
||||
func (r *userMemoryIdempotencyRepoStub) clone(in *service.IdempotencyRecord) *service.IdempotencyRecord {
|
||||
if in == nil {
|
||||
return nil
|
||||
}
|
||||
out := *in
|
||||
if in.LockedUntil != nil {
|
||||
v := *in.LockedUntil
|
||||
out.LockedUntil = &v
|
||||
}
|
||||
if in.ResponseBody != nil {
|
||||
v := *in.ResponseBody
|
||||
out.ResponseBody = &v
|
||||
}
|
||||
if in.ResponseStatus != nil {
|
||||
v := *in.ResponseStatus
|
||||
out.ResponseStatus = &v
|
||||
}
|
||||
if in.ErrorReason != nil {
|
||||
v := *in.ErrorReason
|
||||
out.ErrorReason = &v
|
||||
}
|
||||
return &out
|
||||
}
|
||||
|
||||
func (r *userMemoryIdempotencyRepoStub) CreateProcessing(_ context.Context, record *service.IdempotencyRecord) (bool, error) {
|
||||
r.mu.Lock()
|
||||
defer r.mu.Unlock()
|
||||
k := r.key(record.Scope, record.IdempotencyKeyHash)
|
||||
if _, ok := r.data[k]; ok {
|
||||
return false, nil
|
||||
}
|
||||
cp := r.clone(record)
|
||||
cp.ID = r.nextID
|
||||
r.nextID++
|
||||
r.data[k] = cp
|
||||
record.ID = cp.ID
|
||||
return true, nil
|
||||
}
|
||||
|
||||
func (r *userMemoryIdempotencyRepoStub) GetByScopeAndKeyHash(_ context.Context, scope, keyHash string) (*service.IdempotencyRecord, error) {
|
||||
r.mu.Lock()
|
||||
defer r.mu.Unlock()
|
||||
return r.clone(r.data[r.key(scope, keyHash)]), nil
|
||||
}
|
||||
|
||||
func (r *userMemoryIdempotencyRepoStub) TryReclaim(_ context.Context, id int64, fromStatus string, now, newLockedUntil, newExpiresAt time.Time) (bool, error) {
|
||||
r.mu.Lock()
|
||||
defer r.mu.Unlock()
|
||||
for _, rec := range r.data {
|
||||
if rec.ID != id {
|
||||
continue
|
||||
}
|
||||
if rec.Status != fromStatus {
|
||||
return false, nil
|
||||
}
|
||||
if rec.LockedUntil != nil && rec.LockedUntil.After(now) {
|
||||
return false, nil
|
||||
}
|
||||
rec.Status = service.IdempotencyStatusProcessing
|
||||
rec.LockedUntil = &newLockedUntil
|
||||
rec.ExpiresAt = newExpiresAt
|
||||
rec.ErrorReason = nil
|
||||
return true, nil
|
||||
}
|
||||
return false, nil
|
||||
}
|
||||
|
||||
func (r *userMemoryIdempotencyRepoStub) ExtendProcessingLock(_ context.Context, id int64, requestFingerprint string, newLockedUntil, newExpiresAt time.Time) (bool, error) {
|
||||
r.mu.Lock()
|
||||
defer r.mu.Unlock()
|
||||
for _, rec := range r.data {
|
||||
if rec.ID != id {
|
||||
continue
|
||||
}
|
||||
if rec.Status != service.IdempotencyStatusProcessing || rec.RequestFingerprint != requestFingerprint {
|
||||
return false, nil
|
||||
}
|
||||
rec.LockedUntil = &newLockedUntil
|
||||
rec.ExpiresAt = newExpiresAt
|
||||
return true, nil
|
||||
}
|
||||
return false, nil
|
||||
}
|
||||
|
||||
func (r *userMemoryIdempotencyRepoStub) MarkSucceeded(_ context.Context, id int64, responseStatus int, responseBody string, expiresAt time.Time) error {
|
||||
r.mu.Lock()
|
||||
defer r.mu.Unlock()
|
||||
for _, rec := range r.data {
|
||||
if rec.ID != id {
|
||||
continue
|
||||
}
|
||||
rec.Status = service.IdempotencyStatusSucceeded
|
||||
rec.LockedUntil = nil
|
||||
rec.ExpiresAt = expiresAt
|
||||
rec.ResponseStatus = &responseStatus
|
||||
rec.ResponseBody = &responseBody
|
||||
rec.ErrorReason = nil
|
||||
return nil
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (r *userMemoryIdempotencyRepoStub) MarkFailedRetryable(_ context.Context, id int64, errorReason string, lockedUntil, expiresAt time.Time) error {
|
||||
r.mu.Lock()
|
||||
defer r.mu.Unlock()
|
||||
for _, rec := range r.data {
|
||||
if rec.ID != id {
|
||||
continue
|
||||
}
|
||||
rec.Status = service.IdempotencyStatusFailedRetryable
|
||||
rec.LockedUntil = &lockedUntil
|
||||
rec.ExpiresAt = expiresAt
|
||||
rec.ErrorReason = &errorReason
|
||||
return nil
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (r *userMemoryIdempotencyRepoStub) DeleteExpired(_ context.Context, _ time.Time, _ int) (int64, error) {
|
||||
return 0, nil
|
||||
}
|
||||
|
||||
func withUserSubject(userID int64) gin.HandlerFunc {
|
||||
return func(c *gin.Context) {
|
||||
c.Set(string(middleware.ContextKeyUser), middleware.AuthSubject{UserID: userID})
|
||||
c.Next()
|
||||
}
|
||||
}
|
||||
|
||||
func TestExecuteUserIdempotentJSONFallbackWithoutCoordinator(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
service.SetDefaultIdempotencyCoordinator(nil)
|
||||
|
||||
var executed int
|
||||
router := gin.New()
|
||||
router.Use(withUserSubject(1))
|
||||
router.POST("/idempotent", func(c *gin.Context) {
|
||||
executeUserIdempotentJSON(c, "user.test.scope", map[string]any{"a": 1}, time.Minute, func(ctx context.Context) (any, error) {
|
||||
executed++
|
||||
return gin.H{"ok": true}, nil
|
||||
})
|
||||
})
|
||||
|
||||
req := httptest.NewRequest(http.MethodPost, "/idempotent", bytes.NewBufferString(`{"a":1}`))
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
rec := httptest.NewRecorder()
|
||||
router.ServeHTTP(rec, req)
|
||||
|
||||
require.Equal(t, http.StatusOK, rec.Code)
|
||||
require.Equal(t, 1, executed)
|
||||
}
|
||||
|
||||
func TestExecuteUserIdempotentJSONFailCloseOnStoreUnavailable(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
service.SetDefaultIdempotencyCoordinator(service.NewIdempotencyCoordinator(userStoreUnavailableRepoStub{}, service.DefaultIdempotencyConfig()))
|
||||
t.Cleanup(func() {
|
||||
service.SetDefaultIdempotencyCoordinator(nil)
|
||||
})
|
||||
|
||||
var executed int
|
||||
router := gin.New()
|
||||
router.Use(withUserSubject(2))
|
||||
router.POST("/idempotent", func(c *gin.Context) {
|
||||
executeUserIdempotentJSON(c, "user.test.scope", map[string]any{"a": 1}, time.Minute, func(ctx context.Context) (any, error) {
|
||||
executed++
|
||||
return gin.H{"ok": true}, nil
|
||||
})
|
||||
})
|
||||
|
||||
req := httptest.NewRequest(http.MethodPost, "/idempotent", bytes.NewBufferString(`{"a":1}`))
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
req.Header.Set("Idempotency-Key", "k1")
|
||||
rec := httptest.NewRecorder()
|
||||
router.ServeHTTP(rec, req)
|
||||
|
||||
require.Equal(t, http.StatusServiceUnavailable, rec.Code)
|
||||
require.Equal(t, 0, executed)
|
||||
}
|
||||
|
||||
func TestExecuteUserIdempotentJSONConcurrentRetrySingleSideEffectAndReplay(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
repo := newUserMemoryIdempotencyRepoStub()
|
||||
cfg := service.DefaultIdempotencyConfig()
|
||||
cfg.ProcessingTimeout = 2 * time.Second
|
||||
service.SetDefaultIdempotencyCoordinator(service.NewIdempotencyCoordinator(repo, cfg))
|
||||
t.Cleanup(func() {
|
||||
service.SetDefaultIdempotencyCoordinator(nil)
|
||||
})
|
||||
|
||||
var executed atomic.Int32
|
||||
router := gin.New()
|
||||
router.Use(withUserSubject(3))
|
||||
router.POST("/idempotent", func(c *gin.Context) {
|
||||
executeUserIdempotentJSON(c, "user.test.scope", map[string]any{"a": 1}, time.Minute, func(ctx context.Context) (any, error) {
|
||||
executed.Add(1)
|
||||
time.Sleep(80 * time.Millisecond)
|
||||
return gin.H{"ok": true}, nil
|
||||
})
|
||||
})
|
||||
|
||||
call := func() (int, http.Header) {
|
||||
req := httptest.NewRequest(http.MethodPost, "/idempotent", bytes.NewBufferString(`{"a":1}`))
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
req.Header.Set("Idempotency-Key", "same-user-key")
|
||||
rec := httptest.NewRecorder()
|
||||
router.ServeHTTP(rec, req)
|
||||
return rec.Code, rec.Header()
|
||||
}
|
||||
|
||||
var status1, status2 int
|
||||
var wg sync.WaitGroup
|
||||
wg.Add(2)
|
||||
go func() { defer wg.Done(); status1, _ = call() }()
|
||||
go func() { defer wg.Done(); status2, _ = call() }()
|
||||
wg.Wait()
|
||||
|
||||
require.Contains(t, []int{http.StatusOK, http.StatusConflict}, status1)
|
||||
require.Contains(t, []int{http.StatusOK, http.StatusConflict}, status2)
|
||||
require.Equal(t, int32(1), executed.Load())
|
||||
|
||||
status3, headers3 := call()
|
||||
require.Equal(t, http.StatusOK, status3)
|
||||
require.Equal(t, "true", headers3.Get("X-Idempotency-Replayed"))
|
||||
require.Equal(t, int32(1), executed.Load())
|
||||
}
|
||||
19
backend/internal/handler/logging.go
Normal file
19
backend/internal/handler/logging.go
Normal file
@@ -0,0 +1,19 @@
|
||||
package handler
|
||||
|
||||
import (
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/logger"
|
||||
"github.com/gin-gonic/gin"
|
||||
"go.uber.org/zap"
|
||||
)
|
||||
|
||||
func requestLogger(c *gin.Context, component string, fields ...zap.Field) *zap.Logger {
|
||||
base := logger.L()
|
||||
if c != nil && c.Request != nil {
|
||||
base = logger.FromContext(c.Request.Context())
|
||||
}
|
||||
|
||||
if component != "" {
|
||||
fields = append([]zap.Field{zap.String("component", component)}, fields...)
|
||||
}
|
||||
return base.With(fields...)
|
||||
}
|
||||
File diff suppressed because it is too large
Load Diff
677
backend/internal/handler/openai_gateway_handler_test.go
Normal file
677
backend/internal/handler/openai_gateway_handler_test.go
Normal file
@@ -0,0 +1,677 @@
|
||||
package handler
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
pkghttputil "github.com/Wei-Shaw/sub2api/internal/pkg/httputil"
|
||||
"github.com/Wei-Shaw/sub2api/internal/server/middleware"
|
||||
"github.com/Wei-Shaw/sub2api/internal/service"
|
||||
coderws "github.com/coder/websocket"
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
"github.com/tidwall/gjson"
|
||||
"github.com/tidwall/sjson"
|
||||
)
|
||||
|
||||
func TestOpenAIHandleStreamingAwareError_JSONEscaping(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
errType string
|
||||
message string
|
||||
}{
|
||||
{
|
||||
name: "包含双引号的消息",
|
||||
errType: "server_error",
|
||||
message: `upstream returned "invalid" response`,
|
||||
},
|
||||
{
|
||||
name: "包含反斜杠的消息",
|
||||
errType: "server_error",
|
||||
message: `path C:\Users\test\file.txt not found`,
|
||||
},
|
||||
{
|
||||
name: "包含双引号和反斜杠的消息",
|
||||
errType: "upstream_error",
|
||||
message: `error parsing "key\value": unexpected token`,
|
||||
},
|
||||
{
|
||||
name: "包含换行符的消息",
|
||||
errType: "server_error",
|
||||
message: "line1\nline2\ttab",
|
||||
},
|
||||
{
|
||||
name: "普通消息",
|
||||
errType: "upstream_error",
|
||||
message: "Upstream service temporarily unavailable",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
w := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(w)
|
||||
c.Request = httptest.NewRequest(http.MethodGet, "/", nil)
|
||||
|
||||
h := &OpenAIGatewayHandler{}
|
||||
h.handleStreamingAwareError(c, http.StatusBadGateway, tt.errType, tt.message, true)
|
||||
|
||||
body := w.Body.String()
|
||||
|
||||
// 验证 SSE 格式:event: error\ndata: {JSON}\n\n
|
||||
assert.True(t, strings.HasPrefix(body, "event: error\n"), "应以 'event: error\\n' 开头")
|
||||
assert.True(t, strings.HasSuffix(body, "\n\n"), "应以 '\\n\\n' 结尾")
|
||||
|
||||
// 提取 data 部分
|
||||
lines := strings.Split(strings.TrimSuffix(body, "\n\n"), "\n")
|
||||
require.Len(t, lines, 2, "应有 event 行和 data 行")
|
||||
dataLine := lines[1]
|
||||
require.True(t, strings.HasPrefix(dataLine, "data: "), "第二行应以 'data: ' 开头")
|
||||
jsonStr := strings.TrimPrefix(dataLine, "data: ")
|
||||
|
||||
// 验证 JSON 合法性
|
||||
var parsed map[string]any
|
||||
err := json.Unmarshal([]byte(jsonStr), &parsed)
|
||||
require.NoError(t, err, "JSON 应能被成功解析,原始 JSON: %s", jsonStr)
|
||||
|
||||
// 验证结构
|
||||
errorObj, ok := parsed["error"].(map[string]any)
|
||||
require.True(t, ok, "应包含 error 对象")
|
||||
assert.Equal(t, tt.errType, errorObj["type"])
|
||||
assert.Equal(t, tt.message, errorObj["message"])
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestOpenAIHandleStreamingAwareError_NonStreaming(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
w := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(w)
|
||||
c.Request = httptest.NewRequest(http.MethodGet, "/", nil)
|
||||
|
||||
h := &OpenAIGatewayHandler{}
|
||||
h.handleStreamingAwareError(c, http.StatusBadGateway, "upstream_error", "test error", false)
|
||||
|
||||
// 非流式应返回 JSON 响应
|
||||
assert.Equal(t, http.StatusBadGateway, w.Code)
|
||||
|
||||
var parsed map[string]any
|
||||
err := json.Unmarshal(w.Body.Bytes(), &parsed)
|
||||
require.NoError(t, err)
|
||||
errorObj, ok := parsed["error"].(map[string]any)
|
||||
require.True(t, ok)
|
||||
assert.Equal(t, "upstream_error", errorObj["type"])
|
||||
assert.Equal(t, "test error", errorObj["message"])
|
||||
}
|
||||
|
||||
func TestReadRequestBodyWithPrealloc(t *testing.T) {
|
||||
payload := `{"model":"gpt-5","input":"hello"}`
|
||||
req := httptest.NewRequest(http.MethodPost, "/v1/responses", strings.NewReader(payload))
|
||||
req.ContentLength = int64(len(payload))
|
||||
|
||||
body, err := pkghttputil.ReadRequestBodyWithPrealloc(req)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, payload, string(body))
|
||||
}
|
||||
|
||||
func TestReadRequestBodyWithPrealloc_MaxBytesError(t *testing.T) {
|
||||
rec := httptest.NewRecorder()
|
||||
req := httptest.NewRequest(http.MethodPost, "/v1/responses", strings.NewReader(strings.Repeat("x", 8)))
|
||||
req.Body = http.MaxBytesReader(rec, req.Body, 4)
|
||||
|
||||
_, err := pkghttputil.ReadRequestBodyWithPrealloc(req)
|
||||
require.Error(t, err)
|
||||
var maxErr *http.MaxBytesError
|
||||
require.ErrorAs(t, err, &maxErr)
|
||||
}
|
||||
|
||||
func TestOpenAIEnsureForwardErrorResponse_WritesFallbackWhenNotWritten(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
w := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(w)
|
||||
c.Request = httptest.NewRequest(http.MethodGet, "/", nil)
|
||||
|
||||
h := &OpenAIGatewayHandler{}
|
||||
wrote := h.ensureForwardErrorResponse(c, false)
|
||||
|
||||
require.True(t, wrote)
|
||||
require.Equal(t, http.StatusBadGateway, w.Code)
|
||||
|
||||
var parsed map[string]any
|
||||
err := json.Unmarshal(w.Body.Bytes(), &parsed)
|
||||
require.NoError(t, err)
|
||||
errorObj, ok := parsed["error"].(map[string]any)
|
||||
require.True(t, ok)
|
||||
assert.Equal(t, "upstream_error", errorObj["type"])
|
||||
assert.Equal(t, "Upstream request failed", errorObj["message"])
|
||||
}
|
||||
|
||||
func TestOpenAIEnsureForwardErrorResponse_DoesNotOverrideWrittenResponse(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
w := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(w)
|
||||
c.Request = httptest.NewRequest(http.MethodGet, "/", nil)
|
||||
c.String(http.StatusTeapot, "already written")
|
||||
|
||||
h := &OpenAIGatewayHandler{}
|
||||
wrote := h.ensureForwardErrorResponse(c, false)
|
||||
|
||||
require.False(t, wrote)
|
||||
require.Equal(t, http.StatusTeapot, w.Code)
|
||||
assert.Equal(t, "already written", w.Body.String())
|
||||
}
|
||||
|
||||
func TestShouldLogOpenAIForwardFailureAsWarn(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
|
||||
t.Run("fallback_written_should_not_downgrade", func(t *testing.T) {
|
||||
w := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(w)
|
||||
c.Request = httptest.NewRequest(http.MethodGet, "/", nil)
|
||||
require.False(t, shouldLogOpenAIForwardFailureAsWarn(c, true))
|
||||
})
|
||||
|
||||
t.Run("context_nil_should_not_downgrade", func(t *testing.T) {
|
||||
require.False(t, shouldLogOpenAIForwardFailureAsWarn(nil, false))
|
||||
})
|
||||
|
||||
t.Run("response_not_written_should_not_downgrade", func(t *testing.T) {
|
||||
w := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(w)
|
||||
c.Request = httptest.NewRequest(http.MethodGet, "/", nil)
|
||||
require.False(t, shouldLogOpenAIForwardFailureAsWarn(c, false))
|
||||
})
|
||||
|
||||
t.Run("response_already_written_should_downgrade", func(t *testing.T) {
|
||||
w := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(w)
|
||||
c.Request = httptest.NewRequest(http.MethodGet, "/", nil)
|
||||
c.String(http.StatusForbidden, "already written")
|
||||
require.True(t, shouldLogOpenAIForwardFailureAsWarn(c, false))
|
||||
})
|
||||
}
|
||||
|
||||
func TestOpenAIRecoverResponsesPanic_WritesFallbackResponse(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
|
||||
w := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(w)
|
||||
c.Request = httptest.NewRequest(http.MethodPost, "/v1/responses", nil)
|
||||
|
||||
h := &OpenAIGatewayHandler{}
|
||||
streamStarted := false
|
||||
require.NotPanics(t, func() {
|
||||
func() {
|
||||
defer h.recoverResponsesPanic(c, &streamStarted)
|
||||
panic("test panic")
|
||||
}()
|
||||
})
|
||||
|
||||
require.Equal(t, http.StatusBadGateway, w.Code)
|
||||
|
||||
var parsed map[string]any
|
||||
err := json.Unmarshal(w.Body.Bytes(), &parsed)
|
||||
require.NoError(t, err)
|
||||
|
||||
errorObj, ok := parsed["error"].(map[string]any)
|
||||
require.True(t, ok)
|
||||
assert.Equal(t, "upstream_error", errorObj["type"])
|
||||
assert.Equal(t, "Upstream request failed", errorObj["message"])
|
||||
}
|
||||
|
||||
func TestOpenAIRecoverResponsesPanic_NoPanicNoWrite(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
|
||||
w := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(w)
|
||||
c.Request = httptest.NewRequest(http.MethodPost, "/v1/responses", nil)
|
||||
|
||||
h := &OpenAIGatewayHandler{}
|
||||
streamStarted := false
|
||||
require.NotPanics(t, func() {
|
||||
func() {
|
||||
defer h.recoverResponsesPanic(c, &streamStarted)
|
||||
}()
|
||||
})
|
||||
|
||||
require.False(t, c.Writer.Written())
|
||||
assert.Equal(t, "", w.Body.String())
|
||||
}
|
||||
|
||||
func TestOpenAIRecoverResponsesPanic_DoesNotOverrideWrittenResponse(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
|
||||
w := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(w)
|
||||
c.Request = httptest.NewRequest(http.MethodPost, "/v1/responses", nil)
|
||||
c.String(http.StatusTeapot, "already written")
|
||||
|
||||
h := &OpenAIGatewayHandler{}
|
||||
streamStarted := false
|
||||
require.NotPanics(t, func() {
|
||||
func() {
|
||||
defer h.recoverResponsesPanic(c, &streamStarted)
|
||||
panic("test panic")
|
||||
}()
|
||||
})
|
||||
|
||||
require.Equal(t, http.StatusTeapot, w.Code)
|
||||
assert.Equal(t, "already written", w.Body.String())
|
||||
}
|
||||
|
||||
func TestOpenAIMissingResponsesDependencies(t *testing.T) {
|
||||
t.Run("nil_handler", func(t *testing.T) {
|
||||
var h *OpenAIGatewayHandler
|
||||
require.Equal(t, []string{"handler"}, h.missingResponsesDependencies())
|
||||
})
|
||||
|
||||
t.Run("all_dependencies_missing", func(t *testing.T) {
|
||||
h := &OpenAIGatewayHandler{}
|
||||
require.Equal(t,
|
||||
[]string{"gatewayService", "billingCacheService", "apiKeyService", "concurrencyHelper"},
|
||||
h.missingResponsesDependencies(),
|
||||
)
|
||||
})
|
||||
|
||||
t.Run("all_dependencies_present", func(t *testing.T) {
|
||||
h := &OpenAIGatewayHandler{
|
||||
gatewayService: &service.OpenAIGatewayService{},
|
||||
billingCacheService: &service.BillingCacheService{},
|
||||
apiKeyService: &service.APIKeyService{},
|
||||
concurrencyHelper: &ConcurrencyHelper{
|
||||
concurrencyService: &service.ConcurrencyService{},
|
||||
},
|
||||
}
|
||||
require.Empty(t, h.missingResponsesDependencies())
|
||||
})
|
||||
}
|
||||
|
||||
func TestOpenAIEnsureResponsesDependencies(t *testing.T) {
|
||||
t.Run("missing_dependencies_returns_503", func(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
w := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(w)
|
||||
c.Request = httptest.NewRequest(http.MethodPost, "/v1/responses", nil)
|
||||
|
||||
h := &OpenAIGatewayHandler{}
|
||||
ok := h.ensureResponsesDependencies(c, nil)
|
||||
|
||||
require.False(t, ok)
|
||||
require.Equal(t, http.StatusServiceUnavailable, w.Code)
|
||||
var parsed map[string]any
|
||||
err := json.Unmarshal(w.Body.Bytes(), &parsed)
|
||||
require.NoError(t, err)
|
||||
errorObj, exists := parsed["error"].(map[string]any)
|
||||
require.True(t, exists)
|
||||
assert.Equal(t, "api_error", errorObj["type"])
|
||||
assert.Equal(t, "Service temporarily unavailable", errorObj["message"])
|
||||
})
|
||||
|
||||
t.Run("already_written_response_not_overridden", func(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
w := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(w)
|
||||
c.Request = httptest.NewRequest(http.MethodPost, "/v1/responses", nil)
|
||||
c.String(http.StatusTeapot, "already written")
|
||||
|
||||
h := &OpenAIGatewayHandler{}
|
||||
ok := h.ensureResponsesDependencies(c, nil)
|
||||
|
||||
require.False(t, ok)
|
||||
require.Equal(t, http.StatusTeapot, w.Code)
|
||||
assert.Equal(t, "already written", w.Body.String())
|
||||
})
|
||||
|
||||
t.Run("dependencies_ready_returns_true_and_no_write", func(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
w := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(w)
|
||||
c.Request = httptest.NewRequest(http.MethodPost, "/v1/responses", nil)
|
||||
|
||||
h := &OpenAIGatewayHandler{
|
||||
gatewayService: &service.OpenAIGatewayService{},
|
||||
billingCacheService: &service.BillingCacheService{},
|
||||
apiKeyService: &service.APIKeyService{},
|
||||
concurrencyHelper: &ConcurrencyHelper{
|
||||
concurrencyService: &service.ConcurrencyService{},
|
||||
},
|
||||
}
|
||||
ok := h.ensureResponsesDependencies(c, nil)
|
||||
|
||||
require.True(t, ok)
|
||||
require.False(t, c.Writer.Written())
|
||||
assert.Equal(t, "", w.Body.String())
|
||||
})
|
||||
}
|
||||
|
||||
func TestOpenAIResponses_MissingDependencies_ReturnsServiceUnavailable(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
|
||||
w := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(w)
|
||||
c.Request = httptest.NewRequest(http.MethodPost, "/v1/responses", strings.NewReader(`{"model":"gpt-5","stream":false}`))
|
||||
c.Request.Header.Set("Content-Type", "application/json")
|
||||
|
||||
groupID := int64(2)
|
||||
c.Set(string(middleware.ContextKeyAPIKey), &service.APIKey{
|
||||
ID: 10,
|
||||
GroupID: &groupID,
|
||||
})
|
||||
c.Set(string(middleware.ContextKeyUser), middleware.AuthSubject{
|
||||
UserID: 1,
|
||||
Concurrency: 1,
|
||||
})
|
||||
|
||||
// 故意使用未初始化依赖,验证快速失败而不是崩溃。
|
||||
h := &OpenAIGatewayHandler{}
|
||||
require.NotPanics(t, func() {
|
||||
h.Responses(c)
|
||||
})
|
||||
|
||||
require.Equal(t, http.StatusServiceUnavailable, w.Code)
|
||||
|
||||
var parsed map[string]any
|
||||
err := json.Unmarshal(w.Body.Bytes(), &parsed)
|
||||
require.NoError(t, err)
|
||||
|
||||
errorObj, ok := parsed["error"].(map[string]any)
|
||||
require.True(t, ok)
|
||||
assert.Equal(t, "api_error", errorObj["type"])
|
||||
assert.Equal(t, "Service temporarily unavailable", errorObj["message"])
|
||||
}
|
||||
|
||||
func TestOpenAIResponses_SetsClientTransportHTTP(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
|
||||
w := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(w)
|
||||
c.Request = httptest.NewRequest(http.MethodPost, "/openai/v1/responses", strings.NewReader(`{"model":"gpt-5"}`))
|
||||
c.Request.Header.Set("Content-Type", "application/json")
|
||||
|
||||
h := &OpenAIGatewayHandler{}
|
||||
h.Responses(c)
|
||||
|
||||
require.Equal(t, http.StatusUnauthorized, w.Code)
|
||||
require.Equal(t, service.OpenAIClientTransportHTTP, service.GetOpenAIClientTransport(c))
|
||||
}
|
||||
|
||||
func TestOpenAIResponses_RejectsMessageIDAsPreviousResponseID(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
|
||||
w := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(w)
|
||||
c.Request = httptest.NewRequest(http.MethodPost, "/openai/v1/responses", strings.NewReader(
|
||||
`{"model":"gpt-5.1","stream":false,"previous_response_id":"msg_123456","input":[{"type":"input_text","text":"hello"}]}`,
|
||||
))
|
||||
c.Request.Header.Set("Content-Type", "application/json")
|
||||
|
||||
groupID := int64(2)
|
||||
c.Set(string(middleware.ContextKeyAPIKey), &service.APIKey{
|
||||
ID: 101,
|
||||
GroupID: &groupID,
|
||||
User: &service.User{ID: 1},
|
||||
})
|
||||
c.Set(string(middleware.ContextKeyUser), middleware.AuthSubject{
|
||||
UserID: 1,
|
||||
Concurrency: 1,
|
||||
})
|
||||
|
||||
h := newOpenAIHandlerForPreviousResponseIDValidation(t, nil)
|
||||
h.Responses(c)
|
||||
|
||||
require.Equal(t, http.StatusBadRequest, w.Code)
|
||||
require.Contains(t, w.Body.String(), "previous_response_id must be a response.id")
|
||||
}
|
||||
|
||||
func TestOpenAIResponsesWebSocket_SetsClientTransportWSWhenUpgradeValid(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
|
||||
w := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(w)
|
||||
c.Request = httptest.NewRequest(http.MethodGet, "/openai/v1/responses", nil)
|
||||
c.Request.Header.Set("Upgrade", "websocket")
|
||||
c.Request.Header.Set("Connection", "Upgrade")
|
||||
|
||||
h := &OpenAIGatewayHandler{}
|
||||
h.ResponsesWebSocket(c)
|
||||
|
||||
require.Equal(t, http.StatusUnauthorized, w.Code)
|
||||
require.Equal(t, service.OpenAIClientTransportWS, service.GetOpenAIClientTransport(c))
|
||||
}
|
||||
|
||||
func TestOpenAIResponsesWebSocket_InvalidUpgradeDoesNotSetTransport(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
|
||||
w := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(w)
|
||||
c.Request = httptest.NewRequest(http.MethodGet, "/openai/v1/responses", nil)
|
||||
|
||||
h := &OpenAIGatewayHandler{}
|
||||
h.ResponsesWebSocket(c)
|
||||
|
||||
require.Equal(t, http.StatusUpgradeRequired, w.Code)
|
||||
require.Equal(t, service.OpenAIClientTransportUnknown, service.GetOpenAIClientTransport(c))
|
||||
}
|
||||
|
||||
func TestOpenAIResponsesWebSocket_RejectsMessageIDAsPreviousResponseID(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
|
||||
h := newOpenAIHandlerForPreviousResponseIDValidation(t, nil)
|
||||
wsServer := newOpenAIWSHandlerTestServer(t, h, middleware.AuthSubject{UserID: 1, Concurrency: 1})
|
||||
defer wsServer.Close()
|
||||
|
||||
dialCtx, cancelDial := context.WithTimeout(context.Background(), 3*time.Second)
|
||||
clientConn, _, err := coderws.Dial(dialCtx, "ws"+strings.TrimPrefix(wsServer.URL, "http")+"/openai/v1/responses", nil)
|
||||
cancelDial()
|
||||
require.NoError(t, err)
|
||||
defer func() {
|
||||
_ = clientConn.CloseNow()
|
||||
}()
|
||||
|
||||
writeCtx, cancelWrite := context.WithTimeout(context.Background(), 3*time.Second)
|
||||
err = clientConn.Write(writeCtx, coderws.MessageText, []byte(
|
||||
`{"type":"response.create","model":"gpt-5.1","stream":false,"previous_response_id":"msg_abc123"}`,
|
||||
))
|
||||
cancelWrite()
|
||||
require.NoError(t, err)
|
||||
|
||||
readCtx, cancelRead := context.WithTimeout(context.Background(), 3*time.Second)
|
||||
_, _, err = clientConn.Read(readCtx)
|
||||
cancelRead()
|
||||
require.Error(t, err)
|
||||
var closeErr coderws.CloseError
|
||||
require.ErrorAs(t, err, &closeErr)
|
||||
require.Equal(t, coderws.StatusPolicyViolation, closeErr.Code)
|
||||
require.Contains(t, strings.ToLower(closeErr.Reason), "previous_response_id")
|
||||
}
|
||||
|
||||
func TestOpenAIResponsesWebSocket_PreviousResponseIDKindLoggedBeforeAcquireFailure(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
|
||||
cache := &concurrencyCacheMock{
|
||||
acquireUserSlotFn: func(ctx context.Context, userID int64, maxConcurrency int, requestID string) (bool, error) {
|
||||
return false, errors.New("user slot unavailable")
|
||||
},
|
||||
}
|
||||
h := newOpenAIHandlerForPreviousResponseIDValidation(t, cache)
|
||||
wsServer := newOpenAIWSHandlerTestServer(t, h, middleware.AuthSubject{UserID: 1, Concurrency: 1})
|
||||
defer wsServer.Close()
|
||||
|
||||
dialCtx, cancelDial := context.WithTimeout(context.Background(), 3*time.Second)
|
||||
clientConn, _, err := coderws.Dial(dialCtx, "ws"+strings.TrimPrefix(wsServer.URL, "http")+"/openai/v1/responses", nil)
|
||||
cancelDial()
|
||||
require.NoError(t, err)
|
||||
defer func() {
|
||||
_ = clientConn.CloseNow()
|
||||
}()
|
||||
|
||||
writeCtx, cancelWrite := context.WithTimeout(context.Background(), 3*time.Second)
|
||||
err = clientConn.Write(writeCtx, coderws.MessageText, []byte(
|
||||
`{"type":"response.create","model":"gpt-5.1","stream":false,"previous_response_id":"resp_prev_123"}`,
|
||||
))
|
||||
cancelWrite()
|
||||
require.NoError(t, err)
|
||||
|
||||
readCtx, cancelRead := context.WithTimeout(context.Background(), 3*time.Second)
|
||||
_, _, err = clientConn.Read(readCtx)
|
||||
cancelRead()
|
||||
require.Error(t, err)
|
||||
var closeErr coderws.CloseError
|
||||
require.ErrorAs(t, err, &closeErr)
|
||||
require.Equal(t, coderws.StatusInternalError, closeErr.Code)
|
||||
require.Contains(t, strings.ToLower(closeErr.Reason), "failed to acquire user concurrency slot")
|
||||
}
|
||||
|
||||
func TestSetOpenAIClientTransportHTTP(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
|
||||
w := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(w)
|
||||
|
||||
setOpenAIClientTransportHTTP(c)
|
||||
require.Equal(t, service.OpenAIClientTransportHTTP, service.GetOpenAIClientTransport(c))
|
||||
}
|
||||
|
||||
func TestSetOpenAIClientTransportWS(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
|
||||
w := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(w)
|
||||
|
||||
setOpenAIClientTransportWS(c)
|
||||
require.Equal(t, service.OpenAIClientTransportWS, service.GetOpenAIClientTransport(c))
|
||||
}
|
||||
|
||||
// TestOpenAIHandler_GjsonExtraction 验证 gjson 从请求体中提取 model/stream 的正确性
|
||||
func TestOpenAIHandler_GjsonExtraction(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
body string
|
||||
wantModel string
|
||||
wantStream bool
|
||||
}{
|
||||
{"正常提取", `{"model":"gpt-4","stream":true,"input":"hello"}`, "gpt-4", true},
|
||||
{"stream false", `{"model":"gpt-4","stream":false}`, "gpt-4", false},
|
||||
{"无 stream 字段", `{"model":"gpt-4"}`, "gpt-4", false},
|
||||
{"model 缺失", `{"stream":true}`, "", true},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
body := []byte(tt.body)
|
||||
modelResult := gjson.GetBytes(body, "model")
|
||||
model := ""
|
||||
if modelResult.Type == gjson.String {
|
||||
model = modelResult.String()
|
||||
}
|
||||
stream := gjson.GetBytes(body, "stream").Bool()
|
||||
require.Equal(t, tt.wantModel, model)
|
||||
require.Equal(t, tt.wantStream, stream)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// TestOpenAIHandler_GjsonValidation 验证修复后的 JSON 合法性和类型校验
|
||||
func TestOpenAIHandler_GjsonValidation(t *testing.T) {
|
||||
// 非法 JSON 被 gjson.ValidBytes 拦截
|
||||
require.False(t, gjson.ValidBytes([]byte(`{invalid json`)))
|
||||
|
||||
// model 为数字 → 类型不是 gjson.String,应被拒绝
|
||||
body := []byte(`{"model":123}`)
|
||||
modelResult := gjson.GetBytes(body, "model")
|
||||
require.True(t, modelResult.Exists())
|
||||
require.NotEqual(t, gjson.String, modelResult.Type)
|
||||
|
||||
// model 为 null → 类型不是 gjson.String,应被拒绝
|
||||
body2 := []byte(`{"model":null}`)
|
||||
modelResult2 := gjson.GetBytes(body2, "model")
|
||||
require.True(t, modelResult2.Exists())
|
||||
require.NotEqual(t, gjson.String, modelResult2.Type)
|
||||
|
||||
// stream 为 string → 类型既不是 True 也不是 False,应被拒绝
|
||||
body3 := []byte(`{"model":"gpt-4","stream":"true"}`)
|
||||
streamResult := gjson.GetBytes(body3, "stream")
|
||||
require.True(t, streamResult.Exists())
|
||||
require.NotEqual(t, gjson.True, streamResult.Type)
|
||||
require.NotEqual(t, gjson.False, streamResult.Type)
|
||||
|
||||
// stream 为 int → 同上
|
||||
body4 := []byte(`{"model":"gpt-4","stream":1}`)
|
||||
streamResult2 := gjson.GetBytes(body4, "stream")
|
||||
require.True(t, streamResult2.Exists())
|
||||
require.NotEqual(t, gjson.True, streamResult2.Type)
|
||||
require.NotEqual(t, gjson.False, streamResult2.Type)
|
||||
}
|
||||
|
||||
// TestOpenAIHandler_InstructionsInjection 验证 instructions 的 gjson/sjson 注入逻辑
|
||||
func TestOpenAIHandler_InstructionsInjection(t *testing.T) {
|
||||
// 测试 1:无 instructions → 注入
|
||||
body := []byte(`{"model":"gpt-4"}`)
|
||||
existing := gjson.GetBytes(body, "instructions").String()
|
||||
require.Empty(t, existing)
|
||||
newBody, err := sjson.SetBytes(body, "instructions", "test instruction")
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, "test instruction", gjson.GetBytes(newBody, "instructions").String())
|
||||
|
||||
// 测试 2:已有 instructions → 不覆盖
|
||||
body2 := []byte(`{"model":"gpt-4","instructions":"existing"}`)
|
||||
existing2 := gjson.GetBytes(body2, "instructions").String()
|
||||
require.Equal(t, "existing", existing2)
|
||||
|
||||
// 测试 3:空白 instructions → 注入
|
||||
body3 := []byte(`{"model":"gpt-4","instructions":" "}`)
|
||||
existing3 := strings.TrimSpace(gjson.GetBytes(body3, "instructions").String())
|
||||
require.Empty(t, existing3)
|
||||
|
||||
// 测试 4:sjson.SetBytes 返回错误时不应 panic
|
||||
// 正常 JSON 不会产生 sjson 错误,验证返回值被正确处理
|
||||
validBody := []byte(`{"model":"gpt-4"}`)
|
||||
result, setErr := sjson.SetBytes(validBody, "instructions", "hello")
|
||||
require.NoError(t, setErr)
|
||||
require.True(t, gjson.ValidBytes(result))
|
||||
}
|
||||
|
||||
func newOpenAIHandlerForPreviousResponseIDValidation(t *testing.T, cache *concurrencyCacheMock) *OpenAIGatewayHandler {
|
||||
t.Helper()
|
||||
if cache == nil {
|
||||
cache = &concurrencyCacheMock{
|
||||
acquireUserSlotFn: func(ctx context.Context, userID int64, maxConcurrency int, requestID string) (bool, error) {
|
||||
return true, nil
|
||||
},
|
||||
acquireAccountSlotFn: func(ctx context.Context, accountID int64, maxConcurrency int, requestID string) (bool, error) {
|
||||
return true, nil
|
||||
},
|
||||
}
|
||||
}
|
||||
return &OpenAIGatewayHandler{
|
||||
gatewayService: &service.OpenAIGatewayService{},
|
||||
billingCacheService: &service.BillingCacheService{},
|
||||
apiKeyService: &service.APIKeyService{},
|
||||
concurrencyHelper: NewConcurrencyHelper(service.NewConcurrencyService(cache), SSEPingFormatNone, time.Second),
|
||||
}
|
||||
}
|
||||
|
||||
func newOpenAIWSHandlerTestServer(t *testing.T, h *OpenAIGatewayHandler, subject middleware.AuthSubject) *httptest.Server {
|
||||
t.Helper()
|
||||
groupID := int64(2)
|
||||
apiKey := &service.APIKey{
|
||||
ID: 101,
|
||||
GroupID: &groupID,
|
||||
User: &service.User{ID: subject.UserID},
|
||||
}
|
||||
router := gin.New()
|
||||
router.Use(func(c *gin.Context) {
|
||||
c.Set(string(middleware.ContextKeyAPIKey), apiKey)
|
||||
c.Set(string(middleware.ContextKeyUser), subject)
|
||||
c.Next()
|
||||
})
|
||||
router.GET("/openai/v1/responses", h.ResponsesWebSocket)
|
||||
return httptest.NewServer(router)
|
||||
}
|
||||
@@ -41,9 +41,8 @@ const (
|
||||
)
|
||||
|
||||
type opsErrorLogJob struct {
|
||||
ops *service.OpsService
|
||||
entry *service.OpsInsertErrorLogInput
|
||||
requestBody []byte
|
||||
ops *service.OpsService
|
||||
entry *service.OpsInsertErrorLogInput
|
||||
}
|
||||
|
||||
var (
|
||||
@@ -58,6 +57,7 @@ var (
|
||||
opsErrorLogEnqueued atomic.Int64
|
||||
opsErrorLogDropped atomic.Int64
|
||||
opsErrorLogProcessed atomic.Int64
|
||||
opsErrorLogSanitized atomic.Int64
|
||||
|
||||
opsErrorLogLastDropLogAt atomic.Int64
|
||||
|
||||
@@ -94,7 +94,7 @@ func startOpsErrorLogWorkers() {
|
||||
}
|
||||
}()
|
||||
ctx, cancel := context.WithTimeout(context.Background(), opsErrorLogTimeout)
|
||||
_ = job.ops.RecordError(ctx, job.entry, job.requestBody)
|
||||
_ = job.ops.RecordError(ctx, job.entry, nil)
|
||||
cancel()
|
||||
opsErrorLogProcessed.Add(1)
|
||||
}()
|
||||
@@ -103,7 +103,7 @@ func startOpsErrorLogWorkers() {
|
||||
}
|
||||
}
|
||||
|
||||
func enqueueOpsErrorLog(ops *service.OpsService, entry *service.OpsInsertErrorLogInput, requestBody []byte) {
|
||||
func enqueueOpsErrorLog(ops *service.OpsService, entry *service.OpsInsertErrorLogInput) {
|
||||
if ops == nil || entry == nil {
|
||||
return
|
||||
}
|
||||
@@ -129,7 +129,7 @@ func enqueueOpsErrorLog(ops *service.OpsService, entry *service.OpsInsertErrorLo
|
||||
}
|
||||
|
||||
select {
|
||||
case opsErrorLogQueue <- opsErrorLogJob{ops: ops, entry: entry, requestBody: requestBody}:
|
||||
case opsErrorLogQueue <- opsErrorLogJob{ops: ops, entry: entry}:
|
||||
opsErrorLogQueueLen.Add(1)
|
||||
opsErrorLogEnqueued.Add(1)
|
||||
default:
|
||||
@@ -205,6 +205,10 @@ func OpsErrorLogProcessedTotal() int64 {
|
||||
return opsErrorLogProcessed.Load()
|
||||
}
|
||||
|
||||
func OpsErrorLogSanitizedTotal() int64 {
|
||||
return opsErrorLogSanitized.Load()
|
||||
}
|
||||
|
||||
func maybeLogOpsErrorLogDrop() {
|
||||
now := time.Now().Unix()
|
||||
|
||||
@@ -222,12 +226,13 @@ func maybeLogOpsErrorLogDrop() {
|
||||
queueCap := OpsErrorLogQueueCapacity()
|
||||
|
||||
log.Printf(
|
||||
"[OpsErrorLogger] queue is full; dropping logs (queued=%d cap=%d enqueued_total=%d dropped_total=%d processed_total=%d)",
|
||||
"[OpsErrorLogger] queue is full; dropping logs (queued=%d cap=%d enqueued_total=%d dropped_total=%d processed_total=%d sanitized_total=%d)",
|
||||
queued,
|
||||
queueCap,
|
||||
opsErrorLogEnqueued.Load(),
|
||||
opsErrorLogDropped.Load(),
|
||||
opsErrorLogProcessed.Load(),
|
||||
opsErrorLogSanitized.Load(),
|
||||
)
|
||||
}
|
||||
|
||||
@@ -255,18 +260,49 @@ func setOpsRequestContext(c *gin.Context, model string, stream bool, requestBody
|
||||
if c == nil {
|
||||
return
|
||||
}
|
||||
model = strings.TrimSpace(model)
|
||||
c.Set(opsModelKey, model)
|
||||
c.Set(opsStreamKey, stream)
|
||||
if len(requestBody) > 0 {
|
||||
c.Set(opsRequestBodyKey, requestBody)
|
||||
}
|
||||
if c.Request != nil && model != "" {
|
||||
ctx := context.WithValue(c.Request.Context(), ctxkey.Model, model)
|
||||
c.Request = c.Request.WithContext(ctx)
|
||||
}
|
||||
}
|
||||
|
||||
func setOpsSelectedAccount(c *gin.Context, accountID int64) {
|
||||
func attachOpsRequestBodyToEntry(c *gin.Context, entry *service.OpsInsertErrorLogInput) {
|
||||
if c == nil || entry == nil {
|
||||
return
|
||||
}
|
||||
v, ok := c.Get(opsRequestBodyKey)
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
raw, ok := v.([]byte)
|
||||
if !ok || len(raw) == 0 {
|
||||
return
|
||||
}
|
||||
entry.RequestBodyJSON, entry.RequestBodyTruncated, entry.RequestBodyBytes = service.PrepareOpsRequestBodyForQueue(raw)
|
||||
opsErrorLogSanitized.Add(1)
|
||||
}
|
||||
|
||||
func setOpsSelectedAccount(c *gin.Context, accountID int64, platform ...string) {
|
||||
if c == nil || accountID <= 0 {
|
||||
return
|
||||
}
|
||||
c.Set(opsAccountIDKey, accountID)
|
||||
if c.Request != nil {
|
||||
ctx := context.WithValue(c.Request.Context(), ctxkey.AccountID, accountID)
|
||||
if len(platform) > 0 {
|
||||
p := strings.TrimSpace(platform[0])
|
||||
if p != "" {
|
||||
ctx = context.WithValue(ctx, ctxkey.Platform, p)
|
||||
}
|
||||
}
|
||||
c.Request = c.Request.WithContext(ctx)
|
||||
}
|
||||
}
|
||||
|
||||
type opsCaptureWriter struct {
|
||||
@@ -275,6 +311,35 @@ type opsCaptureWriter struct {
|
||||
buf bytes.Buffer
|
||||
}
|
||||
|
||||
const opsCaptureWriterLimit = 64 * 1024
|
||||
|
||||
var opsCaptureWriterPool = sync.Pool{
|
||||
New: func() any {
|
||||
return &opsCaptureWriter{limit: opsCaptureWriterLimit}
|
||||
},
|
||||
}
|
||||
|
||||
func acquireOpsCaptureWriter(rw gin.ResponseWriter) *opsCaptureWriter {
|
||||
w, ok := opsCaptureWriterPool.Get().(*opsCaptureWriter)
|
||||
if !ok || w == nil {
|
||||
w = &opsCaptureWriter{}
|
||||
}
|
||||
w.ResponseWriter = rw
|
||||
w.limit = opsCaptureWriterLimit
|
||||
w.buf.Reset()
|
||||
return w
|
||||
}
|
||||
|
||||
func releaseOpsCaptureWriter(w *opsCaptureWriter) {
|
||||
if w == nil {
|
||||
return
|
||||
}
|
||||
w.ResponseWriter = nil
|
||||
w.limit = opsCaptureWriterLimit
|
||||
w.buf.Reset()
|
||||
opsCaptureWriterPool.Put(w)
|
||||
}
|
||||
|
||||
func (w *opsCaptureWriter) Write(b []byte) (int, error) {
|
||||
if w.Status() >= 400 && w.limit > 0 && w.buf.Len() < w.limit {
|
||||
remaining := w.limit - w.buf.Len()
|
||||
@@ -306,7 +371,16 @@ func (w *opsCaptureWriter) WriteString(s string) (int, error) {
|
||||
// - Streaming errors after the response has started (SSE) may still need explicit logging.
|
||||
func OpsErrorLoggerMiddleware(ops *service.OpsService) gin.HandlerFunc {
|
||||
return func(c *gin.Context) {
|
||||
w := &opsCaptureWriter{ResponseWriter: c.Writer, limit: 64 * 1024}
|
||||
originalWriter := c.Writer
|
||||
w := acquireOpsCaptureWriter(originalWriter)
|
||||
defer func() {
|
||||
// Restore the original writer before returning so outer middlewares
|
||||
// don't observe a pooled wrapper that has been released.
|
||||
if c.Writer == w {
|
||||
c.Writer = originalWriter
|
||||
}
|
||||
releaseOpsCaptureWriter(w)
|
||||
}()
|
||||
c.Writer = w
|
||||
c.Next()
|
||||
|
||||
@@ -507,6 +581,7 @@ func OpsErrorLoggerMiddleware(ops *service.OpsService) gin.HandlerFunc {
|
||||
RetryCount: 0,
|
||||
CreatedAt: time.Now(),
|
||||
}
|
||||
applyOpsLatencyFieldsFromContext(c, entry)
|
||||
|
||||
if apiKey != nil {
|
||||
entry.APIKeyID = &apiKey.ID
|
||||
@@ -528,14 +603,9 @@ func OpsErrorLoggerMiddleware(ops *service.OpsService) gin.HandlerFunc {
|
||||
entry.ClientIP = &clientIP
|
||||
}
|
||||
|
||||
var requestBody []byte
|
||||
if v, ok := c.Get(opsRequestBodyKey); ok {
|
||||
if b, ok := v.([]byte); ok && len(b) > 0 {
|
||||
requestBody = b
|
||||
}
|
||||
}
|
||||
// Store request headers/body only when an upstream error occurred to keep overhead minimal.
|
||||
entry.RequestHeadersJSON = extractOpsRetryRequestHeaders(c)
|
||||
attachOpsRequestBodyToEntry(c, entry)
|
||||
|
||||
// Skip logging if a passthrough rule with skip_monitoring=true matched.
|
||||
if v, ok := c.Get(service.OpsSkipPassthroughKey); ok {
|
||||
@@ -544,7 +614,7 @@ func OpsErrorLoggerMiddleware(ops *service.OpsService) gin.HandlerFunc {
|
||||
}
|
||||
}
|
||||
|
||||
enqueueOpsErrorLog(ops, entry, requestBody)
|
||||
enqueueOpsErrorLog(ops, entry)
|
||||
return
|
||||
}
|
||||
|
||||
@@ -592,8 +662,10 @@ func OpsErrorLoggerMiddleware(ops *service.OpsService) gin.HandlerFunc {
|
||||
requestID = c.Writer.Header().Get("x-request-id")
|
||||
}
|
||||
|
||||
phase := classifyOpsPhase(parsed.ErrorType, parsed.Message, parsed.Code)
|
||||
isBusinessLimited := classifyOpsIsBusinessLimited(parsed.ErrorType, phase, parsed.Code, status, parsed.Message)
|
||||
normalizedType := normalizeOpsErrorType(parsed.ErrorType, parsed.Code)
|
||||
|
||||
phase := classifyOpsPhase(normalizedType, parsed.Message, parsed.Code)
|
||||
isBusinessLimited := classifyOpsIsBusinessLimited(normalizedType, phase, parsed.Code, status, parsed.Message)
|
||||
|
||||
errorOwner := classifyOpsErrorOwner(phase, parsed.Message)
|
||||
errorSource := classifyOpsErrorSource(phase, parsed.Message)
|
||||
@@ -615,8 +687,8 @@ func OpsErrorLoggerMiddleware(ops *service.OpsService) gin.HandlerFunc {
|
||||
UserAgent: c.GetHeader("User-Agent"),
|
||||
|
||||
ErrorPhase: phase,
|
||||
ErrorType: normalizeOpsErrorType(parsed.ErrorType, parsed.Code),
|
||||
Severity: classifyOpsSeverity(parsed.ErrorType, status),
|
||||
ErrorType: normalizedType,
|
||||
Severity: classifyOpsSeverity(normalizedType, status),
|
||||
StatusCode: status,
|
||||
IsBusinessLimited: isBusinessLimited,
|
||||
IsCountTokens: isCountTokensRequest(c),
|
||||
@@ -628,10 +700,11 @@ func OpsErrorLoggerMiddleware(ops *service.OpsService) gin.HandlerFunc {
|
||||
ErrorSource: errorSource,
|
||||
ErrorOwner: errorOwner,
|
||||
|
||||
IsRetryable: classifyOpsIsRetryable(parsed.ErrorType, status),
|
||||
IsRetryable: classifyOpsIsRetryable(normalizedType, status),
|
||||
RetryCount: 0,
|
||||
CreatedAt: time.Now(),
|
||||
}
|
||||
applyOpsLatencyFieldsFromContext(c, entry)
|
||||
|
||||
// Capture upstream error context set by gateway services (if present).
|
||||
// This does NOT affect the client response; it enriches Ops troubleshooting data.
|
||||
@@ -707,17 +780,12 @@ func OpsErrorLoggerMiddleware(ops *service.OpsService) gin.HandlerFunc {
|
||||
entry.ClientIP = &clientIP
|
||||
}
|
||||
|
||||
var requestBody []byte
|
||||
if v, ok := c.Get(opsRequestBodyKey); ok {
|
||||
if b, ok := v.([]byte); ok && len(b) > 0 {
|
||||
requestBody = b
|
||||
}
|
||||
}
|
||||
// Persist only a minimal, whitelisted set of request headers to improve retry fidelity.
|
||||
// Do NOT store Authorization/Cookie/etc.
|
||||
entry.RequestHeadersJSON = extractOpsRetryRequestHeaders(c)
|
||||
attachOpsRequestBodyToEntry(c, entry)
|
||||
|
||||
enqueueOpsErrorLog(ops, entry, requestBody)
|
||||
enqueueOpsErrorLog(ops, entry)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -760,6 +828,44 @@ func extractOpsRetryRequestHeaders(c *gin.Context) *string {
|
||||
return &s
|
||||
}
|
||||
|
||||
func applyOpsLatencyFieldsFromContext(c *gin.Context, entry *service.OpsInsertErrorLogInput) {
|
||||
if c == nil || entry == nil {
|
||||
return
|
||||
}
|
||||
entry.AuthLatencyMs = getContextLatencyMs(c, service.OpsAuthLatencyMsKey)
|
||||
entry.RoutingLatencyMs = getContextLatencyMs(c, service.OpsRoutingLatencyMsKey)
|
||||
entry.UpstreamLatencyMs = getContextLatencyMs(c, service.OpsUpstreamLatencyMsKey)
|
||||
entry.ResponseLatencyMs = getContextLatencyMs(c, service.OpsResponseLatencyMsKey)
|
||||
entry.TimeToFirstTokenMs = getContextLatencyMs(c, service.OpsTimeToFirstTokenMsKey)
|
||||
}
|
||||
|
||||
func getContextLatencyMs(c *gin.Context, key string) *int64 {
|
||||
if c == nil || strings.TrimSpace(key) == "" {
|
||||
return nil
|
||||
}
|
||||
v, ok := c.Get(key)
|
||||
if !ok {
|
||||
return nil
|
||||
}
|
||||
var ms int64
|
||||
switch t := v.(type) {
|
||||
case int:
|
||||
ms = int64(t)
|
||||
case int32:
|
||||
ms = int64(t)
|
||||
case int64:
|
||||
ms = t
|
||||
case float64:
|
||||
ms = int64(t)
|
||||
default:
|
||||
return nil
|
||||
}
|
||||
if ms < 0 {
|
||||
return nil
|
||||
}
|
||||
return &ms
|
||||
}
|
||||
|
||||
type parsedOpsError struct {
|
||||
ErrorType string
|
||||
Message string
|
||||
@@ -835,8 +941,29 @@ func guessPlatformFromPath(path string) string {
|
||||
}
|
||||
}
|
||||
|
||||
// isKnownOpsErrorType returns true if t is a recognized error type used by the
|
||||
// ops classification pipeline. Upstream proxies sometimes return garbage values
|
||||
// (e.g. the Go-serialized literal "<nil>") which would pollute phase/severity
|
||||
// classification if accepted blindly.
|
||||
func isKnownOpsErrorType(t string) bool {
|
||||
switch t {
|
||||
case "invalid_request_error",
|
||||
"authentication_error",
|
||||
"rate_limit_error",
|
||||
"billing_error",
|
||||
"subscription_error",
|
||||
"upstream_error",
|
||||
"overloaded_error",
|
||||
"api_error",
|
||||
"not_found_error",
|
||||
"forbidden_error":
|
||||
return true
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func normalizeOpsErrorType(errType string, code string) string {
|
||||
if errType != "" {
|
||||
if errType != "" && isKnownOpsErrorType(errType) {
|
||||
return errType
|
||||
}
|
||||
switch strings.TrimSpace(code) {
|
||||
|
||||
276
backend/internal/handler/ops_error_logger_test.go
Normal file
276
backend/internal/handler/ops_error_logger_test.go
Normal file
@@ -0,0 +1,276 @@
|
||||
package handler
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"sync"
|
||||
"testing"
|
||||
|
||||
middleware2 "github.com/Wei-Shaw/sub2api/internal/server/middleware"
|
||||
"github.com/Wei-Shaw/sub2api/internal/service"
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func resetOpsErrorLoggerStateForTest(t *testing.T) {
|
||||
t.Helper()
|
||||
|
||||
opsErrorLogMu.Lock()
|
||||
ch := opsErrorLogQueue
|
||||
opsErrorLogQueue = nil
|
||||
opsErrorLogStopping = true
|
||||
opsErrorLogMu.Unlock()
|
||||
|
||||
if ch != nil {
|
||||
close(ch)
|
||||
}
|
||||
opsErrorLogWorkersWg.Wait()
|
||||
|
||||
opsErrorLogOnce = sync.Once{}
|
||||
opsErrorLogStopOnce = sync.Once{}
|
||||
opsErrorLogWorkersWg = sync.WaitGroup{}
|
||||
opsErrorLogMu = sync.RWMutex{}
|
||||
opsErrorLogStopping = false
|
||||
|
||||
opsErrorLogQueueLen.Store(0)
|
||||
opsErrorLogEnqueued.Store(0)
|
||||
opsErrorLogDropped.Store(0)
|
||||
opsErrorLogProcessed.Store(0)
|
||||
opsErrorLogSanitized.Store(0)
|
||||
opsErrorLogLastDropLogAt.Store(0)
|
||||
|
||||
opsErrorLogShutdownCh = make(chan struct{})
|
||||
opsErrorLogShutdownOnce = sync.Once{}
|
||||
opsErrorLogDrained.Store(false)
|
||||
}
|
||||
|
||||
func TestAttachOpsRequestBodyToEntry_SanitizeAndTrim(t *testing.T) {
|
||||
resetOpsErrorLoggerStateForTest(t)
|
||||
gin.SetMode(gin.TestMode)
|
||||
|
||||
rec := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(rec)
|
||||
c.Request = httptest.NewRequest(http.MethodPost, "/v1/messages", nil)
|
||||
|
||||
raw := []byte(`{"access_token":"secret-token","messages":[{"role":"user","content":"hello"}]}`)
|
||||
setOpsRequestContext(c, "claude-3", false, raw)
|
||||
|
||||
entry := &service.OpsInsertErrorLogInput{}
|
||||
attachOpsRequestBodyToEntry(c, entry)
|
||||
|
||||
require.NotNil(t, entry.RequestBodyBytes)
|
||||
require.Equal(t, len(raw), *entry.RequestBodyBytes)
|
||||
require.NotNil(t, entry.RequestBodyJSON)
|
||||
require.NotContains(t, *entry.RequestBodyJSON, "secret-token")
|
||||
require.Contains(t, *entry.RequestBodyJSON, "[REDACTED]")
|
||||
require.Equal(t, int64(1), OpsErrorLogSanitizedTotal())
|
||||
}
|
||||
|
||||
func TestAttachOpsRequestBodyToEntry_InvalidJSONKeepsSize(t *testing.T) {
|
||||
resetOpsErrorLoggerStateForTest(t)
|
||||
gin.SetMode(gin.TestMode)
|
||||
|
||||
rec := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(rec)
|
||||
c.Request = httptest.NewRequest(http.MethodPost, "/v1/messages", nil)
|
||||
|
||||
raw := []byte("not-json")
|
||||
setOpsRequestContext(c, "claude-3", false, raw)
|
||||
|
||||
entry := &service.OpsInsertErrorLogInput{}
|
||||
attachOpsRequestBodyToEntry(c, entry)
|
||||
|
||||
require.Nil(t, entry.RequestBodyJSON)
|
||||
require.NotNil(t, entry.RequestBodyBytes)
|
||||
require.Equal(t, len(raw), *entry.RequestBodyBytes)
|
||||
require.False(t, entry.RequestBodyTruncated)
|
||||
require.Equal(t, int64(1), OpsErrorLogSanitizedTotal())
|
||||
}
|
||||
|
||||
func TestEnqueueOpsErrorLog_QueueFullDrop(t *testing.T) {
|
||||
resetOpsErrorLoggerStateForTest(t)
|
||||
|
||||
// 禁止 enqueueOpsErrorLog 触发 workers,使用测试队列验证满队列降级。
|
||||
opsErrorLogOnce.Do(func() {})
|
||||
|
||||
opsErrorLogMu.Lock()
|
||||
opsErrorLogQueue = make(chan opsErrorLogJob, 1)
|
||||
opsErrorLogMu.Unlock()
|
||||
|
||||
ops := service.NewOpsService(nil, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil)
|
||||
entry := &service.OpsInsertErrorLogInput{ErrorPhase: "upstream", ErrorType: "upstream_error"}
|
||||
|
||||
enqueueOpsErrorLog(ops, entry)
|
||||
enqueueOpsErrorLog(ops, entry)
|
||||
|
||||
require.Equal(t, int64(1), OpsErrorLogEnqueuedTotal())
|
||||
require.Equal(t, int64(1), OpsErrorLogDroppedTotal())
|
||||
require.Equal(t, int64(1), OpsErrorLogQueueLength())
|
||||
}
|
||||
|
||||
func TestAttachOpsRequestBodyToEntry_EarlyReturnBranches(t *testing.T) {
|
||||
resetOpsErrorLoggerStateForTest(t)
|
||||
gin.SetMode(gin.TestMode)
|
||||
|
||||
entry := &service.OpsInsertErrorLogInput{}
|
||||
attachOpsRequestBodyToEntry(nil, entry)
|
||||
attachOpsRequestBodyToEntry(&gin.Context{}, nil)
|
||||
|
||||
rec := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(rec)
|
||||
c.Request = httptest.NewRequest(http.MethodPost, "/v1/messages", nil)
|
||||
|
||||
// 无请求体 key
|
||||
attachOpsRequestBodyToEntry(c, entry)
|
||||
require.Nil(t, entry.RequestBodyJSON)
|
||||
require.Nil(t, entry.RequestBodyBytes)
|
||||
require.False(t, entry.RequestBodyTruncated)
|
||||
|
||||
// 错误类型
|
||||
c.Set(opsRequestBodyKey, "not-bytes")
|
||||
attachOpsRequestBodyToEntry(c, entry)
|
||||
require.Nil(t, entry.RequestBodyJSON)
|
||||
require.Nil(t, entry.RequestBodyBytes)
|
||||
|
||||
// 空 bytes
|
||||
c.Set(opsRequestBodyKey, []byte{})
|
||||
attachOpsRequestBodyToEntry(c, entry)
|
||||
require.Nil(t, entry.RequestBodyJSON)
|
||||
require.Nil(t, entry.RequestBodyBytes)
|
||||
|
||||
require.Equal(t, int64(0), OpsErrorLogSanitizedTotal())
|
||||
}
|
||||
|
||||
func TestEnqueueOpsErrorLog_EarlyReturnBranches(t *testing.T) {
|
||||
resetOpsErrorLoggerStateForTest(t)
|
||||
|
||||
ops := service.NewOpsService(nil, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil)
|
||||
entry := &service.OpsInsertErrorLogInput{ErrorPhase: "upstream", ErrorType: "upstream_error"}
|
||||
|
||||
// nil 入参分支
|
||||
enqueueOpsErrorLog(nil, entry)
|
||||
enqueueOpsErrorLog(ops, nil)
|
||||
require.Equal(t, int64(0), OpsErrorLogEnqueuedTotal())
|
||||
|
||||
// shutdown 分支
|
||||
close(opsErrorLogShutdownCh)
|
||||
enqueueOpsErrorLog(ops, entry)
|
||||
require.Equal(t, int64(0), OpsErrorLogEnqueuedTotal())
|
||||
|
||||
// stopping 分支
|
||||
resetOpsErrorLoggerStateForTest(t)
|
||||
opsErrorLogMu.Lock()
|
||||
opsErrorLogStopping = true
|
||||
opsErrorLogMu.Unlock()
|
||||
enqueueOpsErrorLog(ops, entry)
|
||||
require.Equal(t, int64(0), OpsErrorLogEnqueuedTotal())
|
||||
|
||||
// queue nil 分支(防止启动 worker 干扰)
|
||||
resetOpsErrorLoggerStateForTest(t)
|
||||
opsErrorLogOnce.Do(func() {})
|
||||
opsErrorLogMu.Lock()
|
||||
opsErrorLogQueue = nil
|
||||
opsErrorLogMu.Unlock()
|
||||
enqueueOpsErrorLog(ops, entry)
|
||||
require.Equal(t, int64(0), OpsErrorLogEnqueuedTotal())
|
||||
}
|
||||
|
||||
func TestOpsCaptureWriterPool_ResetOnRelease(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
|
||||
rec := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(rec)
|
||||
c.Request = httptest.NewRequest(http.MethodGet, "/test", nil)
|
||||
|
||||
writer := acquireOpsCaptureWriter(c.Writer)
|
||||
require.NotNil(t, writer)
|
||||
_, err := writer.buf.WriteString("temp-error-body")
|
||||
require.NoError(t, err)
|
||||
|
||||
releaseOpsCaptureWriter(writer)
|
||||
|
||||
reused := acquireOpsCaptureWriter(c.Writer)
|
||||
defer releaseOpsCaptureWriter(reused)
|
||||
|
||||
require.Zero(t, reused.buf.Len(), "writer should be reset before reuse")
|
||||
}
|
||||
|
||||
func TestOpsErrorLoggerMiddleware_DoesNotBreakOuterMiddlewares(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
|
||||
r := gin.New()
|
||||
r.Use(middleware2.Recovery())
|
||||
r.Use(middleware2.RequestLogger())
|
||||
r.Use(middleware2.Logger())
|
||||
r.GET("/v1/messages", OpsErrorLoggerMiddleware(nil), func(c *gin.Context) {
|
||||
c.Status(http.StatusNoContent)
|
||||
})
|
||||
|
||||
rec := httptest.NewRecorder()
|
||||
req := httptest.NewRequest(http.MethodGet, "/v1/messages", nil)
|
||||
|
||||
require.NotPanics(t, func() {
|
||||
r.ServeHTTP(rec, req)
|
||||
})
|
||||
require.Equal(t, http.StatusNoContent, rec.Code)
|
||||
}
|
||||
|
||||
func TestIsKnownOpsErrorType(t *testing.T) {
|
||||
known := []string{
|
||||
"invalid_request_error",
|
||||
"authentication_error",
|
||||
"rate_limit_error",
|
||||
"billing_error",
|
||||
"subscription_error",
|
||||
"upstream_error",
|
||||
"overloaded_error",
|
||||
"api_error",
|
||||
"not_found_error",
|
||||
"forbidden_error",
|
||||
}
|
||||
for _, k := range known {
|
||||
require.True(t, isKnownOpsErrorType(k), "expected known: %s", k)
|
||||
}
|
||||
|
||||
unknown := []string{"<nil>", "null", "", "random_error", "some_new_type", "<nil>\u003e"}
|
||||
for _, u := range unknown {
|
||||
require.False(t, isKnownOpsErrorType(u), "expected unknown: %q", u)
|
||||
}
|
||||
}
|
||||
|
||||
func TestNormalizeOpsErrorType(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
errType string
|
||||
code string
|
||||
want string
|
||||
}{
|
||||
// Known types pass through.
|
||||
{"known invalid_request_error", "invalid_request_error", "", "invalid_request_error"},
|
||||
{"known rate_limit_error", "rate_limit_error", "", "rate_limit_error"},
|
||||
{"known upstream_error", "upstream_error", "", "upstream_error"},
|
||||
|
||||
// Unknown/garbage types are rejected and fall through to code-based or default.
|
||||
{"nil literal from upstream", "<nil>", "", "api_error"},
|
||||
{"null string", "null", "", "api_error"},
|
||||
{"random string", "something_weird", "", "api_error"},
|
||||
|
||||
// Unknown type but known code still maps correctly.
|
||||
{"nil with INSUFFICIENT_BALANCE code", "<nil>", "INSUFFICIENT_BALANCE", "billing_error"},
|
||||
{"nil with USAGE_LIMIT_EXCEEDED code", "<nil>", "USAGE_LIMIT_EXCEEDED", "subscription_error"},
|
||||
|
||||
// Empty type falls through to code-based mapping.
|
||||
{"empty type with balance code", "", "INSUFFICIENT_BALANCE", "billing_error"},
|
||||
{"empty type with subscription code", "", "SUBSCRIPTION_NOT_FOUND", "subscription_error"},
|
||||
{"empty type no code", "", "", "api_error"},
|
||||
|
||||
// Known type overrides conflicting code-based mapping.
|
||||
{"known type overrides conflicting code", "rate_limit_error", "INSUFFICIENT_BALANCE", "rate_limit_error"},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
got := normalizeOpsErrorType(tt.errType, tt.code)
|
||||
require.Equal(t, tt.want, got)
|
||||
})
|
||||
}
|
||||
}
|
||||
@@ -32,25 +32,28 @@ func (h *SettingHandler) GetPublicSettings(c *gin.Context) {
|
||||
}
|
||||
|
||||
response.Success(c, dto.PublicSettings{
|
||||
RegistrationEnabled: settings.RegistrationEnabled,
|
||||
EmailVerifyEnabled: settings.EmailVerifyEnabled,
|
||||
PromoCodeEnabled: settings.PromoCodeEnabled,
|
||||
PasswordResetEnabled: settings.PasswordResetEnabled,
|
||||
InvitationCodeEnabled: settings.InvitationCodeEnabled,
|
||||
TotpEnabled: settings.TotpEnabled,
|
||||
TurnstileEnabled: settings.TurnstileEnabled,
|
||||
TurnstileSiteKey: settings.TurnstileSiteKey,
|
||||
SiteName: settings.SiteName,
|
||||
SiteLogo: settings.SiteLogo,
|
||||
SiteSubtitle: settings.SiteSubtitle,
|
||||
APIBaseURL: settings.APIBaseURL,
|
||||
ContactInfo: settings.ContactInfo,
|
||||
DocURL: settings.DocURL,
|
||||
HomeContent: settings.HomeContent,
|
||||
HideCcsImportButton: settings.HideCcsImportButton,
|
||||
PurchaseSubscriptionEnabled: settings.PurchaseSubscriptionEnabled,
|
||||
PurchaseSubscriptionURL: settings.PurchaseSubscriptionURL,
|
||||
LinuxDoOAuthEnabled: settings.LinuxDoOAuthEnabled,
|
||||
Version: h.version,
|
||||
RegistrationEnabled: settings.RegistrationEnabled,
|
||||
EmailVerifyEnabled: settings.EmailVerifyEnabled,
|
||||
RegistrationEmailSuffixWhitelist: settings.RegistrationEmailSuffixWhitelist,
|
||||
PromoCodeEnabled: settings.PromoCodeEnabled,
|
||||
PasswordResetEnabled: settings.PasswordResetEnabled,
|
||||
InvitationCodeEnabled: settings.InvitationCodeEnabled,
|
||||
TotpEnabled: settings.TotpEnabled,
|
||||
TurnstileEnabled: settings.TurnstileEnabled,
|
||||
TurnstileSiteKey: settings.TurnstileSiteKey,
|
||||
SiteName: settings.SiteName,
|
||||
SiteLogo: settings.SiteLogo,
|
||||
SiteSubtitle: settings.SiteSubtitle,
|
||||
APIBaseURL: settings.APIBaseURL,
|
||||
ContactInfo: settings.ContactInfo,
|
||||
DocURL: settings.DocURL,
|
||||
HomeContent: settings.HomeContent,
|
||||
HideCcsImportButton: settings.HideCcsImportButton,
|
||||
PurchaseSubscriptionEnabled: settings.PurchaseSubscriptionEnabled,
|
||||
PurchaseSubscriptionURL: settings.PurchaseSubscriptionURL,
|
||||
CustomMenuItems: dto.ParseUserVisibleMenuItems(settings.CustomMenuItems),
|
||||
LinuxDoOAuthEnabled: settings.LinuxDoOAuthEnabled,
|
||||
SoraClientEnabled: settings.SoraClientEnabled,
|
||||
Version: h.version,
|
||||
})
|
||||
}
|
||||
|
||||
979
backend/internal/handler/sora_client_handler.go
Normal file
979
backend/internal/handler/sora_client_handler.go
Normal file
@@ -0,0 +1,979 @@
|
||||
package handler
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"strconv"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/ctxkey"
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/logger"
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/response"
|
||||
middleware2 "github.com/Wei-Shaw/sub2api/internal/server/middleware"
|
||||
"github.com/Wei-Shaw/sub2api/internal/service"
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
const (
|
||||
// 上游模型缓存 TTL
|
||||
modelCacheTTL = 1 * time.Hour // 上游获取成功
|
||||
modelCacheFailedTTL = 2 * time.Minute // 上游获取失败(降级到本地)
|
||||
)
|
||||
|
||||
// SoraClientHandler 处理 Sora 客户端 API 请求。
|
||||
type SoraClientHandler struct {
|
||||
genService *service.SoraGenerationService
|
||||
quotaService *service.SoraQuotaService
|
||||
s3Storage *service.SoraS3Storage
|
||||
soraGatewayService *service.SoraGatewayService
|
||||
gatewayService *service.GatewayService
|
||||
mediaStorage *service.SoraMediaStorage
|
||||
apiKeyService *service.APIKeyService
|
||||
|
||||
// 上游模型缓存
|
||||
modelCacheMu sync.RWMutex
|
||||
cachedFamilies []service.SoraModelFamily
|
||||
modelCacheTime time.Time
|
||||
modelCacheUpstream bool // 是否来自上游(决定 TTL)
|
||||
}
|
||||
|
||||
// NewSoraClientHandler 创建 Sora 客户端 Handler。
|
||||
func NewSoraClientHandler(
|
||||
genService *service.SoraGenerationService,
|
||||
quotaService *service.SoraQuotaService,
|
||||
s3Storage *service.SoraS3Storage,
|
||||
soraGatewayService *service.SoraGatewayService,
|
||||
gatewayService *service.GatewayService,
|
||||
mediaStorage *service.SoraMediaStorage,
|
||||
apiKeyService *service.APIKeyService,
|
||||
) *SoraClientHandler {
|
||||
return &SoraClientHandler{
|
||||
genService: genService,
|
||||
quotaService: quotaService,
|
||||
s3Storage: s3Storage,
|
||||
soraGatewayService: soraGatewayService,
|
||||
gatewayService: gatewayService,
|
||||
mediaStorage: mediaStorage,
|
||||
apiKeyService: apiKeyService,
|
||||
}
|
||||
}
|
||||
|
||||
// GenerateRequest 生成请求。
|
||||
type GenerateRequest struct {
|
||||
Model string `json:"model" binding:"required"`
|
||||
Prompt string `json:"prompt" binding:"required"`
|
||||
MediaType string `json:"media_type"` // video / image,默认 video
|
||||
VideoCount int `json:"video_count,omitempty"` // 视频数量(1-3)
|
||||
ImageInput string `json:"image_input,omitempty"` // 参考图(base64 或 URL)
|
||||
APIKeyID *int64 `json:"api_key_id,omitempty"` // 前端传递的 API Key ID
|
||||
}
|
||||
|
||||
// Generate 异步生成 — 创建 pending 记录后立即返回。
|
||||
// POST /api/v1/sora/generate
|
||||
func (h *SoraClientHandler) Generate(c *gin.Context) {
|
||||
userID := getUserIDFromContext(c)
|
||||
if userID == 0 {
|
||||
response.Error(c, http.StatusUnauthorized, "未登录")
|
||||
return
|
||||
}
|
||||
|
||||
var req GenerateRequest
|
||||
if err := c.ShouldBindJSON(&req); err != nil {
|
||||
response.Error(c, http.StatusBadRequest, "参数错误: "+err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
if req.MediaType == "" {
|
||||
req.MediaType = "video"
|
||||
}
|
||||
req.VideoCount = normalizeVideoCount(req.MediaType, req.VideoCount)
|
||||
|
||||
// 并发数检查(最多 3 个)
|
||||
activeCount, err := h.genService.CountActiveByUser(c.Request.Context(), userID)
|
||||
if err != nil {
|
||||
response.ErrorFrom(c, err)
|
||||
return
|
||||
}
|
||||
if activeCount >= 3 {
|
||||
response.Error(c, http.StatusTooManyRequests, "同时进行中的任务不能超过 3 个")
|
||||
return
|
||||
}
|
||||
|
||||
// 配额检查(粗略检查,实际文件大小在上传后才知道)
|
||||
if h.quotaService != nil {
|
||||
if err := h.quotaService.CheckQuota(c.Request.Context(), userID, 0); err != nil {
|
||||
var quotaErr *service.QuotaExceededError
|
||||
if errors.As(err, "aErr) {
|
||||
response.Error(c, http.StatusTooManyRequests, "存储配额已满,请删除不需要的作品释放空间")
|
||||
return
|
||||
}
|
||||
response.Error(c, http.StatusForbidden, err.Error())
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
// 获取 API Key ID 和 Group ID
|
||||
var apiKeyID *int64
|
||||
var groupID *int64
|
||||
|
||||
if req.APIKeyID != nil && h.apiKeyService != nil {
|
||||
// 前端传递了 api_key_id,需要校验
|
||||
apiKey, err := h.apiKeyService.GetByID(c.Request.Context(), *req.APIKeyID)
|
||||
if err != nil {
|
||||
response.Error(c, http.StatusBadRequest, "API Key 不存在")
|
||||
return
|
||||
}
|
||||
if apiKey.UserID != userID {
|
||||
response.Error(c, http.StatusForbidden, "API Key 不属于当前用户")
|
||||
return
|
||||
}
|
||||
if apiKey.Status != service.StatusAPIKeyActive {
|
||||
response.Error(c, http.StatusForbidden, "API Key 不可用")
|
||||
return
|
||||
}
|
||||
apiKeyID = &apiKey.ID
|
||||
groupID = apiKey.GroupID
|
||||
} else if id, ok := c.Get("api_key_id"); ok {
|
||||
// 兼容 API Key 认证路径(/sora/v1/ 网关路由)
|
||||
if v, ok := id.(int64); ok {
|
||||
apiKeyID = &v
|
||||
}
|
||||
}
|
||||
|
||||
gen, err := h.genService.CreatePending(c.Request.Context(), userID, apiKeyID, req.Model, req.Prompt, req.MediaType)
|
||||
if err != nil {
|
||||
if errors.Is(err, service.ErrSoraGenerationConcurrencyLimit) {
|
||||
response.Error(c, http.StatusTooManyRequests, "同时进行中的任务不能超过 3 个")
|
||||
return
|
||||
}
|
||||
response.ErrorFrom(c, err)
|
||||
return
|
||||
}
|
||||
|
||||
// 启动后台异步生成 goroutine
|
||||
go h.processGeneration(gen.ID, userID, groupID, req.Model, req.Prompt, req.MediaType, req.ImageInput, req.VideoCount)
|
||||
|
||||
response.Success(c, gin.H{
|
||||
"generation_id": gen.ID,
|
||||
"status": gen.Status,
|
||||
})
|
||||
}
|
||||
|
||||
// processGeneration 后台异步执行 Sora 生成任务。
|
||||
// 流程:选择账号 → Forward → 提取媒体 URL → 三层降级存储(S3 → 本地 → 上游)→ 更新记录。
|
||||
func (h *SoraClientHandler) processGeneration(genID int64, userID int64, groupID *int64, model, prompt, mediaType, imageInput string, videoCount int) {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Minute)
|
||||
defer cancel()
|
||||
|
||||
// 标记为生成中
|
||||
if err := h.genService.MarkGenerating(ctx, genID, ""); err != nil {
|
||||
if errors.Is(err, service.ErrSoraGenerationStateConflict) {
|
||||
logger.LegacyPrintf("handler.sora_client", "[SoraClient] 任务状态已变化,跳过生成 id=%d", genID)
|
||||
return
|
||||
}
|
||||
logger.LegacyPrintf("handler.sora_client", "[SoraClient] 标记生成中失败 id=%d err=%v", genID, err)
|
||||
return
|
||||
}
|
||||
|
||||
logger.LegacyPrintf(
|
||||
"handler.sora_client",
|
||||
"[SoraClient] 开始生成 id=%d user=%d group=%d model=%s media_type=%s video_count=%d has_image=%v prompt_len=%d",
|
||||
genID,
|
||||
userID,
|
||||
groupIDForLog(groupID),
|
||||
model,
|
||||
mediaType,
|
||||
videoCount,
|
||||
strings.TrimSpace(imageInput) != "",
|
||||
len(strings.TrimSpace(prompt)),
|
||||
)
|
||||
|
||||
// 有 groupID 时由分组决定平台,无 groupID 时用 ForcePlatform 兜底
|
||||
if groupID == nil {
|
||||
ctx = context.WithValue(ctx, ctxkey.ForcePlatform, service.PlatformSora)
|
||||
}
|
||||
|
||||
if h.gatewayService == nil {
|
||||
_ = h.genService.MarkFailed(ctx, genID, "内部错误: gatewayService 未初始化")
|
||||
return
|
||||
}
|
||||
|
||||
// 选择 Sora 账号
|
||||
account, err := h.gatewayService.SelectAccountForModel(ctx, groupID, "", model)
|
||||
if err != nil {
|
||||
logger.LegacyPrintf(
|
||||
"handler.sora_client",
|
||||
"[SoraClient] 选择账号失败 id=%d user=%d group=%d model=%s err=%v",
|
||||
genID,
|
||||
userID,
|
||||
groupIDForLog(groupID),
|
||||
model,
|
||||
err,
|
||||
)
|
||||
_ = h.genService.MarkFailed(ctx, genID, "选择账号失败: "+err.Error())
|
||||
return
|
||||
}
|
||||
logger.LegacyPrintf(
|
||||
"handler.sora_client",
|
||||
"[SoraClient] 选中账号 id=%d user=%d group=%d model=%s account_id=%d account_name=%s platform=%s type=%s",
|
||||
genID,
|
||||
userID,
|
||||
groupIDForLog(groupID),
|
||||
model,
|
||||
account.ID,
|
||||
account.Name,
|
||||
account.Platform,
|
||||
account.Type,
|
||||
)
|
||||
|
||||
// 构建 chat completions 请求体(非流式)
|
||||
body := buildAsyncRequestBody(model, prompt, imageInput, normalizeVideoCount(mediaType, videoCount))
|
||||
|
||||
if h.soraGatewayService == nil {
|
||||
_ = h.genService.MarkFailed(ctx, genID, "内部错误: soraGatewayService 未初始化")
|
||||
return
|
||||
}
|
||||
|
||||
// 创建 mock gin 上下文用于 Forward(捕获响应以提取媒体 URL)
|
||||
recorder := httptest.NewRecorder()
|
||||
mockGinCtx, _ := gin.CreateTestContext(recorder)
|
||||
mockGinCtx.Request, _ = http.NewRequest("POST", "/", nil)
|
||||
|
||||
// 调用 Forward(非流式)
|
||||
result, err := h.soraGatewayService.Forward(ctx, mockGinCtx, account, body, false)
|
||||
if err != nil {
|
||||
logger.LegacyPrintf(
|
||||
"handler.sora_client",
|
||||
"[SoraClient] Forward失败 id=%d account_id=%d model=%s status=%d body=%s err=%v",
|
||||
genID,
|
||||
account.ID,
|
||||
model,
|
||||
recorder.Code,
|
||||
trimForLog(recorder.Body.String(), 400),
|
||||
err,
|
||||
)
|
||||
// 检查是否已取消
|
||||
gen, _ := h.genService.GetByID(ctx, genID, userID)
|
||||
if gen != nil && gen.Status == service.SoraGenStatusCancelled {
|
||||
return
|
||||
}
|
||||
_ = h.genService.MarkFailed(ctx, genID, "生成失败: "+err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
// 提取媒体 URL(优先从 ForwardResult,其次从响应体解析)
|
||||
mediaURL, mediaURLs := extractMediaURLsFromResult(result, recorder)
|
||||
if mediaURL == "" {
|
||||
logger.LegacyPrintf(
|
||||
"handler.sora_client",
|
||||
"[SoraClient] 未提取到媒体URL id=%d account_id=%d model=%s status=%d body=%s",
|
||||
genID,
|
||||
account.ID,
|
||||
model,
|
||||
recorder.Code,
|
||||
trimForLog(recorder.Body.String(), 400),
|
||||
)
|
||||
_ = h.genService.MarkFailed(ctx, genID, "未获取到媒体 URL")
|
||||
return
|
||||
}
|
||||
|
||||
// 检查任务是否已被取消
|
||||
gen, _ := h.genService.GetByID(ctx, genID, userID)
|
||||
if gen != nil && gen.Status == service.SoraGenStatusCancelled {
|
||||
logger.LegacyPrintf("handler.sora_client", "[SoraClient] 任务已取消,跳过存储 id=%d", genID)
|
||||
return
|
||||
}
|
||||
|
||||
// 三层降级存储:S3 → 本地 → 上游临时 URL
|
||||
storedURL, storedURLs, storageType, s3Keys, fileSize := h.storeMediaWithDegradation(ctx, userID, mediaType, mediaURL, mediaURLs)
|
||||
|
||||
usageAdded := false
|
||||
if (storageType == service.SoraStorageTypeS3 || storageType == service.SoraStorageTypeLocal) && fileSize > 0 && h.quotaService != nil {
|
||||
if err := h.quotaService.AddUsage(ctx, userID, fileSize); err != nil {
|
||||
h.cleanupStoredMedia(ctx, storageType, s3Keys, storedURLs)
|
||||
var quotaErr *service.QuotaExceededError
|
||||
if errors.As(err, "aErr) {
|
||||
_ = h.genService.MarkFailed(ctx, genID, "存储配额已满,请删除不需要的作品释放空间")
|
||||
return
|
||||
}
|
||||
_ = h.genService.MarkFailed(ctx, genID, "存储配额更新失败: "+err.Error())
|
||||
return
|
||||
}
|
||||
usageAdded = true
|
||||
}
|
||||
|
||||
// 存储完成后再做一次取消检查,防止取消被 completed 覆盖。
|
||||
gen, _ = h.genService.GetByID(ctx, genID, userID)
|
||||
if gen != nil && gen.Status == service.SoraGenStatusCancelled {
|
||||
logger.LegacyPrintf("handler.sora_client", "[SoraClient] 存储后检测到任务已取消,回滚存储 id=%d", genID)
|
||||
h.cleanupStoredMedia(ctx, storageType, s3Keys, storedURLs)
|
||||
if usageAdded && h.quotaService != nil {
|
||||
_ = h.quotaService.ReleaseUsage(ctx, userID, fileSize)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
// 标记完成
|
||||
if err := h.genService.MarkCompleted(ctx, genID, storedURL, storedURLs, storageType, s3Keys, fileSize); err != nil {
|
||||
if errors.Is(err, service.ErrSoraGenerationStateConflict) {
|
||||
h.cleanupStoredMedia(ctx, storageType, s3Keys, storedURLs)
|
||||
if usageAdded && h.quotaService != nil {
|
||||
_ = h.quotaService.ReleaseUsage(ctx, userID, fileSize)
|
||||
}
|
||||
return
|
||||
}
|
||||
logger.LegacyPrintf("handler.sora_client", "[SoraClient] 标记完成失败 id=%d err=%v", genID, err)
|
||||
return
|
||||
}
|
||||
|
||||
logger.LegacyPrintf("handler.sora_client", "[SoraClient] 生成完成 id=%d storage=%s size=%d", genID, storageType, fileSize)
|
||||
}
|
||||
|
||||
// storeMediaWithDegradation 实现三层降级存储链:S3 → 本地 → 上游。
|
||||
func (h *SoraClientHandler) storeMediaWithDegradation(
|
||||
ctx context.Context, userID int64, mediaType string,
|
||||
mediaURL string, mediaURLs []string,
|
||||
) (storedURL string, storedURLs []string, storageType string, s3Keys []string, fileSize int64) {
|
||||
urls := mediaURLs
|
||||
if len(urls) == 0 {
|
||||
urls = []string{mediaURL}
|
||||
}
|
||||
|
||||
// 第一层:尝试 S3
|
||||
if h.s3Storage != nil && h.s3Storage.Enabled(ctx) {
|
||||
keys := make([]string, 0, len(urls))
|
||||
var totalSize int64
|
||||
allOK := true
|
||||
for _, u := range urls {
|
||||
key, size, err := h.s3Storage.UploadFromURL(ctx, userID, u)
|
||||
if err != nil {
|
||||
logger.LegacyPrintf("handler.sora_client", "[SoraClient] S3 上传失败 err=%v", err)
|
||||
allOK = false
|
||||
// 清理已上传的文件
|
||||
if len(keys) > 0 {
|
||||
_ = h.s3Storage.DeleteObjects(ctx, keys)
|
||||
}
|
||||
break
|
||||
}
|
||||
keys = append(keys, key)
|
||||
totalSize += size
|
||||
}
|
||||
if allOK && len(keys) > 0 {
|
||||
accessURLs := make([]string, 0, len(keys))
|
||||
for _, key := range keys {
|
||||
accessURL, err := h.s3Storage.GetAccessURL(ctx, key)
|
||||
if err != nil {
|
||||
logger.LegacyPrintf("handler.sora_client", "[SoraClient] 生成 S3 访问 URL 失败 err=%v", err)
|
||||
_ = h.s3Storage.DeleteObjects(ctx, keys)
|
||||
allOK = false
|
||||
break
|
||||
}
|
||||
accessURLs = append(accessURLs, accessURL)
|
||||
}
|
||||
if allOK && len(accessURLs) > 0 {
|
||||
return accessURLs[0], accessURLs, service.SoraStorageTypeS3, keys, totalSize
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// 第二层:尝试本地存储
|
||||
if h.mediaStorage != nil && h.mediaStorage.Enabled() {
|
||||
storedPaths, err := h.mediaStorage.StoreFromURLs(ctx, mediaType, urls)
|
||||
if err == nil && len(storedPaths) > 0 {
|
||||
firstPath := storedPaths[0]
|
||||
totalSize, sizeErr := h.mediaStorage.TotalSizeByRelativePaths(storedPaths)
|
||||
if sizeErr != nil {
|
||||
logger.LegacyPrintf("handler.sora_client", "[SoraClient] 统计本地文件大小失败 err=%v", sizeErr)
|
||||
}
|
||||
return firstPath, storedPaths, service.SoraStorageTypeLocal, nil, totalSize
|
||||
}
|
||||
logger.LegacyPrintf("handler.sora_client", "[SoraClient] 本地存储失败 err=%v", err)
|
||||
}
|
||||
|
||||
// 第三层:保留上游临时 URL
|
||||
return urls[0], urls, service.SoraStorageTypeUpstream, nil, 0
|
||||
}
|
||||
|
||||
// buildAsyncRequestBody 构建 Sora 异步生成的 chat completions 请求体。
|
||||
func buildAsyncRequestBody(model, prompt, imageInput string, videoCount int) []byte {
|
||||
body := map[string]any{
|
||||
"model": model,
|
||||
"messages": []map[string]string{
|
||||
{"role": "user", "content": prompt},
|
||||
},
|
||||
"stream": false,
|
||||
}
|
||||
if imageInput != "" {
|
||||
body["image_input"] = imageInput
|
||||
}
|
||||
if videoCount > 1 {
|
||||
body["video_count"] = videoCount
|
||||
}
|
||||
b, _ := json.Marshal(body)
|
||||
return b
|
||||
}
|
||||
|
||||
func normalizeVideoCount(mediaType string, videoCount int) int {
|
||||
if mediaType != "video" {
|
||||
return 1
|
||||
}
|
||||
if videoCount <= 0 {
|
||||
return 1
|
||||
}
|
||||
if videoCount > 3 {
|
||||
return 3
|
||||
}
|
||||
return videoCount
|
||||
}
|
||||
|
||||
// extractMediaURLsFromResult 从 Forward 结果和响应体中提取媒体 URL。
|
||||
// OAuth 路径:ForwardResult.MediaURL 已填充。
|
||||
// APIKey 路径:需从响应体解析 media_url / media_urls 字段。
|
||||
func extractMediaURLsFromResult(result *service.ForwardResult, recorder *httptest.ResponseRecorder) (string, []string) {
|
||||
// 优先从 ForwardResult 获取(OAuth 路径)
|
||||
if result != nil && result.MediaURL != "" {
|
||||
// 尝试从响应体获取完整 URL 列表
|
||||
if urls := parseMediaURLsFromBody(recorder.Body.Bytes()); len(urls) > 0 {
|
||||
return urls[0], urls
|
||||
}
|
||||
return result.MediaURL, []string{result.MediaURL}
|
||||
}
|
||||
|
||||
// 从响应体解析(APIKey 路径)
|
||||
if urls := parseMediaURLsFromBody(recorder.Body.Bytes()); len(urls) > 0 {
|
||||
return urls[0], urls
|
||||
}
|
||||
|
||||
return "", nil
|
||||
}
|
||||
|
||||
// parseMediaURLsFromBody 从 JSON 响应体中解析 media_url / media_urls 字段。
|
||||
func parseMediaURLsFromBody(body []byte) []string {
|
||||
if len(body) == 0 {
|
||||
return nil
|
||||
}
|
||||
var resp map[string]any
|
||||
if err := json.Unmarshal(body, &resp); err != nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
// 优先 media_urls(多图数组)
|
||||
if rawURLs, ok := resp["media_urls"]; ok {
|
||||
if arr, ok := rawURLs.([]any); ok && len(arr) > 0 {
|
||||
urls := make([]string, 0, len(arr))
|
||||
for _, item := range arr {
|
||||
if s, ok := item.(string); ok && s != "" {
|
||||
urls = append(urls, s)
|
||||
}
|
||||
}
|
||||
if len(urls) > 0 {
|
||||
return urls
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// 回退到 media_url(单个 URL)
|
||||
if url, ok := resp["media_url"].(string); ok && url != "" {
|
||||
return []string{url}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// ListGenerations 查询生成记录列表。
|
||||
// GET /api/v1/sora/generations
|
||||
func (h *SoraClientHandler) ListGenerations(c *gin.Context) {
|
||||
userID := getUserIDFromContext(c)
|
||||
if userID == 0 {
|
||||
response.Error(c, http.StatusUnauthorized, "未登录")
|
||||
return
|
||||
}
|
||||
|
||||
page, _ := strconv.Atoi(c.DefaultQuery("page", "1"))
|
||||
pageSize, _ := strconv.Atoi(c.DefaultQuery("page_size", "20"))
|
||||
|
||||
params := service.SoraGenerationListParams{
|
||||
UserID: userID,
|
||||
Status: c.Query("status"),
|
||||
StorageType: c.Query("storage_type"),
|
||||
MediaType: c.Query("media_type"),
|
||||
Page: page,
|
||||
PageSize: pageSize,
|
||||
}
|
||||
|
||||
gens, total, err := h.genService.List(c.Request.Context(), params)
|
||||
if err != nil {
|
||||
response.ErrorFrom(c, err)
|
||||
return
|
||||
}
|
||||
|
||||
// 为 S3 记录动态生成预签名 URL
|
||||
for _, gen := range gens {
|
||||
_ = h.genService.ResolveMediaURLs(c.Request.Context(), gen)
|
||||
}
|
||||
|
||||
response.Success(c, gin.H{
|
||||
"data": gens,
|
||||
"total": total,
|
||||
"page": page,
|
||||
})
|
||||
}
|
||||
|
||||
// GetGeneration 查询生成记录详情。
|
||||
// GET /api/v1/sora/generations/:id
|
||||
func (h *SoraClientHandler) GetGeneration(c *gin.Context) {
|
||||
userID := getUserIDFromContext(c)
|
||||
if userID == 0 {
|
||||
response.Error(c, http.StatusUnauthorized, "未登录")
|
||||
return
|
||||
}
|
||||
|
||||
id, err := strconv.ParseInt(c.Param("id"), 10, 64)
|
||||
if err != nil {
|
||||
response.Error(c, http.StatusBadRequest, "无效的 ID")
|
||||
return
|
||||
}
|
||||
|
||||
gen, err := h.genService.GetByID(c.Request.Context(), id, userID)
|
||||
if err != nil {
|
||||
response.Error(c, http.StatusNotFound, err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
_ = h.genService.ResolveMediaURLs(c.Request.Context(), gen)
|
||||
response.Success(c, gen)
|
||||
}
|
||||
|
||||
// DeleteGeneration 删除生成记录。
|
||||
// DELETE /api/v1/sora/generations/:id
|
||||
func (h *SoraClientHandler) DeleteGeneration(c *gin.Context) {
|
||||
userID := getUserIDFromContext(c)
|
||||
if userID == 0 {
|
||||
response.Error(c, http.StatusUnauthorized, "未登录")
|
||||
return
|
||||
}
|
||||
|
||||
id, err := strconv.ParseInt(c.Param("id"), 10, 64)
|
||||
if err != nil {
|
||||
response.Error(c, http.StatusBadRequest, "无效的 ID")
|
||||
return
|
||||
}
|
||||
|
||||
gen, err := h.genService.GetByID(c.Request.Context(), id, userID)
|
||||
if err != nil {
|
||||
response.Error(c, http.StatusNotFound, err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
// 先尝试清理本地文件,再删除记录(清理失败不阻塞删除)。
|
||||
if gen.StorageType == service.SoraStorageTypeLocal && h.mediaStorage != nil {
|
||||
paths := gen.MediaURLs
|
||||
if len(paths) == 0 && gen.MediaURL != "" {
|
||||
paths = []string{gen.MediaURL}
|
||||
}
|
||||
if err := h.mediaStorage.DeleteByRelativePaths(paths); err != nil {
|
||||
logger.LegacyPrintf("handler.sora_client", "[SoraClient] 删除本地文件失败 id=%d err=%v", id, err)
|
||||
}
|
||||
}
|
||||
|
||||
if err := h.genService.Delete(c.Request.Context(), id, userID); err != nil {
|
||||
response.Error(c, http.StatusNotFound, err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
response.Success(c, gin.H{"message": "已删除"})
|
||||
}
|
||||
|
||||
// GetQuota 查询用户存储配额。
|
||||
// GET /api/v1/sora/quota
|
||||
func (h *SoraClientHandler) GetQuota(c *gin.Context) {
|
||||
userID := getUserIDFromContext(c)
|
||||
if userID == 0 {
|
||||
response.Error(c, http.StatusUnauthorized, "未登录")
|
||||
return
|
||||
}
|
||||
|
||||
if h.quotaService == nil {
|
||||
response.Success(c, service.QuotaInfo{QuotaSource: "unlimited", Source: "unlimited"})
|
||||
return
|
||||
}
|
||||
|
||||
quota, err := h.quotaService.GetQuota(c.Request.Context(), userID)
|
||||
if err != nil {
|
||||
response.ErrorFrom(c, err)
|
||||
return
|
||||
}
|
||||
response.Success(c, quota)
|
||||
}
|
||||
|
||||
// CancelGeneration 取消生成任务。
|
||||
// POST /api/v1/sora/generations/:id/cancel
|
||||
func (h *SoraClientHandler) CancelGeneration(c *gin.Context) {
|
||||
userID := getUserIDFromContext(c)
|
||||
if userID == 0 {
|
||||
response.Error(c, http.StatusUnauthorized, "未登录")
|
||||
return
|
||||
}
|
||||
|
||||
id, err := strconv.ParseInt(c.Param("id"), 10, 64)
|
||||
if err != nil {
|
||||
response.Error(c, http.StatusBadRequest, "无效的 ID")
|
||||
return
|
||||
}
|
||||
|
||||
// 权限校验
|
||||
gen, err := h.genService.GetByID(c.Request.Context(), id, userID)
|
||||
if err != nil {
|
||||
response.Error(c, http.StatusNotFound, err.Error())
|
||||
return
|
||||
}
|
||||
_ = gen
|
||||
|
||||
if err := h.genService.MarkCancelled(c.Request.Context(), id); err != nil {
|
||||
if errors.Is(err, service.ErrSoraGenerationNotActive) {
|
||||
response.Error(c, http.StatusConflict, "任务已结束,无法取消")
|
||||
return
|
||||
}
|
||||
response.Error(c, http.StatusBadRequest, err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
response.Success(c, gin.H{"message": "已取消"})
|
||||
}
|
||||
|
||||
// SaveToStorage 手动保存 upstream 记录到 S3。
|
||||
// POST /api/v1/sora/generations/:id/save
|
||||
func (h *SoraClientHandler) SaveToStorage(c *gin.Context) {
|
||||
userID := getUserIDFromContext(c)
|
||||
if userID == 0 {
|
||||
response.Error(c, http.StatusUnauthorized, "未登录")
|
||||
return
|
||||
}
|
||||
|
||||
id, err := strconv.ParseInt(c.Param("id"), 10, 64)
|
||||
if err != nil {
|
||||
response.Error(c, http.StatusBadRequest, "无效的 ID")
|
||||
return
|
||||
}
|
||||
|
||||
gen, err := h.genService.GetByID(c.Request.Context(), id, userID)
|
||||
if err != nil {
|
||||
response.Error(c, http.StatusNotFound, err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
if gen.StorageType != service.SoraStorageTypeUpstream {
|
||||
response.Error(c, http.StatusBadRequest, "仅 upstream 类型的记录可手动保存")
|
||||
return
|
||||
}
|
||||
if gen.MediaURL == "" {
|
||||
response.Error(c, http.StatusBadRequest, "媒体 URL 为空,可能已过期")
|
||||
return
|
||||
}
|
||||
|
||||
if h.s3Storage == nil || !h.s3Storage.Enabled(c.Request.Context()) {
|
||||
response.Error(c, http.StatusServiceUnavailable, "云存储未配置,请联系管理员")
|
||||
return
|
||||
}
|
||||
|
||||
sourceURLs := gen.MediaURLs
|
||||
if len(sourceURLs) == 0 && gen.MediaURL != "" {
|
||||
sourceURLs = []string{gen.MediaURL}
|
||||
}
|
||||
if len(sourceURLs) == 0 {
|
||||
response.Error(c, http.StatusBadRequest, "媒体 URL 为空,可能已过期")
|
||||
return
|
||||
}
|
||||
|
||||
uploadedKeys := make([]string, 0, len(sourceURLs))
|
||||
accessURLs := make([]string, 0, len(sourceURLs))
|
||||
var totalSize int64
|
||||
|
||||
for _, sourceURL := range sourceURLs {
|
||||
objectKey, fileSize, uploadErr := h.s3Storage.UploadFromURL(c.Request.Context(), userID, sourceURL)
|
||||
if uploadErr != nil {
|
||||
if len(uploadedKeys) > 0 {
|
||||
_ = h.s3Storage.DeleteObjects(c.Request.Context(), uploadedKeys)
|
||||
}
|
||||
var upstreamErr *service.UpstreamDownloadError
|
||||
if errors.As(uploadErr, &upstreamErr) && (upstreamErr.StatusCode == http.StatusForbidden || upstreamErr.StatusCode == http.StatusNotFound) {
|
||||
response.Error(c, http.StatusGone, "媒体链接已过期,无法保存")
|
||||
return
|
||||
}
|
||||
response.Error(c, http.StatusInternalServerError, "上传到 S3 失败: "+uploadErr.Error())
|
||||
return
|
||||
}
|
||||
accessURL, err := h.s3Storage.GetAccessURL(c.Request.Context(), objectKey)
|
||||
if err != nil {
|
||||
uploadedKeys = append(uploadedKeys, objectKey)
|
||||
_ = h.s3Storage.DeleteObjects(c.Request.Context(), uploadedKeys)
|
||||
response.Error(c, http.StatusInternalServerError, "生成 S3 访问链接失败: "+err.Error())
|
||||
return
|
||||
}
|
||||
uploadedKeys = append(uploadedKeys, objectKey)
|
||||
accessURLs = append(accessURLs, accessURL)
|
||||
totalSize += fileSize
|
||||
}
|
||||
|
||||
usageAdded := false
|
||||
if totalSize > 0 && h.quotaService != nil {
|
||||
if err := h.quotaService.AddUsage(c.Request.Context(), userID, totalSize); err != nil {
|
||||
_ = h.s3Storage.DeleteObjects(c.Request.Context(), uploadedKeys)
|
||||
var quotaErr *service.QuotaExceededError
|
||||
if errors.As(err, "aErr) {
|
||||
response.Error(c, http.StatusTooManyRequests, "存储配额已满,请删除不需要的作品释放空间")
|
||||
return
|
||||
}
|
||||
response.Error(c, http.StatusInternalServerError, "配额更新失败: "+err.Error())
|
||||
return
|
||||
}
|
||||
usageAdded = true
|
||||
}
|
||||
|
||||
if err := h.genService.UpdateStorageForCompleted(
|
||||
c.Request.Context(),
|
||||
id,
|
||||
accessURLs[0],
|
||||
accessURLs,
|
||||
service.SoraStorageTypeS3,
|
||||
uploadedKeys,
|
||||
totalSize,
|
||||
); err != nil {
|
||||
_ = h.s3Storage.DeleteObjects(c.Request.Context(), uploadedKeys)
|
||||
if usageAdded && h.quotaService != nil {
|
||||
_ = h.quotaService.ReleaseUsage(c.Request.Context(), userID, totalSize)
|
||||
}
|
||||
response.ErrorFrom(c, err)
|
||||
return
|
||||
}
|
||||
|
||||
response.Success(c, gin.H{
|
||||
"message": "已保存到 S3",
|
||||
"object_key": uploadedKeys[0],
|
||||
"object_keys": uploadedKeys,
|
||||
})
|
||||
}
|
||||
|
||||
// GetStorageStatus 返回存储状态。
|
||||
// GET /api/v1/sora/storage-status
|
||||
func (h *SoraClientHandler) GetStorageStatus(c *gin.Context) {
|
||||
s3Enabled := h.s3Storage != nil && h.s3Storage.Enabled(c.Request.Context())
|
||||
s3Healthy := false
|
||||
if s3Enabled {
|
||||
s3Healthy = h.s3Storage.IsHealthy(c.Request.Context())
|
||||
}
|
||||
localEnabled := h.mediaStorage != nil && h.mediaStorage.Enabled()
|
||||
response.Success(c, gin.H{
|
||||
"s3_enabled": s3Enabled,
|
||||
"s3_healthy": s3Healthy,
|
||||
"local_enabled": localEnabled,
|
||||
})
|
||||
}
|
||||
|
||||
func (h *SoraClientHandler) cleanupStoredMedia(ctx context.Context, storageType string, s3Keys []string, localPaths []string) {
|
||||
switch storageType {
|
||||
case service.SoraStorageTypeS3:
|
||||
if h.s3Storage != nil && len(s3Keys) > 0 {
|
||||
if err := h.s3Storage.DeleteObjects(ctx, s3Keys); err != nil {
|
||||
logger.LegacyPrintf("handler.sora_client", "[SoraClient] 清理 S3 文件失败 keys=%v err=%v", s3Keys, err)
|
||||
}
|
||||
}
|
||||
case service.SoraStorageTypeLocal:
|
||||
if h.mediaStorage != nil && len(localPaths) > 0 {
|
||||
if err := h.mediaStorage.DeleteByRelativePaths(localPaths); err != nil {
|
||||
logger.LegacyPrintf("handler.sora_client", "[SoraClient] 清理本地文件失败 paths=%v err=%v", localPaths, err)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// getUserIDFromContext 从 gin 上下文中提取用户 ID。
|
||||
func getUserIDFromContext(c *gin.Context) int64 {
|
||||
if subject, ok := middleware2.GetAuthSubjectFromContext(c); ok && subject.UserID > 0 {
|
||||
return subject.UserID
|
||||
}
|
||||
|
||||
if id, ok := c.Get("user_id"); ok {
|
||||
switch v := id.(type) {
|
||||
case int64:
|
||||
return v
|
||||
case float64:
|
||||
return int64(v)
|
||||
case string:
|
||||
n, _ := strconv.ParseInt(v, 10, 64)
|
||||
return n
|
||||
}
|
||||
}
|
||||
// 尝试从 JWT claims 获取
|
||||
if id, ok := c.Get("userID"); ok {
|
||||
if v, ok := id.(int64); ok {
|
||||
return v
|
||||
}
|
||||
}
|
||||
return 0
|
||||
}
|
||||
|
||||
func groupIDForLog(groupID *int64) int64 {
|
||||
if groupID == nil {
|
||||
return 0
|
||||
}
|
||||
return *groupID
|
||||
}
|
||||
|
||||
func trimForLog(raw string, maxLen int) string {
|
||||
trimmed := strings.TrimSpace(raw)
|
||||
if maxLen <= 0 || len(trimmed) <= maxLen {
|
||||
return trimmed
|
||||
}
|
||||
return trimmed[:maxLen] + "...(truncated)"
|
||||
}
|
||||
|
||||
// GetModels 获取可用 Sora 模型家族列表。
|
||||
// 优先从上游 Sora API 同步模型列表,失败时降级到本地配置。
|
||||
// GET /api/v1/sora/models
|
||||
func (h *SoraClientHandler) GetModels(c *gin.Context) {
|
||||
families := h.getModelFamilies(c.Request.Context())
|
||||
response.Success(c, families)
|
||||
}
|
||||
|
||||
// getModelFamilies 获取模型家族列表(带缓存)。
|
||||
func (h *SoraClientHandler) getModelFamilies(ctx context.Context) []service.SoraModelFamily {
|
||||
// 读锁检查缓存
|
||||
h.modelCacheMu.RLock()
|
||||
ttl := modelCacheTTL
|
||||
if !h.modelCacheUpstream {
|
||||
ttl = modelCacheFailedTTL
|
||||
}
|
||||
if h.cachedFamilies != nil && time.Since(h.modelCacheTime) < ttl {
|
||||
families := h.cachedFamilies
|
||||
h.modelCacheMu.RUnlock()
|
||||
return families
|
||||
}
|
||||
h.modelCacheMu.RUnlock()
|
||||
|
||||
// 写锁更新缓存
|
||||
h.modelCacheMu.Lock()
|
||||
defer h.modelCacheMu.Unlock()
|
||||
|
||||
// double-check
|
||||
ttl = modelCacheTTL
|
||||
if !h.modelCacheUpstream {
|
||||
ttl = modelCacheFailedTTL
|
||||
}
|
||||
if h.cachedFamilies != nil && time.Since(h.modelCacheTime) < ttl {
|
||||
return h.cachedFamilies
|
||||
}
|
||||
|
||||
// 尝试从上游获取
|
||||
families, err := h.fetchUpstreamModels(ctx)
|
||||
if err != nil {
|
||||
logger.LegacyPrintf("handler.sora_client", "[SoraClient] 上游模型获取失败,使用本地配置: %v", err)
|
||||
families = service.BuildSoraModelFamilies()
|
||||
h.cachedFamilies = families
|
||||
h.modelCacheTime = time.Now()
|
||||
h.modelCacheUpstream = false
|
||||
return families
|
||||
}
|
||||
|
||||
logger.LegacyPrintf("handler.sora_client", "[SoraClient] 从上游同步到 %d 个模型家族", len(families))
|
||||
h.cachedFamilies = families
|
||||
h.modelCacheTime = time.Now()
|
||||
h.modelCacheUpstream = true
|
||||
return families
|
||||
}
|
||||
|
||||
// fetchUpstreamModels 从上游 Sora API 获取模型列表。
|
||||
func (h *SoraClientHandler) fetchUpstreamModels(ctx context.Context) ([]service.SoraModelFamily, error) {
|
||||
if h.gatewayService == nil {
|
||||
return nil, fmt.Errorf("gatewayService 未初始化")
|
||||
}
|
||||
|
||||
// 设置 ForcePlatform 用于 Sora 账号选择
|
||||
ctx = context.WithValue(ctx, ctxkey.ForcePlatform, service.PlatformSora)
|
||||
|
||||
// 选择一个 Sora 账号
|
||||
account, err := h.gatewayService.SelectAccountForModel(ctx, nil, "", "sora2-landscape-10s")
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("选择 Sora 账号失败: %w", err)
|
||||
}
|
||||
|
||||
// 仅支持 API Key 类型账号
|
||||
if account.Type != service.AccountTypeAPIKey {
|
||||
return nil, fmt.Errorf("当前账号类型 %s 不支持模型同步", account.Type)
|
||||
}
|
||||
|
||||
apiKey := account.GetCredential("api_key")
|
||||
if apiKey == "" {
|
||||
return nil, fmt.Errorf("账号缺少 api_key")
|
||||
}
|
||||
|
||||
baseURL := account.GetBaseURL()
|
||||
if baseURL == "" {
|
||||
return nil, fmt.Errorf("账号缺少 base_url")
|
||||
}
|
||||
|
||||
// 构建上游模型列表请求
|
||||
modelsURL := strings.TrimRight(baseURL, "/") + "/sora/v1/models"
|
||||
|
||||
reqCtx, cancel := context.WithTimeout(ctx, 10*time.Second)
|
||||
defer cancel()
|
||||
|
||||
req, err := http.NewRequestWithContext(reqCtx, http.MethodGet, modelsURL, nil)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("创建请求失败: %w", err)
|
||||
}
|
||||
req.Header.Set("Authorization", "Bearer "+apiKey)
|
||||
|
||||
client := &http.Client{Timeout: 10 * time.Second}
|
||||
resp, err := client.Do(req)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("请求上游失败: %w", err)
|
||||
}
|
||||
defer func() {
|
||||
_ = resp.Body.Close()
|
||||
}()
|
||||
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
return nil, fmt.Errorf("上游返回状态码 %d", resp.StatusCode)
|
||||
}
|
||||
|
||||
body, err := io.ReadAll(io.LimitReader(resp.Body, 1*1024*1024))
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("读取响应失败: %w", err)
|
||||
}
|
||||
|
||||
// 解析 OpenAI 格式的模型列表
|
||||
var modelsResp struct {
|
||||
Data []struct {
|
||||
ID string `json:"id"`
|
||||
} `json:"data"`
|
||||
}
|
||||
if err := json.Unmarshal(body, &modelsResp); err != nil {
|
||||
return nil, fmt.Errorf("解析响应失败: %w", err)
|
||||
}
|
||||
|
||||
if len(modelsResp.Data) == 0 {
|
||||
return nil, fmt.Errorf("上游返回空模型列表")
|
||||
}
|
||||
|
||||
// 提取模型 ID
|
||||
modelIDs := make([]string, 0, len(modelsResp.Data))
|
||||
for _, m := range modelsResp.Data {
|
||||
modelIDs = append(modelIDs, m.ID)
|
||||
}
|
||||
|
||||
// 转换为模型家族
|
||||
families := service.BuildSoraModelFamiliesFromIDs(modelIDs)
|
||||
if len(families) == 0 {
|
||||
return nil, fmt.Errorf("未能从上游模型列表中识别出有效的模型家族")
|
||||
}
|
||||
|
||||
return families, nil
|
||||
}
|
||||
3153
backend/internal/handler/sora_client_handler_test.go
Normal file
3153
backend/internal/handler/sora_client_handler_test.go
Normal file
File diff suppressed because it is too large
Load Diff
685
backend/internal/handler/sora_gateway_handler.go
Normal file
685
backend/internal/handler/sora_gateway_handler.go
Normal file
@@ -0,0 +1,685 @@
|
||||
package handler
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/sha256"
|
||||
"encoding/hex"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"os"
|
||||
"path"
|
||||
"path/filepath"
|
||||
"strconv"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/config"
|
||||
pkghttputil "github.com/Wei-Shaw/sub2api/internal/pkg/httputil"
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/ip"
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/logger"
|
||||
middleware2 "github.com/Wei-Shaw/sub2api/internal/server/middleware"
|
||||
"github.com/Wei-Shaw/sub2api/internal/service"
|
||||
"github.com/Wei-Shaw/sub2api/internal/util/soraerror"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/tidwall/gjson"
|
||||
"github.com/tidwall/sjson"
|
||||
"go.uber.org/zap"
|
||||
)
|
||||
|
||||
// SoraGatewayHandler handles Sora chat completions requests
|
||||
type SoraGatewayHandler struct {
|
||||
gatewayService *service.GatewayService
|
||||
soraGatewayService *service.SoraGatewayService
|
||||
billingCacheService *service.BillingCacheService
|
||||
usageRecordWorkerPool *service.UsageRecordWorkerPool
|
||||
concurrencyHelper *ConcurrencyHelper
|
||||
maxAccountSwitches int
|
||||
streamMode string
|
||||
soraTLSEnabled bool
|
||||
soraMediaSigningKey string
|
||||
soraMediaRoot string
|
||||
}
|
||||
|
||||
// NewSoraGatewayHandler creates a new SoraGatewayHandler
|
||||
func NewSoraGatewayHandler(
|
||||
gatewayService *service.GatewayService,
|
||||
soraGatewayService *service.SoraGatewayService,
|
||||
concurrencyService *service.ConcurrencyService,
|
||||
billingCacheService *service.BillingCacheService,
|
||||
usageRecordWorkerPool *service.UsageRecordWorkerPool,
|
||||
cfg *config.Config,
|
||||
) *SoraGatewayHandler {
|
||||
pingInterval := time.Duration(0)
|
||||
maxAccountSwitches := 3
|
||||
streamMode := "force"
|
||||
soraTLSEnabled := true
|
||||
signKey := ""
|
||||
mediaRoot := "/app/data/sora"
|
||||
if cfg != nil {
|
||||
pingInterval = time.Duration(cfg.Concurrency.PingInterval) * time.Second
|
||||
if cfg.Gateway.MaxAccountSwitches > 0 {
|
||||
maxAccountSwitches = cfg.Gateway.MaxAccountSwitches
|
||||
}
|
||||
if mode := strings.TrimSpace(cfg.Gateway.SoraStreamMode); mode != "" {
|
||||
streamMode = mode
|
||||
}
|
||||
soraTLSEnabled = !cfg.Sora.Client.DisableTLSFingerprint
|
||||
signKey = strings.TrimSpace(cfg.Gateway.SoraMediaSigningKey)
|
||||
if root := strings.TrimSpace(cfg.Sora.Storage.LocalPath); root != "" {
|
||||
mediaRoot = root
|
||||
}
|
||||
}
|
||||
return &SoraGatewayHandler{
|
||||
gatewayService: gatewayService,
|
||||
soraGatewayService: soraGatewayService,
|
||||
billingCacheService: billingCacheService,
|
||||
usageRecordWorkerPool: usageRecordWorkerPool,
|
||||
concurrencyHelper: NewConcurrencyHelper(concurrencyService, SSEPingFormatComment, pingInterval),
|
||||
maxAccountSwitches: maxAccountSwitches,
|
||||
streamMode: strings.ToLower(streamMode),
|
||||
soraTLSEnabled: soraTLSEnabled,
|
||||
soraMediaSigningKey: signKey,
|
||||
soraMediaRoot: mediaRoot,
|
||||
}
|
||||
}
|
||||
|
||||
// ChatCompletions handles Sora /v1/chat/completions endpoint
|
||||
func (h *SoraGatewayHandler) ChatCompletions(c *gin.Context) {
|
||||
apiKey, ok := middleware2.GetAPIKeyFromContext(c)
|
||||
if !ok {
|
||||
h.errorResponse(c, http.StatusUnauthorized, "authentication_error", "Invalid API key")
|
||||
return
|
||||
}
|
||||
|
||||
subject, ok := middleware2.GetAuthSubjectFromContext(c)
|
||||
if !ok {
|
||||
h.errorResponse(c, http.StatusInternalServerError, "api_error", "User context not found")
|
||||
return
|
||||
}
|
||||
reqLog := requestLogger(
|
||||
c,
|
||||
"handler.sora_gateway.chat_completions",
|
||||
zap.Int64("user_id", subject.UserID),
|
||||
zap.Int64("api_key_id", apiKey.ID),
|
||||
zap.Any("group_id", apiKey.GroupID),
|
||||
)
|
||||
|
||||
body, err := pkghttputil.ReadRequestBodyWithPrealloc(c.Request)
|
||||
if err != nil {
|
||||
if maxErr, ok := extractMaxBytesError(err); ok {
|
||||
h.errorResponse(c, http.StatusRequestEntityTooLarge, "invalid_request_error", buildBodyTooLargeMessage(maxErr.Limit))
|
||||
return
|
||||
}
|
||||
h.errorResponse(c, http.StatusBadRequest, "invalid_request_error", "Failed to read request body")
|
||||
return
|
||||
}
|
||||
if len(body) == 0 {
|
||||
h.errorResponse(c, http.StatusBadRequest, "invalid_request_error", "Request body is empty")
|
||||
return
|
||||
}
|
||||
|
||||
setOpsRequestContext(c, "", false, body)
|
||||
|
||||
// 校验请求体 JSON 合法性
|
||||
if !gjson.ValidBytes(body) {
|
||||
h.errorResponse(c, http.StatusBadRequest, "invalid_request_error", "Failed to parse request body")
|
||||
return
|
||||
}
|
||||
|
||||
// 使用 gjson 只读提取字段做校验,避免完整 Unmarshal
|
||||
modelResult := gjson.GetBytes(body, "model")
|
||||
if !modelResult.Exists() || modelResult.Type != gjson.String || modelResult.String() == "" {
|
||||
h.errorResponse(c, http.StatusBadRequest, "invalid_request_error", "model is required")
|
||||
return
|
||||
}
|
||||
reqModel := modelResult.String()
|
||||
|
||||
msgsResult := gjson.GetBytes(body, "messages")
|
||||
if !msgsResult.IsArray() || len(msgsResult.Array()) == 0 {
|
||||
h.errorResponse(c, http.StatusBadRequest, "invalid_request_error", "messages is required")
|
||||
return
|
||||
}
|
||||
|
||||
clientStream := gjson.GetBytes(body, "stream").Bool()
|
||||
reqLog = reqLog.With(zap.String("model", reqModel), zap.Bool("stream", clientStream))
|
||||
if !clientStream {
|
||||
if h.streamMode == "error" {
|
||||
h.errorResponse(c, http.StatusBadRequest, "invalid_request_error", "Sora requires stream=true")
|
||||
return
|
||||
}
|
||||
var err error
|
||||
body, err = sjson.SetBytes(body, "stream", true)
|
||||
if err != nil {
|
||||
h.errorResponse(c, http.StatusInternalServerError, "api_error", "Failed to process request")
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
setOpsRequestContext(c, reqModel, clientStream, body)
|
||||
|
||||
platform := ""
|
||||
if forced, ok := middleware2.GetForcePlatformFromContext(c); ok {
|
||||
platform = forced
|
||||
} else if apiKey.Group != nil {
|
||||
platform = apiKey.Group.Platform
|
||||
}
|
||||
if platform != service.PlatformSora {
|
||||
h.errorResponse(c, http.StatusBadRequest, "invalid_request_error", "This endpoint only supports Sora platform")
|
||||
return
|
||||
}
|
||||
|
||||
streamStarted := false
|
||||
subscription, _ := middleware2.GetSubscriptionFromContext(c)
|
||||
|
||||
maxWait := service.CalculateMaxWait(subject.Concurrency)
|
||||
canWait, err := h.concurrencyHelper.IncrementWaitCount(c.Request.Context(), subject.UserID, maxWait)
|
||||
waitCounted := false
|
||||
if err != nil {
|
||||
reqLog.Warn("sora.user_wait_counter_increment_failed", zap.Error(err))
|
||||
} else if !canWait {
|
||||
reqLog.Info("sora.user_wait_queue_full", zap.Int("max_wait", maxWait))
|
||||
h.errorResponse(c, http.StatusTooManyRequests, "rate_limit_error", "Too many pending requests, please retry later")
|
||||
return
|
||||
}
|
||||
if err == nil && canWait {
|
||||
waitCounted = true
|
||||
}
|
||||
defer func() {
|
||||
if waitCounted {
|
||||
h.concurrencyHelper.DecrementWaitCount(c.Request.Context(), subject.UserID)
|
||||
}
|
||||
}()
|
||||
|
||||
userReleaseFunc, err := h.concurrencyHelper.AcquireUserSlotWithWait(c, subject.UserID, subject.Concurrency, clientStream, &streamStarted)
|
||||
if err != nil {
|
||||
reqLog.Warn("sora.user_slot_acquire_failed", zap.Error(err))
|
||||
h.handleConcurrencyError(c, err, "user", streamStarted)
|
||||
return
|
||||
}
|
||||
if waitCounted {
|
||||
h.concurrencyHelper.DecrementWaitCount(c.Request.Context(), subject.UserID)
|
||||
waitCounted = false
|
||||
}
|
||||
userReleaseFunc = wrapReleaseOnDone(c.Request.Context(), userReleaseFunc)
|
||||
if userReleaseFunc != nil {
|
||||
defer userReleaseFunc()
|
||||
}
|
||||
|
||||
if err := h.billingCacheService.CheckBillingEligibility(c.Request.Context(), apiKey.User, apiKey, apiKey.Group, subscription); err != nil {
|
||||
reqLog.Info("sora.billing_eligibility_check_failed", zap.Error(err))
|
||||
status, code, message := billingErrorDetails(err)
|
||||
h.handleStreamingAwareError(c, status, code, message, streamStarted)
|
||||
return
|
||||
}
|
||||
|
||||
sessionHash := generateOpenAISessionHash(c, body)
|
||||
|
||||
maxAccountSwitches := h.maxAccountSwitches
|
||||
switchCount := 0
|
||||
failedAccountIDs := make(map[int64]struct{})
|
||||
lastFailoverStatus := 0
|
||||
var lastFailoverBody []byte
|
||||
var lastFailoverHeaders http.Header
|
||||
|
||||
for {
|
||||
selection, err := h.gatewayService.SelectAccountWithLoadAwareness(c.Request.Context(), apiKey.GroupID, sessionHash, reqModel, failedAccountIDs, "")
|
||||
if err != nil {
|
||||
reqLog.Warn("sora.account_select_failed",
|
||||
zap.Error(err),
|
||||
zap.Int("excluded_account_count", len(failedAccountIDs)),
|
||||
)
|
||||
if len(failedAccountIDs) == 0 {
|
||||
h.handleStreamingAwareError(c, http.StatusServiceUnavailable, "api_error", "No available accounts: "+err.Error(), streamStarted)
|
||||
return
|
||||
}
|
||||
rayID, mitigated, contentType := extractSoraFailoverHeaderInsights(lastFailoverHeaders, lastFailoverBody)
|
||||
fields := []zap.Field{
|
||||
zap.Int("last_upstream_status", lastFailoverStatus),
|
||||
}
|
||||
if rayID != "" {
|
||||
fields = append(fields, zap.String("last_upstream_cf_ray", rayID))
|
||||
}
|
||||
if mitigated != "" {
|
||||
fields = append(fields, zap.String("last_upstream_cf_mitigated", mitigated))
|
||||
}
|
||||
if contentType != "" {
|
||||
fields = append(fields, zap.String("last_upstream_content_type", contentType))
|
||||
}
|
||||
reqLog.Warn("sora.failover_exhausted_no_available_accounts", fields...)
|
||||
h.handleFailoverExhausted(c, lastFailoverStatus, lastFailoverHeaders, lastFailoverBody, streamStarted)
|
||||
return
|
||||
}
|
||||
account := selection.Account
|
||||
setOpsSelectedAccount(c, account.ID, account.Platform)
|
||||
proxyBound := account.ProxyID != nil
|
||||
proxyID := int64(0)
|
||||
if account.ProxyID != nil {
|
||||
proxyID = *account.ProxyID
|
||||
}
|
||||
tlsFingerprintEnabled := h.soraTLSEnabled
|
||||
|
||||
accountReleaseFunc := selection.ReleaseFunc
|
||||
if !selection.Acquired {
|
||||
if selection.WaitPlan == nil {
|
||||
h.handleStreamingAwareError(c, http.StatusServiceUnavailable, "api_error", "No available accounts", streamStarted)
|
||||
return
|
||||
}
|
||||
accountWaitCounted := false
|
||||
canWait, err := h.concurrencyHelper.IncrementAccountWaitCount(c.Request.Context(), account.ID, selection.WaitPlan.MaxWaiting)
|
||||
if err != nil {
|
||||
reqLog.Warn("sora.account_wait_counter_increment_failed",
|
||||
zap.Int64("account_id", account.ID),
|
||||
zap.Int64("proxy_id", proxyID),
|
||||
zap.Bool("proxy_bound", proxyBound),
|
||||
zap.Bool("tls_fingerprint_enabled", tlsFingerprintEnabled),
|
||||
zap.Error(err),
|
||||
)
|
||||
} else if !canWait {
|
||||
reqLog.Info("sora.account_wait_queue_full",
|
||||
zap.Int64("account_id", account.ID),
|
||||
zap.Int64("proxy_id", proxyID),
|
||||
zap.Bool("proxy_bound", proxyBound),
|
||||
zap.Bool("tls_fingerprint_enabled", tlsFingerprintEnabled),
|
||||
zap.Int("max_waiting", selection.WaitPlan.MaxWaiting),
|
||||
)
|
||||
h.handleStreamingAwareError(c, http.StatusTooManyRequests, "rate_limit_error", "Too many pending requests, please retry later", streamStarted)
|
||||
return
|
||||
}
|
||||
if err == nil && canWait {
|
||||
accountWaitCounted = true
|
||||
}
|
||||
defer func() {
|
||||
if accountWaitCounted {
|
||||
h.concurrencyHelper.DecrementAccountWaitCount(c.Request.Context(), account.ID)
|
||||
}
|
||||
}()
|
||||
|
||||
accountReleaseFunc, err = h.concurrencyHelper.AcquireAccountSlotWithWaitTimeout(
|
||||
c,
|
||||
account.ID,
|
||||
selection.WaitPlan.MaxConcurrency,
|
||||
selection.WaitPlan.Timeout,
|
||||
clientStream,
|
||||
&streamStarted,
|
||||
)
|
||||
if err != nil {
|
||||
reqLog.Warn("sora.account_slot_acquire_failed",
|
||||
zap.Int64("account_id", account.ID),
|
||||
zap.Int64("proxy_id", proxyID),
|
||||
zap.Bool("proxy_bound", proxyBound),
|
||||
zap.Bool("tls_fingerprint_enabled", tlsFingerprintEnabled),
|
||||
zap.Error(err),
|
||||
)
|
||||
h.handleConcurrencyError(c, err, "account", streamStarted)
|
||||
return
|
||||
}
|
||||
if accountWaitCounted {
|
||||
h.concurrencyHelper.DecrementAccountWaitCount(c.Request.Context(), account.ID)
|
||||
accountWaitCounted = false
|
||||
}
|
||||
}
|
||||
accountReleaseFunc = wrapReleaseOnDone(c.Request.Context(), accountReleaseFunc)
|
||||
|
||||
result, err := h.soraGatewayService.Forward(c.Request.Context(), c, account, body, clientStream)
|
||||
if accountReleaseFunc != nil {
|
||||
accountReleaseFunc()
|
||||
}
|
||||
if err != nil {
|
||||
var failoverErr *service.UpstreamFailoverError
|
||||
if errors.As(err, &failoverErr) {
|
||||
failedAccountIDs[account.ID] = struct{}{}
|
||||
if switchCount >= maxAccountSwitches {
|
||||
lastFailoverStatus = failoverErr.StatusCode
|
||||
lastFailoverHeaders = cloneHTTPHeaders(failoverErr.ResponseHeaders)
|
||||
lastFailoverBody = failoverErr.ResponseBody
|
||||
rayID, mitigated, contentType := extractSoraFailoverHeaderInsights(lastFailoverHeaders, lastFailoverBody)
|
||||
fields := []zap.Field{
|
||||
zap.Int64("account_id", account.ID),
|
||||
zap.Int64("proxy_id", proxyID),
|
||||
zap.Bool("proxy_bound", proxyBound),
|
||||
zap.Bool("tls_fingerprint_enabled", tlsFingerprintEnabled),
|
||||
zap.Int("upstream_status", failoverErr.StatusCode),
|
||||
zap.Int("switch_count", switchCount),
|
||||
zap.Int("max_switches", maxAccountSwitches),
|
||||
}
|
||||
if rayID != "" {
|
||||
fields = append(fields, zap.String("upstream_cf_ray", rayID))
|
||||
}
|
||||
if mitigated != "" {
|
||||
fields = append(fields, zap.String("upstream_cf_mitigated", mitigated))
|
||||
}
|
||||
if contentType != "" {
|
||||
fields = append(fields, zap.String("upstream_content_type", contentType))
|
||||
}
|
||||
reqLog.Warn("sora.upstream_failover_exhausted", fields...)
|
||||
h.handleFailoverExhausted(c, lastFailoverStatus, lastFailoverHeaders, lastFailoverBody, streamStarted)
|
||||
return
|
||||
}
|
||||
lastFailoverStatus = failoverErr.StatusCode
|
||||
lastFailoverHeaders = cloneHTTPHeaders(failoverErr.ResponseHeaders)
|
||||
lastFailoverBody = failoverErr.ResponseBody
|
||||
switchCount++
|
||||
upstreamErrCode, upstreamErrMsg := extractUpstreamErrorCodeAndMessage(lastFailoverBody)
|
||||
rayID, mitigated, contentType := extractSoraFailoverHeaderInsights(lastFailoverHeaders, lastFailoverBody)
|
||||
fields := []zap.Field{
|
||||
zap.Int64("account_id", account.ID),
|
||||
zap.Int64("proxy_id", proxyID),
|
||||
zap.Bool("proxy_bound", proxyBound),
|
||||
zap.Bool("tls_fingerprint_enabled", tlsFingerprintEnabled),
|
||||
zap.Int("upstream_status", failoverErr.StatusCode),
|
||||
zap.String("upstream_error_code", upstreamErrCode),
|
||||
zap.String("upstream_error_message", upstreamErrMsg),
|
||||
zap.Int("switch_count", switchCount),
|
||||
zap.Int("max_switches", maxAccountSwitches),
|
||||
}
|
||||
if rayID != "" {
|
||||
fields = append(fields, zap.String("upstream_cf_ray", rayID))
|
||||
}
|
||||
if mitigated != "" {
|
||||
fields = append(fields, zap.String("upstream_cf_mitigated", mitigated))
|
||||
}
|
||||
if contentType != "" {
|
||||
fields = append(fields, zap.String("upstream_content_type", contentType))
|
||||
}
|
||||
reqLog.Warn("sora.upstream_failover_switching", fields...)
|
||||
continue
|
||||
}
|
||||
reqLog.Error("sora.forward_failed",
|
||||
zap.Int64("account_id", account.ID),
|
||||
zap.Int64("proxy_id", proxyID),
|
||||
zap.Bool("proxy_bound", proxyBound),
|
||||
zap.Bool("tls_fingerprint_enabled", tlsFingerprintEnabled),
|
||||
zap.Error(err),
|
||||
)
|
||||
return
|
||||
}
|
||||
|
||||
userAgent := c.GetHeader("User-Agent")
|
||||
clientIP := ip.GetClientIP(c)
|
||||
|
||||
// 使用量记录通过有界 worker 池提交,避免请求热路径创建无界 goroutine。
|
||||
h.submitUsageRecordTask(func(ctx context.Context) {
|
||||
if err := h.gatewayService.RecordUsage(ctx, &service.RecordUsageInput{
|
||||
Result: result,
|
||||
APIKey: apiKey,
|
||||
User: apiKey.User,
|
||||
Account: account,
|
||||
Subscription: subscription,
|
||||
UserAgent: userAgent,
|
||||
IPAddress: clientIP,
|
||||
}); err != nil {
|
||||
logger.L().With(
|
||||
zap.String("component", "handler.sora_gateway.chat_completions"),
|
||||
zap.Int64("user_id", subject.UserID),
|
||||
zap.Int64("api_key_id", apiKey.ID),
|
||||
zap.Any("group_id", apiKey.GroupID),
|
||||
zap.String("model", reqModel),
|
||||
zap.Int64("account_id", account.ID),
|
||||
).Error("sora.record_usage_failed", zap.Error(err))
|
||||
}
|
||||
})
|
||||
reqLog.Debug("sora.request_completed",
|
||||
zap.Int64("account_id", account.ID),
|
||||
zap.Int64("proxy_id", proxyID),
|
||||
zap.Bool("proxy_bound", proxyBound),
|
||||
zap.Bool("tls_fingerprint_enabled", tlsFingerprintEnabled),
|
||||
zap.Int("switch_count", switchCount),
|
||||
)
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
func generateOpenAISessionHash(c *gin.Context, body []byte) string {
|
||||
if c == nil {
|
||||
return ""
|
||||
}
|
||||
sessionID := strings.TrimSpace(c.GetHeader("session_id"))
|
||||
if sessionID == "" {
|
||||
sessionID = strings.TrimSpace(c.GetHeader("conversation_id"))
|
||||
}
|
||||
if sessionID == "" && len(body) > 0 {
|
||||
sessionID = strings.TrimSpace(gjson.GetBytes(body, "prompt_cache_key").String())
|
||||
}
|
||||
if sessionID == "" {
|
||||
return ""
|
||||
}
|
||||
hash := sha256.Sum256([]byte(sessionID))
|
||||
return hex.EncodeToString(hash[:])
|
||||
}
|
||||
|
||||
func (h *SoraGatewayHandler) submitUsageRecordTask(task service.UsageRecordTask) {
|
||||
if task == nil {
|
||||
return
|
||||
}
|
||||
if h.usageRecordWorkerPool != nil {
|
||||
h.usageRecordWorkerPool.Submit(task)
|
||||
return
|
||||
}
|
||||
// 回退路径:worker 池未注入时同步执行,避免退回到无界 goroutine 模式。
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
|
||||
defer cancel()
|
||||
defer func() {
|
||||
if recovered := recover(); recovered != nil {
|
||||
logger.L().With(
|
||||
zap.String("component", "handler.sora_gateway.chat_completions"),
|
||||
zap.Any("panic", recovered),
|
||||
).Error("sora.usage_record_task_panic_recovered")
|
||||
}
|
||||
}()
|
||||
task(ctx)
|
||||
}
|
||||
|
||||
func (h *SoraGatewayHandler) handleConcurrencyError(c *gin.Context, err error, slotType string, streamStarted bool) {
|
||||
h.handleStreamingAwareError(c, http.StatusTooManyRequests, "rate_limit_error",
|
||||
fmt.Sprintf("Concurrency limit exceeded for %s, please retry later", slotType), streamStarted)
|
||||
}
|
||||
|
||||
func (h *SoraGatewayHandler) handleFailoverExhausted(c *gin.Context, statusCode int, responseHeaders http.Header, responseBody []byte, streamStarted bool) {
|
||||
status, errType, errMsg := h.mapUpstreamError(statusCode, responseHeaders, responseBody)
|
||||
h.handleStreamingAwareError(c, status, errType, errMsg, streamStarted)
|
||||
}
|
||||
|
||||
func (h *SoraGatewayHandler) mapUpstreamError(statusCode int, responseHeaders http.Header, responseBody []byte) (int, string, string) {
|
||||
if isSoraCloudflareChallengeResponse(statusCode, responseHeaders, responseBody) {
|
||||
baseMsg := fmt.Sprintf("Sora request blocked by Cloudflare challenge (HTTP %d). Please switch to a clean proxy/network and retry.", statusCode)
|
||||
return http.StatusBadGateway, "upstream_error", formatSoraCloudflareChallengeMessage(baseMsg, responseHeaders, responseBody)
|
||||
}
|
||||
|
||||
upstreamCode, upstreamMessage := extractUpstreamErrorCodeAndMessage(responseBody)
|
||||
if strings.EqualFold(upstreamCode, "cf_shield_429") {
|
||||
baseMsg := "Sora request blocked by Cloudflare shield (429). Please switch to a clean proxy/network and retry."
|
||||
return http.StatusTooManyRequests, "rate_limit_error", formatSoraCloudflareChallengeMessage(baseMsg, responseHeaders, responseBody)
|
||||
}
|
||||
if shouldPassthroughSoraUpstreamMessage(statusCode, upstreamMessage) {
|
||||
switch statusCode {
|
||||
case 401, 403, 404, 500, 502, 503, 504:
|
||||
return http.StatusBadGateway, "upstream_error", upstreamMessage
|
||||
case 429:
|
||||
return http.StatusTooManyRequests, "rate_limit_error", upstreamMessage
|
||||
}
|
||||
}
|
||||
|
||||
switch statusCode {
|
||||
case 401:
|
||||
return http.StatusBadGateway, "upstream_error", "Upstream authentication failed, please contact administrator"
|
||||
case 403:
|
||||
return http.StatusBadGateway, "upstream_error", "Upstream access forbidden, please contact administrator"
|
||||
case 404:
|
||||
if strings.EqualFold(upstreamCode, "unsupported_country_code") {
|
||||
return http.StatusBadGateway, "upstream_error", "Upstream region capability unavailable for this account, please contact administrator"
|
||||
}
|
||||
return http.StatusBadGateway, "upstream_error", "Upstream capability unavailable for this account, please contact administrator"
|
||||
case 429:
|
||||
return http.StatusTooManyRequests, "rate_limit_error", "Upstream rate limit exceeded, please retry later"
|
||||
case 529:
|
||||
return http.StatusServiceUnavailable, "upstream_error", "Upstream service overloaded, please retry later"
|
||||
case 500, 502, 503, 504:
|
||||
return http.StatusBadGateway, "upstream_error", "Upstream service temporarily unavailable"
|
||||
default:
|
||||
return http.StatusBadGateway, "upstream_error", "Upstream request failed"
|
||||
}
|
||||
}
|
||||
|
||||
func cloneHTTPHeaders(headers http.Header) http.Header {
|
||||
if headers == nil {
|
||||
return nil
|
||||
}
|
||||
return headers.Clone()
|
||||
}
|
||||
|
||||
func extractSoraFailoverHeaderInsights(headers http.Header, body []byte) (rayID, mitigated, contentType string) {
|
||||
if headers != nil {
|
||||
mitigated = strings.TrimSpace(headers.Get("cf-mitigated"))
|
||||
contentType = strings.TrimSpace(headers.Get("content-type"))
|
||||
if contentType == "" {
|
||||
contentType = strings.TrimSpace(headers.Get("Content-Type"))
|
||||
}
|
||||
}
|
||||
rayID = soraerror.ExtractCloudflareRayID(headers, body)
|
||||
return rayID, mitigated, contentType
|
||||
}
|
||||
|
||||
func isSoraCloudflareChallengeResponse(statusCode int, headers http.Header, body []byte) bool {
|
||||
return soraerror.IsCloudflareChallengeResponse(statusCode, headers, body)
|
||||
}
|
||||
|
||||
func shouldPassthroughSoraUpstreamMessage(statusCode int, message string) bool {
|
||||
message = strings.TrimSpace(message)
|
||||
if message == "" {
|
||||
return false
|
||||
}
|
||||
if statusCode == http.StatusForbidden || statusCode == http.StatusTooManyRequests {
|
||||
lower := strings.ToLower(message)
|
||||
if strings.Contains(lower, "<html") || strings.Contains(lower, "<!doctype html") || strings.Contains(lower, "window._cf_chl_opt") {
|
||||
return false
|
||||
}
|
||||
}
|
||||
return true
|
||||
}
|
||||
|
||||
func formatSoraCloudflareChallengeMessage(base string, headers http.Header, body []byte) string {
|
||||
return soraerror.FormatCloudflareChallengeMessage(base, headers, body)
|
||||
}
|
||||
|
||||
func extractUpstreamErrorCodeAndMessage(body []byte) (string, string) {
|
||||
return soraerror.ExtractUpstreamErrorCodeAndMessage(body)
|
||||
}
|
||||
|
||||
func (h *SoraGatewayHandler) handleStreamingAwareError(c *gin.Context, status int, errType, message string, streamStarted bool) {
|
||||
if streamStarted {
|
||||
flusher, ok := c.Writer.(http.Flusher)
|
||||
if ok {
|
||||
errorData := map[string]any{
|
||||
"error": map[string]string{
|
||||
"type": errType,
|
||||
"message": message,
|
||||
},
|
||||
}
|
||||
jsonBytes, err := json.Marshal(errorData)
|
||||
if err != nil {
|
||||
_ = c.Error(err)
|
||||
return
|
||||
}
|
||||
errorEvent := fmt.Sprintf("event: error\ndata: %s\n\n", string(jsonBytes))
|
||||
if _, err := fmt.Fprint(c.Writer, errorEvent); err != nil {
|
||||
_ = c.Error(err)
|
||||
}
|
||||
flusher.Flush()
|
||||
}
|
||||
return
|
||||
}
|
||||
h.errorResponse(c, status, errType, message)
|
||||
}
|
||||
|
||||
func (h *SoraGatewayHandler) errorResponse(c *gin.Context, status int, errType, message string) {
|
||||
c.JSON(status, gin.H{
|
||||
"error": gin.H{
|
||||
"type": errType,
|
||||
"message": message,
|
||||
},
|
||||
})
|
||||
}
|
||||
|
||||
// MediaProxy serves local Sora media files.
|
||||
func (h *SoraGatewayHandler) MediaProxy(c *gin.Context) {
|
||||
h.proxySoraMedia(c, false)
|
||||
}
|
||||
|
||||
// MediaProxySigned serves local Sora media files with signature verification.
|
||||
func (h *SoraGatewayHandler) MediaProxySigned(c *gin.Context) {
|
||||
h.proxySoraMedia(c, true)
|
||||
}
|
||||
|
||||
func (h *SoraGatewayHandler) proxySoraMedia(c *gin.Context, requireSignature bool) {
|
||||
rawPath := c.Param("filepath")
|
||||
if rawPath == "" {
|
||||
c.Status(http.StatusNotFound)
|
||||
return
|
||||
}
|
||||
cleaned := path.Clean(rawPath)
|
||||
if !strings.HasPrefix(cleaned, "/image/") && !strings.HasPrefix(cleaned, "/video/") {
|
||||
c.Status(http.StatusNotFound)
|
||||
return
|
||||
}
|
||||
|
||||
query := c.Request.URL.Query()
|
||||
if requireSignature {
|
||||
if h.soraMediaSigningKey == "" {
|
||||
c.JSON(http.StatusServiceUnavailable, gin.H{
|
||||
"error": gin.H{
|
||||
"type": "api_error",
|
||||
"message": "Sora 媒体签名未配置",
|
||||
},
|
||||
})
|
||||
return
|
||||
}
|
||||
expiresStr := strings.TrimSpace(query.Get("expires"))
|
||||
signature := strings.TrimSpace(query.Get("sig"))
|
||||
expires, err := strconv.ParseInt(expiresStr, 10, 64)
|
||||
if err != nil || expires <= time.Now().Unix() {
|
||||
c.JSON(http.StatusUnauthorized, gin.H{
|
||||
"error": gin.H{
|
||||
"type": "authentication_error",
|
||||
"message": "Sora 媒体签名已过期",
|
||||
},
|
||||
})
|
||||
return
|
||||
}
|
||||
query.Del("sig")
|
||||
query.Del("expires")
|
||||
signingQuery := query.Encode()
|
||||
if !service.VerifySoraMediaURL(cleaned, signingQuery, expires, signature, h.soraMediaSigningKey) {
|
||||
c.JSON(http.StatusUnauthorized, gin.H{
|
||||
"error": gin.H{
|
||||
"type": "authentication_error",
|
||||
"message": "Sora 媒体签名无效",
|
||||
},
|
||||
})
|
||||
return
|
||||
}
|
||||
}
|
||||
if strings.TrimSpace(h.soraMediaRoot) == "" {
|
||||
c.JSON(http.StatusServiceUnavailable, gin.H{
|
||||
"error": gin.H{
|
||||
"type": "api_error",
|
||||
"message": "Sora 媒体目录未配置",
|
||||
},
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
relative := strings.TrimPrefix(cleaned, "/")
|
||||
localPath := filepath.Join(h.soraMediaRoot, filepath.FromSlash(relative))
|
||||
if _, err := os.Stat(localPath); err != nil {
|
||||
if os.IsNotExist(err) {
|
||||
c.Status(http.StatusNotFound)
|
||||
return
|
||||
}
|
||||
c.Status(http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
c.File(localPath)
|
||||
}
|
||||
698
backend/internal/handler/sora_gateway_handler_test.go
Normal file
698
backend/internal/handler/sora_gateway_handler_test.go
Normal file
@@ -0,0 +1,698 @@
|
||||
//go:build unit
|
||||
|
||||
package handler
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/config"
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/usagestats"
|
||||
"github.com/Wei-Shaw/sub2api/internal/server/middleware"
|
||||
"github.com/Wei-Shaw/sub2api/internal/service"
|
||||
"github.com/Wei-Shaw/sub2api/internal/testutil"
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/stretchr/testify/require"
|
||||
"github.com/tidwall/gjson"
|
||||
"github.com/tidwall/sjson"
|
||||
)
|
||||
|
||||
// 编译期接口断言
|
||||
var _ service.SoraClient = (*stubSoraClient)(nil)
|
||||
var _ service.AccountRepository = (*stubAccountRepo)(nil)
|
||||
var _ service.GroupRepository = (*stubGroupRepo)(nil)
|
||||
var _ service.UsageLogRepository = (*stubUsageLogRepo)(nil)
|
||||
|
||||
type stubSoraClient struct {
|
||||
imageURLs []string
|
||||
}
|
||||
|
||||
func (s *stubSoraClient) Enabled() bool { return true }
|
||||
func (s *stubSoraClient) UploadImage(ctx context.Context, account *service.Account, data []byte, filename string) (string, error) {
|
||||
return "upload", nil
|
||||
}
|
||||
func (s *stubSoraClient) CreateImageTask(ctx context.Context, account *service.Account, req service.SoraImageRequest) (string, error) {
|
||||
return "task-image", nil
|
||||
}
|
||||
func (s *stubSoraClient) CreateVideoTask(ctx context.Context, account *service.Account, req service.SoraVideoRequest) (string, error) {
|
||||
return "task-video", nil
|
||||
}
|
||||
func (s *stubSoraClient) CreateStoryboardTask(ctx context.Context, account *service.Account, req service.SoraStoryboardRequest) (string, error) {
|
||||
return "task-video", nil
|
||||
}
|
||||
func (s *stubSoraClient) UploadCharacterVideo(ctx context.Context, account *service.Account, data []byte) (string, error) {
|
||||
return "cameo-1", nil
|
||||
}
|
||||
func (s *stubSoraClient) GetCameoStatus(ctx context.Context, account *service.Account, cameoID string) (*service.SoraCameoStatus, error) {
|
||||
return &service.SoraCameoStatus{
|
||||
Status: "finalized",
|
||||
StatusMessage: "Completed",
|
||||
DisplayNameHint: "Character",
|
||||
UsernameHint: "user.character",
|
||||
ProfileAssetURL: "https://example.com/avatar.webp",
|
||||
}, nil
|
||||
}
|
||||
func (s *stubSoraClient) DownloadCharacterImage(ctx context.Context, account *service.Account, imageURL string) ([]byte, error) {
|
||||
return []byte("avatar"), nil
|
||||
}
|
||||
func (s *stubSoraClient) UploadCharacterImage(ctx context.Context, account *service.Account, data []byte) (string, error) {
|
||||
return "asset-pointer", nil
|
||||
}
|
||||
func (s *stubSoraClient) FinalizeCharacter(ctx context.Context, account *service.Account, req service.SoraCharacterFinalizeRequest) (string, error) {
|
||||
return "character-1", nil
|
||||
}
|
||||
func (s *stubSoraClient) SetCharacterPublic(ctx context.Context, account *service.Account, cameoID string) error {
|
||||
return nil
|
||||
}
|
||||
func (s *stubSoraClient) DeleteCharacter(ctx context.Context, account *service.Account, characterID string) error {
|
||||
return nil
|
||||
}
|
||||
func (s *stubSoraClient) PostVideoForWatermarkFree(ctx context.Context, account *service.Account, generationID string) (string, error) {
|
||||
return "s_post", nil
|
||||
}
|
||||
func (s *stubSoraClient) DeletePost(ctx context.Context, account *service.Account, postID string) error {
|
||||
return nil
|
||||
}
|
||||
func (s *stubSoraClient) GetWatermarkFreeURLCustom(ctx context.Context, account *service.Account, parseURL, parseToken, postID string) (string, error) {
|
||||
return "https://example.com/no-watermark.mp4", nil
|
||||
}
|
||||
func (s *stubSoraClient) EnhancePrompt(ctx context.Context, account *service.Account, prompt, expansionLevel string, durationS int) (string, error) {
|
||||
return "enhanced prompt", nil
|
||||
}
|
||||
func (s *stubSoraClient) GetImageTask(ctx context.Context, account *service.Account, taskID string) (*service.SoraImageTaskStatus, error) {
|
||||
return &service.SoraImageTaskStatus{ID: taskID, Status: "completed", URLs: s.imageURLs}, nil
|
||||
}
|
||||
func (s *stubSoraClient) GetVideoTask(ctx context.Context, account *service.Account, taskID string) (*service.SoraVideoTaskStatus, error) {
|
||||
return &service.SoraVideoTaskStatus{ID: taskID, Status: "completed", URLs: s.imageURLs}, nil
|
||||
}
|
||||
|
||||
type stubAccountRepo struct {
|
||||
accounts map[int64]*service.Account
|
||||
}
|
||||
|
||||
func (r *stubAccountRepo) Create(ctx context.Context, account *service.Account) error { return nil }
|
||||
func (r *stubAccountRepo) GetByID(ctx context.Context, id int64) (*service.Account, error) {
|
||||
if acc, ok := r.accounts[id]; ok {
|
||||
return acc, nil
|
||||
}
|
||||
return nil, service.ErrAccountNotFound
|
||||
}
|
||||
func (r *stubAccountRepo) GetByIDs(ctx context.Context, ids []int64) ([]*service.Account, error) {
|
||||
var result []*service.Account
|
||||
for _, id := range ids {
|
||||
if acc, ok := r.accounts[id]; ok {
|
||||
result = append(result, acc)
|
||||
}
|
||||
}
|
||||
return result, nil
|
||||
}
|
||||
func (r *stubAccountRepo) ExistsByID(ctx context.Context, id int64) (bool, error) {
|
||||
_, ok := r.accounts[id]
|
||||
return ok, nil
|
||||
}
|
||||
func (r *stubAccountRepo) GetByCRSAccountID(ctx context.Context, crsAccountID string) (*service.Account, error) {
|
||||
return nil, nil
|
||||
}
|
||||
func (r *stubAccountRepo) FindByExtraField(ctx context.Context, key string, value any) ([]service.Account, error) {
|
||||
return nil, nil
|
||||
}
|
||||
func (r *stubAccountRepo) ListCRSAccountIDs(ctx context.Context) (map[string]int64, error) {
|
||||
return map[string]int64{}, nil
|
||||
}
|
||||
func (r *stubAccountRepo) Update(ctx context.Context, account *service.Account) error { return nil }
|
||||
func (r *stubAccountRepo) Delete(ctx context.Context, id int64) error { return nil }
|
||||
func (r *stubAccountRepo) List(ctx context.Context, params pagination.PaginationParams) ([]service.Account, *pagination.PaginationResult, error) {
|
||||
return nil, nil, nil
|
||||
}
|
||||
func (r *stubAccountRepo) ListWithFilters(ctx context.Context, params pagination.PaginationParams, platform, accountType, status, search string, groupID int64) ([]service.Account, *pagination.PaginationResult, error) {
|
||||
return nil, nil, nil
|
||||
}
|
||||
func (r *stubAccountRepo) ListByGroup(ctx context.Context, groupID int64) ([]service.Account, error) {
|
||||
return nil, nil
|
||||
}
|
||||
func (r *stubAccountRepo) ListActive(ctx context.Context) ([]service.Account, error) { return nil, nil }
|
||||
func (r *stubAccountRepo) ListByPlatform(ctx context.Context, platform string) ([]service.Account, error) {
|
||||
return r.listSchedulableByPlatform(platform), nil
|
||||
}
|
||||
func (r *stubAccountRepo) UpdateLastUsed(ctx context.Context, id int64) error { return nil }
|
||||
func (r *stubAccountRepo) BatchUpdateLastUsed(ctx context.Context, updates map[int64]time.Time) error {
|
||||
return nil
|
||||
}
|
||||
func (r *stubAccountRepo) SetError(ctx context.Context, id int64, errorMsg string) error { return nil }
|
||||
func (r *stubAccountRepo) ClearError(ctx context.Context, id int64) error { return nil }
|
||||
func (r *stubAccountRepo) SetSchedulable(ctx context.Context, id int64, schedulable bool) error {
|
||||
return nil
|
||||
}
|
||||
func (r *stubAccountRepo) AutoPauseExpiredAccounts(ctx context.Context, now time.Time) (int64, error) {
|
||||
return 0, nil
|
||||
}
|
||||
func (r *stubAccountRepo) BindGroups(ctx context.Context, accountID int64, groupIDs []int64) error {
|
||||
return nil
|
||||
}
|
||||
func (r *stubAccountRepo) ListSchedulable(ctx context.Context) ([]service.Account, error) {
|
||||
return r.listSchedulable(), nil
|
||||
}
|
||||
func (r *stubAccountRepo) ListSchedulableByGroupID(ctx context.Context, groupID int64) ([]service.Account, error) {
|
||||
return r.listSchedulable(), nil
|
||||
}
|
||||
func (r *stubAccountRepo) ListSchedulableByPlatform(ctx context.Context, platform string) ([]service.Account, error) {
|
||||
return r.listSchedulableByPlatform(platform), nil
|
||||
}
|
||||
func (r *stubAccountRepo) ListSchedulableByGroupIDAndPlatform(ctx context.Context, groupID int64, platform string) ([]service.Account, error) {
|
||||
return r.listSchedulableByPlatform(platform), nil
|
||||
}
|
||||
func (r *stubAccountRepo) ListSchedulableByPlatforms(ctx context.Context, platforms []string) ([]service.Account, error) {
|
||||
var result []service.Account
|
||||
for _, acc := range r.accounts {
|
||||
for _, platform := range platforms {
|
||||
if acc.Platform == platform && acc.IsSchedulable() {
|
||||
result = append(result, *acc)
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
return result, nil
|
||||
}
|
||||
func (r *stubAccountRepo) ListSchedulableByGroupIDAndPlatforms(ctx context.Context, groupID int64, platforms []string) ([]service.Account, error) {
|
||||
return r.ListSchedulableByPlatforms(ctx, platforms)
|
||||
}
|
||||
func (r *stubAccountRepo) ListSchedulableUngroupedByPlatform(ctx context.Context, platform string) ([]service.Account, error) {
|
||||
return r.ListSchedulableByPlatform(ctx, platform)
|
||||
}
|
||||
func (r *stubAccountRepo) ListSchedulableUngroupedByPlatforms(ctx context.Context, platforms []string) ([]service.Account, error) {
|
||||
return r.ListSchedulableByPlatforms(ctx, platforms)
|
||||
}
|
||||
func (r *stubAccountRepo) SetRateLimited(ctx context.Context, id int64, resetAt time.Time) error {
|
||||
return nil
|
||||
}
|
||||
func (r *stubAccountRepo) SetModelRateLimit(ctx context.Context, id int64, scope string, resetAt time.Time) error {
|
||||
return nil
|
||||
}
|
||||
func (r *stubAccountRepo) SetOverloaded(ctx context.Context, id int64, until time.Time) error {
|
||||
return nil
|
||||
}
|
||||
func (r *stubAccountRepo) SetTempUnschedulable(ctx context.Context, id int64, until time.Time, reason string) error {
|
||||
return nil
|
||||
}
|
||||
func (r *stubAccountRepo) ClearTempUnschedulable(ctx context.Context, id int64) error { return nil }
|
||||
func (r *stubAccountRepo) ClearRateLimit(ctx context.Context, id int64) error { return nil }
|
||||
func (r *stubAccountRepo) ClearAntigravityQuotaScopes(ctx context.Context, id int64) error {
|
||||
return nil
|
||||
}
|
||||
func (r *stubAccountRepo) ClearModelRateLimits(ctx context.Context, id int64) error { return nil }
|
||||
func (r *stubAccountRepo) UpdateSessionWindow(ctx context.Context, id int64, start, end *time.Time, status string) error {
|
||||
return nil
|
||||
}
|
||||
func (r *stubAccountRepo) UpdateExtra(ctx context.Context, id int64, updates map[string]any) error {
|
||||
return nil
|
||||
}
|
||||
func (r *stubAccountRepo) BulkUpdate(ctx context.Context, ids []int64, updates service.AccountBulkUpdate) (int64, error) {
|
||||
return 0, nil
|
||||
}
|
||||
|
||||
func (r *stubAccountRepo) listSchedulable() []service.Account {
|
||||
var result []service.Account
|
||||
for _, acc := range r.accounts {
|
||||
if acc.IsSchedulable() {
|
||||
result = append(result, *acc)
|
||||
}
|
||||
}
|
||||
return result
|
||||
}
|
||||
|
||||
func (r *stubAccountRepo) listSchedulableByPlatform(platform string) []service.Account {
|
||||
var result []service.Account
|
||||
for _, acc := range r.accounts {
|
||||
if acc.Platform == platform && acc.IsSchedulable() {
|
||||
result = append(result, *acc)
|
||||
}
|
||||
}
|
||||
return result
|
||||
}
|
||||
|
||||
type stubGroupRepo struct {
|
||||
group *service.Group
|
||||
}
|
||||
|
||||
func (r *stubGroupRepo) Create(ctx context.Context, group *service.Group) error { return nil }
|
||||
func (r *stubGroupRepo) GetByID(ctx context.Context, id int64) (*service.Group, error) {
|
||||
return r.group, nil
|
||||
}
|
||||
func (r *stubGroupRepo) GetByIDLite(ctx context.Context, id int64) (*service.Group, error) {
|
||||
return r.group, nil
|
||||
}
|
||||
func (r *stubGroupRepo) Update(ctx context.Context, group *service.Group) error { return nil }
|
||||
func (r *stubGroupRepo) Delete(ctx context.Context, id int64) error { return nil }
|
||||
func (r *stubGroupRepo) DeleteCascade(ctx context.Context, id int64) ([]int64, error) {
|
||||
return nil, nil
|
||||
}
|
||||
func (r *stubGroupRepo) List(ctx context.Context, params pagination.PaginationParams) ([]service.Group, *pagination.PaginationResult, error) {
|
||||
return nil, nil, nil
|
||||
}
|
||||
func (r *stubGroupRepo) ListWithFilters(ctx context.Context, params pagination.PaginationParams, platform, status, search string, isExclusive *bool) ([]service.Group, *pagination.PaginationResult, error) {
|
||||
return nil, nil, nil
|
||||
}
|
||||
func (r *stubGroupRepo) ListActive(ctx context.Context) ([]service.Group, error) { return nil, nil }
|
||||
func (r *stubGroupRepo) ListActiveByPlatform(ctx context.Context, platform string) ([]service.Group, error) {
|
||||
return nil, nil
|
||||
}
|
||||
func (r *stubGroupRepo) ExistsByName(ctx context.Context, name string) (bool, error) {
|
||||
return false, nil
|
||||
}
|
||||
func (r *stubGroupRepo) GetAccountCount(ctx context.Context, groupID int64) (int64, error) {
|
||||
return 0, nil
|
||||
}
|
||||
func (r *stubGroupRepo) DeleteAccountGroupsByGroupID(ctx context.Context, groupID int64) (int64, error) {
|
||||
return 0, nil
|
||||
}
|
||||
func (r *stubGroupRepo) GetAccountIDsByGroupIDs(ctx context.Context, groupIDs []int64) ([]int64, error) {
|
||||
return nil, nil
|
||||
}
|
||||
func (r *stubGroupRepo) BindAccountsToGroup(ctx context.Context, groupID int64, accountIDs []int64) error {
|
||||
return nil
|
||||
}
|
||||
func (r *stubGroupRepo) UpdateSortOrders(ctx context.Context, updates []service.GroupSortOrderUpdate) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
type stubUsageLogRepo struct{}
|
||||
|
||||
func (s *stubUsageLogRepo) Create(ctx context.Context, log *service.UsageLog) (bool, error) {
|
||||
return true, nil
|
||||
}
|
||||
func (s *stubUsageLogRepo) GetByID(ctx context.Context, id int64) (*service.UsageLog, error) {
|
||||
return nil, nil
|
||||
}
|
||||
func (s *stubUsageLogRepo) Delete(ctx context.Context, id int64) error { return nil }
|
||||
func (s *stubUsageLogRepo) ListByUser(ctx context.Context, userID int64, params pagination.PaginationParams) ([]service.UsageLog, *pagination.PaginationResult, error) {
|
||||
return nil, nil, nil
|
||||
}
|
||||
func (s *stubUsageLogRepo) ListByAPIKey(ctx context.Context, apiKeyID int64, params pagination.PaginationParams) ([]service.UsageLog, *pagination.PaginationResult, error) {
|
||||
return nil, nil, nil
|
||||
}
|
||||
func (s *stubUsageLogRepo) ListByAccount(ctx context.Context, accountID int64, params pagination.PaginationParams) ([]service.UsageLog, *pagination.PaginationResult, error) {
|
||||
return nil, nil, nil
|
||||
}
|
||||
func (s *stubUsageLogRepo) ListByUserAndTimeRange(ctx context.Context, userID int64, startTime, endTime time.Time) ([]service.UsageLog, *pagination.PaginationResult, error) {
|
||||
return nil, nil, nil
|
||||
}
|
||||
func (s *stubUsageLogRepo) ListByAPIKeyAndTimeRange(ctx context.Context, apiKeyID int64, startTime, endTime time.Time) ([]service.UsageLog, *pagination.PaginationResult, error) {
|
||||
return nil, nil, nil
|
||||
}
|
||||
func (s *stubUsageLogRepo) ListByAccountAndTimeRange(ctx context.Context, accountID int64, startTime, endTime time.Time) ([]service.UsageLog, *pagination.PaginationResult, error) {
|
||||
return nil, nil, nil
|
||||
}
|
||||
func (s *stubUsageLogRepo) ListByModelAndTimeRange(ctx context.Context, modelName string, startTime, endTime time.Time) ([]service.UsageLog, *pagination.PaginationResult, error) {
|
||||
return nil, nil, nil
|
||||
}
|
||||
func (s *stubUsageLogRepo) GetAccountWindowStats(ctx context.Context, accountID int64, startTime time.Time) (*usagestats.AccountStats, error) {
|
||||
return nil, nil
|
||||
}
|
||||
func (s *stubUsageLogRepo) GetAccountTodayStats(ctx context.Context, accountID int64) (*usagestats.AccountStats, error) {
|
||||
return nil, nil
|
||||
}
|
||||
func (s *stubUsageLogRepo) GetDashboardStats(ctx context.Context) (*usagestats.DashboardStats, error) {
|
||||
return nil, nil
|
||||
}
|
||||
func (s *stubUsageLogRepo) GetUsageTrendWithFilters(ctx context.Context, startTime, endTime time.Time, granularity string, userID, apiKeyID, accountID, groupID int64, model string, requestType *int16, stream *bool, billingType *int8) ([]usagestats.TrendDataPoint, error) {
|
||||
return nil, nil
|
||||
}
|
||||
func (s *stubUsageLogRepo) GetModelStatsWithFilters(ctx context.Context, startTime, endTime time.Time, userID, apiKeyID, accountID, groupID int64, requestType *int16, stream *bool, billingType *int8) ([]usagestats.ModelStat, error) {
|
||||
return nil, nil
|
||||
}
|
||||
func (s *stubUsageLogRepo) GetGroupStatsWithFilters(ctx context.Context, startTime, endTime time.Time, userID, apiKeyID, accountID, groupID int64, requestType *int16, stream *bool, billingType *int8) ([]usagestats.GroupStat, error) {
|
||||
return nil, nil
|
||||
}
|
||||
func (s *stubUsageLogRepo) GetAPIKeyUsageTrend(ctx context.Context, startTime, endTime time.Time, granularity string, limit int) ([]usagestats.APIKeyUsageTrendPoint, error) {
|
||||
return nil, nil
|
||||
}
|
||||
func (s *stubUsageLogRepo) GetUserUsageTrend(ctx context.Context, startTime, endTime time.Time, granularity string, limit int) ([]usagestats.UserUsageTrendPoint, error) {
|
||||
return nil, nil
|
||||
}
|
||||
func (s *stubUsageLogRepo) GetBatchUserUsageStats(ctx context.Context, userIDs []int64, startTime, endTime time.Time) (map[int64]*usagestats.BatchUserUsageStats, error) {
|
||||
return nil, nil
|
||||
}
|
||||
func (s *stubUsageLogRepo) GetBatchAPIKeyUsageStats(ctx context.Context, apiKeyIDs []int64, startTime, endTime time.Time) (map[int64]*usagestats.BatchAPIKeyUsageStats, error) {
|
||||
return nil, nil
|
||||
}
|
||||
func (s *stubUsageLogRepo) GetUserDashboardStats(ctx context.Context, userID int64) (*usagestats.UserDashboardStats, error) {
|
||||
return nil, nil
|
||||
}
|
||||
func (s *stubUsageLogRepo) GetAPIKeyDashboardStats(ctx context.Context, apiKeyID int64) (*usagestats.UserDashboardStats, error) {
|
||||
return nil, nil
|
||||
}
|
||||
func (s *stubUsageLogRepo) GetUserUsageTrendByUserID(ctx context.Context, userID int64, startTime, endTime time.Time, granularity string) ([]usagestats.TrendDataPoint, error) {
|
||||
return nil, nil
|
||||
}
|
||||
func (s *stubUsageLogRepo) GetUserModelStats(ctx context.Context, userID int64, startTime, endTime time.Time) ([]usagestats.ModelStat, error) {
|
||||
return nil, nil
|
||||
}
|
||||
func (s *stubUsageLogRepo) ListWithFilters(ctx context.Context, params pagination.PaginationParams, filters usagestats.UsageLogFilters) ([]service.UsageLog, *pagination.PaginationResult, error) {
|
||||
return nil, nil, nil
|
||||
}
|
||||
func (s *stubUsageLogRepo) GetGlobalStats(ctx context.Context, startTime, endTime time.Time) (*usagestats.UsageStats, error) {
|
||||
return nil, nil
|
||||
}
|
||||
func (s *stubUsageLogRepo) GetStatsWithFilters(ctx context.Context, filters usagestats.UsageLogFilters) (*usagestats.UsageStats, error) {
|
||||
return nil, nil
|
||||
}
|
||||
func (s *stubUsageLogRepo) GetAccountUsageStats(ctx context.Context, accountID int64, startTime, endTime time.Time) (*usagestats.AccountUsageStatsResponse, error) {
|
||||
return nil, nil
|
||||
}
|
||||
func (s *stubUsageLogRepo) GetUserStatsAggregated(ctx context.Context, userID int64, startTime, endTime time.Time) (*usagestats.UsageStats, error) {
|
||||
return nil, nil
|
||||
}
|
||||
func (s *stubUsageLogRepo) GetAPIKeyStatsAggregated(ctx context.Context, apiKeyID int64, startTime, endTime time.Time) (*usagestats.UsageStats, error) {
|
||||
return nil, nil
|
||||
}
|
||||
func (s *stubUsageLogRepo) GetAccountStatsAggregated(ctx context.Context, accountID int64, startTime, endTime time.Time) (*usagestats.UsageStats, error) {
|
||||
return nil, nil
|
||||
}
|
||||
func (s *stubUsageLogRepo) GetModelStatsAggregated(ctx context.Context, modelName string, startTime, endTime time.Time) (*usagestats.UsageStats, error) {
|
||||
return nil, nil
|
||||
}
|
||||
func (s *stubUsageLogRepo) GetDailyStatsAggregated(ctx context.Context, userID int64, startTime, endTime time.Time) ([]map[string]any, error) {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
func TestSoraGatewayHandler_ChatCompletions(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
cfg := &config.Config{
|
||||
RunMode: config.RunModeSimple,
|
||||
Gateway: config.GatewayConfig{
|
||||
SoraStreamMode: "force",
|
||||
MaxAccountSwitches: 1,
|
||||
Scheduling: config.GatewaySchedulingConfig{
|
||||
LoadBatchEnabled: false,
|
||||
},
|
||||
},
|
||||
Concurrency: config.ConcurrencyConfig{PingInterval: 0},
|
||||
Sora: config.SoraConfig{
|
||||
Client: config.SoraClientConfig{
|
||||
BaseURL: "https://sora.test",
|
||||
PollIntervalSeconds: 1,
|
||||
MaxPollAttempts: 1,
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
account := &service.Account{ID: 1, Platform: service.PlatformSora, Status: service.StatusActive, Schedulable: true, Concurrency: 1, Priority: 1}
|
||||
accountRepo := &stubAccountRepo{accounts: map[int64]*service.Account{account.ID: account}}
|
||||
group := &service.Group{ID: 1, Platform: service.PlatformSora, Status: service.StatusActive, Hydrated: true}
|
||||
groupRepo := &stubGroupRepo{group: group}
|
||||
|
||||
usageLogRepo := &stubUsageLogRepo{}
|
||||
deferredService := service.NewDeferredService(accountRepo, nil, 0)
|
||||
billingService := service.NewBillingService(cfg, nil)
|
||||
concurrencyService := service.NewConcurrencyService(testutil.StubConcurrencyCache{})
|
||||
billingCacheService := service.NewBillingCacheService(nil, nil, nil, nil, cfg)
|
||||
t.Cleanup(func() {
|
||||
billingCacheService.Stop()
|
||||
})
|
||||
|
||||
gatewayService := service.NewGatewayService(
|
||||
accountRepo,
|
||||
groupRepo,
|
||||
usageLogRepo,
|
||||
nil,
|
||||
nil,
|
||||
nil,
|
||||
testutil.StubGatewayCache{},
|
||||
cfg,
|
||||
nil,
|
||||
concurrencyService,
|
||||
billingService,
|
||||
nil,
|
||||
billingCacheService,
|
||||
nil,
|
||||
nil,
|
||||
deferredService,
|
||||
nil,
|
||||
testutil.StubSessionLimitCache{},
|
||||
nil, // rpmCache
|
||||
nil, // digestStore
|
||||
)
|
||||
|
||||
soraClient := &stubSoraClient{imageURLs: []string{"https://example.com/a.png"}}
|
||||
soraGatewayService := service.NewSoraGatewayService(soraClient, nil, nil, cfg)
|
||||
|
||||
handler := NewSoraGatewayHandler(gatewayService, soraGatewayService, concurrencyService, billingCacheService, nil, cfg)
|
||||
|
||||
rec := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(rec)
|
||||
body := `{"model":"gpt-image","messages":[{"role":"user","content":"hello"}]}`
|
||||
c.Request = httptest.NewRequest(http.MethodPost, "/sora/v1/chat/completions", strings.NewReader(body))
|
||||
c.Request.Header.Set("Content-Type", "application/json")
|
||||
|
||||
apiKey := &service.APIKey{
|
||||
ID: 1,
|
||||
UserID: 1,
|
||||
Status: service.StatusActive,
|
||||
GroupID: &group.ID,
|
||||
User: &service.User{ID: 1, Concurrency: 1, Status: service.StatusActive},
|
||||
Group: group,
|
||||
}
|
||||
c.Set(string(middleware.ContextKeyAPIKey), apiKey)
|
||||
c.Set(string(middleware.ContextKeyUser), middleware.AuthSubject{UserID: apiKey.UserID, Concurrency: apiKey.User.Concurrency})
|
||||
|
||||
handler.ChatCompletions(c)
|
||||
|
||||
require.Equal(t, http.StatusOK, rec.Code)
|
||||
var resp map[string]any
|
||||
require.NoError(t, json.Unmarshal(rec.Body.Bytes(), &resp))
|
||||
require.NotEmpty(t, resp["media_url"])
|
||||
}
|
||||
|
||||
// TestSoraHandler_StreamForcing 验证 sora handler 的 stream 强制逻辑
|
||||
func TestSoraHandler_StreamForcing(t *testing.T) {
|
||||
// 测试 1:stream=false 时 sjson 强制修改为 true
|
||||
body := []byte(`{"model":"sora","messages":[{"role":"user","content":"test"}],"stream":false}`)
|
||||
clientStream := gjson.GetBytes(body, "stream").Bool()
|
||||
require.False(t, clientStream)
|
||||
newBody, err := sjson.SetBytes(body, "stream", true)
|
||||
require.NoError(t, err)
|
||||
require.True(t, gjson.GetBytes(newBody, "stream").Bool())
|
||||
|
||||
// 测试 2:stream=true 时不修改
|
||||
body2 := []byte(`{"model":"sora","messages":[{"role":"user","content":"test"}],"stream":true}`)
|
||||
require.True(t, gjson.GetBytes(body2, "stream").Bool())
|
||||
|
||||
// 测试 3:无 stream 字段时 gjson 返回 false(零值)
|
||||
body3 := []byte(`{"model":"sora","messages":[{"role":"user","content":"test"}]}`)
|
||||
require.False(t, gjson.GetBytes(body3, "stream").Bool())
|
||||
}
|
||||
|
||||
// TestSoraHandler_ValidationExtraction 验证 sora handler 中 gjson 字段校验逻辑
|
||||
func TestSoraHandler_ValidationExtraction(t *testing.T) {
|
||||
// model 缺失
|
||||
body := []byte(`{"messages":[{"role":"user","content":"test"}]}`)
|
||||
modelResult := gjson.GetBytes(body, "model")
|
||||
require.True(t, !modelResult.Exists() || modelResult.Type != gjson.String || modelResult.String() == "")
|
||||
|
||||
// model 为数字 → 类型不是 gjson.String,应被拒绝
|
||||
body1b := []byte(`{"model":123,"messages":[{"role":"user","content":"test"}]}`)
|
||||
modelResult1b := gjson.GetBytes(body1b, "model")
|
||||
require.True(t, modelResult1b.Exists())
|
||||
require.NotEqual(t, gjson.String, modelResult1b.Type)
|
||||
|
||||
// messages 缺失
|
||||
body2 := []byte(`{"model":"sora"}`)
|
||||
require.False(t, gjson.GetBytes(body2, "messages").IsArray())
|
||||
|
||||
// messages 不是 JSON 数组(字符串)
|
||||
body3 := []byte(`{"model":"sora","messages":"not array"}`)
|
||||
require.False(t, gjson.GetBytes(body3, "messages").IsArray())
|
||||
|
||||
// messages 是对象而非数组 → IsArray 返回 false
|
||||
body4 := []byte(`{"model":"sora","messages":{}}`)
|
||||
require.False(t, gjson.GetBytes(body4, "messages").IsArray())
|
||||
|
||||
// messages 是空数组 → IsArray 为 true 但 len==0,应被拒绝
|
||||
body5 := []byte(`{"model":"sora","messages":[]}`)
|
||||
msgsResult := gjson.GetBytes(body5, "messages")
|
||||
require.True(t, msgsResult.IsArray())
|
||||
require.Equal(t, 0, len(msgsResult.Array()))
|
||||
|
||||
// 非法 JSON 被 gjson.ValidBytes 拦截
|
||||
require.False(t, gjson.ValidBytes([]byte(`{invalid`)))
|
||||
}
|
||||
|
||||
// TestGenerateOpenAISessionHash_WithBody 验证 generateOpenAISessionHash 的 body/header 解析逻辑
|
||||
func TestGenerateOpenAISessionHash_WithBody(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
|
||||
// 从 body 提取 prompt_cache_key
|
||||
body := []byte(`{"model":"sora","prompt_cache_key":"session-abc"}`)
|
||||
w := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(w)
|
||||
c.Request = httptest.NewRequest("POST", "/", nil)
|
||||
|
||||
hash := generateOpenAISessionHash(c, body)
|
||||
require.NotEmpty(t, hash)
|
||||
|
||||
// 无 prompt_cache_key 且无 header → 空 hash
|
||||
body2 := []byte(`{"model":"sora"}`)
|
||||
hash2 := generateOpenAISessionHash(c, body2)
|
||||
require.Empty(t, hash2)
|
||||
|
||||
// header 优先于 body
|
||||
c.Request.Header.Set("session_id", "from-header")
|
||||
hash3 := generateOpenAISessionHash(c, body)
|
||||
require.NotEmpty(t, hash3)
|
||||
require.NotEqual(t, hash, hash3) // 不同来源应产生不同 hash
|
||||
}
|
||||
|
||||
func TestSoraHandleStreamingAwareError_JSONEscaping(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
errType string
|
||||
message string
|
||||
}{
|
||||
{
|
||||
name: "包含双引号",
|
||||
errType: "upstream_error",
|
||||
message: `upstream returned "invalid" payload`,
|
||||
},
|
||||
{
|
||||
name: "包含换行和制表符",
|
||||
errType: "rate_limit_error",
|
||||
message: "line1\nline2\ttab",
|
||||
},
|
||||
{
|
||||
name: "包含反斜杠",
|
||||
errType: "upstream_error",
|
||||
message: `path C:\Users\test\file.txt not found`,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
w := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(w)
|
||||
c.Request = httptest.NewRequest(http.MethodGet, "/", nil)
|
||||
|
||||
h := &SoraGatewayHandler{}
|
||||
h.handleStreamingAwareError(c, http.StatusBadGateway, tt.errType, tt.message, true)
|
||||
|
||||
body := w.Body.String()
|
||||
require.True(t, strings.HasPrefix(body, "event: error\n"), "应以 SSE error 事件开头")
|
||||
require.True(t, strings.HasSuffix(body, "\n\n"), "应以 SSE 结束分隔符结尾")
|
||||
|
||||
lines := strings.Split(strings.TrimSuffix(body, "\n\n"), "\n")
|
||||
require.Len(t, lines, 2, "SSE 错误事件应包含 event 行和 data 行")
|
||||
require.Equal(t, "event: error", lines[0])
|
||||
require.True(t, strings.HasPrefix(lines[1], "data: "), "第二行应为 data 前缀")
|
||||
|
||||
jsonStr := strings.TrimPrefix(lines[1], "data: ")
|
||||
var parsed map[string]any
|
||||
require.NoError(t, json.Unmarshal([]byte(jsonStr), &parsed), "data 行必须是合法 JSON")
|
||||
|
||||
errorObj, ok := parsed["error"].(map[string]any)
|
||||
require.True(t, ok, "JSON 中应包含 error 对象")
|
||||
require.Equal(t, tt.errType, errorObj["type"])
|
||||
require.Equal(t, tt.message, errorObj["message"])
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestSoraHandleFailoverExhausted_StreamPassesUpstreamMessage(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
w := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(w)
|
||||
c.Request = httptest.NewRequest(http.MethodGet, "/", nil)
|
||||
|
||||
h := &SoraGatewayHandler{}
|
||||
resp := []byte(`{"error":{"message":"invalid \"prompt\"\nline2","code":"bad_request"}}`)
|
||||
h.handleFailoverExhausted(c, http.StatusBadGateway, nil, resp, true)
|
||||
|
||||
body := w.Body.String()
|
||||
require.True(t, strings.HasPrefix(body, "event: error\n"))
|
||||
require.True(t, strings.HasSuffix(body, "\n\n"))
|
||||
|
||||
lines := strings.Split(strings.TrimSuffix(body, "\n\n"), "\n")
|
||||
require.Len(t, lines, 2)
|
||||
jsonStr := strings.TrimPrefix(lines[1], "data: ")
|
||||
|
||||
var parsed map[string]any
|
||||
require.NoError(t, json.Unmarshal([]byte(jsonStr), &parsed))
|
||||
|
||||
errorObj, ok := parsed["error"].(map[string]any)
|
||||
require.True(t, ok)
|
||||
require.Equal(t, "upstream_error", errorObj["type"])
|
||||
require.Equal(t, "invalid \"prompt\"\nline2", errorObj["message"])
|
||||
}
|
||||
|
||||
func TestSoraHandleFailoverExhausted_CloudflareChallengeIncludesRay(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
w := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(w)
|
||||
c.Request = httptest.NewRequest(http.MethodGet, "/", nil)
|
||||
|
||||
headers := http.Header{}
|
||||
headers.Set("cf-ray", "9d01b0e9ecc35829-SEA")
|
||||
body := []byte(`<!DOCTYPE html><html><head><title>Just a moment...</title></head><body><script>window._cf_chl_opt={};</script></body></html>`)
|
||||
|
||||
h := &SoraGatewayHandler{}
|
||||
h.handleFailoverExhausted(c, http.StatusForbidden, headers, body, true)
|
||||
|
||||
lines := strings.Split(strings.TrimSuffix(w.Body.String(), "\n\n"), "\n")
|
||||
require.Len(t, lines, 2)
|
||||
jsonStr := strings.TrimPrefix(lines[1], "data: ")
|
||||
|
||||
var parsed map[string]any
|
||||
require.NoError(t, json.Unmarshal([]byte(jsonStr), &parsed))
|
||||
|
||||
errorObj, ok := parsed["error"].(map[string]any)
|
||||
require.True(t, ok)
|
||||
require.Equal(t, "upstream_error", errorObj["type"])
|
||||
msg, _ := errorObj["message"].(string)
|
||||
require.Contains(t, msg, "Cloudflare challenge")
|
||||
require.Contains(t, msg, "cf-ray: 9d01b0e9ecc35829-SEA")
|
||||
}
|
||||
|
||||
func TestSoraHandleFailoverExhausted_CfShield429MappedToRateLimitError(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
w := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(w)
|
||||
c.Request = httptest.NewRequest(http.MethodGet, "/", nil)
|
||||
|
||||
headers := http.Header{}
|
||||
headers.Set("cf-ray", "9d03b68c086027a1-SEA")
|
||||
body := []byte(`{"error":{"code":"cf_shield_429","message":"shield blocked"}}`)
|
||||
|
||||
h := &SoraGatewayHandler{}
|
||||
h.handleFailoverExhausted(c, http.StatusTooManyRequests, headers, body, true)
|
||||
|
||||
lines := strings.Split(strings.TrimSuffix(w.Body.String(), "\n\n"), "\n")
|
||||
require.Len(t, lines, 2)
|
||||
jsonStr := strings.TrimPrefix(lines[1], "data: ")
|
||||
|
||||
var parsed map[string]any
|
||||
require.NoError(t, json.Unmarshal([]byte(jsonStr), &parsed))
|
||||
|
||||
errorObj, ok := parsed["error"].(map[string]any)
|
||||
require.True(t, ok)
|
||||
require.Equal(t, "rate_limit_error", errorObj["type"])
|
||||
msg, _ := errorObj["message"].(string)
|
||||
require.Contains(t, msg, "Cloudflare shield")
|
||||
require.Contains(t, msg, "cf-ray: 9d03b68c086027a1-SEA")
|
||||
}
|
||||
|
||||
func TestExtractSoraFailoverHeaderInsights(t *testing.T) {
|
||||
headers := http.Header{}
|
||||
headers.Set("cf-mitigated", "challenge")
|
||||
headers.Set("content-type", "text/html")
|
||||
body := []byte(`<script>window._cf_chl_opt={cRay: '9cff2d62d83bb98d'};</script>`)
|
||||
|
||||
rayID, mitigated, contentType := extractSoraFailoverHeaderInsights(headers, body)
|
||||
require.Equal(t, "9cff2d62d83bb98d", rayID)
|
||||
require.Equal(t, "challenge", mitigated)
|
||||
require.Equal(t, "text/html", contentType)
|
||||
}
|
||||
@@ -2,6 +2,7 @@ package handler
|
||||
|
||||
import (
|
||||
"strconv"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/handler/dto"
|
||||
@@ -65,8 +66,17 @@ func (h *UsageHandler) List(c *gin.Context) {
|
||||
// Parse additional filters
|
||||
model := c.Query("model")
|
||||
|
||||
var requestType *int16
|
||||
var stream *bool
|
||||
if streamStr := c.Query("stream"); streamStr != "" {
|
||||
if requestTypeStr := strings.TrimSpace(c.Query("request_type")); requestTypeStr != "" {
|
||||
parsed, err := service.ParseUsageRequestType(requestTypeStr)
|
||||
if err != nil {
|
||||
response.BadRequest(c, err.Error())
|
||||
return
|
||||
}
|
||||
value := int16(parsed)
|
||||
requestType = &value
|
||||
} else if streamStr := c.Query("stream"); streamStr != "" {
|
||||
val, err := strconv.ParseBool(streamStr)
|
||||
if err != nil {
|
||||
response.BadRequest(c, "Invalid stream value, use true or false")
|
||||
@@ -114,6 +124,7 @@ func (h *UsageHandler) List(c *gin.Context) {
|
||||
UserID: subject.UserID, // Always filter by current user for security
|
||||
APIKeyID: apiKeyID,
|
||||
Model: model,
|
||||
RequestType: requestType,
|
||||
Stream: stream,
|
||||
BillingType: billingType,
|
||||
StartTime: startTime,
|
||||
@@ -392,7 +403,7 @@ func (h *UsageHandler) DashboardAPIKeysUsage(c *gin.Context) {
|
||||
return
|
||||
}
|
||||
|
||||
stats, err := h.usageService.GetBatchAPIKeyUsageStats(c.Request.Context(), validAPIKeyIDs)
|
||||
stats, err := h.usageService.GetBatchAPIKeyUsageStats(c.Request.Context(), validAPIKeyIDs, time.Time{}, time.Time{})
|
||||
if err != nil {
|
||||
response.ErrorFrom(c, err)
|
||||
return
|
||||
|
||||
80
backend/internal/handler/usage_handler_request_type_test.go
Normal file
80
backend/internal/handler/usage_handler_request_type_test.go
Normal file
@@ -0,0 +1,80 @@
|
||||
package handler
|
||||
|
||||
import (
|
||||
"context"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"testing"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/usagestats"
|
||||
middleware2 "github.com/Wei-Shaw/sub2api/internal/server/middleware"
|
||||
"github.com/Wei-Shaw/sub2api/internal/service"
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
type userUsageRepoCapture struct {
|
||||
service.UsageLogRepository
|
||||
listFilters usagestats.UsageLogFilters
|
||||
}
|
||||
|
||||
func (s *userUsageRepoCapture) ListWithFilters(ctx context.Context, params pagination.PaginationParams, filters usagestats.UsageLogFilters) ([]service.UsageLog, *pagination.PaginationResult, error) {
|
||||
s.listFilters = filters
|
||||
return []service.UsageLog{}, &pagination.PaginationResult{
|
||||
Total: 0,
|
||||
Page: params.Page,
|
||||
PageSize: params.PageSize,
|
||||
Pages: 0,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func newUserUsageRequestTypeTestRouter(repo *userUsageRepoCapture) *gin.Engine {
|
||||
gin.SetMode(gin.TestMode)
|
||||
usageSvc := service.NewUsageService(repo, nil, nil, nil)
|
||||
handler := NewUsageHandler(usageSvc, nil)
|
||||
router := gin.New()
|
||||
router.Use(func(c *gin.Context) {
|
||||
c.Set(string(middleware2.ContextKeyUser), middleware2.AuthSubject{UserID: 42})
|
||||
c.Next()
|
||||
})
|
||||
router.GET("/usage", handler.List)
|
||||
return router
|
||||
}
|
||||
|
||||
func TestUserUsageListRequestTypePriority(t *testing.T) {
|
||||
repo := &userUsageRepoCapture{}
|
||||
router := newUserUsageRequestTypeTestRouter(repo)
|
||||
|
||||
req := httptest.NewRequest(http.MethodGet, "/usage?request_type=ws_v2&stream=bad", nil)
|
||||
rec := httptest.NewRecorder()
|
||||
router.ServeHTTP(rec, req)
|
||||
|
||||
require.Equal(t, http.StatusOK, rec.Code)
|
||||
require.Equal(t, int64(42), repo.listFilters.UserID)
|
||||
require.NotNil(t, repo.listFilters.RequestType)
|
||||
require.Equal(t, int16(service.RequestTypeWSV2), *repo.listFilters.RequestType)
|
||||
require.Nil(t, repo.listFilters.Stream)
|
||||
}
|
||||
|
||||
func TestUserUsageListInvalidRequestType(t *testing.T) {
|
||||
repo := &userUsageRepoCapture{}
|
||||
router := newUserUsageRequestTypeTestRouter(repo)
|
||||
|
||||
req := httptest.NewRequest(http.MethodGet, "/usage?request_type=invalid", nil)
|
||||
rec := httptest.NewRecorder()
|
||||
router.ServeHTTP(rec, req)
|
||||
|
||||
require.Equal(t, http.StatusBadRequest, rec.Code)
|
||||
}
|
||||
|
||||
func TestUserUsageListInvalidStream(t *testing.T) {
|
||||
repo := &userUsageRepoCapture{}
|
||||
router := newUserUsageRequestTypeTestRouter(repo)
|
||||
|
||||
req := httptest.NewRequest(http.MethodGet, "/usage?stream=invalid", nil)
|
||||
rec := httptest.NewRecorder()
|
||||
router.ServeHTTP(rec, req)
|
||||
|
||||
require.Equal(t, http.StatusBadRequest, rec.Code)
|
||||
}
|
||||
184
backend/internal/handler/usage_record_submit_task_test.go
Normal file
184
backend/internal/handler/usage_record_submit_task_test.go
Normal file
@@ -0,0 +1,184 @@
|
||||
package handler
|
||||
|
||||
import (
|
||||
"context"
|
||||
"sync/atomic"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/service"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func newUsageRecordTestPool(t *testing.T) *service.UsageRecordWorkerPool {
|
||||
t.Helper()
|
||||
pool := service.NewUsageRecordWorkerPoolWithOptions(service.UsageRecordWorkerPoolOptions{
|
||||
WorkerCount: 1,
|
||||
QueueSize: 8,
|
||||
TaskTimeout: time.Second,
|
||||
OverflowPolicy: "drop",
|
||||
OverflowSamplePercent: 0,
|
||||
AutoScaleEnabled: false,
|
||||
})
|
||||
t.Cleanup(pool.Stop)
|
||||
return pool
|
||||
}
|
||||
|
||||
func TestGatewayHandlerSubmitUsageRecordTask_WithPool(t *testing.T) {
|
||||
pool := newUsageRecordTestPool(t)
|
||||
h := &GatewayHandler{usageRecordWorkerPool: pool}
|
||||
|
||||
done := make(chan struct{})
|
||||
h.submitUsageRecordTask(func(ctx context.Context) {
|
||||
close(done)
|
||||
})
|
||||
|
||||
select {
|
||||
case <-done:
|
||||
case <-time.After(time.Second):
|
||||
t.Fatal("task not executed")
|
||||
}
|
||||
}
|
||||
|
||||
func TestGatewayHandlerSubmitUsageRecordTask_WithoutPoolSyncFallback(t *testing.T) {
|
||||
h := &GatewayHandler{}
|
||||
var called atomic.Bool
|
||||
|
||||
h.submitUsageRecordTask(func(ctx context.Context) {
|
||||
if _, ok := ctx.Deadline(); !ok {
|
||||
t.Fatal("expected deadline in fallback context")
|
||||
}
|
||||
called.Store(true)
|
||||
})
|
||||
|
||||
require.True(t, called.Load())
|
||||
}
|
||||
|
||||
func TestGatewayHandlerSubmitUsageRecordTask_NilTask(t *testing.T) {
|
||||
h := &GatewayHandler{}
|
||||
require.NotPanics(t, func() {
|
||||
h.submitUsageRecordTask(nil)
|
||||
})
|
||||
}
|
||||
|
||||
func TestGatewayHandlerSubmitUsageRecordTask_WithoutPool_TaskPanicRecovered(t *testing.T) {
|
||||
h := &GatewayHandler{}
|
||||
var called atomic.Bool
|
||||
|
||||
require.NotPanics(t, func() {
|
||||
h.submitUsageRecordTask(func(ctx context.Context) {
|
||||
panic("usage task panic")
|
||||
})
|
||||
})
|
||||
|
||||
h.submitUsageRecordTask(func(ctx context.Context) {
|
||||
called.Store(true)
|
||||
})
|
||||
require.True(t, called.Load(), "panic 后后续任务应仍可执行")
|
||||
}
|
||||
|
||||
func TestOpenAIGatewayHandlerSubmitUsageRecordTask_WithPool(t *testing.T) {
|
||||
pool := newUsageRecordTestPool(t)
|
||||
h := &OpenAIGatewayHandler{usageRecordWorkerPool: pool}
|
||||
|
||||
done := make(chan struct{})
|
||||
h.submitUsageRecordTask(func(ctx context.Context) {
|
||||
close(done)
|
||||
})
|
||||
|
||||
select {
|
||||
case <-done:
|
||||
case <-time.After(time.Second):
|
||||
t.Fatal("task not executed")
|
||||
}
|
||||
}
|
||||
|
||||
func TestOpenAIGatewayHandlerSubmitUsageRecordTask_WithoutPoolSyncFallback(t *testing.T) {
|
||||
h := &OpenAIGatewayHandler{}
|
||||
var called atomic.Bool
|
||||
|
||||
h.submitUsageRecordTask(func(ctx context.Context) {
|
||||
if _, ok := ctx.Deadline(); !ok {
|
||||
t.Fatal("expected deadline in fallback context")
|
||||
}
|
||||
called.Store(true)
|
||||
})
|
||||
|
||||
require.True(t, called.Load())
|
||||
}
|
||||
|
||||
func TestOpenAIGatewayHandlerSubmitUsageRecordTask_NilTask(t *testing.T) {
|
||||
h := &OpenAIGatewayHandler{}
|
||||
require.NotPanics(t, func() {
|
||||
h.submitUsageRecordTask(nil)
|
||||
})
|
||||
}
|
||||
|
||||
func TestOpenAIGatewayHandlerSubmitUsageRecordTask_WithoutPool_TaskPanicRecovered(t *testing.T) {
|
||||
h := &OpenAIGatewayHandler{}
|
||||
var called atomic.Bool
|
||||
|
||||
require.NotPanics(t, func() {
|
||||
h.submitUsageRecordTask(func(ctx context.Context) {
|
||||
panic("usage task panic")
|
||||
})
|
||||
})
|
||||
|
||||
h.submitUsageRecordTask(func(ctx context.Context) {
|
||||
called.Store(true)
|
||||
})
|
||||
require.True(t, called.Load(), "panic 后后续任务应仍可执行")
|
||||
}
|
||||
|
||||
func TestSoraGatewayHandlerSubmitUsageRecordTask_WithPool(t *testing.T) {
|
||||
pool := newUsageRecordTestPool(t)
|
||||
h := &SoraGatewayHandler{usageRecordWorkerPool: pool}
|
||||
|
||||
done := make(chan struct{})
|
||||
h.submitUsageRecordTask(func(ctx context.Context) {
|
||||
close(done)
|
||||
})
|
||||
|
||||
select {
|
||||
case <-done:
|
||||
case <-time.After(time.Second):
|
||||
t.Fatal("task not executed")
|
||||
}
|
||||
}
|
||||
|
||||
func TestSoraGatewayHandlerSubmitUsageRecordTask_WithoutPoolSyncFallback(t *testing.T) {
|
||||
h := &SoraGatewayHandler{}
|
||||
var called atomic.Bool
|
||||
|
||||
h.submitUsageRecordTask(func(ctx context.Context) {
|
||||
if _, ok := ctx.Deadline(); !ok {
|
||||
t.Fatal("expected deadline in fallback context")
|
||||
}
|
||||
called.Store(true)
|
||||
})
|
||||
|
||||
require.True(t, called.Load())
|
||||
}
|
||||
|
||||
func TestSoraGatewayHandlerSubmitUsageRecordTask_NilTask(t *testing.T) {
|
||||
h := &SoraGatewayHandler{}
|
||||
require.NotPanics(t, func() {
|
||||
h.submitUsageRecordTask(nil)
|
||||
})
|
||||
}
|
||||
|
||||
func TestSoraGatewayHandlerSubmitUsageRecordTask_WithoutPool_TaskPanicRecovered(t *testing.T) {
|
||||
h := &SoraGatewayHandler{}
|
||||
var called atomic.Bool
|
||||
|
||||
require.NotPanics(t, func() {
|
||||
h.submitUsageRecordTask(func(ctx context.Context) {
|
||||
panic("usage task panic")
|
||||
})
|
||||
})
|
||||
|
||||
h.submitUsageRecordTask(func(ctx context.Context) {
|
||||
called.Store(true)
|
||||
})
|
||||
require.True(t, called.Load(), "panic 后后续任务应仍可执行")
|
||||
}
|
||||
237
backend/internal/handler/user_msg_queue_helper.go
Normal file
237
backend/internal/handler/user_msg_queue_helper.go
Normal file
@@ -0,0 +1,237 @@
|
||||
package handler
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/service"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
"go.uber.org/zap"
|
||||
)
|
||||
|
||||
// UserMsgQueueHelper 用户消息串行队列 Handler 层辅助
|
||||
// 复用 ConcurrencyHelper 的退避 + SSE ping 模式
|
||||
type UserMsgQueueHelper struct {
|
||||
queueService *service.UserMessageQueueService
|
||||
pingFormat SSEPingFormat
|
||||
pingInterval time.Duration
|
||||
}
|
||||
|
||||
// NewUserMsgQueueHelper 创建用户消息串行队列辅助
|
||||
func NewUserMsgQueueHelper(
|
||||
queueService *service.UserMessageQueueService,
|
||||
pingFormat SSEPingFormat,
|
||||
pingInterval time.Duration,
|
||||
) *UserMsgQueueHelper {
|
||||
if pingInterval <= 0 {
|
||||
pingInterval = defaultPingInterval
|
||||
}
|
||||
return &UserMsgQueueHelper{
|
||||
queueService: queueService,
|
||||
pingFormat: pingFormat,
|
||||
pingInterval: pingInterval,
|
||||
}
|
||||
}
|
||||
|
||||
// AcquireWithWait 等待获取串行锁,流式请求期间发送 SSE ping
|
||||
// 返回的 releaseFunc 内部使用 sync.Once,确保只执行一次释放
|
||||
func (h *UserMsgQueueHelper) AcquireWithWait(
|
||||
c *gin.Context,
|
||||
accountID int64,
|
||||
baseRPM int,
|
||||
isStream bool,
|
||||
streamStarted *bool,
|
||||
timeout time.Duration,
|
||||
reqLog *zap.Logger,
|
||||
) (releaseFunc func(), err error) {
|
||||
ctx, cancel := context.WithTimeout(c.Request.Context(), timeout)
|
||||
defer cancel()
|
||||
|
||||
// 先尝试立即获取
|
||||
result, err := h.queueService.TryAcquire(ctx, accountID)
|
||||
if err != nil {
|
||||
return nil, err // fail-open 已在 service 层处理
|
||||
}
|
||||
|
||||
if result.Acquired {
|
||||
// 获取成功,执行 RPM 自适应延迟
|
||||
if err := h.queueService.EnforceDelay(ctx, accountID, baseRPM); err != nil {
|
||||
if ctx.Err() != nil {
|
||||
// 延迟期间 context 取消,释放锁
|
||||
bgCtx, bgCancel := context.WithTimeout(context.Background(), 5*time.Second)
|
||||
_ = h.queueService.Release(bgCtx, accountID, result.RequestID)
|
||||
bgCancel()
|
||||
return nil, ctx.Err()
|
||||
}
|
||||
}
|
||||
reqLog.Debug("gateway.umq_lock_acquired", zap.Int64("account_id", accountID))
|
||||
return h.makeReleaseFunc(accountID, result.RequestID, reqLog), nil
|
||||
}
|
||||
|
||||
// 需要等待:指数退避轮询
|
||||
return h.waitForLockWithPing(c, ctx, accountID, baseRPM, isStream, streamStarted, reqLog)
|
||||
}
|
||||
|
||||
// waitForLockWithPing 等待获取锁,流式请求期间发送 SSE ping
|
||||
func (h *UserMsgQueueHelper) waitForLockWithPing(
|
||||
c *gin.Context,
|
||||
ctx context.Context,
|
||||
accountID int64,
|
||||
baseRPM int,
|
||||
isStream bool,
|
||||
streamStarted *bool,
|
||||
reqLog *zap.Logger,
|
||||
) (func(), error) {
|
||||
needPing := isStream && h.pingFormat != ""
|
||||
|
||||
var flusher http.Flusher
|
||||
if needPing {
|
||||
var ok bool
|
||||
flusher, ok = c.Writer.(http.Flusher)
|
||||
if !ok {
|
||||
needPing = false
|
||||
}
|
||||
}
|
||||
|
||||
var pingCh <-chan time.Time
|
||||
if needPing {
|
||||
pingTicker := time.NewTicker(h.pingInterval)
|
||||
defer pingTicker.Stop()
|
||||
pingCh = pingTicker.C
|
||||
}
|
||||
|
||||
backoff := initialBackoff
|
||||
timer := time.NewTimer(backoff)
|
||||
defer timer.Stop()
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return nil, fmt.Errorf("umq wait timeout for account %d", accountID)
|
||||
|
||||
case <-pingCh:
|
||||
if !*streamStarted {
|
||||
c.Header("Content-Type", "text/event-stream")
|
||||
c.Header("Cache-Control", "no-cache")
|
||||
c.Header("Connection", "keep-alive")
|
||||
c.Header("X-Accel-Buffering", "no")
|
||||
*streamStarted = true
|
||||
}
|
||||
if _, err := fmt.Fprint(c.Writer, string(h.pingFormat)); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
flusher.Flush()
|
||||
|
||||
case <-timer.C:
|
||||
result, err := h.queueService.TryAcquire(ctx, accountID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if result.Acquired {
|
||||
// 获取成功,执行 RPM 自适应延迟
|
||||
if delayErr := h.queueService.EnforceDelay(ctx, accountID, baseRPM); delayErr != nil {
|
||||
if ctx.Err() != nil {
|
||||
bgCtx, bgCancel := context.WithTimeout(context.Background(), 5*time.Second)
|
||||
_ = h.queueService.Release(bgCtx, accountID, result.RequestID)
|
||||
bgCancel()
|
||||
return nil, ctx.Err()
|
||||
}
|
||||
}
|
||||
reqLog.Debug("gateway.umq_lock_acquired", zap.Int64("account_id", accountID))
|
||||
return h.makeReleaseFunc(accountID, result.RequestID, reqLog), nil
|
||||
}
|
||||
backoff = nextBackoff(backoff)
|
||||
timer.Reset(backoff)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// makeReleaseFunc 创建锁释放函数(使用 sync.Once 确保只执行一次)
|
||||
func (h *UserMsgQueueHelper) makeReleaseFunc(accountID int64, requestID string, reqLog *zap.Logger) func() {
|
||||
var once sync.Once
|
||||
return func() {
|
||||
once.Do(func() {
|
||||
bgCtx, bgCancel := context.WithTimeout(context.Background(), 5*time.Second)
|
||||
defer bgCancel()
|
||||
if err := h.queueService.Release(bgCtx, accountID, requestID); err != nil {
|
||||
reqLog.Warn("gateway.umq_release_failed",
|
||||
zap.Int64("account_id", accountID),
|
||||
zap.Error(err),
|
||||
)
|
||||
} else {
|
||||
reqLog.Debug("gateway.umq_lock_released", zap.Int64("account_id", accountID))
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// ThrottleWithPing 软性限速模式:施加 RPM 自适应延迟,流式期间发送 SSE ping
|
||||
// 不获取串行锁,不阻塞并发。返回后即可转发请求。
|
||||
func (h *UserMsgQueueHelper) ThrottleWithPing(
|
||||
c *gin.Context,
|
||||
accountID int64,
|
||||
baseRPM int,
|
||||
isStream bool,
|
||||
streamStarted *bool,
|
||||
timeout time.Duration,
|
||||
reqLog *zap.Logger,
|
||||
) error {
|
||||
ctx, cancel := context.WithTimeout(c.Request.Context(), timeout)
|
||||
defer cancel()
|
||||
|
||||
delay := h.queueService.CalculateRPMAwareDelay(ctx, accountID, baseRPM)
|
||||
if delay <= 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
reqLog.Debug("gateway.umq_throttle_delay",
|
||||
zap.Int64("account_id", accountID),
|
||||
zap.Duration("delay", delay),
|
||||
)
|
||||
|
||||
// 延迟期间发送 SSE ping(复用 waitForLockWithPing 的 ping 逻辑)
|
||||
needPing := isStream && h.pingFormat != ""
|
||||
var flusher http.Flusher
|
||||
if needPing {
|
||||
flusher, _ = c.Writer.(http.Flusher)
|
||||
if flusher == nil {
|
||||
needPing = false
|
||||
}
|
||||
}
|
||||
|
||||
var pingCh <-chan time.Time
|
||||
if needPing {
|
||||
pingTicker := time.NewTicker(h.pingInterval)
|
||||
defer pingTicker.Stop()
|
||||
pingCh = pingTicker.C
|
||||
}
|
||||
|
||||
timer := time.NewTimer(delay)
|
||||
defer timer.Stop()
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return ctx.Err()
|
||||
case <-pingCh:
|
||||
// SSE ping 逻辑(与 waitForLockWithPing 一致)
|
||||
if !*streamStarted {
|
||||
c.Header("Content-Type", "text/event-stream")
|
||||
c.Header("Cache-Control", "no-cache")
|
||||
c.Header("Connection", "keep-alive")
|
||||
c.Header("X-Accel-Buffering", "no")
|
||||
*streamStarted = true
|
||||
}
|
||||
if _, err := fmt.Fprint(c.Writer, string(h.pingFormat)); err != nil {
|
||||
return err
|
||||
}
|
||||
flusher.Flush()
|
||||
case <-timer.C:
|
||||
return nil
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -14,6 +14,7 @@ func ProvideAdminHandlers(
|
||||
groupHandler *admin.GroupHandler,
|
||||
accountHandler *admin.AccountHandler,
|
||||
announcementHandler *admin.AnnouncementHandler,
|
||||
dataManagementHandler *admin.DataManagementHandler,
|
||||
oauthHandler *admin.OAuthHandler,
|
||||
openaiOAuthHandler *admin.OpenAIOAuthHandler,
|
||||
geminiOAuthHandler *admin.GeminiOAuthHandler,
|
||||
@@ -28,6 +29,7 @@ func ProvideAdminHandlers(
|
||||
usageHandler *admin.UsageHandler,
|
||||
userAttributeHandler *admin.UserAttributeHandler,
|
||||
errorPassthroughHandler *admin.ErrorPassthroughHandler,
|
||||
apiKeyHandler *admin.AdminAPIKeyHandler,
|
||||
) *AdminHandlers {
|
||||
return &AdminHandlers{
|
||||
Dashboard: dashboardHandler,
|
||||
@@ -35,6 +37,7 @@ func ProvideAdminHandlers(
|
||||
Group: groupHandler,
|
||||
Account: accountHandler,
|
||||
Announcement: announcementHandler,
|
||||
DataManagement: dataManagementHandler,
|
||||
OAuth: oauthHandler,
|
||||
OpenAIOAuth: openaiOAuthHandler,
|
||||
GeminiOAuth: geminiOAuthHandler,
|
||||
@@ -49,12 +52,13 @@ func ProvideAdminHandlers(
|
||||
Usage: usageHandler,
|
||||
UserAttribute: userAttributeHandler,
|
||||
ErrorPassthrough: errorPassthroughHandler,
|
||||
APIKey: apiKeyHandler,
|
||||
}
|
||||
}
|
||||
|
||||
// ProvideSystemHandler creates admin.SystemHandler with UpdateService
|
||||
func ProvideSystemHandler(updateService *service.UpdateService) *admin.SystemHandler {
|
||||
return admin.NewSystemHandler(updateService)
|
||||
func ProvideSystemHandler(updateService *service.UpdateService, lockService *service.SystemOperationLockService) *admin.SystemHandler {
|
||||
return admin.NewSystemHandler(updateService, lockService)
|
||||
}
|
||||
|
||||
// ProvideSettingHandler creates SettingHandler with version from BuildInfo
|
||||
@@ -74,8 +78,12 @@ func ProvideHandlers(
|
||||
adminHandlers *AdminHandlers,
|
||||
gatewayHandler *GatewayHandler,
|
||||
openaiGatewayHandler *OpenAIGatewayHandler,
|
||||
soraGatewayHandler *SoraGatewayHandler,
|
||||
soraClientHandler *SoraClientHandler,
|
||||
settingHandler *SettingHandler,
|
||||
totpHandler *TotpHandler,
|
||||
_ *service.IdempotencyCoordinator,
|
||||
_ *service.IdempotencyCleanupService,
|
||||
) *Handlers {
|
||||
return &Handlers{
|
||||
Auth: authHandler,
|
||||
@@ -88,6 +96,8 @@ func ProvideHandlers(
|
||||
Admin: adminHandlers,
|
||||
Gateway: gatewayHandler,
|
||||
OpenAIGateway: openaiGatewayHandler,
|
||||
SoraGateway: soraGatewayHandler,
|
||||
SoraClient: soraClientHandler,
|
||||
Setting: settingHandler,
|
||||
Totp: totpHandler,
|
||||
}
|
||||
@@ -105,6 +115,7 @@ var ProviderSet = wire.NewSet(
|
||||
NewAnnouncementHandler,
|
||||
NewGatewayHandler,
|
||||
NewOpenAIGatewayHandler,
|
||||
NewSoraGatewayHandler,
|
||||
NewTotpHandler,
|
||||
ProvideSettingHandler,
|
||||
|
||||
@@ -114,6 +125,7 @@ var ProviderSet = wire.NewSet(
|
||||
admin.NewGroupHandler,
|
||||
admin.NewAccountHandler,
|
||||
admin.NewAnnouncementHandler,
|
||||
admin.NewDataManagementHandler,
|
||||
admin.NewOAuthHandler,
|
||||
admin.NewOpenAIOAuthHandler,
|
||||
admin.NewGeminiOAuthHandler,
|
||||
@@ -128,6 +140,7 @@ var ProviderSet = wire.NewSet(
|
||||
admin.NewUsageHandler,
|
||||
admin.NewUserAttributeHandler,
|
||||
admin.NewErrorPassthroughHandler,
|
||||
admin.NewAdminAPIKeyHandler,
|
||||
|
||||
// AdminHandlers and Handlers constructors
|
||||
ProvideAdminHandlers,
|
||||
|
||||
@@ -21,11 +21,18 @@ var (
|
||||
// - "" (默认): 使用 /v1/messages, /v1beta/models(混合模式,可调度 antigravity 账户)
|
||||
// - "/antigravity": 使用 /antigravity/v1/messages, /antigravity/v1beta/models(非混合模式,仅 antigravity 账户)
|
||||
endpointPrefix = getEnv("ENDPOINT_PREFIX", "")
|
||||
claudeAPIKey = "sk-8e572bc3b3de92ace4f41f4256c28600ca11805732a7b693b5c44741346bbbb3"
|
||||
geminiAPIKey = "sk-5950197a2085b38bbe5a1b229cc02b8ece914963fc44cacc06d497ae8b87410f"
|
||||
testInterval = 1 * time.Second // 测试间隔,防止限流
|
||||
)
|
||||
|
||||
const (
|
||||
// 注意:E2E 测试请使用环境变量注入密钥,避免任何凭证进入仓库历史。
|
||||
// 例如:
|
||||
// export CLAUDE_API_KEY="sk-..."
|
||||
// export GEMINI_API_KEY="sk-..."
|
||||
claudeAPIKeyEnv = "CLAUDE_API_KEY"
|
||||
geminiAPIKeyEnv = "GEMINI_API_KEY"
|
||||
)
|
||||
|
||||
func getEnv(key, defaultVal string) string {
|
||||
if v := os.Getenv(key); v != "" {
|
||||
return v
|
||||
@@ -65,16 +72,45 @@ func TestMain(m *testing.M) {
|
||||
if endpointPrefix != "" {
|
||||
mode = "Antigravity 模式"
|
||||
}
|
||||
fmt.Printf("\n🚀 E2E Gateway Tests - %s (prefix=%q, %s)\n\n", baseURL, endpointPrefix, mode)
|
||||
claudeKeySet := strings.TrimSpace(os.Getenv(claudeAPIKeyEnv)) != ""
|
||||
geminiKeySet := strings.TrimSpace(os.Getenv(geminiAPIKeyEnv)) != ""
|
||||
fmt.Printf("\n🚀 E2E Gateway Tests - %s (prefix=%q, %s, %s=%v, %s=%v)\n\n",
|
||||
baseURL,
|
||||
endpointPrefix,
|
||||
mode,
|
||||
claudeAPIKeyEnv,
|
||||
claudeKeySet,
|
||||
geminiAPIKeyEnv,
|
||||
geminiKeySet,
|
||||
)
|
||||
os.Exit(m.Run())
|
||||
}
|
||||
|
||||
func requireClaudeAPIKey(t *testing.T) string {
|
||||
t.Helper()
|
||||
key := strings.TrimSpace(os.Getenv(claudeAPIKeyEnv))
|
||||
if key == "" {
|
||||
t.Skipf("未设置 %s,跳过 Claude 相关 E2E 测试", claudeAPIKeyEnv)
|
||||
}
|
||||
return key
|
||||
}
|
||||
|
||||
func requireGeminiAPIKey(t *testing.T) string {
|
||||
t.Helper()
|
||||
key := strings.TrimSpace(os.Getenv(geminiAPIKeyEnv))
|
||||
if key == "" {
|
||||
t.Skipf("未设置 %s,跳过 Gemini 相关 E2E 测试", geminiAPIKeyEnv)
|
||||
}
|
||||
return key
|
||||
}
|
||||
|
||||
// TestClaudeModelsList 测试 GET /v1/models
|
||||
func TestClaudeModelsList(t *testing.T) {
|
||||
claudeKey := requireClaudeAPIKey(t)
|
||||
url := baseURL + endpointPrefix + "/v1/models"
|
||||
|
||||
req, _ := http.NewRequest("GET", url, nil)
|
||||
req.Header.Set("Authorization", "Bearer "+claudeAPIKey)
|
||||
req.Header.Set("Authorization", "Bearer "+claudeKey)
|
||||
|
||||
client := &http.Client{Timeout: 30 * time.Second}
|
||||
resp, err := client.Do(req)
|
||||
@@ -106,10 +142,11 @@ func TestClaudeModelsList(t *testing.T) {
|
||||
|
||||
// TestGeminiModelsList 测试 GET /v1beta/models
|
||||
func TestGeminiModelsList(t *testing.T) {
|
||||
geminiKey := requireGeminiAPIKey(t)
|
||||
url := baseURL + endpointPrefix + "/v1beta/models"
|
||||
|
||||
req, _ := http.NewRequest("GET", url, nil)
|
||||
req.Header.Set("Authorization", "Bearer "+geminiAPIKey)
|
||||
req.Header.Set("Authorization", "Bearer "+geminiKey)
|
||||
|
||||
client := &http.Client{Timeout: 30 * time.Second}
|
||||
resp, err := client.Do(req)
|
||||
@@ -137,21 +174,22 @@ func TestGeminiModelsList(t *testing.T) {
|
||||
|
||||
// TestClaudeMessages 测试 Claude /v1/messages 接口
|
||||
func TestClaudeMessages(t *testing.T) {
|
||||
claudeKey := requireClaudeAPIKey(t)
|
||||
for i, model := range claudeModels {
|
||||
if i > 0 {
|
||||
time.Sleep(testInterval)
|
||||
}
|
||||
t.Run(model+"_非流式", func(t *testing.T) {
|
||||
testClaudeMessage(t, model, false)
|
||||
testClaudeMessage(t, claudeKey, model, false)
|
||||
})
|
||||
time.Sleep(testInterval)
|
||||
t.Run(model+"_流式", func(t *testing.T) {
|
||||
testClaudeMessage(t, model, true)
|
||||
testClaudeMessage(t, claudeKey, model, true)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func testClaudeMessage(t *testing.T, model string, stream bool) {
|
||||
func testClaudeMessage(t *testing.T, claudeKey string, model string, stream bool) {
|
||||
url := baseURL + endpointPrefix + "/v1/messages"
|
||||
|
||||
payload := map[string]any{
|
||||
@@ -166,7 +204,7 @@ func testClaudeMessage(t *testing.T, model string, stream bool) {
|
||||
|
||||
req, _ := http.NewRequest("POST", url, bytes.NewReader(body))
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
req.Header.Set("Authorization", "Bearer "+claudeAPIKey)
|
||||
req.Header.Set("Authorization", "Bearer "+claudeKey)
|
||||
req.Header.Set("anthropic-version", "2023-06-01")
|
||||
|
||||
client := &http.Client{Timeout: 60 * time.Second}
|
||||
@@ -213,21 +251,22 @@ func testClaudeMessage(t *testing.T, model string, stream bool) {
|
||||
|
||||
// TestGeminiGenerateContent 测试 Gemini /v1beta/models/:model 接口
|
||||
func TestGeminiGenerateContent(t *testing.T) {
|
||||
geminiKey := requireGeminiAPIKey(t)
|
||||
for i, model := range geminiModels {
|
||||
if i > 0 {
|
||||
time.Sleep(testInterval)
|
||||
}
|
||||
t.Run(model+"_非流式", func(t *testing.T) {
|
||||
testGeminiGenerate(t, model, false)
|
||||
testGeminiGenerate(t, geminiKey, model, false)
|
||||
})
|
||||
time.Sleep(testInterval)
|
||||
t.Run(model+"_流式", func(t *testing.T) {
|
||||
testGeminiGenerate(t, model, true)
|
||||
testGeminiGenerate(t, geminiKey, model, true)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func testGeminiGenerate(t *testing.T, model string, stream bool) {
|
||||
func testGeminiGenerate(t *testing.T, geminiKey string, model string, stream bool) {
|
||||
action := "generateContent"
|
||||
if stream {
|
||||
action = "streamGenerateContent"
|
||||
@@ -254,7 +293,7 @@ func testGeminiGenerate(t *testing.T, model string, stream bool) {
|
||||
|
||||
req, _ := http.NewRequest("POST", url, bytes.NewReader(body))
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
req.Header.Set("Authorization", "Bearer "+geminiAPIKey)
|
||||
req.Header.Set("Authorization", "Bearer "+geminiKey)
|
||||
|
||||
client := &http.Client{Timeout: 60 * time.Second}
|
||||
resp, err := client.Do(req)
|
||||
@@ -301,6 +340,7 @@ func testGeminiGenerate(t *testing.T, model string, stream bool) {
|
||||
// TestClaudeMessagesWithComplexTools 测试带复杂工具 schema 的请求
|
||||
// 模拟 Claude Code 发送的请求,包含需要清理的 JSON Schema 字段
|
||||
func TestClaudeMessagesWithComplexTools(t *testing.T) {
|
||||
claudeKey := requireClaudeAPIKey(t)
|
||||
// 测试模型列表(只测试几个代表性模型)
|
||||
models := []string{
|
||||
"claude-opus-4-5-20251101", // Claude 模型
|
||||
@@ -312,12 +352,12 @@ func TestClaudeMessagesWithComplexTools(t *testing.T) {
|
||||
time.Sleep(testInterval)
|
||||
}
|
||||
t.Run(model+"_复杂工具", func(t *testing.T) {
|
||||
testClaudeMessageWithTools(t, model)
|
||||
testClaudeMessageWithTools(t, claudeKey, model)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func testClaudeMessageWithTools(t *testing.T, model string) {
|
||||
func testClaudeMessageWithTools(t *testing.T, claudeKey string, model string) {
|
||||
url := baseURL + endpointPrefix + "/v1/messages"
|
||||
|
||||
// 构造包含复杂 schema 的工具定义(模拟 Claude Code 的工具)
|
||||
@@ -473,7 +513,7 @@ func testClaudeMessageWithTools(t *testing.T, model string) {
|
||||
|
||||
req, _ := http.NewRequest("POST", url, bytes.NewReader(body))
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
req.Header.Set("Authorization", "Bearer "+claudeAPIKey)
|
||||
req.Header.Set("Authorization", "Bearer "+claudeKey)
|
||||
req.Header.Set("anthropic-version", "2023-06-01")
|
||||
|
||||
client := &http.Client{Timeout: 60 * time.Second}
|
||||
@@ -519,6 +559,7 @@ func testClaudeMessageWithTools(t *testing.T, model string) {
|
||||
// 验证:当历史 assistant 消息包含 tool_use 但没有 signature 时,
|
||||
// 系统应自动添加 dummy thought_signature 避免 Gemini 400 错误
|
||||
func TestClaudeMessagesWithThinkingAndTools(t *testing.T) {
|
||||
claudeKey := requireClaudeAPIKey(t)
|
||||
models := []string{
|
||||
"claude-haiku-4-5-20251001", // gemini-3-flash
|
||||
}
|
||||
@@ -527,12 +568,12 @@ func TestClaudeMessagesWithThinkingAndTools(t *testing.T) {
|
||||
time.Sleep(testInterval)
|
||||
}
|
||||
t.Run(model+"_thinking模式工具调用", func(t *testing.T) {
|
||||
testClaudeThinkingWithToolHistory(t, model)
|
||||
testClaudeThinkingWithToolHistory(t, claudeKey, model)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func testClaudeThinkingWithToolHistory(t *testing.T, model string) {
|
||||
func testClaudeThinkingWithToolHistory(t *testing.T, claudeKey string, model string) {
|
||||
url := baseURL + endpointPrefix + "/v1/messages"
|
||||
|
||||
// 模拟历史对话:用户请求 → assistant 调用工具 → 工具返回 → 继续对话
|
||||
@@ -600,7 +641,7 @@ func testClaudeThinkingWithToolHistory(t *testing.T, model string) {
|
||||
|
||||
req, _ := http.NewRequest("POST", url, bytes.NewReader(body))
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
req.Header.Set("Authorization", "Bearer "+claudeAPIKey)
|
||||
req.Header.Set("Authorization", "Bearer "+claudeKey)
|
||||
req.Header.Set("anthropic-version", "2023-06-01")
|
||||
|
||||
client := &http.Client{Timeout: 60 * time.Second}
|
||||
@@ -649,6 +690,7 @@ func TestClaudeMessagesWithGeminiModel(t *testing.T) {
|
||||
if endpointPrefix != "/antigravity" {
|
||||
t.Skip("仅在 Antigravity 模式下运行")
|
||||
}
|
||||
claudeKey := requireClaudeAPIKey(t)
|
||||
|
||||
// 测试通过 Claude 端点调用 Gemini 模型
|
||||
geminiViaClaude := []string{
|
||||
@@ -664,11 +706,11 @@ func TestClaudeMessagesWithGeminiModel(t *testing.T) {
|
||||
time.Sleep(testInterval)
|
||||
}
|
||||
t.Run(model+"_通过Claude端点", func(t *testing.T) {
|
||||
testClaudeMessage(t, model, false)
|
||||
testClaudeMessage(t, claudeKey, model, false)
|
||||
})
|
||||
time.Sleep(testInterval)
|
||||
t.Run(model+"_通过Claude端点_流式", func(t *testing.T) {
|
||||
testClaudeMessage(t, model, true)
|
||||
testClaudeMessage(t, claudeKey, model, true)
|
||||
})
|
||||
}
|
||||
}
|
||||
@@ -676,6 +718,7 @@ func TestClaudeMessagesWithGeminiModel(t *testing.T) {
|
||||
// TestClaudeMessagesWithNoSignature 测试历史 thinking block 不带 signature 的场景
|
||||
// 验证:Gemini 模型接受没有 signature 的 thinking block
|
||||
func TestClaudeMessagesWithNoSignature(t *testing.T) {
|
||||
claudeKey := requireClaudeAPIKey(t)
|
||||
models := []string{
|
||||
"claude-haiku-4-5-20251001", // gemini-3-flash - 支持无 signature
|
||||
}
|
||||
@@ -684,12 +727,12 @@ func TestClaudeMessagesWithNoSignature(t *testing.T) {
|
||||
time.Sleep(testInterval)
|
||||
}
|
||||
t.Run(model+"_无signature", func(t *testing.T) {
|
||||
testClaudeWithNoSignature(t, model)
|
||||
testClaudeWithNoSignature(t, claudeKey, model)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func testClaudeWithNoSignature(t *testing.T, model string) {
|
||||
func testClaudeWithNoSignature(t *testing.T, claudeKey string, model string) {
|
||||
url := baseURL + endpointPrefix + "/v1/messages"
|
||||
|
||||
// 模拟历史对话包含 thinking block 但没有 signature
|
||||
@@ -732,7 +775,7 @@ func testClaudeWithNoSignature(t *testing.T, model string) {
|
||||
|
||||
req, _ := http.NewRequest("POST", url, bytes.NewReader(body))
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
req.Header.Set("Authorization", "Bearer "+claudeAPIKey)
|
||||
req.Header.Set("Authorization", "Bearer "+claudeKey)
|
||||
req.Header.Set("anthropic-version", "2023-06-01")
|
||||
|
||||
client := &http.Client{Timeout: 60 * time.Second}
|
||||
@@ -777,6 +820,7 @@ func TestGeminiEndpointWithClaudeModel(t *testing.T) {
|
||||
if endpointPrefix != "/antigravity" {
|
||||
t.Skip("仅在 Antigravity 模式下运行")
|
||||
}
|
||||
geminiKey := requireGeminiAPIKey(t)
|
||||
|
||||
// 测试通过 Gemini 端点调用 Claude 模型
|
||||
claudeViaGemini := []string{
|
||||
@@ -789,11 +833,11 @@ func TestGeminiEndpointWithClaudeModel(t *testing.T) {
|
||||
time.Sleep(testInterval)
|
||||
}
|
||||
t.Run(model+"_通过Gemini端点", func(t *testing.T) {
|
||||
testGeminiGenerate(t, model, false)
|
||||
testGeminiGenerate(t, geminiKey, model, false)
|
||||
})
|
||||
time.Sleep(testInterval)
|
||||
t.Run(model+"_通过Gemini端点_流式", func(t *testing.T) {
|
||||
testGeminiGenerate(t, model, true)
|
||||
testGeminiGenerate(t, geminiKey, model, true)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
48
backend/internal/integration/e2e_helpers_test.go
Normal file
48
backend/internal/integration/e2e_helpers_test.go
Normal file
@@ -0,0 +1,48 @@
|
||||
//go:build e2e
|
||||
|
||||
package integration
|
||||
|
||||
import (
|
||||
"os"
|
||||
"strings"
|
||||
"testing"
|
||||
)
|
||||
|
||||
// =============================================================================
|
||||
// E2E Mock 模式支持
|
||||
// =============================================================================
|
||||
// 当 E2E_MOCK=true 时,使用本地 Mock 响应替代真实 API 调用。
|
||||
// 这允许在没有真实 API Key 的环境(如 CI)中验证基本的请求/响应流程。
|
||||
|
||||
// isMockMode 检查是否启用 Mock 模式
|
||||
func isMockMode() bool {
|
||||
return strings.EqualFold(os.Getenv("E2E_MOCK"), "true")
|
||||
}
|
||||
|
||||
// skipIfNoRealAPI 如果未配置真实 API Key 且不在 Mock 模式,则跳过测试
|
||||
func skipIfNoRealAPI(t *testing.T) {
|
||||
t.Helper()
|
||||
if isMockMode() {
|
||||
return // Mock 模式下不跳过
|
||||
}
|
||||
claudeKey := strings.TrimSpace(os.Getenv(claudeAPIKeyEnv))
|
||||
geminiKey := strings.TrimSpace(os.Getenv(geminiAPIKeyEnv))
|
||||
if claudeKey == "" && geminiKey == "" {
|
||||
t.Skip("未设置 API Key 且未启用 Mock 模式,跳过测试")
|
||||
}
|
||||
}
|
||||
|
||||
// =============================================================================
|
||||
// API Key 脱敏(Task 6.10)
|
||||
// =============================================================================
|
||||
|
||||
// safeLogKey 安全地记录 API Key(仅显示前 8 位)
|
||||
func safeLogKey(t *testing.T, prefix string, key string) {
|
||||
t.Helper()
|
||||
key = strings.TrimSpace(key)
|
||||
if len(key) <= 8 {
|
||||
t.Logf("%s: ***(长度: %d)", prefix, len(key))
|
||||
return
|
||||
}
|
||||
t.Logf("%s: %s...(长度: %d)", prefix, key[:8], len(key))
|
||||
}
|
||||
317
backend/internal/integration/e2e_user_flow_test.go
Normal file
317
backend/internal/integration/e2e_user_flow_test.go
Normal file
@@ -0,0 +1,317 @@
|
||||
//go:build e2e
|
||||
|
||||
package integration
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
)
|
||||
|
||||
// E2E 用户流程测试
|
||||
// 测试完整的用户操作链路:注册 → 登录 → 创建 API Key → 调用网关 → 查询用量
|
||||
|
||||
var (
|
||||
testUserEmail = "e2e-test-" + fmt.Sprintf("%d", time.Now().UnixMilli()) + "@test.local"
|
||||
testUserPassword = "E2eTest@12345"
|
||||
testUserName = "e2e-test-user"
|
||||
)
|
||||
|
||||
// TestUserRegistrationAndLogin 测试用户注册和登录流程
|
||||
func TestUserRegistrationAndLogin(t *testing.T) {
|
||||
// 步骤 1: 注册新用户
|
||||
t.Run("注册新用户", func(t *testing.T) {
|
||||
payload := map[string]string{
|
||||
"email": testUserEmail,
|
||||
"password": testUserPassword,
|
||||
"username": testUserName,
|
||||
}
|
||||
body, _ := json.Marshal(payload)
|
||||
|
||||
resp, err := doRequest(t, "POST", "/api/auth/register", body, "")
|
||||
if err != nil {
|
||||
t.Skipf("注册接口不可用,跳过用户流程测试: %v", err)
|
||||
return
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
respBody, _ := io.ReadAll(resp.Body)
|
||||
|
||||
// 注册可能返回 200(成功)或 400(邮箱已存在)或 403(注册已关闭)
|
||||
switch resp.StatusCode {
|
||||
case 200:
|
||||
t.Logf("✅ 用户注册成功: %s", testUserEmail)
|
||||
case 400:
|
||||
t.Logf("⚠️ 用户可能已存在: %s", string(respBody))
|
||||
case 403:
|
||||
t.Skipf("注册功能已关闭: %s", string(respBody))
|
||||
default:
|
||||
t.Logf("⚠️ 注册返回 HTTP %d: %s(继续尝试登录)", resp.StatusCode, string(respBody))
|
||||
}
|
||||
})
|
||||
|
||||
// 步骤 2: 登录获取 JWT
|
||||
var accessToken string
|
||||
t.Run("用户登录获取JWT", func(t *testing.T) {
|
||||
payload := map[string]string{
|
||||
"email": testUserEmail,
|
||||
"password": testUserPassword,
|
||||
}
|
||||
body, _ := json.Marshal(payload)
|
||||
|
||||
resp, err := doRequest(t, "POST", "/api/auth/login", body, "")
|
||||
if err != nil {
|
||||
t.Fatalf("登录请求失败: %v", err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
respBody, _ := io.ReadAll(resp.Body)
|
||||
|
||||
if resp.StatusCode != 200 {
|
||||
t.Skipf("登录失败 HTTP %d: %s(可能需要先注册用户)", resp.StatusCode, string(respBody))
|
||||
return
|
||||
}
|
||||
|
||||
var result map[string]any
|
||||
if err := json.Unmarshal(respBody, &result); err != nil {
|
||||
t.Fatalf("解析登录响应失败: %v", err)
|
||||
}
|
||||
|
||||
// 尝试从标准响应格式获取 token
|
||||
if token, ok := result["access_token"].(string); ok && token != "" {
|
||||
accessToken = token
|
||||
} else if data, ok := result["data"].(map[string]any); ok {
|
||||
if token, ok := data["access_token"].(string); ok {
|
||||
accessToken = token
|
||||
}
|
||||
}
|
||||
|
||||
if accessToken == "" {
|
||||
t.Skipf("未获取到 access_token,响应: %s", string(respBody))
|
||||
return
|
||||
}
|
||||
|
||||
// 验证 token 不为空且格式基本正确
|
||||
if len(accessToken) < 10 {
|
||||
t.Fatalf("access_token 格式异常: %s", accessToken)
|
||||
}
|
||||
|
||||
t.Logf("✅ 登录成功,获取 JWT(长度: %d)", len(accessToken))
|
||||
})
|
||||
|
||||
if accessToken == "" {
|
||||
t.Skip("未获取到 JWT,跳过后续测试")
|
||||
return
|
||||
}
|
||||
|
||||
// 步骤 3: 使用 JWT 获取当前用户信息
|
||||
t.Run("获取当前用户信息", func(t *testing.T) {
|
||||
resp, err := doRequest(t, "GET", "/api/user/me", nil, accessToken)
|
||||
if err != nil {
|
||||
t.Fatalf("请求失败: %v", err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
if resp.StatusCode != 200 {
|
||||
body, _ := io.ReadAll(resp.Body)
|
||||
t.Fatalf("HTTP %d: %s", resp.StatusCode, string(body))
|
||||
}
|
||||
|
||||
t.Logf("✅ 成功获取用户信息")
|
||||
})
|
||||
}
|
||||
|
||||
// TestAPIKeyLifecycle 测试 API Key 的创建和使用
|
||||
func TestAPIKeyLifecycle(t *testing.T) {
|
||||
// 先登录获取 JWT
|
||||
accessToken := loginTestUser(t)
|
||||
if accessToken == "" {
|
||||
t.Skip("无法登录,跳过 API Key 生命周期测试")
|
||||
return
|
||||
}
|
||||
|
||||
var apiKey string
|
||||
|
||||
// 步骤 1: 创建 API Key
|
||||
t.Run("创建API_Key", func(t *testing.T) {
|
||||
payload := map[string]string{
|
||||
"name": "e2e-test-key-" + fmt.Sprintf("%d", time.Now().UnixMilli()),
|
||||
}
|
||||
body, _ := json.Marshal(payload)
|
||||
|
||||
resp, err := doRequest(t, "POST", "/api/keys", body, accessToken)
|
||||
if err != nil {
|
||||
t.Fatalf("创建 API Key 请求失败: %v", err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
respBody, _ := io.ReadAll(resp.Body)
|
||||
|
||||
if resp.StatusCode != 200 {
|
||||
t.Skipf("创建 API Key 失败 HTTP %d: %s", resp.StatusCode, string(respBody))
|
||||
return
|
||||
}
|
||||
|
||||
var result map[string]any
|
||||
if err := json.Unmarshal(respBody, &result); err != nil {
|
||||
t.Fatalf("解析响应失败: %v", err)
|
||||
}
|
||||
|
||||
// 从响应中提取 key
|
||||
if key, ok := result["key"].(string); ok {
|
||||
apiKey = key
|
||||
} else if data, ok := result["data"].(map[string]any); ok {
|
||||
if key, ok := data["key"].(string); ok {
|
||||
apiKey = key
|
||||
}
|
||||
}
|
||||
|
||||
if apiKey == "" {
|
||||
t.Skipf("未获取到 API Key,响应: %s", string(respBody))
|
||||
return
|
||||
}
|
||||
|
||||
// 验证 API Key 脱敏日志(只显示前 8 位)
|
||||
masked := apiKey
|
||||
if len(masked) > 8 {
|
||||
masked = masked[:8] + "..."
|
||||
}
|
||||
t.Logf("✅ API Key 创建成功: %s", masked)
|
||||
})
|
||||
|
||||
if apiKey == "" {
|
||||
t.Skip("未创建 API Key,跳过后续测试")
|
||||
return
|
||||
}
|
||||
|
||||
// 步骤 2: 使用 API Key 调用网关(需要 Claude 或 Gemini 可用)
|
||||
t.Run("使用API_Key调用网关", func(t *testing.T) {
|
||||
// 尝试调用 models 列表(最轻量的 API 调用)
|
||||
resp, err := doRequest(t, "GET", "/v1/models", nil, apiKey)
|
||||
if err != nil {
|
||||
t.Fatalf("网关请求失败: %v", err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
respBody, _ := io.ReadAll(resp.Body)
|
||||
|
||||
// 可能返回 200(成功)或 402(余额不足)或 403(无可用账户)
|
||||
switch {
|
||||
case resp.StatusCode == 200:
|
||||
t.Logf("✅ API Key 网关调用成功")
|
||||
case resp.StatusCode == 402:
|
||||
t.Logf("⚠️ 余额不足,但 API Key 认证通过")
|
||||
case resp.StatusCode == 403:
|
||||
t.Logf("⚠️ 无可用账户,但 API Key 认证通过")
|
||||
default:
|
||||
t.Logf("⚠️ 网关返回 HTTP %d: %s", resp.StatusCode, string(respBody))
|
||||
}
|
||||
})
|
||||
|
||||
// 步骤 3: 查询用量记录
|
||||
t.Run("查询用量记录", func(t *testing.T) {
|
||||
resp, err := doRequest(t, "GET", "/api/usage/dashboard", nil, accessToken)
|
||||
if err != nil {
|
||||
t.Fatalf("用量查询请求失败: %v", err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
if resp.StatusCode != 200 {
|
||||
body, _ := io.ReadAll(resp.Body)
|
||||
t.Logf("⚠️ 用量查询返回 HTTP %d: %s", resp.StatusCode, string(body))
|
||||
return
|
||||
}
|
||||
|
||||
t.Logf("✅ 用量查询成功")
|
||||
})
|
||||
}
|
||||
|
||||
// =============================================================================
|
||||
// 辅助函数
|
||||
// =============================================================================
|
||||
|
||||
func doRequest(t *testing.T, method, path string, body []byte, token string) (*http.Response, error) {
|
||||
t.Helper()
|
||||
|
||||
url := baseURL + path
|
||||
var bodyReader io.Reader
|
||||
if body != nil {
|
||||
bodyReader = bytes.NewReader(body)
|
||||
}
|
||||
|
||||
req, err := http.NewRequest(method, url, bodyReader)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("创建请求失败: %w", err)
|
||||
}
|
||||
|
||||
if body != nil {
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
}
|
||||
if token != "" {
|
||||
req.Header.Set("Authorization", "Bearer "+token)
|
||||
}
|
||||
|
||||
client := &http.Client{Timeout: 30 * time.Second}
|
||||
return client.Do(req)
|
||||
}
|
||||
|
||||
func loginTestUser(t *testing.T) string {
|
||||
t.Helper()
|
||||
|
||||
// 先尝试用管理员账户登录
|
||||
adminEmail := getEnv("ADMIN_EMAIL", "admin@sub2api.local")
|
||||
adminPassword := getEnv("ADMIN_PASSWORD", "")
|
||||
|
||||
if adminPassword == "" {
|
||||
// 尝试用测试用户
|
||||
adminEmail = testUserEmail
|
||||
adminPassword = testUserPassword
|
||||
}
|
||||
|
||||
payload := map[string]string{
|
||||
"email": adminEmail,
|
||||
"password": adminPassword,
|
||||
}
|
||||
body, _ := json.Marshal(payload)
|
||||
|
||||
resp, err := doRequest(t, "POST", "/api/auth/login", body, "")
|
||||
if err != nil {
|
||||
return ""
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
if resp.StatusCode != 200 {
|
||||
return ""
|
||||
}
|
||||
|
||||
respBody, _ := io.ReadAll(resp.Body)
|
||||
var result map[string]any
|
||||
if err := json.Unmarshal(respBody, &result); err != nil {
|
||||
return ""
|
||||
}
|
||||
|
||||
if token, ok := result["access_token"].(string); ok {
|
||||
return token
|
||||
}
|
||||
if data, ok := result["data"].(map[string]any); ok {
|
||||
if token, ok := data["access_token"].(string); ok {
|
||||
return token
|
||||
}
|
||||
}
|
||||
|
||||
return ""
|
||||
}
|
||||
|
||||
// redactAPIKey API Key 脱敏,只显示前 8 位
|
||||
func redactAPIKey(key string) string {
|
||||
key = strings.TrimSpace(key)
|
||||
if len(key) <= 8 {
|
||||
return "***"
|
||||
}
|
||||
return key[:8] + "..."
|
||||
}
|
||||
@@ -60,6 +60,49 @@ func TestRateLimiterFailureModes(t *testing.T) {
|
||||
require.Equal(t, http.StatusTooManyRequests, recorder.Code)
|
||||
}
|
||||
|
||||
func TestRateLimiterDifferentIPsIndependent(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
|
||||
callCounts := make(map[string]int64)
|
||||
originalRun := rateLimitRun
|
||||
rateLimitRun = func(ctx context.Context, client *redis.Client, key string, windowMillis int64) (int64, bool, error) {
|
||||
callCounts[key]++
|
||||
return callCounts[key], false, nil
|
||||
}
|
||||
t.Cleanup(func() {
|
||||
rateLimitRun = originalRun
|
||||
})
|
||||
|
||||
limiter := NewRateLimiter(redis.NewClient(&redis.Options{Addr: "127.0.0.1:1"}))
|
||||
|
||||
router := gin.New()
|
||||
router.Use(limiter.Limit("api", 1, time.Second))
|
||||
router.GET("/test", func(c *gin.Context) {
|
||||
c.JSON(http.StatusOK, gin.H{"ok": true})
|
||||
})
|
||||
|
||||
// 第一个 IP 的请求应通过
|
||||
req1 := httptest.NewRequest(http.MethodGet, "/test", nil)
|
||||
req1.RemoteAddr = "10.0.0.1:1234"
|
||||
rec1 := httptest.NewRecorder()
|
||||
router.ServeHTTP(rec1, req1)
|
||||
require.Equal(t, http.StatusOK, rec1.Code, "第一个 IP 的第一次请求应通过")
|
||||
|
||||
// 第二个 IP 的请求应独立通过(不受第一个 IP 的计数影响)
|
||||
req2 := httptest.NewRequest(http.MethodGet, "/test", nil)
|
||||
req2.RemoteAddr = "10.0.0.2:5678"
|
||||
rec2 := httptest.NewRecorder()
|
||||
router.ServeHTTP(rec2, req2)
|
||||
require.Equal(t, http.StatusOK, rec2.Code, "第二个 IP 的第一次请求应独立通过")
|
||||
|
||||
// 第一个 IP 的第二次请求应被限流
|
||||
req3 := httptest.NewRequest(http.MethodGet, "/test", nil)
|
||||
req3.RemoteAddr = "10.0.0.1:1234"
|
||||
rec3 := httptest.NewRecorder()
|
||||
router.ServeHTTP(rec3, req3)
|
||||
require.Equal(t, http.StatusTooManyRequests, rec3.Code, "第一个 IP 的第二次请求应被限流")
|
||||
}
|
||||
|
||||
func TestRateLimiterSuccessAndLimit(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
|
||||
|
||||
@@ -151,6 +151,9 @@ var claudeModels = []modelDef{
|
||||
{ID: "claude-opus-4-5-thinking", DisplayName: "Claude Opus 4.5 Thinking", CreatedAt: "2025-11-01T00:00:00Z"},
|
||||
{ID: "claude-sonnet-4-5", DisplayName: "Claude Sonnet 4.5", CreatedAt: "2025-09-29T00:00:00Z"},
|
||||
{ID: "claude-sonnet-4-5-thinking", DisplayName: "Claude Sonnet 4.5 Thinking", CreatedAt: "2025-09-29T00:00:00Z"},
|
||||
{ID: "claude-opus-4-6", DisplayName: "Claude Opus 4.6", CreatedAt: "2026-02-05T00:00:00Z"},
|
||||
{ID: "claude-opus-4-6-thinking", DisplayName: "Claude Opus 4.6 Thinking", CreatedAt: "2026-02-05T00:00:00Z"},
|
||||
{ID: "claude-sonnet-4-6", DisplayName: "Claude Sonnet 4.6", CreatedAt: "2026-02-17T00:00:00Z"},
|
||||
}
|
||||
|
||||
// Antigravity 支持的 Gemini 模型
|
||||
@@ -161,6 +164,10 @@ var geminiModels = []modelDef{
|
||||
{ID: "gemini-3-flash", DisplayName: "Gemini 3 Flash", CreatedAt: "2025-06-01T00:00:00Z"},
|
||||
{ID: "gemini-3-pro-low", DisplayName: "Gemini 3 Pro Low", CreatedAt: "2025-06-01T00:00:00Z"},
|
||||
{ID: "gemini-3-pro-high", DisplayName: "Gemini 3 Pro High", CreatedAt: "2025-06-01T00:00:00Z"},
|
||||
{ID: "gemini-3.1-pro-low", DisplayName: "Gemini 3.1 Pro Low", CreatedAt: "2026-02-19T00:00:00Z"},
|
||||
{ID: "gemini-3.1-pro-high", DisplayName: "Gemini 3.1 Pro High", CreatedAt: "2026-02-19T00:00:00Z"},
|
||||
{ID: "gemini-3.1-flash-image", DisplayName: "Gemini 3.1 Flash Image", CreatedAt: "2026-02-19T00:00:00Z"},
|
||||
{ID: "gemini-3.1-flash-image-preview", DisplayName: "Gemini 3.1 Flash Image Preview", CreatedAt: "2026-02-19T00:00:00Z"},
|
||||
{ID: "gemini-3-pro-preview", DisplayName: "Gemini 3 Pro Preview", CreatedAt: "2025-06-01T00:00:00Z"},
|
||||
{ID: "gemini-3-pro-image", DisplayName: "Gemini 3 Pro Image", CreatedAt: "2025-06-01T00:00:00Z"},
|
||||
}
|
||||
|
||||
26
backend/internal/pkg/antigravity/claude_types_test.go
Normal file
26
backend/internal/pkg/antigravity/claude_types_test.go
Normal file
@@ -0,0 +1,26 @@
|
||||
package antigravity
|
||||
|
||||
import "testing"
|
||||
|
||||
func TestDefaultModels_ContainsNewAndLegacyImageModels(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
models := DefaultModels()
|
||||
byID := make(map[string]ClaudeModel, len(models))
|
||||
for _, m := range models {
|
||||
byID[m.ID] = m
|
||||
}
|
||||
|
||||
requiredIDs := []string{
|
||||
"claude-opus-4-6-thinking",
|
||||
"gemini-3.1-flash-image",
|
||||
"gemini-3.1-flash-image-preview",
|
||||
"gemini-3-pro-image", // legacy compatibility
|
||||
}
|
||||
|
||||
for _, id := range requiredIDs {
|
||||
if _, ok := byID[id]; !ok {
|
||||
t.Fatalf("expected model %q to be exposed in DefaultModels", id)
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -14,6 +14,9 @@ import (
|
||||
"net/url"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/proxyurl"
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/proxyutil"
|
||||
)
|
||||
|
||||
// NewAPIRequestWithURL 使用指定的 base URL 创建 Antigravity API 请求(v1internal 端点)
|
||||
@@ -33,7 +36,7 @@ func NewAPIRequestWithURL(ctx context.Context, baseURL, action, accessToken stri
|
||||
// 基础 Headers(与 Antigravity-Manager 保持一致,只设置这 3 个)
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
req.Header.Set("Authorization", "Bearer "+accessToken)
|
||||
req.Header.Set("User-Agent", UserAgent)
|
||||
req.Header.Set("User-Agent", GetUserAgent())
|
||||
|
||||
return req, nil
|
||||
}
|
||||
@@ -149,22 +152,26 @@ type Client struct {
|
||||
httpClient *http.Client
|
||||
}
|
||||
|
||||
func NewClient(proxyURL string) *Client {
|
||||
func NewClient(proxyURL string) (*Client, error) {
|
||||
client := &http.Client{
|
||||
Timeout: 30 * time.Second,
|
||||
}
|
||||
|
||||
if strings.TrimSpace(proxyURL) != "" {
|
||||
if proxyURLParsed, err := url.Parse(proxyURL); err == nil {
|
||||
client.Transport = &http.Transport{
|
||||
Proxy: http.ProxyURL(proxyURLParsed),
|
||||
}
|
||||
_, parsed, err := proxyurl.Parse(proxyURL)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if parsed != nil {
|
||||
transport := &http.Transport{}
|
||||
if err := proxyutil.ConfigureTransportProxy(transport, parsed); err != nil {
|
||||
return nil, fmt.Errorf("configure proxy: %w", err)
|
||||
}
|
||||
client.Transport = transport
|
||||
}
|
||||
|
||||
return &Client{
|
||||
httpClient: client,
|
||||
}
|
||||
}, nil
|
||||
}
|
||||
|
||||
// isConnectionError 判断是否为连接错误(网络超时、DNS 失败、连接拒绝)
|
||||
@@ -204,9 +211,14 @@ func shouldFallbackToNextURL(err error, statusCode int) bool {
|
||||
|
||||
// ExchangeCode 用 authorization code 交换 token
|
||||
func (c *Client) ExchangeCode(ctx context.Context, code, codeVerifier string) (*TokenResponse, error) {
|
||||
clientSecret, err := getClientSecret()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
params := url.Values{}
|
||||
params.Set("client_id", ClientID)
|
||||
params.Set("client_secret", ClientSecret)
|
||||
params.Set("client_secret", clientSecret)
|
||||
params.Set("code", code)
|
||||
params.Set("redirect_uri", RedirectURI)
|
||||
params.Set("grant_type", "authorization_code")
|
||||
@@ -243,9 +255,14 @@ func (c *Client) ExchangeCode(ctx context.Context, code, codeVerifier string) (*
|
||||
|
||||
// RefreshToken 刷新 access_token
|
||||
func (c *Client) RefreshToken(ctx context.Context, refreshToken string) (*TokenResponse, error) {
|
||||
clientSecret, err := getClientSecret()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
params := url.Values{}
|
||||
params.Set("client_id", ClientID)
|
||||
params.Set("client_secret", ClientSecret)
|
||||
params.Set("client_secret", clientSecret)
|
||||
params.Set("refresh_token", refreshToken)
|
||||
params.Set("grant_type", "refresh_token")
|
||||
|
||||
@@ -333,7 +350,7 @@ func (c *Client) LoadCodeAssist(ctx context.Context, accessToken string) (*LoadC
|
||||
}
|
||||
req.Header.Set("Authorization", "Bearer "+accessToken)
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
req.Header.Set("User-Agent", UserAgent)
|
||||
req.Header.Set("User-Agent", GetUserAgent())
|
||||
|
||||
resp, err := c.httpClient.Do(req)
|
||||
if err != nil {
|
||||
@@ -412,7 +429,7 @@ func (c *Client) OnboardUser(ctx context.Context, accessToken, tierID string) (s
|
||||
}
|
||||
req.Header.Set("Authorization", "Bearer "+accessToken)
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
req.Header.Set("User-Agent", UserAgent)
|
||||
req.Header.Set("User-Agent", GetUserAgent())
|
||||
|
||||
resp, err := c.httpClient.Do(req)
|
||||
if err != nil {
|
||||
@@ -532,7 +549,7 @@ func (c *Client) FetchAvailableModels(ctx context.Context, accessToken, projectI
|
||||
}
|
||||
req.Header.Set("Authorization", "Bearer "+accessToken)
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
req.Header.Set("User-Agent", UserAgent)
|
||||
req.Header.Set("User-Agent", GetUserAgent())
|
||||
|
||||
resp, err := c.httpClient.Do(req)
|
||||
if err != nil {
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@@ -70,7 +70,7 @@ type GeminiGenerationConfig struct {
|
||||
ImageConfig *GeminiImageConfig `json:"imageConfig,omitempty"`
|
||||
}
|
||||
|
||||
// GeminiImageConfig Gemini 图片生成配置(仅 gemini-3-pro-image 支持)
|
||||
// GeminiImageConfig Gemini 图片生成配置(gemini-3-pro-image / gemini-3.1-flash-image 等图片模型支持)
|
||||
type GeminiImageConfig struct {
|
||||
AspectRatio string `json:"aspectRatio,omitempty"` // "1:1", "16:9", "9:16", "4:3", "3:4"
|
||||
ImageSize string `json:"imageSize,omitempty"` // "1K", "2K", "4K"
|
||||
|
||||
@@ -6,10 +6,14 @@ import (
|
||||
"encoding/base64"
|
||||
"encoding/hex"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"os"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
infraerrors "github.com/Wei-Shaw/sub2api/internal/pkg/errors"
|
||||
)
|
||||
|
||||
const (
|
||||
@@ -19,8 +23,10 @@ const (
|
||||
UserInfoURL = "https://www.googleapis.com/oauth2/v2/userinfo"
|
||||
|
||||
// Antigravity OAuth 客户端凭证
|
||||
ClientID = "1071006060591-tmhssin2h21lcre235vtolojh4g403ep.apps.googleusercontent.com"
|
||||
ClientSecret = "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"
|
||||
ClientID = "1071006060591-tmhssin2h21lcre235vtolojh4g403ep.apps.googleusercontent.com"
|
||||
|
||||
// AntigravityOAuthClientSecretEnv 是 Antigravity OAuth client_secret 的环境变量名。
|
||||
AntigravityOAuthClientSecretEnv = "ANTIGRAVITY_OAUTH_CLIENT_SECRET"
|
||||
|
||||
// 固定的 redirect_uri(用户需手动复制 code)
|
||||
RedirectURI = "http://localhost:8085/callback"
|
||||
@@ -32,9 +38,6 @@ const (
|
||||
"https://www.googleapis.com/auth/cclog " +
|
||||
"https://www.googleapis.com/auth/experimentsandconfigs"
|
||||
|
||||
// User-Agent(与 Antigravity-Manager 保持一致)
|
||||
UserAgent = "antigravity/1.15.8 windows/amd64"
|
||||
|
||||
// Session 过期时间
|
||||
SessionTTL = 30 * time.Minute
|
||||
|
||||
@@ -46,6 +49,35 @@ const (
|
||||
antigravityDailyBaseURL = "https://daily-cloudcode-pa.sandbox.googleapis.com"
|
||||
)
|
||||
|
||||
// defaultUserAgentVersion 可通过环境变量 ANTIGRAVITY_USER_AGENT_VERSION 配置,默认 1.19.6
|
||||
var defaultUserAgentVersion = "1.19.6"
|
||||
|
||||
// defaultClientSecret 可通过环境变量 ANTIGRAVITY_OAUTH_CLIENT_SECRET 配置
|
||||
var defaultClientSecret = "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"
|
||||
|
||||
func init() {
|
||||
// 从环境变量读取版本号,未设置则使用默认值
|
||||
if version := os.Getenv("ANTIGRAVITY_USER_AGENT_VERSION"); version != "" {
|
||||
defaultUserAgentVersion = version
|
||||
}
|
||||
// 从环境变量读取 client_secret,未设置则使用默认值
|
||||
if secret := os.Getenv(AntigravityOAuthClientSecretEnv); secret != "" {
|
||||
defaultClientSecret = secret
|
||||
}
|
||||
}
|
||||
|
||||
// GetUserAgent 返回当前配置的 User-Agent
|
||||
func GetUserAgent() string {
|
||||
return fmt.Sprintf("antigravity/%s windows/amd64", defaultUserAgentVersion)
|
||||
}
|
||||
|
||||
func getClientSecret() (string, error) {
|
||||
if v := strings.TrimSpace(defaultClientSecret); v != "" {
|
||||
return v, nil
|
||||
}
|
||||
return "", infraerrors.Newf(http.StatusBadRequest, "ANTIGRAVITY_OAUTH_CLIENT_SECRET_MISSING", "missing antigravity oauth client_secret; set %s", AntigravityOAuthClientSecretEnv)
|
||||
}
|
||||
|
||||
// BaseURLs 定义 Antigravity API 端点(与 Antigravity-Manager 保持一致)
|
||||
var BaseURLs = []string{
|
||||
antigravityProdBaseURL, // prod (优先)
|
||||
|
||||
718
backend/internal/pkg/antigravity/oauth_test.go
Normal file
718
backend/internal/pkg/antigravity/oauth_test.go
Normal file
@@ -0,0 +1,718 @@
|
||||
//go:build unit
|
||||
|
||||
package antigravity
|
||||
|
||||
import (
|
||||
"crypto/sha256"
|
||||
"encoding/base64"
|
||||
"encoding/hex"
|
||||
"net/url"
|
||||
"os"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
)
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// getClientSecret
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
func TestGetClientSecret_环境变量设置(t *testing.T) {
|
||||
old := defaultClientSecret
|
||||
defaultClientSecret = ""
|
||||
t.Cleanup(func() { defaultClientSecret = old })
|
||||
t.Setenv(AntigravityOAuthClientSecretEnv, "my-secret-value")
|
||||
|
||||
// 需要重新触发 init 逻辑:手动从环境变量读取
|
||||
defaultClientSecret = os.Getenv(AntigravityOAuthClientSecretEnv)
|
||||
|
||||
secret, err := getClientSecret()
|
||||
if err != nil {
|
||||
t.Fatalf("获取 client_secret 失败: %v", err)
|
||||
}
|
||||
if secret != "my-secret-value" {
|
||||
t.Errorf("client_secret 不匹配: got %s, want my-secret-value", secret)
|
||||
}
|
||||
}
|
||||
|
||||
func TestGetClientSecret_环境变量为空(t *testing.T) {
|
||||
old := defaultClientSecret
|
||||
defaultClientSecret = ""
|
||||
t.Cleanup(func() { defaultClientSecret = old })
|
||||
|
||||
_, err := getClientSecret()
|
||||
if err == nil {
|
||||
t.Fatal("defaultClientSecret 为空时应返回错误")
|
||||
}
|
||||
if !strings.Contains(err.Error(), AntigravityOAuthClientSecretEnv) {
|
||||
t.Errorf("错误信息应包含环境变量名: got %s", err.Error())
|
||||
}
|
||||
}
|
||||
|
||||
func TestGetClientSecret_环境变量未设置(t *testing.T) {
|
||||
old := defaultClientSecret
|
||||
defaultClientSecret = ""
|
||||
t.Cleanup(func() { defaultClientSecret = old })
|
||||
|
||||
_, err := getClientSecret()
|
||||
if err == nil {
|
||||
t.Fatal("defaultClientSecret 为空时应返回错误")
|
||||
}
|
||||
}
|
||||
|
||||
func TestGetClientSecret_环境变量含空格(t *testing.T) {
|
||||
old := defaultClientSecret
|
||||
defaultClientSecret = " "
|
||||
t.Cleanup(func() { defaultClientSecret = old })
|
||||
|
||||
_, err := getClientSecret()
|
||||
if err == nil {
|
||||
t.Fatal("defaultClientSecret 仅含空格时应返回错误")
|
||||
}
|
||||
}
|
||||
|
||||
func TestGetClientSecret_环境变量有前后空格(t *testing.T) {
|
||||
old := defaultClientSecret
|
||||
defaultClientSecret = " valid-secret "
|
||||
t.Cleanup(func() { defaultClientSecret = old })
|
||||
|
||||
secret, err := getClientSecret()
|
||||
if err != nil {
|
||||
t.Fatalf("获取 client_secret 失败: %v", err)
|
||||
}
|
||||
if secret != "valid-secret" {
|
||||
t.Errorf("应去除前后空格: got %q, want %q", secret, "valid-secret")
|
||||
}
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// ForwardBaseURLs
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
func TestForwardBaseURLs_Daily优先(t *testing.T) {
|
||||
urls := ForwardBaseURLs()
|
||||
if len(urls) == 0 {
|
||||
t.Fatal("ForwardBaseURLs 返回空列表")
|
||||
}
|
||||
|
||||
// daily URL 应排在第一位
|
||||
if urls[0] != antigravityDailyBaseURL {
|
||||
t.Errorf("第一个 URL 应为 daily: got %s, want %s", urls[0], antigravityDailyBaseURL)
|
||||
}
|
||||
|
||||
// 应包含所有 URL
|
||||
if len(urls) != len(BaseURLs) {
|
||||
t.Errorf("URL 数量不匹配: got %d, want %d", len(urls), len(BaseURLs))
|
||||
}
|
||||
|
||||
// 验证 prod URL 也在列表中
|
||||
found := false
|
||||
for _, u := range urls {
|
||||
if u == antigravityProdBaseURL {
|
||||
found = true
|
||||
break
|
||||
}
|
||||
}
|
||||
if !found {
|
||||
t.Error("ForwardBaseURLs 中缺少 prod URL")
|
||||
}
|
||||
}
|
||||
|
||||
func TestForwardBaseURLs_不修改原切片(t *testing.T) {
|
||||
originalFirst := BaseURLs[0]
|
||||
_ = ForwardBaseURLs()
|
||||
// 确保原始 BaseURLs 未被修改
|
||||
if BaseURLs[0] != originalFirst {
|
||||
t.Errorf("ForwardBaseURLs 不应修改原始 BaseURLs: got %s, want %s", BaseURLs[0], originalFirst)
|
||||
}
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// URLAvailability
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
func TestNewURLAvailability(t *testing.T) {
|
||||
ua := NewURLAvailability(5 * time.Minute)
|
||||
if ua == nil {
|
||||
t.Fatal("NewURLAvailability 返回 nil")
|
||||
}
|
||||
if ua.ttl != 5*time.Minute {
|
||||
t.Errorf("TTL 不匹配: got %v, want 5m", ua.ttl)
|
||||
}
|
||||
if ua.unavailable == nil {
|
||||
t.Error("unavailable map 不应为 nil")
|
||||
}
|
||||
}
|
||||
|
||||
func TestURLAvailability_MarkUnavailable(t *testing.T) {
|
||||
ua := NewURLAvailability(5 * time.Minute)
|
||||
testURL := "https://example.com"
|
||||
|
||||
ua.MarkUnavailable(testURL)
|
||||
|
||||
if ua.IsAvailable(testURL) {
|
||||
t.Error("标记为不可用后 IsAvailable 应返回 false")
|
||||
}
|
||||
}
|
||||
|
||||
func TestURLAvailability_MarkSuccess(t *testing.T) {
|
||||
ua := NewURLAvailability(5 * time.Minute)
|
||||
testURL := "https://example.com"
|
||||
|
||||
// 先标记为不可用
|
||||
ua.MarkUnavailable(testURL)
|
||||
if ua.IsAvailable(testURL) {
|
||||
t.Error("标记为不可用后应不可用")
|
||||
}
|
||||
|
||||
// 标记成功后应恢复可用
|
||||
ua.MarkSuccess(testURL)
|
||||
if !ua.IsAvailable(testURL) {
|
||||
t.Error("MarkSuccess 后应恢复可用")
|
||||
}
|
||||
|
||||
// 验证 lastSuccess 被设置
|
||||
ua.mu.RLock()
|
||||
if ua.lastSuccess != testURL {
|
||||
t.Errorf("lastSuccess 不匹配: got %s, want %s", ua.lastSuccess, testURL)
|
||||
}
|
||||
ua.mu.RUnlock()
|
||||
}
|
||||
|
||||
func TestURLAvailability_IsAvailable_TTL过期(t *testing.T) {
|
||||
// 使用极短的 TTL
|
||||
ua := NewURLAvailability(1 * time.Millisecond)
|
||||
testURL := "https://example.com"
|
||||
|
||||
ua.MarkUnavailable(testURL)
|
||||
// 等待 TTL 过期
|
||||
time.Sleep(5 * time.Millisecond)
|
||||
|
||||
if !ua.IsAvailable(testURL) {
|
||||
t.Error("TTL 过期后 URL 应恢复可用")
|
||||
}
|
||||
}
|
||||
|
||||
func TestURLAvailability_IsAvailable_未标记的URL(t *testing.T) {
|
||||
ua := NewURLAvailability(5 * time.Minute)
|
||||
if !ua.IsAvailable("https://never-marked.com") {
|
||||
t.Error("未标记的 URL 应默认可用")
|
||||
}
|
||||
}
|
||||
|
||||
func TestURLAvailability_GetAvailableURLs(t *testing.T) {
|
||||
ua := NewURLAvailability(10 * time.Minute)
|
||||
|
||||
// 默认所有 URL 都可用
|
||||
urls := ua.GetAvailableURLs()
|
||||
if len(urls) != len(BaseURLs) {
|
||||
t.Errorf("可用 URL 数量不匹配: got %d, want %d", len(urls), len(BaseURLs))
|
||||
}
|
||||
}
|
||||
|
||||
func TestURLAvailability_GetAvailableURLs_标记一个不可用(t *testing.T) {
|
||||
ua := NewURLAvailability(10 * time.Minute)
|
||||
|
||||
if len(BaseURLs) < 2 {
|
||||
t.Skip("BaseURLs 少于 2 个,跳过此测试")
|
||||
}
|
||||
|
||||
ua.MarkUnavailable(BaseURLs[0])
|
||||
urls := ua.GetAvailableURLs()
|
||||
|
||||
// 标记的 URL 不应出现在可用列表中
|
||||
for _, u := range urls {
|
||||
if u == BaseURLs[0] {
|
||||
t.Errorf("被标记不可用的 URL 不应出现在可用列表中: %s", BaseURLs[0])
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestURLAvailability_GetAvailableURLsWithBase(t *testing.T) {
|
||||
ua := NewURLAvailability(10 * time.Minute)
|
||||
customURLs := []string{"https://a.com", "https://b.com", "https://c.com"}
|
||||
|
||||
urls := ua.GetAvailableURLsWithBase(customURLs)
|
||||
if len(urls) != 3 {
|
||||
t.Errorf("可用 URL 数量不匹配: got %d, want 3", len(urls))
|
||||
}
|
||||
}
|
||||
|
||||
func TestURLAvailability_GetAvailableURLsWithBase_LastSuccess优先(t *testing.T) {
|
||||
ua := NewURLAvailability(10 * time.Minute)
|
||||
customURLs := []string{"https://a.com", "https://b.com", "https://c.com"}
|
||||
|
||||
ua.MarkSuccess("https://c.com")
|
||||
|
||||
urls := ua.GetAvailableURLsWithBase(customURLs)
|
||||
if len(urls) != 3 {
|
||||
t.Fatalf("可用 URL 数量不匹配: got %d, want 3", len(urls))
|
||||
}
|
||||
// c.com 应排在第一位
|
||||
if urls[0] != "https://c.com" {
|
||||
t.Errorf("lastSuccess 应排在第一位: got %s, want https://c.com", urls[0])
|
||||
}
|
||||
// 其余按原始顺序
|
||||
if urls[1] != "https://a.com" {
|
||||
t.Errorf("第二个应为 a.com: got %s", urls[1])
|
||||
}
|
||||
if urls[2] != "https://b.com" {
|
||||
t.Errorf("第三个应为 b.com: got %s", urls[2])
|
||||
}
|
||||
}
|
||||
|
||||
func TestURLAvailability_GetAvailableURLsWithBase_LastSuccess不可用(t *testing.T) {
|
||||
ua := NewURLAvailability(10 * time.Minute)
|
||||
customURLs := []string{"https://a.com", "https://b.com"}
|
||||
|
||||
ua.MarkSuccess("https://b.com")
|
||||
ua.MarkUnavailable("https://b.com")
|
||||
|
||||
urls := ua.GetAvailableURLsWithBase(customURLs)
|
||||
// b.com 被标记不可用,不应出现
|
||||
if len(urls) != 1 {
|
||||
t.Fatalf("可用 URL 数量不匹配: got %d, want 1", len(urls))
|
||||
}
|
||||
if urls[0] != "https://a.com" {
|
||||
t.Errorf("仅 a.com 应可用: got %s", urls[0])
|
||||
}
|
||||
}
|
||||
|
||||
func TestURLAvailability_GetAvailableURLsWithBase_LastSuccess不在列表中(t *testing.T) {
|
||||
ua := NewURLAvailability(10 * time.Minute)
|
||||
customURLs := []string{"https://a.com", "https://b.com"}
|
||||
|
||||
ua.MarkSuccess("https://not-in-list.com")
|
||||
|
||||
urls := ua.GetAvailableURLsWithBase(customURLs)
|
||||
// lastSuccess 不在自定义列表中,不应被添加
|
||||
if len(urls) != 2 {
|
||||
t.Fatalf("可用 URL 数量不匹配: got %d, want 2", len(urls))
|
||||
}
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// SessionStore
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
func TestNewSessionStore(t *testing.T) {
|
||||
store := NewSessionStore()
|
||||
defer store.Stop()
|
||||
|
||||
if store == nil {
|
||||
t.Fatal("NewSessionStore 返回 nil")
|
||||
}
|
||||
if store.sessions == nil {
|
||||
t.Error("sessions map 不应为 nil")
|
||||
}
|
||||
}
|
||||
|
||||
func TestSessionStore_SetAndGet(t *testing.T) {
|
||||
store := NewSessionStore()
|
||||
defer store.Stop()
|
||||
|
||||
session := &OAuthSession{
|
||||
State: "test-state",
|
||||
CodeVerifier: "test-verifier",
|
||||
ProxyURL: "http://proxy.example.com",
|
||||
CreatedAt: time.Now(),
|
||||
}
|
||||
|
||||
store.Set("session-1", session)
|
||||
|
||||
got, ok := store.Get("session-1")
|
||||
if !ok {
|
||||
t.Fatal("Get 应返回 true")
|
||||
}
|
||||
if got.State != "test-state" {
|
||||
t.Errorf("State 不匹配: got %s", got.State)
|
||||
}
|
||||
if got.CodeVerifier != "test-verifier" {
|
||||
t.Errorf("CodeVerifier 不匹配: got %s", got.CodeVerifier)
|
||||
}
|
||||
if got.ProxyURL != "http://proxy.example.com" {
|
||||
t.Errorf("ProxyURL 不匹配: got %s", got.ProxyURL)
|
||||
}
|
||||
}
|
||||
|
||||
func TestSessionStore_Get_不存在(t *testing.T) {
|
||||
store := NewSessionStore()
|
||||
defer store.Stop()
|
||||
|
||||
_, ok := store.Get("nonexistent")
|
||||
if ok {
|
||||
t.Error("不存在的 session 应返回 false")
|
||||
}
|
||||
}
|
||||
|
||||
func TestSessionStore_Get_过期(t *testing.T) {
|
||||
store := NewSessionStore()
|
||||
defer store.Stop()
|
||||
|
||||
session := &OAuthSession{
|
||||
State: "expired-state",
|
||||
CreatedAt: time.Now().Add(-SessionTTL - time.Minute), // 已过期
|
||||
}
|
||||
|
||||
store.Set("expired-session", session)
|
||||
|
||||
_, ok := store.Get("expired-session")
|
||||
if ok {
|
||||
t.Error("过期的 session 应返回 false")
|
||||
}
|
||||
}
|
||||
|
||||
func TestSessionStore_Delete(t *testing.T) {
|
||||
store := NewSessionStore()
|
||||
defer store.Stop()
|
||||
|
||||
session := &OAuthSession{
|
||||
State: "to-delete",
|
||||
CreatedAt: time.Now(),
|
||||
}
|
||||
|
||||
store.Set("del-session", session)
|
||||
store.Delete("del-session")
|
||||
|
||||
_, ok := store.Get("del-session")
|
||||
if ok {
|
||||
t.Error("删除后 Get 应返回 false")
|
||||
}
|
||||
}
|
||||
|
||||
func TestSessionStore_Delete_不存在(t *testing.T) {
|
||||
store := NewSessionStore()
|
||||
defer store.Stop()
|
||||
|
||||
// 删除不存在的 session 不应 panic
|
||||
store.Delete("nonexistent")
|
||||
}
|
||||
|
||||
func TestSessionStore_Stop(t *testing.T) {
|
||||
store := NewSessionStore()
|
||||
store.Stop()
|
||||
|
||||
// 多次 Stop 不应 panic
|
||||
store.Stop()
|
||||
}
|
||||
|
||||
func TestSessionStore_多个Session(t *testing.T) {
|
||||
store := NewSessionStore()
|
||||
defer store.Stop()
|
||||
|
||||
for i := 0; i < 10; i++ {
|
||||
session := &OAuthSession{
|
||||
State: "state-" + string(rune('0'+i)),
|
||||
CreatedAt: time.Now(),
|
||||
}
|
||||
store.Set("session-"+string(rune('0'+i)), session)
|
||||
}
|
||||
|
||||
// 验证都能取到
|
||||
for i := 0; i < 10; i++ {
|
||||
_, ok := store.Get("session-" + string(rune('0'+i)))
|
||||
if !ok {
|
||||
t.Errorf("session-%d 应存在", i)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// GenerateRandomBytes
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
func TestGenerateRandomBytes_长度正确(t *testing.T) {
|
||||
sizes := []int{0, 1, 16, 32, 64, 128}
|
||||
for _, size := range sizes {
|
||||
b, err := GenerateRandomBytes(size)
|
||||
if err != nil {
|
||||
t.Fatalf("GenerateRandomBytes(%d) 失败: %v", size, err)
|
||||
}
|
||||
if len(b) != size {
|
||||
t.Errorf("长度不匹配: got %d, want %d", len(b), size)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestGenerateRandomBytes_不同调用产生不同结果(t *testing.T) {
|
||||
b1, err := GenerateRandomBytes(32)
|
||||
if err != nil {
|
||||
t.Fatalf("第一次调用失败: %v", err)
|
||||
}
|
||||
b2, err := GenerateRandomBytes(32)
|
||||
if err != nil {
|
||||
t.Fatalf("第二次调用失败: %v", err)
|
||||
}
|
||||
// 两次生成的随机字节应该不同(概率上几乎不可能相同)
|
||||
if string(b1) == string(b2) {
|
||||
t.Error("两次生成的随机字节相同,概率极低,可能有问题")
|
||||
}
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// GenerateState
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
func TestGenerateState_返回值格式(t *testing.T) {
|
||||
state, err := GenerateState()
|
||||
if err != nil {
|
||||
t.Fatalf("GenerateState 失败: %v", err)
|
||||
}
|
||||
if state == "" {
|
||||
t.Error("GenerateState 返回空字符串")
|
||||
}
|
||||
// base64url 编码不应包含 +, /, =
|
||||
if strings.ContainsAny(state, "+/=") {
|
||||
t.Errorf("GenerateState 返回值包含非 base64url 字符: %s", state)
|
||||
}
|
||||
// 32 字节的 base64url 编码长度应为 43(去掉了尾部 = 填充)
|
||||
if len(state) != 43 {
|
||||
t.Errorf("GenerateState 返回值长度不匹配: got %d, want 43", len(state))
|
||||
}
|
||||
}
|
||||
|
||||
func TestGenerateState_唯一性(t *testing.T) {
|
||||
s1, _ := GenerateState()
|
||||
s2, _ := GenerateState()
|
||||
if s1 == s2 {
|
||||
t.Error("两次 GenerateState 结果相同")
|
||||
}
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// GenerateSessionID
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
func TestGenerateSessionID_返回值格式(t *testing.T) {
|
||||
id, err := GenerateSessionID()
|
||||
if err != nil {
|
||||
t.Fatalf("GenerateSessionID 失败: %v", err)
|
||||
}
|
||||
if id == "" {
|
||||
t.Error("GenerateSessionID 返回空字符串")
|
||||
}
|
||||
// 16 字节的 hex 编码长度应为 32
|
||||
if len(id) != 32 {
|
||||
t.Errorf("GenerateSessionID 返回值长度不匹配: got %d, want 32", len(id))
|
||||
}
|
||||
// 验证是合法的 hex 字符串
|
||||
if _, err := hex.DecodeString(id); err != nil {
|
||||
t.Errorf("GenerateSessionID 返回值不是合法的 hex 字符串: %s, err: %v", id, err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestGenerateSessionID_唯一性(t *testing.T) {
|
||||
id1, _ := GenerateSessionID()
|
||||
id2, _ := GenerateSessionID()
|
||||
if id1 == id2 {
|
||||
t.Error("两次 GenerateSessionID 结果相同")
|
||||
}
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// GenerateCodeVerifier
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
func TestGenerateCodeVerifier_返回值格式(t *testing.T) {
|
||||
verifier, err := GenerateCodeVerifier()
|
||||
if err != nil {
|
||||
t.Fatalf("GenerateCodeVerifier 失败: %v", err)
|
||||
}
|
||||
if verifier == "" {
|
||||
t.Error("GenerateCodeVerifier 返回空字符串")
|
||||
}
|
||||
// base64url 编码不应包含 +, /, =
|
||||
if strings.ContainsAny(verifier, "+/=") {
|
||||
t.Errorf("GenerateCodeVerifier 返回值包含非 base64url 字符: %s", verifier)
|
||||
}
|
||||
// 32 字节的 base64url 编码长度应为 43
|
||||
if len(verifier) != 43 {
|
||||
t.Errorf("GenerateCodeVerifier 返回值长度不匹配: got %d, want 43", len(verifier))
|
||||
}
|
||||
}
|
||||
|
||||
func TestGenerateCodeVerifier_唯一性(t *testing.T) {
|
||||
v1, _ := GenerateCodeVerifier()
|
||||
v2, _ := GenerateCodeVerifier()
|
||||
if v1 == v2 {
|
||||
t.Error("两次 GenerateCodeVerifier 结果相同")
|
||||
}
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// GenerateCodeChallenge
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
func TestGenerateCodeChallenge_SHA256_Base64URL(t *testing.T) {
|
||||
verifier := "dBjftJeZ4CVP-mB92K27uhbUJU1p1r_wW1gFWFOEjXk"
|
||||
|
||||
challenge := GenerateCodeChallenge(verifier)
|
||||
|
||||
// 手动计算预期值
|
||||
hash := sha256.Sum256([]byte(verifier))
|
||||
expected := strings.TrimRight(base64.URLEncoding.EncodeToString(hash[:]), "=")
|
||||
|
||||
if challenge != expected {
|
||||
t.Errorf("CodeChallenge 不匹配: got %s, want %s", challenge, expected)
|
||||
}
|
||||
}
|
||||
|
||||
func TestGenerateCodeChallenge_不含填充字符(t *testing.T) {
|
||||
challenge := GenerateCodeChallenge("test-verifier")
|
||||
if strings.Contains(challenge, "=") {
|
||||
t.Errorf("CodeChallenge 不应包含 = 填充字符: %s", challenge)
|
||||
}
|
||||
}
|
||||
|
||||
func TestGenerateCodeChallenge_不含非URL安全字符(t *testing.T) {
|
||||
challenge := GenerateCodeChallenge("another-verifier")
|
||||
if strings.ContainsAny(challenge, "+/") {
|
||||
t.Errorf("CodeChallenge 不应包含 + 或 / 字符: %s", challenge)
|
||||
}
|
||||
}
|
||||
|
||||
func TestGenerateCodeChallenge_相同输入相同输出(t *testing.T) {
|
||||
c1 := GenerateCodeChallenge("same-verifier")
|
||||
c2 := GenerateCodeChallenge("same-verifier")
|
||||
if c1 != c2 {
|
||||
t.Errorf("相同输入应产生相同输出: got %s and %s", c1, c2)
|
||||
}
|
||||
}
|
||||
|
||||
func TestGenerateCodeChallenge_不同输入不同输出(t *testing.T) {
|
||||
c1 := GenerateCodeChallenge("verifier-1")
|
||||
c2 := GenerateCodeChallenge("verifier-2")
|
||||
if c1 == c2 {
|
||||
t.Error("不同输入应产生不同输出")
|
||||
}
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// BuildAuthorizationURL
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
func TestBuildAuthorizationURL_参数验证(t *testing.T) {
|
||||
state := "test-state-123"
|
||||
codeChallenge := "test-challenge-abc"
|
||||
|
||||
authURL := BuildAuthorizationURL(state, codeChallenge)
|
||||
|
||||
// 验证以 AuthorizeURL 开头
|
||||
if !strings.HasPrefix(authURL, AuthorizeURL+"?") {
|
||||
t.Errorf("URL 应以 %s? 开头: got %s", AuthorizeURL, authURL)
|
||||
}
|
||||
|
||||
// 解析 URL 并验证参数
|
||||
parsed, err := url.Parse(authURL)
|
||||
if err != nil {
|
||||
t.Fatalf("解析 URL 失败: %v", err)
|
||||
}
|
||||
|
||||
params := parsed.Query()
|
||||
|
||||
expectedParams := map[string]string{
|
||||
"client_id": ClientID,
|
||||
"redirect_uri": RedirectURI,
|
||||
"response_type": "code",
|
||||
"scope": Scopes,
|
||||
"state": state,
|
||||
"code_challenge": codeChallenge,
|
||||
"code_challenge_method": "S256",
|
||||
"access_type": "offline",
|
||||
"prompt": "consent",
|
||||
"include_granted_scopes": "true",
|
||||
}
|
||||
|
||||
for key, want := range expectedParams {
|
||||
got := params.Get(key)
|
||||
if got != want {
|
||||
t.Errorf("参数 %s 不匹配: got %q, want %q", key, got, want)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestBuildAuthorizationURL_参数数量(t *testing.T) {
|
||||
authURL := BuildAuthorizationURL("s", "c")
|
||||
parsed, err := url.Parse(authURL)
|
||||
if err != nil {
|
||||
t.Fatalf("解析 URL 失败: %v", err)
|
||||
}
|
||||
|
||||
params := parsed.Query()
|
||||
// 应包含 10 个参数
|
||||
expectedCount := 10
|
||||
if len(params) != expectedCount {
|
||||
t.Errorf("参数数量不匹配: got %d, want %d", len(params), expectedCount)
|
||||
}
|
||||
}
|
||||
|
||||
func TestBuildAuthorizationURL_特殊字符编码(t *testing.T) {
|
||||
state := "state+with/special=chars"
|
||||
codeChallenge := "challenge+value"
|
||||
|
||||
authURL := BuildAuthorizationURL(state, codeChallenge)
|
||||
|
||||
parsed, err := url.Parse(authURL)
|
||||
if err != nil {
|
||||
t.Fatalf("解析 URL 失败: %v", err)
|
||||
}
|
||||
|
||||
// 解析后应正确还原特殊字符
|
||||
if got := parsed.Query().Get("state"); got != state {
|
||||
t.Errorf("state 参数编码/解码不匹配: got %q, want %q", got, state)
|
||||
}
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// 常量值验证
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
func TestConstants_值正确(t *testing.T) {
|
||||
if AuthorizeURL != "https://accounts.google.com/o/oauth2/v2/auth" {
|
||||
t.Errorf("AuthorizeURL 不匹配: got %s", AuthorizeURL)
|
||||
}
|
||||
if TokenURL != "https://oauth2.googleapis.com/token" {
|
||||
t.Errorf("TokenURL 不匹配: got %s", TokenURL)
|
||||
}
|
||||
if UserInfoURL != "https://www.googleapis.com/oauth2/v2/userinfo" {
|
||||
t.Errorf("UserInfoURL 不匹配: got %s", UserInfoURL)
|
||||
}
|
||||
if ClientID != "1071006060591-tmhssin2h21lcre235vtolojh4g403ep.apps.googleusercontent.com" {
|
||||
t.Errorf("ClientID 不匹配: got %s", ClientID)
|
||||
}
|
||||
secret, err := getClientSecret()
|
||||
if err != nil {
|
||||
t.Fatalf("getClientSecret 应返回默认值,但报错: %v", err)
|
||||
}
|
||||
if secret != "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" {
|
||||
t.Errorf("默认 client_secret 不匹配: got %s", secret)
|
||||
}
|
||||
if RedirectURI != "http://localhost:8085/callback" {
|
||||
t.Errorf("RedirectURI 不匹配: got %s", RedirectURI)
|
||||
}
|
||||
if GetUserAgent() != "antigravity/1.19.6 windows/amd64" {
|
||||
t.Errorf("UserAgent 不匹配: got %s", GetUserAgent())
|
||||
}
|
||||
if SessionTTL != 30*time.Minute {
|
||||
t.Errorf("SessionTTL 不匹配: got %v", SessionTTL)
|
||||
}
|
||||
if URLAvailabilityTTL != 5*time.Minute {
|
||||
t.Errorf("URLAvailabilityTTL 不匹配: got %v", URLAvailabilityTTL)
|
||||
}
|
||||
}
|
||||
|
||||
func TestScopes_包含必要范围(t *testing.T) {
|
||||
expectedScopes := []string{
|
||||
"https://www.googleapis.com/auth/cloud-platform",
|
||||
"https://www.googleapis.com/auth/userinfo.email",
|
||||
"https://www.googleapis.com/auth/userinfo.profile",
|
||||
"https://www.googleapis.com/auth/cclog",
|
||||
"https://www.googleapis.com/auth/experimentsandconfigs",
|
||||
}
|
||||
|
||||
for _, scope := range expectedScopes {
|
||||
if !strings.Contains(Scopes, scope) {
|
||||
t.Errorf("Scopes 缺少 %s", scope)
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -206,6 +206,7 @@ type modelInfo struct {
|
||||
var modelInfoMap = map[string]modelInfo{
|
||||
"claude-opus-4-5": {DisplayName: "Claude Opus 4.5", CanonicalID: "claude-opus-4-5-20250929"},
|
||||
"claude-opus-4-6": {DisplayName: "Claude Opus 4.6", CanonicalID: "claude-opus-4-6"},
|
||||
"claude-sonnet-4-6": {DisplayName: "Claude Sonnet 4.6", CanonicalID: "claude-sonnet-4-6"},
|
||||
"claude-sonnet-4-5": {DisplayName: "Claude Sonnet 4.5", CanonicalID: "claude-sonnet-4-5-20250929"},
|
||||
"claude-haiku-4-5": {DisplayName: "Claude Haiku 4.5", CanonicalID: "claude-haiku-4-5-20251001"},
|
||||
}
|
||||
|
||||
@@ -1,10 +1,13 @@
|
||||
package antigravity
|
||||
|
||||
import (
|
||||
"crypto/rand"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"log"
|
||||
"strings"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
)
|
||||
|
||||
// TransformGeminiToClaude 将 Gemini 响应转换为 Claude 格式(非流式)
|
||||
@@ -341,12 +344,30 @@ func buildGroundingText(grounding *GeminiGroundingMetadata) string {
|
||||
return builder.String()
|
||||
}
|
||||
|
||||
// generateRandomID 生成随机 ID
|
||||
// fallbackCounter 降级伪随机 ID 的全局计数器,混入 seed 避免高并发下 UnixNano 相同导致碰撞。
|
||||
var fallbackCounter uint64
|
||||
|
||||
// generateRandomID 生成密码学安全的随机 ID
|
||||
func generateRandomID() string {
|
||||
const chars = "abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789"
|
||||
result := make([]byte, 12)
|
||||
for i := range result {
|
||||
result[i] = chars[i%len(chars)]
|
||||
id := make([]byte, 12)
|
||||
randBytes := make([]byte, 12)
|
||||
if _, err := rand.Read(randBytes); err != nil {
|
||||
// 避免在请求路径里 panic:极端情况下熵源不可用时降级为伪随机。
|
||||
// 这里主要用于生成响应/工具调用的临时 ID,安全要求不高但需尽量避免碰撞。
|
||||
cnt := atomic.AddUint64(&fallbackCounter, 1)
|
||||
seed := uint64(time.Now().UnixNano()) ^ cnt
|
||||
seed ^= uint64(len(err.Error())) << 32
|
||||
for i := range id {
|
||||
seed ^= seed << 13
|
||||
seed ^= seed >> 7
|
||||
seed ^= seed << 17
|
||||
id[i] = chars[int(seed)%len(chars)]
|
||||
}
|
||||
return string(id)
|
||||
}
|
||||
return string(result)
|
||||
for i, b := range randBytes {
|
||||
id[i] = chars[int(b)%len(chars)]
|
||||
}
|
||||
return string(id)
|
||||
}
|
||||
|
||||
109
backend/internal/pkg/antigravity/response_transformer_test.go
Normal file
109
backend/internal/pkg/antigravity/response_transformer_test.go
Normal file
@@ -0,0 +1,109 @@
|
||||
//go:build unit
|
||||
|
||||
package antigravity
|
||||
|
||||
import (
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
// --- Task 7: 验证 generateRandomID 和降级碰撞防护 ---
|
||||
|
||||
func TestGenerateRandomID_Uniqueness(t *testing.T) {
|
||||
seen := make(map[string]struct{}, 100)
|
||||
for i := 0; i < 100; i++ {
|
||||
id := generateRandomID()
|
||||
require.Len(t, id, 12, "ID 长度应为 12")
|
||||
_, dup := seen[id]
|
||||
require.False(t, dup, "第 %d 次调用生成了重复 ID: %s", i, id)
|
||||
seen[id] = struct{}{}
|
||||
}
|
||||
}
|
||||
|
||||
func TestFallbackCounter_Increments(t *testing.T) {
|
||||
// 验证 fallbackCounter 的原子递增行为确保降级分支不会生成相同 seed
|
||||
before := atomic.LoadUint64(&fallbackCounter)
|
||||
cnt1 := atomic.AddUint64(&fallbackCounter, 1)
|
||||
cnt2 := atomic.AddUint64(&fallbackCounter, 1)
|
||||
require.Equal(t, before+1, cnt1, "第一次递增应为 before+1")
|
||||
require.Equal(t, before+2, cnt2, "第二次递增应为 before+2")
|
||||
require.NotEqual(t, cnt1, cnt2, "连续两次递增的计数器值应不同")
|
||||
}
|
||||
|
||||
func TestFallbackCounter_ConcurrentIncrements(t *testing.T) {
|
||||
// 验证并发递增的原子性 — 每次递增都应产生唯一值
|
||||
const goroutines = 50
|
||||
results := make([]uint64, goroutines)
|
||||
var wg sync.WaitGroup
|
||||
wg.Add(goroutines)
|
||||
|
||||
for i := 0; i < goroutines; i++ {
|
||||
go func(idx int) {
|
||||
defer wg.Done()
|
||||
results[idx] = atomic.AddUint64(&fallbackCounter, 1)
|
||||
}(i)
|
||||
}
|
||||
wg.Wait()
|
||||
|
||||
// 所有结果应唯一
|
||||
seen := make(map[uint64]bool, goroutines)
|
||||
for _, v := range results {
|
||||
assert.False(t, seen[v], "并发递增产生了重复值: %d", v)
|
||||
seen[v] = true
|
||||
}
|
||||
}
|
||||
|
||||
func TestGenerateRandomID_Charset(t *testing.T) {
|
||||
const validChars = "abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789"
|
||||
validSet := make(map[byte]struct{}, len(validChars))
|
||||
for i := 0; i < len(validChars); i++ {
|
||||
validSet[validChars[i]] = struct{}{}
|
||||
}
|
||||
|
||||
for i := 0; i < 50; i++ {
|
||||
id := generateRandomID()
|
||||
for j := 0; j < len(id); j++ {
|
||||
_, ok := validSet[id[j]]
|
||||
require.True(t, ok, "ID 包含非法字符: %c (ID=%s)", id[j], id)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestGenerateRandomID_Length(t *testing.T) {
|
||||
for i := 0; i < 100; i++ {
|
||||
id := generateRandomID()
|
||||
assert.Len(t, id, 12, "每次生成的 ID 长度应为 12")
|
||||
}
|
||||
}
|
||||
|
||||
func TestGenerateRandomID_ConcurrentUniqueness(t *testing.T) {
|
||||
// 验证并发调用不会产生重复 ID
|
||||
const goroutines = 100
|
||||
results := make([]string, goroutines)
|
||||
var wg sync.WaitGroup
|
||||
wg.Add(goroutines)
|
||||
|
||||
for i := 0; i < goroutines; i++ {
|
||||
go func(idx int) {
|
||||
defer wg.Done()
|
||||
results[idx] = generateRandomID()
|
||||
}(i)
|
||||
}
|
||||
wg.Wait()
|
||||
|
||||
seen := make(map[string]bool, goroutines)
|
||||
for _, id := range results {
|
||||
assert.False(t, seen[id], "并发调用产生了重复 ID: %s", id)
|
||||
seen[id] = true
|
||||
}
|
||||
}
|
||||
|
||||
func BenchmarkGenerateRandomID(b *testing.B) {
|
||||
for i := 0; i < b.N; i++ {
|
||||
_ = generateRandomID()
|
||||
}
|
||||
}
|
||||
@@ -10,8 +10,14 @@ const (
|
||||
BetaInterleavedThinking = "interleaved-thinking-2025-05-14"
|
||||
BetaFineGrainedToolStreaming = "fine-grained-tool-streaming-2025-05-14"
|
||||
BetaTokenCounting = "token-counting-2024-11-01"
|
||||
BetaContext1M = "context-1m-2025-08-07"
|
||||
BetaFastMode = "fast-mode-2026-02-01"
|
||||
)
|
||||
|
||||
// DroppedBetas 是转发时需要从 anthropic-beta header 中移除的 beta token 列表。
|
||||
// 这些 token 是客户端特有的,不应透传给上游 API。
|
||||
var DroppedBetas = []string{BetaContext1M, BetaFastMode}
|
||||
|
||||
// DefaultBetaHeader Claude Code 客户端默认的 anthropic-beta header
|
||||
const DefaultBetaHeader = BetaClaudeCode + "," + BetaOAuth + "," + BetaInterleavedThinking + "," + BetaFineGrainedToolStreaming
|
||||
|
||||
@@ -77,6 +83,12 @@ var DefaultModels = []Model{
|
||||
DisplayName: "Claude Opus 4.6",
|
||||
CreatedAt: "2026-02-06T00:00:00Z",
|
||||
},
|
||||
{
|
||||
ID: "claude-sonnet-4-6",
|
||||
Type: "model",
|
||||
DisplayName: "Claude Sonnet 4.6",
|
||||
CreatedAt: "2026-02-18T00:00:00Z",
|
||||
},
|
||||
{
|
||||
ID: "claude-sonnet-4-5-20250929",
|
||||
Type: "model",
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user