diff --git a/relay/channel/claude/relay-claude.go b/relay/channel/claude/relay-claude.go index d9e522514..df59331ad 100644 --- a/relay/channel/claude/relay-claude.go +++ b/relay/channel/claude/relay-claude.go @@ -404,12 +404,15 @@ func RequestOpenAI2ClaudeMessage(c *gin.Context, textRequest dto.GeneralOpenAIRe return &claudeRequest, nil } -func StreamResponseClaude2OpenAI(claudeResponse *dto.ClaudeResponse) *dto.ChatCompletionsStreamResponse { +func StreamResponseClaude2OpenAI(claudeResponse *dto.ClaudeResponse, claudeInfo *ClaudeResponseInfo) *dto.ChatCompletionsStreamResponse { var response dto.ChatCompletionsStreamResponse response.Object = "chat.completion.chunk" response.Model = claudeResponse.Model response.Choices = make([]dto.ChatCompletionsStreamResponseChoice, 0) tools := make([]dto.ToolCallResponse, 0) + if claudeInfo != nil && claudeInfo.ToolCallStreamStates == nil { + claudeInfo.ToolCallStreamStates = make(map[int]*ToolCallStreamState) + } fcIdx := 0 if claudeResponse.Index != nil { fcIdx = *claudeResponse.Index - 1 @@ -433,6 +436,13 @@ func StreamResponseClaude2OpenAI(claudeResponse *dto.ClaudeResponse) *dto.ChatCo choice.Delta.SetContentString(*claudeResponse.ContentBlock.Text) } if claudeResponse.ContentBlock.Type == "tool_use" { + if claudeInfo != nil { + claudeInfo.ToolCallStreamStates[fcIdx] = &ToolCallStreamState{ + ID: claudeResponse.ContentBlock.Id, + Name: claudeResponse.ContentBlock.Name, + } + return nil + } tools = append(tools, dto.ToolCallResponse{ Index: common.GetPointer(fcIdx), ID: claudeResponse.ContentBlock.Id, @@ -451,19 +461,28 @@ func StreamResponseClaude2OpenAI(claudeResponse *dto.ClaudeResponse) *dto.ChatCo choice.Delta.Content = claudeResponse.Delta.Text switch claudeResponse.Delta.Type { case "input_json_delta": - arguments := "{}" - if claudeResponse.Delta.PartialJson != nil { - if partial := strings.TrimSpace(*claudeResponse.Delta.PartialJson); partial != "" { - arguments = partial - } + if claudeResponse.Delta.PartialJson == nil { + return nil } - tools = append(tools, dto.ToolCallResponse{ + arguments := *claudeResponse.Delta.PartialJson + if strings.TrimSpace(arguments) == "" { + return nil + } + toolCall := dto.ToolCallResponse{ Type: "function", Index: common.GetPointer(fcIdx), Function: dto.FunctionResponse{ Arguments: arguments, }, - }) + } + if claudeInfo != nil { + if state, ok := claudeInfo.ToolCallStreamStates[fcIdx]; ok { + state.Emitted = true + toolCall.ID = state.ID + toolCall.Function.Name = state.Name + } + } + tools = append(tools, toolCall) case "signature_delta": // 加密的不处理 signatureContent := "\n" @@ -472,6 +491,27 @@ func StreamResponseClaude2OpenAI(claudeResponse *dto.ClaudeResponse) *dto.ChatCo choice.Delta.ReasoningContent = claudeResponse.Delta.Thinking } } + } else if claudeResponse.Type == "content_block_stop" { + if claudeInfo == nil { + return nil + } + state, ok := claudeInfo.ToolCallStreamStates[fcIdx] + if !ok { + return nil + } + delete(claudeInfo.ToolCallStreamStates, fcIdx) + if state.Emitted { + return nil + } + tools = append(tools, dto.ToolCallResponse{ + ID: state.ID, + Type: "function", + Index: common.GetPointer(fcIdx), + Function: dto.FunctionResponse{ + Name: state.Name, + Arguments: "{}", + }, + }) } else if claudeResponse.Type == "message_delta" { if claudeResponse.Delta != nil && claudeResponse.Delta.StopReason != nil { finishReason := stopReasonClaude2OpenAI(*claudeResponse.Delta.StopReason) @@ -556,12 +596,19 @@ func ResponseClaude2OpenAI(claudeResponse *dto.ClaudeResponse) *dto.OpenAITextRe } type ClaudeResponseInfo struct { - ResponseId string - Created int64 - Model string - ResponseText strings.Builder - Usage *dto.Usage - Done bool + ResponseId string + Created int64 + Model string + ResponseText strings.Builder + Usage *dto.Usage + Done bool + ToolCallStreamStates map[int]*ToolCallStreamState +} + +type ToolCallStreamState struct { + ID string + Name string + Emitted bool } func buildMessageDeltaPatchUsage(claudeResponse *dto.ClaudeResponse, claudeInfo *ClaudeResponseInfo) *dto.ClaudeUsage { @@ -694,7 +741,7 @@ func FormatClaudeResponseInfo(claudeResponse *dto.ClaudeResponse, oaiResponse *d // 判断是否完整 claudeInfo.Done = true - } else if claudeResponse.Type == "content_block_start" { + } else if claudeResponse.Type == "content_block_start" || claudeResponse.Type == "content_block_stop" { } else { return false } @@ -739,7 +786,10 @@ func HandleStreamResponseData(c *gin.Context, info *relaycommon.RelayInfo, claud } helper.ClaudeChunkData(c, claudeResponse, data) } else if info.RelayFormat == types.RelayFormatOpenAI { - response := StreamResponseClaude2OpenAI(&claudeResponse) + response := StreamResponseClaude2OpenAI(&claudeResponse, claudeInfo) + if response == nil { + return nil + } if !FormatClaudeResponseInfo(&claudeResponse, response, claudeInfo) { return nil diff --git a/relay/channel/claude/relay_claude_test.go b/relay/channel/claude/relay_claude_test.go index 986788cf9..c1dffcb97 100644 --- a/relay/channel/claude/relay_claude_test.go +++ b/relay/channel/claude/relay_claude_test.go @@ -216,7 +216,7 @@ func TestRequestOpenAI2ClaudeMessage_AssistantToolCallWithMalformedArguments(t * assert.Empty(t, inputObj) } -func TestStreamResponseClaude2OpenAI_EmptyInputJSONDeltaFallback(t *testing.T) { +func TestStreamResponseClaude2OpenAI_EmptyInputJSONDeltaIgnored(t *testing.T) { empty := "" resp := &dto.ClaudeResponse{ Type: "content_block_delta", @@ -227,12 +227,8 @@ func TestStreamResponseClaude2OpenAI_EmptyInputJSONDeltaFallback(t *testing.T) { }, } - chunk := StreamResponseClaude2OpenAI(resp) - require.NotNil(t, chunk) - require.Len(t, chunk.Choices, 1) - require.NotNil(t, chunk.Choices[0].Delta.ToolCalls) - require.Len(t, chunk.Choices[0].Delta.ToolCalls, 1) - assert.Equal(t, "{}", chunk.Choices[0].Delta.ToolCalls[0].Function.Arguments) + chunk := StreamResponseClaude2OpenAI(resp, &ClaudeResponseInfo{}) + require.Nil(t, chunk) } func TestStreamResponseClaude2OpenAI_NonEmptyInputJSONDeltaPreserved(t *testing.T) { @@ -246,10 +242,71 @@ func TestStreamResponseClaude2OpenAI_NonEmptyInputJSONDeltaPreserved(t *testing. }, } - chunk := StreamResponseClaude2OpenAI(resp) + chunk := StreamResponseClaude2OpenAI(resp, &ClaudeResponseInfo{}) require.NotNil(t, chunk) require.Len(t, chunk.Choices, 1) require.NotNil(t, chunk.Choices[0].Delta.ToolCalls) require.Len(t, chunk.Choices[0].Delta.ToolCalls, 1) assert.Equal(t, partial, chunk.Choices[0].Delta.ToolCalls[0].Function.Arguments) } + +func TestStreamResponseClaude2OpenAI_NoArgToolEmitsObjectAtStop(t *testing.T) { + claudeInfo := &ClaudeResponseInfo{} + start := &dto.ClaudeResponse{ + Type: "content_block_start", + Index: func() *int { v := 1; return &v }(), + ContentBlock: &dto.ClaudeMediaMessage{ + Type: "tool_use", + Id: "toolu_1", + Name: "get_current_time", + }, + } + stop := &dto.ClaudeResponse{ + Type: "content_block_stop", + Index: func() *int { v := 1; return &v }(), + } + + startChunk := StreamResponseClaude2OpenAI(start, claudeInfo) + require.Nil(t, startChunk) + + stopChunk := StreamResponseClaude2OpenAI(stop, claudeInfo) + require.NotNil(t, stopChunk) + require.Len(t, stopChunk.Choices, 1) + require.Len(t, stopChunk.Choices[0].Delta.ToolCalls, 1) + assert.Equal(t, "toolu_1", stopChunk.Choices[0].Delta.ToolCalls[0].ID) + assert.Equal(t, "get_current_time", stopChunk.Choices[0].Delta.ToolCalls[0].Function.Name) + assert.Equal(t, "{}", stopChunk.Choices[0].Delta.ToolCalls[0].Function.Arguments) +} + +func TestStreamResponseClaude2OpenAI_ArgToolKeepsIDNameOnDelta(t *testing.T) { + claudeInfo := &ClaudeResponseInfo{} + start := &dto.ClaudeResponse{ + Type: "content_block_start", + Index: func() *int { v := 1; return &v }(), + ContentBlock: &dto.ClaudeMediaMessage{ + Type: "tool_use", + Id: "toolu_2", + Name: "search_notes", + }, + } + partial := `{"query":"today"}` + delta := &dto.ClaudeResponse{ + Type: "content_block_delta", + Index: func() *int { v := 1; return &v }(), + Delta: &dto.ClaudeMediaMessage{ + Type: "input_json_delta", + PartialJson: &partial, + }, + } + + startChunk := StreamResponseClaude2OpenAI(start, claudeInfo) + require.Nil(t, startChunk) + + deltaChunk := StreamResponseClaude2OpenAI(delta, claudeInfo) + require.NotNil(t, deltaChunk) + require.Len(t, deltaChunk.Choices, 1) + require.Len(t, deltaChunk.Choices[0].Delta.ToolCalls, 1) + assert.Equal(t, "toolu_2", deltaChunk.Choices[0].Delta.ToolCalls[0].ID) + assert.Equal(t, "search_notes", deltaChunk.Choices[0].Delta.ToolCalls[0].Function.Name) + assert.Equal(t, partial, deltaChunk.Choices[0].Delta.ToolCalls[0].Function.Arguments) +}