mirror of
https://github.com/QuantumNous/new-api.git
synced 2026-03-30 05:20:18 +00:00
* feat: imporve UI/UX * fix: stabilize provider enabled toggle and polish custom OAuth settings UX * fix: add access policy/message templates and persist advanced fields reliably * fix: move template fill actions below fields and keep advanced form flow cleaner
669 lines
18 KiB
Go
669 lines
18 KiB
Go
package oauth
|
|
|
|
import (
|
|
"context"
|
|
"encoding/base64"
|
|
stdjson "encoding/json"
|
|
"errors"
|
|
"fmt"
|
|
"io"
|
|
"net/http"
|
|
"net/url"
|
|
"regexp"
|
|
"strconv"
|
|
"strings"
|
|
"time"
|
|
|
|
"github.com/QuantumNous/new-api/common"
|
|
"github.com/QuantumNous/new-api/i18n"
|
|
"github.com/QuantumNous/new-api/logger"
|
|
"github.com/QuantumNous/new-api/model"
|
|
"github.com/QuantumNous/new-api/setting/system_setting"
|
|
"github.com/gin-gonic/gin"
|
|
"github.com/samber/lo"
|
|
"github.com/tidwall/gjson"
|
|
)
|
|
|
|
// AuthStyle defines how to send client credentials
|
|
const (
|
|
AuthStyleAutoDetect = 0 // Auto-detect based on server response
|
|
AuthStyleInParams = 1 // Send client_id and client_secret as POST parameters
|
|
AuthStyleInHeader = 2 // Send as Basic Auth header
|
|
)
|
|
|
|
// GenericOAuthProvider implements OAuth for custom/generic OAuth providers
|
|
type GenericOAuthProvider struct {
|
|
config *model.CustomOAuthProvider
|
|
}
|
|
|
|
type accessPolicy struct {
|
|
Logic string `json:"logic"`
|
|
Conditions []accessCondition `json:"conditions"`
|
|
Groups []accessPolicy `json:"groups"`
|
|
}
|
|
|
|
type accessCondition struct {
|
|
Field string `json:"field"`
|
|
Op string `json:"op"`
|
|
Value any `json:"value"`
|
|
}
|
|
|
|
type accessPolicyFailure struct {
|
|
Field string
|
|
Op string
|
|
Expected any
|
|
Current any
|
|
}
|
|
|
|
var supportedAccessPolicyOps = []string{
|
|
"eq",
|
|
"ne",
|
|
"gt",
|
|
"gte",
|
|
"lt",
|
|
"lte",
|
|
"in",
|
|
"not_in",
|
|
"contains",
|
|
"not_contains",
|
|
"exists",
|
|
"not_exists",
|
|
}
|
|
|
|
// NewGenericOAuthProvider creates a new generic OAuth provider from config
|
|
func NewGenericOAuthProvider(config *model.CustomOAuthProvider) *GenericOAuthProvider {
|
|
return &GenericOAuthProvider{config: config}
|
|
}
|
|
|
|
func (p *GenericOAuthProvider) GetName() string {
|
|
return p.config.Name
|
|
}
|
|
|
|
func (p *GenericOAuthProvider) IsEnabled() bool {
|
|
return p.config.Enabled
|
|
}
|
|
|
|
func (p *GenericOAuthProvider) GetConfig() *model.CustomOAuthProvider {
|
|
return p.config
|
|
}
|
|
|
|
func (p *GenericOAuthProvider) ExchangeToken(ctx context.Context, code string, c *gin.Context) (*OAuthToken, error) {
|
|
if code == "" {
|
|
return nil, NewOAuthError(i18n.MsgOAuthInvalidCode, nil)
|
|
}
|
|
|
|
logger.LogDebug(ctx, "[OAuth-Generic-%s] ExchangeToken: code=%s...", p.config.Slug, code[:min(len(code), 10)])
|
|
|
|
redirectUri := fmt.Sprintf("%s/oauth/%s", system_setting.ServerAddress, p.config.Slug)
|
|
values := url.Values{}
|
|
values.Set("grant_type", "authorization_code")
|
|
values.Set("code", code)
|
|
values.Set("redirect_uri", redirectUri)
|
|
|
|
// Determine auth style
|
|
authStyle := p.config.AuthStyle
|
|
if authStyle == AuthStyleAutoDetect {
|
|
// Default to params style for most OAuth servers
|
|
authStyle = AuthStyleInParams
|
|
}
|
|
|
|
var req *http.Request
|
|
var err error
|
|
|
|
if authStyle == AuthStyleInParams {
|
|
values.Set("client_id", p.config.ClientId)
|
|
values.Set("client_secret", p.config.ClientSecret)
|
|
}
|
|
|
|
req, err = http.NewRequestWithContext(ctx, "POST", p.config.TokenEndpoint, strings.NewReader(values.Encode()))
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
|
|
req.Header.Set("Accept", "application/json")
|
|
|
|
if authStyle == AuthStyleInHeader {
|
|
// Basic Auth
|
|
credentials := base64.StdEncoding.EncodeToString([]byte(p.config.ClientId + ":" + p.config.ClientSecret))
|
|
req.Header.Set("Authorization", "Basic "+credentials)
|
|
}
|
|
|
|
logger.LogDebug(ctx, "[OAuth-Generic-%s] ExchangeToken: token_endpoint=%s, redirect_uri=%s, auth_style=%d",
|
|
p.config.Slug, p.config.TokenEndpoint, redirectUri, authStyle)
|
|
|
|
client := http.Client{
|
|
Timeout: 20 * time.Second,
|
|
}
|
|
res, err := client.Do(req)
|
|
if err != nil {
|
|
logger.LogError(ctx, fmt.Sprintf("[OAuth-Generic-%s] ExchangeToken error: %s", p.config.Slug, err.Error()))
|
|
return nil, NewOAuthErrorWithRaw(i18n.MsgOAuthConnectFailed, map[string]any{"Provider": p.config.Name}, err.Error())
|
|
}
|
|
defer res.Body.Close()
|
|
|
|
logger.LogDebug(ctx, "[OAuth-Generic-%s] ExchangeToken response status: %d", p.config.Slug, res.StatusCode)
|
|
|
|
body, err := io.ReadAll(res.Body)
|
|
if err != nil {
|
|
logger.LogError(ctx, fmt.Sprintf("[OAuth-Generic-%s] ExchangeToken read body error: %s", p.config.Slug, err.Error()))
|
|
return nil, err
|
|
}
|
|
|
|
bodyStr := string(body)
|
|
logger.LogDebug(ctx, "[OAuth-Generic-%s] ExchangeToken response body: %s", p.config.Slug, bodyStr[:min(len(bodyStr), 500)])
|
|
|
|
// Try to parse as JSON first
|
|
var tokenResponse struct {
|
|
AccessToken string `json:"access_token"`
|
|
TokenType string `json:"token_type"`
|
|
RefreshToken string `json:"refresh_token"`
|
|
ExpiresIn int `json:"expires_in"`
|
|
Scope string `json:"scope"`
|
|
IDToken string `json:"id_token"`
|
|
Error string `json:"error"`
|
|
ErrorDesc string `json:"error_description"`
|
|
}
|
|
|
|
if err := common.Unmarshal(body, &tokenResponse); err != nil {
|
|
// Try to parse as URL-encoded (some OAuth servers like GitHub return this format)
|
|
parsedValues, parseErr := url.ParseQuery(bodyStr)
|
|
if parseErr != nil {
|
|
logger.LogError(ctx, fmt.Sprintf("[OAuth-Generic-%s] ExchangeToken parse error: %s", p.config.Slug, err.Error()))
|
|
return nil, err
|
|
}
|
|
tokenResponse.AccessToken = parsedValues.Get("access_token")
|
|
tokenResponse.TokenType = parsedValues.Get("token_type")
|
|
tokenResponse.Scope = parsedValues.Get("scope")
|
|
}
|
|
|
|
if tokenResponse.Error != "" {
|
|
logger.LogError(ctx, fmt.Sprintf("[OAuth-Generic-%s] ExchangeToken OAuth error: %s - %s",
|
|
p.config.Slug, tokenResponse.Error, tokenResponse.ErrorDesc))
|
|
return nil, NewOAuthErrorWithRaw(i18n.MsgOAuthTokenFailed, map[string]any{"Provider": p.config.Name}, tokenResponse.ErrorDesc)
|
|
}
|
|
|
|
if tokenResponse.AccessToken == "" {
|
|
logger.LogError(ctx, fmt.Sprintf("[OAuth-Generic-%s] ExchangeToken failed: empty access token", p.config.Slug))
|
|
return nil, NewOAuthError(i18n.MsgOAuthTokenFailed, map[string]any{"Provider": p.config.Name})
|
|
}
|
|
|
|
logger.LogDebug(ctx, "[OAuth-Generic-%s] ExchangeToken success: scope=%s", p.config.Slug, tokenResponse.Scope)
|
|
|
|
return &OAuthToken{
|
|
AccessToken: tokenResponse.AccessToken,
|
|
TokenType: tokenResponse.TokenType,
|
|
RefreshToken: tokenResponse.RefreshToken,
|
|
ExpiresIn: tokenResponse.ExpiresIn,
|
|
Scope: tokenResponse.Scope,
|
|
IDToken: tokenResponse.IDToken,
|
|
}, nil
|
|
}
|
|
|
|
func (p *GenericOAuthProvider) GetUserInfo(ctx context.Context, token *OAuthToken) (*OAuthUser, error) {
|
|
logger.LogDebug(ctx, "[OAuth-Generic-%s] GetUserInfo: fetching user info from %s", p.config.Slug, p.config.UserInfoEndpoint)
|
|
|
|
req, err := http.NewRequestWithContext(ctx, "GET", p.config.UserInfoEndpoint, nil)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
// Set authorization header
|
|
tokenType := token.TokenType
|
|
if tokenType == "" {
|
|
tokenType = "Bearer"
|
|
}
|
|
req.Header.Set("Authorization", fmt.Sprintf("%s %s", tokenType, token.AccessToken))
|
|
req.Header.Set("Accept", "application/json")
|
|
|
|
client := http.Client{
|
|
Timeout: 20 * time.Second,
|
|
}
|
|
res, err := client.Do(req)
|
|
if err != nil {
|
|
logger.LogError(ctx, fmt.Sprintf("[OAuth-Generic-%s] GetUserInfo error: %s", p.config.Slug, err.Error()))
|
|
return nil, NewOAuthErrorWithRaw(i18n.MsgOAuthConnectFailed, map[string]any{"Provider": p.config.Name}, err.Error())
|
|
}
|
|
defer res.Body.Close()
|
|
|
|
logger.LogDebug(ctx, "[OAuth-Generic-%s] GetUserInfo response status: %d", p.config.Slug, res.StatusCode)
|
|
|
|
if res.StatusCode != http.StatusOK {
|
|
logger.LogError(ctx, fmt.Sprintf("[OAuth-Generic-%s] GetUserInfo failed: status=%d", p.config.Slug, res.StatusCode))
|
|
return nil, NewOAuthError(i18n.MsgOAuthGetUserErr, nil)
|
|
}
|
|
|
|
body, err := io.ReadAll(res.Body)
|
|
if err != nil {
|
|
logger.LogError(ctx, fmt.Sprintf("[OAuth-Generic-%s] GetUserInfo read body error: %s", p.config.Slug, err.Error()))
|
|
return nil, err
|
|
}
|
|
|
|
bodyStr := string(body)
|
|
logger.LogDebug(ctx, "[OAuth-Generic-%s] GetUserInfo response body: %s", p.config.Slug, bodyStr[:min(len(bodyStr), 500)])
|
|
|
|
// Extract fields using gjson (supports JSONPath-like syntax)
|
|
userId := gjson.Get(bodyStr, p.config.UserIdField).String()
|
|
username := gjson.Get(bodyStr, p.config.UsernameField).String()
|
|
displayName := gjson.Get(bodyStr, p.config.DisplayNameField).String()
|
|
email := gjson.Get(bodyStr, p.config.EmailField).String()
|
|
|
|
// If user ID field returns a number, convert it
|
|
if userId == "" {
|
|
// Try to get as number
|
|
userIdNum := gjson.Get(bodyStr, p.config.UserIdField)
|
|
if userIdNum.Exists() {
|
|
userId = userIdNum.Raw
|
|
// Remove quotes if present
|
|
userId = strings.Trim(userId, "\"")
|
|
}
|
|
}
|
|
|
|
if userId == "" {
|
|
logger.LogError(ctx, fmt.Sprintf("[OAuth-Generic-%s] GetUserInfo failed: empty user ID (field: %s)", p.config.Slug, p.config.UserIdField))
|
|
return nil, NewOAuthError(i18n.MsgOAuthUserInfoEmpty, map[string]any{"Provider": p.config.Name})
|
|
}
|
|
|
|
logger.LogDebug(ctx, "[OAuth-Generic-%s] GetUserInfo success: id=%s, username=%s, name=%s, email=%s",
|
|
p.config.Slug, userId, username, displayName, email)
|
|
|
|
policyRaw := strings.TrimSpace(p.config.AccessPolicy)
|
|
if policyRaw != "" {
|
|
policy, err := parseAccessPolicy(policyRaw)
|
|
if err != nil {
|
|
logger.LogError(ctx, fmt.Sprintf("[OAuth-Generic-%s] invalid access policy: %s", p.config.Slug, err.Error()))
|
|
return nil, NewOAuthErrorWithRaw(i18n.MsgOAuthGetUserErr, nil, "invalid access policy configuration")
|
|
}
|
|
allowed, failure := evaluateAccessPolicy(bodyStr, policy)
|
|
if !allowed {
|
|
message := renderAccessDeniedMessage(p.config.AccessDeniedMessage, p.config.Name, bodyStr, failure)
|
|
logger.LogWarn(ctx, fmt.Sprintf("[OAuth-Generic-%s] access denied by policy: field=%s op=%s expected=%v current=%v",
|
|
p.config.Slug, failure.Field, failure.Op, failure.Expected, failure.Current))
|
|
return nil, &AccessDeniedError{Message: message}
|
|
}
|
|
}
|
|
|
|
return &OAuthUser{
|
|
ProviderUserID: userId,
|
|
Username: username,
|
|
DisplayName: displayName,
|
|
Email: email,
|
|
Extra: map[string]any{
|
|
"provider": p.config.Slug,
|
|
},
|
|
}, nil
|
|
}
|
|
|
|
func (p *GenericOAuthProvider) IsUserIDTaken(providerUserID string) bool {
|
|
return model.IsProviderUserIdTaken(p.config.Id, providerUserID)
|
|
}
|
|
|
|
func (p *GenericOAuthProvider) FillUserByProviderID(user *model.User, providerUserID string) error {
|
|
foundUser, err := model.GetUserByOAuthBinding(p.config.Id, providerUserID)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
*user = *foundUser
|
|
return nil
|
|
}
|
|
|
|
func (p *GenericOAuthProvider) SetProviderUserID(user *model.User, providerUserID string) {
|
|
// For generic providers, we store the binding in user_oauth_bindings table
|
|
// This is handled separately in the OAuth controller
|
|
}
|
|
|
|
func (p *GenericOAuthProvider) GetProviderPrefix() string {
|
|
return p.config.Slug + "_"
|
|
}
|
|
|
|
// GetProviderId returns the provider ID for binding purposes
|
|
func (p *GenericOAuthProvider) GetProviderId() int {
|
|
return p.config.Id
|
|
}
|
|
|
|
// IsGenericProvider returns true for generic providers
|
|
func (p *GenericOAuthProvider) IsGenericProvider() bool {
|
|
return true
|
|
}
|
|
|
|
func parseAccessPolicy(raw string) (*accessPolicy, error) {
|
|
var policy accessPolicy
|
|
if err := common.UnmarshalJsonStr(raw, &policy); err != nil {
|
|
return nil, err
|
|
}
|
|
if err := validateAccessPolicy(&policy); err != nil {
|
|
return nil, err
|
|
}
|
|
return &policy, nil
|
|
}
|
|
|
|
func validateAccessPolicy(policy *accessPolicy) error {
|
|
if policy == nil {
|
|
return errors.New("policy is nil")
|
|
}
|
|
|
|
logic := strings.ToLower(strings.TrimSpace(policy.Logic))
|
|
if logic == "" {
|
|
logic = "and"
|
|
}
|
|
if !lo.Contains([]string{"and", "or"}, logic) {
|
|
return fmt.Errorf("unsupported policy logic: %s", logic)
|
|
}
|
|
policy.Logic = logic
|
|
|
|
if len(policy.Conditions) == 0 && len(policy.Groups) == 0 {
|
|
return errors.New("policy requires at least one condition or group")
|
|
}
|
|
|
|
for index := range policy.Conditions {
|
|
if err := validateAccessCondition(&policy.Conditions[index], index); err != nil {
|
|
return err
|
|
}
|
|
}
|
|
|
|
for index := range policy.Groups {
|
|
if err := validateAccessPolicy(&policy.Groups[index]); err != nil {
|
|
return fmt.Errorf("invalid policy group[%d]: %w", index, err)
|
|
}
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
func validateAccessCondition(condition *accessCondition, index int) error {
|
|
if condition == nil {
|
|
return fmt.Errorf("condition[%d] is nil", index)
|
|
}
|
|
|
|
condition.Field = strings.TrimSpace(condition.Field)
|
|
if condition.Field == "" {
|
|
return fmt.Errorf("condition[%d].field is required", index)
|
|
}
|
|
|
|
condition.Op = normalizePolicyOp(condition.Op)
|
|
if !lo.Contains(supportedAccessPolicyOps, condition.Op) {
|
|
return fmt.Errorf("condition[%d].op is unsupported: %s", index, condition.Op)
|
|
}
|
|
|
|
if lo.Contains([]string{"in", "not_in"}, condition.Op) {
|
|
if _, ok := condition.Value.([]any); !ok {
|
|
return fmt.Errorf("condition[%d].value must be an array for op %s", index, condition.Op)
|
|
}
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
func evaluateAccessPolicy(body string, policy *accessPolicy) (bool, *accessPolicyFailure) {
|
|
if policy == nil {
|
|
return true, nil
|
|
}
|
|
|
|
logic := strings.ToLower(strings.TrimSpace(policy.Logic))
|
|
if logic == "" {
|
|
logic = "and"
|
|
}
|
|
|
|
hasAny := len(policy.Conditions) > 0 || len(policy.Groups) > 0
|
|
if !hasAny {
|
|
return true, nil
|
|
}
|
|
|
|
if logic == "or" {
|
|
var firstFailure *accessPolicyFailure
|
|
for _, cond := range policy.Conditions {
|
|
ok, failure := evaluateAccessCondition(body, cond)
|
|
if ok {
|
|
return true, nil
|
|
}
|
|
if firstFailure == nil {
|
|
firstFailure = failure
|
|
}
|
|
}
|
|
for _, group := range policy.Groups {
|
|
ok, failure := evaluateAccessPolicy(body, &group)
|
|
if ok {
|
|
return true, nil
|
|
}
|
|
if firstFailure == nil {
|
|
firstFailure = failure
|
|
}
|
|
}
|
|
return false, firstFailure
|
|
}
|
|
|
|
for _, cond := range policy.Conditions {
|
|
ok, failure := evaluateAccessCondition(body, cond)
|
|
if !ok {
|
|
return false, failure
|
|
}
|
|
}
|
|
for _, group := range policy.Groups {
|
|
ok, failure := evaluateAccessPolicy(body, &group)
|
|
if !ok {
|
|
return false, failure
|
|
}
|
|
}
|
|
return true, nil
|
|
}
|
|
|
|
func evaluateAccessCondition(body string, cond accessCondition) (bool, *accessPolicyFailure) {
|
|
path := cond.Field
|
|
op := cond.Op
|
|
result := gjson.Get(body, path)
|
|
current := gjsonResultToValue(result)
|
|
failure := &accessPolicyFailure{
|
|
Field: path,
|
|
Op: op,
|
|
Expected: cond.Value,
|
|
Current: current,
|
|
}
|
|
|
|
switch op {
|
|
case "exists":
|
|
return result.Exists(), failure
|
|
case "not_exists":
|
|
return !result.Exists(), failure
|
|
case "eq":
|
|
return compareAny(current, cond.Value) == 0, failure
|
|
case "ne":
|
|
return compareAny(current, cond.Value) != 0, failure
|
|
case "gt":
|
|
return compareAny(current, cond.Value) > 0, failure
|
|
case "gte":
|
|
return compareAny(current, cond.Value) >= 0, failure
|
|
case "lt":
|
|
return compareAny(current, cond.Value) < 0, failure
|
|
case "lte":
|
|
return compareAny(current, cond.Value) <= 0, failure
|
|
case "in":
|
|
return valueInSlice(current, cond.Value), failure
|
|
case "not_in":
|
|
return !valueInSlice(current, cond.Value), failure
|
|
case "contains":
|
|
return containsValue(current, cond.Value), failure
|
|
case "not_contains":
|
|
return !containsValue(current, cond.Value), failure
|
|
default:
|
|
return false, failure
|
|
}
|
|
}
|
|
|
|
func normalizePolicyOp(op string) string {
|
|
return strings.ToLower(strings.TrimSpace(op))
|
|
}
|
|
|
|
func gjsonResultToValue(result gjson.Result) any {
|
|
if !result.Exists() {
|
|
return nil
|
|
}
|
|
if result.IsArray() {
|
|
arr := result.Array()
|
|
values := make([]any, 0, len(arr))
|
|
for _, item := range arr {
|
|
values = append(values, gjsonResultToValue(item))
|
|
}
|
|
return values
|
|
}
|
|
switch result.Type {
|
|
case gjson.Null:
|
|
return nil
|
|
case gjson.True:
|
|
return true
|
|
case gjson.False:
|
|
return false
|
|
case gjson.Number:
|
|
return result.Num
|
|
case gjson.String:
|
|
return result.String()
|
|
case gjson.JSON:
|
|
var data any
|
|
if err := common.UnmarshalJsonStr(result.Raw, &data); err == nil {
|
|
return data
|
|
}
|
|
return result.Raw
|
|
default:
|
|
return result.Value()
|
|
}
|
|
}
|
|
|
|
func compareAny(left any, right any) int {
|
|
if lf, ok := toFloat(left); ok {
|
|
if rf, ok2 := toFloat(right); ok2 {
|
|
switch {
|
|
case lf < rf:
|
|
return -1
|
|
case lf > rf:
|
|
return 1
|
|
default:
|
|
return 0
|
|
}
|
|
}
|
|
}
|
|
|
|
ls := strings.TrimSpace(fmt.Sprint(left))
|
|
rs := strings.TrimSpace(fmt.Sprint(right))
|
|
switch {
|
|
case ls < rs:
|
|
return -1
|
|
case ls > rs:
|
|
return 1
|
|
default:
|
|
return 0
|
|
}
|
|
}
|
|
|
|
func toFloat(v any) (float64, bool) {
|
|
switch value := v.(type) {
|
|
case float64:
|
|
return value, true
|
|
case float32:
|
|
return float64(value), true
|
|
case int:
|
|
return float64(value), true
|
|
case int8:
|
|
return float64(value), true
|
|
case int16:
|
|
return float64(value), true
|
|
case int32:
|
|
return float64(value), true
|
|
case int64:
|
|
return float64(value), true
|
|
case uint:
|
|
return float64(value), true
|
|
case uint8:
|
|
return float64(value), true
|
|
case uint16:
|
|
return float64(value), true
|
|
case uint32:
|
|
return float64(value), true
|
|
case uint64:
|
|
return float64(value), true
|
|
case stdjson.Number:
|
|
n, err := value.Float64()
|
|
if err == nil {
|
|
return n, true
|
|
}
|
|
case string:
|
|
n, err := strconv.ParseFloat(strings.TrimSpace(value), 64)
|
|
if err == nil {
|
|
return n, true
|
|
}
|
|
}
|
|
return 0, false
|
|
}
|
|
|
|
func valueInSlice(current any, expected any) bool {
|
|
list, ok := expected.([]any)
|
|
if !ok {
|
|
return false
|
|
}
|
|
return lo.ContainsBy(list, func(item any) bool {
|
|
return compareAny(current, item) == 0
|
|
})
|
|
}
|
|
|
|
func containsValue(current any, expected any) bool {
|
|
switch value := current.(type) {
|
|
case string:
|
|
target := strings.TrimSpace(fmt.Sprint(expected))
|
|
return strings.Contains(value, target)
|
|
case []any:
|
|
return lo.ContainsBy(value, func(item any) bool {
|
|
return compareAny(item, expected) == 0
|
|
})
|
|
}
|
|
return false
|
|
}
|
|
|
|
func renderAccessDeniedMessage(template string, providerName string, body string, failure *accessPolicyFailure) string {
|
|
defaultMessage := "Access denied: your account does not meet this provider's access requirements."
|
|
message := strings.TrimSpace(template)
|
|
if message == "" {
|
|
return defaultMessage
|
|
}
|
|
|
|
if failure == nil {
|
|
failure = &accessPolicyFailure{}
|
|
}
|
|
|
|
replacements := map[string]string{
|
|
"{{provider}}": providerName,
|
|
"{{field}}": failure.Field,
|
|
"{{op}}": failure.Op,
|
|
"{{required}}": fmt.Sprint(failure.Expected),
|
|
"{{current}}": fmt.Sprint(failure.Current),
|
|
}
|
|
|
|
for key, value := range replacements {
|
|
message = strings.ReplaceAll(message, key, value)
|
|
}
|
|
|
|
currentPattern := regexp.MustCompile(`\{\{current\.([^}]+)\}\}`)
|
|
message = currentPattern.ReplaceAllStringFunc(message, func(token string) string {
|
|
match := currentPattern.FindStringSubmatch(token)
|
|
if len(match) != 2 {
|
|
return ""
|
|
}
|
|
path := strings.TrimSpace(match[1])
|
|
if path == "" {
|
|
return ""
|
|
}
|
|
return strings.TrimSpace(gjson.Get(body, path).String())
|
|
})
|
|
|
|
requiredPattern := regexp.MustCompile(`\{\{required\.([^}]+)\}\}`)
|
|
message = requiredPattern.ReplaceAllStringFunc(message, func(token string) string {
|
|
match := requiredPattern.FindStringSubmatch(token)
|
|
if len(match) != 2 {
|
|
return ""
|
|
}
|
|
path := strings.TrimSpace(match[1])
|
|
if failure.Field == path {
|
|
return fmt.Sprint(failure.Expected)
|
|
}
|
|
return ""
|
|
})
|
|
|
|
return strings.TrimSpace(message)
|
|
}
|