Compare commits

...

25 Commits

Author SHA1 Message Date
CaIon
04dd761880 fix: update LIKE pattern sanitization for token search
- Change ESCAPE character from '\' to '!' for compatibility with MySQL/PostgreSQL/SQLite
- Adjust sanitization logic to escape '!' and '_' correctly, improving input validation for search queries
2026-02-06 19:52:35 +08:00
CaIon
5ff9bc3851 chore: add fmt import for improved logging in token controller 2026-02-06 18:01:11 +08:00
Calcium-Ion
053699fa98 Merge commit from fork
fix: harden token search with pagination, rate limiting and input validation
2026-02-06 17:54:40 +08:00
CaIon
3e1be18310 fix: harden token search with pagination, rate limiting and input validation
- Add configurable per-user token creation limit (max_user_tokens)
- Sanitize search input patterns to prevent expensive queries
- Add per-user search rate limiting (by user ID)
- Add pagination to search endpoint with strict page size cap
- Skip empty search fields instead of matching nothing
- Hide internal errors from API responses
- Fix Interface2String float64 formatting causing config parse failures
- Add float-string fallback in config system for int/uint fields
2026-02-06 17:52:19 +08:00
Calcium-Ion
f3d6e99b28 Merge pull request #2863 from prnake/feat/claude-opus-4-6
feat: add claude-opus-4-6
2026-02-06 16:18:00 +08:00
Calcium-Ion
6de8dea9b9 Merge commit from fork
🔒 fix(security): sanitize AI-generated HTML to prevent XSS in playground
2026-02-06 16:16:20 +08:00
t0ng7u
ab5456eb10 🔒 fix(security): sanitize AI-generated HTML to prevent XSS in playground
Mitigate XSS vulnerabilities in the playground where AI-generated content
is rendered without sanitization, allowing potential script injection via
prompt injection attacks.

MarkdownRenderer.jsx:
- Replace dangerouslySetInnerHTML with a sandboxed iframe for HTML preview
- Use sandbox="allow-same-origin" to block script execution while allowing
  CSS rendering and iframe height auto-sizing
- Add SandboxedHtmlPreview component with automatic height adjustment

CodeViewer.jsx:
- Add escapeHtml() utility to encode HTML entities before rendering
- Rewrite highlightJson() to process tokens iteratively, escaping each
  token and structural text before wrapping in syntax highlighting spans
- Escape non-JSON and very-large content paths that previously bypassed
  sanitization
- Update linkRegex to correctly match URLs containing & entities

These changes only affect the playground (AI output rendering). Admin-
configured content (home page, about page, footer, notices) remains
unaffected as they use separate code paths and are within the trusted
admin boundary.
2026-02-06 15:10:05 +08:00
Papersnake
8e6071f146 Merge branch 'feat/claude-opus-4-6' of https://github.com/prnake/new-api into feat/claude-opus-4-6 2026-02-06 11:59:14 +08:00
Papersnake
729610beb0 fix: set temperature to 1 2026-02-06 11:56:38 +08:00
Papersnake
c9f5de7048 feat: support adaptive thinking 2026-02-06 11:01:23 +08:00
Papersnake
ff71786d8d fix: aws claude 2026-02-06 09:51:15 +08:00
Papersnake
2504818b5a feat: add claude-opus-4-6 2026-02-06 09:12:45 +08:00
CaIon
9a7a29eed8 Remove deprecated components and hooks 2026-02-05 23:04:49 +08:00
CaIon
4d797e0a5b Update .gitattributes to enhance text file handling and mark additional file types for LF normalization and binary detection 2026-02-05 22:57:32 +08:00
CaIon
3766e3248f Add .gitattributes to mark frontend as vendored 2026-02-05 22:53:07 +08:00
CaIon
b55e42eda7 feat(api): add 'cookie' to passthroughSkipHeaderNamesLower 2026-02-05 22:16:35 +08:00
CaIon
af54ea85d2 feat(oauth): implement custom OAuth provider management #1106
- Add support for custom OAuth providers, including creation, retrieval, updating, and deletion.
- Introduce new model and controller for managing custom OAuth providers.
- Enhance existing OAuth logic to accommodate custom providers.
- Update API routes for custom OAuth provider management.
- Include i18n support for custom OAuth-related messages.
2026-02-05 21:18:43 +08:00
CaIon
632baadb57 feat(oauth): migrate GitHub user identification from login to numeric ID 2026-02-05 20:30:48 +08:00
CaIon
df6c669e73 refactor: unify OAuth providers with i18n support
- Introduce Provider interface pattern for standard OAuth protocols
- Create unified controller/oauth.go with common OAuth logic
- Add OAuthError type for translatable error messages
- Add i18n keys and translations (zh/en) for OAuth messages
- Use common.ApiErrorI18n/ApiSuccessI18n for consistent responses
- Preserve backward compatibility for existing routes and data
2026-02-05 20:21:38 +08:00
Calcium-Ion
c540033985 Merge pull request #2853 from QuantumNous/remove/claude-legacy-models
remove: drop support for claude-2 and claude-1 series models
2026-02-05 17:26:29 +08:00
CaIon
1d611d89d2 remove: drop support for claude-2 and claude-1 series models
- Remove claude-instant-1.2, claude-2, claude-2.0, claude-2.1 from model lists
- Remove /v1/complete endpoint support (legacy completion API)
- Remove RequestModeCompletion and related code paths
- Simplify handler functions by removing requestMode parameter
- Update all channel adaptors that referenced claude handlers
2026-02-05 17:20:46 +08:00
Calcium-Ion
7b1451caa7 Merge pull request #2848 from seefs001/fix/gemini-empty-responses-local-usage
fix: charge local input tokens when Gemini returns empty response
2026-02-05 16:24:23 +08:00
Seefs
ecebd619a4 fix: charge local input tokens when Gemini returns empty response 2026-02-05 15:57:17 +08:00
Seefs
9d73aa44b7 Merge pull request #2826 from dahetaoa/fix-codex-and-sqlite
fix: optimize Codex relay
2026-02-05 13:43:09 +08:00
dahetaoa
05ed9d43af fix(relay/codex): optimize headers and ensure instructions presence 2026-02-04 21:43:33 +00:00
77 changed files with 3930 additions and 3141 deletions

38
.gitattributes vendored Normal file
View File

@@ -0,0 +1,38 @@
# Auto detect text files and perform LF normalization
* text=auto
# Go files
*.go text eol=lf
# Config files
*.json text eol=lf
*.yaml text eol=lf
*.yml text eol=lf
*.toml text eol=lf
*.md text eol=lf
# JavaScript/TypeScript files
*.js text eol=lf
*.jsx text eol=lf
*.ts text eol=lf
*.tsx text eol=lf
*.html text eol=lf
*.css text eol=lf
# Shell scripts
*.sh text eol=lf
# Binary files
*.png binary
*.jpg binary
*.jpeg binary
*.gif binary
*.ico binary
*.woff binary
*.woff2 binary
# ============================================
# GitHub Linguist - Language Detection
# ============================================
# Mark web frontend as vendored so GitHub recognizes this as a Go project
electron/** linguist-vendored

View File

@@ -175,6 +175,10 @@ var (
DownloadRateLimitNum = 10
DownloadRateLimitDuration int64 = 60
// Per-user search rate limit (applies after authentication, keyed by user ID)
SearchRateLimitNum = 10
SearchRateLimitDuration int64 = 60
)
var RateLimitKeyExpirationDuration = 20 * time.Minute

View File

@@ -192,7 +192,7 @@ func Interface2String(inter interface{}) string {
case int:
return fmt.Sprintf("%d", inter.(int))
case float64:
return fmt.Sprintf("%f", inter.(float64))
return strconv.FormatFloat(inter.(float64), 'f', -1, 64)
case bool:
if inter.(bool) {
return "true"

386
controller/custom_oauth.go Normal file
View File

@@ -0,0 +1,386 @@
package controller
import (
"net/http"
"strconv"
"github.com/QuantumNous/new-api/common"
"github.com/QuantumNous/new-api/model"
"github.com/QuantumNous/new-api/oauth"
"github.com/gin-gonic/gin"
)
// CustomOAuthProviderResponse is the response structure for custom OAuth providers
// It excludes sensitive fields like client_secret
type CustomOAuthProviderResponse struct {
Id int `json:"id"`
Name string `json:"name"`
Slug string `json:"slug"`
Enabled bool `json:"enabled"`
ClientId string `json:"client_id"`
AuthorizationEndpoint string `json:"authorization_endpoint"`
TokenEndpoint string `json:"token_endpoint"`
UserInfoEndpoint string `json:"user_info_endpoint"`
Scopes string `json:"scopes"`
UserIdField string `json:"user_id_field"`
UsernameField string `json:"username_field"`
DisplayNameField string `json:"display_name_field"`
EmailField string `json:"email_field"`
WellKnown string `json:"well_known"`
AuthStyle int `json:"auth_style"`
}
func toCustomOAuthProviderResponse(p *model.CustomOAuthProvider) *CustomOAuthProviderResponse {
return &CustomOAuthProviderResponse{
Id: p.Id,
Name: p.Name,
Slug: p.Slug,
Enabled: p.Enabled,
ClientId: p.ClientId,
AuthorizationEndpoint: p.AuthorizationEndpoint,
TokenEndpoint: p.TokenEndpoint,
UserInfoEndpoint: p.UserInfoEndpoint,
Scopes: p.Scopes,
UserIdField: p.UserIdField,
UsernameField: p.UsernameField,
DisplayNameField: p.DisplayNameField,
EmailField: p.EmailField,
WellKnown: p.WellKnown,
AuthStyle: p.AuthStyle,
}
}
// GetCustomOAuthProviders returns all custom OAuth providers
func GetCustomOAuthProviders(c *gin.Context) {
providers, err := model.GetAllCustomOAuthProviders()
if err != nil {
common.ApiError(c, err)
return
}
response := make([]*CustomOAuthProviderResponse, len(providers))
for i, p := range providers {
response[i] = toCustomOAuthProviderResponse(p)
}
c.JSON(http.StatusOK, gin.H{
"success": true,
"message": "",
"data": response,
})
}
// GetCustomOAuthProvider returns a single custom OAuth provider by ID
func GetCustomOAuthProvider(c *gin.Context) {
idStr := c.Param("id")
id, err := strconv.Atoi(idStr)
if err != nil {
common.ApiErrorMsg(c, "无效的 ID")
return
}
provider, err := model.GetCustomOAuthProviderById(id)
if err != nil {
common.ApiErrorMsg(c, "未找到该 OAuth 提供商")
return
}
c.JSON(http.StatusOK, gin.H{
"success": true,
"message": "",
"data": toCustomOAuthProviderResponse(provider),
})
}
// CreateCustomOAuthProviderRequest is the request structure for creating a custom OAuth provider
type CreateCustomOAuthProviderRequest struct {
Name string `json:"name" binding:"required"`
Slug string `json:"slug" binding:"required"`
Enabled bool `json:"enabled"`
ClientId string `json:"client_id" binding:"required"`
ClientSecret string `json:"client_secret" binding:"required"`
AuthorizationEndpoint string `json:"authorization_endpoint" binding:"required"`
TokenEndpoint string `json:"token_endpoint" binding:"required"`
UserInfoEndpoint string `json:"user_info_endpoint" binding:"required"`
Scopes string `json:"scopes"`
UserIdField string `json:"user_id_field"`
UsernameField string `json:"username_field"`
DisplayNameField string `json:"display_name_field"`
EmailField string `json:"email_field"`
WellKnown string `json:"well_known"`
AuthStyle int `json:"auth_style"`
}
// CreateCustomOAuthProvider creates a new custom OAuth provider
func CreateCustomOAuthProvider(c *gin.Context) {
var req CreateCustomOAuthProviderRequest
if err := c.ShouldBindJSON(&req); err != nil {
common.ApiErrorMsg(c, "无效的请求参数: "+err.Error())
return
}
// Check if slug is already taken
if model.IsSlugTaken(req.Slug, 0) {
common.ApiErrorMsg(c, "该 Slug 已被使用")
return
}
// Check if slug conflicts with built-in providers
if oauth.IsProviderRegistered(req.Slug) && !oauth.IsCustomProvider(req.Slug) {
common.ApiErrorMsg(c, "该 Slug 与内置 OAuth 提供商冲突")
return
}
provider := &model.CustomOAuthProvider{
Name: req.Name,
Slug: req.Slug,
Enabled: req.Enabled,
ClientId: req.ClientId,
ClientSecret: req.ClientSecret,
AuthorizationEndpoint: req.AuthorizationEndpoint,
TokenEndpoint: req.TokenEndpoint,
UserInfoEndpoint: req.UserInfoEndpoint,
Scopes: req.Scopes,
UserIdField: req.UserIdField,
UsernameField: req.UsernameField,
DisplayNameField: req.DisplayNameField,
EmailField: req.EmailField,
WellKnown: req.WellKnown,
AuthStyle: req.AuthStyle,
}
if err := model.CreateCustomOAuthProvider(provider); err != nil {
common.ApiError(c, err)
return
}
// Register the provider in the OAuth registry
oauth.RegisterOrUpdateCustomProvider(provider)
c.JSON(http.StatusOK, gin.H{
"success": true,
"message": "创建成功",
"data": toCustomOAuthProviderResponse(provider),
})
}
// UpdateCustomOAuthProviderRequest is the request structure for updating a custom OAuth provider
type UpdateCustomOAuthProviderRequest struct {
Name string `json:"name"`
Slug string `json:"slug"`
Enabled bool `json:"enabled"`
ClientId string `json:"client_id"`
ClientSecret string `json:"client_secret"` // Optional: if empty, keep existing
AuthorizationEndpoint string `json:"authorization_endpoint"`
TokenEndpoint string `json:"token_endpoint"`
UserInfoEndpoint string `json:"user_info_endpoint"`
Scopes string `json:"scopes"`
UserIdField string `json:"user_id_field"`
UsernameField string `json:"username_field"`
DisplayNameField string `json:"display_name_field"`
EmailField string `json:"email_field"`
WellKnown string `json:"well_known"`
AuthStyle int `json:"auth_style"`
}
// UpdateCustomOAuthProvider updates an existing custom OAuth provider
func UpdateCustomOAuthProvider(c *gin.Context) {
idStr := c.Param("id")
id, err := strconv.Atoi(idStr)
if err != nil {
common.ApiErrorMsg(c, "无效的 ID")
return
}
var req UpdateCustomOAuthProviderRequest
if err := c.ShouldBindJSON(&req); err != nil {
common.ApiErrorMsg(c, "无效的请求参数: "+err.Error())
return
}
// Get existing provider
provider, err := model.GetCustomOAuthProviderById(id)
if err != nil {
common.ApiErrorMsg(c, "未找到该 OAuth 提供商")
return
}
oldSlug := provider.Slug
// Check if new slug is taken by another provider
if req.Slug != "" && req.Slug != provider.Slug {
if model.IsSlugTaken(req.Slug, id) {
common.ApiErrorMsg(c, "该 Slug 已被使用")
return
}
// Check if slug conflicts with built-in providers
if oauth.IsProviderRegistered(req.Slug) && !oauth.IsCustomProvider(req.Slug) {
common.ApiErrorMsg(c, "该 Slug 与内置 OAuth 提供商冲突")
return
}
}
// Update fields
if req.Name != "" {
provider.Name = req.Name
}
if req.Slug != "" {
provider.Slug = req.Slug
}
provider.Enabled = req.Enabled
if req.ClientId != "" {
provider.ClientId = req.ClientId
}
if req.ClientSecret != "" {
provider.ClientSecret = req.ClientSecret
}
if req.AuthorizationEndpoint != "" {
provider.AuthorizationEndpoint = req.AuthorizationEndpoint
}
if req.TokenEndpoint != "" {
provider.TokenEndpoint = req.TokenEndpoint
}
if req.UserInfoEndpoint != "" {
provider.UserInfoEndpoint = req.UserInfoEndpoint
}
if req.Scopes != "" {
provider.Scopes = req.Scopes
}
if req.UserIdField != "" {
provider.UserIdField = req.UserIdField
}
if req.UsernameField != "" {
provider.UsernameField = req.UsernameField
}
if req.DisplayNameField != "" {
provider.DisplayNameField = req.DisplayNameField
}
if req.EmailField != "" {
provider.EmailField = req.EmailField
}
provider.WellKnown = req.WellKnown
provider.AuthStyle = req.AuthStyle
if err := model.UpdateCustomOAuthProvider(provider); err != nil {
common.ApiError(c, err)
return
}
// Update the provider in the OAuth registry
if oldSlug != provider.Slug {
oauth.UnregisterCustomProvider(oldSlug)
}
oauth.RegisterOrUpdateCustomProvider(provider)
c.JSON(http.StatusOK, gin.H{
"success": true,
"message": "更新成功",
"data": toCustomOAuthProviderResponse(provider),
})
}
// DeleteCustomOAuthProvider deletes a custom OAuth provider
func DeleteCustomOAuthProvider(c *gin.Context) {
idStr := c.Param("id")
id, err := strconv.Atoi(idStr)
if err != nil {
common.ApiErrorMsg(c, "无效的 ID")
return
}
// Get existing provider to get slug
provider, err := model.GetCustomOAuthProviderById(id)
if err != nil {
common.ApiErrorMsg(c, "未找到该 OAuth 提供商")
return
}
// Check if there are any user bindings
count, _ := model.GetBindingCountByProviderId(id)
if count > 0 {
common.ApiErrorMsg(c, "该 OAuth 提供商还有用户绑定,无法删除。请先解除所有用户绑定。")
return
}
if err := model.DeleteCustomOAuthProvider(id); err != nil {
common.ApiError(c, err)
return
}
// Unregister the provider from the OAuth registry
oauth.UnregisterCustomProvider(provider.Slug)
c.JSON(http.StatusOK, gin.H{
"success": true,
"message": "删除成功",
})
}
// GetUserOAuthBindings returns all OAuth bindings for the current user
func GetUserOAuthBindings(c *gin.Context) {
userId := c.GetInt("id")
if userId == 0 {
common.ApiErrorMsg(c, "未登录")
return
}
bindings, err := model.GetUserOAuthBindingsByUserId(userId)
if err != nil {
common.ApiError(c, err)
return
}
// Build response with provider info
type BindingResponse struct {
ProviderId int `json:"provider_id"`
ProviderName string `json:"provider_name"`
ProviderSlug string `json:"provider_slug"`
ProviderUserId string `json:"provider_user_id"`
}
response := make([]BindingResponse, 0)
for _, binding := range bindings {
provider, err := model.GetCustomOAuthProviderById(binding.ProviderId)
if err != nil {
continue // Skip if provider not found
}
response = append(response, BindingResponse{
ProviderId: binding.ProviderId,
ProviderName: provider.Name,
ProviderSlug: provider.Slug,
ProviderUserId: binding.ProviderUserId,
})
}
c.JSON(http.StatusOK, gin.H{
"success": true,
"message": "",
"data": response,
})
}
// UnbindCustomOAuth unbinds a custom OAuth provider from the current user
func UnbindCustomOAuth(c *gin.Context) {
userId := c.GetInt("id")
if userId == 0 {
common.ApiErrorMsg(c, "未登录")
return
}
providerIdStr := c.Param("provider_id")
providerId, err := strconv.Atoi(providerIdStr)
if err != nil {
common.ApiErrorMsg(c, "无效的提供商 ID")
return
}
if err := model.DeleteUserOAuthBinding(userId, providerId); err != nil {
common.ApiError(c, err)
return
}
c.JSON(http.StatusOK, gin.H{
"success": true,
"message": "解绑成功",
})
}

View File

@@ -1,223 +0,0 @@
package controller
import (
"encoding/json"
"errors"
"fmt"
"net/http"
"net/url"
"strconv"
"strings"
"time"
"github.com/QuantumNous/new-api/common"
"github.com/QuantumNous/new-api/model"
"github.com/QuantumNous/new-api/setting/system_setting"
"github.com/gin-contrib/sessions"
"github.com/gin-gonic/gin"
)
type DiscordResponse struct {
AccessToken string `json:"access_token"`
IDToken string `json:"id_token"`
RefreshToken string `json:"refresh_token"`
TokenType string `json:"token_type"`
ExpiresIn int `json:"expires_in"`
Scope string `json:"scope"`
}
type DiscordUser struct {
UID string `json:"id"`
ID string `json:"username"`
Name string `json:"global_name"`
}
func getDiscordUserInfoByCode(code string) (*DiscordUser, error) {
if code == "" {
return nil, errors.New("无效的参数")
}
values := url.Values{}
values.Set("client_id", system_setting.GetDiscordSettings().ClientId)
values.Set("client_secret", system_setting.GetDiscordSettings().ClientSecret)
values.Set("code", code)
values.Set("grant_type", "authorization_code")
values.Set("redirect_uri", fmt.Sprintf("%s/oauth/discord", system_setting.ServerAddress))
formData := values.Encode()
req, err := http.NewRequest("POST", "https://discord.com/api/v10/oauth2/token", strings.NewReader(formData))
if err != nil {
return nil, err
}
req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
req.Header.Set("Accept", "application/json")
client := http.Client{
Timeout: 5 * time.Second,
}
res, err := client.Do(req)
if err != nil {
common.SysLog(err.Error())
return nil, errors.New("无法连接至 Discord 服务器,请稍后重试!")
}
defer res.Body.Close()
var discordResponse DiscordResponse
err = json.NewDecoder(res.Body).Decode(&discordResponse)
if err != nil {
return nil, err
}
if discordResponse.AccessToken == "" {
common.SysError("Discord 获取 Token 失败,请检查设置!")
return nil, errors.New("Discord 获取 Token 失败,请检查设置!")
}
req, err = http.NewRequest("GET", "https://discord.com/api/v10/users/@me", nil)
if err != nil {
return nil, err
}
req.Header.Set("Authorization", "Bearer "+discordResponse.AccessToken)
res2, err := client.Do(req)
if err != nil {
common.SysLog(err.Error())
return nil, errors.New("无法连接至 Discord 服务器,请稍后重试!")
}
defer res2.Body.Close()
if res2.StatusCode != http.StatusOK {
common.SysError("Discord 获取用户信息失败!请检查设置!")
return nil, errors.New("Discord 获取用户信息失败!请检查设置!")
}
var discordUser DiscordUser
err = json.NewDecoder(res2.Body).Decode(&discordUser)
if err != nil {
return nil, err
}
if discordUser.UID == "" || discordUser.ID == "" {
common.SysError("Discord 获取用户信息为空!请检查设置!")
return nil, errors.New("Discord 获取用户信息为空!请检查设置!")
}
return &discordUser, nil
}
func DiscordOAuth(c *gin.Context) {
session := sessions.Default(c)
state := c.Query("state")
if state == "" || session.Get("oauth_state") == nil || state != session.Get("oauth_state").(string) {
c.JSON(http.StatusForbidden, gin.H{
"success": false,
"message": "state is empty or not same",
})
return
}
username := session.Get("username")
if username != nil {
DiscordBind(c)
return
}
if !system_setting.GetDiscordSettings().Enabled {
c.JSON(http.StatusOK, gin.H{
"success": false,
"message": "管理员未开启通过 Discord 登录以及注册",
})
return
}
code := c.Query("code")
discordUser, err := getDiscordUserInfoByCode(code)
if err != nil {
common.ApiError(c, err)
return
}
user := model.User{
DiscordId: discordUser.UID,
}
if model.IsDiscordIdAlreadyTaken(user.DiscordId) {
err := user.FillUserByDiscordId()
if err != nil {
c.JSON(http.StatusOK, gin.H{
"success": false,
"message": err.Error(),
})
return
}
} else {
if common.RegisterEnabled {
if discordUser.ID != "" {
user.Username = discordUser.ID
} else {
user.Username = "discord_" + strconv.Itoa(model.GetMaxUserId()+1)
}
if discordUser.Name != "" {
user.DisplayName = discordUser.Name
} else {
user.DisplayName = "Discord User"
}
err := user.Insert(0)
if err != nil {
c.JSON(http.StatusOK, gin.H{
"success": false,
"message": err.Error(),
})
return
}
} else {
c.JSON(http.StatusOK, gin.H{
"success": false,
"message": "管理员关闭了新用户注册",
})
return
}
}
if user.Status != common.UserStatusEnabled {
c.JSON(http.StatusOK, gin.H{
"message": "用户已被封禁",
"success": false,
})
return
}
setupLogin(&user, c)
}
func DiscordBind(c *gin.Context) {
if !system_setting.GetDiscordSettings().Enabled {
c.JSON(http.StatusOK, gin.H{
"success": false,
"message": "管理员未开启通过 Discord 登录以及注册",
})
return
}
code := c.Query("code")
discordUser, err := getDiscordUserInfoByCode(code)
if err != nil {
common.ApiError(c, err)
return
}
user := model.User{
DiscordId: discordUser.UID,
}
if model.IsDiscordIdAlreadyTaken(user.DiscordId) {
c.JSON(http.StatusOK, gin.H{
"success": false,
"message": "该 Discord 账户已被绑定",
})
return
}
session := sessions.Default(c)
id := session.Get("id")
user.Id = id.(int)
err = user.FillUserById()
if err != nil {
common.ApiError(c, err)
return
}
user.DiscordId = discordUser.UID
err = user.Update(false)
if err != nil {
common.ApiError(c, err)
return
}
c.JSON(http.StatusOK, gin.H{
"success": true,
"message": "bind",
})
}

View File

@@ -1,240 +0,0 @@
package controller
import (
"bytes"
"encoding/json"
"errors"
"fmt"
"net/http"
"strconv"
"time"
"github.com/QuantumNous/new-api/common"
"github.com/QuantumNous/new-api/model"
"github.com/gin-contrib/sessions"
"github.com/gin-gonic/gin"
)
type GitHubOAuthResponse struct {
AccessToken string `json:"access_token"`
Scope string `json:"scope"`
TokenType string `json:"token_type"`
}
type GitHubUser struct {
Login string `json:"login"`
Name string `json:"name"`
Email string `json:"email"`
}
func getGitHubUserInfoByCode(code string) (*GitHubUser, error) {
if code == "" {
return nil, errors.New("无效的参数")
}
values := map[string]string{"client_id": common.GitHubClientId, "client_secret": common.GitHubClientSecret, "code": code}
jsonData, err := json.Marshal(values)
if err != nil {
return nil, err
}
req, err := http.NewRequest("POST", "https://github.com/login/oauth/access_token", bytes.NewBuffer(jsonData))
if err != nil {
return nil, err
}
req.Header.Set("Content-Type", "application/json")
req.Header.Set("Accept", "application/json")
client := http.Client{
Timeout: 20 * time.Second,
}
res, err := client.Do(req)
if err != nil {
common.SysLog(err.Error())
return nil, errors.New("无法连接至 GitHub 服务器,请稍后重试!")
}
defer res.Body.Close()
var oAuthResponse GitHubOAuthResponse
err = json.NewDecoder(res.Body).Decode(&oAuthResponse)
if err != nil {
return nil, err
}
req, err = http.NewRequest("GET", "https://api.github.com/user", nil)
if err != nil {
return nil, err
}
req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", oAuthResponse.AccessToken))
res2, err := client.Do(req)
if err != nil {
common.SysLog(err.Error())
return nil, errors.New("无法连接至 GitHub 服务器,请稍后重试!")
}
defer res2.Body.Close()
var githubUser GitHubUser
err = json.NewDecoder(res2.Body).Decode(&githubUser)
if err != nil {
return nil, err
}
if githubUser.Login == "" {
return nil, errors.New("返回值非法,用户字段为空,请稍后重试!")
}
return &githubUser, nil
}
func GitHubOAuth(c *gin.Context) {
session := sessions.Default(c)
state := c.Query("state")
if state == "" || session.Get("oauth_state") == nil || state != session.Get("oauth_state").(string) {
c.JSON(http.StatusForbidden, gin.H{
"success": false,
"message": "state is empty or not same",
})
return
}
username := session.Get("username")
if username != nil {
GitHubBind(c)
return
}
if !common.GitHubOAuthEnabled {
c.JSON(http.StatusOK, gin.H{
"success": false,
"message": "管理员未开启通过 GitHub 登录以及注册",
})
return
}
code := c.Query("code")
githubUser, err := getGitHubUserInfoByCode(code)
if err != nil {
common.ApiError(c, err)
return
}
user := model.User{
GitHubId: githubUser.Login,
}
// IsGitHubIdAlreadyTaken is unscoped
if model.IsGitHubIdAlreadyTaken(user.GitHubId) {
// FillUserByGitHubId is scoped
err := user.FillUserByGitHubId()
if err != nil {
c.JSON(http.StatusOK, gin.H{
"success": false,
"message": err.Error(),
})
return
}
// if user.Id == 0 , user has been deleted
if user.Id == 0 {
c.JSON(http.StatusOK, gin.H{
"success": false,
"message": "用户已注销",
})
return
}
} else {
if common.RegisterEnabled {
user.Username = "github_" + strconv.Itoa(model.GetMaxUserId()+1)
if githubUser.Name != "" {
user.DisplayName = githubUser.Name
} else {
user.DisplayName = "GitHub User"
}
user.Email = githubUser.Email
user.Role = common.RoleCommonUser
user.Status = common.UserStatusEnabled
affCode := session.Get("aff")
inviterId := 0
if affCode != nil {
inviterId, _ = model.GetUserIdByAffCode(affCode.(string))
}
if err := user.Insert(inviterId); err != nil {
c.JSON(http.StatusOK, gin.H{
"success": false,
"message": err.Error(),
})
return
}
} else {
c.JSON(http.StatusOK, gin.H{
"success": false,
"message": "管理员关闭了新用户注册",
})
return
}
}
if user.Status != common.UserStatusEnabled {
c.JSON(http.StatusOK, gin.H{
"message": "用户已被封禁",
"success": false,
})
return
}
setupLogin(&user, c)
}
func GitHubBind(c *gin.Context) {
if !common.GitHubOAuthEnabled {
c.JSON(http.StatusOK, gin.H{
"success": false,
"message": "管理员未开启通过 GitHub 登录以及注册",
})
return
}
code := c.Query("code")
githubUser, err := getGitHubUserInfoByCode(code)
if err != nil {
common.ApiError(c, err)
return
}
user := model.User{
GitHubId: githubUser.Login,
}
if model.IsGitHubIdAlreadyTaken(user.GitHubId) {
c.JSON(http.StatusOK, gin.H{
"success": false,
"message": "该 GitHub 账户已被绑定",
})
return
}
session := sessions.Default(c)
id := session.Get("id")
// id := c.GetInt("id") // critical bug!
user.Id = id.(int)
err = user.FillUserById()
if err != nil {
common.ApiError(c, err)
return
}
user.GitHubId = githubUser.Login
err = user.Update(false)
if err != nil {
common.ApiError(c, err)
return
}
c.JSON(http.StatusOK, gin.H{
"success": true,
"message": "bind",
})
return
}
func GenerateOAuthCode(c *gin.Context) {
session := sessions.Default(c)
state := common.GetRandomString(12)
affCode := c.Query("aff")
if affCode != "" {
session.Set("aff", affCode)
}
session.Set("oauth_state", state)
err := session.Save()
if err != nil {
common.ApiError(c, err)
return
}
c.JSON(http.StatusOK, gin.H{
"success": true,
"message": "",
"data": state,
})
}

View File

@@ -1,268 +0,0 @@
package controller
import (
"encoding/base64"
"encoding/json"
"errors"
"fmt"
"net/http"
"net/url"
"strconv"
"strings"
"time"
"github.com/QuantumNous/new-api/common"
"github.com/QuantumNous/new-api/model"
"github.com/gin-contrib/sessions"
"github.com/gin-gonic/gin"
)
type LinuxdoUser struct {
Id int `json:"id"`
Username string `json:"username"`
Name string `json:"name"`
Active bool `json:"active"`
TrustLevel int `json:"trust_level"`
Silenced bool `json:"silenced"`
}
func LinuxDoBind(c *gin.Context) {
if !common.LinuxDOOAuthEnabled {
c.JSON(http.StatusOK, gin.H{
"success": false,
"message": "管理员未开启通过 Linux DO 登录以及注册",
})
return
}
code := c.Query("code")
linuxdoUser, err := getLinuxdoUserInfoByCode(code, c)
if err != nil {
common.ApiError(c, err)
return
}
user := model.User{
LinuxDOId: strconv.Itoa(linuxdoUser.Id),
}
if model.IsLinuxDOIdAlreadyTaken(user.LinuxDOId) {
c.JSON(http.StatusOK, gin.H{
"success": false,
"message": "该 Linux DO 账户已被绑定",
})
return
}
session := sessions.Default(c)
id := session.Get("id")
user.Id = id.(int)
err = user.FillUserById()
if err != nil {
common.ApiError(c, err)
return
}
user.LinuxDOId = strconv.Itoa(linuxdoUser.Id)
err = user.Update(false)
if err != nil {
common.ApiError(c, err)
return
}
c.JSON(http.StatusOK, gin.H{
"success": true,
"message": "bind",
})
}
func getLinuxdoUserInfoByCode(code string, c *gin.Context) (*LinuxdoUser, error) {
if code == "" {
return nil, errors.New("invalid code")
}
// Get access token using Basic auth
tokenEndpoint := common.GetEnvOrDefaultString("LINUX_DO_TOKEN_ENDPOINT", "https://connect.linux.do/oauth2/token")
credentials := common.LinuxDOClientId + ":" + common.LinuxDOClientSecret
basicAuth := "Basic " + base64.StdEncoding.EncodeToString([]byte(credentials))
// Get redirect URI from request
scheme := "http"
if c.Request.TLS != nil {
scheme = "https"
}
redirectURI := fmt.Sprintf("%s://%s/api/oauth/linuxdo", scheme, c.Request.Host)
data := url.Values{}
data.Set("grant_type", "authorization_code")
data.Set("code", code)
data.Set("redirect_uri", redirectURI)
req, err := http.NewRequest("POST", tokenEndpoint, strings.NewReader(data.Encode()))
if err != nil {
return nil, err
}
req.Header.Set("Authorization", basicAuth)
req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
req.Header.Set("Accept", "application/json")
client := http.Client{Timeout: 5 * time.Second}
res, err := client.Do(req)
if err != nil {
return nil, errors.New("failed to connect to Linux DO server")
}
defer res.Body.Close()
var tokenRes struct {
AccessToken string `json:"access_token"`
Message string `json:"message"`
}
if err := json.NewDecoder(res.Body).Decode(&tokenRes); err != nil {
return nil, err
}
if tokenRes.AccessToken == "" {
return nil, fmt.Errorf("failed to get access token: %s", tokenRes.Message)
}
// Get user info
userEndpoint := common.GetEnvOrDefaultString("LINUX_DO_USER_ENDPOINT", "https://connect.linux.do/api/user")
req, err = http.NewRequest("GET", userEndpoint, nil)
if err != nil {
return nil, err
}
req.Header.Set("Authorization", "Bearer "+tokenRes.AccessToken)
req.Header.Set("Accept", "application/json")
res2, err := client.Do(req)
if err != nil {
return nil, errors.New("failed to get user info from Linux DO")
}
defer res2.Body.Close()
var linuxdoUser LinuxdoUser
if err := json.NewDecoder(res2.Body).Decode(&linuxdoUser); err != nil {
return nil, err
}
if linuxdoUser.Id == 0 {
return nil, errors.New("invalid user info returned")
}
return &linuxdoUser, nil
}
func LinuxdoOAuth(c *gin.Context) {
session := sessions.Default(c)
errorCode := c.Query("error")
if errorCode != "" {
errorDescription := c.Query("error_description")
c.JSON(http.StatusOK, gin.H{
"success": false,
"message": errorDescription,
})
return
}
state := c.Query("state")
if state == "" || session.Get("oauth_state") == nil || state != session.Get("oauth_state").(string) {
c.JSON(http.StatusForbidden, gin.H{
"success": false,
"message": "state is empty or not same",
})
return
}
username := session.Get("username")
if username != nil {
LinuxDoBind(c)
return
}
if !common.LinuxDOOAuthEnabled {
c.JSON(http.StatusOK, gin.H{
"success": false,
"message": "管理员未开启通过 Linux DO 登录以及注册",
})
return
}
code := c.Query("code")
linuxdoUser, err := getLinuxdoUserInfoByCode(code, c)
if err != nil {
common.ApiError(c, err)
return
}
user := model.User{
LinuxDOId: strconv.Itoa(linuxdoUser.Id),
}
// Check if user exists
if model.IsLinuxDOIdAlreadyTaken(user.LinuxDOId) {
err := user.FillUserByLinuxDOId()
if err != nil {
c.JSON(http.StatusOK, gin.H{
"success": false,
"message": err.Error(),
})
return
}
if user.Id == 0 {
c.JSON(http.StatusOK, gin.H{
"success": false,
"message": "用户已注销",
})
return
}
} else {
if common.RegisterEnabled {
if linuxdoUser.TrustLevel >= common.LinuxDOMinimumTrustLevel {
user.Username = "linuxdo_" + strconv.Itoa(model.GetMaxUserId()+1)
user.DisplayName = linuxdoUser.Name
user.Role = common.RoleCommonUser
user.Status = common.UserStatusEnabled
affCode := session.Get("aff")
inviterId := 0
if affCode != nil {
inviterId, _ = model.GetUserIdByAffCode(affCode.(string))
}
if err := user.Insert(inviterId); err != nil {
c.JSON(http.StatusOK, gin.H{
"success": false,
"message": err.Error(),
})
return
}
} else {
c.JSON(http.StatusOK, gin.H{
"success": false,
"message": "Linux DO 信任等级未达到管理员设置的最低信任等级",
})
return
}
} else {
c.JSON(http.StatusOK, gin.H{
"success": false,
"message": "管理员关闭了新用户注册",
})
return
}
}
if user.Status != common.UserStatusEnabled {
c.JSON(http.StatusOK, gin.H{
"message": "用户已被封禁",
"success": false,
})
return
}
setupLogin(&user, c)
}

View File

@@ -10,6 +10,7 @@ import (
"github.com/QuantumNous/new-api/constant"
"github.com/QuantumNous/new-api/middleware"
"github.com/QuantumNous/new-api/model"
"github.com/QuantumNous/new-api/oauth"
"github.com/QuantumNous/new-api/setting"
"github.com/QuantumNous/new-api/setting/console_setting"
"github.com/QuantumNous/new-api/setting/operation_setting"
@@ -129,6 +130,30 @@ func GetStatus(c *gin.Context) {
data["faq"] = console_setting.GetFAQ()
}
// Add enabled custom OAuth providers
customProviders := oauth.GetEnabledCustomProviders()
if len(customProviders) > 0 {
type CustomOAuthInfo struct {
Name string `json:"name"`
Slug string `json:"slug"`
ClientId string `json:"client_id"`
AuthorizationEndpoint string `json:"authorization_endpoint"`
Scopes string `json:"scopes"`
}
providersInfo := make([]CustomOAuthInfo, 0, len(customProviders))
for _, p := range customProviders {
config := p.GetConfig()
providersInfo = append(providersInfo, CustomOAuthInfo{
Name: config.Name,
Slug: config.Slug,
ClientId: config.ClientId,
AuthorizationEndpoint: config.AuthorizationEndpoint,
Scopes: config.Scopes,
})
}
data["custom_oauth_providers"] = providersInfo
}
c.JSON(http.StatusOK, gin.H{
"success": true,
"message": "",

312
controller/oauth.go Normal file
View File

@@ -0,0 +1,312 @@
package controller
import (
"fmt"
"net/http"
"strconv"
"github.com/QuantumNous/new-api/common"
"github.com/QuantumNous/new-api/i18n"
"github.com/QuantumNous/new-api/model"
"github.com/QuantumNous/new-api/oauth"
"github.com/gin-contrib/sessions"
"github.com/gin-gonic/gin"
)
// providerParams returns map with Provider key for i18n templates
func providerParams(name string) map[string]any {
return map[string]any{"Provider": name}
}
// GenerateOAuthCode generates a state code for OAuth CSRF protection
func GenerateOAuthCode(c *gin.Context) {
session := sessions.Default(c)
state := common.GetRandomString(12)
affCode := c.Query("aff")
if affCode != "" {
session.Set("aff", affCode)
}
session.Set("oauth_state", state)
err := session.Save()
if err != nil {
common.ApiError(c, err)
return
}
c.JSON(http.StatusOK, gin.H{
"success": true,
"message": "",
"data": state,
})
}
// HandleOAuth handles OAuth callback for all standard OAuth providers
func HandleOAuth(c *gin.Context) {
providerName := c.Param("provider")
provider := oauth.GetProvider(providerName)
if provider == nil {
c.JSON(http.StatusBadRequest, gin.H{
"success": false,
"message": i18n.T(c, i18n.MsgOAuthUnknownProvider),
})
return
}
session := sessions.Default(c)
// 1. Validate state (CSRF protection)
state := c.Query("state")
if state == "" || session.Get("oauth_state") == nil || state != session.Get("oauth_state").(string) {
c.JSON(http.StatusForbidden, gin.H{
"success": false,
"message": i18n.T(c, i18n.MsgOAuthStateInvalid),
})
return
}
// 2. Check if user is already logged in (bind flow)
username := session.Get("username")
if username != nil {
handleOAuthBind(c, provider)
return
}
// 3. Check if provider is enabled
if !provider.IsEnabled() {
common.ApiErrorI18n(c, i18n.MsgOAuthNotEnabled, providerParams(provider.GetName()))
return
}
// 4. Handle error from provider
errorCode := c.Query("error")
if errorCode != "" {
errorDescription := c.Query("error_description")
c.JSON(http.StatusOK, gin.H{
"success": false,
"message": errorDescription,
})
return
}
// 5. Exchange code for token
code := c.Query("code")
token, err := provider.ExchangeToken(c.Request.Context(), code, c)
if err != nil {
handleOAuthError(c, err)
return
}
// 6. Get user info
oauthUser, err := provider.GetUserInfo(c.Request.Context(), token)
if err != nil {
handleOAuthError(c, err)
return
}
// 7. Find or create user
user, err := findOrCreateOAuthUser(c, provider, oauthUser, session)
if err != nil {
switch err.(type) {
case *OAuthUserDeletedError:
common.ApiErrorI18n(c, i18n.MsgOAuthUserDeleted)
case *OAuthRegistrationDisabledError:
common.ApiErrorI18n(c, i18n.MsgUserRegisterDisabled)
default:
common.ApiError(c, err)
}
return
}
// 8. Check user status
if user.Status != common.UserStatusEnabled {
common.ApiErrorI18n(c, i18n.MsgOAuthUserBanned)
return
}
// 9. Setup login
setupLogin(user, c)
}
// handleOAuthBind handles binding OAuth account to existing user
func handleOAuthBind(c *gin.Context, provider oauth.Provider) {
if !provider.IsEnabled() {
common.ApiErrorI18n(c, i18n.MsgOAuthNotEnabled, providerParams(provider.GetName()))
return
}
// Exchange code for token
code := c.Query("code")
token, err := provider.ExchangeToken(c.Request.Context(), code, c)
if err != nil {
handleOAuthError(c, err)
return
}
// Get user info
oauthUser, err := provider.GetUserInfo(c.Request.Context(), token)
if err != nil {
handleOAuthError(c, err)
return
}
// Check if this OAuth account is already bound (check both new ID and legacy ID)
if provider.IsUserIDTaken(oauthUser.ProviderUserID) {
common.ApiErrorI18n(c, i18n.MsgOAuthAlreadyBound, providerParams(provider.GetName()))
return
}
// Also check legacy ID to prevent duplicate bindings during migration period
if legacyID, ok := oauthUser.Extra["legacy_id"].(string); ok && legacyID != "" {
if provider.IsUserIDTaken(legacyID) {
common.ApiErrorI18n(c, i18n.MsgOAuthAlreadyBound, providerParams(provider.GetName()))
return
}
}
// Get current user from session
session := sessions.Default(c)
id := session.Get("id")
user := model.User{Id: id.(int)}
err = user.FillUserById()
if err != nil {
common.ApiError(c, err)
return
}
// Handle binding based on provider type
if genericProvider, ok := provider.(*oauth.GenericOAuthProvider); ok {
// Custom provider: use user_oauth_bindings table
err = model.UpdateUserOAuthBinding(user.Id, genericProvider.GetProviderId(), oauthUser.ProviderUserID)
if err != nil {
common.ApiError(c, err)
return
}
} else {
// Built-in provider: update user record directly
provider.SetProviderUserID(&user, oauthUser.ProviderUserID)
err = user.Update(false)
if err != nil {
common.ApiError(c, err)
return
}
}
common.ApiSuccessI18n(c, i18n.MsgOAuthBindSuccess, nil)
}
// findOrCreateOAuthUser finds existing user or creates new user
func findOrCreateOAuthUser(c *gin.Context, provider oauth.Provider, oauthUser *oauth.OAuthUser, session sessions.Session) (*model.User, error) {
user := &model.User{}
// Check if user already exists with new ID
if provider.IsUserIDTaken(oauthUser.ProviderUserID) {
err := provider.FillUserByProviderID(user, oauthUser.ProviderUserID)
if err != nil {
return nil, err
}
// Check if user has been deleted
if user.Id == 0 {
return nil, &OAuthUserDeletedError{}
}
return user, nil
}
// Try to find user with legacy ID (for GitHub migration from login to numeric ID)
if legacyID, ok := oauthUser.Extra["legacy_id"].(string); ok && legacyID != "" {
if provider.IsUserIDTaken(legacyID) {
err := provider.FillUserByProviderID(user, legacyID)
if err != nil {
return nil, err
}
if user.Id != 0 {
// Found user with legacy ID, migrate to new ID
common.SysLog(fmt.Sprintf("[OAuth] Migrating user %d from legacy_id=%s to new_id=%s",
user.Id, legacyID, oauthUser.ProviderUserID))
if err := user.UpdateGitHubId(oauthUser.ProviderUserID); err != nil {
common.SysError(fmt.Sprintf("[OAuth] Failed to migrate user %d: %s", user.Id, err.Error()))
// Continue with login even if migration fails
}
return user, nil
}
}
}
// User doesn't exist, create new user if registration is enabled
if !common.RegisterEnabled {
return nil, &OAuthRegistrationDisabledError{}
}
// Set up new user
user.Username = provider.GetProviderPrefix() + strconv.Itoa(model.GetMaxUserId()+1)
if oauthUser.DisplayName != "" {
user.DisplayName = oauthUser.DisplayName
} else if oauthUser.Username != "" {
user.DisplayName = oauthUser.Username
} else {
user.DisplayName = provider.GetName() + " User"
}
if oauthUser.Email != "" {
user.Email = oauthUser.Email
}
user.Role = common.RoleCommonUser
user.Status = common.UserStatusEnabled
// Handle affiliate code
affCode := session.Get("aff")
inviterId := 0
if affCode != nil {
inviterId, _ = model.GetUserIdByAffCode(affCode.(string))
}
if err := user.Insert(inviterId); err != nil {
return nil, err
}
// For custom providers, create the binding after user is created
if genericProvider, ok := provider.(*oauth.GenericOAuthProvider); ok {
binding := &model.UserOAuthBinding{
UserId: user.Id,
ProviderId: genericProvider.GetProviderId(),
ProviderUserId: oauthUser.ProviderUserID,
}
if err := model.CreateUserOAuthBinding(binding); err != nil {
common.SysError(fmt.Sprintf("[OAuth] Failed to create binding for user %d: %s", user.Id, err.Error()))
// Don't fail the registration, just log the error
}
} else {
// Built-in provider: set the provider user ID on the user model
provider.SetProviderUserID(user, oauthUser.ProviderUserID)
if err := user.Update(false); err != nil {
common.SysError(fmt.Sprintf("[OAuth] Failed to update provider ID for user %d: %s", user.Id, err.Error()))
}
}
return user, nil
}
// Error types for OAuth
type OAuthUserDeletedError struct{}
func (e *OAuthUserDeletedError) Error() string {
return "user has been deleted"
}
type OAuthRegistrationDisabledError struct{}
func (e *OAuthRegistrationDisabledError) Error() string {
return "registration is disabled"
}
// handleOAuthError handles OAuth errors and returns translated message
func handleOAuthError(c *gin.Context, err error) {
switch e := err.(type) {
case *oauth.OAuthError:
if e.Params != nil {
common.ApiErrorI18n(c, e.MsgKey, e.Params)
} else {
common.ApiErrorI18n(c, e.MsgKey)
}
case *oauth.TrustLevelError:
common.ApiErrorI18n(c, i18n.MsgOAuthTrustLevelLow)
default:
common.ApiError(c, err)
}
}

View File

@@ -1,228 +0,0 @@
package controller
import (
"encoding/json"
"errors"
"fmt"
"net/http"
"net/url"
"strconv"
"strings"
"time"
"github.com/QuantumNous/new-api/common"
"github.com/QuantumNous/new-api/model"
"github.com/QuantumNous/new-api/setting/system_setting"
"github.com/gin-contrib/sessions"
"github.com/gin-gonic/gin"
)
type OidcResponse struct {
AccessToken string `json:"access_token"`
IDToken string `json:"id_token"`
RefreshToken string `json:"refresh_token"`
TokenType string `json:"token_type"`
ExpiresIn int `json:"expires_in"`
Scope string `json:"scope"`
}
type OidcUser struct {
OpenID string `json:"sub"`
Email string `json:"email"`
Name string `json:"name"`
PreferredUsername string `json:"preferred_username"`
Picture string `json:"picture"`
}
func getOidcUserInfoByCode(code string) (*OidcUser, error) {
if code == "" {
return nil, errors.New("无效的参数")
}
values := url.Values{}
values.Set("client_id", system_setting.GetOIDCSettings().ClientId)
values.Set("client_secret", system_setting.GetOIDCSettings().ClientSecret)
values.Set("code", code)
values.Set("grant_type", "authorization_code")
values.Set("redirect_uri", fmt.Sprintf("%s/oauth/oidc", system_setting.ServerAddress))
formData := values.Encode()
req, err := http.NewRequest("POST", system_setting.GetOIDCSettings().TokenEndpoint, strings.NewReader(formData))
if err != nil {
return nil, err
}
req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
req.Header.Set("Accept", "application/json")
client := http.Client{
Timeout: 5 * time.Second,
}
res, err := client.Do(req)
if err != nil {
common.SysLog(err.Error())
return nil, errors.New("无法连接至 OIDC 服务器,请稍后重试!")
}
defer res.Body.Close()
var oidcResponse OidcResponse
err = json.NewDecoder(res.Body).Decode(&oidcResponse)
if err != nil {
return nil, err
}
if oidcResponse.AccessToken == "" {
common.SysLog("OIDC 获取 Token 失败,请检查设置!")
return nil, errors.New("OIDC 获取 Token 失败,请检查设置!")
}
req, err = http.NewRequest("GET", system_setting.GetOIDCSettings().UserInfoEndpoint, nil)
if err != nil {
return nil, err
}
req.Header.Set("Authorization", "Bearer "+oidcResponse.AccessToken)
res2, err := client.Do(req)
if err != nil {
common.SysLog(err.Error())
return nil, errors.New("无法连接至 OIDC 服务器,请稍后重试!")
}
defer res2.Body.Close()
if res2.StatusCode != http.StatusOK {
common.SysLog("OIDC 获取用户信息失败!请检查设置!")
return nil, errors.New("OIDC 获取用户信息失败!请检查设置!")
}
var oidcUser OidcUser
err = json.NewDecoder(res2.Body).Decode(&oidcUser)
if err != nil {
return nil, err
}
if oidcUser.OpenID == "" || oidcUser.Email == "" {
common.SysLog("OIDC 获取用户信息为空!请检查设置!")
return nil, errors.New("OIDC 获取用户信息为空!请检查设置!")
}
return &oidcUser, nil
}
func OidcAuth(c *gin.Context) {
session := sessions.Default(c)
state := c.Query("state")
if state == "" || session.Get("oauth_state") == nil || state != session.Get("oauth_state").(string) {
c.JSON(http.StatusForbidden, gin.H{
"success": false,
"message": "state is empty or not same",
})
return
}
username := session.Get("username")
if username != nil {
OidcBind(c)
return
}
if !system_setting.GetOIDCSettings().Enabled {
c.JSON(http.StatusOK, gin.H{
"success": false,
"message": "管理员未开启通过 OIDC 登录以及注册",
})
return
}
code := c.Query("code")
oidcUser, err := getOidcUserInfoByCode(code)
if err != nil {
common.ApiError(c, err)
return
}
user := model.User{
OidcId: oidcUser.OpenID,
}
if model.IsOidcIdAlreadyTaken(user.OidcId) {
err := user.FillUserByOidcId()
if err != nil {
c.JSON(http.StatusOK, gin.H{
"success": false,
"message": err.Error(),
})
return
}
} else {
if common.RegisterEnabled {
user.Email = oidcUser.Email
if oidcUser.PreferredUsername != "" {
user.Username = oidcUser.PreferredUsername
} else {
user.Username = "oidc_" + strconv.Itoa(model.GetMaxUserId()+1)
}
if oidcUser.Name != "" {
user.DisplayName = oidcUser.Name
} else {
user.DisplayName = "OIDC User"
}
err := user.Insert(0)
if err != nil {
c.JSON(http.StatusOK, gin.H{
"success": false,
"message": err.Error(),
})
return
}
} else {
c.JSON(http.StatusOK, gin.H{
"success": false,
"message": "管理员关闭了新用户注册",
})
return
}
}
if user.Status != common.UserStatusEnabled {
c.JSON(http.StatusOK, gin.H{
"message": "用户已被封禁",
"success": false,
})
return
}
setupLogin(&user, c)
}
func OidcBind(c *gin.Context) {
if !system_setting.GetOIDCSettings().Enabled {
c.JSON(http.StatusOK, gin.H{
"success": false,
"message": "管理员未开启通过 OIDC 登录以及注册",
})
return
}
code := c.Query("code")
oidcUser, err := getOidcUserInfoByCode(code)
if err != nil {
common.ApiError(c, err)
return
}
user := model.User{
OidcId: oidcUser.OpenID,
}
if model.IsOidcIdAlreadyTaken(user.OidcId) {
c.JSON(http.StatusOK, gin.H{
"success": false,
"message": "该 OIDC 账户已被绑定",
})
return
}
session := sessions.Default(c)
id := session.Get("id")
// id := c.GetInt("id") // critical bug!
user.Id = id.(int)
err = user.FillUserById()
if err != nil {
common.ApiError(c, err)
return
}
user.OidcId = oidcUser.OpenID
err = user.Update(false)
if err != nil {
common.ApiError(c, err)
return
}
c.JSON(http.StatusOK, gin.H{
"success": true,
"message": "bind",
})
return
}

View File

@@ -1,6 +1,7 @@
package controller
import (
"fmt"
"net/http"
"strconv"
"strings"
@@ -8,6 +9,7 @@ import (
"github.com/QuantumNous/new-api/common"
"github.com/QuantumNous/new-api/i18n"
"github.com/QuantumNous/new-api/model"
"github.com/QuantumNous/new-api/setting/operation_setting"
"github.com/gin-gonic/gin"
)
@@ -31,16 +33,17 @@ func SearchTokens(c *gin.Context) {
userId := c.GetInt("id")
keyword := c.Query("keyword")
token := c.Query("token")
tokens, err := model.SearchUserTokens(userId, keyword, token)
pageInfo := common.GetPageQuery(c)
tokens, total, err := model.SearchUserTokens(userId, keyword, token, pageInfo.GetStartIdx(), pageInfo.GetPageSize())
if err != nil {
common.ApiError(c, err)
return
}
c.JSON(http.StatusOK, gin.H{
"success": true,
"message": "",
"data": tokens,
})
pageInfo.SetTotal(int(total))
pageInfo.SetItems(tokens)
common.ApiSuccess(c, pageInfo)
return
}
@@ -157,6 +160,20 @@ func AddToken(c *gin.Context) {
return
}
}
// 检查用户令牌数量是否已达上限
maxTokens := operation_setting.GetMaxUserTokens()
count, err := model.CountUserTokens(c.GetInt("id"))
if err != nil {
common.ApiError(c, err)
return
}
if int(count) >= maxTokens {
c.JSON(http.StatusOK, gin.H{
"success": false,
"message": fmt.Sprintf("已达到最大令牌数量限制 (%d)", maxTokens),
})
return
}
key, err := common.GenerateKey()
if err != nil {
common.ApiErrorI18n(c, i18n.MsgTokenGenerateFailed)

View File

@@ -264,9 +264,20 @@ const (
// OAuth related messages
const (
MsgOAuthInvalidCode = "oauth.invalid_code"
MsgOAuthGetUserErr = "oauth.get_user_error"
MsgOAuthAccountUsed = "oauth.account_used"
MsgOAuthInvalidCode = "oauth.invalid_code"
MsgOAuthGetUserErr = "oauth.get_user_error"
MsgOAuthAccountUsed = "oauth.account_used"
MsgOAuthUnknownProvider = "oauth.unknown_provider"
MsgOAuthStateInvalid = "oauth.state_invalid"
MsgOAuthNotEnabled = "oauth.not_enabled"
MsgOAuthUserDeleted = "oauth.user_deleted"
MsgOAuthUserBanned = "oauth.user_banned"
MsgOAuthBindSuccess = "oauth.bind_success"
MsgOAuthAlreadyBound = "oauth.already_bound"
MsgOAuthConnectFailed = "oauth.connect_failed"
MsgOAuthTokenFailed = "oauth.token_failed"
MsgOAuthUserInfoEmpty = "oauth.user_info_empty"
MsgOAuthTrustLevelLow = "oauth.trust_level_low"
)
// Model layer error messages (for translation in controller)
@@ -276,3 +287,14 @@ const (
MsgUuidDuplicate = "common.uuid_duplicate"
MsgInvalidInput = "common.invalid_input"
)
// Custom OAuth provider related messages
const (
MsgCustomOAuthNotFound = "custom_oauth.not_found"
MsgCustomOAuthSlugEmpty = "custom_oauth.slug_empty"
MsgCustomOAuthSlugExists = "custom_oauth.slug_exists"
MsgCustomOAuthNameEmpty = "custom_oauth.name_empty"
MsgCustomOAuthHasBindings = "custom_oauth.has_bindings"
MsgCustomOAuthBindingNotFound = "custom_oauth.binding_not_found"
MsgCustomOAuthProviderIdInvalid = "custom_oauth.provider_id_field_invalid"
)

View File

@@ -223,9 +223,29 @@ ability.repair_running: "A repair task is already running, please try again late
oauth.invalid_code: "Invalid authorization code"
oauth.get_user_error: "Failed to get user information"
oauth.account_used: "This account has been bound to another user"
oauth.unknown_provider: "Unknown OAuth provider"
oauth.state_invalid: "State parameter is empty or mismatched"
oauth.not_enabled: "{{.Provider}} login and registration has not been enabled by administrator"
oauth.user_deleted: "User has been deleted"
oauth.user_banned: "User has been banned"
oauth.bind_success: "Binding successful"
oauth.already_bound: "This {{.Provider}} account has already been bound"
oauth.connect_failed: "Unable to connect to {{.Provider}} server, please try again later"
oauth.token_failed: "Failed to get token from {{.Provider}}, please check settings"
oauth.user_info_empty: "{{.Provider}} returned empty user info, please check settings"
oauth.trust_level_low: "Linux DO trust level does not meet the minimum required by administrator"
# Model layer error messages
redeem.failed: "Redemption failed, please try again later"
user.create_default_token_error: "Failed to create default token"
common.uuid_duplicate: "Please retry, the system generated a duplicate UUID!"
common.invalid_input: "Invalid input"
# Custom OAuth provider messages
custom_oauth.not_found: "Custom OAuth provider not found"
custom_oauth.slug_empty: "Slug cannot be empty"
custom_oauth.slug_exists: "Slug already exists"
custom_oauth.name_empty: "Provider name cannot be empty"
custom_oauth.has_bindings: "Cannot delete provider with existing user bindings"
custom_oauth.binding_not_found: "OAuth binding not found"
custom_oauth.provider_id_field_invalid: "Could not extract user ID from provider response"

View File

@@ -224,9 +224,29 @@ ability.repair_running: "已经有一个修复任务在运行中,请稍后再
oauth.invalid_code: "无效的授权码"
oauth.get_user_error: "获取用户信息失败"
oauth.account_used: "该账户已被其他用户绑定"
oauth.unknown_provider: "未知的 OAuth 提供商"
oauth.state_invalid: "state 参数为空或不匹配"
oauth.not_enabled: "管理员未开启通过 {{.Provider}} 登录以及注册"
oauth.user_deleted: "用户已注销"
oauth.user_banned: "用户已被封禁"
oauth.bind_success: "绑定成功"
oauth.already_bound: "该 {{.Provider}} 账户已被绑定"
oauth.connect_failed: "无法连接至 {{.Provider}} 服务器,请稍后重试"
oauth.token_failed: "{{.Provider}} 获取 Token 失败,请检查设置"
oauth.user_info_empty: "{{.Provider}} 获取用户信息为空,请检查设置"
oauth.trust_level_low: "Linux DO 信任等级未达到管理员设置的最低信任等级"
# Model layer error messages
redeem.failed: "兑换失败,请稍后重试"
user.create_default_token_error: "创建默认令牌失败"
common.uuid_duplicate: "请重试,系统生成的 UUID 竟然重复了!"
common.invalid_input: "输入不合法"
# Custom OAuth provider messages
custom_oauth.not_found: "自定义 OAuth 提供商不存在"
custom_oauth.slug_empty: "标识符不能为空"
custom_oauth.slug_exists: "标识符已存在"
custom_oauth.name_empty: "提供商名称不能为空"
custom_oauth.has_bindings: "无法删除已有用户绑定的提供商"
custom_oauth.binding_not_found: "OAuth 绑定不存在"
custom_oauth.provider_id_field_invalid: "无法从提供商响应中提取用户 ID"

View File

@@ -18,6 +18,7 @@ import (
"github.com/QuantumNous/new-api/logger"
"github.com/QuantumNous/new-api/middleware"
"github.com/QuantumNous/new-api/model"
"github.com/QuantumNous/new-api/oauth"
"github.com/QuantumNous/new-api/router"
"github.com/QuantumNous/new-api/service"
_ "github.com/QuantumNous/new-api/setting/performance_setting"
@@ -291,5 +292,12 @@ func InitResources() error {
// Register user language loader for lazy loading
i18n.SetUserLangLoader(model.GetUserLanguage)
// Load custom OAuth providers from database
err = oauth.LoadCustomProviders()
if err != nil {
common.SysError("failed to load custom OAuth providers: " + err.Error())
// Don't return error, custom OAuth is not critical
}
return nil
}

View File

@@ -115,3 +115,88 @@ func DownloadRateLimit() func(c *gin.Context) {
func UploadRateLimit() func(c *gin.Context) {
return rateLimitFactory(common.UploadRateLimitNum, common.UploadRateLimitDuration, "UP")
}
// userRateLimitFactory creates a rate limiter keyed by authenticated user ID
// instead of client IP, making it resistant to proxy rotation attacks.
// Must be used AFTER authentication middleware (UserAuth).
func userRateLimitFactory(maxRequestNum int, duration int64, mark string) func(c *gin.Context) {
if common.RedisEnabled {
return func(c *gin.Context) {
userId := c.GetInt("id")
if userId == 0 {
c.Status(http.StatusUnauthorized)
c.Abort()
return
}
key := fmt.Sprintf("rateLimit:%s:user:%d", mark, userId)
userRedisRateLimiter(c, maxRequestNum, duration, key)
}
}
// It's safe to call multi times.
inMemoryRateLimiter.Init(common.RateLimitKeyExpirationDuration)
return func(c *gin.Context) {
userId := c.GetInt("id")
if userId == 0 {
c.Status(http.StatusUnauthorized)
c.Abort()
return
}
key := fmt.Sprintf("%s:user:%d", mark, userId)
if !inMemoryRateLimiter.Request(key, maxRequestNum, duration) {
c.Status(http.StatusTooManyRequests)
c.Abort()
return
}
}
}
// userRedisRateLimiter is like redisRateLimiter but accepts a pre-built key
// (to support user-ID-based keys).
func userRedisRateLimiter(c *gin.Context, maxRequestNum int, duration int64, key string) {
ctx := context.Background()
rdb := common.RDB
listLength, err := rdb.LLen(ctx, key).Result()
if err != nil {
fmt.Println(err.Error())
c.Status(http.StatusInternalServerError)
c.Abort()
return
}
if listLength < int64(maxRequestNum) {
rdb.LPush(ctx, key, time.Now().Format(timeFormat))
rdb.Expire(ctx, key, common.RateLimitKeyExpirationDuration)
} else {
oldTimeStr, _ := rdb.LIndex(ctx, key, -1).Result()
oldTime, err := time.Parse(timeFormat, oldTimeStr)
if err != nil {
fmt.Println(err)
c.Status(http.StatusInternalServerError)
c.Abort()
return
}
nowTimeStr := time.Now().Format(timeFormat)
nowTime, err := time.Parse(timeFormat, nowTimeStr)
if err != nil {
fmt.Println(err)
c.Status(http.StatusInternalServerError)
c.Abort()
return
}
if int64(nowTime.Sub(oldTime).Seconds()) < duration {
rdb.Expire(ctx, key, common.RateLimitKeyExpirationDuration)
c.Status(http.StatusTooManyRequests)
c.Abort()
return
} else {
rdb.LPush(ctx, key, time.Now().Format(timeFormat))
rdb.LTrim(ctx, key, 0, int64(maxRequestNum-1))
rdb.Expire(ctx, key, common.RateLimitKeyExpirationDuration)
}
}
}
// SearchRateLimit returns a per-user rate limiter for search endpoints.
// 10 requests per 60 seconds per user (by user ID, not IP).
func SearchRateLimit() func(c *gin.Context) {
return userRateLimitFactory(common.SearchRateLimitNum, common.SearchRateLimitDuration, "SR")
}

View File

@@ -0,0 +1,158 @@
package model
import (
"errors"
"strings"
"time"
)
// CustomOAuthProvider stores configuration for custom OAuth providers
type CustomOAuthProvider struct {
Id int `json:"id" gorm:"primaryKey"`
Name string `json:"name" gorm:"type:varchar(64);not null"` // Display name, e.g., "GitHub Enterprise"
Slug string `json:"slug" gorm:"type:varchar(64);uniqueIndex;not null"` // URL identifier, e.g., "github-enterprise"
Enabled bool `json:"enabled" gorm:"default:false"` // Whether this provider is enabled
ClientId string `json:"client_id" gorm:"type:varchar(256)"` // OAuth client ID
ClientSecret string `json:"-" gorm:"type:varchar(512)"` // OAuth client secret (not returned to frontend)
AuthorizationEndpoint string `json:"authorization_endpoint" gorm:"type:varchar(512)"` // Authorization URL
TokenEndpoint string `json:"token_endpoint" gorm:"type:varchar(512)"` // Token exchange URL
UserInfoEndpoint string `json:"user_info_endpoint" gorm:"type:varchar(512)"` // User info URL
Scopes string `json:"scopes" gorm:"type:varchar(256);default:'openid profile email'"` // OAuth scopes
// Field mapping configuration (supports JSONPath via gjson)
UserIdField string `json:"user_id_field" gorm:"type:varchar(128);default:'sub'"` // User ID field path, e.g., "sub", "id", "data.user.id"
UsernameField string `json:"username_field" gorm:"type:varchar(128);default:'preferred_username'"` // Username field path
DisplayNameField string `json:"display_name_field" gorm:"type:varchar(128);default:'name'"` // Display name field path
EmailField string `json:"email_field" gorm:"type:varchar(128);default:'email'"` // Email field path
// Advanced options
WellKnown string `json:"well_known" gorm:"type:varchar(512)"` // OIDC discovery endpoint (optional)
AuthStyle int `json:"auth_style" gorm:"default:0"` // 0=auto, 1=params, 2=header (Basic Auth)
CreatedAt time.Time `json:"created_at"`
UpdatedAt time.Time `json:"updated_at"`
}
func (CustomOAuthProvider) TableName() string {
return "custom_oauth_providers"
}
// GetAllCustomOAuthProviders returns all custom OAuth providers
func GetAllCustomOAuthProviders() ([]*CustomOAuthProvider, error) {
var providers []*CustomOAuthProvider
err := DB.Order("id asc").Find(&providers).Error
return providers, err
}
// GetEnabledCustomOAuthProviders returns all enabled custom OAuth providers
func GetEnabledCustomOAuthProviders() ([]*CustomOAuthProvider, error) {
var providers []*CustomOAuthProvider
err := DB.Where("enabled = ?", true).Order("id asc").Find(&providers).Error
return providers, err
}
// GetCustomOAuthProviderById returns a custom OAuth provider by ID
func GetCustomOAuthProviderById(id int) (*CustomOAuthProvider, error) {
var provider CustomOAuthProvider
err := DB.First(&provider, id).Error
if err != nil {
return nil, err
}
return &provider, nil
}
// GetCustomOAuthProviderBySlug returns a custom OAuth provider by slug
func GetCustomOAuthProviderBySlug(slug string) (*CustomOAuthProvider, error) {
var provider CustomOAuthProvider
err := DB.Where("slug = ?", slug).First(&provider).Error
if err != nil {
return nil, err
}
return &provider, nil
}
// CreateCustomOAuthProvider creates a new custom OAuth provider
func CreateCustomOAuthProvider(provider *CustomOAuthProvider) error {
if err := validateCustomOAuthProvider(provider); err != nil {
return err
}
return DB.Create(provider).Error
}
// UpdateCustomOAuthProvider updates an existing custom OAuth provider
func UpdateCustomOAuthProvider(provider *CustomOAuthProvider) error {
if err := validateCustomOAuthProvider(provider); err != nil {
return err
}
return DB.Save(provider).Error
}
// DeleteCustomOAuthProvider deletes a custom OAuth provider by ID
func DeleteCustomOAuthProvider(id int) error {
// First, delete all user bindings for this provider
if err := DB.Where("provider_id = ?", id).Delete(&UserOAuthBinding{}).Error; err != nil {
return err
}
return DB.Delete(&CustomOAuthProvider{}, id).Error
}
// IsSlugTaken checks if a slug is already taken by another provider
func IsSlugTaken(slug string, excludeId int) bool {
var count int64
query := DB.Model(&CustomOAuthProvider{}).Where("slug = ?", slug)
if excludeId > 0 {
query = query.Where("id != ?", excludeId)
}
query.Count(&count)
return count > 0
}
// validateCustomOAuthProvider validates a custom OAuth provider configuration
func validateCustomOAuthProvider(provider *CustomOAuthProvider) error {
if provider.Name == "" {
return errors.New("provider name is required")
}
if provider.Slug == "" {
return errors.New("provider slug is required")
}
// Slug must be lowercase and contain only alphanumeric characters and hyphens
slug := strings.ToLower(provider.Slug)
for _, c := range slug {
if !((c >= 'a' && c <= 'z') || (c >= '0' && c <= '9') || c == '-') {
return errors.New("provider slug must contain only lowercase letters, numbers, and hyphens")
}
}
provider.Slug = slug
if provider.ClientId == "" {
return errors.New("client ID is required")
}
if provider.AuthorizationEndpoint == "" {
return errors.New("authorization endpoint is required")
}
if provider.TokenEndpoint == "" {
return errors.New("token endpoint is required")
}
if provider.UserInfoEndpoint == "" {
return errors.New("user info endpoint is required")
}
// Set defaults for field mappings if empty
if provider.UserIdField == "" {
provider.UserIdField = "sub"
}
if provider.UsernameField == "" {
provider.UsernameField = "preferred_username"
}
if provider.DisplayNameField == "" {
provider.DisplayNameField = "name"
}
if provider.EmailField == "" {
provider.EmailField = "email"
}
if provider.Scopes == "" {
provider.Scopes = "openid profile email"
}
return nil
}

View File

@@ -274,6 +274,8 @@ func migrateDB() error {
&SubscriptionOrder{},
&UserSubscription{},
&SubscriptionPreConsumeRecord{},
&CustomOAuthProvider{},
&UserOAuthBinding{},
)
if err != nil {
return err
@@ -320,6 +322,8 @@ func migrateDBFast() error {
{&SubscriptionOrder{}, "SubscriptionOrder"},
{&UserSubscription{}, "UserSubscription"},
{&SubscriptionPreConsumeRecord{}, "SubscriptionPreConsumeRecord"},
{&CustomOAuthProvider{}, "CustomOAuthProvider"},
{&UserOAuthBinding{}, "UserOAuthBinding"},
}
// 动态计算migration数量确保errChan缓冲区足够大
errChan := make(chan error, len(migrations))

View File

@@ -6,6 +6,7 @@ import (
"strings"
"github.com/QuantumNous/new-api/common"
"github.com/QuantumNous/new-api/setting/operation_setting"
"github.com/bytedance/gopkg/util/gopool"
"gorm.io/gorm"
)
@@ -63,12 +64,104 @@ func GetAllUserTokens(userId int, startIdx int, num int) ([]*Token, error) {
return tokens, err
}
func SearchUserTokens(userId int, keyword string, token string) (tokens []*Token, err error) {
// sanitizeLikePattern 校验并清洗用户输入的 LIKE 搜索模式。
// 规则:
// 1. 转义 ! 和 _使用 ! 作为 ESCAPE 字符,兼容 MySQL/PostgreSQL/SQLite
// 2. 连续的 % 合并为单个 %
// 3. 最多允许 2 个 %
// 4. 含 % 时(模糊搜索),去掉 % 后关键词长度必须 >= 2
// 5. 不含 % 时按精确匹配
func sanitizeLikePattern(input string) (string, error) {
// 1. 先转义 ESCAPE 字符 ! 自身,再转义 _
// 使用 ! 而非 \ 作为 ESCAPE 字符,避免 MySQL 中反斜杠的字符串转义问题
input = strings.ReplaceAll(input, "!", "!!")
input = strings.ReplaceAll(input, `_`, `!_`)
// 2. 连续的 % 直接拒绝
if strings.Contains(input, "%%") {
return "", errors.New("搜索模式中不允许包含连续的 % 通配符")
}
// 3. 统计 % 数量,不得超过 2
count := strings.Count(input, "%")
if count > 2 {
return "", errors.New("搜索模式中最多允许包含 2 个 % 通配符")
}
// 4. 含 % 时,去掉 % 后关键词长度必须 >= 2
if count > 0 {
stripped := strings.ReplaceAll(input, "%", "")
if len(stripped) < 2 {
return "", errors.New("使用模糊搜索时,关键词长度至少为 2 个字符")
}
return input, nil
}
// 5. 无 % 时,精确全匹配
return input, nil
}
const searchHardLimit = 100
func SearchUserTokens(userId int, keyword string, token string, offset int, limit int) (tokens []*Token, total int64, err error) {
// model 层强制截断
if limit <= 0 || limit > searchHardLimit {
limit = searchHardLimit
}
if offset < 0 {
offset = 0
}
if token != "" {
token = strings.Trim(token, "sk-")
}
err = DB.Where("user_id = ?", userId).Where("name LIKE ?", "%"+keyword+"%").Where(commonKeyCol+" LIKE ?", "%"+token+"%").Find(&tokens).Error
return tokens, err
// 超量用户(令牌数超过上限)只允许精确搜索,禁止模糊搜索
maxTokens := operation_setting.GetMaxUserTokens()
hasFuzzy := strings.Contains(keyword, "%") || strings.Contains(token, "%")
if hasFuzzy {
count, err := CountUserTokens(userId)
if err != nil {
common.SysLog("failed to count user tokens: " + err.Error())
return nil, 0, errors.New("获取令牌数量失败")
}
if int(count) > maxTokens {
return nil, 0, errors.New("令牌数量超过上限,仅允许精确搜索,请勿使用 % 通配符")
}
}
baseQuery := DB.Model(&Token{}).Where("user_id = ?", userId)
// 非空才加 LIKE 条件,空则跳过(不过滤该字段)
if keyword != "" {
keywordPattern, err := sanitizeLikePattern(keyword)
if err != nil {
return nil, 0, err
}
baseQuery = baseQuery.Where("name LIKE ? ESCAPE '!'", keywordPattern)
}
if token != "" {
tokenPattern, err := sanitizeLikePattern(token)
if err != nil {
return nil, 0, err
}
baseQuery = baseQuery.Where(commonKeyCol+" LIKE ? ESCAPE '!'", tokenPattern)
}
// 先查匹配总数(用于分页,受 maxTokens 上限保护,避免全表 COUNT
err = baseQuery.Limit(maxTokens).Count(&total).Error
if err != nil {
common.SysError("failed to count search tokens: " + err.Error())
return nil, 0, errors.New("搜索令牌失败")
}
// 再分页查数据
err = baseQuery.Order("id desc").Offset(offset).Limit(limit).Find(&tokens).Error
if err != nil {
common.SysError("failed to search tokens: " + err.Error())
return nil, 0, errors.New("搜索令牌失败")
}
return tokens, total, nil
}
func ValidateUserToken(key string) (token *Token, err error) {

View File

@@ -540,6 +540,14 @@ func (user *User) FillUserByGitHubId() error {
return nil
}
// UpdateGitHubId updates the user's GitHub ID (used for migration from login to numeric ID)
func (user *User) UpdateGitHubId(newGitHubId string) error {
if user.Id == 0 {
return errors.New("user id is empty")
}
return DB.Model(user).Update("github_id", newGitHubId).Error
}
func (user *User) FillUserByDiscordId() error {
if user.DiscordId == "" {
return errors.New("discord id 为空!")

125
model/user_oauth_binding.go Normal file
View File

@@ -0,0 +1,125 @@
package model
import (
"errors"
"time"
)
// UserOAuthBinding stores the binding relationship between users and custom OAuth providers
type UserOAuthBinding struct {
Id int `json:"id" gorm:"primaryKey"`
UserId int `json:"user_id" gorm:"index;not null"` // User ID
ProviderId int `json:"provider_id" gorm:"index;not null"` // Custom OAuth provider ID
ProviderUserId string `json:"provider_user_id" gorm:"type:varchar(256);not null"` // User ID from OAuth provider
CreatedAt time.Time `json:"created_at"`
// Composite unique index to prevent duplicate bindings
// One OAuth account can only be bound to one user
}
func (UserOAuthBinding) TableName() string {
return "user_oauth_bindings"
}
// GetUserOAuthBindingsByUserId returns all OAuth bindings for a user
func GetUserOAuthBindingsByUserId(userId int) ([]*UserOAuthBinding, error) {
var bindings []*UserOAuthBinding
err := DB.Where("user_id = ?", userId).Find(&bindings).Error
return bindings, err
}
// GetUserOAuthBinding returns a specific binding for a user and provider
func GetUserOAuthBinding(userId, providerId int) (*UserOAuthBinding, error) {
var binding UserOAuthBinding
err := DB.Where("user_id = ? AND provider_id = ?", userId, providerId).First(&binding).Error
if err != nil {
return nil, err
}
return &binding, nil
}
// GetUserByOAuthBinding finds a user by provider ID and provider user ID
func GetUserByOAuthBinding(providerId int, providerUserId string) (*User, error) {
var binding UserOAuthBinding
err := DB.Where("provider_id = ? AND provider_user_id = ?", providerId, providerUserId).First(&binding).Error
if err != nil {
return nil, err
}
var user User
err = DB.First(&user, binding.UserId).Error
if err != nil {
return nil, err
}
return &user, nil
}
// IsProviderUserIdTaken checks if a provider user ID is already bound to any user
func IsProviderUserIdTaken(providerId int, providerUserId string) bool {
var count int64
DB.Model(&UserOAuthBinding{}).Where("provider_id = ? AND provider_user_id = ?", providerId, providerUserId).Count(&count)
return count > 0
}
// CreateUserOAuthBinding creates a new OAuth binding
func CreateUserOAuthBinding(binding *UserOAuthBinding) error {
if binding.UserId == 0 {
return errors.New("user ID is required")
}
if binding.ProviderId == 0 {
return errors.New("provider ID is required")
}
if binding.ProviderUserId == "" {
return errors.New("provider user ID is required")
}
// Check if this provider user ID is already taken
if IsProviderUserIdTaken(binding.ProviderId, binding.ProviderUserId) {
return errors.New("this OAuth account is already bound to another user")
}
binding.CreatedAt = time.Now()
return DB.Create(binding).Error
}
// UpdateUserOAuthBinding updates an existing OAuth binding (e.g., rebind to different OAuth account)
func UpdateUserOAuthBinding(userId, providerId int, newProviderUserId string) error {
// Check if the new provider user ID is already taken by another user
var existingBinding UserOAuthBinding
err := DB.Where("provider_id = ? AND provider_user_id = ?", providerId, newProviderUserId).First(&existingBinding).Error
if err == nil && existingBinding.UserId != userId {
return errors.New("this OAuth account is already bound to another user")
}
// Check if user already has a binding for this provider
var binding UserOAuthBinding
err = DB.Where("user_id = ? AND provider_id = ?", userId, providerId).First(&binding).Error
if err != nil {
// No existing binding, create new one
return CreateUserOAuthBinding(&UserOAuthBinding{
UserId: userId,
ProviderId: providerId,
ProviderUserId: newProviderUserId,
})
}
// Update existing binding
return DB.Model(&binding).Update("provider_user_id", newProviderUserId).Error
}
// DeleteUserOAuthBinding deletes an OAuth binding
func DeleteUserOAuthBinding(userId, providerId int) error {
return DB.Where("user_id = ? AND provider_id = ?", userId, providerId).Delete(&UserOAuthBinding{}).Error
}
// DeleteUserOAuthBindingsByUserId deletes all OAuth bindings for a user
func DeleteUserOAuthBindingsByUserId(userId int) error {
return DB.Where("user_id = ?", userId).Delete(&UserOAuthBinding{}).Error
}
// GetBindingCountByProviderId returns the number of bindings for a provider
func GetBindingCountByProviderId(providerId int) (int64, error) {
var count int64
err := DB.Model(&UserOAuthBinding{}).Where("provider_id = ?", providerId).Count(&count).Error
return count, err
}

172
oauth/discord.go Normal file
View File

@@ -0,0 +1,172 @@
package oauth
import (
"context"
"encoding/json"
"fmt"
"net/http"
"net/url"
"strings"
"time"
"github.com/QuantumNous/new-api/i18n"
"github.com/QuantumNous/new-api/logger"
"github.com/QuantumNous/new-api/model"
"github.com/QuantumNous/new-api/setting/system_setting"
"github.com/gin-gonic/gin"
)
func init() {
Register("discord", &DiscordProvider{})
}
// DiscordProvider implements OAuth for Discord
type DiscordProvider struct{}
type discordOAuthResponse struct {
AccessToken string `json:"access_token"`
IDToken string `json:"id_token"`
RefreshToken string `json:"refresh_token"`
TokenType string `json:"token_type"`
ExpiresIn int `json:"expires_in"`
Scope string `json:"scope"`
}
type discordUser struct {
UID string `json:"id"`
ID string `json:"username"`
Name string `json:"global_name"`
}
func (p *DiscordProvider) GetName() string {
return "Discord"
}
func (p *DiscordProvider) IsEnabled() bool {
return system_setting.GetDiscordSettings().Enabled
}
func (p *DiscordProvider) ExchangeToken(ctx context.Context, code string, c *gin.Context) (*OAuthToken, error) {
if code == "" {
return nil, NewOAuthError(i18n.MsgOAuthInvalidCode, nil)
}
logger.LogDebug(ctx, "[OAuth-Discord] ExchangeToken: code=%s...", code[:min(len(code), 10)])
settings := system_setting.GetDiscordSettings()
redirectUri := fmt.Sprintf("%s/oauth/discord", system_setting.ServerAddress)
values := url.Values{}
values.Set("client_id", settings.ClientId)
values.Set("client_secret", settings.ClientSecret)
values.Set("code", code)
values.Set("grant_type", "authorization_code")
values.Set("redirect_uri", redirectUri)
logger.LogDebug(ctx, "[OAuth-Discord] ExchangeToken: redirect_uri=%s", redirectUri)
req, err := http.NewRequestWithContext(ctx, "POST", "https://discord.com/api/v10/oauth2/token", strings.NewReader(values.Encode()))
if err != nil {
return nil, err
}
req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
req.Header.Set("Accept", "application/json")
client := http.Client{
Timeout: 5 * time.Second,
}
res, err := client.Do(req)
if err != nil {
logger.LogError(ctx, fmt.Sprintf("[OAuth-Discord] ExchangeToken error: %s", err.Error()))
return nil, NewOAuthErrorWithRaw(i18n.MsgOAuthConnectFailed, map[string]any{"Provider": "Discord"}, err.Error())
}
defer res.Body.Close()
logger.LogDebug(ctx, "[OAuth-Discord] ExchangeToken response status: %d", res.StatusCode)
var discordResponse discordOAuthResponse
err = json.NewDecoder(res.Body).Decode(&discordResponse)
if err != nil {
logger.LogError(ctx, fmt.Sprintf("[OAuth-Discord] ExchangeToken decode error: %s", err.Error()))
return nil, err
}
if discordResponse.AccessToken == "" {
logger.LogError(ctx, "[OAuth-Discord] ExchangeToken failed: empty access token")
return nil, NewOAuthError(i18n.MsgOAuthTokenFailed, map[string]any{"Provider": "Discord"})
}
logger.LogDebug(ctx, "[OAuth-Discord] ExchangeToken success: scope=%s", discordResponse.Scope)
return &OAuthToken{
AccessToken: discordResponse.AccessToken,
TokenType: discordResponse.TokenType,
RefreshToken: discordResponse.RefreshToken,
ExpiresIn: discordResponse.ExpiresIn,
Scope: discordResponse.Scope,
IDToken: discordResponse.IDToken,
}, nil
}
func (p *DiscordProvider) GetUserInfo(ctx context.Context, token *OAuthToken) (*OAuthUser, error) {
logger.LogDebug(ctx, "[OAuth-Discord] GetUserInfo: fetching user info")
req, err := http.NewRequestWithContext(ctx, "GET", "https://discord.com/api/v10/users/@me", nil)
if err != nil {
return nil, err
}
req.Header.Set("Authorization", "Bearer "+token.AccessToken)
client := http.Client{
Timeout: 5 * time.Second,
}
res, err := client.Do(req)
if err != nil {
logger.LogError(ctx, fmt.Sprintf("[OAuth-Discord] GetUserInfo error: %s", err.Error()))
return nil, NewOAuthErrorWithRaw(i18n.MsgOAuthConnectFailed, map[string]any{"Provider": "Discord"}, err.Error())
}
defer res.Body.Close()
logger.LogDebug(ctx, "[OAuth-Discord] GetUserInfo response status: %d", res.StatusCode)
if res.StatusCode != http.StatusOK {
logger.LogError(ctx, fmt.Sprintf("[OAuth-Discord] GetUserInfo failed: status=%d", res.StatusCode))
return nil, NewOAuthError(i18n.MsgOAuthGetUserErr, nil)
}
var discordUser discordUser
err = json.NewDecoder(res.Body).Decode(&discordUser)
if err != nil {
logger.LogError(ctx, fmt.Sprintf("[OAuth-Discord] GetUserInfo decode error: %s", err.Error()))
return nil, err
}
if discordUser.UID == "" || discordUser.ID == "" {
logger.LogError(ctx, "[OAuth-Discord] GetUserInfo failed: empty user fields")
return nil, NewOAuthError(i18n.MsgOAuthUserInfoEmpty, map[string]any{"Provider": "Discord"})
}
logger.LogDebug(ctx, "[OAuth-Discord] GetUserInfo success: uid=%s, username=%s, name=%s", discordUser.UID, discordUser.ID, discordUser.Name)
return &OAuthUser{
ProviderUserID: discordUser.UID,
Username: discordUser.ID,
DisplayName: discordUser.Name,
}, nil
}
func (p *DiscordProvider) IsUserIDTaken(providerUserID string) bool {
return model.IsDiscordIdAlreadyTaken(providerUserID)
}
func (p *DiscordProvider) FillUserByProviderID(user *model.User, providerUserID string) error {
user.DiscordId = providerUserID
return user.FillUserByDiscordId()
}
func (p *DiscordProvider) SetProviderUserID(user *model.User, providerUserID string) {
user.DiscordId = providerUserID
}
func (p *DiscordProvider) GetProviderPrefix() string {
return "discord_"
}

268
oauth/generic.go Normal file
View File

@@ -0,0 +1,268 @@
package oauth
import (
"context"
"encoding/base64"
"encoding/json"
"fmt"
"io"
"net/http"
"net/url"
"strings"
"time"
"github.com/QuantumNous/new-api/i18n"
"github.com/QuantumNous/new-api/logger"
"github.com/QuantumNous/new-api/model"
"github.com/QuantumNous/new-api/setting/system_setting"
"github.com/gin-gonic/gin"
"github.com/tidwall/gjson"
)
// AuthStyle defines how to send client credentials
const (
AuthStyleAutoDetect = 0 // Auto-detect based on server response
AuthStyleInParams = 1 // Send client_id and client_secret as POST parameters
AuthStyleInHeader = 2 // Send as Basic Auth header
)
// GenericOAuthProvider implements OAuth for custom/generic OAuth providers
type GenericOAuthProvider struct {
config *model.CustomOAuthProvider
}
// NewGenericOAuthProvider creates a new generic OAuth provider from config
func NewGenericOAuthProvider(config *model.CustomOAuthProvider) *GenericOAuthProvider {
return &GenericOAuthProvider{config: config}
}
func (p *GenericOAuthProvider) GetName() string {
return p.config.Name
}
func (p *GenericOAuthProvider) IsEnabled() bool {
return p.config.Enabled
}
func (p *GenericOAuthProvider) GetConfig() *model.CustomOAuthProvider {
return p.config
}
func (p *GenericOAuthProvider) ExchangeToken(ctx context.Context, code string, c *gin.Context) (*OAuthToken, error) {
if code == "" {
return nil, NewOAuthError(i18n.MsgOAuthInvalidCode, nil)
}
logger.LogDebug(ctx, "[OAuth-Generic-%s] ExchangeToken: code=%s...", p.config.Slug, code[:min(len(code), 10)])
redirectUri := fmt.Sprintf("%s/oauth/%s", system_setting.ServerAddress, p.config.Slug)
values := url.Values{}
values.Set("grant_type", "authorization_code")
values.Set("code", code)
values.Set("redirect_uri", redirectUri)
// Determine auth style
authStyle := p.config.AuthStyle
if authStyle == AuthStyleAutoDetect {
// Default to params style for most OAuth servers
authStyle = AuthStyleInParams
}
var req *http.Request
var err error
if authStyle == AuthStyleInParams {
values.Set("client_id", p.config.ClientId)
values.Set("client_secret", p.config.ClientSecret)
}
req, err = http.NewRequestWithContext(ctx, "POST", p.config.TokenEndpoint, strings.NewReader(values.Encode()))
if err != nil {
return nil, err
}
req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
req.Header.Set("Accept", "application/json")
if authStyle == AuthStyleInHeader {
// Basic Auth
credentials := base64.StdEncoding.EncodeToString([]byte(p.config.ClientId + ":" + p.config.ClientSecret))
req.Header.Set("Authorization", "Basic "+credentials)
}
logger.LogDebug(ctx, "[OAuth-Generic-%s] ExchangeToken: token_endpoint=%s, redirect_uri=%s, auth_style=%d",
p.config.Slug, p.config.TokenEndpoint, redirectUri, authStyle)
client := http.Client{
Timeout: 20 * time.Second,
}
res, err := client.Do(req)
if err != nil {
logger.LogError(ctx, fmt.Sprintf("[OAuth-Generic-%s] ExchangeToken error: %s", p.config.Slug, err.Error()))
return nil, NewOAuthErrorWithRaw(i18n.MsgOAuthConnectFailed, map[string]any{"Provider": p.config.Name}, err.Error())
}
defer res.Body.Close()
logger.LogDebug(ctx, "[OAuth-Generic-%s] ExchangeToken response status: %d", p.config.Slug, res.StatusCode)
body, err := io.ReadAll(res.Body)
if err != nil {
logger.LogError(ctx, fmt.Sprintf("[OAuth-Generic-%s] ExchangeToken read body error: %s", p.config.Slug, err.Error()))
return nil, err
}
bodyStr := string(body)
logger.LogDebug(ctx, "[OAuth-Generic-%s] ExchangeToken response body: %s", p.config.Slug, bodyStr[:min(len(bodyStr), 500)])
// Try to parse as JSON first
var tokenResponse struct {
AccessToken string `json:"access_token"`
TokenType string `json:"token_type"`
RefreshToken string `json:"refresh_token"`
ExpiresIn int `json:"expires_in"`
Scope string `json:"scope"`
IDToken string `json:"id_token"`
Error string `json:"error"`
ErrorDesc string `json:"error_description"`
}
if err := json.Unmarshal(body, &tokenResponse); err != nil {
// Try to parse as URL-encoded (some OAuth servers like GitHub return this format)
parsedValues, parseErr := url.ParseQuery(bodyStr)
if parseErr != nil {
logger.LogError(ctx, fmt.Sprintf("[OAuth-Generic-%s] ExchangeToken parse error: %s", p.config.Slug, err.Error()))
return nil, err
}
tokenResponse.AccessToken = parsedValues.Get("access_token")
tokenResponse.TokenType = parsedValues.Get("token_type")
tokenResponse.Scope = parsedValues.Get("scope")
}
if tokenResponse.Error != "" {
logger.LogError(ctx, fmt.Sprintf("[OAuth-Generic-%s] ExchangeToken OAuth error: %s - %s",
p.config.Slug, tokenResponse.Error, tokenResponse.ErrorDesc))
return nil, NewOAuthErrorWithRaw(i18n.MsgOAuthTokenFailed, map[string]any{"Provider": p.config.Name}, tokenResponse.ErrorDesc)
}
if tokenResponse.AccessToken == "" {
logger.LogError(ctx, fmt.Sprintf("[OAuth-Generic-%s] ExchangeToken failed: empty access token", p.config.Slug))
return nil, NewOAuthError(i18n.MsgOAuthTokenFailed, map[string]any{"Provider": p.config.Name})
}
logger.LogDebug(ctx, "[OAuth-Generic-%s] ExchangeToken success: scope=%s", p.config.Slug, tokenResponse.Scope)
return &OAuthToken{
AccessToken: tokenResponse.AccessToken,
TokenType: tokenResponse.TokenType,
RefreshToken: tokenResponse.RefreshToken,
ExpiresIn: tokenResponse.ExpiresIn,
Scope: tokenResponse.Scope,
IDToken: tokenResponse.IDToken,
}, nil
}
func (p *GenericOAuthProvider) GetUserInfo(ctx context.Context, token *OAuthToken) (*OAuthUser, error) {
logger.LogDebug(ctx, "[OAuth-Generic-%s] GetUserInfo: fetching user info from %s", p.config.Slug, p.config.UserInfoEndpoint)
req, err := http.NewRequestWithContext(ctx, "GET", p.config.UserInfoEndpoint, nil)
if err != nil {
return nil, err
}
// Set authorization header
tokenType := token.TokenType
if tokenType == "" {
tokenType = "Bearer"
}
req.Header.Set("Authorization", fmt.Sprintf("%s %s", tokenType, token.AccessToken))
req.Header.Set("Accept", "application/json")
client := http.Client{
Timeout: 20 * time.Second,
}
res, err := client.Do(req)
if err != nil {
logger.LogError(ctx, fmt.Sprintf("[OAuth-Generic-%s] GetUserInfo error: %s", p.config.Slug, err.Error()))
return nil, NewOAuthErrorWithRaw(i18n.MsgOAuthConnectFailed, map[string]any{"Provider": p.config.Name}, err.Error())
}
defer res.Body.Close()
logger.LogDebug(ctx, "[OAuth-Generic-%s] GetUserInfo response status: %d", p.config.Slug, res.StatusCode)
if res.StatusCode != http.StatusOK {
logger.LogError(ctx, fmt.Sprintf("[OAuth-Generic-%s] GetUserInfo failed: status=%d", p.config.Slug, res.StatusCode))
return nil, NewOAuthError(i18n.MsgOAuthGetUserErr, nil)
}
body, err := io.ReadAll(res.Body)
if err != nil {
logger.LogError(ctx, fmt.Sprintf("[OAuth-Generic-%s] GetUserInfo read body error: %s", p.config.Slug, err.Error()))
return nil, err
}
bodyStr := string(body)
logger.LogDebug(ctx, "[OAuth-Generic-%s] GetUserInfo response body: %s", p.config.Slug, bodyStr[:min(len(bodyStr), 500)])
// Extract fields using gjson (supports JSONPath-like syntax)
userId := gjson.Get(bodyStr, p.config.UserIdField).String()
username := gjson.Get(bodyStr, p.config.UsernameField).String()
displayName := gjson.Get(bodyStr, p.config.DisplayNameField).String()
email := gjson.Get(bodyStr, p.config.EmailField).String()
// If user ID field returns a number, convert it
if userId == "" {
// Try to get as number
userIdNum := gjson.Get(bodyStr, p.config.UserIdField)
if userIdNum.Exists() {
userId = userIdNum.Raw
// Remove quotes if present
userId = strings.Trim(userId, "\"")
}
}
if userId == "" {
logger.LogError(ctx, fmt.Sprintf("[OAuth-Generic-%s] GetUserInfo failed: empty user ID (field: %s)", p.config.Slug, p.config.UserIdField))
return nil, NewOAuthError(i18n.MsgOAuthUserInfoEmpty, map[string]any{"Provider": p.config.Name})
}
logger.LogDebug(ctx, "[OAuth-Generic-%s] GetUserInfo success: id=%s, username=%s, name=%s, email=%s",
p.config.Slug, userId, username, displayName, email)
return &OAuthUser{
ProviderUserID: userId,
Username: username,
DisplayName: displayName,
Email: email,
}, nil
}
func (p *GenericOAuthProvider) IsUserIDTaken(providerUserID string) bool {
return model.IsProviderUserIdTaken(p.config.Id, providerUserID)
}
func (p *GenericOAuthProvider) FillUserByProviderID(user *model.User, providerUserID string) error {
foundUser, err := model.GetUserByOAuthBinding(p.config.Id, providerUserID)
if err != nil {
return err
}
*user = *foundUser
return nil
}
func (p *GenericOAuthProvider) SetProviderUserID(user *model.User, providerUserID string) {
// For generic providers, we store the binding in user_oauth_bindings table
// This is handled separately in the OAuth controller
}
func (p *GenericOAuthProvider) GetProviderPrefix() string {
return p.config.Slug + "_"
}
// GetProviderId returns the provider ID for binding purposes
func (p *GenericOAuthProvider) GetProviderId() int {
return p.config.Id
}
// IsGenericProvider returns true for generic providers
func (p *GenericOAuthProvider) IsGenericProvider() bool {
return true
}

166
oauth/github.go Normal file
View File

@@ -0,0 +1,166 @@
package oauth
import (
"bytes"
"context"
"encoding/json"
"fmt"
"net/http"
"strconv"
"time"
"github.com/QuantumNous/new-api/common"
"github.com/QuantumNous/new-api/i18n"
"github.com/QuantumNous/new-api/logger"
"github.com/QuantumNous/new-api/model"
"github.com/gin-gonic/gin"
)
func init() {
Register("github", &GitHubProvider{})
}
// GitHubProvider implements OAuth for GitHub
type GitHubProvider struct{}
type gitHubOAuthResponse struct {
AccessToken string `json:"access_token"`
Scope string `json:"scope"`
TokenType string `json:"token_type"`
}
type gitHubUser struct {
Id int64 `json:"id"` // GitHub numeric ID (permanent, never changes)
Login string `json:"login"` // GitHub username (can be changed by user)
Name string `json:"name"`
Email string `json:"email"`
}
func (p *GitHubProvider) GetName() string {
return "GitHub"
}
func (p *GitHubProvider) IsEnabled() bool {
return common.GitHubOAuthEnabled
}
func (p *GitHubProvider) ExchangeToken(ctx context.Context, code string, c *gin.Context) (*OAuthToken, error) {
if code == "" {
return nil, NewOAuthError(i18n.MsgOAuthInvalidCode, nil)
}
logger.LogDebug(ctx, "[OAuth-GitHub] ExchangeToken: code=%s...", code[:min(len(code), 10)])
values := map[string]string{
"client_id": common.GitHubClientId,
"client_secret": common.GitHubClientSecret,
"code": code,
}
jsonData, err := json.Marshal(values)
if err != nil {
return nil, err
}
req, err := http.NewRequestWithContext(ctx, "POST", "https://github.com/login/oauth/access_token", bytes.NewBuffer(jsonData))
if err != nil {
return nil, err
}
req.Header.Set("Content-Type", "application/json")
req.Header.Set("Accept", "application/json")
client := http.Client{
Timeout: 20 * time.Second,
}
res, err := client.Do(req)
if err != nil {
logger.LogError(ctx, fmt.Sprintf("[OAuth-GitHub] ExchangeToken error: %s", err.Error()))
return nil, NewOAuthErrorWithRaw(i18n.MsgOAuthConnectFailed, map[string]any{"Provider": "GitHub"}, err.Error())
}
defer res.Body.Close()
logger.LogDebug(ctx, "[OAuth-GitHub] ExchangeToken response status: %d", res.StatusCode)
var oAuthResponse gitHubOAuthResponse
err = json.NewDecoder(res.Body).Decode(&oAuthResponse)
if err != nil {
logger.LogError(ctx, fmt.Sprintf("[OAuth-GitHub] ExchangeToken decode error: %s", err.Error()))
return nil, err
}
if oAuthResponse.AccessToken == "" {
logger.LogError(ctx, "[OAuth-GitHub] ExchangeToken failed: empty access token")
return nil, NewOAuthError(i18n.MsgOAuthTokenFailed, map[string]any{"Provider": "GitHub"})
}
logger.LogDebug(ctx, "[OAuth-GitHub] ExchangeToken success: scope=%s", oAuthResponse.Scope)
return &OAuthToken{
AccessToken: oAuthResponse.AccessToken,
TokenType: oAuthResponse.TokenType,
Scope: oAuthResponse.Scope,
}, nil
}
func (p *GitHubProvider) GetUserInfo(ctx context.Context, token *OAuthToken) (*OAuthUser, error) {
logger.LogDebug(ctx, "[OAuth-GitHub] GetUserInfo: fetching user info")
req, err := http.NewRequestWithContext(ctx, "GET", "https://api.github.com/user", nil)
if err != nil {
return nil, err
}
req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", token.AccessToken))
client := http.Client{
Timeout: 20 * time.Second,
}
res, err := client.Do(req)
if err != nil {
logger.LogError(ctx, fmt.Sprintf("[OAuth-GitHub] GetUserInfo error: %s", err.Error()))
return nil, NewOAuthErrorWithRaw(i18n.MsgOAuthConnectFailed, map[string]any{"Provider": "GitHub"}, err.Error())
}
defer res.Body.Close()
logger.LogDebug(ctx, "[OAuth-GitHub] GetUserInfo response status: %d", res.StatusCode)
var githubUser gitHubUser
err = json.NewDecoder(res.Body).Decode(&githubUser)
if err != nil {
logger.LogError(ctx, fmt.Sprintf("[OAuth-GitHub] GetUserInfo decode error: %s", err.Error()))
return nil, err
}
if githubUser.Id == 0 || githubUser.Login == "" {
logger.LogError(ctx, "[OAuth-GitHub] GetUserInfo failed: empty id or login field")
return nil, NewOAuthError(i18n.MsgOAuthUserInfoEmpty, map[string]any{"Provider": "GitHub"})
}
logger.LogDebug(ctx, "[OAuth-GitHub] GetUserInfo success: id=%d, login=%s, name=%s, email=%s",
githubUser.Id, githubUser.Login, githubUser.Name, githubUser.Email)
return &OAuthUser{
ProviderUserID: strconv.FormatInt(githubUser.Id, 10), // Use numeric ID as primary identifier
Username: githubUser.Login,
DisplayName: githubUser.Name,
Email: githubUser.Email,
Extra: map[string]any{
"legacy_id": githubUser.Login, // Store login for migration from old accounts
},
}, nil
}
func (p *GitHubProvider) IsUserIDTaken(providerUserID string) bool {
return model.IsGitHubIdAlreadyTaken(providerUserID)
}
func (p *GitHubProvider) FillUserByProviderID(user *model.User, providerUserID string) error {
user.GitHubId = providerUserID
return user.FillUserByGitHubId()
}
func (p *GitHubProvider) SetProviderUserID(user *model.User, providerUserID string) {
user.GitHubId = providerUserID
}
func (p *GitHubProvider) GetProviderPrefix() string {
return "github_"
}

195
oauth/linuxdo.go Normal file
View File

@@ -0,0 +1,195 @@
package oauth
import (
"context"
"encoding/base64"
"encoding/json"
"fmt"
"net/http"
"net/url"
"strconv"
"strings"
"time"
"github.com/QuantumNous/new-api/common"
"github.com/QuantumNous/new-api/i18n"
"github.com/QuantumNous/new-api/logger"
"github.com/QuantumNous/new-api/model"
"github.com/gin-gonic/gin"
)
func init() {
Register("linuxdo", &LinuxDOProvider{})
}
// LinuxDOProvider implements OAuth for Linux DO
type LinuxDOProvider struct{}
type linuxdoUser struct {
Id int `json:"id"`
Username string `json:"username"`
Name string `json:"name"`
Active bool `json:"active"`
TrustLevel int `json:"trust_level"`
Silenced bool `json:"silenced"`
}
func (p *LinuxDOProvider) GetName() string {
return "Linux DO"
}
func (p *LinuxDOProvider) IsEnabled() bool {
return common.LinuxDOOAuthEnabled
}
func (p *LinuxDOProvider) ExchangeToken(ctx context.Context, code string, c *gin.Context) (*OAuthToken, error) {
if code == "" {
return nil, NewOAuthError(i18n.MsgOAuthInvalidCode, nil)
}
logger.LogDebug(ctx, "[OAuth-LinuxDO] ExchangeToken: code=%s...", code[:min(len(code), 10)])
// Get access token using Basic auth
tokenEndpoint := common.GetEnvOrDefaultString("LINUX_DO_TOKEN_ENDPOINT", "https://connect.linux.do/oauth2/token")
credentials := common.LinuxDOClientId + ":" + common.LinuxDOClientSecret
basicAuth := "Basic " + base64.StdEncoding.EncodeToString([]byte(credentials))
// Get redirect URI from request
scheme := "http"
if c.Request.TLS != nil {
scheme = "https"
}
redirectURI := fmt.Sprintf("%s://%s/api/oauth/linuxdo", scheme, c.Request.Host)
logger.LogDebug(ctx, "[OAuth-LinuxDO] ExchangeToken: token_endpoint=%s, redirect_uri=%s", tokenEndpoint, redirectURI)
data := url.Values{}
data.Set("grant_type", "authorization_code")
data.Set("code", code)
data.Set("redirect_uri", redirectURI)
req, err := http.NewRequestWithContext(ctx, "POST", tokenEndpoint, strings.NewReader(data.Encode()))
if err != nil {
return nil, err
}
req.Header.Set("Authorization", basicAuth)
req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
req.Header.Set("Accept", "application/json")
client := http.Client{Timeout: 5 * time.Second}
res, err := client.Do(req)
if err != nil {
logger.LogError(ctx, fmt.Sprintf("[OAuth-LinuxDO] ExchangeToken error: %s", err.Error()))
return nil, NewOAuthErrorWithRaw(i18n.MsgOAuthConnectFailed, map[string]any{"Provider": "Linux DO"}, err.Error())
}
defer res.Body.Close()
logger.LogDebug(ctx, "[OAuth-LinuxDO] ExchangeToken response status: %d", res.StatusCode)
var tokenRes struct {
AccessToken string `json:"access_token"`
Message string `json:"message"`
}
if err := json.NewDecoder(res.Body).Decode(&tokenRes); err != nil {
logger.LogError(ctx, fmt.Sprintf("[OAuth-LinuxDO] ExchangeToken decode error: %s", err.Error()))
return nil, err
}
if tokenRes.AccessToken == "" {
logger.LogError(ctx, fmt.Sprintf("[OAuth-LinuxDO] ExchangeToken failed: %s", tokenRes.Message))
return nil, NewOAuthErrorWithRaw(i18n.MsgOAuthTokenFailed, map[string]any{"Provider": "Linux DO"}, tokenRes.Message)
}
logger.LogDebug(ctx, "[OAuth-LinuxDO] ExchangeToken success")
return &OAuthToken{
AccessToken: tokenRes.AccessToken,
}, nil
}
func (p *LinuxDOProvider) GetUserInfo(ctx context.Context, token *OAuthToken) (*OAuthUser, error) {
userEndpoint := common.GetEnvOrDefaultString("LINUX_DO_USER_ENDPOINT", "https://connect.linux.do/api/user")
logger.LogDebug(ctx, "[OAuth-LinuxDO] GetUserInfo: user_endpoint=%s", userEndpoint)
req, err := http.NewRequestWithContext(ctx, "GET", userEndpoint, nil)
if err != nil {
return nil, err
}
req.Header.Set("Authorization", "Bearer "+token.AccessToken)
req.Header.Set("Accept", "application/json")
client := http.Client{Timeout: 5 * time.Second}
res, err := client.Do(req)
if err != nil {
logger.LogError(ctx, fmt.Sprintf("[OAuth-LinuxDO] GetUserInfo error: %s", err.Error()))
return nil, NewOAuthErrorWithRaw(i18n.MsgOAuthConnectFailed, map[string]any{"Provider": "Linux DO"}, err.Error())
}
defer res.Body.Close()
logger.LogDebug(ctx, "[OAuth-LinuxDO] GetUserInfo response status: %d", res.StatusCode)
var linuxdoUser linuxdoUser
if err := json.NewDecoder(res.Body).Decode(&linuxdoUser); err != nil {
logger.LogError(ctx, fmt.Sprintf("[OAuth-LinuxDO] GetUserInfo decode error: %s", err.Error()))
return nil, err
}
if linuxdoUser.Id == 0 {
logger.LogError(ctx, "[OAuth-LinuxDO] GetUserInfo failed: invalid user id")
return nil, NewOAuthError(i18n.MsgOAuthUserInfoEmpty, map[string]any{"Provider": "Linux DO"})
}
logger.LogDebug(ctx, "[OAuth-LinuxDO] GetUserInfo: id=%d, username=%s, name=%s, trust_level=%d, active=%v, silenced=%v",
linuxdoUser.Id, linuxdoUser.Username, linuxdoUser.Name, linuxdoUser.TrustLevel, linuxdoUser.Active, linuxdoUser.Silenced)
// Check trust level
if linuxdoUser.TrustLevel < common.LinuxDOMinimumTrustLevel {
logger.LogWarn(ctx, fmt.Sprintf("[OAuth-LinuxDO] GetUserInfo: trust level too low (required=%d, current=%d)",
common.LinuxDOMinimumTrustLevel, linuxdoUser.TrustLevel))
return nil, &TrustLevelError{
Required: common.LinuxDOMinimumTrustLevel,
Current: linuxdoUser.TrustLevel,
}
}
logger.LogDebug(ctx, "[OAuth-LinuxDO] GetUserInfo success: id=%d, username=%s", linuxdoUser.Id, linuxdoUser.Username)
return &OAuthUser{
ProviderUserID: strconv.Itoa(linuxdoUser.Id),
Username: linuxdoUser.Username,
DisplayName: linuxdoUser.Name,
Extra: map[string]any{
"trust_level": linuxdoUser.TrustLevel,
"active": linuxdoUser.Active,
"silenced": linuxdoUser.Silenced,
},
}, nil
}
func (p *LinuxDOProvider) IsUserIDTaken(providerUserID string) bool {
return model.IsLinuxDOIdAlreadyTaken(providerUserID)
}
func (p *LinuxDOProvider) FillUserByProviderID(user *model.User, providerUserID string) error {
user.LinuxDOId = providerUserID
return user.FillUserByLinuxDOId()
}
func (p *LinuxDOProvider) SetProviderUserID(user *model.User, providerUserID string) {
user.LinuxDOId = providerUserID
}
func (p *LinuxDOProvider) GetProviderPrefix() string {
return "linuxdo_"
}
// TrustLevelError indicates the user's trust level is too low
type TrustLevelError struct {
Required int
Current int
}
func (e *TrustLevelError) Error() string {
return "trust level too low"
}

177
oauth/oidc.go Normal file
View File

@@ -0,0 +1,177 @@
package oauth
import (
"context"
"encoding/json"
"fmt"
"net/http"
"net/url"
"strings"
"time"
"github.com/QuantumNous/new-api/i18n"
"github.com/QuantumNous/new-api/logger"
"github.com/QuantumNous/new-api/model"
"github.com/QuantumNous/new-api/setting/system_setting"
"github.com/gin-gonic/gin"
)
func init() {
Register("oidc", &OIDCProvider{})
}
// OIDCProvider implements OAuth for OIDC
type OIDCProvider struct{}
type oidcOAuthResponse struct {
AccessToken string `json:"access_token"`
IDToken string `json:"id_token"`
RefreshToken string `json:"refresh_token"`
TokenType string `json:"token_type"`
ExpiresIn int `json:"expires_in"`
Scope string `json:"scope"`
}
type oidcUser struct {
OpenID string `json:"sub"`
Email string `json:"email"`
Name string `json:"name"`
PreferredUsername string `json:"preferred_username"`
Picture string `json:"picture"`
}
func (p *OIDCProvider) GetName() string {
return "OIDC"
}
func (p *OIDCProvider) IsEnabled() bool {
return system_setting.GetOIDCSettings().Enabled
}
func (p *OIDCProvider) ExchangeToken(ctx context.Context, code string, c *gin.Context) (*OAuthToken, error) {
if code == "" {
return nil, NewOAuthError(i18n.MsgOAuthInvalidCode, nil)
}
logger.LogDebug(ctx, "[OAuth-OIDC] ExchangeToken: code=%s...", code[:min(len(code), 10)])
settings := system_setting.GetOIDCSettings()
redirectUri := fmt.Sprintf("%s/oauth/oidc", system_setting.ServerAddress)
values := url.Values{}
values.Set("client_id", settings.ClientId)
values.Set("client_secret", settings.ClientSecret)
values.Set("code", code)
values.Set("grant_type", "authorization_code")
values.Set("redirect_uri", redirectUri)
logger.LogDebug(ctx, "[OAuth-OIDC] ExchangeToken: token_endpoint=%s, redirect_uri=%s", settings.TokenEndpoint, redirectUri)
req, err := http.NewRequestWithContext(ctx, "POST", settings.TokenEndpoint, strings.NewReader(values.Encode()))
if err != nil {
return nil, err
}
req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
req.Header.Set("Accept", "application/json")
client := http.Client{
Timeout: 5 * time.Second,
}
res, err := client.Do(req)
if err != nil {
logger.LogError(ctx, fmt.Sprintf("[OAuth-OIDC] ExchangeToken error: %s", err.Error()))
return nil, NewOAuthErrorWithRaw(i18n.MsgOAuthConnectFailed, map[string]any{"Provider": "OIDC"}, err.Error())
}
defer res.Body.Close()
logger.LogDebug(ctx, "[OAuth-OIDC] ExchangeToken response status: %d", res.StatusCode)
var oidcResponse oidcOAuthResponse
err = json.NewDecoder(res.Body).Decode(&oidcResponse)
if err != nil {
logger.LogError(ctx, fmt.Sprintf("[OAuth-OIDC] ExchangeToken decode error: %s", err.Error()))
return nil, err
}
if oidcResponse.AccessToken == "" {
logger.LogError(ctx, "[OAuth-OIDC] ExchangeToken failed: empty access token")
return nil, NewOAuthError(i18n.MsgOAuthTokenFailed, map[string]any{"Provider": "OIDC"})
}
logger.LogDebug(ctx, "[OAuth-OIDC] ExchangeToken success: scope=%s", oidcResponse.Scope)
return &OAuthToken{
AccessToken: oidcResponse.AccessToken,
TokenType: oidcResponse.TokenType,
RefreshToken: oidcResponse.RefreshToken,
ExpiresIn: oidcResponse.ExpiresIn,
Scope: oidcResponse.Scope,
IDToken: oidcResponse.IDToken,
}, nil
}
func (p *OIDCProvider) GetUserInfo(ctx context.Context, token *OAuthToken) (*OAuthUser, error) {
settings := system_setting.GetOIDCSettings()
logger.LogDebug(ctx, "[OAuth-OIDC] GetUserInfo: userinfo_endpoint=%s", settings.UserInfoEndpoint)
req, err := http.NewRequestWithContext(ctx, "GET", settings.UserInfoEndpoint, nil)
if err != nil {
return nil, err
}
req.Header.Set("Authorization", "Bearer "+token.AccessToken)
client := http.Client{
Timeout: 5 * time.Second,
}
res, err := client.Do(req)
if err != nil {
logger.LogError(ctx, fmt.Sprintf("[OAuth-OIDC] GetUserInfo error: %s", err.Error()))
return nil, NewOAuthErrorWithRaw(i18n.MsgOAuthConnectFailed, map[string]any{"Provider": "OIDC"}, err.Error())
}
defer res.Body.Close()
logger.LogDebug(ctx, "[OAuth-OIDC] GetUserInfo response status: %d", res.StatusCode)
if res.StatusCode != http.StatusOK {
logger.LogError(ctx, fmt.Sprintf("[OAuth-OIDC] GetUserInfo failed: status=%d", res.StatusCode))
return nil, NewOAuthError(i18n.MsgOAuthGetUserErr, nil)
}
var oidcUser oidcUser
err = json.NewDecoder(res.Body).Decode(&oidcUser)
if err != nil {
logger.LogError(ctx, fmt.Sprintf("[OAuth-OIDC] GetUserInfo decode error: %s", err.Error()))
return nil, err
}
if oidcUser.OpenID == "" || oidcUser.Email == "" {
logger.LogError(ctx, fmt.Sprintf("[OAuth-OIDC] GetUserInfo failed: empty fields (sub=%s, email=%s)", oidcUser.OpenID, oidcUser.Email))
return nil, NewOAuthError(i18n.MsgOAuthUserInfoEmpty, map[string]any{"Provider": "OIDC"})
}
logger.LogDebug(ctx, "[OAuth-OIDC] GetUserInfo success: sub=%s, username=%s, name=%s, email=%s", oidcUser.OpenID, oidcUser.PreferredUsername, oidcUser.Name, oidcUser.Email)
return &OAuthUser{
ProviderUserID: oidcUser.OpenID,
Username: oidcUser.PreferredUsername,
DisplayName: oidcUser.Name,
Email: oidcUser.Email,
}, nil
}
func (p *OIDCProvider) IsUserIDTaken(providerUserID string) bool {
return model.IsOidcIdAlreadyTaken(providerUserID)
}
func (p *OIDCProvider) FillUserByProviderID(user *model.User, providerUserID string) error {
user.OidcId = providerUserID
return user.FillUserByOidcId()
}
func (p *OIDCProvider) SetProviderUserID(user *model.User, providerUserID string) {
user.OidcId = providerUserID
}
func (p *OIDCProvider) GetProviderPrefix() string {
return "oidc_"
}

36
oauth/provider.go Normal file
View File

@@ -0,0 +1,36 @@
package oauth
import (
"context"
"github.com/QuantumNous/new-api/model"
"github.com/gin-gonic/gin"
)
// Provider defines the interface for OAuth providers
type Provider interface {
// GetName returns the display name of the provider (e.g., "GitHub", "Discord")
GetName() string
// IsEnabled returns whether this OAuth provider is enabled
IsEnabled() bool
// ExchangeToken exchanges the authorization code for an access token
// The gin.Context is passed for providers that need request info (e.g., for redirect_uri)
ExchangeToken(ctx context.Context, code string, c *gin.Context) (*OAuthToken, error)
// GetUserInfo retrieves user information using the access token
GetUserInfo(ctx context.Context, token *OAuthToken) (*OAuthUser, error)
// IsUserIDTaken checks if the provider user ID is already associated with an account
IsUserIDTaken(providerUserID string) bool
// FillUserByProviderID fills the user model by provider user ID
FillUserByProviderID(user *model.User, providerUserID string) error
// SetProviderUserID sets the provider user ID on the user model
SetProviderUserID(user *model.User, providerUserID string)
// GetProviderPrefix returns the prefix for auto-generated usernames (e.g., "github_")
GetProviderPrefix() string
}

134
oauth/registry.go Normal file
View File

@@ -0,0 +1,134 @@
package oauth
import (
"fmt"
"sync"
"github.com/QuantumNous/new-api/common"
"github.com/QuantumNous/new-api/model"
)
var (
providers = make(map[string]Provider)
mu sync.RWMutex
// customProviderSlugs tracks which providers are custom (can be unregistered)
customProviderSlugs = make(map[string]bool)
)
// Register registers an OAuth provider with the given name
func Register(name string, provider Provider) {
mu.Lock()
defer mu.Unlock()
providers[name] = provider
}
// RegisterCustom registers a custom OAuth provider (can be unregistered later)
func RegisterCustom(name string, provider Provider) {
mu.Lock()
defer mu.Unlock()
providers[name] = provider
customProviderSlugs[name] = true
}
// Unregister removes a provider from the registry
func Unregister(name string) {
mu.Lock()
defer mu.Unlock()
delete(providers, name)
delete(customProviderSlugs, name)
}
// GetProvider returns the OAuth provider for the given name
func GetProvider(name string) Provider {
mu.RLock()
defer mu.RUnlock()
return providers[name]
}
// GetAllProviders returns all registered OAuth providers
func GetAllProviders() map[string]Provider {
mu.RLock()
defer mu.RUnlock()
result := make(map[string]Provider, len(providers))
for k, v := range providers {
result[k] = v
}
return result
}
// GetEnabledCustomProviders returns all enabled custom OAuth providers
func GetEnabledCustomProviders() []*GenericOAuthProvider {
mu.RLock()
defer mu.RUnlock()
var result []*GenericOAuthProvider
for name, provider := range providers {
if customProviderSlugs[name] {
if gp, ok := provider.(*GenericOAuthProvider); ok && gp.IsEnabled() {
result = append(result, gp)
}
}
}
return result
}
// IsProviderRegistered checks if a provider is registered
func IsProviderRegistered(name string) bool {
mu.RLock()
defer mu.RUnlock()
_, ok := providers[name]
return ok
}
// IsCustomProvider checks if a provider is a custom provider
func IsCustomProvider(name string) bool {
mu.RLock()
defer mu.RUnlock()
return customProviderSlugs[name]
}
// LoadCustomProviders loads all custom OAuth providers from the database
func LoadCustomProviders() error {
// First, unregister all existing custom providers
mu.Lock()
for name := range customProviderSlugs {
delete(providers, name)
}
customProviderSlugs = make(map[string]bool)
mu.Unlock()
// Load all custom providers from database
customProviders, err := model.GetAllCustomOAuthProviders()
if err != nil {
common.SysError("Failed to load custom OAuth providers: " + err.Error())
return err
}
// Register each custom provider
for _, config := range customProviders {
provider := NewGenericOAuthProvider(config)
RegisterCustom(config.Slug, provider)
common.SysLog("Loaded custom OAuth provider: " + config.Name + " (" + config.Slug + ")")
}
common.SysLog(fmt.Sprintf("Loaded %d custom OAuth providers", len(customProviders)))
return nil
}
// ReloadCustomProviders reloads all custom OAuth providers from the database
func ReloadCustomProviders() error {
return LoadCustomProviders()
}
// RegisterOrUpdateCustomProvider registers or updates a single custom provider
func RegisterOrUpdateCustomProvider(config *model.CustomOAuthProvider) {
provider := NewGenericOAuthProvider(config)
mu.Lock()
defer mu.Unlock()
providers[config.Slug] = provider
customProviderSlugs[config.Slug] = true
}
// UnregisterCustomProvider unregisters a custom provider by slug
func UnregisterCustomProvider(slug string) {
Unregister(slug)
}

59
oauth/types.go Normal file
View File

@@ -0,0 +1,59 @@
package oauth
// OAuthToken represents the token received from OAuth provider
type OAuthToken struct {
AccessToken string `json:"access_token"`
TokenType string `json:"token_type"`
RefreshToken string `json:"refresh_token,omitempty"`
ExpiresIn int `json:"expires_in,omitempty"`
Scope string `json:"scope,omitempty"`
IDToken string `json:"id_token,omitempty"`
}
// OAuthUser represents the user info from OAuth provider
type OAuthUser struct {
// ProviderUserID is the unique identifier from the OAuth provider
ProviderUserID string
// Username is the username from the OAuth provider (e.g., GitHub login)
Username string
// DisplayName is the display name from the OAuth provider
DisplayName string
// Email is the email from the OAuth provider
Email string
// Extra contains any additional provider-specific data
Extra map[string]any
}
// OAuthError represents a translatable OAuth error
type OAuthError struct {
// MsgKey is the i18n message key
MsgKey string
// Params contains optional parameters for the message template
Params map[string]any
// RawError is the underlying error for logging purposes
RawError string
}
func (e *OAuthError) Error() string {
if e.RawError != "" {
return e.RawError
}
return e.MsgKey
}
// NewOAuthError creates a new OAuth error with the given message key
func NewOAuthError(msgKey string, params map[string]any) *OAuthError {
return &OAuthError{
MsgKey: msgKey,
Params: params,
}
}
// NewOAuthErrorWithRaw creates a new OAuth error with raw error message for logging
func NewOAuthErrorWithRaw(msgKey string, params map[string]any, rawError string) *OAuthError {
return &OAuthError{
MsgKey: msgKey,
Params: params,
RawError: rawError,
}
}

View File

@@ -224,10 +224,10 @@ func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycom
case types.RelayFormatClaude:
if supportsAliAnthropicMessages(info.UpstreamModelName) {
if info.IsStream {
return claude.ClaudeStreamHandler(c, resp, info, claude.RequestModeMessage)
return claude.ClaudeStreamHandler(c, resp, info)
}
return claude.ClaudeHandler(c, resp, info, claude.RequestModeMessage)
return claude.ClaudeHandler(c, resp, info)
}
adaptor := openai.Adaptor{}

View File

@@ -58,6 +58,8 @@ var passthroughSkipHeaderNamesLower = map[string]struct{}{
"transfer-encoding": {},
"upgrade": {},
"cookie": {},
// Additional headers that should not be forwarded by name-matching passthrough rules.
"host": {},
"content-length": {},

View File

@@ -3,9 +3,6 @@ package aws
import "strings"
var awsModelIDMap = map[string]string{
"claude-instant-1.2": "anthropic.claude-instant-v1",
"claude-2.0": "anthropic.claude-v2",
"claude-2.1": "anthropic.claude-v2:1",
"claude-3-sonnet-20240229": "anthropic.claude-3-sonnet-20240229-v1:0",
"claude-3-opus-20240229": "anthropic.claude-3-opus-20240229-v1:0",
"claude-3-haiku-20240307": "anthropic.claude-3-haiku-20240307-v1:0",
@@ -19,6 +16,7 @@ var awsModelIDMap = map[string]string{
"claude-sonnet-4-5-20250929": "anthropic.claude-sonnet-4-5-20250929-v1:0",
"claude-haiku-4-5-20251001": "anthropic.claude-haiku-4-5-20251001-v1:0",
"claude-opus-4-5-20251101": "anthropic.claude-opus-4-5-20251101-v1:0",
"claude-opus-4-6": "anthropic.claude-opus-4-6-v1",
// Nova models
"nova-micro-v1:0": "amazon.nova-micro-v1:0",
"nova-lite-v1:0": "amazon.nova-lite-v1:0",
@@ -82,6 +80,11 @@ var awsModelCanCrossRegionMap = map[string]map[string]bool{
"ap": true,
"eu": true,
},
"anthropic.claude-opus-4-6-v1": {
"us": true,
"ap": true,
"eu": true,
},
"anthropic.claude-haiku-4-5-20251001-v1:0": {
"us": true,
"ap": true,

View File

@@ -26,6 +26,7 @@ type AwsClaudeRequest struct {
Tools any `json:"tools,omitempty"`
ToolChoice any `json:"tool_choice,omitempty"`
Thinking *dto.Thinking `json:"thinking,omitempty"`
OutputConfig json.RawMessage `json:"output_config,omitempty"`
}
func formatRequest(requestBody io.Reader, requestHeader http.Header) (*AwsClaudeRequest, error) {

View File

@@ -233,7 +233,7 @@ func awsHandler(c *gin.Context, info *relaycommon.RelayInfo, a *Adaptor) (*types
c.Writer.Header().Set("Content-Type", *awsResp.ContentType)
}
handlerErr := claude.HandleClaudeResponseData(c, info, claudeInfo, nil, awsResp.Body, claude.RequestModeMessage)
handlerErr := claude.HandleClaudeResponseData(c, info, claudeInfo, nil, awsResp.Body)
if handlerErr != nil {
return handlerErr, nil
}
@@ -264,7 +264,7 @@ func awsStreamHandler(c *gin.Context, info *relaycommon.RelayInfo, a *Adaptor) (
switch v := event.(type) {
case *bedrockruntimeTypes.ResponseStreamMemberChunk:
info.SetFirstResponseTime()
respErr := claude.HandleStreamResponseData(c, info, claudeInfo, string(v.Value.Bytes), claude.RequestModeMessage)
respErr := claude.HandleStreamResponseData(c, info, claudeInfo, string(v.Value.Bytes))
if respErr != nil {
return respErr, nil
}
@@ -277,7 +277,7 @@ func awsStreamHandler(c *gin.Context, info *relaycommon.RelayInfo, a *Adaptor) (
}
}
claude.HandleStreamFinalResponse(c, info, claudeInfo, claude.RequestModeMessage)
claude.HandleStreamFinalResponse(c, info, claudeInfo)
return nil, claudeInfo.Usage
}

View File

@@ -5,7 +5,6 @@ import (
"fmt"
"io"
"net/http"
"strings"
"github.com/QuantumNous/new-api/dto"
"github.com/QuantumNous/new-api/relay/channel"
@@ -16,13 +15,7 @@ import (
"github.com/gin-gonic/gin"
)
const (
RequestModeCompletion = 1
RequestModeMessage = 2
)
type Adaptor struct {
RequestMode int
}
func (a *Adaptor) ConvertGeminiRequest(*gin.Context, *relaycommon.RelayInfo, *dto.GeminiChatRequest) (any, error) {
@@ -45,20 +38,10 @@ func (a *Adaptor) ConvertImageRequest(c *gin.Context, info *relaycommon.RelayInf
}
func (a *Adaptor) Init(info *relaycommon.RelayInfo) {
if strings.HasPrefix(info.UpstreamModelName, "claude-2") || strings.HasPrefix(info.UpstreamModelName, "claude-instant") {
a.RequestMode = RequestModeCompletion
} else {
a.RequestMode = RequestModeMessage
}
}
func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) {
baseURL := ""
if a.RequestMode == RequestModeMessage {
baseURL = fmt.Sprintf("%s/v1/messages", info.ChannelBaseUrl)
} else {
baseURL = fmt.Sprintf("%s/v1/complete", info.ChannelBaseUrl)
}
baseURL := fmt.Sprintf("%s/v1/messages", info.ChannelBaseUrl)
if info.IsClaudeBetaQuery {
baseURL = baseURL + "?beta=true"
}
@@ -90,11 +73,7 @@ func (a *Adaptor) ConvertOpenAIRequest(c *gin.Context, info *relaycommon.RelayIn
if request == nil {
return nil, errors.New("request is nil")
}
if a.RequestMode == RequestModeCompletion {
return RequestOpenAI2ClaudeComplete(*request), nil
} else {
return RequestOpenAI2ClaudeMessage(c, *request)
}
return RequestOpenAI2ClaudeMessage(c, *request)
}
func (a *Adaptor) ConvertRerankRequest(c *gin.Context, relayMode int, request dto.RerankRequest) (any, error) {
@@ -117,11 +96,10 @@ func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, request
func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage any, err *types.NewAPIError) {
if info.IsStream {
return ClaudeStreamHandler(c, resp, info, a.RequestMode)
return ClaudeStreamHandler(c, resp, info)
} else {
return ClaudeHandler(c, resp, info, a.RequestMode)
return ClaudeHandler(c, resp, info)
}
return
}
func (a *Adaptor) GetModelList() []string {

View File

@@ -1,10 +1,6 @@
package claude
var ModelList = []string{
"claude-instant-1.2",
"claude-2",
"claude-2.0",
"claude-2.1",
"claude-3-sonnet-20240229",
"claude-3-opus-20240229",
"claude-3-haiku-20240307",
@@ -24,6 +20,11 @@ var ModelList = []string{
"claude-sonnet-4-5-20250929-thinking",
"claude-opus-4-5-20251101",
"claude-opus-4-5-20251101-thinking",
"claude-opus-4-6",
"claude-opus-4-6-max",
"claude-opus-4-6-high",
"claude-opus-4-6-medium",
"claude-opus-4-6-low",
}
var ChannelName = "claude"

View File

@@ -17,6 +17,7 @@ import (
"github.com/QuantumNous/new-api/relay/reasonmap"
"github.com/QuantumNous/new-api/service"
"github.com/QuantumNous/new-api/setting/model_setting"
"github.com/QuantumNous/new-api/setting/reasoning"
"github.com/QuantumNous/new-api/types"
"github.com/gin-gonic/gin"
@@ -41,37 +42,6 @@ func maybeMarkClaudeRefusal(c *gin.Context, stopReason string) {
}
}
func RequestOpenAI2ClaudeComplete(textRequest dto.GeneralOpenAIRequest) *dto.ClaudeRequest {
claudeRequest := dto.ClaudeRequest{
Model: textRequest.Model,
Prompt: "",
StopSequences: nil,
Temperature: textRequest.Temperature,
TopP: textRequest.TopP,
TopK: textRequest.TopK,
Stream: textRequest.Stream,
}
if claudeRequest.MaxTokensToSample == 0 {
claudeRequest.MaxTokensToSample = 4096
}
prompt := ""
for _, message := range textRequest.Messages {
if message.Role == "user" {
prompt += fmt.Sprintf("\n\nHuman: %s", message.StringContent())
} else if message.Role == "assistant" {
prompt += fmt.Sprintf("\n\nAssistant: %s", message.StringContent())
} else if message.Role == "system" {
if prompt == "" {
prompt = message.StringContent()
}
}
}
prompt += "\n\nAssistant:"
claudeRequest.Prompt = prompt
return &claudeRequest
}
func RequestOpenAI2ClaudeMessage(c *gin.Context, textRequest dto.GeneralOpenAIRequest) (*dto.ClaudeRequest, error) {
claudeTools := make([]any, 0, len(textRequest.Tools))
@@ -172,7 +142,16 @@ func RequestOpenAI2ClaudeMessage(c *gin.Context, textRequest dto.GeneralOpenAIRe
claudeRequest.MaxTokens = uint(model_setting.GetClaudeSettings().GetDefaultMaxTokens(textRequest.Model))
}
if model_setting.GetClaudeSettings().ThinkingAdapterEnabled &&
if baseModel, effortLevel, ok := reasoning.TrimEffortSuffix(textRequest.Model); ok && effortLevel != "" &&
strings.HasPrefix(textRequest.Model, "claude-opus-4-6") {
claudeRequest.Model = baseModel
claudeRequest.Thinking = &dto.Thinking{
Type: "adaptive",
}
claudeRequest.OutputConfig = json.RawMessage(fmt.Sprintf(`{"effort":"%s"}`, effortLevel))
claudeRequest.TopP = 0
claudeRequest.Temperature = common.GetPointer[float64](1.0)
} else if model_setting.GetClaudeSettings().ThinkingAdapterEnabled &&
strings.HasSuffix(textRequest.Model, "-thinking") {
// 因为BudgetTokens 必须大于1024
@@ -411,7 +390,7 @@ func RequestOpenAI2ClaudeMessage(c *gin.Context, textRequest dto.GeneralOpenAIRe
return &claudeRequest, nil
}
func StreamResponseClaude2OpenAI(reqMode int, claudeResponse *dto.ClaudeResponse) *dto.ChatCompletionsStreamResponse {
func StreamResponseClaude2OpenAI(claudeResponse *dto.ClaudeResponse) *dto.ChatCompletionsStreamResponse {
var response dto.ChatCompletionsStreamResponse
response.Object = "chat.completion.chunk"
response.Model = claudeResponse.Model
@@ -425,74 +404,66 @@ func StreamResponseClaude2OpenAI(reqMode int, claudeResponse *dto.ClaudeResponse
}
}
var choice dto.ChatCompletionsStreamResponseChoice
if reqMode == RequestModeCompletion {
choice.Delta.SetContentString(claudeResponse.Completion)
finishReason := stopReasonClaude2OpenAI(claudeResponse.StopReason)
if finishReason != "null" {
choice.FinishReason = &finishReason
if claudeResponse.Type == "message_start" {
if claudeResponse.Message != nil {
response.Id = claudeResponse.Message.Id
response.Model = claudeResponse.Message.Model
}
} else {
if claudeResponse.Type == "message_start" {
if claudeResponse.Message != nil {
response.Id = claudeResponse.Message.Id
response.Model = claudeResponse.Message.Model
//claudeUsage = &claudeResponse.Message.Usage
choice.Delta.SetContentString("")
choice.Delta.Role = "assistant"
} else if claudeResponse.Type == "content_block_start" {
if claudeResponse.ContentBlock != nil {
// 如果是文本块,尽可能发送首段文本(若存在)
if claudeResponse.ContentBlock.Type == "text" && claudeResponse.ContentBlock.Text != nil {
choice.Delta.SetContentString(*claudeResponse.ContentBlock.Text)
}
//claudeUsage = &claudeResponse.Message.Usage
choice.Delta.SetContentString("")
choice.Delta.Role = "assistant"
} else if claudeResponse.Type == "content_block_start" {
if claudeResponse.ContentBlock != nil {
// 如果是文本块,尽可能发送首段文本(若存在)
if claudeResponse.ContentBlock.Type == "text" && claudeResponse.ContentBlock.Text != nil {
choice.Delta.SetContentString(*claudeResponse.ContentBlock.Text)
}
if claudeResponse.ContentBlock.Type == "tool_use" {
tools = append(tools, dto.ToolCallResponse{
Index: common.GetPointer(fcIdx),
ID: claudeResponse.ContentBlock.Id,
Type: "function",
Function: dto.FunctionResponse{
Name: claudeResponse.ContentBlock.Name,
Arguments: "",
},
})
}
} else {
return nil
if claudeResponse.ContentBlock.Type == "tool_use" {
tools = append(tools, dto.ToolCallResponse{
Index: common.GetPointer(fcIdx),
ID: claudeResponse.ContentBlock.Id,
Type: "function",
Function: dto.FunctionResponse{
Name: claudeResponse.ContentBlock.Name,
Arguments: "",
},
})
}
} else if claudeResponse.Type == "content_block_delta" {
if claudeResponse.Delta != nil {
choice.Delta.Content = claudeResponse.Delta.Text
switch claudeResponse.Delta.Type {
case "input_json_delta":
tools = append(tools, dto.ToolCallResponse{
Type: "function",
Index: common.GetPointer(fcIdx),
Function: dto.FunctionResponse{
Arguments: *claudeResponse.Delta.PartialJson,
},
})
case "signature_delta":
// 加密的不处理
signatureContent := "\n"
choice.Delta.ReasoningContent = &signatureContent
case "thinking_delta":
choice.Delta.ReasoningContent = claudeResponse.Delta.Thinking
}
}
} else if claudeResponse.Type == "message_delta" {
if claudeResponse.Delta != nil && claudeResponse.Delta.StopReason != nil {
finishReason := stopReasonClaude2OpenAI(*claudeResponse.Delta.StopReason)
if finishReason != "null" {
choice.FinishReason = &finishReason
}
}
//claudeUsage = &claudeResponse.Usage
} else if claudeResponse.Type == "message_stop" {
return nil
} else {
return nil
}
} else if claudeResponse.Type == "content_block_delta" {
if claudeResponse.Delta != nil {
choice.Delta.Content = claudeResponse.Delta.Text
switch claudeResponse.Delta.Type {
case "input_json_delta":
tools = append(tools, dto.ToolCallResponse{
Type: "function",
Index: common.GetPointer(fcIdx),
Function: dto.FunctionResponse{
Arguments: *claudeResponse.Delta.PartialJson,
},
})
case "signature_delta":
// 加密的不处理
signatureContent := "\n"
choice.Delta.ReasoningContent = &signatureContent
case "thinking_delta":
choice.Delta.ReasoningContent = claudeResponse.Delta.Thinking
}
}
} else if claudeResponse.Type == "message_delta" {
if claudeResponse.Delta != nil && claudeResponse.Delta.StopReason != nil {
finishReason := stopReasonClaude2OpenAI(*claudeResponse.Delta.StopReason)
if finishReason != "null" {
choice.FinishReason = &finishReason
}
}
//claudeUsage = &claudeResponse.Usage
} else if claudeResponse.Type == "message_stop" {
return nil
} else {
return nil
}
if len(tools) > 0 {
choice.Delta.Content = nil // compatible with other OpenAI derivative applications, like LobeOpenAICompatibleFactory ...
@@ -503,7 +474,7 @@ func StreamResponseClaude2OpenAI(reqMode int, claudeResponse *dto.ClaudeResponse
return &response
}
func ResponseClaude2OpenAI(reqMode int, claudeResponse *dto.ClaudeResponse) *dto.OpenAITextResponse {
func ResponseClaude2OpenAI(claudeResponse *dto.ClaudeResponse) *dto.OpenAITextResponse {
choices := make([]dto.OpenAITextResponseChoice, 0)
fullTextResponse := dto.OpenAITextResponse{
Id: fmt.Sprintf("chatcmpl-%s", common.GetUUID()),
@@ -521,39 +492,26 @@ func ResponseClaude2OpenAI(reqMode int, claudeResponse *dto.ClaudeResponse) *dto
tools := make([]dto.ToolCallResponse, 0)
thinkingContent := ""
if reqMode == RequestModeCompletion {
choice := dto.OpenAITextResponseChoice{
Index: 0,
Message: dto.Message{
Role: "assistant",
Content: strings.TrimPrefix(claudeResponse.Completion, " "),
Name: nil,
},
FinishReason: stopReasonClaude2OpenAI(claudeResponse.StopReason),
}
choices = append(choices, choice)
} else {
fullTextResponse.Id = claudeResponse.Id
for _, message := range claudeResponse.Content {
switch message.Type {
case "tool_use":
args, _ := json.Marshal(message.Input)
tools = append(tools, dto.ToolCallResponse{
ID: message.Id,
Type: "function", // compatible with other OpenAI derivative applications
Function: dto.FunctionResponse{
Name: message.Name,
Arguments: string(args),
},
})
case "thinking":
// 加密的不管, 只输出明文的推理过程
if message.Thinking != nil {
thinkingContent = *message.Thinking
}
case "text":
responseText = message.GetText()
fullTextResponse.Id = claudeResponse.Id
for _, message := range claudeResponse.Content {
switch message.Type {
case "tool_use":
args, _ := json.Marshal(message.Input)
tools = append(tools, dto.ToolCallResponse{
ID: message.Id,
Type: "function", // compatible with other OpenAI derivative applications
Function: dto.FunctionResponse{
Name: message.Name,
Arguments: string(args),
},
})
case "thinking":
// 加密的不管, 只输出明文的推理过程
if message.Thinking != nil {
thinkingContent = *message.Thinking
}
case "text":
responseText = message.GetText()
}
}
choice := dto.OpenAITextResponseChoice{
@@ -586,71 +544,67 @@ type ClaudeResponseInfo struct {
Done bool
}
func FormatClaudeResponseInfo(requestMode int, claudeResponse *dto.ClaudeResponse, oaiResponse *dto.ChatCompletionsStreamResponse, claudeInfo *ClaudeResponseInfo) bool {
func FormatClaudeResponseInfo(claudeResponse *dto.ClaudeResponse, oaiResponse *dto.ChatCompletionsStreamResponse, claudeInfo *ClaudeResponseInfo) bool {
if claudeInfo == nil {
return false
}
if claudeInfo.Usage == nil {
claudeInfo.Usage = &dto.Usage{}
}
if requestMode == RequestModeCompletion {
claudeInfo.ResponseText.WriteString(claudeResponse.Completion)
} else {
if claudeResponse.Type == "message_start" {
if claudeResponse.Message != nil {
claudeInfo.ResponseId = claudeResponse.Message.Id
claudeInfo.Model = claudeResponse.Message.Model
}
// message_start, 获取usage
if claudeResponse.Message != nil && claudeResponse.Message.Usage != nil {
claudeInfo.Usage.PromptTokens = claudeResponse.Message.Usage.InputTokens
claudeInfo.Usage.PromptTokensDetails.CachedTokens = claudeResponse.Message.Usage.CacheReadInputTokens
claudeInfo.Usage.PromptTokensDetails.CachedCreationTokens = claudeResponse.Message.Usage.CacheCreationInputTokens
claudeInfo.Usage.ClaudeCacheCreation5mTokens = claudeResponse.Message.Usage.GetCacheCreation5mTokens()
claudeInfo.Usage.ClaudeCacheCreation1hTokens = claudeResponse.Message.Usage.GetCacheCreation1hTokens()
claudeInfo.Usage.CompletionTokens = claudeResponse.Message.Usage.OutputTokens
}
} else if claudeResponse.Type == "content_block_delta" {
if claudeResponse.Delta != nil {
if claudeResponse.Delta.Text != nil {
claudeInfo.ResponseText.WriteString(*claudeResponse.Delta.Text)
}
if claudeResponse.Delta.Thinking != nil {
claudeInfo.ResponseText.WriteString(*claudeResponse.Delta.Thinking)
}
}
} else if claudeResponse.Type == "message_delta" {
// 最终的usage获取
if claudeResponse.Usage != nil {
if claudeResponse.Usage.InputTokens > 0 {
// 不叠加,只取最新的
claudeInfo.Usage.PromptTokens = claudeResponse.Usage.InputTokens
}
if claudeResponse.Usage.CacheReadInputTokens > 0 {
claudeInfo.Usage.PromptTokensDetails.CachedTokens = claudeResponse.Usage.CacheReadInputTokens
}
if claudeResponse.Usage.CacheCreationInputTokens > 0 {
claudeInfo.Usage.PromptTokensDetails.CachedCreationTokens = claudeResponse.Usage.CacheCreationInputTokens
}
if cacheCreation5m := claudeResponse.Usage.GetCacheCreation5mTokens(); cacheCreation5m > 0 {
claudeInfo.Usage.ClaudeCacheCreation5mTokens = cacheCreation5m
}
if cacheCreation1h := claudeResponse.Usage.GetCacheCreation1hTokens(); cacheCreation1h > 0 {
claudeInfo.Usage.ClaudeCacheCreation1hTokens = cacheCreation1h
}
if claudeResponse.Usage.OutputTokens > 0 {
claudeInfo.Usage.CompletionTokens = claudeResponse.Usage.OutputTokens
}
claudeInfo.Usage.TotalTokens = claudeInfo.Usage.PromptTokens + claudeInfo.Usage.CompletionTokens
}
// 判断是否完整
claudeInfo.Done = true
} else if claudeResponse.Type == "content_block_start" {
} else {
return false
if claudeResponse.Type == "message_start" {
if claudeResponse.Message != nil {
claudeInfo.ResponseId = claudeResponse.Message.Id
claudeInfo.Model = claudeResponse.Message.Model
}
// message_start, 获取usage
if claudeResponse.Message != nil && claudeResponse.Message.Usage != nil {
claudeInfo.Usage.PromptTokens = claudeResponse.Message.Usage.InputTokens
claudeInfo.Usage.PromptTokensDetails.CachedTokens = claudeResponse.Message.Usage.CacheReadInputTokens
claudeInfo.Usage.PromptTokensDetails.CachedCreationTokens = claudeResponse.Message.Usage.CacheCreationInputTokens
claudeInfo.Usage.ClaudeCacheCreation5mTokens = claudeResponse.Message.Usage.GetCacheCreation5mTokens()
claudeInfo.Usage.ClaudeCacheCreation1hTokens = claudeResponse.Message.Usage.GetCacheCreation1hTokens()
claudeInfo.Usage.CompletionTokens = claudeResponse.Message.Usage.OutputTokens
}
} else if claudeResponse.Type == "content_block_delta" {
if claudeResponse.Delta != nil {
if claudeResponse.Delta.Text != nil {
claudeInfo.ResponseText.WriteString(*claudeResponse.Delta.Text)
}
if claudeResponse.Delta.Thinking != nil {
claudeInfo.ResponseText.WriteString(*claudeResponse.Delta.Thinking)
}
}
} else if claudeResponse.Type == "message_delta" {
// 最终的usage获取
if claudeResponse.Usage != nil {
if claudeResponse.Usage.InputTokens > 0 {
// 不叠加,只取最新的
claudeInfo.Usage.PromptTokens = claudeResponse.Usage.InputTokens
}
if claudeResponse.Usage.CacheReadInputTokens > 0 {
claudeInfo.Usage.PromptTokensDetails.CachedTokens = claudeResponse.Usage.CacheReadInputTokens
}
if claudeResponse.Usage.CacheCreationInputTokens > 0 {
claudeInfo.Usage.PromptTokensDetails.CachedCreationTokens = claudeResponse.Usage.CacheCreationInputTokens
}
if cacheCreation5m := claudeResponse.Usage.GetCacheCreation5mTokens(); cacheCreation5m > 0 {
claudeInfo.Usage.ClaudeCacheCreation5mTokens = cacheCreation5m
}
if cacheCreation1h := claudeResponse.Usage.GetCacheCreation1hTokens(); cacheCreation1h > 0 {
claudeInfo.Usage.ClaudeCacheCreation1hTokens = cacheCreation1h
}
if claudeResponse.Usage.OutputTokens > 0 {
claudeInfo.Usage.CompletionTokens = claudeResponse.Usage.OutputTokens
}
claudeInfo.Usage.TotalTokens = claudeInfo.Usage.PromptTokens + claudeInfo.Usage.CompletionTokens
}
// 判断是否完整
claudeInfo.Done = true
} else if claudeResponse.Type == "content_block_start" {
} else {
return false
}
if oaiResponse != nil {
oaiResponse.Id = claudeInfo.ResponseId
@@ -660,7 +614,7 @@ func FormatClaudeResponseInfo(requestMode int, claudeResponse *dto.ClaudeRespons
return true
}
func HandleStreamResponseData(c *gin.Context, info *relaycommon.RelayInfo, claudeInfo *ClaudeResponseInfo, data string, requestMode int) *types.NewAPIError {
func HandleStreamResponseData(c *gin.Context, info *relaycommon.RelayInfo, claudeInfo *ClaudeResponseInfo, data string) *types.NewAPIError {
var claudeResponse dto.ClaudeResponse
err := common.UnmarshalJsonStr(data, &claudeResponse)
if err != nil {
@@ -677,24 +631,19 @@ func HandleStreamResponseData(c *gin.Context, info *relaycommon.RelayInfo, claud
maybeMarkClaudeRefusal(c, *claudeResponse.Delta.StopReason)
}
if info.RelayFormat == types.RelayFormatClaude {
FormatClaudeResponseInfo(requestMode, &claudeResponse, nil, claudeInfo)
FormatClaudeResponseInfo(&claudeResponse, nil, claudeInfo)
if requestMode == RequestModeCompletion {
} else {
if claudeResponse.Type == "message_start" {
// message_start, 获取usage
if claudeResponse.Message != nil {
info.UpstreamModelName = claudeResponse.Message.Model
}
} else if claudeResponse.Type == "content_block_delta" {
} else if claudeResponse.Type == "message_delta" {
if claudeResponse.Type == "message_start" {
// message_start, 获取usage
if claudeResponse.Message != nil {
info.UpstreamModelName = claudeResponse.Message.Model
}
}
helper.ClaudeChunkData(c, claudeResponse, data)
} else if info.RelayFormat == types.RelayFormatOpenAI {
response := StreamResponseClaude2OpenAI(requestMode, &claudeResponse)
response := StreamResponseClaude2OpenAI(&claudeResponse)
if !FormatClaudeResponseInfo(requestMode, &claudeResponse, response, claudeInfo) {
if !FormatClaudeResponseInfo(&claudeResponse, response, claudeInfo) {
return nil
}
@@ -706,20 +655,15 @@ func HandleStreamResponseData(c *gin.Context, info *relaycommon.RelayInfo, claud
return nil
}
func HandleStreamFinalResponse(c *gin.Context, info *relaycommon.RelayInfo, claudeInfo *ClaudeResponseInfo, requestMode int) {
if requestMode == RequestModeCompletion {
claudeInfo.Usage = service.ResponseText2Usage(c, claudeInfo.ResponseText.String(), info.UpstreamModelName, info.GetEstimatePromptTokens())
} else {
if claudeInfo.Usage.PromptTokens == 0 {
//上游出错
}
if claudeInfo.Usage.CompletionTokens == 0 || !claudeInfo.Done {
if common.DebugEnabled {
common.SysLog("claude response usage is not complete, maybe upstream error")
}
claudeInfo.Usage = service.ResponseText2Usage(c, claudeInfo.ResponseText.String(), info.UpstreamModelName, claudeInfo.Usage.PromptTokens)
func HandleStreamFinalResponse(c *gin.Context, info *relaycommon.RelayInfo, claudeInfo *ClaudeResponseInfo) {
if claudeInfo.Usage.PromptTokens == 0 {
//上游出错
}
if claudeInfo.Usage.CompletionTokens == 0 || !claudeInfo.Done {
if common.DebugEnabled {
common.SysLog("claude response usage is not complete, maybe upstream error")
}
claudeInfo.Usage = service.ResponseText2Usage(c, claudeInfo.ResponseText.String(), info.UpstreamModelName, claudeInfo.Usage.PromptTokens)
}
if info.RelayFormat == types.RelayFormatClaude {
@@ -736,7 +680,7 @@ func HandleStreamFinalResponse(c *gin.Context, info *relaycommon.RelayInfo, clau
}
}
func ClaudeStreamHandler(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo, requestMode int) (*dto.Usage, *types.NewAPIError) {
func ClaudeStreamHandler(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (*dto.Usage, *types.NewAPIError) {
claudeInfo := &ClaudeResponseInfo{
ResponseId: helper.GetResponseID(c),
Created: common.GetTimestamp(),
@@ -746,7 +690,7 @@ func ClaudeStreamHandler(c *gin.Context, resp *http.Response, info *relaycommon.
}
var err *types.NewAPIError
helper.StreamScannerHandler(c, resp, info, func(data string) bool {
err = HandleStreamResponseData(c, info, claudeInfo, data, requestMode)
err = HandleStreamResponseData(c, info, claudeInfo, data)
if err != nil {
return false
}
@@ -756,11 +700,11 @@ func ClaudeStreamHandler(c *gin.Context, resp *http.Response, info *relaycommon.
return nil, err
}
HandleStreamFinalResponse(c, info, claudeInfo, requestMode)
HandleStreamFinalResponse(c, info, claudeInfo)
return claudeInfo.Usage, nil
}
func HandleClaudeResponseData(c *gin.Context, info *relaycommon.RelayInfo, claudeInfo *ClaudeResponseInfo, httpResp *http.Response, data []byte, requestMode int) *types.NewAPIError {
func HandleClaudeResponseData(c *gin.Context, info *relaycommon.RelayInfo, claudeInfo *ClaudeResponseInfo, httpResp *http.Response, data []byte) *types.NewAPIError {
var claudeResponse dto.ClaudeResponse
err := common.Unmarshal(data, &claudeResponse)
if err != nil {
@@ -770,26 +714,22 @@ func HandleClaudeResponseData(c *gin.Context, info *relaycommon.RelayInfo, claud
return types.WithClaudeError(*claudeError, http.StatusInternalServerError)
}
maybeMarkClaudeRefusal(c, claudeResponse.StopReason)
if requestMode == RequestModeCompletion {
claudeInfo.Usage = service.ResponseText2Usage(c, claudeResponse.Completion, info.UpstreamModelName, info.GetEstimatePromptTokens())
} else {
if claudeInfo.Usage == nil {
claudeInfo.Usage = &dto.Usage{}
}
if claudeResponse.Usage != nil {
claudeInfo.Usage.PromptTokens = claudeResponse.Usage.InputTokens
claudeInfo.Usage.CompletionTokens = claudeResponse.Usage.OutputTokens
claudeInfo.Usage.TotalTokens = claudeResponse.Usage.InputTokens + claudeResponse.Usage.OutputTokens
claudeInfo.Usage.PromptTokensDetails.CachedTokens = claudeResponse.Usage.CacheReadInputTokens
claudeInfo.Usage.PromptTokensDetails.CachedCreationTokens = claudeResponse.Usage.CacheCreationInputTokens
claudeInfo.Usage.ClaudeCacheCreation5mTokens = claudeResponse.Usage.GetCacheCreation5mTokens()
claudeInfo.Usage.ClaudeCacheCreation1hTokens = claudeResponse.Usage.GetCacheCreation1hTokens()
}
if claudeInfo.Usage == nil {
claudeInfo.Usage = &dto.Usage{}
}
if claudeResponse.Usage != nil {
claudeInfo.Usage.PromptTokens = claudeResponse.Usage.InputTokens
claudeInfo.Usage.CompletionTokens = claudeResponse.Usage.OutputTokens
claudeInfo.Usage.TotalTokens = claudeResponse.Usage.InputTokens + claudeResponse.Usage.OutputTokens
claudeInfo.Usage.PromptTokensDetails.CachedTokens = claudeResponse.Usage.CacheReadInputTokens
claudeInfo.Usage.PromptTokensDetails.CachedCreationTokens = claudeResponse.Usage.CacheCreationInputTokens
claudeInfo.Usage.ClaudeCacheCreation5mTokens = claudeResponse.Usage.GetCacheCreation5mTokens()
claudeInfo.Usage.ClaudeCacheCreation1hTokens = claudeResponse.Usage.GetCacheCreation1hTokens()
}
var responseData []byte
switch info.RelayFormat {
case types.RelayFormatOpenAI:
openaiResponse := ResponseClaude2OpenAI(requestMode, &claudeResponse)
openaiResponse := ResponseClaude2OpenAI(&claudeResponse)
openaiResponse.Usage = *claudeInfo.Usage
responseData, err = json.Marshal(openaiResponse)
if err != nil {
@@ -807,7 +747,7 @@ func HandleClaudeResponseData(c *gin.Context, info *relaycommon.RelayInfo, claud
return nil
}
func ClaudeHandler(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo, requestMode int) (*dto.Usage, *types.NewAPIError) {
func ClaudeHandler(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (*dto.Usage, *types.NewAPIError) {
defer service.CloseResponseBodyGracefully(resp)
claudeInfo := &ClaudeResponseInfo{
@@ -824,7 +764,7 @@ func ClaudeHandler(c *gin.Context, resp *http.Response, info *relaycommon.RelayI
if common.DebugEnabled {
println("responseBody: ", string(responseBody))
}
handleErr := HandleClaudeResponseData(c, info, claudeInfo, resp, responseBody, requestMode)
handleErr := HandleClaudeResponseData(c, info, claudeInfo, resp, responseBody)
if handleErr != nil {
return nil, handleErr
}

View File

@@ -90,6 +90,12 @@ func (a *Adaptor) ConvertOpenAIResponsesRequest(c *gin.Context, info *relaycommo
}
}
}
// Codex backend requires the `instructions` field to be present.
// Keep it consistent with Codex CLI behavior by defaulting to an empty string.
if len(request.Instructions) == 0 {
request.Instructions = json.RawMessage(`""`)
}
if isCompact {
return request, nil
}
@@ -172,5 +178,15 @@ func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Header, info *rel
req.Set("originator", "codex_cli_rs")
}
// chatgpt.com/backend-api/codex/responses is strict about Content-Type.
// Clients may omit it or include parameters like `application/json; charset=utf-8`,
// which can be rejected by the upstream. Force the exact media type.
req.Set("Content-Type", "application/json")
if info.IsStream {
req.Set("Accept", "text/event-stream")
} else if req.Get("Accept") == "" {
req.Set("Accept", "application/json")
}
return nil
}

View File

@@ -96,9 +96,9 @@ func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycom
switch info.RelayFormat {
case types.RelayFormatClaude:
if info.IsStream {
return claude.ClaudeStreamHandler(c, resp, info, claude.RequestModeMessage)
return claude.ClaudeStreamHandler(c, resp, info)
} else {
return claude.ClaudeHandler(c, resp, info, claude.RequestModeMessage)
return claude.ClaudeHandler(c, resp, info)
}
default:
adaptor := openai.Adaptor{}

View File

@@ -1258,8 +1258,7 @@ func geminiStreamHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http
}
if usage.CompletionTokens <= 0 {
str := responseText.String()
if len(str) > 0 {
if info.ReceivedResponseCount > 0 {
usage = service.ResponseText2Usage(c, responseText.String(), info.UpstreamModelName, info.GetEstimatePromptTokens())
} else {
usage = &dto.Usage{}

View File

@@ -103,9 +103,9 @@ func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycom
switch info.RelayFormat {
case types.RelayFormatClaude:
if info.IsStream {
return claude.ClaudeStreamHandler(c, resp, info, claude.RequestModeMessage)
return claude.ClaudeStreamHandler(c, resp, info)
} else {
return claude.ClaudeHandler(c, resp, info, claude.RequestModeMessage)
return claude.ClaudeHandler(c, resp, info)
}
default:
adaptor := openai.Adaptor{}

View File

@@ -42,6 +42,7 @@ var claudeModelMap = map[string]string{
"claude-sonnet-4-5-20250929": "claude-sonnet-4-5@20250929",
"claude-haiku-4-5-20251001": "claude-haiku-4-5@20251001",
"claude-opus-4-5-20251101": "claude-opus-4-5@20251101",
"claude-opus-4-6": "claude-opus-4-6",
}
const anthropicVersion = "vertex-2023-10-16"
@@ -367,7 +368,7 @@ func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycom
if info.IsStream {
switch a.RequestMode {
case RequestModeClaude:
return claude.ClaudeStreamHandler(c, resp, info, claude.RequestModeMessage)
return claude.ClaudeStreamHandler(c, resp, info)
case RequestModeGemini:
if info.RelayMode == constant.RelayModeGemini {
return gemini.GeminiTextGenerationStreamHandler(c, info, resp)
@@ -380,7 +381,7 @@ func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycom
} else {
switch a.RequestMode {
case RequestModeClaude:
return claude.ClaudeHandler(c, resp, info, claude.RequestModeMessage)
return claude.ClaudeHandler(c, resp, info)
case RequestModeGemini:
if info.RelayMode == constant.RelayModeGemini {
return gemini.GeminiTextGenerationHandler(c, info, resp)

View File

@@ -1,6 +1,8 @@
package vertex
import (
"encoding/json"
"github.com/QuantumNous/new-api/dto"
)
@@ -17,6 +19,7 @@ type VertexAIClaudeRequest struct {
Tools any `json:"tools,omitempty"`
ToolChoice any `json:"tool_choice,omitempty"`
Thinking *dto.Thinking `json:"thinking,omitempty"`
OutputConfig json.RawMessage `json:"output_config,omitempty"`
}
func copyRequest(req *dto.ClaudeRequest, version string) *VertexAIClaudeRequest {
@@ -33,5 +36,6 @@ func copyRequest(req *dto.ClaudeRequest, version string) *VertexAIClaudeRequest
Tools: req.Tools,
ToolChoice: req.ToolChoice,
Thinking: req.Thinking,
OutputConfig: req.OutputConfig,
}
}

View File

@@ -348,9 +348,9 @@ func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycom
if info.RelayFormat == types.RelayFormatClaude {
if _, ok := channelconstant.ChannelSpecialBases[info.ChannelBaseUrl]; ok {
if info.IsStream {
return claude.ClaudeStreamHandler(c, resp, info, claude.RequestModeMessage)
return claude.ClaudeStreamHandler(c, resp, info)
}
return claude.ClaudeHandler(c, resp, info, claude.RequestModeMessage)
return claude.ClaudeHandler(c, resp, info)
}
}

View File

@@ -110,9 +110,9 @@ func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycom
switch info.RelayFormat {
case types.RelayFormatClaude:
if info.IsStream {
return claude.ClaudeStreamHandler(c, resp, info, claude.RequestModeMessage)
return claude.ClaudeStreamHandler(c, resp, info)
} else {
return claude.ClaudeHandler(c, resp, info, claude.RequestModeMessage)
return claude.ClaudeHandler(c, resp, info)
}
default:
if info.RelayMode == relayconstant.RelayModeImagesGenerations {

View File

@@ -2,6 +2,7 @@ package relay
import (
"bytes"
"encoding/json"
"fmt"
"io"
"net/http"
@@ -14,6 +15,7 @@ import (
"github.com/QuantumNous/new-api/relay/helper"
"github.com/QuantumNous/new-api/service"
"github.com/QuantumNous/new-api/setting/model_setting"
"github.com/QuantumNous/new-api/setting/reasoning"
"github.com/QuantumNous/new-api/types"
"github.com/gin-gonic/gin"
@@ -49,7 +51,17 @@ func ClaudeHelper(c *gin.Context, info *relaycommon.RelayInfo) (newAPIError *typ
request.MaxTokens = uint(model_setting.GetClaudeSettings().GetDefaultMaxTokens(request.Model))
}
if model_setting.GetClaudeSettings().ThinkingAdapterEnabled &&
if baseModel, effortLevel, ok := reasoning.TrimEffortSuffix(request.Model); ok && effortLevel != "" &&
strings.HasPrefix(request.Model, "claude-opus-4-6") {
request.Model = baseModel
request.Thinking = &dto.Thinking{
Type: "adaptive",
}
request.OutputConfig = json.RawMessage(fmt.Sprintf(`{"effort":"%s"}`, effortLevel))
request.TopP = 0
request.Temperature = common.GetPointer[float64](1.0)
info.UpstreamModelName = request.Model
} else if model_setting.GetClaudeSettings().ThinkingAdapterEnabled &&
strings.HasSuffix(request.Model, "-thinking") {
if request.Thinking == nil {
// 因为BudgetTokens 必须大于1024

View File

@@ -113,6 +113,7 @@ type RelayInfo struct {
UserQuota int
RelayFormat types.RelayFormat
SendResponseCount int
ReceivedResponseCount int
FinalPreConsumedQuota int // 最终预消耗的配额
// BillingSource indicates whether this request is billed from wallet quota or subscription.
// "" or "wallet" => wallet; "subscription" => subscription

View File

@@ -90,10 +90,10 @@ func StreamScannerHandler(c *gin.Context, resp *http.Response, info *relaycommon
// 等待所有 goroutine 退出最多等待5秒
done := make(chan struct{})
go func() {
gopool.Go(func() {
wg.Wait()
close(done)
}()
})
select {
case <-done:
@@ -138,11 +138,11 @@ func StreamScannerHandler(c *gin.Context, resp *http.Response, info *relaycommon
case <-pingTicker.C:
// 使用超时机制防止写操作阻塞
done := make(chan error, 1)
go func() {
gopool.Go(func() {
writeMutex.Lock()
defer writeMutex.Unlock()
done <- PingData(c)
}()
})
select {
case err := <-done:
@@ -219,14 +219,14 @@ func StreamScannerHandler(c *gin.Context, resp *http.Response, info *relaycommon
data = strings.TrimSuffix(data, "\r")
if !strings.HasPrefix(data, "[DONE]") {
info.SetFirstResponseTime()
info.ReceivedResponseCount++
// 使用超时机制防止写操作阻塞
done := make(chan bool, 1)
go func() {
gopool.Go(func() {
writeMutex.Lock()
defer writeMutex.Unlock()
done <- dataHandler(data)
}()
})
select {
case success := <-done:

View File

@@ -4,6 +4,9 @@ import (
"github.com/QuantumNous/new-api/controller"
"github.com/QuantumNous/new-api/middleware"
// Import oauth package to register providers via init()
_ "github.com/QuantumNous/new-api/oauth"
"github.com/gin-contrib/gzip"
"github.com/gin-gonic/gin"
)
@@ -30,16 +33,16 @@ func SetApiRouter(router *gin.Engine) {
apiRouter.GET("/verification", middleware.EmailVerificationRateLimit(), middleware.TurnstileCheck(), controller.SendEmailVerification)
apiRouter.GET("/reset_password", middleware.CriticalRateLimit(), middleware.TurnstileCheck(), controller.SendPasswordResetEmail)
apiRouter.POST("/user/reset", middleware.CriticalRateLimit(), controller.ResetPassword)
apiRouter.GET("/oauth/github", middleware.CriticalRateLimit(), controller.GitHubOAuth)
apiRouter.GET("/oauth/discord", middleware.CriticalRateLimit(), controller.DiscordOAuth)
apiRouter.GET("/oauth/oidc", middleware.CriticalRateLimit(), controller.OidcAuth)
apiRouter.GET("/oauth/linuxdo", middleware.CriticalRateLimit(), controller.LinuxdoOAuth)
// OAuth routes - specific routes must come before :provider wildcard
apiRouter.GET("/oauth/state", middleware.CriticalRateLimit(), controller.GenerateOAuthCode)
apiRouter.GET("/oauth/email/bind", middleware.CriticalRateLimit(), controller.EmailBind)
// Non-standard OAuth (WeChat, Telegram) - keep original routes
apiRouter.GET("/oauth/wechat", middleware.CriticalRateLimit(), controller.WeChatAuth)
apiRouter.GET("/oauth/wechat/bind", middleware.CriticalRateLimit(), controller.WeChatBind)
apiRouter.GET("/oauth/email/bind", middleware.CriticalRateLimit(), controller.EmailBind)
apiRouter.GET("/oauth/telegram/login", middleware.CriticalRateLimit(), controller.TelegramLogin)
apiRouter.GET("/oauth/telegram/bind", middleware.CriticalRateLimit(), controller.TelegramBind)
// Standard OAuth providers (GitHub, Discord, OIDC, LinuxDO) - unified route
apiRouter.GET("/oauth/:provider", middleware.CriticalRateLimit(), controller.HandleOAuth)
apiRouter.GET("/ratio_config", middleware.CriticalRateLimit(), controller.GetRatioConfig)
apiRouter.POST("/stripe/webhook", controller.StripeWebhook)
@@ -99,6 +102,10 @@ func SetApiRouter(router *gin.Engine) {
// Check-in routes
selfRoute.GET("/checkin", controller.GetCheckinStatus)
selfRoute.POST("/checkin", middleware.TurnstileCheck(), controller.DoCheckin)
// Custom OAuth bindings
selfRoute.GET("/oauth/bindings", controller.GetUserOAuthBindings)
selfRoute.DELETE("/oauth/bindings/:provider_id", controller.UnbindCustomOAuth)
}
adminRoute := userRoute.Group("/")
@@ -163,6 +170,17 @@ func SetApiRouter(router *gin.Engine) {
optionRoute.POST("/rest_model_ratio", controller.ResetModelRatio)
optionRoute.POST("/migrate_console_setting", controller.MigrateConsoleSetting) // 用于迁移检测的旧键,下个版本会删除
}
// Custom OAuth provider management (admin only)
customOAuthRoute := apiRouter.Group("/custom-oauth-provider")
customOAuthRoute.Use(middleware.RootAuth())
{
customOAuthRoute.GET("/", controller.GetCustomOAuthProviders)
customOAuthRoute.GET("/:id", controller.GetCustomOAuthProvider)
customOAuthRoute.POST("/", controller.CreateCustomOAuthProvider)
customOAuthRoute.PUT("/:id", controller.UpdateCustomOAuthProvider)
customOAuthRoute.DELETE("/:id", controller.DeleteCustomOAuthProvider)
}
performanceRoute := apiRouter.Group("/performance")
performanceRoute.Use(middleware.RootAuth())
{
@@ -220,7 +238,7 @@ func SetApiRouter(router *gin.Engine) {
tokenRoute.Use(middleware.UserAuth())
{
tokenRoute.GET("/", controller.GetAllTokens)
tokenRoute.GET("/search", controller.SearchTokens)
tokenRoute.GET("/search", middleware.SearchRateLimit(), controller.SearchTokens)
tokenRoute.GET("/:id", controller.GetToken)
tokenRoute.POST("/", controller.AddToken)
tokenRoute.PUT("/", controller.UpdateToken)

View File

@@ -212,13 +212,23 @@ func updateConfigFromMap(config interface{}, configMap map[string]string) error
case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
intValue, err := strconv.ParseInt(strValue, 10, 64)
if err != nil {
continue
// 兼容 float 格式的字符串(如 "2.000000"
floatValue, fErr := strconv.ParseFloat(strValue, 64)
if fErr != nil {
continue
}
intValue = int64(floatValue)
}
field.SetInt(intValue)
case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64:
uintValue, err := strconv.ParseUint(strValue, 10, 64)
if err != nil {
continue
// 兼容 float 格式的字符串
floatValue, fErr := strconv.ParseFloat(strValue, 64)
if fErr != nil || floatValue < 0 {
continue
}
uintValue = uint64(floatValue)
}
field.SetUint(uintValue)
case reflect.Float32, reflect.Float64:

View File

@@ -0,0 +1,28 @@
package operation_setting
import "github.com/QuantumNous/new-api/setting/config"
// TokenSetting 令牌相关配置
type TokenSetting struct {
MaxUserTokens int `json:"max_user_tokens"` // 每用户最大令牌数量
}
// 默认配置
var tokenSetting = TokenSetting{
MaxUserTokens: 1000, // 默认每用户最多 1000 个令牌
}
func init() {
// 注册到全局配置管理器
config.GlobalConfig.Register("token_setting", &tokenSetting)
}
// GetTokenSetting 获取令牌配置
func GetTokenSetting() *TokenSetting {
return &tokenSetting
}
// GetMaxUserTokens 获取每用户最大令牌数量
func GetMaxUserTokens() int {
return GetTokenSetting().MaxUserTokens
}

View File

@@ -60,6 +60,12 @@ var defaultCacheRatio = map[string]float64{
"claude-sonnet-4-5-20250929-thinking": 0.1,
"claude-opus-4-5-20251101": 0.1,
"claude-opus-4-5-20251101-thinking": 0.1,
"claude-opus-4-6": 0.1,
"claude-opus-4-6-thinking": 0.1,
"claude-opus-4-6-max": 0.1,
"claude-opus-4-6-high": 0.1,
"claude-opus-4-6-medium": 0.1,
"claude-opus-4-6-low": 0.1,
}
var defaultCreateCacheRatio = map[string]float64{
@@ -82,6 +88,12 @@ var defaultCreateCacheRatio = map[string]float64{
"claude-sonnet-4-5-20250929-thinking": 1.25,
"claude-opus-4-5-20251101": 1.25,
"claude-opus-4-5-20251101-thinking": 1.25,
"claude-opus-4-6": 1.25,
"claude-opus-4-6-thinking": 1.25,
"claude-opus-4-6-max": 1.25,
"claude-opus-4-6-high": 1.25,
"claude-opus-4-6-medium": 1.25,
"claude-opus-4-6-low": 1.25,
}
//var defaultCreateCacheRatio = map[string]float64{}

View File

@@ -131,9 +131,6 @@ var defaultModelRatio = map[string]float64{
"text-search-ada-doc-001": 10,
"text-moderation-stable": 0.1,
"text-moderation-latest": 0.1,
"claude-instant-1": 0.4, // $0.8 / 1M tokens
"claude-2.0": 4, // $8 / 1M tokens
"claude-2.1": 4, // $8 / 1M tokens
"claude-3-haiku-20240307": 0.125, // $0.25 / 1M tokens
"claude-3-5-haiku-20241022": 0.5, // $1 / 1M tokens
"claude-haiku-4-5-20251001": 0.5, // $1 / 1M tokens
@@ -145,6 +142,11 @@ var defaultModelRatio = map[string]float64{
"claude-sonnet-4-20250514": 1.5,
"claude-sonnet-4-5-20250929": 1.5,
"claude-opus-4-5-20251101": 2.5,
"claude-opus-4-6": 2.5,
"claude-opus-4-6-max": 2.5,
"claude-opus-4-6-high": 2.5,
"claude-opus-4-6-medium": 2.5,
"claude-opus-4-6-low": 2.5,
"claude-3-opus-20240229": 7.5, // $15 / 1M tokens
"claude-opus-4-20250514": 7.5,
"claude-opus-4-1-20250805": 7.5,
@@ -589,8 +591,6 @@ func getHardcodedCompletionModelRatio(name string) (float64, bool) {
return 5, true
} else if strings.Contains(name, "claude-sonnet-4") || strings.Contains(name, "claude-opus-4") || strings.Contains(name, "claude-haiku-4") {
return 5, true
} else if strings.Contains(name, "claude-instant-1") || strings.Contains(name, "claude-2") {
return 3, true
}
if strings.HasPrefix(name, "gpt-3.5") {

View File

@@ -6,7 +6,7 @@ import (
"github.com/samber/lo"
)
var EffortSuffixes = []string{"-high", "-medium", "-low", "-minimal"}
var EffortSuffixes = []string{"-max", "-high", "-medium", "-low", "-minimal"}
// TrimEffortSuffix -> modelName level(low) exists
func TrimEffortSuffix(modelName string) (string, string, bool) {

View File

@@ -34,6 +34,7 @@ import {
onDiscordOAuthClicked,
onOIDCClicked,
onLinuxDOOAuthClicked,
onCustomOAuthClicked,
prepareCredentialRequestOptions,
buildAssertionResult,
isPasskeySupported,
@@ -109,6 +110,7 @@ const LoginForm = () => {
const [githubButtonDisabled, setGithubButtonDisabled] = useState(false);
const githubTimeoutRef = useRef(null);
const githubButtonText = t(githubButtonTextKeyByState[githubButtonState]);
const [customOAuthLoading, setCustomOAuthLoading] = useState({});
const logo = getLogo();
const systemName = getSystemName();
@@ -373,6 +375,23 @@ const LoginForm = () => {
}
};
// 包装的自定义OAuth登录点击处理
const handleCustomOAuthClick = (provider) => {
if ((hasUserAgreement || hasPrivacyPolicy) && !agreedToTerms) {
showInfo(t('请先阅读并同意用户协议和隐私政策'));
return;
}
setCustomOAuthLoading((prev) => ({ ...prev, [provider.slug]: true }));
try {
onCustomOAuthClicked(provider, { shouldLogout: true });
} finally {
// 由于重定向,这里不会执行到,但为了完整性添加
setTimeout(() => {
setCustomOAuthLoading((prev) => ({ ...prev, [provider.slug]: false }));
}, 3000);
}
};
// 包装的邮箱登录选项点击处理
const handleEmailLoginClick = () => {
setEmailLoginLoading(true);
@@ -572,6 +591,23 @@ const LoginForm = () => {
</Button>
)}
{status.custom_oauth_providers &&
status.custom_oauth_providers.map((provider) => (
<Button
key={provider.slug}
theme='outline'
className='w-full h-12 flex items-center justify-center !rounded-full border border-gray-200 hover:bg-gray-50 transition-colors'
type='tertiary'
icon={<IconLock size='large' />}
onClick={() => handleCustomOAuthClick(provider)}
loading={customOAuthLoading[provider.slug]}
>
<span className='ml-3'>
{t('使用 {{name}} 继续', { name: provider.name })}
</span>
</Button>
))}
{status.telegram_oauth && (
<div className='flex justify-center my-2'>
<TelegramLoginButton

View File

@@ -17,7 +17,7 @@ along with this program. If not, see <https://www.gnu.org/licenses/>.
For commercial licensing, please contact support@quantumnous.com
*/
import React, { useContext, useEffect } from 'react';
import React, { useContext, useEffect, useRef } from 'react';
import { useNavigate, useSearchParams } from 'react-router-dom';
import { useTranslation } from 'react-i18next';
import {
@@ -35,6 +35,9 @@ const OAuth2Callback = (props) => {
const [searchParams] = useSearchParams();
const [, userDispatch] = useContext(UserContext);
const navigate = useNavigate();
// 防止 React 18 Strict Mode 下重复执行
const hasExecuted = useRef(false);
// 最大重试次数
const MAX_RETRIES = 3;
@@ -48,7 +51,9 @@ const OAuth2Callback = (props) => {
const { success, message, data } = resData;
if (!success) {
throw new Error(message || 'OAuth2 callback error');
// 业务错误不重试,直接显示错误
showError(message || t('授权失败'));
return;
}
if (message === 'bind') {
@@ -63,6 +68,7 @@ const OAuth2Callback = (props) => {
navigate('/console/token');
}
} catch (error) {
// 网络错误等可重试
if (retry < MAX_RETRIES) {
// 递增的退避等待
await new Promise((resolve) => setTimeout(resolve, (retry + 1) * 2000));
@@ -76,6 +82,12 @@ const OAuth2Callback = (props) => {
};
useEffect(() => {
// 防止 React 18 Strict Mode 下重复执行
if (hasExecuted.current) {
return;
}
hasExecuted.current = true;
const code = searchParams.get('code');
const state = searchParams.get('state');

View File

@@ -1,113 +0,0 @@
/*
Copyright (C) 2025 QuantumNous
This program is free software: you can redistribute it and/or modify
it under the terms of the GNU Affero General Public License as
published by the Free Software Foundation, either version 3 of the
License, or (at your option) any later version.
This program is distributed in the hope that it will be useful,
but WITHOUT ANY WARRANTY; without even the implied warranty of
MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
GNU Affero General Public License for more details.
You should have received a copy of the GNU Affero General Public License
along with this program. If not, see <https://www.gnu.org/licenses/>.
For commercial licensing, please contact support@quantumnous.com
*/
import React, { useState } from 'react';
import { useTranslation } from 'react-i18next';
import { Button, Modal } from '@douyinfe/semi-ui';
import { useSecureVerification } from '../../../hooks/common/useSecureVerification';
import { createApiCalls } from '../../../services/secureVerification';
import SecureVerificationModal from '../modals/SecureVerificationModal';
import ChannelKeyDisplay from '../ui/ChannelKeyDisplay';
/**
* 渠道密钥查看组件使用示例
* 展示如何使用通用安全验证系统
*/
const ChannelKeyViewExample = ({ channelId }) => {
const { t } = useTranslation();
const [keyData, setKeyData] = useState('');
const [showKeyModal, setShowKeyModal] = useState(false);
// 使用通用安全验证 Hook
const {
isModalVisible,
verificationMethods,
verificationState,
startVerification,
executeVerification,
cancelVerification,
setVerificationCode,
switchVerificationMethod,
} = useSecureVerification({
onSuccess: (result) => {
// 验证成功后处理结果
if (result.success && result.data?.key) {
setKeyData(result.data.key);
setShowKeyModal(true);
}
},
successMessage: t('密钥获取成功'),
});
// 开始查看密钥流程
const handleViewKey = async () => {
const apiCall = createApiCalls.viewChannelKey(channelId);
await startVerification(apiCall, {
title: t('查看渠道密钥'),
description: t('为了保护账户安全,请验证您的身份。'),
preferredMethod: 'passkey', // 可以指定首选验证方式
});
};
return (
<>
{/* 查看密钥按钮 */}
<Button type='primary' theme='outline' onClick={handleViewKey}>
{t('查看密钥')}
</Button>
{/* 安全验证模态框 */}
<SecureVerificationModal
visible={isModalVisible}
verificationMethods={verificationMethods}
verificationState={verificationState}
onVerify={executeVerification}
onCancel={cancelVerification}
onCodeChange={setVerificationCode}
onMethodSwitch={switchVerificationMethod}
title={verificationState.title}
description={verificationState.description}
/>
{/* 密钥显示模态框 */}
<Modal
title={t('渠道密钥信息')}
visible={showKeyModal}
onCancel={() => setShowKeyModal(false)}
footer={
<Button type='primary' onClick={() => setShowKeyModal(false)}>
{t('完成')}
</Button>
}
width={700}
style={{ maxWidth: '90vw' }}
>
<ChannelKeyDisplay
keyData={keyData}
showSuccessIcon={true}
successText={t('密钥获取成功')}
showWarning={true}
/>
</Modal>
</>
);
};
export default ChannelKeyViewExample;

View File

@@ -93,6 +93,49 @@ export function Mermaid(props) {
);
}
function SandboxedHtmlPreview({ code }) {
const iframeRef = useRef(null);
const [iframeHeight, setIframeHeight] = useState(150);
useEffect(() => {
const iframe = iframeRef.current;
if (!iframe) return;
const handleLoad = () => {
try {
const doc = iframe.contentDocument || iframe.contentWindow?.document;
if (doc) {
const height =
doc.documentElement.scrollHeight || doc.body.scrollHeight;
setIframeHeight(Math.min(Math.max(height + 16, 60), 600));
}
} catch {
// sandbox restrictions may prevent access, that's fine
}
};
iframe.addEventListener('load', handleLoad);
return () => iframe.removeEventListener('load', handleLoad);
}, [code]);
return (
<iframe
ref={iframeRef}
sandbox='allow-same-origin'
srcDoc={code}
title='HTML Preview'
style={{
width: '100%',
height: `${iframeHeight}px`,
border: 'none',
overflow: 'auto',
backgroundColor: '#fff',
borderRadius: '4px',
}}
/>
);
}
export function PreCode(props) {
const ref = useRef(null);
const [mermaidCode, setMermaidCode] = useState('');
@@ -227,7 +270,7 @@ export function PreCode(props) {
>
HTML预览:
</div>
<div dangerouslySetInnerHTML={{ __html: htmlCode }} />
<SandboxedHtmlPreview code={htmlCode} />
</div>
)}
</>

View File

@@ -1,148 +0,0 @@
/*
Copyright (C) 2025 QuantumNous
This program is free software: you can redistribute it and/or modify
it under the terms of the GNU Affero General Public License as
published by the Free Software Foundation, either version 3 of the
License, or (at your option) any later version.
This program is distributed in the hope that it will be useful,
but WITHOUT ANY WARRANTY; without even the implied warranty of
MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
GNU Affero General Public License for more details.
You should have received a copy of the GNU Affero General Public License
along with this program. If not, see <https://www.gnu.org/licenses/>.
For commercial licensing, please contact support@quantumnous.com
*/
import React from 'react';
import { useTranslation } from 'react-i18next';
import { Modal, Button, Input, Typography } from '@douyinfe/semi-ui';
/**
* 可复用的两步验证模态框组件
* @param {Object} props
* @param {boolean} props.visible - 是否显示模态框
* @param {string} props.code - 验证码值
* @param {boolean} props.loading - 是否正在验证
* @param {Function} props.onCodeChange - 验证码变化回调
* @param {Function} props.onVerify - 验证回调
* @param {Function} props.onCancel - 取消回调
* @param {string} props.title - 模态框标题
* @param {string} props.description - 验证描述文本
* @param {string} props.placeholder - 输入框占位文本
*/
const TwoFactorAuthModal = ({
visible,
code,
loading,
onCodeChange,
onVerify,
onCancel,
title,
description,
placeholder,
}) => {
const { t } = useTranslation();
const handleKeyDown = (e) => {
if (e.key === 'Enter' && code && !loading) {
onVerify();
}
};
return (
<Modal
title={
<div className='flex items-center'>
<div className='w-8 h-8 rounded-full bg-blue-100 dark:bg-blue-900 flex items-center justify-center mr-3'>
<svg
className='w-4 h-4 text-blue-600 dark:text-blue-400'
fill='currentColor'
viewBox='0 0 20 20'
>
<path
fillRule='evenodd'
d='M5 9V7a5 5 0 0110 0v2a2 2 0 012 2v5a2 2 0 01-2 2H5a2 2 0 01-2-2v-5a2 2 0 012-2zm8-2v2H7V7a3 3 0 016 0z'
clipRule='evenodd'
/>
</svg>
</div>
{title || t('安全验证')}
</div>
}
visible={visible}
onCancel={onCancel}
footer={
<>
<Button onClick={onCancel}>{t('取消')}</Button>
<Button
type='primary'
loading={loading}
disabled={!code || loading}
onClick={onVerify}
>
{t('验证')}
</Button>
</>
}
width={500}
style={{ maxWidth: '90vw' }}
>
<div className='space-y-6'>
{/* 安全提示 */}
<div className='bg-blue-50 dark:bg-blue-900 rounded-lg p-4'>
<div className='flex items-start'>
<svg
className='w-5 h-5 text-blue-600 dark:text-blue-400 mt-0.5 mr-3 flex-shrink-0'
fill='currentColor'
viewBox='0 0 20 20'
>
<path
fillRule='evenodd'
d='M18 10a8 8 0 11-16 0 8 8 0 0116 0zm-7-4a1 1 0 11-2 0 1 1 0 012 0zM9 9a1 1 0 000 2v3a1 1 0 001 1h1a1 1 0 100-2v-3a1 1 0 00-1-1H9z'
clipRule='evenodd'
/>
</svg>
<div>
<Typography.Text
strong
className='text-blue-800 dark:text-blue-200'
>
{t('安全验证')}
</Typography.Text>
<Typography.Text className='block text-blue-700 dark:text-blue-300 text-sm mt-1'>
{description || t('为了保护账户安全,请验证您的两步验证码。')}
</Typography.Text>
</div>
</div>
</div>
{/* 验证码输入 */}
<div>
<Typography.Text strong className='block mb-2'>
{t('验证身份')}
</Typography.Text>
<Input
placeholder={placeholder || t('请输入认证器验证码或备用码')}
value={code}
onChange={onCodeChange}
size='large'
maxLength={8}
onKeyDown={handleKeyDown}
autoFocus
/>
<Typography.Text type='tertiary' size='small' className='mt-2 block'>
{t(
'支持6位TOTP验证码或8位备用码可到`个人设置-安全设置-两步验证设置`配置或查看。',
)}
</Typography.Text>
</div>
</div>
</Modal>
);
};
export default TwoFactorAuthModal;

View File

@@ -91,22 +91,45 @@ const codeThemeStyles = {
},
};
const highlightJson = (str) => {
return str.replace(
/("(\\u[a-zA-Z0-9]{4}|\\[^u]|[^\\"])*"(\s*:)?|\b(true|false|null)\b|-?\d+(?:\.\d*)?(?:[eE][+-]?\d+)?)/g,
(match) => {
let color = '#b5cea8';
if (/^"/.test(match)) {
color = /:$/.test(match) ? '#9cdcfe' : '#ce9178';
} else if (/true|false|null/.test(match)) {
color = '#569cd6';
}
return `<span style="color: ${color}">${match}</span>`;
},
);
const escapeHtml = (str) => {
return str
.replace(/&/g, '&amp;')
.replace(/</g, '&lt;')
.replace(/>/g, '&gt;')
.replace(/"/g, '&quot;')
.replace(/'/g, '&#039;');
};
const linkRegex = /(https?:\/\/[^\s<"'\]),;}]+)/g;
const highlightJson = (str) => {
const tokenRegex =
/("(\\u[a-zA-Z0-9]{4}|\\[^u]|[^\\"])*"(\s*:)?|\b(true|false|null)\b|-?\d+(?:\.\d*)?(?:[eE][+-]?\d+)?)/g;
let result = '';
let lastIndex = 0;
let match;
while ((match = tokenRegex.exec(str)) !== null) {
// Escape non-token text (structural chars like {, }, [, ], :, comma, whitespace)
result += escapeHtml(str.slice(lastIndex, match.index));
const token = match[0];
let color = '#b5cea8';
if (/^"/.test(token)) {
color = /:$/.test(token) ? '#9cdcfe' : '#ce9178';
} else if (/true|false|null/.test(token)) {
color = '#569cd6';
}
// Escape token content before wrapping in span
result += `<span style="color: ${color}">${escapeHtml(token)}</span>`;
lastIndex = tokenRegex.lastIndex;
}
// Escape remaining text
result += escapeHtml(str.slice(lastIndex));
return result;
};
const linkRegex = /(https?:\/\/(?:[^\s<"'\]),;&}]|&amp;)+)/g;
const linkifyHtml = (html) => {
const parts = html.split(/(<[^>]+>)/g);
@@ -184,14 +207,14 @@ const CodeViewer = ({ content, title, language = 'json' }) => {
const highlightedContent = useMemo(() => {
if (contentMetrics.isVeryLarge && !isExpanded) {
return displayContent;
return escapeHtml(displayContent);
}
if (isJsonLike(displayContent, language)) {
return highlightJson(displayContent);
}
return displayContent;
return escapeHtml(displayContent);
}, [displayContent, language, contentMetrics.isVeryLarge, isExpanded]);
const renderedContent = useMemo(() => {

View File

@@ -1,40 +0,0 @@
/*
Copyright (C) 2025 QuantumNous
This program is free software: you can redistribute it and/or modify
it under the terms of the GNU Affero General Public License as
published by the Free Software Foundation, either version 3 of the
License, or (at your option) any later version.
This program is distributed in the hope that it will be useful,
but WITHOUT ANY WARRANTY; without even the implied warranty of
MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
GNU Affero General Public License for more details.
You should have received a copy of the GNU Affero General Public License
along with this program. If not, see <https://www.gnu.org/licenses/>.
For commercial licensing, please contact support@quantumnous.com
*/
export { default as SettingsPanel } from './SettingsPanel';
export { default as ChatArea } from './ChatArea';
export { default as DebugPanel } from './DebugPanel';
export { default as MessageContent } from './MessageContent';
export { default as MessageActions } from './MessageActions';
export { default as CustomInputRender } from './CustomInputRender';
export { default as SSEViewer } from './SSEViewer';
export { default as ParameterControl } from './ParameterControl';
export { default as ImageUrlInput } from './ImageUrlInput';
export { default as FloatingButtons } from './FloatingButtons';
export { default as ConfigManager } from './ConfigManager';
export {
saveConfig,
loadConfig,
clearConfig,
hasStoredConfig,
getConfigTimestamp,
exportConfig,
importConfig,
} from './configStorage';

View File

@@ -0,0 +1,631 @@
/*
Copyright (C) 2025 QuantumNous
This program is free software: you can redistribute it and/or modify
it under the terms of the GNU Affero General Public License as
published by the Free Software Foundation, either version 3 of the
License, or (at your option) any later version.
This program is distributed in the hope that it will be useful,
but WITHOUT ANY WARRANTY; without even the implied warranty of
MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
GNU Affero General Public License for more details.
You should have received a copy of the GNU Affero General Public License
along with this program. If not, see <https://www.gnu.org/licenses/>.
For commercial licensing, please contact support@quantumnous.com
*/
import React, { useEffect, useState } from 'react';
import {
Button,
Form,
Row,
Col,
Typography,
Modal,
Banner,
Card,
Table,
Tag,
Popconfirm,
Space,
Select,
} from '@douyinfe/semi-ui';
import { IconPlus, IconEdit, IconDelete } from '@douyinfe/semi-icons';
import { API, showError, showSuccess } from '../../helpers';
import { useTranslation } from 'react-i18next';
const { Text } = Typography;
// Preset templates for common OAuth providers
const OAUTH_PRESETS = {
'github-enterprise': {
name: 'GitHub Enterprise',
authorization_endpoint: '/login/oauth/authorize',
token_endpoint: '/login/oauth/access_token',
user_info_endpoint: '/api/v3/user',
scopes: 'user:email',
user_id_field: 'id',
username_field: 'login',
display_name_field: 'name',
email_field: 'email',
},
gitlab: {
name: 'GitLab',
authorization_endpoint: '/oauth/authorize',
token_endpoint: '/oauth/token',
user_info_endpoint: '/api/v4/user',
scopes: 'openid profile email',
user_id_field: 'id',
username_field: 'username',
display_name_field: 'name',
email_field: 'email',
},
gitea: {
name: 'Gitea',
authorization_endpoint: '/login/oauth/authorize',
token_endpoint: '/login/oauth/access_token',
user_info_endpoint: '/api/v1/user',
scopes: 'openid profile email',
user_id_field: 'id',
username_field: 'login',
display_name_field: 'full_name',
email_field: 'email',
},
nextcloud: {
name: 'Nextcloud',
authorization_endpoint: '/apps/oauth2/authorize',
token_endpoint: '/apps/oauth2/api/v1/token',
user_info_endpoint: '/ocs/v2.php/cloud/user?format=json',
scopes: 'openid profile email',
user_id_field: 'ocs.data.id',
username_field: 'ocs.data.id',
display_name_field: 'ocs.data.displayname',
email_field: 'ocs.data.email',
},
keycloak: {
name: 'Keycloak',
authorization_endpoint: '/realms/{realm}/protocol/openid-connect/auth',
token_endpoint: '/realms/{realm}/protocol/openid-connect/token',
user_info_endpoint: '/realms/{realm}/protocol/openid-connect/userinfo',
scopes: 'openid profile email',
user_id_field: 'sub',
username_field: 'preferred_username',
display_name_field: 'name',
email_field: 'email',
},
authentik: {
name: 'Authentik',
authorization_endpoint: '/application/o/authorize/',
token_endpoint: '/application/o/token/',
user_info_endpoint: '/application/o/userinfo/',
scopes: 'openid profile email',
user_id_field: 'sub',
username_field: 'preferred_username',
display_name_field: 'name',
email_field: 'email',
},
ory: {
name: 'ORY Hydra',
authorization_endpoint: '/oauth2/auth',
token_endpoint: '/oauth2/token',
user_info_endpoint: '/userinfo',
scopes: 'openid profile email',
user_id_field: 'sub',
username_field: 'preferred_username',
display_name_field: 'name',
email_field: 'email',
},
};
const CustomOAuthSetting = ({ serverAddress }) => {
const { t } = useTranslation();
const [providers, setProviders] = useState([]);
const [loading, setLoading] = useState(false);
const [modalVisible, setModalVisible] = useState(false);
const [editingProvider, setEditingProvider] = useState(null);
const [formValues, setFormValues] = useState({});
const [selectedPreset, setSelectedPreset] = useState('');
const [baseUrl, setBaseUrl] = useState('');
const formApiRef = React.useRef(null);
const fetchProviders = async () => {
setLoading(true);
try {
const res = await API.get('/api/custom-oauth-provider/');
if (res.data.success) {
setProviders(res.data.data || []);
} else {
showError(res.data.message);
}
} catch (error) {
showError(t('获取自定义 OAuth 提供商列表失败'));
}
setLoading(false);
};
useEffect(() => {
fetchProviders();
}, []);
const handleAdd = () => {
setEditingProvider(null);
setFormValues({
enabled: false,
scopes: 'openid profile email',
user_id_field: 'sub',
username_field: 'preferred_username',
display_name_field: 'name',
email_field: 'email',
auth_style: 0,
});
setSelectedPreset('');
setBaseUrl('');
setModalVisible(true);
};
const handleEdit = (provider) => {
setEditingProvider(provider);
setFormValues({ ...provider });
setSelectedPreset('');
setBaseUrl('');
setModalVisible(true);
};
const handleDelete = async (id) => {
try {
const res = await API.delete(`/api/custom-oauth-provider/${id}`);
if (res.data.success) {
showSuccess(t('删除成功'));
fetchProviders();
} else {
showError(res.data.message);
}
} catch (error) {
showError(t('删除失败'));
}
};
const handleSubmit = async () => {
// Validate required fields
const requiredFields = [
'name',
'slug',
'client_id',
'authorization_endpoint',
'token_endpoint',
'user_info_endpoint',
];
if (!editingProvider) {
requiredFields.push('client_secret');
}
for (const field of requiredFields) {
if (!formValues[field]) {
showError(t(`请填写 ${field}`));
return;
}
}
// Validate endpoint URLs must be full URLs
const endpointFields = ['authorization_endpoint', 'token_endpoint', 'user_info_endpoint'];
for (const field of endpointFields) {
const value = formValues[field];
if (value && !value.startsWith('http://') && !value.startsWith('https://')) {
// Check if user selected a preset but forgot to fill server address
if (selectedPreset && !baseUrl) {
showError(t('请先填写服务器地址,以自动生成完整的端点 URL'));
} else {
showError(t('端点 URL 必须是完整地址(以 http:// 或 https:// 开头)'));
}
return;
}
}
try {
let res;
if (editingProvider) {
res = await API.put(
`/api/custom-oauth-provider/${editingProvider.id}`,
formValues
);
} else {
res = await API.post('/api/custom-oauth-provider/', formValues);
}
if (res.data.success) {
showSuccess(editingProvider ? t('更新成功') : t('创建成功'));
setModalVisible(false);
fetchProviders();
} else {
showError(res.data.message);
}
} catch (error) {
showError(editingProvider ? t('更新失败') : t('创建失败'));
}
};
const handlePresetChange = (preset) => {
setSelectedPreset(preset);
if (preset && OAUTH_PRESETS[preset]) {
const presetConfig = OAUTH_PRESETS[preset];
const cleanUrl = baseUrl ? baseUrl.replace(/\/+$/, '') : '';
const newValues = {
name: presetConfig.name,
slug: preset,
scopes: presetConfig.scopes,
user_id_field: presetConfig.user_id_field,
username_field: presetConfig.username_field,
display_name_field: presetConfig.display_name_field,
email_field: presetConfig.email_field,
auth_style: presetConfig.auth_style ?? 0,
};
// Only fill endpoints if server address is provided
if (cleanUrl) {
newValues.authorization_endpoint = cleanUrl + presetConfig.authorization_endpoint;
newValues.token_endpoint = cleanUrl + presetConfig.token_endpoint;
newValues.user_info_endpoint = cleanUrl + presetConfig.user_info_endpoint;
}
setFormValues((prev) => ({ ...prev, ...newValues }));
// Update form fields directly via formApi
if (formApiRef.current) {
Object.entries(newValues).forEach(([key, value]) => {
formApiRef.current.setValue(key, value);
});
}
}
};
const handleBaseUrlChange = (url) => {
setBaseUrl(url);
if (url && selectedPreset && OAUTH_PRESETS[selectedPreset]) {
const presetConfig = OAUTH_PRESETS[selectedPreset];
const cleanUrl = url.replace(/\/+$/, ''); // Remove trailing slashes
const newValues = {
authorization_endpoint: cleanUrl + presetConfig.authorization_endpoint,
token_endpoint: cleanUrl + presetConfig.token_endpoint,
user_info_endpoint: cleanUrl + presetConfig.user_info_endpoint,
};
setFormValues((prev) => ({ ...prev, ...newValues }));
// Update form fields directly via formApi (use merge mode to preserve other fields)
if (formApiRef.current) {
Object.entries(newValues).forEach(([key, value]) => {
formApiRef.current.setValue(key, value);
});
}
}
};
const columns = [
{
title: t('名称'),
dataIndex: 'name',
key: 'name',
},
{
title: 'Slug',
dataIndex: 'slug',
key: 'slug',
render: (slug) => <Tag>{slug}</Tag>,
},
{
title: t('状态'),
dataIndex: 'enabled',
key: 'enabled',
render: (enabled) => (
<Tag color={enabled ? 'green' : 'grey'}>
{enabled ? t('已启用') : t('已禁用')}
</Tag>
),
},
{
title: t('Client ID'),
dataIndex: 'client_id',
key: 'client_id',
render: (id) => (id ? id.substring(0, 20) + '...' : '-'),
},
{
title: t('操作'),
key: 'actions',
render: (_, record) => (
<Space>
<Button
icon={<IconEdit />}
size="small"
onClick={() => handleEdit(record)}
>
{t('编辑')}
</Button>
<Popconfirm
title={t('确定要删除此 OAuth 提供商吗?')}
onConfirm={() => handleDelete(record.id)}
>
<Button icon={<IconDelete />} size="small" type="danger">
{t('删除')}
</Button>
</Popconfirm>
</Space>
),
},
];
return (
<Card>
<Form.Section text={t('自定义 OAuth 提供商')}>
<Banner
type="info"
description={
<>
{t(
'配置自定义 OAuth 提供商,支持 GitHub Enterprise、GitLab、Gitea、Nextcloud、Keycloak、ORY 等兼容 OAuth 2.0 协议的身份提供商'
)}
<br />
{t('回调 URL 格式')}: {serverAddress || t('网站地址')}/oauth/
{'{slug}'}
</>
}
style={{ marginBottom: 20 }}
/>
<Button
icon={<IconPlus />}
theme="solid"
onClick={handleAdd}
style={{ marginBottom: 16 }}
>
{t('添加 OAuth 提供商')}
</Button>
<Table
columns={columns}
dataSource={providers}
loading={loading}
rowKey="id"
pagination={false}
empty={t('暂无自定义 OAuth 提供商')}
/>
<Modal
title={editingProvider ? t('编辑 OAuth 提供商') : t('添加 OAuth 提供商')}
visible={modalVisible}
onOk={handleSubmit}
onCancel={() => setModalVisible(false)}
okText={t('保存')}
cancelText={t('取消')}
width={800}
>
<Form
initValues={formValues}
onValueChange={(values) => setFormValues(values)}
getFormApi={(api) => (formApiRef.current = api)}
>
{!editingProvider && (
<Row gutter={16} style={{ marginBottom: 16 }}>
<Col span={12}>
<Form.Select
field="preset"
label={t('预设模板')}
placeholder={t('选择预设模板(可选)')}
value={selectedPreset}
onChange={handlePresetChange}
optionList={[
{ value: '', label: t('自定义') },
...Object.entries(OAUTH_PRESETS).map(([key, config]) => ({
value: key,
label: config.name,
})),
]}
/>
</Col>
<Col span={12}>
<Form.Input
field="base_url"
label={
selectedPreset
? t('服务器地址') + ' *'
: t('服务器地址')
}
placeholder={t('例如https://gitea.example.com')}
value={baseUrl}
onChange={handleBaseUrlChange}
extraText={
selectedPreset
? t('必填:请输入服务器地址以自动生成完整端点 URL')
: t('选择预设模板后填写服务器地址可自动填充端点')
}
/>
</Col>
</Row>
)}
<Row gutter={16}>
<Col span={12}>
<Form.Input
field="name"
label={t('显示名称')}
placeholder={t('例如GitHub Enterprise')}
rules={[{ required: true, message: t('请输入显示名称') }]}
/>
</Col>
<Col span={12}>
<Form.Input
field="slug"
label="Slug"
placeholder={t('例如github-enterprise')}
extraText={t('URL 标识,只能包含小写字母、数字和连字符')}
rules={[{ required: true, message: t('请输入 Slug') }]}
/>
</Col>
</Row>
<Row gutter={16}>
<Col span={12}>
<Form.Input
field="client_id"
label="Client ID"
placeholder={t('OAuth Client ID')}
rules={[{ required: true, message: t('请输入 Client ID') }]}
/>
</Col>
<Col span={12}>
<Form.Input
field="client_secret"
label="Client Secret"
type="password"
placeholder={
editingProvider
? t('留空则保持原有密钥')
: t('OAuth Client Secret')
}
rules={
editingProvider
? []
: [{ required: true, message: t('请输入 Client Secret') }]
}
/>
</Col>
</Row>
<Text strong style={{ display: 'block', margin: '16px 0 8px' }}>
{t('OAuth 端点')}
</Text>
<Row gutter={16}>
<Col span={24}>
<Form.Input
field="authorization_endpoint"
label={t('Authorization Endpoint')}
placeholder={
selectedPreset && OAUTH_PRESETS[selectedPreset]
? t('填写服务器地址后自动生成:') +
OAUTH_PRESETS[selectedPreset].authorization_endpoint
: 'https://example.com/oauth/authorize'
}
rules={[
{ required: true, message: t('请输入 Authorization Endpoint') },
]}
/>
</Col>
</Row>
<Row gutter={16}>
<Col span={12}>
<Form.Input
field="token_endpoint"
label={t('Token Endpoint')}
placeholder={
selectedPreset && OAUTH_PRESETS[selectedPreset]
? t('自动生成:') + OAUTH_PRESETS[selectedPreset].token_endpoint
: 'https://example.com/oauth/token'
}
rules={[{ required: true, message: t('请输入 Token Endpoint') }]}
/>
</Col>
<Col span={12}>
<Form.Input
field="user_info_endpoint"
label={t('User Info Endpoint')}
placeholder={
selectedPreset && OAUTH_PRESETS[selectedPreset]
? t('自动生成:') + OAUTH_PRESETS[selectedPreset].user_info_endpoint
: 'https://example.com/api/user'
}
rules={[
{ required: true, message: t('请输入 User Info Endpoint') },
]}
/>
</Col>
</Row>
<Row gutter={16}>
<Col span={12}>
<Form.Input
field="scopes"
label={t('Scopes')}
placeholder="openid profile email"
/>
</Col>
<Col span={12}>
<Form.Input
field="well_known"
label={t('Well-Known URL')}
placeholder={t('OIDC Discovery 端点(可选)')}
/>
</Col>
</Row>
<Text strong style={{ display: 'block', margin: '16px 0 8px' }}>
{t('字段映射')}
</Text>
<Text type="secondary" style={{ display: 'block', marginBottom: 8 }}>
{t('配置如何从用户信息 API 响应中提取用户数据,支持 JSONPath 语法')}
</Text>
<Row gutter={16}>
<Col span={12}>
<Form.Input
field="user_id_field"
label={t('用户 ID 字段')}
placeholder={t('例如sub、id、data.user.id')}
extraText={t('用于唯一标识用户的字段路径')}
/>
</Col>
<Col span={12}>
<Form.Input
field="username_field"
label={t('用户名字段')}
placeholder={t('例如preferred_username、login')}
/>
</Col>
</Row>
<Row gutter={16}>
<Col span={12}>
<Form.Input
field="display_name_field"
label={t('显示名称字段')}
placeholder={t('例如name、full_name')}
/>
</Col>
<Col span={12}>
<Form.Input
field="email_field"
label={t('邮箱字段')}
placeholder={t('例如email')}
/>
</Col>
</Row>
<Text strong style={{ display: 'block', margin: '16px 0 8px' }}>
{t('高级选项')}
</Text>
<Row gutter={16}>
<Col span={12}>
<Form.Select
field="auth_style"
label={t('认证方式')}
optionList={[
{ value: 0, label: t('自动检测') },
{ value: 1, label: t('POST 参数') },
{ value: 2, label: t('Basic Auth 头') },
]}
/>
</Col>
<Col span={12}>
<Form.Checkbox field="enabled" noLabel>
{t('启用此 OAuth 提供商')}
</Form.Checkbox>
</Col>
</Row>
</Form>
</Modal>
</Form.Section>
</Card>
);
};
export default CustomOAuthSetting;

View File

@@ -78,6 +78,9 @@ const OperationSetting = () => {
'checkin_setting.enabled': false,
'checkin_setting.min_quota': 1000,
'checkin_setting.max_quota': 10000,
/* 令牌设置 */
'token_setting.max_user_tokens': 1000,
});
let [loading, setLoading] = useState(false);

View File

@@ -42,6 +42,7 @@ import {
} from '../../helpers';
import axios from 'axios';
import { useTranslation } from 'react-i18next';
import CustomOAuthSetting from './CustomOAuthSetting';
const SystemSetting = () => {
const { t } = useTranslation();
@@ -1534,6 +1535,8 @@ const SystemSetting = () => {
</Form.Section>
</Card>
<CustomOAuthSetting serverAddress={inputs.ServerAddress} />
<Card>
<Form.Section text={t('配置 WeChat Server')}>
<Text>{t('用以支持通过微信进行登录注册')}</Text>

View File

@@ -42,10 +42,14 @@ import { SiTelegram, SiWechat, SiLinux, SiDiscord } from 'react-icons/si';
import { UserPlus, ShieldCheck } from 'lucide-react';
import TelegramLoginButton from 'react-telegram-login';
import {
API,
showError,
showSuccess,
onGitHubOAuthClicked,
onOIDCClicked,
onLinuxDOOAuthClicked,
onDiscordOAuthClicked,
onCustomOAuthClicked,
} from '../../../../helpers';
import TwoFASetting from '../components/TwoFASetting';
@@ -94,6 +98,66 @@ const AccountManagement = ({
const isBound = (accountId) => Boolean(accountId);
const [showTelegramBindModal, setShowTelegramBindModal] =
React.useState(false);
const [customOAuthBindings, setCustomOAuthBindings] = React.useState([]);
const [customOAuthLoading, setCustomOAuthLoading] = React.useState({});
// Fetch custom OAuth bindings
const loadCustomOAuthBindings = async () => {
try {
const res = await API.get('/api/user/oauth/bindings');
if (res.data.success) {
setCustomOAuthBindings(res.data.data || []);
}
} catch (error) {
// ignore
}
};
// Unbind custom OAuth provider
const handleUnbindCustomOAuth = async (providerId, providerName) => {
Modal.confirm({
title: t('确认解绑'),
content: t('确定要解绑 {{name}} 吗?', { name: providerName }),
okText: t('确认'),
cancelText: t('取消'),
onOk: async () => {
setCustomOAuthLoading((prev) => ({ ...prev, [providerId]: true }));
try {
const res = await API.delete(`/api/user/oauth/bindings/${providerId}`);
if (res.data.success) {
showSuccess(t('解绑成功'));
await loadCustomOAuthBindings();
} else {
showError(res.data.message);
}
} catch (error) {
showError(t('操作失败'));
} finally {
setCustomOAuthLoading((prev) => ({ ...prev, [providerId]: false }));
}
},
});
};
// Handle bind custom OAuth
const handleBindCustomOAuth = (provider) => {
onCustomOAuthClicked(provider);
};
// Check if custom OAuth provider is bound
const isCustomOAuthBound = (providerId) => {
return customOAuthBindings.some((b) => b.provider_id === providerId);
};
// Get binding info for a provider
const getCustomOAuthBinding = (providerId) => {
return customOAuthBindings.find((b) => b.provider_id === providerId);
};
React.useEffect(() => {
loadCustomOAuthBindings();
}, []);
const passkeyEnabled = passkeyStatus?.enabled;
const lastUsedLabel = passkeyStatus?.last_used_at
? new Date(passkeyStatus.last_used_at).toLocaleString()
@@ -447,6 +511,64 @@ const AccountManagement = ({
</div>
</div>
</Card>
{/* 自定义 OAuth 提供商绑定 */}
{status.custom_oauth_providers &&
status.custom_oauth_providers.map((provider) => {
const bound = isCustomOAuthBound(provider.id);
const binding = getCustomOAuthBinding(provider.id);
return (
<Card key={provider.slug} className='!rounded-xl'>
<div className='flex items-center justify-between gap-3'>
<div className='flex items-center flex-1 min-w-0'>
<div className='w-10 h-10 rounded-full bg-slate-100 dark:bg-slate-700 flex items-center justify-center mr-3 flex-shrink-0'>
<IconLock
size='default'
className='text-slate-600 dark:text-slate-300'
/>
</div>
<div className='flex-1 min-w-0'>
<div className='font-medium text-gray-900'>
{provider.name}
</div>
<div className='text-sm text-gray-500 truncate'>
{bound
? renderAccountInfo(
binding?.provider_user_id,
t('{{name}} ID', { name: provider.name }),
)
: t('未绑定')}
</div>
</div>
</div>
<div className='flex-shrink-0'>
{bound ? (
<Button
type='danger'
theme='outline'
size='small'
loading={customOAuthLoading[provider.id]}
onClick={() =>
handleUnbindCustomOAuth(provider.id, provider.name)
}
>
{t('解绑')}
</Button>
) : (
<Button
type='primary'
theme='outline'
size='small'
onClick={() => handleBindCustomOAuth(provider)}
>
{t('绑定')}
</Button>
)}
</div>
</div>
</Card>
);
})}
</div>
</div>
</TabPane>

View File

@@ -1,280 +0,0 @@
/*
Copyright (C) 2025 QuantumNous
This program is free software: you can redistribute it and/or modify
it under the terms of the GNU Affero General Public License as
published by the Free Software Foundation, either version 3 of the
License, or (at your option) any later version.
This program is distributed in the hope that it will be useful,
but WITHOUT ANY WARRANTY; without even the implied warranty of
MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
GNU Affero General Public License for more details.
You should have received a copy of the GNU Affero General Public License
along with this program. If not, see <https://www.gnu.org/licenses/>.
For commercial licensing, please contact support@quantumnous.com
*/
import React, { useState, useEffect } from 'react';
import {
Empty,
Skeleton,
Space,
Tag,
Collapsible,
Tabs,
TabPane,
Typography,
Avatar,
} from '@douyinfe/semi-ui';
import {
IllustrationNoContent,
IllustrationNoContentDark,
} from '@douyinfe/semi-illustrations';
import { IconChevronDown, IconChevronUp } from '@douyinfe/semi-icons';
import { Settings } from 'lucide-react';
import { renderModelTag, getModelCategories } from '../../../../helpers';
const ModelsList = ({ t, models, modelsLoading, copyText }) => {
const [isModelsExpanded, setIsModelsExpanded] = useState(() => {
// Initialize from localStorage if available
const savedState = localStorage.getItem('modelsExpanded');
return savedState ? JSON.parse(savedState) : false;
});
const [activeModelCategory, setActiveModelCategory] = useState('all');
const MODELS_DISPLAY_COUNT = 25; // 默认显示的模型数量
// Save models expanded state to localStorage whenever it changes
useEffect(() => {
localStorage.setItem('modelsExpanded', JSON.stringify(isModelsExpanded));
}, [isModelsExpanded]);
return (
<div className='py-4'>
{/* 卡片头部 */}
<div className='flex items-center mb-4'>
<Avatar size='small' color='green' className='mr-3 shadow-md'>
<Settings size={16} />
</Avatar>
<div>
<Typography.Text className='text-lg font-medium'>
{t('可用模型')}
</Typography.Text>
<div className='text-xs text-gray-600'>
{t('查看当前可用的所有模型')}
</div>
</div>
</div>
{/* 可用模型部分 */}
<div className='bg-gray-50 dark:bg-gray-800 rounded-xl'>
{modelsLoading ? (
// 骨架屏加载状态 - 模拟实际加载后的布局
<div className='space-y-4'>
{/* 模拟分类标签 */}
<div
className='mb-4'
style={{ borderBottom: '1px solid var(--semi-color-border)' }}
>
<div className='flex overflow-x-auto py-2 gap-2'>
{Array.from({ length: 8 }).map((_, index) => (
<Skeleton.Button
key={`cat-${index}`}
style={{
width: index === 0 ? 130 : 100 + Math.random() * 50,
height: 36,
borderRadius: 8,
}}
/>
))}
</div>
</div>
{/* 模拟模型标签列表 */}
<div className='flex flex-wrap gap-2'>
{Array.from({ length: 20 }).map((_, index) => (
<Skeleton.Button
key={`model-${index}`}
style={{
width: 100 + Math.random() * 100,
height: 32,
borderRadius: 16,
margin: '4px',
}}
/>
))}
</div>
</div>
) : models.length === 0 ? (
<div className='py-8'>
<Empty
image={
<IllustrationNoContent style={{ width: 150, height: 150 }} />
}
darkModeImage={
<IllustrationNoContentDark
style={{ width: 150, height: 150 }}
/>
}
description={t('没有可用模型')}
style={{ padding: '24px 0' }}
/>
</div>
) : (
<>
{/* 模型分类标签页 */}
<div className='mb-4'>
<Tabs
type='card'
activeKey={activeModelCategory}
onChange={(key) => setActiveModelCategory(key)}
className='mt-2'
collapsible
>
{Object.entries(getModelCategories(t)).map(
([key, category]) => {
// 计算该分类下的模型数量
const modelCount =
key === 'all'
? models.length
: models.filter((model) =>
category.filter({ model_name: model }),
).length;
if (modelCount === 0 && key !== 'all') return null;
return (
<TabPane
tab={
<span className='flex items-center gap-2'>
{category.icon && (
<span className='w-4 h-4'>{category.icon}</span>
)}
{category.label}
<Tag
color={
activeModelCategory === key ? 'red' : 'grey'
}
size='small'
shape='circle'
>
{modelCount}
</Tag>
</span>
}
itemKey={key}
key={key}
/>
);
},
)}
</Tabs>
</div>
<div className='bg-white dark:bg-gray-700 rounded-lg p-3'>
{(() => {
// 根据当前选中的分类过滤模型
const categories = getModelCategories(t);
const filteredModels =
activeModelCategory === 'all'
? models
: models.filter((model) =>
categories[activeModelCategory].filter({
model_name: model,
}),
);
// 如果过滤后没有模型,显示空状态
if (filteredModels.length === 0) {
return (
<Empty
image={
<IllustrationNoContent
style={{ width: 120, height: 120 }}
/>
}
darkModeImage={
<IllustrationNoContentDark
style={{ width: 120, height: 120 }}
/>
}
description={t('该分类下没有可用模型')}
style={{ padding: '16px 0' }}
/>
);
}
if (filteredModels.length <= MODELS_DISPLAY_COUNT) {
return (
<Space wrap>
{filteredModels.map((model) =>
renderModelTag(model, {
size: 'small',
shape: 'circle',
onClick: () => copyText(model),
}),
)}
</Space>
);
} else {
return (
<>
<Collapsible isOpen={isModelsExpanded}>
<Space wrap>
{filteredModels.map((model) =>
renderModelTag(model, {
size: 'small',
shape: 'circle',
onClick: () => copyText(model),
}),
)}
<Tag
color='grey'
type='light'
className='cursor-pointer !rounded-lg'
onClick={() => setIsModelsExpanded(false)}
icon={<IconChevronUp />}
>
{t('收起')}
</Tag>
</Space>
</Collapsible>
{!isModelsExpanded && (
<Space wrap>
{filteredModels
.slice(0, MODELS_DISPLAY_COUNT)
.map((model) =>
renderModelTag(model, {
size: 'small',
shape: 'circle',
onClick: () => copyText(model),
}),
)}
<Tag
color='grey'
type='light'
className='cursor-pointer !rounded-lg'
onClick={() => setIsModelsExpanded(true)}
icon={<IconChevronDown />}
>
{t('更多')}{' '}
{filteredModels.length - MODELS_DISPLAY_COUNT}{' '}
{t('个模型')}
</Tag>
</Space>
)}
</>
);
}
})()}
</div>
</>
)}
</div>
</div>
);
};
export default ModelsList;

View File

@@ -1,44 +0,0 @@
/*
Copyright (C) 2025 QuantumNous
This program is free software: you can redistribute it and/or modify
it under the terms of the GNU Affero General Public License as
published by the Free Software Foundation, either version 3 of the
License, or (at your option) any later version.
This program is distributed in the hope that it will be useful,
but WITHOUT ANY WARRANTY; without even the implied warranty of
MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
GNU Affero General Public License for more details.
You should have received a copy of the GNU Affero General Public License
along with this program. If not, see <https://www.gnu.org/licenses/>.
For commercial licensing, please contact support@quantumnous.com
*/
import React from 'react';
import { Typography } from '@douyinfe/semi-ui';
import { Layers } from 'lucide-react';
import CompactModeToggle from '../../common/ui/CompactModeToggle';
const { Text } = Typography;
const ModelsDescription = ({ compactMode, setCompactMode, t }) => {
return (
<div className='flex flex-col md:flex-row justify-between items-start md:items-center gap-2 w-full'>
<div className='flex items-center text-green-500'>
<Layers size={16} className='mr-2' />
<Text>{t('模型管理')}</Text>
</div>
<CompactModeToggle
compactMode={compactMode}
setCompactMode={setCompactMode}
t={t}
/>
</div>
);
};
export default ModelsDescription;

View File

@@ -1,123 +0,0 @@
/*
Copyright (C) 2025 QuantumNous
This program is free software: you can redistribute it and/or modify
it under the terms of the GNU Affero General Public License as
published by the Free Software Foundation, either version 3 of the
License, or (at your option) any later version.
This program is distributed in the hope that it will be useful,
but WITHOUT ANY WARRANTY; without even the implied warranty of
MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
GNU Affero General Public License for more details.
You should have received a copy of the GNU Affero General Public License
along with this program. If not, see <https://www.gnu.org/licenses/>.
For commercial licensing, please contact support@quantumnous.com
*/
import React, { useEffect, useMemo, useState } from 'react';
import { Modal, Select, Space, Typography } from '@douyinfe/semi-ui';
import { API, showError, showSuccess } from '../../../../helpers';
const { Text } = Typography;
const BindSubscriptionModal = ({ visible, onCancel, user, t, onSuccess }) => {
const [loading, setLoading] = useState(false);
const [plans, setPlans] = useState([]);
const [selectedPlanId, setSelectedPlanId] = useState(null);
const loadPlans = async () => {
setLoading(true);
try {
const res = await API.get('/api/subscription/admin/plans');
if (res.data?.success) {
setPlans(res.data.data || []);
} else {
showError(res.data?.message || t('加载失败'));
}
} catch (e) {
showError(t('请求失败'));
} finally {
setLoading(false);
}
};
useEffect(() => {
if (visible) {
setSelectedPlanId(null);
loadPlans();
}
}, [visible]);
const planOptions = useMemo(() => {
return (plans || []).map((p) => ({
label: `${p?.plan?.title || ''} (${p?.plan?.currency || 'USD'} ${Number(p?.plan?.price_amount || 0)})`,
value: p?.plan?.id,
}));
}, [plans]);
const bind = async () => {
if (!user?.id) {
showError(t('用户信息缺失'));
return;
}
if (!selectedPlanId) {
showError(t('请选择订阅套餐'));
return;
}
setLoading(true);
try {
const res = await API.post('/api/subscription/admin/bind', {
user_id: user.id,
plan_id: selectedPlanId,
});
if (res.data?.success) {
showSuccess(t('绑定成功'));
onSuccess?.();
onCancel?.();
} else {
showError(res.data?.message || t('绑定失败'));
}
} catch (e) {
showError(t('请求失败'));
} finally {
setLoading(false);
}
};
return (
<Modal
title={t('绑定订阅套餐')}
visible={visible}
onCancel={onCancel}
onOk={bind}
confirmLoading={loading}
maskClosable={false}
centered
>
<Space vertical style={{ width: '100%' }} spacing='medium'>
<div className='text-sm'>
<Text strong>{t('用户')}</Text>
<Text>{user?.username}</Text>
<Text type='tertiary'> (ID: {user?.id})</Text>
</div>
<Select
placeholder={t('选择订阅套餐')}
optionList={planOptions}
value={selectedPlanId}
onChange={setSelectedPlanId}
loading={loading}
filter
style={{ width: '100%' }}
/>
<div className='text-xs text-gray-500'>
{t('绑定后会立即生成用户订阅(无需支付),有效期按套餐配置计算。')}
</div>
</Space>
</Modal>
);
};
export default BindSubscriptionModal;

View File

@@ -294,6 +294,48 @@ export async function onLinuxDOOAuthClicked(
);
}
/**
* Initiate custom OAuth login
* @param {Object} provider - Custom OAuth provider config from status API
* @param {string} provider.slug - Provider slug (used for callback URL)
* @param {string} provider.client_id - OAuth client ID
* @param {string} provider.authorization_endpoint - Authorization URL
* @param {string} provider.scopes - OAuth scopes (space-separated)
* @param {Object} options - Options
* @param {boolean} options.shouldLogout - Whether to logout first
*/
export async function onCustomOAuthClicked(provider, options = {}) {
const state = await prepareOAuthState(options);
if (!state) return;
try {
const redirect_uri = `${window.location.origin}/oauth/${provider.slug}`;
// Check if authorization_endpoint is a full URL or relative path
let authUrl;
if (provider.authorization_endpoint.startsWith('http://') ||
provider.authorization_endpoint.startsWith('https://')) {
authUrl = new URL(provider.authorization_endpoint);
} else {
// Relative path - this is a configuration error, show error message
console.error('Custom OAuth authorization_endpoint must be a full URL:', provider.authorization_endpoint);
showError('OAuth 配置错误:授权端点必须是完整的 URL以 http:// 或 https:// 开头)');
return;
}
authUrl.searchParams.set('client_id', provider.client_id);
authUrl.searchParams.set('redirect_uri', redirect_uri);
authUrl.searchParams.set('response_type', 'code');
authUrl.searchParams.set('scope', provider.scopes || 'openid profile email');
authUrl.searchParams.set('state', state);
window.open(authUrl.toString());
} catch (error) {
console.error('Failed to initiate custom OAuth:', error);
showError('OAuth 登录失败:' + (error.message || '未知错误'));
}
}
let channelModels = undefined;
export async function loadChannelModels() {
const res = await API.get('/api/models');

View File

@@ -571,7 +571,6 @@ export const modelColorMap = {
'claude-3-opus-20240229': 'rgb(255,132,31)', // 橙红色
'claude-3-sonnet-20240229': 'rgb(253,135,93)', // 橙色
'claude-3-haiku-20240307': 'rgb(255,175,146)', // 浅橙色
'claude-2.1': 'rgb(255,209,190)', // 浅橙色(略有区别)
};
export function modelToColor(modelName) {

View File

@@ -1,312 +0,0 @@
/*
Copyright (C) 2025 QuantumNous
This program is free software: you can redistribute it and/or modify
it under the terms of the GNU Affero General Public License as
published by the Free Software Foundation, either version 3 of the
License, or (at your option) any later version.
This program is distributed in the hope that it will be useful,
but WITHOUT ANY WARRANTY; without even the implied warranty of
MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
GNU Affero General Public License for more details.
You should have received a copy of the GNU Affero General Public License
along with this program. If not, see <https://www.gnu.org/licenses/>.
For commercial licensing, please contact support@quantumnous.com
*/
import { useState, useCallback } from 'react';
import { API } from '../../helpers';
import { showError } from '../../helpers';
export const useDeploymentResources = () => {
const [hardwareTypes, setHardwareTypes] = useState([]);
const [hardwareTotalAvailable, setHardwareTotalAvailable] = useState(0);
const [locations, setLocations] = useState([]);
const [locationsTotalAvailable, setLocationsTotalAvailable] = useState(0);
const [availableReplicas, setAvailableReplicas] = useState([]);
const [priceEstimation, setPriceEstimation] = useState(null);
const [loadingHardware, setLoadingHardware] = useState(false);
const [loadingLocations, setLoadingLocations] = useState(false);
const [loadingReplicas, setLoadingReplicas] = useState(false);
const [loadingPrice, setLoadingPrice] = useState(false);
const fetchHardwareTypes = useCallback(async () => {
try {
setLoadingHardware(true);
const response = await API.get('/api/deployments/hardware-types');
if (response.data.success) {
const { hardware_types: hardwareList = [], total_available } =
response.data.data || {};
const normalizedHardware = hardwareList.map((hardware) => {
const availableCountValue = Number(hardware.available_count);
const availableCount = Number.isNaN(availableCountValue)
? 0
: availableCountValue;
const availableBool =
typeof hardware.available === 'boolean'
? hardware.available
: availableCount > 0;
return {
...hardware,
available: availableBool,
available_count: availableCount,
};
});
const providedTotal = Number(total_available);
const fallbackTotal = normalizedHardware.reduce(
(acc, item) =>
acc +
(Number.isNaN(item.available_count) ? 0 : item.available_count),
0,
);
const hasProvidedTotal =
total_available !== undefined &&
total_available !== null &&
total_available !== '' &&
!Number.isNaN(providedTotal);
setHardwareTypes(normalizedHardware);
setHardwareTotalAvailable(
hasProvidedTotal ? providedTotal : fallbackTotal,
);
return normalizedHardware;
} else {
showError('获取硬件类型失败: ' + response.data.message);
setHardwareTotalAvailable(0);
return [];
}
} catch (error) {
showError('获取硬件类型失败: ' + error.message);
setHardwareTotalAvailable(0);
return [];
} finally {
setLoadingHardware(false);
}
}, []);
const fetchLocations = useCallback(async (hardwareId, gpuCount = 1) => {
if (!hardwareId) {
setLocations([]);
setLocationsTotalAvailable(0);
return [];
}
try {
setLoadingLocations(true);
const response = await API.get(
`/api/deployments/available-replicas?hardware_id=${hardwareId}&gpu_count=${gpuCount}`,
);
if (response.data.success) {
const replicas = response.data.data?.replicas || [];
const nextLocationsMap = new Map();
replicas.forEach((replica) => {
const rawId = replica?.location_id ?? replica?.location?.id;
if (rawId === null || rawId === undefined) return;
const mapKey = String(rawId);
if (nextLocationsMap.has(mapKey)) return;
const rawIso2 =
replica?.iso2 ?? replica?.location_iso2 ?? replica?.location?.iso2;
const iso2 = rawIso2 ? String(rawIso2).toUpperCase() : '';
const name =
replica?.location_name ??
replica?.location?.name ??
replica?.name ??
String(rawId);
nextLocationsMap.set(mapKey, {
id: rawId,
name: String(name),
iso2,
region:
replica?.region ??
replica?.location_region ??
replica?.location?.region,
country:
replica?.country ??
replica?.location_country ??
replica?.location?.country,
code:
replica?.code ??
replica?.location_code ??
replica?.location?.code,
available: Number(replica?.available_count) || 0,
});
});
const normalizedLocations = Array.from(nextLocationsMap.values());
setLocations(normalizedLocations);
setLocationsTotalAvailable(
normalizedLocations.reduce(
(acc, item) => acc + (item.available || 0),
0,
),
);
return normalizedLocations;
} else {
showError('获取部署位置失败: ' + response.data.message);
setLocationsTotalAvailable(0);
return [];
}
} catch (error) {
showError('获取部署位置失败: ' + error.message);
setLocationsTotalAvailable(0);
return [];
} finally {
setLoadingLocations(false);
}
}, []);
const fetchAvailableReplicas = useCallback(
async (hardwareId, gpuCount = 1) => {
if (!hardwareId) {
setAvailableReplicas([]);
return [];
}
try {
setLoadingReplicas(true);
const response = await API.get(
`/api/deployments/available-replicas?hardware_id=${hardwareId}&gpu_count=${gpuCount}`,
);
if (response.data.success) {
const replicas = response.data.data.replicas || [];
setAvailableReplicas(replicas);
return replicas;
} else {
showError('获取可用资源失败: ' + response.data.message);
setAvailableReplicas([]);
return [];
}
} catch (error) {
console.error('Load available replicas error:', error);
setAvailableReplicas([]);
return [];
} finally {
setLoadingReplicas(false);
}
},
[],
);
const calculatePrice = useCallback(async (params) => {
const {
locationIds,
hardwareId,
gpusPerContainer,
durationHours,
replicaCount,
} = params;
if (
!locationIds?.length ||
!hardwareId ||
!gpusPerContainer ||
!durationHours ||
!replicaCount
) {
setPriceEstimation(null);
return null;
}
try {
setLoadingPrice(true);
const requestData = {
location_ids: locationIds,
hardware_id: hardwareId,
gpus_per_container: gpusPerContainer,
duration_hours: durationHours,
replica_count: replicaCount,
};
const response = await API.post(
'/api/deployments/price-estimation',
requestData,
);
if (response.data.success) {
const estimation = response.data.data;
setPriceEstimation(estimation);
return estimation;
} else {
showError('价格计算失败: ' + response.data.message);
setPriceEstimation(null);
return null;
}
} catch (error) {
console.error('Price calculation error:', error);
setPriceEstimation(null);
return null;
} finally {
setLoadingPrice(false);
}
}, []);
const checkClusterNameAvailability = useCallback(async (name) => {
if (!name?.trim()) return false;
try {
const response = await API.get(
`/api/deployments/check-name?name=${encodeURIComponent(name.trim())}`,
);
if (response.data.success) {
return response.data.data.available;
} else {
showError('检查名称可用性失败: ' + response.data.message);
return false;
}
} catch (error) {
console.error('Check cluster name availability error:', error);
return false;
}
}, []);
const createDeployment = useCallback(async (deploymentData) => {
try {
const response = await API.post('/api/deployments', deploymentData);
if (response.data.success) {
return response.data.data;
} else {
throw new Error(response.data.message || '创建部署失败');
}
} catch (error) {
throw error;
}
}, []);
return {
// Data
hardwareTypes,
hardwareTotalAvailable,
locations,
locationsTotalAvailable,
availableReplicas,
priceEstimation,
// Loading states
loadingHardware,
loadingLocations,
loadingReplicas,
loadingPrice,
// Functions
fetchHardwareTypes,
fetchLocations,
fetchAvailableReplicas,
calculatePrice,
checkClusterNameAvailability,
createDeployment,
// Clear functions
clearPriceEstimation: () => setPriceEstimation(null),
clearAvailableReplicas: () => setAvailableReplicas([]),
};
};
export default useDeploymentResources;

View File

@@ -1,286 +0,0 @@
/*
Copyright (C) 2025 QuantumNous
This program is free software: you can redistribute it and/or modify
it under the terms of the GNU Affero General Public License as
published by the Free Software Foundation, either version 3 of the
License, or (at your option) any later version.
This program is distributed in the hope that it will be useful,
but WITHOUT ANY WARRANTY; without even the implied warranty of
MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
GNU Affero General Public License for more details.
You should have received a copy of the GNU Affero General Public License
along with this program. If not, see <https://www.gnu.org/licenses/>.
For commercial licensing, please contact support@quantumnous.com
*/
import { useState } from 'react';
import { API, showError, showSuccess } from '../../helpers';
export const useEnhancedDeploymentActions = (t) => {
const [loading, setLoading] = useState({});
// Set loading state for specific operation
const setOperationLoading = (operation, deploymentId, isLoading) => {
setLoading((prev) => ({
...prev,
[`${operation}_${deploymentId}`]: isLoading,
}));
};
// Get loading state for specific operation
const isOperationLoading = (operation, deploymentId) => {
return loading[`${operation}_${deploymentId}`] || false;
};
// Extend deployment duration
const extendDeployment = async (deploymentId, durationHours) => {
try {
setOperationLoading('extend', deploymentId, true);
const response = await API.post(
`/api/deployments/${deploymentId}/extend`,
{
duration_hours: durationHours,
},
);
if (response.data.success) {
showSuccess(t('容器时长延长成功'));
return response.data.data;
}
} catch (error) {
showError(
t('延长时长失败') +
': ' +
(error.response?.data?.message || error.message),
);
throw error;
} finally {
setOperationLoading('extend', deploymentId, false);
}
};
// Get deployment details
const getDeploymentDetails = async (deploymentId) => {
try {
setOperationLoading('details', deploymentId, true);
const response = await API.get(`/api/deployments/${deploymentId}`);
if (response.data.success) {
return response.data.data;
}
} catch (error) {
showError(
t('获取详情失败') +
': ' +
(error.response?.data?.message || error.message),
);
throw error;
} finally {
setOperationLoading('details', deploymentId, false);
}
};
// Get deployment logs
const getDeploymentLogs = async (deploymentId, options = {}) => {
try {
setOperationLoading('logs', deploymentId, true);
const params = new URLSearchParams();
if (options.containerId)
params.append('container_id', options.containerId);
if (options.level) params.append('level', options.level);
if (options.limit) params.append('limit', options.limit.toString());
if (options.cursor) params.append('cursor', options.cursor);
if (options.follow) params.append('follow', 'true');
if (options.startTime) params.append('start_time', options.startTime);
if (options.endTime) params.append('end_time', options.endTime);
const response = await API.get(
`/api/deployments/${deploymentId}/logs?${params}`,
);
if (response.data.success) {
return response.data.data;
}
} catch (error) {
showError(
t('获取日志失败') +
': ' +
(error.response?.data?.message || error.message),
);
throw error;
} finally {
setOperationLoading('logs', deploymentId, false);
}
};
// Update deployment configuration
const updateDeploymentConfig = async (deploymentId, config) => {
try {
setOperationLoading('config', deploymentId, true);
const response = await API.put(
`/api/deployments/${deploymentId}`,
config,
);
if (response.data.success) {
showSuccess(t('容器配置更新成功'));
return response.data.data;
}
} catch (error) {
showError(
t('更新配置失败') +
': ' +
(error.response?.data?.message || error.message),
);
throw error;
} finally {
setOperationLoading('config', deploymentId, false);
}
};
// Delete (destroy) deployment
const deleteDeployment = async (deploymentId) => {
try {
setOperationLoading('delete', deploymentId, true);
const response = await API.delete(`/api/deployments/${deploymentId}`);
if (response.data.success) {
showSuccess(t('容器销毁请求已提交'));
return response.data.data;
}
} catch (error) {
showError(
t('销毁容器失败') +
': ' +
(error.response?.data?.message || error.message),
);
throw error;
} finally {
setOperationLoading('delete', deploymentId, false);
}
};
// Update deployment name
const updateDeploymentName = async (deploymentId, newName) => {
try {
setOperationLoading('rename', deploymentId, true);
const response = await API.put(`/api/deployments/${deploymentId}/name`, {
name: newName,
});
if (response.data.success) {
showSuccess(t('容器名称更新成功'));
return response.data.data;
}
} catch (error) {
showError(
t('更新名称失败') +
': ' +
(error.response?.data?.message || error.message),
);
throw error;
} finally {
setOperationLoading('rename', deploymentId, false);
}
};
// Batch operations
const batchDelete = async (deploymentIds) => {
try {
setOperationLoading('batch_delete', 'all', true);
const results = await Promise.allSettled(
deploymentIds.map((id) => deleteDeployment(id)),
);
const successful = results.filter((r) => r.status === 'fulfilled').length;
const failed = results.filter((r) => r.status === 'rejected').length;
if (successful > 0) {
showSuccess(
t('批量操作完成: {{success}}个成功, {{failed}}个失败', {
success: successful,
failed: failed,
}),
);
}
return { successful, failed };
} catch (error) {
showError(t('批量操作失败') + ': ' + error.message);
throw error;
} finally {
setOperationLoading('batch_delete', 'all', false);
}
};
// Export logs
const exportLogs = async (deploymentId, options = {}) => {
try {
setOperationLoading('export_logs', deploymentId, true);
const logs = await getDeploymentLogs(deploymentId, {
...options,
limit: 10000, // Get more logs for export
});
if (logs && logs.logs) {
const logText = logs.logs
.map(
(log) =>
`[${new Date(log.timestamp).toISOString()}] [${log.level}] ${log.source ? `[${log.source}] ` : ''}${log.message}`,
)
.join('\n');
const blob = new Blob([logText], { type: 'text/plain' });
const url = URL.createObjectURL(blob);
const a = document.createElement('a');
a.href = url;
a.download = `deployment-${deploymentId}-logs-${new Date().toISOString().split('T')[0]}.txt`;
document.body.appendChild(a);
a.click();
document.body.removeChild(a);
URL.revokeObjectURL(url);
showSuccess(t('日志导出成功'));
}
} catch (error) {
showError(t('导出日志失败') + ': ' + error.message);
throw error;
} finally {
setOperationLoading('export_logs', deploymentId, false);
}
};
return {
// Actions
extendDeployment,
getDeploymentDetails,
getDeploymentLogs,
updateDeploymentConfig,
deleteDeployment,
updateDeploymentName,
batchDelete,
exportLogs,
// Loading states
isOperationLoading,
loading,
// Utility
setOperationLoading,
};
};
export default useEnhancedDeploymentActions;

View File

@@ -40,6 +40,7 @@ export const useTokensData = (openFluentNotification) => {
const [tokenCount, setTokenCount] = useState(0);
const [pageSize, setPageSize] = useState(ITEMS_PER_PAGE);
const [searching, setSearching] = useState(false);
const [searchMode, setSearchMode] = useState(false); // 是否处于搜索结果视图
// Selection state
const [selectedKeys, setSelectedKeys] = useState([]);
@@ -91,6 +92,7 @@ export const useTokensData = (openFluentNotification) => {
// Load tokens function
const loadTokens = async (page = 1, size = pageSize) => {
setLoading(true);
setSearchMode(false);
const res = await API.get(`/api/token/?p=${page}&size=${size}`);
const { success, message, data } = res.data;
if (success) {
@@ -188,21 +190,21 @@ export const useTokensData = (openFluentNotification) => {
};
// Search tokens function
const searchTokens = async () => {
const searchTokens = async (page = 1, size = pageSize) => {
const { searchKeyword, searchToken } = getFormValues();
if (searchKeyword === '' && searchToken === '') {
setSearchMode(false);
await loadTokens(1);
return;
}
setSearching(true);
const res = await API.get(
`/api/token/search?keyword=${searchKeyword}&token=${searchToken}`,
`/api/token/search?keyword=${encodeURIComponent(searchKeyword)}&token=${encodeURIComponent(searchToken)}&p=${page}&size=${size}`,
);
const { success, message, data } = res.data;
if (success) {
setTokens(data);
setTokenCount(data.length);
setActivePage(1);
setSearchMode(true);
syncPageData(data);
} else {
showError(message);
}
@@ -226,12 +228,20 @@ export const useTokensData = (openFluentNotification) => {
// Page handlers
const handlePageChange = (page) => {
loadTokens(page, pageSize).then();
if (searchMode) {
searchTokens(page, pageSize).then();
} else {
loadTokens(page, pageSize).then();
}
};
const handlePageSizeChange = async (size) => {
setPageSize(size);
await loadTokens(1, size);
if (searchMode) {
await searchTokens(1, size);
} else {
await loadTokens(1, size);
}
};
// Row selection handlers

View File

@@ -2795,6 +2795,49 @@
"语言偏好": "Language Preference",
"选择您的首选界面语言,设置将自动保存并同步到所有设备": "Select your preferred interface language. Settings will be saved automatically and synced across all devices",
"语言偏好已保存": "Language preference saved",
"提示语言偏好会同步到您登录的所有设备并影响API返回的错误消息语言。": "Note: Language preference syncs across all your logged-in devices and affects the language of API error messages."
"提示语言偏好会同步到您登录的所有设备并影响API返回的错误消息语言。": "Note: Language preference syncs across all your logged-in devices and affects the language of API error messages.",
"自定义 OAuth 提供商": "Custom OAuth Providers",
"配置自定义 OAuth 提供商,支持 GitHub Enterprise、GitLab、Gitea、Nextcloud、Keycloak、ORY 等兼容 OAuth 2.0 协议的身份提供商": "Configure custom OAuth providers, supports GitHub Enterprise, GitLab, Gitea, Nextcloud, Keycloak, ORY and other OAuth 2.0 compatible identity providers",
"回调 URL 格式": "Callback URL format",
"添加提供商": "Add Provider",
"编辑提供商": "Edit Provider",
"选择预设...": "Select preset...",
"输入基础 URL": "Enter base URL",
"例如": "e.g.",
"提供商名称": "Provider Name",
"标识符 (Slug)": "Slug",
"授权端点": "Authorization Endpoint",
"令牌端点": "Token Endpoint",
"用户信息端点": "User Info Endpoint",
"用户 ID 字段": "User ID Field",
"支持 JSONPath如 sub, id, data.user.id": "Supports JSONPath, e.g. sub, id, data.user.id",
"用户名字段": "Username Field",
"支持 JSONPath如 preferred_username, login, data.user.username": "Supports JSONPath, e.g. preferred_username, login, data.user.username",
"显示名称字段": "Display Name Field",
"支持 JSONPath如 name, display_name, data.user.name": "Supports JSONPath, e.g. name, display_name, data.user.name",
"邮箱字段": "Email Field",
"支持 JSONPath如 email, data.user.email": "Supports JSONPath, e.g. email, data.user.email",
"授权范围 (Scopes)": "Scopes",
"认证方式": "Auth Style",
"自动检测": "Auto-detect",
"参数传递": "In Parameters",
"Basic Auth 头": "Basic Auth Header",
"暂无自定义 OAuth 提供商": "No custom OAuth providers",
"确定要删除该提供商吗?": "Are you sure you want to delete this provider?",
"创建成功": "Created successfully",
"更新成功": "Updated successfully",
"确认解绑": "Confirm Unbind",
"确定要解绑 {{name}} 吗?": "Are you sure you want to unbind {{name}}?",
"解绑成功": "Unbind successful",
"{{name}} ID": "{{name}} ID",
"使用 {{name}} 继续": "Continue with {{name}}",
"端点 URL 必须以 http:// 或 https:// 开头:": "Endpoint URL must start with http:// or https://: ",
"OAuth 配置错误:授权端点必须是完整的 URL以 http:// 或 https:// 开头)": "OAuth configuration error: Authorization endpoint must be a full URL (starting with http:// or https://)",
"OAuth 登录失败:": "OAuth login failed: ",
"必填:请输入服务器地址以自动生成完整端点 URL": "Required: Enter server address to auto-generate full endpoint URLs",
"填写服务器地址后自动生成:": "Auto-generated after entering server address: ",
"自动生成:": "Auto-generated: ",
"请先填写服务器地址,以自动生成完整的端点 URL": "Please enter the server address first to auto-generate full endpoint URLs",
"端点 URL 必须是完整地址(以 http:// 或 https:// 开头)": "Endpoint URL must be a full address (starting with http:// or https://)"
}
}

View File

@@ -2740,6 +2740,49 @@
"语言偏好": "语言偏好",
"选择您的首选界面语言,设置将自动保存并同步到所有设备": "选择您的首选界面语言,设置将自动保存并同步到所有设备",
"语言偏好已保存": "语言偏好已保存",
"提示语言偏好会同步到您登录的所有设备并影响API返回的错误消息语言。": "提示语言偏好会同步到您登录的所有设备并影响API返回的错误消息语言。"
"提示语言偏好会同步到您登录的所有设备并影响API返回的错误消息语言。": "提示语言偏好会同步到您登录的所有设备并影响API返回的错误消息语言。",
"自定义 OAuth 提供商": "自定义 OAuth 提供商",
"配置自定义 OAuth 提供商,支持 GitHub Enterprise、GitLab、Gitea、Nextcloud、Keycloak、ORY 等兼容 OAuth 2.0 协议的身份提供商": "配置自定义 OAuth 提供商,支持 GitHub Enterprise、GitLab、Gitea、Nextcloud、Keycloak、ORY 等兼容 OAuth 2.0 协议的身份提供商",
"回调 URL 格式": "回调 URL 格式",
"添加提供商": "添加提供商",
"编辑提供商": "编辑提供商",
"选择预设...": "选择预设...",
"输入基础 URL": "输入基础 URL",
"例如": "例如",
"提供商名称": "提供商名称",
"标识符 (Slug)": "标识符 (Slug)",
"授权端点": "授权端点",
"令牌端点": "令牌端点",
"用户信息端点": "用户信息端点",
"用户 ID 字段": "用户 ID 字段",
"支持 JSONPath如 sub, id, data.user.id": "支持 JSONPath如 sub, id, data.user.id",
"用户名字段": "用户名字段",
"支持 JSONPath如 preferred_username, login, data.user.username": "支持 JSONPath如 preferred_username, login, data.user.username",
"显示名称字段": "显示名称字段",
"支持 JSONPath如 name, display_name, data.user.name": "支持 JSONPath如 name, display_name, data.user.name",
"邮箱字段": "邮箱字段",
"支持 JSONPath如 email, data.user.email": "支持 JSONPath如 email, data.user.email",
"授权范围 (Scopes)": "授权范围 (Scopes)",
"认证方式": "认证方式",
"自动检测": "自动检测",
"参数传递": "参数传递",
"Basic Auth 头": "Basic Auth 头",
"暂无自定义 OAuth 提供商": "暂无自定义 OAuth 提供商",
"确定要删除该提供商吗?": "确定要删除该提供商吗?",
"创建成功": "创建成功",
"更新成功": "更新成功",
"确认解绑": "确认解绑",
"确定要解绑 {{name}} 吗?": "确定要解绑 {{name}} 吗?",
"解绑成功": "解绑成功",
"{{name}} ID": "{{name}} ID",
"使用 {{name}} 继续": "使用 {{name}} 继续",
"端点 URL 必须以 http:// 或 https:// 开头:": "端点 URL 必须以 http:// 或 https:// 开头:",
"OAuth 配置错误:授权端点必须是完整的 URL以 http:// 或 https:// 开头)": "OAuth 配置错误:授权端点必须是完整的 URL以 http:// 或 https:// 开头)",
"OAuth 登录失败:": "OAuth 登录失败:",
"必填:请输入服务器地址以自动生成完整端点 URL": "必填:请输入服务器地址以自动生成完整端点 URL",
"填写服务器地址后自动生成:": "填写服务器地址后自动生成:",
"自动生成:": "自动生成:",
"请先填写服务器地址,以自动生成完整的端点 URL": "请先填写服务器地址,以自动生成完整的端点 URL",
"端点 URL 必须是完整地址(以 http:// 或 https:// 开头)": "端点 URL 必须是完整地址(以 http:// 或 https:// 开头)"
}
}

View File

@@ -56,6 +56,7 @@ export default function GeneralSettings(props) {
DefaultCollapseSidebar: false,
DemoSiteEnabled: false,
SelfUseModeEnabled: false,
'token_setting.max_user_tokens': 1000,
});
const refForm = useRef();
const [inputsRow, setInputsRow] = useState(inputs);
@@ -287,6 +288,19 @@ export default function GeneralSettings(props) {
/>
</Col>
</Row>
<Row gutter={16}>
<Col xs={24} sm={12} md={8} lg={8} xl={8}>
<Form.InputNumber
label={t('用户最大令牌数量')}
field={'token_setting.max_user_tokens'}
step={1}
min={1}
extraText={t('每个用户最多可创建的令牌数量,默认 1000设置过大可能会影响性能')}
placeholder={'1000'}
onChange={handleFieldChange('token_setting.max_user_tokens')}
/>
</Col>
</Row>
<Row>
<Button size='default' onClick={onSubmit}>
{t('保存通用设置')}

View File

@@ -1,488 +0,0 @@
/*
Copyright (C) 2025 QuantumNous
This program is free software: you can redistribute it and/or modify
it under the terms of the GNU Affero General Public License as
published by the Free Software Foundation, either version 3 of the
License, or (at your option) any later version.
This program is distributed in the hope that it will be useful,
but WITHOUT ANY WARRANTY; without even the implied warranty of
MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
GNU Affero General Public License for more details.
You should have received a copy of the GNU Affero General Public License
along with this program. If not, see <https://www.gnu.org/licenses/>.
For commercial licensing, please contact support@quantumnous.com
*/
import { useState, useEffect, useContext } from 'react';
import { useTranslation } from 'react-i18next';
import {
Card,
Button,
Switch,
Typography,
Row,
Col,
Avatar,
} from '@douyinfe/semi-ui';
import { API, showSuccess, showError } from '../../../helpers';
import { StatusContext } from '../../../context/Status';
import { UserContext } from '../../../context/User';
import { useUserPermissions } from '../../../hooks/common/useUserPermissions';
import { mergeAdminConfig, useSidebar } from '../../../hooks/common/useSidebar';
import { Settings } from 'lucide-react';
const { Text } = Typography;
export default function SettingsSidebarModulesUser() {
const { t } = useTranslation();
const [loading, setLoading] = useState(false);
const [statusState] = useContext(StatusContext);
// 使用后端权限验证替代前端角色判断
const {
permissions,
loading: permissionsLoading,
hasSidebarSettingsPermission,
isSidebarSectionAllowed,
isSidebarModuleAllowed,
} = useUserPermissions();
// 使用useSidebar钩子获取刷新方法
const { refreshUserConfig } = useSidebar();
// 如果没有边栏设置权限,不显示此组件
if (!permissionsLoading && !hasSidebarSettingsPermission()) {
return null;
}
// 权限加载中,显示加载状态
if (permissionsLoading) {
return null;
}
// 根据用户权限生成默认配置
const generateDefaultConfig = () => {
const defaultConfig = {};
// 聊天区域 - 所有用户都可以访问
if (isSidebarSectionAllowed('chat')) {
defaultConfig.chat = {
enabled: true,
playground: isSidebarModuleAllowed('chat', 'playground'),
chat: isSidebarModuleAllowed('chat', 'chat'),
};
}
// 控制台区域 - 所有用户都可以访问
if (isSidebarSectionAllowed('console')) {
defaultConfig.console = {
enabled: true,
detail: isSidebarModuleAllowed('console', 'detail'),
token: isSidebarModuleAllowed('console', 'token'),
log: isSidebarModuleAllowed('console', 'log'),
midjourney: isSidebarModuleAllowed('console', 'midjourney'),
task: isSidebarModuleAllowed('console', 'task'),
};
}
// 个人中心区域 - 所有用户都可以访问
if (isSidebarSectionAllowed('personal')) {
defaultConfig.personal = {
enabled: true,
topup: isSidebarModuleAllowed('personal', 'topup'),
personal: isSidebarModuleAllowed('personal', 'personal'),
};
}
// 管理员区域 - 只有管理员可以访问
if (isSidebarSectionAllowed('admin')) {
defaultConfig.admin = {
enabled: true,
channel: isSidebarModuleAllowed('admin', 'channel'),
models: isSidebarModuleAllowed('admin', 'models'),
deployment: isSidebarModuleAllowed('admin', 'deployment'),
redemption: isSidebarModuleAllowed('admin', 'redemption'),
user: isSidebarModuleAllowed('admin', 'user'),
subscription: isSidebarModuleAllowed('admin', 'subscription'),
setting: isSidebarModuleAllowed('admin', 'setting'),
};
}
return defaultConfig;
};
// 用户个人左侧边栏模块设置
const [sidebarModulesUser, setSidebarModulesUser] = useState({});
// 管理员全局配置
const [adminConfig, setAdminConfig] = useState(null);
// 处理区域级别开关变更
function handleSectionChange(sectionKey) {
return (checked) => {
const newModules = {
...sidebarModulesUser,
[sectionKey]: {
...sidebarModulesUser[sectionKey],
enabled: checked,
},
};
setSidebarModulesUser(newModules);
console.log('用户边栏区域配置变更:', sectionKey, checked, newModules);
};
}
// 处理功能级别开关变更
function handleModuleChange(sectionKey, moduleKey) {
return (checked) => {
const newModules = {
...sidebarModulesUser,
[sectionKey]: {
...sidebarModulesUser[sectionKey],
[moduleKey]: checked,
},
};
setSidebarModulesUser(newModules);
console.log(
'用户边栏功能配置变更:',
sectionKey,
moduleKey,
checked,
newModules,
);
};
}
// 重置为默认配置(基于权限过滤)
function resetSidebarModules() {
const defaultConfig = generateDefaultConfig();
setSidebarModulesUser(defaultConfig);
showSuccess(t('已重置为默认配置'));
console.log('用户边栏配置重置为默认:', defaultConfig);
}
// 保存配置
async function onSubmit() {
setLoading(true);
try {
console.log('保存用户边栏配置:', sidebarModulesUser);
const res = await API.put('/api/user/self', {
sidebar_modules: JSON.stringify(sidebarModulesUser),
});
const { success, message } = res.data;
if (success) {
showSuccess(t('保存成功'));
console.log('用户边栏配置保存成功');
// 刷新useSidebar钩子中的用户配置实现实时更新
await refreshUserConfig();
console.log('用户边栏配置已刷新,边栏将立即更新');
} else {
showError(message);
console.error('用户边栏配置保存失败:', message);
}
} catch (error) {
showError(t('保存失败,请重试'));
console.error('用户边栏配置保存异常:', error);
} finally {
setLoading(false);
}
}
// 统一的配置加载逻辑
useEffect(() => {
const loadConfigs = async () => {
try {
// 获取管理员全局配置
if (statusState?.status?.SidebarModulesAdmin) {
try {
const adminConf = JSON.parse(
statusState.status.SidebarModulesAdmin,
);
const mergedAdminConf = mergeAdminConfig(adminConf);
setAdminConfig(mergedAdminConf);
console.log('加载管理员边栏配置:', mergedAdminConf);
} catch (error) {
const mergedAdminConf = mergeAdminConfig(null);
setAdminConfig(mergedAdminConf);
console.log(
'加载管理员边栏配置失败,使用默认配置:',
mergedAdminConf,
);
}
} else {
const mergedAdminConf = mergeAdminConfig(null);
setAdminConfig(mergedAdminConf);
console.log('管理员边栏配置缺失,使用默认配置:', mergedAdminConf);
}
// 获取用户个人配置
const userRes = await API.get('/api/user/self');
if (userRes.data.success && userRes.data.data.sidebar_modules) {
let userConf;
// 检查sidebar_modules是字符串还是对象
if (typeof userRes.data.data.sidebar_modules === 'string') {
userConf = JSON.parse(userRes.data.data.sidebar_modules);
} else {
userConf = userRes.data.data.sidebar_modules;
}
console.log('从API加载的用户配置:', userConf);
// 确保用户配置也经过权限过滤
const filteredUserConf = {};
Object.keys(userConf).forEach((sectionKey) => {
if (isSidebarSectionAllowed(sectionKey)) {
filteredUserConf[sectionKey] = { ...userConf[sectionKey] };
// 过滤不允许的模块
Object.keys(userConf[sectionKey]).forEach((moduleKey) => {
if (
moduleKey !== 'enabled' &&
!isSidebarModuleAllowed(sectionKey, moduleKey)
) {
delete filteredUserConf[sectionKey][moduleKey];
}
});
}
});
setSidebarModulesUser(filteredUserConf);
console.log('权限过滤后的用户配置:', filteredUserConf);
} else {
// 如果用户没有配置,使用权限过滤后的默认配置
const defaultConfig = generateDefaultConfig();
setSidebarModulesUser(defaultConfig);
console.log('用户无配置,使用默认配置:', defaultConfig);
}
} catch (error) {
console.error('加载边栏配置失败:', error);
// 出错时也使用默认配置
const defaultConfig = generateDefaultConfig();
setSidebarModulesUser(defaultConfig);
}
};
// 只有权限加载完成且有边栏设置权限时才加载配置
if (!permissionsLoading && hasSidebarSettingsPermission()) {
loadConfigs();
}
}, [
statusState,
permissionsLoading,
hasSidebarSettingsPermission,
isSidebarSectionAllowed,
isSidebarModuleAllowed,
]);
// 检查功能是否被管理员允许
const isAllowedByAdmin = (sectionKey, moduleKey = null) => {
if (!adminConfig) return true;
if (moduleKey) {
return (
adminConfig[sectionKey]?.enabled && adminConfig[sectionKey]?.[moduleKey]
);
} else {
return adminConfig[sectionKey]?.enabled;
}
};
// 区域配置数据(根据后端权限过滤)
const sectionConfigs = [
{
key: 'chat',
title: t('聊天区域'),
description: t('操练场和聊天功能'),
modules: [
{
key: 'playground',
title: t('操练场'),
description: t('AI模型测试环境'),
},
{ key: 'chat', title: t('聊天'), description: t('聊天会话管理') },
],
},
{
key: 'console',
title: t('控制台区域'),
description: t('数据管理和日志查看'),
modules: [
{ key: 'detail', title: t('数据看板'), description: t('系统数据统计') },
{ key: 'token', title: t('令牌管理'), description: t('API令牌管理') },
{ key: 'log', title: t('使用日志'), description: t('API使用记录') },
{
key: 'midjourney',
title: t('绘图日志'),
description: t('绘图任务记录'),
},
{ key: 'task', title: t('任务日志'), description: t('系统任务记录') },
],
},
{
key: 'personal',
title: t('个人中心区域'),
description: t('用户个人功能'),
modules: [
{ key: 'topup', title: t('钱包管理'), description: t('余额充值管理') },
{
key: 'personal',
title: t('个人设置'),
description: t('个人信息设置'),
},
],
},
{
key: 'admin',
title: t('管理员区域'),
description: t('系统管理功能'),
modules: [
{ key: 'channel', title: t('渠道管理'), description: t('API渠道配置') },
{ key: 'models', title: t('模型管理'), description: t('AI模型配置') },
{
key: 'deployment',
title: t('模型部署'),
description: t('模型部署管理'),
},
{
key: 'subscription',
title: t('订阅管理'),
description: t('订阅套餐管理'),
},
{
key: 'redemption',
title: t('兑换码管理'),
description: t('兑换码生成管理'),
},
{ key: 'user', title: t('用户管理'), description: t('用户账户管理') },
{
key: 'setting',
title: t('系统设置'),
description: t('系统参数配置'),
},
],
},
]
.filter((section) => {
// 使用后端权限验证替代前端角色判断
return isSidebarSectionAllowed(section.key);
})
.map((section) => ({
...section,
modules: section.modules.filter((module) =>
isSidebarModuleAllowed(section.key, module.key),
),
}))
.filter(
(section) =>
// 过滤掉没有可用模块的区域
section.modules.length > 0 && isAllowedByAdmin(section.key),
);
return (
<Card className='!rounded-2xl shadow-sm border-0'>
{/* 卡片头部 */}
<div className='flex items-center mb-4'>
<Avatar size='small' color='purple' className='mr-3 shadow-md'>
<Settings size={16} />
</Avatar>
<div>
<Typography.Text className='text-lg font-medium'>
{t('左侧边栏个人设置')}
</Typography.Text>
<div className='text-xs text-gray-600'>
{t('个性化设置左侧边栏的显示内容')}
</div>
</div>
</div>
<div className='mb-4'>
<Text type='secondary' className='text-sm text-gray-600'>
{t('您可以个性化设置侧边栏的要显示功能')}
</Text>
</div>
{sectionConfigs.map((section) => (
<div key={section.key} className='mb-6'>
{/* 区域标题和总开关 */}
<div className='flex justify-between items-center mb-4 p-4 bg-gray-50 rounded-xl border border-gray-200'>
<div>
<div className='font-semibold text-base text-gray-900 mb-1'>
{section.title}
</div>
<Text className='text-xs text-gray-600'>
{section.description}
</Text>
</div>
<Switch
checked={sidebarModulesUser[section.key]?.enabled !== false}
onChange={handleSectionChange(section.key)}
size='default'
/>
</div>
{/* 功能模块网格 */}
<Row gutter={[12, 12]}>
{section.modules.map((module) => (
<Col key={module.key} xs={24} sm={12} md={8} lg={6} xl={6}>
<Card
className={`!rounded-xl border border-gray-200 hover:border-blue-300 transition-all duration-200 ${
sidebarModulesUser[section.key]?.enabled !== false
? ''
: 'opacity-50'
}`}
bodyStyle={{ padding: '16px' }}
hoverable
>
<div className='flex justify-between items-center h-full'>
<div className='flex-1 text-left'>
<div className='font-semibold text-sm text-gray-900 mb-1'>
{module.title}
</div>
<Text className='text-xs text-gray-600 leading-relaxed block'>
{module.description}
</Text>
</div>
<div className='ml-4'>
<Switch
checked={
sidebarModulesUser[section.key]?.[module.key] !==
false
}
onChange={handleModuleChange(section.key, module.key)}
size='default'
disabled={
sidebarModulesUser[section.key]?.enabled === false
}
/>
</div>
</div>
</Card>
</Col>
))}
</Row>
</div>
))}
{/* 底部按钮 */}
<div className='flex justify-end gap-3 mt-6 pt-4 border-t border-gray-200'>
<Button
type='tertiary'
onClick={resetSidebarModules}
className='!rounded-lg'
>
{t('重置为默认')}
</Button>
<Button
type='primary'
onClick={onSubmit}
loading={loading}
className='!rounded-lg'
>
{t('保存设置')}
</Button>
</div>
</Card>
);
}