Compare commits

..

16 Commits

Author SHA1 Message Date
t0ng7u
380e1b7d56 🔐 fix(oauth): stop authorize flow from bouncing to /console; respect next and redirect unauthenticated users to consent
Problem
- Starting OAuth from Discourse hit GET /api/oauth/authorize and 302’d to /login?next=/oauth/consent…
- The login page and AuthRedirect always navigated to /console when a session existed, ignoring next, which aborted the OAuth flow and dropped users in the console.

Changes
- Backend (src/oauth/server.go)
  - When not logged in, redirect directly to /oauth/consent?<original_query> instead of /login?next=…
  - Keep no-store headers; preserve the original authorize querystring.
- Frontend
  - web/src/helpers/auth.jsx: AuthRedirect now honors the login page’s next query param and only redirects to safe internal paths (starts with “/”, not “//”); otherwise falls back to /console.
  - web/src/components/auth/LoginForm.jsx: After successful login and after 2FA success, navigate to next when present and safe; otherwise go to /console.

Result
- The OAuth authorize flow now reliably reaches the consent screen.
- On approval, the server issues an authorization code and 302’s back to the client’s redirect_uri (e.g., Discourse), completing SSO as expected.

Security
- Sanitize next to avoid open-redirects by allowing only same-origin internal paths.

Compatibility
- No behavior change for normal username/password sign-ins outside the OAuth flow.
- No changes to token/userinfo endpoints.

Testing
- Manually verified end-to-end with Discourse OAuth2 Basic:
  - authorize → consent → approve → redirect with code
- Lint checks pass for modified files.
2025-09-25 13:02:40 +08:00
t0ng7u
63828349de 🔐 fix(oauth2): initialize JWKS on first key creation; prevent nil panic and set current key
Why
- First-time “Initialize Keys” caused a nil pointer panic when adding the first JWK to a nil JWKS set.
- As a result, the returned kid was missing and the first key appeared as “historical” until a second rotation.
- Improve first-time UX: only show Key Management when the server is healthy and guide admins to the correct init flow.

Backend (bug fix)
- src/oauth/server.go
  - RotateSigningKey / GenerateAndPersistKey / ImportPEMKey:
    - If simpleJWKSSet is nil, create a new jwk.NewSet() before AddKey, otherwise AddKey as usual.
    - Ensure currentKeyID is updated; enforceKeyRetention remains unchanged.
  - This prevents the nil pointer panic, ensures the first key is added to JWKS, and is immediately the current key.

Frontend (UX)
- web/src/components/settings/oauth2/OAuth2ServerSettings.jsx
  - Show “Key Management” only when OAuth2 is enabled AND server is healthy (serverInfo present).
  - Refine the warning banner text to instruct: enable OAuth2 & SSO → Save configuration → Key Management → Initialize Keys.
- web/src/components/settings/oauth2/modals/JWKSManagerModal.jsx
  - Dynamic primary action in “Key List” tab:
    - No keys → “Initialize Keys”
    - Has keys → “Rotate Keys”
  - Simplify error handling by relying on `message` + localized fallback.

Notes
- No API surface changes; functional bugfix plus UI/UX improvements.
- Linting passed; no new warnings.

Test plan
1) Start with OAuth2 enabled and no signing keys.
2) Open “Key Management” → click “Initialize Keys”.
3) Expect: success response with new kid; table shows the new kid as Current; JWKS endpoint returns the key; no server panic.
2025-09-23 05:08:51 +08:00
t0ng7u
5706f0ee9f 🌏 i18n: Improve i18n translation 2025-09-23 04:15:59 +08:00
t0ng7u
e9e1dbff5e ♻️ refactor: reorganize OAuth consent page structure
- Move OAuth consent component to dedicated OAuth directory as index.jsx
- Rename component export structure for better module organization
- Update App.jsx import path to reflect new OAuth page structure
- Maintain existing OAuth consent functionality while improving
2025-09-23 04:01:48 +08:00
t0ng7u
315eabc1e7 🎨 refactor(oauth2): merge modals and improve UI consistency
This commit consolidates OAuth2 client management components and
enhances the overall user experience with improved UI consistency.

### Major Changes:

**Component Consolidation:**
- Merge CreateOAuth2ClientModal.jsx and EditOAuth2ClientModal.jsx into OAuth2ClientModal.jsx
- Extract inline Modal.info into dedicated ClientInfoModal.jsx component
- Adopt consistent SideSheet + Card layout following EditTokenModal.jsx style

**UI/UX Improvements:**
- Replace custom client type selection with SemiUI RadioGroup component
- Use 'card' type RadioGroup with descriptive 'extra' prop for better UX
- Remove all Row/Col components in favor of flexbox and margin-based layouts
- Refactor redirect URI section to mimic JSONEditor.jsx visual style
- Add responsive design support for mobile devices

**Form Enhancements:**
- Add 'required' attributes to all mandatory form fields
- Implement placeholders for grant types, scopes, and redirect URI inputs
- Set grant types and scopes to default empty arrays
- Add dynamic validation and conditional rendering for client types
- Improve redirect URI management with template filling functionality

**Bug Fixes:**
- Fix SideSheet closing direction consistency between create/edit modes
- Resolve client_type submission issue (object vs string)
- Prevent "Client Credentials" selection for public clients
- Fix grant type filtering when switching between client types
- Resolve i18n issues for API scope options (api:read, api:write)

**Code Quality:**
- Extract RedirectUriCard as reusable sub-component
- Add comprehensive internationalization support
- Implement proper state management and form validation
- Follow single responsibility principle for component separation

**Files Modified:**
- web/src/components/settings/oauth2/modals/OAuth2ClientModal.jsx
- web/src/components/settings/oauth2/modals/ClientInfoModal.jsx (new)
- web/src/components/settings/oauth2/OAuth2ClientSettings.jsx
- web/src/i18n/locales/en.json

**Files Removed:**
- web/src/components/settings/oauth2/modals/CreateOAuth2ClientModal.jsx
- web/src/components/settings/oauth2/modals/EditOAuth2ClientModal.jsx

This refactoring significantly improves code maintainability, reduces
duplication, and provides a more consistent and intuitive user interface
for OAuth2 client management.
2025-09-23 03:49:53 +08:00
t0ng7u
359dbc9d94 feat(oauth2): enhance JWKS manager modal with improved UX and i18n support
- Refactor JWKSManagerModal with tab-based navigation using Card components
- Add comprehensive i18n support with English translations for all text
- Optimize header actions: refresh button only appears in key list tab
- Improve responsive design using ResponsiveModal component
- Move cautionary text from bottom to Card titles for better visibility
- Update button styles: danger type for delete, circle shape for status tags
- Standardize code formatting (single quotes, multiline formatting)
- Enhance user workflow: separate Import PEM and Generate PEM operations
- Remove redundant cancel buttons as modal already has close icon

Breaking changes: None
Affects: JWKS key management, OAuth2 settings UI
2025-09-23 01:16:17 +08:00
t0ng7u
e157ea6ba2 🎨 style(oauth2): modernize Empty component and clean up inline styles
- **Empty Component Enhancement:**
  - Replace custom User icon with professional IllustrationNoResult from Semi Design
  - Add dark mode support with IllustrationNoResultDark component
  - Standardize illustration size to 150x150px for consistency
  - Add proper padding (30px) to match design system standards

- **Style Modernization:**
  - Convert inline styles to Tailwind CSS classes where appropriate
  - Replace `style={{ marginBottom: 16 }}` with `className='mb-4'`
  - Remove redundant `style={{ marginTop: 8 }}` from Table component
  - Remove custom `style={{ marginTop: 16 }}` from pagination and button

- **Pagination Simplification:**
  - Simplify showTotal configuration from custom function to boolean `true`
  - Remove unnecessary `size='small'` property from pagination
  - Clean up pagination styling for better consistency

- **Design System Alignment:**
  - Ensure Empty component matches UsersTable styling patterns
  - Improve visual consistency across OAuth2 management interfaces
  - Follow Semi Design illustration guidelines for empty states

- **Code Quality:**
  - Reduce inline style usage in favor of utility classes
  - Simplify component props where default behavior is sufficient
  - Maintain functionality while improving maintainability

This update enhances visual consistency and follows modern React styling practices while maintaining all existing functionality.
2025-09-20 23:30:26 +08:00
t0ng7u
dc3dba0665 enhance(oauth2): improve UI components and code display experience
- **Table Layout Optimization:**
  - Remove description column from OAuth2 client table to save space
  - Add tooltip on client name hover to display full description
  - Adjust table scroll width from 1200px to 1000px for better layout
  - Improve client name column width to 180px for better readability

- **Action Button Simplification:**
  - Replace icon-only buttons with text labels for better accessibility
  - Simplify Popconfirm content by removing complex styled layouts
  - Remove unnecessary Tooltip wrappers around action buttons
  - Clean up unused Lucide icon imports (Edit, Key, Trash2)

- **Code Display Enhancement:**
  - Replace basic <pre> tags with CodeViewer component in modal dialogs
  - Add syntax highlighting for JSON content in ServerInfoModal and JWKSInfoModal
  - Implement copy-to-clipboard functionality for server info and JWKS data
  - Add performance optimization for large content display
  - Provide expandable/collapsible interface for better UX

- **Component Architecture:**
  - Import and integrate CodeViewer component in both modal components
  - Set appropriate props: content, title, and language='json'
  - Maintain loading states and error handling functionality

- **Internationalization:**
  - Add English translations for new UI elements:
    * '暂无描述': 'No description'
    * 'OAuth2 服务器配置': 'OAuth2 Server Configuration'
    * 'JWKS 密钥集': 'JWKS Key Set'

- **User Experience Improvements:**
  - Enhanced tooltip interaction for description display
  - Better visual feedback with cursor-help styling
  - Improved code readability with professional dark theme
  - Consistent styling across all OAuth2 management interfaces

This update focuses on UI/UX improvements while maintaining full functionality and adding modern code viewing capabilities to the OAuth2 management system.
2025-09-20 23:19:42 +08:00
t0ng7u
81272da9ac ♻️ refactor(oauth2): restructure OAuth2 client settings UI and extract modal components
- **UI Restructuring:**
  - Separate client info into individual table columns (name, ID, description)
  - Replace icon-only action buttons with text labels for better UX
  - Adjust table scroll width from 1000px to 1200px for new column layout
  - Remove unnecessary Tooltip wrappers and Lucide icons (Edit, Key, Trash2)

- **Component Architecture:**
  - Extract all modal dialogs into separate reusable components:
    * SecretDisplayModal.jsx - for displaying regenerated client secrets
    * ServerInfoModal.jsx - for OAuth2 server configuration info
    * JWKSInfoModal.jsx - for JWKS key set information
  - Simplify main component by removing ~60 lines of inline modal code
  - Implement proper state management for each modal component

- **Code Quality:**
  - Remove unused imports and clean up component dependencies
  - Consolidate modal logic into dedicated components with error handling
  - Improve code maintainability and reusability across the application

- **Internationalization:**
  - Add English translation for '客户端名称': 'Client Name'
  - Remove duplicate translation keys to fix linter warnings
  - Ensure all new components support full i18n functionality

- **User Experience:**
  - Enhance table readability with dedicated columns for each data type
  - Maintain copyable client ID functionality in separate column
  - Improve action button accessibility with clear text labels
  - Add loading states and proper error handling in modal components

This refactoring improves code organization, enhances user experience, and follows React best practices for component composition and separation of concerns.
2025-09-20 22:52:50 +08:00
t0ng7u
926cad87b3 📱 feat(oauth): implement responsive design for consent page
- Add responsive layout for user info section with flex-col on mobile
- Optimize button layout: vertical stack on mobile, horizontal on desktop
- Implement mobile-first approach with sm: breakpoints throughout
- Adjust container width: max-w-sm on mobile, max-w-lg on desktop
- Enhance touch targets with larger buttons (size='large') on mobile
- Improve content hierarchy with primary action button on top for mobile
- Add responsive padding and spacing: px-3 sm:px-4, py-6 sm:py-8
- Optimize text sizing: text-sm sm:text-base for better mobile readability
- Implement responsive gaps: gap-4 sm:gap-6 for icon spacing
- Add break-all class for long URL text wrapping
- Adjust meta info card spacing and dot separator sizing
- Ensure consistent responsive padding across all content sections

This update significantly improves the mobile user experience while
maintaining the desktop layout, following mobile-first design principles
with Tailwind CSS responsive utilities.
2025-09-20 17:45:58 +08:00
t0ng7u
418ce449b7 feat(oauth): redesign consent page with GitHub-style UI and improved UX
- Redesign OAuth consent page layout with centered card design
- Implement GitHub-style authorization flow presentation
- Add application popover with detailed information on hover
- Replace generic icons with scope-specific icons (email, profile, admin, etc.)
- Integrate i18n support for all hardcoded strings
- Optimize permission display with encapsulated ScopeItem component
- Improve visual hierarchy with Semi UI Divider components
- Unify avatar sizes and implement dynamic color generation
- Move action buttons and redirect info to card footer
- Add separate meta information card for technical details
- Remove redundant color styles to rely on Semi UI theming
- Enhance user account section with clearer GitHub-style messaging
- Replace dot separators with Lucide icons for better visual consistency
- Add site logo with fallback mechanism for branding
- Implement responsive design with Tailwind CSS utilities

This redesign significantly improves the OAuth consent experience by following
modern UI patterns and providing clearer information hierarchy for users.
2025-09-20 17:01:00 +08:00
Seefs
4a02ab23ce rm env 2025-09-16 17:21:11 +08:00
Seefs
984097c60b rm docs 2025-09-16 17:20:34 +08:00
Seefs
5550ec017e feat: oauth2 2025-09-16 17:10:01 +08:00
Seefs
9e6752e0ee Merge branch 'main-upstream' into feature/sso 2025-09-16 13:31:40 +08:00
Seefs
91a0eb7031 wip sso 2025-09-08 12:09:26 +08:00
128 changed files with 7971 additions and 3237 deletions

View File

@@ -38,21 +38,21 @@ jobs:
- name: Build Backend (amd64)
run: |
go mod download
go build -ldflags "-s -w -X 'one-api/common.Version=$(git describe --tags)' -extldflags '-static'" -o new-api
go build -ldflags "-s -w -X 'one-api/common.Version=$(git describe --tags)' -extldflags '-static'" -o one-api
- name: Build Backend (arm64)
run: |
sudo apt-get update
DEBIAN_FRONTEND=noninteractive sudo apt-get install -y gcc-aarch64-linux-gnu
CC=aarch64-linux-gnu-gcc CGO_ENABLED=1 GOOS=linux GOARCH=arm64 go build -ldflags "-s -w -X 'one-api/common.Version=$(git describe --tags)' -extldflags '-static'" -o new-api-arm64
CC=aarch64-linux-gnu-gcc CGO_ENABLED=1 GOOS=linux GOARCH=arm64 go build -ldflags "-s -w -X 'one-api/common.Version=$(git describe --tags)' -extldflags '-static'" -o one-api-arm64
- name: Release
uses: softprops/action-gh-release@v1
if: startsWith(github.ref, 'refs/tags/')
with:
files: |
new-api
new-api-arm64
one-api
one-api-arm64
draft: true
generate_release_notes: true
env:

View File

@@ -39,12 +39,12 @@ jobs:
- name: Build Backend
run: |
go mod download
go build -ldflags "-X 'one-api/common.Version=$(git describe --tags)'" -o new-api-macos
go build -ldflags "-X 'one-api/common.Version=$(git describe --tags)'" -o one-api-macos
- name: Release
uses: softprops/action-gh-release@v1
if: startsWith(github.ref, 'refs/tags/')
with:
files: new-api-macos
files: one-api-macos
draft: true
generate_release_notes: true
env:

View File

@@ -41,12 +41,12 @@ jobs:
- name: Build Backend
run: |
go mod download
go build -ldflags "-s -w -X 'one-api/common.Version=$(git describe --tags)'" -o new-api.exe
go build -ldflags "-s -w -X 'one-api/common.Version=$(git describe --tags)'" -o one-api.exe
- name: Release
uses: softprops/action-gh-release@v1
if: startsWith(github.ref, 'refs/tags/')
with:
files: new-api.exe
files: one-api.exe
draft: true
generate_release_notes: true
env:

View File

@@ -1,22 +0,0 @@
package common
import "net"
func IsPrivateIP(ip net.IP) bool {
if ip.IsLoopback() || ip.IsLinkLocalUnicast() || ip.IsLinkLocalMulticast() {
return true
}
private := []net.IPNet{
{IP: net.IPv4(10, 0, 0, 0), Mask: net.CIDRMask(8, 32)},
{IP: net.IPv4(172, 16, 0, 0), Mask: net.CIDRMask(12, 32)},
{IP: net.IPv4(192, 168, 0, 0), Mask: net.CIDRMask(16, 32)},
}
for _, privateNet := range private {
if privateNet.Contains(ip) {
return true
}
}
return false
}

View File

@@ -1,327 +0,0 @@
package common
import (
"fmt"
"net"
"net/url"
"strconv"
"strings"
)
// SSRFProtection SSRF防护配置
type SSRFProtection struct {
AllowPrivateIp bool
DomainFilterMode bool // true: 白名单, false: 黑名单
DomainList []string // domain format, e.g. example.com, *.example.com
IpFilterMode bool // true: 白名单, false: 黑名单
IpList []string // CIDR or single IP
AllowedPorts []int // 允许的端口范围
ApplyIPFilterForDomain bool // 对域名启用IP过滤
}
// DefaultSSRFProtection 默认SSRF防护配置
var DefaultSSRFProtection = &SSRFProtection{
AllowPrivateIp: false,
DomainFilterMode: true,
DomainList: []string{},
IpFilterMode: true,
IpList: []string{},
AllowedPorts: []int{},
}
// isPrivateIP 检查IP是否为私有地址
func isPrivateIP(ip net.IP) bool {
if ip.IsLoopback() || ip.IsLinkLocalUnicast() || ip.IsLinkLocalMulticast() {
return true
}
// 检查私有网段
private := []net.IPNet{
{IP: net.IPv4(10, 0, 0, 0), Mask: net.CIDRMask(8, 32)}, // 10.0.0.0/8
{IP: net.IPv4(172, 16, 0, 0), Mask: net.CIDRMask(12, 32)}, // 172.16.0.0/12
{IP: net.IPv4(192, 168, 0, 0), Mask: net.CIDRMask(16, 32)}, // 192.168.0.0/16
{IP: net.IPv4(127, 0, 0, 0), Mask: net.CIDRMask(8, 32)}, // 127.0.0.0/8
{IP: net.IPv4(169, 254, 0, 0), Mask: net.CIDRMask(16, 32)}, // 169.254.0.0/16 (链路本地)
{IP: net.IPv4(224, 0, 0, 0), Mask: net.CIDRMask(4, 32)}, // 224.0.0.0/4 (组播)
{IP: net.IPv4(240, 0, 0, 0), Mask: net.CIDRMask(4, 32)}, // 240.0.0.0/4 (保留)
}
for _, privateNet := range private {
if privateNet.Contains(ip) {
return true
}
}
// 检查IPv6私有地址
if ip.To4() == nil {
// IPv6 loopback
if ip.Equal(net.IPv6loopback) {
return true
}
// IPv6 link-local
if strings.HasPrefix(ip.String(), "fe80:") {
return true
}
// IPv6 unique local
if strings.HasPrefix(ip.String(), "fc") || strings.HasPrefix(ip.String(), "fd") {
return true
}
}
return false
}
// parsePortRanges 解析端口范围配置
// 支持格式: "80", "443", "8000-9000"
func parsePortRanges(portConfigs []string) ([]int, error) {
var ports []int
for _, config := range portConfigs {
config = strings.TrimSpace(config)
if config == "" {
continue
}
if strings.Contains(config, "-") {
// 处理端口范围 "8000-9000"
parts := strings.Split(config, "-")
if len(parts) != 2 {
return nil, fmt.Errorf("invalid port range format: %s", config)
}
startPort, err := strconv.Atoi(strings.TrimSpace(parts[0]))
if err != nil {
return nil, fmt.Errorf("invalid start port in range %s: %v", config, err)
}
endPort, err := strconv.Atoi(strings.TrimSpace(parts[1]))
if err != nil {
return nil, fmt.Errorf("invalid end port in range %s: %v", config, err)
}
if startPort > endPort {
return nil, fmt.Errorf("invalid port range %s: start port cannot be greater than end port", config)
}
if startPort < 1 || startPort > 65535 || endPort < 1 || endPort > 65535 {
return nil, fmt.Errorf("port range %s contains invalid port numbers (must be 1-65535)", config)
}
// 添加范围内的所有端口
for port := startPort; port <= endPort; port++ {
ports = append(ports, port)
}
} else {
// 处理单个端口 "80"
port, err := strconv.Atoi(config)
if err != nil {
return nil, fmt.Errorf("invalid port number: %s", config)
}
if port < 1 || port > 65535 {
return nil, fmt.Errorf("invalid port number %d (must be 1-65535)", port)
}
ports = append(ports, port)
}
}
return ports, nil
}
// isAllowedPort 检查端口是否被允许
func (p *SSRFProtection) isAllowedPort(port int) bool {
if len(p.AllowedPorts) == 0 {
return true // 如果没有配置端口限制,则允许所有端口
}
for _, allowedPort := range p.AllowedPorts {
if port == allowedPort {
return true
}
}
return false
}
// isDomainWhitelisted 检查域名是否在白名单中
func isDomainListed(domain string, list []string) bool {
if len(list) == 0 {
return false
}
domain = strings.ToLower(domain)
for _, item := range list {
item = strings.ToLower(strings.TrimSpace(item))
if item == "" {
continue
}
// 精确匹配
if domain == item {
return true
}
// 通配符匹配 (*.example.com)
if strings.HasPrefix(item, "*.") {
suffix := strings.TrimPrefix(item, "*.")
if strings.HasSuffix(domain, "."+suffix) || domain == suffix {
return true
}
}
}
return false
}
func (p *SSRFProtection) isDomainAllowed(domain string) bool {
listed := isDomainListed(domain, p.DomainList)
if p.DomainFilterMode { // 白名单
return listed
}
// 黑名单
return !listed
}
// isIPWhitelisted 检查IP是否在白名单中
func isIPListed(ip net.IP, list []string) bool {
if len(list) == 0 {
return false
}
for _, whitelistCIDR := range list {
_, network, err := net.ParseCIDR(whitelistCIDR)
if err != nil {
// 尝试作为单个IP处理
if whitelistIP := net.ParseIP(whitelistCIDR); whitelistIP != nil {
if ip.Equal(whitelistIP) {
return true
}
}
continue
}
if network.Contains(ip) {
return true
}
}
return false
}
// IsIPAccessAllowed 检查IP是否允许访问
func (p *SSRFProtection) IsIPAccessAllowed(ip net.IP) bool {
// 私有IP限制
if isPrivateIP(ip) && !p.AllowPrivateIp {
return false
}
listed := isIPListed(ip, p.IpList)
if p.IpFilterMode { // 白名单
return listed
}
// 黑名单
return !listed
}
// ValidateURL 验证URL是否安全
func (p *SSRFProtection) ValidateURL(urlStr string) error {
// 解析URL
u, err := url.Parse(urlStr)
if err != nil {
return fmt.Errorf("invalid URL format: %v", err)
}
// 只允许HTTP/HTTPS协议
if u.Scheme != "http" && u.Scheme != "https" {
return fmt.Errorf("unsupported protocol: %s (only http/https allowed)", u.Scheme)
}
// 解析主机和端口
host, portStr, err := net.SplitHostPort(u.Host)
if err != nil {
// 没有端口,使用默认端口
host = u.Hostname()
if u.Scheme == "https" {
portStr = "443"
} else {
portStr = "80"
}
}
// 验证端口
port, err := strconv.Atoi(portStr)
if err != nil {
return fmt.Errorf("invalid port: %s", portStr)
}
if !p.isAllowedPort(port) {
return fmt.Errorf("port %d is not allowed", port)
}
// 如果 host 是 IP则跳过域名检查
if ip := net.ParseIP(host); ip != nil {
if !p.IsIPAccessAllowed(ip) {
if isPrivateIP(ip) {
return fmt.Errorf("private IP address not allowed: %s", ip.String())
}
if p.IpFilterMode {
return fmt.Errorf("ip not in whitelist: %s", ip.String())
}
return fmt.Errorf("ip in blacklist: %s", ip.String())
}
return nil
}
// 先进行域名过滤
if !p.isDomainAllowed(host) {
if p.DomainFilterMode {
return fmt.Errorf("domain not in whitelist: %s", host)
}
return fmt.Errorf("domain in blacklist: %s", host)
}
// 若未启用对域名应用IP过滤则到此通过
if !p.ApplyIPFilterForDomain {
return nil
}
// 解析域名对应IP并检查
ips, err := net.LookupIP(host)
if err != nil {
return fmt.Errorf("DNS resolution failed for %s: %v", host, err)
}
for _, ip := range ips {
if !p.IsIPAccessAllowed(ip) {
if isPrivateIP(ip) && !p.AllowPrivateIp {
return fmt.Errorf("private IP address not allowed: %s resolves to %s", host, ip.String())
}
if p.IpFilterMode {
return fmt.Errorf("ip not in whitelist: %s resolves to %s", host, ip.String())
}
return fmt.Errorf("ip in blacklist: %s resolves to %s", host, ip.String())
}
}
return nil
}
// ValidateURLWithFetchSetting 使用FetchSetting配置验证URL
func ValidateURLWithFetchSetting(urlStr string, enableSSRFProtection, allowPrivateIp bool, domainFilterMode bool, ipFilterMode bool, domainList, ipList, allowedPorts []string, applyIPFilterForDomain bool) error {
// 如果SSRF防护被禁用直接返回成功
if !enableSSRFProtection {
return nil
}
// 解析端口范围配置
allowedPortInts, err := parsePortRanges(allowedPorts)
if err != nil {
return fmt.Errorf("request reject - invalid port configuration: %v", err)
}
protection := &SSRFProtection{
AllowPrivateIp: allowPrivateIp,
DomainFilterMode: domainFilterMode,
DomainList: domainList,
IpFilterMode: ipFilterMode,
IpList: ipList,
AllowedPorts: allowedPortInts,
ApplyIPFilterForDomain: applyIPFilterForDomain,
}
return protection.ValidateURL(urlStr)
}

View File

@@ -2,10 +2,9 @@ package common
import (
"fmt"
"github.com/gin-gonic/gin"
"os"
"time"
"github.com/gin-gonic/gin"
)
func SysLog(s string) {
@@ -23,33 +22,3 @@ func FatalLog(v ...any) {
_, _ = fmt.Fprintf(gin.DefaultErrorWriter, "[FATAL] %v | %v \n", t.Format("2006/01/02 - 15:04:05"), v)
os.Exit(1)
}
func LogStartupSuccess(startTime time.Time, port string) {
duration := time.Since(startTime)
durationMs := duration.Milliseconds()
// Get network IPs
networkIps := GetNetworkIps()
// Print blank line for spacing
fmt.Fprintf(gin.DefaultWriter, "\n")
// Print the main success message
fmt.Fprintf(gin.DefaultWriter, " \033[32m%s %s\033[0m ready in %d ms\n", SystemName, Version, durationMs)
fmt.Fprintf(gin.DefaultWriter, "\n")
// Skip fancy startup message in container environments
if !IsRunningInContainer() {
// Print local URL
fmt.Fprintf(gin.DefaultWriter, " ➜ \033[1mLocal:\033[0m http://localhost:%s/\n", port)
}
// Print network URLs
for _, ip := range networkIps {
fmt.Fprintf(gin.DefaultWriter, " ➜ \033[1mNetwork:\033[0m http://%s:%s/\n", ip, port)
}
// Print blank line for spacing
fmt.Fprintf(gin.DefaultWriter, "\n")
}

View File

@@ -68,78 +68,6 @@ func GetIp() (ip string) {
return
}
func GetNetworkIps() []string {
var networkIps []string
ips, err := net.InterfaceAddrs()
if err != nil {
log.Println(err)
return networkIps
}
for _, a := range ips {
if ipNet, ok := a.(*net.IPNet); ok && !ipNet.IP.IsLoopback() {
if ipNet.IP.To4() != nil {
ip := ipNet.IP.String()
// Include common private network ranges
if strings.HasPrefix(ip, "10.") ||
strings.HasPrefix(ip, "172.") ||
strings.HasPrefix(ip, "192.168.") {
networkIps = append(networkIps, ip)
}
}
}
}
return networkIps
}
// IsRunningInContainer detects if the application is running inside a container
func IsRunningInContainer() bool {
// Method 1: Check for .dockerenv file (Docker containers)
if _, err := os.Stat("/.dockerenv"); err == nil {
return true
}
// Method 2: Check cgroup for container indicators
if data, err := os.ReadFile("/proc/1/cgroup"); err == nil {
content := string(data)
if strings.Contains(content, "docker") ||
strings.Contains(content, "containerd") ||
strings.Contains(content, "kubepods") ||
strings.Contains(content, "/lxc/") {
return true
}
}
// Method 3: Check environment variables commonly set by container runtimes
containerEnvVars := []string{
"KUBERNETES_SERVICE_HOST",
"DOCKER_CONTAINER",
"container",
}
for _, envVar := range containerEnvVars {
if os.Getenv(envVar) != "" {
return true
}
}
// Method 4: Check if init process is not the traditional init
if data, err := os.ReadFile("/proc/1/comm"); err == nil {
comm := strings.TrimSpace(string(data))
// In containers, process 1 is often not "init" or "systemd"
if comm != "init" && comm != "systemd" {
// Additional check: if it's a common container entrypoint
if strings.Contains(comm, "docker") ||
strings.Contains(comm, "containerd") ||
strings.Contains(comm, "runc") {
return true
}
}
}
return false
}
var sizeKB = 1024
var sizeMB = sizeKB * 1024
var sizeGB = sizeMB * 1024

View File

@@ -11,10 +11,8 @@ const (
SunoActionMusic = "MUSIC"
SunoActionLyrics = "LYRICS"
TaskActionGenerate = "generate"
TaskActionTextGenerate = "textGenerate"
TaskActionFirstTailGenerate = "firstTailGenerate"
TaskActionReferenceGenerate = "referenceGenerate"
TaskActionGenerate = "generate"
TaskActionTextGenerate = "textGenerate"
)
var SunoModel2Action = map[string]string{

View File

@@ -90,11 +90,6 @@ func testChannel(channel *model.Channel, testModel string) testResult {
requestPath = "/v1/embeddings" // 修改请求路径
}
// VolcEngine 图像生成模型
if channel.Type == constant.ChannelTypeVolcEngine && strings.Contains(testModel, "seedream") {
requestPath = "/v1/images/generations"
}
c.Request = &http.Request{
Method: "POST",
URL: &url.URL{Path: requestPath}, // 使用动态路径
@@ -114,21 +109,6 @@ func testChannel(channel *model.Channel, testModel string) testResult {
}
}
// 重新检查模型类型并更新请求路径
if strings.Contains(strings.ToLower(testModel), "embedding") ||
strings.HasPrefix(testModel, "m3e") ||
strings.Contains(testModel, "bge-") ||
strings.Contains(testModel, "embed") ||
channel.Type == constant.ChannelTypeMokaAI {
requestPath = "/v1/embeddings"
c.Request.URL.Path = requestPath
}
if channel.Type == constant.ChannelTypeVolcEngine && strings.Contains(testModel, "seedream") {
requestPath = "/v1/images/generations"
c.Request.URL.Path = requestPath
}
cache, err := model.GetUserCache(1)
if err != nil {
return testResult{
@@ -160,9 +140,6 @@ func testChannel(channel *model.Channel, testModel string) testResult {
if c.Request.URL.Path == "/v1/embeddings" {
relayFormat = types.RelayFormatEmbedding
}
if c.Request.URL.Path == "/v1/images/generations" {
relayFormat = types.RelayFormatOpenAIImage
}
info, err := relaycommon.GenRelayInfo(c, relayFormat, request, nil)
@@ -224,22 +201,6 @@ func testChannel(channel *model.Channel, testModel string) testResult {
}
// 调用专门用于 Embedding 的转换函数
convertedRequest, err = adaptor.ConvertEmbeddingRequest(c, info, embeddingRequest)
} else if info.RelayMode == relayconstant.RelayModeImagesGenerations {
// 创建一个 ImageRequest
prompt := "cat"
if request.Prompt != nil {
if promptStr, ok := request.Prompt.(string); ok && promptStr != "" {
prompt = promptStr
}
}
imageRequest := dto.ImageRequest{
Prompt: prompt,
Model: request.Model,
N: uint(request.N),
Size: request.Size,
}
// 调用专门用于图像生成的转换函数
convertedRequest, err = adaptor.ConvertImageRequest(c, info, imageRequest)
} else {
// 对其他所有请求类型(如 Chat保持原有逻辑
convertedRequest, err = adaptor.ConvertOpenAIRequest(c, info, request)

View File

@@ -8,7 +8,6 @@ import (
"one-api/constant"
"one-api/dto"
"one-api/model"
"one-api/service"
"strconv"
"strings"
@@ -189,8 +188,6 @@ func FetchUpstreamModels(c *gin.Context) {
url = fmt.Sprintf("%s/v1beta/openai/models", baseURL) // Remove key in url since we need to use AuthHeader
case constant.ChannelTypeAli:
url = fmt.Sprintf("%s/compatible-mode/v1/models", baseURL)
case constant.ChannelTypeZhipu_v4:
url = fmt.Sprintf("%s/api/paas/v4/models", baseURL)
default:
url = fmt.Sprintf("%s/v1/models", baseURL)
}
@@ -504,10 +501,9 @@ func validateChannel(channel *model.Channel, isAdd bool) error {
}
type AddChannelRequest struct {
Mode string `json:"mode"`
MultiKeyMode constant.MultiKeyMode `json:"multi_key_mode"`
BatchAddSetKeyPrefix2Name bool `json:"batch_add_set_key_prefix_2_name"`
Channel *model.Channel `json:"channel"`
Mode string `json:"mode"`
MultiKeyMode constant.MultiKeyMode `json:"multi_key_mode"`
Channel *model.Channel `json:"channel"`
}
func getVertexArrayKeys(keys string) ([]string, error) {
@@ -620,13 +616,6 @@ func AddChannel(c *gin.Context) {
}
localChannel := addChannelRequest.Channel
localChannel.Key = key
if addChannelRequest.BatchAddSetKeyPrefix2Name && len(keys) > 1 {
keyPrefix := localChannel.Key
if len(localChannel.Key) > 8 {
keyPrefix = localChannel.Key[:8]
}
localChannel.Name = fmt.Sprintf("%s %s", localChannel.Name, keyPrefix)
}
channels = append(channels, *localChannel)
}
err = model.BatchInsertChannels(channels)
@@ -634,7 +623,6 @@ func AddChannel(c *gin.Context) {
common.ApiError(c, err)
return
}
service.ResetProxyClientCache()
c.JSON(http.StatusOK, gin.H{
"success": true,
"message": "",
@@ -896,7 +884,6 @@ func UpdateChannel(c *gin.Context) {
return
}
model.InitChannelCache()
service.ResetProxyClientCache()
channel.Key = ""
clearChannelInfo(&channel.Channel)
c.JSON(http.StatusOK, gin.H{
@@ -1106,8 +1093,8 @@ func CopyChannel(c *gin.Context) {
// MultiKeyManageRequest represents the request for multi-key management operations
type MultiKeyManageRequest struct {
ChannelId int `json:"channel_id"`
Action string `json:"action"` // "disable_key", "enable_key", "delete_key", "delete_disabled_keys", "get_key_status"
KeyIndex *int `json:"key_index,omitempty"` // for disable_key, enable_key, and delete_key actions
Action string `json:"action"` // "disable_key", "enable_key", "delete_disabled_keys", "get_key_status"
KeyIndex *int `json:"key_index,omitempty"` // for disable_key and enable_key actions
Page int `json:"page,omitempty"` // for get_key_status pagination
PageSize int `json:"page_size,omitempty"` // for get_key_status pagination
Status *int `json:"status,omitempty"` // for get_key_status filtering: 1=enabled, 2=manual_disabled, 3=auto_disabled, nil=all
@@ -1435,86 +1422,6 @@ func ManageMultiKeys(c *gin.Context) {
})
return
case "delete_key":
if request.KeyIndex == nil {
c.JSON(http.StatusOK, gin.H{
"success": false,
"message": "未指定要删除的密钥索引",
})
return
}
keyIndex := *request.KeyIndex
if keyIndex < 0 || keyIndex >= channel.ChannelInfo.MultiKeySize {
c.JSON(http.StatusOK, gin.H{
"success": false,
"message": "密钥索引超出范围",
})
return
}
keys := channel.GetKeys()
var remainingKeys []string
var newStatusList = make(map[int]int)
var newDisabledTime = make(map[int]int64)
var newDisabledReason = make(map[int]string)
newIndex := 0
for i, key := range keys {
// 跳过要删除的密钥
if i == keyIndex {
continue
}
remainingKeys = append(remainingKeys, key)
// 保留其他密钥的状态信息,重新索引
if channel.ChannelInfo.MultiKeyStatusList != nil {
if status, exists := channel.ChannelInfo.MultiKeyStatusList[i]; exists && status != 1 {
newStatusList[newIndex] = status
}
}
if channel.ChannelInfo.MultiKeyDisabledTime != nil {
if t, exists := channel.ChannelInfo.MultiKeyDisabledTime[i]; exists {
newDisabledTime[newIndex] = t
}
}
if channel.ChannelInfo.MultiKeyDisabledReason != nil {
if r, exists := channel.ChannelInfo.MultiKeyDisabledReason[i]; exists {
newDisabledReason[newIndex] = r
}
}
newIndex++
}
if len(remainingKeys) == 0 {
c.JSON(http.StatusOK, gin.H{
"success": false,
"message": "不能删除最后一个密钥",
})
return
}
// Update channel with remaining keys
channel.Key = strings.Join(remainingKeys, "\n")
channel.ChannelInfo.MultiKeySize = len(remainingKeys)
channel.ChannelInfo.MultiKeyStatusList = newStatusList
channel.ChannelInfo.MultiKeyDisabledTime = newDisabledTime
channel.ChannelInfo.MultiKeyDisabledReason = newDisabledReason
err = channel.Update()
if err != nil {
common.ApiError(c, err)
return
}
model.InitChannelCache()
c.JSON(http.StatusOK, gin.H{
"success": true,
"message": "密钥已删除",
})
return
case "delete_disabled_keys":
keys := channel.GetKeys()
var remainingKeys []string

375
controller/oauth.go Normal file
View File

@@ -0,0 +1,375 @@
package controller
import (
"encoding/json"
"net/http"
"one-api/model"
"one-api/setting/system_setting"
"one-api/src/oauth"
"time"
"github.com/gin-gonic/gin"
jwt "github.com/golang-jwt/jwt/v5"
"one-api/middleware"
"strconv"
"strings"
)
// GetJWKS 获取JWKS公钥集
func GetJWKS(c *gin.Context) {
settings := system_setting.GetOAuth2Settings()
if !settings.Enabled {
c.JSON(http.StatusNotFound, gin.H{
"error": "OAuth2 server is disabled",
})
return
}
// lazy init if needed
_ = oauth.EnsureInitialized()
jwks := oauth.GetJWKS()
if jwks == nil {
c.JSON(http.StatusInternalServerError, gin.H{
"error": "JWKS not available",
})
return
}
// 设置CORS headers
c.Header("Access-Control-Allow-Origin", "*")
c.Header("Access-Control-Allow-Methods", "GET")
c.Header("Access-Control-Allow-Headers", "Content-Type")
c.Header("Cache-Control", "public, max-age=3600") // 缓存1小时
// 返回JWKS
c.Header("Content-Type", "application/json")
// 将JWKS转换为JSON字符串
jsonData, err := json.Marshal(jwks)
if err != nil {
c.JSON(http.StatusInternalServerError, gin.H{
"error": "Failed to marshal JWKS",
})
return
}
c.String(http.StatusOK, string(jsonData))
}
// OAuthTokenEndpoint OAuth2 令牌端点
func OAuthTokenEndpoint(c *gin.Context) {
settings := system_setting.GetOAuth2Settings()
if !settings.Enabled {
c.JSON(http.StatusNotFound, gin.H{
"error": "unsupported_grant_type",
"error_description": "OAuth2 server is disabled",
})
return
}
// 只允许POST请求
if c.Request.Method != "POST" {
c.JSON(http.StatusMethodNotAllowed, gin.H{
"error": "invalid_request",
"error_description": "Only POST method is allowed",
})
return
}
// 只允许application/x-www-form-urlencoded内容类型
contentType := c.GetHeader("Content-Type")
if contentType == "" || !strings.Contains(strings.ToLower(contentType), "application/x-www-form-urlencoded") {
c.JSON(http.StatusBadRequest, gin.H{
"error": "invalid_request",
"error_description": "Content-Type must be application/x-www-form-urlencoded",
})
return
}
// lazy init
if err := oauth.EnsureInitialized(); err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": "server_error", "error_description": err.Error()})
return
}
oauth.HandleTokenRequest(c)
}
// OAuthAuthorizeEndpoint OAuth2 授权端点
func OAuthAuthorizeEndpoint(c *gin.Context) {
settings := system_setting.GetOAuth2Settings()
if !settings.Enabled {
c.JSON(http.StatusNotFound, gin.H{
"error": "server_error",
"error_description": "OAuth2 server is disabled",
})
return
}
if err := oauth.EnsureInitialized(); err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": "server_error", "error_description": err.Error()})
return
}
oauth.HandleAuthorizeRequest(c)
}
// OAuthServerInfo 获取OAuth2服务器信息
func OAuthServerInfo(c *gin.Context) {
settings := system_setting.GetOAuth2Settings()
if !settings.Enabled {
c.JSON(http.StatusNotFound, gin.H{
"error": "OAuth2 server is disabled",
})
return
}
// 返回OAuth2服务器的基本信息类似OpenID Connect Discovery
issuer := settings.Issuer
if issuer == "" {
scheme := "https"
if c.Request.TLS == nil {
if hdr := c.Request.Header.Get("X-Forwarded-Proto"); hdr != "" {
scheme = hdr
} else {
scheme = "http"
}
}
issuer = scheme + "://" + c.Request.Host
}
base := issuer + "/api"
c.JSON(http.StatusOK, gin.H{
"issuer": issuer,
"authorization_endpoint": base + "/oauth/authorize",
"token_endpoint": base + "/oauth/token",
"jwks_uri": base + "/.well-known/jwks.json",
"grant_types_supported": settings.AllowedGrantTypes,
"response_types_supported": []string{"code", "token"},
"token_endpoint_auth_methods_supported": []string{"client_secret_basic", "client_secret_post"},
"code_challenge_methods_supported": []string{"S256"},
"scopes_supported": []string{"openid", "profile", "email", "api:read", "api:write", "admin"},
"default_private_key_path": settings.DefaultPrivateKeyPath,
})
}
// OAuthOIDCConfiguration OIDC discovery document
func OAuthOIDCConfiguration(c *gin.Context) {
settings := system_setting.GetOAuth2Settings()
if !settings.Enabled {
c.JSON(http.StatusNotFound, gin.H{"error": "OAuth2 server is disabled"})
return
}
issuer := settings.Issuer
if issuer == "" {
scheme := "https"
if c.Request.TLS == nil {
if hdr := c.Request.Header.Get("X-Forwarded-Proto"); hdr != "" {
scheme = hdr
} else {
scheme = "http"
}
}
issuer = scheme + "://" + c.Request.Host
}
base := issuer + "/api"
c.JSON(http.StatusOK, gin.H{
"issuer": issuer,
"authorization_endpoint": base + "/oauth/authorize",
"token_endpoint": base + "/oauth/token",
"userinfo_endpoint": base + "/oauth/userinfo",
"jwks_uri": base + "/.well-known/jwks.json",
"response_types_supported": []string{"code", "token"},
"grant_types_supported": settings.AllowedGrantTypes,
"subject_types_supported": []string{"public"},
"id_token_signing_alg_values_supported": []string{"RS256"},
"scopes_supported": []string{"openid", "profile", "email", "api:read", "api:write", "admin"},
"token_endpoint_auth_methods_supported": []string{"client_secret_basic", "client_secret_post"},
"code_challenge_methods_supported": []string{"S256"},
"default_private_key_path": settings.DefaultPrivateKeyPath,
})
}
// OAuthIntrospect 令牌内省端点RFC 7662
func OAuthIntrospect(c *gin.Context) {
settings := system_setting.GetOAuth2Settings()
if !settings.Enabled {
c.JSON(http.StatusNotFound, gin.H{
"error": "OAuth2 server is disabled",
})
return
}
// 只允许POST请求
if c.Request.Method != "POST" {
c.JSON(http.StatusMethodNotAllowed, gin.H{
"error": "invalid_request",
"error_description": "Only POST method is allowed",
})
return
}
token := c.PostForm("token")
if token == "" {
c.JSON(http.StatusBadRequest, gin.H{
"active": false,
})
return
}
tokenString := token
// 验证并解析JWT
parsed, err := jwt.Parse(tokenString, func(token *jwt.Token) (interface{}, error) {
if _, ok := token.Method.(*jwt.SigningMethodRSA); !ok {
return nil, jwt.ErrTokenSignatureInvalid
}
pub := oauth.GetPublicKeyByKid(func() string {
if v, ok := token.Header["kid"].(string); ok {
return v
}
return ""
}())
if pub == nil {
return nil, jwt.ErrTokenUnverifiable
}
return pub, nil
})
if err != nil || !parsed.Valid {
c.JSON(http.StatusOK, gin.H{"active": false})
return
}
claims, ok := parsed.Claims.(jwt.MapClaims)
if !ok {
c.JSON(http.StatusOK, gin.H{"active": false})
return
}
// 检查撤销
if jti, ok := claims["jti"].(string); ok && jti != "" {
if revoked, _ := model.IsTokenRevoked(jti); revoked {
c.JSON(http.StatusOK, gin.H{"active": false})
return
}
}
// 有效
resp := gin.H{"active": true}
for k, v := range claims {
resp[k] = v
}
resp["token_type"] = "Bearer"
c.JSON(http.StatusOK, resp)
}
// OAuthRevoke 令牌撤销端点RFC 7009
func OAuthRevoke(c *gin.Context) {
settings := system_setting.GetOAuth2Settings()
if !settings.Enabled {
c.JSON(http.StatusNotFound, gin.H{
"error": "OAuth2 server is disabled",
})
return
}
// 只允许POST请求
if c.Request.Method != "POST" {
c.JSON(http.StatusMethodNotAllowed, gin.H{
"error": "invalid_request",
"error_description": "Only POST method is allowed",
})
return
}
token := c.PostForm("token")
if token == "" {
c.JSON(http.StatusBadRequest, gin.H{
"error": "invalid_request",
"error_description": "Missing token parameter",
})
return
}
token = c.PostForm("token")
if token == "" {
c.JSON(http.StatusBadRequest, gin.H{
"error": "invalid_request",
"error_description": "Missing token parameter",
})
return
}
// 尝试解析JWT若成功则记录jti到撤销表
parsed, err := jwt.Parse(token, func(t *jwt.Token) (interface{}, error) {
if _, ok := t.Method.(*jwt.SigningMethodRSA); !ok {
return nil, jwt.ErrTokenSignatureInvalid
}
pub := oauth.GetRSAPublicKey()
if pub == nil {
return nil, jwt.ErrTokenUnverifiable
}
return pub, nil
})
if err == nil && parsed != nil && parsed.Valid {
if claims, ok := parsed.Claims.(jwt.MapClaims); ok {
var jti string
var exp int64
if v, ok := claims["jti"].(string); ok {
jti = v
}
if v, ok := claims["exp"].(float64); ok {
exp = int64(v)
} else if v, ok := claims["exp"].(int64); ok {
exp = v
}
if jti != "" {
// 如果没有exp默认撤销至当前+TTL 10分钟
if exp == 0 {
exp = time.Now().Add(10 * time.Minute).Unix()
}
_ = model.RevokeToken(jti, exp)
}
}
}
c.JSON(http.StatusOK, gin.H{"success": true})
}
// OAuthUserInfo returns OIDC userinfo based on access token
func OAuthUserInfo(c *gin.Context) {
settings := system_setting.GetOAuth2Settings()
if !settings.Enabled {
c.JSON(http.StatusNotFound, gin.H{"error": "OAuth2 server is disabled"})
return
}
// 需要 OAuthJWTAuth 中间件注入 claims
claims, ok := middleware.GetOAuthClaims(c)
if !ok {
c.JSON(http.StatusUnauthorized, gin.H{"error": "invalid_token"})
return
}
// scope 校验:必须包含 openid
scope, _ := claims["scope"].(string)
if !strings.Contains(" "+scope+" ", " openid ") {
c.JSON(http.StatusForbidden, gin.H{"error": "insufficient_scope"})
return
}
sub, _ := claims["sub"].(string)
resp := gin.H{"sub": sub}
// 若包含 profile/email scope补充返回
if strings.Contains(" "+scope+" ", " profile ") || strings.Contains(" "+scope+" ", " email ") {
if uid, err := strconv.Atoi(sub); err == nil {
if user, err2 := model.GetUserById(uid, false); err2 == nil && user != nil {
if strings.Contains(" "+scope+" ", " profile ") {
resp["name"] = user.DisplayName
resp["preferred_username"] = user.Username
}
if strings.Contains(" "+scope+" ", " email ") {
resp["email"] = user.Email
resp["email_verified"] = true
}
}
}
}
c.JSON(http.StatusOK, resp)
}

374
controller/oauth_client.go Normal file
View File

@@ -0,0 +1,374 @@
package controller
import (
"net/http"
"one-api/common"
"one-api/model"
"strconv"
"strings"
"github.com/gin-gonic/gin"
"github.com/thanhpk/randstr"
)
// CreateOAuthClientRequest 创建OAuth客户端请求
type CreateOAuthClientRequest struct {
Name string `json:"name" binding:"required"`
ClientType string `json:"client_type" binding:"required,oneof=confidential public"`
GrantTypes []string `json:"grant_types" binding:"required"`
RedirectURIs []string `json:"redirect_uris"`
Scopes []string `json:"scopes" binding:"required"`
Description string `json:"description"`
RequirePKCE bool `json:"require_pkce"`
}
// UpdateOAuthClientRequest 更新OAuth客户端请求
type UpdateOAuthClientRequest struct {
ID string `json:"id" binding:"required"`
Name string `json:"name" binding:"required"`
ClientType string `json:"client_type" binding:"required,oneof=confidential public"`
GrantTypes []string `json:"grant_types" binding:"required"`
RedirectURIs []string `json:"redirect_uris"`
Scopes []string `json:"scopes" binding:"required"`
Description string `json:"description"`
RequirePKCE bool `json:"require_pkce"`
Status int `json:"status" binding:"required,oneof=1 2"`
}
// GetAllOAuthClients 获取所有OAuth客户端
func GetAllOAuthClients(c *gin.Context) {
page, _ := strconv.Atoi(c.Query("page"))
if page < 1 {
page = 1
}
perPage, _ := strconv.Atoi(c.Query("per_page"))
if perPage < 1 || perPage > 100 {
perPage = 20
}
startIdx := (page - 1) * perPage
clients, err := model.GetAllOAuthClients(startIdx, perPage)
if err != nil {
c.JSON(http.StatusOK, gin.H{
"success": false,
"message": err.Error(),
})
return
}
// 清理敏感信息
for _, client := range clients {
client.Secret = maskSecret(client.Secret)
}
total, _ := model.CountOAuthClients()
c.JSON(http.StatusOK, gin.H{
"success": true,
"data": clients,
"total": total,
"page": page,
"per_page": perPage,
})
}
// SearchOAuthClients 搜索OAuth客户端
func SearchOAuthClients(c *gin.Context) {
keyword := c.Query("keyword")
if keyword == "" {
c.JSON(http.StatusBadRequest, gin.H{
"success": false,
"message": "关键词不能为空",
})
return
}
clients, err := model.SearchOAuthClients(keyword)
if err != nil {
c.JSON(http.StatusOK, gin.H{
"success": false,
"message": err.Error(),
})
return
}
// 清理敏感信息
for _, client := range clients {
client.Secret = maskSecret(client.Secret)
}
c.JSON(http.StatusOK, gin.H{
"success": true,
"data": clients,
})
}
// GetOAuthClient 获取单个OAuth客户端
func GetOAuthClient(c *gin.Context) {
id := c.Param("id")
if id == "" {
c.JSON(http.StatusBadRequest, gin.H{
"success": false,
"message": "ID不能为空",
})
return
}
client, err := model.GetOAuthClientByID(id)
if err != nil {
c.JSON(http.StatusNotFound, gin.H{
"success": false,
"message": "客户端不存在",
})
return
}
// 清理敏感信息
client.Secret = maskSecret(client.Secret)
c.JSON(http.StatusOK, gin.H{
"success": true,
"data": client,
})
}
// CreateOAuthClient 创建OAuth客户端
func CreateOAuthClient(c *gin.Context) {
var req CreateOAuthClientRequest
if err := c.ShouldBindJSON(&req); err != nil {
c.JSON(http.StatusBadRequest, gin.H{
"success": false,
"message": "请求参数错误: " + err.Error(),
})
return
}
// 验证授权类型
validGrantTypes := []string{"client_credentials", "authorization_code", "refresh_token"}
for _, grantType := range req.GrantTypes {
if !contains(validGrantTypes, grantType) {
c.JSON(http.StatusBadRequest, gin.H{
"success": false,
"message": "无效的授权类型: " + grantType,
})
return
}
}
// 如果包含authorization_code则必须提供redirect_uris
if contains(req.GrantTypes, "authorization_code") && len(req.RedirectURIs) == 0 {
c.JSON(http.StatusBadRequest, gin.H{
"success": false,
"message": "授权码模式需要提供重定向URI",
})
return
}
// 生成客户端ID和密钥
clientID := generateClientID()
clientSecret := ""
if req.ClientType == "confidential" {
clientSecret = generateClientSecret()
}
// 获取创建者ID
createdBy := c.GetInt("id")
// 创建客户端
client := &model.OAuthClient{
ID: clientID,
Secret: clientSecret,
Name: req.Name,
ClientType: req.ClientType,
RequirePKCE: req.RequirePKCE,
Status: common.UserStatusEnabled,
CreatedBy: createdBy,
Description: req.Description,
}
client.SetGrantTypes(req.GrantTypes)
client.SetRedirectURIs(req.RedirectURIs)
client.SetScopes(req.Scopes)
err := model.CreateOAuthClient(client)
if err != nil {
c.JSON(http.StatusInternalServerError, gin.H{
"success": false,
"message": "创建客户端失败: " + err.Error(),
})
return
}
// 返回结果(包含完整的客户端密钥,仅此一次)
c.JSON(http.StatusCreated, gin.H{
"success": true,
"message": "客户端创建成功",
"client_id": client.ID,
"client_secret": client.Secret, // 仅在创建时返回完整密钥
"data": client,
})
}
// UpdateOAuthClient 更新OAuth客户端
func UpdateOAuthClient(c *gin.Context) {
var req UpdateOAuthClientRequest
if err := c.ShouldBindJSON(&req); err != nil {
c.JSON(http.StatusBadRequest, gin.H{
"success": false,
"message": "请求参数错误: " + err.Error(),
})
return
}
// 获取现有客户端
client, err := model.GetOAuthClientByID(req.ID)
if err != nil {
c.JSON(http.StatusNotFound, gin.H{
"success": false,
"message": "客户端不存在",
})
return
}
// 验证授权类型
validGrantTypes := []string{"client_credentials", "authorization_code", "refresh_token"}
for _, grantType := range req.GrantTypes {
if !contains(validGrantTypes, grantType) {
c.JSON(http.StatusBadRequest, gin.H{
"success": false,
"message": "无效的授权类型: " + grantType,
})
return
}
}
// 更新客户端信息
client.Name = req.Name
client.ClientType = req.ClientType
client.RequirePKCE = req.RequirePKCE
client.Status = req.Status
client.Description = req.Description
client.SetGrantTypes(req.GrantTypes)
client.SetRedirectURIs(req.RedirectURIs)
client.SetScopes(req.Scopes)
err = model.UpdateOAuthClient(client)
if err != nil {
c.JSON(http.StatusInternalServerError, gin.H{
"success": false,
"message": "更新客户端失败: " + err.Error(),
})
return
}
// 清理敏感信息
client.Secret = maskSecret(client.Secret)
c.JSON(http.StatusOK, gin.H{
"success": true,
"message": "客户端更新成功",
"data": client,
})
}
// DeleteOAuthClient 删除OAuth客户端
func DeleteOAuthClient(c *gin.Context) {
id := c.Param("id")
if id == "" {
c.JSON(http.StatusBadRequest, gin.H{
"success": false,
"message": "ID不能为空",
})
return
}
err := model.DeleteOAuthClient(id)
if err != nil {
c.JSON(http.StatusInternalServerError, gin.H{
"success": false,
"message": "删除客户端失败: " + err.Error(),
})
return
}
c.JSON(http.StatusOK, gin.H{
"success": true,
"message": "客户端删除成功",
})
}
// RegenerateOAuthClientSecret 重新生成客户端密钥
func RegenerateOAuthClientSecret(c *gin.Context) {
id := c.Param("id")
if id == "" {
c.JSON(http.StatusBadRequest, gin.H{
"success": false,
"message": "ID不能为空",
})
return
}
client, err := model.GetOAuthClientByID(id)
if err != nil {
c.JSON(http.StatusNotFound, gin.H{
"success": false,
"message": "客户端不存在",
})
return
}
// 只有机密客户端才能重新生成密钥
if client.ClientType != "confidential" {
c.JSON(http.StatusBadRequest, gin.H{
"success": false,
"message": "只有机密客户端才能重新生成密钥",
})
return
}
// 生成新密钥
client.Secret = generateClientSecret()
err = model.UpdateOAuthClient(client)
if err != nil {
c.JSON(http.StatusInternalServerError, gin.H{
"success": false,
"message": "重新生成密钥失败: " + err.Error(),
})
return
}
c.JSON(http.StatusOK, gin.H{
"success": true,
"message": "客户端密钥重新生成成功",
"client_secret": client.Secret, // 返回新生成的密钥
})
}
// generateClientID 生成客户端ID
func generateClientID() string {
return "client_" + randstr.String(16)
}
// generateClientSecret 生成客户端密钥
func generateClientSecret() string {
return randstr.String(32)
}
// maskSecret 掩码密钥显示
func maskSecret(secret string) string {
if len(secret) <= 6 {
return strings.Repeat("*", len(secret))
}
return secret[:3] + strings.Repeat("*", len(secret)-6) + secret[len(secret)-3:]
}
// contains 检查字符串切片是否包含指定值
func contains(slice []string, item string) bool {
for _, s := range slice {
if s == item {
return true
}
}
return false
}

89
controller/oauth_keys.go Normal file
View File

@@ -0,0 +1,89 @@
package controller
import (
"github.com/gin-gonic/gin"
"net/http"
"one-api/logger"
"one-api/src/oauth"
)
type rotateKeyRequest struct {
Kid string `json:"kid"`
}
type genKeyFileRequest struct {
Path string `json:"path"`
Kid string `json:"kid"`
Overwrite bool `json:"overwrite"`
}
type importPemRequest struct {
Pem string `json:"pem"`
Kid string `json:"kid"`
}
// RotateOAuthSigningKey rotates the OAuth2 JWT signing key (Root only)
func RotateOAuthSigningKey(c *gin.Context) {
var req rotateKeyRequest
_ = c.BindJSON(&req)
kid, err := oauth.RotateSigningKey(req.Kid)
if err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"success": false, "message": err.Error()})
return
}
logger.LogInfo(c, "oauth signing key rotated: "+kid)
c.JSON(http.StatusOK, gin.H{"success": true, "kid": kid})
}
// ListOAuthSigningKeys returns current and historical JWKS signing keys
func ListOAuthSigningKeys(c *gin.Context) {
keys := oauth.ListSigningKeys()
c.JSON(http.StatusOK, gin.H{"success": true, "data": keys})
}
// DeleteOAuthSigningKey deletes a non-current key by kid
func DeleteOAuthSigningKey(c *gin.Context) {
kid := c.Param("kid")
if kid == "" {
c.JSON(http.StatusBadRequest, gin.H{"success": false, "message": "kid required"})
return
}
if err := oauth.DeleteSigningKey(kid); err != nil {
c.JSON(http.StatusBadRequest, gin.H{"success": false, "message": err.Error()})
return
}
logger.LogInfo(c, "oauth signing key deleted: "+kid)
c.JSON(http.StatusOK, gin.H{"success": true})
}
// GenerateOAuthSigningKeyFile generates a private key file and rotates current kid
func GenerateOAuthSigningKeyFile(c *gin.Context) {
var req genKeyFileRequest
if err := c.ShouldBindJSON(&req); err != nil || req.Path == "" {
c.JSON(http.StatusBadRequest, gin.H{"success": false, "message": "path required"})
return
}
kid, err := oauth.GenerateAndPersistKey(req.Path, req.Kid, req.Overwrite)
if err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"success": false, "message": err.Error()})
return
}
logger.LogInfo(c, "oauth signing key generated to file: "+req.Path+" kid="+kid)
c.JSON(http.StatusOK, gin.H{"success": true, "kid": kid, "path": req.Path})
}
// ImportOAuthSigningKey imports PEM text and rotates current kid
func ImportOAuthSigningKey(c *gin.Context) {
var req importPemRequest
if err := c.ShouldBindJSON(&req); err != nil || req.Pem == "" {
c.JSON(http.StatusBadRequest, gin.H{"success": false, "message": "pem required"})
return
}
kid, err := oauth.ImportPEMKey(req.Pem, req.Kid)
if err != nil {
c.JSON(http.StatusBadRequest, gin.H{"success": false, "message": err.Error()})
return
}
logger.LogInfo(c, "oauth signing key imported from PEM, kid="+kid)
c.JSON(http.StatusOK, gin.H{"success": true, "kid": kid})
}

View File

@@ -128,33 +128,6 @@ func UpdateOption(c *gin.Context) {
})
return
}
case "ImageRatio":
err = ratio_setting.UpdateImageRatioByJSONString(option.Value.(string))
if err != nil {
c.JSON(http.StatusOK, gin.H{
"success": false,
"message": "图片倍率设置失败: " + err.Error(),
})
return
}
case "AudioRatio":
err = ratio_setting.UpdateAudioRatioByJSONString(option.Value.(string))
if err != nil {
c.JSON(http.StatusOK, gin.H{
"success": false,
"message": "音频倍率设置失败: " + err.Error(),
})
return
}
case "AudioCompletionRatio":
err = ratio_setting.UpdateAudioCompletionRatioByJSONString(option.Value.(string))
if err != nil {
c.JSON(http.StatusOK, gin.H{
"success": false,
"message": "音频补全倍率设置失败: " + err.Error(),
})
return
}
case "ModelRequestRateLimitGroup":
err = setting.CheckModelRequestRateLimitGroup(option.Value.(string))
if err != nil {

View File

@@ -53,7 +53,7 @@ func GetSetup(c *gin.Context) {
func PostSetup(c *gin.Context) {
// Check if setup is already completed
if constant.Setup {
c.JSON(200, gin.H{
c.JSON(400, gin.H{
"success": false,
"message": "系统已经初始化完成",
})
@@ -66,7 +66,7 @@ func PostSetup(c *gin.Context) {
var req SetupRequest
err := c.ShouldBindJSON(&req)
if err != nil {
c.JSON(200, gin.H{
c.JSON(400, gin.H{
"success": false,
"message": "请求参数有误",
})
@@ -77,7 +77,7 @@ func PostSetup(c *gin.Context) {
if !rootExists {
// Validate username length: max 12 characters to align with model.User validation
if len(req.Username) > 12 {
c.JSON(200, gin.H{
c.JSON(400, gin.H{
"success": false,
"message": "用户名长度不能超过12个字符",
})
@@ -85,7 +85,7 @@ func PostSetup(c *gin.Context) {
}
// Validate password
if req.Password != req.ConfirmPassword {
c.JSON(200, gin.H{
c.JSON(400, gin.H{
"success": false,
"message": "两次输入的密码不一致",
})
@@ -93,7 +93,7 @@ func PostSetup(c *gin.Context) {
}
if len(req.Password) < 8 {
c.JSON(200, gin.H{
c.JSON(400, gin.H{
"success": false,
"message": "密码长度至少为8个字符",
})
@@ -103,7 +103,7 @@ func PostSetup(c *gin.Context) {
// Create root user
hashedPassword, err := common.Password2Hash(req.Password)
if err != nil {
c.JSON(200, gin.H{
c.JSON(500, gin.H{
"success": false,
"message": "系统错误: " + err.Error(),
})
@@ -120,7 +120,7 @@ func PostSetup(c *gin.Context) {
}
err = model.DB.Create(&rootUser).Error
if err != nil {
c.JSON(200, gin.H{
c.JSON(500, gin.H{
"success": false,
"message": "创建管理员账号失败: " + err.Error(),
})
@@ -135,7 +135,7 @@ func PostSetup(c *gin.Context) {
// Save operation modes to database for persistence
err = model.UpdateOption("SelfUseModeEnabled", boolToString(req.SelfUseModeEnabled))
if err != nil {
c.JSON(200, gin.H{
c.JSON(500, gin.H{
"success": false,
"message": "保存自用模式设置失败: " + err.Error(),
})
@@ -144,7 +144,7 @@ func PostSetup(c *gin.Context) {
err = model.UpdateOption("DemoSiteEnabled", boolToString(req.DemoSiteEnabled))
if err != nil {
c.JSON(200, gin.H{
c.JSON(500, gin.H{
"success": false,
"message": "保存演示站点模式设置失败: " + err.Error(),
})
@@ -160,7 +160,7 @@ func PostSetup(c *gin.Context) {
}
err = model.DB.Create(&setup).Error
if err != nil {
c.JSON(200, gin.H{
c.JSON(500, gin.H{
"success": false,
"message": "系统初始化失败: " + err.Error(),
})

View File

@@ -225,8 +225,7 @@ func genStripeLink(referenceId string, customerId string, email string, amount i
Quantity: stripe.Int64(amount),
},
},
Mode: stripe.String(string(stripe.CheckoutSessionModePayment)),
AllowPromotionCodes: stripe.Bool(setting.StripePromotionCodesEnabled),
Mode: stripe.String(string(stripe.CheckoutSessionModePayment)),
}
if "" == customerId {

View File

@@ -19,12 +19,4 @@ const (
type ChannelOtherSettings struct {
AzureResponsesVersion string `json:"azure_responses_version,omitempty"`
VertexKeyType VertexKeyType `json:"vertex_key_type,omitempty"` // "json" or "api_key"
OpenRouterEnterprise *bool `json:"openrouter_enterprise,omitempty"`
}
func (s *ChannelOtherSettings) IsOpenRouterEnterprise() bool {
if s == nil || s.OpenRouterEnterprise == nil {
return false
}
return *s.OpenRouterEnterprise
}

View File

@@ -14,30 +14,7 @@ type GeminiChatRequest struct {
SafetySettings []GeminiChatSafetySettings `json:"safetySettings,omitempty"`
GenerationConfig GeminiChatGenerationConfig `json:"generationConfig,omitempty"`
Tools json.RawMessage `json:"tools,omitempty"`
ToolConfig *ToolConfig `json:"toolConfig,omitempty"`
SystemInstructions *GeminiChatContent `json:"systemInstruction,omitempty"`
CachedContent string `json:"cachedContent,omitempty"`
}
type ToolConfig struct {
FunctionCallingConfig *FunctionCallingConfig `json:"functionCallingConfig,omitempty"`
RetrievalConfig *RetrievalConfig `json:"retrievalConfig,omitempty"`
}
type FunctionCallingConfig struct {
Mode FunctionCallingConfigMode `json:"mode,omitempty"`
AllowedFunctionNames []string `json:"allowedFunctionNames,omitempty"`
}
type FunctionCallingConfigMode string
type RetrievalConfig struct {
LatLng *LatLng `json:"latLng,omitempty"`
LanguageCode string `json:"languageCode,omitempty"`
}
type LatLng struct {
Latitude *float64 `json:"latitude,omitempty"`
Longitude *float64 `json:"longitude,omitempty"`
}
func (r *GeminiChatRequest) GetTokenCountMeta() *types.TokenCountMeta {
@@ -251,7 +228,6 @@ type GeminiChatTool struct {
GoogleSearchRetrieval any `json:"googleSearchRetrieval,omitempty"`
CodeExecution any `json:"codeExecution,omitempty"`
FunctionDeclarations any `json:"functionDeclarations,omitempty"`
URLContext any `json:"urlContext,omitempty"`
}
type GeminiChatGenerationConfig struct {
@@ -263,20 +239,12 @@ type GeminiChatGenerationConfig struct {
StopSequences []string `json:"stopSequences,omitempty"`
ResponseMimeType string `json:"responseMimeType,omitempty"`
ResponseSchema any `json:"responseSchema,omitempty"`
ResponseJsonSchema json.RawMessage `json:"responseJsonSchema,omitempty"`
PresencePenalty *float32 `json:"presencePenalty,omitempty"`
FrequencyPenalty *float32 `json:"frequencyPenalty,omitempty"`
ResponseLogprobs bool `json:"responseLogprobs,omitempty"`
Logprobs *int32 `json:"logprobs,omitempty"`
MediaResolution MediaResolution `json:"mediaResolution,omitempty"`
Seed int64 `json:"seed,omitempty"`
ResponseModalities []string `json:"responseModalities,omitempty"`
ThinkingConfig *GeminiThinkingConfig `json:"thinkingConfig,omitempty"`
SpeechConfig json.RawMessage `json:"speechConfig,omitempty"` // RawMessage to allow flexible speech config
}
type MediaResolution string
type GeminiChatCandidate struct {
Content GeminiChatContent `json:"content"`
FinishReason *string `json:"finishReason"`

View File

@@ -772,12 +772,11 @@ type OpenAIResponsesRequest struct {
Instructions json.RawMessage `json:"instructions,omitempty"`
MaxOutputTokens uint `json:"max_output_tokens,omitempty"`
Metadata json.RawMessage `json:"metadata,omitempty"`
ParallelToolCalls json.RawMessage `json:"parallel_tool_calls,omitempty"`
ParallelToolCalls bool `json:"parallel_tool_calls,omitempty"`
PreviousResponseID string `json:"previous_response_id,omitempty"`
Reasoning *Reasoning `json:"reasoning,omitempty"`
ServiceTier string `json:"service_tier,omitempty"`
Store json.RawMessage `json:"store,omitempty"`
PromptCacheKey json.RawMessage `json:"prompt_cache_key,omitempty"`
Store bool `json:"store,omitempty"`
Stream bool `json:"stream,omitempty"`
Temperature float64 `json:"temperature,omitempty"`
Text json.RawMessage `json:"text,omitempty"`

View File

@@ -6,10 +6,6 @@ import (
"one-api/types"
)
const (
ResponsesOutputTypeImageGenerationCall = "image_generation_call"
)
type SimpleResponse struct {
Usage `json:"usage"`
Error any `json:"error"`
@@ -277,42 +273,6 @@ func (o *OpenAIResponsesResponse) GetOpenAIError() *types.OpenAIError {
return GetOpenAIError(o.Error)
}
func (o *OpenAIResponsesResponse) HasImageGenerationCall() bool {
if len(o.Output) == 0 {
return false
}
for _, output := range o.Output {
if output.Type == ResponsesOutputTypeImageGenerationCall {
return true
}
}
return false
}
func (o *OpenAIResponsesResponse) GetQuality() string {
if len(o.Output) == 0 {
return ""
}
for _, output := range o.Output {
if output.Type == ResponsesOutputTypeImageGenerationCall {
return output.Quality
}
}
return ""
}
func (o *OpenAIResponsesResponse) GetSize() string {
if len(o.Output) == 0 {
return ""
}
for _, output := range o.Output {
if output.Type == ResponsesOutputTypeImageGenerationCall {
return output.Size
}
}
return ""
}
type IncompleteDetails struct {
Reasoning string `json:"reasoning"`
}
@@ -323,8 +283,6 @@ type ResponsesOutput struct {
Status string `json:"status"`
Role string `json:"role"`
Content []ResponsesOutputContent `json:"content"`
Quality string `json:"quality"`
Size string `json:"size"`
}
type ResponsesOutputContent struct {

View File

@@ -0,0 +1,326 @@
<!doctype html>
<html lang="zh-CN">
<head>
<meta charset="utf-8" />
<meta name="viewport" content="width=device-width, initial-scale=1" />
<title>OAuth2/OIDC 授权码 + PKCE 前端演示</title>
<style>
:root { --bg:#0b0c10; --panel:#111317; --muted:#aab2bf; --accent:#3b82f6; --ok:#16a34a; --warn:#f59e0b; --err:#ef4444; --border:#1f2430; }
body { margin:0; font-family: ui-sans-serif, system-ui, -apple-system, Segoe UI, Roboto, Helvetica, Arial; background: var(--bg); color:#e5e7eb; }
.wrap { max-width: 980px; margin: 32px auto; padding: 0 16px; }
h1 { font-size: 22px; margin:0 0 16px; }
.card { background: var(--panel); border:1px solid var(--border); border-radius: 10px; padding: 16px; margin: 12px 0; }
.row { display:flex; gap:12px; flex-wrap:wrap; }
.col { flex: 1 1 280px; display:flex; flex-direction:column; }
label { font-size: 12px; color: var(--muted); margin-bottom: 6px; }
input, textarea, select { background:#0f1115; color:#e5e7eb; border:1px solid var(--border); padding:10px 12px; border-radius:8px; outline:none; }
textarea { min-height: 100px; resize: vertical; }
.btns { display:flex; gap:8px; flex-wrap:wrap; margin-top: 8px; }
button { background:#1a1f2b; color:#e5e7eb; border:1px solid var(--border); padding:8px 12px; border-radius:8px; cursor:pointer; }
button.primary { background: var(--accent); border-color: var(--accent); color:white; }
button.ok { background: var(--ok); border-color: var(--ok); color:white; }
button.warn { background: var(--warn); border-color: var(--warn); color:black; }
button.ghost { background: transparent; }
.muted { color: var(--muted); font-size: 12px; }
.mono { font-family: ui-monospace, SFMono-Regular, Menlo, Monaco, Consolas, "Liberation Mono", "Courier New", monospace; }
.grid2 { display:grid; grid-template-columns: 1fr 1fr; gap: 12px; }
@media (max-width: 880px){ .grid2 { grid-template-columns: 1fr; } }
.pill { padding: 3px 8px; border-radius:999px; font-size: 12px; border:1px solid var(--border); background:#0f1115; }
.ok { color: #10b981; }
.err { color: #ef4444; }
.sep { height:1px; background: var(--border); margin: 12px 0; }
</style>
</head>
<body>
<div class="wrap">
<h1>OAuth2/OIDC 授权码 + PKCE 前端演示</h1>
<div class="card">
<div class="row">
<div class="col">
<label>Issuer可选用于自动发现 /.well-known/openid-configuration</label>
<input id="issuer" placeholder="https://your-domain" />
<div class="btns"><button class="" id="btnDiscover">自动发现端点</button></div>
<div class="muted">提示:若未配置 Issuer可直接填写下方端点。</div>
</div>
</div>
<div class="row">
<div class="col"><label>Authorization Endpoint</label><input id="authorization_endpoint" placeholder="https://domain/api/oauth/authorize" /></div>
<div class="col"><label>Token Endpoint</label><input id="token_endpoint" placeholder="https://domain/api/oauth/token" /></div>
</div>
<div class="row">
<div class="col"><label>UserInfo Endpoint可选</label><input id="userinfo_endpoint" placeholder="https://domain/api/oauth/userinfo" /></div>
<div class="col"><label>Client ID</label><input id="client_id" placeholder="your-public-client-id" /></div>
</div>
<div class="row">
<div class="col"><label>Redirect URI当前页地址或你的回调</label><input id="redirect_uri" /></div>
<div class="col"><label>Scope</label><input id="scope" value="openid profile email" /></div>
</div>
<div class="row">
<div class="col"><label>State</label><input id="state" /></div>
<div class="col"><label>Nonce</label><input id="nonce" /></div>
</div>
<div class="row">
<div class="col"><label>Code Verifier自动生成不会上送</label><input id="code_verifier" class="mono" readonly /></div>
<div class="col"><label>Code ChallengeS256</label><input id="code_challenge" class="mono" readonly /></div>
</div>
<div class="btns">
<button id="btnGenPkce">生成 PKCE</button>
<button id="btnRandomState">随机 State</button>
<button id="btnRandomNonce">随机 Nonce</button>
<button id="btnMakeAuthURL">生成授权链接</button>
<button id="btnAuthorize" class="primary">跳转授权</button>
</div>
<div class="row" style="margin-top:8px;">
<div class="col">
<label>授权链接(只生成不跳转)</label>
<textarea id="authorize_url" class="mono" placeholder="(空)"></textarea>
<div class="btns"><button id="btnCopyAuthURL">复制链接</button></div>
</div>
</div>
<div class="sep"></div>
<div class="muted">说明:
<ul>
<li>本页为纯前端演示,适用于公开客户端(不需要 client_secret</li>
<li>如跨域调用 Token/UserInfo需要服务端正确设置 CORS建议将此 demo 部署到同源域名下。</li>
</ul>
</div>
<div class="sep"></div>
<div class="row">
<div class="col">
<label>粘贴 OIDC Discovery JSON/.well-known/openid-configuration</label>
<textarea id="conf_json" class="mono" placeholder='{"issuer":"https://...","authorization_endpoint":"...","token_endpoint":"...","userinfo_endpoint":"..."}'></textarea>
<div class="btns">
<button id="btnParseConf">解析并填充端点</button>
<button id="btnGenConf">用当前端点生成 JSON</button>
</div>
<div class="muted">可将服务端返回的 OIDC Discovery JSON 粘贴到此处,点击“解析并填充端点”。</div>
</div>
</div>
</div>
<div class="card">
<div class="row">
<div class="col">
<label>授权结果</label>
<div id="authResult" class="muted">等待授权...</div>
</div>
</div>
<div class="grid2" style="margin-top:12px;">
<div>
<label>Access Token</label>
<textarea id="access_token" class="mono" placeholder="(空)"></textarea>
<div class="btns">
<button id="btnCopyAT">复制</button>
<button id="btnCallUserInfo" class="ok">调用 UserInfo</button>
</div>
<div id="userinfoOut" class="muted" style="margin-top:6px;"></div>
</div>
<div>
<label>ID TokenJWT</label>
<textarea id="id_token" class="mono" placeholder="(空)"></textarea>
<div class="btns">
<button id="btnDecodeJWT">解码显示 Claims</button>
</div>
<pre id="jwtClaims" class="mono" style="white-space:pre-wrap; word-break:break-all; margin-top:6px;"></pre>
</div>
</div>
<div class="grid2" style="margin-top:12px;">
<div>
<label>Refresh Token</label>
<textarea id="refresh_token" class="mono" placeholder="(空)"></textarea>
<div class="btns">
<button id="btnRefreshToken">使用 Refresh Token 刷新</button>
</div>
</div>
<div>
<label>原始 Token 响应</label>
<textarea id="token_raw" class="mono" placeholder="(空)"></textarea>
</div>
</div>
</div>
</div>
<script>
const $ = (id) => document.getElementById(id);
const toB64Url = (buf) => btoa(String.fromCharCode(...new Uint8Array(buf))).replace(/\+/g, '-').replace(/\//g, '_').replace(/=+$/, '');
async function sha256B64Url(str){
const data = new TextEncoder().encode(str);
const digest = await crypto.subtle.digest('SHA-256', data);
return toB64Url(digest);
}
function randStr(len=64){
const chars = 'ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789-._~';
const arr = new Uint8Array(len); crypto.getRandomValues(arr);
return Array.from(arr, v => chars[v % chars.length]).join('');
}
function setAuthInfo(msg, ok=true){
const el = $('authResult');
el.textContent = msg;
el.className = ok ? 'ok' : 'err';
}
function qs(name){ const u=new URL(location.href); return u.searchParams.get(name); }
function persist(name, val){ sessionStorage.setItem('demo_'+name, val); }
function load(name){ return sessionStorage.getItem('demo_'+name) || ''; }
// init defaults
(function init(){
$('redirect_uri').value = window.location.origin + window.location.pathname;
// try load from discovery if issuer saved previously
const iss = load('issuer'); if(iss) $('issuer').value = iss;
const cid = load('client_id'); if(cid) $('client_id').value = cid;
const scp = load('scope'); if(scp) $('scope').value = scp;
})();
$('btnDiscover').onclick = async () => {
const iss = $('issuer').value.trim(); if(!iss){ alert('请填写 Issuer'); return; }
try{
persist('issuer', iss);
const res = await fetch(iss.replace(/\/$/,'') + '/api/.well-known/openid-configuration');
const d = await res.json();
$('authorization_endpoint').value = d.authorization_endpoint || '';
$('token_endpoint').value = d.token_endpoint || '';
$('userinfo_endpoint').value = d.userinfo_endpoint || '';
if (d.issuer) { $('issuer').value = d.issuer; persist('issuer', d.issuer); }
$('conf_json').value = JSON.stringify(d, null, 2);
setAuthInfo('已从发现文档加载端点', true);
}catch(e){ setAuthInfo('自动发现失败:'+e, false); }
};
$('btnGenPkce').onclick = async () => {
const v = randStr(64); const c = await sha256B64Url(v);
$('code_verifier').value = v; $('code_challenge').value = c;
persist('code_verifier', v); persist('code_challenge', c);
setAuthInfo('已生成 PKCE 参数', true);
};
$('btnRandomState').onclick = () => { $('state').value = randStr(16); persist('state', $('state').value); };
$('btnRandomNonce').onclick = () => { $('nonce').value = randStr(16); persist('nonce', $('nonce').value); };
function buildAuthorizeURLFromFields() {
const auth = $('authorization_endpoint').value.trim();
const token = $('token_endpoint').value.trim(); // just validate
const cid = $('client_id').value.trim();
const red = $('redirect_uri').value.trim();
const scp = $('scope').value.trim() || 'openid profile email';
const st = $('state').value.trim() || randStr(16);
const no = $('nonce').value.trim() || randStr(16);
const cc = $('code_challenge').value.trim();
const cv = $('code_verifier').value.trim();
if(!auth || !token || !cid || !red){ throw new Error('请先完善端点/ClientID/RedirectURI'); }
if(!cc || !cv){ throw new Error('请先生成 PKCE'); }
persist('authorization_endpoint', auth); persist('token_endpoint', token);
persist('client_id', cid); persist('redirect_uri', red); persist('scope', scp);
persist('state', st); persist('nonce', no); persist('code_verifier', cv);
const u = new URL(auth);
u.searchParams.set('response_type', 'code');
u.searchParams.set('client_id', cid);
u.searchParams.set('redirect_uri', red);
u.searchParams.set('scope', scp);
u.searchParams.set('state', st);
u.searchParams.set('nonce', no);
u.searchParams.set('code_challenge', cc);
u.searchParams.set('code_challenge_method', 'S256');
return u.toString();
}
$('btnMakeAuthURL').onclick = () => {
try {
const url = buildAuthorizeURLFromFields();
$('authorize_url').value = url;
setAuthInfo('已生成授权链接', true);
} catch(e){ setAuthInfo(e.message, false); }
};
$('btnAuthorize').onclick = () => {
try { const url = buildAuthorizeURLFromFields(); location.href = url; }
catch(e){ setAuthInfo(e.message, false); }
};
$('btnCopyAuthURL').onclick = async () => { try{ await navigator.clipboard.writeText($('authorize_url').value); }catch{} };
// Parse OIDC discovery JSON pasted by user
$('btnParseConf').onclick = () => {
const txt = $('conf_json').value.trim(); if(!txt){ alert('请先粘贴 JSON'); return; }
try{
const d = JSON.parse(txt);
if (d.issuer) { $('issuer').value = d.issuer; persist('issuer', d.issuer); }
if (d.authorization_endpoint) $('authorization_endpoint').value = d.authorization_endpoint;
if (d.token_endpoint) $('token_endpoint').value = d.token_endpoint;
if (d.userinfo_endpoint) $('userinfo_endpoint').value = d.userinfo_endpoint;
setAuthInfo('已解析配置并填充端点', true);
}catch(e){ setAuthInfo('解析失败:'+e, false); }
};
// Generate a minimal discovery JSON from current fields
$('btnGenConf').onclick = () => {
const d = {
issuer: $('issuer').value.trim() || undefined,
authorization_endpoint: $('authorization_endpoint').value.trim() || undefined,
token_endpoint: $('token_endpoint').value.trim() || undefined,
userinfo_endpoint: $('userinfo_endpoint').value.trim() || undefined,
};
$('conf_json').value = JSON.stringify(d, null, 2);
};
async function postForm(url, data){
const body = Object.entries(data).map(([k,v])=> `${encodeURIComponent(k)}=${encodeURIComponent(v)}`).join('&');
const res = await fetch(url, { method:'POST', headers:{ 'Content-Type':'application/x-www-form-urlencoded' }, body });
if(!res.ok){ const t = await res.text(); throw new Error(`HTTP ${res.status} ${t}`); }
return res.json();
}
async function handleCallback(){
const code = qs('code'); const err = qs('error');
const state = qs('state');
if(err){ setAuthInfo('授权失败:'+err, false); return; }
if(!code){ setAuthInfo('等待授权...', true); return; }
// state check
if(state && load('state') && state !== load('state')){ setAuthInfo('state 不匹配,已拒绝', false); return; }
try{
const tokenEp = load('token_endpoint');
const data = await postForm(tokenEp, {
grant_type:'authorization_code',
code,
client_id: load('client_id'),
redirect_uri: load('redirect_uri'),
code_verifier: load('code_verifier')
});
$('access_token').value = data.access_token || '';
$('id_token').value = data.id_token || '';
$('refresh_token').value = data.refresh_token || '';
$('token_raw').value = JSON.stringify(data, null, 2);
setAuthInfo('授权成功,已获取令牌', true);
}catch(e){ setAuthInfo('交换令牌失败:'+e.message, false); }
}
handleCallback();
$('btnCopyAT').onclick = async () => { try{ await navigator.clipboard.writeText($('access_token').value); }catch{} };
$('btnDecodeJWT').onclick = () => {
const t = $('id_token').value.trim(); if(!t){ $('jwtClaims').textContent='(空)'; return; }
const parts = t.split('.'); if(parts.length<2){ $('jwtClaims').textContent='格式错误'; return; }
try{ const json = JSON.parse(atob(parts[1].replace(/-/g,'+').replace(/_/g,'/'))); $('jwtClaims').textContent = JSON.stringify(json, null, 2);}catch(e){ $('jwtClaims').textContent='解码失败:'+e; }
};
$('btnCallUserInfo').onclick = async () => {
const at = $('access_token').value.trim(); const ep = $('userinfo_endpoint').value.trim(); if(!at||!ep){ alert('请填写UserInfo端点并获取AccessToken'); return; }
try{
const res = await fetch(ep, { headers:{ Authorization: 'Bearer '+at } });
const data = await res.json(); $('userinfoOut').textContent = JSON.stringify(data, null, 2);
}catch(e){ $('userinfoOut').textContent = '调用失败:'+e; }
};
$('btnRefreshToken').onclick = async () => {
const rt = $('refresh_token').value.trim(); if(!rt){ alert('没有刷新令牌'); return; }
try{
const tokenEp = load('token_endpoint');
const data = await postForm(tokenEp, {
grant_type:'refresh_token',
refresh_token: rt,
client_id: load('client_id')
});
$('access_token').value = data.access_token || '';
$('id_token').value = data.id_token || '';
$('refresh_token').value = data.refresh_token || '';
$('token_raw').value = JSON.stringify(data, null, 2);
setAuthInfo('刷新成功', true);
}catch(e){ setAuthInfo('刷新失败:'+e.message, false); }
};
</script>
</body>
</html>

View File

@@ -0,0 +1,181 @@
package main
import (
"context"
"fmt"
"io"
"log"
"net/http"
"net/url"
"path"
"strings"
"time"
"golang.org/x/oauth2"
"golang.org/x/oauth2/clientcredentials"
)
func main() {
// 测试 Client Credentials 流程
//testClientCredentials()
// 测试 Authorization Code + PKCE 流程(需要浏览器交互)
testAuthorizationCode()
}
// testClientCredentials 测试服务对服务认证
func testClientCredentials() {
fmt.Println("=== Testing Client Credentials Flow ===")
cfg := clientcredentials.Config{
ClientID: "client_dsFyyoyNZWjhbNa2", // 需要先创建客户端
ClientSecret: "hLLdn2Ia4UM7hcsJaSuUFDV0Px9BrkNq",
TokenURL: "http://localhost:3000/api/oauth/token",
Scopes: []string{"api:read", "api:write"},
EndpointParams: map[string][]string{
"audience": {"api://new-api"},
},
}
// 创建HTTP客户端
httpClient := cfg.Client(context.Background())
// 调用受保护的API
resp, err := httpClient.Get("http://localhost:3000/api/status")
if err != nil {
log.Printf("Request failed: %v", err)
return
}
defer resp.Body.Close()
body, err := io.ReadAll(resp.Body)
if err != nil {
log.Printf("Failed to read response: %v", err)
return
}
fmt.Printf("Status: %s\n", resp.Status)
fmt.Printf("Response: %s\n", string(body))
}
// testAuthorizationCode 测试授权码流程
func testAuthorizationCode() {
fmt.Println("=== Testing Authorization Code + PKCE Flow ===")
conf := oauth2.Config{
ClientID: "client_dsFyyoyNZWjhbNa2", // 需要先创建客户端
ClientSecret: "JHiugKf89OMmTLuZMZyA2sgZnO0Ioae3",
RedirectURL: "http://localhost:9999/callback",
// 包含 openid/profile/email 以便调用 UserInfo
Scopes: []string{"openid", "profile", "email", "api:read"},
Endpoint: oauth2.Endpoint{
AuthURL: "http://localhost:3000/api/oauth/authorize",
TokenURL: "http://localhost:3000/api/oauth/token",
},
}
// 生成PKCE参数
codeVerifier := oauth2.GenerateVerifier()
state := fmt.Sprintf("state-%d", time.Now().Unix())
// 构建授权URL
url := conf.AuthCodeURL(
state,
oauth2.S256ChallengeOption(codeVerifier),
//oauth2.SetAuthURLParam("audience", "api://new-api"),
)
fmt.Printf("Visit this URL to authorize:\n%s\n\n", url)
fmt.Printf("A local server will listen on http://localhost:9999/callback to receive the code...\n")
// 启动回调本地服务器,自动接收授权码
codeCh := make(chan string, 1)
srv := &http.Server{Addr: ":9999"}
http.HandleFunc("/callback", func(w http.ResponseWriter, r *http.Request) {
q := r.URL.Query()
if errParam := q.Get("error"); errParam != "" {
fmt.Fprintf(w, "Authorization failed: %s", errParam)
return
}
gotState := q.Get("state")
if gotState != state {
http.Error(w, "state mismatch", http.StatusBadRequest)
return
}
code := q.Get("code")
if code == "" {
http.Error(w, "missing code", http.StatusBadRequest)
return
}
fmt.Fprintln(w, "Authorization received. You may close this window.")
select {
case codeCh <- code:
default:
}
go func() {
// 稍后关闭服务
_ = srv.Shutdown(context.Background())
}()
})
go func() {
_ = srv.ListenAndServe()
}()
// 等待授权码
var code string
select {
case code = <-codeCh:
case <-time.After(5 * time.Minute):
log.Println("Timeout waiting for authorization code")
_ = srv.Shutdown(context.Background())
return
}
// 交换令牌
token, err := conf.Exchange(
context.Background(),
code,
oauth2.VerifierOption(codeVerifier),
)
if err != nil {
log.Printf("Token exchange failed: %v", err)
return
}
fmt.Printf("Access Token: %s\n", token.AccessToken)
fmt.Printf("Token Type: %s\n", token.TokenType)
fmt.Printf("Expires In: %v\n", token.Expiry)
// 使用令牌调用 UserInfo
client := conf.Client(context.Background(), token)
userInfoURL := buildUserInfoFromAuth(conf.Endpoint.AuthURL)
resp, err := client.Get(userInfoURL)
if err != nil {
log.Printf("UserInfo request failed: %v", err)
return
}
defer resp.Body.Close()
body, err := io.ReadAll(resp.Body)
if err != nil {
log.Printf("Failed to read UserInfo response: %v", err)
return
}
fmt.Printf("UserInfo: %s\n", string(body))
}
// buildUserInfoFromAuth 将授权端点URL转换为UserInfo端点URL
func buildUserInfoFromAuth(auth string) string {
u, err := url.Parse(auth)
if err != nil {
return ""
}
// 将最后一个路径段 authorize 替换为 userinfo
dir := path.Dir(u.Path)
if strings.HasSuffix(u.Path, "/authorize") {
u.Path = path.Join(dir, "userinfo")
} else {
// 回退:追加默认 /oauth/userinfo
u.Path = path.Join(dir, "userinfo")
}
return u.String()
}

23
go.mod
View File

@@ -11,20 +11,24 @@ require (
github.com/aws/aws-sdk-go-v2/credentials v1.17.11
github.com/aws/aws-sdk-go-v2/service/bedrockruntime v1.33.0
github.com/aws/smithy-go v1.22.5
github.com/bytedance/gopkg v0.0.0-20220118071334-3db87571198b
github.com/bytedance/gopkg v0.0.0-20221122125632-68358b8ecec6
github.com/gin-contrib/cors v1.7.2
github.com/gin-contrib/gzip v0.0.6
github.com/gin-contrib/sessions v0.0.5
github.com/gin-contrib/static v0.0.1
github.com/gin-gonic/gin v1.9.1
github.com/glebarez/sqlite v1.9.0
github.com/go-oauth2/gin-server v1.1.0
github.com/go-oauth2/oauth2/v4 v4.5.4
github.com/go-playground/validator/v10 v10.20.0
github.com/go-redis/redis/v8 v8.11.5
github.com/golang-jwt/jwt v3.2.2+incompatible
github.com/golang-jwt/jwt/v5 v5.3.0
github.com/google/uuid v1.6.0
github.com/gorilla/websocket v1.5.0
github.com/jinzhu/copier v0.4.0
github.com/joho/godotenv v1.5.1
github.com/lestrrat-go/jwx/v2 v2.1.6
github.com/pkg/errors v0.9.1
github.com/pquerna/otp v1.5.0
github.com/samber/lo v1.39.0
@@ -38,6 +42,7 @@ require (
golang.org/x/crypto v0.35.0
golang.org/x/image v0.23.0
golang.org/x/net v0.35.0
golang.org/x/oauth2 v0.30.0
golang.org/x/sync v0.11.0
gorm.io/driver/mysql v1.4.3
gorm.io/driver/postgres v1.5.2
@@ -55,6 +60,7 @@ require (
github.com/cespare/xxhash/v2 v2.3.0 // indirect
github.com/cloudwego/base64x v0.1.4 // indirect
github.com/cloudwego/iasm v0.2.0 // indirect
github.com/decred/dcrd/dcrec/secp256k1/v4 v4.4.0 // indirect
github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f // indirect
github.com/dlclark/regexp2 v1.11.5 // indirect
github.com/dustin/go-humanize v1.0.1 // indirect
@@ -65,7 +71,7 @@ require (
github.com/go-playground/locales v0.14.1 // indirect
github.com/go-playground/universal-translator v0.18.1 // indirect
github.com/go-sql-driver/mysql v1.7.0 // indirect
github.com/goccy/go-json v0.10.2 // indirect
github.com/goccy/go-json v0.10.3 // indirect
github.com/google/go-cmp v0.6.0 // indirect
github.com/gorilla/context v1.1.1 // indirect
github.com/gorilla/securecookie v1.1.1 // indirect
@@ -79,14 +85,25 @@ require (
github.com/json-iterator/go v1.1.12 // indirect
github.com/klauspost/cpuid/v2 v2.2.9 // indirect
github.com/leodido/go-urn v1.4.0 // indirect
github.com/lestrrat-go/blackmagic v1.0.3 // indirect
github.com/lestrrat-go/httpcc v1.0.1 // indirect
github.com/lestrrat-go/httprc v1.0.6 // indirect
github.com/lestrrat-go/iter v1.0.2 // indirect
github.com/lestrrat-go/option v1.0.1 // indirect
github.com/mattn/go-isatty v0.0.20 // indirect
github.com/mitchellh/mapstructure v1.5.0 // indirect
github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd // indirect
github.com/modern-go/reflect2 v1.0.2 // indirect
github.com/pelletier/go-toml/v2 v2.2.1 // indirect
github.com/remyoudompheng/bigfft v0.0.0-20230129092748-24d4a6f8daec // indirect
github.com/segmentio/asm v1.2.0 // indirect
github.com/tidwall/btree v0.0.0-20191029221954-400434d76274 // indirect
github.com/tidwall/buntdb v1.1.2 // indirect
github.com/tidwall/grect v0.0.0-20161006141115-ba9a043346eb // indirect
github.com/tidwall/match v1.1.1 // indirect
github.com/tidwall/pretty v1.2.0 // indirect
github.com/tidwall/rtree v0.0.0-20180113144539-6cd427091e0e // indirect
github.com/tidwall/tinyqueue v0.0.0-20180302190814-1e39f5511563 // indirect
github.com/tklauser/go-sysconf v0.3.12 // indirect
github.com/tklauser/numcpus v0.6.1 // indirect
github.com/twitchyliquid64/golang-asm v0.15.1 // indirect
@@ -94,7 +111,7 @@ require (
github.com/yusufpapurcu/wmi v1.2.3 // indirect
golang.org/x/arch v0.12.0 // indirect
golang.org/x/exp v0.0.0-20240404231335-c0f41cb1a7a0 // indirect
golang.org/x/sys v0.30.0 // indirect
golang.org/x/sys v0.31.0 // indirect
golang.org/x/text v0.22.0 // indirect
google.golang.org/protobuf v1.34.2 // indirect
gopkg.in/yaml.v3 v3.0.1 // indirect

94
go.sum
View File

@@ -1,5 +1,7 @@
github.com/Calcium-Ion/go-epay v0.0.4 h1:C96M7WfRLadcIVscWzwLiYs8etI1wrDmtFMuK2zP22A=
github.com/Calcium-Ion/go-epay v0.0.4/go.mod h1:cxo/ZOg8ClvE3VAnCmEzbuyAZINSq7kFEN9oHj5WQ2U=
github.com/ajg/form v1.5.1 h1:t9c7v8JUKu/XxOGBU0yjNpaMloxGEJhUkqFRq0ibGeU=
github.com/ajg/form v1.5.1/go.mod h1:uL1WgH+h2mgNtvBq0339dVnzXdBETtL2LeUXaIv25UY=
github.com/andybalholm/brotli v1.1.1 h1:PR2pgnyFznKEugtsUo0xLdDop5SKXd5Qf5ysW+7XdTA=
github.com/andybalholm/brotli v1.1.1/go.mod h1:05ib4cKhjx3OQYUY22hTVd34Bc8upXjOLL2rKwwZBoA=
github.com/anknown/ahocorasick v0.0.0-20190904063843-d75dbd5169c0 h1:onfun1RA+KcxaMk1lfrRnwCd1UUuOjJM/lri5eM1qMs=
@@ -23,8 +25,8 @@ github.com/aws/smithy-go v1.22.5/go.mod h1:t1ufH5HMublsJYulve2RKmHDC15xu1f26kHCp
github.com/boombuler/barcode v1.0.1-0.20190219062509-6c824513bacc/go.mod h1:paBWMcWSl3LHKBqUq+rly7CNSldXjb2rDl3JlRe0mD8=
github.com/boombuler/barcode v1.1.0 h1:ChaYjBR63fr4LFyGn8E8nt7dBSt3MiU3zMOZqFvVkHo=
github.com/boombuler/barcode v1.1.0/go.mod h1:paBWMcWSl3LHKBqUq+rly7CNSldXjb2rDl3JlRe0mD8=
github.com/bytedance/gopkg v0.0.0-20220118071334-3db87571198b h1:LTGVFpNmNHhj0vhOlfgWueFJ32eK9blaIlHR2ciXOT0=
github.com/bytedance/gopkg v0.0.0-20220118071334-3db87571198b/go.mod h1:2ZlV9BaUH4+NXIBF0aMdKKAnHTzqH+iMU4KUjAbL23Q=
github.com/bytedance/gopkg v0.0.0-20221122125632-68358b8ecec6 h1:FCLDGi1EmB7JzjVVYNZiqc/zAJj2BQ5M0lfkVOxbfs8=
github.com/bytedance/gopkg v0.0.0-20221122125632-68358b8ecec6/go.mod h1:5FoAH5xUHHCMDvQPy1rnj8moqLkLHFaDVBjHhcFwEi0=
github.com/bytedance/sonic v1.11.6 h1:oUp34TzMlL+OY1OUWxHqsdkgC/Zfc85zGqw9siXjrc0=
github.com/bytedance/sonic v1.11.6/go.mod h1:LysEHSvpvDySVdC2f87zGWf6CIKJcAvqab1ZaiQtds4=
github.com/bytedance/sonic/loader v0.1.1 h1:c+e5Pt1k/cy5wMveRDyk2X4B9hF4g7an8N3zCYjJFNM=
@@ -39,16 +41,22 @@ github.com/creack/pty v1.1.9/go.mod h1:oKZEueFk5CKHvIhNR5MUki03XCEU+Q6VDXinZuGJ3
github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c=
github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
github.com/decred/dcrd/dcrec/secp256k1/v4 v4.4.0 h1:NMZiJj8QnKe1LgsbDayM4UoHwbvwDRwnI3hwNaAHRnc=
github.com/decred/dcrd/dcrec/secp256k1/v4 v4.4.0/go.mod h1:ZXNYxsqcloTdSy/rNShjYzMhyjf0LaoftYK0p+A3h40=
github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f h1:lO4WD4F/rVNCu3HqELle0jiPLLBs70cWOduZpkS1E78=
github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f/go.mod h1:cuUVRXasLTGF7a8hSLbxyZXjz+1KgoB3wDUb6vlszIc=
github.com/dlclark/regexp2 v1.11.5 h1:Q/sSnsKerHeCkc/jSTNq1oCm7KiVgUMZRDUoRu0JQZQ=
github.com/dlclark/regexp2 v1.11.5/go.mod h1:DHkYz0B9wPfa6wondMfaivmHpzrQ3v9q8cnmRbL6yW8=
github.com/dustin/go-humanize v1.0.1 h1:GzkhY7T5VNhEkwH0PVJgjz+fX1rhBrR7pRT3mDkpeCY=
github.com/dustin/go-humanize v1.0.1/go.mod h1:Mu1zIs6XwVuF/gI1OepvI0qD18qycQx+mFykh5fBlto=
github.com/fatih/structs v1.1.0 h1:Q7juDM0QtcnhCpeyLGQKyg4TOIghuNXrkL32pHAUMxo=
github.com/fatih/structs v1.1.0/go.mod h1:9NiDSp5zOcgEDl+j00MP/WkGVPOlPRLejGD8Ga6PJ7M=
github.com/fsnotify/fsnotify v1.4.9 h1:hsms1Qyu0jgnwNXIxa+/V/PDsU6CfLf6CNO8H7IWoS4=
github.com/fsnotify/fsnotify v1.4.9/go.mod h1:znqG4EE+3YCdAaPaxE2ZRY/06pZUdp0tY4IgpuI1SZQ=
github.com/gabriel-vasile/mimetype v1.4.3 h1:in2uUcidCuFcDKtdcBxlR0rJ1+fsokWf+uqxgUFjbI0=
github.com/gabriel-vasile/mimetype v1.4.3/go.mod h1:d8uq/6HKRL6CGdk+aubisF/M5GcPfT7nKyLpA0lbSSk=
github.com/gavv/httpexpect v2.0.0+incompatible h1:1X9kcRshkSKEjNJJxX9Y9mQ5BRfbxU5kORdjhlA1yX8=
github.com/gavv/httpexpect v2.0.0+incompatible/go.mod h1:x+9tiU1YnrOvnB725RkpoLv1M62hOWzwo5OXotisrKc=
github.com/gin-contrib/cors v1.7.2 h1:oLDHxdg8W/XDoN/8zamqk/Drgt4oVZDvaV0YmvVICQw=
github.com/gin-contrib/cors v1.7.2/go.mod h1:SUJVARKgQ40dmrzgXEVxj2m7Ig1v1qIboQkPDTQ9t2E=
github.com/gin-contrib/gzip v0.0.6 h1:NjcunTcGAj5CO1gn4N8jHOSIeRFHIbn51z6K+xaN4d4=
@@ -67,6 +75,10 @@ github.com/glebarez/go-sqlite v1.21.2 h1:3a6LFC4sKahUunAmynQKLZceZCOzUthkRkEAl9g
github.com/glebarez/go-sqlite v1.21.2/go.mod h1:sfxdZyhQjTM2Wry3gVYWaW072Ri1WMdWJi0k6+3382k=
github.com/glebarez/sqlite v1.9.0 h1:Aj6bPA12ZEx5GbSF6XADmCkYXlljPNUY+Zf1EQxynXs=
github.com/glebarez/sqlite v1.9.0/go.mod h1:YBYCoyupOao60lzp1MVBLEjZfgkq0tdB1voAQ09K9zw=
github.com/go-oauth2/gin-server v1.1.0 h1:+7AyIfrcKaThZxxABRYECysxAfTccgpFdAqY1enuzBk=
github.com/go-oauth2/gin-server v1.1.0/go.mod h1:f08F3l5/Pbayb4pjnv5PpUdQLFejgGfHrTjA6IZb0eM=
github.com/go-oauth2/oauth2/v4 v4.5.4 h1:YjI0tmGW8oxVhn9QSBIxlr641QugWrJY5UWa6XmLcW0=
github.com/go-oauth2/oauth2/v4 v4.5.4/go.mod h1:BXiOY+QZtZy2ewbsGk2B5P8TWmtz/Rf7ES5ZttQFxfQ=
github.com/go-ole/go-ole v1.2.6 h1:/Fpf6oFPoeFik9ty7siob0G6Ke8QvQEuVcuChpwXzpY=
github.com/go-ole/go-ole v1.2.6/go.mod h1:pprOEPIfldk/42T2oK7lQ4v4JSDwmV0As9GaiUsvbm0=
github.com/go-playground/assert/v2 v2.0.1/go.mod h1:VDjEfimB/XKnb+ZQfWdccd7VUvScMdVu0Titje2rxJ4=
@@ -90,20 +102,26 @@ github.com/go-sql-driver/mysql v1.6.0/go.mod h1:DCzpHaOWr8IXmIStZouvnhqoel9Qv2LB
github.com/go-sql-driver/mysql v1.7.0 h1:ueSltNNllEqE3qcWBTD0iQd3IpL/6U+mJxLkazJ7YPc=
github.com/go-sql-driver/mysql v1.7.0/go.mod h1:OXbVy3sEdcQ2Doequ6Z5BW6fXNQTmx+9S1MCJN5yJMI=
github.com/goccy/go-json v0.9.7/go.mod h1:6MelG93GURQebXPDq3khkgXZkazVtN9CRI+MGFi0w8I=
github.com/goccy/go-json v0.10.2 h1:CrxCmQqYDkv1z7lO7Wbh2HN93uovUHgrECaO5ZrCXAU=
github.com/goccy/go-json v0.10.2/go.mod h1:6MelG93GURQebXPDq3khkgXZkazVtN9CRI+MGFi0w8I=
github.com/goccy/go-json v0.10.3 h1:KZ5WoDbxAIgm2HNbYckL0se1fHD6rz5j4ywS6ebzDqA=
github.com/goccy/go-json v0.10.3/go.mod h1:oq7eo15ShAhp70Anwd5lgX2pLfOS3QCiwU/PULtXL6M=
github.com/golang-jwt/jwt v3.2.2+incompatible h1:IfV12K8xAKAnZqdXVzCZ+TOjboZ2keLg81eXfW3O+oY=
github.com/golang-jwt/jwt v3.2.2+incompatible/go.mod h1:8pz2t5EyA70fFQQSrl6XZXzqecmYZeUEB8OUGHkxJ+I=
github.com/golang-jwt/jwt/v5 v5.3.0 h1:pv4AsKCKKZuqlgs5sUmn4x8UlGa0kEVt/puTpKx9vvo=
github.com/golang-jwt/jwt/v5 v5.3.0/go.mod h1:fxCRLWMO43lRc8nhHWY6LGqRcf+1gQWArsqaEUEa5bE=
github.com/golang/protobuf v1.3.3/go.mod h1:vzj43D7+SQXF/4pzW/hwtAqwc6iTitCiVSaWz5lYuqw=
github.com/golang/protobuf v1.5.0/go.mod h1:FsONVRAS9T7sI+LIUmWTfcYkHO4aIWwzhcaSAoJOfIk=
github.com/google/go-cmp v0.5.5/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE=
github.com/google/go-cmp v0.6.0 h1:ofyhxvXcZhMsU5ulbFiLKl/XBFqE1GSq7atu8tAmTRI=
github.com/google/go-cmp v0.6.0/go.mod h1:17dUlkBOakJ0+DkrSSNjCkIjxS6bF9zb3elmeNGIjoY=
github.com/google/go-querystring v1.0.0 h1:Xkwi/a1rcvNg1PPYe5vI8GbeBY/jrVuDX5ASuANWTrk=
github.com/google/go-querystring v1.0.0/go.mod h1:odCYkC5MyYFN7vkCjXpyrEuKhc/BUO6wN/zVPAxq5ck=
github.com/google/gofuzz v1.0.0/go.mod h1:dBl0BpW6vV/+mYPU4Po3pmUjxk6FQPldtuIdl/M65Eg=
github.com/google/pprof v0.0.0-20221118152302-e6195bd50e26 h1:Xim43kblpZXfIBQsbuBVKCudVG457BR2GZFIz3uw3hQ=
github.com/google/pprof v0.0.0-20221118152302-e6195bd50e26/go.mod h1:dDKJzRmX4S37WGHujM7tX//fmj1uioxKzKxz3lo4HJo=
github.com/google/uuid v1.6.0 h1:NIvaJDMOsjHA8n1jAhLSgzrAzy1Hgr+hNrb57e+94F0=
github.com/google/uuid v1.6.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo=
github.com/gopherjs/gopherjs v0.0.0-20200217142428-fce0ec30dd00 h1:l5lAOZEym3oK3SQ2HBHWsJUfbNBiTXJDeW2QDxw9AQ0=
github.com/gopherjs/gopherjs v0.0.0-20200217142428-fce0ec30dd00/go.mod h1:wJfORRmW1u3UXTncJ5qlYoELFm8eSnnEO6hX4iZ3EWY=
github.com/gorilla/context v1.1.1 h1:AWwleXJkX/nhcU9bZSnZoi3h/qGYqQAGhq6zZe/aQW8=
github.com/gorilla/context v1.1.1/go.mod h1:kBGZzfjB9CEq2AlWe17Uuf7NDRt0dE0s8S51q0aT7Yg=
github.com/gorilla/securecookie v1.1.1 h1:miw7JPhV+b/lAHSXz4qd/nN9jRiAFV5FwjeKyCS8BvQ=
@@ -112,6 +130,8 @@ github.com/gorilla/sessions v1.2.1 h1:DHd3rPN5lE3Ts3D8rKkQ8x/0kqfeNmBAaiSi+o7Fsg
github.com/gorilla/sessions v1.2.1/go.mod h1:dk2InVEVJ0sfLlnXv9EAgkf6ecYs/i80K/zI+bUmuGM=
github.com/gorilla/websocket v1.5.0 h1:PPwGk2jz7EePpoHN/+ClbZu8SPxiqlu12wZP/3sWmnc=
github.com/gorilla/websocket v1.5.0/go.mod h1:YR8l580nyteQvAITg2hZ9XVh4b55+EU/adAjf1fMHhE=
github.com/imkira/go-interpol v1.1.0 h1:KIiKr0VSG2CUW1hl1jpiyuzuJeKUUpC8iM1AIE7N1Vk=
github.com/imkira/go-interpol v1.1.0/go.mod h1:z0h2/2T3XF8kyEPpRgJ3kmNv+C43p+I/CoI+jC3w2iA=
github.com/jackc/pgpassfile v1.0.0 h1:/6Hmqy13Ss2zCq62VdNG8tM1wchn8zjSGOBJ6icpsIM=
github.com/jackc/pgpassfile v1.0.0/go.mod h1:CEx0iS5ambNFdcRtxPj5JhEz+xB6uRky5eyVu/W2HEg=
github.com/jackc/pgservicefile v0.0.0-20240606120523-5a60cdf6a761 h1:iCEnooe7UlwOQYpKFhBabPMi4aNAfoODPEFNiAnClxo=
@@ -132,6 +152,10 @@ github.com/joho/godotenv v1.5.1/go.mod h1:f4LDr5Voq0i2e/R5DDNOoa2zzDfwtkZa6DnEwA
github.com/json-iterator/go v1.1.9/go.mod h1:KdQUCv79m/52Kvf8AW2vK1V8akMuk1QjK/uOdHXbAo4=
github.com/json-iterator/go v1.1.12 h1:PV8peI4a0ysnczrg+LtxykD8LfKY9ML6u2jnxaEnrnM=
github.com/json-iterator/go v1.1.12/go.mod h1:e30LSqwooZae/UwlEbR2852Gd8hjQvJoHmT4TnhNGBo=
github.com/jtolds/gls v4.20.0+incompatible h1:xdiiI2gbIgH/gLH7ADydsJ1uDOEzR8yvV7C0MuV77Wo=
github.com/jtolds/gls v4.20.0+incompatible/go.mod h1:QJZ7F/aHp+rZTRtaJ1ow/lLfFfVYBRgL+9YlvaHOwJU=
github.com/klauspost/compress v1.15.0 h1:xqfchp4whNFxn5A4XFyyYtitiWI8Hy5EW59jEwcyL6U=
github.com/klauspost/compress v1.15.0/go.mod h1:/3/Vjq9QcHkK5uEr5lBEmyoZ1iFhe47etQ6QUkpK6sk=
github.com/klauspost/cpuid/v2 v2.0.9/go.mod h1:FInQzS24/EEf25PyTYn52gqo7WaD8xa0213Md/qVLRg=
github.com/klauspost/cpuid/v2 v2.2.9 h1:66ze0taIn2H33fBvCkXuv9BmCwDfafmiIVpKV9kKGuY=
github.com/klauspost/cpuid/v2 v2.2.9/go.mod h1:rqkxqrZ1EhYM9G+hXH7YdowN5R5RGN6NK4QwQ3WMXF8=
@@ -148,6 +172,18 @@ github.com/leodido/go-urn v1.2.0/go.mod h1:+8+nEpDfqqsY+g338gtMEUOtuK+4dEMhiQEgx
github.com/leodido/go-urn v1.2.1/go.mod h1:zt4jvISO2HfUBqxjfIshjdMTYS56ZS/qv49ictyFfxY=
github.com/leodido/go-urn v1.4.0 h1:WT9HwE9SGECu3lg4d/dIA+jxlljEa1/ffXKmRjqdmIQ=
github.com/leodido/go-urn v1.4.0/go.mod h1:bvxc+MVxLKB4z00jd1z+Dvzr47oO32F/QSNjSBOlFxI=
github.com/lestrrat-go/blackmagic v1.0.3 h1:94HXkVLxkZO9vJI/w2u1T0DAoprShFd13xtnSINtDWs=
github.com/lestrrat-go/blackmagic v1.0.3/go.mod h1:6AWFyKNNj0zEXQYfTMPfZrAXUWUfTIZ5ECEUEJaijtw=
github.com/lestrrat-go/httpcc v1.0.1 h1:ydWCStUeJLkpYyjLDHihupbn2tYmZ7m22BGkcvZZrIE=
github.com/lestrrat-go/httpcc v1.0.1/go.mod h1:qiltp3Mt56+55GPVCbTdM9MlqhvzyuL6W/NMDA8vA5E=
github.com/lestrrat-go/httprc v1.0.6 h1:qgmgIRhpvBqexMJjA/PmwSvhNk679oqD1RbovdCGW8k=
github.com/lestrrat-go/httprc v1.0.6/go.mod h1:mwwz3JMTPBjHUkkDv/IGJ39aALInZLrhBp0X7KGUZlo=
github.com/lestrrat-go/iter v1.0.2 h1:gMXo1q4c2pHmC3dn8LzRhJfP1ceCbgSiT9lUydIzltI=
github.com/lestrrat-go/iter v1.0.2/go.mod h1:Momfcq3AnRlRjI5b5O8/G5/BvpzrhoFTZcn06fEOPt4=
github.com/lestrrat-go/jwx/v2 v2.1.6 h1:hxM1gfDILk/l5ylers6BX/Eq1m/pnxe9NBwW6lVfecA=
github.com/lestrrat-go/jwx/v2 v2.1.6/go.mod h1:Y722kU5r/8mV7fYDifjug0r8FK8mZdw0K0GpJw/l8pU=
github.com/lestrrat-go/option v1.0.1 h1:oAzP2fvZGQKWkvHa1/SAcFolBEca1oN+mQ7eooNBEYU=
github.com/lestrrat-go/option v1.0.1/go.mod h1:5ZHFbivi4xwXxhxY9XHDe2FHo6/Z7WWmtT7T5nBBp3I=
github.com/mattn/go-isatty v0.0.12/go.mod h1:cbi8OIDigv2wuxKPP5vlRcQ1OAZbq2CE4Kysco4FUpU=
github.com/mattn/go-isatty v0.0.14/go.mod h1:7GGIvUiUoEMVVmxf/4nioHXj79iQHKdU27kJ6hsGG94=
github.com/mattn/go-isatty v0.0.20 h1:xfD0iDuEKnDkl03q4limB+vH+GxLEtL/jb4xVJSWWEY=
@@ -160,6 +196,8 @@ github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd/go.mod h1:6dJ
github.com/modern-go/reflect2 v0.0.0-20180701023420-4b7aa43c6742/go.mod h1:bx2lNnkwVCuqBIxFjflWJWanXIb3RllmbCylyMrvgv0=
github.com/modern-go/reflect2 v1.0.2 h1:xBagoLtFs94CBntxluKeaWgTMpvLxC4ur3nMaC9Gz0M=
github.com/modern-go/reflect2 v1.0.2/go.mod h1:yWuevngMOJpCy52FWWMvUC8ws7m/LJsjYzDa0/r8luk=
github.com/moul/http2curl v1.0.0 h1:dRMWoAtb+ePxMlLkrCbAqh4TlPHXvoGUSQ323/9Zahs=
github.com/moul/http2curl v1.0.0/go.mod h1:8UbvGypXm98wA/IqH45anm5Y2Z6ep6O31QGOAZ3H0fQ=
github.com/nxadm/tail v1.4.8 h1:nPr65rt6Y5JFSKQO7qToXr7pePgD6Gwiw05lkbyAQTE=
github.com/nxadm/tail v1.4.8/go.mod h1:+ncqLTQzXmGhMZNUePPaPqPvBxHAIsmXswZKocGu+AU=
github.com/onsi/ginkgo v1.16.5 h1:8xi0RTUf59SOSfEtZMvwTvXYMzG4gV23XVHOZiXNtnE=
@@ -184,10 +222,18 @@ github.com/rogpeppe/go-internal v1.8.0 h1:FCbCCtXNOY3UtUuHUYaghJg4y7Fd14rXifAYUA
github.com/rogpeppe/go-internal v1.8.0/go.mod h1:WmiCO8CzOY8rg0OYDC4/i/2WRWAB6poM+XZ2dLUbcbE=
github.com/samber/lo v1.39.0 h1:4gTz1wUhNYLhFSKl6O+8peW0v2F4BCY034GRpU9WnuA=
github.com/samber/lo v1.39.0/go.mod h1:+m/ZKRl6ClXCE2Lgf3MsQlWfh4bn1bz6CXEOxnEXnEA=
github.com/segmentio/asm v1.2.0 h1:9BQrFxC+YOHJlTlHGkTrFWf59nbL3XnCoFLTwDCI7ys=
github.com/segmentio/asm v1.2.0/go.mod h1:BqMnlJP91P8d+4ibuonYZw9mfnzI9HfxselHZr5aAcs=
github.com/sergi/go-diff v1.1.0 h1:we8PVUC3FE2uYfodKH/nBHMSetSfHDR6scGdBi+erh0=
github.com/sergi/go-diff v1.1.0/go.mod h1:STckp+ISIX8hZLjrqAeVduY0gWCT9IjLuqbuNXdaHfM=
github.com/shirou/gopsutil v3.21.11+incompatible h1:+1+c1VGhc88SSonWP6foOcLhvnKlUeu/erjjvaPEYiI=
github.com/shirou/gopsutil v3.21.11+incompatible/go.mod h1:5b4v6he4MtMOwMlS0TUMTu2PcXUg8+E1lC7eC3UO/RA=
github.com/shopspring/decimal v1.4.0 h1:bxl37RwXBklmTi0C79JfXCEBD1cqqHt0bbgBAGFp81k=
github.com/shopspring/decimal v1.4.0/go.mod h1:gawqmDU56v4yIKSwfBSFip1HdCCXN8/+DMd9qYNcwME=
github.com/smartystreets/assertions v1.1.0 h1:MkTeG1DMwsrdH7QtLXy5W+fUxWq+vmb6cLmyJ7aRtF0=
github.com/smartystreets/assertions v1.1.0/go.mod h1:tcbTF8ujkAEcZ8TElKY+i30BzYlVhC/LOxJk7iOWnoo=
github.com/smartystreets/goconvey v1.6.4 h1:fv0U8FUIMPNf1L9lnHLvLhgicrIVChEkdzIKYqbNC9s=
github.com/smartystreets/goconvey v1.6.4/go.mod h1:syvi0/a8iFYH4r/RixwvyeAJjdLS9QV7WQ/tjFTllLA=
github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME=
github.com/stretchr/objx v0.4.0/go.mod h1:YvHI0jy2hoMjB+UWwv71VJQ9isScKT/TqJzVSSt89Yw=
github.com/stretchr/objx v0.5.0/go.mod h1:Yh+to48EsGEfYuaHDzXPcE3xhTkx73EhmCGUpEOglKo=
@@ -200,21 +246,35 @@ github.com/stretchr/testify v1.7.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/
github.com/stretchr/testify v1.8.0/go.mod h1:yNjHg4UonilssWZ8iaSj1OCr/vHnekPRkoO+kdMU+MU=
github.com/stretchr/testify v1.8.1/go.mod h1:w2LPCIKwWwSfY2zedu0+kehJoqGctiVI29o6fzry7u4=
github.com/stretchr/testify v1.8.4/go.mod h1:sz/lmYIOXD/1dqDmKjjqLyZ2RngseejIcXlSw2iwfAo=
github.com/stretchr/testify v1.9.0 h1:HtqpIVDClZ4nwg75+f6Lvsy/wHu+3BoSGCbBAcpTsTg=
github.com/stretchr/testify v1.9.0/go.mod h1:r2ic/lqez/lEtzL7wO/rwa5dbSLXVDPFyf8C91i36aY=
github.com/stretchr/testify v1.10.0 h1:Xv5erBjTwe/5IxqUQTdXv5kgmIvbHo3QQyRwhJsOfJA=
github.com/stretchr/testify v1.10.0/go.mod h1:r2ic/lqez/lEtzL7wO/rwa5dbSLXVDPFyf8C91i36aY=
github.com/stripe/stripe-go/v81 v81.4.0 h1:AuD9XzdAvl193qUCSaLocf8H+nRopOouXhxqJUzCLbw=
github.com/stripe/stripe-go/v81 v81.4.0/go.mod h1:C/F4jlmnGNacvYtBp/LUHCvVUJEZffFQCobkzwY1WOo=
github.com/thanhpk/randstr v1.0.6 h1:psAOktJFD4vV9NEVb3qkhRSMvYh4ORRaj1+w/hn4B+o=
github.com/thanhpk/randstr v1.0.6/go.mod h1:M/H2P1eNLZzlDwAzpkkkUvoyNNMbzRGhESZuEQk3r0U=
github.com/tidwall/btree v0.0.0-20191029221954-400434d76274 h1:G6Z6HvJuPjG6XfNGi/feOATzeJrfgTNJY+rGrHbA04E=
github.com/tidwall/btree v0.0.0-20191029221954-400434d76274/go.mod h1:huei1BkDWJ3/sLXmO+bsCNELL+Bp2Kks9OLyQFkzvA8=
github.com/tidwall/buntdb v1.1.2 h1:noCrqQXL9EKMtcdwJcmuVKSEjqu1ua99RHHgbLTEHRo=
github.com/tidwall/buntdb v1.1.2/go.mod h1:xAzi36Hir4FarpSHyfuZ6JzPJdjRZ8QlLZSntE2mqlI=
github.com/tidwall/gjson v1.3.4/go.mod h1:P256ACg0Mn+j1RXIDXoss50DeIABTYK1PULOJHhxOls=
github.com/tidwall/gjson v1.14.2/go.mod h1:/wbyibRr2FHMks5tjHJ5F8dMZh3AcwJEMf5vlfC0lxk=
github.com/tidwall/gjson v1.18.0 h1:FIDeeyB800efLX89e5a8Y0BNH+LOngJyGrIWxG2FKQY=
github.com/tidwall/gjson v1.18.0/go.mod h1:/wbyibRr2FHMks5tjHJ5F8dMZh3AcwJEMf5vlfC0lxk=
github.com/tidwall/grect v0.0.0-20161006141115-ba9a043346eb h1:5NSYaAdrnblKByzd7XByQEJVT8+9v0W/tIY0Oo4OwrE=
github.com/tidwall/grect v0.0.0-20161006141115-ba9a043346eb/go.mod h1:lKYYLFIr9OIgdgrtgkZ9zgRxRdvPYsExnYBsEAd8W5M=
github.com/tidwall/match v1.0.1/go.mod h1:LujAq0jyVjBy028G1WhWfIzbpQfMO8bBZ6Tyb0+pL9E=
github.com/tidwall/match v1.1.1 h1:+Ho715JplO36QYgwN9PGYNhgZvoUSc9X2c80KVTi+GA=
github.com/tidwall/match v1.1.1/go.mod h1:eRSPERbgtNPcGhD8UCthc6PmLEQXEWd3PRB5JTxsfmM=
github.com/tidwall/pretty v1.0.0/go.mod h1:XNkn88O1ChpSDQmQeStsy+sBenx6DDtFZJxhVysOjyk=
github.com/tidwall/pretty v1.2.0 h1:RWIZEg2iJ8/g6fDDYzMpobmaoGh5OLl4AXtGUGPcqCs=
github.com/tidwall/pretty v1.2.0/go.mod h1:ITEVvHYasfjBbM0u2Pg8T2nJnzm8xPwvNhhsoaGGjNU=
github.com/tidwall/rtree v0.0.0-20180113144539-6cd427091e0e h1:+NL1GDIUOKxVfbp2KoJQD9cTQ6dyP2co9q4yzmT9FZo=
github.com/tidwall/rtree v0.0.0-20180113144539-6cd427091e0e/go.mod h1:/h+UnNGt0IhNNJLkGikcdcJqm66zGD/uJGMRxK/9+Ao=
github.com/tidwall/sjson v1.2.5 h1:kLy8mja+1c9jlljvWTlSazM7cKDRfJuR/bOJhcY5NcY=
github.com/tidwall/sjson v1.2.5/go.mod h1:Fvgq9kS/6ociJEDnK0Fk1cpYF4FIW6ZF7LAe+6jwd28=
github.com/tidwall/tinyqueue v0.0.0-20180302190814-1e39f5511563 h1:Otn9S136ELckZ3KKDyCkxapfufrqDqwmGjcHfAyXRrE=
github.com/tidwall/tinyqueue v0.0.0-20180302190814-1e39f5511563/go.mod h1:mLqSmt7Dv/CNneF2wfcChfN1rvapyQr01LGKnKex0DQ=
github.com/tiktoken-go/tokenizer v0.6.2 h1:t0GN2DvcUZSFWT/62YOgoqb10y7gSXBGs0A+4VCQK+g=
github.com/tiktoken-go/tokenizer v0.6.2/go.mod h1:6UCYI/DtOallbmL7sSy30p6YQv60qNyU/4aVigPOx6w=
github.com/tklauser/go-sysconf v0.3.12 h1:0QaGUFOdQaIVdPgfITYzaTegZvdCjmYO52cSFAEVmqU=
@@ -229,8 +289,24 @@ github.com/ugorji/go/codec v1.1.7/go.mod h1:Ax+UKWsSmolVDwsd+7N3ZtXu+yMGCf907BLY
github.com/ugorji/go/codec v1.2.7/go.mod h1:WGN1fab3R1fzQlVQTkfxVtIBhWDRqOviHU95kRgeqEY=
github.com/ugorji/go/codec v1.2.12 h1:9LC83zGrHhuUA9l16C9AHXAqEV/2wBQ4nkvumAE65EE=
github.com/ugorji/go/codec v1.2.12/go.mod h1:UNopzCgEMSXjBc6AOMqYvWC1ktqTAfzJZUZgYf6w6lg=
github.com/valyala/bytebufferpool v1.0.0 h1:GqA5TC/0021Y/b9FG4Oi9Mr3q7XYx6KllzawFIhcdPw=
github.com/valyala/bytebufferpool v1.0.0/go.mod h1:6bBcMArwyJ5K/AmCkWv1jt77kVWyCJ6HpOuEn7z0Csc=
github.com/valyala/fasthttp v1.34.0 h1:d3AAQJ2DRcxJYHm7OXNXtXt2as1vMDfxeIcFvhmGGm4=
github.com/valyala/fasthttp v1.34.0/go.mod h1:epZA5N+7pY6ZaEKRmstzOuYJx9HI8DI1oaCGZpdH4h0=
github.com/xeipuuv/gojsonpointer v0.0.0-20180127040702-4e3ac2762d5f h1:J9EGpcZtP0E/raorCMxlFGSTBrsSlaDGf3jU/qvAE2c=
github.com/xeipuuv/gojsonpointer v0.0.0-20180127040702-4e3ac2762d5f/go.mod h1:N2zxlSyiKSe5eX1tZViRH5QA0qijqEDrYZiPEAiq3wU=
github.com/xeipuuv/gojsonreference v0.0.0-20180127040603-bd5ef7bd5415 h1:EzJWgHovont7NscjpAxXsDA8S8BMYve8Y5+7cuRE7R0=
github.com/xeipuuv/gojsonreference v0.0.0-20180127040603-bd5ef7bd5415/go.mod h1:GwrjFmJcFw6At/Gs6z4yjiIwzuJ1/+UwLxMQDVQXShQ=
github.com/xeipuuv/gojsonschema v1.2.0 h1:LhYJRs+L4fBtjZUfuSZIKGeVu0QRy8e5Xi7D17UxZ74=
github.com/xeipuuv/gojsonschema v1.2.0/go.mod h1:anYRn/JVcOK2ZgGU+IjEV4nwlhoK5sQluxsYJ78Id3Y=
github.com/xyproto/randomstring v1.0.5 h1:YtlWPoRdgMu3NZtP45drfy1GKoojuR7hmRcnhZqKjWU=
github.com/xyproto/randomstring v1.0.5/go.mod h1:rgmS5DeNXLivK7YprL0pY+lTuhNQW3iGxZ18UQApw/E=
github.com/yalp/jsonpath v0.0.0-20180802001716-5cc68e5049a0 h1:6fRhSjgLCkTD3JnJxvaJ4Sj+TYblw757bqYgZaOq5ZY=
github.com/yalp/jsonpath v0.0.0-20180802001716-5cc68e5049a0/go.mod h1:/LWChgwKmvncFJFHJ7Gvn9wZArjbV5/FppcK2fKk/tI=
github.com/yudai/gojsondiff v1.0.0 h1:27cbfqXLVEJ1o8I6v3y9lg8Ydm53EKqHXAOMxEGlCOA=
github.com/yudai/gojsondiff v1.0.0/go.mod h1:AY32+k2cwILAkW1fbgxQ5mUmMiZFgLIV+FBNExI05xg=
github.com/yudai/golcs v0.0.0-20170316035057-ecda9a501e82 h1:BHyfKlQyqbsFN5p3IfnEUduWvb9is428/nNb5L3U01M=
github.com/yudai/golcs v0.0.0-20170316035057-ecda9a501e82/go.mod h1:lgjkn3NuSvDfVJdfcVVdX+jpBxNmX4rDAzaS45IcYoM=
github.com/yusufpapurcu/wmi v1.2.3 h1:E1ctvB7uKFMOJw3fdOW32DwGE9I7t++CRUEMKvFoFiw=
github.com/yusufpapurcu/wmi v1.2.3/go.mod h1:SBZ9tNy3G9/m5Oi98Zks0QjeHVDvuK0qfxQmPyzfmi0=
golang.org/x/arch v0.0.0-20210923205945-b76863e36670/go.mod h1:5om86z9Hs0C8fWVUuoMHwpExlXzs5Tkyp9hOrfG7pp8=
@@ -247,6 +323,8 @@ golang.org/x/net v0.0.0-20210226172049-e18ecbb05110/go.mod h1:m0MpNAwzfU5UDzcl9v
golang.org/x/net v0.0.0-20210520170846-37e1c6afe023/go.mod h1:9nx3DQGgdP8bBQD5qxJ1jj9UTztislL4KSBs9R2vV5Y=
golang.org/x/net v0.35.0 h1:T5GQRQb2y08kTAByq9L4/bz8cipCdA8FbRTXewonqY8=
golang.org/x/net v0.35.0/go.mod h1:EglIi67kWsHKlRzzVMUD93VMSWGFOMSZgxFjparz1Qk=
golang.org/x/oauth2 v0.30.0 h1:dnDm7JmhM45NNpd8FDDeLhK6FwqbOf4MLCM9zb1BOHI=
golang.org/x/oauth2 v0.30.0/go.mod h1:B++QgG3ZKulg6sRPGD/mqlHQs5rB3Ml9erfeDY7xKlU=
golang.org/x/sync v0.0.0-20210220032951-036812b2e83c/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
golang.org/x/sync v0.11.0 h1:GGz8+XQP4FvTTrjZPzNKTMFtSXH80RAzG+5ghFPgK9w=
golang.org/x/sync v0.11.0/go.mod h1:Czt+wKu1gCyEFDUtn0jG5QVvpJ6rzVqr5aXyt9drQfk=
@@ -257,12 +335,12 @@ golang.org/x/sys v0.0.0-20210423082822-04245dca01da/go.mod h1:h1NjWce9XRLGQEsW7w
golang.org/x/sys v0.0.0-20210615035016-665e8c7367d1/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.0.0-20210630005230-0f9fa26af87c/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.0.0-20210806184541-e5e7981a1069/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.0.0-20220110181412-a018aaa089fe/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.0.0-20221010170243-090e33056c14/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.8.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.11.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.30.0 h1:QjkSwP/36a20jFYWkSue1YwXzLmsV5Gfq7Eiy72C1uc=
golang.org/x/sys v0.30.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA=
golang.org/x/sys v0.31.0 h1:ioabZlmFYtWhL+TRYpcnNlLwhyxaM9kWTDEmfnprqik=
golang.org/x/sys v0.31.0/go.mod h1:BJP2sWEmIv4KK5OTEluFJCKSidICx8ciO85XgH3Ak8k=
golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo=
golang.org/x/text v0.3.2/go.mod h1:bEr9sfX3Q8Zfm5fL9x+3itogRgK3+ptLWKqgva+5dAk=
golang.org/x/text v0.3.3/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ=

35
main.go
View File

@@ -1,7 +1,6 @@
package main
import (
"bytes"
"embed"
"fmt"
"log"
@@ -15,10 +14,9 @@ import (
"one-api/router"
"one-api/service"
"one-api/setting/ratio_setting"
"one-api/src/oauth"
"os"
"strconv"
"strings"
"time"
"github.com/bytedance/gopkg/util/gopool"
"github.com/gin-contrib/sessions"
@@ -36,7 +34,6 @@ var buildFS embed.FS
var indexPage []byte
func main() {
startTime := time.Now()
err := InitResources()
if err != nil {
@@ -149,31 +146,11 @@ func main() {
})
server.Use(sessions.Sessions("session", store))
analyticsInjectBuilder := &strings.Builder{}
if os.Getenv("UMAMI_WEBSITE_ID") != "" {
umamiSiteID := os.Getenv("UMAMI_WEBSITE_ID")
umamiScriptURL := os.Getenv("UMAMI_SCRIPT_URL")
if umamiScriptURL == "" {
umamiScriptURL = "https://analytics.umami.is/script.js"
}
analyticsInjectBuilder.WriteString("<script defer src=\"")
analyticsInjectBuilder.WriteString(umamiScriptURL)
analyticsInjectBuilder.WriteString("\" data-website-id=\"")
analyticsInjectBuilder.WriteString(umamiSiteID)
analyticsInjectBuilder.WriteString("\"></script>")
}
analyticsInject := analyticsInjectBuilder.String()
indexPage = bytes.ReplaceAll(indexPage, []byte("<analytics></analytics>\n"), []byte(analyticsInject))
router.SetRouter(server, buildFS, indexPage)
var port = os.Getenv("PORT")
if port == "" {
port = strconv.Itoa(*common.Port)
}
// Log startup success message
common.LogStartupSuccess(startTime, port)
err = server.Run(":" + port)
if err != nil {
common.FatalLog("failed to start HTTP server: " + err.Error())
@@ -227,5 +204,13 @@ func InitResources() error {
if err != nil {
return err
}
// Initialize OAuth2 server
err = oauth.InitOAuthServer()
if err != nil {
common.SysLog("Warning: Failed to initialize OAuth2 server: " + err.Error())
// OAuth2 失败不应该阻止系统启动
}
return nil
}
}

View File

@@ -8,11 +8,14 @@ import (
"one-api/model"
"one-api/setting"
"one-api/setting/ratio_setting"
"one-api/setting/system_setting"
"one-api/src/oauth"
"strconv"
"strings"
"github.com/gin-contrib/sessions"
"github.com/gin-gonic/gin"
jwt "github.com/golang-jwt/jwt/v5"
)
func validUserInfo(username string, role int) bool {
@@ -177,6 +180,7 @@ func WssAuth(c *gin.Context) {
func TokenAuth() func(c *gin.Context) {
return func(c *gin.Context) {
rawAuth := c.Request.Header.Get("Authorization")
// 先检测是否为ws
if c.Request.Header.Get("Sec-WebSocket-Protocol") != "" {
// Sec-WebSocket-Protocol: realtime, openai-insecure-api-key.sk-xxx, openai-beta.realtime-v1
@@ -235,6 +239,11 @@ func TokenAuth() func(c *gin.Context) {
}
}
if err != nil {
// OAuth Bearer fallback
if tryOAuthBearer(c, rawAuth) {
c.Next()
return
}
abortWithOpenAiMessage(c, http.StatusUnauthorized, err.Error())
return
}
@@ -288,6 +297,74 @@ func TokenAuth() func(c *gin.Context) {
}
}
// tryOAuthBearer validates an OAuth JWT access token and sets minimal context for relay
func tryOAuthBearer(c *gin.Context, rawAuth string) bool {
if rawAuth == "" || !strings.HasPrefix(rawAuth, "Bearer ") {
return false
}
tokenString := strings.TrimSpace(strings.TrimPrefix(rawAuth, "Bearer "))
if tokenString == "" {
return false
}
settings := system_setting.GetOAuth2Settings()
// Parse & verify
parsed, err := jwt.Parse(tokenString, func(t *jwt.Token) (interface{}, error) {
if _, ok := t.Method.(*jwt.SigningMethodRSA); !ok {
return nil, jwt.ErrTokenSignatureInvalid
}
if kid, ok := t.Header["kid"].(string); ok {
if settings.JWTKeyID != "" && kid != settings.JWTKeyID {
return nil, jwt.ErrTokenSignatureInvalid
}
}
pub := oauth.GetRSAPublicKey()
if pub == nil {
return nil, jwt.ErrTokenUnverifiable
}
return pub, nil
})
if err != nil || parsed == nil || !parsed.Valid {
return false
}
claims, ok := parsed.Claims.(jwt.MapClaims)
if !ok {
return false
}
// issuer check when configured
if iss, ok2 := claims["iss"].(string); !ok2 || (settings.Issuer != "" && iss != settings.Issuer) {
return false
}
// revoke check
if jti, ok2 := claims["jti"].(string); ok2 && jti != "" {
if revoked, _ := model.IsTokenRevoked(jti); revoked {
return false
}
}
// scope check: must contain api:read or api:write or admin
scope, _ := claims["scope"].(string)
scopePadded := " " + scope + " "
if !(strings.Contains(scopePadded, " api:read ") || strings.Contains(scopePadded, " api:write ") || strings.Contains(scopePadded, " admin ")) {
return false
}
// subject must be user id to support quota logic
sub, _ := claims["sub"].(string)
uid, err := strconv.Atoi(sub)
if err != nil || uid <= 0 {
return false
}
// load user cache & set context
userCache, err := model.GetUserCache(uid)
if err != nil || userCache == nil || userCache.Status != common.UserStatusEnabled {
return false
}
c.Set("id", uid)
c.Set("group", userCache.Group)
c.Set("user_group", userCache.Group)
// set UsingGroup
common.SetContextKey(c, constant.ContextKeyUsingGroup, userCache.Group)
return true
}
func SetupContextForToken(c *gin.Context, token *model.Token, parts ...string) error {
if token == nil {
return fmt.Errorf("token is nil")

291
middleware/oauth_jwt.go Normal file
View File

@@ -0,0 +1,291 @@
package middleware
import (
"crypto/rsa"
"fmt"
"net/http"
"one-api/common"
"one-api/model"
"one-api/setting/system_setting"
"one-api/src/oauth"
"strings"
"github.com/gin-gonic/gin"
"github.com/golang-jwt/jwt/v5"
)
// OAuthJWTAuth OAuth2 JWT认证中间件
func OAuthJWTAuth() gin.HandlerFunc {
return func(c *gin.Context) {
// 检查OAuth2是否启用
settings := system_setting.GetOAuth2Settings()
if !settings.Enabled {
c.Next()
return
}
// 获取Authorization header
authHeader := c.GetHeader("Authorization")
if authHeader == "" {
c.Next() // 没有Authorization header继续到下一个中间件
return
}
// 检查是否为Bearer token
if !strings.HasPrefix(authHeader, "Bearer ") {
c.Next() // 不是Bearer token继续到下一个中间件
return
}
tokenString := strings.TrimPrefix(authHeader, "Bearer ")
if tokenString == "" {
abortWithOAuthError(c, "invalid_token", "Missing token")
return
}
// 验证JWT token
claims, err := validateOAuthJWT(tokenString)
if err != nil {
abortWithOAuthError(c, "invalid_token", err.Error())
return
}
// 验证token的有效性
if err := validateOAuthClaims(claims); err != nil {
abortWithOAuthError(c, "invalid_token", err.Error())
return
}
// 设置上下文信息
setOAuthContext(c, claims)
c.Next()
}
}
// validateOAuthJWT 验证OAuth2 JWT令牌
func validateOAuthJWT(tokenString string) (jwt.MapClaims, error) {
// 解析JWT而不验证签名先获取header中的kid
token, err := jwt.Parse(tokenString, func(token *jwt.Token) (interface{}, error) {
// 检查签名方法
if _, ok := token.Method.(*jwt.SigningMethodRSA); !ok {
return nil, fmt.Errorf("unexpected signing method: %v", token.Header["alg"])
}
// 获取kid
kid, ok := token.Header["kid"].(string)
if !ok {
return nil, fmt.Errorf("missing kid in token header")
}
// 根据kid获取公钥
publicKey, err := getPublicKeyByKid(kid)
if err != nil {
return nil, fmt.Errorf("failed to get public key: %w", err)
}
return publicKey, nil
})
if err != nil {
return nil, fmt.Errorf("failed to parse token: %w", err)
}
if !token.Valid {
return nil, fmt.Errorf("invalid token")
}
claims, ok := token.Claims.(jwt.MapClaims)
if !ok {
return nil, fmt.Errorf("invalid token claims")
}
return claims, nil
}
// getPublicKeyByKid 根据kid获取公钥
func getPublicKeyByKid(kid string) (*rsa.PublicKey, error) {
// 这里需要从JWKS获取公钥
// 在实际实现中你可能需要从OAuth server获取JWKS
// 这里先实现一个简单版本
// TODO: 实现JWKS缓存和刷新机制
pub := oauth.GetPublicKeyByKid(kid)
if pub == nil {
return nil, fmt.Errorf("unknown kid: %s", kid)
}
return pub, nil
}
// validateOAuthClaims 验证OAuth2 claims
func validateOAuthClaims(claims jwt.MapClaims) error {
settings := system_setting.GetOAuth2Settings()
// 验证issuer若配置了 Issuer 则强校验,否则仅要求存在)
if iss, ok := claims["iss"].(string); ok {
if settings.Issuer != "" && iss != settings.Issuer {
return fmt.Errorf("invalid issuer")
}
} else {
return fmt.Errorf("missing issuer claim")
}
// 验证audience
// if aud, ok := claims["aud"].(string); ok {
// // TODO: 验证audience
// }
// 验证客户端ID
if clientID, ok := claims["client_id"].(string); ok {
// 验证客户端是否存在且有效
client, err := model.GetOAuthClientByID(clientID)
if err != nil {
return fmt.Errorf("invalid client")
}
if client.Status != common.UserStatusEnabled {
return fmt.Errorf("client disabled")
}
// 检查是否被撤销
if jti, ok := claims["jti"].(string); ok && jti != "" {
revoked, _ := model.IsTokenRevoked(jti)
if revoked {
return fmt.Errorf("token revoked")
}
}
} else {
return fmt.Errorf("missing client_id claim")
}
return nil
}
// setOAuthContext 设置OAuth上下文信息
func setOAuthContext(c *gin.Context, claims jwt.MapClaims) {
c.Set("oauth_claims", claims)
c.Set("oauth_authenticated", true)
// 提取基本信息
if clientID, ok := claims["client_id"].(string); ok {
c.Set("oauth_client_id", clientID)
}
if scope, ok := claims["scope"].(string); ok {
c.Set("oauth_scope", scope)
}
if sub, ok := claims["sub"].(string); ok {
c.Set("oauth_subject", sub)
}
// 对于client_credentials流程subject就是client_id
// 对于authorization_code流程subject是用户ID
if grantType, ok := claims["grant_type"].(string); ok {
c.Set("oauth_grant_type", grantType)
}
}
// abortWithOAuthError 返回OAuth错误响应
func abortWithOAuthError(c *gin.Context, errorCode, description string) {
c.Header("WWW-Authenticate", fmt.Sprintf(`Bearer error="%s", error_description="%s"`, errorCode, description))
c.JSON(http.StatusUnauthorized, gin.H{
"error": errorCode,
"error_description": description,
})
c.Abort()
}
// RequireOAuthScope OAuth2 scope验证中间件
func RequireOAuthScope(requiredScope string) gin.HandlerFunc {
return func(c *gin.Context) {
// 检查是否通过OAuth认证
if !c.GetBool("oauth_authenticated") {
abortWithOAuthError(c, "insufficient_scope", "OAuth2 authentication required")
return
}
// 获取token的scope
scope, exists := c.Get("oauth_scope")
if !exists {
abortWithOAuthError(c, "insufficient_scope", "No scope in token")
return
}
scopeStr, ok := scope.(string)
if !ok {
abortWithOAuthError(c, "insufficient_scope", "Invalid scope format")
return
}
// 检查是否包含所需的scope
scopes := strings.Split(scopeStr, " ")
for _, s := range scopes {
if strings.TrimSpace(s) == requiredScope {
c.Next()
return
}
}
abortWithOAuthError(c, "insufficient_scope", fmt.Sprintf("Required scope: %s", requiredScope))
}
}
// OptionalOAuthAuth 可选的OAuth认证中间件不会阻止请求
func OptionalOAuthAuth() gin.HandlerFunc {
return func(c *gin.Context) {
// 尝试OAuth认证但不会阻止请求
authHeader := c.GetHeader("Authorization")
if authHeader != "" && strings.HasPrefix(authHeader, "Bearer ") {
tokenString := strings.TrimPrefix(authHeader, "Bearer ")
if claims, err := validateOAuthJWT(tokenString); err == nil {
if validateOAuthClaims(claims) == nil {
setOAuthContext(c, claims)
}
}
}
c.Next()
}
}
// RequireOAuthScopeIfPresent enforces scope only when OAuth is present; otherwise no-op
func RequireOAuthScopeIfPresent(requiredScope string) gin.HandlerFunc {
return func(c *gin.Context) {
if !c.GetBool("oauth_authenticated") {
c.Next()
return
}
scope, exists := c.Get("oauth_scope")
if !exists {
abortWithOAuthError(c, "insufficient_scope", "No scope in token")
return
}
scopeStr, ok := scope.(string)
if !ok {
abortWithOAuthError(c, "insufficient_scope", "Invalid scope format")
return
}
scopes := strings.Split(scopeStr, " ")
for _, s := range scopes {
if strings.TrimSpace(s) == requiredScope {
c.Next()
return
}
}
abortWithOAuthError(c, "insufficient_scope", fmt.Sprintf("Required scope: %s", requiredScope))
}
}
// GetOAuthClaims 获取OAuth claims
func GetOAuthClaims(c *gin.Context) (jwt.MapClaims, bool) {
claims, exists := c.Get("oauth_claims")
if !exists {
return nil, false
}
mapClaims, ok := claims.(jwt.MapClaims)
return mapClaims, ok
}
// IsOAuthAuthenticated 检查是否通过OAuth认证
func IsOAuthAuthenticated(c *gin.Context) bool {
return c.GetBool("oauth_authenticated")
}

View File

@@ -265,6 +265,7 @@ func migrateDB() error {
&Setup{},
&TwoFA{},
&TwoFABackupCode{},
&OAuthClient{},
)
if err != nil {
return err

183
model/oauth_client.go Normal file
View File

@@ -0,0 +1,183 @@
package model
import (
"encoding/json"
"one-api/common"
"strings"
"time"
"gorm.io/gorm"
)
// OAuthClient OAuth2 客户端模型
type OAuthClient struct {
ID string `json:"id" gorm:"type:varchar(64);primaryKey"`
Secret string `json:"secret" gorm:"type:varchar(128);not null"`
Name string `json:"name" gorm:"type:varchar(255);not null"`
Domain string `json:"domain" gorm:"type:varchar(255)"` // 允许的重定向域名
RedirectURIs string `json:"redirect_uris" gorm:"type:text"` // JSON array of redirect URIs
GrantTypes string `json:"grant_types" gorm:"type:varchar(255);default:'client_credentials'"`
Scopes string `json:"scopes" gorm:"type:varchar(255);default:'api:read'"`
RequirePKCE bool `json:"require_pkce" gorm:"default:true"`
Status int `json:"status" gorm:"type:int;default:1"` // 1: enabled, 2: disabled
CreatedBy int `json:"created_by" gorm:"type:int;not null"` // 创建者用户ID
CreatedTime int64 `json:"created_time" gorm:"bigint"`
LastUsedTime int64 `json:"last_used_time" gorm:"bigint;default:0"`
TokenCount int `json:"token_count" gorm:"type:int;default:0"` // 已签发的token数量
Description string `json:"description" gorm:"type:text"`
ClientType string `json:"client_type" gorm:"type:varchar(32);default:'confidential'"` // confidential, public
DeletedAt gorm.DeletedAt `gorm:"index"`
}
// GetRedirectURIs 获取重定向URI列表
func (c *OAuthClient) GetRedirectURIs() []string {
if c.RedirectURIs == "" {
return []string{}
}
var uris []string
err := json.Unmarshal([]byte(c.RedirectURIs), &uris)
if err != nil {
common.SysLog("failed to unmarshal redirect URIs: " + err.Error())
return []string{}
}
return uris
}
// SetRedirectURIs 设置重定向URI列表
func (c *OAuthClient) SetRedirectURIs(uris []string) {
data, err := json.Marshal(uris)
if err != nil {
common.SysLog("failed to marshal redirect URIs: " + err.Error())
return
}
c.RedirectURIs = string(data)
}
// GetGrantTypes 获取允许的授权类型列表
func (c *OAuthClient) GetGrantTypes() []string {
if c.GrantTypes == "" {
return []string{"client_credentials"}
}
return strings.Split(c.GrantTypes, ",")
}
// SetGrantTypes 设置允许的授权类型列表
func (c *OAuthClient) SetGrantTypes(types []string) {
c.GrantTypes = strings.Join(types, ",")
}
// GetScopes 获取允许的scope列表
func (c *OAuthClient) GetScopes() []string {
if c.Scopes == "" {
return []string{"api:read"}
}
return strings.Split(c.Scopes, ",")
}
// SetScopes 设置允许的scope列表
func (c *OAuthClient) SetScopes(scopes []string) {
c.Scopes = strings.Join(scopes, ",")
}
// ValidateRedirectURI 验证重定向URI是否有效
func (c *OAuthClient) ValidateRedirectURI(uri string) bool {
allowedURIs := c.GetRedirectURIs()
for _, allowedURI := range allowedURIs {
if allowedURI == uri {
return true
}
}
return false
}
// ValidateGrantType 验证授权类型是否被允许
func (c *OAuthClient) ValidateGrantType(grantType string) bool {
allowedTypes := c.GetGrantTypes()
for _, allowedType := range allowedTypes {
if allowedType == grantType {
return true
}
}
return false
}
// ValidateScope 验证scope是否被允许
func (c *OAuthClient) ValidateScope(scope string) bool {
allowedScopes := c.GetScopes()
requestedScopes := strings.Split(scope, " ")
for _, requestedScope := range requestedScopes {
requestedScope = strings.TrimSpace(requestedScope)
if requestedScope == "" {
continue
}
found := false
for _, allowedScope := range allowedScopes {
if allowedScope == requestedScope {
found = true
break
}
}
if !found {
return false
}
}
return true
}
// BeforeCreate GORM hook - 在创建前设置时间
func (c *OAuthClient) BeforeCreate(tx *gorm.DB) (err error) {
c.CreatedTime = time.Now().Unix()
return
}
// UpdateLastUsedTime 更新最后使用时间
func (c *OAuthClient) UpdateLastUsedTime() error {
c.LastUsedTime = time.Now().Unix()
c.TokenCount++
return DB.Model(c).Select("last_used_time", "token_count").Updates(c).Error
}
// GetOAuthClientByID 根据ID获取OAuth客户端
func GetOAuthClientByID(id string) (*OAuthClient, error) {
var client OAuthClient
err := DB.Where("id = ? AND status = ?", id, common.UserStatusEnabled).First(&client).Error
return &client, err
}
// GetAllOAuthClients 获取所有OAuth客户端
func GetAllOAuthClients(startIdx int, num int) ([]*OAuthClient, error) {
var clients []*OAuthClient
err := DB.Order("created_time desc").Limit(num).Offset(startIdx).Find(&clients).Error
return clients, err
}
// SearchOAuthClients 搜索OAuth客户端
func SearchOAuthClients(keyword string) ([]*OAuthClient, error) {
var clients []*OAuthClient
err := DB.Where("name LIKE ? OR id LIKE ? OR description LIKE ?",
"%"+keyword+"%", "%"+keyword+"%", "%"+keyword+"%").Find(&clients).Error
return clients, err
}
// CreateOAuthClient 创建OAuth客户端
func CreateOAuthClient(client *OAuthClient) error {
return DB.Create(client).Error
}
// UpdateOAuthClient 更新OAuth客户端
func UpdateOAuthClient(client *OAuthClient) error {
return DB.Save(client).Error
}
// DeleteOAuthClient 删除OAuth客户端
func DeleteOAuthClient(id string) error {
return DB.Where("id = ?", id).Delete(&OAuthClient{}).Error
}
// CountOAuthClients 统计OAuth客户端数量
func CountOAuthClients() (int64, error) {
var count int64
err := DB.Model(&OAuthClient{}).Count(&count).Error
return count, err
}

View File

@@ -0,0 +1,57 @@
package model
import (
"fmt"
"one-api/common"
"sync"
"time"
)
var revokedMem sync.Map // jti -> exp(unix)
func RevokeToken(jti string, exp int64) error {
if jti == "" {
return nil
}
// Prefer Redis, else in-memory
if common.RedisEnabled {
ttl := time.Duration(0)
if exp > 0 {
ttl = time.Until(time.Unix(exp, 0))
}
if ttl <= 0 {
ttl = time.Minute
}
key := fmt.Sprintf("oauth:revoked:%s", jti)
return common.RedisSet(key, "1", ttl)
}
if exp <= 0 {
exp = time.Now().Add(time.Minute).Unix()
}
revokedMem.Store(jti, exp)
return nil
}
func IsTokenRevoked(jti string) (bool, error) {
if jti == "" {
return false, nil
}
if common.RedisEnabled {
key := fmt.Sprintf("oauth:revoked:%s", jti)
if _, err := common.RedisGet(key); err == nil {
return true, nil
} else {
// Not found or error; treat as not revoked on error to avoid hard failures
return false, nil
}
}
// In-memory check
if v, ok := revokedMem.Load(jti); ok {
exp, _ := v.(int64)
if exp == 0 || time.Now().Unix() <= exp {
return true, nil
}
revokedMem.Delete(jti)
}
return false, nil
}

View File

@@ -82,7 +82,6 @@ func InitOptionMap() {
common.OptionMap["StripeWebhookSecret"] = setting.StripeWebhookSecret
common.OptionMap["StripePriceId"] = setting.StripePriceId
common.OptionMap["StripeUnitPrice"] = strconv.FormatFloat(setting.StripeUnitPrice, 'f', -1, 64)
common.OptionMap["StripePromotionCodesEnabled"] = strconv.FormatBool(setting.StripePromotionCodesEnabled)
common.OptionMap["TopupGroupRatio"] = common.TopupGroupRatio2JSONString()
common.OptionMap["Chats"] = setting.Chats2JsonString()
common.OptionMap["AutoGroups"] = setting.AutoGroups2JsonString()
@@ -113,9 +112,6 @@ func InitOptionMap() {
common.OptionMap["GroupGroupRatio"] = ratio_setting.GroupGroupRatio2JSONString()
common.OptionMap["UserUsableGroups"] = setting.UserUsableGroups2JSONString()
common.OptionMap["CompletionRatio"] = ratio_setting.CompletionRatio2JSONString()
common.OptionMap["ImageRatio"] = ratio_setting.ImageRatio2JSONString()
common.OptionMap["AudioRatio"] = ratio_setting.AudioRatio2JSONString()
common.OptionMap["AudioCompletionRatio"] = ratio_setting.AudioCompletionRatio2JSONString()
common.OptionMap["TopUpLink"] = common.TopUpLink
//common.OptionMap["ChatLink"] = common.ChatLink
//common.OptionMap["ChatLink2"] = common.ChatLink2
@@ -331,8 +327,6 @@ func updateOptionMap(key string, value string) (err error) {
setting.StripeUnitPrice, _ = strconv.ParseFloat(value, 64)
case "StripeMinTopUp":
setting.StripeMinTopUp, _ = strconv.Atoi(value)
case "StripePromotionCodesEnabled":
setting.StripePromotionCodesEnabled = value == "true"
case "TopupGroupRatio":
err = common.UpdateTopupGroupRatioByJSONString(value)
case "GitHubClientId":
@@ -403,12 +397,6 @@ func updateOptionMap(key string, value string) (err error) {
err = ratio_setting.UpdateModelPriceByJSONString(value)
case "CacheRatio":
err = ratio_setting.UpdateCacheRatioByJSONString(value)
case "ImageRatio":
err = ratio_setting.UpdateImageRatioByJSONString(value)
case "AudioRatio":
err = ratio_setting.UpdateAudioRatioByJSONString(value)
case "AudioCompletionRatio":
err = ratio_setting.UpdateAudioCompletionRatioByJSONString(value)
case "TopUpLink":
common.TopUpLink = value
//case "ChatLink":

View File

@@ -265,7 +265,6 @@ func doRequest(c *gin.Context, req *http.Request, info *common.RelayInfo) (*http
resp, err := client.Do(req)
if err != nil {
logger.LogError(c, "do request failed: "+err.Error())
return nil, types.NewError(err, types.ErrorCodeDoRequestFailed, types.ErrOptionWithHideErrMsg("upstream error: do request failed"))
}
if resp == nil {

View File

@@ -21,10 +21,6 @@ var awsModelIDMap = map[string]string{
"nova-lite-v1:0": "amazon.nova-lite-v1:0",
"nova-pro-v1:0": "amazon.nova-pro-v1:0",
"nova-premier-v1:0": "amazon.nova-premier-v1:0",
"nova-canvas-v1:0": "amazon.nova-canvas-v1:0",
"nova-reel-v1:0": "amazon.nova-reel-v1:0",
"nova-reel-v1:1": "amazon.nova-reel-v1:1",
"nova-sonic-v1:0": "amazon.nova-sonic-v1:0",
}
var awsModelCanCrossRegionMap = map[string]map[string]bool{
@@ -86,27 +82,10 @@ var awsModelCanCrossRegionMap = map[string]map[string]bool{
"apac": true,
},
"amazon.nova-premier-v1:0": {
"us": true,
},
"amazon.nova-canvas-v1:0": {
"us": true,
"eu": true,
"apac": true,
},
"amazon.nova-reel-v1:0": {
"us": true,
"eu": true,
"apac": true,
},
"amazon.nova-reel-v1:1": {
"us": true,
},
"amazon.nova-sonic-v1:0": {
"us": true,
"eu": true,
"apac": true,
},
}
}}
var awsRegionCrossModelPrefixMap = map[string]string{
"us": "us",

View File

@@ -3,17 +3,17 @@ package deepseek
import (
"errors"
"fmt"
"github.com/gin-gonic/gin"
"io"
"net/http"
"one-api/dto"
"one-api/relay/channel"
"one-api/relay/channel/claude"
"one-api/relay/channel/openai"
relaycommon "one-api/relay/common"
"one-api/relay/constant"
"one-api/types"
"strings"
"github.com/gin-gonic/gin"
)
type Adaptor struct {
@@ -25,7 +25,7 @@ func (a *Adaptor) ConvertGeminiRequest(*gin.Context, *relaycommon.RelayInfo, *dt
}
func (a *Adaptor) ConvertClaudeRequest(c *gin.Context, info *relaycommon.RelayInfo, req *dto.ClaudeRequest) (any, error) {
adaptor := claude.Adaptor{}
adaptor := openai.Adaptor{}
return adaptor.ConvertClaudeRequest(c, info, req)
}
@@ -44,19 +44,14 @@ func (a *Adaptor) Init(info *relaycommon.RelayInfo) {
func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) {
fimBaseUrl := info.ChannelBaseUrl
switch info.RelayFormat {
case types.RelayFormatClaude:
return fmt.Sprintf("%s/anthropic/v1/messages", info.ChannelBaseUrl), nil
if !strings.HasSuffix(info.ChannelBaseUrl, "/beta") {
fimBaseUrl += "/beta"
}
switch info.RelayMode {
case constant.RelayModeCompletions:
return fmt.Sprintf("%s/completions", fimBaseUrl), nil
default:
if !strings.HasSuffix(info.ChannelBaseUrl, "/beta") {
fimBaseUrl += "/beta"
}
switch info.RelayMode {
case constant.RelayModeCompletions:
return fmt.Sprintf("%s/completions", fimBaseUrl), nil
default:
return fmt.Sprintf("%s/v1/chat/completions", info.ChannelBaseUrl), nil
}
return fmt.Sprintf("%s/v1/chat/completions", info.ChannelBaseUrl), nil
}
}
@@ -92,17 +87,12 @@ func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, request
}
func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage any, err *types.NewAPIError) {
switch info.RelayFormat {
case types.RelayFormatClaude:
if info.IsStream {
return claude.ClaudeStreamHandler(c, resp, info, claude.RequestModeMessage)
} else {
return claude.ClaudeHandler(c, resp, info, claude.RequestModeMessage)
}
default:
adaptor := openai.Adaptor{}
return adaptor.DoResponse(c, resp, info)
if info.IsStream {
usage, err = openai.OaiStreamHandler(c, info, resp)
} else {
usage, err = openai.OpenaiHandler(c, info, resp)
}
return
}
func (a *Adaptor) GetModelList() []string {

View File

@@ -215,8 +215,8 @@ func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, request
func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage any, err *types.NewAPIError) {
if info.RelayMode == constant.RelayModeGemini {
if strings.Contains(info.RequestURLPath, ":embedContent") ||
strings.Contains(info.RequestURLPath, ":batchEmbedContents") {
if strings.HasSuffix(info.RequestURLPath, ":embedContent") ||
strings.HasSuffix(info.RequestURLPath, ":batchEmbedContents") {
return NativeGeminiEmbeddingHandler(c, resp, info)
}
if info.IsStream {

View File

@@ -23,7 +23,6 @@ import (
"github.com/gin-gonic/gin"
)
// https://cloud.google.com/vertex-ai/generative-ai/docs/model-reference/inference?hl=zh-cn#blob
var geminiSupportedMimeTypes = map[string]bool{
"application/pdf": true,
"audio/mpeg": true,
@@ -31,7 +30,6 @@ var geminiSupportedMimeTypes = map[string]bool{
"audio/wav": true,
"image/png": true,
"image/jpeg": true,
"image/webp": true,
"text/plain": true,
"video/mov": true,
"video/mpeg": true,
@@ -245,7 +243,6 @@ func CovertGemini2OpenAI(c *gin.Context, textRequest dto.GeneralOpenAIRequest, i
functions := make([]dto.FunctionRequest, 0, len(textRequest.Tools))
googleSearch := false
codeExecution := false
urlContext := false
for _, tool := range textRequest.Tools {
if tool.Function.Name == "googleSearch" {
googleSearch = true
@@ -255,10 +252,6 @@ func CovertGemini2OpenAI(c *gin.Context, textRequest dto.GeneralOpenAIRequest, i
codeExecution = true
continue
}
if tool.Function.Name == "urlContext" {
urlContext = true
continue
}
if tool.Function.Parameters != nil {
params, ok := tool.Function.Parameters.(map[string]interface{})
@@ -286,11 +279,6 @@ func CovertGemini2OpenAI(c *gin.Context, textRequest dto.GeneralOpenAIRequest, i
GoogleSearch: make(map[string]string),
})
}
if urlContext {
geminiTools = append(geminiTools, dto.GeminiChatTool{
URLContext: make(map[string]string),
})
}
if len(functions) > 0 {
geminiTools = append(geminiTools, dto.GeminiChatTool{
FunctionDeclarations: functions,

View File

@@ -25,7 +25,7 @@ func (a *Adaptor) ConvertGeminiRequest(*gin.Context, *relaycommon.RelayInfo, *dt
}
func (a *Adaptor) ConvertClaudeRequest(c *gin.Context, info *relaycommon.RelayInfo, req *dto.ClaudeRequest) (any, error) {
adaptor := claude.Adaptor{}
adaptor := openai.Adaptor{}
return adaptor.ConvertClaudeRequest(c, info, req)
}

View File

@@ -10,7 +10,6 @@ import (
relaycommon "one-api/relay/common"
relayconstant "one-api/relay/constant"
"one-api/types"
"strings"
"github.com/gin-gonic/gin"
)
@@ -18,7 +17,10 @@ import (
type Adaptor struct {
}
func (a *Adaptor) ConvertGeminiRequest(*gin.Context, *relaycommon.RelayInfo, *dto.GeminiChatRequest) (any, error) { return nil, errors.New("not implemented") }
func (a *Adaptor) ConvertGeminiRequest(*gin.Context, *relaycommon.RelayInfo, *dto.GeminiChatRequest) (any, error) {
//TODO implement me
return nil, errors.New("not implemented")
}
func (a *Adaptor) ConvertClaudeRequest(c *gin.Context, info *relaycommon.RelayInfo, request *dto.ClaudeRequest) (any, error) {
openaiAdaptor := openai.Adaptor{}
@@ -29,21 +31,32 @@ func (a *Adaptor) ConvertClaudeRequest(c *gin.Context, info *relaycommon.RelayIn
openaiRequest.(*dto.GeneralOpenAIRequest).StreamOptions = &dto.StreamOptions{
IncludeUsage: true,
}
// map to ollama chat request (Claude -> OpenAI -> Ollama chat)
return openAIChatToOllamaChat(c, openaiRequest.(*dto.GeneralOpenAIRequest))
return requestOpenAI2Ollama(c, openaiRequest.(*dto.GeneralOpenAIRequest))
}
func (a *Adaptor) ConvertAudioRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.AudioRequest) (io.Reader, error) { return nil, errors.New("not implemented") }
func (a *Adaptor) ConvertAudioRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.AudioRequest) (io.Reader, error) {
//TODO implement me
return nil, errors.New("not implemented")
}
func (a *Adaptor) ConvertImageRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.ImageRequest) (any, error) { return nil, errors.New("not implemented") }
func (a *Adaptor) ConvertImageRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.ImageRequest) (any, error) {
//TODO implement me
return nil, errors.New("not implemented")
}
func (a *Adaptor) Init(info *relaycommon.RelayInfo) {
}
func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) {
if info.RelayMode == relayconstant.RelayModeEmbeddings { return info.ChannelBaseUrl + "/api/embed", nil }
if strings.Contains(info.RequestURLPath, "/v1/completions") || info.RelayMode == relayconstant.RelayModeCompletions { return info.ChannelBaseUrl + "/api/generate", nil }
return info.ChannelBaseUrl + "/api/chat", nil
if info.RelayFormat == types.RelayFormatClaude {
return info.ChannelBaseUrl + "/v1/chat/completions", nil
}
switch info.RelayMode {
case relayconstant.RelayModeEmbeddings:
return info.ChannelBaseUrl + "/api/embed", nil
default:
return relaycommon.GetFullRequestURL(info.ChannelBaseUrl, info.RequestURLPath, info.ChannelType), nil
}
}
func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Header, info *relaycommon.RelayInfo) error {
@@ -53,12 +66,10 @@ func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Header, info *rel
}
func (a *Adaptor) ConvertOpenAIRequest(c *gin.Context, info *relaycommon.RelayInfo, request *dto.GeneralOpenAIRequest) (any, error) {
if request == nil { return nil, errors.New("request is nil") }
// decide generate or chat
if strings.Contains(info.RequestURLPath, "/v1/completions") || info.RelayMode == relayconstant.RelayModeCompletions {
return openAIToGenerate(c, request)
if request == nil {
return nil, errors.New("request is nil")
}
return openAIChatToOllamaChat(c, request)
return requestOpenAI2Ollama(c, request)
}
func (a *Adaptor) ConvertRerankRequest(c *gin.Context, relayMode int, request dto.RerankRequest) (any, error) {
@@ -69,7 +80,10 @@ func (a *Adaptor) ConvertEmbeddingRequest(c *gin.Context, info *relaycommon.Rela
return requestOpenAI2Embeddings(request), nil
}
func (a *Adaptor) ConvertOpenAIResponsesRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.OpenAIResponsesRequest) (any, error) { return nil, errors.New("not implemented") }
func (a *Adaptor) ConvertOpenAIResponsesRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.OpenAIResponsesRequest) (any, error) {
// TODO implement me
return nil, errors.New("not implemented")
}
func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (any, error) {
return channel.DoApiRequest(a, c, info, requestBody)
@@ -78,13 +92,15 @@ func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, request
func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage any, err *types.NewAPIError) {
switch info.RelayMode {
case relayconstant.RelayModeEmbeddings:
return ollamaEmbeddingHandler(c, info, resp)
usage, err = ollamaEmbeddingHandler(c, info, resp)
default:
if info.IsStream {
return ollamaStreamHandler(c, info, resp)
usage, err = openai.OaiStreamHandler(c, info, resp)
} else {
usage, err = openai.OpenaiHandler(c, info, resp)
}
return ollamaChatHandler(c, info, resp)
}
return
}
func (a *Adaptor) GetModelList() []string {

View File

@@ -2,69 +2,48 @@ package ollama
import (
"encoding/json"
"one-api/dto"
)
type OllamaChatMessage struct {
Role string `json:"role"`
Content string `json:"content,omitempty"`
Images []string `json:"images,omitempty"`
ToolCalls []OllamaToolCall `json:"tool_calls,omitempty"`
ToolName string `json:"tool_name,omitempty"`
Thinking json.RawMessage `json:"thinking,omitempty"`
type OllamaRequest struct {
Model string `json:"model,omitempty"`
Messages []dto.Message `json:"messages,omitempty"`
Stream bool `json:"stream,omitempty"`
Temperature *float64 `json:"temperature,omitempty"`
Seed float64 `json:"seed,omitempty"`
Topp float64 `json:"top_p,omitempty"`
TopK int `json:"top_k,omitempty"`
Stop any `json:"stop,omitempty"`
MaxTokens uint `json:"max_tokens,omitempty"`
Tools []dto.ToolCallRequest `json:"tools,omitempty"`
ResponseFormat any `json:"response_format,omitempty"`
FrequencyPenalty float64 `json:"frequency_penalty,omitempty"`
PresencePenalty float64 `json:"presence_penalty,omitempty"`
Suffix any `json:"suffix,omitempty"`
StreamOptions *dto.StreamOptions `json:"stream_options,omitempty"`
Prompt any `json:"prompt,omitempty"`
Think json.RawMessage `json:"think,omitempty"`
}
type OllamaToolFunction struct {
Name string `json:"name"`
Description string `json:"description,omitempty"`
Parameters interface{} `json:"parameters,omitempty"`
}
type OllamaTool struct {
Type string `json:"type"`
Function OllamaToolFunction `json:"function"`
}
type OllamaToolCall struct {
Function struct {
Name string `json:"name"`
Arguments interface{} `json:"arguments"`
} `json:"function"`
}
type OllamaChatRequest struct {
Model string `json:"model"`
Messages []OllamaChatMessage `json:"messages"`
Tools interface{} `json:"tools,omitempty"`
Format interface{} `json:"format,omitempty"`
Stream bool `json:"stream,omitempty"`
Options map[string]any `json:"options,omitempty"`
KeepAlive interface{} `json:"keep_alive,omitempty"`
Think json.RawMessage `json:"think,omitempty"`
}
type OllamaGenerateRequest struct {
Model string `json:"model"`
Prompt string `json:"prompt,omitempty"`
Suffix string `json:"suffix,omitempty"`
Images []string `json:"images,omitempty"`
Format interface{} `json:"format,omitempty"`
Stream bool `json:"stream,omitempty"`
Options map[string]any `json:"options,omitempty"`
KeepAlive interface{} `json:"keep_alive,omitempty"`
Think json.RawMessage `json:"think,omitempty"`
type Options struct {
Seed int `json:"seed,omitempty"`
Temperature *float64 `json:"temperature,omitempty"`
TopK int `json:"top_k,omitempty"`
TopP float64 `json:"top_p,omitempty"`
FrequencyPenalty float64 `json:"frequency_penalty,omitempty"`
PresencePenalty float64 `json:"presence_penalty,omitempty"`
NumPredict int `json:"num_predict,omitempty"`
NumCtx int `json:"num_ctx,omitempty"`
}
type OllamaEmbeddingRequest struct {
Model string `json:"model"`
Input interface{} `json:"input"`
Options map[string]any `json:"options,omitempty"`
Dimensions int `json:"dimensions,omitempty"`
Model string `json:"model,omitempty"`
Input []string `json:"input"`
Options *Options `json:"options,omitempty"`
}
type OllamaEmbeddingResponse struct {
Error string `json:"error,omitempty"`
Model string `json:"model"`
Embeddings [][]float64 `json:"embeddings"`
PromptEvalCount int `json:"prompt_eval_count,omitempty"`
Error string `json:"error,omitempty"`
Model string `json:"model"`
Embedding [][]float64 `json:"embeddings,omitempty"`
}

View File

@@ -1,7 +1,6 @@
package ollama
import (
"encoding/json"
"fmt"
"io"
"net/http"
@@ -15,176 +14,121 @@ import (
"github.com/gin-gonic/gin"
)
func openAIChatToOllamaChat(c *gin.Context, r *dto.GeneralOpenAIRequest) (*OllamaChatRequest, error) {
chatReq := &OllamaChatRequest{
Model: r.Model,
Stream: r.Stream,
Options: map[string]any{},
Think: r.Think,
}
if r.ResponseFormat != nil {
if r.ResponseFormat.Type == "json" {
chatReq.Format = "json"
} else if r.ResponseFormat.Type == "json_schema" {
if len(r.ResponseFormat.JsonSchema) > 0 {
var schema any
_ = json.Unmarshal(r.ResponseFormat.JsonSchema, &schema)
chatReq.Format = schema
}
}
}
// options mapping
if r.Temperature != nil { chatReq.Options["temperature"] = r.Temperature }
if r.TopP != 0 { chatReq.Options["top_p"] = r.TopP }
if r.TopK != 0 { chatReq.Options["top_k"] = r.TopK }
if r.FrequencyPenalty != 0 { chatReq.Options["frequency_penalty"] = r.FrequencyPenalty }
if r.PresencePenalty != 0 { chatReq.Options["presence_penalty"] = r.PresencePenalty }
if r.Seed != 0 { chatReq.Options["seed"] = int(r.Seed) }
if mt := r.GetMaxTokens(); mt != 0 { chatReq.Options["num_predict"] = int(mt) }
if r.Stop != nil {
switch v := r.Stop.(type) {
case string:
chatReq.Options["stop"] = []string{v}
case []string:
chatReq.Options["stop"] = v
case []any:
arr := make([]string,0,len(v))
for _, i := range v { if s,ok:=i.(string); ok { arr = append(arr,s) } }
if len(arr)>0 { chatReq.Options["stop"] = arr }
}
}
if len(r.Tools) > 0 {
tools := make([]OllamaTool,0,len(r.Tools))
for _, t := range r.Tools {
tools = append(tools, OllamaTool{Type: "function", Function: OllamaToolFunction{Name: t.Function.Name, Description: t.Function.Description, Parameters: t.Function.Parameters}})
}
chatReq.Tools = tools
}
chatReq.Messages = make([]OllamaChatMessage,0,len(r.Messages))
for _, m := range r.Messages {
var textBuilder strings.Builder
var images []string
if m.IsStringContent() {
textBuilder.WriteString(m.StringContent())
} else {
parts := m.ParseContent()
for _, part := range parts {
if part.Type == dto.ContentTypeImageURL {
img := part.GetImageMedia()
if img != nil && img.Url != "" {
var base64Data string
if strings.HasPrefix(img.Url, "http") {
fileData, err := service.GetFileBase64FromUrl(c, img.Url, "fetch image for ollama chat")
if err != nil { return nil, err }
base64Data = fileData.Base64Data
} else if strings.HasPrefix(img.Url, "data:") {
if idx := strings.Index(img.Url, ","); idx != -1 && idx+1 < len(img.Url) { base64Data = img.Url[idx+1:] }
} else {
base64Data = img.Url
func requestOpenAI2Ollama(c *gin.Context, request *dto.GeneralOpenAIRequest) (*OllamaRequest, error) {
messages := make([]dto.Message, 0, len(request.Messages))
for _, message := range request.Messages {
if !message.IsStringContent() {
mediaMessages := message.ParseContent()
for j, mediaMessage := range mediaMessages {
if mediaMessage.Type == dto.ContentTypeImageURL {
imageUrl := mediaMessage.GetImageMedia()
// check if not base64
if strings.HasPrefix(imageUrl.Url, "http") {
fileData, err := service.GetFileBase64FromUrl(c, imageUrl.Url, "formatting image for Ollama")
if err != nil {
return nil, err
}
if base64Data != "" { images = append(images, base64Data) }
imageUrl.Url = fmt.Sprintf("data:%s;base64,%s", fileData.MimeType, fileData.Base64Data)
}
} else if part.Type == dto.ContentTypeText {
textBuilder.WriteString(part.Text)
mediaMessage.ImageUrl = imageUrl
mediaMessages[j] = mediaMessage
}
}
message.SetMediaContent(mediaMessages)
}
cm := OllamaChatMessage{Role: m.Role, Content: textBuilder.String()}
if len(images)>0 { cm.Images = images }
if m.Role == "tool" && m.Name != nil { cm.ToolName = *m.Name }
if m.ToolCalls != nil && len(m.ToolCalls) > 0 {
parsed := m.ParseToolCalls()
if len(parsed) > 0 {
calls := make([]OllamaToolCall,0,len(parsed))
for _, tc := range parsed {
var args interface{}
if tc.Function.Arguments != "" { _ = json.Unmarshal([]byte(tc.Function.Arguments), &args) }
if args==nil { args = map[string]any{} }
oc := OllamaToolCall{}
oc.Function.Name = tc.Function.Name
oc.Function.Arguments = args
calls = append(calls, oc)
}
cm.ToolCalls = calls
}
}
chatReq.Messages = append(chatReq.Messages, cm)
messages = append(messages, dto.Message{
Role: message.Role,
Content: message.Content,
ToolCalls: message.ToolCalls,
ToolCallId: message.ToolCallId,
})
}
return chatReq, nil
str, ok := request.Stop.(string)
var Stop []string
if ok {
Stop = []string{str}
} else {
Stop, _ = request.Stop.([]string)
}
ollamaRequest := &OllamaRequest{
Model: request.Model,
Messages: messages,
Stream: request.Stream,
Temperature: request.Temperature,
Seed: request.Seed,
Topp: request.TopP,
TopK: request.TopK,
Stop: Stop,
Tools: request.Tools,
MaxTokens: request.GetMaxTokens(),
ResponseFormat: request.ResponseFormat,
FrequencyPenalty: request.FrequencyPenalty,
PresencePenalty: request.PresencePenalty,
Prompt: request.Prompt,
StreamOptions: request.StreamOptions,
Suffix: request.Suffix,
}
ollamaRequest.Think = request.Think
return ollamaRequest, nil
}
// openAIToGenerate converts OpenAI completions request to Ollama generate
func openAIToGenerate(c *gin.Context, r *dto.GeneralOpenAIRequest) (*OllamaGenerateRequest, error) {
gen := &OllamaGenerateRequest{
Model: r.Model,
Stream: r.Stream,
Options: map[string]any{},
Think: r.Think,
func requestOpenAI2Embeddings(request dto.EmbeddingRequest) *OllamaEmbeddingRequest {
return &OllamaEmbeddingRequest{
Model: request.Model,
Input: request.ParseInput(),
Options: &Options{
Seed: int(request.Seed),
Temperature: request.Temperature,
TopP: request.TopP,
FrequencyPenalty: request.FrequencyPenalty,
PresencePenalty: request.PresencePenalty,
},
}
// Prompt may be in r.Prompt (string or []any)
if r.Prompt != nil {
switch v := r.Prompt.(type) {
case string:
gen.Prompt = v
case []any:
var sb strings.Builder
for _, it := range v { if s,ok:=it.(string); ok { sb.WriteString(s) } }
gen.Prompt = sb.String()
default:
gen.Prompt = fmt.Sprintf("%v", r.Prompt)
}
}
if r.Suffix != nil { if s,ok:=r.Suffix.(string); ok { gen.Suffix = s } }
if r.ResponseFormat != nil {
if r.ResponseFormat.Type == "json" { gen.Format = "json" } else if r.ResponseFormat.Type == "json_schema" { var schema any; _ = json.Unmarshal(r.ResponseFormat.JsonSchema,&schema); gen.Format=schema }
}
if r.Temperature != nil { gen.Options["temperature"] = r.Temperature }
if r.TopP != 0 { gen.Options["top_p"] = r.TopP }
if r.TopK != 0 { gen.Options["top_k"] = r.TopK }
if r.FrequencyPenalty != 0 { gen.Options["frequency_penalty"] = r.FrequencyPenalty }
if r.PresencePenalty != 0 { gen.Options["presence_penalty"] = r.PresencePenalty }
if r.Seed != 0 { gen.Options["seed"] = int(r.Seed) }
if mt := r.GetMaxTokens(); mt != 0 { gen.Options["num_predict"] = int(mt) }
if r.Stop != nil {
switch v := r.Stop.(type) {
case string: gen.Options["stop"] = []string{v}
case []string: gen.Options["stop"] = v
case []any: arr:=make([]string,0,len(v)); for _,i:= range v { if s,ok:=i.(string); ok { arr=append(arr,s) } }; if len(arr)>0 { gen.Options["stop"]=arr }
}
}
return gen, nil
}
func requestOpenAI2Embeddings(r dto.EmbeddingRequest) *OllamaEmbeddingRequest {
opts := map[string]any{}
if r.Temperature != nil { opts["temperature"] = r.Temperature }
if r.TopP != 0 { opts["top_p"] = r.TopP }
if r.FrequencyPenalty != 0 { opts["frequency_penalty"] = r.FrequencyPenalty }
if r.PresencePenalty != 0 { opts["presence_penalty"] = r.PresencePenalty }
if r.Seed != 0 { opts["seed"] = int(r.Seed) }
if r.Dimensions != 0 { opts["dimensions"] = r.Dimensions }
input := r.ParseInput()
if len(input)==1 { return &OllamaEmbeddingRequest{Model:r.Model, Input: input[0], Options: opts, Dimensions:r.Dimensions} }
return &OllamaEmbeddingRequest{Model:r.Model, Input: input, Options: opts, Dimensions:r.Dimensions}
}
func ollamaEmbeddingHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http.Response) (*dto.Usage, *types.NewAPIError) {
var oResp OllamaEmbeddingResponse
body, err := io.ReadAll(resp.Body)
if err != nil { return nil, types.NewOpenAIError(err, types.ErrorCodeBadResponseBody, http.StatusInternalServerError) }
var ollamaEmbeddingResponse OllamaEmbeddingResponse
responseBody, err := io.ReadAll(resp.Body)
if err != nil {
return nil, types.NewOpenAIError(err, types.ErrorCodeBadResponseBody, http.StatusInternalServerError)
}
service.CloseResponseBodyGracefully(resp)
if err = common.Unmarshal(body, &oResp); err != nil { return nil, types.NewOpenAIError(err, types.ErrorCodeBadResponseBody, http.StatusInternalServerError) }
if oResp.Error != "" { return nil, types.NewOpenAIError(fmt.Errorf("ollama error: %s", oResp.Error), types.ErrorCodeBadResponseBody, http.StatusInternalServerError) }
data := make([]dto.OpenAIEmbeddingResponseItem,0,len(oResp.Embeddings))
for i, emb := range oResp.Embeddings { data = append(data, dto.OpenAIEmbeddingResponseItem{Index:i,Object:"embedding",Embedding:emb}) }
usage := &dto.Usage{PromptTokens: oResp.PromptEvalCount, CompletionTokens:0, TotalTokens: oResp.PromptEvalCount}
embResp := &dto.OpenAIEmbeddingResponse{Object:"list", Data:data, Model: info.UpstreamModelName, Usage:*usage}
out, _ := common.Marshal(embResp)
service.IOCopyBytesGracefully(c, resp, out)
err = common.Unmarshal(responseBody, &ollamaEmbeddingResponse)
if err != nil {
return nil, types.NewOpenAIError(err, types.ErrorCodeBadResponseBody, http.StatusInternalServerError)
}
if ollamaEmbeddingResponse.Error != "" {
return nil, types.NewOpenAIError(fmt.Errorf("ollama error: %s", ollamaEmbeddingResponse.Error), types.ErrorCodeBadResponseBody, http.StatusInternalServerError)
}
flattenedEmbeddings := flattenEmbeddings(ollamaEmbeddingResponse.Embedding)
data := make([]dto.OpenAIEmbeddingResponseItem, 0, 1)
data = append(data, dto.OpenAIEmbeddingResponseItem{
Embedding: flattenedEmbeddings,
Object: "embedding",
})
usage := &dto.Usage{
TotalTokens: info.PromptTokens,
CompletionTokens: 0,
PromptTokens: info.PromptTokens,
}
embeddingResponse := &dto.OpenAIEmbeddingResponse{
Object: "list",
Data: data,
Model: info.UpstreamModelName,
Usage: *usage,
}
doResponseBody, err := common.Marshal(embeddingResponse)
if err != nil {
return nil, types.NewOpenAIError(err, types.ErrorCodeBadResponseBody, http.StatusInternalServerError)
}
service.IOCopyBytesGracefully(c, resp, doResponseBody)
return usage, nil
}
func flattenEmbeddings(embeddings [][]float64) []float64 {
flattened := []float64{}
for _, row := range embeddings {
flattened = append(flattened, row...)
}
return flattened
}

View File

@@ -1,210 +0,0 @@
package ollama
import (
"bufio"
"encoding/json"
"fmt"
"io"
"net/http"
"one-api/common"
"one-api/dto"
"one-api/logger"
relaycommon "one-api/relay/common"
"one-api/relay/helper"
"one-api/service"
"one-api/types"
"strings"
"time"
"github.com/gin-gonic/gin"
)
type ollamaChatStreamChunk struct {
Model string `json:"model"`
CreatedAt string `json:"created_at"`
// chat
Message *struct {
Role string `json:"role"`
Content string `json:"content"`
Thinking json.RawMessage `json:"thinking"`
ToolCalls []struct {
Function struct {
Name string `json:"name"`
Arguments interface{} `json:"arguments"`
} `json:"function"`
} `json:"tool_calls"`
} `json:"message"`
// generate
Response string `json:"response"`
Done bool `json:"done"`
DoneReason string `json:"done_reason"`
TotalDuration int64 `json:"total_duration"`
LoadDuration int64 `json:"load_duration"`
PromptEvalCount int `json:"prompt_eval_count"`
EvalCount int `json:"eval_count"`
PromptEvalDuration int64 `json:"prompt_eval_duration"`
EvalDuration int64 `json:"eval_duration"`
}
func toUnix(ts string) int64 {
if ts == "" { return time.Now().Unix() }
// try time.RFC3339 or with nanoseconds
t, err := time.Parse(time.RFC3339Nano, ts)
if err != nil { t2, err2 := time.Parse(time.RFC3339, ts); if err2==nil { return t2.Unix() }; return time.Now().Unix() }
return t.Unix()
}
func ollamaStreamHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http.Response) (*dto.Usage, *types.NewAPIError) {
if resp == nil || resp.Body == nil { return nil, types.NewOpenAIError(fmt.Errorf("empty response"), types.ErrorCodeBadResponse, http.StatusBadRequest) }
defer service.CloseResponseBodyGracefully(resp)
helper.SetEventStreamHeaders(c)
scanner := bufio.NewScanner(resp.Body)
usage := &dto.Usage{}
var model = info.UpstreamModelName
var responseId = common.GetUUID()
var created = time.Now().Unix()
var toolCallIndex int
start := helper.GenerateStartEmptyResponse(responseId, created, model, nil)
if data, err := common.Marshal(start); err == nil { _ = helper.StringData(c, string(data)) }
for scanner.Scan() {
line := scanner.Text()
line = strings.TrimSpace(line)
if line == "" { continue }
var chunk ollamaChatStreamChunk
if err := json.Unmarshal([]byte(line), &chunk); err != nil {
logger.LogError(c, "ollama stream json decode error: "+err.Error()+" line="+line)
return usage, types.NewOpenAIError(err, types.ErrorCodeBadResponseBody, http.StatusInternalServerError)
}
if chunk.Model != "" { model = chunk.Model }
created = toUnix(chunk.CreatedAt)
if !chunk.Done {
// delta content
var content string
if chunk.Message != nil { content = chunk.Message.Content } else { content = chunk.Response }
delta := dto.ChatCompletionsStreamResponse{
Id: responseId,
Object: "chat.completion.chunk",
Created: created,
Model: model,
Choices: []dto.ChatCompletionsStreamResponseChoice{ {
Index: 0,
Delta: dto.ChatCompletionsStreamResponseChoiceDelta{ Role: "assistant" },
} },
}
if content != "" { delta.Choices[0].Delta.SetContentString(content) }
if chunk.Message != nil && len(chunk.Message.Thinking) > 0 {
raw := strings.TrimSpace(string(chunk.Message.Thinking))
if raw != "" && raw != "null" { delta.Choices[0].Delta.SetReasoningContent(raw) }
}
// tool calls
if chunk.Message != nil && len(chunk.Message.ToolCalls) > 0 {
delta.Choices[0].Delta.ToolCalls = make([]dto.ToolCallResponse,0,len(chunk.Message.ToolCalls))
for _, tc := range chunk.Message.ToolCalls {
// arguments -> string
argBytes, _ := json.Marshal(tc.Function.Arguments)
toolId := fmt.Sprintf("call_%d", toolCallIndex)
tr := dto.ToolCallResponse{ID:toolId, Type:"function", Function: dto.FunctionResponse{Name: tc.Function.Name, Arguments: string(argBytes)}}
tr.SetIndex(toolCallIndex)
toolCallIndex++
delta.Choices[0].Delta.ToolCalls = append(delta.Choices[0].Delta.ToolCalls, tr)
}
}
if data, err := common.Marshal(delta); err == nil { _ = helper.StringData(c, string(data)) }
continue
}
// done frame
// finalize once and break loop
usage.PromptTokens = chunk.PromptEvalCount
usage.CompletionTokens = chunk.EvalCount
usage.TotalTokens = usage.PromptTokens + usage.CompletionTokens
finishReason := chunk.DoneReason
if finishReason == "" { finishReason = "stop" }
// emit stop delta
if stop := helper.GenerateStopResponse(responseId, created, model, finishReason); stop != nil {
if data, err := common.Marshal(stop); err == nil { _ = helper.StringData(c, string(data)) }
}
// emit usage frame
if final := helper.GenerateFinalUsageResponse(responseId, created, model, *usage); final != nil {
if data, err := common.Marshal(final); err == nil { _ = helper.StringData(c, string(data)) }
}
// send [DONE]
helper.Done(c)
break
}
if err := scanner.Err(); err != nil && err != io.EOF { logger.LogError(c, "ollama stream scan error: "+err.Error()) }
return usage, nil
}
// non-stream handler for chat/generate
func ollamaChatHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http.Response) (*dto.Usage, *types.NewAPIError) {
body, err := io.ReadAll(resp.Body)
if err != nil { return nil, types.NewOpenAIError(err, types.ErrorCodeReadResponseBodyFailed, http.StatusInternalServerError) }
service.CloseResponseBodyGracefully(resp)
raw := string(body)
if common.DebugEnabled { println("ollama non-stream raw resp:", raw) }
lines := strings.Split(raw, "\n")
var (
aggContent strings.Builder
reasoningBuilder strings.Builder
lastChunk ollamaChatStreamChunk
parsedAny bool
)
for _, ln := range lines {
ln = strings.TrimSpace(ln)
if ln == "" { continue }
var ck ollamaChatStreamChunk
if err := json.Unmarshal([]byte(ln), &ck); err != nil {
if len(lines) == 1 { return nil, types.NewOpenAIError(err, types.ErrorCodeBadResponseBody, http.StatusInternalServerError) }
continue
}
parsedAny = true
lastChunk = ck
if ck.Message != nil && len(ck.Message.Thinking) > 0 {
raw := strings.TrimSpace(string(ck.Message.Thinking))
if raw != "" && raw != "null" { reasoningBuilder.WriteString(raw) }
}
if ck.Message != nil && ck.Message.Content != "" { aggContent.WriteString(ck.Message.Content) } else if ck.Response != "" { aggContent.WriteString(ck.Response) }
}
if !parsedAny {
var single ollamaChatStreamChunk
if err := json.Unmarshal(body, &single); err != nil { return nil, types.NewOpenAIError(err, types.ErrorCodeBadResponseBody, http.StatusInternalServerError) }
lastChunk = single
if single.Message != nil {
if len(single.Message.Thinking) > 0 { raw := strings.TrimSpace(string(single.Message.Thinking)); if raw != "" && raw != "null" { reasoningBuilder.WriteString(raw) } }
aggContent.WriteString(single.Message.Content)
} else { aggContent.WriteString(single.Response) }
}
model := lastChunk.Model
if model == "" { model = info.UpstreamModelName }
created := toUnix(lastChunk.CreatedAt)
usage := &dto.Usage{PromptTokens: lastChunk.PromptEvalCount, CompletionTokens: lastChunk.EvalCount, TotalTokens: lastChunk.PromptEvalCount + lastChunk.EvalCount}
content := aggContent.String()
finishReason := lastChunk.DoneReason
if finishReason == "" { finishReason = "stop" }
msg := dto.Message{Role: "assistant", Content: contentPtr(content)}
if rc := reasoningBuilder.String(); rc != "" { msg.ReasoningContent = rc }
full := dto.OpenAITextResponse{
Id: common.GetUUID(),
Model: model,
Object: "chat.completion",
Created: created,
Choices: []dto.OpenAITextResponseChoice{ {
Index: 0,
Message: msg,
FinishReason: finishReason,
} },
Usage: *usage,
}
out, _ := common.Marshal(full)
service.IOCopyBytesGracefully(c, resp, out)
return usage, nil
}
func contentPtr(s string) *string { if s=="" { return nil }; return &s }

View File

@@ -12,7 +12,6 @@ import (
"one-api/constant"
"one-api/dto"
"one-api/logger"
"one-api/relay/channel/openrouter"
relaycommon "one-api/relay/common"
"one-api/relay/helper"
"one-api/service"
@@ -186,27 +185,10 @@ func OpenaiHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http.Respo
if common.DebugEnabled {
println("upstream response body:", string(responseBody))
}
// Unmarshal to simpleResponse
if info.ChannelType == constant.ChannelTypeOpenRouter && info.ChannelOtherSettings.IsOpenRouterEnterprise() {
// 尝试解析为 openrouter enterprise
var enterpriseResponse openrouter.OpenRouterEnterpriseResponse
err = common.Unmarshal(responseBody, &enterpriseResponse)
if err != nil {
return nil, types.NewOpenAIError(err, types.ErrorCodeBadResponseBody, http.StatusInternalServerError)
}
if enterpriseResponse.Success {
responseBody = enterpriseResponse.Data
} else {
logger.LogError(c, fmt.Sprintf("openrouter enterprise response success=false, data: %s", enterpriseResponse.Data))
return nil, types.NewOpenAIError(fmt.Errorf("openrouter response success=false"), types.ErrorCodeBadResponseBody, http.StatusInternalServerError)
}
}
err = common.Unmarshal(responseBody, &simpleResponse)
if err != nil {
return nil, types.NewOpenAIError(err, types.ErrorCodeBadResponseBody, http.StatusInternalServerError)
}
if oaiError := simpleResponse.GetOpenAIError(); oaiError != nil && oaiError.Type != "" {
return nil, types.WithOpenAIError(*oaiError, resp.StatusCode)
}

View File

@@ -33,12 +33,6 @@ func OaiResponsesHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http
return nil, types.WithOpenAIError(*oaiError, resp.StatusCode)
}
if responsesResponse.HasImageGenerationCall() {
c.Set("image_generation_call", true)
c.Set("image_generation_call_quality", responsesResponse.GetQuality())
c.Set("image_generation_call_size", responsesResponse.GetSize())
}
// 写入新的 response body
service.IOCopyBytesGracefully(c, resp, responseBody)
@@ -86,25 +80,18 @@ func OaiResponsesStreamHandler(c *gin.Context, info *relaycommon.RelayInfo, resp
sendResponsesStreamData(c, streamResponse, data)
switch streamResponse.Type {
case "response.completed":
if streamResponse.Response != nil {
if streamResponse.Response.Usage != nil {
if streamResponse.Response.Usage.InputTokens != 0 {
usage.PromptTokens = streamResponse.Response.Usage.InputTokens
}
if streamResponse.Response.Usage.OutputTokens != 0 {
usage.CompletionTokens = streamResponse.Response.Usage.OutputTokens
}
if streamResponse.Response.Usage.TotalTokens != 0 {
usage.TotalTokens = streamResponse.Response.Usage.TotalTokens
}
if streamResponse.Response.Usage.InputTokensDetails != nil {
usage.PromptTokensDetails.CachedTokens = streamResponse.Response.Usage.InputTokensDetails.CachedTokens
}
if streamResponse.Response != nil && streamResponse.Response.Usage != nil {
if streamResponse.Response.Usage.InputTokens != 0 {
usage.PromptTokens = streamResponse.Response.Usage.InputTokens
}
if streamResponse.Response.HasImageGenerationCall() {
c.Set("image_generation_call", true)
c.Set("image_generation_call_quality", streamResponse.Response.GetQuality())
c.Set("image_generation_call_size", streamResponse.Response.GetSize())
if streamResponse.Response.Usage.OutputTokens != 0 {
usage.CompletionTokens = streamResponse.Response.Usage.OutputTokens
}
if streamResponse.Response.Usage.TotalTokens != 0 {
usage.TotalTokens = streamResponse.Response.Usage.TotalTokens
}
if streamResponse.Response.Usage.InputTokensDetails != nil {
usage.PromptTokensDetails.CachedTokens = streamResponse.Response.Usage.InputTokensDetails.CachedTokens
}
}
case "response.output_text.delta":

View File

@@ -1,7 +1,5 @@
package openrouter
import "encoding/json"
type RequestReasoning struct {
// One of the following (not both):
Effort string `json:"effort,omitempty"` // Can be "high", "medium", or "low" (OpenAI-style)
@@ -9,8 +7,3 @@ type RequestReasoning struct {
// Optional: Default is false. All models support this.
Exclude bool `json:"exclude,omitempty"` // Set to true to exclude reasoning tokens from response
}
type OpenRouterEnterpriseResponse struct {
Data json.RawMessage `json:"data"`
Success bool `json:"success"`
}

View File

@@ -94,9 +94,6 @@ func (a *TaskAdaptor) ValidateRequestAndSetAction(c *gin.Context, info *relaycom
// BuildRequestURL constructs the upstream URL.
func (a *TaskAdaptor) BuildRequestURL(info *relaycommon.RelayInfo) (string, error) {
if isNewAPIRelay(info.ApiKey) {
return fmt.Sprintf("%s/jimeng/?Action=CVSync2AsyncSubmitTask&Version=2022-08-31", a.baseURL), nil
}
return fmt.Sprintf("%s/?Action=CVSync2AsyncSubmitTask&Version=2022-08-31", a.baseURL), nil
}
@@ -104,12 +101,7 @@ func (a *TaskAdaptor) BuildRequestURL(info *relaycommon.RelayInfo) (string, erro
func (a *TaskAdaptor) BuildRequestHeader(c *gin.Context, req *http.Request, info *relaycommon.RelayInfo) error {
req.Header.Set("Content-Type", "application/json")
req.Header.Set("Accept", "application/json")
if isNewAPIRelay(info.ApiKey) {
req.Header.Set("Authorization", "Bearer "+info.ApiKey)
} else {
return a.signRequest(req, a.accessKey, a.secretKey)
}
return nil
return a.signRequest(req, a.accessKey, a.secretKey)
}
// BuildRequestBody converts request into Jimeng specific format.
@@ -169,9 +161,6 @@ func (a *TaskAdaptor) FetchTask(baseUrl, key string, body map[string]any) (*http
}
uri := fmt.Sprintf("%s/?Action=CVSync2AsyncGetResult&Version=2022-08-31", baseUrl)
if isNewAPIRelay(key) {
uri = fmt.Sprintf("%s/jimeng/?Action=CVSync2AsyncGetResult&Version=2022-08-31", a.baseURL)
}
payload := map[string]string{
"req_key": "jimeng_vgfm_t2v_l20", // This is fixed value from doc: https://www.volcengine.com/docs/85621/1544774
"task_id": taskID,
@@ -189,20 +178,17 @@ func (a *TaskAdaptor) FetchTask(baseUrl, key string, body map[string]any) (*http
req.Header.Set("Accept", "application/json")
req.Header.Set("Content-Type", "application/json")
if isNewAPIRelay(key) {
req.Header.Set("Authorization", "Bearer "+key)
} else {
keyParts := strings.Split(key, "|")
if len(keyParts) != 2 {
return nil, fmt.Errorf("invalid api key format for jimeng: expected 'ak|sk'")
}
accessKey := strings.TrimSpace(keyParts[0])
secretKey := strings.TrimSpace(keyParts[1])
if err := a.signRequest(req, accessKey, secretKey); err != nil {
return nil, errors.Wrap(err, "sign request failed")
}
keyParts := strings.Split(key, "|")
if len(keyParts) != 2 {
return nil, fmt.Errorf("invalid api key format for jimeng: expected 'ak|sk'")
}
accessKey := strings.TrimSpace(keyParts[0])
secretKey := strings.TrimSpace(keyParts[1])
if err := a.signRequest(req, accessKey, secretKey); err != nil {
return nil, errors.Wrap(err, "sign request failed")
}
return service.GetHttpClient().Do(req)
}
@@ -398,7 +384,3 @@ func (a *TaskAdaptor) ParseTaskResult(respBody []byte) (*relaycommon.TaskInfo, e
taskResult.Url = resTask.Data.VideoUrl
return &taskResult, nil
}
func isNewAPIRelay(apiKey string) bool {
return strings.HasPrefix(apiKey, "sk-")
}

View File

@@ -117,11 +117,6 @@ func (a *TaskAdaptor) ValidateRequestAndSetAction(c *gin.Context, info *relaycom
// BuildRequestURL constructs the upstream URL.
func (a *TaskAdaptor) BuildRequestURL(info *relaycommon.RelayInfo) (string, error) {
path := lo.Ternary(info.Action == constant.TaskActionGenerate, "/v1/videos/image2video", "/v1/videos/text2video")
if isNewAPIRelay(info.ApiKey) {
return fmt.Sprintf("%s/kling%s", a.baseURL, path), nil
}
return fmt.Sprintf("%s%s", a.baseURL, path), nil
}
@@ -204,9 +199,6 @@ func (a *TaskAdaptor) FetchTask(baseUrl, key string, body map[string]any) (*http
}
path := lo.Ternary(action == constant.TaskActionGenerate, "/v1/videos/image2video", "/v1/videos/text2video")
url := fmt.Sprintf("%s%s/%s", baseUrl, path, taskID)
if isNewAPIRelay(key) {
url = fmt.Sprintf("%s/kling%s/%s", baseUrl, path, taskID)
}
req, err := http.NewRequest(http.MethodGet, url, nil)
if err != nil {
@@ -312,13 +304,8 @@ func (a *TaskAdaptor) createJWTToken() (string, error) {
//}
func (a *TaskAdaptor) createJWTTokenWithKey(apiKey string) (string, error) {
if isNewAPIRelay(apiKey) {
return apiKey, nil // new api relay
}
keyParts := strings.Split(apiKey, "|")
if len(keyParts) != 2 {
return "", errors.New("invalid api_key, required format is accessKey|secretKey")
}
accessKey := strings.TrimSpace(keyParts[0])
if len(keyParts) == 1 {
return accessKey, nil
@@ -365,7 +352,3 @@ func (a *TaskAdaptor) ParseTaskResult(respBody []byte) (*relaycommon.TaskInfo, e
}
return taskInfo, nil
}
func isNewAPIRelay(apiKey string) bool {
return strings.HasPrefix(apiKey, "sk-")
}

View File

@@ -80,7 +80,8 @@ func (a *TaskAdaptor) Init(info *relaycommon.RelayInfo) {
}
func (a *TaskAdaptor) ValidateRequestAndSetAction(c *gin.Context, info *relaycommon.RelayInfo) *dto.TaskError {
return relaycommon.ValidateBasicTaskRequest(c, info, constant.TaskActionGenerate)
// Use the unified validation method for TaskSubmitReq with image-based action determination
return relaycommon.ValidateTaskRequestWithImageBinding(c, info)
}
func (a *TaskAdaptor) BuildRequestBody(c *gin.Context, _ *relaycommon.RelayInfo) (io.Reader, error) {
@@ -111,10 +112,6 @@ func (a *TaskAdaptor) BuildRequestURL(info *relaycommon.RelayInfo) (string, erro
switch info.Action {
case constant.TaskActionGenerate:
path = "/img2video"
case constant.TaskActionFirstTailGenerate:
path = "/start-end2video"
case constant.TaskActionReferenceGenerate:
path = "/reference2video"
default:
path = "/text2video"
}
@@ -190,9 +187,14 @@ func (a *TaskAdaptor) GetChannelName() string {
// ============================
func (a *TaskAdaptor) convertToRequestPayload(req *relaycommon.TaskSubmitReq) (*requestPayload, error) {
var images []string
if req.Image != "" {
images = []string{req.Image}
}
r := requestPayload{
Model: defaultString(req.Model, "viduq1"),
Images: req.Images,
Images: images,
Prompt: req.Prompt,
Duration: defaultInt(req.Duration, 5),
Resolution: defaultString(req.Size, "1080p"),

View File

@@ -9,7 +9,6 @@ import (
"mime/multipart"
"net/http"
"net/textproto"
channelconstant "one-api/constant"
"one-api/dto"
"one-api/relay/channel"
"one-api/relay/channel/openai"
@@ -42,8 +41,6 @@ func (a *Adaptor) ConvertAudioRequest(c *gin.Context, info *relaycommon.RelayInf
func (a *Adaptor) ConvertImageRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.ImageRequest) (any, error) {
switch info.RelayMode {
case constant.RelayModeImagesGenerations:
return request, nil
case constant.RelayModeImagesEdits:
var requestBody bytes.Buffer
@@ -189,26 +186,20 @@ func (a *Adaptor) Init(info *relaycommon.RelayInfo) {
}
func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) {
// 支持自定义域名,如果未设置则使用默认域名
baseUrl := info.ChannelBaseUrl
if baseUrl == "" {
baseUrl = channelconstant.ChannelBaseURLs[channelconstant.ChannelTypeVolcEngine]
}
switch info.RelayMode {
case constant.RelayModeChatCompletions:
if strings.HasPrefix(info.UpstreamModelName, "bot") {
return fmt.Sprintf("%s/api/v3/bots/chat/completions", baseUrl), nil
return fmt.Sprintf("%s/api/v3/bots/chat/completions", info.ChannelBaseUrl), nil
}
return fmt.Sprintf("%s/api/v3/chat/completions", baseUrl), nil
return fmt.Sprintf("%s/api/v3/chat/completions", info.ChannelBaseUrl), nil
case constant.RelayModeEmbeddings:
return fmt.Sprintf("%s/api/v3/embeddings", baseUrl), nil
return fmt.Sprintf("%s/api/v3/embeddings", info.ChannelBaseUrl), nil
case constant.RelayModeImagesGenerations:
return fmt.Sprintf("%s/api/v3/images/generations", baseUrl), nil
return fmt.Sprintf("%s/api/v3/images/generations", info.ChannelBaseUrl), nil
case constant.RelayModeImagesEdits:
return fmt.Sprintf("%s/api/v3/images/edits", baseUrl), nil
return fmt.Sprintf("%s/api/v3/images/edits", info.ChannelBaseUrl), nil
case constant.RelayModeRerank:
return fmt.Sprintf("%s/api/v3/rerank", baseUrl), nil
return fmt.Sprintf("%s/api/v3/rerank", info.ChannelBaseUrl), nil
default:
}
return "", fmt.Errorf("unsupported relay mode: %d", info.RelayMode)

View File

@@ -8,12 +8,6 @@ var ModelList = []string{
"Doubao-lite-32k",
"Doubao-lite-4k",
"Doubao-embedding",
"doubao-seedream-4-0-250828",
"seedream-4-0-250828",
"doubao-seedance-1-0-pro-250528",
"seedance-1-0-pro-250528",
"doubao-seed-1-6-thinking-250715",
"seed-1-6-thinking-250715",
}
var ChannelName = "volcengine"

View File

@@ -207,6 +207,10 @@ func xunfeiMakeRequest(textRequest dto.GeneralOpenAIRequest, domain, authUrl, ap
return nil, nil, err
}
defer func() {
conn.Close()
}()
data := requestOpenAI2Xunfei(textRequest, appId, domain)
err = conn.WriteJSON(data)
if err != nil {
@@ -216,9 +220,6 @@ func xunfeiMakeRequest(textRequest dto.GeneralOpenAIRequest, domain, authUrl, ap
dataChan := make(chan XunfeiChatResponse)
stopChan := make(chan bool)
go func() {
defer func() {
conn.Close()
}()
for {
_, msg, err := conn.ReadMessage()
if err != nil {

View File

@@ -6,7 +6,6 @@ import (
"io"
"net/http"
"one-api/common"
"one-api/constant"
"one-api/dto"
relaycommon "one-api/relay/common"
"one-api/relay/helper"
@@ -70,31 +69,6 @@ func ClaudeHelper(c *gin.Context, info *relaycommon.RelayInfo) (newAPIError *typ
info.UpstreamModelName = request.Model
}
if info.ChannelSetting.SystemPrompt != "" {
if request.System == nil {
request.SetStringSystem(info.ChannelSetting.SystemPrompt)
} else if info.ChannelSetting.SystemPromptOverride {
common.SetContextKey(c, constant.ContextKeySystemPromptOverride, true)
if request.IsStringSystem() {
existing := strings.TrimSpace(request.GetStringSystem())
if existing == "" {
request.SetStringSystem(info.ChannelSetting.SystemPrompt)
} else {
request.SetStringSystem(info.ChannelSetting.SystemPrompt + "\n" + existing)
}
} else {
systemContents := request.ParseSystem()
newSystem := dto.ClaudeMediaMessage{Type: dto.ContentTypeText}
newSystem.SetText(info.ChannelSetting.SystemPrompt)
if len(systemContents) == 0 {
request.System = []dto.ClaudeMediaMessage{newSystem}
} else {
request.System = append([]dto.ClaudeMediaMessage{newSystem}, systemContents...)
}
}
}
}
var requestBody io.Reader
if model_setting.GetGlobalSettings().PassThroughRequestEnabled || info.ChannelSetting.PassThroughBodyEnabled {
body, err := common.GetRequestBody(c)

View File

@@ -79,18 +79,34 @@ func ValidateBasicTaskRequest(c *gin.Context, info *RelayInfo, action string) *d
req.Images = []string{req.Image}
}
if req.HasImage() {
action = constant.TaskActionGenerate
if info.ChannelType == constant.ChannelTypeVidu {
// vidu 增加 首尾帧生视频和参考图生视频
if len(req.Images) == 2 {
action = constant.TaskActionFirstTailGenerate
} else if len(req.Images) > 2 {
action = constant.TaskActionReferenceGenerate
}
}
}
storeTaskRequest(c, info, action, req)
return nil
}
func ValidateTaskRequestWithImage(c *gin.Context, info *RelayInfo, requestObj interface{}) *dto.TaskError {
hasPrompt, ok := requestObj.(HasPrompt)
if !ok {
return createTaskError(fmt.Errorf("request must have prompt"), "invalid_request", http.StatusBadRequest, true)
}
if taskErr := validatePrompt(hasPrompt.GetPrompt()); taskErr != nil {
return taskErr
}
action := constant.TaskActionTextGenerate
if hasImage, ok := requestObj.(HasImage); ok && hasImage.HasImage() {
action = constant.TaskActionGenerate
}
storeTaskRequest(c, info, action, requestObj)
return nil
}
func ValidateTaskRequestWithImageBinding(c *gin.Context, info *RelayInfo) *dto.TaskError {
var req TaskSubmitReq
if err := c.ShouldBindJSON(&req); err != nil {
return createTaskError(err, "invalid_request_body", http.StatusBadRequest, false)
}
return ValidateTaskRequestWithImage(c, info, req)
}

View File

@@ -90,41 +90,39 @@ func TextHelper(c *gin.Context, info *relaycommon.RelayInfo) (newAPIError *types
if info.ChannelSetting.SystemPrompt != "" {
// 如果有系统提示,则将其添加到请求中
request, ok := convertedRequest.(*dto.GeneralOpenAIRequest)
if ok {
containSystemPrompt := false
for _, message := range request.Messages {
if message.Role == request.GetSystemRoleName() {
containSystemPrompt = true
break
}
request := convertedRequest.(*dto.GeneralOpenAIRequest)
containSystemPrompt := false
for _, message := range request.Messages {
if message.Role == request.GetSystemRoleName() {
containSystemPrompt = true
break
}
if !containSystemPrompt {
// 如果没有系统提示,则添加系统提示
systemMessage := dto.Message{
Role: request.GetSystemRoleName(),
Content: info.ChannelSetting.SystemPrompt,
}
request.Messages = append([]dto.Message{systemMessage}, request.Messages...)
} else if info.ChannelSetting.SystemPromptOverride {
common.SetContextKey(c, constant.ContextKeySystemPromptOverride, true)
// 如果有系统提示,且允许覆盖,则拼接到前面
for i, message := range request.Messages {
if message.Role == request.GetSystemRoleName() {
if message.IsStringContent() {
request.Messages[i].SetStringContent(info.ChannelSetting.SystemPrompt + "\n" + message.StringContent())
} else {
contents := message.ParseContent()
contents = append([]dto.MediaContent{
{
Type: dto.ContentTypeText,
Text: info.ChannelSetting.SystemPrompt,
},
}, contents...)
request.Messages[i].Content = contents
}
break
}
if !containSystemPrompt {
// 如果没有系统提示,则添加系统提示
systemMessage := dto.Message{
Role: request.GetSystemRoleName(),
Content: info.ChannelSetting.SystemPrompt,
}
request.Messages = append([]dto.Message{systemMessage}, request.Messages...)
} else if info.ChannelSetting.SystemPromptOverride {
common.SetContextKey(c, constant.ContextKeySystemPromptOverride, true)
// 如果有系统提示,且允许覆盖,则拼接到前面
for i, message := range request.Messages {
if message.Role == request.GetSystemRoleName() {
if message.IsStringContent() {
request.Messages[i].SetStringContent(info.ChannelSetting.SystemPrompt + "\n" + message.StringContent())
} else {
contents := message.ParseContent()
contents = append([]dto.MediaContent{
{
Type: dto.ContentTypeText,
Text: info.ChannelSetting.SystemPrompt,
},
}, contents...)
request.Messages[i].Content = contents
}
break
}
}
}
@@ -278,13 +276,6 @@ func postConsumeQuota(ctx *gin.Context, relayInfo *relaycommon.RelayInfo, usage
fileSearchTool.CallCount, dFileSearchQuota.String())
}
}
var dImageGenerationCallQuota decimal.Decimal
var imageGenerationCallPrice float64
if ctx.GetBool("image_generation_call") {
imageGenerationCallPrice = operation_setting.GetGPTImage1PriceOnceCall(ctx.GetString("image_generation_call_quality"), ctx.GetString("image_generation_call_size"))
dImageGenerationCallQuota = decimal.NewFromFloat(imageGenerationCallPrice).Mul(dGroupRatio).Mul(dQuotaPerUnit)
extraContent += fmt.Sprintf("Image Generation Call 花费 %s", dImageGenerationCallQuota.String())
}
var quotaCalculateDecimal decimal.Decimal
@@ -340,8 +331,6 @@ func postConsumeQuota(ctx *gin.Context, relayInfo *relaycommon.RelayInfo, usage
quotaCalculateDecimal = quotaCalculateDecimal.Add(dFileSearchQuota)
// 添加 audio input 独立计费
quotaCalculateDecimal = quotaCalculateDecimal.Add(audioInputQuota)
// 添加 image generation call 计费
quotaCalculateDecimal = quotaCalculateDecimal.Add(dImageGenerationCallQuota)
quota := int(quotaCalculateDecimal.Round(0).IntPart())
totalTokens := promptTokens + completionTokens
@@ -440,10 +429,6 @@ func postConsumeQuota(ctx *gin.Context, relayInfo *relaycommon.RelayInfo, usage
other["audio_input_token_count"] = audioTokens
other["audio_input_price"] = audioInputPrice
}
if !dImageGenerationCallQuota.IsZero() {
other["image_generation_call"] = true
other["image_generation_call_price"] = imageGenerationCallPrice
}
model.RecordConsumeLog(ctx, relayInfo.UserId, model.RecordConsumeLogParams{
ChannelId: relayInfo.ChannelId,
PromptTokens: promptTokens,

View File

@@ -6,7 +6,6 @@ import (
"io"
"net/http"
"one-api/common"
"one-api/constant"
"one-api/dto"
"one-api/logger"
"one-api/relay/channel/gemini"
@@ -95,32 +94,6 @@ func GeminiHelper(c *gin.Context, info *relaycommon.RelayInfo) (newAPIError *typ
adaptor.Init(info)
if info.ChannelSetting.SystemPrompt != "" {
if request.SystemInstructions == nil {
request.SystemInstructions = &dto.GeminiChatContent{
Parts: []dto.GeminiPart{
{Text: info.ChannelSetting.SystemPrompt},
},
}
} else if len(request.SystemInstructions.Parts) == 0 {
request.SystemInstructions.Parts = []dto.GeminiPart{{Text: info.ChannelSetting.SystemPrompt}}
} else if info.ChannelSetting.SystemPromptOverride {
common.SetContextKey(c, constant.ContextKeySystemPromptOverride, true)
merged := false
for i := range request.SystemInstructions.Parts {
if request.SystemInstructions.Parts[i].Text == "" {
continue
}
request.SystemInstructions.Parts[i].Text = info.ChannelSetting.SystemPrompt + "\n" + request.SystemInstructions.Parts[i].Text
merged = true
break
}
if !merged {
request.SystemInstructions.Parts = append([]dto.GeminiPart{{Text: info.ChannelSetting.SystemPrompt}}, request.SystemInstructions.Parts...)
}
}
}
// Clean up empty system instruction
if request.SystemInstructions != nil {
hasContent := false

View File

@@ -52,8 +52,6 @@ func ModelPriceHelper(c *gin.Context, info *relaycommon.RelayInfo, promptTokens
var cacheRatio float64
var imageRatio float64
var cacheCreationRatio float64
var audioRatio float64
var audioCompletionRatio float64
if !usePrice {
preConsumedTokens := common.Max(promptTokens, common.PreConsumedQuota)
if meta.MaxTokens != 0 {
@@ -75,8 +73,6 @@ func ModelPriceHelper(c *gin.Context, info *relaycommon.RelayInfo, promptTokens
cacheRatio, _ = ratio_setting.GetCacheRatio(info.OriginModelName)
cacheCreationRatio, _ = ratio_setting.GetCreateCacheRatio(info.OriginModelName)
imageRatio, _ = ratio_setting.GetImageRatio(info.OriginModelName)
audioRatio = ratio_setting.GetAudioRatio(info.OriginModelName)
audioCompletionRatio = ratio_setting.GetAudioCompletionRatio(info.OriginModelName)
ratio := modelRatio * groupRatioInfo.GroupRatio
preConsumedQuota = int(float64(preConsumedTokens) * ratio)
} else {
@@ -94,8 +90,6 @@ func ModelPriceHelper(c *gin.Context, info *relaycommon.RelayInfo, promptTokens
UsePrice: usePrice,
CacheRatio: cacheRatio,
ImageRatio: imageRatio,
AudioRatio: audioRatio,
AudioCompletionRatio: audioCompletionRatio,
CacheCreationRatio: cacheCreationRatio,
ShouldPreConsumedQuota: preConsumedQuota,
}

View File

@@ -21,11 +21,7 @@ func GetAndValidateRequest(c *gin.Context, format types.RelayFormat) (request dt
case types.RelayFormatOpenAI:
request, err = GetAndValidateTextRequest(c, relayMode)
case types.RelayFormatGemini:
if strings.Contains(c.Request.URL.Path, ":embedContent") || strings.Contains(c.Request.URL.Path, ":batchEmbedContents") {
request, err = GetAndValidateGeminiEmbeddingRequest(c)
} else {
request, err = GetAndValidateGeminiRequest(c)
}
request, err = GetAndValidateGeminiRequest(c)
case types.RelayFormatClaude:
request, err = GetAndValidateClaudeRequest(c)
case types.RelayFormatOpenAIResponses:
@@ -292,6 +288,7 @@ func GetAndValidateTextRequest(c *gin.Context, relayMode int) (*dto.GeneralOpenA
}
func GetAndValidateGeminiRequest(c *gin.Context) (*dto.GeminiChatRequest, error) {
request := &dto.GeminiChatRequest{}
err := common.UnmarshalBodyReusable(c, request)
if err != nil {
@@ -307,12 +304,3 @@ func GetAndValidateGeminiRequest(c *gin.Context) (*dto.GeminiChatRequest, error)
return request, nil
}
func GetAndValidateGeminiEmbeddingRequest(c *gin.Context) (*dto.GeminiEmbeddingRequest, error) {
request := &dto.GeminiEmbeddingRequest{}
err := common.UnmarshalBodyReusable(c, request)
if err != nil {
return nil, err
}
return request, nil
}

View File

@@ -31,6 +31,21 @@ func SetApiRouter(router *gin.Engine) {
apiRouter.GET("/oauth/oidc", middleware.CriticalRateLimit(), controller.OidcAuth)
apiRouter.GET("/oauth/linuxdo", middleware.CriticalRateLimit(), controller.LinuxdoOAuth)
apiRouter.GET("/oauth/state", middleware.CriticalRateLimit(), controller.GenerateOAuthCode)
// OAuth2 Server endpoints
apiRouter.GET("/.well-known/jwks.json", controller.GetJWKS)
apiRouter.GET("/.well-known/openid-configuration", controller.OAuthOIDCConfiguration)
apiRouter.GET("/.well-known/oauth-authorization-server", controller.OAuthServerInfo)
apiRouter.POST("/oauth/token", middleware.CriticalRateLimit(), controller.OAuthTokenEndpoint)
apiRouter.GET("/oauth/authorize", controller.OAuthAuthorizeEndpoint)
apiRouter.POST("/oauth/introspect", middleware.AdminAuth(), controller.OAuthIntrospect)
apiRouter.POST("/oauth/revoke", middleware.CriticalRateLimit(), controller.OAuthRevoke)
apiRouter.GET("/oauth/userinfo", middleware.OAuthJWTAuth(), controller.OAuthUserInfo)
// OAuth2 管理API (前端使用)
apiRouter.GET("/oauth/jwks", controller.GetJWKS)
apiRouter.GET("/oauth/server-info", controller.OAuthServerInfo)
apiRouter.GET("/oauth/wechat", middleware.CriticalRateLimit(), controller.WeChatAuth)
apiRouter.GET("/oauth/wechat/bind", middleware.CriticalRateLimit(), controller.WeChatBind)
apiRouter.GET("/oauth/email/bind", middleware.CriticalRateLimit(), controller.EmailBind)
@@ -40,6 +55,17 @@ func SetApiRouter(router *gin.Engine) {
apiRouter.POST("/stripe/webhook", controller.StripeWebhook)
// OAuth2 admin operations
oauthAdmin := apiRouter.Group("/oauth")
oauthAdmin.Use(middleware.OptionalOAuthAuth(), middleware.RequireOAuthScopeIfPresent("admin"), middleware.RootAuth())
{
oauthAdmin.POST("/keys/rotate", controller.RotateOAuthSigningKey)
oauthAdmin.GET("/keys", controller.ListOAuthSigningKeys)
oauthAdmin.DELETE("/keys/:kid", controller.DeleteOAuthSigningKey)
oauthAdmin.POST("/keys/generate_file", controller.GenerateOAuthSigningKeyFile)
oauthAdmin.POST("/keys/import_pem", controller.ImportOAuthSigningKey)
}
userRoute := apiRouter.Group("/user")
{
userRoute.POST("/register", middleware.CriticalRateLimit(), middleware.TurnstileCheck(), controller.Register)
@@ -78,7 +104,7 @@ func SetApiRouter(router *gin.Engine) {
}
adminRoute := userRoute.Group("/")
adminRoute.Use(middleware.AdminAuth())
adminRoute.Use(middleware.OptionalOAuthAuth(), middleware.RequireOAuthScopeIfPresent("admin"), middleware.AdminAuth())
{
adminRoute.GET("/", controller.GetAllUsers)
adminRoute.GET("/search", controller.SearchUsers)
@@ -94,7 +120,7 @@ func SetApiRouter(router *gin.Engine) {
}
}
optionRoute := apiRouter.Group("/option")
optionRoute.Use(middleware.RootAuth())
optionRoute.Use(middleware.OptionalOAuthAuth(), middleware.RequireOAuthScopeIfPresent("admin"), middleware.RootAuth())
{
optionRoute.GET("/", controller.GetOptions)
optionRoute.PUT("/", controller.UpdateOption)
@@ -108,7 +134,7 @@ func SetApiRouter(router *gin.Engine) {
ratioSyncRoute.POST("/fetch", controller.FetchUpstreamRatios)
}
channelRoute := apiRouter.Group("/channel")
channelRoute.Use(middleware.AdminAuth())
channelRoute.Use(middleware.OptionalOAuthAuth(), middleware.RequireOAuthScopeIfPresent("admin"), middleware.AdminAuth())
{
channelRoute.GET("/", controller.GetAllChannels)
channelRoute.GET("/search", controller.SearchChannels)
@@ -159,7 +185,7 @@ func SetApiRouter(router *gin.Engine) {
}
redemptionRoute := apiRouter.Group("/redemption")
redemptionRoute.Use(middleware.AdminAuth())
redemptionRoute.Use(middleware.OptionalOAuthAuth(), middleware.RequireOAuthScopeIfPresent("admin"), middleware.AdminAuth())
{
redemptionRoute.GET("/", controller.GetAllRedemptions)
redemptionRoute.GET("/search", controller.SearchRedemptions)
@@ -187,13 +213,13 @@ func SetApiRouter(router *gin.Engine) {
logRoute.GET("/token", controller.GetLogByKey)
}
groupRoute := apiRouter.Group("/group")
groupRoute.Use(middleware.AdminAuth())
groupRoute.Use(middleware.OptionalOAuthAuth(), middleware.RequireOAuthScopeIfPresent("admin"), middleware.AdminAuth())
{
groupRoute.GET("/", controller.GetGroups)
}
prefillGroupRoute := apiRouter.Group("/prefill_group")
prefillGroupRoute.Use(middleware.AdminAuth())
prefillGroupRoute.Use(middleware.OptionalOAuthAuth(), middleware.RequireOAuthScopeIfPresent("admin"), middleware.AdminAuth())
{
prefillGroupRoute.GET("/", controller.GetPrefillGroups)
prefillGroupRoute.POST("/", controller.CreatePrefillGroup)
@@ -235,5 +261,17 @@ func SetApiRouter(router *gin.Engine) {
modelsRoute.PUT("/", controller.UpdateModelMeta)
modelsRoute.DELETE("/:id", controller.DeleteModelMeta)
}
oauthClientsRoute := apiRouter.Group("/oauth_clients")
oauthClientsRoute.Use(middleware.AdminAuth())
{
oauthClientsRoute.GET("/", controller.GetAllOAuthClients)
oauthClientsRoute.GET("/search", controller.SearchOAuthClients)
oauthClientsRoute.GET("/:id", controller.GetOAuthClient)
oauthClientsRoute.POST("/", controller.CreateOAuthClient)
oauthClientsRoute.PUT("/", controller.UpdateOAuthClient)
oauthClientsRoute.DELETE("/:id", controller.DeleteOAuthClient)
oauthClientsRoute.POST("/:id/regenerate_secret", controller.RegenerateOAuthClientSecret)
}
}
}

View File

@@ -28,12 +28,6 @@ func DoWorkerRequest(req *WorkerRequest) (*http.Response, error) {
return nil, fmt.Errorf("only support https url")
}
// SSRF防护验证请求URL
fetchSetting := system_setting.GetFetchSetting()
if err := common.ValidateURLWithFetchSetting(req.URL, fetchSetting.EnableSSRFProtection, fetchSetting.AllowPrivateIp, fetchSetting.DomainFilterMode, fetchSetting.IpFilterMode, fetchSetting.DomainList, fetchSetting.IpList, fetchSetting.AllowedPorts, fetchSetting.ApplyIPFilterForDomain); err != nil {
return nil, fmt.Errorf("request reject: %v", err)
}
workerUrl := system_setting.WorkerUrl
if !strings.HasSuffix(workerUrl, "/") {
workerUrl += "/"
@@ -57,13 +51,7 @@ func DoDownloadRequest(originUrl string, reason ...string) (resp *http.Response,
}
return DoWorkerRequest(req)
} else {
// SSRF防护验证请求URL非Worker模式
fetchSetting := system_setting.GetFetchSetting()
if err := common.ValidateURLWithFetchSetting(originUrl, fetchSetting.EnableSSRFProtection, fetchSetting.AllowPrivateIp, fetchSetting.DomainFilterMode, fetchSetting.IpFilterMode, fetchSetting.DomainList, fetchSetting.IpList, fetchSetting.AllowedPorts, fetchSetting.ApplyIPFilterForDomain); err != nil {
return nil, fmt.Errorf("request reject: %v", err)
}
common.SysLog(fmt.Sprintf("downloading from origin: %s, reason: %s", common.MaskSensitiveInfo(originUrl), strings.Join(reason, ", ")))
common.SysLog(fmt.Sprintf("downloading from origin with worker: %s, reason: %s", originUrl, strings.Join(reason, ", ")))
return http.Get(originUrl)
}
}

View File

@@ -7,17 +7,12 @@ import (
"net/http"
"net/url"
"one-api/common"
"sync"
"time"
"golang.org/x/net/proxy"
)
var (
httpClient *http.Client
proxyClientLock sync.Mutex
proxyClients = make(map[string]*http.Client)
)
var httpClient *http.Client
func InitHttpClient() {
if common.RelayTimeout == 0 {
@@ -33,31 +28,12 @@ func GetHttpClient() *http.Client {
return httpClient
}
// ResetProxyClientCache 清空代理客户端缓存,确保下次使用时重新初始化
func ResetProxyClientCache() {
proxyClientLock.Lock()
defer proxyClientLock.Unlock()
for _, client := range proxyClients {
if transport, ok := client.Transport.(*http.Transport); ok && transport != nil {
transport.CloseIdleConnections()
}
}
proxyClients = make(map[string]*http.Client)
}
// NewProxyHttpClient 创建支持代理的 HTTP 客户端
func NewProxyHttpClient(proxyURL string) (*http.Client, error) {
if proxyURL == "" {
return http.DefaultClient, nil
}
proxyClientLock.Lock()
if client, ok := proxyClients[proxyURL]; ok {
proxyClientLock.Unlock()
return client, nil
}
proxyClientLock.Unlock()
parsedURL, err := url.Parse(proxyURL)
if err != nil {
return nil, err
@@ -65,16 +41,11 @@ func NewProxyHttpClient(proxyURL string) (*http.Client, error) {
switch parsedURL.Scheme {
case "http", "https":
client := &http.Client{
return &http.Client{
Transport: &http.Transport{
Proxy: http.ProxyURL(parsedURL),
},
}
client.Timeout = time.Duration(common.RelayTimeout) * time.Second
proxyClientLock.Lock()
proxyClients[proxyURL] = client
proxyClientLock.Unlock()
return client, nil
}, nil
case "socks5", "socks5h":
// 获取认证信息
@@ -96,18 +67,13 @@ func NewProxyHttpClient(proxyURL string) (*http.Client, error) {
return nil, err
}
client := &http.Client{
return &http.Client{
Transport: &http.Transport{
DialContext: func(ctx context.Context, network, addr string) (net.Conn, error) {
return dialer.Dial(network, addr)
},
},
}
client.Timeout = time.Duration(common.RelayTimeout) * time.Second
proxyClientLock.Lock()
proxyClients[proxyURL] = client
proxyClientLock.Unlock()
return client, nil
}, nil
default:
return nil, fmt.Errorf("unsupported proxy scheme: %s", parsedURL.Scheme)

View File

@@ -19,7 +19,7 @@ func ReturnPreConsumedQuota(c *gin.Context, relayInfo *relaycommon.RelayInfo) {
gopool.Go(func() {
relayInfoCopy := *relayInfo
err := PostConsumeQuota(&relayInfoCopy, -relayInfoCopy.FinalPreConsumedQuota, 0, false)
err := PostConsumeQuota(&relayInfoCopy, -relayInfo.FinalPreConsumedQuota, 0, false)
if err != nil {
common.SysLog("error return pre-consumed quota: " + err.Error())
}

View File

@@ -113,12 +113,6 @@ func sendBarkNotify(barkURL string, data dto.Notify) error {
return fmt.Errorf("bark request failed with status code: %d", resp.StatusCode)
}
} else {
// SSRF防护验证Bark URL非Worker模式
fetchSetting := system_setting.GetFetchSetting()
if err := common.ValidateURLWithFetchSetting(finalURL, fetchSetting.EnableSSRFProtection, fetchSetting.AllowPrivateIp, fetchSetting.DomainFilterMode, fetchSetting.IpFilterMode, fetchSetting.DomainList, fetchSetting.IpList, fetchSetting.AllowedPorts, fetchSetting.ApplyIPFilterForDomain); err != nil {
return fmt.Errorf("request reject: %v", err)
}
// 直接发送请求
req, err = http.NewRequest(http.MethodGet, finalURL, nil)
if err != nil {

View File

@@ -8,7 +8,6 @@ import (
"encoding/json"
"fmt"
"net/http"
"one-api/common"
"one-api/dto"
"one-api/setting/system_setting"
"time"
@@ -87,12 +86,6 @@ func SendWebhookNotify(webhookURL string, secret string, data dto.Notify) error
return fmt.Errorf("webhook request failed with status code: %d", resp.StatusCode)
}
} else {
// SSRF防护验证Webhook URL非Worker模式
fetchSetting := system_setting.GetFetchSetting()
if err := common.ValidateURLWithFetchSetting(webhookURL, fetchSetting.EnableSSRFProtection, fetchSetting.AllowPrivateIp, fetchSetting.DomainFilterMode, fetchSetting.IpFilterMode, fetchSetting.DomainList, fetchSetting.IpList, fetchSetting.AllowedPorts, fetchSetting.ApplyIPFilterForDomain); err != nil {
return fmt.Errorf("request reject: %v", err)
}
req, err = http.NewRequest(http.MethodPost, webhookURL, bytes.NewBuffer(payloadBytes))
if err != nil {
return fmt.Errorf("failed to create webhook request: %v", err)

View File

@@ -10,18 +10,6 @@ const (
FileSearchPrice = 2.5
)
const (
GPTImage1Low1024x1024 = 0.011
GPTImage1Low1024x1536 = 0.016
GPTImage1Low1536x1024 = 0.016
GPTImage1Medium1024x1024 = 0.042
GPTImage1Medium1024x1536 = 0.063
GPTImage1Medium1536x1024 = 0.063
GPTImage1High1024x1024 = 0.167
GPTImage1High1024x1536 = 0.25
GPTImage1High1536x1024 = 0.25
)
const (
// Gemini Audio Input Price
Gemini25FlashPreviewInputAudioPrice = 1.00
@@ -77,31 +65,3 @@ func GetGeminiInputAudioPricePerMillionTokens(modelName string) float64 {
}
return 0
}
func GetGPTImage1PriceOnceCall(quality string, size string) float64 {
prices := map[string]map[string]float64{
"low": {
"1024x1024": GPTImage1Low1024x1024,
"1024x1536": GPTImage1Low1024x1536,
"1536x1024": GPTImage1Low1536x1024,
},
"medium": {
"1024x1024": GPTImage1Medium1024x1024,
"1024x1536": GPTImage1Medium1024x1536,
"1536x1024": GPTImage1Medium1536x1024,
},
"high": {
"1024x1024": GPTImage1High1024x1024,
"1024x1536": GPTImage1High1024x1536,
"1536x1024": GPTImage1High1536x1024,
},
}
if qualityMap, exists := prices[quality]; exists {
if price, exists := qualityMap[size]; exists {
return price
}
}
return GPTImage1High1024x1024
}

View File

@@ -5,4 +5,3 @@ var StripeWebhookSecret = ""
var StripePriceId = ""
var StripeUnitPrice = 8.0
var StripeMinTopUp = 1
var StripePromotionCodesEnabled = false

View File

@@ -178,7 +178,6 @@ var defaultModelRatio = map[string]float64{
"gemini-2.5-flash-lite-preview-thinking-*": 0.05,
"gemini-2.5-flash-lite-preview-06-17": 0.05,
"gemini-2.5-flash": 0.15,
"gemini-embedding-001": 0.075,
"text-embedding-004": 0.001,
"chatglm_turbo": 0.3572, // ¥0.005 / 1k tokens
"chatglm_pro": 0.7143, // ¥0.01 / 1k tokens
@@ -279,18 +278,6 @@ var defaultModelPrice = map[string]float64{
"mj_upload": 0.05,
}
var defaultAudioRatio = map[string]float64{
"gpt-4o-audio-preview": 16,
"gpt-4o-mini-audio-preview": 66.67,
"gpt-4o-realtime-preview": 8,
"gpt-4o-mini-realtime-preview": 16.67,
}
var defaultAudioCompletionRatio = map[string]float64{
"gpt-4o-realtime": 2,
"gpt-4o-mini-realtime": 2,
}
var (
modelPriceMap map[string]float64 = nil
modelPriceMapMutex = sync.RWMutex{}
@@ -339,15 +326,6 @@ func InitRatioSettings() {
imageRatioMap = defaultImageRatio
imageRatioMapMutex.Unlock()
// initialize audioRatioMap
audioRatioMapMutex.Lock()
audioRatioMap = defaultAudioRatio
audioRatioMapMutex.Unlock()
// initialize audioCompletionRatioMap
audioCompletionRatioMapMutex.Lock()
audioCompletionRatioMap = defaultAudioCompletionRatio
audioCompletionRatioMapMutex.Unlock()
}
func GetModelPriceMap() map[string]float64 {
@@ -439,18 +417,6 @@ func GetDefaultModelRatioMap() map[string]float64 {
return defaultModelRatio
}
func GetDefaultImageRatioMap() map[string]float64 {
return defaultImageRatio
}
func GetDefaultAudioRatioMap() map[string]float64 {
return defaultAudioRatio
}
func GetDefaultAudioCompletionRatioMap() map[string]float64 {
return defaultAudioCompletionRatio
}
func GetCompletionRatioMap() map[string]float64 {
CompletionRatioMutex.RLock()
defer CompletionRatioMutex.RUnlock()
@@ -618,22 +584,32 @@ func getHardcodedCompletionModelRatio(name string) (float64, bool) {
}
func GetAudioRatio(name string) float64 {
audioRatioMapMutex.RLock()
defer audioRatioMapMutex.RUnlock()
name = FormatMatchingModelName(name)
if ratio, ok := audioRatioMap[name]; ok {
return ratio
if strings.Contains(name, "-realtime") {
if strings.HasSuffix(name, "gpt-4o-realtime-preview") {
return 8
} else if strings.Contains(name, "gpt-4o-mini-realtime-preview") {
return 10 / 0.6
} else {
return 20
}
}
if strings.Contains(name, "-audio") {
if strings.HasPrefix(name, "gpt-4o-audio-preview") {
return 40 / 2.5
} else if strings.HasPrefix(name, "gpt-4o-mini-audio-preview") {
return 10 / 0.15
} else {
return 40
}
}
return 20
}
func GetAudioCompletionRatio(name string) float64 {
audioCompletionRatioMapMutex.RLock()
defer audioCompletionRatioMapMutex.RUnlock()
name = FormatMatchingModelName(name)
if ratio, ok := audioCompletionRatioMap[name]; ok {
return ratio
if strings.HasPrefix(name, "gpt-4o-realtime") {
return 2
} else if strings.HasPrefix(name, "gpt-4o-mini-realtime") {
return 2
}
return 2
}
@@ -654,14 +630,6 @@ var defaultImageRatio = map[string]float64{
}
var imageRatioMap map[string]float64
var imageRatioMapMutex sync.RWMutex
var (
audioRatioMap map[string]float64 = nil
audioRatioMapMutex = sync.RWMutex{}
)
var (
audioCompletionRatioMap map[string]float64 = nil
audioCompletionRatioMapMutex = sync.RWMutex{}
)
func ImageRatio2JSONString() string {
imageRatioMapMutex.RLock()
@@ -690,71 +658,6 @@ func GetImageRatio(name string) (float64, bool) {
return ratio, true
}
func AudioRatio2JSONString() string {
audioRatioMapMutex.RLock()
defer audioRatioMapMutex.RUnlock()
jsonBytes, err := common.Marshal(audioRatioMap)
if err != nil {
common.SysError("error marshalling audio ratio: " + err.Error())
}
return string(jsonBytes)
}
func UpdateAudioRatioByJSONString(jsonStr string) error {
tmp := make(map[string]float64)
if err := common.Unmarshal([]byte(jsonStr), &tmp); err != nil {
return err
}
audioRatioMapMutex.Lock()
audioRatioMap = tmp
audioRatioMapMutex.Unlock()
InvalidateExposedDataCache()
return nil
}
func GetAudioRatioCopy() map[string]float64 {
audioRatioMapMutex.RLock()
defer audioRatioMapMutex.RUnlock()
copyMap := make(map[string]float64, len(audioRatioMap))
for k, v := range audioRatioMap {
copyMap[k] = v
}
return copyMap
}
func AudioCompletionRatio2JSONString() string {
audioCompletionRatioMapMutex.RLock()
defer audioCompletionRatioMapMutex.RUnlock()
jsonBytes, err := common.Marshal(audioCompletionRatioMap)
if err != nil {
common.SysError("error marshalling audio completion ratio: " + err.Error())
}
return string(jsonBytes)
}
func UpdateAudioCompletionRatioByJSONString(jsonStr string) error {
tmp := make(map[string]float64)
if err := common.Unmarshal([]byte(jsonStr), &tmp); err != nil {
return err
}
audioCompletionRatioMapMutex.Lock()
audioCompletionRatioMap = tmp
audioCompletionRatioMapMutex.Unlock()
InvalidateExposedDataCache()
return nil
}
func GetAudioCompletionRatioCopy() map[string]float64 {
audioCompletionRatioMapMutex.RLock()
defer audioCompletionRatioMapMutex.RUnlock()
copyMap := make(map[string]float64, len(audioCompletionRatioMap))
for k, v := range audioCompletionRatioMap {
copyMap[k] = v
}
return copyMap
}
func GetModelRatioCopy() map[string]float64 {
modelRatioMapMutex.RLock()
defer modelRatioMapMutex.RUnlock()

View File

@@ -1,34 +0,0 @@
package system_setting
import "one-api/setting/config"
type FetchSetting struct {
EnableSSRFProtection bool `json:"enable_ssrf_protection"` // 是否启用SSRF防护
AllowPrivateIp bool `json:"allow_private_ip"`
DomainFilterMode bool `json:"domain_filter_mode"` // 域名过滤模式true: 白名单模式false: 黑名单模式
IpFilterMode bool `json:"ip_filter_mode"` // IP过滤模式true: 白名单模式false: 黑名单模式
DomainList []string `json:"domain_list"` // domain format, e.g. example.com, *.example.com
IpList []string `json:"ip_list"` // CIDR format
AllowedPorts []string `json:"allowed_ports"` // port range format, e.g. 80, 443, 8000-9000
ApplyIPFilterForDomain bool `json:"apply_ip_filter_for_domain"` // 对域名启用IP过滤实验性
}
var defaultFetchSetting = FetchSetting{
EnableSSRFProtection: true, // 默认开启SSRF防护
AllowPrivateIp: false,
DomainFilterMode: false,
IpFilterMode: false,
DomainList: []string{},
IpList: []string{},
AllowedPorts: []string{"80", "443", "8080", "8443"},
ApplyIPFilterForDomain: false,
}
func init() {
// 注册到全局配置管理器
config.GlobalConfig.Register("fetch_setting", &defaultFetchSetting)
}
func GetFetchSetting() *FetchSetting {
return &defaultFetchSetting
}

View File

@@ -0,0 +1,74 @@
package system_setting
import "one-api/setting/config"
type OAuth2Settings struct {
Enabled bool `json:"enabled"`
Issuer string `json:"issuer"`
AccessTokenTTL int `json:"access_token_ttl"` // in minutes
RefreshTokenTTL int `json:"refresh_token_ttl"` // in minutes
AllowedGrantTypes []string `json:"allowed_grant_types"` // client_credentials, authorization_code, refresh_token
RequirePKCE bool `json:"require_pkce"` // force PKCE for authorization code flow
JWTSigningAlgorithm string `json:"jwt_signing_algorithm"`
JWTKeyID string `json:"jwt_key_id"`
JWTPrivateKeyFile string `json:"jwt_private_key_file"`
AutoCreateUser bool `json:"auto_create_user"` // auto create user on first OAuth2 login
DefaultUserRole int `json:"default_user_role"` // default role for auto-created users
DefaultUserGroup string `json:"default_user_group"` // default group for auto-created users
ScopeMappings map[string][]string `json:"scope_mappings"` // scope to permissions mapping
MaxJWKSKeys int `json:"max_jwks_keys"` // maximum number of JWKS signing keys to retain
DefaultPrivateKeyPath string `json:"default_private_key_path"` // suggested private key file path
}
// 默认配置
var defaultOAuth2Settings = OAuth2Settings{
Enabled: false,
AccessTokenTTL: 10, // 10 minutes
RefreshTokenTTL: 720, // 12 hours
AllowedGrantTypes: []string{"client_credentials", "authorization_code", "refresh_token"},
RequirePKCE: true,
JWTSigningAlgorithm: "RS256",
JWTKeyID: "oauth2-key-1",
AutoCreateUser: false,
DefaultUserRole: 1, // common user
DefaultUserGroup: "default",
ScopeMappings: map[string][]string{
"api:read": {"read"},
"api:write": {"write"},
"admin": {"admin"},
},
MaxJWKSKeys: 3,
DefaultPrivateKeyPath: "/etc/new-api/oauth2-private.pem",
}
func init() {
// 注册到全局配置管理器
config.GlobalConfig.Register("oauth2", &defaultOAuth2Settings)
}
func GetOAuth2Settings() *OAuth2Settings {
return &defaultOAuth2Settings
}
// UpdateOAuth2Settings 更新OAuth2配置
func UpdateOAuth2Settings(settings OAuth2Settings) {
defaultOAuth2Settings = settings
}
// ValidateGrantType 验证授权类型是否被允许
func (s *OAuth2Settings) ValidateGrantType(grantType string) bool {
for _, allowedType := range s.AllowedGrantTypes {
if allowedType == grantType {
return true
}
}
return false
}
// GetScopePermissions 获取scope对应的权限
func (s *OAuth2Settings) GetScopePermissions(scope string) []string {
if perms, exists := s.ScopeMappings[scope]; exists {
return perms
}
return []string{}
}

1069
src/oauth/server.go Normal file

File diff suppressed because it is too large Load Diff

82
src/oauth/store.go Normal file
View File

@@ -0,0 +1,82 @@
package oauth
import (
"one-api/common"
"sync"
"time"
)
// KVStore is a minimal TTL key-value abstraction used by OAuth flows.
type KVStore interface {
Set(key, value string, ttl time.Duration) error
Get(key string) (string, bool)
Del(key string) error
}
type redisStore struct{}
func (r *redisStore) Set(key, value string, ttl time.Duration) error {
return common.RedisSet(key, value, ttl)
}
func (r *redisStore) Get(key string) (string, bool) {
v, err := common.RedisGet(key)
if err != nil || v == "" {
return "", false
}
return v, true
}
func (r *redisStore) Del(key string) error {
return common.RedisDel(key)
}
type memEntry struct {
val string
exp int64 // unix seconds, 0 means no expiry
}
type memoryStore struct {
m sync.Map // key -> memEntry
}
func (m *memoryStore) Set(key, value string, ttl time.Duration) error {
var exp int64
if ttl > 0 {
exp = time.Now().Add(ttl).Unix()
}
m.m.Store(key, memEntry{val: value, exp: exp})
return nil
}
func (m *memoryStore) Get(key string) (string, bool) {
v, ok := m.m.Load(key)
if !ok {
return "", false
}
e := v.(memEntry)
if e.exp > 0 && time.Now().Unix() > e.exp {
m.m.Delete(key)
return "", false
}
return e.val, true
}
func (m *memoryStore) Del(key string) error {
m.m.Delete(key)
return nil
}
var (
memStore = &memoryStore{}
rdsStore = &redisStore{}
)
func getStore() KVStore {
if common.RedisEnabled {
return rdsStore
}
return memStore
}
func storeSet(key, val string, ttl time.Duration) error { return getStore().Set(key, val, ttl) }
func storeGet(key string) (string, bool) { return getStore().Get(key) }
func storeDel(key string) error { return getStore().Del(key) }

59
src/oauth/util.go Normal file
View File

@@ -0,0 +1,59 @@
package oauth
import (
"crypto/rand"
"crypto/sha256"
"encoding/base64"
"net/http"
"net/url"
"strings"
"github.com/gin-gonic/gin"
)
// getFormOrBasicAuth extracts client_id/client_secret from Basic Auth first, then form
func getFormOrBasicAuth(c *gin.Context) (clientID, clientSecret string) {
id, secret, ok := c.Request.BasicAuth()
if ok {
return strings.TrimSpace(id), strings.TrimSpace(secret)
}
return strings.TrimSpace(c.PostForm("client_id")), strings.TrimSpace(c.PostForm("client_secret"))
}
// genCode generates URL-safe random string based on nBytes of entropy
func genCode(nBytes int) (string, error) {
b := make([]byte, nBytes)
if _, err := rand.Read(b); err != nil {
return "", err
}
return base64.RawURLEncoding.EncodeToString(b), nil
}
// s256Base64URL computes base64url-encoded SHA256 digest
func s256Base64URL(verifier string) string {
sum := sha256.Sum256([]byte(verifier))
return base64.RawURLEncoding.EncodeToString(sum[:])
}
// writeNoStore sets no-store cache headers for OAuth responses
func writeNoStore(c *gin.Context) {
c.Header("Cache-Control", "no-store")
c.Header("Pragma", "no-cache")
}
// writeOAuthRedirectError builds an error redirect to redirect_uri as RFC6749
func writeOAuthRedirectError(c *gin.Context, redirectURI, errCode, description, state string) {
writeNoStore(c)
q := "error=" + url.QueryEscape(errCode)
if description != "" {
q += "&error_description=" + url.QueryEscape(description)
}
if state != "" {
q += "&state=" + url.QueryEscape(state)
}
sep := "?"
if strings.Contains(redirectURI, "?") {
sep = "&"
}
c.Redirect(http.StatusFound, redirectURI+sep+q)
}

View File

@@ -122,9 +122,6 @@ func (e *NewAPIError) MaskSensitiveError() string {
return string(e.errorCode)
}
errStr := e.Err.Error()
if e.errorCode == ErrorCodeCountTokenFailed {
return errStr
}
return common.MaskSensitiveInfo(errStr)
}
@@ -156,9 +153,8 @@ func (e *NewAPIError) ToOpenAIError() OpenAIError {
Code: e.errorCode,
}
}
if e.errorCode != ErrorCodeCountTokenFailed {
result.Message = common.MaskSensitiveInfo(result.Message)
}
result.Message = common.MaskSensitiveInfo(result.Message)
return result
}
@@ -182,9 +178,7 @@ func (e *NewAPIError) ToClaudeError() ClaudeError {
Type: string(e.errorType),
}
}
if e.errorCode != ErrorCodeCountTokenFailed {
result.Message = common.MaskSensitiveInfo(result.Message)
}
result.Message = common.MaskSensitiveInfo(result.Message)
return result
}

View File

@@ -15,8 +15,6 @@ type PriceData struct {
CacheRatio float64
CacheCreationRatio float64
ImageRatio float64
AudioRatio float64
AudioCompletionRatio float64
UsePrice bool
ShouldPreConsumedQuota int
GroupRatioInfo GroupRatioInfo
@@ -29,5 +27,5 @@ type PerCallPriceData struct {
}
func (p PriceData) ToSetting() string {
return fmt.Sprintf("ModelPrice: %f, ModelRatio: %f, CompletionRatio: %f, CacheRatio: %f, GroupRatio: %f, UsePrice: %t, CacheCreationRatio: %f, ShouldPreConsumedQuota: %d, ImageRatio: %f, AudioRatio: %f, AudioCompletionRatio: %f", p.ModelPrice, p.ModelRatio, p.CompletionRatio, p.CacheRatio, p.GroupRatioInfo.GroupRatio, p.UsePrice, p.CacheCreationRatio, p.ShouldPreConsumedQuota, p.ImageRatio, p.AudioRatio, p.AudioCompletionRatio)
return fmt.Sprintf("ModelPrice: %f, ModelRatio: %f, CompletionRatio: %f, CacheRatio: %f, GroupRatio: %f, UsePrice: %t, CacheCreationRatio: %f, ShouldPreConsumedQuota: %d, ImageRatio: %f", p.ModelPrice, p.ModelRatio, p.CompletionRatio, p.CacheRatio, p.GroupRatioInfo.GroupRatio, p.UsePrice, p.CacheCreationRatio, p.ShouldPreConsumedQuota, p.ImageRatio)
}

View File

@@ -10,7 +10,6 @@
content="OpenAI 接口聚合管理,支持多种渠道包括 Azure可用于二次分发管理 key仅单可执行文件已打包好 Docker 镜像,一键部署,开箱即用"
/>
<title>New API</title>
<analytics></analytics>
</head>
<body>

View File

@@ -1,9 +0,0 @@
{
"compilerOptions": {
"baseUrl": "./",
"paths": {
"@/*": ["src/*"]
}
},
"include": ["src/**/*"]
}

662
web/public/oauth-demo.html Normal file
View File

@@ -0,0 +1,662 @@
<!-- This file is a copy of examples/oauth-demo.html for direct serving under /oauth-demo.html -->
<!doctype html>
<html lang="zh-CN">
<head>
<meta charset="utf-8" />
<meta name="viewport" content="width=device-width, initial-scale=1" />
<title>OAuth2/OIDC 授权码 + PKCE 前端演示</title>
<style>
:root {
--bg: #0b0c10;
--panel: #111317;
--muted: #aab2bf;
--accent: #3b82f6;
--ok: #16a34a;
--warn: #f59e0b;
--err: #ef4444;
--border: #1f2430;
}
body {
margin: 0;
font-family:
ui-sans-serif,
system-ui,
-apple-system,
Segoe UI,
Roboto,
Helvetica,
Arial;
background: var(--bg);
color: #e5e7eb;
}
.wrap {
max-width: 980px;
margin: 32px auto;
padding: 0 16px;
}
h1 {
font-size: 22px;
margin: 0 0 16px;
}
.card {
background: var(--panel);
border: 1px solid var(--border);
border-radius: 10px;
padding: 16px;
margin: 12px 0;
}
.row {
display: flex;
gap: 12px;
flex-wrap: wrap;
}
.col {
flex: 1 1 280px;
display: flex;
flex-direction: column;
}
label {
font-size: 12px;
color: var(--muted);
margin-bottom: 6px;
}
input,
textarea,
select {
background: #0f1115;
color: #e5e7eb;
border: 1px solid var(--border);
padding: 10px 12px;
border-radius: 8px;
outline: none;
}
textarea {
min-height: 100px;
resize: vertical;
}
.btns {
display: flex;
gap: 8px;
flex-wrap: wrap;
margin-top: 8px;
}
button {
background: #1a1f2b;
color: #e5e7eb;
border: 1px solid var(--border);
padding: 8px 12px;
border-radius: 8px;
cursor: pointer;
}
button.primary {
background: var(--accent);
border-color: var(--accent);
color: white;
}
button.ok {
background: var(--ok);
border-color: var(--ok);
color: white;
}
button.warn {
background: var(--warn);
border-color: var(--warn);
color: black;
}
button.ghost {
background: transparent;
}
.muted {
color: var(--muted);
font-size: 12px;
}
.mono {
font-family: ui-monospace, SFMono-Regular, Menlo, Monaco, Consolas,
'Liberation Mono', 'Courier New', monospace;
}
.grid2 {
display: grid;
grid-template-columns: 1fr 1fr;
gap: 12px;
}
@media (max-width: 880px) {
.grid2 {
grid-template-columns: 1fr;
}
}
.ok {
color: #10b981;
}
.err {
color: #ef4444;
}
.sep {
height: 1px;
background: var(--border);
margin: 12px 0;
}
</style>
</head>
<body>
<div class="wrap">
<h1>OAuth2/OIDC 授权码 + PKCE 前端演示</h1>
<div class="card">
<div class="row">
<div class="col">
<label
>Issuer可选用于自动发现
/.well-known/openid-configuration</label
>
<input id="issuer" placeholder="https://your-domain" />
<div class="btns">
<button class="" id="btnDiscover">自动发现端点</button>
</div>
<div class="muted">提示:若未配置 Issuer可直接填写下方端点。</div>
</div>
</div>
<div class="row">
<div class="col">
<label>Response Type</label>
<select id="response_type">
<option value="code" selected>code</option>
<option value="token">token</option>
</select>
</div>
<div class="col">
<label>Authorization Endpoint</label
><input
id="authorization_endpoint"
placeholder="https://domain/api/oauth/authorize"
/>
</div>
<div class="col">
<label>Token Endpoint</label
><input
id="token_endpoint"
placeholder="https://domain/api/oauth/token"
/>
</div>
</div>
<div class="row">
<div class="col">
<label>UserInfo Endpoint可选</label
><input
id="userinfo_endpoint"
placeholder="https://domain/api/oauth/userinfo"
/>
</div>
<div class="col">
<label>Client ID</label
><input id="client_id" placeholder="your-public-client-id" />
</div>
</div>
<div class="row">
<div class="col">
<label>Client Secret可选机密客户端</label
><input id="client_secret" placeholder="留空表示公开客户端" />
</div>
</div>
<div class="row">
<div class="col">
<label>Redirect URI当前页地址或你的回调</label
><input id="redirect_uri" />
</div>
<div class="col">
<label>Scope</label
><input id="scope" value="openid profile email" />
</div>
</div>
<div class="row">
<div class="col"><label>State</label><input id="state" /></div>
<div class="col"><label>Nonce</label><input id="nonce" /></div>
</div>
<div class="row">
<div class="col">
<label>Code Verifier自动生成不会上送</label
><input id="code_verifier" class="mono" readonly />
</div>
<div class="col">
<label>Code ChallengeS256</label
><input id="code_challenge" class="mono" readonly />
</div>
</div>
<div class="btns">
<button id="btnGenPkce">生成 PKCE</button>
<button id="btnRandomState">随机 State</button>
<button id="btnRandomNonce">随机 Nonce</button>
<button id="btnMakeAuthURL">生成授权链接</button>
<button id="btnAuthorize" class="primary">跳转授权</button>
</div>
<div class="row" style="margin-top: 8px">
<div class="col">
<label>授权链接(只生成不跳转)</label>
<textarea
id="authorize_url"
class="mono"
placeholder="(空)"
></textarea>
<div class="btns">
<button id="btnCopyAuthURL">复制链接</button>
</div>
</div>
</div>
<div class="sep"></div>
<div class="muted">
说明:
<ul>
<li>
本页为纯前端演示,适用于公开客户端(不需要 client_secret
</li>
<li>
如跨域调用 Token/UserInfo需要服务端正确设置 CORS建议将此 demo
部署到同源域名下。
</li>
</ul>
</div>
<div class="sep"></div>
<div class="row">
<div class="col">
<label
>粘贴 OIDC Discovery
JSON/.well-known/openid-configuration</label
>
<textarea
id="conf_json"
class="mono"
placeholder='{"issuer":"https://...","authorization_endpoint":"...","token_endpoint":"...","userinfo_endpoint":"..."}'
></textarea>
<div class="btns">
<button id="btnParseConf">解析并填充端点</button>
<button id="btnGenConf">用当前端点生成 JSON</button>
</div>
<div class="muted">
可将服务端返回的 OIDC Discovery JSON
粘贴到此处,点击“解析并填充端点”。
</div>
</div>
</div>
</div>
<div class="card">
<div class="row">
<div class="col">
<label>授权结果</label>
<div id="authResult" class="muted">等待授权...</div>
</div>
</div>
<div class="grid2" style="margin-top: 12px">
<div>
<label>Access Token</label>
<textarea
id="access_token"
class="mono"
placeholder="(空)"
></textarea>
<div class="btns">
<button id="btnCopyAT">复制</button
><button id="btnCallUserInfo" class="ok">调用 UserInfo</button>
</div>
<div id="userinfoOut" class="muted" style="margin-top: 6px"></div>
</div>
<div>
<label>ID TokenJWT</label>
<textarea id="id_token" class="mono" placeholder="(空)"></textarea>
<div class="btns">
<button id="btnDecodeJWT">解码显示 Claims</button>
</div>
<pre
id="jwtClaims"
class="mono"
style="
white-space: pre-wrap;
word-break: break-all;
margin-top: 6px;
"
></pre>
</div>
</div>
<div class="grid2" style="margin-top: 12px">
<div>
<label>Refresh Token</label>
<textarea
id="refresh_token"
class="mono"
placeholder="(空)"
></textarea>
<div class="btns">
<button id="btnRefreshToken">使用 Refresh Token 刷新</button>
</div>
</div>
<div>
<label>原始 Token 响应</label>
<textarea id="token_raw" class="mono" placeholder="(空)"></textarea>
</div>
</div>
</div>
</div>
<script>
const $ = (id) => document.getElementById(id);
const toB64Url = (buf) =>
btoa(String.fromCharCode(...new Uint8Array(buf)))
.replace(/\+/g, '-')
.replace(/\//g, '_')
.replace(/=+$/, '');
async function sha256B64Url(str) {
const data = new TextEncoder().encode(str);
const digest = await crypto.subtle.digest('SHA-256', data);
return toB64Url(digest);
}
function randStr(len = 64) {
const chars =
'ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789-._~';
const arr = new Uint8Array(len);
crypto.getRandomValues(arr);
return Array.from(arr, (v) => chars[v % chars.length]).join('');
}
function setAuthInfo(msg, ok = true) {
const el = $('authResult');
el.textContent = msg;
el.className = ok ? 'ok' : 'err';
}
function qs(name) {
const u = new URL(location.href);
return u.searchParams.get(name);
}
function persist(k, v) {
sessionStorage.setItem('demo_' + k, v);
}
function load(k) {
return sessionStorage.getItem('demo_' + k) || '';
}
(function init() {
$('redirect_uri').value =
window.location.origin + window.location.pathname;
const iss = load('issuer');
if (iss) $('issuer').value = iss;
const cid = load('client_id');
if (cid) $('client_id').value = cid;
const scp = load('scope');
if (scp) $('scope').value = scp;
})();
$('btnDiscover').onclick = async () => {
const iss = $('issuer').value.trim();
if (!iss) {
alert('请填写 Issuer');
return;
}
try {
persist('issuer', iss);
const res = await fetch(
iss.replace(/\/$/, '') + '/api/.well-known/openid-configuration',
);
const d = await res.json();
$('authorization_endpoint').value = d.authorization_endpoint || '';
$('token_endpoint').value = d.token_endpoint || '';
$('userinfo_endpoint').value = d.userinfo_endpoint || '';
if (d.issuer) {
$('issuer').value = d.issuer;
persist('issuer', d.issuer);
}
$('conf_json').value = JSON.stringify(d, null, 2);
setAuthInfo('已从发现文档加载端点', true);
} catch (e) {
setAuthInfo('自动发现失败:' + e, false);
}
};
$('btnGenPkce').onclick = async () => {
const v = randStr(64);
const c = await sha256B64Url(v);
$('code_verifier').value = v;
$('code_challenge').value = c;
persist('code_verifier', v);
persist('code_challenge', c);
setAuthInfo('已生成 PKCE 参数', true);
};
$('btnRandomState').onclick = () => {
$('state').value = randStr(16);
persist('state', $('state').value);
};
$('btnRandomNonce').onclick = () => {
$('nonce').value = randStr(16);
persist('nonce', $('nonce').value);
};
function buildAuthorizeURLFromFields() {
const auth = $('authorization_endpoint').value.trim();
const token = $('token_endpoint').value.trim();
const cid = $('client_id').value.trim();
const red = $('redirect_uri').value.trim();
const scp = $('scope').value.trim() || 'openid profile email';
const rt = $('response_type').value;
const st = $('state').value.trim() || randStr(16);
const no = $('nonce').value.trim() || randStr(16);
const cc = $('code_challenge').value.trim();
const cv = $('code_verifier').value.trim();
if (!auth || !cid || !red) {
throw new Error('请先完善端点/ClientID/RedirectURI');
}
if (rt === 'code' && (!cc || !cv)) {
throw new Error('请先生成 PKCE');
}
persist('authorization_endpoint', auth);
persist('token_endpoint', token);
persist('client_id', cid);
persist('redirect_uri', red);
persist('scope', scp);
persist('state', st);
persist('nonce', no);
persist('code_verifier', cv);
const u = new URL(auth);
u.searchParams.set('response_type', rt);
u.searchParams.set('client_id', cid);
u.searchParams.set('redirect_uri', red);
u.searchParams.set('scope', scp);
u.searchParams.set('state', st);
if (no) u.searchParams.set('nonce', no);
if (rt === 'code') {
u.searchParams.set('code_challenge', cc);
u.searchParams.set('code_challenge_method', 'S256');
}
return u.toString();
}
$('btnMakeAuthURL').onclick = () => {
try {
const url = buildAuthorizeURLFromFields();
$('authorize_url').value = url;
setAuthInfo('已生成授权链接', true);
} catch (e) {
setAuthInfo(e.message, false);
}
};
$('btnAuthorize').onclick = () => {
try {
const url = buildAuthorizeURLFromFields();
location.href = url;
} catch (e) {
setAuthInfo(e.message, false);
}
};
$('btnCopyAuthURL').onclick = async () => {
try {
await navigator.clipboard.writeText($('authorize_url').value);
} catch {}
};
async function postForm(url, data, basic) {
const body = Object.entries(data)
.map(([k, v]) => `${encodeURIComponent(k)}=${encodeURIComponent(v)}`)
.join('&');
const headers = { 'Content-Type': 'application/x-www-form-urlencoded' };
if (basic && basic.id && basic.secret) {
headers['Authorization'] =
'Basic ' + btoa(`${basic.id}:${basic.secret}`);
}
const res = await fetch(url, { method: 'POST', headers, body });
if (!res.ok) {
const t = await res.text();
throw new Error(`HTTP ${res.status} ${t}`);
}
return res.json();
}
async function handleCallback() {
const frag =
location.hash && location.hash.startsWith('#')
? new URLSearchParams(location.hash.slice(1))
: null;
const at = frag ? frag.get('access_token') : null;
const err = qs('error') || (frag ? frag.get('error') : null);
const state = qs('state') || (frag ? frag.get('state') : null);
if (err) {
setAuthInfo('授权失败:' + err, false);
return;
}
if (at) {
$('access_token').value = at || '';
$('token_raw').value = JSON.stringify(
{
access_token: at,
token_type: frag.get('token_type'),
expires_in: frag.get('expires_in'),
scope: frag.get('scope'),
state,
},
null,
2,
);
setAuthInfo('隐式模式已获取 Access Token', true);
return;
}
const code = qs('code');
if (!code) {
setAuthInfo('等待授权...', true);
return;
}
if (state && load('state') && state !== load('state')) {
setAuthInfo('state 不匹配,已拒绝', false);
return;
}
try {
const tokenEp = load('token_endpoint');
const cid = load('client_id');
const csec = $('client_secret').value.trim();
const basic = csec ? { id: cid, secret: csec } : null;
const data = await postForm(
tokenEp,
{
grant_type: 'authorization_code',
code,
client_id: cid,
redirect_uri: load('redirect_uri'),
code_verifier: load('code_verifier'),
},
basic,
);
$('access_token').value = data.access_token || '';
$('id_token').value = data.id_token || '';
$('refresh_token').value = data.refresh_token || '';
$('token_raw').value = JSON.stringify(data, null, 2);
setAuthInfo('授权成功,已获取令牌', true);
} catch (e) {
setAuthInfo('交换令牌失败:' + e.message, false);
}
}
handleCallback();
$('btnCopyAT').onclick = async () => {
try {
await navigator.clipboard.writeText($('access_token').value);
} catch {}
};
$('btnDecodeJWT').onclick = () => {
const t = $('id_token').value.trim();
if (!t) {
$('jwtClaims').textContent = '(空)';
return;
}
const parts = t.split('.');
if (parts.length < 2) {
$('jwtClaims').textContent = '格式错误';
return;
}
try {
const json = JSON.parse(
atob(parts[1].replace(/-/g, '+').replace(/_/g, '/')),
);
$('jwtClaims').textContent = JSON.stringify(json, null, 2);
} catch (e) {
$('jwtClaims').textContent = '解码失败:' + e;
}
};
$('btnCallUserInfo').onclick = async () => {
const at = $('access_token').value.trim();
const ep = $('userinfo_endpoint').value.trim();
if (!at || !ep) {
alert('请填写UserInfo端点并获取AccessToken');
return;
}
try {
const res = await fetch(ep, {
headers: { Authorization: 'Bearer ' + at },
});
const data = await res.json();
$('userinfoOut').textContent = JSON.stringify(data, null, 2);
} catch (e) {
$('userinfoOut').textContent = '调用失败:' + e;
}
};
$('btnRefreshToken').onclick = async () => {
const rt = $('refresh_token').value.trim();
if (!rt) {
alert('没有刷新令牌');
return;
}
try {
const tokenEp = load('token_endpoint');
const cid = load('client_id');
const csec = $('client_secret').value.trim();
const basic = csec ? { id: cid, secret: csec } : null;
const data = await postForm(
tokenEp,
{ grant_type: 'refresh_token', refresh_token: rt, client_id: cid },
basic,
);
$('access_token').value = data.access_token || '';
$('id_token').value = data.id_token || '';
$('refresh_token').value = data.refresh_token || '';
$('token_raw').value = JSON.stringify(data, null, 2);
setAuthInfo('刷新成功', true);
} catch (e) {
setAuthInfo('刷新失败:' + e.message, false);
}
};
$('btnParseConf').onclick = () => {
const txt = $('conf_json').value.trim();
if (!txt) {
alert('请先粘贴 JSON');
return;
}
try {
const d = JSON.parse(txt);
if (d.issuer) {
$('issuer').value = d.issuer;
persist('issuer', d.issuer);
}
if (d.authorization_endpoint)
$('authorization_endpoint').value = d.authorization_endpoint;
if (d.token_endpoint) $('token_endpoint').value = d.token_endpoint;
if (d.userinfo_endpoint)
$('userinfo_endpoint').value = d.userinfo_endpoint;
setAuthInfo('已解析配置并填充端点', true);
} catch (e) {
setAuthInfo('解析失败:' + e, false);
}
};
$('btnGenConf').onclick = () => {
const d = {
issuer: $('issuer').value.trim() || undefined,
authorization_endpoint:
$('authorization_endpoint').value.trim() || undefined,
token_endpoint: $('token_endpoint').value.trim() || undefined,
userinfo_endpoint: $('userinfo_endpoint').value.trim() || undefined,
};
$('conf_json').value = JSON.stringify(d, null, 2);
};
</script>
</body>
</html>

View File

@@ -44,6 +44,7 @@ import Task from './pages/Task';
import ModelPage from './pages/Model';
import Playground from './pages/Playground';
import OAuth2Callback from './components/auth/OAuth2Callback';
import OAuthConsent from './pages/OAuth';
import PersonalSetting from './components/settings/PersonalSetting';
import Setup from './pages/Setup';
import SetupCheck from './components/layout/SetupCheck';
@@ -198,6 +199,14 @@ function App() {
</Suspense>
}
/>
<Route
path='/oauth/consent'
element={
<Suspense fallback={<Loading></Loading>}>
<OAuthConsent />
</Suspense>
}
/>
<Route
path='/oauth/linuxdo'
element={

View File

@@ -176,7 +176,11 @@ const LoginForm = () => {
centered: true,
});
}
navigate('/console');
// 优先跳回 next仅允许相对路径
const sp = new URLSearchParams(window.location.search);
const next = sp.get('next');
const isSafeInternalPath = next && next.startsWith('/') && !next.startsWith('//');
navigate(isSafeInternalPath ? next : '/console');
} else {
showError(message);
}
@@ -286,7 +290,10 @@ const LoginForm = () => {
setUserData(data);
updateAPI();
showSuccess('登录成功!');
navigate('/console');
const sp = new URLSearchParams(window.location.search);
const next = sp.get('next');
const isSafeInternalPath = next && next.startsWith('/') && !next.startsWith('//');
navigate(isSafeInternalPath ? next : '/console');
};
// 返回登录页面

View File

@@ -181,8 +181,8 @@ export function PreCode(props) {
e.preventDefault();
e.stopPropagation();
if (ref.current) {
const codeElement = ref.current.querySelector('code');
const code = codeElement?.textContent ?? '';
const code =
ref.current.querySelector('code')?.innerText ?? '';
copy(code).then((success) => {
if (success) {
Toast.success(t('代码已复制到剪贴板'));

View File

@@ -135,7 +135,9 @@ const TwoFactorAuthModal = ({
autoFocus
/>
<Typography.Text type='tertiary' size='small' className='mt-2 block'>
{t('支持6位TOTP验证码或8位备用码可到`个人设置-安全设置-两步验证设置`配置或查看。')}
{t(
'支持6位TOTP验证码或8位备用码可到`个人设置-安全设置-两步验证设置`配置或查看。',
)}
</Typography.Text>
</div>
</div>

View File

@@ -21,7 +21,7 @@ import React, { useState, useMemo, useCallback } from 'react';
import { Button, Tooltip, Toast } from '@douyinfe/semi-ui';
import { Copy, ChevronDown, ChevronUp } from 'lucide-react';
import { useTranslation } from 'react-i18next';
import { copy } from '../../helpers';
import { copy } from '../../../helpers';
const PERFORMANCE_CONFIG = {
MAX_DISPLAY_LENGTH: 50000, //

View File

@@ -0,0 +1,135 @@
/*
Copyright (C) 2025 QuantumNous
This program is free software: you can redistribute it and/or modify
it under the terms of the GNU Affero General Public License as
published by the Free Software Foundation, either version 3 of the
License, or (at your option) any later version.
This program is distributed in the hope that it will be useful,
but WITHOUT ANY WARRANTY; without even the implied warranty of
MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
GNU Affero General Public License for more details.
You should have received a copy of the GNU Affero General Public License
along with this program. If not, see <https://www.gnu.org/licenses/>.
For commercial licensing, please contact support@quantumnous.com
*/
import React from 'react';
import { Modal, Typography } from '@douyinfe/semi-ui';
import PropTypes from 'prop-types';
import { useIsMobile } from '../../../hooks/common/useIsMobile';
const { Title } = Typography;
/**
* ResponsiveModal 响应式模态框组件
*
* 特性:
* - 响应式布局:移动端和桌面端不同的宽度和布局
* - 自定义头部:标题左对齐,操作按钮右对齐,移动端自动换行
* - Tailwind CSS 样式支持
* - 保持原 Modal 组件的所有功能
*/
const ResponsiveModal = ({
visible,
onCancel,
title,
headerActions = [],
children,
width = { mobile: '95%', desktop: 600 },
className = '',
footer = null,
titleProps = {},
headerClassName = '',
actionsClassName = '',
...props
}) => {
const isMobile = useIsMobile();
// 自定义 Header 组件
const CustomHeader = () => {
if (!title && (!headerActions || headerActions.length === 0)) return null;
return (
<div
className={`flex w-full gap-3 justify-between ${
isMobile ? 'flex-col items-start' : 'flex-row items-center'
} ${headerClassName}`}
>
{title && (
<Title heading={5} className='m-0 min-w-fit' {...titleProps}>
{title}
</Title>
)}
{headerActions && headerActions.length > 0 && (
<div
className={`flex flex-wrap gap-2 items-center ${
isMobile ? 'w-full justify-start' : 'w-auto justify-end'
} ${actionsClassName}`}
>
{headerActions.map((action, index) => (
<React.Fragment key={index}>{action}</React.Fragment>
))}
</div>
)}
</div>
);
};
// 计算模态框宽度
const getModalWidth = () => {
if (typeof width === 'object') {
return isMobile ? width.mobile : width.desktop;
}
return width;
};
return (
<Modal
visible={visible}
title={<CustomHeader />}
onCancel={onCancel}
footer={footer}
width={getModalWidth()}
className={`!top-12 ${className}`}
{...props}
>
{children}
</Modal>
);
};
ResponsiveModal.propTypes = {
// Modal 基础属性
visible: PropTypes.bool.isRequired,
onCancel: PropTypes.func.isRequired,
children: PropTypes.node,
// 自定义头部
title: PropTypes.oneOfType([PropTypes.string, PropTypes.node]),
headerActions: PropTypes.arrayOf(PropTypes.node),
// 样式和布局
width: PropTypes.oneOfType([
PropTypes.number,
PropTypes.string,
PropTypes.shape({
mobile: PropTypes.oneOfType([PropTypes.number, PropTypes.string]),
desktop: PropTypes.oneOfType([PropTypes.number, PropTypes.string]),
}),
]),
className: PropTypes.string,
footer: PropTypes.node,
// 标题自定义属性
titleProps: PropTypes.object,
// 自定义 CSS 类
headerClassName: PropTypes.string,
actionsClassName: PropTypes.string,
};
export default ResponsiveModal;

View File

@@ -17,7 +17,7 @@ along with this program. If not, see <https://www.gnu.org/licenses/>.
For commercial licensing, please contact support@quantumnous.com
*/
import React, { useRef } from 'react';
import React from 'react';
import { Link } from 'react-router-dom';
import { Avatar, Button, Dropdown, Typography } from '@douyinfe/semi-ui';
import { ChevronDown } from 'lucide-react';
@@ -39,7 +39,6 @@ const UserArea = ({
navigate,
t,
}) => {
const dropdownRef = useRef(null);
if (isLoading) {
return (
<SkeletonWrapper
@@ -53,93 +52,90 @@ const UserArea = ({
if (userState.user) {
return (
<div className='relative' ref={dropdownRef}>
<Dropdown
position='bottomRight'
getPopupContainer={() => dropdownRef.current}
render={
<Dropdown.Menu className='!bg-semi-color-bg-overlay !border-semi-color-border !shadow-lg !rounded-lg dark:!bg-gray-700 dark:!border-gray-600'>
<Dropdown.Item
onClick={() => {
navigate('/console/personal');
}}
className='!px-3 !py-1.5 !text-sm !text-semi-color-text-0 hover:!bg-semi-color-fill-1 dark:!text-gray-200 dark:hover:!bg-blue-500 dark:hover:!text-white'
>
<div className='flex items-center gap-2'>
<IconUserSetting
size='small'
className='text-gray-500 dark:text-gray-400'
/>
<span>{t('个人设置')}</span>
</div>
</Dropdown.Item>
<Dropdown.Item
onClick={() => {
navigate('/console/token');
}}
className='!px-3 !py-1.5 !text-sm !text-semi-color-text-0 hover:!bg-semi-color-fill-1 dark:!text-gray-200 dark:hover:!bg-blue-500 dark:hover:!text-white'
>
<div className='flex items-center gap-2'>
<IconKey
size='small'
className='text-gray-500 dark:text-gray-400'
/>
<span>{t('令牌管理')}</span>
</div>
</Dropdown.Item>
<Dropdown.Item
onClick={() => {
navigate('/console/topup');
}}
className='!px-3 !py-1.5 !text-sm !text-semi-color-text-0 hover:!bg-semi-color-fill-1 dark:!text-gray-200 dark:hover:!bg-blue-500 dark:hover:!text-white'
>
<div className='flex items-center gap-2'>
<IconCreditCard
size='small'
className='text-gray-500 dark:text-gray-400'
/>
<span>{t('钱包管理')}</span>
</div>
</Dropdown.Item>
<Dropdown.Item
onClick={logout}
className='!px-3 !py-1.5 !text-sm !text-semi-color-text-0 hover:!bg-semi-color-fill-1 dark:!text-gray-200 dark:hover:!bg-red-500 dark:hover:!text-white'
>
<div className='flex items-center gap-2'>
<IconExit
size='small'
className='text-gray-500 dark:text-gray-400'
/>
<span>{t('退出')}</span>
</div>
</Dropdown.Item>
</Dropdown.Menu>
}
>
<Button
theme='borderless'
type='tertiary'
className='flex items-center gap-1.5 !p-1 !rounded-full hover:!bg-semi-color-fill-1 dark:hover:!bg-gray-700 !bg-semi-color-fill-0 dark:!bg-semi-color-fill-1 dark:hover:!bg-semi-color-fill-2'
>
<Avatar
size='extra-small'
color={stringToColor(userState.user.username)}
className='mr-1'
<Dropdown
position='bottomRight'
render={
<Dropdown.Menu className='!bg-semi-color-bg-overlay !border-semi-color-border !shadow-lg !rounded-lg dark:!bg-gray-700 dark:!border-gray-600'>
<Dropdown.Item
onClick={() => {
navigate('/console/personal');
}}
className='!px-3 !py-1.5 !text-sm !text-semi-color-text-0 hover:!bg-semi-color-fill-1 dark:!text-gray-200 dark:hover:!bg-blue-500 dark:hover:!text-white'
>
{userState.user.username[0].toUpperCase()}
</Avatar>
<span className='hidden md:inline'>
<Typography.Text className='!text-xs !font-medium !text-semi-color-text-1 dark:!text-gray-300 mr-1'>
{userState.user.username}
</Typography.Text>
</span>
<ChevronDown
size={14}
className='text-xs text-semi-color-text-2 dark:text-gray-400'
/>
</Button>
</Dropdown>
</div>
<div className='flex items-center gap-2'>
<IconUserSetting
size='small'
className='text-gray-500 dark:text-gray-400'
/>
<span>{t('个人设置')}</span>
</div>
</Dropdown.Item>
<Dropdown.Item
onClick={() => {
navigate('/console/token');
}}
className='!px-3 !py-1.5 !text-sm !text-semi-color-text-0 hover:!bg-semi-color-fill-1 dark:!text-gray-200 dark:hover:!bg-blue-500 dark:hover:!text-white'
>
<div className='flex items-center gap-2'>
<IconKey
size='small'
className='text-gray-500 dark:text-gray-400'
/>
<span>{t('令牌管理')}</span>
</div>
</Dropdown.Item>
<Dropdown.Item
onClick={() => {
navigate('/console/topup');
}}
className='!px-3 !py-1.5 !text-sm !text-semi-color-text-0 hover:!bg-semi-color-fill-1 dark:!text-gray-200 dark:hover:!bg-blue-500 dark:hover:!text-white'
>
<div className='flex items-center gap-2'>
<IconCreditCard
size='small'
className='text-gray-500 dark:text-gray-400'
/>
<span>{t('钱包管理')}</span>
</div>
</Dropdown.Item>
<Dropdown.Item
onClick={logout}
className='!px-3 !py-1.5 !text-sm !text-semi-color-text-0 hover:!bg-semi-color-fill-1 dark:!text-gray-200 dark:hover:!bg-red-500 dark:hover:!text-white'
>
<div className='flex items-center gap-2'>
<IconExit
size='small'
className='text-gray-500 dark:text-gray-400'
/>
<span>{t('退出')}</span>
</div>
</Dropdown.Item>
</Dropdown.Menu>
}
>
<Button
theme='borderless'
type='tertiary'
className='flex items-center gap-1.5 !p-1 !rounded-full hover:!bg-semi-color-fill-1 dark:hover:!bg-gray-700 !bg-semi-color-fill-0 dark:!bg-semi-color-fill-1 dark:hover:!bg-semi-color-fill-2'
>
<Avatar
size='extra-small'
color={stringToColor(userState.user.username)}
className='mr-1'
>
{userState.user.username[0].toUpperCase()}
</Avatar>
<span className='hidden md:inline'>
<Typography.Text className='!text-xs !font-medium !text-semi-color-text-1 dark:!text-gray-300 mr-1'>
{userState.user.username}
</Typography.Text>
</span>
<ChevronDown
size={14}
className='text-xs text-semi-color-text-2 dark:text-gray-400'
/>
</Button>
</Dropdown>
);
} else {
const showRegisterButton = !isSelfUseMode;

View File

@@ -28,7 +28,7 @@ import {
} from '@douyinfe/semi-ui';
import { Code, Zap, Clock, X, Eye, Send } from 'lucide-react';
import { useTranslation } from 'react-i18next';
import CodeViewer from './CodeViewer';
import CodeViewer from '../common/ui/CodeViewer';
const DebugPanel = ({
debugData,

View File

@@ -0,0 +1,72 @@
/*
Copyright (C) 2025 QuantumNous
This program is free software: you can redistribute it and/or modify
it under the terms of the GNU Affero General Public License as
published by the Free Software Foundation, either version 3 of the
License, or (at your option) any later version.
This program is distributed in the hope that it will be useful,
but WITHOUT ANY WARRANTY; without even the implied warranty of
MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
GNU Affero General Public License for more details.
You should have received a copy of the GNU Affero General Public License
along with this program. If not, see <https://www.gnu.org/licenses/>.
For commercial licensing, please contact support@quantumnous.com
*/
import React, { useEffect, useState } from 'react';
import { Spin } from '@douyinfe/semi-ui';
import { API, showError } from '../../helpers';
import { useTranslation } from 'react-i18next';
import OAuth2ServerSettings from './oauth2/OAuth2ServerSettings';
import OAuth2ClientSettings from './oauth2/OAuth2ClientSettings';
const OAuth2Setting = () => {
const { t } = useTranslation();
const [options, setOptions] = useState({});
const [loading, setLoading] = useState(false);
const getOptions = async () => {
setLoading(true);
try {
const res = await API.get('/api/option/');
const { success, message, data } = res.data;
if (success) {
const map = {};
for (const item of data) {
map[item.key] = item.value;
}
setOptions(map);
} else {
showError(message);
}
} catch (error) {
showError(t('获取OAuth2设置失败'));
} finally {
setLoading(false);
}
};
const refresh = () => {
getOptions();
};
useEffect(() => {
getOptions();
}, []);
return (
<Spin spinning={loading} size='large'>
{/* 服务器配置 */}
<OAuth2ServerSettings options={options} refresh={refresh} />
{/* 客户端管理 */}
<OAuth2ClientSettings />
</Spin>
);
};
export default OAuth2Setting;

View File

@@ -45,7 +45,6 @@ const PaymentSetting = () => {
StripePriceId: '',
StripeUnitPrice: 8.0,
StripeMinTopUp: 1,
StripePromotionCodesEnabled: false,
});
let [loading, setLoading] = useState(false);

View File

@@ -19,14 +19,7 @@ For commercial licensing, please contact support@quantumnous.com
import React, { useContext, useEffect, useState } from 'react';
import { useNavigate } from 'react-router-dom';
import {
API,
copy,
showError,
showInfo,
showSuccess,
setStatusData,
} from '../../helpers';
import { API, copy, showError, showInfo, showSuccess } from '../../helpers';
import { UserContext } from '../../context/User';
import { Modal } from '@douyinfe/semi-ui';
import { useTranslation } from 'react-i18next';
@@ -78,40 +71,18 @@ const PersonalSetting = () => {
});
useEffect(() => {
let saved = localStorage.getItem('status');
if (saved) {
const parsed = JSON.parse(saved);
setStatus(parsed);
if (parsed.turnstile_check) {
let status = localStorage.getItem('status');
if (status) {
status = JSON.parse(status);
setStatus(status);
if (status.turnstile_check) {
setTurnstileEnabled(true);
setTurnstileSiteKey(parsed.turnstile_site_key);
} else {
setTurnstileEnabled(false);
setTurnstileSiteKey('');
setTurnstileSiteKey(status.turnstile_site_key);
}
}
// Always refresh status from server to avoid stale flags (e.g., admin just enabled OAuth)
(async () => {
try {
const res = await API.get('/api/status');
const { success, data } = res.data;
if (success && data) {
setStatus(data);
setStatusData(data);
if (data.turnstile_check) {
setTurnstileEnabled(true);
setTurnstileSiteKey(data.turnstile_site_key);
} else {
setTurnstileEnabled(false);
setTurnstileSiteKey('');
}
}
} catch (e) {
// ignore and keep local status
}
})();
getUserData();
getUserData().then((res) => {
console.log(userState);
});
}, []);
useEffect(() => {

View File

@@ -39,9 +39,6 @@ const RatioSetting = () => {
CompletionRatio: '',
GroupRatio: '',
GroupGroupRatio: '',
ImageRatio: '',
AudioRatio: '',
AudioCompletionRatio: '',
AutoGroups: '',
DefaultUseAutoGroup: false,
ExposeRatioEnabled: false,
@@ -64,10 +61,7 @@ const RatioSetting = () => {
item.key === 'UserUsableGroups' ||
item.key === 'CompletionRatio' ||
item.key === 'ModelPrice' ||
item.key === 'CacheRatio' ||
item.key === 'ImageRatio' ||
item.key === 'AudioRatio' ||
item.key === 'AudioCompletionRatio'
item.key === 'CacheRatio'
) {
try {
item.value = JSON.stringify(JSON.parse(item.value), null, 2);

View File

@@ -29,7 +29,6 @@ import {
TagInput,
Spin,
Card,
Radio,
} from '@douyinfe/semi-ui';
const { Text } = Typography;
import {
@@ -45,7 +44,6 @@ import { useTranslation } from 'react-i18next';
const SystemSetting = () => {
const { t } = useTranslation();
let [inputs, setInputs] = useState({
PasswordLoginEnabled: '',
PasswordRegisterEnabled: '',
EmailVerificationEnabled: '',
@@ -89,15 +87,6 @@ const SystemSetting = () => {
LinuxDOClientSecret: '',
LinuxDOMinimumTrustLevel: '',
ServerAddress: '',
// SSRF防护配置
'fetch_setting.enable_ssrf_protection': true,
'fetch_setting.allow_private_ip': '',
'fetch_setting.domain_filter_mode': false, // true 白名单false 黑名单
'fetch_setting.ip_filter_mode': false, // true 白名单false 黑名单
'fetch_setting.domain_list': [],
'fetch_setting.ip_list': [],
'fetch_setting.allowed_ports': [],
'fetch_setting.apply_ip_filter_for_domain': false,
});
const [originInputs, setOriginInputs] = useState({});
@@ -109,11 +98,6 @@ const SystemSetting = () => {
useState(false);
const [linuxDOOAuthEnabled, setLinuxDOOAuthEnabled] = useState(false);
const [emailToAdd, setEmailToAdd] = useState('');
const [domainFilterMode, setDomainFilterMode] = useState(true);
const [ipFilterMode, setIpFilterMode] = useState(true);
const [domainList, setDomainList] = useState([]);
const [ipList, setIpList] = useState([]);
const [allowedPorts, setAllowedPorts] = useState([]);
const getOptions = async () => {
setLoading(true);
@@ -129,37 +113,6 @@ const SystemSetting = () => {
case 'EmailDomainWhitelist':
setEmailDomainWhitelist(item.value ? item.value.split(',') : []);
break;
case 'fetch_setting.allow_private_ip':
case 'fetch_setting.enable_ssrf_protection':
case 'fetch_setting.domain_filter_mode':
case 'fetch_setting.ip_filter_mode':
case 'fetch_setting.apply_ip_filter_for_domain':
item.value = toBoolean(item.value);
break;
case 'fetch_setting.domain_list':
try {
const domains = item.value ? JSON.parse(item.value) : [];
setDomainList(Array.isArray(domains) ? domains : []);
} catch (e) {
setDomainList([]);
}
break;
case 'fetch_setting.ip_list':
try {
const ips = item.value ? JSON.parse(item.value) : [];
setIpList(Array.isArray(ips) ? ips : []);
} catch (e) {
setIpList([]);
}
break;
case 'fetch_setting.allowed_ports':
try {
const ports = item.value ? JSON.parse(item.value) : [];
setAllowedPorts(Array.isArray(ports) ? ports : []);
} catch (e) {
setAllowedPorts(['80', '443', '8080', '8443']);
}
break;
case 'PasswordLoginEnabled':
case 'PasswordRegisterEnabled':
case 'EmailVerificationEnabled':
@@ -187,13 +140,6 @@ const SystemSetting = () => {
});
setInputs(newInputs);
setOriginInputs(newInputs);
// 同步模式布尔到本地状态
if (typeof newInputs['fetch_setting.domain_filter_mode'] !== 'undefined') {
setDomainFilterMode(!!newInputs['fetch_setting.domain_filter_mode']);
}
if (typeof newInputs['fetch_setting.ip_filter_mode'] !== 'undefined') {
setIpFilterMode(!!newInputs['fetch_setting.ip_filter_mode']);
}
if (formApiRef.current) {
formApiRef.current.setValues(newInputs);
}
@@ -330,46 +276,6 @@ const SystemSetting = () => {
}
};
const submitSSRF = async () => {
const options = [];
// 处理域名过滤模式与列表
options.push({
key: 'fetch_setting.domain_filter_mode',
value: domainFilterMode,
});
if (Array.isArray(domainList)) {
options.push({
key: 'fetch_setting.domain_list',
value: JSON.stringify(domainList),
});
}
// 处理IP过滤模式与列表
options.push({
key: 'fetch_setting.ip_filter_mode',
value: ipFilterMode,
});
if (Array.isArray(ipList)) {
options.push({
key: 'fetch_setting.ip_list',
value: JSON.stringify(ipList),
});
}
// 处理端口配置
if (Array.isArray(allowedPorts)) {
options.push({
key: 'fetch_setting.allowed_ports',
value: JSON.stringify(allowedPorts),
});
}
if (options.length > 0) {
await updateOptions(options);
}
};
const handleAddEmail = () => {
if (emailToAdd && emailToAdd.trim() !== '') {
const domain = emailToAdd.trim();
@@ -681,179 +587,6 @@ const SystemSetting = () => {
</Form.Section>
</Card>
<Card>
<Form.Section text={t('SSRF防护设置')}>
<Text extraText={t('SSRF防护详细说明')}>
{t('配置服务器端请求伪造(SSRF)防护,用于保护内网资源安全')}
</Text>
<Row
gutter={{ xs: 8, sm: 16, md: 24, lg: 24, xl: 24, xxl: 24 }}
>
<Col xs={24} sm={24} md={24} lg={24} xl={24}>
<Form.Checkbox
field='fetch_setting.enable_ssrf_protection'
noLabel
extraText={t('SSRF防护开关详细说明')}
onChange={(e) =>
handleCheckboxChange('fetch_setting.enable_ssrf_protection', e)
}
>
{t('启用SSRF防护推荐开启以保护服务器安全')}
</Form.Checkbox>
</Col>
</Row>
<Row
gutter={{ xs: 8, sm: 16, md: 24, lg: 24, xl: 24, xxl: 24 }}
style={{ marginTop: 16 }}
>
<Col xs={24} sm={24} md={24} lg={24} xl={24}>
<Form.Checkbox
field='fetch_setting.allow_private_ip'
noLabel
extraText={t('私有IP访问详细说明')}
onChange={(e) =>
handleCheckboxChange('fetch_setting.allow_private_ip', e)
}
>
{t('允许访问私有IP地址127.0.0.1、192.168.x.x等内网地址')}
</Form.Checkbox>
</Col>
</Row>
<Row
gutter={{ xs: 8, sm: 16, md: 24, lg: 24, xl: 24, xxl: 24 }}
style={{ marginTop: 16 }}
>
<Col xs={24} sm={24} md={24} lg={24} xl={24}>
<Form.Checkbox
field='fetch_setting.apply_ip_filter_for_domain'
noLabel
extraText={t('域名IP过滤详细说明')}
onChange={(e) =>
handleCheckboxChange('fetch_setting.apply_ip_filter_for_domain', e)
}
style={{ marginBottom: 8 }}
>
{t('对域名启用 IP 过滤(实验性)')}
</Form.Checkbox>
<Text strong>
{t(domainFilterMode ? '域名白名单' : '域名黑名单')}
</Text>
<Text type="secondary" style={{ display: 'block', marginBottom: 8 }}>
{t('支持通配符格式example.com, *.api.example.com')}
</Text>
<Radio.Group
type='button'
value={domainFilterMode ? 'whitelist' : 'blacklist'}
onChange={(val) => {
const selected = val && val.target ? val.target.value : val;
const isWhitelist = selected === 'whitelist';
setDomainFilterMode(isWhitelist);
setInputs(prev => ({
...prev,
'fetch_setting.domain_filter_mode': isWhitelist,
}));
}}
style={{ marginBottom: 8 }}
>
<Radio value='whitelist'>{t('白名单')}</Radio>
<Radio value='blacklist'>{t('黑名单')}</Radio>
</Radio.Group>
<TagInput
value={domainList}
onChange={(value) => {
setDomainList(value);
// 触发Form的onChange事件
setInputs(prev => ({
...prev,
'fetch_setting.domain_list': value
}));
}}
placeholder={t('输入域名后回车example.com')}
style={{ width: '100%' }}
/>
</Col>
</Row>
<Row
gutter={{ xs: 8, sm: 16, md: 24, lg: 24, xl: 24, xxl: 24 }}
style={{ marginTop: 16 }}
>
<Col xs={24} sm={24} md={24} lg={24} xl={24}>
<Text strong>
{t(ipFilterMode ? 'IP白名单' : 'IP黑名单')}
</Text>
<Text type="secondary" style={{ display: 'block', marginBottom: 8 }}>
{t('支持CIDR格式8.8.8.8, 192.168.1.0/24')}
</Text>
<Radio.Group
type='button'
value={ipFilterMode ? 'whitelist' : 'blacklist'}
onChange={(val) => {
const selected = val && val.target ? val.target.value : val;
const isWhitelist = selected === 'whitelist';
setIpFilterMode(isWhitelist);
setInputs(prev => ({
...prev,
'fetch_setting.ip_filter_mode': isWhitelist,
}));
}}
style={{ marginBottom: 8 }}
>
<Radio value='whitelist'>{t('白名单')}</Radio>
<Radio value='blacklist'>{t('黑名单')}</Radio>
</Radio.Group>
<TagInput
value={ipList}
onChange={(value) => {
setIpList(value);
// 触发Form的onChange事件
setInputs(prev => ({
...prev,
'fetch_setting.ip_list': value
}));
}}
placeholder={t('输入IP地址后回车8.8.8.8')}
style={{ width: '100%' }}
/>
</Col>
</Row>
<Row
gutter={{ xs: 8, sm: 16, md: 24, lg: 24, xl: 24, xxl: 24 }}
style={{ marginTop: 16 }}
>
<Col xs={24} sm={24} md={24} lg={24} xl={24}>
<Text strong>{t('允许的端口')}</Text>
<Text type="secondary" style={{ display: 'block', marginBottom: 8 }}>
{t('支持单个端口和端口范围80, 443, 8000-8999')}
</Text>
<TagInput
value={allowedPorts}
onChange={(value) => {
setAllowedPorts(value);
// 触发Form的onChange事件
setInputs(prev => ({
...prev,
'fetch_setting.allowed_ports': value
}));
}}
placeholder={t('输入端口后回车80 或 8000-8999')}
style={{ width: '100%' }}
/>
<Text type="secondary" style={{ display: 'block', marginBottom: 8 }}>
{t('端口配置详细说明')}
</Text>
</Col>
</Row>
<Button onClick={submitSSRF} style={{ marginTop: 16 }}>
{t('更新SSRF防护设置')}
</Button>
</Form.Section>
</Card>
<Card>
<Form.Section text={t('配置登录注册')}>
<Row

View File

@@ -0,0 +1,400 @@
/*
Copyright (C) 2025 QuantumNous
This program is free software: you can redistribute it and/or modify
it under the terms of the GNU Affero General Public License as
published by the Free Software Foundation, either version 3 of the
License, or (at your option) any later version.
This program is distributed in the hope that it will be useful,
but WITHOUT ANY WARRANTY; without even the implied warranty of
MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
GNU Affero General Public License for more details.
You should have received a copy of the GNU Affero General Public License
along with this program. If not, see <https://www.gnu.org/licenses/>.
For commercial licensing, please contact support@quantumnous.com
*/
import React, { useEffect, useState } from 'react';
import {
Card,
Table,
Button,
Space,
Tag,
Typography,
Input,
Popconfirm,
Empty,
Tooltip,
} from '@douyinfe/semi-ui';
import { IconSearch } from '@douyinfe/semi-icons';
import { User } from 'lucide-react';
import {
IllustrationNoResult,
IllustrationNoResultDark,
} from '@douyinfe/semi-illustrations';
import { API, showError, showSuccess } from '../../../helpers';
import OAuth2ClientModal from './modals/OAuth2ClientModal';
import SecretDisplayModal from './modals/SecretDisplayModal';
import ServerInfoModal from './modals/ServerInfoModal';
import JWKSInfoModal from './modals/JWKSInfoModal';
import { useTranslation } from 'react-i18next';
const { Text } = Typography;
export default function OAuth2ClientSettings() {
const { t } = useTranslation();
const [loading, setLoading] = useState(false);
const [clients, setClients] = useState([]);
const [filteredClients, setFilteredClients] = useState([]);
const [searchKeyword, setSearchKeyword] = useState('');
const [showModal, setShowModal] = useState(false);
const [editingClient, setEditingClient] = useState(null);
const [showSecretModal, setShowSecretModal] = useState(false);
const [currentSecret, setCurrentSecret] = useState('');
const [showServerInfoModal, setShowServerInfoModal] = useState(false);
const [showJWKSModal, setShowJWKSModal] = useState(false);
// 加载客户端列表
const loadClients = async () => {
setLoading(true);
try {
const res = await API.get('/api/oauth_clients/');
if (res.data.success) {
setClients(res.data.data || []);
setFilteredClients(res.data.data || []);
} else {
showError(res.data.message);
}
} catch (error) {
showError(t('加载OAuth2客户端失败'));
} finally {
setLoading(false);
}
};
// 搜索过滤
const handleSearch = (value) => {
setSearchKeyword(value);
if (!value) {
setFilteredClients(clients);
} else {
const filtered = clients.filter(
(client) =>
client.name?.toLowerCase().includes(value.toLowerCase()) ||
client.id?.toLowerCase().includes(value.toLowerCase()) ||
client.description?.toLowerCase().includes(value.toLowerCase()),
);
setFilteredClients(filtered);
}
};
// 删除客户端
const handleDelete = async (client) => {
try {
const res = await API.delete(`/api/oauth_clients/${client.id}`);
if (res.data.success) {
showSuccess(t('删除成功'));
loadClients();
} else {
showError(res.data.message);
}
} catch (error) {
showError(t('删除失败'));
}
};
// 重新生成密钥
const handleRegenerateSecret = async (client) => {
try {
const res = await API.post(
`/api/oauth_clients/${client.id}/regenerate_secret`,
);
if (res.data.success) {
setCurrentSecret(res.data.client_secret);
setShowSecretModal(true);
loadClients();
} else {
showError(res.data.message);
}
} catch (error) {
showError(t('重新生成密钥失败'));
}
};
// 查看服务器信息
const showServerInfo = () => {
setShowServerInfoModal(true);
};
// 查看JWKS
const showJWKS = () => {
setShowJWKSModal(true);
};
// 表格列定义
const columns = [
{
title: t('客户端名称'),
dataIndex: 'name',
render: (name, record) => (
<div className='flex items-center cursor-help'>
<User size={16} className='mr-1.5 text-gray-500' />
<Tooltip content={record.description || t('暂无描述')} position='top'>
<Text strong>{name}</Text>
</Tooltip>
</div>
),
},
{
title: t('客户端ID'),
dataIndex: 'id',
render: (id) => (
<Text type='tertiary' size='small' code copyable>
{id}
</Text>
),
},
{
title: t('状态'),
dataIndex: 'status',
render: (status) => (
<Tag color={status === 1 ? 'green' : 'red'} shape='circle'>
{status === 1 ? t('启用') : t('禁用')}
</Tag>
),
},
{
title: t('类型'),
dataIndex: 'client_type',
render: (text) => (
<Tag color='white' shape='circle'>
{text === 'confidential' ? t('机密客户端') : t('公开客户端')}
</Tag>
),
},
{
title: t('授权类型'),
dataIndex: 'grant_types',
render: (grantTypes) => {
const types =
typeof grantTypes === 'string'
? grantTypes.split(',')
: grantTypes || [];
const typeMap = {
client_credentials: t('客户端凭证'),
authorization_code: t('授权码'),
refresh_token: t('刷新令牌'),
};
return (
<div className='flex flex-wrap gap-1'>
{types.slice(0, 2).map((type) => (
<Tag color='white' key={type} size='small' shape='circle'>
{typeMap[type] || type}
</Tag>
))}
{types.length > 2 && (
<Tooltip
content={types
.slice(2)
.map((t) => typeMap[t] || t)
.join(', ')}
>
<Tag color='white' size='small' shape='circle'>
+{types.length - 2}
</Tag>
</Tooltip>
)}
</div>
);
},
},
{
title: t('创建时间'),
dataIndex: 'created_time',
render: (time) => new Date(time * 1000).toLocaleString(),
},
{
title: t('操作'),
render: (_, record) => (
<Space size={4} wrap>
<Button
type='primary'
size='small'
onClick={() => {
setEditingClient(record);
setShowModal(true);
}}
>
{t('编辑')}
</Button>
{record.client_type === 'confidential' && (
<Popconfirm
title={t('确认重新生成客户端密钥?')}
content={t('操作不可撤销,旧密钥将立即失效。')}
onConfirm={() => handleRegenerateSecret(record)}
okText={t('确认')}
cancelText={t('取消')}
position='bottomLeft'
>
<Button type='secondary' size='small'>
{t('重新生成密钥')}
</Button>
</Popconfirm>
)}
<Popconfirm
title={t('请再次确认删除该客户端')}
content={t('删除后无法恢复,相关 API 调用将立即失效。')}
onConfirm={() => handleDelete(record)}
okText={t('确定删除')}
cancelText={t('取消')}
position='bottomLeft'
>
<Button type='danger' size='small'>
{t('删除')}
</Button>
</Popconfirm>
</Space>
),
fixed: 'right',
},
];
useEffect(() => {
loadClients();
}, []);
return (
<Card
className='!rounded-2xl shadow-sm border-0'
style={{ marginTop: 10 }}
title={
<div
className='flex flex-col sm:flex-row sm:items-center sm:justify-between w-full gap-3 sm:gap-0'
style={{ paddingRight: '8px' }}
>
<div className='flex items-center'>
<User size={18} className='mr-2' />
<Text strong>{t('OAuth2 客户端管理')}</Text>
<Tag color='white' shape='circle' size='small' className='ml-2'>
{filteredClients.length} {t('个客户端')}
</Tag>
</div>
<div className='flex items-center gap-2 sm:flex-shrink-0 flex-wrap'>
<Input
prefix={<IconSearch />}
placeholder={t('搜索客户端名称、ID或描述')}
value={searchKeyword}
onChange={handleSearch}
showClear
size='small'
style={{ width: 300 }}
/>
<Button type='tertiary' onClick={loadClients} size='small'>
{t('刷新')}
</Button>
<Button type='secondary' onClick={showServerInfo} size='small'>
{t('服务器信息')}
</Button>
<Button type='secondary' onClick={showJWKS} size='small'>
{t('查看JWKS')}
</Button>
<Button
type='primary'
onClick={() => {
setEditingClient(null);
setShowModal(true);
}}
size='small'
>
{t('创建客户端')}
</Button>
</div>
</div>
}
>
<div className='mb-4'>
<Text type='tertiary'>
{t(
'管理OAuth2客户端应用程序每个客户端代表一个可以访问API的应用程序。机密客户端用于服务器端应用公开客户端用于移动应用或单页应用。',
)}
</Text>
</div>
{/* 客户端表格 */}
<Table
columns={columns}
dataSource={filteredClients}
rowKey='id'
loading={loading}
scroll={{ x: 'max-content' }}
pagination={{
showSizeChanger: true,
showQuickJumper: true,
showTotal: true,
pageSize: 10,
}}
empty={
<Empty
image={<IllustrationNoResult style={{ width: 150, height: 150 }} />}
darkModeImage={
<IllustrationNoResultDark style={{ width: 150, height: 150 }} />
}
title={t('暂无OAuth2客户端')}
description={t(
'还没有创建任何客户端,点击下方按钮创建第一个客户端',
)}
style={{ padding: 30 }}
>
<Button
type='primary'
onClick={() => {
setEditingClient(null);
setShowModal(true);
}}
>
{t('创建第一个客户端')}
</Button>
</Empty>
}
/>
{/* OAuth2 客户端模态框 */}
<OAuth2ClientModal
visible={showModal}
client={editingClient}
onCancel={() => {
setShowModal(false);
setEditingClient(null);
}}
onSuccess={() => {
setShowModal(false);
setEditingClient(null);
loadClients();
}}
/>
{/* 密钥显示模态框 */}
<SecretDisplayModal
visible={showSecretModal}
onClose={() => setShowSecretModal(false)}
secret={currentSecret}
/>
{/* 服务器信息模态框 */}
<ServerInfoModal
visible={showServerInfoModal}
onClose={() => setShowServerInfoModal(false)}
/>
{/* JWKS信息模态框 */}
<JWKSInfoModal
visible={showJWKSModal}
onClose={() => setShowJWKSModal(false)}
/>
</Card>
);
}

View File

@@ -0,0 +1,473 @@
/*
Copyright (C) 2025 QuantumNous
This program is free software: you can redistribute it and/or modify
it under the terms of the GNU Affero General Public License as
published by the Free Software Foundation, either version 3 of the
License, or (at your option) any later version.
This program is distributed in the hope that it will be useful,
but WITHOUT ANY WARRANTY; without even the implied warranty of
MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
GNU Affero General Public License for more details.
You should have received a copy of the GNU Affero General Public License
along with this program. If not, see <https://www.gnu.org/licenses/>.
For commercial licensing, please contact support@quantumnous.com
*/
import React, { useEffect, useState, useRef } from 'react';
import {
Banner,
Button,
Col,
Form,
Row,
Card,
Typography,
Badge,
Divider,
} from '@douyinfe/semi-ui';
import { Server } from 'lucide-react';
import JWKSManagerModal from './modals/JWKSManagerModal';
import {
compareObjects,
API,
showError,
showSuccess,
showWarning,
} from '../../../helpers';
import { useTranslation } from 'react-i18next';
const { Text } = Typography;
export default function OAuth2ServerSettings(props) {
const { t } = useTranslation();
const [loading, setLoading] = useState(false);
const [inputs, setInputs] = useState({
'oauth2.enabled': false,
'oauth2.issuer': '',
'oauth2.access_token_ttl': 10,
'oauth2.refresh_token_ttl': 720,
'oauth2.jwt_signing_algorithm': 'RS256',
'oauth2.jwt_key_id': 'oauth2-key-1',
'oauth2.allowed_grant_types': [
'client_credentials',
'authorization_code',
'refresh_token',
],
'oauth2.require_pkce': true,
'oauth2.max_jwks_keys': 3,
});
const refForm = useRef();
const [inputsRow, setInputsRow] = useState(inputs);
const [keysReady, setKeysReady] = useState(true);
const [keysLoading, setKeysLoading] = useState(false);
const [serverInfo, setServerInfo] = useState(null);
const enabledRef = useRef(inputs['oauth2.enabled']);
// 模态框状态
const [jwksVisible, setJwksVisible] = useState(false);
function handleFieldChange(fieldName) {
return (value) => {
setInputs((inputs) => ({ ...inputs, [fieldName]: value }));
};
}
function onSubmit() {
const updateArray = compareObjects(inputs, inputsRow);
if (!updateArray.length) return showWarning(t('你似乎并没有修改什么'));
const requestQueue = updateArray.map((item) => {
let value = '';
if (typeof inputs[item.key] === 'boolean') {
value = String(inputs[item.key]);
} else if (Array.isArray(inputs[item.key])) {
value = JSON.stringify(inputs[item.key]);
} else {
value = inputs[item.key];
}
return API.put('/api/option/', {
key: item.key,
value,
});
});
setLoading(true);
Promise.all(requestQueue)
.then((res) => {
if (requestQueue.length === 1) {
if (res.includes(undefined)) return;
} else if (requestQueue.length > 1) {
if (res.includes(undefined))
return showError(t('部分保存失败,请重试'));
}
showSuccess(t('保存成功'));
if (props && props.refresh) {
props.refresh();
}
})
.catch(() => {
showError(t('保存失败,请重试'));
})
.finally(() => {
setLoading(false);
});
}
// 测试OAuth2连接默认静默仅用户点击时弹提示
const testOAuth2 = async (silent = true) => {
// 未启用时不触发测试,避免 404
if (!enabledRef.current) return;
try {
const res = await API.get('/api/oauth/server-info', {
skipErrorHandler: true,
});
if (!enabledRef.current) return;
if (
res.status === 200 &&
(res.data.issuer || res.data.authorization_endpoint)
) {
if (!silent) showSuccess('OAuth2服务器运行正常');
setServerInfo(res.data);
} else {
if (!enabledRef.current) return;
if (!silent) showError('OAuth2服务器测试失败');
}
} catch (error) {
if (!enabledRef.current) return;
if (!silent) showError('OAuth2服务器连接测试失败');
}
};
useEffect(() => {
if (props && props.options) {
const currentInputs = {};
for (let key in props.options) {
if (Object.keys(inputs).includes(key)) {
if (key === 'oauth2.allowed_grant_types') {
try {
currentInputs[key] = JSON.parse(
props.options[key] ||
'["client_credentials","authorization_code","refresh_token"]',
);
} catch {
currentInputs[key] = [
'client_credentials',
'authorization_code',
'refresh_token',
];
}
} else if (typeof inputs[key] === 'boolean') {
currentInputs[key] = props.options[key] === 'true';
} else if (typeof inputs[key] === 'number') {
currentInputs[key] = parseInt(props.options[key]) || inputs[key];
} else {
currentInputs[key] = props.options[key];
}
}
}
setInputs({ ...inputs, ...currentInputs });
setInputsRow(structuredClone({ ...inputs, ...currentInputs }));
if (refForm.current) {
refForm.current.setValues({ ...inputs, ...currentInputs });
}
}
}, [props]);
useEffect(() => {
enabledRef.current = inputs['oauth2.enabled'];
}, [inputs['oauth2.enabled']]);
useEffect(() => {
const loadKeys = async () => {
try {
setKeysLoading(true);
const res = await API.get('/api/oauth/keys', {
skipErrorHandler: true,
});
const list = res?.data?.data || [];
setKeysReady(list.length > 0);
} catch {
setKeysReady(false);
} finally {
setKeysLoading(false);
}
};
if (inputs['oauth2.enabled']) {
loadKeys();
testOAuth2(true);
} else {
// 禁用时清理状态,避免残留状态与不必要的请求
setKeysReady(true);
setServerInfo(null);
setKeysLoading(false);
}
}, [inputs['oauth2.enabled']]);
const isEnabled = inputs['oauth2.enabled'];
return (
<div>
{/* OAuth2 服务端管理 */}
<Card
className='!rounded-2xl shadow-sm border-0'
style={{ marginTop: 10 }}
title={
<div
className='flex flex-col sm:flex-row sm:items-center sm:justify-between w-full gap-3 sm:gap-0'
style={{ paddingRight: '8px' }}
>
<div className='flex items-center'>
<Server size={18} className='mr-2' />
<Text strong>{t('OAuth2 服务端管理')}</Text>
{isEnabled ? (
serverInfo ? (
<Badge
count={t('运行正常')}
type='success'
style={{ marginLeft: 8 }}
/>
) : (
<Badge
count={t('配置中')}
type='warning'
style={{ marginLeft: 8 }}
/>
)
) : (
<Badge
count={t('未启用')}
type='tertiary'
style={{ marginLeft: 8 }}
/>
)}
</div>
<div className='flex items-center gap-2 sm:flex-shrink-0'>
{isEnabled && (
<Button
type='secondary'
onClick={() => setJwksVisible(true)}
size='small'
>
{t('密钥管理')}
</Button>
)}
<Button
type='primary'
onClick={onSubmit}
loading={loading}
size='small'
>
{t('保存配置')}
</Button>
</div>
</div>
}
>
<Form
initValues={inputs}
getFormApi={(formAPI) => (refForm.current = formAPI)}
>
{!keysReady && isEnabled && (
<Banner
type='warning'
className='!rounded-lg'
closeIcon={null}
description={t(
'尚未准备签名密钥,建议立即初始化或轮换以发布 JWKS。签名密钥用于 JWT 令牌的安全签发。',
)}
/>
)}
<Row gutter={[16, 24]}>
<Col xs={24} lg={12}>
<Form.Switch
field='oauth2.enabled'
label={t('启用 OAuth2 & SSO')}
value={inputs['oauth2.enabled']}
onChange={handleFieldChange('oauth2.enabled')}
extraText={t('开启后将允许以 OAuth2/OIDC 标准进行授权与登录')}
/>
</Col>
<Col xs={24} lg={12}>
<Form.Input
field='oauth2.issuer'
label={t('发行人 (Issuer)')}
placeholder={window.location.origin}
value={inputs['oauth2.issuer']}
onChange={handleFieldChange('oauth2.issuer')}
extraText={t('为空则按请求自动推断(含 X-Forwarded-Proto')}
/>
</Col>
</Row>
{/* 令牌配置 */}
<Divider margin='24px'>{t('令牌配置')}</Divider>
<Row gutter={[16, 24]}>
<Col xs={24} sm={12} lg={8}>
<Form.InputNumber
field='oauth2.access_token_ttl'
label={t('访问令牌有效期')}
suffix={t('分钟')}
min={1}
max={1440}
value={inputs['oauth2.access_token_ttl']}
onChange={handleFieldChange('oauth2.access_token_ttl')}
extraText={t('访问令牌的有效时间建议较短10-60分钟')}
style={{
width: '100%',
opacity: isEnabled ? 1 : 0.5,
}}
disabled={!isEnabled}
/>
</Col>
<Col xs={24} sm={12} lg={8}>
<Form.InputNumber
field='oauth2.refresh_token_ttl'
label={t('刷新令牌有效期')}
suffix={t('小时')}
min={1}
max={8760}
value={inputs['oauth2.refresh_token_ttl']}
onChange={handleFieldChange('oauth2.refresh_token_ttl')}
extraText={t('刷新令牌的有效时间建议较长12-720小时')}
style={{
width: '100%',
opacity: isEnabled ? 1 : 0.5,
}}
disabled={!isEnabled}
/>
</Col>
<Col xs={24} sm={12} lg={8}>
<Form.InputNumber
field='oauth2.max_jwks_keys'
label={t('JWKS历史保留上限')}
min={1}
max={10}
value={inputs['oauth2.max_jwks_keys']}
onChange={handleFieldChange('oauth2.max_jwks_keys')}
extraText={t('轮换后最多保留的历史签名密钥数量')}
style={{
width: '100%',
opacity: isEnabled ? 1 : 0.5,
}}
disabled={!isEnabled}
/>
</Col>
</Row>
<Row gutter={[16, 24]} style={{ marginTop: 16 }}>
<Col xs={24} lg={12}>
<Form.Select
field='oauth2.jwt_signing_algorithm'
label={t('JWT签名算法')}
value={inputs['oauth2.jwt_signing_algorithm']}
onChange={handleFieldChange('oauth2.jwt_signing_algorithm')}
extraText={t('JWT令牌的签名算法推荐使用RS256')}
style={{
width: '100%',
opacity: isEnabled ? 1 : 0.5,
}}
disabled={!isEnabled}
>
<Form.Select.Option value='RS256'>
RS256 (RSA with SHA-256)
</Form.Select.Option>
<Form.Select.Option value='HS256'>
HS256 (HMAC with SHA-256)
</Form.Select.Option>
</Form.Select>
</Col>
<Col xs={24} lg={12}>
<Form.Input
field='oauth2.jwt_key_id'
label={t('JWT密钥ID')}
placeholder='oauth2-key-1'
value={inputs['oauth2.jwt_key_id']}
onChange={handleFieldChange('oauth2.jwt_key_id')}
extraText={t('用于标识JWT签名密钥支持密钥轮换')}
style={{
width: '100%',
opacity: isEnabled ? 1 : 0.5,
}}
disabled={!isEnabled}
/>
</Col>
</Row>
{/* 授权配置 */}
<Divider margin='24px'>{t('授权配置')}</Divider>
<Row gutter={[16, 24]}>
<Col xs={24} lg={12}>
<Form.Select
field='oauth2.allowed_grant_types'
label={t('允许的授权类型')}
multiple
value={inputs['oauth2.allowed_grant_types']}
onChange={handleFieldChange('oauth2.allowed_grant_types')}
extraText={t('选择允许的OAuth2授权流程')}
style={{
width: '100%',
opacity: isEnabled ? 1 : 0.5,
}}
disabled={!isEnabled}
>
<Form.Select.Option value='client_credentials'>
{t('Client Credentials客户端凭证')}
</Form.Select.Option>
<Form.Select.Option value='authorization_code'>
{t('Authorization Code授权码')}
</Form.Select.Option>
<Form.Select.Option value='refresh_token'>
{t('Refresh Token刷新令牌')}
</Form.Select.Option>
</Form.Select>
</Col>
<Col xs={24} lg={12}>
<Form.Switch
field='oauth2.require_pkce'
label={t('强制PKCE验证')}
value={inputs['oauth2.require_pkce']}
onChange={handleFieldChange('oauth2.require_pkce')}
extraText={t('为授权码流程强制启用PKCE提高安全性')}
disabled={!isEnabled}
/>
</Col>
</Row>
<div style={{ marginTop: 16 }}>
<Text type='tertiary' size='small'>
<div className='space-y-1'>
<div> {t('OAuth2 服务器提供标准的 API 认证与授权')}</div>
<div>
{' '}
{t(
'支持 Client Credentials、Authorization Code + PKCE 等标准流程',
)}
</div>
<div>
{' '}
{t(
'配置保存后多数项即时生效;签名密钥轮换与 JWKS 发布为即时操作',
)}
</div>
<div>
{t('生产环境务必启用 HTTPS并妥善管理 JWT 签名密钥')}
</div>
</div>
</Text>
</div>
</Form>
</Card>
{/* 模态框 */}
<JWKSManagerModal
visible={jwksVisible}
onClose={() => setJwksVisible(false)}
/>
</div>
);
}

View File

@@ -0,0 +1,78 @@
/*
Copyright (C) 2025 QuantumNous
This program is free software: you can redistribute it and/or modify
it under the terms of the GNU Affero General Public License as
published by the Free Software Foundation, either version 3 of the
License, or (at your option) any later version.
This program is distributed in the hope that it will be useful,
but WITHOUT ANY WARRANTY; without even the implied warranty of
MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
GNU Affero General Public License for more details.
You should have received a copy of the GNU Affero General Public License
along with this program. If not, see <https://www.gnu.org/licenses/>.
For commercial licensing, please contact support@quantumnous.com
*/
import React from 'react';
import { Modal, Banner, Typography } from '@douyinfe/semi-ui';
import { useTranslation } from 'react-i18next';
const { Text } = Typography;
const ClientInfoModal = ({ visible, onClose, clientId, clientSecret }) => {
const { t } = useTranslation();
return (
<Modal
title={t('客户端创建成功')}
visible={visible}
onCancel={onClose}
onOk={onClose}
cancelText=''
okText={t('我已复制保存')}
width={650}
bodyStyle={{ padding: '20px 24px' }}
>
<Banner
type='success'
closeIcon={null}
description={t(
'客户端信息如下,请立即复制保存。关闭此窗口后将无法再次查看密钥。',
)}
className='mb-5 !rounded-lg'
/>
<div className='space-y-4'>
<div className='flex justify-center items-center'>
<div className='text-center'>
<Text strong className='block mb-2'>
{t('客户端ID')}
</Text>
<Text code copyable>
{clientId}
</Text>
</div>
</div>
{clientSecret && (
<div className='flex justify-center items-center'>
<div className='text-center'>
<Text strong className='block mb-2'>
{t('客户端密钥(仅此一次显示)')}
</Text>
<Text code copyable>
{clientSecret}
</Text>
</div>
</div>
)}
</div>
</Modal>
);
};
export default ClientInfoModal;

View File

@@ -0,0 +1,70 @@
/*
Copyright (C) 2025 QuantumNous
This program is free software: you can redistribute it and/or modify
it under the terms of the GNU Affero General Public License as
published by the Free Software Foundation, either version 3 of the
License, or (at your option) any later version.
This program is distributed in the hope that it will be useful,
but WITHOUT ANY WARRANTY; without even the implied warranty of
MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
GNU Affero General Public License for more details.
You should have received a copy of the GNU Affero General Public License
along with this program. If not, see <https://www.gnu.org/licenses/>.
For commercial licensing, please contact support@quantumnous.com
*/
import React, { useState, useEffect } from 'react';
import { Modal } from '@douyinfe/semi-ui';
import { API, showError } from '../../../../helpers';
import { useTranslation } from 'react-i18next';
import CodeViewer from '../../../common/ui/CodeViewer';
const JWKSInfoModal = ({ visible, onClose }) => {
const { t } = useTranslation();
const [loading, setLoading] = useState(false);
const [jwksInfo, setJwksInfo] = useState(null);
const loadJWKSInfo = async () => {
setLoading(true);
try {
const res = await API.get('/api/oauth/jwks');
setJwksInfo(res.data);
} catch (error) {
showError(t('获取JWKS失败'));
} finally {
setLoading(false);
}
};
useEffect(() => {
if (visible) {
loadJWKSInfo();
}
}, [visible]);
return (
<Modal
title={t('JWKS 信息')}
visible={visible}
onCancel={onClose}
onOk={onClose}
cancelText=''
okText={t('关闭')}
width={650}
bodyStyle={{ padding: '20px 24px' }}
confirmLoading={loading}
>
<CodeViewer
content={jwksInfo ? JSON.stringify(jwksInfo, null, 2) : t('加载中...')}
title={t('JWKS 密钥集')}
language='json'
/>
</Modal>
);
};
export default JWKSInfoModal;

View File

@@ -0,0 +1,399 @@
/*
Copyright (C) 2025 QuantumNous
This program is free software: you can redistribute it and/or modify
it under the terms of the GNU Affero General Public License as
published by the Free Software Foundation, either version 3 of the
License, or (at your option) any later version.
This program is distributed in the hope that it will be useful,
but WITHOUT ANY WARRANTY; without even the implied warranty of
MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
GNU Affero General Public License for more details.
You should have received a copy of the GNU Affero General Public License
along with this program. If not, see <https://www.gnu.org/licenses/>.
For commercial licensing, please contact support@quantumnous.com
*/
import React, { useEffect, useState } from 'react';
import {
Table,
Button,
Space,
Tag,
Typography,
Popconfirm,
Toast,
Form,
Card,
Tabs,
TabPane,
} from '@douyinfe/semi-ui';
import { API, showError, showSuccess } from '../../../../helpers';
import { useTranslation } from 'react-i18next';
import ResponsiveModal from '../../../common/ui/ResponsiveModal';
const { Text } = Typography;
// 操作模式枚举
const OPERATION_MODES = {
VIEW: 'view',
IMPORT: 'import',
GENERATE: 'generate',
};
export default function JWKSManagerModal({ visible, onClose }) {
const { t } = useTranslation();
const [loading, setLoading] = useState(false);
const [keys, setKeys] = useState([]);
const [activeTab, setActiveTab] = useState(OPERATION_MODES.VIEW);
const load = async () => {
setLoading(true);
try {
const res = await API.get('/api/oauth/keys');
if (res?.data?.success) setKeys(res.data.data || []);
else showError(res?.data?.message || t('获取密钥列表失败'));
} catch {
showError(t('获取密钥列表失败'));
} finally {
setLoading(false);
}
};
const rotate = async () => {
setLoading(true);
try {
const res = await API.post('/api/oauth/keys/rotate', {});
if (res?.data?.success) {
showSuccess(t('签名密钥已轮换:{{kid}}', { kid: res.data.kid }));
await load();
} else showError(res?.data?.message || t('密钥轮换失败'));
} catch {
showError(t('密钥轮换失败'));
} finally {
setLoading(false);
}
};
const del = async (kid) => {
setLoading(true);
try {
const res = await API.delete(`/api/oauth/keys/${kid}`);
if (res?.data?.success) {
Toast.success(t('已删除:{{kid}}', { kid }));
await load();
} else showError(res?.data?.message || t('删除失败'));
} catch {
showError(t('删除失败'));
} finally {
setLoading(false);
}
};
// Import PEM state
const [pem, setPem] = useState('');
const [customKid, setCustomKid] = useState('');
// Generate PEM file state
const [genPath, setGenPath] = useState('/etc/new-api/oauth2-private.pem');
const [genKid, setGenKid] = useState('');
// 重置表单数据
const resetForms = () => {
setPem('');
setCustomKid('');
setGenKid('');
};
useEffect(() => {
if (visible) {
load();
// 重置到主视图
setActiveTab(OPERATION_MODES.VIEW);
} else {
// 模态框关闭时重置表单数据
resetForms();
}
}, [visible]);
useEffect(() => {
if (!visible) return;
(async () => {
try {
const res = await API.get('/api/oauth/server-info');
const p = res?.data?.default_private_key_path;
if (p) setGenPath(p);
} catch {}
})();
}, [visible]);
// 导入 PEM 私钥
const importPem = async () => {
if (!pem.trim()) return Toast.warning(t('请粘贴 PEM 私钥'));
setLoading(true);
try {
const res = await API.post('/api/oauth/keys/import_pem', {
pem,
kid: customKid.trim(),
});
if (res?.data?.success) {
Toast.success(
t('已导入私钥并切换到 kid={{kid}}', { kid: res.data.kid }),
);
resetForms();
setActiveTab(OPERATION_MODES.VIEW);
await load();
} else {
Toast.error(res?.data?.message || t('导入失败'));
}
} catch {
Toast.error(t('导入失败'));
} finally {
setLoading(false);
}
};
// 生成 PEM 文件
const generatePemFile = async () => {
if (!genPath.trim()) return Toast.warning(t('请填写保存路径'));
setLoading(true);
try {
const res = await API.post('/api/oauth/keys/generate_file', {
path: genPath.trim(),
kid: genKid.trim(),
});
if (res?.data?.success) {
Toast.success(t('已生成并生效:{{path}}', { path: res.data.path }));
resetForms();
setActiveTab(OPERATION_MODES.VIEW);
await load();
} else {
Toast.error(res?.data?.message || t('生成失败'));
}
} catch {
Toast.error(t('生成失败'));
} finally {
setLoading(false);
}
};
const columns = [
{
title: 'KID',
dataIndex: 'kid',
render: (kid) => (
<Text code copyable>
{kid}
</Text>
),
},
{
title: t('创建时间'),
dataIndex: 'created_at',
render: (ts) => (ts ? new Date(ts * 1000).toLocaleString() : '-'),
},
{
title: t('状态'),
dataIndex: 'current',
render: (cur) =>
cur ? (
<Tag color='green' shape='circle'>
{t('当前')}
</Tag>
) : (
<Tag shape='circle'>{t('历史')}</Tag>
),
},
{
title: t('操作'),
render: (_, r) => (
<Space>
{!r.current && (
<Popconfirm
title={t('确定删除密钥 {{kid}} ', { kid: r.kid })}
content={t(
'删除后使用该 kid 签发的旧令牌仍可被验证(外部 JWKS 缓存可能仍保留)',
)}
okText={t('删除')}
onConfirm={() => del(r.kid)}
>
<Button size='small' type='danger'>
{t('删除')}
</Button>
</Popconfirm>
)}
</Space>
),
},
];
// 头部操作按钮 - 根据当前标签页动态生成
const getHeaderActions = () => {
if (activeTab === OPERATION_MODES.VIEW) {
const hasKeys = Array.isArray(keys) && keys.length > 0;
return [
<Button key='refresh' onClick={load} loading={loading} size='small'>
{t('刷新')}
</Button>,
<Button
key='rotate'
type='primary'
onClick={rotate}
loading={loading}
size='small'
>
{hasKeys ? t('轮换密钥') : t('初始化密钥')}
</Button>,
];
}
if (activeTab === OPERATION_MODES.IMPORT) {
return [
<Button
key='import'
type='primary'
onClick={importPem}
loading={loading}
size='small'
>
{t('导入并生效')}
</Button>,
];
}
if (activeTab === OPERATION_MODES.GENERATE) {
return [
<Button
key='generate'
type='primary'
onClick={generatePemFile}
loading={loading}
size='small'
>
{t('生成并生效')}
</Button>,
];
}
return [];
};
// 渲染密钥列表视图
const renderKeysView = () => (
<Card
className='!rounded-lg'
title={
<Text className='text-blue-700 dark:text-blue-300'>
{t(
'提示:当前密钥用于签发 JWT 令牌。建议定期轮换密钥以提升安全性。只有历史密钥可以删除。',
)}
</Text>
}
>
<Table
dataSource={keys}
columns={columns}
rowKey='kid'
loading={loading}
pagination={false}
empty={<Text type='tertiary'>{t('暂无密钥')}</Text>}
/>
</Card>
);
// 渲染导入 PEM 私钥视图
const renderImportView = () => (
<Card
className='!rounded-lg'
title={
<Text className='text-yellow-700 dark:text-yellow-300'>
{t(
'建议:优先使用内存签名密钥与 JWKS 轮换;仅在有合规要求时导入外部私钥。请确保私钥来源可信。',
)}
</Text>
}
>
<Form labelPosition='left' labelWidth={120}>
<Form.Input
field='kid'
label={t('自定义 KID')}
placeholder={t('可留空自动生成')}
value={customKid}
onChange={setCustomKid}
/>
<Form.TextArea
field='pem'
label={t('PEM 私钥')}
value={pem}
onChange={setPem}
rows={8}
placeholder={
'-----BEGIN RSA PRIVATE KEY-----\n...\n-----END RSA PRIVATE KEY-----'
}
/>
</Form>
</Card>
);
// 渲染生成 PEM 文件视图
const renderGenerateView = () => (
<Card
className='!rounded-lg'
title={
<Text className='text-orange-700 dark:text-orange-300'>
{t(
'建议:仅在合规要求下使用文件私钥。请确保目录权限安全(建议 0600并妥善备份。',
)}
</Text>
}
>
<Form labelPosition='left' labelWidth={120}>
<Form.Input
field='path'
label={t('保存路径')}
value={genPath}
onChange={setGenPath}
placeholder='/secure/path/oauth2-private.pem'
/>
<Form.Input
field='genKid'
label={t('自定义 KID')}
value={genKid}
onChange={setGenKid}
placeholder={t('可留空自动生成')}
/>
</Form>
</Card>
);
return (
<ResponsiveModal
visible={visible}
title={t('JWKS 管理')}
headerActions={getHeaderActions()}
onCancel={onClose}
footer={null}
width={{ mobile: '95%', desktop: 800 }}
>
<Tabs
activeKey={activeTab}
onChange={setActiveTab}
type='card'
size='medium'
className='!-mt-2'
>
<TabPane tab={t('密钥列表')} itemKey={OPERATION_MODES.VIEW}>
{renderKeysView()}
</TabPane>
<TabPane tab={t('导入 PEM 私钥')} itemKey={OPERATION_MODES.IMPORT}>
{renderImportView()}
</TabPane>
<TabPane tab={t('生成 PEM 文件')} itemKey={OPERATION_MODES.GENERATE}>
{renderGenerateView()}
</TabPane>
</Tabs>
</ResponsiveModal>
);
}

View File

@@ -0,0 +1,730 @@
/*
Copyright (C) 2025 QuantumNous
This program is free software: you can redistribute it and/or modify
it under the terms of the GNU Affero General Public License as
published by the Free Software Foundation, either version 3 of the
License, or (at your option) any later version.
This program is distributed in the hope that it will be useful,
but WITHOUT ANY WARRANTY; without even the implied warranty of
MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
GNU Affero General Public License for more details.
You should have received a copy of the GNU Affero General Public License
along with this program. If not, see <https://www.gnu.org/licenses/>.
For commercial licensing, please contact support@quantumnous.com
*/
import React, { useEffect, useState, useRef } from 'react';
import {
SideSheet,
Form,
Input,
Select,
Space,
Typography,
Button,
Card,
Avatar,
Tag,
Spin,
Radio,
Divider,
} from '@douyinfe/semi-ui';
import {
IconKey,
IconLink,
IconSave,
IconClose,
IconPlus,
IconDelete,
} from '@douyinfe/semi-icons';
import { API, showError, showSuccess } from '../../../../helpers';
import { useTranslation } from 'react-i18next';
import { useIsMobile } from '../../../../hooks/common/useIsMobile';
import ClientInfoModal from './ClientInfoModal';
const { Text, Title } = Typography;
const { Option } = Select;
const AUTH_CODE = 'authorization_code';
const CLIENT_CREDENTIALS = 'client_credentials';
// 子组件重定向URI编辑卡片
function RedirectUriCard({
t,
isAuthCodeSelected,
redirectUris,
onAdd,
onUpdate,
onRemove,
onFillTemplate,
}) {
return (
<Card
header={
<div className='flex justify-between items-center'>
<div className='flex items-center'>
<Avatar size='small' color='purple' className='mr-2 shadow-md'>
<IconLink size={16} />
</Avatar>
<div>
<Text className='text-lg font-medium'>{t('重定向URI配置')}</Text>
<div className='text-xs text-gray-600'>
{t('用于授权码流程的重定向地址')}
</div>
</div>
</div>
<Button
type='tertiary'
onClick={onFillTemplate}
size='small'
disabled={!isAuthCodeSelected}
>
{t('填入示例模板')}
</Button>
</div>
}
headerStyle={{ padding: '12px 16px' }}
bodyStyle={{ padding: '16px' }}
className='!rounded-2xl shadow-sm border-0'
>
<div className='space-y-1'>
{redirectUris.length === 0 && (
<div className='text-center py-4 px-4'>
<Text type='tertiary' className='text-gray-500 text-sm'>
{t('暂无重定向URI点击下方按钮添加')}
</Text>
</div>
)}
{redirectUris.map((uri, index) => (
<div
key={index}
style={{
marginBottom: 8,
display: 'flex',
gap: 8,
alignItems: 'center',
}}
>
<Input
placeholder={t('例如https://your-app.com/callback')}
value={uri}
onChange={(value) => onUpdate(index, value)}
style={{ flex: 1 }}
disabled={!isAuthCodeSelected}
/>
<Button
icon={<IconDelete />}
type='danger'
theme='borderless'
onClick={() => onRemove(index)}
disabled={!isAuthCodeSelected}
/>
</div>
))}
<div className='py-2 flex justify-center gap-2'>
<Button
icon={<IconPlus />}
type='primary'
theme='outline'
onClick={onAdd}
disabled={!isAuthCodeSelected}
>
{t('添加重定向URI')}
</Button>
</div>
</div>
<Divider margin='12px' align='center'>
<Text type='tertiary' size='small'>
{isAuthCodeSelected
? t(
'用户授权后将重定向到这些URI。必须使用HTTPS本地开发可使用HTTP仅限localhost/127.0.0.1',
)
: t('仅在选择“授权码”授权类型时需要配置重定向URI')}
</Text>
</Divider>
</Card>
);
}
const OAuth2ClientModal = ({ visible, client, onCancel, onSuccess }) => {
const { t } = useTranslation();
const isMobile = useIsMobile();
const formApiRef = useRef(null);
const [loading, setLoading] = useState(false);
const [redirectUris, setRedirectUris] = useState([]);
const [clientType, setClientType] = useState('confidential');
const [grantTypes, setGrantTypes] = useState([]);
const [allowedGrantTypes, setAllowedGrantTypes] = useState([
CLIENT_CREDENTIALS,
AUTH_CODE,
'refresh_token',
]);
// ClientInfoModal 状态
const [showClientInfo, setShowClientInfo] = useState(false);
const [clientInfo, setClientInfo] = useState({
clientId: '',
clientSecret: '',
});
const isEdit = client?.id !== undefined;
const [mode, setMode] = useState('create'); // 'create' | 'edit'
useEffect(() => {
if (visible) {
setMode(isEdit ? 'edit' : 'create');
}
}, [visible, isEdit]);
const getInitValues = () => ({
name: '',
description: '',
client_type: 'confidential',
grant_types: [],
scopes: [],
require_pkce: true,
status: 1,
});
// 加载后端允许的授权类型
useEffect(() => {
let mounted = true;
(async () => {
try {
const res = await API.get('/api/option/');
const { success, data } = res.data || {};
if (!success || !Array.isArray(data)) return;
const found = data.find((i) => i.key === 'oauth2.allowed_grant_types');
if (!found) return;
let parsed = [];
try {
parsed = JSON.parse(found.value || '[]');
} catch (_) {}
if (mounted && Array.isArray(parsed) && parsed.length) {
setAllowedGrantTypes(parsed);
}
} catch (_) {
// 忽略错误使用默认allowedGrantTypes
}
})();
return () => {
mounted = false;
};
}, []);
useEffect(() => {
setGrantTypes((prev) => {
const normalizedPrev = Array.isArray(prev) ? prev : [];
// 移除不被允许或与客户端类型冲突的类型
let next = normalizedPrev.filter((g) => allowedGrantTypes.includes(g));
if (clientType === 'public') {
next = next.filter((g) => g !== CLIENT_CREDENTIALS);
}
return next.length ? next : [];
});
}, [clientType, allowedGrantTypes]);
// 初始化表单数据(编辑模式)
useEffect(() => {
if (client && visible && isEdit) {
setLoading(true);
// 解析授权类型
let parsedGrantTypes = [];
if (typeof client.grant_types === 'string') {
parsedGrantTypes = client.grant_types.split(',');
} else if (Array.isArray(client.grant_types)) {
parsedGrantTypes = client.grant_types;
}
// 解析Scope
let parsedScopes = [];
if (typeof client.scopes === 'string') {
parsedScopes = client.scopes.split(',');
} else if (Array.isArray(client.scopes)) {
parsedScopes = client.scopes;
}
if (!parsedScopes || parsedScopes.length === 0) {
parsedScopes = ['openid', 'profile', 'email', 'api:read'];
}
// 解析重定向URI
let parsedRedirectUris = [];
if (client.redirect_uris) {
try {
const parsed =
typeof client.redirect_uris === 'string'
? JSON.parse(client.redirect_uris)
: client.redirect_uris;
if (Array.isArray(parsed) && parsed.length > 0) {
parsedRedirectUris = parsed;
}
} catch (e) {}
}
// 过滤不被允许或不兼容的授权类型
const filteredGrantTypes = (parsedGrantTypes || []).filter((g) =>
allowedGrantTypes.includes(g),
);
const finalGrantTypes =
client.client_type === 'public'
? filteredGrantTypes.filter((g) => g !== CLIENT_CREDENTIALS)
: filteredGrantTypes;
setClientType(client.client_type);
setGrantTypes(finalGrantTypes);
// 不自动新增空白URI保持与创建模式一致的手动添加体验
setRedirectUris(parsedRedirectUris);
// 设置表单值
const formValues = {
id: client.id,
name: client.name,
description: client.description,
client_type: client.client_type,
grant_types: finalGrantTypes,
scopes: parsedScopes,
require_pkce: !!client.require_pkce,
status: client.status,
};
setTimeout(() => {
if (formApiRef.current) {
formApiRef.current.setValues(formValues);
}
setLoading(false);
}, 100);
} else if (visible && !isEdit) {
// 创建模式,重置状态
setClientType('confidential');
setGrantTypes([]);
setRedirectUris([]);
if (formApiRef.current) {
formApiRef.current.setValues(getInitValues());
}
}
}, [client, visible, isEdit, allowedGrantTypes]);
const isAuthCodeSelected = grantTypes.includes(AUTH_CODE);
const isGrantTypeDisabled = (value) => {
if (!allowedGrantTypes.includes(value)) return true;
if (clientType === 'public' && value === CLIENT_CREDENTIALS) return true;
return false;
};
// URL校验允许 httpshttp 仅限本地开发域名
const isValidRedirectUri = (uri) => {
if (!uri || !uri.trim()) return false;
try {
const u = new URL(uri.trim());
if (u.protocol === 'https:') return true;
if (u.protocol === 'http:') {
const host = u.hostname;
return (
host === 'localhost' ||
host === '127.0.0.1' ||
host.endsWith('.local')
);
}
return false;
} catch (_) {
return false;
}
};
// 处理提交
const handleSubmit = async (values) => {
setLoading(true);
try {
// 过滤空的重定向URI
const validRedirectUris = redirectUris
.map((u) => (u || '').trim())
.filter((u) => u.length > 0);
// 业务校验
if (!grantTypes.length) {
showError(t('请至少选择一种授权类型'));
setLoading(false);
return;
}
// 校验是否包含不被允许的授权类型
const invalids = grantTypes.filter((g) => !allowedGrantTypes.includes(g));
if (invalids.length) {
showError(
t('不被允许的授权类型: {{types}}', { types: invalids.join(', ') }),
);
setLoading(false);
return;
}
if (clientType === 'public' && grantTypes.includes(CLIENT_CREDENTIALS)) {
showError(t('公开客户端不允许使用client_credentials授权类型'));
setLoading(false);
return;
}
if (grantTypes.includes(AUTH_CODE)) {
if (!validRedirectUris.length) {
showError(t('选择授权码授权类型时必须填写至少一个重定向URI'));
setLoading(false);
return;
}
const allValid = validRedirectUris.every(isValidRedirectUri);
if (!allValid) {
showError(t('重定向URI格式不合法仅支持https或本地开发使用http'));
setLoading(false);
return;
}
}
// 避免把 Radio 组件对象形式的 client_type 直接传给后端
const { client_type: _formClientType, ...restValues } = values || {};
const payload = {
...restValues,
client_type: clientType,
grant_types: grantTypes,
redirect_uris: validRedirectUris,
};
let res;
if (isEdit) {
res = await API.put('/api/oauth_clients/', payload);
} else {
res = await API.post('/api/oauth_clients/', payload);
}
const { success, message, client_id, client_secret } = res.data;
if (success) {
if (isEdit) {
showSuccess(t('OAuth2客户端更新成功'));
resetForm();
onSuccess();
} else {
showSuccess(t('OAuth2客户端创建成功'));
// 显示客户端信息
setClientInfo({
clientId: client_id,
clientSecret: client_secret,
});
setShowClientInfo(true);
}
} else {
showError(message);
}
} catch (error) {
showError(isEdit ? t('更新OAuth2客户端失败') : t('创建OAuth2客户端失败'));
} finally {
setLoading(false);
}
};
// 重置表单
const resetForm = () => {
if (formApiRef.current) {
formApiRef.current.reset();
}
setClientType('confidential');
setGrantTypes([]);
setRedirectUris([]);
};
// 处理ClientInfoModal关闭
const handleClientInfoClose = () => {
setShowClientInfo(false);
setClientInfo({ clientId: '', clientSecret: '' });
resetForm();
onSuccess();
};
// 处理取消
const handleCancel = () => {
resetForm();
onCancel();
};
// 添加重定向URI
const addRedirectUri = () => {
setRedirectUris([...redirectUris, '']);
};
// 删除重定向URI
const removeRedirectUri = (index) => {
setRedirectUris(redirectUris.filter((_, i) => i !== index));
};
// 更新重定向URI
const updateRedirectUri = (index, value) => {
const newUris = [...redirectUris];
newUris[index] = value;
setRedirectUris(newUris);
};
// 填入示例重定向URI模板
const fillRedirectUriTemplate = () => {
const template = [
'https://your-app.com/auth/callback',
'https://localhost:3000/callback',
];
setRedirectUris(template);
};
// 授权类型变化处理(清理非法项,只设置一次)
const handleGrantTypesChange = (values) => {
const allowed = Array.isArray(values)
? values.filter((v) => allowedGrantTypes.includes(v))
: [];
const sanitized =
clientType === 'public'
? allowed.filter((v) => v !== CLIENT_CREDENTIALS)
: allowed;
setGrantTypes(sanitized);
if (formApiRef.current) {
formApiRef.current.setValue('grant_types', sanitized);
}
};
// 客户端类型变化处理(兼容 RadioGroup 事件对象与直接值)
const handleClientTypeChange = (next) => {
const value = next && next.target ? next.target.value : next;
setClientType(value);
// 公开客户端自动移除 client_credentials并同步表单字段
const current = Array.isArray(grantTypes) ? grantTypes : [];
const sanitized =
value === 'public'
? current.filter((g) => g !== CLIENT_CREDENTIALS)
: current;
if (sanitized !== current) {
setGrantTypes(sanitized);
if (formApiRef.current) {
formApiRef.current.setValue('grant_types', sanitized);
}
}
};
return (
<SideSheet
placement={mode === 'edit' ? 'right' : 'left'}
title={
<Space>
{mode === 'edit' ? (
<Tag color='blue' shape='circle'>
{t('编辑')}
</Tag>
) : (
<Tag color='green' shape='circle'>
{t('创建')}
</Tag>
)}
<Title heading={4} className='m-0'>
{mode === 'edit' ? t('编辑OAuth2客户端') : t('创建OAuth2客户端')}
</Title>
</Space>
}
bodyStyle={{ padding: '0' }}
visible={visible}
width={isMobile ? '100%' : 700}
footer={
<div className='flex justify-end bg-white'>
<Space>
<Button
theme='solid'
className='!rounded-lg'
onClick={() => formApiRef.current?.submitForm()}
icon={<IconSave />}
loading={loading}
>
{isEdit ? t('保存') : t('创建')}
</Button>
<Button
theme='light'
className='!rounded-lg'
type='primary'
onClick={handleCancel}
icon={<IconClose />}
>
{t('取消')}
</Button>
</Space>
</div>
}
closeIcon={null}
onCancel={handleCancel}
>
<Spin spinning={loading}>
<Form
key={isEdit ? `edit-${client?.id}` : 'create'}
initValues={getInitValues()}
getFormApi={(api) => (formApiRef.current = api)}
onSubmit={handleSubmit}
>
{() => (
<div className='p-2'>
{/* 表单内容 */}
{/* 基本信息 */}
<Card className='!rounded-2xl shadow-sm border-0'>
<div className='flex items-center mb-4'>
<Avatar size='small' color='blue' className='mr-2 shadow-md'>
<IconKey size={16} />
</Avatar>
<div>
<Text className='text-lg font-medium'>{t('基本信息')}</Text>
<div className='text-xs text-gray-600'>
{t('设置客户端的基本信息')}
</div>
</div>
</div>
{isEdit && (
<>
<Form.Select
field='status'
label={t('状态')}
rules={[{ required: true, message: t('请选择状态') }]}
required
>
<Option value={1}>{t('启用')}</Option>
<Option value={2}>{t('禁用')}</Option>
</Form.Select>
<Form.Input field='id' label={t('客户端ID')} disabled />
</>
)}
<Form.Input
field='name'
label={t('客户端名称')}
placeholder={t('输入客户端名称')}
rules={[{ required: true, message: t('请输入客户端名称') }]}
required
showClear
/>
<Form.TextArea
field='description'
label={t('描述')}
placeholder={t('输入客户端描述')}
rows={3}
showClear
/>
<Form.RadioGroup
label={t('客户端类型')}
field='client_type'
value={clientType}
onChange={handleClientTypeChange}
type='card'
aria-label={t('选择客户端类型')}
disabled={isEdit}
rules={[{ required: true, message: t('请选择客户端类型') }]}
required
>
<Radio
value='confidential'
extra={t('服务器端应用,安全地存储客户端密钥')}
style={{ width: isMobile ? '100%' : 'auto' }}
>
{t('机密客户端Confidential')}
</Radio>
<Radio
value='public'
extra={t('移动应用或单页应用,无法安全存储密钥')}
style={{ width: isMobile ? '100%' : 'auto' }}
>
{t('公开客户端Public')}
</Radio>
</Form.RadioGroup>
<Form.Select
field='grant_types'
label={t('允许的授权类型')}
multiple
value={grantTypes}
onChange={handleGrantTypesChange}
rules={[
{ required: true, message: t('请选择至少一种授权类型') },
]}
required
placeholder={t('请选择授权类型(可多选)')}
>
{clientType !== 'public' && (
<Option
value={CLIENT_CREDENTIALS}
disabled={isGrantTypeDisabled(CLIENT_CREDENTIALS)}
>
{t('Client Credentials客户端凭证')}
</Option>
)}
<Option
value={AUTH_CODE}
disabled={isGrantTypeDisabled(AUTH_CODE)}
>
{t('Authorization Code授权码')}
</Option>
<Option
value='refresh_token'
disabled={isGrantTypeDisabled('refresh_token')}
>
{t('Refresh Token刷新令牌')}
</Option>
</Form.Select>
<Form.Select
field='scopes'
label={t('允许的权限范围Scope')}
multiple
rules={[
{ required: true, message: t('请选择至少一个权限范围') },
]}
required
placeholder={t('请选择权限范围(可多选)')}
>
<Option value='openid'>{t('openidOIDC 基础身份)')}</Option>
<Option value='profile'>
{t('profile用户名/昵称等)')}
</Option>
<Option value='email'>{t('email邮箱信息')}</Option>
<Option value='api:read'>
{`api:read (${t('读取API')})`}
</Option>
<Option value='api:write'>
{`api:write (${t('写入API')})`}
</Option>
<Option value='admin'>{t('admin管理员权限')}</Option>
</Form.Select>
<Form.Switch
field='require_pkce'
label={t('强制PKCE验证')}
size='large'
extraText={t(
'PKCEProof Key for Code Exchange可提高授权码流程的安全性。',
)}
/>
</Card>
{/* 重定向URI */}
<RedirectUriCard
t={t}
isAuthCodeSelected={isAuthCodeSelected}
redirectUris={redirectUris}
onAdd={addRedirectUri}
onUpdate={updateRedirectUri}
onRemove={removeRedirectUri}
onFillTemplate={fillRedirectUriTemplate}
/>
</div>
)}
</Form>
</Spin>
{/* 客户端信息展示模态框 */}
<ClientInfoModal
visible={showClientInfo}
onClose={handleClientInfoClose}
clientId={clientInfo.clientId}
clientSecret={clientInfo.clientSecret}
/>
</SideSheet>
);
};
export default OAuth2ClientModal;

View File

@@ -0,0 +1,57 @@
/*
Copyright (C) 2025 QuantumNous
This program is free software: you can redistribute it and/or modify
it under the terms of the GNU Affero General Public License as
published by the Free Software Foundation, either version 3 of the
License, or (at your option) any later version.
This program is distributed in the hope that it will be useful,
but WITHOUT ANY WARRANTY; without even the implied warranty of
MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
GNU Affero General Public License for more details.
You should have received a copy of the GNU Affero General Public License
along with this program. If not, see <https://www.gnu.org/licenses/>.
For commercial licensing, please contact support@quantumnous.com
*/
import React from 'react';
import { Modal, Banner, Typography } from '@douyinfe/semi-ui';
import { useTranslation } from 'react-i18next';
const { Text } = Typography;
const SecretDisplayModal = ({ visible, onClose, secret }) => {
const { t } = useTranslation();
return (
<Modal
title={t('客户端密钥已重新生成')}
visible={visible}
onCancel={onClose}
onOk={onClose}
cancelText=''
okText={t('我已复制保存')}
width={650}
bodyStyle={{ padding: '20px 24px' }}
>
<Banner
type='success'
closeIcon={null}
description={t(
'新的客户端密钥如下,请立即复制保存。关闭此窗口后将无法再次查看。',
)}
className='mb-5 !rounded-lg'
/>
<div className='flex justify-center items-center'>
<Text code copyable>
{secret}
</Text>
</div>
</Modal>
);
};
export default SecretDisplayModal;

View File

@@ -0,0 +1,72 @@
/*
Copyright (C) 2025 QuantumNous
This program is free software: you can redistribute it and/or modify
it under the terms of the GNU Affero General Public License as
published by the Free Software Foundation, either version 3 of the
License, or (at your option) any later version.
This program is distributed in the hope that it will be useful,
but WITHOUT ANY WARRANTY; without even the implied warranty of
MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
GNU Affero General Public License for more details.
You should have received a copy of the GNU Affero General Public License
along with this program. If not, see <https://www.gnu.org/licenses/>.
For commercial licensing, please contact support@quantumnous.com
*/
import React, { useState, useEffect } from 'react';
import { Modal } from '@douyinfe/semi-ui';
import { API, showError } from '../../../../helpers';
import { useTranslation } from 'react-i18next';
import CodeViewer from '../../../common/ui/CodeViewer';
const ServerInfoModal = ({ visible, onClose }) => {
const { t } = useTranslation();
const [loading, setLoading] = useState(false);
const [serverInfo, setServerInfo] = useState(null);
const loadServerInfo = async () => {
setLoading(true);
try {
const res = await API.get('/api/oauth/server-info');
setServerInfo(res.data);
} catch (error) {
showError(t('获取服务器信息失败'));
} finally {
setLoading(false);
}
};
useEffect(() => {
if (visible) {
loadServerInfo();
}
}, [visible]);
return (
<Modal
title={t('OAuth2 服务器信息')}
visible={visible}
onCancel={onClose}
onOk={onClose}
cancelText=''
okText={t('关闭')}
width={650}
bodyStyle={{ padding: '20px 24px' }}
confirmLoading={loading}
>
<CodeViewer
content={
serverInfo ? JSON.stringify(serverInfo, null, 2) : t('加载中...')
}
title={t('OAuth2 服务器配置')}
language='json'
/>
</Modal>
);
};
export default ServerInfoModal;

View File

@@ -28,7 +28,6 @@ import {
Tabs,
TabPane,
Popover,
Modal,
} from '@douyinfe/semi-ui';
import {
IconMail,
@@ -84,9 +83,6 @@ const AccountManagement = ({
</Popover>
);
};
const isBound = (accountId) => Boolean(accountId);
const [showTelegramBindModal, setShowTelegramBindModal] = React.useState(false);
return (
<Card className='!rounded-2xl'>
{/* 卡片头部 */}
@@ -146,7 +142,7 @@ const AccountManagement = ({
size='small'
onClick={() => setShowEmailBindModal(true)}
>
{isBound(userState.user?.email)
{userState.user && userState.user.email !== ''
? t('修改绑定')
: t('绑定')}
</Button>
@@ -169,11 +165,9 @@ const AccountManagement = ({
{t('微信')}
</div>
<div className='text-sm text-gray-500 truncate'>
{!status.wechat_login
? t('未启用')
: isBound(userState.user?.wechat_id)
? t('已绑定')
: t('未绑定')}
{userState.user && userState.user.wechat_id !== ''
? t('已绑定')
: t('未绑定')}
</div>
</div>
</div>
@@ -185,7 +179,7 @@ const AccountManagement = ({
disabled={!status.wechat_login}
onClick={() => setShowWeChatBindModal(true)}
>
{isBound(userState.user?.wechat_id)
{userState.user && userState.user.wechat_id !== ''
? t('修改绑定')
: status.wechat_login
? t('绑定')
@@ -226,7 +220,8 @@ const AccountManagement = ({
onGitHubOAuthClicked(status.github_client_id)
}
disabled={
isBound(userState.user?.github_id) || !status.github_oauth
(userState.user && userState.user.github_id !== '') ||
!status.github_oauth
}
>
{status.github_oauth ? t('绑定') : t('未启用')}
@@ -269,7 +264,8 @@ const AccountManagement = ({
)
}
disabled={
isBound(userState.user?.oidc_id) || !status.oidc_enabled
(userState.user && userState.user.oidc_id !== '') ||
!status.oidc_enabled
}
>
{status.oidc_enabled ? t('绑定') : t('未启用')}
@@ -302,56 +298,26 @@ const AccountManagement = ({
</div>
<div className='flex-shrink-0'>
{status.telegram_oauth ? (
isBound(userState.user?.telegram_id) ? (
<Button
disabled
size='small'
type='primary'
theme='outline'
>
userState.user.telegram_id !== '' ? (
<Button disabled={true} size='small'>
{t('已绑定')}
</Button>
) : (
<Button
type='primary'
theme='outline'
size='small'
onClick={() => setShowTelegramBindModal(true)}
>
{t('绑定')}
</Button>
<div className='scale-75'>
<TelegramLoginButton
dataAuthUrl='/api/oauth/telegram/bind'
botName={status.telegram_bot_name}
/>
</div>
)
) : (
<Button
disabled
size='small'
type='primary'
theme='outline'
>
<Button disabled={true} size='small'>
{t('未启用')}
</Button>
)}
</div>
</div>
</Card>
<Modal
title={t('绑定 Telegram')}
visible={showTelegramBindModal}
onCancel={() => setShowTelegramBindModal(false)}
footer={null}
>
<div className='my-3 text-sm text-gray-600'>
{t('点击下方按钮通过 Telegram 完成绑定')}
</div>
<div className='flex justify-center'>
<div className='scale-90'>
<TelegramLoginButton
dataAuthUrl='/api/oauth/telegram/bind'
botName={status.telegram_bot_name}
/>
</div>
</div>
</Modal>
{/* LinuxDO绑定 */}
<Card className='!rounded-xl'>
@@ -384,7 +350,8 @@ const AccountManagement = ({
onLinuxDOOAuthClicked(status.linuxdo_client_id)
}
disabled={
isBound(userState.user?.linux_do_id) || !status.linuxdo_oauth
(userState.user && userState.user.linux_do_id !== '') ||
!status.linuxdo_oauth
}
>
{status.linuxdo_oauth ? t('绑定') : t('未启用')}

View File

@@ -40,11 +40,10 @@ import {
showSuccess,
showError,
} from '../../../../helpers';
import CodeViewer from '../../../playground/CodeViewer';
import CodeViewer from '../../../common/ui/CodeViewer';
import { StatusContext } from '../../../../context/Status';
import { UserContext } from '../../../../context/User';
import { useUserPermissions } from '../../../../hooks/common/useUserPermissions';
import { useSidebar } from '../../../../hooks/common/useSidebar';
const NotificationSettings = ({
t,
@@ -98,9 +97,6 @@ const NotificationSettings = ({
isSidebarModuleAllowed,
} = useUserPermissions();
// 使用useSidebar钩子获取刷新方法
const { refreshUserConfig } = useSidebar();
// 左侧边栏设置处理函数
const handleSectionChange = (sectionKey) => {
return (checked) => {
@@ -136,9 +132,6 @@ const NotificationSettings = ({
});
if (res.data.success) {
showSuccess(t('侧边栏设置保存成功'));
// 刷新useSidebar钩子中的用户配置实现实时更新
await refreshUserConfig();
} else {
showError(res.data.message);
}
@@ -341,7 +334,7 @@ const NotificationSettings = ({
loading={sidebarLoading}
className='!rounded-lg'
>
{t('保存设置')}
{t('保存边栏设置')}
</Button>
</>
) : (

View File

@@ -85,26 +85,6 @@ const REGION_EXAMPLE = {
'claude-3-5-sonnet-20240620': 'europe-west1',
};
// 支持并且已适配通过接口获取模型列表的渠道类型
const MODEL_FETCHABLE_TYPES = new Set([
1,
4,
14,
34,
17,
26,
24,
47,
25,
20,
23,
31,
35,
40,
42,
48,
]);
function type2secretPrompt(type) {
// inputs.type === 15 ? '按照如下格式输入APIKey|SecretKey' : (inputs.type === 18 ? '按照如下格式输入APPID|APISecret|APIKey' : '请输入渠道对应的鉴权密钥')
switch (type) {
@@ -164,8 +144,6 @@ const EditChannelModal = (props) => {
settings: '',
// 仅 Vertex: 密钥格式(存入 settings.vertex_key_type
vertex_key_type: 'json',
// 企业账户设置
is_enterprise_account: false,
};
const [batch, setBatch] = useState(false);
const [multiToSingle, setMultiToSingle] = useState(false);
@@ -191,7 +169,6 @@ const EditChannelModal = (props) => {
const [channelSearchValue, setChannelSearchValue] = useState('');
const [useManualInput, setUseManualInput] = useState(false); // 是否使用手动输入模式
const [keyMode, setKeyMode] = useState('append'); // 密钥模式replace覆盖或 append追加
const [isEnterpriseAccount, setIsEnterpriseAccount] = useState(false); // 是否为企业账户
// 2FA验证查看密钥相关状态
const [twoFAState, setTwoFAState] = useState({
@@ -238,7 +215,7 @@ const EditChannelModal = (props) => {
pass_through_body_enabled: false,
system_prompt: '',
});
const showApiConfigCard = true; // 控制是否显示 API 配置卡片
const showApiConfigCard = inputs.type !== 45; // 控制是否显示 API 配置卡片(仅当渠道类型不是 豆包 时显示)
const getInitValues = () => ({ ...originInputs });
// 处理渠道额外设置的更新
@@ -345,10 +322,6 @@ const EditChannelModal = (props) => {
case 36:
localModels = ['suno_music', 'suno_lyrics'];
break;
case 45:
localModels = getChannelModels(value);
setInputs((prevInputs) => ({ ...prevInputs, base_url: 'https://ark.cn-beijing.volces.com' }));
break;
default:
localModels = getChannelModels(value);
break;
@@ -440,27 +413,15 @@ const EditChannelModal = (props) => {
parsedSettings.azure_responses_version || '';
// 读取 Vertex 密钥格式
data.vertex_key_type = parsedSettings.vertex_key_type || 'json';
// 读取企业账户设置
data.is_enterprise_account = parsedSettings.openrouter_enterprise === true;
} catch (error) {
console.error('解析其他设置失败:', error);
data.azure_responses_version = '';
data.region = '';
data.vertex_key_type = 'json';
data.is_enterprise_account = false;
}
} else {
// 兼容历史数据:老渠道没有 settings 时,默认按 json 展示
data.vertex_key_type = 'json';
data.is_enterprise_account = false;
}
if (
data.type === 45 &&
(!data.base_url ||
(typeof data.base_url === 'string' && data.base_url.trim() === ''))
) {
data.base_url = 'https://ark.cn-beijing.volces.com';
}
setInputs(data);
@@ -472,8 +433,6 @@ const EditChannelModal = (props) => {
} else {
setAutoBan(true);
}
// 同步企业账户状态
setIsEnterpriseAccount(data.is_enterprise_account || false);
setBasicModels(getChannelModels(data.type));
// 同步更新channelSettings状态显示
setChannelSettings({
@@ -733,8 +692,6 @@ const EditChannelModal = (props) => {
});
// 重置密钥模式状态
setKeyMode('append');
// 重置企业账户状态
setIsEnterpriseAccount(false);
// 清空表单中的key_mode字段
if (formApiRef.current) {
formApiRef.current.setValue('key_mode', undefined);
@@ -867,10 +824,6 @@ const EditChannelModal = (props) => {
showInfo(t('请至少选择一个模型!'));
return;
}
if (localInputs.type === 45 && (!localInputs.base_url || localInputs.base_url.trim() === '')) {
showInfo(t('请输入API地址'));
return;
}
if (
localInputs.model_mapping &&
localInputs.model_mapping !== '' &&
@@ -900,21 +853,6 @@ const EditChannelModal = (props) => {
};
localInputs.setting = JSON.stringify(channelExtraSettings);
// 处理type === 20的企业账户设置
if (localInputs.type === 20) {
let settings = {};
if (localInputs.settings) {
try {
settings = JSON.parse(localInputs.settings);
} catch (error) {
console.error('解析settings失败:', error);
}
}
// 设置企业账户标识无论是true还是false都要传到后端
settings.openrouter_enterprise = localInputs.is_enterprise_account === true;
localInputs.settings = JSON.stringify(settings);
}
// 清理不需要发送到后端的字段
delete localInputs.force_format;
delete localInputs.thinking_to_content;
@@ -922,7 +860,6 @@ const EditChannelModal = (props) => {
delete localInputs.pass_through_body_enabled;
delete localInputs.system_prompt;
delete localInputs.system_prompt_override;
delete localInputs.is_enterprise_account;
// 顶层的 vertex_key_type 不应发送给后端
delete localInputs.vertex_key_type;
@@ -964,56 +901,6 @@ const EditChannelModal = (props) => {
}
};
// 密钥去重函数
const deduplicateKeys = () => {
const currentKey = formApiRef.current?.getValue('key') || inputs.key || '';
if (!currentKey.trim()) {
showInfo(t('请先输入密钥'));
return;
}
// 按行分割密钥
const keyLines = currentKey.split('\n');
const beforeCount = keyLines.length;
// 使用哈希表去重,保持原有顺序
const keySet = new Set();
const deduplicatedKeys = [];
keyLines.forEach((line) => {
const trimmedLine = line.trim();
if (trimmedLine && !keySet.has(trimmedLine)) {
keySet.add(trimmedLine);
deduplicatedKeys.push(trimmedLine);
}
});
const afterCount = deduplicatedKeys.length;
const deduplicatedKeyText = deduplicatedKeys.join('\n');
// 更新表单和状态
if (formApiRef.current) {
formApiRef.current.setValue('key', deduplicatedKeyText);
}
handleInputChange('key', deduplicatedKeyText);
// 显示去重结果
const message = t(
'去重完成:去重前 {{before}} 个密钥,去重后 {{after}} 个密钥',
{
before: beforeCount,
after: afterCount,
},
);
if (beforeCount === afterCount) {
showInfo(t('未发现重复密钥'));
} else {
showSuccess(message);
}
};
const addCustomModels = () => {
if (customModel.trim() === '') return;
const modelArray = customModel.split(',').map((model) => model.trim());
@@ -1109,41 +996,24 @@ const EditChannelModal = (props) => {
</Checkbox>
)}
{batch && (
<>
<Checkbox
disabled={isEdit}
checked={multiToSingle}
onChange={() => {
setMultiToSingle((prev) => {
const nextValue = !prev;
setInputs((prevInputs) => {
const newInputs = { ...prevInputs };
if (nextValue) {
newInputs.multi_key_mode = multiKeyMode;
} else {
delete newInputs.multi_key_mode;
}
return newInputs;
});
return nextValue;
});
}}
>
{t('密钥聚合模式')}
</Checkbox>
{inputs.type !== 41 && (
<Button
size='small'
type='tertiary'
theme='outline'
onClick={deduplicateKeys}
style={{ textDecoration: 'underline' }}
>
{t('密钥去重')}
</Button>
)}
</>
<Checkbox
disabled={isEdit}
checked={multiToSingle}
onChange={() => {
setMultiToSingle((prev) => !prev);
setInputs((prev) => {
const newInputs = { ...prev };
if (!multiToSingle) {
newInputs.multi_key_mode = multiKeyMode;
} else {
delete newInputs.multi_key_mode;
}
return newInputs;
});
}}
>
{t('密钥聚合模式')}
</Checkbox>
)}
</Space>
) : null;
@@ -1307,21 +1177,6 @@ const EditChannelModal = (props) => {
onChange={(value) => handleInputChange('type', value)}
/>
{inputs.type === 20 && (
<Form.Switch
field='is_enterprise_account'
label={t('是否为企业账户')}
checkedText={t('是')}
uncheckedText={t('否')}
onChange={(value) => {
setIsEnterpriseAccount(value);
handleInputChange('is_enterprise_account', value);
}}
extraText={t('企业账户为特殊返回格式,需要特殊处理,如果非企业账户,请勿勾选')}
initValue={inputs.is_enterprise_account}
/>
)}
<Form.Input
field='name'
label={t('名称')}
@@ -1405,7 +1260,7 @@ const EditChannelModal = (props) => {
autoComplete='new-password'
onChange={(value) => handleInputChange('key', value)}
extraText={
<div className='flex items-center gap-2 flex-wrap'>
<div className='flex items-center gap-2'>
{isEdit &&
isMultiKeyChannel &&
keyMode === 'append' && (
@@ -1941,30 +1796,6 @@ const EditChannelModal = (props) => {
/>
</div>
)}
{inputs.type === 45 && (
<div>
<Form.Select
field='base_url'
label={t('API地址')}
placeholder={t('请选择API地址')}
onChange={(value) =>
handleInputChange('base_url', value)
}
optionList={[
{
value: 'https://ark.cn-beijing.volces.com',
label: 'https://ark.cn-beijing.volces.com'
},
{
value: 'https://ark.ap-southeast.bytepluses.com',
label: 'https://ark.ap-southeast.bytepluses.com'
}
]}
defaultValue='https://ark.cn-beijing.volces.com'
/>
</div>
)}
</Card>
)}
@@ -2048,15 +1879,13 @@ const EditChannelModal = (props) => {
>
{t('填入所有模型')}
</Button>
{MODEL_FETCHABLE_TYPES.has(inputs.type) && (
<Button
size='small'
type='tertiary'
onClick={() => fetchUpstreamModelList('models')}
>
{t('获取模型列表')}
</Button>
)}
<Button
size='small'
type='tertiary'
onClick={() => fetchUpstreamModelList('models')}
>
{t('获取模型列表')}
</Button>
<Button
size='small'
type='warning'

View File

@@ -247,32 +247,6 @@ const MultiKeyManageModal = ({ visible, onCancel, channel, onRefresh }) => {
}
};
// Delete a specific key
const handleDeleteKey = async (keyIndex) => {
const operationId = `delete_${keyIndex}`;
setOperationLoading((prev) => ({ ...prev, [operationId]: true }));
try {
const res = await API.post('/api/channel/multi_key/manage', {
channel_id: channel.id,
action: 'delete_key',
key_index: keyIndex,
});
if (res.data.success) {
showSuccess(t('密钥已删除'));
await loadKeyStatus(currentPage, pageSize); // Reload current page
onRefresh && onRefresh(); // Refresh parent component
} else {
showError(res.data.message);
}
} catch (error) {
showError(t('删除密钥失败'));
} finally {
setOperationLoading((prev) => ({ ...prev, [operationId]: false }));
}
};
// Handle page change
const handlePageChange = (page) => {
setCurrentPage(page);
@@ -410,7 +384,7 @@ const MultiKeyManageModal = ({ visible, onCancel, channel, onRefresh }) => {
title: t('操作'),
key: 'action',
fixed: 'right',
width: 150,
width: 100,
render: (_, record) => (
<Space>
{record.status === 1 ? (
@@ -432,21 +406,6 @@ const MultiKeyManageModal = ({ visible, onCancel, channel, onRefresh }) => {
{t('启用')}
</Button>
)}
<Popconfirm
title={t('确定要删除此密钥吗?')}
content={t('此操作不可撤销,将永久删除该密钥')}
onConfirm={() => handleDeleteKey(record.index)}
okType={'danger'}
position={'topRight'}
>
<Button
type='danger'
size='small'
loading={operationLoading[`delete_${record.index}`]}
>
{t('删除')}
</Button>
</Popconfirm>
</Space>
),
},

Some files were not shown because too many files have changed in this diff Show More