mirror of
https://github.com/QuantumNous/new-api.git
synced 2026-04-30 12:21:45 +00:00
184 lines
5.5 KiB
Go
184 lines
5.5 KiB
Go
package model
|
|
|
|
import (
|
|
"encoding/json"
|
|
"one-api/common"
|
|
"strings"
|
|
"time"
|
|
|
|
"gorm.io/gorm"
|
|
)
|
|
|
|
// OAuthClient OAuth2 客户端模型
|
|
type OAuthClient struct {
|
|
ID string `json:"id" gorm:"type:varchar(64);primaryKey"`
|
|
Secret string `json:"secret" gorm:"type:varchar(128);not null"`
|
|
Name string `json:"name" gorm:"type:varchar(255);not null"`
|
|
Domain string `json:"domain" gorm:"type:varchar(255)"` // 允许的重定向域名
|
|
RedirectURIs string `json:"redirect_uris" gorm:"type:text"` // JSON array of redirect URIs
|
|
GrantTypes string `json:"grant_types" gorm:"type:varchar(255);default:'client_credentials'"`
|
|
Scopes string `json:"scopes" gorm:"type:varchar(255);default:'api:read'"`
|
|
RequirePKCE bool `json:"require_pkce" gorm:"default:true"`
|
|
Status int `json:"status" gorm:"type:int;default:1"` // 1: enabled, 2: disabled
|
|
CreatedBy int `json:"created_by" gorm:"type:int;not null"` // 创建者用户ID
|
|
CreatedTime int64 `json:"created_time" gorm:"bigint"`
|
|
LastUsedTime int64 `json:"last_used_time" gorm:"bigint;default:0"`
|
|
TokenCount int `json:"token_count" gorm:"type:int;default:0"` // 已签发的token数量
|
|
Description string `json:"description" gorm:"type:text"`
|
|
ClientType string `json:"client_type" gorm:"type:varchar(32);default:'confidential'"` // confidential, public
|
|
DeletedAt gorm.DeletedAt `gorm:"index"`
|
|
}
|
|
|
|
// GetRedirectURIs 获取重定向URI列表
|
|
func (c *OAuthClient) GetRedirectURIs() []string {
|
|
if c.RedirectURIs == "" {
|
|
return []string{}
|
|
}
|
|
var uris []string
|
|
err := json.Unmarshal([]byte(c.RedirectURIs), &uris)
|
|
if err != nil {
|
|
common.SysLog("failed to unmarshal redirect URIs: " + err.Error())
|
|
return []string{}
|
|
}
|
|
return uris
|
|
}
|
|
|
|
// SetRedirectURIs 设置重定向URI列表
|
|
func (c *OAuthClient) SetRedirectURIs(uris []string) {
|
|
data, err := json.Marshal(uris)
|
|
if err != nil {
|
|
common.SysLog("failed to marshal redirect URIs: " + err.Error())
|
|
return
|
|
}
|
|
c.RedirectURIs = string(data)
|
|
}
|
|
|
|
// GetGrantTypes 获取允许的授权类型列表
|
|
func (c *OAuthClient) GetGrantTypes() []string {
|
|
if c.GrantTypes == "" {
|
|
return []string{"client_credentials"}
|
|
}
|
|
return strings.Split(c.GrantTypes, ",")
|
|
}
|
|
|
|
// SetGrantTypes 设置允许的授权类型列表
|
|
func (c *OAuthClient) SetGrantTypes(types []string) {
|
|
c.GrantTypes = strings.Join(types, ",")
|
|
}
|
|
|
|
// GetScopes 获取允许的scope列表
|
|
func (c *OAuthClient) GetScopes() []string {
|
|
if c.Scopes == "" {
|
|
return []string{"api:read"}
|
|
}
|
|
return strings.Split(c.Scopes, ",")
|
|
}
|
|
|
|
// SetScopes 设置允许的scope列表
|
|
func (c *OAuthClient) SetScopes(scopes []string) {
|
|
c.Scopes = strings.Join(scopes, ",")
|
|
}
|
|
|
|
// ValidateRedirectURI 验证重定向URI是否有效
|
|
func (c *OAuthClient) ValidateRedirectURI(uri string) bool {
|
|
allowedURIs := c.GetRedirectURIs()
|
|
for _, allowedURI := range allowedURIs {
|
|
if allowedURI == uri {
|
|
return true
|
|
}
|
|
}
|
|
return false
|
|
}
|
|
|
|
// ValidateGrantType 验证授权类型是否被允许
|
|
func (c *OAuthClient) ValidateGrantType(grantType string) bool {
|
|
allowedTypes := c.GetGrantTypes()
|
|
for _, allowedType := range allowedTypes {
|
|
if allowedType == grantType {
|
|
return true
|
|
}
|
|
}
|
|
return false
|
|
}
|
|
|
|
// ValidateScope 验证scope是否被允许
|
|
func (c *OAuthClient) ValidateScope(scope string) bool {
|
|
allowedScopes := c.GetScopes()
|
|
requestedScopes := strings.Split(scope, " ")
|
|
|
|
for _, requestedScope := range requestedScopes {
|
|
requestedScope = strings.TrimSpace(requestedScope)
|
|
if requestedScope == "" {
|
|
continue
|
|
}
|
|
found := false
|
|
for _, allowedScope := range allowedScopes {
|
|
if allowedScope == requestedScope {
|
|
found = true
|
|
break
|
|
}
|
|
}
|
|
if !found {
|
|
return false
|
|
}
|
|
}
|
|
return true
|
|
}
|
|
|
|
// BeforeCreate GORM hook - 在创建前设置时间
|
|
func (c *OAuthClient) BeforeCreate(tx *gorm.DB) (err error) {
|
|
c.CreatedTime = time.Now().Unix()
|
|
return
|
|
}
|
|
|
|
// UpdateLastUsedTime 更新最后使用时间
|
|
func (c *OAuthClient) UpdateLastUsedTime() error {
|
|
c.LastUsedTime = time.Now().Unix()
|
|
c.TokenCount++
|
|
return DB.Model(c).Select("last_used_time", "token_count").Updates(c).Error
|
|
}
|
|
|
|
// GetOAuthClientByID 根据ID获取OAuth客户端
|
|
func GetOAuthClientByID(id string) (*OAuthClient, error) {
|
|
var client OAuthClient
|
|
err := DB.Where("id = ? AND status = ?", id, common.UserStatusEnabled).First(&client).Error
|
|
return &client, err
|
|
}
|
|
|
|
// GetAllOAuthClients 获取所有OAuth客户端
|
|
func GetAllOAuthClients(startIdx int, num int) ([]*OAuthClient, error) {
|
|
var clients []*OAuthClient
|
|
err := DB.Order("created_time desc").Limit(num).Offset(startIdx).Find(&clients).Error
|
|
return clients, err
|
|
}
|
|
|
|
// SearchOAuthClients 搜索OAuth客户端
|
|
func SearchOAuthClients(keyword string) ([]*OAuthClient, error) {
|
|
var clients []*OAuthClient
|
|
err := DB.Where("name LIKE ? OR id LIKE ? OR description LIKE ?",
|
|
"%"+keyword+"%", "%"+keyword+"%", "%"+keyword+"%").Find(&clients).Error
|
|
return clients, err
|
|
}
|
|
|
|
// CreateOAuthClient 创建OAuth客户端
|
|
func CreateOAuthClient(client *OAuthClient) error {
|
|
return DB.Create(client).Error
|
|
}
|
|
|
|
// UpdateOAuthClient 更新OAuth客户端
|
|
func UpdateOAuthClient(client *OAuthClient) error {
|
|
return DB.Save(client).Error
|
|
}
|
|
|
|
// DeleteOAuthClient 删除OAuth客户端
|
|
func DeleteOAuthClient(id string) error {
|
|
return DB.Where("id = ?", id).Delete(&OAuthClient{}).Error
|
|
}
|
|
|
|
// CountOAuthClients 统计OAuth客户端数量
|
|
func CountOAuthClients() (int64, error) {
|
|
var count int64
|
|
err := DB.Model(&OAuthClient{}).Count(&count).Error
|
|
return count, err
|
|
}
|