feat: codex channel (#2652)

* feat: codex channel

* feat: codex channel

* feat: codex oauth flow

* feat: codex refresh cred

* feat: codex usage

* fix: codex err message detail

* fix: codex setting ui

* feat: codex refresh cred task

* fix: import err

* fix: codex store must be false

* fix: chat -> responses tool call

* fix: chat -> responses tool call
This commit is contained in:
Seefs
2026-01-14 22:29:43 +08:00
committed by GitHub
parent ca11fcbabd
commit e5cb9ac03a
28 changed files with 2052 additions and 32 deletions

View File

@@ -0,0 +1,104 @@
package service
import (
"context"
"errors"
"fmt"
"strings"
"time"
"github.com/QuantumNous/new-api/common"
"github.com/QuantumNous/new-api/constant"
"github.com/QuantumNous/new-api/model"
)
type CodexCredentialRefreshOptions struct {
ResetCaches bool
}
type CodexOAuthKey struct {
IDToken string `json:"id_token,omitempty"`
AccessToken string `json:"access_token,omitempty"`
RefreshToken string `json:"refresh_token,omitempty"`
AccountID string `json:"account_id,omitempty"`
LastRefresh string `json:"last_refresh,omitempty"`
Email string `json:"email,omitempty"`
Type string `json:"type,omitempty"`
Expired string `json:"expired,omitempty"`
}
func parseCodexOAuthKey(raw string) (*CodexOAuthKey, error) {
if strings.TrimSpace(raw) == "" {
return nil, errors.New("codex channel: empty oauth key")
}
var key CodexOAuthKey
if err := common.Unmarshal([]byte(raw), &key); err != nil {
return nil, errors.New("codex channel: invalid oauth key json")
}
return &key, nil
}
func RefreshCodexChannelCredential(ctx context.Context, channelID int, opts CodexCredentialRefreshOptions) (*CodexOAuthKey, *model.Channel, error) {
ch, err := model.GetChannelById(channelID, true)
if err != nil {
return nil, nil, err
}
if ch == nil {
return nil, nil, fmt.Errorf("channel not found")
}
if ch.Type != constant.ChannelTypeCodex {
return nil, nil, fmt.Errorf("channel type is not Codex")
}
oauthKey, err := parseCodexOAuthKey(strings.TrimSpace(ch.Key))
if err != nil {
return nil, nil, err
}
if strings.TrimSpace(oauthKey.RefreshToken) == "" {
return nil, nil, fmt.Errorf("codex channel: refresh_token is required to refresh credential")
}
refreshCtx, cancel := context.WithTimeout(ctx, 10*time.Second)
defer cancel()
res, err := RefreshCodexOAuthToken(refreshCtx, oauthKey.RefreshToken)
if err != nil {
return nil, nil, err
}
oauthKey.AccessToken = res.AccessToken
oauthKey.RefreshToken = res.RefreshToken
oauthKey.LastRefresh = time.Now().Format(time.RFC3339)
oauthKey.Expired = res.ExpiresAt.Format(time.RFC3339)
if strings.TrimSpace(oauthKey.Type) == "" {
oauthKey.Type = "codex"
}
if strings.TrimSpace(oauthKey.AccountID) == "" {
if accountID, ok := ExtractCodexAccountIDFromJWT(oauthKey.AccessToken); ok {
oauthKey.AccountID = accountID
}
}
if strings.TrimSpace(oauthKey.Email) == "" {
if email, ok := ExtractEmailFromJWT(oauthKey.AccessToken); ok {
oauthKey.Email = email
}
}
encoded, err := common.Marshal(oauthKey)
if err != nil {
return nil, nil, err
}
if err := model.DB.Model(&model.Channel{}).Where("id = ?", ch.Id).Update("key", string(encoded)).Error; err != nil {
return nil, nil, err
}
if opts.ResetCaches {
model.InitChannelCache()
ResetProxyClientCache()
}
return oauthKey, ch, nil
}

View File

@@ -0,0 +1,140 @@
package service
import (
"context"
"fmt"
"strings"
"sync"
"sync/atomic"
"time"
"github.com/QuantumNous/new-api/common"
"github.com/QuantumNous/new-api/constant"
"github.com/QuantumNous/new-api/logger"
"github.com/QuantumNous/new-api/model"
"github.com/bytedance/gopkg/util/gopool"
)
const (
codexCredentialRefreshTickInterval = 10 * time.Minute
codexCredentialRefreshThreshold = 24 * time.Hour
codexCredentialRefreshBatchSize = 200
codexCredentialRefreshTimeout = 15 * time.Second
)
var (
codexCredentialRefreshOnce sync.Once
codexCredentialRefreshRunning atomic.Bool
)
func StartCodexCredentialAutoRefreshTask() {
codexCredentialRefreshOnce.Do(func() {
if !common.IsMasterNode {
return
}
gopool.Go(func() {
logger.LogInfo(context.Background(), fmt.Sprintf("codex credential auto-refresh task started: tick=%s threshold=%s", codexCredentialRefreshTickInterval, codexCredentialRefreshThreshold))
ticker := time.NewTicker(codexCredentialRefreshTickInterval)
defer ticker.Stop()
runCodexCredentialAutoRefreshOnce()
for range ticker.C {
runCodexCredentialAutoRefreshOnce()
}
})
})
}
func runCodexCredentialAutoRefreshOnce() {
if !codexCredentialRefreshRunning.CompareAndSwap(false, true) {
return
}
defer codexCredentialRefreshRunning.Store(false)
ctx := context.Background()
now := time.Now()
var refreshed int
var scanned int
offset := 0
for {
var channels []*model.Channel
err := model.DB.
Select("id", "name", "key", "status", "channel_info").
Where("type = ? AND status = 1", constant.ChannelTypeCodex).
Order("id asc").
Limit(codexCredentialRefreshBatchSize).
Offset(offset).
Find(&channels).Error
if err != nil {
logger.LogError(ctx, fmt.Sprintf("codex credential auto-refresh: query channels failed: %v", err))
return
}
if len(channels) == 0 {
break
}
offset += codexCredentialRefreshBatchSize
for _, ch := range channels {
if ch == nil {
continue
}
scanned++
if ch.ChannelInfo.IsMultiKey {
continue
}
rawKey := strings.TrimSpace(ch.Key)
if rawKey == "" {
continue
}
oauthKey, err := parseCodexOAuthKey(rawKey)
if err != nil {
continue
}
refreshToken := strings.TrimSpace(oauthKey.RefreshToken)
if refreshToken == "" {
continue
}
expiredAtRaw := strings.TrimSpace(oauthKey.Expired)
expiredAt, err := time.Parse(time.RFC3339, expiredAtRaw)
if err == nil && !expiredAt.IsZero() && expiredAt.Sub(now) > codexCredentialRefreshThreshold {
continue
}
refreshCtx, cancel := context.WithTimeout(ctx, codexCredentialRefreshTimeout)
newKey, _, err := RefreshCodexChannelCredential(refreshCtx, ch.Id, CodexCredentialRefreshOptions{ResetCaches: false})
cancel()
if err != nil {
logger.LogWarn(ctx, fmt.Sprintf("codex credential auto-refresh: channel_id=%d name=%s refresh failed: %v", ch.Id, ch.Name, err))
continue
}
refreshed++
logger.LogInfo(ctx, fmt.Sprintf("codex credential auto-refresh: channel_id=%d name=%s refreshed, expires_at=%s", ch.Id, ch.Name, newKey.Expired))
}
}
if refreshed > 0 {
func() {
defer func() {
if r := recover(); r != nil {
logger.LogWarn(ctx, fmt.Sprintf("codex credential auto-refresh: InitChannelCache panic: %v", r))
}
}()
model.InitChannelCache()
}()
ResetProxyClientCache()
}
if common.DebugEnabled {
logger.LogDebug(ctx, "codex credential auto-refresh: scanned=%d refreshed=%d", scanned, refreshed)
}
}

288
service/codex_oauth.go Normal file
View File

@@ -0,0 +1,288 @@
package service
import (
"context"
"crypto/rand"
"crypto/sha256"
"encoding/base64"
"encoding/json"
"errors"
"fmt"
"net/http"
"net/url"
"strings"
"time"
)
const (
codexOAuthClientID = "app_EMoamEEZ73f0CkXaXp7hrann"
codexOAuthAuthorizeURL = "https://auth.openai.com/oauth/authorize"
codexOAuthTokenURL = "https://auth.openai.com/oauth/token"
codexOAuthRedirectURI = "http://localhost:1455/auth/callback"
codexOAuthScope = "openid profile email offline_access"
codexJWTClaimPath = "https://api.openai.com/auth"
defaultHTTPTimeout = 20 * time.Second
)
type CodexOAuthTokenResult struct {
AccessToken string
RefreshToken string
ExpiresAt time.Time
}
type CodexOAuthAuthorizationFlow struct {
State string
Verifier string
Challenge string
AuthorizeURL string
}
func RefreshCodexOAuthToken(ctx context.Context, refreshToken string) (*CodexOAuthTokenResult, error) {
client := &http.Client{Timeout: defaultHTTPTimeout}
return refreshCodexOAuthToken(ctx, client, codexOAuthTokenURL, codexOAuthClientID, refreshToken)
}
func ExchangeCodexAuthorizationCode(ctx context.Context, code string, verifier string) (*CodexOAuthTokenResult, error) {
client := &http.Client{Timeout: defaultHTTPTimeout}
return exchangeCodexAuthorizationCode(ctx, client, codexOAuthTokenURL, codexOAuthClientID, code, verifier, codexOAuthRedirectURI)
}
func CreateCodexOAuthAuthorizationFlow() (*CodexOAuthAuthorizationFlow, error) {
state, err := createStateHex(16)
if err != nil {
return nil, err
}
verifier, challenge, err := generatePKCEPair()
if err != nil {
return nil, err
}
u, err := buildCodexAuthorizeURL(state, challenge)
if err != nil {
return nil, err
}
return &CodexOAuthAuthorizationFlow{
State: state,
Verifier: verifier,
Challenge: challenge,
AuthorizeURL: u,
}, nil
}
func refreshCodexOAuthToken(
ctx context.Context,
client *http.Client,
tokenURL string,
clientID string,
refreshToken string,
) (*CodexOAuthTokenResult, error) {
rt := strings.TrimSpace(refreshToken)
if rt == "" {
return nil, errors.New("empty refresh_token")
}
form := url.Values{}
form.Set("grant_type", "refresh_token")
form.Set("refresh_token", rt)
form.Set("client_id", clientID)
req, err := http.NewRequestWithContext(ctx, http.MethodPost, tokenURL, strings.NewReader(form.Encode()))
if err != nil {
return nil, err
}
req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
req.Header.Set("Accept", "application/json")
resp, err := client.Do(req)
if err != nil {
return nil, err
}
defer resp.Body.Close()
var payload struct {
AccessToken string `json:"access_token"`
RefreshToken string `json:"refresh_token"`
ExpiresIn int `json:"expires_in"`
}
if err := json.NewDecoder(resp.Body).Decode(&payload); err != nil {
return nil, err
}
if resp.StatusCode < 200 || resp.StatusCode >= 300 {
return nil, fmt.Errorf("codex oauth refresh failed: status=%d", resp.StatusCode)
}
if strings.TrimSpace(payload.AccessToken) == "" || strings.TrimSpace(payload.RefreshToken) == "" || payload.ExpiresIn <= 0 {
return nil, errors.New("codex oauth refresh response missing fields")
}
return &CodexOAuthTokenResult{
AccessToken: strings.TrimSpace(payload.AccessToken),
RefreshToken: strings.TrimSpace(payload.RefreshToken),
ExpiresAt: time.Now().Add(time.Duration(payload.ExpiresIn) * time.Second),
}, nil
}
func exchangeCodexAuthorizationCode(
ctx context.Context,
client *http.Client,
tokenURL string,
clientID string,
code string,
verifier string,
redirectURI string,
) (*CodexOAuthTokenResult, error) {
c := strings.TrimSpace(code)
v := strings.TrimSpace(verifier)
if c == "" {
return nil, errors.New("empty authorization code")
}
if v == "" {
return nil, errors.New("empty code_verifier")
}
form := url.Values{}
form.Set("grant_type", "authorization_code")
form.Set("client_id", clientID)
form.Set("code", c)
form.Set("code_verifier", v)
form.Set("redirect_uri", redirectURI)
req, err := http.NewRequestWithContext(ctx, http.MethodPost, tokenURL, strings.NewReader(form.Encode()))
if err != nil {
return nil, err
}
req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
req.Header.Set("Accept", "application/json")
resp, err := client.Do(req)
if err != nil {
return nil, err
}
defer resp.Body.Close()
var payload struct {
AccessToken string `json:"access_token"`
RefreshToken string `json:"refresh_token"`
ExpiresIn int `json:"expires_in"`
}
if err := json.NewDecoder(resp.Body).Decode(&payload); err != nil {
return nil, err
}
if resp.StatusCode < 200 || resp.StatusCode >= 300 {
return nil, fmt.Errorf("codex oauth code exchange failed: status=%d", resp.StatusCode)
}
if strings.TrimSpace(payload.AccessToken) == "" || strings.TrimSpace(payload.RefreshToken) == "" || payload.ExpiresIn <= 0 {
return nil, errors.New("codex oauth token response missing fields")
}
return &CodexOAuthTokenResult{
AccessToken: strings.TrimSpace(payload.AccessToken),
RefreshToken: strings.TrimSpace(payload.RefreshToken),
ExpiresAt: time.Now().Add(time.Duration(payload.ExpiresIn) * time.Second),
}, nil
}
func buildCodexAuthorizeURL(state string, challenge string) (string, error) {
u, err := url.Parse(codexOAuthAuthorizeURL)
if err != nil {
return "", err
}
q := u.Query()
q.Set("response_type", "code")
q.Set("client_id", codexOAuthClientID)
q.Set("redirect_uri", codexOAuthRedirectURI)
q.Set("scope", codexOAuthScope)
q.Set("code_challenge", challenge)
q.Set("code_challenge_method", "S256")
q.Set("state", state)
q.Set("id_token_add_organizations", "true")
q.Set("codex_cli_simplified_flow", "true")
q.Set("originator", "codex_cli_rs")
u.RawQuery = q.Encode()
return u.String(), nil
}
func createStateHex(nBytes int) (string, error) {
if nBytes <= 0 {
return "", errors.New("invalid state bytes length")
}
b := make([]byte, nBytes)
if _, err := rand.Read(b); err != nil {
return "", err
}
return fmt.Sprintf("%x", b), nil
}
func generatePKCEPair() (verifier string, challenge string, err error) {
b := make([]byte, 32)
if _, err := rand.Read(b); err != nil {
return "", "", err
}
verifier = base64.RawURLEncoding.EncodeToString(b)
sum := sha256.Sum256([]byte(verifier))
challenge = base64.RawURLEncoding.EncodeToString(sum[:])
return verifier, challenge, nil
}
func ExtractCodexAccountIDFromJWT(token string) (string, bool) {
claims, ok := decodeJWTClaims(token)
if !ok {
return "", false
}
raw, ok := claims[codexJWTClaimPath]
if !ok {
return "", false
}
obj, ok := raw.(map[string]any)
if !ok {
return "", false
}
v, ok := obj["chatgpt_account_id"]
if !ok {
return "", false
}
s, ok := v.(string)
if !ok {
return "", false
}
s = strings.TrimSpace(s)
if s == "" {
return "", false
}
return s, true
}
func ExtractEmailFromJWT(token string) (string, bool) {
claims, ok := decodeJWTClaims(token)
if !ok {
return "", false
}
v, ok := claims["email"]
if !ok {
return "", false
}
s, ok := v.(string)
if !ok {
return "", false
}
s = strings.TrimSpace(s)
if s == "" {
return "", false
}
return s, true
}
func decodeJWTClaims(token string) (map[string]any, bool) {
parts := strings.Split(token, ".")
if len(parts) != 3 {
return nil, false
}
payloadRaw, err := base64.RawURLEncoding.DecodeString(parts[1])
if err != nil {
return nil, false
}
var claims map[string]any
if err := json.Unmarshal(payloadRaw, &claims); err != nil {
return nil, false
}
return claims, true
}

View File

@@ -0,0 +1,56 @@
package service
import (
"context"
"fmt"
"io"
"net/http"
"strings"
)
func FetchCodexWhamUsage(
ctx context.Context,
client *http.Client,
baseURL string,
accessToken string,
accountID string,
) (statusCode int, body []byte, err error) {
if client == nil {
return 0, nil, fmt.Errorf("nil http client")
}
bu := strings.TrimRight(strings.TrimSpace(baseURL), "/")
if bu == "" {
return 0, nil, fmt.Errorf("empty baseURL")
}
at := strings.TrimSpace(accessToken)
aid := strings.TrimSpace(accountID)
if at == "" {
return 0, nil, fmt.Errorf("empty accessToken")
}
if aid == "" {
return 0, nil, fmt.Errorf("empty accountID")
}
req, err := http.NewRequestWithContext(ctx, http.MethodGet, bu+"/backend-api/wham/usage", nil)
if err != nil {
return 0, nil, err
}
req.Header.Set("Authorization", "Bearer "+at)
req.Header.Set("chatgpt-account-id", aid)
req.Header.Set("Accept", "application/json")
if req.Header.Get("originator") == "" {
req.Header.Set("originator", "codex_cli_rs")
}
resp, err := client.Do(req)
if err != nil {
return 0, nil, err
}
defer resp.Body.Close()
body, err = io.ReadAll(resp.Body)
if err != nil {
return resp.StatusCode, nil, err
}
return resp.StatusCode, body, nil
}

View File

@@ -54,6 +54,38 @@ func ChatCompletionsRequestToResponsesRequest(req *dto.GeneralOpenAIRequest) (*d
continue
}
if role == "tool" || role == "function" {
callID := strings.TrimSpace(msg.ToolCallId)
var output any
if msg.Content == nil {
output = ""
} else if msg.IsStringContent() {
output = msg.StringContent()
} else {
if b, err := common.Marshal(msg.Content); err == nil {
output = string(b)
} else {
output = fmt.Sprintf("%v", msg.Content)
}
}
if callID == "" {
inputItems = append(inputItems, map[string]any{
"role": "user",
"content": fmt.Sprintf("[tool_output_missing_call_id] %v", output),
})
continue
}
inputItems = append(inputItems, map[string]any{
"type": "function_call_output",
"call_id": callID,
"output": output,
})
continue
}
// Prefer mapping system/developer messages into `instructions`.
if role == "system" || role == "developer" {
if msg.Content == nil {
@@ -88,12 +120,54 @@ func ChatCompletionsRequestToResponsesRequest(req *dto.GeneralOpenAIRequest) (*d
if msg.Content == nil {
item["content"] = ""
inputItems = append(inputItems, item)
if role == "assistant" {
for _, tc := range msg.ParseToolCalls() {
if strings.TrimSpace(tc.ID) == "" {
continue
}
if tc.Type != "" && tc.Type != "function" {
continue
}
name := strings.TrimSpace(tc.Function.Name)
if name == "" {
continue
}
inputItems = append(inputItems, map[string]any{
"type": "function_call",
"call_id": tc.ID,
"name": name,
"arguments": tc.Function.Arguments,
})
}
}
continue
}
if msg.IsStringContent() {
item["content"] = msg.StringContent()
inputItems = append(inputItems, item)
if role == "assistant" {
for _, tc := range msg.ParseToolCalls() {
if strings.TrimSpace(tc.ID) == "" {
continue
}
if tc.Type != "" && tc.Type != "function" {
continue
}
name := strings.TrimSpace(tc.Function.Name)
if name == "" {
continue
}
inputItems = append(inputItems, map[string]any{
"type": "function_call",
"call_id": tc.ID,
"name": name,
"arguments": tc.Function.Arguments,
})
}
}
continue
}
@@ -127,7 +201,6 @@ func ChatCompletionsRequestToResponsesRequest(req *dto.GeneralOpenAIRequest) (*d
"video_url": part.VideoUrl,
})
default:
// Best-effort: keep unknown parts as-is to avoid silently dropping context.
contentParts = append(contentParts, map[string]any{
"type": part.Type,
})
@@ -135,6 +208,27 @@ func ChatCompletionsRequestToResponsesRequest(req *dto.GeneralOpenAIRequest) (*d
}
item["content"] = contentParts
inputItems = append(inputItems, item)
if role == "assistant" {
for _, tc := range msg.ParseToolCalls() {
if strings.TrimSpace(tc.ID) == "" {
continue
}
if tc.Type != "" && tc.Type != "function" {
continue
}
name := strings.TrimSpace(tc.Function.Name)
if name == "" {
continue
}
inputItems = append(inputItems, map[string]any{
"type": "function_call",
"call_id": tc.ID,
"name": name,
"arguments": tc.Function.Arguments,
})
}
}
}
inputRaw, err := common.Marshal(inputItems)