diff --git a/backend/cmd/server/wire_gen.go b/backend/cmd/server/wire_gen.go index 24cd93a2..034c70ec 100644 --- a/backend/cmd/server/wire_gen.go +++ b/backend/cmd/server/wire_gen.go @@ -104,7 +104,7 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) { proxyRepository := repository.NewProxyRepository(client, db) proxyExitInfoProber := repository.NewProxyExitInfoProber(configConfig) proxyLatencyCache := repository.NewProxyLatencyCache(redisClient) - adminService := service.NewAdminService(userRepository, groupRepository, accountRepository, soraAccountRepository, proxyRepository, apiKeyRepository, redeemCodeRepository, userGroupRateRepository, billingCacheService, proxyExitInfoProber, proxyLatencyCache, apiKeyAuthCacheInvalidator, client, settingService, subscriptionService) + adminService := service.NewAdminService(userRepository, groupRepository, accountRepository, soraAccountRepository, proxyRepository, apiKeyRepository, redeemCodeRepository, userGroupRateRepository, billingCacheService, proxyExitInfoProber, proxyLatencyCache, apiKeyAuthCacheInvalidator, client, settingService, subscriptionService, userSubscriptionRepository) concurrencyCache := repository.ProvideConcurrencyCache(redisClient, configConfig) concurrencyService := service.ProvideConcurrencyService(concurrencyCache, accountRepository, configConfig) adminUserHandler := admin.NewUserHandler(adminService, concurrencyService) diff --git a/backend/internal/server/api_contract_test.go b/backend/internal/server/api_contract_test.go index 236bd658..0b36bf66 100644 --- a/backend/internal/server/api_contract_test.go +++ b/backend/internal/server/api_contract_test.go @@ -645,7 +645,7 @@ func newContractDeps(t *testing.T) *contractDeps { settingRepo := newStubSettingRepo() settingService := service.NewSettingService(settingRepo, cfg) - adminService := service.NewAdminService(userRepo, groupRepo, &accountRepo, nil, proxyRepo, apiKeyRepo, redeemRepo, nil, nil, nil, nil, nil, nil, nil, nil) + adminService := service.NewAdminService(userRepo, groupRepo, &accountRepo, nil, proxyRepo, apiKeyRepo, redeemRepo, nil, nil, nil, nil, nil, nil, nil, nil, nil) authHandler := handler.NewAuthHandler(cfg, nil, userService, settingService, nil, redeemService, nil) apiKeyHandler := handler.NewAPIKeyHandler(apiKeyService) usageHandler := handler.NewUsageHandler(usageService, apiKeyService) diff --git a/backend/internal/service/admin_service.go b/backend/internal/service/admin_service.go index 27e2173a..dec4ed33 100644 --- a/backend/internal/service/admin_service.go +++ b/backend/internal/service/admin_service.go @@ -432,6 +432,7 @@ type adminServiceImpl struct { entClient *dbent.Client // 用于开启数据库事务 settingService *SettingService defaultSubAssigner DefaultSubscriptionAssigner + userSubRepo UserSubscriptionRepository } type userGroupRateBatchReader interface { @@ -459,6 +460,7 @@ func NewAdminService( entClient *dbent.Client, settingService *SettingService, defaultSubAssigner DefaultSubscriptionAssigner, + userSubRepo UserSubscriptionRepository, ) AdminService { return &adminServiceImpl{ userRepo: userRepo, @@ -476,6 +478,7 @@ func NewAdminService( entClient: entClient, settingService: settingService, defaultSubAssigner: defaultSubAssigner, + userSubRepo: userSubRepo, } } @@ -1277,9 +1280,17 @@ func (s *adminServiceImpl) AdminUpdateAPIKeyGroupID(ctx context.Context, keyID i if group.Status != StatusActive { return nil, infraerrors.BadRequest("GROUP_NOT_ACTIVE", "target group is not active") } - // 订阅类型分组:不允许通过此 API 直接绑定,需通过订阅管理流程 + // 订阅类型分组:用户须持有该分组的有效订阅才可绑定 if group.IsSubscriptionType() { - return nil, infraerrors.BadRequest("SUBSCRIPTION_GROUP_NOT_ALLOWED", "subscription groups must be managed through the subscription workflow") + if s.userSubRepo == nil { + return nil, infraerrors.InternalServer("SUBSCRIPTION_REPOSITORY_UNAVAILABLE", "subscription repository is not configured") + } + if _, err := s.userSubRepo.GetActiveByUserIDAndGroupID(ctx, apiKey.UserID, *groupID); err != nil { + if errors.Is(err, ErrSubscriptionNotFound) { + return nil, infraerrors.BadRequest("SUBSCRIPTION_REQUIRED", "user does not have an active subscription for this group") + } + return nil, err + } } gid := *groupID @@ -1287,7 +1298,7 @@ func (s *adminServiceImpl) AdminUpdateAPIKeyGroupID(ctx context.Context, keyID i apiKey.Group = group // 专属标准分组:使用事务保证「添加分组权限」与「更新 API Key」的原子性 - if group.IsExclusive { + if group.IsExclusive && !group.IsSubscriptionType() { opCtx := ctx var tx *dbent.Tx if s.entClient == nil { diff --git a/backend/internal/service/admin_service_apikey_test.go b/backend/internal/service/admin_service_apikey_test.go index a6d12f97..88d2f492 100644 --- a/backend/internal/service/admin_service_apikey_test.go +++ b/backend/internal/service/admin_service_apikey_test.go @@ -32,28 +32,44 @@ func (s *userRepoStubForGroupUpdate) AddGroupToAllowedGroups(_ context.Context, return s.addGroupErr } -func (s *userRepoStubForGroupUpdate) Create(context.Context, *User) error { panic("unexpected") } -func (s *userRepoStubForGroupUpdate) GetByID(context.Context, int64) (*User, error) { panic("unexpected") } -func (s *userRepoStubForGroupUpdate) GetByEmail(context.Context, string) (*User, error) { panic("unexpected") } -func (s *userRepoStubForGroupUpdate) GetFirstAdmin(context.Context) (*User, error) { panic("unexpected") } -func (s *userRepoStubForGroupUpdate) Update(context.Context, *User) error { panic("unexpected") } -func (s *userRepoStubForGroupUpdate) Delete(context.Context, int64) error { panic("unexpected") } +func (s *userRepoStubForGroupUpdate) Create(context.Context, *User) error { panic("unexpected") } +func (s *userRepoStubForGroupUpdate) GetByID(context.Context, int64) (*User, error) { + panic("unexpected") +} +func (s *userRepoStubForGroupUpdate) GetByEmail(context.Context, string) (*User, error) { + panic("unexpected") +} +func (s *userRepoStubForGroupUpdate) GetFirstAdmin(context.Context) (*User, error) { + panic("unexpected") +} +func (s *userRepoStubForGroupUpdate) Update(context.Context, *User) error { panic("unexpected") } +func (s *userRepoStubForGroupUpdate) Delete(context.Context, int64) error { panic("unexpected") } func (s *userRepoStubForGroupUpdate) List(context.Context, pagination.PaginationParams) ([]User, *pagination.PaginationResult, error) { panic("unexpected") } func (s *userRepoStubForGroupUpdate) ListWithFilters(context.Context, pagination.PaginationParams, UserListFilters) ([]User, *pagination.PaginationResult, error) { panic("unexpected") } -func (s *userRepoStubForGroupUpdate) UpdateBalance(context.Context, int64, float64) error { panic("unexpected") } -func (s *userRepoStubForGroupUpdate) DeductBalance(context.Context, int64, float64) error { panic("unexpected") } -func (s *userRepoStubForGroupUpdate) UpdateConcurrency(context.Context, int64, int) error { panic("unexpected") } -func (s *userRepoStubForGroupUpdate) ExistsByEmail(context.Context, string) (bool, error) { panic("unexpected") } +func (s *userRepoStubForGroupUpdate) UpdateBalance(context.Context, int64, float64) error { + panic("unexpected") +} +func (s *userRepoStubForGroupUpdate) DeductBalance(context.Context, int64, float64) error { + panic("unexpected") +} +func (s *userRepoStubForGroupUpdate) UpdateConcurrency(context.Context, int64, int) error { + panic("unexpected") +} +func (s *userRepoStubForGroupUpdate) ExistsByEmail(context.Context, string) (bool, error) { + panic("unexpected") +} func (s *userRepoStubForGroupUpdate) RemoveGroupFromAllowedGroups(context.Context, int64) (int64, error) { panic("unexpected") } -func (s *userRepoStubForGroupUpdate) UpdateTotpSecret(context.Context, int64, *string) error { panic("unexpected") } -func (s *userRepoStubForGroupUpdate) EnableTotp(context.Context, int64) error { panic("unexpected") } -func (s *userRepoStubForGroupUpdate) DisableTotp(context.Context, int64) error { panic("unexpected") } +func (s *userRepoStubForGroupUpdate) UpdateTotpSecret(context.Context, int64, *string) error { + panic("unexpected") +} +func (s *userRepoStubForGroupUpdate) EnableTotp(context.Context, int64) error { panic("unexpected") } +func (s *userRepoStubForGroupUpdate) DisableTotp(context.Context, int64) error { panic("unexpected") } // apiKeyRepoStubForGroupUpdate implements APIKeyRepository for AdminUpdateAPIKeyGroupID tests. type apiKeyRepoStubForGroupUpdate struct { @@ -194,6 +210,29 @@ func (s *groupRepoStubForGroupUpdate) UpdateSortOrders(context.Context, []GroupS panic("unexpected") } +type userSubRepoStubForGroupUpdate struct { + userSubRepoNoop + getActiveSub *UserSubscription + getActiveErr error + called bool + calledUserID int64 + calledGroupID int64 +} + +func (s *userSubRepoStubForGroupUpdate) GetActiveByUserIDAndGroupID(_ context.Context, userID, groupID int64) (*UserSubscription, error) { + s.called = true + s.calledUserID = userID + s.calledGroupID = groupID + if s.getActiveErr != nil { + return nil, s.getActiveErr + } + if s.getActiveSub == nil { + return nil, ErrSubscriptionNotFound + } + clone := *s.getActiveSub + return &clone, nil +} + // --------------------------------------------------------------------------- // Tests // --------------------------------------------------------------------------- @@ -386,14 +425,49 @@ func TestAdminService_AdminUpdateAPIKeyGroupID_NonExclusiveGroup_NoAllowedGroupU func TestAdminService_AdminUpdateAPIKeyGroupID_SubscriptionGroup_Blocked(t *testing.T) { existing := &APIKey{ID: 1, UserID: 42, Key: "sk-test", GroupID: nil} apiKeyRepo := &apiKeyRepoStubForGroupUpdate{key: existing} - groupRepo := &groupRepoStubForGroupUpdate{group: &Group{ID: 10, Name: "Sub", Status: StatusActive, IsExclusive: true, SubscriptionType: SubscriptionTypeSubscription}} + groupRepo := &groupRepoStubForGroupUpdate{group: &Group{ID: 10, Name: "Sub", Status: StatusActive, IsExclusive: false, SubscriptionType: SubscriptionTypeSubscription}} + userRepo := &userRepoStubForGroupUpdate{} + userSubRepo := &userSubRepoStubForGroupUpdate{getActiveErr: ErrSubscriptionNotFound} + svc := &adminServiceImpl{apiKeyRepo: apiKeyRepo, groupRepo: groupRepo, userRepo: userRepo, userSubRepo: userSubRepo} + + // 无有效订阅时应拒绝绑定 + _, err := svc.AdminUpdateAPIKeyGroupID(context.Background(), 1, int64Ptr(10)) + require.Error(t, err) + require.Equal(t, "SUBSCRIPTION_REQUIRED", infraerrors.Reason(err)) + require.True(t, userSubRepo.called) + require.Equal(t, int64(42), userSubRepo.calledUserID) + require.Equal(t, int64(10), userSubRepo.calledGroupID) + require.False(t, userRepo.addGroupCalled) +} + +func TestAdminService_AdminUpdateAPIKeyGroupID_SubscriptionGroup_RequiresRepo(t *testing.T) { + existing := &APIKey{ID: 1, UserID: 42, Key: "sk-test", GroupID: nil} + apiKeyRepo := &apiKeyRepoStubForGroupUpdate{key: existing} + groupRepo := &groupRepoStubForGroupUpdate{group: &Group{ID: 10, Name: "Sub", Status: StatusActive, IsExclusive: false, SubscriptionType: SubscriptionTypeSubscription}} userRepo := &userRepoStubForGroupUpdate{} svc := &adminServiceImpl{apiKeyRepo: apiKeyRepo, groupRepo: groupRepo, userRepo: userRepo} - // 订阅类型分组应被阻止绑定 _, err := svc.AdminUpdateAPIKeyGroupID(context.Background(), 1, int64Ptr(10)) require.Error(t, err) - require.Equal(t, "SUBSCRIPTION_GROUP_NOT_ALLOWED", infraerrors.Reason(err)) + require.Equal(t, "SUBSCRIPTION_REPOSITORY_UNAVAILABLE", infraerrors.Reason(err)) + require.False(t, userRepo.addGroupCalled) +} + +func TestAdminService_AdminUpdateAPIKeyGroupID_SubscriptionGroup_AllowsActiveSubscription(t *testing.T) { + existing := &APIKey{ID: 1, UserID: 42, Key: "sk-test", GroupID: nil} + apiKeyRepo := &apiKeyRepoStubForGroupUpdate{key: existing} + groupRepo := &groupRepoStubForGroupUpdate{group: &Group{ID: 10, Name: "Sub", Status: StatusActive, IsExclusive: true, SubscriptionType: SubscriptionTypeSubscription}} + userRepo := &userRepoStubForGroupUpdate{} + userSubRepo := &userSubRepoStubForGroupUpdate{ + getActiveSub: &UserSubscription{ID: 99, UserID: 42, GroupID: 10}, + } + svc := &adminServiceImpl{apiKeyRepo: apiKeyRepo, groupRepo: groupRepo, userRepo: userRepo, userSubRepo: userSubRepo} + + got, err := svc.AdminUpdateAPIKeyGroupID(context.Background(), 1, int64Ptr(10)) + require.NoError(t, err) + require.True(t, userSubRepo.called) + require.NotNil(t, got.APIKey.GroupID) + require.Equal(t, int64(10), *got.APIKey.GroupID) require.False(t, userRepo.addGroupCalled) } diff --git a/frontend/src/components/admin/user/UserApiKeysModal.vue b/frontend/src/components/admin/user/UserApiKeysModal.vue index 7e3c8c25..5e0a0fea 100644 --- a/frontend/src/components/admin/user/UserApiKeysModal.vue +++ b/frontend/src/components/admin/user/UserApiKeysModal.vue @@ -162,8 +162,7 @@ const load = async () => { const loadGroups = async () => { try { const groups = await adminAPI.groups.getAll() - // 过滤掉订阅类型分组(需通过订阅管理流程绑定) - allGroups.value = groups.filter((g) => g.subscription_type !== 'subscription') + allGroups.value = groups } catch (error) { console.error('Failed to load groups:', error) }