From 11f7b835229273e993d016c9f7190e29400d8256 Mon Sep 17 00:00:00 2001 From: Ylarod Date: Thu, 12 Mar 2026 18:42:02 +0800 Subject: [PATCH 1/2] sub2api: add bedrock support --- backend/go.mod | 4 +- backend/go.sum | 4 + backend/internal/domain/constants.go | 26 + .../internal/handler/admin/account_handler.go | 4 +- backend/internal/service/account.go | 9 + .../internal/service/account_test_service.go | 117 +++- backend/internal/service/bedrock_request.go | 607 ++++++++++++++++ .../internal/service/bedrock_request_test.go | 659 ++++++++++++++++++ backend/internal/service/bedrock_signer.go | 67 ++ .../internal/service/bedrock_signer_test.go | 35 + backend/internal/service/bedrock_stream.go | 414 +++++++++++ .../internal/service/bedrock_stream_test.go | 261 +++++++ backend/internal/service/domain_constants.go | 2 + backend/internal/service/gateway_service.go | 467 +++++++++++++ .../gateway_service_bedrock_beta_test.go | 267 +++++++ ...eway_service_bedrock_model_support_test.go | 48 ++ .../components/account/CreateAccountModal.vue | 456 +++++++++++- .../components/account/EditAccountModal.vue | 355 ++++++++++ .../components/common/PlatformTypeBadge.vue | 2 + frontend/src/composables/useModelWhitelist.ts | 10 + frontend/src/i18n/locales/en.ts | 19 + frontend/src/i18n/locales/zh.ts | 19 + frontend/src/types/index.ts | 2 +- 23 files changed, 3839 insertions(+), 15 deletions(-) create mode 100644 backend/internal/service/bedrock_request.go create mode 100644 backend/internal/service/bedrock_request_test.go create mode 100644 backend/internal/service/bedrock_signer.go create mode 100644 backend/internal/service/bedrock_signer_test.go create mode 100644 backend/internal/service/bedrock_stream.go create mode 100644 backend/internal/service/bedrock_stream_test.go create mode 100644 backend/internal/service/gateway_service_bedrock_beta_test.go create mode 100644 backend/internal/service/gateway_service_bedrock_model_support_test.go diff --git a/backend/go.mod b/backend/go.mod index 03637401..135cbd3e 100644 --- a/backend/go.mod +++ b/backend/go.mod @@ -7,7 +7,7 @@ require ( github.com/DATA-DOG/go-sqlmock v1.5.2 github.com/DouDOU-start/go-sora2api v1.1.0 github.com/alitto/pond/v2 v2.6.2 - github.com/aws/aws-sdk-go-v2 v1.41.2 + github.com/aws/aws-sdk-go-v2 v1.41.3 github.com/aws/aws-sdk-go-v2/config v1.32.10 github.com/aws/aws-sdk-go-v2/credentials v1.19.10 github.com/aws/aws-sdk-go-v2/service/s3 v1.96.2 @@ -66,7 +66,7 @@ require ( github.com/aws/aws-sdk-go-v2/service/sso v1.30.11 // indirect github.com/aws/aws-sdk-go-v2/service/ssooidc v1.35.15 // indirect github.com/aws/aws-sdk-go-v2/service/sts v1.41.7 // indirect - github.com/aws/smithy-go v1.24.1 // indirect + github.com/aws/smithy-go v1.24.2 // indirect github.com/bdandy/go-errors v1.2.2 // indirect github.com/bdandy/go-socks4 v1.2.3 // indirect github.com/bmatcuk/doublestar v1.3.4 // indirect diff --git a/backend/go.sum b/backend/go.sum index 993a1d54..324fe652 100644 --- a/backend/go.sum +++ b/backend/go.sum @@ -24,6 +24,8 @@ github.com/apparentlymart/go-textseg/v15 v15.0.0 h1:uYvfpb3DyLSCGWnctWKGj857c6ew github.com/apparentlymart/go-textseg/v15 v15.0.0/go.mod h1:K8XmNZdhEBkdlyDdvbmmsvpAG721bKi0joRfFdHIWJ4= github.com/aws/aws-sdk-go-v2 v1.41.2 h1:LuT2rzqNQsauaGkPK/7813XxcZ3o3yePY0Iy891T2ls= github.com/aws/aws-sdk-go-v2 v1.41.2/go.mod h1:IvvlAZQXvTXznUPfRVfryiG1fbzE2NGK6m9u39YQ+S4= +github.com/aws/aws-sdk-go-v2 v1.41.3 h1:4kQ/fa22KjDt13QCy1+bYADvdgcxpfH18f0zP542kZA= +github.com/aws/aws-sdk-go-v2 v1.41.3/go.mod h1:mwsPRE8ceUUpiTgF7QmQIJ7lgsKUPQOUl3o72QBrE1o= github.com/aws/aws-sdk-go-v2/aws/protocol/eventstream v1.7.5 h1:zWFmPmgw4sveAYi1mRqG+E/g0461cJ5M4bJ8/nc6d3Q= github.com/aws/aws-sdk-go-v2/aws/protocol/eventstream v1.7.5/go.mod h1:nVUlMLVV8ycXSb7mSkcNu9e3v/1TJq2RTlrPwhYWr5c= github.com/aws/aws-sdk-go-v2/config v1.32.10 h1:9DMthfO6XWZYLfzZglAgW5Fyou2nRI5CuV44sTedKBI= @@ -60,6 +62,8 @@ github.com/aws/aws-sdk-go-v2/service/sts v1.41.7 h1:NITQpgo9A5NrDZ57uOWj+abvXSb8 github.com/aws/aws-sdk-go-v2/service/sts v1.41.7/go.mod h1:sks5UWBhEuWYDPdwlnRFn1w7xWdH29Jcpe+/PJQefEs= github.com/aws/smithy-go v1.24.1 h1:VbyeNfmYkWoxMVpGUAbQumkODcYmfMRfZ8yQiH30SK0= github.com/aws/smithy-go v1.24.1/go.mod h1:LEj2LM3rBRQJxPZTB4KuzZkaZYnZPnvgIhb4pu07mx0= +github.com/aws/smithy-go v1.24.2 h1:FzA3bu/nt/vDvmnkg+R8Xl46gmzEDam6mZ1hzmwXFng= +github.com/aws/smithy-go v1.24.2/go.mod h1:YE2RhdIuDbA5E5bTdciG9KrW3+TiEONeUWCqxX9i1Fc= github.com/bdandy/go-errors v1.2.2 h1:WdFv/oukjTJCLa79UfkGmwX7ZxONAihKu4V0mLIs11Q= github.com/bdandy/go-errors v1.2.2/go.mod h1:NkYHl4Fey9oRRdbB1CoC6e84tuqQHiqrOcZpqFEkBxM= github.com/bdandy/go-socks4 v1.2.3 h1:Q6Y2heY1GRjCtHbmlKfnwrKVU/k81LS8mRGLRlmDlic= diff --git a/backend/internal/domain/constants.go b/backend/internal/domain/constants.go index 8a6621a1..9920a76b 100644 --- a/backend/internal/domain/constants.go +++ b/backend/internal/domain/constants.go @@ -31,6 +31,8 @@ const ( AccountTypeSetupToken = "setup-token" // Setup Token类型账号(inference only scope) AccountTypeAPIKey = "apikey" // API Key类型账号 AccountTypeUpstream = "upstream" // 上游透传类型账号(通过 Base URL + API Key 连接上游) + AccountTypeBedrock = "bedrock" // AWS Bedrock 类型账号(通过 SigV4 签名连接 Bedrock) + AccountTypeBedrockAPIKey = "bedrock-apikey" // AWS Bedrock API Key 类型账号(通过 Bearer Token 连接 Bedrock) ) // Redeem type constants @@ -113,3 +115,27 @@ var DefaultAntigravityModelMapping = map[string]string{ "gpt-oss-120b-medium": "gpt-oss-120b-medium", "tab_flash_lite_preview": "tab_flash_lite_preview", } + +// DefaultBedrockModelMapping 是 AWS Bedrock 平台的默认模型映射 +// 将 Anthropic 标准模型名映射到 Bedrock 模型 ID +// 注意:此处的 "us." 前缀仅为默认值,ResolveBedrockModelID 会根据账号配置的 +// aws_region 自动调整为匹配的区域前缀(如 eu.、apac.、jp. 等) +var DefaultBedrockModelMapping = map[string]string{ + // Claude Opus + "claude-opus-4-6-thinking": "us.anthropic.claude-opus-4-6-v1", + "claude-opus-4-6": "us.anthropic.claude-opus-4-6-v1", + "claude-opus-4-5-thinking": "us.anthropic.claude-opus-4-5-20251101-v1:0", + "claude-opus-4-5-20251101": "us.anthropic.claude-opus-4-5-20251101-v1:0", + "claude-opus-4-1": "us.anthropic.claude-opus-4-1-20250805-v1:0", + "claude-opus-4-20250514": "us.anthropic.claude-opus-4-20250514-v1:0", + // Claude Sonnet + "claude-sonnet-4-6-thinking": "us.anthropic.claude-sonnet-4-6", + "claude-sonnet-4-6": "us.anthropic.claude-sonnet-4-6", + "claude-sonnet-4-5": "us.anthropic.claude-sonnet-4-5-20250929-v1:0", + "claude-sonnet-4-5-thinking": "us.anthropic.claude-sonnet-4-5-20250929-v1:0", + "claude-sonnet-4-5-20250929": "us.anthropic.claude-sonnet-4-5-20250929-v1:0", + "claude-sonnet-4-20250514": "us.anthropic.claude-sonnet-4-20250514-v1:0", + // Claude Haiku + "claude-haiku-4-5": "us.anthropic.claude-haiku-4-5-20251001-v1:0", + "claude-haiku-4-5-20251001": "us.anthropic.claude-haiku-4-5-20251001-v1:0", +} diff --git a/backend/internal/handler/admin/account_handler.go b/backend/internal/handler/admin/account_handler.go index 57c2dad1..c7ca0ca2 100644 --- a/backend/internal/handler/admin/account_handler.go +++ b/backend/internal/handler/admin/account_handler.go @@ -97,7 +97,7 @@ type CreateAccountRequest struct { Name string `json:"name" binding:"required"` Notes *string `json:"notes"` Platform string `json:"platform" binding:"required"` - Type string `json:"type" binding:"required,oneof=oauth setup-token apikey upstream"` + Type string `json:"type" binding:"required,oneof=oauth setup-token apikey upstream bedrock bedrock-apikey"` Credentials map[string]any `json:"credentials" binding:"required"` Extra map[string]any `json:"extra"` ProxyID *int64 `json:"proxy_id"` @@ -116,7 +116,7 @@ type CreateAccountRequest struct { type UpdateAccountRequest struct { Name string `json:"name"` Notes *string `json:"notes"` - Type string `json:"type" binding:"omitempty,oneof=oauth setup-token apikey upstream"` + Type string `json:"type" binding:"omitempty,oneof=oauth setup-token apikey upstream bedrock bedrock-apikey"` Credentials map[string]any `json:"credentials"` Extra map[string]any `json:"extra"` ProxyID *int64 `json:"proxy_id"` diff --git a/backend/internal/service/account.go b/backend/internal/service/account.go index 9d4f73d4..9a871c10 100644 --- a/backend/internal/service/account.go +++ b/backend/internal/service/account.go @@ -412,6 +412,7 @@ func (a *Account) resolveModelMapping(rawMapping map[string]any) map[string]stri if a.Platform == domain.PlatformAntigravity { return domain.DefaultAntigravityModelMapping } + // Bedrock 默认映射由 forwardBedrock 统一处理(需配合 region prefix 调整) return nil } if len(rawMapping) == 0 { @@ -764,6 +765,14 @@ func (a *Account) IsInterceptWarmupEnabled() bool { return false } +func (a *Account) IsBedrock() bool { + return a.Platform == PlatformAnthropic && (a.Type == AccountTypeBedrock || a.Type == AccountTypeBedrockAPIKey) +} + +func (a *Account) IsBedrockAPIKey() bool { + return a.Platform == PlatformAnthropic && a.Type == AccountTypeBedrockAPIKey +} + func (a *Account) IsOpenAI() bool { return a.Platform == PlatformOpenAI } diff --git a/backend/internal/service/account_test_service.go b/backend/internal/service/account_test_service.go index 472551cf..482d22b1 100644 --- a/backend/internal/service/account_test_service.go +++ b/backend/internal/service/account_test_service.go @@ -207,14 +207,14 @@ func (s *AccountTestService) testClaudeAccountConnection(c *gin.Context, account testModelID = claude.DefaultTestModel } - // For API Key accounts with model mapping, map the model + // API Key 账号测试连接时也需要应用通配符模型映射。 if account.Type == "apikey" { - mapping := account.GetModelMapping() - if len(mapping) > 0 { - if mappedModel, exists := mapping[testModelID]; exists { - testModelID = mappedModel - } - } + testModelID = account.GetMappedModel(testModelID) + } + + // Bedrock accounts use a separate test path + if account.IsBedrock() { + return s.testBedrockAccountConnection(c, ctx, account, testModelID) } // Determine authentication method and API URL @@ -312,6 +312,109 @@ func (s *AccountTestService) testClaudeAccountConnection(c *gin.Context, account return s.processClaudeStream(c, resp.Body) } +// testBedrockAccountConnection tests a Bedrock (SigV4 or API Key) account using non-streaming invoke +func (s *AccountTestService) testBedrockAccountConnection(c *gin.Context, ctx context.Context, account *Account, testModelID string) error { + region := bedrockRuntimeRegion(account) + resolvedModelID, ok := ResolveBedrockModelID(account, testModelID) + if !ok { + return s.sendErrorAndEnd(c, fmt.Sprintf("Unsupported Bedrock model: %s", testModelID)) + } + testModelID = resolvedModelID + + // Set SSE headers (test UI expects SSE) + c.Writer.Header().Set("Content-Type", "text/event-stream") + c.Writer.Header().Set("Cache-Control", "no-cache") + c.Writer.Header().Set("Connection", "keep-alive") + c.Writer.Header().Set("X-Accel-Buffering", "no") + c.Writer.Flush() + + // Create a minimal Bedrock-compatible payload (no stream, no cache_control) + bedrockPayload := map[string]any{ + "anthropic_version": "bedrock-2023-05-31", + "messages": []map[string]any{ + { + "role": "user", + "content": []map[string]any{ + { + "type": "text", + "text": "hi", + }, + }, + }, + }, + "max_tokens": 256, + "temperature": 1, + } + bedrockBody, _ := json.Marshal(bedrockPayload) + + // Use non-streaming endpoint (response is standard Claude JSON) + apiURL := BuildBedrockURL(region, testModelID, false) + + s.sendEvent(c, TestEvent{Type: "test_start", Model: testModelID}) + + req, err := http.NewRequestWithContext(ctx, "POST", apiURL, bytes.NewReader(bedrockBody)) + if err != nil { + return s.sendErrorAndEnd(c, "Failed to create request") + } + req.Header.Set("Content-Type", "application/json") + + // Sign or set auth based on account type + if account.IsBedrockAPIKey() { + apiKey := account.GetCredential("api_key") + if apiKey == "" { + return s.sendErrorAndEnd(c, "No API key available") + } + req.Header.Set("Authorization", "Bearer "+apiKey) + } else { + signer, err := NewBedrockSignerFromAccount(account) + if err != nil { + return s.sendErrorAndEnd(c, fmt.Sprintf("Failed to create Bedrock signer: %s", err.Error())) + } + if err := signer.SignRequest(ctx, req, bedrockBody); err != nil { + return s.sendErrorAndEnd(c, fmt.Sprintf("Failed to sign request: %s", err.Error())) + } + } + + proxyURL := "" + if account.ProxyID != nil && account.Proxy != nil { + proxyURL = account.Proxy.URL() + } + + resp, err := s.httpUpstream.DoWithTLS(req, proxyURL, account.ID, account.Concurrency, false) + if err != nil { + return s.sendErrorAndEnd(c, fmt.Sprintf("Request failed: %s", err.Error())) + } + defer func() { _ = resp.Body.Close() }() + + body, _ := io.ReadAll(resp.Body) + + if resp.StatusCode != http.StatusOK { + return s.sendErrorAndEnd(c, fmt.Sprintf("API returned %d: %s", resp.StatusCode, string(body))) + } + + // Bedrock non-streaming response is standard Claude JSON, extract the text + var result struct { + Content []struct { + Text string `json:"text"` + } `json:"content"` + } + if err := json.Unmarshal(body, &result); err != nil { + return s.sendErrorAndEnd(c, fmt.Sprintf("Failed to parse response: %s", err.Error())) + } + + text := "" + if len(result.Content) > 0 { + text = result.Content[0].Text + } + if text == "" { + text = "(empty response)" + } + + s.sendEvent(c, TestEvent{Type: "content", Text: text}) + s.sendEvent(c, TestEvent{Type: "test_complete", Success: true}) + return nil +} + // testOpenAIAccountConnection tests an OpenAI account's connection func (s *AccountTestService) testOpenAIAccountConnection(c *gin.Context, account *Account, modelID string) error { ctx := c.Request.Context() diff --git a/backend/internal/service/bedrock_request.go b/backend/internal/service/bedrock_request.go new file mode 100644 index 00000000..2160c13c --- /dev/null +++ b/backend/internal/service/bedrock_request.go @@ -0,0 +1,607 @@ +package service + +import ( + "encoding/json" + "fmt" + "net/url" + "regexp" + "strconv" + "strings" + + "github.com/Wei-Shaw/sub2api/internal/domain" + "github.com/tidwall/gjson" + "github.com/tidwall/sjson" +) + +const defaultBedrockRegion = "us-east-1" + +var bedrockCrossRegionPrefixes = []string{"us.", "eu.", "apac.", "jp.", "au.", "us-gov.", "global."} + +// BedrockCrossRegionPrefix 根据 AWS Region 返回 Bedrock 跨区域推理的模型 ID 前缀 +// 参考: https://docs.aws.amazon.com/bedrock/latest/userguide/inference-profiles-support.html +func BedrockCrossRegionPrefix(region string) string { + switch { + case strings.HasPrefix(region, "us-gov"): + return "us-gov" // GovCloud 使用独立的 us-gov 前缀 + case strings.HasPrefix(region, "us-"): + return "us" + case strings.HasPrefix(region, "eu-"): + return "eu" + case region == "ap-northeast-1": + return "jp" // 日本区域使用独立的 jp 前缀(AWS 官方定义) + case region == "ap-southeast-2": + return "au" // 澳大利亚区域使用独立的 au 前缀(AWS 官方定义) + case strings.HasPrefix(region, "ap-"): + return "apac" // 其余亚太区域使用通用 apac 前缀 + case strings.HasPrefix(region, "ca-"): + return "us" // 加拿大区域使用 us 前缀的跨区域推理 + case strings.HasPrefix(region, "sa-"): + return "us" // 南美区域使用 us 前缀的跨区域推理 + default: + return "us" + } +} + +// AdjustBedrockModelRegionPrefix 将模型 ID 的区域前缀替换为与当前 AWS Region 匹配的前缀 +// 例如 region=eu-west-1 时,"us.anthropic.claude-opus-4-6-v1" → "eu.anthropic.claude-opus-4-6-v1" +// 特殊值 region="global" 强制使用 global. 前缀 +func AdjustBedrockModelRegionPrefix(modelID, region string) string { + var targetPrefix string + if region == "global" { + targetPrefix = "global" + } else { + targetPrefix = BedrockCrossRegionPrefix(region) + } + + for _, p := range bedrockCrossRegionPrefixes { + if strings.HasPrefix(modelID, p) { + if p == targetPrefix+"." { + return modelID // 前缀已匹配,无需替换 + } + return targetPrefix + "." + modelID[len(p):] + } + } + + // 模型 ID 没有已知区域前缀(如 "anthropic.claude-..."),不做修改 + return modelID +} + +func bedrockRuntimeRegion(account *Account) string { + if account == nil { + return defaultBedrockRegion + } + if region := account.GetCredential("aws_region"); region != "" { + return region + } + return defaultBedrockRegion +} + +func shouldForceBedrockGlobal(account *Account) bool { + return account != nil && account.GetCredential("aws_force_global") == "true" +} + +func isRegionalBedrockModelID(modelID string) bool { + for _, prefix := range bedrockCrossRegionPrefixes { + if strings.HasPrefix(modelID, prefix) { + return true + } + } + return false +} + +func isLikelyBedrockModelID(modelID string) bool { + lower := strings.ToLower(strings.TrimSpace(modelID)) + if lower == "" { + return false + } + if strings.HasPrefix(lower, "arn:") { + return true + } + for _, prefix := range []string{ + "anthropic.", + "amazon.", + "meta.", + "mistral.", + "cohere.", + "ai21.", + "deepseek.", + "stability.", + "writer.", + "nova.", + } { + if strings.HasPrefix(lower, prefix) { + return true + } + } + return isRegionalBedrockModelID(lower) +} + +func normalizeBedrockModelID(modelID string) (normalized string, shouldAdjustRegion bool, ok bool) { + modelID = strings.TrimSpace(modelID) + if modelID == "" { + return "", false, false + } + if mapped, exists := domain.DefaultBedrockModelMapping[modelID]; exists { + return mapped, true, true + } + if isRegionalBedrockModelID(modelID) { + return modelID, true, true + } + if isLikelyBedrockModelID(modelID) { + return modelID, false, true + } + return "", false, false +} + +// ResolveBedrockModelID resolves a requested Claude model into a Bedrock model ID. +// It applies account model_mapping first, then default Bedrock aliases, and finally +// adjusts Anthropic cross-region prefixes to match the account region. +func ResolveBedrockModelID(account *Account, requestedModel string) (string, bool) { + if account == nil { + return "", false + } + + mappedModel := account.GetMappedModel(requestedModel) + modelID, shouldAdjustRegion, ok := normalizeBedrockModelID(mappedModel) + if !ok { + return "", false + } + if shouldAdjustRegion { + targetRegion := bedrockRuntimeRegion(account) + if shouldForceBedrockGlobal(account) { + targetRegion = "global" + } + modelID = AdjustBedrockModelRegionPrefix(modelID, targetRegion) + } + return modelID, true +} + +// BuildBedrockURL 构建 Bedrock InvokeModel 的 URL +// stream=true 时使用 invoke-with-response-stream 端点 +// modelID 中的特殊字符会被 URL 编码(与 litellm 的 urllib.parse.quote(safe="") 对齐) +func BuildBedrockURL(region, modelID string, stream bool) string { + if region == "" { + region = defaultBedrockRegion + } + encodedModelID := url.PathEscape(modelID) + // url.PathEscape 不编码冒号(RFC 允许 path 中出现 ":"), + // 但 AWS Bedrock 期望模型 ID 中的冒号被编码为 %3A + encodedModelID = strings.ReplaceAll(encodedModelID, ":", "%3A") + if stream { + return fmt.Sprintf("https://bedrock-runtime.%s.amazonaws.com/model/%s/invoke-with-response-stream", region, encodedModelID) + } + return fmt.Sprintf("https://bedrock-runtime.%s.amazonaws.com/model/%s/invoke", region, encodedModelID) +} + +// PrepareBedrockRequestBody 处理请求体以适配 Bedrock API +// 1. 注入 anthropic_version +// 2. 注入 anthropic_beta(从客户端 anthropic-beta 头解析) +// 3. 移除 Bedrock 不支持的字段(model, stream, output_format, output_config) +// 4. 移除工具定义中的 custom 字段(Claude Code 会发送 custom: {defer_loading: true}) +// 5. 清理 cache_control 中 Bedrock 不支持的字段(scope, ttl) +func PrepareBedrockRequestBody(body []byte, modelID string, betaHeader string) ([]byte, error) { + betaTokens := ResolveBedrockBetaTokens(betaHeader, body, modelID) + return PrepareBedrockRequestBodyWithTokens(body, modelID, betaTokens) +} + +// PrepareBedrockRequestBodyWithTokens prepares a Bedrock request using pre-resolved beta tokens. +func PrepareBedrockRequestBodyWithTokens(body []byte, modelID string, betaTokens []string) ([]byte, error) { + var err error + + // 注入 anthropic_version(Bedrock 要求) + body, err = sjson.SetBytes(body, "anthropic_version", "bedrock-2023-05-31") + if err != nil { + return nil, fmt.Errorf("inject anthropic_version: %w", err) + } + + // 注入 anthropic_beta(Bedrock Invoke 通过请求体传递 beta 头,而非 HTTP 头) + // 1. 从客户端 anthropic-beta header 解析 + // 2. 根据请求体内容自动补齐必要的 beta token + // 参考 litellm: AnthropicModelInfo.get_anthropic_beta_list() + _get_tool_search_beta_header_for_bedrock() + if len(betaTokens) > 0 { + body, err = sjson.SetBytes(body, "anthropic_beta", betaTokens) + if err != nil { + return nil, fmt.Errorf("inject anthropic_beta: %w", err) + } + } + + // 移除 model 字段(Bedrock 通过 URL 指定模型) + body, err = sjson.DeleteBytes(body, "model") + if err != nil { + return nil, fmt.Errorf("remove model field: %w", err) + } + + // 移除 stream 字段(Bedrock 通过不同端点控制流式,不接受请求体中的 stream 字段) + body, err = sjson.DeleteBytes(body, "stream") + if err != nil { + return nil, fmt.Errorf("remove stream field: %w", err) + } + + // 转换 output_format(Bedrock Invoke 不支持此字段,但可将 schema 内联到最后一条 user message) + // 参考 litellm: _convert_output_format_to_inline_schema() + body = convertOutputFormatToInlineSchema(body) + + // 移除 output_config 字段(Bedrock Invoke 不支持) + body, err = sjson.DeleteBytes(body, "output_config") + if err != nil { + return nil, fmt.Errorf("remove output_config field: %w", err) + } + + // 移除工具定义中的 custom 字段 + // Claude Code (v2.1.69+) 在 tool 定义中发送 custom: {defer_loading: true}, + // Anthropic API 接受但 Bedrock 会拒绝并报 "Extra inputs are not permitted" + body = removeCustomFieldFromTools(body) + + // 清理 cache_control 中 Bedrock 不支持的字段 + body = sanitizeBedrockCacheControl(body, modelID) + + return body, nil +} + +// ResolveBedrockBetaTokens computes the final Bedrock beta token list before policy filtering. +func ResolveBedrockBetaTokens(betaHeader string, body []byte, modelID string) []string { + betaTokens := parseAnthropicBetaHeader(betaHeader) + betaTokens = autoInjectBedrockBetaTokens(betaTokens, body, modelID) + return filterBedrockBetaTokens(betaTokens) +} + +// convertOutputFormatToInlineSchema 将 output_format 中的 JSON schema 内联到最后一条 user message +// Bedrock Invoke 不支持 output_format 参数,litellm 的做法是将 schema 追加到用户消息中 +// 参考: litellm AmazonAnthropicClaudeMessagesConfig._convert_output_format_to_inline_schema() +func convertOutputFormatToInlineSchema(body []byte) []byte { + outputFormat := gjson.GetBytes(body, "output_format") + if !outputFormat.Exists() || !outputFormat.IsObject() { + return body + } + + // 先从请求体中移除 output_format + body, _ = sjson.DeleteBytes(body, "output_format") + + schema := outputFormat.Get("schema") + if !schema.Exists() { + return body + } + + // 找到最后一条 user message + messages := gjson.GetBytes(body, "messages") + if !messages.Exists() || !messages.IsArray() { + return body + } + msgArr := messages.Array() + lastUserIdx := -1 + for i := len(msgArr) - 1; i >= 0; i-- { + if msgArr[i].Get("role").String() == "user" { + lastUserIdx = i + break + } + } + if lastUserIdx < 0 { + return body + } + + // 将 schema 序列化为 JSON 文本追加到该 message 的 content 数组 + schemaJSON, err := json.Marshal(json.RawMessage(schema.Raw)) + if err != nil { + return body + } + + content := msgArr[lastUserIdx].Get("content") + basePath := fmt.Sprintf("messages.%d.content", lastUserIdx) + + if content.IsArray() { + // 追加一个 text block 到 content 数组末尾 + idx := len(content.Array()) + body, _ = sjson.SetBytes(body, fmt.Sprintf("%s.%d.type", basePath, idx), "text") + body, _ = sjson.SetBytes(body, fmt.Sprintf("%s.%d.text", basePath, idx), string(schemaJSON)) + } else if content.Type == gjson.String { + // content 是纯字符串,转换为数组格式 + originalText := content.String() + body, _ = sjson.SetBytes(body, basePath, []map[string]string{ + {"type": "text", "text": originalText}, + {"type": "text", "text": string(schemaJSON)}, + }) + } + + return body +} + +// removeCustomFieldFromTools 移除 tools 数组中每个工具定义的 custom 字段 +func removeCustomFieldFromTools(body []byte) []byte { + tools := gjson.GetBytes(body, "tools") + if !tools.Exists() || !tools.IsArray() { + return body + } + var err error + for i := range tools.Array() { + body, err = sjson.DeleteBytes(body, fmt.Sprintf("tools.%d.custom", i)) + if err != nil { + // 删除失败不影响整体流程,跳过 + continue + } + } + return body +} + +// claudeVersionRe 匹配 Claude 模型 ID 中的版本号部分 +// 支持 claude-{tier}-{major}-{minor} 和 claude-{tier}-{major}.{minor} 格式 +var claudeVersionRe = regexp.MustCompile(`claude-(?:haiku|sonnet|opus)-(\d+)[-.](\d+)`) + +// isBedrockClaude45OrNewer 判断 Bedrock 模型 ID 是否为 Claude 4.5 或更新版本 +// Claude 4.5+ 支持 cache_control 中的 ttl 字段("5m" 和 "1h") +func isBedrockClaude45OrNewer(modelID string) bool { + lower := strings.ToLower(modelID) + matches := claudeVersionRe.FindStringSubmatch(lower) + if matches == nil { + return false + } + major, _ := strconv.Atoi(matches[1]) + minor, _ := strconv.Atoi(matches[2]) + return major > 4 || (major == 4 && minor >= 5) +} + +// sanitizeBedrockCacheControl 清理 system 和 messages 中 cache_control 里 +// Bedrock 不支持的字段: +// - scope:Bedrock 不支持(如 "global" 跨请求缓存) +// - ttl:仅 Claude 4.5+ 支持 "5m" 和 "1h",旧模型需要移除 +func sanitizeBedrockCacheControl(body []byte, modelID string) []byte { + isClaude45 := isBedrockClaude45OrNewer(modelID) + + // 清理 system 数组中的 cache_control + systemArr := gjson.GetBytes(body, "system") + if systemArr.Exists() && systemArr.IsArray() { + for i, item := range systemArr.Array() { + if !item.IsObject() { + continue + } + cc := item.Get("cache_control") + if !cc.Exists() || !cc.IsObject() { + continue + } + body = deleteCacheControlUnsupportedFields(body, fmt.Sprintf("system.%d.cache_control", i), cc, isClaude45) + } + } + + // 清理 messages 中的 cache_control + messages := gjson.GetBytes(body, "messages") + if !messages.Exists() || !messages.IsArray() { + return body + } + for mi, msg := range messages.Array() { + if !msg.IsObject() { + continue + } + content := msg.Get("content") + if !content.Exists() || !content.IsArray() { + continue + } + for ci, block := range content.Array() { + if !block.IsObject() { + continue + } + cc := block.Get("cache_control") + if !cc.Exists() || !cc.IsObject() { + continue + } + body = deleteCacheControlUnsupportedFields(body, fmt.Sprintf("messages.%d.content.%d.cache_control", mi, ci), cc, isClaude45) + } + } + + return body +} + +// deleteCacheControlUnsupportedFields 删除给定 cache_control 路径下 Bedrock 不支持的字段 +func deleteCacheControlUnsupportedFields(body []byte, basePath string, cc gjson.Result, isClaude45 bool) []byte { + // Bedrock 不支持 scope(如 "global") + if cc.Get("scope").Exists() { + body, _ = sjson.DeleteBytes(body, basePath+".scope") + } + + // ttl:仅 Claude 4.5+ 支持 "5m" 和 "1h",其余情况移除 + ttl := cc.Get("ttl") + if ttl.Exists() { + shouldRemove := true + if isClaude45 { + v := ttl.String() + if v == "5m" || v == "1h" { + shouldRemove = false + } + } + if shouldRemove { + body, _ = sjson.DeleteBytes(body, basePath+".ttl") + } + } + + return body +} + +// parseAnthropicBetaHeader 解析 anthropic-beta 头的逗号分隔字符串为 token 列表 +func parseAnthropicBetaHeader(header string) []string { + header = strings.TrimSpace(header) + if header == "" { + return nil + } + if strings.HasPrefix(header, "[") && strings.HasSuffix(header, "]") { + var parsed []any + if err := json.Unmarshal([]byte(header), &parsed); err == nil { + tokens := make([]string, 0, len(parsed)) + for _, item := range parsed { + token := strings.TrimSpace(fmt.Sprint(item)) + if token != "" { + tokens = append(tokens, token) + } + } + return tokens + } + } + var tokens []string + for _, part := range strings.Split(header, ",") { + t := strings.TrimSpace(part) + if t != "" { + tokens = append(tokens, t) + } + } + return tokens +} + +// bedrockSupportedBetaTokens 是 Bedrock Invoke 支持的 beta 头白名单 +// 参考: litellm/litellm/llms/bedrock/common_utils.py (anthropic_beta_headers_config.json) +// 更新策略: 当 AWS Bedrock 新增支持的 beta token 时需同步更新此白名单 +var bedrockSupportedBetaTokens = map[string]bool{ + "computer-use-2025-01-24": true, + "computer-use-2025-11-24": true, + "context-1m-2025-08-07": true, + "context-management-2025-06-27": true, + "compact-2026-01-12": true, + "interleaved-thinking-2025-05-14": true, + "tool-search-tool-2025-10-19": true, + "tool-examples-2025-10-29": true, +} + +// bedrockBetaTokenTransforms 定义 Bedrock Invoke 特有的 beta 头转换规则 +// Anthropic 直接 API 使用通用头,Bedrock Invoke 需要特定的替代头 +var bedrockBetaTokenTransforms = map[string]string{ + "advanced-tool-use-2025-11-20": "tool-search-tool-2025-10-19", +} + +// autoInjectBedrockBetaTokens 根据请求体内容自动补齐必要的 beta token +// 参考 litellm: AnthropicModelInfo.get_anthropic_beta_list() 和 +// AmazonAnthropicClaudeMessagesConfig._get_tool_search_beta_header_for_bedrock() +// +// 客户端(特别是非 Claude Code 客户端)可能只在 body 中启用了功能而不在 header 中带对应 beta token, +// 这里通过检测请求体特征自动补齐,确保 Bedrock Invoke 不会因缺少必要 beta 头而 400。 +func autoInjectBedrockBetaTokens(tokens []string, body []byte, modelID string) []string { + seen := make(map[string]bool, len(tokens)) + for _, t := range tokens { + seen[t] = true + } + + inject := func(token string) { + if !seen[token] { + tokens = append(tokens, token) + seen[token] = true + } + } + + // 检测 thinking / interleaved thinking + // 请求体中有 "thinking" 字段 → 需要 interleaved-thinking beta + if gjson.GetBytes(body, "thinking").Exists() { + inject("interleaved-thinking-2025-05-14") + } + + // 检测 computer_use 工具 + // tools 中有 type="computer_20xxxxxx" 的工具 → 需要 computer-use beta + tools := gjson.GetBytes(body, "tools") + if tools.Exists() && tools.IsArray() { + toolSearchUsed := false + programmaticToolCallingUsed := false + inputExamplesUsed := false + for _, tool := range tools.Array() { + toolType := tool.Get("type").String() + if strings.HasPrefix(toolType, "computer_20") { + inject("computer-use-2025-11-24") + } + if isBedrockToolSearchType(toolType) { + toolSearchUsed = true + } + if hasCodeExecutionAllowedCallers(tool) { + programmaticToolCallingUsed = true + } + if hasInputExamples(tool) { + inputExamplesUsed = true + } + } + if programmaticToolCallingUsed || inputExamplesUsed { + // programmatic tool calling 和 input examples 需要 advanced-tool-use, + // 后续 filterBedrockBetaTokens 会将其转换为 Bedrock 特定的 tool-search-tool + inject("advanced-tool-use-2025-11-20") + } + if toolSearchUsed && bedrockModelSupportsToolSearch(modelID) { + // 纯 tool search(无 programmatic/inputExamples)时直接注入 Bedrock 特定头, + // 跳过 advanced-tool-use → tool-search-tool 的转换步骤(与 litellm 对齐) + if !programmaticToolCallingUsed && !inputExamplesUsed { + inject("tool-search-tool-2025-10-19") + } else { + inject("advanced-tool-use-2025-11-20") + } + } + } + + return tokens +} + +func isBedrockToolSearchType(toolType string) bool { + return toolType == "tool_search_tool_regex_20251119" || toolType == "tool_search_tool_bm25_20251119" +} + +func hasCodeExecutionAllowedCallers(tool gjson.Result) bool { + allowedCallers := tool.Get("allowed_callers") + if containsStringInJSONArray(allowedCallers, "code_execution_20250825") { + return true + } + return containsStringInJSONArray(tool.Get("function.allowed_callers"), "code_execution_20250825") +} + +func hasInputExamples(tool gjson.Result) bool { + if arr := tool.Get("input_examples"); arr.Exists() && arr.IsArray() && len(arr.Array()) > 0 { + return true + } + arr := tool.Get("function.input_examples") + return arr.Exists() && arr.IsArray() && len(arr.Array()) > 0 +} + +func containsStringInJSONArray(result gjson.Result, target string) bool { + if !result.Exists() || !result.IsArray() { + return false + } + for _, item := range result.Array() { + if item.String() == target { + return true + } + } + return false +} + +// bedrockModelSupportsToolSearch 判断 Bedrock 模型是否支持 tool search +// 目前仅 Claude Opus/Sonnet 4.5+ 支持,Haiku 不支持 +func bedrockModelSupportsToolSearch(modelID string) bool { + lower := strings.ToLower(modelID) + matches := claudeVersionRe.FindStringSubmatch(lower) + if matches == nil { + return false + } + // Haiku 不支持 tool search + if strings.Contains(lower, "haiku") { + return false + } + major, _ := strconv.Atoi(matches[1]) + minor, _ := strconv.Atoi(matches[2]) + return major > 4 || (major == 4 && minor >= 5) +} + +// filterBedrockBetaTokens 过滤并转换 beta token 列表,仅保留 Bedrock Invoke 支持的 token +// 1. 应用转换规则(如 advanced-tool-use → tool-search-tool) +// 2. 过滤掉 Bedrock 不支持的 token(如 output-128k, files-api, structured-outputs 等) +// 3. 自动关联 tool-examples(当 tool-search-tool 存在时) +func filterBedrockBetaTokens(tokens []string) []string { + seen := make(map[string]bool, len(tokens)) + var result []string + + for _, t := range tokens { + // 应用转换规则 + if replacement, ok := bedrockBetaTokenTransforms[t]; ok { + t = replacement + } + // 只保留白名单中的 token,且去重 + if bedrockSupportedBetaTokens[t] && !seen[t] { + result = append(result, t) + seen[t] = true + } + } + + // 自动关联: tool-search-tool 存在时,确保 tool-examples 也存在 + if seen["tool-search-tool-2025-10-19"] && !seen["tool-examples-2025-10-29"] { + result = append(result, "tool-examples-2025-10-29") + } + + return result +} diff --git a/backend/internal/service/bedrock_request_test.go b/backend/internal/service/bedrock_request_test.go new file mode 100644 index 00000000..361cafb4 --- /dev/null +++ b/backend/internal/service/bedrock_request_test.go @@ -0,0 +1,659 @@ +package service + +import ( + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "github.com/tidwall/gjson" +) + +func TestPrepareBedrockRequestBody_BasicFields(t *testing.T) { + input := `{"model":"claude-opus-4-6","stream":true,"max_tokens":1024,"messages":[{"role":"user","content":"hi"}]}` + result, err := PrepareBedrockRequestBody([]byte(input), "us.anthropic.claude-opus-4-6-v1", "") + require.NoError(t, err) + + // anthropic_version 应被注入 + assert.Equal(t, "bedrock-2023-05-31", gjson.GetBytes(result, "anthropic_version").String()) + // model 和 stream 应被移除 + assert.False(t, gjson.GetBytes(result, "model").Exists()) + assert.False(t, gjson.GetBytes(result, "stream").Exists()) + // max_tokens 应保留 + assert.Equal(t, int64(1024), gjson.GetBytes(result, "max_tokens").Int()) +} + +func TestPrepareBedrockRequestBody_OutputFormatInlineSchema(t *testing.T) { + t.Run("schema inlined into last user message array content", func(t *testing.T) { + input := `{"model":"claude-sonnet-4-5","output_format":{"type":"json","schema":{"name":"string"}},"messages":[{"role":"user","content":[{"type":"text","text":"hello"}]}]}` + result, err := PrepareBedrockRequestBody([]byte(input), "us.anthropic.claude-sonnet-4-5-v1", "") + require.NoError(t, err) + + assert.False(t, gjson.GetBytes(result, "output_format").Exists()) + // schema 应内联到最后一条 user message 的 content 数组末尾 + contentArr := gjson.GetBytes(result, "messages.0.content").Array() + require.Len(t, contentArr, 2) + assert.Equal(t, "text", contentArr[1].Get("type").String()) + assert.Contains(t, contentArr[1].Get("text").String(), `"name":"string"`) + }) + + t.Run("schema inlined into string content", func(t *testing.T) { + input := `{"model":"claude-sonnet-4-5","output_format":{"type":"json","schema":{"result":"number"}},"messages":[{"role":"user","content":"compute this"}]}` + result, err := PrepareBedrockRequestBody([]byte(input), "us.anthropic.claude-sonnet-4-5-v1", "") + require.NoError(t, err) + + assert.False(t, gjson.GetBytes(result, "output_format").Exists()) + contentArr := gjson.GetBytes(result, "messages.0.content").Array() + require.Len(t, contentArr, 2) + assert.Equal(t, "compute this", contentArr[0].Get("text").String()) + assert.Contains(t, contentArr[1].Get("text").String(), `"result":"number"`) + }) + + t.Run("no schema field just removes output_format", func(t *testing.T) { + input := `{"model":"claude-sonnet-4-5","output_format":{"type":"json"},"messages":[{"role":"user","content":"hi"}]}` + result, err := PrepareBedrockRequestBody([]byte(input), "us.anthropic.claude-sonnet-4-5-v1", "") + require.NoError(t, err) + + assert.False(t, gjson.GetBytes(result, "output_format").Exists()) + }) + + t.Run("no messages just removes output_format", func(t *testing.T) { + input := `{"model":"claude-sonnet-4-5","output_format":{"type":"json","schema":{"name":"string"}}}` + result, err := PrepareBedrockRequestBody([]byte(input), "us.anthropic.claude-sonnet-4-5-v1", "") + require.NoError(t, err) + + assert.False(t, gjson.GetBytes(result, "output_format").Exists()) + }) +} + +func TestPrepareBedrockRequestBody_RemoveOutputConfig(t *testing.T) { + input := `{"model":"claude-sonnet-4-5","output_config":{"max_tokens":100},"messages":[]}` + result, err := PrepareBedrockRequestBody([]byte(input), "us.anthropic.claude-sonnet-4-5-v1", "") + require.NoError(t, err) + + assert.False(t, gjson.GetBytes(result, "output_config").Exists()) +} + +func TestRemoveCustomFieldFromTools(t *testing.T) { + input := `{ + "tools": [ + {"name":"tool1","custom":{"defer_loading":true},"description":"desc1"}, + {"name":"tool2","description":"desc2"}, + {"name":"tool3","custom":{"defer_loading":true,"other":123},"description":"desc3"} + ] + }` + result := removeCustomFieldFromTools([]byte(input)) + + tools := gjson.GetBytes(result, "tools").Array() + require.Len(t, tools, 3) + // custom 应被移除 + assert.False(t, tools[0].Get("custom").Exists()) + // name/description 应保留 + assert.Equal(t, "tool1", tools[0].Get("name").String()) + assert.Equal(t, "desc1", tools[0].Get("description").String()) + // 没有 custom 的工具不受影响 + assert.Equal(t, "tool2", tools[1].Get("name").String()) + // 第三个工具的 custom 也应被移除 + assert.False(t, tools[2].Get("custom").Exists()) + assert.Equal(t, "tool3", tools[2].Get("name").String()) +} + +func TestRemoveCustomFieldFromTools_NoTools(t *testing.T) { + input := `{"messages":[{"role":"user","content":"hi"}]}` + result := removeCustomFieldFromTools([]byte(input)) + // 无 tools 时不改变原始数据 + assert.JSONEq(t, input, string(result)) +} + +func TestSanitizeBedrockCacheControl_RemoveScope(t *testing.T) { + input := `{ + "system": [{"type":"text","text":"sys","cache_control":{"type":"ephemeral","scope":"global"}}], + "messages": [{"role":"user","content":[{"type":"text","text":"hi","cache_control":{"type":"ephemeral","scope":"global"}}]}] + }` + result := sanitizeBedrockCacheControl([]byte(input), "us.anthropic.claude-opus-4-6-v1") + + // scope 应被移除 + assert.False(t, gjson.GetBytes(result, "system.0.cache_control.scope").Exists()) + assert.False(t, gjson.GetBytes(result, "messages.0.content.0.cache_control.scope").Exists()) + // type 应保留 + assert.Equal(t, "ephemeral", gjson.GetBytes(result, "system.0.cache_control.type").String()) + assert.Equal(t, "ephemeral", gjson.GetBytes(result, "messages.0.content.0.cache_control.type").String()) +} + +func TestSanitizeBedrockCacheControl_TTL_OldModel(t *testing.T) { + input := `{ + "system": [{"type":"text","text":"sys","cache_control":{"type":"ephemeral","ttl":"5m"}}] + }` + // 旧模型(Claude 3.5)不支持 ttl + result := sanitizeBedrockCacheControl([]byte(input), "anthropic.claude-3-5-sonnet-20241022-v2:0") + + assert.False(t, gjson.GetBytes(result, "system.0.cache_control.ttl").Exists()) + assert.Equal(t, "ephemeral", gjson.GetBytes(result, "system.0.cache_control.type").String()) +} + +func TestSanitizeBedrockCacheControl_TTL_Claude45_Supported(t *testing.T) { + input := `{ + "system": [{"type":"text","text":"sys","cache_control":{"type":"ephemeral","ttl":"5m"}}] + }` + // Claude 4.5+ 支持 "5m" 和 "1h" + result := sanitizeBedrockCacheControl([]byte(input), "us.anthropic.claude-sonnet-4-5-20250929-v1:0") + + assert.True(t, gjson.GetBytes(result, "system.0.cache_control.ttl").Exists()) + assert.Equal(t, "5m", gjson.GetBytes(result, "system.0.cache_control.ttl").String()) +} + +func TestSanitizeBedrockCacheControl_TTL_Claude45_UnsupportedValue(t *testing.T) { + input := `{ + "system": [{"type":"text","text":"sys","cache_control":{"type":"ephemeral","ttl":"10m"}}] + }` + // Claude 4.5 不支持 "10m" + result := sanitizeBedrockCacheControl([]byte(input), "us.anthropic.claude-sonnet-4-5-20250929-v1:0") + + assert.False(t, gjson.GetBytes(result, "system.0.cache_control.ttl").Exists()) +} + +func TestSanitizeBedrockCacheControl_TTL_Claude46(t *testing.T) { + input := `{ + "messages": [{"role":"user","content":[{"type":"text","text":"hi","cache_control":{"type":"ephemeral","ttl":"1h"}}]}] + }` + result := sanitizeBedrockCacheControl([]byte(input), "us.anthropic.claude-opus-4-6-v1") + + assert.True(t, gjson.GetBytes(result, "messages.0.content.0.cache_control.ttl").Exists()) + assert.Equal(t, "1h", gjson.GetBytes(result, "messages.0.content.0.cache_control.ttl").String()) +} + +func TestSanitizeBedrockCacheControl_NoCacheControl(t *testing.T) { + input := `{"system":[{"type":"text","text":"sys"}],"messages":[{"role":"user","content":[{"type":"text","text":"hi"}]}]}` + result := sanitizeBedrockCacheControl([]byte(input), "us.anthropic.claude-opus-4-6-v1") + // 无 cache_control 时不改变原始数据 + assert.JSONEq(t, input, string(result)) +} + +func TestIsBedrockClaude45OrNewer(t *testing.T) { + tests := []struct { + modelID string + expect bool + }{ + {"us.anthropic.claude-opus-4-6-v1", true}, + {"us.anthropic.claude-sonnet-4-6", true}, + {"us.anthropic.claude-sonnet-4-5-20250929-v1:0", true}, + {"us.anthropic.claude-opus-4-5-20251101-v1:0", true}, + {"us.anthropic.claude-haiku-4-5-20251001-v1:0", true}, + {"anthropic.claude-3-5-sonnet-20241022-v2:0", false}, + {"anthropic.claude-3-opus-20240229-v1:0", false}, + {"anthropic.claude-3-haiku-20240307-v1:0", false}, + // 未来版本应自动支持 + {"us.anthropic.claude-sonnet-5-0-v1", true}, + {"us.anthropic.claude-opus-4-7-v1", true}, + // 旧版本 + {"anthropic.claude-opus-4-1-v1", false}, + {"anthropic.claude-sonnet-4-0-v1", false}, + // 非 Claude 模型 + {"amazon.nova-pro-v1", false}, + {"meta.llama3-70b", false}, + } + for _, tt := range tests { + t.Run(tt.modelID, func(t *testing.T) { + assert.Equal(t, tt.expect, isBedrockClaude45OrNewer(tt.modelID)) + }) + } +} + +func TestPrepareBedrockRequestBody_FullIntegration(t *testing.T) { + // 模拟一个完整的 Claude Code 请求 + input := `{ + "model": "claude-opus-4-6", + "stream": true, + "max_tokens": 16384, + "output_format": {"type": "json", "schema": {"result": "string"}}, + "output_config": {"max_tokens": 100}, + "system": [{"type": "text", "text": "You are helpful", "cache_control": {"type": "ephemeral", "scope": "global", "ttl": "5m"}}], + "messages": [ + {"role": "user", "content": [{"type": "text", "text": "hello", "cache_control": {"type": "ephemeral", "ttl": "1h"}}]} + ], + "tools": [ + {"name": "bash", "description": "Run bash", "custom": {"defer_loading": true}, "input_schema": {"type": "object"}}, + {"name": "read", "description": "Read file", "input_schema": {"type": "object"}} + ] + }` + + betaHeader := "interleaved-thinking-2025-05-14, context-1m-2025-08-07, compact-2026-01-12" + result, err := PrepareBedrockRequestBody([]byte(input), "us.anthropic.claude-opus-4-6-v1", betaHeader) + require.NoError(t, err) + + // 基本字段 + assert.Equal(t, "bedrock-2023-05-31", gjson.GetBytes(result, "anthropic_version").String()) + assert.False(t, gjson.GetBytes(result, "model").Exists()) + assert.False(t, gjson.GetBytes(result, "stream").Exists()) + assert.Equal(t, int64(16384), gjson.GetBytes(result, "max_tokens").Int()) + + // anthropic_beta 应包含所有 beta tokens + betaArr := gjson.GetBytes(result, "anthropic_beta").Array() + require.Len(t, betaArr, 3) + assert.Equal(t, "interleaved-thinking-2025-05-14", betaArr[0].String()) + assert.Equal(t, "context-1m-2025-08-07", betaArr[1].String()) + assert.Equal(t, "compact-2026-01-12", betaArr[2].String()) + + // output_format 应被移除,schema 内联到最后一条 user message + assert.False(t, gjson.GetBytes(result, "output_format").Exists()) + assert.False(t, gjson.GetBytes(result, "output_config").Exists()) + // content 数组:原始 text block + 内联 schema block + contentArr := gjson.GetBytes(result, "messages.0.content").Array() + require.Len(t, contentArr, 2) + assert.Equal(t, "hello", contentArr[0].Get("text").String()) + assert.Contains(t, contentArr[1].Get("text").String(), `"result":"string"`) + + // tools 中的 custom 应被移除 + assert.False(t, gjson.GetBytes(result, "tools.0.custom").Exists()) + assert.Equal(t, "bash", gjson.GetBytes(result, "tools.0.name").String()) + assert.Equal(t, "read", gjson.GetBytes(result, "tools.1.name").String()) + + // cache_control: scope 应被移除,ttl 在 Claude 4.6 上保留合法值 + assert.False(t, gjson.GetBytes(result, "system.0.cache_control.scope").Exists()) + assert.Equal(t, "ephemeral", gjson.GetBytes(result, "system.0.cache_control.type").String()) + assert.Equal(t, "5m", gjson.GetBytes(result, "system.0.cache_control.ttl").String()) + assert.Equal(t, "1h", gjson.GetBytes(result, "messages.0.content.0.cache_control.ttl").String()) +} + +func TestPrepareBedrockRequestBody_BetaHeader(t *testing.T) { + input := `{"messages":[{"role":"user","content":"hi"}],"max_tokens":100}` + + t.Run("empty beta header", func(t *testing.T) { + result, err := PrepareBedrockRequestBody([]byte(input), "us.anthropic.claude-opus-4-6-v1", "") + require.NoError(t, err) + assert.False(t, gjson.GetBytes(result, "anthropic_beta").Exists()) + }) + + t.Run("single beta token", func(t *testing.T) { + result, err := PrepareBedrockRequestBody([]byte(input), "us.anthropic.claude-opus-4-6-v1", "interleaved-thinking-2025-05-14") + require.NoError(t, err) + arr := gjson.GetBytes(result, "anthropic_beta").Array() + require.Len(t, arr, 1) + assert.Equal(t, "interleaved-thinking-2025-05-14", arr[0].String()) + }) + + t.Run("multiple beta tokens with spaces", func(t *testing.T) { + result, err := PrepareBedrockRequestBody([]byte(input), "us.anthropic.claude-opus-4-6-v1", "interleaved-thinking-2025-05-14 , context-1m-2025-08-07 ") + require.NoError(t, err) + arr := gjson.GetBytes(result, "anthropic_beta").Array() + require.Len(t, arr, 2) + assert.Equal(t, "interleaved-thinking-2025-05-14", arr[0].String()) + assert.Equal(t, "context-1m-2025-08-07", arr[1].String()) + }) + + t.Run("json array beta header", func(t *testing.T) { + result, err := PrepareBedrockRequestBody([]byte(input), "us.anthropic.claude-opus-4-6-v1", `["interleaved-thinking-2025-05-14","context-1m-2025-08-07"]`) + require.NoError(t, err) + arr := gjson.GetBytes(result, "anthropic_beta").Array() + require.Len(t, arr, 2) + assert.Equal(t, "interleaved-thinking-2025-05-14", arr[0].String()) + assert.Equal(t, "context-1m-2025-08-07", arr[1].String()) + }) +} + +func TestParseAnthropicBetaHeader(t *testing.T) { + assert.Nil(t, parseAnthropicBetaHeader("")) + assert.Equal(t, []string{"a"}, parseAnthropicBetaHeader("a")) + assert.Equal(t, []string{"a", "b"}, parseAnthropicBetaHeader("a,b")) + assert.Equal(t, []string{"a", "b"}, parseAnthropicBetaHeader("a , b ")) + assert.Equal(t, []string{"a", "b", "c"}, parseAnthropicBetaHeader("a,b,c")) + assert.Equal(t, []string{"a", "b"}, parseAnthropicBetaHeader(`["a","b"]`)) +} + +func TestFilterBedrockBetaTokens(t *testing.T) { + t.Run("supported tokens pass through", func(t *testing.T) { + tokens := []string{"interleaved-thinking-2025-05-14", "context-1m-2025-08-07", "compact-2026-01-12"} + result := filterBedrockBetaTokens(tokens) + assert.Equal(t, tokens, result) + }) + + t.Run("unsupported tokens are filtered out", func(t *testing.T) { + tokens := []string{"interleaved-thinking-2025-05-14", "output-128k-2025-02-19", "files-api-2025-04-14", "structured-outputs-2025-11-13"} + result := filterBedrockBetaTokens(tokens) + assert.Equal(t, []string{"interleaved-thinking-2025-05-14"}, result) + }) + + t.Run("advanced-tool-use transforms to tool-search-tool", func(t *testing.T) { + tokens := []string{"advanced-tool-use-2025-11-20"} + result := filterBedrockBetaTokens(tokens) + assert.Contains(t, result, "tool-search-tool-2025-10-19") + // tool-examples 自动关联 + assert.Contains(t, result, "tool-examples-2025-10-29") + }) + + t.Run("tool-search-tool auto-associates tool-examples", func(t *testing.T) { + tokens := []string{"tool-search-tool-2025-10-19"} + result := filterBedrockBetaTokens(tokens) + assert.Contains(t, result, "tool-search-tool-2025-10-19") + assert.Contains(t, result, "tool-examples-2025-10-29") + }) + + t.Run("no duplication when tool-examples already present", func(t *testing.T) { + tokens := []string{"tool-search-tool-2025-10-19", "tool-examples-2025-10-29"} + result := filterBedrockBetaTokens(tokens) + count := 0 + for _, t := range result { + if t == "tool-examples-2025-10-29" { + count++ + } + } + assert.Equal(t, 1, count) + }) + + t.Run("empty input returns nil", func(t *testing.T) { + result := filterBedrockBetaTokens(nil) + assert.Nil(t, result) + }) + + t.Run("all unsupported returns nil", func(t *testing.T) { + result := filterBedrockBetaTokens([]string{"output-128k-2025-02-19", "effort-2025-11-24"}) + assert.Nil(t, result) + }) + + t.Run("duplicate tokens are deduplicated", func(t *testing.T) { + tokens := []string{"context-1m-2025-08-07", "context-1m-2025-08-07"} + result := filterBedrockBetaTokens(tokens) + assert.Equal(t, []string{"context-1m-2025-08-07"}, result) + }) +} + +func TestPrepareBedrockRequestBody_BetaFiltering(t *testing.T) { + input := `{"messages":[{"role":"user","content":"hi"}],"max_tokens":100}` + + t.Run("unsupported beta tokens are filtered", func(t *testing.T) { + result, err := PrepareBedrockRequestBody([]byte(input), "us.anthropic.claude-opus-4-6-v1", + "interleaved-thinking-2025-05-14, output-128k-2025-02-19, files-api-2025-04-14") + require.NoError(t, err) + arr := gjson.GetBytes(result, "anthropic_beta").Array() + require.Len(t, arr, 1) + assert.Equal(t, "interleaved-thinking-2025-05-14", arr[0].String()) + }) + + t.Run("advanced-tool-use transformed in full pipeline", func(t *testing.T) { + result, err := PrepareBedrockRequestBody([]byte(input), "us.anthropic.claude-opus-4-6-v1", + "advanced-tool-use-2025-11-20") + require.NoError(t, err) + arr := gjson.GetBytes(result, "anthropic_beta").Array() + require.Len(t, arr, 2) + assert.Equal(t, "tool-search-tool-2025-10-19", arr[0].String()) + assert.Equal(t, "tool-examples-2025-10-29", arr[1].String()) + }) +} + +func TestBedrockCrossRegionPrefix(t *testing.T) { + tests := []struct { + region string + expect string + }{ + // US regions + {"us-east-1", "us"}, + {"us-east-2", "us"}, + {"us-west-1", "us"}, + {"us-west-2", "us"}, + // GovCloud + {"us-gov-east-1", "us-gov"}, + {"us-gov-west-1", "us-gov"}, + // EU regions + {"eu-west-1", "eu"}, + {"eu-west-2", "eu"}, + {"eu-west-3", "eu"}, + {"eu-central-1", "eu"}, + {"eu-central-2", "eu"}, + {"eu-north-1", "eu"}, + {"eu-south-1", "eu"}, + // APAC regions + {"ap-northeast-1", "jp"}, + {"ap-northeast-2", "apac"}, + {"ap-southeast-1", "apac"}, + {"ap-southeast-2", "au"}, + {"ap-south-1", "apac"}, + // Canada / South America fallback to us + {"ca-central-1", "us"}, + {"sa-east-1", "us"}, + // Unknown defaults to us + {"me-south-1", "us"}, + } + for _, tt := range tests { + t.Run(tt.region, func(t *testing.T) { + assert.Equal(t, tt.expect, BedrockCrossRegionPrefix(tt.region)) + }) + } +} + +func TestResolveBedrockModelID(t *testing.T) { + t.Run("default alias resolves and adjusts region", func(t *testing.T) { + account := &Account{ + Platform: PlatformAnthropic, + Type: AccountTypeBedrock, + Credentials: map[string]any{ + "aws_region": "eu-west-1", + }, + } + + modelID, ok := ResolveBedrockModelID(account, "claude-sonnet-4-5") + require.True(t, ok) + assert.Equal(t, "eu.anthropic.claude-sonnet-4-5-20250929-v1:0", modelID) + }) + + t.Run("custom alias mapping reuses default bedrock mapping", func(t *testing.T) { + account := &Account{ + Platform: PlatformAnthropic, + Type: AccountTypeBedrock, + Credentials: map[string]any{ + "aws_region": "ap-southeast-2", + "model_mapping": map[string]any{ + "claude-*": "claude-opus-4-6", + }, + }, + } + + modelID, ok := ResolveBedrockModelID(account, "claude-opus-4-6-thinking") + require.True(t, ok) + assert.Equal(t, "au.anthropic.claude-opus-4-6-v1", modelID) + }) + + t.Run("force global rewrites anthropic regional model id", func(t *testing.T) { + account := &Account{ + Platform: PlatformAnthropic, + Type: AccountTypeBedrock, + Credentials: map[string]any{ + "aws_region": "us-east-1", + "aws_force_global": "true", + "model_mapping": map[string]any{ + "claude-sonnet-4-6": "us.anthropic.claude-sonnet-4-6", + }, + }, + } + + modelID, ok := ResolveBedrockModelID(account, "claude-sonnet-4-6") + require.True(t, ok) + assert.Equal(t, "global.anthropic.claude-sonnet-4-6", modelID) + }) + + t.Run("direct bedrock model id passes through", func(t *testing.T) { + account := &Account{ + Platform: PlatformAnthropic, + Type: AccountTypeBedrock, + Credentials: map[string]any{ + "aws_region": "us-east-1", + }, + } + + modelID, ok := ResolveBedrockModelID(account, "anthropic.claude-haiku-4-5-20251001-v1:0") + require.True(t, ok) + assert.Equal(t, "anthropic.claude-haiku-4-5-20251001-v1:0", modelID) + }) + + t.Run("unsupported alias returns false", func(t *testing.T) { + account := &Account{ + Platform: PlatformAnthropic, + Type: AccountTypeBedrock, + Credentials: map[string]any{ + "aws_region": "us-east-1", + }, + } + + _, ok := ResolveBedrockModelID(account, "claude-3-5-sonnet-20241022") + assert.False(t, ok) + }) +} + +func TestAutoInjectBedrockBetaTokens(t *testing.T) { + t.Run("inject interleaved-thinking when thinking present", func(t *testing.T) { + body := []byte(`{"thinking":{"type":"enabled","budget_tokens":10000},"messages":[{"role":"user","content":"hi"}]}`) + result := autoInjectBedrockBetaTokens(nil, body, "us.anthropic.claude-opus-4-6-v1") + assert.Contains(t, result, "interleaved-thinking-2025-05-14") + }) + + t.Run("no duplicate when already present", func(t *testing.T) { + body := []byte(`{"thinking":{"type":"enabled","budget_tokens":10000},"messages":[{"role":"user","content":"hi"}]}`) + result := autoInjectBedrockBetaTokens([]string{"interleaved-thinking-2025-05-14"}, body, "us.anthropic.claude-opus-4-6-v1") + count := 0 + for _, t := range result { + if t == "interleaved-thinking-2025-05-14" { + count++ + } + } + assert.Equal(t, 1, count) + }) + + t.Run("inject computer-use when computer tool present", func(t *testing.T) { + body := []byte(`{"tools":[{"type":"computer_20250124","name":"computer","display_width_px":1024}],"messages":[{"role":"user","content":"hi"}]}`) + result := autoInjectBedrockBetaTokens(nil, body, "us.anthropic.claude-opus-4-6-v1") + assert.Contains(t, result, "computer-use-2025-11-24") + }) + + t.Run("inject advanced-tool-use for programmatic tool calling", func(t *testing.T) { + body := []byte(`{"tools":[{"name":"bash","allowed_callers":["code_execution_20250825"]}],"messages":[{"role":"user","content":"hi"}]}`) + result := autoInjectBedrockBetaTokens(nil, body, "us.anthropic.claude-opus-4-6-v1") + assert.Contains(t, result, "advanced-tool-use-2025-11-20") + }) + + t.Run("inject advanced-tool-use for input examples", func(t *testing.T) { + body := []byte(`{"tools":[{"name":"bash","input_examples":[{"cmd":"ls"}]}],"messages":[{"role":"user","content":"hi"}]}`) + result := autoInjectBedrockBetaTokens(nil, body, "us.anthropic.claude-opus-4-6-v1") + assert.Contains(t, result, "advanced-tool-use-2025-11-20") + }) + + t.Run("inject tool-search-tool directly for pure tool search (no programmatic/inputExamples)", func(t *testing.T) { + body := []byte(`{"tools":[{"type":"tool_search_tool_regex_20251119","name":"search"}],"messages":[{"role":"user","content":"hi"}]}`) + result := autoInjectBedrockBetaTokens(nil, body, "us.anthropic.claude-sonnet-4-6") + // 纯 tool search 场景直接注入 Bedrock 特定头,不走 advanced-tool-use 转换 + assert.Contains(t, result, "tool-search-tool-2025-10-19") + assert.NotContains(t, result, "advanced-tool-use-2025-11-20") + }) + + t.Run("inject advanced-tool-use when tool search combined with programmatic calling", func(t *testing.T) { + body := []byte(`{"tools":[{"type":"tool_search_tool_regex_20251119","name":"search"},{"name":"bash","allowed_callers":["code_execution_20250825"]}],"messages":[{"role":"user","content":"hi"}]}`) + result := autoInjectBedrockBetaTokens(nil, body, "us.anthropic.claude-sonnet-4-6") + // 混合场景使用 advanced-tool-use(后续由 filter 转换为 tool-search-tool) + assert.Contains(t, result, "advanced-tool-use-2025-11-20") + }) + + t.Run("do not inject tool-search beta for unsupported models", func(t *testing.T) { + body := []byte(`{"tools":[{"type":"tool_search_tool_regex_20251119","name":"search"}],"messages":[{"role":"user","content":"hi"}]}`) + result := autoInjectBedrockBetaTokens(nil, body, "anthropic.claude-3-5-sonnet-20241022-v2:0") + assert.NotContains(t, result, "advanced-tool-use-2025-11-20") + assert.NotContains(t, result, "tool-search-tool-2025-10-19") + }) + + t.Run("no injection for regular tools", func(t *testing.T) { + body := []byte(`{"tools":[{"name":"bash","description":"run bash","input_schema":{"type":"object"}}],"messages":[{"role":"user","content":"hi"}]}`) + result := autoInjectBedrockBetaTokens(nil, body, "us.anthropic.claude-opus-4-6-v1") + assert.Empty(t, result) + }) + + t.Run("no injection when no features detected", func(t *testing.T) { + body := []byte(`{"messages":[{"role":"user","content":"hi"}],"max_tokens":100}`) + result := autoInjectBedrockBetaTokens(nil, body, "us.anthropic.claude-opus-4-6-v1") + assert.Empty(t, result) + }) + + t.Run("preserves existing tokens", func(t *testing.T) { + body := []byte(`{"thinking":{"type":"enabled"},"messages":[{"role":"user","content":"hi"}]}`) + existing := []string{"context-1m-2025-08-07", "compact-2026-01-12"} + result := autoInjectBedrockBetaTokens(existing, body, "us.anthropic.claude-opus-4-6-v1") + assert.Contains(t, result, "context-1m-2025-08-07") + assert.Contains(t, result, "compact-2026-01-12") + assert.Contains(t, result, "interleaved-thinking-2025-05-14") + }) +} + +func TestResolveBedrockBetaTokens(t *testing.T) { + t.Run("body-only tool features resolve to final bedrock tokens", func(t *testing.T) { + body := []byte(`{"tools":[{"name":"bash","allowed_callers":["code_execution_20250825"]}],"messages":[{"role":"user","content":"hi"}]}`) + result := ResolveBedrockBetaTokens("", body, "us.anthropic.claude-opus-4-6-v1") + assert.Contains(t, result, "tool-search-tool-2025-10-19") + assert.Contains(t, result, "tool-examples-2025-10-29") + }) + + t.Run("unsupported client beta tokens are filtered out", func(t *testing.T) { + body := []byte(`{"messages":[{"role":"user","content":"hi"}]}`) + result := ResolveBedrockBetaTokens("interleaved-thinking-2025-05-14,files-api-2025-04-14", body, "us.anthropic.claude-opus-4-6-v1") + assert.Equal(t, []string{"interleaved-thinking-2025-05-14"}, result) + }) +} + +func TestPrepareBedrockRequestBody_AutoBetaInjection(t *testing.T) { + t.Run("thinking in body auto-injects beta without header", func(t *testing.T) { + input := `{"messages":[{"role":"user","content":"hi"}],"max_tokens":100,"thinking":{"type":"enabled","budget_tokens":10000}}` + result, err := PrepareBedrockRequestBody([]byte(input), "us.anthropic.claude-opus-4-6-v1", "") + require.NoError(t, err) + arr := gjson.GetBytes(result, "anthropic_beta").Array() + found := false + for _, v := range arr { + if v.String() == "interleaved-thinking-2025-05-14" { + found = true + } + } + assert.True(t, found, "interleaved-thinking should be auto-injected") + }) + + t.Run("header tokens merged with auto-injected tokens", func(t *testing.T) { + input := `{"messages":[{"role":"user","content":"hi"}],"max_tokens":100,"thinking":{"type":"enabled","budget_tokens":10000}}` + result, err := PrepareBedrockRequestBody([]byte(input), "us.anthropic.claude-opus-4-6-v1", "context-1m-2025-08-07") + require.NoError(t, err) + arr := gjson.GetBytes(result, "anthropic_beta").Array() + names := make([]string, len(arr)) + for i, v := range arr { + names[i] = v.String() + } + assert.Contains(t, names, "context-1m-2025-08-07") + assert.Contains(t, names, "interleaved-thinking-2025-05-14") + }) +} + +func TestAdjustBedrockModelRegionPrefix(t *testing.T) { + tests := []struct { + name string + modelID string + region string + expect string + }{ + // US region — no change needed + {"us region keeps us prefix", "us.anthropic.claude-opus-4-6-v1", "us-east-1", "us.anthropic.claude-opus-4-6-v1"}, + // EU region — replace us → eu + {"eu region replaces prefix", "us.anthropic.claude-opus-4-6-v1", "eu-west-1", "eu.anthropic.claude-opus-4-6-v1"}, + {"eu region sonnet", "us.anthropic.claude-sonnet-4-6", "eu-central-1", "eu.anthropic.claude-sonnet-4-6"}, + // APAC region — jp and au have dedicated prefixes per AWS docs + {"jp region (ap-northeast-1)", "us.anthropic.claude-sonnet-4-5-20250929-v1:0", "ap-northeast-1", "jp.anthropic.claude-sonnet-4-5-20250929-v1:0"}, + {"au region (ap-southeast-2)", "us.anthropic.claude-haiku-4-5-20251001-v1:0", "ap-southeast-2", "au.anthropic.claude-haiku-4-5-20251001-v1:0"}, + {"apac region (ap-southeast-1)", "us.anthropic.claude-sonnet-4-5-20250929-v1:0", "ap-southeast-1", "apac.anthropic.claude-sonnet-4-5-20250929-v1:0"}, + // eu → us (user manually set eu prefix, moved to us region) + {"eu to us", "eu.anthropic.claude-opus-4-6-v1", "us-west-2", "us.anthropic.claude-opus-4-6-v1"}, + // global prefix — replace to match region + {"global to eu", "global.anthropic.claude-opus-4-6-v1", "eu-west-1", "eu.anthropic.claude-opus-4-6-v1"}, + // No known prefix — leave unchanged + {"no prefix unchanged", "anthropic.claude-3-5-sonnet-20241022-v2:0", "eu-west-1", "anthropic.claude-3-5-sonnet-20241022-v2:0"}, + // GovCloud — uses independent us-gov prefix + {"govcloud from us", "us.anthropic.claude-opus-4-6-v1", "us-gov-east-1", "us-gov.anthropic.claude-opus-4-6-v1"}, + {"govcloud already correct", "us-gov.anthropic.claude-opus-4-6-v1", "us-gov-west-1", "us-gov.anthropic.claude-opus-4-6-v1"}, + // Force global (special region value) + {"force global from us", "us.anthropic.claude-opus-4-6-v1", "global", "global.anthropic.claude-opus-4-6-v1"}, + {"force global from eu", "eu.anthropic.claude-sonnet-4-6", "global", "global.anthropic.claude-sonnet-4-6"}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + assert.Equal(t, tt.expect, AdjustBedrockModelRegionPrefix(tt.modelID, tt.region)) + }) + } +} diff --git a/backend/internal/service/bedrock_signer.go b/backend/internal/service/bedrock_signer.go new file mode 100644 index 00000000..e7000b4d --- /dev/null +++ b/backend/internal/service/bedrock_signer.go @@ -0,0 +1,67 @@ +package service + +import ( + "context" + "crypto/sha256" + "encoding/hex" + "fmt" + "net/http" + "time" + + "github.com/aws/aws-sdk-go-v2/aws" + v4 "github.com/aws/aws-sdk-go-v2/aws/signer/v4" +) + +// BedrockSigner 使用 AWS SigV4 对 Bedrock 请求签名 +type BedrockSigner struct { + credentials aws.Credentials + region string + signer *v4.Signer +} + +// NewBedrockSigner 创建 BedrockSigner +func NewBedrockSigner(accessKeyID, secretAccessKey, sessionToken, region string) *BedrockSigner { + return &BedrockSigner{ + credentials: aws.Credentials{ + AccessKeyID: accessKeyID, + SecretAccessKey: secretAccessKey, + SessionToken: sessionToken, + }, + region: region, + signer: v4.NewSigner(), + } +} + +// NewBedrockSignerFromAccount 从 Account 凭证创建 BedrockSigner +func NewBedrockSignerFromAccount(account *Account) (*BedrockSigner, error) { + accessKeyID := account.GetCredential("aws_access_key_id") + if accessKeyID == "" { + return nil, fmt.Errorf("aws_access_key_id not found in credentials") + } + secretAccessKey := account.GetCredential("aws_secret_access_key") + if secretAccessKey == "" { + return nil, fmt.Errorf("aws_secret_access_key not found in credentials") + } + region := account.GetCredential("aws_region") + if region == "" { + region = defaultBedrockRegion + } + sessionToken := account.GetCredential("aws_session_token") // 可选 + + return NewBedrockSigner(accessKeyID, secretAccessKey, sessionToken, region), nil +} + +// SignRequest 对 HTTP 请求进行 SigV4 签名 +// 重要约束:调用此方法前,req 应只包含 AWS 相关的 header(如 Content-Type、Accept)。 +// 非 AWS header(如 anthropic-beta)会参与签名计算,如果 Bedrock 服务端不识别这些 header, +// 签名验证可能失败。litellm 通过 _filter_headers_for_aws_signature 实现头过滤, +// 当前实现中 buildUpstreamRequestBedrock 仅设置了 Content-Type 和 Accept,因此是安全的。 +func (s *BedrockSigner) SignRequest(ctx context.Context, req *http.Request, body []byte) error { + payloadHash := sha256Hash(body) + return s.signer.SignHTTP(ctx, s.credentials, req, payloadHash, "bedrock", s.region, time.Now()) +} + +func sha256Hash(data []byte) string { + h := sha256.Sum256(data) + return hex.EncodeToString(h[:]) +} diff --git a/backend/internal/service/bedrock_signer_test.go b/backend/internal/service/bedrock_signer_test.go new file mode 100644 index 00000000..641e9341 --- /dev/null +++ b/backend/internal/service/bedrock_signer_test.go @@ -0,0 +1,35 @@ +package service + +import ( + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestNewBedrockSignerFromAccount_DefaultRegion(t *testing.T) { + account := &Account{ + Platform: PlatformAnthropic, + Type: AccountTypeBedrock, + Credentials: map[string]any{ + "aws_access_key_id": "test-akid", + "aws_secret_access_key": "test-secret", + }, + } + + signer, err := NewBedrockSignerFromAccount(account) + require.NoError(t, err) + require.NotNil(t, signer) + assert.Equal(t, defaultBedrockRegion, signer.region) +} + +func TestFilterBetaTokens(t *testing.T) { + tokens := []string{"interleaved-thinking-2025-05-14", "tool-search-tool-2025-10-19"} + filterSet := map[string]struct{}{ + "tool-search-tool-2025-10-19": {}, + } + + assert.Equal(t, []string{"interleaved-thinking-2025-05-14"}, filterBetaTokens(tokens, filterSet)) + assert.Equal(t, tokens, filterBetaTokens(tokens, nil)) + assert.Nil(t, filterBetaTokens(nil, filterSet)) +} diff --git a/backend/internal/service/bedrock_stream.go b/backend/internal/service/bedrock_stream.go new file mode 100644 index 00000000..30c1011b --- /dev/null +++ b/backend/internal/service/bedrock_stream.go @@ -0,0 +1,414 @@ +package service + +import ( + "bufio" + "context" + "encoding/base64" + "errors" + "fmt" + "hash/crc32" + "io" + "net/http" + "sync/atomic" + "time" + + "github.com/gin-gonic/gin" + "github.com/tidwall/gjson" + "github.com/tidwall/sjson" + + "github.com/Wei-Shaw/sub2api/internal/pkg/logger" +) + +// handleBedrockStreamingResponse 处理 Bedrock InvokeModelWithResponseStream 的 EventStream 响应 +// Bedrock 返回 AWS EventStream 二进制格式,每个事件的 payload 中 chunk.bytes 是 base64 编码的 +// Claude SSE 事件 JSON。本方法解码后转换为标准 SSE 格式写入客户端。 +func (s *GatewayService) handleBedrockStreamingResponse( + ctx context.Context, + resp *http.Response, + c *gin.Context, + account *Account, + startTime time.Time, + model string, +) (*streamingResult, error) { + w := c.Writer + flusher, ok := w.(http.Flusher) + if !ok { + return nil, errors.New("streaming not supported") + } + + c.Header("Content-Type", "text/event-stream") + c.Header("Cache-Control", "no-cache") + c.Header("Connection", "keep-alive") + c.Header("X-Accel-Buffering", "no") + if v := resp.Header.Get("x-amzn-requestid"); v != "" { + c.Header("x-request-id", v) + } + + usage := &ClaudeUsage{} + var firstTokenMs *int + clientDisconnected := false + + // Bedrock EventStream 使用 application/vnd.amazon.eventstream 二进制格式。 + // 每个帧结构:total_length(4) + headers_length(4) + prelude_crc(4) + headers + payload + message_crc(4) + // 但更实用的方式是使用行扫描找 JSON chunks,因为 Bedrock 的响应在二进制帧中。 + // 我们使用 EventStream decoder 来正确解析。 + decoder := newBedrockEventStreamDecoder(resp.Body) + + type decodeEvent struct { + payload []byte + err error + } + events := make(chan decodeEvent, 16) + done := make(chan struct{}) + sendEvent := func(ev decodeEvent) bool { + select { + case events <- ev: + return true + case <-done: + return false + } + } + var lastReadAt atomic.Int64 + lastReadAt.Store(time.Now().UnixNano()) + + go func() { + defer close(events) + for { + payload, err := decoder.Decode() + if err != nil { + if err == io.EOF { + return + } + _ = sendEvent(decodeEvent{err: err}) + return + } + lastReadAt.Store(time.Now().UnixNano()) + if !sendEvent(decodeEvent{payload: payload}) { + return + } + } + }() + defer close(done) + + streamInterval := time.Duration(0) + if s.cfg != nil && s.cfg.Gateway.StreamDataIntervalTimeout > 0 { + streamInterval = time.Duration(s.cfg.Gateway.StreamDataIntervalTimeout) * time.Second + } + var intervalTicker *time.Ticker + if streamInterval > 0 { + intervalTicker = time.NewTicker(streamInterval) + defer intervalTicker.Stop() + } + var intervalCh <-chan time.Time + if intervalTicker != nil { + intervalCh = intervalTicker.C + } + + for { + select { + case ev, ok := <-events: + if !ok { + if !clientDisconnected { + flusher.Flush() + } + return &streamingResult{usage: usage, firstTokenMs: firstTokenMs, clientDisconnect: clientDisconnected}, nil + } + if ev.err != nil { + if clientDisconnected { + return &streamingResult{usage: usage, firstTokenMs: firstTokenMs, clientDisconnect: true}, nil + } + if errors.Is(ev.err, context.Canceled) || errors.Is(ev.err, context.DeadlineExceeded) { + return &streamingResult{usage: usage, firstTokenMs: firstTokenMs, clientDisconnect: true}, nil + } + return &streamingResult{usage: usage, firstTokenMs: firstTokenMs}, fmt.Errorf("bedrock stream read error: %w", ev.err) + } + + // payload 是 JSON,提取 chunk.bytes(base64 编码的 Claude SSE 事件数据) + sseData := extractBedrockChunkData(ev.payload) + if sseData == nil { + continue + } + + if firstTokenMs == nil { + ms := int(time.Since(startTime).Milliseconds()) + firstTokenMs = &ms + } + + // 转换 Bedrock 特有的 amazon-bedrock-invocationMetrics 为标准 Anthropic usage 格式 + // 同时移除该字段避免透传给客户端 + sseData = transformBedrockInvocationMetrics(sseData) + + // 解析 SSE 事件数据提取 usage + s.parseSSEUsagePassthrough(string(sseData), usage) + + // 确定 SSE event type + eventType := gjson.GetBytes(sseData, "type").String() + + // 写入标准 SSE 格式 + if !clientDisconnected { + var writeErr error + if eventType != "" { + _, writeErr = fmt.Fprintf(w, "event: %s\ndata: %s\n\n", eventType, sseData) + } else { + _, writeErr = fmt.Fprintf(w, "data: %s\n\n", sseData) + } + if writeErr != nil { + clientDisconnected = true + logger.LegacyPrintf("service.gateway", "[Bedrock] Client disconnected during streaming, continue draining for usage: account=%d", account.ID) + } else { + flusher.Flush() + } + } + + case <-intervalCh: + lastRead := time.Unix(0, lastReadAt.Load()) + if time.Since(lastRead) < streamInterval { + continue + } + if clientDisconnected { + return &streamingResult{usage: usage, firstTokenMs: firstTokenMs, clientDisconnect: true}, nil + } + logger.LegacyPrintf("service.gateway", "[Bedrock] Stream data interval timeout: account=%d model=%s interval=%s", account.ID, model, streamInterval) + if s.rateLimitService != nil { + s.rateLimitService.HandleStreamTimeout(ctx, account, model) + } + return &streamingResult{usage: usage, firstTokenMs: firstTokenMs}, fmt.Errorf("stream data interval timeout") + } + } +} + +// extractBedrockChunkData 从 Bedrock EventStream payload 中提取 Claude SSE 事件数据 +// Bedrock payload 格式:{"bytes":""} +func extractBedrockChunkData(payload []byte) []byte { + b64 := gjson.GetBytes(payload, "bytes").String() + if b64 == "" { + return nil + } + decoded, err := base64.StdEncoding.DecodeString(b64) + if err != nil { + return nil + } + return decoded +} + +// transformBedrockInvocationMetrics 将 Bedrock 特有的 amazon-bedrock-invocationMetrics +// 转换为标准 Anthropic usage 格式,并从 SSE 数据中移除该字段。 +// +// Bedrock Invoke 返回的 message_delta 事件可能包含: +// +// {"type":"message_delta","delta":{...},"amazon-bedrock-invocationMetrics":{"inputTokenCount":150,"outputTokenCount":42}} +// +// 转换为: +// +// {"type":"message_delta","delta":{...},"usage":{"input_tokens":150,"output_tokens":42}} +func transformBedrockInvocationMetrics(data []byte) []byte { + metrics := gjson.GetBytes(data, "amazon-bedrock-invocationMetrics") + if !metrics.Exists() || !metrics.IsObject() { + return data + } + + // 移除 Bedrock 特有字段 + data, _ = sjson.DeleteBytes(data, "amazon-bedrock-invocationMetrics") + + // 如果已有标准 usage 字段,不覆盖 + if gjson.GetBytes(data, "usage").Exists() { + return data + } + + // 转换 camelCase → snake_case 写入 usage + inputTokens := metrics.Get("inputTokenCount") + outputTokens := metrics.Get("outputTokenCount") + if inputTokens.Exists() { + data, _ = sjson.SetBytes(data, "usage.input_tokens", inputTokens.Int()) + } + if outputTokens.Exists() { + data, _ = sjson.SetBytes(data, "usage.output_tokens", outputTokens.Int()) + } + + return data +} + +// bedrockEventStreamDecoder 解码 AWS EventStream 二进制帧 +// EventStream 帧格式: +// +// [total_byte_length: 4 bytes] +// [headers_byte_length: 4 bytes] +// [prelude_crc: 4 bytes] +// [headers: variable] +// [payload: variable] +// [message_crc: 4 bytes] +type bedrockEventStreamDecoder struct { + reader *bufio.Reader +} + +func newBedrockEventStreamDecoder(r io.Reader) *bedrockEventStreamDecoder { + return &bedrockEventStreamDecoder{ + reader: bufio.NewReaderSize(r, 64*1024), + } +} + +// Decode 读取下一个 EventStream 帧并返回 chunk 类型事件的 payload +func (d *bedrockEventStreamDecoder) Decode() ([]byte, error) { + for { + // 读取 prelude: total_length(4) + headers_length(4) + prelude_crc(4) = 12 bytes + prelude := make([]byte, 12) + if _, err := io.ReadFull(d.reader, prelude); err != nil { + return nil, err + } + + // 验证 prelude CRC(AWS EventStream 使用标准 CRC32 / IEEE) + preludeCRC := bedrockReadUint32(prelude[8:12]) + if crc32.Checksum(prelude[0:8], crc32IEEETable) != preludeCRC { + return nil, fmt.Errorf("eventstream prelude CRC mismatch") + } + + totalLength := bedrockReadUint32(prelude[0:4]) + headersLength := bedrockReadUint32(prelude[4:8]) + + if totalLength < 16 { // minimum: 12 prelude + 4 message_crc + return nil, fmt.Errorf("invalid eventstream frame: total_length=%d", totalLength) + } + + // 读取 headers + payload + message_crc + remaining := int(totalLength) - 12 + if remaining <= 0 { + continue + } + data := make([]byte, remaining) + if _, err := io.ReadFull(d.reader, data); err != nil { + return nil, err + } + + // 验证 message CRC(覆盖 prelude + headers + payload) + messageCRC := bedrockReadUint32(data[len(data)-4:]) + h := crc32.New(crc32IEEETable) + h.Write(prelude) + h.Write(data[:len(data)-4]) + if h.Sum32() != messageCRC { + return nil, fmt.Errorf("eventstream message CRC mismatch") + } + + // 解析 headers + headers := data[:headersLength] + payload := data[headersLength : len(data)-4] // 去掉 message_crc + + // 从 headers 中提取 :event-type + eventType := extractEventStreamHeaderValue(headers, ":event-type") + + // 只处理 chunk 事件 + if eventType == "chunk" { + // payload 是完整的 JSON,包含 bytes 字段 + return payload, nil + } + + // 检查异常事件 + exceptionType := extractEventStreamHeaderValue(headers, ":exception-type") + if exceptionType != "" { + return nil, fmt.Errorf("bedrock exception: %s: %s", exceptionType, string(payload)) + } + + messageType := extractEventStreamHeaderValue(headers, ":message-type") + if messageType == "exception" || messageType == "error" { + return nil, fmt.Errorf("bedrock error: %s", string(payload)) + } + + // 跳过其他事件类型(如 initial-response) + } +} + +// extractEventStreamHeaderValue 从 EventStream headers 二进制数据中提取指定 header 的字符串值 +// EventStream header 格式: +// +// [name_length: 1 byte][name: variable][value_type: 1 byte][value: variable] +// +// value_type = 7 表示 string 类型,前 2 bytes 为长度 +func extractEventStreamHeaderValue(headers []byte, targetName string) string { + pos := 0 + for pos < len(headers) { + if pos >= len(headers) { + break + } + nameLen := int(headers[pos]) + pos++ + if pos+nameLen > len(headers) { + break + } + name := string(headers[pos : pos+nameLen]) + pos += nameLen + + if pos >= len(headers) { + break + } + valueType := headers[pos] + pos++ + + switch valueType { + case 7: // string + if pos+2 > len(headers) { + return "" + } + valueLen := int(bedrockReadUint16(headers[pos : pos+2])) + pos += 2 + if pos+valueLen > len(headers) { + return "" + } + value := string(headers[pos : pos+valueLen]) + pos += valueLen + if name == targetName { + return value + } + case 0: // bool true + if name == targetName { + return "true" + } + case 1: // bool false + if name == targetName { + return "false" + } + case 2: // byte + pos++ + if name == targetName { + return "" + } + case 3: // short + pos += 2 + if name == targetName { + return "" + } + case 4: // int + pos += 4 + if name == targetName { + return "" + } + case 5: // long + pos += 8 + if name == targetName { + return "" + } + case 6: // bytes + if pos+2 > len(headers) { + return "" + } + valueLen := int(bedrockReadUint16(headers[pos : pos+2])) + pos += 2 + valueLen + case 8: // timestamp + pos += 8 + case 9: // uuid + pos += 16 + default: + return "" // 未知类型,无法继续解析 + } + } + return "" +} + +// crc32IEEETable is the CRC32 / IEEE table used by AWS EventStream. +var crc32IEEETable = crc32.MakeTable(crc32.IEEE) + +func bedrockReadUint32(b []byte) uint32 { + return uint32(b[0])<<24 | uint32(b[1])<<16 | uint32(b[2])<<8 | uint32(b[3]) +} + +func bedrockReadUint16(b []byte) uint16 { + return uint16(b[0])<<8 | uint16(b[1]) +} diff --git a/backend/internal/service/bedrock_stream_test.go b/backend/internal/service/bedrock_stream_test.go new file mode 100644 index 00000000..500b9292 --- /dev/null +++ b/backend/internal/service/bedrock_stream_test.go @@ -0,0 +1,261 @@ +package service + +import ( + "bytes" + "encoding/base64" + "encoding/binary" + "hash/crc32" + "io" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "github.com/tidwall/gjson" +) + +func TestExtractBedrockChunkData(t *testing.T) { + t.Run("valid base64 payload", func(t *testing.T) { + original := `{"type":"content_block_delta","index":0,"delta":{"type":"text_delta","text":"Hello"}}` + b64 := base64.StdEncoding.EncodeToString([]byte(original)) + payload := []byte(`{"bytes":"` + b64 + `"}`) + + result := extractBedrockChunkData(payload) + require.NotNil(t, result) + assert.JSONEq(t, original, string(result)) + }) + + t.Run("empty bytes field", func(t *testing.T) { + result := extractBedrockChunkData([]byte(`{"bytes":""}`)) + assert.Nil(t, result) + }) + + t.Run("no bytes field", func(t *testing.T) { + result := extractBedrockChunkData([]byte(`{"other":"value"}`)) + assert.Nil(t, result) + }) + + t.Run("invalid base64", func(t *testing.T) { + result := extractBedrockChunkData([]byte(`{"bytes":"not-valid-base64!!!"}`)) + assert.Nil(t, result) + }) +} + +func TestTransformBedrockInvocationMetrics(t *testing.T) { + t.Run("converts metrics to usage", func(t *testing.T) { + input := `{"type":"message_delta","delta":{"stop_reason":"end_turn"},"amazon-bedrock-invocationMetrics":{"inputTokenCount":150,"outputTokenCount":42}}` + result := transformBedrockInvocationMetrics([]byte(input)) + + // amazon-bedrock-invocationMetrics should be removed + assert.False(t, gjson.GetBytes(result, "amazon-bedrock-invocationMetrics").Exists()) + // usage should be set + assert.Equal(t, int64(150), gjson.GetBytes(result, "usage.input_tokens").Int()) + assert.Equal(t, int64(42), gjson.GetBytes(result, "usage.output_tokens").Int()) + // original fields preserved + assert.Equal(t, "message_delta", gjson.GetBytes(result, "type").String()) + assert.Equal(t, "end_turn", gjson.GetBytes(result, "delta.stop_reason").String()) + }) + + t.Run("no metrics present", func(t *testing.T) { + input := `{"type":"content_block_delta","index":0,"delta":{"type":"text_delta","text":"Hi"}}` + result := transformBedrockInvocationMetrics([]byte(input)) + assert.JSONEq(t, input, string(result)) + }) + + t.Run("does not overwrite existing usage", func(t *testing.T) { + input := `{"type":"message_delta","usage":{"output_tokens":100},"amazon-bedrock-invocationMetrics":{"inputTokenCount":150,"outputTokenCount":42}}` + result := transformBedrockInvocationMetrics([]byte(input)) + + // metrics removed but existing usage preserved + assert.False(t, gjson.GetBytes(result, "amazon-bedrock-invocationMetrics").Exists()) + assert.Equal(t, int64(100), gjson.GetBytes(result, "usage.output_tokens").Int()) + }) +} + +func TestExtractEventStreamHeaderValue(t *testing.T) { + // Build a header with :event-type = "chunk" (string type = 7) + buildStringHeader := func(name, value string) []byte { + var buf bytes.Buffer + // name length (1 byte) + buf.WriteByte(byte(len(name))) + // name + buf.WriteString(name) + // value type (7 = string) + buf.WriteByte(7) + // value length (2 bytes, big-endian) + _ = binary.Write(&buf, binary.BigEndian, uint16(len(value))) + // value + buf.WriteString(value) + return buf.Bytes() + } + + t.Run("find string header", func(t *testing.T) { + headers := buildStringHeader(":event-type", "chunk") + assert.Equal(t, "chunk", extractEventStreamHeaderValue(headers, ":event-type")) + }) + + t.Run("header not found", func(t *testing.T) { + headers := buildStringHeader(":event-type", "chunk") + assert.Equal(t, "", extractEventStreamHeaderValue(headers, ":message-type")) + }) + + t.Run("multiple headers", func(t *testing.T) { + var buf bytes.Buffer + buf.Write(buildStringHeader(":content-type", "application/json")) + buf.Write(buildStringHeader(":event-type", "chunk")) + buf.Write(buildStringHeader(":message-type", "event")) + + headers := buf.Bytes() + assert.Equal(t, "chunk", extractEventStreamHeaderValue(headers, ":event-type")) + assert.Equal(t, "application/json", extractEventStreamHeaderValue(headers, ":content-type")) + assert.Equal(t, "event", extractEventStreamHeaderValue(headers, ":message-type")) + }) + + t.Run("empty headers", func(t *testing.T) { + assert.Equal(t, "", extractEventStreamHeaderValue([]byte{}, ":event-type")) + }) +} + +func TestBedrockEventStreamDecoder(t *testing.T) { + crc32IeeeTab := crc32.MakeTable(crc32.IEEE) + + // Build a valid EventStream frame with correct CRC32/IEEE checksums. + buildFrame := func(eventType string, payload []byte) []byte { + // Build headers + var headersBuf bytes.Buffer + // :event-type header + headersBuf.WriteByte(byte(len(":event-type"))) + headersBuf.WriteString(":event-type") + headersBuf.WriteByte(7) // string type + _ = binary.Write(&headersBuf, binary.BigEndian, uint16(len(eventType))) + headersBuf.WriteString(eventType) + // :message-type header + headersBuf.WriteByte(byte(len(":message-type"))) + headersBuf.WriteString(":message-type") + headersBuf.WriteByte(7) + _ = binary.Write(&headersBuf, binary.BigEndian, uint16(len("event"))) + headersBuf.WriteString("event") + + headers := headersBuf.Bytes() + headersLen := uint32(len(headers)) + // total = 12 (prelude) + headers + payload + 4 (message_crc) + totalLen := uint32(12 + len(headers) + len(payload) + 4) + + // Prelude: total_length(4) + headers_length(4) + var preludeBuf bytes.Buffer + _ = binary.Write(&preludeBuf, binary.BigEndian, totalLen) + _ = binary.Write(&preludeBuf, binary.BigEndian, headersLen) + preludeBytes := preludeBuf.Bytes() + preludeCRC := crc32.Checksum(preludeBytes, crc32IeeeTab) + + // Build frame: prelude + prelude_crc + headers + payload + var frame bytes.Buffer + frame.Write(preludeBytes) + _ = binary.Write(&frame, binary.BigEndian, preludeCRC) + frame.Write(headers) + frame.Write(payload) + + // Message CRC covers everything before itself + messageCRC := crc32.Checksum(frame.Bytes(), crc32IeeeTab) + _ = binary.Write(&frame, binary.BigEndian, messageCRC) + return frame.Bytes() + } + + t.Run("decode chunk event", func(t *testing.T) { + payload := []byte(`{"bytes":"dGVzdA=="}`) // base64("test") + frame := buildFrame("chunk", payload) + + decoder := newBedrockEventStreamDecoder(bytes.NewReader(frame)) + result, err := decoder.Decode() + require.NoError(t, err) + assert.Equal(t, payload, result) + }) + + t.Run("skip non-chunk events", func(t *testing.T) { + // Write initial-response followed by chunk + var buf bytes.Buffer + buf.Write(buildFrame("initial-response", []byte(`{}`))) + chunkPayload := []byte(`{"bytes":"aGVsbG8="}`) + buf.Write(buildFrame("chunk", chunkPayload)) + + decoder := newBedrockEventStreamDecoder(&buf) + result, err := decoder.Decode() + require.NoError(t, err) + assert.Equal(t, chunkPayload, result) + }) + + t.Run("EOF on empty input", func(t *testing.T) { + decoder := newBedrockEventStreamDecoder(bytes.NewReader(nil)) + _, err := decoder.Decode() + assert.Equal(t, io.EOF, err) + }) + + t.Run("corrupted prelude CRC", func(t *testing.T) { + frame := buildFrame("chunk", []byte(`{"bytes":"dGVzdA=="}`)) + // Corrupt the prelude CRC (bytes 8-11) + frame[8] ^= 0xFF + decoder := newBedrockEventStreamDecoder(bytes.NewReader(frame)) + _, err := decoder.Decode() + require.Error(t, err) + assert.Contains(t, err.Error(), "prelude CRC mismatch") + }) + + t.Run("corrupted message CRC", func(t *testing.T) { + frame := buildFrame("chunk", []byte(`{"bytes":"dGVzdA=="}`)) + // Corrupt the message CRC (last 4 bytes) + frame[len(frame)-1] ^= 0xFF + decoder := newBedrockEventStreamDecoder(bytes.NewReader(frame)) + _, err := decoder.Decode() + require.Error(t, err) + assert.Contains(t, err.Error(), "message CRC mismatch") + }) + + t.Run("castagnoli encoded frame is rejected", func(t *testing.T) { + castagnoliTab := crc32.MakeTable(crc32.Castagnoli) + payload := []byte(`{"bytes":"dGVzdA=="}`) + + var headersBuf bytes.Buffer + headersBuf.WriteByte(byte(len(":event-type"))) + headersBuf.WriteString(":event-type") + headersBuf.WriteByte(7) + _ = binary.Write(&headersBuf, binary.BigEndian, uint16(len("chunk"))) + headersBuf.WriteString("chunk") + + headers := headersBuf.Bytes() + headersLen := uint32(len(headers)) + totalLen := uint32(12 + len(headers) + len(payload) + 4) + + var preludeBuf bytes.Buffer + _ = binary.Write(&preludeBuf, binary.BigEndian, totalLen) + _ = binary.Write(&preludeBuf, binary.BigEndian, headersLen) + preludeBytes := preludeBuf.Bytes() + + var frame bytes.Buffer + frame.Write(preludeBytes) + _ = binary.Write(&frame, binary.BigEndian, crc32.Checksum(preludeBytes, castagnoliTab)) + frame.Write(headers) + frame.Write(payload) + _ = binary.Write(&frame, binary.BigEndian, crc32.Checksum(frame.Bytes(), castagnoliTab)) + + decoder := newBedrockEventStreamDecoder(bytes.NewReader(frame.Bytes())) + _, err := decoder.Decode() + require.Error(t, err) + assert.Contains(t, err.Error(), "prelude CRC mismatch") + }) +} + +func TestBuildBedrockURL(t *testing.T) { + t.Run("stream URL with colon in model ID", func(t *testing.T) { + url := BuildBedrockURL("us-east-1", "us.anthropic.claude-opus-4-5-20251101-v1:0", true) + assert.Equal(t, "https://bedrock-runtime.us-east-1.amazonaws.com/model/us.anthropic.claude-opus-4-5-20251101-v1%3A0/invoke-with-response-stream", url) + }) + + t.Run("non-stream URL with colon in model ID", func(t *testing.T) { + url := BuildBedrockURL("eu-west-1", "eu.anthropic.claude-sonnet-4-5-20250929-v1:0", false) + assert.Equal(t, "https://bedrock-runtime.eu-west-1.amazonaws.com/model/eu.anthropic.claude-sonnet-4-5-20250929-v1%3A0/invoke", url) + }) + + t.Run("model ID without colon", func(t *testing.T) { + url := BuildBedrockURL("us-east-1", "us.anthropic.claude-sonnet-4-6", true) + assert.Equal(t, "https://bedrock-runtime.us-east-1.amazonaws.com/model/us.anthropic.claude-sonnet-4-6/invoke-with-response-stream", url) + }) +} diff --git a/backend/internal/service/domain_constants.go b/backend/internal/service/domain_constants.go index 304c09f4..26905fb9 100644 --- a/backend/internal/service/domain_constants.go +++ b/backend/internal/service/domain_constants.go @@ -33,6 +33,8 @@ const ( AccountTypeSetupToken = domain.AccountTypeSetupToken // Setup Token类型账号(inference only scope) AccountTypeAPIKey = domain.AccountTypeAPIKey // API Key类型账号 AccountTypeUpstream = domain.AccountTypeUpstream // 上游透传类型账号(通过 Base URL + API Key 连接上游) + AccountTypeBedrock = domain.AccountTypeBedrock // AWS Bedrock 类型账号(通过 SigV4 签名连接 Bedrock) + AccountTypeBedrockAPIKey = domain.AccountTypeBedrockAPIKey // AWS Bedrock API Key 类型账号(通过 Bearer Token 连接 Bedrock) ) // Redeem type constants diff --git a/backend/internal/service/gateway_service.go b/backend/internal/service/gateway_service.go index 8a433a36..978ec98f 100644 --- a/backend/internal/service/gateway_service.go +++ b/backend/internal/service/gateway_service.go @@ -3336,6 +3336,10 @@ func (s *GatewayService) isModelSupportedByAccount(account *Account, requestedMo if account.Platform == PlatformSora { return s.isSoraModelSupportedByAccount(account, requestedModel) } + if account.IsBedrock() { + _, ok := ResolveBedrockModelID(account, requestedModel) + return ok + } // OAuth/SetupToken 账号使用 Anthropic 标准映射(短ID → 长ID) if account.Platform == PlatformAnthropic && account.Type != AccountTypeAPIKey { requestedModel = claude.NormalizeModelID(requestedModel) @@ -3493,6 +3497,10 @@ func (s *GatewayService) GetAccessToken(ctx context.Context, account *Account) ( return "", "", errors.New("api_key not found in credentials") } return apiKey, "apikey", nil + case AccountTypeBedrock: + return "", "bedrock", nil // Bedrock 使用 SigV4 签名,不需要 token + case AccountTypeBedrockAPIKey: + return "", "bedrock-apikey", nil // Bedrock API Key 使用 Bearer Token,由 forwardBedrock 处理 default: return "", "", fmt.Errorf("unsupported account type: %s", account.Type) } @@ -3948,6 +3956,10 @@ func (s *GatewayService) Forward(ctx context.Context, c *gin.Context, account *A return s.forwardAnthropicAPIKeyPassthrough(ctx, c, account, passthroughBody, passthroughModel, parsed.Stream, startTime) } + if account != nil && account.IsBedrock() { + return s.forwardBedrock(ctx, c, account, parsed, startTime) + } + // Beta policy: evaluate once; block check + cache filter set for buildUpstreamRequest. // Always overwrite the cache to prevent stale values from a previous retry with a different account. if account.Platform == PlatformAnthropic && c != nil { @@ -5068,6 +5080,366 @@ func writeAnthropicPassthroughResponseHeaders(dst http.Header, src http.Header, } } +// forwardBedrock 转发请求到 AWS Bedrock +func (s *GatewayService) forwardBedrock( + ctx context.Context, + c *gin.Context, + account *Account, + parsed *ParsedRequest, + startTime time.Time, +) (*ForwardResult, error) { + reqModel := parsed.Model + reqStream := parsed.Stream + body := parsed.Body + + region := bedrockRuntimeRegion(account) + mappedModel, ok := ResolveBedrockModelID(account, reqModel) + if !ok { + return nil, fmt.Errorf("unsupported bedrock model: %s", reqModel) + } + if mappedModel != reqModel { + logger.LegacyPrintf("service.gateway", "[Bedrock] Model mapping: %s -> %s (account: %s)", reqModel, mappedModel, account.Name) + } + + betaHeader := "" + if c != nil && c.Request != nil { + betaHeader = c.GetHeader("anthropic-beta") + } + + // 准备请求体(注入 anthropic_version/anthropic_beta,移除 Bedrock 不支持的字段,清理 cache_control) + betaTokens, err := s.resolveBedrockBetaTokensForRequest(ctx, account, betaHeader, body, mappedModel) + if err != nil { + return nil, err + } + + bedrockBody, err := PrepareBedrockRequestBodyWithTokens(body, mappedModel, betaTokens) + if err != nil { + return nil, fmt.Errorf("prepare bedrock request body: %w", err) + } + + proxyURL := "" + if account.ProxyID != nil && account.Proxy != nil { + proxyURL = account.Proxy.URL() + } + + logger.LegacyPrintf("service.gateway", "[Bedrock] 命中 Bedrock 分支: account=%d name=%s model=%s->%s stream=%v", + account.ID, account.Name, reqModel, mappedModel, reqStream) + + // 根据账号类型选择认证方式 + var signer *BedrockSigner + var bedrockAPIKey string + if account.IsBedrockAPIKey() { + bedrockAPIKey = account.GetCredential("api_key") + if bedrockAPIKey == "" { + return nil, fmt.Errorf("api_key not found in bedrock-apikey credentials") + } + } else { + signer, err = NewBedrockSignerFromAccount(account) + if err != nil { + return nil, fmt.Errorf("create bedrock signer: %w", err) + } + } + + // 执行上游请求(含重试) + resp, err := s.executeBedrockUpstream(ctx, c, account, bedrockBody, mappedModel, region, reqStream, signer, bedrockAPIKey, proxyURL) + if err != nil { + return nil, err + } + defer func() { _ = resp.Body.Close() }() + + // 将 Bedrock 的 x-amzn-requestid 映射到 x-request-id, + // 使通用错误处理函数(handleErrorResponse、handleRetryExhaustedError)能正确提取 AWS request ID。 + if awsReqID := resp.Header.Get("x-amzn-requestid"); awsReqID != "" && resp.Header.Get("x-request-id") == "" { + resp.Header.Set("x-request-id", awsReqID) + } + + // 错误/failover 处理 + if resp.StatusCode >= 400 { + return s.handleBedrockUpstreamErrors(ctx, resp, c, account) + } + + // 响应处理 + var usage *ClaudeUsage + var firstTokenMs *int + var clientDisconnect bool + if reqStream { + streamResult, err := s.handleBedrockStreamingResponse(ctx, resp, c, account, startTime, reqModel) + if err != nil { + return nil, err + } + usage = streamResult.usage + firstTokenMs = streamResult.firstTokenMs + clientDisconnect = streamResult.clientDisconnect + } else { + usage, err = s.handleBedrockNonStreamingResponse(ctx, resp, c, account) + if err != nil { + return nil, err + } + } + if usage == nil { + usage = &ClaudeUsage{} + } + + return &ForwardResult{ + RequestID: resp.Header.Get("x-amzn-requestid"), + Usage: *usage, + Model: reqModel, + Stream: reqStream, + Duration: time.Since(startTime), + FirstTokenMs: firstTokenMs, + ClientDisconnect: clientDisconnect, + }, nil +} + +// executeBedrockUpstream 执行 Bedrock 上游请求(含重试逻辑) +func (s *GatewayService) executeBedrockUpstream( + ctx context.Context, + c *gin.Context, + account *Account, + body []byte, + modelID string, + region string, + stream bool, + signer *BedrockSigner, + apiKey string, + proxyURL string, +) (*http.Response, error) { + var resp *http.Response + var err error + retryStart := time.Now() + for attempt := 1; attempt <= maxRetryAttempts; attempt++ { + var upstreamReq *http.Request + if account.IsBedrockAPIKey() { + upstreamReq, err = s.buildUpstreamRequestBedrockAPIKey(ctx, body, modelID, region, stream, apiKey) + } else { + upstreamReq, err = s.buildUpstreamRequestBedrock(ctx, body, modelID, region, stream, signer) + } + if err != nil { + return nil, err + } + + resp, err = s.httpUpstream.DoWithTLS(upstreamReq, proxyURL, account.ID, account.Concurrency, false) + if err != nil { + if resp != nil && resp.Body != nil { + _ = resp.Body.Close() + } + safeErr := sanitizeUpstreamErrorMessage(err.Error()) + setOpsUpstreamError(c, 0, safeErr, "") + appendOpsUpstreamError(c, OpsUpstreamErrorEvent{ + Platform: account.Platform, + AccountID: account.ID, + AccountName: account.Name, + UpstreamStatusCode: 0, + Kind: "request_error", + Message: safeErr, + }) + c.JSON(http.StatusBadGateway, gin.H{ + "type": "error", + "error": gin.H{ + "type": "upstream_error", + "message": "Upstream request failed", + }, + }) + return nil, fmt.Errorf("upstream request failed: %s", safeErr) + } + + if resp.StatusCode >= 400 && resp.StatusCode != 400 && s.shouldRetryUpstreamError(account, resp.StatusCode) { + if attempt < maxRetryAttempts { + elapsed := time.Since(retryStart) + if elapsed >= maxRetryElapsed { + break + } + + delay := retryBackoffDelay(attempt) + remaining := maxRetryElapsed - elapsed + if delay > remaining { + delay = remaining + } + if delay <= 0 { + break + } + + respBody, _ := io.ReadAll(io.LimitReader(resp.Body, 2<<20)) + _ = resp.Body.Close() + appendOpsUpstreamError(c, OpsUpstreamErrorEvent{ + Platform: account.Platform, + AccountID: account.ID, + AccountName: account.Name, + UpstreamStatusCode: resp.StatusCode, + Kind: "retry", + Message: extractUpstreamErrorMessage(respBody), + Detail: func() string { + if s.cfg != nil && s.cfg.Gateway.LogUpstreamErrorBody { + return truncateString(string(respBody), s.cfg.Gateway.LogUpstreamErrorBodyMaxBytes) + } + return "" + }(), + }) + logger.LegacyPrintf("service.gateway", "[Bedrock] account %d: upstream error %d, retry %d/%d after %v", + account.ID, resp.StatusCode, attempt, maxRetryAttempts, delay) + if err := sleepWithContext(ctx, delay); err != nil { + return nil, err + } + continue + } + break + } + + break + } + if resp == nil || resp.Body == nil { + return nil, errors.New("upstream request failed: empty response") + } + return resp, nil +} + +// handleBedrockUpstreamErrors 处理 Bedrock 上游 4xx/5xx 错误(failover + 错误响应) +func (s *GatewayService) handleBedrockUpstreamErrors( + ctx context.Context, + resp *http.Response, + c *gin.Context, + account *Account, +) (*ForwardResult, error) { + // retry exhausted + failover + if s.shouldRetryUpstreamError(account, resp.StatusCode) { + if s.shouldFailoverUpstreamError(resp.StatusCode) { + respBody, _ := io.ReadAll(io.LimitReader(resp.Body, 2<<20)) + _ = resp.Body.Close() + resp.Body = io.NopCloser(bytes.NewReader(respBody)) + + logger.LegacyPrintf("service.gateway", "[Bedrock] Upstream error (retry exhausted, failover): Account=%d(%s) Status=%d Body=%s", + account.ID, account.Name, resp.StatusCode, truncateString(string(respBody), 1000)) + + s.handleRetryExhaustedSideEffects(ctx, resp, account) + appendOpsUpstreamError(c, OpsUpstreamErrorEvent{ + Platform: account.Platform, + AccountID: account.ID, + AccountName: account.Name, + UpstreamStatusCode: resp.StatusCode, + Kind: "retry_exhausted_failover", + Message: extractUpstreamErrorMessage(respBody), + }) + return nil, &UpstreamFailoverError{ + StatusCode: resp.StatusCode, + ResponseBody: respBody, + } + } + return s.handleRetryExhaustedError(ctx, resp, c, account) + } + + // non-retryable failover + if s.shouldFailoverUpstreamError(resp.StatusCode) { + respBody, _ := io.ReadAll(io.LimitReader(resp.Body, 2<<20)) + _ = resp.Body.Close() + resp.Body = io.NopCloser(bytes.NewReader(respBody)) + + s.handleFailoverSideEffects(ctx, resp, account) + appendOpsUpstreamError(c, OpsUpstreamErrorEvent{ + Platform: account.Platform, + AccountID: account.ID, + AccountName: account.Name, + UpstreamStatusCode: resp.StatusCode, + Kind: "failover", + Message: extractUpstreamErrorMessage(respBody), + }) + return nil, &UpstreamFailoverError{ + StatusCode: resp.StatusCode, + ResponseBody: respBody, + } + } + + // other errors + return s.handleErrorResponse(ctx, resp, c, account) +} + +// buildUpstreamRequestBedrock 构建 Bedrock 上游请求 +func (s *GatewayService) buildUpstreamRequestBedrock( + ctx context.Context, + body []byte, + modelID string, + region string, + stream bool, + signer *BedrockSigner, +) (*http.Request, error) { + targetURL := BuildBedrockURL(region, modelID, stream) + + req, err := http.NewRequestWithContext(ctx, http.MethodPost, targetURL, bytes.NewReader(body)) + if err != nil { + return nil, err + } + + req.Header.Set("Content-Type", "application/json") + req.Header.Set("Accept", "application/json") + + // SigV4 签名 + if err := signer.SignRequest(ctx, req, body); err != nil { + return nil, fmt.Errorf("sign bedrock request: %w", err) + } + + return req, nil +} + +// buildUpstreamRequestBedrockAPIKey 构建 Bedrock API Key (Bearer Token) 上游请求 +func (s *GatewayService) buildUpstreamRequestBedrockAPIKey( + ctx context.Context, + body []byte, + modelID string, + region string, + stream bool, + apiKey string, +) (*http.Request, error) { + targetURL := BuildBedrockURL(region, modelID, stream) + + req, err := http.NewRequestWithContext(ctx, http.MethodPost, targetURL, bytes.NewReader(body)) + if err != nil { + return nil, err + } + + req.Header.Set("Content-Type", "application/json") + req.Header.Set("Accept", "application/json") + req.Header.Set("Authorization", "Bearer "+apiKey) + + return req, nil +} + +// handleBedrockNonStreamingResponse 处理 Bedrock 非流式响应 +// Bedrock InvokeModel 非流式响应的 body 格式与 Claude API 兼容 +func (s *GatewayService) handleBedrockNonStreamingResponse( + ctx context.Context, + resp *http.Response, + c *gin.Context, + account *Account, +) (*ClaudeUsage, error) { + maxBytes := resolveUpstreamResponseReadLimit(s.cfg) + body, err := readUpstreamResponseBodyLimited(resp.Body, maxBytes) + if err != nil { + if errors.Is(err, ErrUpstreamResponseBodyTooLarge) { + setOpsUpstreamError(c, http.StatusBadGateway, "upstream response too large", "") + c.JSON(http.StatusBadGateway, gin.H{ + "type": "error", + "error": gin.H{ + "type": "upstream_error", + "message": "Upstream response too large", + }, + }) + } + return nil, err + } + + // 转换 Bedrock 特有的 amazon-bedrock-invocationMetrics 为标准 Anthropic usage 格式 + // 并移除该字段避免透传给客户端 + body = transformBedrockInvocationMetrics(body) + + usage := parseClaudeUsageFromResponseBody(body) + + c.Header("Content-Type", "application/json") + if v := resp.Header.Get("x-amzn-requestid"); v != "" { + c.Header("x-request-id", v) + } + c.Data(resp.StatusCode, "application/json", body) + return usage, nil +} + func (s *GatewayService) buildUpstreamRequest(ctx context.Context, c *gin.Context, account *Account, body []byte, token, tokenType, modelID string, reqStream bool, mimicClaudeCode bool) (*http.Request, error) { // 确定目标URL targetURL := claudeAPIURL @@ -5481,6 +5853,95 @@ func containsBetaToken(header, token string) bool { return false } +// filterBetaTokensFromHeader removes tokens present in filterSet from a comma-separated header value. +// Returns the filtered header string, or "" if all tokens were removed. +func filterBetaTokensFromHeader(header string, filterSet map[string]struct{}) string { + if header == "" || len(filterSet) == 0 { + return header + } + var kept []string + for _, p := range strings.Split(header, ",") { + t := strings.TrimSpace(p) + if t == "" { + continue + } + if _, filtered := filterSet[t]; !filtered { + kept = append(kept, t) + } + } + return strings.Join(kept, ", ") +} + +func filterBetaTokens(tokens []string, filterSet map[string]struct{}) []string { + if len(tokens) == 0 || len(filterSet) == 0 { + return tokens + } + kept := make([]string, 0, len(tokens)) + for _, token := range tokens { + if _, filtered := filterSet[token]; !filtered { + kept = append(kept, token) + } + } + return kept +} + +func (s *GatewayService) resolveBedrockBetaTokensForRequest( + ctx context.Context, + account *Account, + betaHeader string, + body []byte, + modelID string, +) ([]string, error) { + // 1. 对原始 header 中的 beta token 做 block 检查(快速失败) + policy := s.evaluateBetaPolicy(ctx, betaHeader, account) + if policy.blockErr != nil { + return nil, policy.blockErr + } + + // 2. 解析 header + body 自动注入 + Bedrock 转换/过滤 + betaTokens := ResolveBedrockBetaTokens(betaHeader, body, modelID) + + // 3. 对最终 token 列表再做 block 检查,捕获通过 body 自动注入绕过 header block 的情况。 + // 例如:管理员 block 了 interleaved-thinking,客户端不在 header 中带该 token, + // 但请求体中包含 thinking 字段 → autoInjectBedrockBetaTokens 会自动补齐 → + // 如果不做此检查,block 规则会被绕过。 + if blockErr := s.checkBetaPolicyBlockForTokens(ctx, betaTokens, account); blockErr != nil { + return nil, blockErr + } + + return filterBetaTokens(betaTokens, policy.filterSet), nil +} + +// checkBetaPolicyBlockForTokens 检查 token 列表中是否有被管理员 block 规则命中的 token。 +// 用于补充 evaluateBetaPolicy 对 header 的检查,覆盖 body 自动注入的 token。 +func (s *GatewayService) checkBetaPolicyBlockForTokens(ctx context.Context, tokens []string, account *Account) *BetaBlockedError { + if s.settingService == nil || len(tokens) == 0 { + return nil + } + settings, err := s.settingService.GetBetaPolicySettings(ctx) + if err != nil || settings == nil { + return nil + } + isOAuth := account.IsOAuth() + tokenSet := buildBetaTokenSet(tokens) + for _, rule := range settings.Rules { + if rule.Action != BetaPolicyActionBlock { + continue + } + if !betaPolicyScopeMatches(rule.Scope, isOAuth) { + continue + } + if _, present := tokenSet[rule.BetaToken]; present { + msg := rule.ErrorMessage + if msg == "" { + msg = "beta feature " + rule.BetaToken + " is not allowed" + } + return &BetaBlockedError{Message: msg} + } + } + return nil +} + func buildBetaTokenSet(tokens []string) map[string]struct{} { m := make(map[string]struct{}, len(tokens)) for _, t := range tokens { @@ -7064,6 +7525,12 @@ func (s *GatewayService) ForwardCountTokens(ctx context.Context, c *gin.Context, return s.forwardCountTokensAnthropicAPIKeyPassthrough(ctx, c, account, passthroughBody) } + // Bedrock 不支持 count_tokens 端点 + if account != nil && account.IsBedrock() { + s.countTokensError(c, http.StatusNotFound, "not_found_error", "count_tokens endpoint is not supported for Bedrock") + return nil + } + body := parsed.Body reqModel := parsed.Model diff --git a/backend/internal/service/gateway_service_bedrock_beta_test.go b/backend/internal/service/gateway_service_bedrock_beta_test.go new file mode 100644 index 00000000..8920ee08 --- /dev/null +++ b/backend/internal/service/gateway_service_bedrock_beta_test.go @@ -0,0 +1,267 @@ +package service + +import ( + "context" + "encoding/json" + "testing" + + "github.com/Wei-Shaw/sub2api/internal/config" +) + +type betaPolicySettingRepoStub struct { + values map[string]string +} + +func (s *betaPolicySettingRepoStub) Get(ctx context.Context, key string) (*Setting, error) { + panic("unexpected Get call") +} + +func (s *betaPolicySettingRepoStub) GetValue(ctx context.Context, key string) (string, error) { + if v, ok := s.values[key]; ok { + return v, nil + } + return "", ErrSettingNotFound +} + +func (s *betaPolicySettingRepoStub) Set(ctx context.Context, key, value string) error { + panic("unexpected Set call") +} + +func (s *betaPolicySettingRepoStub) GetMultiple(ctx context.Context, keys []string) (map[string]string, error) { + panic("unexpected GetMultiple call") +} + +func (s *betaPolicySettingRepoStub) SetMultiple(ctx context.Context, settings map[string]string) error { + panic("unexpected SetMultiple call") +} + +func (s *betaPolicySettingRepoStub) GetAll(ctx context.Context) (map[string]string, error) { + panic("unexpected GetAll call") +} + +func (s *betaPolicySettingRepoStub) Delete(ctx context.Context, key string) error { + panic("unexpected Delete call") +} + +func TestResolveBedrockBetaTokensForRequest_BlocksOnOriginalAnthropicToken(t *testing.T) { + settings := &BetaPolicySettings{ + Rules: []BetaPolicyRule{ + { + BetaToken: "advanced-tool-use-2025-11-20", + Action: BetaPolicyActionBlock, + Scope: BetaPolicyScopeAll, + ErrorMessage: "advanced tool use is blocked", + }, + }, + } + raw, err := json.Marshal(settings) + if err != nil { + t.Fatalf("marshal settings: %v", err) + } + + svc := &GatewayService{ + settingService: NewSettingService( + &betaPolicySettingRepoStub{values: map[string]string{ + SettingKeyBetaPolicySettings: string(raw), + }}, + &config.Config{}, + ), + } + account := &Account{Platform: PlatformAnthropic, Type: AccountTypeBedrock} + + _, err = svc.resolveBedrockBetaTokensForRequest( + context.Background(), + account, + "advanced-tool-use-2025-11-20", + []byte(`{"messages":[{"role":"user","content":"hi"}]}`), + "us.anthropic.claude-opus-4-6-v1", + ) + if err == nil { + t.Fatal("expected raw advanced-tool-use token to be blocked before Bedrock transform") + } + if err.Error() != "advanced tool use is blocked" { + t.Fatalf("unexpected error: %v", err) + } +} + +func TestResolveBedrockBetaTokensForRequest_FiltersAfterBedrockTransform(t *testing.T) { + settings := &BetaPolicySettings{ + Rules: []BetaPolicyRule{ + { + BetaToken: "tool-search-tool-2025-10-19", + Action: BetaPolicyActionFilter, + Scope: BetaPolicyScopeAll, + }, + }, + } + raw, err := json.Marshal(settings) + if err != nil { + t.Fatalf("marshal settings: %v", err) + } + + svc := &GatewayService{ + settingService: NewSettingService( + &betaPolicySettingRepoStub{values: map[string]string{ + SettingKeyBetaPolicySettings: string(raw), + }}, + &config.Config{}, + ), + } + account := &Account{Platform: PlatformAnthropic, Type: AccountTypeBedrock} + + betaTokens, err := svc.resolveBedrockBetaTokensForRequest( + context.Background(), + account, + "advanced-tool-use-2025-11-20", + []byte(`{"messages":[{"role":"user","content":"hi"}]}`), + "us.anthropic.claude-opus-4-6-v1", + ) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + for _, token := range betaTokens { + if token == "tool-search-tool-2025-10-19" { + t.Fatalf("expected transformed Bedrock token to be filtered") + } + } +} + +// TestResolveBedrockBetaTokensForRequest_BlocksBodyAutoInjectedThinking 验证: +// 管理员 block 了 interleaved-thinking,客户端不在 header 中带该 token, +// 但请求体包含 thinking 字段 → 自动注入后应被 block。 +func TestResolveBedrockBetaTokensForRequest_BlocksBodyAutoInjectedThinking(t *testing.T) { + settings := &BetaPolicySettings{ + Rules: []BetaPolicyRule{ + { + BetaToken: "interleaved-thinking-2025-05-14", + Action: BetaPolicyActionBlock, + Scope: BetaPolicyScopeAll, + ErrorMessage: "thinking is blocked", + }, + }, + } + raw, err := json.Marshal(settings) + if err != nil { + t.Fatalf("marshal settings: %v", err) + } + + svc := &GatewayService{ + settingService: NewSettingService( + &betaPolicySettingRepoStub{values: map[string]string{ + SettingKeyBetaPolicySettings: string(raw), + }}, + &config.Config{}, + ), + } + account := &Account{Platform: PlatformAnthropic, Type: AccountTypeBedrock} + + // header 中不带 beta token,但 body 中有 thinking 字段 + _, err = svc.resolveBedrockBetaTokensForRequest( + context.Background(), + account, + "", // 空 header + []byte(`{"thinking":{"type":"enabled","budget_tokens":10000},"messages":[{"role":"user","content":"hi"}]}`), + "us.anthropic.claude-opus-4-6-v1", + ) + if err == nil { + t.Fatal("expected body-injected interleaved-thinking to be blocked") + } + if err.Error() != "thinking is blocked" { + t.Fatalf("unexpected error: %v", err) + } +} + +// TestResolveBedrockBetaTokensForRequest_BlocksBodyAutoInjectedToolSearch 验证: +// 管理员 block 了 tool-search-tool,客户端不在 header 中带 beta token, +// 但请求体包含 tool search 工具 → 自动注入后应被 block。 +func TestResolveBedrockBetaTokensForRequest_BlocksBodyAutoInjectedToolSearch(t *testing.T) { + settings := &BetaPolicySettings{ + Rules: []BetaPolicyRule{ + { + BetaToken: "tool-search-tool-2025-10-19", + Action: BetaPolicyActionBlock, + Scope: BetaPolicyScopeAll, + ErrorMessage: "tool search is blocked", + }, + }, + } + raw, err := json.Marshal(settings) + if err != nil { + t.Fatalf("marshal settings: %v", err) + } + + svc := &GatewayService{ + settingService: NewSettingService( + &betaPolicySettingRepoStub{values: map[string]string{ + SettingKeyBetaPolicySettings: string(raw), + }}, + &config.Config{}, + ), + } + account := &Account{Platform: PlatformAnthropic, Type: AccountTypeBedrock} + + // header 中不带 beta token,但 body 中有 tool_search_tool 工具 + _, err = svc.resolveBedrockBetaTokensForRequest( + context.Background(), + account, + "", + []byte(`{"tools":[{"type":"tool_search_tool_regex_20251119","name":"search"}],"messages":[{"role":"user","content":"hi"}]}`), + "us.anthropic.claude-sonnet-4-6", + ) + if err == nil { + t.Fatal("expected body-injected tool-search-tool to be blocked") + } + if err.Error() != "tool search is blocked" { + t.Fatalf("unexpected error: %v", err) + } +} + +// TestResolveBedrockBetaTokensForRequest_PassesWhenNoBlockRuleMatches 验证: +// body 自动注入的 token 如果没有对应的 block 规则,应正常通过。 +func TestResolveBedrockBetaTokensForRequest_PassesWhenNoBlockRuleMatches(t *testing.T) { + settings := &BetaPolicySettings{ + Rules: []BetaPolicyRule{ + { + BetaToken: "computer-use-2025-11-24", + Action: BetaPolicyActionBlock, + Scope: BetaPolicyScopeAll, + ErrorMessage: "computer use is blocked", + }, + }, + } + raw, err := json.Marshal(settings) + if err != nil { + t.Fatalf("marshal settings: %v", err) + } + + svc := &GatewayService{ + settingService: NewSettingService( + &betaPolicySettingRepoStub{values: map[string]string{ + SettingKeyBetaPolicySettings: string(raw), + }}, + &config.Config{}, + ), + } + account := &Account{Platform: PlatformAnthropic, Type: AccountTypeBedrock} + + // body 中有 thinking(会注入 interleaved-thinking),但 block 规则只针对 computer-use + tokens, err := svc.resolveBedrockBetaTokensForRequest( + context.Background(), + account, + "", + []byte(`{"thinking":{"type":"enabled","budget_tokens":10000},"messages":[{"role":"user","content":"hi"}]}`), + "us.anthropic.claude-opus-4-6-v1", + ) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + found := false + for _, token := range tokens { + if token == "interleaved-thinking-2025-05-14" { + found = true + } + } + if !found { + t.Fatal("expected interleaved-thinking token to be present") + } +} diff --git a/backend/internal/service/gateway_service_bedrock_model_support_test.go b/backend/internal/service/gateway_service_bedrock_model_support_test.go new file mode 100644 index 00000000..aa8d4756 --- /dev/null +++ b/backend/internal/service/gateway_service_bedrock_model_support_test.go @@ -0,0 +1,48 @@ +package service + +import "testing" + +func TestGatewayServiceIsModelSupportedByAccount_BedrockDefaultMappingRestrictsModels(t *testing.T) { + svc := &GatewayService{} + account := &Account{ + Platform: PlatformAnthropic, + Type: AccountTypeBedrock, + Credentials: map[string]any{ + "aws_region": "us-east-1", + }, + } + + if !svc.isModelSupportedByAccount(account, "claude-sonnet-4-5") { + t.Fatalf("expected default Bedrock alias to be supported") + } + + if svc.isModelSupportedByAccount(account, "claude-3-5-sonnet-20241022") { + t.Fatalf("expected unsupported alias to be rejected for Bedrock account") + } +} + +func TestGatewayServiceIsModelSupportedByAccount_BedrockCustomMappingStillActsAsAllowlist(t *testing.T) { + svc := &GatewayService{} + account := &Account{ + Platform: PlatformAnthropic, + Type: AccountTypeBedrock, + Credentials: map[string]any{ + "aws_region": "eu-west-1", + "model_mapping": map[string]any{ + "claude-sonnet-*": "claude-sonnet-4-6", + }, + }, + } + + if !svc.isModelSupportedByAccount(account, "claude-sonnet-4-6") { + t.Fatalf("expected matched custom mapping to be supported") + } + + if !svc.isModelSupportedByAccount(account, "claude-opus-4-6") { + t.Fatalf("expected default Bedrock alias fallback to remain supported") + } + + if svc.isModelSupportedByAccount(account, "claude-3-5-sonnet-20241022") { + t.Fatalf("expected unsupported model to still be rejected") + } +} diff --git a/frontend/src/components/account/CreateAccountModal.vue b/frontend/src/components/account/CreateAccountModal.vue index 8423c1b9..1ac96ed6 100644 --- a/frontend/src/components/account/CreateAccountModal.vue +++ b/frontend/src/components/account/CreateAccountModal.vue @@ -232,7 +232,7 @@
-
+
+ + + +
@@ -896,7 +956,7 @@ -
+
+ +
+
+ + +
+
+ + +
+
+ + +

{{ t('admin.accounts.bedrockSessionTokenHint') }}

+
+
+ + +

{{ t('admin.accounts.bedrockRegionHint') }}

+
+
+ +

{{ t('admin.accounts.bedrockForceGlobalHint') }}

+
+ + +
+ + + +
+ + +
+ + +
+ +

+ {{ t('admin.accounts.selectedModels', { count: allowedModels.length }) }} + {{ t('admin.accounts.supportsAllModels') }} +

+
+ + +
+
+ + + + +
+ + +
+ +
+
+
+
+ + +
+
+ + +
+
+ + +

{{ t('admin.accounts.bedrockRegionHint') }}

+
+
+ +

{{ t('admin.accounts.bedrockForceGlobalHint') }}

+
+ + +
+ + + +
+ + +
+ + +
+ +

+ {{ t('admin.accounts.selectedModels', { count: allowedModels.length }) }} + {{ t('admin.accounts.supportsAllModels') }} +

+
+ + +
+
+ + + + +
+ + +
+ +
+
+
+
+
@@ -2671,7 +3014,7 @@ interface TempUnschedRuleForm { // State const step = ref(1) const submitting = ref(false) -const accountCategory = ref<'oauth-based' | 'apikey'>('oauth-based') // UI selection for account category +const accountCategory = ref<'oauth-based' | 'apikey' | 'bedrock' | 'bedrock-apikey'>('oauth-based') // UI selection for account category const addMethod = ref('oauth') // For oauth-based: 'oauth' or 'setup-token' const apiKeyBaseUrl = ref('https://api.anthropic.com') const apiKeyValue = ref('') @@ -2704,6 +3047,19 @@ const antigravityModelRestrictionMode = ref<'whitelist' | 'mapping'>('whitelist' const antigravityWhitelistModels = ref([]) const antigravityModelMappings = ref([]) const antigravityPresetMappings = computed(() => getPresetMappingsByPlatform('antigravity')) +const bedrockPresets = computed(() => getPresetMappingsByPlatform('bedrock')) + +// Bedrock credentials +const bedrockAccessKeyId = ref('') +const bedrockSecretAccessKey = ref('') +const bedrockSessionToken = ref('') +const bedrockRegion = ref('us-east-1') +const bedrockForceGlobal = ref(false) + +// Bedrock API Key credentials +const bedrockApiKeyValue = ref('') +const bedrockApiKeyRegion = ref('us-east-1') +const bedrockApiKeyForceGlobal = ref(false) const tempUnschedEnabled = ref(false) const tempUnschedRules = ref([]) const getModelMappingKey = createStableObjectKeyResolver('create-model-mapping') @@ -2868,6 +3224,10 @@ const isOAuthFlow = computed(() => { if (form.platform === 'antigravity' && antigravityAccountType.value === 'upstream') { return false } + // Bedrock 类型不需要 OAuth 流程 + if (form.platform === 'anthropic' && accountCategory.value === 'bedrock') { + return false + } return accountCategory.value === 'oauth-based' }) @@ -2935,6 +3295,11 @@ watch( form.type = 'apikey' return } + // Bedrock 类型 + if (form.platform === 'anthropic' && category === 'bedrock') { + form.type = 'bedrock' as AccountType + return + } if (category === 'oauth-based') { form.type = method as AccountType // 'oauth' or 'setup-token' } else { @@ -2972,6 +3337,13 @@ watch( antigravityModelMappings.value = [] antigravityModelRestrictionMode.value = 'mapping' } + // Reset Bedrock fields when switching platforms + bedrockAccessKeyId.value = '' + bedrockSecretAccessKey.value = '' + bedrockSessionToken.value = '' + bedrockRegion.value = 'us-east-1' + bedrockForceGlobal.value = false + bedrockApiKeyForceGlobal.value = false // Reset Anthropic/Antigravity-specific settings when switching to other platforms if (newPlatform !== 'anthropic' && newPlatform !== 'antigravity') { interceptWarmupRequests.value = false @@ -3541,6 +3913,84 @@ const handleSubmit = async () => { return } + // For Bedrock type, create directly + if (form.platform === 'anthropic' && accountCategory.value === 'bedrock') { + if (!form.name.trim()) { + appStore.showError(t('admin.accounts.pleaseEnterAccountName')) + return + } + if (!bedrockAccessKeyId.value.trim()) { + appStore.showError(t('admin.accounts.bedrockAccessKeyIdRequired')) + return + } + if (!bedrockSecretAccessKey.value.trim()) { + appStore.showError(t('admin.accounts.bedrockSecretAccessKeyRequired')) + return + } + if (!bedrockRegion.value.trim()) { + appStore.showError(t('admin.accounts.bedrockRegionRequired')) + return + } + + const credentials: Record = { + aws_access_key_id: bedrockAccessKeyId.value.trim(), + aws_secret_access_key: bedrockSecretAccessKey.value.trim(), + aws_region: bedrockRegion.value.trim(), + } + if (bedrockSessionToken.value.trim()) { + credentials.aws_session_token = bedrockSessionToken.value.trim() + } + if (bedrockForceGlobal.value) { + credentials.aws_force_global = 'true' + } + + // Model mapping + const modelMapping = buildModelMappingObject( + modelRestrictionMode.value, allowedModels.value, modelMappings.value + ) + if (modelMapping) { + credentials.model_mapping = modelMapping + } + + applyInterceptWarmup(credentials, interceptWarmupRequests.value, 'create') + + await createAccountAndFinish('anthropic', 'bedrock' as AccountType, credentials) + return + } + + // For Bedrock API Key type, create directly + if (form.platform === 'anthropic' && accountCategory.value === 'bedrock-apikey') { + if (!form.name.trim()) { + appStore.showError(t('admin.accounts.pleaseEnterAccountName')) + return + } + if (!bedrockApiKeyValue.value.trim()) { + appStore.showError(t('admin.accounts.bedrockApiKeyRequired')) + return + } + + const credentials: Record = { + api_key: bedrockApiKeyValue.value.trim(), + aws_region: bedrockApiKeyRegion.value.trim() || 'us-east-1', + } + if (bedrockApiKeyForceGlobal.value) { + credentials.aws_force_global = 'true' + } + + // Model mapping + const modelMapping = buildModelMappingObject( + modelRestrictionMode.value, allowedModels.value, modelMappings.value + ) + if (modelMapping) { + credentials.model_mapping = modelMapping + } + + applyInterceptWarmup(credentials, interceptWarmupRequests.value, 'create') + + await createAccountAndFinish('anthropic', 'bedrock-apikey' as AccountType, credentials) + return + } + // For Antigravity upstream type, create directly if (form.platform === 'antigravity' && antigravityAccountType.value === 'upstream') { if (!form.name.trim()) { diff --git a/frontend/src/components/account/EditAccountModal.vue b/frontend/src/components/account/EditAccountModal.vue index 1f2e988c..b18e9db6 100644 --- a/frontend/src/components/account/EditAccountModal.vue +++ b/frontend/src/components/account/EditAccountModal.vue @@ -563,6 +563,233 @@
+ +
+
+ + +
+
+ + +

{{ t('admin.accounts.bedrockSecretKeyLeaveEmpty') }}

+
+
+ + +

{{ t('admin.accounts.bedrockSessionTokenHint') }}

+
+
+ + +

{{ t('admin.accounts.bedrockRegionHint') }}

+
+
+ +

{{ t('admin.accounts.bedrockForceGlobalHint') }}

+
+ + +
+ + + +
+ + +
+ + +
+ +

+ {{ t('admin.accounts.selectedModels', { count: allowedModels.length }) }} + {{ t('admin.accounts.supportsAllModels') }} +

+
+ + +
+
+ + + + +
+ + +
+ +
+
+
+
+ + +
+
+ + +

{{ t('admin.accounts.bedrockApiKeyLeaveEmpty') }}

+
+
+ + +

{{ t('admin.accounts.bedrockRegionHint') }}

+
+
+ +

{{ t('admin.accounts.bedrockForceGlobalHint') }}

+
+ + +
+ + + +
+ + +
+ + +
+ +

+ {{ t('admin.accounts.selectedModels', { count: allowedModels.length }) }} + {{ t('admin.accounts.supportsAllModels') }} +

+
+ + +
+
+ + + + +
+ + +
+ +
+
+
+
+
@@ -1529,6 +1756,7 @@ const baseUrlHint = computed(() => { }) const antigravityPresetMappings = computed(() => getPresetMappingsByPlatform('antigravity')) +const bedrockPresets = computed(() => getPresetMappingsByPlatform('bedrock')) // Model mapping type interface ModelMapping { @@ -1547,6 +1775,17 @@ interface TempUnschedRuleForm { const submitting = ref(false) const editBaseUrl = ref('https://api.anthropic.com') const editApiKey = ref('') +// Bedrock credentials +const editBedrockAccessKeyId = ref('') +const editBedrockSecretAccessKey = ref('') +const editBedrockSessionToken = ref('') +const editBedrockRegion = ref('') +const editBedrockForceGlobal = ref(false) + +// Bedrock API Key credentials +const editBedrockApiKeyValue = ref('') +const editBedrockApiKeyRegion = ref('') +const editBedrockApiKeyForceGlobal = ref(false) const modelMappings = ref([]) const modelRestrictionMode = ref<'whitelist' | 'mapping'>('whitelist') const allowedModels = ref([]) @@ -1889,6 +2128,58 @@ watch( } else { selectedErrorCodes.value = [] } + } else if (newAccount.type === 'bedrock' && newAccount.credentials) { + const bedrockCreds = newAccount.credentials as Record + editBedrockAccessKeyId.value = (bedrockCreds.aws_access_key_id as string) || '' + editBedrockRegion.value = (bedrockCreds.aws_region as string) || '' + editBedrockForceGlobal.value = (bedrockCreds.aws_force_global as string) === 'true' + editBedrockSecretAccessKey.value = '' + editBedrockSessionToken.value = '' + + // Load model mappings for bedrock + const existingMappings = bedrockCreds.model_mapping as Record | undefined + if (existingMappings && typeof existingMappings === 'object') { + const entries = Object.entries(existingMappings) + const isWhitelistMode = entries.length > 0 && entries.every(([from, to]) => from === to) + if (isWhitelistMode) { + modelRestrictionMode.value = 'whitelist' + allowedModels.value = entries.map(([from]) => from) + modelMappings.value = [] + } else { + modelRestrictionMode.value = 'mapping' + modelMappings.value = entries.map(([from, to]) => ({ from, to })) + allowedModels.value = [] + } + } else { + modelRestrictionMode.value = 'whitelist' + modelMappings.value = [] + allowedModels.value = [] + } + } else if (newAccount.type === 'bedrock-apikey' && newAccount.credentials) { + const bedrockApiKeyCreds = newAccount.credentials as Record + editBedrockApiKeyRegion.value = (bedrockApiKeyCreds.aws_region as string) || 'us-east-1' + editBedrockApiKeyForceGlobal.value = (bedrockApiKeyCreds.aws_force_global as string) === 'true' + editBedrockApiKeyValue.value = '' + + // Load model mappings for bedrock-apikey + const existingMappings = bedrockApiKeyCreds.model_mapping as Record | undefined + if (existingMappings && typeof existingMappings === 'object') { + const entries = Object.entries(existingMappings) + const isWhitelistMode = entries.length > 0 && entries.every(([from, to]) => from === to) + if (isWhitelistMode) { + modelRestrictionMode.value = 'whitelist' + allowedModels.value = entries.map(([from]) => from) + modelMappings.value = [] + } else { + modelRestrictionMode.value = 'mapping' + modelMappings.value = entries.map(([from, to]) => ({ from, to })) + allowedModels.value = [] + } + } else { + modelRestrictionMode.value = 'whitelist' + modelMappings.value = [] + allowedModels.value = [] + } } else if (newAccount.type === 'upstream' && newAccount.credentials) { const credentials = newAccount.credentials as Record editBaseUrl.value = (credentials.base_url as string) || '' @@ -2431,6 +2722,70 @@ const handleSubmit = async () => { return } + updatePayload.credentials = newCredentials + } else if (props.account.type === 'bedrock') { + const currentCredentials = (props.account.credentials as Record) || {} + const newCredentials: Record = { ...currentCredentials } + + newCredentials.aws_access_key_id = editBedrockAccessKeyId.value.trim() + newCredentials.aws_region = editBedrockRegion.value.trim() + if (editBedrockForceGlobal.value) { + newCredentials.aws_force_global = 'true' + } else { + delete newCredentials.aws_force_global + } + + // Only update secrets if user provided new values + if (editBedrockSecretAccessKey.value.trim()) { + newCredentials.aws_secret_access_key = editBedrockSecretAccessKey.value.trim() + } + if (editBedrockSessionToken.value.trim()) { + newCredentials.aws_session_token = editBedrockSessionToken.value.trim() + } + + // Model mapping + const modelMapping = buildModelMappingObject(modelRestrictionMode.value, allowedModels.value, modelMappings.value) + if (modelMapping) { + newCredentials.model_mapping = modelMapping + } else { + delete newCredentials.model_mapping + } + + applyInterceptWarmup(newCredentials, interceptWarmupRequests.value, 'edit') + if (!applyTempUnschedConfig(newCredentials)) { + return + } + + updatePayload.credentials = newCredentials + } else if (props.account.type === 'bedrock-apikey') { + const currentCredentials = (props.account.credentials as Record) || {} + const newCredentials: Record = { ...currentCredentials } + + newCredentials.aws_region = editBedrockApiKeyRegion.value.trim() || 'us-east-1' + if (editBedrockApiKeyForceGlobal.value) { + newCredentials.aws_force_global = 'true' + } else { + delete newCredentials.aws_force_global + } + + // Only update API key if user provided new value + if (editBedrockApiKeyValue.value.trim()) { + newCredentials.api_key = editBedrockApiKeyValue.value.trim() + } + + // Model mapping + const modelMapping = buildModelMappingObject(modelRestrictionMode.value, allowedModels.value, modelMappings.value) + if (modelMapping) { + newCredentials.model_mapping = modelMapping + } else { + delete newCredentials.model_mapping + } + + applyInterceptWarmup(newCredentials, interceptWarmupRequests.value, 'edit') + if (!applyTempUnschedConfig(newCredentials)) { + return + } + updatePayload.credentials = newCredentials } else { // For oauth/setup-token types, only update intercept_warmup_requests if changed diff --git a/frontend/src/components/common/PlatformTypeBadge.vue b/frontend/src/components/common/PlatformTypeBadge.vue index f0625e88..a6ff490e 100644 --- a/frontend/src/components/common/PlatformTypeBadge.vue +++ b/frontend/src/components/common/PlatformTypeBadge.vue @@ -82,6 +82,8 @@ const typeLabel = computed(() => { return 'Token' case 'apikey': return 'Key' + case 'bedrock': + return 'Bedrock' default: return props.type } diff --git a/frontend/src/composables/useModelWhitelist.ts b/frontend/src/composables/useModelWhitelist.ts index 09a150cb..b47b895c 100644 --- a/frontend/src/composables/useModelWhitelist.ts +++ b/frontend/src/composables/useModelWhitelist.ts @@ -331,6 +331,15 @@ const antigravityPresetMappings = [ { label: 'Opus 4.6-thinking', from: 'claude-opus-4-6-thinking', to: 'claude-opus-4-6-thinking', color: 'bg-pink-100 text-pink-700 hover:bg-pink-200 dark:bg-pink-900/30 dark:text-pink-400' } ] +// Bedrock 预设映射(与后端 DefaultBedrockModelMapping 保持一致) +const bedrockPresetMappings = [ + { label: 'Opus 4.6', from: 'claude-opus-4-6', to: 'us.anthropic.claude-opus-4-6-v1', color: 'bg-pink-100 text-pink-700 hover:bg-pink-200 dark:bg-pink-900/30 dark:text-pink-400' }, + { label: 'Sonnet 4.6', from: 'claude-sonnet-4-6', to: 'us.anthropic.claude-sonnet-4-6', color: 'bg-cyan-100 text-cyan-700 hover:bg-cyan-200 dark:bg-cyan-900/30 dark:text-cyan-400' }, + { label: 'Opus 4.5', from: 'claude-opus-4-5-thinking', to: 'us.anthropic.claude-opus-4-5-20251101-v1:0', color: 'bg-pink-100 text-pink-700 hover:bg-pink-200 dark:bg-pink-900/30 dark:text-pink-400' }, + { label: 'Sonnet 4.5', from: 'claude-sonnet-4-5', to: 'us.anthropic.claude-sonnet-4-5-20250929-v1:0', color: 'bg-cyan-100 text-cyan-700 hover:bg-cyan-200 dark:bg-cyan-900/30 dark:text-cyan-400' }, + { label: 'Haiku 4.5', from: 'claude-haiku-4-5', to: 'us.anthropic.claude-haiku-4-5-20251001-v1:0', color: 'bg-green-100 text-green-700 hover:bg-green-200 dark:bg-green-900/30 dark:text-green-400' }, +] + // Antigravity 默认映射(从后端 API 获取,与 constants.go 保持一致) // 使用 fetchAntigravityDefaultMappings() 异步获取 import { getAntigravityDefaultModelMapping } from '@/api/admin/accounts' @@ -403,6 +412,7 @@ export function getPresetMappingsByPlatform(platform: string) { if (platform === 'gemini') return geminiPresetMappings if (platform === 'sora') return soraPresetMappings if (platform === 'antigravity') return antigravityPresetMappings + if (platform === 'bedrock' || platform === 'bedrock-apikey') return bedrockPresetMappings return anthropicPresetMappings } diff --git a/frontend/src/i18n/locales/en.ts b/frontend/src/i18n/locales/en.ts index 100899c3..34b90ce1 100644 --- a/frontend/src/i18n/locales/en.ts +++ b/frontend/src/i18n/locales/en.ts @@ -1921,6 +1921,8 @@ export default { accountType: 'Account Type', claudeCode: 'Claude Code', claudeConsole: 'Claude Console', + bedrockLabel: 'AWS Bedrock', + bedrockDesc: 'SigV4 Signing', oauthSetupToken: 'OAuth / Setup Token', addMethod: 'Add Method', setupTokenLongLived: 'Setup Token (Long-lived)', @@ -2110,6 +2112,23 @@ export default { mixedChannelWarning: 'Warning: Group "{groupName}" contains both {currentPlatform} and {otherPlatform} accounts. Mixing different channels may cause thinking block signature validation issues, which will fallback to non-thinking mode. Are you sure you want to continue?', pleaseEnterAccountName: 'Please enter account name', pleaseEnterApiKey: 'Please enter API Key', + bedrockAccessKeyId: 'AWS Access Key ID', + bedrockSecretAccessKey: 'AWS Secret Access Key', + bedrockSessionToken: 'AWS Session Token', + bedrockRegion: 'AWS Region', + bedrockRegionHint: 'e.g. us-east-1, us-west-2, eu-west-1', + bedrockForceGlobal: 'Force Global cross-region inference', + bedrockForceGlobalHint: 'When enabled, model IDs use the global. prefix (e.g. global.anthropic.claude-...), routing requests to any supported region worldwide for higher availability', + bedrockAccessKeyIdRequired: 'Please enter AWS Access Key ID', + bedrockSecretAccessKeyRequired: 'Please enter AWS Secret Access Key', + bedrockRegionRequired: 'Please select AWS Region', + bedrockSessionTokenHint: 'Optional, for temporary credentials', + bedrockSecretKeyLeaveEmpty: 'Leave empty to keep current key', + bedrockApiKeyLabel: 'Bedrock API Key', + bedrockApiKeyDesc: 'Bearer Token', + bedrockApiKeyInput: 'API Key', + bedrockApiKeyRequired: 'Please enter Bedrock API Key', + bedrockApiKeyLeaveEmpty: 'Leave empty to keep current key', apiKeyIsRequired: 'API Key is required', leaveEmptyToKeep: 'Leave empty to keep current key', // Upstream type diff --git a/frontend/src/i18n/locales/zh.ts b/frontend/src/i18n/locales/zh.ts index 96ec5508..d8a00166 100644 --- a/frontend/src/i18n/locales/zh.ts +++ b/frontend/src/i18n/locales/zh.ts @@ -2069,6 +2069,8 @@ export default { accountType: '账号类型', claudeCode: 'Claude Code', claudeConsole: 'Claude Console', + bedrockLabel: 'AWS Bedrock', + bedrockDesc: 'SigV4 签名', oauthSetupToken: 'OAuth / Setup Token', addMethod: '添加方式', setupTokenLongLived: 'Setup Token(长期有效)', @@ -2251,6 +2253,23 @@ export default { mixedChannelWarning: '警告:分组 "{groupName}" 中同时包含 {currentPlatform} 和 {otherPlatform} 账号。混合使用不同渠道可能导致 thinking block 签名验证问题,会自动回退到非 thinking 模式。确定要继续吗?', pleaseEnterAccountName: '请输入账号名称', pleaseEnterApiKey: '请输入 API Key', + bedrockAccessKeyId: 'AWS Access Key ID', + bedrockSecretAccessKey: 'AWS Secret Access Key', + bedrockSessionToken: 'AWS Session Token', + bedrockRegion: 'AWS Region', + bedrockRegionHint: '例如 us-east-1, us-west-2, eu-west-1', + bedrockForceGlobal: '强制使用 Global 跨区域推理', + bedrockForceGlobalHint: '启用后模型 ID 使用 global. 前缀(如 global.anthropic.claude-...),请求可路由到全球任意支持的区域,获得更高可用性', + bedrockAccessKeyIdRequired: '请输入 AWS Access Key ID', + bedrockSecretAccessKeyRequired: '请输入 AWS Secret Access Key', + bedrockRegionRequired: '请选择 AWS Region', + bedrockSessionTokenHint: '可选,用于临时凭证', + bedrockSecretKeyLeaveEmpty: '留空以保持当前密钥', + bedrockApiKeyLabel: 'Bedrock API Key', + bedrockApiKeyDesc: 'Bearer Token 认证', + bedrockApiKeyInput: 'API Key', + bedrockApiKeyRequired: '请输入 Bedrock API Key', + bedrockApiKeyLeaveEmpty: '留空以保持当前密钥', apiKeyIsRequired: 'API Key 是必需的', leaveEmptyToKeep: '留空以保持当前密钥', // Upstream type diff --git a/frontend/src/types/index.ts b/frontend/src/types/index.ts index 5764134d..cbfa8e28 100644 --- a/frontend/src/types/index.ts +++ b/frontend/src/types/index.ts @@ -531,7 +531,7 @@ export interface UpdateGroupRequest { // ==================== Account & Proxy Types ==================== export type AccountPlatform = 'anthropic' | 'openai' | 'gemini' | 'antigravity' | 'sora' -export type AccountType = 'oauth' | 'setup-token' | 'apikey' | 'upstream' +export type AccountType = 'oauth' | 'setup-token' | 'apikey' | 'upstream' | 'bedrock' | 'bedrock-apikey' export type OAuthAddMethod = 'oauth' | 'setup-token' export type ProxyProtocol = 'http' | 'https' | 'socks5' | 'socks5h' From e90ec847b6bb060729344b7803acf3441429518d Mon Sep 17 00:00:00 2001 From: Ylarod Date: Fri, 13 Mar 2026 19:15:27 +0800 Subject: [PATCH 2/2] fix lint --- backend/internal/domain/constants.go | 8 +-- backend/internal/service/bedrock_stream.go | 4 +- .../internal/service/bedrock_stream_test.go | 54 +++++++++---------- backend/internal/service/domain_constants.go | 8 +-- backend/internal/service/gateway_service.go | 19 ------- 5 files changed, 37 insertions(+), 56 deletions(-) diff --git a/backend/internal/domain/constants.go b/backend/internal/domain/constants.go index 9920a76b..36d043b5 100644 --- a/backend/internal/domain/constants.go +++ b/backend/internal/domain/constants.go @@ -27,10 +27,10 @@ const ( // Account type constants const ( - AccountTypeOAuth = "oauth" // OAuth类型账号(full scope: profile + inference) - AccountTypeSetupToken = "setup-token" // Setup Token类型账号(inference only scope) - AccountTypeAPIKey = "apikey" // API Key类型账号 - AccountTypeUpstream = "upstream" // 上游透传类型账号(通过 Base URL + API Key 连接上游) + AccountTypeOAuth = "oauth" // OAuth类型账号(full scope: profile + inference) + AccountTypeSetupToken = "setup-token" // Setup Token类型账号(inference only scope) + AccountTypeAPIKey = "apikey" // API Key类型账号 + AccountTypeUpstream = "upstream" // 上游透传类型账号(通过 Base URL + API Key 连接上游) AccountTypeBedrock = "bedrock" // AWS Bedrock 类型账号(通过 SigV4 签名连接 Bedrock) AccountTypeBedrockAPIKey = "bedrock-apikey" // AWS Bedrock API Key 类型账号(通过 Bearer Token 连接 Bedrock) ) diff --git a/backend/internal/service/bedrock_stream.go b/backend/internal/service/bedrock_stream.go index 30c1011b..98196d27 100644 --- a/backend/internal/service/bedrock_stream.go +++ b/backend/internal/service/bedrock_stream.go @@ -282,8 +282,8 @@ func (d *bedrockEventStreamDecoder) Decode() ([]byte, error) { // 验证 message CRC(覆盖 prelude + headers + payload) messageCRC := bedrockReadUint32(data[len(data)-4:]) h := crc32.New(crc32IEEETable) - h.Write(prelude) - h.Write(data[:len(data)-4]) + _, _ = h.Write(prelude) + _, _ = h.Write(data[:len(data)-4]) if h.Sum32() != messageCRC { return nil, fmt.Errorf("eventstream message CRC mismatch") } diff --git a/backend/internal/service/bedrock_stream_test.go b/backend/internal/service/bedrock_stream_test.go index 500b9292..3d066137 100644 --- a/backend/internal/service/bedrock_stream_test.go +++ b/backend/internal/service/bedrock_stream_test.go @@ -76,15 +76,15 @@ func TestExtractEventStreamHeaderValue(t *testing.T) { buildStringHeader := func(name, value string) []byte { var buf bytes.Buffer // name length (1 byte) - buf.WriteByte(byte(len(name))) + _ = buf.WriteByte(byte(len(name))) // name - buf.WriteString(name) + _, _ = buf.WriteString(name) // value type (7 = string) - buf.WriteByte(7) + _ = buf.WriteByte(7) // value length (2 bytes, big-endian) _ = binary.Write(&buf, binary.BigEndian, uint16(len(value))) // value - buf.WriteString(value) + _, _ = buf.WriteString(value) return buf.Bytes() } @@ -100,9 +100,9 @@ func TestExtractEventStreamHeaderValue(t *testing.T) { t.Run("multiple headers", func(t *testing.T) { var buf bytes.Buffer - buf.Write(buildStringHeader(":content-type", "application/json")) - buf.Write(buildStringHeader(":event-type", "chunk")) - buf.Write(buildStringHeader(":message-type", "event")) + _, _ = buf.Write(buildStringHeader(":content-type", "application/json")) + _, _ = buf.Write(buildStringHeader(":event-type", "chunk")) + _, _ = buf.Write(buildStringHeader(":message-type", "event")) headers := buf.Bytes() assert.Equal(t, "chunk", extractEventStreamHeaderValue(headers, ":event-type")) @@ -123,17 +123,17 @@ func TestBedrockEventStreamDecoder(t *testing.T) { // Build headers var headersBuf bytes.Buffer // :event-type header - headersBuf.WriteByte(byte(len(":event-type"))) - headersBuf.WriteString(":event-type") - headersBuf.WriteByte(7) // string type + _ = headersBuf.WriteByte(byte(len(":event-type"))) + _, _ = headersBuf.WriteString(":event-type") + _ = headersBuf.WriteByte(7) // string type _ = binary.Write(&headersBuf, binary.BigEndian, uint16(len(eventType))) - headersBuf.WriteString(eventType) + _, _ = headersBuf.WriteString(eventType) // :message-type header - headersBuf.WriteByte(byte(len(":message-type"))) - headersBuf.WriteString(":message-type") - headersBuf.WriteByte(7) + _ = headersBuf.WriteByte(byte(len(":message-type"))) + _, _ = headersBuf.WriteString(":message-type") + _ = headersBuf.WriteByte(7) _ = binary.Write(&headersBuf, binary.BigEndian, uint16(len("event"))) - headersBuf.WriteString("event") + _, _ = headersBuf.WriteString("event") headers := headersBuf.Bytes() headersLen := uint32(len(headers)) @@ -149,10 +149,10 @@ func TestBedrockEventStreamDecoder(t *testing.T) { // Build frame: prelude + prelude_crc + headers + payload var frame bytes.Buffer - frame.Write(preludeBytes) + _, _ = frame.Write(preludeBytes) _ = binary.Write(&frame, binary.BigEndian, preludeCRC) - frame.Write(headers) - frame.Write(payload) + _, _ = frame.Write(headers) + _, _ = frame.Write(payload) // Message CRC covers everything before itself messageCRC := crc32.Checksum(frame.Bytes(), crc32IeeeTab) @@ -173,9 +173,9 @@ func TestBedrockEventStreamDecoder(t *testing.T) { t.Run("skip non-chunk events", func(t *testing.T) { // Write initial-response followed by chunk var buf bytes.Buffer - buf.Write(buildFrame("initial-response", []byte(`{}`))) + _, _ = buf.Write(buildFrame("initial-response", []byte(`{}`))) chunkPayload := []byte(`{"bytes":"aGVsbG8="}`) - buf.Write(buildFrame("chunk", chunkPayload)) + _, _ = buf.Write(buildFrame("chunk", chunkPayload)) decoder := newBedrockEventStreamDecoder(&buf) result, err := decoder.Decode() @@ -214,11 +214,11 @@ func TestBedrockEventStreamDecoder(t *testing.T) { payload := []byte(`{"bytes":"dGVzdA=="}`) var headersBuf bytes.Buffer - headersBuf.WriteByte(byte(len(":event-type"))) - headersBuf.WriteString(":event-type") - headersBuf.WriteByte(7) + _ = headersBuf.WriteByte(byte(len(":event-type"))) + _, _ = headersBuf.WriteString(":event-type") + _ = headersBuf.WriteByte(7) _ = binary.Write(&headersBuf, binary.BigEndian, uint16(len("chunk"))) - headersBuf.WriteString("chunk") + _, _ = headersBuf.WriteString("chunk") headers := headersBuf.Bytes() headersLen := uint32(len(headers)) @@ -230,10 +230,10 @@ func TestBedrockEventStreamDecoder(t *testing.T) { preludeBytes := preludeBuf.Bytes() var frame bytes.Buffer - frame.Write(preludeBytes) + _, _ = frame.Write(preludeBytes) _ = binary.Write(&frame, binary.BigEndian, crc32.Checksum(preludeBytes, castagnoliTab)) - frame.Write(headers) - frame.Write(payload) + _, _ = frame.Write(headers) + _, _ = frame.Write(payload) _ = binary.Write(&frame, binary.BigEndian, crc32.Checksum(frame.Bytes(), castagnoliTab)) decoder := newBedrockEventStreamDecoder(bytes.NewReader(frame.Bytes())) diff --git a/backend/internal/service/domain_constants.go b/backend/internal/service/domain_constants.go index 26905fb9..ad64b467 100644 --- a/backend/internal/service/domain_constants.go +++ b/backend/internal/service/domain_constants.go @@ -29,10 +29,10 @@ const ( // Account type constants const ( - AccountTypeOAuth = domain.AccountTypeOAuth // OAuth类型账号(full scope: profile + inference) - AccountTypeSetupToken = domain.AccountTypeSetupToken // Setup Token类型账号(inference only scope) - AccountTypeAPIKey = domain.AccountTypeAPIKey // API Key类型账号 - AccountTypeUpstream = domain.AccountTypeUpstream // 上游透传类型账号(通过 Base URL + API Key 连接上游) + AccountTypeOAuth = domain.AccountTypeOAuth // OAuth类型账号(full scope: profile + inference) + AccountTypeSetupToken = domain.AccountTypeSetupToken // Setup Token类型账号(inference only scope) + AccountTypeAPIKey = domain.AccountTypeAPIKey // API Key类型账号 + AccountTypeUpstream = domain.AccountTypeUpstream // 上游透传类型账号(通过 Base URL + API Key 连接上游) AccountTypeBedrock = domain.AccountTypeBedrock // AWS Bedrock 类型账号(通过 SigV4 签名连接 Bedrock) AccountTypeBedrockAPIKey = domain.AccountTypeBedrockAPIKey // AWS Bedrock API Key 类型账号(通过 Bearer Token 连接 Bedrock) ) diff --git a/backend/internal/service/gateway_service.go b/backend/internal/service/gateway_service.go index 978ec98f..ce1c746c 100644 --- a/backend/internal/service/gateway_service.go +++ b/backend/internal/service/gateway_service.go @@ -5853,25 +5853,6 @@ func containsBetaToken(header, token string) bool { return false } -// filterBetaTokensFromHeader removes tokens present in filterSet from a comma-separated header value. -// Returns the filtered header string, or "" if all tokens were removed. -func filterBetaTokensFromHeader(header string, filterSet map[string]struct{}) string { - if header == "" || len(filterSet) == 0 { - return header - } - var kept []string - for _, p := range strings.Split(header, ",") { - t := strings.TrimSpace(p) - if t == "" { - continue - } - if _, filtered := filterSet[t]; !filtered { - kept = append(kept, t) - } - } - return strings.Join(kept, ", ") -} - func filterBetaTokens(tokens []string, filterSet map[string]struct{}) []string { if len(tokens) == 0 || len(filterSet) == 0 { return tokens