diff --git a/src/commands/index.ts b/src/commands/index.ts index 21e757e..0aee1e0 100644 --- a/src/commands/index.ts +++ b/src/commands/index.ts @@ -7,6 +7,7 @@ import { logger } from "@/logger" import { modules } from "@/modules" import type { TelemetryContextFlavor } from "@/modules/telemetry" import { redis } from "@/redis" +import { RedisSet } from "@/redis/set" import { fmt } from "@/utils/format" import { ephemeral } from "@/utils/messages" import type { Context, Role } from "@/utils/types" @@ -19,6 +20,19 @@ import { pin } from "./pin" import { report } from "./report" import { search } from "./search" +const userSet = new RedisSet({ + redis, + prefix: "managed-commands:cached-users", + ttl: 60 * 60 * 24, // 24h, we can afford some staleness here and it helps reduce the number of Redis calls significantly +}) + +const userRolesCache = new RedisFallbackAdapter({ + redis, + prefix: "managed-commands:user-roles", + ttl: 60 * 5, + logger, +}) + const adapter = new RedisFallbackAdapter>({ redis, prefix: "conv", @@ -34,6 +48,14 @@ export const commands = new ManagedCommands { + const key = `${userId}:${chatId}` + if (await userSet.has(key)) { + return true + } + await userSet.add(key) + return false + }, wrongScope: async ({ context, command }) => { await context.deleteMessage().catch(() => {}) logger.info( @@ -104,19 +126,15 @@ export const commands = new ManagedCommands { - // TODO: cache this to avoid hitting the db on every command - const { roles } = await api.tg.permissions.getRoles.query({ userId }) - return roles || [] + const cached = await userRolesCache.read(String(userId)) + if (cached) return cached + + const res = await api.tg.permissions.getRoles.query({ userId }) + const roles = res.roles ?? [] + await userRolesCache.write(String(userId), roles) + return roles }, }) - .createCommand({ - trigger: "ping", - scope: "private", - description: "Replies with pong", - handler: async ({ context }) => { - await context.reply("pong") - }, - }) .createCommand({ trigger: "start", scope: "private", @@ -138,4 +156,12 @@ export const commands = new ManagedCommands { + await context.reply("pong") + }, + }) .withCollection(linkAdminDashboard, report, search, management, moderation, pin, invite) diff --git a/src/lib/managed-commands/command.ts b/src/lib/managed-commands/command.ts index a25c396..52073db 100644 --- a/src/lib/managed-commands/command.ts +++ b/src/lib/managed-commands/command.ts @@ -1,6 +1,6 @@ import type { Conversation } from "@grammyjs/conversations" import type { Context } from "grammy" -import type { Message } from "grammy/types" +import type { BotCommand, Message } from "grammy/types" import type { z } from "zod" import type { MaybeArray } from "@/utils/types" import type { ConversationContext } from "./context" @@ -168,6 +168,14 @@ export type AnyCommand +export type AnyGroupCommand = Command< + CommandArgs, + CommandReplyTo, + "group" | "both", + TRole, + C +> + /** * Type guard to check if a command is allowed in groups. * @param cmd The command to check @@ -221,3 +229,45 @@ export function isAllowedInPrivateOnly< >(cmd: Command): cmd is Command { return cmd.scope === "private" } + +export function isAllowedInPrivate< + A extends CommandArgs, + R extends CommandReplyTo, + TRole extends string = string, + C extends Context = Context, +>(cmd: Command): cmd is Command { + return cmd.scope !== "group" +} + +export function isAllowedEverywhere< + A extends CommandArgs, + R extends CommandReplyTo, + TRole extends string = string, + C extends Context = Context, +>(cmd: Command): cmd is Command { + return cmd.scope === "both" || cmd.scope === undefined +} + +export function toBotCommands(command: AnyCommand): BotCommand[] { + const triggers = Array.isArray(command.trigger) ? command.trigger : [command.trigger] + return triggers.map((trigger) => ({ + command: trigger, + description: command.description ?? "No description", + })) +} + +export function isForThisScope(cmd: AnyCommand, chatType: "private" | "group" | "supergroup" | "channel"): boolean { + if (chatType === "channel") return false + if (cmd.scope === "private") return chatType === "private" + if (cmd.scope === "group") return chatType === "group" || chatType === "supergroup" + return true +} + +export function switchOnScope( + cmd: Command, + handlers: { private: T; group: T; both: T } +) { + if (cmd.scope === "private") return handlers.private + if (cmd.scope === "group") return handlers.group + return handlers.both +} diff --git a/src/lib/managed-commands/index.ts b/src/lib/managed-commands/index.ts index 951c8c3..7cd18f4 100644 --- a/src/lib/managed-commands/index.ts +++ b/src/lib/managed-commands/index.ts @@ -9,16 +9,20 @@ import { hydrate } from "@grammyjs/hydrate" import { hydrateReply, parseMode } from "@grammyjs/parse-mode" import type { CommandContext, Context, Middleware, MiddlewareObj } from "grammy" import { Composer, MemorySessionStorage } from "grammy" -import type { Message } from "grammy/types" +import type { BotCommand, Message } from "grammy/types" import type { Result } from "neverthrow" import { err, ok } from "neverthrow" import z from "zod" +import { asyncFilter, asyncMap } from "@/utils/arrays" import { isFromGroupChat, isFromPrivateChat } from "@/utils/chat" import { fmt } from "@/utils/format" -import { ephemeral } from "@/utils/messages" +import { once } from "@/utils/once" +import type { ContextWith } from "@/utils/types" +import { wait } from "@/utils/wait" import type { CommandsCollection } from "./collection" import type { AnyCommand, + AnyGroupCommand, ArgumentMap, ArgumentOptions, Command, @@ -29,7 +33,7 @@ import type { CommandScopedContext, RepliedTo, } from "./command" -import { isAllowedInGroups, isTypedArgumentOptions } from "./command" +import { isAllowedInGroups, isAllowedInPrivate, isTypedArgumentOptions, switchOnScope, toBotCommands } from "./command" import type { ManagedCommandsFlavor } from "./context" export type Hook = ( @@ -71,7 +75,7 @@ export type ManagedCommandsHooks) => Promise + overrideGroupAdminCheck?: (userId: number, chatId: number, context: OC) => Promise /** * Called when a command is invoked, before any processing is done, can be used to implement custom logic that should * run before checking permissions or requirements, for example logging or analytics @@ -82,6 +86,11 @@ export type ManagedCommandsHooks + /** + * A function to externally cache whether a user has had the commands menu generated for them or not + * @returns true if the user has had the commands menu generated, false otherwise + */ + cachedUserSetCommands?: (userId: number, chatId: number) => Promise } export interface IManagedCommandsOptions { @@ -91,7 +100,7 @@ export interface IManagedCommandsOptions + adapter?: ConversationStorage /** * A function to get externally defined roles for a specific user. @@ -99,7 +108,7 @@ export interface IManagedCommandsOptions { + * getUserRoles: async (userId) => { * const roles = await db.getUserRoles(userId) // Array<"admin" | "user">[] * return roles * }, @@ -114,7 +123,7 @@ export interface IManagedCommandsOptions) => Promise + getUserRoles?: (userId: number) => Promise /** * Additional plugins to apply to the conversation inner composer. @@ -124,7 +133,7 @@ export interface IManagedCommandsOptions + hooks?: ManagedCommandsHooks } export type ManagedCommandsOptions = string extends TRole @@ -162,7 +171,7 @@ export class ManagedCommands< { private composer = new Composer() private commands: Record[]> = {} - private getUserRoles: (userId: number, context: CommandContext) => Promise + private getUserRoles: (userId: number) => Promise private hooks: ManagedCommandsHooks private adapter: ConversationStorage private registeredTriggers = new Set() @@ -263,9 +272,11 @@ export class ManagedCommands< */ private static formatCommandUsage(cmd: AnyCommand): string { const args = cmd.args ?? [] - const scope = - cmd.scope === "private" ? "Private Chat" : cmd.scope === "group" ? "Groups" : "Groups and Private Chat" - + const scope = switchOnScope(cmd, { + private: "šŸ‘¤ Private chats only", + group: "šŸ‘„ Groups only", + both: "šŸŒ Both private and group chats", + }) return fmt(({ n, b, i }) => [ typeof cmd.trigger === "string" ? `/${cmd.trigger}` : cmd.trigger.map((t) => `/${t}`).join(" | "), ...args.map(({ key, optional }) => (optional ? n`[${i`${key}`}]` : n`<${i`${key}`}>`)), @@ -281,10 +292,14 @@ export class ManagedCommands< private static formatCommandShort(cmd: AnyCommand): string { const args = cmd.args ?? [] + const trigger: string = + typeof cmd.trigger === "string" ? `/${cmd.trigger}` : cmd.trigger.map((t) => `/${t}`).join(" | ") + const scope = switchOnScope(cmd, { private: "šŸ‘¤", group: "šŸ‘„", both: "šŸŒ" }) + const admin = isAllowedInGroups(cmd) && cmd.permissions?.allowGroupAdmins ? "šŸ›”ļø" : "" return fmt(({ i, n }) => [ - typeof cmd.trigger === "string" ? `/${cmd.trigger}` : cmd.trigger.map((t) => `/${t}`).join(" | "), + trigger, ...args.map(({ key, optional }) => (optional ? i` [${key}]` : i` <${key}>`)), - n`\n\t${cmd.description ?? "No description"}`, + n`\n\t${scope}${admin}${cmd.description ?? "No description"}`, ]) } @@ -364,13 +379,41 @@ export class ManagedCommands< }) ) - this.composer.command("help", async (ctx) => { - if (ctx.chat.type !== "private") - return void ephemeral( - ctx.reply(fmt(({ n, code }) => n`You can only send ${code`/help`} in private chat with the bot.`)), - 10_000 - ) + const setFreeCommands = once(async (ctx: OC) => { + const freeCommands = this.getCommands().filter((cmd) => this.isCommandAllowedForRoles(cmd, [])) + const privateCommands: BotCommand[] = freeCommands + .filter((cmd) => isAllowedInPrivate(cmd)) + .flatMap((cmd) => toBotCommands(cmd)) + .concat([{ command: "help", description: "Show available commands" }]) + await ctx.api.setMyCommands(privateCommands, { scope: { type: "all_private_chats" } }).catch(() => {}) + const groupCommands: BotCommand[] = freeCommands + .filter((cmd) => isAllowedInGroups(cmd) && this.isCommandAllowedInGroup(cmd, -100)) // only include commands that are allowed in all groups + .flatMap((cmd) => toBotCommands(cmd)) + .concat([{ command: "help", description: "Show available commands" }]) + await ctx.api.setMyCommands(groupCommands, { scope: { type: "all_group_chats" } }).catch(() => {}) + }) + + this.composer.use(async (ctx, next) => { + await setFreeCommands(ctx) + return next() + }) + + this.composer.on("message").use(async (ctx, next) => { + if (!ctx.from) return next() + const shouldSkip = (await this.hooks.cachedUserSetCommands?.(ctx.from.id, ctx.chat.id)) ?? false + if (shouldSkip) return next() + const allowedCommands = await this.getAllowedCommandsFor(ctx) + await ctx.api + .setMyCommands(allowedCommands.flatMap(toBotCommands), { + scope: { type: "chat_member", chat_id: ctx.chat.id, user_id: ctx.from.id }, + }) + .catch(() => {}) + return next() + }) + this.composer.command("help", async (ctx) => { + if (!ctx.from) return + const userId = ctx.from.id const text = ctx.message?.text ?? "" const [_, cmdArg] = text.replaceAll("/", "").split(" ") @@ -383,14 +426,30 @@ export class ManagedCommands< return ctx.reply(ManagedCommands.formatCommandUsage(cmd)) } + const getUserRoles = once(async () => await this.getUserRoles(userId)) + const isFromGroupAdmin = once(async () => { + if (ctx.chat.type === "private") return true + return await this.isFromGroupAdmin(ctx) + }) + + const rawCollections = await asyncMap(Object.entries(this.commands), async ([collection, cmds]) => ({ + collection, + commands: await asyncFilter(cmds, async (cmd) => + this.checkPermissionsCached(cmd, ctx, getUserRoles, isFromGroupAdmin) + ), + })) + const collections = rawCollections.filter((c) => c.commands.length > 0) + const reply = fmt( - ({ u, b, skip, n, code }) => [ + ({ u, b, skip, n, code, i }) => [ b`Available commands:`, - ...Object.entries(this.commands).flatMap(([collection, cmds]) => [ + ...collections.flatMap(({ collection, commands }) => [ collection === "default" ? "" : u`${b`\n${collection}:`}`, - ...cmds.flatMap((cmd) => [skip`${ManagedCommands.formatCommandShort(cmd)}`]), + ...commands.map((cmd) => skip`${ManagedCommands.formatCommandShort(cmd)}`), ]), - n`\n\nType ${code`\/help `} for more details on a specific command.`, + i`\nšŸ‘¤: Private only, šŸ‘„: Group only, šŸŒ: Everywhere`, + i`Commands marked with šŸ›”ļø are restricted to administrators.`, + n`Type ${code`\/help `} for more details on a specific command.`, ], { sep: "\n" } ) @@ -407,36 +466,84 @@ export class ManagedCommands< return cmds } - private async checkPermissions(command: AnyCommand, ctx: CommandContext): Promise { - if (!command.permissions) return true - if (!ctx.from) return false + /** + * Checks whether a command is allowed in a specific group based on its permissions + */ + private isCommandAllowedInGroup(command: AnyGroupCommand, chatId: number): boolean { + const { allowedGroupsId, excludedGroupsId } = command.permissions ?? {} + if (allowedGroupsId && !allowedGroupsId.includes(chatId)) return false + if (excludedGroupsId?.includes(chatId)) return false + return true + } - const { allowedRoles, excludedRoles } = command.permissions + /** + * Checks whether a command is allowed for a specific set of roles based on its permissions + */ + private isCommandAllowedForRoles(command: AnyCommand, roles: TRole[]): boolean { + const { allowedRoles, excludedRoles } = command.permissions ?? {} + if (allowedRoles?.every((r) => !roles.includes(r))) return false + if (excludedRoles?.some((r) => roles.includes(r))) return false + return true + } - if (isAllowedInGroups(command) && (ctx.chat.type === "group" || ctx.chat.type === "supergroup")) { - const { allowGroupAdmins, allowedGroupsId, excludedGroupsId } = command.permissions + private async isFromGroupAdmin(ctx: OC): Promise { + if (!ctx.from || !ctx.chatId) return false + if (this.hooks.overrideGroupAdminCheck) { + const isAdmin = await this.hooks.overrideGroupAdminCheck(ctx.from.id, ctx.chatId, ctx) + if (isAdmin) return true + } else { + const { status: groupRole } = await ctx.getChatMember(ctx.from.id) + if (groupRole === "administrator" || groupRole === "creator") return true + } + return false + } - if (allowedGroupsId && !allowedGroupsId.includes(ctx.chatId)) return false - if (excludedGroupsId?.includes(ctx.chatId)) return false + private async getAllowedCommandsFor(ctx: ContextWith): Promise[]> { + const getUserRoles = once(() => this.getUserRoles(ctx.from.id)) + const isFromGroupAdmin = once(() => this.isFromGroupAdmin(ctx)) + + return await Promise.all( + this.getCommands() + .filter(isFromPrivateChat(ctx) ? (cmd) => isAllowedInPrivate(cmd) : (cmd) => isAllowedInGroups(cmd)) + .map((cmd) => + this.checkPermissionsCached(cmd, ctx, getUserRoles, isFromGroupAdmin).then((allowed) => + allowed ? cmd : null + ) + ) + ).then((cmds) => cmds.filter((c) => c !== null)) + } - if (allowGroupAdmins) { - if (this.hooks.overrideGroupAdminCheck) { - const isAdmin = await this.hooks.overrideGroupAdminCheck(ctx.from.id, ctx.chatId, ctx) - if (isAdmin) return true - } else { - const { status: groupRole } = await ctx.getChatMember(ctx.from.id) - if (groupRole === "administrator" || groupRole === "creator") return true - } + private async checkPermissionsCached( + command: AnyCommand, + ctx: ContextWith, + getUserRoles: () => Promise, + isFromGroupAdmin: () => Promise + ): Promise { + if (!command.permissions) return true + + if (isAllowedInGroups(command)) { + const allowed = this.isCommandAllowedInGroup(command, ctx.chat.id) + if (!allowed) return false + + if (command.permissions.allowGroupAdmins) { + const isAdmin = await isFromGroupAdmin() + if (isAdmin) return true } } - const roles = await this.getUserRoles(ctx.from.id, ctx) - - // blacklist is stronger than whitelist - if (allowedRoles?.every((r) => !roles.includes(r))) return false - if (excludedRoles?.some((r) => roles.includes(r))) return false + const roles = await getUserRoles() + return this.isCommandAllowedForRoles(command, roles) + } - return true + private async checkPermissions(command: AnyCommand, ctx: CommandContext): Promise { + if (!ctx.from) return false + const userId = ctx.from.id + return this.checkPermissionsCached( + command, + ctx, + () => this.getUserRoles(userId), + () => this.isFromGroupAdmin(ctx) + ) } /** @@ -538,7 +645,10 @@ export class ManagedCommands< code`/help ${Array.isArray(cmd.trigger) ? cmd.trigger[0] : cmd.trigger}`, ]) ) - if (!isPrivate) void ephemeral(msg, 10_000) // delete the error message after some time in groups, no need to keep it + if (!isPrivate) + void wait(10_000) + .then(() => msg.delete()) + .catch(() => {}) return } diff --git a/src/redis/set.ts b/src/redis/set.ts new file mode 100644 index 0000000..ada3c75 --- /dev/null +++ b/src/redis/set.ts @@ -0,0 +1,117 @@ +import { EventEmitter } from "node:events" +import { + createClient, + type RedisClientOptions, + type RedisClientType, + type RedisFunctions, + type RedisModules, + type RedisScripts, +} from "redis" + +export interface RedisSetOptions { + /** Redis client instance, or options to create one */ + redis: RedisClientType | RedisClientOptions + /** Time to live for each entry in seconds, uses redis' EXPIRE command */ + ttl?: number + /** + * Prefix for each key stored in redis, to avoid collisions, if not provided a + * default one will be used to ensure uniqueness across multiple instances + */ + prefix?: string +} + +export class RedisSet< + M extends RedisModules = RedisModules, + F extends RedisFunctions = RedisFunctions, + S extends RedisScripts = RedisScripts, +> { + private static instanceCount = 0 + private prefix: string + // In-memory cache used when Redis is not available + private memoryCache: Set = new Set() + // temporary store for keys that need to be deleted once redis is back (used when delete does not find the key in memoryCache) + private deletions: Set = new Set() + private redisClient: RedisClientType + + constructor(private options: RedisSetOptions) { + const prefix = options.prefix ?? `redis-set-${RedisSet.instanceCount++}` + if (prefix.endsWith(":")) { + prefix.slice(0, -1) + } + this.prefix = prefix + if (options.redis instanceof EventEmitter) { + // RedisClient extends event emitter :) + this.redisClient = options.redis + } else { + this.redisClient = createClient(options.redis) + void this.redisClient.connect() + } + + this.redisClient.on("ready", () => { + void this.flushMemoryCache() + }) + } + + /** + * Flush the in-memory cache to Redis. Called automatically when the Redis + * connection is re-established. + */ + private async flushMemoryCache() { + // write all memoryCache entries to redis + await Promise.all(this.memoryCache.values().map((value) => this._add(value))) + this.memoryCache.clear() + // delete all keys that were marked for deletion while redis was down + await Promise.all(this.deletions.values().map((k) => this._delete(k))) + this.deletions.clear() + } + + private ready(): boolean { + return this.redisClient.isOpen && this.redisClient.isReady + } + + /** + * Writes a value to Redis. + * + * Sets an expiry if ttl is set in options. + * @param value The value to insert in the set. + */ + private async _add(value: string) { + await this.redisClient.sAdd(this.prefix, value) + if (this.options.ttl) { + await this.redisClient.expire(this.prefix, this.options.ttl) + } + } + + /** + * Deletes a key from Redis. + * @param key The key to delete. + */ + private async _delete(value: string) { + await this.redisClient.sRem(this.prefix, value) + } + + async add(value: string): Promise { + if (this.ready()) { + await this._add(value) + } else { + this.memoryCache.add(value) + } + } + + async delete(value: string): Promise { + if (this.ready()) { + await this._delete(value) + } else { + // Try to delete from memory cache, if not found add to deletions set + if (!this.memoryCache.delete(value)) this.deletions.add(value) + } + } + + async has(value: string): Promise { + if (this.ready()) { + return await this.redisClient.sIsMember(this.prefix, value) + } else { + return this.memoryCache.has(value) + } + } +} diff --git a/src/utils/arrays.ts b/src/utils/arrays.ts new file mode 100644 index 0000000..4cb15c8 --- /dev/null +++ b/src/utils/arrays.ts @@ -0,0 +1,9 @@ +export function asyncFilter(arr: T[], predicate: (item: T) => Promise): Promise { + return Promise.all(arr.map(async (item) => ({ item, keep: await predicate(item) }))).then((results) => + results.filter((result) => result.keep).map((result) => result.item) + ) +} + +export function asyncMap(arr: T[], mapper: (item: T) => Promise): Promise { + return Promise.all(arr.map(mapper)) +} diff --git a/src/utils/once.ts b/src/utils/once.ts index a6ae914..9805a3c 100644 --- a/src/utils/once.ts +++ b/src/utils/once.ts @@ -15,14 +15,10 @@ import type { MaybePromise } from "./types" * @returns A wrapped version of `fn` that only runs on the first call */ export function once(fn: (...args: A) => MaybePromise) { - let called = false - let result: R + let result: Promise> | undefined - return async (...args: A) => { - if (!called) { - called = true - result = await fn(...args) - } + return (...args: A) => { + result ??= Promise.resolve(fn(...args)) return result } } diff --git a/src/utils/types.ts b/src/utils/types.ts index b6ccd85..0c42622 100644 --- a/src/utils/types.ts +++ b/src/utils/types.ts @@ -5,14 +5,8 @@ import type { ApiInput, ApiOutput } from "@/backend" import type { ManagedCommandsFlavor } from "@/lib/managed-commands" import type { TelemetryContextFlavor } from "@/modules/telemetry" -export type OptionalPropertyOf = Exclude< - { - [K in keyof T]: T[K] extends undefined ? never : K - }[keyof T], - undefined -> -export type ContextWith

> = Exclude & { - [K in P]: NonNullable +export type ContextWith = C & { + [K in P]: NonNullable } export type MaybePromise = T | Promise diff --git a/tests/common/dummy-bot.ts b/tests/common/dummy-bot.ts new file mode 100644 index 0000000..e16c89f --- /dev/null +++ b/tests/common/dummy-bot.ts @@ -0,0 +1,163 @@ +import { type ApiCallFn, Bot, type Context, type RawApi } from "grammy" +import type { Update } from "grammy/types" + +type ApiFunction = ApiCallFn +export type ResultType = Awaited> +type Params = Parameters +type PayloadType = Params[1] +export type OutgoingRequest = { + method: string + payload: PayloadType +} + +/** Returns the current timestamp in seconds since the Unix epoch. */ +function now() { + return Math.floor(Date.now() / 1000) +} + +/** + * Creates a dummy bot instance for testing purposes. + * @returns An object containing the dummy bot and an array to capture outgoing API requests + */ +export async function createDummyBot() { + const bot = new Bot("token") + const outgoingRequests: OutgoingRequest[] = [] + + bot.api.config.use(async (_, method, payload) => { + outgoingRequests.push({ method, payload }) + return { ok: true, result: true as ResultType } + }) + + bot.botInfo = { + id: 42, + first_name: "Dummy Bot", + is_bot: true, + username: "dummy_bot", + can_join_groups: true, + can_read_all_group_messages: false, + supports_inline_queries: false, + can_connect_to_business: false, + has_main_web_app: false, + } + + await bot.init() + + return { + bot, + outgoingRequests, + } +} + +/** Generates a dummy command call `Update` for testing purposes. */ +export function generateCommandCall(trigger: string, id: number = 0): Update { + return { + update_id: 0, + message: { + text: `/${trigger}`, + message_id: 0, + date: now(), + entities: [ + { + type: "bot_command", + offset: 0, + length: trigger.length + 1, + }, + ], + chat: { + id, + first_name: "Test", + last_name: "Lastest", + username: "testuser", + type: "private", + }, + from: { + id, + first_name: "Test", + last_name: "Lastest", + username: "testuser", + is_bot: false, + }, + }, + } +} + +/** Generates a dummy command call `Update` for testing purposes, from a group chat */ +export function generateGroupCommandCall(trigger: string, id: number = 0): Update { + return { + update_id: 0, + message: { + text: `/${trigger}`, + message_id: 0, + date: now(), + entities: [ + { + type: "bot_command", + offset: 0, + length: trigger.length + 1, + }, + ], + chat: { + id, + title: "Test Group", + type: "group", + }, + from: { + id, + first_name: "Test", + last_name: "Lastest", + username: "testuser", + is_bot: false, + }, + }, + } +} + +/** Generates a dummy text message `Update` for testing purposes. */ +export function generateMessage(text: string, id: number = 0): Update { + return { + update_id: 0, + message: { + text, + message_id: 0, + date: now(), + chat: { + id, + first_name: "Test", + last_name: "Lastest", + username: "testuser", + type: "private", + }, + from: { + id, + first_name: "Test", + last_name: "Lastest", + username: "testuser", + is_bot: false, + }, + }, + } +} + +/** Generates a dummy text message `Update` for testing purposes, from a group chat */ +export function generateGroupMessage(text: string, id: number = 0): Update { + return { + update_id: 0, + message: { + text, + message_id: 0, + date: now(), + chat: { + id, + title: "Test Group", + type: "group", + }, + from: { + id, + first_name: "Test", + last_name: "Lastest", + username: "testuser", + is_bot: false, + }, + }, + } +} diff --git a/tests/managed-commands/permissions.test.ts b/tests/managed-commands/permissions.test.ts new file mode 100644 index 0000000..c9df46d --- /dev/null +++ b/tests/managed-commands/permissions.test.ts @@ -0,0 +1,312 @@ +import { hydrate } from "@grammyjs/hydrate" +import { hydrateReply, parseMode } from "@grammyjs/parse-mode" +import { MemorySessionStorage } from "grammy" +import { beforeEach, describe, expect, it } from "vitest" +import { ManagedCommands, type ManagedCommandsFlavor } from "@/lib/managed-commands" +import { + createDummyBot, + generateCommandCall, + generateGroupCommandCall, + type OutgoingRequest, + type ResultType, +} from "../common/dummy-bot" + +let wasMissingPermissions = false +const { bot, outgoingRequests } = await createDummyBot() + +type Role = "admin" | "mod" | "banned" + +const commands = new ManagedCommands({ + adapter: new MemorySessionStorage(), + getUserRoles: async (userId) => { + if (userId === 1) return ["admin"] + if (userId === 2) return ["mod"] + if (userId === 3) return ["banned"] + if (userId === 4) return ["admin", "banned"] + return [] + }, + plugins: [ + async (ctx, next) => { + ctx.api.config.use(async (_, method, payload) => { + outgoingRequests.push({ method, payload }) + return { ok: true, result: true as ResultType } + }) + await next() + }, + ], + hooks: { + overrideGroupAdminCheck: async (userId) => { + return userId === 1 || userId === 99 + }, + missingPermissions: async () => { + wasMissingPermissions = true + }, + }, +}) + .createCommand({ + trigger: "public", + handler: async ({ context }) => { + await context.reply("Public command executed") + }, + }) + .createCommand({ + trigger: "private", + scope: "private", + handler: async ({ context }) => { + await context.reply("Private command executed") + }, + }) + .createCommand({ + trigger: "group", + scope: "group", + permissions: { + allowGroupAdmins: false, + }, + handler: async ({ context }) => { + await context.reply("Group command executed") + }, + }) + .createCommand({ + trigger: "role_admin", + permissions: { + allowedRoles: ["admin"], + }, + handler: async ({ context }) => { + await context.reply("Role admin command executed") + }, + }) + .createCommand({ + trigger: "role_mod", + permissions: { + allowedRoles: ["mod"], + }, + handler: async ({ context }) => { + await context.reply("Role mod command executed") + }, + }) + .createCommand({ + trigger: "role_admin_or_mod", + permissions: { + allowedRoles: ["admin", "mod"], + }, + handler: async ({ context }) => { + await context.reply("Role admin or mod command executed") + }, + }) + .createCommand({ + trigger: "excluded_banned", + permissions: { + excludedRoles: ["banned"], + }, + handler: async ({ context }) => { + await context.reply("Excluded banned command executed") + }, + }) + .createCommand({ + trigger: "group_admin", + scope: "group", + permissions: { + allowedRoles: [], + allowGroupAdmins: true, + }, + handler: async ({ context }) => { + await context.reply("Group admin command executed") + }, + }) + .createCommand({ + trigger: "group_allowed_only", + scope: "group", + permissions: { + allowGroupAdmins: false, + allowedGroupsId: [50], + }, + handler: async ({ context }) => { + await context.reply("Group allowed-only command executed") + }, + }) + .createCommand({ + trigger: "group_excluded", + scope: "group", + permissions: { + allowGroupAdmins: false, + excludedGroupsId: [60], + }, + handler: async ({ context }) => { + await context.reply("Group excluded command executed") + }, + }) + .createCommand({ + trigger: "group_allowed_and_excluded", + scope: "group", + permissions: { + allowGroupAdmins: false, + allowedGroupsId: [70], + excludedGroupsId: [70], + }, + handler: async ({ context }) => { + await context.reply("Group allowed-and-excluded command executed") + }, + }) + +bot.use(hydrate()) +bot.use(hydrateReply) +bot.api.config.use(parseMode("MarkdownV2")) +bot.use(commands) + +beforeEach(() => { + outgoingRequests.length = 0 + wasMissingPermissions = false +}) + +function payloadText(request?: OutgoingRequest): string | undefined { + if (!!request && "text" in request.payload) { + return request.payload.text + } + return undefined +} + +function sendMessages(): OutgoingRequest[] { + return outgoingRequests.filter((request) => request.method === "sendMessage") +} + +function lastSentText(): string | undefined { + const request = sendMessages().at(-1) + return payloadText(request) +} + +function expectNoMessageWithText(text: string): void { + expect(sendMessages().some((request) => payloadText(request) === text)).toBe(false) +} + +function normalizeMarkdownEscapes(text?: string): string { + return (text ?? "").replaceAll("\\_", "_") +} + +function expectMissingPermissions(): void { + expect(wasMissingPermissions).toBe(true) + wasMissingPermissions = false +} + +describe("ManagedCommands - Permissions", () => { + it("executes command without permissions", async () => { + await bot.handleUpdate(generateCommandCall("public")) + expect(lastSentText()).toBe("Public command executed") + }) + + it("allows command when user has a required role", async () => { + await bot.handleUpdate(generateCommandCall("role_admin", 1)) + expect(lastSentText()).toBe("Role admin command executed") + }) + + it("denies command when user misses required role", async () => { + await bot.handleUpdate(generateCommandCall("role_admin", 2)) + expectNoMessageWithText("Role admin command executed") + expectMissingPermissions() + }) + + it("allows command when at least one allowed role matches", async () => { + await bot.handleUpdate(generateCommandCall("role_admin_or_mod", 2)) + expect(lastSentText()).toBe("Role admin or mod command executed") + }) + + it("denies command when all allowed roles are missing", async () => { + await bot.handleUpdate(generateCommandCall("role_admin_or_mod")) + expectNoMessageWithText("Role admin or mod command executed") + expectMissingPermissions() + }) + + it("denies command when user has an excluded role", async () => { + await bot.handleUpdate(generateCommandCall("excluded_banned", 3)) + expectNoMessageWithText("Excluded banned command executed") + expectMissingPermissions() + }) + + it("denies command when excluded role is present with an allowed one", async () => { + await bot.handleUpdate(generateCommandCall("excluded_banned", 4)) + expectNoMessageWithText("Excluded banned command executed") + expectMissingPermissions() + }) + + it("allows group-admin command for group admins without external roles", async () => { + await bot.handleUpdate(generateGroupCommandCall("group_admin", 99)) + expect(lastSentText()).toBe("Group admin command executed") + }) + + it("denies group-admin command for non-admin users without required roles", async () => { + await bot.handleUpdate(generateGroupCommandCall("group_admin")) + expectNoMessageWithText("Group admin command executed") + expectMissingPermissions() + }) + + it("allows command only in explicitly allowed groups", async () => { + await bot.handleUpdate(generateGroupCommandCall("group_allowed_only", 50)) + expect(lastSentText()).toBe("Group allowed-only command executed") + }) + + it("denies command outside explicitly allowed groups", async () => { + await bot.handleUpdate(generateGroupCommandCall("group_allowed_only", 51)) + expectNoMessageWithText("Group allowed-only command executed") + expectMissingPermissions() + }) + + it("denies command in excluded groups", async () => { + await bot.handleUpdate(generateGroupCommandCall("group_excluded", 60)) + expectNoMessageWithText("Group excluded command executed") + expectMissingPermissions() + }) + + it("allows command in non-excluded groups", async () => { + await bot.handleUpdate(generateGroupCommandCall("group_excluded", 61)) + expect(lastSentText()).toBe("Group excluded command executed") + }) + + it("gives exclusion precedence when group is both allowed and excluded", async () => { + await bot.handleUpdate(generateGroupCommandCall("group_allowed_and_excluded", 70)) + expectNoMessageWithText("Group allowed-and-excluded command executed") + expectMissingPermissions() + }) + + it("shows only commands allowed by role in help for regular users", async () => { + await bot.handleUpdate(generateCommandCall("help", 10)) + const text = normalizeMarkdownEscapes(lastSentText()) + expect(text).toContain("Available commands") + expect(text).toContain("/public") + expect(text).toContain("/private") + expect(text).toContain("/group") + expect(text).toContain("/excluded_banned") + expect(text).toContain("/group_admin") + expect(text).not.toContain("/role_admin") + expect(text).not.toContain("/role_mod") + expect(text).not.toContain("/role_admin_or_mod") + }) + + it("shows role-gated commands in help for admins", async () => { + await bot.handleUpdate(generateCommandCall("help", 1)) + const text = normalizeMarkdownEscapes(lastSentText()) + expect(text).toContain("/role_admin") + expect(text).toContain("/role_admin_or_mod") + expect(text).not.toContain("/role_mod") + expect(text).toContain("/excluded_banned") + }) + + it("hides excluded commands in help for excluded users", async () => { + await bot.handleUpdate(generateCommandCall("help", 3)) + const text = normalizeMarkdownEscapes(lastSentText()) + expect(text).toContain("Available commands") + expect(text).not.toContain("/excluded_banned") + }) + + it("hides group-restricted command in help when current group is not allowed", async () => { + await bot.handleUpdate(generateGroupCommandCall("help", 51)) + const text = normalizeMarkdownEscapes(lastSentText()) + expect(text).toContain("Available commands") + expect(text).not.toContain("/group_allowed_only") + }) + + it("shows group-restricted command in help when current group is allowed", async () => { + await bot.handleUpdate(generateGroupCommandCall("help", 50)) + const text = normalizeMarkdownEscapes(lastSentText()) + expect(text).toContain("Available commands") + expect(text).toContain("/group_allowed_only") + }) +}) diff --git a/tests/awaiter.test.ts b/tests/utils/awaiter.test.ts similarity index 100% rename from tests/awaiter.test.ts rename to tests/utils/awaiter.test.ts diff --git a/tests/format.test.ts b/tests/utils/format.test.ts similarity index 100% rename from tests/format.test.ts rename to tests/utils/format.test.ts diff --git a/tests/utils/once.test.ts b/tests/utils/once.test.ts new file mode 100644 index 0000000..46c5439 --- /dev/null +++ b/tests/utils/once.test.ts @@ -0,0 +1,66 @@ +import { describe, expect, it, vi } from "vitest" +import { once } from "@/utils/once" + +describe("once", () => { + it("executes the wrapped function only once for sync results", async () => { + const fn = vi.fn((value: number) => value * 2) + const wrapped = once(fn) + + const first = await wrapped(2) + const second = await wrapped(7) + const third = await wrapped(11) + + expect(first).toBe(4) + expect(second).toBe(4) + expect(third).toBe(4) + expect(fn).toHaveBeenCalledTimes(1) + expect(fn).toHaveBeenCalledWith(2) + }) + + it("returns the same promise for repeated async calls", async () => { + const fn = vi.fn(async (name: string) => ({ name, createdAt: Date.now() })) + const wrapped = once(fn) + + const promise1 = wrapped("first") + const promise2 = wrapped("second") + const promise3 = wrapped("third") + + expect(promise1).toBe(promise2) + expect(promise2).toBe(promise3) + + const result1 = await promise1 + const result2 = await promise2 + const result3 = await promise3 + + expect(result1).toEqual({ name: "first", createdAt: expect.any(Number) }) + expect(result2).toEqual(result1) + expect(result3).toEqual(result1) + expect(fn).toHaveBeenCalledTimes(1) + expect(fn).toHaveBeenCalledWith("first") + }) + + it("preserves the original rejection for later calls", async () => { + const error = new Error("boom") + const fn = vi.fn(async () => { + throw error + }) + const wrapped = once(fn) + + await expect(wrapped()).rejects.toThrow(error) + // @ts-expect-error: This is testing that the same error is thrown on subsequent calls, even if the arguments are different + await expect(wrapped("ignored")).rejects.toThrow(error) + + expect(fn).toHaveBeenCalledTimes(1) + }) + + it("keeps the first resolved value even if later calls pass different arguments", async () => { + const fn = vi.fn((value: string, suffix: string) => `${value}-${suffix}`) + const wrapped = once(fn) + + await expect(wrapped("alpha", "one")).resolves.toBe("alpha-one") + await expect(wrapped("beta", "two")).resolves.toBe("alpha-one") + + expect(fn).toHaveBeenCalledTimes(1) + expect(fn).toHaveBeenCalledWith("alpha", "one") + }) +}) diff --git a/tests/throttle.test.ts b/tests/utils/throttle.test.ts similarity index 100% rename from tests/throttle.test.ts rename to tests/utils/throttle.test.ts