refactor(ratio): replace maps with RWMap for improved concurrency handling

This commit is contained in:
CaIon
2026-02-08 00:48:21 +08:00
parent 7a146a11f5
commit 44c5fac5ea
5 changed files with 115 additions and 360 deletions

View File

@@ -1,12 +1,11 @@
package ratio_setting
import (
"encoding/json"
"strings"
"sync"
"github.com/QuantumNous/new-api/common"
"github.com/QuantumNous/new-api/setting/operation_setting"
"github.com/QuantumNous/new-api/types"
)
// from songquanpeng/one-api
@@ -319,19 +318,9 @@ var defaultAudioCompletionRatio = map[string]float64{
"tts-1-hd-1106": 0,
}
var (
modelPriceMap map[string]float64 = nil
modelPriceMapMutex = sync.RWMutex{}
)
var (
modelRatioMap map[string]float64 = nil
modelRatioMapMutex = sync.RWMutex{}
)
var (
CompletionRatio map[string]float64 = nil
CompletionRatioMutex = sync.RWMutex{}
)
var modelPriceMap = types.NewRWMap[string, float64]()
var modelRatioMap = types.NewRWMap[string, float64]()
var completionRatioMap = types.NewRWMap[string, float64]()
var defaultCompletionRatio = map[string]float64{
"gpt-4-gizmo-*": 2,
@@ -342,84 +331,34 @@ var defaultCompletionRatio = map[string]float64{
// InitRatioSettings initializes all model related settings maps
func InitRatioSettings() {
// Initialize modelPriceMap
modelPriceMapMutex.Lock()
modelPriceMap = defaultModelPrice
modelPriceMapMutex.Unlock()
// Initialize modelRatioMap
modelRatioMapMutex.Lock()
modelRatioMap = defaultModelRatio
modelRatioMapMutex.Unlock()
// Initialize CompletionRatio
CompletionRatioMutex.Lock()
CompletionRatio = defaultCompletionRatio
CompletionRatioMutex.Unlock()
// Initialize cacheRatioMap
cacheRatioMapMutex.Lock()
cacheRatioMap = defaultCacheRatio
cacheRatioMapMutex.Unlock()
// Initialize createCacheRatioMap (5m cache creation ratio)
createCacheRatioMapMutex.Lock()
createCacheRatioMap = defaultCreateCacheRatio
createCacheRatioMapMutex.Unlock()
// initialize imageRatioMap
imageRatioMapMutex.Lock()
imageRatioMap = defaultImageRatio
imageRatioMapMutex.Unlock()
// initialize audioRatioMap
audioRatioMapMutex.Lock()
audioRatioMap = defaultAudioRatio
audioRatioMapMutex.Unlock()
// initialize audioCompletionRatioMap
audioCompletionRatioMapMutex.Lock()
audioCompletionRatioMap = defaultAudioCompletionRatio
audioCompletionRatioMapMutex.Unlock()
modelPriceMap.AddAll(defaultModelPrice)
modelRatioMap.AddAll(defaultModelRatio)
completionRatioMap.AddAll(defaultCompletionRatio)
cacheRatioMap.AddAll(defaultCacheRatio)
createCacheRatioMap.AddAll(defaultCreateCacheRatio)
imageRatioMap.AddAll(defaultImageRatio)
audioRatioMap.AddAll(defaultAudioRatio)
audioCompletionRatioMap.AddAll(defaultAudioCompletionRatio)
}
func GetModelPriceMap() map[string]float64 {
modelPriceMapMutex.RLock()
defer modelPriceMapMutex.RUnlock()
return modelPriceMap
return modelPriceMap.ReadAll()
}
func ModelPrice2JSONString() string {
modelPriceMapMutex.RLock()
defer modelPriceMapMutex.RUnlock()
jsonBytes, err := common.Marshal(modelPriceMap)
if err != nil {
common.SysError("error marshalling model price: " + err.Error())
}
return string(jsonBytes)
return modelPriceMap.MarshalJSONString()
}
func UpdateModelPriceByJSONString(jsonStr string) error {
modelPriceMapMutex.Lock()
defer modelPriceMapMutex.Unlock()
modelPriceMap = make(map[string]float64)
err := json.Unmarshal([]byte(jsonStr), &modelPriceMap)
if err == nil {
InvalidateExposedDataCache()
}
return err
return types.LoadFromJsonStringWithCallback(modelPriceMap, jsonStr, InvalidateExposedDataCache)
}
// GetModelPrice 返回模型的价格,如果模型不存在则返回-1false
func GetModelPrice(name string, printErr bool) (float64, bool) {
modelPriceMapMutex.RLock()
defer modelPriceMapMutex.RUnlock()
name = FormatMatchingModelName(name)
if strings.HasSuffix(name, CompactModelSuffix) {
price, ok := modelPriceMap[CompactWildcardModelKey]
price, ok := modelPriceMap.Get(CompactWildcardModelKey)
if !ok {
if printErr {
common.SysError("model price not found: " + name)
@@ -429,7 +368,7 @@ func GetModelPrice(name string, printErr bool) (float64, bool) {
return price, true
}
price, ok := modelPriceMap[name]
price, ok := modelPriceMap.Get(name)
if !ok {
if printErr {
common.SysError("model price not found: " + name)
@@ -440,14 +379,7 @@ func GetModelPrice(name string, printErr bool) (float64, bool) {
}
func UpdateModelRatioByJSONString(jsonStr string) error {
modelRatioMapMutex.Lock()
defer modelRatioMapMutex.Unlock()
modelRatioMap = make(map[string]float64)
err := common.Unmarshal([]byte(jsonStr), &modelRatioMap)
if err == nil {
InvalidateExposedDataCache()
}
return err
return types.LoadFromJsonStringWithCallback(modelRatioMap, jsonStr, InvalidateExposedDataCache)
}
// 处理带有思考预算的模型名称,方便统一定价
@@ -459,15 +391,12 @@ func handleThinkingBudgetModel(name, prefix, wildcard string) string {
}
func GetModelRatio(name string) (float64, bool, string) {
modelRatioMapMutex.RLock()
defer modelRatioMapMutex.RUnlock()
name = FormatMatchingModelName(name)
ratio, ok := modelRatioMap[name]
ratio, ok := modelRatioMap.Get(name)
if !ok {
if strings.HasSuffix(name, CompactModelSuffix) {
if wildcardRatio, ok := modelRatioMap[CompactWildcardModelKey]; ok {
if wildcardRatio, ok := modelRatioMap.Get(CompactWildcardModelKey); ok {
return wildcardRatio, true, name
}
//return 0, true, name
@@ -493,54 +422,19 @@ func GetDefaultModelPriceMap() map[string]float64 {
return defaultModelPrice
}
func GetDefaultImageRatioMap() map[string]float64 {
return defaultImageRatio
}
func GetDefaultAudioRatioMap() map[string]float64 {
return defaultAudioRatio
}
func GetDefaultAudioCompletionRatioMap() map[string]float64 {
return defaultAudioCompletionRatio
}
func GetCompletionRatioMap() map[string]float64 {
CompletionRatioMutex.RLock()
defer CompletionRatioMutex.RUnlock()
return CompletionRatio
}
func CompletionRatio2JSONString() string {
CompletionRatioMutex.RLock()
defer CompletionRatioMutex.RUnlock()
jsonBytes, err := json.Marshal(CompletionRatio)
if err != nil {
common.SysError("error marshalling completion ratio: " + err.Error())
}
return string(jsonBytes)
return completionRatioMap.MarshalJSONString()
}
func UpdateCompletionRatioByJSONString(jsonStr string) error {
CompletionRatioMutex.Lock()
defer CompletionRatioMutex.Unlock()
CompletionRatio = make(map[string]float64)
err := common.Unmarshal([]byte(jsonStr), &CompletionRatio)
if err == nil {
InvalidateExposedDataCache()
}
return err
return types.LoadFromJsonStringWithCallback(completionRatioMap, jsonStr, InvalidateExposedDataCache)
}
func GetCompletionRatio(name string) float64 {
CompletionRatioMutex.RLock()
defer CompletionRatioMutex.RUnlock()
name = FormatMatchingModelName(name)
if strings.Contains(name, "/") {
if ratio, ok := CompletionRatio[name]; ok {
if ratio, ok := completionRatioMap.Get(name); ok {
return ratio
}
}
@@ -548,7 +442,7 @@ func GetCompletionRatio(name string) float64 {
if contain {
return hardCodedRatio
}
if ratio, ok := CompletionRatio[name]; ok {
if ratio, ok := completionRatioMap.Get(name); ok {
return ratio
}
return hardCodedRatio
@@ -676,88 +570,54 @@ func getHardcodedCompletionModelRatio(name string) (float64, bool) {
}
func GetAudioRatio(name string) float64 {
audioRatioMapMutex.RLock()
defer audioRatioMapMutex.RUnlock()
name = FormatMatchingModelName(name)
if ratio, ok := audioRatioMap[name]; ok {
if ratio, ok := audioRatioMap.Get(name); ok {
return ratio
}
return 1
}
func GetAudioCompletionRatio(name string) float64 {
audioCompletionRatioMapMutex.RLock()
defer audioCompletionRatioMapMutex.RUnlock()
name = FormatMatchingModelName(name)
if ratio, ok := audioCompletionRatioMap[name]; ok {
if ratio, ok := audioCompletionRatioMap.Get(name); ok {
return ratio
}
return 1
}
func ContainsAudioRatio(name string) bool {
audioRatioMapMutex.RLock()
defer audioRatioMapMutex.RUnlock()
name = FormatMatchingModelName(name)
_, ok := audioRatioMap[name]
_, ok := audioRatioMap.Get(name)
return ok
}
func ContainsAudioCompletionRatio(name string) bool {
audioCompletionRatioMapMutex.RLock()
defer audioCompletionRatioMapMutex.RUnlock()
name = FormatMatchingModelName(name)
_, ok := audioCompletionRatioMap[name]
_, ok := audioCompletionRatioMap.Get(name)
return ok
}
func ModelRatio2JSONString() string {
modelRatioMapMutex.RLock()
defer modelRatioMapMutex.RUnlock()
jsonBytes, err := common.Marshal(modelRatioMap)
if err != nil {
common.SysError("error marshalling model ratio: " + err.Error())
}
return string(jsonBytes)
return modelRatioMap.MarshalJSONString()
}
var defaultImageRatio = map[string]float64{
"gpt-image-1": 2,
}
var imageRatioMap map[string]float64
var imageRatioMapMutex sync.RWMutex
var (
audioRatioMap map[string]float64 = nil
audioRatioMapMutex = sync.RWMutex{}
)
var (
audioCompletionRatioMap map[string]float64 = nil
audioCompletionRatioMapMutex = sync.RWMutex{}
)
var imageRatioMap = types.NewRWMap[string, float64]()
var audioRatioMap = types.NewRWMap[string, float64]()
var audioCompletionRatioMap = types.NewRWMap[string, float64]()
func ImageRatio2JSONString() string {
imageRatioMapMutex.RLock()
defer imageRatioMapMutex.RUnlock()
jsonBytes, err := common.Marshal(imageRatioMap)
if err != nil {
common.SysError("error marshalling cache ratio: " + err.Error())
}
return string(jsonBytes)
return imageRatioMap.MarshalJSONString()
}
func UpdateImageRatioByJSONString(jsonStr string) error {
imageRatioMapMutex.Lock()
defer imageRatioMapMutex.Unlock()
imageRatioMap = make(map[string]float64)
return common.Unmarshal([]byte(jsonStr), &imageRatioMap)
return types.LoadFromJsonString(imageRatioMap, jsonStr)
}
func GetImageRatio(name string) (float64, bool) {
imageRatioMapMutex.RLock()
defer imageRatioMapMutex.RUnlock()
ratio, ok := imageRatioMap[name]
ratio, ok := imageRatioMap.Get(name)
if !ok {
return 1, false // Default to 1 if not found
}
@@ -765,78 +625,31 @@ func GetImageRatio(name string) (float64, bool) {
}
func AudioRatio2JSONString() string {
audioRatioMapMutex.RLock()
defer audioRatioMapMutex.RUnlock()
jsonBytes, err := common.Marshal(audioRatioMap)
if err != nil {
common.SysError("error marshalling audio ratio: " + err.Error())
}
return string(jsonBytes)
return audioRatioMap.MarshalJSONString()
}
func UpdateAudioRatioByJSONString(jsonStr string) error {
tmp := make(map[string]float64)
if err := common.Unmarshal([]byte(jsonStr), &tmp); err != nil {
return err
}
audioRatioMapMutex.Lock()
audioRatioMap = tmp
audioRatioMapMutex.Unlock()
InvalidateExposedDataCache()
return nil
return types.LoadFromJsonStringWithCallback(audioRatioMap, jsonStr, InvalidateExposedDataCache)
}
func AudioCompletionRatio2JSONString() string {
audioCompletionRatioMapMutex.RLock()
defer audioCompletionRatioMapMutex.RUnlock()
jsonBytes, err := common.Marshal(audioCompletionRatioMap)
if err != nil {
common.SysError("error marshalling audio completion ratio: " + err.Error())
}
return string(jsonBytes)
return audioCompletionRatioMap.MarshalJSONString()
}
func UpdateAudioCompletionRatioByJSONString(jsonStr string) error {
tmp := make(map[string]float64)
if err := common.Unmarshal([]byte(jsonStr), &tmp); err != nil {
return err
}
audioCompletionRatioMapMutex.Lock()
audioCompletionRatioMap = tmp
audioCompletionRatioMapMutex.Unlock()
InvalidateExposedDataCache()
return nil
return types.LoadFromJsonStringWithCallback(audioCompletionRatioMap, jsonStr, InvalidateExposedDataCache)
}
func GetModelRatioCopy() map[string]float64 {
modelRatioMapMutex.RLock()
defer modelRatioMapMutex.RUnlock()
copyMap := make(map[string]float64, len(modelRatioMap))
for k, v := range modelRatioMap {
copyMap[k] = v
}
return copyMap
return modelRatioMap.ReadAll()
}
func GetModelPriceCopy() map[string]float64 {
modelPriceMapMutex.RLock()
defer modelPriceMapMutex.RUnlock()
copyMap := make(map[string]float64, len(modelPriceMap))
for k, v := range modelPriceMap {
copyMap[k] = v
}
return copyMap
return modelPriceMap.ReadAll()
}
func GetCompletionRatioCopy() map[string]float64 {
CompletionRatioMutex.RLock()
defer CompletionRatioMutex.RUnlock()
copyMap := make(map[string]float64, len(CompletionRatio))
for k, v := range CompletionRatio {
copyMap[k] = v
}
return copyMap
return completionRatioMap.ReadAll()
}
// 转换模型名,减少渠道必须配置各种带参数模型