mirror of
https://github.com/QuantumNous/new-api.git
synced 2026-04-15 23:37:28 +00:00
Compare commits
11 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
2ec4565601 | ||
|
|
a4fb33957f | ||
|
|
909c5eb276 | ||
|
|
8723e3f239 | ||
|
|
9328b907f2 | ||
|
|
8efa12b941 | ||
|
|
7b997b3a2c | ||
|
|
700c05b826 | ||
|
|
c5103237b0 | ||
|
|
f500eb17a8 | ||
|
|
86f6bb7abe |
@@ -1,8 +1,8 @@
|
||||
package common
|
||||
|
||||
import (
|
||||
"os"
|
||||
"strconv"
|
||||
//"os"
|
||||
//"strconv"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
@@ -63,8 +63,8 @@ var EmailDomainWhitelist = []string{
|
||||
"foxmail.com",
|
||||
}
|
||||
|
||||
var DebugEnabled = os.Getenv("DEBUG") == "true"
|
||||
var MemoryCacheEnabled = os.Getenv("MEMORY_CACHE_ENABLED") == "true"
|
||||
var DebugEnabled bool
|
||||
var MemoryCacheEnabled bool
|
||||
|
||||
var LogConsumeEnabled = true
|
||||
|
||||
@@ -103,22 +103,22 @@ var RetryTimes = 0
|
||||
|
||||
//var RootUserEmail = ""
|
||||
|
||||
var IsMasterNode = os.Getenv("NODE_TYPE") != "slave"
|
||||
var IsMasterNode bool
|
||||
|
||||
var requestInterval, _ = strconv.Atoi(os.Getenv("POLLING_INTERVAL"))
|
||||
var RequestInterval = time.Duration(requestInterval) * time.Second
|
||||
var requestInterval int
|
||||
var RequestInterval time.Duration
|
||||
|
||||
var SyncFrequency = GetEnvOrDefault("SYNC_FREQUENCY", 60) // unit is second
|
||||
var SyncFrequency int // unit is second
|
||||
|
||||
var BatchUpdateEnabled = false
|
||||
var BatchUpdateInterval = GetEnvOrDefault("BATCH_UPDATE_INTERVAL", 5)
|
||||
var BatchUpdateInterval int
|
||||
|
||||
var RelayTimeout = GetEnvOrDefault("RELAY_TIMEOUT", 0) // unit is second
|
||||
var RelayTimeout int // unit is second
|
||||
|
||||
var GeminiSafetySetting = GetEnvOrDefaultString("GEMINI_SAFETY_SETTING", "BLOCK_NONE")
|
||||
var GeminiSafetySetting string
|
||||
|
||||
// https://docs.cohere.com/docs/safety-modes Type; NONE/CONTEXTUAL/STRICT
|
||||
var CohereSafetySetting = GetEnvOrDefaultString("COHERE_SAFETY_SETTING", "NONE")
|
||||
var CohereSafetySetting string
|
||||
|
||||
const (
|
||||
RequestIdKey = "X-Oneapi-Request-Id"
|
||||
@@ -145,13 +145,13 @@ var (
|
||||
// All duration's unit is seconds
|
||||
// Shouldn't larger then RateLimitKeyExpirationDuration
|
||||
var (
|
||||
GlobalApiRateLimitEnable = GetEnvOrDefaultBool("GLOBAL_API_RATE_LIMIT_ENABLE", true)
|
||||
GlobalApiRateLimitNum = GetEnvOrDefault("GLOBAL_API_RATE_LIMIT", 180)
|
||||
GlobalApiRateLimitDuration = int64(GetEnvOrDefault("GLOBAL_API_RATE_LIMIT_DURATION", 180))
|
||||
GlobalApiRateLimitEnable bool
|
||||
GlobalApiRateLimitNum int
|
||||
GlobalApiRateLimitDuration int64
|
||||
|
||||
GlobalWebRateLimitEnable = GetEnvOrDefaultBool("GLOBAL_WEB_RATE_LIMIT_ENABLE", true)
|
||||
GlobalWebRateLimitNum = GetEnvOrDefault("GLOBAL_WEB_RATE_LIMIT", 60)
|
||||
GlobalWebRateLimitDuration = int64(GetEnvOrDefault("GLOBAL_WEB_RATE_LIMIT_DURATION", 180))
|
||||
GlobalWebRateLimitEnable bool
|
||||
GlobalWebRateLimitNum int
|
||||
GlobalWebRateLimitDuration int64
|
||||
|
||||
UploadRateLimitNum = 10
|
||||
UploadRateLimitDuration int64 = 60
|
||||
@@ -235,6 +235,7 @@ const (
|
||||
ChannelTypeVolcEngine = 45
|
||||
ChannelTypeBaiduV2 = 46
|
||||
ChannelTypeXinference = 47
|
||||
ChannelTypeXai = 48
|
||||
ChannelTypeDummy // this one is only for count, do not add any channel after this
|
||||
|
||||
)
|
||||
@@ -288,4 +289,5 @@ var ChannelBaseURLs = []string{
|
||||
"https://ark.cn-beijing.volces.com", //45
|
||||
"https://qianfan.baidubce.com", //46
|
||||
"", //47
|
||||
"https://api.x.ai", //48
|
||||
}
|
||||
|
||||
@@ -6,6 +6,8 @@ import (
|
||||
"log"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"strconv"
|
||||
"time"
|
||||
)
|
||||
|
||||
var (
|
||||
@@ -66,4 +68,31 @@ func LoadEnv() {
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Initialize variables from constants.go that were using environment variables
|
||||
DebugEnabled = os.Getenv("DEBUG") == "true"
|
||||
MemoryCacheEnabled = os.Getenv("MEMORY_CACHE_ENABLED") == "true"
|
||||
IsMasterNode = os.Getenv("NODE_TYPE") != "slave"
|
||||
|
||||
// Parse requestInterval and set RequestInterval
|
||||
requestInterval, _ = strconv.Atoi(os.Getenv("POLLING_INTERVAL"))
|
||||
RequestInterval = time.Duration(requestInterval) * time.Second
|
||||
|
||||
// Initialize variables with GetEnvOrDefault
|
||||
SyncFrequency = GetEnvOrDefault("SYNC_FREQUENCY", 60)
|
||||
BatchUpdateInterval = GetEnvOrDefault("BATCH_UPDATE_INTERVAL", 5)
|
||||
RelayTimeout = GetEnvOrDefault("RELAY_TIMEOUT", 0)
|
||||
|
||||
// Initialize string variables with GetEnvOrDefaultString
|
||||
GeminiSafetySetting = GetEnvOrDefaultString("GEMINI_SAFETY_SETTING", "BLOCK_NONE")
|
||||
CohereSafetySetting = GetEnvOrDefaultString("COHERE_SAFETY_SETTING", "NONE")
|
||||
|
||||
// Initialize rate limit variables
|
||||
GlobalApiRateLimitEnable = GetEnvOrDefaultBool("GLOBAL_API_RATE_LIMIT_ENABLE", true)
|
||||
GlobalApiRateLimitNum = GetEnvOrDefault("GLOBAL_API_RATE_LIMIT", 180)
|
||||
GlobalApiRateLimitDuration = int64(GetEnvOrDefault("GLOBAL_API_RATE_LIMIT_DURATION", 180))
|
||||
|
||||
GlobalWebRateLimitEnable = GetEnvOrDefaultBool("GLOBAL_WEB_RATE_LIMIT_ENABLE", true)
|
||||
GlobalWebRateLimitNum = GetEnvOrDefault("GLOBAL_WEB_RATE_LIMIT", 60)
|
||||
GlobalWebRateLimitDuration = int64(GetEnvOrDefault("GLOBAL_WEB_RATE_LIMIT_DURATION", 180))
|
||||
}
|
||||
|
||||
@@ -12,3 +12,7 @@ func DecodeJson(data []byte, v any) error {
|
||||
func DecodeJsonStr(data string, v any) error {
|
||||
return DecodeJson(StringToByteSlice(data), v)
|
||||
}
|
||||
|
||||
func EncodeJson(v any) ([]byte, error) {
|
||||
return json.Marshal(v)
|
||||
}
|
||||
|
||||
@@ -4,32 +4,39 @@ import (
|
||||
"one-api/common"
|
||||
)
|
||||
|
||||
var StreamingTimeout = common.GetEnvOrDefault("STREAMING_TIMEOUT", 60)
|
||||
var DifyDebug = common.GetEnvOrDefaultBool("DIFY_DEBUG", true)
|
||||
|
||||
var MaxFileDownloadMB = common.GetEnvOrDefault("MAX_FILE_DOWNLOAD_MB", 20)
|
||||
|
||||
// ForceStreamOption 覆盖请求参数,强制返回usage信息
|
||||
var ForceStreamOption = common.GetEnvOrDefaultBool("FORCE_STREAM_OPTION", true)
|
||||
|
||||
var GetMediaToken = common.GetEnvOrDefaultBool("GET_MEDIA_TOKEN", true)
|
||||
|
||||
var GetMediaTokenNotStream = common.GetEnvOrDefaultBool("GET_MEDIA_TOKEN_NOT_STREAM", true)
|
||||
|
||||
var UpdateTask = common.GetEnvOrDefaultBool("UPDATE_TASK", true)
|
||||
|
||||
var AzureDefaultAPIVersion = common.GetEnvOrDefaultString("AZURE_DEFAULT_API_VERSION", "2024-12-01-preview")
|
||||
var StreamingTimeout int
|
||||
var DifyDebug bool
|
||||
var MaxFileDownloadMB int
|
||||
var ForceStreamOption bool
|
||||
var GetMediaToken bool
|
||||
var GetMediaTokenNotStream bool
|
||||
var UpdateTask bool
|
||||
var AzureDefaultAPIVersion string
|
||||
var GeminiVisionMaxImageNum int
|
||||
var NotifyLimitCount int
|
||||
var NotificationLimitDurationMinute int
|
||||
var GenerateDefaultToken bool
|
||||
|
||||
//var GeminiModelMap = map[string]string{
|
||||
// "gemini-1.0-pro": "v1",
|
||||
//}
|
||||
|
||||
var GeminiVisionMaxImageNum = common.GetEnvOrDefault("GEMINI_VISION_MAX_IMAGE_NUM", 16)
|
||||
|
||||
var NotifyLimitCount = common.GetEnvOrDefault("NOTIFY_LIMIT_COUNT", 2)
|
||||
var NotificationLimitDurationMinute = common.GetEnvOrDefault("NOTIFICATION_LIMIT_DURATION_MINUTE", 10)
|
||||
|
||||
func InitEnv() {
|
||||
StreamingTimeout = common.GetEnvOrDefault("STREAMING_TIMEOUT", 60)
|
||||
DifyDebug = common.GetEnvOrDefaultBool("DIFY_DEBUG", true)
|
||||
MaxFileDownloadMB = common.GetEnvOrDefault("MAX_FILE_DOWNLOAD_MB", 20)
|
||||
// ForceStreamOption 覆盖请求参数,强制返回usage信息
|
||||
ForceStreamOption = common.GetEnvOrDefaultBool("FORCE_STREAM_OPTION", true)
|
||||
GetMediaToken = common.GetEnvOrDefaultBool("GET_MEDIA_TOKEN", true)
|
||||
GetMediaTokenNotStream = common.GetEnvOrDefaultBool("GET_MEDIA_TOKEN_NOT_STREAM", true)
|
||||
UpdateTask = common.GetEnvOrDefaultBool("UPDATE_TASK", true)
|
||||
AzureDefaultAPIVersion = common.GetEnvOrDefaultString("AZURE_DEFAULT_API_VERSION", "2024-12-01-preview")
|
||||
GeminiVisionMaxImageNum = common.GetEnvOrDefault("GEMINI_VISION_MAX_IMAGE_NUM", 16)
|
||||
NotifyLimitCount = common.GetEnvOrDefault("NOTIFY_LIMIT_COUNT", 2)
|
||||
NotificationLimitDurationMinute = common.GetEnvOrDefault("NOTIFICATION_LIMIT_DURATION_MINUTE", 10)
|
||||
// GenerateDefaultToken 是否生成初始令牌,默认关闭。
|
||||
GenerateDefaultToken = common.GetEnvOrDefaultBool("GENERATE_DEFAULT_TOKEN", false)
|
||||
|
||||
//modelVersionMapStr := strings.TrimSpace(os.Getenv("GEMINI_MODEL_MAP"))
|
||||
//if modelVersionMapStr == "" {
|
||||
// return
|
||||
@@ -43,6 +50,3 @@ func InitEnv() {
|
||||
// }
|
||||
//}
|
||||
}
|
||||
|
||||
// GenerateDefaultToken 是否生成初始令牌,默认关闭。
|
||||
var GenerateDefaultToken = common.GetEnvOrDefaultBool("GENERATE_DEFAULT_TOKEN", false)
|
||||
|
||||
@@ -45,15 +45,16 @@ type RealtimeUsage struct {
|
||||
|
||||
type InputTokenDetails struct {
|
||||
CachedTokens int `json:"cached_tokens"`
|
||||
CachedCreationTokens int
|
||||
CachedCreationTokens int `json:"-"`
|
||||
TextTokens int `json:"text_tokens"`
|
||||
AudioTokens int `json:"audio_tokens"`
|
||||
ImageTokens int `json:"image_tokens"`
|
||||
}
|
||||
|
||||
type OutputTokenDetails struct {
|
||||
TextTokens int `json:"text_tokens"`
|
||||
AudioTokens int `json:"audio_tokens"`
|
||||
TextTokens int `json:"text_tokens"`
|
||||
AudioTokens int `json:"audio_tokens"`
|
||||
ReasoningTokens int `json:"reasoning_tokens"`
|
||||
}
|
||||
|
||||
type RealtimeSession struct {
|
||||
|
||||
@@ -1,7 +1,6 @@
|
||||
package dify
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"bytes"
|
||||
"encoding/base64"
|
||||
"encoding/json"
|
||||
@@ -213,12 +212,8 @@ func streamResponseDify2OpenAI(difyResponse DifyChunkChatCompletionResponse) *dt
|
||||
func difyStreamHandler(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (*dto.OpenAIErrorWithStatusCode, *dto.Usage) {
|
||||
var responseText string
|
||||
usage := &dto.Usage{}
|
||||
scanner := bufio.NewScanner(resp.Body)
|
||||
scanner.Split(bufio.ScanLines)
|
||||
var nodeToken int
|
||||
|
||||
helper.SetEventStreamHeaders(c)
|
||||
|
||||
helper.StreamScannerHandler(c, resp, info, func(data string) bool {
|
||||
var difyResponse DifyChunkChatCompletionResponse
|
||||
err := json.Unmarshal([]byte(data), &difyResponse)
|
||||
@@ -247,13 +242,10 @@ func difyStreamHandler(c *gin.Context, resp *http.Response, info *relaycommon.Re
|
||||
}
|
||||
return true
|
||||
})
|
||||
if err := scanner.Err(); err != nil {
|
||||
common.SysError("error reading stream: " + err.Error())
|
||||
}
|
||||
helper.Done(c)
|
||||
err := resp.Body.Close()
|
||||
if err != nil {
|
||||
//return service.OpenAIErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil
|
||||
// return service.OpenAIErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil
|
||||
common.SysError("close_response_body_failed: " + err.Error())
|
||||
}
|
||||
if usage.TotalTokens == 0 {
|
||||
|
||||
@@ -56,6 +56,7 @@ func CovertGemini2OpenAI(textRequest dto.GeneralOpenAIRequest) (*GeminiChatReque
|
||||
continue
|
||||
}
|
||||
if tool.Function.Parameters != nil {
|
||||
|
||||
params, ok := tool.Function.Parameters.(map[string]interface{})
|
||||
if ok {
|
||||
if props, hasProps := params["properties"].(map[string]interface{}); hasProps {
|
||||
@@ -65,6 +66,9 @@ func CovertGemini2OpenAI(textRequest dto.GeneralOpenAIRequest) (*GeminiChatReque
|
||||
}
|
||||
}
|
||||
}
|
||||
// Clean the parameters before appending
|
||||
cleanedParams := cleanFunctionParameters(tool.Function.Parameters)
|
||||
tool.Function.Parameters = cleanedParams
|
||||
functions = append(functions, tool.Function)
|
||||
}
|
||||
if codeExecution {
|
||||
@@ -86,11 +90,11 @@ func CovertGemini2OpenAI(textRequest dto.GeneralOpenAIRequest) (*GeminiChatReque
|
||||
// json_data, _ := json.Marshal(geminiRequest.Tools)
|
||||
// common.SysLog("tools_json: " + string(json_data))
|
||||
} else if textRequest.Functions != nil {
|
||||
geminiRequest.Tools = []GeminiChatTool{
|
||||
{
|
||||
FunctionDeclarations: textRequest.Functions,
|
||||
},
|
||||
}
|
||||
//geminiRequest.Tools = []GeminiChatTool{
|
||||
// {
|
||||
// FunctionDeclarations: textRequest.Functions,
|
||||
// },
|
||||
//}
|
||||
}
|
||||
|
||||
if textRequest.ResponseFormat != nil && (textRequest.ResponseFormat.Type == "json_schema" || textRequest.ResponseFormat.Type == "json_object") {
|
||||
@@ -229,6 +233,96 @@ func CovertGemini2OpenAI(textRequest dto.GeneralOpenAIRequest) (*GeminiChatReque
|
||||
return &geminiRequest, nil
|
||||
}
|
||||
|
||||
|
||||
// cleanFunctionParameters recursively removes unsupported fields from Gemini function parameters.
|
||||
func cleanFunctionParameters(params interface{}) interface{} {
|
||||
if params == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
paramMap, ok := params.(map[string]interface{})
|
||||
if !ok {
|
||||
// Not a map, return as is (e.g., could be an array or primitive)
|
||||
return params
|
||||
}
|
||||
|
||||
// Create a copy to avoid modifying the original
|
||||
cleanedMap := make(map[string]interface{})
|
||||
for k, v := range paramMap {
|
||||
cleanedMap[k] = v
|
||||
}
|
||||
|
||||
// Clean properties
|
||||
if props, ok := cleanedMap["properties"].(map[string]interface{}); ok && props != nil {
|
||||
cleanedProps := make(map[string]interface{})
|
||||
for propName, propValue := range props {
|
||||
propMap, ok := propValue.(map[string]interface{})
|
||||
if !ok {
|
||||
cleanedProps[propName] = propValue // Keep non-map properties
|
||||
continue
|
||||
}
|
||||
|
||||
// Create a copy of the property map
|
||||
cleanedPropMap := make(map[string]interface{})
|
||||
for k, v := range propMap {
|
||||
cleanedPropMap[k] = v
|
||||
}
|
||||
|
||||
// Remove unsupported fields
|
||||
delete(cleanedPropMap, "default")
|
||||
delete(cleanedPropMap, "exclusiveMaximum")
|
||||
delete(cleanedPropMap, "exclusiveMinimum")
|
||||
|
||||
// Check and clean 'format' for string types
|
||||
if propType, typeExists := cleanedPropMap["type"].(string); typeExists && propType == "string" {
|
||||
if formatValue, formatExists := cleanedPropMap["format"].(string); formatExists {
|
||||
if formatValue != "enum" && formatValue != "date-time" {
|
||||
delete(cleanedPropMap, "format")
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Recursively clean nested properties within this property if it's an object/array
|
||||
// Check the type before recursing
|
||||
if propType, typeExists := cleanedPropMap["type"].(string); typeExists && (propType == "object" || propType == "array") {
|
||||
cleanedProps[propName] = cleanFunctionParameters(cleanedPropMap)
|
||||
} else {
|
||||
cleanedProps[propName] = cleanedPropMap // Assign the cleaned map back if not recursing
|
||||
}
|
||||
|
||||
}
|
||||
cleanedMap["properties"] = cleanedProps
|
||||
}
|
||||
|
||||
// Recursively clean items in arrays if needed (e.g., type: array, items: { ... })
|
||||
if items, ok := cleanedMap["items"].(map[string]interface{}); ok && items != nil {
|
||||
cleanedMap["items"] = cleanFunctionParameters(items)
|
||||
}
|
||||
// Also handle items if it's an array of schemas
|
||||
if itemsArray, ok := cleanedMap["items"].([]interface{}); ok {
|
||||
cleanedItemsArray := make([]interface{}, len(itemsArray))
|
||||
for i, item := range itemsArray {
|
||||
cleanedItemsArray[i] = cleanFunctionParameters(item)
|
||||
}
|
||||
cleanedMap["items"] = cleanedItemsArray
|
||||
}
|
||||
|
||||
|
||||
// Recursively clean other schema composition keywords if necessary
|
||||
for _, field := range []string{"allOf", "anyOf", "oneOf"} {
|
||||
if nested, ok := cleanedMap[field].([]interface{}); ok {
|
||||
cleanedNested := make([]interface{}, len(nested))
|
||||
for i, item := range nested {
|
||||
cleanedNested[i] = cleanFunctionParameters(item)
|
||||
}
|
||||
cleanedMap[field] = cleanedNested
|
||||
}
|
||||
}
|
||||
|
||||
return cleanedMap
|
||||
}
|
||||
|
||||
|
||||
func removeAdditionalPropertiesWithDepth(schema interface{}, depth int) interface{} {
|
||||
if depth >= 5 {
|
||||
return schema
|
||||
|
||||
105
relay/channel/xai/adaptor.go
Normal file
105
relay/channel/xai/adaptor.go
Normal file
@@ -0,0 +1,105 @@
|
||||
package xai
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"github.com/gin-gonic/gin"
|
||||
"io"
|
||||
"net/http"
|
||||
"one-api/dto"
|
||||
"one-api/relay/channel"
|
||||
relaycommon "one-api/relay/common"
|
||||
"strings"
|
||||
)
|
||||
|
||||
type Adaptor struct {
|
||||
}
|
||||
|
||||
func (a *Adaptor) ConvertClaudeRequest(*gin.Context, *relaycommon.RelayInfo, *dto.ClaudeRequest) (any, error) {
|
||||
//TODO implement me
|
||||
//panic("implement me")
|
||||
return nil, errors.New("not available")
|
||||
}
|
||||
|
||||
func (a *Adaptor) ConvertAudioRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.AudioRequest) (io.Reader, error) {
|
||||
//not available
|
||||
return nil, errors.New("not available")
|
||||
}
|
||||
|
||||
func (a *Adaptor) ConvertImageRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.ImageRequest) (any, error) {
|
||||
request.Size = ""
|
||||
return request, nil
|
||||
}
|
||||
|
||||
func (a *Adaptor) Init(info *relaycommon.RelayInfo) {
|
||||
}
|
||||
|
||||
func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) {
|
||||
return fmt.Sprintf("%s/v1/chat/completions", info.BaseUrl), nil
|
||||
}
|
||||
|
||||
func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Header, info *relaycommon.RelayInfo) error {
|
||||
channel.SetupApiRequestHeader(info, c, req)
|
||||
req.Set("Authorization", "Bearer "+info.ApiKey)
|
||||
return nil
|
||||
}
|
||||
|
||||
func (a *Adaptor) ConvertOpenAIRequest(c *gin.Context, info *relaycommon.RelayInfo, request *dto.GeneralOpenAIRequest) (any, error) {
|
||||
if request == nil {
|
||||
return nil, errors.New("request is nil")
|
||||
}
|
||||
request.StreamOptions = nil
|
||||
if strings.HasPrefix(request.Model, "grok-3-mini") {
|
||||
if request.MaxCompletionTokens == 0 && request.MaxTokens != 0 {
|
||||
request.MaxCompletionTokens = request.MaxTokens
|
||||
request.MaxTokens = 0
|
||||
}
|
||||
if strings.HasSuffix(request.Model, "-high") {
|
||||
request.ReasoningEffort = "high"
|
||||
request.Model = strings.TrimSuffix(request.Model, "-high")
|
||||
} else if strings.HasSuffix(request.Model, "-low") {
|
||||
request.ReasoningEffort = "low"
|
||||
request.Model = strings.TrimSuffix(request.Model, "-low")
|
||||
} else if strings.HasSuffix(request.Model, "-medium") {
|
||||
request.ReasoningEffort = "medium"
|
||||
request.Model = strings.TrimSuffix(request.Model, "-medium")
|
||||
}
|
||||
info.ReasoningEffort = request.ReasoningEffort
|
||||
info.UpstreamModelName = request.Model
|
||||
}
|
||||
return request, nil
|
||||
}
|
||||
|
||||
func (a *Adaptor) ConvertRerankRequest(c *gin.Context, relayMode int, request dto.RerankRequest) (any, error) {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
func (a *Adaptor) ConvertEmbeddingRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.EmbeddingRequest) (any, error) {
|
||||
//not available
|
||||
return nil, errors.New("not available")
|
||||
}
|
||||
|
||||
func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (any, error) {
|
||||
return channel.DoApiRequest(a, c, info, requestBody)
|
||||
}
|
||||
|
||||
func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage any, err *dto.OpenAIErrorWithStatusCode) {
|
||||
if info.IsStream {
|
||||
err, usage = xAIStreamHandler(c, resp, info)
|
||||
} else {
|
||||
err, usage = xAIHandler(c, resp, info)
|
||||
}
|
||||
//if _, ok := usage.(*dto.Usage); ok && usage != nil {
|
||||
// usage.(*dto.Usage).CompletionTokens = usage.(*dto.Usage).TotalTokens - usage.(*dto.Usage).PromptTokens
|
||||
//}
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
func (a *Adaptor) GetModelList() []string {
|
||||
return ModelList
|
||||
}
|
||||
|
||||
func (a *Adaptor) GetChannelName() string {
|
||||
return ChannelName
|
||||
}
|
||||
18
relay/channel/xai/constants.go
Normal file
18
relay/channel/xai/constants.go
Normal file
@@ -0,0 +1,18 @@
|
||||
package xai
|
||||
|
||||
var ModelList = []string{
|
||||
// grok-3
|
||||
"grok-3-beta", "grok-3-mini-beta",
|
||||
// grok-3 mini
|
||||
"grok-3-fast-beta", "grok-3-mini-fast-beta",
|
||||
// extend grok-3-mini reasoning
|
||||
"grok-3-mini-beta-high", "grok-3-mini-beta-low", "grok-3-mini-beta-medium",
|
||||
"grok-3-mini-fast-beta-high", "grok-3-mini-fast-beta-low", "grok-3-mini-fast-beta-medium",
|
||||
// image model
|
||||
"grok-2-image",
|
||||
// legacy models
|
||||
"grok-2", "grok-2-vision",
|
||||
"grok-beta", "grok-vision-beta",
|
||||
}
|
||||
|
||||
var ChannelName = "xai"
|
||||
14
relay/channel/xai/dto.go
Normal file
14
relay/channel/xai/dto.go
Normal file
@@ -0,0 +1,14 @@
|
||||
package xai
|
||||
|
||||
import "one-api/dto"
|
||||
|
||||
// ChatCompletionResponse represents the response from XAI chat completion API
|
||||
type ChatCompletionResponse struct {
|
||||
Id string `json:"id"`
|
||||
Object string `json:"object"`
|
||||
Created int64 `json:"created"`
|
||||
Model string `json:"model"`
|
||||
Choices []dto.ChatCompletionsStreamResponseChoice
|
||||
Usage *dto.Usage `json:"usage"`
|
||||
SystemFingerprint string `json:"system_fingerprint"`
|
||||
}
|
||||
107
relay/channel/xai/text.go
Normal file
107
relay/channel/xai/text.go
Normal file
@@ -0,0 +1,107 @@
|
||||
package xai
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/json"
|
||||
"github.com/gin-gonic/gin"
|
||||
"io"
|
||||
"net/http"
|
||||
"one-api/common"
|
||||
"one-api/dto"
|
||||
relaycommon "one-api/relay/common"
|
||||
"one-api/relay/helper"
|
||||
"one-api/service"
|
||||
)
|
||||
|
||||
func streamResponseXAI2OpenAI(xAIResp *dto.ChatCompletionsStreamResponse, usage *dto.Usage) *dto.ChatCompletionsStreamResponse {
|
||||
if xAIResp == nil {
|
||||
return nil
|
||||
}
|
||||
if xAIResp.Usage != nil {
|
||||
xAIResp.Usage.CompletionTokens = usage.CompletionTokens
|
||||
}
|
||||
openAIResp := &dto.ChatCompletionsStreamResponse{
|
||||
Id: xAIResp.Id,
|
||||
Object: xAIResp.Object,
|
||||
Created: xAIResp.Created,
|
||||
Model: xAIResp.Model,
|
||||
Choices: xAIResp.Choices,
|
||||
Usage: xAIResp.Usage,
|
||||
}
|
||||
|
||||
return openAIResp
|
||||
}
|
||||
|
||||
func xAIStreamHandler(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (*dto.OpenAIErrorWithStatusCode, *dto.Usage) {
|
||||
usage := &dto.Usage{}
|
||||
|
||||
helper.SetEventStreamHeaders(c)
|
||||
|
||||
helper.StreamScannerHandler(c, resp, info, func(data string) bool {
|
||||
var xAIResp *dto.ChatCompletionsStreamResponse
|
||||
err := json.Unmarshal([]byte(data), &xAIResp)
|
||||
if err != nil {
|
||||
common.SysError("error unmarshalling stream response: " + err.Error())
|
||||
return true
|
||||
}
|
||||
|
||||
// 把 xAI 的usage转换为 OpenAI 的usage
|
||||
if xAIResp.Usage != nil {
|
||||
usage.PromptTokens = xAIResp.Usage.PromptTokens
|
||||
usage.TotalTokens = xAIResp.Usage.TotalTokens
|
||||
usage.CompletionTokens = usage.TotalTokens - usage.PromptTokens
|
||||
}
|
||||
|
||||
openaiResponse := streamResponseXAI2OpenAI(xAIResp, usage)
|
||||
err = helper.ObjectData(c, openaiResponse)
|
||||
if err != nil {
|
||||
common.SysError(err.Error())
|
||||
}
|
||||
return true
|
||||
})
|
||||
|
||||
helper.Done(c)
|
||||
err := resp.Body.Close()
|
||||
if err != nil {
|
||||
//return service.OpenAIErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil
|
||||
common.SysError("close_response_body_failed: " + err.Error())
|
||||
}
|
||||
return nil, usage
|
||||
}
|
||||
|
||||
func xAIHandler(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (*dto.OpenAIErrorWithStatusCode, *dto.Usage) {
|
||||
responseBody, err := io.ReadAll(resp.Body)
|
||||
var response *dto.TextResponse
|
||||
err = common.DecodeJson(responseBody, &response)
|
||||
if err != nil {
|
||||
common.SysError("error unmarshalling stream response: " + err.Error())
|
||||
return nil, nil
|
||||
}
|
||||
response.Usage.CompletionTokens = response.Usage.TotalTokens - response.Usage.PromptTokens
|
||||
response.Usage.CompletionTokenDetails.TextTokens = response.Usage.CompletionTokens - response.Usage.CompletionTokenDetails.ReasoningTokens
|
||||
|
||||
// new body
|
||||
encodeJson, err := common.EncodeJson(response)
|
||||
if err != nil {
|
||||
common.SysError("error marshalling stream response: " + err.Error())
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
// set new body
|
||||
resp.Body = io.NopCloser(bytes.NewBuffer(encodeJson))
|
||||
|
||||
for k, v := range resp.Header {
|
||||
c.Writer.Header().Set(k, v[0])
|
||||
}
|
||||
c.Writer.WriteHeader(resp.StatusCode)
|
||||
_, err = io.Copy(c.Writer, resp.Body)
|
||||
if err != nil {
|
||||
return service.OpenAIErrorWrapper(err, "copy_response_body_failed", http.StatusInternalServerError), nil
|
||||
}
|
||||
err = resp.Body.Close()
|
||||
if err != nil {
|
||||
return service.OpenAIErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil
|
||||
}
|
||||
|
||||
return nil, &response.Usage
|
||||
}
|
||||
@@ -10,6 +10,7 @@ import (
|
||||
"one-api/relay/channel"
|
||||
"one-api/relay/channel/openai"
|
||||
relaycommon "one-api/relay/common"
|
||||
relayconstant "one-api/relay/constant"
|
||||
)
|
||||
|
||||
type Adaptor struct {
|
||||
@@ -35,7 +36,13 @@ func (a *Adaptor) Init(info *relaycommon.RelayInfo) {
|
||||
}
|
||||
|
||||
func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) {
|
||||
return fmt.Sprintf("%s/api/paas/v4/chat/completions", info.BaseUrl), nil
|
||||
baseUrl := fmt.Sprintf("%s/api/paas/v4", info.BaseUrl)
|
||||
switch info.RelayMode {
|
||||
case relayconstant.RelayModeEmbeddings:
|
||||
return fmt.Sprintf("%s/embeddings", baseUrl), nil
|
||||
default:
|
||||
return fmt.Sprintf("%s/chat/completions", baseUrl), nil
|
||||
}
|
||||
}
|
||||
|
||||
func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Header, info *relaycommon.RelayInfo) error {
|
||||
@@ -60,8 +67,7 @@ func (a *Adaptor) ConvertRerankRequest(c *gin.Context, relayMode int, request dt
|
||||
}
|
||||
|
||||
func (a *Adaptor) ConvertEmbeddingRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.EmbeddingRequest) (any, error) {
|
||||
//TODO implement me
|
||||
return nil, errors.New("not implemented")
|
||||
return request, nil
|
||||
}
|
||||
|
||||
func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (any, error) {
|
||||
|
||||
@@ -1,17 +1,9 @@
|
||||
package zhipu_4v
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"bytes"
|
||||
"encoding/json"
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/golang-jwt/jwt"
|
||||
"io"
|
||||
"net/http"
|
||||
"one-api/common"
|
||||
"one-api/dto"
|
||||
"one-api/relay/helper"
|
||||
"one-api/service"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
@@ -119,163 +111,3 @@ func requestOpenAI2Zhipu(request dto.GeneralOpenAIRequest) *dto.GeneralOpenAIReq
|
||||
ToolChoice: request.ToolChoice,
|
||||
}
|
||||
}
|
||||
|
||||
//func responseZhipu2OpenAI(response *dto.OpenAITextResponse) *dto.OpenAITextResponse {
|
||||
// fullTextResponse := dto.OpenAITextResponse{
|
||||
// Id: response.Id,
|
||||
// Object: "chat.completion",
|
||||
// Created: common.GetTimestamp(),
|
||||
// Choices: make([]dto.OpenAITextResponseChoice, 0, len(response.TextResponseChoices)),
|
||||
// Usage: response.Usage,
|
||||
// }
|
||||
// for i, choice := range response.TextResponseChoices {
|
||||
// content, _ := json.Marshal(strings.Trim(choice.Content, "\""))
|
||||
// openaiChoice := dto.OpenAITextResponseChoice{
|
||||
// Index: i,
|
||||
// Message: dto.Message{
|
||||
// Role: choice.Role,
|
||||
// Content: content,
|
||||
// },
|
||||
// FinishReason: "",
|
||||
// }
|
||||
// if i == len(response.TextResponseChoices)-1 {
|
||||
// openaiChoice.FinishReason = "stop"
|
||||
// }
|
||||
// fullTextResponse.Choices = append(fullTextResponse.Choices, openaiChoice)
|
||||
// }
|
||||
// return &fullTextResponse
|
||||
//}
|
||||
|
||||
func streamResponseZhipu2OpenAI(zhipuResponse *ZhipuV4StreamResponse) *dto.ChatCompletionsStreamResponse {
|
||||
var choice dto.ChatCompletionsStreamResponseChoice
|
||||
choice.Delta.Content = zhipuResponse.Choices[0].Delta.Content
|
||||
choice.Delta.Role = zhipuResponse.Choices[0].Delta.Role
|
||||
choice.Delta.ToolCalls = zhipuResponse.Choices[0].Delta.ToolCalls
|
||||
choice.Index = zhipuResponse.Choices[0].Index
|
||||
choice.FinishReason = zhipuResponse.Choices[0].FinishReason
|
||||
response := dto.ChatCompletionsStreamResponse{
|
||||
Id: zhipuResponse.Id,
|
||||
Object: "chat.completion.chunk",
|
||||
Created: zhipuResponse.Created,
|
||||
Model: "glm-4v",
|
||||
Choices: []dto.ChatCompletionsStreamResponseChoice{choice},
|
||||
}
|
||||
return &response
|
||||
}
|
||||
|
||||
func lastStreamResponseZhipuV42OpenAI(zhipuResponse *ZhipuV4StreamResponse) (*dto.ChatCompletionsStreamResponse, *dto.Usage) {
|
||||
response := streamResponseZhipu2OpenAI(zhipuResponse)
|
||||
return response, &zhipuResponse.Usage
|
||||
}
|
||||
|
||||
func zhipuStreamHandler(c *gin.Context, resp *http.Response) (*dto.OpenAIErrorWithStatusCode, *dto.Usage) {
|
||||
var usage *dto.Usage
|
||||
scanner := bufio.NewScanner(resp.Body)
|
||||
scanner.Split(func(data []byte, atEOF bool) (advance int, token []byte, err error) {
|
||||
if atEOF && len(data) == 0 {
|
||||
return 0, nil, nil
|
||||
}
|
||||
if i := strings.Index(string(data), "\n"); i >= 0 {
|
||||
return i + 1, data[0:i], nil
|
||||
}
|
||||
if atEOF {
|
||||
return len(data), data, nil
|
||||
}
|
||||
return 0, nil, nil
|
||||
})
|
||||
dataChan := make(chan string)
|
||||
stopChan := make(chan bool)
|
||||
go func() {
|
||||
for scanner.Scan() {
|
||||
data := scanner.Text()
|
||||
if len(data) < 6 { // ignore blank line or wrong format
|
||||
continue
|
||||
}
|
||||
if data[:6] != "data: " && data[:6] != "[DONE]" {
|
||||
continue
|
||||
}
|
||||
dataChan <- data
|
||||
}
|
||||
stopChan <- true
|
||||
}()
|
||||
helper.SetEventStreamHeaders(c)
|
||||
c.Stream(func(w io.Writer) bool {
|
||||
select {
|
||||
case data := <-dataChan:
|
||||
if strings.HasPrefix(data, "data: [DONE]") {
|
||||
data = data[:12]
|
||||
}
|
||||
// some implementations may add \r at the end of data
|
||||
data = strings.TrimSuffix(data, "\r")
|
||||
|
||||
var streamResponse ZhipuV4StreamResponse
|
||||
err := json.Unmarshal([]byte(data), &streamResponse)
|
||||
if err != nil {
|
||||
common.SysError("error unmarshalling stream response: " + err.Error())
|
||||
}
|
||||
var response *dto.ChatCompletionsStreamResponse
|
||||
if strings.Contains(data, "prompt_tokens") {
|
||||
response, usage = lastStreamResponseZhipuV42OpenAI(&streamResponse)
|
||||
} else {
|
||||
response = streamResponseZhipu2OpenAI(&streamResponse)
|
||||
}
|
||||
jsonResponse, err := json.Marshal(response)
|
||||
if err != nil {
|
||||
common.SysError("error marshalling stream response: " + err.Error())
|
||||
return true
|
||||
}
|
||||
c.Render(-1, common.CustomEvent{Data: "data: " + string(jsonResponse)})
|
||||
return true
|
||||
case <-stopChan:
|
||||
return false
|
||||
}
|
||||
})
|
||||
err := resp.Body.Close()
|
||||
if err != nil {
|
||||
return service.OpenAIErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil
|
||||
}
|
||||
return nil, usage
|
||||
}
|
||||
|
||||
func zhipuHandler(c *gin.Context, resp *http.Response) (*dto.OpenAIErrorWithStatusCode, *dto.Usage) {
|
||||
var textResponse ZhipuV4Response
|
||||
responseBody, err := io.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
return service.OpenAIErrorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil
|
||||
}
|
||||
err = resp.Body.Close()
|
||||
if err != nil {
|
||||
return service.OpenAIErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil
|
||||
}
|
||||
err = json.Unmarshal(responseBody, &textResponse)
|
||||
if err != nil {
|
||||
return service.OpenAIErrorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil
|
||||
}
|
||||
if textResponse.Error.Type != "" {
|
||||
return &dto.OpenAIErrorWithStatusCode{
|
||||
Error: textResponse.Error,
|
||||
StatusCode: resp.StatusCode,
|
||||
}, nil
|
||||
}
|
||||
// Reset response body
|
||||
resp.Body = io.NopCloser(bytes.NewBuffer(responseBody))
|
||||
|
||||
// We shouldn't set the header before we parse the response body, because the parse part may fail.
|
||||
// And then we will have to send an error response, but in this case, the header has already been set.
|
||||
// So the HTTPClient will be confused by the response.
|
||||
// For example, Postman will report error, and we cannot check the response at all.
|
||||
for k, v := range resp.Header {
|
||||
c.Writer.Header().Set(k, v[0])
|
||||
}
|
||||
c.Writer.WriteHeader(resp.StatusCode)
|
||||
_, err = io.Copy(c.Writer, resp.Body)
|
||||
if err != nil {
|
||||
return service.OpenAIErrorWrapper(err, "copy_response_body_failed", http.StatusInternalServerError), nil
|
||||
}
|
||||
err = resp.Body.Close()
|
||||
if err != nil {
|
||||
return service.OpenAIErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil
|
||||
}
|
||||
|
||||
return nil, &textResponse.Usage
|
||||
}
|
||||
|
||||
@@ -32,6 +32,7 @@ const (
|
||||
APITypeBaiduV2
|
||||
APITypeOpenRouter
|
||||
APITypeXinference
|
||||
APITypeXai
|
||||
APITypeDummy // this one is only for count, do not add any channel after this
|
||||
)
|
||||
|
||||
@@ -92,6 +93,8 @@ func ChannelType2APIType(channelType int) (int, bool) {
|
||||
apiType = APITypeOpenRouter
|
||||
case common.ChannelTypeXinference:
|
||||
apiType = APITypeXinference
|
||||
case common.ChannelTypeXai:
|
||||
apiType = APITypeXai
|
||||
}
|
||||
if apiType == -1 {
|
||||
return APITypeOpenAI, false
|
||||
|
||||
@@ -56,6 +56,9 @@ func StringData(c *gin.Context, str string) error {
|
||||
}
|
||||
|
||||
func ObjectData(c *gin.Context, object interface{}) error {
|
||||
if object == nil {
|
||||
return errors.New("object is nil")
|
||||
}
|
||||
jsonData, err := json.Marshal(object)
|
||||
if err != nil {
|
||||
return fmt.Errorf("error marshalling object: %w", err)
|
||||
|
||||
@@ -14,6 +14,11 @@ import (
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
const (
|
||||
InitialScannerBufferSize = 1 << 20 // 1MB (1*1024*1024)
|
||||
MaxScannerBufferSize = 10 << 20 // 10MB (10*1024*1024)
|
||||
)
|
||||
|
||||
func StreamScannerHandler(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo, dataHandler func(data string) bool) {
|
||||
|
||||
if resp == nil {
|
||||
@@ -38,7 +43,7 @@ func StreamScannerHandler(c *gin.Context, resp *http.Response, info *relaycommon
|
||||
ticker.Stop()
|
||||
close(stopChan)
|
||||
}()
|
||||
|
||||
scanner.Buffer(make([]byte, InitialScannerBufferSize), MaxScannerBufferSize)
|
||||
scanner.Split(bufio.ScanLines)
|
||||
SetEventStreamHeaders(c)
|
||||
|
||||
|
||||
@@ -25,6 +25,7 @@ import (
|
||||
"one-api/relay/channel/tencent"
|
||||
"one-api/relay/channel/vertex"
|
||||
"one-api/relay/channel/volcengine"
|
||||
"one-api/relay/channel/xai"
|
||||
"one-api/relay/channel/xunfei"
|
||||
"one-api/relay/channel/zhipu"
|
||||
"one-api/relay/channel/zhipu_4v"
|
||||
@@ -85,6 +86,8 @@ func GetAdaptor(apiType int) channel.Adaptor {
|
||||
return &openai.Adaptor{}
|
||||
case constant.APITypeXinference:
|
||||
return &openai.Adaptor{}
|
||||
case constant.APITypeXai:
|
||||
return &xai.Adaptor{}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -199,6 +199,15 @@ var defaultModelRatio = map[string]float64{
|
||||
"llama-3-sonar-small-32k-online": 0.2 / 1000 * USD,
|
||||
"llama-3-sonar-large-32k-chat": 1 / 1000 * USD,
|
||||
"llama-3-sonar-large-32k-online": 1 / 1000 * USD,
|
||||
// grok
|
||||
"grok-3-beta": 1.5,
|
||||
"grok-3-mini-beta": 0.15,
|
||||
"grok-2": 1,
|
||||
"grok-2-vision": 1,
|
||||
"grok-beta": 2.5,
|
||||
"grok-vision-beta": 2.5,
|
||||
"grok-3-fast-beta": 2.5,
|
||||
"grok-3-mini-fast-beta": 0.3,
|
||||
}
|
||||
|
||||
var defaultModelPrice = map[string]float64{
|
||||
|
||||
@@ -23,7 +23,7 @@
|
||||
"react-turnstile": "^1.0.5",
|
||||
"semantic-ui-offline": "^2.5.0",
|
||||
"semantic-ui-react": "^2.1.3",
|
||||
"sse": "github:mpetazzoni/sse.js",
|
||||
"sse": "https://github.com/mpetazzoni/sse.js",
|
||||
"i18next": "^23.16.8",
|
||||
"react-i18next": "^13.0.0",
|
||||
"i18next-browser-languagedetector": "^7.2.0"
|
||||
|
||||
@@ -115,4 +115,9 @@ export const CHANNEL_OPTIONS = [
|
||||
color: 'blue',
|
||||
label: '字节火山方舟、豆包、DeepSeek通用'
|
||||
},
|
||||
{
|
||||
value: 48,
|
||||
color: 'blue',
|
||||
label: 'xAI'
|
||||
}
|
||||
];
|
||||
|
||||
Reference in New Issue
Block a user