Files
new-api/oauth/generic.go
CaIon af54ea85d2 feat(oauth): implement custom OAuth provider management #1106
- Add support for custom OAuth providers, including creation, retrieval, updating, and deletion.
- Introduce new model and controller for managing custom OAuth providers.
- Enhance existing OAuth logic to accommodate custom providers.
- Update API routes for custom OAuth provider management.
- Include i18n support for custom OAuth-related messages.
2026-02-05 21:18:43 +08:00

269 lines
9.0 KiB
Go

package oauth
import (
"context"
"encoding/base64"
"encoding/json"
"fmt"
"io"
"net/http"
"net/url"
"strings"
"time"
"github.com/QuantumNous/new-api/i18n"
"github.com/QuantumNous/new-api/logger"
"github.com/QuantumNous/new-api/model"
"github.com/QuantumNous/new-api/setting/system_setting"
"github.com/gin-gonic/gin"
"github.com/tidwall/gjson"
)
// AuthStyle defines how to send client credentials
const (
AuthStyleAutoDetect = 0 // Auto-detect based on server response
AuthStyleInParams = 1 // Send client_id and client_secret as POST parameters
AuthStyleInHeader = 2 // Send as Basic Auth header
)
// GenericOAuthProvider implements OAuth for custom/generic OAuth providers
type GenericOAuthProvider struct {
config *model.CustomOAuthProvider
}
// NewGenericOAuthProvider creates a new generic OAuth provider from config
func NewGenericOAuthProvider(config *model.CustomOAuthProvider) *GenericOAuthProvider {
return &GenericOAuthProvider{config: config}
}
func (p *GenericOAuthProvider) GetName() string {
return p.config.Name
}
func (p *GenericOAuthProvider) IsEnabled() bool {
return p.config.Enabled
}
func (p *GenericOAuthProvider) GetConfig() *model.CustomOAuthProvider {
return p.config
}
func (p *GenericOAuthProvider) ExchangeToken(ctx context.Context, code string, c *gin.Context) (*OAuthToken, error) {
if code == "" {
return nil, NewOAuthError(i18n.MsgOAuthInvalidCode, nil)
}
logger.LogDebug(ctx, "[OAuth-Generic-%s] ExchangeToken: code=%s...", p.config.Slug, code[:min(len(code), 10)])
redirectUri := fmt.Sprintf("%s/oauth/%s", system_setting.ServerAddress, p.config.Slug)
values := url.Values{}
values.Set("grant_type", "authorization_code")
values.Set("code", code)
values.Set("redirect_uri", redirectUri)
// Determine auth style
authStyle := p.config.AuthStyle
if authStyle == AuthStyleAutoDetect {
// Default to params style for most OAuth servers
authStyle = AuthStyleInParams
}
var req *http.Request
var err error
if authStyle == AuthStyleInParams {
values.Set("client_id", p.config.ClientId)
values.Set("client_secret", p.config.ClientSecret)
}
req, err = http.NewRequestWithContext(ctx, "POST", p.config.TokenEndpoint, strings.NewReader(values.Encode()))
if err != nil {
return nil, err
}
req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
req.Header.Set("Accept", "application/json")
if authStyle == AuthStyleInHeader {
// Basic Auth
credentials := base64.StdEncoding.EncodeToString([]byte(p.config.ClientId + ":" + p.config.ClientSecret))
req.Header.Set("Authorization", "Basic "+credentials)
}
logger.LogDebug(ctx, "[OAuth-Generic-%s] ExchangeToken: token_endpoint=%s, redirect_uri=%s, auth_style=%d",
p.config.Slug, p.config.TokenEndpoint, redirectUri, authStyle)
client := http.Client{
Timeout: 20 * time.Second,
}
res, err := client.Do(req)
if err != nil {
logger.LogError(ctx, fmt.Sprintf("[OAuth-Generic-%s] ExchangeToken error: %s", p.config.Slug, err.Error()))
return nil, NewOAuthErrorWithRaw(i18n.MsgOAuthConnectFailed, map[string]any{"Provider": p.config.Name}, err.Error())
}
defer res.Body.Close()
logger.LogDebug(ctx, "[OAuth-Generic-%s] ExchangeToken response status: %d", p.config.Slug, res.StatusCode)
body, err := io.ReadAll(res.Body)
if err != nil {
logger.LogError(ctx, fmt.Sprintf("[OAuth-Generic-%s] ExchangeToken read body error: %s", p.config.Slug, err.Error()))
return nil, err
}
bodyStr := string(body)
logger.LogDebug(ctx, "[OAuth-Generic-%s] ExchangeToken response body: %s", p.config.Slug, bodyStr[:min(len(bodyStr), 500)])
// Try to parse as JSON first
var tokenResponse struct {
AccessToken string `json:"access_token"`
TokenType string `json:"token_type"`
RefreshToken string `json:"refresh_token"`
ExpiresIn int `json:"expires_in"`
Scope string `json:"scope"`
IDToken string `json:"id_token"`
Error string `json:"error"`
ErrorDesc string `json:"error_description"`
}
if err := json.Unmarshal(body, &tokenResponse); err != nil {
// Try to parse as URL-encoded (some OAuth servers like GitHub return this format)
parsedValues, parseErr := url.ParseQuery(bodyStr)
if parseErr != nil {
logger.LogError(ctx, fmt.Sprintf("[OAuth-Generic-%s] ExchangeToken parse error: %s", p.config.Slug, err.Error()))
return nil, err
}
tokenResponse.AccessToken = parsedValues.Get("access_token")
tokenResponse.TokenType = parsedValues.Get("token_type")
tokenResponse.Scope = parsedValues.Get("scope")
}
if tokenResponse.Error != "" {
logger.LogError(ctx, fmt.Sprintf("[OAuth-Generic-%s] ExchangeToken OAuth error: %s - %s",
p.config.Slug, tokenResponse.Error, tokenResponse.ErrorDesc))
return nil, NewOAuthErrorWithRaw(i18n.MsgOAuthTokenFailed, map[string]any{"Provider": p.config.Name}, tokenResponse.ErrorDesc)
}
if tokenResponse.AccessToken == "" {
logger.LogError(ctx, fmt.Sprintf("[OAuth-Generic-%s] ExchangeToken failed: empty access token", p.config.Slug))
return nil, NewOAuthError(i18n.MsgOAuthTokenFailed, map[string]any{"Provider": p.config.Name})
}
logger.LogDebug(ctx, "[OAuth-Generic-%s] ExchangeToken success: scope=%s", p.config.Slug, tokenResponse.Scope)
return &OAuthToken{
AccessToken: tokenResponse.AccessToken,
TokenType: tokenResponse.TokenType,
RefreshToken: tokenResponse.RefreshToken,
ExpiresIn: tokenResponse.ExpiresIn,
Scope: tokenResponse.Scope,
IDToken: tokenResponse.IDToken,
}, nil
}
func (p *GenericOAuthProvider) GetUserInfo(ctx context.Context, token *OAuthToken) (*OAuthUser, error) {
logger.LogDebug(ctx, "[OAuth-Generic-%s] GetUserInfo: fetching user info from %s", p.config.Slug, p.config.UserInfoEndpoint)
req, err := http.NewRequestWithContext(ctx, "GET", p.config.UserInfoEndpoint, nil)
if err != nil {
return nil, err
}
// Set authorization header
tokenType := token.TokenType
if tokenType == "" {
tokenType = "Bearer"
}
req.Header.Set("Authorization", fmt.Sprintf("%s %s", tokenType, token.AccessToken))
req.Header.Set("Accept", "application/json")
client := http.Client{
Timeout: 20 * time.Second,
}
res, err := client.Do(req)
if err != nil {
logger.LogError(ctx, fmt.Sprintf("[OAuth-Generic-%s] GetUserInfo error: %s", p.config.Slug, err.Error()))
return nil, NewOAuthErrorWithRaw(i18n.MsgOAuthConnectFailed, map[string]any{"Provider": p.config.Name}, err.Error())
}
defer res.Body.Close()
logger.LogDebug(ctx, "[OAuth-Generic-%s] GetUserInfo response status: %d", p.config.Slug, res.StatusCode)
if res.StatusCode != http.StatusOK {
logger.LogError(ctx, fmt.Sprintf("[OAuth-Generic-%s] GetUserInfo failed: status=%d", p.config.Slug, res.StatusCode))
return nil, NewOAuthError(i18n.MsgOAuthGetUserErr, nil)
}
body, err := io.ReadAll(res.Body)
if err != nil {
logger.LogError(ctx, fmt.Sprintf("[OAuth-Generic-%s] GetUserInfo read body error: %s", p.config.Slug, err.Error()))
return nil, err
}
bodyStr := string(body)
logger.LogDebug(ctx, "[OAuth-Generic-%s] GetUserInfo response body: %s", p.config.Slug, bodyStr[:min(len(bodyStr), 500)])
// Extract fields using gjson (supports JSONPath-like syntax)
userId := gjson.Get(bodyStr, p.config.UserIdField).String()
username := gjson.Get(bodyStr, p.config.UsernameField).String()
displayName := gjson.Get(bodyStr, p.config.DisplayNameField).String()
email := gjson.Get(bodyStr, p.config.EmailField).String()
// If user ID field returns a number, convert it
if userId == "" {
// Try to get as number
userIdNum := gjson.Get(bodyStr, p.config.UserIdField)
if userIdNum.Exists() {
userId = userIdNum.Raw
// Remove quotes if present
userId = strings.Trim(userId, "\"")
}
}
if userId == "" {
logger.LogError(ctx, fmt.Sprintf("[OAuth-Generic-%s] GetUserInfo failed: empty user ID (field: %s)", p.config.Slug, p.config.UserIdField))
return nil, NewOAuthError(i18n.MsgOAuthUserInfoEmpty, map[string]any{"Provider": p.config.Name})
}
logger.LogDebug(ctx, "[OAuth-Generic-%s] GetUserInfo success: id=%s, username=%s, name=%s, email=%s",
p.config.Slug, userId, username, displayName, email)
return &OAuthUser{
ProviderUserID: userId,
Username: username,
DisplayName: displayName,
Email: email,
}, nil
}
func (p *GenericOAuthProvider) IsUserIDTaken(providerUserID string) bool {
return model.IsProviderUserIdTaken(p.config.Id, providerUserID)
}
func (p *GenericOAuthProvider) FillUserByProviderID(user *model.User, providerUserID string) error {
foundUser, err := model.GetUserByOAuthBinding(p.config.Id, providerUserID)
if err != nil {
return err
}
*user = *foundUser
return nil
}
func (p *GenericOAuthProvider) SetProviderUserID(user *model.User, providerUserID string) {
// For generic providers, we store the binding in user_oauth_bindings table
// This is handled separately in the OAuth controller
}
func (p *GenericOAuthProvider) GetProviderPrefix() string {
return p.config.Slug + "_"
}
// GetProviderId returns the provider ID for binding purposes
func (p *GenericOAuthProvider) GetProviderId() int {
return p.config.Id
}
// IsGenericProvider returns true for generic providers
func (p *GenericOAuthProvider) IsGenericProvider() bool {
return true
}