From 2567cff6c8f5a5f40154427aeb032d6678090db2 Mon Sep 17 00:00:00 2001 From: CaIon Date: Thu, 5 Feb 2026 21:48:05 +0800 Subject: [PATCH 1/2] fix(oauth): enhance error handling and transaction management for OAuth user creation and binding - Improve error handling in DeleteCustomOAuthProvider to log and return errors when fetching binding counts. - Refactor user creation and OAuth binding logic to use transactions for atomic operations, ensuring data integrity. - Add unique constraints to UserOAuthBinding model to prevent duplicate bindings. - Enhance GitHub OAuth provider error logging for non-200 responses. - Update AccountManagement component to provide clearer error messages on API failures. --- controller/custom_oauth.go | 7 +- controller/oauth.go | 70 ++++++++++++++----- model/custom_oauth_provider.go | 7 +- model/user.go | 59 ++++++++++++++++ model/user_oauth_binding.go | 34 +++++++-- oauth/github.go | 12 ++++ .../personal/cards/AccountManagement.jsx | 6 +- 7 files changed, 168 insertions(+), 27 deletions(-) diff --git a/controller/custom_oauth.go b/controller/custom_oauth.go index a4acfc38a..f26a55891 100644 --- a/controller/custom_oauth.go +++ b/controller/custom_oauth.go @@ -296,7 +296,12 @@ func DeleteCustomOAuthProvider(c *gin.Context) { } // Check if there are any user bindings - count, _ := model.GetBindingCountByProviderId(id) + count, err := model.GetBindingCountByProviderId(id) + if err != nil { + common.SysError("Failed to get binding count for provider " + strconv.Itoa(id) + ": " + err.Error()) + common.ApiErrorMsg(c, "检查用户绑定时发生错误,请稍后重试") + return + } if count > 0 { common.ApiErrorMsg(c, "该 OAuth 提供商还有用户绑定,无法删除。请先解除所有用户绑定。") return diff --git a/controller/oauth.go b/controller/oauth.go index 58cb40d5c..65e18f9da 100644 --- a/controller/oauth.go +++ b/controller/oauth.go @@ -11,6 +11,7 @@ import ( "github.com/QuantumNous/new-api/oauth" "github.com/gin-contrib/sessions" "github.com/gin-gonic/gin" + "gorm.io/gorm" ) // providerParams returns map with Provider key for i18n templates @@ -256,27 +257,62 @@ func findOrCreateOAuthUser(c *gin.Context, provider oauth.Provider, oauthUser *o inviterId, _ = model.GetUserIdByAffCode(affCode.(string)) } - if err := user.Insert(inviterId); err != nil { - return nil, err - } - - // For custom providers, create the binding after user is created + // Use transaction to ensure user creation and OAuth binding are atomic if genericProvider, ok := provider.(*oauth.GenericOAuthProvider); ok { - binding := &model.UserOAuthBinding{ - UserId: user.Id, - ProviderId: genericProvider.GetProviderId(), - ProviderUserId: oauthUser.ProviderUserID, - } - if err := model.CreateUserOAuthBinding(binding); err != nil { - common.SysError(fmt.Sprintf("[OAuth] Failed to create binding for user %d: %s", user.Id, err.Error())) - // Don't fail the registration, just log the error + // Custom provider: create user and binding in a transaction + err := model.DB.Transaction(func(tx *gorm.DB) error { + // Create user + if err := user.InsertWithTx(tx, inviterId); err != nil { + return err + } + + // Create OAuth binding + binding := &model.UserOAuthBinding{ + UserId: user.Id, + ProviderId: genericProvider.GetProviderId(), + ProviderUserId: oauthUser.ProviderUserID, + } + if err := model.CreateUserOAuthBindingWithTx(tx, binding); err != nil { + return err + } + + return nil + }) + if err != nil { + return nil, err } + + // Perform post-transaction tasks (logs, sidebar config, inviter rewards) + user.FinalizeOAuthUserCreation(inviterId) } else { - // Built-in provider: set the provider user ID on the user model - provider.SetProviderUserID(user, oauthUser.ProviderUserID) - if err := user.Update(false); err != nil { - common.SysError(fmt.Sprintf("[OAuth] Failed to update provider ID for user %d: %s", user.Id, err.Error())) + // Built-in provider: create user and update provider ID in a transaction + err := model.DB.Transaction(func(tx *gorm.DB) error { + // Create user + if err := user.InsertWithTx(tx, inviterId); err != nil { + return err + } + + // Set the provider user ID on the user model and update + provider.SetProviderUserID(user, oauthUser.ProviderUserID) + if err := tx.Model(user).Updates(map[string]interface{}{ + "github_id": user.GitHubId, + "discord_id": user.DiscordId, + "oidc_id": user.OidcId, + "linux_do_id": user.LinuxDOId, + "wechat_id": user.WeChatId, + "telegram_id": user.TelegramId, + }).Error; err != nil { + return err + } + + return nil + }) + if err != nil { + return nil, err } + + // Perform post-transaction tasks + user.FinalizeOAuthUserCreation(inviterId) } return user, nil diff --git a/model/custom_oauth_provider.go b/model/custom_oauth_provider.go index 884e87b06..43c69833a 100644 --- a/model/custom_oauth_provider.go +++ b/model/custom_oauth_provider.go @@ -97,13 +97,18 @@ func DeleteCustomOAuthProvider(id int) error { } // IsSlugTaken checks if a slug is already taken by another provider +// Returns true on DB errors (fail-closed) to prevent slug conflicts func IsSlugTaken(slug string, excludeId int) bool { var count int64 query := DB.Model(&CustomOAuthProvider{}).Where("slug = ?", slug) if excludeId > 0 { query = query.Where("id != ?", excludeId) } - query.Count(&count) + res := query.Count(&count) + if res.Error != nil { + // Fail-closed: treat DB errors as slug being taken to prevent conflicts + return true + } return count > 0 } diff --git a/model/user.go b/model/user.go index 47508a0bb..e0c9c686f 100644 --- a/model/user.go +++ b/model/user.go @@ -429,6 +429,65 @@ func (user *User) Insert(inviterId int) error { return nil } +// InsertWithTx inserts a new user within an existing transaction. +// This is used for OAuth registration where user creation and binding need to be atomic. +// Post-creation tasks (sidebar config, logs, inviter rewards) are handled after the transaction commits. +func (user *User) InsertWithTx(tx *gorm.DB, inviterId int) error { + var err error + if user.Password != "" { + user.Password, err = common.Password2Hash(user.Password) + if err != nil { + return err + } + } + user.Quota = common.QuotaForNewUser + user.AffCode = common.GetRandomString(4) + + // 初始化用户设置 + if user.Setting == "" { + defaultSetting := dto.UserSetting{} + user.SetSetting(defaultSetting) + } + + result := tx.Create(user) + if result.Error != nil { + return result.Error + } + + return nil +} + +// FinalizeOAuthUserCreation performs post-transaction tasks for OAuth user creation. +// This should be called after the transaction commits successfully. +func (user *User) FinalizeOAuthUserCreation(inviterId int) { + // 用户创建成功后,根据角色初始化边栏配置 + var createdUser User + if err := DB.Where("id = ?", user.Id).First(&createdUser).Error; err == nil { + defaultSidebarConfig := generateDefaultSidebarConfigForRole(createdUser.Role) + if defaultSidebarConfig != "" { + currentSetting := createdUser.GetSetting() + currentSetting.SidebarModules = defaultSidebarConfig + createdUser.SetSetting(currentSetting) + createdUser.Update(false) + common.SysLog(fmt.Sprintf("为新用户 %s (角色: %d) 初始化边栏配置", createdUser.Username, createdUser.Role)) + } + } + + if common.QuotaForNewUser > 0 { + RecordLog(user.Id, LogTypeSystem, fmt.Sprintf("新用户注册赠送 %s", logger.LogQuota(common.QuotaForNewUser))) + } + if inviterId != 0 { + if common.QuotaForInvitee > 0 { + _ = IncreaseUserQuota(user.Id, common.QuotaForInvitee, true) + RecordLog(user.Id, LogTypeSystem, fmt.Sprintf("使用邀请码赠送 %s", logger.LogQuota(common.QuotaForInvitee))) + } + if common.QuotaForInviter > 0 { + RecordLog(inviterId, LogTypeSystem, fmt.Sprintf("邀请用户赠送 %s", logger.LogQuota(common.QuotaForInviter))) + _ = inviteUser(inviterId) + } + } +} + func (user *User) Update(updatePassword bool) error { var err error if updatePassword { diff --git a/model/user_oauth_binding.go b/model/user_oauth_binding.go index 7b2acd474..492166251 100644 --- a/model/user_oauth_binding.go +++ b/model/user_oauth_binding.go @@ -3,18 +3,17 @@ package model import ( "errors" "time" + + "gorm.io/gorm" ) // UserOAuthBinding stores the binding relationship between users and custom OAuth providers type UserOAuthBinding struct { Id int `json:"id" gorm:"primaryKey"` - UserId int `json:"user_id" gorm:"index;not null"` // User ID - ProviderId int `json:"provider_id" gorm:"index;not null"` // Custom OAuth provider ID - ProviderUserId string `json:"provider_user_id" gorm:"type:varchar(256);not null"` // User ID from OAuth provider + UserId int `json:"user_id" gorm:"not null;uniqueIndex:ux_user_provider"` // User ID - one binding per user per provider + ProviderId int `json:"provider_id" gorm:"not null;uniqueIndex:ux_user_provider;uniqueIndex:ux_provider_userid"` // Custom OAuth provider ID + ProviderUserId string `json:"provider_user_id" gorm:"type:varchar(256);not null;uniqueIndex:ux_provider_userid"` // User ID from OAuth provider - one OAuth account per provider CreatedAt time.Time `json:"created_at"` - - // Composite unique index to prevent duplicate bindings - // One OAuth account can only be bound to one user } func (UserOAuthBinding) TableName() string { @@ -82,6 +81,29 @@ func CreateUserOAuthBinding(binding *UserOAuthBinding) error { return DB.Create(binding).Error } +// CreateUserOAuthBindingWithTx creates a new OAuth binding within a transaction +func CreateUserOAuthBindingWithTx(tx *gorm.DB, binding *UserOAuthBinding) error { + if binding.UserId == 0 { + return errors.New("user ID is required") + } + if binding.ProviderId == 0 { + return errors.New("provider ID is required") + } + if binding.ProviderUserId == "" { + return errors.New("provider user ID is required") + } + + // Check if this provider user ID is already taken (use tx to check within the same transaction) + var count int64 + tx.Model(&UserOAuthBinding{}).Where("provider_id = ? AND provider_user_id = ?", binding.ProviderId, binding.ProviderUserId).Count(&count) + if count > 0 { + return errors.New("this OAuth account is already bound to another user") + } + + binding.CreatedAt = time.Now() + return tx.Create(binding).Error +} + // UpdateUserOAuthBinding updates an existing OAuth binding (e.g., rebind to different OAuth account) func UpdateUserOAuthBinding(userId, providerId int, newProviderUserId string) error { // Check if the new provider user ID is already taken by another user diff --git a/oauth/github.go b/oauth/github.go index e38f8a784..314118a37 100644 --- a/oauth/github.go +++ b/oauth/github.go @@ -5,6 +5,7 @@ import ( "context" "encoding/json" "fmt" + "io" "net/http" "strconv" "time" @@ -122,6 +123,17 @@ func (p *GitHubProvider) GetUserInfo(ctx context.Context, token *OAuthToken) (*O logger.LogDebug(ctx, "[OAuth-GitHub] GetUserInfo response status: %d", res.StatusCode) + // Check for non-200 status codes before attempting to decode + if res.StatusCode != http.StatusOK { + body, _ := io.ReadAll(res.Body) + bodyStr := string(body) + if len(bodyStr) > 500 { + bodyStr = bodyStr[:500] + "..." + } + logger.LogError(ctx, fmt.Sprintf("[OAuth-GitHub] GetUserInfo failed: status=%d, body=%s", res.StatusCode, bodyStr)) + return nil, NewOAuthErrorWithRaw(i18n.MsgOAuthGetUserErr, map[string]any{"Provider": "GitHub"}, fmt.Sprintf("status %d", res.StatusCode)) + } + var githubUser gitHubUser err = json.NewDecoder(res.Body).Decode(&githubUser) if err != nil { diff --git a/web/src/components/settings/personal/cards/AccountManagement.jsx b/web/src/components/settings/personal/cards/AccountManagement.jsx index 1499c170f..bc27630ba 100644 --- a/web/src/components/settings/personal/cards/AccountManagement.jsx +++ b/web/src/components/settings/personal/cards/AccountManagement.jsx @@ -107,9 +107,11 @@ const AccountManagement = ({ const res = await API.get('/api/user/oauth/bindings'); if (res.data.success) { setCustomOAuthBindings(res.data.data || []); + } else { + showError(res.data.message || t('获取绑定信息失败')); } } catch (error) { - // ignore + showError(error.response?.data?.message || error.message || t('获取绑定信息失败')); } }; @@ -131,7 +133,7 @@ const AccountManagement = ({ showError(res.data.message); } } catch (error) { - showError(t('操作失败')); + showError(error.response?.data?.message || error.message || t('操作失败')); } finally { setCustomOAuthLoading((prev) => ({ ...prev, [providerId]: false })); } From e8d26e52d80d45f637e1fab3db11dfd250cd8e78 Mon Sep 17 00:00:00 2001 From: CaIon Date: Thu, 5 Feb 2026 22:03:30 +0800 Subject: [PATCH 2/2] refactor(oauth): update UpdateCustomOAuthProviderRequest to use pointers for optional fields - Change fields in UpdateCustomOAuthProviderRequest struct to use pointers for optional values, allowing for better handling of nil cases. - Update UpdateCustomOAuthProvider function to check for nil before assigning optional fields, ensuring existing values are preserved when not provided. --- controller/custom_oauth.go | 42 ++++++++++++++++++++++---------------- 1 file changed, 24 insertions(+), 18 deletions(-) diff --git a/controller/custom_oauth.go b/controller/custom_oauth.go index f26a55891..e2245f880 100644 --- a/controller/custom_oauth.go +++ b/controller/custom_oauth.go @@ -166,21 +166,21 @@ func CreateCustomOAuthProvider(c *gin.Context) { // UpdateCustomOAuthProviderRequest is the request structure for updating a custom OAuth provider type UpdateCustomOAuthProviderRequest struct { - Name string `json:"name"` - Slug string `json:"slug"` - Enabled bool `json:"enabled"` - ClientId string `json:"client_id"` - ClientSecret string `json:"client_secret"` // Optional: if empty, keep existing - AuthorizationEndpoint string `json:"authorization_endpoint"` - TokenEndpoint string `json:"token_endpoint"` - UserInfoEndpoint string `json:"user_info_endpoint"` - Scopes string `json:"scopes"` - UserIdField string `json:"user_id_field"` - UsernameField string `json:"username_field"` - DisplayNameField string `json:"display_name_field"` - EmailField string `json:"email_field"` - WellKnown string `json:"well_known"` - AuthStyle int `json:"auth_style"` + Name string `json:"name"` + Slug string `json:"slug"` + Enabled *bool `json:"enabled"` // Optional: if nil, keep existing + ClientId string `json:"client_id"` + ClientSecret string `json:"client_secret"` // Optional: if empty, keep existing + AuthorizationEndpoint string `json:"authorization_endpoint"` + TokenEndpoint string `json:"token_endpoint"` + UserInfoEndpoint string `json:"user_info_endpoint"` + Scopes string `json:"scopes"` + UserIdField string `json:"user_id_field"` + UsernameField string `json:"username_field"` + DisplayNameField string `json:"display_name_field"` + EmailField string `json:"email_field"` + WellKnown *string `json:"well_known"` // Optional: if nil, keep existing + AuthStyle *int `json:"auth_style"` // Optional: if nil, keep existing } // UpdateCustomOAuthProvider updates an existing custom OAuth provider @@ -227,7 +227,9 @@ func UpdateCustomOAuthProvider(c *gin.Context) { if req.Slug != "" { provider.Slug = req.Slug } - provider.Enabled = req.Enabled + if req.Enabled != nil { + provider.Enabled = *req.Enabled + } if req.ClientId != "" { provider.ClientId = req.ClientId } @@ -258,8 +260,12 @@ func UpdateCustomOAuthProvider(c *gin.Context) { if req.EmailField != "" { provider.EmailField = req.EmailField } - provider.WellKnown = req.WellKnown - provider.AuthStyle = req.AuthStyle + if req.WellKnown != nil { + provider.WellKnown = *req.WellKnown + } + if req.AuthStyle != nil { + provider.AuthStyle = *req.AuthStyle + } if err := model.UpdateCustomOAuthProvider(provider); err != nil { common.ApiError(c, err)