mirror of
https://github.com/QuantumNous/new-api.git
synced 2026-03-29 23:10:35 +00:00
248 lines
6.6 KiB
Go
248 lines
6.6 KiB
Go
package controller
|
|
|
|
import (
|
|
"context"
|
|
"errors"
|
|
"fmt"
|
|
"net/http"
|
|
"net/url"
|
|
"strconv"
|
|
"strings"
|
|
"time"
|
|
|
|
"github.com/QuantumNous/new-api/common"
|
|
"github.com/QuantumNous/new-api/constant"
|
|
"github.com/QuantumNous/new-api/model"
|
|
"github.com/QuantumNous/new-api/relay/channel/codex"
|
|
"github.com/QuantumNous/new-api/service"
|
|
|
|
"github.com/gin-contrib/sessions"
|
|
"github.com/gin-gonic/gin"
|
|
)
|
|
|
|
type codexOAuthCompleteRequest struct {
|
|
Input string `json:"input"`
|
|
}
|
|
|
|
func codexOAuthSessionKey(channelID int, field string) string {
|
|
return fmt.Sprintf("codex_oauth_%s_%d", field, channelID)
|
|
}
|
|
|
|
func parseCodexAuthorizationInput(input string) (code string, state string, err error) {
|
|
v := strings.TrimSpace(input)
|
|
if v == "" {
|
|
return "", "", errors.New("empty input")
|
|
}
|
|
if strings.Contains(v, "#") {
|
|
parts := strings.SplitN(v, "#", 2)
|
|
code = strings.TrimSpace(parts[0])
|
|
state = strings.TrimSpace(parts[1])
|
|
return code, state, nil
|
|
}
|
|
if strings.Contains(v, "code=") {
|
|
u, parseErr := url.Parse(v)
|
|
if parseErr == nil {
|
|
q := u.Query()
|
|
code = strings.TrimSpace(q.Get("code"))
|
|
state = strings.TrimSpace(q.Get("state"))
|
|
return code, state, nil
|
|
}
|
|
q, parseErr := url.ParseQuery(v)
|
|
if parseErr == nil {
|
|
code = strings.TrimSpace(q.Get("code"))
|
|
state = strings.TrimSpace(q.Get("state"))
|
|
return code, state, nil
|
|
}
|
|
}
|
|
|
|
code = v
|
|
return code, "", nil
|
|
}
|
|
|
|
func StartCodexOAuth(c *gin.Context) {
|
|
startCodexOAuthWithChannelID(c, 0)
|
|
}
|
|
|
|
func StartCodexOAuthForChannel(c *gin.Context) {
|
|
channelID, err := strconv.Atoi(c.Param("id"))
|
|
if err != nil {
|
|
common.ApiError(c, fmt.Errorf("invalid channel id: %w", err))
|
|
return
|
|
}
|
|
startCodexOAuthWithChannelID(c, channelID)
|
|
}
|
|
|
|
func startCodexOAuthWithChannelID(c *gin.Context, channelID int) {
|
|
if channelID > 0 {
|
|
ch, err := model.GetChannelById(channelID, false)
|
|
if err != nil {
|
|
common.ApiError(c, err)
|
|
return
|
|
}
|
|
if ch == nil {
|
|
c.JSON(http.StatusOK, gin.H{"success": false, "message": "channel not found"})
|
|
return
|
|
}
|
|
if ch.Type != constant.ChannelTypeCodex {
|
|
c.JSON(http.StatusOK, gin.H{"success": false, "message": "channel type is not Codex"})
|
|
return
|
|
}
|
|
}
|
|
|
|
flow, err := service.CreateCodexOAuthAuthorizationFlow()
|
|
if err != nil {
|
|
common.ApiError(c, err)
|
|
return
|
|
}
|
|
|
|
session := sessions.Default(c)
|
|
session.Set(codexOAuthSessionKey(channelID, "state"), flow.State)
|
|
session.Set(codexOAuthSessionKey(channelID, "verifier"), flow.Verifier)
|
|
session.Set(codexOAuthSessionKey(channelID, "created_at"), time.Now().Unix())
|
|
_ = session.Save()
|
|
|
|
c.JSON(http.StatusOK, gin.H{
|
|
"success": true,
|
|
"message": "",
|
|
"data": gin.H{
|
|
"authorize_url": flow.AuthorizeURL,
|
|
},
|
|
})
|
|
}
|
|
|
|
func CompleteCodexOAuth(c *gin.Context) {
|
|
completeCodexOAuthWithChannelID(c, 0)
|
|
}
|
|
|
|
func CompleteCodexOAuthForChannel(c *gin.Context) {
|
|
channelID, err := strconv.Atoi(c.Param("id"))
|
|
if err != nil {
|
|
common.ApiError(c, fmt.Errorf("invalid channel id: %w", err))
|
|
return
|
|
}
|
|
completeCodexOAuthWithChannelID(c, channelID)
|
|
}
|
|
|
|
func completeCodexOAuthWithChannelID(c *gin.Context, channelID int) {
|
|
req := codexOAuthCompleteRequest{}
|
|
if err := c.ShouldBindJSON(&req); err != nil {
|
|
common.ApiError(c, err)
|
|
return
|
|
}
|
|
|
|
code, state, err := parseCodexAuthorizationInput(req.Input)
|
|
if err != nil {
|
|
common.SysError("failed to parse codex authorization input: " + err.Error())
|
|
c.JSON(http.StatusOK, gin.H{"success": false, "message": "解析授权信息失败,请检查输入格式"})
|
|
return
|
|
}
|
|
if strings.TrimSpace(code) == "" {
|
|
c.JSON(http.StatusOK, gin.H{"success": false, "message": "missing authorization code"})
|
|
return
|
|
}
|
|
if strings.TrimSpace(state) == "" {
|
|
c.JSON(http.StatusOK, gin.H{"success": false, "message": "missing state in input"})
|
|
return
|
|
}
|
|
|
|
channelProxy := ""
|
|
if channelID > 0 {
|
|
ch, err := model.GetChannelById(channelID, false)
|
|
if err != nil {
|
|
common.ApiError(c, err)
|
|
return
|
|
}
|
|
if ch == nil {
|
|
c.JSON(http.StatusOK, gin.H{"success": false, "message": "channel not found"})
|
|
return
|
|
}
|
|
if ch.Type != constant.ChannelTypeCodex {
|
|
c.JSON(http.StatusOK, gin.H{"success": false, "message": "channel type is not Codex"})
|
|
return
|
|
}
|
|
channelProxy = ch.GetSetting().Proxy
|
|
}
|
|
|
|
session := sessions.Default(c)
|
|
expectedState, _ := session.Get(codexOAuthSessionKey(channelID, "state")).(string)
|
|
verifier, _ := session.Get(codexOAuthSessionKey(channelID, "verifier")).(string)
|
|
if strings.TrimSpace(expectedState) == "" || strings.TrimSpace(verifier) == "" {
|
|
c.JSON(http.StatusOK, gin.H{"success": false, "message": "oauth flow not started or session expired"})
|
|
return
|
|
}
|
|
if state != expectedState {
|
|
c.JSON(http.StatusOK, gin.H{"success": false, "message": "state mismatch"})
|
|
return
|
|
}
|
|
|
|
ctx, cancel := context.WithTimeout(c.Request.Context(), 15*time.Second)
|
|
defer cancel()
|
|
|
|
tokenRes, err := service.ExchangeCodexAuthorizationCodeWithProxy(ctx, code, verifier, channelProxy)
|
|
if err != nil {
|
|
common.SysError("failed to exchange codex authorization code: " + err.Error())
|
|
c.JSON(http.StatusOK, gin.H{"success": false, "message": "授权码交换失败,请重试"})
|
|
return
|
|
}
|
|
|
|
accountID, ok := service.ExtractCodexAccountIDFromJWT(tokenRes.AccessToken)
|
|
if !ok {
|
|
c.JSON(http.StatusOK, gin.H{"success": false, "message": "failed to extract account_id from access_token"})
|
|
return
|
|
}
|
|
email, _ := service.ExtractEmailFromJWT(tokenRes.AccessToken)
|
|
|
|
key := codex.OAuthKey{
|
|
AccessToken: tokenRes.AccessToken,
|
|
RefreshToken: tokenRes.RefreshToken,
|
|
AccountID: accountID,
|
|
LastRefresh: time.Now().Format(time.RFC3339),
|
|
Expired: tokenRes.ExpiresAt.Format(time.RFC3339),
|
|
Email: email,
|
|
Type: "codex",
|
|
}
|
|
encoded, err := common.Marshal(key)
|
|
if err != nil {
|
|
common.ApiError(c, err)
|
|
return
|
|
}
|
|
|
|
session.Delete(codexOAuthSessionKey(channelID, "state"))
|
|
session.Delete(codexOAuthSessionKey(channelID, "verifier"))
|
|
session.Delete(codexOAuthSessionKey(channelID, "created_at"))
|
|
_ = session.Save()
|
|
|
|
if channelID > 0 {
|
|
if err := model.DB.Model(&model.Channel{}).Where("id = ?", channelID).Update("key", string(encoded)).Error; err != nil {
|
|
common.ApiError(c, err)
|
|
return
|
|
}
|
|
model.InitChannelCache()
|
|
service.ResetProxyClientCache()
|
|
c.JSON(http.StatusOK, gin.H{
|
|
"success": true,
|
|
"message": "saved",
|
|
"data": gin.H{
|
|
"channel_id": channelID,
|
|
"account_id": accountID,
|
|
"email": email,
|
|
"expires_at": key.Expired,
|
|
"last_refresh": key.LastRefresh,
|
|
},
|
|
})
|
|
return
|
|
}
|
|
|
|
c.JSON(http.StatusOK, gin.H{
|
|
"success": true,
|
|
"message": "generated",
|
|
"data": gin.H{
|
|
"key": string(encoded),
|
|
"account_id": accountID,
|
|
"email": email,
|
|
"expires_at": key.Expired,
|
|
"last_refresh": key.LastRefresh,
|
|
},
|
|
})
|
|
}
|