mirror of
https://github.com/Wei-Shaw/sub2api.git
synced 2026-04-19 04:47:26 +00:00
Merge branch 'Wei-Shaw:main' into main
This commit is contained in:
@@ -42,6 +42,9 @@ type AdminService interface {
|
||||
UpdateGroup(ctx context.Context, id int64, input *UpdateGroupInput) (*Group, error)
|
||||
DeleteGroup(ctx context.Context, id int64) error
|
||||
GetGroupAPIKeys(ctx context.Context, groupID int64, page, pageSize int) ([]APIKey, int64, error)
|
||||
GetGroupRateMultipliers(ctx context.Context, groupID int64) ([]UserGroupRateEntry, error)
|
||||
ClearGroupRateMultipliers(ctx context.Context, groupID int64) error
|
||||
BatchSetGroupRateMultipliers(ctx context.Context, groupID int64, entries []GroupRateMultiplierInput) error
|
||||
UpdateGroupSortOrders(ctx context.Context, updates []GroupSortOrderUpdate) error
|
||||
|
||||
// API Key management (admin)
|
||||
@@ -57,6 +60,8 @@ type AdminService interface {
|
||||
RefreshAccountCredentials(ctx context.Context, id int64) (*Account, error)
|
||||
ClearAccountError(ctx context.Context, id int64) (*Account, error)
|
||||
SetAccountError(ctx context.Context, id int64, errorMsg string) error
|
||||
// EnsureOpenAIPrivacy 检查 OpenAI OAuth 账号 privacy_mode,未设置则尝试关闭训练数据共享并持久化。
|
||||
EnsureOpenAIPrivacy(ctx context.Context, account *Account) string
|
||||
SetAccountSchedulable(ctx context.Context, id int64, schedulable bool) (*Account, error)
|
||||
BulkUpdateAccounts(ctx context.Context, input *BulkUpdateAccountsInput) (*BulkUpdateAccountsResult, error)
|
||||
CheckMixedChannelRisk(ctx context.Context, currentAccountID int64, currentAccountPlatform string, groupIDs []int64) error
|
||||
@@ -433,6 +438,7 @@ type adminServiceImpl struct {
|
||||
settingService *SettingService
|
||||
defaultSubAssigner DefaultSubscriptionAssigner
|
||||
userSubRepo UserSubscriptionRepository
|
||||
privacyClientFactory PrivacyClientFactory
|
||||
}
|
||||
|
||||
type userGroupRateBatchReader interface {
|
||||
@@ -461,6 +467,7 @@ func NewAdminService(
|
||||
settingService *SettingService,
|
||||
defaultSubAssigner DefaultSubscriptionAssigner,
|
||||
userSubRepo UserSubscriptionRepository,
|
||||
privacyClientFactory PrivacyClientFactory,
|
||||
) AdminService {
|
||||
return &adminServiceImpl{
|
||||
userRepo: userRepo,
|
||||
@@ -479,6 +486,7 @@ func NewAdminService(
|
||||
settingService: settingService,
|
||||
defaultSubAssigner: defaultSubAssigner,
|
||||
userSubRepo: userSubRepo,
|
||||
privacyClientFactory: privacyClientFactory,
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1244,6 +1252,27 @@ func (s *adminServiceImpl) GetGroupAPIKeys(ctx context.Context, groupID int64, p
|
||||
return keys, result.Total, nil
|
||||
}
|
||||
|
||||
func (s *adminServiceImpl) GetGroupRateMultipliers(ctx context.Context, groupID int64) ([]UserGroupRateEntry, error) {
|
||||
if s.userGroupRateRepo == nil {
|
||||
return nil, nil
|
||||
}
|
||||
return s.userGroupRateRepo.GetByGroupID(ctx, groupID)
|
||||
}
|
||||
|
||||
func (s *adminServiceImpl) ClearGroupRateMultipliers(ctx context.Context, groupID int64) error {
|
||||
if s.userGroupRateRepo == nil {
|
||||
return nil
|
||||
}
|
||||
return s.userGroupRateRepo.DeleteByGroupID(ctx, groupID)
|
||||
}
|
||||
|
||||
func (s *adminServiceImpl) BatchSetGroupRateMultipliers(ctx context.Context, groupID int64, entries []GroupRateMultiplierInput) error {
|
||||
if s.userGroupRateRepo == nil {
|
||||
return nil
|
||||
}
|
||||
return s.userGroupRateRepo.SyncGroupRateMultipliers(ctx, groupID, entries)
|
||||
}
|
||||
|
||||
func (s *adminServiceImpl) UpdateGroupSortOrders(ctx context.Context, updates []GroupSortOrderUpdate) error {
|
||||
return s.groupRepo.UpdateSortOrders(ctx, updates)
|
||||
}
|
||||
@@ -2502,3 +2531,39 @@ func (e *MixedChannelError) Error() string {
|
||||
func (s *adminServiceImpl) ResetAccountQuota(ctx context.Context, id int64) error {
|
||||
return s.accountRepo.ResetQuotaUsed(ctx, id)
|
||||
}
|
||||
|
||||
// EnsureOpenAIPrivacy 检查 OpenAI OAuth 账号是否已设置 privacy_mode,
|
||||
// 未设置则调用 disableOpenAITraining 并持久化到 Extra,返回设置的 mode 值。
|
||||
func (s *adminServiceImpl) EnsureOpenAIPrivacy(ctx context.Context, account *Account) string {
|
||||
if account.Platform != PlatformOpenAI || account.Type != AccountTypeOAuth {
|
||||
return ""
|
||||
}
|
||||
if s.privacyClientFactory == nil {
|
||||
return ""
|
||||
}
|
||||
if account.Extra != nil {
|
||||
if _, ok := account.Extra["privacy_mode"]; ok {
|
||||
return ""
|
||||
}
|
||||
}
|
||||
|
||||
token, _ := account.Credentials["access_token"].(string)
|
||||
if token == "" {
|
||||
return ""
|
||||
}
|
||||
|
||||
var proxyURL string
|
||||
if account.ProxyID != nil {
|
||||
if p, err := s.proxyRepo.GetByID(ctx, *account.ProxyID); err == nil && p != nil {
|
||||
proxyURL = p.URL()
|
||||
}
|
||||
}
|
||||
|
||||
mode := disableOpenAITraining(ctx, s.privacyClientFactory, token, proxyURL)
|
||||
if mode == "" {
|
||||
return ""
|
||||
}
|
||||
|
||||
_ = s.accountRepo.UpdateExtra(ctx, account.ID, map[string]any{"privacy_mode": mode})
|
||||
return mode
|
||||
}
|
||||
|
||||
176
backend/internal/service/admin_service_group_rate_test.go
Normal file
176
backend/internal/service/admin_service_group_rate_test.go
Normal file
@@ -0,0 +1,176 @@
|
||||
//go:build unit
|
||||
|
||||
package service
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
// userGroupRateRepoStubForGroupRate implements UserGroupRateRepository for group rate tests.
|
||||
type userGroupRateRepoStubForGroupRate struct {
|
||||
getByGroupIDData map[int64][]UserGroupRateEntry
|
||||
getByGroupIDErr error
|
||||
|
||||
deletedGroupIDs []int64
|
||||
deleteByGroupErr error
|
||||
|
||||
syncedGroupID int64
|
||||
syncedEntries []GroupRateMultiplierInput
|
||||
syncGroupErr error
|
||||
}
|
||||
|
||||
func (s *userGroupRateRepoStubForGroupRate) GetByUserID(_ context.Context, _ int64) (map[int64]float64, error) {
|
||||
panic("unexpected GetByUserID call")
|
||||
}
|
||||
|
||||
func (s *userGroupRateRepoStubForGroupRate) GetByUserAndGroup(_ context.Context, _, _ int64) (*float64, error) {
|
||||
panic("unexpected GetByUserAndGroup call")
|
||||
}
|
||||
|
||||
func (s *userGroupRateRepoStubForGroupRate) GetByGroupID(_ context.Context, groupID int64) ([]UserGroupRateEntry, error) {
|
||||
if s.getByGroupIDErr != nil {
|
||||
return nil, s.getByGroupIDErr
|
||||
}
|
||||
return s.getByGroupIDData[groupID], nil
|
||||
}
|
||||
|
||||
func (s *userGroupRateRepoStubForGroupRate) SyncUserGroupRates(_ context.Context, _ int64, _ map[int64]*float64) error {
|
||||
panic("unexpected SyncUserGroupRates call")
|
||||
}
|
||||
|
||||
func (s *userGroupRateRepoStubForGroupRate) SyncGroupRateMultipliers(_ context.Context, groupID int64, entries []GroupRateMultiplierInput) error {
|
||||
s.syncedGroupID = groupID
|
||||
s.syncedEntries = entries
|
||||
return s.syncGroupErr
|
||||
}
|
||||
|
||||
func (s *userGroupRateRepoStubForGroupRate) DeleteByGroupID(_ context.Context, groupID int64) error {
|
||||
s.deletedGroupIDs = append(s.deletedGroupIDs, groupID)
|
||||
return s.deleteByGroupErr
|
||||
}
|
||||
|
||||
func (s *userGroupRateRepoStubForGroupRate) DeleteByUserID(_ context.Context, _ int64) error {
|
||||
panic("unexpected DeleteByUserID call")
|
||||
}
|
||||
|
||||
func TestAdminService_GetGroupRateMultipliers(t *testing.T) {
|
||||
t.Run("returns entries for group", func(t *testing.T) {
|
||||
repo := &userGroupRateRepoStubForGroupRate{
|
||||
getByGroupIDData: map[int64][]UserGroupRateEntry{
|
||||
10: {
|
||||
{UserID: 1, UserName: "alice", UserEmail: "alice@test.com", RateMultiplier: 1.5},
|
||||
{UserID: 2, UserName: "bob", UserEmail: "bob@test.com", RateMultiplier: 0.8},
|
||||
},
|
||||
},
|
||||
}
|
||||
svc := &adminServiceImpl{userGroupRateRepo: repo}
|
||||
|
||||
entries, err := svc.GetGroupRateMultipliers(context.Background(), 10)
|
||||
require.NoError(t, err)
|
||||
require.Len(t, entries, 2)
|
||||
require.Equal(t, int64(1), entries[0].UserID)
|
||||
require.Equal(t, "alice", entries[0].UserName)
|
||||
require.Equal(t, 1.5, entries[0].RateMultiplier)
|
||||
require.Equal(t, int64(2), entries[1].UserID)
|
||||
require.Equal(t, 0.8, entries[1].RateMultiplier)
|
||||
})
|
||||
|
||||
t.Run("returns nil when repo is nil", func(t *testing.T) {
|
||||
svc := &adminServiceImpl{userGroupRateRepo: nil}
|
||||
|
||||
entries, err := svc.GetGroupRateMultipliers(context.Background(), 10)
|
||||
require.NoError(t, err)
|
||||
require.Nil(t, entries)
|
||||
})
|
||||
|
||||
t.Run("returns empty slice for group with no entries", func(t *testing.T) {
|
||||
repo := &userGroupRateRepoStubForGroupRate{
|
||||
getByGroupIDData: map[int64][]UserGroupRateEntry{},
|
||||
}
|
||||
svc := &adminServiceImpl{userGroupRateRepo: repo}
|
||||
|
||||
entries, err := svc.GetGroupRateMultipliers(context.Background(), 99)
|
||||
require.NoError(t, err)
|
||||
require.Nil(t, entries)
|
||||
})
|
||||
|
||||
t.Run("propagates repo error", func(t *testing.T) {
|
||||
repo := &userGroupRateRepoStubForGroupRate{
|
||||
getByGroupIDErr: errors.New("db error"),
|
||||
}
|
||||
svc := &adminServiceImpl{userGroupRateRepo: repo}
|
||||
|
||||
_, err := svc.GetGroupRateMultipliers(context.Background(), 10)
|
||||
require.Error(t, err)
|
||||
require.Contains(t, err.Error(), "db error")
|
||||
})
|
||||
}
|
||||
|
||||
func TestAdminService_ClearGroupRateMultipliers(t *testing.T) {
|
||||
t.Run("deletes by group ID", func(t *testing.T) {
|
||||
repo := &userGroupRateRepoStubForGroupRate{}
|
||||
svc := &adminServiceImpl{userGroupRateRepo: repo}
|
||||
|
||||
err := svc.ClearGroupRateMultipliers(context.Background(), 42)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, []int64{42}, repo.deletedGroupIDs)
|
||||
})
|
||||
|
||||
t.Run("returns nil when repo is nil", func(t *testing.T) {
|
||||
svc := &adminServiceImpl{userGroupRateRepo: nil}
|
||||
|
||||
err := svc.ClearGroupRateMultipliers(context.Background(), 42)
|
||||
require.NoError(t, err)
|
||||
})
|
||||
|
||||
t.Run("propagates repo error", func(t *testing.T) {
|
||||
repo := &userGroupRateRepoStubForGroupRate{
|
||||
deleteByGroupErr: errors.New("delete failed"),
|
||||
}
|
||||
svc := &adminServiceImpl{userGroupRateRepo: repo}
|
||||
|
||||
err := svc.ClearGroupRateMultipliers(context.Background(), 42)
|
||||
require.Error(t, err)
|
||||
require.Contains(t, err.Error(), "delete failed")
|
||||
})
|
||||
}
|
||||
|
||||
func TestAdminService_BatchSetGroupRateMultipliers(t *testing.T) {
|
||||
t.Run("syncs entries to repo", func(t *testing.T) {
|
||||
repo := &userGroupRateRepoStubForGroupRate{}
|
||||
svc := &adminServiceImpl{userGroupRateRepo: repo}
|
||||
|
||||
entries := []GroupRateMultiplierInput{
|
||||
{UserID: 1, RateMultiplier: 1.5},
|
||||
{UserID: 2, RateMultiplier: 0.8},
|
||||
}
|
||||
err := svc.BatchSetGroupRateMultipliers(context.Background(), 10, entries)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, int64(10), repo.syncedGroupID)
|
||||
require.Equal(t, entries, repo.syncedEntries)
|
||||
})
|
||||
|
||||
t.Run("returns nil when repo is nil", func(t *testing.T) {
|
||||
svc := &adminServiceImpl{userGroupRateRepo: nil}
|
||||
|
||||
err := svc.BatchSetGroupRateMultipliers(context.Background(), 10, nil)
|
||||
require.NoError(t, err)
|
||||
})
|
||||
|
||||
t.Run("propagates repo error", func(t *testing.T) {
|
||||
repo := &userGroupRateRepoStubForGroupRate{
|
||||
syncGroupErr: errors.New("sync failed"),
|
||||
}
|
||||
svc := &adminServiceImpl{userGroupRateRepo: repo}
|
||||
|
||||
err := svc.BatchSetGroupRateMultipliers(context.Background(), 10, []GroupRateMultiplierInput{
|
||||
{UserID: 1, RateMultiplier: 1.0},
|
||||
})
|
||||
require.Error(t, err)
|
||||
require.Contains(t, err.Error(), "sync failed")
|
||||
})
|
||||
}
|
||||
@@ -68,7 +68,15 @@ func (s *userGroupRateRepoStubForListUsers) SyncUserGroupRates(_ context.Context
|
||||
panic("unexpected SyncUserGroupRates call")
|
||||
}
|
||||
|
||||
func (s *userGroupRateRepoStubForListUsers) DeleteByGroupID(_ context.Context, groupID int64) error {
|
||||
func (s *userGroupRateRepoStubForListUsers) GetByGroupID(_ context.Context, _ int64) ([]UserGroupRateEntry, error) {
|
||||
panic("unexpected GetByGroupID call")
|
||||
}
|
||||
|
||||
func (s *userGroupRateRepoStubForListUsers) SyncGroupRateMultipliers(_ context.Context, _ int64, _ []GroupRateMultiplierInput) error {
|
||||
panic("unexpected SyncGroupRateMultipliers call")
|
||||
}
|
||||
|
||||
func (s *userGroupRateRepoStubForListUsers) DeleteByGroupID(_ context.Context, _ int64) error {
|
||||
panic("unexpected DeleteByGroupID call")
|
||||
}
|
||||
|
||||
|
||||
77
backend/internal/service/openai_privacy_service.go
Normal file
77
backend/internal/service/openai_privacy_service.go
Normal file
@@ -0,0 +1,77 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"log/slog"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/imroc/req/v3"
|
||||
)
|
||||
|
||||
// PrivacyClientFactory creates an HTTP client for privacy API calls.
|
||||
// Injected from repository layer to avoid import cycles.
|
||||
type PrivacyClientFactory func(proxyURL string) (*req.Client, error)
|
||||
|
||||
const (
|
||||
openAISettingsURL = "https://chatgpt.com/backend-api/settings/account_user_setting"
|
||||
|
||||
PrivacyModeTrainingOff = "training_off"
|
||||
PrivacyModeFailed = "training_set_failed"
|
||||
PrivacyModeCFBlocked = "training_set_cf_blocked"
|
||||
)
|
||||
|
||||
// disableOpenAITraining calls ChatGPT settings API to turn off "Improve the model for everyone".
|
||||
// Returns privacy_mode value: "training_off" on success, "cf_blocked" / "failed" on failure.
|
||||
func disableOpenAITraining(ctx context.Context, clientFactory PrivacyClientFactory, accessToken, proxyURL string) string {
|
||||
if accessToken == "" || clientFactory == nil {
|
||||
return ""
|
||||
}
|
||||
|
||||
ctx, cancel := context.WithTimeout(ctx, 15*time.Second)
|
||||
defer cancel()
|
||||
|
||||
client, err := clientFactory(proxyURL)
|
||||
if err != nil {
|
||||
slog.Warn("openai_privacy_client_error", "error", err.Error())
|
||||
return PrivacyModeFailed
|
||||
}
|
||||
|
||||
resp, err := client.R().
|
||||
SetContext(ctx).
|
||||
SetHeader("Authorization", "Bearer "+accessToken).
|
||||
SetHeader("Origin", "https://chatgpt.com").
|
||||
SetHeader("Referer", "https://chatgpt.com/").
|
||||
SetQueryParam("feature", "training_allowed").
|
||||
SetQueryParam("value", "false").
|
||||
Patch(openAISettingsURL)
|
||||
|
||||
if err != nil {
|
||||
slog.Warn("openai_privacy_request_error", "error", err.Error())
|
||||
return PrivacyModeFailed
|
||||
}
|
||||
|
||||
if resp.StatusCode == 403 || resp.StatusCode == 503 {
|
||||
body := resp.String()
|
||||
if strings.Contains(body, "cloudflare") || strings.Contains(body, "cf-") || strings.Contains(body, "Just a moment") {
|
||||
slog.Warn("openai_privacy_cf_blocked", "status", resp.StatusCode)
|
||||
return PrivacyModeCFBlocked
|
||||
}
|
||||
}
|
||||
|
||||
if !resp.IsSuccessState() {
|
||||
slog.Warn("openai_privacy_failed", "status", resp.StatusCode, "body", truncate(resp.String(), 200))
|
||||
return PrivacyModeFailed
|
||||
}
|
||||
|
||||
slog.Info("openai_privacy_training_disabled")
|
||||
return PrivacyModeTrainingOff
|
||||
}
|
||||
|
||||
func truncate(s string, n int) string {
|
||||
if len(s) <= n {
|
||||
return s
|
||||
}
|
||||
return s[:n] + fmt.Sprintf("...(%d more)", len(s)-n)
|
||||
}
|
||||
@@ -11,17 +11,19 @@ import (
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
// resetQuotaUserSubRepoStub 支持 GetByID、ResetDailyUsage、ResetWeeklyUsage,
|
||||
// resetQuotaUserSubRepoStub 支持 GetByID、ResetDailyUsage、ResetWeeklyUsage、ResetMonthlyUsage,
|
||||
// 其余方法继承 userSubRepoNoop(panic)。
|
||||
type resetQuotaUserSubRepoStub struct {
|
||||
userSubRepoNoop
|
||||
|
||||
sub *UserSubscription
|
||||
|
||||
resetDailyCalled bool
|
||||
resetWeeklyCalled bool
|
||||
resetDailyErr error
|
||||
resetWeeklyErr error
|
||||
resetDailyCalled bool
|
||||
resetWeeklyCalled bool
|
||||
resetMonthlyCalled bool
|
||||
resetDailyErr error
|
||||
resetWeeklyErr error
|
||||
resetMonthlyErr error
|
||||
}
|
||||
|
||||
func (r *resetQuotaUserSubRepoStub) GetByID(_ context.Context, id int64) (*UserSubscription, error) {
|
||||
@@ -46,6 +48,11 @@ func (r *resetQuotaUserSubRepoStub) ResetWeeklyUsage(_ context.Context, _ int64,
|
||||
return r.resetWeeklyErr
|
||||
}
|
||||
|
||||
func (r *resetQuotaUserSubRepoStub) ResetMonthlyUsage(_ context.Context, _ int64, _ time.Time) error {
|
||||
r.resetMonthlyCalled = true
|
||||
return r.resetMonthlyErr
|
||||
}
|
||||
|
||||
func newResetQuotaSvc(stub *resetQuotaUserSubRepoStub) *SubscriptionService {
|
||||
return NewSubscriptionService(groupRepoNoop{}, stub, nil, nil, nil)
|
||||
}
|
||||
@@ -56,12 +63,13 @@ func TestAdminResetQuota_ResetBoth(t *testing.T) {
|
||||
}
|
||||
svc := newResetQuotaSvc(stub)
|
||||
|
||||
result, err := svc.AdminResetQuota(context.Background(), 1, true, true)
|
||||
result, err := svc.AdminResetQuota(context.Background(), 1, true, true, false)
|
||||
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, result)
|
||||
require.True(t, stub.resetDailyCalled, "应调用 ResetDailyUsage")
|
||||
require.True(t, stub.resetWeeklyCalled, "应调用 ResetWeeklyUsage")
|
||||
require.False(t, stub.resetMonthlyCalled, "不应调用 ResetMonthlyUsage")
|
||||
}
|
||||
|
||||
func TestAdminResetQuota_ResetDailyOnly(t *testing.T) {
|
||||
@@ -70,12 +78,13 @@ func TestAdminResetQuota_ResetDailyOnly(t *testing.T) {
|
||||
}
|
||||
svc := newResetQuotaSvc(stub)
|
||||
|
||||
result, err := svc.AdminResetQuota(context.Background(), 2, true, false)
|
||||
result, err := svc.AdminResetQuota(context.Background(), 2, true, false, false)
|
||||
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, result)
|
||||
require.True(t, stub.resetDailyCalled, "应调用 ResetDailyUsage")
|
||||
require.False(t, stub.resetWeeklyCalled, "不应调用 ResetWeeklyUsage")
|
||||
require.False(t, stub.resetMonthlyCalled, "不应调用 ResetMonthlyUsage")
|
||||
}
|
||||
|
||||
func TestAdminResetQuota_ResetWeeklyOnly(t *testing.T) {
|
||||
@@ -84,12 +93,13 @@ func TestAdminResetQuota_ResetWeeklyOnly(t *testing.T) {
|
||||
}
|
||||
svc := newResetQuotaSvc(stub)
|
||||
|
||||
result, err := svc.AdminResetQuota(context.Background(), 3, false, true)
|
||||
result, err := svc.AdminResetQuota(context.Background(), 3, false, true, false)
|
||||
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, result)
|
||||
require.False(t, stub.resetDailyCalled, "不应调用 ResetDailyUsage")
|
||||
require.True(t, stub.resetWeeklyCalled, "应调用 ResetWeeklyUsage")
|
||||
require.False(t, stub.resetMonthlyCalled, "不应调用 ResetMonthlyUsage")
|
||||
}
|
||||
|
||||
func TestAdminResetQuota_BothFalseReturnsError(t *testing.T) {
|
||||
@@ -98,22 +108,24 @@ func TestAdminResetQuota_BothFalseReturnsError(t *testing.T) {
|
||||
}
|
||||
svc := newResetQuotaSvc(stub)
|
||||
|
||||
_, err := svc.AdminResetQuota(context.Background(), 7, false, false)
|
||||
_, err := svc.AdminResetQuota(context.Background(), 7, false, false, false)
|
||||
|
||||
require.ErrorIs(t, err, ErrInvalidInput)
|
||||
require.False(t, stub.resetDailyCalled)
|
||||
require.False(t, stub.resetWeeklyCalled)
|
||||
require.False(t, stub.resetMonthlyCalled)
|
||||
}
|
||||
|
||||
func TestAdminResetQuota_SubscriptionNotFound(t *testing.T) {
|
||||
stub := &resetQuotaUserSubRepoStub{sub: nil}
|
||||
svc := newResetQuotaSvc(stub)
|
||||
|
||||
_, err := svc.AdminResetQuota(context.Background(), 999, true, true)
|
||||
_, err := svc.AdminResetQuota(context.Background(), 999, true, true, true)
|
||||
|
||||
require.ErrorIs(t, err, ErrSubscriptionNotFound)
|
||||
require.False(t, stub.resetDailyCalled)
|
||||
require.False(t, stub.resetWeeklyCalled)
|
||||
require.False(t, stub.resetMonthlyCalled)
|
||||
}
|
||||
|
||||
func TestAdminResetQuota_ResetDailyUsageError(t *testing.T) {
|
||||
@@ -124,7 +136,7 @@ func TestAdminResetQuota_ResetDailyUsageError(t *testing.T) {
|
||||
}
|
||||
svc := newResetQuotaSvc(stub)
|
||||
|
||||
_, err := svc.AdminResetQuota(context.Background(), 4, true, true)
|
||||
_, err := svc.AdminResetQuota(context.Background(), 4, true, true, false)
|
||||
|
||||
require.ErrorIs(t, err, dbErr)
|
||||
require.True(t, stub.resetDailyCalled)
|
||||
@@ -139,12 +151,41 @@ func TestAdminResetQuota_ResetWeeklyUsageError(t *testing.T) {
|
||||
}
|
||||
svc := newResetQuotaSvc(stub)
|
||||
|
||||
_, err := svc.AdminResetQuota(context.Background(), 5, false, true)
|
||||
_, err := svc.AdminResetQuota(context.Background(), 5, false, true, false)
|
||||
|
||||
require.ErrorIs(t, err, dbErr)
|
||||
require.True(t, stub.resetWeeklyCalled)
|
||||
}
|
||||
|
||||
func TestAdminResetQuota_ResetMonthlyOnly(t *testing.T) {
|
||||
stub := &resetQuotaUserSubRepoStub{
|
||||
sub: &UserSubscription{ID: 8, UserID: 10, GroupID: 20},
|
||||
}
|
||||
svc := newResetQuotaSvc(stub)
|
||||
|
||||
result, err := svc.AdminResetQuota(context.Background(), 8, false, false, true)
|
||||
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, result)
|
||||
require.False(t, stub.resetDailyCalled, "不应调用 ResetDailyUsage")
|
||||
require.False(t, stub.resetWeeklyCalled, "不应调用 ResetWeeklyUsage")
|
||||
require.True(t, stub.resetMonthlyCalled, "应调用 ResetMonthlyUsage")
|
||||
}
|
||||
|
||||
func TestAdminResetQuota_ResetMonthlyUsageError(t *testing.T) {
|
||||
dbErr := errors.New("db error")
|
||||
stub := &resetQuotaUserSubRepoStub{
|
||||
sub: &UserSubscription{ID: 9, UserID: 10, GroupID: 20},
|
||||
resetMonthlyErr: dbErr,
|
||||
}
|
||||
svc := newResetQuotaSvc(stub)
|
||||
|
||||
_, err := svc.AdminResetQuota(context.Background(), 9, false, false, true)
|
||||
|
||||
require.ErrorIs(t, err, dbErr)
|
||||
require.True(t, stub.resetMonthlyCalled)
|
||||
}
|
||||
|
||||
func TestAdminResetQuota_ReturnsRefreshedSub(t *testing.T) {
|
||||
stub := &resetQuotaUserSubRepoStub{
|
||||
sub: &UserSubscription{
|
||||
@@ -156,7 +197,7 @@ func TestAdminResetQuota_ReturnsRefreshedSub(t *testing.T) {
|
||||
}
|
||||
|
||||
svc := newResetQuotaSvc(stub)
|
||||
result, err := svc.AdminResetQuota(context.Background(), 6, true, false)
|
||||
result, err := svc.AdminResetQuota(context.Background(), 6, true, false, false)
|
||||
|
||||
require.NoError(t, err)
|
||||
// ResetDailyUsage stub 会将 sub.DailyUsageUSD 归零,
|
||||
|
||||
@@ -31,7 +31,7 @@ var (
|
||||
ErrSubscriptionAlreadyExists = infraerrors.Conflict("SUBSCRIPTION_ALREADY_EXISTS", "subscription already exists for this user and group")
|
||||
ErrSubscriptionAssignConflict = infraerrors.Conflict("SUBSCRIPTION_ASSIGN_CONFLICT", "subscription exists but request conflicts with existing assignment semantics")
|
||||
ErrGroupNotSubscriptionType = infraerrors.BadRequest("GROUP_NOT_SUBSCRIPTION_TYPE", "group is not a subscription type")
|
||||
ErrInvalidInput = infraerrors.BadRequest("INVALID_INPUT", "at least one of resetDaily or resetWeekly must be true")
|
||||
ErrInvalidInput = infraerrors.BadRequest("INVALID_INPUT", "at least one of resetDaily, resetWeekly, or resetMonthly must be true")
|
||||
ErrDailyLimitExceeded = infraerrors.TooManyRequests("DAILY_LIMIT_EXCEEDED", "daily usage limit exceeded")
|
||||
ErrWeeklyLimitExceeded = infraerrors.TooManyRequests("WEEKLY_LIMIT_EXCEEDED", "weekly usage limit exceeded")
|
||||
ErrMonthlyLimitExceeded = infraerrors.TooManyRequests("MONTHLY_LIMIT_EXCEEDED", "monthly usage limit exceeded")
|
||||
@@ -696,10 +696,10 @@ func (s *SubscriptionService) CheckAndActivateWindow(ctx context.Context, sub *U
|
||||
return s.userSubRepo.ActivateWindows(ctx, sub.ID, windowStart)
|
||||
}
|
||||
|
||||
// AdminResetQuota manually resets the daily and/or weekly usage windows.
|
||||
// AdminResetQuota manually resets the daily, weekly, and/or monthly usage windows.
|
||||
// Uses startOfDay(now) as the new window start, matching automatic resets.
|
||||
func (s *SubscriptionService) AdminResetQuota(ctx context.Context, subscriptionID int64, resetDaily, resetWeekly bool) (*UserSubscription, error) {
|
||||
if !resetDaily && !resetWeekly {
|
||||
func (s *SubscriptionService) AdminResetQuota(ctx context.Context, subscriptionID int64, resetDaily, resetWeekly, resetMonthly bool) (*UserSubscription, error) {
|
||||
if !resetDaily && !resetWeekly && !resetMonthly {
|
||||
return nil, ErrInvalidInput
|
||||
}
|
||||
sub, err := s.userSubRepo.GetByID(ctx, subscriptionID)
|
||||
@@ -717,8 +717,18 @@ func (s *SubscriptionService) AdminResetQuota(ctx context.Context, subscriptionI
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
// Invalidate caches, same as CheckAndResetWindows
|
||||
if resetMonthly {
|
||||
if err := s.userSubRepo.ResetMonthlyUsage(ctx, sub.ID, windowStart); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
// Invalidate L1 ristretto cache. Ristretto's Del() is asynchronous by design,
|
||||
// so call Wait() immediately after to flush pending operations and guarantee
|
||||
// the deleted key is not returned on the very next Get() call.
|
||||
s.InvalidateSubCache(sub.UserID, sub.GroupID)
|
||||
if s.subCacheL1 != nil {
|
||||
s.subCacheL1.Wait()
|
||||
}
|
||||
if s.billingCacheService != nil {
|
||||
_ = s.billingCacheService.InvalidateSubscription(ctx, sub.UserID, sub.GroupID)
|
||||
}
|
||||
|
||||
@@ -21,6 +21,10 @@ type TokenRefreshService struct {
|
||||
schedulerCache SchedulerCache // 用于同步更新调度器缓存,解决 token 刷新后缓存不一致问题
|
||||
tempUnschedCache TempUnschedCache // 用于清除 Redis 中的临时不可调度缓存
|
||||
|
||||
// OpenAI privacy: 刷新成功后检查并设置 training opt-out
|
||||
privacyClientFactory PrivacyClientFactory
|
||||
proxyRepo ProxyRepository
|
||||
|
||||
stopCh chan struct{}
|
||||
wg sync.WaitGroup
|
||||
}
|
||||
@@ -72,6 +76,12 @@ func (s *TokenRefreshService) SetSoraAccountRepo(repo SoraAccountRepository) {
|
||||
}
|
||||
}
|
||||
|
||||
// SetPrivacyDeps 注入 OpenAI privacy opt-out 所需依赖
|
||||
func (s *TokenRefreshService) SetPrivacyDeps(factory PrivacyClientFactory, proxyRepo ProxyRepository) {
|
||||
s.privacyClientFactory = factory
|
||||
s.proxyRepo = proxyRepo
|
||||
}
|
||||
|
||||
// Start 启动后台刷新服务
|
||||
func (s *TokenRefreshService) Start() {
|
||||
if !s.cfg.Enabled {
|
||||
@@ -277,6 +287,8 @@ func (s *TokenRefreshService) refreshWithRetry(ctx context.Context, account *Acc
|
||||
slog.Debug("token_refresh.scheduler_cache_synced", "account_id", account.ID)
|
||||
}
|
||||
}
|
||||
// OpenAI OAuth: 刷新成功后,检查是否已设置 privacy_mode,未设置则尝试关闭训练数据共享
|
||||
s.ensureOpenAIPrivacy(ctx, account)
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -341,3 +353,49 @@ func isNonRetryableRefreshError(err error) bool {
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// ensureOpenAIPrivacy 检查 OpenAI OAuth 账号是否已设置 privacy_mode,
|
||||
// 未设置则调用 disableOpenAITraining 并持久化结果到 Extra。
|
||||
func (s *TokenRefreshService) ensureOpenAIPrivacy(ctx context.Context, account *Account) {
|
||||
if account.Platform != PlatformOpenAI || account.Type != AccountTypeOAuth {
|
||||
return
|
||||
}
|
||||
if s.privacyClientFactory == nil {
|
||||
return
|
||||
}
|
||||
// 已设置过则跳过
|
||||
if account.Extra != nil {
|
||||
if _, ok := account.Extra["privacy_mode"]; ok {
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
token, _ := account.Credentials["access_token"].(string)
|
||||
if token == "" {
|
||||
return
|
||||
}
|
||||
|
||||
var proxyURL string
|
||||
if account.ProxyID != nil && s.proxyRepo != nil {
|
||||
if p, err := s.proxyRepo.GetByID(ctx, *account.ProxyID); err == nil && p != nil {
|
||||
proxyURL = p.URL()
|
||||
}
|
||||
}
|
||||
|
||||
mode := disableOpenAITraining(ctx, s.privacyClientFactory, token, proxyURL)
|
||||
if mode == "" {
|
||||
return
|
||||
}
|
||||
|
||||
if err := s.accountRepo.UpdateExtra(ctx, account.ID, map[string]any{"privacy_mode": mode}); err != nil {
|
||||
slog.Warn("token_refresh.update_privacy_mode_failed",
|
||||
"account_id", account.ID,
|
||||
"error", err,
|
||||
)
|
||||
} else {
|
||||
slog.Info("token_refresh.privacy_mode_set",
|
||||
"account_id", account.ID,
|
||||
"privacy_mode", mode,
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -2,6 +2,22 @@ package service
|
||||
|
||||
import "context"
|
||||
|
||||
// UserGroupRateEntry 分组下用户专属倍率条目
|
||||
type UserGroupRateEntry struct {
|
||||
UserID int64 `json:"user_id"`
|
||||
UserName string `json:"user_name"`
|
||||
UserEmail string `json:"user_email"`
|
||||
UserNotes string `json:"user_notes"`
|
||||
UserStatus string `json:"user_status"`
|
||||
RateMultiplier float64 `json:"rate_multiplier"`
|
||||
}
|
||||
|
||||
// GroupRateMultiplierInput 批量设置分组倍率的输入条目
|
||||
type GroupRateMultiplierInput struct {
|
||||
UserID int64 `json:"user_id"`
|
||||
RateMultiplier float64 `json:"rate_multiplier"`
|
||||
}
|
||||
|
||||
// UserGroupRateRepository 用户专属分组倍率仓储接口
|
||||
// 允许管理员为特定用户设置分组的专属计费倍率,覆盖分组默认倍率
|
||||
type UserGroupRateRepository interface {
|
||||
@@ -13,10 +29,16 @@ type UserGroupRateRepository interface {
|
||||
// 如果未设置专属倍率,返回 nil
|
||||
GetByUserAndGroup(ctx context.Context, userID, groupID int64) (*float64, error)
|
||||
|
||||
// GetByGroupID 获取指定分组下所有用户的专属倍率
|
||||
GetByGroupID(ctx context.Context, groupID int64) ([]UserGroupRateEntry, error)
|
||||
|
||||
// SyncUserGroupRates 同步用户的分组专属倍率
|
||||
// rates: map[groupID]*rateMultiplier,nil 表示删除该分组的专属倍率
|
||||
SyncUserGroupRates(ctx context.Context, userID int64, rates map[int64]*float64) error
|
||||
|
||||
// SyncGroupRateMultipliers 批量同步分组的用户专属倍率(替换整组数据)
|
||||
SyncGroupRateMultipliers(ctx context.Context, groupID int64, entries []GroupRateMultiplierInput) error
|
||||
|
||||
// DeleteByGroupID 删除指定分组的所有用户专属倍率(分组删除时调用)
|
||||
DeleteByGroupID(ctx context.Context, groupID int64) error
|
||||
|
||||
|
||||
@@ -49,10 +49,14 @@ func ProvideTokenRefreshService(
|
||||
schedulerCache SchedulerCache,
|
||||
cfg *config.Config,
|
||||
tempUnschedCache TempUnschedCache,
|
||||
privacyClientFactory PrivacyClientFactory,
|
||||
proxyRepo ProxyRepository,
|
||||
) *TokenRefreshService {
|
||||
svc := NewTokenRefreshService(accountRepo, oauthService, openaiOAuthService, geminiOAuthService, antigravityOAuthService, cacheInvalidator, schedulerCache, cfg, tempUnschedCache)
|
||||
// 注入 Sora 账号扩展表仓储,用于 OpenAI Token 刷新时同步 sora_accounts 表
|
||||
svc.SetSoraAccountRepo(soraAccountRepo)
|
||||
// 注入 OpenAI privacy opt-out 依赖
|
||||
svc.SetPrivacyDeps(privacyClientFactory, proxyRepo)
|
||||
svc.Start()
|
||||
return svc
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user