diff --git a/backend/internal/handler/admin/account_handler.go b/backend/internal/handler/admin/account_handler.go index fad8a33c..7c4d4638 100644 --- a/backend/internal/handler/admin/account_handler.go +++ b/backend/internal/handler/admin/account_handler.go @@ -628,6 +628,7 @@ func (h *AccountHandler) Delete(c *gin.Context) { // TestAccountRequest represents the request body for testing an account type TestAccountRequest struct { ModelID string `json:"model_id"` + Prompt string `json:"prompt"` } type SyncFromCRSRequest struct { @@ -658,7 +659,7 @@ func (h *AccountHandler) Test(c *gin.Context) { _ = c.ShouldBindJSON(&req) // Use AccountTestService to test the account with SSE streaming - if err := h.accountTestService.TestAccountConnection(c, accountID, req.ModelID); err != nil { + if err := h.accountTestService.TestAccountConnection(c, accountID, req.ModelID, req.Prompt); err != nil { // Error already sent via SSE, just log return } diff --git a/backend/internal/service/account_test_service.go b/backend/internal/service/account_test_service.go index b44f29fd..472551cf 100644 --- a/backend/internal/service/account_test_service.go +++ b/backend/internal/service/account_test_service.go @@ -45,16 +45,23 @@ const ( // TestEvent represents a SSE event for account testing type TestEvent struct { - Type string `json:"type"` - Text string `json:"text,omitempty"` - Model string `json:"model,omitempty"` - Status string `json:"status,omitempty"` - Code string `json:"code,omitempty"` - Data any `json:"data,omitempty"` - Success bool `json:"success,omitempty"` - Error string `json:"error,omitempty"` + Type string `json:"type"` + Text string `json:"text,omitempty"` + Model string `json:"model,omitempty"` + Status string `json:"status,omitempty"` + Code string `json:"code,omitempty"` + ImageURL string `json:"image_url,omitempty"` + MimeType string `json:"mime_type,omitempty"` + Data any `json:"data,omitempty"` + Success bool `json:"success,omitempty"` + Error string `json:"error,omitempty"` } +const ( + defaultGeminiTextTestPrompt = "hi" + defaultGeminiImageTestPrompt = "Generate a cute orange cat astronaut sticker on a clean pastel background." +) + // AccountTestService handles account testing operations type AccountTestService struct { accountRepo AccountRepository @@ -161,7 +168,7 @@ func createTestPayload(modelID string) (map[string]any, error) { // TestAccountConnection tests an account's connection by sending a test request // All account types use full Claude Code client characteristics, only auth header differs // modelID is optional - if empty, defaults to claude.DefaultTestModel -func (s *AccountTestService) TestAccountConnection(c *gin.Context, accountID int64, modelID string) error { +func (s *AccountTestService) TestAccountConnection(c *gin.Context, accountID int64, modelID string, prompt string) error { ctx := c.Request.Context() // Get account @@ -176,11 +183,11 @@ func (s *AccountTestService) TestAccountConnection(c *gin.Context, accountID int } if account.IsGemini() { - return s.testGeminiAccountConnection(c, account, modelID) + return s.testGeminiAccountConnection(c, account, modelID, prompt) } if account.Platform == PlatformAntigravity { - return s.routeAntigravityTest(c, account, modelID) + return s.routeAntigravityTest(c, account, modelID, prompt) } if account.Platform == PlatformSora { @@ -435,7 +442,7 @@ func (s *AccountTestService) testOpenAIAccountConnection(c *gin.Context, account } // testGeminiAccountConnection tests a Gemini account's connection -func (s *AccountTestService) testGeminiAccountConnection(c *gin.Context, account *Account, modelID string) error { +func (s *AccountTestService) testGeminiAccountConnection(c *gin.Context, account *Account, modelID string, prompt string) error { ctx := c.Request.Context() // Determine the model to use @@ -462,7 +469,7 @@ func (s *AccountTestService) testGeminiAccountConnection(c *gin.Context, account c.Writer.Flush() // Create test payload (Gemini format) - payload := createGeminiTestPayload() + payload := createGeminiTestPayload(testModelID, prompt) // Build request based on account type var req *http.Request @@ -1198,10 +1205,10 @@ func truncateSoraErrorBody(body []byte, max int) string { // routeAntigravityTest 路由 Antigravity 账号的测试请求。 // APIKey 类型走原生协议(与 gateway_handler 路由一致),OAuth/Upstream 走 CRS 中转。 -func (s *AccountTestService) routeAntigravityTest(c *gin.Context, account *Account, modelID string) error { +func (s *AccountTestService) routeAntigravityTest(c *gin.Context, account *Account, modelID string, prompt string) error { if account.Type == AccountTypeAPIKey { if strings.HasPrefix(modelID, "gemini-") { - return s.testGeminiAccountConnection(c, account, modelID) + return s.testGeminiAccountConnection(c, account, modelID, prompt) } return s.testClaudeAccountConnection(c, account, modelID) } @@ -1349,14 +1356,46 @@ func (s *AccountTestService) buildCodeAssistRequest(ctx context.Context, accessT return req, nil } -// createGeminiTestPayload creates a minimal test payload for Gemini API -func createGeminiTestPayload() []byte { +// createGeminiTestPayload creates a minimal test payload for Gemini API. +// Image models use the image-generation path so the frontend can preview the returned image. +func createGeminiTestPayload(modelID string, prompt string) []byte { + if isImageGenerationModel(modelID) { + imagePrompt := strings.TrimSpace(prompt) + if imagePrompt == "" { + imagePrompt = defaultGeminiImageTestPrompt + } + + payload := map[string]any{ + "contents": []map[string]any{ + { + "role": "user", + "parts": []map[string]any{ + {"text": imagePrompt}, + }, + }, + }, + "generationConfig": map[string]any{ + "responseModalities": []string{"TEXT", "IMAGE"}, + "imageConfig": map[string]any{ + "aspectRatio": "1:1", + }, + }, + } + bytes, _ := json.Marshal(payload) + return bytes + } + + textPrompt := strings.TrimSpace(prompt) + if textPrompt == "" { + textPrompt = defaultGeminiTextTestPrompt + } + payload := map[string]any{ "contents": []map[string]any{ { "role": "user", "parts": []map[string]any{ - {"text": "hi"}, + {"text": textPrompt}, }, }, }, @@ -1416,6 +1455,17 @@ func (s *AccountTestService) processGeminiStream(c *gin.Context, body io.Reader) if text, ok := partMap["text"].(string); ok && text != "" { s.sendEvent(c, TestEvent{Type: "content", Text: text}) } + if inlineData, ok := partMap["inlineData"].(map[string]any); ok { + mimeType, _ := inlineData["mimeType"].(string) + data, _ := inlineData["data"].(string) + if strings.HasPrefix(strings.ToLower(mimeType), "image/") && data != "" { + s.sendEvent(c, TestEvent{ + Type: "image", + ImageURL: fmt.Sprintf("data:%s;base64,%s", mimeType, data), + MimeType: mimeType, + }) + } + } } } } @@ -1602,7 +1652,7 @@ func (s *AccountTestService) RunTestBackground(ctx context.Context, accountID in ginCtx, _ := gin.CreateTestContext(w) ginCtx.Request = (&http.Request{}).WithContext(ctx) - testErr := s.TestAccountConnection(ginCtx, accountID, modelID) + testErr := s.TestAccountConnection(ginCtx, accountID, modelID, "") finishedAt := time.Now() body := w.Body.String() diff --git a/backend/internal/service/account_test_service_gemini_test.go b/backend/internal/service/account_test_service_gemini_test.go new file mode 100644 index 00000000..5ba04c69 --- /dev/null +++ b/backend/internal/service/account_test_service_gemini_test.go @@ -0,0 +1,59 @@ +//go:build unit + +package service + +import ( + "encoding/json" + "strings" + "testing" + + "github.com/gin-gonic/gin" + "github.com/stretchr/testify/require" +) + +func TestCreateGeminiTestPayload_ImageModel(t *testing.T) { + t.Parallel() + + payload := createGeminiTestPayload("gemini-2.5-flash-image", "draw a tiny robot") + + var parsed struct { + Contents []struct { + Parts []struct { + Text string `json:"text"` + } `json:"parts"` + } `json:"contents"` + GenerationConfig struct { + ResponseModalities []string `json:"responseModalities"` + ImageConfig struct { + AspectRatio string `json:"aspectRatio"` + } `json:"imageConfig"` + } `json:"generationConfig"` + } + + require.NoError(t, json.Unmarshal(payload, &parsed)) + require.Len(t, parsed.Contents, 1) + require.Len(t, parsed.Contents[0].Parts, 1) + require.Equal(t, "draw a tiny robot", parsed.Contents[0].Parts[0].Text) + require.Equal(t, []string{"TEXT", "IMAGE"}, parsed.GenerationConfig.ResponseModalities) + require.Equal(t, "1:1", parsed.GenerationConfig.ImageConfig.AspectRatio) +} + +func TestProcessGeminiStream_EmitsImageEvent(t *testing.T) { + t.Parallel() + gin.SetMode(gin.TestMode) + + ctx, recorder := newSoraTestContext() + svc := &AccountTestService{} + + stream := strings.NewReader("data: {\"candidates\":[{\"content\":{\"parts\":[{\"text\":\"ok\"},{\"inlineData\":{\"mimeType\":\"image/png\",\"data\":\"QUJD\"}}]}}]}\n\ndata: [DONE]\n\n") + + err := svc.processGeminiStream(ctx, stream) + require.NoError(t, err) + + body := recorder.Body.String() + require.Contains(t, body, "\"type\":\"content\"") + require.Contains(t, body, "\"text\":\"ok\"") + require.Contains(t, body, "\"type\":\"image\"") + require.Contains(t, body, "\"image_url\":\"data:image/png;base64,QUJD\"") + require.Contains(t, body, "\"mime_type\":\"image/png\"") +} diff --git a/frontend/src/components/account/AccountTestModal.vue b/frontend/src/components/account/AccountTestModal.vue index 792a8f45..04ab032f 100644 --- a/frontend/src/components/account/AccountTestModal.vue +++ b/frontend/src/components/account/AccountTestModal.vue @@ -15,7 +15,7 @@
- +
{{ account.name }}
@@ -61,6 +61,17 @@ {{ t('admin.accounts.soraTestHint') }}
+
+