feat(oauth): implement custom OAuth provider management #1106

- Add support for custom OAuth providers, including creation, retrieval, updating, and deletion.
- Introduce new model and controller for managing custom OAuth providers.
- Enhance existing OAuth logic to accommodate custom providers.
- Update API routes for custom OAuth provider management.
- Include i18n support for custom OAuth-related messages.
This commit is contained in:
CaIon
2026-02-05 21:18:43 +08:00
parent 632baadb57
commit af54ea85d2
20 changed files with 2066 additions and 11 deletions

View File

@@ -0,0 +1,158 @@
package model
import (
"errors"
"strings"
"time"
)
// CustomOAuthProvider stores configuration for custom OAuth providers
type CustomOAuthProvider struct {
Id int `json:"id" gorm:"primaryKey"`
Name string `json:"name" gorm:"type:varchar(64);not null"` // Display name, e.g., "GitHub Enterprise"
Slug string `json:"slug" gorm:"type:varchar(64);uniqueIndex;not null"` // URL identifier, e.g., "github-enterprise"
Enabled bool `json:"enabled" gorm:"default:false"` // Whether this provider is enabled
ClientId string `json:"client_id" gorm:"type:varchar(256)"` // OAuth client ID
ClientSecret string `json:"-" gorm:"type:varchar(512)"` // OAuth client secret (not returned to frontend)
AuthorizationEndpoint string `json:"authorization_endpoint" gorm:"type:varchar(512)"` // Authorization URL
TokenEndpoint string `json:"token_endpoint" gorm:"type:varchar(512)"` // Token exchange URL
UserInfoEndpoint string `json:"user_info_endpoint" gorm:"type:varchar(512)"` // User info URL
Scopes string `json:"scopes" gorm:"type:varchar(256);default:'openid profile email'"` // OAuth scopes
// Field mapping configuration (supports JSONPath via gjson)
UserIdField string `json:"user_id_field" gorm:"type:varchar(128);default:'sub'"` // User ID field path, e.g., "sub", "id", "data.user.id"
UsernameField string `json:"username_field" gorm:"type:varchar(128);default:'preferred_username'"` // Username field path
DisplayNameField string `json:"display_name_field" gorm:"type:varchar(128);default:'name'"` // Display name field path
EmailField string `json:"email_field" gorm:"type:varchar(128);default:'email'"` // Email field path
// Advanced options
WellKnown string `json:"well_known" gorm:"type:varchar(512)"` // OIDC discovery endpoint (optional)
AuthStyle int `json:"auth_style" gorm:"default:0"` // 0=auto, 1=params, 2=header (Basic Auth)
CreatedAt time.Time `json:"created_at"`
UpdatedAt time.Time `json:"updated_at"`
}
func (CustomOAuthProvider) TableName() string {
return "custom_oauth_providers"
}
// GetAllCustomOAuthProviders returns all custom OAuth providers
func GetAllCustomOAuthProviders() ([]*CustomOAuthProvider, error) {
var providers []*CustomOAuthProvider
err := DB.Order("id asc").Find(&providers).Error
return providers, err
}
// GetEnabledCustomOAuthProviders returns all enabled custom OAuth providers
func GetEnabledCustomOAuthProviders() ([]*CustomOAuthProvider, error) {
var providers []*CustomOAuthProvider
err := DB.Where("enabled = ?", true).Order("id asc").Find(&providers).Error
return providers, err
}
// GetCustomOAuthProviderById returns a custom OAuth provider by ID
func GetCustomOAuthProviderById(id int) (*CustomOAuthProvider, error) {
var provider CustomOAuthProvider
err := DB.First(&provider, id).Error
if err != nil {
return nil, err
}
return &provider, nil
}
// GetCustomOAuthProviderBySlug returns a custom OAuth provider by slug
func GetCustomOAuthProviderBySlug(slug string) (*CustomOAuthProvider, error) {
var provider CustomOAuthProvider
err := DB.Where("slug = ?", slug).First(&provider).Error
if err != nil {
return nil, err
}
return &provider, nil
}
// CreateCustomOAuthProvider creates a new custom OAuth provider
func CreateCustomOAuthProvider(provider *CustomOAuthProvider) error {
if err := validateCustomOAuthProvider(provider); err != nil {
return err
}
return DB.Create(provider).Error
}
// UpdateCustomOAuthProvider updates an existing custom OAuth provider
func UpdateCustomOAuthProvider(provider *CustomOAuthProvider) error {
if err := validateCustomOAuthProvider(provider); err != nil {
return err
}
return DB.Save(provider).Error
}
// DeleteCustomOAuthProvider deletes a custom OAuth provider by ID
func DeleteCustomOAuthProvider(id int) error {
// First, delete all user bindings for this provider
if err := DB.Where("provider_id = ?", id).Delete(&UserOAuthBinding{}).Error; err != nil {
return err
}
return DB.Delete(&CustomOAuthProvider{}, id).Error
}
// IsSlugTaken checks if a slug is already taken by another provider
func IsSlugTaken(slug string, excludeId int) bool {
var count int64
query := DB.Model(&CustomOAuthProvider{}).Where("slug = ?", slug)
if excludeId > 0 {
query = query.Where("id != ?", excludeId)
}
query.Count(&count)
return count > 0
}
// validateCustomOAuthProvider validates a custom OAuth provider configuration
func validateCustomOAuthProvider(provider *CustomOAuthProvider) error {
if provider.Name == "" {
return errors.New("provider name is required")
}
if provider.Slug == "" {
return errors.New("provider slug is required")
}
// Slug must be lowercase and contain only alphanumeric characters and hyphens
slug := strings.ToLower(provider.Slug)
for _, c := range slug {
if !((c >= 'a' && c <= 'z') || (c >= '0' && c <= '9') || c == '-') {
return errors.New("provider slug must contain only lowercase letters, numbers, and hyphens")
}
}
provider.Slug = slug
if provider.ClientId == "" {
return errors.New("client ID is required")
}
if provider.AuthorizationEndpoint == "" {
return errors.New("authorization endpoint is required")
}
if provider.TokenEndpoint == "" {
return errors.New("token endpoint is required")
}
if provider.UserInfoEndpoint == "" {
return errors.New("user info endpoint is required")
}
// Set defaults for field mappings if empty
if provider.UserIdField == "" {
provider.UserIdField = "sub"
}
if provider.UsernameField == "" {
provider.UsernameField = "preferred_username"
}
if provider.DisplayNameField == "" {
provider.DisplayNameField = "name"
}
if provider.EmailField == "" {
provider.EmailField = "email"
}
if provider.Scopes == "" {
provider.Scopes = "openid profile email"
}
return nil
}

View File

@@ -274,6 +274,8 @@ func migrateDB() error {
&SubscriptionOrder{},
&UserSubscription{},
&SubscriptionPreConsumeRecord{},
&CustomOAuthProvider{},
&UserOAuthBinding{},
)
if err != nil {
return err
@@ -320,6 +322,8 @@ func migrateDBFast() error {
{&SubscriptionOrder{}, "SubscriptionOrder"},
{&UserSubscription{}, "UserSubscription"},
{&SubscriptionPreConsumeRecord{}, "SubscriptionPreConsumeRecord"},
{&CustomOAuthProvider{}, "CustomOAuthProvider"},
{&UserOAuthBinding{}, "UserOAuthBinding"},
}
// 动态计算migration数量确保errChan缓冲区足够大
errChan := make(chan error, len(migrations))

125
model/user_oauth_binding.go Normal file
View File

@@ -0,0 +1,125 @@
package model
import (
"errors"
"time"
)
// UserOAuthBinding stores the binding relationship between users and custom OAuth providers
type UserOAuthBinding struct {
Id int `json:"id" gorm:"primaryKey"`
UserId int `json:"user_id" gorm:"index;not null"` // User ID
ProviderId int `json:"provider_id" gorm:"index;not null"` // Custom OAuth provider ID
ProviderUserId string `json:"provider_user_id" gorm:"type:varchar(256);not null"` // User ID from OAuth provider
CreatedAt time.Time `json:"created_at"`
// Composite unique index to prevent duplicate bindings
// One OAuth account can only be bound to one user
}
func (UserOAuthBinding) TableName() string {
return "user_oauth_bindings"
}
// GetUserOAuthBindingsByUserId returns all OAuth bindings for a user
func GetUserOAuthBindingsByUserId(userId int) ([]*UserOAuthBinding, error) {
var bindings []*UserOAuthBinding
err := DB.Where("user_id = ?", userId).Find(&bindings).Error
return bindings, err
}
// GetUserOAuthBinding returns a specific binding for a user and provider
func GetUserOAuthBinding(userId, providerId int) (*UserOAuthBinding, error) {
var binding UserOAuthBinding
err := DB.Where("user_id = ? AND provider_id = ?", userId, providerId).First(&binding).Error
if err != nil {
return nil, err
}
return &binding, nil
}
// GetUserByOAuthBinding finds a user by provider ID and provider user ID
func GetUserByOAuthBinding(providerId int, providerUserId string) (*User, error) {
var binding UserOAuthBinding
err := DB.Where("provider_id = ? AND provider_user_id = ?", providerId, providerUserId).First(&binding).Error
if err != nil {
return nil, err
}
var user User
err = DB.First(&user, binding.UserId).Error
if err != nil {
return nil, err
}
return &user, nil
}
// IsProviderUserIdTaken checks if a provider user ID is already bound to any user
func IsProviderUserIdTaken(providerId int, providerUserId string) bool {
var count int64
DB.Model(&UserOAuthBinding{}).Where("provider_id = ? AND provider_user_id = ?", providerId, providerUserId).Count(&count)
return count > 0
}
// CreateUserOAuthBinding creates a new OAuth binding
func CreateUserOAuthBinding(binding *UserOAuthBinding) error {
if binding.UserId == 0 {
return errors.New("user ID is required")
}
if binding.ProviderId == 0 {
return errors.New("provider ID is required")
}
if binding.ProviderUserId == "" {
return errors.New("provider user ID is required")
}
// Check if this provider user ID is already taken
if IsProviderUserIdTaken(binding.ProviderId, binding.ProviderUserId) {
return errors.New("this OAuth account is already bound to another user")
}
binding.CreatedAt = time.Now()
return DB.Create(binding).Error
}
// UpdateUserOAuthBinding updates an existing OAuth binding (e.g., rebind to different OAuth account)
func UpdateUserOAuthBinding(userId, providerId int, newProviderUserId string) error {
// Check if the new provider user ID is already taken by another user
var existingBinding UserOAuthBinding
err := DB.Where("provider_id = ? AND provider_user_id = ?", providerId, newProviderUserId).First(&existingBinding).Error
if err == nil && existingBinding.UserId != userId {
return errors.New("this OAuth account is already bound to another user")
}
// Check if user already has a binding for this provider
var binding UserOAuthBinding
err = DB.Where("user_id = ? AND provider_id = ?", userId, providerId).First(&binding).Error
if err != nil {
// No existing binding, create new one
return CreateUserOAuthBinding(&UserOAuthBinding{
UserId: userId,
ProviderId: providerId,
ProviderUserId: newProviderUserId,
})
}
// Update existing binding
return DB.Model(&binding).Update("provider_user_id", newProviderUserId).Error
}
// DeleteUserOAuthBinding deletes an OAuth binding
func DeleteUserOAuthBinding(userId, providerId int) error {
return DB.Where("user_id = ? AND provider_id = ?", userId, providerId).Delete(&UserOAuthBinding{}).Error
}
// DeleteUserOAuthBindingsByUserId deletes all OAuth bindings for a user
func DeleteUserOAuthBindingsByUserId(userId int) error {
return DB.Where("user_id = ?", userId).Delete(&UserOAuthBinding{}).Error
}
// GetBindingCountByProviderId returns the number of bindings for a provider
func GetBindingCountByProviderId(providerId int) (int64, error) {
var count int64
err := DB.Model(&UserOAuthBinding{}).Where("provider_id = ?", providerId).Count(&count).Error
return count, err
}