mirror of
https://github.com/QuantumNous/new-api.git
synced 2026-04-19 02:57:27 +00:00
feat: codex channel (#2652)
* feat: codex channel * feat: codex channel * feat: codex oauth flow * feat: codex refresh cred * feat: codex usage * fix: codex err message detail * fix: codex setting ui * feat: codex refresh cred task * fix: import err * fix: codex store must be false * fix: chat -> responses tool call * fix: chat -> responses tool call
This commit is contained in:
104
service/codex_credential_refresh.go
Normal file
104
service/codex_credential_refresh.go
Normal file
@@ -0,0 +1,104 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/QuantumNous/new-api/common"
|
||||
"github.com/QuantumNous/new-api/constant"
|
||||
"github.com/QuantumNous/new-api/model"
|
||||
)
|
||||
|
||||
type CodexCredentialRefreshOptions struct {
|
||||
ResetCaches bool
|
||||
}
|
||||
|
||||
type CodexOAuthKey struct {
|
||||
IDToken string `json:"id_token,omitempty"`
|
||||
AccessToken string `json:"access_token,omitempty"`
|
||||
RefreshToken string `json:"refresh_token,omitempty"`
|
||||
|
||||
AccountID string `json:"account_id,omitempty"`
|
||||
LastRefresh string `json:"last_refresh,omitempty"`
|
||||
Email string `json:"email,omitempty"`
|
||||
Type string `json:"type,omitempty"`
|
||||
Expired string `json:"expired,omitempty"`
|
||||
}
|
||||
|
||||
func parseCodexOAuthKey(raw string) (*CodexOAuthKey, error) {
|
||||
if strings.TrimSpace(raw) == "" {
|
||||
return nil, errors.New("codex channel: empty oauth key")
|
||||
}
|
||||
var key CodexOAuthKey
|
||||
if err := common.Unmarshal([]byte(raw), &key); err != nil {
|
||||
return nil, errors.New("codex channel: invalid oauth key json")
|
||||
}
|
||||
return &key, nil
|
||||
}
|
||||
|
||||
func RefreshCodexChannelCredential(ctx context.Context, channelID int, opts CodexCredentialRefreshOptions) (*CodexOAuthKey, *model.Channel, error) {
|
||||
ch, err := model.GetChannelById(channelID, true)
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
if ch == nil {
|
||||
return nil, nil, fmt.Errorf("channel not found")
|
||||
}
|
||||
if ch.Type != constant.ChannelTypeCodex {
|
||||
return nil, nil, fmt.Errorf("channel type is not Codex")
|
||||
}
|
||||
|
||||
oauthKey, err := parseCodexOAuthKey(strings.TrimSpace(ch.Key))
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
if strings.TrimSpace(oauthKey.RefreshToken) == "" {
|
||||
return nil, nil, fmt.Errorf("codex channel: refresh_token is required to refresh credential")
|
||||
}
|
||||
|
||||
refreshCtx, cancel := context.WithTimeout(ctx, 10*time.Second)
|
||||
defer cancel()
|
||||
|
||||
res, err := RefreshCodexOAuthToken(refreshCtx, oauthKey.RefreshToken)
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
|
||||
oauthKey.AccessToken = res.AccessToken
|
||||
oauthKey.RefreshToken = res.RefreshToken
|
||||
oauthKey.LastRefresh = time.Now().Format(time.RFC3339)
|
||||
oauthKey.Expired = res.ExpiresAt.Format(time.RFC3339)
|
||||
if strings.TrimSpace(oauthKey.Type) == "" {
|
||||
oauthKey.Type = "codex"
|
||||
}
|
||||
|
||||
if strings.TrimSpace(oauthKey.AccountID) == "" {
|
||||
if accountID, ok := ExtractCodexAccountIDFromJWT(oauthKey.AccessToken); ok {
|
||||
oauthKey.AccountID = accountID
|
||||
}
|
||||
}
|
||||
if strings.TrimSpace(oauthKey.Email) == "" {
|
||||
if email, ok := ExtractEmailFromJWT(oauthKey.AccessToken); ok {
|
||||
oauthKey.Email = email
|
||||
}
|
||||
}
|
||||
|
||||
encoded, err := common.Marshal(oauthKey)
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
|
||||
if err := model.DB.Model(&model.Channel{}).Where("id = ?", ch.Id).Update("key", string(encoded)).Error; err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
|
||||
if opts.ResetCaches {
|
||||
model.InitChannelCache()
|
||||
ResetProxyClientCache()
|
||||
}
|
||||
|
||||
return oauthKey, ch, nil
|
||||
}
|
||||
140
service/codex_credential_refresh_task.go
Normal file
140
service/codex_credential_refresh_task.go
Normal file
@@ -0,0 +1,140 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"strings"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
|
||||
"github.com/QuantumNous/new-api/common"
|
||||
"github.com/QuantumNous/new-api/constant"
|
||||
"github.com/QuantumNous/new-api/logger"
|
||||
"github.com/QuantumNous/new-api/model"
|
||||
|
||||
"github.com/bytedance/gopkg/util/gopool"
|
||||
)
|
||||
|
||||
const (
|
||||
codexCredentialRefreshTickInterval = 10 * time.Minute
|
||||
codexCredentialRefreshThreshold = 24 * time.Hour
|
||||
codexCredentialRefreshBatchSize = 200
|
||||
codexCredentialRefreshTimeout = 15 * time.Second
|
||||
)
|
||||
|
||||
var (
|
||||
codexCredentialRefreshOnce sync.Once
|
||||
codexCredentialRefreshRunning atomic.Bool
|
||||
)
|
||||
|
||||
func StartCodexCredentialAutoRefreshTask() {
|
||||
codexCredentialRefreshOnce.Do(func() {
|
||||
if !common.IsMasterNode {
|
||||
return
|
||||
}
|
||||
|
||||
gopool.Go(func() {
|
||||
logger.LogInfo(context.Background(), fmt.Sprintf("codex credential auto-refresh task started: tick=%s threshold=%s", codexCredentialRefreshTickInterval, codexCredentialRefreshThreshold))
|
||||
|
||||
ticker := time.NewTicker(codexCredentialRefreshTickInterval)
|
||||
defer ticker.Stop()
|
||||
|
||||
runCodexCredentialAutoRefreshOnce()
|
||||
for range ticker.C {
|
||||
runCodexCredentialAutoRefreshOnce()
|
||||
}
|
||||
})
|
||||
})
|
||||
}
|
||||
|
||||
func runCodexCredentialAutoRefreshOnce() {
|
||||
if !codexCredentialRefreshRunning.CompareAndSwap(false, true) {
|
||||
return
|
||||
}
|
||||
defer codexCredentialRefreshRunning.Store(false)
|
||||
|
||||
ctx := context.Background()
|
||||
now := time.Now()
|
||||
|
||||
var refreshed int
|
||||
var scanned int
|
||||
|
||||
offset := 0
|
||||
for {
|
||||
var channels []*model.Channel
|
||||
err := model.DB.
|
||||
Select("id", "name", "key", "status", "channel_info").
|
||||
Where("type = ? AND status = 1", constant.ChannelTypeCodex).
|
||||
Order("id asc").
|
||||
Limit(codexCredentialRefreshBatchSize).
|
||||
Offset(offset).
|
||||
Find(&channels).Error
|
||||
if err != nil {
|
||||
logger.LogError(ctx, fmt.Sprintf("codex credential auto-refresh: query channels failed: %v", err))
|
||||
return
|
||||
}
|
||||
if len(channels) == 0 {
|
||||
break
|
||||
}
|
||||
offset += codexCredentialRefreshBatchSize
|
||||
|
||||
for _, ch := range channels {
|
||||
if ch == nil {
|
||||
continue
|
||||
}
|
||||
scanned++
|
||||
if ch.ChannelInfo.IsMultiKey {
|
||||
continue
|
||||
}
|
||||
|
||||
rawKey := strings.TrimSpace(ch.Key)
|
||||
if rawKey == "" {
|
||||
continue
|
||||
}
|
||||
|
||||
oauthKey, err := parseCodexOAuthKey(rawKey)
|
||||
if err != nil {
|
||||
continue
|
||||
}
|
||||
|
||||
refreshToken := strings.TrimSpace(oauthKey.RefreshToken)
|
||||
if refreshToken == "" {
|
||||
continue
|
||||
}
|
||||
|
||||
expiredAtRaw := strings.TrimSpace(oauthKey.Expired)
|
||||
expiredAt, err := time.Parse(time.RFC3339, expiredAtRaw)
|
||||
if err == nil && !expiredAt.IsZero() && expiredAt.Sub(now) > codexCredentialRefreshThreshold {
|
||||
continue
|
||||
}
|
||||
|
||||
refreshCtx, cancel := context.WithTimeout(ctx, codexCredentialRefreshTimeout)
|
||||
newKey, _, err := RefreshCodexChannelCredential(refreshCtx, ch.Id, CodexCredentialRefreshOptions{ResetCaches: false})
|
||||
cancel()
|
||||
if err != nil {
|
||||
logger.LogWarn(ctx, fmt.Sprintf("codex credential auto-refresh: channel_id=%d name=%s refresh failed: %v", ch.Id, ch.Name, err))
|
||||
continue
|
||||
}
|
||||
|
||||
refreshed++
|
||||
logger.LogInfo(ctx, fmt.Sprintf("codex credential auto-refresh: channel_id=%d name=%s refreshed, expires_at=%s", ch.Id, ch.Name, newKey.Expired))
|
||||
}
|
||||
}
|
||||
|
||||
if refreshed > 0 {
|
||||
func() {
|
||||
defer func() {
|
||||
if r := recover(); r != nil {
|
||||
logger.LogWarn(ctx, fmt.Sprintf("codex credential auto-refresh: InitChannelCache panic: %v", r))
|
||||
}
|
||||
}()
|
||||
model.InitChannelCache()
|
||||
}()
|
||||
ResetProxyClientCache()
|
||||
}
|
||||
|
||||
if common.DebugEnabled {
|
||||
logger.LogDebug(ctx, "codex credential auto-refresh: scanned=%d refreshed=%d", scanned, refreshed)
|
||||
}
|
||||
}
|
||||
288
service/codex_oauth.go
Normal file
288
service/codex_oauth.go
Normal file
@@ -0,0 +1,288 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/rand"
|
||||
"crypto/sha256"
|
||||
"encoding/base64"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"strings"
|
||||
"time"
|
||||
)
|
||||
|
||||
const (
|
||||
codexOAuthClientID = "app_EMoamEEZ73f0CkXaXp7hrann"
|
||||
codexOAuthAuthorizeURL = "https://auth.openai.com/oauth/authorize"
|
||||
codexOAuthTokenURL = "https://auth.openai.com/oauth/token"
|
||||
codexOAuthRedirectURI = "http://localhost:1455/auth/callback"
|
||||
codexOAuthScope = "openid profile email offline_access"
|
||||
codexJWTClaimPath = "https://api.openai.com/auth"
|
||||
defaultHTTPTimeout = 20 * time.Second
|
||||
)
|
||||
|
||||
type CodexOAuthTokenResult struct {
|
||||
AccessToken string
|
||||
RefreshToken string
|
||||
ExpiresAt time.Time
|
||||
}
|
||||
|
||||
type CodexOAuthAuthorizationFlow struct {
|
||||
State string
|
||||
Verifier string
|
||||
Challenge string
|
||||
AuthorizeURL string
|
||||
}
|
||||
|
||||
func RefreshCodexOAuthToken(ctx context.Context, refreshToken string) (*CodexOAuthTokenResult, error) {
|
||||
client := &http.Client{Timeout: defaultHTTPTimeout}
|
||||
return refreshCodexOAuthToken(ctx, client, codexOAuthTokenURL, codexOAuthClientID, refreshToken)
|
||||
}
|
||||
|
||||
func ExchangeCodexAuthorizationCode(ctx context.Context, code string, verifier string) (*CodexOAuthTokenResult, error) {
|
||||
client := &http.Client{Timeout: defaultHTTPTimeout}
|
||||
return exchangeCodexAuthorizationCode(ctx, client, codexOAuthTokenURL, codexOAuthClientID, code, verifier, codexOAuthRedirectURI)
|
||||
}
|
||||
|
||||
func CreateCodexOAuthAuthorizationFlow() (*CodexOAuthAuthorizationFlow, error) {
|
||||
state, err := createStateHex(16)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
verifier, challenge, err := generatePKCEPair()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
u, err := buildCodexAuthorizeURL(state, challenge)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &CodexOAuthAuthorizationFlow{
|
||||
State: state,
|
||||
Verifier: verifier,
|
||||
Challenge: challenge,
|
||||
AuthorizeURL: u,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func refreshCodexOAuthToken(
|
||||
ctx context.Context,
|
||||
client *http.Client,
|
||||
tokenURL string,
|
||||
clientID string,
|
||||
refreshToken string,
|
||||
) (*CodexOAuthTokenResult, error) {
|
||||
rt := strings.TrimSpace(refreshToken)
|
||||
if rt == "" {
|
||||
return nil, errors.New("empty refresh_token")
|
||||
}
|
||||
|
||||
form := url.Values{}
|
||||
form.Set("grant_type", "refresh_token")
|
||||
form.Set("refresh_token", rt)
|
||||
form.Set("client_id", clientID)
|
||||
|
||||
req, err := http.NewRequestWithContext(ctx, http.MethodPost, tokenURL, strings.NewReader(form.Encode()))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
|
||||
req.Header.Set("Accept", "application/json")
|
||||
|
||||
resp, err := client.Do(req)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
var payload struct {
|
||||
AccessToken string `json:"access_token"`
|
||||
RefreshToken string `json:"refresh_token"`
|
||||
ExpiresIn int `json:"expires_in"`
|
||||
}
|
||||
|
||||
if err := json.NewDecoder(resp.Body).Decode(&payload); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if resp.StatusCode < 200 || resp.StatusCode >= 300 {
|
||||
return nil, fmt.Errorf("codex oauth refresh failed: status=%d", resp.StatusCode)
|
||||
}
|
||||
|
||||
if strings.TrimSpace(payload.AccessToken) == "" || strings.TrimSpace(payload.RefreshToken) == "" || payload.ExpiresIn <= 0 {
|
||||
return nil, errors.New("codex oauth refresh response missing fields")
|
||||
}
|
||||
|
||||
return &CodexOAuthTokenResult{
|
||||
AccessToken: strings.TrimSpace(payload.AccessToken),
|
||||
RefreshToken: strings.TrimSpace(payload.RefreshToken),
|
||||
ExpiresAt: time.Now().Add(time.Duration(payload.ExpiresIn) * time.Second),
|
||||
}, nil
|
||||
}
|
||||
|
||||
func exchangeCodexAuthorizationCode(
|
||||
ctx context.Context,
|
||||
client *http.Client,
|
||||
tokenURL string,
|
||||
clientID string,
|
||||
code string,
|
||||
verifier string,
|
||||
redirectURI string,
|
||||
) (*CodexOAuthTokenResult, error) {
|
||||
c := strings.TrimSpace(code)
|
||||
v := strings.TrimSpace(verifier)
|
||||
if c == "" {
|
||||
return nil, errors.New("empty authorization code")
|
||||
}
|
||||
if v == "" {
|
||||
return nil, errors.New("empty code_verifier")
|
||||
}
|
||||
|
||||
form := url.Values{}
|
||||
form.Set("grant_type", "authorization_code")
|
||||
form.Set("client_id", clientID)
|
||||
form.Set("code", c)
|
||||
form.Set("code_verifier", v)
|
||||
form.Set("redirect_uri", redirectURI)
|
||||
|
||||
req, err := http.NewRequestWithContext(ctx, http.MethodPost, tokenURL, strings.NewReader(form.Encode()))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
|
||||
req.Header.Set("Accept", "application/json")
|
||||
|
||||
resp, err := client.Do(req)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
var payload struct {
|
||||
AccessToken string `json:"access_token"`
|
||||
RefreshToken string `json:"refresh_token"`
|
||||
ExpiresIn int `json:"expires_in"`
|
||||
}
|
||||
if err := json.NewDecoder(resp.Body).Decode(&payload); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if resp.StatusCode < 200 || resp.StatusCode >= 300 {
|
||||
return nil, fmt.Errorf("codex oauth code exchange failed: status=%d", resp.StatusCode)
|
||||
}
|
||||
if strings.TrimSpace(payload.AccessToken) == "" || strings.TrimSpace(payload.RefreshToken) == "" || payload.ExpiresIn <= 0 {
|
||||
return nil, errors.New("codex oauth token response missing fields")
|
||||
}
|
||||
return &CodexOAuthTokenResult{
|
||||
AccessToken: strings.TrimSpace(payload.AccessToken),
|
||||
RefreshToken: strings.TrimSpace(payload.RefreshToken),
|
||||
ExpiresAt: time.Now().Add(time.Duration(payload.ExpiresIn) * time.Second),
|
||||
}, nil
|
||||
}
|
||||
|
||||
func buildCodexAuthorizeURL(state string, challenge string) (string, error) {
|
||||
u, err := url.Parse(codexOAuthAuthorizeURL)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
q := u.Query()
|
||||
q.Set("response_type", "code")
|
||||
q.Set("client_id", codexOAuthClientID)
|
||||
q.Set("redirect_uri", codexOAuthRedirectURI)
|
||||
q.Set("scope", codexOAuthScope)
|
||||
q.Set("code_challenge", challenge)
|
||||
q.Set("code_challenge_method", "S256")
|
||||
q.Set("state", state)
|
||||
q.Set("id_token_add_organizations", "true")
|
||||
q.Set("codex_cli_simplified_flow", "true")
|
||||
q.Set("originator", "codex_cli_rs")
|
||||
u.RawQuery = q.Encode()
|
||||
return u.String(), nil
|
||||
}
|
||||
|
||||
func createStateHex(nBytes int) (string, error) {
|
||||
if nBytes <= 0 {
|
||||
return "", errors.New("invalid state bytes length")
|
||||
}
|
||||
b := make([]byte, nBytes)
|
||||
if _, err := rand.Read(b); err != nil {
|
||||
return "", err
|
||||
}
|
||||
return fmt.Sprintf("%x", b), nil
|
||||
}
|
||||
|
||||
func generatePKCEPair() (verifier string, challenge string, err error) {
|
||||
b := make([]byte, 32)
|
||||
if _, err := rand.Read(b); err != nil {
|
||||
return "", "", err
|
||||
}
|
||||
verifier = base64.RawURLEncoding.EncodeToString(b)
|
||||
sum := sha256.Sum256([]byte(verifier))
|
||||
challenge = base64.RawURLEncoding.EncodeToString(sum[:])
|
||||
return verifier, challenge, nil
|
||||
}
|
||||
|
||||
func ExtractCodexAccountIDFromJWT(token string) (string, bool) {
|
||||
claims, ok := decodeJWTClaims(token)
|
||||
if !ok {
|
||||
return "", false
|
||||
}
|
||||
raw, ok := claims[codexJWTClaimPath]
|
||||
if !ok {
|
||||
return "", false
|
||||
}
|
||||
obj, ok := raw.(map[string]any)
|
||||
if !ok {
|
||||
return "", false
|
||||
}
|
||||
v, ok := obj["chatgpt_account_id"]
|
||||
if !ok {
|
||||
return "", false
|
||||
}
|
||||
s, ok := v.(string)
|
||||
if !ok {
|
||||
return "", false
|
||||
}
|
||||
s = strings.TrimSpace(s)
|
||||
if s == "" {
|
||||
return "", false
|
||||
}
|
||||
return s, true
|
||||
}
|
||||
|
||||
func ExtractEmailFromJWT(token string) (string, bool) {
|
||||
claims, ok := decodeJWTClaims(token)
|
||||
if !ok {
|
||||
return "", false
|
||||
}
|
||||
v, ok := claims["email"]
|
||||
if !ok {
|
||||
return "", false
|
||||
}
|
||||
s, ok := v.(string)
|
||||
if !ok {
|
||||
return "", false
|
||||
}
|
||||
s = strings.TrimSpace(s)
|
||||
if s == "" {
|
||||
return "", false
|
||||
}
|
||||
return s, true
|
||||
}
|
||||
|
||||
func decodeJWTClaims(token string) (map[string]any, bool) {
|
||||
parts := strings.Split(token, ".")
|
||||
if len(parts) != 3 {
|
||||
return nil, false
|
||||
}
|
||||
payloadRaw, err := base64.RawURLEncoding.DecodeString(parts[1])
|
||||
if err != nil {
|
||||
return nil, false
|
||||
}
|
||||
var claims map[string]any
|
||||
if err := json.Unmarshal(payloadRaw, &claims); err != nil {
|
||||
return nil, false
|
||||
}
|
||||
return claims, true
|
||||
}
|
||||
56
service/codex_wham_usage.go
Normal file
56
service/codex_wham_usage.go
Normal file
@@ -0,0 +1,56 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"strings"
|
||||
)
|
||||
|
||||
func FetchCodexWhamUsage(
|
||||
ctx context.Context,
|
||||
client *http.Client,
|
||||
baseURL string,
|
||||
accessToken string,
|
||||
accountID string,
|
||||
) (statusCode int, body []byte, err error) {
|
||||
if client == nil {
|
||||
return 0, nil, fmt.Errorf("nil http client")
|
||||
}
|
||||
bu := strings.TrimRight(strings.TrimSpace(baseURL), "/")
|
||||
if bu == "" {
|
||||
return 0, nil, fmt.Errorf("empty baseURL")
|
||||
}
|
||||
at := strings.TrimSpace(accessToken)
|
||||
aid := strings.TrimSpace(accountID)
|
||||
if at == "" {
|
||||
return 0, nil, fmt.Errorf("empty accessToken")
|
||||
}
|
||||
if aid == "" {
|
||||
return 0, nil, fmt.Errorf("empty accountID")
|
||||
}
|
||||
|
||||
req, err := http.NewRequestWithContext(ctx, http.MethodGet, bu+"/backend-api/wham/usage", nil)
|
||||
if err != nil {
|
||||
return 0, nil, err
|
||||
}
|
||||
req.Header.Set("Authorization", "Bearer "+at)
|
||||
req.Header.Set("chatgpt-account-id", aid)
|
||||
req.Header.Set("Accept", "application/json")
|
||||
if req.Header.Get("originator") == "" {
|
||||
req.Header.Set("originator", "codex_cli_rs")
|
||||
}
|
||||
|
||||
resp, err := client.Do(req)
|
||||
if err != nil {
|
||||
return 0, nil, err
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
body, err = io.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
return resp.StatusCode, nil, err
|
||||
}
|
||||
return resp.StatusCode, body, nil
|
||||
}
|
||||
@@ -54,6 +54,38 @@ func ChatCompletionsRequestToResponsesRequest(req *dto.GeneralOpenAIRequest) (*d
|
||||
continue
|
||||
}
|
||||
|
||||
if role == "tool" || role == "function" {
|
||||
callID := strings.TrimSpace(msg.ToolCallId)
|
||||
|
||||
var output any
|
||||
if msg.Content == nil {
|
||||
output = ""
|
||||
} else if msg.IsStringContent() {
|
||||
output = msg.StringContent()
|
||||
} else {
|
||||
if b, err := common.Marshal(msg.Content); err == nil {
|
||||
output = string(b)
|
||||
} else {
|
||||
output = fmt.Sprintf("%v", msg.Content)
|
||||
}
|
||||
}
|
||||
|
||||
if callID == "" {
|
||||
inputItems = append(inputItems, map[string]any{
|
||||
"role": "user",
|
||||
"content": fmt.Sprintf("[tool_output_missing_call_id] %v", output),
|
||||
})
|
||||
continue
|
||||
}
|
||||
|
||||
inputItems = append(inputItems, map[string]any{
|
||||
"type": "function_call_output",
|
||||
"call_id": callID,
|
||||
"output": output,
|
||||
})
|
||||
continue
|
||||
}
|
||||
|
||||
// Prefer mapping system/developer messages into `instructions`.
|
||||
if role == "system" || role == "developer" {
|
||||
if msg.Content == nil {
|
||||
@@ -88,12 +120,54 @@ func ChatCompletionsRequestToResponsesRequest(req *dto.GeneralOpenAIRequest) (*d
|
||||
if msg.Content == nil {
|
||||
item["content"] = ""
|
||||
inputItems = append(inputItems, item)
|
||||
|
||||
if role == "assistant" {
|
||||
for _, tc := range msg.ParseToolCalls() {
|
||||
if strings.TrimSpace(tc.ID) == "" {
|
||||
continue
|
||||
}
|
||||
if tc.Type != "" && tc.Type != "function" {
|
||||
continue
|
||||
}
|
||||
name := strings.TrimSpace(tc.Function.Name)
|
||||
if name == "" {
|
||||
continue
|
||||
}
|
||||
inputItems = append(inputItems, map[string]any{
|
||||
"type": "function_call",
|
||||
"call_id": tc.ID,
|
||||
"name": name,
|
||||
"arguments": tc.Function.Arguments,
|
||||
})
|
||||
}
|
||||
}
|
||||
continue
|
||||
}
|
||||
|
||||
if msg.IsStringContent() {
|
||||
item["content"] = msg.StringContent()
|
||||
inputItems = append(inputItems, item)
|
||||
|
||||
if role == "assistant" {
|
||||
for _, tc := range msg.ParseToolCalls() {
|
||||
if strings.TrimSpace(tc.ID) == "" {
|
||||
continue
|
||||
}
|
||||
if tc.Type != "" && tc.Type != "function" {
|
||||
continue
|
||||
}
|
||||
name := strings.TrimSpace(tc.Function.Name)
|
||||
if name == "" {
|
||||
continue
|
||||
}
|
||||
inputItems = append(inputItems, map[string]any{
|
||||
"type": "function_call",
|
||||
"call_id": tc.ID,
|
||||
"name": name,
|
||||
"arguments": tc.Function.Arguments,
|
||||
})
|
||||
}
|
||||
}
|
||||
continue
|
||||
}
|
||||
|
||||
@@ -127,7 +201,6 @@ func ChatCompletionsRequestToResponsesRequest(req *dto.GeneralOpenAIRequest) (*d
|
||||
"video_url": part.VideoUrl,
|
||||
})
|
||||
default:
|
||||
// Best-effort: keep unknown parts as-is to avoid silently dropping context.
|
||||
contentParts = append(contentParts, map[string]any{
|
||||
"type": part.Type,
|
||||
})
|
||||
@@ -135,6 +208,27 @@ func ChatCompletionsRequestToResponsesRequest(req *dto.GeneralOpenAIRequest) (*d
|
||||
}
|
||||
item["content"] = contentParts
|
||||
inputItems = append(inputItems, item)
|
||||
|
||||
if role == "assistant" {
|
||||
for _, tc := range msg.ParseToolCalls() {
|
||||
if strings.TrimSpace(tc.ID) == "" {
|
||||
continue
|
||||
}
|
||||
if tc.Type != "" && tc.Type != "function" {
|
||||
continue
|
||||
}
|
||||
name := strings.TrimSpace(tc.Function.Name)
|
||||
if name == "" {
|
||||
continue
|
||||
}
|
||||
inputItems = append(inputItems, map[string]any{
|
||||
"type": "function_call",
|
||||
"call_id": tc.ID,
|
||||
"name": name,
|
||||
"arguments": tc.Function.Arguments,
|
||||
})
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
inputRaw, err := common.Marshal(inputItems)
|
||||
|
||||
Reference in New Issue
Block a user