mirror of
https://github.com/QuantumNous/new-api.git
synced 2026-03-30 09:35:21 +00:00
1. Async task model redirection (aligned with sync tasks):
- Integrate ModelMappedHelper in RelayTaskSubmit after model name
determination, populating OriginModelName / UpstreamModelName on RelayInfo.
- All task adaptors now send UpstreamModelName to upstream providers:
- Gemini & Vertex: BuildRequestURL uses UpstreamModelName.
- Doubao & Ali: BuildRequestBody conditionally overwrites body.Model.
- Vidu, Kling, Hailuo, Jimeng: convertToRequestPayload accepts RelayInfo
and unconditionally uses info.UpstreamModelName.
- Sora: BuildRequestBody parses JSON and multipart bodies to replace
the "model" field with UpstreamModelName.
- Frontend log visibility: LogTaskConsumption and taskBillingOther now
emit is_model_mapped / upstream_model_name in the "other" JSON field.
- Billing safety: RecalculateTaskQuotaByTokens reads model name from
BillingContext.OriginModelName (via taskModelName) instead of
task.Data["model"], preventing billing leaks from upstream model names.
2. Per-call billing (TaskPricePatches lifecycle):
- Rename TaskBillingContext.ModelName → OriginModelName; add PerCallBilling
bool field, populated from TaskPricePatches at submission time.
- settleTaskBillingOnComplete short-circuits when PerCallBilling is true,
skipping both adaptor adjustments and token-based recalculation.
- Remove ModelName from TaskSubmitResult; use relayInfo.OriginModelName
consistently in controller/relay.go for billing context and logging.
3. Multipart retry boundary mismatch fix:
- Root cause: after Sora (or OpenAI audio) rebuilds a multipart body with a
new boundary and overwrites c.Request.Header["Content-Type"], subsequent
calls to ParseMultipartFormReusable on retry would parse the cached
original body with the wrong boundary, causing "NextPart: EOF".
- Fix: ParseMultipartFormReusable now caches the original Content-Type in
gin context key "_original_multipart_ct" on first call and reuses it for
all subsequent parses, making multipart parsing retry-safe globally.
- Sora adaptor reverted to the standard pattern (direct header set/get),
which is now safe thanks to the root fix.
4. Tests:
- task_billing_test.go: update makeTask to use OriginModelName; add
PerCallBilling settlement tests (skip adaptor adjust, skip token recalc);
add non-per-call adaptor adjustment test with refund verification.
360 lines
8.7 KiB
Go
360 lines
8.7 KiB
Go
package common
|
|
|
|
import (
|
|
"bytes"
|
|
"fmt"
|
|
"io"
|
|
"mime"
|
|
"mime/multipart"
|
|
"net/http"
|
|
"net/url"
|
|
"strings"
|
|
"time"
|
|
|
|
"github.com/QuantumNous/new-api/constant"
|
|
"github.com/pkg/errors"
|
|
|
|
"github.com/gin-gonic/gin"
|
|
)
|
|
|
|
const KeyRequestBody = "key_request_body"
|
|
const KeyBodyStorage = "key_body_storage"
|
|
|
|
var ErrRequestBodyTooLarge = errors.New("request body too large")
|
|
|
|
func IsRequestBodyTooLargeError(err error) bool {
|
|
if err == nil {
|
|
return false
|
|
}
|
|
if errors.Is(err, ErrRequestBodyTooLarge) {
|
|
return true
|
|
}
|
|
var mbe *http.MaxBytesError
|
|
return errors.As(err, &mbe)
|
|
}
|
|
|
|
func GetRequestBody(c *gin.Context) (io.Seeker, error) {
|
|
// 首先检查是否有 BodyStorage 缓存
|
|
if storage, exists := c.Get(KeyBodyStorage); exists && storage != nil {
|
|
if bs, ok := storage.(BodyStorage); ok {
|
|
if _, err := bs.Seek(0, io.SeekStart); err != nil {
|
|
return nil, fmt.Errorf("failed to seek body storage: %w", err)
|
|
}
|
|
return bs, nil
|
|
}
|
|
}
|
|
|
|
// 检查旧的缓存方式
|
|
cached, exists := c.Get(KeyRequestBody)
|
|
if exists && cached != nil {
|
|
if b, ok := cached.([]byte); ok {
|
|
bs, err := CreateBodyStorage(b)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
c.Set(KeyBodyStorage, bs)
|
|
return bs, nil
|
|
}
|
|
}
|
|
|
|
maxMB := constant.MaxRequestBodyMB
|
|
if maxMB <= 0 {
|
|
maxMB = 128 // 默认 128MB
|
|
}
|
|
maxBytes := int64(maxMB) << 20
|
|
|
|
contentLength := c.Request.ContentLength
|
|
|
|
// 使用新的存储系统
|
|
storage, err := CreateBodyStorageFromReader(c.Request.Body, contentLength, maxBytes)
|
|
_ = c.Request.Body.Close()
|
|
|
|
if err != nil {
|
|
if IsRequestBodyTooLargeError(err) {
|
|
return nil, errors.Wrap(ErrRequestBodyTooLarge, fmt.Sprintf("request body exceeds %d MB", maxMB))
|
|
}
|
|
return nil, err
|
|
}
|
|
|
|
// 缓存存储对象
|
|
c.Set(KeyBodyStorage, storage)
|
|
|
|
return storage, nil
|
|
}
|
|
|
|
// GetBodyStorage 获取请求体存储对象(用于需要多次读取的场景)
|
|
func GetBodyStorage(c *gin.Context) (BodyStorage, error) {
|
|
seeker, err := GetRequestBody(c)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
bs, ok := seeker.(BodyStorage)
|
|
if !ok {
|
|
return nil, errors.New("unexpected body storage type")
|
|
}
|
|
return bs, nil
|
|
}
|
|
|
|
// CleanupBodyStorage 清理请求体存储(应在请求结束时调用)
|
|
func CleanupBodyStorage(c *gin.Context) {
|
|
if storage, exists := c.Get(KeyBodyStorage); exists && storage != nil {
|
|
if bs, ok := storage.(BodyStorage); ok {
|
|
bs.Close()
|
|
}
|
|
c.Set(KeyBodyStorage, nil)
|
|
}
|
|
}
|
|
|
|
func UnmarshalBodyReusable(c *gin.Context, v any) error {
|
|
storage, err := GetBodyStorage(c)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
requestBody, err := storage.Bytes()
|
|
if err != nil {
|
|
return err
|
|
}
|
|
contentType := c.Request.Header.Get("Content-Type")
|
|
if strings.HasPrefix(contentType, "application/json") {
|
|
err = Unmarshal(requestBody, v)
|
|
} else if strings.Contains(contentType, gin.MIMEPOSTForm) {
|
|
err = parseFormData(requestBody, v)
|
|
} else if strings.Contains(contentType, gin.MIMEMultipartPOSTForm) {
|
|
err = parseMultipartFormData(c, requestBody, v)
|
|
} else {
|
|
// skip for now
|
|
// TODO: someday non json request have variant model, we will need to implementation this
|
|
}
|
|
if err != nil {
|
|
return err
|
|
}
|
|
// Reset request body
|
|
if _, seekErr := storage.Seek(0, io.SeekStart); seekErr != nil {
|
|
return seekErr
|
|
}
|
|
c.Request.Body = io.NopCloser(storage)
|
|
return nil
|
|
}
|
|
|
|
func SetContextKey(c *gin.Context, key constant.ContextKey, value any) {
|
|
c.Set(string(key), value)
|
|
}
|
|
|
|
func GetContextKey(c *gin.Context, key constant.ContextKey) (any, bool) {
|
|
return c.Get(string(key))
|
|
}
|
|
|
|
func GetContextKeyString(c *gin.Context, key constant.ContextKey) string {
|
|
return c.GetString(string(key))
|
|
}
|
|
|
|
func GetContextKeyInt(c *gin.Context, key constant.ContextKey) int {
|
|
return c.GetInt(string(key))
|
|
}
|
|
|
|
func GetContextKeyBool(c *gin.Context, key constant.ContextKey) bool {
|
|
return c.GetBool(string(key))
|
|
}
|
|
|
|
func GetContextKeyStringSlice(c *gin.Context, key constant.ContextKey) []string {
|
|
return c.GetStringSlice(string(key))
|
|
}
|
|
|
|
func GetContextKeyStringMap(c *gin.Context, key constant.ContextKey) map[string]any {
|
|
return c.GetStringMap(string(key))
|
|
}
|
|
|
|
func GetContextKeyTime(c *gin.Context, key constant.ContextKey) time.Time {
|
|
return c.GetTime(string(key))
|
|
}
|
|
|
|
func GetContextKeyType[T any](c *gin.Context, key constant.ContextKey) (T, bool) {
|
|
if value, ok := c.Get(string(key)); ok {
|
|
if v, ok := value.(T); ok {
|
|
return v, true
|
|
}
|
|
}
|
|
var t T
|
|
return t, false
|
|
}
|
|
|
|
func ApiError(c *gin.Context, err error) {
|
|
c.JSON(http.StatusOK, gin.H{
|
|
"success": false,
|
|
"message": err.Error(),
|
|
})
|
|
}
|
|
|
|
func ApiErrorMsg(c *gin.Context, msg string) {
|
|
c.JSON(http.StatusOK, gin.H{
|
|
"success": false,
|
|
"message": msg,
|
|
})
|
|
}
|
|
|
|
func ApiSuccess(c *gin.Context, data any) {
|
|
c.JSON(http.StatusOK, gin.H{
|
|
"success": true,
|
|
"message": "",
|
|
"data": data,
|
|
})
|
|
}
|
|
|
|
// ApiErrorI18n returns a translated error message based on the user's language preference
|
|
// key is the i18n message key, args is optional template data
|
|
func ApiErrorI18n(c *gin.Context, key string, args ...map[string]any) {
|
|
msg := TranslateMessage(c, key, args...)
|
|
c.JSON(http.StatusOK, gin.H{
|
|
"success": false,
|
|
"message": msg,
|
|
})
|
|
}
|
|
|
|
// ApiSuccessI18n returns a translated success message based on the user's language preference
|
|
func ApiSuccessI18n(c *gin.Context, key string, data any, args ...map[string]any) {
|
|
msg := TranslateMessage(c, key, args...)
|
|
c.JSON(http.StatusOK, gin.H{
|
|
"success": true,
|
|
"message": msg,
|
|
"data": data,
|
|
})
|
|
}
|
|
|
|
// TranslateMessage is a helper function that calls i18n.T
|
|
// This function is defined here to avoid circular imports
|
|
// The actual implementation will be set during init
|
|
var TranslateMessage func(c *gin.Context, key string, args ...map[string]any) string
|
|
|
|
func init() {
|
|
// Default implementation that returns the key as-is
|
|
// This will be replaced by i18n.T during i18n initialization
|
|
TranslateMessage = func(c *gin.Context, key string, args ...map[string]any) string {
|
|
return key
|
|
}
|
|
}
|
|
|
|
func ParseMultipartFormReusable(c *gin.Context) (*multipart.Form, error) {
|
|
storage, err := GetBodyStorage(c)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
requestBody, err := storage.Bytes()
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
// Use the original Content-Type saved on first call to avoid boundary
|
|
// mismatch when callers overwrite c.Request.Header after multipart rebuild.
|
|
var contentType string
|
|
if saved, ok := c.Get("_original_multipart_ct"); ok {
|
|
contentType = saved.(string)
|
|
} else {
|
|
contentType = c.Request.Header.Get("Content-Type")
|
|
c.Set("_original_multipart_ct", contentType)
|
|
}
|
|
boundary, err := parseBoundary(contentType)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
reader := multipart.NewReader(bytes.NewReader(requestBody), boundary)
|
|
form, err := reader.ReadForm(multipartMemoryLimit())
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
// Reset request body
|
|
if _, seekErr := storage.Seek(0, io.SeekStart); seekErr != nil {
|
|
return nil, seekErr
|
|
}
|
|
c.Request.Body = io.NopCloser(storage)
|
|
return form, nil
|
|
}
|
|
|
|
func processFormMap(formMap map[string]any, v any) error {
|
|
jsonData, err := Marshal(formMap)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
err = Unmarshal(jsonData, v)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
func parseFormData(data []byte, v any) error {
|
|
values, err := url.ParseQuery(string(data))
|
|
if err != nil {
|
|
return err
|
|
}
|
|
formMap := make(map[string]any)
|
|
for key, vals := range values {
|
|
if len(vals) == 1 {
|
|
formMap[key] = vals[0]
|
|
} else {
|
|
formMap[key] = vals
|
|
}
|
|
}
|
|
|
|
return processFormMap(formMap, v)
|
|
}
|
|
|
|
func parseMultipartFormData(c *gin.Context, data []byte, v any) error {
|
|
contentType := c.Request.Header.Get("Content-Type")
|
|
boundary, err := parseBoundary(contentType)
|
|
if err != nil {
|
|
if errors.Is(err, errBoundaryNotFound) {
|
|
return Unmarshal(data, v) // Fallback to JSON
|
|
}
|
|
return err
|
|
}
|
|
|
|
reader := multipart.NewReader(bytes.NewReader(data), boundary)
|
|
form, err := reader.ReadForm(multipartMemoryLimit())
|
|
if err != nil {
|
|
return err
|
|
}
|
|
defer form.RemoveAll()
|
|
formMap := make(map[string]any)
|
|
for key, vals := range form.Value {
|
|
if len(vals) == 1 {
|
|
formMap[key] = vals[0]
|
|
} else {
|
|
formMap[key] = vals
|
|
}
|
|
}
|
|
|
|
return processFormMap(formMap, v)
|
|
}
|
|
|
|
var errBoundaryNotFound = errors.New("multipart boundary not found")
|
|
|
|
// parseBoundary extracts the multipart boundary from the Content-Type header using mime.ParseMediaType
|
|
func parseBoundary(contentType string) (string, error) {
|
|
if contentType == "" {
|
|
return "", errBoundaryNotFound
|
|
}
|
|
// Boundary-UUID / boundary-------xxxxxx
|
|
_, params, err := mime.ParseMediaType(contentType)
|
|
if err != nil {
|
|
return "", err
|
|
}
|
|
boundary, ok := params["boundary"]
|
|
if !ok || boundary == "" {
|
|
return "", errBoundaryNotFound
|
|
}
|
|
return boundary, nil
|
|
}
|
|
|
|
// multipartMemoryLimit returns the configured multipart memory limit in bytes
|
|
func multipartMemoryLimit() int64 {
|
|
limitMB := constant.MaxFileDownloadMB
|
|
if limitMB <= 0 {
|
|
limitMB = 32
|
|
}
|
|
return int64(limitMB) << 20
|
|
}
|