Compare commits

...

90 Commits

Author SHA1 Message Date
CaIon
1e25bf700d Merge remote-tracking branch 'origin/alpha' into alpha 2025-07-05 14:14:48 +08:00
CaIon
30fb349d91 feat(endpoint types): add support for image generation models in endpoint type handling 2025-07-05 14:14:40 +08:00
t0ng7u
d40fb68500 📊 feat(detail): add model consumption trend & call ranking charts
Introduce two new visualizations to the “Model Data Analysis” panel:

1. Model Consumption Trend (line chart)
   • Added `spec_model_line` state and legend support.
   • Calculates per-model counts over time and updates via `updateChartData`.
2. Model Call Ranking (bar chart)
   • Added `spec_rank_bar` state with `seriesField` and legend enabled.
   • Ranks models by total call count.

Additional changes:
• Extended tab navigation with two new `TabPane`s and adjusted chart rendering logic.
• Swapped icons/texts to match new chart purposes.
• Reused existing color mapping to ensure consistent palette.

No breaking changes; UI now offers richer insights into model usage patterns.
2025-07-05 00:37:05 +08:00
t0ng7u
3049ad47e5 🔢 feat(user-edit): replace add-quota input with Semi-UI InputNumber
Summary:
• Imported InputNumber from @douyinfe/semi-ui.
• Swapped plain Input for InputNumber in “Add Quota” modal.
• Added UX tweaks: full-width styling, showClear, step = 500 000.
• Initialized addQuotaLocal to an empty string so the field starts blank.
• Adjusted state handling and kept quota calculation logic unchanged.

This improves numeric input accuracy and overall user experience without breaking existing functionality.
2025-07-05 00:03:12 +08:00
t0ng7u
8945a3a2dd 🖼️ style(RatioSync): remove the useless rounded-full style 2025-07-04 23:49:34 +08:00
t0ng7u
d191eef657 🐛 fix: fix the header height calculation issue in the custom HTML styles on the homepage 2025-07-04 23:42:46 +08:00
CaIon
6ac7878863 🔧 refactor(endpoint types): comment out unused endpoint types in constants 2025-07-04 15:53:46 +08:00
t0ng7u
c0a23ffa62 🎨 refactor(EditTagModal): tidy imports & enhance state-sync on open
Motivation
• Remove unused UI components to keep the bundle lean and silence linter warnings.
• Ensure every time the side-sheet opens it reflects the latest tag data, avoiding stale form values (e.g., model / group mismatches).

Key Changes
1. UI Imports
   – Dropped `Input`, `Select`, `TextArea` from `@douyinfe/semi-ui` (unused in Form-based version).
2. State Reset & Form Sync
   – On `visible` or `tag` change:
     • Refresh model & group options.
     • Reset `inputs` to clean defaults (`originInputs`) carrying the current `tag`.
     • Pre-fill Form through `formApiRef` to keep controlled fields aligned.
3. Minor Cleanup
   – Added inline comment clarifying local state reset purpose.

Result
Opening the “Edit Tag” side-sheet now always displays accurate data without residual selections, and build output is cleaner due to removed dead imports.
2025-07-04 06:14:15 +08:00
t0ng7u
7d691f362d refactor(EditChannel&EditToken): refactor Channel & Token edit pages with Semi Form and UX enhancements
Overview
• Migrated both `EditChannel.js` and `EditToken.js` to fully leverage Semi UI `Form.*` components, removing legacy `Input/Select/TextArea` + manual labels.
• Unified data-loading strategy: when the drawer becomes visible we load (or reset) data via `props.visible + id` effect and `formApi.setValues()`, guaranteeing fields are always populated; form resets on close.
• Fixed blank-form bug when opening the same record twice.

Key improvements
1. Validation
   • `type`, `models` always required.
   • `key` required only while creating (not on edit).
2. Batch key creation
   • Checkbox moved into `extraText`; hidden when editing or when channel type = 41.
3. Layout & UI
   • `Row / Col` (12 + 12) for “Priority” and “Weight”.
   • Placeholders revised; model selector now shows creation hint; removed obsolete banner.
   • Help / extraText used for long hints, template buttons (`model_mapping`, `status_code_mapping`, `param_override`, etc.), and API address notice.
   • Added `showClear`, `min`, rounded card class names for consistency.
4. Reusable helpers
   • `batchAllowed`, `batchExtra` utilities.
   • `getInitValues()` + centralized `inputs`→form synchronization.
5. Token editor aligned to the same pattern (`props.visiable` watcher).

Result
Cleaner code, consistent UX, instant field population on every open, and clearer validation/error feedback across both editors.
2025-07-04 05:36:10 +08:00
t0ng7u
bf577b8937 🔌 feat(api): extend endpoint type support & expose in pricing UI
* backend
  - constant/endpoint_type.go
    • Add EndpointTypeMidjourney, EndpointTypeSuno, EndpointTypeKling, EndpointTypeJimeng.
  - common/endpoint_type.go
    • Map Midjourney / MidjourneyPlus, SunoAPI, Kling, Jimeng channel types to the new endpoint types.

* frontend
  - ModelPricing.js
    • Add “Supported Endpoint Type” column.
    • Implement renderSupportedEndpoints with `stringToColor` for consistent tag colors.

These changes allow `/api/pricing` and model lists to return accurate
`supported_endpoint_types` covering all non-OpenAI providers and display
them clearly in the UI.

No breaking changes.
2025-07-04 03:15:34 +08:00
Calcium-Ion
819290c9b8 Merge pull request #1314 from vickyyd/main
修复使用gemini-balance作为上游时,测试gemini2.5pro模型时出现的错误问题
2025-07-03 15:53:32 +08:00
CaIon
22e8b46159 feat: make TopN field in RerankRequest optional in JSON serialization 2025-07-03 15:45:32 +08:00
CaIon
76b8cc1168 feat: add pull request template and enforce branching strategy in workflow 2025-07-03 13:33:50 +08:00
Calcium-Ion
fce07325b9 Merge pull request #1325 from feitianbubu/pr/fix-ali-embedding-lost-prompt-token
fix: ali embedding lose prompt_tokens
2025-07-03 13:26:51 +08:00
Calcium-Ion
123862d41c Merge pull request #1326 from QuantumNous/refactor_constant
 feat: refactor environment variable initialization
2025-07-03 13:18:41 +08:00
CaIon
7e298f8ad1 feat: refactor environment variable initialization and introduce new constant types for API and context keys 2025-07-03 13:10:25 +08:00
IcedTangerine
34aca14858 Merge pull request #1309 from feitianbubu/pr/alpha/video-action-constant2
feat: video action to constant
2025-07-02 15:50:23 +08:00
skynono
6b1f94348a fix: ali embedding lose prompt_tokens 2025-07-02 15:12:02 +08:00
CaIon
4322037639 🐛 fix: correct validation logic for redemption name input in EditRedemption component 2025-07-02 10:28:57 +08:00
CaIon
ae11f88595 feat: increase Node.js memory limit in macOS release workflow 2025-07-01 13:23:29 +08:00
CaIon
389a4c3e4c Merge branch 'main' into alpha 2025-07-01 13:15:47 +08:00
CaIon
efb691e6c2 Merge remote-tracking branch 'origin/alpha' into alpha 2025-07-01 13:14:40 +08:00
CaIon
53e3b35437 feat: enhance JWT exchange process with proxy support. (close #1087) 2025-07-01 13:14:24 +08:00
CaIon
eb265a55e1 feat: enhance environment configuration and resource initialization 2025-07-01 13:13:30 +08:00
Calcium-Ion
950f7d214f Merge pull request #1322 from feitianbubu/pr/jimeng-key-delimiter
feat: jimeng apiKey format to use `|` delimiter
2025-07-01 10:44:19 +08:00
skynono
6bd2316d9c feat: jimeng apiKey format to use | delimiter 2025-07-01 10:35:29 +08:00
kikii16
efc8457770 修复gemini-balance测试gemini2.5pro的错误问题 2025-06-29 13:36:19 +08:00
同語
9b8b982d8a 🐛 fix: ratelimit style error
Merge pull request #1301 from tbphp/fix_ratelimit_style
2025-06-29 02:34:06 +08:00
t0ng7u
e6949e611a style: change the border radius of most components from full to lg size 2025-06-29 02:32:09 +08:00
t0ng7u
cffade7210 🤯style: remove useless card headerStyle 2025-06-29 00:11:15 +08:00
CaIon
6b9237f868 🐛 fix: refactor JSON unmarshalling across multiple handlers to use UnmarshalJson and UnmarshalJsonStr for consistency
This update replaces instances of DecodeJson and DecodeJsonStr with UnmarshalJson and UnmarshalJsonStr in various relay handlers, enhancing code consistency and clarity in JSON processing. The changes improve maintainability and align with recent refactoring efforts in the codebase.
2025-06-28 00:02:07 +08:00
CaIon
1f4cf07b63 🐛 fix: refactor response body handling in multiple relay handlers to utilize IOCopyBytesGracefully 2025-06-27 23:35:56 +08:00
skynono
59a1f4c900 feat: video action to constant 2025-06-27 23:19:34 +08:00
CaIon
0a04a76c71 🐛 fix: refactor JSON encoding and decoding in OpenAI handlers for improved consistency 2025-06-27 22:45:36 +08:00
CaIon
9e6bc518cc 🐛 fix: refactor OaiStreamHandler to improve last response handling and streamline response body closure 2025-06-27 22:44:20 +08:00
CaIon
bfb6fbbac9 🐛 fix: update hardcoded completion model ratio for gemini-2.5-flash-lite 2025-06-27 22:36:23 +08:00
CaIon
9c08d8cf20 feat: introduce IOCopyBytesGracefully function for streamlined response body handling
This update adds the IOCopyBytesGracefully function to the common package, which simplifies the process of copying response bodies in the OpenAI handlers. It enhances error handling and ensures proper resource management by encapsulating the logic for setting headers and writing response data. The OpenAI handlers have been refactored to utilize this new function, improving code clarity and maintainability.
2025-06-27 22:36:12 +08:00
CaIon
281054ff4c 🐛 fix: replace direct response body closure with common.CloseResponseBodyGracefully for improved error handling
This update standardizes the closure of HTTP response bodies across multiple stream handlers, enhancing error management and resource cleanup. The new method ensures that any errors during closure are handled gracefully, preventing potential request termination issues.
2025-06-27 21:40:36 +08:00
CaIon
3002659f47 feat: add CloseResponseBodyGracefully function to handle HTTP response body closure 2025-06-27 21:37:13 +08:00
Calcium-Ion
647f8d7958 Merge pull request #1274 from feitianbubu/feat/add-channel-jimeng
feat: 支持即梦视频渠道
2025-06-27 21:16:50 +08:00
CaIon
5d289d38ba 🐛 fix: handle response body errors more gracefully in OpenAI handler
Changes:
- Replaced error returns with logging for response body copy failures to prevent early termination of the request.
- Ensured that the response body is closed properly after writing to the client.
- Added comments to clarify the handling of billing and error reporting after the response has been sent.

This update improves error handling and maintains resource management in the OpenAI handler.
2025-06-27 21:13:21 +08:00
skynono
05ea0dd54f feat: add video channel jimeng 2025-06-27 17:08:20 +08:00
CaIon
1dad04ec09 feat: add Function and Container fields to ResponsesToolsCall struct #1305 2025-06-27 16:56:54 +08:00
Xyfacai
2171117c53 Merge pull request #1291 from feitianbubu/pr/add-origin-kling-api
feat: add origin kling api
2025-06-27 16:08:03 +08:00
Xyfacai
d389befc9e Merge pull request #1298 from xiangyuanliu/feat/page-format
feat: 优化分页组件
2025-06-27 15:55:30 +08:00
t0ng7u
3ced5ff144 chore: Improve channel creation UX: defer "Fetch Model List" action until after creation
Previously, the "Fetch Model List" button was visible in the channel-creation view even though
it only functions once a channel record exists, leading to user confusion.

Changes introduced:
• Render the "Fetch Model List" button only when editing an existing channel (`isEdit === true`).
• Display an informational Banner in creation mode to remind users that the upstream model list
  can be fetched after the channel has been created.
• Refactored JSX to apply the above conditional rendering without altering existing logic.

This update streamlines the creation workflow and sets clearer expectations for users.
2025-06-27 10:08:44 +08:00
t0ng7u
38d3ab5acf 💄refactor: enhance EditUser and AddUser form validation & UX
Changes in `web/src/pages/User/EditUser.js`:
• Added `rules` to
  – `Form.Select group`: now required with error “Please select group”.
  – `Form.InputNumber quota`: now required with error “Please enter quota”.
• Added `step={500000}` to quota `InputNumber` for quicker numeric input.
• Replaced invalid `readonly` with React-correct `readOnly`, and added descriptive placeholders for all binding-info fields (GitHub/OIDC/WeChat/Email/Telegram).
• Removed unused `downloadTextAsFile` import.

These updates tighten form validation, improve data entry ergonomics, and restore clear read-only indicators for third-party bindings.
2025-06-27 09:44:18 +08:00
t0ng7u
ab32e15a86 🐛 fix(redemptions-table): correct initial page index and pagination state
Summary:
The redemption list occasionally displayed an invalid range such as “Items -9 - 0” and failed to highlight page 1 after a refresh. This was caused by the table being initialized with `currentPage = 0`.

Changes:
• update `useEffect` to load data starting from page 1 instead of page 0
• refactor `loadRedemptions` to accept `page` (default 1) and sanitize backend‐returned pages (`<= 0` coerced to 1)
• keep other logic unchanged

Impact:
Pagination text and page selection now show correct values on first load or refresh, eliminating negative ranges and ensuring the first page is properly highlighted.
2025-06-27 07:42:04 +08:00
t0ng7u
25e17b95d5 🐛 fix(redemptions-table): show loading indicator while refetching data
Previously, the table did not enter the loading state after performing actions such as deleting, enabling, or disabling a redemption code. This caused a brief period where the UI appeared unresponsive while awaiting the backend response.

Changes made:
• Added `setLoading(true)` at the beginning of `loadRedemptions` to activate the loading spinner whenever data is (re)fetched.
• Added an explanatory code comment to clarify the intent.

This improves user experience by clearly indicating that the system is processing and prevents confusion during data refresh operations.
2025-06-27 07:29:28 +08:00
t0ng7u
d07224e658 🎁 refactor(ui/redemption): migrate EditRedemption page to Semi Form & enhance UX
SUMMARY
• Re-implemented `web/src/pages/Redemption/EditRedemption.js` with Semi Form components, removing legacy local-state handling.
• Added `formApiRef` for centralized control; external “Submit” button now triggers `formApi.submitForm()`.
• Replaced `Input/AutoComplete/DatePicker` etc. with `<Form.*>` fields, leveraging built-in validation & accessibility.
• Field validations:
  – `name` (create only), `quota`, `count` → required with localized messages.
• Expiration-time flow:
  – Default value `null` (no more 1970-01-01).
  – When loading data, convert 0 → null, timestamp → Date.
  – On submit, Date → unix seconds; empty → 0.
• Responsive grid layout (`Row/Col`) for tidy alignment.
• Added helpful `showClear` & full-width styling for inputs; quota presets retained.
• Cleaned unused imports & handlers; fixed linter issues.

RESULT
The Redemption form now benefits from higher performance, clearer validation, and a cleaner codebase consistent with Semi Design best practices.
2025-06-27 07:25:46 +08:00
t0ng7u
aa15d45a3d refactor(ui/token): migrate EditToken page to Semi Form API and polish UX
SUMMARY
• Re-implemented `EditToken.js` with Semi Form components, eliminating manual state handling and reducing re-renders.
• Added grid-based layout; “Expiration Time” selector now sits inline with quick-set buttons for consistent alignment on desktop & mobile.
• Introduced dedicated “Quota”, “Access”, “Model Limits”, and “Group” cards for clearer field grouping.
• Reworked model-limit interaction: single multi-select list replaces checkbox toggle; backend flag `model_limits_enabled` is now inferred automatically.
• Applied required validation rules to critical fields (`name`, `remain_quota`, `group`, `expired_time`, `tokenCount`) with localized messages.
• Enabled dynamic option loading for models & groups; default auto-group honoured.
• Added unlimited-quota switch, quota presets, and helpful extraText/tooltips.
• Removed obsolete `handleInputChange` & `setUnlimitedQuota` helpers; formApi now manages all data flow.
• Cleaned imports (e.g., dropped unused `IconUserGroup`), fixed linter errors, and updated submit logic to use `formApi.submitForm()`.

RESULT
The token creation/editing experience is faster, more accessible, and easier to maintain, fully aligned with Semi Design best practices.
2025-06-26 22:58:25 +08:00
tbphp
c6c68da0b5 fix: ratelimit style error 2025-06-26 21:32:05 +08:00
t0ng7u
1a0aac81df 🎨 style: remove all prefix icons to simplify the layout of the sidesheet component 2025-06-26 16:36:36 +08:00
t0ng7u
39cb45c11c 🎨 style: unify card header UI, switch to Avatar icons & remove oversized props
Summary
• Replaced gradient header blocks with compact, neutral headers wrapped in `Avatar` across the following pages:
  - Channel / EditChannel.js
  - Channel / EditTagModal.js
  - Redemption / EditRedemption.js
  - Token / EditToken.js
  - User / EditUser.js
  - User / AddUser.js

Details
1. Added `Avatar` import and substituted raw icon elements, assigning semantic colors (`blue`, `green`, `purple`, `orange`, etc.) and consistent 16 px icons for a cleaner look.
2. Removed gradient backgrounds, decorative “blur-ball” shapes, and extra paddings from header containers to achieve a tight, flat design.
3. Stripped all `size="large"` attributes from `Button`, `Input`, `Select`, `DatePicker`, `AutoComplete`, and `Avatar` components, allowing default sizing for better visual density.
4. Eliminated redundant `bodyStyle` background overrides in some `SideSheet` components.
5. No business logic touched; all changes are purely presentational.

Result
The editing and creation dialogs now share a unified, compact style consistent with the latest design language, improving readability and user experience without altering functionality.
2025-06-26 16:05:13 +08:00
t0ng7u
05d9aa53ef 🔒 style: Hide registration link when Self-Use Mode is enabled
• Add conditional rendering (`!status.self_use_mode_enabled`) to LoginForm
• Suppress “Don't have an account? Register” CTA in self-hosted scenarios
• Keeps UI clean and prevents unintended user sign-ups under self-use mode
• No impact on regular multi-user deployments
2025-06-26 04:29:44 +08:00
t0ng7u
86f374df58 🐛 fix(auth): prevent duplicate “session expired” toast on login
Login Form used to display the message “未登录或登录已过期,请重新登录” twice
because the `useEffect` that inspects the `expired` query parameter was
re-executed on every re-render (e.g. language change or React StrictMode’s
double-mount in development).

### What's changed
• **LoginForm.js** – `useEffect` that shows the toast now has an empty
  dependency array so it runs only once on initial mount.
• Reviewed **PasswordResetConfirm.js**, **PasswordResetForm.js** and
  **RegisterForm.js** and confirmed they do not contain the same issue;
  no changes were required.

### Impact
Users now see the “session expired” notification exactly once, removing
confusion and improving the overall UX.
2025-06-26 03:51:19 +08:00
t0ng7u
6935260bf0 🧶style(TokensTable): add IconDelete in Delete selected token button 2025-06-25 23:23:59 +08:00
t0ng7u
f0d888729b 🐛 fix(auth): restore proper state & context destructuring in Login- and Register-forms
Why
Clicking the “Continue” button on the login page no longer triggered the submission logic. The issue was introduced when `useState`/`useContext` hooks were destructured incorrectly, breaking the setter reference and omitting required values.

What’s changed
• **LoginForm.js**
  – Re-added setter in `useSearchParams` (`[searchParams, setSearchParams]`).
  – Corrected order of destructuring for `inputs` so `username`/`password` are available after hooks.
  – Switched `useContext` to `[userState, userDispatch]` for consistency.

• **RegisterForm.js**
  – Adopted `[userState, userDispatch]` from `UserContext` to mirror LoginForm and retain full state access.

Outcome
Login button now successfully invokes `handleSubmit`, and both auth components have consistent, fully-featured hook destructuring, preventing runtime errors and ensuring future state usage is straightforward.
2025-06-25 23:13:55 +08:00
t0ng7u
6d7d4292ef 💫 feat(ui): introduce dispersed blur-ball background to all auth views
This commit refreshes the visual design of the authentication pages and aligns them with the Home banner style.

Details
• LoginForm.js / RegisterForm.js / PasswordResetForm.js / PasswordResetConfirm.js
  – Wrap top-level container with `relative overflow-hidden` to provide a positioning context.
  – Inject two decorative blur balls:
    ▸ Indigo ball on the top-right (`blur-ball-indigo`).
    ▸ Teal ball on the middle-left (`blur-ball-teal`).
  – Disabled the default X-axis transform on the indigo ball to keep the ball anchored to the corner.
  – Removed redundant `mt-[64px]` from the outer container and shifted it to the inner wrapper to maintain vertical rhythm without affecting the background placement.

Result
The auth screens now feature subtle, non-intrusive atmospheric gradients in the top-right and mid-left corners, offering a cohesive look & feel across the application without obstructing the main content.
2025-06-25 22:57:04 +08:00
t0ng7u
fcefac9dbe 🐛 fix(auth): prevent initial render flicker & clean up state usage
• LoginForm / RegisterForm now initialise `status` directly from localStorage,
  avoiding a post-mount state update that caused a UI flash between OAuth
  options and email/username forms.

• Move Turnstile configuration into a dedicated effect that depends on
  `status`, ensuring setState is not called during rendering.

• Remove unused `setStatus` setter to resolve ESLint “declared but never read”
  warnings.

• Minor refactors: reorder hooks, de-duplicate navigate/context variables and
  streamline state destructuring for improved readability.
2025-06-25 22:46:11 +08:00
t0ng7u
ad5f731b20 🍭style: add mt-[64px] in class auth componets 2025-06-25 22:21:14 +08:00
Xiangyuan-liu
76da067d40 feat: 优化分页组件 2025-06-25 18:42:19 +08:00
CaIon
0689670698 🔧 fix(xinference): update Document type to 'any' for flexibility
- Changed the type of `Document` in `XinRerankResponseDocument` from `string` to `any` to accommodate various data types.
- Updated the `RerankHandler` to handle `Document` as `any`, ensuring proper assignment based on its actual type.

These modifications enhance the handling of document data, allowing for greater versatility in response structures.
2025-06-25 18:04:34 +08:00
t0ng7u
5a6f32c392 🎨 style(ui): refactor Tabs in ModelPricing to use native Semi UI styling
• Removed the custom `renderArrow` helper and its `Dropdown`-based arrow navigation, simplifying the component logic.
• Switched the `<Tabs>` component to rely on Semi UI’s built-in behaviour (no more `renderArrow` override).
• Kept `type="card"` and `collapsible` props for consistent visual appearance while still using the default style.
• Eliminated the now-unused `Dropdown` import.

This cleanup reduces bespoke UI code, makes future maintenance easier, and keeps the interface consistent with the rest of the application.
2025-06-25 15:40:27 +08:00
t0ng7u
d6276c4692 Merge remote-tracking branch 'origin/alpha' into alpha 2025-06-25 15:26:59 +08:00
t0ng7u
29a44eb7ae feat(homepage): enhance banner visuals & UX
• Added read-only Base URL input that shows `status.server_address` (fallback `window.location.origin`) and copies value on click.
• Embedded `ScrollList` as input `suffix`; auto-cycles common endpoints every 3 s and allows manual selection.
• Introduced `API_ENDPOINTS` array in `web/src/constants/common.constant.js` for centralized endpoint management.
• Implemented custom CSS to hide ScrollList wheel indicators / scrollbars for a cleaner look.
• Created two blurred colour spheres behind the banner (`blur-ball-indigo`, `blur-ball-teal`) with light-/dark-mode opacity tweaks and lower vertical placement.
• Increased letter-spacing for Chinese heading via conditional `tracking-wide` / `md:tracking-wider` classes to improve readability.
• Misc: updated imports, helper functions, and responsive sizes to keep UI consistent across devices.
2025-06-25 15:26:51 +08:00
CaIon
048a625181 🚀 feat(auth): support new model API paths in authentication and routing
- Updated TokenAuth middleware to handle requests for both `/v1beta/models/` and `/v1/models/`.
- Adjusted distributor middleware to recognize the new model path.
- Enhanced relay mode determination to include the new model path.
- Added route for handling POST requests to `/models/*path`.

These changes ensure compatibility with the new model API structure, improving the overall routing and authentication flow.
2025-06-25 00:19:38 +08:00
t0ng7u
64782027c4 Merge remote-tracking branch 'origin/alpha' into alpha 2025-06-24 18:10:04 +08:00
t0ng7u
277645db50 🔧 style(ui): Inline tag edit action in ChannelsTable
• Removed the dropdown menu previously used for tag-level operations.
• Added a standalone “Edit” button directly after the “Disable All” button, reducing the number of clicks required to edit a tag group.
• Deleted the now-unused `IconEdit` import and its icon reference.

This streamlines the tag management flow and keeps the UI cleaner and more accessible.
2025-06-24 18:09:16 +08:00
CaIon
3f53e4f53e 🔧 fix(model_ratio): adjust return values for gemini-2.5-pro and gemini-2.5-flash models 2025-06-24 18:08:42 +08:00
t0ng7u
0c5d4ca0a7 🎨 style(channels-table): standardize operation component size to small
All operation-related UI controls in `ChannelsTable` (buttons, dropdowns,
switches, inputs, tags, etc.) now explicitly use `size="small"`.

Reasons & benefits:
- Creates a more compact and consistent look across the table and modals.
- Improves visual coherence between desktop and mobile views.
- Purely presentational; no functional logic is affected.

No database changes or API interactions are involved.
2025-06-24 18:02:34 +08:00
t0ng7u
44495b153a 🚀 feat: enhance model testing UI with bulk selection, copy & success-filter buttons (#1288)
* ChannelsTable
  - Added row-level checkboxes to the model-testing table for multi-selection
  - Implemented cross-page “Select All / Deselect All” via rowSelection.onSelectAll
  - Introduced allSelectingRef to ignore redundant onChange after onSelectAll
  - Added “Copy Selected” button to copy chosen model names (comma-separated) using helpers.copy
  - Added “Select Successful” button to auto-tick all models that passed testing
  - Moved search bar and new action buttons into the modal title for better UX
  - Centralised page size constant MODEL_TABLE_PAGE_SIZE in channel.constants.js
  - Fixed pagination slicing and auto-page-switch logic during batch testing

* channel.constants
  - Exported MODEL_TABLE_PAGE_SIZE (default 10) for unified pagination control

This commit enables users to conveniently copy or filter successful models, fully supports cross-page bulk operations, and resolves previous selection inconsistencies.

Refs: #1288
2025-06-24 17:46:08 +08:00
t0ng7u
de6e551cdb fix: ensure table shows correct loading state on first render & during search
Frontend (`ChannelsTable.js`)
1. Initialize `loading` state to `true` so the spinner is visible while the first data request is in-flight.
2. Set `<Table>` prop `loading={loading || searching}` — the spinner now appears for both the initial load and any subsequent search requests.

Result
Users immediately see a loading indicator on page entry and whenever a search is running, improving perceived responsiveness.
2025-06-24 05:20:54 +08:00
t0ng7u
aeb393e391 🚀 feat: Align search API with channel listing & fix sorting toggle
1. Backend
   • `controller/channel.go`
     – Added pagination (`p`, `page_size`) support to `SearchChannels`.
     – Added independent `type` filter (keeps `type_counts` unaffected).
     – Returned `total`, `type_counts` to match `/api/channel/` response.

2. Frontend
   • `ChannelsTable.js`
     – `loadChannels` / `searchChannels` now pass `p`, `page_size`, `id_sort`, `type`, `status` correctly.
     – Pagination, page-size selector and type tabs work for both normal list and search mode.
     – Switch for “ID sort” calls proper API and keeps UI state in sync.
     – Removed unnecessary `normalize` helper; `getFormValues` back to concise form.

Result
• Search mode and normal listing now share identical pagination and filtering behavior.
• Type tabs show correct counts even after searching.
• “ID Sort” toggle no longer inverses actual behaviour.
2025-06-24 05:13:47 +08:00
t0ng7u
db1b11deaf fix(channels-table): preserve group filter when switching type or status tabs
Refactors `ChannelsTable.js` to ensure that the selected group filter is **never lost** when:

1. Cycling between channel-type tabs.
2. Changing the status dropdown (all / enabled / disabled).

Key points:

• `loadChannels` now detects active search filters (keyword / group / model) and transparently delegates to `searchChannels`, guaranteeing all parameters are sent in every request.
• `searchChannels` accepts optional `typeKey` and `statusF` arguments, enabling reuse without code duplication.
• Loading state handling is unified; no extra renders or side effects were introduced, keeping UI performance intact.
• Duplicate logic removed and responsibilities clearly separated for easier future maintenance.
2025-06-24 04:16:40 +08:00
t0ng7u
5a5e8ce652 Revert "🐛 fix: preserve group filter when switching channel type/status"
This reverts commit a8ba2eba33.
2025-06-24 01:51:26 +08:00
t0ng7u
6c31151430 💄 i18n: shorten channel search placeholder and update i18n
Replaced the verbose placeholder “Search channel ID, name, key and API address ...”
with a concise version “Channel ID, name, key, API address” in
`ChannelsTable.js` and synchronized the corresponding i18n entries.

This improves readability and keeps UI text consistent across languages.
2025-06-24 01:48:39 +08:00
t0ng7u
a8ba2eba33 🐛 fix: preserve group filter when switching channel type/status
Ensure that the selected "group" filter (and other form search values) persist across
type tab changes, status filter updates, pagination, and page-size changes.

Changes include:
• loadChannels: added `searchParams` argument and now appends keyword, group and model
  query strings to API calls.
• refresh / page handlers / type tabs / status Select: now pass current form values
  to loadChannels, keeping filters intact.
• searchChannels: maintains active type and status filters when issuing search requests.
• Form.Select (searchGroup): triggers loadChannels when only group filter is active,
  preventing parameter loss.
• Minor cleanup and comment adjustments.
2025-06-24 01:45:22 +08:00
t0ng7u
c974b1053c 🐛 fix(channel): remove duplicate model names in “Edit Channel” model dropdown (#1292)
• Unify the Select option structure as `{ key, label, value }`; add missing `key` to prevent duplicated rendering by Semi-UI.
• Trim and deduplicate the `models` array via `Set` inside `handleInputChange`, ensuring state always contains unique values.
• In the options-merging `useEffect`, use a `Map` keyed by `value` (after `trim`) to guarantee a unique `optionList` when combining backend data with currently selected models.
• Apply the same structure and de-duplication when:
  – Fetching models from `/api/channel/models`
  – Adding custom models (`addCustomModels`)
  – Fetching upstream model lists (`fetchUpstreamModelList`)
• Replace obsolete `text` field with `label` in custom option objects for consistency.

No backend changes are required; the fix is entirely front-end.

Closes #1292
2025-06-24 00:25:29 +08:00
t0ng7u
1ab75b8a92 🎨 feat(EditChannel): improve model selection UX, clipboard feedback & rounded styling (#1290)
* Added a dedicated effect to merge origin and selected models, ensuring selected items always remain in the dropdown list.
* Enhanced “Copy all models” button:
  * Shows info message when list is empty.
  * Displays success / error notification based on copy result.
* Unified UI look-and-feel by applying `!rounded-lg` class to inputs, selects, banners and buttons.
* i18n: added English translations for new prompts
  - "No models to copy"
  - "Model list copied to clipboard"
  - "Copy failed"
2025-06-24 00:02:22 +08:00
同語
75e3959474 🧬merge: Add a button to copy the selected model in the channel (#1290)
Merge pull request #1290 from JoeyLearnsToCode/feat-copy-models
2025-06-23 23:46:54 +08:00
t0ng7u
bc371778b6 🚀 feat: add enabled/disabled channel filtering & optimize type-based pagination (#1289)
WHAT’S NEW
• Backend
  – Introduced `parseStatusFilter` helper to normalize `status` query across handlers.
  – `GET /api/channel` & `GET /api/channel/search` now accept `status=enabled|disabled` to return only enabled or disabled channels.
  – Tag-mode branch respects both `statusFilter` and `typeFilter`; SQL paths trimmed to one query + one lightweight `GROUP BY` for `type_counts`.

• Frontend (`ChannelsTable.js`)
  – Added “Status Filter” `<Select>` (All / Enabled / Disabled) with localStorage persistence.
  – All data-loading and search requests now always append `type` (when not “all”) and `status` params, so filtering & pagination are handled entirely server-side.
  – Removed client-side post-filtering for type, preventing short pages and reducing CPU work.
  – Tabs’ type counts stay in sync via backend-provided `type_counts`.

IMPROVEMENTS
• Eliminated duplicated status-parsing logic; single source of truth eases future extension.
• Reduced redundant queries, improved consistency of counts in UI.
• Secured key leakage with `Omit("key")` unchanged; no perf regressions observed.

Closes #1289
2025-06-23 23:40:34 +08:00
skynono
cd2870aebc feat: add origin kling api 2025-06-23 22:36:23 +08:00
t0ng7u
7c72545217 Merge remote-tracking branch 'origin/alpha' into alpha 2025-06-23 17:35:50 +08:00
t0ng7u
2591ca3d60 🚀 chore(ui): Refactor UpstreamRatioSync with conflict-modal component, performance hooks & cleanup (#1286)
WHAT’S NEW
• Extracted reusable ConflictConfirmModal for clearer JSX hierarchy
• Added detailed conflict detection & confirmation flow before syncing options
• Refactored state-heavy callbacks (`selectValue`, `performSync`) with `useCallback` to avoid unnecessary renders
• Introduced build-time constants (later removed unused export) and unified helper utilities
• Ensured final ratios are rebuilt accurately before API `PUT`, fixing “value not updated” bug
• Enhanced UI hints: warning icon on conflict, multiline billing info, mobile-friendly modal size
• General code cleanup: removed dead variables, adopted early returns, improved comments

WHY
Improves maintainability, user clarity when billing-type collisions occur, and guarantees data consistency after synchronisation.
2025-06-23 17:35:39 +08:00
t0ng7u
c28190316f 🐛 fix(ratio-sync): reset pagination when filter/search changes
Add a `useEffect` hook in `UpstreamRatioSync.js` to automatically set
`currentPage` to `1` whenever `ratioTypeFilter` or `searchKeyword`
updates.
This prevents the table from appearing empty when users switch to the
“model_price” (fixed price) filter or perform a new search while on a
later page.

Additional changes:
- Import `useEffect` from React.

This enhancement delivers a smoother UX by ensuring the first page of
results is always shown after any filtering action.
2025-06-23 16:34:00 +08:00
JoeyLearnsToCode
ffc22b8dac Merge branch 'main' into feat-copy-models 2025-06-23 16:12:18 +08:00
t0ng7u
5367015a31 🎛️ feat(web): add “Conflict Rates” filter & highlight in Model Settings Visual Editor (#1286)
Introduce the ability to quickly locate models with conflicting billing configurations.

Key points
• Added `hasConflict` flag to detect models that define both a fixed price (`ModelPrice`) and any ratio (`ModelRatio` or `CompletionRatio`).
• Added “Show Only Conflict Rates” `Checkbox` to toolbar; filtering logic now supports keyword + conflict filtering.
• Display a red `Tag` beside the model name when a conflict is detected for immediate visual feedback.
• Kept `hasConflict` state in sync during add, update and delete operations.
• Imported `Checkbox` and `Tag` from **@douyinfe/semi-ui**.
• Minor UI tweaks (circle tag style, margin) for consistency.

This enhancement helps administrators swiftly identify and resolve incompatible pricing rules, addressing the need discussed in issue #1286.
2025-06-23 15:55:10 +08:00
CaIon
75c71c397e 🔧 chore: update STREAMING_TIMEOUT default value to 120 seconds in configuration 2025-06-22 18:47:40 +08:00
JoeyLearnsToCode
69420f713f feat: 渠道编辑页增加复制所有模型功能 2025-05-19 19:33:29 +08:00
152 changed files with 5413 additions and 4721 deletions

View File

@@ -7,6 +7,8 @@
# 调试相关配置
# 启用pprof
# ENABLE_PPROF=true
# 启用调试模式
# DEBUG=true
# 数据库相关配置
# 数据库连接字符串
@@ -41,6 +43,14 @@
# 更新任务启用
# UPDATE_TASK=true
# 对话超时设置
# 所有请求超时时间单位秒默认为0表示不限制
# RELAY_TIMEOUT=0
# 流模式无响应超时时间,单位秒,如果出现空补全可以尝试改为更大值
# STREAMING_TIMEOUT=120
# Gemini 识别图片 最大图片数量
# GEMINI_VISION_MAX_IMAGE_NUM=16
# 会话密钥
# SESSION_SECRET=random_string
@@ -58,8 +68,6 @@
# GET_MEDIA_TOKEN_NOT_STREAM=true
# 设置 Dify 渠道是否输出工作流和节点信息到客户端
# DIFY_DEBUG=true
# 设置流式一次回复的超时时间
# STREAMING_TIMEOUT=90
# 节点类型

View File

@@ -0,0 +1,19 @@
### PR 类型
- [ ] Bug 修复
- [ ] 新功能
- [ ] 文档更新
- [ ] 其他
### PR 是否包含破坏性更新?
- [ ]
- [ ]
### PR 描述
**请在下方详细描述您的 PR包括目的、实现细节等。**
### **重要提示**
**所有 PR 都必须提交到 `alpha` 分支。请确保您的 PR 目标分支是 `alpha`。**

View File

@@ -26,6 +26,7 @@ jobs:
- name: Build Frontend
env:
CI: ""
NODE_OPTIONS: "--max-old-space-size=4096"
run: |
cd web
bun install

View File

@@ -0,0 +1,21 @@
name: Check PR Branching Strategy
on:
pull_request:
types: [opened, synchronize, reopened, edited]
jobs:
check-branching-strategy:
runs-on: ubuntu-latest
steps:
- name: Enforce branching strategy
run: |
if [[ "${{ github.base_ref }}" == "main" ]]; then
if [[ "${{ github.head_ref }}" != "alpha" ]]; then
echo "Error: Pull requests to 'main' are only allowed from the 'alpha' branch."
exit 1
fi
elif [[ "${{ github.base_ref }}" != "alpha" ]]; then
echo "Error: Pull requests must be targeted to the 'alpha' or 'main' branch."
exit 1
fi
echo "Branching strategy check passed."

View File

@@ -100,7 +100,7 @@ This version supports multiple models, please refer to [API Documentation-Relay
For detailed configuration instructions, please refer to [Installation Guide-Environment Variables Configuration](https://docs.newapi.pro/installation/environment-variables):
- `GENERATE_DEFAULT_TOKEN`: Whether to generate initial tokens for newly registered users, default is `false`
- `STREAMING_TIMEOUT`: Streaming response timeout, default is 60 seconds
- `STREAMING_TIMEOUT`: Streaming response timeout, default is 120 seconds
- `DIFY_DEBUG`: Whether to output workflow and node information for Dify channels, default is `true`
- `FORCE_STREAM_OPTION`: Whether to override client stream_options parameter, default is `true`
- `GET_MEDIA_TOKEN`: Whether to count image tokens, default is `true`

View File

@@ -27,9 +27,6 @@
<a href="https://goreportcard.com/report/github.com/Calcium-Ion/new-api">
<img src="https://goreportcard.com/badge/github.com/Calcium-Ion/new-api" alt="GoReportCard">
</a>
<a href="https://coderabbit.ai">
<img src="https://img.shields.io/coderabbit/prs/github/QuantumNous/new-api?utm_source=oss&utm_medium=github&utm_campaign=QuantumNous%2Fnew-api&labelColor=171717&color=FF570A&link=https%3A%2F%2Fcoderabbit.ai&label=CodeRabbit+Reviews" alt="CodeRabbit Pull Request Reviews">
</a>
</p>
</div>
@@ -103,7 +100,7 @@ New API提供了丰富的功能详细特性请参考[特性说明](https://do
详细配置说明请参考[安装指南-环境变量配置](https://docs.newapi.pro/installation/environment-variables)
- `GENERATE_DEFAULT_TOKEN`:是否为新注册用户生成初始令牌,默认为 `false`
- `STREAMING_TIMEOUT`:流式回复超时时间,默认60秒
- `STREAMING_TIMEOUT`:流式回复超时时间,默认120秒
- `DIFY_DEBUG`Dify渠道是否输出工作流和节点信息默认 `true`
- `FORCE_STREAM_OPTION`是否覆盖客户端stream_options参数默认 `true`
- `GET_MEDIA_TOKEN`是否统计图片token默认 `true`

71
common/api_type.go Normal file
View File

@@ -0,0 +1,71 @@
package common
import "one-api/constant"
func ChannelType2APIType(channelType int) (int, bool) {
apiType := -1
switch channelType {
case constant.ChannelTypeOpenAI:
apiType = constant.APITypeOpenAI
case constant.ChannelTypeAnthropic:
apiType = constant.APITypeAnthropic
case constant.ChannelTypeBaidu:
apiType = constant.APITypeBaidu
case constant.ChannelTypePaLM:
apiType = constant.APITypePaLM
case constant.ChannelTypeZhipu:
apiType = constant.APITypeZhipu
case constant.ChannelTypeAli:
apiType = constant.APITypeAli
case constant.ChannelTypeXunfei:
apiType = constant.APITypeXunfei
case constant.ChannelTypeAIProxyLibrary:
apiType = constant.APITypeAIProxyLibrary
case constant.ChannelTypeTencent:
apiType = constant.APITypeTencent
case constant.ChannelTypeGemini:
apiType = constant.APITypeGemini
case constant.ChannelTypeZhipu_v4:
apiType = constant.APITypeZhipuV4
case constant.ChannelTypeOllama:
apiType = constant.APITypeOllama
case constant.ChannelTypePerplexity:
apiType = constant.APITypePerplexity
case constant.ChannelTypeAws:
apiType = constant.APITypeAws
case constant.ChannelTypeCohere:
apiType = constant.APITypeCohere
case constant.ChannelTypeDify:
apiType = constant.APITypeDify
case constant.ChannelTypeJina:
apiType = constant.APITypeJina
case constant.ChannelCloudflare:
apiType = constant.APITypeCloudflare
case constant.ChannelTypeSiliconFlow:
apiType = constant.APITypeSiliconFlow
case constant.ChannelTypeVertexAi:
apiType = constant.APITypeVertexAi
case constant.ChannelTypeMistral:
apiType = constant.APITypeMistral
case constant.ChannelTypeDeepSeek:
apiType = constant.APITypeDeepSeek
case constant.ChannelTypeMokaAI:
apiType = constant.APITypeMokaAI
case constant.ChannelTypeVolcEngine:
apiType = constant.APITypeVolcEngine
case constant.ChannelTypeBaiduV2:
apiType = constant.APITypeBaiduV2
case constant.ChannelTypeOpenRouter:
apiType = constant.APITypeOpenRouter
case constant.ChannelTypeXinference:
apiType = constant.APITypeXinference
case constant.ChannelTypeXai:
apiType = constant.APITypeXai
case constant.ChannelTypeCoze:
apiType = constant.APITypeCoze
}
if apiType == -1 {
return constant.APITypeOpenAI, false
}
return apiType, true
}

View File

@@ -193,109 +193,3 @@ const (
ChannelStatusManuallyDisabled = 2 // also don't use 0
ChannelStatusAutoDisabled = 3
)
const (
ChannelTypeUnknown = 0
ChannelTypeOpenAI = 1
ChannelTypeMidjourney = 2
ChannelTypeAzure = 3
ChannelTypeOllama = 4
ChannelTypeMidjourneyPlus = 5
ChannelTypeOpenAIMax = 6
ChannelTypeOhMyGPT = 7
ChannelTypeCustom = 8
ChannelTypeAILS = 9
ChannelTypeAIProxy = 10
ChannelTypePaLM = 11
ChannelTypeAPI2GPT = 12
ChannelTypeAIGC2D = 13
ChannelTypeAnthropic = 14
ChannelTypeBaidu = 15
ChannelTypeZhipu = 16
ChannelTypeAli = 17
ChannelTypeXunfei = 18
ChannelType360 = 19
ChannelTypeOpenRouter = 20
ChannelTypeAIProxyLibrary = 21
ChannelTypeFastGPT = 22
ChannelTypeTencent = 23
ChannelTypeGemini = 24
ChannelTypeMoonshot = 25
ChannelTypeZhipu_v4 = 26
ChannelTypePerplexity = 27
ChannelTypeLingYiWanWu = 31
ChannelTypeAws = 33
ChannelTypeCohere = 34
ChannelTypeMiniMax = 35
ChannelTypeSunoAPI = 36
ChannelTypeDify = 37
ChannelTypeJina = 38
ChannelCloudflare = 39
ChannelTypeSiliconFlow = 40
ChannelTypeVertexAi = 41
ChannelTypeMistral = 42
ChannelTypeDeepSeek = 43
ChannelTypeMokaAI = 44
ChannelTypeVolcEngine = 45
ChannelTypeBaiduV2 = 46
ChannelTypeXinference = 47
ChannelTypeXai = 48
ChannelTypeCoze = 49
ChannelTypeKling = 50
ChannelTypeDummy // this one is only for count, do not add any channel after this
)
var ChannelBaseURLs = []string{
"", // 0
"https://api.openai.com", // 1
"https://oa.api2d.net", // 2
"", // 3
"http://localhost:11434", // 4
"https://api.openai-sb.com", // 5
"https://api.openaimax.com", // 6
"https://api.ohmygpt.com", // 7
"", // 8
"https://api.caipacity.com", // 9
"https://api.aiproxy.io", // 10
"", // 11
"https://api.api2gpt.com", // 12
"https://api.aigc2d.com", // 13
"https://api.anthropic.com", // 14
"https://aip.baidubce.com", // 15
"https://open.bigmodel.cn", // 16
"https://dashscope.aliyuncs.com", // 17
"", // 18
"https://api.360.cn", // 19
"https://openrouter.ai/api", // 20
"https://api.aiproxy.io", // 21
"https://fastgpt.run/api/openapi", // 22
"https://hunyuan.tencentcloudapi.com", //23
"https://generativelanguage.googleapis.com", //24
"https://api.moonshot.cn", //25
"https://open.bigmodel.cn", //26
"https://api.perplexity.ai", //27
"", //28
"", //29
"", //30
"https://api.lingyiwanwu.com", //31
"", //32
"", //33
"https://api.cohere.ai", //34
"https://api.minimax.chat", //35
"", //36
"https://api.dify.ai", //37
"https://api.jina.ai", //38
"https://api.cloudflare.com", //39
"https://api.siliconflow.cn", //40
"", //41
"https://api.mistral.ai", //42
"https://api.deepseek.com", //43
"https://api.moka.ai", //44
"https://ark.cn-beijing.volces.com", //45
"https://qianfan.baidubce.com", //46
"", //47
"https://api.x.ai", //48
"https://api.coze.cn", //49
"https://api.klingai.com", //50
}

41
common/endpoint_type.go Normal file
View File

@@ -0,0 +1,41 @@
package common
import "one-api/constant"
// GetEndpointTypesByChannelType 获取渠道最优先端点类型(所有的渠道都支持 OpenAI 端点)
func GetEndpointTypesByChannelType(channelType int, modelName string) []constant.EndpointType {
var endpointTypes []constant.EndpointType
switch channelType {
case constant.ChannelTypeJina:
endpointTypes = []constant.EndpointType{constant.EndpointTypeJinaRerank}
//case constant.ChannelTypeMidjourney, constant.ChannelTypeMidjourneyPlus:
// endpointTypes = []constant.EndpointType{constant.EndpointTypeMidjourney}
//case constant.ChannelTypeSunoAPI:
// endpointTypes = []constant.EndpointType{constant.EndpointTypeSuno}
//case constant.ChannelTypeKling:
// endpointTypes = []constant.EndpointType{constant.EndpointTypeKling}
//case constant.ChannelTypeJimeng:
// endpointTypes = []constant.EndpointType{constant.EndpointTypeJimeng}
case constant.ChannelTypeAws:
fallthrough
case constant.ChannelTypeAnthropic:
endpointTypes = []constant.EndpointType{constant.EndpointTypeAnthropic, constant.EndpointTypeOpenAI}
case constant.ChannelTypeVertexAi:
fallthrough
case constant.ChannelTypeGemini:
endpointTypes = []constant.EndpointType{constant.EndpointTypeGemini, constant.EndpointTypeOpenAI}
case constant.ChannelTypeOpenRouter: // OpenRouter 只支持 OpenAI 端点
endpointTypes = []constant.EndpointType{constant.EndpointTypeOpenAI}
default:
if IsOpenAIResponseOnlyModel(modelName) {
endpointTypes = []constant.EndpointType{constant.EndpointTypeOpenAIResponse}
} else {
endpointTypes = []constant.EndpointType{constant.EndpointTypeOpenAI}
}
}
if IsImageGenerationModel(modelName) {
// add to first
endpointTypes = append([]constant.EndpointType{constant.EndpointTypeImageGeneration}, endpointTypes...)
}
return endpointTypes
}

View File

@@ -2,10 +2,11 @@ package common
import (
"bytes"
"encoding/json"
"github.com/gin-gonic/gin"
"io"
"one-api/constant"
"strings"
"time"
)
const KeyRequestBody = "key_request_body"
@@ -31,7 +32,7 @@ func UnmarshalBodyReusable(c *gin.Context, v any) error {
}
contentType := c.Request.Header.Get("Content-Type")
if strings.HasPrefix(contentType, "application/json") {
err = json.Unmarshal(requestBody, &v)
err = UnmarshalJson(requestBody, &v)
} else {
// skip for now
// TODO: someday non json request have variant model, we will need to implementation this
@@ -43,3 +44,35 @@ func UnmarshalBodyReusable(c *gin.Context, v any) error {
c.Request.Body = io.NopCloser(bytes.NewBuffer(requestBody))
return nil
}
func SetContextKey(c *gin.Context, key constant.ContextKey, value any) {
c.Set(string(key), value)
}
func GetContextKey(c *gin.Context, key constant.ContextKey) (any, bool) {
return c.Get(string(key))
}
func GetContextKeyString(c *gin.Context, key constant.ContextKey) string {
return c.GetString(string(key))
}
func GetContextKeyInt(c *gin.Context, key constant.ContextKey) int {
return c.GetInt(string(key))
}
func GetContextKeyBool(c *gin.Context, key constant.ContextKey) bool {
return c.GetBool(string(key))
}
func GetContextKeyStringSlice(c *gin.Context, key constant.ContextKey) []string {
return c.GetStringSlice(string(key))
}
func GetContextKeyStringMap(c *gin.Context, key constant.ContextKey) map[string]any {
return c.GetStringMap(string(key))
}
func GetContextKeyTime(c *gin.Context, key constant.ContextKey) time.Time {
return c.GetTime(string(key))
}

57
common/http.go Normal file
View File

@@ -0,0 +1,57 @@
package common
import (
"bytes"
"fmt"
"io"
"net/http"
"github.com/gin-gonic/gin"
)
func CloseResponseBodyGracefully(httpResponse *http.Response) {
if httpResponse == nil || httpResponse.Body == nil {
return
}
err := httpResponse.Body.Close()
if err != nil {
SysError("failed to close response body: " + err.Error())
}
}
func IOCopyBytesGracefully(c *gin.Context, src *http.Response, data []byte) {
if c.Writer == nil {
return
}
body := io.NopCloser(bytes.NewBuffer(data))
// We shouldn't set the header before we parse the response body, because the parse part may fail.
// And then we will have to send an error response, but in this case, the header has already been set.
// So the httpClient will be confused by the response.
// For example, Postman will report error, and we cannot check the response at all.
if src != nil {
for k, v := range src.Header {
// avoid setting Content-Length
if k == "Content-Length" {
continue
}
c.Writer.Header().Set(k, v[0])
}
}
// set Content-Length header manually BEFORE calling WriteHeader
c.Writer.Header().Set("Content-Length", fmt.Sprintf("%d", len(data)))
// Write header with status code (this sends the headers)
if src != nil {
c.Writer.WriteHeader(src.StatusCode)
} else {
c.Writer.WriteHeader(http.StatusOK)
}
_, err := io.Copy(c.Writer, body)
if err != nil {
LogError(c, fmt.Sprintf("failed to copy response body: %s", err.Error()))
}
}

View File

@@ -4,6 +4,7 @@ import (
"flag"
"fmt"
"log"
"one-api/constant"
"os"
"path/filepath"
"strconv"
@@ -24,7 +25,7 @@ func printHelp() {
fmt.Println("Usage: one-api [--port <port>] [--log-dir <log directory>] [--version] [--help]")
}
func LoadEnv() {
func InitEnv() {
flag.Parse()
if *PrintVersion {
@@ -95,4 +96,25 @@ func LoadEnv() {
GlobalWebRateLimitEnable = GetEnvOrDefaultBool("GLOBAL_WEB_RATE_LIMIT_ENABLE", true)
GlobalWebRateLimitNum = GetEnvOrDefault("GLOBAL_WEB_RATE_LIMIT", 60)
GlobalWebRateLimitDuration = int64(GetEnvOrDefault("GLOBAL_WEB_RATE_LIMIT_DURATION", 180))
initConstantEnv()
}
func initConstantEnv() {
constant.StreamingTimeout = GetEnvOrDefault("STREAMING_TIMEOUT", 120)
constant.DifyDebug = GetEnvOrDefaultBool("DIFY_DEBUG", true)
constant.MaxFileDownloadMB = GetEnvOrDefault("MAX_FILE_DOWNLOAD_MB", 20)
// ForceStreamOption 覆盖请求参数强制返回usage信息
constant.ForceStreamOption = GetEnvOrDefaultBool("FORCE_STREAM_OPTION", true)
constant.GetMediaToken = GetEnvOrDefaultBool("GET_MEDIA_TOKEN", true)
constant.GetMediaTokenNotStream = GetEnvOrDefaultBool("GET_MEDIA_TOKEN_NOT_STREAM", true)
constant.UpdateTask = GetEnvOrDefaultBool("UPDATE_TASK", true)
constant.AzureDefaultAPIVersion = GetEnvOrDefaultString("AZURE_DEFAULT_API_VERSION", "2025-04-01-preview")
constant.GeminiVisionMaxImageNum = GetEnvOrDefault("GEMINI_VISION_MAX_IMAGE_NUM", 16)
constant.NotifyLimitCount = GetEnvOrDefault("NOTIFY_LIMIT_COUNT", 2)
constant.NotificationLimitDurationMinute = GetEnvOrDefault("NOTIFICATION_LIMIT_DURATION_MINUTE", 10)
// GenerateDefaultToken 是否生成初始令牌,默认关闭。
constant.GenerateDefaultToken = GetEnvOrDefaultBool("GENERATE_DEFAULT_TOKEN", false)
// 是否启用错误日志
constant.ErrorLogEnabled = GetEnvOrDefaultBool("ERROR_LOG_ENABLED", false)
}

View File

@@ -5,12 +5,16 @@ import (
"encoding/json"
)
func DecodeJson(data []byte, v any) error {
return json.NewDecoder(bytes.NewReader(data)).Decode(v)
func UnmarshalJson(data []byte, v any) error {
return json.Unmarshal(data, v)
}
func DecodeJsonStr(data string, v any) error {
return DecodeJson(StringToByteSlice(data), v)
func UnmarshalJsonStr(data string, v any) error {
return json.Unmarshal(StringToByteSlice(data), v)
}
func DecodeJson(reader *bytes.Reader, v any) error {
return json.NewDecoder(reader).Decode(v)
}
func EncodeJson(v any) ([]byte, error) {

42
common/model.go Normal file
View File

@@ -0,0 +1,42 @@
package common
import "strings"
var (
// OpenAIResponseOnlyModels is a list of models that are only available for OpenAI responses.
OpenAIResponseOnlyModels = []string{
"o3-pro",
"o3-deep-research",
"o4-mini-deep-research",
}
ImageGenerationModels = []string{
"dall-e-3",
"dall-e-2",
"gpt-image-1",
"prefix:imagen-",
"flux-",
"flux.1-",
}
)
func IsOpenAIResponseOnlyModel(modelName string) bool {
for _, m := range OpenAIResponseOnlyModels {
if strings.Contains(modelName, m) {
return true
}
}
return false
}
func IsImageGenerationModel(modelName string) bool {
modelName = strings.ToLower(modelName)
for _, m := range ImageGenerationModels {
if strings.Contains(modelName, m) {
return true
}
if strings.HasPrefix(m, "prefix:") && strings.HasPrefix(modelName, strings.TrimPrefix(m, "prefix:")) {
return true
}
}
return false
}

62
common/page_info.go Normal file
View File

@@ -0,0 +1,62 @@
package common
import (
"github.com/gin-gonic/gin"
"strconv"
)
type PageInfo struct {
Page int `json:"page"` // page num 页码
PageSize int `json:"page_size"` // page size 页大小
StartTimestamp int64 `json:"start_timestamp"` // 秒级
EndTimestamp int64 `json:"end_timestamp"` // 秒级
Total int `json:"total"` // 总条数,后设置
Items any `json:"items"` // 数据,后设置
}
func (p *PageInfo) GetStartIdx() int {
return (p.Page - 1) * p.PageSize
}
func (p *PageInfo) GetEndIdx() int {
return p.Page * p.PageSize
}
func (p *PageInfo) GetPageSize() int {
return p.PageSize
}
func (p *PageInfo) GetPage() int {
return p.Page
}
func (p *PageInfo) SetTotal(total int) {
p.Total = total
}
func (p *PageInfo) SetItems(items any) {
p.Items = items
}
func GetPageQuery(c *gin.Context) (*PageInfo, error) {
pageInfo := &PageInfo{}
err := c.BindQuery(pageInfo)
if err != nil {
return nil, err
}
if pageInfo.Page < 1 {
// 兼容
page, _ := strconv.Atoi(c.Query("p"))
if page != 0 {
pageInfo.Page = page
} else {
pageInfo.Page = 1
}
}
if pageInfo.PageSize == 0 {
pageInfo.PageSize = ItemsPerPage
}
return pageInfo, nil
}

View File

@@ -16,6 +16,10 @@ import (
var RDB *redis.Client
var RedisEnabled = true
func RedisKeyCacheSeconds() int {
return SyncFrequency
}
// InitRedisClient This function is called after init()
func InitRedisClient() (err error) {
if os.Getenv("REDIS_CONN_STRING") == "" {

26
constant/README.md Normal file
View File

@@ -0,0 +1,26 @@
# constant 包 (`/constant`)
该目录仅用于放置全局可复用的**常量定义**,不包含任何业务逻辑或依赖关系。
## 当前文件
| 文件 | 说明 |
|----------------------|---------------------------------------------------------------------|
| `azure.go` | 定义与 Azure 相关的全局常量,如 `AzureNoRemoveDotTime`(控制删除 `.` 的截止时间)。 |
| `cache_key.go` | 缓存键格式字符串及 Token 相关字段常量,统一缓存命名规则。 |
| `channel_setting.go` | Channel 级别的设置键,如 `proxy``force_format` 等。 |
| `context_key.go` | 定义 `ContextKey` 类型以及在整个项目中使用的上下文键常量请求时间、Token/Channel/User 相关信息等)。 |
| `env.go` | 环境配置相关的全局变量,在启动阶段根据配置文件或环境变量注入。 |
| `finish_reason.go` | OpenAI/GPT 请求返回的 `finish_reason` 字符串常量集合。 |
| `midjourney.go` | Midjourney 相关错误码及动作(Action)常量与模型到动作的映射表。 |
| `setup.go` | 标识项目是否已完成初始化安装 (`Setup` 布尔值)。 |
| `task.go` | 各种任务(Task)平台、动作常量及模型与动作映射表,如 Suno、Midjourney 等。 |
| `user_setting.go` | 用户设置相关键常量以及通知类型(Email/Webhook)等。 |
## 使用约定
1. `constant` 包**只能被其他包引用**import**禁止在此包中引用项目内的其他自定义包**。如确有需要,仅允许引用 **Go 标准库**
2. 不允许在此目录内编写任何与业务流程、数据库操作、第三方服务调用等相关的逻辑代码。
3. 新增类型时,请保持命名语义清晰,并在本 README 的 **当前文件** 表格中补充说明,确保团队成员能够快速了解其用途。
> ⚠️ 违反以上约定将导致包之间产生不必要的耦合,影响代码可维护性与可测试性。请在提交代码前自行检查。

34
constant/api_type.go Normal file
View File

@@ -0,0 +1,34 @@
package constant
const (
APITypeOpenAI = iota
APITypeAnthropic
APITypePaLM
APITypeBaidu
APITypeZhipu
APITypeAli
APITypeXunfei
APITypeAIProxyLibrary
APITypeTencent
APITypeGemini
APITypeZhipuV4
APITypeOllama
APITypePerplexity
APITypeAws
APITypeCohere
APITypeDify
APITypeJina
APITypeCloudflare
APITypeSiliconFlow
APITypeVertexAi
APITypeMistral
APITypeDeepSeek
APITypeMokaAI
APITypeVolcEngine
APITypeBaiduV2
APITypeOpenRouter
APITypeXinference
APITypeXai
APITypeCoze
APITypeDummy // this one is only for count, do not add any channel after this
)

View File

@@ -1,12 +1,5 @@
package constant
import "one-api/common"
// 使用函数来避免初始化顺序带来的赋值问题
func RedisKeyCacheSeconds() int {
return common.SyncFrequency
}
// Cache keys
const (
UserGroupKeyFmt = "user_group:%d"

109
constant/channel.go Normal file
View File

@@ -0,0 +1,109 @@
package constant
const (
ChannelTypeUnknown = 0
ChannelTypeOpenAI = 1
ChannelTypeMidjourney = 2
ChannelTypeAzure = 3
ChannelTypeOllama = 4
ChannelTypeMidjourneyPlus = 5
ChannelTypeOpenAIMax = 6
ChannelTypeOhMyGPT = 7
ChannelTypeCustom = 8
ChannelTypeAILS = 9
ChannelTypeAIProxy = 10
ChannelTypePaLM = 11
ChannelTypeAPI2GPT = 12
ChannelTypeAIGC2D = 13
ChannelTypeAnthropic = 14
ChannelTypeBaidu = 15
ChannelTypeZhipu = 16
ChannelTypeAli = 17
ChannelTypeXunfei = 18
ChannelType360 = 19
ChannelTypeOpenRouter = 20
ChannelTypeAIProxyLibrary = 21
ChannelTypeFastGPT = 22
ChannelTypeTencent = 23
ChannelTypeGemini = 24
ChannelTypeMoonshot = 25
ChannelTypeZhipu_v4 = 26
ChannelTypePerplexity = 27
ChannelTypeLingYiWanWu = 31
ChannelTypeAws = 33
ChannelTypeCohere = 34
ChannelTypeMiniMax = 35
ChannelTypeSunoAPI = 36
ChannelTypeDify = 37
ChannelTypeJina = 38
ChannelCloudflare = 39
ChannelTypeSiliconFlow = 40
ChannelTypeVertexAi = 41
ChannelTypeMistral = 42
ChannelTypeDeepSeek = 43
ChannelTypeMokaAI = 44
ChannelTypeVolcEngine = 45
ChannelTypeBaiduV2 = 46
ChannelTypeXinference = 47
ChannelTypeXai = 48
ChannelTypeCoze = 49
ChannelTypeKling = 50
ChannelTypeJimeng = 51
ChannelTypeDummy // this one is only for count, do not add any channel after this
)
var ChannelBaseURLs = []string{
"", // 0
"https://api.openai.com", // 1
"https://oa.api2d.net", // 2
"", // 3
"http://localhost:11434", // 4
"https://api.openai-sb.com", // 5
"https://api.openaimax.com", // 6
"https://api.ohmygpt.com", // 7
"", // 8
"https://api.caipacity.com", // 9
"https://api.aiproxy.io", // 10
"", // 11
"https://api.api2gpt.com", // 12
"https://api.aigc2d.com", // 13
"https://api.anthropic.com", // 14
"https://aip.baidubce.com", // 15
"https://open.bigmodel.cn", // 16
"https://dashscope.aliyuncs.com", // 17
"", // 18
"https://api.360.cn", // 19
"https://openrouter.ai/api", // 20
"https://api.aiproxy.io", // 21
"https://fastgpt.run/api/openapi", // 22
"https://hunyuan.tencentcloudapi.com", //23
"https://generativelanguage.googleapis.com", //24
"https://api.moonshot.cn", //25
"https://open.bigmodel.cn", //26
"https://api.perplexity.ai", //27
"", //28
"", //29
"", //30
"https://api.lingyiwanwu.com", //31
"", //32
"", //33
"https://api.cohere.ai", //34
"https://api.minimax.chat", //35
"", //36
"https://api.dify.ai", //37
"https://api.jina.ai", //38
"https://api.cloudflare.com", //39
"https://api.siliconflow.cn", //40
"", //41
"https://api.mistral.ai", //42
"https://api.deepseek.com", //43
"https://api.moka.ai", //44
"https://ark.cn-beijing.volces.com", //45
"https://qianfan.baidubce.com", //46
"", //47
"https://api.x.ai", //48
"https://api.coze.cn", //49
"https://api.klingai.com", //50
"https://visual.volcengineapi.com", //51
}

View File

@@ -1,11 +1,35 @@
package constant
type ContextKey string
const (
ContextKeyRequestStartTime = "request_start_time"
ContextKeyUserSetting = "user_setting"
ContextKeyUserQuota = "user_quota"
ContextKeyUserStatus = "user_status"
ContextKeyUserEmail = "user_email"
ContextKeyUserGroup = "user_group"
ContextKeyUsingGroup = "group"
ContextKeyOriginalModel ContextKey = "original_model"
ContextKeyRequestStartTime ContextKey = "request_start_time"
/* token related keys */
ContextKeyTokenUnlimited ContextKey = "token_unlimited_quota"
ContextKeyTokenKey ContextKey = "token_key"
ContextKeyTokenId ContextKey = "token_id"
ContextKeyTokenGroup ContextKey = "token_group"
ContextKeyTokenAllowIps ContextKey = "allow_ips"
ContextKeyTokenSpecificChannelId ContextKey = "specific_channel_id"
ContextKeyTokenModelLimitEnabled ContextKey = "token_model_limit_enabled"
ContextKeyTokenModelLimit ContextKey = "token_model_limit"
/* channel related keys */
ContextKeyBaseUrl ContextKey = "base_url"
ContextKeyChannelType ContextKey = "channel_type"
ContextKeyChannelId ContextKey = "channel_id"
ContextKeyChannelSetting ContextKey = "channel_setting"
ContextKeyParamOverride ContextKey = "param_override"
/* user related keys */
ContextKeyUserId ContextKey = "id"
ContextKeyUserSetting ContextKey = "user_setting"
ContextKeyUserQuota ContextKey = "user_quota"
ContextKeyUserStatus ContextKey = "user_status"
ContextKeyUserEmail ContextKey = "user_email"
ContextKeyUserGroup ContextKey = "user_group"
ContextKeyUsingGroup ContextKey = "group"
ContextKeyUserName ContextKey = "username"
)

16
constant/endpoint_type.go Normal file
View File

@@ -0,0 +1,16 @@
package constant
type EndpointType string
const (
EndpointTypeOpenAI EndpointType = "openai"
EndpointTypeOpenAIResponse EndpointType = "openai-response"
EndpointTypeAnthropic EndpointType = "anthropic"
EndpointTypeGemini EndpointType = "gemini"
EndpointTypeJinaRerank EndpointType = "jina-rerank"
EndpointTypeImageGeneration EndpointType = "image-generation"
//EndpointTypeMidjourney EndpointType = "midjourney-proxy"
//EndpointTypeSuno EndpointType = "suno-proxy"
//EndpointTypeKling EndpointType = "kling"
//EndpointTypeJimeng EndpointType = "jimeng"
)

View File

@@ -1,9 +1,5 @@
package constant
import (
"one-api/common"
)
var StreamingTimeout int
var DifyDebug bool
var MaxFileDownloadMB int
@@ -17,39 +13,3 @@ var NotifyLimitCount int
var NotificationLimitDurationMinute int
var GenerateDefaultToken bool
var ErrorLogEnabled bool
//var GeminiModelMap = map[string]string{
// "gemini-1.0-pro": "v1",
//}
func InitEnv() {
StreamingTimeout = common.GetEnvOrDefault("STREAMING_TIMEOUT", 60)
DifyDebug = common.GetEnvOrDefaultBool("DIFY_DEBUG", true)
MaxFileDownloadMB = common.GetEnvOrDefault("MAX_FILE_DOWNLOAD_MB", 20)
// ForceStreamOption 覆盖请求参数强制返回usage信息
ForceStreamOption = common.GetEnvOrDefaultBool("FORCE_STREAM_OPTION", true)
GetMediaToken = common.GetEnvOrDefaultBool("GET_MEDIA_TOKEN", true)
GetMediaTokenNotStream = common.GetEnvOrDefaultBool("GET_MEDIA_TOKEN_NOT_STREAM", true)
UpdateTask = common.GetEnvOrDefaultBool("UPDATE_TASK", true)
AzureDefaultAPIVersion = common.GetEnvOrDefaultString("AZURE_DEFAULT_API_VERSION", "2025-04-01-preview")
GeminiVisionMaxImageNum = common.GetEnvOrDefault("GEMINI_VISION_MAX_IMAGE_NUM", 16)
NotifyLimitCount = common.GetEnvOrDefault("NOTIFY_LIMIT_COUNT", 2)
NotificationLimitDurationMinute = common.GetEnvOrDefault("NOTIFICATION_LIMIT_DURATION_MINUTE", 10)
// GenerateDefaultToken 是否生成初始令牌,默认关闭。
GenerateDefaultToken = common.GetEnvOrDefaultBool("GENERATE_DEFAULT_TOKEN", false)
// 是否启用错误日志
ErrorLogEnabled = common.GetEnvOrDefaultBool("ERROR_LOG_ENABLED", false)
//modelVersionMapStr := strings.TrimSpace(os.Getenv("GEMINI_MODEL_MAP"))
//if modelVersionMapStr == "" {
// return
//}
//for _, pair := range strings.Split(modelVersionMapStr, ",") {
// parts := strings.Split(pair, ":")
// if len(parts) == 2 {
// GeminiModelMap[parts[0]] = parts[1]
// } else {
// common.SysError(fmt.Sprintf("invalid model version map: %s", pair))
// }
//}
}

View File

@@ -6,11 +6,15 @@ const (
TaskPlatformSuno TaskPlatform = "suno"
TaskPlatformMidjourney = "mj"
TaskPlatformKling TaskPlatform = "kling"
TaskPlatformJimeng TaskPlatform = "jimeng"
)
const (
SunoActionMusic = "MUSIC"
SunoActionLyrics = "LYRICS"
TaskActionGenerate = "generate"
TaskActionTextGenerate = "textGenerate"
)
var SunoModel2Action = map[string]string{

View File

@@ -8,6 +8,7 @@ import (
"io"
"net/http"
"one-api/common"
"one-api/constant"
"one-api/model"
"one-api/service"
"one-api/setting"
@@ -341,34 +342,34 @@ func updateChannelMoonshotBalance(channel *model.Channel) (float64, error) {
}
func updateChannelBalance(channel *model.Channel) (float64, error) {
baseURL := common.ChannelBaseURLs[channel.Type]
baseURL := constant.ChannelBaseURLs[channel.Type]
if channel.GetBaseURL() == "" {
channel.BaseURL = &baseURL
}
switch channel.Type {
case common.ChannelTypeOpenAI:
case constant.ChannelTypeOpenAI:
if channel.GetBaseURL() != "" {
baseURL = channel.GetBaseURL()
}
case common.ChannelTypeAzure:
case constant.ChannelTypeAzure:
return 0, errors.New("尚未实现")
case common.ChannelTypeCustom:
case constant.ChannelTypeCustom:
baseURL = channel.GetBaseURL()
//case common.ChannelTypeOpenAISB:
// return updateChannelOpenAISBBalance(channel)
case common.ChannelTypeAIProxy:
case constant.ChannelTypeAIProxy:
return updateChannelAIProxyBalance(channel)
case common.ChannelTypeAPI2GPT:
case constant.ChannelTypeAPI2GPT:
return updateChannelAPI2GPTBalance(channel)
case common.ChannelTypeAIGC2D:
case constant.ChannelTypeAIGC2D:
return updateChannelAIGC2DBalance(channel)
case common.ChannelTypeSiliconFlow:
case constant.ChannelTypeSiliconFlow:
return updateChannelSiliconFlowBalance(channel)
case common.ChannelTypeDeepSeek:
case constant.ChannelTypeDeepSeek:
return updateChannelDeepSeekBalance(channel)
case common.ChannelTypeOpenRouter:
case constant.ChannelTypeOpenRouter:
return updateChannelOpenRouterBalance(channel)
case common.ChannelTypeMoonshot:
case constant.ChannelTypeMoonshot:
return updateChannelMoonshotBalance(channel)
default:
return 0, errors.New("尚未实现")

View File

@@ -11,12 +11,12 @@ import (
"net/http/httptest"
"net/url"
"one-api/common"
"one-api/constant"
"one-api/dto"
"one-api/middleware"
"one-api/model"
"one-api/relay"
relaycommon "one-api/relay/common"
"one-api/relay/constant"
"one-api/relay/helper"
"one-api/service"
"strconv"
@@ -31,18 +31,21 @@ import (
func testChannel(channel *model.Channel, testModel string) (err error, openAIErrorWithStatusCode *dto.OpenAIErrorWithStatusCode) {
tik := time.Now()
if channel.Type == common.ChannelTypeMidjourney {
if channel.Type == constant.ChannelTypeMidjourney {
return errors.New("midjourney channel test is not supported"), nil
}
if channel.Type == common.ChannelTypeMidjourneyPlus {
return errors.New("midjourney plus channel test is not supported!!!"), nil
if channel.Type == constant.ChannelTypeMidjourneyPlus {
return errors.New("midjourney plus channel test is not supported"), nil
}
if channel.Type == common.ChannelTypeSunoAPI {
if channel.Type == constant.ChannelTypeSunoAPI {
return errors.New("suno channel test is not supported"), nil
}
if channel.Type == common.ChannelTypeKling {
if channel.Type == constant.ChannelTypeKling {
return errors.New("kling channel test is not supported"), nil
}
if channel.Type == constant.ChannelTypeJimeng {
return errors.New("jimeng channel test is not supported"), nil
}
w := httptest.NewRecorder()
c, _ := gin.CreateTestContext(w)
@@ -53,7 +56,7 @@ func testChannel(channel *model.Channel, testModel string) (err error, openAIErr
strings.HasPrefix(testModel, "m3e") || // m3e 系列模型
strings.Contains(testModel, "bge-") || // bge 系列模型
strings.Contains(testModel, "embed") ||
channel.Type == common.ChannelTypeMokaAI { // 其他 embedding 模型
channel.Type == constant.ChannelTypeMokaAI { // 其他 embedding 模型
requestPath = "/v1/embeddings" // 修改请求路径
}
@@ -99,7 +102,7 @@ func testChannel(channel *model.Channel, testModel string) (err error, openAIErr
}
testModel = info.UpstreamModelName
apiType, _ := constant.ChannelType2APIType(channel.Type)
apiType, _ := common.ChannelType2APIType(channel.Type)
adaptor := relay.GetAdaptor(apiType)
if adaptor == nil {
return fmt.Errorf("invalid api type: %d, adaptor is nil", apiType), nil
@@ -199,7 +202,7 @@ func buildTestRequest(model string) *dto.GeneralOpenAIRequest {
testRequest.MaxTokens = 50
}
} else if strings.Contains(model, "gemini") {
testRequest.MaxTokens = 300
testRequest.MaxTokens = 3000
} else {
testRequest.MaxTokens = 10
}

View File

@@ -5,6 +5,7 @@ import (
"fmt"
"net/http"
"one-api/common"
"one-api/constant"
"one-api/model"
"strconv"
"strings"
@@ -40,6 +41,17 @@ type OpenAIModelsResponse struct {
Success bool `json:"success"`
}
func parseStatusFilter(statusParam string) int {
switch strings.ToLower(statusParam) {
case "enabled", "1":
return common.ChannelStatusEnabled
case "disabled", "0":
return 0
default:
return -1
}
}
func GetAllChannels(c *gin.Context) {
p, _ := strconv.Atoi(c.Query("p"))
pageSize, _ := strconv.Atoi(c.Query("page_size"))
@@ -52,6 +64,9 @@ func GetAllChannels(c *gin.Context) {
channelData := make([]*model.Channel, 0)
idSort, _ := strconv.ParseBool(c.Query("id_sort"))
enableTagMode, _ := strconv.ParseBool(c.Query("tag_mode"))
statusParam := c.Query("status")
// statusFilter: -1 all, 1 enabled, 0 disabled (include auto & manual)
statusFilter := parseStatusFilter(statusParam)
// type filter
typeStr := c.Query("type")
typeFilter := -1
@@ -64,42 +79,75 @@ func GetAllChannels(c *gin.Context) {
var total int64
if enableTagMode {
// tag 分页:先分页 tag再取各 tag 下 channels
tags, err := model.GetPaginatedTags((p-1)*pageSize, pageSize)
if err != nil {
c.JSON(http.StatusOK, gin.H{"success": false, "message": err.Error()})
return
}
for _, tag := range tags {
if tag != nil && *tag != "" {
tagChannel, err := model.GetChannelsByTag(*tag, idSort)
if err == nil {
channelData = append(channelData, tagChannel...)
}
if tag == nil || *tag == "" {
continue
}
tagChannels, err := model.GetChannelsByTag(*tag, idSort)
if err != nil {
continue
}
filtered := make([]*model.Channel, 0)
for _, ch := range tagChannels {
if statusFilter == common.ChannelStatusEnabled && ch.Status != common.ChannelStatusEnabled {
continue
}
if statusFilter == 0 && ch.Status == common.ChannelStatusEnabled {
continue
}
if typeFilter >= 0 && ch.Type != typeFilter {
continue
}
filtered = append(filtered, ch)
}
channelData = append(channelData, filtered...)
}
// 计算 tag 总数用于分页
total, _ = model.CountAllTags()
} else if typeFilter >= 0 {
channels, err := model.GetChannelsByType((p-1)*pageSize, pageSize, idSort, typeFilter)
if err != nil {
c.JSON(http.StatusOK, gin.H{"success": false, "message": err.Error()})
return
}
channelData = channels
total, _ = model.CountChannelsByType(typeFilter)
} else {
channels, err := model.GetAllChannels((p-1)*pageSize, pageSize, false, idSort)
baseQuery := model.DB.Model(&model.Channel{})
if typeFilter >= 0 {
baseQuery = baseQuery.Where("type = ?", typeFilter)
}
if statusFilter == common.ChannelStatusEnabled {
baseQuery = baseQuery.Where("status = ?", common.ChannelStatusEnabled)
} else if statusFilter == 0 {
baseQuery = baseQuery.Where("status != ?", common.ChannelStatusEnabled)
}
baseQuery.Count(&total)
order := "priority desc"
if idSort {
order = "id desc"
}
err := baseQuery.Order(order).Limit(pageSize).Offset((p - 1) * pageSize).Omit("key").Find(&channelData).Error
if err != nil {
c.JSON(http.StatusOK, gin.H{"success": false, "message": err.Error()})
return
}
channelData = channels
total, _ = model.CountAllChannels()
}
// calculate type counts
typeCounts, _ := model.CountChannelsGroupByType()
countQuery := model.DB.Model(&model.Channel{})
if statusFilter == common.ChannelStatusEnabled {
countQuery = countQuery.Where("status = ?", common.ChannelStatusEnabled)
} else if statusFilter == 0 {
countQuery = countQuery.Where("status != ?", common.ChannelStatusEnabled)
}
var results []struct {
Type int64
Count int64
}
_ = countQuery.Select("type, count(*) as count").Group("type").Find(&results).Error
typeCounts := make(map[int64]int64)
for _, r := range results {
typeCounts[r.Type] = r.Count
}
c.JSON(http.StatusOK, gin.H{
"success": true,
@@ -134,15 +182,15 @@ func FetchUpstreamModels(c *gin.Context) {
return
}
baseURL := common.ChannelBaseURLs[channel.Type]
baseURL := constant.ChannelBaseURLs[channel.Type]
if channel.GetBaseURL() != "" {
baseURL = channel.GetBaseURL()
}
url := fmt.Sprintf("%s/v1/models", baseURL)
switch channel.Type {
case common.ChannelTypeGemini:
case constant.ChannelTypeGemini:
url = fmt.Sprintf("%s/v1beta/openai/models", baseURL)
case common.ChannelTypeAli:
case constant.ChannelTypeAli:
url = fmt.Sprintf("%s/compatible-mode/v1/models", baseURL)
}
body, err := GetResponseBody("GET", url, channel, GetAuthHeader(channel.Key))
@@ -166,7 +214,7 @@ func FetchUpstreamModels(c *gin.Context) {
var ids []string
for _, model := range result.Data {
id := model.ID
if channel.Type == common.ChannelTypeGemini {
if channel.Type == constant.ChannelTypeGemini {
id = strings.TrimPrefix(id, "models/")
}
ids = append(ids, id)
@@ -199,6 +247,8 @@ func SearchChannels(c *gin.Context) {
keyword := c.Query("keyword")
group := c.Query("group")
modelKeyword := c.Query("model")
statusParam := c.Query("status")
statusFilter := parseStatusFilter(statusParam)
idSort, _ := strconv.ParseBool(c.Query("id_sort"))
enableTagMode, _ := strconv.ParseBool(c.Query("tag_mode"))
channelData := make([]*model.Channel, 0)
@@ -231,17 +281,71 @@ func SearchChannels(c *gin.Context) {
channelData = channels
}
if statusFilter == common.ChannelStatusEnabled || statusFilter == 0 {
filtered := make([]*model.Channel, 0, len(channelData))
for _, ch := range channelData {
if statusFilter == common.ChannelStatusEnabled && ch.Status != common.ChannelStatusEnabled {
continue
}
if statusFilter == 0 && ch.Status == common.ChannelStatusEnabled {
continue
}
filtered = append(filtered, ch)
}
channelData = filtered
}
// calculate type counts for search results
typeCounts := make(map[int64]int64)
for _, channel := range channelData {
typeCounts[int64(channel.Type)]++
}
typeParam := c.Query("type")
typeFilter := -1
if typeParam != "" {
if tp, err := strconv.Atoi(typeParam); err == nil {
typeFilter = tp
}
}
if typeFilter >= 0 {
filtered := make([]*model.Channel, 0, len(channelData))
for _, ch := range channelData {
if ch.Type == typeFilter {
filtered = append(filtered, ch)
}
}
channelData = filtered
}
page, _ := strconv.Atoi(c.DefaultQuery("p", "1"))
pageSize, _ := strconv.Atoi(c.DefaultQuery("page_size", "20"))
if page < 1 {
page = 1
}
if pageSize <= 0 {
pageSize = 20
}
total := len(channelData)
startIdx := (page - 1) * pageSize
if startIdx > total {
startIdx = total
}
endIdx := startIdx + pageSize
if endIdx > total {
endIdx = total
}
pagedData := channelData[startIdx:endIdx]
c.JSON(http.StatusOK, gin.H{
"success": true,
"message": "",
"data": gin.H{
"items": channelData,
"items": pagedData,
"total": total,
"type_counts": typeCounts,
},
})
@@ -285,7 +389,7 @@ func AddChannel(c *gin.Context) {
}
channel.CreatedTime = common.GetTimestamp()
keys := strings.Split(channel.Key, "\n")
if channel.Type == common.ChannelTypeVertexAi {
if channel.Type == constant.ChannelTypeVertexAi {
if channel.Other == "" {
c.JSON(http.StatusOK, gin.H{
"success": false,
@@ -510,7 +614,7 @@ func UpdateChannel(c *gin.Context) {
})
return
}
if channel.Type == common.ChannelTypeVertexAi {
if channel.Type == constant.ChannelTypeVertexAi {
if channel.Other == "" {
c.JSON(http.StatusOK, gin.H{
"success": false,
@@ -565,7 +669,7 @@ func FetchModels(c *gin.Context) {
baseURL := req.BaseURL
if baseURL == "" {
baseURL = common.ChannelBaseURLs[req.Type]
baseURL = constant.ChannelBaseURLs[req.Type]
}
client := &http.Client{}

View File

@@ -2,6 +2,7 @@ package controller
import (
"fmt"
"github.com/gin-gonic/gin"
"github.com/samber/lo"
"net/http"
"one-api/common"
@@ -14,10 +15,7 @@ import (
"one-api/relay/channel/minimax"
"one-api/relay/channel/moonshot"
relaycommon "one-api/relay/common"
relayconstant "one-api/relay/constant"
"one-api/setting"
"github.com/gin-gonic/gin"
)
// https://platform.openai.com/docs/api-reference/models/list
@@ -26,30 +24,10 @@ var openAIModels []dto.OpenAIModels
var openAIModelsMap map[string]dto.OpenAIModels
var channelId2Models map[int][]string
func getPermission() []dto.OpenAIModelPermission {
var permission []dto.OpenAIModelPermission
permission = append(permission, dto.OpenAIModelPermission{
Id: "modelperm-LwHkVFn8AcMItP432fKKDIKJ",
Object: "model_permission",
Created: 1626777600,
AllowCreateEngine: true,
AllowSampling: true,
AllowLogprobs: true,
AllowSearchIndices: false,
AllowView: true,
AllowFineTuning: false,
Organization: "*",
Group: nil,
IsBlocking: false,
})
return permission
}
func init() {
// https://platform.openai.com/docs/models/model-endpoint-compatibility
permission := getPermission()
for i := 0; i < relayconstant.APITypeDummy; i++ {
if i == relayconstant.APITypeAIProxyLibrary {
for i := 0; i < constant.APITypeDummy; i++ {
if i == constant.APITypeAIProxyLibrary {
continue
}
adaptor := relay.GetAdaptor(i)
@@ -57,69 +35,51 @@ func init() {
modelNames := adaptor.GetModelList()
for _, modelName := range modelNames {
openAIModels = append(openAIModels, dto.OpenAIModels{
Id: modelName,
Object: "model",
Created: 1626777600,
OwnedBy: channelName,
Permission: permission,
Root: modelName,
Parent: nil,
Id: modelName,
Object: "model",
Created: 1626777600,
OwnedBy: channelName,
})
}
}
for _, modelName := range ai360.ModelList {
openAIModels = append(openAIModels, dto.OpenAIModels{
Id: modelName,
Object: "model",
Created: 1626777600,
OwnedBy: ai360.ChannelName,
Permission: permission,
Root: modelName,
Parent: nil,
Id: modelName,
Object: "model",
Created: 1626777600,
OwnedBy: ai360.ChannelName,
})
}
for _, modelName := range moonshot.ModelList {
openAIModels = append(openAIModels, dto.OpenAIModels{
Id: modelName,
Object: "model",
Created: 1626777600,
OwnedBy: moonshot.ChannelName,
Permission: permission,
Root: modelName,
Parent: nil,
Id: modelName,
Object: "model",
Created: 1626777600,
OwnedBy: moonshot.ChannelName,
})
}
for _, modelName := range lingyiwanwu.ModelList {
openAIModels = append(openAIModels, dto.OpenAIModels{
Id: modelName,
Object: "model",
Created: 1626777600,
OwnedBy: lingyiwanwu.ChannelName,
Permission: permission,
Root: modelName,
Parent: nil,
Id: modelName,
Object: "model",
Created: 1626777600,
OwnedBy: lingyiwanwu.ChannelName,
})
}
for _, modelName := range minimax.ModelList {
openAIModels = append(openAIModels, dto.OpenAIModels{
Id: modelName,
Object: "model",
Created: 1626777600,
OwnedBy: minimax.ChannelName,
Permission: permission,
Root: modelName,
Parent: nil,
Id: modelName,
Object: "model",
Created: 1626777600,
OwnedBy: minimax.ChannelName,
})
}
for modelName, _ := range constant.MidjourneyModel2Action {
openAIModels = append(openAIModels, dto.OpenAIModels{
Id: modelName,
Object: "model",
Created: 1626777600,
OwnedBy: "midjourney",
Permission: permission,
Root: modelName,
Parent: nil,
Id: modelName,
Object: "model",
Created: 1626777600,
OwnedBy: "midjourney",
})
}
openAIModelsMap = make(map[string]dto.OpenAIModels)
@@ -127,9 +87,9 @@ func init() {
openAIModelsMap[aiModel.Id] = aiModel
}
channelId2Models = make(map[int][]string)
for i := 1; i <= common.ChannelTypeDummy; i++ {
apiType, success := relayconstant.ChannelType2APIType(i)
if !success || apiType == relayconstant.APITypeAIProxyLibrary {
for i := 1; i <= constant.ChannelTypeDummy; i++ {
apiType, success := common.ChannelType2APIType(i)
if !success || apiType == constant.APITypeAIProxyLibrary {
continue
}
meta := &relaycommon.RelayInfo{ChannelType: i}
@@ -144,11 +104,10 @@ func init() {
func ListModels(c *gin.Context) {
userOpenAiModels := make([]dto.OpenAIModels, 0)
permission := getPermission()
modelLimitEnable := c.GetBool("token_model_limit_enabled")
modelLimitEnable := common.GetContextKeyBool(c, constant.ContextKeyTokenModelLimitEnabled)
if modelLimitEnable {
s, ok := c.Get("token_model_limit")
s, ok := common.GetContextKey(c, constant.ContextKeyTokenModelLimit)
var tokenModelLimit map[string]bool
if ok {
tokenModelLimit = s.(map[string]bool)
@@ -156,17 +115,16 @@ func ListModels(c *gin.Context) {
tokenModelLimit = map[string]bool{}
}
for allowModel, _ := range tokenModelLimit {
if _, ok := openAIModelsMap[allowModel]; ok {
userOpenAiModels = append(userOpenAiModels, openAIModelsMap[allowModel])
if oaiModel, ok := openAIModelsMap[allowModel]; ok {
oaiModel.SupportedEndpointTypes = model.GetModelSupportEndpointTypes(allowModel)
userOpenAiModels = append(userOpenAiModels, oaiModel)
} else {
userOpenAiModels = append(userOpenAiModels, dto.OpenAIModels{
Id: allowModel,
Object: "model",
Created: 1626777600,
OwnedBy: "custom",
Permission: permission,
Root: allowModel,
Parent: nil,
Id: allowModel,
Object: "model",
Created: 1626777600,
OwnedBy: "custom",
SupportedEndpointTypes: model.GetModelSupportEndpointTypes(allowModel),
})
}
}
@@ -181,14 +139,14 @@ func ListModels(c *gin.Context) {
return
}
group := userGroup
tokenGroup := c.GetString("token_group")
tokenGroup := common.GetContextKeyString(c, constant.ContextKeyUserGroup)
if tokenGroup != "" {
group = tokenGroup
}
var models []string
if tokenGroup == "auto" {
for _, autoGroup := range setting.AutoGroups {
groupModels := model.GetGroupModels(autoGroup)
groupModels := model.GetGroupEnabledModels(autoGroup)
for _, g := range groupModels {
if !common.StringsContains(models, g) {
models = append(models, g)
@@ -196,20 +154,19 @@ func ListModels(c *gin.Context) {
}
}
} else {
models = model.GetGroupModels(group)
models = model.GetGroupEnabledModels(group)
}
for _, s := range models {
if _, ok := openAIModelsMap[s]; ok {
userOpenAiModels = append(userOpenAiModels, openAIModelsMap[s])
for _, modelName := range models {
if oaiModel, ok := openAIModelsMap[modelName]; ok {
oaiModel.SupportedEndpointTypes = model.GetModelSupportEndpointTypes(modelName)
userOpenAiModels = append(userOpenAiModels, oaiModel)
} else {
userOpenAiModels = append(userOpenAiModels, dto.OpenAIModels{
Id: s,
Object: "model",
Created: 1626777600,
OwnedBy: "custom",
Permission: permission,
Root: s,
Parent: nil,
Id: modelName,
Object: "model",
Created: 1626777600,
OwnedBy: "custom",
SupportedEndpointTypes: model.GetModelSupportEndpointTypes(modelName),
})
}
}

View File

@@ -65,7 +65,7 @@ func Playground(c *gin.Context) {
return
}
middleware.SetupContextForSelectedChannel(c, channel, playgroundRequest.Model)
c.Set(constant.ContextKeyRequestStartTime, time.Now())
common.SetContextKey(c, constant.ContextKeyRequestStartTime, time.Now())
// Write user context to ensure acceptUnsetRatio is available
userId := c.GetInt("id")

View File

@@ -8,12 +8,12 @@ import (
"log"
"net/http"
"one-api/common"
"one-api/constant"
constant2 "one-api/constant"
"one-api/dto"
"one-api/middleware"
"one-api/model"
"one-api/relay"
"one-api/relay/constant"
relayconstant "one-api/relay/constant"
"one-api/relay/helper"
"one-api/service"
@@ -69,7 +69,7 @@ func relayHandler(c *gin.Context, relayMode int) *dto.OpenAIErrorWithStatusCode
}
func Relay(c *gin.Context) {
relayMode := constant.Path2RelayMode(c.Request.URL.Path)
relayMode := relayconstant.Path2RelayMode(c.Request.URL.Path)
requestId := c.GetString(common.RequestIdKey)
group := c.GetString("group")
originalModel := c.GetString("original_model")
@@ -132,7 +132,7 @@ func WssRelay(c *gin.Context) {
return
}
relayMode := constant.Path2RelayMode(c.Request.URL.Path)
relayMode := relayconstant.Path2RelayMode(c.Request.URL.Path)
requestId := c.GetString(common.RequestIdKey)
group := c.GetString("group")
//wss://api.openai.com/v1/realtime?model=gpt-4o-realtime-preview-2024-10-01
@@ -295,7 +295,7 @@ func shouldRetry(c *gin.Context, openaiErr *dto.OpenAIErrorWithStatusCode, retry
}
if openaiErr.StatusCode == http.StatusBadRequest {
channelType := c.GetInt("channel_type")
if channelType == common.ChannelTypeAnthropic {
if channelType == constant.ChannelTypeAnthropic {
return true
}
return false

View File

@@ -74,8 +74,8 @@ func UpdateTaskByPlatform(platform constant.TaskPlatform, taskChannelM map[int][
//_ = UpdateMidjourneyTaskAll(context.Background(), tasks)
case constant.TaskPlatformSuno:
_ = UpdateSunoTaskAll(context.Background(), taskChannelM, taskM)
case constant.TaskPlatformKling:
_ = UpdateVideoTaskAll(context.Background(), taskChannelM, taskM)
case constant.TaskPlatformKling, constant.TaskPlatformJimeng:
_ = UpdateVideoTaskAll(context.Background(), platform, taskChannelM, taskM)
default:
common.SysLog("未知平台")
}

View File

@@ -2,27 +2,26 @@ package controller
import (
"context"
"encoding/json"
"fmt"
"io"
"net/http"
"one-api/common"
"one-api/constant"
"one-api/model"
"one-api/relay"
"one-api/relay/channel"
"time"
)
func UpdateVideoTaskAll(ctx context.Context, taskChannelM map[int][]string, taskM map[string]*model.Task) error {
func UpdateVideoTaskAll(ctx context.Context, platform constant.TaskPlatform, taskChannelM map[int][]string, taskM map[string]*model.Task) error {
for channelId, taskIds := range taskChannelM {
if err := updateVideoTaskAll(ctx, channelId, taskIds, taskM); err != nil {
if err := updateVideoTaskAll(ctx, platform, channelId, taskIds, taskM); err != nil {
common.LogError(ctx, fmt.Sprintf("Channel #%d failed to update video async tasks: %s", channelId, err.Error()))
}
}
return nil
}
func updateVideoTaskAll(ctx context.Context, channelId int, taskIds []string, taskM map[string]*model.Task) error {
func updateVideoTaskAll(ctx context.Context, platform constant.TaskPlatform, channelId int, taskIds []string, taskM map[string]*model.Task) error {
common.LogInfo(ctx, fmt.Sprintf("Channel #%d pending video tasks: %d", channelId, len(taskIds)))
if len(taskIds) == 0 {
return nil
@@ -39,7 +38,7 @@ func updateVideoTaskAll(ctx context.Context, channelId int, taskIds []string, ta
}
return fmt.Errorf("CacheGetChannel failed: %w", err)
}
adaptor := relay.GetTaskAdaptor(constant.TaskPlatformKling)
adaptor := relay.GetTaskAdaptor(platform)
if adaptor == nil {
return fmt.Errorf("video adaptor not found")
}
@@ -52,74 +51,68 @@ func updateVideoTaskAll(ctx context.Context, channelId int, taskIds []string, ta
}
func updateVideoSingleTask(ctx context.Context, adaptor channel.TaskAdaptor, channel *model.Channel, taskId string, taskM map[string]*model.Task) error {
baseURL := common.ChannelBaseURLs[channel.Type]
baseURL := constant.ChannelBaseURLs[channel.Type]
if channel.GetBaseURL() != "" {
baseURL = channel.GetBaseURL()
}
resp, err := adaptor.FetchTask(baseURL, channel.Key, map[string]any{
"task_id": taskId,
})
if err != nil {
return fmt.Errorf("FetchTask failed for task %s: %w", taskId, err)
}
if resp.StatusCode != http.StatusOK {
return fmt.Errorf("Get Video Task status code: %d", resp.StatusCode)
}
defer resp.Body.Close()
responseBody, err := io.ReadAll(resp.Body)
if err != nil {
return fmt.Errorf("ReadAll failed for task %s: %w", taskId, err)
}
var responseItem map[string]interface{}
err = json.Unmarshal(responseBody, &responseItem)
if err != nil {
common.LogError(ctx, fmt.Sprintf("Failed to parse video task response body: %v, body: %s", err, string(responseBody)))
return fmt.Errorf("Unmarshal failed for task %s: %w", taskId, err)
}
code, _ := responseItem["code"].(float64)
if code != 0 {
return fmt.Errorf("video task fetch failed for task %s", taskId)
}
data, ok := responseItem["data"].(map[string]interface{})
if !ok {
common.LogError(ctx, fmt.Sprintf("Video task data format error: %s", string(responseBody)))
return fmt.Errorf("video task data format error for task %s", taskId)
}
task := taskM[taskId]
if task == nil {
common.LogError(ctx, fmt.Sprintf("Task %s not found in taskM", taskId))
return fmt.Errorf("task %s not found", taskId)
}
if status, ok := data["task_status"].(string); ok {
switch status {
case "submitted", "queued":
task.Status = model.TaskStatusSubmitted
case "processing":
task.Status = model.TaskStatusInProgress
case "succeed":
task.Status = model.TaskStatusSuccess
task.Progress = "100%"
if url, err := adaptor.ParseResultUrl(responseItem); err == nil {
task.FailReason = url
} else {
common.LogWarn(ctx, fmt.Sprintf("Failed to get url from body for task %s: %s", task.TaskID, err.Error()))
}
case "failed":
task.Status = model.TaskStatusFailure
task.Progress = "100%"
if reason, ok := data["fail_reason"].(string); ok {
task.FailReason = reason
}
}
resp, err := adaptor.FetchTask(baseURL, channel.Key, map[string]any{
"task_id": taskId,
"action": task.Action,
})
if err != nil {
return fmt.Errorf("fetchTask failed for task %s: %w", taskId, err)
}
//if resp.StatusCode != http.StatusOK {
//return fmt.Errorf("get Video Task status code: %d", resp.StatusCode)
//}
defer resp.Body.Close()
responseBody, err := io.ReadAll(resp.Body)
if err != nil {
return fmt.Errorf("readAll failed for task %s: %w", taskId, err)
}
// If task failed, refund quota
if task.Status == model.TaskStatusFailure {
taskResult, err := adaptor.ParseTaskResult(responseBody)
if err != nil {
return fmt.Errorf("parseTaskResult failed for task %s: %w", taskId, err)
}
//if taskResult.Code != 0 {
// return fmt.Errorf("video task fetch failed for task %s", taskId)
//}
now := time.Now().Unix()
if taskResult.Status == "" {
return fmt.Errorf("task %s status is empty", taskId)
}
task.Status = model.TaskStatus(taskResult.Status)
switch taskResult.Status {
case model.TaskStatusSubmitted:
task.Progress = "10%"
case model.TaskStatusQueued:
task.Progress = "20%"
case model.TaskStatusInProgress:
task.Progress = "30%"
if task.StartTime == 0 {
task.StartTime = now
}
case model.TaskStatusSuccess:
task.Progress = "100%"
if task.FinishTime == 0 {
task.FinishTime = now
}
task.FailReason = taskResult.Url
case model.TaskStatusFailure:
task.Status = model.TaskStatusFailure
task.Progress = "100%"
if task.FinishTime == 0 {
task.FinishTime = now
}
task.FailReason = taskResult.Reason
common.LogInfo(ctx, fmt.Sprintf("Task %s failed: %s", task.TaskID, task.FailReason))
quota := task.Quota
if quota != 0 {
@@ -129,6 +122,11 @@ func updateVideoSingleTask(ctx context.Context, adaptor channel.TaskAdaptor, cha
logContent := fmt.Sprintf("Video async task failed %s, refund %s", task.TaskID, common.LogQuota(quota))
model.RecordLog(task.UserId, model.LogTypeSystem, logContent)
}
default:
return fmt.Errorf("unknown task status %s for task %s", taskResult.Status, taskId)
}
if taskResult.Progress != "" {
task.Progress = taskResult.Progress
}
task.Data = responseBody

View File

@@ -246,15 +246,15 @@ func Register(c *gin.Context) {
}
func GetAllUsers(c *gin.Context) {
p, _ := strconv.Atoi(c.Query("p"))
pageSize, _ := strconv.Atoi(c.Query("page_size"))
if p < 1 {
p = 1
pageInfo, err := common.GetPageQuery(c)
if err != nil {
c.JSON(http.StatusOK, gin.H{
"success": false,
"message": "parse page query failed",
})
return
}
if pageSize < 0 {
pageSize = common.ItemsPerPage
}
users, total, err := model.GetAllUsers((p-1)*pageSize, pageSize)
users, total, err := model.GetAllUsers(pageInfo)
if err != nil {
c.JSON(http.StatusOK, gin.H{
"success": false,
@@ -262,15 +262,13 @@ func GetAllUsers(c *gin.Context) {
})
return
}
pageInfo.SetTotal(int(total))
pageInfo.SetItems(users)
c.JSON(http.StatusOK, gin.H{
"success": true,
"message": "",
"data": gin.H{
"items": users,
"total": total,
"page": p,
"page_size": pageSize,
},
"data": pageInfo,
})
return
}
@@ -489,7 +487,7 @@ func GetUserModels(c *gin.Context) {
groups := setting.GetUserUsableGroups(user.Group)
var models []string
for group := range groups {
for _, g := range model.GetGroupModels(group) {
for _, g := range model.GetGroupEnabledModels(group) {
if !common.StringsContains(models, g) {
models = append(models, g)
}

View File

@@ -16,7 +16,7 @@ services:
- REDIS_CONN_STRING=redis://redis
- TZ=Asia/Shanghai
- ERROR_LOG_ENABLED=true # 是否启用错误日志记录
# - TIKTOKEN_CACHE_DIR=./tiktoken_cache # 如果需要使用tiktoken_cache请取消注释
# - STREAMING_TIMEOUT=120 # 流模式无响应超时时间单位秒默认120秒如果出现空补全可以尝试改为更大值
# - SESSION_SECRET=random_string # 多机部署时设置,必须修改这个随机字符串!!!!!!!
# - NODE_TYPE=slave # Uncomment for slave node in multi-node deployment
# - SYNC_FREQUENCY=60 # Uncomment if regular database syncing is needed

View File

@@ -66,7 +66,7 @@ type GeneralOpenAIRequest struct {
func (r *GeneralOpenAIRequest) ToMap() map[string]any {
result := make(map[string]any)
data, _ := common.EncodeJson(r)
_ = common.DecodeJson(data, &result)
_ = common.UnmarshalJson(data, &result)
return result
}
@@ -646,4 +646,6 @@ type ResponsesToolsCall struct {
Name string `json:"name,omitempty"`
Description string `json:"description,omitempty"`
Parameters json.RawMessage `json:"parameters,omitempty"`
Function json.RawMessage `json:"function,omitempty"`
Container json.RawMessage `json:"container,omitempty"`
}

View File

@@ -1,26 +1,11 @@
package dto
type OpenAIModelPermission struct {
Id string `json:"id"`
Object string `json:"object"`
Created int `json:"created"`
AllowCreateEngine bool `json:"allow_create_engine"`
AllowSampling bool `json:"allow_sampling"`
AllowLogprobs bool `json:"allow_logprobs"`
AllowSearchIndices bool `json:"allow_search_indices"`
AllowView bool `json:"allow_view"`
AllowFineTuning bool `json:"allow_fine_tuning"`
Organization string `json:"organization"`
Group *string `json:"group"`
IsBlocking bool `json:"is_blocking"`
}
import "one-api/constant"
type OpenAIModels struct {
Id string `json:"id"`
Object string `json:"object"`
Created int `json:"created"`
OwnedBy string `json:"owned_by"`
Permission []OpenAIModelPermission `json:"permission"`
Root string `json:"root"`
Parent *string `json:"parent"`
Id string `json:"id"`
Object string `json:"object"`
Created int `json:"created"`
OwnedBy string `json:"owned_by"`
SupportedEndpointTypes []constant.EndpointType `json:"supported_endpoint_types"`
}

View File

@@ -4,7 +4,7 @@ type RerankRequest struct {
Documents []any `json:"documents"`
Query string `json:"query"`
Model string `json:"model"`
TopN int `json:"top_n"`
TopN int `json:"top_n,omitempty"`
ReturnDocuments *bool `json:"return_documents,omitempty"`
MaxChunkPerDoc int `json:"max_chunk_per_doc,omitempty"`
OverLapTokens int `json:"overlap_tokens,omitempty"`

85
main.go
View File

@@ -32,12 +32,12 @@ var buildFS embed.FS
var indexPage []byte
func main() {
err := godotenv.Load(".env")
if err != nil {
common.SysLog("Support for .env file is disabled: " + err.Error())
}
common.LoadEnv()
err := InitResources()
if err != nil {
common.FatalLog("failed to initialize resources: " + err.Error())
return
}
common.SetupLogger()
common.SysLog("New API " + common.Version + " started")
@@ -47,19 +47,7 @@ func main() {
if common.DebugEnabled {
common.SysLog("running in debug mode")
}
// Initialize SQL Database
err = model.InitDB()
if err != nil {
common.FatalLog("failed to initialize database: " + err.Error())
}
model.CheckSetup()
// Initialize SQL Database
err = model.InitLogDB()
if err != nil {
common.FatalLog("failed to initialize database: " + err.Error())
}
defer func() {
err := model.CloseDB()
if err != nil {
@@ -67,21 +55,6 @@ func main() {
}
}()
// Initialize Redis
err = common.InitRedisClient()
if err != nil {
common.FatalLog("failed to initialize Redis: " + err.Error())
}
// Initialize model settings
ratio_setting.InitRatioSettings()
// Initialize constants
constant.InitEnv()
// Initialize options
model.InitOptionMap()
service.InitTokenEncoders()
if common.RedisEnabled {
// for compatibility with old versions
common.MemoryCacheEnabled = true
@@ -186,3 +159,51 @@ func main() {
common.FatalLog("failed to start HTTP server: " + err.Error())
}
}
func InitResources() error {
// Initialize resources here if needed
// This is a placeholder function for future resource initialization
err := godotenv.Load(".env")
if err != nil {
common.SysLog("未找到 .env 文件,使用默认环境变量,如果需要,请创建 .env 文件并设置相关变量")
common.SysLog("No .env file found, using default environment variables. If needed, please create a .env file and set the relevant variables.")
}
// 加载环境变量
common.InitEnv()
// Initialize model settings
ratio_setting.InitRatioSettings()
service.InitHttpClient()
service.InitTokenEncoders()
// Initialize SQL Database
err = model.InitDB()
if err != nil {
common.FatalLog("failed to initialize database: " + err.Error())
return err
}
model.CheckSetup()
// Initialize options, should after model.InitDB()
model.InitOptionMap()
// 初始化模型
model.GetPricing()
// Initialize SQL Database
err = model.InitLogDB()
if err != nil {
return err
}
// Initialize Redis
err = common.InitRedisClient()
if err != nil {
return err
}
return nil
}

View File

@@ -184,7 +184,7 @@ func TokenAuth() func(c *gin.Context) {
}
}
// gemini api 从query中获取key
if strings.HasPrefix(c.Request.URL.Path, "/v1beta/models/") {
if strings.HasPrefix(c.Request.URL.Path, "/v1beta/models/") || strings.HasPrefix(c.Request.URL.Path, "/v1/models/") {
skKey := c.Query("key")
if skKey != "" {
c.Request.Header.Set("Authorization", "Bearer "+skKey)

View File

@@ -25,7 +25,7 @@ type ModelRequest struct {
func Distribute() func(c *gin.Context) {
return func(c *gin.Context) {
allowIpsMap := c.GetStringMap("allow_ips")
allowIpsMap := common.GetContextKeyStringMap(c, constant.ContextKeyTokenAllowIps)
if len(allowIpsMap) != 0 {
clientIp := c.ClientIP()
if _, ok := allowIpsMap[clientIp]; !ok {
@@ -34,14 +34,14 @@ func Distribute() func(c *gin.Context) {
}
}
var channel *model.Channel
channelId, ok := c.Get("specific_channel_id")
channelId, ok := common.GetContextKey(c, constant.ContextKeyTokenSpecificChannelId)
modelRequest, shouldSelectChannel, err := getModelRequest(c)
if err != nil {
abortWithOpenAiMessage(c, http.StatusBadRequest, "Invalid request, "+err.Error())
return
}
userGroup := c.GetString(constant.ContextKeyUserGroup)
tokenGroup := c.GetString("token_group")
userGroup := common.GetContextKeyString(c, constant.ContextKeyUserGroup)
tokenGroup := common.GetContextKeyString(c, constant.ContextKeyTokenGroup)
if tokenGroup != "" {
// check common.UserUsableGroups[userGroup]
if _, ok := setting.GetUserUsableGroups(userGroup)[tokenGroup]; !ok {
@@ -57,7 +57,7 @@ func Distribute() func(c *gin.Context) {
}
userGroup = tokenGroup
}
c.Set(constant.ContextKeyUsingGroup, userGroup)
common.SetContextKey(c, constant.ContextKeyUsingGroup, userGroup)
if ok {
id, err := strconv.Atoi(channelId.(string))
if err != nil {
@@ -76,9 +76,9 @@ func Distribute() func(c *gin.Context) {
} else {
// Select a channel for the user
// check token model mapping
modelLimitEnable := c.GetBool("token_model_limit_enabled")
modelLimitEnable := common.GetContextKeyBool(c, constant.ContextKeyTokenModelLimitEnabled)
if modelLimitEnable {
s, ok := c.Get("token_model_limit")
s, ok := common.GetContextKey(c, constant.ContextKeyTokenModelLimit)
var tokenModelLimit map[string]bool
if ok {
tokenModelLimit = s.(map[string]bool)
@@ -121,7 +121,7 @@ func Distribute() func(c *gin.Context) {
}
}
}
c.Set(constant.ContextKeyRequestStartTime, time.Now())
common.SetContextKey(c, constant.ContextKeyRequestStartTime, time.Now())
SetupContextForSelectedChannel(c, channel, modelRequest.Model)
c.Next()
}
@@ -171,15 +171,25 @@ func getModelRequest(c *gin.Context) (*ModelRequest, bool, error) {
c.Set("platform", string(constant.TaskPlatformSuno))
c.Set("relay_mode", relayMode)
} else if strings.Contains(c.Request.URL.Path, "/v1/video/generations") {
relayMode := relayconstant.Path2RelayKling(c.Request.Method, c.Request.URL.Path)
if relayMode == relayconstant.RelayModeKlingFetchByID {
shouldSelectChannel = false
err = common.UnmarshalBodyReusable(c, &modelRequest)
var platform string
var relayMode int
if strings.HasPrefix(modelRequest.Model, "jimeng") {
platform = string(constant.TaskPlatformJimeng)
relayMode = relayconstant.Path2RelayJimeng(c.Request.Method, c.Request.URL.Path)
if relayMode == relayconstant.RelayModeJimengFetchByID {
shouldSelectChannel = false
}
} else {
err = common.UnmarshalBodyReusable(c, &modelRequest)
platform = string(constant.TaskPlatformKling)
relayMode = relayconstant.Path2RelayKling(c.Request.Method, c.Request.URL.Path)
if relayMode == relayconstant.RelayModeKlingFetchByID {
shouldSelectChannel = false
}
}
c.Set("platform", string(constant.TaskPlatformKling))
c.Set("platform", platform)
c.Set("relay_mode", relayMode)
} else if strings.HasPrefix(c.Request.URL.Path, "/v1beta/models/") {
} else if strings.HasPrefix(c.Request.URL.Path, "/v1beta/models/") || strings.HasPrefix(c.Request.URL.Path, "/v1/models/") {
// Gemini API 路径处理: /v1beta/models/gemini-2.0-flash:generateContent
relayMode := relayconstant.RelayModeGemini
modelName := extractModelNameFromGeminiPath(c.Request.URL.Path)
@@ -251,21 +261,21 @@ func SetupContextForSelectedChannel(c *gin.Context, channel *model.Channel, mode
c.Set("base_url", channel.GetBaseURL())
// TODO: api_version统一
switch channel.Type {
case common.ChannelTypeAzure:
case constant.ChannelTypeAzure:
c.Set("api_version", channel.Other)
case common.ChannelTypeVertexAi:
case constant.ChannelTypeVertexAi:
c.Set("region", channel.Other)
case common.ChannelTypeXunfei:
case constant.ChannelTypeXunfei:
c.Set("api_version", channel.Other)
case common.ChannelTypeGemini:
case constant.ChannelTypeGemini:
c.Set("api_version", channel.Other)
case common.ChannelTypeAli:
case constant.ChannelTypeAli:
c.Set("plugin", channel.Other)
case common.ChannelCloudflare:
case constant.ChannelCloudflare:
c.Set("api_version", channel.Other)
case common.ChannelTypeMokaAI:
case constant.ChannelTypeMokaAI:
c.Set("api_version", channel.Other)
case common.ChannelTypeCoze:
case constant.ChannelTypeCoze:
c.Set("bot_id", channel.Other)
}
}

View File

@@ -0,0 +1,47 @@
package middleware
import (
"bytes"
"encoding/json"
"io"
"one-api/common"
"one-api/constant"
"github.com/gin-gonic/gin"
)
func KlingRequestConvert() func(c *gin.Context) {
return func(c *gin.Context) {
var originalReq map[string]interface{}
if err := common.UnmarshalBodyReusable(c, &originalReq); err != nil {
c.Next()
return
}
model, _ := originalReq["model"].(string)
prompt, _ := originalReq["prompt"].(string)
unifiedReq := map[string]interface{}{
"model": model,
"prompt": prompt,
"metadata": originalReq,
}
jsonData, err := json.Marshal(unifiedReq)
if err != nil {
c.Next()
return
}
// Rewrite request body and path
c.Request.Body = io.NopCloser(bytes.NewBuffer(jsonData))
c.Request.URL.Path = "/v1/video/generations"
if image := originalReq["image"]; image == "" {
c.Set("action", constant.TaskActionTextGenerate)
}
// We have to reset the request body for the next handlers
c.Set(common.KeyRequestBody, jsonData)
c.Next()
}
}

View File

@@ -177,9 +177,9 @@ func ModelRequestRateLimit() func(c *gin.Context) {
successMaxCount := setting.ModelRequestRateLimitSuccessCount
// 获取分组
group := c.GetString("token_group")
group := common.GetContextKeyString(c, constant.ContextKeyTokenGroup)
if group == "" {
group = c.GetString(constant.ContextKeyUserGroup)
group = common.GetContextKeyString(c, constant.ContextKeyUserGroup)
}
//获取分组的限流配置

View File

@@ -21,7 +21,22 @@ type Ability struct {
Tag *string `json:"tag" gorm:"index"`
}
func GetGroupModels(group string) []string {
type AbilityWithChannel struct {
Ability
ChannelType int `json:"channel_type"`
}
func GetAllEnableAbilityWithChannels() ([]AbilityWithChannel, error) {
var abilities []AbilityWithChannel
err := DB.Table("abilities").
Select("abilities.*, channels.type as channel_type").
Joins("left join channels on abilities.channel_id = channels.id").
Where("abilities.enabled = ?", true).
Scan(&abilities).Error
return abilities, err
}
func GetGroupEnabledModels(group string) []string {
var models []string
// Find distinct models
DB.Table("abilities").Where(commonGroupCol+" = ? and enabled = ?", group, true).Distinct("model").Pluck("model", &models)
@@ -46,7 +61,7 @@ func getPriority(group string, model string, retry int) (int, error) {
var priorities []int
err := DB.Model(&Ability{}).
Select("DISTINCT(priority)").
Where(commonGroupCol+" = ? and model = ? and enabled = ?", group, model, commonTrueVal).
Where(commonGroupCol+" = ? and model = ? and enabled = ?", group, model, true).
Order("priority DESC"). // 按优先级降序排序
Pluck("priority", &priorities).Error // Pluck用于将查询的结果直接扫描到一个切片中
@@ -72,14 +87,14 @@ func getPriority(group string, model string, retry int) (int, error) {
}
func getChannelQuery(group string, model string, retry int) *gorm.DB {
maxPrioritySubQuery := DB.Model(&Ability{}).Select("MAX(priority)").Where(commonGroupCol+" = ? and model = ? and enabled = ?", group, model, commonTrueVal)
channelQuery := DB.Where(commonGroupCol+" = ? and model = ? and enabled = ? and priority = (?)", group, model, commonTrueVal, maxPrioritySubQuery)
maxPrioritySubQuery := DB.Model(&Ability{}).Select("MAX(priority)").Where(commonGroupCol+" = ? and model = ? and enabled = ?", group, model, true)
channelQuery := DB.Where(commonGroupCol+" = ? and model = ? and enabled = ? and priority = (?)", group, model, true, maxPrioritySubQuery)
if retry != 0 {
priority, err := getPriority(group, model, retry)
if err != nil {
common.SysError(fmt.Sprintf("Get priority failed: %s", err.Error()))
} else {
channelQuery = DB.Where(commonGroupCol+" = ? and model = ? and enabled = ? and priority = ?", group, model, commonTrueVal, priority)
channelQuery = DB.Where(commonGroupCol+" = ? and model = ? and enabled = ? and priority = ?", group, model, true, priority)
}
}

View File

@@ -1,20 +1,24 @@
package model
import (
"fmt"
"one-api/common"
"one-api/constant"
"one-api/setting/ratio_setting"
"one-api/types"
"sync"
"time"
)
type Pricing struct {
ModelName string `json:"model_name"`
QuotaType int `json:"quota_type"`
ModelRatio float64 `json:"model_ratio"`
ModelPrice float64 `json:"model_price"`
OwnerBy string `json:"owner_by"`
CompletionRatio float64 `json:"completion_ratio"`
EnableGroup []string `json:"enable_groups,omitempty"`
ModelName string `json:"model_name"`
QuotaType int `json:"quota_type"`
ModelRatio float64 `json:"model_ratio"`
ModelPrice float64 `json:"model_price"`
OwnerBy string `json:"owner_by"`
CompletionRatio float64 `json:"completion_ratio"`
EnableGroup []string `json:"enable_groups"`
SupportedEndpointTypes []constant.EndpointType `json:"supported_endpoint_types"`
}
var (
@@ -23,47 +27,89 @@ var (
updatePricingLock sync.Mutex
)
func GetPricing() []Pricing {
updatePricingLock.Lock()
defer updatePricingLock.Unlock()
var (
modelSupportEndpointTypes = make(map[string][]constant.EndpointType)
modelSupportEndpointsLock = sync.RWMutex{}
)
func GetPricing() []Pricing {
if time.Since(lastGetPricingTime) > time.Minute*1 || len(pricingMap) == 0 {
updatePricing()
updatePricingLock.Lock()
defer updatePricingLock.Unlock()
// Double check after acquiring the lock
if time.Since(lastGetPricingTime) > time.Minute*1 || len(pricingMap) == 0 {
modelSupportEndpointsLock.Lock()
defer modelSupportEndpointsLock.Unlock()
updatePricing()
}
}
//if group != "" {
// userPricingMap := make([]Pricing, 0)
// models := GetGroupModels(group)
// for _, pricing := range pricingMap {
// if !common.StringsContains(models, pricing.ModelName) {
// pricing.Available = false
// }
// userPricingMap = append(userPricingMap, pricing)
// }
// return userPricingMap
//}
return pricingMap
}
func GetModelSupportEndpointTypes(model string) []constant.EndpointType {
if model == "" {
return make([]constant.EndpointType, 0)
}
modelSupportEndpointsLock.RLock()
defer modelSupportEndpointsLock.RUnlock()
if endpoints, ok := modelSupportEndpointTypes[model]; ok {
return endpoints
}
return make([]constant.EndpointType, 0)
}
func updatePricing() {
//modelRatios := common.GetModelRatios()
enableAbilities := GetAllEnableAbilities()
modelGroupsMap := make(map[string][]string)
enableAbilities, err := GetAllEnableAbilityWithChannels()
if err != nil {
common.SysError(fmt.Sprintf("GetAllEnableAbilityWithChannels error: %v", err))
return
}
modelGroupsMap := make(map[string]*types.Set[string])
for _, ability := range enableAbilities {
groups := modelGroupsMap[ability.Model]
if groups == nil {
groups = make([]string, 0)
groups, ok := modelGroupsMap[ability.Model]
if !ok {
groups = types.NewSet[string]()
modelGroupsMap[ability.Model] = groups
}
if !common.StringsContains(groups, ability.Group) {
groups = append(groups, ability.Group)
groups.Add(ability.Group)
}
//这里使用切片而不是Set因为一个模型可能支持多个端点类型并且第一个端点是优先使用端点
modelSupportEndpointsStr := make(map[string][]string)
for _, ability := range enableAbilities {
endpoints, ok := modelSupportEndpointsStr[ability.Model]
if !ok {
endpoints = make([]string, 0)
modelSupportEndpointsStr[ability.Model] = endpoints
}
modelGroupsMap[ability.Model] = groups
channelTypes := common.GetEndpointTypesByChannelType(ability.ChannelType, ability.Model)
for _, channelType := range channelTypes {
if !common.StringsContains(endpoints, string(channelType)) {
endpoints = append(endpoints, string(channelType))
}
}
modelSupportEndpointsStr[ability.Model] = endpoints
}
modelSupportEndpointTypes = make(map[string][]constant.EndpointType)
for model, endpoints := range modelSupportEndpointsStr {
supportedEndpoints := make([]constant.EndpointType, 0)
for _, endpointStr := range endpoints {
endpointType := constant.EndpointType(endpointStr)
supportedEndpoints = append(supportedEndpoints, endpointType)
}
modelSupportEndpointTypes[model] = supportedEndpoints
}
pricingMap = make([]Pricing, 0)
for model, groups := range modelGroupsMap {
pricing := Pricing{
ModelName: model,
EnableGroup: groups,
ModelName: model,
EnableGroup: groups.Items(),
SupportedEndpointTypes: modelSupportEndpointTypes[model],
}
modelPrice, findPrice := ratio_setting.GetModelPrice(model, false)
if findPrice {

View File

@@ -10,7 +10,7 @@ import (
func cacheSetToken(token Token) error {
key := common.GenerateHMAC(token.Key)
token.Clean()
err := common.RedisHSetObj(fmt.Sprintf("token:%s", key), &token, time.Duration(constant.RedisKeyCacheSeconds())*time.Second)
err := common.RedisHSetObj(fmt.Sprintf("token:%s", key), &token, time.Duration(common.RedisKeyCacheSeconds())*time.Second)
if err != nil {
return err
}

View File

@@ -114,7 +114,7 @@ func GetMaxUserId() int {
return user.Id
}
func GetAllUsers(startIdx int, num int) (users []*User, total int64, err error) {
func GetAllUsers(pageInfo *common.PageInfo) (users []*User, total int64, err error) {
// Start transaction
tx := DB.Begin()
if tx.Error != nil {
@@ -134,7 +134,7 @@ func GetAllUsers(startIdx int, num int) (users []*User, total int64, err error)
}
// Get paginated users within same transaction
err = tx.Unscoped().Order("id desc").Limit(num).Offset(startIdx).Omit("password").Find(&users).Error
err = tx.Unscoped().Order("id desc").Limit(pageInfo.GetPageSize()).Offset(pageInfo.GetStartIdx()).Omit("password").Find(&users).Error
if err != nil {
tx.Rollback()
return nil, 0, err

View File

@@ -24,12 +24,12 @@ type UserBase struct {
}
func (user *UserBase) WriteContext(c *gin.Context) {
c.Set(constant.ContextKeyUserGroup, user.Group)
c.Set(constant.ContextKeyUserQuota, user.Quota)
c.Set(constant.ContextKeyUserStatus, user.Status)
c.Set(constant.ContextKeyUserEmail, user.Email)
c.Set("username", user.Username)
c.Set(constant.ContextKeyUserSetting, user.GetSetting())
common.SetContextKey(c, constant.ContextKeyUserGroup, user.Group)
common.SetContextKey(c, constant.ContextKeyUserQuota, user.Quota)
common.SetContextKey(c, constant.ContextKeyUserStatus, user.Status)
common.SetContextKey(c, constant.ContextKeyUserEmail, user.Email)
common.SetContextKey(c, constant.ContextKeyUserName, user.Username)
common.SetContextKey(c, constant.ContextKeyUserSetting, user.GetSetting())
}
func (user *UserBase) GetSetting() map[string]interface{} {
@@ -70,7 +70,7 @@ func updateUserCache(user User) error {
return common.RedisHSetObj(
getUserCacheKey(user.Id),
user.ToBaseUser(),
time.Duration(constant.RedisKeyCacheSeconds())*time.Second,
time.Duration(common.RedisKeyCacheSeconds())*time.Second,
)
}

View File

@@ -45,5 +45,5 @@ type TaskAdaptor interface {
// FetchTask
FetchTask(baseUrl, key string, body map[string]any) (*http.Response, error)
ParseResultUrl(resp map[string]any) (string, error)
ParseTaskResult(respBody []byte) (*relaycommon.TaskInfo, error)
}

View File

@@ -30,7 +30,7 @@ func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) {
var fullRequestURL string
switch info.RelayMode {
case constant.RelayModeEmbeddings:
fullRequestURL = fmt.Sprintf("%s/api/v1/services/embeddings/text-embedding/text-embedding", info.BaseUrl)
fullRequestURL = fmt.Sprintf("%s/compatible-mode/v1/embeddings", info.BaseUrl)
case constant.RelayModeRerank:
fullRequestURL = fmt.Sprintf("%s/api/v1/services/rerank/text-rerank/text-rerank", info.BaseUrl)
case constant.RelayModeImagesGenerations:
@@ -82,7 +82,7 @@ func (a *Adaptor) ConvertRerankRequest(c *gin.Context, relayMode int, request dt
}
func (a *Adaptor) ConvertEmbeddingRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.EmbeddingRequest) (any, error) {
return embeddingRequestOpenAI2Ali(request), nil
return request, nil
}
func (a *Adaptor) ConvertAudioRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.AudioRequest) (io.Reader, error) {

View File

@@ -132,10 +132,7 @@ func aliImageHandler(c *gin.Context, resp *http.Response, info *relaycommon.Rela
if err != nil {
return service.OpenAIErrorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil
}
err = resp.Body.Close()
if err != nil {
return service.OpenAIErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil
}
common.CloseResponseBodyGracefully(resp)
err = json.Unmarshal(responseBody, &aliTaskResponse)
if err != nil {
return service.OpenAIErrorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil

View File

@@ -4,6 +4,7 @@ import (
"encoding/json"
"io"
"net/http"
"one-api/common"
"one-api/dto"
relaycommon "one-api/relay/common"
"one-api/service"
@@ -35,10 +36,7 @@ func RerankHandler(c *gin.Context, resp *http.Response, info *relaycommon.RelayI
if err != nil {
return service.OpenAIErrorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil
}
err = resp.Body.Close()
if err != nil {
return service.OpenAIErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil
}
common.CloseResponseBodyGracefully(resp)
var aliResponse AliRerankResponse
err = json.Unmarshal(responseBody, &aliResponse)

View File

@@ -39,34 +39,18 @@ func embeddingRequestOpenAI2Ali(request dto.EmbeddingRequest) *AliEmbeddingReque
}
func aliEmbeddingHandler(c *gin.Context, resp *http.Response) (*dto.OpenAIErrorWithStatusCode, *dto.Usage) {
var aliResponse AliEmbeddingResponse
err := json.NewDecoder(resp.Body).Decode(&aliResponse)
var fullTextResponse dto.OpenAIEmbeddingResponse
err := json.NewDecoder(resp.Body).Decode(&fullTextResponse)
if err != nil {
return service.OpenAIErrorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil
}
err = resp.Body.Close()
if err != nil {
return service.OpenAIErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil
}
if aliResponse.Code != "" {
return &dto.OpenAIErrorWithStatusCode{
Error: dto.OpenAIError{
Message: aliResponse.Message,
Type: aliResponse.Code,
Param: aliResponse.RequestId,
Code: aliResponse.Code,
},
StatusCode: resp.StatusCode,
}, nil
}
common.CloseResponseBodyGracefully(resp)
model := c.GetString("model")
if model == "" {
model = "text-embedding-v4"
}
fullTextResponse := embeddingResponseAli2OpenAI(&aliResponse, model)
jsonResponse, err := json.Marshal(fullTextResponse)
if err != nil {
return service.OpenAIErrorWrapper(err, "marshal_response_body_failed", http.StatusInternalServerError), nil
@@ -186,10 +170,7 @@ func aliStreamHandler(c *gin.Context, resp *http.Response) (*dto.OpenAIErrorWith
return false
}
})
err := resp.Body.Close()
if err != nil {
return service.OpenAIErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil
}
common.CloseResponseBodyGracefully(resp)
return nil, &usage
}
@@ -199,10 +180,7 @@ func aliHandler(c *gin.Context, resp *http.Response) (*dto.OpenAIErrorWithStatus
if err != nil {
return service.OpenAIErrorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil
}
err = resp.Body.Close()
if err != nil {
return service.OpenAIErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil
}
common.CloseResponseBodyGracefully(resp)
err = json.Unmarshal(responseBody, &aliResponse)
if err != nil {
return service.OpenAIErrorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil

View File

@@ -166,10 +166,7 @@ func baiduStreamHandler(c *gin.Context, resp *http.Response) (*dto.OpenAIErrorWi
return false
}
})
err := resp.Body.Close()
if err != nil {
return service.OpenAIErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil
}
common.CloseResponseBodyGracefully(resp)
return nil, &usage
}
@@ -179,10 +176,7 @@ func baiduHandler(c *gin.Context, resp *http.Response) (*dto.OpenAIErrorWithStat
if err != nil {
return service.OpenAIErrorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil
}
err = resp.Body.Close()
if err != nil {
return service.OpenAIErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil
}
common.CloseResponseBodyGracefully(resp)
err = json.Unmarshal(responseBody, &baiduResponse)
if err != nil {
return service.OpenAIErrorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil
@@ -215,10 +209,7 @@ func baiduEmbeddingHandler(c *gin.Context, resp *http.Response) (*dto.OpenAIErro
if err != nil {
return service.OpenAIErrorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil
}
err = resp.Body.Close()
if err != nil {
return service.OpenAIErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil
}
common.CloseResponseBodyGracefully(resp)
err = json.Unmarshal(responseBody, &baiduResponse)
if err != nil {
return service.OpenAIErrorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil
@@ -280,7 +271,7 @@ func getBaiduAccessTokenHelper(apiKey string) (*BaiduAccessToken, error) {
}
req.Header.Add("Content-Type", "application/json")
req.Header.Add("Accept", "application/json")
res, err := service.GetImpatientHttpClient().Do(req)
res, err := service.GetHttpClient().Do(req)
if err != nil {
return nil, err
}

View File

@@ -125,7 +125,7 @@ func RequestOpenAI2ClaudeMessage(textRequest dto.GeneralOpenAIRequest) (*dto.Cla
if textRequest.Reasoning != nil {
var reasoning openrouter.RequestReasoning
if err := common.DecodeJson(textRequest.Reasoning, &reasoning); err != nil {
if err := common.UnmarshalJson(textRequest.Reasoning, &reasoning); err != nil {
return nil, err
}
@@ -519,7 +519,7 @@ func FormatClaudeResponseInfo(requestMode int, claudeResponse *dto.ClaudeRespons
func HandleStreamResponseData(c *gin.Context, info *relaycommon.RelayInfo, claudeInfo *ClaudeResponseInfo, data string, requestMode int) *dto.OpenAIErrorWithStatusCode {
var claudeResponse dto.ClaudeResponse
err := common.DecodeJsonStr(data, &claudeResponse)
err := common.UnmarshalJsonStr(data, &claudeResponse)
if err != nil {
common.SysError("error unmarshalling stream response: " + err.Error())
return service.OpenAIErrorWrapper(err, "stream_response_error", http.StatusInternalServerError)
@@ -619,7 +619,7 @@ func ClaudeStreamHandler(c *gin.Context, resp *http.Response, info *relaycommon.
func HandleClaudeResponseData(c *gin.Context, info *relaycommon.RelayInfo, claudeInfo *ClaudeResponseInfo, data []byte, requestMode int) *dto.OpenAIErrorWithStatusCode {
var claudeResponse dto.ClaudeResponse
err := common.DecodeJson(data, &claudeResponse)
err := common.UnmarshalJson(data, &claudeResponse)
if err != nil {
return service.OpenAIErrorWrapper(err, "unmarshal_claude_response_failed", http.StatusInternalServerError)
}
@@ -657,13 +657,14 @@ func HandleClaudeResponseData(c *gin.Context, info *relaycommon.RelayInfo, claud
case relaycommon.RelayFormatClaude:
responseData = data
}
c.Writer.Header().Set("Content-Type", "application/json")
c.Writer.WriteHeader(http.StatusOK)
_, err = c.Writer.Write(responseData)
common.IOCopyBytesGracefully(c, nil, responseData)
return nil
}
func ClaudeHandler(c *gin.Context, resp *http.Response, requestMode int, info *relaycommon.RelayInfo) (*dto.OpenAIErrorWithStatusCode, *dto.Usage) {
defer common.CloseResponseBodyGracefully(resp)
claudeInfo := &ClaudeResponseInfo{
ResponseId: fmt.Sprintf("chatcmpl-%s", common.GetUUID()),
Created: common.GetTimestamp(),
@@ -675,7 +676,6 @@ func ClaudeHandler(c *gin.Context, resp *http.Response, requestMode int, info *r
if err != nil {
return service.OpenAIErrorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil
}
resp.Body.Close()
if common.DebugEnabled {
println("responseBody: ", string(responseBody))
}

View File

@@ -81,10 +81,7 @@ func cfStreamHandler(c *gin.Context, resp *http.Response, info *relaycommon.Rela
}
helper.Done(c)
err := resp.Body.Close()
if err != nil {
common.LogError(c, "close_response_body_failed: "+err.Error())
}
common.CloseResponseBodyGracefully(resp)
return nil, usage
}
@@ -94,10 +91,7 @@ func cfHandler(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo)
if err != nil {
return service.OpenAIErrorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil
}
err = resp.Body.Close()
if err != nil {
return service.OpenAIErrorWrapperLocal(err, "close_response_body_failed", http.StatusInternalServerError), nil
}
common.CloseResponseBodyGracefully(resp)
var response dto.TextResponse
err = json.Unmarshal(responseBody, &response)
if err != nil {
@@ -127,10 +121,7 @@ func cfSTTHandler(c *gin.Context, resp *http.Response, info *relaycommon.RelayIn
if err != nil {
return service.OpenAIErrorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil
}
err = resp.Body.Close()
if err != nil {
return service.OpenAIErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil
}
common.CloseResponseBodyGracefully(resp)
err = json.Unmarshal(responseBody, &cfResp)
if err != nil {
return service.OpenAIErrorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil

View File

@@ -173,10 +173,7 @@ func cohereHandler(c *gin.Context, resp *http.Response, modelName string, prompt
if err != nil {
return service.OpenAIErrorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil
}
err = resp.Body.Close()
if err != nil {
return service.OpenAIErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil
}
common.CloseResponseBodyGracefully(resp)
var cohereResp CohereResponseResult
err = json.Unmarshal(responseBody, &cohereResp)
if err != nil {
@@ -217,10 +214,7 @@ func cohereRerankHandler(c *gin.Context, resp *http.Response, info *relaycommon.
if err != nil {
return service.OpenAIErrorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil
}
err = resp.Body.Close()
if err != nil {
return service.OpenAIErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil
}
common.CloseResponseBodyGracefully(resp)
var cohereResp CohereRerankResponseResult
err = json.Unmarshal(responseBody, &cohereResp)
if err != nil {

View File

@@ -48,10 +48,7 @@ func cozeChatHandler(c *gin.Context, resp *http.Response, info *relaycommon.Rela
if err != nil {
return service.OpenAIErrorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil
}
err = resp.Body.Close()
if err != nil {
return service.OpenAIErrorWrapperLocal(err, "close_response_body_failed", http.StatusInternalServerError), nil
}
common.CloseResponseBodyGracefully(resp)
// convert coze response to openai response
var response dto.TextResponse
var cozeResponse CozeChatDetailResponse

View File

@@ -95,7 +95,7 @@ func uploadDifyFile(c *gin.Context, info *relaycommon.RelayInfo, user string, me
req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", info.ApiKey))
// Send request
client := service.GetImpatientHttpClient()
client := service.GetHttpClient()
resp, err := client.Do(req)
if err != nil {
common.SysError("failed to send request: " + err.Error())
@@ -257,10 +257,7 @@ func difyHandler(c *gin.Context, resp *http.Response, info *relaycommon.RelayInf
if err != nil {
return service.OpenAIErrorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil
}
err = resp.Body.Close()
if err != nil {
return service.OpenAIErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil
}
common.CloseResponseBodyGracefully(resp)
err = json.Unmarshal(responseBody, &difyResponse)
if err != nil {
return service.OpenAIErrorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil

View File

@@ -1,7 +1,6 @@
package gemini
import (
"encoding/json"
"io"
"net/http"
"one-api/common"
@@ -15,15 +14,13 @@ import (
)
func GeminiTextGenerationHandler(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (*dto.Usage, *dto.OpenAIErrorWithStatusCode) {
defer common.CloseResponseBodyGracefully(resp)
// 读取响应体
responseBody, err := io.ReadAll(resp.Body)
if err != nil {
return nil, service.OpenAIErrorWrapper(err, "read_response_body_failed", http.StatusInternalServerError)
}
err = resp.Body.Close()
if err != nil {
return nil, service.OpenAIErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError)
}
if common.DebugEnabled {
println(string(responseBody))
@@ -31,7 +28,7 @@ func GeminiTextGenerationHandler(c *gin.Context, resp *http.Response, info *rela
// 解析为 Gemini 原生响应格式
var geminiResponse GeminiChatResponse
err = common.DecodeJson(responseBody, &geminiResponse)
err = common.UnmarshalJson(responseBody, &geminiResponse)
if err != nil {
return nil, service.OpenAIErrorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError)
}
@@ -54,18 +51,12 @@ func GeminiTextGenerationHandler(c *gin.Context, resp *http.Response, info *rela
}
// 直接返回 Gemini 原生格式的 JSON 响应
jsonResponse, err := json.Marshal(geminiResponse)
jsonResponse, err := common.EncodeJson(geminiResponse)
if err != nil {
return nil, service.OpenAIErrorWrapper(err, "marshal_response_body_failed", http.StatusInternalServerError)
}
// 设置响应头并写入响应
c.Writer.Header().Set("Content-Type", "application/json")
c.Writer.WriteHeader(resp.StatusCode)
_, err = c.Writer.Write(jsonResponse)
if err != nil {
return nil, service.OpenAIErrorWrapper(err, "write_response_failed", http.StatusInternalServerError)
}
common.IOCopyBytesGracefully(c, resp, jsonResponse)
return &usage, nil
}
@@ -80,7 +71,7 @@ func GeminiTextGenerationStreamHandler(c *gin.Context, resp *http.Response, info
helper.StreamScannerHandler(c, resp, info, func(data string) bool {
var geminiResponse GeminiChatResponse
err := common.DecodeJsonStr(data, &geminiResponse)
err := common.UnmarshalJsonStr(data, &geminiResponse)
if err != nil {
common.LogError(c, "error unmarshalling stream response: "+err.Error())
return false

View File

@@ -801,7 +801,7 @@ func GeminiChatStreamHandler(c *gin.Context, resp *http.Response, info *relaycom
helper.StreamScannerHandler(c, resp, info, func(data string) bool {
var geminiResponse GeminiChatResponse
err := common.DecodeJsonStr(data, &geminiResponse)
err := common.UnmarshalJsonStr(data, &geminiResponse)
if err != nil {
common.LogError(c, "error unmarshalling stream response: "+err.Error())
return false
@@ -866,15 +866,12 @@ func GeminiChatHandler(c *gin.Context, resp *http.Response, info *relaycommon.Re
if err != nil {
return service.OpenAIErrorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil
}
err = resp.Body.Close()
if err != nil {
return service.OpenAIErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil
}
common.CloseResponseBodyGracefully(resp)
if common.DebugEnabled {
println(string(responseBody))
}
var geminiResponse GeminiChatResponse
err = common.DecodeJson(responseBody, &geminiResponse)
err = common.UnmarshalJson(responseBody, &geminiResponse)
if err != nil {
return service.OpenAIErrorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil
}
@@ -920,11 +917,12 @@ func GeminiChatHandler(c *gin.Context, resp *http.Response, info *relaycommon.Re
}
func GeminiEmbeddingHandler(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage any, err *dto.OpenAIErrorWithStatusCode) {
defer common.CloseResponseBodyGracefully(resp)
responseBody, readErr := io.ReadAll(resp.Body)
if readErr != nil {
return nil, service.OpenAIErrorWrapper(readErr, "read_response_body_failed", http.StatusInternalServerError)
}
_ = resp.Body.Close()
var geminiResponse GeminiEmbeddingResponse
if jsonErr := json.Unmarshal(responseBody, &geminiResponse); jsonErr != nil {
@@ -956,14 +954,11 @@ func GeminiEmbeddingHandler(c *gin.Context, resp *http.Response, info *relaycomm
}
openAIResponse.Usage = *usage.(*dto.Usage)
jsonResponse, jsonErr := json.Marshal(openAIResponse)
jsonResponse, jsonErr := common.EncodeJson(openAIResponse)
if jsonErr != nil {
return nil, service.OpenAIErrorWrapper(jsonErr, "marshal_response_failed", http.StatusInternalServerError)
}
c.Writer.Header().Set("Content-Type", "application/json")
c.Writer.WriteHeader(resp.StatusCode)
_, _ = c.Writer.Write(jsonResponse)
common.IOCopyBytesGracefully(c, resp, jsonResponse)
return usage, nil
}

View File

@@ -3,6 +3,7 @@ package jina
var ModelList = []string{
"jina-clip-v1",
"jina-reranker-v2-base-multilingual",
"jina-reranker-m0",
}
var ChannelName = "jina"

View File

@@ -5,6 +5,7 @@ import (
"github.com/gin-gonic/gin"
"io"
"net/http"
"one-api/common"
"one-api/dto"
"one-api/service"
)
@@ -26,7 +27,7 @@ func embeddingRequestOpenAI2Moka(request dto.GeneralOpenAIRequest) *dto.Embeddin
}
return &dto.EmbeddingRequest{
Input: input,
Model: request.Model,
Model: request.Model,
}
}
@@ -53,10 +54,7 @@ func mokaEmbeddingHandler(c *gin.Context, resp *http.Response) (*dto.OpenAIError
if err != nil {
return service.OpenAIErrorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil
}
err = resp.Body.Close()
if err != nil {
return service.OpenAIErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil
}
common.CloseResponseBodyGracefully(resp)
err = json.Unmarshal(responseBody, &baiduResponse)
if err != nil {
return service.OpenAIErrorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil
@@ -80,4 +78,3 @@ func mokaEmbeddingHandler(c *gin.Context, resp *http.Response) (*dto.OpenAIError
_, err = c.Writer.Write(jsonResponse)
return nil, &fullTextResponse.Usage
}

View File

@@ -1,12 +1,12 @@
package ollama
import (
"bytes"
"encoding/json"
"fmt"
"github.com/gin-gonic/gin"
"io"
"net/http"
"one-api/common"
"one-api/dto"
"one-api/service"
"strings"
@@ -88,10 +88,7 @@ func ollamaEmbeddingHandler(c *gin.Context, resp *http.Response, promptTokens in
if err != nil {
return service.OpenAIErrorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil
}
err = resp.Body.Close()
if err != nil {
return service.OpenAIErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil
}
common.CloseResponseBodyGracefully(resp)
err = json.Unmarshal(responseBody, &ollamaEmbeddingResponse)
if err != nil {
return service.OpenAIErrorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil
@@ -120,31 +117,7 @@ func ollamaEmbeddingHandler(c *gin.Context, resp *http.Response, promptTokens in
if err != nil {
return service.OpenAIErrorWrapper(err, "marshal_response_body_failed", http.StatusInternalServerError), nil
}
resp.Body = io.NopCloser(bytes.NewBuffer(doResponseBody))
// We shouldn't set the header before we parse the response body, because the parse part may fail.
// And then we will have to send an error response, but in this case, the header has already been set.
// So the httpClient will be confused by the response.
// For example, Postman will report error, and we cannot check the response at all.
// Copy headers
for k, v := range resp.Header {
// 删除任何现有的相同头部,以防止重复添加头部
c.Writer.Header().Del(k)
for _, vv := range v {
c.Writer.Header().Add(k, vv)
}
}
// reset content length
c.Writer.Header().Del("Content-Length")
c.Writer.Header().Set("Content-Length", fmt.Sprintf("%d", len(doResponseBody)))
c.Writer.WriteHeader(resp.StatusCode)
_, err = io.Copy(c.Writer, resp.Body)
if err != nil {
return service.OpenAIErrorWrapper(err, "copy_response_body_failed", http.StatusInternalServerError), nil
}
err = resp.Body.Close()
if err != nil {
return service.OpenAIErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil
}
common.IOCopyBytesGracefully(c, resp, doResponseBody)
return nil, usage
}

View File

@@ -9,8 +9,7 @@ import (
"mime/multipart"
"net/http"
"net/textproto"
"one-api/common"
constant2 "one-api/constant"
"one-api/constant"
"one-api/dto"
"one-api/relay/channel"
"one-api/relay/channel/ai360"
@@ -21,7 +20,7 @@ import (
"one-api/relay/channel/xinference"
relaycommon "one-api/relay/common"
"one-api/relay/common_handler"
"one-api/relay/constant"
relayconstant "one-api/relay/constant"
"one-api/service"
"path/filepath"
"strings"
@@ -54,7 +53,7 @@ func (a *Adaptor) Init(info *relaycommon.RelayInfo) {
a.ChannelType = info.ChannelType
// initialize ThinkingContentInfo when thinking_to_content is enabled
if think2Content, ok := info.ChannelSetting[constant2.ChannelSettingThinkingToContent].(bool); ok && think2Content {
if think2Content, ok := info.ChannelSetting[constant.ChannelSettingThinkingToContent].(bool); ok && think2Content {
info.ThinkingContentInfo = relaycommon.ThinkingContentInfo{
IsFirstThinkingContent: true,
SendLastThinkingContent: false,
@@ -67,7 +66,7 @@ func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) {
if info.RelayFormat == relaycommon.RelayFormatClaude {
return fmt.Sprintf("%s/v1/chat/completions", info.BaseUrl), nil
}
if info.RelayMode == constant.RelayModeRealtime {
if info.RelayMode == relayconstant.RelayModeRealtime {
if strings.HasPrefix(info.BaseUrl, "https://") {
baseUrl := strings.TrimPrefix(info.BaseUrl, "https://")
baseUrl = "wss://" + baseUrl
@@ -79,10 +78,10 @@ func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) {
}
}
switch info.ChannelType {
case common.ChannelTypeAzure:
case constant.ChannelTypeAzure:
apiVersion := info.ApiVersion
if apiVersion == "" {
apiVersion = constant2.AzureDefaultAPIVersion
apiVersion = constant.AzureDefaultAPIVersion
}
// https://learn.microsoft.com/en-us/azure/cognitive-services/openai/chatgpt-quickstart?pivots=rest-api&tabs=command-line#rest-api
requestURL := strings.Split(info.RequestURLPath, "?")[0]
@@ -90,25 +89,25 @@ func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) {
task := strings.TrimPrefix(requestURL, "/v1/")
// 特殊处理 responses API
if info.RelayMode == constant.RelayModeResponses {
if info.RelayMode == relayconstant.RelayModeResponses {
requestURL = fmt.Sprintf("/openai/v1/responses?api-version=preview")
return relaycommon.GetFullRequestURL(info.BaseUrl, requestURL, info.ChannelType), nil
}
model_ := info.UpstreamModelName
// 2025年5月10日后创建的渠道不移除.
if info.ChannelCreateTime < constant2.AzureNoRemoveDotTime {
if info.ChannelCreateTime < constant.AzureNoRemoveDotTime {
model_ = strings.Replace(model_, ".", "", -1)
}
// https://github.com/songquanpeng/one-api/issues/67
requestURL = fmt.Sprintf("/openai/deployments/%s/%s", model_, task)
if info.RelayMode == constant.RelayModeRealtime {
if info.RelayMode == relayconstant.RelayModeRealtime {
requestURL = fmt.Sprintf("/openai/realtime?deployment=%s&api-version=%s", model_, apiVersion)
}
return relaycommon.GetFullRequestURL(info.BaseUrl, requestURL, info.ChannelType), nil
case common.ChannelTypeMiniMax:
case constant.ChannelTypeMiniMax:
return minimax.GetRequestURL(info)
case common.ChannelTypeCustom:
case constant.ChannelTypeCustom:
url := info.BaseUrl
url = strings.Replace(url, "{model}", info.UpstreamModelName, -1)
return url, nil
@@ -119,14 +118,14 @@ func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) {
func (a *Adaptor) SetupRequestHeader(c *gin.Context, header *http.Header, info *relaycommon.RelayInfo) error {
channel.SetupApiRequestHeader(info, c, header)
if info.ChannelType == common.ChannelTypeAzure {
if info.ChannelType == constant.ChannelTypeAzure {
header.Set("api-key", info.ApiKey)
return nil
}
if info.ChannelType == common.ChannelTypeOpenAI && "" != info.Organization {
if info.ChannelType == constant.ChannelTypeOpenAI && "" != info.Organization {
header.Set("OpenAI-Organization", info.Organization)
}
if info.RelayMode == constant.RelayModeRealtime {
if info.RelayMode == relayconstant.RelayModeRealtime {
swp := c.Request.Header.Get("Sec-WebSocket-Protocol")
if swp != "" {
items := []string{
@@ -145,7 +144,7 @@ func (a *Adaptor) SetupRequestHeader(c *gin.Context, header *http.Header, info *
} else {
header.Set("Authorization", "Bearer "+info.ApiKey)
}
if info.ChannelType == common.ChannelTypeOpenRouter {
if info.ChannelType == constant.ChannelTypeOpenRouter {
header.Set("HTTP-Referer", "https://github.com/Calcium-Ion/new-api")
header.Set("X-Title", "New API")
}
@@ -156,10 +155,10 @@ func (a *Adaptor) ConvertOpenAIRequest(c *gin.Context, info *relaycommon.RelayIn
if request == nil {
return nil, errors.New("request is nil")
}
if info.ChannelType != common.ChannelTypeOpenAI && info.ChannelType != common.ChannelTypeAzure {
if info.ChannelType != constant.ChannelTypeOpenAI && info.ChannelType != constant.ChannelTypeAzure {
request.StreamOptions = nil
}
if info.ChannelType == common.ChannelTypeOpenRouter {
if info.ChannelType == constant.ChannelTypeOpenRouter {
if len(request.Usage) == 0 {
request.Usage = json.RawMessage(`{"include":true}`)
}
@@ -205,7 +204,7 @@ func (a *Adaptor) ConvertEmbeddingRequest(c *gin.Context, info *relaycommon.Rela
func (a *Adaptor) ConvertAudioRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.AudioRequest) (io.Reader, error) {
a.ResponseFormat = request.ResponseFormat
if info.RelayMode == constant.RelayModeAudioSpeech {
if info.RelayMode == relayconstant.RelayModeAudioSpeech {
jsonData, err := json.Marshal(request)
if err != nil {
return nil, fmt.Errorf("error marshalling object: %w", err)
@@ -254,7 +253,7 @@ func (a *Adaptor) ConvertAudioRequest(c *gin.Context, info *relaycommon.RelayInf
func (a *Adaptor) ConvertImageRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.ImageRequest) (any, error) {
switch info.RelayMode {
case constant.RelayModeImagesEdits:
case relayconstant.RelayModeImagesEdits:
var requestBody bytes.Buffer
writer := multipart.NewWriter(&requestBody)
@@ -411,11 +410,11 @@ func (a *Adaptor) ConvertOpenAIResponsesRequest(c *gin.Context, info *relaycommo
}
func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (any, error) {
if info.RelayMode == constant.RelayModeAudioTranscription ||
info.RelayMode == constant.RelayModeAudioTranslation ||
info.RelayMode == constant.RelayModeImagesEdits {
if info.RelayMode == relayconstant.RelayModeAudioTranscription ||
info.RelayMode == relayconstant.RelayModeAudioTranslation ||
info.RelayMode == relayconstant.RelayModeImagesEdits {
return channel.DoFormRequest(a, c, info, requestBody)
} else if info.RelayMode == constant.RelayModeRealtime {
} else if info.RelayMode == relayconstant.RelayModeRealtime {
return channel.DoWssRequest(a, c, info, requestBody)
} else {
return channel.DoApiRequest(a, c, info, requestBody)
@@ -424,19 +423,19 @@ func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, request
func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage any, err *dto.OpenAIErrorWithStatusCode) {
switch info.RelayMode {
case constant.RelayModeRealtime:
case relayconstant.RelayModeRealtime:
err, usage = OpenaiRealtimeHandler(c, info)
case constant.RelayModeAudioSpeech:
case relayconstant.RelayModeAudioSpeech:
err, usage = OpenaiTTSHandler(c, resp, info)
case constant.RelayModeAudioTranslation:
case relayconstant.RelayModeAudioTranslation:
fallthrough
case constant.RelayModeAudioTranscription:
case relayconstant.RelayModeAudioTranscription:
err, usage = OpenaiSTTHandler(c, resp, info, a.ResponseFormat)
case constant.RelayModeImagesGenerations, constant.RelayModeImagesEdits:
case relayconstant.RelayModeImagesGenerations, relayconstant.RelayModeImagesEdits:
err, usage = OpenaiHandlerWithUsage(c, resp, info)
case constant.RelayModeRerank:
case relayconstant.RelayModeRerank:
err, usage = common_handler.RerankHandler(c, info, resp)
case constant.RelayModeResponses:
case relayconstant.RelayModeResponses:
if info.IsStream {
err, usage = OaiResponsesStreamHandler(c, resp, info)
} else {
@@ -454,17 +453,17 @@ func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycom
func (a *Adaptor) GetModelList() []string {
switch a.ChannelType {
case common.ChannelType360:
case constant.ChannelType360:
return ai360.ModelList
case common.ChannelTypeMoonshot:
case constant.ChannelTypeMoonshot:
return moonshot.ModelList
case common.ChannelTypeLingYiWanWu:
case constant.ChannelTypeLingYiWanWu:
return lingyiwanwu.ModelList
case common.ChannelTypeMiniMax:
case constant.ChannelTypeMiniMax:
return minimax.ModelList
case common.ChannelTypeXinference:
case constant.ChannelTypeXinference:
return xinference.ModelList
case common.ChannelTypeOpenRouter:
case constant.ChannelTypeOpenRouter:
return openrouter.ModelList
default:
return ModelList
@@ -473,17 +472,17 @@ func (a *Adaptor) GetModelList() []string {
func (a *Adaptor) GetChannelName() string {
switch a.ChannelType {
case common.ChannelType360:
case constant.ChannelType360:
return ai360.ChannelName
case common.ChannelTypeMoonshot:
case constant.ChannelTypeMoonshot:
return moonshot.ChannelName
case common.ChannelTypeLingYiWanWu:
case constant.ChannelTypeLingYiWanWu:
return lingyiwanwu.ChannelName
case common.ChannelTypeMiniMax:
case constant.ChannelTypeMiniMax:
return minimax.ChannelName
case common.ChannelTypeXinference:
case constant.ChannelTypeXinference:
return xinference.ChannelName
case common.ChannelTypeOpenRouter:
case constant.ChannelTypeOpenRouter:
return openrouter.ChannelName
default:
return ChannelName

View File

@@ -2,7 +2,6 @@ package openai
import (
"bytes"
"encoding/json"
"fmt"
"io"
"math"
@@ -34,7 +33,7 @@ func sendStreamData(c *gin.Context, info *relaycommon.RelayInfo, data string, fo
}
var lastStreamResponse dto.ChatCompletionsStreamResponse
if err := common.DecodeJsonStr(data, &lastStreamResponse); err != nil {
if err := common.UnmarshalJsonStr(data, &lastStreamResponse); err != nil {
return err
}
@@ -111,12 +110,13 @@ func OaiStreamHandler(c *gin.Context, resp *http.Response, info *relaycommon.Rel
return service.OpenAIErrorWrapper(fmt.Errorf("invalid response"), "invalid_response", http.StatusInternalServerError), nil
}
containStreamUsage := false
defer common.CloseResponseBodyGracefully(resp)
model := info.UpstreamModelName
var responseId string
var createAt int64 = 0
var systemFingerprint string
model := info.UpstreamModelName
var containStreamUsage bool
var responseTextBuilder strings.Builder
var toolCount int
var usage = &dto.Usage{}
@@ -148,31 +148,15 @@ func OaiStreamHandler(c *gin.Context, resp *http.Response, info *relaycommon.Rel
return true
})
// 处理最后的响应
shouldSendLastResp := true
var lastStreamResponse dto.ChatCompletionsStreamResponse
err := common.DecodeJsonStr(lastStreamData, &lastStreamResponse)
if err == nil {
responseId = lastStreamResponse.Id
createAt = lastStreamResponse.Created
systemFingerprint = lastStreamResponse.GetSystemFingerprint()
model = lastStreamResponse.Model
if service.ValidUsage(lastStreamResponse.Usage) {
containStreamUsage = true
usage = lastStreamResponse.Usage
if !info.ShouldIncludeUsage {
shouldSendLastResp = false
}
}
for _, choice := range lastStreamResponse.Choices {
if choice.FinishReason != nil {
shouldSendLastResp = true
}
}
if err := handleLastResponse(lastStreamData, &responseId, &createAt, &systemFingerprint, &model, &usage,
&containStreamUsage, info, &shouldSendLastResp); err != nil {
common.SysError("error handling last response: " + err.Error())
}
if shouldSendLastResp {
sendStreamData(c, info, lastStreamData, forceFormat, thinkToContent)
//err = handleStreamFormat(c, info, lastStreamData, forceFormat, thinkToContent)
if shouldSendLastResp && info.RelayFormat == relaycommon.RelayFormatOpenAI {
_ = sendStreamData(c, info, lastStreamData, forceFormat, thinkToContent)
}
// 处理token计算
@@ -184,7 +168,7 @@ func OaiStreamHandler(c *gin.Context, resp *http.Response, info *relaycommon.Rel
usage = service.ResponseText2Usage(responseTextBuilder.String(), info.UpstreamModelName, info.PromptTokens)
usage.CompletionTokens += toolCount * 7
} else {
if info.ChannelType == common.ChannelTypeDeepSeek {
if info.ChannelType == constant.ChannelTypeDeepSeek {
if usage.PromptCacheHitTokens != 0 {
usage.PromptTokensDetails.CachedTokens = usage.PromptCacheHitTokens
}
@@ -197,16 +181,14 @@ func OaiStreamHandler(c *gin.Context, resp *http.Response, info *relaycommon.Rel
}
func OpenaiHandler(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (*dto.OpenAIErrorWithStatusCode, *dto.Usage) {
defer common.CloseResponseBodyGracefully(resp)
var simpleResponse dto.OpenAITextResponse
responseBody, err := io.ReadAll(resp.Body)
if err != nil {
return service.OpenAIErrorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil
}
err = resp.Body.Close()
if err != nil {
return service.OpenAIErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil
}
err = common.DecodeJson(responseBody, &simpleResponse)
err = common.UnmarshalJson(responseBody, &simpleResponse)
if err != nil {
return service.OpenAIErrorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil
}
@@ -238,7 +220,7 @@ func OpenaiHandler(c *gin.Context, resp *http.Response, info *relaycommon.RelayI
switch info.RelayFormat {
case relaycommon.RelayFormatOpenAI:
if forceFormat {
responseBody, err = json.Marshal(simpleResponse)
responseBody, err = common.EncodeJson(simpleResponse)
if err != nil {
return service.OpenAIErrorWrapper(err, "marshal_response_body_failed", http.StatusInternalServerError), nil
}
@@ -247,29 +229,15 @@ func OpenaiHandler(c *gin.Context, resp *http.Response, info *relaycommon.RelayI
}
case relaycommon.RelayFormatClaude:
claudeResp := service.ResponseOpenAI2Claude(&simpleResponse, info)
claudeRespStr, err := json.Marshal(claudeResp)
claudeRespStr, err := common.EncodeJson(claudeResp)
if err != nil {
return service.OpenAIErrorWrapper(err, "marshal_response_body_failed", http.StatusInternalServerError), nil
}
responseBody = claudeRespStr
}
// Reset response body
resp.Body = io.NopCloser(bytes.NewBuffer(responseBody))
// We shouldn't set the header before we parse the response body, because the parse part may fail.
// And then we will have to send an error response, but in this case, the header has already been set.
// So the httpClient will be confused by the response.
// For example, Postman will report error, and we cannot check the response at all.
for k, v := range resp.Header {
c.Writer.Header().Set(k, v[0])
}
c.Writer.WriteHeader(resp.StatusCode)
_, err = io.Copy(c.Writer, resp.Body)
if err != nil {
//return service.OpenAIErrorWrapper(err, "copy_response_body_failed", http.StatusInternalServerError), nil
common.SysError("error copying response body: " + err.Error())
}
resp.Body.Close()
common.IOCopyBytesGracefully(c, resp, responseBody)
return nil, &simpleResponse.Usage
}
@@ -280,7 +248,7 @@ func OpenaiTTSHandler(c *gin.Context, resp *http.Response, info *relaycommon.Rel
// if the upstream returns a specific status code, once the upstream has already written the header,
// the subsequent failure of the response body should be regarded as a non-recoverable error,
// and can be terminated directly.
defer resp.Body.Close()
defer common.CloseResponseBodyGracefully(resp)
usage := &dto.Usage{}
usage.PromptTokens = info.PromptTokens
usage.TotalTokens = info.PromptTokens
@@ -297,6 +265,8 @@ func OpenaiTTSHandler(c *gin.Context, resp *http.Response, info *relaycommon.Rel
}
func OpenaiSTTHandler(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo, responseFormat string) (*dto.OpenAIErrorWithStatusCode, *dto.Usage) {
defer common.CloseResponseBodyGracefully(resp)
// count tokens by audio file duration
audioTokens, err := countAudioTokens(c)
if err != nil {
@@ -306,25 +276,8 @@ func OpenaiSTTHandler(c *gin.Context, resp *http.Response, info *relaycommon.Rel
if err != nil {
return service.OpenAIErrorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil
}
err = resp.Body.Close()
if err != nil {
return service.OpenAIErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil
}
// Reset response body
resp.Body = io.NopCloser(bytes.NewBuffer(responseBody))
// We shouldn't set the header before we parse the response body, because the parse part may fail.
// And then we will have to send an error response, but in this case, the header has already been set.
// So the httpClient will be confused by the response.
// For example, Postman will report error, and we cannot check the response at all.
for k, v := range resp.Header {
c.Writer.Header().Set(k, v[0])
}
c.Writer.WriteHeader(resp.StatusCode)
_, err = io.Copy(c.Writer, resp.Body)
if err != nil {
return service.OpenAIErrorWrapper(err, "copy_response_body_failed", http.StatusInternalServerError), nil
}
resp.Body.Close()
// 写入新的 response body
common.IOCopyBytesGracefully(c, resp, responseBody)
usage := &dto.Usage{}
usage.PromptTokens = audioTokens
@@ -415,7 +368,7 @@ func OpenaiRealtimeHandler(c *gin.Context, info *relaycommon.RelayInfo) (*dto.Op
}
realtimeEvent := &dto.RealtimeEvent{}
err = json.Unmarshal(message, realtimeEvent)
err = common.UnmarshalJson(message, realtimeEvent)
if err != nil {
errChan <- fmt.Errorf("error unmarshalling message: %v", err)
return
@@ -475,7 +428,7 @@ func OpenaiRealtimeHandler(c *gin.Context, info *relaycommon.RelayInfo) (*dto.Op
}
info.SetFirstResponseTime()
realtimeEvent := &dto.RealtimeEvent{}
err = json.Unmarshal(message, realtimeEvent)
err = common.UnmarshalJson(message, realtimeEvent)
if err != nil {
errChan <- fmt.Errorf("error unmarshalling message: %v", err)
return
@@ -522,9 +475,9 @@ func OpenaiRealtimeHandler(c *gin.Context, info *relaycommon.RelayInfo) (*dto.Op
localUsage = &dto.RealtimeUsage{}
// print now usage
}
//common.LogInfo(c, fmt.Sprintf("realtime streaming sumUsage: %v", sumUsage))
//common.LogInfo(c, fmt.Sprintf("realtime streaming localUsage: %v", localUsage))
//common.LogInfo(c, fmt.Sprintf("realtime streaming localUsage: %v", localUsage))
common.LogInfo(c, fmt.Sprintf("realtime streaming sumUsage: %v", sumUsage))
common.LogInfo(c, fmt.Sprintf("realtime streaming localUsage: %v", localUsage))
common.LogInfo(c, fmt.Sprintf("realtime streaming localUsage: %v", localUsage))
} else if realtimeEvent.Type == dto.RealtimeEventTypeSessionUpdated || realtimeEvent.Type == dto.RealtimeEventTypeSessionCreated {
realtimeSession := realtimeEvent.Session
@@ -601,40 +554,25 @@ func preConsumeUsage(ctx *gin.Context, info *relaycommon.RelayInfo, usage *dto.R
}
func OpenaiHandlerWithUsage(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (*dto.OpenAIErrorWithStatusCode, *dto.Usage) {
defer common.CloseResponseBodyGracefully(resp)
responseBody, err := io.ReadAll(resp.Body)
if err != nil {
return service.OpenAIErrorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil
}
err = resp.Body.Close()
if err != nil {
return service.OpenAIErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil
}
// Reset response body
resp.Body = io.NopCloser(bytes.NewBuffer(responseBody))
// We shouldn't set the header before we parse the response body, because the parse part may fail.
// And then we will have to send an error response, but in this case, the header has already been set.
// So the httpClient will be confused by the response.
// For example, Postman will report error, and we cannot check the response at all.
for k, v := range resp.Header {
c.Writer.Header().Set(k, v[0])
}
// reset content length
c.Writer.Header().Set("Content-Length", fmt.Sprintf("%d", len(responseBody)))
c.Writer.WriteHeader(resp.StatusCode)
_, err = io.Copy(c.Writer, resp.Body)
if err != nil {
return service.OpenAIErrorWrapper(err, "copy_response_body_failed", http.StatusInternalServerError), nil
}
err = resp.Body.Close()
if err != nil {
return service.OpenAIErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil
}
var usageResp dto.SimpleResponse
err = json.Unmarshal(responseBody, &usageResp)
err = common.UnmarshalJson(responseBody, &usageResp)
if err != nil {
return service.OpenAIErrorWrapper(err, "parse_response_body_failed", http.StatusInternalServerError), nil
}
// 写入新的 response body
common.IOCopyBytesGracefully(c, resp, responseBody)
// Once we've written to the client, we should not return errors anymore
// because the upstream has already consumed resources and returned content
// We should still perform billing even if parsing fails
// format
if usageResp.InputTokens > 0 {
usageResp.PromptTokens += usageResp.InputTokens

View File

@@ -1,7 +1,6 @@
package openai
import (
"bytes"
"fmt"
"io"
"net/http"
@@ -16,17 +15,15 @@ import (
)
func OaiResponsesHandler(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (*dto.OpenAIErrorWithStatusCode, *dto.Usage) {
defer common.CloseResponseBodyGracefully(resp)
// read response body
var responsesResponse dto.OpenAIResponsesResponse
responseBody, err := io.ReadAll(resp.Body)
if err != nil {
return service.OpenAIErrorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil
}
err = resp.Body.Close()
if err != nil {
return service.OpenAIErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil
}
err = common.DecodeJson(responseBody, &responsesResponse)
err = common.UnmarshalJson(responseBody, &responsesResponse)
if err != nil {
return service.OpenAIErrorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil
}
@@ -41,22 +38,9 @@ func OaiResponsesHandler(c *gin.Context, resp *http.Response, info *relaycommon.
}, nil
}
// reset response body
resp.Body = io.NopCloser(bytes.NewBuffer(responseBody))
// We shouldn't set the header before we parse the response body, because the parse part may fail.
// And then we will have to send an error response, but in this case, the header has already been set.
// So the httpClient will be confused by the response.
// For example, Postman will report error, and we cannot check the response at all.
for k, v := range resp.Header {
c.Writer.Header().Set(k, v[0])
}
c.Writer.WriteHeader(resp.StatusCode)
// copy response body
_, err = io.Copy(c.Writer, resp.Body)
if err != nil {
common.SysError("error copying response body: " + err.Error())
}
resp.Body.Close()
// 写入新的 response body
common.IOCopyBytesGracefully(c, resp, responseBody)
// compute usage
usage := dto.Usage{}
usage.PromptTokens = responsesResponse.Usage.InputTokens
@@ -82,7 +66,7 @@ func OaiResponsesStreamHandler(c *gin.Context, resp *http.Response, info *relayc
// 检查当前数据是否包含 completed 状态和 usage 信息
var streamResponse dto.ResponsesStreamResponse
if err := common.DecodeJsonStr(data, &streamResponse); err == nil {
if err := common.UnmarshalJsonStr(data, &streamResponse); err == nil {
sendResponsesStreamData(c, streamResponse, data)
switch streamResponse.Type {
case "response.completed":

View File

@@ -83,12 +83,7 @@ func palmStreamHandler(c *gin.Context, resp *http.Response) (*dto.OpenAIErrorWit
stopChan <- true
return
}
err = resp.Body.Close()
if err != nil {
common.SysError("error closing stream response: " + err.Error())
stopChan <- true
return
}
common.CloseResponseBodyGracefully(resp)
var palmResponse PaLMChatResponse
err = json.Unmarshal(responseBody, &palmResponse)
if err != nil {
@@ -122,10 +117,7 @@ func palmStreamHandler(c *gin.Context, resp *http.Response) (*dto.OpenAIErrorWit
return false
}
})
err := resp.Body.Close()
if err != nil {
return service.OpenAIErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), ""
}
common.CloseResponseBodyGracefully(resp)
return nil, responseText
}
@@ -134,10 +126,7 @@ func palmHandler(c *gin.Context, resp *http.Response, promptTokens int, model st
if err != nil {
return service.OpenAIErrorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil
}
err = resp.Body.Close()
if err != nil {
return service.OpenAIErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil
}
common.CloseResponseBodyGracefully(resp)
var palmResponse PaLMChatResponse
err = json.Unmarshal(responseBody, &palmResponse)
if err != nil {

View File

@@ -5,6 +5,7 @@ import (
"github.com/gin-gonic/gin"
"io"
"net/http"
"one-api/common"
"one-api/dto"
"one-api/service"
)
@@ -14,10 +15,7 @@ func siliconflowRerankHandler(c *gin.Context, resp *http.Response) (*dto.OpenAIE
if err != nil {
return service.OpenAIErrorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil
}
err = resp.Body.Close()
if err != nil {
return service.OpenAIErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil
}
common.CloseResponseBodyGracefully(resp)
var siliconflowResp SFRerankResponse
err = json.Unmarshal(responseBody, &siliconflowResp)
if err != nil {

View File

@@ -0,0 +1,380 @@
package jimeng
import (
"bytes"
"crypto/hmac"
"crypto/sha256"
"encoding/hex"
"encoding/json"
"fmt"
"io"
"net/http"
"net/url"
"one-api/model"
"sort"
"strings"
"time"
"github.com/gin-gonic/gin"
"github.com/pkg/errors"
"one-api/common"
"one-api/constant"
"one-api/dto"
"one-api/relay/channel"
relaycommon "one-api/relay/common"
"one-api/service"
)
// ============================
// Request / Response structures
// ============================
type requestPayload struct {
ReqKey string `json:"req_key"`
BinaryDataBase64 []string `json:"binary_data_base64,omitempty"`
ImageUrls []string `json:"image_urls,omitempty"`
Prompt string `json:"prompt,omitempty"`
Seed int64 `json:"seed"`
AspectRatio string `json:"aspect_ratio"`
}
type responsePayload struct {
Code int `json:"code"`
Message string `json:"message"`
RequestId string `json:"request_id"`
Data struct {
TaskID string `json:"task_id"`
} `json:"data"`
}
type responseTask struct {
Code int `json:"code"`
Data struct {
BinaryDataBase64 []interface{} `json:"binary_data_base64"`
ImageUrls interface{} `json:"image_urls"`
RespData string `json:"resp_data"`
Status string `json:"status"`
VideoUrl string `json:"video_url"`
} `json:"data"`
Message string `json:"message"`
RequestId string `json:"request_id"`
Status int `json:"status"`
TimeElapsed string `json:"time_elapsed"`
}
// ============================
// Adaptor implementation
// ============================
type TaskAdaptor struct {
ChannelType int
accessKey string
secretKey string
baseURL string
}
func (a *TaskAdaptor) Init(info *relaycommon.TaskRelayInfo) {
a.ChannelType = info.ChannelType
a.baseURL = info.BaseUrl
// apiKey format: "access_key|secret_key"
keyParts := strings.Split(info.ApiKey, "|")
if len(keyParts) == 2 {
a.accessKey = strings.TrimSpace(keyParts[0])
a.secretKey = strings.TrimSpace(keyParts[1])
}
}
// ValidateRequestAndSetAction parses body, validates fields and sets default action.
func (a *TaskAdaptor) ValidateRequestAndSetAction(c *gin.Context, info *relaycommon.TaskRelayInfo) (taskErr *dto.TaskError) {
// Accept only POST /v1/video/generations as "generate" action.
action := constant.TaskActionGenerate
info.Action = action
req := relaycommon.TaskSubmitReq{}
if err := common.UnmarshalBodyReusable(c, &req); err != nil {
taskErr = service.TaskErrorWrapperLocal(err, "invalid_request", http.StatusBadRequest)
return
}
if strings.TrimSpace(req.Prompt) == "" {
taskErr = service.TaskErrorWrapperLocal(fmt.Errorf("prompt is required"), "invalid_request", http.StatusBadRequest)
return
}
// Store into context for later usage
c.Set("task_request", req)
return nil
}
// BuildRequestURL constructs the upstream URL.
func (a *TaskAdaptor) BuildRequestURL(info *relaycommon.TaskRelayInfo) (string, error) {
return fmt.Sprintf("%s/?Action=CVSync2AsyncSubmitTask&Version=2022-08-31", a.baseURL), nil
}
// BuildRequestHeader sets required headers.
func (a *TaskAdaptor) BuildRequestHeader(c *gin.Context, req *http.Request, info *relaycommon.TaskRelayInfo) error {
req.Header.Set("Content-Type", "application/json")
req.Header.Set("Accept", "application/json")
return a.signRequest(req, a.accessKey, a.secretKey)
}
// BuildRequestBody converts request into Jimeng specific format.
func (a *TaskAdaptor) BuildRequestBody(c *gin.Context, info *relaycommon.TaskRelayInfo) (io.Reader, error) {
v, exists := c.Get("task_request")
if !exists {
return nil, fmt.Errorf("request not found in context")
}
req := v.(relaycommon.TaskSubmitReq)
body, err := a.convertToRequestPayload(&req)
if err != nil {
return nil, errors.Wrap(err, "convert request payload failed")
}
data, err := json.Marshal(body)
if err != nil {
return nil, err
}
return bytes.NewReader(data), nil
}
// DoRequest delegates to common helper.
func (a *TaskAdaptor) DoRequest(c *gin.Context, info *relaycommon.TaskRelayInfo, requestBody io.Reader) (*http.Response, error) {
return channel.DoTaskApiRequest(a, c, info, requestBody)
}
// DoResponse handles upstream response, returns taskID etc.
func (a *TaskAdaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.TaskRelayInfo) (taskID string, taskData []byte, taskErr *dto.TaskError) {
responseBody, err := io.ReadAll(resp.Body)
if err != nil {
taskErr = service.TaskErrorWrapper(err, "read_response_body_failed", http.StatusInternalServerError)
return
}
_ = resp.Body.Close()
// Parse Jimeng response
var jResp responsePayload
if err := json.Unmarshal(responseBody, &jResp); err != nil {
taskErr = service.TaskErrorWrapper(errors.Wrapf(err, "body: %s", responseBody), "unmarshal_response_body_failed", http.StatusInternalServerError)
return
}
if jResp.Code != 10000 {
taskErr = service.TaskErrorWrapper(fmt.Errorf(jResp.Message), fmt.Sprintf("%d", jResp.Code), http.StatusInternalServerError)
return
}
c.JSON(http.StatusOK, gin.H{"task_id": jResp.Data.TaskID})
return jResp.Data.TaskID, responseBody, nil
}
// FetchTask fetch task status
func (a *TaskAdaptor) FetchTask(baseUrl, key string, body map[string]any) (*http.Response, error) {
taskID, ok := body["task_id"].(string)
if !ok {
return nil, fmt.Errorf("invalid task_id")
}
uri := fmt.Sprintf("%s/?Action=CVSync2AsyncGetResult&Version=2022-08-31", baseUrl)
payload := map[string]string{
"req_key": "jimeng_vgfm_t2v_l20", // This is fixed value from doc: https://www.volcengine.com/docs/85621/1544774
"task_id": taskID,
}
payloadBytes, err := json.Marshal(payload)
if err != nil {
return nil, errors.Wrap(err, "marshal fetch task payload failed")
}
req, err := http.NewRequest(http.MethodPost, uri, bytes.NewBuffer(payloadBytes))
if err != nil {
return nil, err
}
req.Header.Set("Accept", "application/json")
req.Header.Set("Content-Type", "application/json")
keyParts := strings.Split(key, "|")
if len(keyParts) != 2 {
return nil, fmt.Errorf("invalid api key format for jimeng: expected 'ak|sk'")
}
accessKey := strings.TrimSpace(keyParts[0])
secretKey := strings.TrimSpace(keyParts[1])
if err := a.signRequest(req, accessKey, secretKey); err != nil {
return nil, errors.Wrap(err, "sign request failed")
}
return service.GetHttpClient().Do(req)
}
func (a *TaskAdaptor) GetModelList() []string {
return []string{"jimeng_vgfm_t2v_l20"}
}
func (a *TaskAdaptor) GetChannelName() string {
return "jimeng"
}
func (a *TaskAdaptor) signRequest(req *http.Request, accessKey, secretKey string) error {
var bodyBytes []byte
var err error
if req.Body != nil {
bodyBytes, err = io.ReadAll(req.Body)
if err != nil {
return errors.Wrap(err, "read request body failed")
}
_ = req.Body.Close()
req.Body = io.NopCloser(bytes.NewBuffer(bodyBytes)) // Rewind
} else {
bodyBytes = []byte{}
}
payloadHash := sha256.Sum256(bodyBytes)
hexPayloadHash := hex.EncodeToString(payloadHash[:])
t := time.Now().UTC()
xDate := t.Format("20060102T150405Z")
shortDate := t.Format("20060102")
req.Header.Set("Host", req.URL.Host)
req.Header.Set("X-Date", xDate)
req.Header.Set("X-Content-Sha256", hexPayloadHash)
// Sort and encode query parameters to create canonical query string
queryParams := req.URL.Query()
sortedKeys := make([]string, 0, len(queryParams))
for k := range queryParams {
sortedKeys = append(sortedKeys, k)
}
sort.Strings(sortedKeys)
var queryParts []string
for _, k := range sortedKeys {
values := queryParams[k]
sort.Strings(values)
for _, v := range values {
queryParts = append(queryParts, fmt.Sprintf("%s=%s", url.QueryEscape(k), url.QueryEscape(v)))
}
}
canonicalQueryString := strings.Join(queryParts, "&")
headersToSign := map[string]string{
"host": req.URL.Host,
"x-date": xDate,
"x-content-sha256": hexPayloadHash,
}
if req.Header.Get("Content-Type") != "" {
headersToSign["content-type"] = req.Header.Get("Content-Type")
}
var signedHeaderKeys []string
for k := range headersToSign {
signedHeaderKeys = append(signedHeaderKeys, k)
}
sort.Strings(signedHeaderKeys)
var canonicalHeaders strings.Builder
for _, k := range signedHeaderKeys {
canonicalHeaders.WriteString(k)
canonicalHeaders.WriteString(":")
canonicalHeaders.WriteString(strings.TrimSpace(headersToSign[k]))
canonicalHeaders.WriteString("\n")
}
signedHeaders := strings.Join(signedHeaderKeys, ";")
canonicalRequest := fmt.Sprintf("%s\n%s\n%s\n%s\n%s\n%s",
req.Method,
req.URL.Path,
canonicalQueryString,
canonicalHeaders.String(),
signedHeaders,
hexPayloadHash,
)
hashedCanonicalRequest := sha256.Sum256([]byte(canonicalRequest))
hexHashedCanonicalRequest := hex.EncodeToString(hashedCanonicalRequest[:])
region := "cn-north-1"
serviceName := "cv"
credentialScope := fmt.Sprintf("%s/%s/%s/request", shortDate, region, serviceName)
stringToSign := fmt.Sprintf("HMAC-SHA256\n%s\n%s\n%s",
xDate,
credentialScope,
hexHashedCanonicalRequest,
)
kDate := hmacSHA256([]byte(secretKey), []byte(shortDate))
kRegion := hmacSHA256(kDate, []byte(region))
kService := hmacSHA256(kRegion, []byte(serviceName))
kSigning := hmacSHA256(kService, []byte("request"))
signature := hex.EncodeToString(hmacSHA256(kSigning, []byte(stringToSign)))
authorization := fmt.Sprintf("HMAC-SHA256 Credential=%s/%s, SignedHeaders=%s, Signature=%s",
accessKey,
credentialScope,
signedHeaders,
signature,
)
req.Header.Set("Authorization", authorization)
return nil
}
func hmacSHA256(key []byte, data []byte) []byte {
h := hmac.New(sha256.New, key)
h.Write(data)
return h.Sum(nil)
}
func (a *TaskAdaptor) convertToRequestPayload(req *relaycommon.TaskSubmitReq) (*requestPayload, error) {
r := requestPayload{
ReqKey: "jimeng_vgfm_i2v_l20",
Prompt: req.Prompt,
AspectRatio: "16:9", // Default aspect ratio
Seed: -1, // Default to random
}
// Handle one-of image_urls or binary_data_base64
if req.Image != "" {
if strings.HasPrefix(req.Image, "http") {
r.ImageUrls = []string{req.Image}
} else {
r.BinaryDataBase64 = []string{req.Image}
}
}
metadata := req.Metadata
medaBytes, err := json.Marshal(metadata)
if err != nil {
return nil, errors.Wrap(err, "metadata marshal metadata failed")
}
err = json.Unmarshal(medaBytes, &r)
if err != nil {
return nil, errors.Wrap(err, "unmarshal metadata failed")
}
return &r, nil
}
func (a *TaskAdaptor) ParseTaskResult(respBody []byte) (*relaycommon.TaskInfo, error) {
resTask := responseTask{}
if err := json.Unmarshal(respBody, &resTask); err != nil {
return nil, errors.Wrap(err, "unmarshal task result failed")
}
taskResult := relaycommon.TaskInfo{}
if resTask.Code == 10000 {
taskResult.Code = 0
} else {
taskResult.Code = resTask.Code // todo uni code
taskResult.Reason = resTask.Message
taskResult.Status = model.TaskStatusFailure
taskResult.Progress = "100%"
}
switch resTask.Data.Status {
case "in_queue":
taskResult.Status = model.TaskStatusQueued
taskResult.Progress = "10%"
case "done":
taskResult.Status = model.TaskStatusSuccess
taskResult.Progress = "100%"
}
taskResult.Url = resTask.Data.VideoUrl
return &taskResult, nil
}

View File

@@ -2,11 +2,12 @@ package kling
import (
"bytes"
"context"
"encoding/json"
"fmt"
"github.com/samber/lo"
"io"
"net/http"
"one-api/model"
"strings"
"time"
@@ -15,6 +16,7 @@ import (
"github.com/pkg/errors"
"one-api/common"
"one-api/constant"
"one-api/dto"
"one-api/relay/channel"
relaycommon "one-api/relay/common"
@@ -41,16 +43,27 @@ type requestPayload struct {
Mode string `json:"mode,omitempty"`
Duration string `json:"duration,omitempty"`
AspectRatio string `json:"aspect_ratio,omitempty"`
Model string `json:"model,omitempty"`
ModelName string `json:"model_name,omitempty"`
CfgScale float64 `json:"cfg_scale,omitempty"`
}
type responsePayload struct {
Code int `json:"code"`
Message string `json:"message"`
Data struct {
TaskID string `json:"task_id"`
Code int `json:"code"`
Message string `json:"message"`
RequestId string `json:"request_id"`
Data struct {
TaskId string `json:"task_id"`
TaskStatus string `json:"task_status"`
TaskStatusMsg string `json:"task_status_msg"`
TaskResult struct {
Videos []struct {
Id string `json:"id"`
Url string `json:"url"`
Duration string `json:"duration"`
} `json:"videos"`
} `json:"task_result"`
CreatedAt int64 `json:"created_at"`
UpdatedAt int64 `json:"updated_at"`
} `json:"data"`
}
@@ -80,7 +93,7 @@ func (a *TaskAdaptor) Init(info *relaycommon.TaskRelayInfo) {
// ValidateRequestAndSetAction parses body, validates fields and sets default action.
func (a *TaskAdaptor) ValidateRequestAndSetAction(c *gin.Context, info *relaycommon.TaskRelayInfo) (taskErr *dto.TaskError) {
// Accept only POST /v1/video/generations as "generate" action.
action := "generate"
action := constant.TaskActionGenerate
info.Action = action
var req SubmitReq
@@ -94,13 +107,14 @@ func (a *TaskAdaptor) ValidateRequestAndSetAction(c *gin.Context, info *relaycom
}
// Store into context for later usage
c.Set("kling_request", req)
c.Set("task_request", req)
return nil
}
// BuildRequestURL constructs the upstream URL.
func (a *TaskAdaptor) BuildRequestURL(info *relaycommon.TaskRelayInfo) (string, error) {
return fmt.Sprintf("%s/v1/videos/image2video", a.baseURL), nil
path := lo.Ternary(info.Action == constant.TaskActionGenerate, "/v1/videos/image2video", "/v1/videos/text2video")
return fmt.Sprintf("%s%s", a.baseURL, path), nil
}
// BuildRequestHeader sets required headers.
@@ -119,13 +133,16 @@ func (a *TaskAdaptor) BuildRequestHeader(c *gin.Context, req *http.Request, info
// BuildRequestBody converts request into Kling specific format.
func (a *TaskAdaptor) BuildRequestBody(c *gin.Context, info *relaycommon.TaskRelayInfo) (io.Reader, error) {
v, exists := c.Get("kling_request")
v, exists := c.Get("task_request")
if !exists {
return nil, fmt.Errorf("request not found in context")
}
req := v.(SubmitReq)
body := a.convertToRequestPayload(&req)
body, err := a.convertToRequestPayload(&req)
if err != nil {
return nil, err
}
data, err := json.Marshal(body)
if err != nil {
return nil, err
@@ -135,6 +152,9 @@ func (a *TaskAdaptor) BuildRequestBody(c *gin.Context, info *relaycommon.TaskRel
// DoRequest delegates to common helper.
func (a *TaskAdaptor) DoRequest(c *gin.Context, info *relaycommon.TaskRelayInfo, requestBody io.Reader) (*http.Response, error) {
if action := c.GetString("action"); action != "" {
info.Action = action
}
return channel.DoTaskApiRequest(a, c, info, requestBody)
}
@@ -149,8 +169,8 @@ func (a *TaskAdaptor) DoResponse(c *gin.Context, resp *http.Response, info *rela
// Attempt Kling response parse first.
var kResp responsePayload
if err := json.Unmarshal(responseBody, &kResp); err == nil && kResp.Code == 0 {
c.JSON(http.StatusOK, gin.H{"task_id": kResp.Data.TaskID})
return kResp.Data.TaskID, responseBody, nil
c.JSON(http.StatusOK, gin.H{"task_id": kResp.Data.TaskId})
return kResp.Data.TaskId, responseBody, nil
}
// Fallback generic task response.
@@ -175,7 +195,12 @@ func (a *TaskAdaptor) FetchTask(baseUrl, key string, body map[string]any) (*http
if !ok {
return nil, fmt.Errorf("invalid task_id")
}
url := fmt.Sprintf("%s/v1/videos/image2video/%s", baseUrl, taskID)
action, ok := body["action"].(string)
if !ok {
return nil, fmt.Errorf("invalid action")
}
path := lo.Ternary(action == constant.TaskActionGenerate, "/v1/videos/image2video", "/v1/videos/text2video")
url := fmt.Sprintf("%s%s/%s", baseUrl, path, taskID)
req, err := http.NewRequest(http.MethodGet, url, nil)
if err != nil {
@@ -187,10 +212,6 @@ func (a *TaskAdaptor) FetchTask(baseUrl, key string, body map[string]any) (*http
token = key
}
ctx, cancel := context.WithTimeout(context.Background(), 15*time.Second)
defer cancel()
req = req.WithContext(ctx)
req.Header.Set("Accept", "application/json")
req.Header.Set("Authorization", "Bearer "+token)
req.Header.Set("User-Agent", "kling-sdk/1.0")
@@ -210,22 +231,29 @@ func (a *TaskAdaptor) GetChannelName() string {
// helpers
// ============================
func (a *TaskAdaptor) convertToRequestPayload(req *SubmitReq) *requestPayload {
r := &requestPayload{
func (a *TaskAdaptor) convertToRequestPayload(req *SubmitReq) (*requestPayload, error) {
r := requestPayload{
Prompt: req.Prompt,
Image: req.Image,
Mode: defaultString(req.Mode, "std"),
Duration: fmt.Sprintf("%d", defaultInt(req.Duration, 5)),
AspectRatio: a.getAspectRatio(req.Size),
Model: req.Model,
ModelName: req.Model,
CfgScale: 0.5,
}
if r.Model == "" {
r.Model = "kling-v1"
if r.ModelName == "" {
r.ModelName = "kling-v1"
}
return r
metadata := req.Metadata
medaBytes, err := json.Marshal(metadata)
if err != nil {
return nil, errors.Wrap(err, "metadata marshal metadata failed")
}
err = json.Unmarshal(medaBytes, &r)
if err != nil {
return nil, errors.Wrap(err, "unmarshal metadata failed")
}
return &r, nil
}
func (a *TaskAdaptor) getAspectRatio(size string) string {
@@ -286,27 +314,33 @@ func (a *TaskAdaptor) createJWTTokenWithKeys(accessKey, secretKey string) (strin
return token.SignedString([]byte(secretKey))
}
// ParseResultUrl 提取视频任务结果的 url
func (a *TaskAdaptor) ParseResultUrl(resp map[string]any) (string, error) {
data, ok := resp["data"].(map[string]any)
if !ok {
return "", fmt.Errorf("data field not found or invalid")
func (a *TaskAdaptor) ParseTaskResult(respBody []byte) (*relaycommon.TaskInfo, error) {
resPayload := responsePayload{}
err := json.Unmarshal(respBody, &resPayload)
if err != nil {
return nil, errors.Wrap(err, "failed to unmarshal response body")
}
taskResult, ok := data["task_result"].(map[string]any)
if !ok {
return "", fmt.Errorf("task_result field not found or invalid")
taskInfo := &relaycommon.TaskInfo{}
taskInfo.Code = resPayload.Code
taskInfo.TaskID = resPayload.Data.TaskId
taskInfo.Reason = resPayload.Message
//任务状态枚举值submitted已提交、processing处理中、succeed成功、failed失败
status := resPayload.Data.TaskStatus
switch status {
case "submitted":
taskInfo.Status = model.TaskStatusSubmitted
case "processing":
taskInfo.Status = model.TaskStatusInProgress
case "succeed":
taskInfo.Status = model.TaskStatusSuccess
case "failed":
taskInfo.Status = model.TaskStatusFailure
default:
return nil, fmt.Errorf("unknown task status: %s", status)
}
videos, ok := taskResult["videos"].([]interface{})
if !ok || len(videos) == 0 {
return "", fmt.Errorf("videos field not found or empty")
if videos := resPayload.Data.TaskResult.Videos; len(videos) > 0 {
video := videos[0]
taskInfo.Url = video.Url
}
video, ok := videos[0].(map[string]interface{})
if !ok {
return "", fmt.Errorf("video item invalid")
}
url, ok := video["url"].(string)
if !ok || url == "" {
return "", fmt.Errorf("url field not found or invalid")
}
return url, nil
return taskInfo, nil
}

View File

@@ -22,8 +22,8 @@ type TaskAdaptor struct {
ChannelType int
}
func (a *TaskAdaptor) ParseResultUrl(resp map[string]any) (string, error) {
return "", nil // todo implement this method if needed
func (a *TaskAdaptor) ParseTaskResult([]byte) (*relaycommon.TaskInfo, error) {
return nil, fmt.Errorf("not implement") // todo implement this method if needed
}
func (a *TaskAdaptor) Init(info *relaycommon.TaskRelayInfo) {

View File

@@ -124,10 +124,7 @@ func tencentStreamHandler(c *gin.Context, resp *http.Response) (*dto.OpenAIError
helper.Done(c)
err := resp.Body.Close()
if err != nil {
return service.OpenAIErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), ""
}
common.CloseResponseBodyGracefully(resp)
return nil, responseText
}
@@ -138,10 +135,7 @@ func tencentHandler(c *gin.Context, resp *http.Response) (*dto.OpenAIErrorWithSt
if err != nil {
return service.OpenAIErrorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil
}
err = resp.Body.Close()
if err != nil {
return service.OpenAIErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil
}
common.CloseResponseBodyGracefully(resp)
err = json.Unmarshal(responseBody, &tencentSb)
if err != nil {
return service.OpenAIErrorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil

View File

@@ -11,6 +11,7 @@ import (
"net/http"
"net/url"
relaycommon "one-api/relay/common"
"one-api/service"
"strings"
"fmt"
@@ -45,7 +46,7 @@ func getAccessToken(a *Adaptor, info *relaycommon.RelayInfo) (string, error) {
if err != nil {
return "", fmt.Errorf("failed to create signed JWT: %w", err)
}
newToken, err := exchangeJwtForAccessToken(signedJWT)
newToken, err := exchangeJwtForAccessToken(signedJWT, info)
if err != nil {
return "", fmt.Errorf("failed to exchange JWT for access token: %w", err)
}
@@ -96,14 +97,25 @@ func createSignedJWT(email, privateKeyPEM string) (string, error) {
return signedToken, nil
}
func exchangeJwtForAccessToken(signedJWT string) (string, error) {
func exchangeJwtForAccessToken(signedJWT string, info *relaycommon.RelayInfo) (string, error) {
authURL := "https://www.googleapis.com/oauth2/v4/token"
data := url.Values{}
data.Set("grant_type", "urn:ietf:params:oauth:grant-type:jwt-bearer")
data.Set("assertion", signedJWT)
resp, err := http.PostForm(authURL, data)
var client *http.Client
var err error
if proxyURL, ok := info.ChannelSetting["proxy"]; ok {
client, err = service.NewProxyHttpClient(proxyURL.(string))
if err != nil {
return "", fmt.Errorf("new proxy http client failed: %w", err)
}
} else {
client = service.GetHttpClient()
}
resp, err := client.PostForm(authURL, data)
if err != nil {
return "", err
}

View File

@@ -1,9 +1,7 @@
package xai
import (
"bytes"
"encoding/json"
"github.com/gin-gonic/gin"
"io"
"net/http"
"one-api/common"
@@ -13,6 +11,8 @@ import (
"one-api/relay/helper"
"one-api/service"
"strings"
"github.com/gin-gonic/gin"
)
func streamResponseXAI2OpenAI(xAIResp *dto.ChatCompletionsStreamResponse, usage *dto.Usage) *dto.ChatCompletionsStreamResponse {
@@ -73,18 +73,16 @@ func xAIStreamHandler(c *gin.Context, resp *http.Response, info *relaycommon.Rel
}
helper.Done(c)
err := resp.Body.Close()
if err != nil {
//return service.OpenAIErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil
common.SysError("close_response_body_failed: " + err.Error())
}
common.CloseResponseBodyGracefully(resp)
return nil, usage
}
func xAIHandler(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (*dto.OpenAIErrorWithStatusCode, *dto.Usage) {
defer common.CloseResponseBodyGracefully(resp)
responseBody, err := io.ReadAll(resp.Body)
var response *dto.TextResponse
err = common.DecodeJson(responseBody, &response)
var response *dto.SimpleResponse
err = common.UnmarshalJson(responseBody, &response)
if err != nil {
common.SysError("error unmarshalling stream response: " + err.Error())
return nil, nil
@@ -99,21 +97,7 @@ func xAIHandler(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo
return nil, nil
}
// set new body
resp.Body = io.NopCloser(bytes.NewBuffer(encodeJson))
for k, v := range resp.Header {
c.Writer.Header().Set(k, v[0])
}
c.Writer.WriteHeader(resp.StatusCode)
_, err = io.Copy(c.Writer, resp.Body)
if err != nil {
return service.OpenAIErrorWrapper(err, "copy_response_body_failed", http.StatusInternalServerError), nil
}
err = resp.Body.Close()
if err != nil {
return service.OpenAIErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil
}
common.IOCopyBytesGracefully(c, resp, encodeJson)
return nil, &response.Usage
}

View File

@@ -1,7 +1,7 @@
package xinference
type XinRerankResponseDocument struct {
Document string `json:"document,omitempty"`
Document any `json:"document,omitempty"`
Index int `json:"index"`
RelevanceScore float64 `json:"relevance_score"`
}

View File

@@ -210,10 +210,7 @@ func zhipuStreamHandler(c *gin.Context, resp *http.Response) (*dto.OpenAIErrorWi
return false
}
})
err := resp.Body.Close()
if err != nil {
return service.OpenAIErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil
}
common.CloseResponseBodyGracefully(resp)
return nil, usage
}
@@ -223,10 +220,7 @@ func zhipuHandler(c *gin.Context, resp *http.Response) (*dto.OpenAIErrorWithStat
if err != nil {
return service.OpenAIErrorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil
}
err = resp.Body.Close()
if err != nil {
return service.OpenAIErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil
}
common.CloseResponseBodyGracefully(resp)
err = json.Unmarshal(responseBody, &zhipuResponse)
if err != nil {
return service.OpenAIErrorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil

View File

@@ -113,17 +113,17 @@ type RelayInfo struct {
// 定义支持流式选项的通道类型
var streamSupportedChannels = map[int]bool{
common.ChannelTypeOpenAI: true,
common.ChannelTypeAnthropic: true,
common.ChannelTypeAws: true,
common.ChannelTypeGemini: true,
common.ChannelCloudflare: true,
common.ChannelTypeAzure: true,
common.ChannelTypeVolcEngine: true,
common.ChannelTypeOllama: true,
common.ChannelTypeXai: true,
common.ChannelTypeDeepSeek: true,
common.ChannelTypeBaiduV2: true,
constant.ChannelTypeOpenAI: true,
constant.ChannelTypeAnthropic: true,
constant.ChannelTypeAws: true,
constant.ChannelTypeGemini: true,
constant.ChannelCloudflare: true,
constant.ChannelTypeAzure: true,
constant.ChannelTypeVolcEngine: true,
constant.ChannelTypeOllama: true,
constant.ChannelTypeXai: true,
constant.ChannelTypeDeepSeek: true,
constant.ChannelTypeBaiduV2: true,
}
func GenRelayInfoWs(c *gin.Context, ws *websocket.Conn) *RelayInfo {
@@ -211,40 +211,40 @@ func GenRelayInfoImage(c *gin.Context) *RelayInfo {
}
func GenRelayInfo(c *gin.Context) *RelayInfo {
channelType := c.GetInt("channel_type")
channelId := c.GetInt("channel_id")
channelSetting := c.GetStringMap("channel_setting")
paramOverride := c.GetStringMap("param_override")
channelType := common.GetContextKeyInt(c, constant.ContextKeyChannelType)
channelId := common.GetContextKeyInt(c, constant.ContextKeyChannelId)
channelSetting := common.GetContextKeyStringMap(c, constant.ContextKeyChannelSetting)
paramOverride := common.GetContextKeyStringMap(c, constant.ContextKeyParamOverride)
tokenId := c.GetInt("token_id")
tokenKey := c.GetString("token_key")
userId := c.GetInt("id")
tokenUnlimited := c.GetBool("token_unlimited_quota")
startTime := c.GetTime(constant.ContextKeyRequestStartTime)
tokenId := common.GetContextKeyInt(c, constant.ContextKeyTokenId)
tokenKey := common.GetContextKeyString(c, constant.ContextKeyTokenKey)
userId := common.GetContextKeyInt(c, constant.ContextKeyUserId)
tokenUnlimited := common.GetContextKeyBool(c, constant.ContextKeyTokenUnlimited)
startTime := common.GetContextKeyTime(c, constant.ContextKeyRequestStartTime)
// firstResponseTime = time.Now() - 1 second
apiType, _ := relayconstant.ChannelType2APIType(channelType)
apiType, _ := common.ChannelType2APIType(channelType)
info := &RelayInfo{
UserQuota: c.GetInt(constant.ContextKeyUserQuota),
UserSetting: c.GetStringMap(constant.ContextKeyUserSetting),
UserEmail: c.GetString(constant.ContextKeyUserEmail),
UserQuota: common.GetContextKeyInt(c, constant.ContextKeyUserQuota),
UserSetting: common.GetContextKeyStringMap(c, constant.ContextKeyUserSetting),
UserEmail: common.GetContextKeyString(c, constant.ContextKeyUserEmail),
isFirstResponse: true,
RelayMode: relayconstant.Path2RelayMode(c.Request.URL.Path),
BaseUrl: c.GetString("base_url"),
BaseUrl: common.GetContextKeyString(c, constant.ContextKeyBaseUrl),
RequestURLPath: c.Request.URL.String(),
ChannelType: channelType,
ChannelId: channelId,
TokenId: tokenId,
TokenKey: tokenKey,
UserId: userId,
UsingGroup: c.GetString(constant.ContextKeyUsingGroup),
UserGroup: c.GetString(constant.ContextKeyUserGroup),
UsingGroup: common.GetContextKeyString(c, constant.ContextKeyUsingGroup),
UserGroup: common.GetContextKeyString(c, constant.ContextKeyUserGroup),
TokenUnlimited: tokenUnlimited,
StartTime: startTime,
FirstResponseTime: startTime.Add(-time.Second),
OriginModelName: c.GetString("original_model"),
UpstreamModelName: c.GetString("original_model"),
OriginModelName: common.GetContextKeyString(c, constant.ContextKeyOriginalModel),
UpstreamModelName: common.GetContextKeyString(c, constant.ContextKeyOriginalModel),
//RecodeModelName: c.GetString("original_model"),
IsModelMapped: false,
ApiType: apiType,
@@ -266,12 +266,12 @@ func GenRelayInfo(c *gin.Context) *RelayInfo {
info.RequestURLPath = "/v1" + info.RequestURLPath
}
if info.BaseUrl == "" {
info.BaseUrl = common.ChannelBaseURLs[channelType]
info.BaseUrl = constant.ChannelBaseURLs[channelType]
}
if info.ChannelType == common.ChannelTypeAzure {
if info.ChannelType == constant.ChannelTypeAzure {
info.ApiVersion = GetAPIVersion(c)
}
if info.ChannelType == common.ChannelTypeVertexAi {
if info.ChannelType == constant.ChannelTypeVertexAi {
info.ApiVersion = c.GetString("region")
}
if streamSupportedChannels[info.ChannelType] {
@@ -313,3 +313,22 @@ func GenTaskRelayInfo(c *gin.Context) *TaskRelayInfo {
}
return info
}
type TaskSubmitReq struct {
Prompt string `json:"prompt"`
Model string `json:"model,omitempty"`
Mode string `json:"mode,omitempty"`
Image string `json:"image,omitempty"`
Size string `json:"size,omitempty"`
Duration int `json:"duration,omitempty"`
Metadata map[string]interface{} `json:"metadata,omitempty"`
}
type TaskInfo struct {
Code int `json:"code"`
TaskID string `json:"task_id"`
Status string `json:"status"`
Reason string `json:"reason,omitempty"`
Url string `json:"url,omitempty"`
Progress string `json:"progress,omitempty"`
}

View File

@@ -6,7 +6,7 @@ import (
_ "image/gif"
_ "image/jpeg"
_ "image/png"
"one-api/common"
"one-api/constant"
"strings"
)
@@ -15,9 +15,9 @@ func GetFullRequestURL(baseURL string, requestURL string, channelType int) strin
if strings.HasPrefix(baseURL, "https://gateway.ai.cloudflare.com") {
switch channelType {
case common.ChannelTypeOpenAI:
case constant.ChannelTypeOpenAI:
fullRequestURL = fmt.Sprintf("%s%s", baseURL, strings.TrimPrefix(requestURL, "/v1"))
case common.ChannelTypeAzure:
case constant.ChannelTypeAzure:
fullRequestURL = fmt.Sprintf("%s%s", baseURL, strings.TrimPrefix(requestURL, "/openai/deployments"))
}
}

View File

@@ -5,6 +5,7 @@ import (
"io"
"net/http"
"one-api/common"
"one-api/constant"
"one-api/dto"
"one-api/relay/channel/xinference"
relaycommon "one-api/relay/common"
@@ -16,17 +17,14 @@ func RerankHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http.Respo
if err != nil {
return service.OpenAIErrorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil
}
err = resp.Body.Close()
if err != nil {
return service.OpenAIErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil
}
common.CloseResponseBodyGracefully(resp)
if common.DebugEnabled {
println("reranker response body: ", string(responseBody))
}
var jinaResp dto.RerankResponse
if info.ChannelType == common.ChannelTypeXinference {
if info.ChannelType == constant.ChannelTypeXinference {
var xinRerankResponse xinference.XinRerankResponse
err = common.DecodeJson(responseBody, &xinRerankResponse)
err = common.UnmarshalJson(responseBody, &xinRerankResponse)
if err != nil {
return service.OpenAIErrorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil
}
@@ -38,10 +36,16 @@ func RerankHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http.Respo
}
if info.ReturnDocuments {
var document any
if result.Document == "" {
document = info.Documents[result.Index]
} else {
document = result.Document
if result.Document != nil {
if doc, ok := result.Document.(string); ok {
if doc == "" {
document = info.Documents[result.Index]
} else {
document = doc
}
} else {
document = result.Document
}
}
respResult.Document = document
}
@@ -55,7 +59,7 @@ func RerankHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http.Respo
},
}
} else {
err = common.DecodeJson(responseBody, &jinaResp)
err = common.UnmarshalJson(responseBody, &jinaResp)
if err != nil {
return service.OpenAIErrorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil
}

View File

@@ -1,106 +0,0 @@
package constant
import (
"one-api/common"
)
const (
APITypeOpenAI = iota
APITypeAnthropic
APITypePaLM
APITypeBaidu
APITypeZhipu
APITypeAli
APITypeXunfei
APITypeAIProxyLibrary
APITypeTencent
APITypeGemini
APITypeZhipuV4
APITypeOllama
APITypePerplexity
APITypeAws
APITypeCohere
APITypeDify
APITypeJina
APITypeCloudflare
APITypeSiliconFlow
APITypeVertexAi
APITypeMistral
APITypeDeepSeek
APITypeMokaAI
APITypeVolcEngine
APITypeBaiduV2
APITypeOpenRouter
APITypeXinference
APITypeXai
APITypeCoze
APITypeDummy // this one is only for count, do not add any channel after this
)
func ChannelType2APIType(channelType int) (int, bool) {
apiType := -1
switch channelType {
case common.ChannelTypeOpenAI:
apiType = APITypeOpenAI
case common.ChannelTypeAnthropic:
apiType = APITypeAnthropic
case common.ChannelTypeBaidu:
apiType = APITypeBaidu
case common.ChannelTypePaLM:
apiType = APITypePaLM
case common.ChannelTypeZhipu:
apiType = APITypeZhipu
case common.ChannelTypeAli:
apiType = APITypeAli
case common.ChannelTypeXunfei:
apiType = APITypeXunfei
case common.ChannelTypeAIProxyLibrary:
apiType = APITypeAIProxyLibrary
case common.ChannelTypeTencent:
apiType = APITypeTencent
case common.ChannelTypeGemini:
apiType = APITypeGemini
case common.ChannelTypeZhipu_v4:
apiType = APITypeZhipuV4
case common.ChannelTypeOllama:
apiType = APITypeOllama
case common.ChannelTypePerplexity:
apiType = APITypePerplexity
case common.ChannelTypeAws:
apiType = APITypeAws
case common.ChannelTypeCohere:
apiType = APITypeCohere
case common.ChannelTypeDify:
apiType = APITypeDify
case common.ChannelTypeJina:
apiType = APITypeJina
case common.ChannelCloudflare:
apiType = APITypeCloudflare
case common.ChannelTypeSiliconFlow:
apiType = APITypeSiliconFlow
case common.ChannelTypeVertexAi:
apiType = APITypeVertexAi
case common.ChannelTypeMistral:
apiType = APITypeMistral
case common.ChannelTypeDeepSeek:
apiType = APITypeDeepSeek
case common.ChannelTypeMokaAI:
apiType = APITypeMokaAI
case common.ChannelTypeVolcEngine:
apiType = APITypeVolcEngine
case common.ChannelTypeBaiduV2:
apiType = APITypeBaiduV2
case common.ChannelTypeOpenRouter:
apiType = APITypeOpenRouter
case common.ChannelTypeXinference:
apiType = APITypeXinference
case common.ChannelTypeXai:
apiType = APITypeXai
case common.ChannelTypeCoze:
apiType = APITypeCoze
}
if apiType == -1 {
return APITypeOpenAI, false
}
return apiType, true
}

View File

@@ -41,6 +41,9 @@ const (
RelayModeKlingFetchByID
RelayModeKlingSubmit
RelayModeJimengFetchByID
RelayModeJimengSubmit
RelayModeRerank
RelayModeResponses
@@ -80,7 +83,7 @@ func Path2RelayMode(path string) int {
relayMode = RelayModeRerank
} else if strings.HasPrefix(path, "/v1/realtime") {
relayMode = RelayModeRealtime
} else if strings.HasPrefix(path, "/v1beta/models") {
} else if strings.HasPrefix(path, "/v1beta/models") || strings.HasPrefix(path, "/v1/models") {
relayMode = RelayModeGemini
}
return relayMode
@@ -146,3 +149,13 @@ func Path2RelayKling(method, path string) int {
}
return relayMode
}
func Path2RelayJimeng(method, path string) int {
relayMode := RelayModeUnknown
if method == http.MethodPost && strings.HasSuffix(path, "/video/generations") {
relayMode = RelayModeJimengSubmit
} else if method == http.MethodGet && strings.Contains(path, "/video/generations/") {
relayMode = RelayModeJimengFetchByID
}
return relayMode
}

View File

@@ -20,8 +20,8 @@ import (
)
const (
InitialScannerBufferSize = 64 << 10 // 64KB (64*1024)
MaxScannerBufferSize = 10 << 20 // 10MB (10*1024*1024)
InitialScannerBufferSize = 64 << 10 // 64KB (64*1024)
MaxScannerBufferSize = 10 << 20 // 10MB (10*1024*1024)
DefaultPingInterval = 10 * time.Second
)
@@ -49,7 +49,7 @@ func StreamScannerHandler(c *gin.Context, resp *http.Response, info *relaycommon
scanner = bufio.NewScanner(resp.Body)
ticker = time.NewTicker(streamingTimeout)
pingTicker *time.Ticker
writeMutex sync.Mutex // Mutex to protect concurrent writes
writeMutex sync.Mutex // Mutex to protect concurrent writes
wg sync.WaitGroup // 用于等待所有 goroutine 退出
)
@@ -64,32 +64,39 @@ func StreamScannerHandler(c *gin.Context, resp *http.Response, info *relaycommon
pingTicker = time.NewTicker(pingInterval)
}
if common.DebugEnabled {
// print timeout and ping interval for debugging
println("relay timeout seconds:", common.RelayTimeout)
println("streaming timeout seconds:", int64(streamingTimeout.Seconds()))
println("ping interval seconds:", int64(pingInterval.Seconds()))
}
// 改进资源清理,确保所有 goroutine 正确退出
defer func() {
// 通知所有 goroutine 停止
common.SafeSendBool(stopChan, true)
ticker.Stop()
if pingTicker != nil {
pingTicker.Stop()
}
// 等待所有 goroutine 退出最多等待5秒
done := make(chan struct{})
go func() {
wg.Wait()
close(done)
}()
select {
case <-done:
case <-time.After(5 * time.Second):
common.LogError(c, "timeout waiting for goroutines to exit")
}
close(stopChan)
}()
scanner.Buffer(make([]byte, InitialScannerBufferSize), MaxScannerBufferSize)
scanner.Split(bufio.ScanLines)
SetEventStreamHeaders(c)
@@ -113,12 +120,12 @@ func StreamScannerHandler(c *gin.Context, resp *http.Response, info *relaycommon
println("ping goroutine exited")
}
}()
// 添加超时保护,防止 goroutine 无限运行
maxPingDuration := 30 * time.Minute // 最大 ping 持续时间
pingTimeout := time.NewTimer(maxPingDuration)
defer pingTimeout.Stop()
for {
select {
case <-pingTicker.C:
@@ -129,7 +136,7 @@ func StreamScannerHandler(c *gin.Context, resp *http.Response, info *relaycommon
defer writeMutex.Unlock()
done <- PingData(c)
}()
select {
case err := <-done:
if err != nil {
@@ -175,7 +182,7 @@ func StreamScannerHandler(c *gin.Context, resp *http.Response, info *relaycommon
println("scanner goroutine exited")
}
}()
for scanner.Scan() {
// 检查是否需要停止
select {
@@ -187,7 +194,7 @@ func StreamScannerHandler(c *gin.Context, resp *http.Response, info *relaycommon
return
default:
}
ticker.Reset(streamingTimeout)
data := scanner.Text()
if common.DebugEnabled {
@@ -205,7 +212,7 @@ func StreamScannerHandler(c *gin.Context, resp *http.Response, info *relaycommon
data = strings.TrimSuffix(data, "\r")
if !strings.HasPrefix(data, "[DONE]") {
info.SetFirstResponseTime()
// 使用超时机制防止写操作阻塞
done := make(chan bool, 1)
go func() {
@@ -213,7 +220,7 @@ func StreamScannerHandler(c *gin.Context, resp *http.Response, info *relaycommon
defer writeMutex.Unlock()
done <- dataHandler(data)
}()
select {
case success := <-done:
if !success {

View File

@@ -8,6 +8,7 @@ import (
"io"
"net/http"
"one-api/common"
"one-api/constant"
"one-api/dto"
"one-api/model"
relaycommon "one-api/relay/common"
@@ -17,8 +18,6 @@ import (
"one-api/setting"
"strings"
"one-api/relay/constant"
"github.com/gin-gonic/gin"
)

View File

@@ -279,10 +279,7 @@ func RelayMidjourneyTaskImageSeed(c *gin.Context) *dto.MidjourneyResponse {
if err != nil {
return service.MidjourneyErrorWrapper(constant.MjRequestError, "unmarshal_response_body_failed")
}
_, err = io.Copy(c.Writer, bytes.NewBuffer(respBody))
if err != nil {
return service.MidjourneyErrorWrapper(constant.MjRequestError, "copy_response_body_failed")
}
common.IOCopyBytesGracefully(c, nil, respBody)
return nil
}

View File

@@ -1,6 +1,7 @@
package relay
import (
"one-api/constant"
commonconstant "one-api/constant"
"one-api/relay/channel"
"one-api/relay/channel/ali"
@@ -22,6 +23,7 @@ import (
"one-api/relay/channel/palm"
"one-api/relay/channel/perplexity"
"one-api/relay/channel/siliconflow"
"one-api/relay/channel/task/jimeng"
"one-api/relay/channel/task/kling"
"one-api/relay/channel/task/suno"
"one-api/relay/channel/tencent"
@@ -31,7 +33,6 @@ import (
"one-api/relay/channel/xunfei"
"one-api/relay/channel/zhipu"
"one-api/relay/channel/zhipu_4v"
"one-api/relay/constant"
)
func GetAdaptor(apiType int) channel.Adaptor {
@@ -104,6 +105,8 @@ func GetTaskAdaptor(platform commonconstant.TaskPlatform) channel.TaskAdaptor {
return &suno.TaskAdaptor{}
case commonconstant.TaskPlatformKling:
return &kling.TaskAdaptor{}
case commonconstant.TaskPlatformJimeng:
return &jimeng.TaskAdaptor{}
}
return nil
}

View File

@@ -245,7 +245,7 @@ func sunoFetchByIDRespBodyBuilder(c *gin.Context) (respBody []byte, taskResp *dt
}
func videoFetchByIDRespBodyBuilder(c *gin.Context) (respBody []byte, taskResp *dto.TaskError) {
taskId := c.Param("id")
taskId := c.Param("task_id")
userId := c.GetInt("id")
originTask, exist, err := model.GetByTaskId(userId, taskId)

View File

@@ -78,12 +78,15 @@ func RerankHelper(c *gin.Context, relayMode int) (openaiErr *dto.OpenAIErrorWith
return service.OpenAIErrorWrapperLocal(err, "json_marshal_failed", http.StatusInternalServerError)
}
requestBody := bytes.NewBuffer(jsonData)
statusCodeMappingStr := c.GetString("status_code_mapping")
if common.DebugEnabled {
println(fmt.Sprintf("Rerank request body: %s", requestBody.String()))
}
resp, err := adaptor.DoRequest(c, relayInfo, requestBody)
if err != nil {
return service.OpenAIErrorWrapper(err, "do_request_failed", http.StatusInternalServerError)
}
statusCodeMappingStr := c.GetString("status_code_mapping")
var httpResp *http.Response
if resp != nil {
httpResp = resp.(*http.Response)

View File

@@ -63,6 +63,7 @@ func SetRelayRouter(router *gin.Engine) {
httpRouter.DELETE("/models/:model", controller.RelayNotImplemented)
httpRouter.POST("/moderations", controller.Relay)
httpRouter.POST("/rerank", controller.Relay)
httpRouter.POST("/models/*path", controller.Relay)
}
relayMjRouter := router.Group("/mj")

View File

@@ -14,4 +14,11 @@ func SetVideoRouter(router *gin.Engine) {
videoV1Router.POST("/video/generations", controller.RelayTask)
videoV1Router.GET("/video/generations/:task_id", controller.RelayTask)
}
klingV1Router := router.Group("/kling/v1")
klingV1Router.Use(middleware.KlingRequestConvert(), middleware.TokenAuth(), middleware.Distribute())
{
klingV1Router.POST("/videos/text2video", controller.RelayTask)
klingV1Router.POST("/videos/image2video", controller.RelayTask)
}
}

View File

@@ -4,6 +4,7 @@ import (
"fmt"
"net/http"
"one-api/common"
"one-api/constant"
"one-api/dto"
"one-api/model"
"one-api/setting/operation_setting"
@@ -48,7 +49,7 @@ func ShouldDisableChannel(channelType int, err *dto.OpenAIErrorWithStatusCode) b
}
if err.StatusCode == http.StatusForbidden {
switch channelType {
case common.ChannelTypeGemini:
case constant.ChannelTypeGemini:
return true
}
}

View File

@@ -4,6 +4,7 @@ import (
"encoding/json"
"fmt"
"one-api/common"
"one-api/constant"
"one-api/dto"
"one-api/relay/channel/openrouter"
relaycommon "one-api/relay/common"
@@ -19,7 +20,7 @@ func ClaudeToOpenAIRequest(claudeRequest dto.ClaudeRequest, info *relaycommon.Re
Stream: claudeRequest.Stream,
}
isOpenRouter := info.ChannelType == common.ChannelTypeOpenRouter
isOpenRouter := info.ChannelType == constant.ChannelTypeOpenRouter
if claudeRequest.Thinking != nil && claudeRequest.Thinking.Type == "enabled" {
if isOpenRouter {

View File

@@ -90,10 +90,7 @@ func RelayErrorHandler(resp *http.Response, showBodyWhenFail bool) (errWithStatu
if err != nil {
return
}
err = resp.Body.Close()
if err != nil {
return
}
common.CloseResponseBodyGracefully(resp)
var errResponse dto.GeneralErrorResponse
err = json.Unmarshal(responseBody, &errResponse)
if err != nil {

View File

@@ -13,9 +13,8 @@ import (
)
var httpClient *http.Client
var impatientHTTPClient *http.Client
func init() {
func InitHttpClient() {
if common.RelayTimeout == 0 {
httpClient = &http.Client{}
} else {
@@ -23,20 +22,12 @@ func init() {
Timeout: time.Duration(common.RelayTimeout) * time.Second,
}
}
impatientHTTPClient = &http.Client{
Timeout: 5 * time.Second,
}
}
func GetHttpClient() *http.Client {
return httpClient
}
func GetImpatientHttpClient() *http.Client {
return impatientHTTPClient
}
// NewProxyHttpClient 创建支持代理的 HTTP 客户端
func NewProxyHttpClient(proxyURL string) (*http.Client, error) {
if proxyURL == "" {

View File

@@ -228,10 +228,7 @@ func DoMidjourneyHttpRequest(c *gin.Context, timeout time.Duration, fullRequestU
if err != nil {
return MidjourneyErrorWithStatusCodeWrapper(constant.MjErrorUnknown, "read_response_body_failed", statusCode), nullBytes, err
}
err = resp.Body.Close()
if err != nil {
return MidjourneyErrorWithStatusCodeWrapper(constant.MjErrorUnknown, "close_response_body_failed", statusCode), responseBody, err
}
common.CloseResponseBodyGracefully(resp)
respStr := string(responseBody)
log.Printf("respStr: %s", respStr)
if respStr == "" {

View File

@@ -6,7 +6,7 @@ import (
"log"
"math"
"one-api/common"
constant2 "one-api/constant"
"one-api/constant"
"one-api/dto"
"one-api/model"
relaycommon "one-api/relay/common"
@@ -232,7 +232,7 @@ func PostClaudeConsumeQuota(ctx *gin.Context, relayInfo *relaycommon.RelayInfo,
cacheCreationRatio := priceData.CacheCreationRatio
cacheCreationTokens := usage.PromptTokensDetails.CachedCreationTokens
if relayInfo.ChannelType == common.ChannelTypeOpenRouter {
if relayInfo.ChannelType == constant.ChannelTypeOpenRouter {
promptTokens -= cacheTokens
if cacheCreationTokens == 0 && priceData.CacheCreationRatio != 1 && usage.Cost != 0 {
maybeCacheCreationTokens := CalcOpenRouterCacheCreateTokens(*usage, priceData)
@@ -447,7 +447,7 @@ func checkAndSendQuotaNotify(relayInfo *relaycommon.RelayInfo, quota int, preCon
gopool.Go(func() {
userSetting := relayInfo.UserSetting
threshold := common.QuotaRemindThreshold
if userCustomThreshold, ok := userSetting[constant2.UserSettingQuotaWarningThreshold]; ok {
if userCustomThreshold, ok := userSetting[constant.UserSettingQuotaWarningThreshold]; ok {
threshold = int(userCustomThreshold.(float64))
}

View File

@@ -101,7 +101,7 @@ func getImageToken(info *relaycommon.RelayInfo, imageUrl *dto.MessageImageUrl, m
if !constant.GetMediaToken {
return 3 * baseTokens, nil
}
if info.ChannelType == common.ChannelTypeGemini || info.ChannelType == common.ChannelTypeVertexAi || info.ChannelType == common.ChannelTypeAnthropic {
if info.ChannelType == constant.ChannelTypeGemini || info.ChannelType == constant.ChannelTypeVertexAi || info.ChannelType == constant.ChannelTypeAnthropic {
return 3 * baseTokens, nil
}
var config image.Config
@@ -172,9 +172,6 @@ func CountTokenChatRequest(info *relaycommon.RelayInfo, request dto.GeneralOpenA
}
}
toolTokens := CountTokenInput(countStr, request.Model)
if err != nil {
return 0, err
}
tkm += 8
tkm += toolTokens
}
@@ -195,9 +192,6 @@ func CountTokenClaudeRequest(request dto.ClaudeRequest, model string) (int, erro
// Count tokens in system message
if request.System != "" {
systemTokens := CountTokenInput(request.System, model)
if err != nil {
return 0, err
}
tkm += systemTokens
}

View File

@@ -101,7 +101,7 @@ func SendWebhookNotify(webhookURL string, secret string, data dto.Notify) error
}
// 发送请求
client := GetImpatientHttpClient()
client := GetHttpClient()
resp, err = client.Do(req)
if err != nil {
return fmt.Errorf("failed to send webhook request: %v", err)

View File

@@ -501,16 +501,19 @@ func getHardcodedCompletionModelRatio(name string) (float64, bool) {
} else if strings.HasPrefix(name, "gemini-2.0") {
return 4, true
} else if strings.HasPrefix(name, "gemini-2.5-pro") { // 移除preview来增加兼容性这里假设正式版的倍率和preview一致
return 8, true
return 8, false
} else if strings.HasPrefix(name, "gemini-2.5-flash") { // 处理不同的flash模型倍率
if strings.HasPrefix(name, "gemini-2.5-flash-preview") {
if strings.HasSuffix(name, "-nothinking") {
return 4, true
return 4, false
}
return 3.5 / 0.15, true
return 3.5 / 0.15, false
}
if strings.HasPrefix(name, "gemini-2.5-flash-lite-preview") {
return 4, true
if strings.HasPrefix(name, "gemini-2.5-flash-lite") {
if strings.HasPrefix(name, "gemini-2.5-flash-lite-preview") {
return 4, false
}
return 4, false
}
return 2.5 / 0.3, true
}

42
types/set.go Normal file
View File

@@ -0,0 +1,42 @@
package types
type Set[T comparable] struct {
items map[T]struct{}
}
// NewSet 创建并返回一个新的 Set
func NewSet[T comparable]() *Set[T] {
return &Set[T]{
items: make(map[T]struct{}),
}
}
func (s *Set[T]) Add(item T) {
s.items[item] = struct{}{}
}
// Remove 从 Set 中移除一个元素
func (s *Set[T]) Remove(item T) {
delete(s.items, item)
}
// Contains 检查 Set 是否包含某个元素
func (s *Set[T]) Contains(item T) bool {
_, exists := s.items[item]
return exists
}
// Len 返回 Set 中元素的数量
func (s *Set[T]) Len() int {
return len(s.items)
}
// Items 返回 Set 中所有元素组成的切片
// 注意:由于 map 的无序性,返回的切片元素顺序是随机的
func (s *Set[T]) Items() []T {
items := make([]T, 0, s.Len())
for item := range s.items {
items = append(items, item)
}
return items
}

View File

@@ -34,20 +34,20 @@ import LinuxDoIcon from '../common/logo/LinuxDoIcon.js';
import { useTranslation } from 'react-i18next';
const LoginForm = () => {
let navigate = useNavigate();
const { t } = useTranslation();
const [inputs, setInputs] = useState({
username: '',
password: '',
wechat_verification_code: '',
});
const { username, password } = inputs;
const [searchParams, setSearchParams] = useSearchParams();
const [submitted, setSubmitted] = useState(false);
const { username, password } = inputs;
const [userState, userDispatch] = useContext(UserContext);
const [turnstileEnabled, setTurnstileEnabled] = useState(false);
const [turnstileSiteKey, setTurnstileSiteKey] = useState('');
const [turnstileToken, setTurnstileToken] = useState('');
let navigate = useNavigate();
const [status, setStatus] = useState({});
const [showWeChatLoginModal, setShowWeChatLoginModal] = useState(false);
const [showEmailLogin, setShowEmailLogin] = useState(false);
const [wechatLoading, setWechatLoading] = useState(false);
@@ -59,7 +59,6 @@ const LoginForm = () => {
const [resetPasswordLoading, setResetPasswordLoading] = useState(false);
const [otherLoginOptionsLoading, setOtherLoginOptionsLoading] = useState(false);
const [wechatCodeSubmitLoading, setWechatCodeSubmitLoading] = useState(false);
const { t } = useTranslation();
const logo = getLogo();
const systemName = getSystemName();
@@ -69,19 +68,22 @@ const LoginForm = () => {
localStorage.setItem('aff', affCode);
}
const [status] = useState(() => {
const savedStatus = localStorage.getItem('status');
return savedStatus ? JSON.parse(savedStatus) : {};
});
useEffect(() => {
if (status.turnstile_check) {
setTurnstileEnabled(true);
setTurnstileSiteKey(status.turnstile_site_key);
}
}, [status]);
useEffect(() => {
if (searchParams.get('expired')) {
showError(t('未登录或登录已过期,请重新登录'));
}
let status = localStorage.getItem('status');
if (status) {
status = JSON.parse(status);
setStatus(status);
if (status.turnstile_check) {
setTurnstileEnabled(true);
setTurnstileSiteKey(status.turnstile_site_key);
}
}
}, []);
const onWeChatLoginClicked = () => {
@@ -356,9 +358,19 @@ const LoginForm = () => {
</Button>
</div>
<div className="mt-6 text-center text-sm">
<Text>{t('没有账户?')} <Link to="/register" className="text-blue-600 hover:text-blue-800 font-medium">{t('注册')}</Link></Text>
</div>
{!status.self_use_mode_enabled && (
<div className="mt-6 text-center text-sm">
<Text>
{t('没有账户?')}{' '}
<Link
to="/register"
className="text-blue-600 hover:text-blue-800 font-medium"
>
{t('注册')}
</Link>
</Text>
</div>
)}
</div>
</Card>
</div>
@@ -387,7 +399,6 @@ const LoginForm = () => {
placeholder={t('请输入您的用户名或邮箱地址')}
name="username"
size="large"
className="!rounded-md"
onChange={(value) => handleChange('username', value)}
prefix={<IconMail />}
/>
@@ -399,7 +410,6 @@ const LoginForm = () => {
name="password"
mode="password"
size="large"
className="!rounded-md"
onChange={(value) => handleChange('password', value)}
prefix={<IconLock />}
/>
@@ -451,9 +461,19 @@ const LoginForm = () => {
</>
)}
<div className="mt-6 text-center text-sm">
<Text>{t('没有账户?')} <Link to="/register" className="text-blue-600 hover:text-blue-800 font-medium">{t('注册')}</Link></Text>
</div>
{!status.self_use_mode_enabled && (
<div className="mt-6 text-center text-sm">
<Text>
{t('没有账户?')}{' '}
<Link
to="/register"
className="text-blue-600 hover:text-blue-800 font-medium"
>
{t('注册')}
</Link>
</Text>
</div>
)}
</div>
</Card>
</div>
@@ -499,8 +519,11 @@ const LoginForm = () => {
};
return (
<div className="bg-gray-100 flex items-center justify-center py-12 px-4 sm:px-6 lg:px-8">
<div className="w-full max-w-sm">
<div className="relative overflow-hidden bg-gray-100 flex items-center justify-center py-12 px-4 sm:px-6 lg:px-8">
{/* 背景模糊晕染球 */}
<div className="blur-ball blur-ball-indigo" style={{ top: '-80px', right: '-80px', transform: 'none' }} />
<div className="blur-ball blur-ball-teal" style={{ top: '50%', left: '-120px' }} />
<div className="w-full max-w-sm mt-[64px]">
{showEmailLogin || !(status.github_oauth || status.oidc_enabled || status.wechat_login || status.linuxdo_oauth || status.telegram_oauth)
? renderEmailLoginForm()
: renderOAuthOptions()}

Some files were not shown because too many files have changed in this diff Show More