refactor: unify OAuth providers with i18n support

- Introduce Provider interface pattern for standard OAuth protocols
- Create unified controller/oauth.go with common OAuth logic
- Add OAuthError type for translatable error messages
- Add i18n keys and translations (zh/en) for OAuth messages
- Use common.ApiErrorI18n/ApiSuccessI18n for consistent responses
- Preserve backward compatibility for existing routes and data
This commit is contained in:
CaIon
2026-02-05 20:21:38 +08:00
parent c540033985
commit df6c669e73
17 changed files with 1157 additions and 969 deletions

172
oauth/discord.go Normal file
View File

@@ -0,0 +1,172 @@
package oauth
import (
"context"
"encoding/json"
"fmt"
"net/http"
"net/url"
"strings"
"time"
"github.com/QuantumNous/new-api/i18n"
"github.com/QuantumNous/new-api/logger"
"github.com/QuantumNous/new-api/model"
"github.com/QuantumNous/new-api/setting/system_setting"
"github.com/gin-gonic/gin"
)
func init() {
Register("discord", &DiscordProvider{})
}
// DiscordProvider implements OAuth for Discord
type DiscordProvider struct{}
type discordOAuthResponse struct {
AccessToken string `json:"access_token"`
IDToken string `json:"id_token"`
RefreshToken string `json:"refresh_token"`
TokenType string `json:"token_type"`
ExpiresIn int `json:"expires_in"`
Scope string `json:"scope"`
}
type discordUser struct {
UID string `json:"id"`
ID string `json:"username"`
Name string `json:"global_name"`
}
func (p *DiscordProvider) GetName() string {
return "Discord"
}
func (p *DiscordProvider) IsEnabled() bool {
return system_setting.GetDiscordSettings().Enabled
}
func (p *DiscordProvider) ExchangeToken(ctx context.Context, code string, c *gin.Context) (*OAuthToken, error) {
if code == "" {
return nil, NewOAuthError(i18n.MsgOAuthInvalidCode, nil)
}
logger.LogDebug(ctx, "[OAuth-Discord] ExchangeToken: code=%s...", code[:min(len(code), 10)])
settings := system_setting.GetDiscordSettings()
redirectUri := fmt.Sprintf("%s/oauth/discord", system_setting.ServerAddress)
values := url.Values{}
values.Set("client_id", settings.ClientId)
values.Set("client_secret", settings.ClientSecret)
values.Set("code", code)
values.Set("grant_type", "authorization_code")
values.Set("redirect_uri", redirectUri)
logger.LogDebug(ctx, "[OAuth-Discord] ExchangeToken: redirect_uri=%s", redirectUri)
req, err := http.NewRequestWithContext(ctx, "POST", "https://discord.com/api/v10/oauth2/token", strings.NewReader(values.Encode()))
if err != nil {
return nil, err
}
req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
req.Header.Set("Accept", "application/json")
client := http.Client{
Timeout: 5 * time.Second,
}
res, err := client.Do(req)
if err != nil {
logger.LogError(ctx, fmt.Sprintf("[OAuth-Discord] ExchangeToken error: %s", err.Error()))
return nil, NewOAuthErrorWithRaw(i18n.MsgOAuthConnectFailed, map[string]any{"Provider": "Discord"}, err.Error())
}
defer res.Body.Close()
logger.LogDebug(ctx, "[OAuth-Discord] ExchangeToken response status: %d", res.StatusCode)
var discordResponse discordOAuthResponse
err = json.NewDecoder(res.Body).Decode(&discordResponse)
if err != nil {
logger.LogError(ctx, fmt.Sprintf("[OAuth-Discord] ExchangeToken decode error: %s", err.Error()))
return nil, err
}
if discordResponse.AccessToken == "" {
logger.LogError(ctx, "[OAuth-Discord] ExchangeToken failed: empty access token")
return nil, NewOAuthError(i18n.MsgOAuthTokenFailed, map[string]any{"Provider": "Discord"})
}
logger.LogDebug(ctx, "[OAuth-Discord] ExchangeToken success: scope=%s", discordResponse.Scope)
return &OAuthToken{
AccessToken: discordResponse.AccessToken,
TokenType: discordResponse.TokenType,
RefreshToken: discordResponse.RefreshToken,
ExpiresIn: discordResponse.ExpiresIn,
Scope: discordResponse.Scope,
IDToken: discordResponse.IDToken,
}, nil
}
func (p *DiscordProvider) GetUserInfo(ctx context.Context, token *OAuthToken) (*OAuthUser, error) {
logger.LogDebug(ctx, "[OAuth-Discord] GetUserInfo: fetching user info")
req, err := http.NewRequestWithContext(ctx, "GET", "https://discord.com/api/v10/users/@me", nil)
if err != nil {
return nil, err
}
req.Header.Set("Authorization", "Bearer "+token.AccessToken)
client := http.Client{
Timeout: 5 * time.Second,
}
res, err := client.Do(req)
if err != nil {
logger.LogError(ctx, fmt.Sprintf("[OAuth-Discord] GetUserInfo error: %s", err.Error()))
return nil, NewOAuthErrorWithRaw(i18n.MsgOAuthConnectFailed, map[string]any{"Provider": "Discord"}, err.Error())
}
defer res.Body.Close()
logger.LogDebug(ctx, "[OAuth-Discord] GetUserInfo response status: %d", res.StatusCode)
if res.StatusCode != http.StatusOK {
logger.LogError(ctx, fmt.Sprintf("[OAuth-Discord] GetUserInfo failed: status=%d", res.StatusCode))
return nil, NewOAuthError(i18n.MsgOAuthGetUserErr, nil)
}
var discordUser discordUser
err = json.NewDecoder(res.Body).Decode(&discordUser)
if err != nil {
logger.LogError(ctx, fmt.Sprintf("[OAuth-Discord] GetUserInfo decode error: %s", err.Error()))
return nil, err
}
if discordUser.UID == "" || discordUser.ID == "" {
logger.LogError(ctx, "[OAuth-Discord] GetUserInfo failed: empty user fields")
return nil, NewOAuthError(i18n.MsgOAuthUserInfoEmpty, map[string]any{"Provider": "Discord"})
}
logger.LogDebug(ctx, "[OAuth-Discord] GetUserInfo success: uid=%s, username=%s, name=%s", discordUser.UID, discordUser.ID, discordUser.Name)
return &OAuthUser{
ProviderUserID: discordUser.UID,
Username: discordUser.ID,
DisplayName: discordUser.Name,
}, nil
}
func (p *DiscordProvider) IsUserIDTaken(providerUserID string) bool {
return model.IsDiscordIdAlreadyTaken(providerUserID)
}
func (p *DiscordProvider) FillUserByProviderID(user *model.User, providerUserID string) error {
user.DiscordId = providerUserID
return user.FillUserByDiscordId()
}
func (p *DiscordProvider) SetProviderUserID(user *model.User, providerUserID string) {
user.DiscordId = providerUserID
}
func (p *DiscordProvider) GetProviderPrefix() string {
return "discord_"
}

160
oauth/github.go Normal file
View File

@@ -0,0 +1,160 @@
package oauth
import (
"bytes"
"context"
"encoding/json"
"fmt"
"net/http"
"time"
"github.com/QuantumNous/new-api/common"
"github.com/QuantumNous/new-api/i18n"
"github.com/QuantumNous/new-api/logger"
"github.com/QuantumNous/new-api/model"
"github.com/gin-gonic/gin"
)
func init() {
Register("github", &GitHubProvider{})
}
// GitHubProvider implements OAuth for GitHub
type GitHubProvider struct{}
type gitHubOAuthResponse struct {
AccessToken string `json:"access_token"`
Scope string `json:"scope"`
TokenType string `json:"token_type"`
}
type gitHubUser struct {
Login string `json:"login"`
Name string `json:"name"`
Email string `json:"email"`
}
func (p *GitHubProvider) GetName() string {
return "GitHub"
}
func (p *GitHubProvider) IsEnabled() bool {
return common.GitHubOAuthEnabled
}
func (p *GitHubProvider) ExchangeToken(ctx context.Context, code string, c *gin.Context) (*OAuthToken, error) {
if code == "" {
return nil, NewOAuthError(i18n.MsgOAuthInvalidCode, nil)
}
logger.LogDebug(ctx, "[OAuth-GitHub] ExchangeToken: code=%s...", code[:min(len(code), 10)])
values := map[string]string{
"client_id": common.GitHubClientId,
"client_secret": common.GitHubClientSecret,
"code": code,
}
jsonData, err := json.Marshal(values)
if err != nil {
return nil, err
}
req, err := http.NewRequestWithContext(ctx, "POST", "https://github.com/login/oauth/access_token", bytes.NewBuffer(jsonData))
if err != nil {
return nil, err
}
req.Header.Set("Content-Type", "application/json")
req.Header.Set("Accept", "application/json")
client := http.Client{
Timeout: 20 * time.Second,
}
res, err := client.Do(req)
if err != nil {
logger.LogError(ctx, fmt.Sprintf("[OAuth-GitHub] ExchangeToken error: %s", err.Error()))
return nil, NewOAuthErrorWithRaw(i18n.MsgOAuthConnectFailed, map[string]any{"Provider": "GitHub"}, err.Error())
}
defer res.Body.Close()
logger.LogDebug(ctx, "[OAuth-GitHub] ExchangeToken response status: %d", res.StatusCode)
var oAuthResponse gitHubOAuthResponse
err = json.NewDecoder(res.Body).Decode(&oAuthResponse)
if err != nil {
logger.LogError(ctx, fmt.Sprintf("[OAuth-GitHub] ExchangeToken decode error: %s", err.Error()))
return nil, err
}
if oAuthResponse.AccessToken == "" {
logger.LogError(ctx, "[OAuth-GitHub] ExchangeToken failed: empty access token")
return nil, NewOAuthError(i18n.MsgOAuthTokenFailed, map[string]any{"Provider": "GitHub"})
}
logger.LogDebug(ctx, "[OAuth-GitHub] ExchangeToken success: scope=%s", oAuthResponse.Scope)
return &OAuthToken{
AccessToken: oAuthResponse.AccessToken,
TokenType: oAuthResponse.TokenType,
Scope: oAuthResponse.Scope,
}, nil
}
func (p *GitHubProvider) GetUserInfo(ctx context.Context, token *OAuthToken) (*OAuthUser, error) {
logger.LogDebug(ctx, "[OAuth-GitHub] GetUserInfo: fetching user info")
req, err := http.NewRequestWithContext(ctx, "GET", "https://api.github.com/user", nil)
if err != nil {
return nil, err
}
req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", token.AccessToken))
client := http.Client{
Timeout: 20 * time.Second,
}
res, err := client.Do(req)
if err != nil {
logger.LogError(ctx, fmt.Sprintf("[OAuth-GitHub] GetUserInfo error: %s", err.Error()))
return nil, NewOAuthErrorWithRaw(i18n.MsgOAuthConnectFailed, map[string]any{"Provider": "GitHub"}, err.Error())
}
defer res.Body.Close()
logger.LogDebug(ctx, "[OAuth-GitHub] GetUserInfo response status: %d", res.StatusCode)
var githubUser gitHubUser
err = json.NewDecoder(res.Body).Decode(&githubUser)
if err != nil {
logger.LogError(ctx, fmt.Sprintf("[OAuth-GitHub] GetUserInfo decode error: %s", err.Error()))
return nil, err
}
if githubUser.Login == "" {
logger.LogError(ctx, "[OAuth-GitHub] GetUserInfo failed: empty login field")
return nil, NewOAuthError(i18n.MsgOAuthUserInfoEmpty, map[string]any{"Provider": "GitHub"})
}
logger.LogDebug(ctx, "[OAuth-GitHub] GetUserInfo success: login=%s, name=%s, email=%s", githubUser.Login, githubUser.Name, githubUser.Email)
return &OAuthUser{
ProviderUserID: githubUser.Login,
Username: githubUser.Login,
DisplayName: githubUser.Name,
Email: githubUser.Email,
}, nil
}
func (p *GitHubProvider) IsUserIDTaken(providerUserID string) bool {
return model.IsGitHubIdAlreadyTaken(providerUserID)
}
func (p *GitHubProvider) FillUserByProviderID(user *model.User, providerUserID string) error {
user.GitHubId = providerUserID
return user.FillUserByGitHubId()
}
func (p *GitHubProvider) SetProviderUserID(user *model.User, providerUserID string) {
user.GitHubId = providerUserID
}
func (p *GitHubProvider) GetProviderPrefix() string {
return "github_"
}

195
oauth/linuxdo.go Normal file
View File

@@ -0,0 +1,195 @@
package oauth
import (
"context"
"encoding/base64"
"encoding/json"
"fmt"
"net/http"
"net/url"
"strconv"
"strings"
"time"
"github.com/QuantumNous/new-api/common"
"github.com/QuantumNous/new-api/i18n"
"github.com/QuantumNous/new-api/logger"
"github.com/QuantumNous/new-api/model"
"github.com/gin-gonic/gin"
)
func init() {
Register("linuxdo", &LinuxDOProvider{})
}
// LinuxDOProvider implements OAuth for Linux DO
type LinuxDOProvider struct{}
type linuxdoUser struct {
Id int `json:"id"`
Username string `json:"username"`
Name string `json:"name"`
Active bool `json:"active"`
TrustLevel int `json:"trust_level"`
Silenced bool `json:"silenced"`
}
func (p *LinuxDOProvider) GetName() string {
return "Linux DO"
}
func (p *LinuxDOProvider) IsEnabled() bool {
return common.LinuxDOOAuthEnabled
}
func (p *LinuxDOProvider) ExchangeToken(ctx context.Context, code string, c *gin.Context) (*OAuthToken, error) {
if code == "" {
return nil, NewOAuthError(i18n.MsgOAuthInvalidCode, nil)
}
logger.LogDebug(ctx, "[OAuth-LinuxDO] ExchangeToken: code=%s...", code[:min(len(code), 10)])
// Get access token using Basic auth
tokenEndpoint := common.GetEnvOrDefaultString("LINUX_DO_TOKEN_ENDPOINT", "https://connect.linux.do/oauth2/token")
credentials := common.LinuxDOClientId + ":" + common.LinuxDOClientSecret
basicAuth := "Basic " + base64.StdEncoding.EncodeToString([]byte(credentials))
// Get redirect URI from request
scheme := "http"
if c.Request.TLS != nil {
scheme = "https"
}
redirectURI := fmt.Sprintf("%s://%s/api/oauth/linuxdo", scheme, c.Request.Host)
logger.LogDebug(ctx, "[OAuth-LinuxDO] ExchangeToken: token_endpoint=%s, redirect_uri=%s", tokenEndpoint, redirectURI)
data := url.Values{}
data.Set("grant_type", "authorization_code")
data.Set("code", code)
data.Set("redirect_uri", redirectURI)
req, err := http.NewRequestWithContext(ctx, "POST", tokenEndpoint, strings.NewReader(data.Encode()))
if err != nil {
return nil, err
}
req.Header.Set("Authorization", basicAuth)
req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
req.Header.Set("Accept", "application/json")
client := http.Client{Timeout: 5 * time.Second}
res, err := client.Do(req)
if err != nil {
logger.LogError(ctx, fmt.Sprintf("[OAuth-LinuxDO] ExchangeToken error: %s", err.Error()))
return nil, NewOAuthErrorWithRaw(i18n.MsgOAuthConnectFailed, map[string]any{"Provider": "Linux DO"}, err.Error())
}
defer res.Body.Close()
logger.LogDebug(ctx, "[OAuth-LinuxDO] ExchangeToken response status: %d", res.StatusCode)
var tokenRes struct {
AccessToken string `json:"access_token"`
Message string `json:"message"`
}
if err := json.NewDecoder(res.Body).Decode(&tokenRes); err != nil {
logger.LogError(ctx, fmt.Sprintf("[OAuth-LinuxDO] ExchangeToken decode error: %s", err.Error()))
return nil, err
}
if tokenRes.AccessToken == "" {
logger.LogError(ctx, fmt.Sprintf("[OAuth-LinuxDO] ExchangeToken failed: %s", tokenRes.Message))
return nil, NewOAuthErrorWithRaw(i18n.MsgOAuthTokenFailed, map[string]any{"Provider": "Linux DO"}, tokenRes.Message)
}
logger.LogDebug(ctx, "[OAuth-LinuxDO] ExchangeToken success")
return &OAuthToken{
AccessToken: tokenRes.AccessToken,
}, nil
}
func (p *LinuxDOProvider) GetUserInfo(ctx context.Context, token *OAuthToken) (*OAuthUser, error) {
userEndpoint := common.GetEnvOrDefaultString("LINUX_DO_USER_ENDPOINT", "https://connect.linux.do/api/user")
logger.LogDebug(ctx, "[OAuth-LinuxDO] GetUserInfo: user_endpoint=%s", userEndpoint)
req, err := http.NewRequestWithContext(ctx, "GET", userEndpoint, nil)
if err != nil {
return nil, err
}
req.Header.Set("Authorization", "Bearer "+token.AccessToken)
req.Header.Set("Accept", "application/json")
client := http.Client{Timeout: 5 * time.Second}
res, err := client.Do(req)
if err != nil {
logger.LogError(ctx, fmt.Sprintf("[OAuth-LinuxDO] GetUserInfo error: %s", err.Error()))
return nil, NewOAuthErrorWithRaw(i18n.MsgOAuthConnectFailed, map[string]any{"Provider": "Linux DO"}, err.Error())
}
defer res.Body.Close()
logger.LogDebug(ctx, "[OAuth-LinuxDO] GetUserInfo response status: %d", res.StatusCode)
var linuxdoUser linuxdoUser
if err := json.NewDecoder(res.Body).Decode(&linuxdoUser); err != nil {
logger.LogError(ctx, fmt.Sprintf("[OAuth-LinuxDO] GetUserInfo decode error: %s", err.Error()))
return nil, err
}
if linuxdoUser.Id == 0 {
logger.LogError(ctx, "[OAuth-LinuxDO] GetUserInfo failed: invalid user id")
return nil, NewOAuthError(i18n.MsgOAuthUserInfoEmpty, map[string]any{"Provider": "Linux DO"})
}
logger.LogDebug(ctx, "[OAuth-LinuxDO] GetUserInfo: id=%d, username=%s, name=%s, trust_level=%d, active=%v, silenced=%v",
linuxdoUser.Id, linuxdoUser.Username, linuxdoUser.Name, linuxdoUser.TrustLevel, linuxdoUser.Active, linuxdoUser.Silenced)
// Check trust level
if linuxdoUser.TrustLevel < common.LinuxDOMinimumTrustLevel {
logger.LogWarn(ctx, fmt.Sprintf("[OAuth-LinuxDO] GetUserInfo: trust level too low (required=%d, current=%d)",
common.LinuxDOMinimumTrustLevel, linuxdoUser.TrustLevel))
return nil, &TrustLevelError{
Required: common.LinuxDOMinimumTrustLevel,
Current: linuxdoUser.TrustLevel,
}
}
logger.LogDebug(ctx, "[OAuth-LinuxDO] GetUserInfo success: id=%d, username=%s", linuxdoUser.Id, linuxdoUser.Username)
return &OAuthUser{
ProviderUserID: strconv.Itoa(linuxdoUser.Id),
Username: linuxdoUser.Username,
DisplayName: linuxdoUser.Name,
Extra: map[string]any{
"trust_level": linuxdoUser.TrustLevel,
"active": linuxdoUser.Active,
"silenced": linuxdoUser.Silenced,
},
}, nil
}
func (p *LinuxDOProvider) IsUserIDTaken(providerUserID string) bool {
return model.IsLinuxDOIdAlreadyTaken(providerUserID)
}
func (p *LinuxDOProvider) FillUserByProviderID(user *model.User, providerUserID string) error {
user.LinuxDOId = providerUserID
return user.FillUserByLinuxDOId()
}
func (p *LinuxDOProvider) SetProviderUserID(user *model.User, providerUserID string) {
user.LinuxDOId = providerUserID
}
func (p *LinuxDOProvider) GetProviderPrefix() string {
return "linuxdo_"
}
// TrustLevelError indicates the user's trust level is too low
type TrustLevelError struct {
Required int
Current int
}
func (e *TrustLevelError) Error() string {
return "trust level too low"
}

177
oauth/oidc.go Normal file
View File

@@ -0,0 +1,177 @@
package oauth
import (
"context"
"encoding/json"
"fmt"
"net/http"
"net/url"
"strings"
"time"
"github.com/QuantumNous/new-api/i18n"
"github.com/QuantumNous/new-api/logger"
"github.com/QuantumNous/new-api/model"
"github.com/QuantumNous/new-api/setting/system_setting"
"github.com/gin-gonic/gin"
)
func init() {
Register("oidc", &OIDCProvider{})
}
// OIDCProvider implements OAuth for OIDC
type OIDCProvider struct{}
type oidcOAuthResponse struct {
AccessToken string `json:"access_token"`
IDToken string `json:"id_token"`
RefreshToken string `json:"refresh_token"`
TokenType string `json:"token_type"`
ExpiresIn int `json:"expires_in"`
Scope string `json:"scope"`
}
type oidcUser struct {
OpenID string `json:"sub"`
Email string `json:"email"`
Name string `json:"name"`
PreferredUsername string `json:"preferred_username"`
Picture string `json:"picture"`
}
func (p *OIDCProvider) GetName() string {
return "OIDC"
}
func (p *OIDCProvider) IsEnabled() bool {
return system_setting.GetOIDCSettings().Enabled
}
func (p *OIDCProvider) ExchangeToken(ctx context.Context, code string, c *gin.Context) (*OAuthToken, error) {
if code == "" {
return nil, NewOAuthError(i18n.MsgOAuthInvalidCode, nil)
}
logger.LogDebug(ctx, "[OAuth-OIDC] ExchangeToken: code=%s...", code[:min(len(code), 10)])
settings := system_setting.GetOIDCSettings()
redirectUri := fmt.Sprintf("%s/oauth/oidc", system_setting.ServerAddress)
values := url.Values{}
values.Set("client_id", settings.ClientId)
values.Set("client_secret", settings.ClientSecret)
values.Set("code", code)
values.Set("grant_type", "authorization_code")
values.Set("redirect_uri", redirectUri)
logger.LogDebug(ctx, "[OAuth-OIDC] ExchangeToken: token_endpoint=%s, redirect_uri=%s", settings.TokenEndpoint, redirectUri)
req, err := http.NewRequestWithContext(ctx, "POST", settings.TokenEndpoint, strings.NewReader(values.Encode()))
if err != nil {
return nil, err
}
req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
req.Header.Set("Accept", "application/json")
client := http.Client{
Timeout: 5 * time.Second,
}
res, err := client.Do(req)
if err != nil {
logger.LogError(ctx, fmt.Sprintf("[OAuth-OIDC] ExchangeToken error: %s", err.Error()))
return nil, NewOAuthErrorWithRaw(i18n.MsgOAuthConnectFailed, map[string]any{"Provider": "OIDC"}, err.Error())
}
defer res.Body.Close()
logger.LogDebug(ctx, "[OAuth-OIDC] ExchangeToken response status: %d", res.StatusCode)
var oidcResponse oidcOAuthResponse
err = json.NewDecoder(res.Body).Decode(&oidcResponse)
if err != nil {
logger.LogError(ctx, fmt.Sprintf("[OAuth-OIDC] ExchangeToken decode error: %s", err.Error()))
return nil, err
}
if oidcResponse.AccessToken == "" {
logger.LogError(ctx, "[OAuth-OIDC] ExchangeToken failed: empty access token")
return nil, NewOAuthError(i18n.MsgOAuthTokenFailed, map[string]any{"Provider": "OIDC"})
}
logger.LogDebug(ctx, "[OAuth-OIDC] ExchangeToken success: scope=%s", oidcResponse.Scope)
return &OAuthToken{
AccessToken: oidcResponse.AccessToken,
TokenType: oidcResponse.TokenType,
RefreshToken: oidcResponse.RefreshToken,
ExpiresIn: oidcResponse.ExpiresIn,
Scope: oidcResponse.Scope,
IDToken: oidcResponse.IDToken,
}, nil
}
func (p *OIDCProvider) GetUserInfo(ctx context.Context, token *OAuthToken) (*OAuthUser, error) {
settings := system_setting.GetOIDCSettings()
logger.LogDebug(ctx, "[OAuth-OIDC] GetUserInfo: userinfo_endpoint=%s", settings.UserInfoEndpoint)
req, err := http.NewRequestWithContext(ctx, "GET", settings.UserInfoEndpoint, nil)
if err != nil {
return nil, err
}
req.Header.Set("Authorization", "Bearer "+token.AccessToken)
client := http.Client{
Timeout: 5 * time.Second,
}
res, err := client.Do(req)
if err != nil {
logger.LogError(ctx, fmt.Sprintf("[OAuth-OIDC] GetUserInfo error: %s", err.Error()))
return nil, NewOAuthErrorWithRaw(i18n.MsgOAuthConnectFailed, map[string]any{"Provider": "OIDC"}, err.Error())
}
defer res.Body.Close()
logger.LogDebug(ctx, "[OAuth-OIDC] GetUserInfo response status: %d", res.StatusCode)
if res.StatusCode != http.StatusOK {
logger.LogError(ctx, fmt.Sprintf("[OAuth-OIDC] GetUserInfo failed: status=%d", res.StatusCode))
return nil, NewOAuthError(i18n.MsgOAuthGetUserErr, nil)
}
var oidcUser oidcUser
err = json.NewDecoder(res.Body).Decode(&oidcUser)
if err != nil {
logger.LogError(ctx, fmt.Sprintf("[OAuth-OIDC] GetUserInfo decode error: %s", err.Error()))
return nil, err
}
if oidcUser.OpenID == "" || oidcUser.Email == "" {
logger.LogError(ctx, fmt.Sprintf("[OAuth-OIDC] GetUserInfo failed: empty fields (sub=%s, email=%s)", oidcUser.OpenID, oidcUser.Email))
return nil, NewOAuthError(i18n.MsgOAuthUserInfoEmpty, map[string]any{"Provider": "OIDC"})
}
logger.LogDebug(ctx, "[OAuth-OIDC] GetUserInfo success: sub=%s, username=%s, name=%s, email=%s", oidcUser.OpenID, oidcUser.PreferredUsername, oidcUser.Name, oidcUser.Email)
return &OAuthUser{
ProviderUserID: oidcUser.OpenID,
Username: oidcUser.PreferredUsername,
DisplayName: oidcUser.Name,
Email: oidcUser.Email,
}, nil
}
func (p *OIDCProvider) IsUserIDTaken(providerUserID string) bool {
return model.IsOidcIdAlreadyTaken(providerUserID)
}
func (p *OIDCProvider) FillUserByProviderID(user *model.User, providerUserID string) error {
user.OidcId = providerUserID
return user.FillUserByOidcId()
}
func (p *OIDCProvider) SetProviderUserID(user *model.User, providerUserID string) {
user.OidcId = providerUserID
}
func (p *OIDCProvider) GetProviderPrefix() string {
return "oidc_"
}

36
oauth/provider.go Normal file
View File

@@ -0,0 +1,36 @@
package oauth
import (
"context"
"github.com/QuantumNous/new-api/model"
"github.com/gin-gonic/gin"
)
// Provider defines the interface for OAuth providers
type Provider interface {
// GetName returns the display name of the provider (e.g., "GitHub", "Discord")
GetName() string
// IsEnabled returns whether this OAuth provider is enabled
IsEnabled() bool
// ExchangeToken exchanges the authorization code for an access token
// The gin.Context is passed for providers that need request info (e.g., for redirect_uri)
ExchangeToken(ctx context.Context, code string, c *gin.Context) (*OAuthToken, error)
// GetUserInfo retrieves user information using the access token
GetUserInfo(ctx context.Context, token *OAuthToken) (*OAuthUser, error)
// IsUserIDTaken checks if the provider user ID is already associated with an account
IsUserIDTaken(providerUserID string) bool
// FillUserByProviderID fills the user model by provider user ID
FillUserByProviderID(user *model.User, providerUserID string) error
// SetProviderUserID sets the provider user ID on the user model
SetProviderUserID(user *model.User, providerUserID string)
// GetProviderPrefix returns the prefix for auto-generated usernames (e.g., "github_")
GetProviderPrefix() string
}

43
oauth/registry.go Normal file
View File

@@ -0,0 +1,43 @@
package oauth
import (
"sync"
)
var (
providers = make(map[string]Provider)
mu sync.RWMutex
)
// Register registers an OAuth provider with the given name
func Register(name string, provider Provider) {
mu.Lock()
defer mu.Unlock()
providers[name] = provider
}
// GetProvider returns the OAuth provider for the given name
func GetProvider(name string) Provider {
mu.RLock()
defer mu.RUnlock()
return providers[name]
}
// GetAllProviders returns all registered OAuth providers
func GetAllProviders() map[string]Provider {
mu.RLock()
defer mu.RUnlock()
result := make(map[string]Provider, len(providers))
for k, v := range providers {
result[k] = v
}
return result
}
// IsProviderRegistered checks if a provider is registered
func IsProviderRegistered(name string) bool {
mu.RLock()
defer mu.RUnlock()
_, ok := providers[name]
return ok
}

59
oauth/types.go Normal file
View File

@@ -0,0 +1,59 @@
package oauth
// OAuthToken represents the token received from OAuth provider
type OAuthToken struct {
AccessToken string `json:"access_token"`
TokenType string `json:"token_type"`
RefreshToken string `json:"refresh_token,omitempty"`
ExpiresIn int `json:"expires_in,omitempty"`
Scope string `json:"scope,omitempty"`
IDToken string `json:"id_token,omitempty"`
}
// OAuthUser represents the user info from OAuth provider
type OAuthUser struct {
// ProviderUserID is the unique identifier from the OAuth provider
ProviderUserID string
// Username is the username from the OAuth provider (e.g., GitHub login)
Username string
// DisplayName is the display name from the OAuth provider
DisplayName string
// Email is the email from the OAuth provider
Email string
// Extra contains any additional provider-specific data
Extra map[string]any
}
// OAuthError represents a translatable OAuth error
type OAuthError struct {
// MsgKey is the i18n message key
MsgKey string
// Params contains optional parameters for the message template
Params map[string]any
// RawError is the underlying error for logging purposes
RawError string
}
func (e *OAuthError) Error() string {
if e.RawError != "" {
return e.RawError
}
return e.MsgKey
}
// NewOAuthError creates a new OAuth error with the given message key
func NewOAuthError(msgKey string, params map[string]any) *OAuthError {
return &OAuthError{
MsgKey: msgKey,
Params: params,
}
}
// NewOAuthErrorWithRaw creates a new OAuth error with raw error message for logging
func NewOAuthErrorWithRaw(msgKey string, params map[string]any, rawError string) *OAuthError {
return &OAuthError{
MsgKey: msgKey,
Params: params,
RawError: rawError,
}
}