This commit is contained in:
SunSeekerX
2026-01-21 11:55:28 +08:00
149 changed files with 15035 additions and 4017 deletions

View File

@@ -0,0 +1,789 @@
const redis = require('../models/redis')
const balanceScriptService = require('./balanceScriptService')
const logger = require('../utils/logger')
const CostCalculator = require('../utils/costCalculator')
const { isBalanceScriptEnabled } = require('../utils/featureFlags')
class AccountBalanceService {
constructor(options = {}) {
this.redis = options.redis || redis
this.logger = options.logger || logger
this.providers = new Map()
this.CACHE_TTL_SECONDS = 3600
this.LOCAL_TTL_SECONDS = 300
this.LOW_BALANCE_THRESHOLD = 10
this.HIGH_USAGE_THRESHOLD_PERCENT = 90
this.DEFAULT_CONCURRENCY = 10
}
getSupportedPlatforms() {
return [
'claude',
'claude-console',
'gemini',
'gemini-api',
'openai',
'openai-responses',
'azure_openai',
'bedrock',
'droid',
'ccr'
]
}
normalizePlatform(platform) {
if (!platform) {
return null
}
const value = String(platform).trim().toLowerCase()
// 兼容实施文档与历史命名
if (value === 'claude-official') {
return 'claude'
}
if (value === 'azure-openai') {
return 'azure_openai'
}
// 保持前端平台键一致
return value
}
registerProvider(platform, provider) {
const normalized = this.normalizePlatform(platform)
if (!normalized) {
throw new Error('registerProvider: 缺少 platform')
}
if (!provider || typeof provider.queryBalance !== 'function') {
throw new Error(`registerProvider: Provider 无效 (${normalized})`)
}
this.providers.set(normalized, provider)
}
async getAccountBalance(accountId, platform, options = {}) {
const normalizedPlatform = this.normalizePlatform(platform)
const account = await this.getAccount(accountId, normalizedPlatform)
if (!account) {
return null
}
return await this._getAccountBalanceForAccount(account, normalizedPlatform, options)
}
async refreshAccountBalance(accountId, platform) {
const normalizedPlatform = this.normalizePlatform(platform)
const account = await this.getAccount(accountId, normalizedPlatform)
if (!account) {
return null
}
return await this._getAccountBalanceForAccount(account, normalizedPlatform, {
queryApi: true,
useCache: false
})
}
async getAllAccountsBalance(platform, options = {}) {
const normalizedPlatform = this.normalizePlatform(platform)
const accounts = await this.getAllAccountsByPlatform(normalizedPlatform)
const queryApi = this._parseBoolean(options.queryApi) || false
const useCache = options.useCache !== false
const results = await this._mapWithConcurrency(
accounts,
this.DEFAULT_CONCURRENCY,
async (acc) => {
try {
const balance = await this._getAccountBalanceForAccount(acc, normalizedPlatform, {
queryApi,
useCache
})
return { ...balance, name: acc.name || '' }
} catch (error) {
this.logger.error(`批量获取余额失败: ${normalizedPlatform}:${acc?.id}`, error)
return {
success: true,
data: {
accountId: acc?.id,
platform: normalizedPlatform,
balance: null,
quota: null,
statistics: {},
source: 'local',
lastRefreshAt: new Date().toISOString(),
cacheExpiresAt: null,
status: 'error',
error: error.message || '批量查询失败'
},
name: acc?.name || ''
}
}
}
)
return results
}
async getBalanceSummary() {
const platforms = this.getSupportedPlatforms()
const summary = {
totalBalance: 0,
totalCost: 0,
lowBalanceCount: 0,
platforms: {}
}
for (const platform of platforms) {
const accounts = await this.getAllAccountsByPlatform(platform)
const platformData = {
count: accounts.length,
totalBalance: 0,
totalCost: 0,
lowBalanceCount: 0,
accounts: []
}
const balances = await this._mapWithConcurrency(
accounts,
this.DEFAULT_CONCURRENCY,
async (acc) => {
const balance = await this._getAccountBalanceForAccount(acc, platform, {
queryApi: false,
useCache: true
})
return { ...balance, name: acc.name || '' }
}
)
for (const item of balances) {
platformData.accounts.push(item)
const amount = item?.data?.balance?.amount
const percentage = item?.data?.quota?.percentage
const totalCost = Number(item?.data?.statistics?.totalCost || 0)
const hasAmount = typeof amount === 'number' && Number.isFinite(amount)
const isLowBalance = hasAmount && amount < this.LOW_BALANCE_THRESHOLD
const isHighUsage =
typeof percentage === 'number' &&
Number.isFinite(percentage) &&
percentage > this.HIGH_USAGE_THRESHOLD_PERCENT
if (hasAmount) {
platformData.totalBalance += amount
}
if (isLowBalance || isHighUsage) {
platformData.lowBalanceCount += 1
summary.lowBalanceCount += 1
}
platformData.totalCost += totalCost
}
summary.platforms[platform] = platformData
summary.totalBalance += platformData.totalBalance
summary.totalCost += platformData.totalCost
}
return summary
}
async clearCache(accountId, platform) {
const normalizedPlatform = this.normalizePlatform(platform)
if (!normalizedPlatform) {
throw new Error('缺少 platform 参数')
}
await this.redis.deleteAccountBalance(normalizedPlatform, accountId)
this.logger.info(`余额缓存已清除: ${normalizedPlatform}:${accountId}`)
}
async getAccount(accountId, platform) {
if (!accountId || !platform) {
return null
}
const serviceMap = {
claude: require('./claudeAccountService'),
'claude-console': require('./claudeConsoleAccountService'),
gemini: require('./geminiAccountService'),
'gemini-api': require('./geminiApiAccountService'),
openai: require('./openaiAccountService'),
'openai-responses': require('./openaiResponsesAccountService'),
azure_openai: require('./azureOpenaiAccountService'),
bedrock: require('./bedrockAccountService'),
droid: require('./droidAccountService'),
ccr: require('./ccrAccountService')
}
const service = serviceMap[platform]
if (!service || typeof service.getAccount !== 'function') {
return null
}
const result = await service.getAccount(accountId)
// 处理不同服务返回格式的差异
// Bedrock/CCR/Droid 等服务返回 { success, data } 格式
if (result && typeof result === 'object' && 'success' in result && 'data' in result) {
return result.success ? result.data : null
}
return result
}
async getAllAccountsByPlatform(platform) {
if (!platform) {
return []
}
const serviceMap = {
claude: require('./claudeAccountService'),
'claude-console': require('./claudeConsoleAccountService'),
gemini: require('./geminiAccountService'),
'gemini-api': require('./geminiApiAccountService'),
openai: require('./openaiAccountService'),
'openai-responses': require('./openaiResponsesAccountService'),
azure_openai: require('./azureOpenaiAccountService'),
bedrock: require('./bedrockAccountService'),
droid: require('./droidAccountService'),
ccr: require('./ccrAccountService')
}
const service = serviceMap[platform]
if (!service) {
return []
}
// Bedrock 特殊:返回 { success, data }
if (platform === 'bedrock' && typeof service.getAllAccounts === 'function') {
const result = await service.getAllAccounts()
return result?.success ? result.data || [] : []
}
if (platform === 'openai-responses') {
return await service.getAllAccounts(true)
}
if (typeof service.getAllAccounts !== 'function') {
return []
}
return await service.getAllAccounts()
}
async _getAccountBalanceForAccount(account, platform, options = {}) {
const queryMode = this._parseQueryMode(options.queryApi)
const useCache = options.useCache !== false
const accountId = account?.id
if (!accountId) {
// 如果账户缺少 id返回空响应而不是抛出错误避免接口报错和UI错误
this.logger.warn('账户缺少 id返回空余额数据', { account, platform })
return this._buildResponse(
{
status: 'error',
errorMessage: '账户数据异常',
balance: null,
currency: 'USD',
quota: null,
statistics: {},
lastRefreshAt: new Date().toISOString()
},
'unknown',
platform,
'local',
null,
{ scriptEnabled: false, scriptConfigured: false }
)
}
// 余额脚本配置状态(用于前端控制"刷新余额"按钮)
let scriptConfig = null
let scriptConfigured = false
if (typeof this.redis?.getBalanceScriptConfig === 'function') {
scriptConfig = await this.redis.getBalanceScriptConfig(platform, accountId)
scriptConfigured = !!(
scriptConfig &&
scriptConfig.scriptBody &&
String(scriptConfig.scriptBody).trim().length > 0
)
}
const scriptEnabled = isBalanceScriptEnabled()
const scriptMeta = { scriptEnabled, scriptConfigured }
const localBalance = await this._getBalanceFromLocal(accountId, platform)
const localStatistics = localBalance.statistics || {}
const quotaFromLocal = this._buildQuotaFromLocal(account, localStatistics)
// 安全限制queryApi=auto 仅用于 Antigravitygemini + oauthProvider=antigravity账户
const effectiveQueryMode =
queryMode === 'auto' && !(platform === 'gemini' && account?.oauthProvider === 'antigravity')
? 'local'
: queryMode
// local: 仅本地统计/缓存auto: 优先缓存,无缓存则尝试远程 Provider并缓存结果
if (effectiveQueryMode !== 'api') {
if (useCache) {
const cached = await this.redis.getAccountBalance(platform, accountId)
if (cached && cached.status === 'success') {
return this._buildResponse(
{
status: cached.status,
errorMessage: cached.errorMessage,
balance: quotaFromLocal.balance ?? cached.balance,
currency: quotaFromLocal.currency || cached.currency || 'USD',
quota: quotaFromLocal.quota || cached.quota || null,
statistics: localStatistics,
lastRefreshAt: cached.lastRefreshAt
},
accountId,
platform,
'cache',
cached.ttlSeconds,
scriptMeta
)
}
}
if (effectiveQueryMode === 'local') {
return this._buildResponse(
{
status: 'success',
errorMessage: null,
balance: quotaFromLocal.balance,
currency: quotaFromLocal.currency || 'USD',
quota: quotaFromLocal.quota,
statistics: localStatistics,
lastRefreshAt: localBalance.lastCalculated
},
accountId,
platform,
'local',
null,
scriptMeta
)
}
}
// 强制查询:优先脚本(如启用且已配置),否则调用 Provider失败自动降级到本地统计
let providerResult
if (scriptEnabled && scriptConfigured) {
providerResult = await this._getBalanceFromScript(scriptConfig, accountId, platform)
} else {
const provider = this.providers.get(platform)
if (!provider) {
return this._buildResponse(
{
status: 'error',
errorMessage: `不支持的平台: ${platform}`,
balance: quotaFromLocal.balance,
currency: quotaFromLocal.currency || 'USD',
quota: quotaFromLocal.quota,
statistics: localStatistics,
lastRefreshAt: new Date().toISOString()
},
accountId,
platform,
'local',
null,
scriptMeta
)
}
providerResult = await this._getBalanceFromProvider(provider, account)
}
const isRemoteSuccess =
providerResult.status === 'success' && ['api', 'script'].includes(providerResult.queryMethod)
// 仅缓存“真实远程查询成功”的结果,避免把字段/本地降级结果当作 API 结果缓存 1h
if (isRemoteSuccess) {
await this.redis.setAccountBalance(
platform,
accountId,
providerResult,
this.CACHE_TTL_SECONDS
)
}
const source = isRemoteSuccess ? 'api' : 'local'
return this._buildResponse(
{
status: providerResult.status,
errorMessage: providerResult.errorMessage,
balance: quotaFromLocal.balance ?? providerResult.balance,
currency: quotaFromLocal.currency || providerResult.currency || 'USD',
quota: quotaFromLocal.quota || providerResult.quota || null,
statistics: localStatistics,
lastRefreshAt: providerResult.lastRefreshAt
},
accountId,
platform,
source,
null,
scriptMeta
)
}
async _getBalanceFromScript(scriptConfig, accountId, platform) {
try {
const result = await balanceScriptService.execute({
scriptBody: scriptConfig.scriptBody,
timeoutSeconds: scriptConfig.timeoutSeconds || 10,
variables: {
baseUrl: scriptConfig.baseUrl || '',
apiKey: scriptConfig.apiKey || '',
token: scriptConfig.token || '',
accountId,
platform,
extra: scriptConfig.extra || ''
}
})
const mapped = result?.mapped || {}
return {
status: mapped.status || 'error',
balance: typeof mapped.balance === 'number' ? mapped.balance : null,
currency: mapped.currency || 'USD',
quota: mapped.quota || null,
queryMethod: 'api',
rawData: mapped.rawData || result?.response?.data || null,
lastRefreshAt: new Date().toISOString(),
errorMessage: mapped.errorMessage || ''
}
} catch (error) {
return {
status: 'error',
balance: null,
currency: 'USD',
quota: null,
queryMethod: 'api',
rawData: null,
lastRefreshAt: new Date().toISOString(),
errorMessage: error.message || '脚本执行失败'
}
}
}
async _getBalanceFromProvider(provider, account) {
try {
const result = await provider.queryBalance(account)
return {
status: 'success',
balance: typeof result?.balance === 'number' ? result.balance : null,
currency: result?.currency || 'USD',
quota: result?.quota || null,
queryMethod: result?.queryMethod || 'api',
rawData: result?.rawData || null,
lastRefreshAt: new Date().toISOString(),
errorMessage: ''
}
} catch (error) {
return {
status: 'error',
balance: null,
currency: 'USD',
quota: null,
queryMethod: 'api',
rawData: null,
lastRefreshAt: new Date().toISOString(),
errorMessage: error.message || '查询失败'
}
}
}
async _getBalanceFromLocal(accountId, platform) {
const cached = await this.redis.getLocalBalance(platform, accountId)
if (cached && cached.statistics) {
return cached
}
const statistics = await this._computeLocalStatistics(accountId)
const localBalance = {
status: 'success',
balance: null,
currency: 'USD',
statistics,
queryMethod: 'local',
lastCalculated: new Date().toISOString()
}
await this.redis.setLocalBalance(platform, accountId, localBalance, this.LOCAL_TTL_SECONDS)
return localBalance
}
async _computeLocalStatistics(accountId) {
const safeNumber = (value) => {
const num = Number(value)
return Number.isFinite(num) ? num : 0
}
try {
const usageStats = await this.redis.getAccountUsageStats(accountId)
const dailyCost = safeNumber(usageStats?.daily?.cost || 0)
const monthlyCost = await this._computeMonthlyCost(accountId)
const totalCost = await this._computeTotalCost(accountId)
return {
totalCost,
dailyCost,
monthlyCost,
totalRequests: safeNumber(usageStats?.total?.requests || 0),
dailyRequests: safeNumber(usageStats?.daily?.requests || 0),
monthlyRequests: safeNumber(usageStats?.monthly?.requests || 0)
}
} catch (error) {
this.logger.debug(`本地统计计算失败: ${accountId}`, error)
return {
totalCost: 0,
dailyCost: 0,
monthlyCost: 0,
totalRequests: 0,
dailyRequests: 0,
monthlyRequests: 0
}
}
}
async _computeMonthlyCost(accountId) {
const tzDate = this.redis.getDateInTimezone(new Date())
const currentMonth = `${tzDate.getUTCFullYear()}-${String(tzDate.getUTCMonth() + 1).padStart(
2,
'0'
)}`
const pattern = `account_usage:model:monthly:${accountId}:*:${currentMonth}`
return await this._sumModelCostsByKeysPattern(pattern)
}
async _computeTotalCost(accountId) {
const pattern = `account_usage:model:monthly:${accountId}:*:*`
return await this._sumModelCostsByKeysPattern(pattern)
}
async _sumModelCostsByKeysPattern(pattern) {
try {
const client = this.redis.getClientSafe()
let totalCost = 0
let cursor = '0'
const scanCount = 200
let iterations = 0
const maxIterations = 2000
do {
const [nextCursor, keys] = await client.scan(cursor, 'MATCH', pattern, 'COUNT', scanCount)
cursor = nextCursor
iterations += 1
if (!keys || keys.length === 0) {
continue
}
const pipeline = client.pipeline()
keys.forEach((key) => pipeline.hgetall(key))
const results = await pipeline.exec()
for (let i = 0; i < results.length; i += 1) {
const [, data] = results[i] || []
if (!data || Object.keys(data).length === 0) {
continue
}
const parts = String(keys[i]).split(':')
const model = parts[4] || 'unknown'
const usage = {
input_tokens: parseInt(data.inputTokens || 0),
output_tokens: parseInt(data.outputTokens || 0),
cache_creation_input_tokens: parseInt(data.cacheCreateTokens || 0),
cache_read_input_tokens: parseInt(data.cacheReadTokens || 0)
}
const costResult = CostCalculator.calculateCost(usage, model)
totalCost += costResult.costs.total || 0
}
if (iterations >= maxIterations) {
this.logger.warn(`SCAN 次数超过上限,停止汇总:${pattern}`)
break
}
} while (cursor !== '0')
return totalCost
} catch (error) {
this.logger.debug(`汇总模型费用失败: ${pattern}`, error)
return 0
}
}
_buildQuotaFromLocal(account, statistics) {
if (!account || !Object.prototype.hasOwnProperty.call(account, 'dailyQuota')) {
return { balance: null, currency: null, quota: null }
}
const dailyQuota = Number(account.dailyQuota || 0)
const used = Number(statistics?.dailyCost || 0)
const resetAt = this._computeNextResetAt(account.quotaResetTime || '00:00')
// 不限制
if (!Number.isFinite(dailyQuota) || dailyQuota <= 0) {
return {
balance: null,
currency: 'USD',
quota: {
daily: Infinity,
used,
remaining: Infinity,
percentage: 0,
unlimited: true,
resetAt
}
}
}
const remaining = Math.max(0, dailyQuota - used)
const percentage = dailyQuota > 0 ? (used / dailyQuota) * 100 : 0
return {
balance: remaining,
currency: 'USD',
quota: {
daily: dailyQuota,
used,
remaining,
resetAt,
percentage: Math.round(percentage * 100) / 100
}
}
}
_computeNextResetAt(resetTime) {
const now = new Date()
const tzNow = this.redis.getDateInTimezone(now)
const offsetMs = tzNow.getTime() - now.getTime()
const [h, m] = String(resetTime || '00:00')
.split(':')
.map((n) => parseInt(n, 10))
const resetHour = Number.isFinite(h) ? h : 0
const resetMinute = Number.isFinite(m) ? m : 0
const year = tzNow.getUTCFullYear()
const month = tzNow.getUTCMonth()
const day = tzNow.getUTCDate()
let resetAtMs = Date.UTC(year, month, day, resetHour, resetMinute, 0, 0) - offsetMs
if (resetAtMs <= now.getTime()) {
resetAtMs += 24 * 60 * 60 * 1000
}
return new Date(resetAtMs).toISOString()
}
_buildResponse(balanceData, accountId, platform, source, ttlSeconds = null, extraData = {}) {
const now = new Date()
const amount = typeof balanceData.balance === 'number' ? balanceData.balance : null
const currency = balanceData.currency || 'USD'
let cacheExpiresAt = null
if (source === 'cache') {
const ttl =
typeof ttlSeconds === 'number' && ttlSeconds > 0 ? ttlSeconds : this.CACHE_TTL_SECONDS
cacheExpiresAt = new Date(Date.now() + ttl * 1000).toISOString()
}
return {
success: true,
data: {
accountId,
platform,
balance:
typeof amount === 'number'
? {
amount,
currency,
formattedAmount: this._formatCurrency(amount, currency)
}
: null,
quota: balanceData.quota || null,
statistics: balanceData.statistics || {},
source,
lastRefreshAt: balanceData.lastRefreshAt || now.toISOString(),
cacheExpiresAt,
status: balanceData.status || 'success',
error: balanceData.errorMessage || null,
...(extraData && typeof extraData === 'object' ? extraData : {})
}
}
}
_formatCurrency(amount, currency = 'USD') {
try {
if (typeof amount !== 'number' || !Number.isFinite(amount)) {
return 'N/A'
}
return new Intl.NumberFormat('en-US', { style: 'currency', currency }).format(amount)
} catch (error) {
return `$${amount.toFixed(2)}`
}
}
_parseBoolean(value) {
if (typeof value === 'boolean') {
return value
}
if (typeof value !== 'string') {
return null
}
const normalized = value.trim().toLowerCase()
if (normalized === 'true' || normalized === '1' || normalized === 'yes') {
return true
}
if (normalized === 'false' || normalized === '0' || normalized === 'no') {
return false
}
return null
}
_parseQueryMode(value) {
if (value === 'auto') {
return 'auto'
}
const parsed = this._parseBoolean(value)
return parsed ? 'api' : 'local'
}
async _mapWithConcurrency(items, limit, mapper) {
const concurrency = Math.max(1, Number(limit) || 1)
const list = Array.isArray(items) ? items : []
const results = new Array(list.length)
let nextIndex = 0
const workers = new Array(Math.min(concurrency, list.length)).fill(null).map(async () => {
while (nextIndex < list.length) {
const currentIndex = nextIndex
nextIndex += 1
results[currentIndex] = await mapper(list[currentIndex], currentIndex)
}
})
await Promise.all(workers)
return results
}
}
const accountBalanceService = new AccountBalanceService()
module.exports = accountBalanceService
module.exports.AccountBalanceService = AccountBalanceService

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,595 @@
const axios = require('axios')
const https = require('https')
const { v4: uuidv4 } = require('uuid')
const ProxyHelper = require('../utils/proxyHelper')
const logger = require('../utils/logger')
const {
mapAntigravityUpstreamModel,
normalizeAntigravityModelInput,
getAntigravityModelMetadata
} = require('../utils/antigravityModel')
const { cleanJsonSchemaForGemini } = require('../utils/geminiSchemaCleaner')
const { dumpAntigravityUpstreamRequest } = require('../utils/antigravityUpstreamDump')
const keepAliveAgent = new https.Agent({
keepAlive: true,
keepAliveMsecs: 30000,
timeout: 120000,
maxSockets: 100,
maxFreeSockets: 10
})
function getAntigravityApiUrl() {
return process.env.ANTIGRAVITY_API_URL || 'https://daily-cloudcode-pa.sandbox.googleapis.com'
}
function normalizeBaseUrl(url) {
const str = String(url || '').trim()
return str.endsWith('/') ? str.slice(0, -1) : str
}
function getAntigravityApiUrlCandidates() {
const configured = normalizeBaseUrl(getAntigravityApiUrl())
const daily = 'https://daily-cloudcode-pa.sandbox.googleapis.com'
const prod = 'https://cloudcode-pa.googleapis.com'
// 若显式配置了自定义 base url则只使用该地址不做 fallback避免意外路由到别的环境
if (process.env.ANTIGRAVITY_API_URL) {
return [configured]
}
// 默认行为:优先 daily与旧逻辑一致失败时再尝试 prod对齐 CLIProxyAPI
if (configured === normalizeBaseUrl(daily)) {
return [configured, prod]
}
if (configured === normalizeBaseUrl(prod)) {
return [configured, daily]
}
return [configured, prod, daily].filter(Boolean)
}
function getAntigravityHeaders(accessToken, baseUrl) {
const resolvedBaseUrl = baseUrl || getAntigravityApiUrl()
let host = 'daily-cloudcode-pa.sandbox.googleapis.com'
try {
host = new URL(resolvedBaseUrl).host || host
} catch (e) {
// ignore
}
return {
Host: host,
'User-Agent': process.env.ANTIGRAVITY_USER_AGENT || 'antigravity/1.11.3 windows/amd64',
Authorization: `Bearer ${accessToken}`,
'Content-Type': 'application/json',
'Accept-Encoding': 'gzip',
requestType: 'agent'
}
}
function generateAntigravityProjectId() {
return `ag-${uuidv4().replace(/-/g, '').slice(0, 16)}`
}
function generateAntigravitySessionId() {
return `sess-${uuidv4()}`
}
function resolveAntigravityProjectId(projectId, requestData) {
const candidate = projectId || requestData?.project || requestData?.projectId || null
return candidate || generateAntigravityProjectId()
}
function resolveAntigravitySessionId(sessionId, requestData) {
const candidate =
sessionId || requestData?.request?.sessionId || requestData?.request?.session_id || null
return candidate || generateAntigravitySessionId()
}
function buildAntigravityEnvelope({ requestData, projectId, sessionId, userPromptId }) {
const model = mapAntigravityUpstreamModel(requestData?.model)
const resolvedProjectId = resolveAntigravityProjectId(projectId, requestData)
const resolvedSessionId = resolveAntigravitySessionId(sessionId, requestData)
const requestPayload = {
...(requestData?.request || {})
}
if (requestPayload.session_id !== undefined) {
delete requestPayload.session_id
}
requestPayload.sessionId = resolvedSessionId
const envelope = {
project: resolvedProjectId,
requestId: `req-${uuidv4()}`,
model,
userAgent: 'antigravity',
request: {
...requestPayload
}
}
if (userPromptId) {
envelope.user_prompt_id = userPromptId
envelope.userPromptId = userPromptId
}
normalizeAntigravityEnvelope(envelope)
return { model, envelope }
}
function normalizeAntigravityThinking(model, requestPayload) {
if (!requestPayload || typeof requestPayload !== 'object') {
return
}
const { generationConfig } = requestPayload
if (!generationConfig || typeof generationConfig !== 'object') {
return
}
const { thinkingConfig } = generationConfig
if (!thinkingConfig || typeof thinkingConfig !== 'object') {
return
}
const normalizedModel = normalizeAntigravityModelInput(model)
if (thinkingConfig.thinkingLevel && !normalizedModel.startsWith('gemini-3-')) {
delete thinkingConfig.thinkingLevel
}
const metadata = getAntigravityModelMetadata(normalizedModel)
if (metadata && !metadata.thinking) {
delete generationConfig.thinkingConfig
return
}
if (!metadata || !metadata.thinking) {
return
}
const budgetRaw = Number(thinkingConfig.thinkingBudget)
if (!Number.isFinite(budgetRaw)) {
return
}
let budget = Math.trunc(budgetRaw)
const minBudget = Number.isFinite(metadata.thinking.min) ? metadata.thinking.min : null
const maxBudget = Number.isFinite(metadata.thinking.max) ? metadata.thinking.max : null
if (maxBudget !== null && budget > maxBudget) {
budget = maxBudget
}
let effectiveMax = Number.isFinite(generationConfig.maxOutputTokens)
? generationConfig.maxOutputTokens
: null
let setDefaultMax = false
if (!effectiveMax && metadata.maxCompletionTokens) {
effectiveMax = metadata.maxCompletionTokens
setDefaultMax = true
}
if (effectiveMax && budget >= effectiveMax) {
budget = Math.max(0, effectiveMax - 1)
}
if (minBudget !== null && budget >= 0 && budget < minBudget) {
delete generationConfig.thinkingConfig
return
}
thinkingConfig.thinkingBudget = budget
if (setDefaultMax) {
generationConfig.maxOutputTokens = effectiveMax
}
}
function normalizeAntigravityEnvelope(envelope) {
if (!envelope || typeof envelope !== 'object') {
return
}
const model = String(envelope.model || '')
const requestPayload = envelope.request
if (!requestPayload || typeof requestPayload !== 'object') {
return
}
if (requestPayload.safetySettings !== undefined) {
delete requestPayload.safetySettings
}
// 对齐 CLIProxyAPI有 tools 时默认启用 VALIDATED除非显式 NONE
if (Array.isArray(requestPayload.tools) && requestPayload.tools.length > 0) {
const existing = requestPayload?.toolConfig?.functionCallingConfig || null
if (existing?.mode !== 'NONE') {
const nextCfg = { ...(existing || {}), mode: 'VALIDATED' }
requestPayload.toolConfig = { functionCallingConfig: nextCfg }
}
}
// 对齐 CLIProxyAPI非 Claude 模型移除 maxOutputTokensAntigravity 环境不稳定)
normalizeAntigravityThinking(model, requestPayload)
if (!model.includes('claude')) {
if (requestPayload.generationConfig && typeof requestPayload.generationConfig === 'object') {
delete requestPayload.generationConfig.maxOutputTokens
}
return
}
// Claude 模型parametersJsonSchema -> parameters + schema 清洗(避免 $schema / additionalProperties 等触发 400
if (!Array.isArray(requestPayload.tools)) {
return
}
for (const tool of requestPayload.tools) {
if (!tool || typeof tool !== 'object') {
continue
}
const decls = Array.isArray(tool.functionDeclarations)
? tool.functionDeclarations
: Array.isArray(tool.function_declarations)
? tool.function_declarations
: null
if (!decls) {
continue
}
for (const decl of decls) {
if (!decl || typeof decl !== 'object') {
continue
}
let schema =
decl.parametersJsonSchema !== undefined ? decl.parametersJsonSchema : decl.parameters
if (typeof schema === 'string' && schema) {
try {
schema = JSON.parse(schema)
} catch (_) {
schema = null
}
}
decl.parameters = cleanJsonSchemaForGemini(schema)
delete decl.parametersJsonSchema
}
}
}
async function request({
accessToken,
proxyConfig = null,
requestData,
projectId = null,
sessionId = null,
userPromptId = null,
stream = false,
signal = null,
params = null,
timeoutMs = null
}) {
const { model, envelope } = buildAntigravityEnvelope({
requestData,
projectId,
sessionId,
userPromptId
})
const proxyAgent = ProxyHelper.createProxyAgent(proxyConfig)
let endpoints = getAntigravityApiUrlCandidates()
// Claude 模型在 sandbox(daily) 环境下对 tool_use/tool_result 的兼容性不稳定,优先走 prod。
// 保持可配置优先:若用户显式设置了 ANTIGRAVITY_API_URL则不改变顺序。
if (!process.env.ANTIGRAVITY_API_URL && String(model).includes('claude')) {
const prodHost = 'cloudcode-pa.googleapis.com'
const dailyHost = 'daily-cloudcode-pa.sandbox.googleapis.com'
const ordered = []
for (const u of endpoints) {
if (String(u).includes(prodHost)) {
ordered.push(u)
}
}
for (const u of endpoints) {
if (!String(u).includes(prodHost)) {
ordered.push(u)
}
}
// 去重并保持 prod -> daily 的稳定顺序
endpoints = Array.from(new Set(ordered)).sort((a, b) => {
const av = String(a)
const bv = String(b)
const aScore = av.includes(prodHost) ? 0 : av.includes(dailyHost) ? 1 : 2
const bScore = bv.includes(prodHost) ? 0 : bv.includes(dailyHost) ? 1 : 2
return aScore - bScore
})
}
const isRetryable = (error) => {
// 处理网络层面的连接重置或超时(常见于长请求被中间节点切断)
if (error.code === 'ECONNRESET' || error.code === 'ETIMEDOUT') {
return true
}
const status = error?.response?.status
if (status === 429) {
return true
}
// 400/404 的 “model unavailable / not found” 在不同环境间可能表现不同,允许 fallback。
if (status === 400 || status === 404) {
const data = error?.response?.data
const safeToString = (value) => {
if (typeof value === 'string') {
return value
}
if (value === null || value === undefined) {
return ''
}
// axios responseType=stream 时data 可能是 stream存在循环引用不能 JSON.stringify
if (typeof value === 'object' && typeof value.pipe === 'function') {
return ''
}
if (Buffer.isBuffer(value)) {
try {
return value.toString('utf8')
} catch (_) {
return ''
}
}
if (typeof value === 'object') {
try {
return JSON.stringify(value)
} catch (_) {
return ''
}
}
return String(value)
}
const text = safeToString(data)
const msg = (text || '').toLowerCase()
return (
msg.includes('requested model is currently unavailable') ||
msg.includes('tool_use') ||
msg.includes('tool_result') ||
msg.includes('requested entity was not found') ||
msg.includes('not found')
)
}
return false
}
let lastError = null
let retriedAfterDelay = false
const attemptRequest = async () => {
for (let index = 0; index < endpoints.length; index += 1) {
const baseUrl = endpoints[index]
const url = `${baseUrl}/v1internal:${stream ? 'streamGenerateContent' : 'generateContent'}`
const axiosConfig = {
url,
method: 'POST',
...(params ? { params } : {}),
headers: getAntigravityHeaders(accessToken, baseUrl),
data: envelope,
timeout: stream ? 0 : timeoutMs || 600000,
...(stream ? { responseType: 'stream' } : {})
}
if (proxyAgent) {
axiosConfig.httpsAgent = proxyAgent
axiosConfig.proxy = false
if (index === 0) {
logger.info(
`🌐 Using proxy for Antigravity ${stream ? 'streamGenerateContent' : 'generateContent'}: ${ProxyHelper.getProxyDescription(proxyConfig)}`
)
}
} else {
axiosConfig.httpsAgent = keepAliveAgent
}
if (signal) {
axiosConfig.signal = signal
}
try {
dumpAntigravityUpstreamRequest({
requestId: envelope.requestId,
model,
stream,
url,
baseUrl,
params: axiosConfig.params || null,
headers: axiosConfig.headers,
envelope
}).catch(() => {})
const response = await axios(axiosConfig)
return { model, response }
} catch (error) {
lastError = error
const status = error?.response?.status || null
const hasNext = index + 1 < endpoints.length
if (hasNext && isRetryable(error)) {
logger.warn('⚠️ Antigravity upstream error, retrying with fallback baseUrl', {
status,
from: baseUrl,
to: endpoints[index + 1],
model
})
continue
}
throw error
}
}
throw lastError || new Error('Antigravity request failed')
}
try {
return await attemptRequest()
} catch (error) {
// 如果是 429 RESOURCE_EXHAUSTED 且尚未重试过,等待 2 秒后重试一次
const status = error?.response?.status
if (status === 429 && !retriedAfterDelay && !signal?.aborted) {
const data = error?.response?.data
// 安全地将 data 转为字符串,避免 stream 对象导致循环引用崩溃
const safeDataToString = (value) => {
if (typeof value === 'string') {
return value
}
if (value === null || value === undefined) {
return ''
}
// stream 对象存在循环引用,不能 JSON.stringify
if (typeof value === 'object' && typeof value.pipe === 'function') {
return ''
}
if (Buffer.isBuffer(value)) {
try {
return value.toString('utf8')
} catch (_) {
return ''
}
}
if (typeof value === 'object') {
try {
return JSON.stringify(value)
} catch (_) {
return ''
}
}
return String(value)
}
const msg = safeDataToString(data)
if (
msg.toLowerCase().includes('resource_exhausted') ||
msg.toLowerCase().includes('no capacity')
) {
retriedAfterDelay = true
logger.warn('⏳ Antigravity 429 RESOURCE_EXHAUSTED, waiting 2s before retry', { model })
await new Promise((resolve) => setTimeout(resolve, 2000))
return await attemptRequest()
}
}
throw error
}
}
async function fetchAvailableModels({ accessToken, proxyConfig = null, timeoutMs = 30000 }) {
const proxyAgent = ProxyHelper.createProxyAgent(proxyConfig)
const endpoints = getAntigravityApiUrlCandidates()
let lastError = null
for (let index = 0; index < endpoints.length; index += 1) {
const baseUrl = endpoints[index]
const url = `${baseUrl}/v1internal:fetchAvailableModels`
const axiosConfig = {
url,
method: 'POST',
headers: getAntigravityHeaders(accessToken, baseUrl),
data: {},
timeout: timeoutMs
}
if (proxyAgent) {
axiosConfig.httpsAgent = proxyAgent
axiosConfig.proxy = false
if (index === 0) {
logger.info(
`🌐 Using proxy for Antigravity fetchAvailableModels: ${ProxyHelper.getProxyDescription(proxyConfig)}`
)
}
} else {
axiosConfig.httpsAgent = keepAliveAgent
}
try {
const response = await axios(axiosConfig)
return response.data
} catch (error) {
lastError = error
const status = error?.response?.status
const hasNext = index + 1 < endpoints.length
if (hasNext && (status === 429 || status === 404)) {
continue
}
throw error
}
}
throw lastError || new Error('Antigravity fetchAvailableModels failed')
}
async function countTokens({
accessToken,
proxyConfig = null,
contents,
model,
timeoutMs = 30000
}) {
const upstreamModel = mapAntigravityUpstreamModel(model)
const proxyAgent = ProxyHelper.createProxyAgent(proxyConfig)
const endpoints = getAntigravityApiUrlCandidates()
let lastError = null
for (let index = 0; index < endpoints.length; index += 1) {
const baseUrl = endpoints[index]
const url = `${baseUrl}/v1internal:countTokens`
const axiosConfig = {
url,
method: 'POST',
headers: getAntigravityHeaders(accessToken, baseUrl),
data: {
request: {
model: `models/${upstreamModel}`,
contents
}
},
timeout: timeoutMs
}
if (proxyAgent) {
axiosConfig.httpsAgent = proxyAgent
axiosConfig.proxy = false
if (index === 0) {
logger.info(
`🌐 Using proxy for Antigravity countTokens: ${ProxyHelper.getProxyDescription(proxyConfig)}`
)
}
} else {
axiosConfig.httpsAgent = keepAliveAgent
}
try {
const response = await axios(axiosConfig)
return response.data
} catch (error) {
lastError = error
const status = error?.response?.status
const hasNext = index + 1 < endpoints.length
if (hasNext && (status === 429 || status === 404)) {
continue
}
throw error
}
}
throw lastError || new Error('Antigravity countTokens failed')
}
module.exports = {
getAntigravityApiUrl,
getAntigravityApiUrlCandidates,
getAntigravityHeaders,
buildAntigravityEnvelope,
request,
fetchAvailableModels,
countTokens
}

View File

@@ -0,0 +1,173 @@
const apiKeyService = require('./apiKeyService')
const { convertMessagesToGemini, convertGeminiResponse } = require('./geminiRelayService')
const { normalizeAntigravityModelInput } = require('../utils/antigravityModel')
const antigravityClient = require('./antigravityClient')
function buildRequestData({ messages, model, temperature, maxTokens, sessionId }) {
const requestedModel = normalizeAntigravityModelInput(model)
const { contents, systemInstruction } = convertMessagesToGemini(messages)
const requestData = {
model: requestedModel,
request: {
contents,
generationConfig: {
temperature,
maxOutputTokens: maxTokens,
candidateCount: 1,
topP: 0.95,
topK: 40
},
...(sessionId ? { sessionId } : {})
}
}
if (systemInstruction) {
requestData.request.systemInstruction = { parts: [{ text: systemInstruction }] }
}
return requestData
}
async function* handleStreamResponse(response, model, apiKeyId, accountId) {
let buffer = ''
let totalUsage = {
promptTokenCount: 0,
candidatesTokenCount: 0,
totalTokenCount: 0
}
let usageRecorded = false
try {
for await (const chunk of response.data) {
buffer += chunk.toString()
const lines = buffer.split('\n')
buffer = lines.pop() || ''
for (const line of lines) {
if (!line.trim()) {
continue
}
let jsonData = line
if (line.startsWith('data: ')) {
jsonData = line.substring(6).trim()
}
if (!jsonData || jsonData === '[DONE]') {
continue
}
try {
const data = JSON.parse(jsonData)
const payload = data?.response || data
if (payload?.usageMetadata) {
totalUsage = payload.usageMetadata
}
const openaiChunk = convertGeminiResponse(payload, model, true)
if (openaiChunk) {
yield `data: ${JSON.stringify(openaiChunk)}\n\n`
const finishReason = openaiChunk.choices?.[0]?.finish_reason
if (finishReason === 'stop') {
yield 'data: [DONE]\n\n'
if (apiKeyId && totalUsage.totalTokenCount > 0) {
await apiKeyService.recordUsage(
apiKeyId,
totalUsage.promptTokenCount || 0,
totalUsage.candidatesTokenCount || 0,
0,
0,
model,
accountId,
'gemini'
)
usageRecorded = true
}
return
}
}
} catch (e) {
// ignore chunk parse errors
}
}
}
} finally {
if (!usageRecorded && apiKeyId && totalUsage.totalTokenCount > 0) {
await apiKeyService.recordUsage(
apiKeyId,
totalUsage.promptTokenCount || 0,
totalUsage.candidatesTokenCount || 0,
0,
0,
model,
accountId,
'gemini'
)
}
}
}
async function sendAntigravityRequest({
messages,
model,
temperature = 0.7,
maxTokens = 4096,
stream = false,
accessToken,
proxy,
apiKeyId,
signal,
projectId,
accountId = null
}) {
const requestedModel = normalizeAntigravityModelInput(model)
const requestData = buildRequestData({
messages,
model: requestedModel,
temperature,
maxTokens,
sessionId: apiKeyId
})
const { response } = await antigravityClient.request({
accessToken,
proxyConfig: proxy,
requestData,
projectId,
sessionId: apiKeyId,
stream,
signal,
params: { alt: 'sse' }
})
if (stream) {
return handleStreamResponse(response, requestedModel, apiKeyId, accountId)
}
const payload = response.data?.response || response.data
const openaiResponse = convertGeminiResponse(payload, requestedModel, false)
if (apiKeyId && openaiResponse?.usage) {
await apiKeyService.recordUsage(
apiKeyId,
openaiResponse.usage.prompt_tokens || 0,
openaiResponse.usage.completion_tokens || 0,
0,
0,
requestedModel,
accountId,
'gemini'
)
}
return openaiResponse
}
module.exports = {
sendAntigravityRequest
}

View File

@@ -38,6 +38,51 @@ const ACCOUNT_CATEGORY_MAP = {
droid: 'droid'
}
/**
* 规范化权限数据,兼容旧格式(字符串)和新格式(数组)
* @param {string|array} permissions - 权限数据
* @returns {array} - 权限数组,空数组表示全部服务
*/
function normalizePermissions(permissions) {
if (!permissions) {
return [] // 空 = 全部服务
}
if (Array.isArray(permissions)) {
return permissions
}
// 尝试解析 JSON 字符串(新格式存储)
if (typeof permissions === 'string') {
if (permissions.startsWith('[')) {
try {
const parsed = JSON.parse(permissions)
if (Array.isArray(parsed)) {
return parsed
}
} catch (e) {
// 解析失败,继续处理为普通字符串
}
}
// 旧格式 'all' 转为空数组
if (permissions === 'all') {
return []
}
// 旧单个字符串转为数组
return [permissions]
}
return []
}
/**
* 检查是否有访问特定服务的权限
* @param {string|array} permissions - 权限数据
* @param {string} service - 服务名称claude/gemini/openai/droid
* @returns {boolean} - 是否有权限
*/
function hasPermission(permissions, service) {
const perms = normalizePermissions(permissions)
return perms.length === 0 || perms.includes(service) // 空数组 = 全部服务
}
function normalizeAccountTypeKey(type) {
if (!type) {
return null
@@ -90,7 +135,7 @@ class ApiKeyService {
azureOpenaiAccountId = null,
bedrockAccountId = null, // 添加 Bedrock 账号ID支持
droidAccountId = null,
permissions = 'all', // 可选值:'claude''gemini'、'openai'、'droid' 或 'all'聚合Key为数组
permissions = [], // 数组格式,空数组表示全部服务,如 ['claude', 'gemini']
isActive = true,
concurrencyLimit = 0,
rateLimitWindow = null,
@@ -108,12 +153,7 @@ class ApiKeyService {
activationUnit = 'days', // 新增:激活时间单位 'hours' 或 'days'
expirationMode = 'fixed', // 新增:过期模式 'fixed'(固定时间) 或 'activation'(首次使用后激活)
icon = '', // 新增图标base64编码
// 聚合 Key 相关字段
isAggregated = false, // 是否为聚合 Key
quotaLimit = 0, // CC 额度上限(聚合 Key 使用)
quotaUsed = 0, // 已消耗 CC 额度
serviceQuotaLimits = {}, // 分服务额度限制 { claude: 50, codex: 30 }
serviceQuotaUsed = {} // 分服务已消耗额度
serviceRates = {} // API Key 级别服务倍率覆盖
} = options
// 生成简单的API Key (64字符十六进制)
@@ -121,15 +161,8 @@ class ApiKeyService {
const keyId = uuidv4()
const hashedKey = this._hashApiKey(apiKey)
// 处理 permissions:聚合 Key 使用数组,传统 Key 使用字符串
// 处理 permissions
let permissionsValue = permissions
if (isAggregated && !Array.isArray(permissions)) {
// 聚合 Key 但 permissions 不是数组,转换为数组
permissionsValue =
permissions === 'all'
? ['claude', 'codex', 'gemini', 'droid', 'bedrock', 'azure', 'ccr']
: [permissions]
}
const keyData = {
id: keyId,
@@ -149,9 +182,7 @@ class ApiKeyService {
azureOpenaiAccountId: azureOpenaiAccountId || '',
bedrockAccountId: bedrockAccountId || '', // 添加 Bedrock 账号ID
droidAccountId: droidAccountId || '',
permissions: Array.isArray(permissionsValue)
? JSON.stringify(permissionsValue)
: permissionsValue || 'all',
permissions: JSON.stringify(normalizePermissions(permissions)),
enableModelRestriction: String(enableModelRestriction),
restrictedModels: JSON.stringify(restrictedModels || []),
enableClientRestriction: String(enableClientRestriction || false),
@@ -172,12 +203,7 @@ class ApiKeyService {
userId: options.userId || '',
userUsername: options.userUsername || '',
icon: icon || '', // 新增图标base64编码
// 聚合 Key 相关字段
isAggregated: String(isAggregated),
quotaLimit: String(quotaLimit || 0),
quotaUsed: String(quotaUsed || 0),
serviceQuotaLimits: JSON.stringify(serviceQuotaLimits || {}),
serviceQuotaUsed: JSON.stringify(serviceQuotaUsed || {})
serviceRates: JSON.stringify(serviceRates || {}) // API Key 级别服务倍率
}
// 保存API Key数据并建立哈希映射
@@ -207,9 +233,7 @@ class ApiKeyService {
logger.warn(`Failed to add key ${keyId} to API Key index:`, err.message)
}
logger.success(
`🔑 Generated new API key: ${name} (${keyId})${isAggregated ? ' [Aggregated]' : ''}`
)
logger.success(`🔑 Generated new API key: ${name} (${keyId})`)
// 解析 permissions 用于返回
let parsedPermissions = keyData.permissions
@@ -237,7 +261,7 @@ class ApiKeyService {
azureOpenaiAccountId: keyData.azureOpenaiAccountId,
bedrockAccountId: keyData.bedrockAccountId, // 添加 Bedrock 账号ID
droidAccountId: keyData.droidAccountId,
permissions: parsedPermissions,
permissions: normalizePermissions(keyData.permissions),
enableModelRestriction: keyData.enableModelRestriction === 'true',
restrictedModels: JSON.parse(keyData.restrictedModels),
enableClientRestriction: keyData.enableClientRestriction === 'true',
@@ -254,12 +278,7 @@ class ApiKeyService {
createdAt: keyData.createdAt,
expiresAt: keyData.expiresAt,
createdBy: keyData.createdBy,
// 聚合 Key 相关字段
isAggregated: keyData.isAggregated === 'true',
quotaLimit: parseFloat(keyData.quotaLimit || 0),
quotaUsed: parseFloat(keyData.quotaUsed || 0),
serviceQuotaLimits: JSON.parse(keyData.serviceQuotaLimits || '{}'),
serviceQuotaUsed: JSON.parse(keyData.serviceQuotaUsed || '{}')
serviceRates: JSON.parse(keyData.serviceRates || '{}') // API Key 级别服务倍率
}
}
@@ -399,14 +418,10 @@ class ApiKeyService {
// 不是 JSON保持原值
}
// 解析聚合 Key 相关字段
let serviceQuotaLimits = {}
let serviceQuotaUsed = {}
// 解析 serviceRates
let serviceRates = {}
try {
serviceQuotaLimits = keyData.serviceQuotaLimits
? JSON.parse(keyData.serviceQuotaLimits)
: {}
serviceQuotaUsed = keyData.serviceQuotaUsed ? JSON.parse(keyData.serviceQuotaUsed) : {}
serviceRates = keyData.serviceRates ? JSON.parse(keyData.serviceRates) : {}
} catch (e) {
// 解析失败使用默认值
}
@@ -426,7 +441,7 @@ class ApiKeyService {
azureOpenaiAccountId: keyData.azureOpenaiAccountId,
bedrockAccountId: keyData.bedrockAccountId, // 添加 Bedrock 账号ID
droidAccountId: keyData.droidAccountId,
permissions,
permissions: normalizePermissions(keyData.permissions),
tokenLimit: parseInt(keyData.tokenLimit),
concurrencyLimit: parseInt(keyData.concurrencyLimit || 0),
rateLimitWindow: parseInt(keyData.rateLimitWindow || 0),
@@ -443,12 +458,7 @@ class ApiKeyService {
totalCost: costData.totalCost || 0,
weeklyOpusCost: costData.weeklyOpusCost || 0,
tags,
// 聚合 Key 相关字段
isAggregated: keyData.isAggregated === 'true',
quotaLimit: parseFloat(keyData.quotaLimit || 0),
quotaUsed: parseFloat(keyData.quotaUsed || 0),
serviceQuotaLimits,
serviceQuotaUsed
serviceRates
}
}
} catch (error) {
@@ -560,7 +570,7 @@ class ApiKeyService {
azureOpenaiAccountId: keyData.azureOpenaiAccountId,
bedrockAccountId: keyData.bedrockAccountId,
droidAccountId: keyData.droidAccountId,
permissions: keyData.permissions || 'all',
permissions: normalizePermissions(keyData.permissions),
tokenLimit: parseInt(keyData.tokenLimit),
concurrencyLimit: parseInt(keyData.concurrencyLimit || 0),
rateLimitWindow: parseInt(keyData.rateLimitWindow || 0),
@@ -586,9 +596,154 @@ class ApiKeyService {
}
}
// 🏷️ 获取所有标签(轻量级,使用 SCAN + Pipeline
// 🏷️ 获取所有标签(合并索引和全局集合
async getAllTags() {
return await redis.scanAllApiKeyTags()
const indexTags = await redis.scanAllApiKeyTags()
const globalTags = await redis.getGlobalTags()
// 过滤空值和空格
return [
...new Set(
[...indexTags, ...globalTags].map((t) => (t ? t.trim() : '')).filter((t) => t)
)
].sort()
}
// 🏷️ 创建新标签
async createTag(tagName) {
const existingTags = await this.getAllTags()
if (existingTags.includes(tagName)) {
return { success: false, error: '标签已存在' }
}
await redis.addTag(tagName)
return { success: true }
}
// 🏷️ 获取标签详情(含使用数量)
async getTagsWithCount() {
const apiKeys = await redis.getAllApiKeys()
const tagCounts = new Map()
// 统计 API Key 上的标签trim 后统计)
for (const key of apiKeys) {
if (key.isDeleted === 'true') continue
let tags = []
try {
const parsed = key.tags ? JSON.parse(key.tags) : []
tags = Array.isArray(parsed) ? parsed : []
} catch {
tags = []
}
for (const tag of tags) {
if (typeof tag === 'string') {
const trimmed = tag.trim()
if (trimmed) {
tagCounts.set(trimmed, (tagCounts.get(trimmed) || 0) + 1)
}
}
}
}
// 直接获取全局标签集合(避免重复扫描)
const globalTags = await redis.getGlobalTags()
for (const tag of globalTags) {
const trimmed = tag ? tag.trim() : ''
if (trimmed && !tagCounts.has(trimmed)) {
tagCounts.set(trimmed, 0)
}
}
return Array.from(tagCounts.entries())
.map(([name, count]) => ({ name, count }))
.sort((a, b) => b.count - a.count)
}
// 🏷️ 从所有 API Key 中移除指定标签
async removeTagFromAllKeys(tagName) {
const normalizedName = (tagName || '').trim()
if (!normalizedName) {
return { affectedCount: 0 }
}
const apiKeys = await redis.getAllApiKeys()
let affectedCount = 0
for (const key of apiKeys) {
if (key.isDeleted === 'true') continue
let tags = []
try {
const parsed = key.tags ? JSON.parse(key.tags) : []
tags = Array.isArray(parsed) ? parsed : []
} catch {
tags = []
}
// 匹配时 trim 比较,过滤非字符串
const strTags = tags.filter((t) => typeof t === 'string')
if (strTags.some((t) => t.trim() === normalizedName)) {
const newTags = strTags.filter((t) => t.trim() !== normalizedName)
await this.updateApiKey(key.id, { tags: newTags })
affectedCount++
}
}
// 同时从全局标签集合删除
await redis.removeTag(normalizedName)
await redis.removeTag(tagName) // 也删除原始值(可能带空格)
return { affectedCount }
}
// 🏷️ 重命名标签
async renameTag(oldName, newName) {
if (!newName || !newName.trim()) {
return { affectedCount: 0, error: '新标签名不能为空' }
}
const normalizedOld = (oldName || '').trim()
const normalizedNew = newName.trim()
if (!normalizedOld) {
return { affectedCount: 0, error: '旧标签名不能为空' }
}
const apiKeys = await redis.getAllApiKeys()
let affectedCount = 0
let foundInKeys = false
for (const key of apiKeys) {
if (key.isDeleted === 'true') continue
let tags = []
try {
const parsed = key.tags ? JSON.parse(key.tags) : []
tags = Array.isArray(parsed) ? parsed : []
} catch {
tags = []
}
// 匹配时 trim 比较,过滤非字符串
const strTags = tags.filter((t) => typeof t === 'string')
if (strTags.some((t) => t.trim() === normalizedOld)) {
foundInKeys = true
const newTags = [...new Set(strTags.map((t) => (t.trim() === normalizedOld ? normalizedNew : t)))]
await this.updateApiKey(key.id, { tags: newTags })
affectedCount++
}
}
// 检查全局集合是否有该标签
const globalTags = await redis.getGlobalTags()
const foundInGlobal = globalTags.some((t) => typeof t === 'string' && t.trim() === normalizedOld)
if (!foundInKeys && !foundInGlobal) {
return { affectedCount: 0, error: '标签不存在' }
}
// 同时更新全局标签集合(删旧加新)
await redis.removeTag(normalizedOld)
await redis.removeTag(oldName) // 也删除原始值
await redis.addTag(normalizedNew)
return { affectedCount }
}
// 📋 获取所有API Keys
@@ -623,7 +778,7 @@ class ApiKeyService {
key.isActive = key.isActive === 'true'
key.enableModelRestriction = key.enableModelRestriction === 'true'
key.enableClientRestriction = key.enableClientRestriction === 'true'
key.permissions = key.permissions || 'all' // 兼容旧数据
key.permissions = normalizePermissions(key.permissions)
key.dailyCostLimit = parseFloat(key.dailyCostLimit || 0)
key.totalCostLimit = parseFloat(key.totalCostLimit || 0)
key.weeklyOpusCostLimit = parseFloat(key.weeklyOpusCostLimit || 0)
@@ -1060,12 +1215,7 @@ class ApiKeyService {
'userId', // 新增用户ID所有者变更
'userUsername', // 新增:用户名(所有者变更)
'createdBy', // 新增:创建者(所有者变更)
// 聚合 Key 相关字段
'isAggregated',
'quotaLimit',
'quotaUsed',
'serviceQuotaLimits',
'serviceQuotaUsed'
'serviceRates' // API Key 级别服务倍率
]
const updatedData = { ...keyData }
@@ -1075,21 +1225,17 @@ class ApiKeyService {
field === 'restrictedModels' ||
field === 'allowedClients' ||
field === 'tags' ||
field === 'serviceQuotaLimits' ||
field === 'serviceQuotaUsed'
field === 'serviceRates'
) {
// 特殊处理数组/对象字段
updatedData[field] = JSON.stringify(
value || (field === 'serviceQuotaLimits' || field === 'serviceQuotaUsed' ? {} : [])
)
updatedData[field] = JSON.stringify(value || (field === 'serviceRates' ? {} : []))
} else if (field === 'permissions') {
// permissions 可能是数组(聚合 Key或字符串传统 Key
// permissions 可能是数组或字符串
updatedData[field] = Array.isArray(value) ? JSON.stringify(value) : value || 'all'
} else if (
field === 'enableModelRestriction' ||
field === 'enableClientRestriction' ||
field === 'isActivated' ||
field === 'isAggregated'
field === 'isActivated'
) {
// 布尔值转字符串
updatedData[field] = String(value)
@@ -1358,7 +1504,7 @@ class ApiKeyService {
}
}
// 📊 记录使用情况支持缓存token和账户级别统计
// 📊 记录使用情况支持缓存token和账户级别统计,应用服务倍率
async recordUsage(
keyId,
inputTokens = 0,
@@ -1366,7 +1512,8 @@ class ApiKeyService {
cacheCreateTokens = 0,
cacheReadTokens = 0,
model = 'unknown',
accountId = null
accountId = null,
accountType = null
) {
try {
const totalTokens = inputTokens + outputTokens + cacheCreateTokens + cacheReadTokens
@@ -1404,12 +1551,21 @@ class ApiKeyService {
isLongContextRequest
)
// 记录费用统计
if (costInfo.costs.total > 0) {
await redis.incrementDailyCost(keyId, costInfo.costs.total)
// 记录费用统计(应用服务倍率)
const realCost = costInfo.costs.total
let ratedCost = realCost
if (realCost > 0) {
const serviceRatesService = require('./serviceRatesService')
const service = serviceRatesService.getService(accountType, model)
ratedCost = await this.calculateRatedCost(keyId, service, realCost)
await redis.incrementDailyCost(keyId, ratedCost, realCost)
logger.database(
`💰 Recorded cost for ${keyId}: $${costInfo.costs.total.toFixed(6)}, model: ${model}`
`💰 Recorded cost for ${keyId}: rated=$${ratedCost.toFixed(6)}, real=$${realCost.toFixed(6)}, model: ${model}, service: ${service}`
)
// 记录 Opus 周费用(如果适用)
await this.recordOpusCost(keyId, ratedCost, realCost, model, accountType)
} else {
logger.debug(`💰 No cost recorded for ${keyId} - zero cost for model: ${model}`)
}
@@ -1452,19 +1608,20 @@ class ApiKeyService {
}
}
// 记录单次请求的使用详情
const usageCost = costInfo && costInfo.costs ? costInfo.costs.total || 0 : 0
// 记录单次请求的使用详情(同时保存真实成本和倍率成本)
await redis.addUsageRecord(keyId, {
timestamp: new Date().toISOString(),
model,
accountId: accountId || null,
accountType: accountType || null,
inputTokens,
outputTokens,
cacheCreateTokens,
cacheReadTokens,
totalTokens,
cost: Number(usageCost.toFixed(6)),
costBreakdown: costInfo && costInfo.costs ? costInfo.costs : undefined
cost: Number(ratedCost.toFixed(6)),
realCost: Number(realCost.toFixed(6)),
realCostBreakdown: costInfo && costInfo.costs ? costInfo.costs : undefined
})
const logParts = [`Model: ${model}`, `Input: ${inputTokens}`, `Output: ${outputTokens}`]
@@ -1483,28 +1640,26 @@ class ApiKeyService {
}
// 📊 记录 Opus 模型费用(仅限 claude 和 claude-console 账户)
async recordOpusCost(keyId, cost, model, accountType) {
// ratedCost: 倍率后的成本(用于限额校验)
// realCost: 真实成本(用于对账),如果不传则等于 ratedCost
async recordOpusCost(keyId, ratedCost, realCost, model, accountType) {
try {
// 判断是否为 Opus 模型
if (!model || !model.toLowerCase().includes('claude-opus')) {
return // 不是 Opus 模型,直接返回
}
// 判断是否为 claude、claude-console 或 ccr 账户
if (
!accountType ||
(accountType !== 'claude' && accountType !== 'claude-console' && accountType !== 'ccr')
) {
// 判断是否为 claude-official、claude-console 或 ccr 账户
const opusAccountTypes = ['claude-official', 'claude-console', 'ccr']
if (!accountType || !opusAccountTypes.includes(accountType)) {
logger.debug(`⚠️ Skipping Opus cost recording for non-Claude account type: ${accountType}`)
return // 不是 claude 账户,直接返回
}
// 记录 Opus 周费用
await redis.incrementWeeklyOpusCost(keyId, cost)
// 记录 Opus 周费用(倍率成本和真实成本)
await redis.incrementWeeklyOpusCost(keyId, ratedCost, realCost)
logger.database(
`💰 Recorded Opus weekly cost for ${keyId}: $${cost.toFixed(
6
)}, model: ${model}, account type: ${accountType}`
`💰 Recorded Opus weekly cost for ${keyId}: rated=$${ratedCost.toFixed(6)}, real=$${realCost.toFixed(6)}, model: ${model}`
)
} catch (error) {
logger.error('❌ Failed to record Opus cost:', error)
@@ -1603,15 +1758,22 @@ class ApiKeyService {
costInfo.isLongContextRequest || false // 传递 1M 上下文请求标记
)
// 记录费用统计
if (costInfo.totalCost > 0) {
await redis.incrementDailyCost(keyId, costInfo.totalCost)
// 记录费用统计(应用服务倍率)
const realCostWithDetails = costInfo.totalCost || 0
let ratedCostWithDetails = realCostWithDetails
if (realCostWithDetails > 0) {
const serviceRatesService = require('./serviceRatesService')
const service = serviceRatesService.getService(accountType, model)
ratedCostWithDetails = await this.calculateRatedCost(keyId, service, realCostWithDetails)
// 记录倍率成本和真实成本
await redis.incrementDailyCost(keyId, ratedCostWithDetails, realCostWithDetails)
logger.database(
`💰 Recorded cost for ${keyId}: $${costInfo.totalCost.toFixed(6)}, model: ${model}`
`💰 Recorded cost for ${keyId}: rated=$${ratedCostWithDetails.toFixed(6)}, real=$${realCostWithDetails.toFixed(6)}, model: ${model}, service: ${service}`
)
// 记录 Opus 周费用(如果适用)
await this.recordOpusCost(keyId, costInfo.totalCost, model, accountType)
// 记录 Opus 周费用(如果适用,也应用倍率
await this.recordOpusCost(keyId, ratedCostWithDetails, realCostWithDetails, model, accountType)
// 记录详细的缓存费用(如果有)
if (costInfo.ephemeral5mCost > 0 || costInfo.ephemeral1hCost > 0) {
@@ -1683,8 +1845,9 @@ class ApiKeyService {
ephemeral5mTokens,
ephemeral1hTokens,
totalTokens,
cost: Number((costInfo.totalCost || 0).toFixed(6)),
costBreakdown: {
cost: Number(ratedCostWithDetails.toFixed(6)),
realCost: Number(realCostWithDetails.toFixed(6)),
realCostBreakdown: {
input: costInfo.inputCost || 0,
output: costInfo.outputCost || 0,
cacheCreate: costInfo.cacheCreateCost || 0,
@@ -1921,9 +2084,19 @@ class ApiKeyService {
const recordLimit = optionObject.recordLimit || 20
const recentRecords = await redis.getUsageRecords(keyId, recordLimit)
// API 兼容:同时输出 costBreakdown 和 realCostBreakdown
const compatibleRecords = recentRecords.map((record) => {
const breakdown = record.realCostBreakdown || record.costBreakdown
return {
...record,
costBreakdown: breakdown,
realCostBreakdown: breakdown
}
})
return {
...usageStats,
recentRecords
recentRecords: compatibleRecords
}
}
@@ -2022,7 +2195,7 @@ class ApiKeyService {
userId: keyData.userId,
userUsername: keyData.userUsername,
createdBy: keyData.createdBy,
permissions: keyData.permissions,
permissions: normalizePermissions(keyData.permissions),
dailyCostLimit: parseFloat(keyData.dailyCostLimit || 0),
totalCostLimit: parseFloat(keyData.totalCostLimit || 0),
// 所有平台账户绑定字段
@@ -2268,320 +2441,97 @@ class ApiKeyService {
}
// ═══════════════════════════════════════════════════════════════════════════
// 聚合 Key 相关方法
// 服务倍率和费用限制相关方法
// ═══════════════════════════════════════════════════════════════════════════
/**
* 判断是否为聚合 Key
*/
isAggregatedKey(keyData) {
return keyData && keyData.isAggregated === 'true'
}
/**
* 获取转换为聚合 Key 的预览信息
* 计算应用倍率后的费用
* 公式:消费计费 = 真实消费 × 全局倍率 × Key 倍率
* @param {string} keyId - API Key ID
* @returns {Object} 转换预览信息
*/
async getConvertToAggregatedPreview(keyId) {
try {
const keyData = await redis.getApiKey(keyId)
if (!keyData || Object.keys(keyData).length === 0) {
throw new Error('API key not found')
}
if (keyData.isAggregated === 'true') {
throw new Error('API key is already aggregated')
}
// 获取当前服务倍率
const ratesConfig = await serviceRatesService.getRates()
// 确定原服务类型
let originalService = keyData.permissions || 'all'
try {
const parsed = JSON.parse(originalService)
if (Array.isArray(parsed)) {
originalService = parsed[0] || 'claude'
}
} catch (e) {
// 不是 JSON保持原值
}
// 映射 permissions 到服务类型
const serviceMap = {
all: 'claude',
claude: 'claude',
gemini: 'gemini',
openai: 'codex',
droid: 'droid',
bedrock: 'bedrock',
azure: 'azure',
ccr: 'ccr'
}
const mappedService = serviceMap[originalService] || 'claude'
const serviceRate = ratesConfig.rates[mappedService] || 1.0
// 获取已消耗的真实成本
const costStats = await redis.getCostStats(keyId)
const totalRealCost = costStats?.total || 0
// 获取原限额
const originalLimit = parseFloat(keyData.totalCostLimit || 0)
// 计算建议的 CC 额度
const suggestedQuotaLimit = originalLimit * serviceRate
const suggestedQuotaUsed = totalRealCost * serviceRate
// 获取可用服务列表
const availableServices = Object.keys(ratesConfig.rates)
return {
keyId,
keyName: keyData.name,
originalService: mappedService,
originalPermissions: originalService,
originalLimit,
originalUsed: totalRealCost,
serviceRate,
suggestedQuotaLimit: Math.round(suggestedQuotaLimit * 1000) / 1000,
suggestedQuotaUsed: Math.round(suggestedQuotaUsed * 1000) / 1000,
availableServices,
ratesConfig: ratesConfig.rates
}
} catch (error) {
logger.error('❌ Failed to get convert preview:', error)
throw error
}
}
/**
* 将传统 Key 转换为聚合 Key
* @param {string} keyId - API Key ID
* @param {Object} options - 转换选项
* @param {number} options.quotaLimit - CC 额度上限
* @param {Array<string>} options.permissions - 允许的服务列表
* @param {Object} options.serviceQuotaLimits - 分服务额度限制(可选)
* @returns {Object} 转换结果
*/
async convertToAggregated(keyId, options = {}) {
try {
const keyData = await redis.getApiKey(keyId)
if (!keyData || Object.keys(keyData).length === 0) {
throw new Error('API key not found')
}
if (keyData.isAggregated === 'true') {
throw new Error('API key is already aggregated')
}
const {
quotaLimit,
permissions,
serviceQuotaLimits = {},
quotaUsed = null // 如果不传,自动计算
} = options
if (quotaLimit === undefined || quotaLimit === null) {
throw new Error('quotaLimit is required')
}
if (!permissions || !Array.isArray(permissions) || permissions.length === 0) {
throw new Error('permissions must be a non-empty array')
}
// 计算已消耗的 CC 额度
let calculatedQuotaUsed = quotaUsed
if (calculatedQuotaUsed === null) {
// 自动计算:获取原服务的倍率,按倍率换算已消耗成本
const preview = await this.getConvertToAggregatedPreview(keyId)
calculatedQuotaUsed = preview.suggestedQuotaUsed
}
// 更新 Key 数据
const updates = {
isAggregated: true,
quotaLimit,
quotaUsed: calculatedQuotaUsed,
permissions,
serviceQuotaLimits,
serviceQuotaUsed: {} // 转换时重置分服务统计
}
await this.updateApiKey(keyId, updates)
logger.success(`🔄 Converted API key ${keyId} to aggregated key with ${quotaLimit} CC quota`)
return {
success: true,
keyId,
quotaLimit,
quotaUsed: calculatedQuotaUsed,
permissions
}
} catch (error) {
logger.error('❌ Failed to convert to aggregated key:', error)
throw error
}
}
/**
* 快速检查聚合 Key 额度(用于请求前检查)
* @param {string} keyId - API Key ID
* @returns {Object} { allowed: boolean, reason?: string, quotaRemaining?: number }
*/
async quickQuotaCheck(keyId) {
try {
const [quotaUsed, quotaLimit, isAggregated, isActive] = await redis.client.hmget(
`api_key:${keyId}`,
'quotaUsed',
'quotaLimit',
'isAggregated',
'isActive'
)
// 非聚合 Key 不检查额度
if (isAggregated !== 'true') {
return { allowed: true, isAggregated: false }
}
if (isActive !== 'true') {
return { allowed: false, reason: 'inactive', isAggregated: true }
}
const used = parseFloat(quotaUsed || 0)
const limit = parseFloat(quotaLimit || 0)
// 额度为 0 表示不限制
if (limit === 0) {
return { allowed: true, isAggregated: true, quotaRemaining: Infinity }
}
if (used >= limit) {
return {
allowed: false,
reason: 'quota_exceeded',
isAggregated: true,
quotaUsed: used,
quotaLimit: limit
}
}
return { allowed: true, isAggregated: true, quotaRemaining: limit - used }
} catch (error) {
logger.error('❌ Quick quota check failed:', error)
// 出错时允许通过,避免阻塞请求
return { allowed: true, error: error.message }
}
}
/**
* 异步扣减聚合 Key 额度(请求完成后调用)
* @param {string} keyId - API Key ID
* @param {number} costUSD - 真实成本USD
* @param {string} service - 服务类型
* @returns {Promise<void>}
* @param {number} realCost - 真实成本USD
* @returns {Promise<number>} 应用倍率后的费用
*/
async deductQuotaAsync(keyId, costUSD, service) {
// 使用 setImmediate 确保不阻塞响应
setImmediate(async () => {
async calculateRatedCost(keyId, service, realCost) {
try {
// 获取全局倍率
const globalRate = await serviceRatesService.getServiceRate(service)
// 获取 Key 倍率
const keyData = await redis.getApiKey(keyId)
let keyRates = {}
try {
// 获取服务倍率
const rate = await serviceRatesService.getServiceRate(service)
const quotaToDeduct = costUSD * rate
// 原子更新总额度
await redis.client.hincrbyfloat(`api_key:${keyId}`, 'quotaUsed', quotaToDeduct)
// 更新分服务额度
const serviceQuotaUsedStr = await redis.client.hget(`api_key:${keyId}`, 'serviceQuotaUsed')
let serviceQuotaUsed = {}
try {
serviceQuotaUsed = JSON.parse(serviceQuotaUsedStr || '{}')
} catch (e) {
serviceQuotaUsed = {}
}
serviceQuotaUsed[service] = (serviceQuotaUsed[service] || 0) + quotaToDeduct
await redis.client.hset(
`api_key:${keyId}`,
'serviceQuotaUsed',
JSON.stringify(serviceQuotaUsed)
)
logger.debug(
`📊 Deducted ${quotaToDeduct.toFixed(4)} CC quota from key ${keyId} (service: ${service}, rate: ${rate})`
)
} catch (error) {
logger.error(`❌ Failed to deduct quota for key ${keyId}:`, error)
keyRates = JSON.parse(keyData?.serviceRates || '{}')
} catch (e) {
keyRates = {}
}
})
const keyRate = keyRates[service] ?? 1.0
// 相乘计算
return realCost * globalRate * keyRate
} catch (error) {
logger.error('❌ Failed to calculate rated cost:', error)
// 出错时返回原始费用
return realCost
}
}
/**
* 增加聚合 Key 额度(用于核销额度卡)
* 增加 API Key 费用限制(用于核销额度卡)
* @param {string} keyId - API Key ID
* @param {number} quotaAmount - 要增加的 CC 额度
* @returns {Promise<Object>} { success: boolean, newQuotaLimit: number }
* @param {number} amount - 要增加的金额USD
* @returns {Promise<Object>} { success: boolean, newTotalCostLimit: number }
*/
async addQuota(keyId, quotaAmount) {
async addTotalCostLimit(keyId, amount) {
try {
const keyData = await redis.getApiKey(keyId)
if (!keyData || Object.keys(keyData).length === 0) {
throw new Error('API key not found')
}
if (keyData.isAggregated !== 'true') {
throw new Error('Only aggregated keys can add quota')
}
const currentLimit = parseFloat(keyData.totalCostLimit || 0)
const newLimit = currentLimit + amount
const currentLimit = parseFloat(keyData.quotaLimit || 0)
const newLimit = currentLimit + quotaAmount
await redis.client.hset(`apikey:${keyId}`, 'totalCostLimit', String(newLimit))
await redis.client.hset(`api_key:${keyId}`, 'quotaLimit', String(newLimit))
logger.success(`💰 Added $${amount} to key ${keyId}, new limit: $${newLimit}`)
logger.success(`💰 Added ${quotaAmount} CC quota to key ${keyId}, new limit: ${newLimit}`)
return { success: true, previousLimit: currentLimit, newQuotaLimit: newLimit }
return { success: true, previousLimit: currentLimit, newTotalCostLimit: newLimit }
} catch (error) {
logger.error('❌ Failed to add quota:', error)
logger.error('❌ Failed to add total cost limit:', error)
throw error
}
}
/**
* 减少聚合 Key 额度(用于撤销核销)
* 减少 API Key 费用限制(用于撤销核销)
* @param {string} keyId - API Key ID
* @param {number} quotaAmount - 要减少的 CC 额度
* @returns {Promise<Object>} { success: boolean, newQuotaLimit: number, actualDeducted: number }
* @param {number} amount - 要减少的金额USD
* @returns {Promise<Object>} { success: boolean, newTotalCostLimit: number, actualDeducted: number }
*/
async deductQuotaLimit(keyId, quotaAmount) {
async deductTotalCostLimit(keyId, amount) {
try {
const keyData = await redis.getApiKey(keyId)
if (!keyData || Object.keys(keyData).length === 0) {
throw new Error('API key not found')
}
if (keyData.isAggregated !== 'true') {
throw new Error('Only aggregated keys can deduct quota')
}
const currentLimit = parseFloat(keyData.quotaLimit || 0)
const currentUsed = parseFloat(keyData.quotaUsed || 0)
const currentLimit = parseFloat(keyData.totalCostLimit || 0)
const costStats = await redis.getCostStats(keyId)
const currentUsed = costStats?.total || 0
// 不能扣到比已使用的还少
const minLimit = currentUsed
const actualDeducted = Math.min(quotaAmount, currentLimit - minLimit)
const newLimit = Math.max(currentLimit - quotaAmount, minLimit)
const actualDeducted = Math.min(amount, currentLimit - minLimit)
const newLimit = Math.max(currentLimit - amount, minLimit)
await redis.client.hset(`api_key:${keyId}`, 'quotaLimit', String(newLimit))
await redis.client.hset(`apikey:${keyId}`, 'totalCostLimit', String(newLimit))
logger.success(
`💸 Deducted ${actualDeducted} CC quota from key ${keyId}, new limit: ${newLimit}`
)
logger.success(`💸 Deducted $${actualDeducted} from key ${keyId}, new limit: $${newLimit}`)
return { success: true, previousLimit: currentLimit, newQuotaLimit: newLimit, actualDeducted }
return { success: true, previousLimit: currentLimit, newTotalCostLimit: newLimit, actualDeducted }
} catch (error) {
logger.error('❌ Failed to deduct quota limit:', error)
logger.error('❌ Failed to deduct total cost limit:', error)
throw error
}
}
@@ -2643,4 +2593,8 @@ const apiKeyService = new ApiKeyService()
// 为了方便其他服务调用,导出 recordUsage 方法
apiKeyService.recordUsageMetrics = apiKeyService.recordUsage.bind(apiKeyService)
// 导出权限辅助函数供路由使用
apiKeyService.hasPermission = hasPermission
apiKeyService.normalizePermissions = normalizePermissions
module.exports = apiKeyService

View File

@@ -0,0 +1,133 @@
const axios = require('axios')
const logger = require('../../utils/logger')
const ProxyHelper = require('../../utils/proxyHelper')
/**
* Provider 抽象基类
* 各平台 Provider 需继承并实现 queryBalance(account)
*/
class BaseBalanceProvider {
constructor(platform) {
this.platform = platform
this.logger = logger
}
/**
* 查询余额(抽象方法)
* @param {object} account - 账户对象
* @returns {Promise<object>}
* 形如:
* {
* balance: number|null,
* currency?: string,
* quota?: { daily, used, remaining, resetAt, percentage, unlimited? },
* queryMethod?: 'api'|'field'|'local',
* rawData?: any
* }
*/
async queryBalance(_account) {
throw new Error('queryBalance 方法必须由子类实现')
}
/**
* 通用 HTTP 请求方法(支持代理)
* @param {string} url
* @param {object} options
* @param {object} account
*/
async makeRequest(url, options = {}, account = {}) {
const config = {
url,
method: options.method || 'GET',
headers: options.headers || {},
timeout: options.timeout || 15000,
data: options.data,
params: options.params,
responseType: options.responseType
}
const proxyConfig = account.proxyConfig || account.proxy
if (proxyConfig) {
const agent = ProxyHelper.createProxyAgent(proxyConfig)
if (agent) {
config.httpAgent = agent
config.httpsAgent = agent
config.proxy = false
}
}
try {
const response = await axios(config)
return {
success: true,
data: response.data,
status: response.status,
headers: response.headers
}
} catch (error) {
const status = error.response?.status
const message = error.response?.data?.message || error.message || '请求失败'
this.logger.debug(`余额 Provider HTTP 请求失败: ${url} (${this.platform})`, {
status,
message
})
return { success: false, status, error: message }
}
}
/**
* 从账户字段读取 dailyQuota / dailyUsage通用降级方案
* 注意:部分平台 dailyUsage 字段可能不是实时值,最终以 AccountBalanceService 的本地统计为准
*/
readQuotaFromFields(account) {
const dailyQuota = Number(account?.dailyQuota || 0)
const dailyUsage = Number(account?.dailyUsage || 0)
// 无限制
if (!Number.isFinite(dailyQuota) || dailyQuota <= 0) {
return {
balance: null,
currency: 'USD',
quota: {
daily: Infinity,
used: Number.isFinite(dailyUsage) ? dailyUsage : 0,
remaining: Infinity,
percentage: 0,
unlimited: true
},
queryMethod: 'field'
}
}
const used = Number.isFinite(dailyUsage) ? dailyUsage : 0
const remaining = Math.max(0, dailyQuota - used)
const percentage = dailyQuota > 0 ? (used / dailyQuota) * 100 : 0
return {
balance: remaining,
currency: 'USD',
quota: {
daily: dailyQuota,
used,
remaining,
percentage: Math.round(percentage * 100) / 100
},
queryMethod: 'field'
}
}
parseCurrency(data) {
return data?.currency || data?.Currency || 'USD'
}
async safeExecute(fn, fallbackValue = null) {
try {
return await fn()
} catch (error) {
this.logger.error(`余额 Provider 执行失败: ${this.platform}`, error)
return fallbackValue
}
}
}
module.exports = BaseBalanceProvider

View File

@@ -0,0 +1,30 @@
const BaseBalanceProvider = require('./baseBalanceProvider')
const claudeAccountService = require('../claudeAccountService')
class ClaudeBalanceProvider extends BaseBalanceProvider {
constructor() {
super('claude')
}
/**
* ClaudeOAuth优先尝试获取 OAuth usage用于配额/使用信息),不强行提供余额金额
*/
async queryBalance(account) {
this.logger.debug(`查询 Claude 余额OAuth usage: ${account?.id}`)
// 仅 OAuth 账户可用;失败时降级
const usageData = await claudeAccountService.fetchOAuthUsage(account.id).catch(() => null)
if (!usageData) {
return { balance: null, currency: 'USD', queryMethod: 'local' }
}
return {
balance: null,
currency: 'USD',
queryMethod: 'api',
rawData: usageData
}
}
}
module.exports = ClaudeBalanceProvider

View File

@@ -0,0 +1,14 @@
const BaseBalanceProvider = require('./baseBalanceProvider')
class ClaudeConsoleBalanceProvider extends BaseBalanceProvider {
constructor() {
super('claude-console')
}
async queryBalance(account) {
this.logger.debug(`查询 Claude Console 余额(字段): ${account?.id}`)
return this.readQuotaFromFields(account)
}
}
module.exports = ClaudeConsoleBalanceProvider

View File

@@ -0,0 +1,250 @@
const BaseBalanceProvider = require('./baseBalanceProvider')
const antigravityClient = require('../antigravityClient')
const geminiAccountService = require('../geminiAccountService')
const OAUTH_PROVIDER_ANTIGRAVITY = 'antigravity'
function clamp01(value) {
if (typeof value !== 'number' || !Number.isFinite(value)) {
return null
}
if (value < 0) {
return 0
}
if (value > 1) {
return 1
}
return value
}
function round2(value) {
if (typeof value !== 'number' || !Number.isFinite(value)) {
return null
}
return Math.round(value * 100) / 100
}
function normalizeQuotaCategory(displayName, modelId) {
const name = String(displayName || '')
const id = String(modelId || '')
if (name.includes('Gemini') && name.includes('Pro')) {
return 'Gemini Pro'
}
if (name.includes('Gemini') && name.includes('Flash')) {
return 'Gemini Flash'
}
if (name.includes('Gemini') && name.toLowerCase().includes('image')) {
return 'Gemini Image'
}
if (name.includes('Claude') || name.includes('GPT-OSS')) {
return 'Claude'
}
if (id.startsWith('gemini-3-pro-') || id.startsWith('gemini-2.5-pro')) {
return 'Gemini Pro'
}
if (id.startsWith('gemini-3-flash') || id.startsWith('gemini-2.5-flash')) {
return 'Gemini Flash'
}
if (id.includes('image')) {
return 'Gemini Image'
}
if (id.includes('claude') || id.includes('gpt-oss')) {
return 'Claude'
}
return name || id || 'Unknown'
}
function buildAntigravityQuota(modelsResponse) {
const models = modelsResponse && typeof modelsResponse === 'object' ? modelsResponse.models : null
if (!models || typeof models !== 'object') {
return null
}
const parseRemainingFraction = (quotaInfo) => {
if (!quotaInfo || typeof quotaInfo !== 'object') {
return null
}
const raw =
quotaInfo.remainingFraction ??
quotaInfo.remaining_fraction ??
quotaInfo.remaining ??
undefined
const num = typeof raw === 'number' ? raw : typeof raw === 'string' ? Number(raw) : NaN
if (!Number.isFinite(num)) {
return null
}
return clamp01(num)
}
const allowedCategories = new Set(['Gemini Pro', 'Claude', 'Gemini Flash', 'Gemini Image'])
const fixedOrder = ['Gemini Pro', 'Claude', 'Gemini Flash', 'Gemini Image']
const categoryMap = new Map()
for (const [modelId, modelDataRaw] of Object.entries(models)) {
if (!modelDataRaw || typeof modelDataRaw !== 'object') {
continue
}
const displayName = modelDataRaw.displayName || modelDataRaw.display_name || modelId
const quotaInfo = modelDataRaw.quotaInfo || modelDataRaw.quota_info || null
const remainingFraction = parseRemainingFraction(quotaInfo)
if (remainingFraction === null) {
continue
}
const remainingPercent = round2(remainingFraction * 100)
const usedPercent = round2(100 - remainingPercent)
const resetAt = quotaInfo?.resetTime || quotaInfo?.reset_time || null
const category = normalizeQuotaCategory(displayName, modelId)
if (!allowedCategories.has(category)) {
continue
}
const entry = {
category,
modelId,
displayName: String(displayName || modelId || category),
remainingPercent,
usedPercent,
resetAt: typeof resetAt === 'string' && resetAt.trim() ? resetAt : null
}
const existing = categoryMap.get(category)
if (!existing || entry.remainingPercent < existing.remainingPercent) {
categoryMap.set(category, entry)
}
}
const buckets = fixedOrder.map((category) => {
const existing = categoryMap.get(category) || null
if (existing) {
return existing
}
return {
category,
modelId: '',
displayName: category,
remainingPercent: null,
usedPercent: null,
resetAt: null
}
})
if (buckets.length === 0) {
return null
}
const critical = buckets
.filter((item) => item.remainingPercent !== null)
.reduce((min, item) => {
if (!min) {
return item
}
return (item.remainingPercent ?? 0) < (min.remainingPercent ?? 0) ? item : min
}, null)
if (!critical) {
return null
}
return {
balance: null,
currency: 'USD',
quota: {
type: 'antigravity',
total: 100,
used: critical.usedPercent,
remaining: critical.remainingPercent,
percentage: critical.usedPercent,
resetAt: critical.resetAt,
buckets: buckets.map((item) => ({
category: item.category,
remaining: item.remainingPercent,
used: item.usedPercent,
percentage: item.usedPercent,
resetAt: item.resetAt
}))
},
queryMethod: 'api',
rawData: {
modelsCount: Object.keys(models).length,
bucketCount: buckets.length
}
}
}
class GeminiBalanceProvider extends BaseBalanceProvider {
constructor() {
super('gemini')
}
async queryBalance(account) {
const oauthProvider = account?.oauthProvider
if (oauthProvider !== OAUTH_PROVIDER_ANTIGRAVITY) {
if (account && Object.prototype.hasOwnProperty.call(account, 'dailyQuota')) {
return this.readQuotaFromFields(account)
}
return { balance: null, currency: 'USD', queryMethod: 'local' }
}
const accessToken = String(account?.accessToken || '').trim()
const refreshToken = String(account?.refreshToken || '').trim()
const proxyConfig = account?.proxyConfig || account?.proxy || null
if (!accessToken) {
throw new Error('Antigravity 账户缺少 accessToken')
}
const fetch = async (token) =>
await antigravityClient.fetchAvailableModels({
accessToken: token,
proxyConfig
})
let data
try {
data = await fetch(accessToken)
} catch (error) {
const status = error?.response?.status
if ((status === 401 || status === 403) && refreshToken) {
const refreshed = await geminiAccountService.refreshAccessToken(
refreshToken,
proxyConfig,
OAUTH_PROVIDER_ANTIGRAVITY
)
const nextToken = String(refreshed?.access_token || '').trim()
if (!nextToken) {
throw error
}
data = await fetch(nextToken)
} else {
throw error
}
}
const mapped = buildAntigravityQuota(data)
if (!mapped) {
return {
balance: null,
currency: 'USD',
quota: null,
queryMethod: 'api',
rawData: data || null
}
}
return mapped
}
}
module.exports = GeminiBalanceProvider

View File

@@ -0,0 +1,23 @@
const BaseBalanceProvider = require('./baseBalanceProvider')
class GenericBalanceProvider extends BaseBalanceProvider {
constructor(platform) {
super(platform)
}
async queryBalance(account) {
this.logger.debug(`${this.platform} 暂无专用余额 API实现降级策略`)
if (account && Object.prototype.hasOwnProperty.call(account, 'dailyQuota')) {
return this.readQuotaFromFields(account)
}
return {
balance: null,
currency: 'USD',
queryMethod: 'local'
}
}
}
module.exports = GenericBalanceProvider

View File

@@ -0,0 +1,25 @@
const ClaudeBalanceProvider = require('./claudeBalanceProvider')
const ClaudeConsoleBalanceProvider = require('./claudeConsoleBalanceProvider')
const OpenAIResponsesBalanceProvider = require('./openaiResponsesBalanceProvider')
const GenericBalanceProvider = require('./genericBalanceProvider')
const GeminiBalanceProvider = require('./geminiBalanceProvider')
function registerAllProviders(balanceService) {
// Claude
balanceService.registerProvider('claude', new ClaudeBalanceProvider())
balanceService.registerProvider('claude-console', new ClaudeConsoleBalanceProvider())
// OpenAI / Codex
balanceService.registerProvider('openai-responses', new OpenAIResponsesBalanceProvider())
balanceService.registerProvider('openai', new GenericBalanceProvider('openai'))
balanceService.registerProvider('azure_openai', new GenericBalanceProvider('azure_openai'))
// 其他平台(降级)
balanceService.registerProvider('gemini', new GeminiBalanceProvider())
balanceService.registerProvider('gemini-api', new GenericBalanceProvider('gemini-api'))
balanceService.registerProvider('bedrock', new GenericBalanceProvider('bedrock'))
balanceService.registerProvider('droid', new GenericBalanceProvider('droid'))
balanceService.registerProvider('ccr', new GenericBalanceProvider('ccr'))
}
module.exports = { registerAllProviders }

View File

@@ -0,0 +1,54 @@
const BaseBalanceProvider = require('./baseBalanceProvider')
class OpenAIResponsesBalanceProvider extends BaseBalanceProvider {
constructor() {
super('openai-responses')
}
/**
* OpenAI-Responses
* - 优先使用 dailyQuota 字段(如果配置了额度)
* - 可选:尝试调用兼容 API不同服务商实现不一失败自动降级
*/
async queryBalance(account) {
this.logger.debug(`查询 OpenAI Responses 余额: ${account?.id}`)
// 配置了额度时直接返回(字段法)
if (account?.dailyQuota && Number(account.dailyQuota) > 0) {
return this.readQuotaFromFields(account)
}
// 尝试调用 usage 接口(兼容性不保证)
if (account?.apiKey && account?.baseApi) {
const baseApi = String(account.baseApi).replace(/\/$/, '')
const response = await this.makeRequest(
`${baseApi}/v1/usage`,
{
method: 'GET',
headers: {
Authorization: `Bearer ${account.apiKey}`,
'Content-Type': 'application/json'
}
},
account
)
if (response.success) {
return {
balance: null,
currency: this.parseCurrency(response.data),
queryMethod: 'api',
rawData: response.data
}
}
}
return {
balance: null,
currency: 'USD',
queryMethod: 'local'
}
}
}
module.exports = OpenAIResponsesBalanceProvider

View File

@@ -0,0 +1,210 @@
const vm = require('vm')
const axios = require('axios')
const { isBalanceScriptEnabled } = require('../utils/featureFlags')
/**
* SSRF防护检查URL是否访问内网或敏感地址
* @param {string} url - 要检查的URL
* @returns {boolean} - true表示URL安全
*/
function isUrlSafe(url) {
try {
const parsed = new URL(url)
const hostname = parsed.hostname.toLowerCase()
// 禁止的协议
if (!['http:', 'https:'].includes(parsed.protocol)) {
return false
}
// 禁止访问localhost和私有IP
const privatePatterns = [
/^localhost$/i,
/^127\./,
/^10\./,
/^172\.(1[6-9]|2[0-9]|3[0-1])\./,
/^192\.168\./,
/^169\.254\./, // AWS metadata
/^0\./, // 0.0.0.0
/^::1$/,
/^fc00:/i,
/^fe80:/i,
/\.local$/i,
/\.internal$/i,
/\.localhost$/i
]
for (const pattern of privatePatterns) {
if (pattern.test(hostname)) {
return false
}
}
return true
} catch {
return false
}
}
/**
* 可配置脚本余额查询执行器
* - 脚本格式:({ request: {...}, extractor: function(response){...} })
* - 模板变量:{{baseUrl}}, {{apiKey}}, {{token}}, {{accountId}}, {{platform}}, {{extra}}
*/
class BalanceScriptService {
/**
* 执行脚本:返回标准余额结构 + 原始响应
* @param {object} options
* - scriptBody: string
* - variables: Record<string,string>
* - timeoutSeconds: number
*/
async execute(options = {}) {
if (!isBalanceScriptEnabled()) {
const error = new Error('余额脚本功能已禁用(可通过 BALANCE_SCRIPT_ENABLED=true 启用)')
error.code = 'BALANCE_SCRIPT_DISABLED'
throw error
}
const scriptBody = options.scriptBody?.trim()
if (!scriptBody) {
throw new Error('脚本内容为空')
}
const timeoutMs = Math.max(1, (options.timeoutSeconds || 10) * 1000)
const sandbox = {
console,
Math,
Date
}
let scriptResult
try {
const wrapped = scriptBody.startsWith('(') ? scriptBody : `(${scriptBody})`
const script = new vm.Script(wrapped)
scriptResult = script.runInNewContext(sandbox, { timeout: timeoutMs })
} catch (error) {
throw new Error(`脚本解析失败: ${error.message}`)
}
if (!scriptResult || typeof scriptResult !== 'object') {
throw new Error('脚本返回格式无效(需返回 { request, extractor }')
}
const variables = options.variables || {}
const request = this.applyTemplates(scriptResult.request || {}, variables)
const { extractor } = scriptResult
if (!request?.url || typeof request.url !== 'string') {
throw new Error('脚本 request.url 不能为空')
}
// SSRF防护验证URL安全性
if (!isUrlSafe(request.url)) {
throw new Error('脚本 request.url 不安全禁止访问内网地址、localhost或使用非HTTP(S)协议')
}
if (typeof extractor !== 'function') {
throw new Error('脚本 extractor 必须是函数')
}
const axiosConfig = {
url: request.url,
method: (request.method || 'GET').toUpperCase(),
headers: request.headers || {},
timeout: timeoutMs
}
if (request.params) {
axiosConfig.params = request.params
}
if (request.body || request.data) {
axiosConfig.data = request.body || request.data
}
let httpResponse
try {
httpResponse = await axios(axiosConfig)
} catch (error) {
const { response } = error || {}
const { status, data } = response || {}
throw new Error(
`请求失败: ${status || ''} ${error.message}${data ? ` | ${JSON.stringify(data)}` : ''}`
)
}
const responseData = httpResponse?.data
let extracted = {}
try {
extracted = extractor(responseData) || {}
} catch (error) {
throw new Error(`extractor 执行失败: ${error.message}`)
}
const mapped = this.mapExtractorResult(extracted, responseData)
return {
mapped,
extracted,
response: {
status: httpResponse?.status,
headers: httpResponse?.headers,
data: responseData
}
}
}
applyTemplates(value, variables) {
if (typeof value === 'string') {
return value.replace(/{{(\w+)}}/g, (_, key) => {
const trimmed = key.trim()
return variables[trimmed] !== undefined ? String(variables[trimmed]) : ''
})
}
if (Array.isArray(value)) {
return value.map((item) => this.applyTemplates(item, variables))
}
if (value && typeof value === 'object') {
const result = {}
Object.keys(value).forEach((k) => {
result[k] = this.applyTemplates(value[k], variables)
})
return result
}
return value
}
mapExtractorResult(result = {}, responseData) {
const isValid = result.isValid !== false
const remaining = Number(result.remaining)
const total = Number(result.total)
const used = Number(result.used)
const currency = result.unit || 'USD'
const quota =
Number.isFinite(total) || Number.isFinite(used)
? {
total: Number.isFinite(total) ? total : null,
used: Number.isFinite(used) ? used : null,
remaining: Number.isFinite(remaining) ? remaining : null,
percentage:
Number.isFinite(total) && total > 0 && Number.isFinite(used)
? (used / total) * 100
: null
}
: null
return {
status: isValid ? 'success' : 'error',
errorMessage: isValid ? '' : result.invalidMessage || '套餐无效',
balance: Number.isFinite(remaining) ? remaining : null,
currency,
quota,
planName: result.planName || null,
extra: result.extra || null,
rawData: responseData || result.raw
}
}
}
module.exports = new BalanceScriptService()

View File

@@ -35,12 +35,13 @@ class BedrockAccountService {
description = '',
region = process.env.AWS_REGION || 'us-east-1',
awsCredentials = null, // { accessKeyId, secretAccessKey, sessionToken }
bearerToken = null, // AWS Bearer Token for Bedrock API Keys
defaultModel = 'us.anthropic.claude-sonnet-4-20250514-v1:0',
isActive = true,
accountType = 'shared', // 'dedicated' or 'shared'
priority = 50, // 调度优先级 (1-100数字越小优先级越高)
schedulable = true, // 是否可被调度
credentialType = 'default' // 'default', 'access_key', 'bearer_token'
credentialType = 'access_key' // 'access_key', 'bearer_token'(默认为 access_key
} = options
const accountId = uuidv4()
@@ -71,6 +72,11 @@ class BedrockAccountService {
accountData.awsCredentials = this._encryptAwsCredentials(awsCredentials)
}
// 加密存储 Bearer Token
if (bearerToken) {
accountData.bearerToken = this._encryptAwsCredentials({ token: bearerToken })
}
const client = redis.getClientSafe()
await client.set(`bedrock_account:${accountId}`, JSON.stringify(accountData))
await redis.addToIndex('bedrock_account:index', accountId)
@@ -107,9 +113,85 @@ class BedrockAccountService {
const account = JSON.parse(accountData)
// 解密AWS凭证用于内部使用
if (account.awsCredentials) {
account.awsCredentials = this._decryptAwsCredentials(account.awsCredentials)
// 根据凭证类型解密对应的凭证
// 增强逻辑:优先按照 credentialType 解密,如果字段不存在则尝试解密实际存在的字段(兜底)
try {
let accessKeyDecrypted = false
let bearerTokenDecrypted = false
// 第一步:按照 credentialType 尝试解密对应的凭证
if (account.credentialType === 'access_key' && account.awsCredentials) {
// Access Key 模式:解密 AWS 凭证
account.awsCredentials = this._decryptAwsCredentials(account.awsCredentials)
accessKeyDecrypted = true
logger.debug(
`🔓 解密 Access Key 成功 - ID: ${accountId}, 类型: ${account.credentialType}`
)
} else if (account.credentialType === 'bearer_token' && account.bearerToken) {
// Bearer Token 模式:解密 Bearer Token
const decrypted = this._decryptAwsCredentials(account.bearerToken)
account.bearerToken = decrypted.token
bearerTokenDecrypted = true
logger.debug(
`🔓 解密 Bearer Token 成功 - ID: ${accountId}, 类型: ${account.credentialType}`
)
} else if (!account.credentialType || account.credentialType === 'default') {
// 向后兼容:旧版本账号可能没有 credentialType 字段,尝试解密所有存在的凭证
if (account.awsCredentials) {
account.awsCredentials = this._decryptAwsCredentials(account.awsCredentials)
accessKeyDecrypted = true
}
if (account.bearerToken) {
const decrypted = this._decryptAwsCredentials(account.bearerToken)
account.bearerToken = decrypted.token
bearerTokenDecrypted = true
}
logger.debug(
`🔓 兼容模式解密 - ID: ${accountId}, Access Key: ${accessKeyDecrypted}, Bearer Token: ${bearerTokenDecrypted}`
)
}
// 第二步:兜底逻辑 - 如果按照 credentialType 没有解密到任何凭证,尝试解密实际存在的字段
if (!accessKeyDecrypted && !bearerTokenDecrypted) {
logger.warn(
`⚠️ credentialType="${account.credentialType}" 与实际字段不匹配,尝试兜底解密 - ID: ${accountId}`
)
if (account.awsCredentials) {
account.awsCredentials = this._decryptAwsCredentials(account.awsCredentials)
accessKeyDecrypted = true
logger.warn(
`🔓 兜底解密 Access Key 成功 - ID: ${accountId}, credentialType 应为 'access_key'`
)
}
if (account.bearerToken) {
const decrypted = this._decryptAwsCredentials(account.bearerToken)
account.bearerToken = decrypted.token
bearerTokenDecrypted = true
logger.warn(
`🔓 兜底解密 Bearer Token 成功 - ID: ${accountId}, credentialType 应为 'bearer_token'`
)
}
}
// 验证至少解密了一种凭证
if (!accessKeyDecrypted && !bearerTokenDecrypted) {
logger.error(
`❌ 未找到任何凭证可解密 - ID: ${accountId}, credentialType: ${account.credentialType}, hasAwsCredentials: ${!!account.awsCredentials}, hasBearerToken: ${!!account.bearerToken}`
)
return {
success: false,
error: 'No valid credentials found in account data'
}
}
} catch (decryptError) {
logger.error(
`❌ 解密Bedrock凭证失败 - ID: ${accountId}, 类型: ${account.credentialType}`,
decryptError
)
return {
success: false,
error: `Credentials decryption failed: ${decryptError.message}`
}
}
logger.debug(`🔍 获取Bedrock账户 - ID: ${accountId}, 名称: ${account.name}`)
@@ -162,7 +244,11 @@ class BedrockAccountService {
updatedAt: account.updatedAt,
type: 'bedrock',
platform: 'bedrock',
hasCredentials: !!account.awsCredentials
// 根据凭证类型判断是否有凭证
hasCredentials:
account.credentialType === 'bearer_token'
? !!account.bearerToken
: !!account.awsCredentials
})
}
}
@@ -242,6 +328,15 @@ class BedrockAccountService {
logger.info(`🔐 重新加密Bedrock账户凭证 - ID: ${accountId}`)
}
// 更新 Bearer Token
if (updates.bearerToken !== undefined) {
if (updates.bearerToken) {
account.bearerToken = this._encryptAwsCredentials({ token: updates.bearerToken })
} else {
delete account.bearerToken
}
}
// ✅ 直接保存 subscriptionExpiresAt如果提供
// Bedrock 没有 token 刷新逻辑,不会覆盖此字段
if (updates.subscriptionExpiresAt !== undefined) {
@@ -353,13 +448,45 @@ class BedrockAccountService {
const account = accountResult.data
logger.info(`🧪 测试Bedrock账户连接 - ID: ${accountId}, 名称: ${account.name}`)
logger.info(
`🧪 测试Bedrock账户连接 - ID: ${accountId}, 名称: ${account.name}, 凭证类型: ${account.credentialType}`
)
// 尝试获取模型列表来测试连接
// 验证凭证是否已解密
const hasValidCredentials =
(account.credentialType === 'access_key' && account.awsCredentials) ||
(account.credentialType === 'bearer_token' && account.bearerToken) ||
(!account.credentialType && (account.awsCredentials || account.bearerToken))
if (!hasValidCredentials) {
logger.error(
`❌ 测试失败:账户没有有效凭证 - ID: ${accountId}, credentialType: ${account.credentialType}`
)
return {
success: false,
error: 'No valid credentials found after decryption'
}
}
// 尝试创建 Bedrock 客户端来验证凭证格式
try {
bedrockRelayService._getBedrockClient(account.region, account)
logger.debug(`✅ Bedrock客户端创建成功 - ID: ${accountId}`)
} catch (clientError) {
logger.error(`❌ 创建Bedrock客户端失败 - ID: ${accountId}`, clientError)
return {
success: false,
error: `Failed to create Bedrock client: ${clientError.message}`
}
}
// 获取可用模型列表(硬编码,但至少验证了凭证格式正确)
const models = await bedrockRelayService.getAvailableModels(account)
if (models && models.length > 0) {
logger.info(`✅ Bedrock账户测试成功 - ID: ${accountId}, 发现 ${models.length} 个模型`)
logger.info(
`✅ Bedrock账户测试成功 - ID: ${accountId}, 发现 ${models.length} 个模型, 凭证类型: ${account.credentialType}`
)
return {
success: true,
data: {
@@ -384,6 +511,135 @@ class BedrockAccountService {
}
}
/**
* 🧪 测试 Bedrock 账户连接SSE 流式返回,供前端测试页面使用)
* @param {string} accountId - 账户ID
* @param {Object} res - Express response 对象
* @param {string} model - 测试使用的模型
*/
async testAccountConnection(accountId, res, model = null) {
const { InvokeModelWithResponseStreamCommand } = require('@aws-sdk/client-bedrock-runtime')
try {
// 获取账户信息
const accountResult = await this.getAccount(accountId)
if (!accountResult.success) {
throw new Error(accountResult.error || 'Account not found')
}
const account = accountResult.data
// 根据账户类型选择合适的测试模型
if (!model) {
// Access Key 模式使用 Haiku更快更便宜
model = account.defaultModel || 'us.anthropic.claude-3-5-haiku-20241022-v1:0'
}
logger.info(
`🧪 Testing Bedrock account connection: ${account.name} (${accountId}), model: ${model}, credentialType: ${account.credentialType}`
)
// 设置 SSE 响应头
res.setHeader('Content-Type', 'text/event-stream')
res.setHeader('Cache-Control', 'no-cache')
res.setHeader('Connection', 'keep-alive')
res.setHeader('X-Accel-Buffering', 'no')
res.status(200)
// 发送 test_start 事件
res.write(`data: ${JSON.stringify({ type: 'test_start' })}\n\n`)
// 构造测试请求体Bedrock 格式)
const bedrockPayload = {
anthropic_version: 'bedrock-2023-05-31',
max_tokens: 256,
messages: [
{
role: 'user',
content:
'Hello! Please respond with a simple greeting to confirm the connection is working. And tell me who are you?'
}
]
}
// 获取 Bedrock 客户端
const region = account.region || bedrockRelayService.defaultRegion
const client = bedrockRelayService._getBedrockClient(region, account)
// 创建流式调用命令
const command = new InvokeModelWithResponseStreamCommand({
modelId: model,
body: JSON.stringify(bedrockPayload),
contentType: 'application/json',
accept: 'application/json'
})
logger.debug(`🌊 Bedrock test stream - model: ${model}, region: ${region}`)
const startTime = Date.now()
const response = await client.send(command)
// 处理流式响应
// let responseText = ''
for await (const chunk of response.body) {
if (chunk.chunk) {
const chunkData = JSON.parse(new TextDecoder().decode(chunk.chunk.bytes))
// 提取文本内容
if (chunkData.type === 'content_block_delta' && chunkData.delta?.text) {
const { text } = chunkData.delta
// responseText += text
// 发送 content 事件
res.write(`data: ${JSON.stringify({ type: 'content', text })}\n\n`)
}
// 检测错误
if (chunkData.type === 'error') {
throw new Error(chunkData.error?.message || 'Bedrock API error')
}
}
}
const duration = Date.now() - startTime
logger.info(`✅ Bedrock test completed - model: ${model}, duration: ${duration}ms`)
// 发送 message_stop 事件(前端兼容)
res.write(`data: ${JSON.stringify({ type: 'message_stop' })}\n\n`)
// 发送 test_complete 事件
res.write(`data: ${JSON.stringify({ type: 'test_complete', success: true })}\n\n`)
// 结束响应
res.end()
logger.info(`✅ Test request completed for Bedrock account: ${account.name}`)
} catch (error) {
logger.error(`❌ Test Bedrock account connection failed:`, error)
// 发送错误事件给前端
try {
// 检查响应流是否仍然可写
if (!res.writableEnded && !res.destroyed) {
if (!res.headersSent) {
res.setHeader('Content-Type', 'text/event-stream')
res.setHeader('Cache-Control', 'no-cache')
res.setHeader('Connection', 'keep-alive')
res.status(200)
}
const errorMsg = error.message || '测试失败'
res.write(`data: ${JSON.stringify({ type: 'error', error: errorMsg })}\n\n`)
res.end()
}
} catch (writeError) {
logger.error('Failed to write error to response stream:', writeError)
}
// 不再重新抛出错误,避免路由层再次处理
// throw error
}
}
/**
* 检查账户订阅是否过期
* @param {Object} account - 账户对象

View File

@@ -48,13 +48,17 @@ class BedrockRelayService {
secretAccessKey: bedrockAccount.awsCredentials.secretAccessKey,
sessionToken: bedrockAccount.awsCredentials.sessionToken
}
} else if (bedrockAccount?.bearerToken) {
// Bearer Token 模式AWS SDK >= 3.400.0 会自动检测环境变量
clientConfig.token = { token: bedrockAccount.bearerToken }
logger.debug(`🔑 使用 Bearer Token 认证 - 账户: ${bedrockAccount.name || 'unknown'}`)
} else {
// 检查是否有环境变量凭证
if (process.env.AWS_ACCESS_KEY_ID && process.env.AWS_SECRET_ACCESS_KEY) {
clientConfig.credentials = fromEnv()
} else {
throw new Error(
'AWS凭证未配置。请在Bedrock账户中配置AWS访问密钥或设置环境变量AWS_ACCESS_KEY_ID和AWS_SECRET_ACCESS_KEY'
'AWS凭证未配置。请在Bedrock账户中配置AWS访问密钥或Bearer Token或设置环境变量AWS_ACCESS_KEY_ID和AWS_SECRET_ACCESS_KEY'
)
}
}
@@ -431,6 +435,18 @@ class BedrockRelayService {
_mapToBedrockModel(modelName) {
// 标准Claude模型名到Bedrock模型名的映射表
const modelMapping = {
// Claude 4.5 Opus
'claude-opus-4-5': 'us.anthropic.claude-opus-4-5-20251101-v1:0',
'claude-opus-4-5-20251101': 'us.anthropic.claude-opus-4-5-20251101-v1:0',
// Claude 4.5 Sonnet
'claude-sonnet-4-5': 'us.anthropic.claude-sonnet-4-5-20250929-v1:0',
'claude-sonnet-4-5-20250929': 'us.anthropic.claude-sonnet-4-5-20250929-v1:0',
// Claude 4.5 Haiku
'claude-haiku-4-5': 'us.anthropic.claude-haiku-4-5-20251001-v1:0',
'claude-haiku-4-5-20251001': 'us.anthropic.claude-haiku-4-5-20251001-v1:0',
// Claude Sonnet 4
'claude-sonnet-4': 'us.anthropic.claude-sonnet-4-20250514-v1:0',
'claude-sonnet-4-20250514': 'us.anthropic.claude-sonnet-4-20250514-v1:0',

View File

@@ -29,51 +29,51 @@ const safeClone =
class ClaudeRelayService {
constructor() {
this.claudeApiUrl = 'https://api.anthropic.com/v1/messages?beta=true'
// 🧹 内存优化:用于存储请求体字符串,避免闭包捕获
this.bodyStore = new Map()
this._bodyStoreIdCounter = 0
this.apiVersion = config.claude.apiVersion
this.betaHeader = config.claude.betaHeader
this.systemPrompt = config.claude.systemPrompt
this.claudeCodeSystemPrompt = "You are Claude Code, Anthropic's official CLI for Claude."
this.toolNameSuffix = null
this.toolNameSuffixGeneratedAt = 0
this.toolNameSuffixTtlMs = 60 * 60 * 1000
}
// 🔧 根据模型ID和客户端传递的 anthropic-beta 获取最终的 header
// 规则:
// 1. 如果客户端传递了 anthropic-beta检查是否包含 oauth-2025-04-20
// 2. 如果没有 oauth-2025-04-20则添加到 claude-code-20250219 后面(如果有的话),否则放在第一位
// 3. 如果客户端没传递则根据模型判断haiku 不需要 claude-code其他模型需要
_getBetaHeader(modelId, clientBetaHeader) {
const OAUTH_BETA = 'oauth-2025-04-20'
const CLAUDE_CODE_BETA = 'claude-code-20250219'
const INTERLEAVED_THINKING_BETA = 'interleaved-thinking-2025-05-14'
const TOOL_STREAMING_BETA = 'fine-grained-tool-streaming-2025-05-14'
// 如果客户端传递了 anthropic-beta
if (clientBetaHeader) {
// 检查是否已包含 oauth-2025-04-20
if (clientBetaHeader.includes(OAUTH_BETA)) {
return clientBetaHeader
}
// 需要添加 oauth-2025-04-20
const parts = clientBetaHeader.split(',').map((p) => p.trim())
// 找到 claude-code-20250219 的位置
const claudeCodeIndex = parts.findIndex((p) => p === CLAUDE_CODE_BETA)
if (claudeCodeIndex !== -1) {
// 在 claude-code-20250219 后面插入
parts.splice(claudeCodeIndex + 1, 0, OAUTH_BETA)
} else {
// 放在第一位
parts.unshift(OAUTH_BETA)
}
return parts.join(',')
}
// 客户端没有传递,根据模型判断
const isHaikuModel = modelId && modelId.toLowerCase().includes('haiku')
if (isHaikuModel) {
return 'oauth-2025-04-20,interleaved-thinking-2025-05-14'
const baseBetas = isHaikuModel
? [OAUTH_BETA, INTERLEAVED_THINKING_BETA]
: [CLAUDE_CODE_BETA, OAUTH_BETA, INTERLEAVED_THINKING_BETA, TOOL_STREAMING_BETA]
const betaList = []
const seen = new Set()
const addBeta = (beta) => {
if (!beta || seen.has(beta)) {
return
}
seen.add(beta)
betaList.push(beta)
}
return 'claude-code-20250219,oauth-2025-04-20,interleaved-thinking-2025-05-14,fine-grained-tool-streaming-2025-05-14'
baseBetas.forEach(addBeta)
if (clientBetaHeader) {
clientBetaHeader
.split(',')
.map((p) => p.trim())
.filter(Boolean)
.forEach(addBeta)
}
return betaList.join(',')
}
_buildStandardRateLimitMessage(resetTime) {
@@ -148,6 +148,235 @@ class ClaudeRelayService {
return ClaudeCodeValidator.includesClaudeCodeSystemPrompt(requestBody, 1)
}
_isClaudeCodeUserAgent(clientHeaders) {
const userAgent = clientHeaders?.['user-agent'] || clientHeaders?.['User-Agent']
return typeof userAgent === 'string' && /^claude-cli\/[^\s]+\s+\(/i.test(userAgent)
}
_isActualClaudeCodeRequest(requestBody, clientHeaders) {
return this.isRealClaudeCodeRequest(requestBody) && this._isClaudeCodeUserAgent(clientHeaders)
}
_getHeaderValueCaseInsensitive(headers, key) {
if (!headers || typeof headers !== 'object') {
return undefined
}
const lowerKey = key.toLowerCase()
for (const candidate of Object.keys(headers)) {
if (candidate.toLowerCase() === lowerKey) {
return headers[candidate]
}
}
return undefined
}
_isClaudeCodeCredentialError(body) {
const message = this._extractErrorMessage(body)
if (!message) {
return false
}
const lower = message.toLowerCase()
return (
lower.includes('only authorized for use with claude code') ||
lower.includes('cannot be used for other api requests')
)
}
_toPascalCaseToolName(name) {
const parts = name.split(/[_-]/).filter(Boolean)
if (parts.length === 0) {
return name
}
const pascal = parts
.map((part) => part.charAt(0).toUpperCase() + part.slice(1).toLowerCase())
.join('')
return `${pascal}_tool`
}
_getToolNameSuffix() {
const now = Date.now()
if (!this.toolNameSuffix || now - this.toolNameSuffixGeneratedAt > this.toolNameSuffixTtlMs) {
this.toolNameSuffix = Math.random().toString(36).substring(2, 8)
this.toolNameSuffixGeneratedAt = now
}
return this.toolNameSuffix
}
_toRandomizedToolName(name) {
const suffix = this._getToolNameSuffix()
return `${name}_${suffix}`
}
_transformToolNamesInRequestBody(body, options = {}) {
if (!body || typeof body !== 'object') {
return null
}
const useRandomized = options.useRandomizedToolNames === true
const forwardMap = new Map()
const reverseMap = new Map()
const transformName = (name) => {
if (typeof name !== 'string' || name.length === 0) {
return name
}
if (forwardMap.has(name)) {
return forwardMap.get(name)
}
const transformed = useRandomized
? this._toRandomizedToolName(name)
: this._toPascalCaseToolName(name)
if (transformed !== name) {
forwardMap.set(name, transformed)
reverseMap.set(transformed, name)
}
return transformed
}
if (Array.isArray(body.tools)) {
body.tools.forEach((tool) => {
if (tool && typeof tool.name === 'string') {
tool.name = transformName(tool.name)
}
})
}
if (body.tool_choice && typeof body.tool_choice === 'object') {
if (typeof body.tool_choice.name === 'string') {
body.tool_choice.name = transformName(body.tool_choice.name)
}
}
if (Array.isArray(body.messages)) {
body.messages.forEach((message) => {
const content = message?.content
if (Array.isArray(content)) {
content.forEach((block) => {
if (block?.type === 'tool_use' && typeof block.name === 'string') {
block.name = transformName(block.name)
}
})
}
})
}
return reverseMap.size > 0 ? reverseMap : null
}
_restoreToolName(name, toolNameMap) {
if (!toolNameMap || toolNameMap.size === 0) {
return name
}
return toolNameMap.get(name) || name
}
_restoreToolNamesInContentBlocks(content, toolNameMap) {
if (!Array.isArray(content)) {
return
}
content.forEach((block) => {
if (block?.type === 'tool_use' && typeof block.name === 'string') {
block.name = this._restoreToolName(block.name, toolNameMap)
}
})
}
_restoreToolNamesInResponseObject(responseBody, toolNameMap) {
if (!responseBody || typeof responseBody !== 'object') {
return
}
if (Array.isArray(responseBody.content)) {
this._restoreToolNamesInContentBlocks(responseBody.content, toolNameMap)
}
if (responseBody.message && Array.isArray(responseBody.message.content)) {
this._restoreToolNamesInContentBlocks(responseBody.message.content, toolNameMap)
}
}
_restoreToolNamesInResponseBody(responseBody, toolNameMap) {
if (!responseBody || !toolNameMap || toolNameMap.size === 0) {
return responseBody
}
if (typeof responseBody === 'string') {
try {
const parsed = JSON.parse(responseBody)
this._restoreToolNamesInResponseObject(parsed, toolNameMap)
return JSON.stringify(parsed)
} catch (error) {
return responseBody
}
}
if (typeof responseBody === 'object') {
this._restoreToolNamesInResponseObject(responseBody, toolNameMap)
}
return responseBody
}
_restoreToolNamesInStreamEvent(event, toolNameMap) {
if (!event || typeof event !== 'object') {
return
}
if (event.content_block && event.content_block.type === 'tool_use') {
if (typeof event.content_block.name === 'string') {
event.content_block.name = this._restoreToolName(event.content_block.name, toolNameMap)
}
}
if (event.delta && event.delta.type === 'tool_use') {
if (typeof event.delta.name === 'string') {
event.delta.name = this._restoreToolName(event.delta.name, toolNameMap)
}
}
if (event.message && Array.isArray(event.message.content)) {
this._restoreToolNamesInContentBlocks(event.message.content, toolNameMap)
}
if (Array.isArray(event.content)) {
this._restoreToolNamesInContentBlocks(event.content, toolNameMap)
}
}
_createToolNameStripperStreamTransformer(streamTransformer, toolNameMap) {
if (!toolNameMap || toolNameMap.size === 0) {
return streamTransformer
}
return (payload) => {
const transformed = streamTransformer ? streamTransformer(payload) : payload
if (!transformed || typeof transformed !== 'string') {
return transformed
}
const lines = transformed.split('\n')
const updated = lines.map((line) => {
if (!line.startsWith('data:')) {
return line
}
const jsonStr = line.slice(5).trimStart()
if (!jsonStr || jsonStr === '[DONE]') {
return line
}
try {
const data = JSON.parse(jsonStr)
this._restoreToolNamesInStreamEvent(data, toolNameMap)
return `data: ${JSON.stringify(data)}`
} catch (error) {
return line
}
})
return updated.join('\n')
}
}
// 🚀 转发请求到Claude API
async relayRequest(
requestBody,
@@ -161,6 +390,7 @@ class ClaudeRelayService {
let queueLockAcquired = false
let queueRequestId = null
let selectedAccountId = null
let bodyStoreIdNonStream = null // 🧹 在 try 块外声明,以便 finally 清理
try {
// 调试日志查看API Key数据
@@ -319,7 +549,12 @@ class ClaudeRelayService {
// 获取有效的访问token
const accessToken = await claudeAccountService.getValidAccessToken(accountId)
const isRealClaudeCodeRequest = this._isActualClaudeCodeRequest(requestBody, clientHeaders)
const processedBody = this._processRequestBody(requestBody, account)
// 🧹 内存优化:存储到 bodyStore避免闭包捕获
const originalBodyString = JSON.stringify(processedBody)
bodyStoreIdNonStream = ++this._bodyStoreIdCounter
this.bodyStore.set(bodyStoreIdNonStream, originalBodyString)
// 获取代理配置
const proxyAgent = await this._getProxyAgent(accountId)
@@ -340,36 +575,59 @@ class ClaudeRelayService {
clientResponse.once('close', handleClientDisconnect)
}
// 发送请求到Claude API传入回调以获取请求对象
// 🔄 403 重试机制:仅对 claude-official 类型账户OAuth 或 Setup Token
const maxRetries = this._shouldRetryOn403(accountType) ? 2 : 0
let retryCount = 0
let response
let shouldRetry = false
const makeRequestWithRetries = async (requestOptions) => {
const maxRetries = this._shouldRetryOn403(accountType) ? 2 : 0
let retryCount = 0
let response
let shouldRetry = false
do {
response = await this._makeClaudeRequest(
processedBody,
accessToken,
proxyAgent,
clientHeaders,
accountId,
(req) => {
upstreamRequest = req
},
options
)
// 检查是否需要重试 403
shouldRetry = response.statusCode === 403 && retryCount < maxRetries
if (shouldRetry) {
retryCount++
logger.warn(
`🔄 403 error for account ${accountId}, retry ${retryCount}/${maxRetries} after 2s`
do {
// 🧹 每次重试从 bodyStore 解析新对象,避免闭包捕获
let retryRequestBody
try {
retryRequestBody = JSON.parse(this.bodyStore.get(bodyStoreIdNonStream))
} catch (parseError) {
logger.error(`❌ Failed to parse body for retry: ${parseError.message}`)
throw new Error(`Request body parse failed: ${parseError.message}`)
}
response = await this._makeClaudeRequest(
retryRequestBody,
accessToken,
proxyAgent,
clientHeaders,
accountId,
(req) => {
upstreamRequest = req
},
{
...requestOptions,
isRealClaudeCodeRequest
}
)
await this._sleep(2000)
}
} while (shouldRetry)
shouldRetry = response.statusCode === 403 && retryCount < maxRetries
if (shouldRetry) {
retryCount++
logger.warn(
`🔄 403 error for account ${accountId}, retry ${retryCount}/${maxRetries} after 2s`
)
await this._sleep(2000)
}
} while (shouldRetry)
return { response, retryCount }
}
let requestOptions = options
let { response, retryCount } = await makeRequestWithRetries(requestOptions)
if (
this._isClaudeCodeCredentialError(response.body) &&
requestOptions.useRandomizedToolNames !== true
) {
requestOptions = { ...requestOptions, useRandomizedToolNames: true }
;({ response, retryCount } = await makeRequestWithRetries(requestOptions))
}
// 如果进行了重试,记录最终结果
if (retryCount > 0) {
@@ -669,6 +927,10 @@ class ClaudeRelayService {
)
throw error
} finally {
// 🧹 清理 bodyStore
if (bodyStoreIdNonStream !== null) {
this.bodyStore.delete(bodyStoreIdNonStream)
}
// 📬 释放用户消息队列锁(兜底,正常情况下已在请求发送后提前释放)
if (queueLockAcquired && queueRequestId && selectedAccountId) {
try {
@@ -1043,23 +1305,19 @@ class ClaudeRelayService {
// 获取过滤后的客户端 headers
const filteredHeaders = this._filterClientHeaders(clientHeaders)
// 判断是否是真实的 Claude Code 请求
const isRealClaudeCode = this.isRealClaudeCodeRequest(body)
const isRealClaudeCode =
requestOptions.isRealClaudeCodeRequest === undefined
? this.isRealClaudeCodeRequest(body)
: requestOptions.isRealClaudeCodeRequest === true
// 如果不是真实的 Claude Code 请求,需要使用从账户获取的 Claude Code headers
let finalHeaders = { ...filteredHeaders }
let requestPayload = body
if (!isRealClaudeCode) {
// 获取该账号存储的 Claude Code headers
const claudeCodeHeaders = await claudeCodeHeadersService.getAccountHeaders(accountId)
// 只添加客户端没有提供的 headers
Object.keys(claudeCodeHeaders).forEach((key) => {
const lowerKey = key.toLowerCase()
if (!finalHeaders[key] && !finalHeaders[lowerKey]) {
finalHeaders[key] = claudeCodeHeaders[key]
}
finalHeaders[key] = claudeCodeHeaders[key]
})
}
@@ -1081,6 +1339,13 @@ class ClaudeRelayService {
requestPayload = extensionResult.body
finalHeaders = extensionResult.headers
let toolNameMap = null
if (!isRealClaudeCode) {
toolNameMap = this._transformToolNamesInRequestBody(requestPayload, {
useRandomizedToolNames: requestOptions.useRandomizedToolNames === true
})
}
// 序列化请求体,计算 content-length
const bodyString = JSON.stringify(requestPayload)
const contentLength = Buffer.byteLength(bodyString, 'utf8')
@@ -1108,13 +1373,14 @@ class ClaudeRelayService {
// 根据模型和客户端传递的 anthropic-beta 动态设置 header
const modelId = requestPayload?.model || body?.model
const clientBetaHeader = clientHeaders?.['anthropic-beta']
const clientBetaHeader = this._getHeaderValueCaseInsensitive(clientHeaders, 'anthropic-beta')
headers['anthropic-beta'] = this._getBetaHeader(modelId, clientBetaHeader)
return {
requestPayload,
bodyString,
headers,
isRealClaudeCode
isRealClaudeCode,
toolNameMap
}
}
@@ -1180,7 +1446,8 @@ class ClaudeRelayService {
return prepared.abortResponse
}
const { bodyString, headers } = prepared
let { bodyString } = prepared
const { headers, isRealClaudeCode, toolNameMap } = prepared
return new Promise((resolve, reject) => {
// 支持自定义路径(如 count_tokens
@@ -1235,6 +1502,10 @@ class ClaudeRelayService {
responseBody = responseData.toString('utf8')
}
if (!isRealClaudeCode) {
responseBody = this._restoreToolNamesInResponseBody(responseBody, toolNameMap)
}
const response = {
statusCode: res.statusCode,
headers: res.headers,
@@ -1293,6 +1564,8 @@ class ClaudeRelayService {
// 写入请求体
req.write(bodyString)
// 🧹 内存优化:立即清空 bodyString 引用,避免闭包捕获
bodyString = null
req.end()
})
}
@@ -1474,7 +1747,12 @@ class ClaudeRelayService {
// 获取有效的访问token
const accessToken = await claudeAccountService.getValidAccessToken(accountId)
const isRealClaudeCodeRequest = this._isActualClaudeCodeRequest(requestBody, clientHeaders)
const processedBody = this._processRequestBody(requestBody, account)
// 🧹 内存优化:存储到 bodyStore不放入 requestOptions 避免闭包捕获
const originalBodyString = JSON.stringify(processedBody)
const bodyStoreId = ++this._bodyStoreIdCounter
this.bodyStore.set(bodyStoreId, originalBodyString)
// 获取代理配置
const proxyAgent = await this._getProxyAgent(accountId)
@@ -1496,7 +1774,11 @@ class ClaudeRelayService {
accountType,
sessionHash,
streamTransformer,
options,
{
...options,
bodyStoreId,
isRealClaudeCodeRequest
},
isDedicatedOfficialAccount,
// 📬 新增回调:在收到响应头时释放队列锁
async () => {
@@ -1585,7 +1867,12 @@ class ClaudeRelayService {
return prepared.abortResponse
}
const { bodyString, headers } = prepared
let { bodyString } = prepared
const { headers, toolNameMap } = prepared
const toolNameStreamTransformer = this._createToolNameStripperStreamTransformer(
streamTransformer,
toolNameMap
)
return new Promise((resolve, reject) => {
const url = new URL(this.claudeApiUrl)
@@ -1693,8 +1980,22 @@ class ClaudeRelayService {
try {
// 递归调用自身进行重试
// 🧹 从 bodyStore 获取字符串用于重试
if (
!requestOptions.bodyStoreId ||
!this.bodyStore.has(requestOptions.bodyStoreId)
) {
throw new Error('529 retry requires valid bodyStoreId')
}
let retryBody
try {
retryBody = JSON.parse(this.bodyStore.get(requestOptions.bodyStoreId))
} catch (parseError) {
logger.error(`❌ Failed to parse body for 529 retry: ${parseError.message}`)
throw new Error(`529 retry body parse failed: ${parseError.message}`)
}
const retryResult = await this._makeClaudeStreamRequestWithUsageCapture(
body,
retryBody,
accessToken,
proxyAgent,
clientHeaders,
@@ -1789,11 +2090,48 @@ class ClaudeRelayService {
errorData += chunk.toString()
})
res.on('end', () => {
res.on('end', async () => {
logger.error(
`❌ Claude API error response (Account: ${account?.name || accountId}):`,
errorData
)
if (
this._isClaudeCodeCredentialError(errorData) &&
requestOptions.useRandomizedToolNames !== true &&
requestOptions.bodyStoreId &&
this.bodyStore.has(requestOptions.bodyStoreId)
) {
let retryBody
try {
retryBody = JSON.parse(this.bodyStore.get(requestOptions.bodyStoreId))
} catch (parseError) {
logger.error(`❌ Failed to parse body for 403 retry: ${parseError.message}`)
reject(new Error(`403 retry body parse failed: ${parseError.message}`))
return
}
try {
const retryResult = await this._makeClaudeStreamRequestWithUsageCapture(
retryBody,
accessToken,
proxyAgent,
clientHeaders,
responseStream,
usageCallback,
accountId,
accountType,
sessionHash,
streamTransformer,
{ ...requestOptions, useRandomizedToolNames: true },
isDedicatedOfficialAccount,
onResponseStart,
retryCount
)
resolve(retryResult)
} catch (retryError) {
reject(retryError)
}
return
}
if (this._isOrganizationDisabledError(res.statusCode, errorData)) {
;(async () => {
try {
@@ -1828,7 +2166,7 @@ class ClaudeRelayService {
}
// 如果有 streamTransformer如测试请求使用前端期望的格式
if (streamTransformer) {
if (toolNameStreamTransformer) {
responseStream.write(
`data: ${JSON.stringify({ type: 'error', error: errorMessage })}\n\n`
)
@@ -1867,6 +2205,11 @@ class ClaudeRelayService {
let rateLimitDetected = false // 限流检测标志
// 监听数据块解析SSE并寻找usage信息
// 🧹 内存优化:在闭包创建前提取需要的值,避免闭包捕获 body 和 requestOptions
// body 和 requestOptions 只在闭包外使用,闭包内只引用基本类型
const requestedModel = body?.model || 'unknown'
const { isRealClaudeCodeRequest } = requestOptions
res.on('data', (chunk) => {
try {
const chunkStr = chunk.toString()
@@ -1882,8 +2225,8 @@ class ClaudeRelayService {
if (isStreamWritable(responseStream)) {
const linesToForward = lines.join('\n') + (lines.length > 0 ? '\n' : '')
// 如果有流转换器,应用转换
if (streamTransformer) {
const transformed = streamTransformer(linesToForward)
if (toolNameStreamTransformer) {
const transformed = toolNameStreamTransformer(linesToForward)
if (transformed) {
responseStream.write(transformed)
}
@@ -2016,8 +2359,8 @@ class ClaudeRelayService {
try {
// 处理缓冲区中剩余的数据
if (buffer.trim() && isStreamWritable(responseStream)) {
if (streamTransformer) {
const transformed = streamTransformer(buffer)
if (toolNameStreamTransformer) {
const transformed = toolNameStreamTransformer(buffer)
if (transformed) {
responseStream.write(transformed)
}
@@ -2072,7 +2415,7 @@ class ClaudeRelayService {
// 打印原始的usage数据为JSON字符串避免嵌套问题
logger.info(
`📊 === Stream Request Usage Summary === Model: ${body.model}, Total Events: ${allUsageData.length}, Usage Data: ${JSON.stringify(allUsageData)}`
`📊 === Stream Request Usage Summary === Model: ${requestedModel}, Total Events: ${allUsageData.length}, Usage Data: ${JSON.stringify(allUsageData)}`
)
// 一般一个请求只会使用一个模型即使有多个usage事件也应该合并
@@ -2082,7 +2425,7 @@ class ClaudeRelayService {
output_tokens: totalUsage.output_tokens,
cache_creation_input_tokens: totalUsage.cache_creation_input_tokens,
cache_read_input_tokens: totalUsage.cache_read_input_tokens,
model: allUsageData[allUsageData.length - 1].model || body.model // 使用最后一个模型或请求模型
model: allUsageData[allUsageData.length - 1].model || requestedModel // 使用最后一个模型或请求模型
}
// 如果有详细的cache_creation数据合并它们
@@ -2191,15 +2534,15 @@ class ClaudeRelayService {
}
// 只有真实的 Claude Code 请求才更新 headers流式请求
if (
clientHeaders &&
Object.keys(clientHeaders).length > 0 &&
this.isRealClaudeCodeRequest(body)
) {
if (clientHeaders && Object.keys(clientHeaders).length > 0 && isRealClaudeCodeRequest) {
await claudeCodeHeadersService.storeAccountHeaders(accountId, clientHeaders)
}
}
// 🧹 清理 bodyStore
if (requestOptions.bodyStoreId) {
this.bodyStore.delete(requestOptions.bodyStoreId)
}
logger.debug('🌊 Claude stream response with usage capture completed')
resolve()
})
@@ -2256,6 +2599,10 @@ class ClaudeRelayService {
)
responseStream.end()
}
// 🧹 清理 bodyStore
if (requestOptions.bodyStoreId) {
this.bodyStore.delete(requestOptions.bodyStoreId)
}
reject(error)
})
@@ -2285,6 +2632,10 @@ class ClaudeRelayService {
)
responseStream.end()
}
// 🧹 清理 bodyStore
if (requestOptions.bodyStoreId) {
this.bodyStore.delete(requestOptions.bodyStoreId)
}
reject(new Error('Request timeout'))
})
@@ -2298,6 +2649,8 @@ class ClaudeRelayService {
// 写入请求体
req.write(bodyString)
// 🧹 内存优化:立即清空 bodyString 引用,避免闭包捕获
bodyString = null
req.end()
})
}

View File

@@ -90,7 +90,7 @@ class DroidRelayService {
return normalizedBody
}
async _applyRateLimitTracking(rateLimitInfo, usageSummary, model, context = '') {
async _applyRateLimitTracking(rateLimitInfo, usageSummary, model, context = '', keyId = null) {
if (!rateLimitInfo) {
return
}
@@ -99,7 +99,9 @@ class DroidRelayService {
const { totalTokens, totalCost } = await updateRateLimitCounters(
rateLimitInfo,
usageSummary,
model
model,
keyId,
'droid'
)
if (totalTokens > 0) {
@@ -606,7 +608,8 @@ class DroidRelayService {
clientRequest?.rateLimitInfo,
usageSummary,
model,
' [stream]'
' [stream]',
keyId
)
logger.success(`Droid stream completed - Account: ${account.name}`)
@@ -1225,7 +1228,8 @@ class DroidRelayService {
clientRequest?.rateLimitInfo,
usageSummary,
model,
endpointLabel
endpointLabel,
keyId
)
logger.success(

View File

@@ -14,16 +14,67 @@ const {
} = require('../utils/tokenRefreshLogger')
const tokenRefreshService = require('./tokenRefreshService')
const { createEncryptor } = require('../utils/commonHelper')
const antigravityClient = require('./antigravityClient')
// Gemini 账户键前缀
const GEMINI_ACCOUNT_KEY_PREFIX = 'gemini_account:'
const SHARED_GEMINI_ACCOUNTS_KEY = 'shared_gemini_accounts'
const ACCOUNT_SESSION_MAPPING_PREFIX = 'gemini_session_account_mapping:'
// Gemini CLI OAuth 配置 - 这些是公开的 Gemini CLI 凭据
const OAUTH_CLIENT_ID = '681255809395-oo8ft2oprdrnp9e3aqf6av3hmdib135j.apps.googleusercontent.com'
const OAUTH_CLIENT_SECRET = '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'
const OAUTH_SCOPES = ['https://www.googleapis.com/auth/cloud-platform']
// Gemini OAuth 配置 - 支持 Gemini CLI 与 Antigravity 两种 OAuth 应用
const OAUTH_PROVIDER_GEMINI_CLI = 'gemini-cli'
const OAUTH_PROVIDER_ANTIGRAVITY = 'antigravity'
const OAUTH_PROVIDERS = {
[OAUTH_PROVIDER_GEMINI_CLI]: {
// Gemini CLI OAuth 配置(公开)
clientId:
process.env.GEMINI_OAUTH_CLIENT_ID ||
'681255809395-oo8ft2oprdrnp9e3aqf6av3hmdib135j.apps.googleusercontent.com',
clientSecret: process.env.GEMINI_OAUTH_CLIENT_SECRET || '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',
scopes: ['https://www.googleapis.com/auth/cloud-platform']
},
[OAUTH_PROVIDER_ANTIGRAVITY]: {
// Antigravity OAuth 配置(参考 gcli2api
clientId:
process.env.ANTIGRAVITY_OAUTH_CLIENT_ID ||
'1071006060591-tmhssin2h21lcre235vtolojh4g403ep.apps.googleusercontent.com',
clientSecret:
process.env.ANTIGRAVITY_OAUTH_CLIENT_SECRET || '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',
scopes: [
'https://www.googleapis.com/auth/cloud-platform',
'https://www.googleapis.com/auth/userinfo.email',
'https://www.googleapis.com/auth/userinfo.profile',
'https://www.googleapis.com/auth/cclog',
'https://www.googleapis.com/auth/experimentsandconfigs'
]
}
}
if (!process.env.GEMINI_OAUTH_CLIENT_SECRET) {
logger.warn(
'⚠️ GEMINI_OAUTH_CLIENT_SECRET 未设置,使用内置默认值(建议在生产环境通过环境变量覆盖)'
)
}
if (!process.env.ANTIGRAVITY_OAUTH_CLIENT_SECRET) {
logger.warn(
'⚠️ ANTIGRAVITY_OAUTH_CLIENT_SECRET 未设置,使用内置默认值(建议在生产环境通过环境变量覆盖)'
)
}
function normalizeOauthProvider(oauthProvider) {
if (!oauthProvider) {
return OAUTH_PROVIDER_GEMINI_CLI
}
return oauthProvider === OAUTH_PROVIDER_ANTIGRAVITY
? OAUTH_PROVIDER_ANTIGRAVITY
: OAUTH_PROVIDER_GEMINI_CLI
}
function getOauthProviderConfig(oauthProvider) {
const normalized = normalizeOauthProvider(oauthProvider)
return OAUTH_PROVIDERS[normalized] || OAUTH_PROVIDERS[OAUTH_PROVIDER_GEMINI_CLI]
}
// 🌐 TCP Keep-Alive Agent 配置
// 解决长时间流式请求中 NAT/防火墙空闲超时导致的连接中断问题
@@ -41,6 +92,117 @@ logger.info('🌐 Gemini HTTPS Agent initialized with TCP Keep-Alive support')
const encryptor = createEncryptor('gemini-account-salt')
const { encrypt, decrypt } = encryptor
async function fetchAvailableModelsAntigravity(
accessToken,
proxyConfig = null,
refreshToken = null
) {
try {
let effectiveToken = accessToken
if (refreshToken) {
try {
const client = await getOauthClient(
accessToken,
refreshToken,
proxyConfig,
OAUTH_PROVIDER_ANTIGRAVITY
)
if (client && client.getAccessToken) {
const latest = await client.getAccessToken()
if (latest?.token) {
effectiveToken = latest.token
}
}
} catch (error) {
logger.warn('Failed to refresh Antigravity access token for models list:', {
message: error.message
})
}
}
const data = await antigravityClient.fetchAvailableModels({
accessToken: effectiveToken,
proxyConfig
})
const modelsDict = data?.models
const created = Math.floor(Date.now() / 1000)
const models = []
const seen = new Set()
const {
getAntigravityModelAlias,
getAntigravityModelMetadata,
normalizeAntigravityModelInput
} = require('../utils/antigravityModel')
const pushModel = (modelId) => {
if (!modelId || seen.has(modelId)) {
return
}
seen.add(modelId)
const metadata = getAntigravityModelMetadata(modelId)
const entry = {
id: modelId,
object: 'model',
created,
owned_by: 'antigravity'
}
if (metadata?.name) {
entry.name = metadata.name
}
if (metadata?.maxCompletionTokens) {
entry.max_completion_tokens = metadata.maxCompletionTokens
}
if (metadata?.thinking) {
entry.thinking = metadata.thinking
}
models.push(entry)
}
if (modelsDict && typeof modelsDict === 'object') {
for (const modelId of Object.keys(modelsDict)) {
const normalized = normalizeAntigravityModelInput(modelId)
const alias = getAntigravityModelAlias(normalized)
if (!alias) {
continue
}
pushModel(alias)
if (alias.endsWith('-thinking')) {
pushModel(alias.replace(/-thinking$/, ''))
}
if (alias.startsWith('gemini-claude-')) {
pushModel(alias.replace(/^gemini-/, ''))
}
}
}
return models
} catch (error) {
logger.error('Failed to fetch Antigravity models:', error.response?.data || error.message)
return [
{
id: 'gemini-2.5-flash',
object: 'model',
created: Math.floor(Date.now() / 1000),
owned_by: 'antigravity'
}
]
}
}
async function countTokensAntigravity(client, contents, model, proxyConfig = null) {
const { token } = await client.getAccessToken()
const response = await antigravityClient.countTokens({
accessToken: token,
proxyConfig,
contents,
model
})
return response
}
// 🧹 定期清理缓存每10分钟
setInterval(
() => {
@@ -51,14 +213,15 @@ setInterval(
)
// 创建 OAuth2 客户端(支持代理配置)
function createOAuth2Client(redirectUri = null, proxyConfig = null) {
function createOAuth2Client(redirectUri = null, proxyConfig = null, oauthProvider = null) {
// 如果没有提供 redirectUri使用默认值
const uri = redirectUri || 'http://localhost:45462'
const oauthConfig = getOauthProviderConfig(oauthProvider)
// 准备客户端选项
const clientOptions = {
clientId: OAUTH_CLIENT_ID,
clientSecret: OAUTH_CLIENT_SECRET,
clientId: oauthConfig.clientId,
clientSecret: oauthConfig.clientSecret,
redirectUri: uri
}
@@ -79,10 +242,17 @@ function createOAuth2Client(redirectUri = null, proxyConfig = null) {
}
// 生成授权 URL (支持 PKCE 和代理)
async function generateAuthUrl(state = null, redirectUri = null, proxyConfig = null) {
async function generateAuthUrl(
state = null,
redirectUri = null,
proxyConfig = null,
oauthProvider = null
) {
// 使用新的 redirect URI
const finalRedirectUri = redirectUri || 'https://codeassist.google.com/authcode'
const oAuth2Client = createOAuth2Client(finalRedirectUri, proxyConfig)
const normalizedProvider = normalizeOauthProvider(oauthProvider)
const oauthConfig = getOauthProviderConfig(normalizedProvider)
const oAuth2Client = createOAuth2Client(finalRedirectUri, proxyConfig, normalizedProvider)
if (proxyConfig) {
logger.info(
@@ -99,7 +269,7 @@ async function generateAuthUrl(state = null, redirectUri = null, proxyConfig = n
const authUrl = oAuth2Client.generateAuthUrl({
redirect_uri: finalRedirectUri,
access_type: 'offline',
scope: OAUTH_SCOPES,
scope: oauthConfig.scopes,
code_challenge_method: 'S256',
code_challenge: codeVerifier.codeChallenge,
state: stateValue,
@@ -110,7 +280,8 @@ async function generateAuthUrl(state = null, redirectUri = null, proxyConfig = n
authUrl,
state: stateValue,
codeVerifier: codeVerifier.codeVerifier,
redirectUri: finalRedirectUri
redirectUri: finalRedirectUri,
oauthProvider: normalizedProvider
}
}
@@ -171,11 +342,14 @@ async function exchangeCodeForTokens(
code,
redirectUri = null,
codeVerifier = null,
proxyConfig = null
proxyConfig = null,
oauthProvider = null
) {
try {
const normalizedProvider = normalizeOauthProvider(oauthProvider)
const oauthConfig = getOauthProviderConfig(normalizedProvider)
// 创建带代理配置的 OAuth2Client
const oAuth2Client = createOAuth2Client(redirectUri, proxyConfig)
const oAuth2Client = createOAuth2Client(redirectUri, proxyConfig, normalizedProvider)
if (proxyConfig) {
logger.info(
@@ -201,7 +375,7 @@ async function exchangeCodeForTokens(
return {
access_token: tokens.access_token,
refresh_token: tokens.refresh_token,
scope: tokens.scope || OAUTH_SCOPES.join(' '),
scope: tokens.scope || oauthConfig.scopes.join(' '),
token_type: tokens.token_type || 'Bearer',
expiry_date: tokens.expiry_date || Date.now() + tokens.expires_in * 1000
}
@@ -212,9 +386,11 @@ async function exchangeCodeForTokens(
}
// 刷新访问令牌
async function refreshAccessToken(refreshToken, proxyConfig = null) {
async function refreshAccessToken(refreshToken, proxyConfig = null, oauthProvider = null) {
const normalizedProvider = normalizeOauthProvider(oauthProvider)
const oauthConfig = getOauthProviderConfig(normalizedProvider)
// 创建带代理配置的 OAuth2Client
const oAuth2Client = createOAuth2Client(null, proxyConfig)
const oAuth2Client = createOAuth2Client(null, proxyConfig, normalizedProvider)
try {
// 设置 refresh_token
@@ -246,7 +422,7 @@ async function refreshAccessToken(refreshToken, proxyConfig = null) {
return {
access_token: credentials.access_token,
refresh_token: credentials.refresh_token || refreshToken, // 保留原 refresh_token 如果没有返回新的
scope: credentials.scope || OAUTH_SCOPES.join(' '),
scope: credentials.scope || oauthConfig.scopes.join(' '),
token_type: credentials.token_type || 'Bearer',
expiry_date: credentials.expiry_date || Date.now() + 3600000 // 默认1小时过期
}
@@ -266,6 +442,8 @@ async function refreshAccessToken(refreshToken, proxyConfig = null) {
async function createAccount(accountData) {
const id = uuidv4()
const now = new Date().toISOString()
const oauthProvider = normalizeOauthProvider(accountData.oauthProvider)
const oauthConfig = getOauthProviderConfig(oauthProvider)
// 处理凭证数据
let geminiOauth = null
@@ -298,7 +476,7 @@ async function createAccount(accountData) {
geminiOauth = JSON.stringify({
access_token: accessToken,
refresh_token: refreshToken,
scope: accountData.scope || OAUTH_SCOPES.join(' '),
scope: accountData.scope || oauthConfig.scopes.join(' '),
token_type: accountData.tokenType || 'Bearer',
expiry_date: accountData.expiryDate || Date.now() + 3600000 // 默认1小时
})
@@ -326,7 +504,8 @@ async function createAccount(accountData) {
refreshToken: refreshToken ? encrypt(refreshToken) : '',
expiresAt, // OAuth Token 过期时间(技术字段,自动刷新)
// 只有OAuth方式才有scopes手动添加的没有
scopes: accountData.geminiOauth ? accountData.scopes || OAUTH_SCOPES.join(' ') : '',
scopes: accountData.geminiOauth ? accountData.scopes || oauthConfig.scopes.join(' ') : '',
oauthProvider,
// ✅ 新增:账户订阅到期时间(业务字段,手动管理)
subscriptionExpiresAt: accountData.subscriptionExpiresAt || null,
@@ -436,6 +615,10 @@ async function updateAccount(accountId, updates) {
updates.schedulable = updates.schedulable.toString()
}
if (updates.oauthProvider !== undefined) {
updates.oauthProvider = normalizeOauthProvider(updates.oauthProvider)
}
// 加密敏感字段
if (updates.geminiOauth) {
updates.geminiOauth = encrypt(
@@ -824,12 +1007,13 @@ async function refreshAccountToken(accountId) {
// 重新获取账户数据(可能已被其他进程刷新)
const updatedAccount = await getAccount(accountId)
if (updatedAccount && updatedAccount.accessToken) {
const oauthConfig = getOauthProviderConfig(updatedAccount.oauthProvider)
const accessToken = decrypt(updatedAccount.accessToken)
return {
access_token: accessToken,
refresh_token: updatedAccount.refreshToken ? decrypt(updatedAccount.refreshToken) : '',
expiry_date: updatedAccount.expiresAt ? new Date(updatedAccount.expiresAt).getTime() : 0,
scope: updatedAccount.scope || OAUTH_SCOPES.join(' '),
scope: updatedAccount.scopes || oauthConfig.scopes.join(' '),
token_type: 'Bearer'
}
}
@@ -843,7 +1027,11 @@ async function refreshAccountToken(accountId) {
// account.refreshToken 已经是解密后的值(从 getAccount 返回)
// 传入账户的代理配置
const newTokens = await refreshAccessToken(account.refreshToken, account.proxy)
const newTokens = await refreshAccessToken(
account.refreshToken,
account.proxy,
account.oauthProvider
)
// 更新账户信息
const updates = {
@@ -975,14 +1163,15 @@ async function getAccountRateLimitInfo(accountId) {
}
// 获取配置的OAuth客户端 - 参考GeminiCliSimulator的getOauthClient方法支持代理
async function getOauthClient(accessToken, refreshToken, proxyConfig = null) {
const client = createOAuth2Client(null, proxyConfig)
async function getOauthClient(accessToken, refreshToken, proxyConfig = null, oauthProvider = null) {
const normalizedProvider = normalizeOauthProvider(oauthProvider)
const oauthConfig = getOauthProviderConfig(normalizedProvider)
const client = createOAuth2Client(null, proxyConfig, normalizedProvider)
const creds = {
access_token: accessToken,
refresh_token: refreshToken,
scope:
'https://www.googleapis.com/auth/cloud-platform https://www.googleapis.com/auth/userinfo.profile openid https://www.googleapis.com/auth/userinfo.email',
scope: oauthConfig.scopes.join(' '),
token_type: 'Bearer',
expiry_date: 1754269905646
}
@@ -1448,6 +1637,43 @@ async function generateContent(
return response.data
}
// 调用 Antigravity 上游生成内容(非流式)
async function generateContentAntigravity(
client,
requestData,
userPromptId,
projectId = null,
sessionId = null,
proxyConfig = null
) {
const { token } = await client.getAccessToken()
const { model } = antigravityClient.buildAntigravityEnvelope({
requestData,
projectId,
sessionId,
userPromptId
})
logger.info('🪐 Antigravity generateContent API调用开始', {
model,
userPromptId,
projectId,
sessionId
})
const { response } = await antigravityClient.request({
accessToken: token,
proxyConfig,
requestData,
projectId,
sessionId,
userPromptId,
stream: false
})
logger.info('✅ Antigravity generateContent API调用成功')
return response.data
}
// 调用 Code Assist API 生成内容(流式)
async function generateContentStream(
client,
@@ -1532,6 +1758,46 @@ async function generateContentStream(
return response.data // 返回流对象
}
// 调用 Antigravity 上游生成内容(流式)
async function generateContentStreamAntigravity(
client,
requestData,
userPromptId,
projectId = null,
sessionId = null,
signal = null,
proxyConfig = null
) {
const { token } = await client.getAccessToken()
const { model } = antigravityClient.buildAntigravityEnvelope({
requestData,
projectId,
sessionId,
userPromptId
})
logger.info('🌊 Antigravity streamGenerateContent API调用开始', {
model,
userPromptId,
projectId,
sessionId
})
const { response } = await antigravityClient.request({
accessToken: token,
proxyConfig,
requestData,
projectId,
sessionId,
userPromptId,
stream: true,
signal,
params: { alt: 'sse' }
})
logger.info('✅ Antigravity streamGenerateContent API调用成功开始流式传输')
return response.data
}
// 更新账户的临时项目 ID
async function updateTempProjectId(accountId, tempProjectId) {
if (!tempProjectId) {
@@ -1625,10 +1891,12 @@ module.exports = {
decrypt,
encryptor, // 暴露加密器以便测试和监控
countTokens,
countTokensAntigravity,
generateContent,
generateContentStream,
generateContentAntigravity,
generateContentStreamAntigravity,
fetchAvailableModelsAntigravity,
updateTempProjectId,
resetAccountStatus,
OAUTH_CLIENT_ID,
OAUTH_SCOPES
resetAccountStatus
}

View File

@@ -163,7 +163,8 @@ async function* handleStreamResponse(response, model, apiKeyId, accountId = null
0, // cacheCreateTokens (Gemini 没有这个概念)
0, // cacheReadTokens (Gemini 没有这个概念)
model,
accountId
accountId,
'gemini'
)
.catch((error) => {
logger.error('❌ Failed to record Gemini usage:', error)
@@ -317,7 +318,8 @@ async function sendGeminiRequest({
0, // cacheCreateTokens
0, // cacheReadTokens
model,
accountId
accountId,
'gemini'
)
.catch((error) => {
logger.error('❌ Failed to record Gemini usage:', error)

View File

@@ -557,7 +557,8 @@ class OpenAIResponsesRelayService {
cacheCreateTokens,
cacheReadTokens,
modelToRecord,
account.id
account.id,
'openai-responses'
)
logger.info(
@@ -685,7 +686,8 @@ class OpenAIResponsesRelayService {
cacheCreateTokens,
cacheReadTokens,
actualModel,
account.id
account.id,
'openai-responses'
)
logger.info(

View File

@@ -12,6 +12,51 @@ class QuotaCardService {
this.CARD_PREFIX = 'quota_card:'
this.REDEMPTION_PREFIX = 'redemption:'
this.CARD_CODE_PREFIX = 'CC' // 卡号前缀
this.LIMITS_CONFIG_KEY = 'system:quota_card_limits'
}
/**
* 获取额度卡上限配置
*/
async getLimitsConfig() {
try {
const configStr = await redis.client.get(this.LIMITS_CONFIG_KEY)
if (configStr) {
return JSON.parse(configStr)
}
// 没有 Redis 配置时,使用 config.js 默认值
const config = require('../../config/config')
return config.quotaCardLimits || {
enabled: true,
maxExpiryDays: 90,
maxTotalCostLimit: 1000
}
} catch (error) {
logger.error('❌ Failed to get limits config:', error)
return { enabled: true, maxExpiryDays: 90, maxTotalCostLimit: 1000 }
}
}
/**
* 保存额度卡上限配置
*/
async saveLimitsConfig(config) {
try {
const parsedDays = parseInt(config.maxExpiryDays)
const parsedCost = parseFloat(config.maxTotalCostLimit)
const newConfig = {
enabled: config.enabled !== false,
maxExpiryDays: Number.isNaN(parsedDays) ? 90 : parsedDays,
maxTotalCostLimit: Number.isNaN(parsedCost) ? 1000 : parsedCost,
updatedAt: new Date().toISOString()
}
await redis.client.set(this.LIMITS_CONFIG_KEY, JSON.stringify(newConfig))
logger.info('✅ Quota card limits config saved')
return newConfig
} catch (error) {
logger.error('❌ Failed to save limits config:', error)
throw error
}
}
/**
@@ -248,53 +293,119 @@ class QuotaCardService {
// 获取卡信息
const card = await this.getCardByCode(code)
if (!card) {
throw new Error('Card not found')
throw new Error('卡号不存在')
}
// 检查卡状态
if (card.status !== 'unused') {
throw new Error(`Card is ${card.status}, cannot redeem`)
const statusMap = { used: '已使用', expired: '已过期', revoked: '已撤销' }
throw new Error(`卡片${statusMap[card.status] || card.status},无法兑换`)
}
// 检查卡是否过期
if (card.expiresAt && new Date(card.expiresAt) < new Date()) {
// 更新卡状态为过期
await this._updateCardStatus(card.id, 'expired')
throw new Error('Card has expired')
throw new Error('卡片已过期')
}
// 获取 API Key 信息
const apiKeyService = require('./apiKeyService')
const keyData = await redis.getApiKey(apiKeyId)
if (!keyData || Object.keys(keyData).length === 0) {
throw new Error('API key not found')
throw new Error('API Key 不存在')
}
// 检查 API Key 是否为聚合类型(只有聚合 Key 才能核销额度卡)
if (card.type !== 'time' && keyData.isAggregated !== 'true') {
throw new Error('Only aggregated keys can redeem quota cards')
}
// 获取上限配置
const limits = await this.getLimitsConfig()
// 执行核销
const redemptionId = uuidv4()
const now = new Date().toISOString()
// 记录核销前状态
const beforeQuota = parseFloat(keyData.quotaLimit || 0)
const beforeLimit = parseFloat(keyData.totalCostLimit || 0)
const beforeExpiry = keyData.expiresAt || ''
// 应用卡效果
let afterQuota = beforeQuota
let afterLimit = beforeLimit
let afterExpiry = beforeExpiry
let quotaAdded = 0
let timeAdded = 0
let actualTimeUnit = card.timeUnit // 实际使用的时间单位(截断时会改为 days
const warnings = [] // 截断警告信息
if (card.type === 'quota' || card.type === 'combo') {
const result = await apiKeyService.addQuota(apiKeyId, card.quotaAmount)
afterQuota = result.newQuotaLimit
let amountToAdd = card.quotaAmount
// 上限保护:检查是否超过最大额度限制
if (limits.enabled && limits.maxTotalCostLimit > 0) {
const maxAllowed = limits.maxTotalCostLimit - beforeLimit
if (amountToAdd > maxAllowed) {
amountToAdd = Math.max(0, maxAllowed)
warnings.push(`额度已达上限,本次仅增加 ${amountToAdd} CC原卡面 ${card.quotaAmount} CC`)
logger.warn(`额度卡兑换超出上限,已截断:原 ${card.quotaAmount} -> 实际 ${amountToAdd}`)
}
}
if (amountToAdd > 0) {
const result = await apiKeyService.addTotalCostLimit(apiKeyId, amountToAdd)
afterLimit = result.newTotalCostLimit
quotaAdded = amountToAdd
}
}
if (card.type === 'time' || card.type === 'combo') {
// 计算新的过期时间
let baseDate = beforeExpiry ? new Date(beforeExpiry) : new Date()
if (baseDate < new Date()) {
baseDate = new Date()
}
let newExpiry = new Date(baseDate)
switch (card.timeUnit) {
case 'hours':
newExpiry.setTime(newExpiry.getTime() + card.timeAmount * 60 * 60 * 1000)
break
case 'days':
newExpiry.setDate(newExpiry.getDate() + card.timeAmount)
break
case 'months':
newExpiry.setMonth(newExpiry.getMonth() + card.timeAmount)
break
}
// 上限保护:检查是否超过最大有效期
if (limits.enabled && limits.maxExpiryDays > 0) {
const maxExpiry = new Date()
maxExpiry.setDate(maxExpiry.getDate() + limits.maxExpiryDays)
if (newExpiry > maxExpiry) {
newExpiry = maxExpiry
warnings.push(`有效期已达上限(${limits.maxExpiryDays}天),时间已截断`)
logger.warn(`时间卡兑换超出上限,已截断至 ${maxExpiry.toISOString()}`)
}
}
const result = await apiKeyService.extendExpiry(apiKeyId, card.timeAmount, card.timeUnit)
afterExpiry = result.newExpiresAt
// 如果有上限保护,使用截断后的时间
if (limits.enabled && limits.maxExpiryDays > 0) {
const maxExpiry = new Date()
maxExpiry.setDate(maxExpiry.getDate() + limits.maxExpiryDays)
if (new Date(result.newExpiresAt) > maxExpiry) {
await redis.client.hset(`apikey:${apiKeyId}`, 'expiresAt', maxExpiry.toISOString())
afterExpiry = maxExpiry.toISOString()
// 计算实际增加的天数,截断时统一用天
const actualDays = Math.max(0, Math.ceil((maxExpiry - baseDate) / (1000 * 60 * 60 * 24)))
timeAdded = actualDays
actualTimeUnit = 'days'
} else {
afterExpiry = result.newExpiresAt
timeAdded = card.timeAmount
}
} else {
afterExpiry = result.newExpiresAt
timeAdded = card.timeAmount
}
}
// 更新卡状态
@@ -321,11 +432,11 @@ class QuotaCardService {
username,
apiKeyId,
apiKeyName: keyData.name || '',
quotaAdded: String(card.type === 'time' ? 0 : card.quotaAmount),
timeAdded: String(card.type === 'quota' ? 0 : card.timeAmount),
timeUnit: card.timeUnit,
beforeQuota: String(beforeQuota),
afterQuota: String(afterQuota),
quotaAdded: String(quotaAdded),
timeAdded: String(timeAdded),
timeUnit: actualTimeUnit,
beforeLimit: String(beforeLimit),
afterLimit: String(afterLimit),
beforeExpiry,
afterExpiry,
timestamp: now,
@@ -343,14 +454,15 @@ class QuotaCardService {
return {
success: true,
warnings,
redemptionId,
cardCode: card.code,
cardType: card.type,
quotaAdded: card.type === 'time' ? 0 : card.quotaAmount,
timeAdded: card.type === 'quota' ? 0 : card.timeAmount,
timeUnit: card.timeUnit,
beforeQuota,
afterQuota,
quotaAdded,
timeAdded,
timeUnit: actualTimeUnit,
beforeLimit,
afterLimit,
beforeExpiry,
afterExpiry
}
@@ -383,13 +495,13 @@ class QuotaCardService {
const now = new Date().toISOString()
// 撤销效果
let actualQuotaDeducted = 0
let actualDeducted = 0
if (parseFloat(redemptionData.quotaAdded) > 0) {
const result = await apiKeyService.deductQuotaLimit(
const result = await apiKeyService.deductTotalCostLimit(
redemptionData.apiKeyId,
parseFloat(redemptionData.quotaAdded)
)
actualQuotaDeducted = result.actualDeducted
actualDeducted = result.actualDeducted
}
// 注意:时间卡撤销比较复杂,这里简化处理,不回退时间
@@ -401,7 +513,7 @@ class QuotaCardService {
revokedAt: now,
revokedBy,
revokeReason: reason,
actualQuotaDeducted: String(actualQuotaDeducted)
actualDeducted: String(actualDeducted)
})
// 更新卡状态
@@ -423,7 +535,7 @@ class QuotaCardService {
success: true,
redemptionId,
cardCode: redemptionData.cardCode,
actualQuotaDeducted,
actualDeducted,
reason
}
} catch (error) {
@@ -469,8 +581,8 @@ class QuotaCardService {
quotaAdded: parseFloat(data.quotaAdded || 0),
timeAdded: parseInt(data.timeAdded || 0),
timeUnit: data.timeUnit,
beforeQuota: parseFloat(data.beforeQuota || 0),
afterQuota: parseFloat(data.afterQuota || 0),
beforeLimit: parseFloat(data.beforeLimit || 0),
afterLimit: parseFloat(data.afterLimit || 0),
beforeExpiry: data.beforeExpiry,
afterExpiry: data.afterExpiry,
timestamp: data.timestamp,
@@ -478,7 +590,7 @@ class QuotaCardService {
revokedAt: data.revokedAt,
revokedBy: data.revokedBy,
revokeReason: data.revokeReason,
actualQuotaDeducted: parseFloat(data.actualQuotaDeducted || 0)
actualDeducted: parseFloat(data.actualDeducted || 0)
})
}
}

View File

@@ -72,7 +72,8 @@ class RateLimitCleanupService {
const results = {
openai: { checked: 0, cleared: 0, errors: [] },
claude: { checked: 0, cleared: 0, errors: [] },
claudeConsole: { checked: 0, cleared: 0, errors: [] }
claudeConsole: { checked: 0, cleared: 0, errors: [] },
tokenRefresh: { checked: 0, refreshed: 0, errors: [] }
}
// 清理 OpenAI 账号
@@ -84,21 +85,29 @@ class RateLimitCleanupService {
// 清理 Claude Console 账号
await this.cleanupClaudeConsoleAccounts(results.claudeConsole)
// 主动刷新等待重置的 Claude 账户 Token防止 5小时/7天 等待期间 Token 过期)
await this.proactiveRefreshClaudeTokens(results.tokenRefresh)
const totalChecked =
results.openai.checked + results.claude.checked + results.claudeConsole.checked
const totalCleared =
results.openai.cleared + results.claude.cleared + results.claudeConsole.cleared
const duration = Date.now() - startTime
if (totalCleared > 0) {
if (totalCleared > 0 || results.tokenRefresh.refreshed > 0) {
logger.info(
`✅ Rate limit cleanup completed: ${totalCleared} accounts cleared out of ${totalChecked} checked (${duration}ms)`
`✅ Rate limit cleanup completed: ${totalCleared}/${totalChecked} accounts cleared, ${results.tokenRefresh.refreshed} tokens refreshed (${duration}ms)`
)
logger.info(` OpenAI: ${results.openai.cleared}/${results.openai.checked}`)
logger.info(` Claude: ${results.claude.cleared}/${results.claude.checked}`)
logger.info(
` Claude Console: ${results.claudeConsole.cleared}/${results.claudeConsole.checked}`
)
if (results.tokenRefresh.checked > 0 || results.tokenRefresh.refreshed > 0) {
logger.info(
` Token Refresh: ${results.tokenRefresh.refreshed}/${results.tokenRefresh.checked} refreshed`
)
}
// 发送 webhook 恢复通知
if (this.clearedAccounts.length > 0) {
@@ -114,7 +123,8 @@ class RateLimitCleanupService {
const allErrors = [
...results.openai.errors,
...results.claude.errors,
...results.claudeConsole.errors
...results.claudeConsole.errors,
...results.tokenRefresh.errors
]
if (allErrors.length > 0) {
logger.warn(`⚠️ Encountered ${allErrors.length} errors during cleanup:`, allErrors)
@@ -348,6 +358,75 @@ class RateLimitCleanupService {
}
}
/**
* 主动刷新 Claude 账户 Token防止等待重置期间 Token 过期)
* 仅对因限流/配额限制而等待重置的账户执行刷新:
* - 429 限流账户rateLimitAutoStopped=true
* - 5小时限制自动停止账户fiveHourAutoStopped=true
* 不处理错误状态账户error/temp_error
*/
async proactiveRefreshClaudeTokens(result) {
try {
const redis = require('../models/redis')
const accounts = await redis.getAllClaudeAccounts()
const now = Date.now()
const refreshAheadMs = 30 * 60 * 1000 // 提前30分钟刷新
const recentRefreshMs = 5 * 60 * 1000 // 5分钟内刷新过则跳过
for (const account of accounts) {
// 1. 必须激活
if (account.isActive !== 'true') {
continue
}
// 2. 必须有 refreshToken
if (!account.refreshToken) {
continue
}
// 3. 【优化】仅处理因限流/配额限制而等待重置的账户
// 正常调度的账户会在请求时自动刷新,无需主动刷新
// 错误状态账户的 Token 可能已失效,刷新也会失败
const isWaitingForReset =
account.rateLimitAutoStopped === 'true' || // 429 限流
account.fiveHourAutoStopped === 'true' // 5小时限制自动停止
if (!isWaitingForReset) {
continue
}
// 4. 【优化】如果最近 5 分钟内已刷新,跳过(避免重复刷新)
const lastRefreshAt = account.lastRefreshAt ? new Date(account.lastRefreshAt).getTime() : 0
if (now - lastRefreshAt < recentRefreshMs) {
continue
}
// 5. 检查 Token 是否即将过期30分钟内
const expiresAt = parseInt(account.expiresAt)
if (expiresAt && now < expiresAt - refreshAheadMs) {
continue
}
// 符合条件,执行刷新
result.checked++
try {
await claudeAccountService.refreshAccountToken(account.id)
result.refreshed++
logger.info(`🔄 Proactively refreshed token: ${account.name} (${account.id})`)
} catch (error) {
result.errors.push({
accountId: account.id,
accountName: account.name,
error: error.message
})
logger.warn(`⚠️ Proactive refresh failed for ${account.name}: ${error.message}`)
}
}
} catch (error) {
logger.error('Failed to proactively refresh Claude tokens:', error)
result.errors.push({ error: error.message })
}
}
/**
* 手动触发一次清理(供 API 或 CLI 调用)
*/

View File

@@ -207,6 +207,36 @@ class ServiceRatesService {
return 'claude'
}
/**
* 根据账户类型获取服务类型(优先级高于模型推断)
*/
getServiceFromAccountType(accountType) {
if (!accountType) return null
const mapping = {
claude: 'claude',
'claude-official': 'claude',
'claude-console': 'claude',
ccr: 'ccr',
bedrock: 'bedrock',
gemini: 'gemini',
'openai-responses': 'codex',
openai: 'codex',
azure: 'azure',
'azure-openai': 'azure',
droid: 'droid'
}
return mapping[accountType] || null
}
/**
* 获取服务类型(优先 accountType后备 model
*/
getService(accountType, model) {
return this.getServiceFromAccountType(accountType) || this.getServiceFromModel(model)
}
/**
* 获取所有支持的服务列表
*/

View File

@@ -5,11 +5,51 @@ const redis = require('../models/redis')
const logger = require('../utils/logger')
const { isSchedulable, isActive, sortAccountsByPriority } = require('../utils/commonHelper')
const OAUTH_PROVIDER_GEMINI_CLI = 'gemini-cli'
const OAUTH_PROVIDER_ANTIGRAVITY = 'antigravity'
const KNOWN_OAUTH_PROVIDERS = [OAUTH_PROVIDER_GEMINI_CLI, OAUTH_PROVIDER_ANTIGRAVITY]
function normalizeOauthProvider(oauthProvider) {
if (!oauthProvider) {
return OAUTH_PROVIDER_GEMINI_CLI
}
return oauthProvider === OAUTH_PROVIDER_ANTIGRAVITY
? OAUTH_PROVIDER_ANTIGRAVITY
: OAUTH_PROVIDER_GEMINI_CLI
}
class UnifiedGeminiScheduler {
constructor() {
this.SESSION_MAPPING_PREFIX = 'unified_gemini_session_mapping:'
}
_getSessionMappingKey(sessionHash, oauthProvider = null) {
if (!sessionHash) {
return null
}
if (!oauthProvider) {
return `${this.SESSION_MAPPING_PREFIX}${sessionHash}`
}
const normalized = normalizeOauthProvider(oauthProvider)
return `${this.SESSION_MAPPING_PREFIX}${normalized}:${sessionHash}`
}
// 🔧 辅助方法:检查账户是否可调度(兼容字符串和布尔值)
_isSchedulable(schedulable) {
// 如果是 undefined 或 null默认为可调度
if (schedulable === undefined || schedulable === null) {
return true
}
// 明确设置为 false布尔值或 'false'(字符串)时不可调度
return schedulable !== false && schedulable !== 'false'
}
// 🔧 辅助方法:检查账户是否激活(兼容字符串和布尔值)
_isActive(isActive) {
// 兼容布尔值 true 和字符串 'true'
return isActive === true || isActive === 'true'
}
// 🎯 统一调度Gemini账号
async selectAccountForApiKey(
apiKeyData,
@@ -17,7 +57,8 @@ class UnifiedGeminiScheduler {
requestedModel = null,
options = {}
) {
const { allowApiAccounts = false } = options
const { allowApiAccounts = false, oauthProvider = null } = options
const normalizedOauthProvider = oauthProvider ? normalizeOauthProvider(oauthProvider) : null
try {
// 如果API Key绑定了专属账户或分组优先使用
@@ -59,15 +100,28 @@ class UnifiedGeminiScheduler {
// 普通 Gemini OAuth 专属账户
else {
const boundAccount = await geminiAccountService.getAccount(apiKeyData.geminiAccountId)
if (boundAccount && isActive(boundAccount.isActive) && boundAccount.status !== 'error') {
logger.info(
`🎯 Using bound dedicated Gemini account: ${boundAccount.name} (${apiKeyData.geminiAccountId}) for API key ${apiKeyData.name}`
)
// 更新账户的最后使用时间
await geminiAccountService.markAccountUsed(apiKeyData.geminiAccountId)
return {
accountId: apiKeyData.geminiAccountId,
accountType: 'gemini'
if (
boundAccount &&
this._isActive(boundAccount.isActive) &&
boundAccount.status !== 'error'
) {
if (
normalizedOauthProvider &&
normalizeOauthProvider(boundAccount.oauthProvider) !== normalizedOauthProvider
) {
logger.warn(
`⚠️ Bound Gemini OAuth account ${boundAccount.name} oauthProvider=${normalizeOauthProvider(boundAccount.oauthProvider)} does not match requested oauthProvider=${normalizedOauthProvider}, falling back to pool`
)
} else {
logger.info(
`🎯 Using bound dedicated Gemini account: ${boundAccount.name} (${apiKeyData.geminiAccountId}) for API key ${apiKeyData.name}`
)
// 更新账户的最后使用时间
await geminiAccountService.markAccountUsed(apiKeyData.geminiAccountId)
return {
accountId: apiKeyData.geminiAccountId,
accountType: 'gemini'
}
}
} else {
logger.warn(
@@ -79,7 +133,7 @@ class UnifiedGeminiScheduler {
// 如果有会话哈希,检查是否有已映射的账户
if (sessionHash) {
const mappedAccount = await this._getSessionMapping(sessionHash)
const mappedAccount = await this._getSessionMapping(sessionHash, normalizedOauthProvider)
if (mappedAccount) {
// 验证映射的账户是否仍然可用
const isAvailable = await this._isAccountAvailable(
@@ -88,7 +142,7 @@ class UnifiedGeminiScheduler {
)
if (isAvailable) {
// 🚀 智能会话续期(续期 unified 映射键,按配置)
await this._extendSessionMappingTTL(sessionHash)
await this._extendSessionMappingTTL(sessionHash, normalizedOauthProvider)
logger.info(
`🎯 Using sticky session account: ${mappedAccount.accountId} (${mappedAccount.accountType}) for session ${sessionHash}`
)
@@ -109,11 +163,10 @@ class UnifiedGeminiScheduler {
}
// 获取所有可用账户
const availableAccounts = await this._getAllAvailableAccounts(
apiKeyData,
requestedModel,
allowApiAccounts
)
const availableAccounts = await this._getAllAvailableAccounts(apiKeyData, requestedModel, {
allowApiAccounts,
oauthProvider: normalizedOauthProvider
})
if (availableAccounts.length === 0) {
// 提供更详细的错误信息
@@ -137,7 +190,8 @@ class UnifiedGeminiScheduler {
await this._setSessionMapping(
sessionHash,
selectedAccount.accountId,
selectedAccount.accountType
selectedAccount.accountType,
normalizedOauthProvider
)
logger.info(
`🎯 Created new sticky session mapping: ${selectedAccount.name} (${selectedAccount.accountId}, ${selectedAccount.accountType}) for session ${sessionHash}`
@@ -166,7 +220,18 @@ class UnifiedGeminiScheduler {
}
// 📋 获取所有可用账户
async _getAllAvailableAccounts(apiKeyData, requestedModel = null, allowApiAccounts = false) {
async _getAllAvailableAccounts(
apiKeyData,
requestedModel = null,
allowApiAccountsOrOptions = false
) {
const options =
allowApiAccountsOrOptions && typeof allowApiAccountsOrOptions === 'object'
? allowApiAccountsOrOptions
: { allowApiAccounts: allowApiAccountsOrOptions }
const { allowApiAccounts = false, oauthProvider = null } = options
const normalizedOauthProvider = oauthProvider ? normalizeOauthProvider(oauthProvider) : null
const availableAccounts = []
// 如果API Key绑定了专属账户优先返回
@@ -222,7 +287,17 @@ class UnifiedGeminiScheduler {
// 普通 Gemini OAuth 账户
else if (!apiKeyData.geminiAccountId.startsWith('group:')) {
const boundAccount = await geminiAccountService.getAccount(apiKeyData.geminiAccountId)
if (boundAccount && isActive(boundAccount.isActive) && boundAccount.status !== 'error') {
if (
boundAccount &&
this._isActive(boundAccount.isActive) &&
boundAccount.status !== 'error'
) {
if (
normalizedOauthProvider &&
normalizeOauthProvider(boundAccount.oauthProvider) !== normalizedOauthProvider
) {
return availableAccounts
}
const isRateLimited = await this.isAccountRateLimited(boundAccount.id)
if (!isRateLimited) {
// 检查模型支持
@@ -272,6 +347,12 @@ class UnifiedGeminiScheduler {
(account.accountType === 'shared' || !account.accountType) && // 兼容旧数据
isSchedulable(account.schedulable)
) {
if (
normalizedOauthProvider &&
normalizeOauthProvider(account.oauthProvider) !== normalizedOauthProvider
) {
continue
}
// 检查是否可调度
// 检查token是否过期
@@ -391,9 +472,10 @@ class UnifiedGeminiScheduler {
}
// 🔗 获取会话映射
async _getSessionMapping(sessionHash) {
async _getSessionMapping(sessionHash, oauthProvider = null) {
const client = redis.getClientSafe()
const mappingData = await client.get(`${this.SESSION_MAPPING_PREFIX}${sessionHash}`)
const key = this._getSessionMappingKey(sessionHash, oauthProvider)
const mappingData = key ? await client.get(key) : null
if (mappingData) {
try {
@@ -408,27 +490,42 @@ class UnifiedGeminiScheduler {
}
// 💾 设置会话映射
async _setSessionMapping(sessionHash, accountId, accountType) {
async _setSessionMapping(sessionHash, accountId, accountType, oauthProvider = null) {
const client = redis.getClientSafe()
const mappingData = JSON.stringify({ accountId, accountType })
// 依据配置设置TTL小时
const appConfig = require('../../config/config')
const ttlHours = appConfig.session?.stickyTtlHours || 1
const ttlSeconds = Math.max(1, Math.floor(ttlHours * 60 * 60))
await client.setex(`${this.SESSION_MAPPING_PREFIX}${sessionHash}`, ttlSeconds, mappingData)
const key = this._getSessionMappingKey(sessionHash, oauthProvider)
if (!key) {
return
}
await client.setex(key, ttlSeconds, mappingData)
}
// 🗑️ 删除会话映射
async _deleteSessionMapping(sessionHash) {
const client = redis.getClientSafe()
await client.del(`${this.SESSION_MAPPING_PREFIX}${sessionHash}`)
if (!sessionHash) {
return
}
const keys = [this._getSessionMappingKey(sessionHash)]
for (const provider of KNOWN_OAUTH_PROVIDERS) {
keys.push(this._getSessionMappingKey(sessionHash, provider))
}
await client.del(keys.filter(Boolean))
}
// 🔁 续期统一调度会话映射TTL针对 unified_gemini_session_mapping:* 键),遵循会话配置
async _extendSessionMappingTTL(sessionHash) {
async _extendSessionMappingTTL(sessionHash, oauthProvider = null) {
try {
const client = redis.getClientSafe()
const key = `${this.SESSION_MAPPING_PREFIX}${sessionHash}`
const key = this._getSessionMappingKey(sessionHash, oauthProvider)
if (!key) {
return false
}
const remainingTTL = await client.ttl(key)
if (remainingTTL === -2) {