diff --git a/common/copy.go b/common/copy.go index 3edb2fa25..87c63c4ec 100644 --- a/common/copy.go +++ b/common/copy.go @@ -6,14 +6,21 @@ import ( "github.com/jinzhu/copier" ) -func DeepCopy[T any](src *T) (*T, error) { +func Copy[T any](src *T, deepCopy bool) (*T, error) { if src == nil { return nil, fmt.Errorf("copy source cannot be nil") } var dst T - err := copier.CopyWithOption(&dst, src, copier.Option{DeepCopy: true, IgnoreEmpty: true}) - if err != nil { - return nil, err + if deepCopy { + err := copier.CopyWithOption(&dst, src, copier.Option{DeepCopy: true, IgnoreEmpty: true}) + if err != nil { + return nil, err + } + } else { + err := copier.Copy(&dst, src) + if err != nil { + return nil, err + } } return &dst, nil } diff --git a/dto/gemini.go b/dto/gemini.go index 5df67ba0b..5c8d2a897 100644 --- a/dto/gemini.go +++ b/dto/gemini.go @@ -2,11 +2,12 @@ package dto import ( "encoding/json" - "github.com/gin-gonic/gin" "one-api/common" "one-api/logger" "one-api/types" "strings" + + "github.com/gin-gonic/gin" ) type GeminiChatRequest struct { diff --git a/dto/openai_request.go b/dto/openai_request.go index cd05a63c9..7745a36da 100644 --- a/dto/openai_request.go +++ b/dto/openai_request.go @@ -265,7 +265,7 @@ type Message struct { Reasoning string `json:"reasoning,omitempty"` ToolCalls json.RawMessage `json:"tool_calls,omitempty"` ToolCallId string `json:"tool_call_id,omitempty"` - parsedContent []MediaContent + parsedContent *[]MediaContent //parsedStringContent *string } @@ -441,7 +441,7 @@ func (m *Message) SetStringContent(content string) { func (m *Message) SetMediaContent(content []MediaContent) { m.Content = content - m.parsedContent = content + m.parsedContent = &content } func (m *Message) IsStringContent() bool { @@ -456,8 +456,8 @@ func (m *Message) ParseContent() []MediaContent { if m.Content == nil { return nil } - if len(m.parsedContent) > 0 { - return m.parsedContent + if m.parsedContent != nil && len(*m.parsedContent) > 0 { + return *m.parsedContent } var contentList []MediaContent @@ -468,7 +468,7 @@ func (m *Message) ParseContent() []MediaContent { Type: ContentTypeText, Text: content, }} - m.parsedContent = contentList + m.parsedContent = &contentList return contentList } @@ -580,7 +580,7 @@ func (m *Message) ParseContent() []MediaContent { } if len(contentList) > 0 { - m.parsedContent = contentList + m.parsedContent = &contentList } return contentList } @@ -766,27 +766,27 @@ type WebSearchOptions struct { // https://platform.openai.com/docs/api-reference/responses/create type OpenAIResponsesRequest struct { - Model string `json:"model"` - Input json.RawMessage `json:"input,omitempty"` - Include json.RawMessage `json:"include,omitempty"` - Instructions json.RawMessage `json:"instructions,omitempty"` - MaxOutputTokens uint `json:"max_output_tokens,omitempty"` - Metadata json.RawMessage `json:"metadata,omitempty"` - ParallelToolCalls bool `json:"parallel_tool_calls,omitempty"` - PreviousResponseID string `json:"previous_response_id,omitempty"` - Reasoning *Reasoning `json:"reasoning,omitempty"` - ServiceTier string `json:"service_tier,omitempty"` - Store bool `json:"store,omitempty"` - Stream bool `json:"stream,omitempty"` - Temperature float64 `json:"temperature,omitempty"` - Text json.RawMessage `json:"text,omitempty"` - ToolChoice json.RawMessage `json:"tool_choice,omitempty"` - Tools json.RawMessage `json:"tools,omitempty"` // 需要处理的参数很少,MCP 参数太多不确定,所以用 map - TopP float64 `json:"top_p,omitempty"` - Truncation string `json:"truncation,omitempty"` - User string `json:"user,omitempty"` - MaxToolCalls uint `json:"max_tool_calls,omitempty"` - Prompt json.RawMessage `json:"prompt,omitempty"` + Model string `json:"model"` + Input *json.RawMessage `json:"input,omitempty"` + Include json.RawMessage `json:"include,omitempty"` + Instructions json.RawMessage `json:"instructions,omitempty"` + MaxOutputTokens uint `json:"max_output_tokens,omitempty"` + Metadata json.RawMessage `json:"metadata,omitempty"` + ParallelToolCalls bool `json:"parallel_tool_calls,omitempty"` + PreviousResponseID string `json:"previous_response_id,omitempty"` + Reasoning *Reasoning `json:"reasoning,omitempty"` + ServiceTier string `json:"service_tier,omitempty"` + Store bool `json:"store,omitempty"` + Stream bool `json:"stream,omitempty"` + Temperature float64 `json:"temperature,omitempty"` + Text json.RawMessage `json:"text,omitempty"` + ToolChoice json.RawMessage `json:"tool_choice,omitempty"` + Tools *json.RawMessage `json:"tools,omitempty"` // 需要处理的参数很少,MCP 参数太多不确定,所以用 map + TopP float64 `json:"top_p,omitempty"` + Truncation string `json:"truncation,omitempty"` + User string `json:"user,omitempty"` + MaxToolCalls uint `json:"max_tool_calls,omitempty"` + Prompt json.RawMessage `json:"prompt,omitempty"` } func (r *OpenAIResponsesRequest) GetTokenCountMeta() *types.TokenCountMeta { @@ -837,8 +837,8 @@ func (r *OpenAIResponsesRequest) GetTokenCountMeta() *types.TokenCountMeta { texts = append(texts, string(r.Prompt)) } - if len(r.Tools) > 0 { - texts = append(texts, string(r.Tools)) + if r.Tools != nil && len(*r.Tools) > 0 { + texts = append(texts, string(*r.Tools)) } return &types.TokenCountMeta{ @@ -859,9 +859,9 @@ func (r *OpenAIResponsesRequest) SetModelName(modelName string) { } func (r *OpenAIResponsesRequest) GetToolsMap() []map[string]any { - var toolsMap []map[string]any - if len(r.Tools) > 0 { - _ = common.Unmarshal(r.Tools, &toolsMap) + var toolsMap = make([]map[string]any, 0) + if r.Tools != nil && len(*r.Tools) > 0 { + _ = common.Unmarshal(*r.Tools, &toolsMap) } return toolsMap } @@ -896,17 +896,17 @@ func (r *OpenAIResponsesRequest) ParseInput() []MediaInput { // inputs = append(inputs, MediaInput{Type: "input_text", Text: str}) // return inputs // } - if common.GetJsonType(r.Input) == "string" { + if common.GetJsonType(*r.Input) == "string" { var str string - _ = common.Unmarshal(r.Input, &str) + _ = common.Unmarshal(*r.Input, &str) inputs = append(inputs, MediaInput{Type: "input_text", Text: str}) return inputs } // Try array of parts - if common.GetJsonType(r.Input) == "array" { + if common.GetJsonType(*r.Input) == "array" { var array []any - _ = common.Unmarshal(r.Input, &array) + _ = common.Unmarshal(*r.Input, &array) for _, itemAny := range array { // Already parsed MediaInput if media, ok := itemAny.(MediaInput); ok { diff --git a/relay/audio_handler.go b/relay/audio_handler.go index 711cc7a9b..2be28593a 100644 --- a/relay/audio_handler.go +++ b/relay/audio_handler.go @@ -22,7 +22,7 @@ func AudioHelper(c *gin.Context, info *relaycommon.RelayInfo) (newAPIError *type return types.NewError(errors.New("invalid request type"), types.ErrorCodeInvalidRequest, types.ErrOptionWithSkipRetry()) } - request, err := common.DeepCopy(audioReq) + request, err := common.Copy(audioReq, false) if err != nil { return types.NewError(fmt.Errorf("failed to copy request to AudioRequest: %w", err), types.ErrorCodeInvalidRequest, types.ErrOptionWithSkipRetry()) } diff --git a/relay/claude_handler.go b/relay/claude_handler.go index 59c052f62..a78861e34 100644 --- a/relay/claude_handler.go +++ b/relay/claude_handler.go @@ -27,7 +27,7 @@ func ClaudeHelper(c *gin.Context, info *relaycommon.RelayInfo) (newAPIError *typ return types.NewErrorWithStatusCode(fmt.Errorf("invalid request type, expected *dto.ClaudeRequest, got %T", info.Request), types.ErrorCodeInvalidRequest, http.StatusBadRequest, types.ErrOptionWithSkipRetry()) } - request, err := common.DeepCopy(claudeReq) + request, err := common.Copy(claudeReq, false) if err != nil { return types.NewError(fmt.Errorf("failed to copy request to ClaudeRequest: %w", err), types.ErrorCodeInvalidRequest, types.ErrOptionWithSkipRetry()) } diff --git a/relay/common/relay_info.go b/relay/common/relay_info.go index da572c070..22d8080c2 100644 --- a/relay/common/relay_info.go +++ b/relay/common/relay_info.go @@ -313,7 +313,7 @@ func GenRelayInfoResponses(c *gin.Context, request *dto.OpenAIResponsesRequest) info.ResponsesUsageInfo = &ResponsesUsageInfo{ BuiltInTools: make(map[string]*BuildInToolInfo), } - if len(request.Tools) > 0 { + if request.Tools != nil && len(*request.Tools) > 0 { for _, tool := range request.GetToolsMap() { toolType := common.Interface2String(tool["type"]) info.ResponsesUsageInfo.BuiltInTools[toolType] = &BuildInToolInfo{ diff --git a/relay/compatible_handler.go b/relay/compatible_handler.go index 1f6c525b5..6fb313d3f 100644 --- a/relay/compatible_handler.go +++ b/relay/compatible_handler.go @@ -32,7 +32,7 @@ func TextHelper(c *gin.Context, info *relaycommon.RelayInfo) (newAPIError *types return types.NewErrorWithStatusCode(fmt.Errorf("invalid request type, expected dto.GeneralOpenAIRequest, got %T", info.Request), types.ErrorCodeInvalidRequest, http.StatusBadRequest, types.ErrOptionWithSkipRetry()) } - request, err := common.DeepCopy(textReq) + request, err := common.Copy(textReq, false) if err != nil { return types.NewError(fmt.Errorf("failed to copy request to GeneralOpenAIRequest: %w", err), types.ErrorCodeInvalidRequest, types.ErrOptionWithSkipRetry()) } diff --git a/relay/embedding_handler.go b/relay/embedding_handler.go index 26dcf9719..d6471c297 100644 --- a/relay/embedding_handler.go +++ b/relay/embedding_handler.go @@ -23,7 +23,7 @@ func EmbeddingHelper(c *gin.Context, info *relaycommon.RelayInfo) (newAPIError * return types.NewErrorWithStatusCode(fmt.Errorf("invalid request type, expected *dto.EmbeddingRequest, got %T", info.Request), types.ErrorCodeInvalidRequest, http.StatusBadRequest, types.ErrOptionWithSkipRetry()) } - request, err := common.DeepCopy(embeddingReq) + request, err := common.Copy(embeddingReq, false) if err != nil { return types.NewError(fmt.Errorf("failed to copy request to EmbeddingRequest: %w", err), types.ErrorCodeInvalidRequest, types.ErrOptionWithSkipRetry()) } diff --git a/relay/gemini_handler.go b/relay/gemini_handler.go index 460fd2f58..df504a910 100644 --- a/relay/gemini_handler.go +++ b/relay/gemini_handler.go @@ -58,7 +58,7 @@ func GeminiHelper(c *gin.Context, info *relaycommon.RelayInfo) (newAPIError *typ return types.NewErrorWithStatusCode(fmt.Errorf("invalid request type, expected *dto.GeminiChatRequest, got %T", info.Request), types.ErrorCodeInvalidRequest, http.StatusBadRequest, types.ErrOptionWithSkipRetry()) } - request, err := common.DeepCopy(geminiReq) + request, err := common.Copy(geminiReq, false) if err != nil { return types.NewError(fmt.Errorf("failed to copy request to GeminiChatRequest: %w", err), types.ErrorCodeInvalidRequest, types.ErrOptionWithSkipRetry()) } diff --git a/relay/image_handler.go b/relay/image_handler.go index 14a7103c3..fc48884cd 100644 --- a/relay/image_handler.go +++ b/relay/image_handler.go @@ -26,7 +26,7 @@ func ImageHelper(c *gin.Context, info *relaycommon.RelayInfo) (newAPIError *type return types.NewErrorWithStatusCode(fmt.Errorf("invalid request type, expected dto.ImageRequest, got %T", info.Request), types.ErrorCodeInvalidRequest, http.StatusBadRequest, types.ErrOptionWithSkipRetry()) } - request, err := common.DeepCopy(imageReq) + request, err := common.Copy(imageReq, false) if err != nil { return types.NewError(fmt.Errorf("failed to copy request to ImageRequest: %w", err), types.ErrorCodeInvalidRequest, types.ErrOptionWithSkipRetry()) } diff --git a/relay/rerank_handler.go b/relay/rerank_handler.go index fa3c7bbb4..cdc83c55f 100644 --- a/relay/rerank_handler.go +++ b/relay/rerank_handler.go @@ -24,7 +24,7 @@ func RerankHelper(c *gin.Context, info *relaycommon.RelayInfo) (newAPIError *typ return types.NewErrorWithStatusCode(fmt.Errorf("invalid request type, expected dto.RerankRequest, got %T", info.Request), types.ErrorCodeInvalidRequest, http.StatusBadRequest, types.ErrOptionWithSkipRetry()) } - request, err := common.DeepCopy(rerankReq) + request, err := common.Copy(rerankReq, false) if err != nil { return types.NewError(fmt.Errorf("failed to copy request to ImageRequest: %w", err), types.ErrorCodeInvalidRequest, types.ErrOptionWithSkipRetry()) } diff --git a/relay/responses_handler.go b/relay/responses_handler.go index f5f624c92..61e43daf1 100644 --- a/relay/responses_handler.go +++ b/relay/responses_handler.go @@ -25,7 +25,7 @@ func ResponsesHelper(c *gin.Context, info *relaycommon.RelayInfo) (newAPIError * return types.NewErrorWithStatusCode(fmt.Errorf("invalid request type, expected dto.OpenAIResponsesRequest, got %T", info.Request), types.ErrorCodeInvalidRequest, http.StatusBadRequest, types.ErrOptionWithSkipRetry()) } - request, err := common.DeepCopy(responsesReq) + request, err := common.Copy(responsesReq, false) if err != nil { return types.NewError(fmt.Errorf("failed to copy request to GeneralOpenAIRequest: %w", err), types.ErrorCodeInvalidRequest, types.ErrOptionWithSkipRetry()) }