diff --git a/relay/channel/claude/relay-claude.go b/relay/channel/claude/relay-claude.go index df59331ad..d9e522514 100644 --- a/relay/channel/claude/relay-claude.go +++ b/relay/channel/claude/relay-claude.go @@ -404,15 +404,12 @@ func RequestOpenAI2ClaudeMessage(c *gin.Context, textRequest dto.GeneralOpenAIRe return &claudeRequest, nil } -func StreamResponseClaude2OpenAI(claudeResponse *dto.ClaudeResponse, claudeInfo *ClaudeResponseInfo) *dto.ChatCompletionsStreamResponse { +func StreamResponseClaude2OpenAI(claudeResponse *dto.ClaudeResponse) *dto.ChatCompletionsStreamResponse { var response dto.ChatCompletionsStreamResponse response.Object = "chat.completion.chunk" response.Model = claudeResponse.Model response.Choices = make([]dto.ChatCompletionsStreamResponseChoice, 0) tools := make([]dto.ToolCallResponse, 0) - if claudeInfo != nil && claudeInfo.ToolCallStreamStates == nil { - claudeInfo.ToolCallStreamStates = make(map[int]*ToolCallStreamState) - } fcIdx := 0 if claudeResponse.Index != nil { fcIdx = *claudeResponse.Index - 1 @@ -436,13 +433,6 @@ func StreamResponseClaude2OpenAI(claudeResponse *dto.ClaudeResponse, claudeInfo choice.Delta.SetContentString(*claudeResponse.ContentBlock.Text) } if claudeResponse.ContentBlock.Type == "tool_use" { - if claudeInfo != nil { - claudeInfo.ToolCallStreamStates[fcIdx] = &ToolCallStreamState{ - ID: claudeResponse.ContentBlock.Id, - Name: claudeResponse.ContentBlock.Name, - } - return nil - } tools = append(tools, dto.ToolCallResponse{ Index: common.GetPointer(fcIdx), ID: claudeResponse.ContentBlock.Id, @@ -461,28 +451,19 @@ func StreamResponseClaude2OpenAI(claudeResponse *dto.ClaudeResponse, claudeInfo choice.Delta.Content = claudeResponse.Delta.Text switch claudeResponse.Delta.Type { case "input_json_delta": - if claudeResponse.Delta.PartialJson == nil { - return nil + arguments := "{}" + if claudeResponse.Delta.PartialJson != nil { + if partial := strings.TrimSpace(*claudeResponse.Delta.PartialJson); partial != "" { + arguments = partial + } } - arguments := *claudeResponse.Delta.PartialJson - if strings.TrimSpace(arguments) == "" { - return nil - } - toolCall := dto.ToolCallResponse{ + tools = append(tools, dto.ToolCallResponse{ Type: "function", Index: common.GetPointer(fcIdx), Function: dto.FunctionResponse{ Arguments: arguments, }, - } - if claudeInfo != nil { - if state, ok := claudeInfo.ToolCallStreamStates[fcIdx]; ok { - state.Emitted = true - toolCall.ID = state.ID - toolCall.Function.Name = state.Name - } - } - tools = append(tools, toolCall) + }) case "signature_delta": // 加密的不处理 signatureContent := "\n" @@ -491,27 +472,6 @@ func StreamResponseClaude2OpenAI(claudeResponse *dto.ClaudeResponse, claudeInfo choice.Delta.ReasoningContent = claudeResponse.Delta.Thinking } } - } else if claudeResponse.Type == "content_block_stop" { - if claudeInfo == nil { - return nil - } - state, ok := claudeInfo.ToolCallStreamStates[fcIdx] - if !ok { - return nil - } - delete(claudeInfo.ToolCallStreamStates, fcIdx) - if state.Emitted { - return nil - } - tools = append(tools, dto.ToolCallResponse{ - ID: state.ID, - Type: "function", - Index: common.GetPointer(fcIdx), - Function: dto.FunctionResponse{ - Name: state.Name, - Arguments: "{}", - }, - }) } else if claudeResponse.Type == "message_delta" { if claudeResponse.Delta != nil && claudeResponse.Delta.StopReason != nil { finishReason := stopReasonClaude2OpenAI(*claudeResponse.Delta.StopReason) @@ -596,19 +556,12 @@ func ResponseClaude2OpenAI(claudeResponse *dto.ClaudeResponse) *dto.OpenAITextRe } type ClaudeResponseInfo struct { - ResponseId string - Created int64 - Model string - ResponseText strings.Builder - Usage *dto.Usage - Done bool - ToolCallStreamStates map[int]*ToolCallStreamState -} - -type ToolCallStreamState struct { - ID string - Name string - Emitted bool + ResponseId string + Created int64 + Model string + ResponseText strings.Builder + Usage *dto.Usage + Done bool } func buildMessageDeltaPatchUsage(claudeResponse *dto.ClaudeResponse, claudeInfo *ClaudeResponseInfo) *dto.ClaudeUsage { @@ -741,7 +694,7 @@ func FormatClaudeResponseInfo(claudeResponse *dto.ClaudeResponse, oaiResponse *d // 判断是否完整 claudeInfo.Done = true - } else if claudeResponse.Type == "content_block_start" || claudeResponse.Type == "content_block_stop" { + } else if claudeResponse.Type == "content_block_start" { } else { return false } @@ -786,10 +739,7 @@ func HandleStreamResponseData(c *gin.Context, info *relaycommon.RelayInfo, claud } helper.ClaudeChunkData(c, claudeResponse, data) } else if info.RelayFormat == types.RelayFormatOpenAI { - response := StreamResponseClaude2OpenAI(&claudeResponse, claudeInfo) - if response == nil { - return nil - } + response := StreamResponseClaude2OpenAI(&claudeResponse) if !FormatClaudeResponseInfo(&claudeResponse, response, claudeInfo) { return nil diff --git a/relay/channel/claude/relay_claude_test.go b/relay/channel/claude/relay_claude_test.go index c1dffcb97..986788cf9 100644 --- a/relay/channel/claude/relay_claude_test.go +++ b/relay/channel/claude/relay_claude_test.go @@ -216,7 +216,7 @@ func TestRequestOpenAI2ClaudeMessage_AssistantToolCallWithMalformedArguments(t * assert.Empty(t, inputObj) } -func TestStreamResponseClaude2OpenAI_EmptyInputJSONDeltaIgnored(t *testing.T) { +func TestStreamResponseClaude2OpenAI_EmptyInputJSONDeltaFallback(t *testing.T) { empty := "" resp := &dto.ClaudeResponse{ Type: "content_block_delta", @@ -227,8 +227,12 @@ func TestStreamResponseClaude2OpenAI_EmptyInputJSONDeltaIgnored(t *testing.T) { }, } - chunk := StreamResponseClaude2OpenAI(resp, &ClaudeResponseInfo{}) - require.Nil(t, chunk) + chunk := StreamResponseClaude2OpenAI(resp) + require.NotNil(t, chunk) + require.Len(t, chunk.Choices, 1) + require.NotNil(t, chunk.Choices[0].Delta.ToolCalls) + require.Len(t, chunk.Choices[0].Delta.ToolCalls, 1) + assert.Equal(t, "{}", chunk.Choices[0].Delta.ToolCalls[0].Function.Arguments) } func TestStreamResponseClaude2OpenAI_NonEmptyInputJSONDeltaPreserved(t *testing.T) { @@ -242,71 +246,10 @@ func TestStreamResponseClaude2OpenAI_NonEmptyInputJSONDeltaPreserved(t *testing. }, } - chunk := StreamResponseClaude2OpenAI(resp, &ClaudeResponseInfo{}) + chunk := StreamResponseClaude2OpenAI(resp) require.NotNil(t, chunk) require.Len(t, chunk.Choices, 1) require.NotNil(t, chunk.Choices[0].Delta.ToolCalls) require.Len(t, chunk.Choices[0].Delta.ToolCalls, 1) assert.Equal(t, partial, chunk.Choices[0].Delta.ToolCalls[0].Function.Arguments) } - -func TestStreamResponseClaude2OpenAI_NoArgToolEmitsObjectAtStop(t *testing.T) { - claudeInfo := &ClaudeResponseInfo{} - start := &dto.ClaudeResponse{ - Type: "content_block_start", - Index: func() *int { v := 1; return &v }(), - ContentBlock: &dto.ClaudeMediaMessage{ - Type: "tool_use", - Id: "toolu_1", - Name: "get_current_time", - }, - } - stop := &dto.ClaudeResponse{ - Type: "content_block_stop", - Index: func() *int { v := 1; return &v }(), - } - - startChunk := StreamResponseClaude2OpenAI(start, claudeInfo) - require.Nil(t, startChunk) - - stopChunk := StreamResponseClaude2OpenAI(stop, claudeInfo) - require.NotNil(t, stopChunk) - require.Len(t, stopChunk.Choices, 1) - require.Len(t, stopChunk.Choices[0].Delta.ToolCalls, 1) - assert.Equal(t, "toolu_1", stopChunk.Choices[0].Delta.ToolCalls[0].ID) - assert.Equal(t, "get_current_time", stopChunk.Choices[0].Delta.ToolCalls[0].Function.Name) - assert.Equal(t, "{}", stopChunk.Choices[0].Delta.ToolCalls[0].Function.Arguments) -} - -func TestStreamResponseClaude2OpenAI_ArgToolKeepsIDNameOnDelta(t *testing.T) { - claudeInfo := &ClaudeResponseInfo{} - start := &dto.ClaudeResponse{ - Type: "content_block_start", - Index: func() *int { v := 1; return &v }(), - ContentBlock: &dto.ClaudeMediaMessage{ - Type: "tool_use", - Id: "toolu_2", - Name: "search_notes", - }, - } - partial := `{"query":"today"}` - delta := &dto.ClaudeResponse{ - Type: "content_block_delta", - Index: func() *int { v := 1; return &v }(), - Delta: &dto.ClaudeMediaMessage{ - Type: "input_json_delta", - PartialJson: &partial, - }, - } - - startChunk := StreamResponseClaude2OpenAI(start, claudeInfo) - require.Nil(t, startChunk) - - deltaChunk := StreamResponseClaude2OpenAI(delta, claudeInfo) - require.NotNil(t, deltaChunk) - require.Len(t, deltaChunk.Choices, 1) - require.Len(t, deltaChunk.Choices[0].Delta.ToolCalls, 1) - assert.Equal(t, "toolu_2", deltaChunk.Choices[0].Delta.ToolCalls[0].ID) - assert.Equal(t, "search_notes", deltaChunk.Choices[0].Delta.ToolCalls[0].Function.Name) - assert.Equal(t, partial, deltaChunk.Choices[0].Delta.ToolCalls[0].Function.Arguments) -}