mirror of
https://github.com/QuantumNous/new-api.git
synced 2026-04-06 14:23:09 +00:00
Compare commits
16 Commits
v0.9.1.0
...
feature/ss
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
380e1b7d56 | ||
|
|
63828349de | ||
|
|
5706f0ee9f | ||
|
|
e9e1dbff5e | ||
|
|
315eabc1e7 | ||
|
|
359dbc9d94 | ||
|
|
e157ea6ba2 | ||
|
|
dc3dba0665 | ||
|
|
81272da9ac | ||
|
|
926cad87b3 | ||
|
|
418ce449b7 | ||
|
|
4a02ab23ce | ||
|
|
984097c60b | ||
|
|
5550ec017e | ||
|
|
9e6752e0ee | ||
|
|
91a0eb7031 |
8
.github/workflows/linux-release.yml
vendored
8
.github/workflows/linux-release.yml
vendored
@@ -38,21 +38,21 @@ jobs:
|
||||
- name: Build Backend (amd64)
|
||||
run: |
|
||||
go mod download
|
||||
go build -ldflags "-s -w -X 'one-api/common.Version=$(git describe --tags)' -extldflags '-static'" -o new-api
|
||||
go build -ldflags "-s -w -X 'one-api/common.Version=$(git describe --tags)' -extldflags '-static'" -o one-api
|
||||
|
||||
- name: Build Backend (arm64)
|
||||
run: |
|
||||
sudo apt-get update
|
||||
DEBIAN_FRONTEND=noninteractive sudo apt-get install -y gcc-aarch64-linux-gnu
|
||||
CC=aarch64-linux-gnu-gcc CGO_ENABLED=1 GOOS=linux GOARCH=arm64 go build -ldflags "-s -w -X 'one-api/common.Version=$(git describe --tags)' -extldflags '-static'" -o new-api-arm64
|
||||
CC=aarch64-linux-gnu-gcc CGO_ENABLED=1 GOOS=linux GOARCH=arm64 go build -ldflags "-s -w -X 'one-api/common.Version=$(git describe --tags)' -extldflags '-static'" -o one-api-arm64
|
||||
|
||||
- name: Release
|
||||
uses: softprops/action-gh-release@v1
|
||||
if: startsWith(github.ref, 'refs/tags/')
|
||||
with:
|
||||
files: |
|
||||
new-api
|
||||
new-api-arm64
|
||||
one-api
|
||||
one-api-arm64
|
||||
draft: true
|
||||
generate_release_notes: true
|
||||
env:
|
||||
|
||||
4
.github/workflows/macos-release.yml
vendored
4
.github/workflows/macos-release.yml
vendored
@@ -39,12 +39,12 @@ jobs:
|
||||
- name: Build Backend
|
||||
run: |
|
||||
go mod download
|
||||
go build -ldflags "-X 'one-api/common.Version=$(git describe --tags)'" -o new-api-macos
|
||||
go build -ldflags "-X 'one-api/common.Version=$(git describe --tags)'" -o one-api-macos
|
||||
- name: Release
|
||||
uses: softprops/action-gh-release@v1
|
||||
if: startsWith(github.ref, 'refs/tags/')
|
||||
with:
|
||||
files: new-api-macos
|
||||
files: one-api-macos
|
||||
draft: true
|
||||
generate_release_notes: true
|
||||
env:
|
||||
|
||||
4
.github/workflows/windows-release.yml
vendored
4
.github/workflows/windows-release.yml
vendored
@@ -41,12 +41,12 @@ jobs:
|
||||
- name: Build Backend
|
||||
run: |
|
||||
go mod download
|
||||
go build -ldflags "-s -w -X 'one-api/common.Version=$(git describe --tags)'" -o new-api.exe
|
||||
go build -ldflags "-s -w -X 'one-api/common.Version=$(git describe --tags)'" -o one-api.exe
|
||||
- name: Release
|
||||
uses: softprops/action-gh-release@v1
|
||||
if: startsWith(github.ref, 'refs/tags/')
|
||||
with:
|
||||
files: new-api.exe
|
||||
files: one-api.exe
|
||||
draft: true
|
||||
generate_release_notes: true
|
||||
env:
|
||||
|
||||
22
common/ip.go
22
common/ip.go
@@ -1,22 +0,0 @@
|
||||
package common
|
||||
|
||||
import "net"
|
||||
|
||||
func IsPrivateIP(ip net.IP) bool {
|
||||
if ip.IsLoopback() || ip.IsLinkLocalUnicast() || ip.IsLinkLocalMulticast() {
|
||||
return true
|
||||
}
|
||||
|
||||
private := []net.IPNet{
|
||||
{IP: net.IPv4(10, 0, 0, 0), Mask: net.CIDRMask(8, 32)},
|
||||
{IP: net.IPv4(172, 16, 0, 0), Mask: net.CIDRMask(12, 32)},
|
||||
{IP: net.IPv4(192, 168, 0, 0), Mask: net.CIDRMask(16, 32)},
|
||||
}
|
||||
|
||||
for _, privateNet := range private {
|
||||
if privateNet.Contains(ip) {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
@@ -1,327 +0,0 @@
|
||||
package common
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"net"
|
||||
"net/url"
|
||||
"strconv"
|
||||
"strings"
|
||||
)
|
||||
|
||||
// SSRFProtection SSRF防护配置
|
||||
type SSRFProtection struct {
|
||||
AllowPrivateIp bool
|
||||
DomainFilterMode bool // true: 白名单, false: 黑名单
|
||||
DomainList []string // domain format, e.g. example.com, *.example.com
|
||||
IpFilterMode bool // true: 白名单, false: 黑名单
|
||||
IpList []string // CIDR or single IP
|
||||
AllowedPorts []int // 允许的端口范围
|
||||
ApplyIPFilterForDomain bool // 对域名启用IP过滤
|
||||
}
|
||||
|
||||
// DefaultSSRFProtection 默认SSRF防护配置
|
||||
var DefaultSSRFProtection = &SSRFProtection{
|
||||
AllowPrivateIp: false,
|
||||
DomainFilterMode: true,
|
||||
DomainList: []string{},
|
||||
IpFilterMode: true,
|
||||
IpList: []string{},
|
||||
AllowedPorts: []int{},
|
||||
}
|
||||
|
||||
// isPrivateIP 检查IP是否为私有地址
|
||||
func isPrivateIP(ip net.IP) bool {
|
||||
if ip.IsLoopback() || ip.IsLinkLocalUnicast() || ip.IsLinkLocalMulticast() {
|
||||
return true
|
||||
}
|
||||
|
||||
// 检查私有网段
|
||||
private := []net.IPNet{
|
||||
{IP: net.IPv4(10, 0, 0, 0), Mask: net.CIDRMask(8, 32)}, // 10.0.0.0/8
|
||||
{IP: net.IPv4(172, 16, 0, 0), Mask: net.CIDRMask(12, 32)}, // 172.16.0.0/12
|
||||
{IP: net.IPv4(192, 168, 0, 0), Mask: net.CIDRMask(16, 32)}, // 192.168.0.0/16
|
||||
{IP: net.IPv4(127, 0, 0, 0), Mask: net.CIDRMask(8, 32)}, // 127.0.0.0/8
|
||||
{IP: net.IPv4(169, 254, 0, 0), Mask: net.CIDRMask(16, 32)}, // 169.254.0.0/16 (链路本地)
|
||||
{IP: net.IPv4(224, 0, 0, 0), Mask: net.CIDRMask(4, 32)}, // 224.0.0.0/4 (组播)
|
||||
{IP: net.IPv4(240, 0, 0, 0), Mask: net.CIDRMask(4, 32)}, // 240.0.0.0/4 (保留)
|
||||
}
|
||||
|
||||
for _, privateNet := range private {
|
||||
if privateNet.Contains(ip) {
|
||||
return true
|
||||
}
|
||||
}
|
||||
|
||||
// 检查IPv6私有地址
|
||||
if ip.To4() == nil {
|
||||
// IPv6 loopback
|
||||
if ip.Equal(net.IPv6loopback) {
|
||||
return true
|
||||
}
|
||||
// IPv6 link-local
|
||||
if strings.HasPrefix(ip.String(), "fe80:") {
|
||||
return true
|
||||
}
|
||||
// IPv6 unique local
|
||||
if strings.HasPrefix(ip.String(), "fc") || strings.HasPrefix(ip.String(), "fd") {
|
||||
return true
|
||||
}
|
||||
}
|
||||
|
||||
return false
|
||||
}
|
||||
|
||||
// parsePortRanges 解析端口范围配置
|
||||
// 支持格式: "80", "443", "8000-9000"
|
||||
func parsePortRanges(portConfigs []string) ([]int, error) {
|
||||
var ports []int
|
||||
|
||||
for _, config := range portConfigs {
|
||||
config = strings.TrimSpace(config)
|
||||
if config == "" {
|
||||
continue
|
||||
}
|
||||
|
||||
if strings.Contains(config, "-") {
|
||||
// 处理端口范围 "8000-9000"
|
||||
parts := strings.Split(config, "-")
|
||||
if len(parts) != 2 {
|
||||
return nil, fmt.Errorf("invalid port range format: %s", config)
|
||||
}
|
||||
|
||||
startPort, err := strconv.Atoi(strings.TrimSpace(parts[0]))
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("invalid start port in range %s: %v", config, err)
|
||||
}
|
||||
|
||||
endPort, err := strconv.Atoi(strings.TrimSpace(parts[1]))
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("invalid end port in range %s: %v", config, err)
|
||||
}
|
||||
|
||||
if startPort > endPort {
|
||||
return nil, fmt.Errorf("invalid port range %s: start port cannot be greater than end port", config)
|
||||
}
|
||||
|
||||
if startPort < 1 || startPort > 65535 || endPort < 1 || endPort > 65535 {
|
||||
return nil, fmt.Errorf("port range %s contains invalid port numbers (must be 1-65535)", config)
|
||||
}
|
||||
|
||||
// 添加范围内的所有端口
|
||||
for port := startPort; port <= endPort; port++ {
|
||||
ports = append(ports, port)
|
||||
}
|
||||
} else {
|
||||
// 处理单个端口 "80"
|
||||
port, err := strconv.Atoi(config)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("invalid port number: %s", config)
|
||||
}
|
||||
|
||||
if port < 1 || port > 65535 {
|
||||
return nil, fmt.Errorf("invalid port number %d (must be 1-65535)", port)
|
||||
}
|
||||
|
||||
ports = append(ports, port)
|
||||
}
|
||||
}
|
||||
|
||||
return ports, nil
|
||||
}
|
||||
|
||||
// isAllowedPort 检查端口是否被允许
|
||||
func (p *SSRFProtection) isAllowedPort(port int) bool {
|
||||
if len(p.AllowedPorts) == 0 {
|
||||
return true // 如果没有配置端口限制,则允许所有端口
|
||||
}
|
||||
|
||||
for _, allowedPort := range p.AllowedPorts {
|
||||
if port == allowedPort {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// isDomainWhitelisted 检查域名是否在白名单中
|
||||
func isDomainListed(domain string, list []string) bool {
|
||||
if len(list) == 0 {
|
||||
return false
|
||||
}
|
||||
|
||||
domain = strings.ToLower(domain)
|
||||
for _, item := range list {
|
||||
item = strings.ToLower(strings.TrimSpace(item))
|
||||
if item == "" {
|
||||
continue
|
||||
}
|
||||
// 精确匹配
|
||||
if domain == item {
|
||||
return true
|
||||
}
|
||||
// 通配符匹配 (*.example.com)
|
||||
if strings.HasPrefix(item, "*.") {
|
||||
suffix := strings.TrimPrefix(item, "*.")
|
||||
if strings.HasSuffix(domain, "."+suffix) || domain == suffix {
|
||||
return true
|
||||
}
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func (p *SSRFProtection) isDomainAllowed(domain string) bool {
|
||||
listed := isDomainListed(domain, p.DomainList)
|
||||
if p.DomainFilterMode { // 白名单
|
||||
return listed
|
||||
}
|
||||
// 黑名单
|
||||
return !listed
|
||||
}
|
||||
|
||||
// isIPWhitelisted 检查IP是否在白名单中
|
||||
|
||||
func isIPListed(ip net.IP, list []string) bool {
|
||||
if len(list) == 0 {
|
||||
return false
|
||||
}
|
||||
|
||||
for _, whitelistCIDR := range list {
|
||||
_, network, err := net.ParseCIDR(whitelistCIDR)
|
||||
if err != nil {
|
||||
// 尝试作为单个IP处理
|
||||
if whitelistIP := net.ParseIP(whitelistCIDR); whitelistIP != nil {
|
||||
if ip.Equal(whitelistIP) {
|
||||
return true
|
||||
}
|
||||
}
|
||||
continue
|
||||
}
|
||||
|
||||
if network.Contains(ip) {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// IsIPAccessAllowed 检查IP是否允许访问
|
||||
func (p *SSRFProtection) IsIPAccessAllowed(ip net.IP) bool {
|
||||
// 私有IP限制
|
||||
if isPrivateIP(ip) && !p.AllowPrivateIp {
|
||||
return false
|
||||
}
|
||||
|
||||
listed := isIPListed(ip, p.IpList)
|
||||
if p.IpFilterMode { // 白名单
|
||||
return listed
|
||||
}
|
||||
// 黑名单
|
||||
return !listed
|
||||
}
|
||||
|
||||
// ValidateURL 验证URL是否安全
|
||||
func (p *SSRFProtection) ValidateURL(urlStr string) error {
|
||||
// 解析URL
|
||||
u, err := url.Parse(urlStr)
|
||||
if err != nil {
|
||||
return fmt.Errorf("invalid URL format: %v", err)
|
||||
}
|
||||
|
||||
// 只允许HTTP/HTTPS协议
|
||||
if u.Scheme != "http" && u.Scheme != "https" {
|
||||
return fmt.Errorf("unsupported protocol: %s (only http/https allowed)", u.Scheme)
|
||||
}
|
||||
|
||||
// 解析主机和端口
|
||||
host, portStr, err := net.SplitHostPort(u.Host)
|
||||
if err != nil {
|
||||
// 没有端口,使用默认端口
|
||||
host = u.Hostname()
|
||||
if u.Scheme == "https" {
|
||||
portStr = "443"
|
||||
} else {
|
||||
portStr = "80"
|
||||
}
|
||||
}
|
||||
|
||||
// 验证端口
|
||||
port, err := strconv.Atoi(portStr)
|
||||
if err != nil {
|
||||
return fmt.Errorf("invalid port: %s", portStr)
|
||||
}
|
||||
|
||||
if !p.isAllowedPort(port) {
|
||||
return fmt.Errorf("port %d is not allowed", port)
|
||||
}
|
||||
|
||||
// 如果 host 是 IP,则跳过域名检查
|
||||
if ip := net.ParseIP(host); ip != nil {
|
||||
if !p.IsIPAccessAllowed(ip) {
|
||||
if isPrivateIP(ip) {
|
||||
return fmt.Errorf("private IP address not allowed: %s", ip.String())
|
||||
}
|
||||
if p.IpFilterMode {
|
||||
return fmt.Errorf("ip not in whitelist: %s", ip.String())
|
||||
}
|
||||
return fmt.Errorf("ip in blacklist: %s", ip.String())
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// 先进行域名过滤
|
||||
if !p.isDomainAllowed(host) {
|
||||
if p.DomainFilterMode {
|
||||
return fmt.Errorf("domain not in whitelist: %s", host)
|
||||
}
|
||||
return fmt.Errorf("domain in blacklist: %s", host)
|
||||
}
|
||||
|
||||
// 若未启用对域名应用IP过滤,则到此通过
|
||||
if !p.ApplyIPFilterForDomain {
|
||||
return nil
|
||||
}
|
||||
|
||||
// 解析域名对应IP并检查
|
||||
ips, err := net.LookupIP(host)
|
||||
if err != nil {
|
||||
return fmt.Errorf("DNS resolution failed for %s: %v", host, err)
|
||||
}
|
||||
for _, ip := range ips {
|
||||
if !p.IsIPAccessAllowed(ip) {
|
||||
if isPrivateIP(ip) && !p.AllowPrivateIp {
|
||||
return fmt.Errorf("private IP address not allowed: %s resolves to %s", host, ip.String())
|
||||
}
|
||||
if p.IpFilterMode {
|
||||
return fmt.Errorf("ip not in whitelist: %s resolves to %s", host, ip.String())
|
||||
}
|
||||
return fmt.Errorf("ip in blacklist: %s resolves to %s", host, ip.String())
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// ValidateURLWithFetchSetting 使用FetchSetting配置验证URL
|
||||
func ValidateURLWithFetchSetting(urlStr string, enableSSRFProtection, allowPrivateIp bool, domainFilterMode bool, ipFilterMode bool, domainList, ipList, allowedPorts []string, applyIPFilterForDomain bool) error {
|
||||
// 如果SSRF防护被禁用,直接返回成功
|
||||
if !enableSSRFProtection {
|
||||
return nil
|
||||
}
|
||||
|
||||
// 解析端口范围配置
|
||||
allowedPortInts, err := parsePortRanges(allowedPorts)
|
||||
if err != nil {
|
||||
return fmt.Errorf("request reject - invalid port configuration: %v", err)
|
||||
}
|
||||
|
||||
protection := &SSRFProtection{
|
||||
AllowPrivateIp: allowPrivateIp,
|
||||
DomainFilterMode: domainFilterMode,
|
||||
DomainList: domainList,
|
||||
IpFilterMode: ipFilterMode,
|
||||
IpList: ipList,
|
||||
AllowedPorts: allowedPortInts,
|
||||
ApplyIPFilterForDomain: applyIPFilterForDomain,
|
||||
}
|
||||
return protection.ValidateURL(urlStr)
|
||||
}
|
||||
@@ -2,10 +2,9 @@ package common
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"github.com/gin-gonic/gin"
|
||||
"os"
|
||||
"time"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
func SysLog(s string) {
|
||||
@@ -23,33 +22,3 @@ func FatalLog(v ...any) {
|
||||
_, _ = fmt.Fprintf(gin.DefaultErrorWriter, "[FATAL] %v | %v \n", t.Format("2006/01/02 - 15:04:05"), v)
|
||||
os.Exit(1)
|
||||
}
|
||||
|
||||
func LogStartupSuccess(startTime time.Time, port string) {
|
||||
|
||||
duration := time.Since(startTime)
|
||||
durationMs := duration.Milliseconds()
|
||||
|
||||
// Get network IPs
|
||||
networkIps := GetNetworkIps()
|
||||
|
||||
// Print blank line for spacing
|
||||
fmt.Fprintf(gin.DefaultWriter, "\n")
|
||||
|
||||
// Print the main success message
|
||||
fmt.Fprintf(gin.DefaultWriter, " \033[32m%s %s\033[0m ready in %d ms\n", SystemName, Version, durationMs)
|
||||
fmt.Fprintf(gin.DefaultWriter, "\n")
|
||||
|
||||
// Skip fancy startup message in container environments
|
||||
if !IsRunningInContainer() {
|
||||
// Print local URL
|
||||
fmt.Fprintf(gin.DefaultWriter, " ➜ \033[1mLocal:\033[0m http://localhost:%s/\n", port)
|
||||
}
|
||||
|
||||
// Print network URLs
|
||||
for _, ip := range networkIps {
|
||||
fmt.Fprintf(gin.DefaultWriter, " ➜ \033[1mNetwork:\033[0m http://%s:%s/\n", ip, port)
|
||||
}
|
||||
|
||||
// Print blank line for spacing
|
||||
fmt.Fprintf(gin.DefaultWriter, "\n")
|
||||
}
|
||||
|
||||
@@ -68,78 +68,6 @@ func GetIp() (ip string) {
|
||||
return
|
||||
}
|
||||
|
||||
func GetNetworkIps() []string {
|
||||
var networkIps []string
|
||||
ips, err := net.InterfaceAddrs()
|
||||
if err != nil {
|
||||
log.Println(err)
|
||||
return networkIps
|
||||
}
|
||||
|
||||
for _, a := range ips {
|
||||
if ipNet, ok := a.(*net.IPNet); ok && !ipNet.IP.IsLoopback() {
|
||||
if ipNet.IP.To4() != nil {
|
||||
ip := ipNet.IP.String()
|
||||
// Include common private network ranges
|
||||
if strings.HasPrefix(ip, "10.") ||
|
||||
strings.HasPrefix(ip, "172.") ||
|
||||
strings.HasPrefix(ip, "192.168.") {
|
||||
networkIps = append(networkIps, ip)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
return networkIps
|
||||
}
|
||||
|
||||
// IsRunningInContainer detects if the application is running inside a container
|
||||
func IsRunningInContainer() bool {
|
||||
// Method 1: Check for .dockerenv file (Docker containers)
|
||||
if _, err := os.Stat("/.dockerenv"); err == nil {
|
||||
return true
|
||||
}
|
||||
|
||||
// Method 2: Check cgroup for container indicators
|
||||
if data, err := os.ReadFile("/proc/1/cgroup"); err == nil {
|
||||
content := string(data)
|
||||
if strings.Contains(content, "docker") ||
|
||||
strings.Contains(content, "containerd") ||
|
||||
strings.Contains(content, "kubepods") ||
|
||||
strings.Contains(content, "/lxc/") {
|
||||
return true
|
||||
}
|
||||
}
|
||||
|
||||
// Method 3: Check environment variables commonly set by container runtimes
|
||||
containerEnvVars := []string{
|
||||
"KUBERNETES_SERVICE_HOST",
|
||||
"DOCKER_CONTAINER",
|
||||
"container",
|
||||
}
|
||||
|
||||
for _, envVar := range containerEnvVars {
|
||||
if os.Getenv(envVar) != "" {
|
||||
return true
|
||||
}
|
||||
}
|
||||
|
||||
// Method 4: Check if init process is not the traditional init
|
||||
if data, err := os.ReadFile("/proc/1/comm"); err == nil {
|
||||
comm := strings.TrimSpace(string(data))
|
||||
// In containers, process 1 is often not "init" or "systemd"
|
||||
if comm != "init" && comm != "systemd" {
|
||||
// Additional check: if it's a common container entrypoint
|
||||
if strings.Contains(comm, "docker") ||
|
||||
strings.Contains(comm, "containerd") ||
|
||||
strings.Contains(comm, "runc") {
|
||||
return true
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return false
|
||||
}
|
||||
|
||||
var sizeKB = 1024
|
||||
var sizeMB = sizeKB * 1024
|
||||
var sizeGB = sizeMB * 1024
|
||||
|
||||
@@ -11,10 +11,8 @@ const (
|
||||
SunoActionMusic = "MUSIC"
|
||||
SunoActionLyrics = "LYRICS"
|
||||
|
||||
TaskActionGenerate = "generate"
|
||||
TaskActionTextGenerate = "textGenerate"
|
||||
TaskActionFirstTailGenerate = "firstTailGenerate"
|
||||
TaskActionReferenceGenerate = "referenceGenerate"
|
||||
TaskActionGenerate = "generate"
|
||||
TaskActionTextGenerate = "textGenerate"
|
||||
)
|
||||
|
||||
var SunoModel2Action = map[string]string{
|
||||
|
||||
@@ -90,11 +90,6 @@ func testChannel(channel *model.Channel, testModel string) testResult {
|
||||
requestPath = "/v1/embeddings" // 修改请求路径
|
||||
}
|
||||
|
||||
// VolcEngine 图像生成模型
|
||||
if channel.Type == constant.ChannelTypeVolcEngine && strings.Contains(testModel, "seedream") {
|
||||
requestPath = "/v1/images/generations"
|
||||
}
|
||||
|
||||
c.Request = &http.Request{
|
||||
Method: "POST",
|
||||
URL: &url.URL{Path: requestPath}, // 使用动态路径
|
||||
@@ -114,21 +109,6 @@ func testChannel(channel *model.Channel, testModel string) testResult {
|
||||
}
|
||||
}
|
||||
|
||||
// 重新检查模型类型并更新请求路径
|
||||
if strings.Contains(strings.ToLower(testModel), "embedding") ||
|
||||
strings.HasPrefix(testModel, "m3e") ||
|
||||
strings.Contains(testModel, "bge-") ||
|
||||
strings.Contains(testModel, "embed") ||
|
||||
channel.Type == constant.ChannelTypeMokaAI {
|
||||
requestPath = "/v1/embeddings"
|
||||
c.Request.URL.Path = requestPath
|
||||
}
|
||||
|
||||
if channel.Type == constant.ChannelTypeVolcEngine && strings.Contains(testModel, "seedream") {
|
||||
requestPath = "/v1/images/generations"
|
||||
c.Request.URL.Path = requestPath
|
||||
}
|
||||
|
||||
cache, err := model.GetUserCache(1)
|
||||
if err != nil {
|
||||
return testResult{
|
||||
@@ -160,9 +140,6 @@ func testChannel(channel *model.Channel, testModel string) testResult {
|
||||
if c.Request.URL.Path == "/v1/embeddings" {
|
||||
relayFormat = types.RelayFormatEmbedding
|
||||
}
|
||||
if c.Request.URL.Path == "/v1/images/generations" {
|
||||
relayFormat = types.RelayFormatOpenAIImage
|
||||
}
|
||||
|
||||
info, err := relaycommon.GenRelayInfo(c, relayFormat, request, nil)
|
||||
|
||||
@@ -224,22 +201,6 @@ func testChannel(channel *model.Channel, testModel string) testResult {
|
||||
}
|
||||
// 调用专门用于 Embedding 的转换函数
|
||||
convertedRequest, err = adaptor.ConvertEmbeddingRequest(c, info, embeddingRequest)
|
||||
} else if info.RelayMode == relayconstant.RelayModeImagesGenerations {
|
||||
// 创建一个 ImageRequest
|
||||
prompt := "cat"
|
||||
if request.Prompt != nil {
|
||||
if promptStr, ok := request.Prompt.(string); ok && promptStr != "" {
|
||||
prompt = promptStr
|
||||
}
|
||||
}
|
||||
imageRequest := dto.ImageRequest{
|
||||
Prompt: prompt,
|
||||
Model: request.Model,
|
||||
N: uint(request.N),
|
||||
Size: request.Size,
|
||||
}
|
||||
// 调用专门用于图像生成的转换函数
|
||||
convertedRequest, err = adaptor.ConvertImageRequest(c, info, imageRequest)
|
||||
} else {
|
||||
// 对其他所有请求类型(如 Chat),保持原有逻辑
|
||||
convertedRequest, err = adaptor.ConvertOpenAIRequest(c, info, request)
|
||||
|
||||
@@ -8,7 +8,6 @@ import (
|
||||
"one-api/constant"
|
||||
"one-api/dto"
|
||||
"one-api/model"
|
||||
"one-api/service"
|
||||
"strconv"
|
||||
"strings"
|
||||
|
||||
@@ -189,8 +188,6 @@ func FetchUpstreamModels(c *gin.Context) {
|
||||
url = fmt.Sprintf("%s/v1beta/openai/models", baseURL) // Remove key in url since we need to use AuthHeader
|
||||
case constant.ChannelTypeAli:
|
||||
url = fmt.Sprintf("%s/compatible-mode/v1/models", baseURL)
|
||||
case constant.ChannelTypeZhipu_v4:
|
||||
url = fmt.Sprintf("%s/api/paas/v4/models", baseURL)
|
||||
default:
|
||||
url = fmt.Sprintf("%s/v1/models", baseURL)
|
||||
}
|
||||
@@ -504,10 +501,9 @@ func validateChannel(channel *model.Channel, isAdd bool) error {
|
||||
}
|
||||
|
||||
type AddChannelRequest struct {
|
||||
Mode string `json:"mode"`
|
||||
MultiKeyMode constant.MultiKeyMode `json:"multi_key_mode"`
|
||||
BatchAddSetKeyPrefix2Name bool `json:"batch_add_set_key_prefix_2_name"`
|
||||
Channel *model.Channel `json:"channel"`
|
||||
Mode string `json:"mode"`
|
||||
MultiKeyMode constant.MultiKeyMode `json:"multi_key_mode"`
|
||||
Channel *model.Channel `json:"channel"`
|
||||
}
|
||||
|
||||
func getVertexArrayKeys(keys string) ([]string, error) {
|
||||
@@ -620,13 +616,6 @@ func AddChannel(c *gin.Context) {
|
||||
}
|
||||
localChannel := addChannelRequest.Channel
|
||||
localChannel.Key = key
|
||||
if addChannelRequest.BatchAddSetKeyPrefix2Name && len(keys) > 1 {
|
||||
keyPrefix := localChannel.Key
|
||||
if len(localChannel.Key) > 8 {
|
||||
keyPrefix = localChannel.Key[:8]
|
||||
}
|
||||
localChannel.Name = fmt.Sprintf("%s %s", localChannel.Name, keyPrefix)
|
||||
}
|
||||
channels = append(channels, *localChannel)
|
||||
}
|
||||
err = model.BatchInsertChannels(channels)
|
||||
@@ -634,7 +623,6 @@ func AddChannel(c *gin.Context) {
|
||||
common.ApiError(c, err)
|
||||
return
|
||||
}
|
||||
service.ResetProxyClientCache()
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": true,
|
||||
"message": "",
|
||||
@@ -896,7 +884,6 @@ func UpdateChannel(c *gin.Context) {
|
||||
return
|
||||
}
|
||||
model.InitChannelCache()
|
||||
service.ResetProxyClientCache()
|
||||
channel.Key = ""
|
||||
clearChannelInfo(&channel.Channel)
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
@@ -1106,8 +1093,8 @@ func CopyChannel(c *gin.Context) {
|
||||
// MultiKeyManageRequest represents the request for multi-key management operations
|
||||
type MultiKeyManageRequest struct {
|
||||
ChannelId int `json:"channel_id"`
|
||||
Action string `json:"action"` // "disable_key", "enable_key", "delete_key", "delete_disabled_keys", "get_key_status"
|
||||
KeyIndex *int `json:"key_index,omitempty"` // for disable_key, enable_key, and delete_key actions
|
||||
Action string `json:"action"` // "disable_key", "enable_key", "delete_disabled_keys", "get_key_status"
|
||||
KeyIndex *int `json:"key_index,omitempty"` // for disable_key and enable_key actions
|
||||
Page int `json:"page,omitempty"` // for get_key_status pagination
|
||||
PageSize int `json:"page_size,omitempty"` // for get_key_status pagination
|
||||
Status *int `json:"status,omitempty"` // for get_key_status filtering: 1=enabled, 2=manual_disabled, 3=auto_disabled, nil=all
|
||||
@@ -1435,86 +1422,6 @@ func ManageMultiKeys(c *gin.Context) {
|
||||
})
|
||||
return
|
||||
|
||||
case "delete_key":
|
||||
if request.KeyIndex == nil {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": false,
|
||||
"message": "未指定要删除的密钥索引",
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
keyIndex := *request.KeyIndex
|
||||
if keyIndex < 0 || keyIndex >= channel.ChannelInfo.MultiKeySize {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": false,
|
||||
"message": "密钥索引超出范围",
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
keys := channel.GetKeys()
|
||||
var remainingKeys []string
|
||||
var newStatusList = make(map[int]int)
|
||||
var newDisabledTime = make(map[int]int64)
|
||||
var newDisabledReason = make(map[int]string)
|
||||
|
||||
newIndex := 0
|
||||
for i, key := range keys {
|
||||
// 跳过要删除的密钥
|
||||
if i == keyIndex {
|
||||
continue
|
||||
}
|
||||
|
||||
remainingKeys = append(remainingKeys, key)
|
||||
|
||||
// 保留其他密钥的状态信息,重新索引
|
||||
if channel.ChannelInfo.MultiKeyStatusList != nil {
|
||||
if status, exists := channel.ChannelInfo.MultiKeyStatusList[i]; exists && status != 1 {
|
||||
newStatusList[newIndex] = status
|
||||
}
|
||||
}
|
||||
if channel.ChannelInfo.MultiKeyDisabledTime != nil {
|
||||
if t, exists := channel.ChannelInfo.MultiKeyDisabledTime[i]; exists {
|
||||
newDisabledTime[newIndex] = t
|
||||
}
|
||||
}
|
||||
if channel.ChannelInfo.MultiKeyDisabledReason != nil {
|
||||
if r, exists := channel.ChannelInfo.MultiKeyDisabledReason[i]; exists {
|
||||
newDisabledReason[newIndex] = r
|
||||
}
|
||||
}
|
||||
newIndex++
|
||||
}
|
||||
|
||||
if len(remainingKeys) == 0 {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": false,
|
||||
"message": "不能删除最后一个密钥",
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
// Update channel with remaining keys
|
||||
channel.Key = strings.Join(remainingKeys, "\n")
|
||||
channel.ChannelInfo.MultiKeySize = len(remainingKeys)
|
||||
channel.ChannelInfo.MultiKeyStatusList = newStatusList
|
||||
channel.ChannelInfo.MultiKeyDisabledTime = newDisabledTime
|
||||
channel.ChannelInfo.MultiKeyDisabledReason = newDisabledReason
|
||||
|
||||
err = channel.Update()
|
||||
if err != nil {
|
||||
common.ApiError(c, err)
|
||||
return
|
||||
}
|
||||
|
||||
model.InitChannelCache()
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": true,
|
||||
"message": "密钥已删除",
|
||||
})
|
||||
return
|
||||
|
||||
case "delete_disabled_keys":
|
||||
keys := channel.GetKeys()
|
||||
var remainingKeys []string
|
||||
|
||||
375
controller/oauth.go
Normal file
375
controller/oauth.go
Normal file
@@ -0,0 +1,375 @@
|
||||
package controller
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"net/http"
|
||||
"one-api/model"
|
||||
"one-api/setting/system_setting"
|
||||
"one-api/src/oauth"
|
||||
"time"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
jwt "github.com/golang-jwt/jwt/v5"
|
||||
"one-api/middleware"
|
||||
"strconv"
|
||||
"strings"
|
||||
)
|
||||
|
||||
// GetJWKS 获取JWKS公钥集
|
||||
func GetJWKS(c *gin.Context) {
|
||||
settings := system_setting.GetOAuth2Settings()
|
||||
if !settings.Enabled {
|
||||
c.JSON(http.StatusNotFound, gin.H{
|
||||
"error": "OAuth2 server is disabled",
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
// lazy init if needed
|
||||
_ = oauth.EnsureInitialized()
|
||||
|
||||
jwks := oauth.GetJWKS()
|
||||
if jwks == nil {
|
||||
c.JSON(http.StatusInternalServerError, gin.H{
|
||||
"error": "JWKS not available",
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
// 设置CORS headers
|
||||
c.Header("Access-Control-Allow-Origin", "*")
|
||||
c.Header("Access-Control-Allow-Methods", "GET")
|
||||
c.Header("Access-Control-Allow-Headers", "Content-Type")
|
||||
c.Header("Cache-Control", "public, max-age=3600") // 缓存1小时
|
||||
|
||||
// 返回JWKS
|
||||
c.Header("Content-Type", "application/json")
|
||||
|
||||
// 将JWKS转换为JSON字符串
|
||||
jsonData, err := json.Marshal(jwks)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusInternalServerError, gin.H{
|
||||
"error": "Failed to marshal JWKS",
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
c.String(http.StatusOK, string(jsonData))
|
||||
}
|
||||
|
||||
// OAuthTokenEndpoint OAuth2 令牌端点
|
||||
func OAuthTokenEndpoint(c *gin.Context) {
|
||||
settings := system_setting.GetOAuth2Settings()
|
||||
if !settings.Enabled {
|
||||
c.JSON(http.StatusNotFound, gin.H{
|
||||
"error": "unsupported_grant_type",
|
||||
"error_description": "OAuth2 server is disabled",
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
// 只允许POST请求
|
||||
if c.Request.Method != "POST" {
|
||||
c.JSON(http.StatusMethodNotAllowed, gin.H{
|
||||
"error": "invalid_request",
|
||||
"error_description": "Only POST method is allowed",
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
// 只允许application/x-www-form-urlencoded内容类型
|
||||
contentType := c.GetHeader("Content-Type")
|
||||
if contentType == "" || !strings.Contains(strings.ToLower(contentType), "application/x-www-form-urlencoded") {
|
||||
c.JSON(http.StatusBadRequest, gin.H{
|
||||
"error": "invalid_request",
|
||||
"error_description": "Content-Type must be application/x-www-form-urlencoded",
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
// lazy init
|
||||
if err := oauth.EnsureInitialized(); err != nil {
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": "server_error", "error_description": err.Error()})
|
||||
return
|
||||
}
|
||||
oauth.HandleTokenRequest(c)
|
||||
}
|
||||
|
||||
// OAuthAuthorizeEndpoint OAuth2 授权端点
|
||||
func OAuthAuthorizeEndpoint(c *gin.Context) {
|
||||
settings := system_setting.GetOAuth2Settings()
|
||||
if !settings.Enabled {
|
||||
c.JSON(http.StatusNotFound, gin.H{
|
||||
"error": "server_error",
|
||||
"error_description": "OAuth2 server is disabled",
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
if err := oauth.EnsureInitialized(); err != nil {
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": "server_error", "error_description": err.Error()})
|
||||
return
|
||||
}
|
||||
oauth.HandleAuthorizeRequest(c)
|
||||
}
|
||||
|
||||
// OAuthServerInfo 获取OAuth2服务器信息
|
||||
func OAuthServerInfo(c *gin.Context) {
|
||||
settings := system_setting.GetOAuth2Settings()
|
||||
if !settings.Enabled {
|
||||
c.JSON(http.StatusNotFound, gin.H{
|
||||
"error": "OAuth2 server is disabled",
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
// 返回OAuth2服务器的基本信息(类似OpenID Connect Discovery)
|
||||
issuer := settings.Issuer
|
||||
if issuer == "" {
|
||||
scheme := "https"
|
||||
if c.Request.TLS == nil {
|
||||
if hdr := c.Request.Header.Get("X-Forwarded-Proto"); hdr != "" {
|
||||
scheme = hdr
|
||||
} else {
|
||||
scheme = "http"
|
||||
}
|
||||
}
|
||||
issuer = scheme + "://" + c.Request.Host
|
||||
}
|
||||
|
||||
base := issuer + "/api"
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"issuer": issuer,
|
||||
"authorization_endpoint": base + "/oauth/authorize",
|
||||
"token_endpoint": base + "/oauth/token",
|
||||
"jwks_uri": base + "/.well-known/jwks.json",
|
||||
"grant_types_supported": settings.AllowedGrantTypes,
|
||||
"response_types_supported": []string{"code", "token"},
|
||||
"token_endpoint_auth_methods_supported": []string{"client_secret_basic", "client_secret_post"},
|
||||
"code_challenge_methods_supported": []string{"S256"},
|
||||
"scopes_supported": []string{"openid", "profile", "email", "api:read", "api:write", "admin"},
|
||||
"default_private_key_path": settings.DefaultPrivateKeyPath,
|
||||
})
|
||||
}
|
||||
|
||||
// OAuthOIDCConfiguration OIDC discovery document
|
||||
func OAuthOIDCConfiguration(c *gin.Context) {
|
||||
settings := system_setting.GetOAuth2Settings()
|
||||
if !settings.Enabled {
|
||||
c.JSON(http.StatusNotFound, gin.H{"error": "OAuth2 server is disabled"})
|
||||
return
|
||||
}
|
||||
issuer := settings.Issuer
|
||||
if issuer == "" {
|
||||
scheme := "https"
|
||||
if c.Request.TLS == nil {
|
||||
if hdr := c.Request.Header.Get("X-Forwarded-Proto"); hdr != "" {
|
||||
scheme = hdr
|
||||
} else {
|
||||
scheme = "http"
|
||||
}
|
||||
}
|
||||
issuer = scheme + "://" + c.Request.Host
|
||||
}
|
||||
base := issuer + "/api"
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"issuer": issuer,
|
||||
"authorization_endpoint": base + "/oauth/authorize",
|
||||
"token_endpoint": base + "/oauth/token",
|
||||
"userinfo_endpoint": base + "/oauth/userinfo",
|
||||
"jwks_uri": base + "/.well-known/jwks.json",
|
||||
"response_types_supported": []string{"code", "token"},
|
||||
"grant_types_supported": settings.AllowedGrantTypes,
|
||||
"subject_types_supported": []string{"public"},
|
||||
"id_token_signing_alg_values_supported": []string{"RS256"},
|
||||
"scopes_supported": []string{"openid", "profile", "email", "api:read", "api:write", "admin"},
|
||||
"token_endpoint_auth_methods_supported": []string{"client_secret_basic", "client_secret_post"},
|
||||
"code_challenge_methods_supported": []string{"S256"},
|
||||
"default_private_key_path": settings.DefaultPrivateKeyPath,
|
||||
})
|
||||
}
|
||||
|
||||
// OAuthIntrospect 令牌内省端点(RFC 7662)
|
||||
func OAuthIntrospect(c *gin.Context) {
|
||||
settings := system_setting.GetOAuth2Settings()
|
||||
if !settings.Enabled {
|
||||
c.JSON(http.StatusNotFound, gin.H{
|
||||
"error": "OAuth2 server is disabled",
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
// 只允许POST请求
|
||||
if c.Request.Method != "POST" {
|
||||
c.JSON(http.StatusMethodNotAllowed, gin.H{
|
||||
"error": "invalid_request",
|
||||
"error_description": "Only POST method is allowed",
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
token := c.PostForm("token")
|
||||
if token == "" {
|
||||
c.JSON(http.StatusBadRequest, gin.H{
|
||||
"active": false,
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
tokenString := token
|
||||
|
||||
// 验证并解析JWT
|
||||
parsed, err := jwt.Parse(tokenString, func(token *jwt.Token) (interface{}, error) {
|
||||
if _, ok := token.Method.(*jwt.SigningMethodRSA); !ok {
|
||||
return nil, jwt.ErrTokenSignatureInvalid
|
||||
}
|
||||
pub := oauth.GetPublicKeyByKid(func() string {
|
||||
if v, ok := token.Header["kid"].(string); ok {
|
||||
return v
|
||||
}
|
||||
return ""
|
||||
}())
|
||||
if pub == nil {
|
||||
return nil, jwt.ErrTokenUnverifiable
|
||||
}
|
||||
return pub, nil
|
||||
})
|
||||
if err != nil || !parsed.Valid {
|
||||
c.JSON(http.StatusOK, gin.H{"active": false})
|
||||
return
|
||||
}
|
||||
|
||||
claims, ok := parsed.Claims.(jwt.MapClaims)
|
||||
if !ok {
|
||||
c.JSON(http.StatusOK, gin.H{"active": false})
|
||||
return
|
||||
}
|
||||
|
||||
// 检查撤销
|
||||
if jti, ok := claims["jti"].(string); ok && jti != "" {
|
||||
if revoked, _ := model.IsTokenRevoked(jti); revoked {
|
||||
c.JSON(http.StatusOK, gin.H{"active": false})
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
// 有效
|
||||
resp := gin.H{"active": true}
|
||||
for k, v := range claims {
|
||||
resp[k] = v
|
||||
}
|
||||
resp["token_type"] = "Bearer"
|
||||
c.JSON(http.StatusOK, resp)
|
||||
}
|
||||
|
||||
// OAuthRevoke 令牌撤销端点(RFC 7009)
|
||||
func OAuthRevoke(c *gin.Context) {
|
||||
settings := system_setting.GetOAuth2Settings()
|
||||
if !settings.Enabled {
|
||||
c.JSON(http.StatusNotFound, gin.H{
|
||||
"error": "OAuth2 server is disabled",
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
// 只允许POST请求
|
||||
if c.Request.Method != "POST" {
|
||||
c.JSON(http.StatusMethodNotAllowed, gin.H{
|
||||
"error": "invalid_request",
|
||||
"error_description": "Only POST method is allowed",
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
token := c.PostForm("token")
|
||||
if token == "" {
|
||||
c.JSON(http.StatusBadRequest, gin.H{
|
||||
"error": "invalid_request",
|
||||
"error_description": "Missing token parameter",
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
token = c.PostForm("token")
|
||||
if token == "" {
|
||||
c.JSON(http.StatusBadRequest, gin.H{
|
||||
"error": "invalid_request",
|
||||
"error_description": "Missing token parameter",
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
// 尝试解析JWT,若成功则记录jti到撤销表
|
||||
parsed, err := jwt.Parse(token, func(t *jwt.Token) (interface{}, error) {
|
||||
if _, ok := t.Method.(*jwt.SigningMethodRSA); !ok {
|
||||
return nil, jwt.ErrTokenSignatureInvalid
|
||||
}
|
||||
pub := oauth.GetRSAPublicKey()
|
||||
if pub == nil {
|
||||
return nil, jwt.ErrTokenUnverifiable
|
||||
}
|
||||
return pub, nil
|
||||
})
|
||||
if err == nil && parsed != nil && parsed.Valid {
|
||||
if claims, ok := parsed.Claims.(jwt.MapClaims); ok {
|
||||
var jti string
|
||||
var exp int64
|
||||
if v, ok := claims["jti"].(string); ok {
|
||||
jti = v
|
||||
}
|
||||
if v, ok := claims["exp"].(float64); ok {
|
||||
exp = int64(v)
|
||||
} else if v, ok := claims["exp"].(int64); ok {
|
||||
exp = v
|
||||
}
|
||||
if jti != "" {
|
||||
// 如果没有exp,默认撤销至当前+TTL 10分钟
|
||||
if exp == 0 {
|
||||
exp = time.Now().Add(10 * time.Minute).Unix()
|
||||
}
|
||||
_ = model.RevokeToken(jti, exp)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
c.JSON(http.StatusOK, gin.H{"success": true})
|
||||
}
|
||||
|
||||
// OAuthUserInfo returns OIDC userinfo based on access token
|
||||
func OAuthUserInfo(c *gin.Context) {
|
||||
settings := system_setting.GetOAuth2Settings()
|
||||
if !settings.Enabled {
|
||||
c.JSON(http.StatusNotFound, gin.H{"error": "OAuth2 server is disabled"})
|
||||
return
|
||||
}
|
||||
// 需要 OAuthJWTAuth 中间件注入 claims
|
||||
claims, ok := middleware.GetOAuthClaims(c)
|
||||
if !ok {
|
||||
c.JSON(http.StatusUnauthorized, gin.H{"error": "invalid_token"})
|
||||
return
|
||||
}
|
||||
// scope 校验:必须包含 openid
|
||||
scope, _ := claims["scope"].(string)
|
||||
if !strings.Contains(" "+scope+" ", " openid ") {
|
||||
c.JSON(http.StatusForbidden, gin.H{"error": "insufficient_scope"})
|
||||
return
|
||||
}
|
||||
sub, _ := claims["sub"].(string)
|
||||
resp := gin.H{"sub": sub}
|
||||
// 若包含 profile/email scope,补充返回
|
||||
if strings.Contains(" "+scope+" ", " profile ") || strings.Contains(" "+scope+" ", " email ") {
|
||||
if uid, err := strconv.Atoi(sub); err == nil {
|
||||
if user, err2 := model.GetUserById(uid, false); err2 == nil && user != nil {
|
||||
if strings.Contains(" "+scope+" ", " profile ") {
|
||||
resp["name"] = user.DisplayName
|
||||
resp["preferred_username"] = user.Username
|
||||
}
|
||||
if strings.Contains(" "+scope+" ", " email ") {
|
||||
resp["email"] = user.Email
|
||||
resp["email_verified"] = true
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
c.JSON(http.StatusOK, resp)
|
||||
}
|
||||
374
controller/oauth_client.go
Normal file
374
controller/oauth_client.go
Normal file
@@ -0,0 +1,374 @@
|
||||
package controller
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"one-api/common"
|
||||
"one-api/model"
|
||||
"strconv"
|
||||
"strings"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/thanhpk/randstr"
|
||||
)
|
||||
|
||||
// CreateOAuthClientRequest 创建OAuth客户端请求
|
||||
type CreateOAuthClientRequest struct {
|
||||
Name string `json:"name" binding:"required"`
|
||||
ClientType string `json:"client_type" binding:"required,oneof=confidential public"`
|
||||
GrantTypes []string `json:"grant_types" binding:"required"`
|
||||
RedirectURIs []string `json:"redirect_uris"`
|
||||
Scopes []string `json:"scopes" binding:"required"`
|
||||
Description string `json:"description"`
|
||||
RequirePKCE bool `json:"require_pkce"`
|
||||
}
|
||||
|
||||
// UpdateOAuthClientRequest 更新OAuth客户端请求
|
||||
type UpdateOAuthClientRequest struct {
|
||||
ID string `json:"id" binding:"required"`
|
||||
Name string `json:"name" binding:"required"`
|
||||
ClientType string `json:"client_type" binding:"required,oneof=confidential public"`
|
||||
GrantTypes []string `json:"grant_types" binding:"required"`
|
||||
RedirectURIs []string `json:"redirect_uris"`
|
||||
Scopes []string `json:"scopes" binding:"required"`
|
||||
Description string `json:"description"`
|
||||
RequirePKCE bool `json:"require_pkce"`
|
||||
Status int `json:"status" binding:"required,oneof=1 2"`
|
||||
}
|
||||
|
||||
// GetAllOAuthClients 获取所有OAuth客户端
|
||||
func GetAllOAuthClients(c *gin.Context) {
|
||||
page, _ := strconv.Atoi(c.Query("page"))
|
||||
if page < 1 {
|
||||
page = 1
|
||||
}
|
||||
perPage, _ := strconv.Atoi(c.Query("per_page"))
|
||||
if perPage < 1 || perPage > 100 {
|
||||
perPage = 20
|
||||
}
|
||||
|
||||
startIdx := (page - 1) * perPage
|
||||
clients, err := model.GetAllOAuthClients(startIdx, perPage)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": false,
|
||||
"message": err.Error(),
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
// 清理敏感信息
|
||||
for _, client := range clients {
|
||||
client.Secret = maskSecret(client.Secret)
|
||||
}
|
||||
|
||||
total, _ := model.CountOAuthClients()
|
||||
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": true,
|
||||
"data": clients,
|
||||
"total": total,
|
||||
"page": page,
|
||||
"per_page": perPage,
|
||||
})
|
||||
}
|
||||
|
||||
// SearchOAuthClients 搜索OAuth客户端
|
||||
func SearchOAuthClients(c *gin.Context) {
|
||||
keyword := c.Query("keyword")
|
||||
if keyword == "" {
|
||||
c.JSON(http.StatusBadRequest, gin.H{
|
||||
"success": false,
|
||||
"message": "关键词不能为空",
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
clients, err := model.SearchOAuthClients(keyword)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": false,
|
||||
"message": err.Error(),
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
// 清理敏感信息
|
||||
for _, client := range clients {
|
||||
client.Secret = maskSecret(client.Secret)
|
||||
}
|
||||
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": true,
|
||||
"data": clients,
|
||||
})
|
||||
}
|
||||
|
||||
// GetOAuthClient 获取单个OAuth客户端
|
||||
func GetOAuthClient(c *gin.Context) {
|
||||
id := c.Param("id")
|
||||
if id == "" {
|
||||
c.JSON(http.StatusBadRequest, gin.H{
|
||||
"success": false,
|
||||
"message": "ID不能为空",
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
client, err := model.GetOAuthClientByID(id)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusNotFound, gin.H{
|
||||
"success": false,
|
||||
"message": "客户端不存在",
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
// 清理敏感信息
|
||||
client.Secret = maskSecret(client.Secret)
|
||||
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": true,
|
||||
"data": client,
|
||||
})
|
||||
}
|
||||
|
||||
// CreateOAuthClient 创建OAuth客户端
|
||||
func CreateOAuthClient(c *gin.Context) {
|
||||
var req CreateOAuthClientRequest
|
||||
if err := c.ShouldBindJSON(&req); err != nil {
|
||||
c.JSON(http.StatusBadRequest, gin.H{
|
||||
"success": false,
|
||||
"message": "请求参数错误: " + err.Error(),
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
// 验证授权类型
|
||||
validGrantTypes := []string{"client_credentials", "authorization_code", "refresh_token"}
|
||||
for _, grantType := range req.GrantTypes {
|
||||
if !contains(validGrantTypes, grantType) {
|
||||
c.JSON(http.StatusBadRequest, gin.H{
|
||||
"success": false,
|
||||
"message": "无效的授权类型: " + grantType,
|
||||
})
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
// 如果包含authorization_code,则必须提供redirect_uris
|
||||
if contains(req.GrantTypes, "authorization_code") && len(req.RedirectURIs) == 0 {
|
||||
c.JSON(http.StatusBadRequest, gin.H{
|
||||
"success": false,
|
||||
"message": "授权码模式需要提供重定向URI",
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
// 生成客户端ID和密钥
|
||||
clientID := generateClientID()
|
||||
clientSecret := ""
|
||||
if req.ClientType == "confidential" {
|
||||
clientSecret = generateClientSecret()
|
||||
}
|
||||
|
||||
// 获取创建者ID
|
||||
createdBy := c.GetInt("id")
|
||||
|
||||
// 创建客户端
|
||||
client := &model.OAuthClient{
|
||||
ID: clientID,
|
||||
Secret: clientSecret,
|
||||
Name: req.Name,
|
||||
ClientType: req.ClientType,
|
||||
RequirePKCE: req.RequirePKCE,
|
||||
Status: common.UserStatusEnabled,
|
||||
CreatedBy: createdBy,
|
||||
Description: req.Description,
|
||||
}
|
||||
|
||||
client.SetGrantTypes(req.GrantTypes)
|
||||
client.SetRedirectURIs(req.RedirectURIs)
|
||||
client.SetScopes(req.Scopes)
|
||||
|
||||
err := model.CreateOAuthClient(client)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusInternalServerError, gin.H{
|
||||
"success": false,
|
||||
"message": "创建客户端失败: " + err.Error(),
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
// 返回结果(包含完整的客户端密钥,仅此一次)
|
||||
c.JSON(http.StatusCreated, gin.H{
|
||||
"success": true,
|
||||
"message": "客户端创建成功",
|
||||
"client_id": client.ID,
|
||||
"client_secret": client.Secret, // 仅在创建时返回完整密钥
|
||||
"data": client,
|
||||
})
|
||||
}
|
||||
|
||||
// UpdateOAuthClient 更新OAuth客户端
|
||||
func UpdateOAuthClient(c *gin.Context) {
|
||||
var req UpdateOAuthClientRequest
|
||||
if err := c.ShouldBindJSON(&req); err != nil {
|
||||
c.JSON(http.StatusBadRequest, gin.H{
|
||||
"success": false,
|
||||
"message": "请求参数错误: " + err.Error(),
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
// 获取现有客户端
|
||||
client, err := model.GetOAuthClientByID(req.ID)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusNotFound, gin.H{
|
||||
"success": false,
|
||||
"message": "客户端不存在",
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
// 验证授权类型
|
||||
validGrantTypes := []string{"client_credentials", "authorization_code", "refresh_token"}
|
||||
for _, grantType := range req.GrantTypes {
|
||||
if !contains(validGrantTypes, grantType) {
|
||||
c.JSON(http.StatusBadRequest, gin.H{
|
||||
"success": false,
|
||||
"message": "无效的授权类型: " + grantType,
|
||||
})
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
// 更新客户端信息
|
||||
client.Name = req.Name
|
||||
client.ClientType = req.ClientType
|
||||
client.RequirePKCE = req.RequirePKCE
|
||||
client.Status = req.Status
|
||||
client.Description = req.Description
|
||||
client.SetGrantTypes(req.GrantTypes)
|
||||
client.SetRedirectURIs(req.RedirectURIs)
|
||||
client.SetScopes(req.Scopes)
|
||||
|
||||
err = model.UpdateOAuthClient(client)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusInternalServerError, gin.H{
|
||||
"success": false,
|
||||
"message": "更新客户端失败: " + err.Error(),
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
// 清理敏感信息
|
||||
client.Secret = maskSecret(client.Secret)
|
||||
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": true,
|
||||
"message": "客户端更新成功",
|
||||
"data": client,
|
||||
})
|
||||
}
|
||||
|
||||
// DeleteOAuthClient 删除OAuth客户端
|
||||
func DeleteOAuthClient(c *gin.Context) {
|
||||
id := c.Param("id")
|
||||
if id == "" {
|
||||
c.JSON(http.StatusBadRequest, gin.H{
|
||||
"success": false,
|
||||
"message": "ID不能为空",
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
err := model.DeleteOAuthClient(id)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusInternalServerError, gin.H{
|
||||
"success": false,
|
||||
"message": "删除客户端失败: " + err.Error(),
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": true,
|
||||
"message": "客户端删除成功",
|
||||
})
|
||||
}
|
||||
|
||||
// RegenerateOAuthClientSecret 重新生成客户端密钥
|
||||
func RegenerateOAuthClientSecret(c *gin.Context) {
|
||||
id := c.Param("id")
|
||||
if id == "" {
|
||||
c.JSON(http.StatusBadRequest, gin.H{
|
||||
"success": false,
|
||||
"message": "ID不能为空",
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
client, err := model.GetOAuthClientByID(id)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusNotFound, gin.H{
|
||||
"success": false,
|
||||
"message": "客户端不存在",
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
// 只有机密客户端才能重新生成密钥
|
||||
if client.ClientType != "confidential" {
|
||||
c.JSON(http.StatusBadRequest, gin.H{
|
||||
"success": false,
|
||||
"message": "只有机密客户端才能重新生成密钥",
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
// 生成新密钥
|
||||
client.Secret = generateClientSecret()
|
||||
|
||||
err = model.UpdateOAuthClient(client)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusInternalServerError, gin.H{
|
||||
"success": false,
|
||||
"message": "重新生成密钥失败: " + err.Error(),
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": true,
|
||||
"message": "客户端密钥重新生成成功",
|
||||
"client_secret": client.Secret, // 返回新生成的密钥
|
||||
})
|
||||
}
|
||||
|
||||
// generateClientID 生成客户端ID
|
||||
func generateClientID() string {
|
||||
return "client_" + randstr.String(16)
|
||||
}
|
||||
|
||||
// generateClientSecret 生成客户端密钥
|
||||
func generateClientSecret() string {
|
||||
return randstr.String(32)
|
||||
}
|
||||
|
||||
// maskSecret 掩码密钥显示
|
||||
func maskSecret(secret string) string {
|
||||
if len(secret) <= 6 {
|
||||
return strings.Repeat("*", len(secret))
|
||||
}
|
||||
return secret[:3] + strings.Repeat("*", len(secret)-6) + secret[len(secret)-3:]
|
||||
}
|
||||
|
||||
// contains 检查字符串切片是否包含指定值
|
||||
func contains(slice []string, item string) bool {
|
||||
for _, s := range slice {
|
||||
if s == item {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
89
controller/oauth_keys.go
Normal file
89
controller/oauth_keys.go
Normal file
@@ -0,0 +1,89 @@
|
||||
package controller
|
||||
|
||||
import (
|
||||
"github.com/gin-gonic/gin"
|
||||
"net/http"
|
||||
"one-api/logger"
|
||||
"one-api/src/oauth"
|
||||
)
|
||||
|
||||
type rotateKeyRequest struct {
|
||||
Kid string `json:"kid"`
|
||||
}
|
||||
|
||||
type genKeyFileRequest struct {
|
||||
Path string `json:"path"`
|
||||
Kid string `json:"kid"`
|
||||
Overwrite bool `json:"overwrite"`
|
||||
}
|
||||
|
||||
type importPemRequest struct {
|
||||
Pem string `json:"pem"`
|
||||
Kid string `json:"kid"`
|
||||
}
|
||||
|
||||
// RotateOAuthSigningKey rotates the OAuth2 JWT signing key (Root only)
|
||||
func RotateOAuthSigningKey(c *gin.Context) {
|
||||
var req rotateKeyRequest
|
||||
_ = c.BindJSON(&req)
|
||||
kid, err := oauth.RotateSigningKey(req.Kid)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"success": false, "message": err.Error()})
|
||||
return
|
||||
}
|
||||
logger.LogInfo(c, "oauth signing key rotated: "+kid)
|
||||
c.JSON(http.StatusOK, gin.H{"success": true, "kid": kid})
|
||||
}
|
||||
|
||||
// ListOAuthSigningKeys returns current and historical JWKS signing keys
|
||||
func ListOAuthSigningKeys(c *gin.Context) {
|
||||
keys := oauth.ListSigningKeys()
|
||||
c.JSON(http.StatusOK, gin.H{"success": true, "data": keys})
|
||||
}
|
||||
|
||||
// DeleteOAuthSigningKey deletes a non-current key by kid
|
||||
func DeleteOAuthSigningKey(c *gin.Context) {
|
||||
kid := c.Param("kid")
|
||||
if kid == "" {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"success": false, "message": "kid required"})
|
||||
return
|
||||
}
|
||||
if err := oauth.DeleteSigningKey(kid); err != nil {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"success": false, "message": err.Error()})
|
||||
return
|
||||
}
|
||||
logger.LogInfo(c, "oauth signing key deleted: "+kid)
|
||||
c.JSON(http.StatusOK, gin.H{"success": true})
|
||||
}
|
||||
|
||||
// GenerateOAuthSigningKeyFile generates a private key file and rotates current kid
|
||||
func GenerateOAuthSigningKeyFile(c *gin.Context) {
|
||||
var req genKeyFileRequest
|
||||
if err := c.ShouldBindJSON(&req); err != nil || req.Path == "" {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"success": false, "message": "path required"})
|
||||
return
|
||||
}
|
||||
kid, err := oauth.GenerateAndPersistKey(req.Path, req.Kid, req.Overwrite)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"success": false, "message": err.Error()})
|
||||
return
|
||||
}
|
||||
logger.LogInfo(c, "oauth signing key generated to file: "+req.Path+" kid="+kid)
|
||||
c.JSON(http.StatusOK, gin.H{"success": true, "kid": kid, "path": req.Path})
|
||||
}
|
||||
|
||||
// ImportOAuthSigningKey imports PEM text and rotates current kid
|
||||
func ImportOAuthSigningKey(c *gin.Context) {
|
||||
var req importPemRequest
|
||||
if err := c.ShouldBindJSON(&req); err != nil || req.Pem == "" {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"success": false, "message": "pem required"})
|
||||
return
|
||||
}
|
||||
kid, err := oauth.ImportPEMKey(req.Pem, req.Kid)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"success": false, "message": err.Error()})
|
||||
return
|
||||
}
|
||||
logger.LogInfo(c, "oauth signing key imported from PEM, kid="+kid)
|
||||
c.JSON(http.StatusOK, gin.H{"success": true, "kid": kid})
|
||||
}
|
||||
@@ -128,33 +128,6 @@ func UpdateOption(c *gin.Context) {
|
||||
})
|
||||
return
|
||||
}
|
||||
case "ImageRatio":
|
||||
err = ratio_setting.UpdateImageRatioByJSONString(option.Value.(string))
|
||||
if err != nil {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": false,
|
||||
"message": "图片倍率设置失败: " + err.Error(),
|
||||
})
|
||||
return
|
||||
}
|
||||
case "AudioRatio":
|
||||
err = ratio_setting.UpdateAudioRatioByJSONString(option.Value.(string))
|
||||
if err != nil {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": false,
|
||||
"message": "音频倍率设置失败: " + err.Error(),
|
||||
})
|
||||
return
|
||||
}
|
||||
case "AudioCompletionRatio":
|
||||
err = ratio_setting.UpdateAudioCompletionRatioByJSONString(option.Value.(string))
|
||||
if err != nil {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": false,
|
||||
"message": "音频补全倍率设置失败: " + err.Error(),
|
||||
})
|
||||
return
|
||||
}
|
||||
case "ModelRequestRateLimitGroup":
|
||||
err = setting.CheckModelRequestRateLimitGroup(option.Value.(string))
|
||||
if err != nil {
|
||||
|
||||
@@ -53,7 +53,7 @@ func GetSetup(c *gin.Context) {
|
||||
func PostSetup(c *gin.Context) {
|
||||
// Check if setup is already completed
|
||||
if constant.Setup {
|
||||
c.JSON(200, gin.H{
|
||||
c.JSON(400, gin.H{
|
||||
"success": false,
|
||||
"message": "系统已经初始化完成",
|
||||
})
|
||||
@@ -66,7 +66,7 @@ func PostSetup(c *gin.Context) {
|
||||
var req SetupRequest
|
||||
err := c.ShouldBindJSON(&req)
|
||||
if err != nil {
|
||||
c.JSON(200, gin.H{
|
||||
c.JSON(400, gin.H{
|
||||
"success": false,
|
||||
"message": "请求参数有误",
|
||||
})
|
||||
@@ -77,7 +77,7 @@ func PostSetup(c *gin.Context) {
|
||||
if !rootExists {
|
||||
// Validate username length: max 12 characters to align with model.User validation
|
||||
if len(req.Username) > 12 {
|
||||
c.JSON(200, gin.H{
|
||||
c.JSON(400, gin.H{
|
||||
"success": false,
|
||||
"message": "用户名长度不能超过12个字符",
|
||||
})
|
||||
@@ -85,7 +85,7 @@ func PostSetup(c *gin.Context) {
|
||||
}
|
||||
// Validate password
|
||||
if req.Password != req.ConfirmPassword {
|
||||
c.JSON(200, gin.H{
|
||||
c.JSON(400, gin.H{
|
||||
"success": false,
|
||||
"message": "两次输入的密码不一致",
|
||||
})
|
||||
@@ -93,7 +93,7 @@ func PostSetup(c *gin.Context) {
|
||||
}
|
||||
|
||||
if len(req.Password) < 8 {
|
||||
c.JSON(200, gin.H{
|
||||
c.JSON(400, gin.H{
|
||||
"success": false,
|
||||
"message": "密码长度至少为8个字符",
|
||||
})
|
||||
@@ -103,7 +103,7 @@ func PostSetup(c *gin.Context) {
|
||||
// Create root user
|
||||
hashedPassword, err := common.Password2Hash(req.Password)
|
||||
if err != nil {
|
||||
c.JSON(200, gin.H{
|
||||
c.JSON(500, gin.H{
|
||||
"success": false,
|
||||
"message": "系统错误: " + err.Error(),
|
||||
})
|
||||
@@ -120,7 +120,7 @@ func PostSetup(c *gin.Context) {
|
||||
}
|
||||
err = model.DB.Create(&rootUser).Error
|
||||
if err != nil {
|
||||
c.JSON(200, gin.H{
|
||||
c.JSON(500, gin.H{
|
||||
"success": false,
|
||||
"message": "创建管理员账号失败: " + err.Error(),
|
||||
})
|
||||
@@ -135,7 +135,7 @@ func PostSetup(c *gin.Context) {
|
||||
// Save operation modes to database for persistence
|
||||
err = model.UpdateOption("SelfUseModeEnabled", boolToString(req.SelfUseModeEnabled))
|
||||
if err != nil {
|
||||
c.JSON(200, gin.H{
|
||||
c.JSON(500, gin.H{
|
||||
"success": false,
|
||||
"message": "保存自用模式设置失败: " + err.Error(),
|
||||
})
|
||||
@@ -144,7 +144,7 @@ func PostSetup(c *gin.Context) {
|
||||
|
||||
err = model.UpdateOption("DemoSiteEnabled", boolToString(req.DemoSiteEnabled))
|
||||
if err != nil {
|
||||
c.JSON(200, gin.H{
|
||||
c.JSON(500, gin.H{
|
||||
"success": false,
|
||||
"message": "保存演示站点模式设置失败: " + err.Error(),
|
||||
})
|
||||
@@ -160,7 +160,7 @@ func PostSetup(c *gin.Context) {
|
||||
}
|
||||
err = model.DB.Create(&setup).Error
|
||||
if err != nil {
|
||||
c.JSON(200, gin.H{
|
||||
c.JSON(500, gin.H{
|
||||
"success": false,
|
||||
"message": "系统初始化失败: " + err.Error(),
|
||||
})
|
||||
|
||||
@@ -225,8 +225,7 @@ func genStripeLink(referenceId string, customerId string, email string, amount i
|
||||
Quantity: stripe.Int64(amount),
|
||||
},
|
||||
},
|
||||
Mode: stripe.String(string(stripe.CheckoutSessionModePayment)),
|
||||
AllowPromotionCodes: stripe.Bool(setting.StripePromotionCodesEnabled),
|
||||
Mode: stripe.String(string(stripe.CheckoutSessionModePayment)),
|
||||
}
|
||||
|
||||
if "" == customerId {
|
||||
|
||||
@@ -19,12 +19,4 @@ const (
|
||||
type ChannelOtherSettings struct {
|
||||
AzureResponsesVersion string `json:"azure_responses_version,omitempty"`
|
||||
VertexKeyType VertexKeyType `json:"vertex_key_type,omitempty"` // "json" or "api_key"
|
||||
OpenRouterEnterprise *bool `json:"openrouter_enterprise,omitempty"`
|
||||
}
|
||||
|
||||
func (s *ChannelOtherSettings) IsOpenRouterEnterprise() bool {
|
||||
if s == nil || s.OpenRouterEnterprise == nil {
|
||||
return false
|
||||
}
|
||||
return *s.OpenRouterEnterprise
|
||||
}
|
||||
|
||||
@@ -14,30 +14,7 @@ type GeminiChatRequest struct {
|
||||
SafetySettings []GeminiChatSafetySettings `json:"safetySettings,omitempty"`
|
||||
GenerationConfig GeminiChatGenerationConfig `json:"generationConfig,omitempty"`
|
||||
Tools json.RawMessage `json:"tools,omitempty"`
|
||||
ToolConfig *ToolConfig `json:"toolConfig,omitempty"`
|
||||
SystemInstructions *GeminiChatContent `json:"systemInstruction,omitempty"`
|
||||
CachedContent string `json:"cachedContent,omitempty"`
|
||||
}
|
||||
|
||||
type ToolConfig struct {
|
||||
FunctionCallingConfig *FunctionCallingConfig `json:"functionCallingConfig,omitempty"`
|
||||
RetrievalConfig *RetrievalConfig `json:"retrievalConfig,omitempty"`
|
||||
}
|
||||
|
||||
type FunctionCallingConfig struct {
|
||||
Mode FunctionCallingConfigMode `json:"mode,omitempty"`
|
||||
AllowedFunctionNames []string `json:"allowedFunctionNames,omitempty"`
|
||||
}
|
||||
type FunctionCallingConfigMode string
|
||||
|
||||
type RetrievalConfig struct {
|
||||
LatLng *LatLng `json:"latLng,omitempty"`
|
||||
LanguageCode string `json:"languageCode,omitempty"`
|
||||
}
|
||||
|
||||
type LatLng struct {
|
||||
Latitude *float64 `json:"latitude,omitempty"`
|
||||
Longitude *float64 `json:"longitude,omitempty"`
|
||||
}
|
||||
|
||||
func (r *GeminiChatRequest) GetTokenCountMeta() *types.TokenCountMeta {
|
||||
@@ -251,7 +228,6 @@ type GeminiChatTool struct {
|
||||
GoogleSearchRetrieval any `json:"googleSearchRetrieval,omitempty"`
|
||||
CodeExecution any `json:"codeExecution,omitempty"`
|
||||
FunctionDeclarations any `json:"functionDeclarations,omitempty"`
|
||||
URLContext any `json:"urlContext,omitempty"`
|
||||
}
|
||||
|
||||
type GeminiChatGenerationConfig struct {
|
||||
@@ -263,20 +239,12 @@ type GeminiChatGenerationConfig struct {
|
||||
StopSequences []string `json:"stopSequences,omitempty"`
|
||||
ResponseMimeType string `json:"responseMimeType,omitempty"`
|
||||
ResponseSchema any `json:"responseSchema,omitempty"`
|
||||
ResponseJsonSchema json.RawMessage `json:"responseJsonSchema,omitempty"`
|
||||
PresencePenalty *float32 `json:"presencePenalty,omitempty"`
|
||||
FrequencyPenalty *float32 `json:"frequencyPenalty,omitempty"`
|
||||
ResponseLogprobs bool `json:"responseLogprobs,omitempty"`
|
||||
Logprobs *int32 `json:"logprobs,omitempty"`
|
||||
MediaResolution MediaResolution `json:"mediaResolution,omitempty"`
|
||||
Seed int64 `json:"seed,omitempty"`
|
||||
ResponseModalities []string `json:"responseModalities,omitempty"`
|
||||
ThinkingConfig *GeminiThinkingConfig `json:"thinkingConfig,omitempty"`
|
||||
SpeechConfig json.RawMessage `json:"speechConfig,omitempty"` // RawMessage to allow flexible speech config
|
||||
}
|
||||
|
||||
type MediaResolution string
|
||||
|
||||
type GeminiChatCandidate struct {
|
||||
Content GeminiChatContent `json:"content"`
|
||||
FinishReason *string `json:"finishReason"`
|
||||
|
||||
@@ -772,12 +772,11 @@ type OpenAIResponsesRequest struct {
|
||||
Instructions json.RawMessage `json:"instructions,omitempty"`
|
||||
MaxOutputTokens uint `json:"max_output_tokens,omitempty"`
|
||||
Metadata json.RawMessage `json:"metadata,omitempty"`
|
||||
ParallelToolCalls json.RawMessage `json:"parallel_tool_calls,omitempty"`
|
||||
ParallelToolCalls bool `json:"parallel_tool_calls,omitempty"`
|
||||
PreviousResponseID string `json:"previous_response_id,omitempty"`
|
||||
Reasoning *Reasoning `json:"reasoning,omitempty"`
|
||||
ServiceTier string `json:"service_tier,omitempty"`
|
||||
Store json.RawMessage `json:"store,omitempty"`
|
||||
PromptCacheKey json.RawMessage `json:"prompt_cache_key,omitempty"`
|
||||
Store bool `json:"store,omitempty"`
|
||||
Stream bool `json:"stream,omitempty"`
|
||||
Temperature float64 `json:"temperature,omitempty"`
|
||||
Text json.RawMessage `json:"text,omitempty"`
|
||||
|
||||
@@ -6,10 +6,6 @@ import (
|
||||
"one-api/types"
|
||||
)
|
||||
|
||||
const (
|
||||
ResponsesOutputTypeImageGenerationCall = "image_generation_call"
|
||||
)
|
||||
|
||||
type SimpleResponse struct {
|
||||
Usage `json:"usage"`
|
||||
Error any `json:"error"`
|
||||
@@ -277,42 +273,6 @@ func (o *OpenAIResponsesResponse) GetOpenAIError() *types.OpenAIError {
|
||||
return GetOpenAIError(o.Error)
|
||||
}
|
||||
|
||||
func (o *OpenAIResponsesResponse) HasImageGenerationCall() bool {
|
||||
if len(o.Output) == 0 {
|
||||
return false
|
||||
}
|
||||
for _, output := range o.Output {
|
||||
if output.Type == ResponsesOutputTypeImageGenerationCall {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func (o *OpenAIResponsesResponse) GetQuality() string {
|
||||
if len(o.Output) == 0 {
|
||||
return ""
|
||||
}
|
||||
for _, output := range o.Output {
|
||||
if output.Type == ResponsesOutputTypeImageGenerationCall {
|
||||
return output.Quality
|
||||
}
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
func (o *OpenAIResponsesResponse) GetSize() string {
|
||||
if len(o.Output) == 0 {
|
||||
return ""
|
||||
}
|
||||
for _, output := range o.Output {
|
||||
if output.Type == ResponsesOutputTypeImageGenerationCall {
|
||||
return output.Size
|
||||
}
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
type IncompleteDetails struct {
|
||||
Reasoning string `json:"reasoning"`
|
||||
}
|
||||
@@ -323,8 +283,6 @@ type ResponsesOutput struct {
|
||||
Status string `json:"status"`
|
||||
Role string `json:"role"`
|
||||
Content []ResponsesOutputContent `json:"content"`
|
||||
Quality string `json:"quality"`
|
||||
Size string `json:"size"`
|
||||
}
|
||||
|
||||
type ResponsesOutputContent struct {
|
||||
|
||||
326
examples/oauth/oauth-demo.html
Normal file
326
examples/oauth/oauth-demo.html
Normal file
@@ -0,0 +1,326 @@
|
||||
<!doctype html>
|
||||
<html lang="zh-CN">
|
||||
<head>
|
||||
<meta charset="utf-8" />
|
||||
<meta name="viewport" content="width=device-width, initial-scale=1" />
|
||||
<title>OAuth2/OIDC 授权码 + PKCE 前端演示</title>
|
||||
<style>
|
||||
:root { --bg:#0b0c10; --panel:#111317; --muted:#aab2bf; --accent:#3b82f6; --ok:#16a34a; --warn:#f59e0b; --err:#ef4444; --border:#1f2430; }
|
||||
body { margin:0; font-family: ui-sans-serif, system-ui, -apple-system, Segoe UI, Roboto, Helvetica, Arial; background: var(--bg); color:#e5e7eb; }
|
||||
.wrap { max-width: 980px; margin: 32px auto; padding: 0 16px; }
|
||||
h1 { font-size: 22px; margin:0 0 16px; }
|
||||
.card { background: var(--panel); border:1px solid var(--border); border-radius: 10px; padding: 16px; margin: 12px 0; }
|
||||
.row { display:flex; gap:12px; flex-wrap:wrap; }
|
||||
.col { flex: 1 1 280px; display:flex; flex-direction:column; }
|
||||
label { font-size: 12px; color: var(--muted); margin-bottom: 6px; }
|
||||
input, textarea, select { background:#0f1115; color:#e5e7eb; border:1px solid var(--border); padding:10px 12px; border-radius:8px; outline:none; }
|
||||
textarea { min-height: 100px; resize: vertical; }
|
||||
.btns { display:flex; gap:8px; flex-wrap:wrap; margin-top: 8px; }
|
||||
button { background:#1a1f2b; color:#e5e7eb; border:1px solid var(--border); padding:8px 12px; border-radius:8px; cursor:pointer; }
|
||||
button.primary { background: var(--accent); border-color: var(--accent); color:white; }
|
||||
button.ok { background: var(--ok); border-color: var(--ok); color:white; }
|
||||
button.warn { background: var(--warn); border-color: var(--warn); color:black; }
|
||||
button.ghost { background: transparent; }
|
||||
.muted { color: var(--muted); font-size: 12px; }
|
||||
.mono { font-family: ui-monospace, SFMono-Regular, Menlo, Monaco, Consolas, "Liberation Mono", "Courier New", monospace; }
|
||||
.grid2 { display:grid; grid-template-columns: 1fr 1fr; gap: 12px; }
|
||||
@media (max-width: 880px){ .grid2 { grid-template-columns: 1fr; } }
|
||||
.pill { padding: 3px 8px; border-radius:999px; font-size: 12px; border:1px solid var(--border); background:#0f1115; }
|
||||
.ok { color: #10b981; }
|
||||
.err { color: #ef4444; }
|
||||
.sep { height:1px; background: var(--border); margin: 12px 0; }
|
||||
</style>
|
||||
</head>
|
||||
<body>
|
||||
<div class="wrap">
|
||||
<h1>OAuth2/OIDC 授权码 + PKCE 前端演示</h1>
|
||||
|
||||
<div class="card">
|
||||
<div class="row">
|
||||
<div class="col">
|
||||
<label>Issuer(可选,用于自动发现 /.well-known/openid-configuration)</label>
|
||||
<input id="issuer" placeholder="https://your-domain" />
|
||||
<div class="btns"><button class="" id="btnDiscover">自动发现端点</button></div>
|
||||
<div class="muted">提示:若未配置 Issuer,可直接填写下方端点。</div>
|
||||
</div>
|
||||
</div>
|
||||
<div class="row">
|
||||
<div class="col"><label>Authorization Endpoint</label><input id="authorization_endpoint" placeholder="https://domain/api/oauth/authorize" /></div>
|
||||
<div class="col"><label>Token Endpoint</label><input id="token_endpoint" placeholder="https://domain/api/oauth/token" /></div>
|
||||
</div>
|
||||
<div class="row">
|
||||
<div class="col"><label>UserInfo Endpoint(可选)</label><input id="userinfo_endpoint" placeholder="https://domain/api/oauth/userinfo" /></div>
|
||||
<div class="col"><label>Client ID</label><input id="client_id" placeholder="your-public-client-id" /></div>
|
||||
</div>
|
||||
<div class="row">
|
||||
<div class="col"><label>Redirect URI(当前页地址或你的回调)</label><input id="redirect_uri" /></div>
|
||||
<div class="col"><label>Scope</label><input id="scope" value="openid profile email" /></div>
|
||||
</div>
|
||||
<div class="row">
|
||||
<div class="col"><label>State</label><input id="state" /></div>
|
||||
<div class="col"><label>Nonce</label><input id="nonce" /></div>
|
||||
</div>
|
||||
<div class="row">
|
||||
<div class="col"><label>Code Verifier(自动生成,不会上送)</label><input id="code_verifier" class="mono" readonly /></div>
|
||||
<div class="col"><label>Code Challenge(S256)</label><input id="code_challenge" class="mono" readonly /></div>
|
||||
</div>
|
||||
<div class="btns">
|
||||
<button id="btnGenPkce">生成 PKCE</button>
|
||||
<button id="btnRandomState">随机 State</button>
|
||||
<button id="btnRandomNonce">随机 Nonce</button>
|
||||
<button id="btnMakeAuthURL">生成授权链接</button>
|
||||
<button id="btnAuthorize" class="primary">跳转授权</button>
|
||||
</div>
|
||||
<div class="row" style="margin-top:8px;">
|
||||
<div class="col">
|
||||
<label>授权链接(只生成不跳转)</label>
|
||||
<textarea id="authorize_url" class="mono" placeholder="(空)"></textarea>
|
||||
<div class="btns"><button id="btnCopyAuthURL">复制链接</button></div>
|
||||
</div>
|
||||
</div>
|
||||
<div class="sep"></div>
|
||||
<div class="muted">说明:
|
||||
<ul>
|
||||
<li>本页为纯前端演示,适用于公开客户端(不需要 client_secret)。</li>
|
||||
<li>如跨域调用 Token/UserInfo,需要服务端正确设置 CORS;建议将此 demo 部署到同源域名下。</li>
|
||||
</ul>
|
||||
</div>
|
||||
|
||||
<div class="sep"></div>
|
||||
<div class="row">
|
||||
<div class="col">
|
||||
<label>粘贴 OIDC Discovery JSON(/.well-known/openid-configuration)</label>
|
||||
<textarea id="conf_json" class="mono" placeholder='{"issuer":"https://...","authorization_endpoint":"...","token_endpoint":"...","userinfo_endpoint":"..."}'></textarea>
|
||||
<div class="btns">
|
||||
<button id="btnParseConf">解析并填充端点</button>
|
||||
<button id="btnGenConf">用当前端点生成 JSON</button>
|
||||
</div>
|
||||
<div class="muted">可将服务端返回的 OIDC Discovery JSON 粘贴到此处,点击“解析并填充端点”。</div>
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
<div class="card">
|
||||
<div class="row">
|
||||
<div class="col">
|
||||
<label>授权结果</label>
|
||||
<div id="authResult" class="muted">等待授权...</div>
|
||||
</div>
|
||||
</div>
|
||||
<div class="grid2" style="margin-top:12px;">
|
||||
<div>
|
||||
<label>Access Token</label>
|
||||
<textarea id="access_token" class="mono" placeholder="(空)"></textarea>
|
||||
<div class="btns">
|
||||
<button id="btnCopyAT">复制</button>
|
||||
<button id="btnCallUserInfo" class="ok">调用 UserInfo</button>
|
||||
</div>
|
||||
<div id="userinfoOut" class="muted" style="margin-top:6px;"></div>
|
||||
</div>
|
||||
<div>
|
||||
<label>ID Token(JWT)</label>
|
||||
<textarea id="id_token" class="mono" placeholder="(空)"></textarea>
|
||||
<div class="btns">
|
||||
<button id="btnDecodeJWT">解码显示 Claims</button>
|
||||
</div>
|
||||
<pre id="jwtClaims" class="mono" style="white-space:pre-wrap; word-break:break-all; margin-top:6px;"></pre>
|
||||
</div>
|
||||
</div>
|
||||
<div class="grid2" style="margin-top:12px;">
|
||||
<div>
|
||||
<label>Refresh Token</label>
|
||||
<textarea id="refresh_token" class="mono" placeholder="(空)"></textarea>
|
||||
<div class="btns">
|
||||
<button id="btnRefreshToken">使用 Refresh Token 刷新</button>
|
||||
</div>
|
||||
</div>
|
||||
<div>
|
||||
<label>原始 Token 响应</label>
|
||||
<textarea id="token_raw" class="mono" placeholder="(空)"></textarea>
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
<script>
|
||||
const $ = (id) => document.getElementById(id);
|
||||
const toB64Url = (buf) => btoa(String.fromCharCode(...new Uint8Array(buf))).replace(/\+/g, '-').replace(/\//g, '_').replace(/=+$/, '');
|
||||
async function sha256B64Url(str){
|
||||
const data = new TextEncoder().encode(str);
|
||||
const digest = await crypto.subtle.digest('SHA-256', data);
|
||||
return toB64Url(digest);
|
||||
}
|
||||
function randStr(len=64){
|
||||
const chars = 'ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789-._~';
|
||||
const arr = new Uint8Array(len); crypto.getRandomValues(arr);
|
||||
return Array.from(arr, v => chars[v % chars.length]).join('');
|
||||
}
|
||||
function setAuthInfo(msg, ok=true){
|
||||
const el = $('authResult');
|
||||
el.textContent = msg;
|
||||
el.className = ok ? 'ok' : 'err';
|
||||
}
|
||||
function qs(name){ const u=new URL(location.href); return u.searchParams.get(name); }
|
||||
|
||||
function persist(name, val){ sessionStorage.setItem('demo_'+name, val); }
|
||||
function load(name){ return sessionStorage.getItem('demo_'+name) || ''; }
|
||||
|
||||
// init defaults
|
||||
(function init(){
|
||||
$('redirect_uri').value = window.location.origin + window.location.pathname;
|
||||
// try load from discovery if issuer saved previously
|
||||
const iss = load('issuer'); if(iss) $('issuer').value = iss;
|
||||
const cid = load('client_id'); if(cid) $('client_id').value = cid;
|
||||
const scp = load('scope'); if(scp) $('scope').value = scp;
|
||||
})();
|
||||
|
||||
$('btnDiscover').onclick = async () => {
|
||||
const iss = $('issuer').value.trim(); if(!iss){ alert('请填写 Issuer'); return; }
|
||||
try{
|
||||
persist('issuer', iss);
|
||||
const res = await fetch(iss.replace(/\/$/,'') + '/api/.well-known/openid-configuration');
|
||||
const d = await res.json();
|
||||
$('authorization_endpoint').value = d.authorization_endpoint || '';
|
||||
$('token_endpoint').value = d.token_endpoint || '';
|
||||
$('userinfo_endpoint').value = d.userinfo_endpoint || '';
|
||||
if (d.issuer) { $('issuer').value = d.issuer; persist('issuer', d.issuer); }
|
||||
$('conf_json').value = JSON.stringify(d, null, 2);
|
||||
setAuthInfo('已从发现文档加载端点', true);
|
||||
}catch(e){ setAuthInfo('自动发现失败:'+e, false); }
|
||||
};
|
||||
|
||||
$('btnGenPkce').onclick = async () => {
|
||||
const v = randStr(64); const c = await sha256B64Url(v);
|
||||
$('code_verifier').value = v; $('code_challenge').value = c;
|
||||
persist('code_verifier', v); persist('code_challenge', c);
|
||||
setAuthInfo('已生成 PKCE 参数', true);
|
||||
};
|
||||
$('btnRandomState').onclick = () => { $('state').value = randStr(16); persist('state', $('state').value); };
|
||||
$('btnRandomNonce').onclick = () => { $('nonce').value = randStr(16); persist('nonce', $('nonce').value); };
|
||||
|
||||
function buildAuthorizeURLFromFields() {
|
||||
const auth = $('authorization_endpoint').value.trim();
|
||||
const token = $('token_endpoint').value.trim(); // just validate
|
||||
const cid = $('client_id').value.trim();
|
||||
const red = $('redirect_uri').value.trim();
|
||||
const scp = $('scope').value.trim() || 'openid profile email';
|
||||
const st = $('state').value.trim() || randStr(16);
|
||||
const no = $('nonce').value.trim() || randStr(16);
|
||||
const cc = $('code_challenge').value.trim();
|
||||
const cv = $('code_verifier').value.trim();
|
||||
if(!auth || !token || !cid || !red){ throw new Error('请先完善端点/ClientID/RedirectURI'); }
|
||||
if(!cc || !cv){ throw new Error('请先生成 PKCE'); }
|
||||
persist('authorization_endpoint', auth); persist('token_endpoint', token);
|
||||
persist('client_id', cid); persist('redirect_uri', red); persist('scope', scp);
|
||||
persist('state', st); persist('nonce', no); persist('code_verifier', cv);
|
||||
const u = new URL(auth);
|
||||
u.searchParams.set('response_type', 'code');
|
||||
u.searchParams.set('client_id', cid);
|
||||
u.searchParams.set('redirect_uri', red);
|
||||
u.searchParams.set('scope', scp);
|
||||
u.searchParams.set('state', st);
|
||||
u.searchParams.set('nonce', no);
|
||||
u.searchParams.set('code_challenge', cc);
|
||||
u.searchParams.set('code_challenge_method', 'S256');
|
||||
return u.toString();
|
||||
}
|
||||
$('btnMakeAuthURL').onclick = () => {
|
||||
try {
|
||||
const url = buildAuthorizeURLFromFields();
|
||||
$('authorize_url').value = url;
|
||||
setAuthInfo('已生成授权链接', true);
|
||||
} catch(e){ setAuthInfo(e.message, false); }
|
||||
};
|
||||
$('btnAuthorize').onclick = () => {
|
||||
try { const url = buildAuthorizeURLFromFields(); location.href = url; }
|
||||
catch(e){ setAuthInfo(e.message, false); }
|
||||
};
|
||||
$('btnCopyAuthURL').onclick = async () => { try{ await navigator.clipboard.writeText($('authorize_url').value); }catch{} };
|
||||
|
||||
// Parse OIDC discovery JSON pasted by user
|
||||
$('btnParseConf').onclick = () => {
|
||||
const txt = $('conf_json').value.trim(); if(!txt){ alert('请先粘贴 JSON'); return; }
|
||||
try{
|
||||
const d = JSON.parse(txt);
|
||||
if (d.issuer) { $('issuer').value = d.issuer; persist('issuer', d.issuer); }
|
||||
if (d.authorization_endpoint) $('authorization_endpoint').value = d.authorization_endpoint;
|
||||
if (d.token_endpoint) $('token_endpoint').value = d.token_endpoint;
|
||||
if (d.userinfo_endpoint) $('userinfo_endpoint').value = d.userinfo_endpoint;
|
||||
setAuthInfo('已解析配置并填充端点', true);
|
||||
}catch(e){ setAuthInfo('解析失败:'+e, false); }
|
||||
};
|
||||
// Generate a minimal discovery JSON from current fields
|
||||
$('btnGenConf').onclick = () => {
|
||||
const d = {
|
||||
issuer: $('issuer').value.trim() || undefined,
|
||||
authorization_endpoint: $('authorization_endpoint').value.trim() || undefined,
|
||||
token_endpoint: $('token_endpoint').value.trim() || undefined,
|
||||
userinfo_endpoint: $('userinfo_endpoint').value.trim() || undefined,
|
||||
};
|
||||
$('conf_json').value = JSON.stringify(d, null, 2);
|
||||
};
|
||||
|
||||
async function postForm(url, data){
|
||||
const body = Object.entries(data).map(([k,v])=> `${encodeURIComponent(k)}=${encodeURIComponent(v)}`).join('&');
|
||||
const res = await fetch(url, { method:'POST', headers:{ 'Content-Type':'application/x-www-form-urlencoded' }, body });
|
||||
if(!res.ok){ const t = await res.text(); throw new Error(`HTTP ${res.status} ${t}`); }
|
||||
return res.json();
|
||||
}
|
||||
|
||||
async function handleCallback(){
|
||||
const code = qs('code'); const err = qs('error');
|
||||
const state = qs('state');
|
||||
if(err){ setAuthInfo('授权失败:'+err, false); return; }
|
||||
if(!code){ setAuthInfo('等待授权...', true); return; }
|
||||
// state check
|
||||
if(state && load('state') && state !== load('state')){ setAuthInfo('state 不匹配,已拒绝', false); return; }
|
||||
try{
|
||||
const tokenEp = load('token_endpoint');
|
||||
const data = await postForm(tokenEp, {
|
||||
grant_type:'authorization_code',
|
||||
code,
|
||||
client_id: load('client_id'),
|
||||
redirect_uri: load('redirect_uri'),
|
||||
code_verifier: load('code_verifier')
|
||||
});
|
||||
$('access_token').value = data.access_token || '';
|
||||
$('id_token').value = data.id_token || '';
|
||||
$('refresh_token').value = data.refresh_token || '';
|
||||
$('token_raw').value = JSON.stringify(data, null, 2);
|
||||
setAuthInfo('授权成功,已获取令牌', true);
|
||||
}catch(e){ setAuthInfo('交换令牌失败:'+e.message, false); }
|
||||
}
|
||||
handleCallback();
|
||||
|
||||
$('btnCopyAT').onclick = async () => { try{ await navigator.clipboard.writeText($('access_token').value); }catch{} };
|
||||
$('btnDecodeJWT').onclick = () => {
|
||||
const t = $('id_token').value.trim(); if(!t){ $('jwtClaims').textContent='(空)'; return; }
|
||||
const parts = t.split('.'); if(parts.length<2){ $('jwtClaims').textContent='格式错误'; return; }
|
||||
try{ const json = JSON.parse(atob(parts[1].replace(/-/g,'+').replace(/_/g,'/'))); $('jwtClaims').textContent = JSON.stringify(json, null, 2);}catch(e){ $('jwtClaims').textContent='解码失败:'+e; }
|
||||
};
|
||||
$('btnCallUserInfo').onclick = async () => {
|
||||
const at = $('access_token').value.trim(); const ep = $('userinfo_endpoint').value.trim(); if(!at||!ep){ alert('请填写UserInfo端点并获取AccessToken'); return; }
|
||||
try{
|
||||
const res = await fetch(ep, { headers:{ Authorization: 'Bearer '+at } });
|
||||
const data = await res.json(); $('userinfoOut').textContent = JSON.stringify(data, null, 2);
|
||||
}catch(e){ $('userinfoOut').textContent = '调用失败:'+e; }
|
||||
};
|
||||
$('btnRefreshToken').onclick = async () => {
|
||||
const rt = $('refresh_token').value.trim(); if(!rt){ alert('没有刷新令牌'); return; }
|
||||
try{
|
||||
const tokenEp = load('token_endpoint');
|
||||
const data = await postForm(tokenEp, {
|
||||
grant_type:'refresh_token',
|
||||
refresh_token: rt,
|
||||
client_id: load('client_id')
|
||||
});
|
||||
$('access_token').value = data.access_token || '';
|
||||
$('id_token').value = data.id_token || '';
|
||||
$('refresh_token').value = data.refresh_token || '';
|
||||
$('token_raw').value = JSON.stringify(data, null, 2);
|
||||
setAuthInfo('刷新成功', true);
|
||||
}catch(e){ setAuthInfo('刷新失败:'+e.message, false); }
|
||||
};
|
||||
</script>
|
||||
</body>
|
||||
</html>
|
||||
181
examples/oauth/oauth2_test_client.go
Normal file
181
examples/oauth/oauth2_test_client.go
Normal file
@@ -0,0 +1,181 @@
|
||||
package main
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"io"
|
||||
"log"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"path"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"golang.org/x/oauth2"
|
||||
"golang.org/x/oauth2/clientcredentials"
|
||||
)
|
||||
|
||||
func main() {
|
||||
// 测试 Client Credentials 流程
|
||||
//testClientCredentials()
|
||||
|
||||
// 测试 Authorization Code + PKCE 流程(需要浏览器交互)
|
||||
testAuthorizationCode()
|
||||
}
|
||||
|
||||
// testClientCredentials 测试服务对服务认证
|
||||
func testClientCredentials() {
|
||||
fmt.Println("=== Testing Client Credentials Flow ===")
|
||||
|
||||
cfg := clientcredentials.Config{
|
||||
ClientID: "client_dsFyyoyNZWjhbNa2", // 需要先创建客户端
|
||||
ClientSecret: "hLLdn2Ia4UM7hcsJaSuUFDV0Px9BrkNq",
|
||||
TokenURL: "http://localhost:3000/api/oauth/token",
|
||||
Scopes: []string{"api:read", "api:write"},
|
||||
EndpointParams: map[string][]string{
|
||||
"audience": {"api://new-api"},
|
||||
},
|
||||
}
|
||||
|
||||
// 创建HTTP客户端
|
||||
httpClient := cfg.Client(context.Background())
|
||||
|
||||
// 调用受保护的API
|
||||
resp, err := httpClient.Get("http://localhost:3000/api/status")
|
||||
if err != nil {
|
||||
log.Printf("Request failed: %v", err)
|
||||
return
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
body, err := io.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
log.Printf("Failed to read response: %v", err)
|
||||
return
|
||||
}
|
||||
|
||||
fmt.Printf("Status: %s\n", resp.Status)
|
||||
fmt.Printf("Response: %s\n", string(body))
|
||||
}
|
||||
|
||||
// testAuthorizationCode 测试授权码流程
|
||||
func testAuthorizationCode() {
|
||||
fmt.Println("=== Testing Authorization Code + PKCE Flow ===")
|
||||
|
||||
conf := oauth2.Config{
|
||||
ClientID: "client_dsFyyoyNZWjhbNa2", // 需要先创建客户端
|
||||
ClientSecret: "JHiugKf89OMmTLuZMZyA2sgZnO0Ioae3",
|
||||
RedirectURL: "http://localhost:9999/callback",
|
||||
// 包含 openid/profile/email 以便调用 UserInfo
|
||||
Scopes: []string{"openid", "profile", "email", "api:read"},
|
||||
Endpoint: oauth2.Endpoint{
|
||||
AuthURL: "http://localhost:3000/api/oauth/authorize",
|
||||
TokenURL: "http://localhost:3000/api/oauth/token",
|
||||
},
|
||||
}
|
||||
|
||||
// 生成PKCE参数
|
||||
codeVerifier := oauth2.GenerateVerifier()
|
||||
state := fmt.Sprintf("state-%d", time.Now().Unix())
|
||||
|
||||
// 构建授权URL
|
||||
url := conf.AuthCodeURL(
|
||||
state,
|
||||
oauth2.S256ChallengeOption(codeVerifier),
|
||||
//oauth2.SetAuthURLParam("audience", "api://new-api"),
|
||||
)
|
||||
|
||||
fmt.Printf("Visit this URL to authorize:\n%s\n\n", url)
|
||||
fmt.Printf("A local server will listen on http://localhost:9999/callback to receive the code...\n")
|
||||
|
||||
// 启动回调本地服务器,自动接收授权码
|
||||
codeCh := make(chan string, 1)
|
||||
srv := &http.Server{Addr: ":9999"}
|
||||
http.HandleFunc("/callback", func(w http.ResponseWriter, r *http.Request) {
|
||||
q := r.URL.Query()
|
||||
if errParam := q.Get("error"); errParam != "" {
|
||||
fmt.Fprintf(w, "Authorization failed: %s", errParam)
|
||||
return
|
||||
}
|
||||
gotState := q.Get("state")
|
||||
if gotState != state {
|
||||
http.Error(w, "state mismatch", http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
code := q.Get("code")
|
||||
if code == "" {
|
||||
http.Error(w, "missing code", http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
fmt.Fprintln(w, "Authorization received. You may close this window.")
|
||||
select {
|
||||
case codeCh <- code:
|
||||
default:
|
||||
}
|
||||
go func() {
|
||||
// 稍后关闭服务
|
||||
_ = srv.Shutdown(context.Background())
|
||||
}()
|
||||
})
|
||||
go func() {
|
||||
_ = srv.ListenAndServe()
|
||||
}()
|
||||
|
||||
// 等待授权码
|
||||
var code string
|
||||
select {
|
||||
case code = <-codeCh:
|
||||
case <-time.After(5 * time.Minute):
|
||||
log.Println("Timeout waiting for authorization code")
|
||||
_ = srv.Shutdown(context.Background())
|
||||
return
|
||||
}
|
||||
|
||||
// 交换令牌
|
||||
token, err := conf.Exchange(
|
||||
context.Background(),
|
||||
code,
|
||||
oauth2.VerifierOption(codeVerifier),
|
||||
)
|
||||
if err != nil {
|
||||
log.Printf("Token exchange failed: %v", err)
|
||||
return
|
||||
}
|
||||
|
||||
fmt.Printf("Access Token: %s\n", token.AccessToken)
|
||||
fmt.Printf("Token Type: %s\n", token.TokenType)
|
||||
fmt.Printf("Expires In: %v\n", token.Expiry)
|
||||
|
||||
// 使用令牌调用 UserInfo
|
||||
client := conf.Client(context.Background(), token)
|
||||
userInfoURL := buildUserInfoFromAuth(conf.Endpoint.AuthURL)
|
||||
resp, err := client.Get(userInfoURL)
|
||||
if err != nil {
|
||||
log.Printf("UserInfo request failed: %v", err)
|
||||
return
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
body, err := io.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
log.Printf("Failed to read UserInfo response: %v", err)
|
||||
return
|
||||
}
|
||||
fmt.Printf("UserInfo: %s\n", string(body))
|
||||
}
|
||||
|
||||
// buildUserInfoFromAuth 将授权端点URL转换为UserInfo端点URL
|
||||
func buildUserInfoFromAuth(auth string) string {
|
||||
u, err := url.Parse(auth)
|
||||
if err != nil {
|
||||
return ""
|
||||
}
|
||||
// 将最后一个路径段 authorize 替换为 userinfo
|
||||
dir := path.Dir(u.Path)
|
||||
if strings.HasSuffix(u.Path, "/authorize") {
|
||||
u.Path = path.Join(dir, "userinfo")
|
||||
} else {
|
||||
// 回退:追加默认 /oauth/userinfo
|
||||
u.Path = path.Join(dir, "userinfo")
|
||||
}
|
||||
return u.String()
|
||||
}
|
||||
23
go.mod
23
go.mod
@@ -11,20 +11,24 @@ require (
|
||||
github.com/aws/aws-sdk-go-v2/credentials v1.17.11
|
||||
github.com/aws/aws-sdk-go-v2/service/bedrockruntime v1.33.0
|
||||
github.com/aws/smithy-go v1.22.5
|
||||
github.com/bytedance/gopkg v0.0.0-20220118071334-3db87571198b
|
||||
github.com/bytedance/gopkg v0.0.0-20221122125632-68358b8ecec6
|
||||
github.com/gin-contrib/cors v1.7.2
|
||||
github.com/gin-contrib/gzip v0.0.6
|
||||
github.com/gin-contrib/sessions v0.0.5
|
||||
github.com/gin-contrib/static v0.0.1
|
||||
github.com/gin-gonic/gin v1.9.1
|
||||
github.com/glebarez/sqlite v1.9.0
|
||||
github.com/go-oauth2/gin-server v1.1.0
|
||||
github.com/go-oauth2/oauth2/v4 v4.5.4
|
||||
github.com/go-playground/validator/v10 v10.20.0
|
||||
github.com/go-redis/redis/v8 v8.11.5
|
||||
github.com/golang-jwt/jwt v3.2.2+incompatible
|
||||
github.com/golang-jwt/jwt/v5 v5.3.0
|
||||
github.com/google/uuid v1.6.0
|
||||
github.com/gorilla/websocket v1.5.0
|
||||
github.com/jinzhu/copier v0.4.0
|
||||
github.com/joho/godotenv v1.5.1
|
||||
github.com/lestrrat-go/jwx/v2 v2.1.6
|
||||
github.com/pkg/errors v0.9.1
|
||||
github.com/pquerna/otp v1.5.0
|
||||
github.com/samber/lo v1.39.0
|
||||
@@ -38,6 +42,7 @@ require (
|
||||
golang.org/x/crypto v0.35.0
|
||||
golang.org/x/image v0.23.0
|
||||
golang.org/x/net v0.35.0
|
||||
golang.org/x/oauth2 v0.30.0
|
||||
golang.org/x/sync v0.11.0
|
||||
gorm.io/driver/mysql v1.4.3
|
||||
gorm.io/driver/postgres v1.5.2
|
||||
@@ -55,6 +60,7 @@ require (
|
||||
github.com/cespare/xxhash/v2 v2.3.0 // indirect
|
||||
github.com/cloudwego/base64x v0.1.4 // indirect
|
||||
github.com/cloudwego/iasm v0.2.0 // indirect
|
||||
github.com/decred/dcrd/dcrec/secp256k1/v4 v4.4.0 // indirect
|
||||
github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f // indirect
|
||||
github.com/dlclark/regexp2 v1.11.5 // indirect
|
||||
github.com/dustin/go-humanize v1.0.1 // indirect
|
||||
@@ -65,7 +71,7 @@ require (
|
||||
github.com/go-playground/locales v0.14.1 // indirect
|
||||
github.com/go-playground/universal-translator v0.18.1 // indirect
|
||||
github.com/go-sql-driver/mysql v1.7.0 // indirect
|
||||
github.com/goccy/go-json v0.10.2 // indirect
|
||||
github.com/goccy/go-json v0.10.3 // indirect
|
||||
github.com/google/go-cmp v0.6.0 // indirect
|
||||
github.com/gorilla/context v1.1.1 // indirect
|
||||
github.com/gorilla/securecookie v1.1.1 // indirect
|
||||
@@ -79,14 +85,25 @@ require (
|
||||
github.com/json-iterator/go v1.1.12 // indirect
|
||||
github.com/klauspost/cpuid/v2 v2.2.9 // indirect
|
||||
github.com/leodido/go-urn v1.4.0 // indirect
|
||||
github.com/lestrrat-go/blackmagic v1.0.3 // indirect
|
||||
github.com/lestrrat-go/httpcc v1.0.1 // indirect
|
||||
github.com/lestrrat-go/httprc v1.0.6 // indirect
|
||||
github.com/lestrrat-go/iter v1.0.2 // indirect
|
||||
github.com/lestrrat-go/option v1.0.1 // indirect
|
||||
github.com/mattn/go-isatty v0.0.20 // indirect
|
||||
github.com/mitchellh/mapstructure v1.5.0 // indirect
|
||||
github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd // indirect
|
||||
github.com/modern-go/reflect2 v1.0.2 // indirect
|
||||
github.com/pelletier/go-toml/v2 v2.2.1 // indirect
|
||||
github.com/remyoudompheng/bigfft v0.0.0-20230129092748-24d4a6f8daec // indirect
|
||||
github.com/segmentio/asm v1.2.0 // indirect
|
||||
github.com/tidwall/btree v0.0.0-20191029221954-400434d76274 // indirect
|
||||
github.com/tidwall/buntdb v1.1.2 // indirect
|
||||
github.com/tidwall/grect v0.0.0-20161006141115-ba9a043346eb // indirect
|
||||
github.com/tidwall/match v1.1.1 // indirect
|
||||
github.com/tidwall/pretty v1.2.0 // indirect
|
||||
github.com/tidwall/rtree v0.0.0-20180113144539-6cd427091e0e // indirect
|
||||
github.com/tidwall/tinyqueue v0.0.0-20180302190814-1e39f5511563 // indirect
|
||||
github.com/tklauser/go-sysconf v0.3.12 // indirect
|
||||
github.com/tklauser/numcpus v0.6.1 // indirect
|
||||
github.com/twitchyliquid64/golang-asm v0.15.1 // indirect
|
||||
@@ -94,7 +111,7 @@ require (
|
||||
github.com/yusufpapurcu/wmi v1.2.3 // indirect
|
||||
golang.org/x/arch v0.12.0 // indirect
|
||||
golang.org/x/exp v0.0.0-20240404231335-c0f41cb1a7a0 // indirect
|
||||
golang.org/x/sys v0.30.0 // indirect
|
||||
golang.org/x/sys v0.31.0 // indirect
|
||||
golang.org/x/text v0.22.0 // indirect
|
||||
google.golang.org/protobuf v1.34.2 // indirect
|
||||
gopkg.in/yaml.v3 v3.0.1 // indirect
|
||||
|
||||
94
go.sum
94
go.sum
@@ -1,5 +1,7 @@
|
||||
github.com/Calcium-Ion/go-epay v0.0.4 h1:C96M7WfRLadcIVscWzwLiYs8etI1wrDmtFMuK2zP22A=
|
||||
github.com/Calcium-Ion/go-epay v0.0.4/go.mod h1:cxo/ZOg8ClvE3VAnCmEzbuyAZINSq7kFEN9oHj5WQ2U=
|
||||
github.com/ajg/form v1.5.1 h1:t9c7v8JUKu/XxOGBU0yjNpaMloxGEJhUkqFRq0ibGeU=
|
||||
github.com/ajg/form v1.5.1/go.mod h1:uL1WgH+h2mgNtvBq0339dVnzXdBETtL2LeUXaIv25UY=
|
||||
github.com/andybalholm/brotli v1.1.1 h1:PR2pgnyFznKEugtsUo0xLdDop5SKXd5Qf5ysW+7XdTA=
|
||||
github.com/andybalholm/brotli v1.1.1/go.mod h1:05ib4cKhjx3OQYUY22hTVd34Bc8upXjOLL2rKwwZBoA=
|
||||
github.com/anknown/ahocorasick v0.0.0-20190904063843-d75dbd5169c0 h1:onfun1RA+KcxaMk1lfrRnwCd1UUuOjJM/lri5eM1qMs=
|
||||
@@ -23,8 +25,8 @@ github.com/aws/smithy-go v1.22.5/go.mod h1:t1ufH5HMublsJYulve2RKmHDC15xu1f26kHCp
|
||||
github.com/boombuler/barcode v1.0.1-0.20190219062509-6c824513bacc/go.mod h1:paBWMcWSl3LHKBqUq+rly7CNSldXjb2rDl3JlRe0mD8=
|
||||
github.com/boombuler/barcode v1.1.0 h1:ChaYjBR63fr4LFyGn8E8nt7dBSt3MiU3zMOZqFvVkHo=
|
||||
github.com/boombuler/barcode v1.1.0/go.mod h1:paBWMcWSl3LHKBqUq+rly7CNSldXjb2rDl3JlRe0mD8=
|
||||
github.com/bytedance/gopkg v0.0.0-20220118071334-3db87571198b h1:LTGVFpNmNHhj0vhOlfgWueFJ32eK9blaIlHR2ciXOT0=
|
||||
github.com/bytedance/gopkg v0.0.0-20220118071334-3db87571198b/go.mod h1:2ZlV9BaUH4+NXIBF0aMdKKAnHTzqH+iMU4KUjAbL23Q=
|
||||
github.com/bytedance/gopkg v0.0.0-20221122125632-68358b8ecec6 h1:FCLDGi1EmB7JzjVVYNZiqc/zAJj2BQ5M0lfkVOxbfs8=
|
||||
github.com/bytedance/gopkg v0.0.0-20221122125632-68358b8ecec6/go.mod h1:5FoAH5xUHHCMDvQPy1rnj8moqLkLHFaDVBjHhcFwEi0=
|
||||
github.com/bytedance/sonic v1.11.6 h1:oUp34TzMlL+OY1OUWxHqsdkgC/Zfc85zGqw9siXjrc0=
|
||||
github.com/bytedance/sonic v1.11.6/go.mod h1:LysEHSvpvDySVdC2f87zGWf6CIKJcAvqab1ZaiQtds4=
|
||||
github.com/bytedance/sonic/loader v0.1.1 h1:c+e5Pt1k/cy5wMveRDyk2X4B9hF4g7an8N3zCYjJFNM=
|
||||
@@ -39,16 +41,22 @@ github.com/creack/pty v1.1.9/go.mod h1:oKZEueFk5CKHvIhNR5MUki03XCEU+Q6VDXinZuGJ3
|
||||
github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
|
||||
github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c=
|
||||
github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
|
||||
github.com/decred/dcrd/dcrec/secp256k1/v4 v4.4.0 h1:NMZiJj8QnKe1LgsbDayM4UoHwbvwDRwnI3hwNaAHRnc=
|
||||
github.com/decred/dcrd/dcrec/secp256k1/v4 v4.4.0/go.mod h1:ZXNYxsqcloTdSy/rNShjYzMhyjf0LaoftYK0p+A3h40=
|
||||
github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f h1:lO4WD4F/rVNCu3HqELle0jiPLLBs70cWOduZpkS1E78=
|
||||
github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f/go.mod h1:cuUVRXasLTGF7a8hSLbxyZXjz+1KgoB3wDUb6vlszIc=
|
||||
github.com/dlclark/regexp2 v1.11.5 h1:Q/sSnsKerHeCkc/jSTNq1oCm7KiVgUMZRDUoRu0JQZQ=
|
||||
github.com/dlclark/regexp2 v1.11.5/go.mod h1:DHkYz0B9wPfa6wondMfaivmHpzrQ3v9q8cnmRbL6yW8=
|
||||
github.com/dustin/go-humanize v1.0.1 h1:GzkhY7T5VNhEkwH0PVJgjz+fX1rhBrR7pRT3mDkpeCY=
|
||||
github.com/dustin/go-humanize v1.0.1/go.mod h1:Mu1zIs6XwVuF/gI1OepvI0qD18qycQx+mFykh5fBlto=
|
||||
github.com/fatih/structs v1.1.0 h1:Q7juDM0QtcnhCpeyLGQKyg4TOIghuNXrkL32pHAUMxo=
|
||||
github.com/fatih/structs v1.1.0/go.mod h1:9NiDSp5zOcgEDl+j00MP/WkGVPOlPRLejGD8Ga6PJ7M=
|
||||
github.com/fsnotify/fsnotify v1.4.9 h1:hsms1Qyu0jgnwNXIxa+/V/PDsU6CfLf6CNO8H7IWoS4=
|
||||
github.com/fsnotify/fsnotify v1.4.9/go.mod h1:znqG4EE+3YCdAaPaxE2ZRY/06pZUdp0tY4IgpuI1SZQ=
|
||||
github.com/gabriel-vasile/mimetype v1.4.3 h1:in2uUcidCuFcDKtdcBxlR0rJ1+fsokWf+uqxgUFjbI0=
|
||||
github.com/gabriel-vasile/mimetype v1.4.3/go.mod h1:d8uq/6HKRL6CGdk+aubisF/M5GcPfT7nKyLpA0lbSSk=
|
||||
github.com/gavv/httpexpect v2.0.0+incompatible h1:1X9kcRshkSKEjNJJxX9Y9mQ5BRfbxU5kORdjhlA1yX8=
|
||||
github.com/gavv/httpexpect v2.0.0+incompatible/go.mod h1:x+9tiU1YnrOvnB725RkpoLv1M62hOWzwo5OXotisrKc=
|
||||
github.com/gin-contrib/cors v1.7.2 h1:oLDHxdg8W/XDoN/8zamqk/Drgt4oVZDvaV0YmvVICQw=
|
||||
github.com/gin-contrib/cors v1.7.2/go.mod h1:SUJVARKgQ40dmrzgXEVxj2m7Ig1v1qIboQkPDTQ9t2E=
|
||||
github.com/gin-contrib/gzip v0.0.6 h1:NjcunTcGAj5CO1gn4N8jHOSIeRFHIbn51z6K+xaN4d4=
|
||||
@@ -67,6 +75,10 @@ github.com/glebarez/go-sqlite v1.21.2 h1:3a6LFC4sKahUunAmynQKLZceZCOzUthkRkEAl9g
|
||||
github.com/glebarez/go-sqlite v1.21.2/go.mod h1:sfxdZyhQjTM2Wry3gVYWaW072Ri1WMdWJi0k6+3382k=
|
||||
github.com/glebarez/sqlite v1.9.0 h1:Aj6bPA12ZEx5GbSF6XADmCkYXlljPNUY+Zf1EQxynXs=
|
||||
github.com/glebarez/sqlite v1.9.0/go.mod h1:YBYCoyupOao60lzp1MVBLEjZfgkq0tdB1voAQ09K9zw=
|
||||
github.com/go-oauth2/gin-server v1.1.0 h1:+7AyIfrcKaThZxxABRYECysxAfTccgpFdAqY1enuzBk=
|
||||
github.com/go-oauth2/gin-server v1.1.0/go.mod h1:f08F3l5/Pbayb4pjnv5PpUdQLFejgGfHrTjA6IZb0eM=
|
||||
github.com/go-oauth2/oauth2/v4 v4.5.4 h1:YjI0tmGW8oxVhn9QSBIxlr641QugWrJY5UWa6XmLcW0=
|
||||
github.com/go-oauth2/oauth2/v4 v4.5.4/go.mod h1:BXiOY+QZtZy2ewbsGk2B5P8TWmtz/Rf7ES5ZttQFxfQ=
|
||||
github.com/go-ole/go-ole v1.2.6 h1:/Fpf6oFPoeFik9ty7siob0G6Ke8QvQEuVcuChpwXzpY=
|
||||
github.com/go-ole/go-ole v1.2.6/go.mod h1:pprOEPIfldk/42T2oK7lQ4v4JSDwmV0As9GaiUsvbm0=
|
||||
github.com/go-playground/assert/v2 v2.0.1/go.mod h1:VDjEfimB/XKnb+ZQfWdccd7VUvScMdVu0Titje2rxJ4=
|
||||
@@ -90,20 +102,26 @@ github.com/go-sql-driver/mysql v1.6.0/go.mod h1:DCzpHaOWr8IXmIStZouvnhqoel9Qv2LB
|
||||
github.com/go-sql-driver/mysql v1.7.0 h1:ueSltNNllEqE3qcWBTD0iQd3IpL/6U+mJxLkazJ7YPc=
|
||||
github.com/go-sql-driver/mysql v1.7.0/go.mod h1:OXbVy3sEdcQ2Doequ6Z5BW6fXNQTmx+9S1MCJN5yJMI=
|
||||
github.com/goccy/go-json v0.9.7/go.mod h1:6MelG93GURQebXPDq3khkgXZkazVtN9CRI+MGFi0w8I=
|
||||
github.com/goccy/go-json v0.10.2 h1:CrxCmQqYDkv1z7lO7Wbh2HN93uovUHgrECaO5ZrCXAU=
|
||||
github.com/goccy/go-json v0.10.2/go.mod h1:6MelG93GURQebXPDq3khkgXZkazVtN9CRI+MGFi0w8I=
|
||||
github.com/goccy/go-json v0.10.3 h1:KZ5WoDbxAIgm2HNbYckL0se1fHD6rz5j4ywS6ebzDqA=
|
||||
github.com/goccy/go-json v0.10.3/go.mod h1:oq7eo15ShAhp70Anwd5lgX2pLfOS3QCiwU/PULtXL6M=
|
||||
github.com/golang-jwt/jwt v3.2.2+incompatible h1:IfV12K8xAKAnZqdXVzCZ+TOjboZ2keLg81eXfW3O+oY=
|
||||
github.com/golang-jwt/jwt v3.2.2+incompatible/go.mod h1:8pz2t5EyA70fFQQSrl6XZXzqecmYZeUEB8OUGHkxJ+I=
|
||||
github.com/golang-jwt/jwt/v5 v5.3.0 h1:pv4AsKCKKZuqlgs5sUmn4x8UlGa0kEVt/puTpKx9vvo=
|
||||
github.com/golang-jwt/jwt/v5 v5.3.0/go.mod h1:fxCRLWMO43lRc8nhHWY6LGqRcf+1gQWArsqaEUEa5bE=
|
||||
github.com/golang/protobuf v1.3.3/go.mod h1:vzj43D7+SQXF/4pzW/hwtAqwc6iTitCiVSaWz5lYuqw=
|
||||
github.com/golang/protobuf v1.5.0/go.mod h1:FsONVRAS9T7sI+LIUmWTfcYkHO4aIWwzhcaSAoJOfIk=
|
||||
github.com/google/go-cmp v0.5.5/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE=
|
||||
github.com/google/go-cmp v0.6.0 h1:ofyhxvXcZhMsU5ulbFiLKl/XBFqE1GSq7atu8tAmTRI=
|
||||
github.com/google/go-cmp v0.6.0/go.mod h1:17dUlkBOakJ0+DkrSSNjCkIjxS6bF9zb3elmeNGIjoY=
|
||||
github.com/google/go-querystring v1.0.0 h1:Xkwi/a1rcvNg1PPYe5vI8GbeBY/jrVuDX5ASuANWTrk=
|
||||
github.com/google/go-querystring v1.0.0/go.mod h1:odCYkC5MyYFN7vkCjXpyrEuKhc/BUO6wN/zVPAxq5ck=
|
||||
github.com/google/gofuzz v1.0.0/go.mod h1:dBl0BpW6vV/+mYPU4Po3pmUjxk6FQPldtuIdl/M65Eg=
|
||||
github.com/google/pprof v0.0.0-20221118152302-e6195bd50e26 h1:Xim43kblpZXfIBQsbuBVKCudVG457BR2GZFIz3uw3hQ=
|
||||
github.com/google/pprof v0.0.0-20221118152302-e6195bd50e26/go.mod h1:dDKJzRmX4S37WGHujM7tX//fmj1uioxKzKxz3lo4HJo=
|
||||
github.com/google/uuid v1.6.0 h1:NIvaJDMOsjHA8n1jAhLSgzrAzy1Hgr+hNrb57e+94F0=
|
||||
github.com/google/uuid v1.6.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo=
|
||||
github.com/gopherjs/gopherjs v0.0.0-20200217142428-fce0ec30dd00 h1:l5lAOZEym3oK3SQ2HBHWsJUfbNBiTXJDeW2QDxw9AQ0=
|
||||
github.com/gopherjs/gopherjs v0.0.0-20200217142428-fce0ec30dd00/go.mod h1:wJfORRmW1u3UXTncJ5qlYoELFm8eSnnEO6hX4iZ3EWY=
|
||||
github.com/gorilla/context v1.1.1 h1:AWwleXJkX/nhcU9bZSnZoi3h/qGYqQAGhq6zZe/aQW8=
|
||||
github.com/gorilla/context v1.1.1/go.mod h1:kBGZzfjB9CEq2AlWe17Uuf7NDRt0dE0s8S51q0aT7Yg=
|
||||
github.com/gorilla/securecookie v1.1.1 h1:miw7JPhV+b/lAHSXz4qd/nN9jRiAFV5FwjeKyCS8BvQ=
|
||||
@@ -112,6 +130,8 @@ github.com/gorilla/sessions v1.2.1 h1:DHd3rPN5lE3Ts3D8rKkQ8x/0kqfeNmBAaiSi+o7Fsg
|
||||
github.com/gorilla/sessions v1.2.1/go.mod h1:dk2InVEVJ0sfLlnXv9EAgkf6ecYs/i80K/zI+bUmuGM=
|
||||
github.com/gorilla/websocket v1.5.0 h1:PPwGk2jz7EePpoHN/+ClbZu8SPxiqlu12wZP/3sWmnc=
|
||||
github.com/gorilla/websocket v1.5.0/go.mod h1:YR8l580nyteQvAITg2hZ9XVh4b55+EU/adAjf1fMHhE=
|
||||
github.com/imkira/go-interpol v1.1.0 h1:KIiKr0VSG2CUW1hl1jpiyuzuJeKUUpC8iM1AIE7N1Vk=
|
||||
github.com/imkira/go-interpol v1.1.0/go.mod h1:z0h2/2T3XF8kyEPpRgJ3kmNv+C43p+I/CoI+jC3w2iA=
|
||||
github.com/jackc/pgpassfile v1.0.0 h1:/6Hmqy13Ss2zCq62VdNG8tM1wchn8zjSGOBJ6icpsIM=
|
||||
github.com/jackc/pgpassfile v1.0.0/go.mod h1:CEx0iS5ambNFdcRtxPj5JhEz+xB6uRky5eyVu/W2HEg=
|
||||
github.com/jackc/pgservicefile v0.0.0-20240606120523-5a60cdf6a761 h1:iCEnooe7UlwOQYpKFhBabPMi4aNAfoODPEFNiAnClxo=
|
||||
@@ -132,6 +152,10 @@ github.com/joho/godotenv v1.5.1/go.mod h1:f4LDr5Voq0i2e/R5DDNOoa2zzDfwtkZa6DnEwA
|
||||
github.com/json-iterator/go v1.1.9/go.mod h1:KdQUCv79m/52Kvf8AW2vK1V8akMuk1QjK/uOdHXbAo4=
|
||||
github.com/json-iterator/go v1.1.12 h1:PV8peI4a0ysnczrg+LtxykD8LfKY9ML6u2jnxaEnrnM=
|
||||
github.com/json-iterator/go v1.1.12/go.mod h1:e30LSqwooZae/UwlEbR2852Gd8hjQvJoHmT4TnhNGBo=
|
||||
github.com/jtolds/gls v4.20.0+incompatible h1:xdiiI2gbIgH/gLH7ADydsJ1uDOEzR8yvV7C0MuV77Wo=
|
||||
github.com/jtolds/gls v4.20.0+incompatible/go.mod h1:QJZ7F/aHp+rZTRtaJ1ow/lLfFfVYBRgL+9YlvaHOwJU=
|
||||
github.com/klauspost/compress v1.15.0 h1:xqfchp4whNFxn5A4XFyyYtitiWI8Hy5EW59jEwcyL6U=
|
||||
github.com/klauspost/compress v1.15.0/go.mod h1:/3/Vjq9QcHkK5uEr5lBEmyoZ1iFhe47etQ6QUkpK6sk=
|
||||
github.com/klauspost/cpuid/v2 v2.0.9/go.mod h1:FInQzS24/EEf25PyTYn52gqo7WaD8xa0213Md/qVLRg=
|
||||
github.com/klauspost/cpuid/v2 v2.2.9 h1:66ze0taIn2H33fBvCkXuv9BmCwDfafmiIVpKV9kKGuY=
|
||||
github.com/klauspost/cpuid/v2 v2.2.9/go.mod h1:rqkxqrZ1EhYM9G+hXH7YdowN5R5RGN6NK4QwQ3WMXF8=
|
||||
@@ -148,6 +172,18 @@ github.com/leodido/go-urn v1.2.0/go.mod h1:+8+nEpDfqqsY+g338gtMEUOtuK+4dEMhiQEgx
|
||||
github.com/leodido/go-urn v1.2.1/go.mod h1:zt4jvISO2HfUBqxjfIshjdMTYS56ZS/qv49ictyFfxY=
|
||||
github.com/leodido/go-urn v1.4.0 h1:WT9HwE9SGECu3lg4d/dIA+jxlljEa1/ffXKmRjqdmIQ=
|
||||
github.com/leodido/go-urn v1.4.0/go.mod h1:bvxc+MVxLKB4z00jd1z+Dvzr47oO32F/QSNjSBOlFxI=
|
||||
github.com/lestrrat-go/blackmagic v1.0.3 h1:94HXkVLxkZO9vJI/w2u1T0DAoprShFd13xtnSINtDWs=
|
||||
github.com/lestrrat-go/blackmagic v1.0.3/go.mod h1:6AWFyKNNj0zEXQYfTMPfZrAXUWUfTIZ5ECEUEJaijtw=
|
||||
github.com/lestrrat-go/httpcc v1.0.1 h1:ydWCStUeJLkpYyjLDHihupbn2tYmZ7m22BGkcvZZrIE=
|
||||
github.com/lestrrat-go/httpcc v1.0.1/go.mod h1:qiltp3Mt56+55GPVCbTdM9MlqhvzyuL6W/NMDA8vA5E=
|
||||
github.com/lestrrat-go/httprc v1.0.6 h1:qgmgIRhpvBqexMJjA/PmwSvhNk679oqD1RbovdCGW8k=
|
||||
github.com/lestrrat-go/httprc v1.0.6/go.mod h1:mwwz3JMTPBjHUkkDv/IGJ39aALInZLrhBp0X7KGUZlo=
|
||||
github.com/lestrrat-go/iter v1.0.2 h1:gMXo1q4c2pHmC3dn8LzRhJfP1ceCbgSiT9lUydIzltI=
|
||||
github.com/lestrrat-go/iter v1.0.2/go.mod h1:Momfcq3AnRlRjI5b5O8/G5/BvpzrhoFTZcn06fEOPt4=
|
||||
github.com/lestrrat-go/jwx/v2 v2.1.6 h1:hxM1gfDILk/l5ylers6BX/Eq1m/pnxe9NBwW6lVfecA=
|
||||
github.com/lestrrat-go/jwx/v2 v2.1.6/go.mod h1:Y722kU5r/8mV7fYDifjug0r8FK8mZdw0K0GpJw/l8pU=
|
||||
github.com/lestrrat-go/option v1.0.1 h1:oAzP2fvZGQKWkvHa1/SAcFolBEca1oN+mQ7eooNBEYU=
|
||||
github.com/lestrrat-go/option v1.0.1/go.mod h1:5ZHFbivi4xwXxhxY9XHDe2FHo6/Z7WWmtT7T5nBBp3I=
|
||||
github.com/mattn/go-isatty v0.0.12/go.mod h1:cbi8OIDigv2wuxKPP5vlRcQ1OAZbq2CE4Kysco4FUpU=
|
||||
github.com/mattn/go-isatty v0.0.14/go.mod h1:7GGIvUiUoEMVVmxf/4nioHXj79iQHKdU27kJ6hsGG94=
|
||||
github.com/mattn/go-isatty v0.0.20 h1:xfD0iDuEKnDkl03q4limB+vH+GxLEtL/jb4xVJSWWEY=
|
||||
@@ -160,6 +196,8 @@ github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd/go.mod h1:6dJ
|
||||
github.com/modern-go/reflect2 v0.0.0-20180701023420-4b7aa43c6742/go.mod h1:bx2lNnkwVCuqBIxFjflWJWanXIb3RllmbCylyMrvgv0=
|
||||
github.com/modern-go/reflect2 v1.0.2 h1:xBagoLtFs94CBntxluKeaWgTMpvLxC4ur3nMaC9Gz0M=
|
||||
github.com/modern-go/reflect2 v1.0.2/go.mod h1:yWuevngMOJpCy52FWWMvUC8ws7m/LJsjYzDa0/r8luk=
|
||||
github.com/moul/http2curl v1.0.0 h1:dRMWoAtb+ePxMlLkrCbAqh4TlPHXvoGUSQ323/9Zahs=
|
||||
github.com/moul/http2curl v1.0.0/go.mod h1:8UbvGypXm98wA/IqH45anm5Y2Z6ep6O31QGOAZ3H0fQ=
|
||||
github.com/nxadm/tail v1.4.8 h1:nPr65rt6Y5JFSKQO7qToXr7pePgD6Gwiw05lkbyAQTE=
|
||||
github.com/nxadm/tail v1.4.8/go.mod h1:+ncqLTQzXmGhMZNUePPaPqPvBxHAIsmXswZKocGu+AU=
|
||||
github.com/onsi/ginkgo v1.16.5 h1:8xi0RTUf59SOSfEtZMvwTvXYMzG4gV23XVHOZiXNtnE=
|
||||
@@ -184,10 +222,18 @@ github.com/rogpeppe/go-internal v1.8.0 h1:FCbCCtXNOY3UtUuHUYaghJg4y7Fd14rXifAYUA
|
||||
github.com/rogpeppe/go-internal v1.8.0/go.mod h1:WmiCO8CzOY8rg0OYDC4/i/2WRWAB6poM+XZ2dLUbcbE=
|
||||
github.com/samber/lo v1.39.0 h1:4gTz1wUhNYLhFSKl6O+8peW0v2F4BCY034GRpU9WnuA=
|
||||
github.com/samber/lo v1.39.0/go.mod h1:+m/ZKRl6ClXCE2Lgf3MsQlWfh4bn1bz6CXEOxnEXnEA=
|
||||
github.com/segmentio/asm v1.2.0 h1:9BQrFxC+YOHJlTlHGkTrFWf59nbL3XnCoFLTwDCI7ys=
|
||||
github.com/segmentio/asm v1.2.0/go.mod h1:BqMnlJP91P8d+4ibuonYZw9mfnzI9HfxselHZr5aAcs=
|
||||
github.com/sergi/go-diff v1.1.0 h1:we8PVUC3FE2uYfodKH/nBHMSetSfHDR6scGdBi+erh0=
|
||||
github.com/sergi/go-diff v1.1.0/go.mod h1:STckp+ISIX8hZLjrqAeVduY0gWCT9IjLuqbuNXdaHfM=
|
||||
github.com/shirou/gopsutil v3.21.11+incompatible h1:+1+c1VGhc88SSonWP6foOcLhvnKlUeu/erjjvaPEYiI=
|
||||
github.com/shirou/gopsutil v3.21.11+incompatible/go.mod h1:5b4v6he4MtMOwMlS0TUMTu2PcXUg8+E1lC7eC3UO/RA=
|
||||
github.com/shopspring/decimal v1.4.0 h1:bxl37RwXBklmTi0C79JfXCEBD1cqqHt0bbgBAGFp81k=
|
||||
github.com/shopspring/decimal v1.4.0/go.mod h1:gawqmDU56v4yIKSwfBSFip1HdCCXN8/+DMd9qYNcwME=
|
||||
github.com/smartystreets/assertions v1.1.0 h1:MkTeG1DMwsrdH7QtLXy5W+fUxWq+vmb6cLmyJ7aRtF0=
|
||||
github.com/smartystreets/assertions v1.1.0/go.mod h1:tcbTF8ujkAEcZ8TElKY+i30BzYlVhC/LOxJk7iOWnoo=
|
||||
github.com/smartystreets/goconvey v1.6.4 h1:fv0U8FUIMPNf1L9lnHLvLhgicrIVChEkdzIKYqbNC9s=
|
||||
github.com/smartystreets/goconvey v1.6.4/go.mod h1:syvi0/a8iFYH4r/RixwvyeAJjdLS9QV7WQ/tjFTllLA=
|
||||
github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME=
|
||||
github.com/stretchr/objx v0.4.0/go.mod h1:YvHI0jy2hoMjB+UWwv71VJQ9isScKT/TqJzVSSt89Yw=
|
||||
github.com/stretchr/objx v0.5.0/go.mod h1:Yh+to48EsGEfYuaHDzXPcE3xhTkx73EhmCGUpEOglKo=
|
||||
@@ -200,21 +246,35 @@ github.com/stretchr/testify v1.7.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/
|
||||
github.com/stretchr/testify v1.8.0/go.mod h1:yNjHg4UonilssWZ8iaSj1OCr/vHnekPRkoO+kdMU+MU=
|
||||
github.com/stretchr/testify v1.8.1/go.mod h1:w2LPCIKwWwSfY2zedu0+kehJoqGctiVI29o6fzry7u4=
|
||||
github.com/stretchr/testify v1.8.4/go.mod h1:sz/lmYIOXD/1dqDmKjjqLyZ2RngseejIcXlSw2iwfAo=
|
||||
github.com/stretchr/testify v1.9.0 h1:HtqpIVDClZ4nwg75+f6Lvsy/wHu+3BoSGCbBAcpTsTg=
|
||||
github.com/stretchr/testify v1.9.0/go.mod h1:r2ic/lqez/lEtzL7wO/rwa5dbSLXVDPFyf8C91i36aY=
|
||||
github.com/stretchr/testify v1.10.0 h1:Xv5erBjTwe/5IxqUQTdXv5kgmIvbHo3QQyRwhJsOfJA=
|
||||
github.com/stretchr/testify v1.10.0/go.mod h1:r2ic/lqez/lEtzL7wO/rwa5dbSLXVDPFyf8C91i36aY=
|
||||
github.com/stripe/stripe-go/v81 v81.4.0 h1:AuD9XzdAvl193qUCSaLocf8H+nRopOouXhxqJUzCLbw=
|
||||
github.com/stripe/stripe-go/v81 v81.4.0/go.mod h1:C/F4jlmnGNacvYtBp/LUHCvVUJEZffFQCobkzwY1WOo=
|
||||
github.com/thanhpk/randstr v1.0.6 h1:psAOktJFD4vV9NEVb3qkhRSMvYh4ORRaj1+w/hn4B+o=
|
||||
github.com/thanhpk/randstr v1.0.6/go.mod h1:M/H2P1eNLZzlDwAzpkkkUvoyNNMbzRGhESZuEQk3r0U=
|
||||
github.com/tidwall/btree v0.0.0-20191029221954-400434d76274 h1:G6Z6HvJuPjG6XfNGi/feOATzeJrfgTNJY+rGrHbA04E=
|
||||
github.com/tidwall/btree v0.0.0-20191029221954-400434d76274/go.mod h1:huei1BkDWJ3/sLXmO+bsCNELL+Bp2Kks9OLyQFkzvA8=
|
||||
github.com/tidwall/buntdb v1.1.2 h1:noCrqQXL9EKMtcdwJcmuVKSEjqu1ua99RHHgbLTEHRo=
|
||||
github.com/tidwall/buntdb v1.1.2/go.mod h1:xAzi36Hir4FarpSHyfuZ6JzPJdjRZ8QlLZSntE2mqlI=
|
||||
github.com/tidwall/gjson v1.3.4/go.mod h1:P256ACg0Mn+j1RXIDXoss50DeIABTYK1PULOJHhxOls=
|
||||
github.com/tidwall/gjson v1.14.2/go.mod h1:/wbyibRr2FHMks5tjHJ5F8dMZh3AcwJEMf5vlfC0lxk=
|
||||
github.com/tidwall/gjson v1.18.0 h1:FIDeeyB800efLX89e5a8Y0BNH+LOngJyGrIWxG2FKQY=
|
||||
github.com/tidwall/gjson v1.18.0/go.mod h1:/wbyibRr2FHMks5tjHJ5F8dMZh3AcwJEMf5vlfC0lxk=
|
||||
github.com/tidwall/grect v0.0.0-20161006141115-ba9a043346eb h1:5NSYaAdrnblKByzd7XByQEJVT8+9v0W/tIY0Oo4OwrE=
|
||||
github.com/tidwall/grect v0.0.0-20161006141115-ba9a043346eb/go.mod h1:lKYYLFIr9OIgdgrtgkZ9zgRxRdvPYsExnYBsEAd8W5M=
|
||||
github.com/tidwall/match v1.0.1/go.mod h1:LujAq0jyVjBy028G1WhWfIzbpQfMO8bBZ6Tyb0+pL9E=
|
||||
github.com/tidwall/match v1.1.1 h1:+Ho715JplO36QYgwN9PGYNhgZvoUSc9X2c80KVTi+GA=
|
||||
github.com/tidwall/match v1.1.1/go.mod h1:eRSPERbgtNPcGhD8UCthc6PmLEQXEWd3PRB5JTxsfmM=
|
||||
github.com/tidwall/pretty v1.0.0/go.mod h1:XNkn88O1ChpSDQmQeStsy+sBenx6DDtFZJxhVysOjyk=
|
||||
github.com/tidwall/pretty v1.2.0 h1:RWIZEg2iJ8/g6fDDYzMpobmaoGh5OLl4AXtGUGPcqCs=
|
||||
github.com/tidwall/pretty v1.2.0/go.mod h1:ITEVvHYasfjBbM0u2Pg8T2nJnzm8xPwvNhhsoaGGjNU=
|
||||
github.com/tidwall/rtree v0.0.0-20180113144539-6cd427091e0e h1:+NL1GDIUOKxVfbp2KoJQD9cTQ6dyP2co9q4yzmT9FZo=
|
||||
github.com/tidwall/rtree v0.0.0-20180113144539-6cd427091e0e/go.mod h1:/h+UnNGt0IhNNJLkGikcdcJqm66zGD/uJGMRxK/9+Ao=
|
||||
github.com/tidwall/sjson v1.2.5 h1:kLy8mja+1c9jlljvWTlSazM7cKDRfJuR/bOJhcY5NcY=
|
||||
github.com/tidwall/sjson v1.2.5/go.mod h1:Fvgq9kS/6ociJEDnK0Fk1cpYF4FIW6ZF7LAe+6jwd28=
|
||||
github.com/tidwall/tinyqueue v0.0.0-20180302190814-1e39f5511563 h1:Otn9S136ELckZ3KKDyCkxapfufrqDqwmGjcHfAyXRrE=
|
||||
github.com/tidwall/tinyqueue v0.0.0-20180302190814-1e39f5511563/go.mod h1:mLqSmt7Dv/CNneF2wfcChfN1rvapyQr01LGKnKex0DQ=
|
||||
github.com/tiktoken-go/tokenizer v0.6.2 h1:t0GN2DvcUZSFWT/62YOgoqb10y7gSXBGs0A+4VCQK+g=
|
||||
github.com/tiktoken-go/tokenizer v0.6.2/go.mod h1:6UCYI/DtOallbmL7sSy30p6YQv60qNyU/4aVigPOx6w=
|
||||
github.com/tklauser/go-sysconf v0.3.12 h1:0QaGUFOdQaIVdPgfITYzaTegZvdCjmYO52cSFAEVmqU=
|
||||
@@ -229,8 +289,24 @@ github.com/ugorji/go/codec v1.1.7/go.mod h1:Ax+UKWsSmolVDwsd+7N3ZtXu+yMGCf907BLY
|
||||
github.com/ugorji/go/codec v1.2.7/go.mod h1:WGN1fab3R1fzQlVQTkfxVtIBhWDRqOviHU95kRgeqEY=
|
||||
github.com/ugorji/go/codec v1.2.12 h1:9LC83zGrHhuUA9l16C9AHXAqEV/2wBQ4nkvumAE65EE=
|
||||
github.com/ugorji/go/codec v1.2.12/go.mod h1:UNopzCgEMSXjBc6AOMqYvWC1ktqTAfzJZUZgYf6w6lg=
|
||||
github.com/valyala/bytebufferpool v1.0.0 h1:GqA5TC/0021Y/b9FG4Oi9Mr3q7XYx6KllzawFIhcdPw=
|
||||
github.com/valyala/bytebufferpool v1.0.0/go.mod h1:6bBcMArwyJ5K/AmCkWv1jt77kVWyCJ6HpOuEn7z0Csc=
|
||||
github.com/valyala/fasthttp v1.34.0 h1:d3AAQJ2DRcxJYHm7OXNXtXt2as1vMDfxeIcFvhmGGm4=
|
||||
github.com/valyala/fasthttp v1.34.0/go.mod h1:epZA5N+7pY6ZaEKRmstzOuYJx9HI8DI1oaCGZpdH4h0=
|
||||
github.com/xeipuuv/gojsonpointer v0.0.0-20180127040702-4e3ac2762d5f h1:J9EGpcZtP0E/raorCMxlFGSTBrsSlaDGf3jU/qvAE2c=
|
||||
github.com/xeipuuv/gojsonpointer v0.0.0-20180127040702-4e3ac2762d5f/go.mod h1:N2zxlSyiKSe5eX1tZViRH5QA0qijqEDrYZiPEAiq3wU=
|
||||
github.com/xeipuuv/gojsonreference v0.0.0-20180127040603-bd5ef7bd5415 h1:EzJWgHovont7NscjpAxXsDA8S8BMYve8Y5+7cuRE7R0=
|
||||
github.com/xeipuuv/gojsonreference v0.0.0-20180127040603-bd5ef7bd5415/go.mod h1:GwrjFmJcFw6At/Gs6z4yjiIwzuJ1/+UwLxMQDVQXShQ=
|
||||
github.com/xeipuuv/gojsonschema v1.2.0 h1:LhYJRs+L4fBtjZUfuSZIKGeVu0QRy8e5Xi7D17UxZ74=
|
||||
github.com/xeipuuv/gojsonschema v1.2.0/go.mod h1:anYRn/JVcOK2ZgGU+IjEV4nwlhoK5sQluxsYJ78Id3Y=
|
||||
github.com/xyproto/randomstring v1.0.5 h1:YtlWPoRdgMu3NZtP45drfy1GKoojuR7hmRcnhZqKjWU=
|
||||
github.com/xyproto/randomstring v1.0.5/go.mod h1:rgmS5DeNXLivK7YprL0pY+lTuhNQW3iGxZ18UQApw/E=
|
||||
github.com/yalp/jsonpath v0.0.0-20180802001716-5cc68e5049a0 h1:6fRhSjgLCkTD3JnJxvaJ4Sj+TYblw757bqYgZaOq5ZY=
|
||||
github.com/yalp/jsonpath v0.0.0-20180802001716-5cc68e5049a0/go.mod h1:/LWChgwKmvncFJFHJ7Gvn9wZArjbV5/FppcK2fKk/tI=
|
||||
github.com/yudai/gojsondiff v1.0.0 h1:27cbfqXLVEJ1o8I6v3y9lg8Ydm53EKqHXAOMxEGlCOA=
|
||||
github.com/yudai/gojsondiff v1.0.0/go.mod h1:AY32+k2cwILAkW1fbgxQ5mUmMiZFgLIV+FBNExI05xg=
|
||||
github.com/yudai/golcs v0.0.0-20170316035057-ecda9a501e82 h1:BHyfKlQyqbsFN5p3IfnEUduWvb9is428/nNb5L3U01M=
|
||||
github.com/yudai/golcs v0.0.0-20170316035057-ecda9a501e82/go.mod h1:lgjkn3NuSvDfVJdfcVVdX+jpBxNmX4rDAzaS45IcYoM=
|
||||
github.com/yusufpapurcu/wmi v1.2.3 h1:E1ctvB7uKFMOJw3fdOW32DwGE9I7t++CRUEMKvFoFiw=
|
||||
github.com/yusufpapurcu/wmi v1.2.3/go.mod h1:SBZ9tNy3G9/m5Oi98Zks0QjeHVDvuK0qfxQmPyzfmi0=
|
||||
golang.org/x/arch v0.0.0-20210923205945-b76863e36670/go.mod h1:5om86z9Hs0C8fWVUuoMHwpExlXzs5Tkyp9hOrfG7pp8=
|
||||
@@ -247,6 +323,8 @@ golang.org/x/net v0.0.0-20210226172049-e18ecbb05110/go.mod h1:m0MpNAwzfU5UDzcl9v
|
||||
golang.org/x/net v0.0.0-20210520170846-37e1c6afe023/go.mod h1:9nx3DQGgdP8bBQD5qxJ1jj9UTztislL4KSBs9R2vV5Y=
|
||||
golang.org/x/net v0.35.0 h1:T5GQRQb2y08kTAByq9L4/bz8cipCdA8FbRTXewonqY8=
|
||||
golang.org/x/net v0.35.0/go.mod h1:EglIi67kWsHKlRzzVMUD93VMSWGFOMSZgxFjparz1Qk=
|
||||
golang.org/x/oauth2 v0.30.0 h1:dnDm7JmhM45NNpd8FDDeLhK6FwqbOf4MLCM9zb1BOHI=
|
||||
golang.org/x/oauth2 v0.30.0/go.mod h1:B++QgG3ZKulg6sRPGD/mqlHQs5rB3Ml9erfeDY7xKlU=
|
||||
golang.org/x/sync v0.0.0-20210220032951-036812b2e83c/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
|
||||
golang.org/x/sync v0.11.0 h1:GGz8+XQP4FvTTrjZPzNKTMFtSXH80RAzG+5ghFPgK9w=
|
||||
golang.org/x/sync v0.11.0/go.mod h1:Czt+wKu1gCyEFDUtn0jG5QVvpJ6rzVqr5aXyt9drQfk=
|
||||
@@ -257,12 +335,12 @@ golang.org/x/sys v0.0.0-20210423082822-04245dca01da/go.mod h1:h1NjWce9XRLGQEsW7w
|
||||
golang.org/x/sys v0.0.0-20210615035016-665e8c7367d1/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||
golang.org/x/sys v0.0.0-20210630005230-0f9fa26af87c/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||
golang.org/x/sys v0.0.0-20210806184541-e5e7981a1069/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||
golang.org/x/sys v0.0.0-20220110181412-a018aaa089fe/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||
golang.org/x/sys v0.0.0-20221010170243-090e33056c14/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||
golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||
golang.org/x/sys v0.8.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||
golang.org/x/sys v0.11.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||
golang.org/x/sys v0.30.0 h1:QjkSwP/36a20jFYWkSue1YwXzLmsV5Gfq7Eiy72C1uc=
|
||||
golang.org/x/sys v0.30.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA=
|
||||
golang.org/x/sys v0.31.0 h1:ioabZlmFYtWhL+TRYpcnNlLwhyxaM9kWTDEmfnprqik=
|
||||
golang.org/x/sys v0.31.0/go.mod h1:BJP2sWEmIv4KK5OTEluFJCKSidICx8ciO85XgH3Ak8k=
|
||||
golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo=
|
||||
golang.org/x/text v0.3.2/go.mod h1:bEr9sfX3Q8Zfm5fL9x+3itogRgK3+ptLWKqgva+5dAk=
|
||||
golang.org/x/text v0.3.3/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ=
|
||||
|
||||
35
main.go
35
main.go
@@ -1,7 +1,6 @@
|
||||
package main
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"embed"
|
||||
"fmt"
|
||||
"log"
|
||||
@@ -15,10 +14,9 @@ import (
|
||||
"one-api/router"
|
||||
"one-api/service"
|
||||
"one-api/setting/ratio_setting"
|
||||
"one-api/src/oauth"
|
||||
"os"
|
||||
"strconv"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/bytedance/gopkg/util/gopool"
|
||||
"github.com/gin-contrib/sessions"
|
||||
@@ -36,7 +34,6 @@ var buildFS embed.FS
|
||||
var indexPage []byte
|
||||
|
||||
func main() {
|
||||
startTime := time.Now()
|
||||
|
||||
err := InitResources()
|
||||
if err != nil {
|
||||
@@ -149,31 +146,11 @@ func main() {
|
||||
})
|
||||
server.Use(sessions.Sessions("session", store))
|
||||
|
||||
analyticsInjectBuilder := &strings.Builder{}
|
||||
if os.Getenv("UMAMI_WEBSITE_ID") != "" {
|
||||
umamiSiteID := os.Getenv("UMAMI_WEBSITE_ID")
|
||||
umamiScriptURL := os.Getenv("UMAMI_SCRIPT_URL")
|
||||
if umamiScriptURL == "" {
|
||||
umamiScriptURL = "https://analytics.umami.is/script.js"
|
||||
}
|
||||
analyticsInjectBuilder.WriteString("<script defer src=\"")
|
||||
analyticsInjectBuilder.WriteString(umamiScriptURL)
|
||||
analyticsInjectBuilder.WriteString("\" data-website-id=\"")
|
||||
analyticsInjectBuilder.WriteString(umamiSiteID)
|
||||
analyticsInjectBuilder.WriteString("\"></script>")
|
||||
}
|
||||
analyticsInject := analyticsInjectBuilder.String()
|
||||
indexPage = bytes.ReplaceAll(indexPage, []byte("<analytics></analytics>\n"), []byte(analyticsInject))
|
||||
|
||||
router.SetRouter(server, buildFS, indexPage)
|
||||
var port = os.Getenv("PORT")
|
||||
if port == "" {
|
||||
port = strconv.Itoa(*common.Port)
|
||||
}
|
||||
|
||||
// Log startup success message
|
||||
common.LogStartupSuccess(startTime, port)
|
||||
|
||||
err = server.Run(":" + port)
|
||||
if err != nil {
|
||||
common.FatalLog("failed to start HTTP server: " + err.Error())
|
||||
@@ -227,5 +204,13 @@ func InitResources() error {
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Initialize OAuth2 server
|
||||
err = oauth.InitOAuthServer()
|
||||
if err != nil {
|
||||
common.SysLog("Warning: Failed to initialize OAuth2 server: " + err.Error())
|
||||
// OAuth2 失败不应该阻止系统启动
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
}
|
||||
@@ -8,11 +8,14 @@ import (
|
||||
"one-api/model"
|
||||
"one-api/setting"
|
||||
"one-api/setting/ratio_setting"
|
||||
"one-api/setting/system_setting"
|
||||
"one-api/src/oauth"
|
||||
"strconv"
|
||||
"strings"
|
||||
|
||||
"github.com/gin-contrib/sessions"
|
||||
"github.com/gin-gonic/gin"
|
||||
jwt "github.com/golang-jwt/jwt/v5"
|
||||
)
|
||||
|
||||
func validUserInfo(username string, role int) bool {
|
||||
@@ -177,6 +180,7 @@ func WssAuth(c *gin.Context) {
|
||||
|
||||
func TokenAuth() func(c *gin.Context) {
|
||||
return func(c *gin.Context) {
|
||||
rawAuth := c.Request.Header.Get("Authorization")
|
||||
// 先检测是否为ws
|
||||
if c.Request.Header.Get("Sec-WebSocket-Protocol") != "" {
|
||||
// Sec-WebSocket-Protocol: realtime, openai-insecure-api-key.sk-xxx, openai-beta.realtime-v1
|
||||
@@ -235,6 +239,11 @@ func TokenAuth() func(c *gin.Context) {
|
||||
}
|
||||
}
|
||||
if err != nil {
|
||||
// OAuth Bearer fallback
|
||||
if tryOAuthBearer(c, rawAuth) {
|
||||
c.Next()
|
||||
return
|
||||
}
|
||||
abortWithOpenAiMessage(c, http.StatusUnauthorized, err.Error())
|
||||
return
|
||||
}
|
||||
@@ -288,6 +297,74 @@ func TokenAuth() func(c *gin.Context) {
|
||||
}
|
||||
}
|
||||
|
||||
// tryOAuthBearer validates an OAuth JWT access token and sets minimal context for relay
|
||||
func tryOAuthBearer(c *gin.Context, rawAuth string) bool {
|
||||
if rawAuth == "" || !strings.HasPrefix(rawAuth, "Bearer ") {
|
||||
return false
|
||||
}
|
||||
tokenString := strings.TrimSpace(strings.TrimPrefix(rawAuth, "Bearer "))
|
||||
if tokenString == "" {
|
||||
return false
|
||||
}
|
||||
settings := system_setting.GetOAuth2Settings()
|
||||
// Parse & verify
|
||||
parsed, err := jwt.Parse(tokenString, func(t *jwt.Token) (interface{}, error) {
|
||||
if _, ok := t.Method.(*jwt.SigningMethodRSA); !ok {
|
||||
return nil, jwt.ErrTokenSignatureInvalid
|
||||
}
|
||||
if kid, ok := t.Header["kid"].(string); ok {
|
||||
if settings.JWTKeyID != "" && kid != settings.JWTKeyID {
|
||||
return nil, jwt.ErrTokenSignatureInvalid
|
||||
}
|
||||
}
|
||||
pub := oauth.GetRSAPublicKey()
|
||||
if pub == nil {
|
||||
return nil, jwt.ErrTokenUnverifiable
|
||||
}
|
||||
return pub, nil
|
||||
})
|
||||
if err != nil || parsed == nil || !parsed.Valid {
|
||||
return false
|
||||
}
|
||||
claims, ok := parsed.Claims.(jwt.MapClaims)
|
||||
if !ok {
|
||||
return false
|
||||
}
|
||||
// issuer check when configured
|
||||
if iss, ok2 := claims["iss"].(string); !ok2 || (settings.Issuer != "" && iss != settings.Issuer) {
|
||||
return false
|
||||
}
|
||||
// revoke check
|
||||
if jti, ok2 := claims["jti"].(string); ok2 && jti != "" {
|
||||
if revoked, _ := model.IsTokenRevoked(jti); revoked {
|
||||
return false
|
||||
}
|
||||
}
|
||||
// scope check: must contain api:read or api:write or admin
|
||||
scope, _ := claims["scope"].(string)
|
||||
scopePadded := " " + scope + " "
|
||||
if !(strings.Contains(scopePadded, " api:read ") || strings.Contains(scopePadded, " api:write ") || strings.Contains(scopePadded, " admin ")) {
|
||||
return false
|
||||
}
|
||||
// subject must be user id to support quota logic
|
||||
sub, _ := claims["sub"].(string)
|
||||
uid, err := strconv.Atoi(sub)
|
||||
if err != nil || uid <= 0 {
|
||||
return false
|
||||
}
|
||||
// load user cache & set context
|
||||
userCache, err := model.GetUserCache(uid)
|
||||
if err != nil || userCache == nil || userCache.Status != common.UserStatusEnabled {
|
||||
return false
|
||||
}
|
||||
c.Set("id", uid)
|
||||
c.Set("group", userCache.Group)
|
||||
c.Set("user_group", userCache.Group)
|
||||
// set UsingGroup
|
||||
common.SetContextKey(c, constant.ContextKeyUsingGroup, userCache.Group)
|
||||
return true
|
||||
}
|
||||
|
||||
func SetupContextForToken(c *gin.Context, token *model.Token, parts ...string) error {
|
||||
if token == nil {
|
||||
return fmt.Errorf("token is nil")
|
||||
|
||||
291
middleware/oauth_jwt.go
Normal file
291
middleware/oauth_jwt.go
Normal file
@@ -0,0 +1,291 @@
|
||||
package middleware
|
||||
|
||||
import (
|
||||
"crypto/rsa"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"one-api/common"
|
||||
"one-api/model"
|
||||
"one-api/setting/system_setting"
|
||||
"one-api/src/oauth"
|
||||
"strings"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/golang-jwt/jwt/v5"
|
||||
)
|
||||
|
||||
// OAuthJWTAuth OAuth2 JWT认证中间件
|
||||
func OAuthJWTAuth() gin.HandlerFunc {
|
||||
return func(c *gin.Context) {
|
||||
// 检查OAuth2是否启用
|
||||
settings := system_setting.GetOAuth2Settings()
|
||||
if !settings.Enabled {
|
||||
c.Next()
|
||||
return
|
||||
}
|
||||
|
||||
// 获取Authorization header
|
||||
authHeader := c.GetHeader("Authorization")
|
||||
if authHeader == "" {
|
||||
c.Next() // 没有Authorization header,继续到下一个中间件
|
||||
return
|
||||
}
|
||||
|
||||
// 检查是否为Bearer token
|
||||
if !strings.HasPrefix(authHeader, "Bearer ") {
|
||||
c.Next() // 不是Bearer token,继续到下一个中间件
|
||||
return
|
||||
}
|
||||
|
||||
tokenString := strings.TrimPrefix(authHeader, "Bearer ")
|
||||
if tokenString == "" {
|
||||
abortWithOAuthError(c, "invalid_token", "Missing token")
|
||||
return
|
||||
}
|
||||
|
||||
// 验证JWT token
|
||||
claims, err := validateOAuthJWT(tokenString)
|
||||
if err != nil {
|
||||
abortWithOAuthError(c, "invalid_token", err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
// 验证token的有效性
|
||||
if err := validateOAuthClaims(claims); err != nil {
|
||||
abortWithOAuthError(c, "invalid_token", err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
// 设置上下文信息
|
||||
setOAuthContext(c, claims)
|
||||
c.Next()
|
||||
}
|
||||
}
|
||||
|
||||
// validateOAuthJWT 验证OAuth2 JWT令牌
|
||||
func validateOAuthJWT(tokenString string) (jwt.MapClaims, error) {
|
||||
// 解析JWT而不验证签名(先获取header中的kid)
|
||||
token, err := jwt.Parse(tokenString, func(token *jwt.Token) (interface{}, error) {
|
||||
// 检查签名方法
|
||||
if _, ok := token.Method.(*jwt.SigningMethodRSA); !ok {
|
||||
return nil, fmt.Errorf("unexpected signing method: %v", token.Header["alg"])
|
||||
}
|
||||
|
||||
// 获取kid
|
||||
kid, ok := token.Header["kid"].(string)
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("missing kid in token header")
|
||||
}
|
||||
|
||||
// 根据kid获取公钥
|
||||
publicKey, err := getPublicKeyByKid(kid)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to get public key: %w", err)
|
||||
}
|
||||
|
||||
return publicKey, nil
|
||||
})
|
||||
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to parse token: %w", err)
|
||||
}
|
||||
|
||||
if !token.Valid {
|
||||
return nil, fmt.Errorf("invalid token")
|
||||
}
|
||||
|
||||
claims, ok := token.Claims.(jwt.MapClaims)
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("invalid token claims")
|
||||
}
|
||||
|
||||
return claims, nil
|
||||
}
|
||||
|
||||
// getPublicKeyByKid 根据kid获取公钥
|
||||
func getPublicKeyByKid(kid string) (*rsa.PublicKey, error) {
|
||||
// 这里需要从JWKS获取公钥
|
||||
// 在实际实现中,你可能需要从OAuth server获取JWKS
|
||||
// 这里先实现一个简单版本
|
||||
|
||||
// TODO: 实现JWKS缓存和刷新机制
|
||||
pub := oauth.GetPublicKeyByKid(kid)
|
||||
if pub == nil {
|
||||
return nil, fmt.Errorf("unknown kid: %s", kid)
|
||||
}
|
||||
return pub, nil
|
||||
}
|
||||
|
||||
// validateOAuthClaims 验证OAuth2 claims
|
||||
func validateOAuthClaims(claims jwt.MapClaims) error {
|
||||
settings := system_setting.GetOAuth2Settings()
|
||||
|
||||
// 验证issuer(若配置了 Issuer 则强校验,否则仅要求存在)
|
||||
if iss, ok := claims["iss"].(string); ok {
|
||||
if settings.Issuer != "" && iss != settings.Issuer {
|
||||
return fmt.Errorf("invalid issuer")
|
||||
}
|
||||
} else {
|
||||
return fmt.Errorf("missing issuer claim")
|
||||
}
|
||||
|
||||
// 验证audience
|
||||
// if aud, ok := claims["aud"].(string); ok {
|
||||
// // TODO: 验证audience
|
||||
// }
|
||||
|
||||
// 验证客户端ID
|
||||
if clientID, ok := claims["client_id"].(string); ok {
|
||||
// 验证客户端是否存在且有效
|
||||
client, err := model.GetOAuthClientByID(clientID)
|
||||
if err != nil {
|
||||
return fmt.Errorf("invalid client")
|
||||
}
|
||||
if client.Status != common.UserStatusEnabled {
|
||||
return fmt.Errorf("client disabled")
|
||||
}
|
||||
|
||||
// 检查是否被撤销
|
||||
if jti, ok := claims["jti"].(string); ok && jti != "" {
|
||||
revoked, _ := model.IsTokenRevoked(jti)
|
||||
if revoked {
|
||||
return fmt.Errorf("token revoked")
|
||||
}
|
||||
}
|
||||
} else {
|
||||
return fmt.Errorf("missing client_id claim")
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// setOAuthContext 设置OAuth上下文信息
|
||||
func setOAuthContext(c *gin.Context, claims jwt.MapClaims) {
|
||||
c.Set("oauth_claims", claims)
|
||||
c.Set("oauth_authenticated", true)
|
||||
|
||||
// 提取基本信息
|
||||
if clientID, ok := claims["client_id"].(string); ok {
|
||||
c.Set("oauth_client_id", clientID)
|
||||
}
|
||||
|
||||
if scope, ok := claims["scope"].(string); ok {
|
||||
c.Set("oauth_scope", scope)
|
||||
}
|
||||
|
||||
if sub, ok := claims["sub"].(string); ok {
|
||||
c.Set("oauth_subject", sub)
|
||||
}
|
||||
|
||||
// 对于client_credentials流程,subject就是client_id
|
||||
// 对于authorization_code流程,subject是用户ID
|
||||
if grantType, ok := claims["grant_type"].(string); ok {
|
||||
c.Set("oauth_grant_type", grantType)
|
||||
}
|
||||
}
|
||||
|
||||
// abortWithOAuthError 返回OAuth错误响应
|
||||
func abortWithOAuthError(c *gin.Context, errorCode, description string) {
|
||||
c.Header("WWW-Authenticate", fmt.Sprintf(`Bearer error="%s", error_description="%s"`, errorCode, description))
|
||||
c.JSON(http.StatusUnauthorized, gin.H{
|
||||
"error": errorCode,
|
||||
"error_description": description,
|
||||
})
|
||||
c.Abort()
|
||||
}
|
||||
|
||||
// RequireOAuthScope OAuth2 scope验证中间件
|
||||
func RequireOAuthScope(requiredScope string) gin.HandlerFunc {
|
||||
return func(c *gin.Context) {
|
||||
// 检查是否通过OAuth认证
|
||||
if !c.GetBool("oauth_authenticated") {
|
||||
abortWithOAuthError(c, "insufficient_scope", "OAuth2 authentication required")
|
||||
return
|
||||
}
|
||||
|
||||
// 获取token的scope
|
||||
scope, exists := c.Get("oauth_scope")
|
||||
if !exists {
|
||||
abortWithOAuthError(c, "insufficient_scope", "No scope in token")
|
||||
return
|
||||
}
|
||||
|
||||
scopeStr, ok := scope.(string)
|
||||
if !ok {
|
||||
abortWithOAuthError(c, "insufficient_scope", "Invalid scope format")
|
||||
return
|
||||
}
|
||||
|
||||
// 检查是否包含所需的scope
|
||||
scopes := strings.Split(scopeStr, " ")
|
||||
for _, s := range scopes {
|
||||
if strings.TrimSpace(s) == requiredScope {
|
||||
c.Next()
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
abortWithOAuthError(c, "insufficient_scope", fmt.Sprintf("Required scope: %s", requiredScope))
|
||||
}
|
||||
}
|
||||
|
||||
// OptionalOAuthAuth 可选的OAuth认证中间件(不会阻止请求)
|
||||
func OptionalOAuthAuth() gin.HandlerFunc {
|
||||
return func(c *gin.Context) {
|
||||
// 尝试OAuth认证,但不会阻止请求
|
||||
authHeader := c.GetHeader("Authorization")
|
||||
if authHeader != "" && strings.HasPrefix(authHeader, "Bearer ") {
|
||||
tokenString := strings.TrimPrefix(authHeader, "Bearer ")
|
||||
if claims, err := validateOAuthJWT(tokenString); err == nil {
|
||||
if validateOAuthClaims(claims) == nil {
|
||||
setOAuthContext(c, claims)
|
||||
}
|
||||
}
|
||||
}
|
||||
c.Next()
|
||||
}
|
||||
}
|
||||
|
||||
// RequireOAuthScopeIfPresent enforces scope only when OAuth is present; otherwise no-op
|
||||
func RequireOAuthScopeIfPresent(requiredScope string) gin.HandlerFunc {
|
||||
return func(c *gin.Context) {
|
||||
if !c.GetBool("oauth_authenticated") {
|
||||
c.Next()
|
||||
return
|
||||
}
|
||||
scope, exists := c.Get("oauth_scope")
|
||||
if !exists {
|
||||
abortWithOAuthError(c, "insufficient_scope", "No scope in token")
|
||||
return
|
||||
}
|
||||
scopeStr, ok := scope.(string)
|
||||
if !ok {
|
||||
abortWithOAuthError(c, "insufficient_scope", "Invalid scope format")
|
||||
return
|
||||
}
|
||||
scopes := strings.Split(scopeStr, " ")
|
||||
for _, s := range scopes {
|
||||
if strings.TrimSpace(s) == requiredScope {
|
||||
c.Next()
|
||||
return
|
||||
}
|
||||
}
|
||||
abortWithOAuthError(c, "insufficient_scope", fmt.Sprintf("Required scope: %s", requiredScope))
|
||||
}
|
||||
}
|
||||
|
||||
// GetOAuthClaims 获取OAuth claims
|
||||
func GetOAuthClaims(c *gin.Context) (jwt.MapClaims, bool) {
|
||||
claims, exists := c.Get("oauth_claims")
|
||||
if !exists {
|
||||
return nil, false
|
||||
}
|
||||
|
||||
mapClaims, ok := claims.(jwt.MapClaims)
|
||||
return mapClaims, ok
|
||||
}
|
||||
|
||||
// IsOAuthAuthenticated 检查是否通过OAuth认证
|
||||
func IsOAuthAuthenticated(c *gin.Context) bool {
|
||||
return c.GetBool("oauth_authenticated")
|
||||
}
|
||||
@@ -265,6 +265,7 @@ func migrateDB() error {
|
||||
&Setup{},
|
||||
&TwoFA{},
|
||||
&TwoFABackupCode{},
|
||||
&OAuthClient{},
|
||||
)
|
||||
if err != nil {
|
||||
return err
|
||||
|
||||
183
model/oauth_client.go
Normal file
183
model/oauth_client.go
Normal file
@@ -0,0 +1,183 @@
|
||||
package model
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"one-api/common"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"gorm.io/gorm"
|
||||
)
|
||||
|
||||
// OAuthClient OAuth2 客户端模型
|
||||
type OAuthClient struct {
|
||||
ID string `json:"id" gorm:"type:varchar(64);primaryKey"`
|
||||
Secret string `json:"secret" gorm:"type:varchar(128);not null"`
|
||||
Name string `json:"name" gorm:"type:varchar(255);not null"`
|
||||
Domain string `json:"domain" gorm:"type:varchar(255)"` // 允许的重定向域名
|
||||
RedirectURIs string `json:"redirect_uris" gorm:"type:text"` // JSON array of redirect URIs
|
||||
GrantTypes string `json:"grant_types" gorm:"type:varchar(255);default:'client_credentials'"`
|
||||
Scopes string `json:"scopes" gorm:"type:varchar(255);default:'api:read'"`
|
||||
RequirePKCE bool `json:"require_pkce" gorm:"default:true"`
|
||||
Status int `json:"status" gorm:"type:int;default:1"` // 1: enabled, 2: disabled
|
||||
CreatedBy int `json:"created_by" gorm:"type:int;not null"` // 创建者用户ID
|
||||
CreatedTime int64 `json:"created_time" gorm:"bigint"`
|
||||
LastUsedTime int64 `json:"last_used_time" gorm:"bigint;default:0"`
|
||||
TokenCount int `json:"token_count" gorm:"type:int;default:0"` // 已签发的token数量
|
||||
Description string `json:"description" gorm:"type:text"`
|
||||
ClientType string `json:"client_type" gorm:"type:varchar(32);default:'confidential'"` // confidential, public
|
||||
DeletedAt gorm.DeletedAt `gorm:"index"`
|
||||
}
|
||||
|
||||
// GetRedirectURIs 获取重定向URI列表
|
||||
func (c *OAuthClient) GetRedirectURIs() []string {
|
||||
if c.RedirectURIs == "" {
|
||||
return []string{}
|
||||
}
|
||||
var uris []string
|
||||
err := json.Unmarshal([]byte(c.RedirectURIs), &uris)
|
||||
if err != nil {
|
||||
common.SysLog("failed to unmarshal redirect URIs: " + err.Error())
|
||||
return []string{}
|
||||
}
|
||||
return uris
|
||||
}
|
||||
|
||||
// SetRedirectURIs 设置重定向URI列表
|
||||
func (c *OAuthClient) SetRedirectURIs(uris []string) {
|
||||
data, err := json.Marshal(uris)
|
||||
if err != nil {
|
||||
common.SysLog("failed to marshal redirect URIs: " + err.Error())
|
||||
return
|
||||
}
|
||||
c.RedirectURIs = string(data)
|
||||
}
|
||||
|
||||
// GetGrantTypes 获取允许的授权类型列表
|
||||
func (c *OAuthClient) GetGrantTypes() []string {
|
||||
if c.GrantTypes == "" {
|
||||
return []string{"client_credentials"}
|
||||
}
|
||||
return strings.Split(c.GrantTypes, ",")
|
||||
}
|
||||
|
||||
// SetGrantTypes 设置允许的授权类型列表
|
||||
func (c *OAuthClient) SetGrantTypes(types []string) {
|
||||
c.GrantTypes = strings.Join(types, ",")
|
||||
}
|
||||
|
||||
// GetScopes 获取允许的scope列表
|
||||
func (c *OAuthClient) GetScopes() []string {
|
||||
if c.Scopes == "" {
|
||||
return []string{"api:read"}
|
||||
}
|
||||
return strings.Split(c.Scopes, ",")
|
||||
}
|
||||
|
||||
// SetScopes 设置允许的scope列表
|
||||
func (c *OAuthClient) SetScopes(scopes []string) {
|
||||
c.Scopes = strings.Join(scopes, ",")
|
||||
}
|
||||
|
||||
// ValidateRedirectURI 验证重定向URI是否有效
|
||||
func (c *OAuthClient) ValidateRedirectURI(uri string) bool {
|
||||
allowedURIs := c.GetRedirectURIs()
|
||||
for _, allowedURI := range allowedURIs {
|
||||
if allowedURI == uri {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// ValidateGrantType 验证授权类型是否被允许
|
||||
func (c *OAuthClient) ValidateGrantType(grantType string) bool {
|
||||
allowedTypes := c.GetGrantTypes()
|
||||
for _, allowedType := range allowedTypes {
|
||||
if allowedType == grantType {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// ValidateScope 验证scope是否被允许
|
||||
func (c *OAuthClient) ValidateScope(scope string) bool {
|
||||
allowedScopes := c.GetScopes()
|
||||
requestedScopes := strings.Split(scope, " ")
|
||||
|
||||
for _, requestedScope := range requestedScopes {
|
||||
requestedScope = strings.TrimSpace(requestedScope)
|
||||
if requestedScope == "" {
|
||||
continue
|
||||
}
|
||||
found := false
|
||||
for _, allowedScope := range allowedScopes {
|
||||
if allowedScope == requestedScope {
|
||||
found = true
|
||||
break
|
||||
}
|
||||
}
|
||||
if !found {
|
||||
return false
|
||||
}
|
||||
}
|
||||
return true
|
||||
}
|
||||
|
||||
// BeforeCreate GORM hook - 在创建前设置时间
|
||||
func (c *OAuthClient) BeforeCreate(tx *gorm.DB) (err error) {
|
||||
c.CreatedTime = time.Now().Unix()
|
||||
return
|
||||
}
|
||||
|
||||
// UpdateLastUsedTime 更新最后使用时间
|
||||
func (c *OAuthClient) UpdateLastUsedTime() error {
|
||||
c.LastUsedTime = time.Now().Unix()
|
||||
c.TokenCount++
|
||||
return DB.Model(c).Select("last_used_time", "token_count").Updates(c).Error
|
||||
}
|
||||
|
||||
// GetOAuthClientByID 根据ID获取OAuth客户端
|
||||
func GetOAuthClientByID(id string) (*OAuthClient, error) {
|
||||
var client OAuthClient
|
||||
err := DB.Where("id = ? AND status = ?", id, common.UserStatusEnabled).First(&client).Error
|
||||
return &client, err
|
||||
}
|
||||
|
||||
// GetAllOAuthClients 获取所有OAuth客户端
|
||||
func GetAllOAuthClients(startIdx int, num int) ([]*OAuthClient, error) {
|
||||
var clients []*OAuthClient
|
||||
err := DB.Order("created_time desc").Limit(num).Offset(startIdx).Find(&clients).Error
|
||||
return clients, err
|
||||
}
|
||||
|
||||
// SearchOAuthClients 搜索OAuth客户端
|
||||
func SearchOAuthClients(keyword string) ([]*OAuthClient, error) {
|
||||
var clients []*OAuthClient
|
||||
err := DB.Where("name LIKE ? OR id LIKE ? OR description LIKE ?",
|
||||
"%"+keyword+"%", "%"+keyword+"%", "%"+keyword+"%").Find(&clients).Error
|
||||
return clients, err
|
||||
}
|
||||
|
||||
// CreateOAuthClient 创建OAuth客户端
|
||||
func CreateOAuthClient(client *OAuthClient) error {
|
||||
return DB.Create(client).Error
|
||||
}
|
||||
|
||||
// UpdateOAuthClient 更新OAuth客户端
|
||||
func UpdateOAuthClient(client *OAuthClient) error {
|
||||
return DB.Save(client).Error
|
||||
}
|
||||
|
||||
// DeleteOAuthClient 删除OAuth客户端
|
||||
func DeleteOAuthClient(id string) error {
|
||||
return DB.Where("id = ?", id).Delete(&OAuthClient{}).Error
|
||||
}
|
||||
|
||||
// CountOAuthClients 统计OAuth客户端数量
|
||||
func CountOAuthClients() (int64, error) {
|
||||
var count int64
|
||||
err := DB.Model(&OAuthClient{}).Count(&count).Error
|
||||
return count, err
|
||||
}
|
||||
57
model/oauth_revoked_token.go
Normal file
57
model/oauth_revoked_token.go
Normal file
@@ -0,0 +1,57 @@
|
||||
package model
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"one-api/common"
|
||||
"sync"
|
||||
"time"
|
||||
)
|
||||
|
||||
var revokedMem sync.Map // jti -> exp(unix)
|
||||
|
||||
func RevokeToken(jti string, exp int64) error {
|
||||
if jti == "" {
|
||||
return nil
|
||||
}
|
||||
// Prefer Redis, else in-memory
|
||||
if common.RedisEnabled {
|
||||
ttl := time.Duration(0)
|
||||
if exp > 0 {
|
||||
ttl = time.Until(time.Unix(exp, 0))
|
||||
}
|
||||
if ttl <= 0 {
|
||||
ttl = time.Minute
|
||||
}
|
||||
key := fmt.Sprintf("oauth:revoked:%s", jti)
|
||||
return common.RedisSet(key, "1", ttl)
|
||||
}
|
||||
if exp <= 0 {
|
||||
exp = time.Now().Add(time.Minute).Unix()
|
||||
}
|
||||
revokedMem.Store(jti, exp)
|
||||
return nil
|
||||
}
|
||||
|
||||
func IsTokenRevoked(jti string) (bool, error) {
|
||||
if jti == "" {
|
||||
return false, nil
|
||||
}
|
||||
if common.RedisEnabled {
|
||||
key := fmt.Sprintf("oauth:revoked:%s", jti)
|
||||
if _, err := common.RedisGet(key); err == nil {
|
||||
return true, nil
|
||||
} else {
|
||||
// Not found or error; treat as not revoked on error to avoid hard failures
|
||||
return false, nil
|
||||
}
|
||||
}
|
||||
// In-memory check
|
||||
if v, ok := revokedMem.Load(jti); ok {
|
||||
exp, _ := v.(int64)
|
||||
if exp == 0 || time.Now().Unix() <= exp {
|
||||
return true, nil
|
||||
}
|
||||
revokedMem.Delete(jti)
|
||||
}
|
||||
return false, nil
|
||||
}
|
||||
@@ -82,7 +82,6 @@ func InitOptionMap() {
|
||||
common.OptionMap["StripeWebhookSecret"] = setting.StripeWebhookSecret
|
||||
common.OptionMap["StripePriceId"] = setting.StripePriceId
|
||||
common.OptionMap["StripeUnitPrice"] = strconv.FormatFloat(setting.StripeUnitPrice, 'f', -1, 64)
|
||||
common.OptionMap["StripePromotionCodesEnabled"] = strconv.FormatBool(setting.StripePromotionCodesEnabled)
|
||||
common.OptionMap["TopupGroupRatio"] = common.TopupGroupRatio2JSONString()
|
||||
common.OptionMap["Chats"] = setting.Chats2JsonString()
|
||||
common.OptionMap["AutoGroups"] = setting.AutoGroups2JsonString()
|
||||
@@ -113,9 +112,6 @@ func InitOptionMap() {
|
||||
common.OptionMap["GroupGroupRatio"] = ratio_setting.GroupGroupRatio2JSONString()
|
||||
common.OptionMap["UserUsableGroups"] = setting.UserUsableGroups2JSONString()
|
||||
common.OptionMap["CompletionRatio"] = ratio_setting.CompletionRatio2JSONString()
|
||||
common.OptionMap["ImageRatio"] = ratio_setting.ImageRatio2JSONString()
|
||||
common.OptionMap["AudioRatio"] = ratio_setting.AudioRatio2JSONString()
|
||||
common.OptionMap["AudioCompletionRatio"] = ratio_setting.AudioCompletionRatio2JSONString()
|
||||
common.OptionMap["TopUpLink"] = common.TopUpLink
|
||||
//common.OptionMap["ChatLink"] = common.ChatLink
|
||||
//common.OptionMap["ChatLink2"] = common.ChatLink2
|
||||
@@ -331,8 +327,6 @@ func updateOptionMap(key string, value string) (err error) {
|
||||
setting.StripeUnitPrice, _ = strconv.ParseFloat(value, 64)
|
||||
case "StripeMinTopUp":
|
||||
setting.StripeMinTopUp, _ = strconv.Atoi(value)
|
||||
case "StripePromotionCodesEnabled":
|
||||
setting.StripePromotionCodesEnabled = value == "true"
|
||||
case "TopupGroupRatio":
|
||||
err = common.UpdateTopupGroupRatioByJSONString(value)
|
||||
case "GitHubClientId":
|
||||
@@ -403,12 +397,6 @@ func updateOptionMap(key string, value string) (err error) {
|
||||
err = ratio_setting.UpdateModelPriceByJSONString(value)
|
||||
case "CacheRatio":
|
||||
err = ratio_setting.UpdateCacheRatioByJSONString(value)
|
||||
case "ImageRatio":
|
||||
err = ratio_setting.UpdateImageRatioByJSONString(value)
|
||||
case "AudioRatio":
|
||||
err = ratio_setting.UpdateAudioRatioByJSONString(value)
|
||||
case "AudioCompletionRatio":
|
||||
err = ratio_setting.UpdateAudioCompletionRatioByJSONString(value)
|
||||
case "TopUpLink":
|
||||
common.TopUpLink = value
|
||||
//case "ChatLink":
|
||||
|
||||
@@ -265,7 +265,6 @@ func doRequest(c *gin.Context, req *http.Request, info *common.RelayInfo) (*http
|
||||
|
||||
resp, err := client.Do(req)
|
||||
if err != nil {
|
||||
logger.LogError(c, "do request failed: "+err.Error())
|
||||
return nil, types.NewError(err, types.ErrorCodeDoRequestFailed, types.ErrOptionWithHideErrMsg("upstream error: do request failed"))
|
||||
}
|
||||
if resp == nil {
|
||||
|
||||
@@ -21,10 +21,6 @@ var awsModelIDMap = map[string]string{
|
||||
"nova-lite-v1:0": "amazon.nova-lite-v1:0",
|
||||
"nova-pro-v1:0": "amazon.nova-pro-v1:0",
|
||||
"nova-premier-v1:0": "amazon.nova-premier-v1:0",
|
||||
"nova-canvas-v1:0": "amazon.nova-canvas-v1:0",
|
||||
"nova-reel-v1:0": "amazon.nova-reel-v1:0",
|
||||
"nova-reel-v1:1": "amazon.nova-reel-v1:1",
|
||||
"nova-sonic-v1:0": "amazon.nova-sonic-v1:0",
|
||||
}
|
||||
|
||||
var awsModelCanCrossRegionMap = map[string]map[string]bool{
|
||||
@@ -86,27 +82,10 @@ var awsModelCanCrossRegionMap = map[string]map[string]bool{
|
||||
"apac": true,
|
||||
},
|
||||
"amazon.nova-premier-v1:0": {
|
||||
"us": true,
|
||||
},
|
||||
"amazon.nova-canvas-v1:0": {
|
||||
"us": true,
|
||||
"eu": true,
|
||||
"apac": true,
|
||||
},
|
||||
"amazon.nova-reel-v1:0": {
|
||||
"us": true,
|
||||
"eu": true,
|
||||
"apac": true,
|
||||
},
|
||||
"amazon.nova-reel-v1:1": {
|
||||
"us": true,
|
||||
},
|
||||
"amazon.nova-sonic-v1:0": {
|
||||
"us": true,
|
||||
"eu": true,
|
||||
"apac": true,
|
||||
},
|
||||
}
|
||||
}}
|
||||
|
||||
var awsRegionCrossModelPrefixMap = map[string]string{
|
||||
"us": "us",
|
||||
|
||||
@@ -3,17 +3,17 @@ package deepseek
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"github.com/gin-gonic/gin"
|
||||
"io"
|
||||
"net/http"
|
||||
"one-api/dto"
|
||||
"one-api/relay/channel"
|
||||
"one-api/relay/channel/claude"
|
||||
"one-api/relay/channel/openai"
|
||||
relaycommon "one-api/relay/common"
|
||||
"one-api/relay/constant"
|
||||
"one-api/types"
|
||||
"strings"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
type Adaptor struct {
|
||||
@@ -25,7 +25,7 @@ func (a *Adaptor) ConvertGeminiRequest(*gin.Context, *relaycommon.RelayInfo, *dt
|
||||
}
|
||||
|
||||
func (a *Adaptor) ConvertClaudeRequest(c *gin.Context, info *relaycommon.RelayInfo, req *dto.ClaudeRequest) (any, error) {
|
||||
adaptor := claude.Adaptor{}
|
||||
adaptor := openai.Adaptor{}
|
||||
return adaptor.ConvertClaudeRequest(c, info, req)
|
||||
}
|
||||
|
||||
@@ -44,19 +44,14 @@ func (a *Adaptor) Init(info *relaycommon.RelayInfo) {
|
||||
|
||||
func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) {
|
||||
fimBaseUrl := info.ChannelBaseUrl
|
||||
switch info.RelayFormat {
|
||||
case types.RelayFormatClaude:
|
||||
return fmt.Sprintf("%s/anthropic/v1/messages", info.ChannelBaseUrl), nil
|
||||
if !strings.HasSuffix(info.ChannelBaseUrl, "/beta") {
|
||||
fimBaseUrl += "/beta"
|
||||
}
|
||||
switch info.RelayMode {
|
||||
case constant.RelayModeCompletions:
|
||||
return fmt.Sprintf("%s/completions", fimBaseUrl), nil
|
||||
default:
|
||||
if !strings.HasSuffix(info.ChannelBaseUrl, "/beta") {
|
||||
fimBaseUrl += "/beta"
|
||||
}
|
||||
switch info.RelayMode {
|
||||
case constant.RelayModeCompletions:
|
||||
return fmt.Sprintf("%s/completions", fimBaseUrl), nil
|
||||
default:
|
||||
return fmt.Sprintf("%s/v1/chat/completions", info.ChannelBaseUrl), nil
|
||||
}
|
||||
return fmt.Sprintf("%s/v1/chat/completions", info.ChannelBaseUrl), nil
|
||||
}
|
||||
}
|
||||
|
||||
@@ -92,17 +87,12 @@ func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, request
|
||||
}
|
||||
|
||||
func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage any, err *types.NewAPIError) {
|
||||
switch info.RelayFormat {
|
||||
case types.RelayFormatClaude:
|
||||
if info.IsStream {
|
||||
return claude.ClaudeStreamHandler(c, resp, info, claude.RequestModeMessage)
|
||||
} else {
|
||||
return claude.ClaudeHandler(c, resp, info, claude.RequestModeMessage)
|
||||
}
|
||||
default:
|
||||
adaptor := openai.Adaptor{}
|
||||
return adaptor.DoResponse(c, resp, info)
|
||||
if info.IsStream {
|
||||
usage, err = openai.OaiStreamHandler(c, info, resp)
|
||||
} else {
|
||||
usage, err = openai.OpenaiHandler(c, info, resp)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
func (a *Adaptor) GetModelList() []string {
|
||||
|
||||
@@ -215,8 +215,8 @@ func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, request
|
||||
|
||||
func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage any, err *types.NewAPIError) {
|
||||
if info.RelayMode == constant.RelayModeGemini {
|
||||
if strings.Contains(info.RequestURLPath, ":embedContent") ||
|
||||
strings.Contains(info.RequestURLPath, ":batchEmbedContents") {
|
||||
if strings.HasSuffix(info.RequestURLPath, ":embedContent") ||
|
||||
strings.HasSuffix(info.RequestURLPath, ":batchEmbedContents") {
|
||||
return NativeGeminiEmbeddingHandler(c, resp, info)
|
||||
}
|
||||
if info.IsStream {
|
||||
|
||||
@@ -23,7 +23,6 @@ import (
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
// https://cloud.google.com/vertex-ai/generative-ai/docs/model-reference/inference?hl=zh-cn#blob
|
||||
var geminiSupportedMimeTypes = map[string]bool{
|
||||
"application/pdf": true,
|
||||
"audio/mpeg": true,
|
||||
@@ -31,7 +30,6 @@ var geminiSupportedMimeTypes = map[string]bool{
|
||||
"audio/wav": true,
|
||||
"image/png": true,
|
||||
"image/jpeg": true,
|
||||
"image/webp": true,
|
||||
"text/plain": true,
|
||||
"video/mov": true,
|
||||
"video/mpeg": true,
|
||||
@@ -245,7 +243,6 @@ func CovertGemini2OpenAI(c *gin.Context, textRequest dto.GeneralOpenAIRequest, i
|
||||
functions := make([]dto.FunctionRequest, 0, len(textRequest.Tools))
|
||||
googleSearch := false
|
||||
codeExecution := false
|
||||
urlContext := false
|
||||
for _, tool := range textRequest.Tools {
|
||||
if tool.Function.Name == "googleSearch" {
|
||||
googleSearch = true
|
||||
@@ -255,10 +252,6 @@ func CovertGemini2OpenAI(c *gin.Context, textRequest dto.GeneralOpenAIRequest, i
|
||||
codeExecution = true
|
||||
continue
|
||||
}
|
||||
if tool.Function.Name == "urlContext" {
|
||||
urlContext = true
|
||||
continue
|
||||
}
|
||||
if tool.Function.Parameters != nil {
|
||||
|
||||
params, ok := tool.Function.Parameters.(map[string]interface{})
|
||||
@@ -286,11 +279,6 @@ func CovertGemini2OpenAI(c *gin.Context, textRequest dto.GeneralOpenAIRequest, i
|
||||
GoogleSearch: make(map[string]string),
|
||||
})
|
||||
}
|
||||
if urlContext {
|
||||
geminiTools = append(geminiTools, dto.GeminiChatTool{
|
||||
URLContext: make(map[string]string),
|
||||
})
|
||||
}
|
||||
if len(functions) > 0 {
|
||||
geminiTools = append(geminiTools, dto.GeminiChatTool{
|
||||
FunctionDeclarations: functions,
|
||||
|
||||
@@ -25,7 +25,7 @@ func (a *Adaptor) ConvertGeminiRequest(*gin.Context, *relaycommon.RelayInfo, *dt
|
||||
}
|
||||
|
||||
func (a *Adaptor) ConvertClaudeRequest(c *gin.Context, info *relaycommon.RelayInfo, req *dto.ClaudeRequest) (any, error) {
|
||||
adaptor := claude.Adaptor{}
|
||||
adaptor := openai.Adaptor{}
|
||||
return adaptor.ConvertClaudeRequest(c, info, req)
|
||||
}
|
||||
|
||||
|
||||
@@ -10,7 +10,6 @@ import (
|
||||
relaycommon "one-api/relay/common"
|
||||
relayconstant "one-api/relay/constant"
|
||||
"one-api/types"
|
||||
"strings"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
@@ -18,7 +17,10 @@ import (
|
||||
type Adaptor struct {
|
||||
}
|
||||
|
||||
func (a *Adaptor) ConvertGeminiRequest(*gin.Context, *relaycommon.RelayInfo, *dto.GeminiChatRequest) (any, error) { return nil, errors.New("not implemented") }
|
||||
func (a *Adaptor) ConvertGeminiRequest(*gin.Context, *relaycommon.RelayInfo, *dto.GeminiChatRequest) (any, error) {
|
||||
//TODO implement me
|
||||
return nil, errors.New("not implemented")
|
||||
}
|
||||
|
||||
func (a *Adaptor) ConvertClaudeRequest(c *gin.Context, info *relaycommon.RelayInfo, request *dto.ClaudeRequest) (any, error) {
|
||||
openaiAdaptor := openai.Adaptor{}
|
||||
@@ -29,21 +31,32 @@ func (a *Adaptor) ConvertClaudeRequest(c *gin.Context, info *relaycommon.RelayIn
|
||||
openaiRequest.(*dto.GeneralOpenAIRequest).StreamOptions = &dto.StreamOptions{
|
||||
IncludeUsage: true,
|
||||
}
|
||||
// map to ollama chat request (Claude -> OpenAI -> Ollama chat)
|
||||
return openAIChatToOllamaChat(c, openaiRequest.(*dto.GeneralOpenAIRequest))
|
||||
return requestOpenAI2Ollama(c, openaiRequest.(*dto.GeneralOpenAIRequest))
|
||||
}
|
||||
|
||||
func (a *Adaptor) ConvertAudioRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.AudioRequest) (io.Reader, error) { return nil, errors.New("not implemented") }
|
||||
func (a *Adaptor) ConvertAudioRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.AudioRequest) (io.Reader, error) {
|
||||
//TODO implement me
|
||||
return nil, errors.New("not implemented")
|
||||
}
|
||||
|
||||
func (a *Adaptor) ConvertImageRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.ImageRequest) (any, error) { return nil, errors.New("not implemented") }
|
||||
func (a *Adaptor) ConvertImageRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.ImageRequest) (any, error) {
|
||||
//TODO implement me
|
||||
return nil, errors.New("not implemented")
|
||||
}
|
||||
|
||||
func (a *Adaptor) Init(info *relaycommon.RelayInfo) {
|
||||
}
|
||||
|
||||
func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) {
|
||||
if info.RelayMode == relayconstant.RelayModeEmbeddings { return info.ChannelBaseUrl + "/api/embed", nil }
|
||||
if strings.Contains(info.RequestURLPath, "/v1/completions") || info.RelayMode == relayconstant.RelayModeCompletions { return info.ChannelBaseUrl + "/api/generate", nil }
|
||||
return info.ChannelBaseUrl + "/api/chat", nil
|
||||
if info.RelayFormat == types.RelayFormatClaude {
|
||||
return info.ChannelBaseUrl + "/v1/chat/completions", nil
|
||||
}
|
||||
switch info.RelayMode {
|
||||
case relayconstant.RelayModeEmbeddings:
|
||||
return info.ChannelBaseUrl + "/api/embed", nil
|
||||
default:
|
||||
return relaycommon.GetFullRequestURL(info.ChannelBaseUrl, info.RequestURLPath, info.ChannelType), nil
|
||||
}
|
||||
}
|
||||
|
||||
func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Header, info *relaycommon.RelayInfo) error {
|
||||
@@ -53,12 +66,10 @@ func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Header, info *rel
|
||||
}
|
||||
|
||||
func (a *Adaptor) ConvertOpenAIRequest(c *gin.Context, info *relaycommon.RelayInfo, request *dto.GeneralOpenAIRequest) (any, error) {
|
||||
if request == nil { return nil, errors.New("request is nil") }
|
||||
// decide generate or chat
|
||||
if strings.Contains(info.RequestURLPath, "/v1/completions") || info.RelayMode == relayconstant.RelayModeCompletions {
|
||||
return openAIToGenerate(c, request)
|
||||
if request == nil {
|
||||
return nil, errors.New("request is nil")
|
||||
}
|
||||
return openAIChatToOllamaChat(c, request)
|
||||
return requestOpenAI2Ollama(c, request)
|
||||
}
|
||||
|
||||
func (a *Adaptor) ConvertRerankRequest(c *gin.Context, relayMode int, request dto.RerankRequest) (any, error) {
|
||||
@@ -69,7 +80,10 @@ func (a *Adaptor) ConvertEmbeddingRequest(c *gin.Context, info *relaycommon.Rela
|
||||
return requestOpenAI2Embeddings(request), nil
|
||||
}
|
||||
|
||||
func (a *Adaptor) ConvertOpenAIResponsesRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.OpenAIResponsesRequest) (any, error) { return nil, errors.New("not implemented") }
|
||||
func (a *Adaptor) ConvertOpenAIResponsesRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.OpenAIResponsesRequest) (any, error) {
|
||||
// TODO implement me
|
||||
return nil, errors.New("not implemented")
|
||||
}
|
||||
|
||||
func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (any, error) {
|
||||
return channel.DoApiRequest(a, c, info, requestBody)
|
||||
@@ -78,13 +92,15 @@ func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, request
|
||||
func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage any, err *types.NewAPIError) {
|
||||
switch info.RelayMode {
|
||||
case relayconstant.RelayModeEmbeddings:
|
||||
return ollamaEmbeddingHandler(c, info, resp)
|
||||
usage, err = ollamaEmbeddingHandler(c, info, resp)
|
||||
default:
|
||||
if info.IsStream {
|
||||
return ollamaStreamHandler(c, info, resp)
|
||||
usage, err = openai.OaiStreamHandler(c, info, resp)
|
||||
} else {
|
||||
usage, err = openai.OpenaiHandler(c, info, resp)
|
||||
}
|
||||
return ollamaChatHandler(c, info, resp)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
func (a *Adaptor) GetModelList() []string {
|
||||
|
||||
@@ -2,69 +2,48 @@ package ollama
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"one-api/dto"
|
||||
)
|
||||
|
||||
type OllamaChatMessage struct {
|
||||
Role string `json:"role"`
|
||||
Content string `json:"content,omitempty"`
|
||||
Images []string `json:"images,omitempty"`
|
||||
ToolCalls []OllamaToolCall `json:"tool_calls,omitempty"`
|
||||
ToolName string `json:"tool_name,omitempty"`
|
||||
Thinking json.RawMessage `json:"thinking,omitempty"`
|
||||
type OllamaRequest struct {
|
||||
Model string `json:"model,omitempty"`
|
||||
Messages []dto.Message `json:"messages,omitempty"`
|
||||
Stream bool `json:"stream,omitempty"`
|
||||
Temperature *float64 `json:"temperature,omitempty"`
|
||||
Seed float64 `json:"seed,omitempty"`
|
||||
Topp float64 `json:"top_p,omitempty"`
|
||||
TopK int `json:"top_k,omitempty"`
|
||||
Stop any `json:"stop,omitempty"`
|
||||
MaxTokens uint `json:"max_tokens,omitempty"`
|
||||
Tools []dto.ToolCallRequest `json:"tools,omitempty"`
|
||||
ResponseFormat any `json:"response_format,omitempty"`
|
||||
FrequencyPenalty float64 `json:"frequency_penalty,omitempty"`
|
||||
PresencePenalty float64 `json:"presence_penalty,omitempty"`
|
||||
Suffix any `json:"suffix,omitempty"`
|
||||
StreamOptions *dto.StreamOptions `json:"stream_options,omitempty"`
|
||||
Prompt any `json:"prompt,omitempty"`
|
||||
Think json.RawMessage `json:"think,omitempty"`
|
||||
}
|
||||
|
||||
type OllamaToolFunction struct {
|
||||
Name string `json:"name"`
|
||||
Description string `json:"description,omitempty"`
|
||||
Parameters interface{} `json:"parameters,omitempty"`
|
||||
}
|
||||
|
||||
type OllamaTool struct {
|
||||
Type string `json:"type"`
|
||||
Function OllamaToolFunction `json:"function"`
|
||||
}
|
||||
|
||||
type OllamaToolCall struct {
|
||||
Function struct {
|
||||
Name string `json:"name"`
|
||||
Arguments interface{} `json:"arguments"`
|
||||
} `json:"function"`
|
||||
}
|
||||
|
||||
type OllamaChatRequest struct {
|
||||
Model string `json:"model"`
|
||||
Messages []OllamaChatMessage `json:"messages"`
|
||||
Tools interface{} `json:"tools,omitempty"`
|
||||
Format interface{} `json:"format,omitempty"`
|
||||
Stream bool `json:"stream,omitempty"`
|
||||
Options map[string]any `json:"options,omitempty"`
|
||||
KeepAlive interface{} `json:"keep_alive,omitempty"`
|
||||
Think json.RawMessage `json:"think,omitempty"`
|
||||
}
|
||||
|
||||
type OllamaGenerateRequest struct {
|
||||
Model string `json:"model"`
|
||||
Prompt string `json:"prompt,omitempty"`
|
||||
Suffix string `json:"suffix,omitempty"`
|
||||
Images []string `json:"images,omitempty"`
|
||||
Format interface{} `json:"format,omitempty"`
|
||||
Stream bool `json:"stream,omitempty"`
|
||||
Options map[string]any `json:"options,omitempty"`
|
||||
KeepAlive interface{} `json:"keep_alive,omitempty"`
|
||||
Think json.RawMessage `json:"think,omitempty"`
|
||||
type Options struct {
|
||||
Seed int `json:"seed,omitempty"`
|
||||
Temperature *float64 `json:"temperature,omitempty"`
|
||||
TopK int `json:"top_k,omitempty"`
|
||||
TopP float64 `json:"top_p,omitempty"`
|
||||
FrequencyPenalty float64 `json:"frequency_penalty,omitempty"`
|
||||
PresencePenalty float64 `json:"presence_penalty,omitempty"`
|
||||
NumPredict int `json:"num_predict,omitempty"`
|
||||
NumCtx int `json:"num_ctx,omitempty"`
|
||||
}
|
||||
|
||||
type OllamaEmbeddingRequest struct {
|
||||
Model string `json:"model"`
|
||||
Input interface{} `json:"input"`
|
||||
Options map[string]any `json:"options,omitempty"`
|
||||
Dimensions int `json:"dimensions,omitempty"`
|
||||
Model string `json:"model,omitempty"`
|
||||
Input []string `json:"input"`
|
||||
Options *Options `json:"options,omitempty"`
|
||||
}
|
||||
|
||||
type OllamaEmbeddingResponse struct {
|
||||
Error string `json:"error,omitempty"`
|
||||
Model string `json:"model"`
|
||||
Embeddings [][]float64 `json:"embeddings"`
|
||||
PromptEvalCount int `json:"prompt_eval_count,omitempty"`
|
||||
Error string `json:"error,omitempty"`
|
||||
Model string `json:"model"`
|
||||
Embedding [][]float64 `json:"embeddings,omitempty"`
|
||||
}
|
||||
|
||||
|
||||
@@ -1,7 +1,6 @@
|
||||
package ollama
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
@@ -15,176 +14,121 @@ import (
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
func openAIChatToOllamaChat(c *gin.Context, r *dto.GeneralOpenAIRequest) (*OllamaChatRequest, error) {
|
||||
chatReq := &OllamaChatRequest{
|
||||
Model: r.Model,
|
||||
Stream: r.Stream,
|
||||
Options: map[string]any{},
|
||||
Think: r.Think,
|
||||
}
|
||||
if r.ResponseFormat != nil {
|
||||
if r.ResponseFormat.Type == "json" {
|
||||
chatReq.Format = "json"
|
||||
} else if r.ResponseFormat.Type == "json_schema" {
|
||||
if len(r.ResponseFormat.JsonSchema) > 0 {
|
||||
var schema any
|
||||
_ = json.Unmarshal(r.ResponseFormat.JsonSchema, &schema)
|
||||
chatReq.Format = schema
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// options mapping
|
||||
if r.Temperature != nil { chatReq.Options["temperature"] = r.Temperature }
|
||||
if r.TopP != 0 { chatReq.Options["top_p"] = r.TopP }
|
||||
if r.TopK != 0 { chatReq.Options["top_k"] = r.TopK }
|
||||
if r.FrequencyPenalty != 0 { chatReq.Options["frequency_penalty"] = r.FrequencyPenalty }
|
||||
if r.PresencePenalty != 0 { chatReq.Options["presence_penalty"] = r.PresencePenalty }
|
||||
if r.Seed != 0 { chatReq.Options["seed"] = int(r.Seed) }
|
||||
if mt := r.GetMaxTokens(); mt != 0 { chatReq.Options["num_predict"] = int(mt) }
|
||||
|
||||
if r.Stop != nil {
|
||||
switch v := r.Stop.(type) {
|
||||
case string:
|
||||
chatReq.Options["stop"] = []string{v}
|
||||
case []string:
|
||||
chatReq.Options["stop"] = v
|
||||
case []any:
|
||||
arr := make([]string,0,len(v))
|
||||
for _, i := range v { if s,ok:=i.(string); ok { arr = append(arr,s) } }
|
||||
if len(arr)>0 { chatReq.Options["stop"] = arr }
|
||||
}
|
||||
}
|
||||
|
||||
if len(r.Tools) > 0 {
|
||||
tools := make([]OllamaTool,0,len(r.Tools))
|
||||
for _, t := range r.Tools {
|
||||
tools = append(tools, OllamaTool{Type: "function", Function: OllamaToolFunction{Name: t.Function.Name, Description: t.Function.Description, Parameters: t.Function.Parameters}})
|
||||
}
|
||||
chatReq.Tools = tools
|
||||
}
|
||||
|
||||
chatReq.Messages = make([]OllamaChatMessage,0,len(r.Messages))
|
||||
for _, m := range r.Messages {
|
||||
var textBuilder strings.Builder
|
||||
var images []string
|
||||
if m.IsStringContent() {
|
||||
textBuilder.WriteString(m.StringContent())
|
||||
} else {
|
||||
parts := m.ParseContent()
|
||||
for _, part := range parts {
|
||||
if part.Type == dto.ContentTypeImageURL {
|
||||
img := part.GetImageMedia()
|
||||
if img != nil && img.Url != "" {
|
||||
var base64Data string
|
||||
if strings.HasPrefix(img.Url, "http") {
|
||||
fileData, err := service.GetFileBase64FromUrl(c, img.Url, "fetch image for ollama chat")
|
||||
if err != nil { return nil, err }
|
||||
base64Data = fileData.Base64Data
|
||||
} else if strings.HasPrefix(img.Url, "data:") {
|
||||
if idx := strings.Index(img.Url, ","); idx != -1 && idx+1 < len(img.Url) { base64Data = img.Url[idx+1:] }
|
||||
} else {
|
||||
base64Data = img.Url
|
||||
func requestOpenAI2Ollama(c *gin.Context, request *dto.GeneralOpenAIRequest) (*OllamaRequest, error) {
|
||||
messages := make([]dto.Message, 0, len(request.Messages))
|
||||
for _, message := range request.Messages {
|
||||
if !message.IsStringContent() {
|
||||
mediaMessages := message.ParseContent()
|
||||
for j, mediaMessage := range mediaMessages {
|
||||
if mediaMessage.Type == dto.ContentTypeImageURL {
|
||||
imageUrl := mediaMessage.GetImageMedia()
|
||||
// check if not base64
|
||||
if strings.HasPrefix(imageUrl.Url, "http") {
|
||||
fileData, err := service.GetFileBase64FromUrl(c, imageUrl.Url, "formatting image for Ollama")
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if base64Data != "" { images = append(images, base64Data) }
|
||||
imageUrl.Url = fmt.Sprintf("data:%s;base64,%s", fileData.MimeType, fileData.Base64Data)
|
||||
}
|
||||
} else if part.Type == dto.ContentTypeText {
|
||||
textBuilder.WriteString(part.Text)
|
||||
mediaMessage.ImageUrl = imageUrl
|
||||
mediaMessages[j] = mediaMessage
|
||||
}
|
||||
}
|
||||
message.SetMediaContent(mediaMessages)
|
||||
}
|
||||
cm := OllamaChatMessage{Role: m.Role, Content: textBuilder.String()}
|
||||
if len(images)>0 { cm.Images = images }
|
||||
if m.Role == "tool" && m.Name != nil { cm.ToolName = *m.Name }
|
||||
if m.ToolCalls != nil && len(m.ToolCalls) > 0 {
|
||||
parsed := m.ParseToolCalls()
|
||||
if len(parsed) > 0 {
|
||||
calls := make([]OllamaToolCall,0,len(parsed))
|
||||
for _, tc := range parsed {
|
||||
var args interface{}
|
||||
if tc.Function.Arguments != "" { _ = json.Unmarshal([]byte(tc.Function.Arguments), &args) }
|
||||
if args==nil { args = map[string]any{} }
|
||||
oc := OllamaToolCall{}
|
||||
oc.Function.Name = tc.Function.Name
|
||||
oc.Function.Arguments = args
|
||||
calls = append(calls, oc)
|
||||
}
|
||||
cm.ToolCalls = calls
|
||||
}
|
||||
}
|
||||
chatReq.Messages = append(chatReq.Messages, cm)
|
||||
messages = append(messages, dto.Message{
|
||||
Role: message.Role,
|
||||
Content: message.Content,
|
||||
ToolCalls: message.ToolCalls,
|
||||
ToolCallId: message.ToolCallId,
|
||||
})
|
||||
}
|
||||
return chatReq, nil
|
||||
str, ok := request.Stop.(string)
|
||||
var Stop []string
|
||||
if ok {
|
||||
Stop = []string{str}
|
||||
} else {
|
||||
Stop, _ = request.Stop.([]string)
|
||||
}
|
||||
ollamaRequest := &OllamaRequest{
|
||||
Model: request.Model,
|
||||
Messages: messages,
|
||||
Stream: request.Stream,
|
||||
Temperature: request.Temperature,
|
||||
Seed: request.Seed,
|
||||
Topp: request.TopP,
|
||||
TopK: request.TopK,
|
||||
Stop: Stop,
|
||||
Tools: request.Tools,
|
||||
MaxTokens: request.GetMaxTokens(),
|
||||
ResponseFormat: request.ResponseFormat,
|
||||
FrequencyPenalty: request.FrequencyPenalty,
|
||||
PresencePenalty: request.PresencePenalty,
|
||||
Prompt: request.Prompt,
|
||||
StreamOptions: request.StreamOptions,
|
||||
Suffix: request.Suffix,
|
||||
}
|
||||
ollamaRequest.Think = request.Think
|
||||
return ollamaRequest, nil
|
||||
}
|
||||
|
||||
// openAIToGenerate converts OpenAI completions request to Ollama generate
|
||||
func openAIToGenerate(c *gin.Context, r *dto.GeneralOpenAIRequest) (*OllamaGenerateRequest, error) {
|
||||
gen := &OllamaGenerateRequest{
|
||||
Model: r.Model,
|
||||
Stream: r.Stream,
|
||||
Options: map[string]any{},
|
||||
Think: r.Think,
|
||||
func requestOpenAI2Embeddings(request dto.EmbeddingRequest) *OllamaEmbeddingRequest {
|
||||
return &OllamaEmbeddingRequest{
|
||||
Model: request.Model,
|
||||
Input: request.ParseInput(),
|
||||
Options: &Options{
|
||||
Seed: int(request.Seed),
|
||||
Temperature: request.Temperature,
|
||||
TopP: request.TopP,
|
||||
FrequencyPenalty: request.FrequencyPenalty,
|
||||
PresencePenalty: request.PresencePenalty,
|
||||
},
|
||||
}
|
||||
// Prompt may be in r.Prompt (string or []any)
|
||||
if r.Prompt != nil {
|
||||
switch v := r.Prompt.(type) {
|
||||
case string:
|
||||
gen.Prompt = v
|
||||
case []any:
|
||||
var sb strings.Builder
|
||||
for _, it := range v { if s,ok:=it.(string); ok { sb.WriteString(s) } }
|
||||
gen.Prompt = sb.String()
|
||||
default:
|
||||
gen.Prompt = fmt.Sprintf("%v", r.Prompt)
|
||||
}
|
||||
}
|
||||
if r.Suffix != nil { if s,ok:=r.Suffix.(string); ok { gen.Suffix = s } }
|
||||
if r.ResponseFormat != nil {
|
||||
if r.ResponseFormat.Type == "json" { gen.Format = "json" } else if r.ResponseFormat.Type == "json_schema" { var schema any; _ = json.Unmarshal(r.ResponseFormat.JsonSchema,&schema); gen.Format=schema }
|
||||
}
|
||||
if r.Temperature != nil { gen.Options["temperature"] = r.Temperature }
|
||||
if r.TopP != 0 { gen.Options["top_p"] = r.TopP }
|
||||
if r.TopK != 0 { gen.Options["top_k"] = r.TopK }
|
||||
if r.FrequencyPenalty != 0 { gen.Options["frequency_penalty"] = r.FrequencyPenalty }
|
||||
if r.PresencePenalty != 0 { gen.Options["presence_penalty"] = r.PresencePenalty }
|
||||
if r.Seed != 0 { gen.Options["seed"] = int(r.Seed) }
|
||||
if mt := r.GetMaxTokens(); mt != 0 { gen.Options["num_predict"] = int(mt) }
|
||||
if r.Stop != nil {
|
||||
switch v := r.Stop.(type) {
|
||||
case string: gen.Options["stop"] = []string{v}
|
||||
case []string: gen.Options["stop"] = v
|
||||
case []any: arr:=make([]string,0,len(v)); for _,i:= range v { if s,ok:=i.(string); ok { arr=append(arr,s) } }; if len(arr)>0 { gen.Options["stop"]=arr }
|
||||
}
|
||||
}
|
||||
return gen, nil
|
||||
}
|
||||
|
||||
func requestOpenAI2Embeddings(r dto.EmbeddingRequest) *OllamaEmbeddingRequest {
|
||||
opts := map[string]any{}
|
||||
if r.Temperature != nil { opts["temperature"] = r.Temperature }
|
||||
if r.TopP != 0 { opts["top_p"] = r.TopP }
|
||||
if r.FrequencyPenalty != 0 { opts["frequency_penalty"] = r.FrequencyPenalty }
|
||||
if r.PresencePenalty != 0 { opts["presence_penalty"] = r.PresencePenalty }
|
||||
if r.Seed != 0 { opts["seed"] = int(r.Seed) }
|
||||
if r.Dimensions != 0 { opts["dimensions"] = r.Dimensions }
|
||||
input := r.ParseInput()
|
||||
if len(input)==1 { return &OllamaEmbeddingRequest{Model:r.Model, Input: input[0], Options: opts, Dimensions:r.Dimensions} }
|
||||
return &OllamaEmbeddingRequest{Model:r.Model, Input: input, Options: opts, Dimensions:r.Dimensions}
|
||||
}
|
||||
|
||||
func ollamaEmbeddingHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http.Response) (*dto.Usage, *types.NewAPIError) {
|
||||
var oResp OllamaEmbeddingResponse
|
||||
body, err := io.ReadAll(resp.Body)
|
||||
if err != nil { return nil, types.NewOpenAIError(err, types.ErrorCodeBadResponseBody, http.StatusInternalServerError) }
|
||||
var ollamaEmbeddingResponse OllamaEmbeddingResponse
|
||||
responseBody, err := io.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
return nil, types.NewOpenAIError(err, types.ErrorCodeBadResponseBody, http.StatusInternalServerError)
|
||||
}
|
||||
service.CloseResponseBodyGracefully(resp)
|
||||
if err = common.Unmarshal(body, &oResp); err != nil { return nil, types.NewOpenAIError(err, types.ErrorCodeBadResponseBody, http.StatusInternalServerError) }
|
||||
if oResp.Error != "" { return nil, types.NewOpenAIError(fmt.Errorf("ollama error: %s", oResp.Error), types.ErrorCodeBadResponseBody, http.StatusInternalServerError) }
|
||||
data := make([]dto.OpenAIEmbeddingResponseItem,0,len(oResp.Embeddings))
|
||||
for i, emb := range oResp.Embeddings { data = append(data, dto.OpenAIEmbeddingResponseItem{Index:i,Object:"embedding",Embedding:emb}) }
|
||||
usage := &dto.Usage{PromptTokens: oResp.PromptEvalCount, CompletionTokens:0, TotalTokens: oResp.PromptEvalCount}
|
||||
embResp := &dto.OpenAIEmbeddingResponse{Object:"list", Data:data, Model: info.UpstreamModelName, Usage:*usage}
|
||||
out, _ := common.Marshal(embResp)
|
||||
service.IOCopyBytesGracefully(c, resp, out)
|
||||
err = common.Unmarshal(responseBody, &ollamaEmbeddingResponse)
|
||||
if err != nil {
|
||||
return nil, types.NewOpenAIError(err, types.ErrorCodeBadResponseBody, http.StatusInternalServerError)
|
||||
}
|
||||
if ollamaEmbeddingResponse.Error != "" {
|
||||
return nil, types.NewOpenAIError(fmt.Errorf("ollama error: %s", ollamaEmbeddingResponse.Error), types.ErrorCodeBadResponseBody, http.StatusInternalServerError)
|
||||
}
|
||||
flattenedEmbeddings := flattenEmbeddings(ollamaEmbeddingResponse.Embedding)
|
||||
data := make([]dto.OpenAIEmbeddingResponseItem, 0, 1)
|
||||
data = append(data, dto.OpenAIEmbeddingResponseItem{
|
||||
Embedding: flattenedEmbeddings,
|
||||
Object: "embedding",
|
||||
})
|
||||
usage := &dto.Usage{
|
||||
TotalTokens: info.PromptTokens,
|
||||
CompletionTokens: 0,
|
||||
PromptTokens: info.PromptTokens,
|
||||
}
|
||||
embeddingResponse := &dto.OpenAIEmbeddingResponse{
|
||||
Object: "list",
|
||||
Data: data,
|
||||
Model: info.UpstreamModelName,
|
||||
Usage: *usage,
|
||||
}
|
||||
doResponseBody, err := common.Marshal(embeddingResponse)
|
||||
if err != nil {
|
||||
return nil, types.NewOpenAIError(err, types.ErrorCodeBadResponseBody, http.StatusInternalServerError)
|
||||
}
|
||||
service.IOCopyBytesGracefully(c, resp, doResponseBody)
|
||||
return usage, nil
|
||||
}
|
||||
|
||||
func flattenEmbeddings(embeddings [][]float64) []float64 {
|
||||
flattened := []float64{}
|
||||
for _, row := range embeddings {
|
||||
flattened = append(flattened, row...)
|
||||
}
|
||||
return flattened
|
||||
}
|
||||
|
||||
@@ -1,210 +0,0 @@
|
||||
package ollama
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"one-api/common"
|
||||
"one-api/dto"
|
||||
"one-api/logger"
|
||||
relaycommon "one-api/relay/common"
|
||||
"one-api/relay/helper"
|
||||
"one-api/service"
|
||||
"one-api/types"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
type ollamaChatStreamChunk struct {
|
||||
Model string `json:"model"`
|
||||
CreatedAt string `json:"created_at"`
|
||||
// chat
|
||||
Message *struct {
|
||||
Role string `json:"role"`
|
||||
Content string `json:"content"`
|
||||
Thinking json.RawMessage `json:"thinking"`
|
||||
ToolCalls []struct {
|
||||
Function struct {
|
||||
Name string `json:"name"`
|
||||
Arguments interface{} `json:"arguments"`
|
||||
} `json:"function"`
|
||||
} `json:"tool_calls"`
|
||||
} `json:"message"`
|
||||
// generate
|
||||
Response string `json:"response"`
|
||||
Done bool `json:"done"`
|
||||
DoneReason string `json:"done_reason"`
|
||||
TotalDuration int64 `json:"total_duration"`
|
||||
LoadDuration int64 `json:"load_duration"`
|
||||
PromptEvalCount int `json:"prompt_eval_count"`
|
||||
EvalCount int `json:"eval_count"`
|
||||
PromptEvalDuration int64 `json:"prompt_eval_duration"`
|
||||
EvalDuration int64 `json:"eval_duration"`
|
||||
}
|
||||
|
||||
func toUnix(ts string) int64 {
|
||||
if ts == "" { return time.Now().Unix() }
|
||||
// try time.RFC3339 or with nanoseconds
|
||||
t, err := time.Parse(time.RFC3339Nano, ts)
|
||||
if err != nil { t2, err2 := time.Parse(time.RFC3339, ts); if err2==nil { return t2.Unix() }; return time.Now().Unix() }
|
||||
return t.Unix()
|
||||
}
|
||||
|
||||
func ollamaStreamHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http.Response) (*dto.Usage, *types.NewAPIError) {
|
||||
if resp == nil || resp.Body == nil { return nil, types.NewOpenAIError(fmt.Errorf("empty response"), types.ErrorCodeBadResponse, http.StatusBadRequest) }
|
||||
defer service.CloseResponseBodyGracefully(resp)
|
||||
|
||||
helper.SetEventStreamHeaders(c)
|
||||
scanner := bufio.NewScanner(resp.Body)
|
||||
usage := &dto.Usage{}
|
||||
var model = info.UpstreamModelName
|
||||
var responseId = common.GetUUID()
|
||||
var created = time.Now().Unix()
|
||||
var toolCallIndex int
|
||||
start := helper.GenerateStartEmptyResponse(responseId, created, model, nil)
|
||||
if data, err := common.Marshal(start); err == nil { _ = helper.StringData(c, string(data)) }
|
||||
|
||||
for scanner.Scan() {
|
||||
line := scanner.Text()
|
||||
line = strings.TrimSpace(line)
|
||||
if line == "" { continue }
|
||||
var chunk ollamaChatStreamChunk
|
||||
if err := json.Unmarshal([]byte(line), &chunk); err != nil {
|
||||
logger.LogError(c, "ollama stream json decode error: "+err.Error()+" line="+line)
|
||||
return usage, types.NewOpenAIError(err, types.ErrorCodeBadResponseBody, http.StatusInternalServerError)
|
||||
}
|
||||
if chunk.Model != "" { model = chunk.Model }
|
||||
created = toUnix(chunk.CreatedAt)
|
||||
|
||||
if !chunk.Done {
|
||||
// delta content
|
||||
var content string
|
||||
if chunk.Message != nil { content = chunk.Message.Content } else { content = chunk.Response }
|
||||
delta := dto.ChatCompletionsStreamResponse{
|
||||
Id: responseId,
|
||||
Object: "chat.completion.chunk",
|
||||
Created: created,
|
||||
Model: model,
|
||||
Choices: []dto.ChatCompletionsStreamResponseChoice{ {
|
||||
Index: 0,
|
||||
Delta: dto.ChatCompletionsStreamResponseChoiceDelta{ Role: "assistant" },
|
||||
} },
|
||||
}
|
||||
if content != "" { delta.Choices[0].Delta.SetContentString(content) }
|
||||
if chunk.Message != nil && len(chunk.Message.Thinking) > 0 {
|
||||
raw := strings.TrimSpace(string(chunk.Message.Thinking))
|
||||
if raw != "" && raw != "null" { delta.Choices[0].Delta.SetReasoningContent(raw) }
|
||||
}
|
||||
// tool calls
|
||||
if chunk.Message != nil && len(chunk.Message.ToolCalls) > 0 {
|
||||
delta.Choices[0].Delta.ToolCalls = make([]dto.ToolCallResponse,0,len(chunk.Message.ToolCalls))
|
||||
for _, tc := range chunk.Message.ToolCalls {
|
||||
// arguments -> string
|
||||
argBytes, _ := json.Marshal(tc.Function.Arguments)
|
||||
toolId := fmt.Sprintf("call_%d", toolCallIndex)
|
||||
tr := dto.ToolCallResponse{ID:toolId, Type:"function", Function: dto.FunctionResponse{Name: tc.Function.Name, Arguments: string(argBytes)}}
|
||||
tr.SetIndex(toolCallIndex)
|
||||
toolCallIndex++
|
||||
delta.Choices[0].Delta.ToolCalls = append(delta.Choices[0].Delta.ToolCalls, tr)
|
||||
}
|
||||
}
|
||||
if data, err := common.Marshal(delta); err == nil { _ = helper.StringData(c, string(data)) }
|
||||
continue
|
||||
}
|
||||
// done frame
|
||||
// finalize once and break loop
|
||||
usage.PromptTokens = chunk.PromptEvalCount
|
||||
usage.CompletionTokens = chunk.EvalCount
|
||||
usage.TotalTokens = usage.PromptTokens + usage.CompletionTokens
|
||||
finishReason := chunk.DoneReason
|
||||
if finishReason == "" { finishReason = "stop" }
|
||||
// emit stop delta
|
||||
if stop := helper.GenerateStopResponse(responseId, created, model, finishReason); stop != nil {
|
||||
if data, err := common.Marshal(stop); err == nil { _ = helper.StringData(c, string(data)) }
|
||||
}
|
||||
// emit usage frame
|
||||
if final := helper.GenerateFinalUsageResponse(responseId, created, model, *usage); final != nil {
|
||||
if data, err := common.Marshal(final); err == nil { _ = helper.StringData(c, string(data)) }
|
||||
}
|
||||
// send [DONE]
|
||||
helper.Done(c)
|
||||
break
|
||||
}
|
||||
if err := scanner.Err(); err != nil && err != io.EOF { logger.LogError(c, "ollama stream scan error: "+err.Error()) }
|
||||
return usage, nil
|
||||
}
|
||||
|
||||
// non-stream handler for chat/generate
|
||||
func ollamaChatHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http.Response) (*dto.Usage, *types.NewAPIError) {
|
||||
body, err := io.ReadAll(resp.Body)
|
||||
if err != nil { return nil, types.NewOpenAIError(err, types.ErrorCodeReadResponseBodyFailed, http.StatusInternalServerError) }
|
||||
service.CloseResponseBodyGracefully(resp)
|
||||
raw := string(body)
|
||||
if common.DebugEnabled { println("ollama non-stream raw resp:", raw) }
|
||||
|
||||
lines := strings.Split(raw, "\n")
|
||||
var (
|
||||
aggContent strings.Builder
|
||||
reasoningBuilder strings.Builder
|
||||
lastChunk ollamaChatStreamChunk
|
||||
parsedAny bool
|
||||
)
|
||||
for _, ln := range lines {
|
||||
ln = strings.TrimSpace(ln)
|
||||
if ln == "" { continue }
|
||||
var ck ollamaChatStreamChunk
|
||||
if err := json.Unmarshal([]byte(ln), &ck); err != nil {
|
||||
if len(lines) == 1 { return nil, types.NewOpenAIError(err, types.ErrorCodeBadResponseBody, http.StatusInternalServerError) }
|
||||
continue
|
||||
}
|
||||
parsedAny = true
|
||||
lastChunk = ck
|
||||
if ck.Message != nil && len(ck.Message.Thinking) > 0 {
|
||||
raw := strings.TrimSpace(string(ck.Message.Thinking))
|
||||
if raw != "" && raw != "null" { reasoningBuilder.WriteString(raw) }
|
||||
}
|
||||
if ck.Message != nil && ck.Message.Content != "" { aggContent.WriteString(ck.Message.Content) } else if ck.Response != "" { aggContent.WriteString(ck.Response) }
|
||||
}
|
||||
|
||||
if !parsedAny {
|
||||
var single ollamaChatStreamChunk
|
||||
if err := json.Unmarshal(body, &single); err != nil { return nil, types.NewOpenAIError(err, types.ErrorCodeBadResponseBody, http.StatusInternalServerError) }
|
||||
lastChunk = single
|
||||
if single.Message != nil {
|
||||
if len(single.Message.Thinking) > 0 { raw := strings.TrimSpace(string(single.Message.Thinking)); if raw != "" && raw != "null" { reasoningBuilder.WriteString(raw) } }
|
||||
aggContent.WriteString(single.Message.Content)
|
||||
} else { aggContent.WriteString(single.Response) }
|
||||
}
|
||||
|
||||
model := lastChunk.Model
|
||||
if model == "" { model = info.UpstreamModelName }
|
||||
created := toUnix(lastChunk.CreatedAt)
|
||||
usage := &dto.Usage{PromptTokens: lastChunk.PromptEvalCount, CompletionTokens: lastChunk.EvalCount, TotalTokens: lastChunk.PromptEvalCount + lastChunk.EvalCount}
|
||||
content := aggContent.String()
|
||||
finishReason := lastChunk.DoneReason
|
||||
if finishReason == "" { finishReason = "stop" }
|
||||
|
||||
msg := dto.Message{Role: "assistant", Content: contentPtr(content)}
|
||||
if rc := reasoningBuilder.String(); rc != "" { msg.ReasoningContent = rc }
|
||||
full := dto.OpenAITextResponse{
|
||||
Id: common.GetUUID(),
|
||||
Model: model,
|
||||
Object: "chat.completion",
|
||||
Created: created,
|
||||
Choices: []dto.OpenAITextResponseChoice{ {
|
||||
Index: 0,
|
||||
Message: msg,
|
||||
FinishReason: finishReason,
|
||||
} },
|
||||
Usage: *usage,
|
||||
}
|
||||
out, _ := common.Marshal(full)
|
||||
service.IOCopyBytesGracefully(c, resp, out)
|
||||
return usage, nil
|
||||
}
|
||||
|
||||
func contentPtr(s string) *string { if s=="" { return nil }; return &s }
|
||||
@@ -12,7 +12,6 @@ import (
|
||||
"one-api/constant"
|
||||
"one-api/dto"
|
||||
"one-api/logger"
|
||||
"one-api/relay/channel/openrouter"
|
||||
relaycommon "one-api/relay/common"
|
||||
"one-api/relay/helper"
|
||||
"one-api/service"
|
||||
@@ -186,27 +185,10 @@ func OpenaiHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http.Respo
|
||||
if common.DebugEnabled {
|
||||
println("upstream response body:", string(responseBody))
|
||||
}
|
||||
// Unmarshal to simpleResponse
|
||||
if info.ChannelType == constant.ChannelTypeOpenRouter && info.ChannelOtherSettings.IsOpenRouterEnterprise() {
|
||||
// 尝试解析为 openrouter enterprise
|
||||
var enterpriseResponse openrouter.OpenRouterEnterpriseResponse
|
||||
err = common.Unmarshal(responseBody, &enterpriseResponse)
|
||||
if err != nil {
|
||||
return nil, types.NewOpenAIError(err, types.ErrorCodeBadResponseBody, http.StatusInternalServerError)
|
||||
}
|
||||
if enterpriseResponse.Success {
|
||||
responseBody = enterpriseResponse.Data
|
||||
} else {
|
||||
logger.LogError(c, fmt.Sprintf("openrouter enterprise response success=false, data: %s", enterpriseResponse.Data))
|
||||
return nil, types.NewOpenAIError(fmt.Errorf("openrouter response success=false"), types.ErrorCodeBadResponseBody, http.StatusInternalServerError)
|
||||
}
|
||||
}
|
||||
|
||||
err = common.Unmarshal(responseBody, &simpleResponse)
|
||||
if err != nil {
|
||||
return nil, types.NewOpenAIError(err, types.ErrorCodeBadResponseBody, http.StatusInternalServerError)
|
||||
}
|
||||
|
||||
if oaiError := simpleResponse.GetOpenAIError(); oaiError != nil && oaiError.Type != "" {
|
||||
return nil, types.WithOpenAIError(*oaiError, resp.StatusCode)
|
||||
}
|
||||
|
||||
@@ -33,12 +33,6 @@ func OaiResponsesHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http
|
||||
return nil, types.WithOpenAIError(*oaiError, resp.StatusCode)
|
||||
}
|
||||
|
||||
if responsesResponse.HasImageGenerationCall() {
|
||||
c.Set("image_generation_call", true)
|
||||
c.Set("image_generation_call_quality", responsesResponse.GetQuality())
|
||||
c.Set("image_generation_call_size", responsesResponse.GetSize())
|
||||
}
|
||||
|
||||
// 写入新的 response body
|
||||
service.IOCopyBytesGracefully(c, resp, responseBody)
|
||||
|
||||
@@ -86,25 +80,18 @@ func OaiResponsesStreamHandler(c *gin.Context, info *relaycommon.RelayInfo, resp
|
||||
sendResponsesStreamData(c, streamResponse, data)
|
||||
switch streamResponse.Type {
|
||||
case "response.completed":
|
||||
if streamResponse.Response != nil {
|
||||
if streamResponse.Response.Usage != nil {
|
||||
if streamResponse.Response.Usage.InputTokens != 0 {
|
||||
usage.PromptTokens = streamResponse.Response.Usage.InputTokens
|
||||
}
|
||||
if streamResponse.Response.Usage.OutputTokens != 0 {
|
||||
usage.CompletionTokens = streamResponse.Response.Usage.OutputTokens
|
||||
}
|
||||
if streamResponse.Response.Usage.TotalTokens != 0 {
|
||||
usage.TotalTokens = streamResponse.Response.Usage.TotalTokens
|
||||
}
|
||||
if streamResponse.Response.Usage.InputTokensDetails != nil {
|
||||
usage.PromptTokensDetails.CachedTokens = streamResponse.Response.Usage.InputTokensDetails.CachedTokens
|
||||
}
|
||||
if streamResponse.Response != nil && streamResponse.Response.Usage != nil {
|
||||
if streamResponse.Response.Usage.InputTokens != 0 {
|
||||
usage.PromptTokens = streamResponse.Response.Usage.InputTokens
|
||||
}
|
||||
if streamResponse.Response.HasImageGenerationCall() {
|
||||
c.Set("image_generation_call", true)
|
||||
c.Set("image_generation_call_quality", streamResponse.Response.GetQuality())
|
||||
c.Set("image_generation_call_size", streamResponse.Response.GetSize())
|
||||
if streamResponse.Response.Usage.OutputTokens != 0 {
|
||||
usage.CompletionTokens = streamResponse.Response.Usage.OutputTokens
|
||||
}
|
||||
if streamResponse.Response.Usage.TotalTokens != 0 {
|
||||
usage.TotalTokens = streamResponse.Response.Usage.TotalTokens
|
||||
}
|
||||
if streamResponse.Response.Usage.InputTokensDetails != nil {
|
||||
usage.PromptTokensDetails.CachedTokens = streamResponse.Response.Usage.InputTokensDetails.CachedTokens
|
||||
}
|
||||
}
|
||||
case "response.output_text.delta":
|
||||
|
||||
@@ -1,7 +1,5 @@
|
||||
package openrouter
|
||||
|
||||
import "encoding/json"
|
||||
|
||||
type RequestReasoning struct {
|
||||
// One of the following (not both):
|
||||
Effort string `json:"effort,omitempty"` // Can be "high", "medium", or "low" (OpenAI-style)
|
||||
@@ -9,8 +7,3 @@ type RequestReasoning struct {
|
||||
// Optional: Default is false. All models support this.
|
||||
Exclude bool `json:"exclude,omitempty"` // Set to true to exclude reasoning tokens from response
|
||||
}
|
||||
|
||||
type OpenRouterEnterpriseResponse struct {
|
||||
Data json.RawMessage `json:"data"`
|
||||
Success bool `json:"success"`
|
||||
}
|
||||
|
||||
@@ -94,9 +94,6 @@ func (a *TaskAdaptor) ValidateRequestAndSetAction(c *gin.Context, info *relaycom
|
||||
|
||||
// BuildRequestURL constructs the upstream URL.
|
||||
func (a *TaskAdaptor) BuildRequestURL(info *relaycommon.RelayInfo) (string, error) {
|
||||
if isNewAPIRelay(info.ApiKey) {
|
||||
return fmt.Sprintf("%s/jimeng/?Action=CVSync2AsyncSubmitTask&Version=2022-08-31", a.baseURL), nil
|
||||
}
|
||||
return fmt.Sprintf("%s/?Action=CVSync2AsyncSubmitTask&Version=2022-08-31", a.baseURL), nil
|
||||
}
|
||||
|
||||
@@ -104,12 +101,7 @@ func (a *TaskAdaptor) BuildRequestURL(info *relaycommon.RelayInfo) (string, erro
|
||||
func (a *TaskAdaptor) BuildRequestHeader(c *gin.Context, req *http.Request, info *relaycommon.RelayInfo) error {
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
req.Header.Set("Accept", "application/json")
|
||||
if isNewAPIRelay(info.ApiKey) {
|
||||
req.Header.Set("Authorization", "Bearer "+info.ApiKey)
|
||||
} else {
|
||||
return a.signRequest(req, a.accessKey, a.secretKey)
|
||||
}
|
||||
return nil
|
||||
return a.signRequest(req, a.accessKey, a.secretKey)
|
||||
}
|
||||
|
||||
// BuildRequestBody converts request into Jimeng specific format.
|
||||
@@ -169,9 +161,6 @@ func (a *TaskAdaptor) FetchTask(baseUrl, key string, body map[string]any) (*http
|
||||
}
|
||||
|
||||
uri := fmt.Sprintf("%s/?Action=CVSync2AsyncGetResult&Version=2022-08-31", baseUrl)
|
||||
if isNewAPIRelay(key) {
|
||||
uri = fmt.Sprintf("%s/jimeng/?Action=CVSync2AsyncGetResult&Version=2022-08-31", a.baseURL)
|
||||
}
|
||||
payload := map[string]string{
|
||||
"req_key": "jimeng_vgfm_t2v_l20", // This is fixed value from doc: https://www.volcengine.com/docs/85621/1544774
|
||||
"task_id": taskID,
|
||||
@@ -189,20 +178,17 @@ func (a *TaskAdaptor) FetchTask(baseUrl, key string, body map[string]any) (*http
|
||||
req.Header.Set("Accept", "application/json")
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
|
||||
if isNewAPIRelay(key) {
|
||||
req.Header.Set("Authorization", "Bearer "+key)
|
||||
} else {
|
||||
keyParts := strings.Split(key, "|")
|
||||
if len(keyParts) != 2 {
|
||||
return nil, fmt.Errorf("invalid api key format for jimeng: expected 'ak|sk'")
|
||||
}
|
||||
accessKey := strings.TrimSpace(keyParts[0])
|
||||
secretKey := strings.TrimSpace(keyParts[1])
|
||||
|
||||
if err := a.signRequest(req, accessKey, secretKey); err != nil {
|
||||
return nil, errors.Wrap(err, "sign request failed")
|
||||
}
|
||||
keyParts := strings.Split(key, "|")
|
||||
if len(keyParts) != 2 {
|
||||
return nil, fmt.Errorf("invalid api key format for jimeng: expected 'ak|sk'")
|
||||
}
|
||||
accessKey := strings.TrimSpace(keyParts[0])
|
||||
secretKey := strings.TrimSpace(keyParts[1])
|
||||
|
||||
if err := a.signRequest(req, accessKey, secretKey); err != nil {
|
||||
return nil, errors.Wrap(err, "sign request failed")
|
||||
}
|
||||
|
||||
return service.GetHttpClient().Do(req)
|
||||
}
|
||||
|
||||
@@ -398,7 +384,3 @@ func (a *TaskAdaptor) ParseTaskResult(respBody []byte) (*relaycommon.TaskInfo, e
|
||||
taskResult.Url = resTask.Data.VideoUrl
|
||||
return &taskResult, nil
|
||||
}
|
||||
|
||||
func isNewAPIRelay(apiKey string) bool {
|
||||
return strings.HasPrefix(apiKey, "sk-")
|
||||
}
|
||||
|
||||
@@ -117,11 +117,6 @@ func (a *TaskAdaptor) ValidateRequestAndSetAction(c *gin.Context, info *relaycom
|
||||
// BuildRequestURL constructs the upstream URL.
|
||||
func (a *TaskAdaptor) BuildRequestURL(info *relaycommon.RelayInfo) (string, error) {
|
||||
path := lo.Ternary(info.Action == constant.TaskActionGenerate, "/v1/videos/image2video", "/v1/videos/text2video")
|
||||
|
||||
if isNewAPIRelay(info.ApiKey) {
|
||||
return fmt.Sprintf("%s/kling%s", a.baseURL, path), nil
|
||||
}
|
||||
|
||||
return fmt.Sprintf("%s%s", a.baseURL, path), nil
|
||||
}
|
||||
|
||||
@@ -204,9 +199,6 @@ func (a *TaskAdaptor) FetchTask(baseUrl, key string, body map[string]any) (*http
|
||||
}
|
||||
path := lo.Ternary(action == constant.TaskActionGenerate, "/v1/videos/image2video", "/v1/videos/text2video")
|
||||
url := fmt.Sprintf("%s%s/%s", baseUrl, path, taskID)
|
||||
if isNewAPIRelay(key) {
|
||||
url = fmt.Sprintf("%s/kling%s/%s", baseUrl, path, taskID)
|
||||
}
|
||||
|
||||
req, err := http.NewRequest(http.MethodGet, url, nil)
|
||||
if err != nil {
|
||||
@@ -312,13 +304,8 @@ func (a *TaskAdaptor) createJWTToken() (string, error) {
|
||||
//}
|
||||
|
||||
func (a *TaskAdaptor) createJWTTokenWithKey(apiKey string) (string, error) {
|
||||
if isNewAPIRelay(apiKey) {
|
||||
return apiKey, nil // new api relay
|
||||
}
|
||||
|
||||
keyParts := strings.Split(apiKey, "|")
|
||||
if len(keyParts) != 2 {
|
||||
return "", errors.New("invalid api_key, required format is accessKey|secretKey")
|
||||
}
|
||||
accessKey := strings.TrimSpace(keyParts[0])
|
||||
if len(keyParts) == 1 {
|
||||
return accessKey, nil
|
||||
@@ -365,7 +352,3 @@ func (a *TaskAdaptor) ParseTaskResult(respBody []byte) (*relaycommon.TaskInfo, e
|
||||
}
|
||||
return taskInfo, nil
|
||||
}
|
||||
|
||||
func isNewAPIRelay(apiKey string) bool {
|
||||
return strings.HasPrefix(apiKey, "sk-")
|
||||
}
|
||||
|
||||
@@ -80,7 +80,8 @@ func (a *TaskAdaptor) Init(info *relaycommon.RelayInfo) {
|
||||
}
|
||||
|
||||
func (a *TaskAdaptor) ValidateRequestAndSetAction(c *gin.Context, info *relaycommon.RelayInfo) *dto.TaskError {
|
||||
return relaycommon.ValidateBasicTaskRequest(c, info, constant.TaskActionGenerate)
|
||||
// Use the unified validation method for TaskSubmitReq with image-based action determination
|
||||
return relaycommon.ValidateTaskRequestWithImageBinding(c, info)
|
||||
}
|
||||
|
||||
func (a *TaskAdaptor) BuildRequestBody(c *gin.Context, _ *relaycommon.RelayInfo) (io.Reader, error) {
|
||||
@@ -111,10 +112,6 @@ func (a *TaskAdaptor) BuildRequestURL(info *relaycommon.RelayInfo) (string, erro
|
||||
switch info.Action {
|
||||
case constant.TaskActionGenerate:
|
||||
path = "/img2video"
|
||||
case constant.TaskActionFirstTailGenerate:
|
||||
path = "/start-end2video"
|
||||
case constant.TaskActionReferenceGenerate:
|
||||
path = "/reference2video"
|
||||
default:
|
||||
path = "/text2video"
|
||||
}
|
||||
@@ -190,9 +187,14 @@ func (a *TaskAdaptor) GetChannelName() string {
|
||||
// ============================
|
||||
|
||||
func (a *TaskAdaptor) convertToRequestPayload(req *relaycommon.TaskSubmitReq) (*requestPayload, error) {
|
||||
var images []string
|
||||
if req.Image != "" {
|
||||
images = []string{req.Image}
|
||||
}
|
||||
|
||||
r := requestPayload{
|
||||
Model: defaultString(req.Model, "viduq1"),
|
||||
Images: req.Images,
|
||||
Images: images,
|
||||
Prompt: req.Prompt,
|
||||
Duration: defaultInt(req.Duration, 5),
|
||||
Resolution: defaultString(req.Size, "1080p"),
|
||||
|
||||
@@ -9,7 +9,6 @@ import (
|
||||
"mime/multipart"
|
||||
"net/http"
|
||||
"net/textproto"
|
||||
channelconstant "one-api/constant"
|
||||
"one-api/dto"
|
||||
"one-api/relay/channel"
|
||||
"one-api/relay/channel/openai"
|
||||
@@ -42,8 +41,6 @@ func (a *Adaptor) ConvertAudioRequest(c *gin.Context, info *relaycommon.RelayInf
|
||||
|
||||
func (a *Adaptor) ConvertImageRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.ImageRequest) (any, error) {
|
||||
switch info.RelayMode {
|
||||
case constant.RelayModeImagesGenerations:
|
||||
return request, nil
|
||||
case constant.RelayModeImagesEdits:
|
||||
|
||||
var requestBody bytes.Buffer
|
||||
@@ -189,26 +186,20 @@ func (a *Adaptor) Init(info *relaycommon.RelayInfo) {
|
||||
}
|
||||
|
||||
func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) {
|
||||
// 支持自定义域名,如果未设置则使用默认域名
|
||||
baseUrl := info.ChannelBaseUrl
|
||||
if baseUrl == "" {
|
||||
baseUrl = channelconstant.ChannelBaseURLs[channelconstant.ChannelTypeVolcEngine]
|
||||
}
|
||||
|
||||
switch info.RelayMode {
|
||||
case constant.RelayModeChatCompletions:
|
||||
if strings.HasPrefix(info.UpstreamModelName, "bot") {
|
||||
return fmt.Sprintf("%s/api/v3/bots/chat/completions", baseUrl), nil
|
||||
return fmt.Sprintf("%s/api/v3/bots/chat/completions", info.ChannelBaseUrl), nil
|
||||
}
|
||||
return fmt.Sprintf("%s/api/v3/chat/completions", baseUrl), nil
|
||||
return fmt.Sprintf("%s/api/v3/chat/completions", info.ChannelBaseUrl), nil
|
||||
case constant.RelayModeEmbeddings:
|
||||
return fmt.Sprintf("%s/api/v3/embeddings", baseUrl), nil
|
||||
return fmt.Sprintf("%s/api/v3/embeddings", info.ChannelBaseUrl), nil
|
||||
case constant.RelayModeImagesGenerations:
|
||||
return fmt.Sprintf("%s/api/v3/images/generations", baseUrl), nil
|
||||
return fmt.Sprintf("%s/api/v3/images/generations", info.ChannelBaseUrl), nil
|
||||
case constant.RelayModeImagesEdits:
|
||||
return fmt.Sprintf("%s/api/v3/images/edits", baseUrl), nil
|
||||
return fmt.Sprintf("%s/api/v3/images/edits", info.ChannelBaseUrl), nil
|
||||
case constant.RelayModeRerank:
|
||||
return fmt.Sprintf("%s/api/v3/rerank", baseUrl), nil
|
||||
return fmt.Sprintf("%s/api/v3/rerank", info.ChannelBaseUrl), nil
|
||||
default:
|
||||
}
|
||||
return "", fmt.Errorf("unsupported relay mode: %d", info.RelayMode)
|
||||
|
||||
@@ -8,12 +8,6 @@ var ModelList = []string{
|
||||
"Doubao-lite-32k",
|
||||
"Doubao-lite-4k",
|
||||
"Doubao-embedding",
|
||||
"doubao-seedream-4-0-250828",
|
||||
"seedream-4-0-250828",
|
||||
"doubao-seedance-1-0-pro-250528",
|
||||
"seedance-1-0-pro-250528",
|
||||
"doubao-seed-1-6-thinking-250715",
|
||||
"seed-1-6-thinking-250715",
|
||||
}
|
||||
|
||||
var ChannelName = "volcengine"
|
||||
|
||||
@@ -207,6 +207,10 @@ func xunfeiMakeRequest(textRequest dto.GeneralOpenAIRequest, domain, authUrl, ap
|
||||
return nil, nil, err
|
||||
}
|
||||
|
||||
defer func() {
|
||||
conn.Close()
|
||||
}()
|
||||
|
||||
data := requestOpenAI2Xunfei(textRequest, appId, domain)
|
||||
err = conn.WriteJSON(data)
|
||||
if err != nil {
|
||||
@@ -216,9 +220,6 @@ func xunfeiMakeRequest(textRequest dto.GeneralOpenAIRequest, domain, authUrl, ap
|
||||
dataChan := make(chan XunfeiChatResponse)
|
||||
stopChan := make(chan bool)
|
||||
go func() {
|
||||
defer func() {
|
||||
conn.Close()
|
||||
}()
|
||||
for {
|
||||
_, msg, err := conn.ReadMessage()
|
||||
if err != nil {
|
||||
|
||||
@@ -6,7 +6,6 @@ import (
|
||||
"io"
|
||||
"net/http"
|
||||
"one-api/common"
|
||||
"one-api/constant"
|
||||
"one-api/dto"
|
||||
relaycommon "one-api/relay/common"
|
||||
"one-api/relay/helper"
|
||||
@@ -70,31 +69,6 @@ func ClaudeHelper(c *gin.Context, info *relaycommon.RelayInfo) (newAPIError *typ
|
||||
info.UpstreamModelName = request.Model
|
||||
}
|
||||
|
||||
if info.ChannelSetting.SystemPrompt != "" {
|
||||
if request.System == nil {
|
||||
request.SetStringSystem(info.ChannelSetting.SystemPrompt)
|
||||
} else if info.ChannelSetting.SystemPromptOverride {
|
||||
common.SetContextKey(c, constant.ContextKeySystemPromptOverride, true)
|
||||
if request.IsStringSystem() {
|
||||
existing := strings.TrimSpace(request.GetStringSystem())
|
||||
if existing == "" {
|
||||
request.SetStringSystem(info.ChannelSetting.SystemPrompt)
|
||||
} else {
|
||||
request.SetStringSystem(info.ChannelSetting.SystemPrompt + "\n" + existing)
|
||||
}
|
||||
} else {
|
||||
systemContents := request.ParseSystem()
|
||||
newSystem := dto.ClaudeMediaMessage{Type: dto.ContentTypeText}
|
||||
newSystem.SetText(info.ChannelSetting.SystemPrompt)
|
||||
if len(systemContents) == 0 {
|
||||
request.System = []dto.ClaudeMediaMessage{newSystem}
|
||||
} else {
|
||||
request.System = append([]dto.ClaudeMediaMessage{newSystem}, systemContents...)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
var requestBody io.Reader
|
||||
if model_setting.GetGlobalSettings().PassThroughRequestEnabled || info.ChannelSetting.PassThroughBodyEnabled {
|
||||
body, err := common.GetRequestBody(c)
|
||||
|
||||
@@ -79,18 +79,34 @@ func ValidateBasicTaskRequest(c *gin.Context, info *RelayInfo, action string) *d
|
||||
req.Images = []string{req.Image}
|
||||
}
|
||||
|
||||
if req.HasImage() {
|
||||
action = constant.TaskActionGenerate
|
||||
if info.ChannelType == constant.ChannelTypeVidu {
|
||||
// vidu 增加 首尾帧生视频和参考图生视频
|
||||
if len(req.Images) == 2 {
|
||||
action = constant.TaskActionFirstTailGenerate
|
||||
} else if len(req.Images) > 2 {
|
||||
action = constant.TaskActionReferenceGenerate
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
storeTaskRequest(c, info, action, req)
|
||||
return nil
|
||||
}
|
||||
|
||||
func ValidateTaskRequestWithImage(c *gin.Context, info *RelayInfo, requestObj interface{}) *dto.TaskError {
|
||||
hasPrompt, ok := requestObj.(HasPrompt)
|
||||
if !ok {
|
||||
return createTaskError(fmt.Errorf("request must have prompt"), "invalid_request", http.StatusBadRequest, true)
|
||||
}
|
||||
|
||||
if taskErr := validatePrompt(hasPrompt.GetPrompt()); taskErr != nil {
|
||||
return taskErr
|
||||
}
|
||||
|
||||
action := constant.TaskActionTextGenerate
|
||||
if hasImage, ok := requestObj.(HasImage); ok && hasImage.HasImage() {
|
||||
action = constant.TaskActionGenerate
|
||||
}
|
||||
|
||||
storeTaskRequest(c, info, action, requestObj)
|
||||
return nil
|
||||
}
|
||||
|
||||
func ValidateTaskRequestWithImageBinding(c *gin.Context, info *RelayInfo) *dto.TaskError {
|
||||
var req TaskSubmitReq
|
||||
if err := c.ShouldBindJSON(&req); err != nil {
|
||||
return createTaskError(err, "invalid_request_body", http.StatusBadRequest, false)
|
||||
}
|
||||
|
||||
return ValidateTaskRequestWithImage(c, info, req)
|
||||
}
|
||||
|
||||
@@ -90,41 +90,39 @@ func TextHelper(c *gin.Context, info *relaycommon.RelayInfo) (newAPIError *types
|
||||
|
||||
if info.ChannelSetting.SystemPrompt != "" {
|
||||
// 如果有系统提示,则将其添加到请求中
|
||||
request, ok := convertedRequest.(*dto.GeneralOpenAIRequest)
|
||||
if ok {
|
||||
containSystemPrompt := false
|
||||
for _, message := range request.Messages {
|
||||
if message.Role == request.GetSystemRoleName() {
|
||||
containSystemPrompt = true
|
||||
break
|
||||
}
|
||||
request := convertedRequest.(*dto.GeneralOpenAIRequest)
|
||||
containSystemPrompt := false
|
||||
for _, message := range request.Messages {
|
||||
if message.Role == request.GetSystemRoleName() {
|
||||
containSystemPrompt = true
|
||||
break
|
||||
}
|
||||
if !containSystemPrompt {
|
||||
// 如果没有系统提示,则添加系统提示
|
||||
systemMessage := dto.Message{
|
||||
Role: request.GetSystemRoleName(),
|
||||
Content: info.ChannelSetting.SystemPrompt,
|
||||
}
|
||||
request.Messages = append([]dto.Message{systemMessage}, request.Messages...)
|
||||
} else if info.ChannelSetting.SystemPromptOverride {
|
||||
common.SetContextKey(c, constant.ContextKeySystemPromptOverride, true)
|
||||
// 如果有系统提示,且允许覆盖,则拼接到前面
|
||||
for i, message := range request.Messages {
|
||||
if message.Role == request.GetSystemRoleName() {
|
||||
if message.IsStringContent() {
|
||||
request.Messages[i].SetStringContent(info.ChannelSetting.SystemPrompt + "\n" + message.StringContent())
|
||||
} else {
|
||||
contents := message.ParseContent()
|
||||
contents = append([]dto.MediaContent{
|
||||
{
|
||||
Type: dto.ContentTypeText,
|
||||
Text: info.ChannelSetting.SystemPrompt,
|
||||
},
|
||||
}, contents...)
|
||||
request.Messages[i].Content = contents
|
||||
}
|
||||
break
|
||||
}
|
||||
if !containSystemPrompt {
|
||||
// 如果没有系统提示,则添加系统提示
|
||||
systemMessage := dto.Message{
|
||||
Role: request.GetSystemRoleName(),
|
||||
Content: info.ChannelSetting.SystemPrompt,
|
||||
}
|
||||
request.Messages = append([]dto.Message{systemMessage}, request.Messages...)
|
||||
} else if info.ChannelSetting.SystemPromptOverride {
|
||||
common.SetContextKey(c, constant.ContextKeySystemPromptOverride, true)
|
||||
// 如果有系统提示,且允许覆盖,则拼接到前面
|
||||
for i, message := range request.Messages {
|
||||
if message.Role == request.GetSystemRoleName() {
|
||||
if message.IsStringContent() {
|
||||
request.Messages[i].SetStringContent(info.ChannelSetting.SystemPrompt + "\n" + message.StringContent())
|
||||
} else {
|
||||
contents := message.ParseContent()
|
||||
contents = append([]dto.MediaContent{
|
||||
{
|
||||
Type: dto.ContentTypeText,
|
||||
Text: info.ChannelSetting.SystemPrompt,
|
||||
},
|
||||
}, contents...)
|
||||
request.Messages[i].Content = contents
|
||||
}
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -278,13 +276,6 @@ func postConsumeQuota(ctx *gin.Context, relayInfo *relaycommon.RelayInfo, usage
|
||||
fileSearchTool.CallCount, dFileSearchQuota.String())
|
||||
}
|
||||
}
|
||||
var dImageGenerationCallQuota decimal.Decimal
|
||||
var imageGenerationCallPrice float64
|
||||
if ctx.GetBool("image_generation_call") {
|
||||
imageGenerationCallPrice = operation_setting.GetGPTImage1PriceOnceCall(ctx.GetString("image_generation_call_quality"), ctx.GetString("image_generation_call_size"))
|
||||
dImageGenerationCallQuota = decimal.NewFromFloat(imageGenerationCallPrice).Mul(dGroupRatio).Mul(dQuotaPerUnit)
|
||||
extraContent += fmt.Sprintf("Image Generation Call 花费 %s", dImageGenerationCallQuota.String())
|
||||
}
|
||||
|
||||
var quotaCalculateDecimal decimal.Decimal
|
||||
|
||||
@@ -340,8 +331,6 @@ func postConsumeQuota(ctx *gin.Context, relayInfo *relaycommon.RelayInfo, usage
|
||||
quotaCalculateDecimal = quotaCalculateDecimal.Add(dFileSearchQuota)
|
||||
// 添加 audio input 独立计费
|
||||
quotaCalculateDecimal = quotaCalculateDecimal.Add(audioInputQuota)
|
||||
// 添加 image generation call 计费
|
||||
quotaCalculateDecimal = quotaCalculateDecimal.Add(dImageGenerationCallQuota)
|
||||
|
||||
quota := int(quotaCalculateDecimal.Round(0).IntPart())
|
||||
totalTokens := promptTokens + completionTokens
|
||||
@@ -440,10 +429,6 @@ func postConsumeQuota(ctx *gin.Context, relayInfo *relaycommon.RelayInfo, usage
|
||||
other["audio_input_token_count"] = audioTokens
|
||||
other["audio_input_price"] = audioInputPrice
|
||||
}
|
||||
if !dImageGenerationCallQuota.IsZero() {
|
||||
other["image_generation_call"] = true
|
||||
other["image_generation_call_price"] = imageGenerationCallPrice
|
||||
}
|
||||
model.RecordConsumeLog(ctx, relayInfo.UserId, model.RecordConsumeLogParams{
|
||||
ChannelId: relayInfo.ChannelId,
|
||||
PromptTokens: promptTokens,
|
||||
|
||||
@@ -6,7 +6,6 @@ import (
|
||||
"io"
|
||||
"net/http"
|
||||
"one-api/common"
|
||||
"one-api/constant"
|
||||
"one-api/dto"
|
||||
"one-api/logger"
|
||||
"one-api/relay/channel/gemini"
|
||||
@@ -95,32 +94,6 @@ func GeminiHelper(c *gin.Context, info *relaycommon.RelayInfo) (newAPIError *typ
|
||||
|
||||
adaptor.Init(info)
|
||||
|
||||
if info.ChannelSetting.SystemPrompt != "" {
|
||||
if request.SystemInstructions == nil {
|
||||
request.SystemInstructions = &dto.GeminiChatContent{
|
||||
Parts: []dto.GeminiPart{
|
||||
{Text: info.ChannelSetting.SystemPrompt},
|
||||
},
|
||||
}
|
||||
} else if len(request.SystemInstructions.Parts) == 0 {
|
||||
request.SystemInstructions.Parts = []dto.GeminiPart{{Text: info.ChannelSetting.SystemPrompt}}
|
||||
} else if info.ChannelSetting.SystemPromptOverride {
|
||||
common.SetContextKey(c, constant.ContextKeySystemPromptOverride, true)
|
||||
merged := false
|
||||
for i := range request.SystemInstructions.Parts {
|
||||
if request.SystemInstructions.Parts[i].Text == "" {
|
||||
continue
|
||||
}
|
||||
request.SystemInstructions.Parts[i].Text = info.ChannelSetting.SystemPrompt + "\n" + request.SystemInstructions.Parts[i].Text
|
||||
merged = true
|
||||
break
|
||||
}
|
||||
if !merged {
|
||||
request.SystemInstructions.Parts = append([]dto.GeminiPart{{Text: info.ChannelSetting.SystemPrompt}}, request.SystemInstructions.Parts...)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Clean up empty system instruction
|
||||
if request.SystemInstructions != nil {
|
||||
hasContent := false
|
||||
|
||||
@@ -52,8 +52,6 @@ func ModelPriceHelper(c *gin.Context, info *relaycommon.RelayInfo, promptTokens
|
||||
var cacheRatio float64
|
||||
var imageRatio float64
|
||||
var cacheCreationRatio float64
|
||||
var audioRatio float64
|
||||
var audioCompletionRatio float64
|
||||
if !usePrice {
|
||||
preConsumedTokens := common.Max(promptTokens, common.PreConsumedQuota)
|
||||
if meta.MaxTokens != 0 {
|
||||
@@ -75,8 +73,6 @@ func ModelPriceHelper(c *gin.Context, info *relaycommon.RelayInfo, promptTokens
|
||||
cacheRatio, _ = ratio_setting.GetCacheRatio(info.OriginModelName)
|
||||
cacheCreationRatio, _ = ratio_setting.GetCreateCacheRatio(info.OriginModelName)
|
||||
imageRatio, _ = ratio_setting.GetImageRatio(info.OriginModelName)
|
||||
audioRatio = ratio_setting.GetAudioRatio(info.OriginModelName)
|
||||
audioCompletionRatio = ratio_setting.GetAudioCompletionRatio(info.OriginModelName)
|
||||
ratio := modelRatio * groupRatioInfo.GroupRatio
|
||||
preConsumedQuota = int(float64(preConsumedTokens) * ratio)
|
||||
} else {
|
||||
@@ -94,8 +90,6 @@ func ModelPriceHelper(c *gin.Context, info *relaycommon.RelayInfo, promptTokens
|
||||
UsePrice: usePrice,
|
||||
CacheRatio: cacheRatio,
|
||||
ImageRatio: imageRatio,
|
||||
AudioRatio: audioRatio,
|
||||
AudioCompletionRatio: audioCompletionRatio,
|
||||
CacheCreationRatio: cacheCreationRatio,
|
||||
ShouldPreConsumedQuota: preConsumedQuota,
|
||||
}
|
||||
|
||||
@@ -21,11 +21,7 @@ func GetAndValidateRequest(c *gin.Context, format types.RelayFormat) (request dt
|
||||
case types.RelayFormatOpenAI:
|
||||
request, err = GetAndValidateTextRequest(c, relayMode)
|
||||
case types.RelayFormatGemini:
|
||||
if strings.Contains(c.Request.URL.Path, ":embedContent") || strings.Contains(c.Request.URL.Path, ":batchEmbedContents") {
|
||||
request, err = GetAndValidateGeminiEmbeddingRequest(c)
|
||||
} else {
|
||||
request, err = GetAndValidateGeminiRequest(c)
|
||||
}
|
||||
request, err = GetAndValidateGeminiRequest(c)
|
||||
case types.RelayFormatClaude:
|
||||
request, err = GetAndValidateClaudeRequest(c)
|
||||
case types.RelayFormatOpenAIResponses:
|
||||
@@ -292,6 +288,7 @@ func GetAndValidateTextRequest(c *gin.Context, relayMode int) (*dto.GeneralOpenA
|
||||
}
|
||||
|
||||
func GetAndValidateGeminiRequest(c *gin.Context) (*dto.GeminiChatRequest, error) {
|
||||
|
||||
request := &dto.GeminiChatRequest{}
|
||||
err := common.UnmarshalBodyReusable(c, request)
|
||||
if err != nil {
|
||||
@@ -307,12 +304,3 @@ func GetAndValidateGeminiRequest(c *gin.Context) (*dto.GeminiChatRequest, error)
|
||||
|
||||
return request, nil
|
||||
}
|
||||
|
||||
func GetAndValidateGeminiEmbeddingRequest(c *gin.Context) (*dto.GeminiEmbeddingRequest, error) {
|
||||
request := &dto.GeminiEmbeddingRequest{}
|
||||
err := common.UnmarshalBodyReusable(c, request)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return request, nil
|
||||
}
|
||||
|
||||
@@ -31,6 +31,21 @@ func SetApiRouter(router *gin.Engine) {
|
||||
apiRouter.GET("/oauth/oidc", middleware.CriticalRateLimit(), controller.OidcAuth)
|
||||
apiRouter.GET("/oauth/linuxdo", middleware.CriticalRateLimit(), controller.LinuxdoOAuth)
|
||||
apiRouter.GET("/oauth/state", middleware.CriticalRateLimit(), controller.GenerateOAuthCode)
|
||||
|
||||
// OAuth2 Server endpoints
|
||||
apiRouter.GET("/.well-known/jwks.json", controller.GetJWKS)
|
||||
apiRouter.GET("/.well-known/openid-configuration", controller.OAuthOIDCConfiguration)
|
||||
apiRouter.GET("/.well-known/oauth-authorization-server", controller.OAuthServerInfo)
|
||||
apiRouter.POST("/oauth/token", middleware.CriticalRateLimit(), controller.OAuthTokenEndpoint)
|
||||
apiRouter.GET("/oauth/authorize", controller.OAuthAuthorizeEndpoint)
|
||||
apiRouter.POST("/oauth/introspect", middleware.AdminAuth(), controller.OAuthIntrospect)
|
||||
apiRouter.POST("/oauth/revoke", middleware.CriticalRateLimit(), controller.OAuthRevoke)
|
||||
apiRouter.GET("/oauth/userinfo", middleware.OAuthJWTAuth(), controller.OAuthUserInfo)
|
||||
|
||||
// OAuth2 管理API (前端使用)
|
||||
apiRouter.GET("/oauth/jwks", controller.GetJWKS)
|
||||
apiRouter.GET("/oauth/server-info", controller.OAuthServerInfo)
|
||||
|
||||
apiRouter.GET("/oauth/wechat", middleware.CriticalRateLimit(), controller.WeChatAuth)
|
||||
apiRouter.GET("/oauth/wechat/bind", middleware.CriticalRateLimit(), controller.WeChatBind)
|
||||
apiRouter.GET("/oauth/email/bind", middleware.CriticalRateLimit(), controller.EmailBind)
|
||||
@@ -40,6 +55,17 @@ func SetApiRouter(router *gin.Engine) {
|
||||
|
||||
apiRouter.POST("/stripe/webhook", controller.StripeWebhook)
|
||||
|
||||
// OAuth2 admin operations
|
||||
oauthAdmin := apiRouter.Group("/oauth")
|
||||
oauthAdmin.Use(middleware.OptionalOAuthAuth(), middleware.RequireOAuthScopeIfPresent("admin"), middleware.RootAuth())
|
||||
{
|
||||
oauthAdmin.POST("/keys/rotate", controller.RotateOAuthSigningKey)
|
||||
oauthAdmin.GET("/keys", controller.ListOAuthSigningKeys)
|
||||
oauthAdmin.DELETE("/keys/:kid", controller.DeleteOAuthSigningKey)
|
||||
oauthAdmin.POST("/keys/generate_file", controller.GenerateOAuthSigningKeyFile)
|
||||
oauthAdmin.POST("/keys/import_pem", controller.ImportOAuthSigningKey)
|
||||
}
|
||||
|
||||
userRoute := apiRouter.Group("/user")
|
||||
{
|
||||
userRoute.POST("/register", middleware.CriticalRateLimit(), middleware.TurnstileCheck(), controller.Register)
|
||||
@@ -78,7 +104,7 @@ func SetApiRouter(router *gin.Engine) {
|
||||
}
|
||||
|
||||
adminRoute := userRoute.Group("/")
|
||||
adminRoute.Use(middleware.AdminAuth())
|
||||
adminRoute.Use(middleware.OptionalOAuthAuth(), middleware.RequireOAuthScopeIfPresent("admin"), middleware.AdminAuth())
|
||||
{
|
||||
adminRoute.GET("/", controller.GetAllUsers)
|
||||
adminRoute.GET("/search", controller.SearchUsers)
|
||||
@@ -94,7 +120,7 @@ func SetApiRouter(router *gin.Engine) {
|
||||
}
|
||||
}
|
||||
optionRoute := apiRouter.Group("/option")
|
||||
optionRoute.Use(middleware.RootAuth())
|
||||
optionRoute.Use(middleware.OptionalOAuthAuth(), middleware.RequireOAuthScopeIfPresent("admin"), middleware.RootAuth())
|
||||
{
|
||||
optionRoute.GET("/", controller.GetOptions)
|
||||
optionRoute.PUT("/", controller.UpdateOption)
|
||||
@@ -108,7 +134,7 @@ func SetApiRouter(router *gin.Engine) {
|
||||
ratioSyncRoute.POST("/fetch", controller.FetchUpstreamRatios)
|
||||
}
|
||||
channelRoute := apiRouter.Group("/channel")
|
||||
channelRoute.Use(middleware.AdminAuth())
|
||||
channelRoute.Use(middleware.OptionalOAuthAuth(), middleware.RequireOAuthScopeIfPresent("admin"), middleware.AdminAuth())
|
||||
{
|
||||
channelRoute.GET("/", controller.GetAllChannels)
|
||||
channelRoute.GET("/search", controller.SearchChannels)
|
||||
@@ -159,7 +185,7 @@ func SetApiRouter(router *gin.Engine) {
|
||||
}
|
||||
|
||||
redemptionRoute := apiRouter.Group("/redemption")
|
||||
redemptionRoute.Use(middleware.AdminAuth())
|
||||
redemptionRoute.Use(middleware.OptionalOAuthAuth(), middleware.RequireOAuthScopeIfPresent("admin"), middleware.AdminAuth())
|
||||
{
|
||||
redemptionRoute.GET("/", controller.GetAllRedemptions)
|
||||
redemptionRoute.GET("/search", controller.SearchRedemptions)
|
||||
@@ -187,13 +213,13 @@ func SetApiRouter(router *gin.Engine) {
|
||||
logRoute.GET("/token", controller.GetLogByKey)
|
||||
}
|
||||
groupRoute := apiRouter.Group("/group")
|
||||
groupRoute.Use(middleware.AdminAuth())
|
||||
groupRoute.Use(middleware.OptionalOAuthAuth(), middleware.RequireOAuthScopeIfPresent("admin"), middleware.AdminAuth())
|
||||
{
|
||||
groupRoute.GET("/", controller.GetGroups)
|
||||
}
|
||||
|
||||
prefillGroupRoute := apiRouter.Group("/prefill_group")
|
||||
prefillGroupRoute.Use(middleware.AdminAuth())
|
||||
prefillGroupRoute.Use(middleware.OptionalOAuthAuth(), middleware.RequireOAuthScopeIfPresent("admin"), middleware.AdminAuth())
|
||||
{
|
||||
prefillGroupRoute.GET("/", controller.GetPrefillGroups)
|
||||
prefillGroupRoute.POST("/", controller.CreatePrefillGroup)
|
||||
@@ -235,5 +261,17 @@ func SetApiRouter(router *gin.Engine) {
|
||||
modelsRoute.PUT("/", controller.UpdateModelMeta)
|
||||
modelsRoute.DELETE("/:id", controller.DeleteModelMeta)
|
||||
}
|
||||
|
||||
oauthClientsRoute := apiRouter.Group("/oauth_clients")
|
||||
oauthClientsRoute.Use(middleware.AdminAuth())
|
||||
{
|
||||
oauthClientsRoute.GET("/", controller.GetAllOAuthClients)
|
||||
oauthClientsRoute.GET("/search", controller.SearchOAuthClients)
|
||||
oauthClientsRoute.GET("/:id", controller.GetOAuthClient)
|
||||
oauthClientsRoute.POST("/", controller.CreateOAuthClient)
|
||||
oauthClientsRoute.PUT("/", controller.UpdateOAuthClient)
|
||||
oauthClientsRoute.DELETE("/:id", controller.DeleteOAuthClient)
|
||||
oauthClientsRoute.POST("/:id/regenerate_secret", controller.RegenerateOAuthClientSecret)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -28,12 +28,6 @@ func DoWorkerRequest(req *WorkerRequest) (*http.Response, error) {
|
||||
return nil, fmt.Errorf("only support https url")
|
||||
}
|
||||
|
||||
// SSRF防护:验证请求URL
|
||||
fetchSetting := system_setting.GetFetchSetting()
|
||||
if err := common.ValidateURLWithFetchSetting(req.URL, fetchSetting.EnableSSRFProtection, fetchSetting.AllowPrivateIp, fetchSetting.DomainFilterMode, fetchSetting.IpFilterMode, fetchSetting.DomainList, fetchSetting.IpList, fetchSetting.AllowedPorts, fetchSetting.ApplyIPFilterForDomain); err != nil {
|
||||
return nil, fmt.Errorf("request reject: %v", err)
|
||||
}
|
||||
|
||||
workerUrl := system_setting.WorkerUrl
|
||||
if !strings.HasSuffix(workerUrl, "/") {
|
||||
workerUrl += "/"
|
||||
@@ -57,13 +51,7 @@ func DoDownloadRequest(originUrl string, reason ...string) (resp *http.Response,
|
||||
}
|
||||
return DoWorkerRequest(req)
|
||||
} else {
|
||||
// SSRF防护:验证请求URL(非Worker模式)
|
||||
fetchSetting := system_setting.GetFetchSetting()
|
||||
if err := common.ValidateURLWithFetchSetting(originUrl, fetchSetting.EnableSSRFProtection, fetchSetting.AllowPrivateIp, fetchSetting.DomainFilterMode, fetchSetting.IpFilterMode, fetchSetting.DomainList, fetchSetting.IpList, fetchSetting.AllowedPorts, fetchSetting.ApplyIPFilterForDomain); err != nil {
|
||||
return nil, fmt.Errorf("request reject: %v", err)
|
||||
}
|
||||
|
||||
common.SysLog(fmt.Sprintf("downloading from origin: %s, reason: %s", common.MaskSensitiveInfo(originUrl), strings.Join(reason, ", ")))
|
||||
common.SysLog(fmt.Sprintf("downloading from origin with worker: %s, reason: %s", originUrl, strings.Join(reason, ", ")))
|
||||
return http.Get(originUrl)
|
||||
}
|
||||
}
|
||||
@@ -7,17 +7,12 @@ import (
|
||||
"net/http"
|
||||
"net/url"
|
||||
"one-api/common"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"golang.org/x/net/proxy"
|
||||
)
|
||||
|
||||
var (
|
||||
httpClient *http.Client
|
||||
proxyClientLock sync.Mutex
|
||||
proxyClients = make(map[string]*http.Client)
|
||||
)
|
||||
var httpClient *http.Client
|
||||
|
||||
func InitHttpClient() {
|
||||
if common.RelayTimeout == 0 {
|
||||
@@ -33,31 +28,12 @@ func GetHttpClient() *http.Client {
|
||||
return httpClient
|
||||
}
|
||||
|
||||
// ResetProxyClientCache 清空代理客户端缓存,确保下次使用时重新初始化
|
||||
func ResetProxyClientCache() {
|
||||
proxyClientLock.Lock()
|
||||
defer proxyClientLock.Unlock()
|
||||
for _, client := range proxyClients {
|
||||
if transport, ok := client.Transport.(*http.Transport); ok && transport != nil {
|
||||
transport.CloseIdleConnections()
|
||||
}
|
||||
}
|
||||
proxyClients = make(map[string]*http.Client)
|
||||
}
|
||||
|
||||
// NewProxyHttpClient 创建支持代理的 HTTP 客户端
|
||||
func NewProxyHttpClient(proxyURL string) (*http.Client, error) {
|
||||
if proxyURL == "" {
|
||||
return http.DefaultClient, nil
|
||||
}
|
||||
|
||||
proxyClientLock.Lock()
|
||||
if client, ok := proxyClients[proxyURL]; ok {
|
||||
proxyClientLock.Unlock()
|
||||
return client, nil
|
||||
}
|
||||
proxyClientLock.Unlock()
|
||||
|
||||
parsedURL, err := url.Parse(proxyURL)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
@@ -65,16 +41,11 @@ func NewProxyHttpClient(proxyURL string) (*http.Client, error) {
|
||||
|
||||
switch parsedURL.Scheme {
|
||||
case "http", "https":
|
||||
client := &http.Client{
|
||||
return &http.Client{
|
||||
Transport: &http.Transport{
|
||||
Proxy: http.ProxyURL(parsedURL),
|
||||
},
|
||||
}
|
||||
client.Timeout = time.Duration(common.RelayTimeout) * time.Second
|
||||
proxyClientLock.Lock()
|
||||
proxyClients[proxyURL] = client
|
||||
proxyClientLock.Unlock()
|
||||
return client, nil
|
||||
}, nil
|
||||
|
||||
case "socks5", "socks5h":
|
||||
// 获取认证信息
|
||||
@@ -96,18 +67,13 @@ func NewProxyHttpClient(proxyURL string) (*http.Client, error) {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
client := &http.Client{
|
||||
return &http.Client{
|
||||
Transport: &http.Transport{
|
||||
DialContext: func(ctx context.Context, network, addr string) (net.Conn, error) {
|
||||
return dialer.Dial(network, addr)
|
||||
},
|
||||
},
|
||||
}
|
||||
client.Timeout = time.Duration(common.RelayTimeout) * time.Second
|
||||
proxyClientLock.Lock()
|
||||
proxyClients[proxyURL] = client
|
||||
proxyClientLock.Unlock()
|
||||
return client, nil
|
||||
}, nil
|
||||
|
||||
default:
|
||||
return nil, fmt.Errorf("unsupported proxy scheme: %s", parsedURL.Scheme)
|
||||
|
||||
@@ -19,7 +19,7 @@ func ReturnPreConsumedQuota(c *gin.Context, relayInfo *relaycommon.RelayInfo) {
|
||||
gopool.Go(func() {
|
||||
relayInfoCopy := *relayInfo
|
||||
|
||||
err := PostConsumeQuota(&relayInfoCopy, -relayInfoCopy.FinalPreConsumedQuota, 0, false)
|
||||
err := PostConsumeQuota(&relayInfoCopy, -relayInfo.FinalPreConsumedQuota, 0, false)
|
||||
if err != nil {
|
||||
common.SysLog("error return pre-consumed quota: " + err.Error())
|
||||
}
|
||||
|
||||
@@ -113,12 +113,6 @@ func sendBarkNotify(barkURL string, data dto.Notify) error {
|
||||
return fmt.Errorf("bark request failed with status code: %d", resp.StatusCode)
|
||||
}
|
||||
} else {
|
||||
// SSRF防护:验证Bark URL(非Worker模式)
|
||||
fetchSetting := system_setting.GetFetchSetting()
|
||||
if err := common.ValidateURLWithFetchSetting(finalURL, fetchSetting.EnableSSRFProtection, fetchSetting.AllowPrivateIp, fetchSetting.DomainFilterMode, fetchSetting.IpFilterMode, fetchSetting.DomainList, fetchSetting.IpList, fetchSetting.AllowedPorts, fetchSetting.ApplyIPFilterForDomain); err != nil {
|
||||
return fmt.Errorf("request reject: %v", err)
|
||||
}
|
||||
|
||||
// 直接发送请求
|
||||
req, err = http.NewRequest(http.MethodGet, finalURL, nil)
|
||||
if err != nil {
|
||||
|
||||
@@ -8,7 +8,6 @@ import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"one-api/common"
|
||||
"one-api/dto"
|
||||
"one-api/setting/system_setting"
|
||||
"time"
|
||||
@@ -87,12 +86,6 @@ func SendWebhookNotify(webhookURL string, secret string, data dto.Notify) error
|
||||
return fmt.Errorf("webhook request failed with status code: %d", resp.StatusCode)
|
||||
}
|
||||
} else {
|
||||
// SSRF防护:验证Webhook URL(非Worker模式)
|
||||
fetchSetting := system_setting.GetFetchSetting()
|
||||
if err := common.ValidateURLWithFetchSetting(webhookURL, fetchSetting.EnableSSRFProtection, fetchSetting.AllowPrivateIp, fetchSetting.DomainFilterMode, fetchSetting.IpFilterMode, fetchSetting.DomainList, fetchSetting.IpList, fetchSetting.AllowedPorts, fetchSetting.ApplyIPFilterForDomain); err != nil {
|
||||
return fmt.Errorf("request reject: %v", err)
|
||||
}
|
||||
|
||||
req, err = http.NewRequest(http.MethodPost, webhookURL, bytes.NewBuffer(payloadBytes))
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to create webhook request: %v", err)
|
||||
|
||||
@@ -10,18 +10,6 @@ const (
|
||||
FileSearchPrice = 2.5
|
||||
)
|
||||
|
||||
const (
|
||||
GPTImage1Low1024x1024 = 0.011
|
||||
GPTImage1Low1024x1536 = 0.016
|
||||
GPTImage1Low1536x1024 = 0.016
|
||||
GPTImage1Medium1024x1024 = 0.042
|
||||
GPTImage1Medium1024x1536 = 0.063
|
||||
GPTImage1Medium1536x1024 = 0.063
|
||||
GPTImage1High1024x1024 = 0.167
|
||||
GPTImage1High1024x1536 = 0.25
|
||||
GPTImage1High1536x1024 = 0.25
|
||||
)
|
||||
|
||||
const (
|
||||
// Gemini Audio Input Price
|
||||
Gemini25FlashPreviewInputAudioPrice = 1.00
|
||||
@@ -77,31 +65,3 @@ func GetGeminiInputAudioPricePerMillionTokens(modelName string) float64 {
|
||||
}
|
||||
return 0
|
||||
}
|
||||
|
||||
func GetGPTImage1PriceOnceCall(quality string, size string) float64 {
|
||||
prices := map[string]map[string]float64{
|
||||
"low": {
|
||||
"1024x1024": GPTImage1Low1024x1024,
|
||||
"1024x1536": GPTImage1Low1024x1536,
|
||||
"1536x1024": GPTImage1Low1536x1024,
|
||||
},
|
||||
"medium": {
|
||||
"1024x1024": GPTImage1Medium1024x1024,
|
||||
"1024x1536": GPTImage1Medium1024x1536,
|
||||
"1536x1024": GPTImage1Medium1536x1024,
|
||||
},
|
||||
"high": {
|
||||
"1024x1024": GPTImage1High1024x1024,
|
||||
"1024x1536": GPTImage1High1024x1536,
|
||||
"1536x1024": GPTImage1High1536x1024,
|
||||
},
|
||||
}
|
||||
|
||||
if qualityMap, exists := prices[quality]; exists {
|
||||
if price, exists := qualityMap[size]; exists {
|
||||
return price
|
||||
}
|
||||
}
|
||||
|
||||
return GPTImage1High1024x1024
|
||||
}
|
||||
|
||||
@@ -5,4 +5,3 @@ var StripeWebhookSecret = ""
|
||||
var StripePriceId = ""
|
||||
var StripeUnitPrice = 8.0
|
||||
var StripeMinTopUp = 1
|
||||
var StripePromotionCodesEnabled = false
|
||||
|
||||
@@ -178,7 +178,6 @@ var defaultModelRatio = map[string]float64{
|
||||
"gemini-2.5-flash-lite-preview-thinking-*": 0.05,
|
||||
"gemini-2.5-flash-lite-preview-06-17": 0.05,
|
||||
"gemini-2.5-flash": 0.15,
|
||||
"gemini-embedding-001": 0.075,
|
||||
"text-embedding-004": 0.001,
|
||||
"chatglm_turbo": 0.3572, // ¥0.005 / 1k tokens
|
||||
"chatglm_pro": 0.7143, // ¥0.01 / 1k tokens
|
||||
@@ -279,18 +278,6 @@ var defaultModelPrice = map[string]float64{
|
||||
"mj_upload": 0.05,
|
||||
}
|
||||
|
||||
var defaultAudioRatio = map[string]float64{
|
||||
"gpt-4o-audio-preview": 16,
|
||||
"gpt-4o-mini-audio-preview": 66.67,
|
||||
"gpt-4o-realtime-preview": 8,
|
||||
"gpt-4o-mini-realtime-preview": 16.67,
|
||||
}
|
||||
|
||||
var defaultAudioCompletionRatio = map[string]float64{
|
||||
"gpt-4o-realtime": 2,
|
||||
"gpt-4o-mini-realtime": 2,
|
||||
}
|
||||
|
||||
var (
|
||||
modelPriceMap map[string]float64 = nil
|
||||
modelPriceMapMutex = sync.RWMutex{}
|
||||
@@ -339,15 +326,6 @@ func InitRatioSettings() {
|
||||
imageRatioMap = defaultImageRatio
|
||||
imageRatioMapMutex.Unlock()
|
||||
|
||||
// initialize audioRatioMap
|
||||
audioRatioMapMutex.Lock()
|
||||
audioRatioMap = defaultAudioRatio
|
||||
audioRatioMapMutex.Unlock()
|
||||
|
||||
// initialize audioCompletionRatioMap
|
||||
audioCompletionRatioMapMutex.Lock()
|
||||
audioCompletionRatioMap = defaultAudioCompletionRatio
|
||||
audioCompletionRatioMapMutex.Unlock()
|
||||
}
|
||||
|
||||
func GetModelPriceMap() map[string]float64 {
|
||||
@@ -439,18 +417,6 @@ func GetDefaultModelRatioMap() map[string]float64 {
|
||||
return defaultModelRatio
|
||||
}
|
||||
|
||||
func GetDefaultImageRatioMap() map[string]float64 {
|
||||
return defaultImageRatio
|
||||
}
|
||||
|
||||
func GetDefaultAudioRatioMap() map[string]float64 {
|
||||
return defaultAudioRatio
|
||||
}
|
||||
|
||||
func GetDefaultAudioCompletionRatioMap() map[string]float64 {
|
||||
return defaultAudioCompletionRatio
|
||||
}
|
||||
|
||||
func GetCompletionRatioMap() map[string]float64 {
|
||||
CompletionRatioMutex.RLock()
|
||||
defer CompletionRatioMutex.RUnlock()
|
||||
@@ -618,22 +584,32 @@ func getHardcodedCompletionModelRatio(name string) (float64, bool) {
|
||||
}
|
||||
|
||||
func GetAudioRatio(name string) float64 {
|
||||
audioRatioMapMutex.RLock()
|
||||
defer audioRatioMapMutex.RUnlock()
|
||||
name = FormatMatchingModelName(name)
|
||||
if ratio, ok := audioRatioMap[name]; ok {
|
||||
return ratio
|
||||
if strings.Contains(name, "-realtime") {
|
||||
if strings.HasSuffix(name, "gpt-4o-realtime-preview") {
|
||||
return 8
|
||||
} else if strings.Contains(name, "gpt-4o-mini-realtime-preview") {
|
||||
return 10 / 0.6
|
||||
} else {
|
||||
return 20
|
||||
}
|
||||
}
|
||||
if strings.Contains(name, "-audio") {
|
||||
if strings.HasPrefix(name, "gpt-4o-audio-preview") {
|
||||
return 40 / 2.5
|
||||
} else if strings.HasPrefix(name, "gpt-4o-mini-audio-preview") {
|
||||
return 10 / 0.15
|
||||
} else {
|
||||
return 40
|
||||
}
|
||||
}
|
||||
return 20
|
||||
}
|
||||
|
||||
func GetAudioCompletionRatio(name string) float64 {
|
||||
audioCompletionRatioMapMutex.RLock()
|
||||
defer audioCompletionRatioMapMutex.RUnlock()
|
||||
name = FormatMatchingModelName(name)
|
||||
if ratio, ok := audioCompletionRatioMap[name]; ok {
|
||||
|
||||
return ratio
|
||||
if strings.HasPrefix(name, "gpt-4o-realtime") {
|
||||
return 2
|
||||
} else if strings.HasPrefix(name, "gpt-4o-mini-realtime") {
|
||||
return 2
|
||||
}
|
||||
return 2
|
||||
}
|
||||
@@ -654,14 +630,6 @@ var defaultImageRatio = map[string]float64{
|
||||
}
|
||||
var imageRatioMap map[string]float64
|
||||
var imageRatioMapMutex sync.RWMutex
|
||||
var (
|
||||
audioRatioMap map[string]float64 = nil
|
||||
audioRatioMapMutex = sync.RWMutex{}
|
||||
)
|
||||
var (
|
||||
audioCompletionRatioMap map[string]float64 = nil
|
||||
audioCompletionRatioMapMutex = sync.RWMutex{}
|
||||
)
|
||||
|
||||
func ImageRatio2JSONString() string {
|
||||
imageRatioMapMutex.RLock()
|
||||
@@ -690,71 +658,6 @@ func GetImageRatio(name string) (float64, bool) {
|
||||
return ratio, true
|
||||
}
|
||||
|
||||
func AudioRatio2JSONString() string {
|
||||
audioRatioMapMutex.RLock()
|
||||
defer audioRatioMapMutex.RUnlock()
|
||||
jsonBytes, err := common.Marshal(audioRatioMap)
|
||||
if err != nil {
|
||||
common.SysError("error marshalling audio ratio: " + err.Error())
|
||||
}
|
||||
return string(jsonBytes)
|
||||
}
|
||||
|
||||
func UpdateAudioRatioByJSONString(jsonStr string) error {
|
||||
|
||||
tmp := make(map[string]float64)
|
||||
if err := common.Unmarshal([]byte(jsonStr), &tmp); err != nil {
|
||||
return err
|
||||
}
|
||||
audioRatioMapMutex.Lock()
|
||||
audioRatioMap = tmp
|
||||
audioRatioMapMutex.Unlock()
|
||||
InvalidateExposedDataCache()
|
||||
return nil
|
||||
}
|
||||
|
||||
func GetAudioRatioCopy() map[string]float64 {
|
||||
audioRatioMapMutex.RLock()
|
||||
defer audioRatioMapMutex.RUnlock()
|
||||
copyMap := make(map[string]float64, len(audioRatioMap))
|
||||
for k, v := range audioRatioMap {
|
||||
copyMap[k] = v
|
||||
}
|
||||
return copyMap
|
||||
}
|
||||
|
||||
func AudioCompletionRatio2JSONString() string {
|
||||
audioCompletionRatioMapMutex.RLock()
|
||||
defer audioCompletionRatioMapMutex.RUnlock()
|
||||
jsonBytes, err := common.Marshal(audioCompletionRatioMap)
|
||||
if err != nil {
|
||||
common.SysError("error marshalling audio completion ratio: " + err.Error())
|
||||
}
|
||||
return string(jsonBytes)
|
||||
}
|
||||
|
||||
func UpdateAudioCompletionRatioByJSONString(jsonStr string) error {
|
||||
tmp := make(map[string]float64)
|
||||
if err := common.Unmarshal([]byte(jsonStr), &tmp); err != nil {
|
||||
return err
|
||||
}
|
||||
audioCompletionRatioMapMutex.Lock()
|
||||
audioCompletionRatioMap = tmp
|
||||
audioCompletionRatioMapMutex.Unlock()
|
||||
InvalidateExposedDataCache()
|
||||
return nil
|
||||
}
|
||||
|
||||
func GetAudioCompletionRatioCopy() map[string]float64 {
|
||||
audioCompletionRatioMapMutex.RLock()
|
||||
defer audioCompletionRatioMapMutex.RUnlock()
|
||||
copyMap := make(map[string]float64, len(audioCompletionRatioMap))
|
||||
for k, v := range audioCompletionRatioMap {
|
||||
copyMap[k] = v
|
||||
}
|
||||
return copyMap
|
||||
}
|
||||
|
||||
func GetModelRatioCopy() map[string]float64 {
|
||||
modelRatioMapMutex.RLock()
|
||||
defer modelRatioMapMutex.RUnlock()
|
||||
|
||||
@@ -1,34 +0,0 @@
|
||||
package system_setting
|
||||
|
||||
import "one-api/setting/config"
|
||||
|
||||
type FetchSetting struct {
|
||||
EnableSSRFProtection bool `json:"enable_ssrf_protection"` // 是否启用SSRF防护
|
||||
AllowPrivateIp bool `json:"allow_private_ip"`
|
||||
DomainFilterMode bool `json:"domain_filter_mode"` // 域名过滤模式,true: 白名单模式,false: 黑名单模式
|
||||
IpFilterMode bool `json:"ip_filter_mode"` // IP过滤模式,true: 白名单模式,false: 黑名单模式
|
||||
DomainList []string `json:"domain_list"` // domain format, e.g. example.com, *.example.com
|
||||
IpList []string `json:"ip_list"` // CIDR format
|
||||
AllowedPorts []string `json:"allowed_ports"` // port range format, e.g. 80, 443, 8000-9000
|
||||
ApplyIPFilterForDomain bool `json:"apply_ip_filter_for_domain"` // 对域名启用IP过滤(实验性)
|
||||
}
|
||||
|
||||
var defaultFetchSetting = FetchSetting{
|
||||
EnableSSRFProtection: true, // 默认开启SSRF防护
|
||||
AllowPrivateIp: false,
|
||||
DomainFilterMode: false,
|
||||
IpFilterMode: false,
|
||||
DomainList: []string{},
|
||||
IpList: []string{},
|
||||
AllowedPorts: []string{"80", "443", "8080", "8443"},
|
||||
ApplyIPFilterForDomain: false,
|
||||
}
|
||||
|
||||
func init() {
|
||||
// 注册到全局配置管理器
|
||||
config.GlobalConfig.Register("fetch_setting", &defaultFetchSetting)
|
||||
}
|
||||
|
||||
func GetFetchSetting() *FetchSetting {
|
||||
return &defaultFetchSetting
|
||||
}
|
||||
74
setting/system_setting/oauth2.go
Normal file
74
setting/system_setting/oauth2.go
Normal file
@@ -0,0 +1,74 @@
|
||||
package system_setting
|
||||
|
||||
import "one-api/setting/config"
|
||||
|
||||
type OAuth2Settings struct {
|
||||
Enabled bool `json:"enabled"`
|
||||
Issuer string `json:"issuer"`
|
||||
AccessTokenTTL int `json:"access_token_ttl"` // in minutes
|
||||
RefreshTokenTTL int `json:"refresh_token_ttl"` // in minutes
|
||||
AllowedGrantTypes []string `json:"allowed_grant_types"` // client_credentials, authorization_code, refresh_token
|
||||
RequirePKCE bool `json:"require_pkce"` // force PKCE for authorization code flow
|
||||
JWTSigningAlgorithm string `json:"jwt_signing_algorithm"`
|
||||
JWTKeyID string `json:"jwt_key_id"`
|
||||
JWTPrivateKeyFile string `json:"jwt_private_key_file"`
|
||||
AutoCreateUser bool `json:"auto_create_user"` // auto create user on first OAuth2 login
|
||||
DefaultUserRole int `json:"default_user_role"` // default role for auto-created users
|
||||
DefaultUserGroup string `json:"default_user_group"` // default group for auto-created users
|
||||
ScopeMappings map[string][]string `json:"scope_mappings"` // scope to permissions mapping
|
||||
MaxJWKSKeys int `json:"max_jwks_keys"` // maximum number of JWKS signing keys to retain
|
||||
DefaultPrivateKeyPath string `json:"default_private_key_path"` // suggested private key file path
|
||||
}
|
||||
|
||||
// 默认配置
|
||||
var defaultOAuth2Settings = OAuth2Settings{
|
||||
Enabled: false,
|
||||
AccessTokenTTL: 10, // 10 minutes
|
||||
RefreshTokenTTL: 720, // 12 hours
|
||||
AllowedGrantTypes: []string{"client_credentials", "authorization_code", "refresh_token"},
|
||||
RequirePKCE: true,
|
||||
JWTSigningAlgorithm: "RS256",
|
||||
JWTKeyID: "oauth2-key-1",
|
||||
AutoCreateUser: false,
|
||||
DefaultUserRole: 1, // common user
|
||||
DefaultUserGroup: "default",
|
||||
ScopeMappings: map[string][]string{
|
||||
"api:read": {"read"},
|
||||
"api:write": {"write"},
|
||||
"admin": {"admin"},
|
||||
},
|
||||
MaxJWKSKeys: 3,
|
||||
DefaultPrivateKeyPath: "/etc/new-api/oauth2-private.pem",
|
||||
}
|
||||
|
||||
func init() {
|
||||
// 注册到全局配置管理器
|
||||
config.GlobalConfig.Register("oauth2", &defaultOAuth2Settings)
|
||||
}
|
||||
|
||||
func GetOAuth2Settings() *OAuth2Settings {
|
||||
return &defaultOAuth2Settings
|
||||
}
|
||||
|
||||
// UpdateOAuth2Settings 更新OAuth2配置
|
||||
func UpdateOAuth2Settings(settings OAuth2Settings) {
|
||||
defaultOAuth2Settings = settings
|
||||
}
|
||||
|
||||
// ValidateGrantType 验证授权类型是否被允许
|
||||
func (s *OAuth2Settings) ValidateGrantType(grantType string) bool {
|
||||
for _, allowedType := range s.AllowedGrantTypes {
|
||||
if allowedType == grantType {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// GetScopePermissions 获取scope对应的权限
|
||||
func (s *OAuth2Settings) GetScopePermissions(scope string) []string {
|
||||
if perms, exists := s.ScopeMappings[scope]; exists {
|
||||
return perms
|
||||
}
|
||||
return []string{}
|
||||
}
|
||||
1069
src/oauth/server.go
Normal file
1069
src/oauth/server.go
Normal file
File diff suppressed because it is too large
Load Diff
82
src/oauth/store.go
Normal file
82
src/oauth/store.go
Normal file
@@ -0,0 +1,82 @@
|
||||
package oauth
|
||||
|
||||
import (
|
||||
"one-api/common"
|
||||
"sync"
|
||||
"time"
|
||||
)
|
||||
|
||||
// KVStore is a minimal TTL key-value abstraction used by OAuth flows.
|
||||
type KVStore interface {
|
||||
Set(key, value string, ttl time.Duration) error
|
||||
Get(key string) (string, bool)
|
||||
Del(key string) error
|
||||
}
|
||||
|
||||
type redisStore struct{}
|
||||
|
||||
func (r *redisStore) Set(key, value string, ttl time.Duration) error {
|
||||
return common.RedisSet(key, value, ttl)
|
||||
}
|
||||
func (r *redisStore) Get(key string) (string, bool) {
|
||||
v, err := common.RedisGet(key)
|
||||
if err != nil || v == "" {
|
||||
return "", false
|
||||
}
|
||||
return v, true
|
||||
}
|
||||
func (r *redisStore) Del(key string) error {
|
||||
return common.RedisDel(key)
|
||||
}
|
||||
|
||||
type memEntry struct {
|
||||
val string
|
||||
exp int64 // unix seconds, 0 means no expiry
|
||||
}
|
||||
|
||||
type memoryStore struct {
|
||||
m sync.Map // key -> memEntry
|
||||
}
|
||||
|
||||
func (m *memoryStore) Set(key, value string, ttl time.Duration) error {
|
||||
var exp int64
|
||||
if ttl > 0 {
|
||||
exp = time.Now().Add(ttl).Unix()
|
||||
}
|
||||
m.m.Store(key, memEntry{val: value, exp: exp})
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *memoryStore) Get(key string) (string, bool) {
|
||||
v, ok := m.m.Load(key)
|
||||
if !ok {
|
||||
return "", false
|
||||
}
|
||||
e := v.(memEntry)
|
||||
if e.exp > 0 && time.Now().Unix() > e.exp {
|
||||
m.m.Delete(key)
|
||||
return "", false
|
||||
}
|
||||
return e.val, true
|
||||
}
|
||||
|
||||
func (m *memoryStore) Del(key string) error {
|
||||
m.m.Delete(key)
|
||||
return nil
|
||||
}
|
||||
|
||||
var (
|
||||
memStore = &memoryStore{}
|
||||
rdsStore = &redisStore{}
|
||||
)
|
||||
|
||||
func getStore() KVStore {
|
||||
if common.RedisEnabled {
|
||||
return rdsStore
|
||||
}
|
||||
return memStore
|
||||
}
|
||||
|
||||
func storeSet(key, val string, ttl time.Duration) error { return getStore().Set(key, val, ttl) }
|
||||
func storeGet(key string) (string, bool) { return getStore().Get(key) }
|
||||
func storeDel(key string) error { return getStore().Del(key) }
|
||||
59
src/oauth/util.go
Normal file
59
src/oauth/util.go
Normal file
@@ -0,0 +1,59 @@
|
||||
package oauth
|
||||
|
||||
import (
|
||||
"crypto/rand"
|
||||
"crypto/sha256"
|
||||
"encoding/base64"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"strings"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
// getFormOrBasicAuth extracts client_id/client_secret from Basic Auth first, then form
|
||||
func getFormOrBasicAuth(c *gin.Context) (clientID, clientSecret string) {
|
||||
id, secret, ok := c.Request.BasicAuth()
|
||||
if ok {
|
||||
return strings.TrimSpace(id), strings.TrimSpace(secret)
|
||||
}
|
||||
return strings.TrimSpace(c.PostForm("client_id")), strings.TrimSpace(c.PostForm("client_secret"))
|
||||
}
|
||||
|
||||
// genCode generates URL-safe random string based on nBytes of entropy
|
||||
func genCode(nBytes int) (string, error) {
|
||||
b := make([]byte, nBytes)
|
||||
if _, err := rand.Read(b); err != nil {
|
||||
return "", err
|
||||
}
|
||||
return base64.RawURLEncoding.EncodeToString(b), nil
|
||||
}
|
||||
|
||||
// s256Base64URL computes base64url-encoded SHA256 digest
|
||||
func s256Base64URL(verifier string) string {
|
||||
sum := sha256.Sum256([]byte(verifier))
|
||||
return base64.RawURLEncoding.EncodeToString(sum[:])
|
||||
}
|
||||
|
||||
// writeNoStore sets no-store cache headers for OAuth responses
|
||||
func writeNoStore(c *gin.Context) {
|
||||
c.Header("Cache-Control", "no-store")
|
||||
c.Header("Pragma", "no-cache")
|
||||
}
|
||||
|
||||
// writeOAuthRedirectError builds an error redirect to redirect_uri as RFC6749
|
||||
func writeOAuthRedirectError(c *gin.Context, redirectURI, errCode, description, state string) {
|
||||
writeNoStore(c)
|
||||
q := "error=" + url.QueryEscape(errCode)
|
||||
if description != "" {
|
||||
q += "&error_description=" + url.QueryEscape(description)
|
||||
}
|
||||
if state != "" {
|
||||
q += "&state=" + url.QueryEscape(state)
|
||||
}
|
||||
sep := "?"
|
||||
if strings.Contains(redirectURI, "?") {
|
||||
sep = "&"
|
||||
}
|
||||
c.Redirect(http.StatusFound, redirectURI+sep+q)
|
||||
}
|
||||
@@ -122,9 +122,6 @@ func (e *NewAPIError) MaskSensitiveError() string {
|
||||
return string(e.errorCode)
|
||||
}
|
||||
errStr := e.Err.Error()
|
||||
if e.errorCode == ErrorCodeCountTokenFailed {
|
||||
return errStr
|
||||
}
|
||||
return common.MaskSensitiveInfo(errStr)
|
||||
}
|
||||
|
||||
@@ -156,9 +153,8 @@ func (e *NewAPIError) ToOpenAIError() OpenAIError {
|
||||
Code: e.errorCode,
|
||||
}
|
||||
}
|
||||
if e.errorCode != ErrorCodeCountTokenFailed {
|
||||
result.Message = common.MaskSensitiveInfo(result.Message)
|
||||
}
|
||||
|
||||
result.Message = common.MaskSensitiveInfo(result.Message)
|
||||
return result
|
||||
}
|
||||
|
||||
@@ -182,9 +178,7 @@ func (e *NewAPIError) ToClaudeError() ClaudeError {
|
||||
Type: string(e.errorType),
|
||||
}
|
||||
}
|
||||
if e.errorCode != ErrorCodeCountTokenFailed {
|
||||
result.Message = common.MaskSensitiveInfo(result.Message)
|
||||
}
|
||||
result.Message = common.MaskSensitiveInfo(result.Message)
|
||||
return result
|
||||
}
|
||||
|
||||
|
||||
@@ -15,8 +15,6 @@ type PriceData struct {
|
||||
CacheRatio float64
|
||||
CacheCreationRatio float64
|
||||
ImageRatio float64
|
||||
AudioRatio float64
|
||||
AudioCompletionRatio float64
|
||||
UsePrice bool
|
||||
ShouldPreConsumedQuota int
|
||||
GroupRatioInfo GroupRatioInfo
|
||||
@@ -29,5 +27,5 @@ type PerCallPriceData struct {
|
||||
}
|
||||
|
||||
func (p PriceData) ToSetting() string {
|
||||
return fmt.Sprintf("ModelPrice: %f, ModelRatio: %f, CompletionRatio: %f, CacheRatio: %f, GroupRatio: %f, UsePrice: %t, CacheCreationRatio: %f, ShouldPreConsumedQuota: %d, ImageRatio: %f, AudioRatio: %f, AudioCompletionRatio: %f", p.ModelPrice, p.ModelRatio, p.CompletionRatio, p.CacheRatio, p.GroupRatioInfo.GroupRatio, p.UsePrice, p.CacheCreationRatio, p.ShouldPreConsumedQuota, p.ImageRatio, p.AudioRatio, p.AudioCompletionRatio)
|
||||
return fmt.Sprintf("ModelPrice: %f, ModelRatio: %f, CompletionRatio: %f, CacheRatio: %f, GroupRatio: %f, UsePrice: %t, CacheCreationRatio: %f, ShouldPreConsumedQuota: %d, ImageRatio: %f", p.ModelPrice, p.ModelRatio, p.CompletionRatio, p.CacheRatio, p.GroupRatioInfo.GroupRatio, p.UsePrice, p.CacheCreationRatio, p.ShouldPreConsumedQuota, p.ImageRatio)
|
||||
}
|
||||
|
||||
@@ -10,7 +10,6 @@
|
||||
content="OpenAI 接口聚合管理,支持多种渠道包括 Azure,可用于二次分发管理 key,仅单可执行文件,已打包好 Docker 镜像,一键部署,开箱即用"
|
||||
/>
|
||||
<title>New API</title>
|
||||
<analytics></analytics>
|
||||
</head>
|
||||
|
||||
<body>
|
||||
|
||||
@@ -1,9 +0,0 @@
|
||||
{
|
||||
"compilerOptions": {
|
||||
"baseUrl": "./",
|
||||
"paths": {
|
||||
"@/*": ["src/*"]
|
||||
}
|
||||
},
|
||||
"include": ["src/**/*"]
|
||||
}
|
||||
662
web/public/oauth-demo.html
Normal file
662
web/public/oauth-demo.html
Normal file
@@ -0,0 +1,662 @@
|
||||
<!-- This file is a copy of examples/oauth-demo.html for direct serving under /oauth-demo.html -->
|
||||
<!doctype html>
|
||||
<html lang="zh-CN">
|
||||
<head>
|
||||
<meta charset="utf-8" />
|
||||
<meta name="viewport" content="width=device-width, initial-scale=1" />
|
||||
<title>OAuth2/OIDC 授权码 + PKCE 前端演示</title>
|
||||
<style>
|
||||
:root {
|
||||
--bg: #0b0c10;
|
||||
--panel: #111317;
|
||||
--muted: #aab2bf;
|
||||
--accent: #3b82f6;
|
||||
--ok: #16a34a;
|
||||
--warn: #f59e0b;
|
||||
--err: #ef4444;
|
||||
--border: #1f2430;
|
||||
}
|
||||
body {
|
||||
margin: 0;
|
||||
font-family:
|
||||
ui-sans-serif,
|
||||
system-ui,
|
||||
-apple-system,
|
||||
Segoe UI,
|
||||
Roboto,
|
||||
Helvetica,
|
||||
Arial;
|
||||
background: var(--bg);
|
||||
color: #e5e7eb;
|
||||
}
|
||||
.wrap {
|
||||
max-width: 980px;
|
||||
margin: 32px auto;
|
||||
padding: 0 16px;
|
||||
}
|
||||
h1 {
|
||||
font-size: 22px;
|
||||
margin: 0 0 16px;
|
||||
}
|
||||
.card {
|
||||
background: var(--panel);
|
||||
border: 1px solid var(--border);
|
||||
border-radius: 10px;
|
||||
padding: 16px;
|
||||
margin: 12px 0;
|
||||
}
|
||||
.row {
|
||||
display: flex;
|
||||
gap: 12px;
|
||||
flex-wrap: wrap;
|
||||
}
|
||||
.col {
|
||||
flex: 1 1 280px;
|
||||
display: flex;
|
||||
flex-direction: column;
|
||||
}
|
||||
label {
|
||||
font-size: 12px;
|
||||
color: var(--muted);
|
||||
margin-bottom: 6px;
|
||||
}
|
||||
input,
|
||||
textarea,
|
||||
select {
|
||||
background: #0f1115;
|
||||
color: #e5e7eb;
|
||||
border: 1px solid var(--border);
|
||||
padding: 10px 12px;
|
||||
border-radius: 8px;
|
||||
outline: none;
|
||||
}
|
||||
textarea {
|
||||
min-height: 100px;
|
||||
resize: vertical;
|
||||
}
|
||||
.btns {
|
||||
display: flex;
|
||||
gap: 8px;
|
||||
flex-wrap: wrap;
|
||||
margin-top: 8px;
|
||||
}
|
||||
button {
|
||||
background: #1a1f2b;
|
||||
color: #e5e7eb;
|
||||
border: 1px solid var(--border);
|
||||
padding: 8px 12px;
|
||||
border-radius: 8px;
|
||||
cursor: pointer;
|
||||
}
|
||||
button.primary {
|
||||
background: var(--accent);
|
||||
border-color: var(--accent);
|
||||
color: white;
|
||||
}
|
||||
button.ok {
|
||||
background: var(--ok);
|
||||
border-color: var(--ok);
|
||||
color: white;
|
||||
}
|
||||
button.warn {
|
||||
background: var(--warn);
|
||||
border-color: var(--warn);
|
||||
color: black;
|
||||
}
|
||||
button.ghost {
|
||||
background: transparent;
|
||||
}
|
||||
.muted {
|
||||
color: var(--muted);
|
||||
font-size: 12px;
|
||||
}
|
||||
.mono {
|
||||
font-family: ui-monospace, SFMono-Regular, Menlo, Monaco, Consolas,
|
||||
'Liberation Mono', 'Courier New', monospace;
|
||||
}
|
||||
.grid2 {
|
||||
display: grid;
|
||||
grid-template-columns: 1fr 1fr;
|
||||
gap: 12px;
|
||||
}
|
||||
@media (max-width: 880px) {
|
||||
.grid2 {
|
||||
grid-template-columns: 1fr;
|
||||
}
|
||||
}
|
||||
.ok {
|
||||
color: #10b981;
|
||||
}
|
||||
.err {
|
||||
color: #ef4444;
|
||||
}
|
||||
.sep {
|
||||
height: 1px;
|
||||
background: var(--border);
|
||||
margin: 12px 0;
|
||||
}
|
||||
</style>
|
||||
</head>
|
||||
<body>
|
||||
<div class="wrap">
|
||||
<h1>OAuth2/OIDC 授权码 + PKCE 前端演示</h1>
|
||||
<div class="card">
|
||||
<div class="row">
|
||||
<div class="col">
|
||||
<label
|
||||
>Issuer(可选,用于自动发现
|
||||
/.well-known/openid-configuration)</label
|
||||
>
|
||||
<input id="issuer" placeholder="https://your-domain" />
|
||||
<div class="btns">
|
||||
<button class="" id="btnDiscover">自动发现端点</button>
|
||||
</div>
|
||||
<div class="muted">提示:若未配置 Issuer,可直接填写下方端点。</div>
|
||||
</div>
|
||||
</div>
|
||||
<div class="row">
|
||||
<div class="col">
|
||||
<label>Response Type</label>
|
||||
<select id="response_type">
|
||||
<option value="code" selected>code</option>
|
||||
<option value="token">token</option>
|
||||
</select>
|
||||
</div>
|
||||
<div class="col">
|
||||
<label>Authorization Endpoint</label
|
||||
><input
|
||||
id="authorization_endpoint"
|
||||
placeholder="https://domain/api/oauth/authorize"
|
||||
/>
|
||||
</div>
|
||||
<div class="col">
|
||||
<label>Token Endpoint</label
|
||||
><input
|
||||
id="token_endpoint"
|
||||
placeholder="https://domain/api/oauth/token"
|
||||
/>
|
||||
</div>
|
||||
</div>
|
||||
<div class="row">
|
||||
<div class="col">
|
||||
<label>UserInfo Endpoint(可选)</label
|
||||
><input
|
||||
id="userinfo_endpoint"
|
||||
placeholder="https://domain/api/oauth/userinfo"
|
||||
/>
|
||||
</div>
|
||||
<div class="col">
|
||||
<label>Client ID</label
|
||||
><input id="client_id" placeholder="your-public-client-id" />
|
||||
</div>
|
||||
</div>
|
||||
<div class="row">
|
||||
<div class="col">
|
||||
<label>Client Secret(可选,机密客户端)</label
|
||||
><input id="client_secret" placeholder="留空表示公开客户端" />
|
||||
</div>
|
||||
</div>
|
||||
<div class="row">
|
||||
<div class="col">
|
||||
<label>Redirect URI(当前页地址或你的回调)</label
|
||||
><input id="redirect_uri" />
|
||||
</div>
|
||||
<div class="col">
|
||||
<label>Scope</label
|
||||
><input id="scope" value="openid profile email" />
|
||||
</div>
|
||||
</div>
|
||||
<div class="row">
|
||||
<div class="col"><label>State</label><input id="state" /></div>
|
||||
<div class="col"><label>Nonce</label><input id="nonce" /></div>
|
||||
</div>
|
||||
<div class="row">
|
||||
<div class="col">
|
||||
<label>Code Verifier(自动生成,不会上送)</label
|
||||
><input id="code_verifier" class="mono" readonly />
|
||||
</div>
|
||||
<div class="col">
|
||||
<label>Code Challenge(S256)</label
|
||||
><input id="code_challenge" class="mono" readonly />
|
||||
</div>
|
||||
</div>
|
||||
<div class="btns">
|
||||
<button id="btnGenPkce">生成 PKCE</button>
|
||||
<button id="btnRandomState">随机 State</button>
|
||||
<button id="btnRandomNonce">随机 Nonce</button>
|
||||
<button id="btnMakeAuthURL">生成授权链接</button>
|
||||
<button id="btnAuthorize" class="primary">跳转授权</button>
|
||||
</div>
|
||||
<div class="row" style="margin-top: 8px">
|
||||
<div class="col">
|
||||
<label>授权链接(只生成不跳转)</label>
|
||||
<textarea
|
||||
id="authorize_url"
|
||||
class="mono"
|
||||
placeholder="(空)"
|
||||
></textarea>
|
||||
<div class="btns">
|
||||
<button id="btnCopyAuthURL">复制链接</button>
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
<div class="sep"></div>
|
||||
<div class="muted">
|
||||
说明:
|
||||
<ul>
|
||||
<li>
|
||||
本页为纯前端演示,适用于公开客户端(不需要 client_secret)。
|
||||
</li>
|
||||
<li>
|
||||
如跨域调用 Token/UserInfo,需要服务端正确设置 CORS;建议将此 demo
|
||||
部署到同源域名下。
|
||||
</li>
|
||||
</ul>
|
||||
</div>
|
||||
<div class="sep"></div>
|
||||
<div class="row">
|
||||
<div class="col">
|
||||
<label
|
||||
>粘贴 OIDC Discovery
|
||||
JSON(/.well-known/openid-configuration)</label
|
||||
>
|
||||
<textarea
|
||||
id="conf_json"
|
||||
class="mono"
|
||||
placeholder='{"issuer":"https://...","authorization_endpoint":"...","token_endpoint":"...","userinfo_endpoint":"..."}'
|
||||
></textarea>
|
||||
<div class="btns">
|
||||
<button id="btnParseConf">解析并填充端点</button>
|
||||
<button id="btnGenConf">用当前端点生成 JSON</button>
|
||||
</div>
|
||||
<div class="muted">
|
||||
可将服务端返回的 OIDC Discovery JSON
|
||||
粘贴到此处,点击“解析并填充端点”。
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
<div class="card">
|
||||
<div class="row">
|
||||
<div class="col">
|
||||
<label>授权结果</label>
|
||||
<div id="authResult" class="muted">等待授权...</div>
|
||||
</div>
|
||||
</div>
|
||||
<div class="grid2" style="margin-top: 12px">
|
||||
<div>
|
||||
<label>Access Token</label>
|
||||
<textarea
|
||||
id="access_token"
|
||||
class="mono"
|
||||
placeholder="(空)"
|
||||
></textarea>
|
||||
<div class="btns">
|
||||
<button id="btnCopyAT">复制</button
|
||||
><button id="btnCallUserInfo" class="ok">调用 UserInfo</button>
|
||||
</div>
|
||||
<div id="userinfoOut" class="muted" style="margin-top: 6px"></div>
|
||||
</div>
|
||||
<div>
|
||||
<label>ID Token(JWT)</label>
|
||||
<textarea id="id_token" class="mono" placeholder="(空)"></textarea>
|
||||
<div class="btns">
|
||||
<button id="btnDecodeJWT">解码显示 Claims</button>
|
||||
</div>
|
||||
<pre
|
||||
id="jwtClaims"
|
||||
class="mono"
|
||||
style="
|
||||
white-space: pre-wrap;
|
||||
word-break: break-all;
|
||||
margin-top: 6px;
|
||||
"
|
||||
></pre>
|
||||
</div>
|
||||
</div>
|
||||
<div class="grid2" style="margin-top: 12px">
|
||||
<div>
|
||||
<label>Refresh Token</label>
|
||||
<textarea
|
||||
id="refresh_token"
|
||||
class="mono"
|
||||
placeholder="(空)"
|
||||
></textarea>
|
||||
<div class="btns">
|
||||
<button id="btnRefreshToken">使用 Refresh Token 刷新</button>
|
||||
</div>
|
||||
</div>
|
||||
<div>
|
||||
<label>原始 Token 响应</label>
|
||||
<textarea id="token_raw" class="mono" placeholder="(空)"></textarea>
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
<script>
|
||||
const $ = (id) => document.getElementById(id);
|
||||
const toB64Url = (buf) =>
|
||||
btoa(String.fromCharCode(...new Uint8Array(buf)))
|
||||
.replace(/\+/g, '-')
|
||||
.replace(/\//g, '_')
|
||||
.replace(/=+$/, '');
|
||||
async function sha256B64Url(str) {
|
||||
const data = new TextEncoder().encode(str);
|
||||
const digest = await crypto.subtle.digest('SHA-256', data);
|
||||
return toB64Url(digest);
|
||||
}
|
||||
function randStr(len = 64) {
|
||||
const chars =
|
||||
'ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789-._~';
|
||||
const arr = new Uint8Array(len);
|
||||
crypto.getRandomValues(arr);
|
||||
return Array.from(arr, (v) => chars[v % chars.length]).join('');
|
||||
}
|
||||
function setAuthInfo(msg, ok = true) {
|
||||
const el = $('authResult');
|
||||
el.textContent = msg;
|
||||
el.className = ok ? 'ok' : 'err';
|
||||
}
|
||||
function qs(name) {
|
||||
const u = new URL(location.href);
|
||||
return u.searchParams.get(name);
|
||||
}
|
||||
function persist(k, v) {
|
||||
sessionStorage.setItem('demo_' + k, v);
|
||||
}
|
||||
function load(k) {
|
||||
return sessionStorage.getItem('demo_' + k) || '';
|
||||
}
|
||||
(function init() {
|
||||
$('redirect_uri').value =
|
||||
window.location.origin + window.location.pathname;
|
||||
const iss = load('issuer');
|
||||
if (iss) $('issuer').value = iss;
|
||||
const cid = load('client_id');
|
||||
if (cid) $('client_id').value = cid;
|
||||
const scp = load('scope');
|
||||
if (scp) $('scope').value = scp;
|
||||
})();
|
||||
$('btnDiscover').onclick = async () => {
|
||||
const iss = $('issuer').value.trim();
|
||||
if (!iss) {
|
||||
alert('请填写 Issuer');
|
||||
return;
|
||||
}
|
||||
try {
|
||||
persist('issuer', iss);
|
||||
const res = await fetch(
|
||||
iss.replace(/\/$/, '') + '/api/.well-known/openid-configuration',
|
||||
);
|
||||
const d = await res.json();
|
||||
$('authorization_endpoint').value = d.authorization_endpoint || '';
|
||||
$('token_endpoint').value = d.token_endpoint || '';
|
||||
$('userinfo_endpoint').value = d.userinfo_endpoint || '';
|
||||
if (d.issuer) {
|
||||
$('issuer').value = d.issuer;
|
||||
persist('issuer', d.issuer);
|
||||
}
|
||||
$('conf_json').value = JSON.stringify(d, null, 2);
|
||||
setAuthInfo('已从发现文档加载端点', true);
|
||||
} catch (e) {
|
||||
setAuthInfo('自动发现失败:' + e, false);
|
||||
}
|
||||
};
|
||||
$('btnGenPkce').onclick = async () => {
|
||||
const v = randStr(64);
|
||||
const c = await sha256B64Url(v);
|
||||
$('code_verifier').value = v;
|
||||
$('code_challenge').value = c;
|
||||
persist('code_verifier', v);
|
||||
persist('code_challenge', c);
|
||||
setAuthInfo('已生成 PKCE 参数', true);
|
||||
};
|
||||
$('btnRandomState').onclick = () => {
|
||||
$('state').value = randStr(16);
|
||||
persist('state', $('state').value);
|
||||
};
|
||||
$('btnRandomNonce').onclick = () => {
|
||||
$('nonce').value = randStr(16);
|
||||
persist('nonce', $('nonce').value);
|
||||
};
|
||||
function buildAuthorizeURLFromFields() {
|
||||
const auth = $('authorization_endpoint').value.trim();
|
||||
const token = $('token_endpoint').value.trim();
|
||||
const cid = $('client_id').value.trim();
|
||||
const red = $('redirect_uri').value.trim();
|
||||
const scp = $('scope').value.trim() || 'openid profile email';
|
||||
const rt = $('response_type').value;
|
||||
const st = $('state').value.trim() || randStr(16);
|
||||
const no = $('nonce').value.trim() || randStr(16);
|
||||
const cc = $('code_challenge').value.trim();
|
||||
const cv = $('code_verifier').value.trim();
|
||||
if (!auth || !cid || !red) {
|
||||
throw new Error('请先完善端点/ClientID/RedirectURI');
|
||||
}
|
||||
if (rt === 'code' && (!cc || !cv)) {
|
||||
throw new Error('请先生成 PKCE');
|
||||
}
|
||||
persist('authorization_endpoint', auth);
|
||||
persist('token_endpoint', token);
|
||||
persist('client_id', cid);
|
||||
persist('redirect_uri', red);
|
||||
persist('scope', scp);
|
||||
persist('state', st);
|
||||
persist('nonce', no);
|
||||
persist('code_verifier', cv);
|
||||
const u = new URL(auth);
|
||||
u.searchParams.set('response_type', rt);
|
||||
u.searchParams.set('client_id', cid);
|
||||
u.searchParams.set('redirect_uri', red);
|
||||
u.searchParams.set('scope', scp);
|
||||
u.searchParams.set('state', st);
|
||||
if (no) u.searchParams.set('nonce', no);
|
||||
if (rt === 'code') {
|
||||
u.searchParams.set('code_challenge', cc);
|
||||
u.searchParams.set('code_challenge_method', 'S256');
|
||||
}
|
||||
return u.toString();
|
||||
}
|
||||
$('btnMakeAuthURL').onclick = () => {
|
||||
try {
|
||||
const url = buildAuthorizeURLFromFields();
|
||||
$('authorize_url').value = url;
|
||||
setAuthInfo('已生成授权链接', true);
|
||||
} catch (e) {
|
||||
setAuthInfo(e.message, false);
|
||||
}
|
||||
};
|
||||
$('btnAuthorize').onclick = () => {
|
||||
try {
|
||||
const url = buildAuthorizeURLFromFields();
|
||||
location.href = url;
|
||||
} catch (e) {
|
||||
setAuthInfo(e.message, false);
|
||||
}
|
||||
};
|
||||
$('btnCopyAuthURL').onclick = async () => {
|
||||
try {
|
||||
await navigator.clipboard.writeText($('authorize_url').value);
|
||||
} catch {}
|
||||
};
|
||||
async function postForm(url, data, basic) {
|
||||
const body = Object.entries(data)
|
||||
.map(([k, v]) => `${encodeURIComponent(k)}=${encodeURIComponent(v)}`)
|
||||
.join('&');
|
||||
const headers = { 'Content-Type': 'application/x-www-form-urlencoded' };
|
||||
if (basic && basic.id && basic.secret) {
|
||||
headers['Authorization'] =
|
||||
'Basic ' + btoa(`${basic.id}:${basic.secret}`);
|
||||
}
|
||||
const res = await fetch(url, { method: 'POST', headers, body });
|
||||
if (!res.ok) {
|
||||
const t = await res.text();
|
||||
throw new Error(`HTTP ${res.status} ${t}`);
|
||||
}
|
||||
return res.json();
|
||||
}
|
||||
async function handleCallback() {
|
||||
const frag =
|
||||
location.hash && location.hash.startsWith('#')
|
||||
? new URLSearchParams(location.hash.slice(1))
|
||||
: null;
|
||||
const at = frag ? frag.get('access_token') : null;
|
||||
const err = qs('error') || (frag ? frag.get('error') : null);
|
||||
const state = qs('state') || (frag ? frag.get('state') : null);
|
||||
if (err) {
|
||||
setAuthInfo('授权失败:' + err, false);
|
||||
return;
|
||||
}
|
||||
if (at) {
|
||||
$('access_token').value = at || '';
|
||||
$('token_raw').value = JSON.stringify(
|
||||
{
|
||||
access_token: at,
|
||||
token_type: frag.get('token_type'),
|
||||
expires_in: frag.get('expires_in'),
|
||||
scope: frag.get('scope'),
|
||||
state,
|
||||
},
|
||||
null,
|
||||
2,
|
||||
);
|
||||
setAuthInfo('隐式模式已获取 Access Token', true);
|
||||
return;
|
||||
}
|
||||
const code = qs('code');
|
||||
if (!code) {
|
||||
setAuthInfo('等待授权...', true);
|
||||
return;
|
||||
}
|
||||
if (state && load('state') && state !== load('state')) {
|
||||
setAuthInfo('state 不匹配,已拒绝', false);
|
||||
return;
|
||||
}
|
||||
try {
|
||||
const tokenEp = load('token_endpoint');
|
||||
const cid = load('client_id');
|
||||
const csec = $('client_secret').value.trim();
|
||||
const basic = csec ? { id: cid, secret: csec } : null;
|
||||
const data = await postForm(
|
||||
tokenEp,
|
||||
{
|
||||
grant_type: 'authorization_code',
|
||||
code,
|
||||
client_id: cid,
|
||||
redirect_uri: load('redirect_uri'),
|
||||
code_verifier: load('code_verifier'),
|
||||
},
|
||||
basic,
|
||||
);
|
||||
$('access_token').value = data.access_token || '';
|
||||
$('id_token').value = data.id_token || '';
|
||||
$('refresh_token').value = data.refresh_token || '';
|
||||
$('token_raw').value = JSON.stringify(data, null, 2);
|
||||
setAuthInfo('授权成功,已获取令牌', true);
|
||||
} catch (e) {
|
||||
setAuthInfo('交换令牌失败:' + e.message, false);
|
||||
}
|
||||
}
|
||||
handleCallback();
|
||||
$('btnCopyAT').onclick = async () => {
|
||||
try {
|
||||
await navigator.clipboard.writeText($('access_token').value);
|
||||
} catch {}
|
||||
};
|
||||
$('btnDecodeJWT').onclick = () => {
|
||||
const t = $('id_token').value.trim();
|
||||
if (!t) {
|
||||
$('jwtClaims').textContent = '(空)';
|
||||
return;
|
||||
}
|
||||
const parts = t.split('.');
|
||||
if (parts.length < 2) {
|
||||
$('jwtClaims').textContent = '格式错误';
|
||||
return;
|
||||
}
|
||||
try {
|
||||
const json = JSON.parse(
|
||||
atob(parts[1].replace(/-/g, '+').replace(/_/g, '/')),
|
||||
);
|
||||
$('jwtClaims').textContent = JSON.stringify(json, null, 2);
|
||||
} catch (e) {
|
||||
$('jwtClaims').textContent = '解码失败:' + e;
|
||||
}
|
||||
};
|
||||
$('btnCallUserInfo').onclick = async () => {
|
||||
const at = $('access_token').value.trim();
|
||||
const ep = $('userinfo_endpoint').value.trim();
|
||||
if (!at || !ep) {
|
||||
alert('请填写UserInfo端点并获取AccessToken');
|
||||
return;
|
||||
}
|
||||
try {
|
||||
const res = await fetch(ep, {
|
||||
headers: { Authorization: 'Bearer ' + at },
|
||||
});
|
||||
const data = await res.json();
|
||||
$('userinfoOut').textContent = JSON.stringify(data, null, 2);
|
||||
} catch (e) {
|
||||
$('userinfoOut').textContent = '调用失败:' + e;
|
||||
}
|
||||
};
|
||||
$('btnRefreshToken').onclick = async () => {
|
||||
const rt = $('refresh_token').value.trim();
|
||||
if (!rt) {
|
||||
alert('没有刷新令牌');
|
||||
return;
|
||||
}
|
||||
try {
|
||||
const tokenEp = load('token_endpoint');
|
||||
const cid = load('client_id');
|
||||
const csec = $('client_secret').value.trim();
|
||||
const basic = csec ? { id: cid, secret: csec } : null;
|
||||
const data = await postForm(
|
||||
tokenEp,
|
||||
{ grant_type: 'refresh_token', refresh_token: rt, client_id: cid },
|
||||
basic,
|
||||
);
|
||||
$('access_token').value = data.access_token || '';
|
||||
$('id_token').value = data.id_token || '';
|
||||
$('refresh_token').value = data.refresh_token || '';
|
||||
$('token_raw').value = JSON.stringify(data, null, 2);
|
||||
setAuthInfo('刷新成功', true);
|
||||
} catch (e) {
|
||||
setAuthInfo('刷新失败:' + e.message, false);
|
||||
}
|
||||
};
|
||||
$('btnParseConf').onclick = () => {
|
||||
const txt = $('conf_json').value.trim();
|
||||
if (!txt) {
|
||||
alert('请先粘贴 JSON');
|
||||
return;
|
||||
}
|
||||
try {
|
||||
const d = JSON.parse(txt);
|
||||
if (d.issuer) {
|
||||
$('issuer').value = d.issuer;
|
||||
persist('issuer', d.issuer);
|
||||
}
|
||||
if (d.authorization_endpoint)
|
||||
$('authorization_endpoint').value = d.authorization_endpoint;
|
||||
if (d.token_endpoint) $('token_endpoint').value = d.token_endpoint;
|
||||
if (d.userinfo_endpoint)
|
||||
$('userinfo_endpoint').value = d.userinfo_endpoint;
|
||||
setAuthInfo('已解析配置并填充端点', true);
|
||||
} catch (e) {
|
||||
setAuthInfo('解析失败:' + e, false);
|
||||
}
|
||||
};
|
||||
$('btnGenConf').onclick = () => {
|
||||
const d = {
|
||||
issuer: $('issuer').value.trim() || undefined,
|
||||
authorization_endpoint:
|
||||
$('authorization_endpoint').value.trim() || undefined,
|
||||
token_endpoint: $('token_endpoint').value.trim() || undefined,
|
||||
userinfo_endpoint: $('userinfo_endpoint').value.trim() || undefined,
|
||||
};
|
||||
$('conf_json').value = JSON.stringify(d, null, 2);
|
||||
};
|
||||
</script>
|
||||
</body>
|
||||
</html>
|
||||
@@ -44,6 +44,7 @@ import Task from './pages/Task';
|
||||
import ModelPage from './pages/Model';
|
||||
import Playground from './pages/Playground';
|
||||
import OAuth2Callback from './components/auth/OAuth2Callback';
|
||||
import OAuthConsent from './pages/OAuth';
|
||||
import PersonalSetting from './components/settings/PersonalSetting';
|
||||
import Setup from './pages/Setup';
|
||||
import SetupCheck from './components/layout/SetupCheck';
|
||||
@@ -198,6 +199,14 @@ function App() {
|
||||
</Suspense>
|
||||
}
|
||||
/>
|
||||
<Route
|
||||
path='/oauth/consent'
|
||||
element={
|
||||
<Suspense fallback={<Loading></Loading>}>
|
||||
<OAuthConsent />
|
||||
</Suspense>
|
||||
}
|
||||
/>
|
||||
<Route
|
||||
path='/oauth/linuxdo'
|
||||
element={
|
||||
|
||||
@@ -176,7 +176,11 @@ const LoginForm = () => {
|
||||
centered: true,
|
||||
});
|
||||
}
|
||||
navigate('/console');
|
||||
// 优先跳回 next(仅允许相对路径)
|
||||
const sp = new URLSearchParams(window.location.search);
|
||||
const next = sp.get('next');
|
||||
const isSafeInternalPath = next && next.startsWith('/') && !next.startsWith('//');
|
||||
navigate(isSafeInternalPath ? next : '/console');
|
||||
} else {
|
||||
showError(message);
|
||||
}
|
||||
@@ -286,7 +290,10 @@ const LoginForm = () => {
|
||||
setUserData(data);
|
||||
updateAPI();
|
||||
showSuccess('登录成功!');
|
||||
navigate('/console');
|
||||
const sp = new URLSearchParams(window.location.search);
|
||||
const next = sp.get('next');
|
||||
const isSafeInternalPath = next && next.startsWith('/') && !next.startsWith('//');
|
||||
navigate(isSafeInternalPath ? next : '/console');
|
||||
};
|
||||
|
||||
// 返回登录页面
|
||||
|
||||
@@ -181,8 +181,8 @@ export function PreCode(props) {
|
||||
e.preventDefault();
|
||||
e.stopPropagation();
|
||||
if (ref.current) {
|
||||
const codeElement = ref.current.querySelector('code');
|
||||
const code = codeElement?.textContent ?? '';
|
||||
const code =
|
||||
ref.current.querySelector('code')?.innerText ?? '';
|
||||
copy(code).then((success) => {
|
||||
if (success) {
|
||||
Toast.success(t('代码已复制到剪贴板'));
|
||||
|
||||
@@ -135,7 +135,9 @@ const TwoFactorAuthModal = ({
|
||||
autoFocus
|
||||
/>
|
||||
<Typography.Text type='tertiary' size='small' className='mt-2 block'>
|
||||
{t('支持6位TOTP验证码或8位备用码,可到`个人设置-安全设置-两步验证设置`配置或查看。')}
|
||||
{t(
|
||||
'支持6位TOTP验证码或8位备用码,可到`个人设置-安全设置-两步验证设置`配置或查看。',
|
||||
)}
|
||||
</Typography.Text>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
@@ -21,7 +21,7 @@ import React, { useState, useMemo, useCallback } from 'react';
|
||||
import { Button, Tooltip, Toast } from '@douyinfe/semi-ui';
|
||||
import { Copy, ChevronDown, ChevronUp } from 'lucide-react';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
import { copy } from '../../helpers';
|
||||
import { copy } from '../../../helpers';
|
||||
|
||||
const PERFORMANCE_CONFIG = {
|
||||
MAX_DISPLAY_LENGTH: 50000, // 最大显示字符数
|
||||
135
web/src/components/common/ui/ResponsiveModal.jsx
Normal file
135
web/src/components/common/ui/ResponsiveModal.jsx
Normal file
@@ -0,0 +1,135 @@
|
||||
/*
|
||||
Copyright (C) 2025 QuantumNous
|
||||
|
||||
This program is free software: you can redistribute it and/or modify
|
||||
it under the terms of the GNU Affero General Public License as
|
||||
published by the Free Software Foundation, either version 3 of the
|
||||
License, or (at your option) any later version.
|
||||
|
||||
This program is distributed in the hope that it will be useful,
|
||||
but WITHOUT ANY WARRANTY; without even the implied warranty of
|
||||
MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
||||
GNU Affero General Public License for more details.
|
||||
|
||||
You should have received a copy of the GNU Affero General Public License
|
||||
along with this program. If not, see <https://www.gnu.org/licenses/>.
|
||||
|
||||
For commercial licensing, please contact support@quantumnous.com
|
||||
*/
|
||||
|
||||
import React from 'react';
|
||||
import { Modal, Typography } from '@douyinfe/semi-ui';
|
||||
import PropTypes from 'prop-types';
|
||||
import { useIsMobile } from '../../../hooks/common/useIsMobile';
|
||||
|
||||
const { Title } = Typography;
|
||||
|
||||
/**
|
||||
* ResponsiveModal 响应式模态框组件
|
||||
*
|
||||
* 特性:
|
||||
* - 响应式布局:移动端和桌面端不同的宽度和布局
|
||||
* - 自定义头部:标题左对齐,操作按钮右对齐,移动端自动换行
|
||||
* - Tailwind CSS 样式支持
|
||||
* - 保持原 Modal 组件的所有功能
|
||||
*/
|
||||
const ResponsiveModal = ({
|
||||
visible,
|
||||
onCancel,
|
||||
title,
|
||||
headerActions = [],
|
||||
children,
|
||||
width = { mobile: '95%', desktop: 600 },
|
||||
className = '',
|
||||
footer = null,
|
||||
titleProps = {},
|
||||
headerClassName = '',
|
||||
actionsClassName = '',
|
||||
...props
|
||||
}) => {
|
||||
const isMobile = useIsMobile();
|
||||
|
||||
// 自定义 Header 组件
|
||||
const CustomHeader = () => {
|
||||
if (!title && (!headerActions || headerActions.length === 0)) return null;
|
||||
|
||||
return (
|
||||
<div
|
||||
className={`flex w-full gap-3 justify-between ${
|
||||
isMobile ? 'flex-col items-start' : 'flex-row items-center'
|
||||
} ${headerClassName}`}
|
||||
>
|
||||
{title && (
|
||||
<Title heading={5} className='m-0 min-w-fit' {...titleProps}>
|
||||
{title}
|
||||
</Title>
|
||||
)}
|
||||
{headerActions && headerActions.length > 0 && (
|
||||
<div
|
||||
className={`flex flex-wrap gap-2 items-center ${
|
||||
isMobile ? 'w-full justify-start' : 'w-auto justify-end'
|
||||
} ${actionsClassName}`}
|
||||
>
|
||||
{headerActions.map((action, index) => (
|
||||
<React.Fragment key={index}>{action}</React.Fragment>
|
||||
))}
|
||||
</div>
|
||||
)}
|
||||
</div>
|
||||
);
|
||||
};
|
||||
|
||||
// 计算模态框宽度
|
||||
const getModalWidth = () => {
|
||||
if (typeof width === 'object') {
|
||||
return isMobile ? width.mobile : width.desktop;
|
||||
}
|
||||
return width;
|
||||
};
|
||||
|
||||
return (
|
||||
<Modal
|
||||
visible={visible}
|
||||
title={<CustomHeader />}
|
||||
onCancel={onCancel}
|
||||
footer={footer}
|
||||
width={getModalWidth()}
|
||||
className={`!top-12 ${className}`}
|
||||
{...props}
|
||||
>
|
||||
{children}
|
||||
</Modal>
|
||||
);
|
||||
};
|
||||
|
||||
ResponsiveModal.propTypes = {
|
||||
// Modal 基础属性
|
||||
visible: PropTypes.bool.isRequired,
|
||||
onCancel: PropTypes.func.isRequired,
|
||||
children: PropTypes.node,
|
||||
|
||||
// 自定义头部
|
||||
title: PropTypes.oneOfType([PropTypes.string, PropTypes.node]),
|
||||
headerActions: PropTypes.arrayOf(PropTypes.node),
|
||||
|
||||
// 样式和布局
|
||||
width: PropTypes.oneOfType([
|
||||
PropTypes.number,
|
||||
PropTypes.string,
|
||||
PropTypes.shape({
|
||||
mobile: PropTypes.oneOfType([PropTypes.number, PropTypes.string]),
|
||||
desktop: PropTypes.oneOfType([PropTypes.number, PropTypes.string]),
|
||||
}),
|
||||
]),
|
||||
className: PropTypes.string,
|
||||
footer: PropTypes.node,
|
||||
|
||||
// 标题自定义属性
|
||||
titleProps: PropTypes.object,
|
||||
|
||||
// 自定义 CSS 类
|
||||
headerClassName: PropTypes.string,
|
||||
actionsClassName: PropTypes.string,
|
||||
};
|
||||
|
||||
export default ResponsiveModal;
|
||||
@@ -17,7 +17,7 @@ along with this program. If not, see <https://www.gnu.org/licenses/>.
|
||||
For commercial licensing, please contact support@quantumnous.com
|
||||
*/
|
||||
|
||||
import React, { useRef } from 'react';
|
||||
import React from 'react';
|
||||
import { Link } from 'react-router-dom';
|
||||
import { Avatar, Button, Dropdown, Typography } from '@douyinfe/semi-ui';
|
||||
import { ChevronDown } from 'lucide-react';
|
||||
@@ -39,7 +39,6 @@ const UserArea = ({
|
||||
navigate,
|
||||
t,
|
||||
}) => {
|
||||
const dropdownRef = useRef(null);
|
||||
if (isLoading) {
|
||||
return (
|
||||
<SkeletonWrapper
|
||||
@@ -53,93 +52,90 @@ const UserArea = ({
|
||||
|
||||
if (userState.user) {
|
||||
return (
|
||||
<div className='relative' ref={dropdownRef}>
|
||||
<Dropdown
|
||||
position='bottomRight'
|
||||
getPopupContainer={() => dropdownRef.current}
|
||||
render={
|
||||
<Dropdown.Menu className='!bg-semi-color-bg-overlay !border-semi-color-border !shadow-lg !rounded-lg dark:!bg-gray-700 dark:!border-gray-600'>
|
||||
<Dropdown.Item
|
||||
onClick={() => {
|
||||
navigate('/console/personal');
|
||||
}}
|
||||
className='!px-3 !py-1.5 !text-sm !text-semi-color-text-0 hover:!bg-semi-color-fill-1 dark:!text-gray-200 dark:hover:!bg-blue-500 dark:hover:!text-white'
|
||||
>
|
||||
<div className='flex items-center gap-2'>
|
||||
<IconUserSetting
|
||||
size='small'
|
||||
className='text-gray-500 dark:text-gray-400'
|
||||
/>
|
||||
<span>{t('个人设置')}</span>
|
||||
</div>
|
||||
</Dropdown.Item>
|
||||
<Dropdown.Item
|
||||
onClick={() => {
|
||||
navigate('/console/token');
|
||||
}}
|
||||
className='!px-3 !py-1.5 !text-sm !text-semi-color-text-0 hover:!bg-semi-color-fill-1 dark:!text-gray-200 dark:hover:!bg-blue-500 dark:hover:!text-white'
|
||||
>
|
||||
<div className='flex items-center gap-2'>
|
||||
<IconKey
|
||||
size='small'
|
||||
className='text-gray-500 dark:text-gray-400'
|
||||
/>
|
||||
<span>{t('令牌管理')}</span>
|
||||
</div>
|
||||
</Dropdown.Item>
|
||||
<Dropdown.Item
|
||||
onClick={() => {
|
||||
navigate('/console/topup');
|
||||
}}
|
||||
className='!px-3 !py-1.5 !text-sm !text-semi-color-text-0 hover:!bg-semi-color-fill-1 dark:!text-gray-200 dark:hover:!bg-blue-500 dark:hover:!text-white'
|
||||
>
|
||||
<div className='flex items-center gap-2'>
|
||||
<IconCreditCard
|
||||
size='small'
|
||||
className='text-gray-500 dark:text-gray-400'
|
||||
/>
|
||||
<span>{t('钱包管理')}</span>
|
||||
</div>
|
||||
</Dropdown.Item>
|
||||
<Dropdown.Item
|
||||
onClick={logout}
|
||||
className='!px-3 !py-1.5 !text-sm !text-semi-color-text-0 hover:!bg-semi-color-fill-1 dark:!text-gray-200 dark:hover:!bg-red-500 dark:hover:!text-white'
|
||||
>
|
||||
<div className='flex items-center gap-2'>
|
||||
<IconExit
|
||||
size='small'
|
||||
className='text-gray-500 dark:text-gray-400'
|
||||
/>
|
||||
<span>{t('退出')}</span>
|
||||
</div>
|
||||
</Dropdown.Item>
|
||||
</Dropdown.Menu>
|
||||
}
|
||||
>
|
||||
<Button
|
||||
theme='borderless'
|
||||
type='tertiary'
|
||||
className='flex items-center gap-1.5 !p-1 !rounded-full hover:!bg-semi-color-fill-1 dark:hover:!bg-gray-700 !bg-semi-color-fill-0 dark:!bg-semi-color-fill-1 dark:hover:!bg-semi-color-fill-2'
|
||||
>
|
||||
<Avatar
|
||||
size='extra-small'
|
||||
color={stringToColor(userState.user.username)}
|
||||
className='mr-1'
|
||||
<Dropdown
|
||||
position='bottomRight'
|
||||
render={
|
||||
<Dropdown.Menu className='!bg-semi-color-bg-overlay !border-semi-color-border !shadow-lg !rounded-lg dark:!bg-gray-700 dark:!border-gray-600'>
|
||||
<Dropdown.Item
|
||||
onClick={() => {
|
||||
navigate('/console/personal');
|
||||
}}
|
||||
className='!px-3 !py-1.5 !text-sm !text-semi-color-text-0 hover:!bg-semi-color-fill-1 dark:!text-gray-200 dark:hover:!bg-blue-500 dark:hover:!text-white'
|
||||
>
|
||||
{userState.user.username[0].toUpperCase()}
|
||||
</Avatar>
|
||||
<span className='hidden md:inline'>
|
||||
<Typography.Text className='!text-xs !font-medium !text-semi-color-text-1 dark:!text-gray-300 mr-1'>
|
||||
{userState.user.username}
|
||||
</Typography.Text>
|
||||
</span>
|
||||
<ChevronDown
|
||||
size={14}
|
||||
className='text-xs text-semi-color-text-2 dark:text-gray-400'
|
||||
/>
|
||||
</Button>
|
||||
</Dropdown>
|
||||
</div>
|
||||
<div className='flex items-center gap-2'>
|
||||
<IconUserSetting
|
||||
size='small'
|
||||
className='text-gray-500 dark:text-gray-400'
|
||||
/>
|
||||
<span>{t('个人设置')}</span>
|
||||
</div>
|
||||
</Dropdown.Item>
|
||||
<Dropdown.Item
|
||||
onClick={() => {
|
||||
navigate('/console/token');
|
||||
}}
|
||||
className='!px-3 !py-1.5 !text-sm !text-semi-color-text-0 hover:!bg-semi-color-fill-1 dark:!text-gray-200 dark:hover:!bg-blue-500 dark:hover:!text-white'
|
||||
>
|
||||
<div className='flex items-center gap-2'>
|
||||
<IconKey
|
||||
size='small'
|
||||
className='text-gray-500 dark:text-gray-400'
|
||||
/>
|
||||
<span>{t('令牌管理')}</span>
|
||||
</div>
|
||||
</Dropdown.Item>
|
||||
<Dropdown.Item
|
||||
onClick={() => {
|
||||
navigate('/console/topup');
|
||||
}}
|
||||
className='!px-3 !py-1.5 !text-sm !text-semi-color-text-0 hover:!bg-semi-color-fill-1 dark:!text-gray-200 dark:hover:!bg-blue-500 dark:hover:!text-white'
|
||||
>
|
||||
<div className='flex items-center gap-2'>
|
||||
<IconCreditCard
|
||||
size='small'
|
||||
className='text-gray-500 dark:text-gray-400'
|
||||
/>
|
||||
<span>{t('钱包管理')}</span>
|
||||
</div>
|
||||
</Dropdown.Item>
|
||||
<Dropdown.Item
|
||||
onClick={logout}
|
||||
className='!px-3 !py-1.5 !text-sm !text-semi-color-text-0 hover:!bg-semi-color-fill-1 dark:!text-gray-200 dark:hover:!bg-red-500 dark:hover:!text-white'
|
||||
>
|
||||
<div className='flex items-center gap-2'>
|
||||
<IconExit
|
||||
size='small'
|
||||
className='text-gray-500 dark:text-gray-400'
|
||||
/>
|
||||
<span>{t('退出')}</span>
|
||||
</div>
|
||||
</Dropdown.Item>
|
||||
</Dropdown.Menu>
|
||||
}
|
||||
>
|
||||
<Button
|
||||
theme='borderless'
|
||||
type='tertiary'
|
||||
className='flex items-center gap-1.5 !p-1 !rounded-full hover:!bg-semi-color-fill-1 dark:hover:!bg-gray-700 !bg-semi-color-fill-0 dark:!bg-semi-color-fill-1 dark:hover:!bg-semi-color-fill-2'
|
||||
>
|
||||
<Avatar
|
||||
size='extra-small'
|
||||
color={stringToColor(userState.user.username)}
|
||||
className='mr-1'
|
||||
>
|
||||
{userState.user.username[0].toUpperCase()}
|
||||
</Avatar>
|
||||
<span className='hidden md:inline'>
|
||||
<Typography.Text className='!text-xs !font-medium !text-semi-color-text-1 dark:!text-gray-300 mr-1'>
|
||||
{userState.user.username}
|
||||
</Typography.Text>
|
||||
</span>
|
||||
<ChevronDown
|
||||
size={14}
|
||||
className='text-xs text-semi-color-text-2 dark:text-gray-400'
|
||||
/>
|
||||
</Button>
|
||||
</Dropdown>
|
||||
);
|
||||
} else {
|
||||
const showRegisterButton = !isSelfUseMode;
|
||||
|
||||
@@ -28,7 +28,7 @@ import {
|
||||
} from '@douyinfe/semi-ui';
|
||||
import { Code, Zap, Clock, X, Eye, Send } from 'lucide-react';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
import CodeViewer from './CodeViewer';
|
||||
import CodeViewer from '../common/ui/CodeViewer';
|
||||
|
||||
const DebugPanel = ({
|
||||
debugData,
|
||||
|
||||
72
web/src/components/settings/OAuth2Setting.jsx
Normal file
72
web/src/components/settings/OAuth2Setting.jsx
Normal file
@@ -0,0 +1,72 @@
|
||||
/*
|
||||
Copyright (C) 2025 QuantumNous
|
||||
|
||||
This program is free software: you can redistribute it and/or modify
|
||||
it under the terms of the GNU Affero General Public License as
|
||||
published by the Free Software Foundation, either version 3 of the
|
||||
License, or (at your option) any later version.
|
||||
|
||||
This program is distributed in the hope that it will be useful,
|
||||
but WITHOUT ANY WARRANTY; without even the implied warranty of
|
||||
MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
||||
GNU Affero General Public License for more details.
|
||||
|
||||
You should have received a copy of the GNU Affero General Public License
|
||||
along with this program. If not, see <https://www.gnu.org/licenses/>.
|
||||
|
||||
For commercial licensing, please contact support@quantumnous.com
|
||||
*/
|
||||
|
||||
import React, { useEffect, useState } from 'react';
|
||||
import { Spin } from '@douyinfe/semi-ui';
|
||||
import { API, showError } from '../../helpers';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
import OAuth2ServerSettings from './oauth2/OAuth2ServerSettings';
|
||||
import OAuth2ClientSettings from './oauth2/OAuth2ClientSettings';
|
||||
|
||||
const OAuth2Setting = () => {
|
||||
const { t } = useTranslation();
|
||||
const [options, setOptions] = useState({});
|
||||
const [loading, setLoading] = useState(false);
|
||||
|
||||
const getOptions = async () => {
|
||||
setLoading(true);
|
||||
try {
|
||||
const res = await API.get('/api/option/');
|
||||
const { success, message, data } = res.data;
|
||||
if (success) {
|
||||
const map = {};
|
||||
for (const item of data) {
|
||||
map[item.key] = item.value;
|
||||
}
|
||||
setOptions(map);
|
||||
} else {
|
||||
showError(message);
|
||||
}
|
||||
} catch (error) {
|
||||
showError(t('获取OAuth2设置失败'));
|
||||
} finally {
|
||||
setLoading(false);
|
||||
}
|
||||
};
|
||||
|
||||
const refresh = () => {
|
||||
getOptions();
|
||||
};
|
||||
|
||||
useEffect(() => {
|
||||
getOptions();
|
||||
}, []);
|
||||
|
||||
return (
|
||||
<Spin spinning={loading} size='large'>
|
||||
{/* 服务器配置 */}
|
||||
<OAuth2ServerSettings options={options} refresh={refresh} />
|
||||
|
||||
{/* 客户端管理 */}
|
||||
<OAuth2ClientSettings />
|
||||
</Spin>
|
||||
);
|
||||
};
|
||||
|
||||
export default OAuth2Setting;
|
||||
@@ -45,7 +45,6 @@ const PaymentSetting = () => {
|
||||
StripePriceId: '',
|
||||
StripeUnitPrice: 8.0,
|
||||
StripeMinTopUp: 1,
|
||||
StripePromotionCodesEnabled: false,
|
||||
});
|
||||
|
||||
let [loading, setLoading] = useState(false);
|
||||
|
||||
@@ -19,14 +19,7 @@ For commercial licensing, please contact support@quantumnous.com
|
||||
|
||||
import React, { useContext, useEffect, useState } from 'react';
|
||||
import { useNavigate } from 'react-router-dom';
|
||||
import {
|
||||
API,
|
||||
copy,
|
||||
showError,
|
||||
showInfo,
|
||||
showSuccess,
|
||||
setStatusData,
|
||||
} from '../../helpers';
|
||||
import { API, copy, showError, showInfo, showSuccess } from '../../helpers';
|
||||
import { UserContext } from '../../context/User';
|
||||
import { Modal } from '@douyinfe/semi-ui';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
@@ -78,40 +71,18 @@ const PersonalSetting = () => {
|
||||
});
|
||||
|
||||
useEffect(() => {
|
||||
let saved = localStorage.getItem('status');
|
||||
if (saved) {
|
||||
const parsed = JSON.parse(saved);
|
||||
setStatus(parsed);
|
||||
if (parsed.turnstile_check) {
|
||||
let status = localStorage.getItem('status');
|
||||
if (status) {
|
||||
status = JSON.parse(status);
|
||||
setStatus(status);
|
||||
if (status.turnstile_check) {
|
||||
setTurnstileEnabled(true);
|
||||
setTurnstileSiteKey(parsed.turnstile_site_key);
|
||||
} else {
|
||||
setTurnstileEnabled(false);
|
||||
setTurnstileSiteKey('');
|
||||
setTurnstileSiteKey(status.turnstile_site_key);
|
||||
}
|
||||
}
|
||||
// Always refresh status from server to avoid stale flags (e.g., admin just enabled OAuth)
|
||||
(async () => {
|
||||
try {
|
||||
const res = await API.get('/api/status');
|
||||
const { success, data } = res.data;
|
||||
if (success && data) {
|
||||
setStatus(data);
|
||||
setStatusData(data);
|
||||
if (data.turnstile_check) {
|
||||
setTurnstileEnabled(true);
|
||||
setTurnstileSiteKey(data.turnstile_site_key);
|
||||
} else {
|
||||
setTurnstileEnabled(false);
|
||||
setTurnstileSiteKey('');
|
||||
}
|
||||
}
|
||||
} catch (e) {
|
||||
// ignore and keep local status
|
||||
}
|
||||
})();
|
||||
|
||||
getUserData();
|
||||
getUserData().then((res) => {
|
||||
console.log(userState);
|
||||
});
|
||||
}, []);
|
||||
|
||||
useEffect(() => {
|
||||
|
||||
@@ -39,9 +39,6 @@ const RatioSetting = () => {
|
||||
CompletionRatio: '',
|
||||
GroupRatio: '',
|
||||
GroupGroupRatio: '',
|
||||
ImageRatio: '',
|
||||
AudioRatio: '',
|
||||
AudioCompletionRatio: '',
|
||||
AutoGroups: '',
|
||||
DefaultUseAutoGroup: false,
|
||||
ExposeRatioEnabled: false,
|
||||
@@ -64,10 +61,7 @@ const RatioSetting = () => {
|
||||
item.key === 'UserUsableGroups' ||
|
||||
item.key === 'CompletionRatio' ||
|
||||
item.key === 'ModelPrice' ||
|
||||
item.key === 'CacheRatio' ||
|
||||
item.key === 'ImageRatio' ||
|
||||
item.key === 'AudioRatio' ||
|
||||
item.key === 'AudioCompletionRatio'
|
||||
item.key === 'CacheRatio'
|
||||
) {
|
||||
try {
|
||||
item.value = JSON.stringify(JSON.parse(item.value), null, 2);
|
||||
|
||||
@@ -29,7 +29,6 @@ import {
|
||||
TagInput,
|
||||
Spin,
|
||||
Card,
|
||||
Radio,
|
||||
} from '@douyinfe/semi-ui';
|
||||
const { Text } = Typography;
|
||||
import {
|
||||
@@ -45,7 +44,6 @@ import { useTranslation } from 'react-i18next';
|
||||
const SystemSetting = () => {
|
||||
const { t } = useTranslation();
|
||||
let [inputs, setInputs] = useState({
|
||||
|
||||
PasswordLoginEnabled: '',
|
||||
PasswordRegisterEnabled: '',
|
||||
EmailVerificationEnabled: '',
|
||||
@@ -89,15 +87,6 @@ const SystemSetting = () => {
|
||||
LinuxDOClientSecret: '',
|
||||
LinuxDOMinimumTrustLevel: '',
|
||||
ServerAddress: '',
|
||||
// SSRF防护配置
|
||||
'fetch_setting.enable_ssrf_protection': true,
|
||||
'fetch_setting.allow_private_ip': '',
|
||||
'fetch_setting.domain_filter_mode': false, // true 白名单,false 黑名单
|
||||
'fetch_setting.ip_filter_mode': false, // true 白名单,false 黑名单
|
||||
'fetch_setting.domain_list': [],
|
||||
'fetch_setting.ip_list': [],
|
||||
'fetch_setting.allowed_ports': [],
|
||||
'fetch_setting.apply_ip_filter_for_domain': false,
|
||||
});
|
||||
|
||||
const [originInputs, setOriginInputs] = useState({});
|
||||
@@ -109,11 +98,6 @@ const SystemSetting = () => {
|
||||
useState(false);
|
||||
const [linuxDOOAuthEnabled, setLinuxDOOAuthEnabled] = useState(false);
|
||||
const [emailToAdd, setEmailToAdd] = useState('');
|
||||
const [domainFilterMode, setDomainFilterMode] = useState(true);
|
||||
const [ipFilterMode, setIpFilterMode] = useState(true);
|
||||
const [domainList, setDomainList] = useState([]);
|
||||
const [ipList, setIpList] = useState([]);
|
||||
const [allowedPorts, setAllowedPorts] = useState([]);
|
||||
|
||||
const getOptions = async () => {
|
||||
setLoading(true);
|
||||
@@ -129,37 +113,6 @@ const SystemSetting = () => {
|
||||
case 'EmailDomainWhitelist':
|
||||
setEmailDomainWhitelist(item.value ? item.value.split(',') : []);
|
||||
break;
|
||||
case 'fetch_setting.allow_private_ip':
|
||||
case 'fetch_setting.enable_ssrf_protection':
|
||||
case 'fetch_setting.domain_filter_mode':
|
||||
case 'fetch_setting.ip_filter_mode':
|
||||
case 'fetch_setting.apply_ip_filter_for_domain':
|
||||
item.value = toBoolean(item.value);
|
||||
break;
|
||||
case 'fetch_setting.domain_list':
|
||||
try {
|
||||
const domains = item.value ? JSON.parse(item.value) : [];
|
||||
setDomainList(Array.isArray(domains) ? domains : []);
|
||||
} catch (e) {
|
||||
setDomainList([]);
|
||||
}
|
||||
break;
|
||||
case 'fetch_setting.ip_list':
|
||||
try {
|
||||
const ips = item.value ? JSON.parse(item.value) : [];
|
||||
setIpList(Array.isArray(ips) ? ips : []);
|
||||
} catch (e) {
|
||||
setIpList([]);
|
||||
}
|
||||
break;
|
||||
case 'fetch_setting.allowed_ports':
|
||||
try {
|
||||
const ports = item.value ? JSON.parse(item.value) : [];
|
||||
setAllowedPorts(Array.isArray(ports) ? ports : []);
|
||||
} catch (e) {
|
||||
setAllowedPorts(['80', '443', '8080', '8443']);
|
||||
}
|
||||
break;
|
||||
case 'PasswordLoginEnabled':
|
||||
case 'PasswordRegisterEnabled':
|
||||
case 'EmailVerificationEnabled':
|
||||
@@ -187,13 +140,6 @@ const SystemSetting = () => {
|
||||
});
|
||||
setInputs(newInputs);
|
||||
setOriginInputs(newInputs);
|
||||
// 同步模式布尔到本地状态
|
||||
if (typeof newInputs['fetch_setting.domain_filter_mode'] !== 'undefined') {
|
||||
setDomainFilterMode(!!newInputs['fetch_setting.domain_filter_mode']);
|
||||
}
|
||||
if (typeof newInputs['fetch_setting.ip_filter_mode'] !== 'undefined') {
|
||||
setIpFilterMode(!!newInputs['fetch_setting.ip_filter_mode']);
|
||||
}
|
||||
if (formApiRef.current) {
|
||||
formApiRef.current.setValues(newInputs);
|
||||
}
|
||||
@@ -330,46 +276,6 @@ const SystemSetting = () => {
|
||||
}
|
||||
};
|
||||
|
||||
const submitSSRF = async () => {
|
||||
const options = [];
|
||||
|
||||
// 处理域名过滤模式与列表
|
||||
options.push({
|
||||
key: 'fetch_setting.domain_filter_mode',
|
||||
value: domainFilterMode,
|
||||
});
|
||||
if (Array.isArray(domainList)) {
|
||||
options.push({
|
||||
key: 'fetch_setting.domain_list',
|
||||
value: JSON.stringify(domainList),
|
||||
});
|
||||
}
|
||||
|
||||
// 处理IP过滤模式与列表
|
||||
options.push({
|
||||
key: 'fetch_setting.ip_filter_mode',
|
||||
value: ipFilterMode,
|
||||
});
|
||||
if (Array.isArray(ipList)) {
|
||||
options.push({
|
||||
key: 'fetch_setting.ip_list',
|
||||
value: JSON.stringify(ipList),
|
||||
});
|
||||
}
|
||||
|
||||
// 处理端口配置
|
||||
if (Array.isArray(allowedPorts)) {
|
||||
options.push({
|
||||
key: 'fetch_setting.allowed_ports',
|
||||
value: JSON.stringify(allowedPorts),
|
||||
});
|
||||
}
|
||||
|
||||
if (options.length > 0) {
|
||||
await updateOptions(options);
|
||||
}
|
||||
};
|
||||
|
||||
const handleAddEmail = () => {
|
||||
if (emailToAdd && emailToAdd.trim() !== '') {
|
||||
const domain = emailToAdd.trim();
|
||||
@@ -681,179 +587,6 @@ const SystemSetting = () => {
|
||||
</Form.Section>
|
||||
</Card>
|
||||
|
||||
<Card>
|
||||
<Form.Section text={t('SSRF防护设置')}>
|
||||
<Text extraText={t('SSRF防护详细说明')}>
|
||||
{t('配置服务器端请求伪造(SSRF)防护,用于保护内网资源安全')}
|
||||
</Text>
|
||||
<Row
|
||||
gutter={{ xs: 8, sm: 16, md: 24, lg: 24, xl: 24, xxl: 24 }}
|
||||
>
|
||||
<Col xs={24} sm={24} md={24} lg={24} xl={24}>
|
||||
<Form.Checkbox
|
||||
field='fetch_setting.enable_ssrf_protection'
|
||||
noLabel
|
||||
extraText={t('SSRF防护开关详细说明')}
|
||||
onChange={(e) =>
|
||||
handleCheckboxChange('fetch_setting.enable_ssrf_protection', e)
|
||||
}
|
||||
>
|
||||
{t('启用SSRF防护(推荐开启以保护服务器安全)')}
|
||||
</Form.Checkbox>
|
||||
</Col>
|
||||
</Row>
|
||||
|
||||
<Row
|
||||
gutter={{ xs: 8, sm: 16, md: 24, lg: 24, xl: 24, xxl: 24 }}
|
||||
style={{ marginTop: 16 }}
|
||||
>
|
||||
<Col xs={24} sm={24} md={24} lg={24} xl={24}>
|
||||
<Form.Checkbox
|
||||
field='fetch_setting.allow_private_ip'
|
||||
noLabel
|
||||
extraText={t('私有IP访问详细说明')}
|
||||
onChange={(e) =>
|
||||
handleCheckboxChange('fetch_setting.allow_private_ip', e)
|
||||
}
|
||||
>
|
||||
{t('允许访问私有IP地址(127.0.0.1、192.168.x.x等内网地址)')}
|
||||
</Form.Checkbox>
|
||||
</Col>
|
||||
</Row>
|
||||
|
||||
<Row
|
||||
gutter={{ xs: 8, sm: 16, md: 24, lg: 24, xl: 24, xxl: 24 }}
|
||||
style={{ marginTop: 16 }}
|
||||
>
|
||||
<Col xs={24} sm={24} md={24} lg={24} xl={24}>
|
||||
<Form.Checkbox
|
||||
field='fetch_setting.apply_ip_filter_for_domain'
|
||||
noLabel
|
||||
extraText={t('域名IP过滤详细说明')}
|
||||
onChange={(e) =>
|
||||
handleCheckboxChange('fetch_setting.apply_ip_filter_for_domain', e)
|
||||
}
|
||||
style={{ marginBottom: 8 }}
|
||||
>
|
||||
{t('对域名启用 IP 过滤(实验性)')}
|
||||
</Form.Checkbox>
|
||||
<Text strong>
|
||||
{t(domainFilterMode ? '域名白名单' : '域名黑名单')}
|
||||
</Text>
|
||||
<Text type="secondary" style={{ display: 'block', marginBottom: 8 }}>
|
||||
{t('支持通配符格式,如:example.com, *.api.example.com')}
|
||||
</Text>
|
||||
<Radio.Group
|
||||
type='button'
|
||||
value={domainFilterMode ? 'whitelist' : 'blacklist'}
|
||||
onChange={(val) => {
|
||||
const selected = val && val.target ? val.target.value : val;
|
||||
const isWhitelist = selected === 'whitelist';
|
||||
setDomainFilterMode(isWhitelist);
|
||||
setInputs(prev => ({
|
||||
...prev,
|
||||
'fetch_setting.domain_filter_mode': isWhitelist,
|
||||
}));
|
||||
}}
|
||||
style={{ marginBottom: 8 }}
|
||||
>
|
||||
<Radio value='whitelist'>{t('白名单')}</Radio>
|
||||
<Radio value='blacklist'>{t('黑名单')}</Radio>
|
||||
</Radio.Group>
|
||||
<TagInput
|
||||
value={domainList}
|
||||
onChange={(value) => {
|
||||
setDomainList(value);
|
||||
// 触发Form的onChange事件
|
||||
setInputs(prev => ({
|
||||
...prev,
|
||||
'fetch_setting.domain_list': value
|
||||
}));
|
||||
}}
|
||||
placeholder={t('输入域名后回车,如:example.com')}
|
||||
style={{ width: '100%' }}
|
||||
/>
|
||||
</Col>
|
||||
</Row>
|
||||
|
||||
<Row
|
||||
gutter={{ xs: 8, sm: 16, md: 24, lg: 24, xl: 24, xxl: 24 }}
|
||||
style={{ marginTop: 16 }}
|
||||
>
|
||||
<Col xs={24} sm={24} md={24} lg={24} xl={24}>
|
||||
<Text strong>
|
||||
{t(ipFilterMode ? 'IP白名单' : 'IP黑名单')}
|
||||
</Text>
|
||||
<Text type="secondary" style={{ display: 'block', marginBottom: 8 }}>
|
||||
{t('支持CIDR格式,如:8.8.8.8, 192.168.1.0/24')}
|
||||
</Text>
|
||||
<Radio.Group
|
||||
type='button'
|
||||
value={ipFilterMode ? 'whitelist' : 'blacklist'}
|
||||
onChange={(val) => {
|
||||
const selected = val && val.target ? val.target.value : val;
|
||||
const isWhitelist = selected === 'whitelist';
|
||||
setIpFilterMode(isWhitelist);
|
||||
setInputs(prev => ({
|
||||
...prev,
|
||||
'fetch_setting.ip_filter_mode': isWhitelist,
|
||||
}));
|
||||
}}
|
||||
style={{ marginBottom: 8 }}
|
||||
>
|
||||
<Radio value='whitelist'>{t('白名单')}</Radio>
|
||||
<Radio value='blacklist'>{t('黑名单')}</Radio>
|
||||
</Radio.Group>
|
||||
<TagInput
|
||||
value={ipList}
|
||||
onChange={(value) => {
|
||||
setIpList(value);
|
||||
// 触发Form的onChange事件
|
||||
setInputs(prev => ({
|
||||
...prev,
|
||||
'fetch_setting.ip_list': value
|
||||
}));
|
||||
}}
|
||||
placeholder={t('输入IP地址后回车,如:8.8.8.8')}
|
||||
style={{ width: '100%' }}
|
||||
/>
|
||||
</Col>
|
||||
</Row>
|
||||
|
||||
<Row
|
||||
gutter={{ xs: 8, sm: 16, md: 24, lg: 24, xl: 24, xxl: 24 }}
|
||||
style={{ marginTop: 16 }}
|
||||
>
|
||||
<Col xs={24} sm={24} md={24} lg={24} xl={24}>
|
||||
<Text strong>{t('允许的端口')}</Text>
|
||||
<Text type="secondary" style={{ display: 'block', marginBottom: 8 }}>
|
||||
{t('支持单个端口和端口范围,如:80, 443, 8000-8999')}
|
||||
</Text>
|
||||
<TagInput
|
||||
value={allowedPorts}
|
||||
onChange={(value) => {
|
||||
setAllowedPorts(value);
|
||||
// 触发Form的onChange事件
|
||||
setInputs(prev => ({
|
||||
...prev,
|
||||
'fetch_setting.allowed_ports': value
|
||||
}));
|
||||
}}
|
||||
placeholder={t('输入端口后回车,如:80 或 8000-8999')}
|
||||
style={{ width: '100%' }}
|
||||
/>
|
||||
<Text type="secondary" style={{ display: 'block', marginBottom: 8 }}>
|
||||
{t('端口配置详细说明')}
|
||||
</Text>
|
||||
</Col>
|
||||
</Row>
|
||||
|
||||
<Button onClick={submitSSRF} style={{ marginTop: 16 }}>
|
||||
{t('更新SSRF防护设置')}
|
||||
</Button>
|
||||
</Form.Section>
|
||||
</Card>
|
||||
|
||||
<Card>
|
||||
<Form.Section text={t('配置登录注册')}>
|
||||
<Row
|
||||
|
||||
400
web/src/components/settings/oauth2/OAuth2ClientSettings.jsx
Normal file
400
web/src/components/settings/oauth2/OAuth2ClientSettings.jsx
Normal file
@@ -0,0 +1,400 @@
|
||||
/*
|
||||
Copyright (C) 2025 QuantumNous
|
||||
|
||||
This program is free software: you can redistribute it and/or modify
|
||||
it under the terms of the GNU Affero General Public License as
|
||||
published by the Free Software Foundation, either version 3 of the
|
||||
License, or (at your option) any later version.
|
||||
|
||||
This program is distributed in the hope that it will be useful,
|
||||
but WITHOUT ANY WARRANTY; without even the implied warranty of
|
||||
MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
||||
GNU Affero General Public License for more details.
|
||||
|
||||
You should have received a copy of the GNU Affero General Public License
|
||||
along with this program. If not, see <https://www.gnu.org/licenses/>.
|
||||
|
||||
For commercial licensing, please contact support@quantumnous.com
|
||||
*/
|
||||
|
||||
import React, { useEffect, useState } from 'react';
|
||||
import {
|
||||
Card,
|
||||
Table,
|
||||
Button,
|
||||
Space,
|
||||
Tag,
|
||||
Typography,
|
||||
Input,
|
||||
Popconfirm,
|
||||
Empty,
|
||||
Tooltip,
|
||||
} from '@douyinfe/semi-ui';
|
||||
import { IconSearch } from '@douyinfe/semi-icons';
|
||||
import { User } from 'lucide-react';
|
||||
import {
|
||||
IllustrationNoResult,
|
||||
IllustrationNoResultDark,
|
||||
} from '@douyinfe/semi-illustrations';
|
||||
import { API, showError, showSuccess } from '../../../helpers';
|
||||
import OAuth2ClientModal from './modals/OAuth2ClientModal';
|
||||
import SecretDisplayModal from './modals/SecretDisplayModal';
|
||||
import ServerInfoModal from './modals/ServerInfoModal';
|
||||
import JWKSInfoModal from './modals/JWKSInfoModal';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
|
||||
const { Text } = Typography;
|
||||
|
||||
export default function OAuth2ClientSettings() {
|
||||
const { t } = useTranslation();
|
||||
const [loading, setLoading] = useState(false);
|
||||
const [clients, setClients] = useState([]);
|
||||
const [filteredClients, setFilteredClients] = useState([]);
|
||||
const [searchKeyword, setSearchKeyword] = useState('');
|
||||
const [showModal, setShowModal] = useState(false);
|
||||
const [editingClient, setEditingClient] = useState(null);
|
||||
const [showSecretModal, setShowSecretModal] = useState(false);
|
||||
const [currentSecret, setCurrentSecret] = useState('');
|
||||
const [showServerInfoModal, setShowServerInfoModal] = useState(false);
|
||||
const [showJWKSModal, setShowJWKSModal] = useState(false);
|
||||
|
||||
// 加载客户端列表
|
||||
const loadClients = async () => {
|
||||
setLoading(true);
|
||||
try {
|
||||
const res = await API.get('/api/oauth_clients/');
|
||||
if (res.data.success) {
|
||||
setClients(res.data.data || []);
|
||||
setFilteredClients(res.data.data || []);
|
||||
} else {
|
||||
showError(res.data.message);
|
||||
}
|
||||
} catch (error) {
|
||||
showError(t('加载OAuth2客户端失败'));
|
||||
} finally {
|
||||
setLoading(false);
|
||||
}
|
||||
};
|
||||
|
||||
// 搜索过滤
|
||||
const handleSearch = (value) => {
|
||||
setSearchKeyword(value);
|
||||
if (!value) {
|
||||
setFilteredClients(clients);
|
||||
} else {
|
||||
const filtered = clients.filter(
|
||||
(client) =>
|
||||
client.name?.toLowerCase().includes(value.toLowerCase()) ||
|
||||
client.id?.toLowerCase().includes(value.toLowerCase()) ||
|
||||
client.description?.toLowerCase().includes(value.toLowerCase()),
|
||||
);
|
||||
setFilteredClients(filtered);
|
||||
}
|
||||
};
|
||||
|
||||
// 删除客户端
|
||||
const handleDelete = async (client) => {
|
||||
try {
|
||||
const res = await API.delete(`/api/oauth_clients/${client.id}`);
|
||||
if (res.data.success) {
|
||||
showSuccess(t('删除成功'));
|
||||
loadClients();
|
||||
} else {
|
||||
showError(res.data.message);
|
||||
}
|
||||
} catch (error) {
|
||||
showError(t('删除失败'));
|
||||
}
|
||||
};
|
||||
|
||||
// 重新生成密钥
|
||||
const handleRegenerateSecret = async (client) => {
|
||||
try {
|
||||
const res = await API.post(
|
||||
`/api/oauth_clients/${client.id}/regenerate_secret`,
|
||||
);
|
||||
if (res.data.success) {
|
||||
setCurrentSecret(res.data.client_secret);
|
||||
setShowSecretModal(true);
|
||||
loadClients();
|
||||
} else {
|
||||
showError(res.data.message);
|
||||
}
|
||||
} catch (error) {
|
||||
showError(t('重新生成密钥失败'));
|
||||
}
|
||||
};
|
||||
|
||||
// 查看服务器信息
|
||||
const showServerInfo = () => {
|
||||
setShowServerInfoModal(true);
|
||||
};
|
||||
|
||||
// 查看JWKS
|
||||
const showJWKS = () => {
|
||||
setShowJWKSModal(true);
|
||||
};
|
||||
|
||||
// 表格列定义
|
||||
const columns = [
|
||||
{
|
||||
title: t('客户端名称'),
|
||||
dataIndex: 'name',
|
||||
render: (name, record) => (
|
||||
<div className='flex items-center cursor-help'>
|
||||
<User size={16} className='mr-1.5 text-gray-500' />
|
||||
<Tooltip content={record.description || t('暂无描述')} position='top'>
|
||||
<Text strong>{name}</Text>
|
||||
</Tooltip>
|
||||
</div>
|
||||
),
|
||||
},
|
||||
{
|
||||
title: t('客户端ID'),
|
||||
dataIndex: 'id',
|
||||
render: (id) => (
|
||||
<Text type='tertiary' size='small' code copyable>
|
||||
{id}
|
||||
</Text>
|
||||
),
|
||||
},
|
||||
{
|
||||
title: t('状态'),
|
||||
dataIndex: 'status',
|
||||
render: (status) => (
|
||||
<Tag color={status === 1 ? 'green' : 'red'} shape='circle'>
|
||||
{status === 1 ? t('启用') : t('禁用')}
|
||||
</Tag>
|
||||
),
|
||||
},
|
||||
{
|
||||
title: t('类型'),
|
||||
dataIndex: 'client_type',
|
||||
render: (text) => (
|
||||
<Tag color='white' shape='circle'>
|
||||
{text === 'confidential' ? t('机密客户端') : t('公开客户端')}
|
||||
</Tag>
|
||||
),
|
||||
},
|
||||
{
|
||||
title: t('授权类型'),
|
||||
dataIndex: 'grant_types',
|
||||
render: (grantTypes) => {
|
||||
const types =
|
||||
typeof grantTypes === 'string'
|
||||
? grantTypes.split(',')
|
||||
: grantTypes || [];
|
||||
const typeMap = {
|
||||
client_credentials: t('客户端凭证'),
|
||||
authorization_code: t('授权码'),
|
||||
refresh_token: t('刷新令牌'),
|
||||
};
|
||||
return (
|
||||
<div className='flex flex-wrap gap-1'>
|
||||
{types.slice(0, 2).map((type) => (
|
||||
<Tag color='white' key={type} size='small' shape='circle'>
|
||||
{typeMap[type] || type}
|
||||
</Tag>
|
||||
))}
|
||||
{types.length > 2 && (
|
||||
<Tooltip
|
||||
content={types
|
||||
.slice(2)
|
||||
.map((t) => typeMap[t] || t)
|
||||
.join(', ')}
|
||||
>
|
||||
<Tag color='white' size='small' shape='circle'>
|
||||
+{types.length - 2}
|
||||
</Tag>
|
||||
</Tooltip>
|
||||
)}
|
||||
</div>
|
||||
);
|
||||
},
|
||||
},
|
||||
{
|
||||
title: t('创建时间'),
|
||||
dataIndex: 'created_time',
|
||||
render: (time) => new Date(time * 1000).toLocaleString(),
|
||||
},
|
||||
{
|
||||
title: t('操作'),
|
||||
render: (_, record) => (
|
||||
<Space size={4} wrap>
|
||||
<Button
|
||||
type='primary'
|
||||
size='small'
|
||||
onClick={() => {
|
||||
setEditingClient(record);
|
||||
setShowModal(true);
|
||||
}}
|
||||
>
|
||||
{t('编辑')}
|
||||
</Button>
|
||||
{record.client_type === 'confidential' && (
|
||||
<Popconfirm
|
||||
title={t('确认重新生成客户端密钥?')}
|
||||
content={t('操作不可撤销,旧密钥将立即失效。')}
|
||||
onConfirm={() => handleRegenerateSecret(record)}
|
||||
okText={t('确认')}
|
||||
cancelText={t('取消')}
|
||||
position='bottomLeft'
|
||||
>
|
||||
<Button type='secondary' size='small'>
|
||||
{t('重新生成密钥')}
|
||||
</Button>
|
||||
</Popconfirm>
|
||||
)}
|
||||
<Popconfirm
|
||||
title={t('请再次确认删除该客户端')}
|
||||
content={t('删除后无法恢复,相关 API 调用将立即失效。')}
|
||||
onConfirm={() => handleDelete(record)}
|
||||
okText={t('确定删除')}
|
||||
cancelText={t('取消')}
|
||||
position='bottomLeft'
|
||||
>
|
||||
<Button type='danger' size='small'>
|
||||
{t('删除')}
|
||||
</Button>
|
||||
</Popconfirm>
|
||||
</Space>
|
||||
),
|
||||
fixed: 'right',
|
||||
},
|
||||
];
|
||||
|
||||
useEffect(() => {
|
||||
loadClients();
|
||||
}, []);
|
||||
|
||||
return (
|
||||
<Card
|
||||
className='!rounded-2xl shadow-sm border-0'
|
||||
style={{ marginTop: 10 }}
|
||||
title={
|
||||
<div
|
||||
className='flex flex-col sm:flex-row sm:items-center sm:justify-between w-full gap-3 sm:gap-0'
|
||||
style={{ paddingRight: '8px' }}
|
||||
>
|
||||
<div className='flex items-center'>
|
||||
<User size={18} className='mr-2' />
|
||||
<Text strong>{t('OAuth2 客户端管理')}</Text>
|
||||
<Tag color='white' shape='circle' size='small' className='ml-2'>
|
||||
{filteredClients.length} {t('个客户端')}
|
||||
</Tag>
|
||||
</div>
|
||||
<div className='flex items-center gap-2 sm:flex-shrink-0 flex-wrap'>
|
||||
<Input
|
||||
prefix={<IconSearch />}
|
||||
placeholder={t('搜索客户端名称、ID或描述')}
|
||||
value={searchKeyword}
|
||||
onChange={handleSearch}
|
||||
showClear
|
||||
size='small'
|
||||
style={{ width: 300 }}
|
||||
/>
|
||||
<Button type='tertiary' onClick={loadClients} size='small'>
|
||||
{t('刷新')}
|
||||
</Button>
|
||||
<Button type='secondary' onClick={showServerInfo} size='small'>
|
||||
{t('服务器信息')}
|
||||
</Button>
|
||||
<Button type='secondary' onClick={showJWKS} size='small'>
|
||||
{t('查看JWKS')}
|
||||
</Button>
|
||||
<Button
|
||||
type='primary'
|
||||
onClick={() => {
|
||||
setEditingClient(null);
|
||||
setShowModal(true);
|
||||
}}
|
||||
size='small'
|
||||
>
|
||||
{t('创建客户端')}
|
||||
</Button>
|
||||
</div>
|
||||
</div>
|
||||
}
|
||||
>
|
||||
<div className='mb-4'>
|
||||
<Text type='tertiary'>
|
||||
{t(
|
||||
'管理OAuth2客户端应用程序,每个客户端代表一个可以访问API的应用程序。机密客户端用于服务器端应用,公开客户端用于移动应用或单页应用。',
|
||||
)}
|
||||
</Text>
|
||||
</div>
|
||||
|
||||
{/* 客户端表格 */}
|
||||
<Table
|
||||
columns={columns}
|
||||
dataSource={filteredClients}
|
||||
rowKey='id'
|
||||
loading={loading}
|
||||
scroll={{ x: 'max-content' }}
|
||||
pagination={{
|
||||
showSizeChanger: true,
|
||||
showQuickJumper: true,
|
||||
showTotal: true,
|
||||
pageSize: 10,
|
||||
}}
|
||||
empty={
|
||||
<Empty
|
||||
image={<IllustrationNoResult style={{ width: 150, height: 150 }} />}
|
||||
darkModeImage={
|
||||
<IllustrationNoResultDark style={{ width: 150, height: 150 }} />
|
||||
}
|
||||
title={t('暂无OAuth2客户端')}
|
||||
description={t(
|
||||
'还没有创建任何客户端,点击下方按钮创建第一个客户端',
|
||||
)}
|
||||
style={{ padding: 30 }}
|
||||
>
|
||||
<Button
|
||||
type='primary'
|
||||
onClick={() => {
|
||||
setEditingClient(null);
|
||||
setShowModal(true);
|
||||
}}
|
||||
>
|
||||
{t('创建第一个客户端')}
|
||||
</Button>
|
||||
</Empty>
|
||||
}
|
||||
/>
|
||||
|
||||
{/* OAuth2 客户端模态框 */}
|
||||
<OAuth2ClientModal
|
||||
visible={showModal}
|
||||
client={editingClient}
|
||||
onCancel={() => {
|
||||
setShowModal(false);
|
||||
setEditingClient(null);
|
||||
}}
|
||||
onSuccess={() => {
|
||||
setShowModal(false);
|
||||
setEditingClient(null);
|
||||
loadClients();
|
||||
}}
|
||||
/>
|
||||
|
||||
{/* 密钥显示模态框 */}
|
||||
<SecretDisplayModal
|
||||
visible={showSecretModal}
|
||||
onClose={() => setShowSecretModal(false)}
|
||||
secret={currentSecret}
|
||||
/>
|
||||
|
||||
{/* 服务器信息模态框 */}
|
||||
<ServerInfoModal
|
||||
visible={showServerInfoModal}
|
||||
onClose={() => setShowServerInfoModal(false)}
|
||||
/>
|
||||
|
||||
{/* JWKS信息模态框 */}
|
||||
<JWKSInfoModal
|
||||
visible={showJWKSModal}
|
||||
onClose={() => setShowJWKSModal(false)}
|
||||
/>
|
||||
</Card>
|
||||
);
|
||||
}
|
||||
473
web/src/components/settings/oauth2/OAuth2ServerSettings.jsx
Normal file
473
web/src/components/settings/oauth2/OAuth2ServerSettings.jsx
Normal file
@@ -0,0 +1,473 @@
|
||||
/*
|
||||
Copyright (C) 2025 QuantumNous
|
||||
|
||||
This program is free software: you can redistribute it and/or modify
|
||||
it under the terms of the GNU Affero General Public License as
|
||||
published by the Free Software Foundation, either version 3 of the
|
||||
License, or (at your option) any later version.
|
||||
|
||||
This program is distributed in the hope that it will be useful,
|
||||
but WITHOUT ANY WARRANTY; without even the implied warranty of
|
||||
MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
||||
GNU Affero General Public License for more details.
|
||||
|
||||
You should have received a copy of the GNU Affero General Public License
|
||||
along with this program. If not, see <https://www.gnu.org/licenses/>.
|
||||
|
||||
For commercial licensing, please contact support@quantumnous.com
|
||||
*/
|
||||
|
||||
import React, { useEffect, useState, useRef } from 'react';
|
||||
import {
|
||||
Banner,
|
||||
Button,
|
||||
Col,
|
||||
Form,
|
||||
Row,
|
||||
Card,
|
||||
Typography,
|
||||
Badge,
|
||||
Divider,
|
||||
} from '@douyinfe/semi-ui';
|
||||
import { Server } from 'lucide-react';
|
||||
import JWKSManagerModal from './modals/JWKSManagerModal';
|
||||
import {
|
||||
compareObjects,
|
||||
API,
|
||||
showError,
|
||||
showSuccess,
|
||||
showWarning,
|
||||
} from '../../../helpers';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
|
||||
const { Text } = Typography;
|
||||
|
||||
export default function OAuth2ServerSettings(props) {
|
||||
const { t } = useTranslation();
|
||||
const [loading, setLoading] = useState(false);
|
||||
const [inputs, setInputs] = useState({
|
||||
'oauth2.enabled': false,
|
||||
'oauth2.issuer': '',
|
||||
'oauth2.access_token_ttl': 10,
|
||||
'oauth2.refresh_token_ttl': 720,
|
||||
'oauth2.jwt_signing_algorithm': 'RS256',
|
||||
'oauth2.jwt_key_id': 'oauth2-key-1',
|
||||
'oauth2.allowed_grant_types': [
|
||||
'client_credentials',
|
||||
'authorization_code',
|
||||
'refresh_token',
|
||||
],
|
||||
'oauth2.require_pkce': true,
|
||||
'oauth2.max_jwks_keys': 3,
|
||||
});
|
||||
const refForm = useRef();
|
||||
const [inputsRow, setInputsRow] = useState(inputs);
|
||||
const [keysReady, setKeysReady] = useState(true);
|
||||
const [keysLoading, setKeysLoading] = useState(false);
|
||||
const [serverInfo, setServerInfo] = useState(null);
|
||||
const enabledRef = useRef(inputs['oauth2.enabled']);
|
||||
|
||||
// 模态框状态
|
||||
const [jwksVisible, setJwksVisible] = useState(false);
|
||||
|
||||
function handleFieldChange(fieldName) {
|
||||
return (value) => {
|
||||
setInputs((inputs) => ({ ...inputs, [fieldName]: value }));
|
||||
};
|
||||
}
|
||||
|
||||
function onSubmit() {
|
||||
const updateArray = compareObjects(inputs, inputsRow);
|
||||
if (!updateArray.length) return showWarning(t('你似乎并没有修改什么'));
|
||||
const requestQueue = updateArray.map((item) => {
|
||||
let value = '';
|
||||
if (typeof inputs[item.key] === 'boolean') {
|
||||
value = String(inputs[item.key]);
|
||||
} else if (Array.isArray(inputs[item.key])) {
|
||||
value = JSON.stringify(inputs[item.key]);
|
||||
} else {
|
||||
value = inputs[item.key];
|
||||
}
|
||||
return API.put('/api/option/', {
|
||||
key: item.key,
|
||||
value,
|
||||
});
|
||||
});
|
||||
setLoading(true);
|
||||
Promise.all(requestQueue)
|
||||
.then((res) => {
|
||||
if (requestQueue.length === 1) {
|
||||
if (res.includes(undefined)) return;
|
||||
} else if (requestQueue.length > 1) {
|
||||
if (res.includes(undefined))
|
||||
return showError(t('部分保存失败,请重试'));
|
||||
}
|
||||
showSuccess(t('保存成功'));
|
||||
if (props && props.refresh) {
|
||||
props.refresh();
|
||||
}
|
||||
})
|
||||
.catch(() => {
|
||||
showError(t('保存失败,请重试'));
|
||||
})
|
||||
.finally(() => {
|
||||
setLoading(false);
|
||||
});
|
||||
}
|
||||
|
||||
// 测试OAuth2连接(默认静默,仅用户点击时弹提示)
|
||||
const testOAuth2 = async (silent = true) => {
|
||||
// 未启用时不触发测试,避免 404
|
||||
if (!enabledRef.current) return;
|
||||
try {
|
||||
const res = await API.get('/api/oauth/server-info', {
|
||||
skipErrorHandler: true,
|
||||
});
|
||||
if (!enabledRef.current) return;
|
||||
if (
|
||||
res.status === 200 &&
|
||||
(res.data.issuer || res.data.authorization_endpoint)
|
||||
) {
|
||||
if (!silent) showSuccess('OAuth2服务器运行正常');
|
||||
setServerInfo(res.data);
|
||||
} else {
|
||||
if (!enabledRef.current) return;
|
||||
if (!silent) showError('OAuth2服务器测试失败');
|
||||
}
|
||||
} catch (error) {
|
||||
if (!enabledRef.current) return;
|
||||
if (!silent) showError('OAuth2服务器连接测试失败');
|
||||
}
|
||||
};
|
||||
|
||||
useEffect(() => {
|
||||
if (props && props.options) {
|
||||
const currentInputs = {};
|
||||
for (let key in props.options) {
|
||||
if (Object.keys(inputs).includes(key)) {
|
||||
if (key === 'oauth2.allowed_grant_types') {
|
||||
try {
|
||||
currentInputs[key] = JSON.parse(
|
||||
props.options[key] ||
|
||||
'["client_credentials","authorization_code","refresh_token"]',
|
||||
);
|
||||
} catch {
|
||||
currentInputs[key] = [
|
||||
'client_credentials',
|
||||
'authorization_code',
|
||||
'refresh_token',
|
||||
];
|
||||
}
|
||||
} else if (typeof inputs[key] === 'boolean') {
|
||||
currentInputs[key] = props.options[key] === 'true';
|
||||
} else if (typeof inputs[key] === 'number') {
|
||||
currentInputs[key] = parseInt(props.options[key]) || inputs[key];
|
||||
} else {
|
||||
currentInputs[key] = props.options[key];
|
||||
}
|
||||
}
|
||||
}
|
||||
setInputs({ ...inputs, ...currentInputs });
|
||||
setInputsRow(structuredClone({ ...inputs, ...currentInputs }));
|
||||
if (refForm.current) {
|
||||
refForm.current.setValues({ ...inputs, ...currentInputs });
|
||||
}
|
||||
}
|
||||
}, [props]);
|
||||
|
||||
useEffect(() => {
|
||||
enabledRef.current = inputs['oauth2.enabled'];
|
||||
}, [inputs['oauth2.enabled']]);
|
||||
|
||||
useEffect(() => {
|
||||
const loadKeys = async () => {
|
||||
try {
|
||||
setKeysLoading(true);
|
||||
const res = await API.get('/api/oauth/keys', {
|
||||
skipErrorHandler: true,
|
||||
});
|
||||
const list = res?.data?.data || [];
|
||||
setKeysReady(list.length > 0);
|
||||
} catch {
|
||||
setKeysReady(false);
|
||||
} finally {
|
||||
setKeysLoading(false);
|
||||
}
|
||||
};
|
||||
if (inputs['oauth2.enabled']) {
|
||||
loadKeys();
|
||||
testOAuth2(true);
|
||||
} else {
|
||||
// 禁用时清理状态,避免残留状态与不必要的请求
|
||||
setKeysReady(true);
|
||||
setServerInfo(null);
|
||||
setKeysLoading(false);
|
||||
}
|
||||
}, [inputs['oauth2.enabled']]);
|
||||
|
||||
const isEnabled = inputs['oauth2.enabled'];
|
||||
|
||||
return (
|
||||
<div>
|
||||
{/* OAuth2 服务端管理 */}
|
||||
<Card
|
||||
className='!rounded-2xl shadow-sm border-0'
|
||||
style={{ marginTop: 10 }}
|
||||
title={
|
||||
<div
|
||||
className='flex flex-col sm:flex-row sm:items-center sm:justify-between w-full gap-3 sm:gap-0'
|
||||
style={{ paddingRight: '8px' }}
|
||||
>
|
||||
<div className='flex items-center'>
|
||||
<Server size={18} className='mr-2' />
|
||||
<Text strong>{t('OAuth2 服务端管理')}</Text>
|
||||
{isEnabled ? (
|
||||
serverInfo ? (
|
||||
<Badge
|
||||
count={t('运行正常')}
|
||||
type='success'
|
||||
style={{ marginLeft: 8 }}
|
||||
/>
|
||||
) : (
|
||||
<Badge
|
||||
count={t('配置中')}
|
||||
type='warning'
|
||||
style={{ marginLeft: 8 }}
|
||||
/>
|
||||
)
|
||||
) : (
|
||||
<Badge
|
||||
count={t('未启用')}
|
||||
type='tertiary'
|
||||
style={{ marginLeft: 8 }}
|
||||
/>
|
||||
)}
|
||||
</div>
|
||||
<div className='flex items-center gap-2 sm:flex-shrink-0'>
|
||||
{isEnabled && (
|
||||
<Button
|
||||
type='secondary'
|
||||
onClick={() => setJwksVisible(true)}
|
||||
size='small'
|
||||
>
|
||||
{t('密钥管理')}
|
||||
</Button>
|
||||
)}
|
||||
<Button
|
||||
type='primary'
|
||||
onClick={onSubmit}
|
||||
loading={loading}
|
||||
size='small'
|
||||
>
|
||||
{t('保存配置')}
|
||||
</Button>
|
||||
</div>
|
||||
</div>
|
||||
}
|
||||
>
|
||||
<Form
|
||||
initValues={inputs}
|
||||
getFormApi={(formAPI) => (refForm.current = formAPI)}
|
||||
>
|
||||
{!keysReady && isEnabled && (
|
||||
<Banner
|
||||
type='warning'
|
||||
className='!rounded-lg'
|
||||
closeIcon={null}
|
||||
description={t(
|
||||
'尚未准备签名密钥,建议立即初始化或轮换以发布 JWKS。签名密钥用于 JWT 令牌的安全签发。',
|
||||
)}
|
||||
/>
|
||||
)}
|
||||
|
||||
<Row gutter={[16, 24]}>
|
||||
<Col xs={24} lg={12}>
|
||||
<Form.Switch
|
||||
field='oauth2.enabled'
|
||||
label={t('启用 OAuth2 & SSO')}
|
||||
value={inputs['oauth2.enabled']}
|
||||
onChange={handleFieldChange('oauth2.enabled')}
|
||||
extraText={t('开启后将允许以 OAuth2/OIDC 标准进行授权与登录')}
|
||||
/>
|
||||
</Col>
|
||||
<Col xs={24} lg={12}>
|
||||
<Form.Input
|
||||
field='oauth2.issuer'
|
||||
label={t('发行人 (Issuer)')}
|
||||
placeholder={window.location.origin}
|
||||
value={inputs['oauth2.issuer']}
|
||||
onChange={handleFieldChange('oauth2.issuer')}
|
||||
extraText={t('为空则按请求自动推断(含 X-Forwarded-Proto)')}
|
||||
/>
|
||||
</Col>
|
||||
</Row>
|
||||
|
||||
{/* 令牌配置 */}
|
||||
<Divider margin='24px'>{t('令牌配置')}</Divider>
|
||||
|
||||
<Row gutter={[16, 24]}>
|
||||
<Col xs={24} sm={12} lg={8}>
|
||||
<Form.InputNumber
|
||||
field='oauth2.access_token_ttl'
|
||||
label={t('访问令牌有效期')}
|
||||
suffix={t('分钟')}
|
||||
min={1}
|
||||
max={1440}
|
||||
value={inputs['oauth2.access_token_ttl']}
|
||||
onChange={handleFieldChange('oauth2.access_token_ttl')}
|
||||
extraText={t('访问令牌的有效时间,建议较短(10-60分钟)')}
|
||||
style={{
|
||||
width: '100%',
|
||||
opacity: isEnabled ? 1 : 0.5,
|
||||
}}
|
||||
disabled={!isEnabled}
|
||||
/>
|
||||
</Col>
|
||||
<Col xs={24} sm={12} lg={8}>
|
||||
<Form.InputNumber
|
||||
field='oauth2.refresh_token_ttl'
|
||||
label={t('刷新令牌有效期')}
|
||||
suffix={t('小时')}
|
||||
min={1}
|
||||
max={8760}
|
||||
value={inputs['oauth2.refresh_token_ttl']}
|
||||
onChange={handleFieldChange('oauth2.refresh_token_ttl')}
|
||||
extraText={t('刷新令牌的有效时间,建议较长(12-720小时)')}
|
||||
style={{
|
||||
width: '100%',
|
||||
opacity: isEnabled ? 1 : 0.5,
|
||||
}}
|
||||
disabled={!isEnabled}
|
||||
/>
|
||||
</Col>
|
||||
<Col xs={24} sm={12} lg={8}>
|
||||
<Form.InputNumber
|
||||
field='oauth2.max_jwks_keys'
|
||||
label={t('JWKS历史保留上限')}
|
||||
min={1}
|
||||
max={10}
|
||||
value={inputs['oauth2.max_jwks_keys']}
|
||||
onChange={handleFieldChange('oauth2.max_jwks_keys')}
|
||||
extraText={t('轮换后最多保留的历史签名密钥数量')}
|
||||
style={{
|
||||
width: '100%',
|
||||
opacity: isEnabled ? 1 : 0.5,
|
||||
}}
|
||||
disabled={!isEnabled}
|
||||
/>
|
||||
</Col>
|
||||
</Row>
|
||||
|
||||
<Row gutter={[16, 24]} style={{ marginTop: 16 }}>
|
||||
<Col xs={24} lg={12}>
|
||||
<Form.Select
|
||||
field='oauth2.jwt_signing_algorithm'
|
||||
label={t('JWT签名算法')}
|
||||
value={inputs['oauth2.jwt_signing_algorithm']}
|
||||
onChange={handleFieldChange('oauth2.jwt_signing_algorithm')}
|
||||
extraText={t('JWT令牌的签名算法,推荐使用RS256')}
|
||||
style={{
|
||||
width: '100%',
|
||||
opacity: isEnabled ? 1 : 0.5,
|
||||
}}
|
||||
disabled={!isEnabled}
|
||||
>
|
||||
<Form.Select.Option value='RS256'>
|
||||
RS256 (RSA with SHA-256)
|
||||
</Form.Select.Option>
|
||||
<Form.Select.Option value='HS256'>
|
||||
HS256 (HMAC with SHA-256)
|
||||
</Form.Select.Option>
|
||||
</Form.Select>
|
||||
</Col>
|
||||
<Col xs={24} lg={12}>
|
||||
<Form.Input
|
||||
field='oauth2.jwt_key_id'
|
||||
label={t('JWT密钥ID')}
|
||||
placeholder='oauth2-key-1'
|
||||
value={inputs['oauth2.jwt_key_id']}
|
||||
onChange={handleFieldChange('oauth2.jwt_key_id')}
|
||||
extraText={t('用于标识JWT签名密钥,支持密钥轮换')}
|
||||
style={{
|
||||
width: '100%',
|
||||
opacity: isEnabled ? 1 : 0.5,
|
||||
}}
|
||||
disabled={!isEnabled}
|
||||
/>
|
||||
</Col>
|
||||
</Row>
|
||||
|
||||
{/* 授权配置 */}
|
||||
<Divider margin='24px'>{t('授权配置')}</Divider>
|
||||
|
||||
<Row gutter={[16, 24]}>
|
||||
<Col xs={24} lg={12}>
|
||||
<Form.Select
|
||||
field='oauth2.allowed_grant_types'
|
||||
label={t('允许的授权类型')}
|
||||
multiple
|
||||
value={inputs['oauth2.allowed_grant_types']}
|
||||
onChange={handleFieldChange('oauth2.allowed_grant_types')}
|
||||
extraText={t('选择允许的OAuth2授权流程')}
|
||||
style={{
|
||||
width: '100%',
|
||||
opacity: isEnabled ? 1 : 0.5,
|
||||
}}
|
||||
disabled={!isEnabled}
|
||||
>
|
||||
<Form.Select.Option value='client_credentials'>
|
||||
{t('Client Credentials(客户端凭证)')}
|
||||
</Form.Select.Option>
|
||||
<Form.Select.Option value='authorization_code'>
|
||||
{t('Authorization Code(授权码)')}
|
||||
</Form.Select.Option>
|
||||
<Form.Select.Option value='refresh_token'>
|
||||
{t('Refresh Token(刷新令牌)')}
|
||||
</Form.Select.Option>
|
||||
</Form.Select>
|
||||
</Col>
|
||||
<Col xs={24} lg={12}>
|
||||
<Form.Switch
|
||||
field='oauth2.require_pkce'
|
||||
label={t('强制PKCE验证')}
|
||||
value={inputs['oauth2.require_pkce']}
|
||||
onChange={handleFieldChange('oauth2.require_pkce')}
|
||||
extraText={t('为授权码流程强制启用PKCE,提高安全性')}
|
||||
disabled={!isEnabled}
|
||||
/>
|
||||
</Col>
|
||||
</Row>
|
||||
|
||||
<div style={{ marginTop: 16 }}>
|
||||
<Text type='tertiary' size='small'>
|
||||
<div className='space-y-1'>
|
||||
<div>• {t('OAuth2 服务器提供标准的 API 认证与授权')}</div>
|
||||
<div>
|
||||
•{' '}
|
||||
{t(
|
||||
'支持 Client Credentials、Authorization Code + PKCE 等标准流程',
|
||||
)}
|
||||
</div>
|
||||
<div>
|
||||
•{' '}
|
||||
{t(
|
||||
'配置保存后多数项即时生效;签名密钥轮换与 JWKS 发布为即时操作',
|
||||
)}
|
||||
</div>
|
||||
<div>
|
||||
• {t('生产环境务必启用 HTTPS,并妥善管理 JWT 签名密钥')}
|
||||
</div>
|
||||
</div>
|
||||
</Text>
|
||||
</div>
|
||||
</Form>
|
||||
</Card>
|
||||
|
||||
{/* 模态框 */}
|
||||
<JWKSManagerModal
|
||||
visible={jwksVisible}
|
||||
onClose={() => setJwksVisible(false)}
|
||||
/>
|
||||
</div>
|
||||
);
|
||||
}
|
||||
@@ -0,0 +1,78 @@
|
||||
/*
|
||||
Copyright (C) 2025 QuantumNous
|
||||
|
||||
This program is free software: you can redistribute it and/or modify
|
||||
it under the terms of the GNU Affero General Public License as
|
||||
published by the Free Software Foundation, either version 3 of the
|
||||
License, or (at your option) any later version.
|
||||
|
||||
This program is distributed in the hope that it will be useful,
|
||||
but WITHOUT ANY WARRANTY; without even the implied warranty of
|
||||
MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
||||
GNU Affero General Public License for more details.
|
||||
|
||||
You should have received a copy of the GNU Affero General Public License
|
||||
along with this program. If not, see <https://www.gnu.org/licenses/>.
|
||||
|
||||
For commercial licensing, please contact support@quantumnous.com
|
||||
*/
|
||||
|
||||
import React from 'react';
|
||||
import { Modal, Banner, Typography } from '@douyinfe/semi-ui';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
|
||||
const { Text } = Typography;
|
||||
|
||||
const ClientInfoModal = ({ visible, onClose, clientId, clientSecret }) => {
|
||||
const { t } = useTranslation();
|
||||
|
||||
return (
|
||||
<Modal
|
||||
title={t('客户端创建成功')}
|
||||
visible={visible}
|
||||
onCancel={onClose}
|
||||
onOk={onClose}
|
||||
cancelText=''
|
||||
okText={t('我已复制保存')}
|
||||
width={650}
|
||||
bodyStyle={{ padding: '20px 24px' }}
|
||||
>
|
||||
<Banner
|
||||
type='success'
|
||||
closeIcon={null}
|
||||
description={t(
|
||||
'客户端信息如下,请立即复制保存。关闭此窗口后将无法再次查看密钥。',
|
||||
)}
|
||||
className='mb-5 !rounded-lg'
|
||||
/>
|
||||
|
||||
<div className='space-y-4'>
|
||||
<div className='flex justify-center items-center'>
|
||||
<div className='text-center'>
|
||||
<Text strong className='block mb-2'>
|
||||
{t('客户端ID')}
|
||||
</Text>
|
||||
<Text code copyable>
|
||||
{clientId}
|
||||
</Text>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
{clientSecret && (
|
||||
<div className='flex justify-center items-center'>
|
||||
<div className='text-center'>
|
||||
<Text strong className='block mb-2'>
|
||||
{t('客户端密钥(仅此一次显示)')}
|
||||
</Text>
|
||||
<Text code copyable>
|
||||
{clientSecret}
|
||||
</Text>
|
||||
</div>
|
||||
</div>
|
||||
)}
|
||||
</div>
|
||||
</Modal>
|
||||
);
|
||||
};
|
||||
|
||||
export default ClientInfoModal;
|
||||
70
web/src/components/settings/oauth2/modals/JWKSInfoModal.jsx
Normal file
70
web/src/components/settings/oauth2/modals/JWKSInfoModal.jsx
Normal file
@@ -0,0 +1,70 @@
|
||||
/*
|
||||
Copyright (C) 2025 QuantumNous
|
||||
|
||||
This program is free software: you can redistribute it and/or modify
|
||||
it under the terms of the GNU Affero General Public License as
|
||||
published by the Free Software Foundation, either version 3 of the
|
||||
License, or (at your option) any later version.
|
||||
|
||||
This program is distributed in the hope that it will be useful,
|
||||
but WITHOUT ANY WARRANTY; without even the implied warranty of
|
||||
MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
||||
GNU Affero General Public License for more details.
|
||||
|
||||
You should have received a copy of the GNU Affero General Public License
|
||||
along with this program. If not, see <https://www.gnu.org/licenses/>.
|
||||
|
||||
For commercial licensing, please contact support@quantumnous.com
|
||||
*/
|
||||
|
||||
import React, { useState, useEffect } from 'react';
|
||||
import { Modal } from '@douyinfe/semi-ui';
|
||||
import { API, showError } from '../../../../helpers';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
import CodeViewer from '../../../common/ui/CodeViewer';
|
||||
|
||||
const JWKSInfoModal = ({ visible, onClose }) => {
|
||||
const { t } = useTranslation();
|
||||
const [loading, setLoading] = useState(false);
|
||||
const [jwksInfo, setJwksInfo] = useState(null);
|
||||
|
||||
const loadJWKSInfo = async () => {
|
||||
setLoading(true);
|
||||
try {
|
||||
const res = await API.get('/api/oauth/jwks');
|
||||
setJwksInfo(res.data);
|
||||
} catch (error) {
|
||||
showError(t('获取JWKS失败'));
|
||||
} finally {
|
||||
setLoading(false);
|
||||
}
|
||||
};
|
||||
|
||||
useEffect(() => {
|
||||
if (visible) {
|
||||
loadJWKSInfo();
|
||||
}
|
||||
}, [visible]);
|
||||
|
||||
return (
|
||||
<Modal
|
||||
title={t('JWKS 信息')}
|
||||
visible={visible}
|
||||
onCancel={onClose}
|
||||
onOk={onClose}
|
||||
cancelText=''
|
||||
okText={t('关闭')}
|
||||
width={650}
|
||||
bodyStyle={{ padding: '20px 24px' }}
|
||||
confirmLoading={loading}
|
||||
>
|
||||
<CodeViewer
|
||||
content={jwksInfo ? JSON.stringify(jwksInfo, null, 2) : t('加载中...')}
|
||||
title={t('JWKS 密钥集')}
|
||||
language='json'
|
||||
/>
|
||||
</Modal>
|
||||
);
|
||||
};
|
||||
|
||||
export default JWKSInfoModal;
|
||||
399
web/src/components/settings/oauth2/modals/JWKSManagerModal.jsx
Normal file
399
web/src/components/settings/oauth2/modals/JWKSManagerModal.jsx
Normal file
@@ -0,0 +1,399 @@
|
||||
/*
|
||||
Copyright (C) 2025 QuantumNous
|
||||
|
||||
This program is free software: you can redistribute it and/or modify
|
||||
it under the terms of the GNU Affero General Public License as
|
||||
published by the Free Software Foundation, either version 3 of the
|
||||
License, or (at your option) any later version.
|
||||
|
||||
This program is distributed in the hope that it will be useful,
|
||||
but WITHOUT ANY WARRANTY; without even the implied warranty of
|
||||
MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
||||
GNU Affero General Public License for more details.
|
||||
|
||||
You should have received a copy of the GNU Affero General Public License
|
||||
along with this program. If not, see <https://www.gnu.org/licenses/>.
|
||||
|
||||
For commercial licensing, please contact support@quantumnous.com
|
||||
*/
|
||||
import React, { useEffect, useState } from 'react';
|
||||
import {
|
||||
Table,
|
||||
Button,
|
||||
Space,
|
||||
Tag,
|
||||
Typography,
|
||||
Popconfirm,
|
||||
Toast,
|
||||
Form,
|
||||
Card,
|
||||
Tabs,
|
||||
TabPane,
|
||||
} from '@douyinfe/semi-ui';
|
||||
import { API, showError, showSuccess } from '../../../../helpers';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
import ResponsiveModal from '../../../common/ui/ResponsiveModal';
|
||||
|
||||
const { Text } = Typography;
|
||||
|
||||
// 操作模式枚举
|
||||
const OPERATION_MODES = {
|
||||
VIEW: 'view',
|
||||
IMPORT: 'import',
|
||||
GENERATE: 'generate',
|
||||
};
|
||||
|
||||
export default function JWKSManagerModal({ visible, onClose }) {
|
||||
const { t } = useTranslation();
|
||||
const [loading, setLoading] = useState(false);
|
||||
const [keys, setKeys] = useState([]);
|
||||
const [activeTab, setActiveTab] = useState(OPERATION_MODES.VIEW);
|
||||
|
||||
const load = async () => {
|
||||
setLoading(true);
|
||||
try {
|
||||
const res = await API.get('/api/oauth/keys');
|
||||
if (res?.data?.success) setKeys(res.data.data || []);
|
||||
else showError(res?.data?.message || t('获取密钥列表失败'));
|
||||
} catch {
|
||||
showError(t('获取密钥列表失败'));
|
||||
} finally {
|
||||
setLoading(false);
|
||||
}
|
||||
};
|
||||
|
||||
const rotate = async () => {
|
||||
setLoading(true);
|
||||
try {
|
||||
const res = await API.post('/api/oauth/keys/rotate', {});
|
||||
if (res?.data?.success) {
|
||||
showSuccess(t('签名密钥已轮换:{{kid}}', { kid: res.data.kid }));
|
||||
await load();
|
||||
} else showError(res?.data?.message || t('密钥轮换失败'));
|
||||
} catch {
|
||||
showError(t('密钥轮换失败'));
|
||||
} finally {
|
||||
setLoading(false);
|
||||
}
|
||||
};
|
||||
|
||||
const del = async (kid) => {
|
||||
setLoading(true);
|
||||
try {
|
||||
const res = await API.delete(`/api/oauth/keys/${kid}`);
|
||||
if (res?.data?.success) {
|
||||
Toast.success(t('已删除:{{kid}}', { kid }));
|
||||
await load();
|
||||
} else showError(res?.data?.message || t('删除失败'));
|
||||
} catch {
|
||||
showError(t('删除失败'));
|
||||
} finally {
|
||||
setLoading(false);
|
||||
}
|
||||
};
|
||||
|
||||
// Import PEM state
|
||||
const [pem, setPem] = useState('');
|
||||
const [customKid, setCustomKid] = useState('');
|
||||
|
||||
// Generate PEM file state
|
||||
const [genPath, setGenPath] = useState('/etc/new-api/oauth2-private.pem');
|
||||
const [genKid, setGenKid] = useState('');
|
||||
|
||||
// 重置表单数据
|
||||
const resetForms = () => {
|
||||
setPem('');
|
||||
setCustomKid('');
|
||||
setGenKid('');
|
||||
};
|
||||
|
||||
useEffect(() => {
|
||||
if (visible) {
|
||||
load();
|
||||
// 重置到主视图
|
||||
setActiveTab(OPERATION_MODES.VIEW);
|
||||
} else {
|
||||
// 模态框关闭时重置表单数据
|
||||
resetForms();
|
||||
}
|
||||
}, [visible]);
|
||||
|
||||
useEffect(() => {
|
||||
if (!visible) return;
|
||||
(async () => {
|
||||
try {
|
||||
const res = await API.get('/api/oauth/server-info');
|
||||
const p = res?.data?.default_private_key_path;
|
||||
if (p) setGenPath(p);
|
||||
} catch {}
|
||||
})();
|
||||
}, [visible]);
|
||||
|
||||
// 导入 PEM 私钥
|
||||
const importPem = async () => {
|
||||
if (!pem.trim()) return Toast.warning(t('请粘贴 PEM 私钥'));
|
||||
setLoading(true);
|
||||
try {
|
||||
const res = await API.post('/api/oauth/keys/import_pem', {
|
||||
pem,
|
||||
kid: customKid.trim(),
|
||||
});
|
||||
if (res?.data?.success) {
|
||||
Toast.success(
|
||||
t('已导入私钥并切换到 kid={{kid}}', { kid: res.data.kid }),
|
||||
);
|
||||
resetForms();
|
||||
setActiveTab(OPERATION_MODES.VIEW);
|
||||
await load();
|
||||
} else {
|
||||
Toast.error(res?.data?.message || t('导入失败'));
|
||||
}
|
||||
} catch {
|
||||
Toast.error(t('导入失败'));
|
||||
} finally {
|
||||
setLoading(false);
|
||||
}
|
||||
};
|
||||
|
||||
// 生成 PEM 文件
|
||||
const generatePemFile = async () => {
|
||||
if (!genPath.trim()) return Toast.warning(t('请填写保存路径'));
|
||||
setLoading(true);
|
||||
try {
|
||||
const res = await API.post('/api/oauth/keys/generate_file', {
|
||||
path: genPath.trim(),
|
||||
kid: genKid.trim(),
|
||||
});
|
||||
if (res?.data?.success) {
|
||||
Toast.success(t('已生成并生效:{{path}}', { path: res.data.path }));
|
||||
resetForms();
|
||||
setActiveTab(OPERATION_MODES.VIEW);
|
||||
await load();
|
||||
} else {
|
||||
Toast.error(res?.data?.message || t('生成失败'));
|
||||
}
|
||||
} catch {
|
||||
Toast.error(t('生成失败'));
|
||||
} finally {
|
||||
setLoading(false);
|
||||
}
|
||||
};
|
||||
|
||||
const columns = [
|
||||
{
|
||||
title: 'KID',
|
||||
dataIndex: 'kid',
|
||||
render: (kid) => (
|
||||
<Text code copyable>
|
||||
{kid}
|
||||
</Text>
|
||||
),
|
||||
},
|
||||
{
|
||||
title: t('创建时间'),
|
||||
dataIndex: 'created_at',
|
||||
render: (ts) => (ts ? new Date(ts * 1000).toLocaleString() : '-'),
|
||||
},
|
||||
{
|
||||
title: t('状态'),
|
||||
dataIndex: 'current',
|
||||
render: (cur) =>
|
||||
cur ? (
|
||||
<Tag color='green' shape='circle'>
|
||||
{t('当前')}
|
||||
</Tag>
|
||||
) : (
|
||||
<Tag shape='circle'>{t('历史')}</Tag>
|
||||
),
|
||||
},
|
||||
{
|
||||
title: t('操作'),
|
||||
render: (_, r) => (
|
||||
<Space>
|
||||
{!r.current && (
|
||||
<Popconfirm
|
||||
title={t('确定删除密钥 {{kid}} ?', { kid: r.kid })}
|
||||
content={t(
|
||||
'删除后使用该 kid 签发的旧令牌仍可被验证(外部 JWKS 缓存可能仍保留)',
|
||||
)}
|
||||
okText={t('删除')}
|
||||
onConfirm={() => del(r.kid)}
|
||||
>
|
||||
<Button size='small' type='danger'>
|
||||
{t('删除')}
|
||||
</Button>
|
||||
</Popconfirm>
|
||||
)}
|
||||
</Space>
|
||||
),
|
||||
},
|
||||
];
|
||||
|
||||
// 头部操作按钮 - 根据当前标签页动态生成
|
||||
const getHeaderActions = () => {
|
||||
if (activeTab === OPERATION_MODES.VIEW) {
|
||||
const hasKeys = Array.isArray(keys) && keys.length > 0;
|
||||
return [
|
||||
<Button key='refresh' onClick={load} loading={loading} size='small'>
|
||||
{t('刷新')}
|
||||
</Button>,
|
||||
<Button
|
||||
key='rotate'
|
||||
type='primary'
|
||||
onClick={rotate}
|
||||
loading={loading}
|
||||
size='small'
|
||||
>
|
||||
{hasKeys ? t('轮换密钥') : t('初始化密钥')}
|
||||
</Button>,
|
||||
];
|
||||
}
|
||||
|
||||
if (activeTab === OPERATION_MODES.IMPORT) {
|
||||
return [
|
||||
<Button
|
||||
key='import'
|
||||
type='primary'
|
||||
onClick={importPem}
|
||||
loading={loading}
|
||||
size='small'
|
||||
>
|
||||
{t('导入并生效')}
|
||||
</Button>,
|
||||
];
|
||||
}
|
||||
|
||||
if (activeTab === OPERATION_MODES.GENERATE) {
|
||||
return [
|
||||
<Button
|
||||
key='generate'
|
||||
type='primary'
|
||||
onClick={generatePemFile}
|
||||
loading={loading}
|
||||
size='small'
|
||||
>
|
||||
{t('生成并生效')}
|
||||
</Button>,
|
||||
];
|
||||
}
|
||||
|
||||
return [];
|
||||
};
|
||||
|
||||
// 渲染密钥列表视图
|
||||
const renderKeysView = () => (
|
||||
<Card
|
||||
className='!rounded-lg'
|
||||
title={
|
||||
<Text className='text-blue-700 dark:text-blue-300'>
|
||||
{t(
|
||||
'提示:当前密钥用于签发 JWT 令牌。建议定期轮换密钥以提升安全性。只有历史密钥可以删除。',
|
||||
)}
|
||||
</Text>
|
||||
}
|
||||
>
|
||||
<Table
|
||||
dataSource={keys}
|
||||
columns={columns}
|
||||
rowKey='kid'
|
||||
loading={loading}
|
||||
pagination={false}
|
||||
empty={<Text type='tertiary'>{t('暂无密钥')}</Text>}
|
||||
/>
|
||||
</Card>
|
||||
);
|
||||
|
||||
// 渲染导入 PEM 私钥视图
|
||||
const renderImportView = () => (
|
||||
<Card
|
||||
className='!rounded-lg'
|
||||
title={
|
||||
<Text className='text-yellow-700 dark:text-yellow-300'>
|
||||
{t(
|
||||
'建议:优先使用内存签名密钥与 JWKS 轮换;仅在有合规要求时导入外部私钥。请确保私钥来源可信。',
|
||||
)}
|
||||
</Text>
|
||||
}
|
||||
>
|
||||
<Form labelPosition='left' labelWidth={120}>
|
||||
<Form.Input
|
||||
field='kid'
|
||||
label={t('自定义 KID')}
|
||||
placeholder={t('可留空自动生成')}
|
||||
value={customKid}
|
||||
onChange={setCustomKid}
|
||||
/>
|
||||
<Form.TextArea
|
||||
field='pem'
|
||||
label={t('PEM 私钥')}
|
||||
value={pem}
|
||||
onChange={setPem}
|
||||
rows={8}
|
||||
placeholder={
|
||||
'-----BEGIN RSA PRIVATE KEY-----\n...\n-----END RSA PRIVATE KEY-----'
|
||||
}
|
||||
/>
|
||||
</Form>
|
||||
</Card>
|
||||
);
|
||||
|
||||
// 渲染生成 PEM 文件视图
|
||||
const renderGenerateView = () => (
|
||||
<Card
|
||||
className='!rounded-lg'
|
||||
title={
|
||||
<Text className='text-orange-700 dark:text-orange-300'>
|
||||
{t(
|
||||
'建议:仅在合规要求下使用文件私钥。请确保目录权限安全(建议 0600),并妥善备份。',
|
||||
)}
|
||||
</Text>
|
||||
}
|
||||
>
|
||||
<Form labelPosition='left' labelWidth={120}>
|
||||
<Form.Input
|
||||
field='path'
|
||||
label={t('保存路径')}
|
||||
value={genPath}
|
||||
onChange={setGenPath}
|
||||
placeholder='/secure/path/oauth2-private.pem'
|
||||
/>
|
||||
<Form.Input
|
||||
field='genKid'
|
||||
label={t('自定义 KID')}
|
||||
value={genKid}
|
||||
onChange={setGenKid}
|
||||
placeholder={t('可留空自动生成')}
|
||||
/>
|
||||
</Form>
|
||||
</Card>
|
||||
);
|
||||
|
||||
return (
|
||||
<ResponsiveModal
|
||||
visible={visible}
|
||||
title={t('JWKS 管理')}
|
||||
headerActions={getHeaderActions()}
|
||||
onCancel={onClose}
|
||||
footer={null}
|
||||
width={{ mobile: '95%', desktop: 800 }}
|
||||
>
|
||||
<Tabs
|
||||
activeKey={activeTab}
|
||||
onChange={setActiveTab}
|
||||
type='card'
|
||||
size='medium'
|
||||
className='!-mt-2'
|
||||
>
|
||||
<TabPane tab={t('密钥列表')} itemKey={OPERATION_MODES.VIEW}>
|
||||
{renderKeysView()}
|
||||
</TabPane>
|
||||
<TabPane tab={t('导入 PEM 私钥')} itemKey={OPERATION_MODES.IMPORT}>
|
||||
{renderImportView()}
|
||||
</TabPane>
|
||||
<TabPane tab={t('生成 PEM 文件')} itemKey={OPERATION_MODES.GENERATE}>
|
||||
{renderGenerateView()}
|
||||
</TabPane>
|
||||
</Tabs>
|
||||
</ResponsiveModal>
|
||||
);
|
||||
}
|
||||
730
web/src/components/settings/oauth2/modals/OAuth2ClientModal.jsx
Normal file
730
web/src/components/settings/oauth2/modals/OAuth2ClientModal.jsx
Normal file
@@ -0,0 +1,730 @@
|
||||
/*
|
||||
Copyright (C) 2025 QuantumNous
|
||||
|
||||
This program is free software: you can redistribute it and/or modify
|
||||
it under the terms of the GNU Affero General Public License as
|
||||
published by the Free Software Foundation, either version 3 of the
|
||||
License, or (at your option) any later version.
|
||||
|
||||
This program is distributed in the hope that it will be useful,
|
||||
but WITHOUT ANY WARRANTY; without even the implied warranty of
|
||||
MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
||||
GNU Affero General Public License for more details.
|
||||
|
||||
You should have received a copy of the GNU Affero General Public License
|
||||
along with this program. If not, see <https://www.gnu.org/licenses/>.
|
||||
|
||||
For commercial licensing, please contact support@quantumnous.com
|
||||
*/
|
||||
|
||||
import React, { useEffect, useState, useRef } from 'react';
|
||||
import {
|
||||
SideSheet,
|
||||
Form,
|
||||
Input,
|
||||
Select,
|
||||
Space,
|
||||
Typography,
|
||||
Button,
|
||||
Card,
|
||||
Avatar,
|
||||
Tag,
|
||||
Spin,
|
||||
Radio,
|
||||
Divider,
|
||||
} from '@douyinfe/semi-ui';
|
||||
import {
|
||||
IconKey,
|
||||
IconLink,
|
||||
IconSave,
|
||||
IconClose,
|
||||
IconPlus,
|
||||
IconDelete,
|
||||
} from '@douyinfe/semi-icons';
|
||||
import { API, showError, showSuccess } from '../../../../helpers';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
import { useIsMobile } from '../../../../hooks/common/useIsMobile';
|
||||
import ClientInfoModal from './ClientInfoModal';
|
||||
|
||||
const { Text, Title } = Typography;
|
||||
const { Option } = Select;
|
||||
|
||||
const AUTH_CODE = 'authorization_code';
|
||||
const CLIENT_CREDENTIALS = 'client_credentials';
|
||||
|
||||
// 子组件:重定向URI编辑卡片
|
||||
function RedirectUriCard({
|
||||
t,
|
||||
isAuthCodeSelected,
|
||||
redirectUris,
|
||||
onAdd,
|
||||
onUpdate,
|
||||
onRemove,
|
||||
onFillTemplate,
|
||||
}) {
|
||||
return (
|
||||
<Card
|
||||
header={
|
||||
<div className='flex justify-between items-center'>
|
||||
<div className='flex items-center'>
|
||||
<Avatar size='small' color='purple' className='mr-2 shadow-md'>
|
||||
<IconLink size={16} />
|
||||
</Avatar>
|
||||
<div>
|
||||
<Text className='text-lg font-medium'>{t('重定向URI配置')}</Text>
|
||||
<div className='text-xs text-gray-600'>
|
||||
{t('用于授权码流程的重定向地址')}
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
<Button
|
||||
type='tertiary'
|
||||
onClick={onFillTemplate}
|
||||
size='small'
|
||||
disabled={!isAuthCodeSelected}
|
||||
>
|
||||
{t('填入示例模板')}
|
||||
</Button>
|
||||
</div>
|
||||
}
|
||||
headerStyle={{ padding: '12px 16px' }}
|
||||
bodyStyle={{ padding: '16px' }}
|
||||
className='!rounded-2xl shadow-sm border-0'
|
||||
>
|
||||
<div className='space-y-1'>
|
||||
{redirectUris.length === 0 && (
|
||||
<div className='text-center py-4 px-4'>
|
||||
<Text type='tertiary' className='text-gray-500 text-sm'>
|
||||
{t('暂无重定向URI,点击下方按钮添加')}
|
||||
</Text>
|
||||
</div>
|
||||
)}
|
||||
|
||||
{redirectUris.map((uri, index) => (
|
||||
<div
|
||||
key={index}
|
||||
style={{
|
||||
marginBottom: 8,
|
||||
display: 'flex',
|
||||
gap: 8,
|
||||
alignItems: 'center',
|
||||
}}
|
||||
>
|
||||
<Input
|
||||
placeholder={t('例如:https://your-app.com/callback')}
|
||||
value={uri}
|
||||
onChange={(value) => onUpdate(index, value)}
|
||||
style={{ flex: 1 }}
|
||||
disabled={!isAuthCodeSelected}
|
||||
/>
|
||||
<Button
|
||||
icon={<IconDelete />}
|
||||
type='danger'
|
||||
theme='borderless'
|
||||
onClick={() => onRemove(index)}
|
||||
disabled={!isAuthCodeSelected}
|
||||
/>
|
||||
</div>
|
||||
))}
|
||||
|
||||
<div className='py-2 flex justify-center gap-2'>
|
||||
<Button
|
||||
icon={<IconPlus />}
|
||||
type='primary'
|
||||
theme='outline'
|
||||
onClick={onAdd}
|
||||
disabled={!isAuthCodeSelected}
|
||||
>
|
||||
{t('添加重定向URI')}
|
||||
</Button>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
<Divider margin='12px' align='center'>
|
||||
<Text type='tertiary' size='small'>
|
||||
{isAuthCodeSelected
|
||||
? t(
|
||||
'用户授权后将重定向到这些URI。必须使用HTTPS(本地开发可使用HTTP,仅限localhost/127.0.0.1)',
|
||||
)
|
||||
: t('仅在选择“授权码”授权类型时需要配置重定向URI')}
|
||||
</Text>
|
||||
</Divider>
|
||||
</Card>
|
||||
);
|
||||
}
|
||||
|
||||
const OAuth2ClientModal = ({ visible, client, onCancel, onSuccess }) => {
|
||||
const { t } = useTranslation();
|
||||
const isMobile = useIsMobile();
|
||||
const formApiRef = useRef(null);
|
||||
const [loading, setLoading] = useState(false);
|
||||
const [redirectUris, setRedirectUris] = useState([]);
|
||||
const [clientType, setClientType] = useState('confidential');
|
||||
const [grantTypes, setGrantTypes] = useState([]);
|
||||
const [allowedGrantTypes, setAllowedGrantTypes] = useState([
|
||||
CLIENT_CREDENTIALS,
|
||||
AUTH_CODE,
|
||||
'refresh_token',
|
||||
]);
|
||||
|
||||
// ClientInfoModal 状态
|
||||
const [showClientInfo, setShowClientInfo] = useState(false);
|
||||
const [clientInfo, setClientInfo] = useState({
|
||||
clientId: '',
|
||||
clientSecret: '',
|
||||
});
|
||||
|
||||
const isEdit = client?.id !== undefined;
|
||||
const [mode, setMode] = useState('create'); // 'create' | 'edit'
|
||||
useEffect(() => {
|
||||
if (visible) {
|
||||
setMode(isEdit ? 'edit' : 'create');
|
||||
}
|
||||
}, [visible, isEdit]);
|
||||
|
||||
const getInitValues = () => ({
|
||||
name: '',
|
||||
description: '',
|
||||
client_type: 'confidential',
|
||||
grant_types: [],
|
||||
scopes: [],
|
||||
require_pkce: true,
|
||||
status: 1,
|
||||
});
|
||||
|
||||
// 加载后端允许的授权类型
|
||||
useEffect(() => {
|
||||
let mounted = true;
|
||||
(async () => {
|
||||
try {
|
||||
const res = await API.get('/api/option/');
|
||||
const { success, data } = res.data || {};
|
||||
if (!success || !Array.isArray(data)) return;
|
||||
const found = data.find((i) => i.key === 'oauth2.allowed_grant_types');
|
||||
if (!found) return;
|
||||
let parsed = [];
|
||||
try {
|
||||
parsed = JSON.parse(found.value || '[]');
|
||||
} catch (_) {}
|
||||
if (mounted && Array.isArray(parsed) && parsed.length) {
|
||||
setAllowedGrantTypes(parsed);
|
||||
}
|
||||
} catch (_) {
|
||||
// 忽略错误,使用默认allowedGrantTypes
|
||||
}
|
||||
})();
|
||||
return () => {
|
||||
mounted = false;
|
||||
};
|
||||
}, []);
|
||||
|
||||
useEffect(() => {
|
||||
setGrantTypes((prev) => {
|
||||
const normalizedPrev = Array.isArray(prev) ? prev : [];
|
||||
// 移除不被允许或与客户端类型冲突的类型
|
||||
let next = normalizedPrev.filter((g) => allowedGrantTypes.includes(g));
|
||||
if (clientType === 'public') {
|
||||
next = next.filter((g) => g !== CLIENT_CREDENTIALS);
|
||||
}
|
||||
return next.length ? next : [];
|
||||
});
|
||||
}, [clientType, allowedGrantTypes]);
|
||||
|
||||
// 初始化表单数据(编辑模式)
|
||||
useEffect(() => {
|
||||
if (client && visible && isEdit) {
|
||||
setLoading(true);
|
||||
// 解析授权类型
|
||||
let parsedGrantTypes = [];
|
||||
if (typeof client.grant_types === 'string') {
|
||||
parsedGrantTypes = client.grant_types.split(',');
|
||||
} else if (Array.isArray(client.grant_types)) {
|
||||
parsedGrantTypes = client.grant_types;
|
||||
}
|
||||
|
||||
// 解析Scope
|
||||
let parsedScopes = [];
|
||||
if (typeof client.scopes === 'string') {
|
||||
parsedScopes = client.scopes.split(',');
|
||||
} else if (Array.isArray(client.scopes)) {
|
||||
parsedScopes = client.scopes;
|
||||
}
|
||||
if (!parsedScopes || parsedScopes.length === 0) {
|
||||
parsedScopes = ['openid', 'profile', 'email', 'api:read'];
|
||||
}
|
||||
|
||||
// 解析重定向URI
|
||||
let parsedRedirectUris = [];
|
||||
if (client.redirect_uris) {
|
||||
try {
|
||||
const parsed =
|
||||
typeof client.redirect_uris === 'string'
|
||||
? JSON.parse(client.redirect_uris)
|
||||
: client.redirect_uris;
|
||||
if (Array.isArray(parsed) && parsed.length > 0) {
|
||||
parsedRedirectUris = parsed;
|
||||
}
|
||||
} catch (e) {}
|
||||
}
|
||||
|
||||
// 过滤不被允许或不兼容的授权类型
|
||||
const filteredGrantTypes = (parsedGrantTypes || []).filter((g) =>
|
||||
allowedGrantTypes.includes(g),
|
||||
);
|
||||
const finalGrantTypes =
|
||||
client.client_type === 'public'
|
||||
? filteredGrantTypes.filter((g) => g !== CLIENT_CREDENTIALS)
|
||||
: filteredGrantTypes;
|
||||
|
||||
setClientType(client.client_type);
|
||||
setGrantTypes(finalGrantTypes);
|
||||
// 不自动新增空白URI,保持与创建模式一致的手动添加体验
|
||||
setRedirectUris(parsedRedirectUris);
|
||||
|
||||
// 设置表单值
|
||||
const formValues = {
|
||||
id: client.id,
|
||||
name: client.name,
|
||||
description: client.description,
|
||||
client_type: client.client_type,
|
||||
grant_types: finalGrantTypes,
|
||||
scopes: parsedScopes,
|
||||
require_pkce: !!client.require_pkce,
|
||||
status: client.status,
|
||||
};
|
||||
|
||||
setTimeout(() => {
|
||||
if (formApiRef.current) {
|
||||
formApiRef.current.setValues(formValues);
|
||||
}
|
||||
setLoading(false);
|
||||
}, 100);
|
||||
} else if (visible && !isEdit) {
|
||||
// 创建模式,重置状态
|
||||
setClientType('confidential');
|
||||
setGrantTypes([]);
|
||||
setRedirectUris([]);
|
||||
if (formApiRef.current) {
|
||||
formApiRef.current.setValues(getInitValues());
|
||||
}
|
||||
}
|
||||
}, [client, visible, isEdit, allowedGrantTypes]);
|
||||
|
||||
const isAuthCodeSelected = grantTypes.includes(AUTH_CODE);
|
||||
const isGrantTypeDisabled = (value) => {
|
||||
if (!allowedGrantTypes.includes(value)) return true;
|
||||
if (clientType === 'public' && value === CLIENT_CREDENTIALS) return true;
|
||||
return false;
|
||||
};
|
||||
|
||||
// URL校验:允许 https;http 仅限本地开发域名
|
||||
const isValidRedirectUri = (uri) => {
|
||||
if (!uri || !uri.trim()) return false;
|
||||
try {
|
||||
const u = new URL(uri.trim());
|
||||
if (u.protocol === 'https:') return true;
|
||||
if (u.protocol === 'http:') {
|
||||
const host = u.hostname;
|
||||
return (
|
||||
host === 'localhost' ||
|
||||
host === '127.0.0.1' ||
|
||||
host.endsWith('.local')
|
||||
);
|
||||
}
|
||||
return false;
|
||||
} catch (_) {
|
||||
return false;
|
||||
}
|
||||
};
|
||||
|
||||
// 处理提交
|
||||
const handleSubmit = async (values) => {
|
||||
setLoading(true);
|
||||
try {
|
||||
// 过滤空的重定向URI
|
||||
const validRedirectUris = redirectUris
|
||||
.map((u) => (u || '').trim())
|
||||
.filter((u) => u.length > 0);
|
||||
|
||||
// 业务校验
|
||||
if (!grantTypes.length) {
|
||||
showError(t('请至少选择一种授权类型'));
|
||||
setLoading(false);
|
||||
return;
|
||||
}
|
||||
|
||||
// 校验是否包含不被允许的授权类型
|
||||
const invalids = grantTypes.filter((g) => !allowedGrantTypes.includes(g));
|
||||
if (invalids.length) {
|
||||
showError(
|
||||
t('不被允许的授权类型: {{types}}', { types: invalids.join(', ') }),
|
||||
);
|
||||
setLoading(false);
|
||||
return;
|
||||
}
|
||||
|
||||
if (clientType === 'public' && grantTypes.includes(CLIENT_CREDENTIALS)) {
|
||||
showError(t('公开客户端不允许使用client_credentials授权类型'));
|
||||
setLoading(false);
|
||||
return;
|
||||
}
|
||||
|
||||
if (grantTypes.includes(AUTH_CODE)) {
|
||||
if (!validRedirectUris.length) {
|
||||
showError(t('选择授权码授权类型时,必须填写至少一个重定向URI'));
|
||||
setLoading(false);
|
||||
return;
|
||||
}
|
||||
const allValid = validRedirectUris.every(isValidRedirectUri);
|
||||
if (!allValid) {
|
||||
showError(t('重定向URI格式不合法:仅支持https,或本地开发使用http'));
|
||||
setLoading(false);
|
||||
return;
|
||||
}
|
||||
}
|
||||
|
||||
// 避免把 Radio 组件对象形式的 client_type 直接传给后端
|
||||
const { client_type: _formClientType, ...restValues } = values || {};
|
||||
const payload = {
|
||||
...restValues,
|
||||
client_type: clientType,
|
||||
grant_types: grantTypes,
|
||||
redirect_uris: validRedirectUris,
|
||||
};
|
||||
|
||||
let res;
|
||||
if (isEdit) {
|
||||
res = await API.put('/api/oauth_clients/', payload);
|
||||
} else {
|
||||
res = await API.post('/api/oauth_clients/', payload);
|
||||
}
|
||||
|
||||
const { success, message, client_id, client_secret } = res.data;
|
||||
|
||||
if (success) {
|
||||
if (isEdit) {
|
||||
showSuccess(t('OAuth2客户端更新成功'));
|
||||
resetForm();
|
||||
onSuccess();
|
||||
} else {
|
||||
showSuccess(t('OAuth2客户端创建成功'));
|
||||
// 显示客户端信息
|
||||
setClientInfo({
|
||||
clientId: client_id,
|
||||
clientSecret: client_secret,
|
||||
});
|
||||
setShowClientInfo(true);
|
||||
}
|
||||
} else {
|
||||
showError(message);
|
||||
}
|
||||
} catch (error) {
|
||||
showError(isEdit ? t('更新OAuth2客户端失败') : t('创建OAuth2客户端失败'));
|
||||
} finally {
|
||||
setLoading(false);
|
||||
}
|
||||
};
|
||||
|
||||
// 重置表单
|
||||
const resetForm = () => {
|
||||
if (formApiRef.current) {
|
||||
formApiRef.current.reset();
|
||||
}
|
||||
setClientType('confidential');
|
||||
setGrantTypes([]);
|
||||
setRedirectUris([]);
|
||||
};
|
||||
|
||||
// 处理ClientInfoModal关闭
|
||||
const handleClientInfoClose = () => {
|
||||
setShowClientInfo(false);
|
||||
setClientInfo({ clientId: '', clientSecret: '' });
|
||||
resetForm();
|
||||
onSuccess();
|
||||
};
|
||||
|
||||
// 处理取消
|
||||
const handleCancel = () => {
|
||||
resetForm();
|
||||
onCancel();
|
||||
};
|
||||
|
||||
// 添加重定向URI
|
||||
const addRedirectUri = () => {
|
||||
setRedirectUris([...redirectUris, '']);
|
||||
};
|
||||
|
||||
// 删除重定向URI
|
||||
const removeRedirectUri = (index) => {
|
||||
setRedirectUris(redirectUris.filter((_, i) => i !== index));
|
||||
};
|
||||
|
||||
// 更新重定向URI
|
||||
const updateRedirectUri = (index, value) => {
|
||||
const newUris = [...redirectUris];
|
||||
newUris[index] = value;
|
||||
setRedirectUris(newUris);
|
||||
};
|
||||
|
||||
// 填入示例重定向URI模板
|
||||
const fillRedirectUriTemplate = () => {
|
||||
const template = [
|
||||
'https://your-app.com/auth/callback',
|
||||
'https://localhost:3000/callback',
|
||||
];
|
||||
setRedirectUris(template);
|
||||
};
|
||||
|
||||
// 授权类型变化处理(清理非法项,只设置一次)
|
||||
const handleGrantTypesChange = (values) => {
|
||||
const allowed = Array.isArray(values)
|
||||
? values.filter((v) => allowedGrantTypes.includes(v))
|
||||
: [];
|
||||
const sanitized =
|
||||
clientType === 'public'
|
||||
? allowed.filter((v) => v !== CLIENT_CREDENTIALS)
|
||||
: allowed;
|
||||
setGrantTypes(sanitized);
|
||||
if (formApiRef.current) {
|
||||
formApiRef.current.setValue('grant_types', sanitized);
|
||||
}
|
||||
};
|
||||
|
||||
// 客户端类型变化处理(兼容 RadioGroup 事件对象与直接值)
|
||||
const handleClientTypeChange = (next) => {
|
||||
const value = next && next.target ? next.target.value : next;
|
||||
setClientType(value);
|
||||
// 公开客户端自动移除 client_credentials,并同步表单字段
|
||||
const current = Array.isArray(grantTypes) ? grantTypes : [];
|
||||
const sanitized =
|
||||
value === 'public'
|
||||
? current.filter((g) => g !== CLIENT_CREDENTIALS)
|
||||
: current;
|
||||
if (sanitized !== current) {
|
||||
setGrantTypes(sanitized);
|
||||
if (formApiRef.current) {
|
||||
formApiRef.current.setValue('grant_types', sanitized);
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
return (
|
||||
<SideSheet
|
||||
placement={mode === 'edit' ? 'right' : 'left'}
|
||||
title={
|
||||
<Space>
|
||||
{mode === 'edit' ? (
|
||||
<Tag color='blue' shape='circle'>
|
||||
{t('编辑')}
|
||||
</Tag>
|
||||
) : (
|
||||
<Tag color='green' shape='circle'>
|
||||
{t('创建')}
|
||||
</Tag>
|
||||
)}
|
||||
<Title heading={4} className='m-0'>
|
||||
{mode === 'edit' ? t('编辑OAuth2客户端') : t('创建OAuth2客户端')}
|
||||
</Title>
|
||||
</Space>
|
||||
}
|
||||
bodyStyle={{ padding: '0' }}
|
||||
visible={visible}
|
||||
width={isMobile ? '100%' : 700}
|
||||
footer={
|
||||
<div className='flex justify-end bg-white'>
|
||||
<Space>
|
||||
<Button
|
||||
theme='solid'
|
||||
className='!rounded-lg'
|
||||
onClick={() => formApiRef.current?.submitForm()}
|
||||
icon={<IconSave />}
|
||||
loading={loading}
|
||||
>
|
||||
{isEdit ? t('保存') : t('创建')}
|
||||
</Button>
|
||||
<Button
|
||||
theme='light'
|
||||
className='!rounded-lg'
|
||||
type='primary'
|
||||
onClick={handleCancel}
|
||||
icon={<IconClose />}
|
||||
>
|
||||
{t('取消')}
|
||||
</Button>
|
||||
</Space>
|
||||
</div>
|
||||
}
|
||||
closeIcon={null}
|
||||
onCancel={handleCancel}
|
||||
>
|
||||
<Spin spinning={loading}>
|
||||
<Form
|
||||
key={isEdit ? `edit-${client?.id}` : 'create'}
|
||||
initValues={getInitValues()}
|
||||
getFormApi={(api) => (formApiRef.current = api)}
|
||||
onSubmit={handleSubmit}
|
||||
>
|
||||
{() => (
|
||||
<div className='p-2'>
|
||||
{/* 表单内容 */}
|
||||
{/* 基本信息 */}
|
||||
<Card className='!rounded-2xl shadow-sm border-0'>
|
||||
<div className='flex items-center mb-4'>
|
||||
<Avatar size='small' color='blue' className='mr-2 shadow-md'>
|
||||
<IconKey size={16} />
|
||||
</Avatar>
|
||||
<div>
|
||||
<Text className='text-lg font-medium'>{t('基本信息')}</Text>
|
||||
<div className='text-xs text-gray-600'>
|
||||
{t('设置客户端的基本信息')}
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
{isEdit && (
|
||||
<>
|
||||
<Form.Select
|
||||
field='status'
|
||||
label={t('状态')}
|
||||
rules={[{ required: true, message: t('请选择状态') }]}
|
||||
required
|
||||
>
|
||||
<Option value={1}>{t('启用')}</Option>
|
||||
<Option value={2}>{t('禁用')}</Option>
|
||||
</Form.Select>
|
||||
<Form.Input field='id' label={t('客户端ID')} disabled />
|
||||
</>
|
||||
)}
|
||||
<Form.Input
|
||||
field='name'
|
||||
label={t('客户端名称')}
|
||||
placeholder={t('输入客户端名称')}
|
||||
rules={[{ required: true, message: t('请输入客户端名称') }]}
|
||||
required
|
||||
showClear
|
||||
/>
|
||||
<Form.TextArea
|
||||
field='description'
|
||||
label={t('描述')}
|
||||
placeholder={t('输入客户端描述')}
|
||||
rows={3}
|
||||
showClear
|
||||
/>
|
||||
<Form.RadioGroup
|
||||
label={t('客户端类型')}
|
||||
field='client_type'
|
||||
value={clientType}
|
||||
onChange={handleClientTypeChange}
|
||||
type='card'
|
||||
aria-label={t('选择客户端类型')}
|
||||
disabled={isEdit}
|
||||
rules={[{ required: true, message: t('请选择客户端类型') }]}
|
||||
required
|
||||
>
|
||||
<Radio
|
||||
value='confidential'
|
||||
extra={t('服务器端应用,安全地存储客户端密钥')}
|
||||
style={{ width: isMobile ? '100%' : 'auto' }}
|
||||
>
|
||||
{t('机密客户端(Confidential)')}
|
||||
</Radio>
|
||||
<Radio
|
||||
value='public'
|
||||
extra={t('移动应用或单页应用,无法安全存储密钥')}
|
||||
style={{ width: isMobile ? '100%' : 'auto' }}
|
||||
>
|
||||
{t('公开客户端(Public)')}
|
||||
</Radio>
|
||||
</Form.RadioGroup>
|
||||
<Form.Select
|
||||
field='grant_types'
|
||||
label={t('允许的授权类型')}
|
||||
multiple
|
||||
value={grantTypes}
|
||||
onChange={handleGrantTypesChange}
|
||||
rules={[
|
||||
{ required: true, message: t('请选择至少一种授权类型') },
|
||||
]}
|
||||
required
|
||||
placeholder={t('请选择授权类型(可多选)')}
|
||||
>
|
||||
{clientType !== 'public' && (
|
||||
<Option
|
||||
value={CLIENT_CREDENTIALS}
|
||||
disabled={isGrantTypeDisabled(CLIENT_CREDENTIALS)}
|
||||
>
|
||||
{t('Client Credentials(客户端凭证)')}
|
||||
</Option>
|
||||
)}
|
||||
<Option
|
||||
value={AUTH_CODE}
|
||||
disabled={isGrantTypeDisabled(AUTH_CODE)}
|
||||
>
|
||||
{t('Authorization Code(授权码)')}
|
||||
</Option>
|
||||
<Option
|
||||
value='refresh_token'
|
||||
disabled={isGrantTypeDisabled('refresh_token')}
|
||||
>
|
||||
{t('Refresh Token(刷新令牌)')}
|
||||
</Option>
|
||||
</Form.Select>
|
||||
<Form.Select
|
||||
field='scopes'
|
||||
label={t('允许的权限范围(Scope)')}
|
||||
multiple
|
||||
rules={[
|
||||
{ required: true, message: t('请选择至少一个权限范围') },
|
||||
]}
|
||||
required
|
||||
placeholder={t('请选择权限范围(可多选)')}
|
||||
>
|
||||
<Option value='openid'>{t('openid(OIDC 基础身份)')}</Option>
|
||||
<Option value='profile'>
|
||||
{t('profile(用户名/昵称等)')}
|
||||
</Option>
|
||||
<Option value='email'>{t('email(邮箱信息)')}</Option>
|
||||
<Option value='api:read'>
|
||||
{`api:read (${t('读取API')})`}
|
||||
</Option>
|
||||
<Option value='api:write'>
|
||||
{`api:write (${t('写入API')})`}
|
||||
</Option>
|
||||
<Option value='admin'>{t('admin(管理员权限)')}</Option>
|
||||
</Form.Select>
|
||||
<Form.Switch
|
||||
field='require_pkce'
|
||||
label={t('强制PKCE验证')}
|
||||
size='large'
|
||||
extraText={t(
|
||||
'PKCE(Proof Key for Code Exchange)可提高授权码流程的安全性。',
|
||||
)}
|
||||
/>
|
||||
</Card>
|
||||
|
||||
{/* 重定向URI */}
|
||||
<RedirectUriCard
|
||||
t={t}
|
||||
isAuthCodeSelected={isAuthCodeSelected}
|
||||
redirectUris={redirectUris}
|
||||
onAdd={addRedirectUri}
|
||||
onUpdate={updateRedirectUri}
|
||||
onRemove={removeRedirectUri}
|
||||
onFillTemplate={fillRedirectUriTemplate}
|
||||
/>
|
||||
</div>
|
||||
)}
|
||||
</Form>
|
||||
</Spin>
|
||||
|
||||
{/* 客户端信息展示模态框 */}
|
||||
<ClientInfoModal
|
||||
visible={showClientInfo}
|
||||
onClose={handleClientInfoClose}
|
||||
clientId={clientInfo.clientId}
|
||||
clientSecret={clientInfo.clientSecret}
|
||||
/>
|
||||
</SideSheet>
|
||||
);
|
||||
};
|
||||
|
||||
export default OAuth2ClientModal;
|
||||
@@ -0,0 +1,57 @@
|
||||
/*
|
||||
Copyright (C) 2025 QuantumNous
|
||||
|
||||
This program is free software: you can redistribute it and/or modify
|
||||
it under the terms of the GNU Affero General Public License as
|
||||
published by the Free Software Foundation, either version 3 of the
|
||||
License, or (at your option) any later version.
|
||||
|
||||
This program is distributed in the hope that it will be useful,
|
||||
but WITHOUT ANY WARRANTY; without even the implied warranty of
|
||||
MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
||||
GNU Affero General Public License for more details.
|
||||
|
||||
You should have received a copy of the GNU Affero General Public License
|
||||
along with this program. If not, see <https://www.gnu.org/licenses/>.
|
||||
|
||||
For commercial licensing, please contact support@quantumnous.com
|
||||
*/
|
||||
|
||||
import React from 'react';
|
||||
import { Modal, Banner, Typography } from '@douyinfe/semi-ui';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
|
||||
const { Text } = Typography;
|
||||
|
||||
const SecretDisplayModal = ({ visible, onClose, secret }) => {
|
||||
const { t } = useTranslation();
|
||||
|
||||
return (
|
||||
<Modal
|
||||
title={t('客户端密钥已重新生成')}
|
||||
visible={visible}
|
||||
onCancel={onClose}
|
||||
onOk={onClose}
|
||||
cancelText=''
|
||||
okText={t('我已复制保存')}
|
||||
width={650}
|
||||
bodyStyle={{ padding: '20px 24px' }}
|
||||
>
|
||||
<Banner
|
||||
type='success'
|
||||
closeIcon={null}
|
||||
description={t(
|
||||
'新的客户端密钥如下,请立即复制保存。关闭此窗口后将无法再次查看。',
|
||||
)}
|
||||
className='mb-5 !rounded-lg'
|
||||
/>
|
||||
<div className='flex justify-center items-center'>
|
||||
<Text code copyable>
|
||||
{secret}
|
||||
</Text>
|
||||
</div>
|
||||
</Modal>
|
||||
);
|
||||
};
|
||||
|
||||
export default SecretDisplayModal;
|
||||
@@ -0,0 +1,72 @@
|
||||
/*
|
||||
Copyright (C) 2025 QuantumNous
|
||||
|
||||
This program is free software: you can redistribute it and/or modify
|
||||
it under the terms of the GNU Affero General Public License as
|
||||
published by the Free Software Foundation, either version 3 of the
|
||||
License, or (at your option) any later version.
|
||||
|
||||
This program is distributed in the hope that it will be useful,
|
||||
but WITHOUT ANY WARRANTY; without even the implied warranty of
|
||||
MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
||||
GNU Affero General Public License for more details.
|
||||
|
||||
You should have received a copy of the GNU Affero General Public License
|
||||
along with this program. If not, see <https://www.gnu.org/licenses/>.
|
||||
|
||||
For commercial licensing, please contact support@quantumnous.com
|
||||
*/
|
||||
|
||||
import React, { useState, useEffect } from 'react';
|
||||
import { Modal } from '@douyinfe/semi-ui';
|
||||
import { API, showError } from '../../../../helpers';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
import CodeViewer from '../../../common/ui/CodeViewer';
|
||||
|
||||
const ServerInfoModal = ({ visible, onClose }) => {
|
||||
const { t } = useTranslation();
|
||||
const [loading, setLoading] = useState(false);
|
||||
const [serverInfo, setServerInfo] = useState(null);
|
||||
|
||||
const loadServerInfo = async () => {
|
||||
setLoading(true);
|
||||
try {
|
||||
const res = await API.get('/api/oauth/server-info');
|
||||
setServerInfo(res.data);
|
||||
} catch (error) {
|
||||
showError(t('获取服务器信息失败'));
|
||||
} finally {
|
||||
setLoading(false);
|
||||
}
|
||||
};
|
||||
|
||||
useEffect(() => {
|
||||
if (visible) {
|
||||
loadServerInfo();
|
||||
}
|
||||
}, [visible]);
|
||||
|
||||
return (
|
||||
<Modal
|
||||
title={t('OAuth2 服务器信息')}
|
||||
visible={visible}
|
||||
onCancel={onClose}
|
||||
onOk={onClose}
|
||||
cancelText=''
|
||||
okText={t('关闭')}
|
||||
width={650}
|
||||
bodyStyle={{ padding: '20px 24px' }}
|
||||
confirmLoading={loading}
|
||||
>
|
||||
<CodeViewer
|
||||
content={
|
||||
serverInfo ? JSON.stringify(serverInfo, null, 2) : t('加载中...')
|
||||
}
|
||||
title={t('OAuth2 服务器配置')}
|
||||
language='json'
|
||||
/>
|
||||
</Modal>
|
||||
);
|
||||
};
|
||||
|
||||
export default ServerInfoModal;
|
||||
@@ -28,7 +28,6 @@ import {
|
||||
Tabs,
|
||||
TabPane,
|
||||
Popover,
|
||||
Modal,
|
||||
} from '@douyinfe/semi-ui';
|
||||
import {
|
||||
IconMail,
|
||||
@@ -84,9 +83,6 @@ const AccountManagement = ({
|
||||
</Popover>
|
||||
);
|
||||
};
|
||||
const isBound = (accountId) => Boolean(accountId);
|
||||
const [showTelegramBindModal, setShowTelegramBindModal] = React.useState(false);
|
||||
|
||||
return (
|
||||
<Card className='!rounded-2xl'>
|
||||
{/* 卡片头部 */}
|
||||
@@ -146,7 +142,7 @@ const AccountManagement = ({
|
||||
size='small'
|
||||
onClick={() => setShowEmailBindModal(true)}
|
||||
>
|
||||
{isBound(userState.user?.email)
|
||||
{userState.user && userState.user.email !== ''
|
||||
? t('修改绑定')
|
||||
: t('绑定')}
|
||||
</Button>
|
||||
@@ -169,11 +165,9 @@ const AccountManagement = ({
|
||||
{t('微信')}
|
||||
</div>
|
||||
<div className='text-sm text-gray-500 truncate'>
|
||||
{!status.wechat_login
|
||||
? t('未启用')
|
||||
: isBound(userState.user?.wechat_id)
|
||||
? t('已绑定')
|
||||
: t('未绑定')}
|
||||
{userState.user && userState.user.wechat_id !== ''
|
||||
? t('已绑定')
|
||||
: t('未绑定')}
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
@@ -185,7 +179,7 @@ const AccountManagement = ({
|
||||
disabled={!status.wechat_login}
|
||||
onClick={() => setShowWeChatBindModal(true)}
|
||||
>
|
||||
{isBound(userState.user?.wechat_id)
|
||||
{userState.user && userState.user.wechat_id !== ''
|
||||
? t('修改绑定')
|
||||
: status.wechat_login
|
||||
? t('绑定')
|
||||
@@ -226,7 +220,8 @@ const AccountManagement = ({
|
||||
onGitHubOAuthClicked(status.github_client_id)
|
||||
}
|
||||
disabled={
|
||||
isBound(userState.user?.github_id) || !status.github_oauth
|
||||
(userState.user && userState.user.github_id !== '') ||
|
||||
!status.github_oauth
|
||||
}
|
||||
>
|
||||
{status.github_oauth ? t('绑定') : t('未启用')}
|
||||
@@ -269,7 +264,8 @@ const AccountManagement = ({
|
||||
)
|
||||
}
|
||||
disabled={
|
||||
isBound(userState.user?.oidc_id) || !status.oidc_enabled
|
||||
(userState.user && userState.user.oidc_id !== '') ||
|
||||
!status.oidc_enabled
|
||||
}
|
||||
>
|
||||
{status.oidc_enabled ? t('绑定') : t('未启用')}
|
||||
@@ -302,56 +298,26 @@ const AccountManagement = ({
|
||||
</div>
|
||||
<div className='flex-shrink-0'>
|
||||
{status.telegram_oauth ? (
|
||||
isBound(userState.user?.telegram_id) ? (
|
||||
<Button
|
||||
disabled
|
||||
size='small'
|
||||
type='primary'
|
||||
theme='outline'
|
||||
>
|
||||
userState.user.telegram_id !== '' ? (
|
||||
<Button disabled={true} size='small'>
|
||||
{t('已绑定')}
|
||||
</Button>
|
||||
) : (
|
||||
<Button
|
||||
type='primary'
|
||||
theme='outline'
|
||||
size='small'
|
||||
onClick={() => setShowTelegramBindModal(true)}
|
||||
>
|
||||
{t('绑定')}
|
||||
</Button>
|
||||
<div className='scale-75'>
|
||||
<TelegramLoginButton
|
||||
dataAuthUrl='/api/oauth/telegram/bind'
|
||||
botName={status.telegram_bot_name}
|
||||
/>
|
||||
</div>
|
||||
)
|
||||
) : (
|
||||
<Button
|
||||
disabled
|
||||
size='small'
|
||||
type='primary'
|
||||
theme='outline'
|
||||
>
|
||||
<Button disabled={true} size='small'>
|
||||
{t('未启用')}
|
||||
</Button>
|
||||
)}
|
||||
</div>
|
||||
</div>
|
||||
</Card>
|
||||
<Modal
|
||||
title={t('绑定 Telegram')}
|
||||
visible={showTelegramBindModal}
|
||||
onCancel={() => setShowTelegramBindModal(false)}
|
||||
footer={null}
|
||||
>
|
||||
<div className='my-3 text-sm text-gray-600'>
|
||||
{t('点击下方按钮通过 Telegram 完成绑定')}
|
||||
</div>
|
||||
<div className='flex justify-center'>
|
||||
<div className='scale-90'>
|
||||
<TelegramLoginButton
|
||||
dataAuthUrl='/api/oauth/telegram/bind'
|
||||
botName={status.telegram_bot_name}
|
||||
/>
|
||||
</div>
|
||||
</div>
|
||||
</Modal>
|
||||
|
||||
{/* LinuxDO绑定 */}
|
||||
<Card className='!rounded-xl'>
|
||||
@@ -384,7 +350,8 @@ const AccountManagement = ({
|
||||
onLinuxDOOAuthClicked(status.linuxdo_client_id)
|
||||
}
|
||||
disabled={
|
||||
isBound(userState.user?.linux_do_id) || !status.linuxdo_oauth
|
||||
(userState.user && userState.user.linux_do_id !== '') ||
|
||||
!status.linuxdo_oauth
|
||||
}
|
||||
>
|
||||
{status.linuxdo_oauth ? t('绑定') : t('未启用')}
|
||||
|
||||
@@ -40,11 +40,10 @@ import {
|
||||
showSuccess,
|
||||
showError,
|
||||
} from '../../../../helpers';
|
||||
import CodeViewer from '../../../playground/CodeViewer';
|
||||
import CodeViewer from '../../../common/ui/CodeViewer';
|
||||
import { StatusContext } from '../../../../context/Status';
|
||||
import { UserContext } from '../../../../context/User';
|
||||
import { useUserPermissions } from '../../../../hooks/common/useUserPermissions';
|
||||
import { useSidebar } from '../../../../hooks/common/useSidebar';
|
||||
|
||||
const NotificationSettings = ({
|
||||
t,
|
||||
@@ -98,9 +97,6 @@ const NotificationSettings = ({
|
||||
isSidebarModuleAllowed,
|
||||
} = useUserPermissions();
|
||||
|
||||
// 使用useSidebar钩子获取刷新方法
|
||||
const { refreshUserConfig } = useSidebar();
|
||||
|
||||
// 左侧边栏设置处理函数
|
||||
const handleSectionChange = (sectionKey) => {
|
||||
return (checked) => {
|
||||
@@ -136,9 +132,6 @@ const NotificationSettings = ({
|
||||
});
|
||||
if (res.data.success) {
|
||||
showSuccess(t('侧边栏设置保存成功'));
|
||||
|
||||
// 刷新useSidebar钩子中的用户配置,实现实时更新
|
||||
await refreshUserConfig();
|
||||
} else {
|
||||
showError(res.data.message);
|
||||
}
|
||||
@@ -341,7 +334,7 @@ const NotificationSettings = ({
|
||||
loading={sidebarLoading}
|
||||
className='!rounded-lg'
|
||||
>
|
||||
{t('保存设置')}
|
||||
{t('保存边栏设置')}
|
||||
</Button>
|
||||
</>
|
||||
) : (
|
||||
|
||||
@@ -85,26 +85,6 @@ const REGION_EXAMPLE = {
|
||||
'claude-3-5-sonnet-20240620': 'europe-west1',
|
||||
};
|
||||
|
||||
// 支持并且已适配通过接口获取模型列表的渠道类型
|
||||
const MODEL_FETCHABLE_TYPES = new Set([
|
||||
1,
|
||||
4,
|
||||
14,
|
||||
34,
|
||||
17,
|
||||
26,
|
||||
24,
|
||||
47,
|
||||
25,
|
||||
20,
|
||||
23,
|
||||
31,
|
||||
35,
|
||||
40,
|
||||
42,
|
||||
48,
|
||||
]);
|
||||
|
||||
function type2secretPrompt(type) {
|
||||
// inputs.type === 15 ? '按照如下格式输入:APIKey|SecretKey' : (inputs.type === 18 ? '按照如下格式输入:APPID|APISecret|APIKey' : '请输入渠道对应的鉴权密钥')
|
||||
switch (type) {
|
||||
@@ -164,8 +144,6 @@ const EditChannelModal = (props) => {
|
||||
settings: '',
|
||||
// 仅 Vertex: 密钥格式(存入 settings.vertex_key_type)
|
||||
vertex_key_type: 'json',
|
||||
// 企业账户设置
|
||||
is_enterprise_account: false,
|
||||
};
|
||||
const [batch, setBatch] = useState(false);
|
||||
const [multiToSingle, setMultiToSingle] = useState(false);
|
||||
@@ -191,7 +169,6 @@ const EditChannelModal = (props) => {
|
||||
const [channelSearchValue, setChannelSearchValue] = useState('');
|
||||
const [useManualInput, setUseManualInput] = useState(false); // 是否使用手动输入模式
|
||||
const [keyMode, setKeyMode] = useState('append'); // 密钥模式:replace(覆盖)或 append(追加)
|
||||
const [isEnterpriseAccount, setIsEnterpriseAccount] = useState(false); // 是否为企业账户
|
||||
|
||||
// 2FA验证查看密钥相关状态
|
||||
const [twoFAState, setTwoFAState] = useState({
|
||||
@@ -238,7 +215,7 @@ const EditChannelModal = (props) => {
|
||||
pass_through_body_enabled: false,
|
||||
system_prompt: '',
|
||||
});
|
||||
const showApiConfigCard = true; // 控制是否显示 API 配置卡片
|
||||
const showApiConfigCard = inputs.type !== 45; // 控制是否显示 API 配置卡片(仅当渠道类型不是 豆包 时显示)
|
||||
const getInitValues = () => ({ ...originInputs });
|
||||
|
||||
// 处理渠道额外设置的更新
|
||||
@@ -345,10 +322,6 @@ const EditChannelModal = (props) => {
|
||||
case 36:
|
||||
localModels = ['suno_music', 'suno_lyrics'];
|
||||
break;
|
||||
case 45:
|
||||
localModels = getChannelModels(value);
|
||||
setInputs((prevInputs) => ({ ...prevInputs, base_url: 'https://ark.cn-beijing.volces.com' }));
|
||||
break;
|
||||
default:
|
||||
localModels = getChannelModels(value);
|
||||
break;
|
||||
@@ -440,27 +413,15 @@ const EditChannelModal = (props) => {
|
||||
parsedSettings.azure_responses_version || '';
|
||||
// 读取 Vertex 密钥格式
|
||||
data.vertex_key_type = parsedSettings.vertex_key_type || 'json';
|
||||
// 读取企业账户设置
|
||||
data.is_enterprise_account = parsedSettings.openrouter_enterprise === true;
|
||||
} catch (error) {
|
||||
console.error('解析其他设置失败:', error);
|
||||
data.azure_responses_version = '';
|
||||
data.region = '';
|
||||
data.vertex_key_type = 'json';
|
||||
data.is_enterprise_account = false;
|
||||
}
|
||||
} else {
|
||||
// 兼容历史数据:老渠道没有 settings 时,默认按 json 展示
|
||||
data.vertex_key_type = 'json';
|
||||
data.is_enterprise_account = false;
|
||||
}
|
||||
|
||||
if (
|
||||
data.type === 45 &&
|
||||
(!data.base_url ||
|
||||
(typeof data.base_url === 'string' && data.base_url.trim() === ''))
|
||||
) {
|
||||
data.base_url = 'https://ark.cn-beijing.volces.com';
|
||||
}
|
||||
|
||||
setInputs(data);
|
||||
@@ -472,8 +433,6 @@ const EditChannelModal = (props) => {
|
||||
} else {
|
||||
setAutoBan(true);
|
||||
}
|
||||
// 同步企业账户状态
|
||||
setIsEnterpriseAccount(data.is_enterprise_account || false);
|
||||
setBasicModels(getChannelModels(data.type));
|
||||
// 同步更新channelSettings状态显示
|
||||
setChannelSettings({
|
||||
@@ -733,8 +692,6 @@ const EditChannelModal = (props) => {
|
||||
});
|
||||
// 重置密钥模式状态
|
||||
setKeyMode('append');
|
||||
// 重置企业账户状态
|
||||
setIsEnterpriseAccount(false);
|
||||
// 清空表单中的key_mode字段
|
||||
if (formApiRef.current) {
|
||||
formApiRef.current.setValue('key_mode', undefined);
|
||||
@@ -867,10 +824,6 @@ const EditChannelModal = (props) => {
|
||||
showInfo(t('请至少选择一个模型!'));
|
||||
return;
|
||||
}
|
||||
if (localInputs.type === 45 && (!localInputs.base_url || localInputs.base_url.trim() === '')) {
|
||||
showInfo(t('请输入API地址!'));
|
||||
return;
|
||||
}
|
||||
if (
|
||||
localInputs.model_mapping &&
|
||||
localInputs.model_mapping !== '' &&
|
||||
@@ -900,21 +853,6 @@ const EditChannelModal = (props) => {
|
||||
};
|
||||
localInputs.setting = JSON.stringify(channelExtraSettings);
|
||||
|
||||
// 处理type === 20的企业账户设置
|
||||
if (localInputs.type === 20) {
|
||||
let settings = {};
|
||||
if (localInputs.settings) {
|
||||
try {
|
||||
settings = JSON.parse(localInputs.settings);
|
||||
} catch (error) {
|
||||
console.error('解析settings失败:', error);
|
||||
}
|
||||
}
|
||||
// 设置企业账户标识,无论是true还是false都要传到后端
|
||||
settings.openrouter_enterprise = localInputs.is_enterprise_account === true;
|
||||
localInputs.settings = JSON.stringify(settings);
|
||||
}
|
||||
|
||||
// 清理不需要发送到后端的字段
|
||||
delete localInputs.force_format;
|
||||
delete localInputs.thinking_to_content;
|
||||
@@ -922,7 +860,6 @@ const EditChannelModal = (props) => {
|
||||
delete localInputs.pass_through_body_enabled;
|
||||
delete localInputs.system_prompt;
|
||||
delete localInputs.system_prompt_override;
|
||||
delete localInputs.is_enterprise_account;
|
||||
// 顶层的 vertex_key_type 不应发送给后端
|
||||
delete localInputs.vertex_key_type;
|
||||
|
||||
@@ -964,56 +901,6 @@ const EditChannelModal = (props) => {
|
||||
}
|
||||
};
|
||||
|
||||
// 密钥去重函数
|
||||
const deduplicateKeys = () => {
|
||||
const currentKey = formApiRef.current?.getValue('key') || inputs.key || '';
|
||||
|
||||
if (!currentKey.trim()) {
|
||||
showInfo(t('请先输入密钥'));
|
||||
return;
|
||||
}
|
||||
|
||||
// 按行分割密钥
|
||||
const keyLines = currentKey.split('\n');
|
||||
const beforeCount = keyLines.length;
|
||||
|
||||
// 使用哈希表去重,保持原有顺序
|
||||
const keySet = new Set();
|
||||
const deduplicatedKeys = [];
|
||||
|
||||
keyLines.forEach((line) => {
|
||||
const trimmedLine = line.trim();
|
||||
if (trimmedLine && !keySet.has(trimmedLine)) {
|
||||
keySet.add(trimmedLine);
|
||||
deduplicatedKeys.push(trimmedLine);
|
||||
}
|
||||
});
|
||||
|
||||
const afterCount = deduplicatedKeys.length;
|
||||
const deduplicatedKeyText = deduplicatedKeys.join('\n');
|
||||
|
||||
// 更新表单和状态
|
||||
if (formApiRef.current) {
|
||||
formApiRef.current.setValue('key', deduplicatedKeyText);
|
||||
}
|
||||
handleInputChange('key', deduplicatedKeyText);
|
||||
|
||||
// 显示去重结果
|
||||
const message = t(
|
||||
'去重完成:去重前 {{before}} 个密钥,去重后 {{after}} 个密钥',
|
||||
{
|
||||
before: beforeCount,
|
||||
after: afterCount,
|
||||
},
|
||||
);
|
||||
|
||||
if (beforeCount === afterCount) {
|
||||
showInfo(t('未发现重复密钥'));
|
||||
} else {
|
||||
showSuccess(message);
|
||||
}
|
||||
};
|
||||
|
||||
const addCustomModels = () => {
|
||||
if (customModel.trim() === '') return;
|
||||
const modelArray = customModel.split(',').map((model) => model.trim());
|
||||
@@ -1109,41 +996,24 @@ const EditChannelModal = (props) => {
|
||||
</Checkbox>
|
||||
)}
|
||||
{batch && (
|
||||
<>
|
||||
<Checkbox
|
||||
disabled={isEdit}
|
||||
checked={multiToSingle}
|
||||
onChange={() => {
|
||||
setMultiToSingle((prev) => {
|
||||
const nextValue = !prev;
|
||||
setInputs((prevInputs) => {
|
||||
const newInputs = { ...prevInputs };
|
||||
if (nextValue) {
|
||||
newInputs.multi_key_mode = multiKeyMode;
|
||||
} else {
|
||||
delete newInputs.multi_key_mode;
|
||||
}
|
||||
return newInputs;
|
||||
});
|
||||
return nextValue;
|
||||
});
|
||||
}}
|
||||
>
|
||||
{t('密钥聚合模式')}
|
||||
</Checkbox>
|
||||
|
||||
{inputs.type !== 41 && (
|
||||
<Button
|
||||
size='small'
|
||||
type='tertiary'
|
||||
theme='outline'
|
||||
onClick={deduplicateKeys}
|
||||
style={{ textDecoration: 'underline' }}
|
||||
>
|
||||
{t('密钥去重')}
|
||||
</Button>
|
||||
)}
|
||||
</>
|
||||
<Checkbox
|
||||
disabled={isEdit}
|
||||
checked={multiToSingle}
|
||||
onChange={() => {
|
||||
setMultiToSingle((prev) => !prev);
|
||||
setInputs((prev) => {
|
||||
const newInputs = { ...prev };
|
||||
if (!multiToSingle) {
|
||||
newInputs.multi_key_mode = multiKeyMode;
|
||||
} else {
|
||||
delete newInputs.multi_key_mode;
|
||||
}
|
||||
return newInputs;
|
||||
});
|
||||
}}
|
||||
>
|
||||
{t('密钥聚合模式')}
|
||||
</Checkbox>
|
||||
)}
|
||||
</Space>
|
||||
) : null;
|
||||
@@ -1307,21 +1177,6 @@ const EditChannelModal = (props) => {
|
||||
onChange={(value) => handleInputChange('type', value)}
|
||||
/>
|
||||
|
||||
{inputs.type === 20 && (
|
||||
<Form.Switch
|
||||
field='is_enterprise_account'
|
||||
label={t('是否为企业账户')}
|
||||
checkedText={t('是')}
|
||||
uncheckedText={t('否')}
|
||||
onChange={(value) => {
|
||||
setIsEnterpriseAccount(value);
|
||||
handleInputChange('is_enterprise_account', value);
|
||||
}}
|
||||
extraText={t('企业账户为特殊返回格式,需要特殊处理,如果非企业账户,请勿勾选')}
|
||||
initValue={inputs.is_enterprise_account}
|
||||
/>
|
||||
)}
|
||||
|
||||
<Form.Input
|
||||
field='name'
|
||||
label={t('名称')}
|
||||
@@ -1405,7 +1260,7 @@ const EditChannelModal = (props) => {
|
||||
autoComplete='new-password'
|
||||
onChange={(value) => handleInputChange('key', value)}
|
||||
extraText={
|
||||
<div className='flex items-center gap-2 flex-wrap'>
|
||||
<div className='flex items-center gap-2'>
|
||||
{isEdit &&
|
||||
isMultiKeyChannel &&
|
||||
keyMode === 'append' && (
|
||||
@@ -1941,30 +1796,6 @@ const EditChannelModal = (props) => {
|
||||
/>
|
||||
</div>
|
||||
)}
|
||||
|
||||
{inputs.type === 45 && (
|
||||
<div>
|
||||
<Form.Select
|
||||
field='base_url'
|
||||
label={t('API地址')}
|
||||
placeholder={t('请选择API地址')}
|
||||
onChange={(value) =>
|
||||
handleInputChange('base_url', value)
|
||||
}
|
||||
optionList={[
|
||||
{
|
||||
value: 'https://ark.cn-beijing.volces.com',
|
||||
label: 'https://ark.cn-beijing.volces.com'
|
||||
},
|
||||
{
|
||||
value: 'https://ark.ap-southeast.bytepluses.com',
|
||||
label: 'https://ark.ap-southeast.bytepluses.com'
|
||||
}
|
||||
]}
|
||||
defaultValue='https://ark.cn-beijing.volces.com'
|
||||
/>
|
||||
</div>
|
||||
)}
|
||||
</Card>
|
||||
)}
|
||||
|
||||
@@ -2048,15 +1879,13 @@ const EditChannelModal = (props) => {
|
||||
>
|
||||
{t('填入所有模型')}
|
||||
</Button>
|
||||
{MODEL_FETCHABLE_TYPES.has(inputs.type) && (
|
||||
<Button
|
||||
size='small'
|
||||
type='tertiary'
|
||||
onClick={() => fetchUpstreamModelList('models')}
|
||||
>
|
||||
{t('获取模型列表')}
|
||||
</Button>
|
||||
)}
|
||||
<Button
|
||||
size='small'
|
||||
type='tertiary'
|
||||
onClick={() => fetchUpstreamModelList('models')}
|
||||
>
|
||||
{t('获取模型列表')}
|
||||
</Button>
|
||||
<Button
|
||||
size='small'
|
||||
type='warning'
|
||||
|
||||
@@ -247,32 +247,6 @@ const MultiKeyManageModal = ({ visible, onCancel, channel, onRefresh }) => {
|
||||
}
|
||||
};
|
||||
|
||||
// Delete a specific key
|
||||
const handleDeleteKey = async (keyIndex) => {
|
||||
const operationId = `delete_${keyIndex}`;
|
||||
setOperationLoading((prev) => ({ ...prev, [operationId]: true }));
|
||||
|
||||
try {
|
||||
const res = await API.post('/api/channel/multi_key/manage', {
|
||||
channel_id: channel.id,
|
||||
action: 'delete_key',
|
||||
key_index: keyIndex,
|
||||
});
|
||||
|
||||
if (res.data.success) {
|
||||
showSuccess(t('密钥已删除'));
|
||||
await loadKeyStatus(currentPage, pageSize); // Reload current page
|
||||
onRefresh && onRefresh(); // Refresh parent component
|
||||
} else {
|
||||
showError(res.data.message);
|
||||
}
|
||||
} catch (error) {
|
||||
showError(t('删除密钥失败'));
|
||||
} finally {
|
||||
setOperationLoading((prev) => ({ ...prev, [operationId]: false }));
|
||||
}
|
||||
};
|
||||
|
||||
// Handle page change
|
||||
const handlePageChange = (page) => {
|
||||
setCurrentPage(page);
|
||||
@@ -410,7 +384,7 @@ const MultiKeyManageModal = ({ visible, onCancel, channel, onRefresh }) => {
|
||||
title: t('操作'),
|
||||
key: 'action',
|
||||
fixed: 'right',
|
||||
width: 150,
|
||||
width: 100,
|
||||
render: (_, record) => (
|
||||
<Space>
|
||||
{record.status === 1 ? (
|
||||
@@ -432,21 +406,6 @@ const MultiKeyManageModal = ({ visible, onCancel, channel, onRefresh }) => {
|
||||
{t('启用')}
|
||||
</Button>
|
||||
)}
|
||||
<Popconfirm
|
||||
title={t('确定要删除此密钥吗?')}
|
||||
content={t('此操作不可撤销,将永久删除该密钥')}
|
||||
onConfirm={() => handleDeleteKey(record.index)}
|
||||
okType={'danger'}
|
||||
position={'topRight'}
|
||||
>
|
||||
<Button
|
||||
type='danger'
|
||||
size='small'
|
||||
loading={operationLoading[`delete_${record.index}`]}
|
||||
>
|
||||
{t('删除')}
|
||||
</Button>
|
||||
</Popconfirm>
|
||||
</Space>
|
||||
),
|
||||
},
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user