diff --git a/dto/openai_image.go b/dto/openai_image.go index bf35b0b12..130d1dde8 100644 --- a/dto/openai_image.go +++ b/dto/openai_image.go @@ -27,8 +27,11 @@ type ImageRequest struct { OutputCompression json.RawMessage `json:"output_compression,omitempty"` PartialImages json.RawMessage `json:"partial_images,omitempty"` // Stream bool `json:"stream,omitempty"` - Watermark *bool `json:"watermark,omitempty"` - Image json.RawMessage `json:"image,omitempty"` + Watermark *bool `json:"watermark,omitempty"` + // zhipu 4v + WatermarkEnabled json.RawMessage `json:"watermark_enabled,omitempty"` + UserId json.RawMessage `json:"user_id,omitempty"` + Image json.RawMessage `json:"image,omitempty"` // 用匿名参数接收额外参数 Extra map[string]json.RawMessage `json:"-"` } diff --git a/relay/channel/zhipu_4v/adaptor.go b/relay/channel/zhipu_4v/adaptor.go index 4fd6956eb..b11bea105 100644 --- a/relay/channel/zhipu_4v/adaptor.go +++ b/relay/channel/zhipu_4v/adaptor.go @@ -36,8 +36,7 @@ func (a *Adaptor) ConvertAudioRequest(c *gin.Context, info *relaycommon.RelayInf } func (a *Adaptor) ConvertImageRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.ImageRequest) (any, error) { - //TODO implement me - return nil, errors.New("not implemented") + return request, nil } func (a *Adaptor) Init(info *relaycommon.RelayInfo) { @@ -63,6 +62,8 @@ func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) { return fmt.Sprintf("%s/embeddings", specialPlan.OpenAIBaseURL), nil } return fmt.Sprintf("%s/api/paas/v4/embeddings", baseURL), nil + case relayconstant.RelayModeImagesGenerations: + return fmt.Sprintf("%s/api/paas/v4/images/generations", baseURL), nil default: if hasSpecialPlan && specialPlan.OpenAIBaseURL != "" { return fmt.Sprintf("%s/chat/completions", specialPlan.OpenAIBaseURL), nil @@ -114,6 +115,9 @@ func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycom return claude.ClaudeHandler(c, resp, info, claude.RequestModeMessage) } default: + if info.RelayMode == relayconstant.RelayModeImagesGenerations { + return zhipu4vImageHandler(c, resp, info) + } adaptor := openai.Adaptor{} return adaptor.DoResponse(c, resp, info) } diff --git a/relay/channel/zhipu_4v/image.go b/relay/channel/zhipu_4v/image.go new file mode 100644 index 000000000..b1fd2c8e3 --- /dev/null +++ b/relay/channel/zhipu_4v/image.go @@ -0,0 +1,127 @@ +package zhipu_4v + +import ( + "io" + "net/http" + + "github.com/QuantumNous/new-api/common" + "github.com/QuantumNous/new-api/dto" + "github.com/QuantumNous/new-api/logger" + relaycommon "github.com/QuantumNous/new-api/relay/common" + "github.com/QuantumNous/new-api/service" + "github.com/QuantumNous/new-api/types" + + "github.com/gin-gonic/gin" +) + +type zhipuImageRequest struct { + Model string `json:"model"` + Prompt string `json:"prompt"` + Quality string `json:"quality,omitempty"` + Size string `json:"size,omitempty"` + WatermarkEnabled *bool `json:"watermark_enabled,omitempty"` + UserID string `json:"user_id,omitempty"` +} + +type zhipuImageResponse struct { + Created *int64 `json:"created,omitempty"` + Data []zhipuImageData `json:"data,omitempty"` + ContentFilter any `json:"content_filter,omitempty"` + Usage *dto.Usage `json:"usage,omitempty"` + Error *zhipuImageError `json:"error,omitempty"` + RequestID string `json:"request_id,omitempty"` + ExtendParam map[string]string `json:"extendParam,omitempty"` +} + +type zhipuImageError struct { + Code string `json:"code"` + Message string `json:"message"` +} + +type zhipuImageData struct { + Url string `json:"url,omitempty"` + ImageUrl string `json:"image_url,omitempty"` + B64Json string `json:"b64_json,omitempty"` + B64Image string `json:"b64_image,omitempty"` +} + +type openAIImagePayload struct { + Created int64 `json:"created"` + Data []openAIImageData `json:"data"` +} + +type openAIImageData struct { + B64Json string `json:"b64_json"` +} + +func zhipu4vImageHandler(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (*dto.Usage, *types.NewAPIError) { + responseBody, err := io.ReadAll(resp.Body) + if err != nil { + return nil, types.NewOpenAIError(err, types.ErrorCodeReadResponseBodyFailed, http.StatusInternalServerError) + } + service.CloseResponseBodyGracefully(resp) + + var zhipuResp zhipuImageResponse + if err := common.Unmarshal(responseBody, &zhipuResp); err != nil { + return nil, types.NewOpenAIError(err, types.ErrorCodeBadResponseBody, http.StatusInternalServerError) + } + + if zhipuResp.Error != nil && zhipuResp.Error.Message != "" { + return nil, types.WithOpenAIError(types.OpenAIError{ + Message: zhipuResp.Error.Message, + Type: "zhipu_image_error", + Code: zhipuResp.Error.Code, + }, resp.StatusCode) + } + + payload := openAIImagePayload{} + if zhipuResp.Created != nil && *zhipuResp.Created != 0 { + payload.Created = *zhipuResp.Created + } else { + payload.Created = info.StartTime.Unix() + } + for _, data := range zhipuResp.Data { + url := data.Url + if url == "" { + url = data.ImageUrl + } + if url == "" { + logger.LogWarn(c, "zhipu_image_missing_url") + continue + } + + var b64 string + switch { + case data.B64Json != "": + b64 = data.B64Json + case data.B64Image != "": + b64 = data.B64Image + default: + _, downloaded, err := service.GetImageFromUrl(url) + if err != nil { + logger.LogError(c, "zhipu_image_get_b64_failed: "+err.Error()) + continue + } + b64 = downloaded + } + + if b64 == "" { + logger.LogWarn(c, "zhipu_image_empty_b64") + continue + } + + imageData := openAIImageData{ + B64Json: b64, + } + payload.Data = append(payload.Data, imageData) + } + + jsonResp, err := common.Marshal(payload) + if err != nil { + return nil, types.NewError(err, types.ErrorCodeBadResponseBody) + } + + service.IOCopyBytesGracefully(c, resp, jsonResp) + + return &dto.Usage{}, nil +}