mirror of
https://github.com/QuantumNous/new-api.git
synced 2026-03-30 04:40:59 +00:00
- Add support for custom OAuth providers, including creation, retrieval, updating, and deletion. - Introduce new model and controller for managing custom OAuth providers. - Enhance existing OAuth logic to accommodate custom providers. - Update API routes for custom OAuth provider management. - Include i18n support for custom OAuth-related messages.
159 lines
5.9 KiB
Go
159 lines
5.9 KiB
Go
package model
|
|
|
|
import (
|
|
"errors"
|
|
"strings"
|
|
"time"
|
|
)
|
|
|
|
// CustomOAuthProvider stores configuration for custom OAuth providers
|
|
type CustomOAuthProvider struct {
|
|
Id int `json:"id" gorm:"primaryKey"`
|
|
Name string `json:"name" gorm:"type:varchar(64);not null"` // Display name, e.g., "GitHub Enterprise"
|
|
Slug string `json:"slug" gorm:"type:varchar(64);uniqueIndex;not null"` // URL identifier, e.g., "github-enterprise"
|
|
Enabled bool `json:"enabled" gorm:"default:false"` // Whether this provider is enabled
|
|
ClientId string `json:"client_id" gorm:"type:varchar(256)"` // OAuth client ID
|
|
ClientSecret string `json:"-" gorm:"type:varchar(512)"` // OAuth client secret (not returned to frontend)
|
|
AuthorizationEndpoint string `json:"authorization_endpoint" gorm:"type:varchar(512)"` // Authorization URL
|
|
TokenEndpoint string `json:"token_endpoint" gorm:"type:varchar(512)"` // Token exchange URL
|
|
UserInfoEndpoint string `json:"user_info_endpoint" gorm:"type:varchar(512)"` // User info URL
|
|
Scopes string `json:"scopes" gorm:"type:varchar(256);default:'openid profile email'"` // OAuth scopes
|
|
|
|
// Field mapping configuration (supports JSONPath via gjson)
|
|
UserIdField string `json:"user_id_field" gorm:"type:varchar(128);default:'sub'"` // User ID field path, e.g., "sub", "id", "data.user.id"
|
|
UsernameField string `json:"username_field" gorm:"type:varchar(128);default:'preferred_username'"` // Username field path
|
|
DisplayNameField string `json:"display_name_field" gorm:"type:varchar(128);default:'name'"` // Display name field path
|
|
EmailField string `json:"email_field" gorm:"type:varchar(128);default:'email'"` // Email field path
|
|
|
|
// Advanced options
|
|
WellKnown string `json:"well_known" gorm:"type:varchar(512)"` // OIDC discovery endpoint (optional)
|
|
AuthStyle int `json:"auth_style" gorm:"default:0"` // 0=auto, 1=params, 2=header (Basic Auth)
|
|
|
|
CreatedAt time.Time `json:"created_at"`
|
|
UpdatedAt time.Time `json:"updated_at"`
|
|
}
|
|
|
|
func (CustomOAuthProvider) TableName() string {
|
|
return "custom_oauth_providers"
|
|
}
|
|
|
|
// GetAllCustomOAuthProviders returns all custom OAuth providers
|
|
func GetAllCustomOAuthProviders() ([]*CustomOAuthProvider, error) {
|
|
var providers []*CustomOAuthProvider
|
|
err := DB.Order("id asc").Find(&providers).Error
|
|
return providers, err
|
|
}
|
|
|
|
// GetEnabledCustomOAuthProviders returns all enabled custom OAuth providers
|
|
func GetEnabledCustomOAuthProviders() ([]*CustomOAuthProvider, error) {
|
|
var providers []*CustomOAuthProvider
|
|
err := DB.Where("enabled = ?", true).Order("id asc").Find(&providers).Error
|
|
return providers, err
|
|
}
|
|
|
|
// GetCustomOAuthProviderById returns a custom OAuth provider by ID
|
|
func GetCustomOAuthProviderById(id int) (*CustomOAuthProvider, error) {
|
|
var provider CustomOAuthProvider
|
|
err := DB.First(&provider, id).Error
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
return &provider, nil
|
|
}
|
|
|
|
// GetCustomOAuthProviderBySlug returns a custom OAuth provider by slug
|
|
func GetCustomOAuthProviderBySlug(slug string) (*CustomOAuthProvider, error) {
|
|
var provider CustomOAuthProvider
|
|
err := DB.Where("slug = ?", slug).First(&provider).Error
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
return &provider, nil
|
|
}
|
|
|
|
// CreateCustomOAuthProvider creates a new custom OAuth provider
|
|
func CreateCustomOAuthProvider(provider *CustomOAuthProvider) error {
|
|
if err := validateCustomOAuthProvider(provider); err != nil {
|
|
return err
|
|
}
|
|
return DB.Create(provider).Error
|
|
}
|
|
|
|
// UpdateCustomOAuthProvider updates an existing custom OAuth provider
|
|
func UpdateCustomOAuthProvider(provider *CustomOAuthProvider) error {
|
|
if err := validateCustomOAuthProvider(provider); err != nil {
|
|
return err
|
|
}
|
|
return DB.Save(provider).Error
|
|
}
|
|
|
|
// DeleteCustomOAuthProvider deletes a custom OAuth provider by ID
|
|
func DeleteCustomOAuthProvider(id int) error {
|
|
// First, delete all user bindings for this provider
|
|
if err := DB.Where("provider_id = ?", id).Delete(&UserOAuthBinding{}).Error; err != nil {
|
|
return err
|
|
}
|
|
return DB.Delete(&CustomOAuthProvider{}, id).Error
|
|
}
|
|
|
|
// IsSlugTaken checks if a slug is already taken by another provider
|
|
func IsSlugTaken(slug string, excludeId int) bool {
|
|
var count int64
|
|
query := DB.Model(&CustomOAuthProvider{}).Where("slug = ?", slug)
|
|
if excludeId > 0 {
|
|
query = query.Where("id != ?", excludeId)
|
|
}
|
|
query.Count(&count)
|
|
return count > 0
|
|
}
|
|
|
|
// validateCustomOAuthProvider validates a custom OAuth provider configuration
|
|
func validateCustomOAuthProvider(provider *CustomOAuthProvider) error {
|
|
if provider.Name == "" {
|
|
return errors.New("provider name is required")
|
|
}
|
|
if provider.Slug == "" {
|
|
return errors.New("provider slug is required")
|
|
}
|
|
// Slug must be lowercase and contain only alphanumeric characters and hyphens
|
|
slug := strings.ToLower(provider.Slug)
|
|
for _, c := range slug {
|
|
if !((c >= 'a' && c <= 'z') || (c >= '0' && c <= '9') || c == '-') {
|
|
return errors.New("provider slug must contain only lowercase letters, numbers, and hyphens")
|
|
}
|
|
}
|
|
provider.Slug = slug
|
|
|
|
if provider.ClientId == "" {
|
|
return errors.New("client ID is required")
|
|
}
|
|
if provider.AuthorizationEndpoint == "" {
|
|
return errors.New("authorization endpoint is required")
|
|
}
|
|
if provider.TokenEndpoint == "" {
|
|
return errors.New("token endpoint is required")
|
|
}
|
|
if provider.UserInfoEndpoint == "" {
|
|
return errors.New("user info endpoint is required")
|
|
}
|
|
|
|
// Set defaults for field mappings if empty
|
|
if provider.UserIdField == "" {
|
|
provider.UserIdField = "sub"
|
|
}
|
|
if provider.UsernameField == "" {
|
|
provider.UsernameField = "preferred_username"
|
|
}
|
|
if provider.DisplayNameField == "" {
|
|
provider.DisplayNameField = "name"
|
|
}
|
|
if provider.EmailField == "" {
|
|
provider.EmailField = "email"
|
|
}
|
|
if provider.Scopes == "" {
|
|
provider.Scopes = "openid profile email"
|
|
}
|
|
|
|
return nil
|
|
}
|