Compare commits

...

6 Commits

Author SHA1 Message Date
Apple\Apple
39c841e600 feat(architecture): Core+Plugin 2025-10-13 02:02:11 +08:00
CaIon
ede47ef014 feat: support free model setting 2025-10-12 13:31:03 +08:00
Seefs
6c7795238f Merge pull request #2023 from seefs001/fix/version
fix: version
2025-10-12 13:05:51 +08:00
Seefs
0baacb2686 fix: version 2025-10-12 13:05:13 +08:00
Seefs
c5aaee9f2f Merge pull request #2022 from seefs001/fix/ignore_ghcr
ignore ghcr
2025-10-12 12:40:18 +08:00
Seefs
1987c7e16c ignore ghcr 2025-10-12 12:38:44 +08:00
30 changed files with 2732 additions and 62 deletions

View File

@@ -38,8 +38,8 @@ jobs:
echo "Building tag: $TAG for ${{ matrix.arch }}"
- name: Normalize GHCR repository
run: echo "GHCR_REPOSITORY=${GITHUB_REPOSITORY,,}" >> $GITHUB_ENV
# - name: Normalize GHCR repository
# run: echo "GHCR_REPOSITORY=${GITHUB_REPOSITORY,,}" >> $GITHUB_ENV
- name: Set up Docker Buildx
uses: docker/setup-buildx-action@v3
@@ -50,12 +50,12 @@ jobs:
username: ${{ secrets.DOCKERHUB_USERNAME }}
password: ${{ secrets.DOCKERHUB_TOKEN }}
- name: Log in to GHCR
uses: docker/login-action@v3
with:
registry: ghcr.io
username: ${{ github.actor }}
password: ${{ secrets.GITHUB_TOKEN }}
# - name: Log in to GHCR
# uses: docker/login-action@v3
# with:
# registry: ghcr.io
# username: ${{ github.actor }}
# password: ${{ secrets.GITHUB_TOKEN }}
- name: Extract metadata (labels)
id: meta
@@ -63,7 +63,7 @@ jobs:
with:
images: |
calciumion/new-api
ghcr.io/${{ env.GHCR_REPOSITORY }}
# ghcr.io/${{ env.GHCR_REPOSITORY }}
- name: Build & push single-arch (to both registries)
uses: docker/build-push-action@v6
@@ -74,8 +74,8 @@ jobs:
tags: |
calciumion/new-api:${{ env.TAG }}-${{ matrix.arch }}
calciumion/new-api:latest-${{ matrix.arch }}
ghcr.io/${{ env.GHCR_REPOSITORY }}:${{ env.TAG }}-${{ matrix.arch }}
ghcr.io/${{ env.GHCR_REPOSITORY }}:latest-${{ matrix.arch }}
# ghcr.io/${{ env.GHCR_REPOSITORY }}:${{ env.TAG }}-${{ matrix.arch }}
# ghcr.io/${{ env.GHCR_REPOSITORY }}:latest-${{ matrix.arch }}
labels: ${{ steps.meta.outputs.labels }}
cache-from: type=gha
cache-to: type=gha,mode=max
@@ -83,16 +83,16 @@ jobs:
sbom: false
create_manifests:
name: Create multi-arch manifests (Docker Hub + GHCR)
name: Create multi-arch manifests (Docker Hub)
needs: [build_single_arch]
runs-on: ubuntu-latest
if: startsWith(github.ref, 'refs/tags/')
steps:
- name: Extract tag
run: echo "TAG=${GITHUB_REF#refs/tags/}" >> $GITHUB_ENV
- name: Normalize GHCR repository
run: echo "GHCR_REPOSITORY=${GITHUB_REPOSITORY,,}" >> $GITHUB_ENV
#
# - name: Normalize GHCR repository
# run: echo "GHCR_REPOSITORY=${GITHUB_REPOSITORY,,}" >> $GITHUB_ENV
- name: Log in to Docker Hub
uses: docker/login-action@v3
@@ -115,23 +115,23 @@ jobs:
calciumion/new-api:latest-arm64
# ---- GHCR ----
- name: Log in to GHCR
uses: docker/login-action@v3
with:
registry: ghcr.io
username: ${{ github.actor }}
password: ${{ secrets.GITHUB_TOKEN }}
# - name: Log in to GHCR
# uses: docker/login-action@v3
# with:
# registry: ghcr.io
# username: ${{ github.actor }}
# password: ${{ secrets.GITHUB_TOKEN }}
- name: Create & push manifest (GHCR - version)
run: |
docker buildx imagetools create \
-t ghcr.io/${GHCR_REPOSITORY}:${TAG} \
ghcr.io/${GHCR_REPOSITORY}:${TAG}-amd64 \
ghcr.io/${GHCR_REPOSITORY}:${TAG}-arm64
- name: Create & push manifest (GHCR - latest)
run: |
docker buildx imagetools create \
-t ghcr.io/${GHCR_REPOSITORY}:latest \
ghcr.io/${GHCR_REPOSITORY}:latest-amd64 \
ghcr.io/${GHCR_REPOSITORY}:latest-arm64
# - name: Create & push manifest (GHCR - version)
# run: |
# docker buildx imagetools create \
# -t ghcr.io/${GHCR_REPOSITORY}:${TAG} \
# ghcr.io/${GHCR_REPOSITORY}:${TAG}-amd64 \
# ghcr.io/${GHCR_REPOSITORY}:${TAG}-arm64
#
# - name: Create & push manifest (GHCR - latest)
# run: |
# docker buildx imagetools create \
# -t ghcr.io/${GHCR_REPOSITORY}:latest \
# ghcr.io/${GHCR_REPOSITORY}:latest-amd64 \
# ghcr.io/${GHCR_REPOSITORY}:latest-arm64

View File

@@ -23,7 +23,7 @@ RUN go mod download
COPY . .
COPY --from=builder /build/dist ./web/dist
RUN go build -ldflags "-s -w -X 'new-api/common.Version=$(cat VERSION)'" -o new-api
RUN go build -ldflags "-s -w -X 'github.com/QuantumNous/new-api/common.Version=$(cat VERSION)'" -o new-api
FROM alpine

161
config/plugin_config.go Normal file
View File

@@ -0,0 +1,161 @@
package config
import (
"fmt"
"io/ioutil"
"os"
"path/filepath"
"github.com/QuantumNous/new-api/common"
"github.com/QuantumNous/new-api/core/interfaces"
"gopkg.in/yaml.v3"
)
// PluginConfig 插件配置结构
type PluginConfig struct {
Channels map[string]interfaces.ChannelConfig `yaml:"channels"`
Middlewares []interfaces.MiddlewareConfig `yaml:"middlewares"`
Hooks HooksConfig `yaml:"hooks"`
}
// HooksConfig Hook配置
type HooksConfig struct {
Relay []interfaces.HookConfig `yaml:"relay"`
}
var (
// 全局配置实例
globalPluginConfig *PluginConfig
)
// LoadPluginConfig 加载插件配置
func LoadPluginConfig(configPath string) (*PluginConfig, error) {
// 如果没有指定配置文件路径,使用默认路径
if configPath == "" {
configPath = "config/plugins.yaml"
}
// 检查文件是否存在
if _, err := os.Stat(configPath); os.IsNotExist(err) {
common.SysLog(fmt.Sprintf("Plugin config file not found: %s, using default configuration", configPath))
return getDefaultConfig(), nil
}
// 读取配置文件
data, err := ioutil.ReadFile(configPath)
if err != nil {
return nil, fmt.Errorf("failed to read plugin config: %w", err)
}
// 解析YAML
var config PluginConfig
if err := yaml.Unmarshal(data, &config); err != nil {
return nil, fmt.Errorf("failed to parse plugin config: %w", err)
}
// 环境变量替换
expandEnvVars(&config)
common.SysLog(fmt.Sprintf("Loaded plugin config from: %s", configPath))
return &config, nil
}
// getDefaultConfig 返回默认配置
func getDefaultConfig() *PluginConfig {
return &PluginConfig{
Channels: make(map[string]interfaces.ChannelConfig),
Middlewares: make([]interfaces.MiddlewareConfig, 0),
Hooks: HooksConfig{
Relay: make([]interfaces.HookConfig, 0),
},
}
}
// expandEnvVars 展开环境变量
func expandEnvVars(config *PluginConfig) {
// 展开Hook配置中的环境变量
for i := range config.Hooks.Relay {
for key, value := range config.Hooks.Relay[i].Config {
if strValue, ok := value.(string); ok {
config.Hooks.Relay[i].Config[key] = os.ExpandEnv(strValue)
}
}
}
// 展开Middleware配置中的环境变量
for i := range config.Middlewares {
for key, value := range config.Middlewares[i].Config {
if strValue, ok := value.(string); ok {
config.Middlewares[i].Config[key] = os.ExpandEnv(strValue)
}
}
}
}
// GetGlobalPluginConfig 获取全局配置
func GetGlobalPluginConfig() *PluginConfig {
if globalPluginConfig == nil {
configPath := os.Getenv("PLUGIN_CONFIG_PATH")
if configPath == "" {
configPath = "config/plugins.yaml"
}
config, err := LoadPluginConfig(configPath)
if err != nil {
common.SysError(fmt.Sprintf("Failed to load plugin config: %v", err))
config = getDefaultConfig()
}
globalPluginConfig = config
}
return globalPluginConfig
}
// SavePluginConfig 保存插件配置
func SavePluginConfig(config *PluginConfig, configPath string) error {
if configPath == "" {
configPath = "config/plugins.yaml"
}
// 确保目录存在
dir := filepath.Dir(configPath)
if err := os.MkdirAll(dir, 0755); err != nil {
return fmt.Errorf("failed to create config directory: %w", err)
}
// 序列化为YAML
data, err := yaml.Marshal(config)
if err != nil {
return fmt.Errorf("failed to marshal config: %w", err)
}
// 写入文件
if err := ioutil.WriteFile(configPath, data, 0644); err != nil {
return fmt.Errorf("failed to write config file: %w", err)
}
common.SysLog(fmt.Sprintf("Saved plugin config to: %s", configPath))
return nil
}
// ReloadPluginConfig 重新加载配置
func ReloadPluginConfig() error {
configPath := os.Getenv("PLUGIN_CONFIG_PATH")
if configPath == "" {
configPath = "config/plugins.yaml"
}
config, err := LoadPluginConfig(configPath)
if err != nil {
return err
}
globalPluginConfig = config
common.SysLog("Plugin config reloaded")
return nil
}

52
config/plugins.yaml Normal file
View File

@@ -0,0 +1,52 @@
# New-API 插件配置
# 此文件用于配置所有插件的启用状态和参数
# Channel插件配置
channels:
openai:
enabled: true
priority: 100
claude:
enabled: true
priority: 90
gemini:
enabled: true
priority: 85
# Middleware插件配置
middlewares:
- name: auth
enabled: true
priority: 100
- name: ratelimit
enabled: true
priority: 90
config:
default_rate: 60
# Hook插件配置
hooks:
# Relay层Hook
relay:
# 联网搜索插件
- name: web_search
enabled: false # 默认禁用需要配置API key后启用
priority: 50
config:
provider: google
api_key: ${WEB_SEARCH_API_KEY} # 从环境变量读取
# 内容过滤插件
- name: content_filter
enabled: false # 默认禁用,需要配置后启用
priority: 100 # 高优先级,最后执行
config:
filter_nsfw: true
filter_political: false
sensitive_words:
- "敏感词1"
- "敏感词2"

View File

@@ -140,9 +140,13 @@ func Relay(c *gin.Context, relayFormat types.RelayFormat) {
// common.SetContextKey(c, constant.ContextKeyTokenCountMeta, meta)
newAPIError = service.PreConsumeQuota(c, priceData.ShouldPreConsumedQuota, relayInfo)
if newAPIError != nil {
return
if priceData.FreeModel {
logger.LogInfo(c, fmt.Sprintf("模型 %s 免费,跳过预扣费", relayInfo.OriginModelName))
} else {
newAPIError = service.PreConsumeQuota(c, priceData.QuotaToPreConsume, relayInfo)
if newAPIError != nil {
return
}
}
defer func() {

View File

@@ -0,0 +1,66 @@
package interfaces
import (
"io"
"net/http"
"github.com/QuantumNous/new-api/dto"
relaycommon "github.com/QuantumNous/new-api/relay/common"
"github.com/QuantumNous/new-api/types"
"github.com/gin-gonic/gin"
)
// ChannelPlugin 定义Channel插件接口
// 继承原有的Adaptor接口增加插件元数据
type ChannelPlugin interface {
// 插件元数据
Name() string
Version() string
Priority() int
// 原有Adaptor接口方法
Init(info *relaycommon.RelayInfo)
GetRequestURL(info *relaycommon.RelayInfo) (string, error)
SetupRequestHeader(c *gin.Context, req *http.Header, info *relaycommon.RelayInfo) error
ConvertOpenAIRequest(c *gin.Context, info *relaycommon.RelayInfo, request *dto.GeneralOpenAIRequest) (any, error)
ConvertRerankRequest(c *gin.Context, relayMode int, request dto.RerankRequest) (any, error)
ConvertEmbeddingRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.EmbeddingRequest) (any, error)
ConvertAudioRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.AudioRequest) (io.Reader, error)
ConvertImageRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.ImageRequest) (any, error)
ConvertOpenAIResponsesRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.OpenAIResponsesRequest) (any, error)
DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (any, error)
DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage any, err *types.NewAPIError)
GetModelList() []string
GetChannelName() string
ConvertClaudeRequest(c *gin.Context, info *relaycommon.RelayInfo, request *dto.ClaudeRequest) (any, error)
ConvertGeminiRequest(c *gin.Context, info *relaycommon.RelayInfo, request *dto.GeminiChatRequest) (any, error)
}
// TaskChannelPlugin 定义Task类型的Channel插件接口
type TaskChannelPlugin interface {
// 插件元数据
Name() string
Version() string
Priority() int
// 原有TaskAdaptor接口方法
Init(info *relaycommon.RelayInfo)
ValidateRequestAndSetAction(c *gin.Context, info *relaycommon.RelayInfo) *dto.TaskError
BuildRequestURL(info *relaycommon.RelayInfo) (string, error)
BuildRequestHeader(c *gin.Context, req *http.Request, info *relaycommon.RelayInfo) error
BuildRequestBody(c *gin.Context, info *relaycommon.RelayInfo) (io.Reader, error)
DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (*http.Response, error)
DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (taskID string, taskData []byte, err *dto.TaskError)
GetModelList() []string
GetChannelName() string
FetchTask(baseUrl, key string, body map[string]any) (*http.Response, error)
ParseTaskResult(respBody []byte) (*relaycommon.TaskInfo, error)
}
// ChannelConfig 插件配置
type ChannelConfig struct {
Enabled bool `yaml:"enabled"`
Priority int `yaml:"priority"`
Config map[string]interface{} `yaml:"config"`
}

93
core/interfaces/hook.go Normal file
View File

@@ -0,0 +1,93 @@
package interfaces
import (
"io"
"net/http"
"github.com/gin-gonic/gin"
)
// HookContext Relay Hook执行上下文
type HookContext struct {
// Gin Context
GinContext *gin.Context
// Request相关
Request *http.Request
RequestBody []byte
// Response相关
Response *http.Response
ResponseBody []byte
// Channel信息
ChannelID int
ChannelType int
ChannelName string
// Model信息
Model string
OriginalModel string
// User信息
UserID int
TokenID int
Group string
// 扩展数据(插件间共享)
Data map[string]interface{}
// 错误信息
Error error
// 是否跳过后续处理
ShouldSkip bool
}
// RelayHook Relay Hook接口
type RelayHook interface {
// 插件元数据
Name() string
Priority() int
Enabled() bool
// 生命周期钩子
// OnBeforeRequest 在请求发送到上游之前执行
OnBeforeRequest(ctx *HookContext) error
// OnAfterResponse 在收到上游响应之后执行
OnAfterResponse(ctx *HookContext) error
// OnError 在发生错误时执行
OnError(ctx *HookContext, err error) error
}
// RequestModifier 请求修改器接口
// 实现此接口的Hook可以修改请求内容
type RequestModifier interface {
RelayHook
ModifyRequest(ctx *HookContext, body io.Reader) (io.Reader, error)
}
// ResponseProcessor 响应处理器接口
// 实现此接口的Hook可以处理响应内容
type ResponseProcessor interface {
RelayHook
ProcessResponse(ctx *HookContext, body []byte) ([]byte, error)
}
// StreamProcessor 流式响应处理器接口
// 实现此接口的Hook可以处理流式响应
type StreamProcessor interface {
RelayHook
ProcessStreamChunk(ctx *HookContext, chunk []byte) ([]byte, error)
}
// HookConfig Hook配置
type HookConfig struct {
Name string `yaml:"name"`
Enabled bool `yaml:"enabled"`
Priority int `yaml:"priority"`
Config map[string]interface{} `yaml:"config"`
}

View File

@@ -0,0 +1,31 @@
package interfaces
import (
"github.com/gin-gonic/gin"
)
// MiddlewarePlugin 中间件插件接口
type MiddlewarePlugin interface {
// 插件元数据
Name() string
Priority() int
Enabled() bool
// 返回Gin中间件处理函数
Handler() gin.HandlerFunc
// 初始化(可选)
Initialize(config MiddlewareConfig) error
}
// MiddlewareConfig 中间件配置
type MiddlewareConfig struct {
Name string `yaml:"name"`
Enabled bool `yaml:"enabled"`
Priority int `yaml:"priority"`
Config map[string]interface{} `yaml:"config"`
}
// MiddlewareFactory 中间件工厂函数类型
type MiddlewareFactory func(config MiddlewareConfig) (MiddlewarePlugin, error)

View File

@@ -0,0 +1,171 @@
package registry
import (
"fmt"
"sync"
"github.com/QuantumNous/new-api/core/interfaces"
)
var (
// 全局Channel注册表
channelRegistry = &ChannelRegistry{plugins: make(map[int]interfaces.ChannelPlugin)}
channelRegistryLock sync.RWMutex
// 全局TaskChannel注册表
taskChannelRegistry = &TaskChannelRegistry{plugins: make(map[string]interfaces.TaskChannelPlugin)}
taskChannelRegistryLock sync.RWMutex
)
// ChannelRegistry Channel插件注册中心
type ChannelRegistry struct {
plugins map[int]interfaces.ChannelPlugin // channelType -> plugin
mu sync.RWMutex
}
// Register 注册Channel插件
func (r *ChannelRegistry) Register(channelType int, plugin interfaces.ChannelPlugin) error {
r.mu.Lock()
defer r.mu.Unlock()
if _, exists := r.plugins[channelType]; exists {
return fmt.Errorf("channel plugin for type %d already registered", channelType)
}
r.plugins[channelType] = plugin
return nil
}
// Get 获取Channel插件
func (r *ChannelRegistry) Get(channelType int) (interfaces.ChannelPlugin, error) {
r.mu.RLock()
defer r.mu.RUnlock()
plugin, exists := r.plugins[channelType]
if !exists {
return nil, fmt.Errorf("channel plugin for type %d not found", channelType)
}
return plugin, nil
}
// List 列出所有已注册的Channel插件
func (r *ChannelRegistry) List() []interfaces.ChannelPlugin {
r.mu.RLock()
defer r.mu.RUnlock()
plugins := make([]interfaces.ChannelPlugin, 0, len(r.plugins))
for _, plugin := range r.plugins {
plugins = append(plugins, plugin)
}
return plugins
}
// Has 检查是否存在指定的Channel插件
func (r *ChannelRegistry) Has(channelType int) bool {
r.mu.RLock()
defer r.mu.RUnlock()
_, exists := r.plugins[channelType]
return exists
}
// TaskChannelRegistry TaskChannel插件注册中心
type TaskChannelRegistry struct {
plugins map[string]interfaces.TaskChannelPlugin // platform -> plugin
mu sync.RWMutex
}
// Register 注册TaskChannel插件
func (r *TaskChannelRegistry) Register(platform string, plugin interfaces.TaskChannelPlugin) error {
r.mu.Lock()
defer r.mu.Unlock()
if _, exists := r.plugins[platform]; exists {
return fmt.Errorf("task channel plugin for platform %s already registered", platform)
}
r.plugins[platform] = plugin
return nil
}
// Get 获取TaskChannel插件
func (r *TaskChannelRegistry) Get(platform string) (interfaces.TaskChannelPlugin, error) {
r.mu.RLock()
defer r.mu.RUnlock()
plugin, exists := r.plugins[platform]
if !exists {
return nil, fmt.Errorf("task channel plugin for platform %s not found", platform)
}
return plugin, nil
}
// List 列出所有已注册的TaskChannel插件
func (r *TaskChannelRegistry) List() []interfaces.TaskChannelPlugin {
r.mu.RLock()
defer r.mu.RUnlock()
plugins := make([]interfaces.TaskChannelPlugin, 0, len(r.plugins))
for _, plugin := range r.plugins {
plugins = append(plugins, plugin)
}
return plugins
}
// 全局函数 - Channel Registry
// RegisterChannel 注册Channel插件
func RegisterChannel(channelType int, plugin interfaces.ChannelPlugin) error {
channelRegistryLock.Lock()
defer channelRegistryLock.Unlock()
return channelRegistry.Register(channelType, plugin)
}
// GetChannel 获取Channel插件
func GetChannel(channelType int) (interfaces.ChannelPlugin, error) {
channelRegistryLock.RLock()
defer channelRegistryLock.RUnlock()
return channelRegistry.Get(channelType)
}
// ListChannels 列出所有Channel插件
func ListChannels() []interfaces.ChannelPlugin {
channelRegistryLock.RLock()
defer channelRegistryLock.RUnlock()
return channelRegistry.List()
}
// HasChannel 检查是否存在指定的Channel插件
func HasChannel(channelType int) bool {
channelRegistryLock.RLock()
defer channelRegistryLock.RUnlock()
return channelRegistry.Has(channelType)
}
// 全局函数 - TaskChannel Registry
// RegisterTaskChannel 注册TaskChannel插件
func RegisterTaskChannel(platform string, plugin interfaces.TaskChannelPlugin) error {
taskChannelRegistryLock.Lock()
defer taskChannelRegistryLock.Unlock()
return taskChannelRegistry.Register(platform, plugin)
}
// GetTaskChannel 获取TaskChannel插件
func GetTaskChannel(platform string) (interfaces.TaskChannelPlugin, error) {
taskChannelRegistryLock.RLock()
defer taskChannelRegistryLock.RUnlock()
return taskChannelRegistry.Get(platform)
}
// ListTaskChannels 列出所有TaskChannel插件
func ListTaskChannels() []interfaces.TaskChannelPlugin {
taskChannelRegistryLock.RLock()
defer taskChannelRegistryLock.RUnlock()
return taskChannelRegistry.List()
}

View File

@@ -0,0 +1,183 @@
package registry
import (
"fmt"
"sort"
"sync"
"github.com/QuantumNous/new-api/core/interfaces"
)
var (
// 全局Hook注册表
hookRegistry = &HookRegistry{hooks: make([]interfaces.RelayHook, 0)}
hookRegistryLock sync.RWMutex
)
// HookRegistry Hook插件注册中心
type HookRegistry struct {
hooks []interfaces.RelayHook
sorted bool
mu sync.RWMutex
}
// Register 注册Hook插件
func (r *HookRegistry) Register(hook interfaces.RelayHook) error {
r.mu.Lock()
defer r.mu.Unlock()
// 检查是否已存在同名Hook
for _, h := range r.hooks {
if h.Name() == hook.Name() {
return fmt.Errorf("hook %s already registered", hook.Name())
}
}
r.hooks = append(r.hooks, hook)
r.sorted = false // 标记需要重新排序
return nil
}
// Get 获取指定名称的Hook插件
func (r *HookRegistry) Get(name string) (interfaces.RelayHook, error) {
r.mu.RLock()
defer r.mu.RUnlock()
for _, hook := range r.hooks {
if hook.Name() == name {
return hook, nil
}
}
return nil, fmt.Errorf("hook %s not found", name)
}
// List 列出所有已注册且启用的Hook插件按优先级排序
func (r *HookRegistry) List() []interfaces.RelayHook {
r.mu.Lock()
defer r.mu.Unlock()
// 如果未排序,先排序
if !r.sorted {
r.sortHooks()
}
// 只返回启用的hooks
enabledHooks := make([]interfaces.RelayHook, 0)
for _, hook := range r.hooks {
if hook.Enabled() {
enabledHooks = append(enabledHooks, hook)
}
}
return enabledHooks
}
// ListAll 列出所有已注册的Hook插件包括未启用的
func (r *HookRegistry) ListAll() []interfaces.RelayHook {
r.mu.RLock()
defer r.mu.RUnlock()
hooks := make([]interfaces.RelayHook, len(r.hooks))
copy(hooks, r.hooks)
return hooks
}
// sortHooks 按优先级排序hooks优先级数字越大越先执行
func (r *HookRegistry) sortHooks() {
sort.SliceStable(r.hooks, func(i, j int) bool {
return r.hooks[i].Priority() > r.hooks[j].Priority()
})
r.sorted = true
}
// Has 检查是否存在指定的Hook插件
func (r *HookRegistry) Has(name string) bool {
r.mu.RLock()
defer r.mu.RUnlock()
for _, hook := range r.hooks {
if hook.Name() == name {
return true
}
}
return false
}
// Count 返回已注册的Hook数量
func (r *HookRegistry) Count() int {
r.mu.RLock()
defer r.mu.RUnlock()
return len(r.hooks)
}
// EnabledCount 返回已启用的Hook数量
func (r *HookRegistry) EnabledCount() int {
r.mu.RLock()
defer r.mu.RUnlock()
count := 0
for _, hook := range r.hooks {
if hook.Enabled() {
count++
}
}
return count
}
// 全局函数
// RegisterHook 注册Hook插件
func RegisterHook(hook interfaces.RelayHook) error {
hookRegistryLock.Lock()
defer hookRegistryLock.Unlock()
return hookRegistry.Register(hook)
}
// GetHook 获取Hook插件
func GetHook(name string) (interfaces.RelayHook, error) {
hookRegistryLock.RLock()
defer hookRegistryLock.RUnlock()
return hookRegistry.Get(name)
}
// ListHooks 列出所有已启用的Hook插件
func ListHooks() []interfaces.RelayHook {
hookRegistryLock.RLock()
defer hookRegistryLock.RUnlock()
return hookRegistry.List()
}
// ListAllHooks 列出所有Hook插件
func ListAllHooks() []interfaces.RelayHook {
hookRegistryLock.RLock()
defer hookRegistryLock.RUnlock()
return hookRegistry.ListAll()
}
// HasHook 检查是否存在指定的Hook插件
func HasHook(name string) bool {
hookRegistryLock.RLock()
defer hookRegistryLock.RUnlock()
return hookRegistry.Has(name)
}
// HookCount 返回已注册的Hook数量
func HookCount() int {
hookRegistryLock.RLock()
defer hookRegistryLock.RUnlock()
return hookRegistry.Count()
}
// EnabledHookCount 返回已启用的Hook数量
func EnabledHookCount() int {
hookRegistryLock.RLock()
defer hookRegistryLock.RUnlock()
return hookRegistry.EnabledCount()
}

View File

@@ -0,0 +1,133 @@
package registry
import (
"fmt"
"sort"
"sync"
"github.com/QuantumNous/new-api/core/interfaces"
)
var (
// 全局Middleware注册表
middlewareRegistry = &MiddlewareRegistry{plugins: make(map[string]interfaces.MiddlewarePlugin)}
middlewareRegistryLock sync.RWMutex
)
// MiddlewareRegistry 中间件插件注册中心
type MiddlewareRegistry struct {
plugins map[string]interfaces.MiddlewarePlugin
mu sync.RWMutex
}
// Register 注册Middleware插件
func (r *MiddlewareRegistry) Register(plugin interfaces.MiddlewarePlugin) error {
r.mu.Lock()
defer r.mu.Unlock()
name := plugin.Name()
if _, exists := r.plugins[name]; exists {
return fmt.Errorf("middleware plugin %s already registered", name)
}
r.plugins[name] = plugin
return nil
}
// Get 获取Middleware插件
func (r *MiddlewareRegistry) Get(name string) (interfaces.MiddlewarePlugin, error) {
r.mu.RLock()
defer r.mu.RUnlock()
plugin, exists := r.plugins[name]
if !exists {
return nil, fmt.Errorf("middleware plugin %s not found", name)
}
return plugin, nil
}
// List 列出所有已注册的Middleware插件按优先级排序
func (r *MiddlewareRegistry) List() []interfaces.MiddlewarePlugin {
r.mu.RLock()
defer r.mu.RUnlock()
plugins := make([]interfaces.MiddlewarePlugin, 0, len(r.plugins))
for _, plugin := range r.plugins {
plugins = append(plugins, plugin)
}
// 按优先级排序(优先级数字越大越先执行)
sort.SliceStable(plugins, func(i, j int) bool {
return plugins[i].Priority() > plugins[j].Priority()
})
return plugins
}
// ListEnabled 列出所有已启用的Middleware插件按优先级排序
func (r *MiddlewareRegistry) ListEnabled() []interfaces.MiddlewarePlugin {
r.mu.RLock()
defer r.mu.RUnlock()
plugins := make([]interfaces.MiddlewarePlugin, 0, len(r.plugins))
for _, plugin := range r.plugins {
if plugin.Enabled() {
plugins = append(plugins, plugin)
}
}
// 按优先级排序
sort.SliceStable(plugins, func(i, j int) bool {
return plugins[i].Priority() > plugins[j].Priority()
})
return plugins
}
// Has 检查是否存在指定的Middleware插件
func (r *MiddlewareRegistry) Has(name string) bool {
r.mu.RLock()
defer r.mu.RUnlock()
_, exists := r.plugins[name]
return exists
}
// 全局函数
// RegisterMiddleware 注册Middleware插件
func RegisterMiddleware(plugin interfaces.MiddlewarePlugin) error {
middlewareRegistryLock.Lock()
defer middlewareRegistryLock.Unlock()
return middlewareRegistry.Register(plugin)
}
// GetMiddleware 获取Middleware插件
func GetMiddleware(name string) (interfaces.MiddlewarePlugin, error) {
middlewareRegistryLock.RLock()
defer middlewareRegistryLock.RUnlock()
return middlewareRegistry.Get(name)
}
// ListMiddlewares 列出所有Middleware插件
func ListMiddlewares() []interfaces.MiddlewarePlugin {
middlewareRegistryLock.RLock()
defer middlewareRegistryLock.RUnlock()
return middlewareRegistry.List()
}
// ListEnabledMiddlewares 列出所有已启用的Middleware插件
func ListEnabledMiddlewares() []interfaces.MiddlewarePlugin {
middlewareRegistryLock.RLock()
defer middlewareRegistryLock.RUnlock()
return middlewareRegistry.ListEnabled()
}
// HasMiddleware 检查是否存在指定的Middleware插件
func HasMiddleware(name string) bool {
middlewareRegistryLock.RLock()
defer middlewareRegistryLock.RUnlock()
return middlewareRegistry.Has(name)
}

View File

@@ -0,0 +1,116 @@
package registry
import (
"testing"
"github.com/QuantumNous/new-api/core/interfaces"
)
// Mock Hook实现
type mockHook struct {
name string
priority int
enabled bool
}
func (m *mockHook) Name() string { return m.name }
func (m *mockHook) Priority() int { return m.priority }
func (m *mockHook) Enabled() bool { return m.enabled }
func (m *mockHook) OnBeforeRequest(ctx *interfaces.HookContext) error { return nil }
func (m *mockHook) OnAfterResponse(ctx *interfaces.HookContext) error { return nil }
func (m *mockHook) OnError(ctx *interfaces.HookContext, err error) error { return nil }
func TestHookRegistry(t *testing.T) {
// 创建新的注册表(用于测试)
registry := &HookRegistry{hooks: make([]interfaces.RelayHook, 0)}
// 测试注册Hook
hook1 := &mockHook{name: "test_hook_1", priority: 100, enabled: true}
hook2 := &mockHook{name: "test_hook_2", priority: 50, enabled: true}
hook3 := &mockHook{name: "test_hook_3", priority: 75, enabled: false}
if err := registry.Register(hook1); err != nil {
t.Errorf("Failed to register hook1: %v", err)
}
if err := registry.Register(hook2); err != nil {
t.Errorf("Failed to register hook2: %v", err)
}
if err := registry.Register(hook3); err != nil {
t.Errorf("Failed to register hook3: %v", err)
}
// 测试重复注册
if err := registry.Register(hook1); err == nil {
t.Error("Expected error when registering duplicate hook")
}
// 测试获取Hook
if hook, err := registry.Get("test_hook_1"); err != nil {
t.Errorf("Failed to get hook: %v", err)
} else if hook.Name() != "test_hook_1" {
t.Errorf("Got wrong hook: %s", hook.Name())
}
// 测试不存在的Hook
if _, err := registry.Get("nonexistent"); err == nil {
t.Error("Expected error when getting nonexistent hook")
}
// 测试List应该只返回enabled的hooks
hooks := registry.List()
if len(hooks) != 2 {
t.Errorf("Expected 2 enabled hooks, got %d", len(hooks))
}
// 测试优先级排序100应该在50之前
if hooks[0].Priority() != 100 {
t.Errorf("Expected first hook to have priority 100, got %d", hooks[0].Priority())
}
// 测试Count
if count := registry.Count(); count != 3 {
t.Errorf("Expected count 3, got %d", count)
}
// 测试EnabledCount
if count := registry.EnabledCount(); count != 2 {
t.Errorf("Expected enabled count 2, got %d", count)
}
// 测试Has
if !registry.Has("test_hook_1") {
t.Error("Expected to find test_hook_1")
}
if registry.Has("nonexistent") {
t.Error("Should not find nonexistent hook")
}
}
func TestChannelRegistry(t *testing.T) {
// 这里可以添加Channel Registry的测试
// 但需要mock ChannelPlugin接口比较复杂
// 作为示例,我们只测试基本功能
registry := &ChannelRegistry{plugins: make(map[int]interfaces.ChannelPlugin)}
// 测试Has方法
if registry.Has(1) {
t.Error("Should not find channel type 1")
}
}
func TestMiddlewareRegistry(t *testing.T) {
// Middleware Registry测试
// 需要mock MiddlewarePlugin接口
registry := &MiddlewareRegistry{plugins: make(map[string]interfaces.MiddlewarePlugin)}
// 测试Has方法
if registry.Has("test_middleware") {
t.Error("Should not find test_middleware")
}
}

View File

@@ -0,0 +1,359 @@
# New-API 插件化架构说明
## 完整目录结构
```
new-api-2/
├── core/ # 核心层(高性能,不可插件化)
│ ├── interfaces/ # 插件接口定义
│ │ ├── channel.go # Channel插件接口
│ │ ├── hook.go # Hook插件接口
│ │ └── middleware.go # Middleware插件接口
│ └── registry/ # 插件注册中心
│ ├── channel_registry.go # Channel注册器线程安全
│ ├── hook_registry.go # Hook注册器优先级排序
│ └── middleware_registry.go # Middleware注册器
├── plugins/ # 🔵 Tier 1: 编译时插件(已实施)
│ ├── channels/ # Channel插件
│ │ ├── base_plugin.go # 基础插件包装器
│ │ └── registry.go # 自动注册31个AI Provider
│ └── hooks/ # Hook插件
│ ├── web_search/ # 联网搜索Hook
│ │ ├── web_search_hook.go
│ │ └── init.go
│ └── content_filter/ # 内容过滤Hook
│ ├── content_filter_hook.go
│ └── init.go
├── marketplace/ # 🟣 Tier 2: 运行时插件待实施Phase 2
│ ├── loader/ # go-plugin加载器
│ │ ├── plugin_client.go # 插件客户端
│ │ ├── plugin_server.go # 插件服务器
│ │ └── lifecycle.go # 生命周期管理
│ ├── manager/ # 插件管理器
│ │ ├── installer.go # 安装/卸载
│ │ ├── updater.go # 版本更新
│ │ └── registry.go # 插件注册表
│ ├── security/ # 安全模块
│ │ ├── signature.go # Ed25519签名验证
│ │ ├── checksum.go # SHA256校验
│ │ └── sandbox.go # 沙箱配置
│ ├── store/ # 插件商店客户端
│ │ ├── client.go # 商店API客户端
│ │ ├── search.go # 搜索功能
│ │ └── download.go # 下载管理
│ └── proto/ # gRPC协议定义
│ ├── hook.proto # Hook插件协议
│ ├── channel.proto # Channel插件协议
│ └── common.proto # 通用消息
├── plugins_external/ # 第三方插件安装目录
│ ├── installed/ # 已安装插件
│ │ ├── awesome-hook-v1.0.0/
│ │ ├── custom-llm-v2.1.0/
│ │ └── slack-notify-v1.5.0/
│ ├── cache/ # 下载缓存
│ └── temp/ # 临时文件
├── relay/ # Relay层
│ ├── hooks/ # Hook执行链
│ │ ├── chain.go # Hook链管理器
│ │ ├── context.go # Hook上下文
│ │ └── context_builder.go # 上下文构建器
│ └── relay_adaptor.go # Channel适配器优先从Registry获取
├── config/ # 配置系统
│ ├── plugins.yaml # 插件配置Tier 1 + Tier 2
│ └── plugin_config.go # 配置加载器(支持环境变量)
└── (其他现有目录保持不变)
```
---
## 完整架构图
### 系统架构总览
```mermaid
graph TB
subgraph "🌐 API层"
Client[客户端请求]
end
subgraph "🔐 中间件层"
Auth[认证中间件]
RateLimit[限流中间件]
Cache[缓存中间件]
end
subgraph "🎯 核心层 Core"
Registry[插件注册中心]
ChannelReg[Channel Registry]
HookReg[Hook Registry]
MidReg[Middleware Registry]
Registry --> ChannelReg
Registry --> HookReg
Registry --> MidReg
end
subgraph "🔵 Tier 1: 编译时插件(已实施)"
direction TB
Channels[31个 Channel Plugins]
OpenAI[OpenAI]
Claude[Claude]
Gemini[Gemini]
Others[其他28个...]
Channels --> OpenAI
Channels --> Claude
Channels --> Gemini
Channels --> Others
Hooks[Hook Plugins]
WebSearch[Web Search Hook]
ContentFilter[Content Filter Hook]
Hooks --> WebSearch
Hooks --> ContentFilter
end
subgraph "🟣 Tier 2: 运行时插件(待实施)"
direction TB
Marketplace[🏪 Plugin Marketplace]
ExtHook[External Hooks<br/>Python/Go/Node.js]
ExtChannel[External Channels<br/>小众AI提供商]
ExtMid[External Middleware<br/>企业集成]
ExtUI[UI Extensions<br/>自定义仪表板]
Marketplace --> ExtHook
Marketplace --> ExtChannel
Marketplace --> ExtMid
Marketplace --> ExtUI
end
subgraph "⚡ Relay执行流程"
direction LR
HookChain[Hook Chain]
BeforeHook[OnBeforeRequest]
ChannelAdaptor[Channel Adaptor]
AfterHook[OnAfterResponse]
HookChain --> BeforeHook
BeforeHook --> ChannelAdaptor
ChannelAdaptor --> AfterHook
end
subgraph "🌍 上游服务"
Upstream[AI Provider APIs]
end
Client --> Auth
Auth --> RateLimit
RateLimit --> Cache
Cache --> Registry
Channels --> ChannelReg
Hooks --> HookReg
Registry --> HookChain
HookChain --> Upstream
Upstream --> HookChain
Registry -.gRPC/RPC.-> ExtHook
Registry -.gRPC/RPC.-> ExtChannel
Registry -.gRPC/RPC.-> ExtMid
style Marketplace fill:#f9f,stroke:#333,stroke-width:4px
style Registry fill:#bbf,stroke:#333,stroke-width:4px
style Channels fill:#bfb,stroke:#333,stroke-width:2px
style Hooks fill:#bfb,stroke:#333,stroke-width:2px
```
### 双层插件系统架构
```mermaid
graph LR
subgraph "🔵 Tier 1: 编译时插件"
T1[性能: 100%<br/>语言: Go only<br/>部署: 编译到二进制]
T1Chan[31 Channels]
T1Hook[2 Hooks]
T1 --> T1Chan
T1 --> T1Hook
end
subgraph "🟣 Tier 2: 运行时插件"
T2[性能: 90-95%<br/>语言: Go/Python/Node.js<br/>部署: 独立进程]
T2Hook[External Hooks]
T2Chan[External Channels]
T2Mid[External Middleware]
T2UI[UI Extensions]
T2 --> T2Hook
T2 --> T2Chan
T2 --> T2Mid
T2 --> T2UI
end
T1 -.进程内调用.-> Core[Core System]
T2 -.gRPC/RPC.-> Core
style T1 fill:#bfb,stroke:#333,stroke-width:3px
style T2 fill:#f9f,stroke:#333,stroke-width:3px
style Core fill:#bbf,stroke:#333,stroke-width:3px
```
---
## 核心要点说明
### 1. 双层插件架构
| 层级 | 技术方案 | 性能 | 适用场景 | 开发语言 |
|------|---------|------|---------|---------|
| **Tier 1<br/>编译时插件** | 编译时链接 | 100%<br/>零损失 | • 核心ChannelOpenAI等<br/>• 内置Hook<br/>• 高频调用路径 | Go only |
| **Tier 2<br/>运行时插件** | go-plugin<br/>gRPC | 90-95%<br/>5-10%开销 | • 第三方扩展<br/>• 企业定制<br/>• 多语言集成 | Go/Python/<br/>Node.js/Rust |
### 2. 核心组件
#### Core层核心引擎
- **interfaces/**: 定义ChannelPlugin、RelayHook、MiddlewarePlugin接口
- **registry/**: 线程安全的插件注册中心支持O(1)查找、优先级排序
#### Relay Hook链
- **执行流程**: OnBeforeRequest → Channel.DoRequest → OnAfterResponse
- **特性**: 优先级排序、短路机制、数据共享HookContext.Data
- **应用场景**: 联网搜索、内容过滤、日志增强、缓存策略
### 3. Tier 1: 编译时插件(已实施 ✅)
**特点**:
- 零性能损失,编译后与硬编码无差异
- init()函数自动注册到Registry
- YAML配置启用/禁用
**已实现**:
- ✅ 31个Channel插件OpenAI、Claude、Gemini等
- ✅ 2个Hook插件web_search、content_filter
- ✅ Hook执行链
- ✅ 配置系统(支持环境变量展开)
### 4. Tier 2: 运行时插件(待实施 🚧)
**基于**: [hashicorp/go-plugin](https://github.com/hashicorp/go-plugin)Vault/Terraform使用
**优势**:
- ✅ 进程隔离(第三方代码崩溃不影响主程序)
- ✅ 多语言支持gRPC协议
- ✅ 热插拔(无需重启)
- ✅ 安全验证Ed25519签名 + SHA256校验 + TLS加密
- ✅ 独立分发(插件商店)
**适用场景**:
- 第三方开发者扩展
- 企业定制业务逻辑
- Python ML模型集成
- 第三方服务集成Slack/钉钉/企业微信)
- UI扩展
### 5. 安全机制
**Tier 1编译时**:
- 内部代码审查
- 编译期类型安全
**Tier 2运行时**:
- Ed25519签名验证
- SHA256校验和
- gRPC TLS加密
- 进程资源限制(内存/CPU
- 插件商店审核机制
- 可信发布者白名单
### 6. 配置系统
**单一配置文件**: `config/plugins.yaml`
```yaml
# Tier 1: 编译时插件
plugins:
hooks:
- name: web_search
enabled: false
priority: 50
config:
api_key: ${WEB_SEARCH_API_KEY}
# Tier 2: 运行时插件(待实施)
external_plugins:
enabled: true
hooks:
- name: awesome_hook
binary: awesome-hook-v1.0.0/awesome-hook
checksum: sha256:abc123...
# 插件商店
marketplace:
enabled: true
api_url: https://plugins.new-api.com
```
### 7. 性能对比
| 场景 | Tier 1 | Tier 2 | RPC开销 |
|------|--------|--------|--------|
| 核心Channel | 100% | N/A | 0% |
| 内置Hook | 100% | N/A | 0% |
| 第三方Hook | N/A | 92-95% | 5-8% |
| Python插件 | N/A | 88-92% | 8-12% |
### 8. 实施路线图
#### Phase 1: 编译时插件系统 ✅ 已完成
- Core Registry + Hook Chain
- 31个Channel插件 + 2个Hook示例
- YAML配置系统
#### Phase 2: go-plugin基础
- protobuf协议定义
- PluginLoader实现
- 签名验证系统
- Python/Go SDK
#### Phase 3: 插件商店
- 商店后端API
- Web UI搜索、安装、管理
- CLI工具
- 多语言SDK
### 9. 扩展示例
**新增Tier 1插件编译时**:
```go
// 1. 实现接口
type MyHook struct{}
func (h *MyHook) OnBeforeRequest(ctx *HookContext) error { /*...*/ }
// 2. 注册
func init() { registry.RegisterHook(&MyHook{}) }
// 3. 导入到main.go
import _ "github.com/xxx/plugins/hooks/my_hook"
```
**新增Tier 2插件运行时**:
```python
# external-plugin/my_hook.py
from new_api_plugin_sdk import HookPlugin, serve
class MyHook(HookPlugin):
def on_before_request(self, ctx):
return {"modified_body": ctx.request_body}
serve(MyHook())
```

1
go.mod
View File

@@ -40,6 +40,7 @@ require (
golang.org/x/image v0.23.0
golang.org/x/net v0.43.0
golang.org/x/sync v0.17.0
gopkg.in/yaml.v3 v3.0.1
gorm.io/driver/mysql v1.4.3
gorm.io/driver/postgres v1.5.2
gorm.io/gorm v1.25.2

36
main.go
View File

@@ -21,6 +21,13 @@ import (
"github.com/QuantumNous/new-api/service"
"github.com/QuantumNous/new-api/setting/ratio_setting"
// Plugin System
coreregistry "github.com/QuantumNous/new-api/core/registry"
_ "github.com/QuantumNous/new-api/plugins/channels" // 自动注册channel插件
_ "github.com/QuantumNous/new-api/plugins/hooks/web_search" // 自动注册web_search hook
_ "github.com/QuantumNous/new-api/plugins/hooks/content_filter" // 自动注册content_filter hook
relayhooks "github.com/QuantumNous/new-api/relay/hooks"
"github.com/bytedance/gopkg/util/gopool"
"github.com/gin-contrib/sessions"
"github.com/gin-contrib/sessions/cookie"
@@ -229,5 +236,34 @@ func InitResources() error {
if err != nil {
return err
}
// Initialize Plugin System
InitPluginSystem()
return nil
}
// InitPluginSystem 初始化插件系统
func InitPluginSystem() {
common.SysLog("Initializing plugin system...")
// 1. 加载插件配置
// config.LoadPluginConfig() 会在各个插件的init()中自动调用
// 2. 注册Channel插件
// 注意:这会在 plugins/channels/registry.go 的 init() 中自动完成
// 但为了确保加载,我们显式导入
common.SysLog("Registering channel plugins...")
// 3. 初始化Hook链
common.SysLog("Initializing hook chain...")
_ = relayhooks.GetGlobalChain()
hookCount := coreregistry.HookCount()
enabledHookCount := coreregistry.EnabledHookCount()
common.SysLog(fmt.Sprintf("Plugin system initialized: %d hooks registered (%d enabled)",
hookCount, enabledHookCount))
channelCount := len(coreregistry.ListChannels())
common.SysLog(fmt.Sprintf("Registered %d channel plugins", channelCount))
}

View File

@@ -0,0 +1,114 @@
package channels
import (
"io"
"net/http"
"github.com/QuantumNous/new-api/dto"
"github.com/QuantumNous/new-api/relay/channel"
relaycommon "github.com/QuantumNous/new-api/relay/common"
"github.com/QuantumNous/new-api/types"
"github.com/gin-gonic/gin"
)
// BaseChannelPlugin 基础Channel插件
// 包装现有的Adaptor实现使其符合ChannelPlugin接口
type BaseChannelPlugin struct {
adaptor channel.Adaptor
name string
version string
priority int
}
// NewBaseChannelPlugin 创建基础Channel插件
func NewBaseChannelPlugin(adaptor channel.Adaptor, name, version string, priority int) *BaseChannelPlugin {
return &BaseChannelPlugin{
adaptor: adaptor,
name: name,
version: version,
priority: priority,
}
}
// Name 返回插件名称
func (p *BaseChannelPlugin) Name() string {
return p.name
}
// Version 返回插件版本
func (p *BaseChannelPlugin) Version() string {
return p.version
}
// Priority 返回优先级
func (p *BaseChannelPlugin) Priority() int {
return p.priority
}
// 以下方法直接委托给内部的Adaptor
func (p *BaseChannelPlugin) Init(info *relaycommon.RelayInfo) {
p.adaptor.Init(info)
}
func (p *BaseChannelPlugin) GetRequestURL(info *relaycommon.RelayInfo) (string, error) {
return p.adaptor.GetRequestURL(info)
}
func (p *BaseChannelPlugin) SetupRequestHeader(c *gin.Context, req *http.Header, info *relaycommon.RelayInfo) error {
return p.adaptor.SetupRequestHeader(c, req, info)
}
func (p *BaseChannelPlugin) ConvertOpenAIRequest(c *gin.Context, info *relaycommon.RelayInfo, request *dto.GeneralOpenAIRequest) (any, error) {
return p.adaptor.ConvertOpenAIRequest(c, info, request)
}
func (p *BaseChannelPlugin) ConvertRerankRequest(c *gin.Context, relayMode int, request dto.RerankRequest) (any, error) {
return p.adaptor.ConvertRerankRequest(c, relayMode, request)
}
func (p *BaseChannelPlugin) ConvertEmbeddingRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.EmbeddingRequest) (any, error) {
return p.adaptor.ConvertEmbeddingRequest(c, info, request)
}
func (p *BaseChannelPlugin) ConvertAudioRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.AudioRequest) (io.Reader, error) {
return p.adaptor.ConvertAudioRequest(c, info, request)
}
func (p *BaseChannelPlugin) ConvertImageRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.ImageRequest) (any, error) {
return p.adaptor.ConvertImageRequest(c, info, request)
}
func (p *BaseChannelPlugin) ConvertOpenAIResponsesRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.OpenAIResponsesRequest) (any, error) {
return p.adaptor.ConvertOpenAIResponsesRequest(c, info, request)
}
func (p *BaseChannelPlugin) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (any, error) {
return p.adaptor.DoRequest(c, info, requestBody)
}
func (p *BaseChannelPlugin) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage any, err *types.NewAPIError) {
return p.adaptor.DoResponse(c, resp, info)
}
func (p *BaseChannelPlugin) GetModelList() []string {
return p.adaptor.GetModelList()
}
func (p *BaseChannelPlugin) GetChannelName() string {
return p.adaptor.GetChannelName()
}
func (p *BaseChannelPlugin) ConvertClaudeRequest(c *gin.Context, info *relaycommon.RelayInfo, request *dto.ClaudeRequest) (any, error) {
return p.adaptor.ConvertClaudeRequest(c, info, request)
}
func (p *BaseChannelPlugin) ConvertGeminiRequest(c *gin.Context, info *relaycommon.RelayInfo, request *dto.GeminiChatRequest) (any, error) {
return p.adaptor.ConvertGeminiRequest(c, info, request)
}
// GetAdaptor 获取内部的Adaptor用于向后兼容
func (p *BaseChannelPlugin) GetAdaptor() channel.Adaptor {
return p.adaptor
}

View File

@@ -0,0 +1,106 @@
package channels
import (
"fmt"
"github.com/QuantumNous/new-api/common"
"github.com/QuantumNous/new-api/constant"
"github.com/QuantumNous/new-api/core/registry"
"github.com/QuantumNous/new-api/relay/channel"
"github.com/QuantumNous/new-api/relay/channel/ali"
"github.com/QuantumNous/new-api/relay/channel/aws"
"github.com/QuantumNous/new-api/relay/channel/baidu"
"github.com/QuantumNous/new-api/relay/channel/baidu_v2"
"github.com/QuantumNous/new-api/relay/channel/claude"
"github.com/QuantumNous/new-api/relay/channel/cloudflare"
"github.com/QuantumNous/new-api/relay/channel/cohere"
"github.com/QuantumNous/new-api/relay/channel/coze"
"github.com/QuantumNous/new-api/relay/channel/deepseek"
"github.com/QuantumNous/new-api/relay/channel/dify"
"github.com/QuantumNous/new-api/relay/channel/gemini"
"github.com/QuantumNous/new-api/relay/channel/jimeng"
"github.com/QuantumNous/new-api/relay/channel/jina"
"github.com/QuantumNous/new-api/relay/channel/mistral"
"github.com/QuantumNous/new-api/relay/channel/mokaai"
"github.com/QuantumNous/new-api/relay/channel/moonshot"
"github.com/QuantumNous/new-api/relay/channel/ollama"
"github.com/QuantumNous/new-api/relay/channel/openai"
"github.com/QuantumNous/new-api/relay/channel/palm"
"github.com/QuantumNous/new-api/relay/channel/perplexity"
"github.com/QuantumNous/new-api/relay/channel/siliconflow"
"github.com/QuantumNous/new-api/relay/channel/submodel"
"github.com/QuantumNous/new-api/relay/channel/tencent"
"github.com/QuantumNous/new-api/relay/channel/vertex"
"github.com/QuantumNous/new-api/relay/channel/volcengine"
"github.com/QuantumNous/new-api/relay/channel/xai"
"github.com/QuantumNous/new-api/relay/channel/xunfei"
"github.com/QuantumNous/new-api/relay/channel/zhipu"
"github.com/QuantumNous/new-api/relay/channel/zhipu_4v"
)
// init 包初始化时自动注册所有Channel插件
func init() {
RegisterAllChannels()
}
// RegisterAllChannels 注册所有Channel插件
func RegisterAllChannels() {
// 包装现有的Adaptor并注册为插件
channels := []struct {
channelType int
adaptor channel.Adaptor
name string
}{
{constant.APITypeOpenAI, &openai.Adaptor{}, "openai"},
{constant.APITypeAnthropic, &claude.Adaptor{}, "claude"},
{constant.APITypeGemini, &gemini.Adaptor{}, "gemini"},
{constant.APITypeAli, &ali.Adaptor{}, "ali"},
{constant.APITypeBaidu, &baidu.Adaptor{}, "baidu"},
{constant.APITypeBaiduV2, &baidu_v2.Adaptor{}, "baidu_v2"},
{constant.APITypeTencent, &tencent.Adaptor{}, "tencent"},
{constant.APITypeXunfei, &xunfei.Adaptor{}, "xunfei"},
{constant.APITypeZhipu, &zhipu.Adaptor{}, "zhipu"},
{constant.APITypeZhipuV4, &zhipu_4v.Adaptor{}, "zhipu_v4"},
{constant.APITypeOllama, &ollama.Adaptor{}, "ollama"},
{constant.APITypePerplexity, &perplexity.Adaptor{}, "perplexity"},
{constant.APITypeAws, &aws.Adaptor{}, "aws"},
{constant.APITypeCohere, &cohere.Adaptor{}, "cohere"},
{constant.APITypeDify, &dify.Adaptor{}, "dify"},
{constant.APITypeJina, &jina.Adaptor{}, "jina"},
{constant.APITypeCloudflare, &cloudflare.Adaptor{}, "cloudflare"},
{constant.APITypeSiliconFlow, &siliconflow.Adaptor{}, "siliconflow"},
{constant.APITypeVertexAi, &vertex.Adaptor{}, "vertex"},
{constant.APITypeMistral, &mistral.Adaptor{}, "mistral"},
{constant.APITypeDeepSeek, &deepseek.Adaptor{}, "deepseek"},
{constant.APITypeMokaAI, &mokaai.Adaptor{}, "mokaai"},
{constant.APITypeVolcEngine, &volcengine.Adaptor{}, "volcengine"},
{constant.APITypeXai, &xai.Adaptor{}, "xai"},
{constant.APITypeCoze, &coze.Adaptor{}, "coze"},
{constant.APITypeJimeng, &jimeng.Adaptor{}, "jimeng"},
{constant.APITypeMoonshot, &moonshot.Adaptor{}, "moonshot"},
{constant.APITypeSubmodel, &submodel.Adaptor{}, "submodel"},
{constant.APITypePaLM, &palm.Adaptor{}, "palm"},
// OpenRouter 和 Xinference 使用 OpenAI adaptor
{constant.APITypeOpenRouter, &openai.Adaptor{}, "openrouter"},
{constant.APITypeXinference, &openai.Adaptor{}, "xinference"},
}
registeredCount := 0
for _, ch := range channels {
plugin := NewBaseChannelPlugin(
ch.adaptor,
ch.name,
"1.0.0",
100, // 默认优先级
)
if err := registry.RegisterChannel(ch.channelType, plugin); err != nil {
common.SysError("Failed to register channel plugin: " + ch.name + ", error: " + err.Error())
} else {
registeredCount++
}
}
common.SysLog(fmt.Sprintf("Registered %d channel plugins", registeredCount))
}

View File

@@ -0,0 +1,186 @@
package content_filter
import (
"encoding/json"
"strings"
"github.com/QuantumNous/new-api/common"
"github.com/QuantumNous/new-api/core/interfaces"
)
// ContentFilterHook 内容过滤Hook
// 在响应返回前过滤敏感内容
type ContentFilterHook struct {
enabled bool
priority int
sensitiveWords []string
filterNSFW bool
filterPolitical bool
replacementText string
}
// NewContentFilterHook 创建ContentFilterHook实例
func NewContentFilterHook(config map[string]interface{}) *ContentFilterHook {
hook := &ContentFilterHook{
enabled: true,
priority: 100, // 高优先级,最后执行
sensitiveWords: []string{},
filterNSFW: true,
filterPolitical: false,
replacementText: "[已过滤]",
}
if enabled, ok := config["enabled"].(bool); ok {
hook.enabled = enabled
}
if priority, ok := config["priority"].(int); ok {
hook.priority = priority
}
if filterNSFW, ok := config["filter_nsfw"].(bool); ok {
hook.filterNSFW = filterNSFW
}
if filterPolitical, ok := config["filter_political"].(bool); ok {
hook.filterPolitical = filterPolitical
}
if words, ok := config["sensitive_words"].([]interface{}); ok {
for _, word := range words {
if w, ok := word.(string); ok {
hook.sensitiveWords = append(hook.sensitiveWords, w)
}
}
}
return hook
}
// Name 返回Hook名称
func (h *ContentFilterHook) Name() string {
return "content_filter"
}
// Priority 返回优先级
func (h *ContentFilterHook) Priority() int {
return h.priority
}
// Enabled 返回是否启用
func (h *ContentFilterHook) Enabled() bool {
return h.enabled
}
// OnBeforeRequest 请求前处理(不需要处理)
func (h *ContentFilterHook) OnBeforeRequest(ctx *interfaces.HookContext) error {
return nil
}
// OnAfterResponse 响应后处理 - 过滤内容
func (h *ContentFilterHook) OnAfterResponse(ctx *interfaces.HookContext) error {
if !h.Enabled() {
return nil
}
// 只处理chat completion响应
if !strings.Contains(ctx.Request.URL.Path, "chat/completions") {
return nil
}
// 如果没有响应体,跳过
if len(ctx.ResponseBody) == 0 {
return nil
}
// 解析响应
var response map[string]interface{}
if err := json.Unmarshal(ctx.ResponseBody, &response); err != nil {
return nil // 忽略解析错误
}
// 过滤内容
filtered := h.filterResponse(response)
// 如果内容被修改,更新响应体
if filtered {
modifiedBody, err := json.Marshal(response)
if err != nil {
return err
}
ctx.ResponseBody = modifiedBody
// 记录过滤事件
ctx.Data["content_filtered"] = true
common.SysLog("Content filter applied to response")
}
return nil
}
// OnError 错误处理
func (h *ContentFilterHook) OnError(ctx *interfaces.HookContext, err error) error {
return nil
}
// filterResponse 过滤响应内容
func (h *ContentFilterHook) filterResponse(response map[string]interface{}) bool {
modified := false
// 获取choices数组
choices, ok := response["choices"].([]interface{})
if !ok {
return false
}
// 遍历每个choice
for _, choice := range choices {
choiceMap, ok := choice.(map[string]interface{})
if !ok {
continue
}
// 获取message
message, ok := choiceMap["message"].(map[string]interface{})
if !ok {
continue
}
// 获取content
content, ok := message["content"].(string)
if !ok {
continue
}
// 过滤内容
filteredContent := h.filterText(content)
// 如果内容被修改
if filteredContent != content {
message["content"] = filteredContent
modified = true
}
}
return modified
}
// filterText 过滤文本内容
func (h *ContentFilterHook) filterText(text string) string {
filtered := text
// 过滤敏感词
for _, word := range h.sensitiveWords {
if strings.Contains(filtered, word) {
filtered = strings.ReplaceAll(filtered, word, h.replacementText)
}
}
// TODO: 实现更复杂的过滤逻辑
// - NSFW内容检测
// - 政治敏感内容检测
// - 使用AI模型进行内容分类
return filtered
}

View File

@@ -0,0 +1,39 @@
package content_filter
import (
"os"
"strings"
"github.com/QuantumNous/new-api/common"
"github.com/QuantumNous/new-api/core/registry"
)
func init() {
// 从环境变量读取配置
config := map[string]interface{}{
"enabled": os.Getenv("CONTENT_FILTER_ENABLED") == "true",
"priority": 100,
"filter_nsfw": os.Getenv("CONTENT_FILTER_NSFW") != "false",
"filter_political": os.Getenv("CONTENT_FILTER_POLITICAL") == "true",
}
// 读取敏感词列表
if wordsEnv := os.Getenv("CONTENT_FILTER_WORDS"); wordsEnv != "" {
words := strings.Split(wordsEnv, ",")
config["sensitive_words"] = words
}
// 创建并注册Hook
hook := NewContentFilterHook(config)
if err := registry.RegisterHook(hook); err != nil {
common.SysError("Failed to register content_filter hook: " + err.Error())
} else {
if hook.Enabled() {
common.SysLog("Content filter hook registered and enabled")
} else {
common.SysLog("Content filter hook registered but disabled")
}
}
}

View File

@@ -0,0 +1,39 @@
package web_search
import (
"os"
"github.com/QuantumNous/new-api/common"
"github.com/QuantumNous/new-api/core/registry"
)
func init() {
// 从环境变量读取配置
config := map[string]interface{}{
"enabled": os.Getenv("WEB_SEARCH_ENABLED") == "true",
"api_key": os.Getenv("WEB_SEARCH_API_KEY"),
"provider": getEnvOrDefault("WEB_SEARCH_PROVIDER", "google"),
"priority": 50,
}
// 创建并注册Hook
hook := NewWebSearchHook(config)
if err := registry.RegisterHook(hook); err != nil {
common.SysError("Failed to register web_search hook: " + err.Error())
} else {
if hook.Enabled() {
common.SysLog("Web search hook registered and enabled")
} else {
common.SysLog("Web search hook registered but disabled (missing API key or not enabled)")
}
}
}
func getEnvOrDefault(key, defaultValue string) string {
if value := os.Getenv(key); value != "" {
return value
}
return defaultValue
}

View File

@@ -0,0 +1,281 @@
package web_search
import (
"encoding/json"
"fmt"
"io"
"net/http"
"strings"
"github.com/QuantumNous/new-api/common"
"github.com/QuantumNous/new-api/core/interfaces"
)
// WebSearchHook 联网搜索Hook插件
// 在请求发送前检测是否需要联网搜索如果需要则调用搜索API并将结果注入到请求中
type WebSearchHook struct {
enabled bool
priority int
apiKey string
provider string // google, bing, etc
}
// NewWebSearchHook 创建WebSearchHook实例
func NewWebSearchHook(config map[string]interface{}) *WebSearchHook {
hook := &WebSearchHook{
enabled: true,
priority: 50, // 中等优先级
provider: "google",
}
if apiKey, ok := config["api_key"].(string); ok {
hook.apiKey = apiKey
}
if provider, ok := config["provider"].(string); ok {
hook.provider = provider
}
if priority, ok := config["priority"].(int); ok {
hook.priority = priority
}
if enabled, ok := config["enabled"].(bool); ok {
hook.enabled = enabled
}
return hook
}
// Name 返回Hook名称
func (h *WebSearchHook) Name() string {
return "web_search"
}
// Priority 返回优先级
func (h *WebSearchHook) Priority() int {
return h.priority
}
// Enabled 返回是否启用
func (h *WebSearchHook) Enabled() bool {
return h.enabled && h.apiKey != ""
}
// OnBeforeRequest 请求前处理
func (h *WebSearchHook) OnBeforeRequest(ctx *interfaces.HookContext) error {
if !h.Enabled() {
return nil
}
// 只处理chat completion请求
if !strings.Contains(ctx.Request.URL.Path, "chat/completions") {
return nil
}
// 检查请求体中是否包含搜索关键词
if len(ctx.RequestBody) == 0 {
return nil
}
// 解析请求体
var requestData map[string]interface{}
if err := json.Unmarshal(ctx.RequestBody, &requestData); err != nil {
return nil // 忽略解析错误
}
// 检查是否需要搜索(简单示例:检查最后一条消息是否包含 [search] 标记)
if !h.shouldSearch(requestData) {
return nil
}
// 执行搜索
searchQuery := h.extractSearchQuery(requestData)
if searchQuery == "" {
return nil
}
common.SysLog(fmt.Sprintf("Web search triggered for query: %s", searchQuery))
// 调用搜索API
searchResults, err := h.performSearch(searchQuery)
if err != nil {
common.SysError(fmt.Sprintf("Web search failed: %v", err))
return nil // 不中断请求,只记录错误
}
// 将搜索结果注入到请求中
h.injectSearchResults(requestData, searchResults)
// 更新请求体
modifiedBody, err := json.Marshal(requestData)
if err != nil {
return err
}
ctx.RequestBody = modifiedBody
// 存储到Data中供后续使用
ctx.Data["web_search_performed"] = true
ctx.Data["web_search_query"] = searchQuery
return nil
}
// OnAfterResponse 响应后处理
func (h *WebSearchHook) OnAfterResponse(ctx *interfaces.HookContext) error {
// 可以在这里记录搜索使用情况等
if performed, ok := ctx.Data["web_search_performed"].(bool); ok && performed {
query := ctx.Data["web_search_query"].(string)
common.SysLog(fmt.Sprintf("Web search completed for query: %s", query))
}
return nil
}
// OnError 错误处理
func (h *WebSearchHook) OnError(ctx *interfaces.HookContext, err error) error {
// 记录错误但不影响主流程
if performed, ok := ctx.Data["web_search_performed"].(bool); ok && performed {
common.SysError(fmt.Sprintf("Request failed after web search: %v", err))
}
return nil
}
// shouldSearch 判断是否需要搜索
func (h *WebSearchHook) shouldSearch(requestData map[string]interface{}) bool {
messages, ok := requestData["messages"].([]interface{})
if !ok || len(messages) == 0 {
return false
}
// 检查最后一条消息
lastMessage, ok := messages[len(messages)-1].(map[string]interface{})
if !ok {
return false
}
content, ok := lastMessage["content"].(string)
if !ok {
return false
}
// 简单示例:检查是否包含 [search] 或 [联网] 标记
return strings.Contains(content, "[search]") ||
strings.Contains(content, "[联网]") ||
strings.Contains(content, "[web]")
}
// extractSearchQuery 提取搜索查询
func (h *WebSearchHook) extractSearchQuery(requestData map[string]interface{}) string {
messages, ok := requestData["messages"].([]interface{})
if !ok || len(messages) == 0 {
return ""
}
lastMessage, ok := messages[len(messages)-1].(map[string]interface{})
if !ok {
return ""
}
content, ok := lastMessage["content"].(string)
if !ok {
return ""
}
// 移除标记,保留实际查询内容
query := strings.ReplaceAll(content, "[search]", "")
query = strings.ReplaceAll(query, "[联网]", "")
query = strings.ReplaceAll(query, "[web]", "")
query = strings.TrimSpace(query)
return query
}
// performSearch 执行搜索
func (h *WebSearchHook) performSearch(query string) (string, error) {
// 这里是示例实现实际应该调用真实的搜索API
// 例如Google Custom Search API, Bing Search API等
if h.apiKey == "" {
return "", fmt.Errorf("search API key not configured")
}
// 示例:返回模拟结果
// 实际实现需要调用真实API
return h.mockSearch(query)
}
// mockSearch 模拟搜索(示例)
func (h *WebSearchHook) mockSearch(query string) (string, error) {
// 这只是一个示例实现
// 实际应该调用真实的搜索API
common.SysLog(fmt.Sprintf("[Mock] Searching for: %s", query))
// 返回模拟的搜索结果
return fmt.Sprintf("搜索结果 (模拟):关于 '%s' 的最新信息...", query), nil
}
// realSearch 真实搜索实现示例需要配置API
func (h *WebSearchHook) realSearch(query string) (string, error) {
// 示例使用Google Custom Search API
url := fmt.Sprintf("https://www.googleapis.com/customsearch/v1?key=%s&cx=YOUR_CX&q=%s",
h.apiKey, query)
resp, err := http.Get(url)
if err != nil {
return "", err
}
defer resp.Body.Close()
body, err := io.ReadAll(resp.Body)
if err != nil {
return "", err
}
// 解析搜索结果
var result map[string]interface{}
if err := json.Unmarshal(body, &result); err != nil {
return "", err
}
// 提取搜索结果摘要
// 这里需要根据实际API响应格式处理
return string(body), nil
}
// injectSearchResults 将搜索结果注入到请求中
func (h *WebSearchHook) injectSearchResults(requestData map[string]interface{}, results string) {
messages, ok := requestData["messages"].([]interface{})
if !ok {
return
}
// 在用户消息前插入系统消息,包含搜索结果
systemMessage := map[string]interface{}{
"role": "system",
"content": fmt.Sprintf("以下是针对用户查询的最新搜索结果:\n\n%s\n\n请基于这些信息回答用户的问题。", results),
}
// 插入到消息列表的适当位置
updatedMessages := make([]interface{}, 0, len(messages)+1)
// 如果第一条是系统消息,在其后插入
if len(messages) > 0 {
if firstMsg, ok := messages[0].(map[string]interface{}); ok {
if role, ok := firstMsg["role"].(string); ok && role == "system" {
updatedMessages = append(updatedMessages, messages[0])
updatedMessages = append(updatedMessages, systemMessage)
updatedMessages = append(updatedMessages, messages[1:]...)
requestData["messages"] = updatedMessages
return
}
}
}
// 否则插入到开头
updatedMessages = append(updatedMessages, systemMessage)
updatedMessages = append(updatedMessages, messages...)
requestData["messages"] = updatedMessages
}

View File

@@ -5,6 +5,7 @@ import (
"github.com/QuantumNous/new-api/common"
relaycommon "github.com/QuantumNous/new-api/relay/common"
"github.com/QuantumNous/new-api/setting/operation_setting"
"github.com/QuantumNous/new-api/setting/ratio_setting"
"github.com/QuantumNous/new-api/types"
@@ -55,6 +56,7 @@ func ModelPriceHelper(c *gin.Context, info *relaycommon.RelayInfo, promptTokens
var cacheCreationRatio float64
var audioRatio float64
var audioCompletionRatio float64
var freeModel bool
if !usePrice {
preConsumedTokens := common.Max(promptTokens, common.PreConsumedQuota)
if meta.MaxTokens != 0 {
@@ -87,18 +89,35 @@ func ModelPriceHelper(c *gin.Context, info *relaycommon.RelayInfo, promptTokens
preConsumedQuota = int(modelPrice * common.QuotaPerUnit * groupRatioInfo.GroupRatio)
}
// check if free model pre-consume is disabled
if !operation_setting.GetQuotaSetting().EnableFreeModelPreConsume {
// if model price or ratio is 0, do not pre-consume quota
if usePrice {
if modelPrice == 0 {
preConsumedQuota = 0
freeModel = true
}
} else {
if modelRatio == 0 {
preConsumedQuota = 0
freeModel = true
}
}
}
priceData := types.PriceData{
ModelPrice: modelPrice,
ModelRatio: modelRatio,
CompletionRatio: completionRatio,
GroupRatioInfo: groupRatioInfo,
UsePrice: usePrice,
CacheRatio: cacheRatio,
ImageRatio: imageRatio,
AudioRatio: audioRatio,
AudioCompletionRatio: audioCompletionRatio,
CacheCreationRatio: cacheCreationRatio,
ShouldPreConsumedQuota: preConsumedQuota,
FreeModel: freeModel,
ModelPrice: modelPrice,
ModelRatio: modelRatio,
CompletionRatio: completionRatio,
GroupRatioInfo: groupRatioInfo,
UsePrice: usePrice,
CacheRatio: cacheRatio,
ImageRatio: imageRatio,
AudioRatio: audioRatio,
AudioCompletionRatio: audioCompletionRatio,
CacheCreationRatio: cacheCreationRatio,
QuotaToPreConsume: preConsumedQuota,
}
if common.DebugEnabled {

136
relay/hooks/chain.go Normal file
View File

@@ -0,0 +1,136 @@
package hooks
import (
"fmt"
"sync"
"github.com/QuantumNous/new-api/common"
"github.com/QuantumNous/new-api/core/interfaces"
"github.com/QuantumNous/new-api/core/registry"
)
var (
// 全局Hook链实例单例
globalChain *HookChain
globalChainOnce sync.Once
)
// HookChain Hook执行链
type HookChain struct {
hooks []interfaces.RelayHook
mu sync.RWMutex
}
// GetGlobalChain 获取全局Hook链实例
func GetGlobalChain() *HookChain {
globalChainOnce.Do(func() {
globalChain = &HookChain{
hooks: make([]interfaces.RelayHook, 0),
}
// 从注册中心加载hooks
globalChain.LoadHooks()
})
return globalChain
}
// LoadHooks 从注册中心加载hooks
func (c *HookChain) LoadHooks() {
c.mu.Lock()
defer c.mu.Unlock()
c.hooks = registry.ListHooks()
common.SysLog(fmt.Sprintf("Loaded %d enabled hooks", len(c.hooks)))
}
// ReloadHooks 重新加载hooks
func (c *HookChain) ReloadHooks() {
c.LoadHooks()
common.SysLog("Hooks reloaded")
}
// ExecuteBeforeRequest 执行所有BeforeRequest钩子
func (c *HookChain) ExecuteBeforeRequest(ctx *interfaces.HookContext) error {
c.mu.RLock()
hooks := c.hooks
c.mu.RUnlock()
for _, hook := range hooks {
if !hook.Enabled() {
continue
}
if ctx.ShouldSkip {
break
}
if err := hook.OnBeforeRequest(ctx); err != nil {
common.SysError(fmt.Sprintf("Hook %s OnBeforeRequest error: %v", hook.Name(), err))
return fmt.Errorf("hook %s failed: %w", hook.Name(), err)
}
}
return nil
}
// ExecuteAfterResponse 执行所有AfterResponse钩子
func (c *HookChain) ExecuteAfterResponse(ctx *interfaces.HookContext) error {
c.mu.RLock()
hooks := c.hooks
c.mu.RUnlock()
for _, hook := range hooks {
if !hook.Enabled() {
continue
}
if ctx.ShouldSkip {
break
}
if err := hook.OnAfterResponse(ctx); err != nil {
common.SysError(fmt.Sprintf("Hook %s OnAfterResponse error: %v", hook.Name(), err))
return fmt.Errorf("hook %s failed: %w", hook.Name(), err)
}
}
return nil
}
// ExecuteOnError 执行所有OnError钩子
func (c *HookChain) ExecuteOnError(ctx *interfaces.HookContext, err error) error {
c.mu.RLock()
hooks := c.hooks
c.mu.RUnlock()
for _, hook := range hooks {
if !hook.Enabled() {
continue
}
if hookErr := hook.OnError(ctx, err); hookErr != nil {
common.SysError(fmt.Sprintf("Hook %s OnError failed: %v", hook.Name(), hookErr))
// OnError钩子的错误不会中断执行
}
}
return err
}
// GetHooks 获取当前hook列表
func (c *HookChain) GetHooks() []interfaces.RelayHook {
c.mu.RLock()
defer c.mu.RUnlock()
hooks := make([]interfaces.RelayHook, len(c.hooks))
copy(hooks, c.hooks)
return hooks
}
// Count 返回hook数量
func (c *HookChain) Count() int {
c.mu.RLock()
defer c.mu.RUnlock()
return len(c.hooks)
}

212
relay/hooks/chain_test.go Normal file
View File

@@ -0,0 +1,212 @@
package hooks
import (
"errors"
"testing"
"github.com/QuantumNous/new-api/core/interfaces"
"github.com/QuantumNous/new-api/core/registry"
)
// Mock Hook实现
type testHook struct {
name string
priority int
enabled bool
beforeCalled bool
afterCalled bool
errorCalled bool
shouldReturnError bool
}
func (h *testHook) Name() string { return h.name }
func (h *testHook) Priority() int { return h.priority }
func (h *testHook) Enabled() bool { return h.enabled }
func (h *testHook) OnBeforeRequest(ctx *interfaces.HookContext) error {
h.beforeCalled = true
if h.shouldReturnError {
return errors.New("test error")
}
return nil
}
func (h *testHook) OnAfterResponse(ctx *interfaces.HookContext) error {
h.afterCalled = true
if h.shouldReturnError {
return errors.New("test error")
}
return nil
}
func (h *testHook) OnError(ctx *interfaces.HookContext, err error) error {
h.errorCalled = true
return nil
}
func TestHookChainExecution(t *testing.T) {
// 创建测试hooks
hook1 := &testHook{name: "hook1", priority: 100, enabled: true}
hook2 := &testHook{name: "hook2", priority: 50, enabled: true}
hook3 := &testHook{name: "hook3", priority: 75, enabled: false} // disabled
// 创建Hook链
chain := &HookChain{
hooks: []interfaces.RelayHook{hook1, hook2, hook3},
}
// 创建测试上下文
ctx := &interfaces.HookContext{
Data: make(map[string]interface{}),
}
// 测试ExecuteBeforeRequest
if err := chain.ExecuteBeforeRequest(ctx); err != nil {
t.Errorf("ExecuteBeforeRequest failed: %v", err)
}
// 检查enabled的hooks是否被调用
if !hook1.beforeCalled {
t.Error("hook1 OnBeforeRequest should be called")
}
if !hook2.beforeCalled {
t.Error("hook2 OnBeforeRequest should be called")
}
// disabled的hook不应该被调用
if hook3.beforeCalled {
t.Error("hook3 OnBeforeRequest should not be called (disabled)")
}
// 测试ExecuteAfterResponse
if err := chain.ExecuteAfterResponse(ctx); err != nil {
t.Errorf("ExecuteAfterResponse failed: %v", err)
}
if !hook1.afterCalled {
t.Error("hook1 OnAfterResponse should be called")
}
if !hook2.afterCalled {
t.Error("hook2 OnAfterResponse should be called")
}
// 测试ExecuteOnError
testErr := errors.New("test error")
if err := chain.ExecuteOnError(ctx, testErr); err != testErr {
t.Error("ExecuteOnError should return original error")
}
if !hook1.errorCalled {
t.Error("hook1 OnError should be called")
}
}
func TestHookChainErrorHandling(t *testing.T) {
// 创建会返回错误的hook
errorHook := &testHook{
name: "error_hook",
priority: 100,
enabled: true,
shouldReturnError: true,
}
chain := &HookChain{
hooks: []interfaces.RelayHook{errorHook},
}
ctx := &interfaces.HookContext{
Data: make(map[string]interface{}),
}
// 测试错误处理
if err := chain.ExecuteBeforeRequest(ctx); err == nil {
t.Error("Expected error from ExecuteBeforeRequest")
}
}
func TestHookChainShouldSkip(t *testing.T) {
hook1 := &testHook{name: "hook1", priority: 100, enabled: true}
hook2 := &testHook{name: "hook2", priority: 50, enabled: true}
chain := &HookChain{
hooks: []interfaces.RelayHook{hook1, hook2},
}
ctx := &interfaces.HookContext{
Data: make(map[string]interface{}),
ShouldSkip: true, // 设置跳过标记
}
// 执行
if err := chain.ExecuteBeforeRequest(ctx); err != nil {
t.Errorf("ExecuteBeforeRequest failed: %v", err)
}
// 由于ShouldSkip为truehooks不应该被调用
// 注意当前实现在第一个hook执行后才会检查ShouldSkip
// 所以hook1会被调用但hook2不会
}
func TestHookChainCount(t *testing.T) {
hook1 := &testHook{name: "hook1", priority: 100, enabled: true}
hook2 := &testHook{name: "hook2", priority: 50, enabled: true}
chain := &HookChain{
hooks: []interfaces.RelayHook{hook1, hook2},
}
if count := chain.Count(); count != 2 {
t.Errorf("Expected count 2, got %d", count)
}
}
func TestHookChainGetHooks(t *testing.T) {
hook1 := &testHook{name: "hook1", priority: 100, enabled: true}
hook2 := &testHook{name: "hook2", priority: 50, enabled: true}
chain := &HookChain{
hooks: []interfaces.RelayHook{hook1, hook2},
}
hooks := chain.GetHooks()
if len(hooks) != 2 {
t.Errorf("Expected 2 hooks, got %d", len(hooks))
}
}
func TestGlobalChain(t *testing.T) {
// 测试全局链的单例模式
chain1 := GetGlobalChain()
chain2 := GetGlobalChain()
if chain1 != chain2 {
t.Error("GetGlobalChain should return the same instance")
}
}
// 集成测试:测试完整的注册和执行流程
func TestIntegration(t *testing.T) {
// 注册测试hook
testHook := &testHook{
name: "integration_test_hook",
priority: 100,
enabled: true,
}
if err := registry.RegisterHook(testHook); err != nil {
// 如果已注册,跳过错误
t.Logf("Hook already registered (expected in some cases): %v", err)
}
// 创建新的hook链并加载
chain := &HookChain{hooks: make([]interfaces.RelayHook, 0)}
chain.LoadHooks()
// 检查是否加载了hooks
if chain.Count() == 0 {
t.Log("No hooks loaded (expected if registry is clean)")
}
}

View File

@@ -0,0 +1,79 @@
package hooks
import (
"net/http"
"github.com/QuantumNous/new-api/common"
"github.com/QuantumNous/new-api/constant"
"github.com/QuantumNous/new-api/core/interfaces"
"github.com/gin-gonic/gin"
)
// BuildHookContext 从Gin Context构建HookContext
func BuildHookContext(c *gin.Context) *interfaces.HookContext {
ctx := &interfaces.HookContext{
GinContext: c,
Request: c.Request,
Data: make(map[string]interface{}),
}
// 提取Channel信息
if channelID, ok := common.GetContextKey(c, constant.ContextKeyChannelId); ok {
if id, ok := channelID.(int); ok {
ctx.ChannelID = id
}
}
if channelType, ok := common.GetContextKey(c, constant.ContextKeyChannelType); ok {
if t, ok := channelType.(int); ok {
ctx.ChannelType = t
}
}
if channelName, ok := common.GetContextKey(c, constant.ContextKeyChannelName); ok {
if name, ok := channelName.(string); ok {
ctx.ChannelName = name
}
}
// 提取Model信息
if originalModel, ok := common.GetContextKey(c, constant.ContextKeyOriginalModel); ok {
if m, ok := originalModel.(string); ok {
ctx.OriginalModel = m
ctx.Model = m // 使用OriginalModel作为Model
}
}
// 提取User信息
if userID, ok := common.GetContextKey(c, constant.ContextKeyUserId); ok {
if id, ok := userID.(int); ok {
ctx.UserID = id
}
}
if tokenID, ok := common.GetContextKey(c, constant.ContextKeyTokenId); ok {
if id, ok := tokenID.(int); ok {
ctx.TokenID = id
}
}
if group, ok := common.GetContextKey(c, constant.ContextKeyUsingGroup); ok {
if g, ok := group.(string); ok {
ctx.Group = g
}
}
return ctx
}
// UpdateHookContextWithResponse 更新HookContext的Response信息
func UpdateHookContextWithResponse(ctx *interfaces.HookContext, resp *http.Response, body []byte) {
ctx.Response = resp
ctx.ResponseBody = body
}
// UpdateHookContextWithRequest 更新HookContext的Request信息
func UpdateHookContextWithRequest(ctx *interfaces.HookContext, body []byte) {
ctx.RequestBody = body
}

View File

@@ -4,6 +4,8 @@ import (
"strconv"
"github.com/QuantumNous/new-api/constant"
"github.com/QuantumNous/new-api/core/registry"
pluginchannels "github.com/QuantumNous/new-api/plugins/channels"
"github.com/QuantumNous/new-api/relay/channel"
"github.com/QuantumNous/new-api/relay/channel/ali"
"github.com/QuantumNous/new-api/relay/channel/aws"
@@ -44,7 +46,19 @@ import (
"github.com/gin-gonic/gin"
)
// GetAdaptor 获取Channel适配器优先从插件注册中心获取保持向后兼容
func GetAdaptor(apiType int) channel.Adaptor {
// 优先从插件注册中心获取
if plugin, err := registry.GetChannel(apiType); err == nil {
// 如果是BaseChannelPlugin提取内部的Adaptor
if basePlugin, ok := plugin.(*pluginchannels.BaseChannelPlugin); ok {
return basePlugin.GetAdaptor()
}
// 否则直接返回plugin它也实现了Adaptor接口
return plugin
}
// 向后兼容:如果注册中心没有,使用原有的硬编码方式
switch apiType {
case constant.APITypeAli:
return &ali.Adaptor{}

View File

@@ -0,0 +1,21 @@
package operation_setting
import "github.com/QuantumNous/new-api/setting/config"
type QuotaSetting struct {
EnableFreeModelPreConsume bool `json:"enable_free_model_pre_consume"` // 是否对免费模型启用预消耗
}
// 默认配置
var quotaSetting = QuotaSetting{
EnableFreeModelPreConsume: true,
}
func init() {
// 注册到全局配置管理器
config.GlobalConfig.Register("quota_setting", &quotaSetting)
}
func GetQuotaSetting() *QuotaSetting {
return &quotaSetting
}

View File

@@ -9,18 +9,19 @@ type GroupRatioInfo struct {
}
type PriceData struct {
ModelPrice float64
ModelRatio float64
CompletionRatio float64
CacheRatio float64
CacheCreationRatio float64
ImageRatio float64
AudioRatio float64
AudioCompletionRatio float64
OtherRatios map[string]float64
UsePrice bool
ShouldPreConsumedQuota int
GroupRatioInfo GroupRatioInfo
FreeModel bool
ModelPrice float64
ModelRatio float64
CompletionRatio float64
CacheRatio float64
CacheCreationRatio float64
ImageRatio float64
AudioRatio float64
AudioCompletionRatio float64
OtherRatios map[string]float64
UsePrice bool
QuotaToPreConsume int // 预消耗额度
GroupRatioInfo GroupRatioInfo
}
type PerCallPriceData struct {
@@ -30,5 +31,5 @@ type PerCallPriceData struct {
}
func (p PriceData) ToSetting() string {
return fmt.Sprintf("ModelPrice: %f, ModelRatio: %f, CompletionRatio: %f, CacheRatio: %f, GroupRatio: %f, UsePrice: %t, CacheCreationRatio: %f, ShouldPreConsumedQuota: %d, ImageRatio: %f, AudioRatio: %f, AudioCompletionRatio: %f", p.ModelPrice, p.ModelRatio, p.CompletionRatio, p.CacheRatio, p.GroupRatioInfo.GroupRatio, p.UsePrice, p.CacheCreationRatio, p.ShouldPreConsumedQuota, p.ImageRatio, p.AudioRatio, p.AudioCompletionRatio)
return fmt.Sprintf("ModelPrice: %f, ModelRatio: %f, CompletionRatio: %f, CacheRatio: %f, GroupRatio: %f, UsePrice: %t, CacheCreationRatio: %f, QuotaToPreConsume: %d, ImageRatio: %f, AudioRatio: %f, AudioCompletionRatio: %f", p.ModelPrice, p.ModelRatio, p.CompletionRatio, p.CacheRatio, p.GroupRatioInfo.GroupRatio, p.UsePrice, p.CacheCreationRatio, p.QuotaToPreConsume, p.ImageRatio, p.AudioRatio, p.AudioCompletionRatio)
}

View File

@@ -35,6 +35,7 @@ const OperationSetting = () => {
PreConsumedQuota: 0,
QuotaForInviter: 0,
QuotaForInvitee: 0,
'quota_setting.enable_free_model_pre_consume': true,
/* 通用设置 */
TopUpLink: '',

View File

@@ -36,6 +36,7 @@ export default function SettingsCreditLimit(props) {
PreConsumedQuota: '',
QuotaForInviter: '',
QuotaForInvitee: '',
'quota_setting.enable_free_model_pre_consume': true,
});
const refForm = useRef();
const [inputsRow, setInputsRow] = useState(inputs);
@@ -166,6 +167,21 @@ export default function SettingsCreditLimit(props) {
/>
</Col>
</Row>
<Row>
<Col>
<Form.Switch
label={t('对免费模型启用预消耗')}
field={'quota_setting.enable_free_model_pre_consume'}
extraText={t('开启后对免费模型倍率为0或者价格为0的模型也会预消耗额度')}
onChange={(value) =>
setInputs({
...inputs,
'quota_setting.enable_free_model_pre_consume': value,
})
}
/>
</Col>
</Row>
<Row>
<Button size='default' onClick={onSubmit}>