mirror of
https://github.com/Wei-Shaw/sub2api.git
synced 2026-03-30 02:09:43 +00:00
feat(groups): add Claude Code client restriction and session isolation
- Add claude_code_only field to restrict groups to Claude Code clients only - Add fallback_group_id for non-Claude Code requests to use alternate group - Implement ClaudeCodeValidator for User-Agent detection - Add group-level session binding isolation (groupID in Redis key) - Prevent cross-group sticky session pollution - Update frontend with Claude Code restriction controls
This commit is contained in:
@@ -51,6 +51,10 @@ type Group struct {
|
||||
ImagePrice2k *float64 `json:"image_price_2k,omitempty"`
|
||||
// ImagePrice4k holds the value of the "image_price_4k" field.
|
||||
ImagePrice4k *float64 `json:"image_price_4k,omitempty"`
|
||||
// 是否仅允许 Claude Code 客户端
|
||||
ClaudeCodeOnly bool `json:"claude_code_only,omitempty"`
|
||||
// 非 Claude Code 请求降级使用的分组 ID
|
||||
FallbackGroupID *int64 `json:"fallback_group_id,omitempty"`
|
||||
// Edges holds the relations/edges for other nodes in the graph.
|
||||
// The values are being populated by the GroupQuery when eager-loading is set.
|
||||
Edges GroupEdges `json:"edges"`
|
||||
@@ -157,11 +161,11 @@ func (*Group) scanValues(columns []string) ([]any, error) {
|
||||
values := make([]any, len(columns))
|
||||
for i := range columns {
|
||||
switch columns[i] {
|
||||
case group.FieldIsExclusive:
|
||||
case group.FieldIsExclusive, group.FieldClaudeCodeOnly:
|
||||
values[i] = new(sql.NullBool)
|
||||
case group.FieldRateMultiplier, group.FieldDailyLimitUsd, group.FieldWeeklyLimitUsd, group.FieldMonthlyLimitUsd, group.FieldImagePrice1k, group.FieldImagePrice2k, group.FieldImagePrice4k:
|
||||
values[i] = new(sql.NullFloat64)
|
||||
case group.FieldID, group.FieldDefaultValidityDays:
|
||||
case group.FieldID, group.FieldDefaultValidityDays, group.FieldFallbackGroupID:
|
||||
values[i] = new(sql.NullInt64)
|
||||
case group.FieldName, group.FieldDescription, group.FieldStatus, group.FieldPlatform, group.FieldSubscriptionType:
|
||||
values[i] = new(sql.NullString)
|
||||
@@ -298,6 +302,19 @@ func (_m *Group) assignValues(columns []string, values []any) error {
|
||||
_m.ImagePrice4k = new(float64)
|
||||
*_m.ImagePrice4k = value.Float64
|
||||
}
|
||||
case group.FieldClaudeCodeOnly:
|
||||
if value, ok := values[i].(*sql.NullBool); !ok {
|
||||
return fmt.Errorf("unexpected type %T for field claude_code_only", values[i])
|
||||
} else if value.Valid {
|
||||
_m.ClaudeCodeOnly = value.Bool
|
||||
}
|
||||
case group.FieldFallbackGroupID:
|
||||
if value, ok := values[i].(*sql.NullInt64); !ok {
|
||||
return fmt.Errorf("unexpected type %T for field fallback_group_id", values[i])
|
||||
} else if value.Valid {
|
||||
_m.FallbackGroupID = new(int64)
|
||||
*_m.FallbackGroupID = value.Int64
|
||||
}
|
||||
default:
|
||||
_m.selectValues.Set(columns[i], values[i])
|
||||
}
|
||||
@@ -440,6 +457,14 @@ func (_m *Group) String() string {
|
||||
builder.WriteString("image_price_4k=")
|
||||
builder.WriteString(fmt.Sprintf("%v", *v))
|
||||
}
|
||||
builder.WriteString(", ")
|
||||
builder.WriteString("claude_code_only=")
|
||||
builder.WriteString(fmt.Sprintf("%v", _m.ClaudeCodeOnly))
|
||||
builder.WriteString(", ")
|
||||
if v := _m.FallbackGroupID; v != nil {
|
||||
builder.WriteString("fallback_group_id=")
|
||||
builder.WriteString(fmt.Sprintf("%v", *v))
|
||||
}
|
||||
builder.WriteByte(')')
|
||||
return builder.String()
|
||||
}
|
||||
|
||||
@@ -49,6 +49,10 @@ const (
|
||||
FieldImagePrice2k = "image_price_2k"
|
||||
// FieldImagePrice4k holds the string denoting the image_price_4k field in the database.
|
||||
FieldImagePrice4k = "image_price_4k"
|
||||
// FieldClaudeCodeOnly holds the string denoting the claude_code_only field in the database.
|
||||
FieldClaudeCodeOnly = "claude_code_only"
|
||||
// FieldFallbackGroupID holds the string denoting the fallback_group_id field in the database.
|
||||
FieldFallbackGroupID = "fallback_group_id"
|
||||
// EdgeAPIKeys holds the string denoting the api_keys edge name in mutations.
|
||||
EdgeAPIKeys = "api_keys"
|
||||
// EdgeRedeemCodes holds the string denoting the redeem_codes edge name in mutations.
|
||||
@@ -141,6 +145,8 @@ var Columns = []string{
|
||||
FieldImagePrice1k,
|
||||
FieldImagePrice2k,
|
||||
FieldImagePrice4k,
|
||||
FieldClaudeCodeOnly,
|
||||
FieldFallbackGroupID,
|
||||
}
|
||||
|
||||
var (
|
||||
@@ -196,6 +202,8 @@ var (
|
||||
SubscriptionTypeValidator func(string) error
|
||||
// DefaultDefaultValidityDays holds the default value on creation for the "default_validity_days" field.
|
||||
DefaultDefaultValidityDays int
|
||||
// DefaultClaudeCodeOnly holds the default value on creation for the "claude_code_only" field.
|
||||
DefaultClaudeCodeOnly bool
|
||||
)
|
||||
|
||||
// OrderOption defines the ordering options for the Group queries.
|
||||
@@ -291,6 +299,16 @@ func ByImagePrice4k(opts ...sql.OrderTermOption) OrderOption {
|
||||
return sql.OrderByField(FieldImagePrice4k, opts...).ToFunc()
|
||||
}
|
||||
|
||||
// ByClaudeCodeOnly orders the results by the claude_code_only field.
|
||||
func ByClaudeCodeOnly(opts ...sql.OrderTermOption) OrderOption {
|
||||
return sql.OrderByField(FieldClaudeCodeOnly, opts...).ToFunc()
|
||||
}
|
||||
|
||||
// ByFallbackGroupID orders the results by the fallback_group_id field.
|
||||
func ByFallbackGroupID(opts ...sql.OrderTermOption) OrderOption {
|
||||
return sql.OrderByField(FieldFallbackGroupID, opts...).ToFunc()
|
||||
}
|
||||
|
||||
// ByAPIKeysCount orders the results by api_keys count.
|
||||
func ByAPIKeysCount(opts ...sql.OrderTermOption) OrderOption {
|
||||
return func(s *sql.Selector) {
|
||||
|
||||
@@ -140,6 +140,16 @@ func ImagePrice4k(v float64) predicate.Group {
|
||||
return predicate.Group(sql.FieldEQ(FieldImagePrice4k, v))
|
||||
}
|
||||
|
||||
// ClaudeCodeOnly applies equality check predicate on the "claude_code_only" field. It's identical to ClaudeCodeOnlyEQ.
|
||||
func ClaudeCodeOnly(v bool) predicate.Group {
|
||||
return predicate.Group(sql.FieldEQ(FieldClaudeCodeOnly, v))
|
||||
}
|
||||
|
||||
// FallbackGroupID applies equality check predicate on the "fallback_group_id" field. It's identical to FallbackGroupIDEQ.
|
||||
func FallbackGroupID(v int64) predicate.Group {
|
||||
return predicate.Group(sql.FieldEQ(FieldFallbackGroupID, v))
|
||||
}
|
||||
|
||||
// CreatedAtEQ applies the EQ predicate on the "created_at" field.
|
||||
func CreatedAtEQ(v time.Time) predicate.Group {
|
||||
return predicate.Group(sql.FieldEQ(FieldCreatedAt, v))
|
||||
@@ -995,6 +1005,66 @@ func ImagePrice4kNotNil() predicate.Group {
|
||||
return predicate.Group(sql.FieldNotNull(FieldImagePrice4k))
|
||||
}
|
||||
|
||||
// ClaudeCodeOnlyEQ applies the EQ predicate on the "claude_code_only" field.
|
||||
func ClaudeCodeOnlyEQ(v bool) predicate.Group {
|
||||
return predicate.Group(sql.FieldEQ(FieldClaudeCodeOnly, v))
|
||||
}
|
||||
|
||||
// ClaudeCodeOnlyNEQ applies the NEQ predicate on the "claude_code_only" field.
|
||||
func ClaudeCodeOnlyNEQ(v bool) predicate.Group {
|
||||
return predicate.Group(sql.FieldNEQ(FieldClaudeCodeOnly, v))
|
||||
}
|
||||
|
||||
// FallbackGroupIDEQ applies the EQ predicate on the "fallback_group_id" field.
|
||||
func FallbackGroupIDEQ(v int64) predicate.Group {
|
||||
return predicate.Group(sql.FieldEQ(FieldFallbackGroupID, v))
|
||||
}
|
||||
|
||||
// FallbackGroupIDNEQ applies the NEQ predicate on the "fallback_group_id" field.
|
||||
func FallbackGroupIDNEQ(v int64) predicate.Group {
|
||||
return predicate.Group(sql.FieldNEQ(FieldFallbackGroupID, v))
|
||||
}
|
||||
|
||||
// FallbackGroupIDIn applies the In predicate on the "fallback_group_id" field.
|
||||
func FallbackGroupIDIn(vs ...int64) predicate.Group {
|
||||
return predicate.Group(sql.FieldIn(FieldFallbackGroupID, vs...))
|
||||
}
|
||||
|
||||
// FallbackGroupIDNotIn applies the NotIn predicate on the "fallback_group_id" field.
|
||||
func FallbackGroupIDNotIn(vs ...int64) predicate.Group {
|
||||
return predicate.Group(sql.FieldNotIn(FieldFallbackGroupID, vs...))
|
||||
}
|
||||
|
||||
// FallbackGroupIDGT applies the GT predicate on the "fallback_group_id" field.
|
||||
func FallbackGroupIDGT(v int64) predicate.Group {
|
||||
return predicate.Group(sql.FieldGT(FieldFallbackGroupID, v))
|
||||
}
|
||||
|
||||
// FallbackGroupIDGTE applies the GTE predicate on the "fallback_group_id" field.
|
||||
func FallbackGroupIDGTE(v int64) predicate.Group {
|
||||
return predicate.Group(sql.FieldGTE(FieldFallbackGroupID, v))
|
||||
}
|
||||
|
||||
// FallbackGroupIDLT applies the LT predicate on the "fallback_group_id" field.
|
||||
func FallbackGroupIDLT(v int64) predicate.Group {
|
||||
return predicate.Group(sql.FieldLT(FieldFallbackGroupID, v))
|
||||
}
|
||||
|
||||
// FallbackGroupIDLTE applies the LTE predicate on the "fallback_group_id" field.
|
||||
func FallbackGroupIDLTE(v int64) predicate.Group {
|
||||
return predicate.Group(sql.FieldLTE(FieldFallbackGroupID, v))
|
||||
}
|
||||
|
||||
// FallbackGroupIDIsNil applies the IsNil predicate on the "fallback_group_id" field.
|
||||
func FallbackGroupIDIsNil() predicate.Group {
|
||||
return predicate.Group(sql.FieldIsNull(FieldFallbackGroupID))
|
||||
}
|
||||
|
||||
// FallbackGroupIDNotNil applies the NotNil predicate on the "fallback_group_id" field.
|
||||
func FallbackGroupIDNotNil() predicate.Group {
|
||||
return predicate.Group(sql.FieldNotNull(FieldFallbackGroupID))
|
||||
}
|
||||
|
||||
// HasAPIKeys applies the HasEdge predicate on the "api_keys" edge.
|
||||
func HasAPIKeys() predicate.Group {
|
||||
return predicate.Group(func(s *sql.Selector) {
|
||||
|
||||
@@ -258,6 +258,34 @@ func (_c *GroupCreate) SetNillableImagePrice4k(v *float64) *GroupCreate {
|
||||
return _c
|
||||
}
|
||||
|
||||
// SetClaudeCodeOnly sets the "claude_code_only" field.
|
||||
func (_c *GroupCreate) SetClaudeCodeOnly(v bool) *GroupCreate {
|
||||
_c.mutation.SetClaudeCodeOnly(v)
|
||||
return _c
|
||||
}
|
||||
|
||||
// SetNillableClaudeCodeOnly sets the "claude_code_only" field if the given value is not nil.
|
||||
func (_c *GroupCreate) SetNillableClaudeCodeOnly(v *bool) *GroupCreate {
|
||||
if v != nil {
|
||||
_c.SetClaudeCodeOnly(*v)
|
||||
}
|
||||
return _c
|
||||
}
|
||||
|
||||
// SetFallbackGroupID sets the "fallback_group_id" field.
|
||||
func (_c *GroupCreate) SetFallbackGroupID(v int64) *GroupCreate {
|
||||
_c.mutation.SetFallbackGroupID(v)
|
||||
return _c
|
||||
}
|
||||
|
||||
// SetNillableFallbackGroupID sets the "fallback_group_id" field if the given value is not nil.
|
||||
func (_c *GroupCreate) SetNillableFallbackGroupID(v *int64) *GroupCreate {
|
||||
if v != nil {
|
||||
_c.SetFallbackGroupID(*v)
|
||||
}
|
||||
return _c
|
||||
}
|
||||
|
||||
// AddAPIKeyIDs adds the "api_keys" edge to the APIKey entity by IDs.
|
||||
func (_c *GroupCreate) AddAPIKeyIDs(ids ...int64) *GroupCreate {
|
||||
_c.mutation.AddAPIKeyIDs(ids...)
|
||||
@@ -423,6 +451,10 @@ func (_c *GroupCreate) defaults() error {
|
||||
v := group.DefaultDefaultValidityDays
|
||||
_c.mutation.SetDefaultValidityDays(v)
|
||||
}
|
||||
if _, ok := _c.mutation.ClaudeCodeOnly(); !ok {
|
||||
v := group.DefaultClaudeCodeOnly
|
||||
_c.mutation.SetClaudeCodeOnly(v)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -475,6 +507,9 @@ func (_c *GroupCreate) check() error {
|
||||
if _, ok := _c.mutation.DefaultValidityDays(); !ok {
|
||||
return &ValidationError{Name: "default_validity_days", err: errors.New(`ent: missing required field "Group.default_validity_days"`)}
|
||||
}
|
||||
if _, ok := _c.mutation.ClaudeCodeOnly(); !ok {
|
||||
return &ValidationError{Name: "claude_code_only", err: errors.New(`ent: missing required field "Group.claude_code_only"`)}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -570,6 +605,14 @@ func (_c *GroupCreate) createSpec() (*Group, *sqlgraph.CreateSpec) {
|
||||
_spec.SetField(group.FieldImagePrice4k, field.TypeFloat64, value)
|
||||
_node.ImagePrice4k = &value
|
||||
}
|
||||
if value, ok := _c.mutation.ClaudeCodeOnly(); ok {
|
||||
_spec.SetField(group.FieldClaudeCodeOnly, field.TypeBool, value)
|
||||
_node.ClaudeCodeOnly = value
|
||||
}
|
||||
if value, ok := _c.mutation.FallbackGroupID(); ok {
|
||||
_spec.SetField(group.FieldFallbackGroupID, field.TypeInt64, value)
|
||||
_node.FallbackGroupID = &value
|
||||
}
|
||||
if nodes := _c.mutation.APIKeysIDs(); len(nodes) > 0 {
|
||||
edge := &sqlgraph.EdgeSpec{
|
||||
Rel: sqlgraph.O2M,
|
||||
@@ -1014,6 +1057,42 @@ func (u *GroupUpsert) ClearImagePrice4k() *GroupUpsert {
|
||||
return u
|
||||
}
|
||||
|
||||
// SetClaudeCodeOnly sets the "claude_code_only" field.
|
||||
func (u *GroupUpsert) SetClaudeCodeOnly(v bool) *GroupUpsert {
|
||||
u.Set(group.FieldClaudeCodeOnly, v)
|
||||
return u
|
||||
}
|
||||
|
||||
// UpdateClaudeCodeOnly sets the "claude_code_only" field to the value that was provided on create.
|
||||
func (u *GroupUpsert) UpdateClaudeCodeOnly() *GroupUpsert {
|
||||
u.SetExcluded(group.FieldClaudeCodeOnly)
|
||||
return u
|
||||
}
|
||||
|
||||
// SetFallbackGroupID sets the "fallback_group_id" field.
|
||||
func (u *GroupUpsert) SetFallbackGroupID(v int64) *GroupUpsert {
|
||||
u.Set(group.FieldFallbackGroupID, v)
|
||||
return u
|
||||
}
|
||||
|
||||
// UpdateFallbackGroupID sets the "fallback_group_id" field to the value that was provided on create.
|
||||
func (u *GroupUpsert) UpdateFallbackGroupID() *GroupUpsert {
|
||||
u.SetExcluded(group.FieldFallbackGroupID)
|
||||
return u
|
||||
}
|
||||
|
||||
// AddFallbackGroupID adds v to the "fallback_group_id" field.
|
||||
func (u *GroupUpsert) AddFallbackGroupID(v int64) *GroupUpsert {
|
||||
u.Add(group.FieldFallbackGroupID, v)
|
||||
return u
|
||||
}
|
||||
|
||||
// ClearFallbackGroupID clears the value of the "fallback_group_id" field.
|
||||
func (u *GroupUpsert) ClearFallbackGroupID() *GroupUpsert {
|
||||
u.SetNull(group.FieldFallbackGroupID)
|
||||
return u
|
||||
}
|
||||
|
||||
// UpdateNewValues updates the mutable fields using the new values that were set on create.
|
||||
// Using this option is equivalent to using:
|
||||
//
|
||||
@@ -1395,6 +1474,48 @@ func (u *GroupUpsertOne) ClearImagePrice4k() *GroupUpsertOne {
|
||||
})
|
||||
}
|
||||
|
||||
// SetClaudeCodeOnly sets the "claude_code_only" field.
|
||||
func (u *GroupUpsertOne) SetClaudeCodeOnly(v bool) *GroupUpsertOne {
|
||||
return u.Update(func(s *GroupUpsert) {
|
||||
s.SetClaudeCodeOnly(v)
|
||||
})
|
||||
}
|
||||
|
||||
// UpdateClaudeCodeOnly sets the "claude_code_only" field to the value that was provided on create.
|
||||
func (u *GroupUpsertOne) UpdateClaudeCodeOnly() *GroupUpsertOne {
|
||||
return u.Update(func(s *GroupUpsert) {
|
||||
s.UpdateClaudeCodeOnly()
|
||||
})
|
||||
}
|
||||
|
||||
// SetFallbackGroupID sets the "fallback_group_id" field.
|
||||
func (u *GroupUpsertOne) SetFallbackGroupID(v int64) *GroupUpsertOne {
|
||||
return u.Update(func(s *GroupUpsert) {
|
||||
s.SetFallbackGroupID(v)
|
||||
})
|
||||
}
|
||||
|
||||
// AddFallbackGroupID adds v to the "fallback_group_id" field.
|
||||
func (u *GroupUpsertOne) AddFallbackGroupID(v int64) *GroupUpsertOne {
|
||||
return u.Update(func(s *GroupUpsert) {
|
||||
s.AddFallbackGroupID(v)
|
||||
})
|
||||
}
|
||||
|
||||
// UpdateFallbackGroupID sets the "fallback_group_id" field to the value that was provided on create.
|
||||
func (u *GroupUpsertOne) UpdateFallbackGroupID() *GroupUpsertOne {
|
||||
return u.Update(func(s *GroupUpsert) {
|
||||
s.UpdateFallbackGroupID()
|
||||
})
|
||||
}
|
||||
|
||||
// ClearFallbackGroupID clears the value of the "fallback_group_id" field.
|
||||
func (u *GroupUpsertOne) ClearFallbackGroupID() *GroupUpsertOne {
|
||||
return u.Update(func(s *GroupUpsert) {
|
||||
s.ClearFallbackGroupID()
|
||||
})
|
||||
}
|
||||
|
||||
// Exec executes the query.
|
||||
func (u *GroupUpsertOne) Exec(ctx context.Context) error {
|
||||
if len(u.create.conflict) == 0 {
|
||||
@@ -1942,6 +2063,48 @@ func (u *GroupUpsertBulk) ClearImagePrice4k() *GroupUpsertBulk {
|
||||
})
|
||||
}
|
||||
|
||||
// SetClaudeCodeOnly sets the "claude_code_only" field.
|
||||
func (u *GroupUpsertBulk) SetClaudeCodeOnly(v bool) *GroupUpsertBulk {
|
||||
return u.Update(func(s *GroupUpsert) {
|
||||
s.SetClaudeCodeOnly(v)
|
||||
})
|
||||
}
|
||||
|
||||
// UpdateClaudeCodeOnly sets the "claude_code_only" field to the value that was provided on create.
|
||||
func (u *GroupUpsertBulk) UpdateClaudeCodeOnly() *GroupUpsertBulk {
|
||||
return u.Update(func(s *GroupUpsert) {
|
||||
s.UpdateClaudeCodeOnly()
|
||||
})
|
||||
}
|
||||
|
||||
// SetFallbackGroupID sets the "fallback_group_id" field.
|
||||
func (u *GroupUpsertBulk) SetFallbackGroupID(v int64) *GroupUpsertBulk {
|
||||
return u.Update(func(s *GroupUpsert) {
|
||||
s.SetFallbackGroupID(v)
|
||||
})
|
||||
}
|
||||
|
||||
// AddFallbackGroupID adds v to the "fallback_group_id" field.
|
||||
func (u *GroupUpsertBulk) AddFallbackGroupID(v int64) *GroupUpsertBulk {
|
||||
return u.Update(func(s *GroupUpsert) {
|
||||
s.AddFallbackGroupID(v)
|
||||
})
|
||||
}
|
||||
|
||||
// UpdateFallbackGroupID sets the "fallback_group_id" field to the value that was provided on create.
|
||||
func (u *GroupUpsertBulk) UpdateFallbackGroupID() *GroupUpsertBulk {
|
||||
return u.Update(func(s *GroupUpsert) {
|
||||
s.UpdateFallbackGroupID()
|
||||
})
|
||||
}
|
||||
|
||||
// ClearFallbackGroupID clears the value of the "fallback_group_id" field.
|
||||
func (u *GroupUpsertBulk) ClearFallbackGroupID() *GroupUpsertBulk {
|
||||
return u.Update(func(s *GroupUpsert) {
|
||||
s.ClearFallbackGroupID()
|
||||
})
|
||||
}
|
||||
|
||||
// Exec executes the query.
|
||||
func (u *GroupUpsertBulk) Exec(ctx context.Context) error {
|
||||
if u.create.err != nil {
|
||||
|
||||
@@ -354,6 +354,47 @@ func (_u *GroupUpdate) ClearImagePrice4k() *GroupUpdate {
|
||||
return _u
|
||||
}
|
||||
|
||||
// SetClaudeCodeOnly sets the "claude_code_only" field.
|
||||
func (_u *GroupUpdate) SetClaudeCodeOnly(v bool) *GroupUpdate {
|
||||
_u.mutation.SetClaudeCodeOnly(v)
|
||||
return _u
|
||||
}
|
||||
|
||||
// SetNillableClaudeCodeOnly sets the "claude_code_only" field if the given value is not nil.
|
||||
func (_u *GroupUpdate) SetNillableClaudeCodeOnly(v *bool) *GroupUpdate {
|
||||
if v != nil {
|
||||
_u.SetClaudeCodeOnly(*v)
|
||||
}
|
||||
return _u
|
||||
}
|
||||
|
||||
// SetFallbackGroupID sets the "fallback_group_id" field.
|
||||
func (_u *GroupUpdate) SetFallbackGroupID(v int64) *GroupUpdate {
|
||||
_u.mutation.ResetFallbackGroupID()
|
||||
_u.mutation.SetFallbackGroupID(v)
|
||||
return _u
|
||||
}
|
||||
|
||||
// SetNillableFallbackGroupID sets the "fallback_group_id" field if the given value is not nil.
|
||||
func (_u *GroupUpdate) SetNillableFallbackGroupID(v *int64) *GroupUpdate {
|
||||
if v != nil {
|
||||
_u.SetFallbackGroupID(*v)
|
||||
}
|
||||
return _u
|
||||
}
|
||||
|
||||
// AddFallbackGroupID adds value to the "fallback_group_id" field.
|
||||
func (_u *GroupUpdate) AddFallbackGroupID(v int64) *GroupUpdate {
|
||||
_u.mutation.AddFallbackGroupID(v)
|
||||
return _u
|
||||
}
|
||||
|
||||
// ClearFallbackGroupID clears the value of the "fallback_group_id" field.
|
||||
func (_u *GroupUpdate) ClearFallbackGroupID() *GroupUpdate {
|
||||
_u.mutation.ClearFallbackGroupID()
|
||||
return _u
|
||||
}
|
||||
|
||||
// AddAPIKeyIDs adds the "api_keys" edge to the APIKey entity by IDs.
|
||||
func (_u *GroupUpdate) AddAPIKeyIDs(ids ...int64) *GroupUpdate {
|
||||
_u.mutation.AddAPIKeyIDs(ids...)
|
||||
@@ -750,6 +791,18 @@ func (_u *GroupUpdate) sqlSave(ctx context.Context) (_node int, err error) {
|
||||
if _u.mutation.ImagePrice4kCleared() {
|
||||
_spec.ClearField(group.FieldImagePrice4k, field.TypeFloat64)
|
||||
}
|
||||
if value, ok := _u.mutation.ClaudeCodeOnly(); ok {
|
||||
_spec.SetField(group.FieldClaudeCodeOnly, field.TypeBool, value)
|
||||
}
|
||||
if value, ok := _u.mutation.FallbackGroupID(); ok {
|
||||
_spec.SetField(group.FieldFallbackGroupID, field.TypeInt64, value)
|
||||
}
|
||||
if value, ok := _u.mutation.AddedFallbackGroupID(); ok {
|
||||
_spec.AddField(group.FieldFallbackGroupID, field.TypeInt64, value)
|
||||
}
|
||||
if _u.mutation.FallbackGroupIDCleared() {
|
||||
_spec.ClearField(group.FieldFallbackGroupID, field.TypeInt64)
|
||||
}
|
||||
if _u.mutation.APIKeysCleared() {
|
||||
edge := &sqlgraph.EdgeSpec{
|
||||
Rel: sqlgraph.O2M,
|
||||
@@ -1384,6 +1437,47 @@ func (_u *GroupUpdateOne) ClearImagePrice4k() *GroupUpdateOne {
|
||||
return _u
|
||||
}
|
||||
|
||||
// SetClaudeCodeOnly sets the "claude_code_only" field.
|
||||
func (_u *GroupUpdateOne) SetClaudeCodeOnly(v bool) *GroupUpdateOne {
|
||||
_u.mutation.SetClaudeCodeOnly(v)
|
||||
return _u
|
||||
}
|
||||
|
||||
// SetNillableClaudeCodeOnly sets the "claude_code_only" field if the given value is not nil.
|
||||
func (_u *GroupUpdateOne) SetNillableClaudeCodeOnly(v *bool) *GroupUpdateOne {
|
||||
if v != nil {
|
||||
_u.SetClaudeCodeOnly(*v)
|
||||
}
|
||||
return _u
|
||||
}
|
||||
|
||||
// SetFallbackGroupID sets the "fallback_group_id" field.
|
||||
func (_u *GroupUpdateOne) SetFallbackGroupID(v int64) *GroupUpdateOne {
|
||||
_u.mutation.ResetFallbackGroupID()
|
||||
_u.mutation.SetFallbackGroupID(v)
|
||||
return _u
|
||||
}
|
||||
|
||||
// SetNillableFallbackGroupID sets the "fallback_group_id" field if the given value is not nil.
|
||||
func (_u *GroupUpdateOne) SetNillableFallbackGroupID(v *int64) *GroupUpdateOne {
|
||||
if v != nil {
|
||||
_u.SetFallbackGroupID(*v)
|
||||
}
|
||||
return _u
|
||||
}
|
||||
|
||||
// AddFallbackGroupID adds value to the "fallback_group_id" field.
|
||||
func (_u *GroupUpdateOne) AddFallbackGroupID(v int64) *GroupUpdateOne {
|
||||
_u.mutation.AddFallbackGroupID(v)
|
||||
return _u
|
||||
}
|
||||
|
||||
// ClearFallbackGroupID clears the value of the "fallback_group_id" field.
|
||||
func (_u *GroupUpdateOne) ClearFallbackGroupID() *GroupUpdateOne {
|
||||
_u.mutation.ClearFallbackGroupID()
|
||||
return _u
|
||||
}
|
||||
|
||||
// AddAPIKeyIDs adds the "api_keys" edge to the APIKey entity by IDs.
|
||||
func (_u *GroupUpdateOne) AddAPIKeyIDs(ids ...int64) *GroupUpdateOne {
|
||||
_u.mutation.AddAPIKeyIDs(ids...)
|
||||
@@ -1810,6 +1904,18 @@ func (_u *GroupUpdateOne) sqlSave(ctx context.Context) (_node *Group, err error)
|
||||
if _u.mutation.ImagePrice4kCleared() {
|
||||
_spec.ClearField(group.FieldImagePrice4k, field.TypeFloat64)
|
||||
}
|
||||
if value, ok := _u.mutation.ClaudeCodeOnly(); ok {
|
||||
_spec.SetField(group.FieldClaudeCodeOnly, field.TypeBool, value)
|
||||
}
|
||||
if value, ok := _u.mutation.FallbackGroupID(); ok {
|
||||
_spec.SetField(group.FieldFallbackGroupID, field.TypeInt64, value)
|
||||
}
|
||||
if value, ok := _u.mutation.AddedFallbackGroupID(); ok {
|
||||
_spec.AddField(group.FieldFallbackGroupID, field.TypeInt64, value)
|
||||
}
|
||||
if _u.mutation.FallbackGroupIDCleared() {
|
||||
_spec.ClearField(group.FieldFallbackGroupID, field.TypeInt64)
|
||||
}
|
||||
if _u.mutation.APIKeysCleared() {
|
||||
edge := &sqlgraph.EdgeSpec{
|
||||
Rel: sqlgraph.O2M,
|
||||
|
||||
@@ -221,6 +221,8 @@ var (
|
||||
{Name: "image_price_1k", Type: field.TypeFloat64, Nullable: true, SchemaType: map[string]string{"postgres": "decimal(20,8)"}},
|
||||
{Name: "image_price_2k", Type: field.TypeFloat64, Nullable: true, SchemaType: map[string]string{"postgres": "decimal(20,8)"}},
|
||||
{Name: "image_price_4k", Type: field.TypeFloat64, Nullable: true, SchemaType: map[string]string{"postgres": "decimal(20,8)"}},
|
||||
{Name: "claude_code_only", Type: field.TypeBool, Default: false},
|
||||
{Name: "fallback_group_id", Type: field.TypeInt64, Nullable: true},
|
||||
}
|
||||
// GroupsTable holds the schema information for the "groups" table.
|
||||
GroupsTable = &schema.Table{
|
||||
|
||||
@@ -3590,6 +3590,9 @@ type GroupMutation struct {
|
||||
addimage_price_2k *float64
|
||||
image_price_4k *float64
|
||||
addimage_price_4k *float64
|
||||
claude_code_only *bool
|
||||
fallback_group_id *int64
|
||||
addfallback_group_id *int64
|
||||
clearedFields map[string]struct{}
|
||||
api_keys map[int64]struct{}
|
||||
removedapi_keys map[int64]struct{}
|
||||
@@ -4594,6 +4597,112 @@ func (m *GroupMutation) ResetImagePrice4k() {
|
||||
delete(m.clearedFields, group.FieldImagePrice4k)
|
||||
}
|
||||
|
||||
// SetClaudeCodeOnly sets the "claude_code_only" field.
|
||||
func (m *GroupMutation) SetClaudeCodeOnly(b bool) {
|
||||
m.claude_code_only = &b
|
||||
}
|
||||
|
||||
// ClaudeCodeOnly returns the value of the "claude_code_only" field in the mutation.
|
||||
func (m *GroupMutation) ClaudeCodeOnly() (r bool, exists bool) {
|
||||
v := m.claude_code_only
|
||||
if v == nil {
|
||||
return
|
||||
}
|
||||
return *v, true
|
||||
}
|
||||
|
||||
// OldClaudeCodeOnly returns the old "claude_code_only" field's value of the Group entity.
|
||||
// If the Group object wasn't provided to the builder, the object is fetched from the database.
|
||||
// An error is returned if the mutation operation is not UpdateOne, or the database query fails.
|
||||
func (m *GroupMutation) OldClaudeCodeOnly(ctx context.Context) (v bool, err error) {
|
||||
if !m.op.Is(OpUpdateOne) {
|
||||
return v, errors.New("OldClaudeCodeOnly is only allowed on UpdateOne operations")
|
||||
}
|
||||
if m.id == nil || m.oldValue == nil {
|
||||
return v, errors.New("OldClaudeCodeOnly requires an ID field in the mutation")
|
||||
}
|
||||
oldValue, err := m.oldValue(ctx)
|
||||
if err != nil {
|
||||
return v, fmt.Errorf("querying old value for OldClaudeCodeOnly: %w", err)
|
||||
}
|
||||
return oldValue.ClaudeCodeOnly, nil
|
||||
}
|
||||
|
||||
// ResetClaudeCodeOnly resets all changes to the "claude_code_only" field.
|
||||
func (m *GroupMutation) ResetClaudeCodeOnly() {
|
||||
m.claude_code_only = nil
|
||||
}
|
||||
|
||||
// SetFallbackGroupID sets the "fallback_group_id" field.
|
||||
func (m *GroupMutation) SetFallbackGroupID(i int64) {
|
||||
m.fallback_group_id = &i
|
||||
m.addfallback_group_id = nil
|
||||
}
|
||||
|
||||
// FallbackGroupID returns the value of the "fallback_group_id" field in the mutation.
|
||||
func (m *GroupMutation) FallbackGroupID() (r int64, exists bool) {
|
||||
v := m.fallback_group_id
|
||||
if v == nil {
|
||||
return
|
||||
}
|
||||
return *v, true
|
||||
}
|
||||
|
||||
// OldFallbackGroupID returns the old "fallback_group_id" field's value of the Group entity.
|
||||
// If the Group object wasn't provided to the builder, the object is fetched from the database.
|
||||
// An error is returned if the mutation operation is not UpdateOne, or the database query fails.
|
||||
func (m *GroupMutation) OldFallbackGroupID(ctx context.Context) (v *int64, err error) {
|
||||
if !m.op.Is(OpUpdateOne) {
|
||||
return v, errors.New("OldFallbackGroupID is only allowed on UpdateOne operations")
|
||||
}
|
||||
if m.id == nil || m.oldValue == nil {
|
||||
return v, errors.New("OldFallbackGroupID requires an ID field in the mutation")
|
||||
}
|
||||
oldValue, err := m.oldValue(ctx)
|
||||
if err != nil {
|
||||
return v, fmt.Errorf("querying old value for OldFallbackGroupID: %w", err)
|
||||
}
|
||||
return oldValue.FallbackGroupID, nil
|
||||
}
|
||||
|
||||
// AddFallbackGroupID adds i to the "fallback_group_id" field.
|
||||
func (m *GroupMutation) AddFallbackGroupID(i int64) {
|
||||
if m.addfallback_group_id != nil {
|
||||
*m.addfallback_group_id += i
|
||||
} else {
|
||||
m.addfallback_group_id = &i
|
||||
}
|
||||
}
|
||||
|
||||
// AddedFallbackGroupID returns the value that was added to the "fallback_group_id" field in this mutation.
|
||||
func (m *GroupMutation) AddedFallbackGroupID() (r int64, exists bool) {
|
||||
v := m.addfallback_group_id
|
||||
if v == nil {
|
||||
return
|
||||
}
|
||||
return *v, true
|
||||
}
|
||||
|
||||
// ClearFallbackGroupID clears the value of the "fallback_group_id" field.
|
||||
func (m *GroupMutation) ClearFallbackGroupID() {
|
||||
m.fallback_group_id = nil
|
||||
m.addfallback_group_id = nil
|
||||
m.clearedFields[group.FieldFallbackGroupID] = struct{}{}
|
||||
}
|
||||
|
||||
// FallbackGroupIDCleared returns if the "fallback_group_id" field was cleared in this mutation.
|
||||
func (m *GroupMutation) FallbackGroupIDCleared() bool {
|
||||
_, ok := m.clearedFields[group.FieldFallbackGroupID]
|
||||
return ok
|
||||
}
|
||||
|
||||
// ResetFallbackGroupID resets all changes to the "fallback_group_id" field.
|
||||
func (m *GroupMutation) ResetFallbackGroupID() {
|
||||
m.fallback_group_id = nil
|
||||
m.addfallback_group_id = nil
|
||||
delete(m.clearedFields, group.FieldFallbackGroupID)
|
||||
}
|
||||
|
||||
// AddAPIKeyIDs adds the "api_keys" edge to the APIKey entity by ids.
|
||||
func (m *GroupMutation) AddAPIKeyIDs(ids ...int64) {
|
||||
if m.api_keys == nil {
|
||||
@@ -4952,7 +5061,7 @@ func (m *GroupMutation) Type() string {
|
||||
// order to get all numeric fields that were incremented/decremented, call
|
||||
// AddedFields().
|
||||
func (m *GroupMutation) Fields() []string {
|
||||
fields := make([]string, 0, 17)
|
||||
fields := make([]string, 0, 19)
|
||||
if m.created_at != nil {
|
||||
fields = append(fields, group.FieldCreatedAt)
|
||||
}
|
||||
@@ -5004,6 +5113,12 @@ func (m *GroupMutation) Fields() []string {
|
||||
if m.image_price_4k != nil {
|
||||
fields = append(fields, group.FieldImagePrice4k)
|
||||
}
|
||||
if m.claude_code_only != nil {
|
||||
fields = append(fields, group.FieldClaudeCodeOnly)
|
||||
}
|
||||
if m.fallback_group_id != nil {
|
||||
fields = append(fields, group.FieldFallbackGroupID)
|
||||
}
|
||||
return fields
|
||||
}
|
||||
|
||||
@@ -5046,6 +5161,10 @@ func (m *GroupMutation) Field(name string) (ent.Value, bool) {
|
||||
return m.ImagePrice2k()
|
||||
case group.FieldImagePrice4k:
|
||||
return m.ImagePrice4k()
|
||||
case group.FieldClaudeCodeOnly:
|
||||
return m.ClaudeCodeOnly()
|
||||
case group.FieldFallbackGroupID:
|
||||
return m.FallbackGroupID()
|
||||
}
|
||||
return nil, false
|
||||
}
|
||||
@@ -5089,6 +5208,10 @@ func (m *GroupMutation) OldField(ctx context.Context, name string) (ent.Value, e
|
||||
return m.OldImagePrice2k(ctx)
|
||||
case group.FieldImagePrice4k:
|
||||
return m.OldImagePrice4k(ctx)
|
||||
case group.FieldClaudeCodeOnly:
|
||||
return m.OldClaudeCodeOnly(ctx)
|
||||
case group.FieldFallbackGroupID:
|
||||
return m.OldFallbackGroupID(ctx)
|
||||
}
|
||||
return nil, fmt.Errorf("unknown Group field %s", name)
|
||||
}
|
||||
@@ -5217,6 +5340,20 @@ func (m *GroupMutation) SetField(name string, value ent.Value) error {
|
||||
}
|
||||
m.SetImagePrice4k(v)
|
||||
return nil
|
||||
case group.FieldClaudeCodeOnly:
|
||||
v, ok := value.(bool)
|
||||
if !ok {
|
||||
return fmt.Errorf("unexpected type %T for field %s", value, name)
|
||||
}
|
||||
m.SetClaudeCodeOnly(v)
|
||||
return nil
|
||||
case group.FieldFallbackGroupID:
|
||||
v, ok := value.(int64)
|
||||
if !ok {
|
||||
return fmt.Errorf("unexpected type %T for field %s", value, name)
|
||||
}
|
||||
m.SetFallbackGroupID(v)
|
||||
return nil
|
||||
}
|
||||
return fmt.Errorf("unknown Group field %s", name)
|
||||
}
|
||||
@@ -5249,6 +5386,9 @@ func (m *GroupMutation) AddedFields() []string {
|
||||
if m.addimage_price_4k != nil {
|
||||
fields = append(fields, group.FieldImagePrice4k)
|
||||
}
|
||||
if m.addfallback_group_id != nil {
|
||||
fields = append(fields, group.FieldFallbackGroupID)
|
||||
}
|
||||
return fields
|
||||
}
|
||||
|
||||
@@ -5273,6 +5413,8 @@ func (m *GroupMutation) AddedField(name string) (ent.Value, bool) {
|
||||
return m.AddedImagePrice2k()
|
||||
case group.FieldImagePrice4k:
|
||||
return m.AddedImagePrice4k()
|
||||
case group.FieldFallbackGroupID:
|
||||
return m.AddedFallbackGroupID()
|
||||
}
|
||||
return nil, false
|
||||
}
|
||||
@@ -5338,6 +5480,13 @@ func (m *GroupMutation) AddField(name string, value ent.Value) error {
|
||||
}
|
||||
m.AddImagePrice4k(v)
|
||||
return nil
|
||||
case group.FieldFallbackGroupID:
|
||||
v, ok := value.(int64)
|
||||
if !ok {
|
||||
return fmt.Errorf("unexpected type %T for field %s", value, name)
|
||||
}
|
||||
m.AddFallbackGroupID(v)
|
||||
return nil
|
||||
}
|
||||
return fmt.Errorf("unknown Group numeric field %s", name)
|
||||
}
|
||||
@@ -5370,6 +5519,9 @@ func (m *GroupMutation) ClearedFields() []string {
|
||||
if m.FieldCleared(group.FieldImagePrice4k) {
|
||||
fields = append(fields, group.FieldImagePrice4k)
|
||||
}
|
||||
if m.FieldCleared(group.FieldFallbackGroupID) {
|
||||
fields = append(fields, group.FieldFallbackGroupID)
|
||||
}
|
||||
return fields
|
||||
}
|
||||
|
||||
@@ -5408,6 +5560,9 @@ func (m *GroupMutation) ClearField(name string) error {
|
||||
case group.FieldImagePrice4k:
|
||||
m.ClearImagePrice4k()
|
||||
return nil
|
||||
case group.FieldFallbackGroupID:
|
||||
m.ClearFallbackGroupID()
|
||||
return nil
|
||||
}
|
||||
return fmt.Errorf("unknown Group nullable field %s", name)
|
||||
}
|
||||
@@ -5467,6 +5622,12 @@ func (m *GroupMutation) ResetField(name string) error {
|
||||
case group.FieldImagePrice4k:
|
||||
m.ResetImagePrice4k()
|
||||
return nil
|
||||
case group.FieldClaudeCodeOnly:
|
||||
m.ResetClaudeCodeOnly()
|
||||
return nil
|
||||
case group.FieldFallbackGroupID:
|
||||
m.ResetFallbackGroupID()
|
||||
return nil
|
||||
}
|
||||
return fmt.Errorf("unknown Group field %s", name)
|
||||
}
|
||||
|
||||
@@ -270,6 +270,10 @@ func init() {
|
||||
groupDescDefaultValidityDays := groupFields[10].Descriptor()
|
||||
// group.DefaultDefaultValidityDays holds the default value on creation for the default_validity_days field.
|
||||
group.DefaultDefaultValidityDays = groupDescDefaultValidityDays.Default.(int)
|
||||
// groupDescClaudeCodeOnly is the schema descriptor for claude_code_only field.
|
||||
groupDescClaudeCodeOnly := groupFields[14].Descriptor()
|
||||
// group.DefaultClaudeCodeOnly holds the default value on creation for the claude_code_only field.
|
||||
group.DefaultClaudeCodeOnly = groupDescClaudeCodeOnly.Default.(bool)
|
||||
proxyMixin := schema.Proxy{}.Mixin()
|
||||
proxyMixinHooks1 := proxyMixin[1].Hooks()
|
||||
proxy.Hooks[0] = proxyMixinHooks1[0]
|
||||
|
||||
@@ -86,6 +86,15 @@ func (Group) Fields() []ent.Field {
|
||||
Optional().
|
||||
Nillable().
|
||||
SchemaType(map[string]string{dialect.Postgres: "decimal(20,8)"}),
|
||||
|
||||
// Claude Code 客户端限制 (added by migration 029)
|
||||
field.Bool("claude_code_only").
|
||||
Default(false).
|
||||
Comment("是否仅允许 Claude Code 客户端"),
|
||||
field.Int64("fallback_group_id").
|
||||
Optional().
|
||||
Nillable().
|
||||
Comment("非 Claude Code 请求降级使用的分组 ID"),
|
||||
}
|
||||
}
|
||||
|
||||
@@ -101,6 +110,8 @@ func (Group) Edges() []ent.Edge {
|
||||
edge.From("allowed_users", User.Type).
|
||||
Ref("allowed_groups").
|
||||
Through("user_allowed_groups", UserAllowedGroup.Type),
|
||||
// 注意:fallback_group_id 直接作为字段使用,不定义 edge
|
||||
// 这样允许多个分组指向同一个降级分组(M2O 关系)
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -37,6 +37,8 @@ type CreateGroupRequest struct {
|
||||
ImagePrice1K *float64 `json:"image_price_1k"`
|
||||
ImagePrice2K *float64 `json:"image_price_2k"`
|
||||
ImagePrice4K *float64 `json:"image_price_4k"`
|
||||
ClaudeCodeOnly bool `json:"claude_code_only"`
|
||||
FallbackGroupID *int64 `json:"fallback_group_id"`
|
||||
}
|
||||
|
||||
// UpdateGroupRequest represents update group request
|
||||
@@ -55,6 +57,8 @@ type UpdateGroupRequest struct {
|
||||
ImagePrice1K *float64 `json:"image_price_1k"`
|
||||
ImagePrice2K *float64 `json:"image_price_2k"`
|
||||
ImagePrice4K *float64 `json:"image_price_4k"`
|
||||
ClaudeCodeOnly *bool `json:"claude_code_only"`
|
||||
FallbackGroupID *int64 `json:"fallback_group_id"`
|
||||
}
|
||||
|
||||
// List handles listing all groups with pagination
|
||||
@@ -150,6 +154,8 @@ func (h *GroupHandler) Create(c *gin.Context) {
|
||||
ImagePrice1K: req.ImagePrice1K,
|
||||
ImagePrice2K: req.ImagePrice2K,
|
||||
ImagePrice4K: req.ImagePrice4K,
|
||||
ClaudeCodeOnly: req.ClaudeCodeOnly,
|
||||
FallbackGroupID: req.FallbackGroupID,
|
||||
})
|
||||
if err != nil {
|
||||
response.ErrorFrom(c, err)
|
||||
@@ -188,6 +194,8 @@ func (h *GroupHandler) Update(c *gin.Context) {
|
||||
ImagePrice1K: req.ImagePrice1K,
|
||||
ImagePrice2K: req.ImagePrice2K,
|
||||
ImagePrice4K: req.ImagePrice4K,
|
||||
ClaudeCodeOnly: req.ClaudeCodeOnly,
|
||||
FallbackGroupID: req.FallbackGroupID,
|
||||
})
|
||||
if err != nil {
|
||||
response.ErrorFrom(c, err)
|
||||
|
||||
@@ -85,6 +85,8 @@ func GroupFromServiceShallow(g *service.Group) *Group {
|
||||
ImagePrice1K: g.ImagePrice1K,
|
||||
ImagePrice2K: g.ImagePrice2K,
|
||||
ImagePrice4K: g.ImagePrice4K,
|
||||
ClaudeCodeOnly: g.ClaudeCodeOnly,
|
||||
FallbackGroupID: g.FallbackGroupID,
|
||||
CreatedAt: g.CreatedAt,
|
||||
UpdatedAt: g.UpdatedAt,
|
||||
AccountCount: g.AccountCount,
|
||||
|
||||
@@ -52,6 +52,10 @@ type Group struct {
|
||||
ImagePrice2K *float64 `json:"image_price_2k"`
|
||||
ImagePrice4K *float64 `json:"image_price_4k"`
|
||||
|
||||
// Claude Code 客户端限制
|
||||
ClaudeCodeOnly bool `json:"claude_code_only"`
|
||||
FallbackGroupID *int64 `json:"fallback_group_id"`
|
||||
|
||||
CreatedAt time.Time `json:"created_at"`
|
||||
UpdatedAt time.Time `json:"updated_at"`
|
||||
|
||||
|
||||
@@ -96,6 +96,9 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
|
||||
reqModel := parsedReq.Model
|
||||
reqStream := parsedReq.Stream
|
||||
|
||||
// 设置 Claude Code 客户端标识到 context(用于分组限制检查)
|
||||
SetClaudeCodeClientContext(c, body)
|
||||
|
||||
// 验证 model 必填
|
||||
if reqModel == "" {
|
||||
h.errorResponse(c, http.StatusBadRequest, "invalid_request_error", "model is required")
|
||||
@@ -229,7 +232,7 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
|
||||
h.handleConcurrencyError(c, err, "account", streamStarted)
|
||||
return
|
||||
}
|
||||
if err := h.gatewayService.BindStickySession(c.Request.Context(), sessionKey, account.ID); err != nil {
|
||||
if err := h.gatewayService.BindStickySession(c.Request.Context(), apiKey.GroupID, sessionKey, account.ID); err != nil {
|
||||
log.Printf("Bind sticky session failed: %v", err)
|
||||
}
|
||||
}
|
||||
@@ -357,7 +360,7 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
|
||||
h.handleConcurrencyError(c, err, "account", streamStarted)
|
||||
return
|
||||
}
|
||||
if err := h.gatewayService.BindStickySession(c.Request.Context(), sessionKey, account.ID); err != nil {
|
||||
if err := h.gatewayService.BindStickySession(c.Request.Context(), apiKey.GroupID, sessionKey, account.ID); err != nil {
|
||||
log.Printf("Bind sticky session failed: %v", err)
|
||||
}
|
||||
}
|
||||
@@ -683,6 +686,9 @@ func (h *GatewayHandler) CountTokens(c *gin.Context) {
|
||||
return
|
||||
}
|
||||
|
||||
// 设置 Claude Code 客户端标识到 context(用于分组限制检查)
|
||||
SetClaudeCodeClientContext(c, body)
|
||||
|
||||
// 验证 model 必填
|
||||
if parsedReq.Model == "" {
|
||||
h.errorResponse(c, http.StatusBadRequest, "invalid_request_error", "model is required")
|
||||
|
||||
@@ -2,6 +2,7 @@ package handler
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"math/rand"
|
||||
"net/http"
|
||||
@@ -13,6 +14,26 @@ import (
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
// claudeCodeValidator is a singleton validator for Claude Code client detection
|
||||
var claudeCodeValidator = service.NewClaudeCodeValidator()
|
||||
|
||||
// SetClaudeCodeClientContext 检查请求是否来自 Claude Code 客户端,并设置到 context 中
|
||||
// 返回更新后的 context
|
||||
func SetClaudeCodeClientContext(c *gin.Context, body []byte) {
|
||||
// 解析请求体为 map
|
||||
var bodyMap map[string]any
|
||||
if len(body) > 0 {
|
||||
_ = json.Unmarshal(body, &bodyMap)
|
||||
}
|
||||
|
||||
// 验证是否为 Claude Code 客户端
|
||||
isClaudeCode := claudeCodeValidator.Validate(c.Request, bodyMap)
|
||||
|
||||
// 更新 request context
|
||||
ctx := service.SetClaudeCodeClient(c.Request.Context(), isClaudeCode)
|
||||
c.Request = c.Request.WithContext(ctx)
|
||||
}
|
||||
|
||||
// 并发槽位等待相关常量
|
||||
//
|
||||
// 性能优化说明:
|
||||
|
||||
@@ -203,6 +203,10 @@ func (h *GatewayHandler) GeminiV1BetaModels(c *gin.Context) {
|
||||
|
||||
// 3) select account (sticky session based on request body)
|
||||
parsedReq, _ := service.ParseGatewayRequest(body)
|
||||
|
||||
// 设置 Claude Code 客户端标识到 context(用于分组限制检查)
|
||||
SetClaudeCodeClientContext(c, body)
|
||||
|
||||
sessionHash := h.gatewayService.GenerateSessionHash(parsedReq)
|
||||
sessionKey := sessionHash
|
||||
if sessionHash != "" {
|
||||
@@ -262,7 +266,7 @@ func (h *GatewayHandler) GeminiV1BetaModels(c *gin.Context) {
|
||||
googleError(c, http.StatusTooManyRequests, err.Error())
|
||||
return
|
||||
}
|
||||
if err := h.gatewayService.BindStickySession(c.Request.Context(), sessionKey, account.ID); err != nil {
|
||||
if err := h.gatewayService.BindStickySession(c.Request.Context(), apiKey.GroupID, sessionKey, account.ID); err != nil {
|
||||
log.Printf("Bind sticky session failed: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -206,7 +206,7 @@ func (h *OpenAIGatewayHandler) Responses(c *gin.Context) {
|
||||
h.handleConcurrencyError(c, err, "account", streamStarted)
|
||||
return
|
||||
}
|
||||
if err := h.gatewayService.BindStickySession(c.Request.Context(), sessionHash, account.ID); err != nil {
|
||||
if err := h.gatewayService.BindStickySession(c.Request.Context(), apiKey.GroupID, sessionHash, account.ID); err != nil {
|
||||
log.Printf("Bind sticky session failed: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -7,4 +7,6 @@ type Key string
|
||||
const (
|
||||
// ForcePlatform 强制平台(用于 /antigravity 路由),由 middleware.ForcePlatform 设置
|
||||
ForcePlatform Key = "ctx_force_platform"
|
||||
// IsClaudeCodeClient 是否为 Claude Code 客户端,由中间件设置
|
||||
IsClaudeCodeClient Key = "ctx_is_claude_code_client"
|
||||
)
|
||||
|
||||
@@ -325,6 +325,8 @@ func groupEntityToService(g *dbent.Group) *service.Group {
|
||||
ImagePrice2K: g.ImagePrice2k,
|
||||
ImagePrice4K: g.ImagePrice4k,
|
||||
DefaultValidityDays: g.DefaultValidityDays,
|
||||
ClaudeCodeOnly: g.ClaudeCodeOnly,
|
||||
FallbackGroupID: g.FallbackGroupID,
|
||||
CreatedAt: g.CreatedAt,
|
||||
UpdatedAt: g.UpdatedAt,
|
||||
}
|
||||
|
||||
@@ -2,6 +2,7 @@ package repository
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"time"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/service"
|
||||
@@ -18,17 +19,23 @@ func NewGatewayCache(rdb *redis.Client) service.GatewayCache {
|
||||
return &gatewayCache{rdb: rdb}
|
||||
}
|
||||
|
||||
func (c *gatewayCache) GetSessionAccountID(ctx context.Context, sessionHash string) (int64, error) {
|
||||
key := stickySessionPrefix + sessionHash
|
||||
// buildSessionKey 构建 session key,包含 groupID 实现分组隔离
|
||||
// 格式: sticky_session:{groupID}:{sessionHash}
|
||||
func buildSessionKey(groupID int64, sessionHash string) string {
|
||||
return fmt.Sprintf("%s%d:%s", stickySessionPrefix, groupID, sessionHash)
|
||||
}
|
||||
|
||||
func (c *gatewayCache) GetSessionAccountID(ctx context.Context, groupID int64, sessionHash string) (int64, error) {
|
||||
key := buildSessionKey(groupID, sessionHash)
|
||||
return c.rdb.Get(ctx, key).Int64()
|
||||
}
|
||||
|
||||
func (c *gatewayCache) SetSessionAccountID(ctx context.Context, sessionHash string, accountID int64, ttl time.Duration) error {
|
||||
key := stickySessionPrefix + sessionHash
|
||||
func (c *gatewayCache) SetSessionAccountID(ctx context.Context, groupID int64, sessionHash string, accountID int64, ttl time.Duration) error {
|
||||
key := buildSessionKey(groupID, sessionHash)
|
||||
return c.rdb.Set(ctx, key, accountID, ttl).Err()
|
||||
}
|
||||
|
||||
func (c *gatewayCache) RefreshSessionTTL(ctx context.Context, sessionHash string, ttl time.Duration) error {
|
||||
key := stickySessionPrefix + sessionHash
|
||||
func (c *gatewayCache) RefreshSessionTTL(ctx context.Context, groupID int64, sessionHash string, ttl time.Duration) error {
|
||||
key := buildSessionKey(groupID, sessionHash)
|
||||
return c.rdb.Expire(ctx, key, ttl).Err()
|
||||
}
|
||||
|
||||
@@ -46,7 +46,9 @@ func (r *groupRepository) Create(ctx context.Context, groupIn *service.Group) er
|
||||
SetNillableImagePrice1k(groupIn.ImagePrice1K).
|
||||
SetNillableImagePrice2k(groupIn.ImagePrice2K).
|
||||
SetNillableImagePrice4k(groupIn.ImagePrice4K).
|
||||
SetDefaultValidityDays(groupIn.DefaultValidityDays)
|
||||
SetDefaultValidityDays(groupIn.DefaultValidityDays).
|
||||
SetClaudeCodeOnly(groupIn.ClaudeCodeOnly).
|
||||
SetNillableFallbackGroupID(groupIn.FallbackGroupID)
|
||||
|
||||
created, err := builder.Save(ctx)
|
||||
if err == nil {
|
||||
@@ -72,7 +74,7 @@ func (r *groupRepository) GetByID(ctx context.Context, id int64) (*service.Group
|
||||
}
|
||||
|
||||
func (r *groupRepository) Update(ctx context.Context, groupIn *service.Group) error {
|
||||
updated, err := r.client.Group.UpdateOneID(groupIn.ID).
|
||||
builder := r.client.Group.UpdateOneID(groupIn.ID).
|
||||
SetName(groupIn.Name).
|
||||
SetDescription(groupIn.Description).
|
||||
SetPlatform(groupIn.Platform).
|
||||
@@ -87,7 +89,16 @@ func (r *groupRepository) Update(ctx context.Context, groupIn *service.Group) er
|
||||
SetNillableImagePrice2k(groupIn.ImagePrice2K).
|
||||
SetNillableImagePrice4k(groupIn.ImagePrice4K).
|
||||
SetDefaultValidityDays(groupIn.DefaultValidityDays).
|
||||
Save(ctx)
|
||||
SetClaudeCodeOnly(groupIn.ClaudeCodeOnly)
|
||||
|
||||
// 处理 FallbackGroupID:nil 时清除,否则设置
|
||||
if groupIn.FallbackGroupID != nil {
|
||||
builder = builder.SetFallbackGroupID(*groupIn.FallbackGroupID)
|
||||
} else {
|
||||
builder = builder.ClearFallbackGroupID()
|
||||
}
|
||||
|
||||
updated, err := builder.Save(ctx)
|
||||
if err != nil {
|
||||
return translatePersistenceError(err, service.ErrGroupNotFound, service.ErrGroupExists)
|
||||
}
|
||||
|
||||
@@ -103,6 +103,8 @@ type CreateGroupInput struct {
|
||||
ImagePrice1K *float64
|
||||
ImagePrice2K *float64
|
||||
ImagePrice4K *float64
|
||||
ClaudeCodeOnly bool // 仅允许 Claude Code 客户端
|
||||
FallbackGroupID *int64 // 降级分组 ID
|
||||
}
|
||||
|
||||
type UpdateGroupInput struct {
|
||||
@@ -120,6 +122,8 @@ type UpdateGroupInput struct {
|
||||
ImagePrice1K *float64
|
||||
ImagePrice2K *float64
|
||||
ImagePrice4K *float64
|
||||
ClaudeCodeOnly *bool // 仅允许 Claude Code 客户端
|
||||
FallbackGroupID *int64 // 降级分组 ID
|
||||
}
|
||||
|
||||
type CreateAccountInput struct {
|
||||
@@ -516,6 +520,13 @@ func (s *adminServiceImpl) CreateGroup(ctx context.Context, input *CreateGroupIn
|
||||
imagePrice2K := normalizePrice(input.ImagePrice2K)
|
||||
imagePrice4K := normalizePrice(input.ImagePrice4K)
|
||||
|
||||
// 校验降级分组
|
||||
if input.FallbackGroupID != nil {
|
||||
if err := s.validateFallbackGroup(ctx, 0, *input.FallbackGroupID); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
|
||||
group := &Group{
|
||||
Name: input.Name,
|
||||
Description: input.Description,
|
||||
@@ -530,6 +541,8 @@ func (s *adminServiceImpl) CreateGroup(ctx context.Context, input *CreateGroupIn
|
||||
ImagePrice1K: imagePrice1K,
|
||||
ImagePrice2K: imagePrice2K,
|
||||
ImagePrice4K: imagePrice4K,
|
||||
ClaudeCodeOnly: input.ClaudeCodeOnly,
|
||||
FallbackGroupID: input.FallbackGroupID,
|
||||
}
|
||||
if err := s.groupRepo.Create(ctx, group); err != nil {
|
||||
return nil, err
|
||||
@@ -553,6 +566,29 @@ func normalizePrice(price *float64) *float64 {
|
||||
return price
|
||||
}
|
||||
|
||||
// validateFallbackGroup 校验降级分组的有效性
|
||||
// currentGroupID: 当前分组 ID(新建时为 0)
|
||||
// fallbackGroupID: 降级分组 ID
|
||||
func (s *adminServiceImpl) validateFallbackGroup(ctx context.Context, currentGroupID, fallbackGroupID int64) error {
|
||||
// 不能将自己设置为降级分组
|
||||
if currentGroupID > 0 && currentGroupID == fallbackGroupID {
|
||||
return fmt.Errorf("cannot set self as fallback group")
|
||||
}
|
||||
|
||||
// 检查降级分组是否存在
|
||||
fallbackGroup, err := s.groupRepo.GetByID(ctx, fallbackGroupID)
|
||||
if err != nil {
|
||||
return fmt.Errorf("fallback group not found: %w", err)
|
||||
}
|
||||
|
||||
// 降级分组不能启用 claude_code_only,否则会造成死循环
|
||||
if fallbackGroup.ClaudeCodeOnly {
|
||||
return fmt.Errorf("fallback group cannot have claude_code_only enabled")
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *adminServiceImpl) UpdateGroup(ctx context.Context, id int64, input *UpdateGroupInput) (*Group, error) {
|
||||
group, err := s.groupRepo.GetByID(ctx, id)
|
||||
if err != nil {
|
||||
@@ -603,6 +639,23 @@ func (s *adminServiceImpl) UpdateGroup(ctx context.Context, id int64, input *Upd
|
||||
group.ImagePrice4K = normalizePrice(input.ImagePrice4K)
|
||||
}
|
||||
|
||||
// Claude Code 客户端限制
|
||||
if input.ClaudeCodeOnly != nil {
|
||||
group.ClaudeCodeOnly = *input.ClaudeCodeOnly
|
||||
}
|
||||
if input.FallbackGroupID != nil {
|
||||
// 校验降级分组
|
||||
if *input.FallbackGroupID > 0 {
|
||||
if err := s.validateFallbackGroup(ctx, id, *input.FallbackGroupID); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
group.FallbackGroupID = input.FallbackGroupID
|
||||
} else {
|
||||
// 传入 0 或负数表示清除降级分组
|
||||
group.FallbackGroupID = nil
|
||||
}
|
||||
}
|
||||
|
||||
if err := s.groupRepo.Update(ctx, group); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
265
backend/internal/service/claude_code_validator.go
Normal file
265
backend/internal/service/claude_code_validator.go
Normal file
@@ -0,0 +1,265 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"context"
|
||||
"net/http"
|
||||
"regexp"
|
||||
"strings"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/ctxkey"
|
||||
)
|
||||
|
||||
// ClaudeCodeValidator 验证请求是否来自 Claude Code 客户端
|
||||
// 完全学习自 claude-relay-service 项目的验证逻辑
|
||||
type ClaudeCodeValidator struct{}
|
||||
|
||||
var (
|
||||
// User-Agent 匹配: claude-cli/x.x.x (仅支持官方 CLI,大小写不敏感)
|
||||
claudeCodeUAPattern = regexp.MustCompile(`(?i)^claude-cli/\d+\.\d+\.\d+`)
|
||||
|
||||
// metadata.user_id 格式: user_{64位hex}_account__session_{uuid}
|
||||
userIDPattern = regexp.MustCompile(`^user_[a-fA-F0-9]{64}_account__session_[\w-]+$`)
|
||||
|
||||
// System prompt 相似度阈值(默认 0.5,和 claude-relay-service 一致)
|
||||
systemPromptThreshold = 0.5
|
||||
)
|
||||
|
||||
// Claude Code 官方 System Prompt 模板
|
||||
// 从 claude-relay-service/src/utils/contents.js 提取
|
||||
var claudeCodeSystemPrompts = []string{
|
||||
// claudeOtherSystemPrompt1 - Primary
|
||||
"You are Claude Code, Anthropic's official CLI for Claude.",
|
||||
|
||||
// claudeOtherSystemPrompt3 - Agent SDK
|
||||
"You are a Claude agent, built on Anthropic's Claude Agent SDK.",
|
||||
|
||||
// claudeOtherSystemPrompt4 - Compact Agent SDK
|
||||
"You are Claude Code, Anthropic's official CLI for Claude, running within the Claude Agent SDK.",
|
||||
|
||||
// exploreAgentSystemPrompt
|
||||
"You are a file search specialist for Claude Code, Anthropic's official CLI for Claude.",
|
||||
|
||||
// claudeOtherSystemPromptCompact - Compact (用于对话摘要)
|
||||
"You are a helpful AI assistant tasked with summarizing conversations.",
|
||||
|
||||
// claudeOtherSystemPrompt2 - Secondary (长提示词的关键部分)
|
||||
"You are an interactive CLI tool that helps users",
|
||||
}
|
||||
|
||||
// NewClaudeCodeValidator 创建验证器实例
|
||||
func NewClaudeCodeValidator() *ClaudeCodeValidator {
|
||||
return &ClaudeCodeValidator{}
|
||||
}
|
||||
|
||||
// Validate 验证请求是否来自 Claude Code CLI
|
||||
// 采用与 claude-relay-service 完全一致的验证策略:
|
||||
//
|
||||
// Step 1: User-Agent 检查 (必需) - 必须是 claude-cli/x.x.x
|
||||
// Step 2: 对于非 messages 路径,只要 UA 匹配就通过
|
||||
// Step 3: 对于 messages 路径,进行严格验证:
|
||||
// - System prompt 相似度检查
|
||||
// - X-App header 检查
|
||||
// - anthropic-beta header 检查
|
||||
// - anthropic-version header 检查
|
||||
// - metadata.user_id 格式验证
|
||||
func (v *ClaudeCodeValidator) Validate(r *http.Request, body map[string]any) bool {
|
||||
// Step 1: User-Agent 检查
|
||||
ua := r.Header.Get("User-Agent")
|
||||
if !claudeCodeUAPattern.MatchString(ua) {
|
||||
return false
|
||||
}
|
||||
|
||||
// Step 2: 非 messages 路径,只要 UA 匹配就通过
|
||||
path := r.URL.Path
|
||||
if !strings.Contains(path, "messages") {
|
||||
return true
|
||||
}
|
||||
|
||||
// Step 3: messages 路径,进行严格验证
|
||||
|
||||
// 3.1 检查 system prompt 相似度
|
||||
if !v.hasClaudeCodeSystemPrompt(body) {
|
||||
return false
|
||||
}
|
||||
|
||||
// 3.2 检查必需的 headers(值不为空即可)
|
||||
xApp := r.Header.Get("X-App")
|
||||
if xApp == "" {
|
||||
return false
|
||||
}
|
||||
|
||||
anthropicBeta := r.Header.Get("anthropic-beta")
|
||||
if anthropicBeta == "" {
|
||||
return false
|
||||
}
|
||||
|
||||
anthropicVersion := r.Header.Get("anthropic-version")
|
||||
if anthropicVersion == "" {
|
||||
return false
|
||||
}
|
||||
|
||||
// 3.3 验证 metadata.user_id
|
||||
if body == nil {
|
||||
return false
|
||||
}
|
||||
|
||||
metadata, ok := body["metadata"].(map[string]any)
|
||||
if !ok {
|
||||
return false
|
||||
}
|
||||
|
||||
userID, ok := metadata["user_id"].(string)
|
||||
if !ok || userID == "" {
|
||||
return false
|
||||
}
|
||||
|
||||
if !userIDPattern.MatchString(userID) {
|
||||
return false
|
||||
}
|
||||
|
||||
return true
|
||||
}
|
||||
|
||||
// hasClaudeCodeSystemPrompt 检查请求是否包含 Claude Code 系统提示词
|
||||
// 使用字符串相似度匹配(Dice coefficient)
|
||||
func (v *ClaudeCodeValidator) hasClaudeCodeSystemPrompt(body map[string]any) bool {
|
||||
if body == nil {
|
||||
return false
|
||||
}
|
||||
|
||||
// 检查 model 字段
|
||||
if _, ok := body["model"].(string); !ok {
|
||||
return false
|
||||
}
|
||||
|
||||
// 获取 system 字段
|
||||
systemEntries, ok := body["system"].([]any)
|
||||
if !ok {
|
||||
return false
|
||||
}
|
||||
|
||||
// 检查每个 system entry
|
||||
for _, entry := range systemEntries {
|
||||
entryMap, ok := entry.(map[string]any)
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
|
||||
text, ok := entryMap["text"].(string)
|
||||
if !ok || text == "" {
|
||||
continue
|
||||
}
|
||||
|
||||
// 计算与所有模板的最佳相似度
|
||||
bestScore := v.bestSimilarityScore(text)
|
||||
if bestScore >= systemPromptThreshold {
|
||||
return true
|
||||
}
|
||||
}
|
||||
|
||||
return false
|
||||
}
|
||||
|
||||
// bestSimilarityScore 计算文本与所有 Claude Code 模板的最佳相似度
|
||||
func (v *ClaudeCodeValidator) bestSimilarityScore(text string) float64 {
|
||||
normalizedText := normalizePrompt(text)
|
||||
bestScore := 0.0
|
||||
|
||||
for _, template := range claudeCodeSystemPrompts {
|
||||
normalizedTemplate := normalizePrompt(template)
|
||||
score := diceCoefficient(normalizedText, normalizedTemplate)
|
||||
if score > bestScore {
|
||||
bestScore = score
|
||||
}
|
||||
}
|
||||
|
||||
return bestScore
|
||||
}
|
||||
|
||||
// normalizePrompt 标准化提示词文本(去除多余空白)
|
||||
func normalizePrompt(text string) string {
|
||||
// 将所有空白字符替换为单个空格,并去除首尾空白
|
||||
return strings.Join(strings.Fields(text), " ")
|
||||
}
|
||||
|
||||
// diceCoefficient 计算两个字符串的 Dice 系数(Sørensen–Dice coefficient)
|
||||
// 这是 string-similarity 库使用的算法
|
||||
// 公式: 2 * |intersection| / (|bigrams(a)| + |bigrams(b)|)
|
||||
func diceCoefficient(a, b string) float64 {
|
||||
if a == b {
|
||||
return 1.0
|
||||
}
|
||||
|
||||
if len(a) < 2 || len(b) < 2 {
|
||||
return 0.0
|
||||
}
|
||||
|
||||
// 生成 bigrams
|
||||
bigramsA := getBigrams(a)
|
||||
bigramsB := getBigrams(b)
|
||||
|
||||
if len(bigramsA) == 0 || len(bigramsB) == 0 {
|
||||
return 0.0
|
||||
}
|
||||
|
||||
// 计算交集大小
|
||||
intersection := 0
|
||||
for bigram, countA := range bigramsA {
|
||||
if countB, exists := bigramsB[bigram]; exists {
|
||||
if countA < countB {
|
||||
intersection += countA
|
||||
} else {
|
||||
intersection += countB
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// 计算总 bigram 数量
|
||||
totalA := 0
|
||||
for _, count := range bigramsA {
|
||||
totalA += count
|
||||
}
|
||||
totalB := 0
|
||||
for _, count := range bigramsB {
|
||||
totalB += count
|
||||
}
|
||||
|
||||
return float64(2*intersection) / float64(totalA+totalB)
|
||||
}
|
||||
|
||||
// getBigrams 获取字符串的所有 bigrams(相邻字符对)
|
||||
func getBigrams(s string) map[string]int {
|
||||
bigrams := make(map[string]int)
|
||||
runes := []rune(strings.ToLower(s))
|
||||
|
||||
for i := 0; i < len(runes)-1; i++ {
|
||||
bigram := string(runes[i : i+2])
|
||||
bigrams[bigram]++
|
||||
}
|
||||
|
||||
return bigrams
|
||||
}
|
||||
|
||||
// ValidateUserAgent 仅验证 User-Agent(用于不需要解析请求体的场景)
|
||||
func (v *ClaudeCodeValidator) ValidateUserAgent(ua string) bool {
|
||||
return claudeCodeUAPattern.MatchString(ua)
|
||||
}
|
||||
|
||||
// IncludesClaudeCodeSystemPrompt 检查请求体是否包含 Claude Code 系统提示词
|
||||
// 只要存在匹配的系统提示词就返回 true(用于宽松检测)
|
||||
func (v *ClaudeCodeValidator) IncludesClaudeCodeSystemPrompt(body map[string]any) bool {
|
||||
return v.hasClaudeCodeSystemPrompt(body)
|
||||
}
|
||||
|
||||
// IsClaudeCodeClient 从 context 中获取 Claude Code 客户端标识
|
||||
func IsClaudeCodeClient(ctx context.Context) bool {
|
||||
if v, ok := ctx.Value(ctxkey.IsClaudeCodeClient).(bool); ok {
|
||||
return v
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// SetClaudeCodeClient 将 Claude Code 客户端标识设置到 context 中
|
||||
func SetClaudeCodeClient(ctx context.Context, isClaudeCode bool) context.Context {
|
||||
return context.WithValue(ctx, ctxkey.IsClaudeCodeClient, isClaudeCode)
|
||||
}
|
||||
@@ -56,6 +56,9 @@ var (
|
||||
}
|
||||
)
|
||||
|
||||
// ErrClaudeCodeOnly 表示分组仅允许 Claude Code 客户端访问
|
||||
var ErrClaudeCodeOnly = errors.New("this group only allows Claude Code clients")
|
||||
|
||||
// allowedHeaders 白名单headers(参考CRS项目)
|
||||
var allowedHeaders = map[string]bool{
|
||||
"accept": true,
|
||||
@@ -80,9 +83,17 @@ var allowedHeaders = map[string]bool{
|
||||
|
||||
// GatewayCache defines cache operations for gateway service
|
||||
type GatewayCache interface {
|
||||
GetSessionAccountID(ctx context.Context, sessionHash string) (int64, error)
|
||||
SetSessionAccountID(ctx context.Context, sessionHash string, accountID int64, ttl time.Duration) error
|
||||
RefreshSessionTTL(ctx context.Context, sessionHash string, ttl time.Duration) error
|
||||
GetSessionAccountID(ctx context.Context, groupID int64, sessionHash string) (int64, error)
|
||||
SetSessionAccountID(ctx context.Context, groupID int64, sessionHash string, accountID int64, ttl time.Duration) error
|
||||
RefreshSessionTTL(ctx context.Context, groupID int64, sessionHash string, ttl time.Duration) error
|
||||
}
|
||||
|
||||
// derefGroupID safely dereferences *int64 to int64, returning 0 if nil
|
||||
func derefGroupID(groupID *int64) int64 {
|
||||
if groupID == nil {
|
||||
return 0
|
||||
}
|
||||
return *groupID
|
||||
}
|
||||
|
||||
type AccountWaitPlan struct {
|
||||
@@ -225,11 +236,11 @@ func (s *GatewayService) GenerateSessionHash(parsed *ParsedRequest) string {
|
||||
}
|
||||
|
||||
// BindStickySession sets session -> account binding with standard TTL.
|
||||
func (s *GatewayService) BindStickySession(ctx context.Context, sessionHash string, accountID int64) error {
|
||||
func (s *GatewayService) BindStickySession(ctx context.Context, groupID *int64, sessionHash string, accountID int64) error {
|
||||
if sessionHash == "" || accountID <= 0 || s.cache == nil {
|
||||
return nil
|
||||
}
|
||||
return s.cache.SetSessionAccountID(ctx, sessionHash, accountID, stickySessionTTL)
|
||||
return s.cache.SetSessionAccountID(ctx, derefGroupID(groupID), sessionHash, accountID, stickySessionTTL)
|
||||
}
|
||||
|
||||
func (s *GatewayService) extractCacheableContent(parsed *ParsedRequest) string {
|
||||
@@ -356,6 +367,21 @@ func (s *GatewayService) SelectAccountForModelWithExclusions(ctx context.Context
|
||||
return nil, fmt.Errorf("get group failed: %w", err)
|
||||
}
|
||||
platform = group.Platform
|
||||
|
||||
// 检查 Claude Code 客户端限制
|
||||
if group.ClaudeCodeOnly {
|
||||
isClaudeCode := IsClaudeCodeClient(ctx)
|
||||
if !isClaudeCode {
|
||||
// 非 Claude Code 客户端,检查是否有降级分组
|
||||
if group.FallbackGroupID != nil {
|
||||
// 使用降级分组重新调度
|
||||
fallbackGroupID := *group.FallbackGroupID
|
||||
return s.SelectAccountForModelWithExclusions(ctx, &fallbackGroupID, sessionHash, requestedModel, excludedIDs)
|
||||
}
|
||||
// 无降级分组,拒绝访问
|
||||
return nil, ErrClaudeCodeOnly
|
||||
}
|
||||
}
|
||||
} else {
|
||||
// 无分组时只使用原生 anthropic 平台
|
||||
platform = PlatformAnthropic
|
||||
@@ -377,10 +403,17 @@ func (s *GatewayService) SelectAccountWithLoadAwareness(ctx context.Context, gro
|
||||
cfg := s.schedulingConfig()
|
||||
var stickyAccountID int64
|
||||
if sessionHash != "" && s.cache != nil {
|
||||
if accountID, err := s.cache.GetSessionAccountID(ctx, sessionHash); err == nil {
|
||||
if accountID, err := s.cache.GetSessionAccountID(ctx, derefGroupID(groupID), sessionHash); err == nil {
|
||||
stickyAccountID = accountID
|
||||
}
|
||||
}
|
||||
|
||||
// 检查 Claude Code 客户端限制(可能会替换 groupID 为降级分组)
|
||||
groupID, err := s.checkClaudeCodeRestriction(ctx, groupID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if s.concurrencyService == nil || !cfg.LoadBatchEnabled {
|
||||
account, err := s.SelectAccountForModelWithExclusions(ctx, groupID, sessionHash, requestedModel, excludedIDs)
|
||||
if err != nil {
|
||||
@@ -443,7 +476,7 @@ func (s *GatewayService) SelectAccountWithLoadAwareness(ctx context.Context, gro
|
||||
|
||||
// ============ Layer 1: 粘性会话优先 ============
|
||||
if sessionHash != "" && s.cache != nil {
|
||||
accountID, err := s.cache.GetSessionAccountID(ctx, sessionHash)
|
||||
accountID, err := s.cache.GetSessionAccountID(ctx, derefGroupID(groupID), sessionHash)
|
||||
if err == nil && accountID > 0 && !isExcluded(accountID) {
|
||||
account, err := s.accountRepo.GetByID(ctx, accountID)
|
||||
if err == nil && s.isAccountInGroup(account, groupID) &&
|
||||
@@ -452,7 +485,7 @@ func (s *GatewayService) SelectAccountWithLoadAwareness(ctx context.Context, gro
|
||||
(requestedModel == "" || s.isModelSupportedByAccount(account, requestedModel)) {
|
||||
result, err := s.tryAcquireAccountSlot(ctx, accountID, account.Concurrency)
|
||||
if err == nil && result.Acquired {
|
||||
_ = s.cache.RefreshSessionTTL(ctx, sessionHash, stickySessionTTL)
|
||||
_ = s.cache.RefreshSessionTTL(ctx, derefGroupID(groupID), sessionHash, stickySessionTTL)
|
||||
return &AccountSelectionResult{
|
||||
Account: account,
|
||||
Acquired: true,
|
||||
@@ -506,7 +539,7 @@ func (s *GatewayService) SelectAccountWithLoadAwareness(ctx context.Context, gro
|
||||
|
||||
loadMap, err := s.concurrencyService.GetAccountsLoadBatch(ctx, accountLoads)
|
||||
if err != nil {
|
||||
if result, ok := s.tryAcquireByLegacyOrder(ctx, candidates, sessionHash, preferOAuth); ok {
|
||||
if result, ok := s.tryAcquireByLegacyOrder(ctx, candidates, groupID, sessionHash, preferOAuth); ok {
|
||||
return result, nil
|
||||
}
|
||||
} else {
|
||||
@@ -556,7 +589,7 @@ func (s *GatewayService) SelectAccountWithLoadAwareness(ctx context.Context, gro
|
||||
result, err := s.tryAcquireAccountSlot(ctx, item.account.ID, item.account.Concurrency)
|
||||
if err == nil && result.Acquired {
|
||||
if sessionHash != "" && s.cache != nil {
|
||||
_ = s.cache.SetSessionAccountID(ctx, sessionHash, item.account.ID, stickySessionTTL)
|
||||
_ = s.cache.SetSessionAccountID(ctx, derefGroupID(groupID), sessionHash, item.account.ID, stickySessionTTL)
|
||||
}
|
||||
return &AccountSelectionResult{
|
||||
Account: item.account,
|
||||
@@ -584,7 +617,7 @@ func (s *GatewayService) SelectAccountWithLoadAwareness(ctx context.Context, gro
|
||||
return nil, errors.New("no available accounts")
|
||||
}
|
||||
|
||||
func (s *GatewayService) tryAcquireByLegacyOrder(ctx context.Context, candidates []*Account, sessionHash string, preferOAuth bool) (*AccountSelectionResult, bool) {
|
||||
func (s *GatewayService) tryAcquireByLegacyOrder(ctx context.Context, candidates []*Account, groupID *int64, sessionHash string, preferOAuth bool) (*AccountSelectionResult, bool) {
|
||||
ordered := append([]*Account(nil), candidates...)
|
||||
sortAccountsByPriorityAndLastUsed(ordered, preferOAuth)
|
||||
|
||||
@@ -592,7 +625,7 @@ func (s *GatewayService) tryAcquireByLegacyOrder(ctx context.Context, candidates
|
||||
result, err := s.tryAcquireAccountSlot(ctx, acc.ID, acc.Concurrency)
|
||||
if err == nil && result.Acquired {
|
||||
if sessionHash != "" && s.cache != nil {
|
||||
_ = s.cache.SetSessionAccountID(ctx, sessionHash, acc.ID, stickySessionTTL)
|
||||
_ = s.cache.SetSessionAccountID(ctx, derefGroupID(groupID), sessionHash, acc.ID, stickySessionTTL)
|
||||
}
|
||||
return &AccountSelectionResult{
|
||||
Account: acc,
|
||||
@@ -619,6 +652,42 @@ func (s *GatewayService) schedulingConfig() config.GatewaySchedulingConfig {
|
||||
}
|
||||
}
|
||||
|
||||
// checkClaudeCodeRestriction 检查分组的 Claude Code 客户端限制
|
||||
// 如果分组启用了 claude_code_only 且请求不是来自 Claude Code 客户端:
|
||||
// - 有降级分组:返回降级分组的 ID
|
||||
// - 无降级分组:返回 ErrClaudeCodeOnly 错误
|
||||
func (s *GatewayService) checkClaudeCodeRestriction(ctx context.Context, groupID *int64) (*int64, error) {
|
||||
if groupID == nil {
|
||||
return groupID, nil
|
||||
}
|
||||
|
||||
// 强制平台模式不检查 Claude Code 限制
|
||||
if _, hasForcePlatform := ctx.Value(ctxkey.ForcePlatform).(string); hasForcePlatform {
|
||||
return groupID, nil
|
||||
}
|
||||
|
||||
group, err := s.groupRepo.GetByID(ctx, *groupID)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("get group failed: %w", err)
|
||||
}
|
||||
|
||||
if !group.ClaudeCodeOnly {
|
||||
return groupID, nil
|
||||
}
|
||||
|
||||
// 分组启用了 Claude Code 限制
|
||||
if IsClaudeCodeClient(ctx) {
|
||||
return groupID, nil
|
||||
}
|
||||
|
||||
// 非 Claude Code 客户端,检查降级分组
|
||||
if group.FallbackGroupID != nil {
|
||||
return group.FallbackGroupID, nil
|
||||
}
|
||||
|
||||
return nil, ErrClaudeCodeOnly
|
||||
}
|
||||
|
||||
func (s *GatewayService) resolvePlatform(ctx context.Context, groupID *int64) (string, bool, error) {
|
||||
forcePlatform, hasForcePlatform := ctx.Value(ctxkey.ForcePlatform).(string)
|
||||
if hasForcePlatform && forcePlatform != "" {
|
||||
@@ -738,13 +807,13 @@ func (s *GatewayService) selectAccountForModelWithPlatform(ctx context.Context,
|
||||
preferOAuth := platform == PlatformGemini
|
||||
// 1. 查询粘性会话
|
||||
if sessionHash != "" && s.cache != nil {
|
||||
accountID, err := s.cache.GetSessionAccountID(ctx, sessionHash)
|
||||
accountID, err := s.cache.GetSessionAccountID(ctx, derefGroupID(groupID), sessionHash)
|
||||
if err == nil && accountID > 0 {
|
||||
if _, excluded := excludedIDs[accountID]; !excluded {
|
||||
account, err := s.accountRepo.GetByID(ctx, accountID)
|
||||
// 检查账号分组归属和平台匹配(确保粘性会话不会跨分组或跨平台)
|
||||
if err == nil && s.isAccountInGroup(account, groupID) && account.Platform == platform && account.IsSchedulable() && (requestedModel == "" || s.isModelSupportedByAccount(account, requestedModel)) {
|
||||
if err := s.cache.RefreshSessionTTL(ctx, sessionHash, stickySessionTTL); err != nil {
|
||||
if err := s.cache.RefreshSessionTTL(ctx, derefGroupID(groupID), sessionHash, stickySessionTTL); err != nil {
|
||||
log.Printf("refresh session ttl failed: session=%s err=%v", sessionHash, err)
|
||||
}
|
||||
return account, nil
|
||||
@@ -811,7 +880,7 @@ func (s *GatewayService) selectAccountForModelWithPlatform(ctx context.Context,
|
||||
|
||||
// 4. 建立粘性绑定
|
||||
if sessionHash != "" && s.cache != nil {
|
||||
if err := s.cache.SetSessionAccountID(ctx, sessionHash, selected.ID, stickySessionTTL); err != nil {
|
||||
if err := s.cache.SetSessionAccountID(ctx, derefGroupID(groupID), sessionHash, selected.ID, stickySessionTTL); err != nil {
|
||||
log.Printf("set session account failed: session=%s account_id=%d err=%v", sessionHash, selected.ID, err)
|
||||
}
|
||||
}
|
||||
@@ -827,14 +896,14 @@ func (s *GatewayService) selectAccountWithMixedScheduling(ctx context.Context, g
|
||||
|
||||
// 1. 查询粘性会话
|
||||
if sessionHash != "" && s.cache != nil {
|
||||
accountID, err := s.cache.GetSessionAccountID(ctx, sessionHash)
|
||||
accountID, err := s.cache.GetSessionAccountID(ctx, derefGroupID(groupID), sessionHash)
|
||||
if err == nil && accountID > 0 {
|
||||
if _, excluded := excludedIDs[accountID]; !excluded {
|
||||
account, err := s.accountRepo.GetByID(ctx, accountID)
|
||||
// 检查账号分组归属和有效性:原生平台直接匹配,antigravity 需要启用混合调度
|
||||
if err == nil && s.isAccountInGroup(account, groupID) && account.IsSchedulable() && (requestedModel == "" || s.isModelSupportedByAccount(account, requestedModel)) {
|
||||
if account.Platform == nativePlatform || (account.Platform == PlatformAntigravity && account.IsMixedSchedulingEnabled()) {
|
||||
if err := s.cache.RefreshSessionTTL(ctx, sessionHash, stickySessionTTL); err != nil {
|
||||
if err := s.cache.RefreshSessionTTL(ctx, derefGroupID(groupID), sessionHash, stickySessionTTL); err != nil {
|
||||
log.Printf("refresh session ttl failed: session=%s err=%v", sessionHash, err)
|
||||
}
|
||||
return account, nil
|
||||
@@ -903,7 +972,7 @@ func (s *GatewayService) selectAccountWithMixedScheduling(ctx context.Context, g
|
||||
|
||||
// 4. 建立粘性绑定
|
||||
if sessionHash != "" && s.cache != nil {
|
||||
if err := s.cache.SetSessionAccountID(ctx, sessionHash, selected.ID, stickySessionTTL); err != nil {
|
||||
if err := s.cache.SetSessionAccountID(ctx, derefGroupID(groupID), sessionHash, selected.ID, stickySessionTTL); err != nil {
|
||||
log.Printf("set session account failed: session=%s account_id=%d err=%v", sessionHash, selected.ID, err)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -109,7 +109,7 @@ func (s *GeminiMessagesCompatService) SelectAccountForModelWithExclusions(ctx co
|
||||
cacheKey := "gemini:" + sessionHash
|
||||
|
||||
if sessionHash != "" {
|
||||
accountID, err := s.cache.GetSessionAccountID(ctx, cacheKey)
|
||||
accountID, err := s.cache.GetSessionAccountID(ctx, derefGroupID(groupID), cacheKey)
|
||||
if err == nil && accountID > 0 {
|
||||
if _, excluded := excludedIDs[accountID]; !excluded {
|
||||
account, err := s.accountRepo.GetByID(ctx, accountID)
|
||||
@@ -133,7 +133,7 @@ func (s *GeminiMessagesCompatService) SelectAccountForModelWithExclusions(ctx co
|
||||
}
|
||||
}
|
||||
if usable {
|
||||
_ = s.cache.RefreshSessionTTL(ctx, cacheKey, geminiStickySessionTTL)
|
||||
_ = s.cache.RefreshSessionTTL(ctx, derefGroupID(groupID), cacheKey, geminiStickySessionTTL)
|
||||
return account, nil
|
||||
}
|
||||
}
|
||||
@@ -217,7 +217,7 @@ func (s *GeminiMessagesCompatService) SelectAccountForModelWithExclusions(ctx co
|
||||
}
|
||||
|
||||
if sessionHash != "" {
|
||||
_ = s.cache.SetSessionAccountID(ctx, cacheKey, selected.ID, geminiStickySessionTTL)
|
||||
_ = s.cache.SetSessionAccountID(ctx, derefGroupID(groupID), cacheKey, selected.ID, geminiStickySessionTTL)
|
||||
}
|
||||
|
||||
return selected, nil
|
||||
|
||||
@@ -22,6 +22,10 @@ type Group struct {
|
||||
ImagePrice2K *float64
|
||||
ImagePrice4K *float64
|
||||
|
||||
// Claude Code 客户端限制
|
||||
ClaudeCodeOnly bool
|
||||
FallbackGroupID *int64
|
||||
|
||||
CreatedAt time.Time
|
||||
UpdatedAt time.Time
|
||||
|
||||
|
||||
@@ -134,11 +134,11 @@ func (s *OpenAIGatewayService) GenerateSessionHash(c *gin.Context) string {
|
||||
}
|
||||
|
||||
// BindStickySession sets session -> account binding with standard TTL.
|
||||
func (s *OpenAIGatewayService) BindStickySession(ctx context.Context, sessionHash string, accountID int64) error {
|
||||
func (s *OpenAIGatewayService) BindStickySession(ctx context.Context, groupID *int64, sessionHash string, accountID int64) error {
|
||||
if sessionHash == "" || accountID <= 0 {
|
||||
return nil
|
||||
}
|
||||
return s.cache.SetSessionAccountID(ctx, "openai:"+sessionHash, accountID, openaiStickySessionTTL)
|
||||
return s.cache.SetSessionAccountID(ctx, derefGroupID(groupID), "openai:"+sessionHash, accountID, openaiStickySessionTTL)
|
||||
}
|
||||
|
||||
// SelectAccount selects an OpenAI account with sticky session support
|
||||
@@ -155,13 +155,13 @@ func (s *OpenAIGatewayService) SelectAccountForModel(ctx context.Context, groupI
|
||||
func (s *OpenAIGatewayService) SelectAccountForModelWithExclusions(ctx context.Context, groupID *int64, sessionHash string, requestedModel string, excludedIDs map[int64]struct{}) (*Account, error) {
|
||||
// 1. Check sticky session
|
||||
if sessionHash != "" {
|
||||
accountID, err := s.cache.GetSessionAccountID(ctx, "openai:"+sessionHash)
|
||||
accountID, err := s.cache.GetSessionAccountID(ctx, derefGroupID(groupID), "openai:"+sessionHash)
|
||||
if err == nil && accountID > 0 {
|
||||
if _, excluded := excludedIDs[accountID]; !excluded {
|
||||
account, err := s.accountRepo.GetByID(ctx, accountID)
|
||||
if err == nil && account.IsSchedulable() && account.IsOpenAI() && (requestedModel == "" || account.IsModelSupported(requestedModel)) {
|
||||
// Refresh sticky session TTL
|
||||
_ = s.cache.RefreshSessionTTL(ctx, "openai:"+sessionHash, openaiStickySessionTTL)
|
||||
_ = s.cache.RefreshSessionTTL(ctx, derefGroupID(groupID), "openai:"+sessionHash, openaiStickySessionTTL)
|
||||
return account, nil
|
||||
}
|
||||
}
|
||||
@@ -227,7 +227,7 @@ func (s *OpenAIGatewayService) SelectAccountForModelWithExclusions(ctx context.C
|
||||
|
||||
// 4. Set sticky session
|
||||
if sessionHash != "" {
|
||||
_ = s.cache.SetSessionAccountID(ctx, "openai:"+sessionHash, selected.ID, openaiStickySessionTTL)
|
||||
_ = s.cache.SetSessionAccountID(ctx, derefGroupID(groupID), "openai:"+sessionHash, selected.ID, openaiStickySessionTTL)
|
||||
}
|
||||
|
||||
return selected, nil
|
||||
@@ -238,7 +238,7 @@ func (s *OpenAIGatewayService) SelectAccountWithLoadAwareness(ctx context.Contex
|
||||
cfg := s.schedulingConfig()
|
||||
var stickyAccountID int64
|
||||
if sessionHash != "" && s.cache != nil {
|
||||
if accountID, err := s.cache.GetSessionAccountID(ctx, "openai:"+sessionHash); err == nil {
|
||||
if accountID, err := s.cache.GetSessionAccountID(ctx, derefGroupID(groupID), "openai:"+sessionHash); err == nil {
|
||||
stickyAccountID = accountID
|
||||
}
|
||||
}
|
||||
@@ -298,14 +298,14 @@ func (s *OpenAIGatewayService) SelectAccountWithLoadAwareness(ctx context.Contex
|
||||
|
||||
// ============ Layer 1: Sticky session ============
|
||||
if sessionHash != "" {
|
||||
accountID, err := s.cache.GetSessionAccountID(ctx, "openai:"+sessionHash)
|
||||
accountID, err := s.cache.GetSessionAccountID(ctx, derefGroupID(groupID), "openai:"+sessionHash)
|
||||
if err == nil && accountID > 0 && !isExcluded(accountID) {
|
||||
account, err := s.accountRepo.GetByID(ctx, accountID)
|
||||
if err == nil && account.IsSchedulable() && account.IsOpenAI() &&
|
||||
(requestedModel == "" || account.IsModelSupported(requestedModel)) {
|
||||
result, err := s.tryAcquireAccountSlot(ctx, accountID, account.Concurrency)
|
||||
if err == nil && result.Acquired {
|
||||
_ = s.cache.RefreshSessionTTL(ctx, "openai:"+sessionHash, openaiStickySessionTTL)
|
||||
_ = s.cache.RefreshSessionTTL(ctx, derefGroupID(groupID), "openai:"+sessionHash, openaiStickySessionTTL)
|
||||
return &AccountSelectionResult{
|
||||
Account: account,
|
||||
Acquired: true,
|
||||
@@ -362,7 +362,7 @@ func (s *OpenAIGatewayService) SelectAccountWithLoadAwareness(ctx context.Contex
|
||||
result, err := s.tryAcquireAccountSlot(ctx, acc.ID, acc.Concurrency)
|
||||
if err == nil && result.Acquired {
|
||||
if sessionHash != "" {
|
||||
_ = s.cache.SetSessionAccountID(ctx, "openai:"+sessionHash, acc.ID, openaiStickySessionTTL)
|
||||
_ = s.cache.SetSessionAccountID(ctx, derefGroupID(groupID), "openai:"+sessionHash, acc.ID, openaiStickySessionTTL)
|
||||
}
|
||||
return &AccountSelectionResult{
|
||||
Account: acc,
|
||||
@@ -415,7 +415,7 @@ func (s *OpenAIGatewayService) SelectAccountWithLoadAwareness(ctx context.Contex
|
||||
result, err := s.tryAcquireAccountSlot(ctx, item.account.ID, item.account.Concurrency)
|
||||
if err == nil && result.Acquired {
|
||||
if sessionHash != "" {
|
||||
_ = s.cache.SetSessionAccountID(ctx, "openai:"+sessionHash, item.account.ID, openaiStickySessionTTL)
|
||||
_ = s.cache.SetSessionAccountID(ctx, derefGroupID(groupID), "openai:"+sessionHash, item.account.ID, openaiStickySessionTTL)
|
||||
}
|
||||
return &AccountSelectionResult{
|
||||
Account: item.account,
|
||||
|
||||
21
backend/migrations/029_add_group_claude_code_restriction.sql
Normal file
21
backend/migrations/029_add_group_claude_code_restriction.sql
Normal file
@@ -0,0 +1,21 @@
|
||||
-- 029_add_group_claude_code_restriction.sql
|
||||
-- 添加分组级别的 Claude Code 客户端限制功能
|
||||
|
||||
-- 添加 claude_code_only 字段:是否仅允许 Claude Code 客户端
|
||||
ALTER TABLE groups
|
||||
ADD COLUMN IF NOT EXISTS claude_code_only BOOLEAN NOT NULL DEFAULT FALSE;
|
||||
|
||||
-- 添加 fallback_group_id 字段:非 Claude Code 请求降级到的分组
|
||||
ALTER TABLE groups
|
||||
ADD COLUMN IF NOT EXISTS fallback_group_id BIGINT REFERENCES groups(id) ON DELETE SET NULL;
|
||||
|
||||
-- 添加索引优化查询
|
||||
CREATE INDEX IF NOT EXISTS idx_groups_claude_code_only
|
||||
ON groups(claude_code_only) WHERE deleted_at IS NULL;
|
||||
|
||||
CREATE INDEX IF NOT EXISTS idx_groups_fallback_group_id
|
||||
ON groups(fallback_group_id) WHERE deleted_at IS NULL AND fallback_group_id IS NOT NULL;
|
||||
|
||||
-- 添加字段注释
|
||||
COMMENT ON COLUMN groups.claude_code_only IS '是否仅允许 Claude Code 客户端访问此分组';
|
||||
COMMENT ON COLUMN groups.fallback_group_id IS '非 Claude Code 请求降级使用的分组 ID';
|
||||
@@ -857,6 +857,15 @@ export default {
|
||||
imagePricing: {
|
||||
title: 'Image Generation Pricing',
|
||||
description: 'Configure pricing for gemini-3-pro-image model. Leave empty to use default prices.'
|
||||
},
|
||||
claudeCode: {
|
||||
title: 'Claude Code Client Restriction',
|
||||
tooltip: 'When enabled, this group only allows official Claude Code clients. Non-Claude Code requests will be rejected or fallback to the specified group.',
|
||||
enabled: 'Claude Code Only',
|
||||
disabled: 'Allow All Clients',
|
||||
fallbackGroup: 'Fallback Group',
|
||||
fallbackHint: 'Non-Claude Code requests will use this group. Leave empty to reject directly.',
|
||||
noFallback: 'No Fallback (Reject)'
|
||||
}
|
||||
},
|
||||
|
||||
|
||||
@@ -934,6 +934,15 @@ export default {
|
||||
imagePricing: {
|
||||
title: '图片生成计费',
|
||||
description: '配置 gemini-3-pro-image 模型的图片生成价格,留空则使用默认价格'
|
||||
},
|
||||
claudeCode: {
|
||||
title: 'Claude Code 客户端限制',
|
||||
tooltip: '启用后,此分组仅允许 Claude Code 官方客户端访问。非 Claude Code 请求将被拒绝或降级到指定分组。',
|
||||
enabled: '仅限 Claude Code',
|
||||
disabled: '允许所有客户端',
|
||||
fallbackGroup: '降级分组',
|
||||
fallbackHint: '非 Claude Code 请求将使用此分组,留空则直接拒绝',
|
||||
noFallback: '不降级(直接拒绝)'
|
||||
}
|
||||
},
|
||||
|
||||
|
||||
@@ -263,6 +263,9 @@ export interface Group {
|
||||
image_price_1k: number | null
|
||||
image_price_2k: number | null
|
||||
image_price_4k: number | null
|
||||
// Claude Code 客户端限制
|
||||
claude_code_only: boolean
|
||||
fallback_group_id: number | null
|
||||
account_count?: number
|
||||
created_at: string
|
||||
updated_at: string
|
||||
@@ -298,6 +301,15 @@ export interface CreateGroupRequest {
|
||||
platform?: GroupPlatform
|
||||
rate_multiplier?: number
|
||||
is_exclusive?: boolean
|
||||
subscription_type?: SubscriptionType
|
||||
daily_limit_usd?: number | null
|
||||
weekly_limit_usd?: number | null
|
||||
monthly_limit_usd?: number | null
|
||||
image_price_1k?: number | null
|
||||
image_price_2k?: number | null
|
||||
image_price_4k?: number | null
|
||||
claude_code_only?: boolean
|
||||
fallback_group_id?: number | null
|
||||
}
|
||||
|
||||
export interface UpdateGroupRequest {
|
||||
@@ -307,6 +319,15 @@ export interface UpdateGroupRequest {
|
||||
rate_multiplier?: number
|
||||
is_exclusive?: boolean
|
||||
status?: 'active' | 'inactive'
|
||||
subscription_type?: SubscriptionType
|
||||
daily_limit_usd?: number | null
|
||||
weekly_limit_usd?: number | null
|
||||
monthly_limit_usd?: number | null
|
||||
image_price_1k?: number | null
|
||||
image_price_2k?: number | null
|
||||
image_price_4k?: number | null
|
||||
claude_code_only?: boolean
|
||||
fallback_group_id?: number | null
|
||||
}
|
||||
|
||||
// ==================== Account & Proxy Types ====================
|
||||
|
||||
@@ -403,6 +403,62 @@
|
||||
</div>
|
||||
</div>
|
||||
|
||||
<!-- Claude Code 客户端限制(仅 anthropic 平台) -->
|
||||
<div v-if="createForm.platform === 'anthropic'" class="border-t pt-4">
|
||||
<div class="mb-1.5 flex items-center gap-1">
|
||||
<label class="text-sm font-medium text-gray-700 dark:text-gray-300">
|
||||
{{ t('admin.groups.claudeCode.title') }}
|
||||
</label>
|
||||
<!-- Help Tooltip -->
|
||||
<div class="group relative inline-flex">
|
||||
<Icon
|
||||
name="questionCircle"
|
||||
size="sm"
|
||||
:stroke-width="2"
|
||||
class="cursor-help text-gray-400 transition-colors hover:text-primary-500 dark:text-gray-500 dark:hover:text-primary-400"
|
||||
/>
|
||||
<div class="pointer-events-none absolute bottom-full left-0 z-50 mb-2 w-72 opacity-0 transition-all duration-200 group-hover:pointer-events-auto group-hover:opacity-100">
|
||||
<div class="rounded-lg bg-gray-900 p-3 text-white shadow-lg dark:bg-gray-800">
|
||||
<p class="text-xs leading-relaxed text-gray-300">
|
||||
{{ t('admin.groups.claudeCode.tooltip') }}
|
||||
</p>
|
||||
<div class="absolute -bottom-1.5 left-3 h-3 w-3 rotate-45 bg-gray-900 dark:bg-gray-800"></div>
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
<div class="flex items-center gap-3">
|
||||
<button
|
||||
type="button"
|
||||
@click="createForm.claude_code_only = !createForm.claude_code_only"
|
||||
:class="[
|
||||
'relative inline-flex h-6 w-11 items-center rounded-full transition-colors',
|
||||
createForm.claude_code_only ? 'bg-primary-500' : 'bg-gray-300 dark:bg-dark-600'
|
||||
]"
|
||||
>
|
||||
<span
|
||||
:class="[
|
||||
'inline-block h-4 w-4 transform rounded-full bg-white shadow transition-transform',
|
||||
createForm.claude_code_only ? 'translate-x-6' : 'translate-x-1'
|
||||
]"
|
||||
/>
|
||||
</button>
|
||||
<span class="text-sm text-gray-500 dark:text-gray-400">
|
||||
{{ createForm.claude_code_only ? t('admin.groups.claudeCode.enabled') : t('admin.groups.claudeCode.disabled') }}
|
||||
</span>
|
||||
</div>
|
||||
<!-- 降级分组选择(仅当启用 claude_code_only 时显示) -->
|
||||
<div v-if="createForm.claude_code_only" class="mt-3">
|
||||
<label class="input-label">{{ t('admin.groups.claudeCode.fallbackGroup') }}</label>
|
||||
<Select
|
||||
v-model="createForm.fallback_group_id"
|
||||
:options="fallbackGroupOptions"
|
||||
:placeholder="t('admin.groups.claudeCode.noFallback')"
|
||||
/>
|
||||
<p class="input-hint">{{ t('admin.groups.claudeCode.fallbackHint') }}</p>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
</form>
|
||||
|
||||
<template #footer>
|
||||
@@ -648,6 +704,62 @@
|
||||
</div>
|
||||
</div>
|
||||
|
||||
<!-- Claude Code 客户端限制(仅 anthropic 平台) -->
|
||||
<div v-if="editForm.platform === 'anthropic'" class="border-t pt-4">
|
||||
<div class="mb-1.5 flex items-center gap-1">
|
||||
<label class="text-sm font-medium text-gray-700 dark:text-gray-300">
|
||||
{{ t('admin.groups.claudeCode.title') }}
|
||||
</label>
|
||||
<!-- Help Tooltip -->
|
||||
<div class="group relative inline-flex">
|
||||
<Icon
|
||||
name="questionCircle"
|
||||
size="sm"
|
||||
:stroke-width="2"
|
||||
class="cursor-help text-gray-400 transition-colors hover:text-primary-500 dark:text-gray-500 dark:hover:text-primary-400"
|
||||
/>
|
||||
<div class="pointer-events-none absolute bottom-full left-0 z-50 mb-2 w-72 opacity-0 transition-all duration-200 group-hover:pointer-events-auto group-hover:opacity-100">
|
||||
<div class="rounded-lg bg-gray-900 p-3 text-white shadow-lg dark:bg-gray-800">
|
||||
<p class="text-xs leading-relaxed text-gray-300">
|
||||
{{ t('admin.groups.claudeCode.tooltip') }}
|
||||
</p>
|
||||
<div class="absolute -bottom-1.5 left-3 h-3 w-3 rotate-45 bg-gray-900 dark:bg-gray-800"></div>
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
<div class="flex items-center gap-3">
|
||||
<button
|
||||
type="button"
|
||||
@click="editForm.claude_code_only = !editForm.claude_code_only"
|
||||
:class="[
|
||||
'relative inline-flex h-6 w-11 items-center rounded-full transition-colors',
|
||||
editForm.claude_code_only ? 'bg-primary-500' : 'bg-gray-300 dark:bg-dark-600'
|
||||
]"
|
||||
>
|
||||
<span
|
||||
:class="[
|
||||
'inline-block h-4 w-4 transform rounded-full bg-white shadow transition-transform',
|
||||
editForm.claude_code_only ? 'translate-x-6' : 'translate-x-1'
|
||||
]"
|
||||
/>
|
||||
</button>
|
||||
<span class="text-sm text-gray-500 dark:text-gray-400">
|
||||
{{ editForm.claude_code_only ? t('admin.groups.claudeCode.enabled') : t('admin.groups.claudeCode.disabled') }}
|
||||
</span>
|
||||
</div>
|
||||
<!-- 降级分组选择(仅当启用 claude_code_only 时显示) -->
|
||||
<div v-if="editForm.claude_code_only" class="mt-3">
|
||||
<label class="input-label">{{ t('admin.groups.claudeCode.fallbackGroup') }}</label>
|
||||
<Select
|
||||
v-model="editForm.fallback_group_id"
|
||||
:options="fallbackGroupOptionsForEdit"
|
||||
:placeholder="t('admin.groups.claudeCode.noFallback')"
|
||||
/>
|
||||
<p class="input-hint">{{ t('admin.groups.claudeCode.fallbackHint') }}</p>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
</form>
|
||||
|
||||
<template #footer>
|
||||
@@ -774,6 +886,35 @@ const subscriptionTypeOptions = computed(() => [
|
||||
{ value: 'subscription', label: t('admin.groups.subscription.subscription') }
|
||||
])
|
||||
|
||||
// 降级分组选项(创建时)- 仅包含 anthropic 平台且未启用 claude_code_only 的分组
|
||||
const fallbackGroupOptions = computed(() => {
|
||||
const options: { value: number | null; label: string }[] = [
|
||||
{ value: null, label: t('admin.groups.claudeCode.noFallback') }
|
||||
]
|
||||
const eligibleGroups = groups.value.filter(
|
||||
(g) => g.platform === 'anthropic' && !g.claude_code_only && g.status === 'active'
|
||||
)
|
||||
eligibleGroups.forEach((g) => {
|
||||
options.push({ value: g.id, label: g.name })
|
||||
})
|
||||
return options
|
||||
})
|
||||
|
||||
// 降级分组选项(编辑时)- 排除自身
|
||||
const fallbackGroupOptionsForEdit = computed(() => {
|
||||
const options: { value: number | null; label: string }[] = [
|
||||
{ value: null, label: t('admin.groups.claudeCode.noFallback') }
|
||||
]
|
||||
const currentId = editingGroup.value?.id
|
||||
const eligibleGroups = groups.value.filter(
|
||||
(g) => g.platform === 'anthropic' && !g.claude_code_only && g.status === 'active' && g.id !== currentId
|
||||
)
|
||||
eligibleGroups.forEach((g) => {
|
||||
options.push({ value: g.id, label: g.name })
|
||||
})
|
||||
return options
|
||||
})
|
||||
|
||||
const groups = ref<Group[]>([])
|
||||
const loading = ref(false)
|
||||
const searchQuery = ref('')
|
||||
@@ -821,7 +962,10 @@ const createForm = reactive({
|
||||
// 图片生成计费配置(仅 antigravity 平台使用)
|
||||
image_price_1k: null as number | null,
|
||||
image_price_2k: null as number | null,
|
||||
image_price_4k: null as number | null
|
||||
image_price_4k: null as number | null,
|
||||
// Claude Code 客户端限制(仅 anthropic 平台使用)
|
||||
claude_code_only: false,
|
||||
fallback_group_id: null as number | null
|
||||
})
|
||||
|
||||
const editForm = reactive({
|
||||
@@ -838,7 +982,10 @@ const editForm = reactive({
|
||||
// 图片生成计费配置(仅 antigravity 平台使用)
|
||||
image_price_1k: null as number | null,
|
||||
image_price_2k: null as number | null,
|
||||
image_price_4k: null as number | null
|
||||
image_price_4k: null as number | null,
|
||||
// Claude Code 客户端限制(仅 anthropic 平台使用)
|
||||
claude_code_only: false,
|
||||
fallback_group_id: null as number | null
|
||||
})
|
||||
|
||||
// 根据分组类型返回不同的删除确认消息
|
||||
@@ -908,6 +1055,8 @@ const closeCreateModal = () => {
|
||||
createForm.image_price_1k = null
|
||||
createForm.image_price_2k = null
|
||||
createForm.image_price_4k = null
|
||||
createForm.claude_code_only = false
|
||||
createForm.fallback_group_id = null
|
||||
}
|
||||
|
||||
const handleCreateGroup = async () => {
|
||||
@@ -949,6 +1098,8 @@ const handleEdit = (group: Group) => {
|
||||
editForm.image_price_1k = group.image_price_1k
|
||||
editForm.image_price_2k = group.image_price_2k
|
||||
editForm.image_price_4k = group.image_price_4k
|
||||
editForm.claude_code_only = group.claude_code_only || false
|
||||
editForm.fallback_group_id = group.fallback_group_id
|
||||
showEditModal.value = true
|
||||
}
|
||||
|
||||
@@ -966,7 +1117,12 @@ const handleUpdateGroup = async () => {
|
||||
|
||||
submitting.value = true
|
||||
try {
|
||||
await adminAPI.groups.update(editingGroup.value.id, editForm)
|
||||
// 转换 fallback_group_id: null -> 0 (后端使用 0 表示清除)
|
||||
const payload = {
|
||||
...editForm,
|
||||
fallback_group_id: editForm.fallback_group_id === null ? 0 : editForm.fallback_group_id
|
||||
}
|
||||
await adminAPI.groups.update(editingGroup.value.id, payload)
|
||||
appStore.showSuccess(t('admin.groups.groupUpdated'))
|
||||
closeEditModal()
|
||||
loadGroups()
|
||||
|
||||
Reference in New Issue
Block a user