refactor(discord): unify DM command auth handling

This commit is contained in:
Peter Steinberger
2026-03-01 23:59:55 +00:00
parent 12c1257023
commit 75596e9370
4 changed files with 120 additions and 64 deletions

View File

@@ -17,6 +17,32 @@ export type DiscordDmCommandAccess = {
allowMatch: ReturnType<typeof resolveDiscordAllowListMatch> | { allowed: false }; allowMatch: ReturnType<typeof resolveDiscordAllowListMatch> | { allowed: false };
}; };
function resolveSenderAllowMatch(params: {
allowEntries: string[];
sender: { id: string; name?: string; tag?: string };
allowNameMatching: boolean;
}) {
const allowList = normalizeDiscordAllowList(params.allowEntries, DISCORD_ALLOW_LIST_PREFIXES);
return allowList
? resolveDiscordAllowListMatch({
allowList,
candidate: params.sender,
allowNameMatching: params.allowNameMatching,
})
: ({ allowed: false } as const);
}
function resolveDmPolicyCommandAuthorization(params: {
dmPolicy: DiscordDmPolicy;
decision: DmGroupAccessDecision;
commandAuthorized: boolean;
}) {
if (params.dmPolicy === "open" && params.decision === "allow") {
return true;
}
return params.commandAuthorized;
}
export async function resolveDiscordDmCommandAccess(params: { export async function resolveDiscordDmCommandAccess(params: {
accountId: string; accountId: string;
dmPolicy: DiscordDmPolicy; dmPolicy: DiscordDmPolicy;
@@ -40,30 +66,19 @@ export async function resolveDiscordDmCommandAccess(params: {
allowFrom: params.configuredAllowFrom, allowFrom: params.configuredAllowFrom,
groupAllowFrom: [], groupAllowFrom: [],
storeAllowFrom, storeAllowFrom,
isSenderAllowed: (allowEntries) => { isSenderAllowed: (allowEntries) =>
const allowList = normalizeDiscordAllowList(allowEntries, DISCORD_ALLOW_LIST_PREFIXES); resolveSenderAllowMatch({
const allowMatch = allowList allowEntries,
? resolveDiscordAllowListMatch({ sender: params.sender,
allowList, allowNameMatching: params.allowNameMatching,
candidate: params.sender, }).allowed,
allowNameMatching: params.allowNameMatching,
})
: { allowed: false };
return allowMatch.allowed;
},
}); });
const commandAllowList = normalizeDiscordAllowList( const allowMatch = resolveSenderAllowMatch({
access.effectiveAllowFrom, allowEntries: access.effectiveAllowFrom,
DISCORD_ALLOW_LIST_PREFIXES, sender: params.sender,
); allowNameMatching: params.allowNameMatching,
const allowMatch = commandAllowList });
? resolveDiscordAllowListMatch({
allowList: commandAllowList,
candidate: params.sender,
allowNameMatching: params.allowNameMatching,
})
: { allowed: false };
const commandAuthorized = resolveCommandAuthorizedFromAuthorizers({ const commandAuthorized = resolveCommandAuthorizedFromAuthorizers({
useAccessGroups: params.useAccessGroups, useAccessGroups: params.useAccessGroups,
@@ -75,13 +90,15 @@ export async function resolveDiscordDmCommandAccess(params: {
], ],
modeWhenAccessGroupsOff: "configured", modeWhenAccessGroupsOff: "configured",
}); });
const effectiveCommandAuthorized =
access.decision === "allow" && params.dmPolicy === "open" ? true : commandAuthorized;
return { return {
decision: access.decision, decision: access.decision,
reason: access.reason, reason: access.reason,
commandAuthorized: effectiveCommandAuthorized, commandAuthorized: resolveDmPolicyCommandAuthorization({
dmPolicy: params.dmPolicy,
decision: access.decision,
commandAuthorized,
}),
allowMatch, allowMatch,
}; };
} }

View File

@@ -0,0 +1,39 @@
import { upsertChannelPairingRequest } from "../../pairing/pairing-store.js";
import type { DiscordDmCommandAccess } from "./dm-command-auth.js";
export async function handleDiscordDmCommandDecision(params: {
dmAccess: DiscordDmCommandAccess;
accountId: string;
sender: {
id: string;
tag?: string;
name?: string;
};
onPairingCreated: (code: string) => Promise<void>;
onUnauthorized: () => Promise<void>;
upsertPairingRequest?: typeof upsertChannelPairingRequest;
}): Promise<boolean> {
if (params.dmAccess.decision === "allow") {
return true;
}
if (params.dmAccess.decision === "pairing") {
const upsertPairingRequest = params.upsertPairingRequest ?? upsertChannelPairingRequest;
const { code, created } = await upsertPairingRequest({
channel: "discord",
id: params.sender.id,
accountId: params.accountId,
meta: {
tag: params.sender.tag,
name: params.sender.name,
},
});
if (created) {
await params.onPairingCreated(code);
}
return false;
}
await params.onUnauthorized();
return false;
}

View File

@@ -25,7 +25,6 @@ import { enqueueSystemEvent } from "../../infra/system-events.js";
import { logDebug } from "../../logger.js"; import { logDebug } from "../../logger.js";
import { getChildLogger } from "../../logging.js"; import { getChildLogger } from "../../logging.js";
import { buildPairingReply } from "../../pairing/pairing-messages.js"; import { buildPairingReply } from "../../pairing/pairing-messages.js";
import { upsertChannelPairingRequest } from "../../pairing/pairing-store.js";
import { resolveAgentRoute } from "../../routing/resolve-route.js"; import { resolveAgentRoute } from "../../routing/resolve-route.js";
import { DEFAULT_ACCOUNT_ID, resolveAgentIdFromSessionKey } from "../../routing/session-key.js"; import { DEFAULT_ACCOUNT_ID, resolveAgentIdFromSessionKey } from "../../routing/session-key.js";
import { fetchPluralKitMessageInfo } from "../pluralkit.js"; import { fetchPluralKitMessageInfo } from "../pluralkit.js";
@@ -42,6 +41,7 @@ import {
resolveGroupDmAllow, resolveGroupDmAllow,
} from "./allow-list.js"; } from "./allow-list.js";
import { resolveDiscordDmCommandAccess } from "./dm-command-auth.js"; import { resolveDiscordDmCommandAccess } from "./dm-command-auth.js";
import { handleDiscordDmCommandDecision } from "./dm-command-decision.js";
import { import {
formatDiscordUserTag, formatDiscordUserTag,
resolveDiscordSystemLocation, resolveDiscordSystemLocation,
@@ -175,6 +175,7 @@ export async function preflightDiscordMessage(
const dmPolicy = params.discordConfig?.dmPolicy ?? params.discordConfig?.dm?.policy ?? "pairing"; const dmPolicy = params.discordConfig?.dmPolicy ?? params.discordConfig?.dm?.policy ?? "pairing";
const useAccessGroups = params.cfg.commands?.useAccessGroups !== false; const useAccessGroups = params.cfg.commands?.useAccessGroups !== false;
const resolvedAccountId = params.accountId ?? DEFAULT_ACCOUNT_ID; const resolvedAccountId = params.accountId ?? DEFAULT_ACCOUNT_ID;
const allowNameMatching = isDangerousNameMatchingEnabled(params.discordConfig);
let commandAuthorized = true; let commandAuthorized = true;
if (isDirectMessage) { if (isDirectMessage) {
if (dmPolicy === "disabled") { if (dmPolicy === "disabled") {
@@ -190,7 +191,7 @@ export async function preflightDiscordMessage(
name: sender.name, name: sender.name,
tag: sender.tag, tag: sender.tag,
}, },
allowNameMatching: isDangerousNameMatchingEnabled(params.discordConfig), allowNameMatching,
useAccessGroups, useAccessGroups,
}); });
commandAuthorized = dmAccess.commandAuthorized; commandAuthorized = dmAccess.commandAuthorized;
@@ -198,17 +199,15 @@ export async function preflightDiscordMessage(
const allowMatchMeta = formatAllowlistMatchMeta( const allowMatchMeta = formatAllowlistMatchMeta(
dmAccess.allowMatch.allowed ? dmAccess.allowMatch : undefined, dmAccess.allowMatch.allowed ? dmAccess.allowMatch : undefined,
); );
if (dmAccess.decision === "pairing") { await handleDiscordDmCommandDecision({
const { code, created } = await upsertChannelPairingRequest({ dmAccess,
channel: "discord", accountId: resolvedAccountId,
sender: {
id: author.id, id: author.id,
accountId: resolvedAccountId, tag: formatDiscordUserTag(author),
meta: { name: author.username ?? undefined,
tag: formatDiscordUserTag(author), },
name: author.username ?? undefined, onPairingCreated: async (code) => {
},
});
if (created) {
logVerbose( logVerbose(
`discord pairing request sender=${author.id} tag=${formatDiscordUserTag(author)} (${allowMatchMeta})`, `discord pairing request sender=${author.id} tag=${formatDiscordUserTag(author)} (${allowMatchMeta})`,
); );
@@ -229,12 +228,13 @@ export async function preflightDiscordMessage(
} catch (err) { } catch (err) {
logVerbose(`discord pairing reply failed for ${author.id}: ${String(err)}`); logVerbose(`discord pairing reply failed for ${author.id}: ${String(err)}`);
} }
} },
} else { onUnauthorized: async () => {
logVerbose( logVerbose(
`Blocked unauthorized discord sender ${sender.id} (dmPolicy=${dmPolicy}, ${allowMatchMeta})`, `Blocked unauthorized discord sender ${sender.id} (dmPolicy=${dmPolicy}, ${allowMatchMeta})`,
); );
} },
});
return null; return null;
} }
} }
@@ -570,7 +570,7 @@ export async function preflightDiscordMessage(
guildInfo, guildInfo,
memberRoleIds, memberRoleIds,
sender, sender,
allowNameMatching: isDangerousNameMatchingEnabled(params.discordConfig), allowNameMatching,
}); });
if (!isDirectMessage) { if (!isDirectMessage) {
@@ -587,7 +587,7 @@ export async function preflightDiscordMessage(
name: sender.name, name: sender.name,
tag: sender.tag, tag: sender.tag,
}, },
{ allowNameMatching: isDangerousNameMatchingEnabled(params.discordConfig) }, { allowNameMatching },
) )
: false; : false;
const commandGate = resolveControlCommandGate({ const commandGate = resolveControlCommandGate({

View File

@@ -46,7 +46,6 @@ import { logVerbose } from "../../globals.js";
import { createSubsystemLogger } from "../../logging/subsystem.js"; import { createSubsystemLogger } from "../../logging/subsystem.js";
import { getAgentScopedMediaLocalRoots } from "../../media/local-roots.js"; import { getAgentScopedMediaLocalRoots } from "../../media/local-roots.js";
import { buildPairingReply } from "../../pairing/pairing-messages.js"; import { buildPairingReply } from "../../pairing/pairing-messages.js";
import { upsertChannelPairingRequest } from "../../pairing/pairing-store.js";
import { resolveAgentRoute } from "../../routing/resolve-route.js"; import { resolveAgentRoute } from "../../routing/resolve-route.js";
import { resolveAgentIdFromSessionKey } from "../../routing/session-key.js"; import { resolveAgentIdFromSessionKey } from "../../routing/session-key.js";
import { buildUntrustedChannelMetadata } from "../../security/channel-metadata.js"; import { buildUntrustedChannelMetadata } from "../../security/channel-metadata.js";
@@ -65,6 +64,7 @@ import {
resolveDiscordOwnerAllowFrom, resolveDiscordOwnerAllowFrom,
} from "./allow-list.js"; } from "./allow-list.js";
import { resolveDiscordDmCommandAccess } from "./dm-command-auth.js"; import { resolveDiscordDmCommandAccess } from "./dm-command-auth.js";
import { handleDiscordDmCommandDecision } from "./dm-command-decision.js";
import { resolveDiscordChannelInfo } from "./message-utils.js"; import { resolveDiscordChannelInfo } from "./message-utils.js";
import { import {
readDiscordModelPickerRecentModels, readDiscordModelPickerRecentModels,
@@ -1269,6 +1269,7 @@ async function dispatchDiscordCommandInteraction(params: {
const memberRoleIds = Array.isArray(interaction.rawData.member?.roles) const memberRoleIds = Array.isArray(interaction.rawData.member?.roles)
? interaction.rawData.member.roles.map((roleId: string) => String(roleId)) ? interaction.rawData.member.roles.map((roleId: string) => String(roleId))
: []; : [];
const allowNameMatching = isDangerousNameMatchingEnabled(discordConfig);
const ownerAllowList = normalizeDiscordAllowList( const ownerAllowList = normalizeDiscordAllowList(
discordConfig?.allowFrom ?? discordConfig?.dm?.allowFrom ?? [], discordConfig?.allowFrom ?? discordConfig?.dm?.allowFrom ?? [],
["discord:", "user:", "pk:"], ["discord:", "user:", "pk:"],
@@ -1282,7 +1283,7 @@ async function dispatchDiscordCommandInteraction(params: {
name: sender.name, name: sender.name,
tag: sender.tag, tag: sender.tag,
}, },
{ allowNameMatching: isDangerousNameMatchingEnabled(discordConfig) }, { allowNameMatching },
) )
: false; : false;
const guildInfo = resolveDiscordGuildEntry({ const guildInfo = resolveDiscordGuildEntry({
@@ -1366,22 +1367,20 @@ async function dispatchDiscordCommandInteraction(params: {
name: sender.name, name: sender.name,
tag: sender.tag, tag: sender.tag,
}, },
allowNameMatching: isDangerousNameMatchingEnabled(discordConfig), allowNameMatching,
useAccessGroups, useAccessGroups,
}); });
commandAuthorized = dmAccess.commandAuthorized; commandAuthorized = dmAccess.commandAuthorized;
if (dmAccess.decision !== "allow") { if (dmAccess.decision !== "allow") {
if (dmAccess.decision === "pairing") { await handleDiscordDmCommandDecision({
const { code, created } = await upsertChannelPairingRequest({ dmAccess,
channel: "discord", accountId,
sender: {
id: user.id, id: user.id,
accountId, tag: sender.tag,
meta: { name: sender.name,
tag: sender.tag, },
name: sender.name, onPairingCreated: async (code) => {
},
});
if (created) {
await respond( await respond(
buildPairingReply({ buildPairingReply({
channel: "discord", channel: "discord",
@@ -1390,10 +1389,11 @@ async function dispatchDiscordCommandInteraction(params: {
}), }),
{ ephemeral: true }, { ephemeral: true },
); );
} },
} else { onUnauthorized: async () => {
await respond("You are not authorized to use this command.", { ephemeral: true }); await respond("You are not authorized to use this command.", { ephemeral: true });
} },
});
return; return;
} }
} }
@@ -1403,7 +1403,7 @@ async function dispatchDiscordCommandInteraction(params: {
guildInfo, guildInfo,
memberRoleIds, memberRoleIds,
sender, sender,
allowNameMatching: isDangerousNameMatchingEnabled(discordConfig), allowNameMatching,
}); });
const authorizers = useAccessGroups const authorizers = useAccessGroups
? [ ? [
@@ -1509,7 +1509,7 @@ async function dispatchDiscordCommandInteraction(params: {
channelConfig, channelConfig,
guildInfo, guildInfo,
sender: { id: sender.id, name: sender.name, tag: sender.tag }, sender: { id: sender.id, name: sender.name, tag: sender.tag },
allowNameMatching: isDangerousNameMatchingEnabled(discordConfig), allowNameMatching,
}); });
const ctxPayload = finalizeInboundContext({ const ctxPayload = finalizeInboundContext({
Body: prompt, Body: prompt,