mirror of
https://github.com/QuantumNous/new-api.git
synced 2026-03-30 15:46:44 +00:00
Compare commits
25 Commits
v0.10.8-al
...
v0.10.8-al
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
04dd761880 | ||
|
|
5ff9bc3851 | ||
|
|
053699fa98 | ||
|
|
3e1be18310 | ||
|
|
f3d6e99b28 | ||
|
|
6de8dea9b9 | ||
|
|
ab5456eb10 | ||
|
|
8e6071f146 | ||
|
|
729610beb0 | ||
|
|
c9f5de7048 | ||
|
|
ff71786d8d | ||
|
|
2504818b5a | ||
|
|
9a7a29eed8 | ||
|
|
4d797e0a5b | ||
|
|
3766e3248f | ||
|
|
b55e42eda7 | ||
|
|
af54ea85d2 | ||
|
|
632baadb57 | ||
|
|
df6c669e73 | ||
|
|
c540033985 | ||
|
|
1d611d89d2 | ||
|
|
7b1451caa7 | ||
|
|
ecebd619a4 | ||
|
|
9d73aa44b7 | ||
|
|
05ed9d43af |
38
.gitattributes
vendored
Normal file
38
.gitattributes
vendored
Normal file
@@ -0,0 +1,38 @@
|
||||
# Auto detect text files and perform LF normalization
|
||||
* text=auto
|
||||
|
||||
# Go files
|
||||
*.go text eol=lf
|
||||
|
||||
# Config files
|
||||
*.json text eol=lf
|
||||
*.yaml text eol=lf
|
||||
*.yml text eol=lf
|
||||
*.toml text eol=lf
|
||||
*.md text eol=lf
|
||||
|
||||
# JavaScript/TypeScript files
|
||||
*.js text eol=lf
|
||||
*.jsx text eol=lf
|
||||
*.ts text eol=lf
|
||||
*.tsx text eol=lf
|
||||
*.html text eol=lf
|
||||
*.css text eol=lf
|
||||
|
||||
# Shell scripts
|
||||
*.sh text eol=lf
|
||||
|
||||
# Binary files
|
||||
*.png binary
|
||||
*.jpg binary
|
||||
*.jpeg binary
|
||||
*.gif binary
|
||||
*.ico binary
|
||||
*.woff binary
|
||||
*.woff2 binary
|
||||
|
||||
# ============================================
|
||||
# GitHub Linguist - Language Detection
|
||||
# ============================================
|
||||
# Mark web frontend as vendored so GitHub recognizes this as a Go project
|
||||
electron/** linguist-vendored
|
||||
@@ -175,6 +175,10 @@ var (
|
||||
|
||||
DownloadRateLimitNum = 10
|
||||
DownloadRateLimitDuration int64 = 60
|
||||
|
||||
// Per-user search rate limit (applies after authentication, keyed by user ID)
|
||||
SearchRateLimitNum = 10
|
||||
SearchRateLimitDuration int64 = 60
|
||||
)
|
||||
|
||||
var RateLimitKeyExpirationDuration = 20 * time.Minute
|
||||
|
||||
@@ -192,7 +192,7 @@ func Interface2String(inter interface{}) string {
|
||||
case int:
|
||||
return fmt.Sprintf("%d", inter.(int))
|
||||
case float64:
|
||||
return fmt.Sprintf("%f", inter.(float64))
|
||||
return strconv.FormatFloat(inter.(float64), 'f', -1, 64)
|
||||
case bool:
|
||||
if inter.(bool) {
|
||||
return "true"
|
||||
|
||||
386
controller/custom_oauth.go
Normal file
386
controller/custom_oauth.go
Normal file
@@ -0,0 +1,386 @@
|
||||
package controller
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"strconv"
|
||||
|
||||
"github.com/QuantumNous/new-api/common"
|
||||
"github.com/QuantumNous/new-api/model"
|
||||
"github.com/QuantumNous/new-api/oauth"
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
// CustomOAuthProviderResponse is the response structure for custom OAuth providers
|
||||
// It excludes sensitive fields like client_secret
|
||||
type CustomOAuthProviderResponse struct {
|
||||
Id int `json:"id"`
|
||||
Name string `json:"name"`
|
||||
Slug string `json:"slug"`
|
||||
Enabled bool `json:"enabled"`
|
||||
ClientId string `json:"client_id"`
|
||||
AuthorizationEndpoint string `json:"authorization_endpoint"`
|
||||
TokenEndpoint string `json:"token_endpoint"`
|
||||
UserInfoEndpoint string `json:"user_info_endpoint"`
|
||||
Scopes string `json:"scopes"`
|
||||
UserIdField string `json:"user_id_field"`
|
||||
UsernameField string `json:"username_field"`
|
||||
DisplayNameField string `json:"display_name_field"`
|
||||
EmailField string `json:"email_field"`
|
||||
WellKnown string `json:"well_known"`
|
||||
AuthStyle int `json:"auth_style"`
|
||||
}
|
||||
|
||||
func toCustomOAuthProviderResponse(p *model.CustomOAuthProvider) *CustomOAuthProviderResponse {
|
||||
return &CustomOAuthProviderResponse{
|
||||
Id: p.Id,
|
||||
Name: p.Name,
|
||||
Slug: p.Slug,
|
||||
Enabled: p.Enabled,
|
||||
ClientId: p.ClientId,
|
||||
AuthorizationEndpoint: p.AuthorizationEndpoint,
|
||||
TokenEndpoint: p.TokenEndpoint,
|
||||
UserInfoEndpoint: p.UserInfoEndpoint,
|
||||
Scopes: p.Scopes,
|
||||
UserIdField: p.UserIdField,
|
||||
UsernameField: p.UsernameField,
|
||||
DisplayNameField: p.DisplayNameField,
|
||||
EmailField: p.EmailField,
|
||||
WellKnown: p.WellKnown,
|
||||
AuthStyle: p.AuthStyle,
|
||||
}
|
||||
}
|
||||
|
||||
// GetCustomOAuthProviders returns all custom OAuth providers
|
||||
func GetCustomOAuthProviders(c *gin.Context) {
|
||||
providers, err := model.GetAllCustomOAuthProviders()
|
||||
if err != nil {
|
||||
common.ApiError(c, err)
|
||||
return
|
||||
}
|
||||
|
||||
response := make([]*CustomOAuthProviderResponse, len(providers))
|
||||
for i, p := range providers {
|
||||
response[i] = toCustomOAuthProviderResponse(p)
|
||||
}
|
||||
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": true,
|
||||
"message": "",
|
||||
"data": response,
|
||||
})
|
||||
}
|
||||
|
||||
// GetCustomOAuthProvider returns a single custom OAuth provider by ID
|
||||
func GetCustomOAuthProvider(c *gin.Context) {
|
||||
idStr := c.Param("id")
|
||||
id, err := strconv.Atoi(idStr)
|
||||
if err != nil {
|
||||
common.ApiErrorMsg(c, "无效的 ID")
|
||||
return
|
||||
}
|
||||
|
||||
provider, err := model.GetCustomOAuthProviderById(id)
|
||||
if err != nil {
|
||||
common.ApiErrorMsg(c, "未找到该 OAuth 提供商")
|
||||
return
|
||||
}
|
||||
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": true,
|
||||
"message": "",
|
||||
"data": toCustomOAuthProviderResponse(provider),
|
||||
})
|
||||
}
|
||||
|
||||
// CreateCustomOAuthProviderRequest is the request structure for creating a custom OAuth provider
|
||||
type CreateCustomOAuthProviderRequest struct {
|
||||
Name string `json:"name" binding:"required"`
|
||||
Slug string `json:"slug" binding:"required"`
|
||||
Enabled bool `json:"enabled"`
|
||||
ClientId string `json:"client_id" binding:"required"`
|
||||
ClientSecret string `json:"client_secret" binding:"required"`
|
||||
AuthorizationEndpoint string `json:"authorization_endpoint" binding:"required"`
|
||||
TokenEndpoint string `json:"token_endpoint" binding:"required"`
|
||||
UserInfoEndpoint string `json:"user_info_endpoint" binding:"required"`
|
||||
Scopes string `json:"scopes"`
|
||||
UserIdField string `json:"user_id_field"`
|
||||
UsernameField string `json:"username_field"`
|
||||
DisplayNameField string `json:"display_name_field"`
|
||||
EmailField string `json:"email_field"`
|
||||
WellKnown string `json:"well_known"`
|
||||
AuthStyle int `json:"auth_style"`
|
||||
}
|
||||
|
||||
// CreateCustomOAuthProvider creates a new custom OAuth provider
|
||||
func CreateCustomOAuthProvider(c *gin.Context) {
|
||||
var req CreateCustomOAuthProviderRequest
|
||||
if err := c.ShouldBindJSON(&req); err != nil {
|
||||
common.ApiErrorMsg(c, "无效的请求参数: "+err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
// Check if slug is already taken
|
||||
if model.IsSlugTaken(req.Slug, 0) {
|
||||
common.ApiErrorMsg(c, "该 Slug 已被使用")
|
||||
return
|
||||
}
|
||||
|
||||
// Check if slug conflicts with built-in providers
|
||||
if oauth.IsProviderRegistered(req.Slug) && !oauth.IsCustomProvider(req.Slug) {
|
||||
common.ApiErrorMsg(c, "该 Slug 与内置 OAuth 提供商冲突")
|
||||
return
|
||||
}
|
||||
|
||||
provider := &model.CustomOAuthProvider{
|
||||
Name: req.Name,
|
||||
Slug: req.Slug,
|
||||
Enabled: req.Enabled,
|
||||
ClientId: req.ClientId,
|
||||
ClientSecret: req.ClientSecret,
|
||||
AuthorizationEndpoint: req.AuthorizationEndpoint,
|
||||
TokenEndpoint: req.TokenEndpoint,
|
||||
UserInfoEndpoint: req.UserInfoEndpoint,
|
||||
Scopes: req.Scopes,
|
||||
UserIdField: req.UserIdField,
|
||||
UsernameField: req.UsernameField,
|
||||
DisplayNameField: req.DisplayNameField,
|
||||
EmailField: req.EmailField,
|
||||
WellKnown: req.WellKnown,
|
||||
AuthStyle: req.AuthStyle,
|
||||
}
|
||||
|
||||
if err := model.CreateCustomOAuthProvider(provider); err != nil {
|
||||
common.ApiError(c, err)
|
||||
return
|
||||
}
|
||||
|
||||
// Register the provider in the OAuth registry
|
||||
oauth.RegisterOrUpdateCustomProvider(provider)
|
||||
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": true,
|
||||
"message": "创建成功",
|
||||
"data": toCustomOAuthProviderResponse(provider),
|
||||
})
|
||||
}
|
||||
|
||||
// UpdateCustomOAuthProviderRequest is the request structure for updating a custom OAuth provider
|
||||
type UpdateCustomOAuthProviderRequest struct {
|
||||
Name string `json:"name"`
|
||||
Slug string `json:"slug"`
|
||||
Enabled bool `json:"enabled"`
|
||||
ClientId string `json:"client_id"`
|
||||
ClientSecret string `json:"client_secret"` // Optional: if empty, keep existing
|
||||
AuthorizationEndpoint string `json:"authorization_endpoint"`
|
||||
TokenEndpoint string `json:"token_endpoint"`
|
||||
UserInfoEndpoint string `json:"user_info_endpoint"`
|
||||
Scopes string `json:"scopes"`
|
||||
UserIdField string `json:"user_id_field"`
|
||||
UsernameField string `json:"username_field"`
|
||||
DisplayNameField string `json:"display_name_field"`
|
||||
EmailField string `json:"email_field"`
|
||||
WellKnown string `json:"well_known"`
|
||||
AuthStyle int `json:"auth_style"`
|
||||
}
|
||||
|
||||
// UpdateCustomOAuthProvider updates an existing custom OAuth provider
|
||||
func UpdateCustomOAuthProvider(c *gin.Context) {
|
||||
idStr := c.Param("id")
|
||||
id, err := strconv.Atoi(idStr)
|
||||
if err != nil {
|
||||
common.ApiErrorMsg(c, "无效的 ID")
|
||||
return
|
||||
}
|
||||
|
||||
var req UpdateCustomOAuthProviderRequest
|
||||
if err := c.ShouldBindJSON(&req); err != nil {
|
||||
common.ApiErrorMsg(c, "无效的请求参数: "+err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
// Get existing provider
|
||||
provider, err := model.GetCustomOAuthProviderById(id)
|
||||
if err != nil {
|
||||
common.ApiErrorMsg(c, "未找到该 OAuth 提供商")
|
||||
return
|
||||
}
|
||||
|
||||
oldSlug := provider.Slug
|
||||
|
||||
// Check if new slug is taken by another provider
|
||||
if req.Slug != "" && req.Slug != provider.Slug {
|
||||
if model.IsSlugTaken(req.Slug, id) {
|
||||
common.ApiErrorMsg(c, "该 Slug 已被使用")
|
||||
return
|
||||
}
|
||||
// Check if slug conflicts with built-in providers
|
||||
if oauth.IsProviderRegistered(req.Slug) && !oauth.IsCustomProvider(req.Slug) {
|
||||
common.ApiErrorMsg(c, "该 Slug 与内置 OAuth 提供商冲突")
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
// Update fields
|
||||
if req.Name != "" {
|
||||
provider.Name = req.Name
|
||||
}
|
||||
if req.Slug != "" {
|
||||
provider.Slug = req.Slug
|
||||
}
|
||||
provider.Enabled = req.Enabled
|
||||
if req.ClientId != "" {
|
||||
provider.ClientId = req.ClientId
|
||||
}
|
||||
if req.ClientSecret != "" {
|
||||
provider.ClientSecret = req.ClientSecret
|
||||
}
|
||||
if req.AuthorizationEndpoint != "" {
|
||||
provider.AuthorizationEndpoint = req.AuthorizationEndpoint
|
||||
}
|
||||
if req.TokenEndpoint != "" {
|
||||
provider.TokenEndpoint = req.TokenEndpoint
|
||||
}
|
||||
if req.UserInfoEndpoint != "" {
|
||||
provider.UserInfoEndpoint = req.UserInfoEndpoint
|
||||
}
|
||||
if req.Scopes != "" {
|
||||
provider.Scopes = req.Scopes
|
||||
}
|
||||
if req.UserIdField != "" {
|
||||
provider.UserIdField = req.UserIdField
|
||||
}
|
||||
if req.UsernameField != "" {
|
||||
provider.UsernameField = req.UsernameField
|
||||
}
|
||||
if req.DisplayNameField != "" {
|
||||
provider.DisplayNameField = req.DisplayNameField
|
||||
}
|
||||
if req.EmailField != "" {
|
||||
provider.EmailField = req.EmailField
|
||||
}
|
||||
provider.WellKnown = req.WellKnown
|
||||
provider.AuthStyle = req.AuthStyle
|
||||
|
||||
if err := model.UpdateCustomOAuthProvider(provider); err != nil {
|
||||
common.ApiError(c, err)
|
||||
return
|
||||
}
|
||||
|
||||
// Update the provider in the OAuth registry
|
||||
if oldSlug != provider.Slug {
|
||||
oauth.UnregisterCustomProvider(oldSlug)
|
||||
}
|
||||
oauth.RegisterOrUpdateCustomProvider(provider)
|
||||
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": true,
|
||||
"message": "更新成功",
|
||||
"data": toCustomOAuthProviderResponse(provider),
|
||||
})
|
||||
}
|
||||
|
||||
// DeleteCustomOAuthProvider deletes a custom OAuth provider
|
||||
func DeleteCustomOAuthProvider(c *gin.Context) {
|
||||
idStr := c.Param("id")
|
||||
id, err := strconv.Atoi(idStr)
|
||||
if err != nil {
|
||||
common.ApiErrorMsg(c, "无效的 ID")
|
||||
return
|
||||
}
|
||||
|
||||
// Get existing provider to get slug
|
||||
provider, err := model.GetCustomOAuthProviderById(id)
|
||||
if err != nil {
|
||||
common.ApiErrorMsg(c, "未找到该 OAuth 提供商")
|
||||
return
|
||||
}
|
||||
|
||||
// Check if there are any user bindings
|
||||
count, _ := model.GetBindingCountByProviderId(id)
|
||||
if count > 0 {
|
||||
common.ApiErrorMsg(c, "该 OAuth 提供商还有用户绑定,无法删除。请先解除所有用户绑定。")
|
||||
return
|
||||
}
|
||||
|
||||
if err := model.DeleteCustomOAuthProvider(id); err != nil {
|
||||
common.ApiError(c, err)
|
||||
return
|
||||
}
|
||||
|
||||
// Unregister the provider from the OAuth registry
|
||||
oauth.UnregisterCustomProvider(provider.Slug)
|
||||
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": true,
|
||||
"message": "删除成功",
|
||||
})
|
||||
}
|
||||
|
||||
// GetUserOAuthBindings returns all OAuth bindings for the current user
|
||||
func GetUserOAuthBindings(c *gin.Context) {
|
||||
userId := c.GetInt("id")
|
||||
if userId == 0 {
|
||||
common.ApiErrorMsg(c, "未登录")
|
||||
return
|
||||
}
|
||||
|
||||
bindings, err := model.GetUserOAuthBindingsByUserId(userId)
|
||||
if err != nil {
|
||||
common.ApiError(c, err)
|
||||
return
|
||||
}
|
||||
|
||||
// Build response with provider info
|
||||
type BindingResponse struct {
|
||||
ProviderId int `json:"provider_id"`
|
||||
ProviderName string `json:"provider_name"`
|
||||
ProviderSlug string `json:"provider_slug"`
|
||||
ProviderUserId string `json:"provider_user_id"`
|
||||
}
|
||||
|
||||
response := make([]BindingResponse, 0)
|
||||
for _, binding := range bindings {
|
||||
provider, err := model.GetCustomOAuthProviderById(binding.ProviderId)
|
||||
if err != nil {
|
||||
continue // Skip if provider not found
|
||||
}
|
||||
response = append(response, BindingResponse{
|
||||
ProviderId: binding.ProviderId,
|
||||
ProviderName: provider.Name,
|
||||
ProviderSlug: provider.Slug,
|
||||
ProviderUserId: binding.ProviderUserId,
|
||||
})
|
||||
}
|
||||
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": true,
|
||||
"message": "",
|
||||
"data": response,
|
||||
})
|
||||
}
|
||||
|
||||
// UnbindCustomOAuth unbinds a custom OAuth provider from the current user
|
||||
func UnbindCustomOAuth(c *gin.Context) {
|
||||
userId := c.GetInt("id")
|
||||
if userId == 0 {
|
||||
common.ApiErrorMsg(c, "未登录")
|
||||
return
|
||||
}
|
||||
|
||||
providerIdStr := c.Param("provider_id")
|
||||
providerId, err := strconv.Atoi(providerIdStr)
|
||||
if err != nil {
|
||||
common.ApiErrorMsg(c, "无效的提供商 ID")
|
||||
return
|
||||
}
|
||||
|
||||
if err := model.DeleteUserOAuthBinding(userId, providerId); err != nil {
|
||||
common.ApiError(c, err)
|
||||
return
|
||||
}
|
||||
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": true,
|
||||
"message": "解绑成功",
|
||||
})
|
||||
}
|
||||
@@ -1,223 +0,0 @@
|
||||
package controller
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"strconv"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/QuantumNous/new-api/common"
|
||||
"github.com/QuantumNous/new-api/model"
|
||||
"github.com/QuantumNous/new-api/setting/system_setting"
|
||||
|
||||
"github.com/gin-contrib/sessions"
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
type DiscordResponse struct {
|
||||
AccessToken string `json:"access_token"`
|
||||
IDToken string `json:"id_token"`
|
||||
RefreshToken string `json:"refresh_token"`
|
||||
TokenType string `json:"token_type"`
|
||||
ExpiresIn int `json:"expires_in"`
|
||||
Scope string `json:"scope"`
|
||||
}
|
||||
|
||||
type DiscordUser struct {
|
||||
UID string `json:"id"`
|
||||
ID string `json:"username"`
|
||||
Name string `json:"global_name"`
|
||||
}
|
||||
|
||||
func getDiscordUserInfoByCode(code string) (*DiscordUser, error) {
|
||||
if code == "" {
|
||||
return nil, errors.New("无效的参数")
|
||||
}
|
||||
|
||||
values := url.Values{}
|
||||
values.Set("client_id", system_setting.GetDiscordSettings().ClientId)
|
||||
values.Set("client_secret", system_setting.GetDiscordSettings().ClientSecret)
|
||||
values.Set("code", code)
|
||||
values.Set("grant_type", "authorization_code")
|
||||
values.Set("redirect_uri", fmt.Sprintf("%s/oauth/discord", system_setting.ServerAddress))
|
||||
formData := values.Encode()
|
||||
req, err := http.NewRequest("POST", "https://discord.com/api/v10/oauth2/token", strings.NewReader(formData))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
|
||||
req.Header.Set("Accept", "application/json")
|
||||
client := http.Client{
|
||||
Timeout: 5 * time.Second,
|
||||
}
|
||||
res, err := client.Do(req)
|
||||
if err != nil {
|
||||
common.SysLog(err.Error())
|
||||
return nil, errors.New("无法连接至 Discord 服务器,请稍后重试!")
|
||||
}
|
||||
defer res.Body.Close()
|
||||
var discordResponse DiscordResponse
|
||||
err = json.NewDecoder(res.Body).Decode(&discordResponse)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if discordResponse.AccessToken == "" {
|
||||
common.SysError("Discord 获取 Token 失败,请检查设置!")
|
||||
return nil, errors.New("Discord 获取 Token 失败,请检查设置!")
|
||||
}
|
||||
|
||||
req, err = http.NewRequest("GET", "https://discord.com/api/v10/users/@me", nil)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
req.Header.Set("Authorization", "Bearer "+discordResponse.AccessToken)
|
||||
res2, err := client.Do(req)
|
||||
if err != nil {
|
||||
common.SysLog(err.Error())
|
||||
return nil, errors.New("无法连接至 Discord 服务器,请稍后重试!")
|
||||
}
|
||||
defer res2.Body.Close()
|
||||
if res2.StatusCode != http.StatusOK {
|
||||
common.SysError("Discord 获取用户信息失败!请检查设置!")
|
||||
return nil, errors.New("Discord 获取用户信息失败!请检查设置!")
|
||||
}
|
||||
|
||||
var discordUser DiscordUser
|
||||
err = json.NewDecoder(res2.Body).Decode(&discordUser)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if discordUser.UID == "" || discordUser.ID == "" {
|
||||
common.SysError("Discord 获取用户信息为空!请检查设置!")
|
||||
return nil, errors.New("Discord 获取用户信息为空!请检查设置!")
|
||||
}
|
||||
return &discordUser, nil
|
||||
}
|
||||
|
||||
func DiscordOAuth(c *gin.Context) {
|
||||
session := sessions.Default(c)
|
||||
state := c.Query("state")
|
||||
if state == "" || session.Get("oauth_state") == nil || state != session.Get("oauth_state").(string) {
|
||||
c.JSON(http.StatusForbidden, gin.H{
|
||||
"success": false,
|
||||
"message": "state is empty or not same",
|
||||
})
|
||||
return
|
||||
}
|
||||
username := session.Get("username")
|
||||
if username != nil {
|
||||
DiscordBind(c)
|
||||
return
|
||||
}
|
||||
if !system_setting.GetDiscordSettings().Enabled {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": false,
|
||||
"message": "管理员未开启通过 Discord 登录以及注册",
|
||||
})
|
||||
return
|
||||
}
|
||||
code := c.Query("code")
|
||||
discordUser, err := getDiscordUserInfoByCode(code)
|
||||
if err != nil {
|
||||
common.ApiError(c, err)
|
||||
return
|
||||
}
|
||||
user := model.User{
|
||||
DiscordId: discordUser.UID,
|
||||
}
|
||||
if model.IsDiscordIdAlreadyTaken(user.DiscordId) {
|
||||
err := user.FillUserByDiscordId()
|
||||
if err != nil {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": false,
|
||||
"message": err.Error(),
|
||||
})
|
||||
return
|
||||
}
|
||||
} else {
|
||||
if common.RegisterEnabled {
|
||||
if discordUser.ID != "" {
|
||||
user.Username = discordUser.ID
|
||||
} else {
|
||||
user.Username = "discord_" + strconv.Itoa(model.GetMaxUserId()+1)
|
||||
}
|
||||
if discordUser.Name != "" {
|
||||
user.DisplayName = discordUser.Name
|
||||
} else {
|
||||
user.DisplayName = "Discord User"
|
||||
}
|
||||
err := user.Insert(0)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": false,
|
||||
"message": err.Error(),
|
||||
})
|
||||
return
|
||||
}
|
||||
} else {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": false,
|
||||
"message": "管理员关闭了新用户注册",
|
||||
})
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
if user.Status != common.UserStatusEnabled {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"message": "用户已被封禁",
|
||||
"success": false,
|
||||
})
|
||||
return
|
||||
}
|
||||
setupLogin(&user, c)
|
||||
}
|
||||
|
||||
func DiscordBind(c *gin.Context) {
|
||||
if !system_setting.GetDiscordSettings().Enabled {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": false,
|
||||
"message": "管理员未开启通过 Discord 登录以及注册",
|
||||
})
|
||||
return
|
||||
}
|
||||
code := c.Query("code")
|
||||
discordUser, err := getDiscordUserInfoByCode(code)
|
||||
if err != nil {
|
||||
common.ApiError(c, err)
|
||||
return
|
||||
}
|
||||
user := model.User{
|
||||
DiscordId: discordUser.UID,
|
||||
}
|
||||
if model.IsDiscordIdAlreadyTaken(user.DiscordId) {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": false,
|
||||
"message": "该 Discord 账户已被绑定",
|
||||
})
|
||||
return
|
||||
}
|
||||
session := sessions.Default(c)
|
||||
id := session.Get("id")
|
||||
user.Id = id.(int)
|
||||
err = user.FillUserById()
|
||||
if err != nil {
|
||||
common.ApiError(c, err)
|
||||
return
|
||||
}
|
||||
user.DiscordId = discordUser.UID
|
||||
err = user.Update(false)
|
||||
if err != nil {
|
||||
common.ApiError(c, err)
|
||||
return
|
||||
}
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": true,
|
||||
"message": "bind",
|
||||
})
|
||||
}
|
||||
@@ -1,240 +0,0 @@
|
||||
package controller
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"strconv"
|
||||
"time"
|
||||
|
||||
"github.com/QuantumNous/new-api/common"
|
||||
"github.com/QuantumNous/new-api/model"
|
||||
|
||||
"github.com/gin-contrib/sessions"
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
type GitHubOAuthResponse struct {
|
||||
AccessToken string `json:"access_token"`
|
||||
Scope string `json:"scope"`
|
||||
TokenType string `json:"token_type"`
|
||||
}
|
||||
|
||||
type GitHubUser struct {
|
||||
Login string `json:"login"`
|
||||
Name string `json:"name"`
|
||||
Email string `json:"email"`
|
||||
}
|
||||
|
||||
func getGitHubUserInfoByCode(code string) (*GitHubUser, error) {
|
||||
if code == "" {
|
||||
return nil, errors.New("无效的参数")
|
||||
}
|
||||
values := map[string]string{"client_id": common.GitHubClientId, "client_secret": common.GitHubClientSecret, "code": code}
|
||||
jsonData, err := json.Marshal(values)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
req, err := http.NewRequest("POST", "https://github.com/login/oauth/access_token", bytes.NewBuffer(jsonData))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
req.Header.Set("Accept", "application/json")
|
||||
client := http.Client{
|
||||
Timeout: 20 * time.Second,
|
||||
}
|
||||
res, err := client.Do(req)
|
||||
if err != nil {
|
||||
common.SysLog(err.Error())
|
||||
return nil, errors.New("无法连接至 GitHub 服务器,请稍后重试!")
|
||||
}
|
||||
defer res.Body.Close()
|
||||
var oAuthResponse GitHubOAuthResponse
|
||||
err = json.NewDecoder(res.Body).Decode(&oAuthResponse)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
req, err = http.NewRequest("GET", "https://api.github.com/user", nil)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", oAuthResponse.AccessToken))
|
||||
res2, err := client.Do(req)
|
||||
if err != nil {
|
||||
common.SysLog(err.Error())
|
||||
return nil, errors.New("无法连接至 GitHub 服务器,请稍后重试!")
|
||||
}
|
||||
defer res2.Body.Close()
|
||||
var githubUser GitHubUser
|
||||
err = json.NewDecoder(res2.Body).Decode(&githubUser)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if githubUser.Login == "" {
|
||||
return nil, errors.New("返回值非法,用户字段为空,请稍后重试!")
|
||||
}
|
||||
return &githubUser, nil
|
||||
}
|
||||
|
||||
func GitHubOAuth(c *gin.Context) {
|
||||
session := sessions.Default(c)
|
||||
state := c.Query("state")
|
||||
if state == "" || session.Get("oauth_state") == nil || state != session.Get("oauth_state").(string) {
|
||||
c.JSON(http.StatusForbidden, gin.H{
|
||||
"success": false,
|
||||
"message": "state is empty or not same",
|
||||
})
|
||||
return
|
||||
}
|
||||
username := session.Get("username")
|
||||
if username != nil {
|
||||
GitHubBind(c)
|
||||
return
|
||||
}
|
||||
|
||||
if !common.GitHubOAuthEnabled {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": false,
|
||||
"message": "管理员未开启通过 GitHub 登录以及注册",
|
||||
})
|
||||
return
|
||||
}
|
||||
code := c.Query("code")
|
||||
githubUser, err := getGitHubUserInfoByCode(code)
|
||||
if err != nil {
|
||||
common.ApiError(c, err)
|
||||
return
|
||||
}
|
||||
user := model.User{
|
||||
GitHubId: githubUser.Login,
|
||||
}
|
||||
// IsGitHubIdAlreadyTaken is unscoped
|
||||
if model.IsGitHubIdAlreadyTaken(user.GitHubId) {
|
||||
// FillUserByGitHubId is scoped
|
||||
err := user.FillUserByGitHubId()
|
||||
if err != nil {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": false,
|
||||
"message": err.Error(),
|
||||
})
|
||||
return
|
||||
}
|
||||
// if user.Id == 0 , user has been deleted
|
||||
if user.Id == 0 {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": false,
|
||||
"message": "用户已注销",
|
||||
})
|
||||
return
|
||||
}
|
||||
} else {
|
||||
if common.RegisterEnabled {
|
||||
user.Username = "github_" + strconv.Itoa(model.GetMaxUserId()+1)
|
||||
if githubUser.Name != "" {
|
||||
user.DisplayName = githubUser.Name
|
||||
} else {
|
||||
user.DisplayName = "GitHub User"
|
||||
}
|
||||
user.Email = githubUser.Email
|
||||
user.Role = common.RoleCommonUser
|
||||
user.Status = common.UserStatusEnabled
|
||||
affCode := session.Get("aff")
|
||||
inviterId := 0
|
||||
if affCode != nil {
|
||||
inviterId, _ = model.GetUserIdByAffCode(affCode.(string))
|
||||
}
|
||||
|
||||
if err := user.Insert(inviterId); err != nil {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": false,
|
||||
"message": err.Error(),
|
||||
})
|
||||
return
|
||||
}
|
||||
} else {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": false,
|
||||
"message": "管理员关闭了新用户注册",
|
||||
})
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
if user.Status != common.UserStatusEnabled {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"message": "用户已被封禁",
|
||||
"success": false,
|
||||
})
|
||||
return
|
||||
}
|
||||
setupLogin(&user, c)
|
||||
}
|
||||
|
||||
func GitHubBind(c *gin.Context) {
|
||||
if !common.GitHubOAuthEnabled {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": false,
|
||||
"message": "管理员未开启通过 GitHub 登录以及注册",
|
||||
})
|
||||
return
|
||||
}
|
||||
code := c.Query("code")
|
||||
githubUser, err := getGitHubUserInfoByCode(code)
|
||||
if err != nil {
|
||||
common.ApiError(c, err)
|
||||
return
|
||||
}
|
||||
user := model.User{
|
||||
GitHubId: githubUser.Login,
|
||||
}
|
||||
if model.IsGitHubIdAlreadyTaken(user.GitHubId) {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": false,
|
||||
"message": "该 GitHub 账户已被绑定",
|
||||
})
|
||||
return
|
||||
}
|
||||
session := sessions.Default(c)
|
||||
id := session.Get("id")
|
||||
// id := c.GetInt("id") // critical bug!
|
||||
user.Id = id.(int)
|
||||
err = user.FillUserById()
|
||||
if err != nil {
|
||||
common.ApiError(c, err)
|
||||
return
|
||||
}
|
||||
user.GitHubId = githubUser.Login
|
||||
err = user.Update(false)
|
||||
if err != nil {
|
||||
common.ApiError(c, err)
|
||||
return
|
||||
}
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": true,
|
||||
"message": "bind",
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
func GenerateOAuthCode(c *gin.Context) {
|
||||
session := sessions.Default(c)
|
||||
state := common.GetRandomString(12)
|
||||
affCode := c.Query("aff")
|
||||
if affCode != "" {
|
||||
session.Set("aff", affCode)
|
||||
}
|
||||
session.Set("oauth_state", state)
|
||||
err := session.Save()
|
||||
if err != nil {
|
||||
common.ApiError(c, err)
|
||||
return
|
||||
}
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": true,
|
||||
"message": "",
|
||||
"data": state,
|
||||
})
|
||||
}
|
||||
@@ -1,268 +0,0 @@
|
||||
package controller
|
||||
|
||||
import (
|
||||
"encoding/base64"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"strconv"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/QuantumNous/new-api/common"
|
||||
"github.com/QuantumNous/new-api/model"
|
||||
|
||||
"github.com/gin-contrib/sessions"
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
type LinuxdoUser struct {
|
||||
Id int `json:"id"`
|
||||
Username string `json:"username"`
|
||||
Name string `json:"name"`
|
||||
Active bool `json:"active"`
|
||||
TrustLevel int `json:"trust_level"`
|
||||
Silenced bool `json:"silenced"`
|
||||
}
|
||||
|
||||
func LinuxDoBind(c *gin.Context) {
|
||||
if !common.LinuxDOOAuthEnabled {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": false,
|
||||
"message": "管理员未开启通过 Linux DO 登录以及注册",
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
code := c.Query("code")
|
||||
linuxdoUser, err := getLinuxdoUserInfoByCode(code, c)
|
||||
if err != nil {
|
||||
common.ApiError(c, err)
|
||||
return
|
||||
}
|
||||
|
||||
user := model.User{
|
||||
LinuxDOId: strconv.Itoa(linuxdoUser.Id),
|
||||
}
|
||||
|
||||
if model.IsLinuxDOIdAlreadyTaken(user.LinuxDOId) {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": false,
|
||||
"message": "该 Linux DO 账户已被绑定",
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
session := sessions.Default(c)
|
||||
id := session.Get("id")
|
||||
user.Id = id.(int)
|
||||
|
||||
err = user.FillUserById()
|
||||
if err != nil {
|
||||
common.ApiError(c, err)
|
||||
return
|
||||
}
|
||||
|
||||
user.LinuxDOId = strconv.Itoa(linuxdoUser.Id)
|
||||
err = user.Update(false)
|
||||
if err != nil {
|
||||
common.ApiError(c, err)
|
||||
return
|
||||
}
|
||||
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": true,
|
||||
"message": "bind",
|
||||
})
|
||||
}
|
||||
|
||||
func getLinuxdoUserInfoByCode(code string, c *gin.Context) (*LinuxdoUser, error) {
|
||||
if code == "" {
|
||||
return nil, errors.New("invalid code")
|
||||
}
|
||||
|
||||
// Get access token using Basic auth
|
||||
tokenEndpoint := common.GetEnvOrDefaultString("LINUX_DO_TOKEN_ENDPOINT", "https://connect.linux.do/oauth2/token")
|
||||
credentials := common.LinuxDOClientId + ":" + common.LinuxDOClientSecret
|
||||
basicAuth := "Basic " + base64.StdEncoding.EncodeToString([]byte(credentials))
|
||||
|
||||
// Get redirect URI from request
|
||||
scheme := "http"
|
||||
if c.Request.TLS != nil {
|
||||
scheme = "https"
|
||||
}
|
||||
redirectURI := fmt.Sprintf("%s://%s/api/oauth/linuxdo", scheme, c.Request.Host)
|
||||
|
||||
data := url.Values{}
|
||||
data.Set("grant_type", "authorization_code")
|
||||
data.Set("code", code)
|
||||
data.Set("redirect_uri", redirectURI)
|
||||
|
||||
req, err := http.NewRequest("POST", tokenEndpoint, strings.NewReader(data.Encode()))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
req.Header.Set("Authorization", basicAuth)
|
||||
req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
|
||||
req.Header.Set("Accept", "application/json")
|
||||
|
||||
client := http.Client{Timeout: 5 * time.Second}
|
||||
res, err := client.Do(req)
|
||||
if err != nil {
|
||||
return nil, errors.New("failed to connect to Linux DO server")
|
||||
}
|
||||
defer res.Body.Close()
|
||||
|
||||
var tokenRes struct {
|
||||
AccessToken string `json:"access_token"`
|
||||
Message string `json:"message"`
|
||||
}
|
||||
if err := json.NewDecoder(res.Body).Decode(&tokenRes); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if tokenRes.AccessToken == "" {
|
||||
return nil, fmt.Errorf("failed to get access token: %s", tokenRes.Message)
|
||||
}
|
||||
|
||||
// Get user info
|
||||
userEndpoint := common.GetEnvOrDefaultString("LINUX_DO_USER_ENDPOINT", "https://connect.linux.do/api/user")
|
||||
req, err = http.NewRequest("GET", userEndpoint, nil)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
req.Header.Set("Authorization", "Bearer "+tokenRes.AccessToken)
|
||||
req.Header.Set("Accept", "application/json")
|
||||
|
||||
res2, err := client.Do(req)
|
||||
if err != nil {
|
||||
return nil, errors.New("failed to get user info from Linux DO")
|
||||
}
|
||||
defer res2.Body.Close()
|
||||
|
||||
var linuxdoUser LinuxdoUser
|
||||
if err := json.NewDecoder(res2.Body).Decode(&linuxdoUser); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if linuxdoUser.Id == 0 {
|
||||
return nil, errors.New("invalid user info returned")
|
||||
}
|
||||
|
||||
return &linuxdoUser, nil
|
||||
}
|
||||
|
||||
func LinuxdoOAuth(c *gin.Context) {
|
||||
session := sessions.Default(c)
|
||||
|
||||
errorCode := c.Query("error")
|
||||
if errorCode != "" {
|
||||
errorDescription := c.Query("error_description")
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": false,
|
||||
"message": errorDescription,
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
state := c.Query("state")
|
||||
if state == "" || session.Get("oauth_state") == nil || state != session.Get("oauth_state").(string) {
|
||||
c.JSON(http.StatusForbidden, gin.H{
|
||||
"success": false,
|
||||
"message": "state is empty or not same",
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
username := session.Get("username")
|
||||
if username != nil {
|
||||
LinuxDoBind(c)
|
||||
return
|
||||
}
|
||||
|
||||
if !common.LinuxDOOAuthEnabled {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": false,
|
||||
"message": "管理员未开启通过 Linux DO 登录以及注册",
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
code := c.Query("code")
|
||||
linuxdoUser, err := getLinuxdoUserInfoByCode(code, c)
|
||||
if err != nil {
|
||||
common.ApiError(c, err)
|
||||
return
|
||||
}
|
||||
|
||||
user := model.User{
|
||||
LinuxDOId: strconv.Itoa(linuxdoUser.Id),
|
||||
}
|
||||
|
||||
// Check if user exists
|
||||
if model.IsLinuxDOIdAlreadyTaken(user.LinuxDOId) {
|
||||
err := user.FillUserByLinuxDOId()
|
||||
if err != nil {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": false,
|
||||
"message": err.Error(),
|
||||
})
|
||||
return
|
||||
}
|
||||
if user.Id == 0 {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": false,
|
||||
"message": "用户已注销",
|
||||
})
|
||||
return
|
||||
}
|
||||
} else {
|
||||
if common.RegisterEnabled {
|
||||
if linuxdoUser.TrustLevel >= common.LinuxDOMinimumTrustLevel {
|
||||
user.Username = "linuxdo_" + strconv.Itoa(model.GetMaxUserId()+1)
|
||||
user.DisplayName = linuxdoUser.Name
|
||||
user.Role = common.RoleCommonUser
|
||||
user.Status = common.UserStatusEnabled
|
||||
|
||||
affCode := session.Get("aff")
|
||||
inviterId := 0
|
||||
if affCode != nil {
|
||||
inviterId, _ = model.GetUserIdByAffCode(affCode.(string))
|
||||
}
|
||||
|
||||
if err := user.Insert(inviterId); err != nil {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": false,
|
||||
"message": err.Error(),
|
||||
})
|
||||
return
|
||||
}
|
||||
} else {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": false,
|
||||
"message": "Linux DO 信任等级未达到管理员设置的最低信任等级",
|
||||
})
|
||||
return
|
||||
}
|
||||
} else {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": false,
|
||||
"message": "管理员关闭了新用户注册",
|
||||
})
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
if user.Status != common.UserStatusEnabled {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"message": "用户已被封禁",
|
||||
"success": false,
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
setupLogin(&user, c)
|
||||
}
|
||||
@@ -10,6 +10,7 @@ import (
|
||||
"github.com/QuantumNous/new-api/constant"
|
||||
"github.com/QuantumNous/new-api/middleware"
|
||||
"github.com/QuantumNous/new-api/model"
|
||||
"github.com/QuantumNous/new-api/oauth"
|
||||
"github.com/QuantumNous/new-api/setting"
|
||||
"github.com/QuantumNous/new-api/setting/console_setting"
|
||||
"github.com/QuantumNous/new-api/setting/operation_setting"
|
||||
@@ -129,6 +130,30 @@ func GetStatus(c *gin.Context) {
|
||||
data["faq"] = console_setting.GetFAQ()
|
||||
}
|
||||
|
||||
// Add enabled custom OAuth providers
|
||||
customProviders := oauth.GetEnabledCustomProviders()
|
||||
if len(customProviders) > 0 {
|
||||
type CustomOAuthInfo struct {
|
||||
Name string `json:"name"`
|
||||
Slug string `json:"slug"`
|
||||
ClientId string `json:"client_id"`
|
||||
AuthorizationEndpoint string `json:"authorization_endpoint"`
|
||||
Scopes string `json:"scopes"`
|
||||
}
|
||||
providersInfo := make([]CustomOAuthInfo, 0, len(customProviders))
|
||||
for _, p := range customProviders {
|
||||
config := p.GetConfig()
|
||||
providersInfo = append(providersInfo, CustomOAuthInfo{
|
||||
Name: config.Name,
|
||||
Slug: config.Slug,
|
||||
ClientId: config.ClientId,
|
||||
AuthorizationEndpoint: config.AuthorizationEndpoint,
|
||||
Scopes: config.Scopes,
|
||||
})
|
||||
}
|
||||
data["custom_oauth_providers"] = providersInfo
|
||||
}
|
||||
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": true,
|
||||
"message": "",
|
||||
|
||||
312
controller/oauth.go
Normal file
312
controller/oauth.go
Normal file
@@ -0,0 +1,312 @@
|
||||
package controller
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"net/http"
|
||||
"strconv"
|
||||
|
||||
"github.com/QuantumNous/new-api/common"
|
||||
"github.com/QuantumNous/new-api/i18n"
|
||||
"github.com/QuantumNous/new-api/model"
|
||||
"github.com/QuantumNous/new-api/oauth"
|
||||
"github.com/gin-contrib/sessions"
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
// providerParams returns map with Provider key for i18n templates
|
||||
func providerParams(name string) map[string]any {
|
||||
return map[string]any{"Provider": name}
|
||||
}
|
||||
|
||||
// GenerateOAuthCode generates a state code for OAuth CSRF protection
|
||||
func GenerateOAuthCode(c *gin.Context) {
|
||||
session := sessions.Default(c)
|
||||
state := common.GetRandomString(12)
|
||||
affCode := c.Query("aff")
|
||||
if affCode != "" {
|
||||
session.Set("aff", affCode)
|
||||
}
|
||||
session.Set("oauth_state", state)
|
||||
err := session.Save()
|
||||
if err != nil {
|
||||
common.ApiError(c, err)
|
||||
return
|
||||
}
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": true,
|
||||
"message": "",
|
||||
"data": state,
|
||||
})
|
||||
}
|
||||
|
||||
// HandleOAuth handles OAuth callback for all standard OAuth providers
|
||||
func HandleOAuth(c *gin.Context) {
|
||||
providerName := c.Param("provider")
|
||||
provider := oauth.GetProvider(providerName)
|
||||
if provider == nil {
|
||||
c.JSON(http.StatusBadRequest, gin.H{
|
||||
"success": false,
|
||||
"message": i18n.T(c, i18n.MsgOAuthUnknownProvider),
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
session := sessions.Default(c)
|
||||
|
||||
// 1. Validate state (CSRF protection)
|
||||
state := c.Query("state")
|
||||
if state == "" || session.Get("oauth_state") == nil || state != session.Get("oauth_state").(string) {
|
||||
c.JSON(http.StatusForbidden, gin.H{
|
||||
"success": false,
|
||||
"message": i18n.T(c, i18n.MsgOAuthStateInvalid),
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
// 2. Check if user is already logged in (bind flow)
|
||||
username := session.Get("username")
|
||||
if username != nil {
|
||||
handleOAuthBind(c, provider)
|
||||
return
|
||||
}
|
||||
|
||||
// 3. Check if provider is enabled
|
||||
if !provider.IsEnabled() {
|
||||
common.ApiErrorI18n(c, i18n.MsgOAuthNotEnabled, providerParams(provider.GetName()))
|
||||
return
|
||||
}
|
||||
|
||||
// 4. Handle error from provider
|
||||
errorCode := c.Query("error")
|
||||
if errorCode != "" {
|
||||
errorDescription := c.Query("error_description")
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": false,
|
||||
"message": errorDescription,
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
// 5. Exchange code for token
|
||||
code := c.Query("code")
|
||||
token, err := provider.ExchangeToken(c.Request.Context(), code, c)
|
||||
if err != nil {
|
||||
handleOAuthError(c, err)
|
||||
return
|
||||
}
|
||||
|
||||
// 6. Get user info
|
||||
oauthUser, err := provider.GetUserInfo(c.Request.Context(), token)
|
||||
if err != nil {
|
||||
handleOAuthError(c, err)
|
||||
return
|
||||
}
|
||||
|
||||
// 7. Find or create user
|
||||
user, err := findOrCreateOAuthUser(c, provider, oauthUser, session)
|
||||
if err != nil {
|
||||
switch err.(type) {
|
||||
case *OAuthUserDeletedError:
|
||||
common.ApiErrorI18n(c, i18n.MsgOAuthUserDeleted)
|
||||
case *OAuthRegistrationDisabledError:
|
||||
common.ApiErrorI18n(c, i18n.MsgUserRegisterDisabled)
|
||||
default:
|
||||
common.ApiError(c, err)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
// 8. Check user status
|
||||
if user.Status != common.UserStatusEnabled {
|
||||
common.ApiErrorI18n(c, i18n.MsgOAuthUserBanned)
|
||||
return
|
||||
}
|
||||
|
||||
// 9. Setup login
|
||||
setupLogin(user, c)
|
||||
}
|
||||
|
||||
// handleOAuthBind handles binding OAuth account to existing user
|
||||
func handleOAuthBind(c *gin.Context, provider oauth.Provider) {
|
||||
if !provider.IsEnabled() {
|
||||
common.ApiErrorI18n(c, i18n.MsgOAuthNotEnabled, providerParams(provider.GetName()))
|
||||
return
|
||||
}
|
||||
|
||||
// Exchange code for token
|
||||
code := c.Query("code")
|
||||
token, err := provider.ExchangeToken(c.Request.Context(), code, c)
|
||||
if err != nil {
|
||||
handleOAuthError(c, err)
|
||||
return
|
||||
}
|
||||
|
||||
// Get user info
|
||||
oauthUser, err := provider.GetUserInfo(c.Request.Context(), token)
|
||||
if err != nil {
|
||||
handleOAuthError(c, err)
|
||||
return
|
||||
}
|
||||
|
||||
// Check if this OAuth account is already bound (check both new ID and legacy ID)
|
||||
if provider.IsUserIDTaken(oauthUser.ProviderUserID) {
|
||||
common.ApiErrorI18n(c, i18n.MsgOAuthAlreadyBound, providerParams(provider.GetName()))
|
||||
return
|
||||
}
|
||||
// Also check legacy ID to prevent duplicate bindings during migration period
|
||||
if legacyID, ok := oauthUser.Extra["legacy_id"].(string); ok && legacyID != "" {
|
||||
if provider.IsUserIDTaken(legacyID) {
|
||||
common.ApiErrorI18n(c, i18n.MsgOAuthAlreadyBound, providerParams(provider.GetName()))
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
// Get current user from session
|
||||
session := sessions.Default(c)
|
||||
id := session.Get("id")
|
||||
user := model.User{Id: id.(int)}
|
||||
err = user.FillUserById()
|
||||
if err != nil {
|
||||
common.ApiError(c, err)
|
||||
return
|
||||
}
|
||||
|
||||
// Handle binding based on provider type
|
||||
if genericProvider, ok := provider.(*oauth.GenericOAuthProvider); ok {
|
||||
// Custom provider: use user_oauth_bindings table
|
||||
err = model.UpdateUserOAuthBinding(user.Id, genericProvider.GetProviderId(), oauthUser.ProviderUserID)
|
||||
if err != nil {
|
||||
common.ApiError(c, err)
|
||||
return
|
||||
}
|
||||
} else {
|
||||
// Built-in provider: update user record directly
|
||||
provider.SetProviderUserID(&user, oauthUser.ProviderUserID)
|
||||
err = user.Update(false)
|
||||
if err != nil {
|
||||
common.ApiError(c, err)
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
common.ApiSuccessI18n(c, i18n.MsgOAuthBindSuccess, nil)
|
||||
}
|
||||
|
||||
// findOrCreateOAuthUser finds existing user or creates new user
|
||||
func findOrCreateOAuthUser(c *gin.Context, provider oauth.Provider, oauthUser *oauth.OAuthUser, session sessions.Session) (*model.User, error) {
|
||||
user := &model.User{}
|
||||
|
||||
// Check if user already exists with new ID
|
||||
if provider.IsUserIDTaken(oauthUser.ProviderUserID) {
|
||||
err := provider.FillUserByProviderID(user, oauthUser.ProviderUserID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
// Check if user has been deleted
|
||||
if user.Id == 0 {
|
||||
return nil, &OAuthUserDeletedError{}
|
||||
}
|
||||
return user, nil
|
||||
}
|
||||
|
||||
// Try to find user with legacy ID (for GitHub migration from login to numeric ID)
|
||||
if legacyID, ok := oauthUser.Extra["legacy_id"].(string); ok && legacyID != "" {
|
||||
if provider.IsUserIDTaken(legacyID) {
|
||||
err := provider.FillUserByProviderID(user, legacyID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if user.Id != 0 {
|
||||
// Found user with legacy ID, migrate to new ID
|
||||
common.SysLog(fmt.Sprintf("[OAuth] Migrating user %d from legacy_id=%s to new_id=%s",
|
||||
user.Id, legacyID, oauthUser.ProviderUserID))
|
||||
if err := user.UpdateGitHubId(oauthUser.ProviderUserID); err != nil {
|
||||
common.SysError(fmt.Sprintf("[OAuth] Failed to migrate user %d: %s", user.Id, err.Error()))
|
||||
// Continue with login even if migration fails
|
||||
}
|
||||
return user, nil
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// User doesn't exist, create new user if registration is enabled
|
||||
if !common.RegisterEnabled {
|
||||
return nil, &OAuthRegistrationDisabledError{}
|
||||
}
|
||||
|
||||
// Set up new user
|
||||
user.Username = provider.GetProviderPrefix() + strconv.Itoa(model.GetMaxUserId()+1)
|
||||
if oauthUser.DisplayName != "" {
|
||||
user.DisplayName = oauthUser.DisplayName
|
||||
} else if oauthUser.Username != "" {
|
||||
user.DisplayName = oauthUser.Username
|
||||
} else {
|
||||
user.DisplayName = provider.GetName() + " User"
|
||||
}
|
||||
if oauthUser.Email != "" {
|
||||
user.Email = oauthUser.Email
|
||||
}
|
||||
user.Role = common.RoleCommonUser
|
||||
user.Status = common.UserStatusEnabled
|
||||
|
||||
// Handle affiliate code
|
||||
affCode := session.Get("aff")
|
||||
inviterId := 0
|
||||
if affCode != nil {
|
||||
inviterId, _ = model.GetUserIdByAffCode(affCode.(string))
|
||||
}
|
||||
|
||||
if err := user.Insert(inviterId); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// For custom providers, create the binding after user is created
|
||||
if genericProvider, ok := provider.(*oauth.GenericOAuthProvider); ok {
|
||||
binding := &model.UserOAuthBinding{
|
||||
UserId: user.Id,
|
||||
ProviderId: genericProvider.GetProviderId(),
|
||||
ProviderUserId: oauthUser.ProviderUserID,
|
||||
}
|
||||
if err := model.CreateUserOAuthBinding(binding); err != nil {
|
||||
common.SysError(fmt.Sprintf("[OAuth] Failed to create binding for user %d: %s", user.Id, err.Error()))
|
||||
// Don't fail the registration, just log the error
|
||||
}
|
||||
} else {
|
||||
// Built-in provider: set the provider user ID on the user model
|
||||
provider.SetProviderUserID(user, oauthUser.ProviderUserID)
|
||||
if err := user.Update(false); err != nil {
|
||||
common.SysError(fmt.Sprintf("[OAuth] Failed to update provider ID for user %d: %s", user.Id, err.Error()))
|
||||
}
|
||||
}
|
||||
|
||||
return user, nil
|
||||
}
|
||||
|
||||
// Error types for OAuth
|
||||
type OAuthUserDeletedError struct{}
|
||||
|
||||
func (e *OAuthUserDeletedError) Error() string {
|
||||
return "user has been deleted"
|
||||
}
|
||||
|
||||
type OAuthRegistrationDisabledError struct{}
|
||||
|
||||
func (e *OAuthRegistrationDisabledError) Error() string {
|
||||
return "registration is disabled"
|
||||
}
|
||||
|
||||
// handleOAuthError handles OAuth errors and returns translated message
|
||||
func handleOAuthError(c *gin.Context, err error) {
|
||||
switch e := err.(type) {
|
||||
case *oauth.OAuthError:
|
||||
if e.Params != nil {
|
||||
common.ApiErrorI18n(c, e.MsgKey, e.Params)
|
||||
} else {
|
||||
common.ApiErrorI18n(c, e.MsgKey)
|
||||
}
|
||||
case *oauth.TrustLevelError:
|
||||
common.ApiErrorI18n(c, i18n.MsgOAuthTrustLevelLow)
|
||||
default:
|
||||
common.ApiError(c, err)
|
||||
}
|
||||
}
|
||||
@@ -1,228 +0,0 @@
|
||||
package controller
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"strconv"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/QuantumNous/new-api/common"
|
||||
"github.com/QuantumNous/new-api/model"
|
||||
"github.com/QuantumNous/new-api/setting/system_setting"
|
||||
|
||||
"github.com/gin-contrib/sessions"
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
type OidcResponse struct {
|
||||
AccessToken string `json:"access_token"`
|
||||
IDToken string `json:"id_token"`
|
||||
RefreshToken string `json:"refresh_token"`
|
||||
TokenType string `json:"token_type"`
|
||||
ExpiresIn int `json:"expires_in"`
|
||||
Scope string `json:"scope"`
|
||||
}
|
||||
|
||||
type OidcUser struct {
|
||||
OpenID string `json:"sub"`
|
||||
Email string `json:"email"`
|
||||
Name string `json:"name"`
|
||||
PreferredUsername string `json:"preferred_username"`
|
||||
Picture string `json:"picture"`
|
||||
}
|
||||
|
||||
func getOidcUserInfoByCode(code string) (*OidcUser, error) {
|
||||
if code == "" {
|
||||
return nil, errors.New("无效的参数")
|
||||
}
|
||||
|
||||
values := url.Values{}
|
||||
values.Set("client_id", system_setting.GetOIDCSettings().ClientId)
|
||||
values.Set("client_secret", system_setting.GetOIDCSettings().ClientSecret)
|
||||
values.Set("code", code)
|
||||
values.Set("grant_type", "authorization_code")
|
||||
values.Set("redirect_uri", fmt.Sprintf("%s/oauth/oidc", system_setting.ServerAddress))
|
||||
formData := values.Encode()
|
||||
req, err := http.NewRequest("POST", system_setting.GetOIDCSettings().TokenEndpoint, strings.NewReader(formData))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
|
||||
req.Header.Set("Accept", "application/json")
|
||||
client := http.Client{
|
||||
Timeout: 5 * time.Second,
|
||||
}
|
||||
res, err := client.Do(req)
|
||||
if err != nil {
|
||||
common.SysLog(err.Error())
|
||||
return nil, errors.New("无法连接至 OIDC 服务器,请稍后重试!")
|
||||
}
|
||||
defer res.Body.Close()
|
||||
var oidcResponse OidcResponse
|
||||
err = json.NewDecoder(res.Body).Decode(&oidcResponse)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if oidcResponse.AccessToken == "" {
|
||||
common.SysLog("OIDC 获取 Token 失败,请检查设置!")
|
||||
return nil, errors.New("OIDC 获取 Token 失败,请检查设置!")
|
||||
}
|
||||
|
||||
req, err = http.NewRequest("GET", system_setting.GetOIDCSettings().UserInfoEndpoint, nil)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
req.Header.Set("Authorization", "Bearer "+oidcResponse.AccessToken)
|
||||
res2, err := client.Do(req)
|
||||
if err != nil {
|
||||
common.SysLog(err.Error())
|
||||
return nil, errors.New("无法连接至 OIDC 服务器,请稍后重试!")
|
||||
}
|
||||
defer res2.Body.Close()
|
||||
if res2.StatusCode != http.StatusOK {
|
||||
common.SysLog("OIDC 获取用户信息失败!请检查设置!")
|
||||
return nil, errors.New("OIDC 获取用户信息失败!请检查设置!")
|
||||
}
|
||||
|
||||
var oidcUser OidcUser
|
||||
err = json.NewDecoder(res2.Body).Decode(&oidcUser)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if oidcUser.OpenID == "" || oidcUser.Email == "" {
|
||||
common.SysLog("OIDC 获取用户信息为空!请检查设置!")
|
||||
return nil, errors.New("OIDC 获取用户信息为空!请检查设置!")
|
||||
}
|
||||
return &oidcUser, nil
|
||||
}
|
||||
|
||||
func OidcAuth(c *gin.Context) {
|
||||
session := sessions.Default(c)
|
||||
state := c.Query("state")
|
||||
if state == "" || session.Get("oauth_state") == nil || state != session.Get("oauth_state").(string) {
|
||||
c.JSON(http.StatusForbidden, gin.H{
|
||||
"success": false,
|
||||
"message": "state is empty or not same",
|
||||
})
|
||||
return
|
||||
}
|
||||
username := session.Get("username")
|
||||
if username != nil {
|
||||
OidcBind(c)
|
||||
return
|
||||
}
|
||||
if !system_setting.GetOIDCSettings().Enabled {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": false,
|
||||
"message": "管理员未开启通过 OIDC 登录以及注册",
|
||||
})
|
||||
return
|
||||
}
|
||||
code := c.Query("code")
|
||||
oidcUser, err := getOidcUserInfoByCode(code)
|
||||
if err != nil {
|
||||
common.ApiError(c, err)
|
||||
return
|
||||
}
|
||||
user := model.User{
|
||||
OidcId: oidcUser.OpenID,
|
||||
}
|
||||
if model.IsOidcIdAlreadyTaken(user.OidcId) {
|
||||
err := user.FillUserByOidcId()
|
||||
if err != nil {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": false,
|
||||
"message": err.Error(),
|
||||
})
|
||||
return
|
||||
}
|
||||
} else {
|
||||
if common.RegisterEnabled {
|
||||
user.Email = oidcUser.Email
|
||||
if oidcUser.PreferredUsername != "" {
|
||||
user.Username = oidcUser.PreferredUsername
|
||||
} else {
|
||||
user.Username = "oidc_" + strconv.Itoa(model.GetMaxUserId()+1)
|
||||
}
|
||||
if oidcUser.Name != "" {
|
||||
user.DisplayName = oidcUser.Name
|
||||
} else {
|
||||
user.DisplayName = "OIDC User"
|
||||
}
|
||||
err := user.Insert(0)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": false,
|
||||
"message": err.Error(),
|
||||
})
|
||||
return
|
||||
}
|
||||
} else {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": false,
|
||||
"message": "管理员关闭了新用户注册",
|
||||
})
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
if user.Status != common.UserStatusEnabled {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"message": "用户已被封禁",
|
||||
"success": false,
|
||||
})
|
||||
return
|
||||
}
|
||||
setupLogin(&user, c)
|
||||
}
|
||||
|
||||
func OidcBind(c *gin.Context) {
|
||||
if !system_setting.GetOIDCSettings().Enabled {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": false,
|
||||
"message": "管理员未开启通过 OIDC 登录以及注册",
|
||||
})
|
||||
return
|
||||
}
|
||||
code := c.Query("code")
|
||||
oidcUser, err := getOidcUserInfoByCode(code)
|
||||
if err != nil {
|
||||
common.ApiError(c, err)
|
||||
return
|
||||
}
|
||||
user := model.User{
|
||||
OidcId: oidcUser.OpenID,
|
||||
}
|
||||
if model.IsOidcIdAlreadyTaken(user.OidcId) {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": false,
|
||||
"message": "该 OIDC 账户已被绑定",
|
||||
})
|
||||
return
|
||||
}
|
||||
session := sessions.Default(c)
|
||||
id := session.Get("id")
|
||||
// id := c.GetInt("id") // critical bug!
|
||||
user.Id = id.(int)
|
||||
err = user.FillUserById()
|
||||
if err != nil {
|
||||
common.ApiError(c, err)
|
||||
return
|
||||
}
|
||||
user.OidcId = oidcUser.OpenID
|
||||
err = user.Update(false)
|
||||
if err != nil {
|
||||
common.ApiError(c, err)
|
||||
return
|
||||
}
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": true,
|
||||
"message": "bind",
|
||||
})
|
||||
return
|
||||
}
|
||||
@@ -1,6 +1,7 @@
|
||||
package controller
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"net/http"
|
||||
"strconv"
|
||||
"strings"
|
||||
@@ -8,6 +9,7 @@ import (
|
||||
"github.com/QuantumNous/new-api/common"
|
||||
"github.com/QuantumNous/new-api/i18n"
|
||||
"github.com/QuantumNous/new-api/model"
|
||||
"github.com/QuantumNous/new-api/setting/operation_setting"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
@@ -31,16 +33,17 @@ func SearchTokens(c *gin.Context) {
|
||||
userId := c.GetInt("id")
|
||||
keyword := c.Query("keyword")
|
||||
token := c.Query("token")
|
||||
tokens, err := model.SearchUserTokens(userId, keyword, token)
|
||||
|
||||
pageInfo := common.GetPageQuery(c)
|
||||
|
||||
tokens, total, err := model.SearchUserTokens(userId, keyword, token, pageInfo.GetStartIdx(), pageInfo.GetPageSize())
|
||||
if err != nil {
|
||||
common.ApiError(c, err)
|
||||
return
|
||||
}
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": true,
|
||||
"message": "",
|
||||
"data": tokens,
|
||||
})
|
||||
pageInfo.SetTotal(int(total))
|
||||
pageInfo.SetItems(tokens)
|
||||
common.ApiSuccess(c, pageInfo)
|
||||
return
|
||||
}
|
||||
|
||||
@@ -157,6 +160,20 @@ func AddToken(c *gin.Context) {
|
||||
return
|
||||
}
|
||||
}
|
||||
// 检查用户令牌数量是否已达上限
|
||||
maxTokens := operation_setting.GetMaxUserTokens()
|
||||
count, err := model.CountUserTokens(c.GetInt("id"))
|
||||
if err != nil {
|
||||
common.ApiError(c, err)
|
||||
return
|
||||
}
|
||||
if int(count) >= maxTokens {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": false,
|
||||
"message": fmt.Sprintf("已达到最大令牌数量限制 (%d)", maxTokens),
|
||||
})
|
||||
return
|
||||
}
|
||||
key, err := common.GenerateKey()
|
||||
if err != nil {
|
||||
common.ApiErrorI18n(c, i18n.MsgTokenGenerateFailed)
|
||||
|
||||
28
i18n/keys.go
28
i18n/keys.go
@@ -264,9 +264,20 @@ const (
|
||||
|
||||
// OAuth related messages
|
||||
const (
|
||||
MsgOAuthInvalidCode = "oauth.invalid_code"
|
||||
MsgOAuthGetUserErr = "oauth.get_user_error"
|
||||
MsgOAuthAccountUsed = "oauth.account_used"
|
||||
MsgOAuthInvalidCode = "oauth.invalid_code"
|
||||
MsgOAuthGetUserErr = "oauth.get_user_error"
|
||||
MsgOAuthAccountUsed = "oauth.account_used"
|
||||
MsgOAuthUnknownProvider = "oauth.unknown_provider"
|
||||
MsgOAuthStateInvalid = "oauth.state_invalid"
|
||||
MsgOAuthNotEnabled = "oauth.not_enabled"
|
||||
MsgOAuthUserDeleted = "oauth.user_deleted"
|
||||
MsgOAuthUserBanned = "oauth.user_banned"
|
||||
MsgOAuthBindSuccess = "oauth.bind_success"
|
||||
MsgOAuthAlreadyBound = "oauth.already_bound"
|
||||
MsgOAuthConnectFailed = "oauth.connect_failed"
|
||||
MsgOAuthTokenFailed = "oauth.token_failed"
|
||||
MsgOAuthUserInfoEmpty = "oauth.user_info_empty"
|
||||
MsgOAuthTrustLevelLow = "oauth.trust_level_low"
|
||||
)
|
||||
|
||||
// Model layer error messages (for translation in controller)
|
||||
@@ -276,3 +287,14 @@ const (
|
||||
MsgUuidDuplicate = "common.uuid_duplicate"
|
||||
MsgInvalidInput = "common.invalid_input"
|
||||
)
|
||||
|
||||
// Custom OAuth provider related messages
|
||||
const (
|
||||
MsgCustomOAuthNotFound = "custom_oauth.not_found"
|
||||
MsgCustomOAuthSlugEmpty = "custom_oauth.slug_empty"
|
||||
MsgCustomOAuthSlugExists = "custom_oauth.slug_exists"
|
||||
MsgCustomOAuthNameEmpty = "custom_oauth.name_empty"
|
||||
MsgCustomOAuthHasBindings = "custom_oauth.has_bindings"
|
||||
MsgCustomOAuthBindingNotFound = "custom_oauth.binding_not_found"
|
||||
MsgCustomOAuthProviderIdInvalid = "custom_oauth.provider_id_field_invalid"
|
||||
)
|
||||
|
||||
@@ -223,9 +223,29 @@ ability.repair_running: "A repair task is already running, please try again late
|
||||
oauth.invalid_code: "Invalid authorization code"
|
||||
oauth.get_user_error: "Failed to get user information"
|
||||
oauth.account_used: "This account has been bound to another user"
|
||||
oauth.unknown_provider: "Unknown OAuth provider"
|
||||
oauth.state_invalid: "State parameter is empty or mismatched"
|
||||
oauth.not_enabled: "{{.Provider}} login and registration has not been enabled by administrator"
|
||||
oauth.user_deleted: "User has been deleted"
|
||||
oauth.user_banned: "User has been banned"
|
||||
oauth.bind_success: "Binding successful"
|
||||
oauth.already_bound: "This {{.Provider}} account has already been bound"
|
||||
oauth.connect_failed: "Unable to connect to {{.Provider}} server, please try again later"
|
||||
oauth.token_failed: "Failed to get token from {{.Provider}}, please check settings"
|
||||
oauth.user_info_empty: "{{.Provider}} returned empty user info, please check settings"
|
||||
oauth.trust_level_low: "Linux DO trust level does not meet the minimum required by administrator"
|
||||
|
||||
# Model layer error messages
|
||||
redeem.failed: "Redemption failed, please try again later"
|
||||
user.create_default_token_error: "Failed to create default token"
|
||||
common.uuid_duplicate: "Please retry, the system generated a duplicate UUID!"
|
||||
common.invalid_input: "Invalid input"
|
||||
|
||||
# Custom OAuth provider messages
|
||||
custom_oauth.not_found: "Custom OAuth provider not found"
|
||||
custom_oauth.slug_empty: "Slug cannot be empty"
|
||||
custom_oauth.slug_exists: "Slug already exists"
|
||||
custom_oauth.name_empty: "Provider name cannot be empty"
|
||||
custom_oauth.has_bindings: "Cannot delete provider with existing user bindings"
|
||||
custom_oauth.binding_not_found: "OAuth binding not found"
|
||||
custom_oauth.provider_id_field_invalid: "Could not extract user ID from provider response"
|
||||
|
||||
@@ -224,9 +224,29 @@ ability.repair_running: "已经有一个修复任务在运行中,请稍后再
|
||||
oauth.invalid_code: "无效的授权码"
|
||||
oauth.get_user_error: "获取用户信息失败"
|
||||
oauth.account_used: "该账户已被其他用户绑定"
|
||||
oauth.unknown_provider: "未知的 OAuth 提供商"
|
||||
oauth.state_invalid: "state 参数为空或不匹配"
|
||||
oauth.not_enabled: "管理员未开启通过 {{.Provider}} 登录以及注册"
|
||||
oauth.user_deleted: "用户已注销"
|
||||
oauth.user_banned: "用户已被封禁"
|
||||
oauth.bind_success: "绑定成功"
|
||||
oauth.already_bound: "该 {{.Provider}} 账户已被绑定"
|
||||
oauth.connect_failed: "无法连接至 {{.Provider}} 服务器,请稍后重试"
|
||||
oauth.token_failed: "{{.Provider}} 获取 Token 失败,请检查设置"
|
||||
oauth.user_info_empty: "{{.Provider}} 获取用户信息为空,请检查设置"
|
||||
oauth.trust_level_low: "Linux DO 信任等级未达到管理员设置的最低信任等级"
|
||||
|
||||
# Model layer error messages
|
||||
redeem.failed: "兑换失败,请稍后重试"
|
||||
user.create_default_token_error: "创建默认令牌失败"
|
||||
common.uuid_duplicate: "请重试,系统生成的 UUID 竟然重复了!"
|
||||
common.invalid_input: "输入不合法"
|
||||
|
||||
# Custom OAuth provider messages
|
||||
custom_oauth.not_found: "自定义 OAuth 提供商不存在"
|
||||
custom_oauth.slug_empty: "标识符不能为空"
|
||||
custom_oauth.slug_exists: "标识符已存在"
|
||||
custom_oauth.name_empty: "提供商名称不能为空"
|
||||
custom_oauth.has_bindings: "无法删除已有用户绑定的提供商"
|
||||
custom_oauth.binding_not_found: "OAuth 绑定不存在"
|
||||
custom_oauth.provider_id_field_invalid: "无法从提供商响应中提取用户 ID"
|
||||
|
||||
8
main.go
8
main.go
@@ -18,6 +18,7 @@ import (
|
||||
"github.com/QuantumNous/new-api/logger"
|
||||
"github.com/QuantumNous/new-api/middleware"
|
||||
"github.com/QuantumNous/new-api/model"
|
||||
"github.com/QuantumNous/new-api/oauth"
|
||||
"github.com/QuantumNous/new-api/router"
|
||||
"github.com/QuantumNous/new-api/service"
|
||||
_ "github.com/QuantumNous/new-api/setting/performance_setting"
|
||||
@@ -291,5 +292,12 @@ func InitResources() error {
|
||||
// Register user language loader for lazy loading
|
||||
i18n.SetUserLangLoader(model.GetUserLanguage)
|
||||
|
||||
// Load custom OAuth providers from database
|
||||
err = oauth.LoadCustomProviders()
|
||||
if err != nil {
|
||||
common.SysError("failed to load custom OAuth providers: " + err.Error())
|
||||
// Don't return error, custom OAuth is not critical
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -115,3 +115,88 @@ func DownloadRateLimit() func(c *gin.Context) {
|
||||
func UploadRateLimit() func(c *gin.Context) {
|
||||
return rateLimitFactory(common.UploadRateLimitNum, common.UploadRateLimitDuration, "UP")
|
||||
}
|
||||
|
||||
// userRateLimitFactory creates a rate limiter keyed by authenticated user ID
|
||||
// instead of client IP, making it resistant to proxy rotation attacks.
|
||||
// Must be used AFTER authentication middleware (UserAuth).
|
||||
func userRateLimitFactory(maxRequestNum int, duration int64, mark string) func(c *gin.Context) {
|
||||
if common.RedisEnabled {
|
||||
return func(c *gin.Context) {
|
||||
userId := c.GetInt("id")
|
||||
if userId == 0 {
|
||||
c.Status(http.StatusUnauthorized)
|
||||
c.Abort()
|
||||
return
|
||||
}
|
||||
key := fmt.Sprintf("rateLimit:%s:user:%d", mark, userId)
|
||||
userRedisRateLimiter(c, maxRequestNum, duration, key)
|
||||
}
|
||||
}
|
||||
// It's safe to call multi times.
|
||||
inMemoryRateLimiter.Init(common.RateLimitKeyExpirationDuration)
|
||||
return func(c *gin.Context) {
|
||||
userId := c.GetInt("id")
|
||||
if userId == 0 {
|
||||
c.Status(http.StatusUnauthorized)
|
||||
c.Abort()
|
||||
return
|
||||
}
|
||||
key := fmt.Sprintf("%s:user:%d", mark, userId)
|
||||
if !inMemoryRateLimiter.Request(key, maxRequestNum, duration) {
|
||||
c.Status(http.StatusTooManyRequests)
|
||||
c.Abort()
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// userRedisRateLimiter is like redisRateLimiter but accepts a pre-built key
|
||||
// (to support user-ID-based keys).
|
||||
func userRedisRateLimiter(c *gin.Context, maxRequestNum int, duration int64, key string) {
|
||||
ctx := context.Background()
|
||||
rdb := common.RDB
|
||||
listLength, err := rdb.LLen(ctx, key).Result()
|
||||
if err != nil {
|
||||
fmt.Println(err.Error())
|
||||
c.Status(http.StatusInternalServerError)
|
||||
c.Abort()
|
||||
return
|
||||
}
|
||||
if listLength < int64(maxRequestNum) {
|
||||
rdb.LPush(ctx, key, time.Now().Format(timeFormat))
|
||||
rdb.Expire(ctx, key, common.RateLimitKeyExpirationDuration)
|
||||
} else {
|
||||
oldTimeStr, _ := rdb.LIndex(ctx, key, -1).Result()
|
||||
oldTime, err := time.Parse(timeFormat, oldTimeStr)
|
||||
if err != nil {
|
||||
fmt.Println(err)
|
||||
c.Status(http.StatusInternalServerError)
|
||||
c.Abort()
|
||||
return
|
||||
}
|
||||
nowTimeStr := time.Now().Format(timeFormat)
|
||||
nowTime, err := time.Parse(timeFormat, nowTimeStr)
|
||||
if err != nil {
|
||||
fmt.Println(err)
|
||||
c.Status(http.StatusInternalServerError)
|
||||
c.Abort()
|
||||
return
|
||||
}
|
||||
if int64(nowTime.Sub(oldTime).Seconds()) < duration {
|
||||
rdb.Expire(ctx, key, common.RateLimitKeyExpirationDuration)
|
||||
c.Status(http.StatusTooManyRequests)
|
||||
c.Abort()
|
||||
return
|
||||
} else {
|
||||
rdb.LPush(ctx, key, time.Now().Format(timeFormat))
|
||||
rdb.LTrim(ctx, key, 0, int64(maxRequestNum-1))
|
||||
rdb.Expire(ctx, key, common.RateLimitKeyExpirationDuration)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// SearchRateLimit returns a per-user rate limiter for search endpoints.
|
||||
// 10 requests per 60 seconds per user (by user ID, not IP).
|
||||
func SearchRateLimit() func(c *gin.Context) {
|
||||
return userRateLimitFactory(common.SearchRateLimitNum, common.SearchRateLimitDuration, "SR")
|
||||
}
|
||||
|
||||
158
model/custom_oauth_provider.go
Normal file
158
model/custom_oauth_provider.go
Normal file
@@ -0,0 +1,158 @@
|
||||
package model
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"strings"
|
||||
"time"
|
||||
)
|
||||
|
||||
// CustomOAuthProvider stores configuration for custom OAuth providers
|
||||
type CustomOAuthProvider struct {
|
||||
Id int `json:"id" gorm:"primaryKey"`
|
||||
Name string `json:"name" gorm:"type:varchar(64);not null"` // Display name, e.g., "GitHub Enterprise"
|
||||
Slug string `json:"slug" gorm:"type:varchar(64);uniqueIndex;not null"` // URL identifier, e.g., "github-enterprise"
|
||||
Enabled bool `json:"enabled" gorm:"default:false"` // Whether this provider is enabled
|
||||
ClientId string `json:"client_id" gorm:"type:varchar(256)"` // OAuth client ID
|
||||
ClientSecret string `json:"-" gorm:"type:varchar(512)"` // OAuth client secret (not returned to frontend)
|
||||
AuthorizationEndpoint string `json:"authorization_endpoint" gorm:"type:varchar(512)"` // Authorization URL
|
||||
TokenEndpoint string `json:"token_endpoint" gorm:"type:varchar(512)"` // Token exchange URL
|
||||
UserInfoEndpoint string `json:"user_info_endpoint" gorm:"type:varchar(512)"` // User info URL
|
||||
Scopes string `json:"scopes" gorm:"type:varchar(256);default:'openid profile email'"` // OAuth scopes
|
||||
|
||||
// Field mapping configuration (supports JSONPath via gjson)
|
||||
UserIdField string `json:"user_id_field" gorm:"type:varchar(128);default:'sub'"` // User ID field path, e.g., "sub", "id", "data.user.id"
|
||||
UsernameField string `json:"username_field" gorm:"type:varchar(128);default:'preferred_username'"` // Username field path
|
||||
DisplayNameField string `json:"display_name_field" gorm:"type:varchar(128);default:'name'"` // Display name field path
|
||||
EmailField string `json:"email_field" gorm:"type:varchar(128);default:'email'"` // Email field path
|
||||
|
||||
// Advanced options
|
||||
WellKnown string `json:"well_known" gorm:"type:varchar(512)"` // OIDC discovery endpoint (optional)
|
||||
AuthStyle int `json:"auth_style" gorm:"default:0"` // 0=auto, 1=params, 2=header (Basic Auth)
|
||||
|
||||
CreatedAt time.Time `json:"created_at"`
|
||||
UpdatedAt time.Time `json:"updated_at"`
|
||||
}
|
||||
|
||||
func (CustomOAuthProvider) TableName() string {
|
||||
return "custom_oauth_providers"
|
||||
}
|
||||
|
||||
// GetAllCustomOAuthProviders returns all custom OAuth providers
|
||||
func GetAllCustomOAuthProviders() ([]*CustomOAuthProvider, error) {
|
||||
var providers []*CustomOAuthProvider
|
||||
err := DB.Order("id asc").Find(&providers).Error
|
||||
return providers, err
|
||||
}
|
||||
|
||||
// GetEnabledCustomOAuthProviders returns all enabled custom OAuth providers
|
||||
func GetEnabledCustomOAuthProviders() ([]*CustomOAuthProvider, error) {
|
||||
var providers []*CustomOAuthProvider
|
||||
err := DB.Where("enabled = ?", true).Order("id asc").Find(&providers).Error
|
||||
return providers, err
|
||||
}
|
||||
|
||||
// GetCustomOAuthProviderById returns a custom OAuth provider by ID
|
||||
func GetCustomOAuthProviderById(id int) (*CustomOAuthProvider, error) {
|
||||
var provider CustomOAuthProvider
|
||||
err := DB.First(&provider, id).Error
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &provider, nil
|
||||
}
|
||||
|
||||
// GetCustomOAuthProviderBySlug returns a custom OAuth provider by slug
|
||||
func GetCustomOAuthProviderBySlug(slug string) (*CustomOAuthProvider, error) {
|
||||
var provider CustomOAuthProvider
|
||||
err := DB.Where("slug = ?", slug).First(&provider).Error
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &provider, nil
|
||||
}
|
||||
|
||||
// CreateCustomOAuthProvider creates a new custom OAuth provider
|
||||
func CreateCustomOAuthProvider(provider *CustomOAuthProvider) error {
|
||||
if err := validateCustomOAuthProvider(provider); err != nil {
|
||||
return err
|
||||
}
|
||||
return DB.Create(provider).Error
|
||||
}
|
||||
|
||||
// UpdateCustomOAuthProvider updates an existing custom OAuth provider
|
||||
func UpdateCustomOAuthProvider(provider *CustomOAuthProvider) error {
|
||||
if err := validateCustomOAuthProvider(provider); err != nil {
|
||||
return err
|
||||
}
|
||||
return DB.Save(provider).Error
|
||||
}
|
||||
|
||||
// DeleteCustomOAuthProvider deletes a custom OAuth provider by ID
|
||||
func DeleteCustomOAuthProvider(id int) error {
|
||||
// First, delete all user bindings for this provider
|
||||
if err := DB.Where("provider_id = ?", id).Delete(&UserOAuthBinding{}).Error; err != nil {
|
||||
return err
|
||||
}
|
||||
return DB.Delete(&CustomOAuthProvider{}, id).Error
|
||||
}
|
||||
|
||||
// IsSlugTaken checks if a slug is already taken by another provider
|
||||
func IsSlugTaken(slug string, excludeId int) bool {
|
||||
var count int64
|
||||
query := DB.Model(&CustomOAuthProvider{}).Where("slug = ?", slug)
|
||||
if excludeId > 0 {
|
||||
query = query.Where("id != ?", excludeId)
|
||||
}
|
||||
query.Count(&count)
|
||||
return count > 0
|
||||
}
|
||||
|
||||
// validateCustomOAuthProvider validates a custom OAuth provider configuration
|
||||
func validateCustomOAuthProvider(provider *CustomOAuthProvider) error {
|
||||
if provider.Name == "" {
|
||||
return errors.New("provider name is required")
|
||||
}
|
||||
if provider.Slug == "" {
|
||||
return errors.New("provider slug is required")
|
||||
}
|
||||
// Slug must be lowercase and contain only alphanumeric characters and hyphens
|
||||
slug := strings.ToLower(provider.Slug)
|
||||
for _, c := range slug {
|
||||
if !((c >= 'a' && c <= 'z') || (c >= '0' && c <= '9') || c == '-') {
|
||||
return errors.New("provider slug must contain only lowercase letters, numbers, and hyphens")
|
||||
}
|
||||
}
|
||||
provider.Slug = slug
|
||||
|
||||
if provider.ClientId == "" {
|
||||
return errors.New("client ID is required")
|
||||
}
|
||||
if provider.AuthorizationEndpoint == "" {
|
||||
return errors.New("authorization endpoint is required")
|
||||
}
|
||||
if provider.TokenEndpoint == "" {
|
||||
return errors.New("token endpoint is required")
|
||||
}
|
||||
if provider.UserInfoEndpoint == "" {
|
||||
return errors.New("user info endpoint is required")
|
||||
}
|
||||
|
||||
// Set defaults for field mappings if empty
|
||||
if provider.UserIdField == "" {
|
||||
provider.UserIdField = "sub"
|
||||
}
|
||||
if provider.UsernameField == "" {
|
||||
provider.UsernameField = "preferred_username"
|
||||
}
|
||||
if provider.DisplayNameField == "" {
|
||||
provider.DisplayNameField = "name"
|
||||
}
|
||||
if provider.EmailField == "" {
|
||||
provider.EmailField = "email"
|
||||
}
|
||||
if provider.Scopes == "" {
|
||||
provider.Scopes = "openid profile email"
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
@@ -274,6 +274,8 @@ func migrateDB() error {
|
||||
&SubscriptionOrder{},
|
||||
&UserSubscription{},
|
||||
&SubscriptionPreConsumeRecord{},
|
||||
&CustomOAuthProvider{},
|
||||
&UserOAuthBinding{},
|
||||
)
|
||||
if err != nil {
|
||||
return err
|
||||
@@ -320,6 +322,8 @@ func migrateDBFast() error {
|
||||
{&SubscriptionOrder{}, "SubscriptionOrder"},
|
||||
{&UserSubscription{}, "UserSubscription"},
|
||||
{&SubscriptionPreConsumeRecord{}, "SubscriptionPreConsumeRecord"},
|
||||
{&CustomOAuthProvider{}, "CustomOAuthProvider"},
|
||||
{&UserOAuthBinding{}, "UserOAuthBinding"},
|
||||
}
|
||||
// 动态计算migration数量,确保errChan缓冲区足够大
|
||||
errChan := make(chan error, len(migrations))
|
||||
|
||||
@@ -6,6 +6,7 @@ import (
|
||||
"strings"
|
||||
|
||||
"github.com/QuantumNous/new-api/common"
|
||||
"github.com/QuantumNous/new-api/setting/operation_setting"
|
||||
"github.com/bytedance/gopkg/util/gopool"
|
||||
"gorm.io/gorm"
|
||||
)
|
||||
@@ -63,12 +64,104 @@ func GetAllUserTokens(userId int, startIdx int, num int) ([]*Token, error) {
|
||||
return tokens, err
|
||||
}
|
||||
|
||||
func SearchUserTokens(userId int, keyword string, token string) (tokens []*Token, err error) {
|
||||
// sanitizeLikePattern 校验并清洗用户输入的 LIKE 搜索模式。
|
||||
// 规则:
|
||||
// 1. 转义 ! 和 _(使用 ! 作为 ESCAPE 字符,兼容 MySQL/PostgreSQL/SQLite)
|
||||
// 2. 连续的 % 合并为单个 %
|
||||
// 3. 最多允许 2 个 %
|
||||
// 4. 含 % 时(模糊搜索),去掉 % 后关键词长度必须 >= 2
|
||||
// 5. 不含 % 时按精确匹配
|
||||
func sanitizeLikePattern(input string) (string, error) {
|
||||
// 1. 先转义 ESCAPE 字符 ! 自身,再转义 _
|
||||
// 使用 ! 而非 \ 作为 ESCAPE 字符,避免 MySQL 中反斜杠的字符串转义问题
|
||||
input = strings.ReplaceAll(input, "!", "!!")
|
||||
input = strings.ReplaceAll(input, `_`, `!_`)
|
||||
|
||||
// 2. 连续的 % 直接拒绝
|
||||
if strings.Contains(input, "%%") {
|
||||
return "", errors.New("搜索模式中不允许包含连续的 % 通配符")
|
||||
}
|
||||
|
||||
// 3. 统计 % 数量,不得超过 2
|
||||
count := strings.Count(input, "%")
|
||||
if count > 2 {
|
||||
return "", errors.New("搜索模式中最多允许包含 2 个 % 通配符")
|
||||
}
|
||||
|
||||
// 4. 含 % 时,去掉 % 后关键词长度必须 >= 2
|
||||
if count > 0 {
|
||||
stripped := strings.ReplaceAll(input, "%", "")
|
||||
if len(stripped) < 2 {
|
||||
return "", errors.New("使用模糊搜索时,关键词长度至少为 2 个字符")
|
||||
}
|
||||
return input, nil
|
||||
}
|
||||
|
||||
// 5. 无 % 时,精确全匹配
|
||||
return input, nil
|
||||
}
|
||||
|
||||
const searchHardLimit = 100
|
||||
|
||||
func SearchUserTokens(userId int, keyword string, token string, offset int, limit int) (tokens []*Token, total int64, err error) {
|
||||
// model 层强制截断
|
||||
if limit <= 0 || limit > searchHardLimit {
|
||||
limit = searchHardLimit
|
||||
}
|
||||
if offset < 0 {
|
||||
offset = 0
|
||||
}
|
||||
|
||||
if token != "" {
|
||||
token = strings.Trim(token, "sk-")
|
||||
}
|
||||
err = DB.Where("user_id = ?", userId).Where("name LIKE ?", "%"+keyword+"%").Where(commonKeyCol+" LIKE ?", "%"+token+"%").Find(&tokens).Error
|
||||
return tokens, err
|
||||
|
||||
// 超量用户(令牌数超过上限)只允许精确搜索,禁止模糊搜索
|
||||
maxTokens := operation_setting.GetMaxUserTokens()
|
||||
hasFuzzy := strings.Contains(keyword, "%") || strings.Contains(token, "%")
|
||||
if hasFuzzy {
|
||||
count, err := CountUserTokens(userId)
|
||||
if err != nil {
|
||||
common.SysLog("failed to count user tokens: " + err.Error())
|
||||
return nil, 0, errors.New("获取令牌数量失败")
|
||||
}
|
||||
if int(count) > maxTokens {
|
||||
return nil, 0, errors.New("令牌数量超过上限,仅允许精确搜索,请勿使用 % 通配符")
|
||||
}
|
||||
}
|
||||
|
||||
baseQuery := DB.Model(&Token{}).Where("user_id = ?", userId)
|
||||
|
||||
// 非空才加 LIKE 条件,空则跳过(不过滤该字段)
|
||||
if keyword != "" {
|
||||
keywordPattern, err := sanitizeLikePattern(keyword)
|
||||
if err != nil {
|
||||
return nil, 0, err
|
||||
}
|
||||
baseQuery = baseQuery.Where("name LIKE ? ESCAPE '!'", keywordPattern)
|
||||
}
|
||||
if token != "" {
|
||||
tokenPattern, err := sanitizeLikePattern(token)
|
||||
if err != nil {
|
||||
return nil, 0, err
|
||||
}
|
||||
baseQuery = baseQuery.Where(commonKeyCol+" LIKE ? ESCAPE '!'", tokenPattern)
|
||||
}
|
||||
|
||||
// 先查匹配总数(用于分页,受 maxTokens 上限保护,避免全表 COUNT)
|
||||
err = baseQuery.Limit(maxTokens).Count(&total).Error
|
||||
if err != nil {
|
||||
common.SysError("failed to count search tokens: " + err.Error())
|
||||
return nil, 0, errors.New("搜索令牌失败")
|
||||
}
|
||||
|
||||
// 再分页查数据
|
||||
err = baseQuery.Order("id desc").Offset(offset).Limit(limit).Find(&tokens).Error
|
||||
if err != nil {
|
||||
common.SysError("failed to search tokens: " + err.Error())
|
||||
return nil, 0, errors.New("搜索令牌失败")
|
||||
}
|
||||
return tokens, total, nil
|
||||
}
|
||||
|
||||
func ValidateUserToken(key string) (token *Token, err error) {
|
||||
|
||||
@@ -540,6 +540,14 @@ func (user *User) FillUserByGitHubId() error {
|
||||
return nil
|
||||
}
|
||||
|
||||
// UpdateGitHubId updates the user's GitHub ID (used for migration from login to numeric ID)
|
||||
func (user *User) UpdateGitHubId(newGitHubId string) error {
|
||||
if user.Id == 0 {
|
||||
return errors.New("user id is empty")
|
||||
}
|
||||
return DB.Model(user).Update("github_id", newGitHubId).Error
|
||||
}
|
||||
|
||||
func (user *User) FillUserByDiscordId() error {
|
||||
if user.DiscordId == "" {
|
||||
return errors.New("discord id 为空!")
|
||||
|
||||
125
model/user_oauth_binding.go
Normal file
125
model/user_oauth_binding.go
Normal file
@@ -0,0 +1,125 @@
|
||||
package model
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"time"
|
||||
)
|
||||
|
||||
// UserOAuthBinding stores the binding relationship between users and custom OAuth providers
|
||||
type UserOAuthBinding struct {
|
||||
Id int `json:"id" gorm:"primaryKey"`
|
||||
UserId int `json:"user_id" gorm:"index;not null"` // User ID
|
||||
ProviderId int `json:"provider_id" gorm:"index;not null"` // Custom OAuth provider ID
|
||||
ProviderUserId string `json:"provider_user_id" gorm:"type:varchar(256);not null"` // User ID from OAuth provider
|
||||
CreatedAt time.Time `json:"created_at"`
|
||||
|
||||
// Composite unique index to prevent duplicate bindings
|
||||
// One OAuth account can only be bound to one user
|
||||
}
|
||||
|
||||
func (UserOAuthBinding) TableName() string {
|
||||
return "user_oauth_bindings"
|
||||
}
|
||||
|
||||
// GetUserOAuthBindingsByUserId returns all OAuth bindings for a user
|
||||
func GetUserOAuthBindingsByUserId(userId int) ([]*UserOAuthBinding, error) {
|
||||
var bindings []*UserOAuthBinding
|
||||
err := DB.Where("user_id = ?", userId).Find(&bindings).Error
|
||||
return bindings, err
|
||||
}
|
||||
|
||||
// GetUserOAuthBinding returns a specific binding for a user and provider
|
||||
func GetUserOAuthBinding(userId, providerId int) (*UserOAuthBinding, error) {
|
||||
var binding UserOAuthBinding
|
||||
err := DB.Where("user_id = ? AND provider_id = ?", userId, providerId).First(&binding).Error
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &binding, nil
|
||||
}
|
||||
|
||||
// GetUserByOAuthBinding finds a user by provider ID and provider user ID
|
||||
func GetUserByOAuthBinding(providerId int, providerUserId string) (*User, error) {
|
||||
var binding UserOAuthBinding
|
||||
err := DB.Where("provider_id = ? AND provider_user_id = ?", providerId, providerUserId).First(&binding).Error
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
var user User
|
||||
err = DB.First(&user, binding.UserId).Error
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &user, nil
|
||||
}
|
||||
|
||||
// IsProviderUserIdTaken checks if a provider user ID is already bound to any user
|
||||
func IsProviderUserIdTaken(providerId int, providerUserId string) bool {
|
||||
var count int64
|
||||
DB.Model(&UserOAuthBinding{}).Where("provider_id = ? AND provider_user_id = ?", providerId, providerUserId).Count(&count)
|
||||
return count > 0
|
||||
}
|
||||
|
||||
// CreateUserOAuthBinding creates a new OAuth binding
|
||||
func CreateUserOAuthBinding(binding *UserOAuthBinding) error {
|
||||
if binding.UserId == 0 {
|
||||
return errors.New("user ID is required")
|
||||
}
|
||||
if binding.ProviderId == 0 {
|
||||
return errors.New("provider ID is required")
|
||||
}
|
||||
if binding.ProviderUserId == "" {
|
||||
return errors.New("provider user ID is required")
|
||||
}
|
||||
|
||||
// Check if this provider user ID is already taken
|
||||
if IsProviderUserIdTaken(binding.ProviderId, binding.ProviderUserId) {
|
||||
return errors.New("this OAuth account is already bound to another user")
|
||||
}
|
||||
|
||||
binding.CreatedAt = time.Now()
|
||||
return DB.Create(binding).Error
|
||||
}
|
||||
|
||||
// UpdateUserOAuthBinding updates an existing OAuth binding (e.g., rebind to different OAuth account)
|
||||
func UpdateUserOAuthBinding(userId, providerId int, newProviderUserId string) error {
|
||||
// Check if the new provider user ID is already taken by another user
|
||||
var existingBinding UserOAuthBinding
|
||||
err := DB.Where("provider_id = ? AND provider_user_id = ?", providerId, newProviderUserId).First(&existingBinding).Error
|
||||
if err == nil && existingBinding.UserId != userId {
|
||||
return errors.New("this OAuth account is already bound to another user")
|
||||
}
|
||||
|
||||
// Check if user already has a binding for this provider
|
||||
var binding UserOAuthBinding
|
||||
err = DB.Where("user_id = ? AND provider_id = ?", userId, providerId).First(&binding).Error
|
||||
if err != nil {
|
||||
// No existing binding, create new one
|
||||
return CreateUserOAuthBinding(&UserOAuthBinding{
|
||||
UserId: userId,
|
||||
ProviderId: providerId,
|
||||
ProviderUserId: newProviderUserId,
|
||||
})
|
||||
}
|
||||
|
||||
// Update existing binding
|
||||
return DB.Model(&binding).Update("provider_user_id", newProviderUserId).Error
|
||||
}
|
||||
|
||||
// DeleteUserOAuthBinding deletes an OAuth binding
|
||||
func DeleteUserOAuthBinding(userId, providerId int) error {
|
||||
return DB.Where("user_id = ? AND provider_id = ?", userId, providerId).Delete(&UserOAuthBinding{}).Error
|
||||
}
|
||||
|
||||
// DeleteUserOAuthBindingsByUserId deletes all OAuth bindings for a user
|
||||
func DeleteUserOAuthBindingsByUserId(userId int) error {
|
||||
return DB.Where("user_id = ?", userId).Delete(&UserOAuthBinding{}).Error
|
||||
}
|
||||
|
||||
// GetBindingCountByProviderId returns the number of bindings for a provider
|
||||
func GetBindingCountByProviderId(providerId int) (int64, error) {
|
||||
var count int64
|
||||
err := DB.Model(&UserOAuthBinding{}).Where("provider_id = ?", providerId).Count(&count).Error
|
||||
return count, err
|
||||
}
|
||||
172
oauth/discord.go
Normal file
172
oauth/discord.go
Normal file
@@ -0,0 +1,172 @@
|
||||
package oauth
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/QuantumNous/new-api/i18n"
|
||||
"github.com/QuantumNous/new-api/logger"
|
||||
"github.com/QuantumNous/new-api/model"
|
||||
"github.com/QuantumNous/new-api/setting/system_setting"
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
func init() {
|
||||
Register("discord", &DiscordProvider{})
|
||||
}
|
||||
|
||||
// DiscordProvider implements OAuth for Discord
|
||||
type DiscordProvider struct{}
|
||||
|
||||
type discordOAuthResponse struct {
|
||||
AccessToken string `json:"access_token"`
|
||||
IDToken string `json:"id_token"`
|
||||
RefreshToken string `json:"refresh_token"`
|
||||
TokenType string `json:"token_type"`
|
||||
ExpiresIn int `json:"expires_in"`
|
||||
Scope string `json:"scope"`
|
||||
}
|
||||
|
||||
type discordUser struct {
|
||||
UID string `json:"id"`
|
||||
ID string `json:"username"`
|
||||
Name string `json:"global_name"`
|
||||
}
|
||||
|
||||
func (p *DiscordProvider) GetName() string {
|
||||
return "Discord"
|
||||
}
|
||||
|
||||
func (p *DiscordProvider) IsEnabled() bool {
|
||||
return system_setting.GetDiscordSettings().Enabled
|
||||
}
|
||||
|
||||
func (p *DiscordProvider) ExchangeToken(ctx context.Context, code string, c *gin.Context) (*OAuthToken, error) {
|
||||
if code == "" {
|
||||
return nil, NewOAuthError(i18n.MsgOAuthInvalidCode, nil)
|
||||
}
|
||||
|
||||
logger.LogDebug(ctx, "[OAuth-Discord] ExchangeToken: code=%s...", code[:min(len(code), 10)])
|
||||
|
||||
settings := system_setting.GetDiscordSettings()
|
||||
redirectUri := fmt.Sprintf("%s/oauth/discord", system_setting.ServerAddress)
|
||||
values := url.Values{}
|
||||
values.Set("client_id", settings.ClientId)
|
||||
values.Set("client_secret", settings.ClientSecret)
|
||||
values.Set("code", code)
|
||||
values.Set("grant_type", "authorization_code")
|
||||
values.Set("redirect_uri", redirectUri)
|
||||
|
||||
logger.LogDebug(ctx, "[OAuth-Discord] ExchangeToken: redirect_uri=%s", redirectUri)
|
||||
|
||||
req, err := http.NewRequestWithContext(ctx, "POST", "https://discord.com/api/v10/oauth2/token", strings.NewReader(values.Encode()))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
|
||||
req.Header.Set("Accept", "application/json")
|
||||
|
||||
client := http.Client{
|
||||
Timeout: 5 * time.Second,
|
||||
}
|
||||
res, err := client.Do(req)
|
||||
if err != nil {
|
||||
logger.LogError(ctx, fmt.Sprintf("[OAuth-Discord] ExchangeToken error: %s", err.Error()))
|
||||
return nil, NewOAuthErrorWithRaw(i18n.MsgOAuthConnectFailed, map[string]any{"Provider": "Discord"}, err.Error())
|
||||
}
|
||||
defer res.Body.Close()
|
||||
|
||||
logger.LogDebug(ctx, "[OAuth-Discord] ExchangeToken response status: %d", res.StatusCode)
|
||||
|
||||
var discordResponse discordOAuthResponse
|
||||
err = json.NewDecoder(res.Body).Decode(&discordResponse)
|
||||
if err != nil {
|
||||
logger.LogError(ctx, fmt.Sprintf("[OAuth-Discord] ExchangeToken decode error: %s", err.Error()))
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if discordResponse.AccessToken == "" {
|
||||
logger.LogError(ctx, "[OAuth-Discord] ExchangeToken failed: empty access token")
|
||||
return nil, NewOAuthError(i18n.MsgOAuthTokenFailed, map[string]any{"Provider": "Discord"})
|
||||
}
|
||||
|
||||
logger.LogDebug(ctx, "[OAuth-Discord] ExchangeToken success: scope=%s", discordResponse.Scope)
|
||||
|
||||
return &OAuthToken{
|
||||
AccessToken: discordResponse.AccessToken,
|
||||
TokenType: discordResponse.TokenType,
|
||||
RefreshToken: discordResponse.RefreshToken,
|
||||
ExpiresIn: discordResponse.ExpiresIn,
|
||||
Scope: discordResponse.Scope,
|
||||
IDToken: discordResponse.IDToken,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (p *DiscordProvider) GetUserInfo(ctx context.Context, token *OAuthToken) (*OAuthUser, error) {
|
||||
logger.LogDebug(ctx, "[OAuth-Discord] GetUserInfo: fetching user info")
|
||||
|
||||
req, err := http.NewRequestWithContext(ctx, "GET", "https://discord.com/api/v10/users/@me", nil)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
req.Header.Set("Authorization", "Bearer "+token.AccessToken)
|
||||
|
||||
client := http.Client{
|
||||
Timeout: 5 * time.Second,
|
||||
}
|
||||
res, err := client.Do(req)
|
||||
if err != nil {
|
||||
logger.LogError(ctx, fmt.Sprintf("[OAuth-Discord] GetUserInfo error: %s", err.Error()))
|
||||
return nil, NewOAuthErrorWithRaw(i18n.MsgOAuthConnectFailed, map[string]any{"Provider": "Discord"}, err.Error())
|
||||
}
|
||||
defer res.Body.Close()
|
||||
|
||||
logger.LogDebug(ctx, "[OAuth-Discord] GetUserInfo response status: %d", res.StatusCode)
|
||||
|
||||
if res.StatusCode != http.StatusOK {
|
||||
logger.LogError(ctx, fmt.Sprintf("[OAuth-Discord] GetUserInfo failed: status=%d", res.StatusCode))
|
||||
return nil, NewOAuthError(i18n.MsgOAuthGetUserErr, nil)
|
||||
}
|
||||
|
||||
var discordUser discordUser
|
||||
err = json.NewDecoder(res.Body).Decode(&discordUser)
|
||||
if err != nil {
|
||||
logger.LogError(ctx, fmt.Sprintf("[OAuth-Discord] GetUserInfo decode error: %s", err.Error()))
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if discordUser.UID == "" || discordUser.ID == "" {
|
||||
logger.LogError(ctx, "[OAuth-Discord] GetUserInfo failed: empty user fields")
|
||||
return nil, NewOAuthError(i18n.MsgOAuthUserInfoEmpty, map[string]any{"Provider": "Discord"})
|
||||
}
|
||||
|
||||
logger.LogDebug(ctx, "[OAuth-Discord] GetUserInfo success: uid=%s, username=%s, name=%s", discordUser.UID, discordUser.ID, discordUser.Name)
|
||||
|
||||
return &OAuthUser{
|
||||
ProviderUserID: discordUser.UID,
|
||||
Username: discordUser.ID,
|
||||
DisplayName: discordUser.Name,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (p *DiscordProvider) IsUserIDTaken(providerUserID string) bool {
|
||||
return model.IsDiscordIdAlreadyTaken(providerUserID)
|
||||
}
|
||||
|
||||
func (p *DiscordProvider) FillUserByProviderID(user *model.User, providerUserID string) error {
|
||||
user.DiscordId = providerUserID
|
||||
return user.FillUserByDiscordId()
|
||||
}
|
||||
|
||||
func (p *DiscordProvider) SetProviderUserID(user *model.User, providerUserID string) {
|
||||
user.DiscordId = providerUserID
|
||||
}
|
||||
|
||||
func (p *DiscordProvider) GetProviderPrefix() string {
|
||||
return "discord_"
|
||||
}
|
||||
268
oauth/generic.go
Normal file
268
oauth/generic.go
Normal file
@@ -0,0 +1,268 @@
|
||||
package oauth
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/base64"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/QuantumNous/new-api/i18n"
|
||||
"github.com/QuantumNous/new-api/logger"
|
||||
"github.com/QuantumNous/new-api/model"
|
||||
"github.com/QuantumNous/new-api/setting/system_setting"
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/tidwall/gjson"
|
||||
)
|
||||
|
||||
// AuthStyle defines how to send client credentials
|
||||
const (
|
||||
AuthStyleAutoDetect = 0 // Auto-detect based on server response
|
||||
AuthStyleInParams = 1 // Send client_id and client_secret as POST parameters
|
||||
AuthStyleInHeader = 2 // Send as Basic Auth header
|
||||
)
|
||||
|
||||
// GenericOAuthProvider implements OAuth for custom/generic OAuth providers
|
||||
type GenericOAuthProvider struct {
|
||||
config *model.CustomOAuthProvider
|
||||
}
|
||||
|
||||
// NewGenericOAuthProvider creates a new generic OAuth provider from config
|
||||
func NewGenericOAuthProvider(config *model.CustomOAuthProvider) *GenericOAuthProvider {
|
||||
return &GenericOAuthProvider{config: config}
|
||||
}
|
||||
|
||||
func (p *GenericOAuthProvider) GetName() string {
|
||||
return p.config.Name
|
||||
}
|
||||
|
||||
func (p *GenericOAuthProvider) IsEnabled() bool {
|
||||
return p.config.Enabled
|
||||
}
|
||||
|
||||
func (p *GenericOAuthProvider) GetConfig() *model.CustomOAuthProvider {
|
||||
return p.config
|
||||
}
|
||||
|
||||
func (p *GenericOAuthProvider) ExchangeToken(ctx context.Context, code string, c *gin.Context) (*OAuthToken, error) {
|
||||
if code == "" {
|
||||
return nil, NewOAuthError(i18n.MsgOAuthInvalidCode, nil)
|
||||
}
|
||||
|
||||
logger.LogDebug(ctx, "[OAuth-Generic-%s] ExchangeToken: code=%s...", p.config.Slug, code[:min(len(code), 10)])
|
||||
|
||||
redirectUri := fmt.Sprintf("%s/oauth/%s", system_setting.ServerAddress, p.config.Slug)
|
||||
values := url.Values{}
|
||||
values.Set("grant_type", "authorization_code")
|
||||
values.Set("code", code)
|
||||
values.Set("redirect_uri", redirectUri)
|
||||
|
||||
// Determine auth style
|
||||
authStyle := p.config.AuthStyle
|
||||
if authStyle == AuthStyleAutoDetect {
|
||||
// Default to params style for most OAuth servers
|
||||
authStyle = AuthStyleInParams
|
||||
}
|
||||
|
||||
var req *http.Request
|
||||
var err error
|
||||
|
||||
if authStyle == AuthStyleInParams {
|
||||
values.Set("client_id", p.config.ClientId)
|
||||
values.Set("client_secret", p.config.ClientSecret)
|
||||
}
|
||||
|
||||
req, err = http.NewRequestWithContext(ctx, "POST", p.config.TokenEndpoint, strings.NewReader(values.Encode()))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
|
||||
req.Header.Set("Accept", "application/json")
|
||||
|
||||
if authStyle == AuthStyleInHeader {
|
||||
// Basic Auth
|
||||
credentials := base64.StdEncoding.EncodeToString([]byte(p.config.ClientId + ":" + p.config.ClientSecret))
|
||||
req.Header.Set("Authorization", "Basic "+credentials)
|
||||
}
|
||||
|
||||
logger.LogDebug(ctx, "[OAuth-Generic-%s] ExchangeToken: token_endpoint=%s, redirect_uri=%s, auth_style=%d",
|
||||
p.config.Slug, p.config.TokenEndpoint, redirectUri, authStyle)
|
||||
|
||||
client := http.Client{
|
||||
Timeout: 20 * time.Second,
|
||||
}
|
||||
res, err := client.Do(req)
|
||||
if err != nil {
|
||||
logger.LogError(ctx, fmt.Sprintf("[OAuth-Generic-%s] ExchangeToken error: %s", p.config.Slug, err.Error()))
|
||||
return nil, NewOAuthErrorWithRaw(i18n.MsgOAuthConnectFailed, map[string]any{"Provider": p.config.Name}, err.Error())
|
||||
}
|
||||
defer res.Body.Close()
|
||||
|
||||
logger.LogDebug(ctx, "[OAuth-Generic-%s] ExchangeToken response status: %d", p.config.Slug, res.StatusCode)
|
||||
|
||||
body, err := io.ReadAll(res.Body)
|
||||
if err != nil {
|
||||
logger.LogError(ctx, fmt.Sprintf("[OAuth-Generic-%s] ExchangeToken read body error: %s", p.config.Slug, err.Error()))
|
||||
return nil, err
|
||||
}
|
||||
|
||||
bodyStr := string(body)
|
||||
logger.LogDebug(ctx, "[OAuth-Generic-%s] ExchangeToken response body: %s", p.config.Slug, bodyStr[:min(len(bodyStr), 500)])
|
||||
|
||||
// Try to parse as JSON first
|
||||
var tokenResponse struct {
|
||||
AccessToken string `json:"access_token"`
|
||||
TokenType string `json:"token_type"`
|
||||
RefreshToken string `json:"refresh_token"`
|
||||
ExpiresIn int `json:"expires_in"`
|
||||
Scope string `json:"scope"`
|
||||
IDToken string `json:"id_token"`
|
||||
Error string `json:"error"`
|
||||
ErrorDesc string `json:"error_description"`
|
||||
}
|
||||
|
||||
if err := json.Unmarshal(body, &tokenResponse); err != nil {
|
||||
// Try to parse as URL-encoded (some OAuth servers like GitHub return this format)
|
||||
parsedValues, parseErr := url.ParseQuery(bodyStr)
|
||||
if parseErr != nil {
|
||||
logger.LogError(ctx, fmt.Sprintf("[OAuth-Generic-%s] ExchangeToken parse error: %s", p.config.Slug, err.Error()))
|
||||
return nil, err
|
||||
}
|
||||
tokenResponse.AccessToken = parsedValues.Get("access_token")
|
||||
tokenResponse.TokenType = parsedValues.Get("token_type")
|
||||
tokenResponse.Scope = parsedValues.Get("scope")
|
||||
}
|
||||
|
||||
if tokenResponse.Error != "" {
|
||||
logger.LogError(ctx, fmt.Sprintf("[OAuth-Generic-%s] ExchangeToken OAuth error: %s - %s",
|
||||
p.config.Slug, tokenResponse.Error, tokenResponse.ErrorDesc))
|
||||
return nil, NewOAuthErrorWithRaw(i18n.MsgOAuthTokenFailed, map[string]any{"Provider": p.config.Name}, tokenResponse.ErrorDesc)
|
||||
}
|
||||
|
||||
if tokenResponse.AccessToken == "" {
|
||||
logger.LogError(ctx, fmt.Sprintf("[OAuth-Generic-%s] ExchangeToken failed: empty access token", p.config.Slug))
|
||||
return nil, NewOAuthError(i18n.MsgOAuthTokenFailed, map[string]any{"Provider": p.config.Name})
|
||||
}
|
||||
|
||||
logger.LogDebug(ctx, "[OAuth-Generic-%s] ExchangeToken success: scope=%s", p.config.Slug, tokenResponse.Scope)
|
||||
|
||||
return &OAuthToken{
|
||||
AccessToken: tokenResponse.AccessToken,
|
||||
TokenType: tokenResponse.TokenType,
|
||||
RefreshToken: tokenResponse.RefreshToken,
|
||||
ExpiresIn: tokenResponse.ExpiresIn,
|
||||
Scope: tokenResponse.Scope,
|
||||
IDToken: tokenResponse.IDToken,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (p *GenericOAuthProvider) GetUserInfo(ctx context.Context, token *OAuthToken) (*OAuthUser, error) {
|
||||
logger.LogDebug(ctx, "[OAuth-Generic-%s] GetUserInfo: fetching user info from %s", p.config.Slug, p.config.UserInfoEndpoint)
|
||||
|
||||
req, err := http.NewRequestWithContext(ctx, "GET", p.config.UserInfoEndpoint, nil)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Set authorization header
|
||||
tokenType := token.TokenType
|
||||
if tokenType == "" {
|
||||
tokenType = "Bearer"
|
||||
}
|
||||
req.Header.Set("Authorization", fmt.Sprintf("%s %s", tokenType, token.AccessToken))
|
||||
req.Header.Set("Accept", "application/json")
|
||||
|
||||
client := http.Client{
|
||||
Timeout: 20 * time.Second,
|
||||
}
|
||||
res, err := client.Do(req)
|
||||
if err != nil {
|
||||
logger.LogError(ctx, fmt.Sprintf("[OAuth-Generic-%s] GetUserInfo error: %s", p.config.Slug, err.Error()))
|
||||
return nil, NewOAuthErrorWithRaw(i18n.MsgOAuthConnectFailed, map[string]any{"Provider": p.config.Name}, err.Error())
|
||||
}
|
||||
defer res.Body.Close()
|
||||
|
||||
logger.LogDebug(ctx, "[OAuth-Generic-%s] GetUserInfo response status: %d", p.config.Slug, res.StatusCode)
|
||||
|
||||
if res.StatusCode != http.StatusOK {
|
||||
logger.LogError(ctx, fmt.Sprintf("[OAuth-Generic-%s] GetUserInfo failed: status=%d", p.config.Slug, res.StatusCode))
|
||||
return nil, NewOAuthError(i18n.MsgOAuthGetUserErr, nil)
|
||||
}
|
||||
|
||||
body, err := io.ReadAll(res.Body)
|
||||
if err != nil {
|
||||
logger.LogError(ctx, fmt.Sprintf("[OAuth-Generic-%s] GetUserInfo read body error: %s", p.config.Slug, err.Error()))
|
||||
return nil, err
|
||||
}
|
||||
|
||||
bodyStr := string(body)
|
||||
logger.LogDebug(ctx, "[OAuth-Generic-%s] GetUserInfo response body: %s", p.config.Slug, bodyStr[:min(len(bodyStr), 500)])
|
||||
|
||||
// Extract fields using gjson (supports JSONPath-like syntax)
|
||||
userId := gjson.Get(bodyStr, p.config.UserIdField).String()
|
||||
username := gjson.Get(bodyStr, p.config.UsernameField).String()
|
||||
displayName := gjson.Get(bodyStr, p.config.DisplayNameField).String()
|
||||
email := gjson.Get(bodyStr, p.config.EmailField).String()
|
||||
|
||||
// If user ID field returns a number, convert it
|
||||
if userId == "" {
|
||||
// Try to get as number
|
||||
userIdNum := gjson.Get(bodyStr, p.config.UserIdField)
|
||||
if userIdNum.Exists() {
|
||||
userId = userIdNum.Raw
|
||||
// Remove quotes if present
|
||||
userId = strings.Trim(userId, "\"")
|
||||
}
|
||||
}
|
||||
|
||||
if userId == "" {
|
||||
logger.LogError(ctx, fmt.Sprintf("[OAuth-Generic-%s] GetUserInfo failed: empty user ID (field: %s)", p.config.Slug, p.config.UserIdField))
|
||||
return nil, NewOAuthError(i18n.MsgOAuthUserInfoEmpty, map[string]any{"Provider": p.config.Name})
|
||||
}
|
||||
|
||||
logger.LogDebug(ctx, "[OAuth-Generic-%s] GetUserInfo success: id=%s, username=%s, name=%s, email=%s",
|
||||
p.config.Slug, userId, username, displayName, email)
|
||||
|
||||
return &OAuthUser{
|
||||
ProviderUserID: userId,
|
||||
Username: username,
|
||||
DisplayName: displayName,
|
||||
Email: email,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (p *GenericOAuthProvider) IsUserIDTaken(providerUserID string) bool {
|
||||
return model.IsProviderUserIdTaken(p.config.Id, providerUserID)
|
||||
}
|
||||
|
||||
func (p *GenericOAuthProvider) FillUserByProviderID(user *model.User, providerUserID string) error {
|
||||
foundUser, err := model.GetUserByOAuthBinding(p.config.Id, providerUserID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
*user = *foundUser
|
||||
return nil
|
||||
}
|
||||
|
||||
func (p *GenericOAuthProvider) SetProviderUserID(user *model.User, providerUserID string) {
|
||||
// For generic providers, we store the binding in user_oauth_bindings table
|
||||
// This is handled separately in the OAuth controller
|
||||
}
|
||||
|
||||
func (p *GenericOAuthProvider) GetProviderPrefix() string {
|
||||
return p.config.Slug + "_"
|
||||
}
|
||||
|
||||
// GetProviderId returns the provider ID for binding purposes
|
||||
func (p *GenericOAuthProvider) GetProviderId() int {
|
||||
return p.config.Id
|
||||
}
|
||||
|
||||
// IsGenericProvider returns true for generic providers
|
||||
func (p *GenericOAuthProvider) IsGenericProvider() bool {
|
||||
return true
|
||||
}
|
||||
166
oauth/github.go
Normal file
166
oauth/github.go
Normal file
@@ -0,0 +1,166 @@
|
||||
package oauth
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"strconv"
|
||||
"time"
|
||||
|
||||
"github.com/QuantumNous/new-api/common"
|
||||
"github.com/QuantumNous/new-api/i18n"
|
||||
"github.com/QuantumNous/new-api/logger"
|
||||
"github.com/QuantumNous/new-api/model"
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
func init() {
|
||||
Register("github", &GitHubProvider{})
|
||||
}
|
||||
|
||||
// GitHubProvider implements OAuth for GitHub
|
||||
type GitHubProvider struct{}
|
||||
|
||||
type gitHubOAuthResponse struct {
|
||||
AccessToken string `json:"access_token"`
|
||||
Scope string `json:"scope"`
|
||||
TokenType string `json:"token_type"`
|
||||
}
|
||||
|
||||
type gitHubUser struct {
|
||||
Id int64 `json:"id"` // GitHub numeric ID (permanent, never changes)
|
||||
Login string `json:"login"` // GitHub username (can be changed by user)
|
||||
Name string `json:"name"`
|
||||
Email string `json:"email"`
|
||||
}
|
||||
|
||||
func (p *GitHubProvider) GetName() string {
|
||||
return "GitHub"
|
||||
}
|
||||
|
||||
func (p *GitHubProvider) IsEnabled() bool {
|
||||
return common.GitHubOAuthEnabled
|
||||
}
|
||||
|
||||
func (p *GitHubProvider) ExchangeToken(ctx context.Context, code string, c *gin.Context) (*OAuthToken, error) {
|
||||
if code == "" {
|
||||
return nil, NewOAuthError(i18n.MsgOAuthInvalidCode, nil)
|
||||
}
|
||||
|
||||
logger.LogDebug(ctx, "[OAuth-GitHub] ExchangeToken: code=%s...", code[:min(len(code), 10)])
|
||||
|
||||
values := map[string]string{
|
||||
"client_id": common.GitHubClientId,
|
||||
"client_secret": common.GitHubClientSecret,
|
||||
"code": code,
|
||||
}
|
||||
jsonData, err := json.Marshal(values)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
req, err := http.NewRequestWithContext(ctx, "POST", "https://github.com/login/oauth/access_token", bytes.NewBuffer(jsonData))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
req.Header.Set("Accept", "application/json")
|
||||
|
||||
client := http.Client{
|
||||
Timeout: 20 * time.Second,
|
||||
}
|
||||
res, err := client.Do(req)
|
||||
if err != nil {
|
||||
logger.LogError(ctx, fmt.Sprintf("[OAuth-GitHub] ExchangeToken error: %s", err.Error()))
|
||||
return nil, NewOAuthErrorWithRaw(i18n.MsgOAuthConnectFailed, map[string]any{"Provider": "GitHub"}, err.Error())
|
||||
}
|
||||
defer res.Body.Close()
|
||||
|
||||
logger.LogDebug(ctx, "[OAuth-GitHub] ExchangeToken response status: %d", res.StatusCode)
|
||||
|
||||
var oAuthResponse gitHubOAuthResponse
|
||||
err = json.NewDecoder(res.Body).Decode(&oAuthResponse)
|
||||
if err != nil {
|
||||
logger.LogError(ctx, fmt.Sprintf("[OAuth-GitHub] ExchangeToken decode error: %s", err.Error()))
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if oAuthResponse.AccessToken == "" {
|
||||
logger.LogError(ctx, "[OAuth-GitHub] ExchangeToken failed: empty access token")
|
||||
return nil, NewOAuthError(i18n.MsgOAuthTokenFailed, map[string]any{"Provider": "GitHub"})
|
||||
}
|
||||
|
||||
logger.LogDebug(ctx, "[OAuth-GitHub] ExchangeToken success: scope=%s", oAuthResponse.Scope)
|
||||
|
||||
return &OAuthToken{
|
||||
AccessToken: oAuthResponse.AccessToken,
|
||||
TokenType: oAuthResponse.TokenType,
|
||||
Scope: oAuthResponse.Scope,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (p *GitHubProvider) GetUserInfo(ctx context.Context, token *OAuthToken) (*OAuthUser, error) {
|
||||
logger.LogDebug(ctx, "[OAuth-GitHub] GetUserInfo: fetching user info")
|
||||
|
||||
req, err := http.NewRequestWithContext(ctx, "GET", "https://api.github.com/user", nil)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", token.AccessToken))
|
||||
|
||||
client := http.Client{
|
||||
Timeout: 20 * time.Second,
|
||||
}
|
||||
res, err := client.Do(req)
|
||||
if err != nil {
|
||||
logger.LogError(ctx, fmt.Sprintf("[OAuth-GitHub] GetUserInfo error: %s", err.Error()))
|
||||
return nil, NewOAuthErrorWithRaw(i18n.MsgOAuthConnectFailed, map[string]any{"Provider": "GitHub"}, err.Error())
|
||||
}
|
||||
defer res.Body.Close()
|
||||
|
||||
logger.LogDebug(ctx, "[OAuth-GitHub] GetUserInfo response status: %d", res.StatusCode)
|
||||
|
||||
var githubUser gitHubUser
|
||||
err = json.NewDecoder(res.Body).Decode(&githubUser)
|
||||
if err != nil {
|
||||
logger.LogError(ctx, fmt.Sprintf("[OAuth-GitHub] GetUserInfo decode error: %s", err.Error()))
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if githubUser.Id == 0 || githubUser.Login == "" {
|
||||
logger.LogError(ctx, "[OAuth-GitHub] GetUserInfo failed: empty id or login field")
|
||||
return nil, NewOAuthError(i18n.MsgOAuthUserInfoEmpty, map[string]any{"Provider": "GitHub"})
|
||||
}
|
||||
|
||||
logger.LogDebug(ctx, "[OAuth-GitHub] GetUserInfo success: id=%d, login=%s, name=%s, email=%s",
|
||||
githubUser.Id, githubUser.Login, githubUser.Name, githubUser.Email)
|
||||
|
||||
return &OAuthUser{
|
||||
ProviderUserID: strconv.FormatInt(githubUser.Id, 10), // Use numeric ID as primary identifier
|
||||
Username: githubUser.Login,
|
||||
DisplayName: githubUser.Name,
|
||||
Email: githubUser.Email,
|
||||
Extra: map[string]any{
|
||||
"legacy_id": githubUser.Login, // Store login for migration from old accounts
|
||||
},
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (p *GitHubProvider) IsUserIDTaken(providerUserID string) bool {
|
||||
return model.IsGitHubIdAlreadyTaken(providerUserID)
|
||||
}
|
||||
|
||||
func (p *GitHubProvider) FillUserByProviderID(user *model.User, providerUserID string) error {
|
||||
user.GitHubId = providerUserID
|
||||
return user.FillUserByGitHubId()
|
||||
}
|
||||
|
||||
func (p *GitHubProvider) SetProviderUserID(user *model.User, providerUserID string) {
|
||||
user.GitHubId = providerUserID
|
||||
}
|
||||
|
||||
func (p *GitHubProvider) GetProviderPrefix() string {
|
||||
return "github_"
|
||||
}
|
||||
195
oauth/linuxdo.go
Normal file
195
oauth/linuxdo.go
Normal file
@@ -0,0 +1,195 @@
|
||||
package oauth
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/base64"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"strconv"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/QuantumNous/new-api/common"
|
||||
"github.com/QuantumNous/new-api/i18n"
|
||||
"github.com/QuantumNous/new-api/logger"
|
||||
"github.com/QuantumNous/new-api/model"
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
func init() {
|
||||
Register("linuxdo", &LinuxDOProvider{})
|
||||
}
|
||||
|
||||
// LinuxDOProvider implements OAuth for Linux DO
|
||||
type LinuxDOProvider struct{}
|
||||
|
||||
type linuxdoUser struct {
|
||||
Id int `json:"id"`
|
||||
Username string `json:"username"`
|
||||
Name string `json:"name"`
|
||||
Active bool `json:"active"`
|
||||
TrustLevel int `json:"trust_level"`
|
||||
Silenced bool `json:"silenced"`
|
||||
}
|
||||
|
||||
func (p *LinuxDOProvider) GetName() string {
|
||||
return "Linux DO"
|
||||
}
|
||||
|
||||
func (p *LinuxDOProvider) IsEnabled() bool {
|
||||
return common.LinuxDOOAuthEnabled
|
||||
}
|
||||
|
||||
func (p *LinuxDOProvider) ExchangeToken(ctx context.Context, code string, c *gin.Context) (*OAuthToken, error) {
|
||||
if code == "" {
|
||||
return nil, NewOAuthError(i18n.MsgOAuthInvalidCode, nil)
|
||||
}
|
||||
|
||||
logger.LogDebug(ctx, "[OAuth-LinuxDO] ExchangeToken: code=%s...", code[:min(len(code), 10)])
|
||||
|
||||
// Get access token using Basic auth
|
||||
tokenEndpoint := common.GetEnvOrDefaultString("LINUX_DO_TOKEN_ENDPOINT", "https://connect.linux.do/oauth2/token")
|
||||
credentials := common.LinuxDOClientId + ":" + common.LinuxDOClientSecret
|
||||
basicAuth := "Basic " + base64.StdEncoding.EncodeToString([]byte(credentials))
|
||||
|
||||
// Get redirect URI from request
|
||||
scheme := "http"
|
||||
if c.Request.TLS != nil {
|
||||
scheme = "https"
|
||||
}
|
||||
redirectURI := fmt.Sprintf("%s://%s/api/oauth/linuxdo", scheme, c.Request.Host)
|
||||
|
||||
logger.LogDebug(ctx, "[OAuth-LinuxDO] ExchangeToken: token_endpoint=%s, redirect_uri=%s", tokenEndpoint, redirectURI)
|
||||
|
||||
data := url.Values{}
|
||||
data.Set("grant_type", "authorization_code")
|
||||
data.Set("code", code)
|
||||
data.Set("redirect_uri", redirectURI)
|
||||
|
||||
req, err := http.NewRequestWithContext(ctx, "POST", tokenEndpoint, strings.NewReader(data.Encode()))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
req.Header.Set("Authorization", basicAuth)
|
||||
req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
|
||||
req.Header.Set("Accept", "application/json")
|
||||
|
||||
client := http.Client{Timeout: 5 * time.Second}
|
||||
res, err := client.Do(req)
|
||||
if err != nil {
|
||||
logger.LogError(ctx, fmt.Sprintf("[OAuth-LinuxDO] ExchangeToken error: %s", err.Error()))
|
||||
return nil, NewOAuthErrorWithRaw(i18n.MsgOAuthConnectFailed, map[string]any{"Provider": "Linux DO"}, err.Error())
|
||||
}
|
||||
defer res.Body.Close()
|
||||
|
||||
logger.LogDebug(ctx, "[OAuth-LinuxDO] ExchangeToken response status: %d", res.StatusCode)
|
||||
|
||||
var tokenRes struct {
|
||||
AccessToken string `json:"access_token"`
|
||||
Message string `json:"message"`
|
||||
}
|
||||
if err := json.NewDecoder(res.Body).Decode(&tokenRes); err != nil {
|
||||
logger.LogError(ctx, fmt.Sprintf("[OAuth-LinuxDO] ExchangeToken decode error: %s", err.Error()))
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if tokenRes.AccessToken == "" {
|
||||
logger.LogError(ctx, fmt.Sprintf("[OAuth-LinuxDO] ExchangeToken failed: %s", tokenRes.Message))
|
||||
return nil, NewOAuthErrorWithRaw(i18n.MsgOAuthTokenFailed, map[string]any{"Provider": "Linux DO"}, tokenRes.Message)
|
||||
}
|
||||
|
||||
logger.LogDebug(ctx, "[OAuth-LinuxDO] ExchangeToken success")
|
||||
|
||||
return &OAuthToken{
|
||||
AccessToken: tokenRes.AccessToken,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (p *LinuxDOProvider) GetUserInfo(ctx context.Context, token *OAuthToken) (*OAuthUser, error) {
|
||||
userEndpoint := common.GetEnvOrDefaultString("LINUX_DO_USER_ENDPOINT", "https://connect.linux.do/api/user")
|
||||
|
||||
logger.LogDebug(ctx, "[OAuth-LinuxDO] GetUserInfo: user_endpoint=%s", userEndpoint)
|
||||
|
||||
req, err := http.NewRequestWithContext(ctx, "GET", userEndpoint, nil)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
req.Header.Set("Authorization", "Bearer "+token.AccessToken)
|
||||
req.Header.Set("Accept", "application/json")
|
||||
|
||||
client := http.Client{Timeout: 5 * time.Second}
|
||||
res, err := client.Do(req)
|
||||
if err != nil {
|
||||
logger.LogError(ctx, fmt.Sprintf("[OAuth-LinuxDO] GetUserInfo error: %s", err.Error()))
|
||||
return nil, NewOAuthErrorWithRaw(i18n.MsgOAuthConnectFailed, map[string]any{"Provider": "Linux DO"}, err.Error())
|
||||
}
|
||||
defer res.Body.Close()
|
||||
|
||||
logger.LogDebug(ctx, "[OAuth-LinuxDO] GetUserInfo response status: %d", res.StatusCode)
|
||||
|
||||
var linuxdoUser linuxdoUser
|
||||
if err := json.NewDecoder(res.Body).Decode(&linuxdoUser); err != nil {
|
||||
logger.LogError(ctx, fmt.Sprintf("[OAuth-LinuxDO] GetUserInfo decode error: %s", err.Error()))
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if linuxdoUser.Id == 0 {
|
||||
logger.LogError(ctx, "[OAuth-LinuxDO] GetUserInfo failed: invalid user id")
|
||||
return nil, NewOAuthError(i18n.MsgOAuthUserInfoEmpty, map[string]any{"Provider": "Linux DO"})
|
||||
}
|
||||
|
||||
logger.LogDebug(ctx, "[OAuth-LinuxDO] GetUserInfo: id=%d, username=%s, name=%s, trust_level=%d, active=%v, silenced=%v",
|
||||
linuxdoUser.Id, linuxdoUser.Username, linuxdoUser.Name, linuxdoUser.TrustLevel, linuxdoUser.Active, linuxdoUser.Silenced)
|
||||
|
||||
// Check trust level
|
||||
if linuxdoUser.TrustLevel < common.LinuxDOMinimumTrustLevel {
|
||||
logger.LogWarn(ctx, fmt.Sprintf("[OAuth-LinuxDO] GetUserInfo: trust level too low (required=%d, current=%d)",
|
||||
common.LinuxDOMinimumTrustLevel, linuxdoUser.TrustLevel))
|
||||
return nil, &TrustLevelError{
|
||||
Required: common.LinuxDOMinimumTrustLevel,
|
||||
Current: linuxdoUser.TrustLevel,
|
||||
}
|
||||
}
|
||||
|
||||
logger.LogDebug(ctx, "[OAuth-LinuxDO] GetUserInfo success: id=%d, username=%s", linuxdoUser.Id, linuxdoUser.Username)
|
||||
|
||||
return &OAuthUser{
|
||||
ProviderUserID: strconv.Itoa(linuxdoUser.Id),
|
||||
Username: linuxdoUser.Username,
|
||||
DisplayName: linuxdoUser.Name,
|
||||
Extra: map[string]any{
|
||||
"trust_level": linuxdoUser.TrustLevel,
|
||||
"active": linuxdoUser.Active,
|
||||
"silenced": linuxdoUser.Silenced,
|
||||
},
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (p *LinuxDOProvider) IsUserIDTaken(providerUserID string) bool {
|
||||
return model.IsLinuxDOIdAlreadyTaken(providerUserID)
|
||||
}
|
||||
|
||||
func (p *LinuxDOProvider) FillUserByProviderID(user *model.User, providerUserID string) error {
|
||||
user.LinuxDOId = providerUserID
|
||||
return user.FillUserByLinuxDOId()
|
||||
}
|
||||
|
||||
func (p *LinuxDOProvider) SetProviderUserID(user *model.User, providerUserID string) {
|
||||
user.LinuxDOId = providerUserID
|
||||
}
|
||||
|
||||
func (p *LinuxDOProvider) GetProviderPrefix() string {
|
||||
return "linuxdo_"
|
||||
}
|
||||
|
||||
// TrustLevelError indicates the user's trust level is too low
|
||||
type TrustLevelError struct {
|
||||
Required int
|
||||
Current int
|
||||
}
|
||||
|
||||
func (e *TrustLevelError) Error() string {
|
||||
return "trust level too low"
|
||||
}
|
||||
177
oauth/oidc.go
Normal file
177
oauth/oidc.go
Normal file
@@ -0,0 +1,177 @@
|
||||
package oauth
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/QuantumNous/new-api/i18n"
|
||||
"github.com/QuantumNous/new-api/logger"
|
||||
"github.com/QuantumNous/new-api/model"
|
||||
"github.com/QuantumNous/new-api/setting/system_setting"
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
func init() {
|
||||
Register("oidc", &OIDCProvider{})
|
||||
}
|
||||
|
||||
// OIDCProvider implements OAuth for OIDC
|
||||
type OIDCProvider struct{}
|
||||
|
||||
type oidcOAuthResponse struct {
|
||||
AccessToken string `json:"access_token"`
|
||||
IDToken string `json:"id_token"`
|
||||
RefreshToken string `json:"refresh_token"`
|
||||
TokenType string `json:"token_type"`
|
||||
ExpiresIn int `json:"expires_in"`
|
||||
Scope string `json:"scope"`
|
||||
}
|
||||
|
||||
type oidcUser struct {
|
||||
OpenID string `json:"sub"`
|
||||
Email string `json:"email"`
|
||||
Name string `json:"name"`
|
||||
PreferredUsername string `json:"preferred_username"`
|
||||
Picture string `json:"picture"`
|
||||
}
|
||||
|
||||
func (p *OIDCProvider) GetName() string {
|
||||
return "OIDC"
|
||||
}
|
||||
|
||||
func (p *OIDCProvider) IsEnabled() bool {
|
||||
return system_setting.GetOIDCSettings().Enabled
|
||||
}
|
||||
|
||||
func (p *OIDCProvider) ExchangeToken(ctx context.Context, code string, c *gin.Context) (*OAuthToken, error) {
|
||||
if code == "" {
|
||||
return nil, NewOAuthError(i18n.MsgOAuthInvalidCode, nil)
|
||||
}
|
||||
|
||||
logger.LogDebug(ctx, "[OAuth-OIDC] ExchangeToken: code=%s...", code[:min(len(code), 10)])
|
||||
|
||||
settings := system_setting.GetOIDCSettings()
|
||||
redirectUri := fmt.Sprintf("%s/oauth/oidc", system_setting.ServerAddress)
|
||||
values := url.Values{}
|
||||
values.Set("client_id", settings.ClientId)
|
||||
values.Set("client_secret", settings.ClientSecret)
|
||||
values.Set("code", code)
|
||||
values.Set("grant_type", "authorization_code")
|
||||
values.Set("redirect_uri", redirectUri)
|
||||
|
||||
logger.LogDebug(ctx, "[OAuth-OIDC] ExchangeToken: token_endpoint=%s, redirect_uri=%s", settings.TokenEndpoint, redirectUri)
|
||||
|
||||
req, err := http.NewRequestWithContext(ctx, "POST", settings.TokenEndpoint, strings.NewReader(values.Encode()))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
|
||||
req.Header.Set("Accept", "application/json")
|
||||
|
||||
client := http.Client{
|
||||
Timeout: 5 * time.Second,
|
||||
}
|
||||
res, err := client.Do(req)
|
||||
if err != nil {
|
||||
logger.LogError(ctx, fmt.Sprintf("[OAuth-OIDC] ExchangeToken error: %s", err.Error()))
|
||||
return nil, NewOAuthErrorWithRaw(i18n.MsgOAuthConnectFailed, map[string]any{"Provider": "OIDC"}, err.Error())
|
||||
}
|
||||
defer res.Body.Close()
|
||||
|
||||
logger.LogDebug(ctx, "[OAuth-OIDC] ExchangeToken response status: %d", res.StatusCode)
|
||||
|
||||
var oidcResponse oidcOAuthResponse
|
||||
err = json.NewDecoder(res.Body).Decode(&oidcResponse)
|
||||
if err != nil {
|
||||
logger.LogError(ctx, fmt.Sprintf("[OAuth-OIDC] ExchangeToken decode error: %s", err.Error()))
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if oidcResponse.AccessToken == "" {
|
||||
logger.LogError(ctx, "[OAuth-OIDC] ExchangeToken failed: empty access token")
|
||||
return nil, NewOAuthError(i18n.MsgOAuthTokenFailed, map[string]any{"Provider": "OIDC"})
|
||||
}
|
||||
|
||||
logger.LogDebug(ctx, "[OAuth-OIDC] ExchangeToken success: scope=%s", oidcResponse.Scope)
|
||||
|
||||
return &OAuthToken{
|
||||
AccessToken: oidcResponse.AccessToken,
|
||||
TokenType: oidcResponse.TokenType,
|
||||
RefreshToken: oidcResponse.RefreshToken,
|
||||
ExpiresIn: oidcResponse.ExpiresIn,
|
||||
Scope: oidcResponse.Scope,
|
||||
IDToken: oidcResponse.IDToken,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (p *OIDCProvider) GetUserInfo(ctx context.Context, token *OAuthToken) (*OAuthUser, error) {
|
||||
settings := system_setting.GetOIDCSettings()
|
||||
|
||||
logger.LogDebug(ctx, "[OAuth-OIDC] GetUserInfo: userinfo_endpoint=%s", settings.UserInfoEndpoint)
|
||||
|
||||
req, err := http.NewRequestWithContext(ctx, "GET", settings.UserInfoEndpoint, nil)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
req.Header.Set("Authorization", "Bearer "+token.AccessToken)
|
||||
|
||||
client := http.Client{
|
||||
Timeout: 5 * time.Second,
|
||||
}
|
||||
res, err := client.Do(req)
|
||||
if err != nil {
|
||||
logger.LogError(ctx, fmt.Sprintf("[OAuth-OIDC] GetUserInfo error: %s", err.Error()))
|
||||
return nil, NewOAuthErrorWithRaw(i18n.MsgOAuthConnectFailed, map[string]any{"Provider": "OIDC"}, err.Error())
|
||||
}
|
||||
defer res.Body.Close()
|
||||
|
||||
logger.LogDebug(ctx, "[OAuth-OIDC] GetUserInfo response status: %d", res.StatusCode)
|
||||
|
||||
if res.StatusCode != http.StatusOK {
|
||||
logger.LogError(ctx, fmt.Sprintf("[OAuth-OIDC] GetUserInfo failed: status=%d", res.StatusCode))
|
||||
return nil, NewOAuthError(i18n.MsgOAuthGetUserErr, nil)
|
||||
}
|
||||
|
||||
var oidcUser oidcUser
|
||||
err = json.NewDecoder(res.Body).Decode(&oidcUser)
|
||||
if err != nil {
|
||||
logger.LogError(ctx, fmt.Sprintf("[OAuth-OIDC] GetUserInfo decode error: %s", err.Error()))
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if oidcUser.OpenID == "" || oidcUser.Email == "" {
|
||||
logger.LogError(ctx, fmt.Sprintf("[OAuth-OIDC] GetUserInfo failed: empty fields (sub=%s, email=%s)", oidcUser.OpenID, oidcUser.Email))
|
||||
return nil, NewOAuthError(i18n.MsgOAuthUserInfoEmpty, map[string]any{"Provider": "OIDC"})
|
||||
}
|
||||
|
||||
logger.LogDebug(ctx, "[OAuth-OIDC] GetUserInfo success: sub=%s, username=%s, name=%s, email=%s", oidcUser.OpenID, oidcUser.PreferredUsername, oidcUser.Name, oidcUser.Email)
|
||||
|
||||
return &OAuthUser{
|
||||
ProviderUserID: oidcUser.OpenID,
|
||||
Username: oidcUser.PreferredUsername,
|
||||
DisplayName: oidcUser.Name,
|
||||
Email: oidcUser.Email,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (p *OIDCProvider) IsUserIDTaken(providerUserID string) bool {
|
||||
return model.IsOidcIdAlreadyTaken(providerUserID)
|
||||
}
|
||||
|
||||
func (p *OIDCProvider) FillUserByProviderID(user *model.User, providerUserID string) error {
|
||||
user.OidcId = providerUserID
|
||||
return user.FillUserByOidcId()
|
||||
}
|
||||
|
||||
func (p *OIDCProvider) SetProviderUserID(user *model.User, providerUserID string) {
|
||||
user.OidcId = providerUserID
|
||||
}
|
||||
|
||||
func (p *OIDCProvider) GetProviderPrefix() string {
|
||||
return "oidc_"
|
||||
}
|
||||
36
oauth/provider.go
Normal file
36
oauth/provider.go
Normal file
@@ -0,0 +1,36 @@
|
||||
package oauth
|
||||
|
||||
import (
|
||||
"context"
|
||||
|
||||
"github.com/QuantumNous/new-api/model"
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
// Provider defines the interface for OAuth providers
|
||||
type Provider interface {
|
||||
// GetName returns the display name of the provider (e.g., "GitHub", "Discord")
|
||||
GetName() string
|
||||
|
||||
// IsEnabled returns whether this OAuth provider is enabled
|
||||
IsEnabled() bool
|
||||
|
||||
// ExchangeToken exchanges the authorization code for an access token
|
||||
// The gin.Context is passed for providers that need request info (e.g., for redirect_uri)
|
||||
ExchangeToken(ctx context.Context, code string, c *gin.Context) (*OAuthToken, error)
|
||||
|
||||
// GetUserInfo retrieves user information using the access token
|
||||
GetUserInfo(ctx context.Context, token *OAuthToken) (*OAuthUser, error)
|
||||
|
||||
// IsUserIDTaken checks if the provider user ID is already associated with an account
|
||||
IsUserIDTaken(providerUserID string) bool
|
||||
|
||||
// FillUserByProviderID fills the user model by provider user ID
|
||||
FillUserByProviderID(user *model.User, providerUserID string) error
|
||||
|
||||
// SetProviderUserID sets the provider user ID on the user model
|
||||
SetProviderUserID(user *model.User, providerUserID string)
|
||||
|
||||
// GetProviderPrefix returns the prefix for auto-generated usernames (e.g., "github_")
|
||||
GetProviderPrefix() string
|
||||
}
|
||||
134
oauth/registry.go
Normal file
134
oauth/registry.go
Normal file
@@ -0,0 +1,134 @@
|
||||
package oauth
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"sync"
|
||||
|
||||
"github.com/QuantumNous/new-api/common"
|
||||
"github.com/QuantumNous/new-api/model"
|
||||
)
|
||||
|
||||
var (
|
||||
providers = make(map[string]Provider)
|
||||
mu sync.RWMutex
|
||||
// customProviderSlugs tracks which providers are custom (can be unregistered)
|
||||
customProviderSlugs = make(map[string]bool)
|
||||
)
|
||||
|
||||
// Register registers an OAuth provider with the given name
|
||||
func Register(name string, provider Provider) {
|
||||
mu.Lock()
|
||||
defer mu.Unlock()
|
||||
providers[name] = provider
|
||||
}
|
||||
|
||||
// RegisterCustom registers a custom OAuth provider (can be unregistered later)
|
||||
func RegisterCustom(name string, provider Provider) {
|
||||
mu.Lock()
|
||||
defer mu.Unlock()
|
||||
providers[name] = provider
|
||||
customProviderSlugs[name] = true
|
||||
}
|
||||
|
||||
// Unregister removes a provider from the registry
|
||||
func Unregister(name string) {
|
||||
mu.Lock()
|
||||
defer mu.Unlock()
|
||||
delete(providers, name)
|
||||
delete(customProviderSlugs, name)
|
||||
}
|
||||
|
||||
// GetProvider returns the OAuth provider for the given name
|
||||
func GetProvider(name string) Provider {
|
||||
mu.RLock()
|
||||
defer mu.RUnlock()
|
||||
return providers[name]
|
||||
}
|
||||
|
||||
// GetAllProviders returns all registered OAuth providers
|
||||
func GetAllProviders() map[string]Provider {
|
||||
mu.RLock()
|
||||
defer mu.RUnlock()
|
||||
result := make(map[string]Provider, len(providers))
|
||||
for k, v := range providers {
|
||||
result[k] = v
|
||||
}
|
||||
return result
|
||||
}
|
||||
|
||||
// GetEnabledCustomProviders returns all enabled custom OAuth providers
|
||||
func GetEnabledCustomProviders() []*GenericOAuthProvider {
|
||||
mu.RLock()
|
||||
defer mu.RUnlock()
|
||||
var result []*GenericOAuthProvider
|
||||
for name, provider := range providers {
|
||||
if customProviderSlugs[name] {
|
||||
if gp, ok := provider.(*GenericOAuthProvider); ok && gp.IsEnabled() {
|
||||
result = append(result, gp)
|
||||
}
|
||||
}
|
||||
}
|
||||
return result
|
||||
}
|
||||
|
||||
// IsProviderRegistered checks if a provider is registered
|
||||
func IsProviderRegistered(name string) bool {
|
||||
mu.RLock()
|
||||
defer mu.RUnlock()
|
||||
_, ok := providers[name]
|
||||
return ok
|
||||
}
|
||||
|
||||
// IsCustomProvider checks if a provider is a custom provider
|
||||
func IsCustomProvider(name string) bool {
|
||||
mu.RLock()
|
||||
defer mu.RUnlock()
|
||||
return customProviderSlugs[name]
|
||||
}
|
||||
|
||||
// LoadCustomProviders loads all custom OAuth providers from the database
|
||||
func LoadCustomProviders() error {
|
||||
// First, unregister all existing custom providers
|
||||
mu.Lock()
|
||||
for name := range customProviderSlugs {
|
||||
delete(providers, name)
|
||||
}
|
||||
customProviderSlugs = make(map[string]bool)
|
||||
mu.Unlock()
|
||||
|
||||
// Load all custom providers from database
|
||||
customProviders, err := model.GetAllCustomOAuthProviders()
|
||||
if err != nil {
|
||||
common.SysError("Failed to load custom OAuth providers: " + err.Error())
|
||||
return err
|
||||
}
|
||||
|
||||
// Register each custom provider
|
||||
for _, config := range customProviders {
|
||||
provider := NewGenericOAuthProvider(config)
|
||||
RegisterCustom(config.Slug, provider)
|
||||
common.SysLog("Loaded custom OAuth provider: " + config.Name + " (" + config.Slug + ")")
|
||||
}
|
||||
|
||||
common.SysLog(fmt.Sprintf("Loaded %d custom OAuth providers", len(customProviders)))
|
||||
return nil
|
||||
}
|
||||
|
||||
// ReloadCustomProviders reloads all custom OAuth providers from the database
|
||||
func ReloadCustomProviders() error {
|
||||
return LoadCustomProviders()
|
||||
}
|
||||
|
||||
// RegisterOrUpdateCustomProvider registers or updates a single custom provider
|
||||
func RegisterOrUpdateCustomProvider(config *model.CustomOAuthProvider) {
|
||||
provider := NewGenericOAuthProvider(config)
|
||||
mu.Lock()
|
||||
defer mu.Unlock()
|
||||
providers[config.Slug] = provider
|
||||
customProviderSlugs[config.Slug] = true
|
||||
}
|
||||
|
||||
// UnregisterCustomProvider unregisters a custom provider by slug
|
||||
func UnregisterCustomProvider(slug string) {
|
||||
Unregister(slug)
|
||||
}
|
||||
59
oauth/types.go
Normal file
59
oauth/types.go
Normal file
@@ -0,0 +1,59 @@
|
||||
package oauth
|
||||
|
||||
// OAuthToken represents the token received from OAuth provider
|
||||
type OAuthToken struct {
|
||||
AccessToken string `json:"access_token"`
|
||||
TokenType string `json:"token_type"`
|
||||
RefreshToken string `json:"refresh_token,omitempty"`
|
||||
ExpiresIn int `json:"expires_in,omitempty"`
|
||||
Scope string `json:"scope,omitempty"`
|
||||
IDToken string `json:"id_token,omitempty"`
|
||||
}
|
||||
|
||||
// OAuthUser represents the user info from OAuth provider
|
||||
type OAuthUser struct {
|
||||
// ProviderUserID is the unique identifier from the OAuth provider
|
||||
ProviderUserID string
|
||||
// Username is the username from the OAuth provider (e.g., GitHub login)
|
||||
Username string
|
||||
// DisplayName is the display name from the OAuth provider
|
||||
DisplayName string
|
||||
// Email is the email from the OAuth provider
|
||||
Email string
|
||||
// Extra contains any additional provider-specific data
|
||||
Extra map[string]any
|
||||
}
|
||||
|
||||
// OAuthError represents a translatable OAuth error
|
||||
type OAuthError struct {
|
||||
// MsgKey is the i18n message key
|
||||
MsgKey string
|
||||
// Params contains optional parameters for the message template
|
||||
Params map[string]any
|
||||
// RawError is the underlying error for logging purposes
|
||||
RawError string
|
||||
}
|
||||
|
||||
func (e *OAuthError) Error() string {
|
||||
if e.RawError != "" {
|
||||
return e.RawError
|
||||
}
|
||||
return e.MsgKey
|
||||
}
|
||||
|
||||
// NewOAuthError creates a new OAuth error with the given message key
|
||||
func NewOAuthError(msgKey string, params map[string]any) *OAuthError {
|
||||
return &OAuthError{
|
||||
MsgKey: msgKey,
|
||||
Params: params,
|
||||
}
|
||||
}
|
||||
|
||||
// NewOAuthErrorWithRaw creates a new OAuth error with raw error message for logging
|
||||
func NewOAuthErrorWithRaw(msgKey string, params map[string]any, rawError string) *OAuthError {
|
||||
return &OAuthError{
|
||||
MsgKey: msgKey,
|
||||
Params: params,
|
||||
RawError: rawError,
|
||||
}
|
||||
}
|
||||
@@ -224,10 +224,10 @@ func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycom
|
||||
case types.RelayFormatClaude:
|
||||
if supportsAliAnthropicMessages(info.UpstreamModelName) {
|
||||
if info.IsStream {
|
||||
return claude.ClaudeStreamHandler(c, resp, info, claude.RequestModeMessage)
|
||||
return claude.ClaudeStreamHandler(c, resp, info)
|
||||
}
|
||||
|
||||
return claude.ClaudeHandler(c, resp, info, claude.RequestModeMessage)
|
||||
return claude.ClaudeHandler(c, resp, info)
|
||||
}
|
||||
|
||||
adaptor := openai.Adaptor{}
|
||||
|
||||
@@ -58,6 +58,8 @@ var passthroughSkipHeaderNamesLower = map[string]struct{}{
|
||||
"transfer-encoding": {},
|
||||
"upgrade": {},
|
||||
|
||||
"cookie": {},
|
||||
|
||||
// Additional headers that should not be forwarded by name-matching passthrough rules.
|
||||
"host": {},
|
||||
"content-length": {},
|
||||
|
||||
@@ -3,9 +3,6 @@ package aws
|
||||
import "strings"
|
||||
|
||||
var awsModelIDMap = map[string]string{
|
||||
"claude-instant-1.2": "anthropic.claude-instant-v1",
|
||||
"claude-2.0": "anthropic.claude-v2",
|
||||
"claude-2.1": "anthropic.claude-v2:1",
|
||||
"claude-3-sonnet-20240229": "anthropic.claude-3-sonnet-20240229-v1:0",
|
||||
"claude-3-opus-20240229": "anthropic.claude-3-opus-20240229-v1:0",
|
||||
"claude-3-haiku-20240307": "anthropic.claude-3-haiku-20240307-v1:0",
|
||||
@@ -19,6 +16,7 @@ var awsModelIDMap = map[string]string{
|
||||
"claude-sonnet-4-5-20250929": "anthropic.claude-sonnet-4-5-20250929-v1:0",
|
||||
"claude-haiku-4-5-20251001": "anthropic.claude-haiku-4-5-20251001-v1:0",
|
||||
"claude-opus-4-5-20251101": "anthropic.claude-opus-4-5-20251101-v1:0",
|
||||
"claude-opus-4-6": "anthropic.claude-opus-4-6-v1",
|
||||
// Nova models
|
||||
"nova-micro-v1:0": "amazon.nova-micro-v1:0",
|
||||
"nova-lite-v1:0": "amazon.nova-lite-v1:0",
|
||||
@@ -82,6 +80,11 @@ var awsModelCanCrossRegionMap = map[string]map[string]bool{
|
||||
"ap": true,
|
||||
"eu": true,
|
||||
},
|
||||
"anthropic.claude-opus-4-6-v1": {
|
||||
"us": true,
|
||||
"ap": true,
|
||||
"eu": true,
|
||||
},
|
||||
"anthropic.claude-haiku-4-5-20251001-v1:0": {
|
||||
"us": true,
|
||||
"ap": true,
|
||||
|
||||
@@ -26,6 +26,7 @@ type AwsClaudeRequest struct {
|
||||
Tools any `json:"tools,omitempty"`
|
||||
ToolChoice any `json:"tool_choice,omitempty"`
|
||||
Thinking *dto.Thinking `json:"thinking,omitempty"`
|
||||
OutputConfig json.RawMessage `json:"output_config,omitempty"`
|
||||
}
|
||||
|
||||
func formatRequest(requestBody io.Reader, requestHeader http.Header) (*AwsClaudeRequest, error) {
|
||||
|
||||
@@ -233,7 +233,7 @@ func awsHandler(c *gin.Context, info *relaycommon.RelayInfo, a *Adaptor) (*types
|
||||
c.Writer.Header().Set("Content-Type", *awsResp.ContentType)
|
||||
}
|
||||
|
||||
handlerErr := claude.HandleClaudeResponseData(c, info, claudeInfo, nil, awsResp.Body, claude.RequestModeMessage)
|
||||
handlerErr := claude.HandleClaudeResponseData(c, info, claudeInfo, nil, awsResp.Body)
|
||||
if handlerErr != nil {
|
||||
return handlerErr, nil
|
||||
}
|
||||
@@ -264,7 +264,7 @@ func awsStreamHandler(c *gin.Context, info *relaycommon.RelayInfo, a *Adaptor) (
|
||||
switch v := event.(type) {
|
||||
case *bedrockruntimeTypes.ResponseStreamMemberChunk:
|
||||
info.SetFirstResponseTime()
|
||||
respErr := claude.HandleStreamResponseData(c, info, claudeInfo, string(v.Value.Bytes), claude.RequestModeMessage)
|
||||
respErr := claude.HandleStreamResponseData(c, info, claudeInfo, string(v.Value.Bytes))
|
||||
if respErr != nil {
|
||||
return respErr, nil
|
||||
}
|
||||
@@ -277,7 +277,7 @@ func awsStreamHandler(c *gin.Context, info *relaycommon.RelayInfo, a *Adaptor) (
|
||||
}
|
||||
}
|
||||
|
||||
claude.HandleStreamFinalResponse(c, info, claudeInfo, claude.RequestModeMessage)
|
||||
claude.HandleStreamFinalResponse(c, info, claudeInfo)
|
||||
return nil, claudeInfo.Usage
|
||||
}
|
||||
|
||||
|
||||
@@ -5,7 +5,6 @@ import (
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"strings"
|
||||
|
||||
"github.com/QuantumNous/new-api/dto"
|
||||
"github.com/QuantumNous/new-api/relay/channel"
|
||||
@@ -16,13 +15,7 @@ import (
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
const (
|
||||
RequestModeCompletion = 1
|
||||
RequestModeMessage = 2
|
||||
)
|
||||
|
||||
type Adaptor struct {
|
||||
RequestMode int
|
||||
}
|
||||
|
||||
func (a *Adaptor) ConvertGeminiRequest(*gin.Context, *relaycommon.RelayInfo, *dto.GeminiChatRequest) (any, error) {
|
||||
@@ -45,20 +38,10 @@ func (a *Adaptor) ConvertImageRequest(c *gin.Context, info *relaycommon.RelayInf
|
||||
}
|
||||
|
||||
func (a *Adaptor) Init(info *relaycommon.RelayInfo) {
|
||||
if strings.HasPrefix(info.UpstreamModelName, "claude-2") || strings.HasPrefix(info.UpstreamModelName, "claude-instant") {
|
||||
a.RequestMode = RequestModeCompletion
|
||||
} else {
|
||||
a.RequestMode = RequestModeMessage
|
||||
}
|
||||
}
|
||||
|
||||
func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) {
|
||||
baseURL := ""
|
||||
if a.RequestMode == RequestModeMessage {
|
||||
baseURL = fmt.Sprintf("%s/v1/messages", info.ChannelBaseUrl)
|
||||
} else {
|
||||
baseURL = fmt.Sprintf("%s/v1/complete", info.ChannelBaseUrl)
|
||||
}
|
||||
baseURL := fmt.Sprintf("%s/v1/messages", info.ChannelBaseUrl)
|
||||
if info.IsClaudeBetaQuery {
|
||||
baseURL = baseURL + "?beta=true"
|
||||
}
|
||||
@@ -90,11 +73,7 @@ func (a *Adaptor) ConvertOpenAIRequest(c *gin.Context, info *relaycommon.RelayIn
|
||||
if request == nil {
|
||||
return nil, errors.New("request is nil")
|
||||
}
|
||||
if a.RequestMode == RequestModeCompletion {
|
||||
return RequestOpenAI2ClaudeComplete(*request), nil
|
||||
} else {
|
||||
return RequestOpenAI2ClaudeMessage(c, *request)
|
||||
}
|
||||
return RequestOpenAI2ClaudeMessage(c, *request)
|
||||
}
|
||||
|
||||
func (a *Adaptor) ConvertRerankRequest(c *gin.Context, relayMode int, request dto.RerankRequest) (any, error) {
|
||||
@@ -117,11 +96,10 @@ func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, request
|
||||
|
||||
func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage any, err *types.NewAPIError) {
|
||||
if info.IsStream {
|
||||
return ClaudeStreamHandler(c, resp, info, a.RequestMode)
|
||||
return ClaudeStreamHandler(c, resp, info)
|
||||
} else {
|
||||
return ClaudeHandler(c, resp, info, a.RequestMode)
|
||||
return ClaudeHandler(c, resp, info)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
func (a *Adaptor) GetModelList() []string {
|
||||
|
||||
@@ -1,10 +1,6 @@
|
||||
package claude
|
||||
|
||||
var ModelList = []string{
|
||||
"claude-instant-1.2",
|
||||
"claude-2",
|
||||
"claude-2.0",
|
||||
"claude-2.1",
|
||||
"claude-3-sonnet-20240229",
|
||||
"claude-3-opus-20240229",
|
||||
"claude-3-haiku-20240307",
|
||||
@@ -24,6 +20,11 @@ var ModelList = []string{
|
||||
"claude-sonnet-4-5-20250929-thinking",
|
||||
"claude-opus-4-5-20251101",
|
||||
"claude-opus-4-5-20251101-thinking",
|
||||
"claude-opus-4-6",
|
||||
"claude-opus-4-6-max",
|
||||
"claude-opus-4-6-high",
|
||||
"claude-opus-4-6-medium",
|
||||
"claude-opus-4-6-low",
|
||||
}
|
||||
|
||||
var ChannelName = "claude"
|
||||
|
||||
@@ -17,6 +17,7 @@ import (
|
||||
"github.com/QuantumNous/new-api/relay/reasonmap"
|
||||
"github.com/QuantumNous/new-api/service"
|
||||
"github.com/QuantumNous/new-api/setting/model_setting"
|
||||
"github.com/QuantumNous/new-api/setting/reasoning"
|
||||
"github.com/QuantumNous/new-api/types"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
@@ -41,37 +42,6 @@ func maybeMarkClaudeRefusal(c *gin.Context, stopReason string) {
|
||||
}
|
||||
}
|
||||
|
||||
func RequestOpenAI2ClaudeComplete(textRequest dto.GeneralOpenAIRequest) *dto.ClaudeRequest {
|
||||
|
||||
claudeRequest := dto.ClaudeRequest{
|
||||
Model: textRequest.Model,
|
||||
Prompt: "",
|
||||
StopSequences: nil,
|
||||
Temperature: textRequest.Temperature,
|
||||
TopP: textRequest.TopP,
|
||||
TopK: textRequest.TopK,
|
||||
Stream: textRequest.Stream,
|
||||
}
|
||||
if claudeRequest.MaxTokensToSample == 0 {
|
||||
claudeRequest.MaxTokensToSample = 4096
|
||||
}
|
||||
prompt := ""
|
||||
for _, message := range textRequest.Messages {
|
||||
if message.Role == "user" {
|
||||
prompt += fmt.Sprintf("\n\nHuman: %s", message.StringContent())
|
||||
} else if message.Role == "assistant" {
|
||||
prompt += fmt.Sprintf("\n\nAssistant: %s", message.StringContent())
|
||||
} else if message.Role == "system" {
|
||||
if prompt == "" {
|
||||
prompt = message.StringContent()
|
||||
}
|
||||
}
|
||||
}
|
||||
prompt += "\n\nAssistant:"
|
||||
claudeRequest.Prompt = prompt
|
||||
return &claudeRequest
|
||||
}
|
||||
|
||||
func RequestOpenAI2ClaudeMessage(c *gin.Context, textRequest dto.GeneralOpenAIRequest) (*dto.ClaudeRequest, error) {
|
||||
claudeTools := make([]any, 0, len(textRequest.Tools))
|
||||
|
||||
@@ -172,7 +142,16 @@ func RequestOpenAI2ClaudeMessage(c *gin.Context, textRequest dto.GeneralOpenAIRe
|
||||
claudeRequest.MaxTokens = uint(model_setting.GetClaudeSettings().GetDefaultMaxTokens(textRequest.Model))
|
||||
}
|
||||
|
||||
if model_setting.GetClaudeSettings().ThinkingAdapterEnabled &&
|
||||
if baseModel, effortLevel, ok := reasoning.TrimEffortSuffix(textRequest.Model); ok && effortLevel != "" &&
|
||||
strings.HasPrefix(textRequest.Model, "claude-opus-4-6") {
|
||||
claudeRequest.Model = baseModel
|
||||
claudeRequest.Thinking = &dto.Thinking{
|
||||
Type: "adaptive",
|
||||
}
|
||||
claudeRequest.OutputConfig = json.RawMessage(fmt.Sprintf(`{"effort":"%s"}`, effortLevel))
|
||||
claudeRequest.TopP = 0
|
||||
claudeRequest.Temperature = common.GetPointer[float64](1.0)
|
||||
} else if model_setting.GetClaudeSettings().ThinkingAdapterEnabled &&
|
||||
strings.HasSuffix(textRequest.Model, "-thinking") {
|
||||
|
||||
// 因为BudgetTokens 必须大于1024
|
||||
@@ -411,7 +390,7 @@ func RequestOpenAI2ClaudeMessage(c *gin.Context, textRequest dto.GeneralOpenAIRe
|
||||
return &claudeRequest, nil
|
||||
}
|
||||
|
||||
func StreamResponseClaude2OpenAI(reqMode int, claudeResponse *dto.ClaudeResponse) *dto.ChatCompletionsStreamResponse {
|
||||
func StreamResponseClaude2OpenAI(claudeResponse *dto.ClaudeResponse) *dto.ChatCompletionsStreamResponse {
|
||||
var response dto.ChatCompletionsStreamResponse
|
||||
response.Object = "chat.completion.chunk"
|
||||
response.Model = claudeResponse.Model
|
||||
@@ -425,74 +404,66 @@ func StreamResponseClaude2OpenAI(reqMode int, claudeResponse *dto.ClaudeResponse
|
||||
}
|
||||
}
|
||||
var choice dto.ChatCompletionsStreamResponseChoice
|
||||
if reqMode == RequestModeCompletion {
|
||||
choice.Delta.SetContentString(claudeResponse.Completion)
|
||||
finishReason := stopReasonClaude2OpenAI(claudeResponse.StopReason)
|
||||
if finishReason != "null" {
|
||||
choice.FinishReason = &finishReason
|
||||
if claudeResponse.Type == "message_start" {
|
||||
if claudeResponse.Message != nil {
|
||||
response.Id = claudeResponse.Message.Id
|
||||
response.Model = claudeResponse.Message.Model
|
||||
}
|
||||
} else {
|
||||
if claudeResponse.Type == "message_start" {
|
||||
if claudeResponse.Message != nil {
|
||||
response.Id = claudeResponse.Message.Id
|
||||
response.Model = claudeResponse.Message.Model
|
||||
//claudeUsage = &claudeResponse.Message.Usage
|
||||
choice.Delta.SetContentString("")
|
||||
choice.Delta.Role = "assistant"
|
||||
} else if claudeResponse.Type == "content_block_start" {
|
||||
if claudeResponse.ContentBlock != nil {
|
||||
// 如果是文本块,尽可能发送首段文本(若存在)
|
||||
if claudeResponse.ContentBlock.Type == "text" && claudeResponse.ContentBlock.Text != nil {
|
||||
choice.Delta.SetContentString(*claudeResponse.ContentBlock.Text)
|
||||
}
|
||||
//claudeUsage = &claudeResponse.Message.Usage
|
||||
choice.Delta.SetContentString("")
|
||||
choice.Delta.Role = "assistant"
|
||||
} else if claudeResponse.Type == "content_block_start" {
|
||||
if claudeResponse.ContentBlock != nil {
|
||||
// 如果是文本块,尽可能发送首段文本(若存在)
|
||||
if claudeResponse.ContentBlock.Type == "text" && claudeResponse.ContentBlock.Text != nil {
|
||||
choice.Delta.SetContentString(*claudeResponse.ContentBlock.Text)
|
||||
}
|
||||
if claudeResponse.ContentBlock.Type == "tool_use" {
|
||||
tools = append(tools, dto.ToolCallResponse{
|
||||
Index: common.GetPointer(fcIdx),
|
||||
ID: claudeResponse.ContentBlock.Id,
|
||||
Type: "function",
|
||||
Function: dto.FunctionResponse{
|
||||
Name: claudeResponse.ContentBlock.Name,
|
||||
Arguments: "",
|
||||
},
|
||||
})
|
||||
}
|
||||
} else {
|
||||
return nil
|
||||
if claudeResponse.ContentBlock.Type == "tool_use" {
|
||||
tools = append(tools, dto.ToolCallResponse{
|
||||
Index: common.GetPointer(fcIdx),
|
||||
ID: claudeResponse.ContentBlock.Id,
|
||||
Type: "function",
|
||||
Function: dto.FunctionResponse{
|
||||
Name: claudeResponse.ContentBlock.Name,
|
||||
Arguments: "",
|
||||
},
|
||||
})
|
||||
}
|
||||
} else if claudeResponse.Type == "content_block_delta" {
|
||||
if claudeResponse.Delta != nil {
|
||||
choice.Delta.Content = claudeResponse.Delta.Text
|
||||
switch claudeResponse.Delta.Type {
|
||||
case "input_json_delta":
|
||||
tools = append(tools, dto.ToolCallResponse{
|
||||
Type: "function",
|
||||
Index: common.GetPointer(fcIdx),
|
||||
Function: dto.FunctionResponse{
|
||||
Arguments: *claudeResponse.Delta.PartialJson,
|
||||
},
|
||||
})
|
||||
case "signature_delta":
|
||||
// 加密的不处理
|
||||
signatureContent := "\n"
|
||||
choice.Delta.ReasoningContent = &signatureContent
|
||||
case "thinking_delta":
|
||||
choice.Delta.ReasoningContent = claudeResponse.Delta.Thinking
|
||||
}
|
||||
}
|
||||
} else if claudeResponse.Type == "message_delta" {
|
||||
if claudeResponse.Delta != nil && claudeResponse.Delta.StopReason != nil {
|
||||
finishReason := stopReasonClaude2OpenAI(*claudeResponse.Delta.StopReason)
|
||||
if finishReason != "null" {
|
||||
choice.FinishReason = &finishReason
|
||||
}
|
||||
}
|
||||
//claudeUsage = &claudeResponse.Usage
|
||||
} else if claudeResponse.Type == "message_stop" {
|
||||
return nil
|
||||
} else {
|
||||
return nil
|
||||
}
|
||||
} else if claudeResponse.Type == "content_block_delta" {
|
||||
if claudeResponse.Delta != nil {
|
||||
choice.Delta.Content = claudeResponse.Delta.Text
|
||||
switch claudeResponse.Delta.Type {
|
||||
case "input_json_delta":
|
||||
tools = append(tools, dto.ToolCallResponse{
|
||||
Type: "function",
|
||||
Index: common.GetPointer(fcIdx),
|
||||
Function: dto.FunctionResponse{
|
||||
Arguments: *claudeResponse.Delta.PartialJson,
|
||||
},
|
||||
})
|
||||
case "signature_delta":
|
||||
// 加密的不处理
|
||||
signatureContent := "\n"
|
||||
choice.Delta.ReasoningContent = &signatureContent
|
||||
case "thinking_delta":
|
||||
choice.Delta.ReasoningContent = claudeResponse.Delta.Thinking
|
||||
}
|
||||
}
|
||||
} else if claudeResponse.Type == "message_delta" {
|
||||
if claudeResponse.Delta != nil && claudeResponse.Delta.StopReason != nil {
|
||||
finishReason := stopReasonClaude2OpenAI(*claudeResponse.Delta.StopReason)
|
||||
if finishReason != "null" {
|
||||
choice.FinishReason = &finishReason
|
||||
}
|
||||
}
|
||||
//claudeUsage = &claudeResponse.Usage
|
||||
} else if claudeResponse.Type == "message_stop" {
|
||||
return nil
|
||||
} else {
|
||||
return nil
|
||||
}
|
||||
if len(tools) > 0 {
|
||||
choice.Delta.Content = nil // compatible with other OpenAI derivative applications, like LobeOpenAICompatibleFactory ...
|
||||
@@ -503,7 +474,7 @@ func StreamResponseClaude2OpenAI(reqMode int, claudeResponse *dto.ClaudeResponse
|
||||
return &response
|
||||
}
|
||||
|
||||
func ResponseClaude2OpenAI(reqMode int, claudeResponse *dto.ClaudeResponse) *dto.OpenAITextResponse {
|
||||
func ResponseClaude2OpenAI(claudeResponse *dto.ClaudeResponse) *dto.OpenAITextResponse {
|
||||
choices := make([]dto.OpenAITextResponseChoice, 0)
|
||||
fullTextResponse := dto.OpenAITextResponse{
|
||||
Id: fmt.Sprintf("chatcmpl-%s", common.GetUUID()),
|
||||
@@ -521,39 +492,26 @@ func ResponseClaude2OpenAI(reqMode int, claudeResponse *dto.ClaudeResponse) *dto
|
||||
tools := make([]dto.ToolCallResponse, 0)
|
||||
thinkingContent := ""
|
||||
|
||||
if reqMode == RequestModeCompletion {
|
||||
choice := dto.OpenAITextResponseChoice{
|
||||
Index: 0,
|
||||
Message: dto.Message{
|
||||
Role: "assistant",
|
||||
Content: strings.TrimPrefix(claudeResponse.Completion, " "),
|
||||
Name: nil,
|
||||
},
|
||||
FinishReason: stopReasonClaude2OpenAI(claudeResponse.StopReason),
|
||||
}
|
||||
choices = append(choices, choice)
|
||||
} else {
|
||||
fullTextResponse.Id = claudeResponse.Id
|
||||
for _, message := range claudeResponse.Content {
|
||||
switch message.Type {
|
||||
case "tool_use":
|
||||
args, _ := json.Marshal(message.Input)
|
||||
tools = append(tools, dto.ToolCallResponse{
|
||||
ID: message.Id,
|
||||
Type: "function", // compatible with other OpenAI derivative applications
|
||||
Function: dto.FunctionResponse{
|
||||
Name: message.Name,
|
||||
Arguments: string(args),
|
||||
},
|
||||
})
|
||||
case "thinking":
|
||||
// 加密的不管, 只输出明文的推理过程
|
||||
if message.Thinking != nil {
|
||||
thinkingContent = *message.Thinking
|
||||
}
|
||||
case "text":
|
||||
responseText = message.GetText()
|
||||
fullTextResponse.Id = claudeResponse.Id
|
||||
for _, message := range claudeResponse.Content {
|
||||
switch message.Type {
|
||||
case "tool_use":
|
||||
args, _ := json.Marshal(message.Input)
|
||||
tools = append(tools, dto.ToolCallResponse{
|
||||
ID: message.Id,
|
||||
Type: "function", // compatible with other OpenAI derivative applications
|
||||
Function: dto.FunctionResponse{
|
||||
Name: message.Name,
|
||||
Arguments: string(args),
|
||||
},
|
||||
})
|
||||
case "thinking":
|
||||
// 加密的不管, 只输出明文的推理过程
|
||||
if message.Thinking != nil {
|
||||
thinkingContent = *message.Thinking
|
||||
}
|
||||
case "text":
|
||||
responseText = message.GetText()
|
||||
}
|
||||
}
|
||||
choice := dto.OpenAITextResponseChoice{
|
||||
@@ -586,71 +544,67 @@ type ClaudeResponseInfo struct {
|
||||
Done bool
|
||||
}
|
||||
|
||||
func FormatClaudeResponseInfo(requestMode int, claudeResponse *dto.ClaudeResponse, oaiResponse *dto.ChatCompletionsStreamResponse, claudeInfo *ClaudeResponseInfo) bool {
|
||||
func FormatClaudeResponseInfo(claudeResponse *dto.ClaudeResponse, oaiResponse *dto.ChatCompletionsStreamResponse, claudeInfo *ClaudeResponseInfo) bool {
|
||||
if claudeInfo == nil {
|
||||
return false
|
||||
}
|
||||
if claudeInfo.Usage == nil {
|
||||
claudeInfo.Usage = &dto.Usage{}
|
||||
}
|
||||
if requestMode == RequestModeCompletion {
|
||||
claudeInfo.ResponseText.WriteString(claudeResponse.Completion)
|
||||
} else {
|
||||
if claudeResponse.Type == "message_start" {
|
||||
if claudeResponse.Message != nil {
|
||||
claudeInfo.ResponseId = claudeResponse.Message.Id
|
||||
claudeInfo.Model = claudeResponse.Message.Model
|
||||
}
|
||||
|
||||
// message_start, 获取usage
|
||||
if claudeResponse.Message != nil && claudeResponse.Message.Usage != nil {
|
||||
claudeInfo.Usage.PromptTokens = claudeResponse.Message.Usage.InputTokens
|
||||
claudeInfo.Usage.PromptTokensDetails.CachedTokens = claudeResponse.Message.Usage.CacheReadInputTokens
|
||||
claudeInfo.Usage.PromptTokensDetails.CachedCreationTokens = claudeResponse.Message.Usage.CacheCreationInputTokens
|
||||
claudeInfo.Usage.ClaudeCacheCreation5mTokens = claudeResponse.Message.Usage.GetCacheCreation5mTokens()
|
||||
claudeInfo.Usage.ClaudeCacheCreation1hTokens = claudeResponse.Message.Usage.GetCacheCreation1hTokens()
|
||||
claudeInfo.Usage.CompletionTokens = claudeResponse.Message.Usage.OutputTokens
|
||||
}
|
||||
} else if claudeResponse.Type == "content_block_delta" {
|
||||
if claudeResponse.Delta != nil {
|
||||
if claudeResponse.Delta.Text != nil {
|
||||
claudeInfo.ResponseText.WriteString(*claudeResponse.Delta.Text)
|
||||
}
|
||||
if claudeResponse.Delta.Thinking != nil {
|
||||
claudeInfo.ResponseText.WriteString(*claudeResponse.Delta.Thinking)
|
||||
}
|
||||
}
|
||||
} else if claudeResponse.Type == "message_delta" {
|
||||
// 最终的usage获取
|
||||
if claudeResponse.Usage != nil {
|
||||
if claudeResponse.Usage.InputTokens > 0 {
|
||||
// 不叠加,只取最新的
|
||||
claudeInfo.Usage.PromptTokens = claudeResponse.Usage.InputTokens
|
||||
}
|
||||
if claudeResponse.Usage.CacheReadInputTokens > 0 {
|
||||
claudeInfo.Usage.PromptTokensDetails.CachedTokens = claudeResponse.Usage.CacheReadInputTokens
|
||||
}
|
||||
if claudeResponse.Usage.CacheCreationInputTokens > 0 {
|
||||
claudeInfo.Usage.PromptTokensDetails.CachedCreationTokens = claudeResponse.Usage.CacheCreationInputTokens
|
||||
}
|
||||
if cacheCreation5m := claudeResponse.Usage.GetCacheCreation5mTokens(); cacheCreation5m > 0 {
|
||||
claudeInfo.Usage.ClaudeCacheCreation5mTokens = cacheCreation5m
|
||||
}
|
||||
if cacheCreation1h := claudeResponse.Usage.GetCacheCreation1hTokens(); cacheCreation1h > 0 {
|
||||
claudeInfo.Usage.ClaudeCacheCreation1hTokens = cacheCreation1h
|
||||
}
|
||||
if claudeResponse.Usage.OutputTokens > 0 {
|
||||
claudeInfo.Usage.CompletionTokens = claudeResponse.Usage.OutputTokens
|
||||
}
|
||||
claudeInfo.Usage.TotalTokens = claudeInfo.Usage.PromptTokens + claudeInfo.Usage.CompletionTokens
|
||||
}
|
||||
|
||||
// 判断是否完整
|
||||
claudeInfo.Done = true
|
||||
} else if claudeResponse.Type == "content_block_start" {
|
||||
} else {
|
||||
return false
|
||||
if claudeResponse.Type == "message_start" {
|
||||
if claudeResponse.Message != nil {
|
||||
claudeInfo.ResponseId = claudeResponse.Message.Id
|
||||
claudeInfo.Model = claudeResponse.Message.Model
|
||||
}
|
||||
|
||||
// message_start, 获取usage
|
||||
if claudeResponse.Message != nil && claudeResponse.Message.Usage != nil {
|
||||
claudeInfo.Usage.PromptTokens = claudeResponse.Message.Usage.InputTokens
|
||||
claudeInfo.Usage.PromptTokensDetails.CachedTokens = claudeResponse.Message.Usage.CacheReadInputTokens
|
||||
claudeInfo.Usage.PromptTokensDetails.CachedCreationTokens = claudeResponse.Message.Usage.CacheCreationInputTokens
|
||||
claudeInfo.Usage.ClaudeCacheCreation5mTokens = claudeResponse.Message.Usage.GetCacheCreation5mTokens()
|
||||
claudeInfo.Usage.ClaudeCacheCreation1hTokens = claudeResponse.Message.Usage.GetCacheCreation1hTokens()
|
||||
claudeInfo.Usage.CompletionTokens = claudeResponse.Message.Usage.OutputTokens
|
||||
}
|
||||
} else if claudeResponse.Type == "content_block_delta" {
|
||||
if claudeResponse.Delta != nil {
|
||||
if claudeResponse.Delta.Text != nil {
|
||||
claudeInfo.ResponseText.WriteString(*claudeResponse.Delta.Text)
|
||||
}
|
||||
if claudeResponse.Delta.Thinking != nil {
|
||||
claudeInfo.ResponseText.WriteString(*claudeResponse.Delta.Thinking)
|
||||
}
|
||||
}
|
||||
} else if claudeResponse.Type == "message_delta" {
|
||||
// 最终的usage获取
|
||||
if claudeResponse.Usage != nil {
|
||||
if claudeResponse.Usage.InputTokens > 0 {
|
||||
// 不叠加,只取最新的
|
||||
claudeInfo.Usage.PromptTokens = claudeResponse.Usage.InputTokens
|
||||
}
|
||||
if claudeResponse.Usage.CacheReadInputTokens > 0 {
|
||||
claudeInfo.Usage.PromptTokensDetails.CachedTokens = claudeResponse.Usage.CacheReadInputTokens
|
||||
}
|
||||
if claudeResponse.Usage.CacheCreationInputTokens > 0 {
|
||||
claudeInfo.Usage.PromptTokensDetails.CachedCreationTokens = claudeResponse.Usage.CacheCreationInputTokens
|
||||
}
|
||||
if cacheCreation5m := claudeResponse.Usage.GetCacheCreation5mTokens(); cacheCreation5m > 0 {
|
||||
claudeInfo.Usage.ClaudeCacheCreation5mTokens = cacheCreation5m
|
||||
}
|
||||
if cacheCreation1h := claudeResponse.Usage.GetCacheCreation1hTokens(); cacheCreation1h > 0 {
|
||||
claudeInfo.Usage.ClaudeCacheCreation1hTokens = cacheCreation1h
|
||||
}
|
||||
if claudeResponse.Usage.OutputTokens > 0 {
|
||||
claudeInfo.Usage.CompletionTokens = claudeResponse.Usage.OutputTokens
|
||||
}
|
||||
claudeInfo.Usage.TotalTokens = claudeInfo.Usage.PromptTokens + claudeInfo.Usage.CompletionTokens
|
||||
}
|
||||
|
||||
// 判断是否完整
|
||||
claudeInfo.Done = true
|
||||
} else if claudeResponse.Type == "content_block_start" {
|
||||
} else {
|
||||
return false
|
||||
}
|
||||
if oaiResponse != nil {
|
||||
oaiResponse.Id = claudeInfo.ResponseId
|
||||
@@ -660,7 +614,7 @@ func FormatClaudeResponseInfo(requestMode int, claudeResponse *dto.ClaudeRespons
|
||||
return true
|
||||
}
|
||||
|
||||
func HandleStreamResponseData(c *gin.Context, info *relaycommon.RelayInfo, claudeInfo *ClaudeResponseInfo, data string, requestMode int) *types.NewAPIError {
|
||||
func HandleStreamResponseData(c *gin.Context, info *relaycommon.RelayInfo, claudeInfo *ClaudeResponseInfo, data string) *types.NewAPIError {
|
||||
var claudeResponse dto.ClaudeResponse
|
||||
err := common.UnmarshalJsonStr(data, &claudeResponse)
|
||||
if err != nil {
|
||||
@@ -677,24 +631,19 @@ func HandleStreamResponseData(c *gin.Context, info *relaycommon.RelayInfo, claud
|
||||
maybeMarkClaudeRefusal(c, *claudeResponse.Delta.StopReason)
|
||||
}
|
||||
if info.RelayFormat == types.RelayFormatClaude {
|
||||
FormatClaudeResponseInfo(requestMode, &claudeResponse, nil, claudeInfo)
|
||||
FormatClaudeResponseInfo(&claudeResponse, nil, claudeInfo)
|
||||
|
||||
if requestMode == RequestModeCompletion {
|
||||
} else {
|
||||
if claudeResponse.Type == "message_start" {
|
||||
// message_start, 获取usage
|
||||
if claudeResponse.Message != nil {
|
||||
info.UpstreamModelName = claudeResponse.Message.Model
|
||||
}
|
||||
} else if claudeResponse.Type == "content_block_delta" {
|
||||
} else if claudeResponse.Type == "message_delta" {
|
||||
if claudeResponse.Type == "message_start" {
|
||||
// message_start, 获取usage
|
||||
if claudeResponse.Message != nil {
|
||||
info.UpstreamModelName = claudeResponse.Message.Model
|
||||
}
|
||||
}
|
||||
helper.ClaudeChunkData(c, claudeResponse, data)
|
||||
} else if info.RelayFormat == types.RelayFormatOpenAI {
|
||||
response := StreamResponseClaude2OpenAI(requestMode, &claudeResponse)
|
||||
response := StreamResponseClaude2OpenAI(&claudeResponse)
|
||||
|
||||
if !FormatClaudeResponseInfo(requestMode, &claudeResponse, response, claudeInfo) {
|
||||
if !FormatClaudeResponseInfo(&claudeResponse, response, claudeInfo) {
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -706,20 +655,15 @@ func HandleStreamResponseData(c *gin.Context, info *relaycommon.RelayInfo, claud
|
||||
return nil
|
||||
}
|
||||
|
||||
func HandleStreamFinalResponse(c *gin.Context, info *relaycommon.RelayInfo, claudeInfo *ClaudeResponseInfo, requestMode int) {
|
||||
|
||||
if requestMode == RequestModeCompletion {
|
||||
claudeInfo.Usage = service.ResponseText2Usage(c, claudeInfo.ResponseText.String(), info.UpstreamModelName, info.GetEstimatePromptTokens())
|
||||
} else {
|
||||
if claudeInfo.Usage.PromptTokens == 0 {
|
||||
//上游出错
|
||||
}
|
||||
if claudeInfo.Usage.CompletionTokens == 0 || !claudeInfo.Done {
|
||||
if common.DebugEnabled {
|
||||
common.SysLog("claude response usage is not complete, maybe upstream error")
|
||||
}
|
||||
claudeInfo.Usage = service.ResponseText2Usage(c, claudeInfo.ResponseText.String(), info.UpstreamModelName, claudeInfo.Usage.PromptTokens)
|
||||
func HandleStreamFinalResponse(c *gin.Context, info *relaycommon.RelayInfo, claudeInfo *ClaudeResponseInfo) {
|
||||
if claudeInfo.Usage.PromptTokens == 0 {
|
||||
//上游出错
|
||||
}
|
||||
if claudeInfo.Usage.CompletionTokens == 0 || !claudeInfo.Done {
|
||||
if common.DebugEnabled {
|
||||
common.SysLog("claude response usage is not complete, maybe upstream error")
|
||||
}
|
||||
claudeInfo.Usage = service.ResponseText2Usage(c, claudeInfo.ResponseText.String(), info.UpstreamModelName, claudeInfo.Usage.PromptTokens)
|
||||
}
|
||||
|
||||
if info.RelayFormat == types.RelayFormatClaude {
|
||||
@@ -736,7 +680,7 @@ func HandleStreamFinalResponse(c *gin.Context, info *relaycommon.RelayInfo, clau
|
||||
}
|
||||
}
|
||||
|
||||
func ClaudeStreamHandler(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo, requestMode int) (*dto.Usage, *types.NewAPIError) {
|
||||
func ClaudeStreamHandler(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (*dto.Usage, *types.NewAPIError) {
|
||||
claudeInfo := &ClaudeResponseInfo{
|
||||
ResponseId: helper.GetResponseID(c),
|
||||
Created: common.GetTimestamp(),
|
||||
@@ -746,7 +690,7 @@ func ClaudeStreamHandler(c *gin.Context, resp *http.Response, info *relaycommon.
|
||||
}
|
||||
var err *types.NewAPIError
|
||||
helper.StreamScannerHandler(c, resp, info, func(data string) bool {
|
||||
err = HandleStreamResponseData(c, info, claudeInfo, data, requestMode)
|
||||
err = HandleStreamResponseData(c, info, claudeInfo, data)
|
||||
if err != nil {
|
||||
return false
|
||||
}
|
||||
@@ -756,11 +700,11 @@ func ClaudeStreamHandler(c *gin.Context, resp *http.Response, info *relaycommon.
|
||||
return nil, err
|
||||
}
|
||||
|
||||
HandleStreamFinalResponse(c, info, claudeInfo, requestMode)
|
||||
HandleStreamFinalResponse(c, info, claudeInfo)
|
||||
return claudeInfo.Usage, nil
|
||||
}
|
||||
|
||||
func HandleClaudeResponseData(c *gin.Context, info *relaycommon.RelayInfo, claudeInfo *ClaudeResponseInfo, httpResp *http.Response, data []byte, requestMode int) *types.NewAPIError {
|
||||
func HandleClaudeResponseData(c *gin.Context, info *relaycommon.RelayInfo, claudeInfo *ClaudeResponseInfo, httpResp *http.Response, data []byte) *types.NewAPIError {
|
||||
var claudeResponse dto.ClaudeResponse
|
||||
err := common.Unmarshal(data, &claudeResponse)
|
||||
if err != nil {
|
||||
@@ -770,26 +714,22 @@ func HandleClaudeResponseData(c *gin.Context, info *relaycommon.RelayInfo, claud
|
||||
return types.WithClaudeError(*claudeError, http.StatusInternalServerError)
|
||||
}
|
||||
maybeMarkClaudeRefusal(c, claudeResponse.StopReason)
|
||||
if requestMode == RequestModeCompletion {
|
||||
claudeInfo.Usage = service.ResponseText2Usage(c, claudeResponse.Completion, info.UpstreamModelName, info.GetEstimatePromptTokens())
|
||||
} else {
|
||||
if claudeInfo.Usage == nil {
|
||||
claudeInfo.Usage = &dto.Usage{}
|
||||
}
|
||||
if claudeResponse.Usage != nil {
|
||||
claudeInfo.Usage.PromptTokens = claudeResponse.Usage.InputTokens
|
||||
claudeInfo.Usage.CompletionTokens = claudeResponse.Usage.OutputTokens
|
||||
claudeInfo.Usage.TotalTokens = claudeResponse.Usage.InputTokens + claudeResponse.Usage.OutputTokens
|
||||
claudeInfo.Usage.PromptTokensDetails.CachedTokens = claudeResponse.Usage.CacheReadInputTokens
|
||||
claudeInfo.Usage.PromptTokensDetails.CachedCreationTokens = claudeResponse.Usage.CacheCreationInputTokens
|
||||
claudeInfo.Usage.ClaudeCacheCreation5mTokens = claudeResponse.Usage.GetCacheCreation5mTokens()
|
||||
claudeInfo.Usage.ClaudeCacheCreation1hTokens = claudeResponse.Usage.GetCacheCreation1hTokens()
|
||||
}
|
||||
if claudeInfo.Usage == nil {
|
||||
claudeInfo.Usage = &dto.Usage{}
|
||||
}
|
||||
if claudeResponse.Usage != nil {
|
||||
claudeInfo.Usage.PromptTokens = claudeResponse.Usage.InputTokens
|
||||
claudeInfo.Usage.CompletionTokens = claudeResponse.Usage.OutputTokens
|
||||
claudeInfo.Usage.TotalTokens = claudeResponse.Usage.InputTokens + claudeResponse.Usage.OutputTokens
|
||||
claudeInfo.Usage.PromptTokensDetails.CachedTokens = claudeResponse.Usage.CacheReadInputTokens
|
||||
claudeInfo.Usage.PromptTokensDetails.CachedCreationTokens = claudeResponse.Usage.CacheCreationInputTokens
|
||||
claudeInfo.Usage.ClaudeCacheCreation5mTokens = claudeResponse.Usage.GetCacheCreation5mTokens()
|
||||
claudeInfo.Usage.ClaudeCacheCreation1hTokens = claudeResponse.Usage.GetCacheCreation1hTokens()
|
||||
}
|
||||
var responseData []byte
|
||||
switch info.RelayFormat {
|
||||
case types.RelayFormatOpenAI:
|
||||
openaiResponse := ResponseClaude2OpenAI(requestMode, &claudeResponse)
|
||||
openaiResponse := ResponseClaude2OpenAI(&claudeResponse)
|
||||
openaiResponse.Usage = *claudeInfo.Usage
|
||||
responseData, err = json.Marshal(openaiResponse)
|
||||
if err != nil {
|
||||
@@ -807,7 +747,7 @@ func HandleClaudeResponseData(c *gin.Context, info *relaycommon.RelayInfo, claud
|
||||
return nil
|
||||
}
|
||||
|
||||
func ClaudeHandler(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo, requestMode int) (*dto.Usage, *types.NewAPIError) {
|
||||
func ClaudeHandler(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (*dto.Usage, *types.NewAPIError) {
|
||||
defer service.CloseResponseBodyGracefully(resp)
|
||||
|
||||
claudeInfo := &ClaudeResponseInfo{
|
||||
@@ -824,7 +764,7 @@ func ClaudeHandler(c *gin.Context, resp *http.Response, info *relaycommon.RelayI
|
||||
if common.DebugEnabled {
|
||||
println("responseBody: ", string(responseBody))
|
||||
}
|
||||
handleErr := HandleClaudeResponseData(c, info, claudeInfo, resp, responseBody, requestMode)
|
||||
handleErr := HandleClaudeResponseData(c, info, claudeInfo, resp, responseBody)
|
||||
if handleErr != nil {
|
||||
return nil, handleErr
|
||||
}
|
||||
|
||||
@@ -90,6 +90,12 @@ func (a *Adaptor) ConvertOpenAIResponsesRequest(c *gin.Context, info *relaycommo
|
||||
}
|
||||
}
|
||||
}
|
||||
// Codex backend requires the `instructions` field to be present.
|
||||
// Keep it consistent with Codex CLI behavior by defaulting to an empty string.
|
||||
if len(request.Instructions) == 0 {
|
||||
request.Instructions = json.RawMessage(`""`)
|
||||
}
|
||||
|
||||
if isCompact {
|
||||
return request, nil
|
||||
}
|
||||
@@ -172,5 +178,15 @@ func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Header, info *rel
|
||||
req.Set("originator", "codex_cli_rs")
|
||||
}
|
||||
|
||||
// chatgpt.com/backend-api/codex/responses is strict about Content-Type.
|
||||
// Clients may omit it or include parameters like `application/json; charset=utf-8`,
|
||||
// which can be rejected by the upstream. Force the exact media type.
|
||||
req.Set("Content-Type", "application/json")
|
||||
if info.IsStream {
|
||||
req.Set("Accept", "text/event-stream")
|
||||
} else if req.Get("Accept") == "" {
|
||||
req.Set("Accept", "application/json")
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -96,9 +96,9 @@ func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycom
|
||||
switch info.RelayFormat {
|
||||
case types.RelayFormatClaude:
|
||||
if info.IsStream {
|
||||
return claude.ClaudeStreamHandler(c, resp, info, claude.RequestModeMessage)
|
||||
return claude.ClaudeStreamHandler(c, resp, info)
|
||||
} else {
|
||||
return claude.ClaudeHandler(c, resp, info, claude.RequestModeMessage)
|
||||
return claude.ClaudeHandler(c, resp, info)
|
||||
}
|
||||
default:
|
||||
adaptor := openai.Adaptor{}
|
||||
|
||||
@@ -1258,8 +1258,7 @@ func geminiStreamHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http
|
||||
}
|
||||
|
||||
if usage.CompletionTokens <= 0 {
|
||||
str := responseText.String()
|
||||
if len(str) > 0 {
|
||||
if info.ReceivedResponseCount > 0 {
|
||||
usage = service.ResponseText2Usage(c, responseText.String(), info.UpstreamModelName, info.GetEstimatePromptTokens())
|
||||
} else {
|
||||
usage = &dto.Usage{}
|
||||
|
||||
@@ -103,9 +103,9 @@ func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycom
|
||||
switch info.RelayFormat {
|
||||
case types.RelayFormatClaude:
|
||||
if info.IsStream {
|
||||
return claude.ClaudeStreamHandler(c, resp, info, claude.RequestModeMessage)
|
||||
return claude.ClaudeStreamHandler(c, resp, info)
|
||||
} else {
|
||||
return claude.ClaudeHandler(c, resp, info, claude.RequestModeMessage)
|
||||
return claude.ClaudeHandler(c, resp, info)
|
||||
}
|
||||
default:
|
||||
adaptor := openai.Adaptor{}
|
||||
|
||||
@@ -42,6 +42,7 @@ var claudeModelMap = map[string]string{
|
||||
"claude-sonnet-4-5-20250929": "claude-sonnet-4-5@20250929",
|
||||
"claude-haiku-4-5-20251001": "claude-haiku-4-5@20251001",
|
||||
"claude-opus-4-5-20251101": "claude-opus-4-5@20251101",
|
||||
"claude-opus-4-6": "claude-opus-4-6",
|
||||
}
|
||||
|
||||
const anthropicVersion = "vertex-2023-10-16"
|
||||
@@ -367,7 +368,7 @@ func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycom
|
||||
if info.IsStream {
|
||||
switch a.RequestMode {
|
||||
case RequestModeClaude:
|
||||
return claude.ClaudeStreamHandler(c, resp, info, claude.RequestModeMessage)
|
||||
return claude.ClaudeStreamHandler(c, resp, info)
|
||||
case RequestModeGemini:
|
||||
if info.RelayMode == constant.RelayModeGemini {
|
||||
return gemini.GeminiTextGenerationStreamHandler(c, info, resp)
|
||||
@@ -380,7 +381,7 @@ func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycom
|
||||
} else {
|
||||
switch a.RequestMode {
|
||||
case RequestModeClaude:
|
||||
return claude.ClaudeHandler(c, resp, info, claude.RequestModeMessage)
|
||||
return claude.ClaudeHandler(c, resp, info)
|
||||
case RequestModeGemini:
|
||||
if info.RelayMode == constant.RelayModeGemini {
|
||||
return gemini.GeminiTextGenerationHandler(c, info, resp)
|
||||
|
||||
@@ -1,6 +1,8 @@
|
||||
package vertex
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
|
||||
"github.com/QuantumNous/new-api/dto"
|
||||
)
|
||||
|
||||
@@ -17,6 +19,7 @@ type VertexAIClaudeRequest struct {
|
||||
Tools any `json:"tools,omitempty"`
|
||||
ToolChoice any `json:"tool_choice,omitempty"`
|
||||
Thinking *dto.Thinking `json:"thinking,omitempty"`
|
||||
OutputConfig json.RawMessage `json:"output_config,omitempty"`
|
||||
}
|
||||
|
||||
func copyRequest(req *dto.ClaudeRequest, version string) *VertexAIClaudeRequest {
|
||||
@@ -33,5 +36,6 @@ func copyRequest(req *dto.ClaudeRequest, version string) *VertexAIClaudeRequest
|
||||
Tools: req.Tools,
|
||||
ToolChoice: req.ToolChoice,
|
||||
Thinking: req.Thinking,
|
||||
OutputConfig: req.OutputConfig,
|
||||
}
|
||||
}
|
||||
|
||||
@@ -348,9 +348,9 @@ func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycom
|
||||
if info.RelayFormat == types.RelayFormatClaude {
|
||||
if _, ok := channelconstant.ChannelSpecialBases[info.ChannelBaseUrl]; ok {
|
||||
if info.IsStream {
|
||||
return claude.ClaudeStreamHandler(c, resp, info, claude.RequestModeMessage)
|
||||
return claude.ClaudeStreamHandler(c, resp, info)
|
||||
}
|
||||
return claude.ClaudeHandler(c, resp, info, claude.RequestModeMessage)
|
||||
return claude.ClaudeHandler(c, resp, info)
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -110,9 +110,9 @@ func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycom
|
||||
switch info.RelayFormat {
|
||||
case types.RelayFormatClaude:
|
||||
if info.IsStream {
|
||||
return claude.ClaudeStreamHandler(c, resp, info, claude.RequestModeMessage)
|
||||
return claude.ClaudeStreamHandler(c, resp, info)
|
||||
} else {
|
||||
return claude.ClaudeHandler(c, resp, info, claude.RequestModeMessage)
|
||||
return claude.ClaudeHandler(c, resp, info)
|
||||
}
|
||||
default:
|
||||
if info.RelayMode == relayconstant.RelayModeImagesGenerations {
|
||||
|
||||
@@ -2,6 +2,7 @@ package relay
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
@@ -14,6 +15,7 @@ import (
|
||||
"github.com/QuantumNous/new-api/relay/helper"
|
||||
"github.com/QuantumNous/new-api/service"
|
||||
"github.com/QuantumNous/new-api/setting/model_setting"
|
||||
"github.com/QuantumNous/new-api/setting/reasoning"
|
||||
"github.com/QuantumNous/new-api/types"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
@@ -49,7 +51,17 @@ func ClaudeHelper(c *gin.Context, info *relaycommon.RelayInfo) (newAPIError *typ
|
||||
request.MaxTokens = uint(model_setting.GetClaudeSettings().GetDefaultMaxTokens(request.Model))
|
||||
}
|
||||
|
||||
if model_setting.GetClaudeSettings().ThinkingAdapterEnabled &&
|
||||
if baseModel, effortLevel, ok := reasoning.TrimEffortSuffix(request.Model); ok && effortLevel != "" &&
|
||||
strings.HasPrefix(request.Model, "claude-opus-4-6") {
|
||||
request.Model = baseModel
|
||||
request.Thinking = &dto.Thinking{
|
||||
Type: "adaptive",
|
||||
}
|
||||
request.OutputConfig = json.RawMessage(fmt.Sprintf(`{"effort":"%s"}`, effortLevel))
|
||||
request.TopP = 0
|
||||
request.Temperature = common.GetPointer[float64](1.0)
|
||||
info.UpstreamModelName = request.Model
|
||||
} else if model_setting.GetClaudeSettings().ThinkingAdapterEnabled &&
|
||||
strings.HasSuffix(request.Model, "-thinking") {
|
||||
if request.Thinking == nil {
|
||||
// 因为BudgetTokens 必须大于1024
|
||||
|
||||
@@ -113,6 +113,7 @@ type RelayInfo struct {
|
||||
UserQuota int
|
||||
RelayFormat types.RelayFormat
|
||||
SendResponseCount int
|
||||
ReceivedResponseCount int
|
||||
FinalPreConsumedQuota int // 最终预消耗的配额
|
||||
// BillingSource indicates whether this request is billed from wallet quota or subscription.
|
||||
// "" or "wallet" => wallet; "subscription" => subscription
|
||||
|
||||
@@ -90,10 +90,10 @@ func StreamScannerHandler(c *gin.Context, resp *http.Response, info *relaycommon
|
||||
|
||||
// 等待所有 goroutine 退出,最多等待5秒
|
||||
done := make(chan struct{})
|
||||
go func() {
|
||||
gopool.Go(func() {
|
||||
wg.Wait()
|
||||
close(done)
|
||||
}()
|
||||
})
|
||||
|
||||
select {
|
||||
case <-done:
|
||||
@@ -138,11 +138,11 @@ func StreamScannerHandler(c *gin.Context, resp *http.Response, info *relaycommon
|
||||
case <-pingTicker.C:
|
||||
// 使用超时机制防止写操作阻塞
|
||||
done := make(chan error, 1)
|
||||
go func() {
|
||||
gopool.Go(func() {
|
||||
writeMutex.Lock()
|
||||
defer writeMutex.Unlock()
|
||||
done <- PingData(c)
|
||||
}()
|
||||
})
|
||||
|
||||
select {
|
||||
case err := <-done:
|
||||
@@ -219,14 +219,14 @@ func StreamScannerHandler(c *gin.Context, resp *http.Response, info *relaycommon
|
||||
data = strings.TrimSuffix(data, "\r")
|
||||
if !strings.HasPrefix(data, "[DONE]") {
|
||||
info.SetFirstResponseTime()
|
||||
|
||||
info.ReceivedResponseCount++
|
||||
// 使用超时机制防止写操作阻塞
|
||||
done := make(chan bool, 1)
|
||||
go func() {
|
||||
gopool.Go(func() {
|
||||
writeMutex.Lock()
|
||||
defer writeMutex.Unlock()
|
||||
done <- dataHandler(data)
|
||||
}()
|
||||
})
|
||||
|
||||
select {
|
||||
case success := <-done:
|
||||
|
||||
@@ -4,6 +4,9 @@ import (
|
||||
"github.com/QuantumNous/new-api/controller"
|
||||
"github.com/QuantumNous/new-api/middleware"
|
||||
|
||||
// Import oauth package to register providers via init()
|
||||
_ "github.com/QuantumNous/new-api/oauth"
|
||||
|
||||
"github.com/gin-contrib/gzip"
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
@@ -30,16 +33,16 @@ func SetApiRouter(router *gin.Engine) {
|
||||
apiRouter.GET("/verification", middleware.EmailVerificationRateLimit(), middleware.TurnstileCheck(), controller.SendEmailVerification)
|
||||
apiRouter.GET("/reset_password", middleware.CriticalRateLimit(), middleware.TurnstileCheck(), controller.SendPasswordResetEmail)
|
||||
apiRouter.POST("/user/reset", middleware.CriticalRateLimit(), controller.ResetPassword)
|
||||
apiRouter.GET("/oauth/github", middleware.CriticalRateLimit(), controller.GitHubOAuth)
|
||||
apiRouter.GET("/oauth/discord", middleware.CriticalRateLimit(), controller.DiscordOAuth)
|
||||
apiRouter.GET("/oauth/oidc", middleware.CriticalRateLimit(), controller.OidcAuth)
|
||||
apiRouter.GET("/oauth/linuxdo", middleware.CriticalRateLimit(), controller.LinuxdoOAuth)
|
||||
// OAuth routes - specific routes must come before :provider wildcard
|
||||
apiRouter.GET("/oauth/state", middleware.CriticalRateLimit(), controller.GenerateOAuthCode)
|
||||
apiRouter.GET("/oauth/email/bind", middleware.CriticalRateLimit(), controller.EmailBind)
|
||||
// Non-standard OAuth (WeChat, Telegram) - keep original routes
|
||||
apiRouter.GET("/oauth/wechat", middleware.CriticalRateLimit(), controller.WeChatAuth)
|
||||
apiRouter.GET("/oauth/wechat/bind", middleware.CriticalRateLimit(), controller.WeChatBind)
|
||||
apiRouter.GET("/oauth/email/bind", middleware.CriticalRateLimit(), controller.EmailBind)
|
||||
apiRouter.GET("/oauth/telegram/login", middleware.CriticalRateLimit(), controller.TelegramLogin)
|
||||
apiRouter.GET("/oauth/telegram/bind", middleware.CriticalRateLimit(), controller.TelegramBind)
|
||||
// Standard OAuth providers (GitHub, Discord, OIDC, LinuxDO) - unified route
|
||||
apiRouter.GET("/oauth/:provider", middleware.CriticalRateLimit(), controller.HandleOAuth)
|
||||
apiRouter.GET("/ratio_config", middleware.CriticalRateLimit(), controller.GetRatioConfig)
|
||||
|
||||
apiRouter.POST("/stripe/webhook", controller.StripeWebhook)
|
||||
@@ -99,6 +102,10 @@ func SetApiRouter(router *gin.Engine) {
|
||||
// Check-in routes
|
||||
selfRoute.GET("/checkin", controller.GetCheckinStatus)
|
||||
selfRoute.POST("/checkin", middleware.TurnstileCheck(), controller.DoCheckin)
|
||||
|
||||
// Custom OAuth bindings
|
||||
selfRoute.GET("/oauth/bindings", controller.GetUserOAuthBindings)
|
||||
selfRoute.DELETE("/oauth/bindings/:provider_id", controller.UnbindCustomOAuth)
|
||||
}
|
||||
|
||||
adminRoute := userRoute.Group("/")
|
||||
@@ -163,6 +170,17 @@ func SetApiRouter(router *gin.Engine) {
|
||||
optionRoute.POST("/rest_model_ratio", controller.ResetModelRatio)
|
||||
optionRoute.POST("/migrate_console_setting", controller.MigrateConsoleSetting) // 用于迁移检测的旧键,下个版本会删除
|
||||
}
|
||||
|
||||
// Custom OAuth provider management (admin only)
|
||||
customOAuthRoute := apiRouter.Group("/custom-oauth-provider")
|
||||
customOAuthRoute.Use(middleware.RootAuth())
|
||||
{
|
||||
customOAuthRoute.GET("/", controller.GetCustomOAuthProviders)
|
||||
customOAuthRoute.GET("/:id", controller.GetCustomOAuthProvider)
|
||||
customOAuthRoute.POST("/", controller.CreateCustomOAuthProvider)
|
||||
customOAuthRoute.PUT("/:id", controller.UpdateCustomOAuthProvider)
|
||||
customOAuthRoute.DELETE("/:id", controller.DeleteCustomOAuthProvider)
|
||||
}
|
||||
performanceRoute := apiRouter.Group("/performance")
|
||||
performanceRoute.Use(middleware.RootAuth())
|
||||
{
|
||||
@@ -220,7 +238,7 @@ func SetApiRouter(router *gin.Engine) {
|
||||
tokenRoute.Use(middleware.UserAuth())
|
||||
{
|
||||
tokenRoute.GET("/", controller.GetAllTokens)
|
||||
tokenRoute.GET("/search", controller.SearchTokens)
|
||||
tokenRoute.GET("/search", middleware.SearchRateLimit(), controller.SearchTokens)
|
||||
tokenRoute.GET("/:id", controller.GetToken)
|
||||
tokenRoute.POST("/", controller.AddToken)
|
||||
tokenRoute.PUT("/", controller.UpdateToken)
|
||||
|
||||
@@ -212,13 +212,23 @@ func updateConfigFromMap(config interface{}, configMap map[string]string) error
|
||||
case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
|
||||
intValue, err := strconv.ParseInt(strValue, 10, 64)
|
||||
if err != nil {
|
||||
continue
|
||||
// 兼容 float 格式的字符串(如 "2.000000")
|
||||
floatValue, fErr := strconv.ParseFloat(strValue, 64)
|
||||
if fErr != nil {
|
||||
continue
|
||||
}
|
||||
intValue = int64(floatValue)
|
||||
}
|
||||
field.SetInt(intValue)
|
||||
case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64:
|
||||
uintValue, err := strconv.ParseUint(strValue, 10, 64)
|
||||
if err != nil {
|
||||
continue
|
||||
// 兼容 float 格式的字符串
|
||||
floatValue, fErr := strconv.ParseFloat(strValue, 64)
|
||||
if fErr != nil || floatValue < 0 {
|
||||
continue
|
||||
}
|
||||
uintValue = uint64(floatValue)
|
||||
}
|
||||
field.SetUint(uintValue)
|
||||
case reflect.Float32, reflect.Float64:
|
||||
|
||||
28
setting/operation_setting/token_setting.go
Normal file
28
setting/operation_setting/token_setting.go
Normal file
@@ -0,0 +1,28 @@
|
||||
package operation_setting
|
||||
|
||||
import "github.com/QuantumNous/new-api/setting/config"
|
||||
|
||||
// TokenSetting 令牌相关配置
|
||||
type TokenSetting struct {
|
||||
MaxUserTokens int `json:"max_user_tokens"` // 每用户最大令牌数量
|
||||
}
|
||||
|
||||
// 默认配置
|
||||
var tokenSetting = TokenSetting{
|
||||
MaxUserTokens: 1000, // 默认每用户最多 1000 个令牌
|
||||
}
|
||||
|
||||
func init() {
|
||||
// 注册到全局配置管理器
|
||||
config.GlobalConfig.Register("token_setting", &tokenSetting)
|
||||
}
|
||||
|
||||
// GetTokenSetting 获取令牌配置
|
||||
func GetTokenSetting() *TokenSetting {
|
||||
return &tokenSetting
|
||||
}
|
||||
|
||||
// GetMaxUserTokens 获取每用户最大令牌数量
|
||||
func GetMaxUserTokens() int {
|
||||
return GetTokenSetting().MaxUserTokens
|
||||
}
|
||||
@@ -60,6 +60,12 @@ var defaultCacheRatio = map[string]float64{
|
||||
"claude-sonnet-4-5-20250929-thinking": 0.1,
|
||||
"claude-opus-4-5-20251101": 0.1,
|
||||
"claude-opus-4-5-20251101-thinking": 0.1,
|
||||
"claude-opus-4-6": 0.1,
|
||||
"claude-opus-4-6-thinking": 0.1,
|
||||
"claude-opus-4-6-max": 0.1,
|
||||
"claude-opus-4-6-high": 0.1,
|
||||
"claude-opus-4-6-medium": 0.1,
|
||||
"claude-opus-4-6-low": 0.1,
|
||||
}
|
||||
|
||||
var defaultCreateCacheRatio = map[string]float64{
|
||||
@@ -82,6 +88,12 @@ var defaultCreateCacheRatio = map[string]float64{
|
||||
"claude-sonnet-4-5-20250929-thinking": 1.25,
|
||||
"claude-opus-4-5-20251101": 1.25,
|
||||
"claude-opus-4-5-20251101-thinking": 1.25,
|
||||
"claude-opus-4-6": 1.25,
|
||||
"claude-opus-4-6-thinking": 1.25,
|
||||
"claude-opus-4-6-max": 1.25,
|
||||
"claude-opus-4-6-high": 1.25,
|
||||
"claude-opus-4-6-medium": 1.25,
|
||||
"claude-opus-4-6-low": 1.25,
|
||||
}
|
||||
|
||||
//var defaultCreateCacheRatio = map[string]float64{}
|
||||
|
||||
@@ -131,9 +131,6 @@ var defaultModelRatio = map[string]float64{
|
||||
"text-search-ada-doc-001": 10,
|
||||
"text-moderation-stable": 0.1,
|
||||
"text-moderation-latest": 0.1,
|
||||
"claude-instant-1": 0.4, // $0.8 / 1M tokens
|
||||
"claude-2.0": 4, // $8 / 1M tokens
|
||||
"claude-2.1": 4, // $8 / 1M tokens
|
||||
"claude-3-haiku-20240307": 0.125, // $0.25 / 1M tokens
|
||||
"claude-3-5-haiku-20241022": 0.5, // $1 / 1M tokens
|
||||
"claude-haiku-4-5-20251001": 0.5, // $1 / 1M tokens
|
||||
@@ -145,6 +142,11 @@ var defaultModelRatio = map[string]float64{
|
||||
"claude-sonnet-4-20250514": 1.5,
|
||||
"claude-sonnet-4-5-20250929": 1.5,
|
||||
"claude-opus-4-5-20251101": 2.5,
|
||||
"claude-opus-4-6": 2.5,
|
||||
"claude-opus-4-6-max": 2.5,
|
||||
"claude-opus-4-6-high": 2.5,
|
||||
"claude-opus-4-6-medium": 2.5,
|
||||
"claude-opus-4-6-low": 2.5,
|
||||
"claude-3-opus-20240229": 7.5, // $15 / 1M tokens
|
||||
"claude-opus-4-20250514": 7.5,
|
||||
"claude-opus-4-1-20250805": 7.5,
|
||||
@@ -589,8 +591,6 @@ func getHardcodedCompletionModelRatio(name string) (float64, bool) {
|
||||
return 5, true
|
||||
} else if strings.Contains(name, "claude-sonnet-4") || strings.Contains(name, "claude-opus-4") || strings.Contains(name, "claude-haiku-4") {
|
||||
return 5, true
|
||||
} else if strings.Contains(name, "claude-instant-1") || strings.Contains(name, "claude-2") {
|
||||
return 3, true
|
||||
}
|
||||
|
||||
if strings.HasPrefix(name, "gpt-3.5") {
|
||||
|
||||
@@ -6,7 +6,7 @@ import (
|
||||
"github.com/samber/lo"
|
||||
)
|
||||
|
||||
var EffortSuffixes = []string{"-high", "-medium", "-low", "-minimal"}
|
||||
var EffortSuffixes = []string{"-max", "-high", "-medium", "-low", "-minimal"}
|
||||
|
||||
// TrimEffortSuffix -> modelName level(low) exists
|
||||
func TrimEffortSuffix(modelName string) (string, string, bool) {
|
||||
|
||||
@@ -34,6 +34,7 @@ import {
|
||||
onDiscordOAuthClicked,
|
||||
onOIDCClicked,
|
||||
onLinuxDOOAuthClicked,
|
||||
onCustomOAuthClicked,
|
||||
prepareCredentialRequestOptions,
|
||||
buildAssertionResult,
|
||||
isPasskeySupported,
|
||||
@@ -109,6 +110,7 @@ const LoginForm = () => {
|
||||
const [githubButtonDisabled, setGithubButtonDisabled] = useState(false);
|
||||
const githubTimeoutRef = useRef(null);
|
||||
const githubButtonText = t(githubButtonTextKeyByState[githubButtonState]);
|
||||
const [customOAuthLoading, setCustomOAuthLoading] = useState({});
|
||||
|
||||
const logo = getLogo();
|
||||
const systemName = getSystemName();
|
||||
@@ -373,6 +375,23 @@ const LoginForm = () => {
|
||||
}
|
||||
};
|
||||
|
||||
// 包装的自定义OAuth登录点击处理
|
||||
const handleCustomOAuthClick = (provider) => {
|
||||
if ((hasUserAgreement || hasPrivacyPolicy) && !agreedToTerms) {
|
||||
showInfo(t('请先阅读并同意用户协议和隐私政策'));
|
||||
return;
|
||||
}
|
||||
setCustomOAuthLoading((prev) => ({ ...prev, [provider.slug]: true }));
|
||||
try {
|
||||
onCustomOAuthClicked(provider, { shouldLogout: true });
|
||||
} finally {
|
||||
// 由于重定向,这里不会执行到,但为了完整性添加
|
||||
setTimeout(() => {
|
||||
setCustomOAuthLoading((prev) => ({ ...prev, [provider.slug]: false }));
|
||||
}, 3000);
|
||||
}
|
||||
};
|
||||
|
||||
// 包装的邮箱登录选项点击处理
|
||||
const handleEmailLoginClick = () => {
|
||||
setEmailLoginLoading(true);
|
||||
@@ -572,6 +591,23 @@ const LoginForm = () => {
|
||||
</Button>
|
||||
)}
|
||||
|
||||
{status.custom_oauth_providers &&
|
||||
status.custom_oauth_providers.map((provider) => (
|
||||
<Button
|
||||
key={provider.slug}
|
||||
theme='outline'
|
||||
className='w-full h-12 flex items-center justify-center !rounded-full border border-gray-200 hover:bg-gray-50 transition-colors'
|
||||
type='tertiary'
|
||||
icon={<IconLock size='large' />}
|
||||
onClick={() => handleCustomOAuthClick(provider)}
|
||||
loading={customOAuthLoading[provider.slug]}
|
||||
>
|
||||
<span className='ml-3'>
|
||||
{t('使用 {{name}} 继续', { name: provider.name })}
|
||||
</span>
|
||||
</Button>
|
||||
))}
|
||||
|
||||
{status.telegram_oauth && (
|
||||
<div className='flex justify-center my-2'>
|
||||
<TelegramLoginButton
|
||||
|
||||
@@ -17,7 +17,7 @@ along with this program. If not, see <https://www.gnu.org/licenses/>.
|
||||
For commercial licensing, please contact support@quantumnous.com
|
||||
*/
|
||||
|
||||
import React, { useContext, useEffect } from 'react';
|
||||
import React, { useContext, useEffect, useRef } from 'react';
|
||||
import { useNavigate, useSearchParams } from 'react-router-dom';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
import {
|
||||
@@ -35,6 +35,9 @@ const OAuth2Callback = (props) => {
|
||||
const [searchParams] = useSearchParams();
|
||||
const [, userDispatch] = useContext(UserContext);
|
||||
const navigate = useNavigate();
|
||||
|
||||
// 防止 React 18 Strict Mode 下重复执行
|
||||
const hasExecuted = useRef(false);
|
||||
|
||||
// 最大重试次数
|
||||
const MAX_RETRIES = 3;
|
||||
@@ -48,7 +51,9 @@ const OAuth2Callback = (props) => {
|
||||
const { success, message, data } = resData;
|
||||
|
||||
if (!success) {
|
||||
throw new Error(message || 'OAuth2 callback error');
|
||||
// 业务错误不重试,直接显示错误
|
||||
showError(message || t('授权失败'));
|
||||
return;
|
||||
}
|
||||
|
||||
if (message === 'bind') {
|
||||
@@ -63,6 +68,7 @@ const OAuth2Callback = (props) => {
|
||||
navigate('/console/token');
|
||||
}
|
||||
} catch (error) {
|
||||
// 网络错误等可重试
|
||||
if (retry < MAX_RETRIES) {
|
||||
// 递增的退避等待
|
||||
await new Promise((resolve) => setTimeout(resolve, (retry + 1) * 2000));
|
||||
@@ -76,6 +82,12 @@ const OAuth2Callback = (props) => {
|
||||
};
|
||||
|
||||
useEffect(() => {
|
||||
// 防止 React 18 Strict Mode 下重复执行
|
||||
if (hasExecuted.current) {
|
||||
return;
|
||||
}
|
||||
hasExecuted.current = true;
|
||||
|
||||
const code = searchParams.get('code');
|
||||
const state = searchParams.get('state');
|
||||
|
||||
|
||||
@@ -1,113 +0,0 @@
|
||||
/*
|
||||
Copyright (C) 2025 QuantumNous
|
||||
|
||||
This program is free software: you can redistribute it and/or modify
|
||||
it under the terms of the GNU Affero General Public License as
|
||||
published by the Free Software Foundation, either version 3 of the
|
||||
License, or (at your option) any later version.
|
||||
|
||||
This program is distributed in the hope that it will be useful,
|
||||
but WITHOUT ANY WARRANTY; without even the implied warranty of
|
||||
MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
||||
GNU Affero General Public License for more details.
|
||||
|
||||
You should have received a copy of the GNU Affero General Public License
|
||||
along with this program. If not, see <https://www.gnu.org/licenses/>.
|
||||
|
||||
For commercial licensing, please contact support@quantumnous.com
|
||||
*/
|
||||
|
||||
import React, { useState } from 'react';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
import { Button, Modal } from '@douyinfe/semi-ui';
|
||||
import { useSecureVerification } from '../../../hooks/common/useSecureVerification';
|
||||
import { createApiCalls } from '../../../services/secureVerification';
|
||||
import SecureVerificationModal from '../modals/SecureVerificationModal';
|
||||
import ChannelKeyDisplay from '../ui/ChannelKeyDisplay';
|
||||
|
||||
/**
|
||||
* 渠道密钥查看组件使用示例
|
||||
* 展示如何使用通用安全验证系统
|
||||
*/
|
||||
const ChannelKeyViewExample = ({ channelId }) => {
|
||||
const { t } = useTranslation();
|
||||
const [keyData, setKeyData] = useState('');
|
||||
const [showKeyModal, setShowKeyModal] = useState(false);
|
||||
|
||||
// 使用通用安全验证 Hook
|
||||
const {
|
||||
isModalVisible,
|
||||
verificationMethods,
|
||||
verificationState,
|
||||
startVerification,
|
||||
executeVerification,
|
||||
cancelVerification,
|
||||
setVerificationCode,
|
||||
switchVerificationMethod,
|
||||
} = useSecureVerification({
|
||||
onSuccess: (result) => {
|
||||
// 验证成功后处理结果
|
||||
if (result.success && result.data?.key) {
|
||||
setKeyData(result.data.key);
|
||||
setShowKeyModal(true);
|
||||
}
|
||||
},
|
||||
successMessage: t('密钥获取成功'),
|
||||
});
|
||||
|
||||
// 开始查看密钥流程
|
||||
const handleViewKey = async () => {
|
||||
const apiCall = createApiCalls.viewChannelKey(channelId);
|
||||
|
||||
await startVerification(apiCall, {
|
||||
title: t('查看渠道密钥'),
|
||||
description: t('为了保护账户安全,请验证您的身份。'),
|
||||
preferredMethod: 'passkey', // 可以指定首选验证方式
|
||||
});
|
||||
};
|
||||
|
||||
return (
|
||||
<>
|
||||
{/* 查看密钥按钮 */}
|
||||
<Button type='primary' theme='outline' onClick={handleViewKey}>
|
||||
{t('查看密钥')}
|
||||
</Button>
|
||||
|
||||
{/* 安全验证模态框 */}
|
||||
<SecureVerificationModal
|
||||
visible={isModalVisible}
|
||||
verificationMethods={verificationMethods}
|
||||
verificationState={verificationState}
|
||||
onVerify={executeVerification}
|
||||
onCancel={cancelVerification}
|
||||
onCodeChange={setVerificationCode}
|
||||
onMethodSwitch={switchVerificationMethod}
|
||||
title={verificationState.title}
|
||||
description={verificationState.description}
|
||||
/>
|
||||
|
||||
{/* 密钥显示模态框 */}
|
||||
<Modal
|
||||
title={t('渠道密钥信息')}
|
||||
visible={showKeyModal}
|
||||
onCancel={() => setShowKeyModal(false)}
|
||||
footer={
|
||||
<Button type='primary' onClick={() => setShowKeyModal(false)}>
|
||||
{t('完成')}
|
||||
</Button>
|
||||
}
|
||||
width={700}
|
||||
style={{ maxWidth: '90vw' }}
|
||||
>
|
||||
<ChannelKeyDisplay
|
||||
keyData={keyData}
|
||||
showSuccessIcon={true}
|
||||
successText={t('密钥获取成功')}
|
||||
showWarning={true}
|
||||
/>
|
||||
</Modal>
|
||||
</>
|
||||
);
|
||||
};
|
||||
|
||||
export default ChannelKeyViewExample;
|
||||
@@ -93,6 +93,49 @@ export function Mermaid(props) {
|
||||
);
|
||||
}
|
||||
|
||||
function SandboxedHtmlPreview({ code }) {
|
||||
const iframeRef = useRef(null);
|
||||
const [iframeHeight, setIframeHeight] = useState(150);
|
||||
|
||||
useEffect(() => {
|
||||
const iframe = iframeRef.current;
|
||||
if (!iframe) return;
|
||||
|
||||
const handleLoad = () => {
|
||||
try {
|
||||
const doc = iframe.contentDocument || iframe.contentWindow?.document;
|
||||
if (doc) {
|
||||
const height =
|
||||
doc.documentElement.scrollHeight || doc.body.scrollHeight;
|
||||
setIframeHeight(Math.min(Math.max(height + 16, 60), 600));
|
||||
}
|
||||
} catch {
|
||||
// sandbox restrictions may prevent access, that's fine
|
||||
}
|
||||
};
|
||||
|
||||
iframe.addEventListener('load', handleLoad);
|
||||
return () => iframe.removeEventListener('load', handleLoad);
|
||||
}, [code]);
|
||||
|
||||
return (
|
||||
<iframe
|
||||
ref={iframeRef}
|
||||
sandbox='allow-same-origin'
|
||||
srcDoc={code}
|
||||
title='HTML Preview'
|
||||
style={{
|
||||
width: '100%',
|
||||
height: `${iframeHeight}px`,
|
||||
border: 'none',
|
||||
overflow: 'auto',
|
||||
backgroundColor: '#fff',
|
||||
borderRadius: '4px',
|
||||
}}
|
||||
/>
|
||||
);
|
||||
}
|
||||
|
||||
export function PreCode(props) {
|
||||
const ref = useRef(null);
|
||||
const [mermaidCode, setMermaidCode] = useState('');
|
||||
@@ -227,7 +270,7 @@ export function PreCode(props) {
|
||||
>
|
||||
HTML预览:
|
||||
</div>
|
||||
<div dangerouslySetInnerHTML={{ __html: htmlCode }} />
|
||||
<SandboxedHtmlPreview code={htmlCode} />
|
||||
</div>
|
||||
)}
|
||||
</>
|
||||
|
||||
@@ -1,148 +0,0 @@
|
||||
/*
|
||||
Copyright (C) 2025 QuantumNous
|
||||
|
||||
This program is free software: you can redistribute it and/or modify
|
||||
it under the terms of the GNU Affero General Public License as
|
||||
published by the Free Software Foundation, either version 3 of the
|
||||
License, or (at your option) any later version.
|
||||
|
||||
This program is distributed in the hope that it will be useful,
|
||||
but WITHOUT ANY WARRANTY; without even the implied warranty of
|
||||
MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
||||
GNU Affero General Public License for more details.
|
||||
|
||||
You should have received a copy of the GNU Affero General Public License
|
||||
along with this program. If not, see <https://www.gnu.org/licenses/>.
|
||||
|
||||
For commercial licensing, please contact support@quantumnous.com
|
||||
*/
|
||||
|
||||
import React from 'react';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
import { Modal, Button, Input, Typography } from '@douyinfe/semi-ui';
|
||||
|
||||
/**
|
||||
* 可复用的两步验证模态框组件
|
||||
* @param {Object} props
|
||||
* @param {boolean} props.visible - 是否显示模态框
|
||||
* @param {string} props.code - 验证码值
|
||||
* @param {boolean} props.loading - 是否正在验证
|
||||
* @param {Function} props.onCodeChange - 验证码变化回调
|
||||
* @param {Function} props.onVerify - 验证回调
|
||||
* @param {Function} props.onCancel - 取消回调
|
||||
* @param {string} props.title - 模态框标题
|
||||
* @param {string} props.description - 验证描述文本
|
||||
* @param {string} props.placeholder - 输入框占位文本
|
||||
*/
|
||||
const TwoFactorAuthModal = ({
|
||||
visible,
|
||||
code,
|
||||
loading,
|
||||
onCodeChange,
|
||||
onVerify,
|
||||
onCancel,
|
||||
title,
|
||||
description,
|
||||
placeholder,
|
||||
}) => {
|
||||
const { t } = useTranslation();
|
||||
|
||||
const handleKeyDown = (e) => {
|
||||
if (e.key === 'Enter' && code && !loading) {
|
||||
onVerify();
|
||||
}
|
||||
};
|
||||
|
||||
return (
|
||||
<Modal
|
||||
title={
|
||||
<div className='flex items-center'>
|
||||
<div className='w-8 h-8 rounded-full bg-blue-100 dark:bg-blue-900 flex items-center justify-center mr-3'>
|
||||
<svg
|
||||
className='w-4 h-4 text-blue-600 dark:text-blue-400'
|
||||
fill='currentColor'
|
||||
viewBox='0 0 20 20'
|
||||
>
|
||||
<path
|
||||
fillRule='evenodd'
|
||||
d='M5 9V7a5 5 0 0110 0v2a2 2 0 012 2v5a2 2 0 01-2 2H5a2 2 0 01-2-2v-5a2 2 0 012-2zm8-2v2H7V7a3 3 0 016 0z'
|
||||
clipRule='evenodd'
|
||||
/>
|
||||
</svg>
|
||||
</div>
|
||||
{title || t('安全验证')}
|
||||
</div>
|
||||
}
|
||||
visible={visible}
|
||||
onCancel={onCancel}
|
||||
footer={
|
||||
<>
|
||||
<Button onClick={onCancel}>{t('取消')}</Button>
|
||||
<Button
|
||||
type='primary'
|
||||
loading={loading}
|
||||
disabled={!code || loading}
|
||||
onClick={onVerify}
|
||||
>
|
||||
{t('验证')}
|
||||
</Button>
|
||||
</>
|
||||
}
|
||||
width={500}
|
||||
style={{ maxWidth: '90vw' }}
|
||||
>
|
||||
<div className='space-y-6'>
|
||||
{/* 安全提示 */}
|
||||
<div className='bg-blue-50 dark:bg-blue-900 rounded-lg p-4'>
|
||||
<div className='flex items-start'>
|
||||
<svg
|
||||
className='w-5 h-5 text-blue-600 dark:text-blue-400 mt-0.5 mr-3 flex-shrink-0'
|
||||
fill='currentColor'
|
||||
viewBox='0 0 20 20'
|
||||
>
|
||||
<path
|
||||
fillRule='evenodd'
|
||||
d='M18 10a8 8 0 11-16 0 8 8 0 0116 0zm-7-4a1 1 0 11-2 0 1 1 0 012 0zM9 9a1 1 0 000 2v3a1 1 0 001 1h1a1 1 0 100-2v-3a1 1 0 00-1-1H9z'
|
||||
clipRule='evenodd'
|
||||
/>
|
||||
</svg>
|
||||
<div>
|
||||
<Typography.Text
|
||||
strong
|
||||
className='text-blue-800 dark:text-blue-200'
|
||||
>
|
||||
{t('安全验证')}
|
||||
</Typography.Text>
|
||||
<Typography.Text className='block text-blue-700 dark:text-blue-300 text-sm mt-1'>
|
||||
{description || t('为了保护账户安全,请验证您的两步验证码。')}
|
||||
</Typography.Text>
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
{/* 验证码输入 */}
|
||||
<div>
|
||||
<Typography.Text strong className='block mb-2'>
|
||||
{t('验证身份')}
|
||||
</Typography.Text>
|
||||
<Input
|
||||
placeholder={placeholder || t('请输入认证器验证码或备用码')}
|
||||
value={code}
|
||||
onChange={onCodeChange}
|
||||
size='large'
|
||||
maxLength={8}
|
||||
onKeyDown={handleKeyDown}
|
||||
autoFocus
|
||||
/>
|
||||
<Typography.Text type='tertiary' size='small' className='mt-2 block'>
|
||||
{t(
|
||||
'支持6位TOTP验证码或8位备用码,可到`个人设置-安全设置-两步验证设置`配置或查看。',
|
||||
)}
|
||||
</Typography.Text>
|
||||
</div>
|
||||
</div>
|
||||
</Modal>
|
||||
);
|
||||
};
|
||||
|
||||
export default TwoFactorAuthModal;
|
||||
@@ -91,22 +91,45 @@ const codeThemeStyles = {
|
||||
},
|
||||
};
|
||||
|
||||
const highlightJson = (str) => {
|
||||
return str.replace(
|
||||
/("(\\u[a-zA-Z0-9]{4}|\\[^u]|[^\\"])*"(\s*:)?|\b(true|false|null)\b|-?\d+(?:\.\d*)?(?:[eE][+-]?\d+)?)/g,
|
||||
(match) => {
|
||||
let color = '#b5cea8';
|
||||
if (/^"/.test(match)) {
|
||||
color = /:$/.test(match) ? '#9cdcfe' : '#ce9178';
|
||||
} else if (/true|false|null/.test(match)) {
|
||||
color = '#569cd6';
|
||||
}
|
||||
return `<span style="color: ${color}">${match}</span>`;
|
||||
},
|
||||
);
|
||||
const escapeHtml = (str) => {
|
||||
return str
|
||||
.replace(/&/g, '&')
|
||||
.replace(/</g, '<')
|
||||
.replace(/>/g, '>')
|
||||
.replace(/"/g, '"')
|
||||
.replace(/'/g, ''');
|
||||
};
|
||||
|
||||
const linkRegex = /(https?:\/\/[^\s<"'\]),;}]+)/g;
|
||||
const highlightJson = (str) => {
|
||||
const tokenRegex =
|
||||
/("(\\u[a-zA-Z0-9]{4}|\\[^u]|[^\\"])*"(\s*:)?|\b(true|false|null)\b|-?\d+(?:\.\d*)?(?:[eE][+-]?\d+)?)/g;
|
||||
|
||||
let result = '';
|
||||
let lastIndex = 0;
|
||||
let match;
|
||||
|
||||
while ((match = tokenRegex.exec(str)) !== null) {
|
||||
// Escape non-token text (structural chars like {, }, [, ], :, comma, whitespace)
|
||||
result += escapeHtml(str.slice(lastIndex, match.index));
|
||||
|
||||
const token = match[0];
|
||||
let color = '#b5cea8';
|
||||
if (/^"/.test(token)) {
|
||||
color = /:$/.test(token) ? '#9cdcfe' : '#ce9178';
|
||||
} else if (/true|false|null/.test(token)) {
|
||||
color = '#569cd6';
|
||||
}
|
||||
// Escape token content before wrapping in span
|
||||
result += `<span style="color: ${color}">${escapeHtml(token)}</span>`;
|
||||
lastIndex = tokenRegex.lastIndex;
|
||||
}
|
||||
|
||||
// Escape remaining text
|
||||
result += escapeHtml(str.slice(lastIndex));
|
||||
return result;
|
||||
};
|
||||
|
||||
const linkRegex = /(https?:\/\/(?:[^\s<"'\]),;&}]|&)+)/g;
|
||||
|
||||
const linkifyHtml = (html) => {
|
||||
const parts = html.split(/(<[^>]+>)/g);
|
||||
@@ -184,14 +207,14 @@ const CodeViewer = ({ content, title, language = 'json' }) => {
|
||||
|
||||
const highlightedContent = useMemo(() => {
|
||||
if (contentMetrics.isVeryLarge && !isExpanded) {
|
||||
return displayContent;
|
||||
return escapeHtml(displayContent);
|
||||
}
|
||||
|
||||
if (isJsonLike(displayContent, language)) {
|
||||
return highlightJson(displayContent);
|
||||
}
|
||||
|
||||
return displayContent;
|
||||
return escapeHtml(displayContent);
|
||||
}, [displayContent, language, contentMetrics.isVeryLarge, isExpanded]);
|
||||
|
||||
const renderedContent = useMemo(() => {
|
||||
|
||||
@@ -1,40 +0,0 @@
|
||||
/*
|
||||
Copyright (C) 2025 QuantumNous
|
||||
|
||||
This program is free software: you can redistribute it and/or modify
|
||||
it under the terms of the GNU Affero General Public License as
|
||||
published by the Free Software Foundation, either version 3 of the
|
||||
License, or (at your option) any later version.
|
||||
|
||||
This program is distributed in the hope that it will be useful,
|
||||
but WITHOUT ANY WARRANTY; without even the implied warranty of
|
||||
MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
||||
GNU Affero General Public License for more details.
|
||||
|
||||
You should have received a copy of the GNU Affero General Public License
|
||||
along with this program. If not, see <https://www.gnu.org/licenses/>.
|
||||
|
||||
For commercial licensing, please contact support@quantumnous.com
|
||||
*/
|
||||
|
||||
export { default as SettingsPanel } from './SettingsPanel';
|
||||
export { default as ChatArea } from './ChatArea';
|
||||
export { default as DebugPanel } from './DebugPanel';
|
||||
export { default as MessageContent } from './MessageContent';
|
||||
export { default as MessageActions } from './MessageActions';
|
||||
export { default as CustomInputRender } from './CustomInputRender';
|
||||
export { default as SSEViewer } from './SSEViewer';
|
||||
export { default as ParameterControl } from './ParameterControl';
|
||||
export { default as ImageUrlInput } from './ImageUrlInput';
|
||||
export { default as FloatingButtons } from './FloatingButtons';
|
||||
export { default as ConfigManager } from './ConfigManager';
|
||||
|
||||
export {
|
||||
saveConfig,
|
||||
loadConfig,
|
||||
clearConfig,
|
||||
hasStoredConfig,
|
||||
getConfigTimestamp,
|
||||
exportConfig,
|
||||
importConfig,
|
||||
} from './configStorage';
|
||||
631
web/src/components/settings/CustomOAuthSetting.jsx
Normal file
631
web/src/components/settings/CustomOAuthSetting.jsx
Normal file
@@ -0,0 +1,631 @@
|
||||
/*
|
||||
Copyright (C) 2025 QuantumNous
|
||||
|
||||
This program is free software: you can redistribute it and/or modify
|
||||
it under the terms of the GNU Affero General Public License as
|
||||
published by the Free Software Foundation, either version 3 of the
|
||||
License, or (at your option) any later version.
|
||||
|
||||
This program is distributed in the hope that it will be useful,
|
||||
but WITHOUT ANY WARRANTY; without even the implied warranty of
|
||||
MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
||||
GNU Affero General Public License for more details.
|
||||
|
||||
You should have received a copy of the GNU Affero General Public License
|
||||
along with this program. If not, see <https://www.gnu.org/licenses/>.
|
||||
|
||||
For commercial licensing, please contact support@quantumnous.com
|
||||
*/
|
||||
|
||||
import React, { useEffect, useState } from 'react';
|
||||
import {
|
||||
Button,
|
||||
Form,
|
||||
Row,
|
||||
Col,
|
||||
Typography,
|
||||
Modal,
|
||||
Banner,
|
||||
Card,
|
||||
Table,
|
||||
Tag,
|
||||
Popconfirm,
|
||||
Space,
|
||||
Select,
|
||||
} from '@douyinfe/semi-ui';
|
||||
import { IconPlus, IconEdit, IconDelete } from '@douyinfe/semi-icons';
|
||||
import { API, showError, showSuccess } from '../../helpers';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
|
||||
const { Text } = Typography;
|
||||
|
||||
// Preset templates for common OAuth providers
|
||||
const OAUTH_PRESETS = {
|
||||
'github-enterprise': {
|
||||
name: 'GitHub Enterprise',
|
||||
authorization_endpoint: '/login/oauth/authorize',
|
||||
token_endpoint: '/login/oauth/access_token',
|
||||
user_info_endpoint: '/api/v3/user',
|
||||
scopes: 'user:email',
|
||||
user_id_field: 'id',
|
||||
username_field: 'login',
|
||||
display_name_field: 'name',
|
||||
email_field: 'email',
|
||||
},
|
||||
gitlab: {
|
||||
name: 'GitLab',
|
||||
authorization_endpoint: '/oauth/authorize',
|
||||
token_endpoint: '/oauth/token',
|
||||
user_info_endpoint: '/api/v4/user',
|
||||
scopes: 'openid profile email',
|
||||
user_id_field: 'id',
|
||||
username_field: 'username',
|
||||
display_name_field: 'name',
|
||||
email_field: 'email',
|
||||
},
|
||||
gitea: {
|
||||
name: 'Gitea',
|
||||
authorization_endpoint: '/login/oauth/authorize',
|
||||
token_endpoint: '/login/oauth/access_token',
|
||||
user_info_endpoint: '/api/v1/user',
|
||||
scopes: 'openid profile email',
|
||||
user_id_field: 'id',
|
||||
username_field: 'login',
|
||||
display_name_field: 'full_name',
|
||||
email_field: 'email',
|
||||
},
|
||||
nextcloud: {
|
||||
name: 'Nextcloud',
|
||||
authorization_endpoint: '/apps/oauth2/authorize',
|
||||
token_endpoint: '/apps/oauth2/api/v1/token',
|
||||
user_info_endpoint: '/ocs/v2.php/cloud/user?format=json',
|
||||
scopes: 'openid profile email',
|
||||
user_id_field: 'ocs.data.id',
|
||||
username_field: 'ocs.data.id',
|
||||
display_name_field: 'ocs.data.displayname',
|
||||
email_field: 'ocs.data.email',
|
||||
},
|
||||
keycloak: {
|
||||
name: 'Keycloak',
|
||||
authorization_endpoint: '/realms/{realm}/protocol/openid-connect/auth',
|
||||
token_endpoint: '/realms/{realm}/protocol/openid-connect/token',
|
||||
user_info_endpoint: '/realms/{realm}/protocol/openid-connect/userinfo',
|
||||
scopes: 'openid profile email',
|
||||
user_id_field: 'sub',
|
||||
username_field: 'preferred_username',
|
||||
display_name_field: 'name',
|
||||
email_field: 'email',
|
||||
},
|
||||
authentik: {
|
||||
name: 'Authentik',
|
||||
authorization_endpoint: '/application/o/authorize/',
|
||||
token_endpoint: '/application/o/token/',
|
||||
user_info_endpoint: '/application/o/userinfo/',
|
||||
scopes: 'openid profile email',
|
||||
user_id_field: 'sub',
|
||||
username_field: 'preferred_username',
|
||||
display_name_field: 'name',
|
||||
email_field: 'email',
|
||||
},
|
||||
ory: {
|
||||
name: 'ORY Hydra',
|
||||
authorization_endpoint: '/oauth2/auth',
|
||||
token_endpoint: '/oauth2/token',
|
||||
user_info_endpoint: '/userinfo',
|
||||
scopes: 'openid profile email',
|
||||
user_id_field: 'sub',
|
||||
username_field: 'preferred_username',
|
||||
display_name_field: 'name',
|
||||
email_field: 'email',
|
||||
},
|
||||
};
|
||||
|
||||
const CustomOAuthSetting = ({ serverAddress }) => {
|
||||
const { t } = useTranslation();
|
||||
const [providers, setProviders] = useState([]);
|
||||
const [loading, setLoading] = useState(false);
|
||||
const [modalVisible, setModalVisible] = useState(false);
|
||||
const [editingProvider, setEditingProvider] = useState(null);
|
||||
const [formValues, setFormValues] = useState({});
|
||||
const [selectedPreset, setSelectedPreset] = useState('');
|
||||
const [baseUrl, setBaseUrl] = useState('');
|
||||
const formApiRef = React.useRef(null);
|
||||
|
||||
const fetchProviders = async () => {
|
||||
setLoading(true);
|
||||
try {
|
||||
const res = await API.get('/api/custom-oauth-provider/');
|
||||
if (res.data.success) {
|
||||
setProviders(res.data.data || []);
|
||||
} else {
|
||||
showError(res.data.message);
|
||||
}
|
||||
} catch (error) {
|
||||
showError(t('获取自定义 OAuth 提供商列表失败'));
|
||||
}
|
||||
setLoading(false);
|
||||
};
|
||||
|
||||
useEffect(() => {
|
||||
fetchProviders();
|
||||
}, []);
|
||||
|
||||
const handleAdd = () => {
|
||||
setEditingProvider(null);
|
||||
setFormValues({
|
||||
enabled: false,
|
||||
scopes: 'openid profile email',
|
||||
user_id_field: 'sub',
|
||||
username_field: 'preferred_username',
|
||||
display_name_field: 'name',
|
||||
email_field: 'email',
|
||||
auth_style: 0,
|
||||
});
|
||||
setSelectedPreset('');
|
||||
setBaseUrl('');
|
||||
setModalVisible(true);
|
||||
};
|
||||
|
||||
const handleEdit = (provider) => {
|
||||
setEditingProvider(provider);
|
||||
setFormValues({ ...provider });
|
||||
setSelectedPreset('');
|
||||
setBaseUrl('');
|
||||
setModalVisible(true);
|
||||
};
|
||||
|
||||
const handleDelete = async (id) => {
|
||||
try {
|
||||
const res = await API.delete(`/api/custom-oauth-provider/${id}`);
|
||||
if (res.data.success) {
|
||||
showSuccess(t('删除成功'));
|
||||
fetchProviders();
|
||||
} else {
|
||||
showError(res.data.message);
|
||||
}
|
||||
} catch (error) {
|
||||
showError(t('删除失败'));
|
||||
}
|
||||
};
|
||||
|
||||
const handleSubmit = async () => {
|
||||
// Validate required fields
|
||||
const requiredFields = [
|
||||
'name',
|
||||
'slug',
|
||||
'client_id',
|
||||
'authorization_endpoint',
|
||||
'token_endpoint',
|
||||
'user_info_endpoint',
|
||||
];
|
||||
|
||||
if (!editingProvider) {
|
||||
requiredFields.push('client_secret');
|
||||
}
|
||||
|
||||
for (const field of requiredFields) {
|
||||
if (!formValues[field]) {
|
||||
showError(t(`请填写 ${field}`));
|
||||
return;
|
||||
}
|
||||
}
|
||||
|
||||
// Validate endpoint URLs must be full URLs
|
||||
const endpointFields = ['authorization_endpoint', 'token_endpoint', 'user_info_endpoint'];
|
||||
for (const field of endpointFields) {
|
||||
const value = formValues[field];
|
||||
if (value && !value.startsWith('http://') && !value.startsWith('https://')) {
|
||||
// Check if user selected a preset but forgot to fill server address
|
||||
if (selectedPreset && !baseUrl) {
|
||||
showError(t('请先填写服务器地址,以自动生成完整的端点 URL'));
|
||||
} else {
|
||||
showError(t('端点 URL 必须是完整地址(以 http:// 或 https:// 开头)'));
|
||||
}
|
||||
return;
|
||||
}
|
||||
}
|
||||
|
||||
try {
|
||||
let res;
|
||||
if (editingProvider) {
|
||||
res = await API.put(
|
||||
`/api/custom-oauth-provider/${editingProvider.id}`,
|
||||
formValues
|
||||
);
|
||||
} else {
|
||||
res = await API.post('/api/custom-oauth-provider/', formValues);
|
||||
}
|
||||
|
||||
if (res.data.success) {
|
||||
showSuccess(editingProvider ? t('更新成功') : t('创建成功'));
|
||||
setModalVisible(false);
|
||||
fetchProviders();
|
||||
} else {
|
||||
showError(res.data.message);
|
||||
}
|
||||
} catch (error) {
|
||||
showError(editingProvider ? t('更新失败') : t('创建失败'));
|
||||
}
|
||||
};
|
||||
|
||||
const handlePresetChange = (preset) => {
|
||||
setSelectedPreset(preset);
|
||||
if (preset && OAUTH_PRESETS[preset]) {
|
||||
const presetConfig = OAUTH_PRESETS[preset];
|
||||
const cleanUrl = baseUrl ? baseUrl.replace(/\/+$/, '') : '';
|
||||
const newValues = {
|
||||
name: presetConfig.name,
|
||||
slug: preset,
|
||||
scopes: presetConfig.scopes,
|
||||
user_id_field: presetConfig.user_id_field,
|
||||
username_field: presetConfig.username_field,
|
||||
display_name_field: presetConfig.display_name_field,
|
||||
email_field: presetConfig.email_field,
|
||||
auth_style: presetConfig.auth_style ?? 0,
|
||||
};
|
||||
// Only fill endpoints if server address is provided
|
||||
if (cleanUrl) {
|
||||
newValues.authorization_endpoint = cleanUrl + presetConfig.authorization_endpoint;
|
||||
newValues.token_endpoint = cleanUrl + presetConfig.token_endpoint;
|
||||
newValues.user_info_endpoint = cleanUrl + presetConfig.user_info_endpoint;
|
||||
}
|
||||
setFormValues((prev) => ({ ...prev, ...newValues }));
|
||||
// Update form fields directly via formApi
|
||||
if (formApiRef.current) {
|
||||
Object.entries(newValues).forEach(([key, value]) => {
|
||||
formApiRef.current.setValue(key, value);
|
||||
});
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
const handleBaseUrlChange = (url) => {
|
||||
setBaseUrl(url);
|
||||
if (url && selectedPreset && OAUTH_PRESETS[selectedPreset]) {
|
||||
const presetConfig = OAUTH_PRESETS[selectedPreset];
|
||||
const cleanUrl = url.replace(/\/+$/, ''); // Remove trailing slashes
|
||||
const newValues = {
|
||||
authorization_endpoint: cleanUrl + presetConfig.authorization_endpoint,
|
||||
token_endpoint: cleanUrl + presetConfig.token_endpoint,
|
||||
user_info_endpoint: cleanUrl + presetConfig.user_info_endpoint,
|
||||
};
|
||||
setFormValues((prev) => ({ ...prev, ...newValues }));
|
||||
// Update form fields directly via formApi (use merge mode to preserve other fields)
|
||||
if (formApiRef.current) {
|
||||
Object.entries(newValues).forEach(([key, value]) => {
|
||||
formApiRef.current.setValue(key, value);
|
||||
});
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
const columns = [
|
||||
{
|
||||
title: t('名称'),
|
||||
dataIndex: 'name',
|
||||
key: 'name',
|
||||
},
|
||||
{
|
||||
title: 'Slug',
|
||||
dataIndex: 'slug',
|
||||
key: 'slug',
|
||||
render: (slug) => <Tag>{slug}</Tag>,
|
||||
},
|
||||
{
|
||||
title: t('状态'),
|
||||
dataIndex: 'enabled',
|
||||
key: 'enabled',
|
||||
render: (enabled) => (
|
||||
<Tag color={enabled ? 'green' : 'grey'}>
|
||||
{enabled ? t('已启用') : t('已禁用')}
|
||||
</Tag>
|
||||
),
|
||||
},
|
||||
{
|
||||
title: t('Client ID'),
|
||||
dataIndex: 'client_id',
|
||||
key: 'client_id',
|
||||
render: (id) => (id ? id.substring(0, 20) + '...' : '-'),
|
||||
},
|
||||
{
|
||||
title: t('操作'),
|
||||
key: 'actions',
|
||||
render: (_, record) => (
|
||||
<Space>
|
||||
<Button
|
||||
icon={<IconEdit />}
|
||||
size="small"
|
||||
onClick={() => handleEdit(record)}
|
||||
>
|
||||
{t('编辑')}
|
||||
</Button>
|
||||
<Popconfirm
|
||||
title={t('确定要删除此 OAuth 提供商吗?')}
|
||||
onConfirm={() => handleDelete(record.id)}
|
||||
>
|
||||
<Button icon={<IconDelete />} size="small" type="danger">
|
||||
{t('删除')}
|
||||
</Button>
|
||||
</Popconfirm>
|
||||
</Space>
|
||||
),
|
||||
},
|
||||
];
|
||||
|
||||
return (
|
||||
<Card>
|
||||
<Form.Section text={t('自定义 OAuth 提供商')}>
|
||||
<Banner
|
||||
type="info"
|
||||
description={
|
||||
<>
|
||||
{t(
|
||||
'配置自定义 OAuth 提供商,支持 GitHub Enterprise、GitLab、Gitea、Nextcloud、Keycloak、ORY 等兼容 OAuth 2.0 协议的身份提供商'
|
||||
)}
|
||||
<br />
|
||||
{t('回调 URL 格式')}: {serverAddress || t('网站地址')}/oauth/
|
||||
{'{slug}'}
|
||||
</>
|
||||
}
|
||||
style={{ marginBottom: 20 }}
|
||||
/>
|
||||
|
||||
<Button
|
||||
icon={<IconPlus />}
|
||||
theme="solid"
|
||||
onClick={handleAdd}
|
||||
style={{ marginBottom: 16 }}
|
||||
>
|
||||
{t('添加 OAuth 提供商')}
|
||||
</Button>
|
||||
|
||||
<Table
|
||||
columns={columns}
|
||||
dataSource={providers}
|
||||
loading={loading}
|
||||
rowKey="id"
|
||||
pagination={false}
|
||||
empty={t('暂无自定义 OAuth 提供商')}
|
||||
/>
|
||||
|
||||
<Modal
|
||||
title={editingProvider ? t('编辑 OAuth 提供商') : t('添加 OAuth 提供商')}
|
||||
visible={modalVisible}
|
||||
onOk={handleSubmit}
|
||||
onCancel={() => setModalVisible(false)}
|
||||
okText={t('保存')}
|
||||
cancelText={t('取消')}
|
||||
width={800}
|
||||
>
|
||||
<Form
|
||||
initValues={formValues}
|
||||
onValueChange={(values) => setFormValues(values)}
|
||||
getFormApi={(api) => (formApiRef.current = api)}
|
||||
>
|
||||
{!editingProvider && (
|
||||
<Row gutter={16} style={{ marginBottom: 16 }}>
|
||||
<Col span={12}>
|
||||
<Form.Select
|
||||
field="preset"
|
||||
label={t('预设模板')}
|
||||
placeholder={t('选择预设模板(可选)')}
|
||||
value={selectedPreset}
|
||||
onChange={handlePresetChange}
|
||||
optionList={[
|
||||
{ value: '', label: t('自定义') },
|
||||
...Object.entries(OAUTH_PRESETS).map(([key, config]) => ({
|
||||
value: key,
|
||||
label: config.name,
|
||||
})),
|
||||
]}
|
||||
/>
|
||||
</Col>
|
||||
<Col span={12}>
|
||||
<Form.Input
|
||||
field="base_url"
|
||||
label={
|
||||
selectedPreset
|
||||
? t('服务器地址') + ' *'
|
||||
: t('服务器地址')
|
||||
}
|
||||
placeholder={t('例如:https://gitea.example.com')}
|
||||
value={baseUrl}
|
||||
onChange={handleBaseUrlChange}
|
||||
extraText={
|
||||
selectedPreset
|
||||
? t('必填:请输入服务器地址以自动生成完整端点 URL')
|
||||
: t('选择预设模板后填写服务器地址可自动填充端点')
|
||||
}
|
||||
/>
|
||||
</Col>
|
||||
</Row>
|
||||
)}
|
||||
|
||||
<Row gutter={16}>
|
||||
<Col span={12}>
|
||||
<Form.Input
|
||||
field="name"
|
||||
label={t('显示名称')}
|
||||
placeholder={t('例如:GitHub Enterprise')}
|
||||
rules={[{ required: true, message: t('请输入显示名称') }]}
|
||||
/>
|
||||
</Col>
|
||||
<Col span={12}>
|
||||
<Form.Input
|
||||
field="slug"
|
||||
label="Slug"
|
||||
placeholder={t('例如:github-enterprise')}
|
||||
extraText={t('URL 标识,只能包含小写字母、数字和连字符')}
|
||||
rules={[{ required: true, message: t('请输入 Slug') }]}
|
||||
/>
|
||||
</Col>
|
||||
</Row>
|
||||
|
||||
<Row gutter={16}>
|
||||
<Col span={12}>
|
||||
<Form.Input
|
||||
field="client_id"
|
||||
label="Client ID"
|
||||
placeholder={t('OAuth Client ID')}
|
||||
rules={[{ required: true, message: t('请输入 Client ID') }]}
|
||||
/>
|
||||
</Col>
|
||||
<Col span={12}>
|
||||
<Form.Input
|
||||
field="client_secret"
|
||||
label="Client Secret"
|
||||
type="password"
|
||||
placeholder={
|
||||
editingProvider
|
||||
? t('留空则保持原有密钥')
|
||||
: t('OAuth Client Secret')
|
||||
}
|
||||
rules={
|
||||
editingProvider
|
||||
? []
|
||||
: [{ required: true, message: t('请输入 Client Secret') }]
|
||||
}
|
||||
/>
|
||||
</Col>
|
||||
</Row>
|
||||
|
||||
<Text strong style={{ display: 'block', margin: '16px 0 8px' }}>
|
||||
{t('OAuth 端点')}
|
||||
</Text>
|
||||
|
||||
<Row gutter={16}>
|
||||
<Col span={24}>
|
||||
<Form.Input
|
||||
field="authorization_endpoint"
|
||||
label={t('Authorization Endpoint')}
|
||||
placeholder={
|
||||
selectedPreset && OAUTH_PRESETS[selectedPreset]
|
||||
? t('填写服务器地址后自动生成:') +
|
||||
OAUTH_PRESETS[selectedPreset].authorization_endpoint
|
||||
: 'https://example.com/oauth/authorize'
|
||||
}
|
||||
rules={[
|
||||
{ required: true, message: t('请输入 Authorization Endpoint') },
|
||||
]}
|
||||
/>
|
||||
</Col>
|
||||
</Row>
|
||||
|
||||
<Row gutter={16}>
|
||||
<Col span={12}>
|
||||
<Form.Input
|
||||
field="token_endpoint"
|
||||
label={t('Token Endpoint')}
|
||||
placeholder={
|
||||
selectedPreset && OAUTH_PRESETS[selectedPreset]
|
||||
? t('自动生成:') + OAUTH_PRESETS[selectedPreset].token_endpoint
|
||||
: 'https://example.com/oauth/token'
|
||||
}
|
||||
rules={[{ required: true, message: t('请输入 Token Endpoint') }]}
|
||||
/>
|
||||
</Col>
|
||||
<Col span={12}>
|
||||
<Form.Input
|
||||
field="user_info_endpoint"
|
||||
label={t('User Info Endpoint')}
|
||||
placeholder={
|
||||
selectedPreset && OAUTH_PRESETS[selectedPreset]
|
||||
? t('自动生成:') + OAUTH_PRESETS[selectedPreset].user_info_endpoint
|
||||
: 'https://example.com/api/user'
|
||||
}
|
||||
rules={[
|
||||
{ required: true, message: t('请输入 User Info Endpoint') },
|
||||
]}
|
||||
/>
|
||||
</Col>
|
||||
</Row>
|
||||
|
||||
<Row gutter={16}>
|
||||
<Col span={12}>
|
||||
<Form.Input
|
||||
field="scopes"
|
||||
label={t('Scopes')}
|
||||
placeholder="openid profile email"
|
||||
/>
|
||||
</Col>
|
||||
<Col span={12}>
|
||||
<Form.Input
|
||||
field="well_known"
|
||||
label={t('Well-Known URL')}
|
||||
placeholder={t('OIDC Discovery 端点(可选)')}
|
||||
/>
|
||||
</Col>
|
||||
</Row>
|
||||
|
||||
<Text strong style={{ display: 'block', margin: '16px 0 8px' }}>
|
||||
{t('字段映射')}
|
||||
</Text>
|
||||
<Text type="secondary" style={{ display: 'block', marginBottom: 8 }}>
|
||||
{t('配置如何从用户信息 API 响应中提取用户数据,支持 JSONPath 语法')}
|
||||
</Text>
|
||||
|
||||
<Row gutter={16}>
|
||||
<Col span={12}>
|
||||
<Form.Input
|
||||
field="user_id_field"
|
||||
label={t('用户 ID 字段')}
|
||||
placeholder={t('例如:sub、id、data.user.id')}
|
||||
extraText={t('用于唯一标识用户的字段路径')}
|
||||
/>
|
||||
</Col>
|
||||
<Col span={12}>
|
||||
<Form.Input
|
||||
field="username_field"
|
||||
label={t('用户名字段')}
|
||||
placeholder={t('例如:preferred_username、login')}
|
||||
/>
|
||||
</Col>
|
||||
</Row>
|
||||
|
||||
<Row gutter={16}>
|
||||
<Col span={12}>
|
||||
<Form.Input
|
||||
field="display_name_field"
|
||||
label={t('显示名称字段')}
|
||||
placeholder={t('例如:name、full_name')}
|
||||
/>
|
||||
</Col>
|
||||
<Col span={12}>
|
||||
<Form.Input
|
||||
field="email_field"
|
||||
label={t('邮箱字段')}
|
||||
placeholder={t('例如:email')}
|
||||
/>
|
||||
</Col>
|
||||
</Row>
|
||||
|
||||
<Text strong style={{ display: 'block', margin: '16px 0 8px' }}>
|
||||
{t('高级选项')}
|
||||
</Text>
|
||||
|
||||
<Row gutter={16}>
|
||||
<Col span={12}>
|
||||
<Form.Select
|
||||
field="auth_style"
|
||||
label={t('认证方式')}
|
||||
optionList={[
|
||||
{ value: 0, label: t('自动检测') },
|
||||
{ value: 1, label: t('POST 参数') },
|
||||
{ value: 2, label: t('Basic Auth 头') },
|
||||
]}
|
||||
/>
|
||||
</Col>
|
||||
<Col span={12}>
|
||||
<Form.Checkbox field="enabled" noLabel>
|
||||
{t('启用此 OAuth 提供商')}
|
||||
</Form.Checkbox>
|
||||
</Col>
|
||||
</Row>
|
||||
</Form>
|
||||
</Modal>
|
||||
</Form.Section>
|
||||
</Card>
|
||||
);
|
||||
};
|
||||
|
||||
export default CustomOAuthSetting;
|
||||
@@ -78,6 +78,9 @@ const OperationSetting = () => {
|
||||
'checkin_setting.enabled': false,
|
||||
'checkin_setting.min_quota': 1000,
|
||||
'checkin_setting.max_quota': 10000,
|
||||
|
||||
/* 令牌设置 */
|
||||
'token_setting.max_user_tokens': 1000,
|
||||
});
|
||||
|
||||
let [loading, setLoading] = useState(false);
|
||||
|
||||
@@ -42,6 +42,7 @@ import {
|
||||
} from '../../helpers';
|
||||
import axios from 'axios';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
import CustomOAuthSetting from './CustomOAuthSetting';
|
||||
|
||||
const SystemSetting = () => {
|
||||
const { t } = useTranslation();
|
||||
@@ -1534,6 +1535,8 @@ const SystemSetting = () => {
|
||||
</Form.Section>
|
||||
</Card>
|
||||
|
||||
<CustomOAuthSetting serverAddress={inputs.ServerAddress} />
|
||||
|
||||
<Card>
|
||||
<Form.Section text={t('配置 WeChat Server')}>
|
||||
<Text>{t('用以支持通过微信进行登录注册')}</Text>
|
||||
|
||||
@@ -42,10 +42,14 @@ import { SiTelegram, SiWechat, SiLinux, SiDiscord } from 'react-icons/si';
|
||||
import { UserPlus, ShieldCheck } from 'lucide-react';
|
||||
import TelegramLoginButton from 'react-telegram-login';
|
||||
import {
|
||||
API,
|
||||
showError,
|
||||
showSuccess,
|
||||
onGitHubOAuthClicked,
|
||||
onOIDCClicked,
|
||||
onLinuxDOOAuthClicked,
|
||||
onDiscordOAuthClicked,
|
||||
onCustomOAuthClicked,
|
||||
} from '../../../../helpers';
|
||||
import TwoFASetting from '../components/TwoFASetting';
|
||||
|
||||
@@ -94,6 +98,66 @@ const AccountManagement = ({
|
||||
const isBound = (accountId) => Boolean(accountId);
|
||||
const [showTelegramBindModal, setShowTelegramBindModal] =
|
||||
React.useState(false);
|
||||
const [customOAuthBindings, setCustomOAuthBindings] = React.useState([]);
|
||||
const [customOAuthLoading, setCustomOAuthLoading] = React.useState({});
|
||||
|
||||
// Fetch custom OAuth bindings
|
||||
const loadCustomOAuthBindings = async () => {
|
||||
try {
|
||||
const res = await API.get('/api/user/oauth/bindings');
|
||||
if (res.data.success) {
|
||||
setCustomOAuthBindings(res.data.data || []);
|
||||
}
|
||||
} catch (error) {
|
||||
// ignore
|
||||
}
|
||||
};
|
||||
|
||||
// Unbind custom OAuth provider
|
||||
const handleUnbindCustomOAuth = async (providerId, providerName) => {
|
||||
Modal.confirm({
|
||||
title: t('确认解绑'),
|
||||
content: t('确定要解绑 {{name}} 吗?', { name: providerName }),
|
||||
okText: t('确认'),
|
||||
cancelText: t('取消'),
|
||||
onOk: async () => {
|
||||
setCustomOAuthLoading((prev) => ({ ...prev, [providerId]: true }));
|
||||
try {
|
||||
const res = await API.delete(`/api/user/oauth/bindings/${providerId}`);
|
||||
if (res.data.success) {
|
||||
showSuccess(t('解绑成功'));
|
||||
await loadCustomOAuthBindings();
|
||||
} else {
|
||||
showError(res.data.message);
|
||||
}
|
||||
} catch (error) {
|
||||
showError(t('操作失败'));
|
||||
} finally {
|
||||
setCustomOAuthLoading((prev) => ({ ...prev, [providerId]: false }));
|
||||
}
|
||||
},
|
||||
});
|
||||
};
|
||||
|
||||
// Handle bind custom OAuth
|
||||
const handleBindCustomOAuth = (provider) => {
|
||||
onCustomOAuthClicked(provider);
|
||||
};
|
||||
|
||||
// Check if custom OAuth provider is bound
|
||||
const isCustomOAuthBound = (providerId) => {
|
||||
return customOAuthBindings.some((b) => b.provider_id === providerId);
|
||||
};
|
||||
|
||||
// Get binding info for a provider
|
||||
const getCustomOAuthBinding = (providerId) => {
|
||||
return customOAuthBindings.find((b) => b.provider_id === providerId);
|
||||
};
|
||||
|
||||
React.useEffect(() => {
|
||||
loadCustomOAuthBindings();
|
||||
}, []);
|
||||
|
||||
const passkeyEnabled = passkeyStatus?.enabled;
|
||||
const lastUsedLabel = passkeyStatus?.last_used_at
|
||||
? new Date(passkeyStatus.last_used_at).toLocaleString()
|
||||
@@ -447,6 +511,64 @@ const AccountManagement = ({
|
||||
</div>
|
||||
</div>
|
||||
</Card>
|
||||
|
||||
{/* 自定义 OAuth 提供商绑定 */}
|
||||
{status.custom_oauth_providers &&
|
||||
status.custom_oauth_providers.map((provider) => {
|
||||
const bound = isCustomOAuthBound(provider.id);
|
||||
const binding = getCustomOAuthBinding(provider.id);
|
||||
return (
|
||||
<Card key={provider.slug} className='!rounded-xl'>
|
||||
<div className='flex items-center justify-between gap-3'>
|
||||
<div className='flex items-center flex-1 min-w-0'>
|
||||
<div className='w-10 h-10 rounded-full bg-slate-100 dark:bg-slate-700 flex items-center justify-center mr-3 flex-shrink-0'>
|
||||
<IconLock
|
||||
size='default'
|
||||
className='text-slate-600 dark:text-slate-300'
|
||||
/>
|
||||
</div>
|
||||
<div className='flex-1 min-w-0'>
|
||||
<div className='font-medium text-gray-900'>
|
||||
{provider.name}
|
||||
</div>
|
||||
<div className='text-sm text-gray-500 truncate'>
|
||||
{bound
|
||||
? renderAccountInfo(
|
||||
binding?.provider_user_id,
|
||||
t('{{name}} ID', { name: provider.name }),
|
||||
)
|
||||
: t('未绑定')}
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
<div className='flex-shrink-0'>
|
||||
{bound ? (
|
||||
<Button
|
||||
type='danger'
|
||||
theme='outline'
|
||||
size='small'
|
||||
loading={customOAuthLoading[provider.id]}
|
||||
onClick={() =>
|
||||
handleUnbindCustomOAuth(provider.id, provider.name)
|
||||
}
|
||||
>
|
||||
{t('解绑')}
|
||||
</Button>
|
||||
) : (
|
||||
<Button
|
||||
type='primary'
|
||||
theme='outline'
|
||||
size='small'
|
||||
onClick={() => handleBindCustomOAuth(provider)}
|
||||
>
|
||||
{t('绑定')}
|
||||
</Button>
|
||||
)}
|
||||
</div>
|
||||
</div>
|
||||
</Card>
|
||||
);
|
||||
})}
|
||||
</div>
|
||||
</div>
|
||||
</TabPane>
|
||||
|
||||
@@ -1,280 +0,0 @@
|
||||
/*
|
||||
Copyright (C) 2025 QuantumNous
|
||||
|
||||
This program is free software: you can redistribute it and/or modify
|
||||
it under the terms of the GNU Affero General Public License as
|
||||
published by the Free Software Foundation, either version 3 of the
|
||||
License, or (at your option) any later version.
|
||||
|
||||
This program is distributed in the hope that it will be useful,
|
||||
but WITHOUT ANY WARRANTY; without even the implied warranty of
|
||||
MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
||||
GNU Affero General Public License for more details.
|
||||
|
||||
You should have received a copy of the GNU Affero General Public License
|
||||
along with this program. If not, see <https://www.gnu.org/licenses/>.
|
||||
|
||||
For commercial licensing, please contact support@quantumnous.com
|
||||
*/
|
||||
|
||||
import React, { useState, useEffect } from 'react';
|
||||
import {
|
||||
Empty,
|
||||
Skeleton,
|
||||
Space,
|
||||
Tag,
|
||||
Collapsible,
|
||||
Tabs,
|
||||
TabPane,
|
||||
Typography,
|
||||
Avatar,
|
||||
} from '@douyinfe/semi-ui';
|
||||
import {
|
||||
IllustrationNoContent,
|
||||
IllustrationNoContentDark,
|
||||
} from '@douyinfe/semi-illustrations';
|
||||
import { IconChevronDown, IconChevronUp } from '@douyinfe/semi-icons';
|
||||
import { Settings } from 'lucide-react';
|
||||
import { renderModelTag, getModelCategories } from '../../../../helpers';
|
||||
|
||||
const ModelsList = ({ t, models, modelsLoading, copyText }) => {
|
||||
const [isModelsExpanded, setIsModelsExpanded] = useState(() => {
|
||||
// Initialize from localStorage if available
|
||||
const savedState = localStorage.getItem('modelsExpanded');
|
||||
return savedState ? JSON.parse(savedState) : false;
|
||||
});
|
||||
const [activeModelCategory, setActiveModelCategory] = useState('all');
|
||||
const MODELS_DISPLAY_COUNT = 25; // 默认显示的模型数量
|
||||
|
||||
// Save models expanded state to localStorage whenever it changes
|
||||
useEffect(() => {
|
||||
localStorage.setItem('modelsExpanded', JSON.stringify(isModelsExpanded));
|
||||
}, [isModelsExpanded]);
|
||||
|
||||
return (
|
||||
<div className='py-4'>
|
||||
{/* 卡片头部 */}
|
||||
<div className='flex items-center mb-4'>
|
||||
<Avatar size='small' color='green' className='mr-3 shadow-md'>
|
||||
<Settings size={16} />
|
||||
</Avatar>
|
||||
<div>
|
||||
<Typography.Text className='text-lg font-medium'>
|
||||
{t('可用模型')}
|
||||
</Typography.Text>
|
||||
<div className='text-xs text-gray-600'>
|
||||
{t('查看当前可用的所有模型')}
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
{/* 可用模型部分 */}
|
||||
<div className='bg-gray-50 dark:bg-gray-800 rounded-xl'>
|
||||
{modelsLoading ? (
|
||||
// 骨架屏加载状态 - 模拟实际加载后的布局
|
||||
<div className='space-y-4'>
|
||||
{/* 模拟分类标签 */}
|
||||
<div
|
||||
className='mb-4'
|
||||
style={{ borderBottom: '1px solid var(--semi-color-border)' }}
|
||||
>
|
||||
<div className='flex overflow-x-auto py-2 gap-2'>
|
||||
{Array.from({ length: 8 }).map((_, index) => (
|
||||
<Skeleton.Button
|
||||
key={`cat-${index}`}
|
||||
style={{
|
||||
width: index === 0 ? 130 : 100 + Math.random() * 50,
|
||||
height: 36,
|
||||
borderRadius: 8,
|
||||
}}
|
||||
/>
|
||||
))}
|
||||
</div>
|
||||
</div>
|
||||
|
||||
{/* 模拟模型标签列表 */}
|
||||
<div className='flex flex-wrap gap-2'>
|
||||
{Array.from({ length: 20 }).map((_, index) => (
|
||||
<Skeleton.Button
|
||||
key={`model-${index}`}
|
||||
style={{
|
||||
width: 100 + Math.random() * 100,
|
||||
height: 32,
|
||||
borderRadius: 16,
|
||||
margin: '4px',
|
||||
}}
|
||||
/>
|
||||
))}
|
||||
</div>
|
||||
</div>
|
||||
) : models.length === 0 ? (
|
||||
<div className='py-8'>
|
||||
<Empty
|
||||
image={
|
||||
<IllustrationNoContent style={{ width: 150, height: 150 }} />
|
||||
}
|
||||
darkModeImage={
|
||||
<IllustrationNoContentDark
|
||||
style={{ width: 150, height: 150 }}
|
||||
/>
|
||||
}
|
||||
description={t('没有可用模型')}
|
||||
style={{ padding: '24px 0' }}
|
||||
/>
|
||||
</div>
|
||||
) : (
|
||||
<>
|
||||
{/* 模型分类标签页 */}
|
||||
<div className='mb-4'>
|
||||
<Tabs
|
||||
type='card'
|
||||
activeKey={activeModelCategory}
|
||||
onChange={(key) => setActiveModelCategory(key)}
|
||||
className='mt-2'
|
||||
collapsible
|
||||
>
|
||||
{Object.entries(getModelCategories(t)).map(
|
||||
([key, category]) => {
|
||||
// 计算该分类下的模型数量
|
||||
const modelCount =
|
||||
key === 'all'
|
||||
? models.length
|
||||
: models.filter((model) =>
|
||||
category.filter({ model_name: model }),
|
||||
).length;
|
||||
|
||||
if (modelCount === 0 && key !== 'all') return null;
|
||||
|
||||
return (
|
||||
<TabPane
|
||||
tab={
|
||||
<span className='flex items-center gap-2'>
|
||||
{category.icon && (
|
||||
<span className='w-4 h-4'>{category.icon}</span>
|
||||
)}
|
||||
{category.label}
|
||||
<Tag
|
||||
color={
|
||||
activeModelCategory === key ? 'red' : 'grey'
|
||||
}
|
||||
size='small'
|
||||
shape='circle'
|
||||
>
|
||||
{modelCount}
|
||||
</Tag>
|
||||
</span>
|
||||
}
|
||||
itemKey={key}
|
||||
key={key}
|
||||
/>
|
||||
);
|
||||
},
|
||||
)}
|
||||
</Tabs>
|
||||
</div>
|
||||
|
||||
<div className='bg-white dark:bg-gray-700 rounded-lg p-3'>
|
||||
{(() => {
|
||||
// 根据当前选中的分类过滤模型
|
||||
const categories = getModelCategories(t);
|
||||
const filteredModels =
|
||||
activeModelCategory === 'all'
|
||||
? models
|
||||
: models.filter((model) =>
|
||||
categories[activeModelCategory].filter({
|
||||
model_name: model,
|
||||
}),
|
||||
);
|
||||
|
||||
// 如果过滤后没有模型,显示空状态
|
||||
if (filteredModels.length === 0) {
|
||||
return (
|
||||
<Empty
|
||||
image={
|
||||
<IllustrationNoContent
|
||||
style={{ width: 120, height: 120 }}
|
||||
/>
|
||||
}
|
||||
darkModeImage={
|
||||
<IllustrationNoContentDark
|
||||
style={{ width: 120, height: 120 }}
|
||||
/>
|
||||
}
|
||||
description={t('该分类下没有可用模型')}
|
||||
style={{ padding: '16px 0' }}
|
||||
/>
|
||||
);
|
||||
}
|
||||
|
||||
if (filteredModels.length <= MODELS_DISPLAY_COUNT) {
|
||||
return (
|
||||
<Space wrap>
|
||||
{filteredModels.map((model) =>
|
||||
renderModelTag(model, {
|
||||
size: 'small',
|
||||
shape: 'circle',
|
||||
onClick: () => copyText(model),
|
||||
}),
|
||||
)}
|
||||
</Space>
|
||||
);
|
||||
} else {
|
||||
return (
|
||||
<>
|
||||
<Collapsible isOpen={isModelsExpanded}>
|
||||
<Space wrap>
|
||||
{filteredModels.map((model) =>
|
||||
renderModelTag(model, {
|
||||
size: 'small',
|
||||
shape: 'circle',
|
||||
onClick: () => copyText(model),
|
||||
}),
|
||||
)}
|
||||
<Tag
|
||||
color='grey'
|
||||
type='light'
|
||||
className='cursor-pointer !rounded-lg'
|
||||
onClick={() => setIsModelsExpanded(false)}
|
||||
icon={<IconChevronUp />}
|
||||
>
|
||||
{t('收起')}
|
||||
</Tag>
|
||||
</Space>
|
||||
</Collapsible>
|
||||
{!isModelsExpanded && (
|
||||
<Space wrap>
|
||||
{filteredModels
|
||||
.slice(0, MODELS_DISPLAY_COUNT)
|
||||
.map((model) =>
|
||||
renderModelTag(model, {
|
||||
size: 'small',
|
||||
shape: 'circle',
|
||||
onClick: () => copyText(model),
|
||||
}),
|
||||
)}
|
||||
<Tag
|
||||
color='grey'
|
||||
type='light'
|
||||
className='cursor-pointer !rounded-lg'
|
||||
onClick={() => setIsModelsExpanded(true)}
|
||||
icon={<IconChevronDown />}
|
||||
>
|
||||
{t('更多')}{' '}
|
||||
{filteredModels.length - MODELS_DISPLAY_COUNT}{' '}
|
||||
{t('个模型')}
|
||||
</Tag>
|
||||
</Space>
|
||||
)}
|
||||
</>
|
||||
);
|
||||
}
|
||||
})()}
|
||||
</div>
|
||||
</>
|
||||
)}
|
||||
</div>
|
||||
</div>
|
||||
);
|
||||
};
|
||||
|
||||
export default ModelsList;
|
||||
@@ -1,44 +0,0 @@
|
||||
/*
|
||||
Copyright (C) 2025 QuantumNous
|
||||
|
||||
This program is free software: you can redistribute it and/or modify
|
||||
it under the terms of the GNU Affero General Public License as
|
||||
published by the Free Software Foundation, either version 3 of the
|
||||
License, or (at your option) any later version.
|
||||
|
||||
This program is distributed in the hope that it will be useful,
|
||||
but WITHOUT ANY WARRANTY; without even the implied warranty of
|
||||
MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
||||
GNU Affero General Public License for more details.
|
||||
|
||||
You should have received a copy of the GNU Affero General Public License
|
||||
along with this program. If not, see <https://www.gnu.org/licenses/>.
|
||||
|
||||
For commercial licensing, please contact support@quantumnous.com
|
||||
*/
|
||||
|
||||
import React from 'react';
|
||||
import { Typography } from '@douyinfe/semi-ui';
|
||||
import { Layers } from 'lucide-react';
|
||||
import CompactModeToggle from '../../common/ui/CompactModeToggle';
|
||||
|
||||
const { Text } = Typography;
|
||||
|
||||
const ModelsDescription = ({ compactMode, setCompactMode, t }) => {
|
||||
return (
|
||||
<div className='flex flex-col md:flex-row justify-between items-start md:items-center gap-2 w-full'>
|
||||
<div className='flex items-center text-green-500'>
|
||||
<Layers size={16} className='mr-2' />
|
||||
<Text>{t('模型管理')}</Text>
|
||||
</div>
|
||||
|
||||
<CompactModeToggle
|
||||
compactMode={compactMode}
|
||||
setCompactMode={setCompactMode}
|
||||
t={t}
|
||||
/>
|
||||
</div>
|
||||
);
|
||||
};
|
||||
|
||||
export default ModelsDescription;
|
||||
@@ -1,123 +0,0 @@
|
||||
/*
|
||||
Copyright (C) 2025 QuantumNous
|
||||
|
||||
This program is free software: you can redistribute it and/or modify
|
||||
it under the terms of the GNU Affero General Public License as
|
||||
published by the Free Software Foundation, either version 3 of the
|
||||
License, or (at your option) any later version.
|
||||
|
||||
This program is distributed in the hope that it will be useful,
|
||||
but WITHOUT ANY WARRANTY; without even the implied warranty of
|
||||
MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
||||
GNU Affero General Public License for more details.
|
||||
|
||||
You should have received a copy of the GNU Affero General Public License
|
||||
along with this program. If not, see <https://www.gnu.org/licenses/>.
|
||||
|
||||
For commercial licensing, please contact support@quantumnous.com
|
||||
*/
|
||||
|
||||
import React, { useEffect, useMemo, useState } from 'react';
|
||||
import { Modal, Select, Space, Typography } from '@douyinfe/semi-ui';
|
||||
import { API, showError, showSuccess } from '../../../../helpers';
|
||||
|
||||
const { Text } = Typography;
|
||||
|
||||
const BindSubscriptionModal = ({ visible, onCancel, user, t, onSuccess }) => {
|
||||
const [loading, setLoading] = useState(false);
|
||||
const [plans, setPlans] = useState([]);
|
||||
const [selectedPlanId, setSelectedPlanId] = useState(null);
|
||||
|
||||
const loadPlans = async () => {
|
||||
setLoading(true);
|
||||
try {
|
||||
const res = await API.get('/api/subscription/admin/plans');
|
||||
if (res.data?.success) {
|
||||
setPlans(res.data.data || []);
|
||||
} else {
|
||||
showError(res.data?.message || t('加载失败'));
|
||||
}
|
||||
} catch (e) {
|
||||
showError(t('请求失败'));
|
||||
} finally {
|
||||
setLoading(false);
|
||||
}
|
||||
};
|
||||
|
||||
useEffect(() => {
|
||||
if (visible) {
|
||||
setSelectedPlanId(null);
|
||||
loadPlans();
|
||||
}
|
||||
}, [visible]);
|
||||
|
||||
const planOptions = useMemo(() => {
|
||||
return (plans || []).map((p) => ({
|
||||
label: `${p?.plan?.title || ''} (${p?.plan?.currency || 'USD'} ${Number(p?.plan?.price_amount || 0)})`,
|
||||
value: p?.plan?.id,
|
||||
}));
|
||||
}, [plans]);
|
||||
|
||||
const bind = async () => {
|
||||
if (!user?.id) {
|
||||
showError(t('用户信息缺失'));
|
||||
return;
|
||||
}
|
||||
if (!selectedPlanId) {
|
||||
showError(t('请选择订阅套餐'));
|
||||
return;
|
||||
}
|
||||
setLoading(true);
|
||||
try {
|
||||
const res = await API.post('/api/subscription/admin/bind', {
|
||||
user_id: user.id,
|
||||
plan_id: selectedPlanId,
|
||||
});
|
||||
if (res.data?.success) {
|
||||
showSuccess(t('绑定成功'));
|
||||
onSuccess?.();
|
||||
onCancel?.();
|
||||
} else {
|
||||
showError(res.data?.message || t('绑定失败'));
|
||||
}
|
||||
} catch (e) {
|
||||
showError(t('请求失败'));
|
||||
} finally {
|
||||
setLoading(false);
|
||||
}
|
||||
};
|
||||
|
||||
return (
|
||||
<Modal
|
||||
title={t('绑定订阅套餐')}
|
||||
visible={visible}
|
||||
onCancel={onCancel}
|
||||
onOk={bind}
|
||||
confirmLoading={loading}
|
||||
maskClosable={false}
|
||||
centered
|
||||
>
|
||||
<Space vertical style={{ width: '100%' }} spacing='medium'>
|
||||
<div className='text-sm'>
|
||||
<Text strong>{t('用户')}:</Text>
|
||||
<Text>{user?.username}</Text>
|
||||
<Text type='tertiary'> (ID: {user?.id})</Text>
|
||||
</div>
|
||||
<Select
|
||||
placeholder={t('选择订阅套餐')}
|
||||
optionList={planOptions}
|
||||
value={selectedPlanId}
|
||||
onChange={setSelectedPlanId}
|
||||
loading={loading}
|
||||
filter
|
||||
style={{ width: '100%' }}
|
||||
/>
|
||||
<div className='text-xs text-gray-500'>
|
||||
{t('绑定后会立即生成用户订阅(无需支付),有效期按套餐配置计算。')}
|
||||
</div>
|
||||
</Space>
|
||||
</Modal>
|
||||
);
|
||||
};
|
||||
|
||||
export default BindSubscriptionModal;
|
||||
@@ -294,6 +294,48 @@ export async function onLinuxDOOAuthClicked(
|
||||
);
|
||||
}
|
||||
|
||||
/**
|
||||
* Initiate custom OAuth login
|
||||
* @param {Object} provider - Custom OAuth provider config from status API
|
||||
* @param {string} provider.slug - Provider slug (used for callback URL)
|
||||
* @param {string} provider.client_id - OAuth client ID
|
||||
* @param {string} provider.authorization_endpoint - Authorization URL
|
||||
* @param {string} provider.scopes - OAuth scopes (space-separated)
|
||||
* @param {Object} options - Options
|
||||
* @param {boolean} options.shouldLogout - Whether to logout first
|
||||
*/
|
||||
export async function onCustomOAuthClicked(provider, options = {}) {
|
||||
const state = await prepareOAuthState(options);
|
||||
if (!state) return;
|
||||
|
||||
try {
|
||||
const redirect_uri = `${window.location.origin}/oauth/${provider.slug}`;
|
||||
|
||||
// Check if authorization_endpoint is a full URL or relative path
|
||||
let authUrl;
|
||||
if (provider.authorization_endpoint.startsWith('http://') ||
|
||||
provider.authorization_endpoint.startsWith('https://')) {
|
||||
authUrl = new URL(provider.authorization_endpoint);
|
||||
} else {
|
||||
// Relative path - this is a configuration error, show error message
|
||||
console.error('Custom OAuth authorization_endpoint must be a full URL:', provider.authorization_endpoint);
|
||||
showError('OAuth 配置错误:授权端点必须是完整的 URL(以 http:// 或 https:// 开头)');
|
||||
return;
|
||||
}
|
||||
|
||||
authUrl.searchParams.set('client_id', provider.client_id);
|
||||
authUrl.searchParams.set('redirect_uri', redirect_uri);
|
||||
authUrl.searchParams.set('response_type', 'code');
|
||||
authUrl.searchParams.set('scope', provider.scopes || 'openid profile email');
|
||||
authUrl.searchParams.set('state', state);
|
||||
|
||||
window.open(authUrl.toString());
|
||||
} catch (error) {
|
||||
console.error('Failed to initiate custom OAuth:', error);
|
||||
showError('OAuth 登录失败:' + (error.message || '未知错误'));
|
||||
}
|
||||
}
|
||||
|
||||
let channelModels = undefined;
|
||||
export async function loadChannelModels() {
|
||||
const res = await API.get('/api/models');
|
||||
|
||||
@@ -571,7 +571,6 @@ export const modelColorMap = {
|
||||
'claude-3-opus-20240229': 'rgb(255,132,31)', // 橙红色
|
||||
'claude-3-sonnet-20240229': 'rgb(253,135,93)', // 橙色
|
||||
'claude-3-haiku-20240307': 'rgb(255,175,146)', // 浅橙色
|
||||
'claude-2.1': 'rgb(255,209,190)', // 浅橙色(略有区别)
|
||||
};
|
||||
|
||||
export function modelToColor(modelName) {
|
||||
|
||||
@@ -1,312 +0,0 @@
|
||||
/*
|
||||
Copyright (C) 2025 QuantumNous
|
||||
|
||||
This program is free software: you can redistribute it and/or modify
|
||||
it under the terms of the GNU Affero General Public License as
|
||||
published by the Free Software Foundation, either version 3 of the
|
||||
License, or (at your option) any later version.
|
||||
|
||||
This program is distributed in the hope that it will be useful,
|
||||
but WITHOUT ANY WARRANTY; without even the implied warranty of
|
||||
MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
||||
GNU Affero General Public License for more details.
|
||||
|
||||
You should have received a copy of the GNU Affero General Public License
|
||||
along with this program. If not, see <https://www.gnu.org/licenses/>.
|
||||
|
||||
For commercial licensing, please contact support@quantumnous.com
|
||||
*/
|
||||
|
||||
import { useState, useCallback } from 'react';
|
||||
import { API } from '../../helpers';
|
||||
import { showError } from '../../helpers';
|
||||
|
||||
export const useDeploymentResources = () => {
|
||||
const [hardwareTypes, setHardwareTypes] = useState([]);
|
||||
const [hardwareTotalAvailable, setHardwareTotalAvailable] = useState(0);
|
||||
const [locations, setLocations] = useState([]);
|
||||
const [locationsTotalAvailable, setLocationsTotalAvailable] = useState(0);
|
||||
const [availableReplicas, setAvailableReplicas] = useState([]);
|
||||
const [priceEstimation, setPriceEstimation] = useState(null);
|
||||
|
||||
const [loadingHardware, setLoadingHardware] = useState(false);
|
||||
const [loadingLocations, setLoadingLocations] = useState(false);
|
||||
const [loadingReplicas, setLoadingReplicas] = useState(false);
|
||||
const [loadingPrice, setLoadingPrice] = useState(false);
|
||||
|
||||
const fetchHardwareTypes = useCallback(async () => {
|
||||
try {
|
||||
setLoadingHardware(true);
|
||||
const response = await API.get('/api/deployments/hardware-types');
|
||||
if (response.data.success) {
|
||||
const { hardware_types: hardwareList = [], total_available } =
|
||||
response.data.data || {};
|
||||
const normalizedHardware = hardwareList.map((hardware) => {
|
||||
const availableCountValue = Number(hardware.available_count);
|
||||
const availableCount = Number.isNaN(availableCountValue)
|
||||
? 0
|
||||
: availableCountValue;
|
||||
const availableBool =
|
||||
typeof hardware.available === 'boolean'
|
||||
? hardware.available
|
||||
: availableCount > 0;
|
||||
|
||||
return {
|
||||
...hardware,
|
||||
available: availableBool,
|
||||
available_count: availableCount,
|
||||
};
|
||||
});
|
||||
|
||||
const providedTotal = Number(total_available);
|
||||
const fallbackTotal = normalizedHardware.reduce(
|
||||
(acc, item) =>
|
||||
acc +
|
||||
(Number.isNaN(item.available_count) ? 0 : item.available_count),
|
||||
0,
|
||||
);
|
||||
const hasProvidedTotal =
|
||||
total_available !== undefined &&
|
||||
total_available !== null &&
|
||||
total_available !== '' &&
|
||||
!Number.isNaN(providedTotal);
|
||||
|
||||
setHardwareTypes(normalizedHardware);
|
||||
setHardwareTotalAvailable(
|
||||
hasProvidedTotal ? providedTotal : fallbackTotal,
|
||||
);
|
||||
return normalizedHardware;
|
||||
} else {
|
||||
showError('获取硬件类型失败: ' + response.data.message);
|
||||
setHardwareTotalAvailable(0);
|
||||
return [];
|
||||
}
|
||||
} catch (error) {
|
||||
showError('获取硬件类型失败: ' + error.message);
|
||||
setHardwareTotalAvailable(0);
|
||||
return [];
|
||||
} finally {
|
||||
setLoadingHardware(false);
|
||||
}
|
||||
}, []);
|
||||
|
||||
const fetchLocations = useCallback(async (hardwareId, gpuCount = 1) => {
|
||||
if (!hardwareId) {
|
||||
setLocations([]);
|
||||
setLocationsTotalAvailable(0);
|
||||
return [];
|
||||
}
|
||||
|
||||
try {
|
||||
setLoadingLocations(true);
|
||||
const response = await API.get(
|
||||
`/api/deployments/available-replicas?hardware_id=${hardwareId}&gpu_count=${gpuCount}`,
|
||||
);
|
||||
if (response.data.success) {
|
||||
const replicas = response.data.data?.replicas || [];
|
||||
const nextLocationsMap = new Map();
|
||||
replicas.forEach((replica) => {
|
||||
const rawId = replica?.location_id ?? replica?.location?.id;
|
||||
if (rawId === null || rawId === undefined) return;
|
||||
|
||||
const mapKey = String(rawId);
|
||||
if (nextLocationsMap.has(mapKey)) return;
|
||||
|
||||
const rawIso2 =
|
||||
replica?.iso2 ?? replica?.location_iso2 ?? replica?.location?.iso2;
|
||||
const iso2 = rawIso2 ? String(rawIso2).toUpperCase() : '';
|
||||
const name =
|
||||
replica?.location_name ??
|
||||
replica?.location?.name ??
|
||||
replica?.name ??
|
||||
String(rawId);
|
||||
|
||||
nextLocationsMap.set(mapKey, {
|
||||
id: rawId,
|
||||
name: String(name),
|
||||
iso2,
|
||||
region:
|
||||
replica?.region ??
|
||||
replica?.location_region ??
|
||||
replica?.location?.region,
|
||||
country:
|
||||
replica?.country ??
|
||||
replica?.location_country ??
|
||||
replica?.location?.country,
|
||||
code:
|
||||
replica?.code ??
|
||||
replica?.location_code ??
|
||||
replica?.location?.code,
|
||||
available: Number(replica?.available_count) || 0,
|
||||
});
|
||||
});
|
||||
|
||||
const normalizedLocations = Array.from(nextLocationsMap.values());
|
||||
setLocations(normalizedLocations);
|
||||
setLocationsTotalAvailable(
|
||||
normalizedLocations.reduce(
|
||||
(acc, item) => acc + (item.available || 0),
|
||||
0,
|
||||
),
|
||||
);
|
||||
return normalizedLocations;
|
||||
} else {
|
||||
showError('获取部署位置失败: ' + response.data.message);
|
||||
setLocationsTotalAvailable(0);
|
||||
return [];
|
||||
}
|
||||
} catch (error) {
|
||||
showError('获取部署位置失败: ' + error.message);
|
||||
setLocationsTotalAvailable(0);
|
||||
return [];
|
||||
} finally {
|
||||
setLoadingLocations(false);
|
||||
}
|
||||
}, []);
|
||||
|
||||
const fetchAvailableReplicas = useCallback(
|
||||
async (hardwareId, gpuCount = 1) => {
|
||||
if (!hardwareId) {
|
||||
setAvailableReplicas([]);
|
||||
return [];
|
||||
}
|
||||
|
||||
try {
|
||||
setLoadingReplicas(true);
|
||||
const response = await API.get(
|
||||
`/api/deployments/available-replicas?hardware_id=${hardwareId}&gpu_count=${gpuCount}`,
|
||||
);
|
||||
if (response.data.success) {
|
||||
const replicas = response.data.data.replicas || [];
|
||||
setAvailableReplicas(replicas);
|
||||
return replicas;
|
||||
} else {
|
||||
showError('获取可用资源失败: ' + response.data.message);
|
||||
setAvailableReplicas([]);
|
||||
return [];
|
||||
}
|
||||
} catch (error) {
|
||||
console.error('Load available replicas error:', error);
|
||||
setAvailableReplicas([]);
|
||||
return [];
|
||||
} finally {
|
||||
setLoadingReplicas(false);
|
||||
}
|
||||
},
|
||||
[],
|
||||
);
|
||||
|
||||
const calculatePrice = useCallback(async (params) => {
|
||||
const {
|
||||
locationIds,
|
||||
hardwareId,
|
||||
gpusPerContainer,
|
||||
durationHours,
|
||||
replicaCount,
|
||||
} = params;
|
||||
|
||||
if (
|
||||
!locationIds?.length ||
|
||||
!hardwareId ||
|
||||
!gpusPerContainer ||
|
||||
!durationHours ||
|
||||
!replicaCount
|
||||
) {
|
||||
setPriceEstimation(null);
|
||||
return null;
|
||||
}
|
||||
|
||||
try {
|
||||
setLoadingPrice(true);
|
||||
const requestData = {
|
||||
location_ids: locationIds,
|
||||
hardware_id: hardwareId,
|
||||
gpus_per_container: gpusPerContainer,
|
||||
duration_hours: durationHours,
|
||||
replica_count: replicaCount,
|
||||
};
|
||||
|
||||
const response = await API.post(
|
||||
'/api/deployments/price-estimation',
|
||||
requestData,
|
||||
);
|
||||
if (response.data.success) {
|
||||
const estimation = response.data.data;
|
||||
setPriceEstimation(estimation);
|
||||
return estimation;
|
||||
} else {
|
||||
showError('价格计算失败: ' + response.data.message);
|
||||
setPriceEstimation(null);
|
||||
return null;
|
||||
}
|
||||
} catch (error) {
|
||||
console.error('Price calculation error:', error);
|
||||
setPriceEstimation(null);
|
||||
return null;
|
||||
} finally {
|
||||
setLoadingPrice(false);
|
||||
}
|
||||
}, []);
|
||||
|
||||
const checkClusterNameAvailability = useCallback(async (name) => {
|
||||
if (!name?.trim()) return false;
|
||||
|
||||
try {
|
||||
const response = await API.get(
|
||||
`/api/deployments/check-name?name=${encodeURIComponent(name.trim())}`,
|
||||
);
|
||||
if (response.data.success) {
|
||||
return response.data.data.available;
|
||||
} else {
|
||||
showError('检查名称可用性失败: ' + response.data.message);
|
||||
return false;
|
||||
}
|
||||
} catch (error) {
|
||||
console.error('Check cluster name availability error:', error);
|
||||
return false;
|
||||
}
|
||||
}, []);
|
||||
|
||||
const createDeployment = useCallback(async (deploymentData) => {
|
||||
try {
|
||||
const response = await API.post('/api/deployments', deploymentData);
|
||||
if (response.data.success) {
|
||||
return response.data.data;
|
||||
} else {
|
||||
throw new Error(response.data.message || '创建部署失败');
|
||||
}
|
||||
} catch (error) {
|
||||
throw error;
|
||||
}
|
||||
}, []);
|
||||
|
||||
return {
|
||||
// Data
|
||||
hardwareTypes,
|
||||
hardwareTotalAvailable,
|
||||
locations,
|
||||
locationsTotalAvailable,
|
||||
availableReplicas,
|
||||
priceEstimation,
|
||||
|
||||
// Loading states
|
||||
loadingHardware,
|
||||
loadingLocations,
|
||||
loadingReplicas,
|
||||
loadingPrice,
|
||||
|
||||
// Functions
|
||||
fetchHardwareTypes,
|
||||
fetchLocations,
|
||||
fetchAvailableReplicas,
|
||||
calculatePrice,
|
||||
checkClusterNameAvailability,
|
||||
createDeployment,
|
||||
|
||||
// Clear functions
|
||||
clearPriceEstimation: () => setPriceEstimation(null),
|
||||
clearAvailableReplicas: () => setAvailableReplicas([]),
|
||||
};
|
||||
};
|
||||
|
||||
export default useDeploymentResources;
|
||||
@@ -1,286 +0,0 @@
|
||||
/*
|
||||
Copyright (C) 2025 QuantumNous
|
||||
|
||||
This program is free software: you can redistribute it and/or modify
|
||||
it under the terms of the GNU Affero General Public License as
|
||||
published by the Free Software Foundation, either version 3 of the
|
||||
License, or (at your option) any later version.
|
||||
|
||||
This program is distributed in the hope that it will be useful,
|
||||
but WITHOUT ANY WARRANTY; without even the implied warranty of
|
||||
MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
||||
GNU Affero General Public License for more details.
|
||||
|
||||
You should have received a copy of the GNU Affero General Public License
|
||||
along with this program. If not, see <https://www.gnu.org/licenses/>.
|
||||
|
||||
For commercial licensing, please contact support@quantumnous.com
|
||||
*/
|
||||
|
||||
import { useState } from 'react';
|
||||
import { API, showError, showSuccess } from '../../helpers';
|
||||
|
||||
export const useEnhancedDeploymentActions = (t) => {
|
||||
const [loading, setLoading] = useState({});
|
||||
|
||||
// Set loading state for specific operation
|
||||
const setOperationLoading = (operation, deploymentId, isLoading) => {
|
||||
setLoading((prev) => ({
|
||||
...prev,
|
||||
[`${operation}_${deploymentId}`]: isLoading,
|
||||
}));
|
||||
};
|
||||
|
||||
// Get loading state for specific operation
|
||||
const isOperationLoading = (operation, deploymentId) => {
|
||||
return loading[`${operation}_${deploymentId}`] || false;
|
||||
};
|
||||
|
||||
// Extend deployment duration
|
||||
const extendDeployment = async (deploymentId, durationHours) => {
|
||||
try {
|
||||
setOperationLoading('extend', deploymentId, true);
|
||||
|
||||
const response = await API.post(
|
||||
`/api/deployments/${deploymentId}/extend`,
|
||||
{
|
||||
duration_hours: durationHours,
|
||||
},
|
||||
);
|
||||
|
||||
if (response.data.success) {
|
||||
showSuccess(t('容器时长延长成功'));
|
||||
return response.data.data;
|
||||
}
|
||||
} catch (error) {
|
||||
showError(
|
||||
t('延长时长失败') +
|
||||
': ' +
|
||||
(error.response?.data?.message || error.message),
|
||||
);
|
||||
throw error;
|
||||
} finally {
|
||||
setOperationLoading('extend', deploymentId, false);
|
||||
}
|
||||
};
|
||||
|
||||
// Get deployment details
|
||||
const getDeploymentDetails = async (deploymentId) => {
|
||||
try {
|
||||
setOperationLoading('details', deploymentId, true);
|
||||
|
||||
const response = await API.get(`/api/deployments/${deploymentId}`);
|
||||
|
||||
if (response.data.success) {
|
||||
return response.data.data;
|
||||
}
|
||||
} catch (error) {
|
||||
showError(
|
||||
t('获取详情失败') +
|
||||
': ' +
|
||||
(error.response?.data?.message || error.message),
|
||||
);
|
||||
throw error;
|
||||
} finally {
|
||||
setOperationLoading('details', deploymentId, false);
|
||||
}
|
||||
};
|
||||
|
||||
// Get deployment logs
|
||||
const getDeploymentLogs = async (deploymentId, options = {}) => {
|
||||
try {
|
||||
setOperationLoading('logs', deploymentId, true);
|
||||
|
||||
const params = new URLSearchParams();
|
||||
|
||||
if (options.containerId)
|
||||
params.append('container_id', options.containerId);
|
||||
if (options.level) params.append('level', options.level);
|
||||
if (options.limit) params.append('limit', options.limit.toString());
|
||||
if (options.cursor) params.append('cursor', options.cursor);
|
||||
if (options.follow) params.append('follow', 'true');
|
||||
if (options.startTime) params.append('start_time', options.startTime);
|
||||
if (options.endTime) params.append('end_time', options.endTime);
|
||||
|
||||
const response = await API.get(
|
||||
`/api/deployments/${deploymentId}/logs?${params}`,
|
||||
);
|
||||
|
||||
if (response.data.success) {
|
||||
return response.data.data;
|
||||
}
|
||||
} catch (error) {
|
||||
showError(
|
||||
t('获取日志失败') +
|
||||
': ' +
|
||||
(error.response?.data?.message || error.message),
|
||||
);
|
||||
throw error;
|
||||
} finally {
|
||||
setOperationLoading('logs', deploymentId, false);
|
||||
}
|
||||
};
|
||||
|
||||
// Update deployment configuration
|
||||
const updateDeploymentConfig = async (deploymentId, config) => {
|
||||
try {
|
||||
setOperationLoading('config', deploymentId, true);
|
||||
|
||||
const response = await API.put(
|
||||
`/api/deployments/${deploymentId}`,
|
||||
config,
|
||||
);
|
||||
|
||||
if (response.data.success) {
|
||||
showSuccess(t('容器配置更新成功'));
|
||||
return response.data.data;
|
||||
}
|
||||
} catch (error) {
|
||||
showError(
|
||||
t('更新配置失败') +
|
||||
': ' +
|
||||
(error.response?.data?.message || error.message),
|
||||
);
|
||||
throw error;
|
||||
} finally {
|
||||
setOperationLoading('config', deploymentId, false);
|
||||
}
|
||||
};
|
||||
|
||||
// Delete (destroy) deployment
|
||||
const deleteDeployment = async (deploymentId) => {
|
||||
try {
|
||||
setOperationLoading('delete', deploymentId, true);
|
||||
|
||||
const response = await API.delete(`/api/deployments/${deploymentId}`);
|
||||
|
||||
if (response.data.success) {
|
||||
showSuccess(t('容器销毁请求已提交'));
|
||||
return response.data.data;
|
||||
}
|
||||
} catch (error) {
|
||||
showError(
|
||||
t('销毁容器失败') +
|
||||
': ' +
|
||||
(error.response?.data?.message || error.message),
|
||||
);
|
||||
throw error;
|
||||
} finally {
|
||||
setOperationLoading('delete', deploymentId, false);
|
||||
}
|
||||
};
|
||||
|
||||
// Update deployment name
|
||||
const updateDeploymentName = async (deploymentId, newName) => {
|
||||
try {
|
||||
setOperationLoading('rename', deploymentId, true);
|
||||
|
||||
const response = await API.put(`/api/deployments/${deploymentId}/name`, {
|
||||
name: newName,
|
||||
});
|
||||
|
||||
if (response.data.success) {
|
||||
showSuccess(t('容器名称更新成功'));
|
||||
return response.data.data;
|
||||
}
|
||||
} catch (error) {
|
||||
showError(
|
||||
t('更新名称失败') +
|
||||
': ' +
|
||||
(error.response?.data?.message || error.message),
|
||||
);
|
||||
throw error;
|
||||
} finally {
|
||||
setOperationLoading('rename', deploymentId, false);
|
||||
}
|
||||
};
|
||||
|
||||
// Batch operations
|
||||
const batchDelete = async (deploymentIds) => {
|
||||
try {
|
||||
setOperationLoading('batch_delete', 'all', true);
|
||||
|
||||
const results = await Promise.allSettled(
|
||||
deploymentIds.map((id) => deleteDeployment(id)),
|
||||
);
|
||||
|
||||
const successful = results.filter((r) => r.status === 'fulfilled').length;
|
||||
const failed = results.filter((r) => r.status === 'rejected').length;
|
||||
|
||||
if (successful > 0) {
|
||||
showSuccess(
|
||||
t('批量操作完成: {{success}}个成功, {{failed}}个失败', {
|
||||
success: successful,
|
||||
failed: failed,
|
||||
}),
|
||||
);
|
||||
}
|
||||
|
||||
return { successful, failed };
|
||||
} catch (error) {
|
||||
showError(t('批量操作失败') + ': ' + error.message);
|
||||
throw error;
|
||||
} finally {
|
||||
setOperationLoading('batch_delete', 'all', false);
|
||||
}
|
||||
};
|
||||
|
||||
// Export logs
|
||||
const exportLogs = async (deploymentId, options = {}) => {
|
||||
try {
|
||||
setOperationLoading('export_logs', deploymentId, true);
|
||||
|
||||
const logs = await getDeploymentLogs(deploymentId, {
|
||||
...options,
|
||||
limit: 10000, // Get more logs for export
|
||||
});
|
||||
|
||||
if (logs && logs.logs) {
|
||||
const logText = logs.logs
|
||||
.map(
|
||||
(log) =>
|
||||
`[${new Date(log.timestamp).toISOString()}] [${log.level}] ${log.source ? `[${log.source}] ` : ''}${log.message}`,
|
||||
)
|
||||
.join('\n');
|
||||
|
||||
const blob = new Blob([logText], { type: 'text/plain' });
|
||||
const url = URL.createObjectURL(blob);
|
||||
const a = document.createElement('a');
|
||||
a.href = url;
|
||||
a.download = `deployment-${deploymentId}-logs-${new Date().toISOString().split('T')[0]}.txt`;
|
||||
document.body.appendChild(a);
|
||||
a.click();
|
||||
document.body.removeChild(a);
|
||||
URL.revokeObjectURL(url);
|
||||
|
||||
showSuccess(t('日志导出成功'));
|
||||
}
|
||||
} catch (error) {
|
||||
showError(t('导出日志失败') + ': ' + error.message);
|
||||
throw error;
|
||||
} finally {
|
||||
setOperationLoading('export_logs', deploymentId, false);
|
||||
}
|
||||
};
|
||||
|
||||
return {
|
||||
// Actions
|
||||
extendDeployment,
|
||||
getDeploymentDetails,
|
||||
getDeploymentLogs,
|
||||
updateDeploymentConfig,
|
||||
deleteDeployment,
|
||||
updateDeploymentName,
|
||||
batchDelete,
|
||||
exportLogs,
|
||||
|
||||
// Loading states
|
||||
isOperationLoading,
|
||||
loading,
|
||||
|
||||
// Utility
|
||||
setOperationLoading,
|
||||
};
|
||||
};
|
||||
|
||||
export default useEnhancedDeploymentActions;
|
||||
@@ -40,6 +40,7 @@ export const useTokensData = (openFluentNotification) => {
|
||||
const [tokenCount, setTokenCount] = useState(0);
|
||||
const [pageSize, setPageSize] = useState(ITEMS_PER_PAGE);
|
||||
const [searching, setSearching] = useState(false);
|
||||
const [searchMode, setSearchMode] = useState(false); // 是否处于搜索结果视图
|
||||
|
||||
// Selection state
|
||||
const [selectedKeys, setSelectedKeys] = useState([]);
|
||||
@@ -91,6 +92,7 @@ export const useTokensData = (openFluentNotification) => {
|
||||
// Load tokens function
|
||||
const loadTokens = async (page = 1, size = pageSize) => {
|
||||
setLoading(true);
|
||||
setSearchMode(false);
|
||||
const res = await API.get(`/api/token/?p=${page}&size=${size}`);
|
||||
const { success, message, data } = res.data;
|
||||
if (success) {
|
||||
@@ -188,21 +190,21 @@ export const useTokensData = (openFluentNotification) => {
|
||||
};
|
||||
|
||||
// Search tokens function
|
||||
const searchTokens = async () => {
|
||||
const searchTokens = async (page = 1, size = pageSize) => {
|
||||
const { searchKeyword, searchToken } = getFormValues();
|
||||
if (searchKeyword === '' && searchToken === '') {
|
||||
setSearchMode(false);
|
||||
await loadTokens(1);
|
||||
return;
|
||||
}
|
||||
setSearching(true);
|
||||
const res = await API.get(
|
||||
`/api/token/search?keyword=${searchKeyword}&token=${searchToken}`,
|
||||
`/api/token/search?keyword=${encodeURIComponent(searchKeyword)}&token=${encodeURIComponent(searchToken)}&p=${page}&size=${size}`,
|
||||
);
|
||||
const { success, message, data } = res.data;
|
||||
if (success) {
|
||||
setTokens(data);
|
||||
setTokenCount(data.length);
|
||||
setActivePage(1);
|
||||
setSearchMode(true);
|
||||
syncPageData(data);
|
||||
} else {
|
||||
showError(message);
|
||||
}
|
||||
@@ -226,12 +228,20 @@ export const useTokensData = (openFluentNotification) => {
|
||||
|
||||
// Page handlers
|
||||
const handlePageChange = (page) => {
|
||||
loadTokens(page, pageSize).then();
|
||||
if (searchMode) {
|
||||
searchTokens(page, pageSize).then();
|
||||
} else {
|
||||
loadTokens(page, pageSize).then();
|
||||
}
|
||||
};
|
||||
|
||||
const handlePageSizeChange = async (size) => {
|
||||
setPageSize(size);
|
||||
await loadTokens(1, size);
|
||||
if (searchMode) {
|
||||
await searchTokens(1, size);
|
||||
} else {
|
||||
await loadTokens(1, size);
|
||||
}
|
||||
};
|
||||
|
||||
// Row selection handlers
|
||||
|
||||
@@ -2795,6 +2795,49 @@
|
||||
"语言偏好": "Language Preference",
|
||||
"选择您的首选界面语言,设置将自动保存并同步到所有设备": "Select your preferred interface language. Settings will be saved automatically and synced across all devices",
|
||||
"语言偏好已保存": "Language preference saved",
|
||||
"提示:语言偏好会同步到您登录的所有设备,并影响API返回的错误消息语言。": "Note: Language preference syncs across all your logged-in devices and affects the language of API error messages."
|
||||
"提示:语言偏好会同步到您登录的所有设备,并影响API返回的错误消息语言。": "Note: Language preference syncs across all your logged-in devices and affects the language of API error messages.",
|
||||
"自定义 OAuth 提供商": "Custom OAuth Providers",
|
||||
"配置自定义 OAuth 提供商,支持 GitHub Enterprise、GitLab、Gitea、Nextcloud、Keycloak、ORY 等兼容 OAuth 2.0 协议的身份提供商": "Configure custom OAuth providers, supports GitHub Enterprise, GitLab, Gitea, Nextcloud, Keycloak, ORY and other OAuth 2.0 compatible identity providers",
|
||||
"回调 URL 格式": "Callback URL format",
|
||||
"添加提供商": "Add Provider",
|
||||
"编辑提供商": "Edit Provider",
|
||||
"选择预设...": "Select preset...",
|
||||
"输入基础 URL": "Enter base URL",
|
||||
"例如": "e.g.",
|
||||
"提供商名称": "Provider Name",
|
||||
"标识符 (Slug)": "Slug",
|
||||
"授权端点": "Authorization Endpoint",
|
||||
"令牌端点": "Token Endpoint",
|
||||
"用户信息端点": "User Info Endpoint",
|
||||
"用户 ID 字段": "User ID Field",
|
||||
"支持 JSONPath,如 sub, id, data.user.id": "Supports JSONPath, e.g. sub, id, data.user.id",
|
||||
"用户名字段": "Username Field",
|
||||
"支持 JSONPath,如 preferred_username, login, data.user.username": "Supports JSONPath, e.g. preferred_username, login, data.user.username",
|
||||
"显示名称字段": "Display Name Field",
|
||||
"支持 JSONPath,如 name, display_name, data.user.name": "Supports JSONPath, e.g. name, display_name, data.user.name",
|
||||
"邮箱字段": "Email Field",
|
||||
"支持 JSONPath,如 email, data.user.email": "Supports JSONPath, e.g. email, data.user.email",
|
||||
"授权范围 (Scopes)": "Scopes",
|
||||
"认证方式": "Auth Style",
|
||||
"自动检测": "Auto-detect",
|
||||
"参数传递": "In Parameters",
|
||||
"Basic Auth 头": "Basic Auth Header",
|
||||
"暂无自定义 OAuth 提供商": "No custom OAuth providers",
|
||||
"确定要删除该提供商吗?": "Are you sure you want to delete this provider?",
|
||||
"创建成功": "Created successfully",
|
||||
"更新成功": "Updated successfully",
|
||||
"确认解绑": "Confirm Unbind",
|
||||
"确定要解绑 {{name}} 吗?": "Are you sure you want to unbind {{name}}?",
|
||||
"解绑成功": "Unbind successful",
|
||||
"{{name}} ID": "{{name}} ID",
|
||||
"使用 {{name}} 继续": "Continue with {{name}}",
|
||||
"端点 URL 必须以 http:// 或 https:// 开头:": "Endpoint URL must start with http:// or https://: ",
|
||||
"OAuth 配置错误:授权端点必须是完整的 URL(以 http:// 或 https:// 开头)": "OAuth configuration error: Authorization endpoint must be a full URL (starting with http:// or https://)",
|
||||
"OAuth 登录失败:": "OAuth login failed: ",
|
||||
"必填:请输入服务器地址以自动生成完整端点 URL": "Required: Enter server address to auto-generate full endpoint URLs",
|
||||
"填写服务器地址后自动生成:": "Auto-generated after entering server address: ",
|
||||
"自动生成:": "Auto-generated: ",
|
||||
"请先填写服务器地址,以自动生成完整的端点 URL": "Please enter the server address first to auto-generate full endpoint URLs",
|
||||
"端点 URL 必须是完整地址(以 http:// 或 https:// 开头)": "Endpoint URL must be a full address (starting with http:// or https://)"
|
||||
}
|
||||
}
|
||||
|
||||
@@ -2740,6 +2740,49 @@
|
||||
"语言偏好": "语言偏好",
|
||||
"选择您的首选界面语言,设置将自动保存并同步到所有设备": "选择您的首选界面语言,设置将自动保存并同步到所有设备",
|
||||
"语言偏好已保存": "语言偏好已保存",
|
||||
"提示:语言偏好会同步到您登录的所有设备,并影响API返回的错误消息语言。": "提示:语言偏好会同步到您登录的所有设备,并影响API返回的错误消息语言。"
|
||||
"提示:语言偏好会同步到您登录的所有设备,并影响API返回的错误消息语言。": "提示:语言偏好会同步到您登录的所有设备,并影响API返回的错误消息语言。",
|
||||
"自定义 OAuth 提供商": "自定义 OAuth 提供商",
|
||||
"配置自定义 OAuth 提供商,支持 GitHub Enterprise、GitLab、Gitea、Nextcloud、Keycloak、ORY 等兼容 OAuth 2.0 协议的身份提供商": "配置自定义 OAuth 提供商,支持 GitHub Enterprise、GitLab、Gitea、Nextcloud、Keycloak、ORY 等兼容 OAuth 2.0 协议的身份提供商",
|
||||
"回调 URL 格式": "回调 URL 格式",
|
||||
"添加提供商": "添加提供商",
|
||||
"编辑提供商": "编辑提供商",
|
||||
"选择预设...": "选择预设...",
|
||||
"输入基础 URL": "输入基础 URL",
|
||||
"例如": "例如",
|
||||
"提供商名称": "提供商名称",
|
||||
"标识符 (Slug)": "标识符 (Slug)",
|
||||
"授权端点": "授权端点",
|
||||
"令牌端点": "令牌端点",
|
||||
"用户信息端点": "用户信息端点",
|
||||
"用户 ID 字段": "用户 ID 字段",
|
||||
"支持 JSONPath,如 sub, id, data.user.id": "支持 JSONPath,如 sub, id, data.user.id",
|
||||
"用户名字段": "用户名字段",
|
||||
"支持 JSONPath,如 preferred_username, login, data.user.username": "支持 JSONPath,如 preferred_username, login, data.user.username",
|
||||
"显示名称字段": "显示名称字段",
|
||||
"支持 JSONPath,如 name, display_name, data.user.name": "支持 JSONPath,如 name, display_name, data.user.name",
|
||||
"邮箱字段": "邮箱字段",
|
||||
"支持 JSONPath,如 email, data.user.email": "支持 JSONPath,如 email, data.user.email",
|
||||
"授权范围 (Scopes)": "授权范围 (Scopes)",
|
||||
"认证方式": "认证方式",
|
||||
"自动检测": "自动检测",
|
||||
"参数传递": "参数传递",
|
||||
"Basic Auth 头": "Basic Auth 头",
|
||||
"暂无自定义 OAuth 提供商": "暂无自定义 OAuth 提供商",
|
||||
"确定要删除该提供商吗?": "确定要删除该提供商吗?",
|
||||
"创建成功": "创建成功",
|
||||
"更新成功": "更新成功",
|
||||
"确认解绑": "确认解绑",
|
||||
"确定要解绑 {{name}} 吗?": "确定要解绑 {{name}} 吗?",
|
||||
"解绑成功": "解绑成功",
|
||||
"{{name}} ID": "{{name}} ID",
|
||||
"使用 {{name}} 继续": "使用 {{name}} 继续",
|
||||
"端点 URL 必须以 http:// 或 https:// 开头:": "端点 URL 必须以 http:// 或 https:// 开头:",
|
||||
"OAuth 配置错误:授权端点必须是完整的 URL(以 http:// 或 https:// 开头)": "OAuth 配置错误:授权端点必须是完整的 URL(以 http:// 或 https:// 开头)",
|
||||
"OAuth 登录失败:": "OAuth 登录失败:",
|
||||
"必填:请输入服务器地址以自动生成完整端点 URL": "必填:请输入服务器地址以自动生成完整端点 URL",
|
||||
"填写服务器地址后自动生成:": "填写服务器地址后自动生成:",
|
||||
"自动生成:": "自动生成:",
|
||||
"请先填写服务器地址,以自动生成完整的端点 URL": "请先填写服务器地址,以自动生成完整的端点 URL",
|
||||
"端点 URL 必须是完整地址(以 http:// 或 https:// 开头)": "端点 URL 必须是完整地址(以 http:// 或 https:// 开头)"
|
||||
}
|
||||
}
|
||||
|
||||
@@ -56,6 +56,7 @@ export default function GeneralSettings(props) {
|
||||
DefaultCollapseSidebar: false,
|
||||
DemoSiteEnabled: false,
|
||||
SelfUseModeEnabled: false,
|
||||
'token_setting.max_user_tokens': 1000,
|
||||
});
|
||||
const refForm = useRef();
|
||||
const [inputsRow, setInputsRow] = useState(inputs);
|
||||
@@ -287,6 +288,19 @@ export default function GeneralSettings(props) {
|
||||
/>
|
||||
</Col>
|
||||
</Row>
|
||||
<Row gutter={16}>
|
||||
<Col xs={24} sm={12} md={8} lg={8} xl={8}>
|
||||
<Form.InputNumber
|
||||
label={t('用户最大令牌数量')}
|
||||
field={'token_setting.max_user_tokens'}
|
||||
step={1}
|
||||
min={1}
|
||||
extraText={t('每个用户最多可创建的令牌数量,默认 1000,设置过大可能会影响性能')}
|
||||
placeholder={'1000'}
|
||||
onChange={handleFieldChange('token_setting.max_user_tokens')}
|
||||
/>
|
||||
</Col>
|
||||
</Row>
|
||||
<Row>
|
||||
<Button size='default' onClick={onSubmit}>
|
||||
{t('保存通用设置')}
|
||||
|
||||
@@ -1,488 +0,0 @@
|
||||
/*
|
||||
Copyright (C) 2025 QuantumNous
|
||||
|
||||
This program is free software: you can redistribute it and/or modify
|
||||
it under the terms of the GNU Affero General Public License as
|
||||
published by the Free Software Foundation, either version 3 of the
|
||||
License, or (at your option) any later version.
|
||||
|
||||
This program is distributed in the hope that it will be useful,
|
||||
but WITHOUT ANY WARRANTY; without even the implied warranty of
|
||||
MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
||||
GNU Affero General Public License for more details.
|
||||
|
||||
You should have received a copy of the GNU Affero General Public License
|
||||
along with this program. If not, see <https://www.gnu.org/licenses/>.
|
||||
|
||||
For commercial licensing, please contact support@quantumnous.com
|
||||
*/
|
||||
|
||||
import { useState, useEffect, useContext } from 'react';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
import {
|
||||
Card,
|
||||
Button,
|
||||
Switch,
|
||||
Typography,
|
||||
Row,
|
||||
Col,
|
||||
Avatar,
|
||||
} from '@douyinfe/semi-ui';
|
||||
import { API, showSuccess, showError } from '../../../helpers';
|
||||
import { StatusContext } from '../../../context/Status';
|
||||
import { UserContext } from '../../../context/User';
|
||||
import { useUserPermissions } from '../../../hooks/common/useUserPermissions';
|
||||
import { mergeAdminConfig, useSidebar } from '../../../hooks/common/useSidebar';
|
||||
import { Settings } from 'lucide-react';
|
||||
|
||||
const { Text } = Typography;
|
||||
|
||||
export default function SettingsSidebarModulesUser() {
|
||||
const { t } = useTranslation();
|
||||
const [loading, setLoading] = useState(false);
|
||||
const [statusState] = useContext(StatusContext);
|
||||
|
||||
// 使用后端权限验证替代前端角色判断
|
||||
const {
|
||||
permissions,
|
||||
loading: permissionsLoading,
|
||||
hasSidebarSettingsPermission,
|
||||
isSidebarSectionAllowed,
|
||||
isSidebarModuleAllowed,
|
||||
} = useUserPermissions();
|
||||
|
||||
// 使用useSidebar钩子获取刷新方法
|
||||
const { refreshUserConfig } = useSidebar();
|
||||
|
||||
// 如果没有边栏设置权限,不显示此组件
|
||||
if (!permissionsLoading && !hasSidebarSettingsPermission()) {
|
||||
return null;
|
||||
}
|
||||
|
||||
// 权限加载中,显示加载状态
|
||||
if (permissionsLoading) {
|
||||
return null;
|
||||
}
|
||||
|
||||
// 根据用户权限生成默认配置
|
||||
const generateDefaultConfig = () => {
|
||||
const defaultConfig = {};
|
||||
|
||||
// 聊天区域 - 所有用户都可以访问
|
||||
if (isSidebarSectionAllowed('chat')) {
|
||||
defaultConfig.chat = {
|
||||
enabled: true,
|
||||
playground: isSidebarModuleAllowed('chat', 'playground'),
|
||||
chat: isSidebarModuleAllowed('chat', 'chat'),
|
||||
};
|
||||
}
|
||||
|
||||
// 控制台区域 - 所有用户都可以访问
|
||||
if (isSidebarSectionAllowed('console')) {
|
||||
defaultConfig.console = {
|
||||
enabled: true,
|
||||
detail: isSidebarModuleAllowed('console', 'detail'),
|
||||
token: isSidebarModuleAllowed('console', 'token'),
|
||||
log: isSidebarModuleAllowed('console', 'log'),
|
||||
midjourney: isSidebarModuleAllowed('console', 'midjourney'),
|
||||
task: isSidebarModuleAllowed('console', 'task'),
|
||||
};
|
||||
}
|
||||
|
||||
// 个人中心区域 - 所有用户都可以访问
|
||||
if (isSidebarSectionAllowed('personal')) {
|
||||
defaultConfig.personal = {
|
||||
enabled: true,
|
||||
topup: isSidebarModuleAllowed('personal', 'topup'),
|
||||
personal: isSidebarModuleAllowed('personal', 'personal'),
|
||||
};
|
||||
}
|
||||
|
||||
// 管理员区域 - 只有管理员可以访问
|
||||
if (isSidebarSectionAllowed('admin')) {
|
||||
defaultConfig.admin = {
|
||||
enabled: true,
|
||||
channel: isSidebarModuleAllowed('admin', 'channel'),
|
||||
models: isSidebarModuleAllowed('admin', 'models'),
|
||||
deployment: isSidebarModuleAllowed('admin', 'deployment'),
|
||||
redemption: isSidebarModuleAllowed('admin', 'redemption'),
|
||||
user: isSidebarModuleAllowed('admin', 'user'),
|
||||
subscription: isSidebarModuleAllowed('admin', 'subscription'),
|
||||
setting: isSidebarModuleAllowed('admin', 'setting'),
|
||||
};
|
||||
}
|
||||
|
||||
return defaultConfig;
|
||||
};
|
||||
|
||||
// 用户个人左侧边栏模块设置
|
||||
const [sidebarModulesUser, setSidebarModulesUser] = useState({});
|
||||
|
||||
// 管理员全局配置
|
||||
const [adminConfig, setAdminConfig] = useState(null);
|
||||
|
||||
// 处理区域级别开关变更
|
||||
function handleSectionChange(sectionKey) {
|
||||
return (checked) => {
|
||||
const newModules = {
|
||||
...sidebarModulesUser,
|
||||
[sectionKey]: {
|
||||
...sidebarModulesUser[sectionKey],
|
||||
enabled: checked,
|
||||
},
|
||||
};
|
||||
setSidebarModulesUser(newModules);
|
||||
console.log('用户边栏区域配置变更:', sectionKey, checked, newModules);
|
||||
};
|
||||
}
|
||||
|
||||
// 处理功能级别开关变更
|
||||
function handleModuleChange(sectionKey, moduleKey) {
|
||||
return (checked) => {
|
||||
const newModules = {
|
||||
...sidebarModulesUser,
|
||||
[sectionKey]: {
|
||||
...sidebarModulesUser[sectionKey],
|
||||
[moduleKey]: checked,
|
||||
},
|
||||
};
|
||||
setSidebarModulesUser(newModules);
|
||||
console.log(
|
||||
'用户边栏功能配置变更:',
|
||||
sectionKey,
|
||||
moduleKey,
|
||||
checked,
|
||||
newModules,
|
||||
);
|
||||
};
|
||||
}
|
||||
|
||||
// 重置为默认配置(基于权限过滤)
|
||||
function resetSidebarModules() {
|
||||
const defaultConfig = generateDefaultConfig();
|
||||
setSidebarModulesUser(defaultConfig);
|
||||
showSuccess(t('已重置为默认配置'));
|
||||
console.log('用户边栏配置重置为默认:', defaultConfig);
|
||||
}
|
||||
|
||||
// 保存配置
|
||||
async function onSubmit() {
|
||||
setLoading(true);
|
||||
try {
|
||||
console.log('保存用户边栏配置:', sidebarModulesUser);
|
||||
const res = await API.put('/api/user/self', {
|
||||
sidebar_modules: JSON.stringify(sidebarModulesUser),
|
||||
});
|
||||
const { success, message } = res.data;
|
||||
if (success) {
|
||||
showSuccess(t('保存成功'));
|
||||
console.log('用户边栏配置保存成功');
|
||||
|
||||
// 刷新useSidebar钩子中的用户配置,实现实时更新
|
||||
await refreshUserConfig();
|
||||
console.log('用户边栏配置已刷新,边栏将立即更新');
|
||||
} else {
|
||||
showError(message);
|
||||
console.error('用户边栏配置保存失败:', message);
|
||||
}
|
||||
} catch (error) {
|
||||
showError(t('保存失败,请重试'));
|
||||
console.error('用户边栏配置保存异常:', error);
|
||||
} finally {
|
||||
setLoading(false);
|
||||
}
|
||||
}
|
||||
|
||||
// 统一的配置加载逻辑
|
||||
useEffect(() => {
|
||||
const loadConfigs = async () => {
|
||||
try {
|
||||
// 获取管理员全局配置
|
||||
if (statusState?.status?.SidebarModulesAdmin) {
|
||||
try {
|
||||
const adminConf = JSON.parse(
|
||||
statusState.status.SidebarModulesAdmin,
|
||||
);
|
||||
const mergedAdminConf = mergeAdminConfig(adminConf);
|
||||
setAdminConfig(mergedAdminConf);
|
||||
console.log('加载管理员边栏配置:', mergedAdminConf);
|
||||
} catch (error) {
|
||||
const mergedAdminConf = mergeAdminConfig(null);
|
||||
setAdminConfig(mergedAdminConf);
|
||||
console.log(
|
||||
'加载管理员边栏配置失败,使用默认配置:',
|
||||
mergedAdminConf,
|
||||
);
|
||||
}
|
||||
} else {
|
||||
const mergedAdminConf = mergeAdminConfig(null);
|
||||
setAdminConfig(mergedAdminConf);
|
||||
console.log('管理员边栏配置缺失,使用默认配置:', mergedAdminConf);
|
||||
}
|
||||
|
||||
// 获取用户个人配置
|
||||
const userRes = await API.get('/api/user/self');
|
||||
if (userRes.data.success && userRes.data.data.sidebar_modules) {
|
||||
let userConf;
|
||||
// 检查sidebar_modules是字符串还是对象
|
||||
if (typeof userRes.data.data.sidebar_modules === 'string') {
|
||||
userConf = JSON.parse(userRes.data.data.sidebar_modules);
|
||||
} else {
|
||||
userConf = userRes.data.data.sidebar_modules;
|
||||
}
|
||||
console.log('从API加载的用户配置:', userConf);
|
||||
|
||||
// 确保用户配置也经过权限过滤
|
||||
const filteredUserConf = {};
|
||||
Object.keys(userConf).forEach((sectionKey) => {
|
||||
if (isSidebarSectionAllowed(sectionKey)) {
|
||||
filteredUserConf[sectionKey] = { ...userConf[sectionKey] };
|
||||
// 过滤不允许的模块
|
||||
Object.keys(userConf[sectionKey]).forEach((moduleKey) => {
|
||||
if (
|
||||
moduleKey !== 'enabled' &&
|
||||
!isSidebarModuleAllowed(sectionKey, moduleKey)
|
||||
) {
|
||||
delete filteredUserConf[sectionKey][moduleKey];
|
||||
}
|
||||
});
|
||||
}
|
||||
});
|
||||
setSidebarModulesUser(filteredUserConf);
|
||||
console.log('权限过滤后的用户配置:', filteredUserConf);
|
||||
} else {
|
||||
// 如果用户没有配置,使用权限过滤后的默认配置
|
||||
const defaultConfig = generateDefaultConfig();
|
||||
setSidebarModulesUser(defaultConfig);
|
||||
console.log('用户无配置,使用默认配置:', defaultConfig);
|
||||
}
|
||||
} catch (error) {
|
||||
console.error('加载边栏配置失败:', error);
|
||||
// 出错时也使用默认配置
|
||||
const defaultConfig = generateDefaultConfig();
|
||||
setSidebarModulesUser(defaultConfig);
|
||||
}
|
||||
};
|
||||
|
||||
// 只有权限加载完成且有边栏设置权限时才加载配置
|
||||
if (!permissionsLoading && hasSidebarSettingsPermission()) {
|
||||
loadConfigs();
|
||||
}
|
||||
}, [
|
||||
statusState,
|
||||
permissionsLoading,
|
||||
hasSidebarSettingsPermission,
|
||||
isSidebarSectionAllowed,
|
||||
isSidebarModuleAllowed,
|
||||
]);
|
||||
|
||||
// 检查功能是否被管理员允许
|
||||
const isAllowedByAdmin = (sectionKey, moduleKey = null) => {
|
||||
if (!adminConfig) return true;
|
||||
|
||||
if (moduleKey) {
|
||||
return (
|
||||
adminConfig[sectionKey]?.enabled && adminConfig[sectionKey]?.[moduleKey]
|
||||
);
|
||||
} else {
|
||||
return adminConfig[sectionKey]?.enabled;
|
||||
}
|
||||
};
|
||||
|
||||
// 区域配置数据(根据后端权限过滤)
|
||||
const sectionConfigs = [
|
||||
{
|
||||
key: 'chat',
|
||||
title: t('聊天区域'),
|
||||
description: t('操练场和聊天功能'),
|
||||
modules: [
|
||||
{
|
||||
key: 'playground',
|
||||
title: t('操练场'),
|
||||
description: t('AI模型测试环境'),
|
||||
},
|
||||
{ key: 'chat', title: t('聊天'), description: t('聊天会话管理') },
|
||||
],
|
||||
},
|
||||
{
|
||||
key: 'console',
|
||||
title: t('控制台区域'),
|
||||
description: t('数据管理和日志查看'),
|
||||
modules: [
|
||||
{ key: 'detail', title: t('数据看板'), description: t('系统数据统计') },
|
||||
{ key: 'token', title: t('令牌管理'), description: t('API令牌管理') },
|
||||
{ key: 'log', title: t('使用日志'), description: t('API使用记录') },
|
||||
{
|
||||
key: 'midjourney',
|
||||
title: t('绘图日志'),
|
||||
description: t('绘图任务记录'),
|
||||
},
|
||||
{ key: 'task', title: t('任务日志'), description: t('系统任务记录') },
|
||||
],
|
||||
},
|
||||
{
|
||||
key: 'personal',
|
||||
title: t('个人中心区域'),
|
||||
description: t('用户个人功能'),
|
||||
modules: [
|
||||
{ key: 'topup', title: t('钱包管理'), description: t('余额充值管理') },
|
||||
{
|
||||
key: 'personal',
|
||||
title: t('个人设置'),
|
||||
description: t('个人信息设置'),
|
||||
},
|
||||
],
|
||||
},
|
||||
{
|
||||
key: 'admin',
|
||||
title: t('管理员区域'),
|
||||
description: t('系统管理功能'),
|
||||
modules: [
|
||||
{ key: 'channel', title: t('渠道管理'), description: t('API渠道配置') },
|
||||
{ key: 'models', title: t('模型管理'), description: t('AI模型配置') },
|
||||
{
|
||||
key: 'deployment',
|
||||
title: t('模型部署'),
|
||||
description: t('模型部署管理'),
|
||||
},
|
||||
{
|
||||
key: 'subscription',
|
||||
title: t('订阅管理'),
|
||||
description: t('订阅套餐管理'),
|
||||
},
|
||||
{
|
||||
key: 'redemption',
|
||||
title: t('兑换码管理'),
|
||||
description: t('兑换码生成管理'),
|
||||
},
|
||||
{ key: 'user', title: t('用户管理'), description: t('用户账户管理') },
|
||||
{
|
||||
key: 'setting',
|
||||
title: t('系统设置'),
|
||||
description: t('系统参数配置'),
|
||||
},
|
||||
],
|
||||
},
|
||||
]
|
||||
.filter((section) => {
|
||||
// 使用后端权限验证替代前端角色判断
|
||||
return isSidebarSectionAllowed(section.key);
|
||||
})
|
||||
.map((section) => ({
|
||||
...section,
|
||||
modules: section.modules.filter((module) =>
|
||||
isSidebarModuleAllowed(section.key, module.key),
|
||||
),
|
||||
}))
|
||||
.filter(
|
||||
(section) =>
|
||||
// 过滤掉没有可用模块的区域
|
||||
section.modules.length > 0 && isAllowedByAdmin(section.key),
|
||||
);
|
||||
|
||||
return (
|
||||
<Card className='!rounded-2xl shadow-sm border-0'>
|
||||
{/* 卡片头部 */}
|
||||
<div className='flex items-center mb-4'>
|
||||
<Avatar size='small' color='purple' className='mr-3 shadow-md'>
|
||||
<Settings size={16} />
|
||||
</Avatar>
|
||||
<div>
|
||||
<Typography.Text className='text-lg font-medium'>
|
||||
{t('左侧边栏个人设置')}
|
||||
</Typography.Text>
|
||||
<div className='text-xs text-gray-600'>
|
||||
{t('个性化设置左侧边栏的显示内容')}
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
<div className='mb-4'>
|
||||
<Text type='secondary' className='text-sm text-gray-600'>
|
||||
{t('您可以个性化设置侧边栏的要显示功能')}
|
||||
</Text>
|
||||
</div>
|
||||
|
||||
{sectionConfigs.map((section) => (
|
||||
<div key={section.key} className='mb-6'>
|
||||
{/* 区域标题和总开关 */}
|
||||
<div className='flex justify-between items-center mb-4 p-4 bg-gray-50 rounded-xl border border-gray-200'>
|
||||
<div>
|
||||
<div className='font-semibold text-base text-gray-900 mb-1'>
|
||||
{section.title}
|
||||
</div>
|
||||
<Text className='text-xs text-gray-600'>
|
||||
{section.description}
|
||||
</Text>
|
||||
</div>
|
||||
<Switch
|
||||
checked={sidebarModulesUser[section.key]?.enabled !== false}
|
||||
onChange={handleSectionChange(section.key)}
|
||||
size='default'
|
||||
/>
|
||||
</div>
|
||||
|
||||
{/* 功能模块网格 */}
|
||||
<Row gutter={[12, 12]}>
|
||||
{section.modules.map((module) => (
|
||||
<Col key={module.key} xs={24} sm={12} md={8} lg={6} xl={6}>
|
||||
<Card
|
||||
className={`!rounded-xl border border-gray-200 hover:border-blue-300 transition-all duration-200 ${
|
||||
sidebarModulesUser[section.key]?.enabled !== false
|
||||
? ''
|
||||
: 'opacity-50'
|
||||
}`}
|
||||
bodyStyle={{ padding: '16px' }}
|
||||
hoverable
|
||||
>
|
||||
<div className='flex justify-between items-center h-full'>
|
||||
<div className='flex-1 text-left'>
|
||||
<div className='font-semibold text-sm text-gray-900 mb-1'>
|
||||
{module.title}
|
||||
</div>
|
||||
<Text className='text-xs text-gray-600 leading-relaxed block'>
|
||||
{module.description}
|
||||
</Text>
|
||||
</div>
|
||||
<div className='ml-4'>
|
||||
<Switch
|
||||
checked={
|
||||
sidebarModulesUser[section.key]?.[module.key] !==
|
||||
false
|
||||
}
|
||||
onChange={handleModuleChange(section.key, module.key)}
|
||||
size='default'
|
||||
disabled={
|
||||
sidebarModulesUser[section.key]?.enabled === false
|
||||
}
|
||||
/>
|
||||
</div>
|
||||
</div>
|
||||
</Card>
|
||||
</Col>
|
||||
))}
|
||||
</Row>
|
||||
</div>
|
||||
))}
|
||||
|
||||
{/* 底部按钮 */}
|
||||
<div className='flex justify-end gap-3 mt-6 pt-4 border-t border-gray-200'>
|
||||
<Button
|
||||
type='tertiary'
|
||||
onClick={resetSidebarModules}
|
||||
className='!rounded-lg'
|
||||
>
|
||||
{t('重置为默认')}
|
||||
</Button>
|
||||
<Button
|
||||
type='primary'
|
||||
onClick={onSubmit}
|
||||
loading={loading}
|
||||
className='!rounded-lg'
|
||||
>
|
||||
{t('保存设置')}
|
||||
</Button>
|
||||
</div>
|
||||
</Card>
|
||||
);
|
||||
}
|
||||
Reference in New Issue
Block a user