diff --git a/controller/discord.go b/controller/discord.go deleted file mode 100644 index a0865de51..000000000 --- a/controller/discord.go +++ /dev/null @@ -1,223 +0,0 @@ -package controller - -import ( - "encoding/json" - "errors" - "fmt" - "net/http" - "net/url" - "strconv" - "strings" - "time" - - "github.com/QuantumNous/new-api/common" - "github.com/QuantumNous/new-api/model" - "github.com/QuantumNous/new-api/setting/system_setting" - - "github.com/gin-contrib/sessions" - "github.com/gin-gonic/gin" -) - -type DiscordResponse struct { - AccessToken string `json:"access_token"` - IDToken string `json:"id_token"` - RefreshToken string `json:"refresh_token"` - TokenType string `json:"token_type"` - ExpiresIn int `json:"expires_in"` - Scope string `json:"scope"` -} - -type DiscordUser struct { - UID string `json:"id"` - ID string `json:"username"` - Name string `json:"global_name"` -} - -func getDiscordUserInfoByCode(code string) (*DiscordUser, error) { - if code == "" { - return nil, errors.New("无效的参数") - } - - values := url.Values{} - values.Set("client_id", system_setting.GetDiscordSettings().ClientId) - values.Set("client_secret", system_setting.GetDiscordSettings().ClientSecret) - values.Set("code", code) - values.Set("grant_type", "authorization_code") - values.Set("redirect_uri", fmt.Sprintf("%s/oauth/discord", system_setting.ServerAddress)) - formData := values.Encode() - req, err := http.NewRequest("POST", "https://discord.com/api/v10/oauth2/token", strings.NewReader(formData)) - if err != nil { - return nil, err - } - req.Header.Set("Content-Type", "application/x-www-form-urlencoded") - req.Header.Set("Accept", "application/json") - client := http.Client{ - Timeout: 5 * time.Second, - } - res, err := client.Do(req) - if err != nil { - common.SysLog(err.Error()) - return nil, errors.New("无法连接至 Discord 服务器,请稍后重试!") - } - defer res.Body.Close() - var discordResponse DiscordResponse - err = json.NewDecoder(res.Body).Decode(&discordResponse) - if err != nil { - return nil, err - } - - if discordResponse.AccessToken == "" { - common.SysError("Discord 获取 Token 失败,请检查设置!") - return nil, errors.New("Discord 获取 Token 失败,请检查设置!") - } - - req, err = http.NewRequest("GET", "https://discord.com/api/v10/users/@me", nil) - if err != nil { - return nil, err - } - req.Header.Set("Authorization", "Bearer "+discordResponse.AccessToken) - res2, err := client.Do(req) - if err != nil { - common.SysLog(err.Error()) - return nil, errors.New("无法连接至 Discord 服务器,请稍后重试!") - } - defer res2.Body.Close() - if res2.StatusCode != http.StatusOK { - common.SysError("Discord 获取用户信息失败!请检查设置!") - return nil, errors.New("Discord 获取用户信息失败!请检查设置!") - } - - var discordUser DiscordUser - err = json.NewDecoder(res2.Body).Decode(&discordUser) - if err != nil { - return nil, err - } - if discordUser.UID == "" || discordUser.ID == "" { - common.SysError("Discord 获取用户信息为空!请检查设置!") - return nil, errors.New("Discord 获取用户信息为空!请检查设置!") - } - return &discordUser, nil -} - -func DiscordOAuth(c *gin.Context) { - session := sessions.Default(c) - state := c.Query("state") - if state == "" || session.Get("oauth_state") == nil || state != session.Get("oauth_state").(string) { - c.JSON(http.StatusForbidden, gin.H{ - "success": false, - "message": "state is empty or not same", - }) - return - } - username := session.Get("username") - if username != nil { - DiscordBind(c) - return - } - if !system_setting.GetDiscordSettings().Enabled { - c.JSON(http.StatusOK, gin.H{ - "success": false, - "message": "管理员未开启通过 Discord 登录以及注册", - }) - return - } - code := c.Query("code") - discordUser, err := getDiscordUserInfoByCode(code) - if err != nil { - common.ApiError(c, err) - return - } - user := model.User{ - DiscordId: discordUser.UID, - } - if model.IsDiscordIdAlreadyTaken(user.DiscordId) { - err := user.FillUserByDiscordId() - if err != nil { - c.JSON(http.StatusOK, gin.H{ - "success": false, - "message": err.Error(), - }) - return - } - } else { - if common.RegisterEnabled { - if discordUser.ID != "" { - user.Username = discordUser.ID - } else { - user.Username = "discord_" + strconv.Itoa(model.GetMaxUserId()+1) - } - if discordUser.Name != "" { - user.DisplayName = discordUser.Name - } else { - user.DisplayName = "Discord User" - } - err := user.Insert(0) - if err != nil { - c.JSON(http.StatusOK, gin.H{ - "success": false, - "message": err.Error(), - }) - return - } - } else { - c.JSON(http.StatusOK, gin.H{ - "success": false, - "message": "管理员关闭了新用户注册", - }) - return - } - } - - if user.Status != common.UserStatusEnabled { - c.JSON(http.StatusOK, gin.H{ - "message": "用户已被封禁", - "success": false, - }) - return - } - setupLogin(&user, c) -} - -func DiscordBind(c *gin.Context) { - if !system_setting.GetDiscordSettings().Enabled { - c.JSON(http.StatusOK, gin.H{ - "success": false, - "message": "管理员未开启通过 Discord 登录以及注册", - }) - return - } - code := c.Query("code") - discordUser, err := getDiscordUserInfoByCode(code) - if err != nil { - common.ApiError(c, err) - return - } - user := model.User{ - DiscordId: discordUser.UID, - } - if model.IsDiscordIdAlreadyTaken(user.DiscordId) { - c.JSON(http.StatusOK, gin.H{ - "success": false, - "message": "该 Discord 账户已被绑定", - }) - return - } - session := sessions.Default(c) - id := session.Get("id") - user.Id = id.(int) - err = user.FillUserById() - if err != nil { - common.ApiError(c, err) - return - } - user.DiscordId = discordUser.UID - err = user.Update(false) - if err != nil { - common.ApiError(c, err) - return - } - c.JSON(http.StatusOK, gin.H{ - "success": true, - "message": "bind", - }) -} diff --git a/controller/github.go b/controller/github.go deleted file mode 100644 index 79f27bca3..000000000 --- a/controller/github.go +++ /dev/null @@ -1,240 +0,0 @@ -package controller - -import ( - "bytes" - "encoding/json" - "errors" - "fmt" - "net/http" - "strconv" - "time" - - "github.com/QuantumNous/new-api/common" - "github.com/QuantumNous/new-api/model" - - "github.com/gin-contrib/sessions" - "github.com/gin-gonic/gin" -) - -type GitHubOAuthResponse struct { - AccessToken string `json:"access_token"` - Scope string `json:"scope"` - TokenType string `json:"token_type"` -} - -type GitHubUser struct { - Login string `json:"login"` - Name string `json:"name"` - Email string `json:"email"` -} - -func getGitHubUserInfoByCode(code string) (*GitHubUser, error) { - if code == "" { - return nil, errors.New("无效的参数") - } - values := map[string]string{"client_id": common.GitHubClientId, "client_secret": common.GitHubClientSecret, "code": code} - jsonData, err := json.Marshal(values) - if err != nil { - return nil, err - } - req, err := http.NewRequest("POST", "https://github.com/login/oauth/access_token", bytes.NewBuffer(jsonData)) - if err != nil { - return nil, err - } - req.Header.Set("Content-Type", "application/json") - req.Header.Set("Accept", "application/json") - client := http.Client{ - Timeout: 20 * time.Second, - } - res, err := client.Do(req) - if err != nil { - common.SysLog(err.Error()) - return nil, errors.New("无法连接至 GitHub 服务器,请稍后重试!") - } - defer res.Body.Close() - var oAuthResponse GitHubOAuthResponse - err = json.NewDecoder(res.Body).Decode(&oAuthResponse) - if err != nil { - return nil, err - } - req, err = http.NewRequest("GET", "https://api.github.com/user", nil) - if err != nil { - return nil, err - } - req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", oAuthResponse.AccessToken)) - res2, err := client.Do(req) - if err != nil { - common.SysLog(err.Error()) - return nil, errors.New("无法连接至 GitHub 服务器,请稍后重试!") - } - defer res2.Body.Close() - var githubUser GitHubUser - err = json.NewDecoder(res2.Body).Decode(&githubUser) - if err != nil { - return nil, err - } - if githubUser.Login == "" { - return nil, errors.New("返回值非法,用户字段为空,请稍后重试!") - } - return &githubUser, nil -} - -func GitHubOAuth(c *gin.Context) { - session := sessions.Default(c) - state := c.Query("state") - if state == "" || session.Get("oauth_state") == nil || state != session.Get("oauth_state").(string) { - c.JSON(http.StatusForbidden, gin.H{ - "success": false, - "message": "state is empty or not same", - }) - return - } - username := session.Get("username") - if username != nil { - GitHubBind(c) - return - } - - if !common.GitHubOAuthEnabled { - c.JSON(http.StatusOK, gin.H{ - "success": false, - "message": "管理员未开启通过 GitHub 登录以及注册", - }) - return - } - code := c.Query("code") - githubUser, err := getGitHubUserInfoByCode(code) - if err != nil { - common.ApiError(c, err) - return - } - user := model.User{ - GitHubId: githubUser.Login, - } - // IsGitHubIdAlreadyTaken is unscoped - if model.IsGitHubIdAlreadyTaken(user.GitHubId) { - // FillUserByGitHubId is scoped - err := user.FillUserByGitHubId() - if err != nil { - c.JSON(http.StatusOK, gin.H{ - "success": false, - "message": err.Error(), - }) - return - } - // if user.Id == 0 , user has been deleted - if user.Id == 0 { - c.JSON(http.StatusOK, gin.H{ - "success": false, - "message": "用户已注销", - }) - return - } - } else { - if common.RegisterEnabled { - user.Username = "github_" + strconv.Itoa(model.GetMaxUserId()+1) - if githubUser.Name != "" { - user.DisplayName = githubUser.Name - } else { - user.DisplayName = "GitHub User" - } - user.Email = githubUser.Email - user.Role = common.RoleCommonUser - user.Status = common.UserStatusEnabled - affCode := session.Get("aff") - inviterId := 0 - if affCode != nil { - inviterId, _ = model.GetUserIdByAffCode(affCode.(string)) - } - - if err := user.Insert(inviterId); err != nil { - c.JSON(http.StatusOK, gin.H{ - "success": false, - "message": err.Error(), - }) - return - } - } else { - c.JSON(http.StatusOK, gin.H{ - "success": false, - "message": "管理员关闭了新用户注册", - }) - return - } - } - - if user.Status != common.UserStatusEnabled { - c.JSON(http.StatusOK, gin.H{ - "message": "用户已被封禁", - "success": false, - }) - return - } - setupLogin(&user, c) -} - -func GitHubBind(c *gin.Context) { - if !common.GitHubOAuthEnabled { - c.JSON(http.StatusOK, gin.H{ - "success": false, - "message": "管理员未开启通过 GitHub 登录以及注册", - }) - return - } - code := c.Query("code") - githubUser, err := getGitHubUserInfoByCode(code) - if err != nil { - common.ApiError(c, err) - return - } - user := model.User{ - GitHubId: githubUser.Login, - } - if model.IsGitHubIdAlreadyTaken(user.GitHubId) { - c.JSON(http.StatusOK, gin.H{ - "success": false, - "message": "该 GitHub 账户已被绑定", - }) - return - } - session := sessions.Default(c) - id := session.Get("id") - // id := c.GetInt("id") // critical bug! - user.Id = id.(int) - err = user.FillUserById() - if err != nil { - common.ApiError(c, err) - return - } - user.GitHubId = githubUser.Login - err = user.Update(false) - if err != nil { - common.ApiError(c, err) - return - } - c.JSON(http.StatusOK, gin.H{ - "success": true, - "message": "bind", - }) - return -} - -func GenerateOAuthCode(c *gin.Context) { - session := sessions.Default(c) - state := common.GetRandomString(12) - affCode := c.Query("aff") - if affCode != "" { - session.Set("aff", affCode) - } - session.Set("oauth_state", state) - err := session.Save() - if err != nil { - common.ApiError(c, err) - return - } - c.JSON(http.StatusOK, gin.H{ - "success": true, - "message": "", - "data": state, - }) -} diff --git a/controller/linuxdo.go b/controller/linuxdo.go deleted file mode 100644 index 5457c9a4f..000000000 --- a/controller/linuxdo.go +++ /dev/null @@ -1,268 +0,0 @@ -package controller - -import ( - "encoding/base64" - "encoding/json" - "errors" - "fmt" - "net/http" - "net/url" - "strconv" - "strings" - "time" - - "github.com/QuantumNous/new-api/common" - "github.com/QuantumNous/new-api/model" - - "github.com/gin-contrib/sessions" - "github.com/gin-gonic/gin" -) - -type LinuxdoUser struct { - Id int `json:"id"` - Username string `json:"username"` - Name string `json:"name"` - Active bool `json:"active"` - TrustLevel int `json:"trust_level"` - Silenced bool `json:"silenced"` -} - -func LinuxDoBind(c *gin.Context) { - if !common.LinuxDOOAuthEnabled { - c.JSON(http.StatusOK, gin.H{ - "success": false, - "message": "管理员未开启通过 Linux DO 登录以及注册", - }) - return - } - - code := c.Query("code") - linuxdoUser, err := getLinuxdoUserInfoByCode(code, c) - if err != nil { - common.ApiError(c, err) - return - } - - user := model.User{ - LinuxDOId: strconv.Itoa(linuxdoUser.Id), - } - - if model.IsLinuxDOIdAlreadyTaken(user.LinuxDOId) { - c.JSON(http.StatusOK, gin.H{ - "success": false, - "message": "该 Linux DO 账户已被绑定", - }) - return - } - - session := sessions.Default(c) - id := session.Get("id") - user.Id = id.(int) - - err = user.FillUserById() - if err != nil { - common.ApiError(c, err) - return - } - - user.LinuxDOId = strconv.Itoa(linuxdoUser.Id) - err = user.Update(false) - if err != nil { - common.ApiError(c, err) - return - } - - c.JSON(http.StatusOK, gin.H{ - "success": true, - "message": "bind", - }) -} - -func getLinuxdoUserInfoByCode(code string, c *gin.Context) (*LinuxdoUser, error) { - if code == "" { - return nil, errors.New("invalid code") - } - - // Get access token using Basic auth - tokenEndpoint := common.GetEnvOrDefaultString("LINUX_DO_TOKEN_ENDPOINT", "https://connect.linux.do/oauth2/token") - credentials := common.LinuxDOClientId + ":" + common.LinuxDOClientSecret - basicAuth := "Basic " + base64.StdEncoding.EncodeToString([]byte(credentials)) - - // Get redirect URI from request - scheme := "http" - if c.Request.TLS != nil { - scheme = "https" - } - redirectURI := fmt.Sprintf("%s://%s/api/oauth/linuxdo", scheme, c.Request.Host) - - data := url.Values{} - data.Set("grant_type", "authorization_code") - data.Set("code", code) - data.Set("redirect_uri", redirectURI) - - req, err := http.NewRequest("POST", tokenEndpoint, strings.NewReader(data.Encode())) - if err != nil { - return nil, err - } - - req.Header.Set("Authorization", basicAuth) - req.Header.Set("Content-Type", "application/x-www-form-urlencoded") - req.Header.Set("Accept", "application/json") - - client := http.Client{Timeout: 5 * time.Second} - res, err := client.Do(req) - if err != nil { - return nil, errors.New("failed to connect to Linux DO server") - } - defer res.Body.Close() - - var tokenRes struct { - AccessToken string `json:"access_token"` - Message string `json:"message"` - } - if err := json.NewDecoder(res.Body).Decode(&tokenRes); err != nil { - return nil, err - } - - if tokenRes.AccessToken == "" { - return nil, fmt.Errorf("failed to get access token: %s", tokenRes.Message) - } - - // Get user info - userEndpoint := common.GetEnvOrDefaultString("LINUX_DO_USER_ENDPOINT", "https://connect.linux.do/api/user") - req, err = http.NewRequest("GET", userEndpoint, nil) - if err != nil { - return nil, err - } - req.Header.Set("Authorization", "Bearer "+tokenRes.AccessToken) - req.Header.Set("Accept", "application/json") - - res2, err := client.Do(req) - if err != nil { - return nil, errors.New("failed to get user info from Linux DO") - } - defer res2.Body.Close() - - var linuxdoUser LinuxdoUser - if err := json.NewDecoder(res2.Body).Decode(&linuxdoUser); err != nil { - return nil, err - } - - if linuxdoUser.Id == 0 { - return nil, errors.New("invalid user info returned") - } - - return &linuxdoUser, nil -} - -func LinuxdoOAuth(c *gin.Context) { - session := sessions.Default(c) - - errorCode := c.Query("error") - if errorCode != "" { - errorDescription := c.Query("error_description") - c.JSON(http.StatusOK, gin.H{ - "success": false, - "message": errorDescription, - }) - return - } - - state := c.Query("state") - if state == "" || session.Get("oauth_state") == nil || state != session.Get("oauth_state").(string) { - c.JSON(http.StatusForbidden, gin.H{ - "success": false, - "message": "state is empty or not same", - }) - return - } - - username := session.Get("username") - if username != nil { - LinuxDoBind(c) - return - } - - if !common.LinuxDOOAuthEnabled { - c.JSON(http.StatusOK, gin.H{ - "success": false, - "message": "管理员未开启通过 Linux DO 登录以及注册", - }) - return - } - - code := c.Query("code") - linuxdoUser, err := getLinuxdoUserInfoByCode(code, c) - if err != nil { - common.ApiError(c, err) - return - } - - user := model.User{ - LinuxDOId: strconv.Itoa(linuxdoUser.Id), - } - - // Check if user exists - if model.IsLinuxDOIdAlreadyTaken(user.LinuxDOId) { - err := user.FillUserByLinuxDOId() - if err != nil { - c.JSON(http.StatusOK, gin.H{ - "success": false, - "message": err.Error(), - }) - return - } - if user.Id == 0 { - c.JSON(http.StatusOK, gin.H{ - "success": false, - "message": "用户已注销", - }) - return - } - } else { - if common.RegisterEnabled { - if linuxdoUser.TrustLevel >= common.LinuxDOMinimumTrustLevel { - user.Username = "linuxdo_" + strconv.Itoa(model.GetMaxUserId()+1) - user.DisplayName = linuxdoUser.Name - user.Role = common.RoleCommonUser - user.Status = common.UserStatusEnabled - - affCode := session.Get("aff") - inviterId := 0 - if affCode != nil { - inviterId, _ = model.GetUserIdByAffCode(affCode.(string)) - } - - if err := user.Insert(inviterId); err != nil { - c.JSON(http.StatusOK, gin.H{ - "success": false, - "message": err.Error(), - }) - return - } - } else { - c.JSON(http.StatusOK, gin.H{ - "success": false, - "message": "Linux DO 信任等级未达到管理员设置的最低信任等级", - }) - return - } - } else { - c.JSON(http.StatusOK, gin.H{ - "success": false, - "message": "管理员关闭了新用户注册", - }) - return - } - } - - if user.Status != common.UserStatusEnabled { - c.JSON(http.StatusOK, gin.H{ - "message": "用户已被封禁", - "success": false, - }) - return - } - - setupLogin(&user, c) -} diff --git a/controller/oauth.go b/controller/oauth.go new file mode 100644 index 000000000..a24912933 --- /dev/null +++ b/controller/oauth.go @@ -0,0 +1,257 @@ +package controller + +import ( + "net/http" + "strconv" + + "github.com/QuantumNous/new-api/common" + "github.com/QuantumNous/new-api/i18n" + "github.com/QuantumNous/new-api/model" + "github.com/QuantumNous/new-api/oauth" + "github.com/gin-contrib/sessions" + "github.com/gin-gonic/gin" +) + +// providerParams returns map with Provider key for i18n templates +func providerParams(name string) map[string]any { + return map[string]any{"Provider": name} +} + +// GenerateOAuthCode generates a state code for OAuth CSRF protection +func GenerateOAuthCode(c *gin.Context) { + session := sessions.Default(c) + state := common.GetRandomString(12) + affCode := c.Query("aff") + if affCode != "" { + session.Set("aff", affCode) + } + session.Set("oauth_state", state) + err := session.Save() + if err != nil { + common.ApiError(c, err) + return + } + c.JSON(http.StatusOK, gin.H{ + "success": true, + "message": "", + "data": state, + }) +} + +// HandleOAuth handles OAuth callback for all standard OAuth providers +func HandleOAuth(c *gin.Context) { + providerName := c.Param("provider") + provider := oauth.GetProvider(providerName) + if provider == nil { + c.JSON(http.StatusBadRequest, gin.H{ + "success": false, + "message": i18n.T(c, i18n.MsgOAuthUnknownProvider), + }) + return + } + + session := sessions.Default(c) + + // 1. Validate state (CSRF protection) + state := c.Query("state") + if state == "" || session.Get("oauth_state") == nil || state != session.Get("oauth_state").(string) { + c.JSON(http.StatusForbidden, gin.H{ + "success": false, + "message": i18n.T(c, i18n.MsgOAuthStateInvalid), + }) + return + } + + // 2. Check if user is already logged in (bind flow) + username := session.Get("username") + if username != nil { + handleOAuthBind(c, provider) + return + } + + // 3. Check if provider is enabled + if !provider.IsEnabled() { + common.ApiErrorI18n(c, i18n.MsgOAuthNotEnabled, providerParams(provider.GetName())) + return + } + + // 4. Handle error from provider + errorCode := c.Query("error") + if errorCode != "" { + errorDescription := c.Query("error_description") + c.JSON(http.StatusOK, gin.H{ + "success": false, + "message": errorDescription, + }) + return + } + + // 5. Exchange code for token + code := c.Query("code") + token, err := provider.ExchangeToken(c.Request.Context(), code, c) + if err != nil { + handleOAuthError(c, err) + return + } + + // 6. Get user info + oauthUser, err := provider.GetUserInfo(c.Request.Context(), token) + if err != nil { + handleOAuthError(c, err) + return + } + + // 7. Find or create user + user, err := findOrCreateOAuthUser(c, provider, oauthUser, session) + if err != nil { + switch err.(type) { + case *OAuthUserDeletedError: + common.ApiErrorI18n(c, i18n.MsgOAuthUserDeleted) + case *OAuthRegistrationDisabledError: + common.ApiErrorI18n(c, i18n.MsgUserRegisterDisabled) + default: + common.ApiError(c, err) + } + return + } + + // 8. Check user status + if user.Status != common.UserStatusEnabled { + common.ApiErrorI18n(c, i18n.MsgOAuthUserBanned) + return + } + + // 9. Setup login + setupLogin(user, c) +} + +// handleOAuthBind handles binding OAuth account to existing user +func handleOAuthBind(c *gin.Context, provider oauth.Provider) { + if !provider.IsEnabled() { + common.ApiErrorI18n(c, i18n.MsgOAuthNotEnabled, providerParams(provider.GetName())) + return + } + + // Exchange code for token + code := c.Query("code") + token, err := provider.ExchangeToken(c.Request.Context(), code, c) + if err != nil { + handleOAuthError(c, err) + return + } + + // Get user info + oauthUser, err := provider.GetUserInfo(c.Request.Context(), token) + if err != nil { + handleOAuthError(c, err) + return + } + + // Check if this OAuth account is already bound + if provider.IsUserIDTaken(oauthUser.ProviderUserID) { + common.ApiErrorI18n(c, i18n.MsgOAuthAlreadyBound, providerParams(provider.GetName())) + return + } + + // Get current user from session + session := sessions.Default(c) + id := session.Get("id") + user := model.User{Id: id.(int)} + err = user.FillUserById() + if err != nil { + common.ApiError(c, err) + return + } + + // Update user with OAuth ID + provider.SetProviderUserID(&user, oauthUser.ProviderUserID) + err = user.Update(false) + if err != nil { + common.ApiError(c, err) + return + } + + common.ApiSuccessI18n(c, i18n.MsgOAuthBindSuccess, nil) +} + +// findOrCreateOAuthUser finds existing user or creates new user +func findOrCreateOAuthUser(c *gin.Context, provider oauth.Provider, oauthUser *oauth.OAuthUser, session sessions.Session) (*model.User, error) { + user := &model.User{} + + // Check if user already exists + if provider.IsUserIDTaken(oauthUser.ProviderUserID) { + provider.SetProviderUserID(user, oauthUser.ProviderUserID) + err := provider.FillUserByProviderID(user, oauthUser.ProviderUserID) + if err != nil { + return nil, err + } + // Check if user has been deleted + if user.Id == 0 { + return nil, &OAuthUserDeletedError{} + } + return user, nil + } + + // User doesn't exist, create new user if registration is enabled + if !common.RegisterEnabled { + return nil, &OAuthRegistrationDisabledError{} + } + + // Set up new user + user.Username = provider.GetProviderPrefix() + strconv.Itoa(model.GetMaxUserId()+1) + if oauthUser.DisplayName != "" { + user.DisplayName = oauthUser.DisplayName + } else if oauthUser.Username != "" { + user.DisplayName = oauthUser.Username + } else { + user.DisplayName = provider.GetName() + " User" + } + if oauthUser.Email != "" { + user.Email = oauthUser.Email + } + user.Role = common.RoleCommonUser + user.Status = common.UserStatusEnabled + provider.SetProviderUserID(user, oauthUser.ProviderUserID) + + // Handle affiliate code + affCode := session.Get("aff") + inviterId := 0 + if affCode != nil { + inviterId, _ = model.GetUserIdByAffCode(affCode.(string)) + } + + if err := user.Insert(inviterId); err != nil { + return nil, err + } + + return user, nil +} + +// Error types for OAuth +type OAuthUserDeletedError struct{} + +func (e *OAuthUserDeletedError) Error() string { + return "user has been deleted" +} + +type OAuthRegistrationDisabledError struct{} + +func (e *OAuthRegistrationDisabledError) Error() string { + return "registration is disabled" +} + +// handleOAuthError handles OAuth errors and returns translated message +func handleOAuthError(c *gin.Context, err error) { + switch e := err.(type) { + case *oauth.OAuthError: + if e.Params != nil { + common.ApiErrorI18n(c, e.MsgKey, e.Params) + } else { + common.ApiErrorI18n(c, e.MsgKey) + } + case *oauth.TrustLevelError: + common.ApiErrorI18n(c, i18n.MsgOAuthTrustLevelLow) + default: + common.ApiError(c, err) + } +} diff --git a/controller/oidc.go b/controller/oidc.go deleted file mode 100644 index ac49f84e1..000000000 --- a/controller/oidc.go +++ /dev/null @@ -1,228 +0,0 @@ -package controller - -import ( - "encoding/json" - "errors" - "fmt" - "net/http" - "net/url" - "strconv" - "strings" - "time" - - "github.com/QuantumNous/new-api/common" - "github.com/QuantumNous/new-api/model" - "github.com/QuantumNous/new-api/setting/system_setting" - - "github.com/gin-contrib/sessions" - "github.com/gin-gonic/gin" -) - -type OidcResponse struct { - AccessToken string `json:"access_token"` - IDToken string `json:"id_token"` - RefreshToken string `json:"refresh_token"` - TokenType string `json:"token_type"` - ExpiresIn int `json:"expires_in"` - Scope string `json:"scope"` -} - -type OidcUser struct { - OpenID string `json:"sub"` - Email string `json:"email"` - Name string `json:"name"` - PreferredUsername string `json:"preferred_username"` - Picture string `json:"picture"` -} - -func getOidcUserInfoByCode(code string) (*OidcUser, error) { - if code == "" { - return nil, errors.New("无效的参数") - } - - values := url.Values{} - values.Set("client_id", system_setting.GetOIDCSettings().ClientId) - values.Set("client_secret", system_setting.GetOIDCSettings().ClientSecret) - values.Set("code", code) - values.Set("grant_type", "authorization_code") - values.Set("redirect_uri", fmt.Sprintf("%s/oauth/oidc", system_setting.ServerAddress)) - formData := values.Encode() - req, err := http.NewRequest("POST", system_setting.GetOIDCSettings().TokenEndpoint, strings.NewReader(formData)) - if err != nil { - return nil, err - } - req.Header.Set("Content-Type", "application/x-www-form-urlencoded") - req.Header.Set("Accept", "application/json") - client := http.Client{ - Timeout: 5 * time.Second, - } - res, err := client.Do(req) - if err != nil { - common.SysLog(err.Error()) - return nil, errors.New("无法连接至 OIDC 服务器,请稍后重试!") - } - defer res.Body.Close() - var oidcResponse OidcResponse - err = json.NewDecoder(res.Body).Decode(&oidcResponse) - if err != nil { - return nil, err - } - - if oidcResponse.AccessToken == "" { - common.SysLog("OIDC 获取 Token 失败,请检查设置!") - return nil, errors.New("OIDC 获取 Token 失败,请检查设置!") - } - - req, err = http.NewRequest("GET", system_setting.GetOIDCSettings().UserInfoEndpoint, nil) - if err != nil { - return nil, err - } - req.Header.Set("Authorization", "Bearer "+oidcResponse.AccessToken) - res2, err := client.Do(req) - if err != nil { - common.SysLog(err.Error()) - return nil, errors.New("无法连接至 OIDC 服务器,请稍后重试!") - } - defer res2.Body.Close() - if res2.StatusCode != http.StatusOK { - common.SysLog("OIDC 获取用户信息失败!请检查设置!") - return nil, errors.New("OIDC 获取用户信息失败!请检查设置!") - } - - var oidcUser OidcUser - err = json.NewDecoder(res2.Body).Decode(&oidcUser) - if err != nil { - return nil, err - } - if oidcUser.OpenID == "" || oidcUser.Email == "" { - common.SysLog("OIDC 获取用户信息为空!请检查设置!") - return nil, errors.New("OIDC 获取用户信息为空!请检查设置!") - } - return &oidcUser, nil -} - -func OidcAuth(c *gin.Context) { - session := sessions.Default(c) - state := c.Query("state") - if state == "" || session.Get("oauth_state") == nil || state != session.Get("oauth_state").(string) { - c.JSON(http.StatusForbidden, gin.H{ - "success": false, - "message": "state is empty or not same", - }) - return - } - username := session.Get("username") - if username != nil { - OidcBind(c) - return - } - if !system_setting.GetOIDCSettings().Enabled { - c.JSON(http.StatusOK, gin.H{ - "success": false, - "message": "管理员未开启通过 OIDC 登录以及注册", - }) - return - } - code := c.Query("code") - oidcUser, err := getOidcUserInfoByCode(code) - if err != nil { - common.ApiError(c, err) - return - } - user := model.User{ - OidcId: oidcUser.OpenID, - } - if model.IsOidcIdAlreadyTaken(user.OidcId) { - err := user.FillUserByOidcId() - if err != nil { - c.JSON(http.StatusOK, gin.H{ - "success": false, - "message": err.Error(), - }) - return - } - } else { - if common.RegisterEnabled { - user.Email = oidcUser.Email - if oidcUser.PreferredUsername != "" { - user.Username = oidcUser.PreferredUsername - } else { - user.Username = "oidc_" + strconv.Itoa(model.GetMaxUserId()+1) - } - if oidcUser.Name != "" { - user.DisplayName = oidcUser.Name - } else { - user.DisplayName = "OIDC User" - } - err := user.Insert(0) - if err != nil { - c.JSON(http.StatusOK, gin.H{ - "success": false, - "message": err.Error(), - }) - return - } - } else { - c.JSON(http.StatusOK, gin.H{ - "success": false, - "message": "管理员关闭了新用户注册", - }) - return - } - } - - if user.Status != common.UserStatusEnabled { - c.JSON(http.StatusOK, gin.H{ - "message": "用户已被封禁", - "success": false, - }) - return - } - setupLogin(&user, c) -} - -func OidcBind(c *gin.Context) { - if !system_setting.GetOIDCSettings().Enabled { - c.JSON(http.StatusOK, gin.H{ - "success": false, - "message": "管理员未开启通过 OIDC 登录以及注册", - }) - return - } - code := c.Query("code") - oidcUser, err := getOidcUserInfoByCode(code) - if err != nil { - common.ApiError(c, err) - return - } - user := model.User{ - OidcId: oidcUser.OpenID, - } - if model.IsOidcIdAlreadyTaken(user.OidcId) { - c.JSON(http.StatusOK, gin.H{ - "success": false, - "message": "该 OIDC 账户已被绑定", - }) - return - } - session := sessions.Default(c) - id := session.Get("id") - // id := c.GetInt("id") // critical bug! - user.Id = id.(int) - err = user.FillUserById() - if err != nil { - common.ApiError(c, err) - return - } - user.OidcId = oidcUser.OpenID - err = user.Update(false) - if err != nil { - common.ApiError(c, err) - return - } - c.JSON(http.StatusOK, gin.H{ - "success": true, - "message": "bind", - }) - return -} diff --git a/i18n/keys.go b/i18n/keys.go index 5de0d43b2..6ac0a574c 100644 --- a/i18n/keys.go +++ b/i18n/keys.go @@ -264,9 +264,20 @@ const ( // OAuth related messages const ( - MsgOAuthInvalidCode = "oauth.invalid_code" - MsgOAuthGetUserErr = "oauth.get_user_error" - MsgOAuthAccountUsed = "oauth.account_used" + MsgOAuthInvalidCode = "oauth.invalid_code" + MsgOAuthGetUserErr = "oauth.get_user_error" + MsgOAuthAccountUsed = "oauth.account_used" + MsgOAuthUnknownProvider = "oauth.unknown_provider" + MsgOAuthStateInvalid = "oauth.state_invalid" + MsgOAuthNotEnabled = "oauth.not_enabled" + MsgOAuthUserDeleted = "oauth.user_deleted" + MsgOAuthUserBanned = "oauth.user_banned" + MsgOAuthBindSuccess = "oauth.bind_success" + MsgOAuthAlreadyBound = "oauth.already_bound" + MsgOAuthConnectFailed = "oauth.connect_failed" + MsgOAuthTokenFailed = "oauth.token_failed" + MsgOAuthUserInfoEmpty = "oauth.user_info_empty" + MsgOAuthTrustLevelLow = "oauth.trust_level_low" ) // Model layer error messages (for translation in controller) diff --git a/i18n/locales/en.yaml b/i18n/locales/en.yaml index 994ff7837..e44f7ad7b 100644 --- a/i18n/locales/en.yaml +++ b/i18n/locales/en.yaml @@ -223,6 +223,17 @@ ability.repair_running: "A repair task is already running, please try again late oauth.invalid_code: "Invalid authorization code" oauth.get_user_error: "Failed to get user information" oauth.account_used: "This account has been bound to another user" +oauth.unknown_provider: "Unknown OAuth provider" +oauth.state_invalid: "State parameter is empty or mismatched" +oauth.not_enabled: "{{.Provider}} login and registration has not been enabled by administrator" +oauth.user_deleted: "User has been deleted" +oauth.user_banned: "User has been banned" +oauth.bind_success: "Binding successful" +oauth.already_bound: "This {{.Provider}} account has already been bound" +oauth.connect_failed: "Unable to connect to {{.Provider}} server, please try again later" +oauth.token_failed: "Failed to get token from {{.Provider}}, please check settings" +oauth.user_info_empty: "{{.Provider}} returned empty user info, please check settings" +oauth.trust_level_low: "Linux DO trust level does not meet the minimum required by administrator" # Model layer error messages redeem.failed: "Redemption failed, please try again later" diff --git a/i18n/locales/zh.yaml b/i18n/locales/zh.yaml index 58576ac7c..9098e977e 100644 --- a/i18n/locales/zh.yaml +++ b/i18n/locales/zh.yaml @@ -224,6 +224,17 @@ ability.repair_running: "已经有一个修复任务在运行中,请稍后再 oauth.invalid_code: "无效的授权码" oauth.get_user_error: "获取用户信息失败" oauth.account_used: "该账户已被其他用户绑定" +oauth.unknown_provider: "未知的 OAuth 提供商" +oauth.state_invalid: "state 参数为空或不匹配" +oauth.not_enabled: "管理员未开启通过 {{.Provider}} 登录以及注册" +oauth.user_deleted: "用户已注销" +oauth.user_banned: "用户已被封禁" +oauth.bind_success: "绑定成功" +oauth.already_bound: "该 {{.Provider}} 账户已被绑定" +oauth.connect_failed: "无法连接至 {{.Provider}} 服务器,请稍后重试" +oauth.token_failed: "{{.Provider}} 获取 Token 失败,请检查设置" +oauth.user_info_empty: "{{.Provider}} 获取用户信息为空,请检查设置" +oauth.trust_level_low: "Linux DO 信任等级未达到管理员设置的最低信任等级" # Model layer error messages redeem.failed: "兑换失败,请稍后重试" diff --git a/oauth/discord.go b/oauth/discord.go new file mode 100644 index 000000000..b626d2f82 --- /dev/null +++ b/oauth/discord.go @@ -0,0 +1,172 @@ +package oauth + +import ( + "context" + "encoding/json" + "fmt" + "net/http" + "net/url" + "strings" + "time" + + "github.com/QuantumNous/new-api/i18n" + "github.com/QuantumNous/new-api/logger" + "github.com/QuantumNous/new-api/model" + "github.com/QuantumNous/new-api/setting/system_setting" + "github.com/gin-gonic/gin" +) + +func init() { + Register("discord", &DiscordProvider{}) +} + +// DiscordProvider implements OAuth for Discord +type DiscordProvider struct{} + +type discordOAuthResponse struct { + AccessToken string `json:"access_token"` + IDToken string `json:"id_token"` + RefreshToken string `json:"refresh_token"` + TokenType string `json:"token_type"` + ExpiresIn int `json:"expires_in"` + Scope string `json:"scope"` +} + +type discordUser struct { + UID string `json:"id"` + ID string `json:"username"` + Name string `json:"global_name"` +} + +func (p *DiscordProvider) GetName() string { + return "Discord" +} + +func (p *DiscordProvider) IsEnabled() bool { + return system_setting.GetDiscordSettings().Enabled +} + +func (p *DiscordProvider) ExchangeToken(ctx context.Context, code string, c *gin.Context) (*OAuthToken, error) { + if code == "" { + return nil, NewOAuthError(i18n.MsgOAuthInvalidCode, nil) + } + + logger.LogDebug(ctx, "[OAuth-Discord] ExchangeToken: code=%s...", code[:min(len(code), 10)]) + + settings := system_setting.GetDiscordSettings() + redirectUri := fmt.Sprintf("%s/oauth/discord", system_setting.ServerAddress) + values := url.Values{} + values.Set("client_id", settings.ClientId) + values.Set("client_secret", settings.ClientSecret) + values.Set("code", code) + values.Set("grant_type", "authorization_code") + values.Set("redirect_uri", redirectUri) + + logger.LogDebug(ctx, "[OAuth-Discord] ExchangeToken: redirect_uri=%s", redirectUri) + + req, err := http.NewRequestWithContext(ctx, "POST", "https://discord.com/api/v10/oauth2/token", strings.NewReader(values.Encode())) + if err != nil { + return nil, err + } + req.Header.Set("Content-Type", "application/x-www-form-urlencoded") + req.Header.Set("Accept", "application/json") + + client := http.Client{ + Timeout: 5 * time.Second, + } + res, err := client.Do(req) + if err != nil { + logger.LogError(ctx, fmt.Sprintf("[OAuth-Discord] ExchangeToken error: %s", err.Error())) + return nil, NewOAuthErrorWithRaw(i18n.MsgOAuthConnectFailed, map[string]any{"Provider": "Discord"}, err.Error()) + } + defer res.Body.Close() + + logger.LogDebug(ctx, "[OAuth-Discord] ExchangeToken response status: %d", res.StatusCode) + + var discordResponse discordOAuthResponse + err = json.NewDecoder(res.Body).Decode(&discordResponse) + if err != nil { + logger.LogError(ctx, fmt.Sprintf("[OAuth-Discord] ExchangeToken decode error: %s", err.Error())) + return nil, err + } + + if discordResponse.AccessToken == "" { + logger.LogError(ctx, "[OAuth-Discord] ExchangeToken failed: empty access token") + return nil, NewOAuthError(i18n.MsgOAuthTokenFailed, map[string]any{"Provider": "Discord"}) + } + + logger.LogDebug(ctx, "[OAuth-Discord] ExchangeToken success: scope=%s", discordResponse.Scope) + + return &OAuthToken{ + AccessToken: discordResponse.AccessToken, + TokenType: discordResponse.TokenType, + RefreshToken: discordResponse.RefreshToken, + ExpiresIn: discordResponse.ExpiresIn, + Scope: discordResponse.Scope, + IDToken: discordResponse.IDToken, + }, nil +} + +func (p *DiscordProvider) GetUserInfo(ctx context.Context, token *OAuthToken) (*OAuthUser, error) { + logger.LogDebug(ctx, "[OAuth-Discord] GetUserInfo: fetching user info") + + req, err := http.NewRequestWithContext(ctx, "GET", "https://discord.com/api/v10/users/@me", nil) + if err != nil { + return nil, err + } + req.Header.Set("Authorization", "Bearer "+token.AccessToken) + + client := http.Client{ + Timeout: 5 * time.Second, + } + res, err := client.Do(req) + if err != nil { + logger.LogError(ctx, fmt.Sprintf("[OAuth-Discord] GetUserInfo error: %s", err.Error())) + return nil, NewOAuthErrorWithRaw(i18n.MsgOAuthConnectFailed, map[string]any{"Provider": "Discord"}, err.Error()) + } + defer res.Body.Close() + + logger.LogDebug(ctx, "[OAuth-Discord] GetUserInfo response status: %d", res.StatusCode) + + if res.StatusCode != http.StatusOK { + logger.LogError(ctx, fmt.Sprintf("[OAuth-Discord] GetUserInfo failed: status=%d", res.StatusCode)) + return nil, NewOAuthError(i18n.MsgOAuthGetUserErr, nil) + } + + var discordUser discordUser + err = json.NewDecoder(res.Body).Decode(&discordUser) + if err != nil { + logger.LogError(ctx, fmt.Sprintf("[OAuth-Discord] GetUserInfo decode error: %s", err.Error())) + return nil, err + } + + if discordUser.UID == "" || discordUser.ID == "" { + logger.LogError(ctx, "[OAuth-Discord] GetUserInfo failed: empty user fields") + return nil, NewOAuthError(i18n.MsgOAuthUserInfoEmpty, map[string]any{"Provider": "Discord"}) + } + + logger.LogDebug(ctx, "[OAuth-Discord] GetUserInfo success: uid=%s, username=%s, name=%s", discordUser.UID, discordUser.ID, discordUser.Name) + + return &OAuthUser{ + ProviderUserID: discordUser.UID, + Username: discordUser.ID, + DisplayName: discordUser.Name, + }, nil +} + +func (p *DiscordProvider) IsUserIDTaken(providerUserID string) bool { + return model.IsDiscordIdAlreadyTaken(providerUserID) +} + +func (p *DiscordProvider) FillUserByProviderID(user *model.User, providerUserID string) error { + user.DiscordId = providerUserID + return user.FillUserByDiscordId() +} + +func (p *DiscordProvider) SetProviderUserID(user *model.User, providerUserID string) { + user.DiscordId = providerUserID +} + +func (p *DiscordProvider) GetProviderPrefix() string { + return "discord_" +} diff --git a/oauth/github.go b/oauth/github.go new file mode 100644 index 000000000..d080ce54e --- /dev/null +++ b/oauth/github.go @@ -0,0 +1,160 @@ +package oauth + +import ( + "bytes" + "context" + "encoding/json" + "fmt" + "net/http" + "time" + + "github.com/QuantumNous/new-api/common" + "github.com/QuantumNous/new-api/i18n" + "github.com/QuantumNous/new-api/logger" + "github.com/QuantumNous/new-api/model" + "github.com/gin-gonic/gin" +) + +func init() { + Register("github", &GitHubProvider{}) +} + +// GitHubProvider implements OAuth for GitHub +type GitHubProvider struct{} + +type gitHubOAuthResponse struct { + AccessToken string `json:"access_token"` + Scope string `json:"scope"` + TokenType string `json:"token_type"` +} + +type gitHubUser struct { + Login string `json:"login"` + Name string `json:"name"` + Email string `json:"email"` +} + +func (p *GitHubProvider) GetName() string { + return "GitHub" +} + +func (p *GitHubProvider) IsEnabled() bool { + return common.GitHubOAuthEnabled +} + +func (p *GitHubProvider) ExchangeToken(ctx context.Context, code string, c *gin.Context) (*OAuthToken, error) { + if code == "" { + return nil, NewOAuthError(i18n.MsgOAuthInvalidCode, nil) + } + + logger.LogDebug(ctx, "[OAuth-GitHub] ExchangeToken: code=%s...", code[:min(len(code), 10)]) + + values := map[string]string{ + "client_id": common.GitHubClientId, + "client_secret": common.GitHubClientSecret, + "code": code, + } + jsonData, err := json.Marshal(values) + if err != nil { + return nil, err + } + + req, err := http.NewRequestWithContext(ctx, "POST", "https://github.com/login/oauth/access_token", bytes.NewBuffer(jsonData)) + if err != nil { + return nil, err + } + req.Header.Set("Content-Type", "application/json") + req.Header.Set("Accept", "application/json") + + client := http.Client{ + Timeout: 20 * time.Second, + } + res, err := client.Do(req) + if err != nil { + logger.LogError(ctx, fmt.Sprintf("[OAuth-GitHub] ExchangeToken error: %s", err.Error())) + return nil, NewOAuthErrorWithRaw(i18n.MsgOAuthConnectFailed, map[string]any{"Provider": "GitHub"}, err.Error()) + } + defer res.Body.Close() + + logger.LogDebug(ctx, "[OAuth-GitHub] ExchangeToken response status: %d", res.StatusCode) + + var oAuthResponse gitHubOAuthResponse + err = json.NewDecoder(res.Body).Decode(&oAuthResponse) + if err != nil { + logger.LogError(ctx, fmt.Sprintf("[OAuth-GitHub] ExchangeToken decode error: %s", err.Error())) + return nil, err + } + + if oAuthResponse.AccessToken == "" { + logger.LogError(ctx, "[OAuth-GitHub] ExchangeToken failed: empty access token") + return nil, NewOAuthError(i18n.MsgOAuthTokenFailed, map[string]any{"Provider": "GitHub"}) + } + + logger.LogDebug(ctx, "[OAuth-GitHub] ExchangeToken success: scope=%s", oAuthResponse.Scope) + + return &OAuthToken{ + AccessToken: oAuthResponse.AccessToken, + TokenType: oAuthResponse.TokenType, + Scope: oAuthResponse.Scope, + }, nil +} + +func (p *GitHubProvider) GetUserInfo(ctx context.Context, token *OAuthToken) (*OAuthUser, error) { + logger.LogDebug(ctx, "[OAuth-GitHub] GetUserInfo: fetching user info") + + req, err := http.NewRequestWithContext(ctx, "GET", "https://api.github.com/user", nil) + if err != nil { + return nil, err + } + req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", token.AccessToken)) + + client := http.Client{ + Timeout: 20 * time.Second, + } + res, err := client.Do(req) + if err != nil { + logger.LogError(ctx, fmt.Sprintf("[OAuth-GitHub] GetUserInfo error: %s", err.Error())) + return nil, NewOAuthErrorWithRaw(i18n.MsgOAuthConnectFailed, map[string]any{"Provider": "GitHub"}, err.Error()) + } + defer res.Body.Close() + + logger.LogDebug(ctx, "[OAuth-GitHub] GetUserInfo response status: %d", res.StatusCode) + + var githubUser gitHubUser + err = json.NewDecoder(res.Body).Decode(&githubUser) + if err != nil { + logger.LogError(ctx, fmt.Sprintf("[OAuth-GitHub] GetUserInfo decode error: %s", err.Error())) + return nil, err + } + + if githubUser.Login == "" { + logger.LogError(ctx, "[OAuth-GitHub] GetUserInfo failed: empty login field") + return nil, NewOAuthError(i18n.MsgOAuthUserInfoEmpty, map[string]any{"Provider": "GitHub"}) + } + + logger.LogDebug(ctx, "[OAuth-GitHub] GetUserInfo success: login=%s, name=%s, email=%s", githubUser.Login, githubUser.Name, githubUser.Email) + + return &OAuthUser{ + ProviderUserID: githubUser.Login, + Username: githubUser.Login, + DisplayName: githubUser.Name, + Email: githubUser.Email, + }, nil +} + +func (p *GitHubProvider) IsUserIDTaken(providerUserID string) bool { + return model.IsGitHubIdAlreadyTaken(providerUserID) +} + +func (p *GitHubProvider) FillUserByProviderID(user *model.User, providerUserID string) error { + user.GitHubId = providerUserID + return user.FillUserByGitHubId() +} + +func (p *GitHubProvider) SetProviderUserID(user *model.User, providerUserID string) { + user.GitHubId = providerUserID +} + +func (p *GitHubProvider) GetProviderPrefix() string { + return "github_" +} diff --git a/oauth/linuxdo.go b/oauth/linuxdo.go new file mode 100644 index 000000000..1ed91e009 --- /dev/null +++ b/oauth/linuxdo.go @@ -0,0 +1,195 @@ +package oauth + +import ( + "context" + "encoding/base64" + "encoding/json" + "fmt" + "net/http" + "net/url" + "strconv" + "strings" + "time" + + "github.com/QuantumNous/new-api/common" + "github.com/QuantumNous/new-api/i18n" + "github.com/QuantumNous/new-api/logger" + "github.com/QuantumNous/new-api/model" + "github.com/gin-gonic/gin" +) + +func init() { + Register("linuxdo", &LinuxDOProvider{}) +} + +// LinuxDOProvider implements OAuth for Linux DO +type LinuxDOProvider struct{} + +type linuxdoUser struct { + Id int `json:"id"` + Username string `json:"username"` + Name string `json:"name"` + Active bool `json:"active"` + TrustLevel int `json:"trust_level"` + Silenced bool `json:"silenced"` +} + +func (p *LinuxDOProvider) GetName() string { + return "Linux DO" +} + +func (p *LinuxDOProvider) IsEnabled() bool { + return common.LinuxDOOAuthEnabled +} + +func (p *LinuxDOProvider) ExchangeToken(ctx context.Context, code string, c *gin.Context) (*OAuthToken, error) { + if code == "" { + return nil, NewOAuthError(i18n.MsgOAuthInvalidCode, nil) + } + + logger.LogDebug(ctx, "[OAuth-LinuxDO] ExchangeToken: code=%s...", code[:min(len(code), 10)]) + + // Get access token using Basic auth + tokenEndpoint := common.GetEnvOrDefaultString("LINUX_DO_TOKEN_ENDPOINT", "https://connect.linux.do/oauth2/token") + credentials := common.LinuxDOClientId + ":" + common.LinuxDOClientSecret + basicAuth := "Basic " + base64.StdEncoding.EncodeToString([]byte(credentials)) + + // Get redirect URI from request + scheme := "http" + if c.Request.TLS != nil { + scheme = "https" + } + redirectURI := fmt.Sprintf("%s://%s/api/oauth/linuxdo", scheme, c.Request.Host) + + logger.LogDebug(ctx, "[OAuth-LinuxDO] ExchangeToken: token_endpoint=%s, redirect_uri=%s", tokenEndpoint, redirectURI) + + data := url.Values{} + data.Set("grant_type", "authorization_code") + data.Set("code", code) + data.Set("redirect_uri", redirectURI) + + req, err := http.NewRequestWithContext(ctx, "POST", tokenEndpoint, strings.NewReader(data.Encode())) + if err != nil { + return nil, err + } + req.Header.Set("Authorization", basicAuth) + req.Header.Set("Content-Type", "application/x-www-form-urlencoded") + req.Header.Set("Accept", "application/json") + + client := http.Client{Timeout: 5 * time.Second} + res, err := client.Do(req) + if err != nil { + logger.LogError(ctx, fmt.Sprintf("[OAuth-LinuxDO] ExchangeToken error: %s", err.Error())) + return nil, NewOAuthErrorWithRaw(i18n.MsgOAuthConnectFailed, map[string]any{"Provider": "Linux DO"}, err.Error()) + } + defer res.Body.Close() + + logger.LogDebug(ctx, "[OAuth-LinuxDO] ExchangeToken response status: %d", res.StatusCode) + + var tokenRes struct { + AccessToken string `json:"access_token"` + Message string `json:"message"` + } + if err := json.NewDecoder(res.Body).Decode(&tokenRes); err != nil { + logger.LogError(ctx, fmt.Sprintf("[OAuth-LinuxDO] ExchangeToken decode error: %s", err.Error())) + return nil, err + } + + if tokenRes.AccessToken == "" { + logger.LogError(ctx, fmt.Sprintf("[OAuth-LinuxDO] ExchangeToken failed: %s", tokenRes.Message)) + return nil, NewOAuthErrorWithRaw(i18n.MsgOAuthTokenFailed, map[string]any{"Provider": "Linux DO"}, tokenRes.Message) + } + + logger.LogDebug(ctx, "[OAuth-LinuxDO] ExchangeToken success") + + return &OAuthToken{ + AccessToken: tokenRes.AccessToken, + }, nil +} + +func (p *LinuxDOProvider) GetUserInfo(ctx context.Context, token *OAuthToken) (*OAuthUser, error) { + userEndpoint := common.GetEnvOrDefaultString("LINUX_DO_USER_ENDPOINT", "https://connect.linux.do/api/user") + + logger.LogDebug(ctx, "[OAuth-LinuxDO] GetUserInfo: user_endpoint=%s", userEndpoint) + + req, err := http.NewRequestWithContext(ctx, "GET", userEndpoint, nil) + if err != nil { + return nil, err + } + req.Header.Set("Authorization", "Bearer "+token.AccessToken) + req.Header.Set("Accept", "application/json") + + client := http.Client{Timeout: 5 * time.Second} + res, err := client.Do(req) + if err != nil { + logger.LogError(ctx, fmt.Sprintf("[OAuth-LinuxDO] GetUserInfo error: %s", err.Error())) + return nil, NewOAuthErrorWithRaw(i18n.MsgOAuthConnectFailed, map[string]any{"Provider": "Linux DO"}, err.Error()) + } + defer res.Body.Close() + + logger.LogDebug(ctx, "[OAuth-LinuxDO] GetUserInfo response status: %d", res.StatusCode) + + var linuxdoUser linuxdoUser + if err := json.NewDecoder(res.Body).Decode(&linuxdoUser); err != nil { + logger.LogError(ctx, fmt.Sprintf("[OAuth-LinuxDO] GetUserInfo decode error: %s", err.Error())) + return nil, err + } + + if linuxdoUser.Id == 0 { + logger.LogError(ctx, "[OAuth-LinuxDO] GetUserInfo failed: invalid user id") + return nil, NewOAuthError(i18n.MsgOAuthUserInfoEmpty, map[string]any{"Provider": "Linux DO"}) + } + + logger.LogDebug(ctx, "[OAuth-LinuxDO] GetUserInfo: id=%d, username=%s, name=%s, trust_level=%d, active=%v, silenced=%v", + linuxdoUser.Id, linuxdoUser.Username, linuxdoUser.Name, linuxdoUser.TrustLevel, linuxdoUser.Active, linuxdoUser.Silenced) + + // Check trust level + if linuxdoUser.TrustLevel < common.LinuxDOMinimumTrustLevel { + logger.LogWarn(ctx, fmt.Sprintf("[OAuth-LinuxDO] GetUserInfo: trust level too low (required=%d, current=%d)", + common.LinuxDOMinimumTrustLevel, linuxdoUser.TrustLevel)) + return nil, &TrustLevelError{ + Required: common.LinuxDOMinimumTrustLevel, + Current: linuxdoUser.TrustLevel, + } + } + + logger.LogDebug(ctx, "[OAuth-LinuxDO] GetUserInfo success: id=%d, username=%s", linuxdoUser.Id, linuxdoUser.Username) + + return &OAuthUser{ + ProviderUserID: strconv.Itoa(linuxdoUser.Id), + Username: linuxdoUser.Username, + DisplayName: linuxdoUser.Name, + Extra: map[string]any{ + "trust_level": linuxdoUser.TrustLevel, + "active": linuxdoUser.Active, + "silenced": linuxdoUser.Silenced, + }, + }, nil +} + +func (p *LinuxDOProvider) IsUserIDTaken(providerUserID string) bool { + return model.IsLinuxDOIdAlreadyTaken(providerUserID) +} + +func (p *LinuxDOProvider) FillUserByProviderID(user *model.User, providerUserID string) error { + user.LinuxDOId = providerUserID + return user.FillUserByLinuxDOId() +} + +func (p *LinuxDOProvider) SetProviderUserID(user *model.User, providerUserID string) { + user.LinuxDOId = providerUserID +} + +func (p *LinuxDOProvider) GetProviderPrefix() string { + return "linuxdo_" +} + +// TrustLevelError indicates the user's trust level is too low +type TrustLevelError struct { + Required int + Current int +} + +func (e *TrustLevelError) Error() string { + return "trust level too low" +} diff --git a/oauth/oidc.go b/oauth/oidc.go new file mode 100644 index 000000000..9bdc6d01e --- /dev/null +++ b/oauth/oidc.go @@ -0,0 +1,177 @@ +package oauth + +import ( + "context" + "encoding/json" + "fmt" + "net/http" + "net/url" + "strings" + "time" + + "github.com/QuantumNous/new-api/i18n" + "github.com/QuantumNous/new-api/logger" + "github.com/QuantumNous/new-api/model" + "github.com/QuantumNous/new-api/setting/system_setting" + "github.com/gin-gonic/gin" +) + +func init() { + Register("oidc", &OIDCProvider{}) +} + +// OIDCProvider implements OAuth for OIDC +type OIDCProvider struct{} + +type oidcOAuthResponse struct { + AccessToken string `json:"access_token"` + IDToken string `json:"id_token"` + RefreshToken string `json:"refresh_token"` + TokenType string `json:"token_type"` + ExpiresIn int `json:"expires_in"` + Scope string `json:"scope"` +} + +type oidcUser struct { + OpenID string `json:"sub"` + Email string `json:"email"` + Name string `json:"name"` + PreferredUsername string `json:"preferred_username"` + Picture string `json:"picture"` +} + +func (p *OIDCProvider) GetName() string { + return "OIDC" +} + +func (p *OIDCProvider) IsEnabled() bool { + return system_setting.GetOIDCSettings().Enabled +} + +func (p *OIDCProvider) ExchangeToken(ctx context.Context, code string, c *gin.Context) (*OAuthToken, error) { + if code == "" { + return nil, NewOAuthError(i18n.MsgOAuthInvalidCode, nil) + } + + logger.LogDebug(ctx, "[OAuth-OIDC] ExchangeToken: code=%s...", code[:min(len(code), 10)]) + + settings := system_setting.GetOIDCSettings() + redirectUri := fmt.Sprintf("%s/oauth/oidc", system_setting.ServerAddress) + values := url.Values{} + values.Set("client_id", settings.ClientId) + values.Set("client_secret", settings.ClientSecret) + values.Set("code", code) + values.Set("grant_type", "authorization_code") + values.Set("redirect_uri", redirectUri) + + logger.LogDebug(ctx, "[OAuth-OIDC] ExchangeToken: token_endpoint=%s, redirect_uri=%s", settings.TokenEndpoint, redirectUri) + + req, err := http.NewRequestWithContext(ctx, "POST", settings.TokenEndpoint, strings.NewReader(values.Encode())) + if err != nil { + return nil, err + } + req.Header.Set("Content-Type", "application/x-www-form-urlencoded") + req.Header.Set("Accept", "application/json") + + client := http.Client{ + Timeout: 5 * time.Second, + } + res, err := client.Do(req) + if err != nil { + logger.LogError(ctx, fmt.Sprintf("[OAuth-OIDC] ExchangeToken error: %s", err.Error())) + return nil, NewOAuthErrorWithRaw(i18n.MsgOAuthConnectFailed, map[string]any{"Provider": "OIDC"}, err.Error()) + } + defer res.Body.Close() + + logger.LogDebug(ctx, "[OAuth-OIDC] ExchangeToken response status: %d", res.StatusCode) + + var oidcResponse oidcOAuthResponse + err = json.NewDecoder(res.Body).Decode(&oidcResponse) + if err != nil { + logger.LogError(ctx, fmt.Sprintf("[OAuth-OIDC] ExchangeToken decode error: %s", err.Error())) + return nil, err + } + + if oidcResponse.AccessToken == "" { + logger.LogError(ctx, "[OAuth-OIDC] ExchangeToken failed: empty access token") + return nil, NewOAuthError(i18n.MsgOAuthTokenFailed, map[string]any{"Provider": "OIDC"}) + } + + logger.LogDebug(ctx, "[OAuth-OIDC] ExchangeToken success: scope=%s", oidcResponse.Scope) + + return &OAuthToken{ + AccessToken: oidcResponse.AccessToken, + TokenType: oidcResponse.TokenType, + RefreshToken: oidcResponse.RefreshToken, + ExpiresIn: oidcResponse.ExpiresIn, + Scope: oidcResponse.Scope, + IDToken: oidcResponse.IDToken, + }, nil +} + +func (p *OIDCProvider) GetUserInfo(ctx context.Context, token *OAuthToken) (*OAuthUser, error) { + settings := system_setting.GetOIDCSettings() + + logger.LogDebug(ctx, "[OAuth-OIDC] GetUserInfo: userinfo_endpoint=%s", settings.UserInfoEndpoint) + + req, err := http.NewRequestWithContext(ctx, "GET", settings.UserInfoEndpoint, nil) + if err != nil { + return nil, err + } + req.Header.Set("Authorization", "Bearer "+token.AccessToken) + + client := http.Client{ + Timeout: 5 * time.Second, + } + res, err := client.Do(req) + if err != nil { + logger.LogError(ctx, fmt.Sprintf("[OAuth-OIDC] GetUserInfo error: %s", err.Error())) + return nil, NewOAuthErrorWithRaw(i18n.MsgOAuthConnectFailed, map[string]any{"Provider": "OIDC"}, err.Error()) + } + defer res.Body.Close() + + logger.LogDebug(ctx, "[OAuth-OIDC] GetUserInfo response status: %d", res.StatusCode) + + if res.StatusCode != http.StatusOK { + logger.LogError(ctx, fmt.Sprintf("[OAuth-OIDC] GetUserInfo failed: status=%d", res.StatusCode)) + return nil, NewOAuthError(i18n.MsgOAuthGetUserErr, nil) + } + + var oidcUser oidcUser + err = json.NewDecoder(res.Body).Decode(&oidcUser) + if err != nil { + logger.LogError(ctx, fmt.Sprintf("[OAuth-OIDC] GetUserInfo decode error: %s", err.Error())) + return nil, err + } + + if oidcUser.OpenID == "" || oidcUser.Email == "" { + logger.LogError(ctx, fmt.Sprintf("[OAuth-OIDC] GetUserInfo failed: empty fields (sub=%s, email=%s)", oidcUser.OpenID, oidcUser.Email)) + return nil, NewOAuthError(i18n.MsgOAuthUserInfoEmpty, map[string]any{"Provider": "OIDC"}) + } + + logger.LogDebug(ctx, "[OAuth-OIDC] GetUserInfo success: sub=%s, username=%s, name=%s, email=%s", oidcUser.OpenID, oidcUser.PreferredUsername, oidcUser.Name, oidcUser.Email) + + return &OAuthUser{ + ProviderUserID: oidcUser.OpenID, + Username: oidcUser.PreferredUsername, + DisplayName: oidcUser.Name, + Email: oidcUser.Email, + }, nil +} + +func (p *OIDCProvider) IsUserIDTaken(providerUserID string) bool { + return model.IsOidcIdAlreadyTaken(providerUserID) +} + +func (p *OIDCProvider) FillUserByProviderID(user *model.User, providerUserID string) error { + user.OidcId = providerUserID + return user.FillUserByOidcId() +} + +func (p *OIDCProvider) SetProviderUserID(user *model.User, providerUserID string) { + user.OidcId = providerUserID +} + +func (p *OIDCProvider) GetProviderPrefix() string { + return "oidc_" +} diff --git a/oauth/provider.go b/oauth/provider.go new file mode 100644 index 000000000..785ed25d2 --- /dev/null +++ b/oauth/provider.go @@ -0,0 +1,36 @@ +package oauth + +import ( + "context" + + "github.com/QuantumNous/new-api/model" + "github.com/gin-gonic/gin" +) + +// Provider defines the interface for OAuth providers +type Provider interface { + // GetName returns the display name of the provider (e.g., "GitHub", "Discord") + GetName() string + + // IsEnabled returns whether this OAuth provider is enabled + IsEnabled() bool + + // ExchangeToken exchanges the authorization code for an access token + // The gin.Context is passed for providers that need request info (e.g., for redirect_uri) + ExchangeToken(ctx context.Context, code string, c *gin.Context) (*OAuthToken, error) + + // GetUserInfo retrieves user information using the access token + GetUserInfo(ctx context.Context, token *OAuthToken) (*OAuthUser, error) + + // IsUserIDTaken checks if the provider user ID is already associated with an account + IsUserIDTaken(providerUserID string) bool + + // FillUserByProviderID fills the user model by provider user ID + FillUserByProviderID(user *model.User, providerUserID string) error + + // SetProviderUserID sets the provider user ID on the user model + SetProviderUserID(user *model.User, providerUserID string) + + // GetProviderPrefix returns the prefix for auto-generated usernames (e.g., "github_") + GetProviderPrefix() string +} diff --git a/oauth/registry.go b/oauth/registry.go new file mode 100644 index 000000000..13ee2bcfb --- /dev/null +++ b/oauth/registry.go @@ -0,0 +1,43 @@ +package oauth + +import ( + "sync" +) + +var ( + providers = make(map[string]Provider) + mu sync.RWMutex +) + +// Register registers an OAuth provider with the given name +func Register(name string, provider Provider) { + mu.Lock() + defer mu.Unlock() + providers[name] = provider +} + +// GetProvider returns the OAuth provider for the given name +func GetProvider(name string) Provider { + mu.RLock() + defer mu.RUnlock() + return providers[name] +} + +// GetAllProviders returns all registered OAuth providers +func GetAllProviders() map[string]Provider { + mu.RLock() + defer mu.RUnlock() + result := make(map[string]Provider, len(providers)) + for k, v := range providers { + result[k] = v + } + return result +} + +// IsProviderRegistered checks if a provider is registered +func IsProviderRegistered(name string) bool { + mu.RLock() + defer mu.RUnlock() + _, ok := providers[name] + return ok +} diff --git a/oauth/types.go b/oauth/types.go new file mode 100644 index 000000000..1b0e3646a --- /dev/null +++ b/oauth/types.go @@ -0,0 +1,59 @@ +package oauth + +// OAuthToken represents the token received from OAuth provider +type OAuthToken struct { + AccessToken string `json:"access_token"` + TokenType string `json:"token_type"` + RefreshToken string `json:"refresh_token,omitempty"` + ExpiresIn int `json:"expires_in,omitempty"` + Scope string `json:"scope,omitempty"` + IDToken string `json:"id_token,omitempty"` +} + +// OAuthUser represents the user info from OAuth provider +type OAuthUser struct { + // ProviderUserID is the unique identifier from the OAuth provider + ProviderUserID string + // Username is the username from the OAuth provider (e.g., GitHub login) + Username string + // DisplayName is the display name from the OAuth provider + DisplayName string + // Email is the email from the OAuth provider + Email string + // Extra contains any additional provider-specific data + Extra map[string]any +} + +// OAuthError represents a translatable OAuth error +type OAuthError struct { + // MsgKey is the i18n message key + MsgKey string + // Params contains optional parameters for the message template + Params map[string]any + // RawError is the underlying error for logging purposes + RawError string +} + +func (e *OAuthError) Error() string { + if e.RawError != "" { + return e.RawError + } + return e.MsgKey +} + +// NewOAuthError creates a new OAuth error with the given message key +func NewOAuthError(msgKey string, params map[string]any) *OAuthError { + return &OAuthError{ + MsgKey: msgKey, + Params: params, + } +} + +// NewOAuthErrorWithRaw creates a new OAuth error with raw error message for logging +func NewOAuthErrorWithRaw(msgKey string, params map[string]any, rawError string) *OAuthError { + return &OAuthError{ + MsgKey: msgKey, + Params: params, + RawError: rawError, + } +} diff --git a/router/api-router.go b/router/api-router.go index e46361c17..2b84295a1 100644 --- a/router/api-router.go +++ b/router/api-router.go @@ -4,6 +4,9 @@ import ( "github.com/QuantumNous/new-api/controller" "github.com/QuantumNous/new-api/middleware" + // Import oauth package to register providers via init() + _ "github.com/QuantumNous/new-api/oauth" + "github.com/gin-contrib/gzip" "github.com/gin-gonic/gin" ) @@ -30,16 +33,16 @@ func SetApiRouter(router *gin.Engine) { apiRouter.GET("/verification", middleware.EmailVerificationRateLimit(), middleware.TurnstileCheck(), controller.SendEmailVerification) apiRouter.GET("/reset_password", middleware.CriticalRateLimit(), middleware.TurnstileCheck(), controller.SendPasswordResetEmail) apiRouter.POST("/user/reset", middleware.CriticalRateLimit(), controller.ResetPassword) - apiRouter.GET("/oauth/github", middleware.CriticalRateLimit(), controller.GitHubOAuth) - apiRouter.GET("/oauth/discord", middleware.CriticalRateLimit(), controller.DiscordOAuth) - apiRouter.GET("/oauth/oidc", middleware.CriticalRateLimit(), controller.OidcAuth) - apiRouter.GET("/oauth/linuxdo", middleware.CriticalRateLimit(), controller.LinuxdoOAuth) + // OAuth routes - specific routes must come before :provider wildcard apiRouter.GET("/oauth/state", middleware.CriticalRateLimit(), controller.GenerateOAuthCode) + apiRouter.GET("/oauth/email/bind", middleware.CriticalRateLimit(), controller.EmailBind) + // Non-standard OAuth (WeChat, Telegram) - keep original routes apiRouter.GET("/oauth/wechat", middleware.CriticalRateLimit(), controller.WeChatAuth) apiRouter.GET("/oauth/wechat/bind", middleware.CriticalRateLimit(), controller.WeChatBind) - apiRouter.GET("/oauth/email/bind", middleware.CriticalRateLimit(), controller.EmailBind) apiRouter.GET("/oauth/telegram/login", middleware.CriticalRateLimit(), controller.TelegramLogin) apiRouter.GET("/oauth/telegram/bind", middleware.CriticalRateLimit(), controller.TelegramBind) + // Standard OAuth providers (GitHub, Discord, OIDC, LinuxDO) - unified route + apiRouter.GET("/oauth/:provider", middleware.CriticalRateLimit(), controller.HandleOAuth) apiRouter.GET("/ratio_config", middleware.CriticalRateLimit(), controller.GetRatioConfig) apiRouter.POST("/stripe/webhook", controller.StripeWebhook) diff --git a/web/src/components/auth/OAuth2Callback.jsx b/web/src/components/auth/OAuth2Callback.jsx index e43e9e033..c0c6418a1 100644 --- a/web/src/components/auth/OAuth2Callback.jsx +++ b/web/src/components/auth/OAuth2Callback.jsx @@ -17,7 +17,7 @@ along with this program. If not, see . For commercial licensing, please contact support@quantumnous.com */ -import React, { useContext, useEffect } from 'react'; +import React, { useContext, useEffect, useRef } from 'react'; import { useNavigate, useSearchParams } from 'react-router-dom'; import { useTranslation } from 'react-i18next'; import { @@ -35,6 +35,9 @@ const OAuth2Callback = (props) => { const [searchParams] = useSearchParams(); const [, userDispatch] = useContext(UserContext); const navigate = useNavigate(); + + // 防止 React 18 Strict Mode 下重复执行 + const hasExecuted = useRef(false); // 最大重试次数 const MAX_RETRIES = 3; @@ -48,7 +51,9 @@ const OAuth2Callback = (props) => { const { success, message, data } = resData; if (!success) { - throw new Error(message || 'OAuth2 callback error'); + // 业务错误不重试,直接显示错误 + showError(message || t('授权失败')); + return; } if (message === 'bind') { @@ -63,6 +68,7 @@ const OAuth2Callback = (props) => { navigate('/console/token'); } } catch (error) { + // 网络错误等可重试 if (retry < MAX_RETRIES) { // 递增的退避等待 await new Promise((resolve) => setTimeout(resolve, (retry + 1) * 2000)); @@ -76,6 +82,12 @@ const OAuth2Callback = (props) => { }; useEffect(() => { + // 防止 React 18 Strict Mode 下重复执行 + if (hasExecuted.current) { + return; + } + hasExecuted.current = true; + const code = searchParams.get('code'); const state = searchParams.get('state');