mirror of
https://github.com/QuantumNous/new-api.git
synced 2026-04-30 13:11:45 +00:00
✨ feat(architecture): Core+Plugin
This commit is contained in:
114
plugins/channels/base_plugin.go
Normal file
114
plugins/channels/base_plugin.go
Normal file
@@ -0,0 +1,114 @@
|
||||
package channels
|
||||
|
||||
import (
|
||||
"io"
|
||||
"net/http"
|
||||
|
||||
"github.com/QuantumNous/new-api/dto"
|
||||
"github.com/QuantumNous/new-api/relay/channel"
|
||||
relaycommon "github.com/QuantumNous/new-api/relay/common"
|
||||
"github.com/QuantumNous/new-api/types"
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
// BaseChannelPlugin 基础Channel插件
|
||||
// 包装现有的Adaptor实现,使其符合ChannelPlugin接口
|
||||
type BaseChannelPlugin struct {
|
||||
adaptor channel.Adaptor
|
||||
name string
|
||||
version string
|
||||
priority int
|
||||
}
|
||||
|
||||
// NewBaseChannelPlugin 创建基础Channel插件
|
||||
func NewBaseChannelPlugin(adaptor channel.Adaptor, name, version string, priority int) *BaseChannelPlugin {
|
||||
return &BaseChannelPlugin{
|
||||
adaptor: adaptor,
|
||||
name: name,
|
||||
version: version,
|
||||
priority: priority,
|
||||
}
|
||||
}
|
||||
|
||||
// Name 返回插件名称
|
||||
func (p *BaseChannelPlugin) Name() string {
|
||||
return p.name
|
||||
}
|
||||
|
||||
// Version 返回插件版本
|
||||
func (p *BaseChannelPlugin) Version() string {
|
||||
return p.version
|
||||
}
|
||||
|
||||
// Priority 返回优先级
|
||||
func (p *BaseChannelPlugin) Priority() int {
|
||||
return p.priority
|
||||
}
|
||||
|
||||
// 以下方法直接委托给内部的Adaptor
|
||||
|
||||
func (p *BaseChannelPlugin) Init(info *relaycommon.RelayInfo) {
|
||||
p.adaptor.Init(info)
|
||||
}
|
||||
|
||||
func (p *BaseChannelPlugin) GetRequestURL(info *relaycommon.RelayInfo) (string, error) {
|
||||
return p.adaptor.GetRequestURL(info)
|
||||
}
|
||||
|
||||
func (p *BaseChannelPlugin) SetupRequestHeader(c *gin.Context, req *http.Header, info *relaycommon.RelayInfo) error {
|
||||
return p.adaptor.SetupRequestHeader(c, req, info)
|
||||
}
|
||||
|
||||
func (p *BaseChannelPlugin) ConvertOpenAIRequest(c *gin.Context, info *relaycommon.RelayInfo, request *dto.GeneralOpenAIRequest) (any, error) {
|
||||
return p.adaptor.ConvertOpenAIRequest(c, info, request)
|
||||
}
|
||||
|
||||
func (p *BaseChannelPlugin) ConvertRerankRequest(c *gin.Context, relayMode int, request dto.RerankRequest) (any, error) {
|
||||
return p.adaptor.ConvertRerankRequest(c, relayMode, request)
|
||||
}
|
||||
|
||||
func (p *BaseChannelPlugin) ConvertEmbeddingRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.EmbeddingRequest) (any, error) {
|
||||
return p.adaptor.ConvertEmbeddingRequest(c, info, request)
|
||||
}
|
||||
|
||||
func (p *BaseChannelPlugin) ConvertAudioRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.AudioRequest) (io.Reader, error) {
|
||||
return p.adaptor.ConvertAudioRequest(c, info, request)
|
||||
}
|
||||
|
||||
func (p *BaseChannelPlugin) ConvertImageRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.ImageRequest) (any, error) {
|
||||
return p.adaptor.ConvertImageRequest(c, info, request)
|
||||
}
|
||||
|
||||
func (p *BaseChannelPlugin) ConvertOpenAIResponsesRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.OpenAIResponsesRequest) (any, error) {
|
||||
return p.adaptor.ConvertOpenAIResponsesRequest(c, info, request)
|
||||
}
|
||||
|
||||
func (p *BaseChannelPlugin) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (any, error) {
|
||||
return p.adaptor.DoRequest(c, info, requestBody)
|
||||
}
|
||||
|
||||
func (p *BaseChannelPlugin) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage any, err *types.NewAPIError) {
|
||||
return p.adaptor.DoResponse(c, resp, info)
|
||||
}
|
||||
|
||||
func (p *BaseChannelPlugin) GetModelList() []string {
|
||||
return p.adaptor.GetModelList()
|
||||
}
|
||||
|
||||
func (p *BaseChannelPlugin) GetChannelName() string {
|
||||
return p.adaptor.GetChannelName()
|
||||
}
|
||||
|
||||
func (p *BaseChannelPlugin) ConvertClaudeRequest(c *gin.Context, info *relaycommon.RelayInfo, request *dto.ClaudeRequest) (any, error) {
|
||||
return p.adaptor.ConvertClaudeRequest(c, info, request)
|
||||
}
|
||||
|
||||
func (p *BaseChannelPlugin) ConvertGeminiRequest(c *gin.Context, info *relaycommon.RelayInfo, request *dto.GeminiChatRequest) (any, error) {
|
||||
return p.adaptor.ConvertGeminiRequest(c, info, request)
|
||||
}
|
||||
|
||||
// GetAdaptor 获取内部的Adaptor(用于向后兼容)
|
||||
func (p *BaseChannelPlugin) GetAdaptor() channel.Adaptor {
|
||||
return p.adaptor
|
||||
}
|
||||
|
||||
106
plugins/channels/registry.go
Normal file
106
plugins/channels/registry.go
Normal file
@@ -0,0 +1,106 @@
|
||||
package channels
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
|
||||
"github.com/QuantumNous/new-api/common"
|
||||
"github.com/QuantumNous/new-api/constant"
|
||||
"github.com/QuantumNous/new-api/core/registry"
|
||||
"github.com/QuantumNous/new-api/relay/channel"
|
||||
"github.com/QuantumNous/new-api/relay/channel/ali"
|
||||
"github.com/QuantumNous/new-api/relay/channel/aws"
|
||||
"github.com/QuantumNous/new-api/relay/channel/baidu"
|
||||
"github.com/QuantumNous/new-api/relay/channel/baidu_v2"
|
||||
"github.com/QuantumNous/new-api/relay/channel/claude"
|
||||
"github.com/QuantumNous/new-api/relay/channel/cloudflare"
|
||||
"github.com/QuantumNous/new-api/relay/channel/cohere"
|
||||
"github.com/QuantumNous/new-api/relay/channel/coze"
|
||||
"github.com/QuantumNous/new-api/relay/channel/deepseek"
|
||||
"github.com/QuantumNous/new-api/relay/channel/dify"
|
||||
"github.com/QuantumNous/new-api/relay/channel/gemini"
|
||||
"github.com/QuantumNous/new-api/relay/channel/jimeng"
|
||||
"github.com/QuantumNous/new-api/relay/channel/jina"
|
||||
"github.com/QuantumNous/new-api/relay/channel/mistral"
|
||||
"github.com/QuantumNous/new-api/relay/channel/mokaai"
|
||||
"github.com/QuantumNous/new-api/relay/channel/moonshot"
|
||||
"github.com/QuantumNous/new-api/relay/channel/ollama"
|
||||
"github.com/QuantumNous/new-api/relay/channel/openai"
|
||||
"github.com/QuantumNous/new-api/relay/channel/palm"
|
||||
"github.com/QuantumNous/new-api/relay/channel/perplexity"
|
||||
"github.com/QuantumNous/new-api/relay/channel/siliconflow"
|
||||
"github.com/QuantumNous/new-api/relay/channel/submodel"
|
||||
"github.com/QuantumNous/new-api/relay/channel/tencent"
|
||||
"github.com/QuantumNous/new-api/relay/channel/vertex"
|
||||
"github.com/QuantumNous/new-api/relay/channel/volcengine"
|
||||
"github.com/QuantumNous/new-api/relay/channel/xai"
|
||||
"github.com/QuantumNous/new-api/relay/channel/xunfei"
|
||||
"github.com/QuantumNous/new-api/relay/channel/zhipu"
|
||||
"github.com/QuantumNous/new-api/relay/channel/zhipu_4v"
|
||||
)
|
||||
|
||||
// init 包初始化时自动注册所有Channel插件
|
||||
func init() {
|
||||
RegisterAllChannels()
|
||||
}
|
||||
|
||||
// RegisterAllChannels 注册所有Channel插件
|
||||
func RegisterAllChannels() {
|
||||
// 包装现有的Adaptor并注册为插件
|
||||
channels := []struct {
|
||||
channelType int
|
||||
adaptor channel.Adaptor
|
||||
name string
|
||||
}{
|
||||
{constant.APITypeOpenAI, &openai.Adaptor{}, "openai"},
|
||||
{constant.APITypeAnthropic, &claude.Adaptor{}, "claude"},
|
||||
{constant.APITypeGemini, &gemini.Adaptor{}, "gemini"},
|
||||
{constant.APITypeAli, &ali.Adaptor{}, "ali"},
|
||||
{constant.APITypeBaidu, &baidu.Adaptor{}, "baidu"},
|
||||
{constant.APITypeBaiduV2, &baidu_v2.Adaptor{}, "baidu_v2"},
|
||||
{constant.APITypeTencent, &tencent.Adaptor{}, "tencent"},
|
||||
{constant.APITypeXunfei, &xunfei.Adaptor{}, "xunfei"},
|
||||
{constant.APITypeZhipu, &zhipu.Adaptor{}, "zhipu"},
|
||||
{constant.APITypeZhipuV4, &zhipu_4v.Adaptor{}, "zhipu_v4"},
|
||||
{constant.APITypeOllama, &ollama.Adaptor{}, "ollama"},
|
||||
{constant.APITypePerplexity, &perplexity.Adaptor{}, "perplexity"},
|
||||
{constant.APITypeAws, &aws.Adaptor{}, "aws"},
|
||||
{constant.APITypeCohere, &cohere.Adaptor{}, "cohere"},
|
||||
{constant.APITypeDify, &dify.Adaptor{}, "dify"},
|
||||
{constant.APITypeJina, &jina.Adaptor{}, "jina"},
|
||||
{constant.APITypeCloudflare, &cloudflare.Adaptor{}, "cloudflare"},
|
||||
{constant.APITypeSiliconFlow, &siliconflow.Adaptor{}, "siliconflow"},
|
||||
{constant.APITypeVertexAi, &vertex.Adaptor{}, "vertex"},
|
||||
{constant.APITypeMistral, &mistral.Adaptor{}, "mistral"},
|
||||
{constant.APITypeDeepSeek, &deepseek.Adaptor{}, "deepseek"},
|
||||
{constant.APITypeMokaAI, &mokaai.Adaptor{}, "mokaai"},
|
||||
{constant.APITypeVolcEngine, &volcengine.Adaptor{}, "volcengine"},
|
||||
{constant.APITypeXai, &xai.Adaptor{}, "xai"},
|
||||
{constant.APITypeCoze, &coze.Adaptor{}, "coze"},
|
||||
{constant.APITypeJimeng, &jimeng.Adaptor{}, "jimeng"},
|
||||
{constant.APITypeMoonshot, &moonshot.Adaptor{}, "moonshot"},
|
||||
{constant.APITypeSubmodel, &submodel.Adaptor{}, "submodel"},
|
||||
{constant.APITypePaLM, &palm.Adaptor{}, "palm"},
|
||||
// OpenRouter 和 Xinference 使用 OpenAI adaptor
|
||||
{constant.APITypeOpenRouter, &openai.Adaptor{}, "openrouter"},
|
||||
{constant.APITypeXinference, &openai.Adaptor{}, "xinference"},
|
||||
}
|
||||
|
||||
registeredCount := 0
|
||||
for _, ch := range channels {
|
||||
plugin := NewBaseChannelPlugin(
|
||||
ch.adaptor,
|
||||
ch.name,
|
||||
"1.0.0",
|
||||
100, // 默认优先级
|
||||
)
|
||||
|
||||
if err := registry.RegisterChannel(ch.channelType, plugin); err != nil {
|
||||
common.SysError("Failed to register channel plugin: " + ch.name + ", error: " + err.Error())
|
||||
} else {
|
||||
registeredCount++
|
||||
}
|
||||
}
|
||||
|
||||
common.SysLog(fmt.Sprintf("Registered %d channel plugins", registeredCount))
|
||||
}
|
||||
|
||||
186
plugins/hooks/content_filter/content_filter_hook.go
Normal file
186
plugins/hooks/content_filter/content_filter_hook.go
Normal file
@@ -0,0 +1,186 @@
|
||||
package content_filter
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"strings"
|
||||
|
||||
"github.com/QuantumNous/new-api/common"
|
||||
"github.com/QuantumNous/new-api/core/interfaces"
|
||||
)
|
||||
|
||||
// ContentFilterHook 内容过滤Hook
|
||||
// 在响应返回前过滤敏感内容
|
||||
type ContentFilterHook struct {
|
||||
enabled bool
|
||||
priority int
|
||||
sensitiveWords []string
|
||||
filterNSFW bool
|
||||
filterPolitical bool
|
||||
replacementText string
|
||||
}
|
||||
|
||||
// NewContentFilterHook 创建ContentFilterHook实例
|
||||
func NewContentFilterHook(config map[string]interface{}) *ContentFilterHook {
|
||||
hook := &ContentFilterHook{
|
||||
enabled: true,
|
||||
priority: 100, // 高优先级,最后执行
|
||||
sensitiveWords: []string{},
|
||||
filterNSFW: true,
|
||||
filterPolitical: false,
|
||||
replacementText: "[已过滤]",
|
||||
}
|
||||
|
||||
if enabled, ok := config["enabled"].(bool); ok {
|
||||
hook.enabled = enabled
|
||||
}
|
||||
|
||||
if priority, ok := config["priority"].(int); ok {
|
||||
hook.priority = priority
|
||||
}
|
||||
|
||||
if filterNSFW, ok := config["filter_nsfw"].(bool); ok {
|
||||
hook.filterNSFW = filterNSFW
|
||||
}
|
||||
|
||||
if filterPolitical, ok := config["filter_political"].(bool); ok {
|
||||
hook.filterPolitical = filterPolitical
|
||||
}
|
||||
|
||||
if words, ok := config["sensitive_words"].([]interface{}); ok {
|
||||
for _, word := range words {
|
||||
if w, ok := word.(string); ok {
|
||||
hook.sensitiveWords = append(hook.sensitiveWords, w)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return hook
|
||||
}
|
||||
|
||||
// Name 返回Hook名称
|
||||
func (h *ContentFilterHook) Name() string {
|
||||
return "content_filter"
|
||||
}
|
||||
|
||||
// Priority 返回优先级
|
||||
func (h *ContentFilterHook) Priority() int {
|
||||
return h.priority
|
||||
}
|
||||
|
||||
// Enabled 返回是否启用
|
||||
func (h *ContentFilterHook) Enabled() bool {
|
||||
return h.enabled
|
||||
}
|
||||
|
||||
// OnBeforeRequest 请求前处理(不需要处理)
|
||||
func (h *ContentFilterHook) OnBeforeRequest(ctx *interfaces.HookContext) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
// OnAfterResponse 响应后处理 - 过滤内容
|
||||
func (h *ContentFilterHook) OnAfterResponse(ctx *interfaces.HookContext) error {
|
||||
if !h.Enabled() {
|
||||
return nil
|
||||
}
|
||||
|
||||
// 只处理chat completion响应
|
||||
if !strings.Contains(ctx.Request.URL.Path, "chat/completions") {
|
||||
return nil
|
||||
}
|
||||
|
||||
// 如果没有响应体,跳过
|
||||
if len(ctx.ResponseBody) == 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
// 解析响应
|
||||
var response map[string]interface{}
|
||||
if err := json.Unmarshal(ctx.ResponseBody, &response); err != nil {
|
||||
return nil // 忽略解析错误
|
||||
}
|
||||
|
||||
// 过滤内容
|
||||
filtered := h.filterResponse(response)
|
||||
|
||||
// 如果内容被修改,更新响应体
|
||||
if filtered {
|
||||
modifiedBody, err := json.Marshal(response)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
ctx.ResponseBody = modifiedBody
|
||||
|
||||
// 记录过滤事件
|
||||
ctx.Data["content_filtered"] = true
|
||||
common.SysLog("Content filter applied to response")
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// OnError 错误处理
|
||||
func (h *ContentFilterHook) OnError(ctx *interfaces.HookContext, err error) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
// filterResponse 过滤响应内容
|
||||
func (h *ContentFilterHook) filterResponse(response map[string]interface{}) bool {
|
||||
modified := false
|
||||
|
||||
// 获取choices数组
|
||||
choices, ok := response["choices"].([]interface{})
|
||||
if !ok {
|
||||
return false
|
||||
}
|
||||
|
||||
// 遍历每个choice
|
||||
for _, choice := range choices {
|
||||
choiceMap, ok := choice.(map[string]interface{})
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
|
||||
// 获取message
|
||||
message, ok := choiceMap["message"].(map[string]interface{})
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
|
||||
// 获取content
|
||||
content, ok := message["content"].(string)
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
|
||||
// 过滤内容
|
||||
filteredContent := h.filterText(content)
|
||||
|
||||
// 如果内容被修改
|
||||
if filteredContent != content {
|
||||
message["content"] = filteredContent
|
||||
modified = true
|
||||
}
|
||||
}
|
||||
|
||||
return modified
|
||||
}
|
||||
|
||||
// filterText 过滤文本内容
|
||||
func (h *ContentFilterHook) filterText(text string) string {
|
||||
filtered := text
|
||||
|
||||
// 过滤敏感词
|
||||
for _, word := range h.sensitiveWords {
|
||||
if strings.Contains(filtered, word) {
|
||||
filtered = strings.ReplaceAll(filtered, word, h.replacementText)
|
||||
}
|
||||
}
|
||||
|
||||
// TODO: 实现更复杂的过滤逻辑
|
||||
// - NSFW内容检测
|
||||
// - 政治敏感内容检测
|
||||
// - 使用AI模型进行内容分类
|
||||
|
||||
return filtered
|
||||
}
|
||||
|
||||
39
plugins/hooks/content_filter/init.go
Normal file
39
plugins/hooks/content_filter/init.go
Normal file
@@ -0,0 +1,39 @@
|
||||
package content_filter
|
||||
|
||||
import (
|
||||
"os"
|
||||
"strings"
|
||||
|
||||
"github.com/QuantumNous/new-api/common"
|
||||
"github.com/QuantumNous/new-api/core/registry"
|
||||
)
|
||||
|
||||
func init() {
|
||||
// 从环境变量读取配置
|
||||
config := map[string]interface{}{
|
||||
"enabled": os.Getenv("CONTENT_FILTER_ENABLED") == "true",
|
||||
"priority": 100,
|
||||
"filter_nsfw": os.Getenv("CONTENT_FILTER_NSFW") != "false",
|
||||
"filter_political": os.Getenv("CONTENT_FILTER_POLITICAL") == "true",
|
||||
}
|
||||
|
||||
// 读取敏感词列表
|
||||
if wordsEnv := os.Getenv("CONTENT_FILTER_WORDS"); wordsEnv != "" {
|
||||
words := strings.Split(wordsEnv, ",")
|
||||
config["sensitive_words"] = words
|
||||
}
|
||||
|
||||
// 创建并注册Hook
|
||||
hook := NewContentFilterHook(config)
|
||||
|
||||
if err := registry.RegisterHook(hook); err != nil {
|
||||
common.SysError("Failed to register content_filter hook: " + err.Error())
|
||||
} else {
|
||||
if hook.Enabled() {
|
||||
common.SysLog("Content filter hook registered and enabled")
|
||||
} else {
|
||||
common.SysLog("Content filter hook registered but disabled")
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
39
plugins/hooks/web_search/init.go
Normal file
39
plugins/hooks/web_search/init.go
Normal file
@@ -0,0 +1,39 @@
|
||||
package web_search
|
||||
|
||||
import (
|
||||
"os"
|
||||
|
||||
"github.com/QuantumNous/new-api/common"
|
||||
"github.com/QuantumNous/new-api/core/registry"
|
||||
)
|
||||
|
||||
func init() {
|
||||
// 从环境变量读取配置
|
||||
config := map[string]interface{}{
|
||||
"enabled": os.Getenv("WEB_SEARCH_ENABLED") == "true",
|
||||
"api_key": os.Getenv("WEB_SEARCH_API_KEY"),
|
||||
"provider": getEnvOrDefault("WEB_SEARCH_PROVIDER", "google"),
|
||||
"priority": 50,
|
||||
}
|
||||
|
||||
// 创建并注册Hook
|
||||
hook := NewWebSearchHook(config)
|
||||
|
||||
if err := registry.RegisterHook(hook); err != nil {
|
||||
common.SysError("Failed to register web_search hook: " + err.Error())
|
||||
} else {
|
||||
if hook.Enabled() {
|
||||
common.SysLog("Web search hook registered and enabled")
|
||||
} else {
|
||||
common.SysLog("Web search hook registered but disabled (missing API key or not enabled)")
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func getEnvOrDefault(key, defaultValue string) string {
|
||||
if value := os.Getenv(key); value != "" {
|
||||
return value
|
||||
}
|
||||
return defaultValue
|
||||
}
|
||||
|
||||
281
plugins/hooks/web_search/web_search_hook.go
Normal file
281
plugins/hooks/web_search/web_search_hook.go
Normal file
@@ -0,0 +1,281 @@
|
||||
package web_search
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"strings"
|
||||
|
||||
"github.com/QuantumNous/new-api/common"
|
||||
"github.com/QuantumNous/new-api/core/interfaces"
|
||||
)
|
||||
|
||||
// WebSearchHook 联网搜索Hook插件
|
||||
// 在请求发送前检测是否需要联网搜索,如果需要则调用搜索API并将结果注入到请求中
|
||||
type WebSearchHook struct {
|
||||
enabled bool
|
||||
priority int
|
||||
apiKey string
|
||||
provider string // google, bing, etc
|
||||
}
|
||||
|
||||
// NewWebSearchHook 创建WebSearchHook实例
|
||||
func NewWebSearchHook(config map[string]interface{}) *WebSearchHook {
|
||||
hook := &WebSearchHook{
|
||||
enabled: true,
|
||||
priority: 50, // 中等优先级
|
||||
provider: "google",
|
||||
}
|
||||
|
||||
if apiKey, ok := config["api_key"].(string); ok {
|
||||
hook.apiKey = apiKey
|
||||
}
|
||||
|
||||
if provider, ok := config["provider"].(string); ok {
|
||||
hook.provider = provider
|
||||
}
|
||||
|
||||
if priority, ok := config["priority"].(int); ok {
|
||||
hook.priority = priority
|
||||
}
|
||||
|
||||
if enabled, ok := config["enabled"].(bool); ok {
|
||||
hook.enabled = enabled
|
||||
}
|
||||
|
||||
return hook
|
||||
}
|
||||
|
||||
// Name 返回Hook名称
|
||||
func (h *WebSearchHook) Name() string {
|
||||
return "web_search"
|
||||
}
|
||||
|
||||
// Priority 返回优先级
|
||||
func (h *WebSearchHook) Priority() int {
|
||||
return h.priority
|
||||
}
|
||||
|
||||
// Enabled 返回是否启用
|
||||
func (h *WebSearchHook) Enabled() bool {
|
||||
return h.enabled && h.apiKey != ""
|
||||
}
|
||||
|
||||
// OnBeforeRequest 请求前处理
|
||||
func (h *WebSearchHook) OnBeforeRequest(ctx *interfaces.HookContext) error {
|
||||
if !h.Enabled() {
|
||||
return nil
|
||||
}
|
||||
|
||||
// 只处理chat completion请求
|
||||
if !strings.Contains(ctx.Request.URL.Path, "chat/completions") {
|
||||
return nil
|
||||
}
|
||||
|
||||
// 检查请求体中是否包含搜索关键词
|
||||
if len(ctx.RequestBody) == 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
// 解析请求体
|
||||
var requestData map[string]interface{}
|
||||
if err := json.Unmarshal(ctx.RequestBody, &requestData); err != nil {
|
||||
return nil // 忽略解析错误
|
||||
}
|
||||
|
||||
// 检查是否需要搜索(简单示例:检查最后一条消息是否包含 [search] 标记)
|
||||
if !h.shouldSearch(requestData) {
|
||||
return nil
|
||||
}
|
||||
|
||||
// 执行搜索
|
||||
searchQuery := h.extractSearchQuery(requestData)
|
||||
if searchQuery == "" {
|
||||
return nil
|
||||
}
|
||||
|
||||
common.SysLog(fmt.Sprintf("Web search triggered for query: %s", searchQuery))
|
||||
|
||||
// 调用搜索API
|
||||
searchResults, err := h.performSearch(searchQuery)
|
||||
if err != nil {
|
||||
common.SysError(fmt.Sprintf("Web search failed: %v", err))
|
||||
return nil // 不中断请求,只记录错误
|
||||
}
|
||||
|
||||
// 将搜索结果注入到请求中
|
||||
h.injectSearchResults(requestData, searchResults)
|
||||
|
||||
// 更新请求体
|
||||
modifiedBody, err := json.Marshal(requestData)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
ctx.RequestBody = modifiedBody
|
||||
|
||||
// 存储到Data中供后续使用
|
||||
ctx.Data["web_search_performed"] = true
|
||||
ctx.Data["web_search_query"] = searchQuery
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// OnAfterResponse 响应后处理
|
||||
func (h *WebSearchHook) OnAfterResponse(ctx *interfaces.HookContext) error {
|
||||
// 可以在这里记录搜索使用情况等
|
||||
if performed, ok := ctx.Data["web_search_performed"].(bool); ok && performed {
|
||||
query := ctx.Data["web_search_query"].(string)
|
||||
common.SysLog(fmt.Sprintf("Web search completed for query: %s", query))
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// OnError 错误处理
|
||||
func (h *WebSearchHook) OnError(ctx *interfaces.HookContext, err error) error {
|
||||
// 记录错误但不影响主流程
|
||||
if performed, ok := ctx.Data["web_search_performed"].(bool); ok && performed {
|
||||
common.SysError(fmt.Sprintf("Request failed after web search: %v", err))
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// shouldSearch 判断是否需要搜索
|
||||
func (h *WebSearchHook) shouldSearch(requestData map[string]interface{}) bool {
|
||||
messages, ok := requestData["messages"].([]interface{})
|
||||
if !ok || len(messages) == 0 {
|
||||
return false
|
||||
}
|
||||
|
||||
// 检查最后一条消息
|
||||
lastMessage, ok := messages[len(messages)-1].(map[string]interface{})
|
||||
if !ok {
|
||||
return false
|
||||
}
|
||||
|
||||
content, ok := lastMessage["content"].(string)
|
||||
if !ok {
|
||||
return false
|
||||
}
|
||||
|
||||
// 简单示例:检查是否包含 [search] 或 [联网] 标记
|
||||
return strings.Contains(content, "[search]") ||
|
||||
strings.Contains(content, "[联网]") ||
|
||||
strings.Contains(content, "[web]")
|
||||
}
|
||||
|
||||
// extractSearchQuery 提取搜索查询
|
||||
func (h *WebSearchHook) extractSearchQuery(requestData map[string]interface{}) string {
|
||||
messages, ok := requestData["messages"].([]interface{})
|
||||
if !ok || len(messages) == 0 {
|
||||
return ""
|
||||
}
|
||||
|
||||
lastMessage, ok := messages[len(messages)-1].(map[string]interface{})
|
||||
if !ok {
|
||||
return ""
|
||||
}
|
||||
|
||||
content, ok := lastMessage["content"].(string)
|
||||
if !ok {
|
||||
return ""
|
||||
}
|
||||
|
||||
// 移除标记,保留实际查询内容
|
||||
query := strings.ReplaceAll(content, "[search]", "")
|
||||
query = strings.ReplaceAll(query, "[联网]", "")
|
||||
query = strings.ReplaceAll(query, "[web]", "")
|
||||
query = strings.TrimSpace(query)
|
||||
|
||||
return query
|
||||
}
|
||||
|
||||
// performSearch 执行搜索
|
||||
func (h *WebSearchHook) performSearch(query string) (string, error) {
|
||||
// 这里是示例实现,实际应该调用真实的搜索API
|
||||
// 例如:Google Custom Search API, Bing Search API等
|
||||
|
||||
if h.apiKey == "" {
|
||||
return "", fmt.Errorf("search API key not configured")
|
||||
}
|
||||
|
||||
// 示例:返回模拟结果
|
||||
// 实际实现需要调用真实API
|
||||
return h.mockSearch(query)
|
||||
}
|
||||
|
||||
// mockSearch 模拟搜索(示例)
|
||||
func (h *WebSearchHook) mockSearch(query string) (string, error) {
|
||||
// 这只是一个示例实现
|
||||
// 实际应该调用真实的搜索API
|
||||
|
||||
common.SysLog(fmt.Sprintf("[Mock] Searching for: %s", query))
|
||||
|
||||
// 返回模拟的搜索结果
|
||||
return fmt.Sprintf("搜索结果 (模拟):关于 '%s' 的最新信息...", query), nil
|
||||
}
|
||||
|
||||
// realSearch 真实搜索实现示例(需要配置API)
|
||||
func (h *WebSearchHook) realSearch(query string) (string, error) {
|
||||
// 示例:使用Google Custom Search API
|
||||
url := fmt.Sprintf("https://www.googleapis.com/customsearch/v1?key=%s&cx=YOUR_CX&q=%s",
|
||||
h.apiKey, query)
|
||||
|
||||
resp, err := http.Get(url)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
body, err := io.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
// 解析搜索结果
|
||||
var result map[string]interface{}
|
||||
if err := json.Unmarshal(body, &result); err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
// 提取搜索结果摘要
|
||||
// 这里需要根据实际API响应格式处理
|
||||
return string(body), nil
|
||||
}
|
||||
|
||||
// injectSearchResults 将搜索结果注入到请求中
|
||||
func (h *WebSearchHook) injectSearchResults(requestData map[string]interface{}, results string) {
|
||||
messages, ok := requestData["messages"].([]interface{})
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
|
||||
// 在用户消息前插入系统消息,包含搜索结果
|
||||
systemMessage := map[string]interface{}{
|
||||
"role": "system",
|
||||
"content": fmt.Sprintf("以下是针对用户查询的最新搜索结果:\n\n%s\n\n请基于这些信息回答用户的问题。", results),
|
||||
}
|
||||
|
||||
// 插入到消息列表的适当位置
|
||||
updatedMessages := make([]interface{}, 0, len(messages)+1)
|
||||
|
||||
// 如果第一条是系统消息,在其后插入
|
||||
if len(messages) > 0 {
|
||||
if firstMsg, ok := messages[0].(map[string]interface{}); ok {
|
||||
if role, ok := firstMsg["role"].(string); ok && role == "system" {
|
||||
updatedMessages = append(updatedMessages, messages[0])
|
||||
updatedMessages = append(updatedMessages, systemMessage)
|
||||
updatedMessages = append(updatedMessages, messages[1:]...)
|
||||
requestData["messages"] = updatedMessages
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// 否则插入到开头
|
||||
updatedMessages = append(updatedMessages, systemMessage)
|
||||
updatedMessages = append(updatedMessages, messages...)
|
||||
requestData["messages"] = updatedMessages
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user