Merge pull request #2857 from QuantumNous/feat/custom-oauth

feat(oauth): implement custom OAuth provider
This commit is contained in:
Calcium-Ion
2026-02-08 00:13:20 +08:00
committed by GitHub
7 changed files with 192 additions and 45 deletions

View File

@@ -166,21 +166,21 @@ func CreateCustomOAuthProvider(c *gin.Context) {
// UpdateCustomOAuthProviderRequest is the request structure for updating a custom OAuth provider
type UpdateCustomOAuthProviderRequest struct {
Name string `json:"name"`
Slug string `json:"slug"`
Enabled bool `json:"enabled"`
ClientId string `json:"client_id"`
ClientSecret string `json:"client_secret"` // Optional: if empty, keep existing
AuthorizationEndpoint string `json:"authorization_endpoint"`
TokenEndpoint string `json:"token_endpoint"`
UserInfoEndpoint string `json:"user_info_endpoint"`
Scopes string `json:"scopes"`
UserIdField string `json:"user_id_field"`
UsernameField string `json:"username_field"`
DisplayNameField string `json:"display_name_field"`
EmailField string `json:"email_field"`
WellKnown string `json:"well_known"`
AuthStyle int `json:"auth_style"`
Name string `json:"name"`
Slug string `json:"slug"`
Enabled *bool `json:"enabled"` // Optional: if nil, keep existing
ClientId string `json:"client_id"`
ClientSecret string `json:"client_secret"` // Optional: if empty, keep existing
AuthorizationEndpoint string `json:"authorization_endpoint"`
TokenEndpoint string `json:"token_endpoint"`
UserInfoEndpoint string `json:"user_info_endpoint"`
Scopes string `json:"scopes"`
UserIdField string `json:"user_id_field"`
UsernameField string `json:"username_field"`
DisplayNameField string `json:"display_name_field"`
EmailField string `json:"email_field"`
WellKnown *string `json:"well_known"` // Optional: if nil, keep existing
AuthStyle *int `json:"auth_style"` // Optional: if nil, keep existing
}
// UpdateCustomOAuthProvider updates an existing custom OAuth provider
@@ -227,7 +227,9 @@ func UpdateCustomOAuthProvider(c *gin.Context) {
if req.Slug != "" {
provider.Slug = req.Slug
}
provider.Enabled = req.Enabled
if req.Enabled != nil {
provider.Enabled = *req.Enabled
}
if req.ClientId != "" {
provider.ClientId = req.ClientId
}
@@ -258,8 +260,12 @@ func UpdateCustomOAuthProvider(c *gin.Context) {
if req.EmailField != "" {
provider.EmailField = req.EmailField
}
provider.WellKnown = req.WellKnown
provider.AuthStyle = req.AuthStyle
if req.WellKnown != nil {
provider.WellKnown = *req.WellKnown
}
if req.AuthStyle != nil {
provider.AuthStyle = *req.AuthStyle
}
if err := model.UpdateCustomOAuthProvider(provider); err != nil {
common.ApiError(c, err)
@@ -296,7 +302,12 @@ func DeleteCustomOAuthProvider(c *gin.Context) {
}
// Check if there are any user bindings
count, _ := model.GetBindingCountByProviderId(id)
count, err := model.GetBindingCountByProviderId(id)
if err != nil {
common.SysError("Failed to get binding count for provider " + strconv.Itoa(id) + ": " + err.Error())
common.ApiErrorMsg(c, "检查用户绑定时发生错误,请稍后重试")
return
}
if count > 0 {
common.ApiErrorMsg(c, "该 OAuth 提供商还有用户绑定,无法删除。请先解除所有用户绑定。")
return

View File

@@ -11,6 +11,7 @@ import (
"github.com/QuantumNous/new-api/oauth"
"github.com/gin-contrib/sessions"
"github.com/gin-gonic/gin"
"gorm.io/gorm"
)
// providerParams returns map with Provider key for i18n templates
@@ -256,27 +257,62 @@ func findOrCreateOAuthUser(c *gin.Context, provider oauth.Provider, oauthUser *o
inviterId, _ = model.GetUserIdByAffCode(affCode.(string))
}
if err := user.Insert(inviterId); err != nil {
return nil, err
}
// For custom providers, create the binding after user is created
// Use transaction to ensure user creation and OAuth binding are atomic
if genericProvider, ok := provider.(*oauth.GenericOAuthProvider); ok {
binding := &model.UserOAuthBinding{
UserId: user.Id,
ProviderId: genericProvider.GetProviderId(),
ProviderUserId: oauthUser.ProviderUserID,
}
if err := model.CreateUserOAuthBinding(binding); err != nil {
common.SysError(fmt.Sprintf("[OAuth] Failed to create binding for user %d: %s", user.Id, err.Error()))
// Don't fail the registration, just log the error
// Custom provider: create user and binding in a transaction
err := model.DB.Transaction(func(tx *gorm.DB) error {
// Create user
if err := user.InsertWithTx(tx, inviterId); err != nil {
return err
}
// Create OAuth binding
binding := &model.UserOAuthBinding{
UserId: user.Id,
ProviderId: genericProvider.GetProviderId(),
ProviderUserId: oauthUser.ProviderUserID,
}
if err := model.CreateUserOAuthBindingWithTx(tx, binding); err != nil {
return err
}
return nil
})
if err != nil {
return nil, err
}
// Perform post-transaction tasks (logs, sidebar config, inviter rewards)
user.FinalizeOAuthUserCreation(inviterId)
} else {
// Built-in provider: set the provider user ID on the user model
provider.SetProviderUserID(user, oauthUser.ProviderUserID)
if err := user.Update(false); err != nil {
common.SysError(fmt.Sprintf("[OAuth] Failed to update provider ID for user %d: %s", user.Id, err.Error()))
// Built-in provider: create user and update provider ID in a transaction
err := model.DB.Transaction(func(tx *gorm.DB) error {
// Create user
if err := user.InsertWithTx(tx, inviterId); err != nil {
return err
}
// Set the provider user ID on the user model and update
provider.SetProviderUserID(user, oauthUser.ProviderUserID)
if err := tx.Model(user).Updates(map[string]interface{}{
"github_id": user.GitHubId,
"discord_id": user.DiscordId,
"oidc_id": user.OidcId,
"linux_do_id": user.LinuxDOId,
"wechat_id": user.WeChatId,
"telegram_id": user.TelegramId,
}).Error; err != nil {
return err
}
return nil
})
if err != nil {
return nil, err
}
// Perform post-transaction tasks
user.FinalizeOAuthUserCreation(inviterId)
}
return user, nil

View File

@@ -97,13 +97,18 @@ func DeleteCustomOAuthProvider(id int) error {
}
// IsSlugTaken checks if a slug is already taken by another provider
// Returns true on DB errors (fail-closed) to prevent slug conflicts
func IsSlugTaken(slug string, excludeId int) bool {
var count int64
query := DB.Model(&CustomOAuthProvider{}).Where("slug = ?", slug)
if excludeId > 0 {
query = query.Where("id != ?", excludeId)
}
query.Count(&count)
res := query.Count(&count)
if res.Error != nil {
// Fail-closed: treat DB errors as slug being taken to prevent conflicts
return true
}
return count > 0
}

View File

@@ -429,6 +429,65 @@ func (user *User) Insert(inviterId int) error {
return nil
}
// InsertWithTx inserts a new user within an existing transaction.
// This is used for OAuth registration where user creation and binding need to be atomic.
// Post-creation tasks (sidebar config, logs, inviter rewards) are handled after the transaction commits.
func (user *User) InsertWithTx(tx *gorm.DB, inviterId int) error {
var err error
if user.Password != "" {
user.Password, err = common.Password2Hash(user.Password)
if err != nil {
return err
}
}
user.Quota = common.QuotaForNewUser
user.AffCode = common.GetRandomString(4)
// 初始化用户设置
if user.Setting == "" {
defaultSetting := dto.UserSetting{}
user.SetSetting(defaultSetting)
}
result := tx.Create(user)
if result.Error != nil {
return result.Error
}
return nil
}
// FinalizeOAuthUserCreation performs post-transaction tasks for OAuth user creation.
// This should be called after the transaction commits successfully.
func (user *User) FinalizeOAuthUserCreation(inviterId int) {
// 用户创建成功后,根据角色初始化边栏配置
var createdUser User
if err := DB.Where("id = ?", user.Id).First(&createdUser).Error; err == nil {
defaultSidebarConfig := generateDefaultSidebarConfigForRole(createdUser.Role)
if defaultSidebarConfig != "" {
currentSetting := createdUser.GetSetting()
currentSetting.SidebarModules = defaultSidebarConfig
createdUser.SetSetting(currentSetting)
createdUser.Update(false)
common.SysLog(fmt.Sprintf("为新用户 %s (角色: %d) 初始化边栏配置", createdUser.Username, createdUser.Role))
}
}
if common.QuotaForNewUser > 0 {
RecordLog(user.Id, LogTypeSystem, fmt.Sprintf("新用户注册赠送 %s", logger.LogQuota(common.QuotaForNewUser)))
}
if inviterId != 0 {
if common.QuotaForInvitee > 0 {
_ = IncreaseUserQuota(user.Id, common.QuotaForInvitee, true)
RecordLog(user.Id, LogTypeSystem, fmt.Sprintf("使用邀请码赠送 %s", logger.LogQuota(common.QuotaForInvitee)))
}
if common.QuotaForInviter > 0 {
RecordLog(inviterId, LogTypeSystem, fmt.Sprintf("邀请用户赠送 %s", logger.LogQuota(common.QuotaForInviter)))
_ = inviteUser(inviterId)
}
}
}
func (user *User) Update(updatePassword bool) error {
var err error
if updatePassword {

View File

@@ -3,18 +3,17 @@ package model
import (
"errors"
"time"
"gorm.io/gorm"
)
// UserOAuthBinding stores the binding relationship between users and custom OAuth providers
type UserOAuthBinding struct {
Id int `json:"id" gorm:"primaryKey"`
UserId int `json:"user_id" gorm:"index;not null"` // User ID
ProviderId int `json:"provider_id" gorm:"index;not null"` // Custom OAuth provider ID
ProviderUserId string `json:"provider_user_id" gorm:"type:varchar(256);not null"` // User ID from OAuth provider
UserId int `json:"user_id" gorm:"not null;uniqueIndex:ux_user_provider"` // User ID - one binding per user per provider
ProviderId int `json:"provider_id" gorm:"not null;uniqueIndex:ux_user_provider;uniqueIndex:ux_provider_userid"` // Custom OAuth provider ID
ProviderUserId string `json:"provider_user_id" gorm:"type:varchar(256);not null;uniqueIndex:ux_provider_userid"` // User ID from OAuth provider - one OAuth account per provider
CreatedAt time.Time `json:"created_at"`
// Composite unique index to prevent duplicate bindings
// One OAuth account can only be bound to one user
}
func (UserOAuthBinding) TableName() string {
@@ -82,6 +81,29 @@ func CreateUserOAuthBinding(binding *UserOAuthBinding) error {
return DB.Create(binding).Error
}
// CreateUserOAuthBindingWithTx creates a new OAuth binding within a transaction
func CreateUserOAuthBindingWithTx(tx *gorm.DB, binding *UserOAuthBinding) error {
if binding.UserId == 0 {
return errors.New("user ID is required")
}
if binding.ProviderId == 0 {
return errors.New("provider ID is required")
}
if binding.ProviderUserId == "" {
return errors.New("provider user ID is required")
}
// Check if this provider user ID is already taken (use tx to check within the same transaction)
var count int64
tx.Model(&UserOAuthBinding{}).Where("provider_id = ? AND provider_user_id = ?", binding.ProviderId, binding.ProviderUserId).Count(&count)
if count > 0 {
return errors.New("this OAuth account is already bound to another user")
}
binding.CreatedAt = time.Now()
return tx.Create(binding).Error
}
// UpdateUserOAuthBinding updates an existing OAuth binding (e.g., rebind to different OAuth account)
func UpdateUserOAuthBinding(userId, providerId int, newProviderUserId string) error {
// Check if the new provider user ID is already taken by another user

View File

@@ -5,6 +5,7 @@ import (
"context"
"encoding/json"
"fmt"
"io"
"net/http"
"strconv"
"time"
@@ -122,6 +123,17 @@ func (p *GitHubProvider) GetUserInfo(ctx context.Context, token *OAuthToken) (*O
logger.LogDebug(ctx, "[OAuth-GitHub] GetUserInfo response status: %d", res.StatusCode)
// Check for non-200 status codes before attempting to decode
if res.StatusCode != http.StatusOK {
body, _ := io.ReadAll(res.Body)
bodyStr := string(body)
if len(bodyStr) > 500 {
bodyStr = bodyStr[:500] + "..."
}
logger.LogError(ctx, fmt.Sprintf("[OAuth-GitHub] GetUserInfo failed: status=%d, body=%s", res.StatusCode, bodyStr))
return nil, NewOAuthErrorWithRaw(i18n.MsgOAuthGetUserErr, map[string]any{"Provider": "GitHub"}, fmt.Sprintf("status %d", res.StatusCode))
}
var githubUser gitHubUser
err = json.NewDecoder(res.Body).Decode(&githubUser)
if err != nil {

View File

@@ -107,9 +107,11 @@ const AccountManagement = ({
const res = await API.get('/api/user/oauth/bindings');
if (res.data.success) {
setCustomOAuthBindings(res.data.data || []);
} else {
showError(res.data.message || t('获取绑定信息失败'));
}
} catch (error) {
// ignore
showError(error.response?.data?.message || error.message || t('获取绑定信息失败'));
}
};
@@ -131,7 +133,7 @@ const AccountManagement = ({
showError(res.data.message);
}
} catch (error) {
showError(t('操作失败'));
showError(error.response?.data?.message || error.message || t('操作失败'));
} finally {
setCustomOAuthLoading((prev) => ({ ...prev, [providerId]: false }));
}