diff --git a/relay/channel/openai/relay_responses.go b/relay/channel/openai/relay_responses.go index e77080e63..88fb88083 100644 --- a/relay/channel/openai/relay_responses.go +++ b/relay/channel/openai/relay_responses.go @@ -73,9 +73,15 @@ func OaiResponsesStreamHandler(c *gin.Context, info *relaycommon.RelayInfo, resp switch streamResponse.Type { case "response.completed": if streamResponse.Response.Usage != nil { - usage.PromptTokens = streamResponse.Response.Usage.InputTokens - usage.CompletionTokens = streamResponse.Response.Usage.OutputTokens - usage.TotalTokens = streamResponse.Response.Usage.TotalTokens + if streamResponse.Response.Usage.InputTokens != 0 { + usage.PromptTokens = streamResponse.Response.Usage.InputTokens + } + if streamResponse.Response.Usage.OutputTokens != 0 { + usage.CompletionTokens = streamResponse.Response.Usage.OutputTokens + } + if streamResponse.Response.Usage.TotalTokens != 0 { + usage.TotalTokens = streamResponse.Response.Usage.TotalTokens + } if streamResponse.Response.Usage.InputTokensDetails != nil { usage.PromptTokensDetails.CachedTokens = streamResponse.Response.Usage.InputTokensDetails.CachedTokens } @@ -110,9 +116,9 @@ func OaiResponsesStreamHandler(c *gin.Context, info *relaycommon.RelayInfo, resp if usage.PromptTokens == 0 && usage.CompletionTokens != 0 { usage.PromptTokens = info.PromptTokens - } else { - usage.TotalTokens = usage.PromptTokens + usage.CompletionTokens } + usage.TotalTokens = usage.PromptTokens + usage.CompletionTokens + return usage, nil } diff --git a/relay/helper/common.go b/relay/helper/common.go index 5b3e76743..381147ae5 100644 --- a/relay/helper/common.go +++ b/relay/helper/common.go @@ -1,7 +1,6 @@ package helper import ( - "encoding/json" "errors" "fmt" "net/http" @@ -42,7 +41,7 @@ func SetEventStreamHeaders(c *gin.Context) { } func ClaudeData(c *gin.Context, resp dto.ClaudeResponse) error { - jsonData, err := json.Marshal(resp) + jsonData, err := common.Marshal(resp) if err != nil { common.SysError("error marshalling stream response: " + err.Error()) } else { @@ -104,7 +103,7 @@ func WssString(c *gin.Context, ws *websocket.Conn, str string) error { } func WssObject(c *gin.Context, ws *websocket.Conn, object interface{}) error { - jsonData, err := json.Marshal(object) + jsonData, err := common.Marshal(object) if err != nil { return fmt.Errorf("error marshalling object: %w", err) }