mirror of
https://github.com/QuantumNous/new-api.git
synced 2026-04-18 18:57:27 +00:00
137 lines
2.8 KiB
Go
137 lines
2.8 KiB
Go
package hooks
|
||
|
||
import (
|
||
"fmt"
|
||
"sync"
|
||
|
||
"github.com/QuantumNous/new-api/common"
|
||
"github.com/QuantumNous/new-api/core/interfaces"
|
||
"github.com/QuantumNous/new-api/core/registry"
|
||
)
|
||
|
||
var (
|
||
// 全局Hook链实例(单例)
|
||
globalChain *HookChain
|
||
globalChainOnce sync.Once
|
||
)
|
||
|
||
// HookChain Hook执行链
|
||
type HookChain struct {
|
||
hooks []interfaces.RelayHook
|
||
mu sync.RWMutex
|
||
}
|
||
|
||
// GetGlobalChain 获取全局Hook链实例
|
||
func GetGlobalChain() *HookChain {
|
||
globalChainOnce.Do(func() {
|
||
globalChain = &HookChain{
|
||
hooks: make([]interfaces.RelayHook, 0),
|
||
}
|
||
// 从注册中心加载hooks
|
||
globalChain.LoadHooks()
|
||
})
|
||
return globalChain
|
||
}
|
||
|
||
// LoadHooks 从注册中心加载hooks
|
||
func (c *HookChain) LoadHooks() {
|
||
c.mu.Lock()
|
||
defer c.mu.Unlock()
|
||
|
||
c.hooks = registry.ListHooks()
|
||
common.SysLog(fmt.Sprintf("Loaded %d enabled hooks", len(c.hooks)))
|
||
}
|
||
|
||
// ReloadHooks 重新加载hooks
|
||
func (c *HookChain) ReloadHooks() {
|
||
c.LoadHooks()
|
||
common.SysLog("Hooks reloaded")
|
||
}
|
||
|
||
// ExecuteBeforeRequest 执行所有BeforeRequest钩子
|
||
func (c *HookChain) ExecuteBeforeRequest(ctx *interfaces.HookContext) error {
|
||
c.mu.RLock()
|
||
hooks := c.hooks
|
||
c.mu.RUnlock()
|
||
|
||
for _, hook := range hooks {
|
||
if !hook.Enabled() {
|
||
continue
|
||
}
|
||
|
||
if ctx.ShouldSkip {
|
||
break
|
||
}
|
||
|
||
if err := hook.OnBeforeRequest(ctx); err != nil {
|
||
common.SysError(fmt.Sprintf("Hook %s OnBeforeRequest error: %v", hook.Name(), err))
|
||
return fmt.Errorf("hook %s failed: %w", hook.Name(), err)
|
||
}
|
||
}
|
||
|
||
return nil
|
||
}
|
||
|
||
// ExecuteAfterResponse 执行所有AfterResponse钩子
|
||
func (c *HookChain) ExecuteAfterResponse(ctx *interfaces.HookContext) error {
|
||
c.mu.RLock()
|
||
hooks := c.hooks
|
||
c.mu.RUnlock()
|
||
|
||
for _, hook := range hooks {
|
||
if !hook.Enabled() {
|
||
continue
|
||
}
|
||
|
||
if ctx.ShouldSkip {
|
||
break
|
||
}
|
||
|
||
if err := hook.OnAfterResponse(ctx); err != nil {
|
||
common.SysError(fmt.Sprintf("Hook %s OnAfterResponse error: %v", hook.Name(), err))
|
||
return fmt.Errorf("hook %s failed: %w", hook.Name(), err)
|
||
}
|
||
}
|
||
|
||
return nil
|
||
}
|
||
|
||
// ExecuteOnError 执行所有OnError钩子
|
||
func (c *HookChain) ExecuteOnError(ctx *interfaces.HookContext, err error) error {
|
||
c.mu.RLock()
|
||
hooks := c.hooks
|
||
c.mu.RUnlock()
|
||
|
||
for _, hook := range hooks {
|
||
if !hook.Enabled() {
|
||
continue
|
||
}
|
||
|
||
if hookErr := hook.OnError(ctx, err); hookErr != nil {
|
||
common.SysError(fmt.Sprintf("Hook %s OnError failed: %v", hook.Name(), hookErr))
|
||
// OnError钩子的错误不会中断执行
|
||
}
|
||
}
|
||
|
||
return err
|
||
}
|
||
|
||
// GetHooks 获取当前hook列表
|
||
func (c *HookChain) GetHooks() []interfaces.RelayHook {
|
||
c.mu.RLock()
|
||
defer c.mu.RUnlock()
|
||
|
||
hooks := make([]interfaces.RelayHook, len(c.hooks))
|
||
copy(hooks, c.hooks)
|
||
return hooks
|
||
}
|
||
|
||
// Count 返回hook数量
|
||
func (c *HookChain) Count() int {
|
||
c.mu.RLock()
|
||
defer c.mu.RUnlock()
|
||
|
||
return len(c.hooks)
|
||
}
|
||
|