refactor: dedupe channel and gateway surfaces

This commit is contained in:
Peter Steinberger
2026-03-02 19:48:12 +00:00
parent 9617ac9dd5
commit 9d30159fcd
44 changed files with 1072 additions and 1479 deletions

View File

@@ -12,16 +12,44 @@ function buildDmAccess(overrides: Partial<DiscordDmCommandAccess>): DiscordDmCom
};
}
const TEST_ACCOUNT_ID = "default";
const TEST_SENDER = { id: "123", tag: "alice#0001", name: "alice" };
function createDmDecisionHarness(params?: { pairingCreated?: boolean }) {
const onPairingCreated = vi.fn(async () => {});
const onUnauthorized = vi.fn(async () => {});
const upsertPairingRequest = vi.fn(async () => ({
code: "PAIR-1",
created: params?.pairingCreated ?? true,
}));
return { onPairingCreated, onUnauthorized, upsertPairingRequest };
}
async function runPairingDecision(params?: { pairingCreated?: boolean }) {
const harness = createDmDecisionHarness({ pairingCreated: params?.pairingCreated });
const allowed = await handleDiscordDmCommandDecision({
dmAccess: buildDmAccess({
decision: "pairing",
commandAuthorized: false,
allowMatch: { allowed: false },
}),
accountId: TEST_ACCOUNT_ID,
sender: TEST_SENDER,
onPairingCreated: harness.onPairingCreated,
onUnauthorized: harness.onUnauthorized,
upsertPairingRequest: harness.upsertPairingRequest,
});
return { allowed, ...harness };
}
describe("handleDiscordDmCommandDecision", () => {
it("returns true for allowed DM access", async () => {
const onPairingCreated = vi.fn(async () => {});
const onUnauthorized = vi.fn(async () => {});
const upsertPairingRequest = vi.fn(async () => ({ code: "PAIR-1", created: true }));
const { onPairingCreated, onUnauthorized, upsertPairingRequest } = createDmDecisionHarness();
const allowed = await handleDiscordDmCommandDecision({
dmAccess: buildDmAccess({ decision: "allow" }),
accountId: "default",
sender: { id: "123", tag: "alice#0001", name: "alice" },
accountId: TEST_ACCOUNT_ID,
sender: TEST_SENDER,
onPairingCreated,
onUnauthorized,
upsertPairingRequest,
@@ -34,31 +62,17 @@ describe("handleDiscordDmCommandDecision", () => {
});
it("creates pairing reply for new pairing requests", async () => {
const onPairingCreated = vi.fn(async () => {});
const onUnauthorized = vi.fn(async () => {});
const upsertPairingRequest = vi.fn(async () => ({ code: "PAIR-1", created: true }));
const allowed = await handleDiscordDmCommandDecision({
dmAccess: buildDmAccess({
decision: "pairing",
commandAuthorized: false,
allowMatch: { allowed: false },
}),
accountId: "default",
sender: { id: "123", tag: "alice#0001", name: "alice" },
onPairingCreated,
onUnauthorized,
upsertPairingRequest,
});
const { allowed, onPairingCreated, onUnauthorized, upsertPairingRequest } =
await runPairingDecision();
expect(allowed).toBe(false);
expect(upsertPairingRequest).toHaveBeenCalledWith({
channel: "discord",
id: "123",
accountId: "default",
accountId: TEST_ACCOUNT_ID,
meta: {
tag: "alice#0001",
name: "alice",
tag: TEST_SENDER.tag,
name: TEST_SENDER.name,
},
});
expect(onPairingCreated).toHaveBeenCalledWith("PAIR-1");
@@ -66,21 +80,8 @@ describe("handleDiscordDmCommandDecision", () => {
});
it("skips pairing reply when pairing request already exists", async () => {
const onPairingCreated = vi.fn(async () => {});
const onUnauthorized = vi.fn(async () => {});
const upsertPairingRequest = vi.fn(async () => ({ code: "PAIR-1", created: false }));
const allowed = await handleDiscordDmCommandDecision({
dmAccess: buildDmAccess({
decision: "pairing",
commandAuthorized: false,
allowMatch: { allowed: false },
}),
accountId: "default",
sender: { id: "123", tag: "alice#0001", name: "alice" },
onPairingCreated,
onUnauthorized,
upsertPairingRequest,
const { allowed, onPairingCreated, onUnauthorized } = await runPairingDecision({
pairingCreated: false,
});
expect(allowed).toBe(false);
@@ -89,9 +90,7 @@ describe("handleDiscordDmCommandDecision", () => {
});
it("runs unauthorized handler for blocked DM access", async () => {
const onPairingCreated = vi.fn(async () => {});
const onUnauthorized = vi.fn(async () => {});
const upsertPairingRequest = vi.fn(async () => ({ code: "PAIR-1", created: true }));
const { onPairingCreated, onUnauthorized, upsertPairingRequest } = createDmDecisionHarness();
const allowed = await handleDiscordDmCommandDecision({
dmAccess: buildDmAccess({
@@ -99,8 +98,8 @@ describe("handleDiscordDmCommandDecision", () => {
commandAuthorized: false,
allowMatch: { allowed: false },
}),
accountId: "default",
sender: { id: "123", tag: "alice#0001", name: "alice" },
accountId: TEST_ACCOUNT_ID,
sender: TEST_SENDER,
onPairingCreated,
onUnauthorized,
upsertPairingRequest,

View File

@@ -374,7 +374,7 @@ async function handleDiscordReactionEvent(params: {
channelType === ChannelType.PublicThread ||
channelType === ChannelType.PrivateThread ||
channelType === ChannelType.AnnouncementThread;
const ingressAccess = await authorizeDiscordReactionIngress({
const reactionIngressBase: Omit<DiscordReactionIngressAuthorizationParams, "channelConfig"> = {
accountId: params.accountId,
user,
isDirectMessage,
@@ -391,7 +391,8 @@ async function handleDiscordReactionEvent(params: {
groupPolicy: params.groupPolicy,
allowNameMatching: params.allowNameMatching,
guildInfo,
});
};
const ingressAccess = await authorizeDiscordReactionIngress(reactionIngressBase);
if (!ingressAccess.allowed) {
logVerbose(`discord reaction blocked sender=${user.id} (reason=${ingressAccess.reason})`);
return;
@@ -486,22 +487,7 @@ async function handleDiscordReactionEvent(params: {
channelConfig: ReturnType<typeof resolveDiscordChannelConfigWithFallback>,
) =>
await authorizeDiscordReactionIngress({
accountId: params.accountId,
user,
isDirectMessage,
isGroupDm,
isGuildMessage,
channelId: data.channel_id,
channelName,
channelSlug,
dmEnabled: params.dmEnabled,
groupDmEnabled: params.groupDmEnabled,
groupDmChannels: params.groupDmChannels,
dmPolicy: params.dmPolicy,
allowFrom: params.allowFrom,
groupPolicy: params.groupPolicy,
allowNameMatching: params.allowNameMatching,
guildInfo,
...reactionIngressBase,
channelConfig,
});
const authorizeThreadChannelAccess = async (channelInfo: { parentId?: string } | null) => {

View File

@@ -3,7 +3,10 @@ import { inboundCtxCapture as capture } from "../../../test/helpers/inbound-cont
import { expectInboundContextContract } from "../../../test/helpers/inbound-contract.js";
import type { DiscordMessagePreflightContext } from "./message-handler.preflight.js";
import { processDiscordMessage } from "./message-handler.process.js";
import { createBaseDiscordMessageContext } from "./message-handler.test-harness.js";
import {
createBaseDiscordMessageContext,
createDiscordDirectMessageContextOverrides,
} from "./message-handler.test-harness.js";
describe("discord processDiscordMessage inbound contract", () => {
it("passes a finalized MsgContext to dispatchInboundMessage", async () => {
@@ -11,26 +14,7 @@ describe("discord processDiscordMessage inbound contract", () => {
const messageCtx = await createBaseDiscordMessageContext({
cfg: { messages: {} },
ackReactionScope: "direct",
data: { guild: null },
channelInfo: null,
channelName: undefined,
isGuildMessage: false,
isDirectMessage: true,
isGroupDm: false,
shouldRequireMention: false,
canDetectMention: false,
effectiveWasMentioned: false,
displayChannelSlug: "",
guildInfo: null,
guildSlug: "",
baseSessionKey: "agent:main:discord:direct:u1",
route: {
agentId: "main",
channel: "discord",
accountId: "default",
sessionKey: "agent:main:discord:direct:u1",
mainSessionKey: "agent:main:main",
},
...createDiscordDirectMessageContextOverrides(),
});
await processDiscordMessage(messageCtx);

View File

@@ -1,6 +1,9 @@
import { beforeEach, describe, expect, it, vi } from "vitest";
import { DEFAULT_EMOJIS } from "../../channels/status-reactions.js";
import { createBaseDiscordMessageContext } from "./message-handler.test-harness.js";
import {
createBaseDiscordMessageContext,
createDiscordDirectMessageContextOverrides,
} from "./message-handler.test-harness.js";
import {
__testing as threadBindingTesting,
createThreadBindingManager,
@@ -295,18 +298,7 @@ describe("processDiscordMessage ack reactions", () => {
describe("processDiscordMessage session routing", () => {
it("stores DM lastRoute with user target for direct-session continuity", async () => {
const ctx = await createBaseContext({
data: { guild: null },
channelInfo: null,
channelName: undefined,
isGuildMessage: false,
isDirectMessage: true,
isGroupDm: false,
shouldRequireMention: false,
canDetectMention: false,
effectiveWasMentioned: false,
displayChannelSlug: "",
guildInfo: null,
guildSlug: "",
...createDiscordDirectMessageContextOverrides(),
message: {
id: "m1",
channelId: "dm1",
@@ -314,14 +306,6 @@ describe("processDiscordMessage session routing", () => {
attachments: [],
},
messageChannelId: "dm1",
baseSessionKey: "agent:main:discord:direct:u1",
route: {
agentId: "main",
channel: "discord",
accountId: "default",
sessionKey: "agent:main:discord:direct:u1",
mainSessionKey: "agent:main:main",
},
});
// oxlint-disable-next-line typescript/no-explicit-any

View File

@@ -72,3 +72,28 @@ export async function createBaseDiscordMessageContext(
...overrides,
} as unknown as DiscordMessagePreflightContext;
}
export function createDiscordDirectMessageContextOverrides(): Record<string, unknown> {
return {
data: { guild: null },
channelInfo: null,
channelName: undefined,
isGuildMessage: false,
isDirectMessage: true,
isGroupDm: false,
shouldRequireMention: false,
canDetectMention: false,
effectiveWasMentioned: false,
displayChannelSlug: "",
guildInfo: null,
guildSlug: "",
baseSessionKey: "agent:main:discord:direct:u1",
route: {
agentId: "main",
channel: "discord",
accountId: "default",
sessionKey: "agent:main:discord:direct:u1",
mainSessionKey: "agent:main:main",
},
};
}

View File

@@ -30,6 +30,68 @@ function asMessage(payload: Record<string, unknown>): Message {
return payload as unknown as Message;
}
function expectSinglePngDownload(params: {
result: unknown;
expectedUrl: string;
filePathHint: string;
expectedPath: string;
placeholder: "<media:image>" | "<media:sticker>";
}) {
expect(fetchRemoteMedia).toHaveBeenCalledTimes(1);
expect(fetchRemoteMedia).toHaveBeenCalledWith({
url: params.expectedUrl,
filePathHint: params.filePathHint,
maxBytes: 512,
fetchImpl: undefined,
ssrfPolicy: expect.objectContaining({ allowRfc2544BenchmarkRange: true }),
});
expect(saveMediaBuffer).toHaveBeenCalledTimes(1);
expect(saveMediaBuffer).toHaveBeenCalledWith(expect.any(Buffer), "image/png", "inbound", 512);
expect(params.result).toEqual([
{
path: params.expectedPath,
contentType: "image/png",
placeholder: params.placeholder,
},
]);
}
function expectAttachmentImageFallback(params: { result: unknown; attachment: { url: string } }) {
expect(saveMediaBuffer).not.toHaveBeenCalled();
expect(params.result).toEqual([
{
path: params.attachment.url,
contentType: "image/png",
placeholder: "<media:image>",
},
]);
}
function asForwardedSnapshotMessage(params: {
content: string;
embeds: Array<{ title?: string; description?: string }>;
}) {
return asMessage({
content: "",
rawData: {
message_snapshots: [
{
message: {
content: params.content,
embeds: params.embeds,
attachments: [],
author: {
id: "u2",
username: "Bob",
discriminator: "0",
},
},
},
],
},
});
}
describe("resolveDiscordMessageChannelId", () => {
it.each([
{
@@ -157,14 +219,7 @@ describe("resolveForwardedMediaList", () => {
512,
);
expect(saveMediaBuffer).not.toHaveBeenCalled();
expect(result).toEqual([
{
path: attachment.url,
contentType: "image/png",
placeholder: "<media:image>",
},
]);
expectAttachmentImageFallback({ result, attachment });
});
it("downloads forwarded stickers", async () => {
@@ -191,23 +246,13 @@ describe("resolveForwardedMediaList", () => {
512,
);
expect(fetchRemoteMedia).toHaveBeenCalledTimes(1);
expect(fetchRemoteMedia).toHaveBeenCalledWith({
url: "https://media.discordapp.net/stickers/sticker-1.png",
expectSinglePngDownload({
result,
expectedUrl: "https://media.discordapp.net/stickers/sticker-1.png",
filePathHint: "wave.png",
maxBytes: 512,
fetchImpl: undefined,
ssrfPolicy: expect.objectContaining({ allowRfc2544BenchmarkRange: true }),
expectedPath: "/tmp/sticker.png",
placeholder: "<media:sticker>",
});
expect(saveMediaBuffer).toHaveBeenCalledTimes(1);
expect(saveMediaBuffer).toHaveBeenCalledWith(expect.any(Buffer), "image/png", "inbound", 512);
expect(result).toEqual([
{
path: "/tmp/sticker.png",
contentType: "image/png",
placeholder: "<media:sticker>",
},
]);
});
it("returns empty when no snapshots are present", async () => {
@@ -260,23 +305,13 @@ describe("resolveMediaList", () => {
512,
);
expect(fetchRemoteMedia).toHaveBeenCalledTimes(1);
expect(fetchRemoteMedia).toHaveBeenCalledWith({
url: "https://media.discordapp.net/stickers/sticker-2.png",
expectSinglePngDownload({
result,
expectedUrl: "https://media.discordapp.net/stickers/sticker-2.png",
filePathHint: "hello.png",
maxBytes: 512,
fetchImpl: undefined,
ssrfPolicy: expect.objectContaining({ allowRfc2544BenchmarkRange: true }),
expectedPath: "/tmp/sticker-2.png",
placeholder: "<media:sticker>",
});
expect(saveMediaBuffer).toHaveBeenCalledTimes(1);
expect(saveMediaBuffer).toHaveBeenCalledWith(expect.any(Buffer), "image/png", "inbound", 512);
expect(result).toEqual([
{
path: "/tmp/sticker-2.png",
contentType: "image/png",
placeholder: "<media:sticker>",
},
]);
});
it("forwards fetchImpl to sticker downloads", async () => {
@@ -324,14 +359,7 @@ describe("resolveMediaList", () => {
512,
);
expect(saveMediaBuffer).not.toHaveBeenCalled();
expect(result).toEqual([
{
path: attachment.url,
contentType: "image/png",
placeholder: "<media:image>",
},
]);
expectAttachmentImageFallback({ result, attachment });
});
it("falls back to URL when saveMediaBuffer fails", async () => {
@@ -471,24 +499,9 @@ describe("Discord media SSRF policy", () => {
describe("resolveDiscordMessageText", () => {
it("includes forwarded message snapshots in body text", () => {
const text = resolveDiscordMessageText(
asMessage({
content: "",
rawData: {
message_snapshots: [
{
message: {
content: "forwarded hello",
embeds: [],
attachments: [],
author: {
id: "u2",
username: "Bob",
discriminator: "0",
},
},
},
],
},
asForwardedSnapshotMessage({
content: "forwarded hello",
embeds: [],
}),
{ includeForwarded: true },
);
@@ -560,24 +573,9 @@ describe("resolveDiscordMessageText", () => {
it("joins forwarded snapshot embed title and description when content is empty", () => {
const text = resolveDiscordMessageText(
asMessage({
asForwardedSnapshotMessage({
content: "",
rawData: {
message_snapshots: [
{
message: {
content: "",
embeds: [{ title: "Forwarded title", description: "Forwarded details" }],
attachments: [],
author: {
id: "u2",
username: "Bob",
discriminator: "0",
},
},
},
],
},
embeds: [{ title: "Forwarded title", description: "Forwarded details" }],
}),
{ includeForwarded: true },
);

View File

@@ -122,6 +122,27 @@ describe("runDiscordGatewayLifecycle", () => {
expect(params.releaseEarlyGatewayErrorGuard).toHaveBeenCalledTimes(1);
}
function createGatewayHarness(params?: {
state?: {
sessionId?: string | null;
resumeGatewayUrl?: string | null;
sequence?: number | null;
};
sequence?: number | null;
}) {
const emitter = new EventEmitter();
const gateway = {
isConnected: false,
options: {},
disconnect: vi.fn(),
connect: vi.fn(),
...(params?.state ? { state: params.state } : {}),
...(params?.sequence !== undefined ? { sequence: params.sequence } : {}),
emitter,
};
return { emitter, gateway };
}
it("cleans up thread bindings when exec approvals startup fails", async () => {
const { runDiscordGatewayLifecycle } = await import("./provider.lifecycle.js");
const { lifecycleParams, start, stop, threadStop, releaseEarlyGatewayErrorGuard } =
@@ -229,20 +250,14 @@ describe("runDiscordGatewayLifecycle", () => {
vi.useFakeTimers();
try {
const { runDiscordGatewayLifecycle } = await import("./provider.lifecycle.js");
const emitter = new EventEmitter();
const gateway = {
isConnected: false,
options: {},
disconnect: vi.fn(),
connect: vi.fn(),
const { emitter, gateway } = createGatewayHarness({
state: {
sessionId: "session-1",
resumeGatewayUrl: "wss://gateway.discord.gg",
sequence: 123,
},
sequence: 123,
emitter,
};
});
getDiscordGatewayEmitterMock.mockReturnValueOnce(emitter);
waitForDiscordGatewayStopMock.mockImplementationOnce(async () => {
emitter.emit("debug", "WebSocket connection opened");
@@ -260,9 +275,10 @@ describe("runDiscordGatewayLifecycle", () => {
expect(gateway.connect).toHaveBeenNthCalledWith(1, true);
expect(gateway.connect).toHaveBeenNthCalledWith(2, true);
expect(gateway.connect).toHaveBeenNthCalledWith(3, false);
expect(gateway.state.sessionId).toBeNull();
expect(gateway.state.resumeGatewayUrl).toBeNull();
expect(gateway.state.sequence).toBeNull();
expect(gateway.state).toBeDefined();
expect(gateway.state?.sessionId).toBeNull();
expect(gateway.state?.resumeGatewayUrl).toBeNull();
expect(gateway.state?.sequence).toBeNull();
expect(gateway.sequence).toBeNull();
} finally {
vi.useRealTimers();
@@ -273,20 +289,14 @@ describe("runDiscordGatewayLifecycle", () => {
vi.useFakeTimers();
try {
const { runDiscordGatewayLifecycle } = await import("./provider.lifecycle.js");
const emitter = new EventEmitter();
const gateway = {
isConnected: false,
options: {},
disconnect: vi.fn(),
connect: vi.fn(),
const { emitter, gateway } = createGatewayHarness({
state: {
sessionId: "session-2",
resumeGatewayUrl: "wss://gateway.discord.gg",
sequence: 456,
},
sequence: 456,
emitter,
};
});
getDiscordGatewayEmitterMock.mockReturnValueOnce(emitter);
waitForDiscordGatewayStopMock.mockImplementationOnce(async () => {
emitter.emit("debug", "WebSocket connection opened");
@@ -324,14 +334,7 @@ describe("runDiscordGatewayLifecycle", () => {
vi.useFakeTimers();
try {
const { runDiscordGatewayLifecycle } = await import("./provider.lifecycle.js");
const emitter = new EventEmitter();
const gateway = {
isConnected: false,
options: {},
disconnect: vi.fn(),
connect: vi.fn(),
emitter,
};
const { emitter, gateway } = createGatewayHarness();
getDiscordGatewayEmitterMock.mockReturnValueOnce(emitter);
waitForDiscordGatewayStopMock.mockImplementationOnce(
(waitParams: WaitForDiscordGatewayStopParams) =>
@@ -356,14 +359,7 @@ describe("runDiscordGatewayLifecycle", () => {
vi.useFakeTimers();
try {
const { runDiscordGatewayLifecycle } = await import("./provider.lifecycle.js");
const emitter = new EventEmitter();
const gateway = {
isConnected: false,
options: {},
disconnect: vi.fn(),
connect: vi.fn(),
emitter,
};
const { emitter, gateway } = createGatewayHarness();
getDiscordGatewayEmitterMock.mockReturnValueOnce(emitter);
let resolveWait: (() => void) | undefined;
waitForDiscordGatewayStopMock.mockImplementationOnce(

View File

@@ -14,6 +14,11 @@ import { resolveTextChunkLimit } from "../../auto-reply/chunk.js";
import { listNativeCommandSpecsForConfig } from "../../auto-reply/commands-registry.js";
import type { HistoryEntry } from "../../auto-reply/reply/history.js";
import { listSkillCommandsForAgents } from "../../auto-reply/skill-commands.js";
import {
resolveThreadBindingIdleTimeoutMs,
resolveThreadBindingMaxAgeMs,
resolveThreadBindingsEnabled,
} from "../../channels/thread-bindings-policy.js";
import {
isNativeCommandsExplicitlyDisabled,
resolveNativeCommandsEnabled,
@@ -110,59 +115,6 @@ function summarizeGuilds(entries?: Record<string, unknown>) {
return `${sample.join(", ")}${suffix}`;
}
const DEFAULT_THREAD_BINDING_IDLE_HOURS = 24;
const DEFAULT_THREAD_BINDING_MAX_AGE_HOURS = 0;
function normalizeThreadBindingHours(raw: unknown): number | undefined {
if (typeof raw !== "number" || !Number.isFinite(raw)) {
return undefined;
}
if (raw < 0) {
return undefined;
}
return raw;
}
function resolveThreadBindingIdleTimeoutMs(params: {
channelIdleHoursRaw: unknown;
sessionIdleHoursRaw: unknown;
}): number {
const idleHours =
normalizeThreadBindingHours(params.channelIdleHoursRaw) ??
normalizeThreadBindingHours(params.sessionIdleHoursRaw) ??
DEFAULT_THREAD_BINDING_IDLE_HOURS;
return Math.floor(idleHours * 60 * 60 * 1000);
}
function resolveThreadBindingMaxAgeMs(params: {
channelMaxAgeHoursRaw: unknown;
sessionMaxAgeHoursRaw: unknown;
}): number {
const maxAgeHours =
normalizeThreadBindingHours(params.channelMaxAgeHoursRaw) ??
normalizeThreadBindingHours(params.sessionMaxAgeHoursRaw) ??
DEFAULT_THREAD_BINDING_MAX_AGE_HOURS;
return Math.floor(maxAgeHours * 60 * 60 * 1000);
}
function normalizeThreadBindingsEnabled(raw: unknown): boolean | undefined {
if (typeof raw !== "boolean") {
return undefined;
}
return raw;
}
function resolveThreadBindingsEnabled(params: {
channelEnabledRaw: unknown;
sessionEnabledRaw: unknown;
}): boolean {
return (
normalizeThreadBindingsEnabled(params.channelEnabledRaw) ??
normalizeThreadBindingsEnabled(params.sessionEnabledRaw) ??
true
);
}
function formatThreadBindingDurationForConfigLabel(durationMs: number): string {
const label = formatThreadBindingDurationLabel(durationMs);
return label === "disabled" ? "off" : label;
@@ -612,43 +564,26 @@ export async function monitorDiscordProvider(opts: MonitorDiscordOpts = {}) {
client.listeners,
new DiscordMessageListener(messageHandler, logger, trackInboundEvent),
);
const reactionListenerOptions = {
cfg,
accountId: account.accountId,
runtime,
botUserId,
dmEnabled,
groupDmEnabled,
groupDmChannels: groupDmChannels ?? [],
dmPolicy,
allowFrom: allowFrom ?? [],
groupPolicy,
allowNameMatching: isDangerousNameMatchingEnabled(discordCfg),
guildEntries,
logger,
onEvent: trackInboundEvent,
};
registerDiscordListener(client.listeners, new DiscordReactionListener(reactionListenerOptions));
registerDiscordListener(
client.listeners,
new DiscordReactionListener({
cfg,
accountId: account.accountId,
runtime,
botUserId,
dmEnabled,
groupDmEnabled,
groupDmChannels: groupDmChannels ?? [],
dmPolicy,
allowFrom: allowFrom ?? [],
groupPolicy,
allowNameMatching: isDangerousNameMatchingEnabled(discordCfg),
guildEntries,
logger,
onEvent: trackInboundEvent,
}),
);
registerDiscordListener(
client.listeners,
new DiscordReactionRemoveListener({
cfg,
accountId: account.accountId,
runtime,
botUserId,
dmEnabled,
groupDmEnabled,
groupDmChannels: groupDmChannels ?? [],
dmPolicy,
allowFrom: allowFrom ?? [],
groupPolicy,
allowNameMatching: isDangerousNameMatchingEnabled(discordCfg),
guildEntries,
logger,
onEvent: trackInboundEvent,
}),
new DiscordReactionRemoveListener(reactionListenerOptions),
);
if (discordCfg.intents?.presence) {

View File

@@ -4,6 +4,28 @@ import { resolveDiscordChannelAllowlist } from "./resolve-channels.js";
import { jsonResponse, urlToString } from "./test-http-helpers.js";
describe("resolveDiscordChannelAllowlist", () => {
async function resolveWithChannelLookup(params: {
guilds: Array<{ id: string; name: string }>;
channel: { id: string; name: string; guild_id: string; type: number };
entry: string;
}) {
const fetcher = withFetchPreconnect(async (input: RequestInfo | URL) => {
const url = urlToString(input);
if (url.endsWith("/users/@me/guilds")) {
return jsonResponse(params.guilds);
}
if (url.endsWith(`/channels/${params.channel.id}`)) {
return jsonResponse(params.channel);
}
return new Response("not found", { status: 404 });
});
return resolveDiscordChannelAllowlist({
token: "test",
entries: [params.entry],
fetcher,
});
}
it("resolves guild/channel by name", async () => {
const fetcher = withFetchPreconnect(async (input: RequestInfo | URL) => {
const url = urlToString(input);
@@ -54,21 +76,10 @@ describe("resolveDiscordChannelAllowlist", () => {
});
it("resolves guildId/channelId entries via channel lookup", async () => {
const fetcher = withFetchPreconnect(async (input: RequestInfo | URL) => {
const url = urlToString(input);
if (url.endsWith("/users/@me/guilds")) {
return jsonResponse([{ id: "111", name: "Guild One" }]);
}
if (url.endsWith("/channels/222")) {
return jsonResponse({ id: "222", name: "general", guild_id: "111", type: 0 });
}
return new Response("not found", { status: 404 });
});
const res = await resolveDiscordChannelAllowlist({
token: "test",
entries: ["111/222"],
fetcher,
const res = await resolveWithChannelLookup({
guilds: [{ id: "111", name: "Guild One" }],
channel: { id: "222", name: "general", guild_id: "111", type: 0 },
entry: "111/222",
});
expect(res[0]).toMatchObject({
@@ -82,24 +93,13 @@ describe("resolveDiscordChannelAllowlist", () => {
});
it("reports unresolved when channel id belongs to a different guild", async () => {
const fetcher = withFetchPreconnect(async (input: RequestInfo | URL) => {
const url = urlToString(input);
if (url.endsWith("/users/@me/guilds")) {
return jsonResponse([
{ id: "111", name: "Guild One" },
{ id: "333", name: "Guild Two" },
]);
}
if (url.endsWith("/channels/222")) {
return jsonResponse({ id: "222", name: "general", guild_id: "333", type: 0 });
}
return new Response("not found", { status: 404 });
});
const res = await resolveDiscordChannelAllowlist({
token: "test",
entries: ["111/222"],
fetcher,
const res = await resolveWithChannelLookup({
guilds: [
{ id: "111", name: "Guild One" },
{ id: "333", name: "Guild Two" },
],
channel: { id: "222", name: "general", guild_id: "333", type: 0 },
entry: "111/222",
});
expect(res[0]).toMatchObject({

View File

@@ -1,9 +1,7 @@
import type { DirectoryConfigParams } from "../channels/plugins/directory-config.js";
import {
buildMessagingTarget,
ensureTargetId,
parseTargetMention,
parseTargetPrefixes,
parseMentionPrefixOrAtUserTarget,
requireTargetKind,
type MessagingTarget,
type MessagingTargetKind,
@@ -25,33 +23,19 @@ export function parseDiscordTarget(
if (!trimmed) {
return undefined;
}
const mentionTarget = parseTargetMention({
const userTarget = parseMentionPrefixOrAtUserTarget({
raw: trimmed,
mentionPattern: /^<@!?(\d+)>$/,
kind: "user",
});
if (mentionTarget) {
return mentionTarget;
}
const prefixedTarget = parseTargetPrefixes({
raw: trimmed,
prefixes: [
{ prefix: "user:", kind: "user" },
{ prefix: "channel:", kind: "channel" },
{ prefix: "discord:", kind: "user" },
],
atUserPattern: /^\d+$/,
atUserErrorMessage: "Discord DMs require a user id (use user:<id> or a <@id> mention)",
});
if (prefixedTarget) {
return prefixedTarget;
}
if (trimmed.startsWith("@")) {
const candidate = trimmed.slice(1).trim();
const id = ensureTargetId({
candidate,
pattern: /^\d+$/,
errorMessage: "Discord DMs require a user id (use user:<id> or a <@id> mention)",
});
return buildMessagingTarget("user", id, trimmed);
if (userTarget) {
return userTarget;
}
if (/^\d+$/.test(trimmed)) {
if (options.defaultKind) {

View File

@@ -124,6 +124,44 @@ describe("DiscordVoiceManager", () => {
resolveAgentRouteMock.mockClear();
});
const createManager = (
discordConfig: ConstructorParameters<
typeof managerModule.DiscordVoiceManager
>[0]["discordConfig"] = {},
) =>
new managerModule.DiscordVoiceManager({
client: createClient() as never,
cfg: {},
discordConfig,
accountId: "default",
runtime: createRuntime(),
});
const expectConnectedStatus = (
manager: InstanceType<typeof managerModule.DiscordVoiceManager>,
channelId: string,
) => {
expect(manager.status()).toEqual([
{
ok: true,
message: `connected: guild g1 channel ${channelId}`,
guildId: "g1",
channelId,
},
]);
};
const emitDecryptFailure = (manager: InstanceType<typeof managerModule.DiscordVoiceManager>) => {
const entry = (manager as unknown as { sessions: Map<string, unknown> }).sessions.get("g1");
expect(entry).toBeDefined();
(
manager as unknown as { handleReceiveError: (e: unknown, err: unknown) => void }
).handleReceiveError(
entry,
new Error("Failed to decrypt: DecryptionFailed(UnencryptedWhenPassthroughDisabled)"),
);
};
it("keeps the new session when an old disconnected handler fires", async () => {
const oldConnection = createConnectionMock();
const newConnection = createConnectionMock();
@@ -135,13 +173,7 @@ describe("DiscordVoiceManager", () => {
return undefined;
});
const manager = new managerModule.DiscordVoiceManager({
client: createClient() as never,
cfg: {},
discordConfig: {},
accountId: "default",
runtime: createRuntime(),
});
const manager = createManager();
await manager.join({ guildId: "g1", channelId: "c1" });
await manager.join({ guildId: "g1", channelId: "c2" });
@@ -150,14 +182,7 @@ describe("DiscordVoiceManager", () => {
expect(oldDisconnected).toBeTypeOf("function");
await oldDisconnected?.();
expect(manager.status()).toEqual([
{
ok: true,
message: "connected: guild g1 channel c2",
guildId: "g1",
channelId: "c2",
},
]);
expectConnectedStatus(manager, "c2");
});
it("keeps the new session when an old destroyed handler fires", async () => {
@@ -165,13 +190,7 @@ describe("DiscordVoiceManager", () => {
const newConnection = createConnectionMock();
joinVoiceChannelMock.mockReturnValueOnce(oldConnection).mockReturnValueOnce(newConnection);
const manager = new managerModule.DiscordVoiceManager({
client: createClient() as never,
cfg: {},
discordConfig: {},
accountId: "default",
runtime: createRuntime(),
});
const manager = createManager();
await manager.join({ guildId: "g1", channelId: "c1" });
await manager.join({ guildId: "g1", channelId: "c2" });
@@ -180,26 +199,13 @@ describe("DiscordVoiceManager", () => {
expect(oldDestroyed).toBeTypeOf("function");
oldDestroyed?.();
expect(manager.status()).toEqual([
{
ok: true,
message: "connected: guild g1 channel c2",
guildId: "g1",
channelId: "c2",
},
]);
expectConnectedStatus(manager, "c2");
});
it("removes voice listeners on leave", async () => {
const connection = createConnectionMock();
joinVoiceChannelMock.mockReturnValueOnce(connection);
const manager = new managerModule.DiscordVoiceManager({
client: createClient() as never,
cfg: {},
discordConfig: {},
accountId: "default",
runtime: createRuntime(),
});
const manager = createManager();
await manager.join({ guildId: "g1", channelId: "c1" });
await manager.leave({ guildId: "g1" });
@@ -212,17 +218,11 @@ describe("DiscordVoiceManager", () => {
});
it("passes DAVE options to joinVoiceChannel", async () => {
const manager = new managerModule.DiscordVoiceManager({
client: createClient() as never,
cfg: {},
discordConfig: {
voice: {
daveEncryption: false,
decryptionFailureTolerance: 8,
},
const manager = createManager({
voice: {
daveEncryption: false,
decryptionFailureTolerance: 8,
},
accountId: "default",
runtime: createRuntime(),
});
await manager.join({ guildId: "g1", channelId: "c1" });
@@ -236,36 +236,13 @@ describe("DiscordVoiceManager", () => {
});
it("attempts rejoin after repeated decrypt failures", async () => {
const manager = new managerModule.DiscordVoiceManager({
client: createClient() as never,
cfg: {},
discordConfig: {},
accountId: "default",
runtime: createRuntime(),
});
const manager = createManager();
await manager.join({ guildId: "g1", channelId: "c1" });
const entry = (manager as unknown as { sessions: Map<string, unknown> }).sessions.get("g1");
expect(entry).toBeDefined();
(
manager as unknown as { handleReceiveError: (e: unknown, err: unknown) => void }
).handleReceiveError(
entry,
new Error("Failed to decrypt: DecryptionFailed(UnencryptedWhenPassthroughDisabled)"),
);
(
manager as unknown as { handleReceiveError: (e: unknown, err: unknown) => void }
).handleReceiveError(
entry,
new Error("Failed to decrypt: DecryptionFailed(UnencryptedWhenPassthroughDisabled)"),
);
(
manager as unknown as { handleReceiveError: (e: unknown, err: unknown) => void }
).handleReceiveError(
entry,
new Error("Failed to decrypt: DecryptionFailed(UnencryptedWhenPassthroughDisabled)"),
);
emitDecryptFailure(manager);
emitDecryptFailure(manager);
emitDecryptFailure(manager);
await new Promise((resolve) => setTimeout(resolve, 0));
await new Promise((resolve) => setTimeout(resolve, 0));