mirror of
https://github.com/Wei-Shaw/sub2api.git
synced 2026-03-30 09:20:04 +00:00
Merge branch 'Wei-Shaw:main' into main
This commit is contained in:
53
.github/workflows/release.yml
vendored
53
.github/workflows/release.yml
vendored
@@ -143,3 +143,56 @@ jobs:
|
||||
repository: ${{ secrets.DOCKERHUB_USERNAME }}/sub2api
|
||||
short-description: "Sub2API - AI API Gateway Platform"
|
||||
readme-filepath: ./deploy/DOCKER.md
|
||||
|
||||
# Send Telegram notification
|
||||
- name: Send Telegram Notification
|
||||
if: ${{ secrets.TELEGRAM_BOT_TOKEN != '' && secrets.TELEGRAM_CHAT_ID != '' }}
|
||||
env:
|
||||
TELEGRAM_BOT_TOKEN: ${{ secrets.TELEGRAM_BOT_TOKEN }}
|
||||
TELEGRAM_CHAT_ID: ${{ secrets.TELEGRAM_CHAT_ID }}
|
||||
continue-on-error: true
|
||||
run: |
|
||||
TAG_NAME=${GITHUB_REF#refs/tags/}
|
||||
VERSION=${TAG_NAME#v}
|
||||
REPO="${{ github.repository }}"
|
||||
DOCKER_IMAGE="${{ secrets.DOCKERHUB_USERNAME }}/sub2api"
|
||||
|
||||
# 获取 tag message 内容
|
||||
TAG_MESSAGE='${{ steps.tag_message.outputs.message }}'
|
||||
|
||||
# 限制消息长度(Telegram 消息限制 4096 字符,预留空间给头尾固定内容)
|
||||
if [ ${#TAG_MESSAGE} -gt 3500 ]; then
|
||||
TAG_MESSAGE="${TAG_MESSAGE:0:3500}..."
|
||||
fi
|
||||
|
||||
# 构建消息内容
|
||||
MESSAGE="🚀 *Sub2API 新版本发布!*"$'\n'$'\n'
|
||||
MESSAGE+="📦 版本号: \`${VERSION}\`"$'\n'$'\n'
|
||||
|
||||
# 添加更新内容
|
||||
if [ -n "$TAG_MESSAGE" ]; then
|
||||
MESSAGE+="${TAG_MESSAGE}"$'\n'$'\n'
|
||||
fi
|
||||
|
||||
MESSAGE+="🐳 *Docker 部署:*"$'\n'
|
||||
MESSAGE+="\`\`\`bash"$'\n'
|
||||
MESSAGE+="docker pull ${DOCKER_IMAGE}:${TAG_NAME}"$'\n'
|
||||
MESSAGE+="docker pull ${DOCKER_IMAGE}:latest"$'\n'
|
||||
MESSAGE+="\`\`\`"$'\n'$'\n'
|
||||
MESSAGE+="🔗 *相关链接:*"$'\n'
|
||||
MESSAGE+="• [GitHub Release](https://github.com/${REPO}/releases/tag/${TAG_NAME})"$'\n'
|
||||
MESSAGE+="• [Docker Hub](https://hub.docker.com/r/${DOCKER_IMAGE})"$'\n'$'\n'
|
||||
MESSAGE+="#Sub2API #Release #${TAG_NAME//./_}"
|
||||
|
||||
# 发送消息
|
||||
curl -s -X POST "https://api.telegram.org/bot${TELEGRAM_BOT_TOKEN}/sendMessage" \
|
||||
-H "Content-Type: application/json" \
|
||||
-d "$(jq -n \
|
||||
--arg chat_id "${TELEGRAM_CHAT_ID}" \
|
||||
--arg text "${MESSAGE}" \
|
||||
'{
|
||||
chat_id: $chat_id,
|
||||
text: $text,
|
||||
parse_mode: "Markdown",
|
||||
disable_web_page_preview: true
|
||||
}')"
|
||||
|
||||
@@ -3,6 +3,7 @@ package handler
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"log"
|
||||
@@ -127,66 +128,158 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
|
||||
platform = apiKey.Group.Platform
|
||||
}
|
||||
|
||||
// 选择支持该模型的账号
|
||||
var account *service.Account
|
||||
if platform == service.PlatformGemini {
|
||||
account, err = h.geminiCompatService.SelectAccountForModel(c.Request.Context(), apiKey.GroupID, sessionHash, req.Model)
|
||||
} else {
|
||||
account, err = h.gatewayService.SelectAccountForModel(c.Request.Context(), apiKey.GroupID, sessionHash, req.Model)
|
||||
}
|
||||
if err != nil {
|
||||
h.handleStreamingAwareError(c, http.StatusServiceUnavailable, "api_error", "No available accounts: "+err.Error(), streamStarted)
|
||||
return
|
||||
}
|
||||
const maxAccountSwitches = 3
|
||||
switchCount := 0
|
||||
failedAccountIDs := make(map[int64]struct{})
|
||||
lastFailoverStatus := 0
|
||||
|
||||
// 检查预热请求拦截(在账号选择后、转发前检查)
|
||||
if account.IsInterceptWarmupEnabled() && isWarmupRequest(body) {
|
||||
if req.Stream {
|
||||
sendMockWarmupStream(c, req.Model)
|
||||
} else {
|
||||
sendMockWarmupResponse(c, req.Model)
|
||||
for {
|
||||
account, err := h.geminiCompatService.SelectAccountForModelWithExclusions(c.Request.Context(), apiKey.GroupID, sessionHash, req.Model, failedAccountIDs)
|
||||
if err != nil {
|
||||
if len(failedAccountIDs) == 0 {
|
||||
h.handleStreamingAwareError(c, http.StatusServiceUnavailable, "api_error", "No available accounts: "+err.Error(), streamStarted)
|
||||
return
|
||||
}
|
||||
h.handleFailoverExhausted(c, lastFailoverStatus, streamStarted)
|
||||
return
|
||||
}
|
||||
|
||||
// 检查预热请求拦截(在账号选择后、转发前检查)
|
||||
if account.IsInterceptWarmupEnabled() && isWarmupRequest(body) {
|
||||
if req.Stream {
|
||||
sendMockWarmupStream(c, req.Model)
|
||||
} else {
|
||||
sendMockWarmupResponse(c, req.Model)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
// 3. 获取账号并发槽位
|
||||
accountReleaseFunc, err := h.concurrencyHelper.AcquireAccountSlotWithWait(c, account.ID, account.Concurrency, req.Stream, &streamStarted)
|
||||
if err != nil {
|
||||
log.Printf("Account concurrency acquire failed: %v", err)
|
||||
h.handleConcurrencyError(c, err, "account", streamStarted)
|
||||
return
|
||||
}
|
||||
|
||||
// 转发请求
|
||||
result, err := h.geminiCompatService.Forward(c.Request.Context(), c, account, body)
|
||||
if accountReleaseFunc != nil {
|
||||
accountReleaseFunc()
|
||||
}
|
||||
if err != nil {
|
||||
var failoverErr *service.UpstreamFailoverError
|
||||
if errors.As(err, &failoverErr) {
|
||||
failedAccountIDs[account.ID] = struct{}{}
|
||||
if switchCount >= maxAccountSwitches {
|
||||
lastFailoverStatus = failoverErr.StatusCode
|
||||
h.handleFailoverExhausted(c, lastFailoverStatus, streamStarted)
|
||||
return
|
||||
}
|
||||
lastFailoverStatus = failoverErr.StatusCode
|
||||
switchCount++
|
||||
log.Printf("Account %d: upstream error %d, switching account %d/%d", account.ID, failoverErr.StatusCode, switchCount, maxAccountSwitches)
|
||||
continue
|
||||
}
|
||||
// 错误响应已在Forward中处理,这里只记录日志
|
||||
log.Printf("Forward request failed: %v", err)
|
||||
return
|
||||
}
|
||||
|
||||
// 异步记录使用量(subscription已在函数开头获取)
|
||||
go func(result *service.ForwardResult, usedAccount *service.Account) {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
|
||||
defer cancel()
|
||||
if err := h.gatewayService.RecordUsage(ctx, &service.RecordUsageInput{
|
||||
Result: result,
|
||||
ApiKey: apiKey,
|
||||
User: apiKey.User,
|
||||
Account: usedAccount,
|
||||
Subscription: subscription,
|
||||
}); err != nil {
|
||||
log.Printf("Record usage failed: %v", err)
|
||||
}
|
||||
}(result, account)
|
||||
return
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
// 3. 获取账号并发槽位
|
||||
accountReleaseFunc, err := h.concurrencyHelper.AcquireAccountSlotWithWait(c, account.ID, account.Concurrency, req.Stream, &streamStarted)
|
||||
if err != nil {
|
||||
log.Printf("Account concurrency acquire failed: %v", err)
|
||||
h.handleConcurrencyError(c, err, "account", streamStarted)
|
||||
return
|
||||
}
|
||||
if accountReleaseFunc != nil {
|
||||
defer accountReleaseFunc()
|
||||
}
|
||||
const maxAccountSwitches = 3
|
||||
switchCount := 0
|
||||
failedAccountIDs := make(map[int64]struct{})
|
||||
lastFailoverStatus := 0
|
||||
|
||||
// 转发请求
|
||||
var result *service.ForwardResult
|
||||
if platform == service.PlatformGemini {
|
||||
result, err = h.geminiCompatService.Forward(c.Request.Context(), c, account, body)
|
||||
} else {
|
||||
result, err = h.gatewayService.Forward(c.Request.Context(), c, account, body)
|
||||
}
|
||||
if err != nil {
|
||||
// 错误响应已在Forward中处理,这里只记录日志
|
||||
log.Printf("Forward request failed: %v", err)
|
||||
return
|
||||
}
|
||||
|
||||
// 异步记录使用量(subscription已在函数开头获取)
|
||||
go func() {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
|
||||
defer cancel()
|
||||
if err := h.gatewayService.RecordUsage(ctx, &service.RecordUsageInput{
|
||||
Result: result,
|
||||
ApiKey: apiKey,
|
||||
User: apiKey.User,
|
||||
Account: account,
|
||||
Subscription: subscription,
|
||||
}); err != nil {
|
||||
log.Printf("Record usage failed: %v", err)
|
||||
for {
|
||||
// 选择支持该模型的账号
|
||||
account, err := h.gatewayService.SelectAccountForModelWithExclusions(c.Request.Context(), apiKey.GroupID, sessionHash, req.Model, failedAccountIDs)
|
||||
if err != nil {
|
||||
if len(failedAccountIDs) == 0 {
|
||||
h.handleStreamingAwareError(c, http.StatusServiceUnavailable, "api_error", "No available accounts: "+err.Error(), streamStarted)
|
||||
return
|
||||
}
|
||||
h.handleFailoverExhausted(c, lastFailoverStatus, streamStarted)
|
||||
return
|
||||
}
|
||||
}()
|
||||
|
||||
// 检查预热请求拦截(在账号选择后、转发前检查)
|
||||
if account.IsInterceptWarmupEnabled() && isWarmupRequest(body) {
|
||||
if req.Stream {
|
||||
sendMockWarmupStream(c, req.Model)
|
||||
} else {
|
||||
sendMockWarmupResponse(c, req.Model)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
// 3. 获取账号并发槽位
|
||||
accountReleaseFunc, err := h.concurrencyHelper.AcquireAccountSlotWithWait(c, account.ID, account.Concurrency, req.Stream, &streamStarted)
|
||||
if err != nil {
|
||||
log.Printf("Account concurrency acquire failed: %v", err)
|
||||
h.handleConcurrencyError(c, err, "account", streamStarted)
|
||||
return
|
||||
}
|
||||
|
||||
// 转发请求
|
||||
result, err := h.gatewayService.Forward(c.Request.Context(), c, account, body)
|
||||
if accountReleaseFunc != nil {
|
||||
accountReleaseFunc()
|
||||
}
|
||||
if err != nil {
|
||||
var failoverErr *service.UpstreamFailoverError
|
||||
if errors.As(err, &failoverErr) {
|
||||
failedAccountIDs[account.ID] = struct{}{}
|
||||
if switchCount >= maxAccountSwitches {
|
||||
lastFailoverStatus = failoverErr.StatusCode
|
||||
h.handleFailoverExhausted(c, lastFailoverStatus, streamStarted)
|
||||
return
|
||||
}
|
||||
lastFailoverStatus = failoverErr.StatusCode
|
||||
switchCount++
|
||||
log.Printf("Account %d: upstream error %d, switching account %d/%d", account.ID, failoverErr.StatusCode, switchCount, maxAccountSwitches)
|
||||
continue
|
||||
}
|
||||
// 错误响应已在Forward中处理,这里只记录日志
|
||||
log.Printf("Forward request failed: %v", err)
|
||||
return
|
||||
}
|
||||
|
||||
// 异步记录使用量(subscription已在函数开头获取)
|
||||
go func(result *service.ForwardResult, usedAccount *service.Account) {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
|
||||
defer cancel()
|
||||
if err := h.gatewayService.RecordUsage(ctx, &service.RecordUsageInput{
|
||||
Result: result,
|
||||
ApiKey: apiKey,
|
||||
User: apiKey.User,
|
||||
Account: usedAccount,
|
||||
Subscription: subscription,
|
||||
}); err != nil {
|
||||
log.Printf("Record usage failed: %v", err)
|
||||
}
|
||||
}(result, account)
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
// Models handles listing available models
|
||||
@@ -314,6 +407,28 @@ func (h *GatewayHandler) handleConcurrencyError(c *gin.Context, err error, slotT
|
||||
fmt.Sprintf("Concurrency limit exceeded for %s, please retry later", slotType), streamStarted)
|
||||
}
|
||||
|
||||
func (h *GatewayHandler) handleFailoverExhausted(c *gin.Context, statusCode int, streamStarted bool) {
|
||||
status, errType, errMsg := h.mapUpstreamError(statusCode)
|
||||
h.handleStreamingAwareError(c, status, errType, errMsg, streamStarted)
|
||||
}
|
||||
|
||||
func (h *GatewayHandler) mapUpstreamError(statusCode int) (int, string, string) {
|
||||
switch statusCode {
|
||||
case 401:
|
||||
return http.StatusBadGateway, "upstream_error", "Upstream authentication failed, please contact administrator"
|
||||
case 403:
|
||||
return http.StatusBadGateway, "upstream_error", "Upstream access forbidden, please contact administrator"
|
||||
case 429:
|
||||
return http.StatusTooManyRequests, "rate_limit_error", "Upstream rate limit exceeded, please retry later"
|
||||
case 529:
|
||||
return http.StatusServiceUnavailable, "overloaded_error", "Upstream service overloaded, please retry later"
|
||||
case 500, 502, 503, 504:
|
||||
return http.StatusBadGateway, "upstream_error", "Upstream service temporarily unavailable"
|
||||
default:
|
||||
return http.StatusBadGateway, "upstream_error", "Upstream request failed"
|
||||
}
|
||||
}
|
||||
|
||||
// handleStreamingAwareError handles errors that may occur after streaming has started
|
||||
func (h *GatewayHandler) handleStreamingAwareError(c *gin.Context, status int, errType, message string, streamStarted bool) {
|
||||
if streamStarted {
|
||||
|
||||
@@ -2,6 +2,7 @@ package handler
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"io"
|
||||
"log"
|
||||
"net/http"
|
||||
@@ -158,44 +159,69 @@ func (h *GatewayHandler) GeminiV1BetaModels(c *gin.Context) {
|
||||
|
||||
// 3) select account (sticky session based on request body)
|
||||
sessionHash := h.gatewayService.GenerateSessionHash(body)
|
||||
account, err := h.geminiCompatService.SelectAccountForModel(c.Request.Context(), apiKey.GroupID, sessionHash, modelName)
|
||||
if err != nil {
|
||||
googleError(c, http.StatusServiceUnavailable, "No available Gemini accounts: "+err.Error())
|
||||
return
|
||||
}
|
||||
const maxAccountSwitches = 3
|
||||
switchCount := 0
|
||||
failedAccountIDs := make(map[int64]struct{})
|
||||
lastFailoverStatus := 0
|
||||
|
||||
// 4) account concurrency slot
|
||||
accountReleaseFunc, err := geminiConcurrency.AcquireAccountSlotWithWait(c, account.ID, account.Concurrency, stream, &streamStarted)
|
||||
if err != nil {
|
||||
googleError(c, http.StatusTooManyRequests, err.Error())
|
||||
return
|
||||
}
|
||||
if accountReleaseFunc != nil {
|
||||
defer accountReleaseFunc()
|
||||
}
|
||||
|
||||
// 5) forward (writes response to client)
|
||||
result, err := h.geminiCompatService.ForwardNative(c.Request.Context(), c, account, modelName, action, stream, body)
|
||||
if err != nil {
|
||||
// ForwardNative already wrote the response
|
||||
log.Printf("Gemini native forward failed: %v", err)
|
||||
return
|
||||
}
|
||||
|
||||
// 6) record usage async
|
||||
go func() {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
|
||||
defer cancel()
|
||||
if err := h.gatewayService.RecordUsage(ctx, &service.RecordUsageInput{
|
||||
Result: result,
|
||||
ApiKey: apiKey,
|
||||
User: apiKey.User,
|
||||
Account: account,
|
||||
Subscription: subscription,
|
||||
}); err != nil {
|
||||
log.Printf("Record usage failed: %v", err)
|
||||
for {
|
||||
account, err := h.geminiCompatService.SelectAccountForModelWithExclusions(c.Request.Context(), apiKey.GroupID, sessionHash, modelName, failedAccountIDs)
|
||||
if err != nil {
|
||||
if len(failedAccountIDs) == 0 {
|
||||
googleError(c, http.StatusServiceUnavailable, "No available Gemini accounts: "+err.Error())
|
||||
return
|
||||
}
|
||||
handleGeminiFailoverExhausted(c, lastFailoverStatus)
|
||||
return
|
||||
}
|
||||
}()
|
||||
|
||||
// 4) account concurrency slot
|
||||
accountReleaseFunc, err := geminiConcurrency.AcquireAccountSlotWithWait(c, account.ID, account.Concurrency, stream, &streamStarted)
|
||||
if err != nil {
|
||||
googleError(c, http.StatusTooManyRequests, err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
// 5) forward (writes response to client)
|
||||
result, err := h.geminiCompatService.ForwardNative(c.Request.Context(), c, account, modelName, action, stream, body)
|
||||
if accountReleaseFunc != nil {
|
||||
accountReleaseFunc()
|
||||
}
|
||||
if err != nil {
|
||||
var failoverErr *service.UpstreamFailoverError
|
||||
if errors.As(err, &failoverErr) {
|
||||
failedAccountIDs[account.ID] = struct{}{}
|
||||
if switchCount >= maxAccountSwitches {
|
||||
lastFailoverStatus = failoverErr.StatusCode
|
||||
handleGeminiFailoverExhausted(c, lastFailoverStatus)
|
||||
return
|
||||
}
|
||||
lastFailoverStatus = failoverErr.StatusCode
|
||||
switchCount++
|
||||
log.Printf("Gemini account %d: upstream error %d, switching account %d/%d", account.ID, failoverErr.StatusCode, switchCount, maxAccountSwitches)
|
||||
continue
|
||||
}
|
||||
// ForwardNative already wrote the response
|
||||
log.Printf("Gemini native forward failed: %v", err)
|
||||
return
|
||||
}
|
||||
|
||||
// 6) record usage async
|
||||
go func(result *service.ForwardResult, usedAccount *service.Account) {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
|
||||
defer cancel()
|
||||
if err := h.gatewayService.RecordUsage(ctx, &service.RecordUsageInput{
|
||||
Result: result,
|
||||
ApiKey: apiKey,
|
||||
User: apiKey.User,
|
||||
Account: usedAccount,
|
||||
Subscription: subscription,
|
||||
}); err != nil {
|
||||
log.Printf("Record usage failed: %v", err)
|
||||
}
|
||||
}(result, account)
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
func parseGeminiModelAction(rest string) (model string, action string, err error) {
|
||||
@@ -217,6 +243,28 @@ func parseGeminiModelAction(rest string) (model string, action string, err error
|
||||
return "", "", &pathParseError{"invalid model action path"}
|
||||
}
|
||||
|
||||
func handleGeminiFailoverExhausted(c *gin.Context, statusCode int) {
|
||||
status, message := mapGeminiUpstreamError(statusCode)
|
||||
googleError(c, status, message)
|
||||
}
|
||||
|
||||
func mapGeminiUpstreamError(statusCode int) (int, string) {
|
||||
switch statusCode {
|
||||
case 401:
|
||||
return http.StatusBadGateway, "Upstream authentication failed, please contact administrator"
|
||||
case 403:
|
||||
return http.StatusBadGateway, "Upstream access forbidden, please contact administrator"
|
||||
case 429:
|
||||
return http.StatusTooManyRequests, "Upstream rate limit exceeded, please retry later"
|
||||
case 529:
|
||||
return http.StatusServiceUnavailable, "Upstream service overloaded, please retry later"
|
||||
case 500, 502, 503, 504:
|
||||
return http.StatusBadGateway, "Upstream service temporarily unavailable"
|
||||
default:
|
||||
return http.StatusBadGateway, "Upstream request failed"
|
||||
}
|
||||
}
|
||||
|
||||
type pathParseError struct{ msg string }
|
||||
|
||||
func (e *pathParseError) Error() string { return e.msg }
|
||||
|
||||
@@ -3,6 +3,7 @@ package handler
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"log"
|
||||
@@ -127,49 +128,74 @@ func (h *OpenAIGatewayHandler) Responses(c *gin.Context) {
|
||||
// Generate session hash (from header for OpenAI)
|
||||
sessionHash := h.gatewayService.GenerateSessionHash(c)
|
||||
|
||||
// Select account supporting the requested model
|
||||
log.Printf("[OpenAI Handler] Selecting account: groupID=%v model=%s", apiKey.GroupID, reqModel)
|
||||
account, err := h.gatewayService.SelectAccountForModel(c.Request.Context(), apiKey.GroupID, sessionHash, reqModel)
|
||||
if err != nil {
|
||||
log.Printf("[OpenAI Handler] SelectAccount failed: %v", err)
|
||||
h.handleStreamingAwareError(c, http.StatusServiceUnavailable, "api_error", "No available accounts: "+err.Error(), streamStarted)
|
||||
return
|
||||
}
|
||||
log.Printf("[OpenAI Handler] Selected account: id=%d name=%s", account.ID, account.Name)
|
||||
const maxAccountSwitches = 3
|
||||
switchCount := 0
|
||||
failedAccountIDs := make(map[int64]struct{})
|
||||
lastFailoverStatus := 0
|
||||
|
||||
// 3. Acquire account concurrency slot
|
||||
accountReleaseFunc, err := h.concurrencyHelper.AcquireAccountSlotWithWait(c, account.ID, account.Concurrency, reqStream, &streamStarted)
|
||||
if err != nil {
|
||||
log.Printf("Account concurrency acquire failed: %v", err)
|
||||
h.handleConcurrencyError(c, err, "account", streamStarted)
|
||||
return
|
||||
}
|
||||
if accountReleaseFunc != nil {
|
||||
defer accountReleaseFunc()
|
||||
}
|
||||
|
||||
// Forward request
|
||||
result, err := h.gatewayService.Forward(c.Request.Context(), c, account, body)
|
||||
if err != nil {
|
||||
// Error response already handled in Forward, just log
|
||||
log.Printf("Forward request failed: %v", err)
|
||||
return
|
||||
}
|
||||
|
||||
// Async record usage
|
||||
go func() {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
|
||||
defer cancel()
|
||||
if err := h.gatewayService.RecordUsage(ctx, &service.OpenAIRecordUsageInput{
|
||||
Result: result,
|
||||
ApiKey: apiKey,
|
||||
User: apiKey.User,
|
||||
Account: account,
|
||||
Subscription: subscription,
|
||||
}); err != nil {
|
||||
log.Printf("Record usage failed: %v", err)
|
||||
for {
|
||||
// Select account supporting the requested model
|
||||
log.Printf("[OpenAI Handler] Selecting account: groupID=%v model=%s", apiKey.GroupID, reqModel)
|
||||
account, err := h.gatewayService.SelectAccountForModelWithExclusions(c.Request.Context(), apiKey.GroupID, sessionHash, reqModel, failedAccountIDs)
|
||||
if err != nil {
|
||||
log.Printf("[OpenAI Handler] SelectAccount failed: %v", err)
|
||||
if len(failedAccountIDs) == 0 {
|
||||
h.handleStreamingAwareError(c, http.StatusServiceUnavailable, "api_error", "No available accounts: "+err.Error(), streamStarted)
|
||||
return
|
||||
}
|
||||
h.handleFailoverExhausted(c, lastFailoverStatus, streamStarted)
|
||||
return
|
||||
}
|
||||
}()
|
||||
log.Printf("[OpenAI Handler] Selected account: id=%d name=%s", account.ID, account.Name)
|
||||
|
||||
// 3. Acquire account concurrency slot
|
||||
accountReleaseFunc, err := h.concurrencyHelper.AcquireAccountSlotWithWait(c, account.ID, account.Concurrency, reqStream, &streamStarted)
|
||||
if err != nil {
|
||||
log.Printf("Account concurrency acquire failed: %v", err)
|
||||
h.handleConcurrencyError(c, err, "account", streamStarted)
|
||||
return
|
||||
}
|
||||
|
||||
// Forward request
|
||||
result, err := h.gatewayService.Forward(c.Request.Context(), c, account, body)
|
||||
if accountReleaseFunc != nil {
|
||||
accountReleaseFunc()
|
||||
}
|
||||
if err != nil {
|
||||
var failoverErr *service.UpstreamFailoverError
|
||||
if errors.As(err, &failoverErr) {
|
||||
failedAccountIDs[account.ID] = struct{}{}
|
||||
if switchCount >= maxAccountSwitches {
|
||||
lastFailoverStatus = failoverErr.StatusCode
|
||||
h.handleFailoverExhausted(c, lastFailoverStatus, streamStarted)
|
||||
return
|
||||
}
|
||||
lastFailoverStatus = failoverErr.StatusCode
|
||||
switchCount++
|
||||
log.Printf("Account %d: upstream error %d, switching account %d/%d", account.ID, failoverErr.StatusCode, switchCount, maxAccountSwitches)
|
||||
continue
|
||||
}
|
||||
// Error response already handled in Forward, just log
|
||||
log.Printf("Forward request failed: %v", err)
|
||||
return
|
||||
}
|
||||
|
||||
// Async record usage
|
||||
go func(result *service.OpenAIForwardResult, usedAccount *service.Account) {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
|
||||
defer cancel()
|
||||
if err := h.gatewayService.RecordUsage(ctx, &service.OpenAIRecordUsageInput{
|
||||
Result: result,
|
||||
ApiKey: apiKey,
|
||||
User: apiKey.User,
|
||||
Account: usedAccount,
|
||||
Subscription: subscription,
|
||||
}); err != nil {
|
||||
log.Printf("Record usage failed: %v", err)
|
||||
}
|
||||
}(result, account)
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
// handleConcurrencyError handles concurrency-related errors with proper 429 response
|
||||
@@ -178,6 +204,28 @@ func (h *OpenAIGatewayHandler) handleConcurrencyError(c *gin.Context, err error,
|
||||
fmt.Sprintf("Concurrency limit exceeded for %s, please retry later", slotType), streamStarted)
|
||||
}
|
||||
|
||||
func (h *OpenAIGatewayHandler) handleFailoverExhausted(c *gin.Context, statusCode int, streamStarted bool) {
|
||||
status, errType, errMsg := h.mapUpstreamError(statusCode)
|
||||
h.handleStreamingAwareError(c, status, errType, errMsg, streamStarted)
|
||||
}
|
||||
|
||||
func (h *OpenAIGatewayHandler) mapUpstreamError(statusCode int) (int, string, string) {
|
||||
switch statusCode {
|
||||
case 401:
|
||||
return http.StatusBadGateway, "upstream_error", "Upstream authentication failed, please contact administrator"
|
||||
case 403:
|
||||
return http.StatusBadGateway, "upstream_error", "Upstream access forbidden, please contact administrator"
|
||||
case 429:
|
||||
return http.StatusTooManyRequests, "rate_limit_error", "Upstream rate limit exceeded, please retry later"
|
||||
case 529:
|
||||
return http.StatusServiceUnavailable, "upstream_error", "Upstream service overloaded, please retry later"
|
||||
case 500, 502, 503, 504:
|
||||
return http.StatusBadGateway, "upstream_error", "Upstream service temporarily unavailable"
|
||||
default:
|
||||
return http.StatusBadGateway, "upstream_error", "Upstream request failed"
|
||||
}
|
||||
}
|
||||
|
||||
// handleStreamingAwareError handles errors that may occur after streaming has started
|
||||
func (h *OpenAIGatewayHandler) handleStreamingAwareError(c *gin.Context, status int, errType, message string, streamStarted bool) {
|
||||
if streamStarted {
|
||||
|
||||
@@ -199,16 +199,20 @@ func (s *claudeOAuthService) ExchangeCodeForToken(ctx context.Context, code, cod
|
||||
func (s *claudeOAuthService) RefreshToken(ctx context.Context, refreshToken, proxyURL string) (*oauth.TokenResponse, error) {
|
||||
client := s.clientFactory(proxyURL)
|
||||
|
||||
formData := url.Values{}
|
||||
formData.Set("grant_type", "refresh_token")
|
||||
formData.Set("refresh_token", refreshToken)
|
||||
formData.Set("client_id", oauth.ClientID)
|
||||
// 使用 JSON 格式(与 ExchangeCodeForToken 保持一致)
|
||||
// Anthropic OAuth API 期望 JSON 格式的请求体
|
||||
reqBody := map[string]any{
|
||||
"grant_type": "refresh_token",
|
||||
"refresh_token": refreshToken,
|
||||
"client_id": oauth.ClientID,
|
||||
}
|
||||
|
||||
var tokenResp oauth.TokenResponse
|
||||
|
||||
resp, err := client.R().
|
||||
SetContext(ctx).
|
||||
SetFormDataFromValues(formData).
|
||||
SetHeader("Content-Type", "application/json").
|
||||
SetBody(reqBody).
|
||||
SetSuccessResult(&tokenResp).
|
||||
Post(s.tokenURL)
|
||||
|
||||
|
||||
@@ -6,7 +6,6 @@ import (
|
||||
"io"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"net/url"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
@@ -34,7 +33,6 @@ type requestCapture struct {
|
||||
method string
|
||||
cookies []*http.Cookie
|
||||
body []byte
|
||||
formValues url.Values
|
||||
bodyJSON map[string]any
|
||||
contentType string
|
||||
}
|
||||
@@ -282,24 +280,53 @@ func (s *ClaudeOAuthServiceSuite) TestRefreshToken() {
|
||||
validate func(captured requestCapture)
|
||||
}{
|
||||
{
|
||||
name: "sends_form",
|
||||
name: "sends_json_format",
|
||||
handler: func(w http.ResponseWriter, r *http.Request) {
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
_ = json.NewEncoder(w).Encode(oauth.TokenResponse{AccessToken: "at2", TokenType: "bearer", ExpiresIn: 3600})
|
||||
_ = json.NewEncoder(w).Encode(oauth.TokenResponse{
|
||||
AccessToken: "new_access_token",
|
||||
TokenType: "bearer",
|
||||
ExpiresIn: 28800,
|
||||
RefreshToken: "new_refresh_token",
|
||||
Scope: "user:profile user:inference",
|
||||
})
|
||||
},
|
||||
wantResp: &oauth.TokenResponse{
|
||||
AccessToken: "new_access_token",
|
||||
RefreshToken: "new_refresh_token",
|
||||
},
|
||||
wantResp: &oauth.TokenResponse{AccessToken: "at2"},
|
||||
validate: func(captured requestCapture) {
|
||||
require.Equal(s.T(), http.MethodPost, captured.method, "expected POST")
|
||||
require.Equal(s.T(), "refresh_token", captured.formValues.Get("grant_type"))
|
||||
require.Equal(s.T(), "rt", captured.formValues.Get("refresh_token"))
|
||||
require.Equal(s.T(), oauth.ClientID, captured.formValues.Get("client_id"))
|
||||
// 验证使用 JSON 格式(不是 form 格式)
|
||||
require.True(s.T(), strings.HasPrefix(captured.contentType, "application/json"),
|
||||
"expected JSON content-type, got: %s", captured.contentType)
|
||||
// 验证 JSON body 内容
|
||||
require.Equal(s.T(), "refresh_token", captured.bodyJSON["grant_type"])
|
||||
require.Equal(s.T(), "rt", captured.bodyJSON["refresh_token"])
|
||||
require.Equal(s.T(), oauth.ClientID, captured.bodyJSON["client_id"])
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "returns_new_refresh_token",
|
||||
handler: func(w http.ResponseWriter, r *http.Request) {
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
_ = json.NewEncoder(w).Encode(oauth.TokenResponse{
|
||||
AccessToken: "at",
|
||||
TokenType: "bearer",
|
||||
ExpiresIn: 28800,
|
||||
RefreshToken: "rotated_rt", // Anthropic rotates refresh tokens
|
||||
})
|
||||
},
|
||||
wantResp: &oauth.TokenResponse{
|
||||
AccessToken: "at",
|
||||
RefreshToken: "rotated_rt",
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "non_200_returns_error",
|
||||
handler: func(w http.ResponseWriter, r *http.Request) {
|
||||
w.WriteHeader(http.StatusUnauthorized)
|
||||
_, _ = w.Write([]byte("unauthorized"))
|
||||
_, _ = w.Write([]byte(`{"error":"invalid_grant"}`))
|
||||
},
|
||||
wantErr: true,
|
||||
},
|
||||
@@ -311,8 +338,9 @@ func (s *ClaudeOAuthServiceSuite) TestRefreshToken() {
|
||||
|
||||
s.srv = httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
captured.method = r.Method
|
||||
captured.contentType = r.Header.Get("Content-Type")
|
||||
captured.body, _ = io.ReadAll(r.Body)
|
||||
captured.formValues, _ = url.ParseQuery(string(captured.body))
|
||||
_ = json.Unmarshal(captured.body, &captured.bodyJSON)
|
||||
tt.handler(w, r)
|
||||
}))
|
||||
defer s.srv.Close()
|
||||
@@ -331,6 +359,7 @@ func (s *ClaudeOAuthServiceSuite) TestRefreshToken() {
|
||||
|
||||
require.NoError(s.T(), err)
|
||||
require.Equal(s.T(), tt.wantResp.AccessToken, resp.AccessToken)
|
||||
require.Equal(s.T(), tt.wantResp.RefreshToken, resp.RefreshToken)
|
||||
if tt.validate != nil {
|
||||
tt.validate(captured)
|
||||
}
|
||||
|
||||
@@ -925,8 +925,38 @@ func (r *stubUsageLogRepo) GetUserModelStats(ctx context.Context, userID int64,
|
||||
|
||||
func (r *stubUsageLogRepo) ListWithFilters(ctx context.Context, params pagination.PaginationParams, filters usagestats.UsageLogFilters) ([]service.UsageLog, *pagination.PaginationResult, error) {
|
||||
logs := r.userLogs[filters.UserID]
|
||||
total := int64(len(logs))
|
||||
out := paginateLogs(logs, params)
|
||||
|
||||
// Apply filters
|
||||
var filtered []service.UsageLog
|
||||
for _, log := range logs {
|
||||
// Apply ApiKeyID filter
|
||||
if filters.ApiKeyID > 0 && log.ApiKeyID != filters.ApiKeyID {
|
||||
continue
|
||||
}
|
||||
// Apply Model filter
|
||||
if filters.Model != "" && log.Model != filters.Model {
|
||||
continue
|
||||
}
|
||||
// Apply Stream filter
|
||||
if filters.Stream != nil && log.Stream != *filters.Stream {
|
||||
continue
|
||||
}
|
||||
// Apply BillingType filter
|
||||
if filters.BillingType != nil && log.BillingType != *filters.BillingType {
|
||||
continue
|
||||
}
|
||||
// Apply time range filters
|
||||
if filters.StartTime != nil && log.CreatedAt.Before(*filters.StartTime) {
|
||||
continue
|
||||
}
|
||||
if filters.EndTime != nil && log.CreatedAt.After(*filters.EndTime) {
|
||||
continue
|
||||
}
|
||||
filtered = append(filtered, log)
|
||||
}
|
||||
|
||||
total := int64(len(filtered))
|
||||
out := paginateLogs(filtered, params)
|
||||
return out, paginationResult(total, params), nil
|
||||
}
|
||||
|
||||
|
||||
@@ -1,6 +1,9 @@
|
||||
package service
|
||||
|
||||
import "time"
|
||||
import (
|
||||
"strconv"
|
||||
"time"
|
||||
)
|
||||
|
||||
type Account struct {
|
||||
ID int64
|
||||
@@ -82,12 +85,25 @@ func (a *Account) GetCredential(key string) string {
|
||||
if a.Credentials == nil {
|
||||
return ""
|
||||
}
|
||||
if v, ok := a.Credentials[key]; ok {
|
||||
if s, ok := v.(string); ok {
|
||||
return s
|
||||
}
|
||||
v, ok := a.Credentials[key]
|
||||
if !ok || v == nil {
|
||||
return ""
|
||||
}
|
||||
|
||||
// 支持多种类型(兼容历史数据中 expires_at 等字段可能是数字或字符串)
|
||||
switch val := v.(type) {
|
||||
case string:
|
||||
return val
|
||||
case float64:
|
||||
// JSON 解析后数字默认为 float64
|
||||
return strconv.FormatInt(int64(val), 10)
|
||||
case int64:
|
||||
return strconv.FormatInt(val, 10)
|
||||
case int:
|
||||
return strconv.Itoa(val)
|
||||
default:
|
||||
return ""
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
func (a *Account) GetModelMapping() map[string]string {
|
||||
|
||||
@@ -208,20 +208,23 @@ func (s *AccountService) Update(ctx context.Context, id int64, req UpdateAccount
|
||||
account.Status = *req.Status
|
||||
}
|
||||
|
||||
if err := s.accountRepo.Update(ctx, account); err != nil {
|
||||
return nil, fmt.Errorf("update account: %w", err)
|
||||
}
|
||||
|
||||
// 更新分组绑定
|
||||
// 先验证分组是否存在(在任何写操作之前)
|
||||
if req.GroupIDs != nil {
|
||||
// 验证分组是否存在
|
||||
for _, groupID := range *req.GroupIDs {
|
||||
_, err := s.groupRepo.GetByID(ctx, groupID)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("get group: %w", err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// 执行更新
|
||||
if err := s.accountRepo.Update(ctx, account); err != nil {
|
||||
return nil, fmt.Errorf("update account: %w", err)
|
||||
}
|
||||
|
||||
// 绑定分组
|
||||
if req.GroupIDs != nil {
|
||||
if err := s.accountRepo.BindGroups(ctx, account.ID, *req.GroupIDs); err != nil {
|
||||
return nil, fmt.Errorf("bind groups: %w", err)
|
||||
}
|
||||
|
||||
@@ -652,11 +652,20 @@ func (s *adminServiceImpl) UpdateAccount(ctx context.Context, id int64, input *U
|
||||
account.Status = input.Status
|
||||
}
|
||||
|
||||
// 先验证分组是否存在(在任何写操作之前)
|
||||
if input.GroupIDs != nil {
|
||||
for _, groupID := range *input.GroupIDs {
|
||||
if _, err := s.groupRepo.GetByID(ctx, groupID); err != nil {
|
||||
return nil, fmt.Errorf("get group: %w", err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if err := s.accountRepo.Update(ctx, account); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// 更新分组绑定
|
||||
// 绑定分组
|
||||
if input.GroupIDs != nil {
|
||||
if err := s.accountRepo.BindGroups(ctx, account.ID, *input.GroupIDs); err != nil {
|
||||
return nil, err
|
||||
|
||||
@@ -81,6 +81,15 @@ type ForwardResult struct {
|
||||
FirstTokenMs *int // 首字时间(流式请求)
|
||||
}
|
||||
|
||||
// UpstreamFailoverError indicates an upstream error that should trigger account failover.
|
||||
type UpstreamFailoverError struct {
|
||||
StatusCode int
|
||||
}
|
||||
|
||||
func (e *UpstreamFailoverError) Error() string {
|
||||
return fmt.Sprintf("upstream error: %d (failover)", e.StatusCode)
|
||||
}
|
||||
|
||||
// GatewayService handles API gateway operations
|
||||
type GatewayService struct {
|
||||
accountRepo AccountRepository
|
||||
@@ -274,19 +283,26 @@ func (s *GatewayService) SelectAccount(ctx context.Context, groupID *int64, sess
|
||||
|
||||
// SelectAccountForModel 选择支持指定模型的账号(粘性会话+优先级+模型映射)
|
||||
func (s *GatewayService) SelectAccountForModel(ctx context.Context, groupID *int64, sessionHash string, requestedModel string) (*Account, error) {
|
||||
return s.SelectAccountForModelWithExclusions(ctx, groupID, sessionHash, requestedModel, nil)
|
||||
}
|
||||
|
||||
// SelectAccountForModelWithExclusions selects an account supporting the requested model while excluding specified accounts.
|
||||
func (s *GatewayService) SelectAccountForModelWithExclusions(ctx context.Context, groupID *int64, sessionHash string, requestedModel string, excludedIDs map[int64]struct{}) (*Account, error) {
|
||||
// 1. 查询粘性会话
|
||||
if sessionHash != "" {
|
||||
accountID, err := s.cache.GetSessionAccountID(ctx, sessionHash)
|
||||
if err == nil && accountID > 0 {
|
||||
account, err := s.accountRepo.GetByID(ctx, accountID)
|
||||
// 使用IsSchedulable代替IsActive,确保限流/过载账号不会被选中
|
||||
// 同时检查模型支持
|
||||
if err == nil && account.IsSchedulable() && (requestedModel == "" || account.IsModelSupported(requestedModel)) {
|
||||
// 续期粘性会话
|
||||
if err := s.cache.RefreshSessionTTL(ctx, sessionHash, stickySessionTTL); err != nil {
|
||||
log.Printf("refresh session ttl failed: session=%s err=%v", sessionHash, err)
|
||||
if _, excluded := excludedIDs[accountID]; !excluded {
|
||||
account, err := s.accountRepo.GetByID(ctx, accountID)
|
||||
// 使用IsSchedulable代替IsActive,确保限流/过载账号不会被选中
|
||||
// 同时检查模型支持
|
||||
if err == nil && account.IsSchedulable() && (requestedModel == "" || account.IsModelSupported(requestedModel)) {
|
||||
// 续期粘性会话
|
||||
if err := s.cache.RefreshSessionTTL(ctx, sessionHash, stickySessionTTL); err != nil {
|
||||
log.Printf("refresh session ttl failed: session=%s err=%v", sessionHash, err)
|
||||
}
|
||||
return account, nil
|
||||
}
|
||||
return account, nil
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -307,6 +323,9 @@ func (s *GatewayService) SelectAccountForModel(ctx context.Context, groupID *int
|
||||
var selected *Account
|
||||
for i := range accounts {
|
||||
acc := &accounts[i]
|
||||
if _, excluded := excludedIDs[acc.ID]; excluded {
|
||||
continue
|
||||
}
|
||||
// 检查模型支持
|
||||
if requestedModel != "" && !acc.IsModelSupported(requestedModel) {
|
||||
continue
|
||||
@@ -394,6 +413,16 @@ func (s *GatewayService) shouldRetryUpstreamError(account *Account, statusCode i
|
||||
return !account.ShouldHandleErrorCode(statusCode)
|
||||
}
|
||||
|
||||
// shouldFailoverUpstreamError determines whether an upstream error should trigger account failover.
|
||||
func (s *GatewayService) shouldFailoverUpstreamError(statusCode int) bool {
|
||||
switch statusCode {
|
||||
case 401, 403, 429, 529:
|
||||
return true
|
||||
default:
|
||||
return statusCode >= 500
|
||||
}
|
||||
}
|
||||
|
||||
// Forward 转发请求到Claude API
|
||||
func (s *GatewayService) Forward(ctx context.Context, c *gin.Context, account *Account, body []byte) (*ForwardResult, error) {
|
||||
startTime := time.Now()
|
||||
@@ -478,9 +507,19 @@ func (s *GatewayService) Forward(ctx context.Context, c *gin.Context, account *A
|
||||
|
||||
// 处理重试耗尽的情况
|
||||
if resp.StatusCode >= 400 && s.shouldRetryUpstreamError(account, resp.StatusCode) {
|
||||
if s.shouldFailoverUpstreamError(resp.StatusCode) {
|
||||
s.handleRetryExhaustedSideEffects(ctx, resp, account)
|
||||
return nil, &UpstreamFailoverError{StatusCode: resp.StatusCode}
|
||||
}
|
||||
return s.handleRetryExhaustedError(ctx, resp, c, account)
|
||||
}
|
||||
|
||||
// 处理可切换账号的错误
|
||||
if resp.StatusCode >= 400 && s.shouldFailoverUpstreamError(resp.StatusCode) {
|
||||
s.handleFailoverSideEffects(ctx, resp, account)
|
||||
return nil, &UpstreamFailoverError{StatusCode: resp.StatusCode}
|
||||
}
|
||||
|
||||
// 处理错误响应(不可重试的错误)
|
||||
if resp.StatusCode >= 400 {
|
||||
return s.handleErrorResponse(ctx, resp, c, account)
|
||||
@@ -692,10 +731,7 @@ func (s *GatewayService) handleErrorResponse(ctx context.Context, resp *http.Res
|
||||
return nil, fmt.Errorf("upstream error: %d", resp.StatusCode)
|
||||
}
|
||||
|
||||
// handleRetryExhaustedError 处理重试耗尽后的错误
|
||||
// OAuth 403:标记账号异常
|
||||
// API Key 未配置错误码:仅返回错误,不标记账号
|
||||
func (s *GatewayService) handleRetryExhaustedError(ctx context.Context, resp *http.Response, c *gin.Context, account *Account) (*ForwardResult, error) {
|
||||
func (s *GatewayService) handleRetryExhaustedSideEffects(ctx context.Context, resp *http.Response, account *Account) {
|
||||
body, _ := io.ReadAll(resp.Body)
|
||||
statusCode := resp.StatusCode
|
||||
|
||||
@@ -707,6 +743,18 @@ func (s *GatewayService) handleRetryExhaustedError(ctx context.Context, resp *ht
|
||||
// API Key 未配置错误码:不标记账号状态
|
||||
log.Printf("Account %d: upstream error %d after %d retries (not marking account)", account.ID, statusCode, maxRetries)
|
||||
}
|
||||
}
|
||||
|
||||
func (s *GatewayService) handleFailoverSideEffects(ctx context.Context, resp *http.Response, account *Account) {
|
||||
body, _ := io.ReadAll(resp.Body)
|
||||
s.rateLimitService.HandleUpstreamError(ctx, account, resp.StatusCode, resp.Header, body)
|
||||
}
|
||||
|
||||
// handleRetryExhaustedError 处理重试耗尽后的错误
|
||||
// OAuth 403:标记账号异常
|
||||
// API Key 未配置错误码:仅返回错误,不标记账号
|
||||
func (s *GatewayService) handleRetryExhaustedError(ctx context.Context, resp *http.Response, c *gin.Context, account *Account) (*ForwardResult, error) {
|
||||
s.handleRetryExhaustedSideEffects(ctx, resp, account)
|
||||
|
||||
// 返回统一的重试耗尽错误响应
|
||||
c.JSON(http.StatusBadGateway, gin.H{
|
||||
@@ -717,7 +765,7 @@ func (s *GatewayService) handleRetryExhaustedError(ctx context.Context, resp *ht
|
||||
},
|
||||
})
|
||||
|
||||
return nil, fmt.Errorf("upstream error: %d (retries exhausted)", statusCode)
|
||||
return nil, fmt.Errorf("upstream error: %d (retries exhausted)", resp.StatusCode)
|
||||
}
|
||||
|
||||
// streamingResult 流式响应结果
|
||||
|
||||
@@ -62,14 +62,20 @@ func (s *GeminiMessagesCompatService) GetTokenProvider() *GeminiTokenProvider {
|
||||
}
|
||||
|
||||
func (s *GeminiMessagesCompatService) SelectAccountForModel(ctx context.Context, groupID *int64, sessionHash string, requestedModel string) (*Account, error) {
|
||||
return s.SelectAccountForModelWithExclusions(ctx, groupID, sessionHash, requestedModel, nil)
|
||||
}
|
||||
|
||||
func (s *GeminiMessagesCompatService) SelectAccountForModelWithExclusions(ctx context.Context, groupID *int64, sessionHash string, requestedModel string, excludedIDs map[int64]struct{}) (*Account, error) {
|
||||
cacheKey := "gemini:" + sessionHash
|
||||
if sessionHash != "" {
|
||||
accountID, err := s.cache.GetSessionAccountID(ctx, cacheKey)
|
||||
if err == nil && accountID > 0 {
|
||||
account, err := s.accountRepo.GetByID(ctx, accountID)
|
||||
if err == nil && account.IsSchedulable() && account.Platform == PlatformGemini && (requestedModel == "" || account.IsModelSupported(requestedModel)) {
|
||||
_ = s.cache.RefreshSessionTTL(ctx, cacheKey, geminiStickySessionTTL)
|
||||
return account, nil
|
||||
if _, excluded := excludedIDs[accountID]; !excluded {
|
||||
account, err := s.accountRepo.GetByID(ctx, accountID)
|
||||
if err == nil && account.IsSchedulable() && account.Platform == PlatformGemini && (requestedModel == "" || account.IsModelSupported(requestedModel)) {
|
||||
_ = s.cache.RefreshSessionTTL(ctx, cacheKey, geminiStickySessionTTL)
|
||||
return account, nil
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -88,6 +94,9 @@ func (s *GeminiMessagesCompatService) SelectAccountForModel(ctx context.Context,
|
||||
var selected *Account
|
||||
for i := range accounts {
|
||||
acc := &accounts[i]
|
||||
if _, excluded := excludedIDs[acc.ID]; excluded {
|
||||
continue
|
||||
}
|
||||
if requestedModel != "" && !acc.IsModelSupported(requestedModel) {
|
||||
continue
|
||||
}
|
||||
@@ -425,6 +434,9 @@ func (s *GeminiMessagesCompatService) Forward(ctx context.Context, c *gin.Contex
|
||||
if resp.StatusCode >= 400 {
|
||||
respBody, _ := io.ReadAll(io.LimitReader(resp.Body, 2<<20))
|
||||
s.handleGeminiUpstreamError(ctx, account, resp.StatusCode, resp.Header, respBody)
|
||||
if s.shouldFailoverGeminiUpstreamError(resp.StatusCode) {
|
||||
return nil, &UpstreamFailoverError{StatusCode: resp.StatusCode}
|
||||
}
|
||||
return nil, s.writeGeminiMappedError(c, resp.StatusCode, respBody)
|
||||
}
|
||||
|
||||
@@ -724,6 +736,10 @@ func (s *GeminiMessagesCompatService) ForwardNative(ctx context.Context, c *gin.
|
||||
}, nil
|
||||
}
|
||||
|
||||
if s.shouldFailoverGeminiUpstreamError(resp.StatusCode) {
|
||||
return nil, &UpstreamFailoverError{StatusCode: resp.StatusCode}
|
||||
}
|
||||
|
||||
respBody = unwrapIfNeeded(isOAuth, respBody)
|
||||
contentType := resp.Header.Get("Content-Type")
|
||||
if contentType == "" {
|
||||
@@ -795,6 +811,15 @@ func (s *GeminiMessagesCompatService) shouldRetryGeminiUpstreamError(account *Ac
|
||||
}
|
||||
}
|
||||
|
||||
func (s *GeminiMessagesCompatService) shouldFailoverGeminiUpstreamError(statusCode int) bool {
|
||||
switch statusCode {
|
||||
case 401, 403, 429, 529:
|
||||
return true
|
||||
default:
|
||||
return statusCode >= 500
|
||||
}
|
||||
}
|
||||
|
||||
func sleepGeminiBackoff(attempt int) {
|
||||
delay := geminiRetryBaseDelay * time.Duration(1<<uint(attempt-1))
|
||||
if delay > geminiRetryMaxDelay {
|
||||
|
||||
@@ -129,15 +129,22 @@ func (s *OpenAIGatewayService) SelectAccount(ctx context.Context, groupID *int64
|
||||
|
||||
// SelectAccountForModel selects an account supporting the requested model
|
||||
func (s *OpenAIGatewayService) SelectAccountForModel(ctx context.Context, groupID *int64, sessionHash string, requestedModel string) (*Account, error) {
|
||||
return s.SelectAccountForModelWithExclusions(ctx, groupID, sessionHash, requestedModel, nil)
|
||||
}
|
||||
|
||||
// SelectAccountForModelWithExclusions selects an account supporting the requested model while excluding specified accounts.
|
||||
func (s *OpenAIGatewayService) SelectAccountForModelWithExclusions(ctx context.Context, groupID *int64, sessionHash string, requestedModel string, excludedIDs map[int64]struct{}) (*Account, error) {
|
||||
// 1. Check sticky session
|
||||
if sessionHash != "" {
|
||||
accountID, err := s.cache.GetSessionAccountID(ctx, "openai:"+sessionHash)
|
||||
if err == nil && accountID > 0 {
|
||||
account, err := s.accountRepo.GetByID(ctx, accountID)
|
||||
if err == nil && account.IsSchedulable() && account.IsOpenAI() && (requestedModel == "" || account.IsModelSupported(requestedModel)) {
|
||||
// Refresh sticky session TTL
|
||||
_ = s.cache.RefreshSessionTTL(ctx, "openai:"+sessionHash, openaiStickySessionTTL)
|
||||
return account, nil
|
||||
if _, excluded := excludedIDs[accountID]; !excluded {
|
||||
account, err := s.accountRepo.GetByID(ctx, accountID)
|
||||
if err == nil && account.IsSchedulable() && account.IsOpenAI() && (requestedModel == "" || account.IsModelSupported(requestedModel)) {
|
||||
// Refresh sticky session TTL
|
||||
_ = s.cache.RefreshSessionTTL(ctx, "openai:"+sessionHash, openaiStickySessionTTL)
|
||||
return account, nil
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -158,6 +165,9 @@ func (s *OpenAIGatewayService) SelectAccountForModel(ctx context.Context, groupI
|
||||
var selected *Account
|
||||
for i := range accounts {
|
||||
acc := &accounts[i]
|
||||
if _, excluded := excludedIDs[acc.ID]; excluded {
|
||||
continue
|
||||
}
|
||||
// Check model support
|
||||
if requestedModel != "" && !acc.IsModelSupported(requestedModel) {
|
||||
continue
|
||||
@@ -221,6 +231,20 @@ func (s *OpenAIGatewayService) GetAccessToken(ctx context.Context, account *Acco
|
||||
}
|
||||
}
|
||||
|
||||
func (s *OpenAIGatewayService) shouldFailoverUpstreamError(statusCode int) bool {
|
||||
switch statusCode {
|
||||
case 401, 403, 429, 529:
|
||||
return true
|
||||
default:
|
||||
return statusCode >= 500
|
||||
}
|
||||
}
|
||||
|
||||
func (s *OpenAIGatewayService) handleFailoverSideEffects(ctx context.Context, resp *http.Response, account *Account) {
|
||||
body, _ := io.ReadAll(resp.Body)
|
||||
s.rateLimitService.HandleUpstreamError(ctx, account, resp.StatusCode, resp.Header, body)
|
||||
}
|
||||
|
||||
// Forward forwards request to OpenAI API
|
||||
func (s *OpenAIGatewayService) Forward(ctx context.Context, c *gin.Context, account *Account, body []byte) (*OpenAIForwardResult, error) {
|
||||
startTime := time.Now()
|
||||
@@ -288,6 +312,10 @@ func (s *OpenAIGatewayService) Forward(ctx context.Context, c *gin.Context, acco
|
||||
|
||||
// Handle error response
|
||||
if resp.StatusCode >= 400 {
|
||||
if s.shouldFailoverUpstreamError(resp.StatusCode) {
|
||||
s.handleFailoverSideEffects(ctx, resp, account)
|
||||
return nil, &UpstreamFailoverError{StatusCode: resp.StatusCode}
|
||||
}
|
||||
return s.handleErrorResponse(ctx, resp, c, account)
|
||||
}
|
||||
|
||||
|
||||
@@ -280,10 +280,12 @@
|
||||
import { ref, watch, nextTick } from 'vue'
|
||||
import { useI18n } from 'vue-i18n'
|
||||
import Modal from '@/components/common/Modal.vue'
|
||||
import { useClipboard } from '@/composables/useClipboard'
|
||||
import { adminAPI } from '@/api/admin'
|
||||
import type { Account, ClaudeModel } from '@/types'
|
||||
|
||||
const { t } = useI18n()
|
||||
const { copyToClipboard } = useClipboard()
|
||||
|
||||
interface OutputLine {
|
||||
text: string
|
||||
@@ -501,6 +503,6 @@ const handleEvent = (event: {
|
||||
|
||||
const copyOutput = () => {
|
||||
const text = outputLines.value.map((l) => l.text).join('\n')
|
||||
navigator.clipboard.writeText(text)
|
||||
copyToClipboard(text, t('admin.accounts.outputCopied'))
|
||||
}
|
||||
</script>
|
||||
|
||||
@@ -119,7 +119,7 @@
|
||||
import { ref, computed, h, watch, type Component } from 'vue'
|
||||
import { useI18n } from 'vue-i18n'
|
||||
import Modal from '@/components/common/Modal.vue'
|
||||
import { useAppStore } from '@/stores/app'
|
||||
import { useClipboard } from '@/composables/useClipboard'
|
||||
import type { GroupPlatform } from '@/types'
|
||||
|
||||
interface Props {
|
||||
@@ -150,7 +150,7 @@ const props = defineProps<Props>()
|
||||
const emit = defineEmits<Emits>()
|
||||
|
||||
const { t } = useI18n()
|
||||
const appStore = useAppStore()
|
||||
const { copyToClipboard: clipboardCopy } = useClipboard()
|
||||
|
||||
const copiedIndex = ref<number | null>(null)
|
||||
const activeTab = ref<string>('unix')
|
||||
@@ -340,14 +340,12 @@ ${key('requires_openai_auth')} ${operator('=')} ${keyword('true')}`
|
||||
}
|
||||
|
||||
const copyContent = async (content: string, index: number) => {
|
||||
try {
|
||||
await navigator.clipboard.writeText(content)
|
||||
const success = await clipboardCopy(content, t('keys.copied'))
|
||||
if (success) {
|
||||
copiedIndex.value = index
|
||||
setTimeout(() => {
|
||||
copiedIndex.value = null
|
||||
}, 2000)
|
||||
} catch (error) {
|
||||
appStore.showError(t('common.copyFailed'))
|
||||
}
|
||||
}
|
||||
</script>
|
||||
|
||||
@@ -1,40 +1,65 @@
|
||||
import { ref } from 'vue'
|
||||
import { useAppStore } from '@/stores/app'
|
||||
|
||||
/**
|
||||
* 检测是否支持 Clipboard API(需要安全上下文:HTTPS/localhost)
|
||||
*/
|
||||
function isClipboardSupported(): boolean {
|
||||
return !!(navigator.clipboard && window.isSecureContext)
|
||||
}
|
||||
|
||||
/**
|
||||
* 降级方案:使用 textarea + execCommand
|
||||
* 使用 textarea 而非 input,以正确处理多行文本
|
||||
*/
|
||||
function fallbackCopy(text: string): boolean {
|
||||
const textarea = document.createElement('textarea')
|
||||
textarea.value = text
|
||||
textarea.style.cssText = 'position:fixed;left:-9999px;top:-9999px'
|
||||
document.body.appendChild(textarea)
|
||||
textarea.select()
|
||||
try {
|
||||
return document.execCommand('copy')
|
||||
} finally {
|
||||
document.body.removeChild(textarea)
|
||||
}
|
||||
}
|
||||
|
||||
export function useClipboard() {
|
||||
const appStore = useAppStore()
|
||||
const copied = ref(false)
|
||||
|
||||
const copyToClipboard = async (text: string, successMessage = 'Copied to clipboard') => {
|
||||
const copyToClipboard = async (
|
||||
text: string,
|
||||
successMessage = 'Copied to clipboard'
|
||||
): Promise<boolean> => {
|
||||
if (!text) return false
|
||||
|
||||
try {
|
||||
await navigator.clipboard.writeText(text)
|
||||
copied.value = true
|
||||
appStore.showSuccess(successMessage)
|
||||
setTimeout(() => {
|
||||
copied.value = false
|
||||
}, 2000)
|
||||
return true
|
||||
} catch {
|
||||
// Fallback for older browsers
|
||||
const input = document.createElement('input')
|
||||
input.value = text
|
||||
document.body.appendChild(input)
|
||||
input.select()
|
||||
document.execCommand('copy')
|
||||
document.body.removeChild(input)
|
||||
copied.value = true
|
||||
appStore.showSuccess(successMessage)
|
||||
setTimeout(() => {
|
||||
copied.value = false
|
||||
}, 2000)
|
||||
return true
|
||||
let success = false
|
||||
|
||||
if (isClipboardSupported()) {
|
||||
try {
|
||||
await navigator.clipboard.writeText(text)
|
||||
success = true
|
||||
} catch {
|
||||
success = fallbackCopy(text)
|
||||
}
|
||||
} else {
|
||||
success = fallbackCopy(text)
|
||||
}
|
||||
|
||||
if (success) {
|
||||
copied.value = true
|
||||
appStore.showSuccess(successMessage)
|
||||
setTimeout(() => {
|
||||
copied.value = false
|
||||
}, 2000)
|
||||
} else {
|
||||
appStore.showError('Copy failed')
|
||||
}
|
||||
|
||||
return success
|
||||
}
|
||||
|
||||
return {
|
||||
copied,
|
||||
copyToClipboard
|
||||
}
|
||||
return { copied, copyToClipboard }
|
||||
}
|
||||
|
||||
@@ -418,6 +418,7 @@
|
||||
import { ref, reactive, computed, onMounted } from 'vue'
|
||||
import { useI18n } from 'vue-i18n'
|
||||
import { useAppStore } from '@/stores/app'
|
||||
import { useClipboard } from '@/composables/useClipboard'
|
||||
import { adminAPI } from '@/api/admin'
|
||||
import { formatDateTime } from '@/utils/format'
|
||||
import type { RedeemCode, RedeemCodeType, Group } from '@/types'
|
||||
@@ -431,6 +432,7 @@ import Select from '@/components/common/Select.vue'
|
||||
|
||||
const { t } = useI18n()
|
||||
const appStore = useAppStore()
|
||||
const { copyToClipboard: clipboardCopy } = useClipboard()
|
||||
|
||||
const showGenerateDialog = ref(false)
|
||||
const showResultDialog = ref(false)
|
||||
@@ -618,15 +620,12 @@ const handleGenerateCodes = async () => {
|
||||
}
|
||||
|
||||
const copyToClipboard = async (text: string) => {
|
||||
try {
|
||||
await navigator.clipboard.writeText(text)
|
||||
const success = await clipboardCopy(text, t('admin.redeem.copied'))
|
||||
if (success) {
|
||||
copiedCode.value = text
|
||||
setTimeout(() => {
|
||||
copiedCode.value = null
|
||||
}, 2000)
|
||||
} catch (error) {
|
||||
appStore.showError(t('admin.redeem.failedToCopy'))
|
||||
console.error('Error copying to clipboard:', error)
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -1173,6 +1173,7 @@
|
||||
import { ref, reactive, computed, onMounted } from 'vue'
|
||||
import { useI18n } from 'vue-i18n'
|
||||
import { useAppStore } from '@/stores/app'
|
||||
import { useClipboard } from '@/composables/useClipboard'
|
||||
import { formatDateTime } from '@/utils/format'
|
||||
|
||||
const { t } = useI18n()
|
||||
@@ -1191,6 +1192,7 @@ import Select from '@/components/common/Select.vue'
|
||||
import GroupBadge from '@/components/common/GroupBadge.vue'
|
||||
|
||||
const appStore = useAppStore()
|
||||
const { copyToClipboard: clipboardCopy } = useClipboard()
|
||||
|
||||
const columns = computed<Column[]>(() => [
|
||||
{ key: 'email', label: t('admin.users.columns.user'), sortable: true },
|
||||
@@ -1312,27 +1314,23 @@ const generateEditPassword = () => {
|
||||
|
||||
const copyPassword = async () => {
|
||||
if (!createForm.password) return
|
||||
try {
|
||||
await navigator.clipboard.writeText(createForm.password)
|
||||
const success = await clipboardCopy(createForm.password, t('admin.users.passwordCopied'))
|
||||
if (success) {
|
||||
passwordCopied.value = true
|
||||
setTimeout(() => {
|
||||
passwordCopied.value = false
|
||||
}, 2000)
|
||||
} catch (error) {
|
||||
appStore.showError(t('common.copyFailed'))
|
||||
}
|
||||
}
|
||||
|
||||
const copyEditPassword = async () => {
|
||||
if (!editForm.password) return
|
||||
try {
|
||||
await navigator.clipboard.writeText(editForm.password)
|
||||
const success = await clipboardCopy(editForm.password, t('admin.users.passwordCopied'))
|
||||
if (success) {
|
||||
editPasswordCopied.value = true
|
||||
setTimeout(() => {
|
||||
editPasswordCopied.value = false
|
||||
}, 2000)
|
||||
} catch (error) {
|
||||
appStore.showError(t('common.copyFailed'))
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -493,6 +493,7 @@
|
||||
import { ref, computed, onMounted, onUnmounted, type ComponentPublicInstance } from 'vue'
|
||||
import { useI18n } from 'vue-i18n'
|
||||
import { useAppStore } from '@/stores/app'
|
||||
import { useClipboard } from '@/composables/useClipboard'
|
||||
|
||||
const { t } = useI18n()
|
||||
import { keysAPI, authAPI, usageAPI, userGroupsAPI } from '@/api'
|
||||
@@ -520,6 +521,7 @@ interface GroupOption {
|
||||
}
|
||||
|
||||
const appStore = useAppStore()
|
||||
const { copyToClipboard: clipboardCopy } = useClipboard()
|
||||
|
||||
const columns = computed<Column[]>(() => [
|
||||
{ key: 'name', label: t('common.name'), sortable: true },
|
||||
@@ -616,14 +618,12 @@ const maskKey = (key: string): string => {
|
||||
}
|
||||
|
||||
const copyToClipboard = async (text: string, keyId: number) => {
|
||||
try {
|
||||
await navigator.clipboard.writeText(text)
|
||||
const success = await clipboardCopy(text, t('keys.copied'))
|
||||
if (success) {
|
||||
copiedKeyId.value = keyId
|
||||
setTimeout(() => {
|
||||
copiedKeyId.value = null
|
||||
}, 2000)
|
||||
} catch (error) {
|
||||
appStore.showError(t('common.copyFailed'))
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
File diff suppressed because one or more lines are too long
Reference in New Issue
Block a user