diff --git a/README.en.md b/README.en.md index e51d4c89a..55e2ffe5b 100644 --- a/README.en.md +++ b/README.en.md @@ -193,6 +193,7 @@ docker run --name new-api -d --restart always \ ### 🔐 Authorization and Security +- 😈 Discord authorization login - 🤖 LinuxDO authorization login - 📱 Telegram authorization login - 🔑 OIDC unified authentication diff --git a/README.md b/README.md index aaa64e31e..e0759506c 100644 --- a/README.md +++ b/README.md @@ -193,6 +193,7 @@ docker run --name new-api -d --restart always \ ### 🔐 授权与安全 +- 😈 Discord 授权登录 - 🤖 LinuxDO 授权登录 - 📱 Telegram 授权登录 - 🔑 OIDC 统一认证 diff --git a/controller/discord.go b/controller/discord.go new file mode 100644 index 000000000..41dd59808 --- /dev/null +++ b/controller/discord.go @@ -0,0 +1,223 @@ +package controller + +import ( + "encoding/json" + "errors" + "fmt" + "net/http" + "net/url" + "strconv" + "strings" + "time" + + "github.com/QuantumNous/new-api/common" + "github.com/QuantumNous/new-api/model" + "github.com/QuantumNous/new-api/setting/system_setting" + + "github.com/gin-contrib/sessions" + "github.com/gin-gonic/gin" +) + +type DiscordResponse struct { + AccessToken string `json:"access_token"` + IDToken string `json:"id_token"` + RefreshToken string `json:"refresh_token"` + TokenType string `json:"token_type"` + ExpiresIn int `json:"expires_in"` + Scope string `json:"scope"` +} + +type DiscordUser struct { + UID string `json:"id"` + ID string `json:"username"` + Name string `json:"global_name"` +} + +func getDiscordUserInfoByCode(code string) (*DiscordUser, error) { + if code == "" { + return nil, errors.New("无效的参数") + } + + values := url.Values{} + values.Set("client_id", system_setting.GetDiscordSettings().ClientId) + values.Set("client_secret", system_setting.GetDiscordSettings().ClientSecret) + values.Set("code", code) + values.Set("grant_type", "authorization_code") + values.Set("redirect_uri", fmt.Sprintf("%s/oauth/discord", system_setting.ServerAddress)) + formData := values.Encode() + req, err := http.NewRequest("POST", "https://discord.com/api/v10/oauth2/token", strings.NewReader(formData)) + if err != nil { + return nil, err + } + req.Header.Set("Content-Type", "application/x-www-form-urlencoded") + req.Header.Set("Accept", "application/json") + client := http.Client{ + Timeout: 5 * time.Second, + } + res, err := client.Do(req) + if err != nil { + common.SysLog(err.Error()) + return nil, errors.New("无法连接至 Discord 服务器,请稍后重试!") + } + defer res.Body.Close() + var discordResponse DiscordResponse + err = json.NewDecoder(res.Body).Decode(&discordResponse) + if err != nil { + return nil, err + } + + if discordResponse.AccessToken == "" { + common.SysError("Discord 获取 Token 失败,请检查设置!") + return nil, errors.New("Discord 获取 Token 失败,请检查设置!") + } + + req, err = http.NewRequest("GET", "https://discord.com/api/v10/users/@me", nil) + if err != nil { + return nil, err + } + req.Header.Set("Authorization", "Bearer "+discordResponse.AccessToken) + res2, err := client.Do(req) + if err != nil { + common.SysLog(err.Error()) + return nil, errors.New("无法连接至 Discord 服务器,请稍后重试!") + } + defer res2.Body.Close() + if res2.StatusCode != http.StatusOK { + common.SysError("Discord 获取用户信息失败!请检查设置!") + return nil, errors.New("Discord 获取用户信息失败!请检查设置!") + } + + var discordUser DiscordUser + err = json.NewDecoder(res2.Body).Decode(&discordUser) + if err != nil { + return nil, err + } + if discordUser.UID == "" || discordUser.ID == "" { + common.SysError("Discord 获取用户信息为空!请检查设置!") + return nil, errors.New("Discord 获取用户信息为空!请检查设置!") + } + return &discordUser, nil +} + +func DiscordOAuth(c *gin.Context) { + session := sessions.Default(c) + state := c.Query("state") + if state == "" || session.Get("oauth_state") == nil || state != session.Get("oauth_state").(string) { + c.JSON(http.StatusForbidden, gin.H{ + "success": false, + "message": "state is empty or not same", + }) + return + } + username := session.Get("username") + if username != nil { + DiscordBind(c) + return + } + if !system_setting.GetDiscordSettings().Enabled { + c.JSON(http.StatusOK, gin.H{ + "success": false, + "message": "管理员未开启通过 Discord 登录以及注册", + }) + return + } + code := c.Query("code") + discordUser, err := getDiscordUserInfoByCode(code) + if err != nil { + common.ApiError(c, err) + return + } + user := model.User{ + DiscordId: discordUser.UID, + } + if model.IsDiscordIdAlreadyTaken(user.DiscordId) { + err := user.FillUserByDiscordId() + if err != nil { + c.JSON(http.StatusOK, gin.H{ + "success": false, + "message": err.Error(), + }) + return + } + } else { + if common.RegisterEnabled { + if discordUser.ID != "" { + user.Username = discordUser.ID + } else { + user.Username = "discord_" + strconv.Itoa(model.GetMaxUserId()+1) + } + if discordUser.Name != "" { + user.DisplayName = discordUser.Name + } else { + user.DisplayName = "Discord User" + } + err := user.Insert(0) + if err != nil { + c.JSON(http.StatusOK, gin.H{ + "success": false, + "message": err.Error(), + }) + return + } + } else { + c.JSON(http.StatusOK, gin.H{ + "success": false, + "message": "管理员关闭了新用户注册", + }) + return + } + } + + if user.Status != common.UserStatusEnabled { + c.JSON(http.StatusOK, gin.H{ + "message": "用户已被封禁", + "success": false, + }) + return + } + setupLogin(&user, c) +} + +func DiscordBind(c *gin.Context) { + if !system_setting.GetDiscordSettings().Enabled { + c.JSON(http.StatusOK, gin.H{ + "success": false, + "message": "管理员未开启通过 Discord 登录以及注册", + }) + return + } + code := c.Query("code") + discordUser, err := getDiscordUserInfoByCode(code) + if err != nil { + common.ApiError(c, err) + return + } + user := model.User{ + DiscordId: discordUser.UID, + } + if model.IsDiscordIdAlreadyTaken(user.DiscordId) { + c.JSON(http.StatusOK, gin.H{ + "success": false, + "message": "该 Discord 账户已被绑定", + }) + return + } + session := sessions.Default(c) + id := session.Get("id") + user.Id = id.(int) + err = user.FillUserById() + if err != nil { + common.ApiError(c, err) + return + } + user.DiscordId = discordUser.UID + err = user.Update(false) + if err != nil { + common.ApiError(c, err) + return + } + c.JSON(http.StatusOK, gin.H{ + "success": true, + "message": "bind", + }) +} diff --git a/controller/misc.go b/controller/misc.go index 83b43fb57..70415137a 100644 --- a/controller/misc.go +++ b/controller/misc.go @@ -52,6 +52,8 @@ func GetStatus(c *gin.Context) { "email_verification": common.EmailVerificationEnabled, "github_oauth": common.GitHubOAuthEnabled, "github_client_id": common.GitHubClientId, + "discord_oauth": system_setting.GetDiscordSettings().Enabled, + "discord_client_id": system_setting.GetDiscordSettings().ClientId, "linuxdo_oauth": common.LinuxDOOAuthEnabled, "linuxdo_client_id": common.LinuxDOClientId, "linuxdo_minimum_trust_level": common.LinuxDOMinimumTrustLevel, diff --git a/controller/option.go b/controller/option.go index 56f65f5ff..89b2fc4d5 100644 --- a/controller/option.go +++ b/controller/option.go @@ -71,6 +71,14 @@ func UpdateOption(c *gin.Context) { }) return } + case "discord.enabled": + if option.Value == "true" && system_setting.GetDiscordSettings().ClientId == "" { + c.JSON(http.StatusOK, gin.H{ + "success": false, + "message": "无法启用 Discord OAuth,请先填入 Discord Client Id 以及 Discord Client Secret!", + }) + return + } case "oidc.enabled": if option.Value == "true" && system_setting.GetOIDCSettings().ClientId == "" { c.JSON(http.StatusOK, gin.H{ diff --git a/controller/user.go b/controller/user.go index eda4f7f12..ef4f0ddc0 100644 --- a/controller/user.go +++ b/controller/user.go @@ -453,6 +453,7 @@ func GetSelf(c *gin.Context) { "status": user.Status, "email": user.Email, "github_id": user.GitHubId, + "discord_id": user.DiscordId, "oidc_id": user.OidcId, "wechat_id": user.WeChatId, "telegram_id": user.TelegramId, diff --git a/docs/api/web_api.md b/docs/api/web_api.md index e64fd3594..aa88a606c 100644 --- a/docs/api/web_api.md +++ b/docs/api/web_api.md @@ -42,6 +42,7 @@ | 方法 | 路径 | 鉴权 | 说明 | |------|------|------|------| | GET | /api/oauth/github | 公开 | GitHub OAuth 跳转 | +| GET | /api/oauth/discord | 公开 | Discord 通用 OAuth 跳转 | | GET | /api/oauth/oidc | 公开 | OIDC 通用 OAuth 跳转 | | GET | /api/oauth/linuxdo | 公开 | LinuxDo OAuth 跳转 | | GET | /api/oauth/wechat | 公开 | 微信扫码登录跳转 | diff --git a/model/user.go b/model/user.go index 78365e06e..395daa0b5 100644 --- a/model/user.go +++ b/model/user.go @@ -27,6 +27,7 @@ type User struct { Status int `json:"status" gorm:"type:int;default:1"` // enabled, disabled Email string `json:"email" gorm:"index" validate:"max=50"` GitHubId string `json:"github_id" gorm:"column:github_id;index"` + DiscordId string `json:"discord_id" gorm:"column:discord_id;index"` OidcId string `json:"oidc_id" gorm:"column:oidc_id;index"` WeChatId string `json:"wechat_id" gorm:"column:wechat_id;index"` TelegramId string `json:"telegram_id" gorm:"column:telegram_id;index"` @@ -539,6 +540,14 @@ func (user *User) FillUserByGitHubId() error { return nil } +func (user *User) FillUserByDiscordId() error { + if user.DiscordId == "" { + return errors.New("discord id 为空!") + } + DB.Where(User{DiscordId: user.DiscordId}).First(user) + return nil +} + func (user *User) FillUserByOidcId() error { if user.OidcId == "" { return errors.New("oidc id 为空!") @@ -578,6 +587,10 @@ func IsGitHubIdAlreadyTaken(githubId string) bool { return DB.Unscoped().Where("github_id = ?", githubId).Find(&User{}).RowsAffected == 1 } +func IsDiscordIdAlreadyTaken(discordId string) bool { + return DB.Unscoped().Where("discord_id = ?", discordId).Find(&User{}).RowsAffected == 1 +} + func IsOidcIdAlreadyTaken(oidcId string) bool { return DB.Where("oidc_id = ?", oidcId).Find(&User{}).RowsAffected == 1 } diff --git a/router/api-router.go b/router/api-router.go index 9506875cd..fd204e7e6 100644 --- a/router/api-router.go +++ b/router/api-router.go @@ -30,6 +30,7 @@ func SetApiRouter(router *gin.Engine) { apiRouter.GET("/reset_password", middleware.CriticalRateLimit(), middleware.TurnstileCheck(), controller.SendPasswordResetEmail) apiRouter.POST("/user/reset", middleware.CriticalRateLimit(), controller.ResetPassword) apiRouter.GET("/oauth/github", middleware.CriticalRateLimit(), controller.GitHubOAuth) + apiRouter.GET("/oauth/discord", middleware.CriticalRateLimit(), controller.DiscordOAuth) apiRouter.GET("/oauth/oidc", middleware.CriticalRateLimit(), controller.OidcAuth) apiRouter.GET("/oauth/linuxdo", middleware.CriticalRateLimit(), controller.LinuxdoOAuth) apiRouter.GET("/oauth/state", middleware.CriticalRateLimit(), controller.GenerateOAuthCode) diff --git a/setting/system_setting/discord.go b/setting/system_setting/discord.go new file mode 100644 index 000000000..f4e763ffa --- /dev/null +++ b/setting/system_setting/discord.go @@ -0,0 +1,21 @@ +package system_setting + +import "github.com/QuantumNous/new-api/setting/config" + +type DiscordSettings struct { + Enabled bool `json:"enabled"` + ClientId string `json:"client_id"` + ClientSecret string `json:"client_secret"` +} + +// 默认配置 +var defaultDiscordSettings = DiscordSettings{} + +func init() { + // 注册到全局配置管理器 + config.GlobalConfig.Register("discord", &defaultDiscordSettings) +} + +func GetDiscordSettings() *DiscordSettings { + return &defaultDiscordSettings +} diff --git a/web/src/App.jsx b/web/src/App.jsx index 06e364897..b0f281c45 100644 --- a/web/src/App.jsx +++ b/web/src/App.jsx @@ -192,6 +192,14 @@ function App() { } /> + } key={location.pathname}> + + + } + /> { let navigate = useNavigate(); @@ -73,6 +75,7 @@ const LoginForm = () => { const [showEmailLogin, setShowEmailLogin] = useState(false); const [wechatLoading, setWechatLoading] = useState(false); const [githubLoading, setGithubLoading] = useState(false); + const [discordLoading, setDiscordLoading] = useState(false); const [oidcLoading, setOidcLoading] = useState(false); const [linuxdoLoading, setLinuxdoLoading] = useState(false); const [emailLoginLoading, setEmailLoginLoading] = useState(false); @@ -298,6 +301,21 @@ const LoginForm = () => { } }; + // 包装的Discord登录点击处理 + const handleDiscordClick = () => { + if ((hasUserAgreement || hasPrivacyPolicy) && !agreedToTerms) { + showInfo(t('请先阅读并同意用户协议和隐私政策')); + return; + } + setDiscordLoading(true); + try { + onDiscordOAuthClicked(status.discord_client_id); + } finally { + // 由于重定向,这里不会执行到,但为了完整性添加 + setTimeout(() => setDiscordLoading(false), 3000); + } + }; + // 包装的OIDC登录点击处理 const handleOIDCClick = () => { if ((hasUserAgreement || hasPrivacyPolicy) && !agreedToTerms) { @@ -472,6 +490,19 @@ const LoginForm = () => { )} + {status.discord_oauth && ( + + )} + {status.oidc_enabled && ( )} + {status.discord_oauth && ( + + )} + {status.oidc_enabled && ( + + + {t('用以支持通过 Discord 进行登录注册')} + + + + + + + + + + + + diff --git a/web/src/components/settings/personal/cards/AccountManagement.jsx b/web/src/components/settings/personal/cards/AccountManagement.jsx index d54edb93a..9c61bbf0f 100644 --- a/web/src/components/settings/personal/cards/AccountManagement.jsx +++ b/web/src/components/settings/personal/cards/AccountManagement.jsx @@ -38,13 +38,14 @@ import { IconLock, IconDelete, } from '@douyinfe/semi-icons'; -import { SiTelegram, SiWechat, SiLinux } from 'react-icons/si'; +import { SiTelegram, SiWechat, SiLinux, SiDiscord } from 'react-icons/si'; import { UserPlus, ShieldCheck } from 'lucide-react'; import TelegramLoginButton from 'react-telegram-login'; import { onGitHubOAuthClicked, onOIDCClicked, onLinuxDOOAuthClicked, + onDiscordOAuthClicked, } from '../../../../helpers'; import TwoFASetting from '../components/TwoFASetting'; @@ -247,6 +248,47 @@ const AccountManagement = ({ + {/* Discord绑定 */} + +
+
+
+ +
+
+
+ {t('Discord')} +
+
+ {renderAccountInfo( + userState.user?.discord_id, + t('Discord ID'), + )} +
+
+
+
+ +
+
+
+ {/* OIDC绑定 */}
diff --git a/web/src/components/table/users/modals/EditUserModal.jsx b/web/src/components/table/users/modals/EditUserModal.jsx index 7f5a8414c..0ef4a9a78 100644 --- a/web/src/components/table/users/modals/EditUserModal.jsx +++ b/web/src/components/table/users/modals/EditUserModal.jsx @@ -72,6 +72,7 @@ const EditUserModal = (props) => { password: '', github_id: '', oidc_id: '', + discord_id: '', wechat_id: '', telegram_id: '', email: '', @@ -332,6 +333,7 @@ const EditUserModal = (props) => { {[ 'github_id', + 'discord_id', 'oidc_id', 'wechat_id', 'email', diff --git a/web/src/helpers/api.js b/web/src/helpers/api.js index 1ccfffaf2..b87e5a2f8 100644 --- a/web/src/helpers/api.js +++ b/web/src/helpers/api.js @@ -231,6 +231,17 @@ export async function getOAuthState() { } } +export async function onDiscordOAuthClicked(client_id) { + const state = await getOAuthState(); + if (!state) return; + const redirect_uri = `${window.location.origin}/oauth/discord`; + const response_type = 'code'; + const scope = 'identify+openid'; + window.open( + `https://discord.com/oauth2/authorize?client_id=${client_id}&redirect_uri=${redirect_uri}&response_type=${response_type}&scope=${scope}&state=${state}`, + ); +} + export async function onOIDCClicked(auth_url, client_id, openInNewTab = false) { const state = await getOAuthState(); if (!state) return; diff --git a/web/src/i18n/locales/zh.json b/web/src/i18n/locales/zh.json index 29c1c7f40..541912b20 100644 --- a/web/src/i18n/locales/zh.json +++ b/web/src/i18n/locales/zh.json @@ -257,6 +257,7 @@ "余额充值管理": "余额充值管理", "你似乎并没有修改什么": "你似乎并没有修改什么", "使用 GitHub 继续": "使用 GitHub 继续", + "使用 Discord 继续": "使用 Discord 继续", "使用 JSON 对象格式,格式为:{\"组名\": [最多请求次数, 最多请求完成次数]}": "使用 JSON 对象格式,格式为:{\"组名\": [最多请求次数, 最多请求完成次数]}", "使用 LinuxDO 继续": "使用 LinuxDO 继续", "使用 OIDC 继续": "使用 OIDC 继续",