refactor(ratio): replace maps with RWMap for improved concurrency handling

This commit is contained in:
CaIon
2026-02-08 00:48:21 +08:00
parent 7a146a11f5
commit 44c5fac5ea
5 changed files with 115 additions and 360 deletions

View File

@@ -1,10 +1,7 @@
package ratio_setting
import (
"encoding/json"
"sync"
"github.com/QuantumNous/new-api/common"
"github.com/QuantumNous/new-api/types"
)
var defaultCacheRatio = map[string]float64{
@@ -98,70 +95,37 @@ var defaultCreateCacheRatio = map[string]float64{
//var defaultCreateCacheRatio = map[string]float64{}
var cacheRatioMap map[string]float64
var cacheRatioMapMutex sync.RWMutex
var cacheRatioMap = types.NewRWMap[string, float64]()
var createCacheRatioMap = types.NewRWMap[string, float64]()
var createCacheRatioMap map[string]float64
var createCacheRatioMapMutex sync.RWMutex
// GetCacheRatioMap returns the cache ratio map
// GetCacheRatioMap returns a copy of the cache ratio map
func GetCacheRatioMap() map[string]float64 {
cacheRatioMapMutex.RLock()
defer cacheRatioMapMutex.RUnlock()
return cacheRatioMap
return cacheRatioMap.ReadAll()
}
// CacheRatio2JSONString converts the cache ratio map to a JSON string
func CacheRatio2JSONString() string {
cacheRatioMapMutex.RLock()
defer cacheRatioMapMutex.RUnlock()
jsonBytes, err := json.Marshal(cacheRatioMap)
if err != nil {
common.SysLog("error marshalling cache ratio: " + err.Error())
}
return string(jsonBytes)
return cacheRatioMap.MarshalJSONString()
}
// CreateCacheRatio2JSONString converts the create cache ratio map to a JSON string
func CreateCacheRatio2JSONString() string {
createCacheRatioMapMutex.RLock()
defer createCacheRatioMapMutex.RUnlock()
jsonBytes, err := json.Marshal(createCacheRatioMap)
if err != nil {
common.SysLog("error marshalling create cache ratio: " + err.Error())
}
return string(jsonBytes)
return createCacheRatioMap.MarshalJSONString()
}
// UpdateCacheRatioByJSONString updates the cache ratio map from a JSON string
func UpdateCacheRatioByJSONString(jsonStr string) error {
cacheRatioMapMutex.Lock()
defer cacheRatioMapMutex.Unlock()
cacheRatioMap = make(map[string]float64)
err := json.Unmarshal([]byte(jsonStr), &cacheRatioMap)
if err == nil {
InvalidateExposedDataCache()
}
return err
return types.LoadFromJsonStringWithCallback(cacheRatioMap, jsonStr, InvalidateExposedDataCache)
}
// UpdateCreateCacheRatioByJSONString updates the create cache ratio map from a JSON string
func UpdateCreateCacheRatioByJSONString(jsonStr string) error {
createCacheRatioMapMutex.Lock()
defer createCacheRatioMapMutex.Unlock()
createCacheRatioMap = make(map[string]float64)
err := json.Unmarshal([]byte(jsonStr), &createCacheRatioMap)
if err == nil {
InvalidateExposedDataCache()
}
return err
return types.LoadFromJsonStringWithCallback(createCacheRatioMap, jsonStr, InvalidateExposedDataCache)
}
// GetCacheRatio returns the cache ratio for a model
func GetCacheRatio(name string) (float64, bool) {
cacheRatioMapMutex.RLock()
defer cacheRatioMapMutex.RUnlock()
ratio, ok := cacheRatioMap[name]
ratio, ok := cacheRatioMap.Get(name)
if !ok {
return 1, false // Default to 1 if not found
}
@@ -169,9 +133,7 @@ func GetCacheRatio(name string) (float64, bool) {
}
func GetCreateCacheRatio(name string) (float64, bool) {
createCacheRatioMapMutex.RLock()
defer createCacheRatioMapMutex.RUnlock()
ratio, ok := createCacheRatioMap[name]
ratio, ok := createCacheRatioMap.Get(name)
if !ok {
return 1.25, false // Default to 1.25 if not found
}
@@ -179,21 +141,9 @@ func GetCreateCacheRatio(name string) (float64, bool) {
}
func GetCacheRatioCopy() map[string]float64 {
cacheRatioMapMutex.RLock()
defer cacheRatioMapMutex.RUnlock()
copyMap := make(map[string]float64, len(cacheRatioMap))
for k, v := range cacheRatioMap {
copyMap[k] = v
}
return copyMap
return cacheRatioMap.ReadAll()
}
func GetCreateCacheRatioCopy() map[string]float64 {
createCacheRatioMapMutex.RLock()
defer createCacheRatioMapMutex.RUnlock()
copyMap := make(map[string]float64, len(createCacheRatioMap))
for k, v := range createCacheRatioMap {
copyMap[k] = v
}
return copyMap
return createCacheRatioMap.ReadAll()
}