mirror of
https://github.com/QuantumNous/new-api.git
synced 2026-03-30 08:36:22 +00:00
211 lines
7.8 KiB
Go
211 lines
7.8 KiB
Go
package model
|
|
|
|
import (
|
|
"encoding/base64"
|
|
"encoding/json"
|
|
"errors"
|
|
"fmt"
|
|
"strings"
|
|
"time"
|
|
|
|
"github.com/QuantumNous/new-api/common"
|
|
|
|
"github.com/go-webauthn/webauthn/protocol"
|
|
"github.com/go-webauthn/webauthn/webauthn"
|
|
"gorm.io/gorm"
|
|
)
|
|
|
|
var (
|
|
ErrPasskeyNotFound = errors.New("passkey credential not found")
|
|
ErrFriendlyPasskeyNotFound = errors.New("Passkey 验证失败,请重试或联系管理员")
|
|
)
|
|
|
|
type PasskeyCredential struct {
|
|
ID int `json:"id" gorm:"primaryKey"`
|
|
UserID int `json:"user_id" gorm:"uniqueIndex;not null"`
|
|
CredentialID string `json:"credential_id" gorm:"type:varchar(512);uniqueIndex;not null"` // base64 encoded
|
|
PublicKey string `json:"public_key" gorm:"type:text;not null"` // base64 encoded
|
|
AttestationType string `json:"attestation_type" gorm:"type:varchar(255)"`
|
|
AAGUID string `json:"aaguid" gorm:"type:varchar(512)"` // base64 encoded
|
|
SignCount uint32 `json:"sign_count" gorm:"default:0"`
|
|
CloneWarning bool `json:"clone_warning"`
|
|
UserPresent bool `json:"user_present"`
|
|
UserVerified bool `json:"user_verified"`
|
|
BackupEligible bool `json:"backup_eligible"`
|
|
BackupState bool `json:"backup_state"`
|
|
Transports string `json:"transports" gorm:"type:text"`
|
|
Attachment string `json:"attachment" gorm:"type:varchar(32)"`
|
|
LastUsedAt *time.Time `json:"last_used_at"`
|
|
CreatedAt time.Time `json:"created_at"`
|
|
UpdatedAt time.Time `json:"updated_at"`
|
|
DeletedAt gorm.DeletedAt `json:"-" gorm:"index"`
|
|
}
|
|
|
|
func (p *PasskeyCredential) TransportList() []protocol.AuthenticatorTransport {
|
|
if p == nil || strings.TrimSpace(p.Transports) == "" {
|
|
return nil
|
|
}
|
|
var transports []string
|
|
if err := json.Unmarshal([]byte(p.Transports), &transports); err != nil {
|
|
return nil
|
|
}
|
|
result := make([]protocol.AuthenticatorTransport, 0, len(transports))
|
|
for _, transport := range transports {
|
|
result = append(result, protocol.AuthenticatorTransport(transport))
|
|
}
|
|
return result
|
|
}
|
|
|
|
func (p *PasskeyCredential) SetTransports(list []protocol.AuthenticatorTransport) {
|
|
if len(list) == 0 {
|
|
p.Transports = ""
|
|
return
|
|
}
|
|
stringList := make([]string, len(list))
|
|
for i, transport := range list {
|
|
stringList[i] = string(transport)
|
|
}
|
|
encoded, err := json.Marshal(stringList)
|
|
if err != nil {
|
|
return
|
|
}
|
|
p.Transports = string(encoded)
|
|
}
|
|
|
|
func (p *PasskeyCredential) ToWebAuthnCredential() webauthn.Credential {
|
|
flags := webauthn.CredentialFlags{
|
|
UserPresent: p.UserPresent,
|
|
UserVerified: p.UserVerified,
|
|
BackupEligible: p.BackupEligible,
|
|
BackupState: p.BackupState,
|
|
}
|
|
|
|
credID, _ := base64.StdEncoding.DecodeString(p.CredentialID)
|
|
pubKey, _ := base64.StdEncoding.DecodeString(p.PublicKey)
|
|
aaguid, _ := base64.StdEncoding.DecodeString(p.AAGUID)
|
|
|
|
return webauthn.Credential{
|
|
ID: credID,
|
|
PublicKey: pubKey,
|
|
AttestationType: p.AttestationType,
|
|
Transport: p.TransportList(),
|
|
Flags: flags,
|
|
Authenticator: webauthn.Authenticator{
|
|
AAGUID: aaguid,
|
|
SignCount: p.SignCount,
|
|
CloneWarning: p.CloneWarning,
|
|
Attachment: protocol.AuthenticatorAttachment(p.Attachment),
|
|
},
|
|
}
|
|
}
|
|
|
|
func NewPasskeyCredentialFromWebAuthn(userID int, credential *webauthn.Credential) *PasskeyCredential {
|
|
if credential == nil {
|
|
return nil
|
|
}
|
|
passkey := &PasskeyCredential{
|
|
UserID: userID,
|
|
CredentialID: base64.StdEncoding.EncodeToString(credential.ID),
|
|
PublicKey: base64.StdEncoding.EncodeToString(credential.PublicKey),
|
|
AttestationType: credential.AttestationType,
|
|
AAGUID: base64.StdEncoding.EncodeToString(credential.Authenticator.AAGUID),
|
|
SignCount: credential.Authenticator.SignCount,
|
|
CloneWarning: credential.Authenticator.CloneWarning,
|
|
UserPresent: credential.Flags.UserPresent,
|
|
UserVerified: credential.Flags.UserVerified,
|
|
BackupEligible: credential.Flags.BackupEligible,
|
|
BackupState: credential.Flags.BackupState,
|
|
Attachment: string(credential.Authenticator.Attachment),
|
|
}
|
|
passkey.SetTransports(credential.Transport)
|
|
return passkey
|
|
}
|
|
|
|
func (p *PasskeyCredential) ApplyValidatedCredential(credential *webauthn.Credential) {
|
|
if credential == nil || p == nil {
|
|
return
|
|
}
|
|
p.CredentialID = base64.StdEncoding.EncodeToString(credential.ID)
|
|
p.PublicKey = base64.StdEncoding.EncodeToString(credential.PublicKey)
|
|
p.AttestationType = credential.AttestationType
|
|
p.AAGUID = base64.StdEncoding.EncodeToString(credential.Authenticator.AAGUID)
|
|
p.SignCount = credential.Authenticator.SignCount
|
|
p.CloneWarning = credential.Authenticator.CloneWarning
|
|
p.UserPresent = credential.Flags.UserPresent
|
|
p.UserVerified = credential.Flags.UserVerified
|
|
p.BackupEligible = credential.Flags.BackupEligible
|
|
p.BackupState = credential.Flags.BackupState
|
|
p.Attachment = string(credential.Authenticator.Attachment)
|
|
p.SetTransports(credential.Transport)
|
|
}
|
|
|
|
func GetPasskeyByUserID(userID int) (*PasskeyCredential, error) {
|
|
if userID == 0 {
|
|
common.SysLog("GetPasskeyByUserID: empty user ID")
|
|
return nil, ErrFriendlyPasskeyNotFound
|
|
}
|
|
var credential PasskeyCredential
|
|
if err := DB.Where("user_id = ?", userID).First(&credential).Error; err != nil {
|
|
if errors.Is(err, gorm.ErrRecordNotFound) {
|
|
// 未找到记录是正常情况(用户未绑定),返回 ErrPasskeyNotFound 而不记录日志
|
|
return nil, ErrPasskeyNotFound
|
|
}
|
|
// 只有真正的数据库错误才记录日志
|
|
common.SysLog(fmt.Sprintf("GetPasskeyByUserID: database error for user %d: %v", userID, err))
|
|
return nil, ErrFriendlyPasskeyNotFound
|
|
}
|
|
return &credential, nil
|
|
}
|
|
|
|
func GetPasskeyByCredentialID(credentialID []byte) (*PasskeyCredential, error) {
|
|
if len(credentialID) == 0 {
|
|
common.SysLog("GetPasskeyByCredentialID: empty credential ID")
|
|
return nil, ErrFriendlyPasskeyNotFound
|
|
}
|
|
|
|
credIDStr := base64.StdEncoding.EncodeToString(credentialID)
|
|
var credential PasskeyCredential
|
|
if err := DB.Where("credential_id = ?", credIDStr).First(&credential).Error; err != nil {
|
|
if errors.Is(err, gorm.ErrRecordNotFound) {
|
|
common.SysLog(fmt.Sprintf("GetPasskeyByCredentialID: passkey not found for credential ID length %d", len(credentialID)))
|
|
return nil, ErrFriendlyPasskeyNotFound
|
|
}
|
|
common.SysLog(fmt.Sprintf("GetPasskeyByCredentialID: database error for credential ID: %v", err))
|
|
return nil, ErrFriendlyPasskeyNotFound
|
|
}
|
|
|
|
return &credential, nil
|
|
}
|
|
|
|
func UpsertPasskeyCredential(credential *PasskeyCredential) error {
|
|
if credential == nil {
|
|
common.SysLog("UpsertPasskeyCredential: nil credential provided")
|
|
return fmt.Errorf("Passkey 保存失败,请重试")
|
|
}
|
|
return DB.Transaction(func(tx *gorm.DB) error {
|
|
// 使用Unscoped()进行硬删除,避免唯一索引冲突
|
|
if err := tx.Unscoped().Where("user_id = ?", credential.UserID).Delete(&PasskeyCredential{}).Error; err != nil {
|
|
common.SysLog(fmt.Sprintf("UpsertPasskeyCredential: failed to delete existing credential for user %d: %v", credential.UserID, err))
|
|
return fmt.Errorf("Passkey 保存失败,请重试")
|
|
}
|
|
if err := tx.Create(credential).Error; err != nil {
|
|
common.SysLog(fmt.Sprintf("UpsertPasskeyCredential: failed to create credential for user %d: %v", credential.UserID, err))
|
|
return fmt.Errorf("Passkey 保存失败,请重试")
|
|
}
|
|
return nil
|
|
})
|
|
}
|
|
|
|
func DeletePasskeyByUserID(userID int) error {
|
|
if userID == 0 {
|
|
common.SysLog("DeletePasskeyByUserID: empty user ID")
|
|
return fmt.Errorf("删除失败,请重试")
|
|
}
|
|
// 使用Unscoped()进行硬删除,避免唯一索引冲突
|
|
if err := DB.Unscoped().Where("user_id = ?", userID).Delete(&PasskeyCredential{}).Error; err != nil {
|
|
common.SysLog(fmt.Sprintf("DeletePasskeyByUserID: failed to delete passkey for user %d: %v", userID, err))
|
|
return fmt.Errorf("删除失败,请重试")
|
|
}
|
|
return nil
|
|
}
|