diff --git a/backend/internal/handler/openai_chat_completions.go b/backend/internal/handler/openai_chat_completions.go index 30e2c6b9..cd30a63f 100644 --- a/backend/internal/handler/openai_chat_completions.go +++ b/backend/internal/handler/openai_chat_completions.go @@ -62,6 +62,7 @@ func (h *OpenAIGatewayHandler) ChatCompletions(c *gin.Context) { stream, _ := converted["stream"].(bool) model, _ := converted["model"].(string) + originalWriter := c.Writer writer := newChatCompletionsResponseWriter(c.Writer, stream, includeUsage, model) c.Writer = writer c.Request.Body = io.NopCloser(bytes.NewReader(convertedBody)) @@ -69,6 +70,7 @@ func (h *OpenAIGatewayHandler) ChatCompletions(c *gin.Context) { h.Responses(c) writer.Finalize() + c.Writer = originalWriter } type chatCompletionsResponseWriter struct { @@ -167,6 +169,20 @@ func (w *chatCompletionsResponseWriter) SetPassthrough() { w.passthrough = true } +func (w *chatCompletionsResponseWriter) Status() int { + if w.ResponseWriter == nil { + return 0 + } + return w.ResponseWriter.Status() +} + +func (w *chatCompletionsResponseWriter) Written() bool { + if w.ResponseWriter == nil { + return false + } + return w.ResponseWriter.Written() +} + func (w *chatCompletionsResponseWriter) flushStreamBuffer() { for { buf := w.streamBuf.Bytes() diff --git a/backend/internal/server/routes/gateway.go b/backend/internal/server/routes/gateway.go index 382e78dd..ea40f2f1 100644 --- a/backend/internal/server/routes/gateway.go +++ b/backend/internal/server/routes/gateway.go @@ -1,6 +1,8 @@ package routes import ( + "net/http" + "github.com/Wei-Shaw/sub2api/internal/config" "github.com/Wei-Shaw/sub2api/internal/handler" "github.com/Wei-Shaw/sub2api/internal/server/middleware" diff --git a/backend/internal/service/openai_chat_completions_forward.go b/backend/internal/service/openai_chat_completions_forward.go index 703f3af1..0eefdb35 100644 --- a/backend/internal/service/openai_chat_completions_forward.go +++ b/backend/internal/service/openai_chat_completions_forward.go @@ -209,8 +209,8 @@ func (s *OpenAIGatewayService) buildChatCompletionsRequest(ctx context.Context, } func (s *OpenAIGatewayService) handleChatCompletionsStreamingResponse(ctx context.Context, resp *http.Response, c *gin.Context, account *Account, startTime time.Time, originalModel, mappedModel string) (*chatStreamingResult, error) { - if s.cfg != nil { - responseheaders.WriteFilteredHeaders(c.Writer.Header(), resp.Header, s.cfg.Security.ResponseHeaders) + if s.responseHeaderFilter != nil { + responseheaders.WriteFilteredHeaders(c.Writer.Header(), resp.Header, s.responseHeaderFilter) } c.Header("Content-Type", "text/event-stream") @@ -409,7 +409,9 @@ func (s *OpenAIGatewayService) handleChatCompletionsNonStreamingResponse(resp *h } body = s.correctToolCallsInResponseBody(body) - responseheaders.WriteFilteredHeaders(c.Writer.Header(), resp.Header, s.cfg.Security.ResponseHeaders) + if s.responseHeaderFilter != nil { + responseheaders.WriteFilteredHeaders(c.Writer.Header(), resp.Header, s.responseHeaderFilter) + } contentType := "application/json" if s.cfg != nil && !s.cfg.Security.ResponseHeaders.Enabled {