mirror of
https://github.com/Wei-Shaw/sub2api.git
synced 2026-03-30 00:48:53 +00:00
add test for fix #935
This commit is contained in:
@@ -2164,6 +2164,98 @@ func (s *AntigravityGatewayService) ForwardGemini(ctx context.Context, c *gin.Co
|
||||
}
|
||||
}
|
||||
|
||||
// Gemini 原生请求中的 thoughtSignature 可能来自旧上下文/旧账号,触发上游严格校验后返回
|
||||
// "Corrupted thought signature."。检测到此类 400 时,将 thoughtSignature 清理为 dummy 值后重试一次。
|
||||
signatureCheckBody := respBody
|
||||
if unwrapped, unwrapErr := s.unwrapV1InternalResponse(respBody); unwrapErr == nil && len(unwrapped) > 0 {
|
||||
signatureCheckBody = unwrapped
|
||||
}
|
||||
if resp.StatusCode == http.StatusBadRequest &&
|
||||
s.settingService != nil &&
|
||||
s.settingService.IsSignatureRectifierEnabled(ctx) &&
|
||||
isSignatureRelatedError(signatureCheckBody) &&
|
||||
bytes.Contains(injectedBody, []byte(`"thoughtSignature"`)) {
|
||||
upstreamMsg := sanitizeUpstreamErrorMessage(strings.TrimSpace(extractAntigravityErrorMessage(signatureCheckBody)))
|
||||
upstreamDetail := s.getUpstreamErrorDetail(signatureCheckBody)
|
||||
appendOpsUpstreamError(c, OpsUpstreamErrorEvent{
|
||||
Platform: account.Platform,
|
||||
AccountID: account.ID,
|
||||
AccountName: account.Name,
|
||||
UpstreamStatusCode: resp.StatusCode,
|
||||
UpstreamRequestID: resp.Header.Get("x-request-id"),
|
||||
Kind: "signature_error",
|
||||
Message: upstreamMsg,
|
||||
Detail: upstreamDetail,
|
||||
})
|
||||
|
||||
logger.LegacyPrintf("service.antigravity_gateway", "Antigravity Gemini account %d: detected signature-related 400, retrying with cleaned thought signatures", account.ID)
|
||||
|
||||
cleanedInjectedBody := CleanGeminiNativeThoughtSignatures(injectedBody)
|
||||
retryWrappedBody, wrapErr := s.wrapV1InternalRequest(projectID, mappedModel, cleanedInjectedBody)
|
||||
if wrapErr == nil {
|
||||
retryResult, retryErr := s.antigravityRetryLoop(antigravityRetryLoopParams{
|
||||
ctx: ctx,
|
||||
prefix: prefix,
|
||||
account: account,
|
||||
proxyURL: proxyURL,
|
||||
accessToken: accessToken,
|
||||
action: upstreamAction,
|
||||
body: retryWrappedBody,
|
||||
c: c,
|
||||
httpUpstream: s.httpUpstream,
|
||||
settingService: s.settingService,
|
||||
accountRepo: s.accountRepo,
|
||||
handleError: s.handleUpstreamError,
|
||||
requestedModel: originalModel,
|
||||
isStickySession: isStickySession,
|
||||
groupID: 0,
|
||||
sessionHash: "",
|
||||
})
|
||||
if retryErr == nil {
|
||||
retryResp := retryResult.resp
|
||||
if retryResp.StatusCode < 400 {
|
||||
resp = retryResp
|
||||
} else {
|
||||
retryRespBody, _ := io.ReadAll(io.LimitReader(retryResp.Body, 2<<20))
|
||||
_ = retryResp.Body.Close()
|
||||
retryOpsBody := retryRespBody
|
||||
if retryUnwrapped, unwrapErr := s.unwrapV1InternalResponse(retryRespBody); unwrapErr == nil && len(retryUnwrapped) > 0 {
|
||||
retryOpsBody = retryUnwrapped
|
||||
}
|
||||
appendOpsUpstreamError(c, OpsUpstreamErrorEvent{
|
||||
Platform: account.Platform,
|
||||
AccountID: account.ID,
|
||||
AccountName: account.Name,
|
||||
UpstreamStatusCode: retryResp.StatusCode,
|
||||
UpstreamRequestID: retryResp.Header.Get("x-request-id"),
|
||||
Kind: "signature_retry",
|
||||
Message: sanitizeUpstreamErrorMessage(strings.TrimSpace(extractAntigravityErrorMessage(retryOpsBody))),
|
||||
Detail: s.getUpstreamErrorDetail(retryOpsBody),
|
||||
})
|
||||
respBody = retryRespBody
|
||||
resp = &http.Response{
|
||||
StatusCode: retryResp.StatusCode,
|
||||
Header: retryResp.Header.Clone(),
|
||||
Body: io.NopCloser(bytes.NewReader(retryRespBody)),
|
||||
}
|
||||
contentType = resp.Header.Get("Content-Type")
|
||||
}
|
||||
} else {
|
||||
appendOpsUpstreamError(c, OpsUpstreamErrorEvent{
|
||||
Platform: account.Platform,
|
||||
AccountID: account.ID,
|
||||
AccountName: account.Name,
|
||||
UpstreamStatusCode: 0,
|
||||
Kind: "signature_retry_request_error",
|
||||
Message: sanitizeUpstreamErrorMessage(retryErr.Error()),
|
||||
})
|
||||
logger.LegacyPrintf("service.antigravity_gateway", "Antigravity Gemini account %d: signature retry request failed: %v", account.ID, retryErr)
|
||||
}
|
||||
} else {
|
||||
logger.LegacyPrintf("service.antigravity_gateway", "Antigravity Gemini account %d: signature retry wrap failed: %v", account.ID, wrapErr)
|
||||
}
|
||||
}
|
||||
|
||||
// fallback 成功:继续按正常响应处理
|
||||
if resp.StatusCode < 400 {
|
||||
goto handleSuccess
|
||||
|
||||
@@ -134,6 +134,43 @@ func (s *httpUpstreamStub) DoWithTLS(_ *http.Request, _ string, _ int64, _ int,
|
||||
return s.resp, s.err
|
||||
}
|
||||
|
||||
type queuedHTTPUpstreamStub struct {
|
||||
responses []*http.Response
|
||||
errors []error
|
||||
requestBodies [][]byte
|
||||
callCount int
|
||||
}
|
||||
|
||||
func (s *queuedHTTPUpstreamStub) Do(req *http.Request, _ string, _ int64, _ int) (*http.Response, error) {
|
||||
if req != nil && req.Body != nil {
|
||||
body, _ := io.ReadAll(req.Body)
|
||||
s.requestBodies = append(s.requestBodies, body)
|
||||
req.Body = io.NopCloser(bytes.NewReader(body))
|
||||
} else {
|
||||
s.requestBodies = append(s.requestBodies, nil)
|
||||
}
|
||||
|
||||
idx := s.callCount
|
||||
s.callCount++
|
||||
|
||||
var resp *http.Response
|
||||
if idx < len(s.responses) {
|
||||
resp = s.responses[idx]
|
||||
}
|
||||
var err error
|
||||
if idx < len(s.errors) {
|
||||
err = s.errors[idx]
|
||||
}
|
||||
if resp == nil && err == nil {
|
||||
return nil, errors.New("unexpected upstream call")
|
||||
}
|
||||
return resp, err
|
||||
}
|
||||
|
||||
func (s *queuedHTTPUpstreamStub) DoWithTLS(req *http.Request, proxyURL string, accountID int64, concurrency int, _ bool) (*http.Response, error) {
|
||||
return s.Do(req, proxyURL, accountID, concurrency)
|
||||
}
|
||||
|
||||
type antigravitySettingRepoStub struct{}
|
||||
|
||||
func (s *antigravitySettingRepoStub) Get(ctx context.Context, key string) (*Setting, error) {
|
||||
@@ -556,6 +593,92 @@ func TestAntigravityGatewayService_ForwardGemini_BillsWithMappedModel(t *testing
|
||||
require.Equal(t, mappedModel, result.Model)
|
||||
}
|
||||
|
||||
func TestAntigravityGatewayService_ForwardGemini_RetriesCorruptedThoughtSignature(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
writer := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(writer)
|
||||
|
||||
body, err := json.Marshal(map[string]any{
|
||||
"contents": []map[string]any{
|
||||
{"role": "user", "parts": []map[string]any{{"text": "hello"}}},
|
||||
{"role": "model", "parts": []map[string]any{{"text": "thinking", "thought": true, "thoughtSignature": "sig_bad_1"}}},
|
||||
{"role": "model", "parts": []map[string]any{{"functionCall": map[string]any{"name": "toolA", "args": map[string]any{"x": 1}}, "thoughtSignature": "sig_bad_2"}}},
|
||||
},
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
req := httptest.NewRequest(http.MethodPost, "/antigravity/v1beta/models/gemini-3.1-pro-preview:streamGenerateContent", bytes.NewReader(body))
|
||||
c.Request = req
|
||||
|
||||
firstRespBody := []byte(`{"response":{"error":{"code":400,"message":"Corrupted thought signature.","status":"INVALID_ARGUMENT"}}}`)
|
||||
secondRespBody := []byte("data: {\"response\":{\"candidates\":[{\"content\":{\"parts\":[{\"text\":\"ok\"}]},\"finishReason\":\"STOP\"}],\"usageMetadata\":{\"promptTokenCount\":8,\"candidatesTokenCount\":3}}}\n\n")
|
||||
|
||||
upstream := &queuedHTTPUpstreamStub{
|
||||
responses: []*http.Response{
|
||||
{
|
||||
StatusCode: http.StatusBadRequest,
|
||||
Header: http.Header{
|
||||
"Content-Type": []string{"application/json"},
|
||||
"X-Request-Id": []string{"req-sig-1"},
|
||||
},
|
||||
Body: io.NopCloser(bytes.NewReader(firstRespBody)),
|
||||
},
|
||||
{
|
||||
StatusCode: http.StatusOK,
|
||||
Header: http.Header{
|
||||
"Content-Type": []string{"text/event-stream"},
|
||||
"X-Request-Id": []string{"req-sig-2"},
|
||||
},
|
||||
Body: io.NopCloser(bytes.NewReader(secondRespBody)),
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
svc := &AntigravityGatewayService{
|
||||
settingService: NewSettingService(&antigravitySettingRepoStub{}, &config.Config{Gateway: config.GatewayConfig{MaxLineSize: defaultMaxLineSize}}),
|
||||
tokenProvider: &AntigravityTokenProvider{},
|
||||
httpUpstream: upstream,
|
||||
}
|
||||
|
||||
const originalModel = "gemini-3.1-pro-preview"
|
||||
const mappedModel = "gemini-3.1-pro-high"
|
||||
account := &Account{
|
||||
ID: 7,
|
||||
Name: "acc-gemini-signature",
|
||||
Platform: PlatformAntigravity,
|
||||
Type: AccountTypeOAuth,
|
||||
Status: StatusActive,
|
||||
Concurrency: 1,
|
||||
Credentials: map[string]any{
|
||||
"access_token": "token",
|
||||
"model_mapping": map[string]any{
|
||||
originalModel: mappedModel,
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
result, err := svc.ForwardGemini(context.Background(), c, account, originalModel, "streamGenerateContent", true, body, false)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, result)
|
||||
require.Equal(t, mappedModel, result.Model)
|
||||
require.Len(t, upstream.requestBodies, 2, "signature error should trigger exactly one retry")
|
||||
|
||||
firstReq := string(upstream.requestBodies[0])
|
||||
secondReq := string(upstream.requestBodies[1])
|
||||
require.Contains(t, firstReq, `"thoughtSignature":"sig_bad_1"`)
|
||||
require.Contains(t, firstReq, `"thoughtSignature":"sig_bad_2"`)
|
||||
require.Contains(t, secondReq, `"thoughtSignature":"skip_thought_signature_validator"`)
|
||||
require.NotContains(t, secondReq, `"thoughtSignature":"sig_bad_1"`)
|
||||
require.NotContains(t, secondReq, `"thoughtSignature":"sig_bad_2"`)
|
||||
|
||||
raw, ok := c.Get(OpsUpstreamErrorsKey)
|
||||
require.True(t, ok)
|
||||
events, ok := raw.([]*OpsUpstreamErrorEvent)
|
||||
require.True(t, ok)
|
||||
require.NotEmpty(t, events)
|
||||
require.Equal(t, "signature_error", events[0].Kind)
|
||||
}
|
||||
|
||||
// TestStreamUpstreamResponse_UsageAndFirstToken
|
||||
// 验证:usage 字段可被累积/覆盖更新,并且能记录首 token 时间
|
||||
func TestStreamUpstreamResponse_UsageAndFirstToken(t *testing.T) {
|
||||
|
||||
@@ -0,0 +1,75 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"testing"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/antigravity"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestCleanGeminiNativeThoughtSignatures_ReplacesNestedThoughtSignatures(t *testing.T) {
|
||||
input := []byte(`{
|
||||
"contents": [
|
||||
{
|
||||
"role": "user",
|
||||
"parts": [{"text": "hello"}]
|
||||
},
|
||||
{
|
||||
"role": "model",
|
||||
"parts": [
|
||||
{"text": "thinking", "thought": true, "thoughtSignature": "sig_1"},
|
||||
{"functionCall": {"name": "toolA", "args": {"k": "v"}}, "thoughtSignature": "sig_2"}
|
||||
]
|
||||
}
|
||||
],
|
||||
"cachedContent": {
|
||||
"parts": [{"text": "cached", "thoughtSignature": "sig_3"}]
|
||||
},
|
||||
"signature": "keep_me"
|
||||
}`)
|
||||
|
||||
cleaned := CleanGeminiNativeThoughtSignatures(input)
|
||||
|
||||
var got map[string]any
|
||||
require.NoError(t, json.Unmarshal(cleaned, &got))
|
||||
|
||||
require.NotContains(t, string(cleaned), `"thoughtSignature":"sig_1"`)
|
||||
require.NotContains(t, string(cleaned), `"thoughtSignature":"sig_2"`)
|
||||
require.NotContains(t, string(cleaned), `"thoughtSignature":"sig_3"`)
|
||||
require.Contains(t, string(cleaned), `"thoughtSignature":"`+antigravity.DummyThoughtSignature+`"`)
|
||||
require.Contains(t, string(cleaned), `"signature":"keep_me"`)
|
||||
}
|
||||
|
||||
func TestCleanGeminiNativeThoughtSignatures_InvalidJSONReturnsOriginal(t *testing.T) {
|
||||
input := []byte(`{"contents":[invalid-json]}`)
|
||||
|
||||
cleaned := CleanGeminiNativeThoughtSignatures(input)
|
||||
|
||||
require.Equal(t, input, cleaned)
|
||||
}
|
||||
|
||||
func TestReplaceThoughtSignaturesRecursive_OnlyReplacesTargetField(t *testing.T) {
|
||||
input := map[string]any{
|
||||
"thoughtSignature": "sig_root",
|
||||
"signature": "keep_signature",
|
||||
"nested": []any{
|
||||
map[string]any{
|
||||
"thoughtSignature": "sig_nested",
|
||||
"signature": "keep_nested_signature",
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
got, ok := replaceThoughtSignaturesRecursive(input).(map[string]any)
|
||||
require.True(t, ok)
|
||||
require.Equal(t, antigravity.DummyThoughtSignature, got["thoughtSignature"])
|
||||
require.Equal(t, "keep_signature", got["signature"])
|
||||
|
||||
nested, ok := got["nested"].([]any)
|
||||
require.True(t, ok)
|
||||
nestedMap, ok := nested[0].(map[string]any)
|
||||
require.True(t, ok)
|
||||
require.Equal(t, antigravity.DummyThoughtSignature, nestedMap["thoughtSignature"])
|
||||
require.Equal(t, "keep_nested_signature", nestedMap["signature"])
|
||||
}
|
||||
Reference in New Issue
Block a user