Compare commits

...

14 Commits

Author SHA1 Message Date
CalciumIon
a075598757 fix: stream options 2024-07-08 21:54:32 +08:00
CalciumIon
a984daa503 feat: update FORCE_STREAM_OPTION default value 2024-07-08 21:41:52 +08:00
CalciumIon
90abe7f27d fix: baidu max_output_tokens (#353) 2024-07-08 19:50:12 +08:00
CalciumIon
bb313eb26f ci: update ci 2024-07-08 19:48:03 +08:00
CalciumIon
02545e4856 fix: baidu max_output_tokens (close #353) 2024-07-08 19:46:45 +08:00
CalciumIon
49cec50908 fix: channel default test model 2024-07-08 17:06:29 +08:00
CalciumIon
4f6710e50c fix: 修复渠道晒筛选后无法展开测试模型 (close #297 #302) 2024-07-08 17:00:10 +08:00
CalciumIon
03b130f2b5 feat: 允许设置是否检测mj任务已完成才可进行action操作 (close #349) 2024-07-08 16:48:10 +08:00
CalciumIon
45b9de9df9 feat: able to use email to login (close #343,#348) 2024-07-08 16:28:56 +08:00
CalciumIon
e062cf32e3 fix: 日志详情 2024-07-08 15:48:28 +08:00
CalciumIon
52debe7572 feat: 完善stream_options 2024-07-08 02:04:21 +08:00
CalciumIon
df6502733c feat: 完善stream_options 2024-07-08 02:00:39 +08:00
CalciumIon
9896ba0a64 feat: support aws stream_options 2024-07-08 01:52:40 +08:00
CalciumIon
e8b93ed6ec feat: support claude stream_options 2024-07-08 01:45:43 +08:00
21 changed files with 127 additions and 47 deletions

View File

@@ -4,6 +4,7 @@ on:
push:
tags:
- '*'
- '!*-alpha*'
workflow_dispatch:
inputs:
name:

View File

@@ -4,6 +4,7 @@ var MjNotifyEnabled = false
var MjAccountFilterEnabled = false
var MjModeClearEnabled = false
var MjForwardUrlEnabled = true
var MjActionCheckSuccessEnabled = true
const (
MjErrorUnknown = 5

View File

@@ -67,8 +67,8 @@ func testChannel(channel *model.Channel, testModel string) (err error, openaiErr
if channel.TestModel != nil && *channel.TestModel != "" {
testModel = *channel.TestModel
} else {
if len(adaptor.GetModelList()) > 0 {
testModel = adaptor.GetModelList()[0]
if len(channel.GetModels()) > 0 {
testModel = channel.GetModels()[0]
} else {
testModel = "gpt-3.5-turbo"
}

View File

@@ -102,6 +102,7 @@ type ChatCompletionsStreamResponse struct {
Model string `json:"model"`
SystemFingerprint *string `json:"system_fingerprint"`
Choices []ChatCompletionsStreamResponseChoice `json:"choices"`
Usage *Usage `json:"usage"`
}
type ChatCompletionsStreamResponseSimple struct {

View File

@@ -4,6 +4,7 @@ import (
"encoding/json"
"gorm.io/gorm"
"one-api/common"
"strings"
)
type Channel struct {
@@ -33,6 +34,13 @@ type Channel struct {
OtherInfo string `json:"other_info"`
}
func (channel *Channel) GetModels() []string {
if channel.Models == "" {
return []string{}
}
return strings.Split(strings.Trim(channel.Models, ","), ",")
}
func (channel *Channel) GetOtherInfo() map[string]interface{} {
otherInfo := make(map[string]interface{})
if channel.OtherInfo != "" {

View File

@@ -99,6 +99,7 @@ func InitOptionMap() {
common.OptionMap["MjAccountFilterEnabled"] = strconv.FormatBool(constant.MjAccountFilterEnabled)
common.OptionMap["MjModeClearEnabled"] = strconv.FormatBool(constant.MjModeClearEnabled)
common.OptionMap["MjForwardUrlEnabled"] = strconv.FormatBool(constant.MjForwardUrlEnabled)
common.OptionMap["MjActionCheckSuccessEnabled"] = strconv.FormatBool(constant.MjActionCheckSuccessEnabled)
common.OptionMap["CheckSensitiveEnabled"] = strconv.FormatBool(constant.CheckSensitiveEnabled)
common.OptionMap["CheckSensitiveOnPromptEnabled"] = strconv.FormatBool(constant.CheckSensitiveOnPromptEnabled)
//common.OptionMap["CheckSensitiveOnCompletionEnabled"] = strconv.FormatBool(constant.CheckSensitiveOnCompletionEnabled)
@@ -210,6 +211,8 @@ func updateOptionMap(key string, value string) (err error) {
constant.MjModeClearEnabled = boolValue
case "MjForwardUrlEnabled":
constant.MjForwardUrlEnabled = boolValue
case "MjActionCheckSuccessEnabled":
constant.MjActionCheckSuccessEnabled = boolValue
case "CheckSensitiveEnabled":
constant.CheckSensitiveEnabled = boolValue
case "CheckSensitiveOnPromptEnabled":

View File

@@ -295,10 +295,15 @@ func (user *User) ValidateAndFill() (err error) {
// that means if your fields value is 0, '', false or other zero values,
// it wont be used to build query conditions
password := user.Password
if user.Username == "" || password == "" {
if (user.Username == "" && user.Email == "") || password == "" {
return errors.New("用户名或密码为空")
}
DB.Where(User{Username: user.Username}).First(user)
// find buy username or email
if user.Username != "" {
DB.Where(User{Username: user.Username}).First(user)
} else if user.Email != "" {
DB.Where(User{Email: user.Email}).First(user)
}
okay := common.ValidatePasswordAndHash(password, user.Password)
if !okay || user.Status != common.UserStatusEnabled {
return errors.New("用户名或密码错误,或用户已被封禁")

View File

@@ -68,7 +68,7 @@ func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, request
func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage *dto.Usage, err *dto.OpenAIErrorWithStatusCode) {
if info.IsStream {
err, usage = awsStreamHandler(c, info, a.RequestMode)
err, usage = awsStreamHandler(c, resp, info, a.RequestMode)
} else {
err, usage = awsHandler(c, info, a.RequestMode)
}

View File

@@ -13,6 +13,7 @@ import (
relaymodel "one-api/dto"
"one-api/relay/channel/claude"
relaycommon "one-api/relay/common"
"one-api/service"
"strings"
"time"
@@ -112,7 +113,7 @@ func awsHandler(c *gin.Context, info *relaycommon.RelayInfo, requestMode int) (*
return nil, &usage
}
func awsStreamHandler(c *gin.Context, info *relaycommon.RelayInfo, requestMode int) (*relaymodel.OpenAIErrorWithStatusCode, *relaymodel.Usage) {
func awsStreamHandler(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo, requestMode int) (*relaymodel.OpenAIErrorWithStatusCode, *relaymodel.Usage) {
awsCli, err := newAwsClient(c, info)
if err != nil {
return wrapErr(errors.Wrap(err, "newAwsClient")), nil
@@ -162,7 +163,6 @@ func awsStreamHandler(c *gin.Context, info *relaycommon.RelayInfo, requestMode i
c.Stream(func(w io.Writer) bool {
event, ok := <-stream.Events()
if !ok {
c.Render(-1, common.CustomEvent{Data: "data: [DONE]"})
return false
}
@@ -214,6 +214,17 @@ func awsStreamHandler(c *gin.Context, info *relaycommon.RelayInfo, requestMode i
return false
}
})
if info.ShouldIncludeUsage {
response := service.GenerateFinalUsageResponse(id, createdTime, info.UpstreamModelName, usage)
err := service.ObjectData(c, response)
if err != nil {
common.SysError("send final response failed: " + err.Error())
}
}
service.Done(c)
err = resp.Body.Close()
if err != nil {
return service.OpenAIErrorWrapperLocal(err, "close_response_body_failed", http.StatusInternalServerError), nil
}
return nil, &usage
}

View File

@@ -19,7 +19,7 @@ type BaiduChatRequest struct {
System string `json:"system,omitempty"`
DisableSearch bool `json:"disable_search,omitempty"`
EnableCitation bool `json:"enable_citation,omitempty"`
MaxOutputTokens int `json:"max_output_tokens,omitempty"`
MaxOutputTokens *int `json:"max_output_tokens,omitempty"`
UserId string `json:"user_id,omitempty"`
}

View File

@@ -23,14 +23,20 @@ var baiduTokenStore sync.Map
func requestOpenAI2Baidu(request dto.GeneralOpenAIRequest) *BaiduChatRequest {
baiduRequest := BaiduChatRequest{
Temperature: request.Temperature,
TopP: request.TopP,
PenaltyScore: request.FrequencyPenalty,
Stream: request.Stream,
DisableSearch: false,
EnableCitation: false,
MaxOutputTokens: int(request.MaxTokens),
UserId: request.User,
Temperature: request.Temperature,
TopP: request.TopP,
PenaltyScore: request.FrequencyPenalty,
Stream: request.Stream,
DisableSearch: false,
EnableCitation: false,
UserId: request.User,
}
if request.MaxTokens != 0 {
maxTokens := int(request.MaxTokens)
if request.MaxTokens == 1 {
maxTokens = 2
}
baiduRequest.MaxOutputTokens = &maxTokens
}
for _, message := range request.Messages {
if message.Role == "system" {

View File

@@ -330,22 +330,15 @@ func claudeStreamHandler(c *gin.Context, resp *http.Response, info *relaycommon.
response.Created = createdTime
response.Model = info.UpstreamModelName
jsonStr, err := json.Marshal(response)
err = service.ObjectData(c, response)
if err != nil {
common.SysError("error marshalling stream response: " + err.Error())
return true
common.SysError(err.Error())
}
c.Render(-1, common.CustomEvent{Data: "data: " + string(jsonStr)})
return true
case <-stopChan:
c.Render(-1, common.CustomEvent{Data: "data: [DONE]"})
return false
}
})
err := resp.Body.Close()
if err != nil {
return service.OpenAIErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil
}
if requestMode == RequestModeCompletion {
usage, _ = service.ResponseText2Usage(responseText, info.UpstreamModelName, info.PromptTokens)
} else {
@@ -356,6 +349,18 @@ func claudeStreamHandler(c *gin.Context, resp *http.Response, info *relaycommon.
usage, _ = service.ResponseText2Usage(responseText, info.UpstreamModelName, usage.PromptTokens)
}
}
if info.ShouldIncludeUsage {
response := service.GenerateFinalUsageResponse(responseId, createdTime, info.UpstreamModelName, *usage)
err := service.ObjectData(c, response)
if err != nil {
common.SysError("send final response failed: " + err.Error())
}
}
service.Done(c)
err := resp.Body.Close()
if err != nil {
return service.OpenAIErrorWrapperLocal(err, "close_response_body_failed", http.StatusInternalServerError), nil
}
return nil, usage
}

View File

@@ -7,7 +7,6 @@ import (
"io"
"net/http"
"one-api/common"
"one-api/constant"
"one-api/dto"
"one-api/relay/channel"
"one-api/relay/channel/ai360"
@@ -20,8 +19,7 @@ import (
)
type Adaptor struct {
ChannelType int
SupportStreamOptions bool
ChannelType int
}
func (a *Adaptor) ConvertRerankRequest(c *gin.Context, relayMode int, request dto.RerankRequest) (any, error) {
@@ -33,7 +31,6 @@ func (a *Adaptor) InitRerank(info *relaycommon.RelayInfo, request dto.RerankRequ
func (a *Adaptor) Init(info *relaycommon.RelayInfo, request dto.GeneralOpenAIRequest) {
a.ChannelType = info.ChannelType
a.SupportStreamOptions = info.SupportStreamOptions
}
func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) {
@@ -81,17 +78,6 @@ func (a *Adaptor) ConvertRequest(c *gin.Context, relayMode int, request *dto.Gen
if request == nil {
return nil, errors.New("request is nil")
}
// 如果不支持StreamOptions将StreamOptions设置为nil
if !a.SupportStreamOptions {
request.StreamOptions = nil
} else {
// 如果支持StreamOptions且请求中没有设置StreamOptions根据配置文件设置StreamOptions
if constant.ForceStreamOption {
request.StreamOptions = &dto.StreamOptions{
IncludeUsage: true,
}
}
}
return request, nil
}

View File

@@ -28,6 +28,7 @@ type RelayInfo struct {
Organization string
BaseUrl string
SupportStreamOptions bool
ShouldIncludeUsage bool
}
func GenRelayInfo(c *gin.Context) *RelayInfo {
@@ -66,7 +67,7 @@ func GenRelayInfo(c *gin.Context) *RelayInfo {
if info.ChannelType == common.ChannelTypeAzure {
info.ApiVersion = GetAPIVersion(c)
}
if info.ChannelType == common.ChannelTypeOpenAI {
if info.ChannelType == common.ChannelTypeOpenAI || info.ChannelType == common.ChannelTypeAnthropic || info.ChannelType == common.ChannelTypeAws {
info.SupportStreamOptions = true
}
return info

View File

@@ -415,9 +415,12 @@ func RelayMidjourneySubmit(c *gin.Context, relayMode int) *dto.MidjourneyRespons
originTask := model.GetByMJId(userId, mjId)
if originTask == nil {
return service.MidjourneyErrorWrapper(constant.MjRequestError, "task_not_found")
} else if originTask.Status != "SUCCESS" && relayMode != relayconstant.RelayModeMidjourneyModal {
return service.MidjourneyErrorWrapper(constant.MjRequestError, "task_status_not_success")
} else { //原任务的Status=SUCCESS则可以做放大UPSCALE、变换VARIATION等动作此时必须使用原来的请求地址才能正确处理
if constant.MjActionCheckSuccessEnabled {
if originTask.Status != "SUCCESS" && relayMode != relayconstant.RelayModeMidjourneyModal {
return service.MidjourneyErrorWrapper(constant.MjRequestError, "task_status_not_success")
}
}
channel, err := model.GetChannelById(originTask.ChannelId, true)
if err != nil {
return service.MidjourneyErrorWrapper(constant.MjRequestError, "get_channel_info_failed")

View File

@@ -130,6 +130,22 @@ func TextHelper(c *gin.Context) *dto.OpenAIErrorWithStatusCode {
return openaiErr
}
// 如果不支持StreamOptions将StreamOptions设置为nil
if !relayInfo.SupportStreamOptions || !textRequest.Stream {
textRequest.StreamOptions = nil
} else {
// 如果支持StreamOptions且请求中没有设置StreamOptions根据配置文件设置StreamOptions
if constant.ForceStreamOption {
textRequest.StreamOptions = &dto.StreamOptions{
IncludeUsage: true,
}
}
}
if textRequest.StreamOptions != nil && textRequest.StreamOptions.IncludeUsage {
relayInfo.ShouldIncludeUsage = textRequest.StreamOptions.IncludeUsage
}
adaptor := GetAdaptor(relayInfo.ApiType)
if adaptor == nil {
return service.OpenAIErrorWrapperLocal(fmt.Errorf("invalid api type: %d", relayInfo.ApiType), "invalid_api_type", http.StatusBadRequest)

View File

@@ -24,3 +24,15 @@ func ResponseText2Usage(responseText string, modeName string, promptTokens int)
usage.TotalTokens = usage.PromptTokens + usage.CompletionTokens
return usage, err
}
func GenerateFinalUsageResponse(id string, createAt int64, model string, usage dto.Usage) *dto.ChatCompletionsStreamResponse {
return &dto.ChatCompletionsStreamResponse{
Id: id,
Object: "chat.completion.chunk",
Created: createAt,
Model: model,
SystemFingerprint: nil,
Choices: make([]dto.ChatCompletionsStreamResponseChoice, 0),
Usage: &usage,
}
}

View File

@@ -550,7 +550,7 @@ const ChannelsTable = () => {
);
const { success, message, data } = res.data;
if (success) {
setChannels(data);
setChannelFormat(data);
setActivePage(1);
} else {
showError(message);

View File

@@ -42,6 +42,7 @@ const OperationSetting = () => {
MjAccountFilterEnabled: false,
MjModeClearEnabled: false,
MjForwardUrlEnabled: false,
MjActionCheckSuccessEnabled: false,
DrawingEnabled: false,
DataExportEnabled: false,
DataExportDefaultTime: 'hour',

View File

@@ -153,8 +153,8 @@ export function renderModelPrice(
let inputRatioPrice = modelRatio * 2.0;
let completionRatioPrice = modelRatio * 2.0 * completionRatio;
let price =
(inputTokens / 1000000) * inputRatioPrice +
(completionTokens / 1000000) * completionRatioPrice;
(inputTokens / 1000000) * inputRatioPrice * groupRatio +
(completionTokens / 1000000) * completionRatioPrice * groupRatio;
return (
<>
<article>

View File

@@ -16,6 +16,7 @@ export default function SettingsDrawing(props) {
MjAccountFilterEnabled: false,
MjForwardUrlEnabled: false,
MjModeClearEnabled: false,
MjActionCheckSuccessEnabled: false,
});
const refForm = useRef();
const [inputsRow, setInputsRow] = useState(inputs);
@@ -156,6 +157,25 @@ export default function SettingsDrawing(props) {
}
/>
</Col>
<Col span={8}>
<Form.Switch
field={'MjActionCheckSuccessEnabled'}
label={
<>
检测必须等待绘图成功才能进行放大等操作
</>
}
size='large'
checkedText=''
uncheckedText=''
onChange={(value) =>
setInputs({
...inputs,
MjActionCheckSuccessEnabled: value,
})
}
/>
</Col>
</Row>
<Row>
<Button size='large' onClick={onSubmit}>