diff --git a/relay/channel/claude/relay-claude.go b/relay/channel/claude/relay-claude.go index d9e522514..0636ecd44 100644 --- a/relay/channel/claude/relay-claude.go +++ b/relay/channel/claude/relay-claude.go @@ -250,9 +250,6 @@ func RequestOpenAI2ClaudeMessage(c *gin.Context, textRequest dto.GeneralOpenAIRe } if message.Role == "assistant" && message.ToolCalls != nil { fmtMessage.ToolCalls = message.ToolCalls - if message.IsStringContent() && message.StringContent() == "" { - fmtMessage.SetNullContent() - } } if lastMessage.Role == message.Role && lastMessage.Role != "tool" { if lastMessage.IsStringContent() && message.IsStringContent() { @@ -261,7 +258,7 @@ func RequestOpenAI2ClaudeMessage(c *gin.Context, textRequest dto.GeneralOpenAIRe formatMessages = formatMessages[:len(formatMessages)-1] } } - if fmtMessage.Content == nil && !(message.Role == "assistant" && message.ToolCalls != nil) { + if fmtMessage.Content == nil { fmtMessage.SetStringContent("...") } formatMessages = append(formatMessages, fmtMessage) @@ -376,9 +373,9 @@ func RequestOpenAI2ClaudeMessage(c *gin.Context, textRequest dto.GeneralOpenAIRe if message.ToolCalls != nil { for _, toolCall := range message.ParseToolCalls() { inputObj := make(map[string]any) - if err := common.UnmarshalJsonStr(toolCall.Function.Arguments, &inputObj); err != nil { + if err := json.Unmarshal([]byte(toolCall.Function.Arguments), &inputObj); err != nil { common.SysLog("tool call function arguments is not a map[string]any: " + fmt.Sprintf("%v", toolCall.Function.Arguments)) - inputObj = map[string]any{} + continue } claudeMediaMessages = append(claudeMediaMessages, dto.ClaudeMediaMessage{ Type: "tool_use", @@ -451,17 +448,11 @@ func StreamResponseClaude2OpenAI(claudeResponse *dto.ClaudeResponse) *dto.ChatCo choice.Delta.Content = claudeResponse.Delta.Text switch claudeResponse.Delta.Type { case "input_json_delta": - arguments := "{}" - if claudeResponse.Delta.PartialJson != nil { - if partial := strings.TrimSpace(*claudeResponse.Delta.PartialJson); partial != "" { - arguments = partial - } - } tools = append(tools, dto.ToolCallResponse{ Type: "function", Index: common.GetPointer(fcIdx), Function: dto.FunctionResponse{ - Arguments: arguments, + Arguments: *claudeResponse.Delta.PartialJson, }, }) case "signature_delta": diff --git a/relay/channel/claude/relay_claude_test.go b/relay/channel/claude/relay_claude_test.go index 986788cf9..e34c861ac 100644 --- a/relay/channel/claude/relay_claude_test.go +++ b/relay/channel/claude/relay_claude_test.go @@ -5,8 +5,6 @@ import ( "testing" "github.com/QuantumNous/new-api/dto" - "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/require" ) func TestFormatClaudeResponseInfo_MessageStart(t *testing.T) { @@ -28,15 +26,28 @@ func TestFormatClaudeResponseInfo_MessageStart(t *testing.T) { } ok := FormatClaudeResponseInfo(claudeResponse, nil, claudeInfo) - require.True(t, ok) - assert.Equal(t, 100, claudeInfo.Usage.PromptTokens) - assert.Equal(t, 30, claudeInfo.Usage.PromptTokensDetails.CachedTokens) - assert.Equal(t, 50, claudeInfo.Usage.PromptTokensDetails.CachedCreationTokens) - assert.Equal(t, "msg_123", claudeInfo.ResponseId) - assert.Equal(t, "claude-3-5-sonnet", claudeInfo.Model) + if !ok { + t.Fatal("expected true") + } + if claudeInfo.Usage.PromptTokens != 100 { + t.Errorf("PromptTokens = %d, want 100", claudeInfo.Usage.PromptTokens) + } + if claudeInfo.Usage.PromptTokensDetails.CachedTokens != 30 { + t.Errorf("CachedTokens = %d, want 30", claudeInfo.Usage.PromptTokensDetails.CachedTokens) + } + if claudeInfo.Usage.PromptTokensDetails.CachedCreationTokens != 50 { + t.Errorf("CachedCreationTokens = %d, want 50", claudeInfo.Usage.PromptTokensDetails.CachedCreationTokens) + } + if claudeInfo.ResponseId != "msg_123" { + t.Errorf("ResponseId = %s, want msg_123", claudeInfo.ResponseId) + } + if claudeInfo.Model != "claude-3-5-sonnet" { + t.Errorf("Model = %s, want claude-3-5-sonnet", claudeInfo.Model) + } } func TestFormatClaudeResponseInfo_MessageDelta_FullUsage(t *testing.T) { + // message_start 先积累 usage claudeInfo := &ClaudeResponseInfo{ Usage: &dto.Usage{ PromptTokens: 100, @@ -48,6 +59,7 @@ func TestFormatClaudeResponseInfo_MessageDelta_FullUsage(t *testing.T) { }, } + // message_delta 带完整 usage(原生 Anthropic 场景) claudeResponse := &dto.ClaudeResponse{ Type: "message_delta", Usage: &dto.ClaudeUsage{ @@ -59,14 +71,25 @@ func TestFormatClaudeResponseInfo_MessageDelta_FullUsage(t *testing.T) { } ok := FormatClaudeResponseInfo(claudeResponse, nil, claudeInfo) - require.True(t, ok) - assert.Equal(t, 100, claudeInfo.Usage.PromptTokens) - assert.Equal(t, 200, claudeInfo.Usage.CompletionTokens) - assert.Equal(t, 300, claudeInfo.Usage.TotalTokens) - assert.True(t, claudeInfo.Done) + if !ok { + t.Fatal("expected true") + } + if claudeInfo.Usage.PromptTokens != 100 { + t.Errorf("PromptTokens = %d, want 100", claudeInfo.Usage.PromptTokens) + } + if claudeInfo.Usage.CompletionTokens != 200 { + t.Errorf("CompletionTokens = %d, want 200", claudeInfo.Usage.CompletionTokens) + } + if claudeInfo.Usage.TotalTokens != 300 { + t.Errorf("TotalTokens = %d, want 300", claudeInfo.Usage.TotalTokens) + } + if !claudeInfo.Done { + t.Error("expected Done = true") + } } func TestFormatClaudeResponseInfo_MessageDelta_OnlyOutputTokens(t *testing.T) { + // 模拟 Bedrock: message_start 已积累 usage claudeInfo := &ClaudeResponseInfo{ Usage: &dto.Usage{ PromptTokens: 100, @@ -80,29 +103,53 @@ func TestFormatClaudeResponseInfo_MessageDelta_OnlyOutputTokens(t *testing.T) { }, } + // Bedrock 的 message_delta 只有 output_tokens,缺少 input_tokens 和 cache 字段 claudeResponse := &dto.ClaudeResponse{ Type: "message_delta", Usage: &dto.ClaudeUsage{ OutputTokens: 200, + // InputTokens, CacheCreationInputTokens, CacheReadInputTokens 都是 0 }, } ok := FormatClaudeResponseInfo(claudeResponse, nil, claudeInfo) - require.True(t, ok) - assert.Equal(t, 100, claudeInfo.Usage.PromptTokens) - assert.Equal(t, 200, claudeInfo.Usage.CompletionTokens) - assert.Equal(t, 300, claudeInfo.Usage.TotalTokens) - assert.Equal(t, 30, claudeInfo.Usage.PromptTokensDetails.CachedTokens) - assert.Equal(t, 50, claudeInfo.Usage.PromptTokensDetails.CachedCreationTokens) - assert.Equal(t, 10, claudeInfo.Usage.ClaudeCacheCreation5mTokens) - assert.Equal(t, 20, claudeInfo.Usage.ClaudeCacheCreation1hTokens) - assert.True(t, claudeInfo.Done) + if !ok { + t.Fatal("expected true") + } + // PromptTokens 应保持 message_start 的值(因为 message_delta 的 InputTokens=0,不更新) + if claudeInfo.Usage.PromptTokens != 100 { + t.Errorf("PromptTokens = %d, want 100", claudeInfo.Usage.PromptTokens) + } + if claudeInfo.Usage.CompletionTokens != 200 { + t.Errorf("CompletionTokens = %d, want 200", claudeInfo.Usage.CompletionTokens) + } + if claudeInfo.Usage.TotalTokens != 300 { + t.Errorf("TotalTokens = %d, want 300", claudeInfo.Usage.TotalTokens) + } + // cache 字段应保持 message_start 的值 + if claudeInfo.Usage.PromptTokensDetails.CachedTokens != 30 { + t.Errorf("CachedTokens = %d, want 30", claudeInfo.Usage.PromptTokensDetails.CachedTokens) + } + if claudeInfo.Usage.PromptTokensDetails.CachedCreationTokens != 50 { + t.Errorf("CachedCreationTokens = %d, want 50", claudeInfo.Usage.PromptTokensDetails.CachedCreationTokens) + } + if claudeInfo.Usage.ClaudeCacheCreation5mTokens != 10 { + t.Errorf("ClaudeCacheCreation5mTokens = %d, want 10", claudeInfo.Usage.ClaudeCacheCreation5mTokens) + } + if claudeInfo.Usage.ClaudeCacheCreation1hTokens != 20 { + t.Errorf("ClaudeCacheCreation1hTokens = %d, want 20", claudeInfo.Usage.ClaudeCacheCreation1hTokens) + } + if !claudeInfo.Done { + t.Error("expected Done = true") + } } func TestFormatClaudeResponseInfo_NilClaudeInfo(t *testing.T) { claudeResponse := &dto.ClaudeResponse{Type: "message_start"} ok := FormatClaudeResponseInfo(claudeResponse, nil, nil) - assert.False(t, ok) + if ok { + t.Error("expected false for nil claudeInfo") + } } func TestFormatClaudeResponseInfo_ContentBlockDelta(t *testing.T) { @@ -119,137 +166,10 @@ func TestFormatClaudeResponseInfo_ContentBlockDelta(t *testing.T) { } ok := FormatClaudeResponseInfo(claudeResponse, nil, claudeInfo) - require.True(t, ok) - assert.Equal(t, "hello", claudeInfo.ResponseText.String()) -} - -func TestRequestOpenAI2ClaudeMessage_AssistantToolCallWithEmptyContent(t *testing.T) { - request := dto.GeneralOpenAIRequest{ - Model: "claude-opus-4-6", - Messages: []dto.Message{ - { - Role: "user", - Content: "what time is it", - }, - }, + if !ok { + t.Fatal("expected true") } - assistantMessage := dto.Message{ - Role: "assistant", - Content: "", - } - assistantMessage.SetToolCalls([]dto.ToolCallRequest{ - { - ID: "call_1", - Type: "function", - Function: dto.FunctionRequest{ - Name: "get_current_time", - Arguments: "{}", - }, - }, - }) - request.Messages = append(request.Messages, assistantMessage) - - claudeRequest, err := RequestOpenAI2ClaudeMessage(nil, request) - require.NoError(t, err) - require.Len(t, claudeRequest.Messages, 2) - - assistantClaudeMessage := claudeRequest.Messages[1] - assert.Equal(t, "assistant", assistantClaudeMessage.Role) - - contentBlocks, ok := assistantClaudeMessage.Content.([]dto.ClaudeMediaMessage) - require.True(t, ok) - require.Len(t, contentBlocks, 1) - - assert.Equal(t, "tool_use", contentBlocks[0].Type) - assert.Equal(t, "call_1", contentBlocks[0].Id) - assert.Equal(t, "get_current_time", contentBlocks[0].Name) - if assert.NotNil(t, contentBlocks[0].Input) { - _, isMap := contentBlocks[0].Input.(map[string]any) - assert.True(t, isMap) - } - if contentBlocks[0].Text != nil { - assert.NotEqual(t, "", *contentBlocks[0].Text) + if claudeInfo.ResponseText.String() != "hello" { + t.Errorf("ResponseText = %q, want %q", claudeInfo.ResponseText.String(), "hello") } } - -func TestRequestOpenAI2ClaudeMessage_AssistantToolCallWithMalformedArguments(t *testing.T) { - request := dto.GeneralOpenAIRequest{ - Model: "claude-opus-4-6", - Messages: []dto.Message{ - { - Role: "user", - Content: "what time is it", - }, - }, - } - assistantMessage := dto.Message{ - Role: "assistant", - Content: "", - } - assistantMessage.SetToolCalls([]dto.ToolCallRequest{ - { - ID: "call_bad_args", - Type: "function", - Function: dto.FunctionRequest{ - Name: "get_current_timestamp", - Arguments: "{", - }, - }, - }) - request.Messages = append(request.Messages, assistantMessage) - - claudeRequest, err := RequestOpenAI2ClaudeMessage(nil, request) - require.NoError(t, err) - require.Len(t, claudeRequest.Messages, 2) - - assistantClaudeMessage := claudeRequest.Messages[1] - contentBlocks, ok := assistantClaudeMessage.Content.([]dto.ClaudeMediaMessage) - require.True(t, ok) - require.Len(t, contentBlocks, 1) - - assert.Equal(t, "tool_use", contentBlocks[0].Type) - assert.Equal(t, "call_bad_args", contentBlocks[0].Id) - assert.Equal(t, "get_current_timestamp", contentBlocks[0].Name) - - inputObj, ok := contentBlocks[0].Input.(map[string]any) - require.True(t, ok) - assert.Empty(t, inputObj) -} - -func TestStreamResponseClaude2OpenAI_EmptyInputJSONDeltaFallback(t *testing.T) { - empty := "" - resp := &dto.ClaudeResponse{ - Type: "content_block_delta", - Index: func() *int { v := 1; return &v }(), - Delta: &dto.ClaudeMediaMessage{ - Type: "input_json_delta", - PartialJson: &empty, - }, - } - - chunk := StreamResponseClaude2OpenAI(resp) - require.NotNil(t, chunk) - require.Len(t, chunk.Choices, 1) - require.NotNil(t, chunk.Choices[0].Delta.ToolCalls) - require.Len(t, chunk.Choices[0].Delta.ToolCalls, 1) - assert.Equal(t, "{}", chunk.Choices[0].Delta.ToolCalls[0].Function.Arguments) -} - -func TestStreamResponseClaude2OpenAI_NonEmptyInputJSONDeltaPreserved(t *testing.T) { - partial := `{"timezone":"Asia/Shanghai"}` - resp := &dto.ClaudeResponse{ - Type: "content_block_delta", - Index: func() *int { v := 1; return &v }(), - Delta: &dto.ClaudeMediaMessage{ - Type: "input_json_delta", - PartialJson: &partial, - }, - } - - chunk := StreamResponseClaude2OpenAI(resp) - require.NotNil(t, chunk) - require.Len(t, chunk.Choices, 1) - require.NotNil(t, chunk.Choices[0].Delta.ToolCalls) - require.Len(t, chunk.Choices[0].Delta.ToolCalls, 1) - assert.Equal(t, partial, chunk.Choices[0].Delta.ToolCalls[0].Function.Arguments) -}