diff --git a/controller/custom_oauth.go b/controller/custom_oauth.go new file mode 100644 index 000000000..a4acfc38a --- /dev/null +++ b/controller/custom_oauth.go @@ -0,0 +1,386 @@ +package controller + +import ( + "net/http" + "strconv" + + "github.com/QuantumNous/new-api/common" + "github.com/QuantumNous/new-api/model" + "github.com/QuantumNous/new-api/oauth" + "github.com/gin-gonic/gin" +) + +// CustomOAuthProviderResponse is the response structure for custom OAuth providers +// It excludes sensitive fields like client_secret +type CustomOAuthProviderResponse struct { + Id int `json:"id"` + Name string `json:"name"` + Slug string `json:"slug"` + Enabled bool `json:"enabled"` + ClientId string `json:"client_id"` + AuthorizationEndpoint string `json:"authorization_endpoint"` + TokenEndpoint string `json:"token_endpoint"` + UserInfoEndpoint string `json:"user_info_endpoint"` + Scopes string `json:"scopes"` + UserIdField string `json:"user_id_field"` + UsernameField string `json:"username_field"` + DisplayNameField string `json:"display_name_field"` + EmailField string `json:"email_field"` + WellKnown string `json:"well_known"` + AuthStyle int `json:"auth_style"` +} + +func toCustomOAuthProviderResponse(p *model.CustomOAuthProvider) *CustomOAuthProviderResponse { + return &CustomOAuthProviderResponse{ + Id: p.Id, + Name: p.Name, + Slug: p.Slug, + Enabled: p.Enabled, + ClientId: p.ClientId, + AuthorizationEndpoint: p.AuthorizationEndpoint, + TokenEndpoint: p.TokenEndpoint, + UserInfoEndpoint: p.UserInfoEndpoint, + Scopes: p.Scopes, + UserIdField: p.UserIdField, + UsernameField: p.UsernameField, + DisplayNameField: p.DisplayNameField, + EmailField: p.EmailField, + WellKnown: p.WellKnown, + AuthStyle: p.AuthStyle, + } +} + +// GetCustomOAuthProviders returns all custom OAuth providers +func GetCustomOAuthProviders(c *gin.Context) { + providers, err := model.GetAllCustomOAuthProviders() + if err != nil { + common.ApiError(c, err) + return + } + + response := make([]*CustomOAuthProviderResponse, len(providers)) + for i, p := range providers { + response[i] = toCustomOAuthProviderResponse(p) + } + + c.JSON(http.StatusOK, gin.H{ + "success": true, + "message": "", + "data": response, + }) +} + +// GetCustomOAuthProvider returns a single custom OAuth provider by ID +func GetCustomOAuthProvider(c *gin.Context) { + idStr := c.Param("id") + id, err := strconv.Atoi(idStr) + if err != nil { + common.ApiErrorMsg(c, "无效的 ID") + return + } + + provider, err := model.GetCustomOAuthProviderById(id) + if err != nil { + common.ApiErrorMsg(c, "未找到该 OAuth 提供商") + return + } + + c.JSON(http.StatusOK, gin.H{ + "success": true, + "message": "", + "data": toCustomOAuthProviderResponse(provider), + }) +} + +// CreateCustomOAuthProviderRequest is the request structure for creating a custom OAuth provider +type CreateCustomOAuthProviderRequest struct { + Name string `json:"name" binding:"required"` + Slug string `json:"slug" binding:"required"` + Enabled bool `json:"enabled"` + ClientId string `json:"client_id" binding:"required"` + ClientSecret string `json:"client_secret" binding:"required"` + AuthorizationEndpoint string `json:"authorization_endpoint" binding:"required"` + TokenEndpoint string `json:"token_endpoint" binding:"required"` + UserInfoEndpoint string `json:"user_info_endpoint" binding:"required"` + Scopes string `json:"scopes"` + UserIdField string `json:"user_id_field"` + UsernameField string `json:"username_field"` + DisplayNameField string `json:"display_name_field"` + EmailField string `json:"email_field"` + WellKnown string `json:"well_known"` + AuthStyle int `json:"auth_style"` +} + +// CreateCustomOAuthProvider creates a new custom OAuth provider +func CreateCustomOAuthProvider(c *gin.Context) { + var req CreateCustomOAuthProviderRequest + if err := c.ShouldBindJSON(&req); err != nil { + common.ApiErrorMsg(c, "无效的请求参数: "+err.Error()) + return + } + + // Check if slug is already taken + if model.IsSlugTaken(req.Slug, 0) { + common.ApiErrorMsg(c, "该 Slug 已被使用") + return + } + + // Check if slug conflicts with built-in providers + if oauth.IsProviderRegistered(req.Slug) && !oauth.IsCustomProvider(req.Slug) { + common.ApiErrorMsg(c, "该 Slug 与内置 OAuth 提供商冲突") + return + } + + provider := &model.CustomOAuthProvider{ + Name: req.Name, + Slug: req.Slug, + Enabled: req.Enabled, + ClientId: req.ClientId, + ClientSecret: req.ClientSecret, + AuthorizationEndpoint: req.AuthorizationEndpoint, + TokenEndpoint: req.TokenEndpoint, + UserInfoEndpoint: req.UserInfoEndpoint, + Scopes: req.Scopes, + UserIdField: req.UserIdField, + UsernameField: req.UsernameField, + DisplayNameField: req.DisplayNameField, + EmailField: req.EmailField, + WellKnown: req.WellKnown, + AuthStyle: req.AuthStyle, + } + + if err := model.CreateCustomOAuthProvider(provider); err != nil { + common.ApiError(c, err) + return + } + + // Register the provider in the OAuth registry + oauth.RegisterOrUpdateCustomProvider(provider) + + c.JSON(http.StatusOK, gin.H{ + "success": true, + "message": "创建成功", + "data": toCustomOAuthProviderResponse(provider), + }) +} + +// UpdateCustomOAuthProviderRequest is the request structure for updating a custom OAuth provider +type UpdateCustomOAuthProviderRequest struct { + Name string `json:"name"` + Slug string `json:"slug"` + Enabled bool `json:"enabled"` + ClientId string `json:"client_id"` + ClientSecret string `json:"client_secret"` // Optional: if empty, keep existing + AuthorizationEndpoint string `json:"authorization_endpoint"` + TokenEndpoint string `json:"token_endpoint"` + UserInfoEndpoint string `json:"user_info_endpoint"` + Scopes string `json:"scopes"` + UserIdField string `json:"user_id_field"` + UsernameField string `json:"username_field"` + DisplayNameField string `json:"display_name_field"` + EmailField string `json:"email_field"` + WellKnown string `json:"well_known"` + AuthStyle int `json:"auth_style"` +} + +// UpdateCustomOAuthProvider updates an existing custom OAuth provider +func UpdateCustomOAuthProvider(c *gin.Context) { + idStr := c.Param("id") + id, err := strconv.Atoi(idStr) + if err != nil { + common.ApiErrorMsg(c, "无效的 ID") + return + } + + var req UpdateCustomOAuthProviderRequest + if err := c.ShouldBindJSON(&req); err != nil { + common.ApiErrorMsg(c, "无效的请求参数: "+err.Error()) + return + } + + // Get existing provider + provider, err := model.GetCustomOAuthProviderById(id) + if err != nil { + common.ApiErrorMsg(c, "未找到该 OAuth 提供商") + return + } + + oldSlug := provider.Slug + + // Check if new slug is taken by another provider + if req.Slug != "" && req.Slug != provider.Slug { + if model.IsSlugTaken(req.Slug, id) { + common.ApiErrorMsg(c, "该 Slug 已被使用") + return + } + // Check if slug conflicts with built-in providers + if oauth.IsProviderRegistered(req.Slug) && !oauth.IsCustomProvider(req.Slug) { + common.ApiErrorMsg(c, "该 Slug 与内置 OAuth 提供商冲突") + return + } + } + + // Update fields + if req.Name != "" { + provider.Name = req.Name + } + if req.Slug != "" { + provider.Slug = req.Slug + } + provider.Enabled = req.Enabled + if req.ClientId != "" { + provider.ClientId = req.ClientId + } + if req.ClientSecret != "" { + provider.ClientSecret = req.ClientSecret + } + if req.AuthorizationEndpoint != "" { + provider.AuthorizationEndpoint = req.AuthorizationEndpoint + } + if req.TokenEndpoint != "" { + provider.TokenEndpoint = req.TokenEndpoint + } + if req.UserInfoEndpoint != "" { + provider.UserInfoEndpoint = req.UserInfoEndpoint + } + if req.Scopes != "" { + provider.Scopes = req.Scopes + } + if req.UserIdField != "" { + provider.UserIdField = req.UserIdField + } + if req.UsernameField != "" { + provider.UsernameField = req.UsernameField + } + if req.DisplayNameField != "" { + provider.DisplayNameField = req.DisplayNameField + } + if req.EmailField != "" { + provider.EmailField = req.EmailField + } + provider.WellKnown = req.WellKnown + provider.AuthStyle = req.AuthStyle + + if err := model.UpdateCustomOAuthProvider(provider); err != nil { + common.ApiError(c, err) + return + } + + // Update the provider in the OAuth registry + if oldSlug != provider.Slug { + oauth.UnregisterCustomProvider(oldSlug) + } + oauth.RegisterOrUpdateCustomProvider(provider) + + c.JSON(http.StatusOK, gin.H{ + "success": true, + "message": "更新成功", + "data": toCustomOAuthProviderResponse(provider), + }) +} + +// DeleteCustomOAuthProvider deletes a custom OAuth provider +func DeleteCustomOAuthProvider(c *gin.Context) { + idStr := c.Param("id") + id, err := strconv.Atoi(idStr) + if err != nil { + common.ApiErrorMsg(c, "无效的 ID") + return + } + + // Get existing provider to get slug + provider, err := model.GetCustomOAuthProviderById(id) + if err != nil { + common.ApiErrorMsg(c, "未找到该 OAuth 提供商") + return + } + + // Check if there are any user bindings + count, _ := model.GetBindingCountByProviderId(id) + if count > 0 { + common.ApiErrorMsg(c, "该 OAuth 提供商还有用户绑定,无法删除。请先解除所有用户绑定。") + return + } + + if err := model.DeleteCustomOAuthProvider(id); err != nil { + common.ApiError(c, err) + return + } + + // Unregister the provider from the OAuth registry + oauth.UnregisterCustomProvider(provider.Slug) + + c.JSON(http.StatusOK, gin.H{ + "success": true, + "message": "删除成功", + }) +} + +// GetUserOAuthBindings returns all OAuth bindings for the current user +func GetUserOAuthBindings(c *gin.Context) { + userId := c.GetInt("id") + if userId == 0 { + common.ApiErrorMsg(c, "未登录") + return + } + + bindings, err := model.GetUserOAuthBindingsByUserId(userId) + if err != nil { + common.ApiError(c, err) + return + } + + // Build response with provider info + type BindingResponse struct { + ProviderId int `json:"provider_id"` + ProviderName string `json:"provider_name"` + ProviderSlug string `json:"provider_slug"` + ProviderUserId string `json:"provider_user_id"` + } + + response := make([]BindingResponse, 0) + for _, binding := range bindings { + provider, err := model.GetCustomOAuthProviderById(binding.ProviderId) + if err != nil { + continue // Skip if provider not found + } + response = append(response, BindingResponse{ + ProviderId: binding.ProviderId, + ProviderName: provider.Name, + ProviderSlug: provider.Slug, + ProviderUserId: binding.ProviderUserId, + }) + } + + c.JSON(http.StatusOK, gin.H{ + "success": true, + "message": "", + "data": response, + }) +} + +// UnbindCustomOAuth unbinds a custom OAuth provider from the current user +func UnbindCustomOAuth(c *gin.Context) { + userId := c.GetInt("id") + if userId == 0 { + common.ApiErrorMsg(c, "未登录") + return + } + + providerIdStr := c.Param("provider_id") + providerId, err := strconv.Atoi(providerIdStr) + if err != nil { + common.ApiErrorMsg(c, "无效的提供商 ID") + return + } + + if err := model.DeleteUserOAuthBinding(userId, providerId); err != nil { + common.ApiError(c, err) + return + } + + c.JSON(http.StatusOK, gin.H{ + "success": true, + "message": "解绑成功", + }) +} diff --git a/controller/misc.go b/controller/misc.go index e76ca51bb..a16e2d554 100644 --- a/controller/misc.go +++ b/controller/misc.go @@ -10,6 +10,7 @@ import ( "github.com/QuantumNous/new-api/constant" "github.com/QuantumNous/new-api/middleware" "github.com/QuantumNous/new-api/model" + "github.com/QuantumNous/new-api/oauth" "github.com/QuantumNous/new-api/setting" "github.com/QuantumNous/new-api/setting/console_setting" "github.com/QuantumNous/new-api/setting/operation_setting" @@ -129,6 +130,30 @@ func GetStatus(c *gin.Context) { data["faq"] = console_setting.GetFAQ() } + // Add enabled custom OAuth providers + customProviders := oauth.GetEnabledCustomProviders() + if len(customProviders) > 0 { + type CustomOAuthInfo struct { + Name string `json:"name"` + Slug string `json:"slug"` + ClientId string `json:"client_id"` + AuthorizationEndpoint string `json:"authorization_endpoint"` + Scopes string `json:"scopes"` + } + providersInfo := make([]CustomOAuthInfo, 0, len(customProviders)) + for _, p := range customProviders { + config := p.GetConfig() + providersInfo = append(providersInfo, CustomOAuthInfo{ + Name: config.Name, + Slug: config.Slug, + ClientId: config.ClientId, + AuthorizationEndpoint: config.AuthorizationEndpoint, + Scopes: config.Scopes, + }) + } + data["custom_oauth_providers"] = providersInfo + } + c.JSON(http.StatusOK, gin.H{ "success": true, "message": "", diff --git a/controller/oauth.go b/controller/oauth.go index fb9e59a59..58cb40d5c 100644 --- a/controller/oauth.go +++ b/controller/oauth.go @@ -171,12 +171,22 @@ func handleOAuthBind(c *gin.Context, provider oauth.Provider) { return } - // Update user with OAuth ID - provider.SetProviderUserID(&user, oauthUser.ProviderUserID) - err = user.Update(false) - if err != nil { - common.ApiError(c, err) - return + // Handle binding based on provider type + if genericProvider, ok := provider.(*oauth.GenericOAuthProvider); ok { + // Custom provider: use user_oauth_bindings table + err = model.UpdateUserOAuthBinding(user.Id, genericProvider.GetProviderId(), oauthUser.ProviderUserID) + if err != nil { + common.ApiError(c, err) + return + } + } else { + // Built-in provider: update user record directly + provider.SetProviderUserID(&user, oauthUser.ProviderUserID) + err = user.Update(false) + if err != nil { + common.ApiError(c, err) + return + } } common.ApiSuccessI18n(c, i18n.MsgOAuthBindSuccess, nil) @@ -188,7 +198,6 @@ func findOrCreateOAuthUser(c *gin.Context, provider oauth.Provider, oauthUser *o // Check if user already exists with new ID if provider.IsUserIDTaken(oauthUser.ProviderUserID) { - provider.SetProviderUserID(user, oauthUser.ProviderUserID) err := provider.FillUserByProviderID(user, oauthUser.ProviderUserID) if err != nil { return nil, err @@ -203,7 +212,6 @@ func findOrCreateOAuthUser(c *gin.Context, provider oauth.Provider, oauthUser *o // Try to find user with legacy ID (for GitHub migration from login to numeric ID) if legacyID, ok := oauthUser.Extra["legacy_id"].(string); ok && legacyID != "" { if provider.IsUserIDTaken(legacyID) { - provider.SetProviderUserID(user, legacyID) err := provider.FillUserByProviderID(user, legacyID) if err != nil { return nil, err @@ -240,7 +248,6 @@ func findOrCreateOAuthUser(c *gin.Context, provider oauth.Provider, oauthUser *o } user.Role = common.RoleCommonUser user.Status = common.UserStatusEnabled - provider.SetProviderUserID(user, oauthUser.ProviderUserID) // Handle affiliate code affCode := session.Get("aff") @@ -253,6 +260,25 @@ func findOrCreateOAuthUser(c *gin.Context, provider oauth.Provider, oauthUser *o return nil, err } + // For custom providers, create the binding after user is created + if genericProvider, ok := provider.(*oauth.GenericOAuthProvider); ok { + binding := &model.UserOAuthBinding{ + UserId: user.Id, + ProviderId: genericProvider.GetProviderId(), + ProviderUserId: oauthUser.ProviderUserID, + } + if err := model.CreateUserOAuthBinding(binding); err != nil { + common.SysError(fmt.Sprintf("[OAuth] Failed to create binding for user %d: %s", user.Id, err.Error())) + // Don't fail the registration, just log the error + } + } else { + // Built-in provider: set the provider user ID on the user model + provider.SetProviderUserID(user, oauthUser.ProviderUserID) + if err := user.Update(false); err != nil { + common.SysError(fmt.Sprintf("[OAuth] Failed to update provider ID for user %d: %s", user.Id, err.Error())) + } + } + return user, nil } diff --git a/i18n/keys.go b/i18n/keys.go index 6ac0a574c..d1fd00c61 100644 --- a/i18n/keys.go +++ b/i18n/keys.go @@ -287,3 +287,14 @@ const ( MsgUuidDuplicate = "common.uuid_duplicate" MsgInvalidInput = "common.invalid_input" ) + +// Custom OAuth provider related messages +const ( + MsgCustomOAuthNotFound = "custom_oauth.not_found" + MsgCustomOAuthSlugEmpty = "custom_oauth.slug_empty" + MsgCustomOAuthSlugExists = "custom_oauth.slug_exists" + MsgCustomOAuthNameEmpty = "custom_oauth.name_empty" + MsgCustomOAuthHasBindings = "custom_oauth.has_bindings" + MsgCustomOAuthBindingNotFound = "custom_oauth.binding_not_found" + MsgCustomOAuthProviderIdInvalid = "custom_oauth.provider_id_field_invalid" +) diff --git a/i18n/locales/en.yaml b/i18n/locales/en.yaml index e44f7ad7b..be5df367e 100644 --- a/i18n/locales/en.yaml +++ b/i18n/locales/en.yaml @@ -240,3 +240,12 @@ redeem.failed: "Redemption failed, please try again later" user.create_default_token_error: "Failed to create default token" common.uuid_duplicate: "Please retry, the system generated a duplicate UUID!" common.invalid_input: "Invalid input" + +# Custom OAuth provider messages +custom_oauth.not_found: "Custom OAuth provider not found" +custom_oauth.slug_empty: "Slug cannot be empty" +custom_oauth.slug_exists: "Slug already exists" +custom_oauth.name_empty: "Provider name cannot be empty" +custom_oauth.has_bindings: "Cannot delete provider with existing user bindings" +custom_oauth.binding_not_found: "OAuth binding not found" +custom_oauth.provider_id_field_invalid: "Could not extract user ID from provider response" diff --git a/i18n/locales/zh.yaml b/i18n/locales/zh.yaml index 9098e977e..0f4460c6a 100644 --- a/i18n/locales/zh.yaml +++ b/i18n/locales/zh.yaml @@ -241,3 +241,12 @@ redeem.failed: "兑换失败,请稍后重试" user.create_default_token_error: "创建默认令牌失败" common.uuid_duplicate: "请重试,系统生成的 UUID 竟然重复了!" common.invalid_input: "输入不合法" + +# Custom OAuth provider messages +custom_oauth.not_found: "自定义 OAuth 提供商不存在" +custom_oauth.slug_empty: "标识符不能为空" +custom_oauth.slug_exists: "标识符已存在" +custom_oauth.name_empty: "提供商名称不能为空" +custom_oauth.has_bindings: "无法删除已有用户绑定的提供商" +custom_oauth.binding_not_found: "OAuth 绑定不存在" +custom_oauth.provider_id_field_invalid: "无法从提供商响应中提取用户 ID" diff --git a/main.go b/main.go index 4f9cf84ee..852e1a0a8 100644 --- a/main.go +++ b/main.go @@ -18,6 +18,7 @@ import ( "github.com/QuantumNous/new-api/logger" "github.com/QuantumNous/new-api/middleware" "github.com/QuantumNous/new-api/model" + "github.com/QuantumNous/new-api/oauth" "github.com/QuantumNous/new-api/router" "github.com/QuantumNous/new-api/service" _ "github.com/QuantumNous/new-api/setting/performance_setting" @@ -291,5 +292,12 @@ func InitResources() error { // Register user language loader for lazy loading i18n.SetUserLangLoader(model.GetUserLanguage) + // Load custom OAuth providers from database + err = oauth.LoadCustomProviders() + if err != nil { + common.SysError("failed to load custom OAuth providers: " + err.Error()) + // Don't return error, custom OAuth is not critical + } + return nil } diff --git a/model/custom_oauth_provider.go b/model/custom_oauth_provider.go new file mode 100644 index 000000000..884e87b06 --- /dev/null +++ b/model/custom_oauth_provider.go @@ -0,0 +1,158 @@ +package model + +import ( + "errors" + "strings" + "time" +) + +// CustomOAuthProvider stores configuration for custom OAuth providers +type CustomOAuthProvider struct { + Id int `json:"id" gorm:"primaryKey"` + Name string `json:"name" gorm:"type:varchar(64);not null"` // Display name, e.g., "GitHub Enterprise" + Slug string `json:"slug" gorm:"type:varchar(64);uniqueIndex;not null"` // URL identifier, e.g., "github-enterprise" + Enabled bool `json:"enabled" gorm:"default:false"` // Whether this provider is enabled + ClientId string `json:"client_id" gorm:"type:varchar(256)"` // OAuth client ID + ClientSecret string `json:"-" gorm:"type:varchar(512)"` // OAuth client secret (not returned to frontend) + AuthorizationEndpoint string `json:"authorization_endpoint" gorm:"type:varchar(512)"` // Authorization URL + TokenEndpoint string `json:"token_endpoint" gorm:"type:varchar(512)"` // Token exchange URL + UserInfoEndpoint string `json:"user_info_endpoint" gorm:"type:varchar(512)"` // User info URL + Scopes string `json:"scopes" gorm:"type:varchar(256);default:'openid profile email'"` // OAuth scopes + + // Field mapping configuration (supports JSONPath via gjson) + UserIdField string `json:"user_id_field" gorm:"type:varchar(128);default:'sub'"` // User ID field path, e.g., "sub", "id", "data.user.id" + UsernameField string `json:"username_field" gorm:"type:varchar(128);default:'preferred_username'"` // Username field path + DisplayNameField string `json:"display_name_field" gorm:"type:varchar(128);default:'name'"` // Display name field path + EmailField string `json:"email_field" gorm:"type:varchar(128);default:'email'"` // Email field path + + // Advanced options + WellKnown string `json:"well_known" gorm:"type:varchar(512)"` // OIDC discovery endpoint (optional) + AuthStyle int `json:"auth_style" gorm:"default:0"` // 0=auto, 1=params, 2=header (Basic Auth) + + CreatedAt time.Time `json:"created_at"` + UpdatedAt time.Time `json:"updated_at"` +} + +func (CustomOAuthProvider) TableName() string { + return "custom_oauth_providers" +} + +// GetAllCustomOAuthProviders returns all custom OAuth providers +func GetAllCustomOAuthProviders() ([]*CustomOAuthProvider, error) { + var providers []*CustomOAuthProvider + err := DB.Order("id asc").Find(&providers).Error + return providers, err +} + +// GetEnabledCustomOAuthProviders returns all enabled custom OAuth providers +func GetEnabledCustomOAuthProviders() ([]*CustomOAuthProvider, error) { + var providers []*CustomOAuthProvider + err := DB.Where("enabled = ?", true).Order("id asc").Find(&providers).Error + return providers, err +} + +// GetCustomOAuthProviderById returns a custom OAuth provider by ID +func GetCustomOAuthProviderById(id int) (*CustomOAuthProvider, error) { + var provider CustomOAuthProvider + err := DB.First(&provider, id).Error + if err != nil { + return nil, err + } + return &provider, nil +} + +// GetCustomOAuthProviderBySlug returns a custom OAuth provider by slug +func GetCustomOAuthProviderBySlug(slug string) (*CustomOAuthProvider, error) { + var provider CustomOAuthProvider + err := DB.Where("slug = ?", slug).First(&provider).Error + if err != nil { + return nil, err + } + return &provider, nil +} + +// CreateCustomOAuthProvider creates a new custom OAuth provider +func CreateCustomOAuthProvider(provider *CustomOAuthProvider) error { + if err := validateCustomOAuthProvider(provider); err != nil { + return err + } + return DB.Create(provider).Error +} + +// UpdateCustomOAuthProvider updates an existing custom OAuth provider +func UpdateCustomOAuthProvider(provider *CustomOAuthProvider) error { + if err := validateCustomOAuthProvider(provider); err != nil { + return err + } + return DB.Save(provider).Error +} + +// DeleteCustomOAuthProvider deletes a custom OAuth provider by ID +func DeleteCustomOAuthProvider(id int) error { + // First, delete all user bindings for this provider + if err := DB.Where("provider_id = ?", id).Delete(&UserOAuthBinding{}).Error; err != nil { + return err + } + return DB.Delete(&CustomOAuthProvider{}, id).Error +} + +// IsSlugTaken checks if a slug is already taken by another provider +func IsSlugTaken(slug string, excludeId int) bool { + var count int64 + query := DB.Model(&CustomOAuthProvider{}).Where("slug = ?", slug) + if excludeId > 0 { + query = query.Where("id != ?", excludeId) + } + query.Count(&count) + return count > 0 +} + +// validateCustomOAuthProvider validates a custom OAuth provider configuration +func validateCustomOAuthProvider(provider *CustomOAuthProvider) error { + if provider.Name == "" { + return errors.New("provider name is required") + } + if provider.Slug == "" { + return errors.New("provider slug is required") + } + // Slug must be lowercase and contain only alphanumeric characters and hyphens + slug := strings.ToLower(provider.Slug) + for _, c := range slug { + if !((c >= 'a' && c <= 'z') || (c >= '0' && c <= '9') || c == '-') { + return errors.New("provider slug must contain only lowercase letters, numbers, and hyphens") + } + } + provider.Slug = slug + + if provider.ClientId == "" { + return errors.New("client ID is required") + } + if provider.AuthorizationEndpoint == "" { + return errors.New("authorization endpoint is required") + } + if provider.TokenEndpoint == "" { + return errors.New("token endpoint is required") + } + if provider.UserInfoEndpoint == "" { + return errors.New("user info endpoint is required") + } + + // Set defaults for field mappings if empty + if provider.UserIdField == "" { + provider.UserIdField = "sub" + } + if provider.UsernameField == "" { + provider.UsernameField = "preferred_username" + } + if provider.DisplayNameField == "" { + provider.DisplayNameField = "name" + } + if provider.EmailField == "" { + provider.EmailField = "email" + } + if provider.Scopes == "" { + provider.Scopes = "openid profile email" + } + + return nil +} diff --git a/model/main.go b/model/main.go index e78970950..21a5d4c08 100644 --- a/model/main.go +++ b/model/main.go @@ -274,6 +274,8 @@ func migrateDB() error { &SubscriptionOrder{}, &UserSubscription{}, &SubscriptionPreConsumeRecord{}, + &CustomOAuthProvider{}, + &UserOAuthBinding{}, ) if err != nil { return err @@ -320,6 +322,8 @@ func migrateDBFast() error { {&SubscriptionOrder{}, "SubscriptionOrder"}, {&UserSubscription{}, "UserSubscription"}, {&SubscriptionPreConsumeRecord{}, "SubscriptionPreConsumeRecord"}, + {&CustomOAuthProvider{}, "CustomOAuthProvider"}, + {&UserOAuthBinding{}, "UserOAuthBinding"}, } // 动态计算migration数量,确保errChan缓冲区足够大 errChan := make(chan error, len(migrations)) diff --git a/model/user_oauth_binding.go b/model/user_oauth_binding.go new file mode 100644 index 000000000..7b2acd474 --- /dev/null +++ b/model/user_oauth_binding.go @@ -0,0 +1,125 @@ +package model + +import ( + "errors" + "time" +) + +// UserOAuthBinding stores the binding relationship between users and custom OAuth providers +type UserOAuthBinding struct { + Id int `json:"id" gorm:"primaryKey"` + UserId int `json:"user_id" gorm:"index;not null"` // User ID + ProviderId int `json:"provider_id" gorm:"index;not null"` // Custom OAuth provider ID + ProviderUserId string `json:"provider_user_id" gorm:"type:varchar(256);not null"` // User ID from OAuth provider + CreatedAt time.Time `json:"created_at"` + + // Composite unique index to prevent duplicate bindings + // One OAuth account can only be bound to one user +} + +func (UserOAuthBinding) TableName() string { + return "user_oauth_bindings" +} + +// GetUserOAuthBindingsByUserId returns all OAuth bindings for a user +func GetUserOAuthBindingsByUserId(userId int) ([]*UserOAuthBinding, error) { + var bindings []*UserOAuthBinding + err := DB.Where("user_id = ?", userId).Find(&bindings).Error + return bindings, err +} + +// GetUserOAuthBinding returns a specific binding for a user and provider +func GetUserOAuthBinding(userId, providerId int) (*UserOAuthBinding, error) { + var binding UserOAuthBinding + err := DB.Where("user_id = ? AND provider_id = ?", userId, providerId).First(&binding).Error + if err != nil { + return nil, err + } + return &binding, nil +} + +// GetUserByOAuthBinding finds a user by provider ID and provider user ID +func GetUserByOAuthBinding(providerId int, providerUserId string) (*User, error) { + var binding UserOAuthBinding + err := DB.Where("provider_id = ? AND provider_user_id = ?", providerId, providerUserId).First(&binding).Error + if err != nil { + return nil, err + } + + var user User + err = DB.First(&user, binding.UserId).Error + if err != nil { + return nil, err + } + return &user, nil +} + +// IsProviderUserIdTaken checks if a provider user ID is already bound to any user +func IsProviderUserIdTaken(providerId int, providerUserId string) bool { + var count int64 + DB.Model(&UserOAuthBinding{}).Where("provider_id = ? AND provider_user_id = ?", providerId, providerUserId).Count(&count) + return count > 0 +} + +// CreateUserOAuthBinding creates a new OAuth binding +func CreateUserOAuthBinding(binding *UserOAuthBinding) error { + if binding.UserId == 0 { + return errors.New("user ID is required") + } + if binding.ProviderId == 0 { + return errors.New("provider ID is required") + } + if binding.ProviderUserId == "" { + return errors.New("provider user ID is required") + } + + // Check if this provider user ID is already taken + if IsProviderUserIdTaken(binding.ProviderId, binding.ProviderUserId) { + return errors.New("this OAuth account is already bound to another user") + } + + binding.CreatedAt = time.Now() + return DB.Create(binding).Error +} + +// UpdateUserOAuthBinding updates an existing OAuth binding (e.g., rebind to different OAuth account) +func UpdateUserOAuthBinding(userId, providerId int, newProviderUserId string) error { + // Check if the new provider user ID is already taken by another user + var existingBinding UserOAuthBinding + err := DB.Where("provider_id = ? AND provider_user_id = ?", providerId, newProviderUserId).First(&existingBinding).Error + if err == nil && existingBinding.UserId != userId { + return errors.New("this OAuth account is already bound to another user") + } + + // Check if user already has a binding for this provider + var binding UserOAuthBinding + err = DB.Where("user_id = ? AND provider_id = ?", userId, providerId).First(&binding).Error + if err != nil { + // No existing binding, create new one + return CreateUserOAuthBinding(&UserOAuthBinding{ + UserId: userId, + ProviderId: providerId, + ProviderUserId: newProviderUserId, + }) + } + + // Update existing binding + return DB.Model(&binding).Update("provider_user_id", newProviderUserId).Error +} + +// DeleteUserOAuthBinding deletes an OAuth binding +func DeleteUserOAuthBinding(userId, providerId int) error { + return DB.Where("user_id = ? AND provider_id = ?", userId, providerId).Delete(&UserOAuthBinding{}).Error +} + +// DeleteUserOAuthBindingsByUserId deletes all OAuth bindings for a user +func DeleteUserOAuthBindingsByUserId(userId int) error { + return DB.Where("user_id = ?", userId).Delete(&UserOAuthBinding{}).Error +} + +// GetBindingCountByProviderId returns the number of bindings for a provider +func GetBindingCountByProviderId(providerId int) (int64, error) { + var count int64 + err := DB.Model(&UserOAuthBinding{}).Where("provider_id = ?", providerId).Count(&count).Error + return count, err +} diff --git a/oauth/generic.go b/oauth/generic.go new file mode 100644 index 000000000..c7aa87931 --- /dev/null +++ b/oauth/generic.go @@ -0,0 +1,268 @@ +package oauth + +import ( + "context" + "encoding/base64" + "encoding/json" + "fmt" + "io" + "net/http" + "net/url" + "strings" + "time" + + "github.com/QuantumNous/new-api/i18n" + "github.com/QuantumNous/new-api/logger" + "github.com/QuantumNous/new-api/model" + "github.com/QuantumNous/new-api/setting/system_setting" + "github.com/gin-gonic/gin" + "github.com/tidwall/gjson" +) + +// AuthStyle defines how to send client credentials +const ( + AuthStyleAutoDetect = 0 // Auto-detect based on server response + AuthStyleInParams = 1 // Send client_id and client_secret as POST parameters + AuthStyleInHeader = 2 // Send as Basic Auth header +) + +// GenericOAuthProvider implements OAuth for custom/generic OAuth providers +type GenericOAuthProvider struct { + config *model.CustomOAuthProvider +} + +// NewGenericOAuthProvider creates a new generic OAuth provider from config +func NewGenericOAuthProvider(config *model.CustomOAuthProvider) *GenericOAuthProvider { + return &GenericOAuthProvider{config: config} +} + +func (p *GenericOAuthProvider) GetName() string { + return p.config.Name +} + +func (p *GenericOAuthProvider) IsEnabled() bool { + return p.config.Enabled +} + +func (p *GenericOAuthProvider) GetConfig() *model.CustomOAuthProvider { + return p.config +} + +func (p *GenericOAuthProvider) ExchangeToken(ctx context.Context, code string, c *gin.Context) (*OAuthToken, error) { + if code == "" { + return nil, NewOAuthError(i18n.MsgOAuthInvalidCode, nil) + } + + logger.LogDebug(ctx, "[OAuth-Generic-%s] ExchangeToken: code=%s...", p.config.Slug, code[:min(len(code), 10)]) + + redirectUri := fmt.Sprintf("%s/oauth/%s", system_setting.ServerAddress, p.config.Slug) + values := url.Values{} + values.Set("grant_type", "authorization_code") + values.Set("code", code) + values.Set("redirect_uri", redirectUri) + + // Determine auth style + authStyle := p.config.AuthStyle + if authStyle == AuthStyleAutoDetect { + // Default to params style for most OAuth servers + authStyle = AuthStyleInParams + } + + var req *http.Request + var err error + + if authStyle == AuthStyleInParams { + values.Set("client_id", p.config.ClientId) + values.Set("client_secret", p.config.ClientSecret) + } + + req, err = http.NewRequestWithContext(ctx, "POST", p.config.TokenEndpoint, strings.NewReader(values.Encode())) + if err != nil { + return nil, err + } + req.Header.Set("Content-Type", "application/x-www-form-urlencoded") + req.Header.Set("Accept", "application/json") + + if authStyle == AuthStyleInHeader { + // Basic Auth + credentials := base64.StdEncoding.EncodeToString([]byte(p.config.ClientId + ":" + p.config.ClientSecret)) + req.Header.Set("Authorization", "Basic "+credentials) + } + + logger.LogDebug(ctx, "[OAuth-Generic-%s] ExchangeToken: token_endpoint=%s, redirect_uri=%s, auth_style=%d", + p.config.Slug, p.config.TokenEndpoint, redirectUri, authStyle) + + client := http.Client{ + Timeout: 20 * time.Second, + } + res, err := client.Do(req) + if err != nil { + logger.LogError(ctx, fmt.Sprintf("[OAuth-Generic-%s] ExchangeToken error: %s", p.config.Slug, err.Error())) + return nil, NewOAuthErrorWithRaw(i18n.MsgOAuthConnectFailed, map[string]any{"Provider": p.config.Name}, err.Error()) + } + defer res.Body.Close() + + logger.LogDebug(ctx, "[OAuth-Generic-%s] ExchangeToken response status: %d", p.config.Slug, res.StatusCode) + + body, err := io.ReadAll(res.Body) + if err != nil { + logger.LogError(ctx, fmt.Sprintf("[OAuth-Generic-%s] ExchangeToken read body error: %s", p.config.Slug, err.Error())) + return nil, err + } + + bodyStr := string(body) + logger.LogDebug(ctx, "[OAuth-Generic-%s] ExchangeToken response body: %s", p.config.Slug, bodyStr[:min(len(bodyStr), 500)]) + + // Try to parse as JSON first + var tokenResponse struct { + AccessToken string `json:"access_token"` + TokenType string `json:"token_type"` + RefreshToken string `json:"refresh_token"` + ExpiresIn int `json:"expires_in"` + Scope string `json:"scope"` + IDToken string `json:"id_token"` + Error string `json:"error"` + ErrorDesc string `json:"error_description"` + } + + if err := json.Unmarshal(body, &tokenResponse); err != nil { + // Try to parse as URL-encoded (some OAuth servers like GitHub return this format) + parsedValues, parseErr := url.ParseQuery(bodyStr) + if parseErr != nil { + logger.LogError(ctx, fmt.Sprintf("[OAuth-Generic-%s] ExchangeToken parse error: %s", p.config.Slug, err.Error())) + return nil, err + } + tokenResponse.AccessToken = parsedValues.Get("access_token") + tokenResponse.TokenType = parsedValues.Get("token_type") + tokenResponse.Scope = parsedValues.Get("scope") + } + + if tokenResponse.Error != "" { + logger.LogError(ctx, fmt.Sprintf("[OAuth-Generic-%s] ExchangeToken OAuth error: %s - %s", + p.config.Slug, tokenResponse.Error, tokenResponse.ErrorDesc)) + return nil, NewOAuthErrorWithRaw(i18n.MsgOAuthTokenFailed, map[string]any{"Provider": p.config.Name}, tokenResponse.ErrorDesc) + } + + if tokenResponse.AccessToken == "" { + logger.LogError(ctx, fmt.Sprintf("[OAuth-Generic-%s] ExchangeToken failed: empty access token", p.config.Slug)) + return nil, NewOAuthError(i18n.MsgOAuthTokenFailed, map[string]any{"Provider": p.config.Name}) + } + + logger.LogDebug(ctx, "[OAuth-Generic-%s] ExchangeToken success: scope=%s", p.config.Slug, tokenResponse.Scope) + + return &OAuthToken{ + AccessToken: tokenResponse.AccessToken, + TokenType: tokenResponse.TokenType, + RefreshToken: tokenResponse.RefreshToken, + ExpiresIn: tokenResponse.ExpiresIn, + Scope: tokenResponse.Scope, + IDToken: tokenResponse.IDToken, + }, nil +} + +func (p *GenericOAuthProvider) GetUserInfo(ctx context.Context, token *OAuthToken) (*OAuthUser, error) { + logger.LogDebug(ctx, "[OAuth-Generic-%s] GetUserInfo: fetching user info from %s", p.config.Slug, p.config.UserInfoEndpoint) + + req, err := http.NewRequestWithContext(ctx, "GET", p.config.UserInfoEndpoint, nil) + if err != nil { + return nil, err + } + + // Set authorization header + tokenType := token.TokenType + if tokenType == "" { + tokenType = "Bearer" + } + req.Header.Set("Authorization", fmt.Sprintf("%s %s", tokenType, token.AccessToken)) + req.Header.Set("Accept", "application/json") + + client := http.Client{ + Timeout: 20 * time.Second, + } + res, err := client.Do(req) + if err != nil { + logger.LogError(ctx, fmt.Sprintf("[OAuth-Generic-%s] GetUserInfo error: %s", p.config.Slug, err.Error())) + return nil, NewOAuthErrorWithRaw(i18n.MsgOAuthConnectFailed, map[string]any{"Provider": p.config.Name}, err.Error()) + } + defer res.Body.Close() + + logger.LogDebug(ctx, "[OAuth-Generic-%s] GetUserInfo response status: %d", p.config.Slug, res.StatusCode) + + if res.StatusCode != http.StatusOK { + logger.LogError(ctx, fmt.Sprintf("[OAuth-Generic-%s] GetUserInfo failed: status=%d", p.config.Slug, res.StatusCode)) + return nil, NewOAuthError(i18n.MsgOAuthGetUserErr, nil) + } + + body, err := io.ReadAll(res.Body) + if err != nil { + logger.LogError(ctx, fmt.Sprintf("[OAuth-Generic-%s] GetUserInfo read body error: %s", p.config.Slug, err.Error())) + return nil, err + } + + bodyStr := string(body) + logger.LogDebug(ctx, "[OAuth-Generic-%s] GetUserInfo response body: %s", p.config.Slug, bodyStr[:min(len(bodyStr), 500)]) + + // Extract fields using gjson (supports JSONPath-like syntax) + userId := gjson.Get(bodyStr, p.config.UserIdField).String() + username := gjson.Get(bodyStr, p.config.UsernameField).String() + displayName := gjson.Get(bodyStr, p.config.DisplayNameField).String() + email := gjson.Get(bodyStr, p.config.EmailField).String() + + // If user ID field returns a number, convert it + if userId == "" { + // Try to get as number + userIdNum := gjson.Get(bodyStr, p.config.UserIdField) + if userIdNum.Exists() { + userId = userIdNum.Raw + // Remove quotes if present + userId = strings.Trim(userId, "\"") + } + } + + if userId == "" { + logger.LogError(ctx, fmt.Sprintf("[OAuth-Generic-%s] GetUserInfo failed: empty user ID (field: %s)", p.config.Slug, p.config.UserIdField)) + return nil, NewOAuthError(i18n.MsgOAuthUserInfoEmpty, map[string]any{"Provider": p.config.Name}) + } + + logger.LogDebug(ctx, "[OAuth-Generic-%s] GetUserInfo success: id=%s, username=%s, name=%s, email=%s", + p.config.Slug, userId, username, displayName, email) + + return &OAuthUser{ + ProviderUserID: userId, + Username: username, + DisplayName: displayName, + Email: email, + }, nil +} + +func (p *GenericOAuthProvider) IsUserIDTaken(providerUserID string) bool { + return model.IsProviderUserIdTaken(p.config.Id, providerUserID) +} + +func (p *GenericOAuthProvider) FillUserByProviderID(user *model.User, providerUserID string) error { + foundUser, err := model.GetUserByOAuthBinding(p.config.Id, providerUserID) + if err != nil { + return err + } + *user = *foundUser + return nil +} + +func (p *GenericOAuthProvider) SetProviderUserID(user *model.User, providerUserID string) { + // For generic providers, we store the binding in user_oauth_bindings table + // This is handled separately in the OAuth controller +} + +func (p *GenericOAuthProvider) GetProviderPrefix() string { + return p.config.Slug + "_" +} + +// GetProviderId returns the provider ID for binding purposes +func (p *GenericOAuthProvider) GetProviderId() int { + return p.config.Id +} + +// IsGenericProvider returns true for generic providers +func (p *GenericOAuthProvider) IsGenericProvider() bool { + return true +} diff --git a/oauth/registry.go b/oauth/registry.go index 13ee2bcfb..91d196364 100644 --- a/oauth/registry.go +++ b/oauth/registry.go @@ -1,12 +1,18 @@ package oauth import ( + "fmt" "sync" + + "github.com/QuantumNous/new-api/common" + "github.com/QuantumNous/new-api/model" ) var ( providers = make(map[string]Provider) mu sync.RWMutex + // customProviderSlugs tracks which providers are custom (can be unregistered) + customProviderSlugs = make(map[string]bool) ) // Register registers an OAuth provider with the given name @@ -16,6 +22,22 @@ func Register(name string, provider Provider) { providers[name] = provider } +// RegisterCustom registers a custom OAuth provider (can be unregistered later) +func RegisterCustom(name string, provider Provider) { + mu.Lock() + defer mu.Unlock() + providers[name] = provider + customProviderSlugs[name] = true +} + +// Unregister removes a provider from the registry +func Unregister(name string) { + mu.Lock() + defer mu.Unlock() + delete(providers, name) + delete(customProviderSlugs, name) +} + // GetProvider returns the OAuth provider for the given name func GetProvider(name string) Provider { mu.RLock() @@ -34,6 +56,21 @@ func GetAllProviders() map[string]Provider { return result } +// GetEnabledCustomProviders returns all enabled custom OAuth providers +func GetEnabledCustomProviders() []*GenericOAuthProvider { + mu.RLock() + defer mu.RUnlock() + var result []*GenericOAuthProvider + for name, provider := range providers { + if customProviderSlugs[name] { + if gp, ok := provider.(*GenericOAuthProvider); ok && gp.IsEnabled() { + result = append(result, gp) + } + } + } + return result +} + // IsProviderRegistered checks if a provider is registered func IsProviderRegistered(name string) bool { mu.RLock() @@ -41,3 +78,57 @@ func IsProviderRegistered(name string) bool { _, ok := providers[name] return ok } + +// IsCustomProvider checks if a provider is a custom provider +func IsCustomProvider(name string) bool { + mu.RLock() + defer mu.RUnlock() + return customProviderSlugs[name] +} + +// LoadCustomProviders loads all custom OAuth providers from the database +func LoadCustomProviders() error { + // First, unregister all existing custom providers + mu.Lock() + for name := range customProviderSlugs { + delete(providers, name) + } + customProviderSlugs = make(map[string]bool) + mu.Unlock() + + // Load all custom providers from database + customProviders, err := model.GetAllCustomOAuthProviders() + if err != nil { + common.SysError("Failed to load custom OAuth providers: " + err.Error()) + return err + } + + // Register each custom provider + for _, config := range customProviders { + provider := NewGenericOAuthProvider(config) + RegisterCustom(config.Slug, provider) + common.SysLog("Loaded custom OAuth provider: " + config.Name + " (" + config.Slug + ")") + } + + common.SysLog(fmt.Sprintf("Loaded %d custom OAuth providers", len(customProviders))) + return nil +} + +// ReloadCustomProviders reloads all custom OAuth providers from the database +func ReloadCustomProviders() error { + return LoadCustomProviders() +} + +// RegisterOrUpdateCustomProvider registers or updates a single custom provider +func RegisterOrUpdateCustomProvider(config *model.CustomOAuthProvider) { + provider := NewGenericOAuthProvider(config) + mu.Lock() + defer mu.Unlock() + providers[config.Slug] = provider + customProviderSlugs[config.Slug] = true +} + +// UnregisterCustomProvider unregisters a custom provider by slug +func UnregisterCustomProvider(slug string) { + Unregister(slug) +} diff --git a/router/api-router.go b/router/api-router.go index 2b84295a1..e26f9b700 100644 --- a/router/api-router.go +++ b/router/api-router.go @@ -102,6 +102,10 @@ func SetApiRouter(router *gin.Engine) { // Check-in routes selfRoute.GET("/checkin", controller.GetCheckinStatus) selfRoute.POST("/checkin", middleware.TurnstileCheck(), controller.DoCheckin) + + // Custom OAuth bindings + selfRoute.GET("/oauth/bindings", controller.GetUserOAuthBindings) + selfRoute.DELETE("/oauth/bindings/:provider_id", controller.UnbindCustomOAuth) } adminRoute := userRoute.Group("/") @@ -166,6 +170,17 @@ func SetApiRouter(router *gin.Engine) { optionRoute.POST("/rest_model_ratio", controller.ResetModelRatio) optionRoute.POST("/migrate_console_setting", controller.MigrateConsoleSetting) // 用于迁移检测的旧键,下个版本会删除 } + + // Custom OAuth provider management (admin only) + customOAuthRoute := apiRouter.Group("/custom-oauth-provider") + customOAuthRoute.Use(middleware.RootAuth()) + { + customOAuthRoute.GET("/", controller.GetCustomOAuthProviders) + customOAuthRoute.GET("/:id", controller.GetCustomOAuthProvider) + customOAuthRoute.POST("/", controller.CreateCustomOAuthProvider) + customOAuthRoute.PUT("/:id", controller.UpdateCustomOAuthProvider) + customOAuthRoute.DELETE("/:id", controller.DeleteCustomOAuthProvider) + } performanceRoute := apiRouter.Group("/performance") performanceRoute.Use(middleware.RootAuth()) { diff --git a/web/src/components/auth/LoginForm.jsx b/web/src/components/auth/LoginForm.jsx index 134451ec3..636317e44 100644 --- a/web/src/components/auth/LoginForm.jsx +++ b/web/src/components/auth/LoginForm.jsx @@ -34,6 +34,7 @@ import { onDiscordOAuthClicked, onOIDCClicked, onLinuxDOOAuthClicked, + onCustomOAuthClicked, prepareCredentialRequestOptions, buildAssertionResult, isPasskeySupported, @@ -109,6 +110,7 @@ const LoginForm = () => { const [githubButtonDisabled, setGithubButtonDisabled] = useState(false); const githubTimeoutRef = useRef(null); const githubButtonText = t(githubButtonTextKeyByState[githubButtonState]); + const [customOAuthLoading, setCustomOAuthLoading] = useState({}); const logo = getLogo(); const systemName = getSystemName(); @@ -373,6 +375,23 @@ const LoginForm = () => { } }; + // 包装的自定义OAuth登录点击处理 + const handleCustomOAuthClick = (provider) => { + if ((hasUserAgreement || hasPrivacyPolicy) && !agreedToTerms) { + showInfo(t('请先阅读并同意用户协议和隐私政策')); + return; + } + setCustomOAuthLoading((prev) => ({ ...prev, [provider.slug]: true })); + try { + onCustomOAuthClicked(provider, { shouldLogout: true }); + } finally { + // 由于重定向,这里不会执行到,但为了完整性添加 + setTimeout(() => { + setCustomOAuthLoading((prev) => ({ ...prev, [provider.slug]: false })); + }, 3000); + } + }; + // 包装的邮箱登录选项点击处理 const handleEmailLoginClick = () => { setEmailLoginLoading(true); @@ -572,6 +591,23 @@ const LoginForm = () => { )} + {status.custom_oauth_providers && + status.custom_oauth_providers.map((provider) => ( + } + onClick={() => handleCustomOAuthClick(provider)} + loading={customOAuthLoading[provider.slug]} + > + + {t('使用 {{name}} 继续', { name: provider.name })} + + + ))} + {status.telegram_oauth && (