mirror of
https://github.com/Wei-Shaw/sub2api.git
synced 2026-03-30 02:09:43 +00:00
309 lines
9.4 KiB
Go
309 lines
9.4 KiB
Go
package repository
|
|
|
|
import (
|
|
"context"
|
|
"database/sql"
|
|
"errors"
|
|
"strings"
|
|
|
|
dbent "github.com/Wei-Shaw/sub2api/ent"
|
|
"github.com/Wei-Shaw/sub2api/internal/pkg/logger"
|
|
"github.com/Wei-Shaw/sub2api/internal/service"
|
|
)
|
|
|
|
type usageBillingRepository struct {
|
|
db *sql.DB
|
|
}
|
|
|
|
func NewUsageBillingRepository(_ *dbent.Client, sqlDB *sql.DB) service.UsageBillingRepository {
|
|
return &usageBillingRepository{db: sqlDB}
|
|
}
|
|
|
|
func (r *usageBillingRepository) Apply(ctx context.Context, cmd *service.UsageBillingCommand) (_ *service.UsageBillingApplyResult, err error) {
|
|
if cmd == nil {
|
|
return &service.UsageBillingApplyResult{}, nil
|
|
}
|
|
if r == nil || r.db == nil {
|
|
return nil, errors.New("usage billing repository db is nil")
|
|
}
|
|
|
|
cmd.Normalize()
|
|
if cmd.RequestID == "" {
|
|
return nil, service.ErrUsageBillingRequestIDRequired
|
|
}
|
|
|
|
tx, err := r.db.BeginTx(ctx, nil)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
defer func() {
|
|
if tx != nil {
|
|
_ = tx.Rollback()
|
|
}
|
|
}()
|
|
|
|
applied, err := r.claimUsageBillingKey(ctx, tx, cmd)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
if !applied {
|
|
return &service.UsageBillingApplyResult{Applied: false}, nil
|
|
}
|
|
|
|
result := &service.UsageBillingApplyResult{Applied: true}
|
|
if err := r.applyUsageBillingEffects(ctx, tx, cmd, result); err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
if err := tx.Commit(); err != nil {
|
|
return nil, err
|
|
}
|
|
tx = nil
|
|
return result, nil
|
|
}
|
|
|
|
func (r *usageBillingRepository) claimUsageBillingKey(ctx context.Context, tx *sql.Tx, cmd *service.UsageBillingCommand) (bool, error) {
|
|
var id int64
|
|
err := tx.QueryRowContext(ctx, `
|
|
INSERT INTO usage_billing_dedup (request_id, api_key_id, request_fingerprint)
|
|
VALUES ($1, $2, $3)
|
|
ON CONFLICT (request_id, api_key_id) DO NOTHING
|
|
RETURNING id
|
|
`, cmd.RequestID, cmd.APIKeyID, cmd.RequestFingerprint).Scan(&id)
|
|
if errors.Is(err, sql.ErrNoRows) {
|
|
var existingFingerprint string
|
|
if err := tx.QueryRowContext(ctx, `
|
|
SELECT request_fingerprint
|
|
FROM usage_billing_dedup
|
|
WHERE request_id = $1 AND api_key_id = $2
|
|
`, cmd.RequestID, cmd.APIKeyID).Scan(&existingFingerprint); err != nil {
|
|
return false, err
|
|
}
|
|
if strings.TrimSpace(existingFingerprint) != strings.TrimSpace(cmd.RequestFingerprint) {
|
|
return false, service.ErrUsageBillingRequestConflict
|
|
}
|
|
return false, nil
|
|
}
|
|
if err != nil {
|
|
return false, err
|
|
}
|
|
var archivedFingerprint string
|
|
err = tx.QueryRowContext(ctx, `
|
|
SELECT request_fingerprint
|
|
FROM usage_billing_dedup_archive
|
|
WHERE request_id = $1 AND api_key_id = $2
|
|
`, cmd.RequestID, cmd.APIKeyID).Scan(&archivedFingerprint)
|
|
if err == nil {
|
|
if strings.TrimSpace(archivedFingerprint) != strings.TrimSpace(cmd.RequestFingerprint) {
|
|
return false, service.ErrUsageBillingRequestConflict
|
|
}
|
|
return false, nil
|
|
}
|
|
if !errors.Is(err, sql.ErrNoRows) {
|
|
return false, err
|
|
}
|
|
return true, nil
|
|
}
|
|
|
|
func (r *usageBillingRepository) applyUsageBillingEffects(ctx context.Context, tx *sql.Tx, cmd *service.UsageBillingCommand, result *service.UsageBillingApplyResult) error {
|
|
if cmd.SubscriptionCost > 0 && cmd.SubscriptionID != nil {
|
|
if err := incrementUsageBillingSubscription(ctx, tx, *cmd.SubscriptionID, cmd.SubscriptionCost); err != nil {
|
|
return err
|
|
}
|
|
}
|
|
|
|
if cmd.BalanceCost > 0 {
|
|
if err := deductUsageBillingBalance(ctx, tx, cmd.UserID, cmd.BalanceCost); err != nil {
|
|
return err
|
|
}
|
|
}
|
|
|
|
if cmd.APIKeyQuotaCost > 0 {
|
|
exhausted, err := incrementUsageBillingAPIKeyQuota(ctx, tx, cmd.APIKeyID, cmd.APIKeyQuotaCost)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
result.APIKeyQuotaExhausted = exhausted
|
|
}
|
|
|
|
if cmd.APIKeyRateLimitCost > 0 {
|
|
if err := incrementUsageBillingAPIKeyRateLimit(ctx, tx, cmd.APIKeyID, cmd.APIKeyRateLimitCost); err != nil {
|
|
return err
|
|
}
|
|
}
|
|
|
|
if cmd.AccountQuotaCost > 0 && strings.EqualFold(cmd.AccountType, service.AccountTypeAPIKey) {
|
|
if err := incrementUsageBillingAccountQuota(ctx, tx, cmd.AccountID, cmd.AccountQuotaCost); err != nil {
|
|
return err
|
|
}
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
func incrementUsageBillingSubscription(ctx context.Context, tx *sql.Tx, subscriptionID int64, costUSD float64) error {
|
|
const updateSQL = `
|
|
UPDATE user_subscriptions us
|
|
SET
|
|
daily_usage_usd = us.daily_usage_usd + $1,
|
|
weekly_usage_usd = us.weekly_usage_usd + $1,
|
|
monthly_usage_usd = us.monthly_usage_usd + $1,
|
|
updated_at = NOW()
|
|
FROM groups g
|
|
WHERE us.id = $2
|
|
AND us.deleted_at IS NULL
|
|
AND us.group_id = g.id
|
|
AND g.deleted_at IS NULL
|
|
`
|
|
res, err := tx.ExecContext(ctx, updateSQL, costUSD, subscriptionID)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
affected, err := res.RowsAffected()
|
|
if err != nil {
|
|
return err
|
|
}
|
|
if affected > 0 {
|
|
return nil
|
|
}
|
|
return service.ErrSubscriptionNotFound
|
|
}
|
|
|
|
func deductUsageBillingBalance(ctx context.Context, tx *sql.Tx, userID int64, amount float64) error {
|
|
res, err := tx.ExecContext(ctx, `
|
|
UPDATE users
|
|
SET balance = balance - $1,
|
|
updated_at = NOW()
|
|
WHERE id = $2 AND deleted_at IS NULL
|
|
`, amount, userID)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
affected, err := res.RowsAffected()
|
|
if err != nil {
|
|
return err
|
|
}
|
|
if affected > 0 {
|
|
return nil
|
|
}
|
|
return service.ErrUserNotFound
|
|
}
|
|
|
|
func incrementUsageBillingAPIKeyQuota(ctx context.Context, tx *sql.Tx, apiKeyID int64, amount float64) (bool, error) {
|
|
var exhausted bool
|
|
err := tx.QueryRowContext(ctx, `
|
|
UPDATE api_keys
|
|
SET quota_used = quota_used + $1,
|
|
status = CASE
|
|
WHEN quota > 0
|
|
AND status = $3
|
|
AND quota_used < quota
|
|
AND quota_used + $1 >= quota
|
|
THEN $4
|
|
ELSE status
|
|
END,
|
|
updated_at = NOW()
|
|
WHERE id = $2 AND deleted_at IS NULL
|
|
RETURNING quota > 0 AND quota_used >= quota AND quota_used - $1 < quota
|
|
`, amount, apiKeyID, service.StatusAPIKeyActive, service.StatusAPIKeyQuotaExhausted).Scan(&exhausted)
|
|
if errors.Is(err, sql.ErrNoRows) {
|
|
return false, service.ErrAPIKeyNotFound
|
|
}
|
|
if err != nil {
|
|
return false, err
|
|
}
|
|
return exhausted, nil
|
|
}
|
|
|
|
func incrementUsageBillingAPIKeyRateLimit(ctx context.Context, tx *sql.Tx, apiKeyID int64, cost float64) error {
|
|
res, err := tx.ExecContext(ctx, `
|
|
UPDATE api_keys SET
|
|
usage_5h = CASE WHEN window_5h_start IS NOT NULL AND window_5h_start + INTERVAL '5 hours' <= NOW() THEN $1 ELSE usage_5h + $1 END,
|
|
usage_1d = CASE WHEN window_1d_start IS NOT NULL AND window_1d_start + INTERVAL '24 hours' <= NOW() THEN $1 ELSE usage_1d + $1 END,
|
|
usage_7d = CASE WHEN window_7d_start IS NOT NULL AND window_7d_start + INTERVAL '7 days' <= NOW() THEN $1 ELSE usage_7d + $1 END,
|
|
window_5h_start = CASE WHEN window_5h_start IS NULL OR window_5h_start + INTERVAL '5 hours' <= NOW() THEN NOW() ELSE window_5h_start END,
|
|
window_1d_start = CASE WHEN window_1d_start IS NULL OR window_1d_start + INTERVAL '24 hours' <= NOW() THEN date_trunc('day', NOW()) ELSE window_1d_start END,
|
|
window_7d_start = CASE WHEN window_7d_start IS NULL OR window_7d_start + INTERVAL '7 days' <= NOW() THEN date_trunc('day', NOW()) ELSE window_7d_start END,
|
|
updated_at = NOW()
|
|
WHERE id = $2 AND deleted_at IS NULL
|
|
`, cost, apiKeyID)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
affected, err := res.RowsAffected()
|
|
if err != nil {
|
|
return err
|
|
}
|
|
if affected == 0 {
|
|
return service.ErrAPIKeyNotFound
|
|
}
|
|
return nil
|
|
}
|
|
|
|
func incrementUsageBillingAccountQuota(ctx context.Context, tx *sql.Tx, accountID int64, amount float64) error {
|
|
rows, err := tx.QueryContext(ctx,
|
|
`UPDATE accounts SET extra = (
|
|
COALESCE(extra, '{}'::jsonb)
|
|
|| jsonb_build_object('quota_used', COALESCE((extra->>'quota_used')::numeric, 0) + $1)
|
|
|| CASE WHEN COALESCE((extra->>'quota_daily_limit')::numeric, 0) > 0 THEN
|
|
jsonb_build_object(
|
|
'quota_daily_used',
|
|
CASE WHEN COALESCE((extra->>'quota_daily_start')::timestamptz, '1970-01-01'::timestamptz)
|
|
+ '24 hours'::interval <= NOW()
|
|
THEN $1
|
|
ELSE COALESCE((extra->>'quota_daily_used')::numeric, 0) + $1 END,
|
|
'quota_daily_start',
|
|
CASE WHEN COALESCE((extra->>'quota_daily_start')::timestamptz, '1970-01-01'::timestamptz)
|
|
+ '24 hours'::interval <= NOW()
|
|
THEN `+nowUTC+`
|
|
ELSE COALESCE(extra->>'quota_daily_start', `+nowUTC+`) END
|
|
)
|
|
ELSE '{}'::jsonb END
|
|
|| CASE WHEN COALESCE((extra->>'quota_weekly_limit')::numeric, 0) > 0 THEN
|
|
jsonb_build_object(
|
|
'quota_weekly_used',
|
|
CASE WHEN COALESCE((extra->>'quota_weekly_start')::timestamptz, '1970-01-01'::timestamptz)
|
|
+ '168 hours'::interval <= NOW()
|
|
THEN $1
|
|
ELSE COALESCE((extra->>'quota_weekly_used')::numeric, 0) + $1 END,
|
|
'quota_weekly_start',
|
|
CASE WHEN COALESCE((extra->>'quota_weekly_start')::timestamptz, '1970-01-01'::timestamptz)
|
|
+ '168 hours'::interval <= NOW()
|
|
THEN `+nowUTC+`
|
|
ELSE COALESCE(extra->>'quota_weekly_start', `+nowUTC+`) END
|
|
)
|
|
ELSE '{}'::jsonb END
|
|
), updated_at = NOW()
|
|
WHERE id = $2 AND deleted_at IS NULL
|
|
RETURNING
|
|
COALESCE((extra->>'quota_used')::numeric, 0),
|
|
COALESCE((extra->>'quota_limit')::numeric, 0)`,
|
|
amount, accountID)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
defer func() { _ = rows.Close() }()
|
|
|
|
var newUsed, limit float64
|
|
if rows.Next() {
|
|
if err := rows.Scan(&newUsed, &limit); err != nil {
|
|
return err
|
|
}
|
|
} else {
|
|
if err := rows.Err(); err != nil {
|
|
return err
|
|
}
|
|
return service.ErrAccountNotFound
|
|
}
|
|
if err := rows.Err(); err != nil {
|
|
return err
|
|
}
|
|
if limit > 0 && newUsed >= limit && (newUsed-amount) < limit {
|
|
if err := enqueueSchedulerOutbox(ctx, tx, service.SchedulerOutboxEventAccountChanged, &accountID, nil, nil); err != nil {
|
|
logger.LegacyPrintf("repository.usage_billing", "[SchedulerOutbox] enqueue quota exceeded failed: account=%d err=%v", accountID, err)
|
|
return err
|
|
}
|
|
}
|
|
return nil
|
|
}
|