diff --git a/relay/channel/openai/relay-openai.go b/relay/channel/openai/relay-openai.go index a4de16112..c9c71327b 100644 --- a/relay/channel/openai/relay-openai.go +++ b/relay/channel/openai/relay-openai.go @@ -1,6 +1,7 @@ package openai import ( + "bytes" "fmt" "io" "net/http" @@ -22,6 +23,19 @@ import ( "github.com/gorilla/websocket" ) +const xaiCSAMSafetyCheckType = "SAFETY_CHECK_TYPE_CSAM" + +func maybeMarkXaiCSAMRefusal(c *gin.Context, info *relaycommon.RelayInfo, responseBody []byte) bool { + if c == nil || info == nil || len(responseBody) == 0 { + return false + } + if !bytes.Contains(responseBody, []byte(xaiCSAMSafetyCheckType)) { + return false + } + common.SetContextKey(c, constant.ContextKeyAdminRejectReason, "grok_safety_check_type=csam") + return true +} + func sendStreamData(c *gin.Context, info *relaycommon.RelayInfo, data string, forceFormat bool, thinkToContent bool) error { if data == "" { return nil @@ -201,6 +215,7 @@ func OpenaiHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http.Respo if err != nil { return nil, types.NewOpenAIError(err, types.ErrorCodeReadResponseBodyFailed, http.StatusInternalServerError) } + isXaiCSAMRefusal := maybeMarkXaiCSAMRefusal(c, info, responseBody) if common.DebugEnabled { println("upstream response body:", string(responseBody)) } @@ -222,10 +237,16 @@ func OpenaiHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http.Respo err = common.Unmarshal(responseBody, &simpleResponse) if err != nil { + if isXaiCSAMRefusal { + return nil, types.NewOpenAIError(err, types.ErrorCodeBadResponseBody, http.StatusInternalServerError, types.ErrOptionWithSkipRetry()) + } return nil, types.NewOpenAIError(err, types.ErrorCodeBadResponseBody, http.StatusInternalServerError) } if oaiError := simpleResponse.GetOpenAIError(); oaiError != nil && oaiError.Type != "" { + if isXaiCSAMRefusal { + return nil, types.WithOpenAIError(*oaiError, resp.StatusCode, types.ErrOptionWithSkipRetry()) + } return nil, types.WithOpenAIError(*oaiError, resp.StatusCode) }