Compare commits

..

130 Commits

Author SHA1 Message Date
Calcium-Ion
cf4700a35c Merge pull request #1278 from QuantumNous/alpha
feat: conditionally set Gemini ThinkingBudget based on MaxOutputTokens
2025-06-21 18:27:05 +08:00
CaIon
6bb552128c feat(relay-gemini): conditionally set ThinkingBudget based on MaxOutputTokens 2025-06-21 17:51:13 +08:00
CaIon
50b4fc06f8 Merge branch 'main' into alpha 2025-06-21 17:04:29 +08:00
creamlike1024
59574dc80f Merge branch 'bddiudiu-feat_images' 2025-06-21 10:31:53 +08:00
creamlike1024
7577ec1ac4 Merge branch 'feat_images' of github.com:bddiudiu/new-api into bddiudiu-feat_images 2025-06-21 10:31:37 +08:00
t0ng7u
d487be0029 feat(settings/announcements): sort by publishDate desc
Add reverse-chronological sorting for the announcements list so that the newest
items appear first in the dashboard.

No API changes; this only affects front-end display and user notifications.
2025-06-21 06:15:26 +08:00
Apple\Apple
83a3872b97 Merge remote-tracking branch 'origin/alpha' into alpha 2025-06-21 06:06:35 +08:00
Apple\Apple
1ad2f63f85 feat: Enhance announcements UX with unread badge, tabbed NoticeModal, and shine animation
• HeaderBar
  - Added dynamic unread badge; click now opens NoticeModal on “System Announcements” tab
  - Passes `defaultTab` and `unreadKeys` props to NoticeModal for contextual behaviour

• NoticeModal
  - Introduced Tabs inside the modal title with Lucide icons (Bell, Megaphone)
  - Displays in-app notice (markdown) and system announcements separately
  - Highlights unread announcements with “shine” text animation
  - Accepts new props `defaultTab`, `unreadKeys` to control initial tab and highlight logic

• CSS (index.css)
  - Implemented `sweep-shine` keyframes and `.shine-text` utility for left-to-right glow
  - Added dark-mode variant for better contrast
  - Ensured cross-browser support with standard `background-clip`

Overall, users now see an unread counter, are directed to new announcements automatically, and benefit from an eye-catching glow effect that works in both light and dark themes.
2025-06-21 06:06:21 +08:00
Calcium-Ion
fcaa8317e4 Merge pull request #1160 from feitianbubu/add-moonshot-kimi-update-balance
feat: add moonshot(kimi) update balance
2025-06-21 05:03:45 +08:00
Calcium-Ion
ccda14255a Merge pull request #1135 from PeterDaveHelloKitchen/Dockerfile
refactor: optimize Dockerfile apk usage
2025-06-21 05:03:16 +08:00
Calcium-Ion
8d66828229 Merge pull request #1210 from RedwindA/feat/128-budget
feat: allow for a lower percentage of the thinkingBudget
2025-06-21 04:54:46 +08:00
Calcium-Ion
4ebf9e35e1 Merge pull request #1248 from RedwindA/update-gemini-ratio
feat(model-ratio): add default ratios for new Gemini models and refine flash model handling
2025-06-21 04:51:41 +08:00
t0ng7u
2902d6c7c2 🎨 style: add mt-2 style in ratio section 2025-06-21 04:25:56 +08:00
t0ng7u
01ef1fe4e4 📝 refactor: reorganize payment settings into dedicated tab
Restructure payment settings into a separate tab for better organization and user experience. The changes include:

1. Create dedicated Payment components in the Setting directory structure
2. Move payment-related settings from SystemSetting to PaymentSetting
3. Add proper i18n support with useTranslation hook
4. Split payment settings into GeneralPayment and PaymentGateway components
5. Fix internationalization issues in placeholder text
6. Update navigation with CreditCard icon for payment tab

This refactoring improves code maintainability by following the established project pattern of having specialized setting components in their own directories.
2025-06-21 04:16:01 +08:00
Apple\Apple
c3d2d07b68 🚚 refactor: Move DataDashboard settings from Operation to Dashboard section
This commit relocates the DataDashboard settings component from the Operation section to the Dashboard section for better logical organization. The changes include:

- Remove DataDashboard import and component from OperationSetting.js
- Add DataDashboard component to DashboardSetting.js
- Update import path from Operation to Dashboard directory
- Add DataExport related state management in DashboardSetting

This restructuring improves the application's information architecture by grouping related dashboard visualization settings together.
2025-06-21 02:56:38 +08:00
Apple\Apple
18417bacb3 🎨 refactor(UI): Move drawing settings to a separate tab
- Create a new DrawingSetting component for managing drawing-related configurations
- Add a dedicated "Drawing Settings" tab with Palette icon in the settings page
- Remove drawing settings section from the OperationSetting component
- Update import path to use Drawing directory instead of Operation directory
- Improve UI organization by separating drawing settings from general operations
2025-06-21 02:50:09 +08:00
Apple\Apple
8ec18dd21b 💬 refactor: separate chat settings into dedicated tab
- Create new ChatsSetting component for managing chat configurations
- Add "Chat Settings" tab with MessageSquare icon in settings page
- Remove chat settings section from OperationSetting component
- Update import path to use Chat directory structure
2025-06-21 02:36:09 +08:00
t0ng7u
edaff1c689 🎨 feat(UI): Add Lucide icons to settings tabs for improved navigation
- Add icons to each settings tab to enhance visual recognition
- Import necessary Lucide React icons (Settings, Calculator, Gauge, Shapes, etc.)
- Create consistent tab styling with icons aligned next to text
- Reorder tabs to place "Other Settings" as the last option
- Improve overall settings page UI with better visual hierarchy
2025-06-21 02:21:27 +08:00
Apple\Apple
9c3a13cb23 📝 i18n: Update ratio sync message text
- Change message from "已与上游倍率完全一致,无需同步" to "未找到差异化倍率,无需同步"
- Update English translation to "No differential ratio found, no synchronization is required"
- Improve user experience clarity for upstream ratio synchronization status
2025-06-21 02:09:08 +08:00
Apple\Apple
0b326e7af4 🌐 feat(i18n): add English translation for "暴露倍率接口"
- Add translation "Expose ratio API" for Chinese text "暴露倍率接口"
- Update English locale file (en.json) with new translation entry
2025-06-21 02:00:58 +08:00
CaIon
1a1ff836b5 Merge branch 'alpha' 2025-06-21 01:26:57 +08:00
Calcium-Ion
34fed74f64 Merge pull request #1252 from feitianbubu/pr/fix-playgroud-user-setting
fix: playground write user context to check acceptUnsetRatio
2025-06-21 01:25:22 +08:00
Calcium-Ion
f89b29928c Merge pull request #1264 from feitianbubu/pr/uniq-channel-models
fix: unique channel models
2025-06-21 01:21:28 +08:00
Calcium-Ion
2c6d4460c3 Merge pull request #1272 from QuantumNous/gemini-stream-completion-count-fix
fix: gemini 原生格式流模式中断请求未计费
2025-06-21 01:21:01 +08:00
CaIon
7afd3f97ee fix: remove unnecessary error handling in token counting functions 2025-06-21 01:16:54 +08:00
CaIon
0708452939 fix: improve usage calculation in GeminiTextGenerationStreamHandler 2025-06-21 01:08:15 +08:00
CaIon
a9e5d99ea3 refactor: token counter logic 2025-06-21 00:54:40 +08:00
creamlike1024
a56d9ea98b fix: gemini 原生格式流模式中断请求未计费 2025-06-20 23:01:10 +08:00
CaIon
f5e80af0b3 fix: update response handling in GeminiTextGenerationStreamHandler
- Changed response handling from ObjectData to StringData for improved data processing.
- Ensured proper error logging in case of response handling failure.
2025-06-20 21:55:28 +08:00
CaIon
a1a7ddbc83 fix: update payment method handling in topup controller
- Refactored payment method validation to check against available methods.
- Changed payment method types from "zfb" to "alipay" and "wx" to "wxpay" for consistency.
- Updated the purchase request to use the validated payment method directly.
2025-06-20 17:48:55 +08:00
Calcium-Ion
8b209d8926 Merge pull request #1271 from RedwindA/feat/vertex-budgetControl
feat: vertex budget control in model name
2025-06-20 17:24:42 +08:00
RedwindA
9344cab59a fix: update model name logic for vertex 2025-06-20 16:40:51 +08:00
Calcium-Ion
03468e05e4 Merge pull request #1267 from t0ng7u/feature/upstream-ratio-sync
🔄 feat(ratio-sync): introduce upstream ratio synchronisation feature #1220
2025-06-20 16:22:00 +08:00
Calcium-Ion
11792ba1a4 Merge pull request #1270 from QuantumNous/refactor_model_mapping
feat: implement new handlers for relay processing
2025-06-20 16:12:41 +08:00
Calcium-Ion
5baaa06896 Merge pull request #1244 from feitianbubu/feat/video
feat: 支持可灵视频渠道(异步任务)
2025-06-20 16:11:59 +08:00
CaIon
d3286893c4 feat: implement new handlers for audio, image, embedding, and responses processing
- Added new handlers: AudioHelper, ImageHelper, EmbeddingHelper, and ResponsesHelper to manage respective requests.
- Updated ModelMappedHelper to accept request parameters for better model mapping.
- Enhanced error handling and validation across new handlers to ensure robust request processing.
- Introduced support for new relay formats in relay_info and updated relevant functions accordingly.
2025-06-20 16:02:23 +08:00
CaIon
b087b20bac refactor: update error handling in ClaudeHelper and GeminiHelper 2025-06-20 14:53:27 +08:00
Calcium-Ion
6a5a839d4d Merge pull request #1268 from QuantumNous/alpha
fix: gemini relay empty response
2025-06-20 02:31:11 +08:00
CaIon
5d8a0952b4 feat: enhance error handling in GeminiHelper and streamline response processing
- Added status code mapping handling in GeminiHelper to reset status codes based on response.
- Removed redundant candidate check in GeminiTextGenerationHandler to simplify response processing.
2025-06-20 01:42:19 +08:00
CaIon
bd08ecc1e0 Merge remote-tracking branch 'origin/alpha' into alpha 2025-06-20 01:07:52 +08:00
CaIon
e4f61c1084 Merge branch 'main' into alpha 2025-06-20 01:07:44 +08:00
Apple\Apple
a38215478f 💄 style(TopUp): Optimize payment method buttons layout based on quantity
Enhance the UI of payment method selection area with responsive layouts:
- Use 2-column grid when exactly 2 payment methods are present
- Use 3-column grid for 3 payment methods
- Use compact card layout for more than 3 payment methods
- Full-width button for single payment method

This improves the visual balance across different device sizes and payment provider configurations, ensuring buttons fill their grid cells appropriately with the w-full class.
2025-06-20 00:52:45 +08:00
RedwindA
c192d07a04 fix gizmo completion ratio 2025-06-19 20:16:04 +08:00
RedwindA
098880b796 Merge remote-tracking branch 'upstream/alpha' into update-gemini-ratio 2025-06-19 20:02:27 +08:00
Apple\Apple
150c506ece 🚀 chore(controller, dto): elevate ratio-sync feature to production readiness
WHAT’S NEW
• controller/ratio_sync.go
  – Deleted unused local structs (TestResult, DifferenceItem, SyncableChannel).
  – Centralised config with constants: defaultTimeoutSeconds, defaultEndpoint, maxConcurrentFetches, ratioTypes.
  – Replaced magic numbers; added semaphore-based concurrency limit and shared http.Client (with TLS & Expect-Continue timeouts).
  – Added comprehensive error handling and context-aware logging via common.Log* helpers.
  – Checked DB errors from GetChannelsByIds; early-return on failures or empty upstream list.
  – Removed custom-channel support; logic now relies solely on ChannelIDs.
  – Minor clean-ups: import grouping, string trimming, endpoint normalisation.

• dto/ratio_sync.go
  – Simplified UpstreamRequest: dropped unused CustomChannels field.

WHY
These improvements harden the ratio-sync endpoint for production use by preventing silent failures, controlling resource usage, and making behaviour configurable and observable.

HOW
No business logic change—only structural refactor, logging, and safeguards—so existing API contracts (aside from removed custom_channels) remain intact.
2025-06-19 19:55:51 +08:00
CaIon
f978d8224e feat: add data presence check before batch update in utils 2025-06-19 19:34:57 +08:00
CaIon
ab59887933 refactor: streamline JSON response structure in channel API endpoints 2025-06-19 19:03:35 +08:00
Apple\Apple
458472f3e2 🔍 feat(ratio-sync): add fuzzy model search & enhance empty-state UX
Summary
1. Add model name search box
   • Introduce Semi UI `Input` with `IconSearch` prefix next to the “Apply Sync” button.
   • Support case-insensitive fuzzy matching of model names.
   • Real-time filtering, pagination and bulk-select logic now work on filtered data.

2. Improve empty state handling
   • Add `hasSynced` flag to distinguish “not synced yet” from “synced with no differences”.
   • Display messages:
     – “Please select sync channels” when no sync has been performed.
     – “No differences found” when a sync completed with zero discrepancies.
     – “No matching model found” when search yields no results.

3. UI tweaks
   • Replace lucide-react `Search` icon with Semi UI `IconSearch` for visual consistency.
   • Keep responsive width and clearable input for better usability.

Why
These changes allow admins to quickly locate specific models and provide accurate feedback on the sync status, greatly improving the usability of the Upstream Ratio Sync page.
2025-06-19 18:54:46 +08:00
Apple\Apple
a9f98c5d39 🛠️ chore(ratio-sync): improve upstream ratio comparison & output cleanliness
Summary
1. Consider “both unset” as identical
   • When both localValue and upstreamValue are nil, mark upstreamValue as "same" to avoid showing “Not set”.

2. Exclude fully-synced upstream channels from result
   • Scan `differences` to detect channels that contain at least one divergent value.
   • Remove channels whose every ratio is either `"same"` or `nil`, so the frontend only receives actionable discrepancies.

Why
These changes reduce visual noise in the Upstream Ratio Sync table, making it easier for admins to focus on models requiring attention. No functional regressions or breaking API changes are introduced.
2025-06-19 18:38:43 +08:00
creamlike1024
2b7dff2d94 fix: 使用日志分组查询 2025-06-19 17:17:32 +08:00
CaIon
58752d2dcf Merge branch 'alpha' 2025-06-19 16:17:56 +08:00
Apple\Apple
67546f4b2a chore(ui): enhance channel selector with status avatars and UI improvements
Add visual status indicators and improve user experience for the upstream ratio sync channel selector modal.

Features:
- Add status-based avatar indicators for channels (enabled/disabled/auto-disabled)
- Implement search functionality with text highlighting
- Add endpoint configuration input for each channel
- Optimize component structure with reusable ChannelInfo component

UI Improvements:
- Custom styling for transfer component items
- Hide scrollbars for cleaner appearance in transfer lists
- Responsive layout adjustments for channel information display
- Color-coded avatars: green (enabled), red (disabled), amber (auto-disabled), grey (unknown)

Code Quality:
- Extract channel status configuration to constants
- Create reusable ChannelInfo component to reduce code duplication
- Implement proper search filtering for both channel names and URLs
- Add consistent styling classes for transfer demo components

Files modified:
- web/src/components/settings/ChannelSelectorModal.js
- web/src/pages/Setting/Ratio/UpstreamRatioSync.js
- web/src/index.css

This enhancement provides better visual feedback for channel status and improves the overall user experience when selecting channels for ratio synchronization.
2025-06-19 16:05:50 +08:00
CaIon
8e9dae7b5f fix: ratio render 2025-06-19 15:36:06 +08:00
Apple\Apple
fb4ff63bad 🗑️ chore(custom channel): Remove custom channel support from upstream ratio sync
Remove all custom channel functionality from the upstream ratio sync feature to simplify the codebase and focus on database-stored channels only.

Changes:
- Remove custom channel UI components and related state management
- Remove custom channel testing and validation logic
- Simplify ChannelSelectorModal by removing custom channel input fields
- Update API payload to only include channel_ids, removing custom_channels
- Remove custom channel processing logic from backend controller
- Update import path for DEFAULT_ENDPOINT constant

Files modified:
- web/src/pages/Setting/Ratio/UpstreamRatioSync.js
- web/src/components/settings/ChannelSelectorModal.js
- controller/ratio_sync.go

This change streamlines the ratio synchronization workflow by focusing solely on pre-configured database channels, reducing complexity and potential maintenance overhead.
2025-06-19 15:17:05 +08:00
skynono
1fed1ee567 fix: unique channel models 2025-06-19 15:16:26 +08:00
creamlike1024
02571c20ff Merge branch 'xqx121-main' 2025-06-19 14:51:15 +08:00
creamlike1024
8201daa4b4 update relay-gemini-native.go 2025-06-19 14:50:50 +08:00
creamlike1024
5b54624cd5 Merge branch 'main' of github.com:xqx121/new-api into xqx121-main 2025-06-19 14:45:41 +08:00
CaIon
db737567fb Merge remote-tracking branch 'origin/alpha' into alpha 2025-06-19 14:36:55 +08:00
CaIon
5c3898d13e Merge branch 'main' into alpha 2025-06-19 14:36:17 +08:00
Calcium-Ion
2fdb2be6d0 Merge pull request #1253 from TopickAI/main
Fix Vertex channel global region format for claude models
2025-06-19 14:35:18 +08:00
Calcium-Ion
ab78efc815 Merge pull request #1260 from tbphp/fix-gemini-empty-content-error
fix: Gemini & Vertex empty content error
2025-06-19 14:34:27 +08:00
Calcium-Ion
fcf97d1796 Merge pull request #1261 from KamiPasi/new-api-pr
透传thinking参数, 豆包模型用来控制是否思考
2025-06-19 14:33:41 +08:00
Calcium-Ion
e85cc6acbe Merge pull request #1262 from wans10/main
fix: 修复渠道界面模型选择下拉框模型重复显示
2025-06-19 14:32:49 +08:00
Calcium-Ion
da002e6ca9 Merge pull request #1257 from QuantumNous/pay_custom
feat: 自定义充值方式
2025-06-19 14:31:45 +08:00
wans10
070e7b6911 修复渠道界面模型选择下拉框模型重复显示 2025-06-19 13:34:11 +08:00
KamiPasi
d5a3eb7d04 透传thinking参数, 豆包模型用来控制是否思考 2025-06-19 12:06:42 +08:00
skynono
616e6953cc feat: add unsupported test case for kling channel 2025-06-19 11:53:47 +08:00
skynono
b7c77777a5 feat: add video channel kling fix 2025-06-19 11:53:47 +08:00
skynono
8a79de333a feat: add video channel kling 2025-06-19 11:53:42 +08:00
tbphp
a87d4271d3 fix: Gemini & Vertex empty content error 2025-06-19 11:25:59 +08:00
Apple\Apple
7975cdf3bf 🚀 feat(ratio-sync): major refactor & UX overhaul for Upstream Ratio Sync 2025-06-19 08:57:34 +08:00
Calcium-Ion
b2badad554 Merge pull request #1258 from feitianbubu/pr/fix-task-cost-time
fix: task cost time
2025-06-18 23:15:28 +08:00
skynono
133d8c9f77 fix: task cost time 2025-06-18 21:57:46 +08:00
creamlike1024
9708d645d3 feat: 充值方式设置 2025-06-18 21:23:06 +08:00
CaIon
0bca4d3efc Merge remote-tracking branch 'origin/alpha' into alpha 2025-06-18 20:51:06 +08:00
CaIon
7572e791f6 feat(relay): add debug logging for Gemini request body and introduce flexible speech configuration 2025-06-18 20:50:13 +08:00
Apple\Apple
a180d13182 🚚 Refactor(ratio_setting): refactor ratio management into standalone ratio_setting package
Summary
• Migrated all ratio-related sources into `setting/ratio_setting/`
  – `model_ratio.go` (renamed from model-ratio.go)
  – `cache_ratio.go`
  – `group_ratio.go`
• Changed package name to `ratio_setting` and relocated initialization (`ratio_setting.InitRatioSettings()` in main).
• Updated every import & call site:
  – Model / cache / completion / image ratio helpers
  – Group ratio helpers (`GetGroupRatio*`, `ContainsGroupRatio`, `CheckGroupRatio`, etc.)
  – JSON-serialization & update helpers (`*Ratio2JSONString`, `Update*RatioByJSONString`)
• Adjusted controllers, middleware, relay helpers, services and models to reference the new package.
• Removed obsolete `setting` / `operation_setting` imports; added missing `ratio_setting` imports.
• Adopted idiomatic map iteration (`for key := range m`) where value is unused.
• Ran static checks to ensure clean build.

This commit centralises all ratio configuration (model, cache and group) in one cohesive module, simplifying future maintenance and improving code clarity.
2025-06-18 18:00:49 +08:00
xqx121
edcdb378fd Update relay-gemini-native.go 2025-06-18 14:26:23 +08:00
sgyy
4447e51588 fix: Vertex channel global region format 2025-06-18 11:21:56 +08:00
skynono
fb8aac650f fix: playground write user context to check acceptUnsetRatio 2025-06-18 11:06:15 +08:00
Apple\Apple
ba6b0637cc 🐛 fix(detail): explicitly set preventScroll={true} on Tabs to stop page jump
Problem
Semi UI’s Tabs calls `focus()` on the active tab during mount, causing the browser to scroll the page to that element.
Using the bare `preventScroll` shorthand was not picked up reliably, so the page still jumped to the Tabs’ position on first render.

Changes
• Updated both Tabs instances in `web/src/pages/Detail/index.js` to `preventScroll={true}` instead of the shorthand prop.
• Ensures the prop is explicitly interpreted as boolean `true`, converting the internal call to `focus({ preventScroll: true })`.

Result
The `Detail` page now stays at its original scroll position after load, eliminating the unexpected auto-scroll behavior.
2025-06-18 05:10:32 +08:00
Apple\Apple
3502730dfc 🛠️ fix(detail): disable automatic page scroll caused by Tabs focus
The initial render of the `Detail` page was jumping to the first `Tabs` component because Semi UI calls `focus()` on the active tab, which triggers the browser’s default scroll-into-view behavior.

Changes made
• Added `preventScroll` to the chart-selector `Tabs` (type="button").
• Added `preventScroll` to the uptime-monitor `Tabs` (type="card").

These flags convert the internal `focus()` call to `focus({ preventScroll: true })`, allowing the page to stay at its current position after load.

No functional logic is changed other than disabling the unwanted scroll; UI and user interactions remain the same.
2025-06-18 04:36:12 +08:00
RedwindA
b95c5bb8f4 feat(gemini): add pricing for Gemini 2.5 Flash Lite preview audio input 2025-06-18 03:38:58 +08:00
RedwindA
f35784aa97 feat(gemini): update audio input pricing and adjust model handling logic 2025-06-18 03:25:59 +08:00
Apple\Apple
3746482e8c 🏷️ chore(ui): Hide Type Tabs in Tag-Aggregation Mode & Refine Query Logic
frontend(ChannelsTable):
• Do not render type-filter Tabs when `enableTagMode` is true, preventing UI/logic conflicts in tag aggregation view.
• Adjust API query construction:
  – Append `type=` param only when NOT in tag mode and selected tab ≠ 'all'.
  – Applies to both `loadChannels` and `searchChannels`.
• Result: UI stays clean in tag view, and backend receives correct parameters across modes.

No other functionality affected.
2025-06-18 02:59:34 +08:00
Apple\Apple
5ed4b60b8f 🎨 style(EditChannel): replace fixed-height TextArea with autosize to eliminate unwanted scrollbar
The “Batch Create” secret-key input in Channel Edit previously used a TextArea
with a hard-coded `minHeight`, which caused an extra scrollbar and blank space
on the right side of the field.
This change:

• Removes the fixed `minHeight` in favour of `autosize={{ minRows: 6, maxRows: 6 }}`
• Keeps the field’s rounded appearance while letting it grow/shrink with
  content, improving usability on both desktop and mobile

No other components or global styles are affected.
2025-06-18 02:41:06 +08:00
Apple\Apple
547da2da60 🚀 feat(Channels): Enhance Channel Filtering & Performance
feat(api):
• Add optional `type` query param to `/api/channel` endpoint for type-specific pagination
• Return `type_counts` map with counts for each channel type
• Implement `GetChannelsByType`, `CountChannelsByType`, `CountChannelsGroupByType` in `model/channel.go`

feat(frontend):
• Introduce type Tabs in `ChannelsTable` to switch between channel types
• Tabs show dynamic counts using backend `type_counts`; “All” is computed from sum
• Persist active type, reload data on tab change (with proper query params)

perf(frontend):
• Use a request counter (`useRef`) to discard stale responses when tabs switch quickly
• Move all `useMemo` hooks to top level to satisfy React Hook rules
• Remove redundant local type counting fallback when backend data present

ui:
• Remove icons from response-time tags for cleaner look
• Use Semi-UI native arrow controls for Tabs; custom arrow code deleted

chore:
• Minor refactor & comments for clarity
• Ensure ESLint Hook rules pass

Result: Channel list now supports fast, accurate type filtering with correct counts, improved concurrency safety, and cleaner UI.
2025-06-18 02:33:18 +08:00
Apple\Apple
f88ed4dd5c Merge remote-tracking branch 'origin/main' into alpha 2025-06-18 01:30:12 +08:00
Apple\Apple
87fc681df3 🚀 feat(ui): isolate ratio configurations into dedicated “Ratio” tab and refactor settings components
Summary
• Added new Ratio tab in Settings for managing all ratio-related configurations (group & model multipliers).
• Created `RatioSetting` component to host GroupRatio, ModelRatio, Visual Editor and Unset-Models panels.
• Moved ratio components to `web/src/pages/Setting/Ratio/` directory:
  – `GroupRatioSettings.js`
  – `ModelRatioSettings.js`
  – `ModelSettingsVisualEditor.js`
  – `ModelRationNotSetEditor.js`
• Updated imports in `RatioSetting.js` to use the new path.
• Updated main Settings router (`web/src/pages/Setting/index.js`) to include the new “Ratio Settings” tab.
• Pruned `OperationSetting.js`:
  – Removed ratio-specific cards, tabs and unused imports.
  – Reduced state to only the keys required by its child components.
  – Deleted obsolete fields (`StreamCacheQueueLength`, `CheckSensitiveOnCompletionEnabled`, `StopOnSensitiveEnabled`).
• Added boolean handling simplification in `OperationSetting.js`.
• Adjusted helper import list and removed unused translation hook.

Why
Separating ratio-related settings improves UX clarity, reduces cognitive load in the Operation Settings panel and keeps the codebase modular and easier to maintain.

BREAKING CHANGE
The file paths for ratio components have changed. Any external imports referencing the old `Operation` directory must update to the new `Ratio` path.
2025-06-18 01:29:35 +08:00
RedwindA
a39b2f5aa7 feat(ratio): add new Gemini model ratios and enhance flash model handling 2025-06-18 01:09:09 +08:00
Calcium-Ion
0b9b21eafd Merge pull request #1247 from RedwindA/feat/25lite-thinking
feat: improve gemini thinking budget adaption
2025-06-18 01:00:08 +08:00
RedwindA
21f43b0dd8 feat(Gemini): enhance budget clamping logic for Gemini models 2025-06-18 00:49:35 +08:00
CaIon
3a7ba5725c fix(relay): ensure consistent setting of web search context size in TextHelper function 2025-06-18 00:37:22 +08:00
CaIon
2e4fa32d63 fix(relay): refine error message for unsupported MIME types and enhance error handling in OpenAI wrapper 2025-06-17 22:44:57 +08:00
CaIon
0199896d9a fix(relay): improve error handling for unsupported MIME types by sanitizing URLs 2025-06-17 22:40:41 +08:00
CaIon
edd9049100 feat(file_decoder): expand MIME type detection to include additional file extensions 2025-06-17 22:20:19 +08:00
CaIon
290c763901 feat(file_decoder): add debug logging for MIME type detection when handling application/octet-stream 2025-06-17 22:18:51 +08:00
CaIon
226446a3b5 feat(file_decoder): enhance MIME type detection based on URL and Content-Disposition header 2025-06-17 21:49:13 +08:00
Calcium-Ion
ab627db4be Merge pull request #1239 from QuantumNous/auto_group
feat: auto分组
2025-06-17 21:14:09 +08:00
CaIon
0f35d2368f feat: enhance group ratio handling in pricing calculations 2025-06-17 21:05:35 +08:00
CaIon
3c276d13c4 feat(GroupRatioSettings): enhance JSON validation for group ratios 2025-06-17 21:05:24 +08:00
CaIon
b7c3328d43 feat(channel): enhance Claude response handling with new Done flag and improved usage tracking 2025-06-17 20:08:25 +08:00
CaIon
4d8e63bd1a feat(channel): add handling for pre_consume_token_quota_failed error type 2025-06-17 16:46:52 +08:00
CaIon
51757b83e1 Merge branch 'alpha' 2025-06-17 14:49:13 +08:00
Calcium-Ion
87c260093a Merge pull request #1243 from cjm0810151/main
fix(audio): :bugs: fix webm audio strconv.ParseFloat: parsing "N/A"
2025-06-17 14:48:17 +08:00
Calcium-Ion
691a878aa2 Merge pull request #1240 from RedwindA/fix/redis
Fix: optimize Redis expiration handling and refactor cache duration retrieval
2025-06-17 14:47:00 +08:00
chenjm
b33d808bc1 fix(audio): :bugs: fix webm audio strconv.ParseFloat: parsing "N/A" 2025-06-17 10:04:36 +08:00
chenjm
4559f5b2d3 fix(audio): :bugs: fix webm audio strconv.ParseFloat: parsing "N/A" 2025-06-17 09:21:56 +08:00
RedwindA
0b9c6ecb00 🔧 refactor(redis): replace direct constant usage with RedisKeyCacheSeconds function for cache duration 2025-06-17 03:24:39 +08:00
RedwindA
a7d87475af 🔧 fix(redis): only set expiration if greater than 0 in RedisHSetObj 2025-06-17 02:37:19 +08:00
CaIon
ba37750943 Merge remote-tracking branch 'origin/alpha' into alpha 2025-06-17 00:09:38 +08:00
CaIon
4fc85d27e9 🧹 chore(relay): remove unused import in relay-palm.go 2025-06-17 00:09:26 +08:00
Calcium-Ion
246ca40aac Merge pull request #1231 from RedwindA/feat/gemini-budget-in-name
feat(Gemini): implement thinking budget control in model name
2025-06-17 00:03:53 +08:00
Calcium-Ion
59a6fa7274 Merge pull request #1208 from feitianbubu/pr/fix-hot-reload-sync-options
fix: enabled hot reload SyncOptions
2025-06-16 23:03:29 +08:00
creamlike1024
7fa21ce95f feat: auto分组 2025-06-16 22:15:12 +08:00
CaIon
6b7295bbdf 🔧 refactor(relay): replace UUID generation with helper function for response IDs 2025-06-16 21:02:27 +08:00
Apple\Apple
b4b6bd46fe Merge branch 'main' into alpha 2025-06-16 20:06:40 +08:00
Apple\Apple
d5c96cb036 🐛 fix(console-setting): ensure announcements are returned in newest-first order
Summary
• Added stable, descending sort to `GetAnnouncements()` so that the API always returns the latest announcements first.
• Introduced helper `getPublishTime()` to safely parse `publishDate` (RFC 3339) and fall back to zero value on failure.
• Switched to `sort.SliceStable` for deterministic ordering when timestamps are identical.
• Imported the standard `sort` package and removed redundant, duplicate date parsing.

Impact
Front-end no longer needs to perform client-side sorting; the latest announcement is guaranteed to appear at the top on all platforms and clients.
2025-06-16 20:05:54 +08:00
RedwindA
1294d286ee refactor: replace inline closure with a helper function 2025-06-16 19:41:42 +08:00
Calcium-Ion
dc95d0d1e6 Merge pull request #1205 from a37836323/fix-azure-responses-api
修复Azure渠道对responses API的兼容性支持 - 为Azure渠道添加对responses API的特殊处理 - 兼容微软新…
2025-06-16 19:17:21 +08:00
RedwindA
c48a398737 update i18n 2025-06-15 23:40:58 +08:00
RedwindA
e735377218 feat: implement thinking budget control in model name 2025-06-15 23:20:41 +08:00
RedwindA
1ec2bbd533 update input range and description for thinking adapter budget tokens 2025-06-12 18:20:58 +08:00
skynono
21edb75081 fix: enabled hot reload SyncOptions 2025-06-12 17:17:07 +08:00
a37836323
856465ae59 修复Azure渠道对responses API的兼容性支持 - 为Azure渠道添加对responses API的特殊处理 - 兼容微软新的API格式,使用preview版本的api-version - 修复了Azure渠道无法正确处理responses请求的问题 2025-06-11 22:11:47 +08:00
skynono
4a313a5f93 feat: add moonshot(kimi) update balance 2025-06-05 18:16:37 +08:00
Adam.Wang
6be78ff283 feat: 火山引擎增加文生图 2025-05-23 16:42:53 +08:00
Adam.Wang
5281f2ba64 feat: 火山引擎增加文生图 2025-05-22 13:58:05 +08:00
Peter Dave Hello
bc322ddac4 refactor: optimize Dockerfile apk usage 2025-04-29 22:54:43 +08:00
127 changed files with 5034 additions and 1173 deletions

View File

@@ -24,8 +24,7 @@ RUN go build -ldflags "-s -w -X 'one-api/common.Version=$(cat VERSION)'" -o one-
FROM alpine
RUN apk update \
&& apk upgrade \
RUN apk upgrade --no-cache \
&& apk add --no-cache ca-certificates tzdata ffmpeg \
&& update-ca-certificates

View File

@@ -241,6 +241,7 @@ const (
ChannelTypeXinference = 47
ChannelTypeXai = 48
ChannelTypeCoze = 49
ChannelTypeKling = 50
ChannelTypeDummy // this one is only for count, do not add any channel after this
)
@@ -296,4 +297,5 @@ var ChannelBaseURLs = []string{
"", //47
"https://api.x.ai", //48
"https://api.coze.cn", //49
"https://api.klingai.com", //50
}

View File

@@ -141,7 +141,11 @@ func RedisHSetObj(key string, obj interface{}, expiration time.Duration) error {
txn := RDB.TxPipeline()
txn.HSet(ctx, key, data)
txn.Expire(ctx, key, expiration)
// 只有在 expiration 大于 0 时才设置过期时间
if expiration > 0 {
txn.Expire(ctx, key, expiration)
}
_, err := txn.Exec(ctx)
if err != nil {

View File

@@ -13,6 +13,7 @@ import (
"math/big"
"math/rand"
"net"
"net/url"
"os"
"os/exec"
"runtime"
@@ -249,13 +250,55 @@ func SaveTmpFile(filename string, data io.Reader) (string, error) {
}
// GetAudioDuration returns the duration of an audio file in seconds.
func GetAudioDuration(ctx context.Context, filename string) (float64, error) {
func GetAudioDuration(ctx context.Context, filename string, ext string) (float64, error) {
// ffprobe -v error -show_entries format=duration -of default=noprint_wrappers=1:nokey=1 {{input}}
c := exec.CommandContext(ctx, "ffprobe", "-v", "error", "-show_entries", "format=duration", "-of", "default=noprint_wrappers=1:nokey=1", filename)
output, err := c.Output()
if err != nil {
return 0, errors.Wrap(err, "failed to get audio duration")
}
durationStr := string(bytes.TrimSpace(output))
if durationStr == "N/A" {
// Create a temporary output file name
tmpFp, err := os.CreateTemp("", "audio-*"+ext)
if err != nil {
return 0, errors.Wrap(err, "failed to create temporary file")
}
tmpName := tmpFp.Name()
// Close immediately so ffmpeg can open the file on Windows.
_ = tmpFp.Close()
defer os.Remove(tmpName)
return strconv.ParseFloat(string(bytes.TrimSpace(output)), 64)
// ffmpeg -y -i filename -vcodec copy -acodec copy <tmpName>
ffmpegCmd := exec.CommandContext(ctx, "ffmpeg", "-y", "-i", filename, "-vcodec", "copy", "-acodec", "copy", tmpName)
if err := ffmpegCmd.Run(); err != nil {
return 0, errors.Wrap(err, "failed to run ffmpeg")
}
// Recalculate the duration of the new file
c = exec.CommandContext(ctx, "ffprobe", "-v", "error", "-show_entries", "format=duration", "-of", "default=noprint_wrappers=1:nokey=1", tmpName)
output, err := c.Output()
if err != nil {
return 0, errors.Wrap(err, "failed to get audio duration after ffmpeg")
}
durationStr = string(bytes.TrimSpace(output))
}
return strconv.ParseFloat(durationStr, 64)
}
// BuildURL concatenates base and endpoint, returns the complete url string
func BuildURL(base string, endpoint string) string {
u, err := url.Parse(base)
if err != nil {
return base + endpoint
}
end := endpoint
if end == "" {
end = "/"
}
ref, err := url.Parse(end)
if err != nil {
return base + endpoint
}
return u.ResolveReference(ref).String()
}

View File

@@ -2,12 +2,10 @@ package constant
import "one-api/common"
var (
TokenCacheSeconds = common.SyncFrequency
UserId2GroupCacheSeconds = common.SyncFrequency
UserId2QuotaCacheSeconds = common.SyncFrequency
UserId2StatusCacheSeconds = common.SyncFrequency
)
// 使用函数来避免初始化顺序带来的赋值问题
func RedisKeyCacheSeconds() int {
return common.SyncFrequency
}
// Cache keys
const (

View File

@@ -5,6 +5,7 @@ type TaskPlatform string
const (
TaskPlatformSuno TaskPlatform = "suno"
TaskPlatformMidjourney = "mj"
TaskPlatformKling TaskPlatform = "kling"
)
const (

View File

@@ -4,11 +4,13 @@ import (
"encoding/json"
"errors"
"fmt"
"github.com/shopspring/decimal"
"io"
"net/http"
"one-api/common"
"one-api/model"
"one-api/service"
"one-api/setting"
"strconv"
"time"
@@ -304,6 +306,40 @@ func updateChannelOpenRouterBalance(channel *model.Channel) (float64, error) {
return balance, nil
}
func updateChannelMoonshotBalance(channel *model.Channel) (float64, error) {
url := "https://api.moonshot.cn/v1/users/me/balance"
body, err := GetResponseBody("GET", url, channel, GetAuthHeader(channel.Key))
if err != nil {
return 0, err
}
type MoonshotBalanceData struct {
AvailableBalance float64 `json:"available_balance"`
VoucherBalance float64 `json:"voucher_balance"`
CashBalance float64 `json:"cash_balance"`
}
type MoonshotBalanceResponse struct {
Code int `json:"code"`
Data MoonshotBalanceData `json:"data"`
Scode string `json:"scode"`
Status bool `json:"status"`
}
response := MoonshotBalanceResponse{}
err = json.Unmarshal(body, &response)
if err != nil {
return 0, err
}
if !response.Status || response.Code != 0 {
return 0, fmt.Errorf("failed to update moonshot balance, status: %v, code: %d, scode: %s", response.Status, response.Code, response.Scode)
}
availableBalanceCny := response.Data.AvailableBalance
availableBalanceUsd := decimal.NewFromFloat(availableBalanceCny).Div(decimal.NewFromFloat(setting.Price)).InexactFloat64()
channel.UpdateBalance(availableBalanceUsd)
return availableBalanceUsd, nil
}
func updateChannelBalance(channel *model.Channel) (float64, error) {
baseURL := common.ChannelBaseURLs[channel.Type]
if channel.GetBaseURL() == "" {
@@ -332,6 +368,8 @@ func updateChannelBalance(channel *model.Channel) (float64, error) {
return updateChannelDeepSeekBalance(channel)
case common.ChannelTypeOpenRouter:
return updateChannelOpenRouterBalance(channel)
case common.ChannelTypeMoonshot:
return updateChannelMoonshotBalance(channel)
default:
return 0, errors.New("尚未实现")
}

View File

@@ -40,6 +40,9 @@ func testChannel(channel *model.Channel, testModel string) (err error, openAIErr
if channel.Type == common.ChannelTypeSunoAPI {
return errors.New("suno channel test is not supported"), nil
}
if channel.Type == common.ChannelTypeKling {
return errors.New("kling channel test is not supported"), nil
}
w := httptest.NewRecorder()
c, _ := gin.CreateTestContext(w)
@@ -90,7 +93,7 @@ func testChannel(channel *model.Channel, testModel string) (err error, openAIErr
info := relaycommon.GenRelayInfo(c)
err = helper.ModelMappedHelper(c, info)
err = helper.ModelMappedHelper(c, info, nil)
if err != nil {
return err, nil
}
@@ -165,8 +168,8 @@ func testChannel(channel *model.Channel, testModel string) (err error, openAIErr
tok := time.Now()
milliseconds := tok.Sub(tik).Milliseconds()
consumedTime := float64(milliseconds) / 1000.0
other := service.GenerateTextOtherInfo(c, info, priceData.ModelRatio, priceData.GroupRatio, priceData.CompletionRatio,
usage.PromptTokensDetails.CachedTokens, priceData.CacheRatio, priceData.ModelPrice, priceData.UserGroupRatio)
other := service.GenerateTextOtherInfo(c, info, priceData.ModelRatio, priceData.GroupRatioInfo.GroupRatio, priceData.CompletionRatio,
usage.PromptTokensDetails.CachedTokens, priceData.CacheRatio, priceData.ModelPrice, priceData.GroupRatioInfo.GroupSpecialRatio)
model.RecordConsumeLog(c, 1, channel.Id, usage.PromptTokens, usage.CompletionTokens, info.OriginModelName, "模型测试",
quota, "模型测试", 0, quota, int(consumedTime), false, info.Group, other)
common.SysLog(fmt.Sprintf("testing channel #%d, response: \n%s", channel.Id, string(respBody)))
@@ -312,7 +315,7 @@ func testAllChannels(notify bool) error {
channel.UpdateResponseTime(milliseconds)
time.Sleep(common.RequestInterval)
}
if notify {
service.NotifyRootUser(dto.NotifyTypeChannelTest, "通道测试完成", "所有通道测试已完成")
}

View File

@@ -52,6 +52,14 @@ func GetAllChannels(c *gin.Context) {
channelData := make([]*model.Channel, 0)
idSort, _ := strconv.ParseBool(c.Query("id_sort"))
enableTagMode, _ := strconv.ParseBool(c.Query("tag_mode"))
// type filter
typeStr := c.Query("type")
typeFilter := -1
if typeStr != "" {
if t, err := strconv.Atoi(typeStr); err == nil {
typeFilter = t
}
}
var total int64
@@ -72,6 +80,14 @@ func GetAllChannels(c *gin.Context) {
}
// 计算 tag 总数用于分页
total, _ = model.CountAllTags()
} else if typeFilter >= 0 {
channels, err := model.GetChannelsByType((p-1)*pageSize, pageSize, idSort, typeFilter)
if err != nil {
c.JSON(http.StatusOK, gin.H{"success": false, "message": err.Error()})
return
}
channelData = channels
total, _ = model.CountChannelsByType(typeFilter)
} else {
channels, err := model.GetAllChannels((p-1)*pageSize, pageSize, false, idSort)
if err != nil {
@@ -82,14 +98,18 @@ func GetAllChannels(c *gin.Context) {
total, _ = model.CountAllChannels()
}
// calculate type counts
typeCounts, _ := model.CountChannelsGroupByType()
c.JSON(http.StatusOK, gin.H{
"success": true,
"message": "",
"data": gin.H{
"items": channelData,
"total": total,
"page": p,
"page_size": pageSize,
"items": channelData,
"total": total,
"page": p,
"page_size": pageSize,
"type_counts": typeCounts,
},
})
return
@@ -217,10 +237,20 @@ func SearchChannels(c *gin.Context) {
}
channelData = channels
}
// calculate type counts for search results
typeCounts := make(map[int64]int64)
for _, channel := range channelData {
typeCounts[int64(channel.Type)]++
}
c.JSON(http.StatusOK, gin.H{
"success": true,
"message": "",
"data": channelData,
"data": gin.H{
"items": channelData,
"type_counts": typeCounts,
},
})
return
}

View File

@@ -1,15 +1,17 @@
package controller
import (
"github.com/gin-gonic/gin"
"net/http"
"one-api/model"
"one-api/setting"
"one-api/setting/ratio_setting"
"github.com/gin-gonic/gin"
)
func GetGroups(c *gin.Context) {
groupNames := make([]string, 0)
for groupName, _ := range setting.GetGroupRatioCopy() {
for groupName := range ratio_setting.GetGroupRatioCopy() {
groupNames = append(groupNames, groupName)
}
c.JSON(http.StatusOK, gin.H{
@@ -24,7 +26,7 @@ func GetUserGroups(c *gin.Context) {
userGroup := ""
userId := c.GetInt("id")
userGroup, _ = model.GetUserGroup(userId, false)
for groupName, ratio := range setting.GetGroupRatioCopy() {
for groupName, ratio := range ratio_setting.GetGroupRatioCopy() {
// UserUsableGroups contains the groups that the user can use
userUsableGroups := setting.GetUserUsableGroups(userGroup)
if desc, ok := userUsableGroups[groupName]; ok {
@@ -34,6 +36,12 @@ func GetUserGroups(c *gin.Context) {
}
}
}
if setting.GroupInUserUsableGroups("auto") {
usableGroups["auto"] = map[string]interface{}{
"ratio": "自动",
"desc": setting.GetUsableGroupDescription("auto"),
}
}
c.JSON(http.StatusOK, gin.H{
"success": true,
"message": "",

View File

@@ -9,9 +9,9 @@ import (
"one-api/middleware"
"one-api/model"
"one-api/setting"
"one-api/setting/console_setting"
"one-api/setting/operation_setting"
"one-api/setting/system_setting"
"one-api/setting/console_setting"
"strings"
"github.com/gin-gonic/gin"
@@ -41,46 +41,48 @@ func GetStatus(c *gin.Context) {
cs := console_setting.GetConsoleSetting()
data := gin.H{
"version": common.Version,
"start_time": common.StartTime,
"email_verification": common.EmailVerificationEnabled,
"github_oauth": common.GitHubOAuthEnabled,
"github_client_id": common.GitHubClientId,
"linuxdo_oauth": common.LinuxDOOAuthEnabled,
"linuxdo_client_id": common.LinuxDOClientId,
"telegram_oauth": common.TelegramOAuthEnabled,
"telegram_bot_name": common.TelegramBotName,
"system_name": common.SystemName,
"logo": common.Logo,
"footer_html": common.Footer,
"wechat_qrcode": common.WeChatAccountQRCodeImageURL,
"wechat_login": common.WeChatAuthEnabled,
"server_address": setting.ServerAddress,
"price": setting.Price,
"min_topup": setting.MinTopUp,
"turnstile_check": common.TurnstileCheckEnabled,
"turnstile_site_key": common.TurnstileSiteKey,
"top_up_link": common.TopUpLink,
"docs_link": operation_setting.GetGeneralSetting().DocsLink,
"quota_per_unit": common.QuotaPerUnit,
"display_in_currency": common.DisplayInCurrencyEnabled,
"enable_batch_update": common.BatchUpdateEnabled,
"enable_drawing": common.DrawingEnabled,
"enable_task": common.TaskEnabled,
"enable_data_export": common.DataExportEnabled,
"data_export_default_time": common.DataExportDefaultTime,
"default_collapse_sidebar": common.DefaultCollapseSidebar,
"enable_online_topup": setting.PayAddress != "" && setting.EpayId != "" && setting.EpayKey != "",
"mj_notify_enabled": setting.MjNotifyEnabled,
"chats": setting.Chats,
"demo_site_enabled": operation_setting.DemoSiteEnabled,
"self_use_mode_enabled": operation_setting.SelfUseModeEnabled,
"version": common.Version,
"start_time": common.StartTime,
"email_verification": common.EmailVerificationEnabled,
"github_oauth": common.GitHubOAuthEnabled,
"github_client_id": common.GitHubClientId,
"linuxdo_oauth": common.LinuxDOOAuthEnabled,
"linuxdo_client_id": common.LinuxDOClientId,
"telegram_oauth": common.TelegramOAuthEnabled,
"telegram_bot_name": common.TelegramBotName,
"system_name": common.SystemName,
"logo": common.Logo,
"footer_html": common.Footer,
"wechat_qrcode": common.WeChatAccountQRCodeImageURL,
"wechat_login": common.WeChatAuthEnabled,
"server_address": setting.ServerAddress,
"price": setting.Price,
"min_topup": setting.MinTopUp,
"turnstile_check": common.TurnstileCheckEnabled,
"turnstile_site_key": common.TurnstileSiteKey,
"top_up_link": common.TopUpLink,
"docs_link": operation_setting.GetGeneralSetting().DocsLink,
"quota_per_unit": common.QuotaPerUnit,
"display_in_currency": common.DisplayInCurrencyEnabled,
"enable_batch_update": common.BatchUpdateEnabled,
"enable_drawing": common.DrawingEnabled,
"enable_task": common.TaskEnabled,
"enable_data_export": common.DataExportEnabled,
"data_export_default_time": common.DataExportDefaultTime,
"default_collapse_sidebar": common.DefaultCollapseSidebar,
"enable_online_topup": setting.PayAddress != "" && setting.EpayId != "" && setting.EpayKey != "",
"mj_notify_enabled": setting.MjNotifyEnabled,
"chats": setting.Chats,
"demo_site_enabled": operation_setting.DemoSiteEnabled,
"self_use_mode_enabled": operation_setting.SelfUseModeEnabled,
"default_use_auto_group": setting.DefaultUseAutoGroup,
"pay_methods": setting.PayMethods,
// 面板启用开关
"api_info_enabled": cs.ApiInfoEnabled,
"uptime_kuma_enabled": cs.UptimeKumaEnabled,
"announcements_enabled": cs.AnnouncementsEnabled,
"faq_enabled": cs.FAQEnabled,
"api_info_enabled": cs.ApiInfoEnabled,
"uptime_kuma_enabled": cs.UptimeKumaEnabled,
"announcements_enabled": cs.AnnouncementsEnabled,
"faq_enabled": cs.FAQEnabled,
"oidc_enabled": system_setting.GetOIDCSettings().Enabled,
"oidc_client_id": system_setting.GetOIDCSettings().ClientId,

View File

@@ -2,7 +2,7 @@ package controller
import (
"fmt"
"github.com/gin-gonic/gin"
"github.com/samber/lo"
"net/http"
"one-api/common"
"one-api/constant"
@@ -15,6 +15,9 @@ import (
"one-api/relay/channel/moonshot"
relaycommon "one-api/relay/common"
relayconstant "one-api/relay/constant"
"one-api/setting"
"github.com/gin-gonic/gin"
)
// https://platform.openai.com/docs/api-reference/models/list
@@ -134,6 +137,9 @@ func init() {
adaptor.Init(meta)
channelId2Models[i] = adaptor.GetModelList()
}
openAIModels = lo.UniqBy(openAIModels, func(m dto.OpenAIModels) string {
return m.Id
})
}
func ListModels(c *gin.Context) {
@@ -179,7 +185,19 @@ func ListModels(c *gin.Context) {
if tokenGroup != "" {
group = tokenGroup
}
models := model.GetGroupModels(group)
var models []string
if tokenGroup == "auto" {
for _, autoGroup := range setting.AutoGroups {
groupModels := model.GetGroupModels(autoGroup)
for _, g := range groupModels {
if !common.StringsContains(models, g) {
models = append(models, g)
}
}
}
} else {
models = model.GetGroupModels(group)
}
for _, s := range models {
if _, ok := openAIModelsMap[s]; ok {
userOpenAiModels = append(userOpenAiModels, openAIModelsMap[s])

View File

@@ -7,6 +7,7 @@ import (
"one-api/model"
"one-api/setting"
"one-api/setting/console_setting"
"one-api/setting/ratio_setting"
"one-api/setting/system_setting"
"strings"
@@ -103,7 +104,7 @@ func UpdateOption(c *gin.Context) {
return
}
case "GroupRatio":
err = setting.CheckGroupRatio(option.Value)
err = ratio_setting.CheckGroupRatio(option.Value)
if err != nil {
c.JSON(http.StatusOK, gin.H{
"success": false,

View File

@@ -3,7 +3,6 @@ package controller
import (
"errors"
"fmt"
"github.com/gin-gonic/gin"
"net/http"
"one-api/common"
"one-api/constant"
@@ -13,6 +12,8 @@ import (
"one-api/service"
"one-api/setting"
"time"
"github.com/gin-gonic/gin"
)
func Playground(c *gin.Context) {
@@ -57,13 +58,22 @@ func Playground(c *gin.Context) {
c.Set("group", group)
}
c.Set("token_name", "playground-"+group)
channel, err := model.CacheGetRandomSatisfiedChannel(group, playgroundRequest.Model, 0)
channel, finalGroup, err := model.CacheGetRandomSatisfiedChannel(c, group, playgroundRequest.Model, 0)
if err != nil {
message := fmt.Sprintf("当前分组 %s 下对于模型 %s 无可用渠道", group, playgroundRequest.Model)
message := fmt.Sprintf("当前分组 %s 下对于模型 %s 无可用渠道", finalGroup, playgroundRequest.Model)
openaiErr = service.OpenAIErrorWrapperLocal(errors.New(message), "get_playground_channel_failed", http.StatusInternalServerError)
return
}
middleware.SetupContextForSelectedChannel(c, channel, playgroundRequest.Model)
c.Set(constant.ContextKeyRequestStartTime, time.Now())
// Write user context to ensure acceptUnsetRatio is available
userId := c.GetInt("id")
userCache, err := model.GetUserCache(userId)
if err != nil {
openaiErr = service.OpenAIErrorWrapperLocal(err, "get_user_cache_failed", http.StatusInternalServerError)
return
}
userCache.WriteContext(c)
Relay(c)
}

View File

@@ -3,7 +3,7 @@ package controller
import (
"one-api/model"
"one-api/setting"
"one-api/setting/operation_setting"
"one-api/setting/ratio_setting"
"github.com/gin-gonic/gin"
)
@@ -13,7 +13,7 @@ func GetPricing(c *gin.Context) {
userId, exists := c.Get("id")
usableGroup := map[string]string{}
groupRatio := map[string]float64{}
for s, f := range setting.GetGroupRatioCopy() {
for s, f := range ratio_setting.GetGroupRatioCopy() {
groupRatio[s] = f
}
var group string
@@ -22,7 +22,7 @@ func GetPricing(c *gin.Context) {
if err == nil {
group = user.Group
for g := range groupRatio {
ratio, ok := setting.GetGroupGroupRatio(group, g)
ratio, ok := ratio_setting.GetGroupGroupRatio(group, g)
if ok {
groupRatio[g] = ratio
}
@@ -32,7 +32,7 @@ func GetPricing(c *gin.Context) {
usableGroup = setting.GetUserUsableGroups(group)
// check groupRatio contains usableGroup
for group := range setting.GetGroupRatioCopy() {
for group := range ratio_setting.GetGroupRatioCopy() {
if _, ok := usableGroup[group]; !ok {
delete(groupRatio, group)
}
@@ -47,7 +47,7 @@ func GetPricing(c *gin.Context) {
}
func ResetModelRatio(c *gin.Context) {
defaultStr := operation_setting.DefaultModelRatio2JSONString()
defaultStr := ratio_setting.DefaultModelRatio2JSONString()
err := model.UpdateOption("ModelRatio", defaultStr)
if err != nil {
c.JSON(200, gin.H{
@@ -56,7 +56,7 @@ func ResetModelRatio(c *gin.Context) {
})
return
}
err = operation_setting.UpdateModelRatioByJSONString(defaultStr)
err = ratio_setting.UpdateModelRatioByJSONString(defaultStr)
if err != nil {
c.JSON(200, gin.H{
"success": false,

View File

@@ -0,0 +1,24 @@
package controller
import (
"net/http"
"one-api/setting/ratio_setting"
"github.com/gin-gonic/gin"
)
func GetRatioConfig(c *gin.Context) {
if !ratio_setting.IsExposeRatioEnabled() {
c.JSON(http.StatusForbidden, gin.H{
"success": false,
"message": "倍率配置接口未启用",
})
return
}
c.JSON(http.StatusOK, gin.H{
"success": true,
"message": "",
"data": ratio_setting.GetExposedData(),
})
}

322
controller/ratio_sync.go Normal file
View File

@@ -0,0 +1,322 @@
package controller
import (
"context"
"encoding/json"
"net/http"
"strings"
"sync"
"time"
"one-api/common"
"one-api/dto"
"one-api/model"
"one-api/setting/ratio_setting"
"github.com/gin-gonic/gin"
)
const (
defaultTimeoutSeconds = 10
defaultEndpoint = "/api/ratio_config"
maxConcurrentFetches = 8
)
var ratioTypes = []string{"model_ratio", "completion_ratio", "cache_ratio", "model_price"}
type upstreamResult struct {
Name string `json:"name"`
Data map[string]any `json:"data,omitempty"`
Err string `json:"err,omitempty"`
}
func FetchUpstreamRatios(c *gin.Context) {
var req dto.UpstreamRequest
if err := c.ShouldBindJSON(&req); err != nil {
c.JSON(http.StatusBadRequest, gin.H{"success": false, "message": err.Error()})
return
}
if req.Timeout <= 0 {
req.Timeout = defaultTimeoutSeconds
}
var upstreams []dto.UpstreamDTO
if len(req.ChannelIDs) > 0 {
intIds := make([]int, 0, len(req.ChannelIDs))
for _, id64 := range req.ChannelIDs {
intIds = append(intIds, int(id64))
}
dbChannels, err := model.GetChannelsByIds(intIds)
if err != nil {
common.LogError(c.Request.Context(), "failed to query channels: "+err.Error())
c.JSON(http.StatusInternalServerError, gin.H{"success": false, "message": "查询渠道失败"})
return
}
for _, ch := range dbChannels {
if base := ch.GetBaseURL(); strings.HasPrefix(base, "http") {
upstreams = append(upstreams, dto.UpstreamDTO{
Name: ch.Name,
BaseURL: strings.TrimRight(base, "/"),
Endpoint: "",
})
}
}
}
if len(upstreams) == 0 {
c.JSON(http.StatusOK, gin.H{"success": false, "message": "无有效上游渠道"})
return
}
var wg sync.WaitGroup
ch := make(chan upstreamResult, len(upstreams))
sem := make(chan struct{}, maxConcurrentFetches)
client := &http.Client{Transport: &http.Transport{MaxIdleConns: 100, IdleConnTimeout: 90 * time.Second, TLSHandshakeTimeout: 10 * time.Second, ExpectContinueTimeout: 1 * time.Second}}
for _, chn := range upstreams {
wg.Add(1)
go func(chItem dto.UpstreamDTO) {
defer wg.Done()
sem <- struct{}{}
defer func() { <-sem }()
endpoint := chItem.Endpoint
if endpoint == "" {
endpoint = defaultEndpoint
} else if !strings.HasPrefix(endpoint, "/") {
endpoint = "/" + endpoint
}
fullURL := chItem.BaseURL + endpoint
ctx, cancel := context.WithTimeout(c.Request.Context(), time.Duration(req.Timeout)*time.Second)
defer cancel()
httpReq, err := http.NewRequestWithContext(ctx, http.MethodGet, fullURL, nil)
if err != nil {
common.LogWarn(c.Request.Context(), "build request failed: "+err.Error())
ch <- upstreamResult{Name: chItem.Name, Err: err.Error()}
return
}
resp, err := client.Do(httpReq)
if err != nil {
common.LogWarn(c.Request.Context(), "http error on "+chItem.Name+": "+err.Error())
ch <- upstreamResult{Name: chItem.Name, Err: err.Error()}
return
}
defer resp.Body.Close()
if resp.StatusCode != http.StatusOK {
common.LogWarn(c.Request.Context(), "non-200 from "+chItem.Name+": "+resp.Status)
ch <- upstreamResult{Name: chItem.Name, Err: resp.Status}
return
}
var body struct {
Success bool `json:"success"`
Data map[string]any `json:"data"`
Message string `json:"message"`
}
if err := json.NewDecoder(resp.Body).Decode(&body); err != nil {
common.LogWarn(c.Request.Context(), "json decode failed from "+chItem.Name+": "+err.Error())
ch <- upstreamResult{Name: chItem.Name, Err: err.Error()}
return
}
if !body.Success {
ch <- upstreamResult{Name: chItem.Name, Err: body.Message}
return
}
ch <- upstreamResult{Name: chItem.Name, Data: body.Data}
}(chn)
}
wg.Wait()
close(ch)
localData := ratio_setting.GetExposedData()
var testResults []dto.TestResult
var successfulChannels []struct {
name string
data map[string]any
}
for r := range ch {
if r.Err != "" {
testResults = append(testResults, dto.TestResult{
Name: r.Name,
Status: "error",
Error: r.Err,
})
} else {
testResults = append(testResults, dto.TestResult{
Name: r.Name,
Status: "success",
})
successfulChannels = append(successfulChannels, struct {
name string
data map[string]any
}{name: r.Name, data: r.Data})
}
}
differences := buildDifferences(localData, successfulChannels)
c.JSON(http.StatusOK, gin.H{
"success": true,
"data": gin.H{
"differences": differences,
"test_results": testResults,
},
})
}
func buildDifferences(localData map[string]any, successfulChannels []struct {
name string
data map[string]any
}) map[string]map[string]dto.DifferenceItem {
differences := make(map[string]map[string]dto.DifferenceItem)
allModels := make(map[string]struct{})
for _, ratioType := range ratioTypes {
if localRatioAny, ok := localData[ratioType]; ok {
if localRatio, ok := localRatioAny.(map[string]float64); ok {
for modelName := range localRatio {
allModels[modelName] = struct{}{}
}
}
}
}
for _, channel := range successfulChannels {
for _, ratioType := range ratioTypes {
if upstreamRatio, ok := channel.data[ratioType].(map[string]any); ok {
for modelName := range upstreamRatio {
allModels[modelName] = struct{}{}
}
}
}
}
for modelName := range allModels {
for _, ratioType := range ratioTypes {
var localValue interface{} = nil
if localRatioAny, ok := localData[ratioType]; ok {
if localRatio, ok := localRatioAny.(map[string]float64); ok {
if val, exists := localRatio[modelName]; exists {
localValue = val
}
}
}
upstreamValues := make(map[string]interface{})
hasUpstreamValue := false
hasDifference := false
for _, channel := range successfulChannels {
var upstreamValue interface{} = nil
if upstreamRatio, ok := channel.data[ratioType].(map[string]any); ok {
if val, exists := upstreamRatio[modelName]; exists {
upstreamValue = val
hasUpstreamValue = true
if localValue != nil && localValue != val {
hasDifference = true
} else if localValue == val {
upstreamValue = "same"
}
}
}
if upstreamValue == nil && localValue == nil {
upstreamValue = "same"
}
if localValue == nil && upstreamValue != nil && upstreamValue != "same" {
hasDifference = true
}
upstreamValues[channel.name] = upstreamValue
}
shouldInclude := false
if localValue != nil {
if hasDifference {
shouldInclude = true
}
} else {
if hasUpstreamValue {
shouldInclude = true
}
}
if shouldInclude {
if differences[modelName] == nil {
differences[modelName] = make(map[string]dto.DifferenceItem)
}
differences[modelName][ratioType] = dto.DifferenceItem{
Current: localValue,
Upstreams: upstreamValues,
}
}
}
}
channelHasDiff := make(map[string]bool)
for _, ratioMap := range differences {
for _, item := range ratioMap {
for chName, val := range item.Upstreams {
if val != nil && val != "same" {
channelHasDiff[chName] = true
}
}
}
}
for modelName, ratioMap := range differences {
for ratioType, item := range ratioMap {
for chName := range item.Upstreams {
if !channelHasDiff[chName] {
delete(item.Upstreams, chName)
}
}
differences[modelName][ratioType] = item
}
}
return differences
}
func GetSyncableChannels(c *gin.Context) {
channels, err := model.GetAllChannels(0, 0, true, false)
if err != nil {
c.JSON(http.StatusOK, gin.H{
"success": false,
"message": err.Error(),
})
return
}
var syncableChannels []dto.SyncableChannel
for _, channel := range channels {
if channel.GetBaseURL() != "" {
syncableChannels = append(syncableChannels, dto.SyncableChannel{
ID: channel.Id,
Name: channel.Name,
BaseURL: channel.GetBaseURL(),
Status: channel.Status,
})
}
}
c.JSON(http.StatusOK, gin.H{
"success": true,
"message": "",
"data": syncableChannels,
})
}

View File

@@ -259,7 +259,7 @@ func getChannel(c *gin.Context, group, originalModel string, retryCount int) (*m
AutoBan: &autoBanInt,
}, nil
}
channel, err := model.CacheGetRandomSatisfiedChannel(group, originalModel, retryCount)
channel, _, err := model.CacheGetRandomSatisfiedChannel(c, group, originalModel, retryCount)
if err != nil {
return nil, errors.New(fmt.Sprintf("获取重试渠道失败: %s", err.Error()))
}
@@ -388,7 +388,7 @@ func RelayTask(c *gin.Context) {
retryTimes = 0
}
for i := 0; shouldRetryTaskRelay(c, channelId, taskErr, retryTimes) && i < retryTimes; i++ {
channel, err := model.CacheGetRandomSatisfiedChannel(group, originalModel, i)
channel, _, err := model.CacheGetRandomSatisfiedChannel(c, group, originalModel, i)
if err != nil {
common.LogError(c, fmt.Sprintf("CacheGetRandomSatisfiedChannel failed: %s", err.Error()))
break
@@ -420,7 +420,7 @@ func RelayTask(c *gin.Context) {
func taskRelayHandler(c *gin.Context, relayMode int) *dto.TaskError {
var err *dto.TaskError
switch relayMode {
case relayconstant.RelayModeSunoFetch, relayconstant.RelayModeSunoFetchByID:
case relayconstant.RelayModeSunoFetch, relayconstant.RelayModeSunoFetchByID, relayconstant.RelayModeKlingFetchByID:
err = relay.RelayTaskFetch(c, relayMode)
default:
err = relay.RelayTaskSubmit(c, relayMode)

View File

@@ -74,6 +74,8 @@ func UpdateTaskByPlatform(platform constant.TaskPlatform, taskChannelM map[int][
//_ = UpdateMidjourneyTaskAll(context.Background(), tasks)
case constant.TaskPlatformSuno:
_ = UpdateSunoTaskAll(context.Background(), taskChannelM, taskM)
case constant.TaskPlatformKling:
_ = UpdateVideoTaskAll(context.Background(), taskChannelM, taskM)
default:
common.SysLog("未知平台")
}

140
controller/task_video.go Normal file
View File

@@ -0,0 +1,140 @@
package controller
import (
"context"
"encoding/json"
"fmt"
"io"
"net/http"
"one-api/common"
"one-api/constant"
"one-api/model"
"one-api/relay"
"one-api/relay/channel"
)
func UpdateVideoTaskAll(ctx context.Context, taskChannelM map[int][]string, taskM map[string]*model.Task) error {
for channelId, taskIds := range taskChannelM {
if err := updateVideoTaskAll(ctx, channelId, taskIds, taskM); err != nil {
common.LogError(ctx, fmt.Sprintf("Channel #%d failed to update video async tasks: %s", channelId, err.Error()))
}
}
return nil
}
func updateVideoTaskAll(ctx context.Context, channelId int, taskIds []string, taskM map[string]*model.Task) error {
common.LogInfo(ctx, fmt.Sprintf("Channel #%d pending video tasks: %d", channelId, len(taskIds)))
if len(taskIds) == 0 {
return nil
}
cacheGetChannel, err := model.CacheGetChannel(channelId)
if err != nil {
errUpdate := model.TaskBulkUpdate(taskIds, map[string]any{
"fail_reason": fmt.Sprintf("Failed to get channel info, channel ID: %d", channelId),
"status": "FAILURE",
"progress": "100%",
})
if errUpdate != nil {
common.SysError(fmt.Sprintf("UpdateVideoTask error: %v", errUpdate))
}
return fmt.Errorf("CacheGetChannel failed: %w", err)
}
adaptor := relay.GetTaskAdaptor(constant.TaskPlatformKling)
if adaptor == nil {
return fmt.Errorf("video adaptor not found")
}
for _, taskId := range taskIds {
if err := updateVideoSingleTask(ctx, adaptor, cacheGetChannel, taskId, taskM); err != nil {
common.LogError(ctx, fmt.Sprintf("Failed to update video task %s: %s", taskId, err.Error()))
}
}
return nil
}
func updateVideoSingleTask(ctx context.Context, adaptor channel.TaskAdaptor, channel *model.Channel, taskId string, taskM map[string]*model.Task) error {
baseURL := common.ChannelBaseURLs[channel.Type]
if channel.GetBaseURL() != "" {
baseURL = channel.GetBaseURL()
}
resp, err := adaptor.FetchTask(baseURL, channel.Key, map[string]any{
"task_id": taskId,
})
if err != nil {
return fmt.Errorf("FetchTask failed for task %s: %w", taskId, err)
}
if resp.StatusCode != http.StatusOK {
return fmt.Errorf("Get Video Task status code: %d", resp.StatusCode)
}
defer resp.Body.Close()
responseBody, err := io.ReadAll(resp.Body)
if err != nil {
return fmt.Errorf("ReadAll failed for task %s: %w", taskId, err)
}
var responseItem map[string]interface{}
err = json.Unmarshal(responseBody, &responseItem)
if err != nil {
common.LogError(ctx, fmt.Sprintf("Failed to parse video task response body: %v, body: %s", err, string(responseBody)))
return fmt.Errorf("Unmarshal failed for task %s: %w", taskId, err)
}
code, _ := responseItem["code"].(float64)
if code != 0 {
return fmt.Errorf("video task fetch failed for task %s", taskId)
}
data, ok := responseItem["data"].(map[string]interface{})
if !ok {
common.LogError(ctx, fmt.Sprintf("Video task data format error: %s", string(responseBody)))
return fmt.Errorf("video task data format error for task %s", taskId)
}
task := taskM[taskId]
if task == nil {
common.LogError(ctx, fmt.Sprintf("Task %s not found in taskM", taskId))
return fmt.Errorf("task %s not found", taskId)
}
if status, ok := data["task_status"].(string); ok {
switch status {
case "submitted", "queued":
task.Status = model.TaskStatusSubmitted
case "processing":
task.Status = model.TaskStatusInProgress
case "succeed":
task.Status = model.TaskStatusSuccess
task.Progress = "100%"
if url, err := adaptor.ParseResultUrl(responseItem); err == nil {
task.FailReason = url
} else {
common.LogWarn(ctx, fmt.Sprintf("Failed to get url from body for task %s: %s", task.TaskID, err.Error()))
}
case "failed":
task.Status = model.TaskStatusFailure
task.Progress = "100%"
if reason, ok := data["fail_reason"].(string); ok {
task.FailReason = reason
}
}
}
// If task failed, refund quota
if task.Status == model.TaskStatusFailure {
common.LogInfo(ctx, fmt.Sprintf("Task %s failed: %s", task.TaskID, task.FailReason))
quota := task.Quota
if quota != 0 {
if err := model.IncreaseUserQuota(task.UserId, quota, false); err != nil {
common.LogError(ctx, "Failed to increase user quota: "+err.Error())
}
logContent := fmt.Sprintf("Video async task failed %s, refund %s", task.TaskID, common.LogQuota(quota))
model.RecordLog(task.UserId, model.LogTypeSystem, logContent)
}
}
task.Data = responseBody
if err := task.Update(); err != nil {
common.SysError("UpdateVideoTask task error: " + err.Error())
}
return nil
}

View File

@@ -97,14 +97,12 @@ func RequestEpay(c *gin.Context) {
c.JSON(200, gin.H{"message": "error", "data": "充值金额过低"})
return
}
payType := "wxpay"
if req.PaymentMethod == "zfb" {
payType = "alipay"
}
if req.PaymentMethod == "wx" {
req.PaymentMethod = "wxpay"
payType = "wxpay"
if !setting.ContainsPayMethod(req.PaymentMethod) {
c.JSON(200, gin.H{"message": "error", "data": "支付方式不存在"})
return
}
callBackAddress := service.GetCallbackAddress()
returnUrl, _ := url.Parse(setting.ServerAddress + "/console/log")
notifyUrl, _ := url.Parse(callBackAddress + "/api/user/epay/notify")
@@ -116,7 +114,7 @@ func RequestEpay(c *gin.Context) {
return
}
uri, params, err := client.Purchase(&epay.PurchaseArgs{
Type: payType,
Type: req.PaymentMethod,
ServiceTradeNo: tradeNo,
Name: fmt.Sprintf("TUC%d", req.Amount),
Money: strconv.FormatFloat(payMoney, 'f', 2, 64),

View File

@@ -226,6 +226,9 @@ func Register(c *gin.Context) {
UnlimitedQuota: true,
ModelLimitsEnabled: false,
}
if setting.DefaultUseAutoGroup {
token.Group = "auto"
}
if err := token.Insert(); err != nil {
c.JSON(http.StatusOK, gin.H{
"success": false,

View File

@@ -15,6 +15,7 @@ type ImageRequest struct {
Background string `json:"background,omitempty"`
Moderation string `json:"moderation,omitempty"`
OutputFormat string `json:"output_format,omitempty"`
Watermark *bool `json:"watermark,omitempty"`
}
type ImageResponse struct {

View File

@@ -53,6 +53,7 @@ type GeneralOpenAIRequest struct {
Modalities json.RawMessage `json:"modalities,omitempty"`
Audio json.RawMessage `json:"audio,omitempty"`
EnableThinking any `json:"enable_thinking,omitempty"` // ali
THINKING json.RawMessage `json:"thinking,omitempty"` // doubao
ExtraBody json.RawMessage `json:"extra_body,omitempty"`
WebSearchOptions *WebSearchOptions `json:"web_search_options,omitempty"`
// OpenRouter Params

49
dto/ratio_sync.go Normal file
View File

@@ -0,0 +1,49 @@
package dto
// UpstreamDTO 提交到后端同步倍率的上游渠道信息
// Endpoint 可以为空,后端会默认使用 /api/ratio_config
// BaseURL 必须以 http/https 开头,不要以 / 结尾
// 例如: https://api.example.com
// Endpoint: /api/ratio_config
// 提交示例:
// {
// "name": "openai",
// "base_url": "https://api.openai.com",
// "endpoint": "/ratio_config"
// }
type UpstreamDTO struct {
Name string `json:"name" binding:"required"`
BaseURL string `json:"base_url" binding:"required"`
Endpoint string `json:"endpoint"`
}
type UpstreamRequest struct {
ChannelIDs []int64 `json:"channel_ids"`
Timeout int `json:"timeout"`
}
// TestResult 上游测试连通性结果
type TestResult struct {
Name string `json:"name"`
Status string `json:"status"`
Error string `json:"error,omitempty"`
}
// DifferenceItem 差异项
// Current 为本地值,可能为 nil
// Upstreams 为各渠道的上游值,具体数值 / "same" / nil
type DifferenceItem struct {
Current interface{} `json:"current"`
Upstreams map[string]interface{} `json:"upstreams"`
}
// SyncableChannel 可同步的渠道信息base_url 不为空)
type SyncableChannel struct {
ID int `json:"id"`
Name string `json:"name"`
BaseURL string `json:"base_url"`
Status int `json:"status"`
}

47
dto/video.go Normal file
View File

@@ -0,0 +1,47 @@
package dto
type VideoRequest struct {
Model string `json:"model,omitempty" example:"kling-v1"` // Model/style ID
Prompt string `json:"prompt,omitempty" example:"宇航员站起身走了"` // Text prompt
Image string `json:"image,omitempty" example:"https://h2.inkwai.com/bs2/upload-ylab-stunt/se/ai_portal_queue_mmu_image_upscale_aiweb/3214b798-e1b4-4b00-b7af-72b5b0417420_raw_image_0.jpg"` // Image input (URL/Base64)
Duration float64 `json:"duration" example:"5.0"` // Video duration (seconds)
Width int `json:"width" example:"512"` // Video width
Height int `json:"height" example:"512"` // Video height
Fps int `json:"fps,omitempty" example:"30"` // Video frame rate
Seed int `json:"seed,omitempty" example:"20231234"` // Random seed
N int `json:"n,omitempty" example:"1"` // Number of videos to generate
ResponseFormat string `json:"response_format,omitempty" example:"url"` // Response format
User string `json:"user,omitempty" example:"user-1234"` // User identifier
Metadata map[string]any `json:"metadata,omitempty"` // Vendor-specific/custom params (e.g. negative_prompt, style, quality_level, etc.)
}
// VideoResponse 视频生成提交任务后的响应
type VideoResponse struct {
TaskId string `json:"task_id"`
Status string `json:"status"`
}
// VideoTaskResponse 查询视频生成任务状态的响应
type VideoTaskResponse struct {
TaskId string `json:"task_id" example:"abcd1234efgh"` // 任务ID
Status string `json:"status" example:"succeeded"` // 任务状态
Url string `json:"url,omitempty"` // 视频资源URL成功时
Format string `json:"format,omitempty" example:"mp4"` // 视频格式
Metadata *VideoTaskMetadata `json:"metadata,omitempty"` // 结果元数据
Error *VideoTaskError `json:"error,omitempty"` // 错误信息(失败时)
}
// VideoTaskMetadata 视频任务元数据
type VideoTaskMetadata struct {
Duration float64 `json:"duration" example:"5.0"` // 实际生成的视频时长
Fps int `json:"fps" example:"30"` // 实际帧率
Width int `json:"width" example:"512"` // 实际宽度
Height int `json:"height" example:"512"` // 实际高度
Seed int `json:"seed" example:"20231234"` // 使用的随机种子
}
// VideoTaskError 视频任务错误信息
type VideoTaskError struct {
Code int `json:"code"`
Message string `json:"message"`
}

View File

@@ -12,7 +12,7 @@ import (
"one-api/model"
"one-api/router"
"one-api/service"
"one-api/setting/operation_setting"
"one-api/setting/ratio_setting"
"os"
"strconv"
@@ -74,7 +74,7 @@ func main() {
}
// Initialize model settings
operation_setting.InitRatioSettings()
ratio_setting.InitRatioSettings()
// Initialize constants
constant.InitEnv()
// Initialize options
@@ -105,10 +105,12 @@ func main() {
model.InitChannelCache()
}()
go model.SyncOptions(common.SyncFrequency)
go model.SyncChannelCache(common.SyncFrequency)
}
// 热更新配置
go model.SyncOptions(common.SyncFrequency)
// 数据看板
go model.UpdateQuotaData()

View File

@@ -11,6 +11,7 @@ import (
relayconstant "one-api/relay/constant"
"one-api/service"
"one-api/setting"
"one-api/setting/ratio_setting"
"strconv"
"strings"
"time"
@@ -48,9 +49,11 @@ func Distribute() func(c *gin.Context) {
return
}
// check group in common.GroupRatio
if !setting.ContainsGroupRatio(tokenGroup) {
abortWithOpenAiMessage(c, http.StatusForbidden, fmt.Sprintf("分组 %s 已被弃用", tokenGroup))
return
if !ratio_setting.ContainsGroupRatio(tokenGroup) {
if tokenGroup != "auto" {
abortWithOpenAiMessage(c, http.StatusForbidden, fmt.Sprintf("分组 %s 已被弃用", tokenGroup))
return
}
}
userGroup = tokenGroup
}
@@ -95,9 +98,14 @@ func Distribute() func(c *gin.Context) {
}
if shouldSelectChannel {
channel, err = model.CacheGetRandomSatisfiedChannel(userGroup, modelRequest.Model, 0)
var selectGroup string
channel, selectGroup, err = model.CacheGetRandomSatisfiedChannel(c, userGroup, modelRequest.Model, 0)
if err != nil {
message := fmt.Sprintf("当前分组 %s 下对于模型 %s 无可用渠道", userGroup, modelRequest.Model)
showGroup := userGroup
if userGroup == "auto" {
showGroup = fmt.Sprintf("auto(%s)", selectGroup)
}
message := fmt.Sprintf("当前分组 %s 下对于模型 %s 无可用渠道", showGroup, modelRequest.Model)
// 如果错误,但是渠道不为空,说明是数据库一致性问题
if channel != nil {
common.SysError(fmt.Sprintf("渠道不存在:%d", channel.Id))
@@ -162,6 +170,15 @@ func getModelRequest(c *gin.Context) (*ModelRequest, bool, error) {
}
c.Set("platform", string(constant.TaskPlatformSuno))
c.Set("relay_mode", relayMode)
} else if strings.Contains(c.Request.URL.Path, "/v1/video/generations") {
relayMode := relayconstant.Path2RelayKling(c.Request.Method, c.Request.URL.Path)
if relayMode == relayconstant.RelayModeKlingFetchByID {
shouldSelectChannel = false
} else {
err = common.UnmarshalBodyReusable(c, &modelRequest)
}
c.Set("platform", string(constant.TaskPlatformKling))
c.Set("relay_mode", relayMode)
} else if strings.HasPrefix(c.Request.URL.Path, "/v1beta/models/") {
// Gemini API 路径处理: /v1beta/models/gemini-2.0-flash:generateContent
relayMode := relayconstant.RelayModeGemini

View File

@@ -5,10 +5,13 @@ import (
"fmt"
"math/rand"
"one-api/common"
"one-api/setting"
"sort"
"strings"
"sync"
"time"
"github.com/gin-gonic/gin"
)
var group2model2channels map[string]map[string][]*Channel
@@ -75,7 +78,43 @@ func SyncChannelCache(frequency int) {
}
}
func CacheGetRandomSatisfiedChannel(group string, model string, retry int) (*Channel, error) {
func CacheGetRandomSatisfiedChannel(c *gin.Context, group string, model string, retry int) (*Channel, string, error) {
var channel *Channel
var err error
selectGroup := group
if group == "auto" {
if len(setting.AutoGroups) == 0 {
return nil, selectGroup, errors.New("auto groups is not enabled")
}
for _, autoGroup := range setting.AutoGroups {
if common.DebugEnabled {
println("autoGroup:", autoGroup)
}
channel, _ = getRandomSatisfiedChannel(autoGroup, model, retry)
if channel == nil {
continue
} else {
c.Set("auto_group", autoGroup)
selectGroup = autoGroup
if common.DebugEnabled {
println("selectGroup:", selectGroup)
}
break
}
}
} else {
channel, err = getRandomSatisfiedChannel(group, model, retry)
if err != nil {
return nil, group, err
}
}
if channel == nil {
return nil, group, errors.New("channel not found")
}
return channel, selectGroup, nil
}
func getRandomSatisfiedChannel(group string, model string, retry int) (*Channel, error) {
if strings.HasPrefix(model, "gpt-4-gizmo") {
model = "gpt-4-gizmo-*"
}

View File

@@ -597,3 +597,39 @@ func CountAllTags() (int64, error) {
err := DB.Model(&Channel{}).Where("tag is not null AND tag != ''").Distinct("tag").Count(&total).Error
return total, err
}
// Get channels of specified type with pagination
func GetChannelsByType(startIdx int, num int, idSort bool, channelType int) ([]*Channel, error) {
var channels []*Channel
order := "priority desc"
if idSort {
order = "id desc"
}
err := DB.Where("type = ?", channelType).Order(order).Limit(num).Offset(startIdx).Omit("key").Find(&channels).Error
return channels, err
}
// Count channels of specific type
func CountChannelsByType(channelType int) (int64, error) {
var count int64
err := DB.Model(&Channel{}).Where("type = ?", channelType).Count(&count).Error
return count, err
}
// Return map[type]count for all channels
func CountChannelsGroupByType() (map[int64]int64, error) {
type result struct {
Type int64 `gorm:"column:type"`
Count int64 `gorm:"column:count"`
}
var results []result
err := DB.Model(&Channel{}).Select("type, count(*) as count").Group("type").Find(&results).Error
if err != nil {
return nil, err
}
counts := make(map[int64]int64)
for _, r := range results {
counts[r.Type] = r.Count
}
return counts, nil
}

View File

@@ -46,6 +46,15 @@ func initCol() {
logGroupCol = commonGroupCol
logKeyCol = commonKeyCol
}
} else {
// LOG_SQL_DSN 为空时,日志数据库与主数据库相同
if common.UsingPostgreSQL {
logGroupCol = `"group"`
logKeyCol = `"key"`
} else {
logGroupCol = commonGroupCol
logKeyCol = commonKeyCol
}
}
// log sql type and database type
common.SysLog("Using Log SQL Type: " + common.LogSqlType)

View File

@@ -5,6 +5,7 @@ import (
"one-api/setting"
"one-api/setting/config"
"one-api/setting/operation_setting"
"one-api/setting/ratio_setting"
"strconv"
"strings"
"time"
@@ -76,6 +77,9 @@ func InitOptionMap() {
common.OptionMap["MinTopUp"] = strconv.Itoa(setting.MinTopUp)
common.OptionMap["TopupGroupRatio"] = common.TopupGroupRatio2JSONString()
common.OptionMap["Chats"] = setting.Chats2JsonString()
common.OptionMap["AutoGroups"] = setting.AutoGroups2JsonString()
common.OptionMap["DefaultUseAutoGroup"] = strconv.FormatBool(setting.DefaultUseAutoGroup)
common.OptionMap["PayMethods"] = setting.PayMethods2JsonString()
common.OptionMap["GitHubClientId"] = ""
common.OptionMap["GitHubClientSecret"] = ""
common.OptionMap["TelegramBotToken"] = ""
@@ -94,13 +98,13 @@ func InitOptionMap() {
common.OptionMap["ModelRequestRateLimitDurationMinutes"] = strconv.Itoa(setting.ModelRequestRateLimitDurationMinutes)
common.OptionMap["ModelRequestRateLimitSuccessCount"] = strconv.Itoa(setting.ModelRequestRateLimitSuccessCount)
common.OptionMap["ModelRequestRateLimitGroup"] = setting.ModelRequestRateLimitGroup2JSONString()
common.OptionMap["ModelRatio"] = operation_setting.ModelRatio2JSONString()
common.OptionMap["ModelPrice"] = operation_setting.ModelPrice2JSONString()
common.OptionMap["CacheRatio"] = operation_setting.CacheRatio2JSONString()
common.OptionMap["GroupRatio"] = setting.GroupRatio2JSONString()
common.OptionMap["GroupGroupRatio"] = setting.GroupGroupRatio2JSONString()
common.OptionMap["ModelRatio"] = ratio_setting.ModelRatio2JSONString()
common.OptionMap["ModelPrice"] = ratio_setting.ModelPrice2JSONString()
common.OptionMap["CacheRatio"] = ratio_setting.CacheRatio2JSONString()
common.OptionMap["GroupRatio"] = ratio_setting.GroupRatio2JSONString()
common.OptionMap["GroupGroupRatio"] = ratio_setting.GroupGroupRatio2JSONString()
common.OptionMap["UserUsableGroups"] = setting.UserUsableGroups2JSONString()
common.OptionMap["CompletionRatio"] = operation_setting.CompletionRatio2JSONString()
common.OptionMap["CompletionRatio"] = ratio_setting.CompletionRatio2JSONString()
common.OptionMap["TopUpLink"] = common.TopUpLink
//common.OptionMap["ChatLink"] = common.ChatLink
//common.OptionMap["ChatLink2"] = common.ChatLink2
@@ -123,6 +127,7 @@ func InitOptionMap() {
common.OptionMap["SensitiveWords"] = setting.SensitiveWordsToString()
common.OptionMap["StreamCacheQueueLength"] = strconv.Itoa(setting.StreamCacheQueueLength)
common.OptionMap["AutomaticDisableKeywords"] = operation_setting.AutomaticDisableKeywordsToString()
common.OptionMap["ExposeRatioEnabled"] = strconv.FormatBool(ratio_setting.IsExposeRatioEnabled())
// 自动添加所有注册的模型配置
modelConfigs := config.GlobalConfig.ExportAllConfigs()
@@ -192,7 +197,7 @@ func updateOptionMap(key string, value string) (err error) {
common.ImageDownloadPermission = intValue
}
}
if strings.HasSuffix(key, "Enabled") || key == "DefaultCollapseSidebar" {
if strings.HasSuffix(key, "Enabled") || key == "DefaultCollapseSidebar" || key == "DefaultUseAutoGroup" {
boolValue := value == "true"
switch key {
case "PasswordRegisterEnabled":
@@ -261,6 +266,10 @@ func updateOptionMap(key string, value string) (err error) {
common.SMTPSSLEnabled = boolValue
case "WorkerAllowHttpImageRequestEnabled":
setting.WorkerAllowHttpImageRequestEnabled = boolValue
case "DefaultUseAutoGroup":
setting.DefaultUseAutoGroup = boolValue
case "ExposeRatioEnabled":
ratio_setting.SetExposeRatioEnabled(boolValue)
}
}
switch key {
@@ -287,6 +296,8 @@ func updateOptionMap(key string, value string) (err error) {
setting.PayAddress = value
case "Chats":
err = setting.UpdateChatsByJsonString(value)
case "AutoGroups":
err = setting.UpdateAutoGroupsByJsonString(value)
case "CustomCallbackAddress":
setting.CustomCallbackAddress = value
case "EpayId":
@@ -352,19 +363,19 @@ func updateOptionMap(key string, value string) (err error) {
case "DataExportDefaultTime":
common.DataExportDefaultTime = value
case "ModelRatio":
err = operation_setting.UpdateModelRatioByJSONString(value)
err = ratio_setting.UpdateModelRatioByJSONString(value)
case "GroupRatio":
err = setting.UpdateGroupRatioByJSONString(value)
err = ratio_setting.UpdateGroupRatioByJSONString(value)
case "GroupGroupRatio":
err = setting.UpdateGroupGroupRatioByJSONString(value)
err = ratio_setting.UpdateGroupGroupRatioByJSONString(value)
case "UserUsableGroups":
err = setting.UpdateUserUsableGroupsByJSONString(value)
case "CompletionRatio":
err = operation_setting.UpdateCompletionRatioByJSONString(value)
err = ratio_setting.UpdateCompletionRatioByJSONString(value)
case "ModelPrice":
err = operation_setting.UpdateModelPriceByJSONString(value)
err = ratio_setting.UpdateModelPriceByJSONString(value)
case "CacheRatio":
err = operation_setting.UpdateCacheRatioByJSONString(value)
err = ratio_setting.UpdateCacheRatioByJSONString(value)
case "TopUpLink":
common.TopUpLink = value
//case "ChatLink":
@@ -381,6 +392,8 @@ func updateOptionMap(key string, value string) (err error) {
operation_setting.AutomaticDisableKeywordsFromString(value)
case "StreamCacheQueueLength":
setting.StreamCacheQueueLength, _ = strconv.Atoi(value)
case "PayMethods":
err = setting.UpdatePayMethodsByJsonString(value)
}
return err
}

View File

@@ -2,7 +2,7 @@ package model
import (
"one-api/common"
"one-api/setting/operation_setting"
"one-api/setting/ratio_setting"
"sync"
"time"
)
@@ -65,14 +65,14 @@ func updatePricing() {
ModelName: model,
EnableGroup: groups,
}
modelPrice, findPrice := operation_setting.GetModelPrice(model, false)
modelPrice, findPrice := ratio_setting.GetModelPrice(model, false)
if findPrice {
pricing.ModelPrice = modelPrice
pricing.QuotaType = 1
} else {
modelRatio, _ := operation_setting.GetModelRatio(model)
modelRatio, _ := ratio_setting.GetModelRatio(model)
pricing.ModelRatio = modelRatio
pricing.CompletionRatio = operation_setting.GetCompletionRatio(model)
pricing.CompletionRatio = ratio_setting.GetCompletionRatio(model)
pricing.QuotaType = 0
}
pricingMap = append(pricingMap, pricing)

View File

@@ -10,7 +10,7 @@ import (
func cacheSetToken(token Token) error {
key := common.GenerateHMAC(token.Key)
token.Clean()
err := common.RedisHSetObj(fmt.Sprintf("token:%s", key), &token, time.Duration(constant.TokenCacheSeconds)*time.Second)
err := common.RedisHSetObj(fmt.Sprintf("token:%s", key), &token, time.Duration(constant.RedisKeyCacheSeconds())*time.Second)
if err != nil {
return err
}

View File

@@ -70,7 +70,7 @@ func updateUserCache(user User) error {
return common.RedisHSetObj(
getUserCacheKey(user.Id),
user.ToBaseUser(),
time.Duration(constant.UserId2QuotaCacheSeconds)*time.Second,
time.Duration(constant.RedisKeyCacheSeconds())*time.Second,
)
}

View File

@@ -2,11 +2,12 @@ package model
import (
"errors"
"github.com/bytedance/gopkg/util/gopool"
"gorm.io/gorm"
"one-api/common"
"sync"
"time"
"github.com/bytedance/gopkg/util/gopool"
"gorm.io/gorm"
)
const (
@@ -48,6 +49,22 @@ func addNewRecord(type_ int, id int, value int) {
}
func batchUpdate() {
// check if there's any data to update
hasData := false
for i := 0; i < BatchUpdateTypeCount; i++ {
batchUpdateLocks[i].Lock()
if len(batchUpdateStores[i]) > 0 {
hasData = true
batchUpdateLocks[i].Unlock()
break
}
batchUpdateLocks[i].Unlock()
}
if !hasData {
return
}
common.SysLog("batch update started")
for i := 0; i < BatchUpdateTypeCount; i++ {
batchUpdateLocks[i].Lock()

View File

@@ -55,7 +55,7 @@ func getAndValidAudioRequest(c *gin.Context, info *relaycommon.RelayInfo) (*dto.
}
func AudioHelper(c *gin.Context) (openaiErr *dto.OpenAIErrorWithStatusCode) {
relayInfo := relaycommon.GenRelayInfo(c)
relayInfo := relaycommon.GenRelayInfoOpenAIAudio(c)
audioRequest, err := getAndValidAudioRequest(c, relayInfo)
if err != nil {
@@ -66,10 +66,7 @@ func AudioHelper(c *gin.Context) (openaiErr *dto.OpenAIErrorWithStatusCode) {
promptTokens := 0
preConsumedTokens := common.PreConsumedQuota
if relayInfo.RelayMode == relayconstant.RelayModeAudioSpeech {
promptTokens, err = service.CountTTSToken(audioRequest.Input, audioRequest.Model)
if err != nil {
return service.OpenAIErrorWrapper(err, "count_audio_token_failed", http.StatusInternalServerError)
}
promptTokens = service.CountTTSToken(audioRequest.Input, audioRequest.Model)
preConsumedTokens = promptTokens
relayInfo.PromptTokens = promptTokens
}
@@ -89,13 +86,11 @@ func AudioHelper(c *gin.Context) (openaiErr *dto.OpenAIErrorWithStatusCode) {
}
}()
err = helper.ModelMappedHelper(c, relayInfo)
err = helper.ModelMappedHelper(c, relayInfo, audioRequest)
if err != nil {
return service.OpenAIErrorWrapperLocal(err, "model_mapped_error", http.StatusInternalServerError)
}
audioRequest.Model = relayInfo.UpstreamModelName
adaptor := GetAdaptor(relayInfo.ApiType)
if adaptor == nil {
return service.OpenAIErrorWrapperLocal(fmt.Errorf("invalid api type: %d", relayInfo.ApiType), "invalid_api_type", http.StatusBadRequest)

View File

@@ -44,4 +44,6 @@ type TaskAdaptor interface {
// FetchTask
FetchTask(baseUrl, key string, body map[string]any) (*http.Response, error)
ParseResultUrl(resp map[string]any) (string, error)
}

View File

@@ -454,6 +454,7 @@ type ClaudeResponseInfo struct {
Model string
ResponseText strings.Builder
Usage *dto.Usage
Done bool
}
func FormatClaudeResponseInfo(requestMode int, claudeResponse *dto.ClaudeResponse, oaiResponse *dto.ChatCompletionsStreamResponse, claudeInfo *ClaudeResponseInfo) bool {
@@ -461,20 +462,32 @@ func FormatClaudeResponseInfo(requestMode int, claudeResponse *dto.ClaudeRespons
claudeInfo.ResponseText.WriteString(claudeResponse.Completion)
} else {
if claudeResponse.Type == "message_start" {
// message_start, 获取usage
claudeInfo.ResponseId = claudeResponse.Message.Id
claudeInfo.Model = claudeResponse.Message.Model
// message_start, 获取usage
claudeInfo.Usage.PromptTokens = claudeResponse.Message.Usage.InputTokens
claudeInfo.Usage.PromptTokensDetails.CachedTokens = claudeResponse.Message.Usage.CacheReadInputTokens
claudeInfo.Usage.PromptTokensDetails.CachedCreationTokens = claudeResponse.Message.Usage.CacheCreationInputTokens
claudeInfo.Usage.CompletionTokens = claudeResponse.Message.Usage.OutputTokens
} else if claudeResponse.Type == "content_block_delta" {
if claudeResponse.Delta.Text != nil {
claudeInfo.ResponseText.WriteString(*claudeResponse.Delta.Text)
}
if claudeResponse.Delta.Thinking != "" {
claudeInfo.ResponseText.WriteString(claudeResponse.Delta.Thinking)
}
} else if claudeResponse.Type == "message_delta" {
claudeInfo.Usage.CompletionTokens = claudeResponse.Usage.OutputTokens
// 最终的usage获取
if claudeResponse.Usage.InputTokens > 0 {
// 不叠加,只取最新的
claudeInfo.Usage.PromptTokens = claudeResponse.Usage.InputTokens
}
claudeInfo.Usage.TotalTokens = claudeInfo.Usage.PromptTokens + claudeResponse.Usage.OutputTokens
claudeInfo.Usage.CompletionTokens = claudeResponse.Usage.OutputTokens
claudeInfo.Usage.TotalTokens = claudeInfo.Usage.PromptTokens + claudeInfo.Usage.CompletionTokens
// 判断是否完整
claudeInfo.Done = true
} else if claudeResponse.Type == "content_block_start" {
} else {
return false
@@ -506,25 +519,15 @@ func HandleStreamResponseData(c *gin.Context, info *relaycommon.RelayInfo, claud
}
}
if info.RelayFormat == relaycommon.RelayFormatClaude {
FormatClaudeResponseInfo(requestMode, &claudeResponse, nil, claudeInfo)
if requestMode == RequestModeCompletion {
claudeInfo.ResponseText.WriteString(claudeResponse.Completion)
} else {
if claudeResponse.Type == "message_start" {
// message_start, 获取usage
info.UpstreamModelName = claudeResponse.Message.Model
claudeInfo.Usage.PromptTokens = claudeResponse.Message.Usage.InputTokens
claudeInfo.Usage.PromptTokensDetails.CachedTokens = claudeResponse.Message.Usage.CacheReadInputTokens
claudeInfo.Usage.PromptTokensDetails.CachedCreationTokens = claudeResponse.Message.Usage.CacheCreationInputTokens
claudeInfo.Usage.CompletionTokens = claudeResponse.Message.Usage.OutputTokens
} else if claudeResponse.Type == "content_block_delta" {
claudeInfo.ResponseText.WriteString(claudeResponse.Delta.GetText())
} else if claudeResponse.Type == "message_delta" {
if claudeResponse.Usage.InputTokens > 0 {
// 不叠加,只取最新的
claudeInfo.Usage.PromptTokens = claudeResponse.Usage.InputTokens
}
claudeInfo.Usage.CompletionTokens = claudeResponse.Usage.OutputTokens
claudeInfo.Usage.TotalTokens = claudeInfo.Usage.PromptTokens + claudeInfo.Usage.CompletionTokens
}
}
helper.ClaudeChunkData(c, claudeResponse, data)
@@ -544,29 +547,25 @@ func HandleStreamResponseData(c *gin.Context, info *relaycommon.RelayInfo, claud
}
func HandleStreamFinalResponse(c *gin.Context, info *relaycommon.RelayInfo, claudeInfo *ClaudeResponseInfo, requestMode int) {
if requestMode == RequestModeCompletion {
claudeInfo.Usage = service.ResponseText2Usage(claudeInfo.ResponseText.String(), info.UpstreamModelName, info.PromptTokens)
} else {
if claudeInfo.Usage.PromptTokens == 0 {
//上游出错
}
if claudeInfo.Usage.CompletionTokens == 0 || !claudeInfo.Done {
if common.DebugEnabled {
common.SysError("claude response usage is not complete, maybe upstream error")
}
claudeInfo.Usage = service.ResponseText2Usage(claudeInfo.ResponseText.String(), info.UpstreamModelName, claudeInfo.Usage.PromptTokens)
}
}
if info.RelayFormat == relaycommon.RelayFormatClaude {
if requestMode == RequestModeCompletion {
claudeInfo.Usage, _ = service.ResponseText2Usage(claudeInfo.ResponseText.String(), info.UpstreamModelName, info.PromptTokens)
} else {
// 说明流模式建立失败,可能为官方出错
if claudeInfo.Usage.PromptTokens == 0 {
//usage.PromptTokens = info.PromptTokens
}
if claudeInfo.Usage.CompletionTokens == 0 {
claudeInfo.Usage, _ = service.ResponseText2Usage(claudeInfo.ResponseText.String(), info.UpstreamModelName, claudeInfo.Usage.PromptTokens)
}
}
//
} else if info.RelayFormat == relaycommon.RelayFormatOpenAI {
if requestMode == RequestModeCompletion {
claudeInfo.Usage, _ = service.ResponseText2Usage(claudeInfo.ResponseText.String(), info.UpstreamModelName, info.PromptTokens)
} else {
if claudeInfo.Usage.PromptTokens == 0 {
//上游出错
}
if claudeInfo.Usage.CompletionTokens == 0 {
claudeInfo.Usage, _ = service.ResponseText2Usage(claudeInfo.ResponseText.String(), info.UpstreamModelName, claudeInfo.Usage.PromptTokens)
}
}
if info.ShouldIncludeUsage {
response := helper.GenerateFinalUsageResponse(claudeInfo.ResponseId, claudeInfo.Created, info.UpstreamModelName, *claudeInfo.Usage)
err := helper.ObjectData(c, response)
@@ -619,10 +618,7 @@ func HandleClaudeResponseData(c *gin.Context, info *relaycommon.RelayInfo, claud
}
}
if requestMode == RequestModeCompletion {
completionTokens, err := service.CountTextToken(claudeResponse.Completion, info.OriginModelName)
if err != nil {
return service.OpenAIErrorWrapper(err, "count_token_text_failed", http.StatusInternalServerError)
}
completionTokens := service.CountTextToken(claudeResponse.Completion, info.OriginModelName)
claudeInfo.Usage.PromptTokens = info.PromptTokens
claudeInfo.Usage.CompletionTokens = completionTokens
claudeInfo.Usage.TotalTokens = info.PromptTokens + completionTokens

View File

@@ -71,7 +71,7 @@ func cfStreamHandler(c *gin.Context, resp *http.Response, info *relaycommon.Rela
if err := scanner.Err(); err != nil {
common.LogError(c, "error_scanning_stream_response: "+err.Error())
}
usage, _ := service.ResponseText2Usage(responseText, info.UpstreamModelName, info.PromptTokens)
usage := service.ResponseText2Usage(responseText, info.UpstreamModelName, info.PromptTokens)
if info.ShouldIncludeUsage {
response := helper.GenerateFinalUsageResponse(id, info.StartTime.Unix(), info.UpstreamModelName, *usage)
err := helper.ObjectData(c, response)
@@ -108,7 +108,7 @@ func cfHandler(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo)
for _, choice := range response.Choices {
responseText += choice.Message.StringContent()
}
usage, _ := service.ResponseText2Usage(responseText, info.UpstreamModelName, info.PromptTokens)
usage := service.ResponseText2Usage(responseText, info.UpstreamModelName, info.PromptTokens)
response.Usage = *usage
response.Id = helper.GetResponseID(c)
jsonResponse, err := json.Marshal(response)
@@ -150,7 +150,7 @@ func cfSTTHandler(c *gin.Context, resp *http.Response, info *relaycommon.RelayIn
usage := &dto.Usage{}
usage.PromptTokens = info.PromptTokens
usage.CompletionTokens, _ = service.CountTextToken(cfResp.Result.Text, info.UpstreamModelName)
usage.CompletionTokens = service.CountTextToken(cfResp.Result.Text, info.UpstreamModelName)
usage.TotalTokens = usage.PromptTokens + usage.CompletionTokens
return nil, usage

View File

@@ -3,7 +3,6 @@ package cohere
import (
"bufio"
"encoding/json"
"fmt"
"github.com/gin-gonic/gin"
"io"
"net/http"
@@ -78,7 +77,7 @@ func stopReasonCohere2OpenAI(reason string) string {
}
func cohereStreamHandler(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (*dto.OpenAIErrorWithStatusCode, *dto.Usage) {
responseId := fmt.Sprintf("chatcmpl-%s", common.GetUUID())
responseId := helper.GetResponseID(c)
createdTime := common.GetTimestamp()
usage := &dto.Usage{}
responseText := ""
@@ -163,7 +162,7 @@ func cohereStreamHandler(c *gin.Context, resp *http.Response, info *relaycommon.
}
})
if usage.PromptTokens == 0 {
usage, _ = service.ResponseText2Usage(responseText, info.UpstreamModelName, info.PromptTokens)
usage = service.ResponseText2Usage(responseText, info.UpstreamModelName, info.PromptTokens)
}
return nil, usage
}

View File

@@ -106,7 +106,7 @@ func cozeChatStreamHandler(c *gin.Context, resp *http.Response, info *relaycommo
var currentEvent string
var currentData string
var usage dto.Usage
var usage = &dto.Usage{}
for scanner.Scan() {
line := scanner.Text()
@@ -114,7 +114,7 @@ func cozeChatStreamHandler(c *gin.Context, resp *http.Response, info *relaycommo
if line == "" {
if currentEvent != "" && currentData != "" {
// handle last event
handleCozeEvent(c, currentEvent, currentData, &responseText, &usage, id, info)
handleCozeEvent(c, currentEvent, currentData, &responseText, usage, id, info)
currentEvent = ""
currentData = ""
}
@@ -134,7 +134,7 @@ func cozeChatStreamHandler(c *gin.Context, resp *http.Response, info *relaycommo
// Last event
if currentEvent != "" && currentData != "" {
handleCozeEvent(c, currentEvent, currentData, &responseText, &usage, id, info)
handleCozeEvent(c, currentEvent, currentData, &responseText, usage, id, info)
}
if err := scanner.Err(); err != nil {
@@ -143,12 +143,10 @@ func cozeChatStreamHandler(c *gin.Context, resp *http.Response, info *relaycommo
helper.Done(c)
if usage.TotalTokens == 0 {
usage.PromptTokens = info.PromptTokens
usage.CompletionTokens, _ = service.CountTextToken("gpt-3.5-turbo", responseText)
usage.TotalTokens = usage.PromptTokens + usage.CompletionTokens
usage = service.ResponseText2Usage(responseText, info.UpstreamModelName, c.GetInt("coze_input_count"))
}
return nil, &usage
return nil, usage
}
func handleCozeEvent(c *gin.Context, event string, data string, responseText *string, usage *dto.Usage, id string, info *relaycommon.RelayInfo) {

View File

@@ -243,15 +243,8 @@ func difyStreamHandler(c *gin.Context, resp *http.Response, info *relaycommon.Re
return true
})
helper.Done(c)
err := resp.Body.Close()
if err != nil {
// return service.OpenAIErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil
common.SysError("close_response_body_failed: " + err.Error())
}
if usage.TotalTokens == 0 {
usage.PromptTokens = info.PromptTokens
usage.CompletionTokens, _ = service.CountTextToken("gpt-3.5-turbo", responseText)
usage.TotalTokens = usage.PromptTokens + usage.CompletionTokens
usage = service.ResponseText2Usage(responseText, info.UpstreamModelName, info.PromptTokens)
}
usage.CompletionTokens += nodeToken
return nil, usage

View File

@@ -72,10 +72,13 @@ func (a *Adaptor) Init(info *relaycommon.RelayInfo) {
func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) {
if model_setting.GetGeminiSettings().ThinkingAdapterEnabled {
// suffix -thinking and -nothinking
if strings.HasSuffix(info.OriginModelName, "-thinking") {
// 新增逻辑:处理 -thinking-<budget> 格式
if strings.Contains(info.UpstreamModelName, "-thinking-") {
parts := strings.Split(info.UpstreamModelName, "-thinking-")
info.UpstreamModelName = parts[0]
} else if strings.HasSuffix(info.UpstreamModelName, "-thinking") { // 旧的适配
info.UpstreamModelName = strings.TrimSuffix(info.UpstreamModelName, "-thinking")
} else if strings.HasSuffix(info.OriginModelName, "-nothinking") {
} else if strings.HasSuffix(info.UpstreamModelName, "-nothinking") {
info.UpstreamModelName = strings.TrimSuffix(info.UpstreamModelName, "-nothinking")
}
}

View File

@@ -140,6 +140,7 @@ type GeminiChatGenerationConfig struct {
Seed int64 `json:"seed,omitempty"`
ResponseModalities []string `json:"responseModalities,omitempty"`
ThinkingConfig *GeminiThinkingConfig `json:"thinkingConfig,omitempty"`
SpeechConfig json.RawMessage `json:"speechConfig,omitempty"` // RawMessage to allow flexible speech config
}
type GeminiChatCandidate struct {

View File

@@ -9,6 +9,7 @@ import (
relaycommon "one-api/relay/common"
"one-api/relay/helper"
"one-api/service"
"strings"
"github.com/gin-gonic/gin"
)
@@ -35,23 +36,10 @@ func GeminiTextGenerationHandler(c *gin.Context, resp *http.Response, info *rela
return nil, service.OpenAIErrorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError)
}
// 检查是否有候选响应
if len(geminiResponse.Candidates) == 0 {
return nil, &dto.OpenAIErrorWithStatusCode{
Error: dto.OpenAIError{
Message: "No candidates returned",
Type: "server_error",
Param: "",
Code: 500,
},
StatusCode: resp.StatusCode,
}
}
// 计算使用量(基于 UsageMetadata
usage := dto.Usage{
PromptTokens: geminiResponse.UsageMetadata.PromptTokenCount,
CompletionTokens: geminiResponse.UsageMetadata.CandidatesTokenCount,
CompletionTokens: geminiResponse.UsageMetadata.CandidatesTokenCount + geminiResponse.UsageMetadata.ThoughtsTokenCount,
TotalTokens: geminiResponse.UsageMetadata.TotalTokenCount,
}
@@ -88,6 +76,8 @@ func GeminiTextGenerationStreamHandler(c *gin.Context, resp *http.Response, info
helper.SetEventStreamHeaders(c)
responseText := strings.Builder{}
helper.StreamScannerHandler(c, resp, info, func(data string) bool {
var geminiResponse GeminiChatResponse
err := common.DecodeJsonStr(data, &geminiResponse)
@@ -102,13 +92,16 @@ func GeminiTextGenerationStreamHandler(c *gin.Context, resp *http.Response, info
if part.InlineData != nil && part.InlineData.MimeType != "" {
imageCount++
}
if part.Text != "" {
responseText.WriteString(part.Text)
}
}
}
// 更新使用量统计
if geminiResponse.UsageMetadata.TotalTokenCount != 0 {
usage.PromptTokens = geminiResponse.UsageMetadata.PromptTokenCount
usage.CompletionTokens = geminiResponse.UsageMetadata.CandidatesTokenCount
usage.CompletionTokens = geminiResponse.UsageMetadata.CandidatesTokenCount + geminiResponse.UsageMetadata.ThoughtsTokenCount
usage.TotalTokens = geminiResponse.UsageMetadata.TotalTokenCount
usage.CompletionTokenDetails.ReasoningTokens = geminiResponse.UsageMetadata.ThoughtsTokenCount
for _, detail := range geminiResponse.UsageMetadata.PromptTokensDetails {
@@ -121,7 +114,7 @@ func GeminiTextGenerationStreamHandler(c *gin.Context, resp *http.Response, info
}
// 直接发送 GeminiChatResponse 响应
err = helper.ObjectData(c, geminiResponse)
err = helper.StringData(c, data)
if err != nil {
common.LogError(c, err.Error())
}
@@ -135,8 +128,16 @@ func GeminiTextGenerationStreamHandler(c *gin.Context, resp *http.Response, info
}
}
// 计算最终使用量
usage.CompletionTokens = usage.TotalTokens - usage.PromptTokens
// 如果usage.CompletionTokens为0则使用本地统计的completion tokens
if usage.CompletionTokens == 0 {
str := responseText.String()
if len(str) > 0 {
usage = service.ResponseText2Usage(responseText.String(), info.UpstreamModelName, info.PromptTokens)
} else {
// 空补全,不需要使用量
usage = &dto.Usage{}
}
}
// 移除流式响应结尾的[Done]因为Gemini API没有发送Done的行为
//helper.Done(c)

View File

@@ -12,6 +12,7 @@ import (
"one-api/relay/helper"
"one-api/service"
"one-api/setting/model_setting"
"strconv"
"strings"
"unicode/utf8"
@@ -36,6 +37,47 @@ var geminiSupportedMimeTypes = map[string]bool{
"video/flv": true,
}
// Gemini 允许的思考预算范围
const (
pro25MinBudget = 128
pro25MaxBudget = 32768
flash25MaxBudget = 24576
flash25LiteMinBudget = 512
flash25LiteMaxBudget = 24576
)
// clampThinkingBudget 根据模型名称将预算限制在允许的范围内
func clampThinkingBudget(modelName string, budget int) int {
isNew25Pro := strings.HasPrefix(modelName, "gemini-2.5-pro") &&
!strings.HasPrefix(modelName, "gemini-2.5-pro-preview-05-06") &&
!strings.HasPrefix(modelName, "gemini-2.5-pro-preview-03-25")
is25FlashLite := strings.HasPrefix(modelName, "gemini-2.5-flash-lite")
if is25FlashLite {
if budget < flash25LiteMinBudget {
return flash25LiteMinBudget
}
if budget > flash25LiteMaxBudget {
return flash25LiteMaxBudget
}
} else if isNew25Pro {
if budget < pro25MinBudget {
return pro25MinBudget
}
if budget > pro25MaxBudget {
return pro25MaxBudget
}
} else { // 其他模型
if budget < 0 {
return 0
}
if budget > flash25MaxBudget {
return flash25MaxBudget
}
}
return budget
}
// Setting safety to the lowest possible values since Gemini is already powerless enough
func CovertGemini2OpenAI(textRequest dto.GeneralOpenAIRequest, info *relaycommon.RelayInfo) (*GeminiChatRequest, error) {
@@ -57,16 +99,30 @@ func CovertGemini2OpenAI(textRequest dto.GeneralOpenAIRequest, info *relaycommon
}
if model_setting.GetGeminiSettings().ThinkingAdapterEnabled {
if strings.HasSuffix(info.OriginModelName, "-thinking") {
// 硬编码不支持 ThinkingBudget 的旧模型
modelName := info.UpstreamModelName
isNew25Pro := strings.HasPrefix(modelName, "gemini-2.5-pro") &&
!strings.HasPrefix(modelName, "gemini-2.5-pro-preview-05-06") &&
!strings.HasPrefix(modelName, "gemini-2.5-pro-preview-03-25")
if strings.Contains(modelName, "-thinking-") {
parts := strings.SplitN(modelName, "-thinking-", 2)
if len(parts) == 2 && parts[1] != "" {
if budgetTokens, err := strconv.Atoi(parts[1]); err == nil {
clampedBudget := clampThinkingBudget(modelName, budgetTokens)
geminiRequest.GenerationConfig.ThinkingConfig = &GeminiThinkingConfig{
ThinkingBudget: common.GetPointer(clampedBudget),
IncludeThoughts: true,
}
}
}
} else if strings.HasSuffix(modelName, "-thinking") {
unsupportedModels := []string{
"gemini-2.5-pro-preview-05-06",
"gemini-2.5-pro-preview-03-25",
}
isUnsupported := false
for _, unsupportedModel := range unsupportedModels {
if strings.HasPrefix(info.OriginModelName, unsupportedModel) {
if strings.HasPrefix(modelName, unsupportedModel) {
isUnsupported = true
break
}
@@ -77,40 +133,17 @@ func CovertGemini2OpenAI(textRequest dto.GeneralOpenAIRequest, info *relaycommon
IncludeThoughts: true,
}
} else {
budgetTokens := model_setting.GetGeminiSettings().ThinkingAdapterBudgetTokensPercentage * float64(geminiRequest.GenerationConfig.MaxOutputTokens)
// 检查是否为新的2.5pro模型支持ThinkingBudget但有特殊范围
isNew25Pro := strings.HasPrefix(info.OriginModelName, "gemini-2.5-pro") &&
!strings.HasPrefix(info.OriginModelName, "gemini-2.5-pro-preview-05-06") &&
!strings.HasPrefix(info.OriginModelName, "gemini-2.5-pro-preview-03-25")
if isNew25Pro {
// 新的2.5pro模型ThinkingBudget范围为128-32768
if budgetTokens == 0 || budgetTokens < 128 {
budgetTokens = 128
} else if budgetTokens > 32768 {
budgetTokens = 32768
}
} else {
// 其他模型ThinkingBudget范围为0-24576
if budgetTokens == 0 || budgetTokens > 24576 {
budgetTokens = 24576
}
}
geminiRequest.GenerationConfig.ThinkingConfig = &GeminiThinkingConfig{
ThinkingBudget: common.GetPointer(int(budgetTokens)),
IncludeThoughts: true,
}
if geminiRequest.GenerationConfig.MaxOutputTokens > 0 {
budgetTokens := model_setting.GetGeminiSettings().ThinkingAdapterBudgetTokensPercentage * float64(geminiRequest.GenerationConfig.MaxOutputTokens)
clampedBudget := clampThinkingBudget(modelName, int(budgetTokens))
geminiRequest.GenerationConfig.ThinkingConfig.ThinkingBudget = common.GetPointer(clampedBudget)
}
}
} else if strings.HasSuffix(info.OriginModelName, "-nothinking") {
// 检查是否为新的2.5pro模型(不支持-nothinking因为最低值只能为128
isNew25Pro := strings.HasPrefix(info.OriginModelName, "gemini-2.5-pro") &&
!strings.HasPrefix(info.OriginModelName, "gemini-2.5-pro-preview-05-06") &&
!strings.HasPrefix(info.OriginModelName, "gemini-2.5-pro-preview-03-25")
} else if strings.HasSuffix(modelName, "-nothinking") {
if !isNew25Pro {
// 只有非新2.5pro模型才支持-nothinking
geminiRequest.GenerationConfig.ThinkingConfig = &GeminiThinkingConfig{
ThinkingBudget: common.GetPointer(0),
}
@@ -283,7 +316,8 @@ func CovertGemini2OpenAI(textRequest dto.GeneralOpenAIRequest, info *relaycommon
// 校验 MimeType 是否在 Gemini 支持的白名单中
if _, ok := geminiSupportedMimeTypes[strings.ToLower(fileData.MimeType)]; !ok {
return nil, fmt.Errorf("MIME type '%s' from URL '%s' is not supported by Gemini. Supported types are: %v", fileData.MimeType, part.GetImageMedia().Url, getSupportedMimeTypesList())
url := part.GetImageMedia().Url
return nil, fmt.Errorf("mime type is not supported by Gemini: '%s', url: '%s', supported types are: %v", fileData.MimeType, url, getSupportedMimeTypesList())
}
parts = append(parts, GeminiPart{
@@ -341,7 +375,9 @@ func CovertGemini2OpenAI(textRequest dto.GeneralOpenAIRequest, info *relaycommon
if content.Role == "assistant" {
content.Role = "model"
}
geminiRequest.Contents = append(geminiRequest.Contents, content)
if len(content.Parts) > 0 {
geminiRequest.Contents = append(geminiRequest.Contents, content)
}
}
if len(system_content) > 0 {
@@ -611,9 +647,9 @@ func getResponseToolCall(item *GeminiPart) *dto.ToolCallResponse {
}
}
func responseGeminiChat2OpenAI(response *GeminiChatResponse) *dto.OpenAITextResponse {
func responseGeminiChat2OpenAI(c *gin.Context, response *GeminiChatResponse) *dto.OpenAITextResponse {
fullTextResponse := dto.OpenAITextResponse{
Id: fmt.Sprintf("chatcmpl-%s", common.GetUUID()),
Id: helper.GetResponseID(c),
Object: "chat.completion",
Created: common.GetTimestamp(),
Choices: make([]dto.OpenAITextResponseChoice, 0, len(response.Candidates)),
@@ -754,7 +790,7 @@ func streamResponseGeminiChat2OpenAI(geminiResponse *GeminiChatResponse) (*dto.C
func GeminiChatStreamHandler(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (*dto.OpenAIErrorWithStatusCode, *dto.Usage) {
// responseText := ""
id := fmt.Sprintf("chatcmpl-%s", common.GetUUID())
id := helper.GetResponseID(c)
createAt := common.GetTimestamp()
var usage = &dto.Usage{}
var imageCount int
@@ -849,7 +885,7 @@ func GeminiChatHandler(c *gin.Context, resp *http.Response, info *relaycommon.Re
StatusCode: resp.StatusCode,
}, nil
}
fullTextResponse := responseGeminiChat2OpenAI(&geminiResponse)
fullTextResponse := responseGeminiChat2OpenAI(c, &geminiResponse)
fullTextResponse.Model = info.UpstreamModelName
usage := dto.Usage{
PromptTokens: geminiResponse.UsageMetadata.PromptTokenCount,

View File

@@ -88,6 +88,13 @@ func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) {
requestURL := strings.Split(info.RequestURLPath, "?")[0]
requestURL = fmt.Sprintf("%s?api-version=%s", requestURL, apiVersion)
task := strings.TrimPrefix(requestURL, "/v1/")
// 特殊处理 responses API
if info.RelayMode == constant.RelayModeResponses {
requestURL = fmt.Sprintf("/openai/v1/responses?api-version=preview")
return relaycommon.GetFullRequestURL(info.BaseUrl, requestURL, info.ChannelType), nil
}
model_ := info.UpstreamModelName
// 2025年5月10日后创建的渠道不移除.
if info.ChannelCreateTime < constant2.AzureNoRemoveDotTime {

View File

@@ -15,6 +15,7 @@ import (
"one-api/relay/helper"
"one-api/service"
"os"
"path/filepath"
"strings"
"github.com/bytedance/gopkg/util/gopool"
@@ -180,7 +181,7 @@ func OaiStreamHandler(c *gin.Context, resp *http.Response, info *relaycommon.Rel
}
if !containStreamUsage {
usage, _ = service.ResponseText2Usage(responseTextBuilder.String(), info.UpstreamModelName, info.PromptTokens)
usage = service.ResponseText2Usage(responseTextBuilder.String(), info.UpstreamModelName, info.PromptTokens)
usage.CompletionTokens += toolCount * 7
} else {
if info.ChannelType == common.ChannelTypeDeepSeek {
@@ -215,7 +216,7 @@ func OpenaiHandler(c *gin.Context, resp *http.Response, info *relaycommon.RelayI
StatusCode: resp.StatusCode,
}, nil
}
forceFormat := false
if forceFmt, ok := info.ChannelSetting[constant.ForceFormat].(bool); ok {
forceFormat = forceFmt
@@ -224,7 +225,7 @@ func OpenaiHandler(c *gin.Context, resp *http.Response, info *relaycommon.RelayI
if simpleResponse.Usage.TotalTokens == 0 || (simpleResponse.Usage.PromptTokens == 0 && simpleResponse.Usage.CompletionTokens == 0) {
completionTokens := 0
for _, choice := range simpleResponse.Choices {
ctkm, _ := service.CountTextToken(choice.Message.StringContent()+choice.Message.ReasoningContent+choice.Message.Reasoning, info.UpstreamModelName)
ctkm := service.CountTextToken(choice.Message.StringContent()+choice.Message.ReasoningContent+choice.Message.Reasoning, info.UpstreamModelName)
completionTokens += ctkm
}
simpleResponse.Usage = dto.Usage{
@@ -275,9 +276,9 @@ func OpenaiHandler(c *gin.Context, resp *http.Response, info *relaycommon.RelayI
func OpenaiTTSHandler(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (*dto.OpenAIErrorWithStatusCode, *dto.Usage) {
// the status code has been judged before, if there is a body reading failure,
// it should be regarded as a non-recoverable error, so it should not return err for external retry.
// Analogous to nginx's load balancing, it will only retry if it can't be requested or
// if the upstream returns a specific status code, once the upstream has already written the header,
// the subsequent failure of the response body should be regarded as a non-recoverable error,
// Analogous to nginx's load balancing, it will only retry if it can't be requested or
// if the upstream returns a specific status code, once the upstream has already written the header,
// the subsequent failure of the response body should be regarded as a non-recoverable error,
// and can be terminated directly.
defer resp.Body.Close()
usage := &dto.Usage{}
@@ -345,13 +346,14 @@ func countAudioTokens(c *gin.Context) (int, error) {
if err = c.ShouldBind(&reqBody); err != nil {
return 0, errors.WithStack(err)
}
ext := filepath.Ext(reqBody.File.Filename) // 获取文件扩展名
reqFp, err := reqBody.File.Open()
if err != nil {
return 0, errors.WithStack(err)
}
defer reqFp.Close()
tmpFp, err := os.CreateTemp("", "audio-*")
tmpFp, err := os.CreateTemp("", "audio-*"+ext)
if err != nil {
return 0, errors.WithStack(err)
}
@@ -365,7 +367,7 @@ func countAudioTokens(c *gin.Context) (int, error) {
return 0, errors.WithStack(err)
}
duration, err := common.GetAudioDuration(c.Request.Context(), tmpFp.Name())
duration, err := common.GetAudioDuration(c.Request.Context(), tmpFp.Name(), ext)
if err != nil {
return 0, errors.WithStack(err)
}

View File

@@ -110,7 +110,7 @@ func OaiResponsesStreamHandler(c *gin.Context, resp *http.Response, info *relayc
tempStr := responseTextBuilder.String()
if len(tempStr) > 0 {
// 非正常结束,使用输出文本的 token 数量
completionTokens, _ := service.CountTextToken(tempStr, info.UpstreamModelName)
completionTokens := service.CountTextToken(tempStr, info.UpstreamModelName)
usage.CompletionTokens = completionTokens
}
}

View File

@@ -74,7 +74,7 @@ func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycom
if info.IsStream {
var responseText string
err, responseText = palmStreamHandler(c, resp)
usage, _ = service.ResponseText2Usage(responseText, info.UpstreamModelName, info.PromptTokens)
usage = service.ResponseText2Usage(responseText, info.UpstreamModelName, info.PromptTokens)
} else {
err, usage = palmHandler(c, resp, info.PromptTokens, info.UpstreamModelName)
}

View File

@@ -2,7 +2,6 @@ package palm
import (
"encoding/json"
"fmt"
"github.com/gin-gonic/gin"
"io"
"net/http"
@@ -73,7 +72,7 @@ func streamResponsePaLM2OpenAI(palmResponse *PaLMChatResponse) *dto.ChatCompleti
func palmStreamHandler(c *gin.Context, resp *http.Response) (*dto.OpenAIErrorWithStatusCode, string) {
responseText := ""
responseId := fmt.Sprintf("chatcmpl-%s", common.GetUUID())
responseId := helper.GetResponseID(c)
createdTime := common.GetTimestamp()
dataChan := make(chan string)
stopChan := make(chan bool)
@@ -156,7 +155,7 @@ func palmHandler(c *gin.Context, resp *http.Response, promptTokens int, model st
}, nil
}
fullTextResponse := responsePaLM2OpenAI(&palmResponse)
completionTokens, _ := service.CountTextToken(palmResponse.Candidates[0].Content, model)
completionTokens := service.CountTextToken(palmResponse.Candidates[0].Content, model)
usage := dto.Usage{
PromptTokens: promptTokens,
CompletionTokens: completionTokens,

View File

@@ -0,0 +1,312 @@
package kling
import (
"bytes"
"context"
"encoding/json"
"fmt"
"io"
"net/http"
"strings"
"time"
"github.com/gin-gonic/gin"
"github.com/golang-jwt/jwt"
"github.com/pkg/errors"
"one-api/common"
"one-api/dto"
"one-api/relay/channel"
relaycommon "one-api/relay/common"
"one-api/service"
)
// ============================
// Request / Response structures
// ============================
type SubmitReq struct {
Prompt string `json:"prompt"`
Model string `json:"model,omitempty"`
Mode string `json:"mode,omitempty"`
Image string `json:"image,omitempty"`
Size string `json:"size,omitempty"`
Duration int `json:"duration,omitempty"`
Metadata map[string]interface{} `json:"metadata,omitempty"`
}
type requestPayload struct {
Prompt string `json:"prompt,omitempty"`
Image string `json:"image,omitempty"`
Mode string `json:"mode,omitempty"`
Duration string `json:"duration,omitempty"`
AspectRatio string `json:"aspect_ratio,omitempty"`
Model string `json:"model,omitempty"`
ModelName string `json:"model_name,omitempty"`
CfgScale float64 `json:"cfg_scale,omitempty"`
}
type responsePayload struct {
Code int `json:"code"`
Message string `json:"message"`
Data struct {
TaskID string `json:"task_id"`
} `json:"data"`
}
// ============================
// Adaptor implementation
// ============================
type TaskAdaptor struct {
ChannelType int
accessKey string
secretKey string
baseURL string
}
func (a *TaskAdaptor) Init(info *relaycommon.TaskRelayInfo) {
a.ChannelType = info.ChannelType
a.baseURL = info.BaseUrl
// apiKey format: "access_key,secret_key"
keyParts := strings.Split(info.ApiKey, ",")
if len(keyParts) == 2 {
a.accessKey = strings.TrimSpace(keyParts[0])
a.secretKey = strings.TrimSpace(keyParts[1])
}
}
// ValidateRequestAndSetAction parses body, validates fields and sets default action.
func (a *TaskAdaptor) ValidateRequestAndSetAction(c *gin.Context, info *relaycommon.TaskRelayInfo) (taskErr *dto.TaskError) {
// Accept only POST /v1/video/generations as "generate" action.
action := "generate"
info.Action = action
var req SubmitReq
if err := common.UnmarshalBodyReusable(c, &req); err != nil {
taskErr = service.TaskErrorWrapperLocal(err, "invalid_request", http.StatusBadRequest)
return
}
if strings.TrimSpace(req.Prompt) == "" {
taskErr = service.TaskErrorWrapperLocal(fmt.Errorf("prompt is required"), "invalid_request", http.StatusBadRequest)
return
}
// Store into context for later usage
c.Set("kling_request", req)
return nil
}
// BuildRequestURL constructs the upstream URL.
func (a *TaskAdaptor) BuildRequestURL(info *relaycommon.TaskRelayInfo) (string, error) {
return fmt.Sprintf("%s/v1/videos/image2video", a.baseURL), nil
}
// BuildRequestHeader sets required headers.
func (a *TaskAdaptor) BuildRequestHeader(c *gin.Context, req *http.Request, info *relaycommon.TaskRelayInfo) error {
token, err := a.createJWTToken()
if err != nil {
return fmt.Errorf("failed to create JWT token: %w", err)
}
req.Header.Set("Content-Type", "application/json")
req.Header.Set("Accept", "application/json")
req.Header.Set("Authorization", "Bearer "+token)
req.Header.Set("User-Agent", "kling-sdk/1.0")
return nil
}
// BuildRequestBody converts request into Kling specific format.
func (a *TaskAdaptor) BuildRequestBody(c *gin.Context, info *relaycommon.TaskRelayInfo) (io.Reader, error) {
v, exists := c.Get("kling_request")
if !exists {
return nil, fmt.Errorf("request not found in context")
}
req := v.(SubmitReq)
body := a.convertToRequestPayload(&req)
data, err := json.Marshal(body)
if err != nil {
return nil, err
}
return bytes.NewReader(data), nil
}
// DoRequest delegates to common helper.
func (a *TaskAdaptor) DoRequest(c *gin.Context, info *relaycommon.TaskRelayInfo, requestBody io.Reader) (*http.Response, error) {
return channel.DoTaskApiRequest(a, c, info, requestBody)
}
// DoResponse handles upstream response, returns taskID etc.
func (a *TaskAdaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.TaskRelayInfo) (taskID string, taskData []byte, taskErr *dto.TaskError) {
responseBody, err := io.ReadAll(resp.Body)
if err != nil {
taskErr = service.TaskErrorWrapper(err, "read_response_body_failed", http.StatusInternalServerError)
return
}
// Attempt Kling response parse first.
var kResp responsePayload
if err := json.Unmarshal(responseBody, &kResp); err == nil && kResp.Code == 0 {
c.JSON(http.StatusOK, gin.H{"task_id": kResp.Data.TaskID})
return kResp.Data.TaskID, responseBody, nil
}
// Fallback generic task response.
var generic dto.TaskResponse[string]
if err := json.Unmarshal(responseBody, &generic); err != nil {
taskErr = service.TaskErrorWrapper(errors.Wrapf(err, "body: %s", responseBody), "unmarshal_response_body_failed", http.StatusInternalServerError)
return
}
if !generic.IsSuccess() {
taskErr = service.TaskErrorWrapper(fmt.Errorf(generic.Message), generic.Code, http.StatusInternalServerError)
return
}
c.JSON(http.StatusOK, gin.H{"task_id": generic.Data})
return generic.Data, responseBody, nil
}
// FetchTask fetch task status
func (a *TaskAdaptor) FetchTask(baseUrl, key string, body map[string]any) (*http.Response, error) {
taskID, ok := body["task_id"].(string)
if !ok {
return nil, fmt.Errorf("invalid task_id")
}
url := fmt.Sprintf("%s/v1/videos/image2video/%s", baseUrl, taskID)
req, err := http.NewRequest(http.MethodGet, url, nil)
if err != nil {
return nil, err
}
token, err := a.createJWTTokenWithKey(key)
if err != nil {
token = key
}
ctx, cancel := context.WithTimeout(context.Background(), 15*time.Second)
defer cancel()
req = req.WithContext(ctx)
req.Header.Set("Accept", "application/json")
req.Header.Set("Authorization", "Bearer "+token)
req.Header.Set("User-Agent", "kling-sdk/1.0")
return service.GetHttpClient().Do(req)
}
func (a *TaskAdaptor) GetModelList() []string {
return []string{"kling-v1", "kling-v1-6", "kling-v2-master"}
}
func (a *TaskAdaptor) GetChannelName() string {
return "kling"
}
// ============================
// helpers
// ============================
func (a *TaskAdaptor) convertToRequestPayload(req *SubmitReq) *requestPayload {
r := &requestPayload{
Prompt: req.Prompt,
Image: req.Image,
Mode: defaultString(req.Mode, "std"),
Duration: fmt.Sprintf("%d", defaultInt(req.Duration, 5)),
AspectRatio: a.getAspectRatio(req.Size),
Model: req.Model,
ModelName: req.Model,
CfgScale: 0.5,
}
if r.Model == "" {
r.Model = "kling-v1"
r.ModelName = "kling-v1"
}
return r
}
func (a *TaskAdaptor) getAspectRatio(size string) string {
switch size {
case "1024x1024", "512x512":
return "1:1"
case "1280x720", "1920x1080":
return "16:9"
case "720x1280", "1080x1920":
return "9:16"
default:
return "1:1"
}
}
func defaultString(s, def string) string {
if strings.TrimSpace(s) == "" {
return def
}
return s
}
func defaultInt(v int, def int) int {
if v == 0 {
return def
}
return v
}
// ============================
// JWT helpers
// ============================
func (a *TaskAdaptor) createJWTToken() (string, error) {
return a.createJWTTokenWithKeys(a.accessKey, a.secretKey)
}
func (a *TaskAdaptor) createJWTTokenWithKey(apiKey string) (string, error) {
parts := strings.Split(apiKey, ",")
if len(parts) != 2 {
return "", fmt.Errorf("invalid API key format, expected 'access_key,secret_key'")
}
return a.createJWTTokenWithKeys(strings.TrimSpace(parts[0]), strings.TrimSpace(parts[1]))
}
func (a *TaskAdaptor) createJWTTokenWithKeys(accessKey, secretKey string) (string, error) {
if accessKey == "" || secretKey == "" {
return "", fmt.Errorf("access key and secret key are required")
}
now := time.Now().Unix()
claims := jwt.MapClaims{
"iss": accessKey,
"exp": now + 1800, // 30 minutes
"nbf": now - 5,
}
token := jwt.NewWithClaims(jwt.SigningMethodHS256, claims)
token.Header["typ"] = "JWT"
return token.SignedString([]byte(secretKey))
}
// ParseResultUrl 提取视频任务结果的 url
func (a *TaskAdaptor) ParseResultUrl(resp map[string]any) (string, error) {
data, ok := resp["data"].(map[string]any)
if !ok {
return "", fmt.Errorf("data field not found or invalid")
}
taskResult, ok := data["task_result"].(map[string]any)
if !ok {
return "", fmt.Errorf("task_result field not found or invalid")
}
videos, ok := taskResult["videos"].([]interface{})
if !ok || len(videos) == 0 {
return "", fmt.Errorf("videos field not found or empty")
}
video, ok := videos[0].(map[string]interface{})
if !ok {
return "", fmt.Errorf("video item invalid")
}
url, ok := video["url"].(string)
if !ok || url == "" {
return "", fmt.Errorf("url field not found or invalid")
}
return url, nil
}

View File

@@ -22,6 +22,10 @@ type TaskAdaptor struct {
ChannelType int
}
func (a *TaskAdaptor) ParseResultUrl(resp map[string]any) (string, error) {
return "", nil // todo implement this method if needed
}
func (a *TaskAdaptor) Init(info *relaycommon.TaskRelayInfo) {
a.ChannelType = info.ChannelType
}

View File

@@ -98,7 +98,7 @@ func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycom
if info.IsStream {
var responseText string
err, responseText = tencentStreamHandler(c, resp)
usage, _ = service.ResponseText2Usage(responseText, info.UpstreamModelName, info.PromptTokens)
usage = service.ResponseText2Usage(responseText, info.UpstreamModelName, info.PromptTokens)
} else {
err, usage = tencentHandler(c, resp)
}

View File

@@ -83,10 +83,13 @@ func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) {
suffix := ""
if a.RequestMode == RequestModeGemini {
if model_setting.GetGeminiSettings().ThinkingAdapterEnabled {
// suffix -thinking and -nothinking
if strings.HasSuffix(info.OriginModelName, "-thinking") {
// 新增逻辑:处理 -thinking-<budget> 格式
if strings.Contains(info.UpstreamModelName, "-thinking-") {
parts := strings.Split(info.UpstreamModelName, "-thinking-")
info.UpstreamModelName = parts[0]
} else if strings.HasSuffix(info.UpstreamModelName, "-thinking") { // 旧的适配
info.UpstreamModelName = strings.TrimSuffix(info.UpstreamModelName, "-thinking")
} else if strings.HasSuffix(info.OriginModelName, "-nothinking") {
} else if strings.HasSuffix(info.UpstreamModelName, "-nothinking") {
info.UpstreamModelName = strings.TrimSuffix(info.UpstreamModelName, "-nothinking")
}
}
@@ -123,14 +126,23 @@ func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) {
if v, ok := claudeModelMap[info.UpstreamModelName]; ok {
model = v
}
return fmt.Sprintf(
"https://%s-aiplatform.googleapis.com/v1/projects/%s/locations/%s/publishers/anthropic/models/%s:%s",
region,
adc.ProjectID,
region,
model,
suffix,
), nil
if region == "global" {
return fmt.Sprintf(
"https://aiplatform.googleapis.com/v1/projects/%s/locations/global/publishers/anthropic/models/%s:%s",
adc.ProjectID,
model,
suffix,
), nil
} else {
return fmt.Sprintf(
"https://%s-aiplatform.googleapis.com/v1/projects/%s/locations/%s/publishers/anthropic/models/%s:%s",
region,
adc.ProjectID,
region,
model,
suffix,
), nil
}
} else if a.RequestMode == RequestModeLlama {
return fmt.Sprintf(
"https://%s-aiplatform.googleapis.com/v1beta1/projects/%s/locations/%s/endpoints/openapi/chat/completions",

View File

@@ -1,15 +1,19 @@
package volcengine
import (
"bytes"
"errors"
"fmt"
"io"
"mime/multipart"
"net/http"
"net/textproto"
"one-api/dto"
"one-api/relay/channel"
"one-api/relay/channel/openai"
relaycommon "one-api/relay/common"
"one-api/relay/constant"
"path/filepath"
"strings"
"github.com/gin-gonic/gin"
@@ -30,8 +34,146 @@ func (a *Adaptor) ConvertAudioRequest(c *gin.Context, info *relaycommon.RelayInf
}
func (a *Adaptor) ConvertImageRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.ImageRequest) (any, error) {
//TODO implement me
return nil, errors.New("not implemented")
switch info.RelayMode {
case constant.RelayModeImagesEdits:
var requestBody bytes.Buffer
writer := multipart.NewWriter(&requestBody)
writer.WriteField("model", request.Model)
// 获取所有表单字段
formData := c.Request.PostForm
// 遍历表单字段并打印输出
for key, values := range formData {
if key == "model" {
continue
}
for _, value := range values {
writer.WriteField(key, value)
}
}
// Parse the multipart form to handle both single image and multiple images
if err := c.Request.ParseMultipartForm(32 << 20); err != nil { // 32MB max memory
return nil, errors.New("failed to parse multipart form")
}
if c.Request.MultipartForm != nil && c.Request.MultipartForm.File != nil {
// Check if "image" field exists in any form, including array notation
var imageFiles []*multipart.FileHeader
var exists bool
// First check for standard "image" field
if imageFiles, exists = c.Request.MultipartForm.File["image"]; !exists || len(imageFiles) == 0 {
// If not found, check for "image[]" field
if imageFiles, exists = c.Request.MultipartForm.File["image[]"]; !exists || len(imageFiles) == 0 {
// If still not found, iterate through all fields to find any that start with "image["
foundArrayImages := false
for fieldName, files := range c.Request.MultipartForm.File {
if strings.HasPrefix(fieldName, "image[") && len(files) > 0 {
foundArrayImages = true
for _, file := range files {
imageFiles = append(imageFiles, file)
}
}
}
// If no image fields found at all
if !foundArrayImages && (len(imageFiles) == 0) {
return nil, errors.New("image is required")
}
}
}
// Process all image files
for i, fileHeader := range imageFiles {
file, err := fileHeader.Open()
if err != nil {
return nil, fmt.Errorf("failed to open image file %d: %w", i, err)
}
defer file.Close()
// If multiple images, use image[] as the field name
fieldName := "image"
if len(imageFiles) > 1 {
fieldName = "image[]"
}
// Determine MIME type based on file extension
mimeType := detectImageMimeType(fileHeader.Filename)
// Create a form file with the appropriate content type
h := make(textproto.MIMEHeader)
h.Set("Content-Disposition", fmt.Sprintf(`form-data; name="%s"; filename="%s"`, fieldName, fileHeader.Filename))
h.Set("Content-Type", mimeType)
part, err := writer.CreatePart(h)
if err != nil {
return nil, fmt.Errorf("create form part failed for image %d: %w", i, err)
}
if _, err := io.Copy(part, file); err != nil {
return nil, fmt.Errorf("copy file failed for image %d: %w", i, err)
}
}
// Handle mask file if present
if maskFiles, exists := c.Request.MultipartForm.File["mask"]; exists && len(maskFiles) > 0 {
maskFile, err := maskFiles[0].Open()
if err != nil {
return nil, errors.New("failed to open mask file")
}
defer maskFile.Close()
// Determine MIME type for mask file
mimeType := detectImageMimeType(maskFiles[0].Filename)
// Create a form file with the appropriate content type
h := make(textproto.MIMEHeader)
h.Set("Content-Disposition", fmt.Sprintf(`form-data; name="mask"; filename="%s"`, maskFiles[0].Filename))
h.Set("Content-Type", mimeType)
maskPart, err := writer.CreatePart(h)
if err != nil {
return nil, errors.New("create form file failed for mask")
}
if _, err := io.Copy(maskPart, maskFile); err != nil {
return nil, errors.New("copy mask file failed")
}
}
} else {
return nil, errors.New("no multipart form data found")
}
// 关闭 multipart 编写器以设置分界线
writer.Close()
c.Request.Header.Set("Content-Type", writer.FormDataContentType())
return bytes.NewReader(requestBody.Bytes()), nil
default:
return request, nil
}
}
// detectImageMimeType determines the MIME type based on the file extension
func detectImageMimeType(filename string) string {
ext := strings.ToLower(filepath.Ext(filename))
switch ext {
case ".jpg", ".jpeg":
return "image/jpeg"
case ".png":
return "image/png"
case ".webp":
return "image/webp"
default:
// Try to detect from extension if possible
if strings.HasPrefix(ext, ".jp") {
return "image/jpeg"
}
// Default to png as a fallback
return "image/png"
}
}
func (a *Adaptor) Init(info *relaycommon.RelayInfo) {
@@ -46,6 +188,8 @@ func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) {
return fmt.Sprintf("%s/api/v3/chat/completions", info.BaseUrl), nil
case constant.RelayModeEmbeddings:
return fmt.Sprintf("%s/api/v3/embeddings", info.BaseUrl), nil
case constant.RelayModeImagesGenerations:
return fmt.Sprintf("%s/api/v3/images/generations", info.BaseUrl), nil
default:
}
return "", fmt.Errorf("unsupported relay mode: %d", info.RelayMode)
@@ -91,6 +235,8 @@ func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycom
}
case constant.RelayModeEmbeddings:
err, usage = openai.OpenaiHandler(c, resp, info)
case constant.RelayModeImagesGenerations, constant.RelayModeImagesEdits:
err, usage = openai.OpenaiHandlerWithUsage(c, resp, info)
}
return
}

View File

@@ -68,7 +68,7 @@ func xAIStreamHandler(c *gin.Context, resp *http.Response, info *relaycommon.Rel
})
if !containStreamUsage {
usage, _ = service.ResponseText2Usage(responseTextBuilder.String(), info.UpstreamModelName, info.PromptTokens)
usage = service.ResponseText2Usage(responseTextBuilder.String(), info.UpstreamModelName, info.PromptTokens)
usage.CompletionTokens += toolCount * 7
}

View File

@@ -46,13 +46,11 @@ func ClaudeHelper(c *gin.Context) (claudeError *dto.ClaudeErrorWithStatusCode) {
relayInfo.IsStream = true
}
err = helper.ModelMappedHelper(c, relayInfo)
err = helper.ModelMappedHelper(c, relayInfo, textRequest)
if err != nil {
return service.ClaudeErrorWrapperLocal(err, "model_mapped_error", http.StatusInternalServerError)
}
textRequest.Model = relayInfo.UpstreamModelName
promptTokens, err := getClaudePromptTokens(textRequest, relayInfo)
// count messages token error 计算promptTokens错误
if err != nil {
@@ -126,7 +124,7 @@ func ClaudeHelper(c *gin.Context) (claudeError *dto.ClaudeErrorWithStatusCode) {
var httpResp *http.Response
resp, err := adaptor.DoRequest(c, relayInfo, requestBody)
if err != nil {
return service.ClaudeErrorWrapperLocal(err, "do_request_failed", http.StatusInternalServerError)
return service.ClaudeErrorWrapper(err, "do_request_failed", http.StatusInternalServerError)
}
if resp != nil {

View File

@@ -34,9 +34,14 @@ type ClaudeConvertInfo struct {
}
const (
RelayFormatOpenAI = "openai"
RelayFormatClaude = "claude"
RelayFormatGemini = "gemini"
RelayFormatOpenAI = "openai"
RelayFormatClaude = "claude"
RelayFormatGemini = "gemini"
RelayFormatOpenAIResponses = "openai_responses"
RelayFormatOpenAIAudio = "openai_audio"
RelayFormatOpenAIImage = "openai_image"
RelayFormatRerank = "rerank"
RelayFormatEmbedding = "embedding"
)
type RerankerInfo struct {
@@ -143,6 +148,7 @@ func GenRelayInfoClaude(c *gin.Context) *RelayInfo {
func GenRelayInfoRerank(c *gin.Context, req *dto.RerankRequest) *RelayInfo {
info := GenRelayInfo(c)
info.RelayMode = relayconstant.RelayModeRerank
info.RelayFormat = RelayFormatRerank
info.RerankerInfo = &RerankerInfo{
Documents: req.Documents,
ReturnDocuments: req.GetReturnDocuments(),
@@ -150,9 +156,25 @@ func GenRelayInfoRerank(c *gin.Context, req *dto.RerankRequest) *RelayInfo {
return info
}
func GenRelayInfoOpenAIAudio(c *gin.Context) *RelayInfo {
info := GenRelayInfo(c)
info.RelayFormat = RelayFormatOpenAIAudio
return info
}
func GenRelayInfoEmbedding(c *gin.Context) *RelayInfo {
info := GenRelayInfo(c)
info.RelayFormat = RelayFormatEmbedding
return info
}
func GenRelayInfoResponses(c *gin.Context, req *dto.OpenAIResponsesRequest) *RelayInfo {
info := GenRelayInfo(c)
info.RelayMode = relayconstant.RelayModeResponses
info.RelayFormat = RelayFormatOpenAIResponses
info.SupportStreamOptions = false
info.ResponsesUsageInfo = &ResponsesUsageInfo{
BuiltInTools: make(map[string]*BuildInToolInfo),
}
@@ -175,6 +197,19 @@ func GenRelayInfoResponses(c *gin.Context, req *dto.OpenAIResponsesRequest) *Rel
return info
}
func GenRelayInfoGemini(c *gin.Context) *RelayInfo {
info := GenRelayInfo(c)
info.RelayFormat = RelayFormatGemini
info.ShouldIncludeUsage = false
return info
}
func GenRelayInfoImage(c *gin.Context) *RelayInfo {
info := GenRelayInfo(c)
info.RelayFormat = RelayFormatOpenAIImage
return info
}
func GenRelayInfo(c *gin.Context) *RelayInfo {
channelType := c.GetInt("channel_type")
channelId := c.GetInt("channel_id")
@@ -243,10 +278,6 @@ func GenRelayInfo(c *gin.Context) *RelayInfo {
if streamSupportedChannels[info.ChannelType] {
info.SupportStreamOptions = true
}
// responses 模式不支持 StreamOptions
if relayconstant.RelayModeResponses == info.RelayMode {
info.SupportStreamOptions = false
}
return info
}

View File

@@ -38,6 +38,9 @@ const (
RelayModeSunoFetchByID
RelayModeSunoSubmit
RelayModeKlingFetchByID
RelayModeKlingSubmit
RelayModeRerank
RelayModeResponses
@@ -133,3 +136,13 @@ func Path2RelaySuno(method, path string) int {
}
return relayMode
}
func Path2RelayKling(method, path string) int {
relayMode := RelayModeUnknown
if method == http.MethodPost && strings.HasSuffix(path, "/video/generations") {
relayMode = RelayModeKlingSubmit
} else if method == http.MethodGet && strings.Contains(path, "/video/generations/") {
relayMode = RelayModeKlingFetchByID
}
return relayMode
}

View File

@@ -15,7 +15,7 @@ import (
)
func getEmbeddingPromptToken(embeddingRequest dto.EmbeddingRequest) int {
token, _ := service.CountTokenInput(embeddingRequest.Input, embeddingRequest.Model)
token := service.CountTokenInput(embeddingRequest.Input, embeddingRequest.Model)
return token
}
@@ -33,7 +33,7 @@ func validateEmbeddingRequest(c *gin.Context, info *relaycommon.RelayInfo, embed
}
func EmbeddingHelper(c *gin.Context) (openaiErr *dto.OpenAIErrorWithStatusCode) {
relayInfo := relaycommon.GenRelayInfo(c)
relayInfo := relaycommon.GenRelayInfoEmbedding(c)
var embeddingRequest *dto.EmbeddingRequest
err := common.UnmarshalBodyReusable(c, &embeddingRequest)
@@ -47,13 +47,11 @@ func EmbeddingHelper(c *gin.Context) (openaiErr *dto.OpenAIErrorWithStatusCode)
return service.OpenAIErrorWrapperLocal(err, "invalid_embedding_request", http.StatusBadRequest)
}
err = helper.ModelMappedHelper(c, relayInfo)
err = helper.ModelMappedHelper(c, relayInfo, embeddingRequest)
if err != nil {
return service.OpenAIErrorWrapperLocal(err, "model_mapped_error", http.StatusInternalServerError)
}
embeddingRequest.Model = relayInfo.UpstreamModelName
promptToken := getEmbeddingPromptToken(*embeddingRequest)
relayInfo.PromptTokens = promptToken

View File

@@ -59,7 +59,7 @@ func checkGeminiInputSensitive(textRequest *gemini.GeminiChatRequest) ([]string,
return sensitiveWords, err
}
func getGeminiInputTokens(req *gemini.GeminiChatRequest, info *relaycommon.RelayInfo) (int, error) {
func getGeminiInputTokens(req *gemini.GeminiChatRequest, info *relaycommon.RelayInfo) int {
// 计算输入 token 数量
var inputTexts []string
for _, content := range req.Contents {
@@ -71,9 +71,9 @@ func getGeminiInputTokens(req *gemini.GeminiChatRequest, info *relaycommon.Relay
}
inputText := strings.Join(inputTexts, "\n")
inputTokens, err := service.CountTokenInput(inputText, info.UpstreamModelName)
inputTokens := service.CountTokenInput(inputText, info.UpstreamModelName)
info.PromptTokens = inputTokens
return inputTokens, err
return inputTokens
}
func GeminiHelper(c *gin.Context) (openaiErr *dto.OpenAIErrorWithStatusCode) {
@@ -83,7 +83,7 @@ func GeminiHelper(c *gin.Context) (openaiErr *dto.OpenAIErrorWithStatusCode) {
return service.OpenAIErrorWrapperLocal(err, "invalid_gemini_request", http.StatusBadRequest)
}
relayInfo := relaycommon.GenRelayInfo(c)
relayInfo := relaycommon.GenRelayInfoGemini(c)
// 检查 Gemini 流式模式
checkGeminiStreamMode(c, relayInfo)
@@ -97,7 +97,7 @@ func GeminiHelper(c *gin.Context) (openaiErr *dto.OpenAIErrorWithStatusCode) {
}
// model mapped 模型映射
err = helper.ModelMappedHelper(c, relayInfo)
err = helper.ModelMappedHelper(c, relayInfo, req)
if err != nil {
return service.OpenAIErrorWrapperLocal(err, "model_mapped_error", http.StatusBadRequest)
}
@@ -106,7 +106,7 @@ func GeminiHelper(c *gin.Context) (openaiErr *dto.OpenAIErrorWithStatusCode) {
promptTokens := value.(int)
relayInfo.SetPromptTokens(promptTokens)
} else {
promptTokens, err := getGeminiInputTokens(req, relayInfo)
promptTokens := getGeminiInputTokens(req, relayInfo)
if err != nil {
return service.OpenAIErrorWrapperLocal(err, "count_input_tokens_error", http.StatusBadRequest)
}
@@ -155,14 +155,33 @@ func GeminiHelper(c *gin.Context) (openaiErr *dto.OpenAIErrorWithStatusCode) {
return service.OpenAIErrorWrapperLocal(err, "marshal_text_request_failed", http.StatusInternalServerError)
}
if common.DebugEnabled {
println("Gemini request body: %s", string(requestBody))
}
resp, err := adaptor.DoRequest(c, relayInfo, bytes.NewReader(requestBody))
if err != nil {
common.LogError(c, "Do gemini request failed: "+err.Error())
return service.OpenAIErrorWrapperLocal(err, "do_request_failed", http.StatusInternalServerError)
return service.OpenAIErrorWrapper(err, "do_request_failed", http.StatusInternalServerError)
}
statusCodeMappingStr := c.GetString("status_code_mapping")
var httpResp *http.Response
if resp != nil {
httpResp = resp.(*http.Response)
relayInfo.IsStream = relayInfo.IsStream || strings.HasPrefix(httpResp.Header.Get("Content-Type"), "text/event-stream")
if httpResp.StatusCode != http.StatusOK {
openaiErr = service.RelayErrorHandler(httpResp, false)
// reset status code 重置状态码
service.ResetStatusCode(openaiErr, statusCodeMappingStr)
return openaiErr
}
}
usage, openaiErr := adaptor.DoResponse(c, resp.(*http.Response), relayInfo)
if openaiErr != nil {
service.ResetStatusCode(openaiErr, statusCodeMappingStr)
return openaiErr
}

View File

@@ -4,12 +4,14 @@ import (
"encoding/json"
"errors"
"fmt"
common2 "one-api/common"
"one-api/dto"
"one-api/relay/common"
"github.com/gin-gonic/gin"
)
func ModelMappedHelper(c *gin.Context, info *common.RelayInfo) error {
func ModelMappedHelper(c *gin.Context, info *common.RelayInfo, request any) error {
// map model name
modelMapping := c.GetString("model_mapping")
if modelMapping != "" && modelMapping != "{}" {
@@ -50,5 +52,41 @@ func ModelMappedHelper(c *gin.Context, info *common.RelayInfo) error {
info.UpstreamModelName = currentModel
}
}
if request != nil {
switch info.RelayFormat {
case common.RelayFormatGemini:
// Gemini 模型映射
case common.RelayFormatClaude:
if claudeRequest, ok := request.(*dto.ClaudeRequest); ok {
claudeRequest.Model = info.UpstreamModelName
}
case common.RelayFormatOpenAIResponses:
if openAIResponsesRequest, ok := request.(*dto.OpenAIResponsesRequest); ok {
openAIResponsesRequest.Model = info.UpstreamModelName
}
case common.RelayFormatOpenAIAudio:
if openAIAudioRequest, ok := request.(*dto.AudioRequest); ok {
openAIAudioRequest.Model = info.UpstreamModelName
}
case common.RelayFormatOpenAIImage:
if imageRequest, ok := request.(*dto.ImageRequest); ok {
imageRequest.Model = info.UpstreamModelName
}
case common.RelayFormatRerank:
if rerankRequest, ok := request.(*dto.RerankRequest); ok {
rerankRequest.Model = info.UpstreamModelName
}
case common.RelayFormatEmbedding:
if embeddingRequest, ok := request.(*dto.EmbeddingRequest); ok {
embeddingRequest.Model = info.UpstreamModelName
}
default:
if openAIRequest, ok := request.(*dto.GeneralOpenAIRequest); ok {
openAIRequest.Model = info.UpstreamModelName
} else {
common2.LogWarn(c, fmt.Sprintf("model mapped but request type %T not supported", request))
}
}
}
return nil
}

View File

@@ -5,12 +5,16 @@ import (
"one-api/common"
constant2 "one-api/constant"
relaycommon "one-api/relay/common"
"one-api/setting"
"one-api/setting/operation_setting"
"one-api/setting/ratio_setting"
"github.com/gin-gonic/gin"
)
type GroupRatioInfo struct {
GroupRatio float64
GroupSpecialRatio float64
}
type PriceData struct {
ModelPrice float64
ModelRatio float64
@@ -18,23 +22,50 @@ type PriceData struct {
CacheRatio float64
CacheCreationRatio float64
ImageRatio float64
GroupRatio float64
UserGroupRatio float64
UsePrice bool
ShouldPreConsumedQuota int
GroupRatioInfo GroupRatioInfo
}
func (p PriceData) ToSetting() string {
return fmt.Sprintf("ModelPrice: %f, ModelRatio: %f, CompletionRatio: %f, CacheRatio: %f, GroupRatio: %f, UsePrice: %t, CacheCreationRatio: %f, ShouldPreConsumedQuota: %d, ImageRatio: %f", p.ModelPrice, p.ModelRatio, p.CompletionRatio, p.CacheRatio, p.GroupRatio, p.UsePrice, p.CacheCreationRatio, p.ShouldPreConsumedQuota, p.ImageRatio)
return fmt.Sprintf("ModelPrice: %f, ModelRatio: %f, CompletionRatio: %f, CacheRatio: %f, GroupRatio: %f, UsePrice: %t, CacheCreationRatio: %f, ShouldPreConsumedQuota: %d, ImageRatio: %f", p.ModelPrice, p.ModelRatio, p.CompletionRatio, p.CacheRatio, p.GroupRatioInfo.GroupRatio, p.UsePrice, p.CacheCreationRatio, p.ShouldPreConsumedQuota, p.ImageRatio)
}
// HandleGroupRatio checks for "auto_group" in the context and updates the group ratio and relayInfo.Group if present
func HandleGroupRatio(ctx *gin.Context, relayInfo *relaycommon.RelayInfo) GroupRatioInfo {
groupRatioInfo := GroupRatioInfo{
GroupRatio: 1.0, // default ratio
GroupSpecialRatio: -1,
}
// check auto group
autoGroup, exists := ctx.Get("auto_group")
if exists {
if common.DebugEnabled {
println(fmt.Sprintf("final group: %s", autoGroup))
}
relayInfo.Group = autoGroup.(string)
}
// check user group special ratio
userGroupRatio, ok := ratio_setting.GetGroupGroupRatio(relayInfo.UserGroup, relayInfo.Group)
if ok {
// user group special ratio
groupRatioInfo.GroupSpecialRatio = userGroupRatio
groupRatioInfo.GroupRatio = userGroupRatio
} else {
// normal group ratio
groupRatioInfo.GroupRatio = ratio_setting.GetGroupRatio(relayInfo.Group)
}
return groupRatioInfo
}
func ModelPriceHelper(c *gin.Context, info *relaycommon.RelayInfo, promptTokens int, maxTokens int) (PriceData, error) {
modelPrice, usePrice := operation_setting.GetModelPrice(info.OriginModelName, false)
groupRatio := setting.GetGroupRatio(info.Group)
userGroupRatio, ok := setting.GetGroupGroupRatio(info.UserGroup, info.Group)
if ok {
groupRatio = userGroupRatio
}
modelPrice, usePrice := ratio_setting.GetModelPrice(info.OriginModelName, false)
groupRatioInfo := HandleGroupRatio(c, info)
var preConsumedQuota int
var modelRatio float64
var completionRatio float64
@@ -47,7 +78,7 @@ func ModelPriceHelper(c *gin.Context, info *relaycommon.RelayInfo, promptTokens
preConsumedTokens = promptTokens + maxTokens
}
var success bool
modelRatio, success = operation_setting.GetModelRatio(info.OriginModelName)
modelRatio, success = ratio_setting.GetModelRatio(info.OriginModelName)
if !success {
acceptUnsetRatio := false
if accept, ok := info.UserSetting[constant2.UserAcceptUnsetRatioModel]; ok {
@@ -60,22 +91,21 @@ func ModelPriceHelper(c *gin.Context, info *relaycommon.RelayInfo, promptTokens
return PriceData{}, fmt.Errorf("模型 %s 倍率或价格未配置请联系管理员设置或开始自用模式Model %s ratio or price not set, please set or start self-use mode", info.OriginModelName, info.OriginModelName)
}
}
completionRatio = operation_setting.GetCompletionRatio(info.OriginModelName)
cacheRatio, _ = operation_setting.GetCacheRatio(info.OriginModelName)
cacheCreationRatio, _ = operation_setting.GetCreateCacheRatio(info.OriginModelName)
imageRatio, _ = operation_setting.GetImageRatio(info.OriginModelName)
ratio := modelRatio * groupRatio
completionRatio = ratio_setting.GetCompletionRatio(info.OriginModelName)
cacheRatio, _ = ratio_setting.GetCacheRatio(info.OriginModelName)
cacheCreationRatio, _ = ratio_setting.GetCreateCacheRatio(info.OriginModelName)
imageRatio, _ = ratio_setting.GetImageRatio(info.OriginModelName)
ratio := modelRatio * groupRatioInfo.GroupRatio
preConsumedQuota = int(float64(preConsumedTokens) * ratio)
} else {
preConsumedQuota = int(modelPrice * common.QuotaPerUnit * groupRatio)
preConsumedQuota = int(modelPrice * common.QuotaPerUnit * groupRatioInfo.GroupRatio)
}
priceData := PriceData{
ModelPrice: modelPrice,
ModelRatio: modelRatio,
CompletionRatio: completionRatio,
GroupRatio: groupRatio,
UserGroupRatio: userGroupRatio,
GroupRatioInfo: groupRatioInfo,
UsePrice: usePrice,
CacheRatio: cacheRatio,
ImageRatio: imageRatio,
@@ -91,11 +121,11 @@ func ModelPriceHelper(c *gin.Context, info *relaycommon.RelayInfo, promptTokens
}
func ContainPriceOrRatio(modelName string) bool {
_, ok := operation_setting.GetModelPrice(modelName, false)
_, ok := ratio_setting.GetModelPrice(modelName, false)
if ok {
return true
}
_, ok = operation_setting.GetModelRatio(modelName)
_, ok = ratio_setting.GetModelRatio(modelName)
if ok {
return true
}

View File

@@ -17,6 +17,8 @@ import (
"one-api/setting"
"strings"
"one-api/relay/constant"
"github.com/gin-gonic/gin"
)
@@ -44,6 +46,11 @@ func getAndValidImageRequest(c *gin.Context, info *relaycommon.RelayInfo) (*dto.
if imageRequest.N == 0 {
imageRequest.N = 1
}
if info.ApiType == constant.APITypeVolcEngine {
watermark := formData.Has("watermark")
imageRequest.Watermark = &watermark
}
default:
err := common.UnmarshalBodyReusable(c, imageRequest)
if err != nil {
@@ -102,7 +109,7 @@ func getAndValidImageRequest(c *gin.Context, info *relaycommon.RelayInfo) (*dto.
}
func ImageHelper(c *gin.Context) *dto.OpenAIErrorWithStatusCode {
relayInfo := relaycommon.GenRelayInfo(c)
relayInfo := relaycommon.GenRelayInfoImage(c)
imageRequest, err := getAndValidImageRequest(c, relayInfo)
if err != nil {
@@ -110,13 +117,11 @@ func ImageHelper(c *gin.Context) *dto.OpenAIErrorWithStatusCode {
return service.OpenAIErrorWrapper(err, "invalid_image_request", http.StatusBadRequest)
}
err = helper.ModelMappedHelper(c, relayInfo)
err = helper.ModelMappedHelper(c, relayInfo, imageRequest)
if err != nil {
return service.OpenAIErrorWrapperLocal(err, "model_mapped_error", http.StatusInternalServerError)
}
imageRequest.Model = relayInfo.UpstreamModelName
priceData, err := helper.ModelPriceHelper(c, relayInfo, len(imageRequest.Prompt), 0)
if err != nil {
return service.OpenAIErrorWrapperLocal(err, "model_price_error", http.StatusInternalServerError)
@@ -162,7 +167,7 @@ func ImageHelper(c *gin.Context) *dto.OpenAIErrorWithStatusCode {
// reset model price
priceData.ModelPrice *= sizeRatio * qualityRatio * float64(imageRequest.N)
quota = int(priceData.ModelPrice * priceData.GroupRatio * common.QuotaPerUnit)
quota = int(priceData.ModelPrice * priceData.GroupRatioInfo.GroupRatio * common.QuotaPerUnit)
userQuota, err = model.GetUserQuota(relayInfo.UserId, false)
if err != nil {
return service.OpenAIErrorWrapperLocal(err, "get_user_quota_failed", http.StatusInternalServerError)

View File

@@ -15,7 +15,7 @@ import (
relayconstant "one-api/relay/constant"
"one-api/service"
"one-api/setting"
"one-api/setting/operation_setting"
"one-api/setting/ratio_setting"
"strconv"
"strings"
"time"
@@ -174,17 +174,17 @@ func RelaySwapFace(c *gin.Context) *dto.MidjourneyResponse {
return service.MidjourneyErrorWrapper(constant.MjRequestError, "sour_base64_and_target_base64_is_required")
}
modelName := service.CoverActionToModelName(constant.MjActionSwapFace)
modelPrice, success := operation_setting.GetModelPrice(modelName, true)
modelPrice, success := ratio_setting.GetModelPrice(modelName, true)
// 如果没有配置价格,则使用默认价格
if !success {
defaultPrice, ok := operation_setting.GetDefaultModelRatioMap()[modelName]
defaultPrice, ok := ratio_setting.GetDefaultModelRatioMap()[modelName]
if !ok {
modelPrice = 0.1
} else {
modelPrice = defaultPrice
}
}
groupRatio := setting.GetGroupRatio(group)
groupRatio := ratio_setting.GetGroupRatio(group)
ratio := modelPrice * groupRatio
userQuota, err := model.GetUserQuota(userId, false)
if err != nil {
@@ -480,17 +480,17 @@ func RelayMidjourneySubmit(c *gin.Context, relayMode int) *dto.MidjourneyRespons
fullRequestURL := fmt.Sprintf("%s%s", baseURL, requestURL)
modelName := service.CoverActionToModelName(midjRequest.Action)
modelPrice, success := operation_setting.GetModelPrice(modelName, true)
modelPrice, success := ratio_setting.GetModelPrice(modelName, true)
// 如果没有配置价格,则使用默认价格
if !success {
defaultPrice, ok := operation_setting.GetDefaultModelRatioMap()[modelName]
defaultPrice, ok := ratio_setting.GetDefaultModelRatioMap()[modelName]
if !ok {
modelPrice = 0.1
} else {
modelPrice = defaultPrice
}
}
groupRatio := setting.GetGroupRatio(group)
groupRatio := ratio_setting.GetGroupRatio(group)
ratio := modelPrice * groupRatio
userQuota, err := model.GetUserQuota(userId, false)
if err != nil {

View File

@@ -90,15 +90,16 @@ func TextHelper(c *gin.Context) (openaiErr *dto.OpenAIErrorWithStatusCode) {
// get & validate textRequest 获取并验证文本请求
textRequest, err := getAndValidateTextRequest(c, relayInfo)
if textRequest.WebSearchOptions != nil {
c.Set("chat_completion_web_search_context_size", textRequest.WebSearchOptions.SearchContextSize)
}
if err != nil {
common.LogError(c, fmt.Sprintf("getAndValidateTextRequest failed: %s", err.Error()))
return service.OpenAIErrorWrapperLocal(err, "invalid_text_request", http.StatusBadRequest)
}
if textRequest.WebSearchOptions != nil {
c.Set("chat_completion_web_search_context_size", textRequest.WebSearchOptions.SearchContextSize)
}
if setting.ShouldCheckPromptSensitive() {
words, err := checkRequestSensitive(textRequest, relayInfo)
if err != nil {
@@ -107,13 +108,11 @@ func TextHelper(c *gin.Context) (openaiErr *dto.OpenAIErrorWithStatusCode) {
}
}
err = helper.ModelMappedHelper(c, relayInfo)
err = helper.ModelMappedHelper(c, relayInfo, textRequest)
if err != nil {
return service.OpenAIErrorWrapperLocal(err, "model_mapped_error", http.StatusInternalServerError)
}
textRequest.Model = relayInfo.UpstreamModelName
// 获取 promptTokens如果上下文中已经存在则直接使用
var promptTokens int
if value, exists := c.Get("prompt_tokens"); exists {
@@ -252,11 +251,11 @@ func getPromptTokens(textRequest *dto.GeneralOpenAIRequest, info *relaycommon.Re
case relayconstant.RelayModeChatCompletions:
promptTokens, err = service.CountTokenChatRequest(info, *textRequest)
case relayconstant.RelayModeCompletions:
promptTokens, err = service.CountTokenInput(textRequest.Prompt, textRequest.Model)
promptTokens = service.CountTokenInput(textRequest.Prompt, textRequest.Model)
case relayconstant.RelayModeModerations:
promptTokens, err = service.CountTokenInput(textRequest.Input, textRequest.Model)
promptTokens = service.CountTokenInput(textRequest.Input, textRequest.Model)
case relayconstant.RelayModeEmbeddings:
promptTokens, err = service.CountTokenInput(textRequest.Input, textRequest.Model)
promptTokens = service.CountTokenInput(textRequest.Input, textRequest.Model)
default:
err = errors.New("unknown relay mode")
promptTokens = 0
@@ -361,9 +360,8 @@ func postConsumeQuota(ctx *gin.Context, relayInfo *relaycommon.RelayInfo,
cacheRatio := priceData.CacheRatio
imageRatio := priceData.ImageRatio
modelRatio := priceData.ModelRatio
groupRatio := priceData.GroupRatio
groupRatio := priceData.GroupRatioInfo.GroupRatio
modelPrice := priceData.ModelPrice
userGroupRatio := priceData.UserGroupRatio
// Convert values to decimal for precise calculation
dPromptTokens := decimal.NewFromInt(int64(promptTokens))
@@ -511,7 +509,7 @@ func postConsumeQuota(ctx *gin.Context, relayInfo *relaycommon.RelayInfo,
if extraContent != "" {
logContent += ", " + extraContent
}
other := service.GenerateTextOtherInfo(ctx, relayInfo, modelRatio, groupRatio, completionRatio, cacheTokens, cacheRatio, modelPrice, userGroupRatio)
other := service.GenerateTextOtherInfo(ctx, relayInfo, modelRatio, groupRatio, completionRatio, cacheTokens, cacheRatio, modelPrice, priceData.GroupRatioInfo.GroupSpecialRatio)
if imageTokens != 0 {
other["image"] = true
other["image_ratio"] = imageRatio

View File

@@ -22,6 +22,7 @@ import (
"one-api/relay/channel/palm"
"one-api/relay/channel/perplexity"
"one-api/relay/channel/siliconflow"
"one-api/relay/channel/task/kling"
"one-api/relay/channel/task/suno"
"one-api/relay/channel/tencent"
"one-api/relay/channel/vertex"
@@ -101,6 +102,8 @@ func GetTaskAdaptor(platform commonconstant.TaskPlatform) channel.TaskAdaptor {
// return &aiproxy.Adaptor{}
case commonconstant.TaskPlatformSuno:
return &suno.TaskAdaptor{}
case commonconstant.TaskPlatformKling:
return &kling.TaskAdaptor{}
}
return nil
}

View File

@@ -15,8 +15,7 @@ import (
relaycommon "one-api/relay/common"
relayconstant "one-api/relay/constant"
"one-api/service"
"one-api/setting"
"one-api/setting/operation_setting"
"one-api/setting/ratio_setting"
)
/*
@@ -38,9 +37,12 @@ func RelayTaskSubmit(c *gin.Context, relayMode int) (taskErr *dto.TaskError) {
}
modelName := service.CoverTaskActionToModelName(platform, relayInfo.Action)
modelPrice, success := operation_setting.GetModelPrice(modelName, true)
if platform == constant.TaskPlatformKling {
modelName = relayInfo.OriginModelName
}
modelPrice, success := ratio_setting.GetModelPrice(modelName, true)
if !success {
defaultPrice, ok := operation_setting.GetDefaultModelRatioMap()[modelName]
defaultPrice, ok := ratio_setting.GetDefaultModelRatioMap()[modelName]
if !ok {
modelPrice = 0.1
} else {
@@ -49,7 +51,7 @@ func RelayTaskSubmit(c *gin.Context, relayMode int) (taskErr *dto.TaskError) {
}
// 预扣
groupRatio := setting.GetGroupRatio(relayInfo.Group)
groupRatio := ratio_setting.GetGroupRatio(relayInfo.Group)
ratio := modelPrice * groupRatio
userQuota, err := model.GetUserQuota(relayInfo.UserId, false)
if err != nil {
@@ -137,10 +139,11 @@ func RelayTaskSubmit(c *gin.Context, relayMode int) (taskErr *dto.TaskError) {
}
relayInfo.ConsumeQuota = true
// insert task
task := model.InitTask(constant.TaskPlatformSuno, relayInfo)
task := model.InitTask(platform, relayInfo)
task.TaskID = taskID
task.Quota = quota
task.Data = taskData
task.Action = relayInfo.Action
err = task.Insert()
if err != nil {
taskErr = service.TaskErrorWrapper(err, "insert_task_failed", http.StatusInternalServerError)
@@ -150,8 +153,9 @@ func RelayTaskSubmit(c *gin.Context, relayMode int) (taskErr *dto.TaskError) {
}
var fetchRespBuilders = map[int]func(c *gin.Context) (respBody []byte, taskResp *dto.TaskError){
relayconstant.RelayModeSunoFetchByID: sunoFetchByIDRespBodyBuilder,
relayconstant.RelayModeSunoFetch: sunoFetchRespBodyBuilder,
relayconstant.RelayModeSunoFetchByID: sunoFetchByIDRespBodyBuilder,
relayconstant.RelayModeSunoFetch: sunoFetchRespBodyBuilder,
relayconstant.RelayModeKlingFetchByID: videoFetchByIDRespBodyBuilder,
}
func RelayTaskFetch(c *gin.Context, relayMode int) (taskResp *dto.TaskError) {
@@ -226,6 +230,27 @@ func sunoFetchByIDRespBodyBuilder(c *gin.Context) (respBody []byte, taskResp *dt
return
}
func videoFetchByIDRespBodyBuilder(c *gin.Context) (respBody []byte, taskResp *dto.TaskError) {
taskId := c.Param("id")
userId := c.GetInt("id")
originTask, exist, err := model.GetByTaskId(userId, taskId)
if err != nil {
taskResp = service.TaskErrorWrapper(err, "get_task_failed", http.StatusInternalServerError)
return
}
if !exist {
taskResp = service.TaskErrorWrapperLocal(errors.New("task_not_exist"), "task_not_exist", http.StatusBadRequest)
return
}
respBody, err = json.Marshal(dto.TaskResponse[any]{
Code: "success",
Data: TaskModel2Dto(originTask),
})
return
}
func TaskModel2Dto(task *model.Task) *dto.TaskDto {
return &dto.TaskDto{
TaskID: task.TaskID,

View File

@@ -14,12 +14,10 @@ import (
)
func getRerankPromptToken(rerankRequest dto.RerankRequest) int {
token, _ := service.CountTokenInput(rerankRequest.Query, rerankRequest.Model)
token := service.CountTokenInput(rerankRequest.Query, rerankRequest.Model)
for _, document := range rerankRequest.Documents {
tkm, err := service.CountTokenInput(document, rerankRequest.Model)
if err == nil {
token += tkm
}
tkm := service.CountTokenInput(document, rerankRequest.Model)
token += tkm
}
return token
}
@@ -42,13 +40,11 @@ func RerankHelper(c *gin.Context, relayMode int) (openaiErr *dto.OpenAIErrorWith
return service.OpenAIErrorWrapperLocal(fmt.Errorf("documents is empty"), "invalid_documents", http.StatusBadRequest)
}
err = helper.ModelMappedHelper(c, relayInfo)
err = helper.ModelMappedHelper(c, relayInfo, rerankRequest)
if err != nil {
return service.OpenAIErrorWrapperLocal(err, "model_mapped_error", http.StatusInternalServerError)
}
rerankRequest.Model = relayInfo.UpstreamModelName
promptToken := getRerankPromptToken(*rerankRequest)
relayInfo.PromptTokens = promptToken

View File

@@ -40,10 +40,10 @@ func checkInputSensitive(textRequest *dto.OpenAIResponsesRequest, info *relaycom
return sensitiveWords, err
}
func getInputTokens(req *dto.OpenAIResponsesRequest, info *relaycommon.RelayInfo) (int, error) {
inputTokens, err := service.CountTokenInput(req.Input, req.Model)
func getInputTokens(req *dto.OpenAIResponsesRequest, info *relaycommon.RelayInfo) int {
inputTokens := service.CountTokenInput(req.Input, req.Model)
info.PromptTokens = inputTokens
return inputTokens, err
return inputTokens
}
func ResponsesHelper(c *gin.Context) (openaiErr *dto.OpenAIErrorWithStatusCode) {
@@ -63,19 +63,16 @@ func ResponsesHelper(c *gin.Context) (openaiErr *dto.OpenAIErrorWithStatusCode)
}
}
err = helper.ModelMappedHelper(c, relayInfo)
err = helper.ModelMappedHelper(c, relayInfo, req)
if err != nil {
return service.OpenAIErrorWrapperLocal(err, "model_mapped_error", http.StatusBadRequest)
}
req.Model = relayInfo.UpstreamModelName
if value, exists := c.Get("prompt_tokens"); exists {
promptTokens := value.(int)
relayInfo.SetPromptTokens(promptTokens)
} else {
promptTokens, err := getInputTokens(req, relayInfo)
if err != nil {
return service.OpenAIErrorWrapperLocal(err, "count_input_tokens_error", http.StatusBadRequest)
}
promptTokens := getInputTokens(req, relayInfo)
c.Set("prompt_tokens", promptTokens)
}

View File

@@ -6,12 +6,10 @@ import (
"github.com/gin-gonic/gin"
"github.com/gorilla/websocket"
"net/http"
"one-api/common"
"one-api/dto"
relaycommon "one-api/relay/common"
"one-api/relay/helper"
"one-api/service"
"one-api/setting"
"one-api/setting/operation_setting"
)
func WssHelper(c *gin.Context, ws *websocket.Conn) (openaiErr *dto.OpenAIErrorWithStatusCode) {
@@ -39,43 +37,14 @@ func WssHelper(c *gin.Context, ws *websocket.Conn) (openaiErr *dto.OpenAIErrorWi
//isModelMapped = true
}
}
//relayInfo.UpstreamModelName = textRequest.Model
modelPrice, getModelPriceSuccess := operation_setting.GetModelPrice(relayInfo.UpstreamModelName, false)
groupRatio := setting.GetGroupRatio(relayInfo.Group)
var preConsumedQuota int
var ratio float64
var modelRatio float64
//err := service.SensitiveWordsCheck(textRequest)
//if constant.ShouldCheckPromptSensitive() {
// err = checkRequestSensitive(textRequest, relayInfo)
// if err != nil {
// return service.OpenAIErrorWrapperLocal(err, "sensitive_words_detected", http.StatusBadRequest)
// }
//}
//promptTokens, err := getWssPromptTokens(realtimeEvent, relayInfo)
//// count messages token error 计算promptTokens错误
//if err != nil {
// return service.OpenAIErrorWrapper(err, "count_token_messages_failed", http.StatusInternalServerError)
//}
//
if !getModelPriceSuccess {
preConsumedTokens := common.PreConsumedQuota
//if realtimeEvent.Session.MaxResponseOutputTokens != 0 {
// preConsumedTokens = promptTokens + int(realtimeEvent.Session.MaxResponseOutputTokens)
//}
modelRatio, _ = operation_setting.GetModelRatio(relayInfo.UpstreamModelName)
ratio = modelRatio * groupRatio
preConsumedQuota = int(float64(preConsumedTokens) * ratio)
} else {
preConsumedQuota = int(modelPrice * common.QuotaPerUnit * groupRatio)
relayInfo.UsePrice = true
priceData, err := helper.ModelPriceHelper(c, relayInfo, 0, 0)
if err != nil {
return service.OpenAIErrorWrapperLocal(err, "model_price_error", http.StatusInternalServerError)
}
// pre-consume quota 预消耗配额
preConsumedQuota, userQuota, openaiErr := preConsumeQuota(c, preConsumedQuota, relayInfo)
preConsumedQuota, userQuota, openaiErr := preConsumeQuota(c, priceData.ShouldPreConsumedQuota, relayInfo)
if openaiErr != nil {
return openaiErr
}
@@ -113,6 +82,6 @@ func WssHelper(c *gin.Context, ws *websocket.Conn) (openaiErr *dto.OpenAIErrorWi
return openaiErr
}
service.PostWssConsumeQuota(c, relayInfo, relayInfo.UpstreamModelName, usage.(*dto.RealtimeUsage), preConsumedQuota,
userQuota, modelRatio, groupRatio, modelPrice, getModelPriceSuccess, "")
userQuota, priceData, "")
return nil
}

View File

@@ -36,6 +36,7 @@ func SetApiRouter(router *gin.Engine) {
apiRouter.GET("/oauth/email/bind", middleware.CriticalRateLimit(), controller.EmailBind)
apiRouter.GET("/oauth/telegram/login", middleware.CriticalRateLimit(), controller.TelegramLogin)
apiRouter.GET("/oauth/telegram/bind", middleware.CriticalRateLimit(), controller.TelegramBind)
apiRouter.GET("/ratio_config", middleware.CriticalRateLimit(), controller.GetRatioConfig)
userRoute := apiRouter.Group("/user")
{
@@ -83,6 +84,12 @@ func SetApiRouter(router *gin.Engine) {
optionRoute.POST("/rest_model_ratio", controller.ResetModelRatio)
optionRoute.POST("/migrate_console_setting", controller.MigrateConsoleSetting) // 用于迁移检测的旧键,下个版本会删除
}
ratioSyncRoute := apiRouter.Group("/ratio_sync")
ratioSyncRoute.Use(middleware.RootAuth())
{
ratioSyncRoute.GET("/channels", controller.GetSyncableChannels)
ratioSyncRoute.POST("/fetch", controller.FetchUpstreamRatios)
}
channelRoute := apiRouter.Group("/channel")
channelRoute.Use(middleware.AdminAuth())
{

View File

@@ -14,6 +14,7 @@ func SetRouter(router *gin.Engine, buildFS embed.FS, indexPage []byte) {
SetApiRouter(router)
SetDashboardRouter(router)
SetRelayRouter(router)
SetVideoRouter(router)
frontendBaseUrl := os.Getenv("FRONTEND_BASE_URL")
if common.IsMasterNode && frontendBaseUrl != "" {
frontendBaseUrl = ""

17
router/video-router.go Normal file
View File

@@ -0,0 +1,17 @@
package router
import (
"one-api/controller"
"one-api/middleware"
"github.com/gin-gonic/gin"
)
func SetVideoRouter(router *gin.Engine) {
videoV1Router := router.Group("/v1")
videoV1Router.Use(middleware.TokenAuth(), middleware.Distribute())
{
videoV1Router.POST("/video/generations", controller.RelayTask)
videoV1Router.GET("/video/generations/:task_id", controller.RelayTask)
}
}

View File

@@ -59,6 +59,8 @@ func ShouldDisableChannel(channelType int, err *dto.OpenAIErrorWithStatusCode) b
return true
case "billing_not_active":
return true
case "pre_consume_token_quota_failed":
return true
}
switch err.Error.Type {
case "insufficient_quota":

View File

@@ -29,9 +29,11 @@ func MidjourneyErrorWithStatusCodeWrapper(code int, desc string, statusCode int)
func OpenAIErrorWrapper(err error, code string, statusCode int) *dto.OpenAIErrorWithStatusCode {
text := err.Error()
lowerText := strings.ToLower(text)
if strings.Contains(lowerText, "post") || strings.Contains(lowerText, "dial") || strings.Contains(lowerText, "http") {
common.SysLog(fmt.Sprintf("error: %s", text))
text = "请求上游地址失败"
if !strings.HasPrefix(lowerText, "get file base64 from url") && !strings.HasPrefix(lowerText, "mime type is not supported") {
if strings.Contains(lowerText, "post") || strings.Contains(lowerText, "dial") || strings.Contains(lowerText, "http") {
common.SysLog(fmt.Sprintf("error: %s", text))
text = "请求上游地址失败"
}
}
openAIError := dto.OpenAIError{
Message: text,
@@ -53,9 +55,11 @@ func OpenAIErrorWrapperLocal(err error, code string, statusCode int) *dto.OpenAI
func ClaudeErrorWrapper(err error, code string, statusCode int) *dto.ClaudeErrorWithStatusCode {
text := err.Error()
lowerText := strings.ToLower(text)
if strings.Contains(lowerText, "post") || strings.Contains(lowerText, "dial") || strings.Contains(lowerText, "http") {
common.SysLog(fmt.Sprintf("error: %s", text))
text = "请求上游地址失败"
if !strings.HasPrefix(lowerText, "get file base64 from url") {
if strings.Contains(lowerText, "post") || strings.Contains(lowerText, "dial") || strings.Contains(lowerText, "http") {
common.SysLog(fmt.Sprintf("error: %s", text))
text = "请求上游地址失败"
}
}
claudeError := dto.ClaudeError{
Message: text,

View File

@@ -4,8 +4,10 @@ import (
"encoding/base64"
"fmt"
"io"
"one-api/common"
"one-api/constant"
"one-api/dto"
"strings"
)
func GetFileBase64FromUrl(url string) (*dto.LocalFileData, error) {
@@ -30,9 +32,104 @@ func GetFileBase64FromUrl(url string) (*dto.LocalFileData, error) {
// Convert to base64
base64Data := base64.StdEncoding.EncodeToString(fileBytes)
mimeType := resp.Header.Get("Content-Type")
if len(strings.Split(mimeType, ";")) > 1 {
// If Content-Type has parameters, take the first part
mimeType = strings.Split(mimeType, ";")[0]
}
if mimeType == "application/octet-stream" {
if common.DebugEnabled {
println("MIME type is application/octet-stream, trying to guess from URL or filename")
}
// try to guess the MIME type from the url last segment
urlParts := strings.Split(url, "/")
if len(urlParts) > 0 {
lastSegment := urlParts[len(urlParts)-1]
if strings.Contains(lastSegment, ".") {
// Extract the file extension
filename := strings.Split(lastSegment, ".")
if len(filename) > 1 {
ext := strings.ToLower(filename[len(filename)-1])
// Guess MIME type based on file extension
mimeType = GetMimeTypeByExtension(ext)
}
}
} else {
// try to guess the MIME type from the file extension
fileName := resp.Header.Get("Content-Disposition")
if fileName != "" {
// Extract the filename from the Content-Disposition header
parts := strings.Split(fileName, ";")
for _, part := range parts {
if strings.HasPrefix(strings.TrimSpace(part), "filename=") {
fileName = strings.TrimSpace(strings.TrimPrefix(part, "filename="))
// Remove quotes if present
if len(fileName) > 2 && fileName[0] == '"' && fileName[len(fileName)-1] == '"' {
fileName = fileName[1 : len(fileName)-1]
}
// Guess MIME type based on file extension
if ext := strings.ToLower(strings.TrimPrefix(fileName, ".")); ext != "" {
mimeType = GetMimeTypeByExtension(ext)
}
break
}
}
}
}
}
return &dto.LocalFileData{
Base64Data: base64Data,
MimeType: resp.Header.Get("Content-Type"),
MimeType: mimeType,
Size: int64(len(fileBytes)),
}, nil
}
func GetMimeTypeByExtension(ext string) string {
// Convert to lowercase for case-insensitive comparison
ext = strings.ToLower(ext)
switch ext {
// Text files
case "txt", "md", "markdown", "csv", "json", "xml", "html", "htm":
return "text/plain"
// Image files
case "jpg", "jpeg":
return "image/jpeg"
case "png":
return "image/png"
case "gif":
return "image/gif"
// Audio files
case "mp3":
return "audio/mp3"
case "wav":
return "audio/wav"
case "mpeg":
return "audio/mpeg"
// Video files
case "mp4":
return "video/mp4"
case "wmv":
return "video/wmv"
case "flv":
return "video/flv"
case "mov":
return "video/mov"
case "mpg":
return "video/mpg"
case "avi":
return "video/avi"
case "mpegps":
return "video/mpegps"
// Document files
case "pdf":
return "application/pdf"
default:
return "application/octet-stream" // Default for unknown types
}
}

View File

@@ -3,6 +3,7 @@ package service
import (
"errors"
"fmt"
"log"
"one-api/common"
constant2 "one-api/constant"
"one-api/dto"
@@ -10,7 +11,7 @@ import (
relaycommon "one-api/relay/common"
"one-api/relay/helper"
"one-api/setting"
"one-api/setting/operation_setting"
"one-api/setting/ratio_setting"
"strings"
"time"
@@ -45,9 +46,9 @@ func calculateAudioQuota(info QuotaInfo) int {
return int(quota.IntPart())
}
completionRatio := decimal.NewFromFloat(operation_setting.GetCompletionRatio(info.ModelName))
audioRatio := decimal.NewFromFloat(operation_setting.GetAudioRatio(info.ModelName))
audioCompletionRatio := decimal.NewFromFloat(operation_setting.GetAudioCompletionRatio(info.ModelName))
completionRatio := decimal.NewFromFloat(ratio_setting.GetCompletionRatio(info.ModelName))
audioRatio := decimal.NewFromFloat(ratio_setting.GetAudioRatio(info.ModelName))
audioCompletionRatio := decimal.NewFromFloat(ratio_setting.GetAudioCompletionRatio(info.ModelName))
groupRatio := decimal.NewFromFloat(info.GroupRatio)
modelRatio := decimal.NewFromFloat(info.ModelRatio)
@@ -93,12 +94,21 @@ func PreWssConsumeQuota(ctx *gin.Context, relayInfo *relaycommon.RelayInfo, usag
textOutTokens := usage.OutputTokenDetails.TextTokens
audioInputTokens := usage.InputTokenDetails.AudioTokens
audioOutTokens := usage.OutputTokenDetails.AudioTokens
groupRatio := setting.GetGroupRatio(relayInfo.Group)
userGroupRatio, ok := setting.GetGroupGroupRatio(relayInfo.UserGroup, relayInfo.Group)
if ok {
groupRatio = userGroupRatio
groupRatio := ratio_setting.GetGroupRatio(relayInfo.Group)
modelRatio, _ := ratio_setting.GetModelRatio(modelName)
autoGroup, exists := ctx.Get("auto_group")
if exists {
groupRatio = ratio_setting.GetGroupRatio(autoGroup.(string))
log.Printf("final group ratio: %f", groupRatio)
relayInfo.Group = autoGroup.(string)
}
actualGroupRatio := groupRatio
userGroupRatio, ok := ratio_setting.GetGroupGroupRatio(relayInfo.UserGroup, relayInfo.Group)
if ok {
actualGroupRatio = userGroupRatio
}
modelRatio, _ := operation_setting.GetModelRatio(modelName)
quotaInfo := QuotaInfo{
InputDetails: TokenDetails{
@@ -112,7 +122,7 @@ func PreWssConsumeQuota(ctx *gin.Context, relayInfo *relaycommon.RelayInfo, usag
ModelName: modelName,
UsePrice: relayInfo.UsePrice,
ModelRatio: modelRatio,
GroupRatio: groupRatio,
GroupRatio: actualGroupRatio,
}
quota := calculateAudioQuota(quotaInfo)
@@ -134,8 +144,7 @@ func PreWssConsumeQuota(ctx *gin.Context, relayInfo *relaycommon.RelayInfo, usag
}
func PostWssConsumeQuota(ctx *gin.Context, relayInfo *relaycommon.RelayInfo, modelName string,
usage *dto.RealtimeUsage, preConsumedQuota int, userQuota int, modelRatio float64, groupRatio float64,
modelPrice float64, usePrice bool, extraContent string) {
usage *dto.RealtimeUsage, preConsumedQuota int, userQuota int, priceData helper.PriceData, extraContent string) {
useTimeSeconds := time.Now().Unix() - relayInfo.StartTime.Unix()
textInputTokens := usage.InputTokenDetails.TextTokens
@@ -145,15 +154,15 @@ func PostWssConsumeQuota(ctx *gin.Context, relayInfo *relaycommon.RelayInfo, mod
audioOutTokens := usage.OutputTokenDetails.AudioTokens
tokenName := ctx.GetString("token_name")
completionRatio := decimal.NewFromFloat(operation_setting.GetCompletionRatio(modelName))
audioRatio := decimal.NewFromFloat(operation_setting.GetAudioRatio(relayInfo.OriginModelName))
audioCompletionRatio := decimal.NewFromFloat(operation_setting.GetAudioCompletionRatio(modelName))
completionRatio := decimal.NewFromFloat(ratio_setting.GetCompletionRatio(modelName))
audioRatio := decimal.NewFromFloat(ratio_setting.GetAudioRatio(relayInfo.OriginModelName))
audioCompletionRatio := decimal.NewFromFloat(ratio_setting.GetAudioCompletionRatio(modelName))
modelRatio := priceData.ModelRatio
groupRatio := priceData.GroupRatioInfo.GroupRatio
modelPrice := priceData.ModelPrice
usePrice := priceData.UsePrice
actualGroupRatio := groupRatio
userGroupRatio, ok := setting.GetGroupGroupRatio(relayInfo.UserGroup, relayInfo.Group)
if ok {
actualGroupRatio = userGroupRatio
}
quotaInfo := QuotaInfo{
InputDetails: TokenDetails{
TextTokens: textInputTokens,
@@ -166,7 +175,7 @@ func PostWssConsumeQuota(ctx *gin.Context, relayInfo *relaycommon.RelayInfo, mod
ModelName: modelName,
UsePrice: usePrice,
ModelRatio: modelRatio,
GroupRatio: actualGroupRatio,
GroupRatio: groupRatio,
}
quota := calculateAudioQuota(quotaInfo)
@@ -198,7 +207,7 @@ func PostWssConsumeQuota(ctx *gin.Context, relayInfo *relaycommon.RelayInfo, mod
logContent += ", " + extraContent
}
other := GenerateWssOtherInfo(ctx, relayInfo, usage, modelRatio, groupRatio,
completionRatio.InexactFloat64(), audioRatio.InexactFloat64(), audioCompletionRatio.InexactFloat64(), modelPrice, userGroupRatio)
completionRatio.InexactFloat64(), audioRatio.InexactFloat64(), audioCompletionRatio.InexactFloat64(), modelPrice, priceData.GroupRatioInfo.GroupSpecialRatio)
model.RecordConsumeLog(ctx, relayInfo.UserId, relayInfo.ChannelId, usage.InputTokens, usage.OutputTokens, logModel,
tokenName, quota, logContent, relayInfo.TokenId, userQuota, int(useTimeSeconds), relayInfo.IsStream, relayInfo.Group, other)
}
@@ -214,9 +223,8 @@ func PostClaudeConsumeQuota(ctx *gin.Context, relayInfo *relaycommon.RelayInfo,
tokenName := ctx.GetString("token_name")
completionRatio := priceData.CompletionRatio
modelRatio := priceData.ModelRatio
groupRatio := priceData.GroupRatio
groupRatio := priceData.GroupRatioInfo.GroupRatio
modelPrice := priceData.ModelPrice
userGroupRatio := priceData.UserGroupRatio
cacheRatio := priceData.CacheRatio
cacheTokens := usage.PromptTokensDetails.CachedTokens
@@ -265,7 +273,7 @@ func PostClaudeConsumeQuota(ctx *gin.Context, relayInfo *relaycommon.RelayInfo,
}
other := GenerateClaudeOtherInfo(ctx, relayInfo, modelRatio, groupRatio, completionRatio,
cacheTokens, cacheRatio, cacheCreationTokens, cacheCreationRatio, modelPrice, userGroupRatio)
cacheTokens, cacheRatio, cacheCreationTokens, cacheCreationRatio, modelPrice, priceData.GroupRatioInfo.GroupSpecialRatio)
model.RecordConsumeLog(ctx, relayInfo.UserId, relayInfo.ChannelId, promptTokens, completionTokens, modelName,
tokenName, quota, logContent, relayInfo.TokenId, userQuota, int(useTimeSeconds), relayInfo.IsStream, relayInfo.Group, other)
}
@@ -281,21 +289,15 @@ func PostAudioConsumeQuota(ctx *gin.Context, relayInfo *relaycommon.RelayInfo,
audioOutTokens := usage.CompletionTokenDetails.AudioTokens
tokenName := ctx.GetString("token_name")
completionRatio := decimal.NewFromFloat(operation_setting.GetCompletionRatio(relayInfo.OriginModelName))
audioRatio := decimal.NewFromFloat(operation_setting.GetAudioRatio(relayInfo.OriginModelName))
audioCompletionRatio := decimal.NewFromFloat(operation_setting.GetAudioCompletionRatio(relayInfo.OriginModelName))
completionRatio := decimal.NewFromFloat(ratio_setting.GetCompletionRatio(relayInfo.OriginModelName))
audioRatio := decimal.NewFromFloat(ratio_setting.GetAudioRatio(relayInfo.OriginModelName))
audioCompletionRatio := decimal.NewFromFloat(ratio_setting.GetAudioCompletionRatio(relayInfo.OriginModelName))
modelRatio := priceData.ModelRatio
groupRatio := priceData.GroupRatio
groupRatio := priceData.GroupRatioInfo.GroupRatio
modelPrice := priceData.ModelPrice
usePrice := priceData.UsePrice
actualGroupRatio := groupRatio
userGroupRatio, ok := setting.GetGroupGroupRatio(relayInfo.UserGroup, relayInfo.Group)
if ok {
actualGroupRatio = userGroupRatio
}
quotaInfo := QuotaInfo{
InputDetails: TokenDetails{
TextTokens: textInputTokens,
@@ -308,7 +310,7 @@ func PostAudioConsumeQuota(ctx *gin.Context, relayInfo *relaycommon.RelayInfo,
ModelName: relayInfo.OriginModelName,
UsePrice: usePrice,
ModelRatio: modelRatio,
GroupRatio: actualGroupRatio,
GroupRatio: groupRatio,
}
quota := calculateAudioQuota(quotaInfo)
@@ -348,7 +350,7 @@ func PostAudioConsumeQuota(ctx *gin.Context, relayInfo *relaycommon.RelayInfo,
logContent += ", " + extraContent
}
other := GenerateAudioOtherInfo(ctx, relayInfo, usage, modelRatio, groupRatio,
completionRatio.InexactFloat64(), audioRatio.InexactFloat64(), audioCompletionRatio.InexactFloat64(), modelPrice, userGroupRatio)
completionRatio.InexactFloat64(), audioRatio.InexactFloat64(), audioCompletionRatio.InexactFloat64(), modelPrice, priceData.GroupRatioInfo.GroupSpecialRatio)
model.RecordConsumeLog(ctx, relayInfo.UserId, relayInfo.ChannelId, usage.PromptTokens, usage.CompletionTokens, logModel,
tokenName, quota, logContent, relayInfo.TokenId, userQuota, int(useTimeSeconds), relayInfo.IsStream, relayInfo.Group, other)
}

View File

@@ -171,7 +171,7 @@ func CountTokenChatRequest(info *relaycommon.RelayInfo, request dto.GeneralOpenA
countStr += fmt.Sprintf("%v", tool.Function.Parameters)
}
}
toolTokens, err := CountTokenInput(countStr, request.Model)
toolTokens := CountTokenInput(countStr, request.Model)
if err != nil {
return 0, err
}
@@ -194,7 +194,7 @@ func CountTokenClaudeRequest(request dto.ClaudeRequest, model string) (int, erro
// Count tokens in system message
if request.System != "" {
systemTokens, err := CountTokenInput(request.System, model)
systemTokens := CountTokenInput(request.System, model)
if err != nil {
return 0, err
}
@@ -296,10 +296,7 @@ func CountTokenRealtime(info *relaycommon.RelayInfo, request dto.RealtimeEvent,
switch request.Type {
case dto.RealtimeEventTypeSessionUpdate:
if request.Session != nil {
msgTokens, err := CountTextToken(request.Session.Instructions, model)
if err != nil {
return 0, 0, err
}
msgTokens := CountTextToken(request.Session.Instructions, model)
textToken += msgTokens
}
case dto.RealtimeEventResponseAudioDelta:
@@ -311,10 +308,7 @@ func CountTokenRealtime(info *relaycommon.RelayInfo, request dto.RealtimeEvent,
audioToken += atk
case dto.RealtimeEventResponseAudioTranscriptionDelta, dto.RealtimeEventResponseFunctionCallArgumentsDelta:
// count text token
tkm, err := CountTextToken(request.Delta, model)
if err != nil {
return 0, 0, fmt.Errorf("error counting text token: %v", err)
}
tkm := CountTextToken(request.Delta, model)
textToken += tkm
case dto.RealtimeEventInputAudioBufferAppend:
// count audio token
@@ -329,10 +323,7 @@ func CountTokenRealtime(info *relaycommon.RelayInfo, request dto.RealtimeEvent,
case "message":
for _, content := range request.Item.Content {
if content.Type == "input_text" {
tokens, err := CountTextToken(content.Text, model)
if err != nil {
return 0, 0, err
}
tokens := CountTextToken(content.Text, model)
textToken += tokens
}
}
@@ -343,10 +334,7 @@ func CountTokenRealtime(info *relaycommon.RelayInfo, request dto.RealtimeEvent,
if !info.IsFirstRequest {
if info.RealtimeTools != nil && len(info.RealtimeTools) > 0 {
for _, tool := range info.RealtimeTools {
toolTokens, err := CountTokenInput(tool, model)
if err != nil {
return 0, 0, err
}
toolTokens := CountTokenInput(tool, model)
textToken += 8
textToken += toolTokens
}
@@ -409,7 +397,7 @@ func CountTokenMessages(info *relaycommon.RelayInfo, messages []dto.Message, mod
return tokenNum, nil
}
func CountTokenInput(input any, model string) (int, error) {
func CountTokenInput(input any, model string) int {
switch v := input.(type) {
case string:
return CountTextToken(v, model)
@@ -432,13 +420,13 @@ func CountTokenInput(input any, model string) (int, error) {
func CountTokenStreamChoices(messages []dto.ChatCompletionsStreamResponseChoice, model string) int {
tokens := 0
for _, message := range messages {
tkm, _ := CountTokenInput(message.Delta.GetContentString(), model)
tkm := CountTokenInput(message.Delta.GetContentString(), model)
tokens += tkm
if message.Delta.ToolCalls != nil {
for _, tool := range message.Delta.ToolCalls {
tkm, _ := CountTokenInput(tool.Function.Name, model)
tkm := CountTokenInput(tool.Function.Name, model)
tokens += tkm
tkm, _ = CountTokenInput(tool.Function.Arguments, model)
tkm = CountTokenInput(tool.Function.Arguments, model)
tokens += tkm
}
}
@@ -446,9 +434,9 @@ func CountTokenStreamChoices(messages []dto.ChatCompletionsStreamResponseChoice,
return tokens
}
func CountTTSToken(text string, model string) (int, error) {
func CountTTSToken(text string, model string) int {
if strings.HasPrefix(model, "tts") {
return utf8.RuneCountInString(text), nil
return utf8.RuneCountInString(text)
} else {
return CountTextToken(text, model)
}
@@ -483,8 +471,10 @@ func CountAudioTokenOutput(audioBase64 string, audioFormat string) (int, error)
//}
// CountTextToken 统计文本的token数量仅当文本包含敏感词返回错误同时返回token数量
func CountTextToken(text string, model string) (int, error) {
var err error
func CountTextToken(text string, model string) int {
if text == "" {
return 0
}
tokenEncoder := getTokenEncoder(model)
return getTokenNum(tokenEncoder, text), err
return getTokenNum(tokenEncoder, text)
}

View File

@@ -16,13 +16,13 @@ import (
// return 0, errors.New("unknown relay mode")
//}
func ResponseText2Usage(responseText string, modeName string, promptTokens int) (*dto.Usage, error) {
func ResponseText2Usage(responseText string, modeName string, promptTokens int) *dto.Usage {
usage := &dto.Usage{}
usage.PromptTokens = promptTokens
ctkm, err := CountTextToken(responseText, modeName)
ctkm := CountTextToken(responseText, modeName)
usage.CompletionTokens = ctkm
usage.TotalTokens = usage.PromptTokens + usage.CompletionTokens
return usage, err
return usage
}
func ValidUsage(usage *dto.Usage) bool {

31
setting/auto_group.go Normal file
View File

@@ -0,0 +1,31 @@
package setting
import "encoding/json"
var AutoGroups = []string{
"default",
}
var DefaultUseAutoGroup = false
func ContainsAutoGroup(group string) bool {
for _, autoGroup := range AutoGroups {
if autoGroup == group {
return true
}
}
return false
}
func UpdateAutoGroupsByJsonString(jsonString string) error {
AutoGroups = make([]string, 0)
return json.Unmarshal([]byte(jsonString), &AutoGroups)
}
func AutoGroups2JsonString() string {
jsonBytes, err := json.Marshal(AutoGroups)
if err != nil {
return "[]"
}
return string(jsonBytes)
}

View File

@@ -7,6 +7,7 @@ import (
"regexp"
"strings"
"time"
"sort"
)
var (
@@ -210,8 +211,23 @@ func validateFAQ(faqStr string) error {
return nil
}
func getPublishTime(item map[string]interface{}) time.Time {
if v, ok := item["publishDate"]; ok {
if s, ok2 := v.(string); ok2 {
if t, err := time.Parse(time.RFC3339, s); err == nil {
return t
}
}
}
return time.Time{}
}
func GetAnnouncements() []map[string]interface{} {
return getJSONList(GetConsoleSetting().Announcements)
list := getJSONList(GetConsoleSetting().Announcements)
sort.SliceStable(list, func(i, j int) bool {
return getPublishTime(list[i]).After(getPublishTime(list[j]))
})
return list
}
func GetFAQ() []map[string]interface{} {

View File

@@ -17,6 +17,8 @@ const (
const (
// Gemini Audio Input Price
Gemini25FlashPreviewInputAudioPrice = 1.00
Gemini25FlashProductionInputAudioPrice = 1.00 // for `gemini-2.5-flash`
Gemini25FlashLitePreviewInputAudioPrice = 0.50
Gemini25FlashNativeAudioInputAudioPrice = 3.00
Gemini20FlashInputAudioPrice = 0.70
)
@@ -64,10 +66,14 @@ func GetFileSearchPricePerThousand() float64 {
}
func GetGeminiInputAudioPricePerMillionTokens(modelName string) float64 {
if strings.HasPrefix(modelName, "gemini-2.5-flash-preview") {
return Gemini25FlashPreviewInputAudioPrice
} else if strings.HasPrefix(modelName, "gemini-2.5-flash-preview-native-audio") {
if strings.HasPrefix(modelName, "gemini-2.5-flash-preview-native-audio") {
return Gemini25FlashNativeAudioInputAudioPrice
} else if strings.HasPrefix(modelName, "gemini-2.5-flash-preview-lite") {
return Gemini25FlashLitePreviewInputAudioPrice
} else if strings.HasPrefix(modelName, "gemini-2.5-flash-preview") {
return Gemini25FlashPreviewInputAudioPrice
} else if strings.HasPrefix(modelName, "gemini-2.5-flash") {
return Gemini25FlashProductionInputAudioPrice
} else if strings.HasPrefix(modelName, "gemini-2.0-flash") {
return Gemini20FlashInputAudioPrice
}

View File

@@ -1,8 +1,45 @@
package setting
import "encoding/json"
var PayAddress = ""
var CustomCallbackAddress = ""
var EpayId = ""
var EpayKey = ""
var Price = 7.3
var MinTopUp = 1
var PayMethods = []map[string]string{
{
"name": "支付宝",
"color": "rgba(var(--semi-blue-5), 1)",
"type": "alipay",
},
{
"name": "微信",
"color": "rgba(var(--semi-green-5), 1)",
"type": "wxpay",
},
}
func UpdatePayMethodsByJsonString(jsonString string) error {
PayMethods = make([]map[string]string, 0)
return json.Unmarshal([]byte(jsonString), &PayMethods)
}
func PayMethods2JsonString() string {
jsonBytes, err := json.Marshal(PayMethods)
if err != nil {
return "[]"
}
return string(jsonBytes)
}
func ContainsPayMethod(method string) bool {
for _, payMethod := range PayMethods {
if payMethod["type"] == method {
return true
}
}
return false
}

View File

@@ -1,4 +1,4 @@
package operation_setting
package ratio_setting
import (
"encoding/json"
@@ -85,7 +85,11 @@ func UpdateCacheRatioByJSONString(jsonStr string) error {
cacheRatioMapMutex.Lock()
defer cacheRatioMapMutex.Unlock()
cacheRatioMap = make(map[string]float64)
return json.Unmarshal([]byte(jsonStr), &cacheRatioMap)
err := json.Unmarshal([]byte(jsonStr), &cacheRatioMap)
if err == nil {
InvalidateExposedDataCache()
}
return err
}
// GetCacheRatio returns the cache ratio for a model
@@ -106,3 +110,13 @@ func GetCreateCacheRatio(name string) (float64, bool) {
}
return ratio, true
}
func GetCacheRatioCopy() map[string]float64 {
cacheRatioMapMutex.RLock()
defer cacheRatioMapMutex.RUnlock()
copyMap := make(map[string]float64, len(cacheRatioMap))
for k, v := range cacheRatioMap {
copyMap[k] = v
}
return copyMap
}

View File

@@ -0,0 +1,17 @@
package ratio_setting
import "sync/atomic"
var exposeRatioEnabled atomic.Bool
func init() {
exposeRatioEnabled.Store(false)
}
func SetExposeRatioEnabled(enabled bool) {
exposeRatioEnabled.Store(enabled)
}
func IsExposeRatioEnabled() bool {
return exposeRatioEnabled.Load()
}

View File

@@ -0,0 +1,55 @@
package ratio_setting
import (
"sync"
"sync/atomic"
"time"
"github.com/gin-gonic/gin"
)
const exposedDataTTL = 30 * time.Second
type exposedCache struct {
data gin.H
expiresAt time.Time
}
var (
exposedData atomic.Value
rebuildMu sync.Mutex
)
func InvalidateExposedDataCache() {
exposedData.Store((*exposedCache)(nil))
}
func cloneGinH(src gin.H) gin.H {
dst := make(gin.H, len(src))
for k, v := range src {
dst[k] = v
}
return dst
}
func GetExposedData() gin.H {
if c, ok := exposedData.Load().(*exposedCache); ok && c != nil && time.Now().Before(c.expiresAt) {
return cloneGinH(c.data)
}
rebuildMu.Lock()
defer rebuildMu.Unlock()
if c, ok := exposedData.Load().(*exposedCache); ok && c != nil && time.Now().Before(c.expiresAt) {
return cloneGinH(c.data)
}
newData := gin.H{
"model_ratio": GetModelRatioCopy(),
"completion_ratio": GetCompletionRatioCopy(),
"cache_ratio": GetCacheRatioCopy(),
"model_price": GetModelPriceCopy(),
}
exposedData.Store(&exposedCache{
data: newData,
expiresAt: time.Now().Add(exposedDataTTL),
})
return cloneGinH(newData)
}

View File

@@ -1,4 +1,4 @@
package setting
package ratio_setting
import (
"encoding/json"

View File

@@ -1,8 +1,9 @@
package operation_setting
package ratio_setting
import (
"encoding/json"
"one-api/common"
"one-api/setting/operation_setting"
"strings"
"sync"
)
@@ -139,9 +140,17 @@ var defaultModelRatio = map[string]float64{
"gemini-2.0-flash": 0.05,
"gemini-2.5-pro-exp-03-25": 0.625,
"gemini-2.5-pro-preview-03-25": 0.625,
"gemini-2.5-pro": 0.625,
"gemini-2.5-flash-preview-04-17": 0.075,
"gemini-2.5-flash-preview-04-17-thinking": 0.075,
"gemini-2.5-flash-preview-04-17-nothinking": 0.075,
"gemini-2.5-flash-preview-05-20": 0.075,
"gemini-2.5-flash-preview-05-20-thinking": 0.075,
"gemini-2.5-flash-preview-05-20-nothinking": 0.075,
"gemini-2.5-flash-thinking-*": 0.075, // 用于为后续所有2.5 flash thinking budget 模型设置默认倍率
"gemini-2.5-pro-thinking-*": 0.625, // 用于为后续所有2.5 pro thinking budget 模型设置默认倍率
"gemini-2.5-flash-lite-preview-06-17": 0.05,
"gemini-2.5-flash": 0.15,
"text-embedding-004": 0.001,
"chatglm_turbo": 0.3572, // ¥0.005 / 1k tokens
"chatglm_pro": 0.7143, // ¥0.01 / 1k tokens
@@ -311,7 +320,11 @@ func UpdateModelPriceByJSONString(jsonStr string) error {
modelPriceMapMutex.Lock()
defer modelPriceMapMutex.Unlock()
modelPriceMap = make(map[string]float64)
return json.Unmarshal([]byte(jsonStr), &modelPriceMap)
err := json.Unmarshal([]byte(jsonStr), &modelPriceMap)
if err == nil {
InvalidateExposedDataCache()
}
return err
}
// GetModelPrice 返回模型的价格,如果模型不存在则返回-1false
@@ -339,19 +352,33 @@ func UpdateModelRatioByJSONString(jsonStr string) error {
modelRatioMapMutex.Lock()
defer modelRatioMapMutex.Unlock()
modelRatioMap = make(map[string]float64)
return json.Unmarshal([]byte(jsonStr), &modelRatioMap)
err := json.Unmarshal([]byte(jsonStr), &modelRatioMap)
if err == nil {
InvalidateExposedDataCache()
}
return err
}
// 处理带有思考预算的模型名称,方便统一定价
func handleThinkingBudgetModel(name, prefix, wildcard string) string {
if strings.HasPrefix(name, prefix) && strings.Contains(name, "-thinking-") {
return wildcard
}
return name
}
func GetModelRatio(name string) (float64, bool) {
modelRatioMapMutex.RLock()
defer modelRatioMapMutex.RUnlock()
name = handleThinkingBudgetModel(name, "gemini-2.5-flash", "gemini-2.5-flash-thinking-*")
name = handleThinkingBudgetModel(name, "gemini-2.5-pro", "gemini-2.5-pro-thinking-*")
if strings.HasPrefix(name, "gpt-4-gizmo") {
name = "gpt-4-gizmo-*"
}
ratio, ok := modelRatioMap[name]
if !ok {
return 37.5, SelfUseModeEnabled
return 37.5, operation_setting.SelfUseModeEnabled
}
return ratio, true
}
@@ -389,13 +416,22 @@ func UpdateCompletionRatioByJSONString(jsonStr string) error {
CompletionRatioMutex.Lock()
defer CompletionRatioMutex.Unlock()
CompletionRatio = make(map[string]float64)
return json.Unmarshal([]byte(jsonStr), &CompletionRatio)
err := json.Unmarshal([]byte(jsonStr), &CompletionRatio)
if err == nil {
InvalidateExposedDataCache()
}
return err
}
func GetCompletionRatio(name string) float64 {
CompletionRatioMutex.RLock()
defer CompletionRatioMutex.RUnlock()
if strings.HasPrefix(name, "gpt-4-gizmo") {
name = "gpt-4-gizmo-*"
}
if strings.HasPrefix(name, "gpt-4o-gizmo") {
name = "gpt-4o-gizmo-*"
}
if strings.Contains(name, "/") {
if ratio, ok := CompletionRatio[name]; ok {
return ratio
@@ -413,12 +449,6 @@ func GetCompletionRatio(name string) float64 {
func getHardcodedCompletionModelRatio(name string) (float64, bool) {
lowercaseName := strings.ToLower(name)
if strings.HasPrefix(name, "gpt-4-gizmo") {
name = "gpt-4-gizmo-*"
}
if strings.HasPrefix(name, "gpt-4o-gizmo") {
name = "gpt-4o-gizmo-*"
}
if strings.HasPrefix(name, "gpt-4") && !strings.HasSuffix(name, "-all") && !strings.HasSuffix(name, "-gizmo-*") {
if strings.HasPrefix(name, "gpt-4o") {
if name == "gpt-4o-2024-05-13" {
@@ -470,14 +500,19 @@ func getHardcodedCompletionModelRatio(name string) (float64, bool) {
return 4, true
} else if strings.HasPrefix(name, "gemini-2.0") {
return 4, true
} else if strings.HasPrefix(name, "gemini-2.5-pro-preview") {
} else if strings.HasPrefix(name, "gemini-2.5-pro") { // 移除preview来增加兼容性这里假设正式版的倍率和preview一致
return 8, true
} else if strings.HasPrefix(name, "gemini-2.5-flash-preview") {
if strings.HasSuffix(name, "-nothinking") {
return 4, false
} else {
return 3.5 / 0.6, false
} else if strings.HasPrefix(name, "gemini-2.5-flash") { // 处理不同的flash模型倍率
if strings.HasPrefix(name, "gemini-2.5-flash-preview") {
if strings.HasSuffix(name, "-nothinking") {
return 4, true
}
return 3.5 / 0.15, true
}
if strings.HasPrefix(name, "gemini-2.5-flash-lite-preview") {
return 4, true
}
return 2.5 / 0.3, true
}
return 4, false
}
@@ -593,3 +628,33 @@ func GetImageRatio(name string) (float64, bool) {
}
return ratio, true
}
func GetModelRatioCopy() map[string]float64 {
modelRatioMapMutex.RLock()
defer modelRatioMapMutex.RUnlock()
copyMap := make(map[string]float64, len(modelRatioMap))
for k, v := range modelRatioMap {
copyMap[k] = v
}
return copyMap
}
func GetModelPriceCopy() map[string]float64 {
modelPriceMapMutex.RLock()
defer modelPriceMapMutex.RUnlock()
copyMap := make(map[string]float64, len(modelPriceMap))
for k, v := range modelPriceMap {
copyMap[k] = v
}
return copyMap
}
func GetCompletionRatioCopy() map[string]float64 {
CompletionRatioMutex.RLock()
defer CompletionRatioMutex.RUnlock()
copyMap := make(map[string]float64, len(CompletionRatio))
for k, v := range CompletionRatio {
copyMap[k] = v
}
return copyMap
}

View File

@@ -50,3 +50,10 @@ func GroupInUserUsableGroups(groupName string) bool {
_, ok := userUsableGroups[groupName]
return ok
}
func GetUsableGroupDescription(groupName string) string {
if desc, ok := userUsableGroups[groupName]; ok {
return desc
}
return groupName
}

View File

@@ -28,6 +28,7 @@ import {
Tag,
Typography,
Skeleton,
Badge,
} from '@douyinfe/semi-ui';
import { StatusContext } from '../../context/Status/index.js';
import { useStyle, styleActions } from '../../context/Style/index.js';
@@ -43,6 +44,7 @@ const HeaderBar = () => {
const [mobileMenuOpen, setMobileMenuOpen] = useState(false);
const location = useLocation();
const [noticeVisible, setNoticeVisible] = useState(false);
const [unreadCount, setUnreadCount] = useState(0);
const systemName = getSystemName();
const logo = getLogo();
@@ -53,9 +55,44 @@ const HeaderBar = () => {
const docsLink = statusState?.status?.docs_link || '';
const isDemoSiteMode = statusState?.status?.demo_site_enabled || false;
const isConsoleRoute = location.pathname.startsWith('/console');
const theme = useTheme();
const setTheme = useSetTheme();
const announcements = statusState?.status?.announcements || [];
const getAnnouncementKey = (a) => `${a?.publishDate || ''}-${(a?.content || '').slice(0, 30)}`;
const calculateUnreadCount = () => {
if (!announcements.length) return 0;
let readKeys = [];
try {
readKeys = JSON.parse(localStorage.getItem('notice_read_keys')) || [];
} catch (_) {
readKeys = [];
}
const readSet = new Set(readKeys);
return announcements.filter((a) => !readSet.has(getAnnouncementKey(a))).length;
};
const getUnreadKeys = () => {
if (!announcements.length) return [];
let readKeys = [];
try {
readKeys = JSON.parse(localStorage.getItem('notice_read_keys')) || [];
} catch (_) {
readKeys = [];
}
const readSet = new Set(readKeys);
return announcements.filter((a) => !readSet.has(getAnnouncementKey(a))).map(getAnnouncementKey);
};
useEffect(() => {
setUnreadCount(calculateUnreadCount());
// eslint-disable-next-line react-hooks/exhaustive-deps
}, [announcements]);
const mainNavLinks = [
{
text: t('首页'),
@@ -106,6 +143,25 @@ const HeaderBar = () => {
}, 3000);
};
const handleNoticeOpen = () => {
setNoticeVisible(true);
};
const handleNoticeClose = () => {
setNoticeVisible(false);
if (announcements.length) {
let readKeys = [];
try {
readKeys = JSON.parse(localStorage.getItem('notice_read_keys')) || [];
} catch (_) {
readKeys = [];
}
const mergedKeys = Array.from(new Set([...readKeys, ...announcements.map(getAnnouncementKey)]));
localStorage.setItem('notice_read_keys', JSON.stringify(mergedKeys));
}
setUnreadCount(0);
};
useEffect(() => {
if (theme === 'dark') {
document.body.setAttribute('theme-mode', 'dark');
@@ -353,15 +409,14 @@ const HeaderBar = () => {
}
};
// 检查当前路由是否以/console开头
const isConsoleRoute = location.pathname.startsWith('/console');
return (
<header className="text-semi-color-text-0 sticky top-0 z-50 transition-colors duration-300 bg-white/75 dark:bg-zinc-900/75 backdrop-blur-lg">
<NoticeModal
visible={noticeVisible}
onClose={() => setNoticeVisible(false)}
onClose={handleNoticeClose}
isMobile={styleState.isMobile}
defaultTab={unreadCount > 0 ? 'system' : 'inApp'}
unreadKeys={getUnreadKeys()}
/>
<div className="w-full px-2">
<div className="flex items-center justify-between h-16">
@@ -462,14 +517,27 @@ const HeaderBar = () => {
</Dropdown>
)}
<Button
icon={<IconBell className="text-lg" />}
aria-label={t('系统公告')}
onClick={() => setNoticeVisible(true)}
theme="borderless"
type="tertiary"
className="!p-1.5 !text-current focus:!bg-semi-color-fill-1 dark:focus:!bg-gray-700 !rounded-full !bg-semi-color-fill-0 dark:!bg-semi-color-fill-1 hover:!bg-semi-color-fill-1 dark:hover:!bg-semi-color-fill-2"
/>
{unreadCount > 0 ? (
<Badge count={unreadCount} type="danger" overflowCount={99}>
<Button
icon={<IconBell className="text-lg" />}
aria-label={t('系统公告')}
onClick={handleNoticeOpen}
theme="borderless"
type="tertiary"
className="!p-1.5 !text-current focus:!bg-semi-color-fill-1 dark:focus:!bg-gray-700 !rounded-full !bg-semi-color-fill-0 dark:!bg-semi-color-fill-1 hover:!bg-semi-color-fill-1 dark:hover:!bg-semi-color-fill-2"
/>
</Badge>
) : (
<Button
icon={<IconBell className="text-lg" />}
aria-label={t('系统公告')}
onClick={handleNoticeOpen}
theme="borderless"
type="tertiary"
className="!p-1.5 !text-current focus:!bg-semi-color-fill-1 dark:focus:!bg-gray-700 !rounded-full !bg-semi-color-fill-0 dark:!bg-semi-color-fill-1 hover:!bg-semi-color-fill-1 dark:hover:!bg-semi-color-fill-2"
/>
)}
<Button
icon={theme === 'dark' ? <IconSun size="large" className="text-yellow-500" /> : <IconMoon size="large" className="text-gray-300" />}

View File

@@ -1,14 +1,36 @@
import React, { useEffect, useState } from 'react';
import { Button, Modal, Empty } from '@douyinfe/semi-ui';
import React, { useEffect, useState, useContext, useMemo } from 'react';
import { Button, Modal, Empty, Tabs, TabPane, Timeline } from '@douyinfe/semi-ui';
import { useTranslation } from 'react-i18next';
import { API, showError } from '../../helpers';
import { API, showError, getRelativeTime } from '../../helpers';
import { marked } from 'marked';
import { IllustrationNoContent, IllustrationNoContentDark } from '@douyinfe/semi-illustrations';
import { StatusContext } from '../../context/Status/index.js';
import { Bell, Megaphone } from 'lucide-react';
const NoticeModal = ({ visible, onClose, isMobile }) => {
const NoticeModal = ({ visible, onClose, isMobile, defaultTab = 'inApp', unreadKeys = [] }) => {
const { t } = useTranslation();
const [noticeContent, setNoticeContent] = useState('');
const [loading, setLoading] = useState(false);
const [activeTab, setActiveTab] = useState(defaultTab);
const [statusState] = useContext(StatusContext);
const announcements = statusState?.status?.announcements || [];
const unreadSet = useMemo(() => new Set(unreadKeys), [unreadKeys]);
const getKeyForItem = (item) => `${item?.publishDate || ''}-${(item?.content || '').slice(0, 30)}`;
const processedAnnouncements = useMemo(() => {
return (announcements || []).slice(0, 20).map(item => ({
key: getKeyForItem(item),
type: item.type || 'default',
time: getRelativeTime(item.publishDate),
content: item.content,
extra: item.extra,
isUnread: unreadSet.has(getKeyForItem(item))
}));
}, [announcements, unreadSet]);
const handleCloseTodayNotice = () => {
const today = new Date().toDateString();
@@ -44,7 +66,13 @@ const NoticeModal = ({ visible, onClose, isMobile }) => {
}
}, [visible]);
const renderContent = () => {
useEffect(() => {
if (visible) {
setActiveTab(defaultTab);
}
}, [defaultTab, visible]);
const renderMarkdownNotice = () => {
if (loading) {
return <div className="py-12"><Empty description={t('加载中...')} /></div>;
}
@@ -64,14 +92,74 @@ const NoticeModal = ({ visible, onClose, isMobile }) => {
return (
<div
dangerouslySetInnerHTML={{ __html: noticeContent }}
className="notice-content-scroll max-h-[60vh] overflow-y-auto pr-2"
className="notice-content-scroll max-h-[55vh] overflow-y-auto pr-2"
/>
);
};
const renderAnnouncementTimeline = () => {
if (processedAnnouncements.length === 0) {
return (
<div className="py-12">
<Empty
image={<IllustrationNoContent style={{ width: 150, height: 150 }} />}
darkModeImage={<IllustrationNoContentDark style={{ width: 150, height: 150 }} />}
description={t('暂无系统公告')}
/>
</div>
);
}
return (
<div className="max-h-[55vh] overflow-y-auto pr-2 card-content-scroll">
<Timeline mode="alternate">
{processedAnnouncements.map((item, idx) => (
<Timeline.Item
key={idx}
type={item.type}
time={item.time}
className={item.isUnread ? '' : ''}
>
<div>
{item.isUnread ? (
<span className="shine-text">
{item.content}
</span>
) : (
item.content
)}
{item.extra && <div className="text-xs text-gray-500">{item.extra}</div>}
</div>
</Timeline.Item>
))}
</Timeline>
</div>
);
};
const renderBody = () => {
if (activeTab === 'inApp') {
return renderMarkdownNotice();
}
return renderAnnouncementTimeline();
};
return (
<Modal
title={t('系统公告')}
title={
<div className="flex items-center justify-between w-full">
<span>{t('系统公告')}</span>
<Tabs
activeKey={activeTab}
onChange={setActiveTab}
type='card'
size='small'
>
<TabPane tab={<span className="flex items-center gap-1"><Bell size={14} /> {t('通知')}</span>} itemKey='inApp' />
<TabPane tab={<span className="flex items-center gap-1"><Megaphone size={14} /> {t('系统公告')}</span>} itemKey='system' />
</Tabs>
</div>
}
visible={visible}
onCancel={onClose}
footer={(
@@ -82,7 +170,7 @@ const NoticeModal = ({ visible, onClose, isMobile }) => {
)}
size={isMobile ? 'full-width' : 'large'}
>
{renderContent()}
{renderBody()}
</Modal>
);
};

View File

@@ -0,0 +1,143 @@
import React, { useState } from 'react';
import {
Modal,
Transfer,
Input,
Space,
Checkbox,
Avatar,
Highlight,
} from '@douyinfe/semi-ui';
import { IconClose } from '@douyinfe/semi-icons';
const CHANNEL_STATUS_CONFIG = {
1: { color: 'green', text: '启用' },
2: { color: 'red', text: '禁用' },
3: { color: 'amber', text: '自禁' },
default: { color: 'grey', text: '未知' }
};
const getChannelStatusConfig = (status) => {
return CHANNEL_STATUS_CONFIG[status] || CHANNEL_STATUS_CONFIG.default;
};
export default function ChannelSelectorModal({
t,
visible,
onCancel,
onOk,
allChannels = [],
selectedChannelIds = [],
setSelectedChannelIds,
channelEndpoints,
updateChannelEndpoint,
}) {
const [searchText, setSearchText] = useState('');
const ChannelInfo = ({ item, showEndpoint = false, isSelected = false }) => {
const channelId = item.key || item.value;
const currentEndpoint = channelEndpoints[channelId];
const baseUrl = item._originalData?.base_url || '';
const status = item._originalData?.status || 0;
const statusConfig = getChannelStatusConfig(status);
return (
<>
<Avatar color={statusConfig.color} size="small">
{statusConfig.text}
</Avatar>
<div className="info">
<div className="name">
{isSelected ? (
item.label
) : (
<Highlight sourceString={item.label} searchWords={[searchText]} />
)}
</div>
<div className="email" style={showEndpoint ? { display: 'flex', alignItems: 'center', gap: '4px' } : {}}>
<span className="text-xs text-gray-500 truncate max-w-[200px]" title={baseUrl}>
{isSelected ? (
baseUrl
) : (
<Highlight sourceString={baseUrl} searchWords={[searchText]} />
)}
</span>
{showEndpoint && (
<Input
size="small"
value={currentEndpoint}
onChange={(value) => updateChannelEndpoint(channelId, value)}
placeholder="/api/ratio_config"
className="flex-1 text-xs"
style={{ fontSize: '12px' }}
/>
)}
{isSelected && !showEndpoint && (
<span className="text-xs text-gray-700 font-mono bg-gray-100 px-2 py-1 rounded ml-2">
{currentEndpoint}
</span>
)}
</div>
</div>
</>
);
};
const renderSourceItem = (item) => {
return (
<div className="components-transfer-source-item" key={item.key}>
<Checkbox
onChange={item.onChange}
checked={item.checked}
style={{ height: 52, alignItems: 'center' }}
>
<ChannelInfo item={item} showEndpoint={true} />
</Checkbox>
</div>
);
};
const renderSelectedItem = (item) => {
return (
<div className="components-transfer-selected-item" key={item.key}>
<ChannelInfo item={item} isSelected={true} />
<IconClose style={{ cursor: 'pointer' }} onClick={item.onRemove} />
</div>
);
};
const channelFilter = (input, item) => {
const searchLower = input.toLowerCase();
return item.label.toLowerCase().includes(searchLower) ||
(item._originalData?.base_url || '').toLowerCase().includes(searchLower);
};
return (
<Modal
visible={visible}
onCancel={onCancel}
onOk={onOk}
title={<span className="text-lg font-semibold">{t('选择同步渠道')}</span>}
width={1000}
>
<Space vertical style={{ width: '100%' }}>
<Transfer
style={{ width: '100%' }}
dataSource={allChannels}
value={selectedChannelIds}
onChange={setSelectedChannelIds}
renderSourceItem={renderSourceItem}
renderSelectedItem={renderSelectedItem}
filter={channelFilter}
inputProps={{ placeholder: t('搜索渠道名称或地址') }}
onSearch={setSearchText}
emptyContent={{
left: t('暂无渠道'),
right: t('暂无选择'),
search: t('无搜索结果'),
}}
/>
</Space>
</Modal>
);
}

View File

@@ -0,0 +1,63 @@
import React, { useEffect, useState } from 'react';
import { Card, Spin } from '@douyinfe/semi-ui';
import SettingsChats from '../../pages/Setting/Chat/SettingsChats.js';
import { API, showError } from '../../helpers';
const ChatsSetting = () => {
let [inputs, setInputs] = useState({
/* 聊天设置 */
Chats: '[]',
});
let [loading, setLoading] = useState(false);
const getOptions = async () => {
const res = await API.get('/api/option/');
const { success, message, data } = res.data;
if (success) {
let newInputs = {};
data.forEach((item) => {
if (
item.key.endsWith('Enabled') ||
['DefaultCollapseSidebar'].includes(item.key)
) {
newInputs[item.key] = item.value === 'true' ? true : false;
} else {
newInputs[item.key] = item.value;
}
});
setInputs(newInputs);
} else {
showError(message);
}
};
async function onRefresh() {
try {
setLoading(true);
await getOptions();
} catch (error) {
showError('刷新失败');
} finally {
setLoading(false);
}
}
useEffect(() => {
onRefresh();
}, []);
return (
<>
<Spin spinning={loading} size='large'>
{/* 聊天设置 */}
<Card style={{ marginTop: '10px' }}>
<SettingsChats options={inputs} refresh={onRefresh} />
</Card>
</Spin>
</>
);
};
export default ChatsSetting;

View File

@@ -5,6 +5,7 @@ import SettingsAPIInfo from '../../pages/Setting/Dashboard/SettingsAPIInfo.js';
import SettingsAnnouncements from '../../pages/Setting/Dashboard/SettingsAnnouncements.js';
import SettingsFAQ from '../../pages/Setting/Dashboard/SettingsFAQ.js';
import SettingsUptimeKuma from '../../pages/Setting/Dashboard/SettingsUptimeKuma.js';
import SettingsDataDashboard from '../../pages/Setting/Dashboard/SettingsDataDashboard.js';
const DashboardSetting = () => {
let [inputs, setInputs] = useState({
@@ -23,6 +24,11 @@ const DashboardSetting = () => {
FAQ: '',
UptimeKumaUrl: '',
UptimeKumaSlug: '',
/* 数据看板 */
DataExportEnabled: false,
DataExportDefaultTime: 'hour',
DataExportInterval: 5,
});
let [loading, setLoading] = useState(false);
@@ -37,6 +43,10 @@ const DashboardSetting = () => {
if (item.key in inputs) {
newInputs[item.key] = item.value;
}
if (item.key.endsWith('Enabled') &&
(item.key === 'DataExportEnabled')) {
newInputs[item.key] = item.value === 'true' ? true : false;
}
});
setInputs(newInputs);
} else {
@@ -106,9 +116,9 @@ const DashboardSetting = () => {
</p>
</Modal>
{/* API信息管理 */}
{/* 数据看板设置 */}
<Card style={{ marginTop: '10px' }}>
<SettingsAPIInfo options={inputs} refresh={onRefresh} />
<SettingsDataDashboard options={inputs} refresh={onRefresh} />
</Card>
{/* 系统公告管理 */}
@@ -116,6 +126,11 @@ const DashboardSetting = () => {
<SettingsAnnouncements options={inputs} refresh={onRefresh} />
</Card>
{/* API信息管理 */}
<Card style={{ marginTop: '10px' }}>
<SettingsAPIInfo options={inputs} refresh={onRefresh} />
</Card>
{/* 常见问答管理 */}
<Card style={{ marginTop: '10px' }}>
<SettingsFAQ options={inputs} refresh={onRefresh} />

View File

@@ -0,0 +1,65 @@
import React, { useEffect, useState } from 'react';
import { Card, Spin } from '@douyinfe/semi-ui';
import SettingsDrawing from '../../pages/Setting/Drawing/SettingsDrawing.js';
import { API, showError } from '../../helpers';
const DrawingSetting = () => {
let [inputs, setInputs] = useState({
/* 绘图设置 */
DrawingEnabled: false,
MjNotifyEnabled: false,
MjAccountFilterEnabled: false,
MjForwardUrlEnabled: false,
MjModeClearEnabled: false,
MjActionCheckSuccessEnabled: false,
});
let [loading, setLoading] = useState(false);
const getOptions = async () => {
const res = await API.get('/api/option/');
const { success, message, data } = res.data;
if (success) {
let newInputs = {};
data.forEach((item) => {
if (item.key.endsWith('Enabled')) {
newInputs[item.key] = item.value === 'true' ? true : false;
} else {
newInputs[item.key] = item.value;
}
});
setInputs(newInputs);
} else {
showError(message);
}
};
async function onRefresh() {
try {
setLoading(true);
await getOptions();
} catch (error) {
showError('刷新失败');
} finally {
setLoading(false);
}
}
useEffect(() => {
onRefresh();
}, []);
return (
<>
<Spin spinning={loading} size='large'>
{/* 绘图设置 */}
<Card style={{ marginTop: '10px' }}>
<SettingsDrawing options={inputs} refresh={onRefresh} />
</Card>
</Spin>
</>
);
};
export default DrawingSetting;

View File

@@ -1,66 +1,44 @@
import React, { useEffect, useState } from 'react';
import { Card, Spin, Tabs } from '@douyinfe/semi-ui';
import { Card, Spin } from '@douyinfe/semi-ui';
import SettingsGeneral from '../../pages/Setting/Operation/SettingsGeneral.js';
import SettingsDrawing from '../../pages/Setting/Operation/SettingsDrawing.js';
import SettingsSensitiveWords from '../../pages/Setting/Operation/SettingsSensitiveWords.js';
import SettingsLog from '../../pages/Setting/Operation/SettingsLog.js';
import SettingsDataDashboard from '../../pages/Setting/Operation/SettingsDataDashboard.js';
import SettingsMonitoring from '../../pages/Setting/Operation/SettingsMonitoring.js';
import SettingsCreditLimit from '../../pages/Setting/Operation/SettingsCreditLimit.js';
import ModelSettingsVisualEditor from '../../pages/Setting/Operation/ModelSettingsVisualEditor.js';
import GroupRatioSettings from '../../pages/Setting/Operation/GroupRatioSettings.js';
import ModelRatioSettings from '../../pages/Setting/Operation/ModelRatioSettings.js';
import { API, showError, showSuccess } from '../../helpers';
import SettingsChats from '../../pages/Setting/Operation/SettingsChats.js';
import { useTranslation } from 'react-i18next';
import ModelRatioNotSetEditor from '../../pages/Setting/Operation/ModelRationNotSetEditor.js';
import { API, showError } from '../../helpers';
const OperationSetting = () => {
const { t } = useTranslation();
let [inputs, setInputs] = useState({
/* 额度相关 */
QuotaForNewUser: 0,
PreConsumedQuota: 0,
QuotaForInviter: 0,
QuotaForInvitee: 0,
QuotaRemindThreshold: 0,
PreConsumedQuota: 0,
StreamCacheQueueLength: 0,
ModelRatio: '',
CacheRatio: '',
CompletionRatio: '',
ModelPrice: '',
GroupRatio: '',
GroupGroupRatio: '',
UserUsableGroups: '',
/* 通用设置 */
TopUpLink: '',
'general_setting.docs_link': '',
// ChatLink2: '', // 添加的新状态变量
QuotaPerUnit: 0,
AutomaticDisableChannelEnabled: false,
AutomaticEnableChannelEnabled: false,
ChannelDisableThreshold: 0,
LogConsumeEnabled: false,
RetryTimes: 0,
DisplayInCurrencyEnabled: false,
DisplayTokenStatEnabled: false,
CheckSensitiveEnabled: false,
CheckSensitiveOnPromptEnabled: false,
CheckSensitiveOnCompletionEnabled: '',
StopOnSensitiveEnabled: '',
SensitiveWords: '',
MjNotifyEnabled: false,
MjAccountFilterEnabled: false,
MjModeClearEnabled: false,
MjForwardUrlEnabled: false,
MjActionCheckSuccessEnabled: false,
DrawingEnabled: false,
DataExportEnabled: false,
DataExportDefaultTime: 'hour',
DataExportInterval: 5,
DefaultCollapseSidebar: false, // 默认折叠侧边栏
RetryTimes: 0,
Chats: '[]',
DefaultCollapseSidebar: false,
DemoSiteEnabled: false,
SelfUseModeEnabled: false,
/* 敏感词设置 */
CheckSensitiveEnabled: false,
CheckSensitiveOnPromptEnabled: false,
SensitiveWords: '',
/* 日志设置 */
LogConsumeEnabled: false,
/* 监控设置 */
ChannelDisableThreshold: 0,
QuotaRemindThreshold: 0,
AutomaticDisableChannelEnabled: false,
AutomaticEnableChannelEnabled: false,
AutomaticDisableKeywords: '',
});
@@ -72,17 +50,6 @@ const OperationSetting = () => {
if (success) {
let newInputs = {};
data.forEach((item) => {
if (
item.key === 'ModelRatio' ||
item.key === 'GroupRatio' ||
item.key === 'GroupGroupRatio' ||
item.key === 'UserUsableGroups' ||
item.key === 'CompletionRatio' ||
item.key === 'ModelPrice' ||
item.key === 'CacheRatio'
) {
item.value = JSON.stringify(JSON.parse(item.value), null, 2);
}
if (
item.key.endsWith('Enabled') ||
['DefaultCollapseSidebar'].includes(item.key)
@@ -121,10 +88,6 @@ const OperationSetting = () => {
<Card style={{ marginTop: '10px' }}>
<SettingsGeneral options={inputs} refresh={onRefresh} />
</Card>
{/* 绘图设置 */}
<Card style={{ marginTop: '10px' }}>
<SettingsDrawing options={inputs} refresh={onRefresh} />
</Card>
{/* 屏蔽词过滤设置 */}
<Card style={{ marginTop: '10px' }}>
<SettingsSensitiveWords options={inputs} refresh={onRefresh} />
@@ -133,10 +96,6 @@ const OperationSetting = () => {
<Card style={{ marginTop: '10px' }}>
<SettingsLog options={inputs} refresh={onRefresh} />
</Card>
{/* 数据看板 */}
<Card style={{ marginTop: '10px' }}>
<SettingsDataDashboard options={inputs} refresh={onRefresh} />
</Card>
{/* 监控设置 */}
<Card style={{ marginTop: '10px' }}>
<SettingsMonitoring options={inputs} refresh={onRefresh} />
@@ -145,28 +104,6 @@ const OperationSetting = () => {
<Card style={{ marginTop: '10px' }}>
<SettingsCreditLimit options={inputs} refresh={onRefresh} />
</Card>
{/* 聊天设置 */}
<Card style={{ marginTop: '10px' }}>
<SettingsChats options={inputs} refresh={onRefresh} />
</Card>
{/* 分组倍率设置 */}
<Card style={{ marginTop: '10px' }}>
<GroupRatioSettings options={inputs} refresh={onRefresh} />
</Card>
{/* 合并模型倍率设置和可视化倍率设置 */}
<Card style={{ marginTop: '10px' }}>
<Tabs type='line'>
<Tabs.TabPane tab={t('模型倍率设置')} itemKey='model'>
<ModelRatioSettings options={inputs} refresh={onRefresh} />
</Tabs.TabPane>
<Tabs.TabPane tab={t('可视化倍率设置')} itemKey='visual'>
<ModelSettingsVisualEditor options={inputs} refresh={onRefresh} />
</Tabs.TabPane>
<Tabs.TabPane tab={t('未设置倍率模型')} itemKey='unset_models'>
<ModelRatioNotSetEditor options={inputs} refresh={onRefresh} />
</Tabs.TabPane>
</Tabs>
</Card>
</Spin>
</>
);

View File

@@ -0,0 +1,88 @@
import React, { useEffect, useState } from 'react';
import { Card, Spin } from '@douyinfe/semi-ui';
import SettingsGeneralPayment from '../../pages/Setting/Payment/SettingsGeneralPayment.js';
import SettingsPaymentGateway from '../../pages/Setting/Payment/SettingsPaymentGateway.js';
import { API, showError } from '../../helpers';
import { useTranslation } from 'react-i18next';
const PaymentSetting = () => {
const { t } = useTranslation();
let [inputs, setInputs] = useState({
ServerAddress: '',
PayAddress: '',
EpayId: '',
EpayKey: '',
Price: 7.3,
MinTopUp: 1,
TopupGroupRatio: '',
CustomCallbackAddress: '',
PayMethods: '',
});
let [loading, setLoading] = useState(false);
const getOptions = async () => {
const res = await API.get('/api/option/');
const { success, message, data } = res.data;
if (success) {
let newInputs = {};
data.forEach((item) => {
switch (item.key) {
case 'TopupGroupRatio':
try {
newInputs[item.key] = JSON.stringify(JSON.parse(item.value), null, 2);
} catch (error) {
console.error('解析TopupGroupRatio出错:', error);
newInputs[item.key] = item.value;
}
break;
case 'Price':
case 'MinTopUp':
newInputs[item.key] = parseFloat(item.value);
break;
default:
if (item.key.endsWith('Enabled')) {
newInputs[item.key] = item.value === 'true' ? true : false;
} else {
newInputs[item.key] = item.value;
}
break;
}
});
setInputs(newInputs);
} else {
showError(t(message));
}
};
async function onRefresh() {
try {
setLoading(true);
await getOptions();
} catch (error) {
showError(t('刷新失败'));
} finally {
setLoading(false);
}
}
useEffect(() => {
onRefresh();
}, []);
return (
<>
<Spin spinning={loading} size='large'>
<Card style={{ marginTop: '10px' }}>
<SettingsGeneralPayment options={inputs} refresh={onRefresh} />
</Card>
<Card style={{ marginTop: '10px' }}>
<SettingsPaymentGateway options={inputs} refresh={onRefresh} />
</Card>
</Spin>
</>
);
};
export default PaymentSetting;

Some files were not shown because too many files have changed in this diff Show More