Compare commits

..

2 Commits

Author SHA1 Message Date
Xyfacai
606aa8a4a7 feat: 调试 suno 2024-06-13 00:32:14 +08:00
Xiangyuan Liu
1e8abc7027 feat: suno api 支持 2024-06-12 20:37:42 +08:00
126 changed files with 3214 additions and 5630 deletions

View File

@@ -4,7 +4,6 @@ on:
push:
tags:
- '*'
- '!*-alpha*'
workflow_dispatch:
inputs:
name:

View File

@@ -2,21 +2,6 @@
**简介**:Midjourney Proxy API文档
## 接口列表
支持的接口如下:
+ [x] /mj/submit/imagine
+ [x] /mj/submit/change
+ [x] /mj/submit/blend
+ [x] /mj/submit/describe
+ [x] /mj/image/{id} (通过此接口获取图片,**请必须在系统设置中填写服务器地址!!**
+ [x] /mj/task/{id}/fetch 此接口返回的图片地址为经过One API转发的地址
+ [x] /task/list-by-condition
+ [x] /mj/submit/action 仅midjourney-proxy-plus支持下同
+ [x] /mj/submit/modal
+ [x] /mj/submit/shorten
+ [x] /mj/task/{id}/image-seed
+ [x] /mj/insight-face/swap InsightFace
## 模型列表
### midjourney-proxy支持

View File

@@ -5,6 +5,8 @@
> 本项目为开源项目,在[One API](https://github.com/songquanpeng/one-api)的基础上进行二次开发,感谢原作者的无私奉献。
> 使用者必须在遵循 OpenAI 的[使用条款](https://openai.com/policies/terms-of-use)以及**法律法规**的情况下使用,不得用于非法用途。
> [!WARNING]
> 本项目为个人学习使用,不保证稳定性,且不提供任何技术支持,使用者必须在遵循 OpenAI 的使用条款以及法律法规的情况下使用,不得用于非法用途。
> 根据[《生成式人工智能服务管理暂行办法》](http://www.cac.gov.cn/2023-07/13/c_1690898327029107.htm)的要求,请勿对中国地区公众提供一切未经备案的生成式人工智能服务。
@@ -16,7 +18,19 @@
此分叉版本的主要变更如下:
1. 全新的UI界面部分界面还待更新
2. 添加[Midjourney-Proxy(Plus)](https://github.com/novicezk/midjourney-proxy)接口的支持,[对接文档](Midjourney.md)
2. 添加[Midjourney-Proxy(Plus)](https://github.com/novicezk/midjourney-proxy)接口的支持,[对接文档](Midjourney.md),支持的接口如下:
+ [x] /mj/submit/imagine
+ [x] /mj/submit/change
+ [x] /mj/submit/blend
+ [x] /mj/submit/describe
+ [x] /mj/image/{id} (通过此接口获取图片,**请必须在系统设置中填写服务器地址!!**
+ [x] /mj/task/{id}/fetch 此接口返回的图片地址为经过One API转发的地址
+ [x] /task/list-by-condition
+ [x] /mj/submit/action 仅midjourney-proxy-plus支持下同
+ [x] /mj/submit/modal
+ [x] /mj/submit/shorten
+ [x] /mj/task/{id}/image-seed
+ [x] /mj/insight-face/swap InsightFace
3. 支持在线充值功能,可在系统设置中设置,当前支持的支付接口:
+ [x] 易支付
4. 支持用key查询使用额度:
@@ -33,21 +47,16 @@
2. 对[@Botfather](https://t.me/botfather)输入指令/setdomain
3. 选择你的bot然后输入http(s)://你的网站地址/login
4. Telegram Bot 名称是bot username 去掉@后的字符串
13. 添加 [Suno API](https://github.com/Suno-API/Suno-API)接口的支持,[对接文档](Suno.md)
14. 支持Rerank模型目前仅兼容Cohere和Jina可接入Dify[对接文档](Rerank.md)
## 模型支持
此版本额外支持以下模型:
1. 第三方模型 **gps** gpt-4-gizmo-*
2. 智谱glm-4vglm-4v识图
3. Anthropic Claude 3
3. Anthropic Claude 3 (claude-3-opus-20240229, claude-3-sonnet-20240229)
4. [Ollama](https://github.com/ollama/ollama?tab=readme-ov-file),添加渠道时,密钥可以随便填写,默认的请求地址是[http://localhost:11434](http://localhost:11434),如果需要修改请在渠道中修改
5. [Midjourney-Proxy(Plus)](https://github.com/novicezk/midjourney-proxy)接口,[对接文档](Midjourney.md)
6. [零一万物](https://platform.lingyiwanwu.com/)
7. 自定义渠道,支持填入完整调用地址
8. [Suno API](https://github.com/Suno-API/Suno-API) 接口,[对接文档](Suno.md)
9. Rerank模型目前支持[Cohere](https://cohere.ai/)和[Jina](https://jina.ai/)[对接文档](Rerank.md)
10. Dify
您可以在渠道中添加自定义模型gpt-4-gizmo-*此模型并非OpenAI官方模型而是第三方模型使用官方key无法调用。
@@ -70,17 +79,8 @@
```
可以实现400错误转为500错误从而重试
## 比原版One API多出的配置
- `STREAMING_TIMEOUT`:设置流式一次回复的超时时间,默认为 30 秒
- `DIFY_DEBUG`:设置 Dify 渠道是否输出工作流和节点信息到客户端,默认为 `true`
- `FORCE_STREAM_OPTION`是否覆盖客户端stream_options参数请求上游返回流模式usage默认为 `true`
- `GET_MEDIA_TOKEN`是统计图片token默认为 `true`关闭后将不再在本地计算图片token可能会导致和上游计费不同此项覆盖 `GET_MEDIA_TOKEN_NOT_STREAM` 选项作用,
- `GET_MEDIA_TOKEN_NOT_STREAM`:是否在非流(`stream=false`情况下统计图片token默认为 `true`
- `UPDATE_TASK`是否更新异步任务Midjourney、Suno默认为 `true`,关闭后将不会更新任务进度
## 部署
### 部署要求
- 本地数据库默认SQLiteDocker 部署默认使用 SQLite必须挂载 `/data` 目录到宿主机)
- 远程数据库MySQL 版本 >= 5.7.8PgSQL 版本 >= 9.6
### 基于 Docker 进行部署
```shell
# 使用 SQLite 的部署命令:
@@ -105,9 +105,6 @@ docker run --name new-api -d --restart always -p 3000:3000 -e SQL_DSN="root:1234
## Midjourney接口设置文档
[对接文档](Midjourney.md)
## Suno接口设置文档
[对接文档](Suno.md)
## 交流群
<img src="https://github.com/Calcium-Ion/new-api/assets/61247483/de536a8a-0161-47a7-a0a2-66ef6de81266" width="300">

View File

@@ -1,62 +0,0 @@
# Rerank API文档
**简介**:Rerank API文档
## 接入Dify
模型供应商选择Jina按要求填写模型信息即可接入Dify。
## 请求方式
Post: /v1/rerank
Request:
```json
{
"model": "rerank-multilingual-v3.0",
"query": "What is the capital of the United States?",
"top_n": 3,
"documents": [
"Carson City is the capital city of the American state of Nevada.",
"The Commonwealth of the Northern Mariana Islands is a group of islands in the Pacific Ocean. Its capital is Saipan.",
"Washington, D.C. (also known as simply Washington or D.C., and officially as the District of Columbia) is the capital of the United States. It is a federal district.",
"Capitalization or capitalisation in English grammar is the use of a capital letter at the start of a word. English usage varies from capitalization in other languages.",
"Capital punishment (the death penalty) has existed in the United States since beforethe United States was a country. As of 2017, capital punishment is legal in 30 of the 50 states."
]
}
```
Response:
```json
{
"results": [
{
"document": {
"text": "Washington, D.C. (also known as simply Washington or D.C., and officially as the District of Columbia) is the capital of the United States. It is a federal district."
},
"index": 2,
"relevance_score": 0.9999702
},
{
"document": {
"text": "Carson City is the capital city of the American state of Nevada."
},
"index": 0,
"relevance_score": 0.67800725
},
{
"document": {
"text": "Capitalization or capitalisation in English grammar is the use of a capital letter at the start of a word. English usage varies from capitalization in other languages."
},
"index": 3,
"relevance_score": 0.02800752
}
],
"usage": {
"prompt_tokens": 158,
"completion_tokens": 0,
"total_tokens": 158
}
}
```

44
Suno.md
View File

@@ -1,44 +0,0 @@
# Suno API文档
**简介**:Suno API文档
## 接口列表
支持的接口如下:
+ [x] /suno/submit/music
+ [x] /suno/submit/lyrics
+ [x] /suno/fetch
+ [x] /suno/fetch/:id
## 模型列表
### Suno API支持
- suno_music (自定义模式、灵感模式、续写)
- suno_lyrics (生成歌词)
## 模型价格设置(在设置-运营设置-模型固定价格设置中设置)
```json
{
"suno_music": 0.3,
"suno_lyrics": 0.01
}
```
## 渠道设置
### 对接 Suno API
1.
部署 Suno API并配置好suno账号等强烈建议设置密钥[项目地址](https://github.com/Suno-API/Suno-API)
2. 在渠道管理中添加渠道,渠道类型选择**Suno API**
,模型请参考上方模型列表
3. **代理**填写 Suno API 部署的地址例如http://localhost:8080
4. 密钥填写 Suno API 的密钥,如果没有设置密钥,可以随便填
### 对接上游new api
1. 在渠道管理中添加渠道,渠道类型选择**Suno API**,或任意类型,只需模型包含上方模型列表的模型
2. **代理**填写上游new api的地址例如http://localhost:3000
3. 密钥填写上游new api的密钥

View File

@@ -103,14 +103,14 @@ var IsMasterNode = os.Getenv("NODE_TYPE") != "slave"
var requestInterval, _ = strconv.Atoi(os.Getenv("POLLING_INTERVAL"))
var RequestInterval = time.Duration(requestInterval) * time.Second
var SyncFrequency = GetEnvOrDefault("SYNC_FREQUENCY", 60) // unit is second
var SyncFrequency = GetOrDefault("SYNC_FREQUENCY", 60) // unit is second
var BatchUpdateEnabled = false
var BatchUpdateInterval = GetEnvOrDefault("BATCH_UPDATE_INTERVAL", 5)
var BatchUpdateInterval = GetOrDefault("BATCH_UPDATE_INTERVAL", 5)
var RelayTimeout = GetEnvOrDefault("RELAY_TIMEOUT", 0) // unit is second
var RelayTimeout = GetOrDefault("RELAY_TIMEOUT", 0) // unit is second
var GeminiSafetySetting = GetEnvOrDefaultString("GEMINI_SAFETY_SETTING", "BLOCK_NONE")
var GeminiSafetySetting = GetOrDefaultString("GEMINI_SAFETY_SETTING", "BLOCK_NONE")
const (
RequestIdKey = "X-Oneapi-Request-Id"
@@ -133,10 +133,10 @@ var (
// All duration's unit is seconds
// Shouldn't larger then RateLimitKeyExpirationDuration
var (
GlobalApiRateLimitNum = GetEnvOrDefault("GLOBAL_API_RATE_LIMIT", 180)
GlobalApiRateLimitNum = GetOrDefault("GLOBAL_API_RATE_LIMIT", 180)
GlobalApiRateLimitDuration int64 = 3 * 60
GlobalWebRateLimitNum = GetEnvOrDefault("GLOBAL_WEB_RATE_LIMIT", 60)
GlobalWebRateLimitNum = GetOrDefault("GLOBAL_WEB_RATE_LIMIT", 60)
GlobalWebRateLimitDuration int64 = 3 * 60
UploadRateLimitNum = 10
@@ -210,39 +210,36 @@ const (
ChannelTypeCohere = 34
ChannelTypeMiniMax = 35
ChannelTypeSunoAPI = 36
ChannelTypeDify = 37
ChannelTypeJina = 38
ChannelCloudflare = 39
ChannelTypeDummy // this one is only for count, do not add any channel after this
)
var ChannelBaseURLs = []string{
"", // 0
"https://api.openai.com", // 1
"https://oa.api2d.net", // 2
"", // 3
"http://localhost:11434", // 4
"https://api.openai-sb.com", // 5
"https://api.openaimax.com", // 6
"https://api.ohmygpt.com", // 7
"", // 8
"https://api.caipacity.com", // 9
"https://api.aiproxy.io", // 10
"", // 11
"https://api.api2gpt.com", // 12
"https://api.aigc2d.com", // 13
"https://api.anthropic.com", // 14
"https://aip.baidubce.com", // 15
"https://open.bigmodel.cn", // 16
"https://dashscope.aliyuncs.com", // 17
"", // 18
"https://ai.360.cn", // 19
"https://openrouter.ai/api", // 20
"https://api.aiproxy.io", // 21
"https://fastgpt.run/api/openapi", // 22
"https://hunyuan.tencentcloudapi.com", //23
"", // 0
"https://api.openai.com", // 1
"https://oa.api2d.net", // 2
"", // 3
"http://localhost:11434", // 4
"https://api.openai-sb.com", // 5
"https://api.openaimax.com", // 6
"https://api.ohmygpt.com", // 7
"", // 8
"https://api.caipacity.com", // 9
"https://api.aiproxy.io", // 10
"", // 11
"https://api.api2gpt.com", // 12
"https://api.aigc2d.com", // 13
"https://api.anthropic.com", // 14
"https://aip.baidubce.com", // 15
"https://open.bigmodel.cn", // 16
"https://dashscope.aliyuncs.com", // 17
"", // 18
"https://ai.360.cn", // 19
"https://openrouter.ai/api", // 20
"https://api.aiproxy.io", // 21
"https://fastgpt.run/api/openapi", // 22
"https://hunyuan.cloud.tencent.com", //23
"https://generativelanguage.googleapis.com", //24
"https://api.moonshot.cn", //25
"https://open.bigmodel.cn", //26
@@ -256,7 +253,4 @@ var ChannelBaseURLs = []string{
"https://api.cohere.ai", //34
"https://api.minimax.chat", //35
"", //36
"", //37
"https://api.jina.ai", //38
"https://api.cloudflare.com", //39
}

View File

@@ -1,38 +0,0 @@
package common
import (
"fmt"
"os"
"strconv"
)
func GetEnvOrDefault(env string, defaultValue int) int {
if env == "" || os.Getenv(env) == "" {
return defaultValue
}
num, err := strconv.Atoi(os.Getenv(env))
if err != nil {
SysError(fmt.Sprintf("failed to parse %s: %s, using default value: %d", env, err.Error(), defaultValue))
return defaultValue
}
return num
}
func GetEnvOrDefaultString(env string, defaultValue string) string {
if env == "" || os.Getenv(env) == "" {
return defaultValue
}
return os.Getenv(env)
}
func GetEnvOrDefaultBool(env string, defaultValue bool) bool {
if env == "" || os.Getenv(env) == "" {
return defaultValue
}
b, err := strconv.ParseBool(os.Getenv(env))
if err != nil {
SysError(fmt.Sprintf("failed to parse %s: %s, using default value: %t", env, err.Error(), defaultValue))
return defaultValue
}
return b
}

View File

@@ -3,7 +3,6 @@ package common
import (
"fmt"
"runtime/debug"
"time"
)
func SafeGoroutine(f func()) {
@@ -46,21 +45,3 @@ func SafeSendString(ch chan string, value string) (closed bool) {
// If the code reaches here, then the channel was not closed.
return false
}
// SafeSendStringTimeout send, return true, else return false
func SafeSendStringTimeout(ch chan string, value string, timeout int) (closed bool) {
defer func() {
// Recover from panic if one occured. A panic would mean the channel was closed.
if recover() != nil {
closed = false
}
}()
// This will panic if the channel is closed.
select {
case ch <- value:
return true
case <-time.After(time.Duration(timeout) * time.Second):
return false
}
}

View File

@@ -1,8 +1,6 @@
package common
import (
"encoding/json"
)
import "encoding/json"
var GroupRatio = map[string]float64{
"default": 1,

View File

@@ -72,29 +72,24 @@ var defaultModelRatio = map[string]float64{
"text-search-ada-doc-001": 10,
"text-moderation-stable": 0.1,
"text-moderation-latest": 0.1,
"claude-instant-1": 0.4, // $0.8 / 1M tokens
"claude-2.0": 4, // $8 / 1M tokens
"claude-2.1": 4, // $8 / 1M tokens
"claude-3-haiku-20240307": 0.125, // $0.25 / 1M tokens
"claude-3-sonnet-20240229": 1.5, // $3 / 1M tokens
"claude-3-5-sonnet-20240620": 1.5,
"claude-3-opus-20240229": 7.5, // $15 / 1M tokens
"ERNIE-4.0-8K": 0.120 * RMB,
"ERNIE-3.5-8K": 0.012 * RMB,
"ERNIE-3.5-8K-0205": 0.024 * RMB,
"ERNIE-3.5-8K-1222": 0.012 * RMB,
"ERNIE-Bot-8K": 0.024 * RMB,
"ERNIE-3.5-4K-0205": 0.012 * RMB,
"ERNIE-Speed-8K": 0.004 * RMB,
"ERNIE-Speed-128K": 0.004 * RMB,
"ERNIE-Lite-8K-0922": 0.008 * RMB,
"ERNIE-Lite-8K-0308": 0.003 * RMB,
"ERNIE-Tiny-8K": 0.001 * RMB,
"BLOOMZ-7B": 0.004 * RMB,
"Embedding-V1": 0.002 * RMB,
"bge-large-zh": 0.002 * RMB,
"bge-large-en": 0.002 * RMB,
"tao-8k": 0.002 * RMB,
"claude-instant-1": 0.4, // $0.8 / 1M tokens
"claude-2.0": 4, // $8 / 1M tokens
"claude-2.1": 4, // $8 / 1M tokens
"claude-3-haiku-20240307": 0.125, // $0.25 / 1M tokens
"claude-3-sonnet-20240229": 1.5, // $3 / 1M tokens
"claude-3-opus-20240229": 7.5, // $15 / 1M tokens
"ERNIE-Bot": 0.8572, // ¥0.012 / 1k tokens //renamed to ERNIE-3.5-8K
"ERNIE-Bot-turbo": 0.5715, // ¥0.008 / 1k tokens //renamed to ERNIE-Lite-8K
"ERNIE-Bot-4": 8.572, // ¥0.12 / 1k tokens //renamed to ERNIE-4.0-8K
"ERNIE-4.0-8K": 8.572, // ¥0.12 / 1k tokens
"ERNIE-3.5-8K": 0.8572, // ¥0.012 / 1k tokens
"ERNIE-Speed-8K": 0.2858, // ¥0.004 / 1k tokens
"ERNIE-Speed-128K": 0.2858, // ¥0.004 / 1k tokens
"ERNIE-Lite-8K": 0.2143, // ¥0.003 / 1k tokens
"ERNIE-Tiny-8K": 0.0715, // ¥0.001 / 1k tokens
"ERNIE-Character-8K": 0.2858, // ¥0.004 / 1k tokens
"ERNIE-Functions-8K": 0.2858, // ¥0.004 / 1k tokens
"Embedding-V1": 0.1429, // ¥0.002 / 1k tokens
"PaLM-2": 1,
"gemini-pro": 1, // $0.00025 / 1k characters -> $0.001 / 1k tokens
"gemini-pro-vision": 1, // $0.00025 / 1k characters -> $0.001 / 1k tokens
@@ -105,13 +100,12 @@ var defaultModelRatio = map[string]float64{
"gemini-1.0-pro-latest": 1,
"gemini-1.0-pro-vision-latest": 1,
"gemini-ultra": 1,
"chatglm_turbo": 0.3572, // ¥0.005 / 1k tokens
"chatglm_pro": 0.7143, // ¥0.01 / 1k tokens
"chatglm_std": 0.3572, // ¥0.005 / 1k tokens
"chatglm_lite": 0.1429, // ¥0.002 / 1k tokens
"glm-4": 7.143, // ¥0.1 / 1k tokens
"glm-4v": 0.05 * RMB, // ¥0.05 / 1k tokens
"glm-4-alltools": 0.1 * RMB, // ¥0.1 / 1k tokens
"chatglm_turbo": 0.3572, // ¥0.005 / 1k tokens
"chatglm_pro": 0.7143, // ¥0.01 / 1k tokens
"chatglm_std": 0.3572, // ¥0.005 / 1k tokens
"chatglm_lite": 0.1429, // ¥0.002 / 1k tokens
"glm-4": 7.143, // ¥0.1 / 1k tokens
"glm-4v": 7.143, // ¥0.1 / 1k tokens
"glm-3-turbo": 0.3572,
"qwen-turbo": 0.8572, // ¥0.012 / 1k tokens
"qwen-plus": 10, // ¥0.14 / 1k tokens
@@ -120,7 +114,6 @@ var defaultModelRatio = map[string]float64{
"SparkDesk-v2.1": 1.2858, // ¥0.018 / 1k tokens
"SparkDesk-v3.1": 1.2858, // ¥0.018 / 1k tokens
"SparkDesk-v3.5": 1.2858, // ¥0.018 / 1k tokens
"SparkDesk-v4.0": 1.2858,
"360GPT_S2_V9": 0.8572, // ¥0.012 / 1k tokens
"360gpt-turbo": 0.0858, // ¥0.0012 / 1k tokens
"360gpt-turbo-responsibility-8k": 0.8572, // ¥0.012 / 1k tokens
@@ -159,8 +152,6 @@ var defaultModelRatio = map[string]float64{
}
var defaultModelPrice = map[string]float64{
"suno_music": 0.1,
"suno_lyrics": 0.01,
"dall-e-3": 0.04,
"gpt-4-gizmo-*": 0.1,
"mj_imagine": 0.1,

View File

@@ -1,73 +1,80 @@
package common
import (
"encoding/json"
"math/rand"
"strconv"
"unsafe"
"bytes"
"fmt"
goahocorasick "github.com/anknown/ahocorasick"
"one-api/constant"
"strings"
)
func GetStringIfEmpty(str string, defaultValue string) string {
if str == "" {
return defaultValue
func SundaySearch(text string, pattern string) bool {
// 计算偏移表
offset := make(map[rune]int)
for i, c := range pattern {
offset[c] = len(pattern) - i
}
return str
// 文本串长度和模式串长度
n, m := len(text), len(pattern)
// 主循环i表示当前对齐的文本串位置
for i := 0; i <= n-m; {
// 检查子串
j := 0
for j < m && text[i+j] == pattern[j] {
j++
}
// 如果完全匹配,返回匹配位置
if j == m {
return true
}
// 如果还有剩余字符,则检查下一位字符在偏移表中的值
if i+m < n {
next := rune(text[i+m])
if val, ok := offset[next]; ok {
i += val // 存在于偏移表中,进行跳跃
} else {
i += len(pattern) + 1 // 不存在于偏移表中,跳过整个模式串长度
}
} else {
break
}
}
return false // 如果没有找到匹配,返回-1
}
func GetRandomString(length int) string {
//rand.Seed(time.Now().UnixNano())
key := make([]byte, length)
for i := 0; i < length; i++ {
key[i] = keyChars[rand.Intn(len(keyChars))]
func RemoveDuplicate(s []string) []string {
result := make([]string, 0, len(s))
temp := map[string]struct{}{}
for _, item := range s {
if _, ok := temp[item]; !ok {
temp[item] = struct{}{}
result = append(result, item)
}
}
return string(key)
return result
}
func MapToJsonStr(m map[string]interface{}) string {
bytes, err := json.Marshal(m)
if err != nil {
return ""
}
return string(bytes)
}
func MapToJsonStrFloat(m map[string]float64) string {
bytes, err := json.Marshal(m)
if err != nil {
return ""
}
return string(bytes)
}
func StrToMap(str string) map[string]interface{} {
m := make(map[string]interface{})
err := json.Unmarshal([]byte(str), &m)
if err != nil {
func InitAc() *goahocorasick.Machine {
m := new(goahocorasick.Machine)
dict := readRunes()
if err := m.Build(dict); err != nil {
fmt.Println(err)
return nil
}
return m
}
func String2Int(str string) int {
num, err := strconv.Atoi(str)
if err != nil {
return 0
}
return num
}
func readRunes() [][]rune {
var dict [][]rune
func StringsContains(strs []string, str string) bool {
for _, s := range strs {
if s == str {
return true
}
for _, word := range constant.SensitiveWords {
word = strings.ToLower(word)
l := bytes.TrimSpace([]byte(word))
dict = append(dict, bytes.Runes(l))
}
return false
}
// StringToByteSlice []byte only read, panic on append
func StringToByteSlice(s string) []byte {
tmp1 := (*[2]uintptr)(unsafe.Pointer(&s))
tmp2 := [3]uintptr{tmp1[0], tmp1[1], tmp1[1]}
return *(*[]byte)(unsafe.Pointer(&tmp2))
return dict
}

View File

@@ -1,8 +1,6 @@
package common
import (
"encoding/json"
)
import "encoding/json"
var TopupGroupRatio = map[string]float64{
"default": 1,

View File

@@ -1,17 +1,20 @@
package common
import (
"encoding/json"
"fmt"
"github.com/google/uuid"
"html/template"
"log"
"math/rand"
"net"
"os"
"os/exec"
"runtime"
"strconv"
"strings"
"time"
"unsafe"
)
func OpenBrowser(url string) {
@@ -157,6 +160,15 @@ func GenerateKey() string {
return string(key)
}
func GetRandomString(length int) string {
//rand.Seed(time.Now().UnixNano())
key := make([]byte, length)
for i := 0; i < length; i++ {
key[i] = keyChars[rand.Intn(len(keyChars))]
}
return string(key)
}
func GetRandomInt(max int) int {
//rand.Seed(time.Now().UnixNano())
return rand.Intn(max)
@@ -179,11 +191,79 @@ func Max(a int, b int) int {
}
}
func GetOrDefault(env string, defaultValue int) int {
if env == "" || os.Getenv(env) == "" {
return defaultValue
}
num, err := strconv.Atoi(os.Getenv(env))
if err != nil {
SysError(fmt.Sprintf("failed to parse %s: %s, using default value: %d", env, err.Error(), defaultValue))
return defaultValue
}
return num
}
func GetOrDefaultString(env string, defaultValue string) string {
if env == "" || os.Getenv(env) == "" {
return defaultValue
}
return os.Getenv(env)
}
func MessageWithRequestId(message string, id string) string {
return fmt.Sprintf("%s (request id: %s)", message, id)
}
func String2Int(str string) int {
num, err := strconv.Atoi(str)
if err != nil {
return 0
}
return num
}
func StringsContains(strs []string, str string) bool {
for _, s := range strs {
if s == str {
return true
}
}
return false
}
// StringToByteSlice []byte only read, panic on append
func StringToByteSlice(s string) []byte {
tmp1 := (*[2]uintptr)(unsafe.Pointer(&s))
tmp2 := [3]uintptr{tmp1[0], tmp1[1], tmp1[1]}
return *(*[]byte)(unsafe.Pointer(&tmp2))
}
func RandomSleep() {
// Sleep for 0-3000 ms
time.Sleep(time.Duration(rand.Intn(3000)) * time.Millisecond)
}
func MapToJsonStr(m map[string]interface{}) string {
bytes, err := json.Marshal(m)
if err != nil {
return ""
}
return string(bytes)
}
func MapToJsonStrFloat(m map[string]float64) string {
bytes, err := json.Marshal(m)
if err != nil {
return ""
}
return string(bytes)
}
func StrToMap(str string) map[string]interface{} {
m := make(map[string]interface{})
err := json.Unmarshal([]byte(str), &m)
if err != nil {
return nil
}
return m
}

View File

@@ -1,17 +0,0 @@
package constant
import (
"one-api/common"
)
var StreamingTimeout = common.GetEnvOrDefault("STREAMING_TIMEOUT", 30)
var DifyDebug = common.GetEnvOrDefaultBool("DIFY_DEBUG", true)
// ForceStreamOption 覆盖请求参数强制返回usage信息
var ForceStreamOption = common.GetEnvOrDefaultBool("FORCE_STREAM_OPTION", true)
var GetMediaToken = common.GetEnvOrDefaultBool("GET_MEDIA_TOKEN", true)
var GetMediaTokenNotStream = common.GetEnvOrDefaultBool("GET_MEDIA_TOKEN_NOT_STREAM", true)
var UpdateTask = common.GetEnvOrDefaultBool("UPDATE_TASK", true)

View File

@@ -4,7 +4,6 @@ var MjNotifyEnabled = false
var MjAccountFilterEnabled = false
var MjModeClearEnabled = false
var MjForwardUrlEnabled = true
var MjActionCheckSuccessEnabled = true
const (
MjErrorUnknown = 5

View File

@@ -5,15 +5,12 @@ import (
"encoding/json"
"errors"
"fmt"
"github.com/bytedance/gopkg/util/gopool"
"io"
"math"
"net/http"
"net/http/httptest"
"net/url"
"one-api/common"
"one-api/dto"
"one-api/middleware"
"one-api/model"
"one-api/relay"
relaycommon "one-api/relay/common"
@@ -26,8 +23,7 @@ import (
"github.com/gin-gonic/gin"
)
func testChannel(channel *model.Channel, testModel string) (err error, openAIErrorWithStatusCode *dto.OpenAIErrorWithStatusCode) {
tik := time.Now()
func testChannel(channel *model.Channel, testModel string) (err error, openaiErr *dto.OpenAIError) {
if channel.Type == common.ChannelTypeMidjourney {
return errors.New("midjourney channel test is not supported"), nil
}
@@ -42,37 +38,22 @@ func testChannel(channel *model.Channel, testModel string) (err error, openAIErr
Body: nil,
Header: make(http.Header),
}
if testModel == "" {
if channel.TestModel != nil && *channel.TestModel != "" {
testModel = *channel.TestModel
} else {
if len(channel.GetModels()) > 0 {
testModel = channel.GetModels()[0]
} else {
testModel = "gpt-3.5-turbo"
}
}
} else {
modelMapping := *channel.ModelMapping
if modelMapping != "" && modelMapping != "{}" {
modelMap := make(map[string]string)
err := json.Unmarshal([]byte(modelMapping), &modelMap)
if err != nil {
return err, service.OpenAIErrorWrapperLocal(err, "unmarshal_model_mapping_failed", http.StatusInternalServerError)
}
if modelMap[testModel] != "" {
testModel = modelMap[testModel]
}
}
}
c.Request.Header.Set("Authorization", "Bearer "+channel.Key)
c.Request.Header.Set("Content-Type", "application/json")
c.Set("channel", channel.Type)
c.Set("base_url", channel.GetBaseURL())
middleware.SetupContextForSelectedChannel(c, channel, testModel)
switch channel.Type {
case common.ChannelTypeAzure:
c.Set("api_version", channel.Other)
case common.ChannelTypeXunfei:
c.Set("api_version", channel.Other)
//case common.ChannelTypeAIProxyLibrary:
// c.Set("library_id", channel.Other)
case common.ChannelTypeGemini:
c.Set("api_version", channel.Other)
case common.ChannelTypeAli:
c.Set("plugin", channel.Other)
}
meta := relaycommon.GenRelayInfo(c)
apiType, _ := constant.ChannelType2APIType(channel.Type)
@@ -80,15 +61,35 @@ func testChannel(channel *model.Channel, testModel string) (err error, openAIErr
if adaptor == nil {
return fmt.Errorf("invalid api type: %d, adaptor is nil", apiType), nil
}
if testModel == "" {
if channel.TestModel != nil && *channel.TestModel != "" {
testModel = *channel.TestModel
} else {
testModel = adaptor.GetModelList()[0]
}
} else {
modelMapping := *channel.ModelMapping
if modelMapping != "" && modelMapping != "{}" {
modelMap := make(map[string]string)
err := json.Unmarshal([]byte(modelMapping), &modelMap)
if err != nil {
openaiErr := service.OpenAIErrorWrapperLocal(err, "unmarshal_model_mapping_failed", http.StatusInternalServerError).Error
return err, &openaiErr
}
if modelMap[testModel] != "" {
testModel = modelMap[testModel]
}
}
}
request := buildTestRequest()
request.Model = testModel
meta.UpstreamModelName = testModel
common.SysLog(fmt.Sprintf("testing channel %d with model %s", channel.Id, testModel))
adaptor.Init(meta)
adaptor.Init(meta, *request)
convertedRequest, err := adaptor.ConvertRequest(c, meta, request)
convertedRequest, err := adaptor.ConvertRequest(c, constant.RelayModeChatCompletions, request)
if err != nil {
return err, nil
}
@@ -103,40 +104,22 @@ func testChannel(channel *model.Channel, testModel string) (err error, openAIErr
return err, nil
}
if resp != nil && resp.StatusCode != http.StatusOK {
err := service.RelayErrorHandler(resp)
return fmt.Errorf("status code %d: %s", resp.StatusCode, err.Error.Message), err
err := relaycommon.RelayErrorHandler(resp)
return fmt.Errorf("status code %d: %s", resp.StatusCode, err.Error.Message), &err.Error
}
usage, respErr := adaptor.DoResponse(c, resp, meta)
if respErr != nil {
return fmt.Errorf("%s", respErr.Error.Message), respErr
return fmt.Errorf("%s", respErr.Error.Message), &respErr.Error
}
if usage == nil {
return errors.New("usage is nil"), nil
}
result := w.Result()
// print result.Body
respBody, err := io.ReadAll(result.Body)
if err != nil {
return err, nil
}
modelPrice, usePrice := common.GetModelPrice(testModel, false)
modelRatio := common.GetModelRatio(testModel)
completionRatio := common.GetCompletionRatio(testModel)
ratio := modelRatio
quota := 0
if !usePrice {
quota = usage.PromptTokens + int(math.Round(float64(usage.CompletionTokens)*completionRatio))
quota = int(math.Round(float64(quota) * ratio))
if ratio != 0 && quota <= 0 {
quota = 1
}
} else {
quota = int(modelPrice * common.QuotaPerUnit)
}
tok := time.Now()
milliseconds := tok.Sub(tik).Milliseconds()
consumedTime := float64(milliseconds) / 1000.0
other := service.GenerateTextOtherInfo(c, meta, modelRatio, 1, completionRatio, modelPrice)
model.RecordConsumeLog(c, 1, channel.Id, usage.PromptTokens, usage.CompletionTokens, testModel, "模型测试", quota, "模型测试", 0, quota, int(consumedTime), false, other)
common.SysLog(fmt.Sprintf("testing channel #%d, response: \n%s", channel.Id, string(respBody)))
return nil, nil
}
@@ -157,7 +140,7 @@ func buildTestRequest() *dto.GeneralOpenAIRequest {
}
func TestChannel(c *gin.Context) {
channelId, err := strconv.Atoi(c.Param("id"))
id, err := strconv.Atoi(c.Param("id"))
if err != nil {
c.JSON(http.StatusOK, gin.H{
"success": false,
@@ -165,7 +148,7 @@ func TestChannel(c *gin.Context) {
})
return
}
channel, err := model.GetChannelById(channelId, true)
channel, err := model.GetChannelById(id, true)
if err != nil {
c.JSON(http.StatusOK, gin.H{
"success": false,
@@ -218,11 +201,11 @@ func testAllChannels(notify bool) error {
if disableThreshold == 0 {
disableThreshold = 10000000 // a impossible value
}
gopool.Go(func() {
go func() {
for _, channel := range channels {
isChannelEnabled := channel.Status == common.ChannelStatusEnabled
tik := time.Now()
err, openaiWithStatusErr := testChannel(channel, "")
err, openaiErr := testChannel(channel, "")
tok := time.Now()
milliseconds := tok.Sub(tik).Milliseconds()
@@ -231,29 +214,20 @@ func testAllChannels(notify bool) error {
err = errors.New(fmt.Sprintf("响应时间 %.2fs 超过阈值 %.2fs", float64(milliseconds)/1000.0, float64(disableThreshold)/1000.0))
ban = true
}
// request error disables the channel
if openaiWithStatusErr != nil {
oaiErr := openaiWithStatusErr.Error
err = errors.New(fmt.Sprintf("type %s, httpCode %d, code %v, message %s", oaiErr.Type, openaiWithStatusErr.StatusCode, oaiErr.Code, oaiErr.Message))
ban = service.ShouldDisableChannel(channel.Type, openaiWithStatusErr)
if openaiErr != nil {
err = errors.New(fmt.Sprintf("type %s, code %v, message %s", openaiErr.Type, openaiErr.Code, openaiErr.Message))
ban = true
}
// parse *int to bool
if channel.AutoBan != nil && *channel.AutoBan == 0 {
ban = false
}
// disable channel
if ban && isChannelEnabled {
if isChannelEnabled && service.ShouldDisableChannel(openaiErr, -1) && ban {
service.DisableChannel(channel.Id, channel.Name, err.Error())
}
// enable channel
if !isChannelEnabled && service.ShouldEnableChannel(err, openaiWithStatusErr, channel.Status) {
if !isChannelEnabled && service.ShouldEnableChannel(err, openaiErr, channel.Status) {
service.EnableChannel(channel.Id, channel.Name)
}
channel.UpdateResponseTime(milliseconds)
time.Sleep(common.RequestInterval)
}
@@ -266,7 +240,7 @@ func testAllChannels(notify bool) error {
common.SysError(fmt.Sprintf("failed to send email: %s", err.Error()))
}
}
})
}()
return nil
}

View File

@@ -146,26 +146,28 @@ func UpdateMidjourneyTaskBulk() {
buttonStr, _ := json.Marshal(responseItem.Buttons)
task.Buttons = string(buttonStr)
}
shouldReturnQuota := false
if (task.Progress != "100%" && responseItem.FailReason != "") || (task.Progress == "100%" && task.Status == "FAILURE") {
common.LogInfo(ctx, task.MjId+" 构建失败,"+task.FailReason)
task.Progress = "100%"
if task.Quota != 0 {
shouldReturnQuota = true
err = model.CacheUpdateUserQuota(task.UserId)
if err != nil {
common.LogError(ctx, "error update user quota cache: "+err.Error())
} else {
quota := task.Quota
if quota != 0 {
err = model.IncreaseUserQuota(task.UserId, quota)
if err != nil {
common.LogError(ctx, "fail to increase user quota: "+err.Error())
}
logContent := fmt.Sprintf("构图失败 %s补偿 %s", task.MjId, common.LogQuota(quota))
model.RecordLog(task.UserId, model.LogTypeSystem, logContent)
}
}
}
err = task.Update()
if err != nil {
common.LogError(ctx, "UpdateMidjourneyTask task error: "+err.Error())
} else {
if shouldReturnQuota {
err = model.IncreaseUserQuota(task.UserId, task.Quota)
if err != nil {
common.LogError(ctx, "fail to increase user quota: "+err.Error())
}
logContent := fmt.Sprintf("构图失败 %s补偿 %s", task.MjId, common.LogQuota(task.Quota))
model.RecordLog(task.UserId, model.LogTypeSystem, logContent)
}
}
}
}

View File

@@ -131,7 +131,7 @@ func init() {
}
meta := &relaycommon.RelayInfo{ChannelType: i}
adaptor := relay.GetAdaptor(apiType)
adaptor.Init(meta)
adaptor.Init(meta, dto.GeneralOpenAIRequest{})
channelId2Models[i] = adaptor.GetModelList()
}
}

View File

@@ -22,15 +22,13 @@ func relayHandler(c *gin.Context, relayMode int) *dto.OpenAIErrorWithStatusCode
var err *dto.OpenAIErrorWithStatusCode
switch relayMode {
case relayconstant.RelayModeImagesGenerations:
err = relay.ImageHelper(c, relayMode)
err = relay.RelayImageHelper(c, relayMode)
case relayconstant.RelayModeAudioSpeech:
fallthrough
case relayconstant.RelayModeAudioTranslation:
fallthrough
case relayconstant.RelayModeAudioTranscription:
err = relay.AudioHelper(c)
case relayconstant.RelayModeRerank:
err = relay.RerankHelper(c, relayMode)
err = relay.AudioHelper(c, relayMode)
default:
err = relay.TextHelper(c)
}
@@ -42,13 +40,12 @@ func Relay(c *gin.Context) {
retryTimes := common.RetryTimes
requestId := c.GetString(common.RequestIdKey)
channelId := c.GetInt("channel_id")
channelType := c.GetInt("channel_type")
group := c.GetString("group")
originalModel := c.GetString("original_model")
openaiErr := relayHandler(c, relayMode)
c.Set("use_channel", []string{fmt.Sprintf("%d", channelId)})
if openaiErr != nil {
go processChannelError(c, channelId, channelType, openaiErr)
go processChannelError(c, channelId, openaiErr)
} else {
retryTimes = 0
}
@@ -69,7 +66,7 @@ func Relay(c *gin.Context) {
c.Request.Body = io.NopCloser(bytes.NewBuffer(requestBody))
openaiErr = relayHandler(c, relayMode)
if openaiErr != nil {
go processChannelError(c, channelId, channel.Type, openaiErr)
go processChannelError(c, channelId, openaiErr)
}
}
useChannel := c.GetStringSlice("use_channel")
@@ -128,10 +125,10 @@ func shouldRetry(c *gin.Context, channelId int, openaiErr *dto.OpenAIErrorWithSt
return true
}
func processChannelError(c *gin.Context, channelId int, channelType int, err *dto.OpenAIErrorWithStatusCode) {
func processChannelError(c *gin.Context, channelId int, err *dto.OpenAIErrorWithStatusCode) {
autoBan := c.GetBool("auto_ban")
common.LogError(c.Request.Context(), fmt.Sprintf("relay error (channel #%d, status code: %d): %s", channelId, err.StatusCode, err.Error.Message))
if service.ShouldDisableChannel(channelType, err) && autoBan {
if service.ShouldDisableChannel(&err.Error, err.StatusCode) && autoBan {
channelName := c.GetString("channel_name")
service.DisableChannel(channelId, channelName, err.Error.Message)
}

View File

@@ -1,6 +1,7 @@
package controller
import (
"bytes"
"context"
"encoding/json"
"errors"
@@ -13,7 +14,7 @@ import (
"one-api/constant"
"one-api/dto"
"one-api/model"
"one-api/relay"
"one-api/service"
"sort"
"strconv"
"time"
@@ -107,13 +108,26 @@ func updateSunoTaskAll(ctx context.Context, channelId int, taskIds []string, tas
}
return err
}
adaptor := relay.GetTaskAdaptor(constant.TaskPlatformSuno)
if adaptor == nil {
return errors.New("adaptor not found")
}
resp, err := adaptor.FetchTask(*channel.BaseURL, channel.Key, map[string]any{
requestUrl := fmt.Sprintf("%s/fetch", *channel.BaseURL)
body, _ := json.Marshal(map[string]any{
"ids": taskIds,
})
req, err := http.NewRequest("POST", requestUrl, bytes.NewBuffer(body))
if err != nil {
common.SysError(fmt.Sprintf("Get Task error: %v", err))
return err
}
defer req.Body.Close()
// 设置超时时间
timeout := time.Second * 15
ctx, cancel := context.WithTimeout(context.Background(), timeout)
defer cancel()
// 使用带有超时的 context 创建新的请求
req = req.WithContext(ctx)
req.Header.Set("Content-Type", "application/json")
req.Header.Set("Authorization", "Bearer "+channel.Key)
resp, err := service.GetHttpClient().Do(req)
if err != nil {
common.SysError(fmt.Sprintf("Get Task Do req error: %v", err))
return err
@@ -131,7 +145,7 @@ func updateSunoTaskAll(ctx context.Context, channelId int, taskIds []string, tas
var responseItems dto.TaskResponse[[]dto.SunoDataResponse]
err = json.Unmarshal(responseBody, &responseItems)
if err != nil {
common.LogError(ctx, fmt.Sprintf("Get Task parse body error2: %v, body: %s", err, string(responseBody)))
common.LogError(ctx, fmt.Sprintf("Get Task parse body error2: %v, req: %s, body: %s", err, string(body), string(responseBody)))
return err
}
if !responseItems.IsSuccess() {

View File

@@ -5,10 +5,11 @@ import (
"github.com/Calcium-Ion/go-epay/epay"
"github.com/gin-gonic/gin"
"github.com/samber/lo"
"one-api/constant"
"log"
"net/url"
"one-api/common"
"one-api/constant"
"one-api/model"
"one-api/service"
"strconv"

View File

@@ -1,34 +1,13 @@
package dto
type AudioRequest struct {
Model string `json:"model"`
Input string `json:"input"`
Voice string `json:"voice"`
Speed float64 `json:"speed,omitempty"`
ResponseFormat string `json:"response_format,omitempty"`
type TextToSpeechRequest struct {
Model string `json:"model" binding:"required"`
Input string `json:"input" binding:"required"`
Voice string `json:"voice" binding:"required"`
Speed float64 `json:"speed"`
ResponseFormat string `json:"response_format"`
}
type AudioResponse struct {
Text string `json:"text"`
}
type WhisperVerboseJSONResponse struct {
Task string `json:"task,omitempty"`
Language string `json:"language,omitempty"`
Duration float64 `json:"duration,omitempty"`
Text string `json:"text,omitempty"`
Segments []Segment `json:"segments,omitempty"`
}
type Segment struct {
Id int `json:"id"`
Seek int `json:"seek"`
Start float64 `json:"start"`
End float64 `json:"end"`
Text string `json:"text"`
Tokens []int `json:"tokens"`
Temperature float64 `json:"temperature"`
AvgLogprob float64 `json:"avg_logprob"`
CompressionRatio float64 `json:"compression_ratio"`
NoSpeechProb float64 `json:"no_speech_prob"`
}

View File

@@ -12,11 +12,9 @@ type ImageRequest struct {
}
type ImageResponse struct {
Data []ImageData `json:"data"`
Created int64 `json:"created"`
}
type ImageData struct {
Url string `json:"url"`
B64Json string `json:"b64_json"`
RevisedPrompt string `json:"revised_prompt"`
Created int `json:"created"`
Data []struct {
Url string `json:"url"`
B64Json string `json:"b64_json"`
}
}

View File

@@ -24,3 +24,14 @@ type OpenAIModels struct {
Root string `json:"root"`
Parent *string `json:"parent"`
}
type ModelPricing struct {
Available bool `json:"available"`
ModelName string `json:"model_name"`
QuotaType int `json:"quota_type"`
ModelRatio float64 `json:"model_ratio"`
ModelPrice float64 `json:"model_price"`
OwnerBy string `json:"owner_by"`
CompletionRatio float64 `json:"completion_ratio"`
EnableGroup []string `json:"enable_group,omitempty"`
}

View File

@@ -1,19 +0,0 @@
package dto
type RerankRequest struct {
Documents []any `json:"documents"`
Query string `json:"query"`
Model string `json:"model"`
TopN int `json:"top_n"`
}
type RerankResponseDocument struct {
Document any `json:"document"`
Index int `json:"index"`
RelevanceScore float64 `json:"relevance_score"`
}
type RerankResponse struct {
Results []RerankResponseDocument `json:"results"`
Usage Usage `json:"usage"`
}

View File

@@ -11,7 +11,6 @@ type GeneralOpenAIRequest struct {
Messages []Message `json:"messages,omitempty"`
Prompt any `json:"prompt,omitempty"`
Stream bool `json:"stream,omitempty"`
StreamOptions *StreamOptions `json:"stream_options,omitempty"`
MaxTokens uint `json:"max_tokens,omitempty"`
Temperature float64 `json:"temperature,omitempty"`
TopP float64 `json:"top_p,omitempty"`
@@ -26,12 +25,11 @@ type GeneralOpenAIRequest struct {
PresencePenalty float64 `json:"presence_penalty,omitempty"`
ResponseFormat *ResponseFormat `json:"response_format,omitempty"`
Seed float64 `json:"seed,omitempty"`
Tools []ToolCall `json:"tools,omitempty"`
Tools any `json:"tools,omitempty"`
ToolChoice any `json:"tool_choice,omitempty"`
User string `json:"user,omitempty"`
LogProbs bool `json:"logprobs,omitempty"`
TopLogProbs int `json:"top_logprobs,omitempty"`
Dimensions int `json:"dimensions,omitempty"`
}
type OpenAITools struct {
@@ -45,12 +43,8 @@ type OpenAIFunction struct {
Parameters any `json:"parameters,omitempty"`
}
type StreamOptions struct {
IncludeUsage bool `json:"include_usage,omitempty"`
}
func (r GeneralOpenAIRequest) GetMaxTokens() int {
return int(r.MaxTokens)
func (r GeneralOpenAIRequest) GetMaxTokens() int64 {
return int64(r.MaxTokens)
}
func (r GeneralOpenAIRequest) ParseInput() []string {
@@ -104,11 +98,6 @@ func (m Message) StringContent() string {
return string(m.Content)
}
func (m *Message) SetStringContent(content string) {
jsonContent, _ := json.Marshal(content)
m.Content = jsonContent
}
func (m Message) IsStringContent() bool {
var stringContent string
if err := json.Unmarshal(m.Content, &stringContent); err == nil {
@@ -148,7 +137,7 @@ func (m Message) ParseContent() []MediaMessage {
if ok {
subObj["detail"] = detail.(string)
} else {
subObj["detail"] = "high"
subObj["detail"] = "auto"
}
contentList = append(contentList, MediaMessage{
Type: ContentTypeImageURL,
@@ -157,16 +146,7 @@ func (m Message) ParseContent() []MediaMessage {
Detail: subObj["detail"].(string),
},
})
} else if url, ok := contentMap["image_url"].(string); ok {
contentList = append(contentList, MediaMessage{
Type: ContentTypeImageURL,
ImageUrl: MessageImageUrl{
Url: url,
Detail: "high",
},
})
}
}
}
return contentList

View File

@@ -66,6 +66,10 @@ type ChatCompletionsStreamResponseChoiceDelta struct {
ToolCalls []ToolCall `json:"tool_calls,omitempty"`
}
func (c *ChatCompletionsStreamResponseChoiceDelta) IsEmpty() bool {
return c.Content == nil && len(c.ToolCalls) == 0
}
func (c *ChatCompletionsStreamResponseChoiceDelta) SetContentString(s string) {
c.Content = &s
}
@@ -86,11 +90,9 @@ type ToolCall struct {
}
type FunctionCall struct {
Description string `json:"description,omitempty"`
Name string `json:"name,omitempty"`
Name string `json:"name,omitempty"`
// call function with arguments in JSON format
Parameters any `json:"parameters,omitempty"` // request
Arguments string `json:"arguments,omitempty"`
Arguments string `json:"arguments,omitempty"`
}
type ChatCompletionsStreamResponse struct {
@@ -100,23 +102,10 @@ type ChatCompletionsStreamResponse struct {
Model string `json:"model"`
SystemFingerprint *string `json:"system_fingerprint"`
Choices []ChatCompletionsStreamResponseChoice `json:"choices"`
Usage *Usage `json:"usage"`
}
func (c *ChatCompletionsStreamResponse) GetSystemFingerprint() string {
if c.SystemFingerprint == nil {
return ""
}
return *c.SystemFingerprint
}
func (c *ChatCompletionsStreamResponse) SetSystemFingerprint(s string) {
c.SystemFingerprint = &s
}
type ChatCompletionsStreamResponseSimple struct {
Choices []ChatCompletionsStreamResponseChoice `json:"choices"`
Usage *Usage `json:"usage"`
}
type CompletionsStreamResponse struct {

1
go.mod
View File

@@ -38,7 +38,6 @@ require (
github.com/aws/aws-sdk-go-v2/internal/configsources v1.3.5 // indirect
github.com/aws/aws-sdk-go-v2/internal/endpoints/v2 v2.6.5 // indirect
github.com/aws/smithy-go v1.20.2 // indirect
github.com/bytedance/gopkg v0.0.0-20220118071334-3db87571198b // indirect
github.com/bytedance/sonic v1.9.1 // indirect
github.com/cespare/xxhash/v2 v2.1.2 // indirect
github.com/chenzhuoyu/base64x v0.0.0-20221115062448-fe3a3abad311 // indirect

4
go.sum
View File

@@ -18,8 +18,6 @@ github.com/aws/aws-sdk-go-v2/service/bedrockruntime v1.7.4 h1:JgHnonzbnA3pbqj76w
github.com/aws/aws-sdk-go-v2/service/bedrockruntime v1.7.4/go.mod h1:nZspkhg+9p8iApLFoyAqfyuMP0F38acy2Hm3r5r95Cg=
github.com/aws/smithy-go v1.20.2 h1:tbp628ireGtzcHDDmLT/6ADHidqnwgF57XOXZe6tp4Q=
github.com/aws/smithy-go v1.20.2/go.mod h1:krry+ya/rV9RDcV/Q16kpu6ypI4K2czasz0NC3qS14E=
github.com/bytedance/gopkg v0.0.0-20220118071334-3db87571198b h1:LTGVFpNmNHhj0vhOlfgWueFJ32eK9blaIlHR2ciXOT0=
github.com/bytedance/gopkg v0.0.0-20220118071334-3db87571198b/go.mod h1:2ZlV9BaUH4+NXIBF0aMdKKAnHTzqH+iMU4KUjAbL23Q=
github.com/bytedance/sonic v1.5.0/go.mod h1:ED5hyg4y6t3/9Ku1R6dU/4KyJ48DZ4jPhfY1O2AihPM=
github.com/bytedance/sonic v1.9.1 h1:6iJ6NqdoxCDr6mbY8h18oSO+cShGSMRGCEo7F2h0x8s=
github.com/bytedance/sonic v1.9.1/go.mod h1:i736AoUSYt75HyZLoJW9ERYxcy6eaN6h4BZXU064P/U=
@@ -200,7 +198,6 @@ golang.org/x/image v0.15.0/go.mod h1:HUYqC05R2ZcZ3ejNQsIHQDQiwWM4JBqmm6MKANTp4LE
golang.org/x/net v0.0.0-20210226172049-e18ecbb05110/go.mod h1:m0MpNAwzfU5UDzcl9v0D8zg8gWTRqZa9RBIspLL5mdg=
golang.org/x/net v0.21.0 h1:AQyQV4dYCvJ7vGmJyKki9+PBdyvhkSd8EIx/qb0AYv4=
golang.org/x/net v0.21.0/go.mod h1:bIjVDfnllIU7BJ2DNgfnXvpSvtn8VRwhlsaeUTyUS44=
golang.org/x/sync v0.0.0-20210220032951-036812b2e83c/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
golang.org/x/sync v0.7.0 h1:YsImfSBoP9QPYL0xyKJPq0gcaJdG3rInoqxTWbfQu9M=
golang.org/x/sync v0.7.0/go.mod h1:Czt+wKu1gCyEFDUtn0jG5QVvpJ6rzVqr5aXyt9drQfk=
golang.org/x/sys v0.0.0-20190916202348-b4ddaad3f8a3/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
@@ -209,7 +206,6 @@ golang.org/x/sys v0.0.0-20201119102817-f84b799fce68/go.mod h1:h1NjWce9XRLGQEsW7w
golang.org/x/sys v0.0.0-20210615035016-665e8c7367d1/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.0.0-20210630005230-0f9fa26af87c/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.0.0-20210806184541-e5e7981a1069/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.0.0-20220110181412-a018aaa089fe/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.0.0-20220704084225-05e143d24a9e/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.8.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=

20
main.go
View File

@@ -3,14 +3,12 @@ package main
import (
"embed"
"fmt"
"github.com/bytedance/gopkg/util/gopool"
"github.com/gin-contrib/sessions"
"github.com/gin-contrib/sessions/cookie"
"github.com/gin-gonic/gin"
"log"
"net/http"
"one-api/common"
"one-api/constant"
"one-api/controller"
"one-api/middleware"
"one-api/model"
@@ -22,10 +20,10 @@ import (
_ "net/http/pprof"
)
//go:embed web/dist
// /go:embed web/dist
var buildFS embed.FS
//go:embed web/dist/index.html
// /go:embed web/dist/index.html
var indexPage []byte
func main() {
@@ -91,14 +89,12 @@ func main() {
}
go controller.AutomaticallyTestChannels(frequency)
}
if common.IsMasterNode && constant.UpdateTask {
gopool.Go(func() {
controller.UpdateMidjourneyTaskBulk()
})
gopool.Go(func() {
controller.UpdateTaskBulk()
})
}
common.SafeGoroutine(func() {
controller.UpdateMidjourneyTaskBulk()
})
common.SafeGoroutine(func() {
controller.UpdateTaskBulk()
})
if os.Getenv("BATCH_UPDATE_ENABLED") == "true" {
common.BatchUpdateEnabled = true
common.SysLog("batch update enabled with interval " + strconv.Itoa(common.BatchUpdateInterval) + "s")

View File

@@ -1,7 +1,6 @@
package middleware
import (
"errors"
"fmt"
"net/http"
"one-api/common"
@@ -26,10 +25,6 @@ func Distribute() func(c *gin.Context) {
var channel *model.Channel
channelId, ok := c.Get("specific_channel_id")
modelRequest, shouldSelectChannel, err := getModelRequest(c)
if err != nil {
abortWithOpenAiMessage(c, http.StatusBadRequest, "Invalid request, "+err.Error())
return
}
userGroup, _ := model.CacheGetUserGroup(userId)
c.Set("group", userGroup)
if ok {
@@ -146,7 +141,7 @@ func getModelRequest(c *gin.Context) (*ModelRequest, bool, error) {
}
if err != nil {
abortWithOpenAiMessage(c, http.StatusBadRequest, "无效的请求, "+err.Error())
return nil, false, errors.New("无效的请求, " + err.Error())
return nil, false, err
}
if strings.HasPrefix(c.Request.URL.Path, "/v1/moderations") {
if modelRequest.Model == "" {
@@ -159,22 +154,18 @@ func getModelRequest(c *gin.Context) (*ModelRequest, bool, error) {
}
}
if strings.HasPrefix(c.Request.URL.Path, "/v1/images/generations") {
modelRequest.Model = common.GetStringIfEmpty(modelRequest.Model, "dall-e")
if modelRequest.Model == "" {
modelRequest.Model = "dall-e"
}
}
if strings.HasPrefix(c.Request.URL.Path, "/v1/audio") {
relayMode := relayconstant.RelayModeAudioSpeech
if strings.HasPrefix(c.Request.URL.Path, "/v1/audio/speech") {
modelRequest.Model = common.GetStringIfEmpty(modelRequest.Model, "tts-1")
} else if strings.HasPrefix(c.Request.URL.Path, "/v1/audio/translations") {
modelRequest.Model = common.GetStringIfEmpty(modelRequest.Model, c.PostForm("model"))
modelRequest.Model = common.GetStringIfEmpty(modelRequest.Model, "whisper-1")
relayMode = relayconstant.RelayModeAudioTranslation
} else if strings.HasPrefix(c.Request.URL.Path, "/v1/audio/transcriptions") {
modelRequest.Model = common.GetStringIfEmpty(modelRequest.Model, c.PostForm("model"))
modelRequest.Model = common.GetStringIfEmpty(modelRequest.Model, "whisper-1")
relayMode = relayconstant.RelayModeAudioTranscription
if modelRequest.Model == "" {
if strings.HasPrefix(c.Request.URL.Path, "/v1/audio/speech") {
modelRequest.Model = "tts-1"
} else {
modelRequest.Model = "whisper-1"
}
}
c.Set("relay_mode", relayMode)
}
return &modelRequest, shouldSelectChannel, nil
}
@@ -187,7 +178,6 @@ func SetupContextForSelectedChannel(c *gin.Context, channel *model.Channel, mode
c.Set("channel", channel.Type)
c.Set("channel_id", channel.Id)
c.Set("channel_name", channel.Name)
c.Set("channel_type", channel.Type)
ban := true
// parse *int to bool
if channel.AutoBan != nil && *channel.AutoBan == 0 {
@@ -207,11 +197,11 @@ func SetupContextForSelectedChannel(c *gin.Context, channel *model.Channel, mode
c.Set("api_version", channel.Other)
case common.ChannelTypeXunfei:
c.Set("api_version", channel.Other)
//case common.ChannelTypeAIProxyLibrary:
// c.Set("library_id", channel.Other)
case common.ChannelTypeGemini:
c.Set("api_version", channel.Other)
case common.ChannelTypeAli:
c.Set("plugin", channel.Other)
case common.ChannelCloudflare:
c.Set("api_version", channel.Other)
}
}

View File

@@ -56,11 +56,6 @@ func getPriority(group string, model string, retry int) (int, error) {
return 0, err
}
if len(priorities) == 0 {
// 如果没有查询到优先级,则返回错误
return 0, errors.New("数据库一致性被破坏")
}
// 确定要使用的优先级
var priorityToUse int
if retry >= len(priorities) {
@@ -204,7 +199,7 @@ func FixAbility() (int, error) {
// Use channelIds to find channel not in abilities table
var abilityChannelIds []int
err = DB.Table("abilities").Distinct("channel_id").Pluck("channel_id", &abilityChannelIds).Error
err = DB.Model(&Ability{}).Pluck("channel_id", &abilityChannelIds).Error
if err != nil {
common.SysError(fmt.Sprintf("Get channel ids from abilities table failed: %s", err.Error()))
return 0, err

View File

@@ -1,10 +1,8 @@
package model
import (
"encoding/json"
"gorm.io/gorm"
"one-api/common"
"strings"
)
type Channel struct {
@@ -31,38 +29,6 @@ type Channel struct {
StatusCodeMapping *string `json:"status_code_mapping" gorm:"type:varchar(1024);default:''"`
Priority *int64 `json:"priority" gorm:"bigint;default:0"`
AutoBan *int `json:"auto_ban" gorm:"default:1"`
OtherInfo string `json:"other_info"`
}
func (channel *Channel) GetModels() []string {
if channel.Models == "" {
return []string{}
}
return strings.Split(strings.Trim(channel.Models, ","), ",")
}
func (channel *Channel) GetOtherInfo() map[string]interface{} {
otherInfo := make(map[string]interface{})
if channel.OtherInfo != "" {
err := json.Unmarshal([]byte(channel.OtherInfo), &otherInfo)
if err != nil {
common.SysError("failed to unmarshal other info: " + err.Error())
}
}
return otherInfo
}
func (channel *Channel) SetOtherInfo(otherInfo map[string]interface{}) {
otherInfoBytes, err := json.Marshal(otherInfo)
if err != nil {
common.SysError("failed to marshal other info: " + err.Error())
return
}
channel.OtherInfo = string(otherInfoBytes)
}
func (channel *Channel) Save() error {
return DB.Save(channel).Error
}
func GetAllChannels(startIdx int, num int, selectAll bool, idSort bool) ([]*Channel, error) {
@@ -247,31 +213,15 @@ func (channel *Channel) Delete() error {
return err
}
func UpdateChannelStatusById(id int, status int, reason string) {
func UpdateChannelStatusById(id int, status int) {
err := UpdateAbilityStatus(id, status == common.ChannelStatusEnabled)
if err != nil {
common.SysError("failed to update ability status: " + err.Error())
}
channel, err := GetChannelById(id, true)
err = DB.Model(&Channel{}).Where("id = ?", id).Update("status", status).Error
if err != nil {
// find channel by id error, directly update status
err = DB.Model(&Channel{}).Where("id = ?", id).Update("status", status).Error
if err != nil {
common.SysError("failed to update channel status: " + err.Error())
}
} else {
// find channel by id success, update status and other info
info := channel.GetOtherInfo()
info["status_reason"] = reason
info["status_time"] = common.GetTimestamp()
channel.SetOtherInfo(info)
channel.Status = status
err = channel.Save()
if err != nil {
common.SysError("failed to update channel status: " + err.Error())
}
common.SysError("failed to update channel status: " + err.Error())
}
}
func UpdateChannelUsedQuota(id int, quota int) {

View File

@@ -3,7 +3,6 @@ package model
import (
"context"
"fmt"
"github.com/bytedance/gopkg/util/gopool"
"gorm.io/gorm"
"one-api/common"
"strings"
@@ -88,7 +87,7 @@ func RecordConsumeLog(ctx context.Context, userId int, channelId int, promptToke
common.LogError(ctx, "failed to record log: "+err.Error())
}
if common.DataExportEnabled {
gopool.Go(func() {
common.SafeGoroutine(func() {
LogQuotaData(userId, username, modelName, quota, common.GetTimestamp(), promptTokens+completionTokens)
})
}

View File

@@ -86,9 +86,9 @@ func InitDB() (err error) {
if err != nil {
return err
}
sqlDB.SetMaxIdleConns(common.GetEnvOrDefault("SQL_MAX_IDLE_CONNS", 100))
sqlDB.SetMaxOpenConns(common.GetEnvOrDefault("SQL_MAX_OPEN_CONNS", 1000))
sqlDB.SetConnMaxLifetime(time.Second * time.Duration(common.GetEnvOrDefault("SQL_MAX_LIFETIME", 60)))
sqlDB.SetMaxIdleConns(common.GetOrDefault("SQL_MAX_IDLE_CONNS", 100))
sqlDB.SetMaxOpenConns(common.GetOrDefault("SQL_MAX_OPEN_CONNS", 1000))
sqlDB.SetConnMaxLifetime(time.Second * time.Duration(common.GetOrDefault("SQL_MAX_LIFETIME", 60)))
if !common.IsMasterNode {
return nil

View File

@@ -99,7 +99,6 @@ func InitOptionMap() {
common.OptionMap["MjAccountFilterEnabled"] = strconv.FormatBool(constant.MjAccountFilterEnabled)
common.OptionMap["MjModeClearEnabled"] = strconv.FormatBool(constant.MjModeClearEnabled)
common.OptionMap["MjForwardUrlEnabled"] = strconv.FormatBool(constant.MjForwardUrlEnabled)
common.OptionMap["MjActionCheckSuccessEnabled"] = strconv.FormatBool(constant.MjActionCheckSuccessEnabled)
common.OptionMap["CheckSensitiveEnabled"] = strconv.FormatBool(constant.CheckSensitiveEnabled)
common.OptionMap["CheckSensitiveOnPromptEnabled"] = strconv.FormatBool(constant.CheckSensitiveOnPromptEnabled)
//common.OptionMap["CheckSensitiveOnCompletionEnabled"] = strconv.FormatBool(constant.CheckSensitiveOnCompletionEnabled)
@@ -211,8 +210,6 @@ func updateOptionMap(key string, value string) (err error) {
constant.MjModeClearEnabled = boolValue
case "MjForwardUrlEnabled":
constant.MjForwardUrlEnabled = boolValue
case "MjActionCheckSuccessEnabled":
constant.MjActionCheckSuccessEnabled = boolValue
case "CheckSensitiveEnabled":
constant.CheckSensitiveEnabled = boolValue
case "CheckSensitiveOnPromptEnabled":

View File

@@ -2,28 +2,18 @@ package model
import (
"one-api/common"
"one-api/dto"
"sync"
"time"
)
type Pricing struct {
Available bool `json:"available"`
ModelName string `json:"model_name"`
QuotaType int `json:"quota_type"`
ModelRatio float64 `json:"model_ratio"`
ModelPrice float64 `json:"model_price"`
OwnerBy string `json:"owner_by"`
CompletionRatio float64 `json:"completion_ratio"`
EnableGroup []string `json:"enable_group,omitempty"`
}
var (
pricingMap []Pricing
pricingMap []dto.ModelPricing
lastGetPricingTime time.Time
updatePricingLock sync.Mutex
)
func GetPricing(group string) []Pricing {
func GetPricing(group string) []dto.ModelPricing {
updatePricingLock.Lock()
defer updatePricingLock.Unlock()
@@ -31,7 +21,7 @@ func GetPricing(group string) []Pricing {
updatePricing()
}
if group != "" {
userPricingMap := make([]Pricing, 0)
userPricingMap := make([]dto.ModelPricing, 0)
models := GetGroupModels(group)
for _, pricing := range pricingMap {
if !common.StringsContains(models, pricing.ModelName) {
@@ -52,9 +42,9 @@ func updatePricing() {
allModels[model] = i
}
pricingMap = make([]Pricing, 0)
pricingMap = make([]dto.ModelPricing, 0)
for model, _ := range allModels {
pricing := Pricing{
pricing := dto.ModelPricing{
Available: true,
ModelName: model,
}

View File

@@ -78,7 +78,7 @@ func Redeem(key string, userId int) (quota int, err error) {
if err != nil {
return 0, errors.New("兑换失败," + err.Error())
}
RecordLog(userId, LogTypeTopup, fmt.Sprintf("通过兑换码充值 %s兑换码ID %d", common.LogQuota(redemption.Quota), redemption.Id))
RecordLog(userId, LogTypeTopup, fmt.Sprintf("通过兑换码充值 %s", common.LogQuota(redemption.Quota)))
return redemption.Quota, nil
}

View File

@@ -250,9 +250,11 @@ func PreConsumeTokenQuota(tokenId int, quota int) (userQuota int, err error) {
if userQuota < quota {
return 0, errors.New(fmt.Sprintf("用户额度不足,剩余额度为 %d", userQuota))
}
err = DecreaseTokenQuota(tokenId, quota)
if err != nil {
return 0, err
if !token.UnlimitedQuota {
err = DecreaseTokenQuota(tokenId, quota)
if err != nil {
return 0, err
}
}
err = DecreaseUserQuota(token.UserId, quota)
return userQuota - quota, err
@@ -270,13 +272,15 @@ func PostConsumeTokenQuota(tokenId int, userQuota int, quota int, preConsumedQuo
return err
}
if quota > 0 {
err = DecreaseTokenQuota(tokenId, quota)
} else {
err = IncreaseTokenQuota(tokenId, -quota)
}
if err != nil {
return err
if !token.UnlimitedQuota {
if quota > 0 {
err = DecreaseTokenQuota(tokenId, quota)
} else {
err = IncreaseTokenQuota(tokenId, -quota)
}
if err != nil {
return err
}
}
if sendEmail {

View File

@@ -298,8 +298,7 @@ func (user *User) ValidateAndFill() (err error) {
if user.Username == "" || password == "" {
return errors.New("用户名或密码为空")
}
// find buy username or email
DB.Where("username = ? OR email = ?", user.Username, user.Username).First(user)
DB.Where(User{Username: user.Username}).First(user)
okay := common.ValidatePasswordAndHash(password, user.Password)
if !okay || user.Status != common.UserStatusEnabled {
return errors.New("用户名或密码错误,或用户已被封禁")

View File

@@ -2,7 +2,6 @@ package model
import (
"errors"
"github.com/bytedance/gopkg/util/gopool"
"gorm.io/gorm"
"one-api/common"
"sync"
@@ -29,12 +28,12 @@ func init() {
}
func InitBatchUpdater() {
gopool.Go(func() {
go func() {
for {
time.Sleep(time.Duration(common.BatchUpdateInterval) * time.Second)
batchUpdate()
}
})
}()
}
func addNewRecord(type_ int, id int, value int) {

View File

@@ -10,13 +10,10 @@ import (
type Adaptor interface {
// Init IsStream bool
Init(info *relaycommon.RelayInfo)
Init(info *relaycommon.RelayInfo, request dto.GeneralOpenAIRequest)
GetRequestURL(info *relaycommon.RelayInfo) (string, error)
SetupRequestHeader(c *gin.Context, req *http.Request, info *relaycommon.RelayInfo) error
ConvertRequest(c *gin.Context, info *relaycommon.RelayInfo, request *dto.GeneralOpenAIRequest) (any, error)
ConvertRerankRequest(c *gin.Context, relayMode int, request dto.RerankRequest) (any, error)
ConvertAudioRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.AudioRequest) (io.Reader, error)
ConvertImageRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.ImageRequest) (any, error)
ConvertRequest(c *gin.Context, relayMode int, request *dto.GeneralOpenAIRequest) (any, error)
DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (*http.Response, error)
DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage *dto.Usage, err *dto.OpenAIErrorWithStatusCode)
GetModelList() []string
@@ -39,5 +36,4 @@ type TaskAdaptor interface {
GetChannelName() string
// FetchTask
FetchTask(baseUrl, key string, body map[string]any) (*http.Response, error)
}

View File

@@ -8,7 +8,6 @@ import (
"net/http"
"one-api/dto"
"one-api/relay/channel"
"one-api/relay/channel/openai"
relaycommon "one-api/relay/common"
"one-api/relay/constant"
)
@@ -16,18 +15,14 @@ import (
type Adaptor struct {
}
func (a *Adaptor) Init(info *relaycommon.RelayInfo) {
func (a *Adaptor) Init(info *relaycommon.RelayInfo, request dto.GeneralOpenAIRequest) {
}
func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) {
var fullRequestURL string
switch info.RelayMode {
case constant.RelayModeEmbeddings:
fullRequestURL := fmt.Sprintf("%s/api/v1/services/aigc/text-generation/generation", info.BaseUrl)
if info.RelayMode == constant.RelayModeEmbeddings {
fullRequestURL = fmt.Sprintf("%s/api/v1/services/embeddings/text-embedding/text-embedding", info.BaseUrl)
case constant.RelayModeImagesGenerations:
fullRequestURL = fmt.Sprintf("%s/api/v1/services/aigc/text2image/image-synthesis", info.BaseUrl)
default:
fullRequestURL = fmt.Sprintf("%s/compatible-mode/v1/chat/completions", info.BaseUrl)
}
return fullRequestURL, nil
}
@@ -44,49 +39,33 @@ func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Request, info *re
return nil
}
func (a *Adaptor) ConvertRequest(c *gin.Context, info *relaycommon.RelayInfo, request *dto.GeneralOpenAIRequest) (any, error) {
func (a *Adaptor) ConvertRequest(c *gin.Context, relayMode int, request *dto.GeneralOpenAIRequest) (any, error) {
if request == nil {
return nil, errors.New("request is nil")
}
switch info.RelayMode {
switch relayMode {
case constant.RelayModeEmbeddings:
baiduEmbeddingRequest := embeddingRequestOpenAI2Ali(*request)
return baiduEmbeddingRequest, nil
default:
aliReq := requestOpenAI2Ali(*request)
return aliReq, nil
baiduRequest := requestOpenAI2Ali(*request)
return baiduRequest, nil
}
}
func (a *Adaptor) ConvertImageRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.ImageRequest) (any, error) {
aliRequest := oaiImage2Ali(request)
return aliRequest, nil
}
func (a *Adaptor) ConvertRerankRequest(c *gin.Context, relayMode int, request dto.RerankRequest) (any, error) {
return nil, errors.New("not implemented")
}
func (a *Adaptor) ConvertAudioRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.AudioRequest) (io.Reader, error) {
//TODO implement me
return nil, errors.New("not implemented")
}
func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (*http.Response, error) {
return channel.DoApiRequest(a, c, info, requestBody)
}
func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage *dto.Usage, err *dto.OpenAIErrorWithStatusCode) {
switch info.RelayMode {
case constant.RelayModeImagesGenerations:
err, usage = aliImageHandler(c, resp, info)
case constant.RelayModeEmbeddings:
err, usage = aliEmbeddingHandler(c, resp)
default:
if info.IsStream {
err, usage = openai.OaiStreamHandler(c, resp, info)
} else {
err, usage = openai.OpenaiHandler(c, resp, info.PromptTokens, info.UpstreamModelName)
if info.IsStream {
err, usage = aliStreamHandler(c, resp)
} else {
switch info.RelayMode {
case constant.RelayModeEmbeddings:
err, usage = aliEmbeddingHandler(c, resp)
default:
err, usage = aliHandler(c, resp)
}
}
return

View File

@@ -60,40 +60,13 @@ type AliUsage struct {
TotalTokens int `json:"total_tokens"`
}
type TaskResult struct {
B64Image string `json:"b64_image,omitempty"`
Url string `json:"url,omitempty"`
Code string `json:"code,omitempty"`
Message string `json:"message,omitempty"`
}
type AliOutput struct {
TaskId string `json:"task_id,omitempty"`
TaskStatus string `json:"task_status,omitempty"`
Text string `json:"text"`
FinishReason string `json:"finish_reason"`
Message string `json:"message,omitempty"`
Code string `json:"code,omitempty"`
Results []TaskResult `json:"results,omitempty"`
Text string `json:"text"`
FinishReason string `json:"finish_reason"`
}
type AliResponse struct {
type AliChatResponse struct {
Output AliOutput `json:"output"`
Usage AliUsage `json:"usage"`
AliError
}
type AliImageRequest struct {
Model string `json:"model"`
Input struct {
Prompt string `json:"prompt"`
NegativePrompt string `json:"negative_prompt,omitempty"`
} `json:"input"`
Parameters struct {
Size string `json:"size,omitempty"`
N int `json:"n,omitempty"`
Steps string `json:"steps,omitempty"`
Scale string `json:"scale,omitempty"`
} `json:"parameters,omitempty"`
ResponseFormat string `json:"response_format,omitempty"`
}

View File

@@ -1,177 +0,0 @@
package ali
import (
"encoding/json"
"errors"
"fmt"
"github.com/gin-gonic/gin"
"io"
"net/http"
"one-api/common"
"one-api/dto"
relaycommon "one-api/relay/common"
"one-api/service"
"strings"
"time"
)
func oaiImage2Ali(request dto.ImageRequest) *AliImageRequest {
var imageRequest AliImageRequest
imageRequest.Input.Prompt = request.Prompt
imageRequest.Model = request.Model
imageRequest.Parameters.Size = strings.Replace(request.Size, "x", "*", -1)
imageRequest.Parameters.N = request.N
imageRequest.ResponseFormat = request.ResponseFormat
return &imageRequest
}
func updateTask(info *relaycommon.RelayInfo, taskID string, key string) (*AliResponse, error, []byte) {
url := fmt.Sprintf("/api/v1/tasks/%s", taskID)
var aliResponse AliResponse
req, err := http.NewRequest("GET", url, nil)
if err != nil {
return &aliResponse, err, nil
}
req.Header.Set("Authorization", "Bearer "+key)
client := &http.Client{}
resp, err := client.Do(req)
if err != nil {
common.SysError("updateTask client.Do err: " + err.Error())
return &aliResponse, err, nil
}
defer resp.Body.Close()
responseBody, err := io.ReadAll(resp.Body)
var response AliResponse
err = json.Unmarshal(responseBody, &response)
if err != nil {
common.SysError("updateTask NewDecoder err: " + err.Error())
return &aliResponse, err, nil
}
return &response, nil, responseBody
}
func asyncTaskWait(info *relaycommon.RelayInfo, taskID string, key string) (*AliResponse, []byte, error) {
waitSeconds := 3
step := 0
maxStep := 20
var taskResponse AliResponse
var responseBody []byte
for {
step++
rsp, err, body := updateTask(info, taskID, key)
responseBody = body
if err != nil {
return &taskResponse, responseBody, err
}
if rsp.Output.TaskStatus == "" {
return &taskResponse, responseBody, nil
}
switch rsp.Output.TaskStatus {
case "FAILED":
fallthrough
case "CANCELED":
fallthrough
case "SUCCEEDED":
fallthrough
case "UNKNOWN":
return rsp, responseBody, nil
}
if step >= maxStep {
break
}
time.Sleep(time.Duration(waitSeconds) * time.Second)
}
return nil, nil, fmt.Errorf("aliAsyncTaskWait timeout")
}
func responseAli2OpenAIImage(c *gin.Context, response *AliResponse, info *relaycommon.RelayInfo, responseFormat string) *dto.ImageResponse {
imageResponse := dto.ImageResponse{
Created: info.StartTime.Unix(),
}
for _, data := range response.Output.Results {
var b64Json string
if responseFormat == "b64_json" {
_, b64, err := service.GetImageFromUrl(data.Url)
if err != nil {
common.LogError(c, "get_image_data_failed: "+err.Error())
continue
}
b64Json = b64
} else {
b64Json = data.B64Image
}
imageResponse.Data = append(imageResponse.Data, dto.ImageData{
Url: data.Url,
B64Json: b64Json,
RevisedPrompt: "",
})
}
return &imageResponse
}
func aliImageHandler(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (*dto.OpenAIErrorWithStatusCode, *dto.Usage) {
apiKey := c.Request.Header.Get("Authorization")
apiKey = strings.TrimPrefix(apiKey, "Bearer ")
responseFormat := c.GetString("response_format")
var aliTaskResponse AliResponse
responseBody, err := io.ReadAll(resp.Body)
if err != nil {
return service.OpenAIErrorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil
}
err = resp.Body.Close()
if err != nil {
return service.OpenAIErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil
}
err = json.Unmarshal(responseBody, &aliTaskResponse)
if err != nil {
return service.OpenAIErrorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil
}
if aliTaskResponse.Message != "" {
common.LogError(c, "ali_async_task_failed: "+aliTaskResponse.Message)
return service.OpenAIErrorWrapper(errors.New(aliTaskResponse.Message), "ali_async_task_failed", http.StatusInternalServerError), nil
}
aliResponse, _, err := asyncTaskWait(info, aliTaskResponse.Output.TaskId, apiKey)
if err != nil {
return service.OpenAIErrorWrapper(err, "ali_async_task_wait_failed", http.StatusInternalServerError), nil
}
if aliResponse.Output.TaskStatus != "SUCCEEDED" {
return &dto.OpenAIErrorWithStatusCode{
Error: dto.OpenAIError{
Message: aliResponse.Output.Message,
Type: "ali_error",
Param: "",
Code: aliResponse.Output.Code,
},
StatusCode: resp.StatusCode,
}, nil
}
fullTextResponse := responseAli2OpenAIImage(c, aliResponse, info, responseFormat)
jsonResponse, err := json.Marshal(fullTextResponse)
if err != nil {
return service.OpenAIErrorWrapper(err, "marshal_response_body_failed", http.StatusInternalServerError), nil
}
c.Writer.Header().Set("Content-Type", "application/json")
c.Writer.WriteHeader(resp.StatusCode)
_, err = c.Writer.Write(jsonResponse)
return nil, nil
}

View File

@@ -16,13 +16,34 @@ import (
const EnableSearchModelSuffix = "-internet"
func requestOpenAI2Ali(request dto.GeneralOpenAIRequest) *dto.GeneralOpenAIRequest {
if request.TopP >= 1 {
request.TopP = 0.999
} else if request.TopP <= 0 {
request.TopP = 0.001
func requestOpenAI2Ali(request dto.GeneralOpenAIRequest) *AliChatRequest {
messages := make([]AliMessage, 0, len(request.Messages))
//prompt := ""
for i := 0; i < len(request.Messages); i++ {
message := request.Messages[i]
messages = append(messages, AliMessage{
Content: message.StringContent(),
Role: strings.ToLower(message.Role),
})
}
enableSearch := false
aliModel := request.Model
if strings.HasSuffix(aliModel, EnableSearchModelSuffix) {
enableSearch = true
aliModel = strings.TrimSuffix(aliModel, EnableSearchModelSuffix)
}
return &AliChatRequest{
Model: request.Model,
Input: AliInput{
//Prompt: prompt,
Messages: messages,
},
Parameters: AliParameters{
IncrementalOutput: request.Stream,
Seed: uint64(request.Seed),
EnableSearch: enableSearch,
},
}
return &request
}
func embeddingRequestOpenAI2Ali(request dto.GeneralOpenAIRequest) *AliEmbeddingRequest {
@@ -89,7 +110,7 @@ func embeddingResponseAli2OpenAI(response *AliEmbeddingResponse) *dto.OpenAIEmbe
return &openAIEmbeddingResponse
}
func responseAli2OpenAI(response *AliResponse) *dto.OpenAITextResponse {
func responseAli2OpenAI(response *AliChatResponse) *dto.OpenAITextResponse {
content, _ := json.Marshal(response.Output.Text)
choice := dto.OpenAITextResponseChoice{
Index: 0,
@@ -113,7 +134,7 @@ func responseAli2OpenAI(response *AliResponse) *dto.OpenAITextResponse {
return &fullTextResponse
}
func streamResponseAli2OpenAI(aliResponse *AliResponse) *dto.ChatCompletionsStreamResponse {
func streamResponseAli2OpenAI(aliResponse *AliChatResponse) *dto.ChatCompletionsStreamResponse {
var choice dto.ChatCompletionsStreamResponseChoice
choice.Delta.SetContentString(aliResponse.Output.Text)
if aliResponse.Output.FinishReason != "null" {
@@ -133,7 +154,18 @@ func streamResponseAli2OpenAI(aliResponse *AliResponse) *dto.ChatCompletionsStre
func aliStreamHandler(c *gin.Context, resp *http.Response) (*dto.OpenAIErrorWithStatusCode, *dto.Usage) {
var usage dto.Usage
scanner := bufio.NewScanner(resp.Body)
scanner.Split(bufio.ScanLines)
scanner.Split(func(data []byte, atEOF bool) (advance int, token []byte, err error) {
if atEOF && len(data) == 0 {
return 0, nil, nil
}
if i := strings.Index(string(data), "\n"); i >= 0 {
return i + 1, data[0:i], nil
}
if atEOF {
return len(data), data, nil
}
return 0, nil, nil
})
dataChan := make(chan string)
stopChan := make(chan bool)
go func() {
@@ -155,7 +187,7 @@ func aliStreamHandler(c *gin.Context, resp *http.Response) (*dto.OpenAIErrorWith
c.Stream(func(w io.Writer) bool {
select {
case data := <-dataChan:
var aliResponse AliResponse
var aliResponse AliChatResponse
err := json.Unmarshal([]byte(data), &aliResponse)
if err != nil {
common.SysError("error unmarshalling stream response: " + err.Error())
@@ -189,7 +221,7 @@ func aliStreamHandler(c *gin.Context, resp *http.Response) (*dto.OpenAIErrorWith
}
func aliHandler(c *gin.Context, resp *http.Response) (*dto.OpenAIErrorWithStatusCode, *dto.Usage) {
var aliResponse AliResponse
var aliResponse AliChatResponse
responseBody, err := io.ReadAll(resp.Body)
if err != nil {
return service.OpenAIErrorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil

View File

@@ -7,19 +7,14 @@ import (
"io"
"net/http"
"one-api/relay/common"
"one-api/relay/constant"
"one-api/service"
)
func SetupApiRequestHeader(info *common.RelayInfo, c *gin.Context, req *http.Request) {
if info.RelayMode == constant.RelayModeAudioTranscription || info.RelayMode == constant.RelayModeAudioTranslation {
// multipart/form-data
} else {
req.Header.Set("Content-Type", c.Request.Header.Get("Content-Type"))
req.Header.Set("Accept", c.Request.Header.Get("Accept"))
if info.IsStream && c.Request.Header.Get("Accept") == "" {
req.Header.Set("Accept", "text/event-stream")
}
req.Header.Set("Content-Type", c.Request.Header.Get("Content-Type"))
req.Header.Set("Accept", c.Request.Header.Get("Accept"))
if info.IsStream && c.Request.Header.Get("Accept") == "" {
req.Header.Set("Accept", "text/event-stream")
}
}
@@ -43,29 +38,6 @@ func DoApiRequest(a Adaptor, c *gin.Context, info *common.RelayInfo, requestBody
return resp, nil
}
func DoFormRequest(a Adaptor, c *gin.Context, info *common.RelayInfo, requestBody io.Reader) (*http.Response, error) {
fullRequestURL, err := a.GetRequestURL(info)
if err != nil {
return nil, fmt.Errorf("get request url failed: %w", err)
}
req, err := http.NewRequest(c.Request.Method, fullRequestURL, requestBody)
if err != nil {
return nil, fmt.Errorf("new request failed: %w", err)
}
// set form data
req.Header.Set("Content-Type", c.Request.Header.Get("Content-Type"))
err = a.SetupRequestHeader(c, req, info)
if err != nil {
return nil, fmt.Errorf("setup request header failed: %w", err)
}
resp, err := doRequest(c, req)
if err != nil {
return nil, fmt.Errorf("do request failed: %w", err)
}
return resp, nil
}
func doRequest(c *gin.Context, req *http.Request) (*http.Response, error) {
resp, err := service.GetHttpClient().Do(req)
if err != nil {

View File

@@ -20,17 +20,7 @@ type Adaptor struct {
RequestMode int
}
func (a *Adaptor) ConvertAudioRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.AudioRequest) (io.Reader, error) {
//TODO implement me
return nil, errors.New("not implemented")
}
func (a *Adaptor) ConvertImageRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.ImageRequest) (any, error) {
//TODO implement me
return nil, errors.New("not implemented")
}
func (a *Adaptor) Init(info *relaycommon.RelayInfo) {
func (a *Adaptor) Init(info *relaycommon.RelayInfo, request dto.GeneralOpenAIRequest) {
if strings.HasPrefix(info.UpstreamModelName, "claude-3") {
a.RequestMode = RequestModeMessage
} else {
@@ -46,7 +36,7 @@ func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Request, info *re
return nil
}
func (a *Adaptor) ConvertRequest(c *gin.Context, info *relaycommon.RelayInfo, request *dto.GeneralOpenAIRequest) (any, error) {
func (a *Adaptor) ConvertRequest(c *gin.Context, relayMode int, request *dto.GeneralOpenAIRequest) (any, error) {
if request == nil {
return nil, errors.New("request is nil")
}
@@ -63,17 +53,13 @@ func (a *Adaptor) ConvertRequest(c *gin.Context, info *relaycommon.RelayInfo, re
return claudeReq, err
}
func (a *Adaptor) ConvertRerankRequest(c *gin.Context, relayMode int, request dto.RerankRequest) (any, error) {
return nil, nil
}
func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (*http.Response, error) {
return nil, nil
}
func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage *dto.Usage, err *dto.OpenAIErrorWithStatusCode) {
if info.IsStream {
err, usage = awsStreamHandler(c, resp, info, a.RequestMode)
err, usage = awsStreamHandler(c, info, a.RequestMode)
} else {
err, usage = awsHandler(c, info, a.RequestMode)
}

View File

@@ -7,7 +7,6 @@ var awsModelIDMap = map[string]string{
"claude-3-sonnet-20240229": "anthropic.claude-3-sonnet-20240229-v1:0",
"claude-3-opus-20240229": "anthropic.claude-3-opus-20240229-v1:0",
"claude-3-haiku-20240307": "anthropic.claude-3-haiku-20240307-v1:0",
"claude-3-5-sonnet-20240620": "anthropic.claude-3-5-sonnet-20240620-v1:0",
}
var ChannelName = "aws"

View File

@@ -13,9 +13,7 @@ import (
relaymodel "one-api/dto"
"one-api/relay/channel/claude"
relaycommon "one-api/relay/common"
"one-api/service"
"strings"
"time"
"github.com/aws/aws-sdk-go-v2/aws"
"github.com/aws/aws-sdk-go-v2/credentials"
@@ -113,7 +111,7 @@ func awsHandler(c *gin.Context, info *relaycommon.RelayInfo, requestMode int) (*
return nil, &usage
}
func awsStreamHandler(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo, requestMode int) (*relaymodel.OpenAIErrorWithStatusCode, *relaymodel.Usage) {
func awsStreamHandler(c *gin.Context, info *relaycommon.RelayInfo, requestMode int) (*relaymodel.OpenAIErrorWithStatusCode, *relaymodel.Usage) {
awsCli, err := newAwsClient(c, info)
if err != nil {
return wrapErr(errors.Wrap(err, "newAwsClient")), nil
@@ -158,20 +156,16 @@ func awsStreamHandler(c *gin.Context, resp *http.Response, info *relaycommon.Rel
var usage relaymodel.Usage
var id string
var model string
isFirst := true
createdTime := common.GetTimestamp()
c.Stream(func(w io.Writer) bool {
event, ok := <-stream.Events()
if !ok {
c.Render(-1, common.CustomEvent{Data: "data: [DONE]"})
return false
}
switch v := event.(type) {
case *types.ResponseStreamMemberChunk:
if isFirst {
isFirst = false
info.FirstResponseTime = time.Now()
}
claudeResp := new(claude.ClaudeResponse)
err := json.NewDecoder(bytes.NewReader(v.Value.Bytes)).Decode(claudeResp)
if err != nil {
@@ -214,17 +208,6 @@ func awsStreamHandler(c *gin.Context, resp *http.Response, info *relaycommon.Rel
return false
}
})
if info.ShouldIncludeUsage {
response := service.GenerateFinalUsageResponse(id, createdTime, info.UpstreamModelName, usage)
err := service.ObjectData(c, response)
if err != nil {
common.SysError("send final response failed: " + err.Error())
}
}
service.Done(c)
err = resp.Body.Close()
if err != nil {
return service.OpenAIErrorWrapperLocal(err, "close_response_body_failed", http.StatusInternalServerError), nil
}
return nil, &usage
}

View File

@@ -2,7 +2,6 @@ package baidu
import (
"errors"
"fmt"
"github.com/gin-gonic/gin"
"io"
"net/http"
@@ -16,79 +15,44 @@ import (
type Adaptor struct {
}
func (a *Adaptor) ConvertAudioRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.AudioRequest) (io.Reader, error) {
//TODO implement me
return nil, errors.New("not implemented")
}
func (a *Adaptor) ConvertImageRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.ImageRequest) (any, error) {
//TODO implement me
return nil, errors.New("not implemented")
}
func (a *Adaptor) Init(info *relaycommon.RelayInfo) {
func (a *Adaptor) Init(info *relaycommon.RelayInfo, request dto.GeneralOpenAIRequest) {
}
func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) {
// https://cloud.baidu.com/doc/WENXINWORKSHOP/s/clntwmv7t
suffix := "chat/"
if strings.HasPrefix(info.UpstreamModelName, "Embedding") {
suffix = "embeddings/"
}
if strings.HasPrefix(info.UpstreamModelName, "bge-large") {
suffix = "embeddings/"
}
if strings.HasPrefix(info.UpstreamModelName, "tao-8k") {
suffix = "embeddings/"
}
var fullRequestURL string
switch info.UpstreamModelName {
case "ERNIE-4.0":
suffix += "completions_pro"
case "ERNIE-Bot-4":
suffix += "completions_pro"
case "ERNIE-Bot":
suffix += "completions"
case "ERNIE-Bot-turbo":
suffix += "eb-instant"
case "ERNIE-Speed":
suffix += "ernie_speed"
case "ERNIE-4.0-8K":
suffix += "completions_pro"
case "ERNIE-3.5-8K":
suffix += "completions"
case "ERNIE-3.5-8K-0205":
suffix += "ernie-3.5-8k-0205"
case "ERNIE-3.5-8K-1222":
suffix += "ernie-3.5-8k-1222"
fullRequestURL = "https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/completions_pro"
case "ERNIE-Bot-8K":
suffix += "ernie_bot_8k"
case "ERNIE-3.5-4K-0205":
suffix += "ernie-3.5-4k-0205"
case "ERNIE-Speed-8K":
suffix += "ernie_speed"
case "ERNIE-Speed-128K":
suffix += "ernie-speed-128k"
case "ERNIE-Lite-8K-0922":
suffix += "eb-instant"
case "ERNIE-Lite-8K-0308":
suffix += "ernie-lite-8k"
case "ERNIE-Tiny-8K":
suffix += "ernie-tiny-8k"
fullRequestURL = "https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/ernie_bot_8k"
case "ERNIE-Bot":
fullRequestURL = "https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/completions"
case "ERNIE-Speed":
fullRequestURL = "https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/ernie_speed"
case "ERNIE-Bot-turbo":
fullRequestURL = "https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/eb-instant"
case "BLOOMZ-7B":
suffix += "bloomz_7b1"
fullRequestURL = "https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/bloomz_7b1"
case "ERNIE-4.0-8K":
fullRequestURL = "https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/completions_pro"
case "ERNIE-3.5-8K":
fullRequestURL = "https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/completions"
case "ERNIE-Speed-8K":
fullRequestURL = "https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/ernie_speed"
case "ERNIE-Character-8K":
fullRequestURL = "https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/ernie-char-8k"
case "ERNIE-Functions-8K":
fullRequestURL = "https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/ernie-func-8k"
case "ERNIE-Lite-8K-0922":
fullRequestURL = "https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/eb-instant"
case "Yi-34B-Chat":
fullRequestURL = "https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/yi_34b_chat"
case "Embedding-V1":
suffix += "embedding-v1"
case "bge-large-zh":
suffix += "bge_large_zh"
case "bge-large-en":
suffix += "bge_large_en"
case "tao-8k":
suffix += "tao_8k"
fullRequestURL = "https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/embeddings/embedding-v1"
default:
suffix += strings.ToLower(info.UpstreamModelName)
fullRequestURL = "https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/" + strings.ToLower(info.UpstreamModelName)
}
fullRequestURL := fmt.Sprintf("%s/rpc/2.0/ai_custom/v1/wenxinworkshop/%s", info.BaseUrl, suffix)
var accessToken string
var err error
if accessToken, err = getBaiduAccessToken(info.ApiKey); err != nil {
@@ -104,11 +68,11 @@ func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Request, info *re
return nil
}
func (a *Adaptor) ConvertRequest(c *gin.Context, info *relaycommon.RelayInfo, request *dto.GeneralOpenAIRequest) (any, error) {
func (a *Adaptor) ConvertRequest(c *gin.Context, relayMode int, request *dto.GeneralOpenAIRequest) (any, error) {
if request == nil {
return nil, errors.New("request is nil")
}
switch info.RelayMode {
switch relayMode {
case constant.RelayModeEmbeddings:
baiduEmbeddingRequest := embeddingRequestOpenAI2Baidu(*request)
return baiduEmbeddingRequest, nil
@@ -118,10 +82,6 @@ func (a *Adaptor) ConvertRequest(c *gin.Context, info *relaycommon.RelayInfo, re
}
}
func (a *Adaptor) ConvertRerankRequest(c *gin.Context, relayMode int, request dto.RerankRequest) (any, error) {
return nil, nil
}
func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (*http.Response, error) {
return channel.DoApiRequest(a, c, info, requestBody)
}

View File

@@ -1,22 +1,20 @@
package baidu
var ModelList = []string{
"ERNIE-4.0-8K",
"ERNIE-3.5-8K",
"ERNIE-3.5-8K-0205",
"ERNIE-3.5-8K-1222",
"ERNIE-Bot-8K",
"ERNIE-3.5-4K-0205",
"ERNIE-4.0-8K",
"ERNIE-Speed-8K",
"ERNIE-Speed-128K",
"ERNIE-Lite-8K-0922",
"ERNIE-Lite-8K-0308",
"ERNIE-Lite-8K",
"ERNIE-Tiny-8K",
"BLOOMZ-7B",
"ERNIE-Character-8K",
"ERNIE-Functions-8K",
//"ERNIE-Bot-4",
//"ERNIE-Bot-8K",
//"ERNIE-Bot",
//"ERNIE-Speed",
//"ERNIE-Bot-turbo",
"Embedding-V1",
"bge-large-zh",
"bge-large-en",
"tao-8k",
}
var ChannelName = "baidu"

View File

@@ -11,16 +11,9 @@ type BaiduMessage struct {
}
type BaiduChatRequest struct {
Messages []BaiduMessage `json:"messages"`
Temperature float64 `json:"temperature,omitempty"`
TopP float64 `json:"top_p,omitempty"`
PenaltyScore float64 `json:"penalty_score,omitempty"`
Stream bool `json:"stream,omitempty"`
System string `json:"system,omitempty"`
DisableSearch bool `json:"disable_search,omitempty"`
EnableCitation bool `json:"enable_citation,omitempty"`
MaxOutputTokens *int `json:"max_output_tokens,omitempty"`
UserId string `json:"user_id,omitempty"`
Messages []BaiduMessage `json:"messages"`
Stream bool `json:"stream"`
UserId string `json:"user_id,omitempty"`
}
type Error struct {

View File

@@ -22,33 +22,17 @@ import (
var baiduTokenStore sync.Map
func requestOpenAI2Baidu(request dto.GeneralOpenAIRequest) *BaiduChatRequest {
baiduRequest := BaiduChatRequest{
Temperature: request.Temperature,
TopP: request.TopP,
PenaltyScore: request.FrequencyPenalty,
Stream: request.Stream,
DisableSearch: false,
EnableCitation: false,
UserId: request.User,
}
if request.MaxTokens != 0 {
maxTokens := int(request.MaxTokens)
if request.MaxTokens == 1 {
maxTokens = 2
}
baiduRequest.MaxOutputTokens = &maxTokens
}
messages := make([]BaiduMessage, 0, len(request.Messages))
for _, message := range request.Messages {
if message.Role == "system" {
baiduRequest.System = message.StringContent()
} else {
baiduRequest.Messages = append(baiduRequest.Messages, BaiduMessage{
Role: message.Role,
Content: message.StringContent(),
})
}
messages = append(messages, BaiduMessage{
Role: message.Role,
Content: message.StringContent(),
})
}
return &BaiduChatRequest{
Messages: messages,
Stream: request.Stream,
}
return &baiduRequest
}
func responseBaidu2OpenAI(response *BaiduChatResponse) *dto.OpenAITextResponse {

View File

@@ -21,17 +21,7 @@ type Adaptor struct {
RequestMode int
}
func (a *Adaptor) ConvertAudioRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.AudioRequest) (io.Reader, error) {
//TODO implement me
return nil, errors.New("not implemented")
}
func (a *Adaptor) ConvertImageRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.ImageRequest) (any, error) {
//TODO implement me
return nil, errors.New("not implemented")
}
func (a *Adaptor) Init(info *relaycommon.RelayInfo) {
func (a *Adaptor) Init(info *relaycommon.RelayInfo, request dto.GeneralOpenAIRequest) {
if strings.HasPrefix(info.UpstreamModelName, "claude-3") {
a.RequestMode = RequestModeMessage
} else {
@@ -58,7 +48,7 @@ func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Request, info *re
return nil
}
func (a *Adaptor) ConvertRequest(c *gin.Context, info *relaycommon.RelayInfo, request *dto.GeneralOpenAIRequest) (any, error) {
func (a *Adaptor) ConvertRequest(c *gin.Context, relayMode int, request *dto.GeneralOpenAIRequest) (any, error) {
if request == nil {
return nil, errors.New("request is nil")
}
@@ -69,17 +59,13 @@ func (a *Adaptor) ConvertRequest(c *gin.Context, info *relaycommon.RelayInfo, re
}
}
func (a *Adaptor) ConvertRerankRequest(c *gin.Context, relayMode int, request dto.RerankRequest) (any, error) {
return nil, nil
}
func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (*http.Response, error) {
return channel.DoApiRequest(a, c, info, requestBody)
}
func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage *dto.Usage, err *dto.OpenAIErrorWithStatusCode) {
if info.IsStream {
err, usage = claudeStreamHandler(c, resp, info, a.RequestMode)
err, usage = claudeStreamHandler(a.RequestMode, info.UpstreamModelName, info.PromptTokens, c, resp)
} else {
err, usage = claudeHandler(a.RequestMode, c, resp, info.PromptTokens, info.UpstreamModelName)
}

View File

@@ -8,7 +8,6 @@ var ModelList = []string{
"claude-3-sonnet-20240229",
"claude-3-opus-20240229",
"claude-3-haiku-20240307",
"claude-3-5-sonnet-20240620",
}
var ChannelName = "claude"

View File

@@ -5,18 +5,11 @@ type ClaudeMetadata struct {
}
type ClaudeMediaMessage struct {
Type string `json:"type"`
Text string `json:"text,omitempty"`
Source *ClaudeMessageSource `json:"source,omitempty"`
Usage *ClaudeUsage `json:"usage,omitempty"`
StopReason *string `json:"stop_reason,omitempty"`
PartialJson string `json:"partial_json,omitempty"`
// tool_calls
Id string `json:"id,omitempty"`
Name string `json:"name,omitempty"`
Input any `json:"input,omitempty"`
Content string `json:"content,omitempty"`
ToolUseId string `json:"tool_use_id,omitempty"`
Type string `json:"type"`
Text string `json:"text,omitempty"`
Source *ClaudeMessageSource `json:"source,omitempty"`
Usage *ClaudeUsage `json:"usage,omitempty"`
StopReason *string `json:"stop_reason,omitempty"`
}
type ClaudeMessageSource struct {
@@ -30,18 +23,6 @@ type ClaudeMessage struct {
Content any `json:"content"`
}
type Tool struct {
Name string `json:"name"`
Description string `json:"description,omitempty"`
InputSchema InputSchema `json:"input_schema"`
}
type InputSchema struct {
Type string `json:"type"`
Properties any `json:"properties,omitempty"`
Required any `json:"required,omitempty"`
}
type ClaudeRequest struct {
Model string `json:"model"`
Prompt string `json:"prompt,omitempty"`
@@ -54,9 +35,7 @@ type ClaudeRequest struct {
TopP float64 `json:"top_p,omitempty"`
TopK int `json:"top_k,omitempty"`
//ClaudeMetadata `json:"metadata,omitempty"`
Stream bool `json:"stream,omitempty"`
Tools []Tool `json:"tools,omitempty"`
ToolChoice any `json:"tool_choice,omitempty"`
Stream bool `json:"stream,omitempty"`
}
type ClaudeError struct {
@@ -65,20 +44,24 @@ type ClaudeError struct {
}
type ClaudeResponse struct {
Id string `json:"id"`
Type string `json:"type"`
Content []ClaudeMediaMessage `json:"content"`
Completion string `json:"completion"`
StopReason string `json:"stop_reason"`
Model string `json:"model"`
Error ClaudeError `json:"error"`
Usage ClaudeUsage `json:"usage"`
Index int `json:"index"` // stream only
ContentBlock *ClaudeMediaMessage `json:"content_block"`
Delta *ClaudeMediaMessage `json:"delta"` // stream only
Message *ClaudeResponse `json:"message"` // stream only: message_start
Id string `json:"id"`
Type string `json:"type"`
Content []ClaudeMediaMessage `json:"content"`
Completion string `json:"completion"`
StopReason string `json:"stop_reason"`
Model string `json:"model"`
Error ClaudeError `json:"error"`
Usage ClaudeUsage `json:"usage"`
Index int `json:"index"` // stream only
Delta *ClaudeMediaMessage `json:"delta"` // stream only
Message *ClaudeResponse `json:"message"` // stream only: message_start
}
//type ClaudeResponseChoice struct {
// Index int `json:"index"`
// Type string `json:"type"`
//}
type ClaudeUsage struct {
InputTokens int `json:"input_tokens"`
OutputTokens int `json:"output_tokens"`

View File

@@ -9,7 +9,6 @@ import (
"net/http"
"one-api/common"
"one-api/dto"
relaycommon "one-api/relay/common"
"one-api/service"
"strings"
)
@@ -28,7 +27,6 @@ func stopReasonClaude2OpenAI(reason string) string {
}
func RequestOpenAI2ClaudeComplete(textRequest dto.GeneralOpenAIRequest) *ClaudeRequest {
claudeRequest := ClaudeRequest{
Model: textRequest.Model,
Prompt: "",
@@ -59,22 +57,6 @@ func RequestOpenAI2ClaudeComplete(textRequest dto.GeneralOpenAIRequest) *ClaudeR
}
func RequestOpenAI2ClaudeMessage(textRequest dto.GeneralOpenAIRequest) (*ClaudeRequest, error) {
claudeTools := make([]Tool, 0, len(textRequest.Tools))
for _, tool := range textRequest.Tools {
if params, ok := tool.Function.Parameters.(map[string]any); ok {
claudeTools = append(claudeTools, Tool{
Name: tool.Function.Name,
Description: tool.Function.Description,
InputSchema: InputSchema{
Type: params["type"].(string),
Properties: params["properties"],
Required: params["required"],
},
})
}
}
claudeRequest := ClaudeRequest{
Model: textRequest.Model,
MaxTokens: textRequest.MaxTokens,
@@ -83,24 +65,10 @@ func RequestOpenAI2ClaudeMessage(textRequest dto.GeneralOpenAIRequest) (*ClaudeR
TopP: textRequest.TopP,
TopK: textRequest.TopK,
Stream: textRequest.Stream,
Tools: claudeTools,
}
if claudeRequest.MaxTokens == 0 {
claudeRequest.MaxTokens = 4096
}
if textRequest.Stop != nil {
// stop maybe string/array string, convert to array string
switch textRequest.Stop.(type) {
case string:
claudeRequest.StopSequences = []string{textRequest.Stop.(string)}
case []interface{}:
stopSequences := make([]string, 0)
for _, stop := range textRequest.Stop.([]interface{}) {
stopSequences = append(stopSequences, stop.(string))
}
claudeRequest.StopSequences = stopSequences
}
}
formatMessages := make([]dto.Message, 0)
var lastMessage *dto.Message
for i, message := range textRequest.Messages {
@@ -200,7 +168,6 @@ func StreamResponseClaude2OpenAI(reqMode int, claudeResponse *ClaudeResponse) (*
response.Object = "chat.completion.chunk"
response.Model = claudeResponse.Model
response.Choices = make([]dto.ChatCompletionsStreamResponseChoice, 0)
tools := make([]dto.ToolCall, 0)
var choice dto.ChatCompletionsStreamResponseChoice
if reqMode == RequestModeCompletion {
choice.Delta.SetContentString(claudeResponse.Completion)
@@ -216,33 +183,10 @@ func StreamResponseClaude2OpenAI(reqMode int, claudeResponse *ClaudeResponse) (*
choice.Delta.SetContentString("")
choice.Delta.Role = "assistant"
} else if claudeResponse.Type == "content_block_start" {
if claudeResponse.ContentBlock != nil {
//choice.Delta.SetContentString(claudeResponse.ContentBlock.Text)
if claudeResponse.ContentBlock.Type == "tool_use" {
tools = append(tools, dto.ToolCall{
ID: claudeResponse.ContentBlock.Id,
Type: "function",
Function: dto.FunctionCall{
Name: claudeResponse.ContentBlock.Name,
Arguments: "",
},
})
}
} else {
return nil, nil
}
return nil, nil
} else if claudeResponse.Type == "content_block_delta" {
if claudeResponse.Delta != nil {
choice.Index = claudeResponse.Index
choice.Delta.SetContentString(claudeResponse.Delta.Text)
if claudeResponse.Delta.Type == "input_json_delta" {
tools = append(tools, dto.ToolCall{
Function: dto.FunctionCall{
Arguments: claudeResponse.Delta.PartialJson,
},
})
}
}
choice.Index = claudeResponse.Index
choice.Delta.SetContentString(claudeResponse.Delta.Text)
} else if claudeResponse.Type == "message_delta" {
finishReason := stopReasonClaude2OpenAI(*claudeResponse.Delta.StopReason)
if finishReason != "null" {
@@ -258,10 +202,6 @@ func StreamResponseClaude2OpenAI(reqMode int, claudeResponse *ClaudeResponse) (*
if claudeUsage == nil {
claudeUsage = &ClaudeUsage{}
}
if len(tools) > 0 {
choice.Delta.Content = nil // compatible with other OpenAI derivative applications, like LobeOpenAICompatibleFactory ...
choice.Delta.ToolCalls = tools
}
response.Choices = append(response.Choices, choice)
return &response, claudeUsage
@@ -274,11 +214,6 @@ func ResponseClaude2OpenAI(reqMode int, claudeResponse *ClaudeResponse) *dto.Ope
Object: "chat.completion",
Created: common.GetTimestamp(),
}
var responseText string
if len(claudeResponse.Content) > 0 {
responseText = claudeResponse.Content[0].Text
}
tools := make([]dto.ToolCall, 0)
if reqMode == RequestModeCompletion {
content, _ := json.Marshal(strings.TrimPrefix(claudeResponse.Completion, " "))
choice := dto.OpenAITextResponseChoice{
@@ -293,115 +228,119 @@ func ResponseClaude2OpenAI(reqMode int, claudeResponse *ClaudeResponse) *dto.Ope
choices = append(choices, choice)
} else {
fullTextResponse.Id = claudeResponse.Id
for _, message := range claudeResponse.Content {
if message.Type == "tool_use" {
args, _ := json.Marshal(message.Input)
tools = append(tools, dto.ToolCall{
ID: message.Id,
Type: "function", // compatible with other OpenAI derivative applications
Function: dto.FunctionCall{
Name: message.Name,
Arguments: string(args),
},
})
for i, message := range claudeResponse.Content {
content, _ := json.Marshal(message.Text)
choice := dto.OpenAITextResponseChoice{
Index: i,
Message: dto.Message{
Role: "assistant",
Content: content,
},
FinishReason: stopReasonClaude2OpenAI(claudeResponse.StopReason),
}
choices = append(choices, choice)
}
}
choice := dto.OpenAITextResponseChoice{
Index: 0,
Message: dto.Message{
Role: "assistant",
},
FinishReason: stopReasonClaude2OpenAI(claudeResponse.StopReason),
}
choice.SetStringContent(responseText)
if len(tools) > 0 {
choice.Message.ToolCalls = tools
}
choices = append(choices, choice)
fullTextResponse.Choices = choices
return &fullTextResponse
}
func claudeStreamHandler(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo, requestMode int) (*dto.OpenAIErrorWithStatusCode, *dto.Usage) {
func claudeStreamHandler(requestMode int, modelName string, promptTokens int, c *gin.Context, resp *http.Response) (*dto.OpenAIErrorWithStatusCode, *dto.Usage) {
responseId := fmt.Sprintf("chatcmpl-%s", common.GetUUID())
var usage *dto.Usage
usage = &dto.Usage{}
responseText := ""
createdTime := common.GetTimestamp()
scanner := bufio.NewScanner(resp.Body)
scanner.Split(bufio.ScanLines)
service.SetEventStreamHeaders(c)
for scanner.Scan() {
data := scanner.Text()
info.SetFirstResponseTime()
if len(data) < 6 || !strings.HasPrefix(data, "data:") {
continue
scanner.Split(func(data []byte, atEOF bool) (advance int, token []byte, err error) {
if atEOF && len(data) == 0 {
return 0, nil, nil
}
data = strings.TrimPrefix(data, "data:")
data = strings.TrimSpace(data)
var claudeResponse ClaudeResponse
err := json.Unmarshal([]byte(data), &claudeResponse)
if err != nil {
common.SysError("error unmarshalling stream response: " + err.Error())
continue
if i := strings.Index(string(data), "\n"); i >= 0 {
return i + 1, data[0:i], nil
}
response, claudeUsage := StreamResponseClaude2OpenAI(requestMode, &claudeResponse)
if response == nil {
continue
if atEOF {
return len(data), data, nil
}
if requestMode == RequestModeCompletion {
responseText += claudeResponse.Completion
responseId = response.Id
} else {
if claudeResponse.Type == "message_start" {
// message_start, 获取usage
responseId = claudeResponse.Message.Id
info.UpstreamModelName = claudeResponse.Message.Model
usage.PromptTokens = claudeUsage.InputTokens
} else if claudeResponse.Type == "content_block_delta" {
responseText += claudeResponse.Delta.Text
} else if claudeResponse.Type == "message_delta" {
usage.CompletionTokens = claudeUsage.OutputTokens
usage.TotalTokens = claudeUsage.InputTokens + claudeUsage.OutputTokens
} else if claudeResponse.Type == "content_block_start" {
} else {
return 0, nil, nil
})
dataChan := make(chan string)
stopChan := make(chan bool)
go func() {
for scanner.Scan() {
data := scanner.Text()
if !strings.HasPrefix(data, "data: ") {
continue
}
data = strings.TrimPrefix(data, "data: ")
dataChan <- data
}
//response.Id = responseId
response.Id = responseId
response.Created = createdTime
response.Model = info.UpstreamModelName
stopChan <- true
}()
service.SetEventStreamHeaders(c)
c.Stream(func(w io.Writer) bool {
select {
case data := <-dataChan:
// some implementations may add \r at the end of data
data = strings.TrimSuffix(data, "\r")
var claudeResponse ClaudeResponse
err := json.Unmarshal([]byte(data), &claudeResponse)
if err != nil {
common.SysError("error unmarshalling stream response: " + err.Error())
return true
}
err = service.ObjectData(c, response)
if err != nil {
common.LogError(c, "send_stream_response_failed: "+err.Error())
response, claudeUsage := StreamResponseClaude2OpenAI(requestMode, &claudeResponse)
if response == nil {
return true
}
if requestMode == RequestModeCompletion {
responseText += claudeResponse.Completion
responseId = response.Id
} else {
if claudeResponse.Type == "message_start" {
// message_start, 获取usage
responseId = claudeResponse.Message.Id
modelName = claudeResponse.Message.Model
usage.PromptTokens = claudeUsage.InputTokens
} else if claudeResponse.Type == "content_block_delta" {
responseText += claudeResponse.Delta.Text
} else if claudeResponse.Type == "message_delta" {
usage.CompletionTokens = claudeUsage.OutputTokens
usage.TotalTokens = claudeUsage.InputTokens + claudeUsage.OutputTokens
} else {
return true
}
}
//response.Id = responseId
response.Id = responseId
response.Created = createdTime
response.Model = modelName
jsonStr, err := json.Marshal(response)
if err != nil {
common.SysError("error marshalling stream response: " + err.Error())
return true
}
c.Render(-1, common.CustomEvent{Data: "data: " + string(jsonStr)})
return true
case <-stopChan:
c.Render(-1, common.CustomEvent{Data: "data: [DONE]"})
return false
}
})
err := resp.Body.Close()
if err != nil {
return service.OpenAIErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil
}
if requestMode == RequestModeCompletion {
usage, _ = service.ResponseText2Usage(responseText, info.UpstreamModelName, info.PromptTokens)
usage, _ = service.ResponseText2Usage(responseText, modelName, promptTokens)
} else {
if usage.PromptTokens == 0 {
usage.PromptTokens = info.PromptTokens
}
if usage.CompletionTokens == 0 {
usage, _ = service.ResponseText2Usage(responseText, info.UpstreamModelName, usage.PromptTokens)
usage, _ = service.ResponseText2Usage(responseText, modelName, usage.PromptTokens)
}
}
if info.ShouldIncludeUsage {
response := service.GenerateFinalUsageResponse(responseId, createdTime, info.UpstreamModelName, *usage)
err := service.ObjectData(c, response)
if err != nil {
common.SysError("send final response failed: " + err.Error())
}
}
service.Done(c)
resp.Body.Close()
return nil, usage
}

View File

@@ -1,105 +0,0 @@
package cloudflare
import (
"bytes"
"errors"
"fmt"
"github.com/gin-gonic/gin"
"io"
"net/http"
"one-api/dto"
"one-api/relay/channel"
relaycommon "one-api/relay/common"
"one-api/relay/constant"
)
type Adaptor struct {
}
func (a *Adaptor) Init(info *relaycommon.RelayInfo) {
}
func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) {
switch info.RelayMode {
case constant.RelayModeChatCompletions:
return fmt.Sprintf("%s/client/v4/accounts/%s/ai/v1/chat/completions", info.BaseUrl, info.ApiVersion), nil
case constant.RelayModeEmbeddings:
return fmt.Sprintf("%s/client/v4/accounts/%s/ai/v1/embeddings", info.BaseUrl, info.ApiVersion), nil
default:
return fmt.Sprintf("%s/client/v4/accounts/%s/ai/run/%s", info.BaseUrl, info.ApiVersion, info.UpstreamModelName), nil
}
}
func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Request, info *relaycommon.RelayInfo) error {
channel.SetupApiRequestHeader(info, c, req)
req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", info.ApiKey))
return nil
}
func (a *Adaptor) ConvertRequest(c *gin.Context, info *relaycommon.RelayInfo, request *dto.GeneralOpenAIRequest) (any, error) {
if request == nil {
return nil, errors.New("request is nil")
}
switch info.RelayMode {
case constant.RelayModeCompletions:
return convertCf2CompletionsRequest(*request), nil
default:
return request, nil
}
}
func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (*http.Response, error) {
return channel.DoApiRequest(a, c, info, requestBody)
}
func (a *Adaptor) ConvertRerankRequest(c *gin.Context, relayMode int, request dto.RerankRequest) (any, error) {
return request, nil
}
func (a *Adaptor) ConvertAudioRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.AudioRequest) (io.Reader, error) {
// 添加文件字段
file, _, err := c.Request.FormFile("file")
if err != nil {
return nil, errors.New("file is required")
}
defer file.Close()
// 打开临时文件用于保存上传的文件内容
requestBody := &bytes.Buffer{}
// 将上传的文件内容复制到临时文件
if _, err := io.Copy(requestBody, file); err != nil {
return nil, err
}
return requestBody, nil
}
func (a *Adaptor) ConvertImageRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.ImageRequest) (any, error) {
//TODO implement me
return nil, errors.New("not implemented")
}
func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage *dto.Usage, err *dto.OpenAIErrorWithStatusCode) {
switch info.RelayMode {
case constant.RelayModeEmbeddings:
fallthrough
case constant.RelayModeChatCompletions:
if info.IsStream {
err, usage = cfStreamHandler(c, resp, info)
} else {
err, usage = cfHandler(c, resp, info)
}
case constant.RelayModeAudioTranslation:
fallthrough
case constant.RelayModeAudioTranscription:
err, usage = cfSTTHandler(c, resp, info)
}
return
}
func (a *Adaptor) GetModelList() []string {
return ModelList
}
func (a *Adaptor) GetChannelName() string {
return ChannelName
}

View File

@@ -1,38 +0,0 @@
package cloudflare
var ModelList = []string{
"@cf/meta/llama-2-7b-chat-fp16",
"@cf/meta/llama-2-7b-chat-int8",
"@cf/mistral/mistral-7b-instruct-v0.1",
"@hf/thebloke/deepseek-coder-6.7b-base-awq",
"@hf/thebloke/deepseek-coder-6.7b-instruct-awq",
"@cf/deepseek-ai/deepseek-math-7b-base",
"@cf/deepseek-ai/deepseek-math-7b-instruct",
"@cf/thebloke/discolm-german-7b-v1-awq",
"@cf/tiiuae/falcon-7b-instruct",
"@cf/google/gemma-2b-it-lora",
"@hf/google/gemma-7b-it",
"@cf/google/gemma-7b-it-lora",
"@hf/nousresearch/hermes-2-pro-mistral-7b",
"@hf/thebloke/llama-2-13b-chat-awq",
"@cf/meta-llama/llama-2-7b-chat-hf-lora",
"@cf/meta/llama-3-8b-instruct",
"@hf/thebloke/llamaguard-7b-awq",
"@hf/thebloke/mistral-7b-instruct-v0.1-awq",
"@hf/mistralai/mistral-7b-instruct-v0.2",
"@cf/mistral/mistral-7b-instruct-v0.2-lora",
"@hf/thebloke/neural-chat-7b-v3-1-awq",
"@cf/openchat/openchat-3.5-0106",
"@hf/thebloke/openhermes-2.5-mistral-7b-awq",
"@cf/microsoft/phi-2",
"@cf/qwen/qwen1.5-0.5b-chat",
"@cf/qwen/qwen1.5-1.8b-chat",
"@cf/qwen/qwen1.5-14b-chat-awq",
"@cf/qwen/qwen1.5-7b-chat-awq",
"@cf/defog/sqlcoder-7b-2",
"@hf/nexusflow/starling-lm-7b-beta",
"@cf/tinyllama/tinyllama-1.1b-chat-v1.0",
"@hf/thebloke/zephyr-7b-beta-awq",
}
var ChannelName = "cloudflare"

View File

@@ -1,21 +0,0 @@
package cloudflare
import "one-api/dto"
type CfRequest struct {
Messages []dto.Message `json:"messages,omitempty"`
Lora string `json:"lora,omitempty"`
MaxTokens int `json:"max_tokens,omitempty"`
Prompt string `json:"prompt,omitempty"`
Raw bool `json:"raw,omitempty"`
Stream bool `json:"stream,omitempty"`
Temperature float64 `json:"temperature,omitempty"`
}
type CfAudioResponse struct {
Result CfSTTResult `json:"result"`
}
type CfSTTResult struct {
Text string `json:"text"`
}

View File

@@ -1,156 +0,0 @@
package cloudflare
import (
"bufio"
"encoding/json"
"github.com/gin-gonic/gin"
"io"
"net/http"
"one-api/common"
"one-api/dto"
relaycommon "one-api/relay/common"
"one-api/service"
"strings"
"time"
)
func convertCf2CompletionsRequest(textRequest dto.GeneralOpenAIRequest) *CfRequest {
p, _ := textRequest.Prompt.(string)
return &CfRequest{
Prompt: p,
MaxTokens: textRequest.GetMaxTokens(),
Stream: textRequest.Stream,
Temperature: textRequest.Temperature,
}
}
func cfStreamHandler(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (*dto.OpenAIErrorWithStatusCode, *dto.Usage) {
scanner := bufio.NewScanner(resp.Body)
scanner.Split(bufio.ScanLines)
service.SetEventStreamHeaders(c)
id := service.GetResponseID(c)
var responseText string
isFirst := true
for scanner.Scan() {
data := scanner.Text()
if len(data) < len("data: ") {
continue
}
data = strings.TrimPrefix(data, "data: ")
data = strings.TrimSuffix(data, "\r")
if data == "[DONE]" {
break
}
var response dto.ChatCompletionsStreamResponse
err := json.Unmarshal([]byte(data), &response)
if err != nil {
common.LogError(c, "error_unmarshalling_stream_response: "+err.Error())
continue
}
for _, choice := range response.Choices {
choice.Delta.Role = "assistant"
responseText += choice.Delta.GetContentString()
}
response.Id = id
response.Model = info.UpstreamModelName
err = service.ObjectData(c, response)
if isFirst {
isFirst = false
info.FirstResponseTime = time.Now()
}
if err != nil {
common.LogError(c, "error_rendering_stream_response: "+err.Error())
}
}
if err := scanner.Err(); err != nil {
common.LogError(c, "error_scanning_stream_response: "+err.Error())
}
usage, _ := service.ResponseText2Usage(responseText, info.UpstreamModelName, info.PromptTokens)
if info.ShouldIncludeUsage {
response := service.GenerateFinalUsageResponse(id, info.StartTime.Unix(), info.UpstreamModelName, *usage)
err := service.ObjectData(c, response)
if err != nil {
common.LogError(c, "error_rendering_final_usage_response: "+err.Error())
}
}
service.Done(c)
err := resp.Body.Close()
if err != nil {
common.LogError(c, "close_response_body_failed: "+err.Error())
}
return nil, usage
}
func cfHandler(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (*dto.OpenAIErrorWithStatusCode, *dto.Usage) {
responseBody, err := io.ReadAll(resp.Body)
if err != nil {
return service.OpenAIErrorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil
}
err = resp.Body.Close()
if err != nil {
return service.OpenAIErrorWrapperLocal(err, "close_response_body_failed", http.StatusInternalServerError), nil
}
var response dto.TextResponse
err = json.Unmarshal(responseBody, &response)
if err != nil {
return service.OpenAIErrorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil
}
response.Model = info.UpstreamModelName
var responseText string
for _, choice := range response.Choices {
responseText += choice.Message.StringContent()
}
usage, _ := service.ResponseText2Usage(responseText, info.UpstreamModelName, info.PromptTokens)
response.Usage = *usage
response.Id = service.GetResponseID(c)
jsonResponse, err := json.Marshal(response)
if err != nil {
return service.OpenAIErrorWrapper(err, "marshal_response_body_failed", http.StatusInternalServerError), nil
}
c.Writer.Header().Set("Content-Type", "application/json")
c.Writer.WriteHeader(resp.StatusCode)
_, _ = c.Writer.Write(jsonResponse)
return nil, usage
}
func cfSTTHandler(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (*dto.OpenAIErrorWithStatusCode, *dto.Usage) {
var cfResp CfAudioResponse
responseBody, err := io.ReadAll(resp.Body)
if err != nil {
return service.OpenAIErrorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil
}
err = resp.Body.Close()
if err != nil {
return service.OpenAIErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil
}
err = json.Unmarshal(responseBody, &cfResp)
if err != nil {
return service.OpenAIErrorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil
}
audioResp := &dto.AudioResponse{
Text: cfResp.Result.Text,
}
jsonResponse, err := json.Marshal(audioResp)
if err != nil {
return service.OpenAIErrorWrapper(err, "marshal_response_body_failed", http.StatusInternalServerError), nil
}
c.Writer.Header().Set("Content-Type", "application/json")
c.Writer.WriteHeader(resp.StatusCode)
_, _ = c.Writer.Write(jsonResponse)
usage := &dto.Usage{}
usage.PromptTokens = info.PromptTokens
usage.CompletionTokens, _ = service.CountTokenText(cfResp.Result.Text, info.UpstreamModelName)
usage.TotalTokens = usage.PromptTokens + usage.CompletionTokens
return nil, usage
}

View File

@@ -1,7 +1,6 @@
package cohere
import (
"errors"
"fmt"
"github.com/gin-gonic/gin"
"io"
@@ -9,31 +8,16 @@ import (
"one-api/dto"
"one-api/relay/channel"
relaycommon "one-api/relay/common"
"one-api/relay/constant"
)
type Adaptor struct {
}
func (a *Adaptor) ConvertAudioRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.AudioRequest) (io.Reader, error) {
//TODO implement me
return nil, errors.New("not implemented")
}
func (a *Adaptor) ConvertImageRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.ImageRequest) (any, error) {
//TODO implement me
return nil, errors.New("not implemented")
}
func (a *Adaptor) Init(info *relaycommon.RelayInfo) {
func (a *Adaptor) Init(info *relaycommon.RelayInfo, request dto.GeneralOpenAIRequest) {
}
func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) {
if info.RelayMode == constant.RelayModeRerank {
return fmt.Sprintf("%s/v1/rerank", info.BaseUrl), nil
} else {
return fmt.Sprintf("%s/v1/chat", info.BaseUrl), nil
}
return fmt.Sprintf("%s/v1/chat", info.BaseUrl), nil
}
func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Request, info *relaycommon.RelayInfo) error {
@@ -42,7 +26,7 @@ func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Request, info *re
return nil
}
func (a *Adaptor) ConvertRequest(c *gin.Context, info *relaycommon.RelayInfo, request *dto.GeneralOpenAIRequest) (any, error) {
func (a *Adaptor) ConvertRequest(c *gin.Context, relayMode int, request *dto.GeneralOpenAIRequest) (any, error) {
return requestOpenAI2Cohere(*request), nil
}
@@ -50,19 +34,11 @@ func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, request
return channel.DoApiRequest(a, c, info, requestBody)
}
func (a *Adaptor) ConvertRerankRequest(c *gin.Context, relayMode int, request dto.RerankRequest) (any, error) {
return requestConvertRerank2Cohere(request), nil
}
func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage *dto.Usage, err *dto.OpenAIErrorWithStatusCode) {
if info.RelayMode == constant.RelayModeRerank {
err, usage = cohereRerankHandler(c, resp, info)
if info.IsStream {
err, usage = cohereStreamHandler(c, resp, info.UpstreamModelName, info.PromptTokens)
} else {
if info.IsStream {
err, usage = cohereStreamHandler(c, resp, info)
} else {
err, usage = cohereHandler(c, resp, info.UpstreamModelName, info.PromptTokens)
}
err, usage = cohereHandler(c, resp, info.UpstreamModelName, info.PromptTokens)
}
return
}

View File

@@ -2,7 +2,6 @@ package cohere
var ModelList = []string{
"command-r", "command-r-plus", "command-light", "command-light-nightly", "command", "command-nightly",
"rerank-english-v3.0", "rerank-multilingual-v3.0", "rerank-english-v2.0", "rerank-multilingual-v2.0",
}
var ChannelName = "cohere"

View File

@@ -1,13 +1,11 @@
package cohere
import "one-api/dto"
type CohereRequest struct {
Model string `json:"model"`
ChatHistory []ChatHistory `json:"chat_history"`
Message string `json:"message"`
Stream bool `json:"stream"`
MaxTokens int `json:"max_tokens"`
MaxTokens int64 `json:"max_tokens"`
}
type ChatHistory struct {
@@ -30,19 +28,6 @@ type CohereResponseResult struct {
Meta CohereMeta `json:"meta"`
}
type CohereRerankRequest struct {
Documents []any `json:"documents"`
Query string `json:"query"`
Model string `json:"model"`
TopN int `json:"top_n"`
ReturnDocuments bool `json:"return_documents"`
}
type CohereRerankResponseResult struct {
Results []dto.RerankResponseDocument `json:"results"`
Meta CohereMeta `json:"meta"`
}
type CohereMeta struct {
//Tokens CohereTokens `json:"tokens"`
BilledUnits CohereBilledUnits `json:"billed_units"`

View File

@@ -9,10 +9,8 @@ import (
"net/http"
"one-api/common"
"one-api/dto"
relaycommon "one-api/relay/common"
"one-api/service"
"strings"
"time"
)
func requestOpenAI2Cohere(textRequest dto.GeneralOpenAIRequest) *CohereRequest {
@@ -47,20 +45,6 @@ func requestOpenAI2Cohere(textRequest dto.GeneralOpenAIRequest) *CohereRequest {
return &cohereReq
}
func requestConvertRerank2Cohere(rerankRequest dto.RerankRequest) *CohereRerankRequest {
if rerankRequest.TopN == 0 {
rerankRequest.TopN = 1
}
cohereReq := CohereRerankRequest{
Query: rerankRequest.Query,
Documents: rerankRequest.Documents,
Model: rerankRequest.Model,
TopN: rerankRequest.TopN,
ReturnDocuments: true,
}
return &cohereReq
}
func stopReasonCohere2OpenAI(reason string) string {
switch reason {
case "COMPLETE":
@@ -72,7 +56,7 @@ func stopReasonCohere2OpenAI(reason string) string {
}
}
func cohereStreamHandler(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (*dto.OpenAIErrorWithStatusCode, *dto.Usage) {
func cohereStreamHandler(c *gin.Context, resp *http.Response, modelName string, promptTokens int) (*dto.OpenAIErrorWithStatusCode, *dto.Usage) {
responseId := fmt.Sprintf("chatcmpl-%s", common.GetUUID())
createdTime := common.GetTimestamp()
usage := &dto.Usage{}
@@ -100,14 +84,9 @@ func cohereStreamHandler(c *gin.Context, resp *http.Response, info *relaycommon.
stopChan <- true
}()
service.SetEventStreamHeaders(c)
isFirst := true
c.Stream(func(w io.Writer) bool {
select {
case data := <-dataChan:
if isFirst {
isFirst = false
info.FirstResponseTime = time.Now()
}
data = strings.TrimSuffix(data, "\r")
var cohereResp CohereResponse
err := json.Unmarshal([]byte(data), &cohereResp)
@@ -119,7 +98,7 @@ func cohereStreamHandler(c *gin.Context, resp *http.Response, info *relaycommon.
openaiResp.Id = responseId
openaiResp.Created = createdTime
openaiResp.Object = "chat.completion.chunk"
openaiResp.Model = info.UpstreamModelName
openaiResp.Model = modelName
if cohereResp.IsFinished {
finishReason := stopReasonCohere2OpenAI(cohereResp.FinishReason)
openaiResp.Choices = []dto.ChatCompletionsStreamResponseChoice{
@@ -158,7 +137,7 @@ func cohereStreamHandler(c *gin.Context, resp *http.Response, info *relaycommon.
}
})
if usage.PromptTokens == 0 {
usage, _ = service.ResponseText2Usage(responseText, info.UpstreamModelName, info.PromptTokens)
usage, _ = service.ResponseText2Usage(responseText, modelName, promptTokens)
}
return nil, usage
}
@@ -208,42 +187,3 @@ func cohereHandler(c *gin.Context, resp *http.Response, modelName string, prompt
_, err = c.Writer.Write(jsonResponse)
return nil, &usage
}
func cohereRerankHandler(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (*dto.OpenAIErrorWithStatusCode, *dto.Usage) {
responseBody, err := io.ReadAll(resp.Body)
if err != nil {
return service.OpenAIErrorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil
}
err = resp.Body.Close()
if err != nil {
return service.OpenAIErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil
}
var cohereResp CohereRerankResponseResult
err = json.Unmarshal(responseBody, &cohereResp)
if err != nil {
return service.OpenAIErrorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil
}
usage := dto.Usage{}
if cohereResp.Meta.BilledUnits.InputTokens == 0 {
usage.PromptTokens = info.PromptTokens
usage.CompletionTokens = 0
usage.TotalTokens = info.PromptTokens
} else {
usage.PromptTokens = cohereResp.Meta.BilledUnits.InputTokens
usage.CompletionTokens = cohereResp.Meta.BilledUnits.OutputTokens
usage.TotalTokens = cohereResp.Meta.BilledUnits.InputTokens + cohereResp.Meta.BilledUnits.OutputTokens
}
var rerankResp dto.RerankResponse
rerankResp.Results = cohereResp.Results
rerankResp.Usage = usage
jsonResponse, err := json.Marshal(rerankResp)
if err != nil {
return service.OpenAIErrorWrapper(err, "marshal_response_body_failed", http.StatusInternalServerError), nil
}
c.Writer.Header().Set("Content-Type", "application/json")
c.Writer.WriteHeader(resp.StatusCode)
_, err = c.Writer.Write(jsonResponse)
return nil, &usage
}

View File

@@ -1,70 +0,0 @@
package dify
import (
"errors"
"fmt"
"github.com/gin-gonic/gin"
"io"
"net/http"
"one-api/dto"
"one-api/relay/channel"
relaycommon "one-api/relay/common"
)
type Adaptor struct {
}
func (a *Adaptor) ConvertAudioRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.AudioRequest) (io.Reader, error) {
//TODO implement me
return nil, errors.New("not implemented")
}
func (a *Adaptor) ConvertImageRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.ImageRequest) (any, error) {
//TODO implement me
return nil, errors.New("not implemented")
}
func (a *Adaptor) Init(info *relaycommon.RelayInfo) {
}
func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) {
return fmt.Sprintf("%s/v1/chat-messages", info.BaseUrl), nil
}
func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Request, info *relaycommon.RelayInfo) error {
channel.SetupApiRequestHeader(info, c, req)
req.Header.Set("Authorization", "Bearer "+info.ApiKey)
return nil
}
func (a *Adaptor) ConvertRequest(c *gin.Context, info *relaycommon.RelayInfo, request *dto.GeneralOpenAIRequest) (any, error) {
if request == nil {
return nil, errors.New("request is nil")
}
return requestOpenAI2Dify(*request), nil
}
func (a *Adaptor) ConvertRerankRequest(c *gin.Context, relayMode int, request dto.RerankRequest) (any, error) {
return nil, nil
}
func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (*http.Response, error) {
return channel.DoApiRequest(a, c, info, requestBody)
}
func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage *dto.Usage, err *dto.OpenAIErrorWithStatusCode) {
if info.IsStream {
err, usage = difyStreamHandler(c, resp, info)
} else {
err, usage = difyHandler(c, resp, info)
}
return
}
func (a *Adaptor) GetModelList() []string {
return ModelList
}
func (a *Adaptor) GetChannelName() string {
return ChannelName
}

View File

@@ -1,5 +0,0 @@
package dify
var ModelList []string
var ChannelName = "dify"

View File

@@ -1,35 +0,0 @@
package dify
import "one-api/dto"
type DifyChatRequest struct {
Inputs map[string]interface{} `json:"inputs"`
Query string `json:"query"`
ResponseMode string `json:"response_mode"`
User string `json:"user"`
AutoGenerateName bool `json:"auto_generate_name"`
}
type DifyMetaData struct {
Usage dto.Usage `json:"usage"`
}
type DifyData struct {
WorkflowId string `json:"workflow_id"`
NodeId string `json:"node_id"`
}
type DifyChatCompletionResponse struct {
ConversationId string `json:"conversation_id"`
Answer string `json:"answer"`
CreateAt int64 `json:"create_at"`
MetaData DifyMetaData `json:"metadata"`
}
type DifyChunkChatCompletionResponse struct {
Event string `json:"event"`
ConversationId string `json:"conversation_id"`
Answer string `json:"answer"`
Data DifyData `json:"data"`
MetaData DifyMetaData `json:"metadata"`
}

View File

@@ -1,156 +0,0 @@
package dify
import (
"bufio"
"encoding/json"
"github.com/gin-gonic/gin"
"io"
"net/http"
"one-api/common"
"one-api/constant"
"one-api/dto"
relaycommon "one-api/relay/common"
"one-api/service"
"strings"
)
func requestOpenAI2Dify(request dto.GeneralOpenAIRequest) *DifyChatRequest {
content := ""
for _, message := range request.Messages {
if message.Role == "system" {
content += "SYSTEM: \n" + message.StringContent() + "\n"
} else if message.Role == "assistant" {
content += "ASSISTANT: \n" + message.StringContent() + "\n"
} else {
content += "USER: \n" + message.StringContent() + "\n"
}
}
mode := "blocking"
if request.Stream {
mode = "streaming"
}
user := request.User
if user == "" {
user = "api-user"
}
return &DifyChatRequest{
Inputs: make(map[string]interface{}),
Query: content,
ResponseMode: mode,
User: user,
AutoGenerateName: false,
}
}
func streamResponseDify2OpenAI(difyResponse DifyChunkChatCompletionResponse) *dto.ChatCompletionsStreamResponse {
response := dto.ChatCompletionsStreamResponse{
Object: "chat.completion.chunk",
Created: common.GetTimestamp(),
Model: "dify",
}
var choice dto.ChatCompletionsStreamResponseChoice
if constant.DifyDebug && difyResponse.Event == "workflow_started" {
choice.Delta.SetContentString("Workflow: " + difyResponse.Data.WorkflowId + "\n")
} else if constant.DifyDebug && difyResponse.Event == "node_started" {
choice.Delta.SetContentString("Node: " + difyResponse.Data.NodeId + "\n")
} else if difyResponse.Event == "message" {
choice.Delta.SetContentString(difyResponse.Answer)
}
response.Choices = append(response.Choices, choice)
return &response
}
func difyStreamHandler(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (*dto.OpenAIErrorWithStatusCode, *dto.Usage) {
var responseText string
usage := &dto.Usage{}
scanner := bufio.NewScanner(resp.Body)
scanner.Split(bufio.ScanLines)
service.SetEventStreamHeaders(c)
for scanner.Scan() {
data := scanner.Text()
if len(data) < 5 || !strings.HasPrefix(data, "data:") {
continue
}
data = strings.TrimPrefix(data, "data:")
var difyResponse DifyChunkChatCompletionResponse
err := json.Unmarshal([]byte(data), &difyResponse)
if err != nil {
common.SysError("error unmarshalling stream response: " + err.Error())
continue
}
var openaiResponse dto.ChatCompletionsStreamResponse
if difyResponse.Event == "message_end" {
usage = &difyResponse.MetaData.Usage
break
} else if difyResponse.Event == "error" {
break
} else {
openaiResponse = *streamResponseDify2OpenAI(difyResponse)
if len(openaiResponse.Choices) != 0 {
responseText += openaiResponse.Choices[0].Delta.GetContentString()
}
}
err = service.ObjectData(c, openaiResponse)
if err != nil {
common.SysError(err.Error())
}
}
if err := scanner.Err(); err != nil {
common.SysError("error reading stream: " + err.Error())
}
service.Done(c)
err := resp.Body.Close()
if err != nil {
//return service.OpenAIErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil
common.SysError("close_response_body_failed: " + err.Error())
}
if usage.TotalTokens == 0 {
usage.PromptTokens = info.PromptTokens
usage.CompletionTokens, _ = service.CountTokenText("gpt-3.5-turbo", responseText)
usage.TotalTokens = usage.PromptTokens + usage.CompletionTokens
}
return nil, usage
}
func difyHandler(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (*dto.OpenAIErrorWithStatusCode, *dto.Usage) {
var difyResponse DifyChatCompletionResponse
responseBody, err := io.ReadAll(resp.Body)
if err != nil {
return service.OpenAIErrorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil
}
err = resp.Body.Close()
if err != nil {
return service.OpenAIErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil
}
err = json.Unmarshal(responseBody, &difyResponse)
if err != nil {
return service.OpenAIErrorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil
}
fullTextResponse := dto.OpenAITextResponse{
Id: difyResponse.ConversationId,
Object: "chat.completion",
Created: common.GetTimestamp(),
Usage: difyResponse.MetaData.Usage,
}
content, _ := json.Marshal(difyResponse.Answer)
choice := dto.OpenAITextResponseChoice{
Index: 0,
Message: dto.Message{
Role: "assistant",
Content: content,
},
FinishReason: "stop",
}
fullTextResponse.Choices = append(fullTextResponse.Choices, choice)
jsonResponse, err := json.Marshal(fullTextResponse)
if err != nil {
return service.OpenAIErrorWrapper(err, "marshal_response_body_failed", http.StatusInternalServerError), nil
}
c.Writer.Header().Set("Content-Type", "application/json")
c.Writer.WriteHeader(resp.StatusCode)
_, err = c.Writer.Write(jsonResponse)
return nil, &difyResponse.MetaData.Usage
}

View File

@@ -9,47 +9,38 @@ import (
"one-api/dto"
"one-api/relay/channel"
relaycommon "one-api/relay/common"
"one-api/service"
)
type Adaptor struct {
}
func (a *Adaptor) ConvertAudioRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.AudioRequest) (io.Reader, error) {
//TODO implement me
return nil, errors.New("not implemented")
}
func (a *Adaptor) ConvertImageRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.ImageRequest) (any, error) {
//TODO implement me
return nil, errors.New("not implemented")
}
func (a *Adaptor) Init(info *relaycommon.RelayInfo) {
func (a *Adaptor) Init(info *relaycommon.RelayInfo, request dto.GeneralOpenAIRequest) {
}
// 定义一个映射,存储模型名称和对应的版本
var modelVersionMap = map[string]string{
"gemini-1.5-pro-latest": "v1beta",
"gemini-1.5-flash-latest": "v1beta",
"gemini-ultra": "v1beta",
"gemini-1.5-pro-latest": "v1beta",
"gemini-1.5-flash-latest": "v1beta",
"gemini-ultra": "v1beta",
}
func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) {
// 从映射中获取模型名称对应的版本,如果找不到就使用 info.ApiVersion 或默认的版本 "v1"
version, beta := modelVersionMap[info.UpstreamModelName]
if !beta {
if info.ApiVersion != "" {
version = info.ApiVersion
} else {
version = "v1"
}
}
// 从映射中获取模型名称对应的版本,如果找不到就使用 info.ApiVersion 或默认的版本 "v1"
version, beta := modelVersionMap[info.UpstreamModelName]
if !beta {
if info.ApiVersion != "" {
version = info.ApiVersion
} else {
version = "v1"
}
}
action := "generateContent"
if info.IsStream {
action = "streamGenerateContent?alt=sse"
}
return fmt.Sprintf("%s/%s/models/%s:%s", info.BaseUrl, version, info.UpstreamModelName, action), nil
action := "generateContent"
if info.IsStream {
action = "streamGenerateContent"
}
return fmt.Sprintf("%s/%s/models/%s:%s", info.BaseUrl, version, info.UpstreamModelName, action), nil
}
func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Request, info *relaycommon.RelayInfo) error {
@@ -58,24 +49,22 @@ func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Request, info *re
return nil
}
func (a *Adaptor) ConvertRequest(c *gin.Context, info *relaycommon.RelayInfo, request *dto.GeneralOpenAIRequest) (any, error) {
func (a *Adaptor) ConvertRequest(c *gin.Context, relayMode int, request *dto.GeneralOpenAIRequest) (any, error) {
if request == nil {
return nil, errors.New("request is nil")
}
return CovertGemini2OpenAI(*request), nil
}
func (a *Adaptor) ConvertRerankRequest(c *gin.Context, relayMode int, request dto.RerankRequest) (any, error) {
return nil, nil
}
func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (*http.Response, error) {
return channel.DoApiRequest(a, c, info, requestBody)
}
func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage *dto.Usage, err *dto.OpenAIErrorWithStatusCode) {
if info.IsStream {
err, usage = geminiChatStreamHandler(c, resp, info)
var responseText string
err, responseText = geminiChatStreamHandler(c, resp)
usage, _ = service.ResponseText2Usage(responseText, info.UpstreamModelName, info.PromptTokens)
} else {
err, usage = geminiChatHandler(c, resp, info.PromptTokens, info.UpstreamModelName)
}

View File

@@ -12,15 +12,9 @@ type GeminiInlineData struct {
Data string `json:"data"`
}
type FunctionCall struct {
FunctionName string `json:"name"`
Arguments any `json:"args"`
}
type GeminiPart struct {
Text string `json:"text,omitempty"`
InlineData *GeminiInlineData `json:"inlineData,omitempty"`
FunctionCall *FunctionCall `json:"functionCall,omitempty"`
Text string `json:"text,omitempty"`
InlineData *GeminiInlineData `json:"inlineData,omitempty"`
}
type GeminiChatContent struct {
@@ -65,11 +59,4 @@ type GeminiChatPromptFeedback struct {
type GeminiChatResponse struct {
Candidates []GeminiChatCandidate `json:"candidates"`
PromptFeedback GeminiChatPromptFeedback `json:"promptFeedback"`
UsageMetadata GeminiUsageMetadata `json:"usageMetadata"`
}
type GeminiUsageMetadata struct {
PromptTokenCount int `json:"promptTokenCount"`
CandidatesTokenCount int `json:"candidatesTokenCount"`
TotalTokenCount int `json:"totalTokenCount"`
}

View File

@@ -4,7 +4,6 @@ import (
"bufio"
"encoding/json"
"fmt"
"github.com/gin-gonic/gin"
"io"
"net/http"
"one-api/common"
@@ -12,6 +11,8 @@ import (
relaycommon "one-api/relay/common"
"one-api/service"
"strings"
"github.com/gin-gonic/gin"
)
// Setting safety to the lowest possible values since Gemini is already powerless enough
@@ -42,17 +43,7 @@ func CovertGemini2OpenAI(textRequest dto.GeneralOpenAIRequest) *GeminiChatReques
MaxOutputTokens: textRequest.MaxTokens,
},
}
if textRequest.Tools != nil {
functions := make([]dto.FunctionCall, 0, len(textRequest.Tools))
for _, tool := range textRequest.Tools {
functions = append(functions, tool.Function)
}
geminiRequest.Tools = []GeminiChatTools{
{
FunctionDeclarations: functions,
},
}
} else if textRequest.Functions != nil {
if textRequest.Functions != nil {
geminiRequest.Tools = []GeminiChatTools{
{
FunctionDeclarations: textRequest.Functions,
@@ -132,30 +123,6 @@ func (g *GeminiChatResponse) GetResponseText() string {
return ""
}
func getToolCalls(candidate *GeminiChatCandidate) []dto.ToolCall {
var toolCalls []dto.ToolCall
item := candidate.Content.Parts[0]
if item.FunctionCall == nil {
return toolCalls
}
argsBytes, err := json.Marshal(item.FunctionCall.Arguments)
if err != nil {
//common.SysError("getToolCalls failed: " + err.Error())
return toolCalls
}
toolCall := dto.ToolCall{
ID: fmt.Sprintf("call_%s", common.GetUUID()),
Type: "function",
Function: dto.FunctionCall{
Arguments: string(argsBytes),
Name: item.FunctionCall.FunctionName,
},
}
toolCalls = append(toolCalls, toolCall)
return toolCalls
}
func responseGeminiChat2OpenAI(response *GeminiChatResponse) *dto.OpenAITextResponse {
fullTextResponse := dto.OpenAITextResponse{
Id: fmt.Sprintf("chatcmpl-%s", common.GetUUID()),
@@ -174,11 +141,8 @@ func responseGeminiChat2OpenAI(response *GeminiChatResponse) *dto.OpenAITextResp
FinishReason: relaycommon.StopFinishReason,
}
if len(candidate.Content.Parts) > 0 {
if candidate.Content.Parts[0].FunctionCall != nil {
choice.Message.ToolCalls = getToolCalls(&candidate)
} else {
choice.Message.SetStringContent(candidate.Content.Parts[0].Text)
}
content, _ = json.Marshal(candidate.Content.Parts[0].Text)
choice.Message.Content = content
}
fullTextResponse.Choices = append(fullTextResponse.Choices, choice)
}
@@ -187,17 +151,7 @@ func responseGeminiChat2OpenAI(response *GeminiChatResponse) *dto.OpenAITextResp
func streamResponseGeminiChat2OpenAI(geminiResponse *GeminiChatResponse) *dto.ChatCompletionsStreamResponse {
var choice dto.ChatCompletionsStreamResponseChoice
//choice.Delta.SetContentString(geminiResponse.GetResponseText())
if len(geminiResponse.Candidates) > 0 && len(geminiResponse.Candidates[0].Content.Parts) > 0 {
respFirst := geminiResponse.Candidates[0].Content.Parts[0]
if respFirst.FunctionCall != nil {
// function response
choice.Delta.ToolCalls = getToolCalls(&geminiResponse.Candidates[0])
} else {
// text response
choice.Delta.SetContentString(respFirst.Text)
}
}
choice.Delta.SetContentString(geminiResponse.GetResponseText())
choice.FinishReason = &relaycommon.StopFinishReason
var response dto.ChatCompletionsStreamResponse
response.Object = "chat.completion.chunk"
@@ -206,59 +160,74 @@ func streamResponseGeminiChat2OpenAI(geminiResponse *GeminiChatResponse) *dto.Ch
return &response
}
func geminiChatStreamHandler(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (*dto.OpenAIErrorWithStatusCode, *dto.Usage) {
func geminiChatStreamHandler(c *gin.Context, resp *http.Response) (*dto.OpenAIErrorWithStatusCode, string) {
responseText := ""
id := fmt.Sprintf("chatcmpl-%s", common.GetUUID())
createAt := common.GetTimestamp()
var usage = &dto.Usage{}
dataChan := make(chan string)
stopChan := make(chan bool)
scanner := bufio.NewScanner(resp.Body)
scanner.Split(bufio.ScanLines)
scanner.Split(func(data []byte, atEOF bool) (advance int, token []byte, err error) {
if atEOF && len(data) == 0 {
return 0, nil, nil
}
if i := strings.Index(string(data), "\n"); i >= 0 {
return i + 1, data[0:i], nil
}
if atEOF {
return len(data), data, nil
}
return 0, nil, nil
})
go func() {
for scanner.Scan() {
data := scanner.Text()
data = strings.TrimSpace(data)
if !strings.HasPrefix(data, "\"text\": \"") {
continue
}
data = strings.TrimPrefix(data, "\"text\": \"")
data = strings.TrimSuffix(data, "\"")
dataChan <- data
}
stopChan <- true
}()
service.SetEventStreamHeaders(c)
for scanner.Scan() {
data := scanner.Text()
info.SetFirstResponseTime()
data = strings.TrimSpace(data)
if !strings.HasPrefix(data, "data: ") {
continue
}
data = strings.TrimPrefix(data, "data: ")
data = strings.TrimSuffix(data, "\"")
var geminiResponse GeminiChatResponse
err := json.Unmarshal([]byte(data), &geminiResponse)
if err != nil {
common.LogError(c, "error unmarshalling stream response: "+err.Error())
continue
}
response := streamResponseGeminiChat2OpenAI(&geminiResponse)
if response == nil {
continue
}
response.Id = id
response.Created = createAt
responseText += response.Choices[0].Delta.GetContentString()
if geminiResponse.UsageMetadata.TotalTokenCount != 0 {
usage.PromptTokens = geminiResponse.UsageMetadata.PromptTokenCount
usage.CompletionTokens = geminiResponse.UsageMetadata.CandidatesTokenCount
}
err = service.ObjectData(c, response)
if err != nil {
common.LogError(c, err.Error())
c.Stream(func(w io.Writer) bool {
select {
case data := <-dataChan:
// this is used to prevent annoying \ related format bug
data = fmt.Sprintf("{\"content\": \"%s\"}", data)
type dummyStruct struct {
Content string `json:"content"`
}
var dummy dummyStruct
err := json.Unmarshal([]byte(data), &dummy)
responseText += dummy.Content
var choice dto.ChatCompletionsStreamResponseChoice
choice.Delta.SetContentString(dummy.Content)
response := dto.ChatCompletionsStreamResponse{
Id: fmt.Sprintf("chatcmpl-%s", common.GetUUID()),
Object: "chat.completion.chunk",
Created: common.GetTimestamp(),
Model: "gemini-pro",
Choices: []dto.ChatCompletionsStreamResponseChoice{choice},
}
jsonResponse, err := json.Marshal(response)
if err != nil {
common.SysError("error marshalling stream response: " + err.Error())
return true
}
c.Render(-1, common.CustomEvent{Data: "data: " + string(jsonResponse)})
return true
case <-stopChan:
c.Render(-1, common.CustomEvent{Data: "data: [DONE]"})
return false
}
})
err := resp.Body.Close()
if err != nil {
return service.OpenAIErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), ""
}
usage.TotalTokens = usage.PromptTokens + usage.CompletionTokens
if info.ShouldIncludeUsage {
response := service.GenerateFinalUsageResponse(id, createAt, info.UpstreamModelName, *usage)
err := service.ObjectData(c, response)
if err != nil {
common.SysError("send final response failed: " + err.Error())
}
}
service.Done(c)
resp.Body.Close()
return nil, usage
return nil, responseText
}
func geminiChatHandler(c *gin.Context, resp *http.Response, promptTokens int, model string) (*dto.OpenAIErrorWithStatusCode, *dto.Usage) {
@@ -287,10 +256,11 @@ func geminiChatHandler(c *gin.Context, resp *http.Response, promptTokens int, mo
}, nil
}
fullTextResponse := responseGeminiChat2OpenAI(&geminiResponse)
completionTokens, _ := service.CountTokenText(geminiResponse.GetResponseText(), model)
usage := dto.Usage{
PromptTokens: geminiResponse.UsageMetadata.PromptTokenCount,
CompletionTokens: geminiResponse.UsageMetadata.CandidatesTokenCount,
TotalTokens: geminiResponse.UsageMetadata.TotalTokenCount,
PromptTokens: promptTokens,
CompletionTokens: completionTokens,
TotalTokens: promptTokens + completionTokens,
}
fullTextResponse.Usage = usage
jsonResponse, err := json.Marshal(fullTextResponse)

View File

@@ -1,71 +0,0 @@
package jina
import (
"errors"
"fmt"
"github.com/gin-gonic/gin"
"io"
"net/http"
"one-api/dto"
"one-api/relay/channel"
relaycommon "one-api/relay/common"
"one-api/relay/constant"
)
type Adaptor struct {
}
func (a *Adaptor) ConvertAudioRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.AudioRequest) (io.Reader, error) {
//TODO implement me
return nil, errors.New("not implemented")
}
func (a *Adaptor) ConvertImageRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.ImageRequest) (any, error) {
//TODO implement me
return nil, errors.New("not implemented")
}
func (a *Adaptor) Init(info *relaycommon.RelayInfo) {
}
func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) {
if info.RelayMode == constant.RelayModeRerank {
return fmt.Sprintf("%s/v1/rerank", info.BaseUrl), nil
} else if info.RelayMode == constant.RelayModeEmbeddings {
return fmt.Sprintf("%s/v1/embeddings ", info.BaseUrl), nil
}
return "", errors.New("invalid relay mode")
}
func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Request, info *relaycommon.RelayInfo) error {
channel.SetupApiRequestHeader(info, c, req)
req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", info.ApiKey))
return nil
}
func (a *Adaptor) ConvertRequest(c *gin.Context, info *relaycommon.RelayInfo, request *dto.GeneralOpenAIRequest) (any, error) {
return request, nil
}
func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (*http.Response, error) {
return channel.DoApiRequest(a, c, info, requestBody)
}
func (a *Adaptor) ConvertRerankRequest(c *gin.Context, relayMode int, request dto.RerankRequest) (any, error) {
return request, nil
}
func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage *dto.Usage, err *dto.OpenAIErrorWithStatusCode) {
if info.RelayMode == constant.RelayModeRerank {
err, usage = jinaRerankHandler(c, resp)
}
return
}
func (a *Adaptor) GetModelList() []string {
return ModelList
}
func (a *Adaptor) GetChannelName() string {
return ChannelName
}

View File

@@ -1,8 +0,0 @@
package jina
var ModelList = []string{
"jina-clip-v1",
"jina-reranker-v2-base-multilingual",
}
var ChannelName = "jina"

View File

@@ -1,35 +0,0 @@
package jina
import (
"encoding/json"
"github.com/gin-gonic/gin"
"io"
"net/http"
"one-api/dto"
"one-api/service"
)
func jinaRerankHandler(c *gin.Context, resp *http.Response) (*dto.OpenAIErrorWithStatusCode, *dto.Usage) {
responseBody, err := io.ReadAll(resp.Body)
if err != nil {
return service.OpenAIErrorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil
}
err = resp.Body.Close()
if err != nil {
return service.OpenAIErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil
}
var jinaResp dto.RerankResponse
err = json.Unmarshal(responseBody, &jinaResp)
if err != nil {
return service.OpenAIErrorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil
}
jsonResponse, err := json.Marshal(jinaResp)
if err != nil {
return service.OpenAIErrorWrapper(err, "marshal_response_body_failed", http.StatusInternalServerError), nil
}
c.Writer.Header().Set("Content-Type", "application/json")
c.Writer.WriteHeader(resp.StatusCode)
_, err = c.Writer.Write(jsonResponse)
return nil, &jinaResp.Usage
}

View File

@@ -10,22 +10,13 @@ import (
"one-api/relay/channel/openai"
relaycommon "one-api/relay/common"
relayconstant "one-api/relay/constant"
"one-api/service"
)
type Adaptor struct {
}
func (a *Adaptor) ConvertAudioRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.AudioRequest) (io.Reader, error) {
//TODO implement me
return nil, errors.New("not implemented")
}
func (a *Adaptor) ConvertImageRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.ImageRequest) (any, error) {
//TODO implement me
return nil, errors.New("not implemented")
}
func (a *Adaptor) Init(info *relaycommon.RelayInfo) {
func (a *Adaptor) Init(info *relaycommon.RelayInfo, request dto.GeneralOpenAIRequest) {
}
func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) {
@@ -42,11 +33,11 @@ func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Request, info *re
return nil
}
func (a *Adaptor) ConvertRequest(c *gin.Context, info *relaycommon.RelayInfo, request *dto.GeneralOpenAIRequest) (any, error) {
func (a *Adaptor) ConvertRequest(c *gin.Context, relayMode int, request *dto.GeneralOpenAIRequest) (any, error) {
if request == nil {
return nil, errors.New("request is nil")
}
switch info.RelayMode {
switch relayMode {
case relayconstant.RelayModeEmbeddings:
return requestOpenAI2Embeddings(*request), nil
default:
@@ -54,17 +45,15 @@ func (a *Adaptor) ConvertRequest(c *gin.Context, info *relaycommon.RelayInfo, re
}
}
func (a *Adaptor) ConvertRerankRequest(c *gin.Context, relayMode int, request dto.RerankRequest) (any, error) {
return nil, nil
}
func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (*http.Response, error) {
return channel.DoApiRequest(a, c, info, requestBody)
}
func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage *dto.Usage, err *dto.OpenAIErrorWithStatusCode) {
if info.IsStream {
err, usage = openai.OaiStreamHandler(c, resp, info)
var responseText string
err, responseText, _ = openai.OpenaiStreamHandler(c, resp, info.RelayMode)
usage, _ = service.ResponseText2Usage(responseText, info.UpstreamModelName, info.PromptTokens)
} else {
if info.RelayMode == relayconstant.RelayModeEmbeddings {
err, usage = ollamaEmbeddingHandler(c, resp, info.PromptTokens, info.UpstreamModelName, info.RelayMode)

View File

@@ -1,13 +1,10 @@
package openai
import (
"bytes"
"encoding/json"
"errors"
"fmt"
"github.com/gin-gonic/gin"
"io"
"mime/multipart"
"net/http"
"one-api/common"
"one-api/dto"
@@ -17,16 +14,15 @@ import (
"one-api/relay/channel/minimax"
"one-api/relay/channel/moonshot"
relaycommon "one-api/relay/common"
"one-api/relay/constant"
"one-api/service"
"strings"
)
type Adaptor struct {
ChannelType int
ResponseFormat string
ChannelType int
}
func (a *Adaptor) Init(info *relaycommon.RelayInfo) {
func (a *Adaptor) Init(info *relaycommon.RelayInfo, request dto.GeneralOpenAIRequest) {
a.ChannelType = info.ChannelType
}
@@ -71,84 +67,26 @@ func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Request, info *re
return nil
}
func (a *Adaptor) ConvertRequest(c *gin.Context, info *relaycommon.RelayInfo, request *dto.GeneralOpenAIRequest) (any, error) {
func (a *Adaptor) ConvertRequest(c *gin.Context, relayMode int, request *dto.GeneralOpenAIRequest) (any, error) {
if request == nil {
return nil, errors.New("request is nil")
}
if info.ChannelType != common.ChannelTypeOpenAI {
request.StreamOptions = nil
}
return request, nil
}
func (a *Adaptor) ConvertRerankRequest(c *gin.Context, relayMode int, request dto.RerankRequest) (any, error) {
return nil, errors.New("not implemented")
}
func (a *Adaptor) ConvertAudioRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.AudioRequest) (io.Reader, error) {
a.ResponseFormat = request.ResponseFormat
if info.RelayMode == constant.RelayModeAudioSpeech {
jsonData, err := json.Marshal(request)
if err != nil {
return nil, fmt.Errorf("error marshalling object: %w", err)
}
return bytes.NewReader(jsonData), nil
} else {
var requestBody bytes.Buffer
writer := multipart.NewWriter(&requestBody)
writer.WriteField("model", request.Model)
// 添加文件字段
file, header, err := c.Request.FormFile("file")
if err != nil {
return nil, errors.New("file is required")
}
defer file.Close()
part, err := writer.CreateFormFile("file", header.Filename)
if err != nil {
return nil, errors.New("create form file failed")
}
if _, err := io.Copy(part, file); err != nil {
return nil, errors.New("copy file failed")
}
// 关闭 multipart 编写器以设置分界线
writer.Close()
c.Request.Header.Set("Content-Type", writer.FormDataContentType())
return &requestBody, nil
}
}
func (a *Adaptor) ConvertImageRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.ImageRequest) (any, error) {
return request, nil
}
func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (*http.Response, error) {
if info.RelayMode == constant.RelayModeAudioTranscription || info.RelayMode == constant.RelayModeAudioTranslation {
return channel.DoFormRequest(a, c, info, requestBody)
} else {
return channel.DoApiRequest(a, c, info, requestBody)
}
return channel.DoApiRequest(a, c, info, requestBody)
}
func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage *dto.Usage, err *dto.OpenAIErrorWithStatusCode) {
switch info.RelayMode {
case constant.RelayModeAudioSpeech:
err, usage = OpenaiTTSHandler(c, resp, info)
case constant.RelayModeAudioTranslation:
fallthrough
case constant.RelayModeAudioTranscription:
err, usage = OpenaiSTTHandler(c, resp, info, a.ResponseFormat)
case constant.RelayModeImagesGenerations:
err, usage = OpenaiTTSHandler(c, resp, info)
default:
if info.IsStream {
err, usage = OaiStreamHandler(c, resp, info)
} else {
err, usage = OpenaiHandler(c, resp, info.PromptTokens, info.UpstreamModelName)
}
if info.IsStream {
var responseText string
var toolCount int
err, responseText, toolCount = OpenaiStreamHandler(c, resp, info.RelayMode)
usage, _ = service.ResponseText2Usage(responseText, info.UpstreamModelName, info.PromptTokens)
usage.CompletionTokens += toolCount * 7
} else {
err, usage = OpenaiHandler(c, resp, info.PromptTokens, info.UpstreamModelName)
}
return
}

View File

@@ -4,48 +4,45 @@ import (
"bufio"
"bytes"
"encoding/json"
"fmt"
"github.com/bytedance/gopkg/util/gopool"
"github.com/gin-gonic/gin"
"io"
"net/http"
"one-api/common"
"one-api/constant"
"one-api/dto"
relaycommon "one-api/relay/common"
relayconstant "one-api/relay/constant"
"one-api/service"
"strings"
"sync"
"time"
)
func OaiStreamHandler(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (*dto.OpenAIErrorWithStatusCode, *dto.Usage) {
containStreamUsage := false
responseId := ""
var createAt int64 = 0
var systemFingerprint string
model := info.UpstreamModelName
func OpenaiStreamHandler(c *gin.Context, resp *http.Response, relayMode int) (*dto.OpenAIErrorWithStatusCode, string, int) {
//checkSensitive := constant.ShouldCheckCompletionSensitive()
var responseTextBuilder strings.Builder
var usage = &dto.Usage{}
var streamItems []string // store stream items
toolCount := 0
scanner := bufio.NewScanner(resp.Body)
scanner.Split(bufio.ScanLines)
service.SetEventStreamHeaders(c)
ticker := time.NewTicker(time.Duration(constant.StreamingTimeout) * time.Second)
defer ticker.Stop()
stopChan := make(chan bool)
scanner.Split(func(data []byte, atEOF bool) (advance int, token []byte, err error) {
if atEOF && len(data) == 0 {
return 0, nil, nil
}
if i := strings.Index(string(data), "\n"); i >= 0 {
return i + 1, data[0:i], nil
}
if atEOF {
return len(data), data, nil
}
return 0, nil, nil
})
dataChan := make(chan string, 5)
stopChan := make(chan bool, 2)
defer close(stopChan)
gopool.Go(func() {
defer close(dataChan)
var wg sync.WaitGroup
go func() {
wg.Add(1)
defer wg.Done()
var streamItems []string // store stream items
for scanner.Scan() {
info.SetFirstResponseTime()
ticker.Reset(time.Duration(constant.StreamingTimeout) * time.Second)
data := scanner.Text()
if len(data) < 6 { // ignore blank line or wrong format
continue
@@ -53,47 +50,39 @@ func OaiStreamHandler(c *gin.Context, resp *http.Response, info *relaycommon.Rel
if data[:6] != "data: " && data[:6] != "[DONE]" {
continue
}
common.SafeSendString(dataChan, data)
data = data[6:]
if !strings.HasPrefix(data, "[DONE]") {
err := service.StringData(c, data)
if err != nil {
common.LogError(c, "streaming error: "+err.Error())
}
streamItems = append(streamItems, data)
}
}
common.SafeSendBool(stopChan, true)
})
select {
case <-ticker.C:
// 超时处理逻辑
common.LogError(c, "streaming timeout")
case <-stopChan:
// 正常结束
}
// 计算token
streamResp := "[" + strings.Join(streamItems, ",") + "]"
switch info.RelayMode {
case relayconstant.RelayModeChatCompletions:
var streamResponses []dto.ChatCompletionsStreamResponse
err := json.Unmarshal(common.StringToByteSlice(streamResp), &streamResponses)
if err != nil {
// 一次性解析失败,逐个解析
common.SysError("error unmarshalling stream response: " + err.Error())
for _, item := range streamItems {
var streamResponse dto.ChatCompletionsStreamResponse
err := json.Unmarshal(common.StringToByteSlice(item), &streamResponse)
if err == nil {
responseId = streamResponse.Id
createAt = streamResponse.Created
systemFingerprint = streamResponse.GetSystemFingerprint()
model = streamResponse.Model
if service.ValidUsage(streamResponse.Usage) {
usage = streamResponse.Usage
containStreamUsage = true
streamResp := "[" + strings.Join(streamItems, ",") + "]"
switch relayMode {
case relayconstant.RelayModeChatCompletions:
var streamResponses []dto.ChatCompletionsStreamResponseSimple
err := json.Unmarshal(common.StringToByteSlice(streamResp), &streamResponses)
if err != nil {
common.SysError("error unmarshalling stream response: " + err.Error())
for _, item := range streamItems {
var streamResponse dto.ChatCompletionsStreamResponseSimple
err := json.Unmarshal(common.StringToByteSlice(item), &streamResponse)
if err == nil {
for _, choice := range streamResponse.Choices {
responseTextBuilder.WriteString(choice.Delta.GetContentString())
if choice.Delta.ToolCalls != nil {
if len(choice.Delta.ToolCalls) > toolCount {
toolCount = len(choice.Delta.ToolCalls)
}
for _, tool := range choice.Delta.ToolCalls {
responseTextBuilder.WriteString(tool.Function.Name)
responseTextBuilder.WriteString(tool.Function.Arguments)
}
}
}
}
}
} else {
for _, streamResponse := range streamResponses {
for _, choice := range streamResponse.Choices {
responseTextBuilder.WriteString(choice.Delta.GetContentString())
if choice.Delta.ToolCalls != nil {
@@ -108,69 +97,55 @@ func OaiStreamHandler(c *gin.Context, resp *http.Response, info *relaycommon.Rel
}
}
}
} else {
for _, streamResponse := range streamResponses {
responseId = streamResponse.Id
createAt = streamResponse.Created
systemFingerprint = streamResponse.GetSystemFingerprint()
model = streamResponse.Model
if service.ValidUsage(streamResponse.Usage) {
usage = streamResponse.Usage
containStreamUsage = true
}
for _, choice := range streamResponse.Choices {
responseTextBuilder.WriteString(choice.Delta.GetContentString())
if choice.Delta.ToolCalls != nil {
if len(choice.Delta.ToolCalls) > toolCount {
toolCount = len(choice.Delta.ToolCalls)
}
for _, tool := range choice.Delta.ToolCalls {
responseTextBuilder.WriteString(tool.Function.Name)
responseTextBuilder.WriteString(tool.Function.Arguments)
case relayconstant.RelayModeCompletions:
var streamResponses []dto.CompletionsStreamResponse
err := json.Unmarshal(common.StringToByteSlice(streamResp), &streamResponses)
if err != nil {
common.SysError("error unmarshalling stream response: " + err.Error())
for _, item := range streamItems {
var streamResponse dto.CompletionsStreamResponse
err := json.Unmarshal(common.StringToByteSlice(item), &streamResponse)
if err == nil {
for _, choice := range streamResponse.Choices {
responseTextBuilder.WriteString(choice.Text)
}
}
}
}
}
case relayconstant.RelayModeCompletions:
var streamResponses []dto.CompletionsStreamResponse
err := json.Unmarshal(common.StringToByteSlice(streamResp), &streamResponses)
if err != nil {
// 一次性解析失败,逐个解析
common.SysError("error unmarshalling stream response: " + err.Error())
for _, item := range streamItems {
var streamResponse dto.CompletionsStreamResponse
err := json.Unmarshal(common.StringToByteSlice(item), &streamResponse)
if err == nil {
} else {
for _, streamResponse := range streamResponses {
for _, choice := range streamResponse.Choices {
responseTextBuilder.WriteString(choice.Text)
}
}
}
} else {
for _, streamResponse := range streamResponses {
for _, choice := range streamResponse.Choices {
responseTextBuilder.WriteString(choice.Text)
}
}
}
if len(dataChan) > 0 {
// wait data out
time.Sleep(2 * time.Second)
}
common.SafeSendBool(stopChan, true)
}()
service.SetEventStreamHeaders(c)
c.Stream(func(w io.Writer) bool {
select {
case data := <-dataChan:
if strings.HasPrefix(data, "data: [DONE]") {
data = data[:12]
}
// some implementations may add \r at the end of data
data = strings.TrimSuffix(data, "\r")
c.Render(-1, common.CustomEvent{Data: data})
return true
case <-stopChan:
return false
}
})
err := resp.Body.Close()
if err != nil {
return service.OpenAIErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), "", toolCount
}
if !containStreamUsage {
usage, _ = service.ResponseText2Usage(responseTextBuilder.String(), info.UpstreamModelName, info.PromptTokens)
usage.CompletionTokens += toolCount * 7
}
if info.ShouldIncludeUsage && !containStreamUsage {
response := service.GenerateFinalUsageResponse(responseId, createAt, model, *usage)
response.SetSystemFingerprint(systemFingerprint)
service.ObjectData(c, response)
}
service.Done(c)
resp.Body.Close()
return nil, usage
wg.Wait()
return nil, responseTextBuilder.String(), toolCount
}
func OpenaiHandler(c *gin.Context, resp *http.Response, promptTokens int, model string) (*dto.OpenAIErrorWithStatusCode, *dto.Usage) {
@@ -207,8 +182,12 @@ func OpenaiHandler(c *gin.Context, resp *http.Response, promptTokens int, model
if err != nil {
return service.OpenAIErrorWrapper(err, "copy_response_body_failed", http.StatusInternalServerError), nil
}
resp.Body.Close()
if simpleResponse.Usage.TotalTokens == 0 || (simpleResponse.Usage.PromptTokens == 0 && simpleResponse.Usage.CompletionTokens == 0) {
err = resp.Body.Close()
if err != nil {
return service.OpenAIErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil
}
if simpleResponse.Usage.TotalTokens == 0 {
completionTokens := 0
for _, choice := range simpleResponse.Choices {
ctkm, _ := service.CountTokenText(string(choice.Message.Content), model)
@@ -222,134 +201,3 @@ func OpenaiHandler(c *gin.Context, resp *http.Response, promptTokens int, model
}
return nil, &simpleResponse.Usage
}
func OpenaiTTSHandler(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (*dto.OpenAIErrorWithStatusCode, *dto.Usage) {
responseBody, err := io.ReadAll(resp.Body)
if err != nil {
return service.OpenAIErrorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil
}
err = resp.Body.Close()
if err != nil {
return service.OpenAIErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil
}
// Reset response body
resp.Body = io.NopCloser(bytes.NewBuffer(responseBody))
// We shouldn't set the header before we parse the response body, because the parse part may fail.
// And then we will have to send an error response, but in this case, the header has already been set.
// So the httpClient will be confused by the response.
// For example, Postman will report error, and we cannot check the response at all.
for k, v := range resp.Header {
c.Writer.Header().Set(k, v[0])
}
c.Writer.WriteHeader(resp.StatusCode)
_, err = io.Copy(c.Writer, resp.Body)
if err != nil {
return service.OpenAIErrorWrapper(err, "copy_response_body_failed", http.StatusInternalServerError), nil
}
err = resp.Body.Close()
if err != nil {
return service.OpenAIErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil
}
usage := &dto.Usage{}
usage.PromptTokens = info.PromptTokens
usage.TotalTokens = info.PromptTokens
return nil, usage
}
func OpenaiSTTHandler(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo, responseFormat string) (*dto.OpenAIErrorWithStatusCode, *dto.Usage) {
var audioResp dto.AudioResponse
responseBody, err := io.ReadAll(resp.Body)
if err != nil {
return service.OpenAIErrorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil
}
err = resp.Body.Close()
if err != nil {
return service.OpenAIErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil
}
err = json.Unmarshal(responseBody, &audioResp)
if err != nil {
return service.OpenAIErrorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil
}
// Reset response body
resp.Body = io.NopCloser(bytes.NewBuffer(responseBody))
// We shouldn't set the header before we parse the response body, because the parse part may fail.
// And then we will have to send an error response, but in this case, the header has already been set.
// So the httpClient will be confused by the response.
// For example, Postman will report error, and we cannot check the response at all.
for k, v := range resp.Header {
c.Writer.Header().Set(k, v[0])
}
c.Writer.WriteHeader(resp.StatusCode)
_, err = io.Copy(c.Writer, resp.Body)
if err != nil {
return service.OpenAIErrorWrapper(err, "copy_response_body_failed", http.StatusInternalServerError), nil
}
resp.Body.Close()
var text string
switch responseFormat {
case "json":
text, err = getTextFromJSON(responseBody)
case "text":
text, err = getTextFromText(responseBody)
case "srt":
text, err = getTextFromSRT(responseBody)
case "verbose_json":
text, err = getTextFromVerboseJSON(responseBody)
case "vtt":
text, err = getTextFromVTT(responseBody)
}
usage := &dto.Usage{}
usage.PromptTokens = info.PromptTokens
usage.CompletionTokens, _ = service.CountTokenText(text, info.UpstreamModelName)
usage.TotalTokens = usage.PromptTokens + usage.CompletionTokens
return nil, usage
}
func getTextFromVTT(body []byte) (string, error) {
return getTextFromSRT(body)
}
func getTextFromVerboseJSON(body []byte) (string, error) {
var whisperResponse dto.WhisperVerboseJSONResponse
if err := json.Unmarshal(body, &whisperResponse); err != nil {
return "", fmt.Errorf("unmarshal_response_body_failed err :%w", err)
}
return whisperResponse.Text, nil
}
func getTextFromSRT(body []byte) (string, error) {
scanner := bufio.NewScanner(strings.NewReader(string(body)))
var builder strings.Builder
var textLine bool
for scanner.Scan() {
line := scanner.Text()
if textLine {
builder.WriteString(line)
textLine = false
continue
} else if strings.Contains(line, "-->") {
textLine = true
continue
}
}
if err := scanner.Err(); err != nil {
return "", err
}
return builder.String(), nil
}
func getTextFromText(body []byte) (string, error) {
return strings.TrimSuffix(string(body), "\n"), nil
}
func getTextFromJSON(body []byte) (string, error) {
var whisperResponse dto.AudioResponse
if err := json.Unmarshal(body, &whisperResponse); err != nil {
return "", fmt.Errorf("unmarshal_response_body_failed err :%w", err)
}
return whisperResponse.Text, nil
}

View File

@@ -15,17 +15,7 @@ import (
type Adaptor struct {
}
func (a *Adaptor) ConvertAudioRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.AudioRequest) (io.Reader, error) {
//TODO implement me
return nil, errors.New("not implemented")
}
func (a *Adaptor) ConvertImageRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.ImageRequest) (any, error) {
//TODO implement me
return nil, errors.New("not implemented")
}
func (a *Adaptor) Init(info *relaycommon.RelayInfo) {
func (a *Adaptor) Init(info *relaycommon.RelayInfo, request dto.GeneralOpenAIRequest) {
}
func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) {
@@ -38,17 +28,13 @@ func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Request, info *re
return nil
}
func (a *Adaptor) ConvertRequest(c *gin.Context, info *relaycommon.RelayInfo, request *dto.GeneralOpenAIRequest) (any, error) {
func (a *Adaptor) ConvertRequest(c *gin.Context, relayMode int, request *dto.GeneralOpenAIRequest) (any, error) {
if request == nil {
return nil, errors.New("request is nil")
}
return request, nil
}
func (a *Adaptor) ConvertRerankRequest(c *gin.Context, relayMode int, request dto.RerankRequest) (any, error) {
return nil, nil
}
func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (*http.Response, error) {
return channel.DoApiRequest(a, c, info, requestBody)
}

View File

@@ -10,22 +10,13 @@ import (
"one-api/relay/channel"
"one-api/relay/channel/openai"
relaycommon "one-api/relay/common"
"one-api/service"
)
type Adaptor struct {
}
func (a *Adaptor) ConvertAudioRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.AudioRequest) (io.Reader, error) {
//TODO implement me
return nil, errors.New("not implemented")
}
func (a *Adaptor) ConvertImageRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.ImageRequest) (any, error) {
//TODO implement me
return nil, errors.New("not implemented")
}
func (a *Adaptor) Init(info *relaycommon.RelayInfo) {
func (a *Adaptor) Init(info *relaycommon.RelayInfo, request dto.GeneralOpenAIRequest) {
}
func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) {
@@ -38,7 +29,7 @@ func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Request, info *re
return nil
}
func (a *Adaptor) ConvertRequest(c *gin.Context, info *relaycommon.RelayInfo, request *dto.GeneralOpenAIRequest) (any, error) {
func (a *Adaptor) ConvertRequest(c *gin.Context, relayMode int, request *dto.GeneralOpenAIRequest) (any, error) {
if request == nil {
return nil, errors.New("request is nil")
}
@@ -48,17 +39,15 @@ func (a *Adaptor) ConvertRequest(c *gin.Context, info *relaycommon.RelayInfo, re
return requestOpenAI2Perplexity(*request), nil
}
func (a *Adaptor) ConvertRerankRequest(c *gin.Context, relayMode int, request dto.RerankRequest) (any, error) {
return nil, nil
}
func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (*http.Response, error) {
return channel.DoApiRequest(a, c, info, requestBody)
}
func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage *dto.Usage, err *dto.OpenAIErrorWithStatusCode) {
if info.IsStream {
err, usage = openai.OaiStreamHandler(c, resp, info)
var responseText string
err, responseText, _ = openai.OpenaiStreamHandler(c, resp, info.RelayMode)
usage, _ = service.ResponseText2Usage(responseText, info.UpstreamModelName, info.PromptTokens)
} else {
err, usage = openai.OpenaiHandler(c, resp, info.PromptTokens, info.UpstreamModelName)
}

View File

@@ -2,7 +2,6 @@ package suno
import (
"bytes"
"context"
"encoding/json"
"fmt"
"github.com/gin-gonic/gin"
@@ -15,7 +14,6 @@ import (
relaycommon "one-api/relay/common"
"one-api/service"
"strings"
"time"
)
type TaskAdaptor struct {
@@ -56,7 +54,7 @@ func (a *TaskAdaptor) ValidateRequestAndSetAction(c *gin.Context, info *relaycom
func (a *TaskAdaptor) BuildRequestURL(info *relaycommon.TaskRelayInfo) (string, error) {
baseURL := info.BaseUrl
fullRequestURL := fmt.Sprintf("%s%s", baseURL, "/suno/submit/"+info.Action)
fullRequestURL := fmt.Sprintf("%s%s", baseURL, "/submit/"+info.Action)
return fullRequestURL, nil
}
@@ -126,34 +124,6 @@ func (a *TaskAdaptor) GetChannelName() string {
return ChannelName
}
func (a *TaskAdaptor) FetchTask(baseUrl, key string, body map[string]any) (*http.Response, error) {
requestUrl := fmt.Sprintf("%s/suno/fetch", baseUrl)
byteBody, err := json.Marshal(body)
if err != nil {
return nil, err
}
req, err := http.NewRequest("POST", requestUrl, bytes.NewBuffer(byteBody))
if err != nil {
common.SysError(fmt.Sprintf("Get Task error: %v", err))
return nil, err
}
defer req.Body.Close()
// 设置超时时间
timeout := time.Second * 15
ctx, cancel := context.WithTimeout(context.Background(), timeout)
defer cancel()
// 使用带有超时的 context 创建新的请求
req = req.WithContext(ctx)
req.Header.Set("Content-Type", "application/json")
req.Header.Set("Authorization", "Bearer "+key)
resp, err := service.GetHttpClient().Do(req)
if err != nil {
return nil, err
}
return resp, nil
}
func actionValidate(c *gin.Context, sunoRequest *dto.SunoSubmitReq, action string) (err error) {
switch action {
case constant.SunoActionMusic:

View File

@@ -6,73 +6,49 @@ import (
"github.com/gin-gonic/gin"
"io"
"net/http"
"one-api/common"
"one-api/dto"
"one-api/relay/channel"
relaycommon "one-api/relay/common"
"one-api/service"
"strconv"
"strings"
)
type Adaptor struct {
Sign string
AppID int64
Action string
Version string
Timestamp int64
Sign string
}
func (a *Adaptor) ConvertAudioRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.AudioRequest) (io.Reader, error) {
//TODO implement me
return nil, errors.New("not implemented")
}
func (a *Adaptor) ConvertImageRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.ImageRequest) (any, error) {
//TODO implement me
return nil, errors.New("not implemented")
}
func (a *Adaptor) Init(info *relaycommon.RelayInfo) {
a.Action = "ChatCompletions"
a.Version = "2023-09-01"
a.Timestamp = common.GetTimestamp()
func (a *Adaptor) Init(info *relaycommon.RelayInfo, request dto.GeneralOpenAIRequest) {
}
func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) {
return fmt.Sprintf("%s/", info.BaseUrl), nil
return fmt.Sprintf("%s/hyllm/v1/chat/completions", info.BaseUrl), nil
}
func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Request, info *relaycommon.RelayInfo) error {
channel.SetupApiRequestHeader(info, c, req)
req.Header.Set("Authorization", a.Sign)
req.Header.Set("X-TC-Action", a.Action)
req.Header.Set("X-TC-Version", a.Version)
req.Header.Set("X-TC-Timestamp", strconv.FormatInt(a.Timestamp, 10))
req.Header.Set("X-TC-Action", info.UpstreamModelName)
return nil
}
func (a *Adaptor) ConvertRequest(c *gin.Context, info *relaycommon.RelayInfo, request *dto.GeneralOpenAIRequest) (any, error) {
func (a *Adaptor) ConvertRequest(c *gin.Context, relayMode int, request *dto.GeneralOpenAIRequest) (any, error) {
if request == nil {
return nil, errors.New("request is nil")
}
apiKey := c.Request.Header.Get("Authorization")
apiKey = strings.TrimPrefix(apiKey, "Bearer ")
appId, secretId, secretKey, err := parseTencentConfig(apiKey)
a.AppID = appId
if err != nil {
return nil, err
}
tencentRequest := requestOpenAI2Tencent(a, *request)
tencentRequest := requestOpenAI2Tencent(*request)
tencentRequest.AppId = appId
tencentRequest.SecretId = secretId
// we have to calculate the sign here
a.Sign = getTencentSign(*tencentRequest, a, secretId, secretKey)
a.Sign = getTencentSign(*tencentRequest, secretKey)
return tencentRequest, nil
}
func (a *Adaptor) ConvertRerankRequest(c *gin.Context, relayMode int, request dto.RerankRequest) (any, error) {
return nil, nil
}
func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (*http.Response, error) {
return channel.DoApiRequest(a, c, info, requestBody)
}

View File

@@ -1,10 +1,9 @@
package tencent
var ModelList = []string{
"hunyuan-lite",
"hunyuan-standard",
"hunyuan-standard-256K",
"hunyuan-pro",
"ChatPro",
"ChatStd",
"hunyuan",
}
var ChannelName = "tencent"

View File

@@ -1,75 +1,61 @@
package tencent
import "one-api/dto"
type TencentMessage struct {
Role string `json:"Role"`
Content string `json:"Content"`
Role string `json:"role"`
Content string `json:"content"`
}
type TencentChatRequest struct {
// 模型名称,可选值包括 hunyuan-lite、hunyuan-standard、hunyuan-standard-256K、hunyuan-pro。
// 各模型介绍请阅读 [产品概述](https://cloud.tencent.com/document/product/1729/104753) 中的说明。
//
// 注意:
// 不同的模型计费不同,请根据 [购买指南](https://cloud.tencent.com/document/product/1729/97731) 按需调用。
Model *string `json:"Model"`
// 聊天上下文信息。
// 说明:
// 1. 长度最多为 40按对话时间从旧到新在数组中排列。
// 2. Message.Role 可选值system、user、assistant。
// 其中system 角色可选如存在则必须位于列表的最开始。user 和 assistant 需交替出现(一问一答),以 user 提问开始和结束,且 Content 不能为空。Role 的顺序示例:[system可选 user assistant user assistant user ...]。
// 3. Messages 中 Content 总长度不能超过模型输入长度上限(可参考 [产品概述](https://cloud.tencent.com/document/product/1729/104753) 文档),超过则会截断最前面的内容,只保留尾部内容。
Messages []*TencentMessage `json:"Messages"`
// 流式调用开关。
// 说明:
// 1. 未传值时默认为非流式调用false
// 2. 流式调用时以 SSE 协议增量返回结果(返回值取 Choices[n].Delta 中的值,需要拼接增量数据才能获得完整结果)。
// 3. 非流式调用时:
// 调用方式与普通 HTTP 请求无异。
// 接口响应耗时较长,**如需更低时延建议设置为 true**。
// 只返回一次最终结果(返回值取 Choices[n].Message 中的值)。
//
// 注意:
// 通过 SDK 调用时,流式和非流式调用需用**不同的方式**获取返回值,具体参考 SDK 中的注释或示例(在各语言 SDK 代码仓库的 examples/hunyuan/v20230901/ 目录中)。
Stream *bool `json:"Stream,omitempty"`
// 说明:
// 1. 影响输出文本的多样性,取值越大,生成文本的多样性越强。
// 2. 取值区间为 [0.0, 1.0],未传值时使用各模型推荐值。
// 3. 非必要不建议使用,不合理的取值会影响效果。
TopP *float64 `json:"TopP,omitempty"`
// 说明:
// 1. 较高的数值会使输出更加随机,而较低的数值会使其更加集中和确定。
// 2. 取值区间为 [0.0, 2.0],未传值时使用各模型推荐值。
// 3. 非必要不建议使用,不合理的取值会影响效果。
Temperature *float64 `json:"Temperature,omitempty"`
AppId int64 `json:"app_id"` // 腾讯云账号的 APPID
SecretId string `json:"secret_id"` // 官网 SecretId
// Timestamp当前 UNIX 时间戳,单位为秒,可记录发起 API 请求的时间。
// 例如1529223702如果与当前时间相差过大会引起签名过期错误
Timestamp int64 `json:"timestamp"`
// Expired 签名的有效期,是一个符合 UNIX Epoch 时间戳规范的数值,
// 单位为秒Expired 必须大于 Timestamp 且 Expired-Timestamp 小于90天
Expired int64 `json:"expired"`
QueryID string `json:"query_id"` //请求 Id用于问题排查
// Temperature 较高的数值会使输出更加随机,而较低的数值会使其更加集中和确定
// 默认 1.0,取值区间为[0.0,2.0],非必要不建议使用,不合理的取值会影响效果
// 建议该参数和 top_p 只设置1个不要同时更改 top_p
Temperature float64 `json:"temperature"`
// TopP 影响输出文本的多样性,取值越大,生成文本的多样性越强
// 默认1.0,取值区间为[0.0, 1.0],非必要不建议使用, 不合理的取值会影响效果
// 建议该参数和 temperature 只设置1个不要同时更改
TopP float64 `json:"top_p"`
// Stream 0同步1流式 默认协议SSE)
// 同步请求超时60s如果内容较长建议使用流式
Stream int `json:"stream"`
// Messages 会话内容, 长度最多为40, 按对话时间从旧到新在数组中排列
// 输入 content 总数最大支持 3000 token。
Messages []TencentMessage `json:"messages"`
}
type TencentError struct {
Code int `json:"Code"`
Message string `json:"Message"`
Code int `json:"code"`
Message string `json:"message"`
}
type TencentUsage struct {
PromptTokens int `json:"PromptTokens"`
CompletionTokens int `json:"CompletionTokens"`
TotalTokens int `json:"TotalTokens"`
InputTokens int `json:"input_tokens"`
OutputTokens int `json:"output_tokens"`
TotalTokens int `json:"total_tokens"`
}
type TencentResponseChoices struct {
FinishReason string `json:"FinishReason,omitempty"` // 流式结束标志位,为 stop 则表示尾包
Messages TencentMessage `json:"Message,omitempty"` // 内容,同步模式返回内容,流模式为 null 输出 content 内容总数最多支持 1024token。
Delta TencentMessage `json:"Delta,omitempty"` // 内容,流模式返回内容,同步模式为 null 输出 content 内容总数最多支持 1024token。
FinishReason string `json:"finish_reason,omitempty"` // 流式结束标志位,为 stop 则表示尾包
Messages TencentMessage `json:"messages,omitempty"` // 内容,同步模式返回内容,流模式为 null 输出 content 内容总数最多支持 1024token。
Delta TencentMessage `json:"delta,omitempty"` // 内容,流模式返回内容,同步模式为 null 输出 content 内容总数最多支持 1024token。
}
type TencentChatResponse struct {
Choices []TencentResponseChoices `json:"Choices,omitempty"` // 结果
Created int64 `json:"Created,omitempty"` // unix 时间戳的字符串
Id string `json:"Id,omitempty"` // 会话 id
Usage TencentUsage `json:"Usage,omitempty"` // token 数量
Error TencentError `json:"Error,omitempty"` // 错误信息 注意:此字段可能返回 null表示取不到有效值
Note string `json:"Note,omitempty"` // 注释
ReqID string `json:"Req_id,omitempty"` // 唯一请求 Id每次请求都会返回。用于反馈接口入参
}
type TencentChatResponseSB struct {
Response TencentChatResponse `json:"Response,omitempty"`
Choices []TencentResponseChoices `json:"choices,omitempty"` // 结果
Created string `json:"created,omitempty"` // unix 时间戳的字符串
Id string `json:"id,omitempty"` // 会话 id
Usage dto.Usage `json:"usage,omitempty"` // token 数量
Error TencentError `json:"error,omitempty"` // 错误信息 注意:此字段可能返回 null表示取不到有效值
Note string `json:"note,omitempty"` // 注释
ReqID string `json:"req_id,omitempty"` // 唯一请求 Id每次请求都会返回。用于反馈接口入参
}

View File

@@ -3,8 +3,8 @@ package tencent
import (
"bufio"
"crypto/hmac"
"crypto/sha256"
"encoding/hex"
"crypto/sha1"
"encoding/base64"
"encoding/json"
"errors"
"fmt"
@@ -15,46 +15,53 @@ import (
"one-api/dto"
relaycommon "one-api/relay/common"
"one-api/service"
"sort"
"strconv"
"strings"
"time"
)
// https://cloud.tencent.com/document/product/1729/97732
func requestOpenAI2Tencent(a *Adaptor, request dto.GeneralOpenAIRequest) *TencentChatRequest {
messages := make([]*TencentMessage, 0, len(request.Messages))
func requestOpenAI2Tencent(request dto.GeneralOpenAIRequest) *TencentChatRequest {
messages := make([]TencentMessage, 0, len(request.Messages))
for i := 0; i < len(request.Messages); i++ {
message := request.Messages[i]
messages = append(messages, &TencentMessage{
if message.Role == "system" {
messages = append(messages, TencentMessage{
Role: "user",
Content: message.StringContent(),
})
messages = append(messages, TencentMessage{
Role: "assistant",
Content: "Okay",
})
continue
}
messages = append(messages, TencentMessage{
Content: message.StringContent(),
Role: message.Role,
})
}
var req = TencentChatRequest{
Stream: &request.Stream,
Messages: messages,
Model: &request.Model,
stream := 0
if request.Stream {
stream = 1
}
if request.TopP != 0 {
req.TopP = &request.TopP
return &TencentChatRequest{
Timestamp: common.GetTimestamp(),
Expired: common.GetTimestamp() + 24*60*60,
QueryID: common.GetUUID(),
Temperature: request.Temperature,
TopP: request.TopP,
Stream: stream,
Messages: messages,
}
if request.Temperature != 0 {
req.Temperature = &request.Temperature
}
return &req
}
func responseTencent2OpenAI(response *TencentChatResponse) *dto.OpenAITextResponse {
fullTextResponse := dto.OpenAITextResponse{
Id: response.Id,
Object: "chat.completion",
Created: common.GetTimestamp(),
Usage: dto.Usage{
PromptTokens: response.Usage.PromptTokens,
CompletionTokens: response.Usage.CompletionTokens,
TotalTokens: response.Usage.TotalTokens,
},
Usage: response.Usage,
}
if len(response.Choices) > 0 {
content, _ := json.Marshal(response.Choices[0].Messages.Content)
@@ -91,51 +98,69 @@ func streamResponseTencent2OpenAI(TencentResponse *TencentChatResponse) *dto.Cha
func tencentStreamHandler(c *gin.Context, resp *http.Response) (*dto.OpenAIErrorWithStatusCode, string) {
var responseText string
scanner := bufio.NewScanner(resp.Body)
scanner.Split(bufio.ScanLines)
scanner.Split(func(data []byte, atEOF bool) (advance int, token []byte, err error) {
if atEOF && len(data) == 0 {
return 0, nil, nil
}
if i := strings.Index(string(data), "\n"); i >= 0 {
return i + 1, data[0:i], nil
}
if atEOF {
return len(data), data, nil
}
return 0, nil, nil
})
dataChan := make(chan string)
stopChan := make(chan bool)
go func() {
for scanner.Scan() {
data := scanner.Text()
if len(data) < 5 { // ignore blank line or wrong format
continue
}
if data[:5] != "data:" {
continue
}
data = data[5:]
dataChan <- data
}
stopChan <- true
}()
service.SetEventStreamHeaders(c)
for scanner.Scan() {
data := scanner.Text()
if len(data) < 5 || !strings.HasPrefix(data, "data:") {
continue
c.Stream(func(w io.Writer) bool {
select {
case data := <-dataChan:
var TencentResponse TencentChatResponse
err := json.Unmarshal([]byte(data), &TencentResponse)
if err != nil {
common.SysError("error unmarshalling stream response: " + err.Error())
return true
}
response := streamResponseTencent2OpenAI(&TencentResponse)
if len(response.Choices) != 0 {
responseText += response.Choices[0].Delta.GetContentString()
}
jsonResponse, err := json.Marshal(response)
if err != nil {
common.SysError("error marshalling stream response: " + err.Error())
return true
}
c.Render(-1, common.CustomEvent{Data: "data: " + string(jsonResponse)})
return true
case <-stopChan:
c.Render(-1, common.CustomEvent{Data: "data: [DONE]"})
return false
}
data = strings.TrimPrefix(data, "data:")
var tencentResponse TencentChatResponse
err := json.Unmarshal([]byte(data), &tencentResponse)
if err != nil {
common.SysError("error unmarshalling stream response: " + err.Error())
continue
}
response := streamResponseTencent2OpenAI(&tencentResponse)
if len(response.Choices) != 0 {
responseText += response.Choices[0].Delta.GetContentString()
}
err = service.ObjectData(c, response)
if err != nil {
common.SysError(err.Error())
}
}
if err := scanner.Err(); err != nil {
common.SysError("error reading stream: " + err.Error())
}
service.Done(c)
})
err := resp.Body.Close()
if err != nil {
return service.OpenAIErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), ""
}
return nil, responseText
}
func tencentHandler(c *gin.Context, resp *http.Response) (*dto.OpenAIErrorWithStatusCode, *dto.Usage) {
var tencentSb TencentChatResponseSB
var TencentResponse TencentChatResponse
responseBody, err := io.ReadAll(resp.Body)
if err != nil {
return service.OpenAIErrorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil
@@ -144,20 +169,20 @@ func tencentHandler(c *gin.Context, resp *http.Response) (*dto.OpenAIErrorWithSt
if err != nil {
return service.OpenAIErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil
}
err = json.Unmarshal(responseBody, &tencentSb)
err = json.Unmarshal(responseBody, &TencentResponse)
if err != nil {
return service.OpenAIErrorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil
}
if tencentSb.Response.Error.Code != 0 {
if TencentResponse.Error.Code != 0 {
return &dto.OpenAIErrorWithStatusCode{
Error: dto.OpenAIError{
Message: tencentSb.Response.Error.Message,
Code: tencentSb.Response.Error.Code,
Message: TencentResponse.Error.Message,
Code: TencentResponse.Error.Code,
},
StatusCode: resp.StatusCode,
}, nil
}
fullTextResponse := responseTencent2OpenAI(&tencentSb.Response)
fullTextResponse := responseTencent2OpenAI(&TencentResponse)
jsonResponse, err := json.Marshal(fullTextResponse)
if err != nil {
return service.OpenAIErrorWrapper(err, "marshal_response_body_failed", http.StatusInternalServerError), nil
@@ -180,62 +205,29 @@ func parseTencentConfig(config string) (appId int64, secretId string, secretKey
return
}
func sha256hex(s string) string {
b := sha256.Sum256([]byte(s))
return hex.EncodeToString(b[:])
}
func hmacSha256(s, key string) string {
hashed := hmac.New(sha256.New, []byte(key))
hashed.Write([]byte(s))
return string(hashed.Sum(nil))
}
func getTencentSign(req TencentChatRequest, adaptor *Adaptor, secId, secKey string) string {
// build canonical request string
host := "hunyuan.tencentcloudapi.com"
httpRequestMethod := "POST"
canonicalURI := "/"
canonicalQueryString := ""
canonicalHeaders := fmt.Sprintf("content-type:%s\nhost:%s\nx-tc-action:%s\n",
"application/json", host, strings.ToLower(adaptor.Action))
signedHeaders := "content-type;host;x-tc-action"
payload, _ := json.Marshal(req)
hashedRequestPayload := sha256hex(string(payload))
canonicalRequest := fmt.Sprintf("%s\n%s\n%s\n%s\n%s\n%s",
httpRequestMethod,
canonicalURI,
canonicalQueryString,
canonicalHeaders,
signedHeaders,
hashedRequestPayload)
// build string to sign
algorithm := "TC3-HMAC-SHA256"
requestTimestamp := strconv.FormatInt(adaptor.Timestamp, 10)
timestamp, _ := strconv.ParseInt(requestTimestamp, 10, 64)
t := time.Unix(timestamp, 0).UTC()
// must be the format 2006-01-02, ref to package time for more info
date := t.Format("2006-01-02")
credentialScope := fmt.Sprintf("%s/%s/tc3_request", date, "hunyuan")
hashedCanonicalRequest := sha256hex(canonicalRequest)
string2sign := fmt.Sprintf("%s\n%s\n%s\n%s",
algorithm,
requestTimestamp,
credentialScope,
hashedCanonicalRequest)
// sign string
secretDate := hmacSha256(date, "TC3"+secKey)
secretService := hmacSha256("hunyuan", secretDate)
secretKey := hmacSha256("tc3_request", secretService)
signature := hex.EncodeToString([]byte(hmacSha256(string2sign, secretKey)))
// build authorization
authorization := fmt.Sprintf("%s Credential=%s/%s, SignedHeaders=%s, Signature=%s",
algorithm,
secId,
credentialScope,
signedHeaders,
signature)
return authorization
func getTencentSign(req TencentChatRequest, secretKey string) string {
params := make([]string, 0)
params = append(params, "app_id="+strconv.FormatInt(req.AppId, 10))
params = append(params, "secret_id="+req.SecretId)
params = append(params, "timestamp="+strconv.FormatInt(req.Timestamp, 10))
params = append(params, "query_id="+req.QueryID)
params = append(params, "temperature="+strconv.FormatFloat(req.Temperature, 'f', -1, 64))
params = append(params, "top_p="+strconv.FormatFloat(req.TopP, 'f', -1, 64))
params = append(params, "stream="+strconv.Itoa(req.Stream))
params = append(params, "expired="+strconv.FormatInt(req.Expired, 10))
var messageStr string
for _, msg := range req.Messages {
messageStr += fmt.Sprintf(`{"role":"%s","content":"%s"},`, msg.Role, msg.Content)
}
messageStr = strings.TrimSuffix(messageStr, ",")
params = append(params, "messages=["+messageStr+"]")
sort.Sort(sort.StringSlice(params))
url := "hunyuan.cloud.tencent.com/hyllm/v1/chat/completions?" + strings.Join(params, "&")
mac := hmac.New(sha1.New, []byte(secretKey))
signURL := url
mac.Write([]byte(signURL))
sign := mac.Sum([]byte(nil))
return base64.StdEncoding.EncodeToString(sign)
}

View File

@@ -16,17 +16,7 @@ type Adaptor struct {
request *dto.GeneralOpenAIRequest
}
func (a *Adaptor) ConvertAudioRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.AudioRequest) (io.Reader, error) {
//TODO implement me
return nil, errors.New("not implemented")
}
func (a *Adaptor) ConvertImageRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.ImageRequest) (any, error) {
//TODO implement me
return nil, errors.New("not implemented")
}
func (a *Adaptor) Init(info *relaycommon.RelayInfo) {
func (a *Adaptor) Init(info *relaycommon.RelayInfo, request dto.GeneralOpenAIRequest) {
}
func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) {
@@ -38,7 +28,7 @@ func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Request, info *re
return nil
}
func (a *Adaptor) ConvertRequest(c *gin.Context, info *relaycommon.RelayInfo, request *dto.GeneralOpenAIRequest) (any, error) {
func (a *Adaptor) ConvertRequest(c *gin.Context, relayMode int, request *dto.GeneralOpenAIRequest) (any, error) {
if request == nil {
return nil, errors.New("request is nil")
}
@@ -46,10 +36,6 @@ func (a *Adaptor) ConvertRequest(c *gin.Context, info *relaycommon.RelayInfo, re
return request, nil
}
func (a *Adaptor) ConvertRerankRequest(c *gin.Context, relayMode int, request dto.RerankRequest) (any, error) {
return nil, nil
}
func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (*http.Response, error) {
// xunfei's request is not http request, so we don't need to do anything here
dummyResp := &http.Response{}

View File

@@ -6,7 +6,6 @@ var ModelList = []string{
"SparkDesk-v2.1",
"SparkDesk-v3.1",
"SparkDesk-v3.5",
"SparkDesk-v4.0",
}
var ChannelName = "xunfei"

View File

@@ -252,8 +252,6 @@ func apiVersion2domain(apiVersion string) string {
return "generalv3"
case "v3.5":
return "generalv3.5"
case "v4.0":
return "4.0Ultra"
}
return "general" + apiVersion
}

View File

@@ -14,17 +14,7 @@ import (
type Adaptor struct {
}
func (a *Adaptor) ConvertAudioRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.AudioRequest) (io.Reader, error) {
//TODO implement me
return nil, errors.New("not implemented")
}
func (a *Adaptor) ConvertImageRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.ImageRequest) (any, error) {
//TODO implement me
return nil, errors.New("not implemented")
}
func (a *Adaptor) Init(info *relaycommon.RelayInfo) {
func (a *Adaptor) Init(info *relaycommon.RelayInfo, request dto.GeneralOpenAIRequest) {
}
func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) {
@@ -42,7 +32,7 @@ func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Request, info *re
return nil
}
func (a *Adaptor) ConvertRequest(c *gin.Context, info *relaycommon.RelayInfo, request *dto.GeneralOpenAIRequest) (any, error) {
func (a *Adaptor) ConvertRequest(c *gin.Context, relayMode int, request *dto.GeneralOpenAIRequest) (any, error) {
if request == nil {
return nil, errors.New("request is nil")
}
@@ -52,10 +42,6 @@ func (a *Adaptor) ConvertRequest(c *gin.Context, info *relaycommon.RelayInfo, re
return requestOpenAI2Zhipu(*request), nil
}
func (a *Adaptor) ConvertRerankRequest(c *gin.Context, relayMode int, request dto.RerankRequest) (any, error) {
return nil, nil
}
func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (*http.Response, error) {
return channel.DoApiRequest(a, c, info, requestBody)
}

View File

@@ -153,7 +153,18 @@ func streamMetaResponseZhipu2OpenAI(zhipuResponse *ZhipuStreamMetaResponse) (*dt
func zhipuStreamHandler(c *gin.Context, resp *http.Response) (*dto.OpenAIErrorWithStatusCode, *dto.Usage) {
var usage *dto.Usage
scanner := bufio.NewScanner(resp.Body)
scanner.Split(bufio.ScanLines)
scanner.Split(func(data []byte, atEOF bool) (advance int, token []byte, err error) {
if atEOF && len(data) == 0 {
return 0, nil, nil
}
if i := strings.Index(string(data), "\n\n"); i >= 0 && strings.Index(string(data), ":") >= 0 {
return i + 2, data[0:i], nil
}
if atEOF {
return len(data), data, nil
}
return 0, nil, nil
})
dataChan := make(chan string)
metaChan := make(chan string)
stopChan := make(chan bool)

View File

@@ -10,22 +10,13 @@ import (
"one-api/relay/channel"
"one-api/relay/channel/openai"
relaycommon "one-api/relay/common"
"one-api/service"
)
type Adaptor struct {
}
func (a *Adaptor) ConvertAudioRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.AudioRequest) (io.Reader, error) {
//TODO implement me
return nil, errors.New("not implemented")
}
func (a *Adaptor) ConvertImageRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.ImageRequest) (any, error) {
//TODO implement me
return nil, errors.New("not implemented")
}
func (a *Adaptor) Init(info *relaycommon.RelayInfo) {
func (a *Adaptor) Init(info *relaycommon.RelayInfo, request dto.GeneralOpenAIRequest) {
}
func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) {
@@ -39,7 +30,7 @@ func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Request, info *re
return nil
}
func (a *Adaptor) ConvertRequest(c *gin.Context, info *relaycommon.RelayInfo, request *dto.GeneralOpenAIRequest) (any, error) {
func (a *Adaptor) ConvertRequest(c *gin.Context, relayMode int, request *dto.GeneralOpenAIRequest) (any, error) {
if request == nil {
return nil, errors.New("request is nil")
}
@@ -49,17 +40,17 @@ func (a *Adaptor) ConvertRequest(c *gin.Context, info *relaycommon.RelayInfo, re
return requestOpenAI2Zhipu(*request), nil
}
func (a *Adaptor) ConvertRerankRequest(c *gin.Context, relayMode int, request dto.RerankRequest) (any, error) {
return nil, nil
}
func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (*http.Response, error) {
return channel.DoApiRequest(a, c, info, requestBody)
}
func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage *dto.Usage, err *dto.OpenAIErrorWithStatusCode) {
if info.IsStream {
err, usage = openai.OaiStreamHandler(c, resp, info)
var responseText string
var toolCount int
err, responseText, toolCount = openai.OpenaiStreamHandler(c, resp, info.RelayMode)
usage, _ = service.ResponseText2Usage(responseText, info.UpstreamModelName, info.PromptTokens)
usage.CompletionTokens += toolCount * 7
} else {
err, usage = openai.OpenaiHandler(c, resp, info.PromptTokens, info.UpstreamModelName)
}

View File

@@ -1,7 +1,7 @@
package zhipu_4v
var ModelList = []string{
"glm-4", "glm-4v", "glm-3-turbo", "glm-4-alltools",
"glm-4", "glm-4v", "glm-3-turbo",
}
var ChannelName = "zhipu_4v"

View File

@@ -9,27 +9,23 @@ import (
)
type RelayInfo struct {
ChannelType int
ChannelId int
TokenId int
UserId int
Group string
TokenUnlimited bool
StartTime time.Time
FirstResponseTime time.Time
setFirstResponse bool
ApiType int
IsStream bool
RelayMode int
UpstreamModelName string
RequestURLPath string
ApiVersion string
PromptTokens int
ApiKey string
Organization string
BaseUrl string
SupportStreamOptions bool
ShouldIncludeUsage bool
ChannelType int
ChannelId int
TokenId int
UserId int
Group string
TokenUnlimited bool
StartTime time.Time
ApiType int
IsStream bool
RelayMode int
UpstreamModelName string
RequestURLPath string
ApiVersion string
PromptTokens int
ApiKey string
Organization string
BaseUrl string
}
func GenRelayInfo(c *gin.Context) *RelayInfo {
@@ -41,26 +37,24 @@ func GenRelayInfo(c *gin.Context) *RelayInfo {
group := c.GetString("group")
tokenUnlimited := c.GetBool("token_unlimited_quota")
startTime := time.Now()
// firstResponseTime = time.Now() - 1 second
apiType, _ := constant.ChannelType2APIType(channelType)
info := &RelayInfo{
RelayMode: constant.Path2RelayMode(c.Request.URL.Path),
BaseUrl: c.GetString("base_url"),
RequestURLPath: c.Request.URL.String(),
ChannelType: channelType,
ChannelId: channelId,
TokenId: tokenId,
UserId: userId,
Group: group,
TokenUnlimited: tokenUnlimited,
StartTime: startTime,
FirstResponseTime: startTime.Add(-time.Second),
ApiType: apiType,
ApiVersion: c.GetString("api_version"),
ApiKey: strings.TrimPrefix(c.Request.Header.Get("Authorization"), "Bearer "),
Organization: c.GetString("channel_organization"),
RelayMode: constant.Path2RelayMode(c.Request.URL.Path),
BaseUrl: c.GetString("base_url"),
RequestURLPath: c.Request.URL.String(),
ChannelType: channelType,
ChannelId: channelId,
TokenId: tokenId,
UserId: userId,
Group: group,
TokenUnlimited: tokenUnlimited,
StartTime: startTime,
ApiType: apiType,
ApiVersion: c.GetString("api_version"),
ApiKey: strings.TrimPrefix(c.Request.Header.Get("Authorization"), "Bearer "),
Organization: c.GetString("channel_organization"),
}
if info.BaseUrl == "" {
info.BaseUrl = common.ChannelBaseURLs[channelType]
@@ -68,11 +62,6 @@ func GenRelayInfo(c *gin.Context) *RelayInfo {
if info.ChannelType == common.ChannelTypeAzure {
info.ApiVersion = GetAPIVersion(c)
}
if info.ChannelType == common.ChannelTypeOpenAI || info.ChannelType == common.ChannelTypeAnthropic ||
info.ChannelType == common.ChannelTypeAws || info.ChannelType == common.ChannelTypeGemini ||
info.ChannelType == common.ChannelCloudflare {
info.SupportStreamOptions = true
}
return info
}
@@ -84,13 +73,6 @@ func (info *RelayInfo) SetIsStream(isStream bool) {
info.IsStream = isStream
}
func (info *RelayInfo) SetFirstResponseTime() {
if !info.setFirstResponse {
info.FirstResponseTime = time.Now()
info.setFirstResponse = true
}
}
type TaskRelayInfo struct {
ChannelType int
ChannelId int

View File

@@ -1,17 +1,49 @@
package common
import (
"encoding/json"
"fmt"
"github.com/gin-gonic/gin"
_ "image/gif"
_ "image/jpeg"
_ "image/png"
"io"
"net/http"
"one-api/common"
"one-api/dto"
"strconv"
"strings"
)
var StopFinishReason = "stop"
func RelayErrorHandler(resp *http.Response) (OpenAIErrorWithStatusCode *dto.OpenAIErrorWithStatusCode) {
OpenAIErrorWithStatusCode = &dto.OpenAIErrorWithStatusCode{
StatusCode: resp.StatusCode,
Error: dto.OpenAIError{
Message: fmt.Sprintf("bad response status code %d", resp.StatusCode),
Type: "upstream_error",
Code: "bad_response_status_code",
Param: strconv.Itoa(resp.StatusCode),
},
}
responseBody, err := io.ReadAll(resp.Body)
if err != nil {
return
}
err = resp.Body.Close()
if err != nil {
return
}
var textResponse dto.TextResponseWithError
err = json.Unmarshal(responseBody, &textResponse)
if err != nil {
return
}
OpenAIErrorWithStatusCode.Error = textResponse.Error
return
}
func GetFullRequestURL(baseURL string, requestURL string, channelType int) string {
fullRequestURL := fmt.Sprintf("%s%s", baseURL, requestURL)

View File

@@ -20,9 +20,6 @@ const (
APITypePerplexity
APITypeAws
APITypeCohere
APITypeDify
APITypeJina
APITypeCloudflare
APITypeDummy // this one is only for count, do not add any channel after this
)
@@ -60,12 +57,6 @@ func ChannelType2APIType(channelType int) (int, bool) {
apiType = APITypeAws
case common.ChannelTypeCohere:
apiType = APITypeCohere
case common.ChannelTypeDify:
apiType = APITypeDify
case common.ChannelTypeJina:
apiType = APITypeJina
case common.ChannelCloudflare:
apiType = APITypeCloudflare
}
if apiType == -1 {
return APITypeOpenAI, false

View File

@@ -13,7 +13,6 @@ const (
RelayModeModerations
RelayModeImagesGenerations
RelayModeEdits
RelayModeMidjourneyImagine
RelayModeMidjourneyDescribe
RelayModeMidjourneyBlend
@@ -23,20 +22,16 @@ const (
RelayModeMidjourneyTaskFetch
RelayModeMidjourneyTaskImageSeed
RelayModeMidjourneyTaskFetchByCondition
RelayModeAudioSpeech
RelayModeAudioTranscription
RelayModeAudioTranslation
RelayModeMidjourneyAction
RelayModeMidjourneyModal
RelayModeMidjourneyShorten
RelayModeSwapFace
RelayModeAudioSpeech // tts
RelayModeAudioTranscription // whisper
RelayModeAudioTranslation // whisper
RelayModeSunoFetch
RelayModeSunoFetchByID
RelayModeSunoSubmit
RelayModeRerank
)
func Path2RelayMode(path string) int {
@@ -61,8 +56,6 @@ func Path2RelayMode(path string) int {
relayMode = RelayModeAudioTranscription
} else if strings.HasPrefix(path, "/v1/audio/translations") {
relayMode = RelayModeAudioTranslation
} else if strings.HasPrefix(path, "/v1/rerank") {
relayMode = RelayModeRerank
}
return relayMode
}

View File

@@ -1,10 +1,13 @@
package relay
import (
"bytes"
"context"
"encoding/json"
"errors"
"fmt"
"github.com/gin-gonic/gin"
"io"
"net/http"
"one-api/common"
"one-api/constant"
@@ -13,73 +16,71 @@ import (
relaycommon "one-api/relay/common"
relayconstant "one-api/relay/constant"
"one-api/service"
"strings"
"time"
)
func getAndValidAudioRequest(c *gin.Context, info *relaycommon.RelayInfo) (*dto.AudioRequest, error) {
audioRequest := &dto.AudioRequest{}
err := common.UnmarshalBodyReusable(c, audioRequest)
if err != nil {
return nil, err
}
switch info.RelayMode {
case relayconstant.RelayModeAudioSpeech:
if audioRequest.Model == "" {
return nil, errors.New("model is required")
func AudioHelper(c *gin.Context, relayMode int) *dto.OpenAIErrorWithStatusCode {
tokenId := c.GetInt("token_id")
channelType := c.GetInt("channel")
channelId := c.GetInt("channel_id")
userId := c.GetInt("id")
group := c.GetString("group")
startTime := time.Now()
var audioRequest dto.TextToSpeechRequest
if !strings.HasPrefix(c.Request.URL.Path, "/v1/audio/transcriptions") {
err := common.UnmarshalBodyReusable(c, &audioRequest)
if err != nil {
return service.OpenAIErrorWrapper(err, "bind_request_body_failed", http.StatusBadRequest)
}
if constant.ShouldCheckPromptSensitive() {
err := service.CheckSensitiveInput(audioRequest.Input)
if err != nil {
return nil, err
}
}
default:
if audioRequest.Model == "" {
audioRequest.Model = c.PostForm("model")
}
if audioRequest.Model == "" {
return nil, errors.New("model is required")
}
if audioRequest.ResponseFormat == "" {
audioRequest.ResponseFormat = "json"
} else {
audioRequest = dto.TextToSpeechRequest{
Model: "whisper-1",
}
}
return audioRequest, nil
}
//err := common.UnmarshalBodyReusable(c, &audioRequest)
func AudioHelper(c *gin.Context) *dto.OpenAIErrorWithStatusCode {
relayInfo := relaycommon.GenRelayInfo(c)
audioRequest, err := getAndValidAudioRequest(c, relayInfo)
if err != nil {
common.LogError(c, fmt.Sprintf("getAndValidAudioRequest failed: %s", err.Error()))
return service.OpenAIErrorWrapper(err, "invalid_audio_request", http.StatusBadRequest)
// request validation
if audioRequest.Model == "" {
return service.OpenAIErrorWrapper(errors.New("model is required"), "required_field_missing", http.StatusBadRequest)
}
if strings.HasPrefix(audioRequest.Model, "tts-1") {
if audioRequest.Voice == "" {
return service.OpenAIErrorWrapper(errors.New("voice is required"), "required_field_missing", http.StatusBadRequest)
}
}
var err error
promptTokens := 0
preConsumedTokens := common.PreConsumedQuota
if relayInfo.RelayMode == relayconstant.RelayModeAudioSpeech {
if strings.HasPrefix(audioRequest.Model, "tts-1") {
if constant.ShouldCheckPromptSensitive() {
err = service.CheckSensitiveInput(audioRequest.Input)
if err != nil {
return service.OpenAIErrorWrapper(err, "sensitive_words_detected", http.StatusBadRequest)
}
}
promptTokens, err = service.CountAudioToken(audioRequest.Input, audioRequest.Model)
if err != nil {
return service.OpenAIErrorWrapper(err, "count_audio_token_failed", http.StatusInternalServerError)
}
preConsumedTokens = promptTokens
relayInfo.PromptTokens = promptTokens
}
modelRatio := common.GetModelRatio(audioRequest.Model)
groupRatio := common.GetGroupRatio(relayInfo.Group)
groupRatio := common.GetGroupRatio(group)
ratio := modelRatio * groupRatio
preConsumedQuota := int(float64(preConsumedTokens) * ratio)
userQuota, err := model.CacheGetUserQuota(relayInfo.UserId)
userQuota, err := model.CacheGetUserQuota(userId)
if err != nil {
return service.OpenAIErrorWrapperLocal(err, "get_user_quota_failed", http.StatusInternalServerError)
return service.OpenAIErrorWrapper(err, "get_user_quota_failed", http.StatusInternalServerError)
}
if userQuota-preConsumedQuota < 0 {
return service.OpenAIErrorWrapperLocal(errors.New("user quota is not enough"), "insufficient_user_quota", http.StatusForbidden)
return service.OpenAIErrorWrapper(errors.New("user quota is not enough"), "insufficient_user_quota", http.StatusForbidden)
}
err = model.CacheDecreaseUserQuota(relayInfo.UserId, preConsumedQuota)
err = model.CacheDecreaseUserQuota(userId, preConsumedQuota)
if err != nil {
return service.OpenAIErrorWrapperLocal(err, "decrease_user_quota_failed", http.StatusInternalServerError)
return service.OpenAIErrorWrapper(err, "decrease_user_quota_failed", http.StatusInternalServerError)
}
if userQuota > 100*preConsumedQuota {
// in this case, we do not pre-consume quota
@@ -87,12 +88,28 @@ func AudioHelper(c *gin.Context) *dto.OpenAIErrorWithStatusCode {
preConsumedQuota = 0
}
if preConsumedQuota > 0 {
userQuota, err = model.PreConsumeTokenQuota(relayInfo.TokenId, preConsumedQuota)
userQuota, err = model.PreConsumeTokenQuota(tokenId, preConsumedQuota)
if err != nil {
return service.OpenAIErrorWrapperLocal(err, "pre_consume_token_quota_failed", http.StatusForbidden)
return service.OpenAIErrorWrapper(err, "pre_consume_token_quota_failed", http.StatusForbidden)
}
}
succeed := false
defer func() {
if succeed {
return
}
if preConsumedQuota > 0 {
// we need to roll back the pre-consumed quota
defer func() {
go func() {
// negative means add quota back for token & user
returnPreConsumedQuota(c, tokenId, userQuota, preConsumedQuota)
}()
}()
}
}()
// map model name
modelMapping := c.GetString("model_mapping")
if modelMapping != "" {
@@ -105,44 +122,133 @@ func AudioHelper(c *gin.Context) *dto.OpenAIErrorWithStatusCode {
audioRequest.Model = modelMap[audioRequest.Model]
}
}
relayInfo.UpstreamModelName = audioRequest.Model
adaptor := GetAdaptor(relayInfo.ApiType)
if adaptor == nil {
return service.OpenAIErrorWrapperLocal(fmt.Errorf("invalid api type: %d", relayInfo.ApiType), "invalid_api_type", http.StatusBadRequest)
baseURL := common.ChannelBaseURLs[channelType]
requestURL := c.Request.URL.String()
if c.GetString("base_url") != "" {
baseURL = c.GetString("base_url")
}
adaptor.Init(relayInfo)
ioReader, err := adaptor.ConvertAudioRequest(c, relayInfo, *audioRequest)
fullRequestURL := relaycommon.GetFullRequestURL(baseURL, requestURL, channelType)
if relayMode == relayconstant.RelayModeAudioTranscription && channelType == common.ChannelTypeAzure {
// https://learn.microsoft.com/en-us/azure/ai-services/openai/whisper-quickstart?tabs=command-line#rest-api
apiVersion := relaycommon.GetAPIVersion(c)
fullRequestURL = fmt.Sprintf("%s/openai/deployments/%s/audio/transcriptions?api-version=%s", baseURL, audioRequest.Model, apiVersion)
}
requestBody := c.Request.Body
req, err := http.NewRequest(c.Request.Method, fullRequestURL, requestBody)
if err != nil {
return service.OpenAIErrorWrapperLocal(err, "convert_request_failed", http.StatusInternalServerError)
return service.OpenAIErrorWrapper(err, "new_request_failed", http.StatusInternalServerError)
}
resp, err := adaptor.DoRequest(c, relayInfo, ioReader)
if relayMode == relayconstant.RelayModeAudioTranscription && channelType == common.ChannelTypeAzure {
// https://learn.microsoft.com/en-us/azure/ai-services/openai/whisper-quickstart?tabs=command-line#rest-api
apiKey := c.Request.Header.Get("Authorization")
apiKey = strings.TrimPrefix(apiKey, "Bearer ")
req.Header.Set("api-key", apiKey)
req.ContentLength = c.Request.ContentLength
} else {
req.Header.Set("Authorization", c.Request.Header.Get("Authorization"))
}
req.Header.Set("Content-Type", c.Request.Header.Get("Content-Type"))
req.Header.Set("Accept", c.Request.Header.Get("Accept"))
resp, err := service.GetHttpClient().Do(req)
if err != nil {
return service.OpenAIErrorWrapper(err, "do_request_failed", http.StatusInternalServerError)
}
statusCodeMappingStr := c.GetString("status_code_mapping")
if resp != nil {
if resp.StatusCode != http.StatusOK {
returnPreConsumedQuota(c, relayInfo.TokenId, userQuota, preConsumedQuota)
openaiErr := service.RelayErrorHandler(resp)
// reset status code 重置状态码
service.ResetStatusCode(openaiErr, statusCodeMappingStr)
return openaiErr
err = req.Body.Close()
if err != nil {
return service.OpenAIErrorWrapper(err, "close_request_body_failed", http.StatusInternalServerError)
}
err = c.Request.Body.Close()
if err != nil {
return service.OpenAIErrorWrapper(err, "close_request_body_failed", http.StatusInternalServerError)
}
if resp.StatusCode != http.StatusOK {
return relaycommon.RelayErrorHandler(resp)
}
succeed = true
var audioResponse dto.AudioResponse
defer func(ctx context.Context) {
go func() {
useTimeSeconds := time.Now().Unix() - startTime.Unix()
quota := 0
if strings.HasPrefix(audioRequest.Model, "tts-1") {
quota = promptTokens
} else {
quota, err = service.CountAudioToken(audioResponse.Text, audioRequest.Model)
}
quota = int(float64(quota) * ratio)
if ratio != 0 && quota <= 0 {
quota = 1
}
quotaDelta := quota - preConsumedQuota
err := model.PostConsumeTokenQuota(tokenId, userQuota, quotaDelta, preConsumedQuota, true)
if err != nil {
common.SysError("error consuming token remain quota: " + err.Error())
}
err = model.CacheUpdateUserQuota(userId)
if err != nil {
common.SysError("error update user quota cache: " + err.Error())
}
if quota != 0 {
tokenName := c.GetString("token_name")
logContent := fmt.Sprintf("模型倍率 %.2f,分组倍率 %.2f", modelRatio, groupRatio)
other := make(map[string]interface{})
other["model_ratio"] = modelRatio
other["group_ratio"] = groupRatio
model.RecordConsumeLog(ctx, userId, channelId, promptTokens, 0, audioRequest.Model, tokenName, quota, logContent, tokenId, userQuota, int(useTimeSeconds), false, other)
model.UpdateUserUsedQuotaAndRequestCount(userId, quota)
channelId := c.GetInt("channel_id")
model.UpdateChannelUsedQuota(channelId, quota)
}
}()
}(c.Request.Context())
responseBody, err := io.ReadAll(resp.Body)
if err != nil {
return service.OpenAIErrorWrapper(err, "read_response_body_failed", http.StatusInternalServerError)
}
err = resp.Body.Close()
if err != nil {
return service.OpenAIErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError)
}
if strings.HasPrefix(audioRequest.Model, "tts-1") {
} else {
err = json.Unmarshal(responseBody, &audioResponse)
if err != nil {
return service.OpenAIErrorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError)
}
contains, words := service.SensitiveWordContains(audioResponse.Text)
if contains {
return service.OpenAIErrorWrapper(errors.New("response contains sensitive words: "+strings.Join(words, ", ")), "response_contains_sensitive_words", http.StatusBadRequest)
}
}
usage, openaiErr := adaptor.DoResponse(c, resp, relayInfo)
if openaiErr != nil {
returnPreConsumedQuota(c, relayInfo.TokenId, userQuota, preConsumedQuota)
// reset status code 重置状态码
service.ResetStatusCode(openaiErr, statusCodeMappingStr)
return openaiErr
resp.Body = io.NopCloser(bytes.NewBuffer(responseBody))
for k, v := range resp.Header {
c.Writer.Header().Set(k, v[0])
}
c.Writer.WriteHeader(resp.StatusCode)
postConsumeQuota(c, relayInfo, audioRequest.Model, usage, ratio, preConsumedQuota, userQuota, modelRatio, groupRatio, 0, false, "")
_, err = io.Copy(c.Writer, resp.Body)
if err != nil {
return service.OpenAIErrorWrapper(err, "copy_response_body_failed", http.StatusInternalServerError)
}
err = resp.Body.Close()
if err != nil {
return service.OpenAIErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError)
}
return nil
}

View File

@@ -2,6 +2,7 @@ package relay
import (
"bytes"
"context"
"encoding/json"
"errors"
"fmt"
@@ -13,71 +14,72 @@ import (
"one-api/dto"
"one-api/model"
relaycommon "one-api/relay/common"
relayconstant "one-api/relay/constant"
"one-api/service"
"strings"
"time"
)
func getAndValidImageRequest(c *gin.Context, info *relaycommon.RelayInfo) (*dto.ImageRequest, error) {
imageRequest := &dto.ImageRequest{}
err := common.UnmarshalBodyReusable(c, imageRequest)
func RelayImageHelper(c *gin.Context, relayMode int) *dto.OpenAIErrorWithStatusCode {
tokenId := c.GetInt("token_id")
channelType := c.GetInt("channel")
channelId := c.GetInt("channel_id")
userId := c.GetInt("id")
group := c.GetString("group")
startTime := time.Now()
var imageRequest dto.ImageRequest
err := common.UnmarshalBodyReusable(c, &imageRequest)
if err != nil {
return nil, err
return service.OpenAIErrorWrapper(err, "bind_request_body_failed", http.StatusBadRequest)
}
if imageRequest.Prompt == "" {
return nil, errors.New("prompt is required")
}
if strings.Contains(imageRequest.Size, "×") {
return nil, errors.New("size an unexpected error occurred in the parameter, please use 'x' instead of the multiplication sign '×'")
}
if imageRequest.N == 0 {
imageRequest.N = 1
if imageRequest.Model == "" {
imageRequest.Model = "dall-e-3"
}
if imageRequest.Size == "" {
imageRequest.Size = "1024x1024"
}
if imageRequest.Model == "" {
imageRequest.Model = "dall-e-2"
if imageRequest.N == 0 {
imageRequest.N = 1
}
if imageRequest.Quality == "" {
imageRequest.Quality = "standard"
// Prompt validation
if imageRequest.Prompt == "" {
return service.OpenAIErrorWrapper(errors.New("prompt is required"), "required_field_missing", http.StatusBadRequest)
}
if constant.ShouldCheckPromptSensitive() {
err = service.CheckSensitiveInput(imageRequest.Prompt)
if err != nil {
return service.OpenAIErrorWrapper(err, "sensitive_words_detected", http.StatusBadRequest)
}
}
if strings.Contains(imageRequest.Size, "×") {
return service.OpenAIErrorWrapper(errors.New("size an unexpected error occurred in the parameter, please use 'x' instead of the multiplication sign '×'"), "invalid_field_value", http.StatusBadRequest)
}
// Not "256x256", "512x512", or "1024x1024"
if imageRequest.Model == "dall-e-2" || imageRequest.Model == "dall-e" {
if imageRequest.Size != "" && imageRequest.Size != "256x256" && imageRequest.Size != "512x512" && imageRequest.Size != "1024x1024" {
return nil, errors.New("size must be one of 256x256, 512x512, or 1024x1024, dall-e-3 1024x1792 or 1792x1024")
return service.OpenAIErrorWrapper(errors.New("size must be one of 256x256, 512x512, or 1024x1024, dall-e-3 1024x1792 or 1792x1024"), "invalid_field_value", http.StatusBadRequest)
}
} else if imageRequest.Model == "dall-e-3" {
if imageRequest.Size != "" && imageRequest.Size != "1024x1024" && imageRequest.Size != "1024x1792" && imageRequest.Size != "1792x1024" {
return nil, errors.New("size must be one of 256x256, 512x512, or 1024x1024, dall-e-3 1024x1792 or 1792x1024")
return service.OpenAIErrorWrapper(errors.New("size must be one of 256x256, 512x512, or 1024x1024, dall-e-3 1024x1792 or 1792x1024"), "invalid_field_value", http.StatusBadRequest)
}
if imageRequest.N != 1 {
return service.OpenAIErrorWrapper(errors.New("n must be 1"), "invalid_field_value", http.StatusBadRequest)
}
//if imageRequest.N != 1 {
// return nil, errors.New("n must be 1")
//}
}
// N should between 1 and 10
//if imageRequest.N != 0 && (imageRequest.N < 1 || imageRequest.N > 10) {
// return service.OpenAIErrorWrapper(errors.New("n must be between 1 and 10"), "invalid_field_value", http.StatusBadRequest)
//}
if constant.ShouldCheckPromptSensitive() {
err := service.CheckSensitiveInput(imageRequest.Prompt)
if err != nil {
return nil, err
}
}
return imageRequest, nil
}
func ImageHelper(c *gin.Context, relayMode int) *dto.OpenAIErrorWithStatusCode {
relayInfo := relaycommon.GenRelayInfo(c)
imageRequest, err := getAndValidImageRequest(c, relayInfo)
if err != nil {
common.LogError(c, fmt.Sprintf("getAndValidImageRequest failed: %s", err.Error()))
return service.OpenAIErrorWrapper(err, "invalid_image_request", http.StatusBadRequest)
if imageRequest.N != 0 && (imageRequest.N < 1 || imageRequest.N > 10) {
return service.OpenAIErrorWrapper(errors.New("n must be between 1 and 10"), "invalid_field_value", http.StatusBadRequest)
}
// map model name
modelMapping := c.GetString("model_mapping")
isModelMapped := false
if modelMapping != "" {
modelMap := make(map[string]string)
err := json.Unmarshal([]byte(modelMapping), &modelMap)
@@ -86,9 +88,31 @@ func ImageHelper(c *gin.Context, relayMode int) *dto.OpenAIErrorWithStatusCode {
}
if modelMap[imageRequest.Model] != "" {
imageRequest.Model = modelMap[imageRequest.Model]
isModelMapped = true
}
}
relayInfo.UpstreamModelName = imageRequest.Model
baseURL := common.ChannelBaseURLs[channelType]
requestURL := c.Request.URL.String()
if c.GetString("base_url") != "" {
baseURL = c.GetString("base_url")
}
fullRequestURL := relaycommon.GetFullRequestURL(baseURL, requestURL, channelType)
if channelType == common.ChannelTypeAzure && relayMode == relayconstant.RelayModeImagesGenerations {
// https://learn.microsoft.com/en-us/azure/ai-services/openai/dall-e-quickstart?tabs=dalle3%2Ccommand-line&pivots=rest-api
apiVersion := relaycommon.GetAPIVersion(c)
// https://{resource_name}.openai.azure.com/openai/deployments/dall-e-3/images/generations?api-version=2023-06-01-preview
fullRequestURL = fmt.Sprintf("%s/openai/deployments/%s/images/generations?api-version=%s", baseURL, imageRequest.Model, apiVersion)
}
var requestBody io.Reader
if isModelMapped || channelType == common.ChannelTypeAzure { // make Azure channel request body
jsonStr, err := json.Marshal(imageRequest)
if err != nil {
return service.OpenAIErrorWrapper(err, "marshal_text_request_failed", http.StatusInternalServerError)
}
requestBody = bytes.NewBuffer(jsonStr)
} else {
requestBody = c.Request.Body
}
modelPrice, success := common.GetModelPrice(imageRequest.Model, true)
if !success {
@@ -97,9 +121,8 @@ func ImageHelper(c *gin.Context, relayMode int) *dto.OpenAIErrorWithStatusCode {
// per 1 modelRatio = $0.04 / 16
modelPrice = 0.0025 * modelRatio
}
groupRatio := common.GetGroupRatio(relayInfo.Group)
userQuota, err := model.CacheGetUserQuota(relayInfo.UserId)
groupRatio := common.GetGroupRatio(group)
userQuota, err := model.CacheGetUserQuota(userId)
sizeRatio := 1.0
// Size
@@ -124,63 +147,101 @@ func ImageHelper(c *gin.Context, relayMode int) *dto.OpenAIErrorWithStatusCode {
quota := int(modelPrice*groupRatio*common.QuotaPerUnit*sizeRatio*qualityRatio) * imageRequest.N
if userQuota-quota < 0 {
return service.OpenAIErrorWrapperLocal(errors.New("user quota is not enough"), "insufficient_user_quota", http.StatusForbidden)
return service.OpenAIErrorWrapper(errors.New("user quota is not enough"), "insufficient_user_quota", http.StatusForbidden)
}
adaptor := GetAdaptor(relayInfo.ApiType)
if adaptor == nil {
return service.OpenAIErrorWrapperLocal(fmt.Errorf("invalid api type: %d", relayInfo.ApiType), "invalid_api_type", http.StatusBadRequest)
}
adaptor.Init(relayInfo)
var requestBody io.Reader
convertedRequest, err := adaptor.ConvertImageRequest(c, relayInfo, *imageRequest)
req, err := http.NewRequest(c.Request.Method, fullRequestURL, requestBody)
if err != nil {
return service.OpenAIErrorWrapperLocal(err, "convert_request_failed", http.StatusInternalServerError)
return service.OpenAIErrorWrapper(err, "new_request_failed", http.StatusInternalServerError)
}
jsonData, err := json.Marshal(convertedRequest)
if err != nil {
return service.OpenAIErrorWrapperLocal(err, "json_marshal_failed", http.StatusInternalServerError)
token := c.Request.Header.Get("Authorization")
if channelType == common.ChannelTypeAzure { // Azure authentication
token = strings.TrimPrefix(token, "Bearer ")
req.Header.Set("api-key", token)
} else {
req.Header.Set("Authorization", token)
}
requestBody = bytes.NewBuffer(jsonData)
req.Header.Set("Content-Type", c.Request.Header.Get("Content-Type"))
req.Header.Set("Accept", c.Request.Header.Get("Accept"))
statusCodeMappingStr := c.GetString("status_code_mapping")
resp, err := adaptor.DoRequest(c, relayInfo, requestBody)
resp, err := service.GetHttpClient().Do(req)
if err != nil {
return service.OpenAIErrorWrapper(err, "do_request_failed", http.StatusInternalServerError)
}
if resp != nil {
relayInfo.IsStream = relayInfo.IsStream || strings.HasPrefix(resp.Header.Get("Content-Type"), "text/event-stream")
err = req.Body.Close()
if err != nil {
return service.OpenAIErrorWrapper(err, "close_request_body_failed", http.StatusInternalServerError)
}
err = c.Request.Body.Close()
if err != nil {
return service.OpenAIErrorWrapper(err, "close_request_body_failed", http.StatusInternalServerError)
}
if resp.StatusCode != http.StatusOK {
return relaycommon.RelayErrorHandler(resp)
}
var textResponse dto.ImageResponse
defer func(ctx context.Context) {
useTimeSeconds := time.Now().Unix() - startTime.Unix()
if resp.StatusCode != http.StatusOK {
openaiErr := service.RelayErrorHandler(resp)
// reset status code 重置状态码
service.ResetStatusCode(openaiErr, statusCodeMappingStr)
return openaiErr
return
}
err := model.PostConsumeTokenQuota(tokenId, userQuota, quota, 0, true)
if err != nil {
common.SysError("error consuming token remain quota: " + err.Error())
}
err = model.CacheUpdateUserQuota(userId)
if err != nil {
common.SysError("error update user quota cache: " + err.Error())
}
if quota != 0 {
tokenName := c.GetString("token_name")
quality := "normal"
if imageRequest.Quality == "hd" {
quality = "hd"
}
logContent := fmt.Sprintf("模型价格 %.2f,分组倍率 %.2f, 大小 %s, 品质 %s", modelPrice, groupRatio, imageRequest.Size, quality)
other := make(map[string]interface{})
other["model_price"] = modelPrice
other["group_ratio"] = groupRatio
model.RecordConsumeLog(ctx, userId, channelId, 0, 0, imageRequest.Model, tokenName, quota, logContent, tokenId, userQuota, int(useTimeSeconds), false, other)
model.UpdateUserUsedQuotaAndRequestCount(userId, quota)
channelId := c.GetInt("channel_id")
model.UpdateChannelUsedQuota(channelId, quota)
}
}(c.Request.Context())
responseBody, err := io.ReadAll(resp.Body)
if err != nil {
return service.OpenAIErrorWrapper(err, "read_response_body_failed", http.StatusInternalServerError)
}
err = resp.Body.Close()
if err != nil {
return service.OpenAIErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError)
}
err = json.Unmarshal(responseBody, &textResponse)
if err != nil {
return service.OpenAIErrorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError)
}
_, openaiErr := adaptor.DoResponse(c, resp, relayInfo)
if openaiErr != nil {
// reset status code 重置状态码
service.ResetStatusCode(openaiErr, statusCodeMappingStr)
return openaiErr
resp.Body = io.NopCloser(bytes.NewBuffer(responseBody))
for k, v := range resp.Header {
c.Writer.Header().Set(k, v[0])
}
c.Writer.WriteHeader(resp.StatusCode)
usage := &dto.Usage{
PromptTokens: imageRequest.N,
TotalTokens: imageRequest.N,
_, err = io.Copy(c.Writer, resp.Body)
if err != nil {
return service.OpenAIErrorWrapper(err, "copy_response_body_failed", http.StatusInternalServerError)
}
quality := "standard"
if imageRequest.Quality == "hd" {
quality = "hd"
err = resp.Body.Close()
if err != nil {
return service.OpenAIErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError)
}
logContent := fmt.Sprintf("大小 %s, 品质 %s", imageRequest.Size, quality)
postConsumeQuota(c, relayInfo, imageRequest.Model, usage, 0, 0, userQuota, 0, groupRatio, modelPrice, true, logContent)
return nil
}

Some files were not shown because too many files have changed in this diff Show More