mirror of
https://github.com/QuantumNous/new-api.git
synced 2026-03-30 02:05:21 +00:00
refactor: unify OAuth providers with i18n support
- Introduce Provider interface pattern for standard OAuth protocols - Create unified controller/oauth.go with common OAuth logic - Add OAuthError type for translatable error messages - Add i18n keys and translations (zh/en) for OAuth messages - Use common.ApiErrorI18n/ApiSuccessI18n for consistent responses - Preserve backward compatibility for existing routes and data
This commit is contained in:
@@ -1,223 +0,0 @@
|
||||
package controller
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"strconv"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/QuantumNous/new-api/common"
|
||||
"github.com/QuantumNous/new-api/model"
|
||||
"github.com/QuantumNous/new-api/setting/system_setting"
|
||||
|
||||
"github.com/gin-contrib/sessions"
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
type DiscordResponse struct {
|
||||
AccessToken string `json:"access_token"`
|
||||
IDToken string `json:"id_token"`
|
||||
RefreshToken string `json:"refresh_token"`
|
||||
TokenType string `json:"token_type"`
|
||||
ExpiresIn int `json:"expires_in"`
|
||||
Scope string `json:"scope"`
|
||||
}
|
||||
|
||||
type DiscordUser struct {
|
||||
UID string `json:"id"`
|
||||
ID string `json:"username"`
|
||||
Name string `json:"global_name"`
|
||||
}
|
||||
|
||||
func getDiscordUserInfoByCode(code string) (*DiscordUser, error) {
|
||||
if code == "" {
|
||||
return nil, errors.New("无效的参数")
|
||||
}
|
||||
|
||||
values := url.Values{}
|
||||
values.Set("client_id", system_setting.GetDiscordSettings().ClientId)
|
||||
values.Set("client_secret", system_setting.GetDiscordSettings().ClientSecret)
|
||||
values.Set("code", code)
|
||||
values.Set("grant_type", "authorization_code")
|
||||
values.Set("redirect_uri", fmt.Sprintf("%s/oauth/discord", system_setting.ServerAddress))
|
||||
formData := values.Encode()
|
||||
req, err := http.NewRequest("POST", "https://discord.com/api/v10/oauth2/token", strings.NewReader(formData))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
|
||||
req.Header.Set("Accept", "application/json")
|
||||
client := http.Client{
|
||||
Timeout: 5 * time.Second,
|
||||
}
|
||||
res, err := client.Do(req)
|
||||
if err != nil {
|
||||
common.SysLog(err.Error())
|
||||
return nil, errors.New("无法连接至 Discord 服务器,请稍后重试!")
|
||||
}
|
||||
defer res.Body.Close()
|
||||
var discordResponse DiscordResponse
|
||||
err = json.NewDecoder(res.Body).Decode(&discordResponse)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if discordResponse.AccessToken == "" {
|
||||
common.SysError("Discord 获取 Token 失败,请检查设置!")
|
||||
return nil, errors.New("Discord 获取 Token 失败,请检查设置!")
|
||||
}
|
||||
|
||||
req, err = http.NewRequest("GET", "https://discord.com/api/v10/users/@me", nil)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
req.Header.Set("Authorization", "Bearer "+discordResponse.AccessToken)
|
||||
res2, err := client.Do(req)
|
||||
if err != nil {
|
||||
common.SysLog(err.Error())
|
||||
return nil, errors.New("无法连接至 Discord 服务器,请稍后重试!")
|
||||
}
|
||||
defer res2.Body.Close()
|
||||
if res2.StatusCode != http.StatusOK {
|
||||
common.SysError("Discord 获取用户信息失败!请检查设置!")
|
||||
return nil, errors.New("Discord 获取用户信息失败!请检查设置!")
|
||||
}
|
||||
|
||||
var discordUser DiscordUser
|
||||
err = json.NewDecoder(res2.Body).Decode(&discordUser)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if discordUser.UID == "" || discordUser.ID == "" {
|
||||
common.SysError("Discord 获取用户信息为空!请检查设置!")
|
||||
return nil, errors.New("Discord 获取用户信息为空!请检查设置!")
|
||||
}
|
||||
return &discordUser, nil
|
||||
}
|
||||
|
||||
func DiscordOAuth(c *gin.Context) {
|
||||
session := sessions.Default(c)
|
||||
state := c.Query("state")
|
||||
if state == "" || session.Get("oauth_state") == nil || state != session.Get("oauth_state").(string) {
|
||||
c.JSON(http.StatusForbidden, gin.H{
|
||||
"success": false,
|
||||
"message": "state is empty or not same",
|
||||
})
|
||||
return
|
||||
}
|
||||
username := session.Get("username")
|
||||
if username != nil {
|
||||
DiscordBind(c)
|
||||
return
|
||||
}
|
||||
if !system_setting.GetDiscordSettings().Enabled {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": false,
|
||||
"message": "管理员未开启通过 Discord 登录以及注册",
|
||||
})
|
||||
return
|
||||
}
|
||||
code := c.Query("code")
|
||||
discordUser, err := getDiscordUserInfoByCode(code)
|
||||
if err != nil {
|
||||
common.ApiError(c, err)
|
||||
return
|
||||
}
|
||||
user := model.User{
|
||||
DiscordId: discordUser.UID,
|
||||
}
|
||||
if model.IsDiscordIdAlreadyTaken(user.DiscordId) {
|
||||
err := user.FillUserByDiscordId()
|
||||
if err != nil {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": false,
|
||||
"message": err.Error(),
|
||||
})
|
||||
return
|
||||
}
|
||||
} else {
|
||||
if common.RegisterEnabled {
|
||||
if discordUser.ID != "" {
|
||||
user.Username = discordUser.ID
|
||||
} else {
|
||||
user.Username = "discord_" + strconv.Itoa(model.GetMaxUserId()+1)
|
||||
}
|
||||
if discordUser.Name != "" {
|
||||
user.DisplayName = discordUser.Name
|
||||
} else {
|
||||
user.DisplayName = "Discord User"
|
||||
}
|
||||
err := user.Insert(0)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": false,
|
||||
"message": err.Error(),
|
||||
})
|
||||
return
|
||||
}
|
||||
} else {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": false,
|
||||
"message": "管理员关闭了新用户注册",
|
||||
})
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
if user.Status != common.UserStatusEnabled {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"message": "用户已被封禁",
|
||||
"success": false,
|
||||
})
|
||||
return
|
||||
}
|
||||
setupLogin(&user, c)
|
||||
}
|
||||
|
||||
func DiscordBind(c *gin.Context) {
|
||||
if !system_setting.GetDiscordSettings().Enabled {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": false,
|
||||
"message": "管理员未开启通过 Discord 登录以及注册",
|
||||
})
|
||||
return
|
||||
}
|
||||
code := c.Query("code")
|
||||
discordUser, err := getDiscordUserInfoByCode(code)
|
||||
if err != nil {
|
||||
common.ApiError(c, err)
|
||||
return
|
||||
}
|
||||
user := model.User{
|
||||
DiscordId: discordUser.UID,
|
||||
}
|
||||
if model.IsDiscordIdAlreadyTaken(user.DiscordId) {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": false,
|
||||
"message": "该 Discord 账户已被绑定",
|
||||
})
|
||||
return
|
||||
}
|
||||
session := sessions.Default(c)
|
||||
id := session.Get("id")
|
||||
user.Id = id.(int)
|
||||
err = user.FillUserById()
|
||||
if err != nil {
|
||||
common.ApiError(c, err)
|
||||
return
|
||||
}
|
||||
user.DiscordId = discordUser.UID
|
||||
err = user.Update(false)
|
||||
if err != nil {
|
||||
common.ApiError(c, err)
|
||||
return
|
||||
}
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": true,
|
||||
"message": "bind",
|
||||
})
|
||||
}
|
||||
@@ -1,240 +0,0 @@
|
||||
package controller
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"strconv"
|
||||
"time"
|
||||
|
||||
"github.com/QuantumNous/new-api/common"
|
||||
"github.com/QuantumNous/new-api/model"
|
||||
|
||||
"github.com/gin-contrib/sessions"
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
type GitHubOAuthResponse struct {
|
||||
AccessToken string `json:"access_token"`
|
||||
Scope string `json:"scope"`
|
||||
TokenType string `json:"token_type"`
|
||||
}
|
||||
|
||||
type GitHubUser struct {
|
||||
Login string `json:"login"`
|
||||
Name string `json:"name"`
|
||||
Email string `json:"email"`
|
||||
}
|
||||
|
||||
func getGitHubUserInfoByCode(code string) (*GitHubUser, error) {
|
||||
if code == "" {
|
||||
return nil, errors.New("无效的参数")
|
||||
}
|
||||
values := map[string]string{"client_id": common.GitHubClientId, "client_secret": common.GitHubClientSecret, "code": code}
|
||||
jsonData, err := json.Marshal(values)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
req, err := http.NewRequest("POST", "https://github.com/login/oauth/access_token", bytes.NewBuffer(jsonData))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
req.Header.Set("Accept", "application/json")
|
||||
client := http.Client{
|
||||
Timeout: 20 * time.Second,
|
||||
}
|
||||
res, err := client.Do(req)
|
||||
if err != nil {
|
||||
common.SysLog(err.Error())
|
||||
return nil, errors.New("无法连接至 GitHub 服务器,请稍后重试!")
|
||||
}
|
||||
defer res.Body.Close()
|
||||
var oAuthResponse GitHubOAuthResponse
|
||||
err = json.NewDecoder(res.Body).Decode(&oAuthResponse)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
req, err = http.NewRequest("GET", "https://api.github.com/user", nil)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", oAuthResponse.AccessToken))
|
||||
res2, err := client.Do(req)
|
||||
if err != nil {
|
||||
common.SysLog(err.Error())
|
||||
return nil, errors.New("无法连接至 GitHub 服务器,请稍后重试!")
|
||||
}
|
||||
defer res2.Body.Close()
|
||||
var githubUser GitHubUser
|
||||
err = json.NewDecoder(res2.Body).Decode(&githubUser)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if githubUser.Login == "" {
|
||||
return nil, errors.New("返回值非法,用户字段为空,请稍后重试!")
|
||||
}
|
||||
return &githubUser, nil
|
||||
}
|
||||
|
||||
func GitHubOAuth(c *gin.Context) {
|
||||
session := sessions.Default(c)
|
||||
state := c.Query("state")
|
||||
if state == "" || session.Get("oauth_state") == nil || state != session.Get("oauth_state").(string) {
|
||||
c.JSON(http.StatusForbidden, gin.H{
|
||||
"success": false,
|
||||
"message": "state is empty or not same",
|
||||
})
|
||||
return
|
||||
}
|
||||
username := session.Get("username")
|
||||
if username != nil {
|
||||
GitHubBind(c)
|
||||
return
|
||||
}
|
||||
|
||||
if !common.GitHubOAuthEnabled {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": false,
|
||||
"message": "管理员未开启通过 GitHub 登录以及注册",
|
||||
})
|
||||
return
|
||||
}
|
||||
code := c.Query("code")
|
||||
githubUser, err := getGitHubUserInfoByCode(code)
|
||||
if err != nil {
|
||||
common.ApiError(c, err)
|
||||
return
|
||||
}
|
||||
user := model.User{
|
||||
GitHubId: githubUser.Login,
|
||||
}
|
||||
// IsGitHubIdAlreadyTaken is unscoped
|
||||
if model.IsGitHubIdAlreadyTaken(user.GitHubId) {
|
||||
// FillUserByGitHubId is scoped
|
||||
err := user.FillUserByGitHubId()
|
||||
if err != nil {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": false,
|
||||
"message": err.Error(),
|
||||
})
|
||||
return
|
||||
}
|
||||
// if user.Id == 0 , user has been deleted
|
||||
if user.Id == 0 {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": false,
|
||||
"message": "用户已注销",
|
||||
})
|
||||
return
|
||||
}
|
||||
} else {
|
||||
if common.RegisterEnabled {
|
||||
user.Username = "github_" + strconv.Itoa(model.GetMaxUserId()+1)
|
||||
if githubUser.Name != "" {
|
||||
user.DisplayName = githubUser.Name
|
||||
} else {
|
||||
user.DisplayName = "GitHub User"
|
||||
}
|
||||
user.Email = githubUser.Email
|
||||
user.Role = common.RoleCommonUser
|
||||
user.Status = common.UserStatusEnabled
|
||||
affCode := session.Get("aff")
|
||||
inviterId := 0
|
||||
if affCode != nil {
|
||||
inviterId, _ = model.GetUserIdByAffCode(affCode.(string))
|
||||
}
|
||||
|
||||
if err := user.Insert(inviterId); err != nil {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": false,
|
||||
"message": err.Error(),
|
||||
})
|
||||
return
|
||||
}
|
||||
} else {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": false,
|
||||
"message": "管理员关闭了新用户注册",
|
||||
})
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
if user.Status != common.UserStatusEnabled {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"message": "用户已被封禁",
|
||||
"success": false,
|
||||
})
|
||||
return
|
||||
}
|
||||
setupLogin(&user, c)
|
||||
}
|
||||
|
||||
func GitHubBind(c *gin.Context) {
|
||||
if !common.GitHubOAuthEnabled {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": false,
|
||||
"message": "管理员未开启通过 GitHub 登录以及注册",
|
||||
})
|
||||
return
|
||||
}
|
||||
code := c.Query("code")
|
||||
githubUser, err := getGitHubUserInfoByCode(code)
|
||||
if err != nil {
|
||||
common.ApiError(c, err)
|
||||
return
|
||||
}
|
||||
user := model.User{
|
||||
GitHubId: githubUser.Login,
|
||||
}
|
||||
if model.IsGitHubIdAlreadyTaken(user.GitHubId) {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": false,
|
||||
"message": "该 GitHub 账户已被绑定",
|
||||
})
|
||||
return
|
||||
}
|
||||
session := sessions.Default(c)
|
||||
id := session.Get("id")
|
||||
// id := c.GetInt("id") // critical bug!
|
||||
user.Id = id.(int)
|
||||
err = user.FillUserById()
|
||||
if err != nil {
|
||||
common.ApiError(c, err)
|
||||
return
|
||||
}
|
||||
user.GitHubId = githubUser.Login
|
||||
err = user.Update(false)
|
||||
if err != nil {
|
||||
common.ApiError(c, err)
|
||||
return
|
||||
}
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": true,
|
||||
"message": "bind",
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
func GenerateOAuthCode(c *gin.Context) {
|
||||
session := sessions.Default(c)
|
||||
state := common.GetRandomString(12)
|
||||
affCode := c.Query("aff")
|
||||
if affCode != "" {
|
||||
session.Set("aff", affCode)
|
||||
}
|
||||
session.Set("oauth_state", state)
|
||||
err := session.Save()
|
||||
if err != nil {
|
||||
common.ApiError(c, err)
|
||||
return
|
||||
}
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": true,
|
||||
"message": "",
|
||||
"data": state,
|
||||
})
|
||||
}
|
||||
@@ -1,268 +0,0 @@
|
||||
package controller
|
||||
|
||||
import (
|
||||
"encoding/base64"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"strconv"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/QuantumNous/new-api/common"
|
||||
"github.com/QuantumNous/new-api/model"
|
||||
|
||||
"github.com/gin-contrib/sessions"
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
type LinuxdoUser struct {
|
||||
Id int `json:"id"`
|
||||
Username string `json:"username"`
|
||||
Name string `json:"name"`
|
||||
Active bool `json:"active"`
|
||||
TrustLevel int `json:"trust_level"`
|
||||
Silenced bool `json:"silenced"`
|
||||
}
|
||||
|
||||
func LinuxDoBind(c *gin.Context) {
|
||||
if !common.LinuxDOOAuthEnabled {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": false,
|
||||
"message": "管理员未开启通过 Linux DO 登录以及注册",
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
code := c.Query("code")
|
||||
linuxdoUser, err := getLinuxdoUserInfoByCode(code, c)
|
||||
if err != nil {
|
||||
common.ApiError(c, err)
|
||||
return
|
||||
}
|
||||
|
||||
user := model.User{
|
||||
LinuxDOId: strconv.Itoa(linuxdoUser.Id),
|
||||
}
|
||||
|
||||
if model.IsLinuxDOIdAlreadyTaken(user.LinuxDOId) {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": false,
|
||||
"message": "该 Linux DO 账户已被绑定",
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
session := sessions.Default(c)
|
||||
id := session.Get("id")
|
||||
user.Id = id.(int)
|
||||
|
||||
err = user.FillUserById()
|
||||
if err != nil {
|
||||
common.ApiError(c, err)
|
||||
return
|
||||
}
|
||||
|
||||
user.LinuxDOId = strconv.Itoa(linuxdoUser.Id)
|
||||
err = user.Update(false)
|
||||
if err != nil {
|
||||
common.ApiError(c, err)
|
||||
return
|
||||
}
|
||||
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": true,
|
||||
"message": "bind",
|
||||
})
|
||||
}
|
||||
|
||||
func getLinuxdoUserInfoByCode(code string, c *gin.Context) (*LinuxdoUser, error) {
|
||||
if code == "" {
|
||||
return nil, errors.New("invalid code")
|
||||
}
|
||||
|
||||
// Get access token using Basic auth
|
||||
tokenEndpoint := common.GetEnvOrDefaultString("LINUX_DO_TOKEN_ENDPOINT", "https://connect.linux.do/oauth2/token")
|
||||
credentials := common.LinuxDOClientId + ":" + common.LinuxDOClientSecret
|
||||
basicAuth := "Basic " + base64.StdEncoding.EncodeToString([]byte(credentials))
|
||||
|
||||
// Get redirect URI from request
|
||||
scheme := "http"
|
||||
if c.Request.TLS != nil {
|
||||
scheme = "https"
|
||||
}
|
||||
redirectURI := fmt.Sprintf("%s://%s/api/oauth/linuxdo", scheme, c.Request.Host)
|
||||
|
||||
data := url.Values{}
|
||||
data.Set("grant_type", "authorization_code")
|
||||
data.Set("code", code)
|
||||
data.Set("redirect_uri", redirectURI)
|
||||
|
||||
req, err := http.NewRequest("POST", tokenEndpoint, strings.NewReader(data.Encode()))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
req.Header.Set("Authorization", basicAuth)
|
||||
req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
|
||||
req.Header.Set("Accept", "application/json")
|
||||
|
||||
client := http.Client{Timeout: 5 * time.Second}
|
||||
res, err := client.Do(req)
|
||||
if err != nil {
|
||||
return nil, errors.New("failed to connect to Linux DO server")
|
||||
}
|
||||
defer res.Body.Close()
|
||||
|
||||
var tokenRes struct {
|
||||
AccessToken string `json:"access_token"`
|
||||
Message string `json:"message"`
|
||||
}
|
||||
if err := json.NewDecoder(res.Body).Decode(&tokenRes); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if tokenRes.AccessToken == "" {
|
||||
return nil, fmt.Errorf("failed to get access token: %s", tokenRes.Message)
|
||||
}
|
||||
|
||||
// Get user info
|
||||
userEndpoint := common.GetEnvOrDefaultString("LINUX_DO_USER_ENDPOINT", "https://connect.linux.do/api/user")
|
||||
req, err = http.NewRequest("GET", userEndpoint, nil)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
req.Header.Set("Authorization", "Bearer "+tokenRes.AccessToken)
|
||||
req.Header.Set("Accept", "application/json")
|
||||
|
||||
res2, err := client.Do(req)
|
||||
if err != nil {
|
||||
return nil, errors.New("failed to get user info from Linux DO")
|
||||
}
|
||||
defer res2.Body.Close()
|
||||
|
||||
var linuxdoUser LinuxdoUser
|
||||
if err := json.NewDecoder(res2.Body).Decode(&linuxdoUser); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if linuxdoUser.Id == 0 {
|
||||
return nil, errors.New("invalid user info returned")
|
||||
}
|
||||
|
||||
return &linuxdoUser, nil
|
||||
}
|
||||
|
||||
func LinuxdoOAuth(c *gin.Context) {
|
||||
session := sessions.Default(c)
|
||||
|
||||
errorCode := c.Query("error")
|
||||
if errorCode != "" {
|
||||
errorDescription := c.Query("error_description")
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": false,
|
||||
"message": errorDescription,
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
state := c.Query("state")
|
||||
if state == "" || session.Get("oauth_state") == nil || state != session.Get("oauth_state").(string) {
|
||||
c.JSON(http.StatusForbidden, gin.H{
|
||||
"success": false,
|
||||
"message": "state is empty or not same",
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
username := session.Get("username")
|
||||
if username != nil {
|
||||
LinuxDoBind(c)
|
||||
return
|
||||
}
|
||||
|
||||
if !common.LinuxDOOAuthEnabled {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": false,
|
||||
"message": "管理员未开启通过 Linux DO 登录以及注册",
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
code := c.Query("code")
|
||||
linuxdoUser, err := getLinuxdoUserInfoByCode(code, c)
|
||||
if err != nil {
|
||||
common.ApiError(c, err)
|
||||
return
|
||||
}
|
||||
|
||||
user := model.User{
|
||||
LinuxDOId: strconv.Itoa(linuxdoUser.Id),
|
||||
}
|
||||
|
||||
// Check if user exists
|
||||
if model.IsLinuxDOIdAlreadyTaken(user.LinuxDOId) {
|
||||
err := user.FillUserByLinuxDOId()
|
||||
if err != nil {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": false,
|
||||
"message": err.Error(),
|
||||
})
|
||||
return
|
||||
}
|
||||
if user.Id == 0 {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": false,
|
||||
"message": "用户已注销",
|
||||
})
|
||||
return
|
||||
}
|
||||
} else {
|
||||
if common.RegisterEnabled {
|
||||
if linuxdoUser.TrustLevel >= common.LinuxDOMinimumTrustLevel {
|
||||
user.Username = "linuxdo_" + strconv.Itoa(model.GetMaxUserId()+1)
|
||||
user.DisplayName = linuxdoUser.Name
|
||||
user.Role = common.RoleCommonUser
|
||||
user.Status = common.UserStatusEnabled
|
||||
|
||||
affCode := session.Get("aff")
|
||||
inviterId := 0
|
||||
if affCode != nil {
|
||||
inviterId, _ = model.GetUserIdByAffCode(affCode.(string))
|
||||
}
|
||||
|
||||
if err := user.Insert(inviterId); err != nil {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": false,
|
||||
"message": err.Error(),
|
||||
})
|
||||
return
|
||||
}
|
||||
} else {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": false,
|
||||
"message": "Linux DO 信任等级未达到管理员设置的最低信任等级",
|
||||
})
|
||||
return
|
||||
}
|
||||
} else {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": false,
|
||||
"message": "管理员关闭了新用户注册",
|
||||
})
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
if user.Status != common.UserStatusEnabled {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"message": "用户已被封禁",
|
||||
"success": false,
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
setupLogin(&user, c)
|
||||
}
|
||||
257
controller/oauth.go
Normal file
257
controller/oauth.go
Normal file
@@ -0,0 +1,257 @@
|
||||
package controller
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"strconv"
|
||||
|
||||
"github.com/QuantumNous/new-api/common"
|
||||
"github.com/QuantumNous/new-api/i18n"
|
||||
"github.com/QuantumNous/new-api/model"
|
||||
"github.com/QuantumNous/new-api/oauth"
|
||||
"github.com/gin-contrib/sessions"
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
// providerParams returns map with Provider key for i18n templates
|
||||
func providerParams(name string) map[string]any {
|
||||
return map[string]any{"Provider": name}
|
||||
}
|
||||
|
||||
// GenerateOAuthCode generates a state code for OAuth CSRF protection
|
||||
func GenerateOAuthCode(c *gin.Context) {
|
||||
session := sessions.Default(c)
|
||||
state := common.GetRandomString(12)
|
||||
affCode := c.Query("aff")
|
||||
if affCode != "" {
|
||||
session.Set("aff", affCode)
|
||||
}
|
||||
session.Set("oauth_state", state)
|
||||
err := session.Save()
|
||||
if err != nil {
|
||||
common.ApiError(c, err)
|
||||
return
|
||||
}
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": true,
|
||||
"message": "",
|
||||
"data": state,
|
||||
})
|
||||
}
|
||||
|
||||
// HandleOAuth handles OAuth callback for all standard OAuth providers
|
||||
func HandleOAuth(c *gin.Context) {
|
||||
providerName := c.Param("provider")
|
||||
provider := oauth.GetProvider(providerName)
|
||||
if provider == nil {
|
||||
c.JSON(http.StatusBadRequest, gin.H{
|
||||
"success": false,
|
||||
"message": i18n.T(c, i18n.MsgOAuthUnknownProvider),
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
session := sessions.Default(c)
|
||||
|
||||
// 1. Validate state (CSRF protection)
|
||||
state := c.Query("state")
|
||||
if state == "" || session.Get("oauth_state") == nil || state != session.Get("oauth_state").(string) {
|
||||
c.JSON(http.StatusForbidden, gin.H{
|
||||
"success": false,
|
||||
"message": i18n.T(c, i18n.MsgOAuthStateInvalid),
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
// 2. Check if user is already logged in (bind flow)
|
||||
username := session.Get("username")
|
||||
if username != nil {
|
||||
handleOAuthBind(c, provider)
|
||||
return
|
||||
}
|
||||
|
||||
// 3. Check if provider is enabled
|
||||
if !provider.IsEnabled() {
|
||||
common.ApiErrorI18n(c, i18n.MsgOAuthNotEnabled, providerParams(provider.GetName()))
|
||||
return
|
||||
}
|
||||
|
||||
// 4. Handle error from provider
|
||||
errorCode := c.Query("error")
|
||||
if errorCode != "" {
|
||||
errorDescription := c.Query("error_description")
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": false,
|
||||
"message": errorDescription,
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
// 5. Exchange code for token
|
||||
code := c.Query("code")
|
||||
token, err := provider.ExchangeToken(c.Request.Context(), code, c)
|
||||
if err != nil {
|
||||
handleOAuthError(c, err)
|
||||
return
|
||||
}
|
||||
|
||||
// 6. Get user info
|
||||
oauthUser, err := provider.GetUserInfo(c.Request.Context(), token)
|
||||
if err != nil {
|
||||
handleOAuthError(c, err)
|
||||
return
|
||||
}
|
||||
|
||||
// 7. Find or create user
|
||||
user, err := findOrCreateOAuthUser(c, provider, oauthUser, session)
|
||||
if err != nil {
|
||||
switch err.(type) {
|
||||
case *OAuthUserDeletedError:
|
||||
common.ApiErrorI18n(c, i18n.MsgOAuthUserDeleted)
|
||||
case *OAuthRegistrationDisabledError:
|
||||
common.ApiErrorI18n(c, i18n.MsgUserRegisterDisabled)
|
||||
default:
|
||||
common.ApiError(c, err)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
// 8. Check user status
|
||||
if user.Status != common.UserStatusEnabled {
|
||||
common.ApiErrorI18n(c, i18n.MsgOAuthUserBanned)
|
||||
return
|
||||
}
|
||||
|
||||
// 9. Setup login
|
||||
setupLogin(user, c)
|
||||
}
|
||||
|
||||
// handleOAuthBind handles binding OAuth account to existing user
|
||||
func handleOAuthBind(c *gin.Context, provider oauth.Provider) {
|
||||
if !provider.IsEnabled() {
|
||||
common.ApiErrorI18n(c, i18n.MsgOAuthNotEnabled, providerParams(provider.GetName()))
|
||||
return
|
||||
}
|
||||
|
||||
// Exchange code for token
|
||||
code := c.Query("code")
|
||||
token, err := provider.ExchangeToken(c.Request.Context(), code, c)
|
||||
if err != nil {
|
||||
handleOAuthError(c, err)
|
||||
return
|
||||
}
|
||||
|
||||
// Get user info
|
||||
oauthUser, err := provider.GetUserInfo(c.Request.Context(), token)
|
||||
if err != nil {
|
||||
handleOAuthError(c, err)
|
||||
return
|
||||
}
|
||||
|
||||
// Check if this OAuth account is already bound
|
||||
if provider.IsUserIDTaken(oauthUser.ProviderUserID) {
|
||||
common.ApiErrorI18n(c, i18n.MsgOAuthAlreadyBound, providerParams(provider.GetName()))
|
||||
return
|
||||
}
|
||||
|
||||
// Get current user from session
|
||||
session := sessions.Default(c)
|
||||
id := session.Get("id")
|
||||
user := model.User{Id: id.(int)}
|
||||
err = user.FillUserById()
|
||||
if err != nil {
|
||||
common.ApiError(c, err)
|
||||
return
|
||||
}
|
||||
|
||||
// Update user with OAuth ID
|
||||
provider.SetProviderUserID(&user, oauthUser.ProviderUserID)
|
||||
err = user.Update(false)
|
||||
if err != nil {
|
||||
common.ApiError(c, err)
|
||||
return
|
||||
}
|
||||
|
||||
common.ApiSuccessI18n(c, i18n.MsgOAuthBindSuccess, nil)
|
||||
}
|
||||
|
||||
// findOrCreateOAuthUser finds existing user or creates new user
|
||||
func findOrCreateOAuthUser(c *gin.Context, provider oauth.Provider, oauthUser *oauth.OAuthUser, session sessions.Session) (*model.User, error) {
|
||||
user := &model.User{}
|
||||
|
||||
// Check if user already exists
|
||||
if provider.IsUserIDTaken(oauthUser.ProviderUserID) {
|
||||
provider.SetProviderUserID(user, oauthUser.ProviderUserID)
|
||||
err := provider.FillUserByProviderID(user, oauthUser.ProviderUserID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
// Check if user has been deleted
|
||||
if user.Id == 0 {
|
||||
return nil, &OAuthUserDeletedError{}
|
||||
}
|
||||
return user, nil
|
||||
}
|
||||
|
||||
// User doesn't exist, create new user if registration is enabled
|
||||
if !common.RegisterEnabled {
|
||||
return nil, &OAuthRegistrationDisabledError{}
|
||||
}
|
||||
|
||||
// Set up new user
|
||||
user.Username = provider.GetProviderPrefix() + strconv.Itoa(model.GetMaxUserId()+1)
|
||||
if oauthUser.DisplayName != "" {
|
||||
user.DisplayName = oauthUser.DisplayName
|
||||
} else if oauthUser.Username != "" {
|
||||
user.DisplayName = oauthUser.Username
|
||||
} else {
|
||||
user.DisplayName = provider.GetName() + " User"
|
||||
}
|
||||
if oauthUser.Email != "" {
|
||||
user.Email = oauthUser.Email
|
||||
}
|
||||
user.Role = common.RoleCommonUser
|
||||
user.Status = common.UserStatusEnabled
|
||||
provider.SetProviderUserID(user, oauthUser.ProviderUserID)
|
||||
|
||||
// Handle affiliate code
|
||||
affCode := session.Get("aff")
|
||||
inviterId := 0
|
||||
if affCode != nil {
|
||||
inviterId, _ = model.GetUserIdByAffCode(affCode.(string))
|
||||
}
|
||||
|
||||
if err := user.Insert(inviterId); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return user, nil
|
||||
}
|
||||
|
||||
// Error types for OAuth
|
||||
type OAuthUserDeletedError struct{}
|
||||
|
||||
func (e *OAuthUserDeletedError) Error() string {
|
||||
return "user has been deleted"
|
||||
}
|
||||
|
||||
type OAuthRegistrationDisabledError struct{}
|
||||
|
||||
func (e *OAuthRegistrationDisabledError) Error() string {
|
||||
return "registration is disabled"
|
||||
}
|
||||
|
||||
// handleOAuthError handles OAuth errors and returns translated message
|
||||
func handleOAuthError(c *gin.Context, err error) {
|
||||
switch e := err.(type) {
|
||||
case *oauth.OAuthError:
|
||||
if e.Params != nil {
|
||||
common.ApiErrorI18n(c, e.MsgKey, e.Params)
|
||||
} else {
|
||||
common.ApiErrorI18n(c, e.MsgKey)
|
||||
}
|
||||
case *oauth.TrustLevelError:
|
||||
common.ApiErrorI18n(c, i18n.MsgOAuthTrustLevelLow)
|
||||
default:
|
||||
common.ApiError(c, err)
|
||||
}
|
||||
}
|
||||
@@ -1,228 +0,0 @@
|
||||
package controller
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"strconv"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/QuantumNous/new-api/common"
|
||||
"github.com/QuantumNous/new-api/model"
|
||||
"github.com/QuantumNous/new-api/setting/system_setting"
|
||||
|
||||
"github.com/gin-contrib/sessions"
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
type OidcResponse struct {
|
||||
AccessToken string `json:"access_token"`
|
||||
IDToken string `json:"id_token"`
|
||||
RefreshToken string `json:"refresh_token"`
|
||||
TokenType string `json:"token_type"`
|
||||
ExpiresIn int `json:"expires_in"`
|
||||
Scope string `json:"scope"`
|
||||
}
|
||||
|
||||
type OidcUser struct {
|
||||
OpenID string `json:"sub"`
|
||||
Email string `json:"email"`
|
||||
Name string `json:"name"`
|
||||
PreferredUsername string `json:"preferred_username"`
|
||||
Picture string `json:"picture"`
|
||||
}
|
||||
|
||||
func getOidcUserInfoByCode(code string) (*OidcUser, error) {
|
||||
if code == "" {
|
||||
return nil, errors.New("无效的参数")
|
||||
}
|
||||
|
||||
values := url.Values{}
|
||||
values.Set("client_id", system_setting.GetOIDCSettings().ClientId)
|
||||
values.Set("client_secret", system_setting.GetOIDCSettings().ClientSecret)
|
||||
values.Set("code", code)
|
||||
values.Set("grant_type", "authorization_code")
|
||||
values.Set("redirect_uri", fmt.Sprintf("%s/oauth/oidc", system_setting.ServerAddress))
|
||||
formData := values.Encode()
|
||||
req, err := http.NewRequest("POST", system_setting.GetOIDCSettings().TokenEndpoint, strings.NewReader(formData))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
|
||||
req.Header.Set("Accept", "application/json")
|
||||
client := http.Client{
|
||||
Timeout: 5 * time.Second,
|
||||
}
|
||||
res, err := client.Do(req)
|
||||
if err != nil {
|
||||
common.SysLog(err.Error())
|
||||
return nil, errors.New("无法连接至 OIDC 服务器,请稍后重试!")
|
||||
}
|
||||
defer res.Body.Close()
|
||||
var oidcResponse OidcResponse
|
||||
err = json.NewDecoder(res.Body).Decode(&oidcResponse)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if oidcResponse.AccessToken == "" {
|
||||
common.SysLog("OIDC 获取 Token 失败,请检查设置!")
|
||||
return nil, errors.New("OIDC 获取 Token 失败,请检查设置!")
|
||||
}
|
||||
|
||||
req, err = http.NewRequest("GET", system_setting.GetOIDCSettings().UserInfoEndpoint, nil)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
req.Header.Set("Authorization", "Bearer "+oidcResponse.AccessToken)
|
||||
res2, err := client.Do(req)
|
||||
if err != nil {
|
||||
common.SysLog(err.Error())
|
||||
return nil, errors.New("无法连接至 OIDC 服务器,请稍后重试!")
|
||||
}
|
||||
defer res2.Body.Close()
|
||||
if res2.StatusCode != http.StatusOK {
|
||||
common.SysLog("OIDC 获取用户信息失败!请检查设置!")
|
||||
return nil, errors.New("OIDC 获取用户信息失败!请检查设置!")
|
||||
}
|
||||
|
||||
var oidcUser OidcUser
|
||||
err = json.NewDecoder(res2.Body).Decode(&oidcUser)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if oidcUser.OpenID == "" || oidcUser.Email == "" {
|
||||
common.SysLog("OIDC 获取用户信息为空!请检查设置!")
|
||||
return nil, errors.New("OIDC 获取用户信息为空!请检查设置!")
|
||||
}
|
||||
return &oidcUser, nil
|
||||
}
|
||||
|
||||
func OidcAuth(c *gin.Context) {
|
||||
session := sessions.Default(c)
|
||||
state := c.Query("state")
|
||||
if state == "" || session.Get("oauth_state") == nil || state != session.Get("oauth_state").(string) {
|
||||
c.JSON(http.StatusForbidden, gin.H{
|
||||
"success": false,
|
||||
"message": "state is empty or not same",
|
||||
})
|
||||
return
|
||||
}
|
||||
username := session.Get("username")
|
||||
if username != nil {
|
||||
OidcBind(c)
|
||||
return
|
||||
}
|
||||
if !system_setting.GetOIDCSettings().Enabled {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": false,
|
||||
"message": "管理员未开启通过 OIDC 登录以及注册",
|
||||
})
|
||||
return
|
||||
}
|
||||
code := c.Query("code")
|
||||
oidcUser, err := getOidcUserInfoByCode(code)
|
||||
if err != nil {
|
||||
common.ApiError(c, err)
|
||||
return
|
||||
}
|
||||
user := model.User{
|
||||
OidcId: oidcUser.OpenID,
|
||||
}
|
||||
if model.IsOidcIdAlreadyTaken(user.OidcId) {
|
||||
err := user.FillUserByOidcId()
|
||||
if err != nil {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": false,
|
||||
"message": err.Error(),
|
||||
})
|
||||
return
|
||||
}
|
||||
} else {
|
||||
if common.RegisterEnabled {
|
||||
user.Email = oidcUser.Email
|
||||
if oidcUser.PreferredUsername != "" {
|
||||
user.Username = oidcUser.PreferredUsername
|
||||
} else {
|
||||
user.Username = "oidc_" + strconv.Itoa(model.GetMaxUserId()+1)
|
||||
}
|
||||
if oidcUser.Name != "" {
|
||||
user.DisplayName = oidcUser.Name
|
||||
} else {
|
||||
user.DisplayName = "OIDC User"
|
||||
}
|
||||
err := user.Insert(0)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": false,
|
||||
"message": err.Error(),
|
||||
})
|
||||
return
|
||||
}
|
||||
} else {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": false,
|
||||
"message": "管理员关闭了新用户注册",
|
||||
})
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
if user.Status != common.UserStatusEnabled {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"message": "用户已被封禁",
|
||||
"success": false,
|
||||
})
|
||||
return
|
||||
}
|
||||
setupLogin(&user, c)
|
||||
}
|
||||
|
||||
func OidcBind(c *gin.Context) {
|
||||
if !system_setting.GetOIDCSettings().Enabled {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": false,
|
||||
"message": "管理员未开启通过 OIDC 登录以及注册",
|
||||
})
|
||||
return
|
||||
}
|
||||
code := c.Query("code")
|
||||
oidcUser, err := getOidcUserInfoByCode(code)
|
||||
if err != nil {
|
||||
common.ApiError(c, err)
|
||||
return
|
||||
}
|
||||
user := model.User{
|
||||
OidcId: oidcUser.OpenID,
|
||||
}
|
||||
if model.IsOidcIdAlreadyTaken(user.OidcId) {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": false,
|
||||
"message": "该 OIDC 账户已被绑定",
|
||||
})
|
||||
return
|
||||
}
|
||||
session := sessions.Default(c)
|
||||
id := session.Get("id")
|
||||
// id := c.GetInt("id") // critical bug!
|
||||
user.Id = id.(int)
|
||||
err = user.FillUserById()
|
||||
if err != nil {
|
||||
common.ApiError(c, err)
|
||||
return
|
||||
}
|
||||
user.OidcId = oidcUser.OpenID
|
||||
err = user.Update(false)
|
||||
if err != nil {
|
||||
common.ApiError(c, err)
|
||||
return
|
||||
}
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": true,
|
||||
"message": "bind",
|
||||
})
|
||||
return
|
||||
}
|
||||
17
i18n/keys.go
17
i18n/keys.go
@@ -264,9 +264,20 @@ const (
|
||||
|
||||
// OAuth related messages
|
||||
const (
|
||||
MsgOAuthInvalidCode = "oauth.invalid_code"
|
||||
MsgOAuthGetUserErr = "oauth.get_user_error"
|
||||
MsgOAuthAccountUsed = "oauth.account_used"
|
||||
MsgOAuthInvalidCode = "oauth.invalid_code"
|
||||
MsgOAuthGetUserErr = "oauth.get_user_error"
|
||||
MsgOAuthAccountUsed = "oauth.account_used"
|
||||
MsgOAuthUnknownProvider = "oauth.unknown_provider"
|
||||
MsgOAuthStateInvalid = "oauth.state_invalid"
|
||||
MsgOAuthNotEnabled = "oauth.not_enabled"
|
||||
MsgOAuthUserDeleted = "oauth.user_deleted"
|
||||
MsgOAuthUserBanned = "oauth.user_banned"
|
||||
MsgOAuthBindSuccess = "oauth.bind_success"
|
||||
MsgOAuthAlreadyBound = "oauth.already_bound"
|
||||
MsgOAuthConnectFailed = "oauth.connect_failed"
|
||||
MsgOAuthTokenFailed = "oauth.token_failed"
|
||||
MsgOAuthUserInfoEmpty = "oauth.user_info_empty"
|
||||
MsgOAuthTrustLevelLow = "oauth.trust_level_low"
|
||||
)
|
||||
|
||||
// Model layer error messages (for translation in controller)
|
||||
|
||||
@@ -223,6 +223,17 @@ ability.repair_running: "A repair task is already running, please try again late
|
||||
oauth.invalid_code: "Invalid authorization code"
|
||||
oauth.get_user_error: "Failed to get user information"
|
||||
oauth.account_used: "This account has been bound to another user"
|
||||
oauth.unknown_provider: "Unknown OAuth provider"
|
||||
oauth.state_invalid: "State parameter is empty or mismatched"
|
||||
oauth.not_enabled: "{{.Provider}} login and registration has not been enabled by administrator"
|
||||
oauth.user_deleted: "User has been deleted"
|
||||
oauth.user_banned: "User has been banned"
|
||||
oauth.bind_success: "Binding successful"
|
||||
oauth.already_bound: "This {{.Provider}} account has already been bound"
|
||||
oauth.connect_failed: "Unable to connect to {{.Provider}} server, please try again later"
|
||||
oauth.token_failed: "Failed to get token from {{.Provider}}, please check settings"
|
||||
oauth.user_info_empty: "{{.Provider}} returned empty user info, please check settings"
|
||||
oauth.trust_level_low: "Linux DO trust level does not meet the minimum required by administrator"
|
||||
|
||||
# Model layer error messages
|
||||
redeem.failed: "Redemption failed, please try again later"
|
||||
|
||||
@@ -224,6 +224,17 @@ ability.repair_running: "已经有一个修复任务在运行中,请稍后再
|
||||
oauth.invalid_code: "无效的授权码"
|
||||
oauth.get_user_error: "获取用户信息失败"
|
||||
oauth.account_used: "该账户已被其他用户绑定"
|
||||
oauth.unknown_provider: "未知的 OAuth 提供商"
|
||||
oauth.state_invalid: "state 参数为空或不匹配"
|
||||
oauth.not_enabled: "管理员未开启通过 {{.Provider}} 登录以及注册"
|
||||
oauth.user_deleted: "用户已注销"
|
||||
oauth.user_banned: "用户已被封禁"
|
||||
oauth.bind_success: "绑定成功"
|
||||
oauth.already_bound: "该 {{.Provider}} 账户已被绑定"
|
||||
oauth.connect_failed: "无法连接至 {{.Provider}} 服务器,请稍后重试"
|
||||
oauth.token_failed: "{{.Provider}} 获取 Token 失败,请检查设置"
|
||||
oauth.user_info_empty: "{{.Provider}} 获取用户信息为空,请检查设置"
|
||||
oauth.trust_level_low: "Linux DO 信任等级未达到管理员设置的最低信任等级"
|
||||
|
||||
# Model layer error messages
|
||||
redeem.failed: "兑换失败,请稍后重试"
|
||||
|
||||
172
oauth/discord.go
Normal file
172
oauth/discord.go
Normal file
@@ -0,0 +1,172 @@
|
||||
package oauth
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/QuantumNous/new-api/i18n"
|
||||
"github.com/QuantumNous/new-api/logger"
|
||||
"github.com/QuantumNous/new-api/model"
|
||||
"github.com/QuantumNous/new-api/setting/system_setting"
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
func init() {
|
||||
Register("discord", &DiscordProvider{})
|
||||
}
|
||||
|
||||
// DiscordProvider implements OAuth for Discord
|
||||
type DiscordProvider struct{}
|
||||
|
||||
type discordOAuthResponse struct {
|
||||
AccessToken string `json:"access_token"`
|
||||
IDToken string `json:"id_token"`
|
||||
RefreshToken string `json:"refresh_token"`
|
||||
TokenType string `json:"token_type"`
|
||||
ExpiresIn int `json:"expires_in"`
|
||||
Scope string `json:"scope"`
|
||||
}
|
||||
|
||||
type discordUser struct {
|
||||
UID string `json:"id"`
|
||||
ID string `json:"username"`
|
||||
Name string `json:"global_name"`
|
||||
}
|
||||
|
||||
func (p *DiscordProvider) GetName() string {
|
||||
return "Discord"
|
||||
}
|
||||
|
||||
func (p *DiscordProvider) IsEnabled() bool {
|
||||
return system_setting.GetDiscordSettings().Enabled
|
||||
}
|
||||
|
||||
func (p *DiscordProvider) ExchangeToken(ctx context.Context, code string, c *gin.Context) (*OAuthToken, error) {
|
||||
if code == "" {
|
||||
return nil, NewOAuthError(i18n.MsgOAuthInvalidCode, nil)
|
||||
}
|
||||
|
||||
logger.LogDebug(ctx, "[OAuth-Discord] ExchangeToken: code=%s...", code[:min(len(code), 10)])
|
||||
|
||||
settings := system_setting.GetDiscordSettings()
|
||||
redirectUri := fmt.Sprintf("%s/oauth/discord", system_setting.ServerAddress)
|
||||
values := url.Values{}
|
||||
values.Set("client_id", settings.ClientId)
|
||||
values.Set("client_secret", settings.ClientSecret)
|
||||
values.Set("code", code)
|
||||
values.Set("grant_type", "authorization_code")
|
||||
values.Set("redirect_uri", redirectUri)
|
||||
|
||||
logger.LogDebug(ctx, "[OAuth-Discord] ExchangeToken: redirect_uri=%s", redirectUri)
|
||||
|
||||
req, err := http.NewRequestWithContext(ctx, "POST", "https://discord.com/api/v10/oauth2/token", strings.NewReader(values.Encode()))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
|
||||
req.Header.Set("Accept", "application/json")
|
||||
|
||||
client := http.Client{
|
||||
Timeout: 5 * time.Second,
|
||||
}
|
||||
res, err := client.Do(req)
|
||||
if err != nil {
|
||||
logger.LogError(ctx, fmt.Sprintf("[OAuth-Discord] ExchangeToken error: %s", err.Error()))
|
||||
return nil, NewOAuthErrorWithRaw(i18n.MsgOAuthConnectFailed, map[string]any{"Provider": "Discord"}, err.Error())
|
||||
}
|
||||
defer res.Body.Close()
|
||||
|
||||
logger.LogDebug(ctx, "[OAuth-Discord] ExchangeToken response status: %d", res.StatusCode)
|
||||
|
||||
var discordResponse discordOAuthResponse
|
||||
err = json.NewDecoder(res.Body).Decode(&discordResponse)
|
||||
if err != nil {
|
||||
logger.LogError(ctx, fmt.Sprintf("[OAuth-Discord] ExchangeToken decode error: %s", err.Error()))
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if discordResponse.AccessToken == "" {
|
||||
logger.LogError(ctx, "[OAuth-Discord] ExchangeToken failed: empty access token")
|
||||
return nil, NewOAuthError(i18n.MsgOAuthTokenFailed, map[string]any{"Provider": "Discord"})
|
||||
}
|
||||
|
||||
logger.LogDebug(ctx, "[OAuth-Discord] ExchangeToken success: scope=%s", discordResponse.Scope)
|
||||
|
||||
return &OAuthToken{
|
||||
AccessToken: discordResponse.AccessToken,
|
||||
TokenType: discordResponse.TokenType,
|
||||
RefreshToken: discordResponse.RefreshToken,
|
||||
ExpiresIn: discordResponse.ExpiresIn,
|
||||
Scope: discordResponse.Scope,
|
||||
IDToken: discordResponse.IDToken,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (p *DiscordProvider) GetUserInfo(ctx context.Context, token *OAuthToken) (*OAuthUser, error) {
|
||||
logger.LogDebug(ctx, "[OAuth-Discord] GetUserInfo: fetching user info")
|
||||
|
||||
req, err := http.NewRequestWithContext(ctx, "GET", "https://discord.com/api/v10/users/@me", nil)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
req.Header.Set("Authorization", "Bearer "+token.AccessToken)
|
||||
|
||||
client := http.Client{
|
||||
Timeout: 5 * time.Second,
|
||||
}
|
||||
res, err := client.Do(req)
|
||||
if err != nil {
|
||||
logger.LogError(ctx, fmt.Sprintf("[OAuth-Discord] GetUserInfo error: %s", err.Error()))
|
||||
return nil, NewOAuthErrorWithRaw(i18n.MsgOAuthConnectFailed, map[string]any{"Provider": "Discord"}, err.Error())
|
||||
}
|
||||
defer res.Body.Close()
|
||||
|
||||
logger.LogDebug(ctx, "[OAuth-Discord] GetUserInfo response status: %d", res.StatusCode)
|
||||
|
||||
if res.StatusCode != http.StatusOK {
|
||||
logger.LogError(ctx, fmt.Sprintf("[OAuth-Discord] GetUserInfo failed: status=%d", res.StatusCode))
|
||||
return nil, NewOAuthError(i18n.MsgOAuthGetUserErr, nil)
|
||||
}
|
||||
|
||||
var discordUser discordUser
|
||||
err = json.NewDecoder(res.Body).Decode(&discordUser)
|
||||
if err != nil {
|
||||
logger.LogError(ctx, fmt.Sprintf("[OAuth-Discord] GetUserInfo decode error: %s", err.Error()))
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if discordUser.UID == "" || discordUser.ID == "" {
|
||||
logger.LogError(ctx, "[OAuth-Discord] GetUserInfo failed: empty user fields")
|
||||
return nil, NewOAuthError(i18n.MsgOAuthUserInfoEmpty, map[string]any{"Provider": "Discord"})
|
||||
}
|
||||
|
||||
logger.LogDebug(ctx, "[OAuth-Discord] GetUserInfo success: uid=%s, username=%s, name=%s", discordUser.UID, discordUser.ID, discordUser.Name)
|
||||
|
||||
return &OAuthUser{
|
||||
ProviderUserID: discordUser.UID,
|
||||
Username: discordUser.ID,
|
||||
DisplayName: discordUser.Name,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (p *DiscordProvider) IsUserIDTaken(providerUserID string) bool {
|
||||
return model.IsDiscordIdAlreadyTaken(providerUserID)
|
||||
}
|
||||
|
||||
func (p *DiscordProvider) FillUserByProviderID(user *model.User, providerUserID string) error {
|
||||
user.DiscordId = providerUserID
|
||||
return user.FillUserByDiscordId()
|
||||
}
|
||||
|
||||
func (p *DiscordProvider) SetProviderUserID(user *model.User, providerUserID string) {
|
||||
user.DiscordId = providerUserID
|
||||
}
|
||||
|
||||
func (p *DiscordProvider) GetProviderPrefix() string {
|
||||
return "discord_"
|
||||
}
|
||||
160
oauth/github.go
Normal file
160
oauth/github.go
Normal file
@@ -0,0 +1,160 @@
|
||||
package oauth
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"time"
|
||||
|
||||
"github.com/QuantumNous/new-api/common"
|
||||
"github.com/QuantumNous/new-api/i18n"
|
||||
"github.com/QuantumNous/new-api/logger"
|
||||
"github.com/QuantumNous/new-api/model"
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
func init() {
|
||||
Register("github", &GitHubProvider{})
|
||||
}
|
||||
|
||||
// GitHubProvider implements OAuth for GitHub
|
||||
type GitHubProvider struct{}
|
||||
|
||||
type gitHubOAuthResponse struct {
|
||||
AccessToken string `json:"access_token"`
|
||||
Scope string `json:"scope"`
|
||||
TokenType string `json:"token_type"`
|
||||
}
|
||||
|
||||
type gitHubUser struct {
|
||||
Login string `json:"login"`
|
||||
Name string `json:"name"`
|
||||
Email string `json:"email"`
|
||||
}
|
||||
|
||||
func (p *GitHubProvider) GetName() string {
|
||||
return "GitHub"
|
||||
}
|
||||
|
||||
func (p *GitHubProvider) IsEnabled() bool {
|
||||
return common.GitHubOAuthEnabled
|
||||
}
|
||||
|
||||
func (p *GitHubProvider) ExchangeToken(ctx context.Context, code string, c *gin.Context) (*OAuthToken, error) {
|
||||
if code == "" {
|
||||
return nil, NewOAuthError(i18n.MsgOAuthInvalidCode, nil)
|
||||
}
|
||||
|
||||
logger.LogDebug(ctx, "[OAuth-GitHub] ExchangeToken: code=%s...", code[:min(len(code), 10)])
|
||||
|
||||
values := map[string]string{
|
||||
"client_id": common.GitHubClientId,
|
||||
"client_secret": common.GitHubClientSecret,
|
||||
"code": code,
|
||||
}
|
||||
jsonData, err := json.Marshal(values)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
req, err := http.NewRequestWithContext(ctx, "POST", "https://github.com/login/oauth/access_token", bytes.NewBuffer(jsonData))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
req.Header.Set("Accept", "application/json")
|
||||
|
||||
client := http.Client{
|
||||
Timeout: 20 * time.Second,
|
||||
}
|
||||
res, err := client.Do(req)
|
||||
if err != nil {
|
||||
logger.LogError(ctx, fmt.Sprintf("[OAuth-GitHub] ExchangeToken error: %s", err.Error()))
|
||||
return nil, NewOAuthErrorWithRaw(i18n.MsgOAuthConnectFailed, map[string]any{"Provider": "GitHub"}, err.Error())
|
||||
}
|
||||
defer res.Body.Close()
|
||||
|
||||
logger.LogDebug(ctx, "[OAuth-GitHub] ExchangeToken response status: %d", res.StatusCode)
|
||||
|
||||
var oAuthResponse gitHubOAuthResponse
|
||||
err = json.NewDecoder(res.Body).Decode(&oAuthResponse)
|
||||
if err != nil {
|
||||
logger.LogError(ctx, fmt.Sprintf("[OAuth-GitHub] ExchangeToken decode error: %s", err.Error()))
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if oAuthResponse.AccessToken == "" {
|
||||
logger.LogError(ctx, "[OAuth-GitHub] ExchangeToken failed: empty access token")
|
||||
return nil, NewOAuthError(i18n.MsgOAuthTokenFailed, map[string]any{"Provider": "GitHub"})
|
||||
}
|
||||
|
||||
logger.LogDebug(ctx, "[OAuth-GitHub] ExchangeToken success: scope=%s", oAuthResponse.Scope)
|
||||
|
||||
return &OAuthToken{
|
||||
AccessToken: oAuthResponse.AccessToken,
|
||||
TokenType: oAuthResponse.TokenType,
|
||||
Scope: oAuthResponse.Scope,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (p *GitHubProvider) GetUserInfo(ctx context.Context, token *OAuthToken) (*OAuthUser, error) {
|
||||
logger.LogDebug(ctx, "[OAuth-GitHub] GetUserInfo: fetching user info")
|
||||
|
||||
req, err := http.NewRequestWithContext(ctx, "GET", "https://api.github.com/user", nil)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", token.AccessToken))
|
||||
|
||||
client := http.Client{
|
||||
Timeout: 20 * time.Second,
|
||||
}
|
||||
res, err := client.Do(req)
|
||||
if err != nil {
|
||||
logger.LogError(ctx, fmt.Sprintf("[OAuth-GitHub] GetUserInfo error: %s", err.Error()))
|
||||
return nil, NewOAuthErrorWithRaw(i18n.MsgOAuthConnectFailed, map[string]any{"Provider": "GitHub"}, err.Error())
|
||||
}
|
||||
defer res.Body.Close()
|
||||
|
||||
logger.LogDebug(ctx, "[OAuth-GitHub] GetUserInfo response status: %d", res.StatusCode)
|
||||
|
||||
var githubUser gitHubUser
|
||||
err = json.NewDecoder(res.Body).Decode(&githubUser)
|
||||
if err != nil {
|
||||
logger.LogError(ctx, fmt.Sprintf("[OAuth-GitHub] GetUserInfo decode error: %s", err.Error()))
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if githubUser.Login == "" {
|
||||
logger.LogError(ctx, "[OAuth-GitHub] GetUserInfo failed: empty login field")
|
||||
return nil, NewOAuthError(i18n.MsgOAuthUserInfoEmpty, map[string]any{"Provider": "GitHub"})
|
||||
}
|
||||
|
||||
logger.LogDebug(ctx, "[OAuth-GitHub] GetUserInfo success: login=%s, name=%s, email=%s", githubUser.Login, githubUser.Name, githubUser.Email)
|
||||
|
||||
return &OAuthUser{
|
||||
ProviderUserID: githubUser.Login,
|
||||
Username: githubUser.Login,
|
||||
DisplayName: githubUser.Name,
|
||||
Email: githubUser.Email,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (p *GitHubProvider) IsUserIDTaken(providerUserID string) bool {
|
||||
return model.IsGitHubIdAlreadyTaken(providerUserID)
|
||||
}
|
||||
|
||||
func (p *GitHubProvider) FillUserByProviderID(user *model.User, providerUserID string) error {
|
||||
user.GitHubId = providerUserID
|
||||
return user.FillUserByGitHubId()
|
||||
}
|
||||
|
||||
func (p *GitHubProvider) SetProviderUserID(user *model.User, providerUserID string) {
|
||||
user.GitHubId = providerUserID
|
||||
}
|
||||
|
||||
func (p *GitHubProvider) GetProviderPrefix() string {
|
||||
return "github_"
|
||||
}
|
||||
195
oauth/linuxdo.go
Normal file
195
oauth/linuxdo.go
Normal file
@@ -0,0 +1,195 @@
|
||||
package oauth
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/base64"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"strconv"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/QuantumNous/new-api/common"
|
||||
"github.com/QuantumNous/new-api/i18n"
|
||||
"github.com/QuantumNous/new-api/logger"
|
||||
"github.com/QuantumNous/new-api/model"
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
func init() {
|
||||
Register("linuxdo", &LinuxDOProvider{})
|
||||
}
|
||||
|
||||
// LinuxDOProvider implements OAuth for Linux DO
|
||||
type LinuxDOProvider struct{}
|
||||
|
||||
type linuxdoUser struct {
|
||||
Id int `json:"id"`
|
||||
Username string `json:"username"`
|
||||
Name string `json:"name"`
|
||||
Active bool `json:"active"`
|
||||
TrustLevel int `json:"trust_level"`
|
||||
Silenced bool `json:"silenced"`
|
||||
}
|
||||
|
||||
func (p *LinuxDOProvider) GetName() string {
|
||||
return "Linux DO"
|
||||
}
|
||||
|
||||
func (p *LinuxDOProvider) IsEnabled() bool {
|
||||
return common.LinuxDOOAuthEnabled
|
||||
}
|
||||
|
||||
func (p *LinuxDOProvider) ExchangeToken(ctx context.Context, code string, c *gin.Context) (*OAuthToken, error) {
|
||||
if code == "" {
|
||||
return nil, NewOAuthError(i18n.MsgOAuthInvalidCode, nil)
|
||||
}
|
||||
|
||||
logger.LogDebug(ctx, "[OAuth-LinuxDO] ExchangeToken: code=%s...", code[:min(len(code), 10)])
|
||||
|
||||
// Get access token using Basic auth
|
||||
tokenEndpoint := common.GetEnvOrDefaultString("LINUX_DO_TOKEN_ENDPOINT", "https://connect.linux.do/oauth2/token")
|
||||
credentials := common.LinuxDOClientId + ":" + common.LinuxDOClientSecret
|
||||
basicAuth := "Basic " + base64.StdEncoding.EncodeToString([]byte(credentials))
|
||||
|
||||
// Get redirect URI from request
|
||||
scheme := "http"
|
||||
if c.Request.TLS != nil {
|
||||
scheme = "https"
|
||||
}
|
||||
redirectURI := fmt.Sprintf("%s://%s/api/oauth/linuxdo", scheme, c.Request.Host)
|
||||
|
||||
logger.LogDebug(ctx, "[OAuth-LinuxDO] ExchangeToken: token_endpoint=%s, redirect_uri=%s", tokenEndpoint, redirectURI)
|
||||
|
||||
data := url.Values{}
|
||||
data.Set("grant_type", "authorization_code")
|
||||
data.Set("code", code)
|
||||
data.Set("redirect_uri", redirectURI)
|
||||
|
||||
req, err := http.NewRequestWithContext(ctx, "POST", tokenEndpoint, strings.NewReader(data.Encode()))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
req.Header.Set("Authorization", basicAuth)
|
||||
req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
|
||||
req.Header.Set("Accept", "application/json")
|
||||
|
||||
client := http.Client{Timeout: 5 * time.Second}
|
||||
res, err := client.Do(req)
|
||||
if err != nil {
|
||||
logger.LogError(ctx, fmt.Sprintf("[OAuth-LinuxDO] ExchangeToken error: %s", err.Error()))
|
||||
return nil, NewOAuthErrorWithRaw(i18n.MsgOAuthConnectFailed, map[string]any{"Provider": "Linux DO"}, err.Error())
|
||||
}
|
||||
defer res.Body.Close()
|
||||
|
||||
logger.LogDebug(ctx, "[OAuth-LinuxDO] ExchangeToken response status: %d", res.StatusCode)
|
||||
|
||||
var tokenRes struct {
|
||||
AccessToken string `json:"access_token"`
|
||||
Message string `json:"message"`
|
||||
}
|
||||
if err := json.NewDecoder(res.Body).Decode(&tokenRes); err != nil {
|
||||
logger.LogError(ctx, fmt.Sprintf("[OAuth-LinuxDO] ExchangeToken decode error: %s", err.Error()))
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if tokenRes.AccessToken == "" {
|
||||
logger.LogError(ctx, fmt.Sprintf("[OAuth-LinuxDO] ExchangeToken failed: %s", tokenRes.Message))
|
||||
return nil, NewOAuthErrorWithRaw(i18n.MsgOAuthTokenFailed, map[string]any{"Provider": "Linux DO"}, tokenRes.Message)
|
||||
}
|
||||
|
||||
logger.LogDebug(ctx, "[OAuth-LinuxDO] ExchangeToken success")
|
||||
|
||||
return &OAuthToken{
|
||||
AccessToken: tokenRes.AccessToken,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (p *LinuxDOProvider) GetUserInfo(ctx context.Context, token *OAuthToken) (*OAuthUser, error) {
|
||||
userEndpoint := common.GetEnvOrDefaultString("LINUX_DO_USER_ENDPOINT", "https://connect.linux.do/api/user")
|
||||
|
||||
logger.LogDebug(ctx, "[OAuth-LinuxDO] GetUserInfo: user_endpoint=%s", userEndpoint)
|
||||
|
||||
req, err := http.NewRequestWithContext(ctx, "GET", userEndpoint, nil)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
req.Header.Set("Authorization", "Bearer "+token.AccessToken)
|
||||
req.Header.Set("Accept", "application/json")
|
||||
|
||||
client := http.Client{Timeout: 5 * time.Second}
|
||||
res, err := client.Do(req)
|
||||
if err != nil {
|
||||
logger.LogError(ctx, fmt.Sprintf("[OAuth-LinuxDO] GetUserInfo error: %s", err.Error()))
|
||||
return nil, NewOAuthErrorWithRaw(i18n.MsgOAuthConnectFailed, map[string]any{"Provider": "Linux DO"}, err.Error())
|
||||
}
|
||||
defer res.Body.Close()
|
||||
|
||||
logger.LogDebug(ctx, "[OAuth-LinuxDO] GetUserInfo response status: %d", res.StatusCode)
|
||||
|
||||
var linuxdoUser linuxdoUser
|
||||
if err := json.NewDecoder(res.Body).Decode(&linuxdoUser); err != nil {
|
||||
logger.LogError(ctx, fmt.Sprintf("[OAuth-LinuxDO] GetUserInfo decode error: %s", err.Error()))
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if linuxdoUser.Id == 0 {
|
||||
logger.LogError(ctx, "[OAuth-LinuxDO] GetUserInfo failed: invalid user id")
|
||||
return nil, NewOAuthError(i18n.MsgOAuthUserInfoEmpty, map[string]any{"Provider": "Linux DO"})
|
||||
}
|
||||
|
||||
logger.LogDebug(ctx, "[OAuth-LinuxDO] GetUserInfo: id=%d, username=%s, name=%s, trust_level=%d, active=%v, silenced=%v",
|
||||
linuxdoUser.Id, linuxdoUser.Username, linuxdoUser.Name, linuxdoUser.TrustLevel, linuxdoUser.Active, linuxdoUser.Silenced)
|
||||
|
||||
// Check trust level
|
||||
if linuxdoUser.TrustLevel < common.LinuxDOMinimumTrustLevel {
|
||||
logger.LogWarn(ctx, fmt.Sprintf("[OAuth-LinuxDO] GetUserInfo: trust level too low (required=%d, current=%d)",
|
||||
common.LinuxDOMinimumTrustLevel, linuxdoUser.TrustLevel))
|
||||
return nil, &TrustLevelError{
|
||||
Required: common.LinuxDOMinimumTrustLevel,
|
||||
Current: linuxdoUser.TrustLevel,
|
||||
}
|
||||
}
|
||||
|
||||
logger.LogDebug(ctx, "[OAuth-LinuxDO] GetUserInfo success: id=%d, username=%s", linuxdoUser.Id, linuxdoUser.Username)
|
||||
|
||||
return &OAuthUser{
|
||||
ProviderUserID: strconv.Itoa(linuxdoUser.Id),
|
||||
Username: linuxdoUser.Username,
|
||||
DisplayName: linuxdoUser.Name,
|
||||
Extra: map[string]any{
|
||||
"trust_level": linuxdoUser.TrustLevel,
|
||||
"active": linuxdoUser.Active,
|
||||
"silenced": linuxdoUser.Silenced,
|
||||
},
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (p *LinuxDOProvider) IsUserIDTaken(providerUserID string) bool {
|
||||
return model.IsLinuxDOIdAlreadyTaken(providerUserID)
|
||||
}
|
||||
|
||||
func (p *LinuxDOProvider) FillUserByProviderID(user *model.User, providerUserID string) error {
|
||||
user.LinuxDOId = providerUserID
|
||||
return user.FillUserByLinuxDOId()
|
||||
}
|
||||
|
||||
func (p *LinuxDOProvider) SetProviderUserID(user *model.User, providerUserID string) {
|
||||
user.LinuxDOId = providerUserID
|
||||
}
|
||||
|
||||
func (p *LinuxDOProvider) GetProviderPrefix() string {
|
||||
return "linuxdo_"
|
||||
}
|
||||
|
||||
// TrustLevelError indicates the user's trust level is too low
|
||||
type TrustLevelError struct {
|
||||
Required int
|
||||
Current int
|
||||
}
|
||||
|
||||
func (e *TrustLevelError) Error() string {
|
||||
return "trust level too low"
|
||||
}
|
||||
177
oauth/oidc.go
Normal file
177
oauth/oidc.go
Normal file
@@ -0,0 +1,177 @@
|
||||
package oauth
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/QuantumNous/new-api/i18n"
|
||||
"github.com/QuantumNous/new-api/logger"
|
||||
"github.com/QuantumNous/new-api/model"
|
||||
"github.com/QuantumNous/new-api/setting/system_setting"
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
func init() {
|
||||
Register("oidc", &OIDCProvider{})
|
||||
}
|
||||
|
||||
// OIDCProvider implements OAuth for OIDC
|
||||
type OIDCProvider struct{}
|
||||
|
||||
type oidcOAuthResponse struct {
|
||||
AccessToken string `json:"access_token"`
|
||||
IDToken string `json:"id_token"`
|
||||
RefreshToken string `json:"refresh_token"`
|
||||
TokenType string `json:"token_type"`
|
||||
ExpiresIn int `json:"expires_in"`
|
||||
Scope string `json:"scope"`
|
||||
}
|
||||
|
||||
type oidcUser struct {
|
||||
OpenID string `json:"sub"`
|
||||
Email string `json:"email"`
|
||||
Name string `json:"name"`
|
||||
PreferredUsername string `json:"preferred_username"`
|
||||
Picture string `json:"picture"`
|
||||
}
|
||||
|
||||
func (p *OIDCProvider) GetName() string {
|
||||
return "OIDC"
|
||||
}
|
||||
|
||||
func (p *OIDCProvider) IsEnabled() bool {
|
||||
return system_setting.GetOIDCSettings().Enabled
|
||||
}
|
||||
|
||||
func (p *OIDCProvider) ExchangeToken(ctx context.Context, code string, c *gin.Context) (*OAuthToken, error) {
|
||||
if code == "" {
|
||||
return nil, NewOAuthError(i18n.MsgOAuthInvalidCode, nil)
|
||||
}
|
||||
|
||||
logger.LogDebug(ctx, "[OAuth-OIDC] ExchangeToken: code=%s...", code[:min(len(code), 10)])
|
||||
|
||||
settings := system_setting.GetOIDCSettings()
|
||||
redirectUri := fmt.Sprintf("%s/oauth/oidc", system_setting.ServerAddress)
|
||||
values := url.Values{}
|
||||
values.Set("client_id", settings.ClientId)
|
||||
values.Set("client_secret", settings.ClientSecret)
|
||||
values.Set("code", code)
|
||||
values.Set("grant_type", "authorization_code")
|
||||
values.Set("redirect_uri", redirectUri)
|
||||
|
||||
logger.LogDebug(ctx, "[OAuth-OIDC] ExchangeToken: token_endpoint=%s, redirect_uri=%s", settings.TokenEndpoint, redirectUri)
|
||||
|
||||
req, err := http.NewRequestWithContext(ctx, "POST", settings.TokenEndpoint, strings.NewReader(values.Encode()))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
|
||||
req.Header.Set("Accept", "application/json")
|
||||
|
||||
client := http.Client{
|
||||
Timeout: 5 * time.Second,
|
||||
}
|
||||
res, err := client.Do(req)
|
||||
if err != nil {
|
||||
logger.LogError(ctx, fmt.Sprintf("[OAuth-OIDC] ExchangeToken error: %s", err.Error()))
|
||||
return nil, NewOAuthErrorWithRaw(i18n.MsgOAuthConnectFailed, map[string]any{"Provider": "OIDC"}, err.Error())
|
||||
}
|
||||
defer res.Body.Close()
|
||||
|
||||
logger.LogDebug(ctx, "[OAuth-OIDC] ExchangeToken response status: %d", res.StatusCode)
|
||||
|
||||
var oidcResponse oidcOAuthResponse
|
||||
err = json.NewDecoder(res.Body).Decode(&oidcResponse)
|
||||
if err != nil {
|
||||
logger.LogError(ctx, fmt.Sprintf("[OAuth-OIDC] ExchangeToken decode error: %s", err.Error()))
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if oidcResponse.AccessToken == "" {
|
||||
logger.LogError(ctx, "[OAuth-OIDC] ExchangeToken failed: empty access token")
|
||||
return nil, NewOAuthError(i18n.MsgOAuthTokenFailed, map[string]any{"Provider": "OIDC"})
|
||||
}
|
||||
|
||||
logger.LogDebug(ctx, "[OAuth-OIDC] ExchangeToken success: scope=%s", oidcResponse.Scope)
|
||||
|
||||
return &OAuthToken{
|
||||
AccessToken: oidcResponse.AccessToken,
|
||||
TokenType: oidcResponse.TokenType,
|
||||
RefreshToken: oidcResponse.RefreshToken,
|
||||
ExpiresIn: oidcResponse.ExpiresIn,
|
||||
Scope: oidcResponse.Scope,
|
||||
IDToken: oidcResponse.IDToken,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (p *OIDCProvider) GetUserInfo(ctx context.Context, token *OAuthToken) (*OAuthUser, error) {
|
||||
settings := system_setting.GetOIDCSettings()
|
||||
|
||||
logger.LogDebug(ctx, "[OAuth-OIDC] GetUserInfo: userinfo_endpoint=%s", settings.UserInfoEndpoint)
|
||||
|
||||
req, err := http.NewRequestWithContext(ctx, "GET", settings.UserInfoEndpoint, nil)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
req.Header.Set("Authorization", "Bearer "+token.AccessToken)
|
||||
|
||||
client := http.Client{
|
||||
Timeout: 5 * time.Second,
|
||||
}
|
||||
res, err := client.Do(req)
|
||||
if err != nil {
|
||||
logger.LogError(ctx, fmt.Sprintf("[OAuth-OIDC] GetUserInfo error: %s", err.Error()))
|
||||
return nil, NewOAuthErrorWithRaw(i18n.MsgOAuthConnectFailed, map[string]any{"Provider": "OIDC"}, err.Error())
|
||||
}
|
||||
defer res.Body.Close()
|
||||
|
||||
logger.LogDebug(ctx, "[OAuth-OIDC] GetUserInfo response status: %d", res.StatusCode)
|
||||
|
||||
if res.StatusCode != http.StatusOK {
|
||||
logger.LogError(ctx, fmt.Sprintf("[OAuth-OIDC] GetUserInfo failed: status=%d", res.StatusCode))
|
||||
return nil, NewOAuthError(i18n.MsgOAuthGetUserErr, nil)
|
||||
}
|
||||
|
||||
var oidcUser oidcUser
|
||||
err = json.NewDecoder(res.Body).Decode(&oidcUser)
|
||||
if err != nil {
|
||||
logger.LogError(ctx, fmt.Sprintf("[OAuth-OIDC] GetUserInfo decode error: %s", err.Error()))
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if oidcUser.OpenID == "" || oidcUser.Email == "" {
|
||||
logger.LogError(ctx, fmt.Sprintf("[OAuth-OIDC] GetUserInfo failed: empty fields (sub=%s, email=%s)", oidcUser.OpenID, oidcUser.Email))
|
||||
return nil, NewOAuthError(i18n.MsgOAuthUserInfoEmpty, map[string]any{"Provider": "OIDC"})
|
||||
}
|
||||
|
||||
logger.LogDebug(ctx, "[OAuth-OIDC] GetUserInfo success: sub=%s, username=%s, name=%s, email=%s", oidcUser.OpenID, oidcUser.PreferredUsername, oidcUser.Name, oidcUser.Email)
|
||||
|
||||
return &OAuthUser{
|
||||
ProviderUserID: oidcUser.OpenID,
|
||||
Username: oidcUser.PreferredUsername,
|
||||
DisplayName: oidcUser.Name,
|
||||
Email: oidcUser.Email,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (p *OIDCProvider) IsUserIDTaken(providerUserID string) bool {
|
||||
return model.IsOidcIdAlreadyTaken(providerUserID)
|
||||
}
|
||||
|
||||
func (p *OIDCProvider) FillUserByProviderID(user *model.User, providerUserID string) error {
|
||||
user.OidcId = providerUserID
|
||||
return user.FillUserByOidcId()
|
||||
}
|
||||
|
||||
func (p *OIDCProvider) SetProviderUserID(user *model.User, providerUserID string) {
|
||||
user.OidcId = providerUserID
|
||||
}
|
||||
|
||||
func (p *OIDCProvider) GetProviderPrefix() string {
|
||||
return "oidc_"
|
||||
}
|
||||
36
oauth/provider.go
Normal file
36
oauth/provider.go
Normal file
@@ -0,0 +1,36 @@
|
||||
package oauth
|
||||
|
||||
import (
|
||||
"context"
|
||||
|
||||
"github.com/QuantumNous/new-api/model"
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
// Provider defines the interface for OAuth providers
|
||||
type Provider interface {
|
||||
// GetName returns the display name of the provider (e.g., "GitHub", "Discord")
|
||||
GetName() string
|
||||
|
||||
// IsEnabled returns whether this OAuth provider is enabled
|
||||
IsEnabled() bool
|
||||
|
||||
// ExchangeToken exchanges the authorization code for an access token
|
||||
// The gin.Context is passed for providers that need request info (e.g., for redirect_uri)
|
||||
ExchangeToken(ctx context.Context, code string, c *gin.Context) (*OAuthToken, error)
|
||||
|
||||
// GetUserInfo retrieves user information using the access token
|
||||
GetUserInfo(ctx context.Context, token *OAuthToken) (*OAuthUser, error)
|
||||
|
||||
// IsUserIDTaken checks if the provider user ID is already associated with an account
|
||||
IsUserIDTaken(providerUserID string) bool
|
||||
|
||||
// FillUserByProviderID fills the user model by provider user ID
|
||||
FillUserByProviderID(user *model.User, providerUserID string) error
|
||||
|
||||
// SetProviderUserID sets the provider user ID on the user model
|
||||
SetProviderUserID(user *model.User, providerUserID string)
|
||||
|
||||
// GetProviderPrefix returns the prefix for auto-generated usernames (e.g., "github_")
|
||||
GetProviderPrefix() string
|
||||
}
|
||||
43
oauth/registry.go
Normal file
43
oauth/registry.go
Normal file
@@ -0,0 +1,43 @@
|
||||
package oauth
|
||||
|
||||
import (
|
||||
"sync"
|
||||
)
|
||||
|
||||
var (
|
||||
providers = make(map[string]Provider)
|
||||
mu sync.RWMutex
|
||||
)
|
||||
|
||||
// Register registers an OAuth provider with the given name
|
||||
func Register(name string, provider Provider) {
|
||||
mu.Lock()
|
||||
defer mu.Unlock()
|
||||
providers[name] = provider
|
||||
}
|
||||
|
||||
// GetProvider returns the OAuth provider for the given name
|
||||
func GetProvider(name string) Provider {
|
||||
mu.RLock()
|
||||
defer mu.RUnlock()
|
||||
return providers[name]
|
||||
}
|
||||
|
||||
// GetAllProviders returns all registered OAuth providers
|
||||
func GetAllProviders() map[string]Provider {
|
||||
mu.RLock()
|
||||
defer mu.RUnlock()
|
||||
result := make(map[string]Provider, len(providers))
|
||||
for k, v := range providers {
|
||||
result[k] = v
|
||||
}
|
||||
return result
|
||||
}
|
||||
|
||||
// IsProviderRegistered checks if a provider is registered
|
||||
func IsProviderRegistered(name string) bool {
|
||||
mu.RLock()
|
||||
defer mu.RUnlock()
|
||||
_, ok := providers[name]
|
||||
return ok
|
||||
}
|
||||
59
oauth/types.go
Normal file
59
oauth/types.go
Normal file
@@ -0,0 +1,59 @@
|
||||
package oauth
|
||||
|
||||
// OAuthToken represents the token received from OAuth provider
|
||||
type OAuthToken struct {
|
||||
AccessToken string `json:"access_token"`
|
||||
TokenType string `json:"token_type"`
|
||||
RefreshToken string `json:"refresh_token,omitempty"`
|
||||
ExpiresIn int `json:"expires_in,omitempty"`
|
||||
Scope string `json:"scope,omitempty"`
|
||||
IDToken string `json:"id_token,omitempty"`
|
||||
}
|
||||
|
||||
// OAuthUser represents the user info from OAuth provider
|
||||
type OAuthUser struct {
|
||||
// ProviderUserID is the unique identifier from the OAuth provider
|
||||
ProviderUserID string
|
||||
// Username is the username from the OAuth provider (e.g., GitHub login)
|
||||
Username string
|
||||
// DisplayName is the display name from the OAuth provider
|
||||
DisplayName string
|
||||
// Email is the email from the OAuth provider
|
||||
Email string
|
||||
// Extra contains any additional provider-specific data
|
||||
Extra map[string]any
|
||||
}
|
||||
|
||||
// OAuthError represents a translatable OAuth error
|
||||
type OAuthError struct {
|
||||
// MsgKey is the i18n message key
|
||||
MsgKey string
|
||||
// Params contains optional parameters for the message template
|
||||
Params map[string]any
|
||||
// RawError is the underlying error for logging purposes
|
||||
RawError string
|
||||
}
|
||||
|
||||
func (e *OAuthError) Error() string {
|
||||
if e.RawError != "" {
|
||||
return e.RawError
|
||||
}
|
||||
return e.MsgKey
|
||||
}
|
||||
|
||||
// NewOAuthError creates a new OAuth error with the given message key
|
||||
func NewOAuthError(msgKey string, params map[string]any) *OAuthError {
|
||||
return &OAuthError{
|
||||
MsgKey: msgKey,
|
||||
Params: params,
|
||||
}
|
||||
}
|
||||
|
||||
// NewOAuthErrorWithRaw creates a new OAuth error with raw error message for logging
|
||||
func NewOAuthErrorWithRaw(msgKey string, params map[string]any, rawError string) *OAuthError {
|
||||
return &OAuthError{
|
||||
MsgKey: msgKey,
|
||||
Params: params,
|
||||
RawError: rawError,
|
||||
}
|
||||
}
|
||||
@@ -4,6 +4,9 @@ import (
|
||||
"github.com/QuantumNous/new-api/controller"
|
||||
"github.com/QuantumNous/new-api/middleware"
|
||||
|
||||
// Import oauth package to register providers via init()
|
||||
_ "github.com/QuantumNous/new-api/oauth"
|
||||
|
||||
"github.com/gin-contrib/gzip"
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
@@ -30,16 +33,16 @@ func SetApiRouter(router *gin.Engine) {
|
||||
apiRouter.GET("/verification", middleware.EmailVerificationRateLimit(), middleware.TurnstileCheck(), controller.SendEmailVerification)
|
||||
apiRouter.GET("/reset_password", middleware.CriticalRateLimit(), middleware.TurnstileCheck(), controller.SendPasswordResetEmail)
|
||||
apiRouter.POST("/user/reset", middleware.CriticalRateLimit(), controller.ResetPassword)
|
||||
apiRouter.GET("/oauth/github", middleware.CriticalRateLimit(), controller.GitHubOAuth)
|
||||
apiRouter.GET("/oauth/discord", middleware.CriticalRateLimit(), controller.DiscordOAuth)
|
||||
apiRouter.GET("/oauth/oidc", middleware.CriticalRateLimit(), controller.OidcAuth)
|
||||
apiRouter.GET("/oauth/linuxdo", middleware.CriticalRateLimit(), controller.LinuxdoOAuth)
|
||||
// OAuth routes - specific routes must come before :provider wildcard
|
||||
apiRouter.GET("/oauth/state", middleware.CriticalRateLimit(), controller.GenerateOAuthCode)
|
||||
apiRouter.GET("/oauth/email/bind", middleware.CriticalRateLimit(), controller.EmailBind)
|
||||
// Non-standard OAuth (WeChat, Telegram) - keep original routes
|
||||
apiRouter.GET("/oauth/wechat", middleware.CriticalRateLimit(), controller.WeChatAuth)
|
||||
apiRouter.GET("/oauth/wechat/bind", middleware.CriticalRateLimit(), controller.WeChatBind)
|
||||
apiRouter.GET("/oauth/email/bind", middleware.CriticalRateLimit(), controller.EmailBind)
|
||||
apiRouter.GET("/oauth/telegram/login", middleware.CriticalRateLimit(), controller.TelegramLogin)
|
||||
apiRouter.GET("/oauth/telegram/bind", middleware.CriticalRateLimit(), controller.TelegramBind)
|
||||
// Standard OAuth providers (GitHub, Discord, OIDC, LinuxDO) - unified route
|
||||
apiRouter.GET("/oauth/:provider", middleware.CriticalRateLimit(), controller.HandleOAuth)
|
||||
apiRouter.GET("/ratio_config", middleware.CriticalRateLimit(), controller.GetRatioConfig)
|
||||
|
||||
apiRouter.POST("/stripe/webhook", controller.StripeWebhook)
|
||||
|
||||
@@ -17,7 +17,7 @@ along with this program. If not, see <https://www.gnu.org/licenses/>.
|
||||
For commercial licensing, please contact support@quantumnous.com
|
||||
*/
|
||||
|
||||
import React, { useContext, useEffect } from 'react';
|
||||
import React, { useContext, useEffect, useRef } from 'react';
|
||||
import { useNavigate, useSearchParams } from 'react-router-dom';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
import {
|
||||
@@ -35,6 +35,9 @@ const OAuth2Callback = (props) => {
|
||||
const [searchParams] = useSearchParams();
|
||||
const [, userDispatch] = useContext(UserContext);
|
||||
const navigate = useNavigate();
|
||||
|
||||
// 防止 React 18 Strict Mode 下重复执行
|
||||
const hasExecuted = useRef(false);
|
||||
|
||||
// 最大重试次数
|
||||
const MAX_RETRIES = 3;
|
||||
@@ -48,7 +51,9 @@ const OAuth2Callback = (props) => {
|
||||
const { success, message, data } = resData;
|
||||
|
||||
if (!success) {
|
||||
throw new Error(message || 'OAuth2 callback error');
|
||||
// 业务错误不重试,直接显示错误
|
||||
showError(message || t('授权失败'));
|
||||
return;
|
||||
}
|
||||
|
||||
if (message === 'bind') {
|
||||
@@ -63,6 +68,7 @@ const OAuth2Callback = (props) => {
|
||||
navigate('/console/token');
|
||||
}
|
||||
} catch (error) {
|
||||
// 网络错误等可重试
|
||||
if (retry < MAX_RETRIES) {
|
||||
// 递增的退避等待
|
||||
await new Promise((resolve) => setTimeout(resolve, (retry + 1) * 2000));
|
||||
@@ -76,6 +82,12 @@ const OAuth2Callback = (props) => {
|
||||
};
|
||||
|
||||
useEffect(() => {
|
||||
// 防止 React 18 Strict Mode 下重复执行
|
||||
if (hasExecuted.current) {
|
||||
return;
|
||||
}
|
||||
hasExecuted.current = true;
|
||||
|
||||
const code = searchParams.get('code');
|
||||
const state = searchParams.get('state');
|
||||
|
||||
|
||||
Reference in New Issue
Block a user