mirror of
https://github.com/QuantumNous/new-api.git
synced 2026-04-05 02:33:34 +00:00
Compare commits
15 Commits
jsrt
...
refactor_e
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
a9e03e6172 | ||
|
|
cb16bf552e | ||
|
|
98952198bb | ||
|
|
338e914a60 | ||
|
|
0e6b608f91 | ||
|
|
f1856fe4d2 | ||
|
|
870cdd5a56 | ||
|
|
f0f277dc2a | ||
|
|
b695e67154 | ||
|
|
fa2cd85007 | ||
|
|
4a8b7bfa37 | ||
|
|
7403df7e9c | ||
|
|
617c8e8f4f | ||
|
|
aa793088ed | ||
|
|
0089157b83 |
11
.env.example
11
.env.example
@@ -73,14 +73,3 @@
|
|||||||
# 节点类型
|
# 节点类型
|
||||||
# 如果是主节点则为master
|
# 如果是主节点则为master
|
||||||
# NODE_TYPE=master
|
# NODE_TYPE=master
|
||||||
|
|
||||||
|
|
||||||
# JavaScript 运行时配置
|
|
||||||
# 是否启用(默认:false)
|
|
||||||
# JS_RUNTIME_ENABLED=true
|
|
||||||
# 最大虚拟机数量(默认:8)
|
|
||||||
# JS_MAX_VM_COUNT=
|
|
||||||
# 运行超时时间(单位:秒,默认:5)
|
|
||||||
# JS_SCRIPT_TIMEOUT=
|
|
||||||
# 脚本文件夹(默认:scripts/)
|
|
||||||
# JS_SCRIPT_PATH=
|
|
||||||
|
|||||||
@@ -32,7 +32,7 @@ func UnmarshalBodyReusable(c *gin.Context, v any) error {
|
|||||||
}
|
}
|
||||||
contentType := c.Request.Header.Get("Content-Type")
|
contentType := c.Request.Header.Get("Content-Type")
|
||||||
if strings.HasPrefix(contentType, "application/json") {
|
if strings.HasPrefix(contentType, "application/json") {
|
||||||
err = UnmarshalJson(requestBody, &v)
|
err = Unmarshal(requestBody, &v)
|
||||||
} else {
|
} else {
|
||||||
// skip for now
|
// skip for now
|
||||||
// TODO: someday non json request have variant model, we will need to implementation this
|
// TODO: someday non json request have variant model, we will need to implementation this
|
||||||
|
|||||||
@@ -5,7 +5,7 @@ import (
|
|||||||
"encoding/json"
|
"encoding/json"
|
||||||
)
|
)
|
||||||
|
|
||||||
func UnmarshalJson(data []byte, v any) error {
|
func Unmarshal(data []byte, v any) error {
|
||||||
return json.Unmarshal(data, v)
|
return json.Unmarshal(data, v)
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -17,6 +17,6 @@ func DecodeJson(reader *bytes.Reader, v any) error {
|
|||||||
return json.NewDecoder(reader).Decode(v)
|
return json.NewDecoder(reader).Decode(v)
|
||||||
}
|
}
|
||||||
|
|
||||||
func EncodeJson(v any) ([]byte, error) {
|
func Marshal(v any) ([]byte, error) {
|
||||||
return json.Marshal(v)
|
return json.Marshal(v)
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -32,16 +32,30 @@ func MapToJsonStr(m map[string]interface{}) string {
|
|||||||
return string(bytes)
|
return string(bytes)
|
||||||
}
|
}
|
||||||
|
|
||||||
func StrToMap(str string) map[string]interface{} {
|
func StrToMap(str string) (map[string]interface{}, error) {
|
||||||
m := make(map[string]interface{})
|
m := make(map[string]interface{})
|
||||||
err := json.Unmarshal([]byte(str), &m)
|
err := Unmarshal([]byte(str), &m)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil
|
return nil, err
|
||||||
}
|
}
|
||||||
return m
|
return m, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func IsJsonStr(str string) bool {
|
func StrToJsonArray(str string) ([]interface{}, error) {
|
||||||
|
var js []interface{}
|
||||||
|
err := json.Unmarshal([]byte(str), &js)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
return js, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func IsJsonArray(str string) bool {
|
||||||
|
var js []interface{}
|
||||||
|
return json.Unmarshal([]byte(str), &js) == nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func IsJsonObject(str string) bool {
|
||||||
var js map[string]interface{}
|
var js map[string]interface{}
|
||||||
return json.Unmarshal([]byte(str), &js) == nil
|
return json.Unmarshal([]byte(str), &js) == nil
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -1,149 +0,0 @@
|
|||||||
package common
|
|
||||||
|
|
||||||
import (
|
|
||||||
"fmt"
|
|
||||||
"reflect"
|
|
||||||
)
|
|
||||||
|
|
||||||
// StructToMap 递归地把任意结构体 v 转成 map[string]any。
|
|
||||||
// - 只处理导出字段;未导出字段会被跳过。
|
|
||||||
// - 优先使用 `json:"name"` 里逗号前的部分作为键;如果是 "-" 则忽略该字段;若无 tag,则使用字段名。
|
|
||||||
// - 对指针、切片、数组、嵌套结构体、map 做深度遍历,保持原始结构。
|
|
||||||
func StructToMap(v any) (map[string]any, error) {
|
|
||||||
val := reflect.ValueOf(v)
|
|
||||||
if !val.IsValid() {
|
|
||||||
return nil, fmt.Errorf("nil value")
|
|
||||||
}
|
|
||||||
for val.Kind() == reflect.Pointer {
|
|
||||||
if val.IsNil() {
|
|
||||||
return nil, fmt.Errorf("nil pointer")
|
|
||||||
}
|
|
||||||
val = val.Elem()
|
|
||||||
}
|
|
||||||
if val.Kind() != reflect.Struct {
|
|
||||||
return nil, fmt.Errorf("expect struct, got %s", val.Kind())
|
|
||||||
}
|
|
||||||
|
|
||||||
return structValueToMap(val), nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func structValueToMap(val reflect.Value) map[string]any {
|
|
||||||
out := make(map[string]any, val.NumField())
|
|
||||||
|
|
||||||
typ := val.Type()
|
|
||||||
for i := 0; i < val.NumField(); i++ {
|
|
||||||
f := typ.Field(i)
|
|
||||||
if f.PkgPath != "" { // 未导出字段
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
|
|
||||||
// 解析 json tag
|
|
||||||
tag := f.Tag.Get("json")
|
|
||||||
name, opts := parseTag(tag)
|
|
||||||
if name == "-" {
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
if name == "" {
|
|
||||||
name = f.Name
|
|
||||||
}
|
|
||||||
|
|
||||||
fv := val.Field(i)
|
|
||||||
out[name] = valueToAny(fv, opts.Contains("omitempty"))
|
|
||||||
}
|
|
||||||
return out
|
|
||||||
}
|
|
||||||
|
|
||||||
// valueToAny 递归处理各种值类型。
|
|
||||||
func valueToAny(v reflect.Value, omitEmpty bool) any {
|
|
||||||
if !v.IsValid() {
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
for v.Kind() == reflect.Pointer {
|
|
||||||
if v.IsNil() {
|
|
||||||
if omitEmpty {
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
// 保持与 encoding/json 行为一致,nil 指针写成 null
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
v = v.Elem()
|
|
||||||
}
|
|
||||||
|
|
||||||
switch v.Kind() {
|
|
||||||
|
|
||||||
case reflect.Struct:
|
|
||||||
return structValueToMap(v)
|
|
||||||
|
|
||||||
case reflect.Slice, reflect.Array:
|
|
||||||
l := v.Len()
|
|
||||||
arr := make([]any, l)
|
|
||||||
for i := 0; i < l; i++ {
|
|
||||||
arr[i] = valueToAny(v.Index(i), false)
|
|
||||||
}
|
|
||||||
return arr
|
|
||||||
|
|
||||||
case reflect.Map:
|
|
||||||
m := make(map[string]any, v.Len())
|
|
||||||
iter := v.MapRange()
|
|
||||||
for iter.Next() {
|
|
||||||
k := iter.Key()
|
|
||||||
// 只支持 string key,与 encoding/json 一致
|
|
||||||
if k.Kind() == reflect.String {
|
|
||||||
m[k.String()] = valueToAny(iter.Value(), false)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return m
|
|
||||||
|
|
||||||
default:
|
|
||||||
// 基本类型直接返回其接口值
|
|
||||||
return v.Interface()
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// tagOptions 用于判断是否包含 "omitempty"
|
|
||||||
type tagOptions string
|
|
||||||
|
|
||||||
func (o tagOptions) Contains(opt string) bool {
|
|
||||||
if len(o) == 0 {
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
for _, s := range splitComma(string(o)) {
|
|
||||||
if s == opt {
|
|
||||||
return true
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
|
|
||||||
func parseTag(tag string) (string, tagOptions) {
|
|
||||||
if idx := indexComma(tag); idx != -1 {
|
|
||||||
return tag[:idx], tagOptions(tag[idx+1:])
|
|
||||||
}
|
|
||||||
return tag, tagOptions("")
|
|
||||||
}
|
|
||||||
|
|
||||||
// 避免 strings.Split 额外分配
|
|
||||||
func indexComma(s string) int {
|
|
||||||
for i, r := range s {
|
|
||||||
if r == ',' {
|
|
||||||
return i
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return -1
|
|
||||||
}
|
|
||||||
|
|
||||||
func splitComma(s string) []string {
|
|
||||||
var parts []string
|
|
||||||
start := 0
|
|
||||||
for i, r := range s {
|
|
||||||
if r == ',' {
|
|
||||||
parts = append(parts, s[start:i])
|
|
||||||
start = i + 1
|
|
||||||
}
|
|
||||||
}
|
|
||||||
if start <= len(s) {
|
|
||||||
parts = append(parts, s[start:])
|
|
||||||
}
|
|
||||||
return parts
|
|
||||||
}
|
|
||||||
@@ -17,11 +17,18 @@ const (
|
|||||||
ContextKeyTokenModelLimit ContextKey = "token_model_limit"
|
ContextKeyTokenModelLimit ContextKey = "token_model_limit"
|
||||||
|
|
||||||
/* channel related keys */
|
/* channel related keys */
|
||||||
ContextKeyBaseUrl ContextKey = "base_url"
|
ContextKeyChannelId ContextKey = "channel_id"
|
||||||
ContextKeyChannelType ContextKey = "channel_type"
|
ContextKeyChannelName ContextKey = "channel_name"
|
||||||
ContextKeyChannelId ContextKey = "channel_id"
|
ContextKeyChannelCreateTime ContextKey = "channel_create_name"
|
||||||
ContextKeyChannelSetting ContextKey = "channel_setting"
|
ContextKeyChannelBaseUrl ContextKey = "base_url"
|
||||||
ContextKeyParamOverride ContextKey = "param_override"
|
ContextKeyChannelType ContextKey = "channel_type"
|
||||||
|
ContextKeyChannelSetting ContextKey = "channel_setting"
|
||||||
|
ContextKeyChannelParamOverride ContextKey = "param_override"
|
||||||
|
ContextKeyChannelOrganization ContextKey = "channel_organization"
|
||||||
|
ContextKeyChannelAutoBan ContextKey = "auto_ban"
|
||||||
|
ContextKeyChannelModelMapping ContextKey = "model_mapping"
|
||||||
|
ContextKeyChannelStatusCodeMapping ContextKey = "status_code_mapping"
|
||||||
|
ContextKeyChannelIsMultiKey ContextKey = "channel_is_multi_key"
|
||||||
|
|
||||||
/* user related keys */
|
/* user related keys */
|
||||||
ContextKeyUserId ContextKey = "id"
|
ContextKeyUserId ContextKey = "id"
|
||||||
|
|||||||
8
constant/multi_key_mode.go
Normal file
8
constant/multi_key_mode.go
Normal file
@@ -0,0 +1,8 @@
|
|||||||
|
package constant
|
||||||
|
|
||||||
|
type MultiKeyMode string
|
||||||
|
|
||||||
|
const (
|
||||||
|
MultiKeyModeRandom MultiKeyMode = "random" // 随机
|
||||||
|
MultiKeyModePolling MultiKeyMode = "polling" // 轮询
|
||||||
|
)
|
||||||
@@ -19,6 +19,7 @@ import (
|
|||||||
relaycommon "one-api/relay/common"
|
relaycommon "one-api/relay/common"
|
||||||
"one-api/relay/helper"
|
"one-api/relay/helper"
|
||||||
"one-api/service"
|
"one-api/service"
|
||||||
|
"one-api/types"
|
||||||
"strconv"
|
"strconv"
|
||||||
"strings"
|
"strings"
|
||||||
"sync"
|
"sync"
|
||||||
@@ -29,7 +30,7 @@ import (
|
|||||||
"github.com/gin-gonic/gin"
|
"github.com/gin-gonic/gin"
|
||||||
)
|
)
|
||||||
|
|
||||||
func testChannel(channel *model.Channel, testModel string) (err error, openAIErrorWithStatusCode *dto.OpenAIErrorWithStatusCode) {
|
func testChannel(channel *model.Channel, testModel string) (err error, newAPIError *types.NewAPIError) {
|
||||||
tik := time.Now()
|
tik := time.Now()
|
||||||
if channel.Type == constant.ChannelTypeMidjourney {
|
if channel.Type == constant.ChannelTypeMidjourney {
|
||||||
return errors.New("midjourney channel test is not supported"), nil
|
return errors.New("midjourney channel test is not supported"), nil
|
||||||
@@ -98,14 +99,14 @@ func testChannel(channel *model.Channel, testModel string) (err error, openAIErr
|
|||||||
|
|
||||||
err = helper.ModelMappedHelper(c, info, nil)
|
err = helper.ModelMappedHelper(c, info, nil)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err, nil
|
return err, types.NewError(err, types.ErrorCodeChannelModelMappedError)
|
||||||
}
|
}
|
||||||
testModel = info.UpstreamModelName
|
testModel = info.UpstreamModelName
|
||||||
|
|
||||||
apiType, _ := common.ChannelType2APIType(channel.Type)
|
apiType, _ := common.ChannelType2APIType(channel.Type)
|
||||||
adaptor := relay.GetAdaptor(apiType)
|
adaptor := relay.GetAdaptor(apiType)
|
||||||
if adaptor == nil {
|
if adaptor == nil {
|
||||||
return fmt.Errorf("invalid api type: %d, adaptor is nil", apiType), nil
|
return fmt.Errorf("invalid api type: %d, adaptor is nil", apiType), types.NewError(fmt.Errorf("invalid api type: %d, adaptor is nil", apiType), types.ErrorCodeInvalidApiType)
|
||||||
}
|
}
|
||||||
|
|
||||||
request := buildTestRequest(testModel)
|
request := buildTestRequest(testModel)
|
||||||
@@ -116,45 +117,45 @@ func testChannel(channel *model.Channel, testModel string) (err error, openAIErr
|
|||||||
|
|
||||||
priceData, err := helper.ModelPriceHelper(c, info, 0, int(request.MaxTokens))
|
priceData, err := helper.ModelPriceHelper(c, info, 0, int(request.MaxTokens))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err, nil
|
return err, types.NewError(err, types.ErrorCodeModelPriceError)
|
||||||
}
|
}
|
||||||
|
|
||||||
adaptor.Init(info)
|
adaptor.Init(info)
|
||||||
|
|
||||||
convertedRequest, err := adaptor.ConvertOpenAIRequest(c, info, request)
|
convertedRequest, err := adaptor.ConvertOpenAIRequest(c, info, request)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err, nil
|
return err, types.NewError(err, types.ErrorCodeConvertRequestFailed)
|
||||||
}
|
}
|
||||||
jsonData, err := json.Marshal(convertedRequest)
|
jsonData, err := json.Marshal(convertedRequest)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err, nil
|
return err, types.NewError(err, types.ErrorCodeJsonMarshalFailed)
|
||||||
}
|
}
|
||||||
requestBody := bytes.NewBuffer(jsonData)
|
requestBody := bytes.NewBuffer(jsonData)
|
||||||
c.Request.Body = io.NopCloser(requestBody)
|
c.Request.Body = io.NopCloser(requestBody)
|
||||||
resp, err := adaptor.DoRequest(c, info, requestBody)
|
resp, err := adaptor.DoRequest(c, info, requestBody)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err, nil
|
return err, types.NewError(err, types.ErrorCodeDoRequestFailed)
|
||||||
}
|
}
|
||||||
var httpResp *http.Response
|
var httpResp *http.Response
|
||||||
if resp != nil {
|
if resp != nil {
|
||||||
httpResp = resp.(*http.Response)
|
httpResp = resp.(*http.Response)
|
||||||
if httpResp.StatusCode != http.StatusOK {
|
if httpResp.StatusCode != http.StatusOK {
|
||||||
err := service.RelayErrorHandler(httpResp, true)
|
err := service.RelayErrorHandler(httpResp, true)
|
||||||
return fmt.Errorf("status code %d: %s", httpResp.StatusCode, err.Error.Message), err
|
return err, types.NewError(err, types.ErrorCodeBadResponse)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
usageA, respErr := adaptor.DoResponse(c, httpResp, info)
|
usageA, respErr := adaptor.DoResponse(c, httpResp, info)
|
||||||
if respErr != nil {
|
if respErr != nil {
|
||||||
return fmt.Errorf("%s", respErr.Error.Message), respErr
|
return respErr, respErr
|
||||||
}
|
}
|
||||||
if usageA == nil {
|
if usageA == nil {
|
||||||
return errors.New("usage is nil"), nil
|
return errors.New("usage is nil"), types.NewError(errors.New("usage is nil"), types.ErrorCodeBadResponseBody)
|
||||||
}
|
}
|
||||||
usage := usageA.(*dto.Usage)
|
usage := usageA.(*dto.Usage)
|
||||||
result := w.Result()
|
result := w.Result()
|
||||||
respBody, err := io.ReadAll(result.Body)
|
respBody, err := io.ReadAll(result.Body)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err, nil
|
return err, types.NewError(err, types.ErrorCodeReadResponseBodyFailed)
|
||||||
}
|
}
|
||||||
info.PromptTokens = usage.PromptTokens
|
info.PromptTokens = usage.PromptTokens
|
||||||
|
|
||||||
@@ -246,15 +247,15 @@ func TestChannel(c *gin.Context) {
|
|||||||
}
|
}
|
||||||
testModel := c.Query("model")
|
testModel := c.Query("model")
|
||||||
tik := time.Now()
|
tik := time.Now()
|
||||||
err, _ = testChannel(channel, testModel)
|
_, newAPIError := testChannel(channel, testModel)
|
||||||
tok := time.Now()
|
tok := time.Now()
|
||||||
milliseconds := tok.Sub(tik).Milliseconds()
|
milliseconds := tok.Sub(tik).Milliseconds()
|
||||||
go channel.UpdateResponseTime(milliseconds)
|
go channel.UpdateResponseTime(milliseconds)
|
||||||
consumedTime := float64(milliseconds) / 1000.0
|
consumedTime := float64(milliseconds) / 1000.0
|
||||||
if err != nil {
|
if newAPIError != nil {
|
||||||
c.JSON(http.StatusOK, gin.H{
|
c.JSON(http.StatusOK, gin.H{
|
||||||
"success": false,
|
"success": false,
|
||||||
"message": err.Error(),
|
"message": newAPIError.Error(),
|
||||||
"time": consumedTime,
|
"time": consumedTime,
|
||||||
})
|
})
|
||||||
return
|
return
|
||||||
@@ -298,17 +299,15 @@ func testAllChannels(notify bool) error {
|
|||||||
for _, channel := range channels {
|
for _, channel := range channels {
|
||||||
isChannelEnabled := channel.Status == common.ChannelStatusEnabled
|
isChannelEnabled := channel.Status == common.ChannelStatusEnabled
|
||||||
tik := time.Now()
|
tik := time.Now()
|
||||||
err, openaiWithStatusErr := testChannel(channel, "")
|
err, newAPIError := testChannel(channel, "")
|
||||||
tok := time.Now()
|
tok := time.Now()
|
||||||
milliseconds := tok.Sub(tik).Milliseconds()
|
milliseconds := tok.Sub(tik).Milliseconds()
|
||||||
|
|
||||||
shouldBanChannel := false
|
shouldBanChannel := false
|
||||||
|
|
||||||
// request error disables the channel
|
// request error disables the channel
|
||||||
if openaiWithStatusErr != nil {
|
if err != nil {
|
||||||
oaiErr := openaiWithStatusErr.Error
|
shouldBanChannel = service.ShouldDisableChannel(channel.Type, newAPIError)
|
||||||
err = errors.New(fmt.Sprintf("type %s, httpCode %d, code %v, message %s", oaiErr.Type, openaiWithStatusErr.StatusCode, oaiErr.Code, oaiErr.Message))
|
|
||||||
shouldBanChannel = service.ShouldDisableChannel(channel.Type, openaiWithStatusErr)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
if milliseconds > disableThreshold {
|
if milliseconds > disableThreshold {
|
||||||
@@ -322,7 +321,7 @@ func testAllChannels(notify bool) error {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// enable channel
|
// enable channel
|
||||||
if !isChannelEnabled && service.ShouldEnableChannel(err, openaiWithStatusErr, channel.Status) {
|
if !isChannelEnabled && service.ShouldEnableChannel(err, newAPIError, channel.Status) {
|
||||||
service.EnableChannel(channel.Id, channel.Name)
|
service.EnableChannel(channel.Id, channel.Name)
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -354,6 +353,10 @@ func TestAllChannels(c *gin.Context) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func AutomaticallyTestChannels(frequency int) {
|
func AutomaticallyTestChannels(frequency int) {
|
||||||
|
if frequency <= 0 {
|
||||||
|
common.SysLog("CHANNEL_TEST_FREQUENCY is not set or invalid, skipping automatic channel test")
|
||||||
|
return
|
||||||
|
}
|
||||||
for {
|
for {
|
||||||
time.Sleep(time.Duration(frequency) * time.Minute)
|
time.Sleep(time.Duration(frequency) * time.Minute)
|
||||||
common.SysLog("testing all channels")
|
common.SysLog("testing all channels")
|
||||||
|
|||||||
@@ -380,9 +380,47 @@ func GetChannel(c *gin.Context) {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
type AddChannelRequest struct {
|
||||||
|
Mode string `json:"mode"`
|
||||||
|
MultiKeyMode constant.MultiKeyMode `json:"multi_key_mode"`
|
||||||
|
Channel *model.Channel `json:"channel"`
|
||||||
|
}
|
||||||
|
|
||||||
|
func getVertexArrayKeys(keys string) ([]string, error) {
|
||||||
|
if keys == "" {
|
||||||
|
return nil, nil
|
||||||
|
}
|
||||||
|
var keyArray []interface{}
|
||||||
|
err := common.Unmarshal([]byte(keys), &keyArray)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("批量添加 Vertex AI 必须使用标准的JsonArray格式,例如[{key1}, {key2}...],请检查输入: %w", err)
|
||||||
|
}
|
||||||
|
cleanKeys := make([]string, 0, len(keyArray))
|
||||||
|
for _, key := range keyArray {
|
||||||
|
var keyStr string
|
||||||
|
switch v := key.(type) {
|
||||||
|
case string:
|
||||||
|
keyStr = strings.TrimSpace(v)
|
||||||
|
default:
|
||||||
|
bytes, err := json.Marshal(v)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("Vertex AI key JSON 编码失败: %w", err)
|
||||||
|
}
|
||||||
|
keyStr = string(bytes)
|
||||||
|
}
|
||||||
|
if keyStr != "" {
|
||||||
|
cleanKeys = append(cleanKeys, keyStr)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if len(cleanKeys) == 0 {
|
||||||
|
return nil, fmt.Errorf("批量添加 Vertex AI 的 keys 不能为空")
|
||||||
|
}
|
||||||
|
return cleanKeys, nil
|
||||||
|
}
|
||||||
|
|
||||||
func AddChannel(c *gin.Context) {
|
func AddChannel(c *gin.Context) {
|
||||||
channel := model.Channel{}
|
addChannelRequest := AddChannelRequest{}
|
||||||
err := c.ShouldBindJSON(&channel)
|
err := c.ShouldBindJSON(&addChannelRequest)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
c.JSON(http.StatusOK, gin.H{
|
c.JSON(http.StatusOK, gin.H{
|
||||||
"success": false,
|
"success": false,
|
||||||
@@ -390,7 +428,8 @@ func AddChannel(c *gin.Context) {
|
|||||||
})
|
})
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
err = channel.ValidateSettings()
|
|
||||||
|
err = addChannelRequest.Channel.ValidateSettings()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
c.JSON(http.StatusOK, gin.H{
|
c.JSON(http.StatusOK, gin.H{
|
||||||
"success": false,
|
"success": false,
|
||||||
@@ -398,49 +437,111 @@ func AddChannel(c *gin.Context) {
|
|||||||
})
|
})
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
channel.CreatedTime = common.GetTimestamp()
|
|
||||||
keys := strings.Split(channel.Key, "\n")
|
if addChannelRequest.Channel == nil || addChannelRequest.Channel.Key == "" {
|
||||||
if channel.Type == constant.ChannelTypeVertexAi {
|
c.JSON(http.StatusOK, gin.H{
|
||||||
if channel.Other == "" {
|
"success": false,
|
||||||
|
"message": "channel cannot be empty",
|
||||||
|
})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// Validate the length of the model name
|
||||||
|
for _, m := range addChannelRequest.Channel.GetModels() {
|
||||||
|
if len(m) > 255 {
|
||||||
|
c.JSON(http.StatusOK, gin.H{
|
||||||
|
"success": false,
|
||||||
|
"message": fmt.Sprintf("模型名称过长: %s", m),
|
||||||
|
})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if addChannelRequest.Channel.Type == constant.ChannelTypeVertexAi {
|
||||||
|
if addChannelRequest.Channel.Other == "" {
|
||||||
c.JSON(http.StatusOK, gin.H{
|
c.JSON(http.StatusOK, gin.H{
|
||||||
"success": false,
|
"success": false,
|
||||||
"message": "部署地区不能为空",
|
"message": "部署地区不能为空",
|
||||||
})
|
})
|
||||||
return
|
return
|
||||||
} else {
|
} else {
|
||||||
if common.IsJsonStr(channel.Other) {
|
regionMap, err := common.StrToMap(addChannelRequest.Channel.Other)
|
||||||
// must have default
|
if err != nil {
|
||||||
regionMap := common.StrToMap(channel.Other)
|
c.JSON(http.StatusOK, gin.H{
|
||||||
if regionMap["default"] == nil {
|
"success": false,
|
||||||
c.JSON(http.StatusOK, gin.H{
|
"message": "部署地区必须是标准的Json格式,例如{\"default\": \"us-central1\", \"region2\": \"us-east1\"}",
|
||||||
"success": false,
|
})
|
||||||
"message": "部署地区必须包含default字段",
|
return
|
||||||
})
|
}
|
||||||
return
|
if regionMap["default"] == nil {
|
||||||
}
|
c.JSON(http.StatusOK, gin.H{
|
||||||
|
"success": false,
|
||||||
|
"message": "部署地区必须包含default字段",
|
||||||
|
})
|
||||||
|
return
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
keys = []string{channel.Key}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
addChannelRequest.Channel.CreatedTime = common.GetTimestamp()
|
||||||
|
keys := make([]string, 0)
|
||||||
|
switch addChannelRequest.Mode {
|
||||||
|
case "multi_to_single":
|
||||||
|
addChannelRequest.Channel.ChannelInfo.IsMultiKey = true
|
||||||
|
addChannelRequest.Channel.ChannelInfo.MultiKeyMode = addChannelRequest.MultiKeyMode
|
||||||
|
if addChannelRequest.Channel.Type == constant.ChannelTypeVertexAi {
|
||||||
|
array, err := getVertexArrayKeys(addChannelRequest.Channel.Key)
|
||||||
|
if err != nil {
|
||||||
|
c.JSON(http.StatusOK, gin.H{
|
||||||
|
"success": false,
|
||||||
|
"message": err.Error(),
|
||||||
|
})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
addChannelRequest.Channel.Key = strings.Join(array, "\n")
|
||||||
|
} else {
|
||||||
|
cleanKeys := make([]string, 0)
|
||||||
|
for _, key := range strings.Split(addChannelRequest.Channel.Key, "\n") {
|
||||||
|
if key == "" {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
key = strings.TrimSpace(key)
|
||||||
|
cleanKeys = append(cleanKeys, key)
|
||||||
|
}
|
||||||
|
addChannelRequest.Channel.Key = strings.Join(cleanKeys, "\n")
|
||||||
|
}
|
||||||
|
keys = []string{addChannelRequest.Channel.Key}
|
||||||
|
case "batch":
|
||||||
|
if addChannelRequest.Channel.Type == constant.ChannelTypeVertexAi {
|
||||||
|
// multi json
|
||||||
|
keys, err = getVertexArrayKeys(addChannelRequest.Channel.Key)
|
||||||
|
if err != nil {
|
||||||
|
c.JSON(http.StatusOK, gin.H{
|
||||||
|
"success": false,
|
||||||
|
"message": err.Error(),
|
||||||
|
})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
keys = strings.Split(addChannelRequest.Channel.Key, "\n")
|
||||||
|
}
|
||||||
|
case "single":
|
||||||
|
keys = []string{addChannelRequest.Channel.Key}
|
||||||
|
default:
|
||||||
|
c.JSON(http.StatusOK, gin.H{
|
||||||
|
"success": false,
|
||||||
|
"message": "不支持的添加模式",
|
||||||
|
})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
channels := make([]model.Channel, 0, len(keys))
|
channels := make([]model.Channel, 0, len(keys))
|
||||||
for _, key := range keys {
|
for _, key := range keys {
|
||||||
if key == "" {
|
if key == "" {
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
localChannel := channel
|
localChannel := addChannelRequest.Channel
|
||||||
localChannel.Key = key
|
localChannel.Key = key
|
||||||
// Validate the length of the model name
|
channels = append(channels, *localChannel)
|
||||||
models := strings.Split(localChannel.Models, ",")
|
|
||||||
for _, model := range models {
|
|
||||||
if len(model) > 255 {
|
|
||||||
c.JSON(http.StatusOK, gin.H{
|
|
||||||
"success": false,
|
|
||||||
"message": fmt.Sprintf("模型名称过长: %s", model),
|
|
||||||
})
|
|
||||||
return
|
|
||||||
}
|
|
||||||
}
|
|
||||||
channels = append(channels, localChannel)
|
|
||||||
}
|
}
|
||||||
err = model.BatchInsertChannels(channels)
|
err = model.BatchInsertChannels(channels)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@@ -641,16 +742,20 @@ func UpdateChannel(c *gin.Context) {
|
|||||||
})
|
})
|
||||||
return
|
return
|
||||||
} else {
|
} else {
|
||||||
if common.IsJsonStr(channel.Other) {
|
regionMap, err := common.StrToMap(channel.Other)
|
||||||
// must have default
|
if err != nil {
|
||||||
regionMap := common.StrToMap(channel.Other)
|
c.JSON(http.StatusOK, gin.H{
|
||||||
if regionMap["default"] == nil {
|
"success": false,
|
||||||
c.JSON(http.StatusOK, gin.H{
|
"message": "部署地区必须是标准的Json格式,例如{\"default\": \"us-central1\", \"region2\": \"us-east1\"}",
|
||||||
"success": false,
|
})
|
||||||
"message": "部署地区必须包含default字段",
|
return
|
||||||
})
|
}
|
||||||
return
|
if regionMap["default"] == nil {
|
||||||
}
|
c.JSON(http.StatusOK, gin.H{
|
||||||
|
"success": false,
|
||||||
|
"message": "部署地区必须包含default字段",
|
||||||
|
})
|
||||||
|
return
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -1,14 +1,12 @@
|
|||||||
package controller
|
package controller
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"slices"
|
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
"fmt"
|
"fmt"
|
||||||
"net/http"
|
"net/http"
|
||||||
"one-api/common"
|
"one-api/common"
|
||||||
"one-api/constant"
|
"one-api/constant"
|
||||||
"one-api/middleware"
|
"one-api/middleware"
|
||||||
"one-api/middleware/jsrt"
|
|
||||||
"one-api/model"
|
"one-api/model"
|
||||||
"one-api/setting"
|
"one-api/setting"
|
||||||
"one-api/setting/console_setting"
|
"one-api/setting/console_setting"
|
||||||
@@ -35,6 +33,7 @@ func TestStatus(c *gin.Context) {
|
|||||||
"message": "Server is running",
|
"message": "Server is running",
|
||||||
"http_stats": httpStats,
|
"http_stats": httpStats,
|
||||||
})
|
})
|
||||||
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
func GetStatus(c *gin.Context) {
|
func GetStatus(c *gin.Context) {
|
||||||
@@ -107,6 +106,7 @@ func GetStatus(c *gin.Context) {
|
|||||||
"message": "",
|
"message": "",
|
||||||
"data": data,
|
"data": data,
|
||||||
})
|
})
|
||||||
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
func GetNotice(c *gin.Context) {
|
func GetNotice(c *gin.Context) {
|
||||||
@@ -117,6 +117,7 @@ func GetNotice(c *gin.Context) {
|
|||||||
"message": "",
|
"message": "",
|
||||||
"data": common.OptionMap["Notice"],
|
"data": common.OptionMap["Notice"],
|
||||||
})
|
})
|
||||||
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
func GetAbout(c *gin.Context) {
|
func GetAbout(c *gin.Context) {
|
||||||
@@ -127,6 +128,7 @@ func GetAbout(c *gin.Context) {
|
|||||||
"message": "",
|
"message": "",
|
||||||
"data": common.OptionMap["About"],
|
"data": common.OptionMap["About"],
|
||||||
})
|
})
|
||||||
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
func GetMidjourney(c *gin.Context) {
|
func GetMidjourney(c *gin.Context) {
|
||||||
@@ -137,6 +139,7 @@ func GetMidjourney(c *gin.Context) {
|
|||||||
"message": "",
|
"message": "",
|
||||||
"data": common.OptionMap["Midjourney"],
|
"data": common.OptionMap["Midjourney"],
|
||||||
})
|
})
|
||||||
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
func GetHomePageContent(c *gin.Context) {
|
func GetHomePageContent(c *gin.Context) {
|
||||||
@@ -147,6 +150,7 @@ func GetHomePageContent(c *gin.Context) {
|
|||||||
"message": "",
|
"message": "",
|
||||||
"data": common.OptionMap["HomePageContent"],
|
"data": common.OptionMap["HomePageContent"],
|
||||||
})
|
})
|
||||||
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
func SendEmailVerification(c *gin.Context) {
|
func SendEmailVerification(c *gin.Context) {
|
||||||
@@ -169,7 +173,13 @@ func SendEmailVerification(c *gin.Context) {
|
|||||||
localPart := parts[0]
|
localPart := parts[0]
|
||||||
domainPart := parts[1]
|
domainPart := parts[1]
|
||||||
if common.EmailDomainRestrictionEnabled {
|
if common.EmailDomainRestrictionEnabled {
|
||||||
allowed := slices.Contains(common.EmailDomainWhitelist, domainPart)
|
allowed := false
|
||||||
|
for _, domain := range common.EmailDomainWhitelist {
|
||||||
|
if domainPart == domain {
|
||||||
|
allowed = true
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
if !allowed {
|
if !allowed {
|
||||||
c.JSON(http.StatusOK, gin.H{
|
c.JSON(http.StatusOK, gin.H{
|
||||||
"success": false,
|
"success": false,
|
||||||
@@ -214,6 +224,7 @@ func SendEmailVerification(c *gin.Context) {
|
|||||||
"success": true,
|
"success": true,
|
||||||
"message": "",
|
"message": "",
|
||||||
})
|
})
|
||||||
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
func SendPasswordResetEmail(c *gin.Context) {
|
func SendPasswordResetEmail(c *gin.Context) {
|
||||||
@@ -252,6 +263,7 @@ func SendPasswordResetEmail(c *gin.Context) {
|
|||||||
"success": true,
|
"success": true,
|
||||||
"message": "",
|
"message": "",
|
||||||
})
|
})
|
||||||
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
type PasswordResetRequest struct {
|
type PasswordResetRequest struct {
|
||||||
@@ -291,13 +303,5 @@ func ResetPassword(c *gin.Context) {
|
|||||||
"message": "",
|
"message": "",
|
||||||
"data": password,
|
"data": password,
|
||||||
})
|
})
|
||||||
}
|
return
|
||||||
|
|
||||||
func ReloadJSScripts(c *gin.Context) {
|
|
||||||
jsrt.ReloadJSScripts()
|
|
||||||
|
|
||||||
c.JSON(http.StatusOK, gin.H{
|
|
||||||
"success": true,
|
|
||||||
"message": "JavaScript 脚本已重新加载",
|
|
||||||
})
|
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -3,45 +3,44 @@ package controller
|
|||||||
import (
|
import (
|
||||||
"errors"
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
"net/http"
|
|
||||||
"one-api/common"
|
"one-api/common"
|
||||||
"one-api/constant"
|
"one-api/constant"
|
||||||
"one-api/dto"
|
"one-api/dto"
|
||||||
"one-api/middleware"
|
"one-api/middleware"
|
||||||
"one-api/model"
|
"one-api/model"
|
||||||
"one-api/service"
|
|
||||||
"one-api/setting"
|
"one-api/setting"
|
||||||
|
"one-api/types"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/gin-gonic/gin"
|
"github.com/gin-gonic/gin"
|
||||||
)
|
)
|
||||||
|
|
||||||
func Playground(c *gin.Context) {
|
func Playground(c *gin.Context) {
|
||||||
var openaiErr *dto.OpenAIErrorWithStatusCode
|
var newAPIError *types.NewAPIError
|
||||||
|
|
||||||
defer func() {
|
defer func() {
|
||||||
if openaiErr != nil {
|
if newAPIError != nil {
|
||||||
c.JSON(openaiErr.StatusCode, gin.H{
|
c.JSON(newAPIError.StatusCode, gin.H{
|
||||||
"error": openaiErr.Error,
|
"error": newAPIError.ToOpenAIError(),
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
}()
|
}()
|
||||||
|
|
||||||
useAccessToken := c.GetBool("use_access_token")
|
useAccessToken := c.GetBool("use_access_token")
|
||||||
if useAccessToken {
|
if useAccessToken {
|
||||||
openaiErr = service.OpenAIErrorWrapperLocal(errors.New("暂不支持使用 access token"), "access_token_not_supported", http.StatusBadRequest)
|
newAPIError = types.NewError(errors.New("暂不支持使用 access token"), types.ErrorCodeAccessDenied)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
playgroundRequest := &dto.PlayGroundRequest{}
|
playgroundRequest := &dto.PlayGroundRequest{}
|
||||||
err := common.UnmarshalBodyReusable(c, playgroundRequest)
|
err := common.UnmarshalBodyReusable(c, playgroundRequest)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
openaiErr = service.OpenAIErrorWrapperLocal(err, "unmarshal_request_failed", http.StatusBadRequest)
|
newAPIError = types.NewError(err, types.ErrorCodeInvalidRequest)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
if playgroundRequest.Model == "" {
|
if playgroundRequest.Model == "" {
|
||||||
openaiErr = service.OpenAIErrorWrapperLocal(errors.New("请选择模型"), "model_required", http.StatusBadRequest)
|
newAPIError = types.NewError(errors.New("请选择模型"), types.ErrorCodeInvalidRequest)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
c.Set("original_model", playgroundRequest.Model)
|
c.Set("original_model", playgroundRequest.Model)
|
||||||
@@ -52,26 +51,32 @@ func Playground(c *gin.Context) {
|
|||||||
group = userGroup
|
group = userGroup
|
||||||
} else {
|
} else {
|
||||||
if !setting.GroupInUserUsableGroups(group) && group != userGroup {
|
if !setting.GroupInUserUsableGroups(group) && group != userGroup {
|
||||||
openaiErr = service.OpenAIErrorWrapperLocal(errors.New("无权访问该分组"), "group_not_allowed", http.StatusForbidden)
|
newAPIError = types.NewError(errors.New("无权访问该分组"), types.ErrorCodeAccessDenied)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
c.Set("group", group)
|
c.Set("group", group)
|
||||||
}
|
}
|
||||||
c.Set("token_name", "playground-"+group)
|
|
||||||
channel, finalGroup, err := model.CacheGetRandomSatisfiedChannel(c, group, playgroundRequest.Model, 0)
|
userId := c.GetInt("id")
|
||||||
|
//c.Set("token_name", "playground-"+group)
|
||||||
|
tempToken := &model.Token{
|
||||||
|
UserId: userId,
|
||||||
|
Name: fmt.Sprintf("playground-%s", group),
|
||||||
|
Group: group,
|
||||||
|
}
|
||||||
|
_ = middleware.SetupContextForToken(c, tempToken)
|
||||||
|
_, err = getChannel(c, group, playgroundRequest.Model, 0)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
message := fmt.Sprintf("当前分组 %s 下对于模型 %s 无可用渠道", finalGroup, playgroundRequest.Model)
|
newAPIError = types.NewError(err, types.ErrorCodeGetChannelFailed)
|
||||||
openaiErr = service.OpenAIErrorWrapperLocal(errors.New(message), "get_playground_channel_failed", http.StatusInternalServerError)
|
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
middleware.SetupContextForSelectedChannel(c, channel, playgroundRequest.Model)
|
//middleware.SetupContextForSelectedChannel(c, channel, playgroundRequest.Model)
|
||||||
common.SetContextKey(c, constant.ContextKeyRequestStartTime, time.Now())
|
common.SetContextKey(c, constant.ContextKeyRequestStartTime, time.Now())
|
||||||
|
|
||||||
// Write user context to ensure acceptUnsetRatio is available
|
// Write user context to ensure acceptUnsetRatio is available
|
||||||
userId := c.GetInt("id")
|
|
||||||
userCache, err := model.GetUserCache(userId)
|
userCache, err := model.GetUserCache(userId)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
openaiErr = service.OpenAIErrorWrapperLocal(err, "get_user_cache_failed", http.StatusInternalServerError)
|
newAPIError = types.NewError(err, types.ErrorCodeQueryDataError)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
userCache.WriteContext(c)
|
userCache.WriteContext(c)
|
||||||
|
|||||||
@@ -17,14 +17,15 @@ import (
|
|||||||
relayconstant "one-api/relay/constant"
|
relayconstant "one-api/relay/constant"
|
||||||
"one-api/relay/helper"
|
"one-api/relay/helper"
|
||||||
"one-api/service"
|
"one-api/service"
|
||||||
|
"one-api/types"
|
||||||
"strings"
|
"strings"
|
||||||
|
|
||||||
"github.com/gin-gonic/gin"
|
"github.com/gin-gonic/gin"
|
||||||
"github.com/gorilla/websocket"
|
"github.com/gorilla/websocket"
|
||||||
)
|
)
|
||||||
|
|
||||||
func relayHandler(c *gin.Context, relayMode int) *dto.OpenAIErrorWithStatusCode {
|
func relayHandler(c *gin.Context, relayMode int) *types.NewAPIError {
|
||||||
var err *dto.OpenAIErrorWithStatusCode
|
var err *types.NewAPIError
|
||||||
switch relayMode {
|
switch relayMode {
|
||||||
case relayconstant.RelayModeImagesGenerations, relayconstant.RelayModeImagesEdits:
|
case relayconstant.RelayModeImagesGenerations, relayconstant.RelayModeImagesEdits:
|
||||||
err = relay.ImageHelper(c)
|
err = relay.ImageHelper(c)
|
||||||
@@ -55,14 +56,14 @@ func relayHandler(c *gin.Context, relayMode int) *dto.OpenAIErrorWithStatusCode
|
|||||||
userGroup := c.GetString("group")
|
userGroup := c.GetString("group")
|
||||||
channelId := c.GetInt("channel_id")
|
channelId := c.GetInt("channel_id")
|
||||||
other := make(map[string]interface{})
|
other := make(map[string]interface{})
|
||||||
other["error_type"] = err.Error.Type
|
other["error_type"] = err.ErrorType
|
||||||
other["error_code"] = err.Error.Code
|
other["error_code"] = err.GetErrorCode()
|
||||||
other["status_code"] = err.StatusCode
|
other["status_code"] = err.StatusCode
|
||||||
other["channel_id"] = channelId
|
other["channel_id"] = channelId
|
||||||
other["channel_name"] = c.GetString("channel_name")
|
other["channel_name"] = c.GetString("channel_name")
|
||||||
other["channel_type"] = c.GetInt("channel_type")
|
other["channel_type"] = c.GetInt("channel_type")
|
||||||
|
|
||||||
model.RecordErrorLog(c, userId, channelId, modelName, tokenName, err.Error.Message, tokenId, 0, false, userGroup, other)
|
model.RecordErrorLog(c, userId, channelId, modelName, tokenName, err.Error(), tokenId, 0, false, userGroup, other)
|
||||||
}
|
}
|
||||||
|
|
||||||
return err
|
return err
|
||||||
@@ -73,25 +74,25 @@ func Relay(c *gin.Context) {
|
|||||||
requestId := c.GetString(common.RequestIdKey)
|
requestId := c.GetString(common.RequestIdKey)
|
||||||
group := c.GetString("group")
|
group := c.GetString("group")
|
||||||
originalModel := c.GetString("original_model")
|
originalModel := c.GetString("original_model")
|
||||||
var openaiErr *dto.OpenAIErrorWithStatusCode
|
var newAPIError *types.NewAPIError
|
||||||
|
|
||||||
for i := 0; i <= common.RetryTimes; i++ {
|
for i := 0; i <= common.RetryTimes; i++ {
|
||||||
channel, err := getChannel(c, group, originalModel, i)
|
channel, err := getChannel(c, group, originalModel, i)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
common.LogError(c, err.Error())
|
common.LogError(c, err.Error())
|
||||||
openaiErr = service.OpenAIErrorWrapperLocal(err, "get_channel_failed", http.StatusInternalServerError)
|
newAPIError = types.NewError(err, types.ErrorCodeGetChannelFailed)
|
||||||
break
|
break
|
||||||
}
|
}
|
||||||
|
|
||||||
openaiErr = relayRequest(c, relayMode, channel)
|
newAPIError = relayRequest(c, relayMode, channel)
|
||||||
|
|
||||||
if openaiErr == nil {
|
if newAPIError == nil {
|
||||||
return // 成功处理请求,直接返回
|
return // 成功处理请求,直接返回
|
||||||
}
|
}
|
||||||
|
|
||||||
go processChannelError(c, channel.Id, channel.Type, channel.Name, channel.GetAutoBan(), openaiErr)
|
go processChannelError(c, channel.Id, channel.Type, channel.Name, channel.GetAutoBan(), newAPIError)
|
||||||
|
|
||||||
if !shouldRetry(c, openaiErr, common.RetryTimes-i) {
|
if !shouldRetry(c, newAPIError, common.RetryTimes-i) {
|
||||||
break
|
break
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -101,14 +102,14 @@ func Relay(c *gin.Context) {
|
|||||||
common.LogInfo(c, retryLogStr)
|
common.LogInfo(c, retryLogStr)
|
||||||
}
|
}
|
||||||
|
|
||||||
if openaiErr != nil {
|
if newAPIError != nil {
|
||||||
if openaiErr.StatusCode == http.StatusTooManyRequests {
|
if newAPIError.StatusCode == http.StatusTooManyRequests {
|
||||||
common.LogError(c, fmt.Sprintf("origin 429 error: %s", openaiErr.Error.Message))
|
common.LogError(c, fmt.Sprintf("origin 429 error: %s", newAPIError.Error()))
|
||||||
openaiErr.Error.Message = "当前分组上游负载已饱和,请稍后再试"
|
newAPIError.SetMessage("当前分组上游负载已饱和,请稍后再试")
|
||||||
}
|
}
|
||||||
openaiErr.Error.Message = common.MessageWithRequestId(openaiErr.Error.Message, requestId)
|
newAPIError.SetMessage(common.MessageWithRequestId(newAPIError.Error(), requestId))
|
||||||
c.JSON(openaiErr.StatusCode, gin.H{
|
c.JSON(newAPIError.StatusCode, gin.H{
|
||||||
"error": openaiErr.Error,
|
"error": newAPIError.ToOpenAIError(),
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -127,8 +128,7 @@ func WssRelay(c *gin.Context) {
|
|||||||
defer ws.Close()
|
defer ws.Close()
|
||||||
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
openaiErr := service.OpenAIErrorWrapper(err, "get_channel_failed", http.StatusInternalServerError)
|
helper.WssError(c, ws, types.NewError(err, types.ErrorCodeGetChannelFailed).ToOpenAIError())
|
||||||
helper.WssError(c, ws, openaiErr.Error)
|
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -137,25 +137,25 @@ func WssRelay(c *gin.Context) {
|
|||||||
group := c.GetString("group")
|
group := c.GetString("group")
|
||||||
//wss://api.openai.com/v1/realtime?model=gpt-4o-realtime-preview-2024-10-01
|
//wss://api.openai.com/v1/realtime?model=gpt-4o-realtime-preview-2024-10-01
|
||||||
originalModel := c.GetString("original_model")
|
originalModel := c.GetString("original_model")
|
||||||
var openaiErr *dto.OpenAIErrorWithStatusCode
|
var newAPIError *types.NewAPIError
|
||||||
|
|
||||||
for i := 0; i <= common.RetryTimes; i++ {
|
for i := 0; i <= common.RetryTimes; i++ {
|
||||||
channel, err := getChannel(c, group, originalModel, i)
|
channel, err := getChannel(c, group, originalModel, i)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
common.LogError(c, err.Error())
|
common.LogError(c, err.Error())
|
||||||
openaiErr = service.OpenAIErrorWrapperLocal(err, "get_channel_failed", http.StatusInternalServerError)
|
newAPIError = types.NewError(err, types.ErrorCodeGetChannelFailed)
|
||||||
break
|
break
|
||||||
}
|
}
|
||||||
|
|
||||||
openaiErr = wssRequest(c, ws, relayMode, channel)
|
newAPIError = wssRequest(c, ws, relayMode, channel)
|
||||||
|
|
||||||
if openaiErr == nil {
|
if newAPIError == nil {
|
||||||
return // 成功处理请求,直接返回
|
return // 成功处理请求,直接返回
|
||||||
}
|
}
|
||||||
|
|
||||||
go processChannelError(c, channel.Id, channel.Type, channel.Name, channel.GetAutoBan(), openaiErr)
|
go processChannelError(c, channel.Id, channel.Type, channel.Name, channel.GetAutoBan(), newAPIError)
|
||||||
|
|
||||||
if !shouldRetry(c, openaiErr, common.RetryTimes-i) {
|
if !shouldRetry(c, newAPIError, common.RetryTimes-i) {
|
||||||
break
|
break
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -165,12 +165,12 @@ func WssRelay(c *gin.Context) {
|
|||||||
common.LogInfo(c, retryLogStr)
|
common.LogInfo(c, retryLogStr)
|
||||||
}
|
}
|
||||||
|
|
||||||
if openaiErr != nil {
|
if newAPIError != nil {
|
||||||
if openaiErr.StatusCode == http.StatusTooManyRequests {
|
if newAPIError.StatusCode == http.StatusTooManyRequests {
|
||||||
openaiErr.Error.Message = "当前分组上游负载已饱和,请稍后再试"
|
newAPIError.SetMessage("当前分组上游负载已饱和,请稍后再试")
|
||||||
}
|
}
|
||||||
openaiErr.Error.Message = common.MessageWithRequestId(openaiErr.Error.Message, requestId)
|
newAPIError.SetMessage(common.MessageWithRequestId(newAPIError.Error(), requestId))
|
||||||
helper.WssError(c, ws, openaiErr.Error)
|
helper.WssError(c, ws, newAPIError.ToOpenAIError())
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -179,27 +179,25 @@ func RelayClaude(c *gin.Context) {
|
|||||||
requestId := c.GetString(common.RequestIdKey)
|
requestId := c.GetString(common.RequestIdKey)
|
||||||
group := c.GetString("group")
|
group := c.GetString("group")
|
||||||
originalModel := c.GetString("original_model")
|
originalModel := c.GetString("original_model")
|
||||||
var claudeErr *dto.ClaudeErrorWithStatusCode
|
var newAPIError *types.NewAPIError
|
||||||
|
|
||||||
for i := 0; i <= common.RetryTimes; i++ {
|
for i := 0; i <= common.RetryTimes; i++ {
|
||||||
channel, err := getChannel(c, group, originalModel, i)
|
channel, err := getChannel(c, group, originalModel, i)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
common.LogError(c, err.Error())
|
common.LogError(c, err.Error())
|
||||||
claudeErr = service.ClaudeErrorWrapperLocal(err, "get_channel_failed", http.StatusInternalServerError)
|
newAPIError = types.NewError(err, types.ErrorCodeGetChannelFailed)
|
||||||
break
|
break
|
||||||
}
|
}
|
||||||
|
|
||||||
claudeErr = claudeRequest(c, channel)
|
newAPIError = claudeRequest(c, channel)
|
||||||
|
|
||||||
if claudeErr == nil {
|
if newAPIError == nil {
|
||||||
return // 成功处理请求,直接返回
|
return // 成功处理请求,直接返回
|
||||||
}
|
}
|
||||||
|
|
||||||
openaiErr := service.ClaudeErrorToOpenAIError(claudeErr)
|
go processChannelError(c, channel.Id, channel.Type, channel.Name, channel.GetAutoBan(), newAPIError)
|
||||||
|
|
||||||
go processChannelError(c, channel.Id, channel.Type, channel.Name, channel.GetAutoBan(), openaiErr)
|
if !shouldRetry(c, newAPIError, common.RetryTimes-i) {
|
||||||
|
|
||||||
if !shouldRetry(c, openaiErr, common.RetryTimes-i) {
|
|
||||||
break
|
break
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -209,30 +207,30 @@ func RelayClaude(c *gin.Context) {
|
|||||||
common.LogInfo(c, retryLogStr)
|
common.LogInfo(c, retryLogStr)
|
||||||
}
|
}
|
||||||
|
|
||||||
if claudeErr != nil {
|
if newAPIError != nil {
|
||||||
claudeErr.Error.Message = common.MessageWithRequestId(claudeErr.Error.Message, requestId)
|
newAPIError.SetMessage(common.MessageWithRequestId(newAPIError.Error(), requestId))
|
||||||
c.JSON(claudeErr.StatusCode, gin.H{
|
c.JSON(newAPIError.StatusCode, gin.H{
|
||||||
"type": "error",
|
"type": "error",
|
||||||
"error": claudeErr.Error,
|
"error": newAPIError.ToClaudeError(),
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func relayRequest(c *gin.Context, relayMode int, channel *model.Channel) *dto.OpenAIErrorWithStatusCode {
|
func relayRequest(c *gin.Context, relayMode int, channel *model.Channel) *types.NewAPIError {
|
||||||
addUsedChannel(c, channel.Id)
|
addUsedChannel(c, channel.Id)
|
||||||
requestBody, _ := common.GetRequestBody(c)
|
requestBody, _ := common.GetRequestBody(c)
|
||||||
c.Request.Body = io.NopCloser(bytes.NewBuffer(requestBody))
|
c.Request.Body = io.NopCloser(bytes.NewBuffer(requestBody))
|
||||||
return relayHandler(c, relayMode)
|
return relayHandler(c, relayMode)
|
||||||
}
|
}
|
||||||
|
|
||||||
func wssRequest(c *gin.Context, ws *websocket.Conn, relayMode int, channel *model.Channel) *dto.OpenAIErrorWithStatusCode {
|
func wssRequest(c *gin.Context, ws *websocket.Conn, relayMode int, channel *model.Channel) *types.NewAPIError {
|
||||||
addUsedChannel(c, channel.Id)
|
addUsedChannel(c, channel.Id)
|
||||||
requestBody, _ := common.GetRequestBody(c)
|
requestBody, _ := common.GetRequestBody(c)
|
||||||
c.Request.Body = io.NopCloser(bytes.NewBuffer(requestBody))
|
c.Request.Body = io.NopCloser(bytes.NewBuffer(requestBody))
|
||||||
return relay.WssHelper(c, ws)
|
return relay.WssHelper(c, ws)
|
||||||
}
|
}
|
||||||
|
|
||||||
func claudeRequest(c *gin.Context, channel *model.Channel) *dto.ClaudeErrorWithStatusCode {
|
func claudeRequest(c *gin.Context, channel *model.Channel) *types.NewAPIError {
|
||||||
addUsedChannel(c, channel.Id)
|
addUsedChannel(c, channel.Id)
|
||||||
requestBody, _ := common.GetRequestBody(c)
|
requestBody, _ := common.GetRequestBody(c)
|
||||||
c.Request.Body = io.NopCloser(bytes.NewBuffer(requestBody))
|
c.Request.Body = io.NopCloser(bytes.NewBuffer(requestBody))
|
||||||
@@ -259,19 +257,25 @@ func getChannel(c *gin.Context, group, originalModel string, retryCount int) (*m
|
|||||||
AutoBan: &autoBanInt,
|
AutoBan: &autoBanInt,
|
||||||
}, nil
|
}, nil
|
||||||
}
|
}
|
||||||
channel, _, err := model.CacheGetRandomSatisfiedChannel(c, group, originalModel, retryCount)
|
channel, selectGroup, err := model.CacheGetRandomSatisfiedChannel(c, group, originalModel, retryCount)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, errors.New(fmt.Sprintf("获取重试渠道失败: %s", err.Error()))
|
if group == "auto" {
|
||||||
|
return nil, errors.New(fmt.Sprintf("获取自动分组下模型 %s 的可用渠道失败: %s", originalModel, err.Error()))
|
||||||
|
}
|
||||||
|
return nil, errors.New(fmt.Sprintf("获取分组 %s 下模型 %s 的可用渠道失败: %s", selectGroup, originalModel, err.Error()))
|
||||||
}
|
}
|
||||||
middleware.SetupContextForSelectedChannel(c, channel, originalModel)
|
middleware.SetupContextForSelectedChannel(c, channel, originalModel)
|
||||||
return channel, nil
|
return channel, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func shouldRetry(c *gin.Context, openaiErr *dto.OpenAIErrorWithStatusCode, retryTimes int) bool {
|
func shouldRetry(c *gin.Context, openaiErr *types.NewAPIError, retryTimes int) bool {
|
||||||
if openaiErr == nil {
|
if openaiErr == nil {
|
||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
if openaiErr.LocalError {
|
if types.IsChannelError(openaiErr) {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
if types.IsLocalError(openaiErr) {
|
||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
if retryTimes <= 0 {
|
if retryTimes <= 0 {
|
||||||
@@ -310,12 +314,12 @@ func shouldRetry(c *gin.Context, openaiErr *dto.OpenAIErrorWithStatusCode, retry
|
|||||||
return true
|
return true
|
||||||
}
|
}
|
||||||
|
|
||||||
func processChannelError(c *gin.Context, channelId int, channelType int, channelName string, autoBan bool, err *dto.OpenAIErrorWithStatusCode) {
|
func processChannelError(c *gin.Context, channelId int, channelType int, channelName string, autoBan bool, err *types.NewAPIError) {
|
||||||
// 不要使用context获取渠道信息,异步处理时可能会出现渠道信息不一致的情况
|
// 不要使用context获取渠道信息,异步处理时可能会出现渠道信息不一致的情况
|
||||||
// do not use context to get channel info, there may be inconsistent channel info when processing asynchronously
|
// do not use context to get channel info, there may be inconsistent channel info when processing asynchronously
|
||||||
common.LogError(c, fmt.Sprintf("relay error (channel #%d, status code: %d): %s", channelId, err.StatusCode, err.Error.Message))
|
common.LogError(c, fmt.Sprintf("relay error (channel #%d, status code: %d): %s", channelId, err.StatusCode, err.Error()))
|
||||||
if service.ShouldDisableChannel(channelType, err) && autoBan {
|
if service.ShouldDisableChannel(channelType, err) && autoBan {
|
||||||
service.DisableChannel(channelId, channelName, err.Error.Message)
|
service.DisableChannel(channelId, channelName, err.Error())
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -388,9 +392,10 @@ func RelayTask(c *gin.Context) {
|
|||||||
retryTimes = 0
|
retryTimes = 0
|
||||||
}
|
}
|
||||||
for i := 0; shouldRetryTaskRelay(c, channelId, taskErr, retryTimes) && i < retryTimes; i++ {
|
for i := 0; shouldRetryTaskRelay(c, channelId, taskErr, retryTimes) && i < retryTimes; i++ {
|
||||||
channel, _, err := model.CacheGetRandomSatisfiedChannel(c, group, originalModel, i)
|
channel, err := getChannel(c, group, originalModel, i)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
common.LogError(c, fmt.Sprintf("CacheGetRandomSatisfiedChannel failed: %s", err.Error()))
|
common.LogError(c, fmt.Sprintf("CacheGetRandomSatisfiedChannel failed: %s", err.Error()))
|
||||||
|
taskErr = service.TaskErrorWrapperLocal(err, "get_channel_failed", http.StatusInternalServerError)
|
||||||
break
|
break
|
||||||
}
|
}
|
||||||
channelId = channel.Id
|
channelId = channel.Id
|
||||||
@@ -398,7 +403,7 @@ func RelayTask(c *gin.Context) {
|
|||||||
useChannel = append(useChannel, fmt.Sprintf("%d", channelId))
|
useChannel = append(useChannel, fmt.Sprintf("%d", channelId))
|
||||||
c.Set("use_channel", useChannel)
|
c.Set("use_channel", useChannel)
|
||||||
common.LogInfo(c, fmt.Sprintf("using channel #%d to retry (remain times %d)", channel.Id, i))
|
common.LogInfo(c, fmt.Sprintf("using channel #%d to retry (remain times %d)", channel.Id, i))
|
||||||
middleware.SetupContextForSelectedChannel(c, channel, originalModel)
|
//middleware.SetupContextForSelectedChannel(c, channel, originalModel)
|
||||||
|
|
||||||
requestBody, err := common.GetRequestBody(c)
|
requestBody, err := common.GetRequestBody(c)
|
||||||
c.Request.Body = io.NopCloser(bytes.NewBuffer(requestBody))
|
c.Request.Body = io.NopCloser(bytes.NewBuffer(requestBody))
|
||||||
|
|||||||
@@ -122,7 +122,7 @@ func updateSunoTaskAll(ctx context.Context, channelId int, taskIds []string, tas
|
|||||||
}
|
}
|
||||||
if resp.StatusCode != http.StatusOK {
|
if resp.StatusCode != http.StatusOK {
|
||||||
common.LogError(ctx, fmt.Sprintf("Get Task status code: %d", resp.StatusCode))
|
common.LogError(ctx, fmt.Sprintf("Get Task status code: %d", resp.StatusCode))
|
||||||
return fmt.Errorf("Get Task status code: %d", resp.StatusCode)
|
return errors.New(fmt.Sprintf("Get Task status code: %d", resp.StatusCode))
|
||||||
}
|
}
|
||||||
defer resp.Body.Close()
|
defer resp.Body.Close()
|
||||||
responseBody, err := io.ReadAll(resp.Body)
|
responseBody, err := io.ReadAll(resp.Body)
|
||||||
|
|||||||
@@ -11,7 +11,6 @@ services:
|
|||||||
volumes:
|
volumes:
|
||||||
- ./data:/data
|
- ./data:/data
|
||||||
- ./logs:/app/logs
|
- ./logs:/app/logs
|
||||||
- ${JS_SCRIPT_DIR:-./scripts}:/app/scripts
|
|
||||||
environment:
|
environment:
|
||||||
- SQL_DSN=root:123456@tcp(mysql:3306)/new-api # Point to the mysql service
|
- SQL_DSN=root:123456@tcp(mysql:3306)/new-api # Point to the mysql service
|
||||||
- REDIS_CONN_STRING=redis://redis
|
- REDIS_CONN_STRING=redis://redis
|
||||||
@@ -22,6 +21,7 @@ services:
|
|||||||
# - NODE_TYPE=slave # Uncomment for slave node in multi-node deployment
|
# - NODE_TYPE=slave # Uncomment for slave node in multi-node deployment
|
||||||
# - SYNC_FREQUENCY=60 # Uncomment if regular database syncing is needed
|
# - SYNC_FREQUENCY=60 # Uncomment if regular database syncing is needed
|
||||||
# - FRONTEND_BASE_URL=https://openai.justsong.cn # Uncomment for multi-node deployment with front-end URL
|
# - FRONTEND_BASE_URL=https://openai.justsong.cn # Uncomment for multi-node deployment with front-end URL
|
||||||
|
|
||||||
depends_on:
|
depends_on:
|
||||||
- redis
|
- redis
|
||||||
- mysql
|
- mysql
|
||||||
|
|||||||
@@ -3,6 +3,7 @@ package dto
|
|||||||
import (
|
import (
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
"one-api/common"
|
"one-api/common"
|
||||||
|
"one-api/types"
|
||||||
)
|
)
|
||||||
|
|
||||||
type ClaudeMetadata struct {
|
type ClaudeMetadata struct {
|
||||||
@@ -228,7 +229,7 @@ type ClaudeResponse struct {
|
|||||||
Completion string `json:"completion,omitempty"`
|
Completion string `json:"completion,omitempty"`
|
||||||
StopReason string `json:"stop_reason,omitempty"`
|
StopReason string `json:"stop_reason,omitempty"`
|
||||||
Model string `json:"model,omitempty"`
|
Model string `json:"model,omitempty"`
|
||||||
Error *ClaudeError `json:"error,omitempty"`
|
Error *types.ClaudeError `json:"error,omitempty"`
|
||||||
Usage *ClaudeUsage `json:"usage,omitempty"`
|
Usage *ClaudeUsage `json:"usage,omitempty"`
|
||||||
Index *int `json:"index,omitempty"`
|
Index *int `json:"index,omitempty"`
|
||||||
ContentBlock *ClaudeMediaMessage `json:"content_block,omitempty"`
|
ContentBlock *ClaudeMediaMessage `json:"content_block,omitempty"`
|
||||||
|
|||||||
12
dto/error.go
12
dto/error.go
@@ -1,5 +1,7 @@
|
|||||||
package dto
|
package dto
|
||||||
|
|
||||||
|
import "one-api/types"
|
||||||
|
|
||||||
type OpenAIError struct {
|
type OpenAIError struct {
|
||||||
Message string `json:"message"`
|
Message string `json:"message"`
|
||||||
Type string `json:"type"`
|
Type string `json:"type"`
|
||||||
@@ -14,11 +16,11 @@ type OpenAIErrorWithStatusCode struct {
|
|||||||
}
|
}
|
||||||
|
|
||||||
type GeneralErrorResponse struct {
|
type GeneralErrorResponse struct {
|
||||||
Error OpenAIError `json:"error"`
|
Error types.OpenAIError `json:"error"`
|
||||||
Message string `json:"message"`
|
Message string `json:"message"`
|
||||||
Msg string `json:"msg"`
|
Msg string `json:"msg"`
|
||||||
Err string `json:"err"`
|
Err string `json:"err"`
|
||||||
ErrorMsg string `json:"error_msg"`
|
ErrorMsg string `json:"error_msg"`
|
||||||
Header struct {
|
Header struct {
|
||||||
Message string `json:"message"`
|
Message string `json:"message"`
|
||||||
} `json:"header"`
|
} `json:"header"`
|
||||||
|
|||||||
@@ -65,8 +65,8 @@ type GeneralOpenAIRequest struct {
|
|||||||
|
|
||||||
func (r *GeneralOpenAIRequest) ToMap() map[string]any {
|
func (r *GeneralOpenAIRequest) ToMap() map[string]any {
|
||||||
result := make(map[string]any)
|
result := make(map[string]any)
|
||||||
data, _ := common.EncodeJson(r)
|
data, _ := common.Marshal(r)
|
||||||
_ = common.UnmarshalJson(data, &result)
|
_ = common.Unmarshal(data, &result)
|
||||||
return result
|
return result
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -1,6 +1,9 @@
|
|||||||
package dto
|
package dto
|
||||||
|
|
||||||
import "encoding/json"
|
import (
|
||||||
|
"encoding/json"
|
||||||
|
"one-api/types"
|
||||||
|
)
|
||||||
|
|
||||||
type SimpleResponse struct {
|
type SimpleResponse struct {
|
||||||
Usage `json:"usage"`
|
Usage `json:"usage"`
|
||||||
@@ -28,7 +31,7 @@ type OpenAITextResponse struct {
|
|||||||
Object string `json:"object"`
|
Object string `json:"object"`
|
||||||
Created any `json:"created"`
|
Created any `json:"created"`
|
||||||
Choices []OpenAITextResponseChoice `json:"choices"`
|
Choices []OpenAITextResponseChoice `json:"choices"`
|
||||||
Error *OpenAIError `json:"error,omitempty"`
|
Error *types.OpenAIError `json:"error,omitempty"`
|
||||||
Usage `json:"usage"`
|
Usage `json:"usage"`
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -201,7 +204,7 @@ type OpenAIResponsesResponse struct {
|
|||||||
Object string `json:"object"`
|
Object string `json:"object"`
|
||||||
CreatedAt int `json:"created_at"`
|
CreatedAt int `json:"created_at"`
|
||||||
Status string `json:"status"`
|
Status string `json:"status"`
|
||||||
Error *OpenAIError `json:"error,omitempty"`
|
Error *types.OpenAIError `json:"error,omitempty"`
|
||||||
IncompleteDetails *IncompleteDetails `json:"incomplete_details,omitempty"`
|
IncompleteDetails *IncompleteDetails `json:"incomplete_details,omitempty"`
|
||||||
Instructions string `json:"instructions"`
|
Instructions string `json:"instructions"`
|
||||||
MaxOutputTokens int `json:"max_output_tokens"`
|
MaxOutputTokens int `json:"max_output_tokens"`
|
||||||
|
|||||||
@@ -1,5 +1,7 @@
|
|||||||
package dto
|
package dto
|
||||||
|
|
||||||
|
import "one-api/types"
|
||||||
|
|
||||||
const (
|
const (
|
||||||
RealtimeEventTypeError = "error"
|
RealtimeEventTypeError = "error"
|
||||||
RealtimeEventTypeSessionUpdate = "session.update"
|
RealtimeEventTypeSessionUpdate = "session.update"
|
||||||
@@ -23,12 +25,12 @@ type RealtimeEvent struct {
|
|||||||
EventId string `json:"event_id"`
|
EventId string `json:"event_id"`
|
||||||
Type string `json:"type"`
|
Type string `json:"type"`
|
||||||
//PreviousItemId string `json:"previous_item_id"`
|
//PreviousItemId string `json:"previous_item_id"`
|
||||||
Session *RealtimeSession `json:"session,omitempty"`
|
Session *RealtimeSession `json:"session,omitempty"`
|
||||||
Item *RealtimeItem `json:"item,omitempty"`
|
Item *RealtimeItem `json:"item,omitempty"`
|
||||||
Error *OpenAIError `json:"error,omitempty"`
|
Error *types.OpenAIError `json:"error,omitempty"`
|
||||||
Response *RealtimeResponse `json:"response,omitempty"`
|
Response *RealtimeResponse `json:"response,omitempty"`
|
||||||
Delta string `json:"delta,omitempty"`
|
Delta string `json:"delta,omitempty"`
|
||||||
Audio string `json:"audio,omitempty"`
|
Audio string `json:"audio,omitempty"`
|
||||||
}
|
}
|
||||||
|
|
||||||
type RealtimeResponse struct {
|
type RealtimeResponse struct {
|
||||||
|
|||||||
5
go.mod
5
go.mod
@@ -11,7 +11,6 @@ require (
|
|||||||
github.com/aws/aws-sdk-go-v2/credentials v1.17.11
|
github.com/aws/aws-sdk-go-v2/credentials v1.17.11
|
||||||
github.com/aws/aws-sdk-go-v2/service/bedrockruntime v1.7.4
|
github.com/aws/aws-sdk-go-v2/service/bedrockruntime v1.7.4
|
||||||
github.com/bytedance/gopkg v0.0.0-20220118071334-3db87571198b
|
github.com/bytedance/gopkg v0.0.0-20220118071334-3db87571198b
|
||||||
github.com/dop251/goja v0.0.0-20250630131328-58d95d85e994
|
|
||||||
github.com/gin-contrib/cors v1.7.2
|
github.com/gin-contrib/cors v1.7.2
|
||||||
github.com/gin-contrib/gzip v0.0.6
|
github.com/gin-contrib/gzip v0.0.6
|
||||||
github.com/gin-contrib/sessions v0.0.5
|
github.com/gin-contrib/sessions v0.0.5
|
||||||
@@ -32,7 +31,6 @@ require (
|
|||||||
golang.org/x/crypto v0.35.0
|
golang.org/x/crypto v0.35.0
|
||||||
golang.org/x/image v0.23.0
|
golang.org/x/image v0.23.0
|
||||||
golang.org/x/net v0.35.0
|
golang.org/x/net v0.35.0
|
||||||
golang.org/x/sync v0.11.0
|
|
||||||
gorm.io/driver/mysql v1.4.3
|
gorm.io/driver/mysql v1.4.3
|
||||||
gorm.io/driver/postgres v1.5.2
|
gorm.io/driver/postgres v1.5.2
|
||||||
gorm.io/gorm v1.25.2
|
gorm.io/gorm v1.25.2
|
||||||
@@ -58,11 +56,9 @@ require (
|
|||||||
github.com/go-ole/go-ole v1.2.6 // indirect
|
github.com/go-ole/go-ole v1.2.6 // indirect
|
||||||
github.com/go-playground/locales v0.14.1 // indirect
|
github.com/go-playground/locales v0.14.1 // indirect
|
||||||
github.com/go-playground/universal-translator v0.18.1 // indirect
|
github.com/go-playground/universal-translator v0.18.1 // indirect
|
||||||
github.com/go-sourcemap/sourcemap v2.1.3+incompatible // indirect
|
|
||||||
github.com/go-sql-driver/mysql v1.7.0 // indirect
|
github.com/go-sql-driver/mysql v1.7.0 // indirect
|
||||||
github.com/goccy/go-json v0.10.2 // indirect
|
github.com/goccy/go-json v0.10.2 // indirect
|
||||||
github.com/google/go-cmp v0.6.0 // indirect
|
github.com/google/go-cmp v0.6.0 // indirect
|
||||||
github.com/google/pprof v0.0.0-20230207041349-798e818bf904 // indirect
|
|
||||||
github.com/gorilla/context v1.1.1 // indirect
|
github.com/gorilla/context v1.1.1 // indirect
|
||||||
github.com/gorilla/securecookie v1.1.1 // indirect
|
github.com/gorilla/securecookie v1.1.1 // indirect
|
||||||
github.com/gorilla/sessions v1.2.1 // indirect
|
github.com/gorilla/sessions v1.2.1 // indirect
|
||||||
@@ -88,6 +84,7 @@ require (
|
|||||||
github.com/yusufpapurcu/wmi v1.2.3 // indirect
|
github.com/yusufpapurcu/wmi v1.2.3 // indirect
|
||||||
golang.org/x/arch v0.12.0 // indirect
|
golang.org/x/arch v0.12.0 // indirect
|
||||||
golang.org/x/exp v0.0.0-20240404231335-c0f41cb1a7a0 // indirect
|
golang.org/x/exp v0.0.0-20240404231335-c0f41cb1a7a0 // indirect
|
||||||
|
golang.org/x/sync v0.11.0 // indirect
|
||||||
golang.org/x/sys v0.30.0 // indirect
|
golang.org/x/sys v0.30.0 // indirect
|
||||||
golang.org/x/text v0.22.0 // indirect
|
golang.org/x/text v0.22.0 // indirect
|
||||||
google.golang.org/protobuf v1.34.2 // indirect
|
google.golang.org/protobuf v1.34.2 // indirect
|
||||||
|
|||||||
10
go.sum
10
go.sum
@@ -1,7 +1,5 @@
|
|||||||
github.com/Calcium-Ion/go-epay v0.0.4 h1:C96M7WfRLadcIVscWzwLiYs8etI1wrDmtFMuK2zP22A=
|
github.com/Calcium-Ion/go-epay v0.0.4 h1:C96M7WfRLadcIVscWzwLiYs8etI1wrDmtFMuK2zP22A=
|
||||||
github.com/Calcium-Ion/go-epay v0.0.4/go.mod h1:cxo/ZOg8ClvE3VAnCmEzbuyAZINSq7kFEN9oHj5WQ2U=
|
github.com/Calcium-Ion/go-epay v0.0.4/go.mod h1:cxo/ZOg8ClvE3VAnCmEzbuyAZINSq7kFEN9oHj5WQ2U=
|
||||||
github.com/Masterminds/semver/v3 v3.2.1 h1:RN9w6+7QoMeJVGyfmbcgs28Br8cvmnucEXnY0rYXWg0=
|
|
||||||
github.com/Masterminds/semver/v3 v3.2.1/go.mod h1:qvl/7zhW3nngYb5+80sSMF+FG2BjYrf8m9wsX0PNOMQ=
|
|
||||||
github.com/andybalholm/brotli v1.1.1 h1:PR2pgnyFznKEugtsUo0xLdDop5SKXd5Qf5ysW+7XdTA=
|
github.com/andybalholm/brotli v1.1.1 h1:PR2pgnyFznKEugtsUo0xLdDop5SKXd5Qf5ysW+7XdTA=
|
||||||
github.com/andybalholm/brotli v1.1.1/go.mod h1:05ib4cKhjx3OQYUY22hTVd34Bc8upXjOLL2rKwwZBoA=
|
github.com/andybalholm/brotli v1.1.1/go.mod h1:05ib4cKhjx3OQYUY22hTVd34Bc8upXjOLL2rKwwZBoA=
|
||||||
github.com/anknown/ahocorasick v0.0.0-20190904063843-d75dbd5169c0 h1:onfun1RA+KcxaMk1lfrRnwCd1UUuOjJM/lri5eM1qMs=
|
github.com/anknown/ahocorasick v0.0.0-20190904063843-d75dbd5169c0 h1:onfun1RA+KcxaMk1lfrRnwCd1UUuOjJM/lri5eM1qMs=
|
||||||
@@ -42,8 +40,6 @@ github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f h1:lO4WD4F/r
|
|||||||
github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f/go.mod h1:cuUVRXasLTGF7a8hSLbxyZXjz+1KgoB3wDUb6vlszIc=
|
github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f/go.mod h1:cuUVRXasLTGF7a8hSLbxyZXjz+1KgoB3wDUb6vlszIc=
|
||||||
github.com/dlclark/regexp2 v1.11.5 h1:Q/sSnsKerHeCkc/jSTNq1oCm7KiVgUMZRDUoRu0JQZQ=
|
github.com/dlclark/regexp2 v1.11.5 h1:Q/sSnsKerHeCkc/jSTNq1oCm7KiVgUMZRDUoRu0JQZQ=
|
||||||
github.com/dlclark/regexp2 v1.11.5/go.mod h1:DHkYz0B9wPfa6wondMfaivmHpzrQ3v9q8cnmRbL6yW8=
|
github.com/dlclark/regexp2 v1.11.5/go.mod h1:DHkYz0B9wPfa6wondMfaivmHpzrQ3v9q8cnmRbL6yW8=
|
||||||
github.com/dop251/goja v0.0.0-20250630131328-58d95d85e994 h1:aQYWswi+hRL2zJqGacdCZx32XjKYV8ApXFGntw79XAM=
|
|
||||||
github.com/dop251/goja v0.0.0-20250630131328-58d95d85e994/go.mod h1:MxLav0peU43GgvwVgNbLAj1s/bSGboKkhuULvq/7hx4=
|
|
||||||
github.com/dustin/go-humanize v1.0.1 h1:GzkhY7T5VNhEkwH0PVJgjz+fX1rhBrR7pRT3mDkpeCY=
|
github.com/dustin/go-humanize v1.0.1 h1:GzkhY7T5VNhEkwH0PVJgjz+fX1rhBrR7pRT3mDkpeCY=
|
||||||
github.com/dustin/go-humanize v1.0.1/go.mod h1:Mu1zIs6XwVuF/gI1OepvI0qD18qycQx+mFykh5fBlto=
|
github.com/dustin/go-humanize v1.0.1/go.mod h1:Mu1zIs6XwVuF/gI1OepvI0qD18qycQx+mFykh5fBlto=
|
||||||
github.com/fsnotify/fsnotify v1.4.9 h1:hsms1Qyu0jgnwNXIxa+/V/PDsU6CfLf6CNO8H7IWoS4=
|
github.com/fsnotify/fsnotify v1.4.9 h1:hsms1Qyu0jgnwNXIxa+/V/PDsU6CfLf6CNO8H7IWoS4=
|
||||||
@@ -87,8 +83,6 @@ github.com/go-playground/validator/v10 v10.20.0 h1:K9ISHbSaI0lyB2eWMPJo+kOS/FBEx
|
|||||||
github.com/go-playground/validator/v10 v10.20.0/go.mod h1:dbuPbCMFw/DrkbEynArYaCwl3amGuJotoKCe95atGMM=
|
github.com/go-playground/validator/v10 v10.20.0/go.mod h1:dbuPbCMFw/DrkbEynArYaCwl3amGuJotoKCe95atGMM=
|
||||||
github.com/go-redis/redis/v8 v8.11.5 h1:AcZZR7igkdvfVmQTPnu9WE37LRrO/YrBH5zWyjDC0oI=
|
github.com/go-redis/redis/v8 v8.11.5 h1:AcZZR7igkdvfVmQTPnu9WE37LRrO/YrBH5zWyjDC0oI=
|
||||||
github.com/go-redis/redis/v8 v8.11.5/go.mod h1:gREzHqY1hg6oD9ngVRbLStwAWKhA0FEgq8Jd4h5lpwo=
|
github.com/go-redis/redis/v8 v8.11.5/go.mod h1:gREzHqY1hg6oD9ngVRbLStwAWKhA0FEgq8Jd4h5lpwo=
|
||||||
github.com/go-sourcemap/sourcemap v2.1.3+incompatible h1:W1iEw64niKVGogNgBN3ePyLFfuisuzeidWPMPWmECqU=
|
|
||||||
github.com/go-sourcemap/sourcemap v2.1.3+incompatible/go.mod h1:F8jJfvm2KbVjc5NqelyYJmf/v5J0dwNLS2mL4sNA1Jg=
|
|
||||||
github.com/go-sql-driver/mysql v1.6.0/go.mod h1:DCzpHaOWr8IXmIStZouvnhqoel9Qv2LBy8hT2VhHyBg=
|
github.com/go-sql-driver/mysql v1.6.0/go.mod h1:DCzpHaOWr8IXmIStZouvnhqoel9Qv2LBy8hT2VhHyBg=
|
||||||
github.com/go-sql-driver/mysql v1.7.0 h1:ueSltNNllEqE3qcWBTD0iQd3IpL/6U+mJxLkazJ7YPc=
|
github.com/go-sql-driver/mysql v1.7.0 h1:ueSltNNllEqE3qcWBTD0iQd3IpL/6U+mJxLkazJ7YPc=
|
||||||
github.com/go-sql-driver/mysql v1.7.0/go.mod h1:OXbVy3sEdcQ2Doequ6Z5BW6fXNQTmx+9S1MCJN5yJMI=
|
github.com/go-sql-driver/mysql v1.7.0/go.mod h1:OXbVy3sEdcQ2Doequ6Z5BW6fXNQTmx+9S1MCJN5yJMI=
|
||||||
@@ -103,8 +97,8 @@ github.com/google/go-cmp v0.5.5/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/
|
|||||||
github.com/google/go-cmp v0.6.0 h1:ofyhxvXcZhMsU5ulbFiLKl/XBFqE1GSq7atu8tAmTRI=
|
github.com/google/go-cmp v0.6.0 h1:ofyhxvXcZhMsU5ulbFiLKl/XBFqE1GSq7atu8tAmTRI=
|
||||||
github.com/google/go-cmp v0.6.0/go.mod h1:17dUlkBOakJ0+DkrSSNjCkIjxS6bF9zb3elmeNGIjoY=
|
github.com/google/go-cmp v0.6.0/go.mod h1:17dUlkBOakJ0+DkrSSNjCkIjxS6bF9zb3elmeNGIjoY=
|
||||||
github.com/google/gofuzz v1.0.0/go.mod h1:dBl0BpW6vV/+mYPU4Po3pmUjxk6FQPldtuIdl/M65Eg=
|
github.com/google/gofuzz v1.0.0/go.mod h1:dBl0BpW6vV/+mYPU4Po3pmUjxk6FQPldtuIdl/M65Eg=
|
||||||
github.com/google/pprof v0.0.0-20230207041349-798e818bf904 h1:4/hN5RUoecvl+RmJRE2YxKWtnnQls6rQjjW5oV7qg2U=
|
github.com/google/pprof v0.0.0-20221118152302-e6195bd50e26 h1:Xim43kblpZXfIBQsbuBVKCudVG457BR2GZFIz3uw3hQ=
|
||||||
github.com/google/pprof v0.0.0-20230207041349-798e818bf904/go.mod h1:uglQLonpP8qtYCYyzA+8c/9qtqgA3qsXGYqCPKARAFg=
|
github.com/google/pprof v0.0.0-20221118152302-e6195bd50e26/go.mod h1:dDKJzRmX4S37WGHujM7tX//fmj1uioxKzKxz3lo4HJo=
|
||||||
github.com/google/uuid v1.6.0 h1:NIvaJDMOsjHA8n1jAhLSgzrAzy1Hgr+hNrb57e+94F0=
|
github.com/google/uuid v1.6.0 h1:NIvaJDMOsjHA8n1jAhLSgzrAzy1Hgr+hNrb57e+94F0=
|
||||||
github.com/google/uuid v1.6.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo=
|
github.com/google/uuid v1.6.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo=
|
||||||
github.com/gorilla/context v1.1.1 h1:AWwleXJkX/nhcU9bZSnZoi3h/qGYqQAGhq6zZe/aQW8=
|
github.com/gorilla/context v1.1.1 h1:AWwleXJkX/nhcU9bZSnZoi3h/qGYqQAGhq6zZe/aQW8=
|
||||||
|
|||||||
4
main.go
4
main.go
@@ -168,11 +168,11 @@ func InitResources() error {
|
|||||||
common.SysLog("No .env file found, using default environment variables. If needed, please create a .env file and set the relevant variables.")
|
common.SysLog("No .env file found, using default environment variables. If needed, please create a .env file and set the relevant variables.")
|
||||||
}
|
}
|
||||||
|
|
||||||
|
common.SetupLogger()
|
||||||
|
|
||||||
// 加载环境变量
|
// 加载环境变量
|
||||||
common.InitEnv()
|
common.InitEnv()
|
||||||
|
|
||||||
common.SetupLogger()
|
|
||||||
|
|
||||||
// Initialize model settings
|
// Initialize model settings
|
||||||
ratio_setting.InitRatioSettings()
|
ratio_setting.InitRatioSettings()
|
||||||
|
|
||||||
|
|||||||
@@ -1,6 +1,7 @@
|
|||||||
package middleware
|
package middleware
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"fmt"
|
||||||
"net/http"
|
"net/http"
|
||||||
"one-api/common"
|
"one-api/common"
|
||||||
"one-api/model"
|
"one-api/model"
|
||||||
@@ -233,30 +234,41 @@ func TokenAuth() func(c *gin.Context) {
|
|||||||
|
|
||||||
userCache.WriteContext(c)
|
userCache.WriteContext(c)
|
||||||
|
|
||||||
c.Set("id", token.UserId)
|
err = SetupContextForToken(c, token, parts...)
|
||||||
c.Set("token_id", token.Id)
|
if err != nil {
|
||||||
c.Set("token_key", token.Key)
|
return
|
||||||
c.Set("token_name", token.Name)
|
|
||||||
c.Set("token_unlimited_quota", token.UnlimitedQuota)
|
|
||||||
if !token.UnlimitedQuota {
|
|
||||||
c.Set("token_quota", token.RemainQuota)
|
|
||||||
}
|
|
||||||
if token.ModelLimitsEnabled {
|
|
||||||
c.Set("token_model_limit_enabled", true)
|
|
||||||
c.Set("token_model_limit", token.GetModelLimitsMap())
|
|
||||||
} else {
|
|
||||||
c.Set("token_model_limit_enabled", false)
|
|
||||||
}
|
|
||||||
c.Set("allow_ips", token.GetIpLimitsMap())
|
|
||||||
c.Set("token_group", token.Group)
|
|
||||||
if len(parts) > 1 {
|
|
||||||
if model.IsAdmin(token.UserId) {
|
|
||||||
c.Set("specific_channel_id", parts[1])
|
|
||||||
} else {
|
|
||||||
abortWithOpenAiMessage(c, http.StatusForbidden, "普通用户不支持指定渠道")
|
|
||||||
return
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
c.Next()
|
c.Next()
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func SetupContextForToken(c *gin.Context, token *model.Token, parts ...string) error {
|
||||||
|
if token == nil {
|
||||||
|
return fmt.Errorf("token is nil")
|
||||||
|
}
|
||||||
|
c.Set("id", token.UserId)
|
||||||
|
c.Set("token_id", token.Id)
|
||||||
|
c.Set("token_key", token.Key)
|
||||||
|
c.Set("token_name", token.Name)
|
||||||
|
c.Set("token_unlimited_quota", token.UnlimitedQuota)
|
||||||
|
if !token.UnlimitedQuota {
|
||||||
|
c.Set("token_quota", token.RemainQuota)
|
||||||
|
}
|
||||||
|
if token.ModelLimitsEnabled {
|
||||||
|
c.Set("token_model_limit_enabled", true)
|
||||||
|
c.Set("token_model_limit", token.GetModelLimitsMap())
|
||||||
|
} else {
|
||||||
|
c.Set("token_model_limit_enabled", false)
|
||||||
|
}
|
||||||
|
c.Set("allow_ips", token.GetIpLimitsMap())
|
||||||
|
c.Set("token_group", token.Group)
|
||||||
|
if len(parts) > 1 {
|
||||||
|
if model.IsAdmin(token.UserId) {
|
||||||
|
c.Set("specific_channel_id", parts[1])
|
||||||
|
} else {
|
||||||
|
abortWithOpenAiMessage(c, http.StatusForbidden, "普通用户不支持指定渠道")
|
||||||
|
return fmt.Errorf("普通用户不支持指定渠道")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|||||||
@@ -21,6 +21,7 @@ import (
|
|||||||
|
|
||||||
type ModelRequest struct {
|
type ModelRequest struct {
|
||||||
Model string `json:"model"`
|
Model string `json:"model"`
|
||||||
|
Group string `json:"group,omitempty"`
|
||||||
}
|
}
|
||||||
|
|
||||||
func Distribute() func(c *gin.Context) {
|
func Distribute() func(c *gin.Context) {
|
||||||
@@ -237,6 +238,14 @@ func getModelRequest(c *gin.Context) (*ModelRequest, bool, error) {
|
|||||||
}
|
}
|
||||||
c.Set("relay_mode", relayMode)
|
c.Set("relay_mode", relayMode)
|
||||||
}
|
}
|
||||||
|
if strings.HasPrefix(c.Request.URL.Path, "/pg/chat/completions") {
|
||||||
|
// playground chat completions
|
||||||
|
err = common.UnmarshalBodyReusable(c, &modelRequest)
|
||||||
|
if err != nil {
|
||||||
|
return nil, false, errors.New("无效的请求, " + err.Error())
|
||||||
|
}
|
||||||
|
common.SetContextKey(c, constant.ContextKeyTokenGroup, modelRequest.Group)
|
||||||
|
}
|
||||||
return &modelRequest, shouldSelectChannel, nil
|
return &modelRequest, shouldSelectChannel, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -245,20 +254,25 @@ func SetupContextForSelectedChannel(c *gin.Context, channel *model.Channel, mode
|
|||||||
if channel == nil {
|
if channel == nil {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
c.Set("channel_id", channel.Id)
|
common.SetContextKey(c, constant.ContextKeyChannelId, channel.Id)
|
||||||
c.Set("channel_name", channel.Name)
|
common.SetContextKey(c, constant.ContextKeyChannelName, channel.Name)
|
||||||
common.SetContextKey(c, constant.ContextKeyChannelType, channel.Type)
|
common.SetContextKey(c, constant.ContextKeyChannelType, channel.Type)
|
||||||
c.Set("channel_create_time", channel.CreatedTime)
|
common.SetContextKey(c, constant.ContextKeyChannelCreateTime, channel.CreatedTime)
|
||||||
common.SetContextKey(c, constant.ContextKeyChannelSetting, channel.GetSetting())
|
common.SetContextKey(c, constant.ContextKeyChannelSetting, channel.GetSetting())
|
||||||
c.Set("param_override", channel.GetParamOverride())
|
common.SetContextKey(c, constant.ContextKeyChannelParamOverride, channel.GetParamOverride())
|
||||||
if nil != channel.OpenAIOrganization && "" != *channel.OpenAIOrganization {
|
if nil != channel.OpenAIOrganization && *channel.OpenAIOrganization != "" {
|
||||||
c.Set("channel_organization", *channel.OpenAIOrganization)
|
common.SetContextKey(c, constant.ContextKeyChannelOrganization, *channel.OpenAIOrganization)
|
||||||
|
}
|
||||||
|
common.SetContextKey(c, constant.ContextKeyChannelAutoBan, channel.GetAutoBan())
|
||||||
|
common.SetContextKey(c, constant.ContextKeyChannelModelMapping, channel.GetModelMapping())
|
||||||
|
common.SetContextKey(c, constant.ContextKeyChannelStatusCodeMapping, channel.GetStatusCodeMapping())
|
||||||
|
if channel.ChannelInfo.IsMultiKey {
|
||||||
|
common.SetContextKey(c, constant.ContextKeyChannelIsMultiKey, true)
|
||||||
|
|
||||||
}
|
}
|
||||||
c.Set("auto_ban", channel.GetAutoBan())
|
|
||||||
c.Set("model_mapping", channel.GetModelMapping())
|
|
||||||
c.Set("status_code_mapping", channel.GetStatusCodeMapping())
|
|
||||||
c.Request.Header.Set("Authorization", fmt.Sprintf("Bearer %s", channel.Key))
|
c.Request.Header.Set("Authorization", fmt.Sprintf("Bearer %s", channel.Key))
|
||||||
common.SetContextKey(c, constant.ContextKeyBaseUrl, channel.GetBaseURL())
|
common.SetContextKey(c, constant.ContextKeyChannelBaseUrl, channel.GetBaseURL())
|
||||||
|
|
||||||
// TODO: api_version统一
|
// TODO: api_version统一
|
||||||
switch channel.Type {
|
switch channel.Type {
|
||||||
case constant.ChannelTypeAzure:
|
case constant.ChannelTypeAzure:
|
||||||
|
|||||||
@@ -1,62 +0,0 @@
|
|||||||
package jsrt
|
|
||||||
|
|
||||||
import (
|
|
||||||
"os"
|
|
||||||
"strconv"
|
|
||||||
"time"
|
|
||||||
)
|
|
||||||
|
|
||||||
// Runtime 配置
|
|
||||||
type JSRuntimeConfig struct {
|
|
||||||
Enabled bool `json:"enabled"`
|
|
||||||
MaxVMCount int `json:"max_vm_count"`
|
|
||||||
ScriptTimeout time.Duration `json:"script_timeout"`
|
|
||||||
ScriptDir string `json:"script_dir"`
|
|
||||||
FetchTimeout time.Duration `json:"fetch_timeout"`
|
|
||||||
}
|
|
||||||
|
|
||||||
var (
|
|
||||||
jsConfig = JSRuntimeConfig{}
|
|
||||||
)
|
|
||||||
|
|
||||||
const (
|
|
||||||
defaultScriptDir = "scripts/"
|
|
||||||
defaultScriptTimeout = 5 * time.Second
|
|
||||||
defaultFetchTimeout = 10 * time.Second
|
|
||||||
defaultMaxVMCount = 8
|
|
||||||
)
|
|
||||||
|
|
||||||
func loadCfg() {
|
|
||||||
if enabled := os.Getenv("JS_RUNTIME_ENABLED"); enabled != "" {
|
|
||||||
jsConfig.Enabled = enabled == "true"
|
|
||||||
}
|
|
||||||
|
|
||||||
if maxCount := os.Getenv("JS_MAX_VM_COUNT"); maxCount != "" {
|
|
||||||
if count, err := strconv.Atoi(maxCount); err == nil && count > 0 {
|
|
||||||
jsConfig.MaxVMCount = count
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
jsConfig.MaxVMCount = defaultMaxVMCount
|
|
||||||
}
|
|
||||||
|
|
||||||
if timeout := os.Getenv("JS_SCRIPT_TIMEOUT"); timeout != "" {
|
|
||||||
if t, err := time.ParseDuration(timeout + "s"); err == nil && t > 0 {
|
|
||||||
jsConfig.ScriptTimeout = t
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
jsConfig.ScriptTimeout = defaultScriptTimeout
|
|
||||||
}
|
|
||||||
|
|
||||||
if fetchTimeout := os.Getenv("JS_FETCH_TIMEOUT"); fetchTimeout != "" {
|
|
||||||
if t, err := time.ParseDuration(fetchTimeout + "s"); err == nil && t > 0 {
|
|
||||||
jsConfig.FetchTimeout = t
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
jsConfig.FetchTimeout = defaultFetchTimeout
|
|
||||||
}
|
|
||||||
|
|
||||||
jsConfig.ScriptDir = os.Getenv("JS_SCRIPT_DIR")
|
|
||||||
if jsConfig.ScriptDir == "" {
|
|
||||||
jsConfig.ScriptDir = defaultScriptDir
|
|
||||||
}
|
|
||||||
}
|
|
||||||
@@ -1,69 +0,0 @@
|
|||||||
package jsrt
|
|
||||||
|
|
||||||
import (
|
|
||||||
"one-api/common"
|
|
||||||
|
|
||||||
"gorm.io/gorm"
|
|
||||||
)
|
|
||||||
|
|
||||||
func dbQuery(db *gorm.DB, sql string, args ...any) []map[string]any {
|
|
||||||
if db == nil {
|
|
||||||
common.SysError("JS DB is nil")
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
rows, err := db.Raw(sql, args...).Rows()
|
|
||||||
if err != nil {
|
|
||||||
common.SysError("JS DB Query Error: " + err.Error())
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
defer rows.Close()
|
|
||||||
|
|
||||||
columns, err := rows.Columns()
|
|
||||||
if err != nil {
|
|
||||||
common.SysError("JS DB Columns Error: " + err.Error())
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
results := make([]map[string]any, 0, 100)
|
|
||||||
for rows.Next() {
|
|
||||||
values := make([]any, len(columns))
|
|
||||||
valuePtrs := make([]any, len(columns))
|
|
||||||
for i := range values {
|
|
||||||
valuePtrs[i] = &values[i]
|
|
||||||
}
|
|
||||||
|
|
||||||
if err := rows.Scan(valuePtrs...); err != nil {
|
|
||||||
common.SysError("JS DB Scan Error: " + err.Error())
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
|
|
||||||
row := make(map[string]any, len(columns))
|
|
||||||
for i, col := range columns {
|
|
||||||
val := values[i]
|
|
||||||
if b, ok := val.([]byte); ok {
|
|
||||||
row[col] = string(b)
|
|
||||||
} else {
|
|
||||||
row[col] = val
|
|
||||||
}
|
|
||||||
}
|
|
||||||
results = append(results, row)
|
|
||||||
}
|
|
||||||
|
|
||||||
return results
|
|
||||||
}
|
|
||||||
|
|
||||||
func dbExec(db *gorm.DB, sql string, args ...any) map[string]any {
|
|
||||||
if db == nil {
|
|
||||||
return map[string]any{
|
|
||||||
"rowsAffected": int64(0),
|
|
||||||
"error": "database is nil",
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
result := db.Exec(sql, args...)
|
|
||||||
return map[string]any{
|
|
||||||
"rowsAffected": result.RowsAffected,
|
|
||||||
"error": result.Error,
|
|
||||||
}
|
|
||||||
}
|
|
||||||
@@ -1,137 +0,0 @@
|
|||||||
package jsrt
|
|
||||||
|
|
||||||
import (
|
|
||||||
"bytes"
|
|
||||||
"context"
|
|
||||||
"encoding/json"
|
|
||||||
"fmt"
|
|
||||||
"io"
|
|
||||||
"net/http"
|
|
||||||
"strings"
|
|
||||||
"time"
|
|
||||||
)
|
|
||||||
|
|
||||||
type JSFetchRequest struct {
|
|
||||||
Method string `json:"method"`
|
|
||||||
URL string `json:"url"`
|
|
||||||
Headers map[string]string `json:"headers"`
|
|
||||||
Body any `json:"body"`
|
|
||||||
Timeout int `json:"timeout"`
|
|
||||||
}
|
|
||||||
|
|
||||||
type JSFetchResponse struct {
|
|
||||||
Status int `json:"status"`
|
|
||||||
Headers map[string]string `json:"headers"`
|
|
||||||
Body string `json:"body"`
|
|
||||||
Error string `json:"error,omitempty"`
|
|
||||||
}
|
|
||||||
|
|
||||||
func (p *JSRuntimePool) fetch(url string, options ...any) *JSFetchResponse {
|
|
||||||
req := &JSFetchRequest{
|
|
||||||
Method: "GET",
|
|
||||||
URL: url,
|
|
||||||
Headers: make(map[string]string),
|
|
||||||
Timeout: int(jsConfig.FetchTimeout.Seconds()),
|
|
||||||
}
|
|
||||||
|
|
||||||
// 解析选项
|
|
||||||
if len(options) > 0 && options[0] != nil {
|
|
||||||
if optMap, ok := options[0].(map[string]any); ok {
|
|
||||||
if method, exists := optMap["method"]; exists {
|
|
||||||
if methodStr, ok := method.(string); ok {
|
|
||||||
req.Method = strings.ToUpper(methodStr)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
if headers, exists := optMap["headers"]; exists {
|
|
||||||
if headersMap, ok := headers.(map[string]any); ok {
|
|
||||||
for k, v := range headersMap {
|
|
||||||
if vStr, ok := v.(string); ok {
|
|
||||||
req.Headers[k] = vStr
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
if body, exists := optMap["body"]; exists {
|
|
||||||
req.Body = body
|
|
||||||
}
|
|
||||||
|
|
||||||
if timeout, exists := optMap["timeout"]; exists {
|
|
||||||
if timeoutNum, ok := timeout.(float64); ok {
|
|
||||||
req.Timeout = int(timeoutNum)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// 创建HTTP请求
|
|
||||||
var bodyReader io.Reader
|
|
||||||
switch body := req.Body.(type) {
|
|
||||||
case string:
|
|
||||||
bodyReader = strings.NewReader(body)
|
|
||||||
case []byte:
|
|
||||||
bodyReader = bytes.NewReader(body)
|
|
||||||
case nil:
|
|
||||||
bodyReader = nil
|
|
||||||
default:
|
|
||||||
bodyBytes, err := json.Marshal(body)
|
|
||||||
if err != nil {
|
|
||||||
return &JSFetchResponse{
|
|
||||||
Error: fmt.Sprintf("Failed to marshal body: %v", err),
|
|
||||||
}
|
|
||||||
}
|
|
||||||
bodyReader = bytes.NewReader(bodyBytes)
|
|
||||||
}
|
|
||||||
|
|
||||||
httpReq, err := http.NewRequest(req.Method, req.URL, bodyReader)
|
|
||||||
if err != nil {
|
|
||||||
return &JSFetchResponse{
|
|
||||||
Error: err.Error(),
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// 设置请求头
|
|
||||||
for k, v := range req.Headers {
|
|
||||||
httpReq.Header.Set(k, v)
|
|
||||||
}
|
|
||||||
|
|
||||||
// 设置默认User-Agent
|
|
||||||
if httpReq.Header.Get("User-Agent") == "" {
|
|
||||||
httpReq.Header.Set("User-Agent", "JS-Runtime-Fetch/1.0")
|
|
||||||
}
|
|
||||||
|
|
||||||
// 创建带超时的上下文
|
|
||||||
ctx, cancel := context.WithTimeout(context.Background(), time.Duration(req.Timeout)*time.Second)
|
|
||||||
defer cancel()
|
|
||||||
httpReq = httpReq.WithContext(ctx)
|
|
||||||
|
|
||||||
// 执行请求
|
|
||||||
resp, err := p.httpClient.Do(httpReq)
|
|
||||||
if err != nil {
|
|
||||||
return &JSFetchResponse{}
|
|
||||||
}
|
|
||||||
defer resp.Body.Close()
|
|
||||||
|
|
||||||
// 读取响应体
|
|
||||||
bodyBytes, err := io.ReadAll(resp.Body)
|
|
||||||
if err != nil {
|
|
||||||
return &JSFetchResponse{
|
|
||||||
Status: resp.StatusCode,
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// 构建响应头
|
|
||||||
headers := make(map[string]string)
|
|
||||||
for k, v := range resp.Header {
|
|
||||||
if len(v) > 0 {
|
|
||||||
headers[k] = v[0]
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
return &JSFetchResponse{
|
|
||||||
Status: resp.StatusCode,
|
|
||||||
Headers: headers,
|
|
||||||
Body: string(bodyBytes),
|
|
||||||
}
|
|
||||||
}
|
|
||||||
@@ -1,570 +0,0 @@
|
|||||||
package jsrt
|
|
||||||
|
|
||||||
import (
|
|
||||||
"bytes"
|
|
||||||
"crypto/tls"
|
|
||||||
"encoding/json"
|
|
||||||
"fmt"
|
|
||||||
"io"
|
|
||||||
"net/http"
|
|
||||||
"one-api/common"
|
|
||||||
"one-api/model"
|
|
||||||
"os"
|
|
||||||
"path/filepath"
|
|
||||||
"strings"
|
|
||||||
"sync"
|
|
||||||
"time"
|
|
||||||
|
|
||||||
"github.com/dop251/goja"
|
|
||||||
"github.com/gin-gonic/gin"
|
|
||||||
)
|
|
||||||
|
|
||||||
// 池化
|
|
||||||
type JSRuntimePool struct {
|
|
||||||
pool chan *goja.Runtime
|
|
||||||
maxSize int
|
|
||||||
createFunc func() *goja.Runtime
|
|
||||||
scripts string
|
|
||||||
mu sync.RWMutex
|
|
||||||
httpClient *http.Client
|
|
||||||
}
|
|
||||||
|
|
||||||
var (
|
|
||||||
jsRuntimePool *JSRuntimePool
|
|
||||||
jsPoolOnce sync.Once
|
|
||||||
)
|
|
||||||
|
|
||||||
func NewJSRuntimePool(maxSize int) *JSRuntimePool {
|
|
||||||
// 创建HTTP客户端
|
|
||||||
httpClient := &http.Client{
|
|
||||||
Timeout: jsConfig.FetchTimeout,
|
|
||||||
Transport: &http.Transport{
|
|
||||||
TLSClientConfig: &tls.Config{
|
|
||||||
InsecureSkipVerify: false,
|
|
||||||
},
|
|
||||||
MaxIdleConns: 100,
|
|
||||||
MaxIdleConnsPerHost: 10,
|
|
||||||
IdleConnTimeout: 90 * time.Second,
|
|
||||||
},
|
|
||||||
}
|
|
||||||
|
|
||||||
pool := &JSRuntimePool{
|
|
||||||
pool: make(chan *goja.Runtime, maxSize),
|
|
||||||
maxSize: maxSize,
|
|
||||||
scripts: "",
|
|
||||||
httpClient: httpClient,
|
|
||||||
}
|
|
||||||
|
|
||||||
pool.createFunc = func() *goja.Runtime {
|
|
||||||
vm := goja.New()
|
|
||||||
pool.setupGlobals(vm)
|
|
||||||
pool.loadScripts(vm)
|
|
||||||
return vm
|
|
||||||
}
|
|
||||||
|
|
||||||
// 预创建VM
|
|
||||||
preCreate := min(maxSize/2, 4)
|
|
||||||
for range preCreate {
|
|
||||||
select {
|
|
||||||
case pool.pool <- pool.createFunc():
|
|
||||||
default:
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
return pool
|
|
||||||
}
|
|
||||||
|
|
||||||
func (p *JSRuntimePool) Get() *goja.Runtime {
|
|
||||||
select {
|
|
||||||
case vm := <-p.pool:
|
|
||||||
return vm
|
|
||||||
default:
|
|
||||||
return p.createFunc()
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func (p *JSRuntimePool) Put(vm *goja.Runtime) {
|
|
||||||
if vm == nil {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
select {
|
|
||||||
case p.pool <- vm:
|
|
||||||
default:
|
|
||||||
// 池满,丢弃VM让GC回收
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func (p *JSRuntimePool) setupGlobals(vm *goja.Runtime) {
|
|
||||||
// console
|
|
||||||
console := vm.NewObject()
|
|
||||||
console.Set("log", func(args ...any) {
|
|
||||||
var strs []string
|
|
||||||
for _, arg := range args {
|
|
||||||
strs = append(strs, fmt.Sprintf("%v", arg))
|
|
||||||
}
|
|
||||||
common.SysLog("JS: " + strings.Join(strs, " "))
|
|
||||||
})
|
|
||||||
console.Set("error", func(args ...any) {
|
|
||||||
var strs []string
|
|
||||||
for _, arg := range args {
|
|
||||||
strs = append(strs, fmt.Sprintf("%v", arg))
|
|
||||||
}
|
|
||||||
common.SysError("JS: " + strings.Join(strs, " "))
|
|
||||||
})
|
|
||||||
console.Set("warn", func(args ...any) {
|
|
||||||
var strs []string
|
|
||||||
for _, arg := range args {
|
|
||||||
strs = append(strs, fmt.Sprintf("%v", arg))
|
|
||||||
}
|
|
||||||
common.SysError("JS WARN: " + strings.Join(strs, " "))
|
|
||||||
})
|
|
||||||
vm.Set("console", console)
|
|
||||||
|
|
||||||
// JSON
|
|
||||||
jsonObj := vm.NewObject()
|
|
||||||
jsonObj.Set("parse", func(str string) any {
|
|
||||||
var result any
|
|
||||||
err := json.Unmarshal([]byte(str), &result)
|
|
||||||
if err != nil {
|
|
||||||
panic(vm.ToValue(err.Error()))
|
|
||||||
}
|
|
||||||
return result
|
|
||||||
})
|
|
||||||
jsonObj.Set("stringify", func(obj any) string {
|
|
||||||
data, err := json.Marshal(obj)
|
|
||||||
if err != nil {
|
|
||||||
panic(vm.ToValue(err.Error()))
|
|
||||||
}
|
|
||||||
return string(data)
|
|
||||||
})
|
|
||||||
vm.Set("JSON", jsonObj)
|
|
||||||
|
|
||||||
// fetch 实现
|
|
||||||
vm.Set("fetch", func(url string, options ...any) *JSFetchResponse {
|
|
||||||
return p.fetch(url, options...)
|
|
||||||
})
|
|
||||||
|
|
||||||
// 数据库
|
|
||||||
setDB(vm, model.DB, "db")
|
|
||||||
setDB(vm, model.LOG_DB, "logdb")
|
|
||||||
|
|
||||||
// 定时器
|
|
||||||
vm.Set("setTimeout", func(fn func(), delay int) {
|
|
||||||
go func() {
|
|
||||||
time.Sleep(time.Duration(delay) * time.Millisecond)
|
|
||||||
fn()
|
|
||||||
}()
|
|
||||||
})
|
|
||||||
}
|
|
||||||
|
|
||||||
func (p *JSRuntimePool) loadScripts(vm *goja.Runtime) {
|
|
||||||
p.mu.RLock()
|
|
||||||
defer p.mu.RUnlock()
|
|
||||||
|
|
||||||
// 如果已经缓存了合并的脚本,直接使用
|
|
||||||
if p.scripts != "" {
|
|
||||||
if _, err := vm.RunString(p.scripts); err != nil {
|
|
||||||
common.SysError("Failed to load cached scripts: " + err.Error())
|
|
||||||
}
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
// 首次加载时,读取 scripts/ 文件夹中的所有脚本
|
|
||||||
p.mu.RUnlock()
|
|
||||||
p.mu.Lock()
|
|
||||||
defer func() {
|
|
||||||
p.mu.Unlock()
|
|
||||||
p.mu.RLock()
|
|
||||||
}()
|
|
||||||
|
|
||||||
if p.scripts != "" {
|
|
||||||
if _, err := vm.RunString(p.scripts); err != nil {
|
|
||||||
common.SysError("Failed to load cached scripts: " + err.Error())
|
|
||||||
}
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
// 读取所有脚本文件
|
|
||||||
var combinedScript strings.Builder
|
|
||||||
scriptDir := jsConfig.ScriptDir
|
|
||||||
|
|
||||||
// 检查目录是否存在
|
|
||||||
if _, err := os.Stat(scriptDir); os.IsNotExist(err) {
|
|
||||||
common.SysLog("Scripts directory does not exist: " + scriptDir)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
// 读取目录中的所有 .js 文件
|
|
||||||
files, err := filepath.Glob(filepath.Join(scriptDir, "*.js"))
|
|
||||||
if err != nil {
|
|
||||||
common.SysError("Failed to read scripts directory: " + err.Error())
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
if len(files) == 0 {
|
|
||||||
common.SysLog("No JavaScript files found in: " + scriptDir)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
// 按文件名排序以确保加载顺序一致
|
|
||||||
for _, file := range files {
|
|
||||||
content, err := os.ReadFile(file)
|
|
||||||
if err != nil {
|
|
||||||
common.SysError("Failed to read script file " + file + ": " + err.Error())
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
|
|
||||||
// 添加文件注释和内容
|
|
||||||
combinedScript.WriteString("// File: " + filepath.Base(file) + "\n")
|
|
||||||
combinedScript.WriteString(string(content))
|
|
||||||
combinedScript.WriteString("\n\n")
|
|
||||||
|
|
||||||
common.SysLog("Loaded script: " + filepath.Base(file))
|
|
||||||
}
|
|
||||||
|
|
||||||
// 缓存合并后的脚本
|
|
||||||
p.scripts = combinedScript.String()
|
|
||||||
|
|
||||||
// 执行脚本
|
|
||||||
if p.scripts != "" {
|
|
||||||
if _, err := vm.RunString(p.scripts); err != nil {
|
|
||||||
common.SysError("Failed to load combined scripts: " + err.Error())
|
|
||||||
} else {
|
|
||||||
common.SysLog("Successfully loaded and combined all JavaScript files from: " + scriptDir)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func (p *JSRuntimePool) ReloadScripts() {
|
|
||||||
p.mu.Lock()
|
|
||||||
defer p.mu.Unlock()
|
|
||||||
|
|
||||||
// 清空缓存的脚本
|
|
||||||
p.scripts = ""
|
|
||||||
|
|
||||||
// 清空VM池,强制重新创建
|
|
||||||
for {
|
|
||||||
select {
|
|
||||||
case <-p.pool:
|
|
||||||
default:
|
|
||||||
goto done
|
|
||||||
}
|
|
||||||
}
|
|
||||||
done:
|
|
||||||
common.SysLog("JavaScript scripts reloaded")
|
|
||||||
}
|
|
||||||
|
|
||||||
func initJSRuntimePool() *JSRuntimePool {
|
|
||||||
jsPoolOnce.Do(func() {
|
|
||||||
jsRuntimePool = NewJSRuntimePool(jsConfig.MaxVMCount)
|
|
||||||
common.SysLog("JavaScript runtime pool initialized successfully")
|
|
||||||
})
|
|
||||||
return jsRuntimePool
|
|
||||||
}
|
|
||||||
|
|
||||||
func validateGinContext(c *gin.Context) error {
|
|
||||||
if c == nil {
|
|
||||||
return fmt.Errorf("gin context is nil")
|
|
||||||
}
|
|
||||||
if c.Request == nil {
|
|
||||||
return fmt.Errorf("gin context request is nil")
|
|
||||||
}
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (p *JSRuntimePool) executeWithTimeout(_ *goja.Runtime, fn func() (goja.Value, error)) (goja.Value, error) {
|
|
||||||
type result struct {
|
|
||||||
value goja.Value
|
|
||||||
err error
|
|
||||||
}
|
|
||||||
|
|
||||||
resultChan := make(chan result, 1)
|
|
||||||
go func() {
|
|
||||||
defer func() {
|
|
||||||
if r := recover(); r != nil {
|
|
||||||
resultChan <- result{err: fmt.Errorf("JS panic: %v", r)}
|
|
||||||
}
|
|
||||||
}()
|
|
||||||
|
|
||||||
value, err := fn()
|
|
||||||
resultChan <- result{value: value, err: err}
|
|
||||||
}()
|
|
||||||
|
|
||||||
select {
|
|
||||||
case res := <-resultChan:
|
|
||||||
return res.value, res.err
|
|
||||||
case <-time.After(jsConfig.ScriptTimeout):
|
|
||||||
return nil, fmt.Errorf("script execution timeout after %v", jsConfig.ScriptTimeout)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func (p *JSRuntimePool) PreProcessRequest(c *gin.Context) error {
|
|
||||||
if err := validateGinContext(c); err != nil {
|
|
||||||
common.SysError("JS PreProcess Validation Error: " + err.Error())
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
vm := p.Get()
|
|
||||||
defer p.Put(vm)
|
|
||||||
|
|
||||||
preProcessFunc := vm.Get("preProcessRequest")
|
|
||||||
if preProcessFunc == nil || goja.IsUndefined(preProcessFunc) {
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
jsReq, err := common.StructToMap(createJSReq(c))
|
|
||||||
if err != nil {
|
|
||||||
return fmt.Errorf("failed to create JS context: %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
result, err := p.executeWithTimeout(vm, func() (goja.Value, error) {
|
|
||||||
fn, ok := goja.AssertFunction(preProcessFunc)
|
|
||||||
if !ok {
|
|
||||||
return nil, fmt.Errorf("preProcessRequest is not a function")
|
|
||||||
}
|
|
||||||
return fn(goja.Undefined(), vm.ToValue(jsReq))
|
|
||||||
})
|
|
||||||
|
|
||||||
if err != nil {
|
|
||||||
common.SysError("JS PreProcess Error: " + err.Error())
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
// 处理返回结果
|
|
||||||
if result != nil && !goja.IsUndefined(result) {
|
|
||||||
resultObj := result.Export()
|
|
||||||
if resultMap, ok := resultObj.(map[string]any); ok {
|
|
||||||
// 是否修改请求
|
|
||||||
if newBody, exists := resultMap["body"]; exists {
|
|
||||||
switch v := newBody.(type) {
|
|
||||||
case string:
|
|
||||||
c.Request.Body = io.NopCloser(strings.NewReader(v))
|
|
||||||
c.Request.ContentLength = int64(len(v))
|
|
||||||
case []byte:
|
|
||||||
c.Request.Body = io.NopCloser(bytes.NewBuffer(v))
|
|
||||||
c.Request.ContentLength = int64(len(v))
|
|
||||||
case map[string]any:
|
|
||||||
bodyBytes, err := json.Marshal(v)
|
|
||||||
if err == nil {
|
|
||||||
c.Request.Body = io.NopCloser(bytes.NewBuffer(bodyBytes))
|
|
||||||
c.Request.ContentLength = int64(len(bodyBytes))
|
|
||||||
c.Request.Header.Set("Content-Type", "application/json")
|
|
||||||
} else {
|
|
||||||
common.SysError("JS PreProcess JSON Marshal Error: " + err.Error())
|
|
||||||
}
|
|
||||||
default:
|
|
||||||
common.SysError("JS PreProcess Unsupported Body Type: " + fmt.Sprintf("%T", newBody))
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// 是否修改 headers
|
|
||||||
if newHeaders, exists := resultMap["headers"]; exists {
|
|
||||||
if headersMap, ok := newHeaders.(map[string]any); ok {
|
|
||||||
for key, value := range headersMap {
|
|
||||||
if valueStr, ok := value.(string); ok {
|
|
||||||
c.Request.Header.Set(key, valueStr)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// 是否阻止请求
|
|
||||||
if block, exists := resultMap["block"]; exists {
|
|
||||||
if blockBool, ok := block.(bool); ok && blockBool {
|
|
||||||
status := http.StatusForbidden
|
|
||||||
if statusCode, exists := resultMap["statusCode"]; exists {
|
|
||||||
if statusInt, ok := statusCode.(float64); ok {
|
|
||||||
status = int(statusInt)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
message := "Request blocked by pre-process script"
|
|
||||||
if msg, exists := resultMap["message"]; exists {
|
|
||||||
if msgStr, ok := msg.(string); ok {
|
|
||||||
message = msgStr
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
c.JSON(status, gin.H{"error": message})
|
|
||||||
c.Abort()
|
|
||||||
return fmt.Errorf("request blocked")
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (p *JSRuntimePool) PostProcessResponse(c *gin.Context, statusCode int, body []byte) (int, []byte, error) {
|
|
||||||
if err := validateGinContext(c); err != nil {
|
|
||||||
common.SysError("JS PostProcess Validation Error: " + err.Error())
|
|
||||||
return statusCode, body, err
|
|
||||||
}
|
|
||||||
|
|
||||||
vm := p.Get()
|
|
||||||
defer p.Put(vm)
|
|
||||||
|
|
||||||
postProcessFunc := vm.Get("postProcessResponse")
|
|
||||||
if postProcessFunc == nil || goja.IsUndefined(postProcessFunc) {
|
|
||||||
return statusCode, body, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
jsReq, err := common.StructToMap(createJSReq(c))
|
|
||||||
if err != nil {
|
|
||||||
return statusCode, body, fmt.Errorf("failed to create JS context: %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
jsResp := &JSResponse{
|
|
||||||
StatusCode: statusCode,
|
|
||||||
Headers: make(map[string]string),
|
|
||||||
Body: string(body),
|
|
||||||
}
|
|
||||||
|
|
||||||
// 获取响应头
|
|
||||||
if c.Writer != nil {
|
|
||||||
for key, values := range c.Writer.Header() {
|
|
||||||
if len(values) > 0 {
|
|
||||||
jsResp.Headers[key] = values[0]
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
jsResponse, err := common.StructToMap(jsResp)
|
|
||||||
if err != nil {
|
|
||||||
return statusCode, body, fmt.Errorf("failed to create JS response context: %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
result, err := p.executeWithTimeout(vm, func() (goja.Value, error) {
|
|
||||||
fn, ok := goja.AssertFunction(postProcessFunc)
|
|
||||||
if !ok {
|
|
||||||
return nil, fmt.Errorf("postProcessResponse is not a function")
|
|
||||||
}
|
|
||||||
return fn(goja.Undefined(), vm.ToValue(jsReq), vm.ToValue(jsResponse))
|
|
||||||
})
|
|
||||||
|
|
||||||
if err != nil {
|
|
||||||
common.SysError("JS PostProcess Error: " + err.Error())
|
|
||||||
return statusCode, body, err
|
|
||||||
}
|
|
||||||
|
|
||||||
// 处理返回
|
|
||||||
if result != nil && !goja.IsUndefined(result) {
|
|
||||||
resultObj := result.Export()
|
|
||||||
if resultMap, ok := resultObj.(map[string]any); ok {
|
|
||||||
if newStatusCode, exists := resultMap["statusCode"]; exists {
|
|
||||||
if statusInt, ok := newStatusCode.(float64); ok {
|
|
||||||
statusCode = int(statusInt)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
if newBody, exists := resultMap["body"]; exists {
|
|
||||||
if bodyStr, ok := newBody.(string); ok {
|
|
||||||
body = []byte(bodyStr)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
if newHeaders, exists := resultMap["headers"]; exists {
|
|
||||||
if headersMap, ok := newHeaders.(map[string]any); ok {
|
|
||||||
for key, value := range headersMap {
|
|
||||||
if valueStr, ok := value.(string); ok {
|
|
||||||
c.Header(key, valueStr)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
return statusCode, body, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (p *JSRuntimePool) hasPostProcessFunction() bool {
|
|
||||||
vm := p.Get()
|
|
||||||
defer p.Put(vm)
|
|
||||||
postProcessFunc := vm.Get("postProcessResponse")
|
|
||||||
return postProcessFunc != nil && !goja.IsUndefined(postProcessFunc)
|
|
||||||
}
|
|
||||||
|
|
||||||
func JSRuntimeMiddleware() *gin.HandlerFunc {
|
|
||||||
loadCfg()
|
|
||||||
if !jsConfig.Enabled {
|
|
||||||
common.SysLog("JavaScript Runtime is disabled")
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
pool := initJSRuntimePool()
|
|
||||||
var fn gin.HandlerFunc
|
|
||||||
fn = func(c *gin.Context) {
|
|
||||||
start := time.Now()
|
|
||||||
|
|
||||||
// 预处理
|
|
||||||
if err := pool.PreProcessRequest(c); err != nil {
|
|
||||||
common.SysError("JS Runtime PreProcess Error: " + err.Error())
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
duration := time.Since(start)
|
|
||||||
if duration > time.Millisecond*100 {
|
|
||||||
common.SysLog(fmt.Sprintf("JS Runtime PreProcess took %v", duration))
|
|
||||||
}
|
|
||||||
|
|
||||||
// 后处理
|
|
||||||
if pool.hasPostProcessFunction() {
|
|
||||||
writer := newResponseWriter(c.Writer)
|
|
||||||
c.Writer = writer
|
|
||||||
|
|
||||||
c.Next()
|
|
||||||
|
|
||||||
// 后处理响应
|
|
||||||
if writer.body.Len() > 0 {
|
|
||||||
start := time.Now()
|
|
||||||
|
|
||||||
statusCode, body, err := pool.PostProcessResponse(c, writer.statusCode, writer.body.Bytes())
|
|
||||||
if err == nil {
|
|
||||||
c.Writer = writer.ResponseWriter
|
|
||||||
|
|
||||||
for k, v := range writer.headerMap {
|
|
||||||
for _, value := range v {
|
|
||||||
c.Writer.Header().Add(k, value)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
c.Status(statusCode)
|
|
||||||
|
|
||||||
if len(body) >= 0 {
|
|
||||||
c.Writer.Header().Set("Content-Length", fmt.Sprintf("%d", len(body)))
|
|
||||||
c.Writer.Write(body)
|
|
||||||
} else {
|
|
||||||
c.Writer.Header().Del("Content-Length")
|
|
||||||
c.Writer.Write(body)
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
// 出错时回复原响应
|
|
||||||
c.Writer = writer.ResponseWriter
|
|
||||||
c.Status(writer.statusCode)
|
|
||||||
|
|
||||||
common.SysError(fmt.Sprintf("JS Runtime PostProcess Error: %v", err))
|
|
||||||
}
|
|
||||||
|
|
||||||
duration := time.Since(start)
|
|
||||||
if duration > time.Millisecond*100 {
|
|
||||||
common.SysLog(fmt.Sprintf("JS Runtime PostProcess took %v", duration))
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
// 没有响应体时,恢复原始writer
|
|
||||||
c.Writer = writer.ResponseWriter
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
c.Next()
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return &fn
|
|
||||||
}
|
|
||||||
|
|
||||||
func ReloadJSScripts() {
|
|
||||||
if jsRuntimePool != nil {
|
|
||||||
jsRuntimePool.ReloadScripts()
|
|
||||||
common.SysLog("JavaScript scripts reloaded")
|
|
||||||
}
|
|
||||||
}
|
|
||||||
@@ -1,139 +0,0 @@
|
|||||||
package jsrt
|
|
||||||
|
|
||||||
import (
|
|
||||||
"bytes"
|
|
||||||
"io"
|
|
||||||
"maps"
|
|
||||||
"net/http"
|
|
||||||
"sync"
|
|
||||||
|
|
||||||
"github.com/gin-gonic/gin"
|
|
||||||
)
|
|
||||||
|
|
||||||
// 请求
|
|
||||||
type JSReq struct {
|
|
||||||
Method string `json:"method"`
|
|
||||||
URL string `json:"url"`
|
|
||||||
Headers map[string]string `json:"headers"`
|
|
||||||
Body any `json:"body"`
|
|
||||||
UserAgent string `json:"userAgent"`
|
|
||||||
RemoteIP string `json:"remoteIP"`
|
|
||||||
Extra map[string]any `json:"extra"`
|
|
||||||
}
|
|
||||||
|
|
||||||
type JSResponse struct {
|
|
||||||
StatusCode int `json:"statusCode"`
|
|
||||||
Headers map[string]string `json:"headers"`
|
|
||||||
Body string `json:"body"`
|
|
||||||
}
|
|
||||||
|
|
||||||
type responseWriter struct {
|
|
||||||
gin.ResponseWriter
|
|
||||||
body *bytes.Buffer
|
|
||||||
statusCode int
|
|
||||||
headerMap http.Header
|
|
||||||
written bool
|
|
||||||
mu sync.RWMutex
|
|
||||||
}
|
|
||||||
|
|
||||||
func createJSReq(c *gin.Context) *JSReq {
|
|
||||||
var bodyBytes []byte
|
|
||||||
if c.Request != nil && c.Request.Body != nil {
|
|
||||||
bodyBytes, _ = io.ReadAll(c.Request.Body)
|
|
||||||
c.Request.Body = io.NopCloser(bytes.NewBuffer(bodyBytes))
|
|
||||||
}
|
|
||||||
|
|
||||||
// headers map
|
|
||||||
headers := make(map[string]string)
|
|
||||||
if c.Request != nil && c.Request.Header != nil {
|
|
||||||
for key, values := range c.Request.Header {
|
|
||||||
if len(values) > 0 {
|
|
||||||
headers[key] = values[0]
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
method := ""
|
|
||||||
url := ""
|
|
||||||
userAgent := ""
|
|
||||||
remoteIP := ""
|
|
||||||
contentType := ""
|
|
||||||
|
|
||||||
if c.Request != nil {
|
|
||||||
method = c.Request.Method
|
|
||||||
if c.Request.URL != nil {
|
|
||||||
url = c.Request.URL.String()
|
|
||||||
}
|
|
||||||
userAgent = c.Request.UserAgent()
|
|
||||||
contentType = c.ContentType()
|
|
||||||
}
|
|
||||||
|
|
||||||
if c != nil {
|
|
||||||
remoteIP = c.ClientIP()
|
|
||||||
}
|
|
||||||
|
|
||||||
parsedBody := parseBodyByType(bodyBytes, contentType)
|
|
||||||
|
|
||||||
return &JSReq{
|
|
||||||
Method: method,
|
|
||||||
URL: url,
|
|
||||||
Headers: headers,
|
|
||||||
Body: parsedBody,
|
|
||||||
UserAgent: userAgent,
|
|
||||||
RemoteIP: remoteIP,
|
|
||||||
Extra: make(map[string]any),
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func newResponseWriter(w gin.ResponseWriter) *responseWriter {
|
|
||||||
return &responseWriter{
|
|
||||||
ResponseWriter: w,
|
|
||||||
body: &bytes.Buffer{},
|
|
||||||
statusCode: 200,
|
|
||||||
headerMap: make(http.Header),
|
|
||||||
written: false,
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func (w *responseWriter) Write(data []byte) (int, error) {
|
|
||||||
w.mu.Lock()
|
|
||||||
defer w.mu.Unlock()
|
|
||||||
|
|
||||||
if !w.written {
|
|
||||||
w.WriteHeader(200)
|
|
||||||
}
|
|
||||||
return w.body.Write(data)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (w *responseWriter) WriteString(s string) (int, error) {
|
|
||||||
w.mu.Lock()
|
|
||||||
defer w.mu.Unlock()
|
|
||||||
|
|
||||||
if !w.written {
|
|
||||||
w.WriteHeader(200)
|
|
||||||
}
|
|
||||||
return w.body.WriteString(s)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (w *responseWriter) WriteHeader(statusCode int) {
|
|
||||||
w.mu.Lock()
|
|
||||||
defer w.mu.Unlock()
|
|
||||||
|
|
||||||
if w.written {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
w.statusCode = statusCode
|
|
||||||
w.written = true
|
|
||||||
|
|
||||||
maps.Copy(w.headerMap, w.ResponseWriter.Header())
|
|
||||||
}
|
|
||||||
|
|
||||||
func (w *responseWriter) Header() http.Header {
|
|
||||||
w.mu.RLock()
|
|
||||||
defer w.mu.RUnlock()
|
|
||||||
|
|
||||||
if w.headerMap == nil {
|
|
||||||
w.headerMap = make(http.Header)
|
|
||||||
}
|
|
||||||
return w.headerMap
|
|
||||||
}
|
|
||||||
@@ -1,86 +0,0 @@
|
|||||||
package jsrt
|
|
||||||
|
|
||||||
import (
|
|
||||||
"encoding/json"
|
|
||||||
"net/url"
|
|
||||||
"one-api/common"
|
|
||||||
"strings"
|
|
||||||
|
|
||||||
"github.com/dop251/goja"
|
|
||||||
"gorm.io/gorm"
|
|
||||||
)
|
|
||||||
|
|
||||||
func setDB(vm *goja.Runtime, db *gorm.DB, name string) {
|
|
||||||
if db == nil {
|
|
||||||
common.SysError("JS DB is nil")
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
obj := vm.NewObject()
|
|
||||||
obj.Set("query", func(sql string, params ...any) []map[string]any {
|
|
||||||
return dbQuery(db, sql, params...)
|
|
||||||
})
|
|
||||||
obj.Set("exec", func(sql string, params ...any) map[string]any {
|
|
||||||
return dbExec(db, sql, params...)
|
|
||||||
})
|
|
||||||
if err := vm.Set(name, obj); err != nil {
|
|
||||||
common.SysError("Failed to set JS DB: " + err.Error())
|
|
||||||
return
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func parseBodyByType(bodyBytes []byte, contentType string) any {
|
|
||||||
if len(bodyBytes) == 0 {
|
|
||||||
return ""
|
|
||||||
}
|
|
||||||
|
|
||||||
bodyStr := string(bodyBytes)
|
|
||||||
contentLower := strings.ToLower(contentType)
|
|
||||||
|
|
||||||
switch {
|
|
||||||
case strings.Contains(contentLower, "application/json"):
|
|
||||||
var jsonObj any
|
|
||||||
if err := json.Unmarshal(bodyBytes, &jsonObj); err == nil {
|
|
||||||
return jsonObj
|
|
||||||
}
|
|
||||||
return bodyStr
|
|
||||||
|
|
||||||
case strings.Contains(contentLower, "application/x-www-form-urlencoded"):
|
|
||||||
if values, err := url.ParseQuery(bodyStr); err == nil {
|
|
||||||
result := make(map[string]string, len(values))
|
|
||||||
for k, v := range values {
|
|
||||||
if len(v) > 0 {
|
|
||||||
result[k] = v[0]
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return result
|
|
||||||
}
|
|
||||||
return bodyStr
|
|
||||||
|
|
||||||
case strings.Contains(contentLower, "multipart/form-data"):
|
|
||||||
return bodyBytes
|
|
||||||
|
|
||||||
case strings.Contains(contentLower, "text/"):
|
|
||||||
return bodyStr
|
|
||||||
|
|
||||||
default:
|
|
||||||
// 尝试JSON解析
|
|
||||||
var jsonObj any
|
|
||||||
if json.Unmarshal(bodyBytes, &jsonObj) == nil {
|
|
||||||
return jsonObj
|
|
||||||
}
|
|
||||||
|
|
||||||
// 尝试form解析
|
|
||||||
if values, err := url.ParseQuery(bodyStr); err == nil && len(values) > 0 {
|
|
||||||
result := make(map[string]string, len(values))
|
|
||||||
for k, v := range values {
|
|
||||||
if len(v) > 0 {
|
|
||||||
result[k] = v[0]
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return result
|
|
||||||
}
|
|
||||||
|
|
||||||
return bodyStr
|
|
||||||
}
|
|
||||||
}
|
|
||||||
@@ -1,8 +1,12 @@
|
|||||||
package model
|
package model
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"database/sql/driver"
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
|
"fmt"
|
||||||
|
"math/rand"
|
||||||
"one-api/common"
|
"one-api/common"
|
||||||
|
"one-api/constant"
|
||||||
"one-api/dto"
|
"one-api/dto"
|
||||||
"strings"
|
"strings"
|
||||||
"sync"
|
"sync"
|
||||||
@@ -36,8 +40,100 @@ type Channel struct {
|
|||||||
AutoBan *int `json:"auto_ban" gorm:"default:1"`
|
AutoBan *int `json:"auto_ban" gorm:"default:1"`
|
||||||
OtherInfo string `json:"other_info"`
|
OtherInfo string `json:"other_info"`
|
||||||
Tag *string `json:"tag" gorm:"index"`
|
Tag *string `json:"tag" gorm:"index"`
|
||||||
Setting *string `json:"setting" gorm:"type:text"`
|
Setting *string `json:"setting" gorm:"type:text"` // 渠道额外设置
|
||||||
ParamOverride *string `json:"param_override" gorm:"type:text"`
|
ParamOverride *string `json:"param_override" gorm:"type:text"`
|
||||||
|
// add after v0.8.5
|
||||||
|
ChannelInfo ChannelInfo `json:"channel_info" gorm:"type:json"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type ChannelInfo struct {
|
||||||
|
IsMultiKey bool `json:"is_multi_key"` // 是否多Key模式
|
||||||
|
MultiKeyStatusList map[int]int `json:"multi_key_status_list"` // key状态列表,key index -> status
|
||||||
|
MultiKeyPollingIndex int `json:"multi_key_polling_index"` // 多Key模式下轮询的key索引
|
||||||
|
MultiKeyMode constant.MultiKeyMode `json:"multi_key_mode"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// Value implements driver.Valuer interface
|
||||||
|
func (c ChannelInfo) Value() (driver.Value, error) {
|
||||||
|
return common.Marshal(&c)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Scan implements sql.Scanner interface
|
||||||
|
func (c *ChannelInfo) Scan(value interface{}) error {
|
||||||
|
bytesValue, _ := value.([]byte)
|
||||||
|
return common.Unmarshal(bytesValue, c)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (channel *Channel) getKeys() []string {
|
||||||
|
if channel.Key == "" {
|
||||||
|
return []string{}
|
||||||
|
}
|
||||||
|
// use \n to split keys
|
||||||
|
keys := strings.Split(strings.Trim(channel.Key, "\n"), "\n")
|
||||||
|
return keys
|
||||||
|
}
|
||||||
|
|
||||||
|
func (channel *Channel) GetNextEnabledKey() (string, error) {
|
||||||
|
// If not in multi-key mode, return the original key string directly.
|
||||||
|
if !channel.ChannelInfo.IsMultiKey {
|
||||||
|
return channel.Key, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Obtain all keys (split by \n)
|
||||||
|
keys := channel.getKeys()
|
||||||
|
if len(keys) == 0 {
|
||||||
|
// No keys available, return error, should disable the channel
|
||||||
|
return "", fmt.Errorf("no valid keys in channel")
|
||||||
|
}
|
||||||
|
|
||||||
|
statusList := channel.ChannelInfo.MultiKeyStatusList
|
||||||
|
// helper to get key status, default to enabled when missing
|
||||||
|
getStatus := func(idx int) int {
|
||||||
|
if statusList == nil {
|
||||||
|
return common.ChannelStatusEnabled
|
||||||
|
}
|
||||||
|
if status, ok := statusList[idx]; ok {
|
||||||
|
return status
|
||||||
|
}
|
||||||
|
return common.ChannelStatusEnabled
|
||||||
|
}
|
||||||
|
|
||||||
|
// Collect indexes of enabled keys
|
||||||
|
enabledIdx := make([]int, 0, len(keys))
|
||||||
|
for i := range keys {
|
||||||
|
if getStatus(i) == common.ChannelStatusEnabled {
|
||||||
|
enabledIdx = append(enabledIdx, i)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
// If no specific status list or none enabled, fall back to first key
|
||||||
|
if len(enabledIdx) == 0 {
|
||||||
|
return keys[0], nil
|
||||||
|
}
|
||||||
|
|
||||||
|
switch channel.ChannelInfo.MultiKeyMode {
|
||||||
|
case constant.MultiKeyModeRandom:
|
||||||
|
// Randomly pick one enabled key
|
||||||
|
return keys[enabledIdx[rand.Intn(len(enabledIdx))]], nil
|
||||||
|
case constant.MultiKeyModePolling:
|
||||||
|
// Start from the saved polling index and look for the next enabled key
|
||||||
|
start := channel.ChannelInfo.MultiKeyPollingIndex
|
||||||
|
if start < 0 || start >= len(keys) {
|
||||||
|
start = 0
|
||||||
|
}
|
||||||
|
for i := 0; i < len(keys); i++ {
|
||||||
|
idx := (start + i) % len(keys)
|
||||||
|
if getStatus(idx) == common.ChannelStatusEnabled {
|
||||||
|
// update polling index for next call (point to the next position)
|
||||||
|
channel.ChannelInfo.MultiKeyPollingIndex = (idx + 1) % len(keys)
|
||||||
|
return keys[idx], nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
// Fallback – should not happen, but return first enabled key
|
||||||
|
return keys[enabledIdx[0]], nil
|
||||||
|
default:
|
||||||
|
// Unknown mode, default to first enabled key (or original key string)
|
||||||
|
return keys[enabledIdx[0]], nil
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (channel *Channel) GetModels() []string {
|
func (channel *Channel) GetModels() []string {
|
||||||
|
|||||||
@@ -49,7 +49,7 @@ func formatUserLogs(logs []*Log) {
|
|||||||
for i := range logs {
|
for i := range logs {
|
||||||
logs[i].ChannelName = ""
|
logs[i].ChannelName = ""
|
||||||
var otherMap map[string]interface{}
|
var otherMap map[string]interface{}
|
||||||
otherMap = common.StrToMap(logs[i].Other)
|
otherMap, _ = common.StrToMap(logs[i].Other)
|
||||||
if otherMap != nil {
|
if otherMap != nil {
|
||||||
// delete admin
|
// delete admin
|
||||||
delete(otherMap, "admin_info")
|
delete(otherMap, "admin_info")
|
||||||
|
|||||||
@@ -57,7 +57,7 @@ func initCol() {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
// log sql type and database type
|
// log sql type and database type
|
||||||
common.SysLog("Using Log SQL Type: " + common.LogSqlType)
|
//common.SysLog("Using Log SQL Type: " + common.LogSqlType)
|
||||||
}
|
}
|
||||||
|
|
||||||
var DB *gorm.DB
|
var DB *gorm.DB
|
||||||
@@ -225,12 +225,6 @@ func InitLogDB() (err error) {
|
|||||||
if !common.IsMasterNode {
|
if !common.IsMasterNode {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
//if common.UsingMySQL {
|
|
||||||
// _, _ = sqlDB.Exec("DROP INDEX idx_channels_key ON channels;") // TODO: delete this line when most users have upgraded
|
|
||||||
// _, _ = sqlDB.Exec("ALTER TABLE midjourneys MODIFY action VARCHAR(40);") // TODO: delete this line when most users have upgraded
|
|
||||||
// _, _ = sqlDB.Exec("ALTER TABLE midjourneys MODIFY progress VARCHAR(30);") // TODO: delete this line when most users have upgraded
|
|
||||||
// _, _ = sqlDB.Exec("ALTER TABLE midjourneys MODIFY status VARCHAR(20);") // TODO: delete this line when most users have upgraded
|
|
||||||
//}
|
|
||||||
common.SysLog("database migration started")
|
common.SysLog("database migration started")
|
||||||
err = migrateLOGDB()
|
err = migrateLOGDB()
|
||||||
return err
|
return err
|
||||||
|
|||||||
@@ -1,7 +1,6 @@
|
|||||||
package model
|
package model
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"encoding/json"
|
|
||||||
"fmt"
|
"fmt"
|
||||||
"one-api/common"
|
"one-api/common"
|
||||||
"one-api/constant"
|
"one-api/constant"
|
||||||
@@ -36,7 +35,7 @@ func (user *UserBase) WriteContext(c *gin.Context) {
|
|||||||
func (user *UserBase) GetSetting() dto.UserSetting {
|
func (user *UserBase) GetSetting() dto.UserSetting {
|
||||||
setting := dto.UserSetting{}
|
setting := dto.UserSetting{}
|
||||||
if user.Setting != "" {
|
if user.Setting != "" {
|
||||||
err := json.Unmarshal([]byte(user.Setting), &setting)
|
err := common.Unmarshal([]byte(user.Setting), &setting)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
common.SysError("failed to unmarshal setting: " + err.Error())
|
common.SysError("failed to unmarshal setting: " + err.Error())
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -3,7 +3,6 @@ package relay
|
|||||||
import (
|
import (
|
||||||
"errors"
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
"github.com/gin-gonic/gin"
|
|
||||||
"net/http"
|
"net/http"
|
||||||
"one-api/common"
|
"one-api/common"
|
||||||
"one-api/dto"
|
"one-api/dto"
|
||||||
@@ -12,7 +11,10 @@ import (
|
|||||||
"one-api/relay/helper"
|
"one-api/relay/helper"
|
||||||
"one-api/service"
|
"one-api/service"
|
||||||
"one-api/setting"
|
"one-api/setting"
|
||||||
|
"one-api/types"
|
||||||
"strings"
|
"strings"
|
||||||
|
|
||||||
|
"github.com/gin-gonic/gin"
|
||||||
)
|
)
|
||||||
|
|
||||||
func getAndValidAudioRequest(c *gin.Context, info *relaycommon.RelayInfo) (*dto.AudioRequest, error) {
|
func getAndValidAudioRequest(c *gin.Context, info *relaycommon.RelayInfo) (*dto.AudioRequest, error) {
|
||||||
@@ -54,13 +56,13 @@ func getAndValidAudioRequest(c *gin.Context, info *relaycommon.RelayInfo) (*dto.
|
|||||||
return audioRequest, nil
|
return audioRequest, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func AudioHelper(c *gin.Context) (openaiErr *dto.OpenAIErrorWithStatusCode) {
|
func AudioHelper(c *gin.Context) (newAPIError *types.NewAPIError) {
|
||||||
relayInfo := relaycommon.GenRelayInfoOpenAIAudio(c)
|
relayInfo := relaycommon.GenRelayInfoOpenAIAudio(c)
|
||||||
audioRequest, err := getAndValidAudioRequest(c, relayInfo)
|
audioRequest, err := getAndValidAudioRequest(c, relayInfo)
|
||||||
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
common.LogError(c, fmt.Sprintf("getAndValidAudioRequest failed: %s", err.Error()))
|
common.LogError(c, fmt.Sprintf("getAndValidAudioRequest failed: %s", err.Error()))
|
||||||
return service.OpenAIErrorWrapper(err, "invalid_audio_request", http.StatusBadRequest)
|
return types.NewError(err, types.ErrorCodeInvalidRequest)
|
||||||
}
|
}
|
||||||
|
|
||||||
promptTokens := 0
|
promptTokens := 0
|
||||||
@@ -73,7 +75,7 @@ func AudioHelper(c *gin.Context) (openaiErr *dto.OpenAIErrorWithStatusCode) {
|
|||||||
|
|
||||||
priceData, err := helper.ModelPriceHelper(c, relayInfo, preConsumedTokens, 0)
|
priceData, err := helper.ModelPriceHelper(c, relayInfo, preConsumedTokens, 0)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return service.OpenAIErrorWrapperLocal(err, "model_price_error", http.StatusInternalServerError)
|
return types.NewError(err, types.ErrorCodeModelPriceError)
|
||||||
}
|
}
|
||||||
|
|
||||||
preConsumedQuota, userQuota, openaiErr := preConsumeQuota(c, priceData.ShouldPreConsumedQuota, relayInfo)
|
preConsumedQuota, userQuota, openaiErr := preConsumeQuota(c, priceData.ShouldPreConsumedQuota, relayInfo)
|
||||||
@@ -88,23 +90,23 @@ func AudioHelper(c *gin.Context) (openaiErr *dto.OpenAIErrorWithStatusCode) {
|
|||||||
|
|
||||||
err = helper.ModelMappedHelper(c, relayInfo, audioRequest)
|
err = helper.ModelMappedHelper(c, relayInfo, audioRequest)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return service.OpenAIErrorWrapperLocal(err, "model_mapped_error", http.StatusInternalServerError)
|
return types.NewError(err, types.ErrorCodeChannelModelMappedError)
|
||||||
}
|
}
|
||||||
|
|
||||||
adaptor := GetAdaptor(relayInfo.ApiType)
|
adaptor := GetAdaptor(relayInfo.ApiType)
|
||||||
if adaptor == nil {
|
if adaptor == nil {
|
||||||
return service.OpenAIErrorWrapperLocal(fmt.Errorf("invalid api type: %d", relayInfo.ApiType), "invalid_api_type", http.StatusBadRequest)
|
return types.NewError(fmt.Errorf("invalid api type: %d", relayInfo.ApiType), types.ErrorCodeInvalidApiType)
|
||||||
}
|
}
|
||||||
adaptor.Init(relayInfo)
|
adaptor.Init(relayInfo)
|
||||||
|
|
||||||
ioReader, err := adaptor.ConvertAudioRequest(c, relayInfo, *audioRequest)
|
ioReader, err := adaptor.ConvertAudioRequest(c, relayInfo, *audioRequest)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return service.OpenAIErrorWrapperLocal(err, "convert_request_failed", http.StatusInternalServerError)
|
return types.NewError(err, types.ErrorCodeConvertRequestFailed)
|
||||||
}
|
}
|
||||||
|
|
||||||
resp, err := adaptor.DoRequest(c, relayInfo, ioReader)
|
resp, err := adaptor.DoRequest(c, relayInfo, ioReader)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return service.OpenAIErrorWrapper(err, "do_request_failed", http.StatusInternalServerError)
|
return types.NewError(err, types.ErrorCodeDoRequestFailed)
|
||||||
}
|
}
|
||||||
statusCodeMappingStr := c.GetString("status_code_mapping")
|
statusCodeMappingStr := c.GetString("status_code_mapping")
|
||||||
|
|
||||||
@@ -112,18 +114,18 @@ func AudioHelper(c *gin.Context) (openaiErr *dto.OpenAIErrorWithStatusCode) {
|
|||||||
if resp != nil {
|
if resp != nil {
|
||||||
httpResp = resp.(*http.Response)
|
httpResp = resp.(*http.Response)
|
||||||
if httpResp.StatusCode != http.StatusOK {
|
if httpResp.StatusCode != http.StatusOK {
|
||||||
openaiErr = service.RelayErrorHandler(httpResp, false)
|
newAPIError = service.RelayErrorHandler(httpResp, false)
|
||||||
// reset status code 重置状态码
|
// reset status code 重置状态码
|
||||||
service.ResetStatusCode(openaiErr, statusCodeMappingStr)
|
service.ResetStatusCode(newAPIError, statusCodeMappingStr)
|
||||||
return openaiErr
|
return newAPIError
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
usage, openaiErr := adaptor.DoResponse(c, httpResp, relayInfo)
|
usage, newAPIError := adaptor.DoResponse(c, httpResp, relayInfo)
|
||||||
if openaiErr != nil {
|
if newAPIError != nil {
|
||||||
// reset status code 重置状态码
|
// reset status code 重置状态码
|
||||||
service.ResetStatusCode(openaiErr, statusCodeMappingStr)
|
service.ResetStatusCode(newAPIError, statusCodeMappingStr)
|
||||||
return openaiErr
|
return newAPIError
|
||||||
}
|
}
|
||||||
|
|
||||||
postConsumeQuota(c, relayInfo, usage.(*dto.Usage), preConsumedQuota, userQuota, priceData, "")
|
postConsumeQuota(c, relayInfo, usage.(*dto.Usage), preConsumedQuota, userQuota, priceData, "")
|
||||||
|
|||||||
@@ -5,6 +5,7 @@ import (
|
|||||||
"net/http"
|
"net/http"
|
||||||
"one-api/dto"
|
"one-api/dto"
|
||||||
relaycommon "one-api/relay/common"
|
relaycommon "one-api/relay/common"
|
||||||
|
"one-api/types"
|
||||||
|
|
||||||
"github.com/gin-gonic/gin"
|
"github.com/gin-gonic/gin"
|
||||||
)
|
)
|
||||||
@@ -21,7 +22,7 @@ type Adaptor interface {
|
|||||||
ConvertImageRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.ImageRequest) (any, error)
|
ConvertImageRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.ImageRequest) (any, error)
|
||||||
ConvertOpenAIResponsesRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.OpenAIResponsesRequest) (any, error)
|
ConvertOpenAIResponsesRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.OpenAIResponsesRequest) (any, error)
|
||||||
DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (any, error)
|
DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (any, error)
|
||||||
DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage any, err *dto.OpenAIErrorWithStatusCode)
|
DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage any, err *types.NewAPIError)
|
||||||
GetModelList() []string
|
GetModelList() []string
|
||||||
GetChannelName() string
|
GetChannelName() string
|
||||||
ConvertClaudeRequest(c *gin.Context, info *relaycommon.RelayInfo, request *dto.ClaudeRequest) (any, error)
|
ConvertClaudeRequest(c *gin.Context, info *relaycommon.RelayInfo, request *dto.ClaudeRequest) (any, error)
|
||||||
|
|||||||
@@ -10,6 +10,7 @@ import (
|
|||||||
"one-api/relay/channel/openai"
|
"one-api/relay/channel/openai"
|
||||||
relaycommon "one-api/relay/common"
|
relaycommon "one-api/relay/common"
|
||||||
"one-api/relay/constant"
|
"one-api/relay/constant"
|
||||||
|
"one-api/types"
|
||||||
|
|
||||||
"github.com/gin-gonic/gin"
|
"github.com/gin-gonic/gin"
|
||||||
)
|
)
|
||||||
@@ -99,7 +100,7 @@ func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, request
|
|||||||
return channel.DoApiRequest(a, c, info, requestBody)
|
return channel.DoApiRequest(a, c, info, requestBody)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage any, err *dto.OpenAIErrorWithStatusCode) {
|
func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage any, err *types.NewAPIError) {
|
||||||
switch info.RelayMode {
|
switch info.RelayMode {
|
||||||
case constant.RelayModeImagesGenerations:
|
case constant.RelayModeImagesGenerations:
|
||||||
err, usage = aliImageHandler(c, resp, info)
|
err, usage = aliImageHandler(c, resp, info)
|
||||||
@@ -109,9 +110,9 @@ func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycom
|
|||||||
err, usage = RerankHandler(c, resp, info)
|
err, usage = RerankHandler(c, resp, info)
|
||||||
default:
|
default:
|
||||||
if info.IsStream {
|
if info.IsStream {
|
||||||
err, usage = openai.OaiStreamHandler(c, resp, info)
|
usage, err = openai.OaiStreamHandler(c, info, resp)
|
||||||
} else {
|
} else {
|
||||||
err, usage = openai.OpenaiHandler(c, resp, info)
|
usage, err = openai.OpenaiHandler(c, info, resp)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
return
|
return
|
||||||
|
|||||||
@@ -4,15 +4,17 @@ import (
|
|||||||
"encoding/json"
|
"encoding/json"
|
||||||
"errors"
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
"github.com/gin-gonic/gin"
|
|
||||||
"io"
|
"io"
|
||||||
"net/http"
|
"net/http"
|
||||||
"one-api/common"
|
"one-api/common"
|
||||||
"one-api/dto"
|
"one-api/dto"
|
||||||
relaycommon "one-api/relay/common"
|
relaycommon "one-api/relay/common"
|
||||||
"one-api/service"
|
"one-api/service"
|
||||||
|
"one-api/types"
|
||||||
"strings"
|
"strings"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
|
"github.com/gin-gonic/gin"
|
||||||
)
|
)
|
||||||
|
|
||||||
func oaiImage2Ali(request dto.ImageRequest) *AliImageRequest {
|
func oaiImage2Ali(request dto.ImageRequest) *AliImageRequest {
|
||||||
@@ -124,49 +126,46 @@ func responseAli2OpenAIImage(c *gin.Context, response *AliResponse, info *relayc
|
|||||||
return &imageResponse
|
return &imageResponse
|
||||||
}
|
}
|
||||||
|
|
||||||
func aliImageHandler(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (*dto.OpenAIErrorWithStatusCode, *dto.Usage) {
|
func aliImageHandler(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (*types.NewAPIError, *dto.Usage) {
|
||||||
responseFormat := c.GetString("response_format")
|
responseFormat := c.GetString("response_format")
|
||||||
|
|
||||||
var aliTaskResponse AliResponse
|
var aliTaskResponse AliResponse
|
||||||
responseBody, err := io.ReadAll(resp.Body)
|
responseBody, err := io.ReadAll(resp.Body)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return service.OpenAIErrorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil
|
return types.NewError(err, types.ErrorCodeReadResponseBodyFailed), nil
|
||||||
}
|
}
|
||||||
common.CloseResponseBodyGracefully(resp)
|
common.CloseResponseBodyGracefully(resp)
|
||||||
err = json.Unmarshal(responseBody, &aliTaskResponse)
|
err = json.Unmarshal(responseBody, &aliTaskResponse)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return service.OpenAIErrorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil
|
return types.NewError(err, types.ErrorCodeBadResponseBody), nil
|
||||||
}
|
}
|
||||||
|
|
||||||
if aliTaskResponse.Message != "" {
|
if aliTaskResponse.Message != "" {
|
||||||
common.LogError(c, "ali_async_task_failed: "+aliTaskResponse.Message)
|
common.LogError(c, "ali_async_task_failed: "+aliTaskResponse.Message)
|
||||||
return service.OpenAIErrorWrapper(errors.New(aliTaskResponse.Message), "ali_async_task_failed", http.StatusInternalServerError), nil
|
return types.NewError(errors.New(aliTaskResponse.Message), types.ErrorCodeBadResponse), nil
|
||||||
}
|
}
|
||||||
|
|
||||||
aliResponse, _, err := asyncTaskWait(info, aliTaskResponse.Output.TaskId)
|
aliResponse, _, err := asyncTaskWait(info, aliTaskResponse.Output.TaskId)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return service.OpenAIErrorWrapper(err, "ali_async_task_wait_failed", http.StatusInternalServerError), nil
|
return types.NewError(err, types.ErrorCodeBadResponse), nil
|
||||||
}
|
}
|
||||||
|
|
||||||
if aliResponse.Output.TaskStatus != "SUCCEEDED" {
|
if aliResponse.Output.TaskStatus != "SUCCEEDED" {
|
||||||
return &dto.OpenAIErrorWithStatusCode{
|
return types.WithOpenAIError(types.OpenAIError{
|
||||||
Error: dto.OpenAIError{
|
Message: aliResponse.Output.Message,
|
||||||
Message: aliResponse.Output.Message,
|
Type: "ali_error",
|
||||||
Type: "ali_error",
|
Param: "",
|
||||||
Param: "",
|
Code: aliResponse.Output.Code,
|
||||||
Code: aliResponse.Output.Code,
|
}, resp.StatusCode), nil
|
||||||
},
|
|
||||||
StatusCode: resp.StatusCode,
|
|
||||||
}, nil
|
|
||||||
}
|
}
|
||||||
|
|
||||||
fullTextResponse := responseAli2OpenAIImage(c, aliResponse, info, responseFormat)
|
fullTextResponse := responseAli2OpenAIImage(c, aliResponse, info, responseFormat)
|
||||||
jsonResponse, err := json.Marshal(fullTextResponse)
|
jsonResponse, err := json.Marshal(fullTextResponse)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return service.OpenAIErrorWrapper(err, "marshal_response_body_failed", http.StatusInternalServerError), nil
|
return types.NewError(err, types.ErrorCodeBadResponseBody), nil
|
||||||
}
|
}
|
||||||
c.Writer.Header().Set("Content-Type", "application/json")
|
c.Writer.Header().Set("Content-Type", "application/json")
|
||||||
c.Writer.WriteHeader(resp.StatusCode)
|
c.Writer.WriteHeader(resp.StatusCode)
|
||||||
_, err = c.Writer.Write(jsonResponse)
|
c.Writer.Write(jsonResponse)
|
||||||
return nil, nil
|
return nil, &dto.Usage{}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -7,7 +7,7 @@ import (
|
|||||||
"one-api/common"
|
"one-api/common"
|
||||||
"one-api/dto"
|
"one-api/dto"
|
||||||
relaycommon "one-api/relay/common"
|
relaycommon "one-api/relay/common"
|
||||||
"one-api/service"
|
"one-api/types"
|
||||||
|
|
||||||
"github.com/gin-gonic/gin"
|
"github.com/gin-gonic/gin"
|
||||||
)
|
)
|
||||||
@@ -31,29 +31,26 @@ func ConvertRerankRequest(request dto.RerankRequest) *AliRerankRequest {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func RerankHandler(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (*dto.OpenAIErrorWithStatusCode, *dto.Usage) {
|
func RerankHandler(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (*types.NewAPIError, *dto.Usage) {
|
||||||
responseBody, err := io.ReadAll(resp.Body)
|
responseBody, err := io.ReadAll(resp.Body)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return service.OpenAIErrorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil
|
return types.NewError(err, types.ErrorCodeReadResponseBodyFailed), nil
|
||||||
}
|
}
|
||||||
common.CloseResponseBodyGracefully(resp)
|
common.CloseResponseBodyGracefully(resp)
|
||||||
|
|
||||||
var aliResponse AliRerankResponse
|
var aliResponse AliRerankResponse
|
||||||
err = json.Unmarshal(responseBody, &aliResponse)
|
err = json.Unmarshal(responseBody, &aliResponse)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return service.OpenAIErrorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil
|
return types.NewError(err, types.ErrorCodeBadResponseBody), nil
|
||||||
}
|
}
|
||||||
|
|
||||||
if aliResponse.Code != "" {
|
if aliResponse.Code != "" {
|
||||||
return &dto.OpenAIErrorWithStatusCode{
|
return types.WithOpenAIError(types.OpenAIError{
|
||||||
Error: dto.OpenAIError{
|
Message: aliResponse.Message,
|
||||||
Message: aliResponse.Message,
|
Type: aliResponse.Code,
|
||||||
Type: aliResponse.Code,
|
Param: aliResponse.RequestId,
|
||||||
Param: aliResponse.RequestId,
|
Code: aliResponse.Code,
|
||||||
Code: aliResponse.Code,
|
}, resp.StatusCode), nil
|
||||||
},
|
|
||||||
StatusCode: resp.StatusCode,
|
|
||||||
}, nil
|
|
||||||
}
|
}
|
||||||
|
|
||||||
usage := dto.Usage{
|
usage := dto.Usage{
|
||||||
@@ -68,14 +65,10 @@ func RerankHandler(c *gin.Context, resp *http.Response, info *relaycommon.RelayI
|
|||||||
|
|
||||||
jsonResponse, err := json.Marshal(rerankResponse)
|
jsonResponse, err := json.Marshal(rerankResponse)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return service.OpenAIErrorWrapper(err, "marshal_response_body_failed", http.StatusInternalServerError), nil
|
return types.NewError(err, types.ErrorCodeBadResponseBody), nil
|
||||||
}
|
}
|
||||||
c.Writer.Header().Set("Content-Type", "application/json")
|
c.Writer.Header().Set("Content-Type", "application/json")
|
||||||
c.Writer.WriteHeader(resp.StatusCode)
|
c.Writer.WriteHeader(resp.StatusCode)
|
||||||
_, err = c.Writer.Write(jsonResponse)
|
c.Writer.Write(jsonResponse)
|
||||||
if err != nil {
|
|
||||||
return service.OpenAIErrorWrapper(err, "write_response_body_failed", http.StatusInternalServerError), nil
|
|
||||||
}
|
|
||||||
|
|
||||||
return nil, &usage
|
return nil, &usage
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -8,9 +8,10 @@ import (
|
|||||||
"one-api/common"
|
"one-api/common"
|
||||||
"one-api/dto"
|
"one-api/dto"
|
||||||
"one-api/relay/helper"
|
"one-api/relay/helper"
|
||||||
"one-api/service"
|
|
||||||
"strings"
|
"strings"
|
||||||
|
|
||||||
|
"one-api/types"
|
||||||
|
|
||||||
"github.com/gin-gonic/gin"
|
"github.com/gin-gonic/gin"
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -38,11 +39,11 @@ func embeddingRequestOpenAI2Ali(request dto.EmbeddingRequest) *AliEmbeddingReque
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func aliEmbeddingHandler(c *gin.Context, resp *http.Response) (*dto.OpenAIErrorWithStatusCode, *dto.Usage) {
|
func aliEmbeddingHandler(c *gin.Context, resp *http.Response) (*types.NewAPIError, *dto.Usage) {
|
||||||
var fullTextResponse dto.OpenAIEmbeddingResponse
|
var fullTextResponse dto.OpenAIEmbeddingResponse
|
||||||
err := json.NewDecoder(resp.Body).Decode(&fullTextResponse)
|
err := json.NewDecoder(resp.Body).Decode(&fullTextResponse)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return service.OpenAIErrorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil
|
return types.NewError(err, types.ErrorCodeBadResponseBody), nil
|
||||||
}
|
}
|
||||||
|
|
||||||
common.CloseResponseBodyGracefully(resp)
|
common.CloseResponseBodyGracefully(resp)
|
||||||
@@ -53,11 +54,11 @@ func aliEmbeddingHandler(c *gin.Context, resp *http.Response) (*dto.OpenAIErrorW
|
|||||||
}
|
}
|
||||||
jsonResponse, err := json.Marshal(fullTextResponse)
|
jsonResponse, err := json.Marshal(fullTextResponse)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return service.OpenAIErrorWrapper(err, "marshal_response_body_failed", http.StatusInternalServerError), nil
|
return types.NewError(err, types.ErrorCodeBadResponseBody), nil
|
||||||
}
|
}
|
||||||
c.Writer.Header().Set("Content-Type", "application/json")
|
c.Writer.Header().Set("Content-Type", "application/json")
|
||||||
c.Writer.WriteHeader(resp.StatusCode)
|
c.Writer.WriteHeader(resp.StatusCode)
|
||||||
_, err = c.Writer.Write(jsonResponse)
|
c.Writer.Write(jsonResponse)
|
||||||
return nil, &fullTextResponse.Usage
|
return nil, &fullTextResponse.Usage
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -119,7 +120,7 @@ func streamResponseAli2OpenAI(aliResponse *AliResponse) *dto.ChatCompletionsStre
|
|||||||
return &response
|
return &response
|
||||||
}
|
}
|
||||||
|
|
||||||
func aliStreamHandler(c *gin.Context, resp *http.Response) (*dto.OpenAIErrorWithStatusCode, *dto.Usage) {
|
func aliStreamHandler(c *gin.Context, resp *http.Response) (*types.NewAPIError, *dto.Usage) {
|
||||||
var usage dto.Usage
|
var usage dto.Usage
|
||||||
scanner := bufio.NewScanner(resp.Body)
|
scanner := bufio.NewScanner(resp.Body)
|
||||||
scanner.Split(bufio.ScanLines)
|
scanner.Split(bufio.ScanLines)
|
||||||
@@ -174,32 +175,29 @@ func aliStreamHandler(c *gin.Context, resp *http.Response) (*dto.OpenAIErrorWith
|
|||||||
return nil, &usage
|
return nil, &usage
|
||||||
}
|
}
|
||||||
|
|
||||||
func aliHandler(c *gin.Context, resp *http.Response) (*dto.OpenAIErrorWithStatusCode, *dto.Usage) {
|
func aliHandler(c *gin.Context, resp *http.Response) (*types.NewAPIError, *dto.Usage) {
|
||||||
var aliResponse AliResponse
|
var aliResponse AliResponse
|
||||||
responseBody, err := io.ReadAll(resp.Body)
|
responseBody, err := io.ReadAll(resp.Body)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return service.OpenAIErrorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil
|
return types.NewError(err, types.ErrorCodeReadResponseBodyFailed), nil
|
||||||
}
|
}
|
||||||
common.CloseResponseBodyGracefully(resp)
|
common.CloseResponseBodyGracefully(resp)
|
||||||
err = json.Unmarshal(responseBody, &aliResponse)
|
err = json.Unmarshal(responseBody, &aliResponse)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return service.OpenAIErrorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil
|
return types.NewError(err, types.ErrorCodeBadResponseBody), nil
|
||||||
}
|
}
|
||||||
if aliResponse.Code != "" {
|
if aliResponse.Code != "" {
|
||||||
return &dto.OpenAIErrorWithStatusCode{
|
return types.WithOpenAIError(types.OpenAIError{
|
||||||
Error: dto.OpenAIError{
|
Message: aliResponse.Message,
|
||||||
Message: aliResponse.Message,
|
Type: "ali_error",
|
||||||
Type: aliResponse.Code,
|
Param: aliResponse.RequestId,
|
||||||
Param: aliResponse.RequestId,
|
Code: aliResponse.Code,
|
||||||
Code: aliResponse.Code,
|
}, resp.StatusCode), nil
|
||||||
},
|
|
||||||
StatusCode: resp.StatusCode,
|
|
||||||
}, nil
|
|
||||||
}
|
}
|
||||||
fullTextResponse := responseAli2OpenAI(&aliResponse)
|
fullTextResponse := responseAli2OpenAI(&aliResponse)
|
||||||
jsonResponse, err := json.Marshal(fullTextResponse)
|
jsonResponse, err := common.Marshal(fullTextResponse)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return service.OpenAIErrorWrapper(err, "marshal_response_body_failed", http.StatusInternalServerError), nil
|
return types.NewError(err, types.ErrorCodeBadResponseBody), nil
|
||||||
}
|
}
|
||||||
c.Writer.Header().Set("Content-Type", "application/json")
|
c.Writer.Header().Set("Content-Type", "application/json")
|
||||||
c.Writer.WriteHeader(resp.StatusCode)
|
c.Writer.WriteHeader(resp.StatusCode)
|
||||||
|
|||||||
@@ -8,6 +8,7 @@ import (
|
|||||||
"one-api/relay/channel/claude"
|
"one-api/relay/channel/claude"
|
||||||
relaycommon "one-api/relay/common"
|
relaycommon "one-api/relay/common"
|
||||||
"one-api/setting/model_setting"
|
"one-api/setting/model_setting"
|
||||||
|
"one-api/types"
|
||||||
|
|
||||||
"github.com/gin-gonic/gin"
|
"github.com/gin-gonic/gin"
|
||||||
)
|
)
|
||||||
@@ -84,7 +85,7 @@ func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, request
|
|||||||
return nil, nil
|
return nil, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage any, err *dto.OpenAIErrorWithStatusCode) {
|
func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage any, err *types.NewAPIError) {
|
||||||
if info.IsStream {
|
if info.IsStream {
|
||||||
err, usage = awsStreamHandler(c, resp, info, a.RequestMode)
|
err, usage = awsStreamHandler(c, resp, info, a.RequestMode)
|
||||||
} else {
|
} else {
|
||||||
|
|||||||
@@ -3,19 +3,22 @@ package aws
|
|||||||
import (
|
import (
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
"fmt"
|
"fmt"
|
||||||
"github.com/gin-gonic/gin"
|
|
||||||
"github.com/pkg/errors"
|
|
||||||
"net/http"
|
"net/http"
|
||||||
"one-api/common"
|
"one-api/common"
|
||||||
"one-api/dto"
|
"one-api/dto"
|
||||||
"one-api/relay/channel/claude"
|
"one-api/relay/channel/claude"
|
||||||
relaycommon "one-api/relay/common"
|
relaycommon "one-api/relay/common"
|
||||||
|
"one-api/relay/helper"
|
||||||
|
"one-api/types"
|
||||||
"strings"
|
"strings"
|
||||||
|
|
||||||
|
"github.com/gin-gonic/gin"
|
||||||
|
"github.com/pkg/errors"
|
||||||
|
|
||||||
"github.com/aws/aws-sdk-go-v2/aws"
|
"github.com/aws/aws-sdk-go-v2/aws"
|
||||||
"github.com/aws/aws-sdk-go-v2/credentials"
|
"github.com/aws/aws-sdk-go-v2/credentials"
|
||||||
"github.com/aws/aws-sdk-go-v2/service/bedrockruntime"
|
"github.com/aws/aws-sdk-go-v2/service/bedrockruntime"
|
||||||
"github.com/aws/aws-sdk-go-v2/service/bedrockruntime/types"
|
bedrockruntimeTypes "github.com/aws/aws-sdk-go-v2/service/bedrockruntime/types"
|
||||||
)
|
)
|
||||||
|
|
||||||
func newAwsClient(c *gin.Context, info *relaycommon.RelayInfo) (*bedrockruntime.Client, error) {
|
func newAwsClient(c *gin.Context, info *relaycommon.RelayInfo) (*bedrockruntime.Client, error) {
|
||||||
@@ -65,24 +68,21 @@ func awsModelCrossRegion(awsModelId, awsRegionPrefix string) string {
|
|||||||
return modelPrefix + "." + awsModelId
|
return modelPrefix + "." + awsModelId
|
||||||
}
|
}
|
||||||
|
|
||||||
func awsModelID(requestModel string) (string, error) {
|
func awsModelID(requestModel string) string {
|
||||||
if awsModelID, ok := awsModelIDMap[requestModel]; ok {
|
if awsModelID, ok := awsModelIDMap[requestModel]; ok {
|
||||||
return awsModelID, nil
|
return awsModelID
|
||||||
}
|
}
|
||||||
|
|
||||||
return requestModel, nil
|
return requestModel
|
||||||
}
|
}
|
||||||
|
|
||||||
func awsHandler(c *gin.Context, info *relaycommon.RelayInfo, requestMode int) (*dto.OpenAIErrorWithStatusCode, *dto.Usage) {
|
func awsHandler(c *gin.Context, info *relaycommon.RelayInfo, requestMode int) (*types.NewAPIError, *dto.Usage) {
|
||||||
awsCli, err := newAwsClient(c, info)
|
awsCli, err := newAwsClient(c, info)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return wrapErr(errors.Wrap(err, "newAwsClient")), nil
|
return types.NewError(err, types.ErrorCodeChannelAwsClientError), nil
|
||||||
}
|
}
|
||||||
|
|
||||||
awsModelId, err := awsModelID(c.GetString("request_model"))
|
awsModelId := awsModelID(c.GetString("request_model"))
|
||||||
if err != nil {
|
|
||||||
return wrapErr(errors.Wrap(err, "awsModelID")), nil
|
|
||||||
}
|
|
||||||
|
|
||||||
awsRegionPrefix := awsRegionPrefix(awsCli.Options().Region)
|
awsRegionPrefix := awsRegionPrefix(awsCli.Options().Region)
|
||||||
canCrossRegion := awsModelCanCrossRegion(awsModelId, awsRegionPrefix)
|
canCrossRegion := awsModelCanCrossRegion(awsModelId, awsRegionPrefix)
|
||||||
@@ -98,42 +98,42 @@ func awsHandler(c *gin.Context, info *relaycommon.RelayInfo, requestMode int) (*
|
|||||||
|
|
||||||
claudeReq_, ok := c.Get("converted_request")
|
claudeReq_, ok := c.Get("converted_request")
|
||||||
if !ok {
|
if !ok {
|
||||||
return wrapErr(errors.New("request not found")), nil
|
return types.NewError(errors.New("aws claude request not found"), types.ErrorCodeInvalidRequest), nil
|
||||||
}
|
}
|
||||||
claudeReq := claudeReq_.(*dto.ClaudeRequest)
|
claudeReq := claudeReq_.(*dto.ClaudeRequest)
|
||||||
awsClaudeReq := copyRequest(claudeReq)
|
awsClaudeReq := copyRequest(claudeReq)
|
||||||
awsReq.Body, err = json.Marshal(awsClaudeReq)
|
awsReq.Body, err = json.Marshal(awsClaudeReq)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return wrapErr(errors.Wrap(err, "marshal request")), nil
|
return types.NewError(errors.Wrap(err, "marshal request"), types.ErrorCodeBadResponseBody), nil
|
||||||
}
|
}
|
||||||
|
|
||||||
awsResp, err := awsCli.InvokeModel(c.Request.Context(), awsReq)
|
awsResp, err := awsCli.InvokeModel(c.Request.Context(), awsReq)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return wrapErr(errors.Wrap(err, "InvokeModel")), nil
|
return types.NewError(errors.Wrap(err, "InvokeModel"), types.ErrorCodeChannelAwsClientError), nil
|
||||||
}
|
}
|
||||||
|
|
||||||
claudeInfo := &claude.ClaudeResponseInfo{
|
claudeInfo := &claude.ClaudeResponseInfo{
|
||||||
ResponseId: fmt.Sprintf("chatcmpl-%s", common.GetUUID()),
|
ResponseId: helper.GetResponseID(c),
|
||||||
Created: common.GetTimestamp(),
|
Created: common.GetTimestamp(),
|
||||||
Model: info.UpstreamModelName,
|
Model: info.UpstreamModelName,
|
||||||
ResponseText: strings.Builder{},
|
ResponseText: strings.Builder{},
|
||||||
Usage: &dto.Usage{},
|
Usage: &dto.Usage{},
|
||||||
}
|
}
|
||||||
|
|
||||||
claude.HandleClaudeResponseData(c, info, claudeInfo, awsResp.Body, RequestModeMessage)
|
handlerErr := claude.HandleClaudeResponseData(c, info, claudeInfo, awsResp.Body, RequestModeMessage)
|
||||||
|
if handlerErr != nil {
|
||||||
|
return handlerErr, nil
|
||||||
|
}
|
||||||
return nil, claudeInfo.Usage
|
return nil, claudeInfo.Usage
|
||||||
}
|
}
|
||||||
|
|
||||||
func awsStreamHandler(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo, requestMode int) (*dto.OpenAIErrorWithStatusCode, *dto.Usage) {
|
func awsStreamHandler(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo, requestMode int) (*types.NewAPIError, *dto.Usage) {
|
||||||
awsCli, err := newAwsClient(c, info)
|
awsCli, err := newAwsClient(c, info)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return wrapErr(errors.Wrap(err, "newAwsClient")), nil
|
return types.NewError(err, types.ErrorCodeChannelAwsClientError), nil
|
||||||
}
|
}
|
||||||
|
|
||||||
awsModelId, err := awsModelID(c.GetString("request_model"))
|
awsModelId := awsModelID(c.GetString("request_model"))
|
||||||
if err != nil {
|
|
||||||
return wrapErr(errors.Wrap(err, "awsModelID")), nil
|
|
||||||
}
|
|
||||||
|
|
||||||
awsRegionPrefix := awsRegionPrefix(awsCli.Options().Region)
|
awsRegionPrefix := awsRegionPrefix(awsCli.Options().Region)
|
||||||
canCrossRegion := awsModelCanCrossRegion(awsModelId, awsRegionPrefix)
|
canCrossRegion := awsModelCanCrossRegion(awsModelId, awsRegionPrefix)
|
||||||
@@ -149,25 +149,25 @@ func awsStreamHandler(c *gin.Context, resp *http.Response, info *relaycommon.Rel
|
|||||||
|
|
||||||
claudeReq_, ok := c.Get("converted_request")
|
claudeReq_, ok := c.Get("converted_request")
|
||||||
if !ok {
|
if !ok {
|
||||||
return wrapErr(errors.New("request not found")), nil
|
return types.NewError(errors.New("aws claude request not found"), types.ErrorCodeInvalidRequest), nil
|
||||||
}
|
}
|
||||||
claudeReq := claudeReq_.(*dto.ClaudeRequest)
|
claudeReq := claudeReq_.(*dto.ClaudeRequest)
|
||||||
|
|
||||||
awsClaudeReq := copyRequest(claudeReq)
|
awsClaudeReq := copyRequest(claudeReq)
|
||||||
awsReq.Body, err = json.Marshal(awsClaudeReq)
|
awsReq.Body, err = json.Marshal(awsClaudeReq)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return wrapErr(errors.Wrap(err, "marshal request")), nil
|
return types.NewError(errors.Wrap(err, "marshal request"), types.ErrorCodeBadResponseBody), nil
|
||||||
}
|
}
|
||||||
|
|
||||||
awsResp, err := awsCli.InvokeModelWithResponseStream(c.Request.Context(), awsReq)
|
awsResp, err := awsCli.InvokeModelWithResponseStream(c.Request.Context(), awsReq)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return wrapErr(errors.Wrap(err, "InvokeModelWithResponseStream")), nil
|
return types.NewError(errors.Wrap(err, "InvokeModelWithResponseStream"), types.ErrorCodeChannelAwsClientError), nil
|
||||||
}
|
}
|
||||||
stream := awsResp.GetStream()
|
stream := awsResp.GetStream()
|
||||||
defer stream.Close()
|
defer stream.Close()
|
||||||
|
|
||||||
claudeInfo := &claude.ClaudeResponseInfo{
|
claudeInfo := &claude.ClaudeResponseInfo{
|
||||||
ResponseId: fmt.Sprintf("chatcmpl-%s", common.GetUUID()),
|
ResponseId: helper.GetResponseID(c),
|
||||||
Created: common.GetTimestamp(),
|
Created: common.GetTimestamp(),
|
||||||
Model: info.UpstreamModelName,
|
Model: info.UpstreamModelName,
|
||||||
ResponseText: strings.Builder{},
|
ResponseText: strings.Builder{},
|
||||||
@@ -176,18 +176,18 @@ func awsStreamHandler(c *gin.Context, resp *http.Response, info *relaycommon.Rel
|
|||||||
|
|
||||||
for event := range stream.Events() {
|
for event := range stream.Events() {
|
||||||
switch v := event.(type) {
|
switch v := event.(type) {
|
||||||
case *types.ResponseStreamMemberChunk:
|
case *bedrockruntimeTypes.ResponseStreamMemberChunk:
|
||||||
info.SetFirstResponseTime()
|
info.SetFirstResponseTime()
|
||||||
respErr := claude.HandleStreamResponseData(c, info, claudeInfo, string(v.Value.Bytes), RequestModeMessage)
|
respErr := claude.HandleStreamResponseData(c, info, claudeInfo, string(v.Value.Bytes), RequestModeMessage)
|
||||||
if respErr != nil {
|
if respErr != nil {
|
||||||
return respErr, nil
|
return respErr, nil
|
||||||
}
|
}
|
||||||
case *types.UnknownUnionMember:
|
case *bedrockruntimeTypes.UnknownUnionMember:
|
||||||
fmt.Println("unknown tag:", v.Tag)
|
fmt.Println("unknown tag:", v.Tag)
|
||||||
return wrapErr(errors.New("unknown response type")), nil
|
return types.NewError(errors.New("unknown response type"), types.ErrorCodeInvalidRequest), nil
|
||||||
default:
|
default:
|
||||||
fmt.Println("union is nil or unknown type")
|
fmt.Println("union is nil or unknown type")
|
||||||
return wrapErr(errors.New("nil or unknown response type")), nil
|
return types.NewError(errors.New("nil or unknown response type"), types.ErrorCodeInvalidRequest), nil
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -9,6 +9,7 @@ import (
|
|||||||
"one-api/relay/channel"
|
"one-api/relay/channel"
|
||||||
relaycommon "one-api/relay/common"
|
relaycommon "one-api/relay/common"
|
||||||
"one-api/relay/constant"
|
"one-api/relay/constant"
|
||||||
|
"one-api/types"
|
||||||
"strings"
|
"strings"
|
||||||
|
|
||||||
"github.com/gin-gonic/gin"
|
"github.com/gin-gonic/gin"
|
||||||
@@ -140,15 +141,15 @@ func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, request
|
|||||||
return channel.DoApiRequest(a, c, info, requestBody)
|
return channel.DoApiRequest(a, c, info, requestBody)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage any, err *dto.OpenAIErrorWithStatusCode) {
|
func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage any, err *types.NewAPIError) {
|
||||||
if info.IsStream {
|
if info.IsStream {
|
||||||
err, usage = baiduStreamHandler(c, resp)
|
err, usage = baiduStreamHandler(c, info, resp)
|
||||||
} else {
|
} else {
|
||||||
switch info.RelayMode {
|
switch info.RelayMode {
|
||||||
case constant.RelayModeEmbeddings:
|
case constant.RelayModeEmbeddings:
|
||||||
err, usage = baiduEmbeddingHandler(c, resp)
|
err, usage = baiduEmbeddingHandler(c, info, resp)
|
||||||
default:
|
default:
|
||||||
err, usage = baiduHandler(c, resp)
|
err, usage = baiduHandler(c, info, resp)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
return
|
return
|
||||||
|
|||||||
@@ -1,21 +1,23 @@
|
|||||||
package baidu
|
package baidu
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"bufio"
|
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
"errors"
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
"github.com/gin-gonic/gin"
|
|
||||||
"io"
|
"io"
|
||||||
"net/http"
|
"net/http"
|
||||||
"one-api/common"
|
"one-api/common"
|
||||||
"one-api/constant"
|
"one-api/constant"
|
||||||
"one-api/dto"
|
"one-api/dto"
|
||||||
|
relaycommon "one-api/relay/common"
|
||||||
"one-api/relay/helper"
|
"one-api/relay/helper"
|
||||||
"one-api/service"
|
"one-api/service"
|
||||||
|
"one-api/types"
|
||||||
"strings"
|
"strings"
|
||||||
"sync"
|
"sync"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
|
"github.com/gin-gonic/gin"
|
||||||
)
|
)
|
||||||
|
|
||||||
// https://cloud.baidu.com/doc/WENXINWORKSHOP/s/flfmc9do2
|
// https://cloud.baidu.com/doc/WENXINWORKSHOP/s/flfmc9do2
|
||||||
@@ -110,92 +112,49 @@ func embeddingResponseBaidu2OpenAI(response *BaiduEmbeddingResponse) *dto.OpenAI
|
|||||||
return &openAIEmbeddingResponse
|
return &openAIEmbeddingResponse
|
||||||
}
|
}
|
||||||
|
|
||||||
func baiduStreamHandler(c *gin.Context, resp *http.Response) (*dto.OpenAIErrorWithStatusCode, *dto.Usage) {
|
func baiduStreamHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http.Response) (*types.NewAPIError, *dto.Usage) {
|
||||||
var usage dto.Usage
|
usage := &dto.Usage{}
|
||||||
scanner := bufio.NewScanner(resp.Body)
|
helper.StreamScannerHandler(c, resp, info, func(data string) bool {
|
||||||
scanner.Split(func(data []byte, atEOF bool) (advance int, token []byte, err error) {
|
var baiduResponse BaiduChatStreamResponse
|
||||||
if atEOF && len(data) == 0 {
|
err := common.Unmarshal([]byte(data), &baiduResponse)
|
||||||
return 0, nil, nil
|
if err != nil {
|
||||||
}
|
common.SysError("error unmarshalling stream response: " + err.Error())
|
||||||
if i := strings.Index(string(data), "\n"); i >= 0 {
|
|
||||||
return i + 1, data[0:i], nil
|
|
||||||
}
|
|
||||||
if atEOF {
|
|
||||||
return len(data), data, nil
|
|
||||||
}
|
|
||||||
return 0, nil, nil
|
|
||||||
})
|
|
||||||
dataChan := make(chan string)
|
|
||||||
stopChan := make(chan bool)
|
|
||||||
go func() {
|
|
||||||
for scanner.Scan() {
|
|
||||||
data := scanner.Text()
|
|
||||||
if len(data) < 6 { // ignore blank line or wrong format
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
data = data[6:]
|
|
||||||
dataChan <- data
|
|
||||||
}
|
|
||||||
stopChan <- true
|
|
||||||
}()
|
|
||||||
helper.SetEventStreamHeaders(c)
|
|
||||||
c.Stream(func(w io.Writer) bool {
|
|
||||||
select {
|
|
||||||
case data := <-dataChan:
|
|
||||||
var baiduResponse BaiduChatStreamResponse
|
|
||||||
err := json.Unmarshal([]byte(data), &baiduResponse)
|
|
||||||
if err != nil {
|
|
||||||
common.SysError("error unmarshalling stream response: " + err.Error())
|
|
||||||
return true
|
|
||||||
}
|
|
||||||
if baiduResponse.Usage.TotalTokens != 0 {
|
|
||||||
usage.TotalTokens = baiduResponse.Usage.TotalTokens
|
|
||||||
usage.PromptTokens = baiduResponse.Usage.PromptTokens
|
|
||||||
usage.CompletionTokens = baiduResponse.Usage.TotalTokens - baiduResponse.Usage.PromptTokens
|
|
||||||
}
|
|
||||||
response := streamResponseBaidu2OpenAI(&baiduResponse)
|
|
||||||
jsonResponse, err := json.Marshal(response)
|
|
||||||
if err != nil {
|
|
||||||
common.SysError("error marshalling stream response: " + err.Error())
|
|
||||||
return true
|
|
||||||
}
|
|
||||||
c.Render(-1, common.CustomEvent{Data: "data: " + string(jsonResponse)})
|
|
||||||
return true
|
return true
|
||||||
case <-stopChan:
|
|
||||||
c.Render(-1, common.CustomEvent{Data: "data: [DONE]"})
|
|
||||||
return false
|
|
||||||
}
|
}
|
||||||
|
if baiduResponse.Usage.TotalTokens != 0 {
|
||||||
|
usage.TotalTokens = baiduResponse.Usage.TotalTokens
|
||||||
|
usage.PromptTokens = baiduResponse.Usage.PromptTokens
|
||||||
|
usage.CompletionTokens = baiduResponse.Usage.TotalTokens - baiduResponse.Usage.PromptTokens
|
||||||
|
}
|
||||||
|
response := streamResponseBaidu2OpenAI(&baiduResponse)
|
||||||
|
err = helper.ObjectData(c, response)
|
||||||
|
if err != nil {
|
||||||
|
common.SysError("error sending stream response: " + err.Error())
|
||||||
|
}
|
||||||
|
return true
|
||||||
})
|
})
|
||||||
common.CloseResponseBodyGracefully(resp)
|
common.CloseResponseBodyGracefully(resp)
|
||||||
return nil, &usage
|
return nil, usage
|
||||||
}
|
}
|
||||||
|
|
||||||
func baiduHandler(c *gin.Context, resp *http.Response) (*dto.OpenAIErrorWithStatusCode, *dto.Usage) {
|
func baiduHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http.Response) (*types.NewAPIError, *dto.Usage) {
|
||||||
var baiduResponse BaiduChatResponse
|
var baiduResponse BaiduChatResponse
|
||||||
responseBody, err := io.ReadAll(resp.Body)
|
responseBody, err := io.ReadAll(resp.Body)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return service.OpenAIErrorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil
|
return types.NewError(err, types.ErrorCodeBadResponseBody), nil
|
||||||
}
|
}
|
||||||
common.CloseResponseBodyGracefully(resp)
|
common.CloseResponseBodyGracefully(resp)
|
||||||
err = json.Unmarshal(responseBody, &baiduResponse)
|
err = json.Unmarshal(responseBody, &baiduResponse)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return service.OpenAIErrorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil
|
return types.NewError(err, types.ErrorCodeBadResponseBody), nil
|
||||||
}
|
}
|
||||||
if baiduResponse.ErrorMsg != "" {
|
if baiduResponse.ErrorMsg != "" {
|
||||||
return &dto.OpenAIErrorWithStatusCode{
|
return types.NewError(fmt.Errorf(baiduResponse.ErrorMsg), types.ErrorCodeBadResponseBody), nil
|
||||||
Error: dto.OpenAIError{
|
|
||||||
Message: baiduResponse.ErrorMsg,
|
|
||||||
Type: "baidu_error",
|
|
||||||
Param: "",
|
|
||||||
Code: baiduResponse.ErrorCode,
|
|
||||||
},
|
|
||||||
StatusCode: resp.StatusCode,
|
|
||||||
}, nil
|
|
||||||
}
|
}
|
||||||
fullTextResponse := responseBaidu2OpenAI(&baiduResponse)
|
fullTextResponse := responseBaidu2OpenAI(&baiduResponse)
|
||||||
jsonResponse, err := json.Marshal(fullTextResponse)
|
jsonResponse, err := json.Marshal(fullTextResponse)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return service.OpenAIErrorWrapper(err, "marshal_response_body_failed", http.StatusInternalServerError), nil
|
return types.NewError(err, types.ErrorCodeBadResponseBody), nil
|
||||||
}
|
}
|
||||||
c.Writer.Header().Set("Content-Type", "application/json")
|
c.Writer.Header().Set("Content-Type", "application/json")
|
||||||
c.Writer.WriteHeader(resp.StatusCode)
|
c.Writer.WriteHeader(resp.StatusCode)
|
||||||
@@ -203,32 +162,24 @@ func baiduHandler(c *gin.Context, resp *http.Response) (*dto.OpenAIErrorWithStat
|
|||||||
return nil, &fullTextResponse.Usage
|
return nil, &fullTextResponse.Usage
|
||||||
}
|
}
|
||||||
|
|
||||||
func baiduEmbeddingHandler(c *gin.Context, resp *http.Response) (*dto.OpenAIErrorWithStatusCode, *dto.Usage) {
|
func baiduEmbeddingHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http.Response) (*types.NewAPIError, *dto.Usage) {
|
||||||
var baiduResponse BaiduEmbeddingResponse
|
var baiduResponse BaiduEmbeddingResponse
|
||||||
responseBody, err := io.ReadAll(resp.Body)
|
responseBody, err := io.ReadAll(resp.Body)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return service.OpenAIErrorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil
|
return types.NewError(err, types.ErrorCodeBadResponseBody), nil
|
||||||
}
|
}
|
||||||
common.CloseResponseBodyGracefully(resp)
|
common.CloseResponseBodyGracefully(resp)
|
||||||
err = json.Unmarshal(responseBody, &baiduResponse)
|
err = json.Unmarshal(responseBody, &baiduResponse)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return service.OpenAIErrorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil
|
return types.NewError(err, types.ErrorCodeBadResponseBody), nil
|
||||||
}
|
}
|
||||||
if baiduResponse.ErrorMsg != "" {
|
if baiduResponse.ErrorMsg != "" {
|
||||||
return &dto.OpenAIErrorWithStatusCode{
|
return types.NewError(fmt.Errorf(baiduResponse.ErrorMsg), types.ErrorCodeBadResponseBody), nil
|
||||||
Error: dto.OpenAIError{
|
|
||||||
Message: baiduResponse.ErrorMsg,
|
|
||||||
Type: "baidu_error",
|
|
||||||
Param: "",
|
|
||||||
Code: baiduResponse.ErrorCode,
|
|
||||||
},
|
|
||||||
StatusCode: resp.StatusCode,
|
|
||||||
}, nil
|
|
||||||
}
|
}
|
||||||
fullTextResponse := embeddingResponseBaidu2OpenAI(&baiduResponse)
|
fullTextResponse := embeddingResponseBaidu2OpenAI(&baiduResponse)
|
||||||
jsonResponse, err := json.Marshal(fullTextResponse)
|
jsonResponse, err := json.Marshal(fullTextResponse)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return service.OpenAIErrorWrapper(err, "marshal_response_body_failed", http.StatusInternalServerError), nil
|
return types.NewError(err, types.ErrorCodeBadResponseBody), nil
|
||||||
}
|
}
|
||||||
c.Writer.Header().Set("Content-Type", "application/json")
|
c.Writer.Header().Set("Content-Type", "application/json")
|
||||||
c.Writer.WriteHeader(resp.StatusCode)
|
c.Writer.WriteHeader(resp.StatusCode)
|
||||||
|
|||||||
@@ -9,6 +9,7 @@ import (
|
|||||||
"one-api/relay/channel"
|
"one-api/relay/channel"
|
||||||
"one-api/relay/channel/openai"
|
"one-api/relay/channel/openai"
|
||||||
relaycommon "one-api/relay/common"
|
relaycommon "one-api/relay/common"
|
||||||
|
"one-api/types"
|
||||||
"strings"
|
"strings"
|
||||||
|
|
||||||
"github.com/gin-gonic/gin"
|
"github.com/gin-gonic/gin"
|
||||||
@@ -92,11 +93,11 @@ func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, request
|
|||||||
return channel.DoApiRequest(a, c, info, requestBody)
|
return channel.DoApiRequest(a, c, info, requestBody)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage any, err *dto.OpenAIErrorWithStatusCode) {
|
func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage any, err *types.NewAPIError) {
|
||||||
if info.IsStream {
|
if info.IsStream {
|
||||||
err, usage = openai.OaiStreamHandler(c, resp, info)
|
usage, err = openai.OaiStreamHandler(c, info, resp)
|
||||||
} else {
|
} else {
|
||||||
err, usage = openai.OpenaiHandler(c, resp, info)
|
usage, err = openai.OpenaiHandler(c, info, resp)
|
||||||
}
|
}
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -9,6 +9,7 @@ import (
|
|||||||
"one-api/relay/channel"
|
"one-api/relay/channel"
|
||||||
relaycommon "one-api/relay/common"
|
relaycommon "one-api/relay/common"
|
||||||
"one-api/setting/model_setting"
|
"one-api/setting/model_setting"
|
||||||
|
"one-api/types"
|
||||||
"strings"
|
"strings"
|
||||||
|
|
||||||
"github.com/gin-gonic/gin"
|
"github.com/gin-gonic/gin"
|
||||||
@@ -94,7 +95,7 @@ func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, request
|
|||||||
return channel.DoApiRequest(a, c, info, requestBody)
|
return channel.DoApiRequest(a, c, info, requestBody)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage any, err *dto.OpenAIErrorWithStatusCode) {
|
func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage any, err *types.NewAPIError) {
|
||||||
if info.IsStream {
|
if info.IsStream {
|
||||||
err, usage = ClaudeStreamHandler(c, resp, info, a.RequestMode)
|
err, usage = ClaudeStreamHandler(c, resp, info, a.RequestMode)
|
||||||
} else {
|
} else {
|
||||||
|
|||||||
@@ -12,6 +12,7 @@ import (
|
|||||||
"one-api/relay/helper"
|
"one-api/relay/helper"
|
||||||
"one-api/service"
|
"one-api/service"
|
||||||
"one-api/setting/model_setting"
|
"one-api/setting/model_setting"
|
||||||
|
"one-api/types"
|
||||||
"strings"
|
"strings"
|
||||||
|
|
||||||
"github.com/gin-gonic/gin"
|
"github.com/gin-gonic/gin"
|
||||||
@@ -125,7 +126,7 @@ func RequestOpenAI2ClaudeMessage(textRequest dto.GeneralOpenAIRequest) (*dto.Cla
|
|||||||
|
|
||||||
if textRequest.Reasoning != nil {
|
if textRequest.Reasoning != nil {
|
||||||
var reasoning openrouter.RequestReasoning
|
var reasoning openrouter.RequestReasoning
|
||||||
if err := common.UnmarshalJson(textRequest.Reasoning, &reasoning); err != nil {
|
if err := common.Unmarshal(textRequest.Reasoning, &reasoning); err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -517,22 +518,15 @@ func FormatClaudeResponseInfo(requestMode int, claudeResponse *dto.ClaudeRespons
|
|||||||
return true
|
return true
|
||||||
}
|
}
|
||||||
|
|
||||||
func HandleStreamResponseData(c *gin.Context, info *relaycommon.RelayInfo, claudeInfo *ClaudeResponseInfo, data string, requestMode int) *dto.OpenAIErrorWithStatusCode {
|
func HandleStreamResponseData(c *gin.Context, info *relaycommon.RelayInfo, claudeInfo *ClaudeResponseInfo, data string, requestMode int) *types.NewAPIError {
|
||||||
var claudeResponse dto.ClaudeResponse
|
var claudeResponse dto.ClaudeResponse
|
||||||
err := common.UnmarshalJsonStr(data, &claudeResponse)
|
err := common.UnmarshalJsonStr(data, &claudeResponse)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
common.SysError("error unmarshalling stream response: " + err.Error())
|
common.SysError("error unmarshalling stream response: " + err.Error())
|
||||||
return service.OpenAIErrorWrapper(err, "stream_response_error", http.StatusInternalServerError)
|
return types.NewError(err, types.ErrorCodeBadResponseBody)
|
||||||
}
|
}
|
||||||
if claudeResponse.Error != nil && claudeResponse.Error.Type != "" {
|
if claudeResponse.Error != nil && claudeResponse.Error.Type != "" {
|
||||||
return &dto.OpenAIErrorWithStatusCode{
|
return types.WithClaudeError(*claudeResponse.Error, http.StatusInternalServerError)
|
||||||
Error: dto.OpenAIError{
|
|
||||||
Code: "stream_response_error",
|
|
||||||
Type: claudeResponse.Error.Type,
|
|
||||||
Message: claudeResponse.Error.Message,
|
|
||||||
},
|
|
||||||
StatusCode: http.StatusInternalServerError,
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
if info.RelayFormat == relaycommon.RelayFormatClaude {
|
if info.RelayFormat == relaycommon.RelayFormatClaude {
|
||||||
FormatClaudeResponseInfo(requestMode, &claudeResponse, nil, claudeInfo)
|
FormatClaudeResponseInfo(requestMode, &claudeResponse, nil, claudeInfo)
|
||||||
@@ -593,15 +587,15 @@ func HandleStreamFinalResponse(c *gin.Context, info *relaycommon.RelayInfo, clau
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func ClaudeStreamHandler(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo, requestMode int) (*dto.OpenAIErrorWithStatusCode, *dto.Usage) {
|
func ClaudeStreamHandler(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo, requestMode int) (*types.NewAPIError, *dto.Usage) {
|
||||||
claudeInfo := &ClaudeResponseInfo{
|
claudeInfo := &ClaudeResponseInfo{
|
||||||
ResponseId: fmt.Sprintf("chatcmpl-%s", common.GetUUID()),
|
ResponseId: helper.GetResponseID(c),
|
||||||
Created: common.GetTimestamp(),
|
Created: common.GetTimestamp(),
|
||||||
Model: info.UpstreamModelName,
|
Model: info.UpstreamModelName,
|
||||||
ResponseText: strings.Builder{},
|
ResponseText: strings.Builder{},
|
||||||
Usage: &dto.Usage{},
|
Usage: &dto.Usage{},
|
||||||
}
|
}
|
||||||
var err *dto.OpenAIErrorWithStatusCode
|
var err *types.NewAPIError
|
||||||
helper.StreamScannerHandler(c, resp, info, func(data string) bool {
|
helper.StreamScannerHandler(c, resp, info, func(data string) bool {
|
||||||
err = HandleStreamResponseData(c, info, claudeInfo, data, requestMode)
|
err = HandleStreamResponseData(c, info, claudeInfo, data, requestMode)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@@ -617,21 +611,14 @@ func ClaudeStreamHandler(c *gin.Context, resp *http.Response, info *relaycommon.
|
|||||||
return nil, claudeInfo.Usage
|
return nil, claudeInfo.Usage
|
||||||
}
|
}
|
||||||
|
|
||||||
func HandleClaudeResponseData(c *gin.Context, info *relaycommon.RelayInfo, claudeInfo *ClaudeResponseInfo, data []byte, requestMode int) *dto.OpenAIErrorWithStatusCode {
|
func HandleClaudeResponseData(c *gin.Context, info *relaycommon.RelayInfo, claudeInfo *ClaudeResponseInfo, data []byte, requestMode int) *types.NewAPIError {
|
||||||
var claudeResponse dto.ClaudeResponse
|
var claudeResponse dto.ClaudeResponse
|
||||||
err := common.UnmarshalJson(data, &claudeResponse)
|
err := common.Unmarshal(data, &claudeResponse)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return service.OpenAIErrorWrapper(err, "unmarshal_claude_response_failed", http.StatusInternalServerError)
|
return types.NewError(err, types.ErrorCodeBadResponseBody)
|
||||||
}
|
}
|
||||||
if claudeResponse.Error != nil && claudeResponse.Error.Type != "" {
|
if claudeResponse.Error != nil && claudeResponse.Error.Type != "" {
|
||||||
return &dto.OpenAIErrorWithStatusCode{
|
return types.WithClaudeError(*claudeResponse.Error, http.StatusInternalServerError)
|
||||||
Error: dto.OpenAIError{
|
|
||||||
Message: claudeResponse.Error.Message,
|
|
||||||
Type: claudeResponse.Error.Type,
|
|
||||||
Code: claudeResponse.Error.Type,
|
|
||||||
},
|
|
||||||
StatusCode: http.StatusInternalServerError,
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
if requestMode == RequestModeCompletion {
|
if requestMode == RequestModeCompletion {
|
||||||
completionTokens := service.CountTextToken(claudeResponse.Completion, info.OriginModelName)
|
completionTokens := service.CountTextToken(claudeResponse.Completion, info.OriginModelName)
|
||||||
@@ -652,7 +639,7 @@ func HandleClaudeResponseData(c *gin.Context, info *relaycommon.RelayInfo, claud
|
|||||||
openaiResponse.Usage = *claudeInfo.Usage
|
openaiResponse.Usage = *claudeInfo.Usage
|
||||||
responseData, err = json.Marshal(openaiResponse)
|
responseData, err = json.Marshal(openaiResponse)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return service.OpenAIErrorWrapper(err, "marshal_response_body_failed", http.StatusInternalServerError)
|
return types.NewError(err, types.ErrorCodeBadResponseBody)
|
||||||
}
|
}
|
||||||
case relaycommon.RelayFormatClaude:
|
case relaycommon.RelayFormatClaude:
|
||||||
responseData = data
|
responseData = data
|
||||||
@@ -662,11 +649,11 @@ func HandleClaudeResponseData(c *gin.Context, info *relaycommon.RelayInfo, claud
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func ClaudeHandler(c *gin.Context, resp *http.Response, requestMode int, info *relaycommon.RelayInfo) (*dto.OpenAIErrorWithStatusCode, *dto.Usage) {
|
func ClaudeHandler(c *gin.Context, resp *http.Response, requestMode int, info *relaycommon.RelayInfo) (*types.NewAPIError, *dto.Usage) {
|
||||||
defer common.CloseResponseBodyGracefully(resp)
|
defer common.CloseResponseBodyGracefully(resp)
|
||||||
|
|
||||||
claudeInfo := &ClaudeResponseInfo{
|
claudeInfo := &ClaudeResponseInfo{
|
||||||
ResponseId: fmt.Sprintf("chatcmpl-%s", common.GetUUID()),
|
ResponseId: helper.GetResponseID(c),
|
||||||
Created: common.GetTimestamp(),
|
Created: common.GetTimestamp(),
|
||||||
Model: info.UpstreamModelName,
|
Model: info.UpstreamModelName,
|
||||||
ResponseText: strings.Builder{},
|
ResponseText: strings.Builder{},
|
||||||
@@ -674,7 +661,7 @@ func ClaudeHandler(c *gin.Context, resp *http.Response, requestMode int, info *r
|
|||||||
}
|
}
|
||||||
responseBody, err := io.ReadAll(resp.Body)
|
responseBody, err := io.ReadAll(resp.Body)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return service.OpenAIErrorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil
|
return types.NewError(err, types.ErrorCodeBadResponseBody), nil
|
||||||
}
|
}
|
||||||
if common.DebugEnabled {
|
if common.DebugEnabled {
|
||||||
println("responseBody: ", string(responseBody))
|
println("responseBody: ", string(responseBody))
|
||||||
|
|||||||
@@ -10,6 +10,7 @@ import (
|
|||||||
"one-api/relay/channel"
|
"one-api/relay/channel"
|
||||||
relaycommon "one-api/relay/common"
|
relaycommon "one-api/relay/common"
|
||||||
"one-api/relay/constant"
|
"one-api/relay/constant"
|
||||||
|
"one-api/types"
|
||||||
|
|
||||||
"github.com/gin-gonic/gin"
|
"github.com/gin-gonic/gin"
|
||||||
)
|
)
|
||||||
@@ -94,20 +95,20 @@ func (a *Adaptor) ConvertImageRequest(c *gin.Context, info *relaycommon.RelayInf
|
|||||||
return nil, errors.New("not implemented")
|
return nil, errors.New("not implemented")
|
||||||
}
|
}
|
||||||
|
|
||||||
func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage any, err *dto.OpenAIErrorWithStatusCode) {
|
func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage any, err *types.NewAPIError) {
|
||||||
switch info.RelayMode {
|
switch info.RelayMode {
|
||||||
case constant.RelayModeEmbeddings:
|
case constant.RelayModeEmbeddings:
|
||||||
fallthrough
|
fallthrough
|
||||||
case constant.RelayModeChatCompletions:
|
case constant.RelayModeChatCompletions:
|
||||||
if info.IsStream {
|
if info.IsStream {
|
||||||
err, usage = cfStreamHandler(c, resp, info)
|
err, usage = cfStreamHandler(c, info, resp)
|
||||||
} else {
|
} else {
|
||||||
err, usage = cfHandler(c, resp, info)
|
err, usage = cfHandler(c, info, resp)
|
||||||
}
|
}
|
||||||
case constant.RelayModeAudioTranslation:
|
case constant.RelayModeAudioTranslation:
|
||||||
fallthrough
|
fallthrough
|
||||||
case constant.RelayModeAudioTranscription:
|
case constant.RelayModeAudioTranscription:
|
||||||
err, usage = cfSTTHandler(c, resp, info)
|
err, usage = cfSTTHandler(c, info, resp)
|
||||||
}
|
}
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -3,7 +3,6 @@ package cloudflare
|
|||||||
import (
|
import (
|
||||||
"bufio"
|
"bufio"
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
"github.com/gin-gonic/gin"
|
|
||||||
"io"
|
"io"
|
||||||
"net/http"
|
"net/http"
|
||||||
"one-api/common"
|
"one-api/common"
|
||||||
@@ -11,8 +10,11 @@ import (
|
|||||||
relaycommon "one-api/relay/common"
|
relaycommon "one-api/relay/common"
|
||||||
"one-api/relay/helper"
|
"one-api/relay/helper"
|
||||||
"one-api/service"
|
"one-api/service"
|
||||||
|
"one-api/types"
|
||||||
"strings"
|
"strings"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
|
"github.com/gin-gonic/gin"
|
||||||
)
|
)
|
||||||
|
|
||||||
func convertCf2CompletionsRequest(textRequest dto.GeneralOpenAIRequest) *CfRequest {
|
func convertCf2CompletionsRequest(textRequest dto.GeneralOpenAIRequest) *CfRequest {
|
||||||
@@ -25,7 +27,7 @@ func convertCf2CompletionsRequest(textRequest dto.GeneralOpenAIRequest) *CfReque
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func cfStreamHandler(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (*dto.OpenAIErrorWithStatusCode, *dto.Usage) {
|
func cfStreamHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http.Response) (*types.NewAPIError, *dto.Usage) {
|
||||||
scanner := bufio.NewScanner(resp.Body)
|
scanner := bufio.NewScanner(resp.Body)
|
||||||
scanner.Split(bufio.ScanLines)
|
scanner.Split(bufio.ScanLines)
|
||||||
|
|
||||||
@@ -86,16 +88,16 @@ func cfStreamHandler(c *gin.Context, resp *http.Response, info *relaycommon.Rela
|
|||||||
return nil, usage
|
return nil, usage
|
||||||
}
|
}
|
||||||
|
|
||||||
func cfHandler(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (*dto.OpenAIErrorWithStatusCode, *dto.Usage) {
|
func cfHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http.Response) (*types.NewAPIError, *dto.Usage) {
|
||||||
responseBody, err := io.ReadAll(resp.Body)
|
responseBody, err := io.ReadAll(resp.Body)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return service.OpenAIErrorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil
|
return types.NewError(err, types.ErrorCodeBadResponseBody), nil
|
||||||
}
|
}
|
||||||
common.CloseResponseBodyGracefully(resp)
|
common.CloseResponseBodyGracefully(resp)
|
||||||
var response dto.TextResponse
|
var response dto.TextResponse
|
||||||
err = json.Unmarshal(responseBody, &response)
|
err = json.Unmarshal(responseBody, &response)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return service.OpenAIErrorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil
|
return types.NewError(err, types.ErrorCodeBadResponseBody), nil
|
||||||
}
|
}
|
||||||
response.Model = info.UpstreamModelName
|
response.Model = info.UpstreamModelName
|
||||||
var responseText string
|
var responseText string
|
||||||
@@ -107,7 +109,7 @@ func cfHandler(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo)
|
|||||||
response.Id = helper.GetResponseID(c)
|
response.Id = helper.GetResponseID(c)
|
||||||
jsonResponse, err := json.Marshal(response)
|
jsonResponse, err := json.Marshal(response)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return service.OpenAIErrorWrapper(err, "marshal_response_body_failed", http.StatusInternalServerError), nil
|
return types.NewError(err, types.ErrorCodeBadResponseBody), nil
|
||||||
}
|
}
|
||||||
c.Writer.Header().Set("Content-Type", "application/json")
|
c.Writer.Header().Set("Content-Type", "application/json")
|
||||||
c.Writer.WriteHeader(resp.StatusCode)
|
c.Writer.WriteHeader(resp.StatusCode)
|
||||||
@@ -115,16 +117,16 @@ func cfHandler(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo)
|
|||||||
return nil, usage
|
return nil, usage
|
||||||
}
|
}
|
||||||
|
|
||||||
func cfSTTHandler(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (*dto.OpenAIErrorWithStatusCode, *dto.Usage) {
|
func cfSTTHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http.Response) (*types.NewAPIError, *dto.Usage) {
|
||||||
var cfResp CfAudioResponse
|
var cfResp CfAudioResponse
|
||||||
responseBody, err := io.ReadAll(resp.Body)
|
responseBody, err := io.ReadAll(resp.Body)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return service.OpenAIErrorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil
|
return types.NewError(err, types.ErrorCodeBadResponseBody), nil
|
||||||
}
|
}
|
||||||
common.CloseResponseBodyGracefully(resp)
|
common.CloseResponseBodyGracefully(resp)
|
||||||
err = json.Unmarshal(responseBody, &cfResp)
|
err = json.Unmarshal(responseBody, &cfResp)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return service.OpenAIErrorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil
|
return types.NewError(err, types.ErrorCodeBadResponseBody), nil
|
||||||
}
|
}
|
||||||
|
|
||||||
audioResp := &dto.AudioResponse{
|
audioResp := &dto.AudioResponse{
|
||||||
@@ -133,7 +135,7 @@ func cfSTTHandler(c *gin.Context, resp *http.Response, info *relaycommon.RelayIn
|
|||||||
|
|
||||||
jsonResponse, err := json.Marshal(audioResp)
|
jsonResponse, err := json.Marshal(audioResp)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return service.OpenAIErrorWrapper(err, "marshal_response_body_failed", http.StatusInternalServerError), nil
|
return types.NewError(err, types.ErrorCodeBadResponseBody), nil
|
||||||
}
|
}
|
||||||
c.Writer.Header().Set("Content-Type", "application/json")
|
c.Writer.Header().Set("Content-Type", "application/json")
|
||||||
c.Writer.WriteHeader(resp.StatusCode)
|
c.Writer.WriteHeader(resp.StatusCode)
|
||||||
|
|||||||
@@ -9,6 +9,7 @@ import (
|
|||||||
"one-api/relay/channel"
|
"one-api/relay/channel"
|
||||||
relaycommon "one-api/relay/common"
|
relaycommon "one-api/relay/common"
|
||||||
"one-api/relay/constant"
|
"one-api/relay/constant"
|
||||||
|
"one-api/types"
|
||||||
|
|
||||||
"github.com/gin-gonic/gin"
|
"github.com/gin-gonic/gin"
|
||||||
)
|
)
|
||||||
@@ -71,14 +72,14 @@ func (a *Adaptor) ConvertEmbeddingRequest(c *gin.Context, info *relaycommon.Rela
|
|||||||
return nil, errors.New("not implemented")
|
return nil, errors.New("not implemented")
|
||||||
}
|
}
|
||||||
|
|
||||||
func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage any, err *dto.OpenAIErrorWithStatusCode) {
|
func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage any, err *types.NewAPIError) {
|
||||||
if info.RelayMode == constant.RelayModeRerank {
|
if info.RelayMode == constant.RelayModeRerank {
|
||||||
err, usage = cohereRerankHandler(c, resp, info)
|
err, usage = cohereRerankHandler(c, resp, info)
|
||||||
} else {
|
} else {
|
||||||
if info.IsStream {
|
if info.IsStream {
|
||||||
err, usage = cohereStreamHandler(c, resp, info)
|
err, usage = cohereStreamHandler(c, info, resp)
|
||||||
} else {
|
} else {
|
||||||
err, usage = cohereHandler(c, resp, info.UpstreamModelName, info.PromptTokens)
|
err, usage = cohereHandler(c, info, resp)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
return
|
return
|
||||||
|
|||||||
@@ -3,7 +3,6 @@ package cohere
|
|||||||
import (
|
import (
|
||||||
"bufio"
|
"bufio"
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
"github.com/gin-gonic/gin"
|
|
||||||
"io"
|
"io"
|
||||||
"net/http"
|
"net/http"
|
||||||
"one-api/common"
|
"one-api/common"
|
||||||
@@ -11,8 +10,11 @@ import (
|
|||||||
relaycommon "one-api/relay/common"
|
relaycommon "one-api/relay/common"
|
||||||
"one-api/relay/helper"
|
"one-api/relay/helper"
|
||||||
"one-api/service"
|
"one-api/service"
|
||||||
|
"one-api/types"
|
||||||
"strings"
|
"strings"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
|
"github.com/gin-gonic/gin"
|
||||||
)
|
)
|
||||||
|
|
||||||
func requestOpenAI2Cohere(textRequest dto.GeneralOpenAIRequest) *CohereRequest {
|
func requestOpenAI2Cohere(textRequest dto.GeneralOpenAIRequest) *CohereRequest {
|
||||||
@@ -76,7 +78,7 @@ func stopReasonCohere2OpenAI(reason string) string {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func cohereStreamHandler(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (*dto.OpenAIErrorWithStatusCode, *dto.Usage) {
|
func cohereStreamHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http.Response) (*types.NewAPIError, *dto.Usage) {
|
||||||
responseId := helper.GetResponseID(c)
|
responseId := helper.GetResponseID(c)
|
||||||
createdTime := common.GetTimestamp()
|
createdTime := common.GetTimestamp()
|
||||||
usage := &dto.Usage{}
|
usage := &dto.Usage{}
|
||||||
@@ -167,17 +169,17 @@ func cohereStreamHandler(c *gin.Context, resp *http.Response, info *relaycommon.
|
|||||||
return nil, usage
|
return nil, usage
|
||||||
}
|
}
|
||||||
|
|
||||||
func cohereHandler(c *gin.Context, resp *http.Response, modelName string, promptTokens int) (*dto.OpenAIErrorWithStatusCode, *dto.Usage) {
|
func cohereHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http.Response) (*types.NewAPIError, *dto.Usage) {
|
||||||
createdTime := common.GetTimestamp()
|
createdTime := common.GetTimestamp()
|
||||||
responseBody, err := io.ReadAll(resp.Body)
|
responseBody, err := io.ReadAll(resp.Body)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return service.OpenAIErrorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil
|
return types.NewError(err, types.ErrorCodeBadResponseBody), nil
|
||||||
}
|
}
|
||||||
common.CloseResponseBodyGracefully(resp)
|
common.CloseResponseBodyGracefully(resp)
|
||||||
var cohereResp CohereResponseResult
|
var cohereResp CohereResponseResult
|
||||||
err = json.Unmarshal(responseBody, &cohereResp)
|
err = json.Unmarshal(responseBody, &cohereResp)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return service.OpenAIErrorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil
|
return types.NewError(err, types.ErrorCodeBadResponseBody), nil
|
||||||
}
|
}
|
||||||
usage := dto.Usage{}
|
usage := dto.Usage{}
|
||||||
usage.PromptTokens = cohereResp.Meta.BilledUnits.InputTokens
|
usage.PromptTokens = cohereResp.Meta.BilledUnits.InputTokens
|
||||||
@@ -188,7 +190,7 @@ func cohereHandler(c *gin.Context, resp *http.Response, modelName string, prompt
|
|||||||
openaiResp.Id = cohereResp.ResponseId
|
openaiResp.Id = cohereResp.ResponseId
|
||||||
openaiResp.Created = createdTime
|
openaiResp.Created = createdTime
|
||||||
openaiResp.Object = "chat.completion"
|
openaiResp.Object = "chat.completion"
|
||||||
openaiResp.Model = modelName
|
openaiResp.Model = info.UpstreamModelName
|
||||||
openaiResp.Usage = usage
|
openaiResp.Usage = usage
|
||||||
|
|
||||||
openaiResp.Choices = []dto.OpenAITextResponseChoice{
|
openaiResp.Choices = []dto.OpenAITextResponseChoice{
|
||||||
@@ -201,7 +203,7 @@ func cohereHandler(c *gin.Context, resp *http.Response, modelName string, prompt
|
|||||||
|
|
||||||
jsonResponse, err := json.Marshal(openaiResp)
|
jsonResponse, err := json.Marshal(openaiResp)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return service.OpenAIErrorWrapper(err, "marshal_response_body_failed", http.StatusInternalServerError), nil
|
return types.NewError(err, types.ErrorCodeBadResponseBody), nil
|
||||||
}
|
}
|
||||||
c.Writer.Header().Set("Content-Type", "application/json")
|
c.Writer.Header().Set("Content-Type", "application/json")
|
||||||
c.Writer.WriteHeader(resp.StatusCode)
|
c.Writer.WriteHeader(resp.StatusCode)
|
||||||
@@ -209,16 +211,16 @@ func cohereHandler(c *gin.Context, resp *http.Response, modelName string, prompt
|
|||||||
return nil, &usage
|
return nil, &usage
|
||||||
}
|
}
|
||||||
|
|
||||||
func cohereRerankHandler(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (*dto.OpenAIErrorWithStatusCode, *dto.Usage) {
|
func cohereRerankHandler(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (*types.NewAPIError, *dto.Usage) {
|
||||||
responseBody, err := io.ReadAll(resp.Body)
|
responseBody, err := io.ReadAll(resp.Body)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return service.OpenAIErrorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil
|
return types.NewError(err, types.ErrorCodeBadResponseBody), nil
|
||||||
}
|
}
|
||||||
common.CloseResponseBodyGracefully(resp)
|
common.CloseResponseBodyGracefully(resp)
|
||||||
var cohereResp CohereRerankResponseResult
|
var cohereResp CohereRerankResponseResult
|
||||||
err = json.Unmarshal(responseBody, &cohereResp)
|
err = json.Unmarshal(responseBody, &cohereResp)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return service.OpenAIErrorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil
|
return types.NewError(err, types.ErrorCodeBadResponseBody), nil
|
||||||
}
|
}
|
||||||
usage := dto.Usage{}
|
usage := dto.Usage{}
|
||||||
if cohereResp.Meta.BilledUnits.InputTokens == 0 {
|
if cohereResp.Meta.BilledUnits.InputTokens == 0 {
|
||||||
@@ -237,7 +239,7 @@ func cohereRerankHandler(c *gin.Context, resp *http.Response, info *relaycommon.
|
|||||||
|
|
||||||
jsonResponse, err := json.Marshal(rerankResp)
|
jsonResponse, err := json.Marshal(rerankResp)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return service.OpenAIErrorWrapper(err, "marshal_response_body_failed", http.StatusInternalServerError), nil
|
return types.NewError(err, types.ErrorCodeBadResponseBody), nil
|
||||||
}
|
}
|
||||||
c.Writer.Header().Set("Content-Type", "application/json")
|
c.Writer.Header().Set("Content-Type", "application/json")
|
||||||
c.Writer.WriteHeader(resp.StatusCode)
|
c.Writer.WriteHeader(resp.StatusCode)
|
||||||
|
|||||||
@@ -9,6 +9,7 @@ import (
|
|||||||
"one-api/dto"
|
"one-api/dto"
|
||||||
"one-api/relay/channel"
|
"one-api/relay/channel"
|
||||||
"one-api/relay/common"
|
"one-api/relay/common"
|
||||||
|
"one-api/types"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/gin-gonic/gin"
|
"github.com/gin-gonic/gin"
|
||||||
@@ -95,11 +96,11 @@ func (a *Adaptor) DoRequest(c *gin.Context, info *common.RelayInfo, requestBody
|
|||||||
}
|
}
|
||||||
|
|
||||||
// DoResponse implements channel.Adaptor.
|
// DoResponse implements channel.Adaptor.
|
||||||
func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *common.RelayInfo) (usage any, err *dto.OpenAIErrorWithStatusCode) {
|
func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *common.RelayInfo) (usage any, err *types.NewAPIError) {
|
||||||
if info.IsStream {
|
if info.IsStream {
|
||||||
err, usage = cozeChatStreamHandler(c, resp, info)
|
err, usage = cozeChatStreamHandler(c, info, resp)
|
||||||
} else {
|
} else {
|
||||||
err, usage = cozeChatHandler(c, resp, info)
|
err, usage = cozeChatHandler(c, info, resp)
|
||||||
}
|
}
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -12,6 +12,7 @@ import (
|
|||||||
relaycommon "one-api/relay/common"
|
relaycommon "one-api/relay/common"
|
||||||
"one-api/relay/helper"
|
"one-api/relay/helper"
|
||||||
"one-api/service"
|
"one-api/service"
|
||||||
|
"one-api/types"
|
||||||
"strings"
|
"strings"
|
||||||
|
|
||||||
"github.com/gin-gonic/gin"
|
"github.com/gin-gonic/gin"
|
||||||
@@ -43,10 +44,10 @@ func convertCozeChatRequest(c *gin.Context, request dto.GeneralOpenAIRequest) *C
|
|||||||
return cozeRequest
|
return cozeRequest
|
||||||
}
|
}
|
||||||
|
|
||||||
func cozeChatHandler(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (*dto.OpenAIErrorWithStatusCode, *dto.Usage) {
|
func cozeChatHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http.Response) (*types.NewAPIError, *dto.Usage) {
|
||||||
responseBody, err := io.ReadAll(resp.Body)
|
responseBody, err := io.ReadAll(resp.Body)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return service.OpenAIErrorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil
|
return types.NewError(err, types.ErrorCodeBadResponseBody), nil
|
||||||
}
|
}
|
||||||
common.CloseResponseBodyGracefully(resp)
|
common.CloseResponseBodyGracefully(resp)
|
||||||
// convert coze response to openai response
|
// convert coze response to openai response
|
||||||
@@ -55,10 +56,10 @@ func cozeChatHandler(c *gin.Context, resp *http.Response, info *relaycommon.Rela
|
|||||||
response.Model = info.UpstreamModelName
|
response.Model = info.UpstreamModelName
|
||||||
err = json.Unmarshal(responseBody, &cozeResponse)
|
err = json.Unmarshal(responseBody, &cozeResponse)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return service.OpenAIErrorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil
|
return types.NewError(err, types.ErrorCodeBadResponseBody), nil
|
||||||
}
|
}
|
||||||
if cozeResponse.Code != 0 {
|
if cozeResponse.Code != 0 {
|
||||||
return service.OpenAIErrorWrapper(errors.New(cozeResponse.Msg), fmt.Sprintf("%d", cozeResponse.Code), http.StatusInternalServerError), nil
|
return types.NewError(errors.New(cozeResponse.Msg), types.ErrorCodeBadResponseBody), nil
|
||||||
}
|
}
|
||||||
// 从上下文获取 usage
|
// 从上下文获取 usage
|
||||||
var usage dto.Usage
|
var usage dto.Usage
|
||||||
@@ -85,7 +86,7 @@ func cozeChatHandler(c *gin.Context, resp *http.Response, info *relaycommon.Rela
|
|||||||
}
|
}
|
||||||
jsonResponse, err := json.Marshal(response)
|
jsonResponse, err := json.Marshal(response)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return service.OpenAIErrorWrapper(err, "marshal_response_body_failed", http.StatusInternalServerError), nil
|
return types.NewError(err, types.ErrorCodeBadResponseBody), nil
|
||||||
}
|
}
|
||||||
c.Writer.Header().Set("Content-Type", "application/json")
|
c.Writer.Header().Set("Content-Type", "application/json")
|
||||||
c.Writer.WriteHeader(resp.StatusCode)
|
c.Writer.WriteHeader(resp.StatusCode)
|
||||||
@@ -94,7 +95,7 @@ func cozeChatHandler(c *gin.Context, resp *http.Response, info *relaycommon.Rela
|
|||||||
return nil, &usage
|
return nil, &usage
|
||||||
}
|
}
|
||||||
|
|
||||||
func cozeChatStreamHandler(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (*dto.OpenAIErrorWithStatusCode, *dto.Usage) {
|
func cozeChatStreamHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http.Response) (*types.NewAPIError, *dto.Usage) {
|
||||||
scanner := bufio.NewScanner(resp.Body)
|
scanner := bufio.NewScanner(resp.Body)
|
||||||
scanner.Split(bufio.ScanLines)
|
scanner.Split(bufio.ScanLines)
|
||||||
helper.SetEventStreamHeaders(c)
|
helper.SetEventStreamHeaders(c)
|
||||||
@@ -135,7 +136,7 @@ func cozeChatStreamHandler(c *gin.Context, resp *http.Response, info *relaycommo
|
|||||||
}
|
}
|
||||||
|
|
||||||
if err := scanner.Err(); err != nil {
|
if err := scanner.Err(); err != nil {
|
||||||
return service.OpenAIErrorWrapper(err, "stream_scanner_error", http.StatusInternalServerError), nil
|
return types.NewError(err, types.ErrorCodeBadResponseBody), nil
|
||||||
}
|
}
|
||||||
helper.Done(c)
|
helper.Done(c)
|
||||||
|
|
||||||
|
|||||||
@@ -10,6 +10,7 @@ import (
|
|||||||
"one-api/relay/channel/openai"
|
"one-api/relay/channel/openai"
|
||||||
relaycommon "one-api/relay/common"
|
relaycommon "one-api/relay/common"
|
||||||
"one-api/relay/constant"
|
"one-api/relay/constant"
|
||||||
|
"one-api/types"
|
||||||
"strings"
|
"strings"
|
||||||
|
|
||||||
"github.com/gin-gonic/gin"
|
"github.com/gin-gonic/gin"
|
||||||
@@ -81,11 +82,11 @@ func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, request
|
|||||||
return channel.DoApiRequest(a, c, info, requestBody)
|
return channel.DoApiRequest(a, c, info, requestBody)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage any, err *dto.OpenAIErrorWithStatusCode) {
|
func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage any, err *types.NewAPIError) {
|
||||||
if info.IsStream {
|
if info.IsStream {
|
||||||
err, usage = openai.OaiStreamHandler(c, resp, info)
|
usage, err = openai.OaiStreamHandler(c, info, resp)
|
||||||
} else {
|
} else {
|
||||||
err, usage = openai.OpenaiHandler(c, resp, info)
|
usage, err = openai.OpenaiHandler(c, info, resp)
|
||||||
}
|
}
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -8,6 +8,7 @@ import (
|
|||||||
"one-api/dto"
|
"one-api/dto"
|
||||||
"one-api/relay/channel"
|
"one-api/relay/channel"
|
||||||
relaycommon "one-api/relay/common"
|
relaycommon "one-api/relay/common"
|
||||||
|
"one-api/types"
|
||||||
|
|
||||||
"github.com/gin-gonic/gin"
|
"github.com/gin-gonic/gin"
|
||||||
)
|
)
|
||||||
@@ -96,11 +97,11 @@ func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, request
|
|||||||
return channel.DoApiRequest(a, c, info, requestBody)
|
return channel.DoApiRequest(a, c, info, requestBody)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage any, err *dto.OpenAIErrorWithStatusCode) {
|
func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage any, err *types.NewAPIError) {
|
||||||
if info.IsStream {
|
if info.IsStream {
|
||||||
err, usage = difyStreamHandler(c, resp, info)
|
return difyStreamHandler(c, info, resp)
|
||||||
} else {
|
} else {
|
||||||
err, usage = difyHandler(c, resp, info)
|
return difyHandler(c, info, resp)
|
||||||
}
|
}
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -14,6 +14,7 @@ import (
|
|||||||
relaycommon "one-api/relay/common"
|
relaycommon "one-api/relay/common"
|
||||||
"one-api/relay/helper"
|
"one-api/relay/helper"
|
||||||
"one-api/service"
|
"one-api/service"
|
||||||
|
"one-api/types"
|
||||||
"os"
|
"os"
|
||||||
"strings"
|
"strings"
|
||||||
|
|
||||||
@@ -209,7 +210,7 @@ func streamResponseDify2OpenAI(difyResponse DifyChunkChatCompletionResponse) *dt
|
|||||||
return &response
|
return &response
|
||||||
}
|
}
|
||||||
|
|
||||||
func difyStreamHandler(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (*dto.OpenAIErrorWithStatusCode, *dto.Usage) {
|
func difyStreamHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http.Response) (*dto.Usage, *types.NewAPIError) {
|
||||||
var responseText string
|
var responseText string
|
||||||
usage := &dto.Usage{}
|
usage := &dto.Usage{}
|
||||||
var nodeToken int
|
var nodeToken int
|
||||||
@@ -247,20 +248,20 @@ func difyStreamHandler(c *gin.Context, resp *http.Response, info *relaycommon.Re
|
|||||||
usage = service.ResponseText2Usage(responseText, info.UpstreamModelName, info.PromptTokens)
|
usage = service.ResponseText2Usage(responseText, info.UpstreamModelName, info.PromptTokens)
|
||||||
}
|
}
|
||||||
usage.CompletionTokens += nodeToken
|
usage.CompletionTokens += nodeToken
|
||||||
return nil, usage
|
return usage, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func difyHandler(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (*dto.OpenAIErrorWithStatusCode, *dto.Usage) {
|
func difyHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http.Response) (*dto.Usage, *types.NewAPIError) {
|
||||||
var difyResponse DifyChatCompletionResponse
|
var difyResponse DifyChatCompletionResponse
|
||||||
responseBody, err := io.ReadAll(resp.Body)
|
responseBody, err := io.ReadAll(resp.Body)
|
||||||
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return service.OpenAIErrorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil
|
return nil, types.NewError(err, types.ErrorCodeBadResponseBody)
|
||||||
}
|
}
|
||||||
common.CloseResponseBodyGracefully(resp)
|
common.CloseResponseBodyGracefully(resp)
|
||||||
err = json.Unmarshal(responseBody, &difyResponse)
|
err = json.Unmarshal(responseBody, &difyResponse)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return service.OpenAIErrorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil
|
return nil, types.NewError(err, types.ErrorCodeBadResponseBody)
|
||||||
}
|
}
|
||||||
fullTextResponse := dto.OpenAITextResponse{
|
fullTextResponse := dto.OpenAITextResponse{
|
||||||
Id: difyResponse.ConversationId,
|
Id: difyResponse.ConversationId,
|
||||||
@@ -279,10 +280,10 @@ func difyHandler(c *gin.Context, resp *http.Response, info *relaycommon.RelayInf
|
|||||||
fullTextResponse.Choices = append(fullTextResponse.Choices, choice)
|
fullTextResponse.Choices = append(fullTextResponse.Choices, choice)
|
||||||
jsonResponse, err := json.Marshal(fullTextResponse)
|
jsonResponse, err := json.Marshal(fullTextResponse)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return service.OpenAIErrorWrapper(err, "marshal_response_body_failed", http.StatusInternalServerError), nil
|
return nil, types.NewError(err, types.ErrorCodeBadResponseBody)
|
||||||
}
|
}
|
||||||
c.Writer.Header().Set("Content-Type", "application/json")
|
c.Writer.Header().Set("Content-Type", "application/json")
|
||||||
c.Writer.WriteHeader(resp.StatusCode)
|
c.Writer.WriteHeader(resp.StatusCode)
|
||||||
_, err = c.Writer.Write(jsonResponse)
|
c.Writer.Write(jsonResponse)
|
||||||
return nil, &difyResponse.MetaData.Usage
|
return &difyResponse.MetaData.Usage, nil
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -11,8 +11,8 @@ import (
|
|||||||
"one-api/relay/channel"
|
"one-api/relay/channel"
|
||||||
relaycommon "one-api/relay/common"
|
relaycommon "one-api/relay/common"
|
||||||
"one-api/relay/constant"
|
"one-api/relay/constant"
|
||||||
"one-api/service"
|
|
||||||
"one-api/setting/model_setting"
|
"one-api/setting/model_setting"
|
||||||
|
"one-api/types"
|
||||||
"strings"
|
"strings"
|
||||||
|
|
||||||
"github.com/gin-gonic/gin"
|
"github.com/gin-gonic/gin"
|
||||||
@@ -168,30 +168,30 @@ func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, request
|
|||||||
return channel.DoApiRequest(a, c, info, requestBody)
|
return channel.DoApiRequest(a, c, info, requestBody)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage any, err *dto.OpenAIErrorWithStatusCode) {
|
func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage any, err *types.NewAPIError) {
|
||||||
if info.RelayMode == constant.RelayModeGemini {
|
if info.RelayMode == constant.RelayModeGemini {
|
||||||
if info.IsStream {
|
if info.IsStream {
|
||||||
return GeminiTextGenerationStreamHandler(c, resp, info)
|
return GeminiTextGenerationStreamHandler(c, info, resp)
|
||||||
} else {
|
} else {
|
||||||
return GeminiTextGenerationHandler(c, resp, info)
|
return GeminiTextGenerationHandler(c, info, resp)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
if strings.HasPrefix(info.UpstreamModelName, "imagen") {
|
if strings.HasPrefix(info.UpstreamModelName, "imagen") {
|
||||||
return GeminiImageHandler(c, resp, info)
|
return GeminiImageHandler(c, info, resp)
|
||||||
}
|
}
|
||||||
|
|
||||||
// check if the model is an embedding model
|
// check if the model is an embedding model
|
||||||
if strings.HasPrefix(info.UpstreamModelName, "text-embedding") ||
|
if strings.HasPrefix(info.UpstreamModelName, "text-embedding") ||
|
||||||
strings.HasPrefix(info.UpstreamModelName, "embedding") ||
|
strings.HasPrefix(info.UpstreamModelName, "embedding") ||
|
||||||
strings.HasPrefix(info.UpstreamModelName, "gemini-embedding") {
|
strings.HasPrefix(info.UpstreamModelName, "gemini-embedding") {
|
||||||
return GeminiEmbeddingHandler(c, resp, info)
|
return GeminiEmbeddingHandler(c, info, resp)
|
||||||
}
|
}
|
||||||
|
|
||||||
if info.IsStream {
|
if info.IsStream {
|
||||||
err, usage = GeminiChatStreamHandler(c, resp, info)
|
return GeminiChatStreamHandler(c, info, resp)
|
||||||
} else {
|
} else {
|
||||||
err, usage = GeminiChatHandler(c, resp, info)
|
return GeminiChatHandler(c, info, resp)
|
||||||
}
|
}
|
||||||
|
|
||||||
//if usage.(*dto.Usage).CompletionTokenDetails.ReasoningTokens > 100 {
|
//if usage.(*dto.Usage).CompletionTokenDetails.ReasoningTokens > 100 {
|
||||||
@@ -205,23 +205,23 @@ func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycom
|
|||||||
// }
|
// }
|
||||||
//}
|
//}
|
||||||
|
|
||||||
return
|
return nil, types.NewError(errors.New("not implemented"), types.ErrorCodeBadResponseBody)
|
||||||
}
|
}
|
||||||
|
|
||||||
func GeminiImageHandler(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage any, err *dto.OpenAIErrorWithStatusCode) {
|
func GeminiImageHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http.Response) (*dto.Usage, *types.NewAPIError) {
|
||||||
responseBody, readErr := io.ReadAll(resp.Body)
|
responseBody, readErr := io.ReadAll(resp.Body)
|
||||||
if readErr != nil {
|
if readErr != nil {
|
||||||
return nil, service.OpenAIErrorWrapper(readErr, "read_response_body_failed", http.StatusInternalServerError)
|
return nil, types.NewError(readErr, types.ErrorCodeBadResponseBody)
|
||||||
}
|
}
|
||||||
_ = resp.Body.Close()
|
_ = resp.Body.Close()
|
||||||
|
|
||||||
var geminiResponse GeminiImageResponse
|
var geminiResponse GeminiImageResponse
|
||||||
if jsonErr := json.Unmarshal(responseBody, &geminiResponse); jsonErr != nil {
|
if jsonErr := json.Unmarshal(responseBody, &geminiResponse); jsonErr != nil {
|
||||||
return nil, service.OpenAIErrorWrapper(jsonErr, "unmarshal_response_body_failed", http.StatusInternalServerError)
|
return nil, types.NewError(jsonErr, types.ErrorCodeBadResponseBody)
|
||||||
}
|
}
|
||||||
|
|
||||||
if len(geminiResponse.Predictions) == 0 {
|
if len(geminiResponse.Predictions) == 0 {
|
||||||
return nil, service.OpenAIErrorWrapper(errors.New("no images generated"), "no_images", http.StatusBadRequest)
|
return nil, types.NewError(errors.New("no images generated"), types.ErrorCodeBadResponseBody)
|
||||||
}
|
}
|
||||||
|
|
||||||
// convert to openai format response
|
// convert to openai format response
|
||||||
@@ -241,7 +241,7 @@ func GeminiImageHandler(c *gin.Context, resp *http.Response, info *relaycommon.R
|
|||||||
|
|
||||||
jsonResponse, jsonErr := json.Marshal(openAIResponse)
|
jsonResponse, jsonErr := json.Marshal(openAIResponse)
|
||||||
if jsonErr != nil {
|
if jsonErr != nil {
|
||||||
return nil, service.OpenAIErrorWrapper(jsonErr, "marshal_response_failed", http.StatusInternalServerError)
|
return nil, types.NewError(jsonErr, types.ErrorCodeBadResponseBody)
|
||||||
}
|
}
|
||||||
|
|
||||||
c.Writer.Header().Set("Content-Type", "application/json")
|
c.Writer.Header().Set("Content-Type", "application/json")
|
||||||
@@ -253,7 +253,7 @@ func GeminiImageHandler(c *gin.Context, resp *http.Response, info *relaycommon.R
|
|||||||
const imageTokens = 258
|
const imageTokens = 258
|
||||||
generatedImages := len(openAIResponse.Data)
|
generatedImages := len(openAIResponse.Data)
|
||||||
|
|
||||||
usage = &dto.Usage{
|
usage := &dto.Usage{
|
||||||
PromptTokens: imageTokens * generatedImages, // each generated image has fixed 258 tokens
|
PromptTokens: imageTokens * generatedImages, // each generated image has fixed 258 tokens
|
||||||
CompletionTokens: 0, // image generation does not calculate completion tokens
|
CompletionTokens: 0, // image generation does not calculate completion tokens
|
||||||
TotalTokens: imageTokens * generatedImages,
|
TotalTokens: imageTokens * generatedImages,
|
||||||
|
|||||||
@@ -8,18 +8,19 @@ import (
|
|||||||
relaycommon "one-api/relay/common"
|
relaycommon "one-api/relay/common"
|
||||||
"one-api/relay/helper"
|
"one-api/relay/helper"
|
||||||
"one-api/service"
|
"one-api/service"
|
||||||
|
"one-api/types"
|
||||||
"strings"
|
"strings"
|
||||||
|
|
||||||
"github.com/gin-gonic/gin"
|
"github.com/gin-gonic/gin"
|
||||||
)
|
)
|
||||||
|
|
||||||
func GeminiTextGenerationHandler(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (*dto.Usage, *dto.OpenAIErrorWithStatusCode) {
|
func GeminiTextGenerationHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http.Response) (*dto.Usage, *types.NewAPIError) {
|
||||||
defer common.CloseResponseBodyGracefully(resp)
|
defer common.CloseResponseBodyGracefully(resp)
|
||||||
|
|
||||||
// 读取响应体
|
// 读取响应体
|
||||||
responseBody, err := io.ReadAll(resp.Body)
|
responseBody, err := io.ReadAll(resp.Body)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, service.OpenAIErrorWrapper(err, "read_response_body_failed", http.StatusInternalServerError)
|
return nil, types.NewError(err, types.ErrorCodeBadResponseBody)
|
||||||
}
|
}
|
||||||
|
|
||||||
if common.DebugEnabled {
|
if common.DebugEnabled {
|
||||||
@@ -28,9 +29,9 @@ func GeminiTextGenerationHandler(c *gin.Context, resp *http.Response, info *rela
|
|||||||
|
|
||||||
// 解析为 Gemini 原生响应格式
|
// 解析为 Gemini 原生响应格式
|
||||||
var geminiResponse GeminiChatResponse
|
var geminiResponse GeminiChatResponse
|
||||||
err = common.UnmarshalJson(responseBody, &geminiResponse)
|
err = common.Unmarshal(responseBody, &geminiResponse)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, service.OpenAIErrorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError)
|
return nil, types.NewError(err, types.ErrorCodeBadResponseBody)
|
||||||
}
|
}
|
||||||
|
|
||||||
// 计算使用量(基于 UsageMetadata)
|
// 计算使用量(基于 UsageMetadata)
|
||||||
@@ -51,9 +52,9 @@ func GeminiTextGenerationHandler(c *gin.Context, resp *http.Response, info *rela
|
|||||||
}
|
}
|
||||||
|
|
||||||
// 直接返回 Gemini 原生格式的 JSON 响应
|
// 直接返回 Gemini 原生格式的 JSON 响应
|
||||||
jsonResponse, err := common.EncodeJson(geminiResponse)
|
jsonResponse, err := common.Marshal(geminiResponse)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, service.OpenAIErrorWrapper(err, "marshal_response_body_failed", http.StatusInternalServerError)
|
return nil, types.NewError(err, types.ErrorCodeBadResponseBody)
|
||||||
}
|
}
|
||||||
|
|
||||||
common.IOCopyBytesGracefully(c, resp, jsonResponse)
|
common.IOCopyBytesGracefully(c, resp, jsonResponse)
|
||||||
@@ -61,7 +62,7 @@ func GeminiTextGenerationHandler(c *gin.Context, resp *http.Response, info *rela
|
|||||||
return &usage, nil
|
return &usage, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func GeminiTextGenerationStreamHandler(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (*dto.Usage, *dto.OpenAIErrorWithStatusCode) {
|
func GeminiTextGenerationStreamHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http.Response) (*dto.Usage, *types.NewAPIError) {
|
||||||
var usage = &dto.Usage{}
|
var usage = &dto.Usage{}
|
||||||
var imageCount int
|
var imageCount int
|
||||||
|
|
||||||
|
|||||||
@@ -2,6 +2,7 @@ package gemini
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
"io"
|
"io"
|
||||||
"net/http"
|
"net/http"
|
||||||
@@ -12,6 +13,7 @@ import (
|
|||||||
"one-api/relay/helper"
|
"one-api/relay/helper"
|
||||||
"one-api/service"
|
"one-api/service"
|
||||||
"one-api/setting/model_setting"
|
"one-api/setting/model_setting"
|
||||||
|
"one-api/types"
|
||||||
"strconv"
|
"strconv"
|
||||||
"strings"
|
"strings"
|
||||||
"unicode/utf8"
|
"unicode/utf8"
|
||||||
@@ -792,7 +794,7 @@ func streamResponseGeminiChat2OpenAI(geminiResponse *GeminiChatResponse) (*dto.C
|
|||||||
return &response, isStop, hasImage
|
return &response, isStop, hasImage
|
||||||
}
|
}
|
||||||
|
|
||||||
func GeminiChatStreamHandler(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (*dto.OpenAIErrorWithStatusCode, *dto.Usage) {
|
func GeminiChatStreamHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http.Response) (*dto.Usage, *types.NewAPIError) {
|
||||||
// responseText := ""
|
// responseText := ""
|
||||||
id := helper.GetResponseID(c)
|
id := helper.GetResponseID(c)
|
||||||
createAt := common.GetTimestamp()
|
createAt := common.GetTimestamp()
|
||||||
@@ -858,33 +860,25 @@ func GeminiChatStreamHandler(c *gin.Context, resp *http.Response, info *relaycom
|
|||||||
}
|
}
|
||||||
helper.Done(c)
|
helper.Done(c)
|
||||||
//resp.Body.Close()
|
//resp.Body.Close()
|
||||||
return nil, usage
|
return usage, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func GeminiChatHandler(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (*dto.OpenAIErrorWithStatusCode, *dto.Usage) {
|
func GeminiChatHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http.Response) (*dto.Usage, *types.NewAPIError) {
|
||||||
responseBody, err := io.ReadAll(resp.Body)
|
responseBody, err := io.ReadAll(resp.Body)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return service.OpenAIErrorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil
|
return nil, types.NewError(err, types.ErrorCodeBadResponseBody)
|
||||||
}
|
}
|
||||||
common.CloseResponseBodyGracefully(resp)
|
common.CloseResponseBodyGracefully(resp)
|
||||||
if common.DebugEnabled {
|
if common.DebugEnabled {
|
||||||
println(string(responseBody))
|
println(string(responseBody))
|
||||||
}
|
}
|
||||||
var geminiResponse GeminiChatResponse
|
var geminiResponse GeminiChatResponse
|
||||||
err = common.UnmarshalJson(responseBody, &geminiResponse)
|
err = common.Unmarshal(responseBody, &geminiResponse)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return service.OpenAIErrorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil
|
return nil, types.NewError(err, types.ErrorCodeBadResponseBody)
|
||||||
}
|
}
|
||||||
if len(geminiResponse.Candidates) == 0 {
|
if len(geminiResponse.Candidates) == 0 {
|
||||||
return &dto.OpenAIErrorWithStatusCode{
|
return nil, types.NewError(errors.New("no candidates returned"), types.ErrorCodeBadResponseBody)
|
||||||
Error: dto.OpenAIError{
|
|
||||||
Message: "No candidates returned",
|
|
||||||
Type: "server_error",
|
|
||||||
Param: "",
|
|
||||||
Code: 500,
|
|
||||||
},
|
|
||||||
StatusCode: resp.StatusCode,
|
|
||||||
}, nil
|
|
||||||
}
|
}
|
||||||
fullTextResponse := responseGeminiChat2OpenAI(c, &geminiResponse)
|
fullTextResponse := responseGeminiChat2OpenAI(c, &geminiResponse)
|
||||||
fullTextResponse.Model = info.UpstreamModelName
|
fullTextResponse.Model = info.UpstreamModelName
|
||||||
@@ -908,25 +902,25 @@ func GeminiChatHandler(c *gin.Context, resp *http.Response, info *relaycommon.Re
|
|||||||
fullTextResponse.Usage = usage
|
fullTextResponse.Usage = usage
|
||||||
jsonResponse, err := json.Marshal(fullTextResponse)
|
jsonResponse, err := json.Marshal(fullTextResponse)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return service.OpenAIErrorWrapper(err, "marshal_response_body_failed", http.StatusInternalServerError), nil
|
return nil, types.NewError(err, types.ErrorCodeBadResponseBody)
|
||||||
}
|
}
|
||||||
c.Writer.Header().Set("Content-Type", "application/json")
|
c.Writer.Header().Set("Content-Type", "application/json")
|
||||||
c.Writer.WriteHeader(resp.StatusCode)
|
c.Writer.WriteHeader(resp.StatusCode)
|
||||||
_, err = c.Writer.Write(jsonResponse)
|
c.Writer.Write(jsonResponse)
|
||||||
return nil, &usage
|
return &usage, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func GeminiEmbeddingHandler(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage any, err *dto.OpenAIErrorWithStatusCode) {
|
func GeminiEmbeddingHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http.Response) (*dto.Usage, *types.NewAPIError) {
|
||||||
defer common.CloseResponseBodyGracefully(resp)
|
defer common.CloseResponseBodyGracefully(resp)
|
||||||
|
|
||||||
responseBody, readErr := io.ReadAll(resp.Body)
|
responseBody, readErr := io.ReadAll(resp.Body)
|
||||||
if readErr != nil {
|
if readErr != nil {
|
||||||
return nil, service.OpenAIErrorWrapper(readErr, "read_response_body_failed", http.StatusInternalServerError)
|
return nil, types.NewError(readErr, types.ErrorCodeBadResponseBody)
|
||||||
}
|
}
|
||||||
|
|
||||||
var geminiResponse GeminiEmbeddingResponse
|
var geminiResponse GeminiEmbeddingResponse
|
||||||
if jsonErr := json.Unmarshal(responseBody, &geminiResponse); jsonErr != nil {
|
if jsonErr := common.Unmarshal(responseBody, &geminiResponse); jsonErr != nil {
|
||||||
return nil, service.OpenAIErrorWrapper(jsonErr, "unmarshal_response_body_failed", http.StatusInternalServerError)
|
return nil, types.NewError(jsonErr, types.ErrorCodeBadResponseBody)
|
||||||
}
|
}
|
||||||
|
|
||||||
// convert to openai format response
|
// convert to openai format response
|
||||||
@@ -947,16 +941,16 @@ func GeminiEmbeddingHandler(c *gin.Context, resp *http.Response, info *relaycomm
|
|||||||
// Google has not yet clarified how embedding models will be billed
|
// Google has not yet clarified how embedding models will be billed
|
||||||
// refer to openai billing method to use input tokens billing
|
// refer to openai billing method to use input tokens billing
|
||||||
// https://platform.openai.com/docs/guides/embeddings#what-are-embeddings
|
// https://platform.openai.com/docs/guides/embeddings#what-are-embeddings
|
||||||
usage = &dto.Usage{
|
usage := &dto.Usage{
|
||||||
PromptTokens: info.PromptTokens,
|
PromptTokens: info.PromptTokens,
|
||||||
CompletionTokens: 0,
|
CompletionTokens: 0,
|
||||||
TotalTokens: info.PromptTokens,
|
TotalTokens: info.PromptTokens,
|
||||||
}
|
}
|
||||||
openAIResponse.Usage = *usage.(*dto.Usage)
|
openAIResponse.Usage = *usage
|
||||||
|
|
||||||
jsonResponse, jsonErr := common.EncodeJson(openAIResponse)
|
jsonResponse, jsonErr := common.Marshal(openAIResponse)
|
||||||
if jsonErr != nil {
|
if jsonErr != nil {
|
||||||
return nil, service.OpenAIErrorWrapper(jsonErr, "marshal_response_failed", http.StatusInternalServerError)
|
return nil, types.NewError(jsonErr, types.ErrorCodeBadResponseBody)
|
||||||
}
|
}
|
||||||
|
|
||||||
common.IOCopyBytesGracefully(c, resp, jsonResponse)
|
common.IOCopyBytesGracefully(c, resp, jsonResponse)
|
||||||
|
|||||||
@@ -11,6 +11,7 @@ import (
|
|||||||
relaycommon "one-api/relay/common"
|
relaycommon "one-api/relay/common"
|
||||||
"one-api/relay/common_handler"
|
"one-api/relay/common_handler"
|
||||||
"one-api/relay/constant"
|
"one-api/relay/constant"
|
||||||
|
"one-api/types"
|
||||||
|
|
||||||
"github.com/gin-gonic/gin"
|
"github.com/gin-gonic/gin"
|
||||||
)
|
)
|
||||||
@@ -73,11 +74,11 @@ func (a *Adaptor) ConvertEmbeddingRequest(c *gin.Context, info *relaycommon.Rela
|
|||||||
return request, nil
|
return request, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage any, err *dto.OpenAIErrorWithStatusCode) {
|
func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage any, err *types.NewAPIError) {
|
||||||
if info.RelayMode == constant.RelayModeRerank {
|
if info.RelayMode == constant.RelayModeRerank {
|
||||||
err, usage = common_handler.RerankHandler(c, info, resp)
|
usage, err = common_handler.RerankHandler(c, info, resp)
|
||||||
} else if info.RelayMode == constant.RelayModeEmbeddings {
|
} else if info.RelayMode == constant.RelayModeEmbeddings {
|
||||||
err, usage = openai.OpenaiHandler(c, resp, info)
|
usage, err = openai.OpenaiHandler(c, info, resp)
|
||||||
}
|
}
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -8,6 +8,7 @@ import (
|
|||||||
"one-api/relay/channel"
|
"one-api/relay/channel"
|
||||||
"one-api/relay/channel/openai"
|
"one-api/relay/channel/openai"
|
||||||
relaycommon "one-api/relay/common"
|
relaycommon "one-api/relay/common"
|
||||||
|
"one-api/types"
|
||||||
|
|
||||||
"github.com/gin-gonic/gin"
|
"github.com/gin-gonic/gin"
|
||||||
)
|
)
|
||||||
@@ -69,11 +70,11 @@ func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, request
|
|||||||
return channel.DoApiRequest(a, c, info, requestBody)
|
return channel.DoApiRequest(a, c, info, requestBody)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage any, err *dto.OpenAIErrorWithStatusCode) {
|
func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage any, err *types.NewAPIError) {
|
||||||
if info.IsStream {
|
if info.IsStream {
|
||||||
err, usage = openai.OaiStreamHandler(c, resp, info)
|
usage, err = openai.OaiStreamHandler(c, info, resp)
|
||||||
} else {
|
} else {
|
||||||
err, usage = openai.OpenaiHandler(c, resp, info)
|
usage, err = openai.OpenaiHandler(c, info, resp)
|
||||||
}
|
}
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -9,6 +9,7 @@ import (
|
|||||||
"one-api/relay/channel"
|
"one-api/relay/channel"
|
||||||
relaycommon "one-api/relay/common"
|
relaycommon "one-api/relay/common"
|
||||||
"one-api/relay/constant"
|
"one-api/relay/constant"
|
||||||
|
"one-api/types"
|
||||||
"strings"
|
"strings"
|
||||||
|
|
||||||
"github.com/gin-gonic/gin"
|
"github.com/gin-gonic/gin"
|
||||||
@@ -84,11 +85,11 @@ func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, request
|
|||||||
return channel.DoApiRequest(a, c, info, requestBody)
|
return channel.DoApiRequest(a, c, info, requestBody)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage any, err *dto.OpenAIErrorWithStatusCode) {
|
func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage any, err *types.NewAPIError) {
|
||||||
|
|
||||||
switch info.RelayMode {
|
switch info.RelayMode {
|
||||||
case constant.RelayModeEmbeddings:
|
case constant.RelayModeEmbeddings:
|
||||||
err, usage = mokaEmbeddingHandler(c, resp)
|
return mokaEmbeddingHandler(c, info, resp)
|
||||||
default:
|
default:
|
||||||
// err, usage = mokaHandler(c, resp)
|
// err, usage = mokaHandler(c, resp)
|
||||||
|
|
||||||
|
|||||||
@@ -2,12 +2,14 @@ package mokaai
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
"github.com/gin-gonic/gin"
|
|
||||||
"io"
|
"io"
|
||||||
"net/http"
|
"net/http"
|
||||||
"one-api/common"
|
"one-api/common"
|
||||||
"one-api/dto"
|
"one-api/dto"
|
||||||
"one-api/service"
|
relaycommon "one-api/relay/common"
|
||||||
|
"one-api/types"
|
||||||
|
|
||||||
|
"github.com/gin-gonic/gin"
|
||||||
)
|
)
|
||||||
|
|
||||||
func embeddingRequestOpenAI2Moka(request dto.GeneralOpenAIRequest) *dto.EmbeddingRequest {
|
func embeddingRequestOpenAI2Moka(request dto.GeneralOpenAIRequest) *dto.EmbeddingRequest {
|
||||||
@@ -48,16 +50,16 @@ func embeddingResponseMoka2OpenAI(response *dto.EmbeddingResponse) *dto.OpenAIEm
|
|||||||
return &openAIEmbeddingResponse
|
return &openAIEmbeddingResponse
|
||||||
}
|
}
|
||||||
|
|
||||||
func mokaEmbeddingHandler(c *gin.Context, resp *http.Response) (*dto.OpenAIErrorWithStatusCode, *dto.Usage) {
|
func mokaEmbeddingHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http.Response) (*dto.Usage, *types.NewAPIError) {
|
||||||
var baiduResponse dto.EmbeddingResponse
|
var baiduResponse dto.EmbeddingResponse
|
||||||
responseBody, err := io.ReadAll(resp.Body)
|
responseBody, err := io.ReadAll(resp.Body)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return service.OpenAIErrorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil
|
return nil, types.NewError(err, types.ErrorCodeBadResponseBody)
|
||||||
}
|
}
|
||||||
common.CloseResponseBodyGracefully(resp)
|
common.CloseResponseBodyGracefully(resp)
|
||||||
err = json.Unmarshal(responseBody, &baiduResponse)
|
err = json.Unmarshal(responseBody, &baiduResponse)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return service.OpenAIErrorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil
|
return nil, types.NewError(err, types.ErrorCodeBadResponseBody)
|
||||||
}
|
}
|
||||||
// if baiduResponse.ErrorMsg != "" {
|
// if baiduResponse.ErrorMsg != "" {
|
||||||
// return &dto.OpenAIErrorWithStatusCode{
|
// return &dto.OpenAIErrorWithStatusCode{
|
||||||
@@ -69,12 +71,12 @@ func mokaEmbeddingHandler(c *gin.Context, resp *http.Response) (*dto.OpenAIError
|
|||||||
// }, nil
|
// }, nil
|
||||||
// }
|
// }
|
||||||
fullTextResponse := embeddingResponseMoka2OpenAI(&baiduResponse)
|
fullTextResponse := embeddingResponseMoka2OpenAI(&baiduResponse)
|
||||||
jsonResponse, err := json.Marshal(fullTextResponse)
|
jsonResponse, err := common.Marshal(fullTextResponse)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return service.OpenAIErrorWrapper(err, "marshal_response_body_failed", http.StatusInternalServerError), nil
|
return nil, types.NewError(err, types.ErrorCodeBadResponseBody)
|
||||||
}
|
}
|
||||||
c.Writer.Header().Set("Content-Type", "application/json")
|
c.Writer.Header().Set("Content-Type", "application/json")
|
||||||
c.Writer.WriteHeader(resp.StatusCode)
|
c.Writer.WriteHeader(resp.StatusCode)
|
||||||
_, err = c.Writer.Write(jsonResponse)
|
common.IOCopyBytesGracefully(c, resp, jsonResponse)
|
||||||
return nil, &fullTextResponse.Usage
|
return &fullTextResponse.Usage, nil
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -9,6 +9,7 @@ import (
|
|||||||
"one-api/relay/channel/openai"
|
"one-api/relay/channel/openai"
|
||||||
relaycommon "one-api/relay/common"
|
relaycommon "one-api/relay/common"
|
||||||
relayconstant "one-api/relay/constant"
|
relayconstant "one-api/relay/constant"
|
||||||
|
"one-api/types"
|
||||||
|
|
||||||
"github.com/gin-gonic/gin"
|
"github.com/gin-gonic/gin"
|
||||||
)
|
)
|
||||||
@@ -74,14 +75,14 @@ func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, request
|
|||||||
return channel.DoApiRequest(a, c, info, requestBody)
|
return channel.DoApiRequest(a, c, info, requestBody)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage any, err *dto.OpenAIErrorWithStatusCode) {
|
func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage any, err *types.NewAPIError) {
|
||||||
if info.IsStream {
|
if info.IsStream {
|
||||||
err, usage = openai.OaiStreamHandler(c, resp, info)
|
usage, err = openai.OaiStreamHandler(c, info, resp)
|
||||||
} else {
|
} else {
|
||||||
if info.RelayMode == relayconstant.RelayModeEmbeddings {
|
if info.RelayMode == relayconstant.RelayModeEmbeddings {
|
||||||
err, usage = ollamaEmbeddingHandler(c, resp, info.PromptTokens, info.UpstreamModelName, info.RelayMode)
|
usage, err = ollamaEmbeddingHandler(c, info, resp)
|
||||||
} else {
|
} else {
|
||||||
err, usage = openai.OpenaiHandler(c, resp, info)
|
usage, err = openai.OpenaiHandler(c, info, resp)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
return
|
return
|
||||||
|
|||||||
@@ -1,15 +1,17 @@
|
|||||||
package ollama
|
package ollama
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"encoding/json"
|
|
||||||
"fmt"
|
"fmt"
|
||||||
"github.com/gin-gonic/gin"
|
|
||||||
"io"
|
"io"
|
||||||
"net/http"
|
"net/http"
|
||||||
"one-api/common"
|
"one-api/common"
|
||||||
"one-api/dto"
|
"one-api/dto"
|
||||||
|
relaycommon "one-api/relay/common"
|
||||||
"one-api/service"
|
"one-api/service"
|
||||||
|
"one-api/types"
|
||||||
"strings"
|
"strings"
|
||||||
|
|
||||||
|
"github.com/gin-gonic/gin"
|
||||||
)
|
)
|
||||||
|
|
||||||
func requestOpenAI2Ollama(request dto.GeneralOpenAIRequest) (*OllamaRequest, error) {
|
func requestOpenAI2Ollama(request dto.GeneralOpenAIRequest) (*OllamaRequest, error) {
|
||||||
@@ -82,19 +84,19 @@ func requestOpenAI2Embeddings(request dto.EmbeddingRequest) *OllamaEmbeddingRequ
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func ollamaEmbeddingHandler(c *gin.Context, resp *http.Response, promptTokens int, model string, relayMode int) (*dto.OpenAIErrorWithStatusCode, *dto.Usage) {
|
func ollamaEmbeddingHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http.Response) (*dto.Usage, *types.NewAPIError) {
|
||||||
var ollamaEmbeddingResponse OllamaEmbeddingResponse
|
var ollamaEmbeddingResponse OllamaEmbeddingResponse
|
||||||
responseBody, err := io.ReadAll(resp.Body)
|
responseBody, err := io.ReadAll(resp.Body)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return service.OpenAIErrorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil
|
return nil, types.NewError(err, types.ErrorCodeBadResponseBody)
|
||||||
}
|
}
|
||||||
common.CloseResponseBodyGracefully(resp)
|
common.CloseResponseBodyGracefully(resp)
|
||||||
err = json.Unmarshal(responseBody, &ollamaEmbeddingResponse)
|
err = common.Unmarshal(responseBody, &ollamaEmbeddingResponse)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return service.OpenAIErrorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil
|
return nil, types.NewError(err, types.ErrorCodeBadResponseBody)
|
||||||
}
|
}
|
||||||
if ollamaEmbeddingResponse.Error != "" {
|
if ollamaEmbeddingResponse.Error != "" {
|
||||||
return service.OpenAIErrorWrapper(err, "ollama_error", resp.StatusCode), nil
|
return nil, types.NewError(err, types.ErrorCodeBadResponseBody)
|
||||||
}
|
}
|
||||||
flattenedEmbeddings := flattenEmbeddings(ollamaEmbeddingResponse.Embedding)
|
flattenedEmbeddings := flattenEmbeddings(ollamaEmbeddingResponse.Embedding)
|
||||||
data := make([]dto.OpenAIEmbeddingResponseItem, 0, 1)
|
data := make([]dto.OpenAIEmbeddingResponseItem, 0, 1)
|
||||||
@@ -103,22 +105,22 @@ func ollamaEmbeddingHandler(c *gin.Context, resp *http.Response, promptTokens in
|
|||||||
Object: "embedding",
|
Object: "embedding",
|
||||||
})
|
})
|
||||||
usage := &dto.Usage{
|
usage := &dto.Usage{
|
||||||
TotalTokens: promptTokens,
|
TotalTokens: info.PromptTokens,
|
||||||
CompletionTokens: 0,
|
CompletionTokens: 0,
|
||||||
PromptTokens: promptTokens,
|
PromptTokens: info.PromptTokens,
|
||||||
}
|
}
|
||||||
embeddingResponse := &dto.OpenAIEmbeddingResponse{
|
embeddingResponse := &dto.OpenAIEmbeddingResponse{
|
||||||
Object: "list",
|
Object: "list",
|
||||||
Data: data,
|
Data: data,
|
||||||
Model: model,
|
Model: info.UpstreamModelName,
|
||||||
Usage: *usage,
|
Usage: *usage,
|
||||||
}
|
}
|
||||||
doResponseBody, err := json.Marshal(embeddingResponse)
|
doResponseBody, err := common.Marshal(embeddingResponse)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return service.OpenAIErrorWrapper(err, "marshal_response_body_failed", http.StatusInternalServerError), nil
|
return nil, types.NewError(err, types.ErrorCodeBadResponseBody)
|
||||||
}
|
}
|
||||||
common.IOCopyBytesGracefully(c, resp, doResponseBody)
|
common.IOCopyBytesGracefully(c, resp, doResponseBody)
|
||||||
return nil, usage
|
return usage, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func flattenEmbeddings(embeddings [][]float64) []float64 {
|
func flattenEmbeddings(embeddings [][]float64) []float64 {
|
||||||
|
|||||||
@@ -22,6 +22,7 @@ import (
|
|||||||
"one-api/relay/common_handler"
|
"one-api/relay/common_handler"
|
||||||
relayconstant "one-api/relay/constant"
|
relayconstant "one-api/relay/constant"
|
||||||
"one-api/service"
|
"one-api/service"
|
||||||
|
"one-api/types"
|
||||||
"path/filepath"
|
"path/filepath"
|
||||||
"strings"
|
"strings"
|
||||||
|
|
||||||
@@ -421,31 +422,31 @@ func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, request
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage any, err *dto.OpenAIErrorWithStatusCode) {
|
func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage any, err *types.NewAPIError) {
|
||||||
switch info.RelayMode {
|
switch info.RelayMode {
|
||||||
case relayconstant.RelayModeRealtime:
|
case relayconstant.RelayModeRealtime:
|
||||||
err, usage = OpenaiRealtimeHandler(c, info)
|
err, usage = OpenaiRealtimeHandler(c, info)
|
||||||
case relayconstant.RelayModeAudioSpeech:
|
case relayconstant.RelayModeAudioSpeech:
|
||||||
err, usage = OpenaiTTSHandler(c, resp, info)
|
usage = OpenaiTTSHandler(c, resp, info)
|
||||||
case relayconstant.RelayModeAudioTranslation:
|
case relayconstant.RelayModeAudioTranslation:
|
||||||
fallthrough
|
fallthrough
|
||||||
case relayconstant.RelayModeAudioTranscription:
|
case relayconstant.RelayModeAudioTranscription:
|
||||||
err, usage = OpenaiSTTHandler(c, resp, info, a.ResponseFormat)
|
err, usage = OpenaiSTTHandler(c, resp, info, a.ResponseFormat)
|
||||||
case relayconstant.RelayModeImagesGenerations, relayconstant.RelayModeImagesEdits:
|
case relayconstant.RelayModeImagesGenerations, relayconstant.RelayModeImagesEdits:
|
||||||
err, usage = OpenaiHandlerWithUsage(c, resp, info)
|
usage, err = OpenaiHandlerWithUsage(c, info, resp)
|
||||||
case relayconstant.RelayModeRerank:
|
case relayconstant.RelayModeRerank:
|
||||||
err, usage = common_handler.RerankHandler(c, info, resp)
|
usage, err = common_handler.RerankHandler(c, info, resp)
|
||||||
case relayconstant.RelayModeResponses:
|
case relayconstant.RelayModeResponses:
|
||||||
if info.IsStream {
|
if info.IsStream {
|
||||||
err, usage = OaiResponsesStreamHandler(c, resp, info)
|
usage, err = OaiResponsesStreamHandler(c, info, resp)
|
||||||
} else {
|
} else {
|
||||||
err, usage = OaiResponsesHandler(c, resp, info)
|
usage, err = OaiResponsesHandler(c, info, resp)
|
||||||
}
|
}
|
||||||
default:
|
default:
|
||||||
if info.IsStream {
|
if info.IsStream {
|
||||||
err, usage = OaiStreamHandler(c, resp, info)
|
usage, err = OaiStreamHandler(c, info, resp)
|
||||||
} else {
|
} else {
|
||||||
err, usage = OpenaiHandler(c, resp, info)
|
usage, err = OpenaiHandler(c, info, resp)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
return
|
return
|
||||||
|
|||||||
@@ -17,6 +17,8 @@ import (
|
|||||||
"path/filepath"
|
"path/filepath"
|
||||||
"strings"
|
"strings"
|
||||||
|
|
||||||
|
"one-api/types"
|
||||||
|
|
||||||
"github.com/bytedance/gopkg/util/gopool"
|
"github.com/bytedance/gopkg/util/gopool"
|
||||||
"github.com/gin-gonic/gin"
|
"github.com/gin-gonic/gin"
|
||||||
"github.com/gorilla/websocket"
|
"github.com/gorilla/websocket"
|
||||||
@@ -104,10 +106,10 @@ func sendStreamData(c *gin.Context, info *relaycommon.RelayInfo, data string, fo
|
|||||||
return helper.ObjectData(c, lastStreamResponse)
|
return helper.ObjectData(c, lastStreamResponse)
|
||||||
}
|
}
|
||||||
|
|
||||||
func OaiStreamHandler(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (*dto.OpenAIErrorWithStatusCode, *dto.Usage) {
|
func OaiStreamHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http.Response) (*dto.Usage, *types.NewAPIError) {
|
||||||
if resp == nil || resp.Body == nil {
|
if resp == nil || resp.Body == nil {
|
||||||
common.LogError(c, "invalid response or response body")
|
common.LogError(c, "invalid response or response body")
|
||||||
return service.OpenAIErrorWrapper(fmt.Errorf("invalid response"), "invalid_response", http.StatusInternalServerError), nil
|
return nil, types.NewError(fmt.Errorf("invalid response"), types.ErrorCodeBadResponse)
|
||||||
}
|
}
|
||||||
|
|
||||||
defer common.CloseResponseBodyGracefully(resp)
|
defer common.CloseResponseBodyGracefully(resp)
|
||||||
@@ -177,26 +179,23 @@ func OaiStreamHandler(c *gin.Context, resp *http.Response, info *relaycommon.Rel
|
|||||||
|
|
||||||
handleFinalResponse(c, info, lastStreamData, responseId, createAt, model, systemFingerprint, usage, containStreamUsage)
|
handleFinalResponse(c, info, lastStreamData, responseId, createAt, model, systemFingerprint, usage, containStreamUsage)
|
||||||
|
|
||||||
return nil, usage
|
return usage, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func OpenaiHandler(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (*dto.OpenAIErrorWithStatusCode, *dto.Usage) {
|
func OpenaiHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http.Response) (*dto.Usage, *types.NewAPIError) {
|
||||||
defer common.CloseResponseBodyGracefully(resp)
|
defer common.CloseResponseBodyGracefully(resp)
|
||||||
|
|
||||||
var simpleResponse dto.OpenAITextResponse
|
var simpleResponse dto.OpenAITextResponse
|
||||||
responseBody, err := io.ReadAll(resp.Body)
|
responseBody, err := io.ReadAll(resp.Body)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return service.OpenAIErrorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil
|
return nil, types.NewError(err, types.ErrorCodeReadResponseBodyFailed)
|
||||||
}
|
}
|
||||||
err = common.UnmarshalJson(responseBody, &simpleResponse)
|
err = common.Unmarshal(responseBody, &simpleResponse)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return service.OpenAIErrorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil
|
return nil, types.NewError(err, types.ErrorCodeBadResponseBody)
|
||||||
}
|
}
|
||||||
if simpleResponse.Error != nil && simpleResponse.Error.Type != "" {
|
if simpleResponse.Error != nil && simpleResponse.Error.Type != "" {
|
||||||
return &dto.OpenAIErrorWithStatusCode{
|
return nil, types.WithOpenAIError(*simpleResponse.Error, resp.StatusCode)
|
||||||
Error: *simpleResponse.Error,
|
|
||||||
StatusCode: resp.StatusCode,
|
|
||||||
}, nil
|
|
||||||
}
|
}
|
||||||
|
|
||||||
forceFormat := false
|
forceFormat := false
|
||||||
@@ -220,28 +219,28 @@ func OpenaiHandler(c *gin.Context, resp *http.Response, info *relaycommon.RelayI
|
|||||||
switch info.RelayFormat {
|
switch info.RelayFormat {
|
||||||
case relaycommon.RelayFormatOpenAI:
|
case relaycommon.RelayFormatOpenAI:
|
||||||
if forceFormat {
|
if forceFormat {
|
||||||
responseBody, err = common.EncodeJson(simpleResponse)
|
responseBody, err = common.Marshal(simpleResponse)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return service.OpenAIErrorWrapper(err, "marshal_response_body_failed", http.StatusInternalServerError), nil
|
return nil, types.NewError(err, types.ErrorCodeBadResponseBody)
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
break
|
break
|
||||||
}
|
}
|
||||||
case relaycommon.RelayFormatClaude:
|
case relaycommon.RelayFormatClaude:
|
||||||
claudeResp := service.ResponseOpenAI2Claude(&simpleResponse, info)
|
claudeResp := service.ResponseOpenAI2Claude(&simpleResponse, info)
|
||||||
claudeRespStr, err := common.EncodeJson(claudeResp)
|
claudeRespStr, err := common.Marshal(claudeResp)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return service.OpenAIErrorWrapper(err, "marshal_response_body_failed", http.StatusInternalServerError), nil
|
return nil, types.NewError(err, types.ErrorCodeBadResponseBody)
|
||||||
}
|
}
|
||||||
responseBody = claudeRespStr
|
responseBody = claudeRespStr
|
||||||
}
|
}
|
||||||
|
|
||||||
common.IOCopyBytesGracefully(c, resp, responseBody)
|
common.IOCopyBytesGracefully(c, resp, responseBody)
|
||||||
|
|
||||||
return nil, &simpleResponse.Usage
|
return &simpleResponse.Usage, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func OpenaiTTSHandler(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (*dto.OpenAIErrorWithStatusCode, *dto.Usage) {
|
func OpenaiTTSHandler(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) *dto.Usage {
|
||||||
// the status code has been judged before, if there is a body reading failure,
|
// the status code has been judged before, if there is a body reading failure,
|
||||||
// it should be regarded as a non-recoverable error, so it should not return err for external retry.
|
// it should be regarded as a non-recoverable error, so it should not return err for external retry.
|
||||||
// Analogous to nginx's load balancing, it will only retry if it can't be requested or
|
// Analogous to nginx's load balancing, it will only retry if it can't be requested or
|
||||||
@@ -261,20 +260,20 @@ func OpenaiTTSHandler(c *gin.Context, resp *http.Response, info *relaycommon.Rel
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
common.LogError(c, err.Error())
|
common.LogError(c, err.Error())
|
||||||
}
|
}
|
||||||
return nil, usage
|
return usage
|
||||||
}
|
}
|
||||||
|
|
||||||
func OpenaiSTTHandler(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo, responseFormat string) (*dto.OpenAIErrorWithStatusCode, *dto.Usage) {
|
func OpenaiSTTHandler(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo, responseFormat string) (*types.NewAPIError, *dto.Usage) {
|
||||||
defer common.CloseResponseBodyGracefully(resp)
|
defer common.CloseResponseBodyGracefully(resp)
|
||||||
|
|
||||||
// count tokens by audio file duration
|
// count tokens by audio file duration
|
||||||
audioTokens, err := countAudioTokens(c)
|
audioTokens, err := countAudioTokens(c)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return service.OpenAIErrorWrapper(err, "count_audio_tokens_failed", http.StatusInternalServerError), nil
|
return types.NewError(err, types.ErrorCodeCountTokenFailed), nil
|
||||||
}
|
}
|
||||||
responseBody, err := io.ReadAll(resp.Body)
|
responseBody, err := io.ReadAll(resp.Body)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return service.OpenAIErrorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil
|
return types.NewError(err, types.ErrorCodeReadResponseBodyFailed), nil
|
||||||
}
|
}
|
||||||
// 写入新的 response body
|
// 写入新的 response body
|
||||||
common.IOCopyBytesGracefully(c, resp, responseBody)
|
common.IOCopyBytesGracefully(c, resp, responseBody)
|
||||||
@@ -328,9 +327,9 @@ func countAudioTokens(c *gin.Context) (int, error) {
|
|||||||
return int(math.Round(math.Ceil(duration) / 60.0 * 1000)), nil // 1 minute 相当于 1k tokens
|
return int(math.Round(math.Ceil(duration) / 60.0 * 1000)), nil // 1 minute 相当于 1k tokens
|
||||||
}
|
}
|
||||||
|
|
||||||
func OpenaiRealtimeHandler(c *gin.Context, info *relaycommon.RelayInfo) (*dto.OpenAIErrorWithStatusCode, *dto.RealtimeUsage) {
|
func OpenaiRealtimeHandler(c *gin.Context, info *relaycommon.RelayInfo) (*types.NewAPIError, *dto.RealtimeUsage) {
|
||||||
if info == nil || info.ClientWs == nil || info.TargetWs == nil {
|
if info == nil || info.ClientWs == nil || info.TargetWs == nil {
|
||||||
return service.OpenAIErrorWrapper(fmt.Errorf("invalid websocket connection"), "invalid_connection", http.StatusBadRequest), nil
|
return types.NewError(fmt.Errorf("invalid websocket connection"), types.ErrorCodeBadResponse), nil
|
||||||
}
|
}
|
||||||
|
|
||||||
info.IsStream = true
|
info.IsStream = true
|
||||||
@@ -368,7 +367,7 @@ func OpenaiRealtimeHandler(c *gin.Context, info *relaycommon.RelayInfo) (*dto.Op
|
|||||||
}
|
}
|
||||||
|
|
||||||
realtimeEvent := &dto.RealtimeEvent{}
|
realtimeEvent := &dto.RealtimeEvent{}
|
||||||
err = common.UnmarshalJson(message, realtimeEvent)
|
err = common.Unmarshal(message, realtimeEvent)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
errChan <- fmt.Errorf("error unmarshalling message: %v", err)
|
errChan <- fmt.Errorf("error unmarshalling message: %v", err)
|
||||||
return
|
return
|
||||||
@@ -428,7 +427,7 @@ func OpenaiRealtimeHandler(c *gin.Context, info *relaycommon.RelayInfo) (*dto.Op
|
|||||||
}
|
}
|
||||||
info.SetFirstResponseTime()
|
info.SetFirstResponseTime()
|
||||||
realtimeEvent := &dto.RealtimeEvent{}
|
realtimeEvent := &dto.RealtimeEvent{}
|
||||||
err = common.UnmarshalJson(message, realtimeEvent)
|
err = common.Unmarshal(message, realtimeEvent)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
errChan <- fmt.Errorf("error unmarshalling message: %v", err)
|
errChan <- fmt.Errorf("error unmarshalling message: %v", err)
|
||||||
return
|
return
|
||||||
@@ -553,18 +552,18 @@ func preConsumeUsage(ctx *gin.Context, info *relaycommon.RelayInfo, usage *dto.R
|
|||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
func OpenaiHandlerWithUsage(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (*dto.OpenAIErrorWithStatusCode, *dto.Usage) {
|
func OpenaiHandlerWithUsage(c *gin.Context, info *relaycommon.RelayInfo, resp *http.Response) (*dto.Usage, *types.NewAPIError) {
|
||||||
defer common.CloseResponseBodyGracefully(resp)
|
defer common.CloseResponseBodyGracefully(resp)
|
||||||
|
|
||||||
responseBody, err := io.ReadAll(resp.Body)
|
responseBody, err := io.ReadAll(resp.Body)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return service.OpenAIErrorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil
|
return nil, types.NewError(err, types.ErrorCodeReadResponseBodyFailed)
|
||||||
}
|
}
|
||||||
|
|
||||||
var usageResp dto.SimpleResponse
|
var usageResp dto.SimpleResponse
|
||||||
err = common.UnmarshalJson(responseBody, &usageResp)
|
err = common.Unmarshal(responseBody, &usageResp)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return service.OpenAIErrorWrapper(err, "parse_response_body_failed", http.StatusInternalServerError), nil
|
return nil, types.NewError(err, types.ErrorCodeBadResponseBody)
|
||||||
}
|
}
|
||||||
|
|
||||||
// 写入新的 response body
|
// 写入新的 response body
|
||||||
@@ -584,5 +583,5 @@ func OpenaiHandlerWithUsage(c *gin.Context, resp *http.Response, info *relaycomm
|
|||||||
usageResp.PromptTokensDetails.ImageTokens += usageResp.InputTokensDetails.ImageTokens
|
usageResp.PromptTokensDetails.ImageTokens += usageResp.InputTokensDetails.ImageTokens
|
||||||
usageResp.PromptTokensDetails.TextTokens += usageResp.InputTokensDetails.TextTokens
|
usageResp.PromptTokensDetails.TextTokens += usageResp.InputTokensDetails.TextTokens
|
||||||
}
|
}
|
||||||
return nil, &usageResp.Usage
|
return &usageResp.Usage, nil
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -9,33 +9,27 @@ import (
|
|||||||
relaycommon "one-api/relay/common"
|
relaycommon "one-api/relay/common"
|
||||||
"one-api/relay/helper"
|
"one-api/relay/helper"
|
||||||
"one-api/service"
|
"one-api/service"
|
||||||
|
"one-api/types"
|
||||||
"strings"
|
"strings"
|
||||||
|
|
||||||
"github.com/gin-gonic/gin"
|
"github.com/gin-gonic/gin"
|
||||||
)
|
)
|
||||||
|
|
||||||
func OaiResponsesHandler(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (*dto.OpenAIErrorWithStatusCode, *dto.Usage) {
|
func OaiResponsesHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http.Response) (*dto.Usage, *types.NewAPIError) {
|
||||||
defer common.CloseResponseBodyGracefully(resp)
|
defer common.CloseResponseBodyGracefully(resp)
|
||||||
|
|
||||||
// read response body
|
// read response body
|
||||||
var responsesResponse dto.OpenAIResponsesResponse
|
var responsesResponse dto.OpenAIResponsesResponse
|
||||||
responseBody, err := io.ReadAll(resp.Body)
|
responseBody, err := io.ReadAll(resp.Body)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return service.OpenAIErrorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil
|
return nil, types.NewError(err, types.ErrorCodeReadResponseBodyFailed)
|
||||||
}
|
}
|
||||||
err = common.UnmarshalJson(responseBody, &responsesResponse)
|
err = common.Unmarshal(responseBody, &responsesResponse)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return service.OpenAIErrorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil
|
return nil, types.NewError(err, types.ErrorCodeBadResponseBody)
|
||||||
}
|
}
|
||||||
if responsesResponse.Error != nil {
|
if responsesResponse.Error != nil {
|
||||||
return &dto.OpenAIErrorWithStatusCode{
|
return nil, types.WithOpenAIError(*responsesResponse.Error, resp.StatusCode)
|
||||||
Error: dto.OpenAIError{
|
|
||||||
Message: responsesResponse.Error.Message,
|
|
||||||
Type: "openai_error",
|
|
||||||
Code: responsesResponse.Error.Code,
|
|
||||||
},
|
|
||||||
StatusCode: resp.StatusCode,
|
|
||||||
}, nil
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// 写入新的 response body
|
// 写入新的 response body
|
||||||
@@ -50,13 +44,13 @@ func OaiResponsesHandler(c *gin.Context, resp *http.Response, info *relaycommon.
|
|||||||
for _, tool := range responsesResponse.Tools {
|
for _, tool := range responsesResponse.Tools {
|
||||||
info.ResponsesUsageInfo.BuiltInTools[tool.Type].CallCount++
|
info.ResponsesUsageInfo.BuiltInTools[tool.Type].CallCount++
|
||||||
}
|
}
|
||||||
return nil, &usage
|
return &usage, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func OaiResponsesStreamHandler(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (*dto.OpenAIErrorWithStatusCode, *dto.Usage) {
|
func OaiResponsesStreamHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http.Response) (*dto.Usage, *types.NewAPIError) {
|
||||||
if resp == nil || resp.Body == nil {
|
if resp == nil || resp.Body == nil {
|
||||||
common.LogError(c, "invalid response or response body")
|
common.LogError(c, "invalid response or response body")
|
||||||
return service.OpenAIErrorWrapper(fmt.Errorf("invalid response"), "invalid_response", http.StatusInternalServerError), nil
|
return nil, types.NewError(fmt.Errorf("invalid response"), types.ErrorCodeBadResponse)
|
||||||
}
|
}
|
||||||
|
|
||||||
var usage = &dto.Usage{}
|
var usage = &dto.Usage{}
|
||||||
@@ -99,5 +93,5 @@ func OaiResponsesStreamHandler(c *gin.Context, resp *http.Response, info *relayc
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
return nil, usage
|
return usage, nil
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -9,6 +9,7 @@ import (
|
|||||||
"one-api/relay/channel"
|
"one-api/relay/channel"
|
||||||
relaycommon "one-api/relay/common"
|
relaycommon "one-api/relay/common"
|
||||||
"one-api/service"
|
"one-api/service"
|
||||||
|
"one-api/types"
|
||||||
|
|
||||||
"github.com/gin-gonic/gin"
|
"github.com/gin-gonic/gin"
|
||||||
)
|
)
|
||||||
@@ -70,13 +71,13 @@ func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, request
|
|||||||
return channel.DoApiRequest(a, c, info, requestBody)
|
return channel.DoApiRequest(a, c, info, requestBody)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage any, err *dto.OpenAIErrorWithStatusCode) {
|
func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage any, err *types.NewAPIError) {
|
||||||
if info.IsStream {
|
if info.IsStream {
|
||||||
var responseText string
|
var responseText string
|
||||||
err, responseText = palmStreamHandler(c, resp)
|
err, responseText = palmStreamHandler(c, resp)
|
||||||
usage = service.ResponseText2Usage(responseText, info.UpstreamModelName, info.PromptTokens)
|
usage = service.ResponseText2Usage(responseText, info.UpstreamModelName, info.PromptTokens)
|
||||||
} else {
|
} else {
|
||||||
err, usage = palmHandler(c, resp, info.PromptTokens, info.UpstreamModelName)
|
usage, err = palmHandler(c, info, resp)
|
||||||
}
|
}
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -2,14 +2,17 @@ package palm
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
"github.com/gin-gonic/gin"
|
|
||||||
"io"
|
"io"
|
||||||
"net/http"
|
"net/http"
|
||||||
"one-api/common"
|
"one-api/common"
|
||||||
"one-api/constant"
|
"one-api/constant"
|
||||||
"one-api/dto"
|
"one-api/dto"
|
||||||
|
relaycommon "one-api/relay/common"
|
||||||
"one-api/relay/helper"
|
"one-api/relay/helper"
|
||||||
"one-api/service"
|
"one-api/service"
|
||||||
|
"one-api/types"
|
||||||
|
|
||||||
|
"github.com/gin-gonic/gin"
|
||||||
)
|
)
|
||||||
|
|
||||||
// https://developers.generativeai.google/api/rest/generativelanguage/models/generateMessage#request-body
|
// https://developers.generativeai.google/api/rest/generativelanguage/models/generateMessage#request-body
|
||||||
@@ -70,7 +73,7 @@ func streamResponsePaLM2OpenAI(palmResponse *PaLMChatResponse) *dto.ChatCompleti
|
|||||||
return &response
|
return &response
|
||||||
}
|
}
|
||||||
|
|
||||||
func palmStreamHandler(c *gin.Context, resp *http.Response) (*dto.OpenAIErrorWithStatusCode, string) {
|
func palmStreamHandler(c *gin.Context, resp *http.Response) (*types.NewAPIError, string) {
|
||||||
responseText := ""
|
responseText := ""
|
||||||
responseId := helper.GetResponseID(c)
|
responseId := helper.GetResponseID(c)
|
||||||
createdTime := common.GetTimestamp()
|
createdTime := common.GetTimestamp()
|
||||||
@@ -121,42 +124,39 @@ func palmStreamHandler(c *gin.Context, resp *http.Response) (*dto.OpenAIErrorWit
|
|||||||
return nil, responseText
|
return nil, responseText
|
||||||
}
|
}
|
||||||
|
|
||||||
func palmHandler(c *gin.Context, resp *http.Response, promptTokens int, model string) (*dto.OpenAIErrorWithStatusCode, *dto.Usage) {
|
func palmHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http.Response) (*dto.Usage, *types.NewAPIError) {
|
||||||
responseBody, err := io.ReadAll(resp.Body)
|
responseBody, err := io.ReadAll(resp.Body)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return service.OpenAIErrorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil
|
return nil, types.NewError(err, types.ErrorCodeReadResponseBodyFailed)
|
||||||
}
|
}
|
||||||
common.CloseResponseBodyGracefully(resp)
|
common.CloseResponseBodyGracefully(resp)
|
||||||
var palmResponse PaLMChatResponse
|
var palmResponse PaLMChatResponse
|
||||||
err = json.Unmarshal(responseBody, &palmResponse)
|
err = json.Unmarshal(responseBody, &palmResponse)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return service.OpenAIErrorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil
|
return nil, types.NewError(err, types.ErrorCodeBadResponseBody)
|
||||||
}
|
}
|
||||||
if palmResponse.Error.Code != 0 || len(palmResponse.Candidates) == 0 {
|
if palmResponse.Error.Code != 0 || len(palmResponse.Candidates) == 0 {
|
||||||
return &dto.OpenAIErrorWithStatusCode{
|
return nil, types.WithOpenAIError(types.OpenAIError{
|
||||||
Error: dto.OpenAIError{
|
Message: palmResponse.Error.Message,
|
||||||
Message: palmResponse.Error.Message,
|
Type: palmResponse.Error.Status,
|
||||||
Type: palmResponse.Error.Status,
|
Param: "",
|
||||||
Param: "",
|
Code: palmResponse.Error.Code,
|
||||||
Code: palmResponse.Error.Code,
|
}, resp.StatusCode)
|
||||||
},
|
|
||||||
StatusCode: resp.StatusCode,
|
|
||||||
}, nil
|
|
||||||
}
|
}
|
||||||
fullTextResponse := responsePaLM2OpenAI(&palmResponse)
|
fullTextResponse := responsePaLM2OpenAI(&palmResponse)
|
||||||
completionTokens := service.CountTextToken(palmResponse.Candidates[0].Content, model)
|
completionTokens := service.CountTextToken(palmResponse.Candidates[0].Content, info.UpstreamModelName)
|
||||||
usage := dto.Usage{
|
usage := dto.Usage{
|
||||||
PromptTokens: promptTokens,
|
PromptTokens: info.PromptTokens,
|
||||||
CompletionTokens: completionTokens,
|
CompletionTokens: completionTokens,
|
||||||
TotalTokens: promptTokens + completionTokens,
|
TotalTokens: info.PromptTokens + completionTokens,
|
||||||
}
|
}
|
||||||
fullTextResponse.Usage = usage
|
fullTextResponse.Usage = usage
|
||||||
jsonResponse, err := json.Marshal(fullTextResponse)
|
jsonResponse, err := common.Marshal(fullTextResponse)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return service.OpenAIErrorWrapper(err, "marshal_response_body_failed", http.StatusInternalServerError), nil
|
return nil, types.NewError(err, types.ErrorCodeBadResponseBody)
|
||||||
}
|
}
|
||||||
c.Writer.Header().Set("Content-Type", "application/json")
|
c.Writer.Header().Set("Content-Type", "application/json")
|
||||||
c.Writer.WriteHeader(resp.StatusCode)
|
c.Writer.WriteHeader(resp.StatusCode)
|
||||||
_, err = c.Writer.Write(jsonResponse)
|
common.IOCopyBytesGracefully(c, resp, jsonResponse)
|
||||||
return nil, &usage
|
return &usage, nil
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -9,6 +9,7 @@ import (
|
|||||||
"one-api/relay/channel"
|
"one-api/relay/channel"
|
||||||
"one-api/relay/channel/openai"
|
"one-api/relay/channel/openai"
|
||||||
relaycommon "one-api/relay/common"
|
relaycommon "one-api/relay/common"
|
||||||
|
"one-api/types"
|
||||||
|
|
||||||
"github.com/gin-gonic/gin"
|
"github.com/gin-gonic/gin"
|
||||||
)
|
)
|
||||||
@@ -73,11 +74,11 @@ func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, request
|
|||||||
return channel.DoApiRequest(a, c, info, requestBody)
|
return channel.DoApiRequest(a, c, info, requestBody)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage any, err *dto.OpenAIErrorWithStatusCode) {
|
func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage any, err *types.NewAPIError) {
|
||||||
if info.IsStream {
|
if info.IsStream {
|
||||||
err, usage = openai.OaiStreamHandler(c, resp, info)
|
usage, err = openai.OaiStreamHandler(c, info, resp)
|
||||||
} else {
|
} else {
|
||||||
err, usage = openai.OpenaiHandler(c, resp, info)
|
usage, err = openai.OpenaiHandler(c, info, resp)
|
||||||
}
|
}
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -10,6 +10,7 @@ import (
|
|||||||
"one-api/relay/channel/openai"
|
"one-api/relay/channel/openai"
|
||||||
relaycommon "one-api/relay/common"
|
relaycommon "one-api/relay/common"
|
||||||
"one-api/relay/constant"
|
"one-api/relay/constant"
|
||||||
|
"one-api/types"
|
||||||
|
|
||||||
"github.com/gin-gonic/gin"
|
"github.com/gin-gonic/gin"
|
||||||
)
|
)
|
||||||
@@ -76,20 +77,20 @@ func (a *Adaptor) ConvertEmbeddingRequest(c *gin.Context, info *relaycommon.Rela
|
|||||||
return request, nil
|
return request, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage any, err *dto.OpenAIErrorWithStatusCode) {
|
func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage any, err *types.NewAPIError) {
|
||||||
switch info.RelayMode {
|
switch info.RelayMode {
|
||||||
case constant.RelayModeRerank:
|
case constant.RelayModeRerank:
|
||||||
err, usage = siliconflowRerankHandler(c, resp)
|
usage, err = siliconflowRerankHandler(c, info, resp)
|
||||||
case constant.RelayModeCompletions:
|
case constant.RelayModeCompletions:
|
||||||
fallthrough
|
fallthrough
|
||||||
case constant.RelayModeChatCompletions:
|
case constant.RelayModeChatCompletions:
|
||||||
if info.IsStream {
|
if info.IsStream {
|
||||||
err, usage = openai.OaiStreamHandler(c, resp, info)
|
usage, err = openai.OaiStreamHandler(c, info, resp)
|
||||||
} else {
|
} else {
|
||||||
err, usage = openai.OpenaiHandler(c, resp, info)
|
usage, err = openai.OpenaiHandler(c, info, resp)
|
||||||
}
|
}
|
||||||
case constant.RelayModeEmbeddings:
|
case constant.RelayModeEmbeddings:
|
||||||
err, usage = openai.OpenaiHandler(c, resp, info)
|
usage, err = openai.OpenaiHandler(c, info, resp)
|
||||||
}
|
}
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -2,24 +2,26 @@ package siliconflow
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
"github.com/gin-gonic/gin"
|
|
||||||
"io"
|
"io"
|
||||||
"net/http"
|
"net/http"
|
||||||
"one-api/common"
|
"one-api/common"
|
||||||
"one-api/dto"
|
"one-api/dto"
|
||||||
"one-api/service"
|
relaycommon "one-api/relay/common"
|
||||||
|
"one-api/types"
|
||||||
|
|
||||||
|
"github.com/gin-gonic/gin"
|
||||||
)
|
)
|
||||||
|
|
||||||
func siliconflowRerankHandler(c *gin.Context, resp *http.Response) (*dto.OpenAIErrorWithStatusCode, *dto.Usage) {
|
func siliconflowRerankHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http.Response) (*dto.Usage, *types.NewAPIError) {
|
||||||
responseBody, err := io.ReadAll(resp.Body)
|
responseBody, err := io.ReadAll(resp.Body)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return service.OpenAIErrorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil
|
return nil, types.NewError(err, types.ErrorCodeReadResponseBodyFailed)
|
||||||
}
|
}
|
||||||
common.CloseResponseBodyGracefully(resp)
|
common.CloseResponseBodyGracefully(resp)
|
||||||
var siliconflowResp SFRerankResponse
|
var siliconflowResp SFRerankResponse
|
||||||
err = json.Unmarshal(responseBody, &siliconflowResp)
|
err = json.Unmarshal(responseBody, &siliconflowResp)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return service.OpenAIErrorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil
|
return nil, types.NewError(err, types.ErrorCodeBadResponseBody)
|
||||||
}
|
}
|
||||||
usage := &dto.Usage{
|
usage := &dto.Usage{
|
||||||
PromptTokens: siliconflowResp.Meta.Tokens.InputTokens,
|
PromptTokens: siliconflowResp.Meta.Tokens.InputTokens,
|
||||||
@@ -33,10 +35,10 @@ func siliconflowRerankHandler(c *gin.Context, resp *http.Response) (*dto.OpenAIE
|
|||||||
|
|
||||||
jsonResponse, err := json.Marshal(rerankResp)
|
jsonResponse, err := json.Marshal(rerankResp)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return service.OpenAIErrorWrapper(err, "marshal_response_body_failed", http.StatusInternalServerError), nil
|
return nil, types.NewError(err, types.ErrorCodeBadResponseBody)
|
||||||
}
|
}
|
||||||
c.Writer.Header().Set("Content-Type", "application/json")
|
c.Writer.Header().Set("Content-Type", "application/json")
|
||||||
c.Writer.WriteHeader(resp.StatusCode)
|
c.Writer.WriteHeader(resp.StatusCode)
|
||||||
_, err = c.Writer.Write(jsonResponse)
|
common.IOCopyBytesGracefully(c, resp, jsonResponse)
|
||||||
return nil, usage
|
return usage, nil
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -9,7 +9,7 @@ import (
|
|||||||
"one-api/dto"
|
"one-api/dto"
|
||||||
"one-api/relay/channel"
|
"one-api/relay/channel"
|
||||||
relaycommon "one-api/relay/common"
|
relaycommon "one-api/relay/common"
|
||||||
"one-api/service"
|
"one-api/types"
|
||||||
"strconv"
|
"strconv"
|
||||||
"strings"
|
"strings"
|
||||||
|
|
||||||
@@ -94,13 +94,11 @@ func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, request
|
|||||||
return channel.DoApiRequest(a, c, info, requestBody)
|
return channel.DoApiRequest(a, c, info, requestBody)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage any, err *dto.OpenAIErrorWithStatusCode) {
|
func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage any, err *types.NewAPIError) {
|
||||||
if info.IsStream {
|
if info.IsStream {
|
||||||
var responseText string
|
usage, err = tencentStreamHandler(c, info, resp)
|
||||||
err, responseText = tencentStreamHandler(c, resp)
|
|
||||||
usage = service.ResponseText2Usage(responseText, info.UpstreamModelName, info.PromptTokens)
|
|
||||||
} else {
|
} else {
|
||||||
err, usage = tencentHandler(c, resp)
|
usage, err = tencentHandler(c, info, resp)
|
||||||
}
|
}
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -8,17 +8,20 @@ import (
|
|||||||
"encoding/json"
|
"encoding/json"
|
||||||
"errors"
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
"github.com/gin-gonic/gin"
|
|
||||||
"io"
|
"io"
|
||||||
"net/http"
|
"net/http"
|
||||||
"one-api/common"
|
"one-api/common"
|
||||||
"one-api/constant"
|
"one-api/constant"
|
||||||
"one-api/dto"
|
"one-api/dto"
|
||||||
|
relaycommon "one-api/relay/common"
|
||||||
"one-api/relay/helper"
|
"one-api/relay/helper"
|
||||||
"one-api/service"
|
"one-api/service"
|
||||||
|
"one-api/types"
|
||||||
"strconv"
|
"strconv"
|
||||||
"strings"
|
"strings"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
|
"github.com/gin-gonic/gin"
|
||||||
)
|
)
|
||||||
|
|
||||||
// https://cloud.tencent.com/document/product/1729/97732
|
// https://cloud.tencent.com/document/product/1729/97732
|
||||||
@@ -86,7 +89,7 @@ func streamResponseTencent2OpenAI(TencentResponse *TencentChatResponse) *dto.Cha
|
|||||||
return &response
|
return &response
|
||||||
}
|
}
|
||||||
|
|
||||||
func tencentStreamHandler(c *gin.Context, resp *http.Response) (*dto.OpenAIErrorWithStatusCode, string) {
|
func tencentStreamHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http.Response) (*dto.Usage, *types.NewAPIError) {
|
||||||
var responseText string
|
var responseText string
|
||||||
scanner := bufio.NewScanner(resp.Body)
|
scanner := bufio.NewScanner(resp.Body)
|
||||||
scanner.Split(bufio.ScanLines)
|
scanner.Split(bufio.ScanLines)
|
||||||
@@ -126,38 +129,35 @@ func tencentStreamHandler(c *gin.Context, resp *http.Response) (*dto.OpenAIError
|
|||||||
|
|
||||||
common.CloseResponseBodyGracefully(resp)
|
common.CloseResponseBodyGracefully(resp)
|
||||||
|
|
||||||
return nil, responseText
|
return service.ResponseText2Usage(responseText, info.UpstreamModelName, info.PromptTokens), nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func tencentHandler(c *gin.Context, resp *http.Response) (*dto.OpenAIErrorWithStatusCode, *dto.Usage) {
|
func tencentHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http.Response) (*dto.Usage, *types.NewAPIError) {
|
||||||
var tencentSb TencentChatResponseSB
|
var tencentSb TencentChatResponseSB
|
||||||
responseBody, err := io.ReadAll(resp.Body)
|
responseBody, err := io.ReadAll(resp.Body)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return service.OpenAIErrorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil
|
return nil, types.NewError(err, types.ErrorCodeReadResponseBodyFailed)
|
||||||
}
|
}
|
||||||
common.CloseResponseBodyGracefully(resp)
|
common.CloseResponseBodyGracefully(resp)
|
||||||
err = json.Unmarshal(responseBody, &tencentSb)
|
err = json.Unmarshal(responseBody, &tencentSb)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return service.OpenAIErrorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil
|
return nil, types.NewError(err, types.ErrorCodeBadResponseBody)
|
||||||
}
|
}
|
||||||
if tencentSb.Response.Error.Code != 0 {
|
if tencentSb.Response.Error.Code != 0 {
|
||||||
return &dto.OpenAIErrorWithStatusCode{
|
return nil, types.WithOpenAIError(types.OpenAIError{
|
||||||
Error: dto.OpenAIError{
|
Message: tencentSb.Response.Error.Message,
|
||||||
Message: tencentSb.Response.Error.Message,
|
Code: tencentSb.Response.Error.Code,
|
||||||
Code: tencentSb.Response.Error.Code,
|
}, resp.StatusCode)
|
||||||
},
|
|
||||||
StatusCode: resp.StatusCode,
|
|
||||||
}, nil
|
|
||||||
}
|
}
|
||||||
fullTextResponse := responseTencent2OpenAI(&tencentSb.Response)
|
fullTextResponse := responseTencent2OpenAI(&tencentSb.Response)
|
||||||
jsonResponse, err := json.Marshal(fullTextResponse)
|
jsonResponse, err := common.Marshal(fullTextResponse)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return service.OpenAIErrorWrapper(err, "marshal_response_body_failed", http.StatusInternalServerError), nil
|
return nil, types.NewError(err, types.ErrorCodeBadResponseBody)
|
||||||
}
|
}
|
||||||
c.Writer.Header().Set("Content-Type", "application/json")
|
c.Writer.Header().Set("Content-Type", "application/json")
|
||||||
c.Writer.WriteHeader(resp.StatusCode)
|
c.Writer.WriteHeader(resp.StatusCode)
|
||||||
_, err = c.Writer.Write(jsonResponse)
|
common.IOCopyBytesGracefully(c, resp, jsonResponse)
|
||||||
return nil, &fullTextResponse.Usage
|
return &fullTextResponse.Usage, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func parseTencentConfig(config string) (appId int64, secretId string, secretKey string, err error) {
|
func parseTencentConfig(config string) (appId int64, secretId string, secretKey string, err error) {
|
||||||
|
|||||||
@@ -14,6 +14,7 @@ import (
|
|||||||
relaycommon "one-api/relay/common"
|
relaycommon "one-api/relay/common"
|
||||||
"one-api/relay/constant"
|
"one-api/relay/constant"
|
||||||
"one-api/setting/model_setting"
|
"one-api/setting/model_setting"
|
||||||
|
"one-api/types"
|
||||||
"strings"
|
"strings"
|
||||||
|
|
||||||
"github.com/gin-gonic/gin"
|
"github.com/gin-gonic/gin"
|
||||||
@@ -208,19 +209,19 @@ func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, request
|
|||||||
return channel.DoApiRequest(a, c, info, requestBody)
|
return channel.DoApiRequest(a, c, info, requestBody)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage any, err *dto.OpenAIErrorWithStatusCode) {
|
func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage any, err *types.NewAPIError) {
|
||||||
if info.IsStream {
|
if info.IsStream {
|
||||||
switch a.RequestMode {
|
switch a.RequestMode {
|
||||||
case RequestModeClaude:
|
case RequestModeClaude:
|
||||||
err, usage = claude.ClaudeStreamHandler(c, resp, info, claude.RequestModeMessage)
|
err, usage = claude.ClaudeStreamHandler(c, resp, info, claude.RequestModeMessage)
|
||||||
case RequestModeGemini:
|
case RequestModeGemini:
|
||||||
if info.RelayMode == constant.RelayModeGemini {
|
if info.RelayMode == constant.RelayModeGemini {
|
||||||
usage, err = gemini.GeminiTextGenerationStreamHandler(c, resp, info)
|
usage, err = gemini.GeminiTextGenerationStreamHandler(c, info, resp)
|
||||||
} else {
|
} else {
|
||||||
err, usage = gemini.GeminiChatStreamHandler(c, resp, info)
|
usage, err = gemini.GeminiChatStreamHandler(c, info, resp)
|
||||||
}
|
}
|
||||||
case RequestModeLlama:
|
case RequestModeLlama:
|
||||||
err, usage = openai.OaiStreamHandler(c, resp, info)
|
usage, err = openai.OaiStreamHandler(c, info, resp)
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
switch a.RequestMode {
|
switch a.RequestMode {
|
||||||
@@ -228,12 +229,12 @@ func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycom
|
|||||||
err, usage = claude.ClaudeHandler(c, resp, claude.RequestModeMessage, info)
|
err, usage = claude.ClaudeHandler(c, resp, claude.RequestModeMessage, info)
|
||||||
case RequestModeGemini:
|
case RequestModeGemini:
|
||||||
if info.RelayMode == constant.RelayModeGemini {
|
if info.RelayMode == constant.RelayModeGemini {
|
||||||
usage, err = gemini.GeminiTextGenerationHandler(c, resp, info)
|
usage, err = gemini.GeminiTextGenerationHandler(c, info, resp)
|
||||||
} else {
|
} else {
|
||||||
err, usage = gemini.GeminiChatHandler(c, resp, info)
|
usage, err = gemini.GeminiChatHandler(c, info, resp)
|
||||||
}
|
}
|
||||||
case RequestModeLlama:
|
case RequestModeLlama:
|
||||||
err, usage = openai.OpenaiHandler(c, resp, info)
|
usage, err = openai.OpenaiHandler(c, info, resp)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
return
|
return
|
||||||
|
|||||||
@@ -4,8 +4,11 @@ import "one-api/common"
|
|||||||
|
|
||||||
func GetModelRegion(other string, localModelName string) string {
|
func GetModelRegion(other string, localModelName string) string {
|
||||||
// if other is json string
|
// if other is json string
|
||||||
if common.IsJsonStr(other) {
|
if common.IsJsonObject(other) {
|
||||||
m := common.StrToMap(other)
|
m, err := common.StrToMap(other)
|
||||||
|
if err != nil {
|
||||||
|
return other // return original if parsing fails
|
||||||
|
}
|
||||||
if m[localModelName] != nil {
|
if m[localModelName] != nil {
|
||||||
return m[localModelName].(string)
|
return m[localModelName].(string)
|
||||||
} else {
|
} else {
|
||||||
|
|||||||
@@ -13,6 +13,7 @@ import (
|
|||||||
"one-api/relay/channel/openai"
|
"one-api/relay/channel/openai"
|
||||||
relaycommon "one-api/relay/common"
|
relaycommon "one-api/relay/common"
|
||||||
"one-api/relay/constant"
|
"one-api/relay/constant"
|
||||||
|
"one-api/types"
|
||||||
"path/filepath"
|
"path/filepath"
|
||||||
"strings"
|
"strings"
|
||||||
|
|
||||||
@@ -225,18 +226,18 @@ func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, request
|
|||||||
return channel.DoApiRequest(a, c, info, requestBody)
|
return channel.DoApiRequest(a, c, info, requestBody)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage any, err *dto.OpenAIErrorWithStatusCode) {
|
func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage any, err *types.NewAPIError) {
|
||||||
switch info.RelayMode {
|
switch info.RelayMode {
|
||||||
case constant.RelayModeChatCompletions:
|
case constant.RelayModeChatCompletions:
|
||||||
if info.IsStream {
|
if info.IsStream {
|
||||||
err, usage = openai.OaiStreamHandler(c, resp, info)
|
usage, err = openai.OaiStreamHandler(c, info, resp)
|
||||||
} else {
|
} else {
|
||||||
err, usage = openai.OpenaiHandler(c, resp, info)
|
usage, err = openai.OpenaiHandler(c, info, resp)
|
||||||
}
|
}
|
||||||
case constant.RelayModeEmbeddings:
|
case constant.RelayModeEmbeddings:
|
||||||
err, usage = openai.OpenaiHandler(c, resp, info)
|
usage, err = openai.OpenaiHandler(c, info, resp)
|
||||||
case constant.RelayModeImagesGenerations, constant.RelayModeImagesEdits:
|
case constant.RelayModeImagesGenerations, constant.RelayModeImagesEdits:
|
||||||
err, usage = openai.OpenaiHandlerWithUsage(c, resp, info)
|
usage, err = openai.OpenaiHandlerWithUsage(c, info, resp)
|
||||||
}
|
}
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -8,6 +8,7 @@ import (
|
|||||||
"one-api/relay/channel"
|
"one-api/relay/channel"
|
||||||
"one-api/relay/channel/openai"
|
"one-api/relay/channel/openai"
|
||||||
relaycommon "one-api/relay/common"
|
relaycommon "one-api/relay/common"
|
||||||
|
"one-api/types"
|
||||||
"strings"
|
"strings"
|
||||||
|
|
||||||
"one-api/relay/constant"
|
"one-api/relay/constant"
|
||||||
@@ -95,15 +96,15 @@ func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, request
|
|||||||
return channel.DoApiRequest(a, c, info, requestBody)
|
return channel.DoApiRequest(a, c, info, requestBody)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage any, err *dto.OpenAIErrorWithStatusCode) {
|
func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage any, err *types.NewAPIError) {
|
||||||
switch info.RelayMode {
|
switch info.RelayMode {
|
||||||
case constant.RelayModeImagesGenerations, constant.RelayModeImagesEdits:
|
case constant.RelayModeImagesGenerations, constant.RelayModeImagesEdits:
|
||||||
err, usage = openai.OpenaiHandlerWithUsage(c, resp, info)
|
usage, err = openai.OpenaiHandlerWithUsage(c, info, resp)
|
||||||
default:
|
default:
|
||||||
if info.IsStream {
|
if info.IsStream {
|
||||||
err, usage = xAIStreamHandler(c, resp, info)
|
usage, err = xAIStreamHandler(c, info, resp)
|
||||||
} else {
|
} else {
|
||||||
err, usage = xAIHandler(c, resp, info)
|
usage, err = xAIHandler(c, info, resp)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
return
|
return
|
||||||
|
|||||||
@@ -10,6 +10,7 @@ import (
|
|||||||
relaycommon "one-api/relay/common"
|
relaycommon "one-api/relay/common"
|
||||||
"one-api/relay/helper"
|
"one-api/relay/helper"
|
||||||
"one-api/service"
|
"one-api/service"
|
||||||
|
"one-api/types"
|
||||||
"strings"
|
"strings"
|
||||||
|
|
||||||
"github.com/gin-gonic/gin"
|
"github.com/gin-gonic/gin"
|
||||||
@@ -34,7 +35,7 @@ func streamResponseXAI2OpenAI(xAIResp *dto.ChatCompletionsStreamResponse, usage
|
|||||||
return openAIResp
|
return openAIResp
|
||||||
}
|
}
|
||||||
|
|
||||||
func xAIStreamHandler(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (*dto.OpenAIErrorWithStatusCode, *dto.Usage) {
|
func xAIStreamHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http.Response) (*dto.Usage, *types.NewAPIError) {
|
||||||
usage := &dto.Usage{}
|
usage := &dto.Usage{}
|
||||||
var responseTextBuilder strings.Builder
|
var responseTextBuilder strings.Builder
|
||||||
var toolCount int
|
var toolCount int
|
||||||
@@ -74,30 +75,28 @@ func xAIStreamHandler(c *gin.Context, resp *http.Response, info *relaycommon.Rel
|
|||||||
|
|
||||||
helper.Done(c)
|
helper.Done(c)
|
||||||
common.CloseResponseBodyGracefully(resp)
|
common.CloseResponseBodyGracefully(resp)
|
||||||
return nil, usage
|
return usage, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func xAIHandler(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (*dto.OpenAIErrorWithStatusCode, *dto.Usage) {
|
func xAIHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http.Response) (*dto.Usage, *types.NewAPIError) {
|
||||||
defer common.CloseResponseBodyGracefully(resp)
|
defer common.CloseResponseBodyGracefully(resp)
|
||||||
|
|
||||||
responseBody, err := io.ReadAll(resp.Body)
|
responseBody, err := io.ReadAll(resp.Body)
|
||||||
var response *dto.SimpleResponse
|
var response *dto.SimpleResponse
|
||||||
err = common.UnmarshalJson(responseBody, &response)
|
err = common.Unmarshal(responseBody, &response)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
common.SysError("error unmarshalling stream response: " + err.Error())
|
return nil, types.NewError(err, types.ErrorCodeBadResponseBody)
|
||||||
return nil, nil
|
|
||||||
}
|
}
|
||||||
response.Usage.CompletionTokens = response.Usage.TotalTokens - response.Usage.PromptTokens
|
response.Usage.CompletionTokens = response.Usage.TotalTokens - response.Usage.PromptTokens
|
||||||
response.Usage.CompletionTokenDetails.TextTokens = response.Usage.CompletionTokens - response.Usage.CompletionTokenDetails.ReasoningTokens
|
response.Usage.CompletionTokenDetails.TextTokens = response.Usage.CompletionTokens - response.Usage.CompletionTokenDetails.ReasoningTokens
|
||||||
|
|
||||||
// new body
|
// new body
|
||||||
encodeJson, err := common.EncodeJson(response)
|
encodeJson, err := common.Marshal(response)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
common.SysError("error marshalling stream response: " + err.Error())
|
return nil, types.NewError(err, types.ErrorCodeBadResponseBody)
|
||||||
return nil, nil
|
|
||||||
}
|
}
|
||||||
|
|
||||||
common.IOCopyBytesGracefully(c, resp, encodeJson)
|
common.IOCopyBytesGracefully(c, resp, encodeJson)
|
||||||
|
|
||||||
return nil, &response.Usage
|
return &response.Usage, nil
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -7,7 +7,7 @@ import (
|
|||||||
"one-api/dto"
|
"one-api/dto"
|
||||||
"one-api/relay/channel"
|
"one-api/relay/channel"
|
||||||
relaycommon "one-api/relay/common"
|
relaycommon "one-api/relay/common"
|
||||||
"one-api/service"
|
"one-api/types"
|
||||||
"strings"
|
"strings"
|
||||||
|
|
||||||
"github.com/gin-gonic/gin"
|
"github.com/gin-gonic/gin"
|
||||||
@@ -74,18 +74,18 @@ func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, request
|
|||||||
return dummyResp, nil
|
return dummyResp, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage any, err *dto.OpenAIErrorWithStatusCode) {
|
func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage any, err *types.NewAPIError) {
|
||||||
splits := strings.Split(info.ApiKey, "|")
|
splits := strings.Split(info.ApiKey, "|")
|
||||||
if len(splits) != 3 {
|
if len(splits) != 3 {
|
||||||
return nil, service.OpenAIErrorWrapper(errors.New("invalid auth"), "invalid_auth", http.StatusBadRequest)
|
return nil, types.NewError(errors.New("invalid auth"), types.ErrorCodeChannelInvalidKey)
|
||||||
}
|
}
|
||||||
if a.request == nil {
|
if a.request == nil {
|
||||||
return nil, service.OpenAIErrorWrapper(errors.New("request is nil"), "request_is_nil", http.StatusBadRequest)
|
return nil, types.NewError(errors.New("request is nil"), types.ErrorCodeInvalidRequest)
|
||||||
}
|
}
|
||||||
if info.IsStream {
|
if info.IsStream {
|
||||||
err, usage = xunfeiStreamHandler(c, *a.request, splits[0], splits[1], splits[2])
|
usage, err = xunfeiStreamHandler(c, *a.request, splits[0], splits[1], splits[2])
|
||||||
} else {
|
} else {
|
||||||
err, usage = xunfeiHandler(c, *a.request, splits[0], splits[1], splits[2])
|
usage, err = xunfeiHandler(c, *a.request, splits[0], splits[1], splits[2])
|
||||||
}
|
}
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -6,18 +6,18 @@ import (
|
|||||||
"encoding/base64"
|
"encoding/base64"
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
"fmt"
|
"fmt"
|
||||||
"github.com/gin-gonic/gin"
|
|
||||||
"github.com/gorilla/websocket"
|
|
||||||
"io"
|
"io"
|
||||||
"net/http"
|
|
||||||
"net/url"
|
"net/url"
|
||||||
"one-api/common"
|
"one-api/common"
|
||||||
"one-api/constant"
|
"one-api/constant"
|
||||||
"one-api/dto"
|
"one-api/dto"
|
||||||
"one-api/relay/helper"
|
"one-api/relay/helper"
|
||||||
"one-api/service"
|
"one-api/types"
|
||||||
"strings"
|
"strings"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
|
"github.com/gin-gonic/gin"
|
||||||
|
"github.com/gorilla/websocket"
|
||||||
)
|
)
|
||||||
|
|
||||||
// https://console.xfyun.cn/services/cbm
|
// https://console.xfyun.cn/services/cbm
|
||||||
@@ -126,11 +126,11 @@ func buildXunfeiAuthUrl(hostUrl string, apiKey, apiSecret string) string {
|
|||||||
return callUrl
|
return callUrl
|
||||||
}
|
}
|
||||||
|
|
||||||
func xunfeiStreamHandler(c *gin.Context, textRequest dto.GeneralOpenAIRequest, appId string, apiSecret string, apiKey string) (*dto.OpenAIErrorWithStatusCode, *dto.Usage) {
|
func xunfeiStreamHandler(c *gin.Context, textRequest dto.GeneralOpenAIRequest, appId string, apiSecret string, apiKey string) (*dto.Usage, *types.NewAPIError) {
|
||||||
domain, authUrl := getXunfeiAuthUrl(c, apiKey, apiSecret, textRequest.Model)
|
domain, authUrl := getXunfeiAuthUrl(c, apiKey, apiSecret, textRequest.Model)
|
||||||
dataChan, stopChan, err := xunfeiMakeRequest(textRequest, domain, authUrl, appId)
|
dataChan, stopChan, err := xunfeiMakeRequest(textRequest, domain, authUrl, appId)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return service.OpenAIErrorWrapper(err, "make xunfei request err", http.StatusInternalServerError), nil
|
return nil, types.NewError(err, types.ErrorCodeDoRequestFailed)
|
||||||
}
|
}
|
||||||
helper.SetEventStreamHeaders(c)
|
helper.SetEventStreamHeaders(c)
|
||||||
var usage dto.Usage
|
var usage dto.Usage
|
||||||
@@ -153,14 +153,14 @@ func xunfeiStreamHandler(c *gin.Context, textRequest dto.GeneralOpenAIRequest, a
|
|||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
})
|
})
|
||||||
return nil, &usage
|
return &usage, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func xunfeiHandler(c *gin.Context, textRequest dto.GeneralOpenAIRequest, appId string, apiSecret string, apiKey string) (*dto.OpenAIErrorWithStatusCode, *dto.Usage) {
|
func xunfeiHandler(c *gin.Context, textRequest dto.GeneralOpenAIRequest, appId string, apiSecret string, apiKey string) (*dto.Usage, *types.NewAPIError) {
|
||||||
domain, authUrl := getXunfeiAuthUrl(c, apiKey, apiSecret, textRequest.Model)
|
domain, authUrl := getXunfeiAuthUrl(c, apiKey, apiSecret, textRequest.Model)
|
||||||
dataChan, stopChan, err := xunfeiMakeRequest(textRequest, domain, authUrl, appId)
|
dataChan, stopChan, err := xunfeiMakeRequest(textRequest, domain, authUrl, appId)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return service.OpenAIErrorWrapper(err, "make xunfei request err", http.StatusInternalServerError), nil
|
return nil, types.NewError(err, types.ErrorCodeDoRequestFailed)
|
||||||
}
|
}
|
||||||
var usage dto.Usage
|
var usage dto.Usage
|
||||||
var content string
|
var content string
|
||||||
@@ -191,11 +191,11 @@ func xunfeiHandler(c *gin.Context, textRequest dto.GeneralOpenAIRequest, appId s
|
|||||||
response := responseXunfei2OpenAI(&xunfeiResponse)
|
response := responseXunfei2OpenAI(&xunfeiResponse)
|
||||||
jsonResponse, err := json.Marshal(response)
|
jsonResponse, err := json.Marshal(response)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return service.OpenAIErrorWrapper(err, "marshal_response_body_failed", http.StatusInternalServerError), nil
|
return nil, types.NewError(err, types.ErrorCodeBadResponseBody)
|
||||||
}
|
}
|
||||||
c.Writer.Header().Set("Content-Type", "application/json")
|
c.Writer.Header().Set("Content-Type", "application/json")
|
||||||
_, _ = c.Writer.Write(jsonResponse)
|
_, _ = c.Writer.Write(jsonResponse)
|
||||||
return nil, &usage
|
return &usage, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func xunfeiMakeRequest(textRequest dto.GeneralOpenAIRequest, domain, authUrl, appId string) (chan XunfeiChatResponse, chan bool, error) {
|
func xunfeiMakeRequest(textRequest dto.GeneralOpenAIRequest, domain, authUrl, appId string) (chan XunfeiChatResponse, chan bool, error) {
|
||||||
|
|||||||
@@ -8,6 +8,7 @@ import (
|
|||||||
"one-api/dto"
|
"one-api/dto"
|
||||||
"one-api/relay/channel"
|
"one-api/relay/channel"
|
||||||
relaycommon "one-api/relay/common"
|
relaycommon "one-api/relay/common"
|
||||||
|
"one-api/types"
|
||||||
|
|
||||||
"github.com/gin-gonic/gin"
|
"github.com/gin-gonic/gin"
|
||||||
)
|
)
|
||||||
@@ -77,11 +78,11 @@ func (a *Adaptor) ConvertOpenAIResponsesRequest(c *gin.Context, info *relaycommo
|
|||||||
return nil, errors.New("not implemented")
|
return nil, errors.New("not implemented")
|
||||||
}
|
}
|
||||||
|
|
||||||
func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage any, err *dto.OpenAIErrorWithStatusCode) {
|
func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage any, err *types.NewAPIError) {
|
||||||
if info.IsStream {
|
if info.IsStream {
|
||||||
err, usage = zhipuStreamHandler(c, resp)
|
usage, err = zhipuStreamHandler(c, info, resp)
|
||||||
} else {
|
} else {
|
||||||
err, usage = zhipuHandler(c, resp)
|
usage, err = zhipuHandler(c, info, resp)
|
||||||
}
|
}
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -3,18 +3,20 @@ package zhipu
|
|||||||
import (
|
import (
|
||||||
"bufio"
|
"bufio"
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
"github.com/gin-gonic/gin"
|
|
||||||
"github.com/golang-jwt/jwt"
|
|
||||||
"io"
|
"io"
|
||||||
"net/http"
|
"net/http"
|
||||||
"one-api/common"
|
"one-api/common"
|
||||||
"one-api/constant"
|
"one-api/constant"
|
||||||
"one-api/dto"
|
"one-api/dto"
|
||||||
|
relaycommon "one-api/relay/common"
|
||||||
"one-api/relay/helper"
|
"one-api/relay/helper"
|
||||||
"one-api/service"
|
"one-api/types"
|
||||||
"strings"
|
"strings"
|
||||||
"sync"
|
"sync"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
|
"github.com/gin-gonic/gin"
|
||||||
|
"github.com/golang-jwt/jwt"
|
||||||
)
|
)
|
||||||
|
|
||||||
// https://open.bigmodel.cn/doc/api#chatglm_std
|
// https://open.bigmodel.cn/doc/api#chatglm_std
|
||||||
@@ -150,7 +152,7 @@ func streamMetaResponseZhipu2OpenAI(zhipuResponse *ZhipuStreamMetaResponse) (*dt
|
|||||||
return &response, &zhipuResponse.Usage
|
return &response, &zhipuResponse.Usage
|
||||||
}
|
}
|
||||||
|
|
||||||
func zhipuStreamHandler(c *gin.Context, resp *http.Response) (*dto.OpenAIErrorWithStatusCode, *dto.Usage) {
|
func zhipuStreamHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http.Response) (*dto.Usage, *types.NewAPIError) {
|
||||||
var usage *dto.Usage
|
var usage *dto.Usage
|
||||||
scanner := bufio.NewScanner(resp.Body)
|
scanner := bufio.NewScanner(resp.Body)
|
||||||
scanner.Split(bufio.ScanLines)
|
scanner.Split(bufio.ScanLines)
|
||||||
@@ -211,38 +213,33 @@ func zhipuStreamHandler(c *gin.Context, resp *http.Response) (*dto.OpenAIErrorWi
|
|||||||
}
|
}
|
||||||
})
|
})
|
||||||
common.CloseResponseBodyGracefully(resp)
|
common.CloseResponseBodyGracefully(resp)
|
||||||
return nil, usage
|
return usage, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func zhipuHandler(c *gin.Context, resp *http.Response) (*dto.OpenAIErrorWithStatusCode, *dto.Usage) {
|
func zhipuHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http.Response) (*dto.Usage, *types.NewAPIError) {
|
||||||
var zhipuResponse ZhipuResponse
|
var zhipuResponse ZhipuResponse
|
||||||
responseBody, err := io.ReadAll(resp.Body)
|
responseBody, err := io.ReadAll(resp.Body)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return service.OpenAIErrorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil
|
return nil, types.NewError(err, types.ErrorCodeReadResponseBodyFailed)
|
||||||
}
|
}
|
||||||
common.CloseResponseBodyGracefully(resp)
|
common.CloseResponseBodyGracefully(resp)
|
||||||
err = json.Unmarshal(responseBody, &zhipuResponse)
|
err = json.Unmarshal(responseBody, &zhipuResponse)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return service.OpenAIErrorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil
|
return nil, types.NewError(err, types.ErrorCodeBadResponseBody)
|
||||||
}
|
}
|
||||||
if !zhipuResponse.Success {
|
if !zhipuResponse.Success {
|
||||||
return &dto.OpenAIErrorWithStatusCode{
|
return nil, types.WithOpenAIError(types.OpenAIError{
|
||||||
Error: dto.OpenAIError{
|
Message: zhipuResponse.Msg,
|
||||||
Message: zhipuResponse.Msg,
|
Code: zhipuResponse.Code,
|
||||||
Type: "zhipu_error",
|
}, resp.StatusCode)
|
||||||
Param: "",
|
|
||||||
Code: zhipuResponse.Code,
|
|
||||||
},
|
|
||||||
StatusCode: resp.StatusCode,
|
|
||||||
}, nil
|
|
||||||
}
|
}
|
||||||
fullTextResponse := responseZhipu2OpenAI(&zhipuResponse)
|
fullTextResponse := responseZhipu2OpenAI(&zhipuResponse)
|
||||||
jsonResponse, err := json.Marshal(fullTextResponse)
|
jsonResponse, err := json.Marshal(fullTextResponse)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return service.OpenAIErrorWrapper(err, "marshal_response_body_failed", http.StatusInternalServerError), nil
|
return nil, types.NewError(err, types.ErrorCodeBadResponseBody)
|
||||||
}
|
}
|
||||||
c.Writer.Header().Set("Content-Type", "application/json")
|
c.Writer.Header().Set("Content-Type", "application/json")
|
||||||
c.Writer.WriteHeader(resp.StatusCode)
|
c.Writer.WriteHeader(resp.StatusCode)
|
||||||
_, err = c.Writer.Write(jsonResponse)
|
_, err = c.Writer.Write(jsonResponse)
|
||||||
return nil, &fullTextResponse.Usage
|
return &fullTextResponse.Usage, nil
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -10,6 +10,7 @@ import (
|
|||||||
"one-api/relay/channel/openai"
|
"one-api/relay/channel/openai"
|
||||||
relaycommon "one-api/relay/common"
|
relaycommon "one-api/relay/common"
|
||||||
relayconstant "one-api/relay/constant"
|
relayconstant "one-api/relay/constant"
|
||||||
|
"one-api/types"
|
||||||
|
|
||||||
"github.com/gin-gonic/gin"
|
"github.com/gin-gonic/gin"
|
||||||
)
|
)
|
||||||
@@ -80,11 +81,11 @@ func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, request
|
|||||||
return channel.DoApiRequest(a, c, info, requestBody)
|
return channel.DoApiRequest(a, c, info, requestBody)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage any, err *dto.OpenAIErrorWithStatusCode) {
|
func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage any, err *types.NewAPIError) {
|
||||||
if info.IsStream {
|
if info.IsStream {
|
||||||
err, usage = openai.OaiStreamHandler(c, resp, info)
|
usage, err = openai.OaiStreamHandler(c, info, resp)
|
||||||
} else {
|
} else {
|
||||||
err, usage = openai.OpenaiHandler(c, resp, info)
|
usage, err = openai.OpenaiHandler(c, info, resp)
|
||||||
}
|
}
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -2,10 +2,8 @@ package relay
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"bytes"
|
"bytes"
|
||||||
"encoding/json"
|
|
||||||
"errors"
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
"github.com/gin-gonic/gin"
|
|
||||||
"io"
|
"io"
|
||||||
"net/http"
|
"net/http"
|
||||||
"one-api/common"
|
"one-api/common"
|
||||||
@@ -14,7 +12,10 @@ import (
|
|||||||
"one-api/relay/helper"
|
"one-api/relay/helper"
|
||||||
"one-api/service"
|
"one-api/service"
|
||||||
"one-api/setting/model_setting"
|
"one-api/setting/model_setting"
|
||||||
|
"one-api/types"
|
||||||
"strings"
|
"strings"
|
||||||
|
|
||||||
|
"github.com/gin-gonic/gin"
|
||||||
)
|
)
|
||||||
|
|
||||||
func getAndValidateClaudeRequest(c *gin.Context) (textRequest *dto.ClaudeRequest, err error) {
|
func getAndValidateClaudeRequest(c *gin.Context) (textRequest *dto.ClaudeRequest, err error) {
|
||||||
@@ -32,14 +33,14 @@ func getAndValidateClaudeRequest(c *gin.Context) (textRequest *dto.ClaudeRequest
|
|||||||
return textRequest, nil
|
return textRequest, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func ClaudeHelper(c *gin.Context) (claudeError *dto.ClaudeErrorWithStatusCode) {
|
func ClaudeHelper(c *gin.Context) (newAPIError *types.NewAPIError) {
|
||||||
|
|
||||||
relayInfo := relaycommon.GenRelayInfoClaude(c)
|
relayInfo := relaycommon.GenRelayInfoClaude(c)
|
||||||
|
|
||||||
// get & validate textRequest 获取并验证文本请求
|
// get & validate textRequest 获取并验证文本请求
|
||||||
textRequest, err := getAndValidateClaudeRequest(c)
|
textRequest, err := getAndValidateClaudeRequest(c)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return service.ClaudeErrorWrapperLocal(err, "invalid_claude_request", http.StatusBadRequest)
|
return types.NewError(err, types.ErrorCodeInvalidRequest)
|
||||||
}
|
}
|
||||||
|
|
||||||
if textRequest.Stream {
|
if textRequest.Stream {
|
||||||
@@ -48,35 +49,35 @@ func ClaudeHelper(c *gin.Context) (claudeError *dto.ClaudeErrorWithStatusCode) {
|
|||||||
|
|
||||||
err = helper.ModelMappedHelper(c, relayInfo, textRequest)
|
err = helper.ModelMappedHelper(c, relayInfo, textRequest)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return service.ClaudeErrorWrapperLocal(err, "model_mapped_error", http.StatusInternalServerError)
|
return types.NewError(err, types.ErrorCodeChannelModelMappedError)
|
||||||
}
|
}
|
||||||
|
|
||||||
promptTokens, err := getClaudePromptTokens(textRequest, relayInfo)
|
promptTokens, err := getClaudePromptTokens(textRequest, relayInfo)
|
||||||
// count messages token error 计算promptTokens错误
|
// count messages token error 计算promptTokens错误
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return service.ClaudeErrorWrapperLocal(err, "count_token_messages_failed", http.StatusInternalServerError)
|
return types.NewError(err, types.ErrorCodeCountTokenFailed)
|
||||||
}
|
}
|
||||||
|
|
||||||
priceData, err := helper.ModelPriceHelper(c, relayInfo, promptTokens, int(textRequest.MaxTokens))
|
priceData, err := helper.ModelPriceHelper(c, relayInfo, promptTokens, int(textRequest.MaxTokens))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return service.ClaudeErrorWrapperLocal(err, "model_price_error", http.StatusInternalServerError)
|
return types.NewError(err, types.ErrorCodeModelPriceError)
|
||||||
}
|
}
|
||||||
|
|
||||||
// pre-consume quota 预消耗配额
|
// pre-consume quota 预消耗配额
|
||||||
preConsumedQuota, userQuota, openaiErr := preConsumeQuota(c, priceData.ShouldPreConsumedQuota, relayInfo)
|
preConsumedQuota, userQuota, newAPIError := preConsumeQuota(c, priceData.ShouldPreConsumedQuota, relayInfo)
|
||||||
|
|
||||||
if openaiErr != nil {
|
if newAPIError != nil {
|
||||||
return service.OpenAIErrorToClaudeError(openaiErr)
|
return newAPIError
|
||||||
}
|
}
|
||||||
defer func() {
|
defer func() {
|
||||||
if openaiErr != nil {
|
if newAPIError != nil {
|
||||||
returnPreConsumedQuota(c, relayInfo, userQuota, preConsumedQuota)
|
returnPreConsumedQuota(c, relayInfo, userQuota, preConsumedQuota)
|
||||||
}
|
}
|
||||||
}()
|
}()
|
||||||
|
|
||||||
adaptor := GetAdaptor(relayInfo.ApiType)
|
adaptor := GetAdaptor(relayInfo.ApiType)
|
||||||
if adaptor == nil {
|
if adaptor == nil {
|
||||||
return service.ClaudeErrorWrapperLocal(fmt.Errorf("invalid api type: %d", relayInfo.ApiType), "invalid_api_type", http.StatusBadRequest)
|
return types.NewError(fmt.Errorf("invalid api type: %d", relayInfo.ApiType), types.ErrorCodeInvalidApiType)
|
||||||
}
|
}
|
||||||
adaptor.Init(relayInfo)
|
adaptor.Init(relayInfo)
|
||||||
var requestBody io.Reader
|
var requestBody io.Reader
|
||||||
@@ -109,14 +110,14 @@ func ClaudeHelper(c *gin.Context) (claudeError *dto.ClaudeErrorWithStatusCode) {
|
|||||||
|
|
||||||
convertedRequest, err := adaptor.ConvertClaudeRequest(c, relayInfo, textRequest)
|
convertedRequest, err := adaptor.ConvertClaudeRequest(c, relayInfo, textRequest)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return service.ClaudeErrorWrapperLocal(err, "convert_request_failed", http.StatusInternalServerError)
|
return types.NewError(err, types.ErrorCodeConvertRequestFailed)
|
||||||
}
|
}
|
||||||
jsonData, err := json.Marshal(convertedRequest)
|
jsonData, err := common.Marshal(convertedRequest)
|
||||||
if common.DebugEnabled {
|
if common.DebugEnabled {
|
||||||
println("requestBody: ", string(jsonData))
|
println("requestBody: ", string(jsonData))
|
||||||
}
|
}
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return service.ClaudeErrorWrapperLocal(err, "json_marshal_failed", http.StatusInternalServerError)
|
return types.NewError(err, types.ErrorCodeConvertRequestFailed)
|
||||||
}
|
}
|
||||||
requestBody = bytes.NewBuffer(jsonData)
|
requestBody = bytes.NewBuffer(jsonData)
|
||||||
|
|
||||||
@@ -124,26 +125,26 @@ func ClaudeHelper(c *gin.Context) (claudeError *dto.ClaudeErrorWithStatusCode) {
|
|||||||
var httpResp *http.Response
|
var httpResp *http.Response
|
||||||
resp, err := adaptor.DoRequest(c, relayInfo, requestBody)
|
resp, err := adaptor.DoRequest(c, relayInfo, requestBody)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return service.ClaudeErrorWrapper(err, "do_request_failed", http.StatusInternalServerError)
|
return types.NewError(err, types.ErrorCodeDoRequestFailed)
|
||||||
}
|
}
|
||||||
|
|
||||||
if resp != nil {
|
if resp != nil {
|
||||||
httpResp = resp.(*http.Response)
|
httpResp = resp.(*http.Response)
|
||||||
relayInfo.IsStream = relayInfo.IsStream || strings.HasPrefix(httpResp.Header.Get("Content-Type"), "text/event-stream")
|
relayInfo.IsStream = relayInfo.IsStream || strings.HasPrefix(httpResp.Header.Get("Content-Type"), "text/event-stream")
|
||||||
if httpResp.StatusCode != http.StatusOK {
|
if httpResp.StatusCode != http.StatusOK {
|
||||||
openaiErr = service.RelayErrorHandler(httpResp, false)
|
newAPIError = service.RelayErrorHandler(httpResp, false)
|
||||||
// reset status code 重置状态码
|
// reset status code 重置状态码
|
||||||
service.ResetStatusCode(openaiErr, statusCodeMappingStr)
|
service.ResetStatusCode(newAPIError, statusCodeMappingStr)
|
||||||
return service.OpenAIErrorToClaudeError(openaiErr)
|
return newAPIError
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
usage, openaiErr := adaptor.DoResponse(c, httpResp, relayInfo)
|
usage, newAPIError := adaptor.DoResponse(c, httpResp, relayInfo)
|
||||||
//log.Printf("usage: %v", usage)
|
//log.Printf("usage: %v", usage)
|
||||||
if openaiErr != nil {
|
if newAPIError != nil {
|
||||||
// reset status code 重置状态码
|
// reset status code 重置状态码
|
||||||
service.ResetStatusCode(openaiErr, statusCodeMappingStr)
|
service.ResetStatusCode(newAPIError, statusCodeMappingStr)
|
||||||
return service.OpenAIErrorToClaudeError(openaiErr)
|
return newAPIError
|
||||||
}
|
}
|
||||||
service.PostClaudeConsumeQuota(c, relayInfo, usage.(*dto.Usage), preConsumedQuota, userQuota, priceData, "")
|
service.PostClaudeConsumeQuota(c, relayInfo, usage.(*dto.Usage), preConsumedQuota, userQuota, priceData, "")
|
||||||
return nil
|
return nil
|
||||||
|
|||||||
@@ -213,7 +213,7 @@ func GenRelayInfoImage(c *gin.Context) *RelayInfo {
|
|||||||
func GenRelayInfo(c *gin.Context) *RelayInfo {
|
func GenRelayInfo(c *gin.Context) *RelayInfo {
|
||||||
channelType := common.GetContextKeyInt(c, constant.ContextKeyChannelType)
|
channelType := common.GetContextKeyInt(c, constant.ContextKeyChannelType)
|
||||||
channelId := common.GetContextKeyInt(c, constant.ContextKeyChannelId)
|
channelId := common.GetContextKeyInt(c, constant.ContextKeyChannelId)
|
||||||
paramOverride := common.GetContextKeyStringMap(c, constant.ContextKeyParamOverride)
|
paramOverride := common.GetContextKeyStringMap(c, constant.ContextKeyChannelParamOverride)
|
||||||
|
|
||||||
tokenId := common.GetContextKeyInt(c, constant.ContextKeyTokenId)
|
tokenId := common.GetContextKeyInt(c, constant.ContextKeyTokenId)
|
||||||
tokenKey := common.GetContextKeyString(c, constant.ContextKeyTokenKey)
|
tokenKey := common.GetContextKeyString(c, constant.ContextKeyTokenKey)
|
||||||
@@ -229,7 +229,7 @@ func GenRelayInfo(c *gin.Context) *RelayInfo {
|
|||||||
UserEmail: common.GetContextKeyString(c, constant.ContextKeyUserEmail),
|
UserEmail: common.GetContextKeyString(c, constant.ContextKeyUserEmail),
|
||||||
isFirstResponse: true,
|
isFirstResponse: true,
|
||||||
RelayMode: relayconstant.Path2RelayMode(c.Request.URL.Path),
|
RelayMode: relayconstant.Path2RelayMode(c.Request.URL.Path),
|
||||||
BaseUrl: common.GetContextKeyString(c, constant.ContextKeyBaseUrl),
|
BaseUrl: common.GetContextKeyString(c, constant.ContextKeyChannelBaseUrl),
|
||||||
RequestURLPath: c.Request.URL.String(),
|
RequestURLPath: c.Request.URL.String(),
|
||||||
ChannelType: channelType,
|
ChannelType: channelType,
|
||||||
ChannelId: channelId,
|
ChannelId: channelId,
|
||||||
|
|||||||
@@ -1,7 +1,6 @@
|
|||||||
package common_handler
|
package common_handler
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"github.com/gin-gonic/gin"
|
|
||||||
"io"
|
"io"
|
||||||
"net/http"
|
"net/http"
|
||||||
"one-api/common"
|
"one-api/common"
|
||||||
@@ -9,13 +8,15 @@ import (
|
|||||||
"one-api/dto"
|
"one-api/dto"
|
||||||
"one-api/relay/channel/xinference"
|
"one-api/relay/channel/xinference"
|
||||||
relaycommon "one-api/relay/common"
|
relaycommon "one-api/relay/common"
|
||||||
"one-api/service"
|
"one-api/types"
|
||||||
|
|
||||||
|
"github.com/gin-gonic/gin"
|
||||||
)
|
)
|
||||||
|
|
||||||
func RerankHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http.Response) (*dto.OpenAIErrorWithStatusCode, *dto.Usage) {
|
func RerankHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http.Response) (*dto.Usage, *types.NewAPIError) {
|
||||||
responseBody, err := io.ReadAll(resp.Body)
|
responseBody, err := io.ReadAll(resp.Body)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return service.OpenAIErrorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil
|
return nil, types.NewError(err, types.ErrorCodeReadResponseBodyFailed)
|
||||||
}
|
}
|
||||||
common.CloseResponseBodyGracefully(resp)
|
common.CloseResponseBodyGracefully(resp)
|
||||||
if common.DebugEnabled {
|
if common.DebugEnabled {
|
||||||
@@ -24,9 +25,9 @@ func RerankHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http.Respo
|
|||||||
var jinaResp dto.RerankResponse
|
var jinaResp dto.RerankResponse
|
||||||
if info.ChannelType == constant.ChannelTypeXinference {
|
if info.ChannelType == constant.ChannelTypeXinference {
|
||||||
var xinRerankResponse xinference.XinRerankResponse
|
var xinRerankResponse xinference.XinRerankResponse
|
||||||
err = common.UnmarshalJson(responseBody, &xinRerankResponse)
|
err = common.Unmarshal(responseBody, &xinRerankResponse)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return service.OpenAIErrorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil
|
return nil, types.NewError(err, types.ErrorCodeBadResponseBody)
|
||||||
}
|
}
|
||||||
jinaRespResults := make([]dto.RerankResponseResult, len(xinRerankResponse.Results))
|
jinaRespResults := make([]dto.RerankResponseResult, len(xinRerankResponse.Results))
|
||||||
for i, result := range xinRerankResponse.Results {
|
for i, result := range xinRerankResponse.Results {
|
||||||
@@ -59,14 +60,14 @@ func RerankHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http.Respo
|
|||||||
},
|
},
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
err = common.UnmarshalJson(responseBody, &jinaResp)
|
err = common.Unmarshal(responseBody, &jinaResp)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return service.OpenAIErrorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil
|
return nil, types.NewError(err, types.ErrorCodeBadResponseBody)
|
||||||
}
|
}
|
||||||
jinaResp.Usage.PromptTokens = jinaResp.Usage.TotalTokens
|
jinaResp.Usage.PromptTokens = jinaResp.Usage.TotalTokens
|
||||||
}
|
}
|
||||||
|
|
||||||
c.Writer.Header().Set("Content-Type", "application/json")
|
c.Writer.Header().Set("Content-Type", "application/json")
|
||||||
c.JSON(http.StatusOK, jinaResp)
|
c.JSON(http.StatusOK, jinaResp)
|
||||||
return nil, &jinaResp.Usage
|
return &jinaResp.Usage, nil
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -4,7 +4,6 @@ import (
|
|||||||
"bytes"
|
"bytes"
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
"fmt"
|
"fmt"
|
||||||
"github.com/gin-gonic/gin"
|
|
||||||
"net/http"
|
"net/http"
|
||||||
"one-api/common"
|
"one-api/common"
|
||||||
"one-api/dto"
|
"one-api/dto"
|
||||||
@@ -12,6 +11,9 @@ import (
|
|||||||
relayconstant "one-api/relay/constant"
|
relayconstant "one-api/relay/constant"
|
||||||
"one-api/relay/helper"
|
"one-api/relay/helper"
|
||||||
"one-api/service"
|
"one-api/service"
|
||||||
|
"one-api/types"
|
||||||
|
|
||||||
|
"github.com/gin-gonic/gin"
|
||||||
)
|
)
|
||||||
|
|
||||||
func getEmbeddingPromptToken(embeddingRequest dto.EmbeddingRequest) int {
|
func getEmbeddingPromptToken(embeddingRequest dto.EmbeddingRequest) int {
|
||||||
@@ -32,24 +34,24 @@ func validateEmbeddingRequest(c *gin.Context, info *relaycommon.RelayInfo, embed
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func EmbeddingHelper(c *gin.Context) (openaiErr *dto.OpenAIErrorWithStatusCode) {
|
func EmbeddingHelper(c *gin.Context) (newAPIError *types.NewAPIError) {
|
||||||
relayInfo := relaycommon.GenRelayInfoEmbedding(c)
|
relayInfo := relaycommon.GenRelayInfoEmbedding(c)
|
||||||
|
|
||||||
var embeddingRequest *dto.EmbeddingRequest
|
var embeddingRequest *dto.EmbeddingRequest
|
||||||
err := common.UnmarshalBodyReusable(c, &embeddingRequest)
|
err := common.UnmarshalBodyReusable(c, &embeddingRequest)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
common.LogError(c, fmt.Sprintf("getAndValidateTextRequest failed: %s", err.Error()))
|
common.LogError(c, fmt.Sprintf("getAndValidateTextRequest failed: %s", err.Error()))
|
||||||
return service.OpenAIErrorWrapperLocal(err, "invalid_text_request", http.StatusBadRequest)
|
return types.NewError(err, types.ErrorCodeInvalidRequest)
|
||||||
}
|
}
|
||||||
|
|
||||||
err = validateEmbeddingRequest(c, relayInfo, *embeddingRequest)
|
err = validateEmbeddingRequest(c, relayInfo, *embeddingRequest)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return service.OpenAIErrorWrapperLocal(err, "invalid_embedding_request", http.StatusBadRequest)
|
return types.NewError(err, types.ErrorCodeInvalidRequest)
|
||||||
}
|
}
|
||||||
|
|
||||||
err = helper.ModelMappedHelper(c, relayInfo, embeddingRequest)
|
err = helper.ModelMappedHelper(c, relayInfo, embeddingRequest)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return service.OpenAIErrorWrapperLocal(err, "model_mapped_error", http.StatusInternalServerError)
|
return types.NewError(err, types.ErrorCodeChannelModelMappedError)
|
||||||
}
|
}
|
||||||
|
|
||||||
promptToken := getEmbeddingPromptToken(*embeddingRequest)
|
promptToken := getEmbeddingPromptToken(*embeddingRequest)
|
||||||
@@ -57,57 +59,57 @@ func EmbeddingHelper(c *gin.Context) (openaiErr *dto.OpenAIErrorWithStatusCode)
|
|||||||
|
|
||||||
priceData, err := helper.ModelPriceHelper(c, relayInfo, promptToken, 0)
|
priceData, err := helper.ModelPriceHelper(c, relayInfo, promptToken, 0)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return service.OpenAIErrorWrapperLocal(err, "model_price_error", http.StatusInternalServerError)
|
return types.NewError(err, types.ErrorCodeModelPriceError)
|
||||||
}
|
}
|
||||||
// pre-consume quota 预消耗配额
|
// pre-consume quota 预消耗配额
|
||||||
preConsumedQuota, userQuota, openaiErr := preConsumeQuota(c, priceData.ShouldPreConsumedQuota, relayInfo)
|
preConsumedQuota, userQuota, newAPIError := preConsumeQuota(c, priceData.ShouldPreConsumedQuota, relayInfo)
|
||||||
if openaiErr != nil {
|
if newAPIError != nil {
|
||||||
return openaiErr
|
return newAPIError
|
||||||
}
|
}
|
||||||
defer func() {
|
defer func() {
|
||||||
if openaiErr != nil {
|
if newAPIError != nil {
|
||||||
returnPreConsumedQuota(c, relayInfo, userQuota, preConsumedQuota)
|
returnPreConsumedQuota(c, relayInfo, userQuota, preConsumedQuota)
|
||||||
}
|
}
|
||||||
}()
|
}()
|
||||||
|
|
||||||
adaptor := GetAdaptor(relayInfo.ApiType)
|
adaptor := GetAdaptor(relayInfo.ApiType)
|
||||||
if adaptor == nil {
|
if adaptor == nil {
|
||||||
return service.OpenAIErrorWrapperLocal(fmt.Errorf("invalid api type: %d", relayInfo.ApiType), "invalid_api_type", http.StatusBadRequest)
|
return types.NewError(fmt.Errorf("invalid api type: %d", relayInfo.ApiType), types.ErrorCodeInvalidApiType)
|
||||||
}
|
}
|
||||||
adaptor.Init(relayInfo)
|
adaptor.Init(relayInfo)
|
||||||
|
|
||||||
convertedRequest, err := adaptor.ConvertEmbeddingRequest(c, relayInfo, *embeddingRequest)
|
convertedRequest, err := adaptor.ConvertEmbeddingRequest(c, relayInfo, *embeddingRequest)
|
||||||
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return service.OpenAIErrorWrapperLocal(err, "convert_request_failed", http.StatusInternalServerError)
|
return types.NewError(err, types.ErrorCodeConvertRequestFailed)
|
||||||
}
|
}
|
||||||
jsonData, err := json.Marshal(convertedRequest)
|
jsonData, err := json.Marshal(convertedRequest)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return service.OpenAIErrorWrapperLocal(err, "json_marshal_failed", http.StatusInternalServerError)
|
return types.NewError(err, types.ErrorCodeConvertRequestFailed)
|
||||||
}
|
}
|
||||||
requestBody := bytes.NewBuffer(jsonData)
|
requestBody := bytes.NewBuffer(jsonData)
|
||||||
statusCodeMappingStr := c.GetString("status_code_mapping")
|
statusCodeMappingStr := c.GetString("status_code_mapping")
|
||||||
resp, err := adaptor.DoRequest(c, relayInfo, requestBody)
|
resp, err := adaptor.DoRequest(c, relayInfo, requestBody)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return service.OpenAIErrorWrapper(err, "do_request_failed", http.StatusInternalServerError)
|
return types.NewError(err, types.ErrorCodeDoRequestFailed)
|
||||||
}
|
}
|
||||||
|
|
||||||
var httpResp *http.Response
|
var httpResp *http.Response
|
||||||
if resp != nil {
|
if resp != nil {
|
||||||
httpResp = resp.(*http.Response)
|
httpResp = resp.(*http.Response)
|
||||||
if httpResp.StatusCode != http.StatusOK {
|
if httpResp.StatusCode != http.StatusOK {
|
||||||
openaiErr = service.RelayErrorHandler(httpResp, false)
|
newAPIError = service.RelayErrorHandler(httpResp, false)
|
||||||
// reset status code 重置状态码
|
// reset status code 重置状态码
|
||||||
service.ResetStatusCode(openaiErr, statusCodeMappingStr)
|
service.ResetStatusCode(newAPIError, statusCodeMappingStr)
|
||||||
return openaiErr
|
return newAPIError
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
usage, openaiErr := adaptor.DoResponse(c, httpResp, relayInfo)
|
usage, newAPIError := adaptor.DoResponse(c, httpResp, relayInfo)
|
||||||
if openaiErr != nil {
|
if newAPIError != nil {
|
||||||
// reset status code 重置状态码
|
// reset status code 重置状态码
|
||||||
service.ResetStatusCode(openaiErr, statusCodeMappingStr)
|
service.ResetStatusCode(newAPIError, statusCodeMappingStr)
|
||||||
return openaiErr
|
return newAPIError
|
||||||
}
|
}
|
||||||
postConsumeQuota(c, relayInfo, usage.(*dto.Usage), preConsumedQuota, userQuota, priceData, "")
|
postConsumeQuota(c, relayInfo, usage.(*dto.Usage), preConsumedQuota, userQuota, priceData, "")
|
||||||
return nil
|
return nil
|
||||||
|
|||||||
@@ -14,6 +14,7 @@ import (
|
|||||||
"one-api/service"
|
"one-api/service"
|
||||||
"one-api/setting"
|
"one-api/setting"
|
||||||
"one-api/setting/model_setting"
|
"one-api/setting/model_setting"
|
||||||
|
"one-api/types"
|
||||||
"strings"
|
"strings"
|
||||||
|
|
||||||
"github.com/gin-gonic/gin"
|
"github.com/gin-gonic/gin"
|
||||||
@@ -104,11 +105,11 @@ func trimModelThinking(modelName string) string {
|
|||||||
return modelName
|
return modelName
|
||||||
}
|
}
|
||||||
|
|
||||||
func GeminiHelper(c *gin.Context) (openaiErr *dto.OpenAIErrorWithStatusCode) {
|
func GeminiHelper(c *gin.Context) (newAPIError *types.NewAPIError) {
|
||||||
req, err := getAndValidateGeminiRequest(c)
|
req, err := getAndValidateGeminiRequest(c)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
common.LogError(c, fmt.Sprintf("getAndValidateGeminiRequest error: %s", err.Error()))
|
common.LogError(c, fmt.Sprintf("getAndValidateGeminiRequest error: %s", err.Error()))
|
||||||
return service.OpenAIErrorWrapperLocal(err, "invalid_gemini_request", http.StatusBadRequest)
|
return types.NewError(err, types.ErrorCodeInvalidRequest)
|
||||||
}
|
}
|
||||||
|
|
||||||
relayInfo := relaycommon.GenRelayInfoGemini(c)
|
relayInfo := relaycommon.GenRelayInfoGemini(c)
|
||||||
@@ -120,14 +121,14 @@ func GeminiHelper(c *gin.Context) (openaiErr *dto.OpenAIErrorWithStatusCode) {
|
|||||||
sensitiveWords, err := checkGeminiInputSensitive(req)
|
sensitiveWords, err := checkGeminiInputSensitive(req)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
common.LogWarn(c, fmt.Sprintf("user sensitive words detected: %s", strings.Join(sensitiveWords, ", ")))
|
common.LogWarn(c, fmt.Sprintf("user sensitive words detected: %s", strings.Join(sensitiveWords, ", ")))
|
||||||
return service.OpenAIErrorWrapperLocal(err, "check_request_sensitive_error", http.StatusBadRequest)
|
return types.NewError(err, types.ErrorCodeSensitiveWordsDetected)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// model mapped 模型映射
|
// model mapped 模型映射
|
||||||
err = helper.ModelMappedHelper(c, relayInfo, req)
|
err = helper.ModelMappedHelper(c, relayInfo, req)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return service.OpenAIErrorWrapperLocal(err, "model_mapped_error", http.StatusBadRequest)
|
return types.NewError(err, types.ErrorCodeChannelModelMappedError)
|
||||||
}
|
}
|
||||||
|
|
||||||
if value, exists := c.Get("prompt_tokens"); exists {
|
if value, exists := c.Get("prompt_tokens"); exists {
|
||||||
@@ -158,23 +159,23 @@ func GeminiHelper(c *gin.Context) (openaiErr *dto.OpenAIErrorWithStatusCode) {
|
|||||||
|
|
||||||
priceData, err := helper.ModelPriceHelper(c, relayInfo, relayInfo.PromptTokens, int(req.GenerationConfig.MaxOutputTokens))
|
priceData, err := helper.ModelPriceHelper(c, relayInfo, relayInfo.PromptTokens, int(req.GenerationConfig.MaxOutputTokens))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return service.OpenAIErrorWrapperLocal(err, "model_price_error", http.StatusInternalServerError)
|
return types.NewError(err, types.ErrorCodeModelPriceError)
|
||||||
}
|
}
|
||||||
|
|
||||||
// pre consume quota
|
// pre consume quota
|
||||||
preConsumedQuota, userQuota, openaiErr := preConsumeQuota(c, priceData.ShouldPreConsumedQuota, relayInfo)
|
preConsumedQuota, userQuota, newAPIError := preConsumeQuota(c, priceData.ShouldPreConsumedQuota, relayInfo)
|
||||||
if openaiErr != nil {
|
if newAPIError != nil {
|
||||||
return openaiErr
|
return newAPIError
|
||||||
}
|
}
|
||||||
defer func() {
|
defer func() {
|
||||||
if openaiErr != nil {
|
if newAPIError != nil {
|
||||||
returnPreConsumedQuota(c, relayInfo, userQuota, preConsumedQuota)
|
returnPreConsumedQuota(c, relayInfo, userQuota, preConsumedQuota)
|
||||||
}
|
}
|
||||||
}()
|
}()
|
||||||
|
|
||||||
adaptor := GetAdaptor(relayInfo.ApiType)
|
adaptor := GetAdaptor(relayInfo.ApiType)
|
||||||
if adaptor == nil {
|
if adaptor == nil {
|
||||||
return service.OpenAIErrorWrapperLocal(fmt.Errorf("invalid api type: %d", relayInfo.ApiType), "invalid_api_type", http.StatusBadRequest)
|
return types.NewError(fmt.Errorf("invalid api type: %d", relayInfo.ApiType), types.ErrorCodeInvalidApiType)
|
||||||
}
|
}
|
||||||
|
|
||||||
adaptor.Init(relayInfo)
|
adaptor.Init(relayInfo)
|
||||||
@@ -195,7 +196,7 @@ func GeminiHelper(c *gin.Context) (openaiErr *dto.OpenAIErrorWithStatusCode) {
|
|||||||
|
|
||||||
requestBody, err := json.Marshal(req)
|
requestBody, err := json.Marshal(req)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return service.OpenAIErrorWrapperLocal(err, "marshal_text_request_failed", http.StatusInternalServerError)
|
return types.NewError(err, types.ErrorCodeConvertRequestFailed)
|
||||||
}
|
}
|
||||||
|
|
||||||
if common.DebugEnabled {
|
if common.DebugEnabled {
|
||||||
@@ -205,7 +206,7 @@ func GeminiHelper(c *gin.Context) (openaiErr *dto.OpenAIErrorWithStatusCode) {
|
|||||||
resp, err := adaptor.DoRequest(c, relayInfo, bytes.NewReader(requestBody))
|
resp, err := adaptor.DoRequest(c, relayInfo, bytes.NewReader(requestBody))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
common.LogError(c, "Do gemini request failed: "+err.Error())
|
common.LogError(c, "Do gemini request failed: "+err.Error())
|
||||||
return service.OpenAIErrorWrapper(err, "do_request_failed", http.StatusInternalServerError)
|
return types.NewError(err, types.ErrorCodeDoRequestFailed)
|
||||||
}
|
}
|
||||||
|
|
||||||
statusCodeMappingStr := c.GetString("status_code_mapping")
|
statusCodeMappingStr := c.GetString("status_code_mapping")
|
||||||
@@ -215,10 +216,10 @@ func GeminiHelper(c *gin.Context) (openaiErr *dto.OpenAIErrorWithStatusCode) {
|
|||||||
httpResp = resp.(*http.Response)
|
httpResp = resp.(*http.Response)
|
||||||
relayInfo.IsStream = relayInfo.IsStream || strings.HasPrefix(httpResp.Header.Get("Content-Type"), "text/event-stream")
|
relayInfo.IsStream = relayInfo.IsStream || strings.HasPrefix(httpResp.Header.Get("Content-Type"), "text/event-stream")
|
||||||
if httpResp.StatusCode != http.StatusOK {
|
if httpResp.StatusCode != http.StatusOK {
|
||||||
openaiErr = service.RelayErrorHandler(httpResp, false)
|
newAPIError = service.RelayErrorHandler(httpResp, false)
|
||||||
// reset status code 重置状态码
|
// reset status code 重置状态码
|
||||||
service.ResetStatusCode(openaiErr, statusCodeMappingStr)
|
service.ResetStatusCode(newAPIError, statusCodeMappingStr)
|
||||||
return openaiErr
|
return newAPIError
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -4,27 +4,29 @@ import (
|
|||||||
"encoding/json"
|
"encoding/json"
|
||||||
"errors"
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
"github.com/gin-gonic/gin"
|
|
||||||
"github.com/gorilla/websocket"
|
|
||||||
"net/http"
|
"net/http"
|
||||||
"one-api/common"
|
"one-api/common"
|
||||||
"one-api/dto"
|
"one-api/dto"
|
||||||
|
"one-api/types"
|
||||||
|
|
||||||
|
"github.com/gin-gonic/gin"
|
||||||
|
"github.com/gorilla/websocket"
|
||||||
)
|
)
|
||||||
|
|
||||||
func SetEventStreamHeaders(c *gin.Context) {
|
func SetEventStreamHeaders(c *gin.Context) {
|
||||||
// 检查是否已经设置过头部
|
// 检查是否已经设置过头部
|
||||||
if _, exists := c.Get("event_stream_headers_set"); exists {
|
if _, exists := c.Get("event_stream_headers_set"); exists {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
c.Writer.Header().Set("Content-Type", "text/event-stream")
|
c.Writer.Header().Set("Content-Type", "text/event-stream")
|
||||||
c.Writer.Header().Set("Cache-Control", "no-cache")
|
c.Writer.Header().Set("Cache-Control", "no-cache")
|
||||||
c.Writer.Header().Set("Connection", "keep-alive")
|
c.Writer.Header().Set("Connection", "keep-alive")
|
||||||
c.Writer.Header().Set("Transfer-Encoding", "chunked")
|
c.Writer.Header().Set("Transfer-Encoding", "chunked")
|
||||||
c.Writer.Header().Set("X-Accel-Buffering", "no")
|
c.Writer.Header().Set("X-Accel-Buffering", "no")
|
||||||
|
|
||||||
// 设置标志,表示头部已经设置过
|
// 设置标志,表示头部已经设置过
|
||||||
c.Set("event_stream_headers_set", true)
|
c.Set("event_stream_headers_set", true)
|
||||||
}
|
}
|
||||||
|
|
||||||
func ClaudeData(c *gin.Context, resp dto.ClaudeResponse) error {
|
func ClaudeData(c *gin.Context, resp dto.ClaudeResponse) error {
|
||||||
@@ -85,7 +87,7 @@ func ObjectData(c *gin.Context, object interface{}) error {
|
|||||||
if object == nil {
|
if object == nil {
|
||||||
return errors.New("object is nil")
|
return errors.New("object is nil")
|
||||||
}
|
}
|
||||||
jsonData, err := json.Marshal(object)
|
jsonData, err := common.Marshal(object)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("error marshalling object: %w", err)
|
return fmt.Errorf("error marshalling object: %w", err)
|
||||||
}
|
}
|
||||||
@@ -118,7 +120,7 @@ func WssObject(c *gin.Context, ws *websocket.Conn, object interface{}) error {
|
|||||||
return ws.WriteMessage(1, jsonData)
|
return ws.WriteMessage(1, jsonData)
|
||||||
}
|
}
|
||||||
|
|
||||||
func WssError(c *gin.Context, ws *websocket.Conn, openaiError dto.OpenAIError) {
|
func WssError(c *gin.Context, ws *websocket.Conn, openaiError types.OpenAIError) {
|
||||||
errorObj := &dto.RealtimeEvent{
|
errorObj := &dto.RealtimeEvent{
|
||||||
Type: "error",
|
Type: "error",
|
||||||
EventId: GetLocalRealtimeID(c),
|
EventId: GetLocalRealtimeID(c),
|
||||||
|
|||||||
@@ -16,6 +16,7 @@ import (
|
|||||||
"one-api/relay/helper"
|
"one-api/relay/helper"
|
||||||
"one-api/service"
|
"one-api/service"
|
||||||
"one-api/setting"
|
"one-api/setting"
|
||||||
|
"one-api/types"
|
||||||
"strings"
|
"strings"
|
||||||
|
|
||||||
"github.com/gin-gonic/gin"
|
"github.com/gin-gonic/gin"
|
||||||
@@ -107,23 +108,23 @@ func getAndValidImageRequest(c *gin.Context, info *relaycommon.RelayInfo) (*dto.
|
|||||||
return imageRequest, nil
|
return imageRequest, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func ImageHelper(c *gin.Context) *dto.OpenAIErrorWithStatusCode {
|
func ImageHelper(c *gin.Context) (newAPIError *types.NewAPIError) {
|
||||||
relayInfo := relaycommon.GenRelayInfoImage(c)
|
relayInfo := relaycommon.GenRelayInfoImage(c)
|
||||||
|
|
||||||
imageRequest, err := getAndValidImageRequest(c, relayInfo)
|
imageRequest, err := getAndValidImageRequest(c, relayInfo)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
common.LogError(c, fmt.Sprintf("getAndValidImageRequest failed: %s", err.Error()))
|
common.LogError(c, fmt.Sprintf("getAndValidImageRequest failed: %s", err.Error()))
|
||||||
return service.OpenAIErrorWrapper(err, "invalid_image_request", http.StatusBadRequest)
|
return types.NewError(err, types.ErrorCodeInvalidRequest)
|
||||||
}
|
}
|
||||||
|
|
||||||
err = helper.ModelMappedHelper(c, relayInfo, imageRequest)
|
err = helper.ModelMappedHelper(c, relayInfo, imageRequest)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return service.OpenAIErrorWrapperLocal(err, "model_mapped_error", http.StatusInternalServerError)
|
return types.NewError(err, types.ErrorCodeChannelModelMappedError)
|
||||||
}
|
}
|
||||||
|
|
||||||
priceData, err := helper.ModelPriceHelper(c, relayInfo, len(imageRequest.Prompt), 0)
|
priceData, err := helper.ModelPriceHelper(c, relayInfo, len(imageRequest.Prompt), 0)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return service.OpenAIErrorWrapperLocal(err, "model_price_error", http.StatusInternalServerError)
|
return types.NewError(err, types.ErrorCodeModelPriceError)
|
||||||
}
|
}
|
||||||
var preConsumedQuota int
|
var preConsumedQuota int
|
||||||
var quota int
|
var quota int
|
||||||
@@ -132,13 +133,12 @@ func ImageHelper(c *gin.Context) *dto.OpenAIErrorWithStatusCode {
|
|||||||
// modelRatio 16 = modelPrice $0.04
|
// modelRatio 16 = modelPrice $0.04
|
||||||
// per 1 modelRatio = $0.04 / 16
|
// per 1 modelRatio = $0.04 / 16
|
||||||
// priceData.ModelPrice = 0.0025 * priceData.ModelRatio
|
// priceData.ModelPrice = 0.0025 * priceData.ModelRatio
|
||||||
var openaiErr *dto.OpenAIErrorWithStatusCode
|
preConsumedQuota, userQuota, newAPIError = preConsumeQuota(c, priceData.ShouldPreConsumedQuota, relayInfo)
|
||||||
preConsumedQuota, userQuota, openaiErr = preConsumeQuota(c, priceData.ShouldPreConsumedQuota, relayInfo)
|
if newAPIError != nil {
|
||||||
if openaiErr != nil {
|
return newAPIError
|
||||||
return openaiErr
|
|
||||||
}
|
}
|
||||||
defer func() {
|
defer func() {
|
||||||
if openaiErr != nil {
|
if newAPIError != nil {
|
||||||
returnPreConsumedQuota(c, relayInfo, userQuota, preConsumedQuota)
|
returnPreConsumedQuota(c, relayInfo, userQuota, preConsumedQuota)
|
||||||
}
|
}
|
||||||
}()
|
}()
|
||||||
@@ -169,16 +169,16 @@ func ImageHelper(c *gin.Context) *dto.OpenAIErrorWithStatusCode {
|
|||||||
quota = int(priceData.ModelPrice * priceData.GroupRatioInfo.GroupRatio * common.QuotaPerUnit)
|
quota = int(priceData.ModelPrice * priceData.GroupRatioInfo.GroupRatio * common.QuotaPerUnit)
|
||||||
userQuota, err = model.GetUserQuota(relayInfo.UserId, false)
|
userQuota, err = model.GetUserQuota(relayInfo.UserId, false)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return service.OpenAIErrorWrapperLocal(err, "get_user_quota_failed", http.StatusInternalServerError)
|
return types.NewError(err, types.ErrorCodeQueryDataError)
|
||||||
}
|
}
|
||||||
if userQuota-quota < 0 {
|
if userQuota-quota < 0 {
|
||||||
return service.OpenAIErrorWrapperLocal(fmt.Errorf("image pre-consumed quota failed, user quota: %s, need quota: %s", common.FormatQuota(userQuota), common.FormatQuota(quota)), "insufficient_user_quota", http.StatusForbidden)
|
return types.NewError(fmt.Errorf("image pre-consumed quota failed, user quota: %s, need quota: %s", common.FormatQuota(userQuota), common.FormatQuota(quota)), types.ErrorCodeInsufficientUserQuota)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
adaptor := GetAdaptor(relayInfo.ApiType)
|
adaptor := GetAdaptor(relayInfo.ApiType)
|
||||||
if adaptor == nil {
|
if adaptor == nil {
|
||||||
return service.OpenAIErrorWrapperLocal(fmt.Errorf("invalid api type: %d", relayInfo.ApiType), "invalid_api_type", http.StatusBadRequest)
|
return types.NewError(fmt.Errorf("invalid api type: %d", relayInfo.ApiType), types.ErrorCodeInvalidApiType)
|
||||||
}
|
}
|
||||||
adaptor.Init(relayInfo)
|
adaptor.Init(relayInfo)
|
||||||
|
|
||||||
@@ -186,14 +186,14 @@ func ImageHelper(c *gin.Context) *dto.OpenAIErrorWithStatusCode {
|
|||||||
|
|
||||||
convertedRequest, err := adaptor.ConvertImageRequest(c, relayInfo, *imageRequest)
|
convertedRequest, err := adaptor.ConvertImageRequest(c, relayInfo, *imageRequest)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return service.OpenAIErrorWrapperLocal(err, "convert_request_failed", http.StatusInternalServerError)
|
return types.NewError(err, types.ErrorCodeConvertRequestFailed)
|
||||||
}
|
}
|
||||||
if relayInfo.RelayMode == relayconstant.RelayModeImagesEdits {
|
if relayInfo.RelayMode == relayconstant.RelayModeImagesEdits {
|
||||||
requestBody = convertedRequest.(io.Reader)
|
requestBody = convertedRequest.(io.Reader)
|
||||||
} else {
|
} else {
|
||||||
jsonData, err := json.Marshal(convertedRequest)
|
jsonData, err := json.Marshal(convertedRequest)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return service.OpenAIErrorWrapperLocal(err, "json_marshal_failed", http.StatusInternalServerError)
|
return types.NewError(err, types.ErrorCodeConvertRequestFailed)
|
||||||
}
|
}
|
||||||
requestBody = bytes.NewBuffer(jsonData)
|
requestBody = bytes.NewBuffer(jsonData)
|
||||||
}
|
}
|
||||||
@@ -206,25 +206,25 @@ func ImageHelper(c *gin.Context) *dto.OpenAIErrorWithStatusCode {
|
|||||||
|
|
||||||
resp, err := adaptor.DoRequest(c, relayInfo, requestBody)
|
resp, err := adaptor.DoRequest(c, relayInfo, requestBody)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return service.OpenAIErrorWrapper(err, "do_request_failed", http.StatusInternalServerError)
|
return types.NewError(err, types.ErrorCodeDoRequestFailed)
|
||||||
}
|
}
|
||||||
var httpResp *http.Response
|
var httpResp *http.Response
|
||||||
if resp != nil {
|
if resp != nil {
|
||||||
httpResp = resp.(*http.Response)
|
httpResp = resp.(*http.Response)
|
||||||
relayInfo.IsStream = relayInfo.IsStream || strings.HasPrefix(httpResp.Header.Get("Content-Type"), "text/event-stream")
|
relayInfo.IsStream = relayInfo.IsStream || strings.HasPrefix(httpResp.Header.Get("Content-Type"), "text/event-stream")
|
||||||
if httpResp.StatusCode != http.StatusOK {
|
if httpResp.StatusCode != http.StatusOK {
|
||||||
openaiErr := service.RelayErrorHandler(httpResp, false)
|
newAPIError = service.RelayErrorHandler(httpResp, false)
|
||||||
// reset status code 重置状态码
|
// reset status code 重置状态码
|
||||||
service.ResetStatusCode(openaiErr, statusCodeMappingStr)
|
service.ResetStatusCode(newAPIError, statusCodeMappingStr)
|
||||||
return openaiErr
|
return newAPIError
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
usage, openaiErr := adaptor.DoResponse(c, httpResp, relayInfo)
|
usage, newAPIError := adaptor.DoResponse(c, httpResp, relayInfo)
|
||||||
if openaiErr != nil {
|
if newAPIError != nil {
|
||||||
// reset status code 重置状态码
|
// reset status code 重置状态码
|
||||||
service.ResetStatusCode(openaiErr, statusCodeMappingStr)
|
service.ResetStatusCode(newAPIError, statusCodeMappingStr)
|
||||||
return openaiErr
|
return newAPIError
|
||||||
}
|
}
|
||||||
|
|
||||||
if usage.(*dto.Usage).TotalTokens == 0 {
|
if usage.(*dto.Usage).TotalTokens == 0 {
|
||||||
|
|||||||
@@ -19,6 +19,7 @@ import (
|
|||||||
"one-api/setting"
|
"one-api/setting"
|
||||||
"one-api/setting/model_setting"
|
"one-api/setting/model_setting"
|
||||||
"one-api/setting/operation_setting"
|
"one-api/setting/operation_setting"
|
||||||
|
"one-api/types"
|
||||||
"strings"
|
"strings"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
@@ -84,7 +85,7 @@ func getAndValidateTextRequest(c *gin.Context, relayInfo *relaycommon.RelayInfo)
|
|||||||
return textRequest, nil
|
return textRequest, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func TextHelper(c *gin.Context) (openaiErr *dto.OpenAIErrorWithStatusCode) {
|
func TextHelper(c *gin.Context) (newAPIError *types.NewAPIError) {
|
||||||
|
|
||||||
relayInfo := relaycommon.GenRelayInfo(c)
|
relayInfo := relaycommon.GenRelayInfo(c)
|
||||||
|
|
||||||
@@ -92,8 +93,7 @@ func TextHelper(c *gin.Context) (openaiErr *dto.OpenAIErrorWithStatusCode) {
|
|||||||
textRequest, err := getAndValidateTextRequest(c, relayInfo)
|
textRequest, err := getAndValidateTextRequest(c, relayInfo)
|
||||||
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
common.LogError(c, fmt.Sprintf("getAndValidateTextRequest failed: %s", err.Error()))
|
return types.NewError(err, types.ErrorCodeInvalidRequest)
|
||||||
return service.OpenAIErrorWrapperLocal(err, "invalid_text_request", http.StatusBadRequest)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
if textRequest.WebSearchOptions != nil {
|
if textRequest.WebSearchOptions != nil {
|
||||||
@@ -104,13 +104,13 @@ func TextHelper(c *gin.Context) (openaiErr *dto.OpenAIErrorWithStatusCode) {
|
|||||||
words, err := checkRequestSensitive(textRequest, relayInfo)
|
words, err := checkRequestSensitive(textRequest, relayInfo)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
common.LogWarn(c, fmt.Sprintf("user sensitive words detected: %s", strings.Join(words, ", ")))
|
common.LogWarn(c, fmt.Sprintf("user sensitive words detected: %s", strings.Join(words, ", ")))
|
||||||
return service.OpenAIErrorWrapperLocal(err, "sensitive_words_detected", http.StatusBadRequest)
|
return types.NewError(err, types.ErrorCodeSensitiveWordsDetected)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
err = helper.ModelMappedHelper(c, relayInfo, textRequest)
|
err = helper.ModelMappedHelper(c, relayInfo, textRequest)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return service.OpenAIErrorWrapperLocal(err, "model_mapped_error", http.StatusInternalServerError)
|
return types.NewError(err, types.ErrorCodeChannelModelMappedError)
|
||||||
}
|
}
|
||||||
|
|
||||||
// 获取 promptTokens,如果上下文中已经存在,则直接使用
|
// 获取 promptTokens,如果上下文中已经存在,则直接使用
|
||||||
@@ -122,23 +122,23 @@ func TextHelper(c *gin.Context) (openaiErr *dto.OpenAIErrorWithStatusCode) {
|
|||||||
promptTokens, err = getPromptTokens(textRequest, relayInfo)
|
promptTokens, err = getPromptTokens(textRequest, relayInfo)
|
||||||
// count messages token error 计算promptTokens错误
|
// count messages token error 计算promptTokens错误
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return service.OpenAIErrorWrapper(err, "count_token_messages_failed", http.StatusInternalServerError)
|
return types.NewError(err, types.ErrorCodeCountTokenFailed)
|
||||||
}
|
}
|
||||||
c.Set("prompt_tokens", promptTokens)
|
c.Set("prompt_tokens", promptTokens)
|
||||||
}
|
}
|
||||||
|
|
||||||
priceData, err := helper.ModelPriceHelper(c, relayInfo, promptTokens, int(math.Max(float64(textRequest.MaxTokens), float64(textRequest.MaxCompletionTokens))))
|
priceData, err := helper.ModelPriceHelper(c, relayInfo, promptTokens, int(math.Max(float64(textRequest.MaxTokens), float64(textRequest.MaxCompletionTokens))))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return service.OpenAIErrorWrapperLocal(err, "model_price_error", http.StatusInternalServerError)
|
return types.NewError(err, types.ErrorCodeModelPriceError)
|
||||||
}
|
}
|
||||||
|
|
||||||
// pre-consume quota 预消耗配额
|
// pre-consume quota 预消耗配额
|
||||||
preConsumedQuota, userQuota, openaiErr := preConsumeQuota(c, priceData.ShouldPreConsumedQuota, relayInfo)
|
preConsumedQuota, userQuota, newApiErr := preConsumeQuota(c, priceData.ShouldPreConsumedQuota, relayInfo)
|
||||||
if openaiErr != nil {
|
if newApiErr != nil {
|
||||||
return openaiErr
|
return newApiErr
|
||||||
}
|
}
|
||||||
defer func() {
|
defer func() {
|
||||||
if openaiErr != nil {
|
if newApiErr != nil {
|
||||||
returnPreConsumedQuota(c, relayInfo, userQuota, preConsumedQuota)
|
returnPreConsumedQuota(c, relayInfo, userQuota, preConsumedQuota)
|
||||||
}
|
}
|
||||||
}()
|
}()
|
||||||
@@ -166,7 +166,7 @@ func TextHelper(c *gin.Context) (openaiErr *dto.OpenAIErrorWithStatusCode) {
|
|||||||
|
|
||||||
adaptor := GetAdaptor(relayInfo.ApiType)
|
adaptor := GetAdaptor(relayInfo.ApiType)
|
||||||
if adaptor == nil {
|
if adaptor == nil {
|
||||||
return service.OpenAIErrorWrapperLocal(fmt.Errorf("invalid api type: %d", relayInfo.ApiType), "invalid_api_type", http.StatusBadRequest)
|
return types.NewError(fmt.Errorf("invalid api type: %d", relayInfo.ApiType), types.ErrorCodeInvalidApiType)
|
||||||
}
|
}
|
||||||
adaptor.Init(relayInfo)
|
adaptor.Init(relayInfo)
|
||||||
var requestBody io.Reader
|
var requestBody io.Reader
|
||||||
@@ -174,32 +174,29 @@ func TextHelper(c *gin.Context) (openaiErr *dto.OpenAIErrorWithStatusCode) {
|
|||||||
if model_setting.GetGlobalSettings().PassThroughRequestEnabled {
|
if model_setting.GetGlobalSettings().PassThroughRequestEnabled {
|
||||||
body, err := common.GetRequestBody(c)
|
body, err := common.GetRequestBody(c)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return service.OpenAIErrorWrapperLocal(err, "get_request_body_failed", http.StatusInternalServerError)
|
return types.NewErrorWithStatusCode(err, types.ErrorCodeReadRequestBodyFailed, http.StatusBadRequest)
|
||||||
}
|
}
|
||||||
requestBody = bytes.NewBuffer(body)
|
requestBody = bytes.NewBuffer(body)
|
||||||
} else {
|
} else {
|
||||||
convertedRequest, err := adaptor.ConvertOpenAIRequest(c, relayInfo, textRequest)
|
convertedRequest, err := adaptor.ConvertOpenAIRequest(c, relayInfo, textRequest)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return service.OpenAIErrorWrapperLocal(err, "convert_request_failed", http.StatusInternalServerError)
|
return types.NewError(err, types.ErrorCodeConvertRequestFailed)
|
||||||
}
|
}
|
||||||
jsonData, err := json.Marshal(convertedRequest)
|
jsonData, err := json.Marshal(convertedRequest)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return service.OpenAIErrorWrapperLocal(err, "json_marshal_failed", http.StatusInternalServerError)
|
return types.NewError(err, types.ErrorCodeConvertRequestFailed)
|
||||||
}
|
}
|
||||||
|
|
||||||
// apply param override
|
// apply param override
|
||||||
if len(relayInfo.ParamOverride) > 0 {
|
if len(relayInfo.ParamOverride) > 0 {
|
||||||
reqMap := make(map[string]interface{})
|
reqMap := make(map[string]interface{})
|
||||||
err = json.Unmarshal(jsonData, &reqMap)
|
_ = common.Unmarshal(jsonData, &reqMap)
|
||||||
if err != nil {
|
|
||||||
return service.OpenAIErrorWrapperLocal(err, "param_override_unmarshal_failed", http.StatusInternalServerError)
|
|
||||||
}
|
|
||||||
for key, value := range relayInfo.ParamOverride {
|
for key, value := range relayInfo.ParamOverride {
|
||||||
reqMap[key] = value
|
reqMap[key] = value
|
||||||
}
|
}
|
||||||
jsonData, err = json.Marshal(reqMap)
|
jsonData, err = common.Marshal(reqMap)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return service.OpenAIErrorWrapperLocal(err, "param_override_marshal_failed", http.StatusInternalServerError)
|
return types.NewError(err, types.ErrorCodeChannelParamOverrideInvalid)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -213,7 +210,7 @@ func TextHelper(c *gin.Context) (openaiErr *dto.OpenAIErrorWithStatusCode) {
|
|||||||
resp, err := adaptor.DoRequest(c, relayInfo, requestBody)
|
resp, err := adaptor.DoRequest(c, relayInfo, requestBody)
|
||||||
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return service.OpenAIErrorWrapper(err, "do_request_failed", http.StatusInternalServerError)
|
return types.NewError(err, types.ErrorCodeDoRequestFailed)
|
||||||
}
|
}
|
||||||
|
|
||||||
statusCodeMappingStr := c.GetString("status_code_mapping")
|
statusCodeMappingStr := c.GetString("status_code_mapping")
|
||||||
@@ -222,18 +219,18 @@ func TextHelper(c *gin.Context) (openaiErr *dto.OpenAIErrorWithStatusCode) {
|
|||||||
httpResp = resp.(*http.Response)
|
httpResp = resp.(*http.Response)
|
||||||
relayInfo.IsStream = relayInfo.IsStream || strings.HasPrefix(httpResp.Header.Get("Content-Type"), "text/event-stream")
|
relayInfo.IsStream = relayInfo.IsStream || strings.HasPrefix(httpResp.Header.Get("Content-Type"), "text/event-stream")
|
||||||
if httpResp.StatusCode != http.StatusOK {
|
if httpResp.StatusCode != http.StatusOK {
|
||||||
openaiErr = service.RelayErrorHandler(httpResp, false)
|
newApiErr = service.RelayErrorHandler(httpResp, false)
|
||||||
// reset status code 重置状态码
|
// reset status code 重置状态码
|
||||||
service.ResetStatusCode(openaiErr, statusCodeMappingStr)
|
service.ResetStatusCode(newApiErr, statusCodeMappingStr)
|
||||||
return openaiErr
|
return newApiErr
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
usage, openaiErr := adaptor.DoResponse(c, httpResp, relayInfo)
|
usage, newApiErr := adaptor.DoResponse(c, httpResp, relayInfo)
|
||||||
if openaiErr != nil {
|
if newApiErr != nil {
|
||||||
// reset status code 重置状态码
|
// reset status code 重置状态码
|
||||||
service.ResetStatusCode(openaiErr, statusCodeMappingStr)
|
service.ResetStatusCode(newApiErr, statusCodeMappingStr)
|
||||||
return openaiErr
|
return newApiErr
|
||||||
}
|
}
|
||||||
|
|
||||||
if strings.HasPrefix(relayInfo.OriginModelName, "gpt-4o-audio") {
|
if strings.HasPrefix(relayInfo.OriginModelName, "gpt-4o-audio") {
|
||||||
@@ -281,16 +278,16 @@ func checkRequestSensitive(textRequest *dto.GeneralOpenAIRequest, info *relaycom
|
|||||||
}
|
}
|
||||||
|
|
||||||
// 预扣费并返回用户剩余配额
|
// 预扣费并返回用户剩余配额
|
||||||
func preConsumeQuota(c *gin.Context, preConsumedQuota int, relayInfo *relaycommon.RelayInfo) (int, int, *dto.OpenAIErrorWithStatusCode) {
|
func preConsumeQuota(c *gin.Context, preConsumedQuota int, relayInfo *relaycommon.RelayInfo) (int, int, *types.NewAPIError) {
|
||||||
userQuota, err := model.GetUserQuota(relayInfo.UserId, false)
|
userQuota, err := model.GetUserQuota(relayInfo.UserId, false)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return 0, 0, service.OpenAIErrorWrapperLocal(err, "get_user_quota_failed", http.StatusInternalServerError)
|
return 0, 0, types.NewError(err, types.ErrorCodeQueryDataError)
|
||||||
}
|
}
|
||||||
if userQuota <= 0 {
|
if userQuota <= 0 {
|
||||||
return 0, 0, service.OpenAIErrorWrapperLocal(errors.New("user quota is not enough"), "insufficient_user_quota", http.StatusForbidden)
|
return 0, 0, types.NewErrorWithStatusCode(errors.New("user quota is not enough"), types.ErrorCodeInsufficientUserQuota, http.StatusForbidden)
|
||||||
}
|
}
|
||||||
if userQuota-preConsumedQuota < 0 {
|
if userQuota-preConsumedQuota < 0 {
|
||||||
return 0, 0, service.OpenAIErrorWrapperLocal(fmt.Errorf("chat pre-consumed quota failed, user quota: %s, need quota: %s", common.FormatQuota(userQuota), common.FormatQuota(preConsumedQuota)), "insufficient_user_quota", http.StatusForbidden)
|
return 0, 0, types.NewErrorWithStatusCode(fmt.Errorf("pre-consume quota failed, user quota: %s, need quota: %s", common.FormatQuota(userQuota), common.FormatQuota(preConsumedQuota)), types.ErrorCodeInsufficientUserQuota, http.StatusForbidden)
|
||||||
}
|
}
|
||||||
relayInfo.UserQuota = userQuota
|
relayInfo.UserQuota = userQuota
|
||||||
if userQuota > 100*preConsumedQuota {
|
if userQuota > 100*preConsumedQuota {
|
||||||
@@ -314,11 +311,11 @@ func preConsumeQuota(c *gin.Context, preConsumedQuota int, relayInfo *relaycommo
|
|||||||
if preConsumedQuota > 0 {
|
if preConsumedQuota > 0 {
|
||||||
err := service.PreConsumeTokenQuota(relayInfo, preConsumedQuota)
|
err := service.PreConsumeTokenQuota(relayInfo, preConsumedQuota)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return 0, 0, service.OpenAIErrorWrapperLocal(err, "pre_consume_token_quota_failed", http.StatusForbidden)
|
return 0, 0, types.NewErrorWithStatusCode(err, types.ErrorCodePreConsumeTokenQuotaFailed, http.StatusForbidden)
|
||||||
}
|
}
|
||||||
err = model.DecreaseUserQuota(relayInfo.UserId, preConsumedQuota)
|
err = model.DecreaseUserQuota(relayInfo.UserId, preConsumedQuota)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return 0, 0, service.OpenAIErrorWrapperLocal(err, "decrease_user_quota_failed", http.StatusInternalServerError)
|
return 0, 0, types.NewError(err, types.ErrorCodeUpdateDataError)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
return preConsumedQuota, userQuota, nil
|
return preConsumedQuota, userQuota, nil
|
||||||
|
|||||||
@@ -2,15 +2,16 @@ package relay
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"bytes"
|
"bytes"
|
||||||
"encoding/json"
|
|
||||||
"fmt"
|
"fmt"
|
||||||
"github.com/gin-gonic/gin"
|
|
||||||
"net/http"
|
"net/http"
|
||||||
"one-api/common"
|
"one-api/common"
|
||||||
"one-api/dto"
|
"one-api/dto"
|
||||||
relaycommon "one-api/relay/common"
|
relaycommon "one-api/relay/common"
|
||||||
"one-api/relay/helper"
|
"one-api/relay/helper"
|
||||||
"one-api/service"
|
"one-api/service"
|
||||||
|
"one-api/types"
|
||||||
|
|
||||||
|
"github.com/gin-gonic/gin"
|
||||||
)
|
)
|
||||||
|
|
||||||
func getRerankPromptToken(rerankRequest dto.RerankRequest) int {
|
func getRerankPromptToken(rerankRequest dto.RerankRequest) int {
|
||||||
@@ -22,27 +23,27 @@ func getRerankPromptToken(rerankRequest dto.RerankRequest) int {
|
|||||||
return token
|
return token
|
||||||
}
|
}
|
||||||
|
|
||||||
func RerankHelper(c *gin.Context, relayMode int) (openaiErr *dto.OpenAIErrorWithStatusCode) {
|
func RerankHelper(c *gin.Context, relayMode int) (newAPIError *types.NewAPIError) {
|
||||||
|
|
||||||
var rerankRequest *dto.RerankRequest
|
var rerankRequest *dto.RerankRequest
|
||||||
err := common.UnmarshalBodyReusable(c, &rerankRequest)
|
err := common.UnmarshalBodyReusable(c, &rerankRequest)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
common.LogError(c, fmt.Sprintf("getAndValidateTextRequest failed: %s", err.Error()))
|
common.LogError(c, fmt.Sprintf("getAndValidateTextRequest failed: %s", err.Error()))
|
||||||
return service.OpenAIErrorWrapperLocal(err, "invalid_text_request", http.StatusBadRequest)
|
return types.NewError(err, types.ErrorCodeInvalidRequest)
|
||||||
}
|
}
|
||||||
|
|
||||||
relayInfo := relaycommon.GenRelayInfoRerank(c, rerankRequest)
|
relayInfo := relaycommon.GenRelayInfoRerank(c, rerankRequest)
|
||||||
|
|
||||||
if rerankRequest.Query == "" {
|
if rerankRequest.Query == "" {
|
||||||
return service.OpenAIErrorWrapperLocal(fmt.Errorf("query is empty"), "invalid_query", http.StatusBadRequest)
|
return types.NewError(fmt.Errorf("query is empty"), types.ErrorCodeInvalidRequest)
|
||||||
}
|
}
|
||||||
if len(rerankRequest.Documents) == 0 {
|
if len(rerankRequest.Documents) == 0 {
|
||||||
return service.OpenAIErrorWrapperLocal(fmt.Errorf("documents is empty"), "invalid_documents", http.StatusBadRequest)
|
return types.NewError(fmt.Errorf("documents is empty"), types.ErrorCodeInvalidRequest)
|
||||||
}
|
}
|
||||||
|
|
||||||
err = helper.ModelMappedHelper(c, relayInfo, rerankRequest)
|
err = helper.ModelMappedHelper(c, relayInfo, rerankRequest)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return service.OpenAIErrorWrapperLocal(err, "model_mapped_error", http.StatusInternalServerError)
|
return types.NewError(err, types.ErrorCodeChannelModelMappedError)
|
||||||
}
|
}
|
||||||
|
|
||||||
promptToken := getRerankPromptToken(*rerankRequest)
|
promptToken := getRerankPromptToken(*rerankRequest)
|
||||||
@@ -50,32 +51,32 @@ func RerankHelper(c *gin.Context, relayMode int) (openaiErr *dto.OpenAIErrorWith
|
|||||||
|
|
||||||
priceData, err := helper.ModelPriceHelper(c, relayInfo, promptToken, 0)
|
priceData, err := helper.ModelPriceHelper(c, relayInfo, promptToken, 0)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return service.OpenAIErrorWrapperLocal(err, "model_price_error", http.StatusInternalServerError)
|
return types.NewError(err, types.ErrorCodeModelPriceError)
|
||||||
}
|
}
|
||||||
// pre-consume quota 预消耗配额
|
// pre-consume quota 预消耗配额
|
||||||
preConsumedQuota, userQuota, openaiErr := preConsumeQuota(c, priceData.ShouldPreConsumedQuota, relayInfo)
|
preConsumedQuota, userQuota, newAPIError := preConsumeQuota(c, priceData.ShouldPreConsumedQuota, relayInfo)
|
||||||
if openaiErr != nil {
|
if newAPIError != nil {
|
||||||
return openaiErr
|
return newAPIError
|
||||||
}
|
}
|
||||||
defer func() {
|
defer func() {
|
||||||
if openaiErr != nil {
|
if newAPIError != nil {
|
||||||
returnPreConsumedQuota(c, relayInfo, userQuota, preConsumedQuota)
|
returnPreConsumedQuota(c, relayInfo, userQuota, preConsumedQuota)
|
||||||
}
|
}
|
||||||
}()
|
}()
|
||||||
|
|
||||||
adaptor := GetAdaptor(relayInfo.ApiType)
|
adaptor := GetAdaptor(relayInfo.ApiType)
|
||||||
if adaptor == nil {
|
if adaptor == nil {
|
||||||
return service.OpenAIErrorWrapperLocal(fmt.Errorf("invalid api type: %d", relayInfo.ApiType), "invalid_api_type", http.StatusBadRequest)
|
return types.NewError(fmt.Errorf("invalid api type: %d", relayInfo.ApiType), types.ErrorCodeInvalidApiType)
|
||||||
}
|
}
|
||||||
adaptor.Init(relayInfo)
|
adaptor.Init(relayInfo)
|
||||||
|
|
||||||
convertedRequest, err := adaptor.ConvertRerankRequest(c, relayInfo.RelayMode, *rerankRequest)
|
convertedRequest, err := adaptor.ConvertRerankRequest(c, relayInfo.RelayMode, *rerankRequest)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return service.OpenAIErrorWrapperLocal(err, "convert_request_failed", http.StatusInternalServerError)
|
return types.NewError(err, types.ErrorCodeConvertRequestFailed)
|
||||||
}
|
}
|
||||||
jsonData, err := json.Marshal(convertedRequest)
|
jsonData, err := common.Marshal(convertedRequest)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return service.OpenAIErrorWrapperLocal(err, "json_marshal_failed", http.StatusInternalServerError)
|
return types.NewError(err, types.ErrorCodeConvertRequestFailed)
|
||||||
}
|
}
|
||||||
requestBody := bytes.NewBuffer(jsonData)
|
requestBody := bytes.NewBuffer(jsonData)
|
||||||
if common.DebugEnabled {
|
if common.DebugEnabled {
|
||||||
@@ -83,7 +84,7 @@ func RerankHelper(c *gin.Context, relayMode int) (openaiErr *dto.OpenAIErrorWith
|
|||||||
}
|
}
|
||||||
resp, err := adaptor.DoRequest(c, relayInfo, requestBody)
|
resp, err := adaptor.DoRequest(c, relayInfo, requestBody)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return service.OpenAIErrorWrapper(err, "do_request_failed", http.StatusInternalServerError)
|
return types.NewError(err, types.ErrorCodeDoRequestFailed)
|
||||||
}
|
}
|
||||||
|
|
||||||
statusCodeMappingStr := c.GetString("status_code_mapping")
|
statusCodeMappingStr := c.GetString("status_code_mapping")
|
||||||
@@ -91,18 +92,18 @@ func RerankHelper(c *gin.Context, relayMode int) (openaiErr *dto.OpenAIErrorWith
|
|||||||
if resp != nil {
|
if resp != nil {
|
||||||
httpResp = resp.(*http.Response)
|
httpResp = resp.(*http.Response)
|
||||||
if httpResp.StatusCode != http.StatusOK {
|
if httpResp.StatusCode != http.StatusOK {
|
||||||
openaiErr = service.RelayErrorHandler(httpResp, false)
|
newAPIError = service.RelayErrorHandler(httpResp, false)
|
||||||
// reset status code 重置状态码
|
// reset status code 重置状态码
|
||||||
service.ResetStatusCode(openaiErr, statusCodeMappingStr)
|
service.ResetStatusCode(newAPIError, statusCodeMappingStr)
|
||||||
return openaiErr
|
return newAPIError
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
usage, openaiErr := adaptor.DoResponse(c, httpResp, relayInfo)
|
usage, newAPIError := adaptor.DoResponse(c, httpResp, relayInfo)
|
||||||
if openaiErr != nil {
|
if newAPIError != nil {
|
||||||
// reset status code 重置状态码
|
// reset status code 重置状态码
|
||||||
service.ResetStatusCode(openaiErr, statusCodeMappingStr)
|
service.ResetStatusCode(newAPIError, statusCodeMappingStr)
|
||||||
return openaiErr
|
return newAPIError
|
||||||
}
|
}
|
||||||
postConsumeQuota(c, relayInfo, usage.(*dto.Usage), preConsumedQuota, userQuota, priceData, "")
|
postConsumeQuota(c, relayInfo, usage.(*dto.Usage), preConsumedQuota, userQuota, priceData, "")
|
||||||
return nil
|
return nil
|
||||||
|
|||||||
@@ -14,6 +14,7 @@ import (
|
|||||||
"one-api/service"
|
"one-api/service"
|
||||||
"one-api/setting"
|
"one-api/setting"
|
||||||
"one-api/setting/model_setting"
|
"one-api/setting/model_setting"
|
||||||
|
"one-api/types"
|
||||||
"strings"
|
"strings"
|
||||||
|
|
||||||
"github.com/gin-gonic/gin"
|
"github.com/gin-gonic/gin"
|
||||||
@@ -46,11 +47,11 @@ func getInputTokens(req *dto.OpenAIResponsesRequest, info *relaycommon.RelayInfo
|
|||||||
return inputTokens
|
return inputTokens
|
||||||
}
|
}
|
||||||
|
|
||||||
func ResponsesHelper(c *gin.Context) (openaiErr *dto.OpenAIErrorWithStatusCode) {
|
func ResponsesHelper(c *gin.Context) (newAPIError *types.NewAPIError) {
|
||||||
req, err := getAndValidateResponsesRequest(c)
|
req, err := getAndValidateResponsesRequest(c)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
common.LogError(c, fmt.Sprintf("getAndValidateResponsesRequest error: %s", err.Error()))
|
common.LogError(c, fmt.Sprintf("getAndValidateResponsesRequest error: %s", err.Error()))
|
||||||
return service.OpenAIErrorWrapperLocal(err, "invalid_responses_request", http.StatusBadRequest)
|
return types.NewError(err, types.ErrorCodeInvalidRequest)
|
||||||
}
|
}
|
||||||
|
|
||||||
relayInfo := relaycommon.GenRelayInfoResponses(c, req)
|
relayInfo := relaycommon.GenRelayInfoResponses(c, req)
|
||||||
@@ -59,13 +60,13 @@ func ResponsesHelper(c *gin.Context) (openaiErr *dto.OpenAIErrorWithStatusCode)
|
|||||||
sensitiveWords, err := checkInputSensitive(req, relayInfo)
|
sensitiveWords, err := checkInputSensitive(req, relayInfo)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
common.LogWarn(c, fmt.Sprintf("user sensitive words detected: %s", strings.Join(sensitiveWords, ", ")))
|
common.LogWarn(c, fmt.Sprintf("user sensitive words detected: %s", strings.Join(sensitiveWords, ", ")))
|
||||||
return service.OpenAIErrorWrapperLocal(err, "check_request_sensitive_error", http.StatusBadRequest)
|
return types.NewError(err, types.ErrorCodeSensitiveWordsDetected)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
err = helper.ModelMappedHelper(c, relayInfo, req)
|
err = helper.ModelMappedHelper(c, relayInfo, req)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return service.OpenAIErrorWrapperLocal(err, "model_mapped_error", http.StatusBadRequest)
|
return types.NewError(err, types.ErrorCodeChannelModelMappedError)
|
||||||
}
|
}
|
||||||
|
|
||||||
if value, exists := c.Get("prompt_tokens"); exists {
|
if value, exists := c.Get("prompt_tokens"); exists {
|
||||||
@@ -78,52 +79,52 @@ func ResponsesHelper(c *gin.Context) (openaiErr *dto.OpenAIErrorWithStatusCode)
|
|||||||
|
|
||||||
priceData, err := helper.ModelPriceHelper(c, relayInfo, relayInfo.PromptTokens, int(req.MaxOutputTokens))
|
priceData, err := helper.ModelPriceHelper(c, relayInfo, relayInfo.PromptTokens, int(req.MaxOutputTokens))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return service.OpenAIErrorWrapperLocal(err, "model_price_error", http.StatusInternalServerError)
|
return types.NewError(err, types.ErrorCodeModelPriceError)
|
||||||
}
|
}
|
||||||
// pre consume quota
|
// pre consume quota
|
||||||
preConsumedQuota, userQuota, openaiErr := preConsumeQuota(c, priceData.ShouldPreConsumedQuota, relayInfo)
|
preConsumedQuota, userQuota, newAPIError := preConsumeQuota(c, priceData.ShouldPreConsumedQuota, relayInfo)
|
||||||
if openaiErr != nil {
|
if newAPIError != nil {
|
||||||
return openaiErr
|
return newAPIError
|
||||||
}
|
}
|
||||||
defer func() {
|
defer func() {
|
||||||
if openaiErr != nil {
|
if newAPIError != nil {
|
||||||
returnPreConsumedQuota(c, relayInfo, userQuota, preConsumedQuota)
|
returnPreConsumedQuota(c, relayInfo, userQuota, preConsumedQuota)
|
||||||
}
|
}
|
||||||
}()
|
}()
|
||||||
adaptor := GetAdaptor(relayInfo.ApiType)
|
adaptor := GetAdaptor(relayInfo.ApiType)
|
||||||
if adaptor == nil {
|
if adaptor == nil {
|
||||||
return service.OpenAIErrorWrapperLocal(fmt.Errorf("invalid api type: %d", relayInfo.ApiType), "invalid_api_type", http.StatusBadRequest)
|
return types.NewError(fmt.Errorf("invalid api type: %d", relayInfo.ApiType), types.ErrorCodeInvalidApiType)
|
||||||
}
|
}
|
||||||
adaptor.Init(relayInfo)
|
adaptor.Init(relayInfo)
|
||||||
var requestBody io.Reader
|
var requestBody io.Reader
|
||||||
if model_setting.GetGlobalSettings().PassThroughRequestEnabled {
|
if model_setting.GetGlobalSettings().PassThroughRequestEnabled {
|
||||||
body, err := common.GetRequestBody(c)
|
body, err := common.GetRequestBody(c)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return service.OpenAIErrorWrapperLocal(err, "get_request_body_error", http.StatusInternalServerError)
|
return types.NewError(err, types.ErrorCodeReadRequestBodyFailed)
|
||||||
}
|
}
|
||||||
requestBody = bytes.NewBuffer(body)
|
requestBody = bytes.NewBuffer(body)
|
||||||
} else {
|
} else {
|
||||||
convertedRequest, err := adaptor.ConvertOpenAIResponsesRequest(c, relayInfo, *req)
|
convertedRequest, err := adaptor.ConvertOpenAIResponsesRequest(c, relayInfo, *req)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return service.OpenAIErrorWrapperLocal(err, "convert_request_error", http.StatusBadRequest)
|
return types.NewError(err, types.ErrorCodeConvertRequestFailed)
|
||||||
}
|
}
|
||||||
jsonData, err := json.Marshal(convertedRequest)
|
jsonData, err := json.Marshal(convertedRequest)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return service.OpenAIErrorWrapperLocal(err, "marshal_request_error", http.StatusInternalServerError)
|
return types.NewError(err, types.ErrorCodeConvertRequestFailed)
|
||||||
}
|
}
|
||||||
// apply param override
|
// apply param override
|
||||||
if len(relayInfo.ParamOverride) > 0 {
|
if len(relayInfo.ParamOverride) > 0 {
|
||||||
reqMap := make(map[string]interface{})
|
reqMap := make(map[string]interface{})
|
||||||
err = json.Unmarshal(jsonData, &reqMap)
|
err = json.Unmarshal(jsonData, &reqMap)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return service.OpenAIErrorWrapperLocal(err, "param_override_unmarshal_failed", http.StatusInternalServerError)
|
return types.NewError(err, types.ErrorCodeChannelParamOverrideInvalid)
|
||||||
}
|
}
|
||||||
for key, value := range relayInfo.ParamOverride {
|
for key, value := range relayInfo.ParamOverride {
|
||||||
reqMap[key] = value
|
reqMap[key] = value
|
||||||
}
|
}
|
||||||
jsonData, err = json.Marshal(reqMap)
|
jsonData, err = json.Marshal(reqMap)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return service.OpenAIErrorWrapperLocal(err, "param_override_marshal_failed", http.StatusInternalServerError)
|
return types.NewError(err, types.ErrorCodeConvertRequestFailed)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -136,7 +137,7 @@ func ResponsesHelper(c *gin.Context) (openaiErr *dto.OpenAIErrorWithStatusCode)
|
|||||||
var httpResp *http.Response
|
var httpResp *http.Response
|
||||||
resp, err := adaptor.DoRequest(c, relayInfo, requestBody)
|
resp, err := adaptor.DoRequest(c, relayInfo, requestBody)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return service.OpenAIErrorWrapper(err, "do_request_failed", http.StatusInternalServerError)
|
return types.NewError(err, types.ErrorCodeDoRequestFailed)
|
||||||
}
|
}
|
||||||
|
|
||||||
statusCodeMappingStr := c.GetString("status_code_mapping")
|
statusCodeMappingStr := c.GetString("status_code_mapping")
|
||||||
@@ -145,18 +146,18 @@ func ResponsesHelper(c *gin.Context) (openaiErr *dto.OpenAIErrorWithStatusCode)
|
|||||||
httpResp = resp.(*http.Response)
|
httpResp = resp.(*http.Response)
|
||||||
|
|
||||||
if httpResp.StatusCode != http.StatusOK {
|
if httpResp.StatusCode != http.StatusOK {
|
||||||
openaiErr = service.RelayErrorHandler(httpResp, false)
|
newAPIError = service.RelayErrorHandler(httpResp, false)
|
||||||
// reset status code 重置状态码
|
// reset status code 重置状态码
|
||||||
service.ResetStatusCode(openaiErr, statusCodeMappingStr)
|
service.ResetStatusCode(newAPIError, statusCodeMappingStr)
|
||||||
return openaiErr
|
return newAPIError
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
usage, openaiErr := adaptor.DoResponse(c, httpResp, relayInfo)
|
usage, newAPIError := adaptor.DoResponse(c, httpResp, relayInfo)
|
||||||
if openaiErr != nil {
|
if newAPIError != nil {
|
||||||
// reset status code 重置状态码
|
// reset status code 重置状态码
|
||||||
service.ResetStatusCode(openaiErr, statusCodeMappingStr)
|
service.ResetStatusCode(newAPIError, statusCodeMappingStr)
|
||||||
return openaiErr
|
return newAPIError
|
||||||
}
|
}
|
||||||
|
|
||||||
if strings.HasPrefix(relayInfo.OriginModelName, "gpt-4o-audio") {
|
if strings.HasPrefix(relayInfo.OriginModelName, "gpt-4o-audio") {
|
||||||
|
|||||||
@@ -1,18 +1,18 @@
|
|||||||
package relay
|
package relay
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"encoding/json"
|
|
||||||
"fmt"
|
"fmt"
|
||||||
"github.com/gin-gonic/gin"
|
|
||||||
"github.com/gorilla/websocket"
|
|
||||||
"net/http"
|
|
||||||
"one-api/dto"
|
"one-api/dto"
|
||||||
relaycommon "one-api/relay/common"
|
relaycommon "one-api/relay/common"
|
||||||
"one-api/relay/helper"
|
"one-api/relay/helper"
|
||||||
"one-api/service"
|
"one-api/service"
|
||||||
|
"one-api/types"
|
||||||
|
|
||||||
|
"github.com/gin-gonic/gin"
|
||||||
|
"github.com/gorilla/websocket"
|
||||||
)
|
)
|
||||||
|
|
||||||
func WssHelper(c *gin.Context, ws *websocket.Conn) (openaiErr *dto.OpenAIErrorWithStatusCode) {
|
func WssHelper(c *gin.Context, ws *websocket.Conn) (newAPIError *types.NewAPIError) {
|
||||||
relayInfo := relaycommon.GenRelayInfoWs(c, ws)
|
relayInfo := relaycommon.GenRelayInfoWs(c, ws)
|
||||||
|
|
||||||
// get & validate textRequest 获取并验证文本请求
|
// get & validate textRequest 获取并验证文本请求
|
||||||
@@ -22,42 +22,31 @@ func WssHelper(c *gin.Context, ws *websocket.Conn) (openaiErr *dto.OpenAIErrorWi
|
|||||||
// return service.OpenAIErrorWrapperLocal(err, "invalid_text_request", http.StatusBadRequest)
|
// return service.OpenAIErrorWrapperLocal(err, "invalid_text_request", http.StatusBadRequest)
|
||||||
//}
|
//}
|
||||||
|
|
||||||
// map model name
|
err := helper.ModelMappedHelper(c, relayInfo, nil)
|
||||||
modelMapping := c.GetString("model_mapping")
|
if err != nil {
|
||||||
//isModelMapped := false
|
return types.NewError(err, types.ErrorCodeChannelModelMappedError)
|
||||||
if modelMapping != "" && modelMapping != "{}" {
|
|
||||||
modelMap := make(map[string]string)
|
|
||||||
err := json.Unmarshal([]byte(modelMapping), &modelMap)
|
|
||||||
if err != nil {
|
|
||||||
return service.OpenAIErrorWrapperLocal(err, "unmarshal_model_mapping_failed", http.StatusInternalServerError)
|
|
||||||
}
|
|
||||||
if modelMap[relayInfo.OriginModelName] != "" {
|
|
||||||
relayInfo.UpstreamModelName = modelMap[relayInfo.OriginModelName]
|
|
||||||
// set upstream model name
|
|
||||||
//isModelMapped = true
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
priceData, err := helper.ModelPriceHelper(c, relayInfo, 0, 0)
|
priceData, err := helper.ModelPriceHelper(c, relayInfo, 0, 0)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return service.OpenAIErrorWrapperLocal(err, "model_price_error", http.StatusInternalServerError)
|
return types.NewError(err, types.ErrorCodeModelPriceError)
|
||||||
}
|
}
|
||||||
|
|
||||||
// pre-consume quota 预消耗配额
|
// pre-consume quota 预消耗配额
|
||||||
preConsumedQuota, userQuota, openaiErr := preConsumeQuota(c, priceData.ShouldPreConsumedQuota, relayInfo)
|
preConsumedQuota, userQuota, newAPIError := preConsumeQuota(c, priceData.ShouldPreConsumedQuota, relayInfo)
|
||||||
if openaiErr != nil {
|
if newAPIError != nil {
|
||||||
return openaiErr
|
return newAPIError
|
||||||
}
|
}
|
||||||
|
|
||||||
defer func() {
|
defer func() {
|
||||||
if openaiErr != nil {
|
if newAPIError != nil {
|
||||||
returnPreConsumedQuota(c, relayInfo, userQuota, preConsumedQuota)
|
returnPreConsumedQuota(c, relayInfo, userQuota, preConsumedQuota)
|
||||||
}
|
}
|
||||||
}()
|
}()
|
||||||
|
|
||||||
adaptor := GetAdaptor(relayInfo.ApiType)
|
adaptor := GetAdaptor(relayInfo.ApiType)
|
||||||
if adaptor == nil {
|
if adaptor == nil {
|
||||||
return service.OpenAIErrorWrapperLocal(fmt.Errorf("invalid api type: %d", relayInfo.ApiType), "invalid_api_type", http.StatusBadRequest)
|
return types.NewError(fmt.Errorf("invalid api type: %d", relayInfo.ApiType), types.ErrorCodeInvalidApiType)
|
||||||
}
|
}
|
||||||
adaptor.Init(relayInfo)
|
adaptor.Init(relayInfo)
|
||||||
//var requestBody io.Reader
|
//var requestBody io.Reader
|
||||||
@@ -67,7 +56,7 @@ func WssHelper(c *gin.Context, ws *websocket.Conn) (openaiErr *dto.OpenAIErrorWi
|
|||||||
statusCodeMappingStr := c.GetString("status_code_mapping")
|
statusCodeMappingStr := c.GetString("status_code_mapping")
|
||||||
resp, err := adaptor.DoRequest(c, relayInfo, nil)
|
resp, err := adaptor.DoRequest(c, relayInfo, nil)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return service.OpenAIErrorWrapper(err, "do_request_failed", http.StatusInternalServerError)
|
return types.NewError(err, types.ErrorCodeDoRequestFailed)
|
||||||
}
|
}
|
||||||
|
|
||||||
if resp != nil {
|
if resp != nil {
|
||||||
@@ -75,11 +64,11 @@ func WssHelper(c *gin.Context, ws *websocket.Conn) (openaiErr *dto.OpenAIErrorWi
|
|||||||
defer relayInfo.TargetWs.Close()
|
defer relayInfo.TargetWs.Close()
|
||||||
}
|
}
|
||||||
|
|
||||||
usage, openaiErr := adaptor.DoResponse(c, nil, relayInfo)
|
usage, newAPIError := adaptor.DoResponse(c, nil, relayInfo)
|
||||||
if openaiErr != nil {
|
if newAPIError != nil {
|
||||||
// reset status code 重置状态码
|
// reset status code 重置状态码
|
||||||
service.ResetStatusCode(openaiErr, statusCodeMappingStr)
|
service.ResetStatusCode(newAPIError, statusCodeMappingStr)
|
||||||
return openaiErr
|
return newAPIError
|
||||||
}
|
}
|
||||||
service.PostWssConsumeQuota(c, relayInfo, relayInfo.UpstreamModelName, usage.(*dto.RealtimeUsage), preConsumedQuota,
|
service.PostWssConsumeQuota(c, relayInfo, relayInfo.UpstreamModelName, usage.(*dto.RealtimeUsage), preConsumedQuota,
|
||||||
userQuota, priceData, "")
|
userQuota, priceData, "")
|
||||||
|
|||||||
@@ -19,7 +19,6 @@ func SetApiRouter(router *gin.Engine) {
|
|||||||
apiRouter.GET("/uptime/status", controller.GetUptimeKumaStatus)
|
apiRouter.GET("/uptime/status", controller.GetUptimeKumaStatus)
|
||||||
apiRouter.GET("/models", middleware.UserAuth(), controller.DashboardListModels)
|
apiRouter.GET("/models", middleware.UserAuth(), controller.DashboardListModels)
|
||||||
apiRouter.GET("/status/test", middleware.AdminAuth(), controller.TestStatus)
|
apiRouter.GET("/status/test", middleware.AdminAuth(), controller.TestStatus)
|
||||||
apiRouter.GET("/jsrt/reload", middleware.AdminAuth(), controller.ReloadJSScripts)
|
|
||||||
apiRouter.GET("/notice", controller.GetNotice)
|
apiRouter.GET("/notice", controller.GetNotice)
|
||||||
apiRouter.GET("/about", controller.GetAbout)
|
apiRouter.GET("/about", controller.GetAbout)
|
||||||
//apiRouter.GET("/midjourney", controller.GetMidjourney)
|
//apiRouter.GET("/midjourney", controller.GetMidjourney)
|
||||||
|
|||||||
@@ -3,21 +3,14 @@ package router
|
|||||||
import (
|
import (
|
||||||
"embed"
|
"embed"
|
||||||
"fmt"
|
"fmt"
|
||||||
|
"github.com/gin-gonic/gin"
|
||||||
"net/http"
|
"net/http"
|
||||||
"one-api/common"
|
"one-api/common"
|
||||||
"one-api/middleware/jsrt"
|
|
||||||
"os"
|
"os"
|
||||||
"strings"
|
"strings"
|
||||||
|
|
||||||
"github.com/gin-gonic/gin"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
func SetRouter(router *gin.Engine, buildFS embed.FS, indexPage []byte) {
|
func SetRouter(router *gin.Engine, buildFS embed.FS, indexPage []byte) {
|
||||||
jsrtMid := jsrt.JSRuntimeMiddleware()
|
|
||||||
if jsrtMid != nil {
|
|
||||||
router.Use(*jsrtMid)
|
|
||||||
}
|
|
||||||
|
|
||||||
SetApiRouter(router)
|
SetApiRouter(router)
|
||||||
SetDashboardRouter(router)
|
SetDashboardRouter(router)
|
||||||
SetRelayRouter(router)
|
SetRelayRouter(router)
|
||||||
|
|||||||
@@ -12,7 +12,6 @@ func SetRelayRouter(router *gin.Engine) {
|
|||||||
router.Use(middleware.CORS())
|
router.Use(middleware.CORS())
|
||||||
router.Use(middleware.DecompressRequestMiddleware())
|
router.Use(middleware.DecompressRequestMiddleware())
|
||||||
router.Use(middleware.StatsMiddleware())
|
router.Use(middleware.StatsMiddleware())
|
||||||
|
|
||||||
// https://platform.openai.com/docs/api-reference/introduction
|
// https://platform.openai.com/docs/api-reference/introduction
|
||||||
modelsRouter := router.Group("/v1/models")
|
modelsRouter := router.Group("/v1/models")
|
||||||
modelsRouter.Use(middleware.TokenAuth())
|
modelsRouter.Use(middleware.TokenAuth())
|
||||||
@@ -21,7 +20,7 @@ func SetRelayRouter(router *gin.Engine) {
|
|||||||
modelsRouter.GET("/:model", controller.RetrieveModel)
|
modelsRouter.GET("/:model", controller.RetrieveModel)
|
||||||
}
|
}
|
||||||
playgroundRouter := router.Group("/pg")
|
playgroundRouter := router.Group("/pg")
|
||||||
playgroundRouter.Use(middleware.UserAuth())
|
playgroundRouter.Use(middleware.UserAuth(), middleware.Distribute())
|
||||||
{
|
{
|
||||||
playgroundRouter.POST("/chat/completions", controller.Playground)
|
playgroundRouter.POST("/chat/completions", controller.Playground)
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -1,15 +0,0 @@
|
|||||||
// Utility functions for JavaScript runtime
|
|
||||||
|
|
||||||
function logWithReq(req, message) {
|
|
||||||
let reqPath = req.url || 'unknown path';
|
|
||||||
console.log(`[${req.method} ${reqPath}] ${message}`);
|
|
||||||
}
|
|
||||||
|
|
||||||
function safeJsonParse(str, defaultValue = null) {
|
|
||||||
try {
|
|
||||||
return JSON.parse(str);
|
|
||||||
} catch (e) {
|
|
||||||
console.error('JSON parse error:', e.message);
|
|
||||||
return defaultValue;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user