mirror of
https://github.com/QuantumNous/new-api.git
synced 2026-03-29 23:10:35 +00:00
585 lines
17 KiB
Go
585 lines
17 KiB
Go
package controller
|
|
|
|
import (
|
|
"context"
|
|
"io"
|
|
"net/http"
|
|
"net/url"
|
|
"strconv"
|
|
"strings"
|
|
"time"
|
|
|
|
"github.com/QuantumNous/new-api/common"
|
|
"github.com/QuantumNous/new-api/model"
|
|
"github.com/QuantumNous/new-api/oauth"
|
|
"github.com/gin-gonic/gin"
|
|
)
|
|
|
|
// CustomOAuthProviderResponse is the response structure for custom OAuth providers
|
|
// It excludes sensitive fields like client_secret
|
|
type CustomOAuthProviderResponse struct {
|
|
Id int `json:"id"`
|
|
Name string `json:"name"`
|
|
Slug string `json:"slug"`
|
|
Icon string `json:"icon"`
|
|
Enabled bool `json:"enabled"`
|
|
ClientId string `json:"client_id"`
|
|
AuthorizationEndpoint string `json:"authorization_endpoint"`
|
|
TokenEndpoint string `json:"token_endpoint"`
|
|
UserInfoEndpoint string `json:"user_info_endpoint"`
|
|
Scopes string `json:"scopes"`
|
|
UserIdField string `json:"user_id_field"`
|
|
UsernameField string `json:"username_field"`
|
|
DisplayNameField string `json:"display_name_field"`
|
|
EmailField string `json:"email_field"`
|
|
WellKnown string `json:"well_known"`
|
|
AuthStyle int `json:"auth_style"`
|
|
AccessPolicy string `json:"access_policy"`
|
|
AccessDeniedMessage string `json:"access_denied_message"`
|
|
}
|
|
|
|
type UserOAuthBindingResponse struct {
|
|
ProviderId int `json:"provider_id"`
|
|
ProviderName string `json:"provider_name"`
|
|
ProviderSlug string `json:"provider_slug"`
|
|
ProviderIcon string `json:"provider_icon"`
|
|
ProviderUserId string `json:"provider_user_id"`
|
|
}
|
|
|
|
func toCustomOAuthProviderResponse(p *model.CustomOAuthProvider) *CustomOAuthProviderResponse {
|
|
return &CustomOAuthProviderResponse{
|
|
Id: p.Id,
|
|
Name: p.Name,
|
|
Slug: p.Slug,
|
|
Icon: p.Icon,
|
|
Enabled: p.Enabled,
|
|
ClientId: p.ClientId,
|
|
AuthorizationEndpoint: p.AuthorizationEndpoint,
|
|
TokenEndpoint: p.TokenEndpoint,
|
|
UserInfoEndpoint: p.UserInfoEndpoint,
|
|
Scopes: p.Scopes,
|
|
UserIdField: p.UserIdField,
|
|
UsernameField: p.UsernameField,
|
|
DisplayNameField: p.DisplayNameField,
|
|
EmailField: p.EmailField,
|
|
WellKnown: p.WellKnown,
|
|
AuthStyle: p.AuthStyle,
|
|
AccessPolicy: p.AccessPolicy,
|
|
AccessDeniedMessage: p.AccessDeniedMessage,
|
|
}
|
|
}
|
|
|
|
// GetCustomOAuthProviders returns all custom OAuth providers
|
|
func GetCustomOAuthProviders(c *gin.Context) {
|
|
providers, err := model.GetAllCustomOAuthProviders()
|
|
if err != nil {
|
|
common.ApiError(c, err)
|
|
return
|
|
}
|
|
|
|
response := make([]*CustomOAuthProviderResponse, len(providers))
|
|
for i, p := range providers {
|
|
response[i] = toCustomOAuthProviderResponse(p)
|
|
}
|
|
|
|
c.JSON(http.StatusOK, gin.H{
|
|
"success": true,
|
|
"message": "",
|
|
"data": response,
|
|
})
|
|
}
|
|
|
|
// GetCustomOAuthProvider returns a single custom OAuth provider by ID
|
|
func GetCustomOAuthProvider(c *gin.Context) {
|
|
idStr := c.Param("id")
|
|
id, err := strconv.Atoi(idStr)
|
|
if err != nil {
|
|
common.ApiErrorMsg(c, "无效的 ID")
|
|
return
|
|
}
|
|
|
|
provider, err := model.GetCustomOAuthProviderById(id)
|
|
if err != nil {
|
|
common.ApiErrorMsg(c, "未找到该 OAuth 提供商")
|
|
return
|
|
}
|
|
|
|
c.JSON(http.StatusOK, gin.H{
|
|
"success": true,
|
|
"message": "",
|
|
"data": toCustomOAuthProviderResponse(provider),
|
|
})
|
|
}
|
|
|
|
// CreateCustomOAuthProviderRequest is the request structure for creating a custom OAuth provider
|
|
type CreateCustomOAuthProviderRequest struct {
|
|
Name string `json:"name" binding:"required"`
|
|
Slug string `json:"slug" binding:"required"`
|
|
Icon string `json:"icon"`
|
|
Enabled bool `json:"enabled"`
|
|
ClientId string `json:"client_id" binding:"required"`
|
|
ClientSecret string `json:"client_secret" binding:"required"`
|
|
AuthorizationEndpoint string `json:"authorization_endpoint" binding:"required"`
|
|
TokenEndpoint string `json:"token_endpoint" binding:"required"`
|
|
UserInfoEndpoint string `json:"user_info_endpoint" binding:"required"`
|
|
Scopes string `json:"scopes"`
|
|
UserIdField string `json:"user_id_field"`
|
|
UsernameField string `json:"username_field"`
|
|
DisplayNameField string `json:"display_name_field"`
|
|
EmailField string `json:"email_field"`
|
|
WellKnown string `json:"well_known"`
|
|
AuthStyle int `json:"auth_style"`
|
|
AccessPolicy string `json:"access_policy"`
|
|
AccessDeniedMessage string `json:"access_denied_message"`
|
|
}
|
|
|
|
type FetchCustomOAuthDiscoveryRequest struct {
|
|
WellKnownURL string `json:"well_known_url"`
|
|
IssuerURL string `json:"issuer_url"`
|
|
}
|
|
|
|
// FetchCustomOAuthDiscovery fetches OIDC discovery document via backend (root-only route)
|
|
func FetchCustomOAuthDiscovery(c *gin.Context) {
|
|
var req FetchCustomOAuthDiscoveryRequest
|
|
if err := c.ShouldBindJSON(&req); err != nil {
|
|
common.ApiErrorMsg(c, "无效的请求参数: "+err.Error())
|
|
return
|
|
}
|
|
|
|
wellKnownURL := strings.TrimSpace(req.WellKnownURL)
|
|
issuerURL := strings.TrimSpace(req.IssuerURL)
|
|
|
|
if wellKnownURL == "" && issuerURL == "" {
|
|
common.ApiErrorMsg(c, "请先填写 Discovery URL 或 Issuer URL")
|
|
return
|
|
}
|
|
|
|
targetURL := wellKnownURL
|
|
if targetURL == "" {
|
|
targetURL = strings.TrimRight(issuerURL, "/") + "/.well-known/openid-configuration"
|
|
}
|
|
targetURL = strings.TrimSpace(targetURL)
|
|
|
|
parsedURL, err := url.Parse(targetURL)
|
|
if err != nil || parsedURL.Host == "" || (parsedURL.Scheme != "http" && parsedURL.Scheme != "https") {
|
|
common.ApiErrorMsg(c, "Discovery URL 无效,仅支持 http/https")
|
|
return
|
|
}
|
|
|
|
ctx, cancel := context.WithTimeout(c.Request.Context(), 20*time.Second)
|
|
defer cancel()
|
|
|
|
httpReq, err := http.NewRequestWithContext(ctx, http.MethodGet, targetURL, nil)
|
|
if err != nil {
|
|
common.ApiErrorMsg(c, "创建 Discovery 请求失败: "+err.Error())
|
|
return
|
|
}
|
|
httpReq.Header.Set("Accept", "application/json")
|
|
|
|
client := &http.Client{Timeout: 20 * time.Second}
|
|
resp, err := client.Do(httpReq)
|
|
if err != nil {
|
|
common.ApiErrorMsg(c, "获取 Discovery 配置失败: "+err.Error())
|
|
return
|
|
}
|
|
defer resp.Body.Close()
|
|
|
|
if resp.StatusCode != http.StatusOK {
|
|
body, _ := io.ReadAll(io.LimitReader(resp.Body, 512))
|
|
message := strings.TrimSpace(string(body))
|
|
if message == "" {
|
|
message = resp.Status
|
|
}
|
|
common.ApiErrorMsg(c, "获取 Discovery 配置失败: "+message)
|
|
return
|
|
}
|
|
|
|
var discovery map[string]any
|
|
if err = common.DecodeJson(resp.Body, &discovery); err != nil {
|
|
common.ApiErrorMsg(c, "解析 Discovery 配置失败: "+err.Error())
|
|
return
|
|
}
|
|
|
|
c.JSON(http.StatusOK, gin.H{
|
|
"success": true,
|
|
"message": "",
|
|
"data": gin.H{
|
|
"well_known_url": targetURL,
|
|
"discovery": discovery,
|
|
},
|
|
})
|
|
}
|
|
|
|
// CreateCustomOAuthProvider creates a new custom OAuth provider
|
|
func CreateCustomOAuthProvider(c *gin.Context) {
|
|
var req CreateCustomOAuthProviderRequest
|
|
if err := c.ShouldBindJSON(&req); err != nil {
|
|
common.ApiErrorMsg(c, "无效的请求参数: "+err.Error())
|
|
return
|
|
}
|
|
|
|
// Check if slug is already taken
|
|
if model.IsSlugTaken(req.Slug, 0) {
|
|
common.ApiErrorMsg(c, "该 Slug 已被使用")
|
|
return
|
|
}
|
|
|
|
// Check if slug conflicts with built-in providers
|
|
if oauth.IsProviderRegistered(req.Slug) && !oauth.IsCustomProvider(req.Slug) {
|
|
common.ApiErrorMsg(c, "该 Slug 与内置 OAuth 提供商冲突")
|
|
return
|
|
}
|
|
|
|
provider := &model.CustomOAuthProvider{
|
|
Name: req.Name,
|
|
Slug: req.Slug,
|
|
Icon: req.Icon,
|
|
Enabled: req.Enabled,
|
|
ClientId: req.ClientId,
|
|
ClientSecret: req.ClientSecret,
|
|
AuthorizationEndpoint: req.AuthorizationEndpoint,
|
|
TokenEndpoint: req.TokenEndpoint,
|
|
UserInfoEndpoint: req.UserInfoEndpoint,
|
|
Scopes: req.Scopes,
|
|
UserIdField: req.UserIdField,
|
|
UsernameField: req.UsernameField,
|
|
DisplayNameField: req.DisplayNameField,
|
|
EmailField: req.EmailField,
|
|
WellKnown: req.WellKnown,
|
|
AuthStyle: req.AuthStyle,
|
|
AccessPolicy: req.AccessPolicy,
|
|
AccessDeniedMessage: req.AccessDeniedMessage,
|
|
}
|
|
|
|
if err := model.CreateCustomOAuthProvider(provider); err != nil {
|
|
common.ApiError(c, err)
|
|
return
|
|
}
|
|
|
|
// Register the provider in the OAuth registry
|
|
oauth.RegisterOrUpdateCustomProvider(provider)
|
|
|
|
c.JSON(http.StatusOK, gin.H{
|
|
"success": true,
|
|
"message": "创建成功",
|
|
"data": toCustomOAuthProviderResponse(provider),
|
|
})
|
|
}
|
|
|
|
// UpdateCustomOAuthProviderRequest is the request structure for updating a custom OAuth provider
|
|
type UpdateCustomOAuthProviderRequest struct {
|
|
Name string `json:"name"`
|
|
Slug string `json:"slug"`
|
|
Icon *string `json:"icon"` // Optional: if nil, keep existing
|
|
Enabled *bool `json:"enabled"` // Optional: if nil, keep existing
|
|
ClientId string `json:"client_id"`
|
|
ClientSecret string `json:"client_secret"` // Optional: if empty, keep existing
|
|
AuthorizationEndpoint string `json:"authorization_endpoint"`
|
|
TokenEndpoint string `json:"token_endpoint"`
|
|
UserInfoEndpoint string `json:"user_info_endpoint"`
|
|
Scopes string `json:"scopes"`
|
|
UserIdField string `json:"user_id_field"`
|
|
UsernameField string `json:"username_field"`
|
|
DisplayNameField string `json:"display_name_field"`
|
|
EmailField string `json:"email_field"`
|
|
WellKnown *string `json:"well_known"` // Optional: if nil, keep existing
|
|
AuthStyle *int `json:"auth_style"` // Optional: if nil, keep existing
|
|
AccessPolicy *string `json:"access_policy"` // Optional: if nil, keep existing
|
|
AccessDeniedMessage *string `json:"access_denied_message"` // Optional: if nil, keep existing
|
|
}
|
|
|
|
// UpdateCustomOAuthProvider updates an existing custom OAuth provider
|
|
func UpdateCustomOAuthProvider(c *gin.Context) {
|
|
idStr := c.Param("id")
|
|
id, err := strconv.Atoi(idStr)
|
|
if err != nil {
|
|
common.ApiErrorMsg(c, "无效的 ID")
|
|
return
|
|
}
|
|
|
|
var req UpdateCustomOAuthProviderRequest
|
|
if err := c.ShouldBindJSON(&req); err != nil {
|
|
common.ApiErrorMsg(c, "无效的请求参数: "+err.Error())
|
|
return
|
|
}
|
|
|
|
// Get existing provider
|
|
provider, err := model.GetCustomOAuthProviderById(id)
|
|
if err != nil {
|
|
common.ApiErrorMsg(c, "未找到该 OAuth 提供商")
|
|
return
|
|
}
|
|
|
|
oldSlug := provider.Slug
|
|
|
|
// Check if new slug is taken by another provider
|
|
if req.Slug != "" && req.Slug != provider.Slug {
|
|
if model.IsSlugTaken(req.Slug, id) {
|
|
common.ApiErrorMsg(c, "该 Slug 已被使用")
|
|
return
|
|
}
|
|
// Check if slug conflicts with built-in providers
|
|
if oauth.IsProviderRegistered(req.Slug) && !oauth.IsCustomProvider(req.Slug) {
|
|
common.ApiErrorMsg(c, "该 Slug 与内置 OAuth 提供商冲突")
|
|
return
|
|
}
|
|
}
|
|
|
|
// Update fields
|
|
if req.Name != "" {
|
|
provider.Name = req.Name
|
|
}
|
|
if req.Slug != "" {
|
|
provider.Slug = req.Slug
|
|
}
|
|
if req.Icon != nil {
|
|
provider.Icon = *req.Icon
|
|
}
|
|
if req.Enabled != nil {
|
|
provider.Enabled = *req.Enabled
|
|
}
|
|
if req.ClientId != "" {
|
|
provider.ClientId = req.ClientId
|
|
}
|
|
if req.ClientSecret != "" {
|
|
provider.ClientSecret = req.ClientSecret
|
|
}
|
|
if req.AuthorizationEndpoint != "" {
|
|
provider.AuthorizationEndpoint = req.AuthorizationEndpoint
|
|
}
|
|
if req.TokenEndpoint != "" {
|
|
provider.TokenEndpoint = req.TokenEndpoint
|
|
}
|
|
if req.UserInfoEndpoint != "" {
|
|
provider.UserInfoEndpoint = req.UserInfoEndpoint
|
|
}
|
|
if req.Scopes != "" {
|
|
provider.Scopes = req.Scopes
|
|
}
|
|
if req.UserIdField != "" {
|
|
provider.UserIdField = req.UserIdField
|
|
}
|
|
if req.UsernameField != "" {
|
|
provider.UsernameField = req.UsernameField
|
|
}
|
|
if req.DisplayNameField != "" {
|
|
provider.DisplayNameField = req.DisplayNameField
|
|
}
|
|
if req.EmailField != "" {
|
|
provider.EmailField = req.EmailField
|
|
}
|
|
if req.WellKnown != nil {
|
|
provider.WellKnown = *req.WellKnown
|
|
}
|
|
if req.AuthStyle != nil {
|
|
provider.AuthStyle = *req.AuthStyle
|
|
}
|
|
if req.AccessPolicy != nil {
|
|
provider.AccessPolicy = *req.AccessPolicy
|
|
}
|
|
if req.AccessDeniedMessage != nil {
|
|
provider.AccessDeniedMessage = *req.AccessDeniedMessage
|
|
}
|
|
|
|
if err := model.UpdateCustomOAuthProvider(provider); err != nil {
|
|
common.ApiError(c, err)
|
|
return
|
|
}
|
|
|
|
// Update the provider in the OAuth registry
|
|
if oldSlug != provider.Slug {
|
|
oauth.UnregisterCustomProvider(oldSlug)
|
|
}
|
|
oauth.RegisterOrUpdateCustomProvider(provider)
|
|
|
|
c.JSON(http.StatusOK, gin.H{
|
|
"success": true,
|
|
"message": "更新成功",
|
|
"data": toCustomOAuthProviderResponse(provider),
|
|
})
|
|
}
|
|
|
|
// DeleteCustomOAuthProvider deletes a custom OAuth provider
|
|
func DeleteCustomOAuthProvider(c *gin.Context) {
|
|
idStr := c.Param("id")
|
|
id, err := strconv.Atoi(idStr)
|
|
if err != nil {
|
|
common.ApiErrorMsg(c, "无效的 ID")
|
|
return
|
|
}
|
|
|
|
// Get existing provider to get slug
|
|
provider, err := model.GetCustomOAuthProviderById(id)
|
|
if err != nil {
|
|
common.ApiErrorMsg(c, "未找到该 OAuth 提供商")
|
|
return
|
|
}
|
|
|
|
// Check if there are any user bindings
|
|
count, err := model.GetBindingCountByProviderId(id)
|
|
if err != nil {
|
|
common.SysError("Failed to get binding count for provider " + strconv.Itoa(id) + ": " + err.Error())
|
|
common.ApiErrorMsg(c, "检查用户绑定时发生错误,请稍后重试")
|
|
return
|
|
}
|
|
if count > 0 {
|
|
common.ApiErrorMsg(c, "该 OAuth 提供商还有用户绑定,无法删除。请先解除所有用户绑定。")
|
|
return
|
|
}
|
|
|
|
if err := model.DeleteCustomOAuthProvider(id); err != nil {
|
|
common.ApiError(c, err)
|
|
return
|
|
}
|
|
|
|
// Unregister the provider from the OAuth registry
|
|
oauth.UnregisterCustomProvider(provider.Slug)
|
|
|
|
c.JSON(http.StatusOK, gin.H{
|
|
"success": true,
|
|
"message": "删除成功",
|
|
})
|
|
}
|
|
|
|
func buildUserOAuthBindingsResponse(userId int) ([]UserOAuthBindingResponse, error) {
|
|
bindings, err := model.GetUserOAuthBindingsByUserId(userId)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
response := make([]UserOAuthBindingResponse, 0, len(bindings))
|
|
for _, binding := range bindings {
|
|
provider, err := model.GetCustomOAuthProviderById(binding.ProviderId)
|
|
if err != nil {
|
|
continue
|
|
}
|
|
response = append(response, UserOAuthBindingResponse{
|
|
ProviderId: binding.ProviderId,
|
|
ProviderName: provider.Name,
|
|
ProviderSlug: provider.Slug,
|
|
ProviderIcon: provider.Icon,
|
|
ProviderUserId: binding.ProviderUserId,
|
|
})
|
|
}
|
|
|
|
return response, nil
|
|
}
|
|
|
|
// GetUserOAuthBindings returns all OAuth bindings for the current user
|
|
func GetUserOAuthBindings(c *gin.Context) {
|
|
userId := c.GetInt("id")
|
|
if userId == 0 {
|
|
common.ApiErrorMsg(c, "未登录")
|
|
return
|
|
}
|
|
|
|
response, err := buildUserOAuthBindingsResponse(userId)
|
|
if err != nil {
|
|
common.ApiError(c, err)
|
|
return
|
|
}
|
|
|
|
c.JSON(http.StatusOK, gin.H{
|
|
"success": true,
|
|
"message": "",
|
|
"data": response,
|
|
})
|
|
}
|
|
|
|
func GetUserOAuthBindingsByAdmin(c *gin.Context) {
|
|
userIdStr := c.Param("id")
|
|
userId, err := strconv.Atoi(userIdStr)
|
|
if err != nil {
|
|
common.ApiErrorMsg(c, "invalid user id")
|
|
return
|
|
}
|
|
|
|
targetUser, err := model.GetUserById(userId, false)
|
|
if err != nil {
|
|
common.ApiError(c, err)
|
|
return
|
|
}
|
|
|
|
myRole := c.GetInt("role")
|
|
if myRole <= targetUser.Role && myRole != common.RoleRootUser {
|
|
common.ApiErrorMsg(c, "no permission")
|
|
return
|
|
}
|
|
|
|
response, err := buildUserOAuthBindingsResponse(userId)
|
|
if err != nil {
|
|
common.ApiError(c, err)
|
|
return
|
|
}
|
|
|
|
c.JSON(http.StatusOK, gin.H{
|
|
"success": true,
|
|
"message": "",
|
|
"data": response,
|
|
})
|
|
}
|
|
|
|
// UnbindCustomOAuth unbinds a custom OAuth provider from the current user
|
|
func UnbindCustomOAuth(c *gin.Context) {
|
|
userId := c.GetInt("id")
|
|
if userId == 0 {
|
|
common.ApiErrorMsg(c, "未登录")
|
|
return
|
|
}
|
|
|
|
providerIdStr := c.Param("provider_id")
|
|
providerId, err := strconv.Atoi(providerIdStr)
|
|
if err != nil {
|
|
common.ApiErrorMsg(c, "无效的提供商 ID")
|
|
return
|
|
}
|
|
|
|
if err := model.DeleteUserOAuthBinding(userId, providerId); err != nil {
|
|
common.ApiError(c, err)
|
|
return
|
|
}
|
|
|
|
c.JSON(http.StatusOK, gin.H{
|
|
"success": true,
|
|
"message": "解绑成功",
|
|
})
|
|
}
|
|
|
|
func UnbindCustomOAuthByAdmin(c *gin.Context) {
|
|
userIdStr := c.Param("id")
|
|
userId, err := strconv.Atoi(userIdStr)
|
|
if err != nil {
|
|
common.ApiErrorMsg(c, "invalid user id")
|
|
return
|
|
}
|
|
|
|
targetUser, err := model.GetUserById(userId, false)
|
|
if err != nil {
|
|
common.ApiError(c, err)
|
|
return
|
|
}
|
|
|
|
myRole := c.GetInt("role")
|
|
if myRole <= targetUser.Role && myRole != common.RoleRootUser {
|
|
common.ApiErrorMsg(c, "no permission")
|
|
return
|
|
}
|
|
|
|
providerIdStr := c.Param("provider_id")
|
|
providerId, err := strconv.Atoi(providerIdStr)
|
|
if err != nil {
|
|
common.ApiErrorMsg(c, "invalid provider id")
|
|
return
|
|
}
|
|
|
|
if err := model.DeleteUserOAuthBinding(userId, providerId); err != nil {
|
|
common.ApiError(c, err)
|
|
return
|
|
}
|
|
|
|
c.JSON(http.StatusOK, gin.H{
|
|
"success": true,
|
|
"message": "success",
|
|
})
|
|
}
|