diff --git a/controller/custom_oauth.go b/controller/custom_oauth.go index e2245f880..3197a9163 100644 --- a/controller/custom_oauth.go +++ b/controller/custom_oauth.go @@ -1,8 +1,13 @@ package controller import ( + "context" + "io" "net/http" + "net/url" "strconv" + "strings" + "time" "github.com/QuantumNous/new-api/common" "github.com/QuantumNous/new-api/model" @@ -16,6 +21,7 @@ type CustomOAuthProviderResponse struct { Id int `json:"id"` Name string `json:"name"` Slug string `json:"slug"` + Icon string `json:"icon"` Enabled bool `json:"enabled"` ClientId string `json:"client_id"` AuthorizationEndpoint string `json:"authorization_endpoint"` @@ -28,6 +34,8 @@ type CustomOAuthProviderResponse struct { EmailField string `json:"email_field"` WellKnown string `json:"well_known"` AuthStyle int `json:"auth_style"` + AccessPolicy string `json:"access_policy"` + AccessDeniedMessage string `json:"access_denied_message"` } func toCustomOAuthProviderResponse(p *model.CustomOAuthProvider) *CustomOAuthProviderResponse { @@ -35,6 +43,7 @@ func toCustomOAuthProviderResponse(p *model.CustomOAuthProvider) *CustomOAuthPro Id: p.Id, Name: p.Name, Slug: p.Slug, + Icon: p.Icon, Enabled: p.Enabled, ClientId: p.ClientId, AuthorizationEndpoint: p.AuthorizationEndpoint, @@ -47,6 +56,8 @@ func toCustomOAuthProviderResponse(p *model.CustomOAuthProvider) *CustomOAuthPro EmailField: p.EmailField, WellKnown: p.WellKnown, AuthStyle: p.AuthStyle, + AccessPolicy: p.AccessPolicy, + AccessDeniedMessage: p.AccessDeniedMessage, } } @@ -96,6 +107,7 @@ func GetCustomOAuthProvider(c *gin.Context) { type CreateCustomOAuthProviderRequest struct { Name string `json:"name" binding:"required"` Slug string `json:"slug" binding:"required"` + Icon string `json:"icon"` Enabled bool `json:"enabled"` ClientId string `json:"client_id" binding:"required"` ClientSecret string `json:"client_secret" binding:"required"` @@ -109,6 +121,85 @@ type CreateCustomOAuthProviderRequest struct { EmailField string `json:"email_field"` WellKnown string `json:"well_known"` AuthStyle int `json:"auth_style"` + AccessPolicy string `json:"access_policy"` + AccessDeniedMessage string `json:"access_denied_message"` +} + +type FetchCustomOAuthDiscoveryRequest struct { + WellKnownURL string `json:"well_known_url"` + IssuerURL string `json:"issuer_url"` +} + +// FetchCustomOAuthDiscovery fetches OIDC discovery document via backend (root-only route) +func FetchCustomOAuthDiscovery(c *gin.Context) { + var req FetchCustomOAuthDiscoveryRequest + if err := c.ShouldBindJSON(&req); err != nil { + common.ApiErrorMsg(c, "无效的请求参数: "+err.Error()) + return + } + + wellKnownURL := strings.TrimSpace(req.WellKnownURL) + issuerURL := strings.TrimSpace(req.IssuerURL) + + if wellKnownURL == "" && issuerURL == "" { + common.ApiErrorMsg(c, "请先填写 Discovery URL 或 Issuer URL") + return + } + + targetURL := wellKnownURL + if targetURL == "" { + targetURL = strings.TrimRight(issuerURL, "/") + "/.well-known/openid-configuration" + } + targetURL = strings.TrimSpace(targetURL) + + parsedURL, err := url.Parse(targetURL) + if err != nil || parsedURL.Host == "" || (parsedURL.Scheme != "http" && parsedURL.Scheme != "https") { + common.ApiErrorMsg(c, "Discovery URL 无效,仅支持 http/https") + return + } + + ctx, cancel := context.WithTimeout(c.Request.Context(), 20*time.Second) + defer cancel() + + httpReq, err := http.NewRequestWithContext(ctx, http.MethodGet, targetURL, nil) + if err != nil { + common.ApiErrorMsg(c, "创建 Discovery 请求失败: "+err.Error()) + return + } + httpReq.Header.Set("Accept", "application/json") + + client := &http.Client{Timeout: 20 * time.Second} + resp, err := client.Do(httpReq) + if err != nil { + common.ApiErrorMsg(c, "获取 Discovery 配置失败: "+err.Error()) + return + } + defer resp.Body.Close() + + if resp.StatusCode != http.StatusOK { + body, _ := io.ReadAll(io.LimitReader(resp.Body, 512)) + message := strings.TrimSpace(string(body)) + if message == "" { + message = resp.Status + } + common.ApiErrorMsg(c, "获取 Discovery 配置失败: "+message) + return + } + + var discovery map[string]any + if err = common.DecodeJson(resp.Body, &discovery); err != nil { + common.ApiErrorMsg(c, "解析 Discovery 配置失败: "+err.Error()) + return + } + + c.JSON(http.StatusOK, gin.H{ + "success": true, + "message": "", + "data": gin.H{ + "well_known_url": targetURL, + "discovery": discovery, + }, + }) } // CreateCustomOAuthProvider creates a new custom OAuth provider @@ -134,6 +225,7 @@ func CreateCustomOAuthProvider(c *gin.Context) { provider := &model.CustomOAuthProvider{ Name: req.Name, Slug: req.Slug, + Icon: req.Icon, Enabled: req.Enabled, ClientId: req.ClientId, ClientSecret: req.ClientSecret, @@ -147,6 +239,8 @@ func CreateCustomOAuthProvider(c *gin.Context) { EmailField: req.EmailField, WellKnown: req.WellKnown, AuthStyle: req.AuthStyle, + AccessPolicy: req.AccessPolicy, + AccessDeniedMessage: req.AccessDeniedMessage, } if err := model.CreateCustomOAuthProvider(provider); err != nil { @@ -168,9 +262,10 @@ func CreateCustomOAuthProvider(c *gin.Context) { type UpdateCustomOAuthProviderRequest struct { Name string `json:"name"` Slug string `json:"slug"` - Enabled *bool `json:"enabled"` // Optional: if nil, keep existing + Icon *string `json:"icon"` // Optional: if nil, keep existing + Enabled *bool `json:"enabled"` // Optional: if nil, keep existing ClientId string `json:"client_id"` - ClientSecret string `json:"client_secret"` // Optional: if empty, keep existing + ClientSecret string `json:"client_secret"` // Optional: if empty, keep existing AuthorizationEndpoint string `json:"authorization_endpoint"` TokenEndpoint string `json:"token_endpoint"` UserInfoEndpoint string `json:"user_info_endpoint"` @@ -181,6 +276,8 @@ type UpdateCustomOAuthProviderRequest struct { EmailField string `json:"email_field"` WellKnown *string `json:"well_known"` // Optional: if nil, keep existing AuthStyle *int `json:"auth_style"` // Optional: if nil, keep existing + AccessPolicy *string `json:"access_policy"` // Optional: if nil, keep existing + AccessDeniedMessage *string `json:"access_denied_message"` // Optional: if nil, keep existing } // UpdateCustomOAuthProvider updates an existing custom OAuth provider @@ -227,6 +324,9 @@ func UpdateCustomOAuthProvider(c *gin.Context) { if req.Slug != "" { provider.Slug = req.Slug } + if req.Icon != nil { + provider.Icon = *req.Icon + } if req.Enabled != nil { provider.Enabled = *req.Enabled } @@ -266,6 +366,12 @@ func UpdateCustomOAuthProvider(c *gin.Context) { if req.AuthStyle != nil { provider.AuthStyle = *req.AuthStyle } + if req.AccessPolicy != nil { + provider.AccessPolicy = *req.AccessPolicy + } + if req.AccessDeniedMessage != nil { + provider.AccessDeniedMessage = *req.AccessDeniedMessage + } if err := model.UpdateCustomOAuthProvider(provider); err != nil { common.ApiError(c, err) @@ -346,6 +452,7 @@ func GetUserOAuthBindings(c *gin.Context) { ProviderId int `json:"provider_id"` ProviderName string `json:"provider_name"` ProviderSlug string `json:"provider_slug"` + ProviderIcon string `json:"provider_icon"` ProviderUserId string `json:"provider_user_id"` } @@ -359,6 +466,7 @@ func GetUserOAuthBindings(c *gin.Context) { ProviderId: binding.ProviderId, ProviderName: provider.Name, ProviderSlug: provider.Slug, + ProviderIcon: provider.Icon, ProviderUserId: binding.ProviderUserId, }) } diff --git a/controller/misc.go b/controller/misc.go index a16e2d554..b24a74adf 100644 --- a/controller/misc.go +++ b/controller/misc.go @@ -134,8 +134,10 @@ func GetStatus(c *gin.Context) { customProviders := oauth.GetEnabledCustomProviders() if len(customProviders) > 0 { type CustomOAuthInfo struct { + Id int `json:"id"` Name string `json:"name"` Slug string `json:"slug"` + Icon string `json:"icon"` ClientId string `json:"client_id"` AuthorizationEndpoint string `json:"authorization_endpoint"` Scopes string `json:"scopes"` @@ -144,8 +146,10 @@ func GetStatus(c *gin.Context) { for _, p := range customProviders { config := p.GetConfig() providersInfo = append(providersInfo, CustomOAuthInfo{ + Id: config.Id, Name: config.Name, Slug: config.Slug, + Icon: config.Icon, ClientId: config.ClientId, AuthorizationEndpoint: config.AuthorizationEndpoint, Scopes: config.Scopes, diff --git a/controller/oauth.go b/controller/oauth.go index 65e18f9da..75ab29898 100644 --- a/controller/oauth.go +++ b/controller/oauth.go @@ -295,12 +295,12 @@ func findOrCreateOAuthUser(c *gin.Context, provider oauth.Provider, oauthUser *o // Set the provider user ID on the user model and update provider.SetProviderUserID(user, oauthUser.ProviderUserID) if err := tx.Model(user).Updates(map[string]interface{}{ - "github_id": user.GitHubId, - "discord_id": user.DiscordId, - "oidc_id": user.OidcId, - "linux_do_id": user.LinuxDOId, - "wechat_id": user.WeChatId, - "telegram_id": user.TelegramId, + "github_id": user.GitHubId, + "discord_id": user.DiscordId, + "oidc_id": user.OidcId, + "linux_do_id": user.LinuxDOId, + "wechat_id": user.WeChatId, + "telegram_id": user.TelegramId, }).Error; err != nil { return err } @@ -340,6 +340,8 @@ func handleOAuthError(c *gin.Context, err error) { } else { common.ApiErrorI18n(c, e.MsgKey) } + case *oauth.AccessDeniedError: + common.ApiErrorMsg(c, e.Message) case *oauth.TrustLevelError: common.ApiErrorI18n(c, i18n.MsgOAuthTrustLevelLow) default: diff --git a/model/custom_oauth_provider.go b/model/custom_oauth_provider.go index 43c69833a..12b4d1111 100644 --- a/model/custom_oauth_provider.go +++ b/model/custom_oauth_provider.go @@ -2,32 +2,65 @@ package model import ( "errors" + "fmt" "strings" "time" + + "github.com/QuantumNous/new-api/common" ) +type accessPolicyPayload struct { + Logic string `json:"logic"` + Conditions []accessConditionItem `json:"conditions"` + Groups []accessPolicyPayload `json:"groups"` +} + +type accessConditionItem struct { + Field string `json:"field"` + Op string `json:"op"` + Value any `json:"value"` +} + +var supportedAccessPolicyOps = map[string]struct{}{ + "eq": {}, + "ne": {}, + "gt": {}, + "gte": {}, + "lt": {}, + "lte": {}, + "in": {}, + "not_in": {}, + "contains": {}, + "not_contains": {}, + "exists": {}, + "not_exists": {}, +} + // CustomOAuthProvider stores configuration for custom OAuth providers type CustomOAuthProvider struct { - Id int `json:"id" gorm:"primaryKey"` - Name string `json:"name" gorm:"type:varchar(64);not null"` // Display name, e.g., "GitHub Enterprise" - Slug string `json:"slug" gorm:"type:varchar(64);uniqueIndex;not null"` // URL identifier, e.g., "github-enterprise" - Enabled bool `json:"enabled" gorm:"default:false"` // Whether this provider is enabled - ClientId string `json:"client_id" gorm:"type:varchar(256)"` // OAuth client ID - ClientSecret string `json:"-" gorm:"type:varchar(512)"` // OAuth client secret (not returned to frontend) - AuthorizationEndpoint string `json:"authorization_endpoint" gorm:"type:varchar(512)"` // Authorization URL - TokenEndpoint string `json:"token_endpoint" gorm:"type:varchar(512)"` // Token exchange URL - UserInfoEndpoint string `json:"user_info_endpoint" gorm:"type:varchar(512)"` // User info URL - Scopes string `json:"scopes" gorm:"type:varchar(256);default:'openid profile email'"` // OAuth scopes + Id int `json:"id" gorm:"primaryKey"` + Name string `json:"name" gorm:"type:varchar(64);not null"` // Display name, e.g., "GitHub Enterprise" + Slug string `json:"slug" gorm:"type:varchar(64);uniqueIndex;not null"` // URL identifier, e.g., "github-enterprise" + Icon string `json:"icon" gorm:"type:varchar(128);default:''"` // Icon name from @lobehub/icons + Enabled bool `json:"enabled" gorm:"default:false"` // Whether this provider is enabled + ClientId string `json:"client_id" gorm:"type:varchar(256)"` // OAuth client ID + ClientSecret string `json:"-" gorm:"type:varchar(512)"` // OAuth client secret (not returned to frontend) + AuthorizationEndpoint string `json:"authorization_endpoint" gorm:"type:varchar(512)"` // Authorization URL + TokenEndpoint string `json:"token_endpoint" gorm:"type:varchar(512)"` // Token exchange URL + UserInfoEndpoint string `json:"user_info_endpoint" gorm:"type:varchar(512)"` // User info URL + Scopes string `json:"scopes" gorm:"type:varchar(256);default:'openid profile email'"` // OAuth scopes // Field mapping configuration (supports JSONPath via gjson) - UserIdField string `json:"user_id_field" gorm:"type:varchar(128);default:'sub'"` // User ID field path, e.g., "sub", "id", "data.user.id" - UsernameField string `json:"username_field" gorm:"type:varchar(128);default:'preferred_username'"` // Username field path - DisplayNameField string `json:"display_name_field" gorm:"type:varchar(128);default:'name'"` // Display name field path - EmailField string `json:"email_field" gorm:"type:varchar(128);default:'email'"` // Email field path + UserIdField string `json:"user_id_field" gorm:"type:varchar(128);default:'sub'"` // User ID field path, e.g., "sub", "id", "data.user.id" + UsernameField string `json:"username_field" gorm:"type:varchar(128);default:'preferred_username'"` // Username field path + DisplayNameField string `json:"display_name_field" gorm:"type:varchar(128);default:'name'"` // Display name field path + EmailField string `json:"email_field" gorm:"type:varchar(128);default:'email'"` // Email field path // Advanced options - WellKnown string `json:"well_known" gorm:"type:varchar(512)"` // OIDC discovery endpoint (optional) - AuthStyle int `json:"auth_style" gorm:"default:0"` // 0=auto, 1=params, 2=header (Basic Auth) + WellKnown string `json:"well_known" gorm:"type:varchar(512)"` // OIDC discovery endpoint (optional) + AuthStyle int `json:"auth_style" gorm:"default:0"` // 0=auto, 1=params, 2=header (Basic Auth) + AccessPolicy string `json:"access_policy" gorm:"type:text"` // JSON policy for access control based on user info + AccessDeniedMessage string `json:"access_denied_message" gorm:"type:varchar(512)"` // Custom error message template when access is denied CreatedAt time.Time `json:"created_at"` UpdatedAt time.Time `json:"updated_at"` @@ -158,6 +191,57 @@ func validateCustomOAuthProvider(provider *CustomOAuthProvider) error { if provider.Scopes == "" { provider.Scopes = "openid profile email" } + if strings.TrimSpace(provider.AccessPolicy) != "" { + var policy accessPolicyPayload + if err := common.UnmarshalJsonStr(provider.AccessPolicy, &policy); err != nil { + return errors.New("access_policy must be valid JSON") + } + if err := validateAccessPolicyPayload(&policy); err != nil { + return fmt.Errorf("access_policy is invalid: %w", err) + } + } + + return nil +} + +func validateAccessPolicyPayload(policy *accessPolicyPayload) error { + if policy == nil { + return errors.New("policy is nil") + } + + logic := strings.ToLower(strings.TrimSpace(policy.Logic)) + if logic == "" { + logic = "and" + } + if logic != "and" && logic != "or" { + return fmt.Errorf("unsupported logic: %s", logic) + } + + if len(policy.Conditions) == 0 && len(policy.Groups) == 0 { + return errors.New("policy requires at least one condition or group") + } + + for index, condition := range policy.Conditions { + field := strings.TrimSpace(condition.Field) + if field == "" { + return fmt.Errorf("condition[%d].field is required", index) + } + op := strings.ToLower(strings.TrimSpace(condition.Op)) + if _, ok := supportedAccessPolicyOps[op]; !ok { + return fmt.Errorf("condition[%d].op is unsupported: %s", index, op) + } + if op == "in" || op == "not_in" { + if _, ok := condition.Value.([]any); !ok { + return fmt.Errorf("condition[%d].value must be an array for op %s", index, op) + } + } + } + + for index := range policy.Groups { + if err := validateAccessPolicyPayload(&policy.Groups[index]); err != nil { + return fmt.Errorf("group[%d]: %w", index, err) + } + } return nil } diff --git a/oauth/generic.go b/oauth/generic.go index c7aa87931..bc18054d5 100644 --- a/oauth/generic.go +++ b/oauth/generic.go @@ -3,19 +3,24 @@ package oauth import ( "context" "encoding/base64" - "encoding/json" + stdjson "encoding/json" + "errors" "fmt" "io" "net/http" "net/url" + "regexp" + "strconv" "strings" "time" + "github.com/QuantumNous/new-api/common" "github.com/QuantumNous/new-api/i18n" "github.com/QuantumNous/new-api/logger" "github.com/QuantumNous/new-api/model" "github.com/QuantumNous/new-api/setting/system_setting" "github.com/gin-gonic/gin" + "github.com/samber/lo" "github.com/tidwall/gjson" ) @@ -31,6 +36,40 @@ type GenericOAuthProvider struct { config *model.CustomOAuthProvider } +type accessPolicy struct { + Logic string `json:"logic"` + Conditions []accessCondition `json:"conditions"` + Groups []accessPolicy `json:"groups"` +} + +type accessCondition struct { + Field string `json:"field"` + Op string `json:"op"` + Value any `json:"value"` +} + +type accessPolicyFailure struct { + Field string + Op string + Expected any + Current any +} + +var supportedAccessPolicyOps = []string{ + "eq", + "ne", + "gt", + "gte", + "lt", + "lte", + "in", + "not_in", + "contains", + "not_contains", + "exists", + "not_exists", +} + // NewGenericOAuthProvider creates a new generic OAuth provider from config func NewGenericOAuthProvider(config *model.CustomOAuthProvider) *GenericOAuthProvider { return &GenericOAuthProvider{config: config} @@ -125,7 +164,7 @@ func (p *GenericOAuthProvider) ExchangeToken(ctx context.Context, code string, c ErrorDesc string `json:"error_description"` } - if err := json.Unmarshal(body, &tokenResponse); err != nil { + if err := common.Unmarshal(body, &tokenResponse); err != nil { // Try to parse as URL-encoded (some OAuth servers like GitHub return this format) parsedValues, parseErr := url.ParseQuery(bodyStr) if parseErr != nil { @@ -227,11 +266,30 @@ func (p *GenericOAuthProvider) GetUserInfo(ctx context.Context, token *OAuthToke logger.LogDebug(ctx, "[OAuth-Generic-%s] GetUserInfo success: id=%s, username=%s, name=%s, email=%s", p.config.Slug, userId, username, displayName, email) + policyRaw := strings.TrimSpace(p.config.AccessPolicy) + if policyRaw != "" { + policy, err := parseAccessPolicy(policyRaw) + if err != nil { + logger.LogError(ctx, fmt.Sprintf("[OAuth-Generic-%s] invalid access policy: %s", p.config.Slug, err.Error())) + return nil, NewOAuthErrorWithRaw(i18n.MsgOAuthGetUserErr, nil, "invalid access policy configuration") + } + allowed, failure := evaluateAccessPolicy(bodyStr, policy) + if !allowed { + message := renderAccessDeniedMessage(p.config.AccessDeniedMessage, p.config.Name, bodyStr, failure) + logger.LogWarn(ctx, fmt.Sprintf("[OAuth-Generic-%s] access denied by policy: field=%s op=%s expected=%v current=%v", + p.config.Slug, failure.Field, failure.Op, failure.Expected, failure.Current)) + return nil, &AccessDeniedError{Message: message} + } + } + return &OAuthUser{ ProviderUserID: userId, Username: username, DisplayName: displayName, Email: email, + Extra: map[string]any{ + "provider": p.config.Slug, + }, }, nil } @@ -266,3 +324,345 @@ func (p *GenericOAuthProvider) GetProviderId() int { func (p *GenericOAuthProvider) IsGenericProvider() bool { return true } + +func parseAccessPolicy(raw string) (*accessPolicy, error) { + var policy accessPolicy + if err := common.UnmarshalJsonStr(raw, &policy); err != nil { + return nil, err + } + if err := validateAccessPolicy(&policy); err != nil { + return nil, err + } + return &policy, nil +} + +func validateAccessPolicy(policy *accessPolicy) error { + if policy == nil { + return errors.New("policy is nil") + } + + logic := strings.ToLower(strings.TrimSpace(policy.Logic)) + if logic == "" { + logic = "and" + } + if !lo.Contains([]string{"and", "or"}, logic) { + return fmt.Errorf("unsupported policy logic: %s", logic) + } + policy.Logic = logic + + if len(policy.Conditions) == 0 && len(policy.Groups) == 0 { + return errors.New("policy requires at least one condition or group") + } + + for index := range policy.Conditions { + if err := validateAccessCondition(&policy.Conditions[index], index); err != nil { + return err + } + } + + for index := range policy.Groups { + if err := validateAccessPolicy(&policy.Groups[index]); err != nil { + return fmt.Errorf("invalid policy group[%d]: %w", index, err) + } + } + + return nil +} + +func validateAccessCondition(condition *accessCondition, index int) error { + if condition == nil { + return fmt.Errorf("condition[%d] is nil", index) + } + + condition.Field = strings.TrimSpace(condition.Field) + if condition.Field == "" { + return fmt.Errorf("condition[%d].field is required", index) + } + + condition.Op = normalizePolicyOp(condition.Op) + if !lo.Contains(supportedAccessPolicyOps, condition.Op) { + return fmt.Errorf("condition[%d].op is unsupported: %s", index, condition.Op) + } + + if lo.Contains([]string{"in", "not_in"}, condition.Op) { + if _, ok := condition.Value.([]any); !ok { + return fmt.Errorf("condition[%d].value must be an array for op %s", index, condition.Op) + } + } + + return nil +} + +func evaluateAccessPolicy(body string, policy *accessPolicy) (bool, *accessPolicyFailure) { + if policy == nil { + return true, nil + } + + logic := strings.ToLower(strings.TrimSpace(policy.Logic)) + if logic == "" { + logic = "and" + } + + hasAny := len(policy.Conditions) > 0 || len(policy.Groups) > 0 + if !hasAny { + return true, nil + } + + if logic == "or" { + var firstFailure *accessPolicyFailure + for _, cond := range policy.Conditions { + ok, failure := evaluateAccessCondition(body, cond) + if ok { + return true, nil + } + if firstFailure == nil { + firstFailure = failure + } + } + for _, group := range policy.Groups { + ok, failure := evaluateAccessPolicy(body, &group) + if ok { + return true, nil + } + if firstFailure == nil { + firstFailure = failure + } + } + return false, firstFailure + } + + for _, cond := range policy.Conditions { + ok, failure := evaluateAccessCondition(body, cond) + if !ok { + return false, failure + } + } + for _, group := range policy.Groups { + ok, failure := evaluateAccessPolicy(body, &group) + if !ok { + return false, failure + } + } + return true, nil +} + +func evaluateAccessCondition(body string, cond accessCondition) (bool, *accessPolicyFailure) { + path := cond.Field + op := cond.Op + result := gjson.Get(body, path) + current := gjsonResultToValue(result) + failure := &accessPolicyFailure{ + Field: path, + Op: op, + Expected: cond.Value, + Current: current, + } + + switch op { + case "exists": + return result.Exists(), failure + case "not_exists": + return !result.Exists(), failure + case "eq": + return compareAny(current, cond.Value) == 0, failure + case "ne": + return compareAny(current, cond.Value) != 0, failure + case "gt": + return compareAny(current, cond.Value) > 0, failure + case "gte": + return compareAny(current, cond.Value) >= 0, failure + case "lt": + return compareAny(current, cond.Value) < 0, failure + case "lte": + return compareAny(current, cond.Value) <= 0, failure + case "in": + return valueInSlice(current, cond.Value), failure + case "not_in": + return !valueInSlice(current, cond.Value), failure + case "contains": + return containsValue(current, cond.Value), failure + case "not_contains": + return !containsValue(current, cond.Value), failure + default: + return false, failure + } +} + +func normalizePolicyOp(op string) string { + return strings.ToLower(strings.TrimSpace(op)) +} + +func gjsonResultToValue(result gjson.Result) any { + if !result.Exists() { + return nil + } + if result.IsArray() { + arr := result.Array() + values := make([]any, 0, len(arr)) + for _, item := range arr { + values = append(values, gjsonResultToValue(item)) + } + return values + } + switch result.Type { + case gjson.Null: + return nil + case gjson.True: + return true + case gjson.False: + return false + case gjson.Number: + return result.Num + case gjson.String: + return result.String() + case gjson.JSON: + var data any + if err := common.UnmarshalJsonStr(result.Raw, &data); err == nil { + return data + } + return result.Raw + default: + return result.Value() + } +} + +func compareAny(left any, right any) int { + if lf, ok := toFloat(left); ok { + if rf, ok2 := toFloat(right); ok2 { + switch { + case lf < rf: + return -1 + case lf > rf: + return 1 + default: + return 0 + } + } + } + + ls := strings.TrimSpace(fmt.Sprint(left)) + rs := strings.TrimSpace(fmt.Sprint(right)) + switch { + case ls < rs: + return -1 + case ls > rs: + return 1 + default: + return 0 + } +} + +func toFloat(v any) (float64, bool) { + switch value := v.(type) { + case float64: + return value, true + case float32: + return float64(value), true + case int: + return float64(value), true + case int8: + return float64(value), true + case int16: + return float64(value), true + case int32: + return float64(value), true + case int64: + return float64(value), true + case uint: + return float64(value), true + case uint8: + return float64(value), true + case uint16: + return float64(value), true + case uint32: + return float64(value), true + case uint64: + return float64(value), true + case stdjson.Number: + n, err := value.Float64() + if err == nil { + return n, true + } + case string: + n, err := strconv.ParseFloat(strings.TrimSpace(value), 64) + if err == nil { + return n, true + } + } + return 0, false +} + +func valueInSlice(current any, expected any) bool { + list, ok := expected.([]any) + if !ok { + return false + } + return lo.ContainsBy(list, func(item any) bool { + return compareAny(current, item) == 0 + }) +} + +func containsValue(current any, expected any) bool { + switch value := current.(type) { + case string: + target := strings.TrimSpace(fmt.Sprint(expected)) + return strings.Contains(value, target) + case []any: + return lo.ContainsBy(value, func(item any) bool { + return compareAny(item, expected) == 0 + }) + } + return false +} + +func renderAccessDeniedMessage(template string, providerName string, body string, failure *accessPolicyFailure) string { + defaultMessage := "Access denied: your account does not meet this provider's access requirements." + message := strings.TrimSpace(template) + if message == "" { + return defaultMessage + } + + if failure == nil { + failure = &accessPolicyFailure{} + } + + replacements := map[string]string{ + "{{provider}}": providerName, + "{{field}}": failure.Field, + "{{op}}": failure.Op, + "{{required}}": fmt.Sprint(failure.Expected), + "{{current}}": fmt.Sprint(failure.Current), + } + + for key, value := range replacements { + message = strings.ReplaceAll(message, key, value) + } + + currentPattern := regexp.MustCompile(`\{\{current\.([^}]+)\}\}`) + message = currentPattern.ReplaceAllStringFunc(message, func(token string) string { + match := currentPattern.FindStringSubmatch(token) + if len(match) != 2 { + return "" + } + path := strings.TrimSpace(match[1]) + if path == "" { + return "" + } + return strings.TrimSpace(gjson.Get(body, path).String()) + }) + + requiredPattern := regexp.MustCompile(`\{\{required\.([^}]+)\}\}`) + message = requiredPattern.ReplaceAllStringFunc(message, func(token string) string { + match := requiredPattern.FindStringSubmatch(token) + if len(match) != 2 { + return "" + } + path := strings.TrimSpace(match[1]) + if failure.Field == path { + return fmt.Sprint(failure.Expected) + } + return "" + }) + + return strings.TrimSpace(message) +} diff --git a/oauth/types.go b/oauth/types.go index 1b0e3646a..383e6f351 100644 --- a/oauth/types.go +++ b/oauth/types.go @@ -57,3 +57,12 @@ func NewOAuthErrorWithRaw(msgKey string, params map[string]any, rawError string) RawError: rawError, } } + +// AccessDeniedError is a direct user-facing access denial message. +type AccessDeniedError struct { + Message string +} + +func (e *AccessDeniedError) Error() string { + return e.Message +} diff --git a/router/api-router.go b/router/api-router.go index e2ef2f531..d60ba39b2 100644 --- a/router/api-router.go +++ b/router/api-router.go @@ -170,10 +170,11 @@ func SetApiRouter(router *gin.Engine) { optionRoute.POST("/migrate_console_setting", controller.MigrateConsoleSetting) // 用于迁移检测的旧键,下个版本会删除 } - // Custom OAuth provider management (admin only) + // Custom OAuth provider management (root only) customOAuthRoute := apiRouter.Group("/custom-oauth-provider") customOAuthRoute.Use(middleware.RootAuth()) { + customOAuthRoute.POST("/discovery", controller.FetchCustomOAuthDiscovery) customOAuthRoute.GET("/", controller.GetCustomOAuthProviders) customOAuthRoute.GET("/:id", controller.GetCustomOAuthProvider) customOAuthRoute.POST("/", controller.CreateCustomOAuthProvider) diff --git a/web/src/components/auth/LoginForm.jsx b/web/src/components/auth/LoginForm.jsx index 636317e44..7e8c0ce01 100644 --- a/web/src/components/auth/LoginForm.jsx +++ b/web/src/components/auth/LoginForm.jsx @@ -29,6 +29,7 @@ import { showSuccess, updateAPI, getSystemName, + getOAuthProviderIcon, setUserData, onGitHubOAuthClicked, onDiscordOAuthClicked, @@ -130,6 +131,17 @@ const LoginForm = () => { return {}; } }, [statusState?.status]); + const hasCustomOAuthProviders = + (status.custom_oauth_providers || []).length > 0; + const hasOAuthLoginOptions = Boolean( + status.github_oauth || + status.discord_oauth || + status.oidc_enabled || + status.wechat_login || + status.linuxdo_oauth || + status.telegram_oauth || + hasCustomOAuthProviders, + ); useEffect(() => { if (status?.turnstile_check) { @@ -598,7 +610,7 @@ const LoginForm = () => { theme='outline' className='w-full h-12 flex items-center justify-center !rounded-full border border-gray-200 hover:bg-gray-50 transition-colors' type='tertiary' - icon={} + icon={getOAuthProviderIcon(provider.icon || '', 20)} onClick={() => handleCustomOAuthClick(provider)} loading={customOAuthLoading[provider.slug]} > @@ -817,12 +829,7 @@ const LoginForm = () => { - {(status.github_oauth || - status.discord_oauth || - status.oidc_enabled || - status.wechat_login || - status.linuxdo_oauth || - status.telegram_oauth) && ( + {hasOAuthLoginOptions && ( <> {t('或')} @@ -952,14 +959,7 @@ const LoginForm = () => { />
{showEmailLogin || - !( - status.github_oauth || - status.discord_oauth || - status.oidc_enabled || - status.wechat_login || - status.linuxdo_oauth || - status.telegram_oauth - ) + !hasOAuthLoginOptions ? renderEmailLoginForm() : renderOAuthOptions()} {renderWeChatLoginModal()} diff --git a/web/src/components/auth/RegisterForm.jsx b/web/src/components/auth/RegisterForm.jsx index 2edc499b1..0a755b194 100644 --- a/web/src/components/auth/RegisterForm.jsx +++ b/web/src/components/auth/RegisterForm.jsx @@ -27,8 +27,10 @@ import { showSuccess, updateAPI, getSystemName, + getOAuthProviderIcon, setUserData, onDiscordOAuthClicked, + onCustomOAuthClicked, } from '../../helpers'; import Turnstile from 'react-turnstile'; import { @@ -98,6 +100,7 @@ const RegisterForm = () => { const [otherRegisterOptionsLoading, setOtherRegisterOptionsLoading] = useState(false); const [wechatCodeSubmitLoading, setWechatCodeSubmitLoading] = useState(false); + const [customOAuthLoading, setCustomOAuthLoading] = useState({}); const [disableButton, setDisableButton] = useState(false); const [countdown, setCountdown] = useState(30); const [agreedToTerms, setAgreedToTerms] = useState(false); @@ -126,6 +129,17 @@ const RegisterForm = () => { return {}; } }, [statusState?.status]); + const hasCustomOAuthProviders = + (status.custom_oauth_providers || []).length > 0; + const hasOAuthRegisterOptions = Boolean( + status.github_oauth || + status.discord_oauth || + status.oidc_enabled || + status.wechat_login || + status.linuxdo_oauth || + status.telegram_oauth || + hasCustomOAuthProviders, + ); const [showEmailVerification, setShowEmailVerification] = useState(false); @@ -319,6 +333,17 @@ const RegisterForm = () => { } }; + const handleCustomOAuthClick = (provider) => { + setCustomOAuthLoading((prev) => ({ ...prev, [provider.slug]: true })); + try { + onCustomOAuthClicked(provider, { shouldLogout: true }); + } finally { + setTimeout(() => { + setCustomOAuthLoading((prev) => ({ ...prev, [provider.slug]: false })); + }, 3000); + } + }; + const handleEmailRegisterClick = () => { setEmailRegisterLoading(true); setShowEmailRegister(true); @@ -469,6 +494,23 @@ const RegisterForm = () => { )} + {status.custom_oauth_providers && + status.custom_oauth_providers.map((provider) => ( + + ))} + {status.telegram_oauth && (
{
- {(status.github_oauth || - status.discord_oauth || - status.oidc_enabled || - status.wechat_login || - status.linuxdo_oauth || - status.telegram_oauth) && ( + {hasOAuthRegisterOptions && ( <> {t('或')} @@ -745,14 +782,7 @@ const RegisterForm = () => { />
{showEmailRegister || - !( - status.github_oauth || - status.discord_oauth || - status.oidc_enabled || - status.wechat_login || - status.linuxdo_oauth || - status.telegram_oauth - ) + !hasOAuthRegisterOptions ? renderEmailRegisterForm() : renderOAuthOptions()} {renderWeChatLoginModal()} diff --git a/web/src/components/settings/CustomOAuthSetting.jsx b/web/src/components/settings/CustomOAuthSetting.jsx index 4b6df4c81..0912160be 100644 --- a/web/src/components/settings/CustomOAuthSetting.jsx +++ b/web/src/components/settings/CustomOAuthSetting.jsx @@ -27,14 +27,20 @@ import { Modal, Banner, Card, + Collapse, + Switch, Table, Tag, Popconfirm, Space, - Select, } from '@douyinfe/semi-ui'; -import { IconPlus, IconEdit, IconDelete } from '@douyinfe/semi-icons'; -import { API, showError, showSuccess } from '../../helpers'; +import { + IconPlus, + IconEdit, + IconDelete, + IconRefresh, +} from '@douyinfe/semi-icons'; +import { API, showError, showSuccess, getOAuthProviderIcon } from '../../helpers'; import { useTranslation } from 'react-i18next'; const { Text } = Typography; @@ -120,6 +126,69 @@ const OAUTH_PRESETS = { }, }; +const OAUTH_PRESET_ICONS = { + 'github-enterprise': 'github', + gitlab: 'gitlab', + gitea: 'gitea', + nextcloud: 'nextcloud', + keycloak: 'keycloak', + authentik: 'authentik', + ory: 'openid', +}; + +const getPresetIcon = (preset) => OAUTH_PRESET_ICONS[preset] || ''; + +const PRESET_RESET_VALUES = { + name: '', + slug: '', + icon: '', + authorization_endpoint: '', + token_endpoint: '', + user_info_endpoint: '', + scopes: '', + user_id_field: '', + username_field: '', + display_name_field: '', + email_field: '', + well_known: '', + auth_style: 0, + access_policy: '', + access_denied_message: '', +}; + +const DISCOVERY_FIELD_LABELS = { + authorization_endpoint: 'Authorization Endpoint', + token_endpoint: 'Token Endpoint', + user_info_endpoint: 'User Info Endpoint', + scopes: 'Scopes', + user_id_field: 'User ID Field', + username_field: 'Username Field', + display_name_field: 'Display Name Field', + email_field: 'Email Field', +}; + +const ACCESS_POLICY_TEMPLATES = { + level_active: `{ + "logic": "and", + "conditions": [ + {"field": "trust_level", "op": "gte", "value": 2}, + {"field": "active", "op": "eq", "value": true} + ] +}`, + org_or_role: `{ + "logic": "or", + "conditions": [ + {"field": "org", "op": "eq", "value": "core"}, + {"field": "roles", "op": "contains", "value": "admin"} + ] +}`, +}; + +const ACCESS_DENIED_TEMPLATES = { + level_hint: '需要等级 {{required}},你当前等级 {{current}}(字段:{{field}})', + org_hint: '仅限指定组织或角色访问。组织={{current.org}},角色={{current.roles}}', +}; + const CustomOAuthSetting = ({ serverAddress }) => { const { t } = useTranslation(); const [providers, setProviders] = useState([]); @@ -129,8 +198,47 @@ const CustomOAuthSetting = ({ serverAddress }) => { const [formValues, setFormValues] = useState({}); const [selectedPreset, setSelectedPreset] = useState(''); const [baseUrl, setBaseUrl] = useState(''); + const [discoveryLoading, setDiscoveryLoading] = useState(false); + const [discoveryInfo, setDiscoveryInfo] = useState(null); + const [advancedActiveKeys, setAdvancedActiveKeys] = useState([]); const formApiRef = React.useRef(null); + const mergeFormValues = (newValues) => { + setFormValues((prev) => ({ ...prev, ...newValues })); + if (!formApiRef.current) return; + Object.entries(newValues).forEach(([key, value]) => { + formApiRef.current.setValue(key, value); + }); + }; + + const getLatestFormValues = () => { + const values = formApiRef.current?.getValues?.(); + return values && typeof values === 'object' ? values : formValues; + }; + + const normalizeBaseUrl = (url) => (url || '').trim().replace(/\/+$/, ''); + + const inferBaseUrlFromProvider = (provider) => { + const endpoint = provider?.authorization_endpoint || provider?.token_endpoint; + if (!endpoint) return ''; + try { + const url = new URL(endpoint); + return `${url.protocol}//${url.host}`; + } catch (error) { + return ''; + } + }; + + const resetDiscoveryState = () => { + setDiscoveryInfo(null); + }; + + const closeModal = () => { + setModalVisible(false); + resetDiscoveryState(); + setAdvancedActiveKeys([]); + }; + const fetchProviders = async () => { setLoading(true); try { @@ -154,23 +262,30 @@ const CustomOAuthSetting = ({ serverAddress }) => { setEditingProvider(null); setFormValues({ enabled: false, + icon: '', scopes: 'openid profile email', user_id_field: 'sub', username_field: 'preferred_username', display_name_field: 'name', email_field: 'email', auth_style: 0, + access_policy: '', + access_denied_message: '', }); setSelectedPreset(''); setBaseUrl(''); + resetDiscoveryState(); + setAdvancedActiveKeys([]); setModalVisible(true); }; const handleEdit = (provider) => { setEditingProvider(provider); setFormValues({ ...provider }); - setSelectedPreset(''); - setBaseUrl(''); + setSelectedPreset(OAUTH_PRESETS[provider.slug] ? provider.slug : ''); + setBaseUrl(inferBaseUrlFromProvider(provider)); + resetDiscoveryState(); + setAdvancedActiveKeys([]); setModalVisible(true); }; @@ -189,6 +304,8 @@ const CustomOAuthSetting = ({ serverAddress }) => { }; const handleSubmit = async () => { + const currentValues = getLatestFormValues(); + // Validate required fields const requiredFields = [ 'name', @@ -204,7 +321,7 @@ const CustomOAuthSetting = ({ serverAddress }) => { } for (const field of requiredFields) { - if (!formValues[field]) { + if (!currentValues[field]) { showError(t(`请填写 ${field}`)); return; } @@ -213,11 +330,11 @@ const CustomOAuthSetting = ({ serverAddress }) => { // Validate endpoint URLs must be full URLs const endpointFields = ['authorization_endpoint', 'token_endpoint', 'user_info_endpoint']; for (const field of endpointFields) { - const value = formValues[field]; + const value = currentValues[field]; if (value && !value.startsWith('http://') && !value.startsWith('https://')) { - // Check if user selected a preset but forgot to fill server address + // Check if user selected a preset but forgot to fill issuer URL if (selectedPreset && !baseUrl) { - showError(t('请先填写服务器地址,以自动生成完整的端点 URL')); + showError(t('请先填写 Issuer URL,以自动生成完整的端点 URL')); } else { showError(t('端点 URL 必须是完整地址(以 http:// 或 https:// 开头)')); } @@ -226,80 +343,199 @@ const CustomOAuthSetting = ({ serverAddress }) => { } try { + const payload = { ...currentValues, enabled: !!currentValues.enabled }; + delete payload.preset; + delete payload.base_url; + let res; if (editingProvider) { res = await API.put( `/api/custom-oauth-provider/${editingProvider.id}`, - formValues + payload ); } else { - res = await API.post('/api/custom-oauth-provider/', formValues); + res = await API.post('/api/custom-oauth-provider/', payload); } if (res.data.success) { showSuccess(editingProvider ? t('更新成功') : t('创建成功')); - setModalVisible(false); + closeModal(); fetchProviders(); } else { showError(res.data.message); } } catch (error) { - showError(editingProvider ? t('更新失败') : t('创建失败')); + showError( + error?.response?.data?.message || + (editingProvider ? t('更新失败') : t('创建失败')), + ); + } + }; + + const handleFetchFromDiscovery = async () => { + const cleanBaseUrl = normalizeBaseUrl(baseUrl); + const configuredWellKnown = (formValues.well_known || '').trim(); + const wellKnownUrl = + configuredWellKnown || + (cleanBaseUrl ? `${cleanBaseUrl}/.well-known/openid-configuration` : ''); + + if (!wellKnownUrl) { + showError(t('请先填写 Discovery URL 或 Issuer URL')); + return; + } + + setDiscoveryLoading(true); + try { + const res = await API.post('/api/custom-oauth-provider/discovery', { + well_known_url: configuredWellKnown || '', + issuer_url: cleanBaseUrl || '', + }); + if (!res.data.success) { + throw new Error(res.data.message || t('未知错误')); + } + const data = res.data.data?.discovery || {}; + const resolvedWellKnown = res.data.data?.well_known_url || wellKnownUrl; + + const discoveredValues = { + well_known: resolvedWellKnown, + }; + const autoFilledFields = []; + if (data.authorization_endpoint) { + discoveredValues.authorization_endpoint = data.authorization_endpoint; + autoFilledFields.push('authorization_endpoint'); + } + if (data.token_endpoint) { + discoveredValues.token_endpoint = data.token_endpoint; + autoFilledFields.push('token_endpoint'); + } + if (data.userinfo_endpoint) { + discoveredValues.user_info_endpoint = data.userinfo_endpoint; + autoFilledFields.push('user_info_endpoint'); + } + + const scopesSupported = Array.isArray(data.scopes_supported) + ? data.scopes_supported + : []; + if (scopesSupported.length > 0 && !formValues.scopes) { + const preferredScopes = ['openid', 'profile', 'email'].filter((scope) => + scopesSupported.includes(scope), + ); + discoveredValues.scopes = + preferredScopes.length > 0 + ? preferredScopes.join(' ') + : scopesSupported.slice(0, 5).join(' '); + autoFilledFields.push('scopes'); + } + + const claimsSupported = Array.isArray(data.claims_supported) + ? data.claims_supported + : []; + const claimMap = { + user_id_field: 'sub', + username_field: 'preferred_username', + display_name_field: 'name', + email_field: 'email', + }; + Object.entries(claimMap).forEach(([field, claim]) => { + if (!formValues[field] && claimsSupported.includes(claim)) { + discoveredValues[field] = claim; + autoFilledFields.push(field); + } + }); + + const hasCoreEndpoint = + discoveredValues.authorization_endpoint || + discoveredValues.token_endpoint || + discoveredValues.user_info_endpoint; + if (!hasCoreEndpoint) { + showError(t('未在 Discovery 响应中找到可用的 OAuth 端点')); + return; + } + + mergeFormValues(discoveredValues); + setDiscoveryInfo({ + wellKnown: wellKnownUrl, + autoFilledFields, + scopesSupported: scopesSupported.slice(0, 12), + claimsSupported: claimsSupported.slice(0, 12), + }); + showSuccess(t('已从 Discovery 自动填充配置')); + } catch (error) { + showError( + t('获取 Discovery 配置失败:') + (error?.message || t('未知错误')), + ); + } finally { + setDiscoveryLoading(false); } }; const handlePresetChange = (preset) => { setSelectedPreset(preset); - if (preset && OAUTH_PRESETS[preset]) { - const presetConfig = OAUTH_PRESETS[preset]; - const cleanUrl = baseUrl ? baseUrl.replace(/\/+$/, '') : ''; - const newValues = { - name: presetConfig.name, - slug: preset, - scopes: presetConfig.scopes, - user_id_field: presetConfig.user_id_field, - username_field: presetConfig.username_field, - display_name_field: presetConfig.display_name_field, - email_field: presetConfig.email_field, - auth_style: presetConfig.auth_style ?? 0, - }; - // Only fill endpoints if server address is provided - if (cleanUrl) { - newValues.authorization_endpoint = cleanUrl + presetConfig.authorization_endpoint; - newValues.token_endpoint = cleanUrl + presetConfig.token_endpoint; - newValues.user_info_endpoint = cleanUrl + presetConfig.user_info_endpoint; - } - setFormValues((prev) => ({ ...prev, ...newValues })); - // Update form fields directly via formApi - if (formApiRef.current) { - Object.entries(newValues).forEach(([key, value]) => { - formApiRef.current.setValue(key, value); - }); - } + resetDiscoveryState(); + const cleanUrl = normalizeBaseUrl(baseUrl); + if (!preset || !OAUTH_PRESETS[preset]) { + mergeFormValues(PRESET_RESET_VALUES); + return; } + + const presetConfig = OAUTH_PRESETS[preset]; + const newValues = { + ...PRESET_RESET_VALUES, + name: presetConfig.name, + slug: preset, + icon: getPresetIcon(preset), + scopes: presetConfig.scopes, + user_id_field: presetConfig.user_id_field, + username_field: presetConfig.username_field, + display_name_field: presetConfig.display_name_field, + email_field: presetConfig.email_field, + auth_style: presetConfig.auth_style ?? 0, + }; + if (cleanUrl) { + newValues.authorization_endpoint = + cleanUrl + presetConfig.authorization_endpoint; + newValues.token_endpoint = cleanUrl + presetConfig.token_endpoint; + newValues.user_info_endpoint = cleanUrl + presetConfig.user_info_endpoint; + } + mergeFormValues(newValues); }; const handleBaseUrlChange = (url) => { setBaseUrl(url); if (url && selectedPreset && OAUTH_PRESETS[selectedPreset]) { const presetConfig = OAUTH_PRESETS[selectedPreset]; - const cleanUrl = url.replace(/\/+$/, ''); // Remove trailing slashes + const cleanUrl = normalizeBaseUrl(url); const newValues = { authorization_endpoint: cleanUrl + presetConfig.authorization_endpoint, token_endpoint: cleanUrl + presetConfig.token_endpoint, user_info_endpoint: cleanUrl + presetConfig.user_info_endpoint, }; - setFormValues((prev) => ({ ...prev, ...newValues })); - // Update form fields directly via formApi (use merge mode to preserve other fields) - if (formApiRef.current) { - Object.entries(newValues).forEach(([key, value]) => { - formApiRef.current.setValue(key, value); - }); - } + mergeFormValues(newValues); } }; + const applyAccessPolicyTemplate = (templateKey) => { + const template = ACCESS_POLICY_TEMPLATES[templateKey]; + if (!template) return; + mergeFormValues({ access_policy: template }); + showSuccess(t('已填充策略模板')); + }; + + const applyDeniedTemplate = (templateKey) => { + const template = ACCESS_DENIED_TEMPLATES[templateKey]; + if (!template) return; + mergeFormValues({ access_denied_message: template }); + showSuccess(t('已填充提示模板')); + }; + const columns = [ + { + title: t('图标'), + dataIndex: 'icon', + key: 'icon', + width: 80, + render: (icon) => getOAuthProviderIcon(icon || '', 18), + }, { title: t('名称'), dataIndex: 'name', @@ -325,7 +561,10 @@ const CustomOAuthSetting = ({ serverAddress }) => { title: t('Client ID'), dataIndex: 'client_id', key: 'client_id', - render: (id) => (id ? id.substring(0, 20) + '...' : '-'), + render: (id) => { + if (!id) return '-'; + return id.length > 20 ? `${id.substring(0, 20)}...` : id; + }, }, { title: t('操作'), @@ -352,6 +591,10 @@ const CustomOAuthSetting = ({ serverAddress }) => { }, ]; + const discoveryAutoFilledLabels = (discoveryInfo?.autoFilledFields || []) + .map((field) => DISCOVERY_FIELD_LABELS[field] || field) + .join(', '); + return ( @@ -391,56 +634,142 @@ const CustomOAuthSetting = ({ serverAddress }) => { setModalVisible(false)} - okText={t('保存')} - cancelText={t('取消')} - width={800} + onCancel={closeModal} + width={860} + centered + bodyStyle={{ maxHeight: '72vh', overflowY: 'auto', paddingRight: 6 }} + footer={ +
+ + {t('启用供应商')} + mergeFormValues({ enabled: !!checked })} + /> + + {formValues.enabled ? t('已启用') : t('已禁用')} + + + + +
+ } >
setFormValues(values)} + onValueChange={() => { + setFormValues((prev) => ({ ...prev, ...getLatestFormValues() })); + }} getFormApi={(api) => (formApiRef.current = api)} > - {!editingProvider && ( - - - ({ - value: key, - label: config.name, - })), - ]} - /> - - - - - + + {t('Configuration')} + + + {t('先填写配置,再自动填充 OAuth 端点,能显著减少手工输入')} + + {discoveryInfo && ( + +
+ {t('已从 Discovery 获取配置,可继续手动修改所有字段。')} +
+ {discoveryAutoFilledLabels ? ( +
+ {t('自动填充字段')}: + {' '} + {discoveryAutoFilledLabels} +
+ ) : null} + {discoveryInfo.scopesSupported?.length ? ( +
+ {t('Discovery scopes')}: + {' '} + {discoveryInfo.scopesSupported.join(', ')} +
+ ) : null} + {discoveryInfo.claimsSupported?.length ? ( +
+ {t('Discovery claims')}: + {' '} + {discoveryInfo.claimsSupported.join(', ')} +
+ ) : null} +
+ } + /> )} + + + ({ + value: key, + label: config.name, + })), + ]} + /> + + + + + +
+ +
+ +
+ + + + + + { + + + + {t( + '图标使用 react-icons(Simple Icons)或 URL/emoji,例如:github、gitlab、si:google', + )} + + } + showClear + /> + + +
+ {getOAuthProviderIcon(formValues.icon || '', 24)} +
+ +
+ { label={t('Authorization Endpoint')} placeholder={ selectedPreset && OAUTH_PRESETS[selectedPreset] - ? t('填写服务器地址后自动生成:') + + ? t('填写 Issuer URL 后自动生成:') + OAUTH_PRESETS[selectedPreset].authorization_endpoint : 'https://example.com/oauth/authorize' } @@ -544,15 +908,14 @@ const CustomOAuthSetting = ({ serverAddress }) => { - - - @@ -568,7 +931,7 @@ const CustomOAuthSetting = ({ serverAddress }) => { @@ -576,7 +939,7 @@ const CustomOAuthSetting = ({ serverAddress }) => { @@ -586,41 +949,100 @@ const CustomOAuthSetting = ({ serverAddress }) => { - - {t('高级选项')} - + { + const keys = Array.isArray(activeKey) ? activeKey : [activeKey]; + setAdvancedActiveKeys(keys.filter(Boolean)); + }} + > + + + + + + - - - - - - - {t('启用此 OAuth 提供商')} - - - + + {t('准入策略')} + + + {t('可选:基于用户信息 JSON 做组合条件准入,条件不满足时返回自定义提示')} + + + + mergeFormValues({ access_policy: value })} + label={t('准入策略 JSON(可选)')} + rows={6} + placeholder={`{ + "logic": "and", + "conditions": [ + {"field": "trust_level", "op": "gte", "value": 2}, + {"field": "active", "op": "eq", "value": true} + ] +}`} + extraText={t('支持逻辑 and/or 与嵌套 groups;操作符支持 eq/ne/gt/gte/lt/lte/in/not_in/contains/exists')} + showClear + /> + + + + + + + + + mergeFormValues({ access_denied_message: value })} + label={t('拒绝提示模板(可选)')} + placeholder={t('例如:需要等级 {{required}},你当前等级 {{current}}')} + extraText={t('可用变量:{{provider}} {{field}} {{op}} {{required}} {{current}} 以及 {{current.path}}')} + showClear + /> + + + + + + + + diff --git a/web/src/components/settings/personal/cards/AccountManagement.jsx b/web/src/components/settings/personal/cards/AccountManagement.jsx index bc27630ba..29249caa1 100644 --- a/web/src/components/settings/personal/cards/AccountManagement.jsx +++ b/web/src/components/settings/personal/cards/AccountManagement.jsx @@ -50,6 +50,7 @@ import { onLinuxDOOAuthClicked, onDiscordOAuthClicked, onCustomOAuthClicked, + getOAuthProviderIcon, } from '../../../../helpers'; import TwoFASetting from '../components/TwoFASetting'; @@ -148,12 +149,14 @@ const AccountManagement = ({ // Check if custom OAuth provider is bound const isCustomOAuthBound = (providerId) => { - return customOAuthBindings.some((b) => b.provider_id === providerId); + const normalizedId = Number(providerId); + return customOAuthBindings.some((b) => Number(b.provider_id) === normalizedId); }; // Get binding info for a provider const getCustomOAuthBinding = (providerId) => { - return customOAuthBindings.find((b) => b.provider_id === providerId); + const normalizedId = Number(providerId); + return customOAuthBindings.find((b) => Number(b.provider_id) === normalizedId); }; React.useEffect(() => { @@ -524,10 +527,10 @@ const AccountManagement = ({
- + {getOAuthProviderIcon( + provider.icon || binding?.provider_icon || '', + 20, + )}
diff --git a/web/src/helpers/render.jsx b/web/src/helpers/render.jsx index ecc252cfd..3ba198cb3 100644 --- a/web/src/helpers/render.jsx +++ b/web/src/helpers/render.jsx @@ -76,6 +76,31 @@ import { Server, CalendarClock, } from 'lucide-react'; +import { + SiAtlassian, + SiAuth0, + SiAuthentik, + SiBitbucket, + SiDiscord, + SiDropbox, + SiFacebook, + SiGitea, + SiGithub, + SiGitlab, + SiGoogle, + SiKeycloak, + SiLinkedin, + SiNextcloud, + SiNotion, + SiOkta, + SiOpenid, + SiReddit, + SiSlack, + SiTelegram, + SiTwitch, + SiWechat, + SiX, +} from 'react-icons/si'; // 获取侧边栏Lucide图标组件 export function getLucideIcon(key, selected = false) { @@ -472,6 +497,106 @@ export function getLobeHubIcon(iconName, size = 14) { return ; } +const oauthProviderIconMap = { + github: SiGithub, + gitlab: SiGitlab, + gitea: SiGitea, + google: SiGoogle, + discord: SiDiscord, + facebook: SiFacebook, + linkedin: SiLinkedin, + x: SiX, + twitter: SiX, + slack: SiSlack, + telegram: SiTelegram, + wechat: SiWechat, + keycloak: SiKeycloak, + nextcloud: SiNextcloud, + authentik: SiAuthentik, + openid: SiOpenid, + okta: SiOkta, + auth0: SiAuth0, + atlassian: SiAtlassian, + bitbucket: SiBitbucket, + notion: SiNotion, + twitch: SiTwitch, + reddit: SiReddit, + dropbox: SiDropbox, +}; + +function isHttpUrl(value) { + return /^https?:\/\//i.test(value || ''); +} + +function isSimpleEmoji(value) { + if (!value) return false; + const trimmed = String(value).trim(); + return trimmed.length > 0 && trimmed.length <= 4 && !isHttpUrl(trimmed); +} + +function normalizeOAuthIconKey(raw) { + return raw + .trim() + .toLowerCase() + .replace(/^ri:/, '') + .replace(/^react-icons:/, '') + .replace(/^si:/, ''); +} + +/** + * Render custom OAuth provider icon with react-icons or URL/emoji fallback. + * Supported formats: + * - react-icons simple key: github / gitlab / google / keycloak + * - prefixed key: ri:github / si:github + * - full URL image: https://example.com/logo.png + * - emoji: 🐱 + */ +export function getOAuthProviderIcon(iconName, size = 20) { + const raw = String(iconName || '').trim(); + const iconSize = Number(size) > 0 ? Number(size) : 20; + + if (!raw) { + return ; + } + + if (isHttpUrl(raw)) { + return ( + provider icon + ); + } + + if (isSimpleEmoji(raw)) { + return ( + + {raw} + + ); + } + + const key = normalizeOAuthIconKey(raw); + const IconComp = oauthProviderIconMap[key]; + if (IconComp) { + return ; + } + + return {raw.charAt(0).toUpperCase()}; +} + // 颜色列表 const colors = [ 'amber',