fix: droid增加comm端点

This commit is contained in:
shaw
2025-11-27 20:38:50 +08:00
parent 89238818eb
commit 4aeb47062b
4 changed files with 197 additions and 17 deletions

View File

@@ -18,11 +18,12 @@ const RUNTIME_EVENT_FMT_PAYLOAD = 'fmtPayload'
class DroidRelayService {
constructor() {
this.factoryApiBaseUrl = 'https://app.factory.ai/api/llm'
this.factoryApiBaseUrl = 'https://api.factory.ai/api/llm'
this.endpoints = {
anthropic: '/a/v1/messages',
openai: '/o/v1/responses'
openai: '/o/v1/responses',
comm: '/o/v1/chat/completions'
}
this.userAgent = 'factory-cli/0.19.12'
@@ -36,10 +37,14 @@ class DroidRelayService {
}
const normalized = String(endpointType).toLowerCase()
if (normalized === 'openai' || normalized === 'common') {
if (normalized === 'openai') {
return 'openai'
}
if (normalized === 'comm') {
return 'comm'
}
if (normalized === 'anthropic') {
return 'anthropic'
}
@@ -559,8 +564,8 @@ class DroidRelayService {
if (endpointType === 'anthropic') {
// Anthropic Messages API 格式
this._parseAnthropicUsageFromSSE(chunkStr, buffer, currentUsageData)
} else if (endpointType === 'openai') {
// OpenAI Chat Completions 格式
} else if (endpointType === 'openai' || endpointType === 'comm') {
// OpenAI Chat Completions 格式openai 和 comm 共用)
this._parseOpenAIUsageFromSSE(chunkStr, buffer, currentUsageData)
}
@@ -716,8 +721,21 @@ class DroidRelayService {
// 兼容传统 Chat Completions usage 字段
if (data.usage) {
currentUsageData.input_tokens = data.usage.prompt_tokens || 0
currentUsageData.output_tokens = data.usage.completion_tokens || 0
currentUsageData.total_tokens = data.usage.total_tokens || 0
// completion_tokens 可能缺失(如某些模型响应),从 total_tokens - prompt_tokens 计算
if (
data.usage.completion_tokens !== undefined &&
data.usage.completion_tokens !== null
) {
currentUsageData.output_tokens = data.usage.completion_tokens
} else if (currentUsageData.total_tokens > 0 && currentUsageData.input_tokens >= 0) {
currentUsageData.output_tokens = Math.max(
0,
currentUsageData.total_tokens - currentUsageData.input_tokens
)
} else {
currentUsageData.output_tokens = 0
}
logger.debug('📊 Droid OpenAI usage:', currentUsageData)
}
@@ -727,8 +745,18 @@ class DroidRelayService {
const { usage } = data.response
currentUsageData.input_tokens =
usage.input_tokens || usage.prompt_tokens || usage.total_tokens || 0
currentUsageData.output_tokens = usage.output_tokens || usage.completion_tokens || 0
currentUsageData.total_tokens = usage.total_tokens || 0
// completion_tokens/output_tokens 可能缺失,从 total_tokens - input_tokens 计算
if (usage.output_tokens !== undefined || usage.completion_tokens !== undefined) {
currentUsageData.output_tokens = usage.output_tokens || usage.completion_tokens || 0
} else if (currentUsageData.total_tokens > 0 && currentUsageData.input_tokens >= 0) {
currentUsageData.output_tokens = Math.max(
0,
currentUsageData.total_tokens - currentUsageData.input_tokens
)
} else {
currentUsageData.output_tokens = 0
}
logger.debug('📊 Droid OpenAI response usage:', currentUsageData)
}
@@ -763,7 +791,7 @@ class DroidRelayService {
return false
}
if (endpointType === 'openai') {
if (endpointType === 'openai' || endpointType === 'comm') {
if (lower.includes('data: [done]')) {
return true
}
@@ -817,9 +845,16 @@ class DroidRelayService {
usageData.inputTokens ??
usageData.total_input_tokens
)
const outputTokens = toNumber(
const totalTokens = toNumber(usageData.total_tokens ?? usageData.totalTokens)
// 尝试从多个字段获取 output_tokens
let outputTokens = toNumber(
usageData.output_tokens ?? usageData.completion_tokens ?? usageData.outputTokens
)
// 如果 output_tokens 为 0 但有 total_tokens从差值计算
if (outputTokens === 0 && totalTokens > 0 && inputTokens >= 0) {
outputTokens = Math.max(0, totalTokens - inputTokens)
}
const cacheReadTokens = toNumber(
usageData.cache_read_input_tokens ??
usageData.cacheReadTokens ??
@@ -894,6 +929,40 @@ class DroidRelayService {
return account.id || account.accountId || account.account_id || null
}
/**
* 根据模型名称推断 API provider
*/
_inferProviderFromModel(model) {
if (!model || typeof model !== 'string') {
return 'baseten'
}
const lowerModel = model.toLowerCase()
// Google Gemini 模型
if (lowerModel.startsWith('gemini-') || lowerModel.includes('gemini')) {
return 'google'
}
// Anthropic Claude 模型
if (lowerModel.startsWith('claude-') || lowerModel.includes('claude')) {
return 'anthropic'
}
// OpenAI GPT 模型
if (lowerModel.startsWith('gpt-') || lowerModel.includes('gpt')) {
return 'azure_openai'
}
// GLM 模型使用 fireworks
if (lowerModel.startsWith('glm-') || lowerModel.includes('glm')) {
return 'fireworks'
}
// 默认使用 baseten
return 'baseten'
}
/**
* 构建请求头
*/
@@ -923,6 +992,12 @@ class DroidRelayService {
headers['x-api-provider'] = 'azure_openai'
}
// Comm 端点根据模型动态设置 provider
if (endpointType === 'comm') {
const model = requestBody?.model
headers['x-api-provider'] = this._inferProviderFromModel(model)
}
// 生成会话 ID如果客户端没有提供
headers['x-session-id'] = clientHeaders['x-session-id'] || this._generateUUID()
@@ -1034,6 +1109,36 @@ class DroidRelayService {
}
}
// Comm 端点:在 messages 数组前注入 system 消息
if (endpointType === 'comm') {
if (this.systemPrompt && Array.isArray(processedBody.messages)) {
const hasSystemMessage = processedBody.messages.some((m) => m && m.role === 'system')
if (hasSystemMessage) {
// 如果已有 system 消息,在第一个 system 消息的 content 前追加
const firstSystemIndex = processedBody.messages.findIndex((m) => m && m.role === 'system')
if (firstSystemIndex !== -1) {
const existingContent = processedBody.messages[firstSystemIndex].content || ''
if (
typeof existingContent === 'string' &&
!existingContent.startsWith(this.systemPrompt)
) {
processedBody.messages[firstSystemIndex] = {
...processedBody.messages[firstSystemIndex],
content: this.systemPrompt + existingContent
}
}
}
} else {
// 如果没有 system 消息,在 messages 数组最前面插入
processedBody.messages = [
{ role: 'system', content: this.systemPrompt },
...processedBody.messages
]
}
}
}
// 处理 temperature 和 top_p 参数
const hasValidTemperature =
processedBody.temperature !== undefined && processedBody.temperature !== null
@@ -1080,11 +1185,17 @@ class DroidRelayService {
cacheReadTokens: normalizedUsage.cache_read_input_tokens || 0
}
const endpointLabel =
endpointType === 'anthropic'
? ' [anthropic]'
: endpointType === 'comm'
? ' [comm]'
: ' [openai]'
await this._applyRateLimitTracking(
clientRequest?.rateLimitInfo,
usageSummary,
model,
endpointType === 'anthropic' ? ' [anthropic]' : ' [openai]'
endpointLabel
)
logger.success(