mirror of
https://github.com/QuantumNous/new-api.git
synced 2026-03-30 08:36:22 +00:00
- Add support for custom OAuth providers, including creation, retrieval, updating, and deletion. - Introduce new model and controller for managing custom OAuth providers. - Enhance existing OAuth logic to accommodate custom providers. - Update API routes for custom OAuth provider management. - Include i18n support for custom OAuth-related messages.
269 lines
9.0 KiB
Go
269 lines
9.0 KiB
Go
package oauth
|
|
|
|
import (
|
|
"context"
|
|
"encoding/base64"
|
|
"encoding/json"
|
|
"fmt"
|
|
"io"
|
|
"net/http"
|
|
"net/url"
|
|
"strings"
|
|
"time"
|
|
|
|
"github.com/QuantumNous/new-api/i18n"
|
|
"github.com/QuantumNous/new-api/logger"
|
|
"github.com/QuantumNous/new-api/model"
|
|
"github.com/QuantumNous/new-api/setting/system_setting"
|
|
"github.com/gin-gonic/gin"
|
|
"github.com/tidwall/gjson"
|
|
)
|
|
|
|
// AuthStyle defines how to send client credentials
|
|
const (
|
|
AuthStyleAutoDetect = 0 // Auto-detect based on server response
|
|
AuthStyleInParams = 1 // Send client_id and client_secret as POST parameters
|
|
AuthStyleInHeader = 2 // Send as Basic Auth header
|
|
)
|
|
|
|
// GenericOAuthProvider implements OAuth for custom/generic OAuth providers
|
|
type GenericOAuthProvider struct {
|
|
config *model.CustomOAuthProvider
|
|
}
|
|
|
|
// NewGenericOAuthProvider creates a new generic OAuth provider from config
|
|
func NewGenericOAuthProvider(config *model.CustomOAuthProvider) *GenericOAuthProvider {
|
|
return &GenericOAuthProvider{config: config}
|
|
}
|
|
|
|
func (p *GenericOAuthProvider) GetName() string {
|
|
return p.config.Name
|
|
}
|
|
|
|
func (p *GenericOAuthProvider) IsEnabled() bool {
|
|
return p.config.Enabled
|
|
}
|
|
|
|
func (p *GenericOAuthProvider) GetConfig() *model.CustomOAuthProvider {
|
|
return p.config
|
|
}
|
|
|
|
func (p *GenericOAuthProvider) ExchangeToken(ctx context.Context, code string, c *gin.Context) (*OAuthToken, error) {
|
|
if code == "" {
|
|
return nil, NewOAuthError(i18n.MsgOAuthInvalidCode, nil)
|
|
}
|
|
|
|
logger.LogDebug(ctx, "[OAuth-Generic-%s] ExchangeToken: code=%s...", p.config.Slug, code[:min(len(code), 10)])
|
|
|
|
redirectUri := fmt.Sprintf("%s/oauth/%s", system_setting.ServerAddress, p.config.Slug)
|
|
values := url.Values{}
|
|
values.Set("grant_type", "authorization_code")
|
|
values.Set("code", code)
|
|
values.Set("redirect_uri", redirectUri)
|
|
|
|
// Determine auth style
|
|
authStyle := p.config.AuthStyle
|
|
if authStyle == AuthStyleAutoDetect {
|
|
// Default to params style for most OAuth servers
|
|
authStyle = AuthStyleInParams
|
|
}
|
|
|
|
var req *http.Request
|
|
var err error
|
|
|
|
if authStyle == AuthStyleInParams {
|
|
values.Set("client_id", p.config.ClientId)
|
|
values.Set("client_secret", p.config.ClientSecret)
|
|
}
|
|
|
|
req, err = http.NewRequestWithContext(ctx, "POST", p.config.TokenEndpoint, strings.NewReader(values.Encode()))
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
|
|
req.Header.Set("Accept", "application/json")
|
|
|
|
if authStyle == AuthStyleInHeader {
|
|
// Basic Auth
|
|
credentials := base64.StdEncoding.EncodeToString([]byte(p.config.ClientId + ":" + p.config.ClientSecret))
|
|
req.Header.Set("Authorization", "Basic "+credentials)
|
|
}
|
|
|
|
logger.LogDebug(ctx, "[OAuth-Generic-%s] ExchangeToken: token_endpoint=%s, redirect_uri=%s, auth_style=%d",
|
|
p.config.Slug, p.config.TokenEndpoint, redirectUri, authStyle)
|
|
|
|
client := http.Client{
|
|
Timeout: 20 * time.Second,
|
|
}
|
|
res, err := client.Do(req)
|
|
if err != nil {
|
|
logger.LogError(ctx, fmt.Sprintf("[OAuth-Generic-%s] ExchangeToken error: %s", p.config.Slug, err.Error()))
|
|
return nil, NewOAuthErrorWithRaw(i18n.MsgOAuthConnectFailed, map[string]any{"Provider": p.config.Name}, err.Error())
|
|
}
|
|
defer res.Body.Close()
|
|
|
|
logger.LogDebug(ctx, "[OAuth-Generic-%s] ExchangeToken response status: %d", p.config.Slug, res.StatusCode)
|
|
|
|
body, err := io.ReadAll(res.Body)
|
|
if err != nil {
|
|
logger.LogError(ctx, fmt.Sprintf("[OAuth-Generic-%s] ExchangeToken read body error: %s", p.config.Slug, err.Error()))
|
|
return nil, err
|
|
}
|
|
|
|
bodyStr := string(body)
|
|
logger.LogDebug(ctx, "[OAuth-Generic-%s] ExchangeToken response body: %s", p.config.Slug, bodyStr[:min(len(bodyStr), 500)])
|
|
|
|
// Try to parse as JSON first
|
|
var tokenResponse struct {
|
|
AccessToken string `json:"access_token"`
|
|
TokenType string `json:"token_type"`
|
|
RefreshToken string `json:"refresh_token"`
|
|
ExpiresIn int `json:"expires_in"`
|
|
Scope string `json:"scope"`
|
|
IDToken string `json:"id_token"`
|
|
Error string `json:"error"`
|
|
ErrorDesc string `json:"error_description"`
|
|
}
|
|
|
|
if err := json.Unmarshal(body, &tokenResponse); err != nil {
|
|
// Try to parse as URL-encoded (some OAuth servers like GitHub return this format)
|
|
parsedValues, parseErr := url.ParseQuery(bodyStr)
|
|
if parseErr != nil {
|
|
logger.LogError(ctx, fmt.Sprintf("[OAuth-Generic-%s] ExchangeToken parse error: %s", p.config.Slug, err.Error()))
|
|
return nil, err
|
|
}
|
|
tokenResponse.AccessToken = parsedValues.Get("access_token")
|
|
tokenResponse.TokenType = parsedValues.Get("token_type")
|
|
tokenResponse.Scope = parsedValues.Get("scope")
|
|
}
|
|
|
|
if tokenResponse.Error != "" {
|
|
logger.LogError(ctx, fmt.Sprintf("[OAuth-Generic-%s] ExchangeToken OAuth error: %s - %s",
|
|
p.config.Slug, tokenResponse.Error, tokenResponse.ErrorDesc))
|
|
return nil, NewOAuthErrorWithRaw(i18n.MsgOAuthTokenFailed, map[string]any{"Provider": p.config.Name}, tokenResponse.ErrorDesc)
|
|
}
|
|
|
|
if tokenResponse.AccessToken == "" {
|
|
logger.LogError(ctx, fmt.Sprintf("[OAuth-Generic-%s] ExchangeToken failed: empty access token", p.config.Slug))
|
|
return nil, NewOAuthError(i18n.MsgOAuthTokenFailed, map[string]any{"Provider": p.config.Name})
|
|
}
|
|
|
|
logger.LogDebug(ctx, "[OAuth-Generic-%s] ExchangeToken success: scope=%s", p.config.Slug, tokenResponse.Scope)
|
|
|
|
return &OAuthToken{
|
|
AccessToken: tokenResponse.AccessToken,
|
|
TokenType: tokenResponse.TokenType,
|
|
RefreshToken: tokenResponse.RefreshToken,
|
|
ExpiresIn: tokenResponse.ExpiresIn,
|
|
Scope: tokenResponse.Scope,
|
|
IDToken: tokenResponse.IDToken,
|
|
}, nil
|
|
}
|
|
|
|
func (p *GenericOAuthProvider) GetUserInfo(ctx context.Context, token *OAuthToken) (*OAuthUser, error) {
|
|
logger.LogDebug(ctx, "[OAuth-Generic-%s] GetUserInfo: fetching user info from %s", p.config.Slug, p.config.UserInfoEndpoint)
|
|
|
|
req, err := http.NewRequestWithContext(ctx, "GET", p.config.UserInfoEndpoint, nil)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
// Set authorization header
|
|
tokenType := token.TokenType
|
|
if tokenType == "" {
|
|
tokenType = "Bearer"
|
|
}
|
|
req.Header.Set("Authorization", fmt.Sprintf("%s %s", tokenType, token.AccessToken))
|
|
req.Header.Set("Accept", "application/json")
|
|
|
|
client := http.Client{
|
|
Timeout: 20 * time.Second,
|
|
}
|
|
res, err := client.Do(req)
|
|
if err != nil {
|
|
logger.LogError(ctx, fmt.Sprintf("[OAuth-Generic-%s] GetUserInfo error: %s", p.config.Slug, err.Error()))
|
|
return nil, NewOAuthErrorWithRaw(i18n.MsgOAuthConnectFailed, map[string]any{"Provider": p.config.Name}, err.Error())
|
|
}
|
|
defer res.Body.Close()
|
|
|
|
logger.LogDebug(ctx, "[OAuth-Generic-%s] GetUserInfo response status: %d", p.config.Slug, res.StatusCode)
|
|
|
|
if res.StatusCode != http.StatusOK {
|
|
logger.LogError(ctx, fmt.Sprintf("[OAuth-Generic-%s] GetUserInfo failed: status=%d", p.config.Slug, res.StatusCode))
|
|
return nil, NewOAuthError(i18n.MsgOAuthGetUserErr, nil)
|
|
}
|
|
|
|
body, err := io.ReadAll(res.Body)
|
|
if err != nil {
|
|
logger.LogError(ctx, fmt.Sprintf("[OAuth-Generic-%s] GetUserInfo read body error: %s", p.config.Slug, err.Error()))
|
|
return nil, err
|
|
}
|
|
|
|
bodyStr := string(body)
|
|
logger.LogDebug(ctx, "[OAuth-Generic-%s] GetUserInfo response body: %s", p.config.Slug, bodyStr[:min(len(bodyStr), 500)])
|
|
|
|
// Extract fields using gjson (supports JSONPath-like syntax)
|
|
userId := gjson.Get(bodyStr, p.config.UserIdField).String()
|
|
username := gjson.Get(bodyStr, p.config.UsernameField).String()
|
|
displayName := gjson.Get(bodyStr, p.config.DisplayNameField).String()
|
|
email := gjson.Get(bodyStr, p.config.EmailField).String()
|
|
|
|
// If user ID field returns a number, convert it
|
|
if userId == "" {
|
|
// Try to get as number
|
|
userIdNum := gjson.Get(bodyStr, p.config.UserIdField)
|
|
if userIdNum.Exists() {
|
|
userId = userIdNum.Raw
|
|
// Remove quotes if present
|
|
userId = strings.Trim(userId, "\"")
|
|
}
|
|
}
|
|
|
|
if userId == "" {
|
|
logger.LogError(ctx, fmt.Sprintf("[OAuth-Generic-%s] GetUserInfo failed: empty user ID (field: %s)", p.config.Slug, p.config.UserIdField))
|
|
return nil, NewOAuthError(i18n.MsgOAuthUserInfoEmpty, map[string]any{"Provider": p.config.Name})
|
|
}
|
|
|
|
logger.LogDebug(ctx, "[OAuth-Generic-%s] GetUserInfo success: id=%s, username=%s, name=%s, email=%s",
|
|
p.config.Slug, userId, username, displayName, email)
|
|
|
|
return &OAuthUser{
|
|
ProviderUserID: userId,
|
|
Username: username,
|
|
DisplayName: displayName,
|
|
Email: email,
|
|
}, nil
|
|
}
|
|
|
|
func (p *GenericOAuthProvider) IsUserIDTaken(providerUserID string) bool {
|
|
return model.IsProviderUserIdTaken(p.config.Id, providerUserID)
|
|
}
|
|
|
|
func (p *GenericOAuthProvider) FillUserByProviderID(user *model.User, providerUserID string) error {
|
|
foundUser, err := model.GetUserByOAuthBinding(p.config.Id, providerUserID)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
*user = *foundUser
|
|
return nil
|
|
}
|
|
|
|
func (p *GenericOAuthProvider) SetProviderUserID(user *model.User, providerUserID string) {
|
|
// For generic providers, we store the binding in user_oauth_bindings table
|
|
// This is handled separately in the OAuth controller
|
|
}
|
|
|
|
func (p *GenericOAuthProvider) GetProviderPrefix() string {
|
|
return p.config.Slug + "_"
|
|
}
|
|
|
|
// GetProviderId returns the provider ID for binding purposes
|
|
func (p *GenericOAuthProvider) GetProviderId() int {
|
|
return p.config.Id
|
|
}
|
|
|
|
// IsGenericProvider returns true for generic providers
|
|
func (p *GenericOAuthProvider) IsGenericProvider() bool {
|
|
return true
|
|
}
|