Files
new-api/common/init.go
t0ng7u 8e3f9b1faa 🛡️ fix: prevent OOM on large/decompressed requests; skip heavy prompt meta when token count is disabled
Clamp request body size (including post-decompression) to avoid memory exhaustion caused by huge payloads/zip bombs, especially with large-context Claude requests. Add a configurable `MAX_REQUEST_BODY_MB` (default `32`) and document it.

- Enforce max request body size after gzip/br decompression via `http.MaxBytesReader`
- Add a secondary size guard in `common.GetRequestBody` and cache-safe handling
- Return **413 Request Entity Too Large** on oversized bodies in relay entry
- Avoid building large `TokenCountMeta.CombineText` when both token counting and sensitive check are disabled (use lightweight meta for pricing)
- Update READMEs (CN/EN/FR/JA) with `MAX_REQUEST_BODY_MB`
- Fix a handful of vet/formatting issues encountered during the change
- `go test ./...` passes
2025-12-16 17:00:19 +08:00

152 lines
5.8 KiB
Go
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
package common
import (
"flag"
"fmt"
"log"
"os"
"path/filepath"
"strconv"
"strings"
"time"
"github.com/QuantumNous/new-api/constant"
)
var (
Port = flag.Int("port", 3000, "the listening port")
PrintVersion = flag.Bool("version", false, "print version and exit")
PrintHelp = flag.Bool("help", false, "print help and exit")
LogDir = flag.String("log-dir", "./logs", "specify the log directory")
)
func printHelp() {
fmt.Println("NewAPI(Based OneAPI) " + Version + " - The next-generation LLM gateway and AI asset management system supports multiple languages.")
fmt.Println("Original Project: OneAPI by JustSong - https://github.com/songquanpeng/one-api")
fmt.Println("Maintainer: QuantumNous - https://github.com/QuantumNous/new-api")
fmt.Println("Usage: newapi [--port <port>] [--log-dir <log directory>] [--version] [--help]")
}
func InitEnv() {
flag.Parse()
envVersion := os.Getenv("VERSION")
if envVersion != "" {
Version = envVersion
}
if *PrintVersion {
fmt.Println(Version)
os.Exit(0)
}
if *PrintHelp {
printHelp()
os.Exit(0)
}
if os.Getenv("SESSION_SECRET") != "" {
ss := os.Getenv("SESSION_SECRET")
if ss == "random_string" {
log.Println("WARNING: SESSION_SECRET is set to the default value 'random_string', please change it to a random string.")
log.Println("警告SESSION_SECRET被设置为默认值'random_string',请修改为随机字符串。")
log.Fatal("Please set SESSION_SECRET to a random string.")
} else {
SessionSecret = ss
}
}
if os.Getenv("CRYPTO_SECRET") != "" {
CryptoSecret = os.Getenv("CRYPTO_SECRET")
} else {
CryptoSecret = SessionSecret
}
if os.Getenv("SQLITE_PATH") != "" {
SQLitePath = os.Getenv("SQLITE_PATH")
}
if *LogDir != "" {
var err error
*LogDir, err = filepath.Abs(*LogDir)
if err != nil {
log.Fatal(err)
}
if _, err := os.Stat(*LogDir); os.IsNotExist(err) {
err = os.Mkdir(*LogDir, 0777)
if err != nil {
log.Fatal(err)
}
}
}
// Initialize variables from constants.go that were using environment variables
DebugEnabled = os.Getenv("DEBUG") == "true"
MemoryCacheEnabled = os.Getenv("MEMORY_CACHE_ENABLED") == "true"
IsMasterNode = os.Getenv("NODE_TYPE") != "slave"
// Parse requestInterval and set RequestInterval
requestInterval, _ = strconv.Atoi(os.Getenv("POLLING_INTERVAL"))
RequestInterval = time.Duration(requestInterval) * time.Second
// Initialize variables with GetEnvOrDefault
SyncFrequency = GetEnvOrDefault("SYNC_FREQUENCY", 60)
BatchUpdateInterval = GetEnvOrDefault("BATCH_UPDATE_INTERVAL", 5)
RelayTimeout = GetEnvOrDefault("RELAY_TIMEOUT", 0)
RelayMaxIdleConns = GetEnvOrDefault("RELAY_MAX_IDLE_CONNS", 500)
RelayMaxIdleConnsPerHost = GetEnvOrDefault("RELAY_MAX_IDLE_CONNS_PER_HOST", 100)
// Initialize string variables with GetEnvOrDefaultString
GeminiSafetySetting = GetEnvOrDefaultString("GEMINI_SAFETY_SETTING", "BLOCK_NONE")
CohereSafetySetting = GetEnvOrDefaultString("COHERE_SAFETY_SETTING", "NONE")
// Initialize rate limit variables
GlobalApiRateLimitEnable = GetEnvOrDefaultBool("GLOBAL_API_RATE_LIMIT_ENABLE", true)
GlobalApiRateLimitNum = GetEnvOrDefault("GLOBAL_API_RATE_LIMIT", 180)
GlobalApiRateLimitDuration = int64(GetEnvOrDefault("GLOBAL_API_RATE_LIMIT_DURATION", 180))
GlobalWebRateLimitEnable = GetEnvOrDefaultBool("GLOBAL_WEB_RATE_LIMIT_ENABLE", true)
GlobalWebRateLimitNum = GetEnvOrDefault("GLOBAL_WEB_RATE_LIMIT", 60)
GlobalWebRateLimitDuration = int64(GetEnvOrDefault("GLOBAL_WEB_RATE_LIMIT_DURATION", 180))
CriticalRateLimitEnable = GetEnvOrDefaultBool("CRITICAL_RATE_LIMIT_ENABLE", true)
CriticalRateLimitNum = GetEnvOrDefault("CRITICAL_RATE_LIMIT", 20)
CriticalRateLimitDuration = int64(GetEnvOrDefault("CRITICAL_RATE_LIMIT_DURATION", 20*60))
initConstantEnv()
}
func initConstantEnv() {
constant.StreamingTimeout = GetEnvOrDefault("STREAMING_TIMEOUT", 300)
constant.DifyDebug = GetEnvOrDefaultBool("DIFY_DEBUG", true)
constant.MaxFileDownloadMB = GetEnvOrDefault("MAX_FILE_DOWNLOAD_MB", 20)
constant.StreamScannerMaxBufferMB = GetEnvOrDefault("STREAM_SCANNER_MAX_BUFFER_MB", 64)
// MaxRequestBodyMB 请求体最大大小(解压后),用于防止超大请求/zip bomb导致内存暴涨
constant.MaxRequestBodyMB = GetEnvOrDefault("MAX_REQUEST_BODY_MB", 32)
// ForceStreamOption 覆盖请求参数强制返回usage信息
constant.ForceStreamOption = GetEnvOrDefaultBool("FORCE_STREAM_OPTION", true)
constant.CountToken = GetEnvOrDefaultBool("CountToken", true)
constant.GetMediaToken = GetEnvOrDefaultBool("GET_MEDIA_TOKEN", true)
constant.GetMediaTokenNotStream = GetEnvOrDefaultBool("GET_MEDIA_TOKEN_NOT_STREAM", false)
constant.UpdateTask = GetEnvOrDefaultBool("UPDATE_TASK", true)
constant.AzureDefaultAPIVersion = GetEnvOrDefaultString("AZURE_DEFAULT_API_VERSION", "2025-04-01-preview")
constant.GeminiVisionMaxImageNum = GetEnvOrDefault("GEMINI_VISION_MAX_IMAGE_NUM", 16)
constant.NotifyLimitCount = GetEnvOrDefault("NOTIFY_LIMIT_COUNT", 2)
constant.NotificationLimitDurationMinute = GetEnvOrDefault("NOTIFICATION_LIMIT_DURATION_MINUTE", 10)
// GenerateDefaultToken 是否生成初始令牌,默认关闭。
constant.GenerateDefaultToken = GetEnvOrDefaultBool("GENERATE_DEFAULT_TOKEN", false)
// 是否启用错误日志
constant.ErrorLogEnabled = GetEnvOrDefaultBool("ERROR_LOG_ENABLED", false)
// 任务轮询时查询的最大数量
constant.TaskQueryLimit = GetEnvOrDefault("TASK_QUERY_LIMIT", 1000)
soraPatchStr := GetEnvOrDefaultString("TASK_PRICE_PATCH", "")
if soraPatchStr != "" {
var taskPricePatches []string
soraPatches := strings.Split(soraPatchStr, ",")
for _, patch := range soraPatches {
trimmedPatch := strings.TrimSpace(patch)
if trimmedPatch != "" {
taskPricePatches = append(taskPricePatches, trimmedPatch)
}
}
constant.TaskPricePatches = taskPricePatches
}
}